All checks were successful
Publish Python Package to PyPI / deploy (push) Successful in 12s
The previous GA logic was returning the "top N" solutions, which led to test failures when the algorithm correctly converged on only one of all possible roots (e.g., returning 1000 variations of -1.0).
This commit fixes the root-finding logic to correctly identify and return *all* unique, high-quality roots:
1. **feat(api):** Adds `root_precision` to `GA_Options`. This new parameter (default: 5) allows the user to control the number of decimal places for clustering unique roots.
2. **fix(ga):** Replaces the flawed "top N" logic in both `_solve_x_numpy` and `_solve_x_cuda`. The new process is:
* Dynamically sets a `quality_threshold` based on the user's `root_precision` (e.g., `precision=5` requires a rank > `1e6`).
* Filters the *entire* final population for all solutions that meet this quality threshold.
* Rounds these high-quality solutions to `root_precision`.
* Returns only the `np.unique()` results.
This ensures the solver returns all distinct roots that meet the accuracy requirements, rather than just the top N variations of a single root.
Reviewed-on: #19
Co-authored-by: Jonathan Rampersad <rampersad.jonathan@gmail.com>
Co-committed-by: Jonathan Rampersad <rampersad.jonathan@gmail.com>
137 lines
4.7 KiB
Python
137 lines
4.7 KiB
Python
import pytest
|
|
import numpy as np
|
|
|
|
# Try to import cupy to check for CUDA availability
|
|
try:
|
|
import cupy
|
|
_CUPY_AVAILABLE = True
|
|
except ImportError:
|
|
_CUPY_AVAILABLE = False
|
|
|
|
from polysolve import Function, GA_Options
|
|
|
|
@pytest.fixture
|
|
def quadratic_func() -> Function:
|
|
"""Provides a standard quadratic function: 2x^2 - 3x - 5."""
|
|
f = Function(largest_exponent=2)
|
|
f.set_coeffs([2, -3, -5])
|
|
return f
|
|
|
|
@pytest.fixture
|
|
def linear_func() -> Function:
|
|
"""Provides a standard linear function: x + 10."""
|
|
f = Function(largest_exponent=1)
|
|
f.set_coeffs([1, 10])
|
|
return f
|
|
|
|
@pytest.fixture
|
|
def m_func_1() -> Function:
|
|
f = Function(2)
|
|
f.set_coeffs([2, 3, 1])
|
|
return f
|
|
|
|
@pytest.fixture
|
|
def m_func_2() -> Function:
|
|
f = Function(1)
|
|
f.set_coeffs([5, -4])
|
|
return f
|
|
|
|
# --- Core Functionality Tests ---
|
|
|
|
def test_solve_y(quadratic_func):
|
|
"""Tests if the function correctly evaluates y for a given x."""
|
|
assert quadratic_func.solve_y(5) == 30.0
|
|
assert quadratic_func.solve_y(0) == -5.0
|
|
assert quadratic_func.solve_y(-1) == 0.0
|
|
|
|
def test_derivative(quadratic_func):
|
|
"""Tests the calculation of the function's derivative."""
|
|
derivative = quadratic_func.derivative()
|
|
assert derivative.largest_exponent == 1
|
|
# The derivative of 2x^2 - 3x - 5 is 4x - 3
|
|
assert np.array_equal(derivative.coefficients, [4, -3])
|
|
|
|
def test_nth_derivative(quadratic_func):
|
|
"""Tests the calculation of the function's 2nd derivative."""
|
|
derivative = quadratic_func.nth_derivative(2)
|
|
assert derivative.largest_exponent == 0
|
|
# The derivative of 2x^2 - 3x - 5 is 4x - 3
|
|
assert np.array_equal(derivative.coefficients, [4])
|
|
|
|
def test_quadratic_solve(quadratic_func):
|
|
"""Tests the analytical quadratic solver for exact roots."""
|
|
roots = quadratic_func.quadratic_solve()
|
|
# Sorting ensures consistent order for comparison
|
|
assert sorted(roots) == [-1.0, 2.5]
|
|
|
|
# --- Arithmetic Operation Tests ---
|
|
|
|
def test_addition(quadratic_func, linear_func):
|
|
"""Tests the addition of two Function objects."""
|
|
# (2x^2 - 3x - 5) + (x + 10) = 2x^2 - 2x + 5
|
|
result = quadratic_func + linear_func
|
|
assert result.largest_exponent == 2
|
|
assert np.array_equal(result.coefficients, [2, -2, 5])
|
|
|
|
def test_subtraction(quadratic_func, linear_func):
|
|
"""Tests the subtraction of two Function objects."""
|
|
# (2x^2 - 3x - 5) - (x + 10) = 2x^2 - 4x - 15
|
|
result = quadratic_func - linear_func
|
|
assert result.largest_exponent == 2
|
|
assert np.array_equal(result.coefficients, [2, -4, -15])
|
|
|
|
def test_scalar_multiplication(linear_func):
|
|
"""Tests the multiplication of a Function object by a scalar."""
|
|
# (x + 10) * 3 = 3x + 30
|
|
result = linear_func * 3
|
|
assert result.largest_exponent == 1
|
|
assert np.array_equal(result.coefficients, [3, 30])
|
|
|
|
def test_function_multiplication(m_func_1, m_func_2):
|
|
"""Tests the multiplication of two Function objects."""
|
|
# (2x^2 + 3x + 1) * (5x -4) = 10x^3 + 7x^2 - 7x -4
|
|
result = m_func_1 * m_func_2
|
|
assert result.largest_exponent == 3
|
|
assert np.array_equal(result.coefficients, [10, 7, -7, -4])
|
|
|
|
# --- Genetic Algorithm Root-Finding Tests ---
|
|
|
|
def test_get_real_roots_numpy(quadratic_func):
|
|
"""
|
|
Tests that the NumPy-based genetic algorithm approximates the roots correctly.
|
|
"""
|
|
# Using more generations for higher accuracy in testing
|
|
ga_opts = GA_Options(num_of_generations=50, data_size=200000, root_precision=3)
|
|
|
|
roots = quadratic_func.get_real_roots(ga_opts, use_cuda=False)
|
|
|
|
# Check if the algorithm found values close to the two known roots.
|
|
# We don't know which order they'll be in, so we check for presence.
|
|
expected_roots = np.array([-1.0, 2.5])
|
|
|
|
# Check that at least one found root is close to -1.0
|
|
assert np.any(np.isclose(roots, expected_roots[0], atol=1e-2))
|
|
|
|
# Check that at least one found root is close to 2.5
|
|
assert np.any(np.isclose(roots, expected_roots[1], atol=1e-2))
|
|
|
|
|
|
@pytest.mark.skipif(not _CUPY_AVAILABLE, reason="CuPy is not installed, skipping CUDA test.")
|
|
def test_get_real_roots_cuda(quadratic_func):
|
|
"""
|
|
Tests that the CUDA-based genetic algorithm approximates the roots correctly.
|
|
This test implicitly verifies that the CUDA kernel is functioning.
|
|
It will be skipped automatically if CuPy is not available.
|
|
"""
|
|
|
|
ga_opts = GA_Options(num_of_generations=50, data_size=200000, root_precision=3)
|
|
|
|
roots = quadratic_func.get_real_roots(ga_opts, use_cuda=True)
|
|
|
|
expected_roots = np.array([-1.0, 2.5])
|
|
|
|
# Verify that the CUDA implementation also finds the correct roots within tolerance.
|
|
assert np.any(np.isclose(roots, expected_roots[0], atol=1e-2))
|
|
assert np.any(np.isclose(roots, expected_roots[1], atol=1e-2))
|
|
|