Standardize root outputs as numpy arrays.
This commit is contained in:
@@ -1127,7 +1127,9 @@ class Function:
|
|||||||
|
|
||||||
d_unique = cupy.unique(rounded_real + 1j * rounded_imag)
|
d_unique = cupy.unique(rounded_real + 1j * rounded_imag)
|
||||||
|
|
||||||
return cupy.asnumpy(d_unique)
|
# Sort the unique roots and copy back to CPU
|
||||||
|
final_solutions_gpu = cupy.sort(d_unique)
|
||||||
|
return final_solutions_gpu.get()
|
||||||
|
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
@@ -1299,7 +1301,7 @@ class Function:
|
|||||||
return np.allclose(c1, c2)
|
return np.allclose(c1, c2)
|
||||||
|
|
||||||
|
|
||||||
def quadratic_solve(self) -> Optional[List[complex]]:
|
def quadratic_solve(self) -> Optional[List[Union[complex, float]]]:
|
||||||
"""
|
"""
|
||||||
Calculates the roots (real or complex) of a quadratic function.
|
Calculates the roots (real or complex) of a quadratic function.
|
||||||
|
|
||||||
@@ -1328,9 +1330,14 @@ class Function:
|
|||||||
else:
|
else:
|
||||||
# Standard case: Use Vieta's formula
|
# Standard case: Use Vieta's formula
|
||||||
root2 = (c / a) / root1
|
root2 = (c / a) / root1
|
||||||
|
|
||||||
|
roots = np.array([root1, root2])
|
||||||
|
roots.sort()
|
||||||
|
|
||||||
# Return roots in a consistent order
|
if np.all(np.abs(roots.imag) < 1e-15):
|
||||||
return [root1, root2]
|
return roots.real.astype(np.float64)
|
||||||
|
|
||||||
|
return roots
|
||||||
|
|
||||||
# Example Usage
|
# Example Usage
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|||||||
Reference in New Issue
Block a user