Files
PolySolve/tests/test_polysolve.py
Jonathan Rampersad b415df2983 feat: add complex root finding and dynamic CUDA shared memory optimization
Major update extending the library to solve for complex roots and optimizing GPU performance using Shared Memory.

Complex Number Support:
- Implemented `_solve_complex_cuda` and `_solve_complex_numpy` to find roots in the complex plane.
- Added specialized CUDA kernels (`_FITNESS_KERNEL_COMPLEX`, `_FITNESS_KERNEL_COMPLEX_DYNAMIC`) handling complex arithmetic (multiplication/addition) directly on the GPU.
- Updated `Function` class and `set_coeffs` to handle `np.complex128` data types.
- Updated `quadratic_solve` to return complex roots using `cmath`.

CUDA Performance & Optimization:
- Implemented Dynamic Shared Memory kernels (`extern __shared__`) to cache polynomial coefficients on the GPU block, significantly reducing global memory latency.
- Added intelligent fallback logic: The solver checks `MaxSharedMemoryPerBlock`. If the polynomial is too large for Shared Memory, it falls back to the standard Global Memory kernel to prevent crashes.
- Split complex coefficients into separate Real and Imaginary arrays for CUDA kernel efficiency.

Polynomial Logic:
- Added `_strip_leading_zeros` helper to ensure polynomial degree is correctly maintained after arithmetic operations (e.g., preventing `0x^2 + x` from being treated as degree 2).
- Updated `__init__` to allow direct coefficient injection.

GA Algorithm:
- Updated crossover logic to support 2D search space (Real + Imaginary) for complex solutions.
- Refined fitness function to explicitly handle `isinf`/`isnan` for numerical stability.
2025-12-05 13:47:29 -04:00

203 lines
6.9 KiB
Python

