Standardize root outputs as numpy arrays.

This commit is contained in:
2025-12-06 08:49:35 -04:00
parent b415df2983
commit 117e43a984

View File

@@ -1127,7 +1127,9 @@ class Function:
d_unique = cupy.unique(rounded_real + 1j * rounded_imag) d_unique = cupy.unique(rounded_real + 1j * rounded_imag)
return cupy.asnumpy(d_unique) # Sort the unique roots and copy back to CPU
final_solutions_gpu = cupy.sort(d_unique)
return final_solutions_gpu.get()
def __str__(self) -> str: def __str__(self) -> str:
@@ -1299,7 +1301,7 @@ class Function:
return np.allclose(c1, c2) return np.allclose(c1, c2)
def quadratic_solve(self) -> Optional[List[complex]]: def quadratic_solve(self) -> Optional[List[Union[complex, float]]]:
""" """
Calculates the roots (real or complex) of a quadratic function. Calculates the roots (real or complex) of a quadratic function.
@@ -1328,9 +1330,14 @@ class Function:
else: else:
# Standard case: Use Vieta's formula # Standard case: Use Vieta's formula
root2 = (c / a) / root1 root2 = (c / a) / root1
roots = np.array([root1, root2])
roots.sort()
# Return roots in a consistent order if np.all(np.abs(roots.imag) < 1e-15):
return [root1, root2] return roots.real.astype(np.float64)
return roots
# Example Usage # Example Usage
if __name__ == '__main__': if __name__ == '__main__':