diff --git a/src/polysolve/__init__.py b/src/polysolve/__init__.py index 0843e5f..6fe8e75 100644 --- a/src/polysolve/__init__.py +++ b/src/polysolve/__init__.py @@ -1127,7 +1127,9 @@ class Function: d_unique = cupy.unique(rounded_real + 1j * rounded_imag) - return cupy.asnumpy(d_unique) + # Sort the unique roots and copy back to CPU + final_solutions_gpu = cupy.sort(d_unique) + return final_solutions_gpu.get() def __str__(self) -> str: @@ -1299,7 +1301,7 @@ class Function: return np.allclose(c1, c2) - def quadratic_solve(self) -> Optional[List[complex]]: + def quadratic_solve(self) -> Optional[List[Union[complex, float]]]: """ Calculates the roots (real or complex) of a quadratic function. @@ -1328,9 +1330,14 @@ class Function: else: # Standard case: Use Vieta's formula root2 = (c / a) / root1 + + roots = np.array([root1, root2]) + roots.sort() - # Return roots in a consistent order - return [root1, root2] + if np.all(np.abs(roots.imag) < 1e-15): + return roots.real.astype(np.float64) + + return roots # Example Usage if __name__ == '__main__':