Standardize root outputs as numpy arrays.

This commit is contained in:
2025-12-06 08:49:35 -04:00
parent b415df2983
commit 117e43a984

View File

@@ -1127,7 +1127,9 @@ class Function:
d_unique = cupy.unique(rounded_real + 1j * rounded_imag)
return cupy.asnumpy(d_unique)
# Sort the unique roots and copy back to CPU
final_solutions_gpu = cupy.sort(d_unique)
return final_solutions_gpu.get()
def __str__(self) -> str:
@@ -1299,7 +1301,7 @@ class Function:
return np.allclose(c1, c2)
def quadratic_solve(self) -> Optional[List[complex]]:
def quadratic_solve(self) -> Optional[List[Union[complex, float]]]:
"""
Calculates the roots (real or complex) of a quadratic function.
@@ -1328,9 +1330,14 @@ class Function:
else:
# Standard case: Use Vieta's formula
root2 = (c / a) / root1
roots = np.array([root1, root2])
roots.sort()
# Return roots in a consistent order
return [root1, root2]
if np.all(np.abs(roots.imag) < 1e-15):
return roots.real.astype(np.float64)
return roots
# Example Usage
if __name__ == '__main__':