Standardize root outputs as numpy arrays.
This commit is contained in:
@@ -1127,7 +1127,9 @@ class Function:
|
||||
|
||||
d_unique = cupy.unique(rounded_real + 1j * rounded_imag)
|
||||
|
||||
return cupy.asnumpy(d_unique)
|
||||
# Sort the unique roots and copy back to CPU
|
||||
final_solutions_gpu = cupy.sort(d_unique)
|
||||
return final_solutions_gpu.get()
|
||||
|
||||
|
||||
def __str__(self) -> str:
|
||||
@@ -1299,7 +1301,7 @@ class Function:
|
||||
return np.allclose(c1, c2)
|
||||
|
||||
|
||||
def quadratic_solve(self) -> Optional[List[complex]]:
|
||||
def quadratic_solve(self) -> Optional[List[Union[complex, float]]]:
|
||||
"""
|
||||
Calculates the roots (real or complex) of a quadratic function.
|
||||
|
||||
@@ -1329,8 +1331,13 @@ class Function:
|
||||
# Standard case: Use Vieta's formula
|
||||
root2 = (c / a) / root1
|
||||
|
||||
# Return roots in a consistent order
|
||||
return [root1, root2]
|
||||
roots = np.array([root1, root2])
|
||||
roots.sort()
|
||||
|
||||
if np.all(np.abs(roots.imag) < 1e-15):
|
||||
return roots.real.astype(np.float64)
|
||||
|
||||
return roots
|
||||
|
||||
# Example Usage
|
||||
if __name__ == '__main__':
|
||||
|
||||
Reference in New Issue
Block a user