From 94723dcb88b533bf2b31830187a8ad92489f9d9c Mon Sep 17 00:00:00 2001 From: Jonathan Rampersad Date: Sun, 2 Nov 2025 12:50:48 -0400 Subject: [PATCH] feat(Function): Add __eq__ method and improve quadratic_solve stability Implements two features for the Function class: 1. Adds the `__eq__` operator (`==`) to allow for logical comparison of two Function objects based on their coefficients. 2. Replaces the standard quadratic formula with a numerically stable version in `quadratic_solve` to prevent "catastrophic cancellation" errors and improve accuracy. --- pyproject.toml | 2 +- src/polysolve/__init__.py | 45 +++++++++++++++++++++++++++++++++++++-- tests/test_polysolve.py | 32 ++++++++++++++++++++++++++++ 3 files changed, 76 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index cfed628..7797a57 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,7 @@ build-backend = "setuptools.build_meta" [project] # --- Core Metadata --- name = "polysolve" -version = "0.6.1" +version = "0.6.2" authors = [ { name="Jonathan Rampersad", email="jonathan@jono-rams.work" }, ] diff --git a/src/polysolve/__init__.py b/src/polysolve/__init__.py index 30f00eb..35f9537 100644 --- a/src/polysolve/__init__.py +++ b/src/polysolve/__init__.py @@ -746,6 +746,21 @@ class Function: return NotImplemented return self + + def __eq__(self, other: object) -> bool: + """ + Checks if two Function objects are equal by comparing + their coefficients. + """ + # Check if the 'other' object is even a Function + if not isinstance(other, Function): + return NotImplemented + + # Ensure both are initialized before trying to access .coefficients + if not self._initialized or not other._initialized: + return False + + return np.array_equal(self.coefficients, other.coefficients) def quadratic_solve(self) -> Optional[List[float]]: @@ -770,9 +785,35 @@ class Function: return None # No real roots sqrt_discriminant = math.sqrt(discriminant) - root1 = (-b + sqrt_discriminant) / (2 * a) - root2 = (-b - sqrt_discriminant) / (2 * a) + + # 1. Calculate the first root. + # We use math.copysign(val, sign) to get the sign of b. + # This ensures (-b - sign*sqrt) is always an *addition* + # (or subtraction of a smaller from a larger number), + # avoiding catastrophic cancellation. + root1 = (-b - math.copysign(sqrt_discriminant, b)) / (2 * a) + # 2. Calculate the second root using Vieta's formulas. + # We know that root1 * root2 = c / a. + # This is just a division, which is numerically stable. + + # Handle the edge case where c=0. + # If c=0, then root1 is 0.0, and root2 is -b/a + # We can't divide by root1=0, so we check. + if root1 == 0.0: + # If c is also 0, the other root is -b/a + if c == 0.0: + root2 = -b / a + else: + # This case (root1=0 but c!=0) shouldn't happen + # with real numbers, but it's safe to just + # return the one root we found. + return [0.0] + else: + # Standard case: Use Vieta's formula + root2 = (c / a) / root1 + + # Return roots in a consistent order return [root1, root2] # Example Usage diff --git a/tests/test_polysolve.py b/tests/test_polysolve.py index 00bdb37..dcef499 100644 --- a/tests/test_polysolve.py +++ b/tests/test_polysolve.py @@ -37,6 +37,12 @@ def m_func_2() -> Function: f.set_coeffs([5, -4]) return f +@pytest.fixture +def base_func(): + f = Function(2) + f.set_coeffs([1, 2, 3]) + return f + # --- Core Functionality Tests --- def test_solve_y(quadratic_func): @@ -95,6 +101,32 @@ def test_function_multiplication(m_func_1, m_func_2): assert result.largest_exponent == 3 assert np.array_equal(result.coefficients, [10, 7, -7, -4]) +def test_equality(base_func): + """Tests the __eq__ method for the Function class.""" + + # 1. Test for equality with a new, identical object + f_identical = Function(2) + f_identical.set_coeffs([1, 2, 3]) + assert base_func == f_identical + + # 2. Test for inequality (different coefficients) + f_different = Function(2) + f_different.set_coeffs([1, 9, 3]) + assert base_func != f_different + + # 3. Test for inequality (different degree) + f_diff_degree = Function(1) + f_diff_degree.set_coeffs([1, 2]) + assert base_func != f_diff_degree + + # 4. Test against a different type + assert base_func != "some_string" + assert base_func != 123 + + # 5. Test against an uninitialized Function + f_uninitialized = Function(2) + assert base_func != f_uninitialized + # --- Genetic Algorithm Root-Finding Tests --- def test_get_real_roots_numpy(quadratic_func): -- 2.49.1