import pytest
import numpy as np
import numpy.testing as npt
# Try to import cupy to check for CUDA availability
try:
import cupy
_CUPY_AVAILABLE = True
except ImportError:
_CUPY_AVAILABLE = False
from polysolve import Function, GA_Options
@pytest.fixture
def quadratic_func() -> Function:
"""Provides a standard quadratic function: 2x^2 - 3x - 5."""
f = Function(largest_exponent=2)
f.set_coeffs([2, -3, -5])
return f
@pytest.fixture
def linear_func() -> Function:
"""Provides a standard linear function: x + 10."""
f = Function(largest_exponent=1)
f.set_coeffs([1, 10])
return f
@pytest.fixture
def m_func_1() -> Function:
f = Function(2)
f.set_coeffs([2, 3, 1])
return f
@pytest.fixture
def m_func_2() -> Function:
f = Function(1)
f.set_coeffs([5, -4])
return f
@pytest.fixture
def base_func():
f = Function(2)
f.set_coeffs([1, 2, 3])
return f
@pytest.fixture
def complex_func():
f = Function(2, [1, 2, 2])
return f
# --- Core Functionality Tests ---
def test_solve_y(quadratic_func):
"""Tests if the function correctly evaluates y for a given x."""
assert quadratic_func.solve_y(5) == 30.0
assert quadratic_func.solve_y(0) == -5.0
assert quadratic_func.solve_y(-1) == 0.0
def test_derivative(quadratic_func):
"""Tests the calculation of the function's derivative."""
derivative = quadratic_func.derivative()
assert derivative.largest_exponent == 1
# The derivative of 2x^2 - 3x - 5 is 4x - 3
assert np.array_equal(derivative.coefficients, [4, -3])
def test_nth_derivative(quadratic_func):
"""Tests the calculation of the function's 2nd derivative."""
derivative = quadratic_func.nth_derivative(2)
assert derivative.largest_exponent == 0
# The derivative of 2x^2 - 3x - 5 is 4x - 3
assert np.array_equal(derivative.coefficients, [4])
def test_quadratic_solve(quadratic_func):
"""Tests the analytical quadratic solver for exact roots."""
roots = quadratic_func.quadratic_solve()
# Sorting ensures consistent order for comparison
assert sorted(roots) == [-1.0, 2.5]
# --- Arithmetic Operation Tests ---
def test_addition(quadratic_func, linear_func):
"""Tests the addition of two Function objects."""
# (2x^2 - 3x - 5) + (x + 10) = 2x^2 - 2x + 5
result = quadratic_func + linear_func
assert result.largest_exponent == 2
assert np.array_equal(result.coefficients, [2, -2, 5])
def test_subtraction(quadratic_func, linear_func):
"""Tests the subtraction of two Function objects."""
# (2x^2 - 3x - 5) - (x + 10) = 2x^2 - 4x - 15
result = quadratic_func - linear_func
assert result.largest_exponent == 2
assert np.array_equal(result.coefficients, [2, -4, -15])
def test_scalar_multiplication(linear_func):
"""Tests the multiplication of a Function object by a scalar."""
# (x + 10) * 3 = 3x + 30
result = linear_func * 3
assert result.largest_exponent == 1
assert np.array_equal(result.coefficients, [3, 30])
def test_function_multiplication(m_func_1, m_func_2):
"""Tests the multiplication of two Function objects."""
# (2x^2 + 3x + 1) * (5x -4) = 10x^3 + 7x^2 - 7x -4
result = m_func_1 * m_func_2
assert result.largest_exponent == 3
assert np.array_equal(result.coefficients, [10, 7, -7, -4])
def test_equality(base_func):
"""Tests the __eq__ method for the Function class."""
# 1. Test for equality with a new, identical object
f_identical = Function(2)
f_identical.set_coeffs([1, 2, 3])
assert base_func == f_identical
# 2. Test for inequality (different coefficients)
f_different = Function(2)
f_different.set_coeffs([1, 9, 3])
assert base_func != f_different
# 3. Test for inequality (different degree)
f_diff_degree = Function(1)
f_diff_degree.set_coeffs([1, 2])
assert base_func != f_diff_degree
# 4. Test against a different type
assert base_func != "some_string"
assert base_func != 123
# 5. Test against an uninitialized Function
f_uninitialized = Function(2)
assert base_func != f_uninitialized
# --- Genetic Algorithm Root-Finding Tests ---
def test_get_real_roots_numpy(quadratic_func):
"""
Tests that the NumPy-based genetic algorithm approximates the roots correctly.
"""
# Using more generations for higher accuracy in testing
ga_opts = GA_Options(num_of_generations=50, data_size=200000, selection_percentile=0.66, root_precision=3)
roots = quadratic_func.get_real_roots(ga_opts, use_cuda=False)
# Check if the algorithm found values close to the two known roots.
# We don't know which order they'll be in, so we check for presence.
expected_roots = np.array([-1.0, 2.5])
npt.assert_allclose(np.sort(roots), np.sort(expected_roots), atol=1e-2)
@pytest.mark.skipif(not _CUPY_AVAILABLE, reason="CuPy is not installed, skipping CUDA test.")
def test_get_real_roots_cuda(quadratic_func):
"""
Tests that the CUDA-based genetic algorithm approximates the roots correctly.
This test implicitly verifies that the CUDA kernel is functioning.
It will be skipped automatically if CuPy is not available.
"""
ga_opts = GA_Options(num_of_generations=50, data_size=200000, selection_percentile=0.66, root_precision=3)
roots = quadratic_func.get_real_roots(ga_opts, use_cuda=True)
expected_roots = np.array([-1.0, 2.5])
# Verify that the CUDA implementation also finds the correct roots within tolerance.
npt.assert_allclose(np.sort(roots), np.sort(expected_roots), atol=1e-2)
def test_get_roots_numpy(complex_func):
"""
Tests that the NumPy-based genetic algorithm approximates the roots correctly.
"""
# Using more generations for higher accuracy in testing
ga_opts = GA_Options(num_of_generations=50, data_size=200000, selection_percentile=0.66, root_precision=3)
roots = complex_func.get_roots(ga_opts, use_cuda=False)
# Check if the algorithm found values close to the two known roots.
# We don't know which order they'll be in, so we check for presence.
expected_roots = np.array([-1.0-1.j, -1.0+1.j])
npt.assert_allclose(np.sort(roots), np.sort(expected_roots), atol=1e-2)
@pytest.mark.skipif(not _CUPY_AVAILABLE, reason="CuPy is not installed, skipping CUDA test.")
def test_get_roots_cuda(complex_func):
"""
Tests that the CUDA-based genetic algorithm approximates the roots correctly.
This test implicitly verifies that the CUDA kernel is functioning.
It will be skipped automatically if CuPy is not available.
"""
ga_opts = GA_Options(num_of_generations=50, data_size=200000, selection_percentile=0.66, root_precision=3)
roots = complex_func.get_roots(ga_opts, use_cuda=True)
expected_roots = np.array([-1.0-1.j, -1+1.j])
# Verify that the CUDA implementation also finds the correct roots within tolerance.
npt.assert_allclose(np.sort(roots), np.sort(expected_roots), atol=1e-2)