feat(Function): Add __eq__ method and improve quadratic_solve stability #23
@@ -5,7 +5,7 @@ build-backend = "setuptools.build_meta"
|
|||||||
[project]
|
[project]
|
||||||
# --- Core Metadata ---
|
# --- Core Metadata ---
|
||||||
name = "polysolve"
|
name = "polysolve"
|
||||||
version = "0.6.1"
|
version = "0.6.2"
|
||||||
authors = [
|
authors = [
|
||||||
{ name="Jonathan Rampersad", email="jonathan@jono-rams.work" },
|
{ name="Jonathan Rampersad", email="jonathan@jono-rams.work" },
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -746,6 +746,21 @@ class Function:
|
|||||||
return NotImplemented
|
return NotImplemented
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
def __eq__(self, other: object) -> bool:
|
||||||
|
"""
|
||||||
|
Checks if two Function objects are equal by comparing
|
||||||
|
their coefficients.
|
||||||
|
"""
|
||||||
|
# Check if the 'other' object is even a Function
|
||||||
|
if not isinstance(other, Function):
|
||||||
|
return NotImplemented
|
||||||
|
|
||||||
|
# Ensure both are initialized before trying to access .coefficients
|
||||||
|
if not self._initialized or not other._initialized:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return np.array_equal(self.coefficients, other.coefficients)
|
||||||
|
|
||||||
|
|
||||||
def quadratic_solve(self) -> Optional[List[float]]:
|
def quadratic_solve(self) -> Optional[List[float]]:
|
||||||
@@ -770,9 +785,35 @@ class Function:
|
|||||||
return None # No real roots
|
return None # No real roots
|
||||||
|
|
||||||
sqrt_discriminant = math.sqrt(discriminant)
|
sqrt_discriminant = math.sqrt(discriminant)
|
||||||
root1 = (-b + sqrt_discriminant) / (2 * a)
|
|
||||||
root2 = (-b - sqrt_discriminant) / (2 * a)
|
# 1. Calculate the first root.
|
||||||
|
# We use math.copysign(val, sign) to get the sign of b.
|
||||||
|
# This ensures (-b - sign*sqrt) is always an *addition*
|
||||||
|
# (or subtraction of a smaller from a larger number),
|
||||||
|
# avoiding catastrophic cancellation.
|
||||||
|
root1 = (-b - math.copysign(sqrt_discriminant, b)) / (2 * a)
|
||||||
|
|
||||||
|
# 2. Calculate the second root using Vieta's formulas.
|
||||||
|
# We know that root1 * root2 = c / a.
|
||||||
|
# This is just a division, which is numerically stable.
|
||||||
|
|
||||||
|
# Handle the edge case where c=0.
|
||||||
|
# If c=0, then root1 is 0.0, and root2 is -b/a
|
||||||
|
# We can't divide by root1=0, so we check.
|
||||||
|
if root1 == 0.0:
|
||||||
|
# If c is also 0, the other root is -b/a
|
||||||
|
if c == 0.0:
|
||||||
|
root2 = -b / a
|
||||||
|
else:
|
||||||
|
# This case (root1=0 but c!=0) shouldn't happen
|
||||||
|
# with real numbers, but it's safe to just
|
||||||
|
# return the one root we found.
|
||||||
|
return [0.0]
|
||||||
|
else:
|
||||||
|
# Standard case: Use Vieta's formula
|
||||||
|
root2 = (c / a) / root1
|
||||||
|
|
||||||
|
# Return roots in a consistent order
|
||||||
return [root1, root2]
|
return [root1, root2]
|
||||||
|
|
||||||
# Example Usage
|
# Example Usage
|
||||||
|
|||||||
@@ -37,6 +37,12 @@ def m_func_2() -> Function:
|
|||||||
f.set_coeffs([5, -4])
|
f.set_coeffs([5, -4])
|
||||||
return f
|
return f
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def base_func():
|
||||||
|
f = Function(2)
|
||||||
|
f.set_coeffs([1, 2, 3])
|
||||||
|
return f
|
||||||
|
|
||||||
# --- Core Functionality Tests ---
|
# --- Core Functionality Tests ---
|
||||||
|
|
||||||
def test_solve_y(quadratic_func):
|
def test_solve_y(quadratic_func):
|
||||||
@@ -95,6 +101,32 @@ def test_function_multiplication(m_func_1, m_func_2):
|
|||||||
assert result.largest_exponent == 3
|
assert result.largest_exponent == 3
|
||||||
assert np.array_equal(result.coefficients, [10, 7, -7, -4])
|
assert np.array_equal(result.coefficients, [10, 7, -7, -4])
|
||||||
|
|
||||||
|
def test_equality(base_func):
|
||||||
|
"""Tests the __eq__ method for the Function class."""
|
||||||
|
|
||||||
|
# 1. Test for equality with a new, identical object
|
||||||
|
f_identical = Function(2)
|
||||||
|
f_identical.set_coeffs([1, 2, 3])
|
||||||
|
assert base_func == f_identical
|
||||||
|
|
||||||
|
# 2. Test for inequality (different coefficients)
|
||||||
|
f_different = Function(2)
|
||||||
|
f_different.set_coeffs([1, 9, 3])
|
||||||
|
assert base_func != f_different
|
||||||
|
|
||||||
|
# 3. Test for inequality (different degree)
|
||||||
|
f_diff_degree = Function(1)
|
||||||
|
f_diff_degree.set_coeffs([1, 2])
|
||||||
|
assert base_func != f_diff_degree
|
||||||
|
|
||||||
|
# 4. Test against a different type
|
||||||
|
assert base_func != "some_string"
|
||||||
|
assert base_func != 123
|
||||||
|
|
||||||
|
# 5. Test against an uninitialized Function
|
||||||
|
f_uninitialized = Function(2)
|
||||||
|
assert base_func != f_uninitialized
|
||||||
|
|
||||||
# --- Genetic Algorithm Root-Finding Tests ---
|
# --- Genetic Algorithm Root-Finding Tests ---
|
||||||
|
|
||||||
def test_get_real_roots_numpy(quadratic_func):
|
def test_get_real_roots_numpy(quadratic_func):
|
||||||
|
|||||||
Reference in New Issue
Block a user