From b415df2983921cad3dae689fa6fe2a6ac954c654 Mon Sep 17 00:00:00 2001 From: Jonathan Rampersad Date: Fri, 5 Dec 2025 13:47:29 -0400 Subject: [PATCH] feat: add complex root finding and dynamic CUDA shared memory optimization Major update extending the library to solve for complex roots and optimizing GPU performance using Shared Memory. Complex Number Support: - Implemented `_solve_complex_cuda` and `_solve_complex_numpy` to find roots in the complex plane. - Added specialized CUDA kernels (`_FITNESS_KERNEL_COMPLEX`, `_FITNESS_KERNEL_COMPLEX_DYNAMIC`) handling complex arithmetic (multiplication/addition) directly on the GPU. - Updated `Function` class and `set_coeffs` to handle `np.complex128` data types. - Updated `quadratic_solve` to return complex roots using `cmath`. CUDA Performance & Optimization: - Implemented Dynamic Shared Memory kernels (`extern __shared__`) to cache polynomial coefficients on the GPU block, significantly reducing global memory latency. - Added intelligent fallback logic: The solver checks `MaxSharedMemoryPerBlock`. If the polynomial is too large for Shared Memory, it falls back to the standard Global Memory kernel to prevent crashes. - Split complex coefficients into separate Real and Imaginary arrays for CUDA kernel efficiency. Polynomial Logic: - Added `_strip_leading_zeros` helper to ensure polynomial degree is correctly maintained after arithmetic operations (e.g., preventing `0x^2 + x` from being treated as degree 2). - Updated `__init__` to allow direct coefficient injection. GA Algorithm: - Updated crossover logic to support 2D search space (Real + Imaginary) for complex solutions. - Refined fitness function to explicitly handle `isinf`/`isnan` for numerical stability. --- pyproject.toml | 2 +- src/polysolve/__init__.py | 836 +++++++++++++++++++++++++++++++------- tests/test_polysolve.py | 38 ++ 3 files changed, 733 insertions(+), 143 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index f0c515f..11a2088 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,7 @@ build-backend = "setuptools.build_meta" [project] # --- Core Metadata --- name = "polysolve" -version = "0.6.3" +version = "0.7.0" authors = [ { name="Jonathan Rampersad", email="jonathan@jono-rams.work" }, ] diff --git a/src/polysolve/__init__.py b/src/polysolve/__init__.py index 4155638..0843e5f 100644 --- a/src/polysolve/__init__.py +++ b/src/polysolve/__init__.py @@ -17,10 +17,10 @@ except ImportError: # The CUDA kernels for the fitness function _FITNESS_KERNEL_FLOAT = """ extern "C" __global__ void fitness_kernel( - const double* coefficients, + const double* __restrict__ coefficients, int num_coefficients, - const double* x_vals, - double* ranks, + const double* __restrict__ x_vals, + double* __restrict__ ranks, int size, double y_val) { @@ -28,18 +28,189 @@ extern "C" __global__ void fitness_kernel( if (idx < size) { double ans = coefficients[0]; + double x = x_vals[idx]; + for (int i = 1; i < num_coefficients; ++i) { - ans = ans * x_vals[idx] + coefficients[i]; + ans = ans * x + coefficients[i]; } ans -= y_val; - ranks[idx] = (ans == 0) ? 1.7976931348623157e+308 : fabs(1.0 / ans); + + if (isinf(ans) || isnan(ans)) { + ranks[idx] = 0.0; + } else { + ranks[idx] = 1.0 / (fabs(ans) + 1e-15); + } } } """ -@numba.jit(nopython=True, fastmath=True, parallel=True) +_FITNESS_KERNEL_FLOAT_DYNAMIC = """ +extern "C" __global__ void fitness_kernel_shared( + const double* __restrict__ coefficients, + int num_coefficients, + const double* __restrict__ x_vals, + double* __restrict__ ranks, + int size, + double y_val) +{ + // Dynamic Shared Memory declaration + extern __shared__ double s_coeffs[]; + + for (int i = threadIdx.x; i < num_coefficients; i += blockDim.x) { + s_coeffs[i] = coefficients[i]; + } + + __syncthreads(); + + int idx = threadIdx.x + blockIdx.x * blockDim.x; + if (idx < size) + { + double ans = s_coeffs[0]; + double x = x_vals[idx]; + + for (int i = 1; i < num_coefficients; ++i) + { + ans = ans * x + s_coeffs[i]; + } + + ans -= y_val; + + if (isinf(ans) || isnan(ans)) { + ranks[idx] = 0.0; + } else { + ranks[idx] = 1.0 / (fabs(ans) + 1e-15); + } + } +} +""" + +_FITNESS_KERNEL_COMPLEX = """ +struct Complex { + double r; + double i; +}; + +__device__ Complex c_add(Complex a, Complex b) { + return {a.r + b.r, a.i + b.i}; +} + +__device__ Complex c_mul(Complex a, Complex b) { + return { + a.r * b.r - a.i * b.i, + a.r * b.i + a.i * b.r + }; +} + +__device__ double c_abs(Complex a) { + return sqrt(a.r * a.r + a.i * a.i); +} + +extern "C" __global__ void fitness_kernel_complex( + const double* __restrict__ coeffs_real, + const double* __restrict__ coeffs_imag, + int num_coefficients, + const double* __restrict__ sol_real, + const double* __restrict__ sol_imag, + double* __restrict__ ranks, + int size, + double y_real, + double y_imag) +{ + int idx = threadIdx.x + blockIdx.x * blockDim.x; + if (idx < size) + { + Complex x = {sol_real[idx], sol_imag[idx]}; + Complex ans = {coeffs_real[0], coeffs_imag[0]}; + + for (int i = 1; i < num_coefficients; ++i) + { + Complex c = {coeffs_real[i], coeffs_imag[i]}; + ans = c_mul(ans, x); + ans = c_add(ans, c); + } + + Complex diff = {ans.r - y_real, ans.i - y_imag}; + + if (isinf(diff.r) || isinf(diff.i) || isnan(diff.r) || isnan(diff.i)) { + ranks[idx] = 0.0; + } else { + double modulus = hypot(diff.r, diff.i); + ranks[idx] = 1.0 / (modulus + 1e-15); + } + } +} +""" + +_FITNESS_KERNEL_COMPLEX_DYNAMIC = """ +struct Complex { + double r; + double i; +}; + +__device__ Complex c_add(Complex a, Complex b) { + return {a.r + b.r, a.i + b.i}; +} + +__device__ Complex c_mul(Complex a, Complex b) { + return { + a.r * b.r - a.i * b.i, + a.r * b.i + a.i * b.r + }; +} + +__device__ double c_abs(Complex a) { + return sqrt(a.r * a.r + a.i * a.i); +} + +extern "C" __global__ void fitness_kernel_complex_shared( + const double* __restrict__ coeffs_real, + const double* __restrict__ coeffs_imag, + int num_coefficients, + const double* __restrict__ sol_real, + const double* __restrict__ sol_imag, + double* __restrict__ ranks, + int size, + double y_real, + double y_imag) +{ + // Dynamic Shared Memory declaration + extern __shared__ double s_memory[]; + + for (int i = threadIdx.x; i < num_coefficients; i += blockDim.x) { + s_memory[2 * i] = coeffs_real[i]; + s_memory[2 * i + 1] = coeffs_imag[i]; + } + + __syncthreads(); + + int idx = threadIdx.x + blockIdx.x * blockDim.x; + if (idx < size) + { + Complex x = {sol_real[idx], sol_imag[idx]}; + Complex ans = {s_memory[0], s_memory[1]}; + + for (int i = 1; i < num_coefficients; ++i) + { + Complex c = {s_memory[2 * i], s_memory[2 * i + 1]}; + ans = c_mul(ans, x); + ans = c_add(ans, c); + } + + Complex diff = {ans.r - y_real, ans.i - y_imag}; + + if (isinf(diff.r) || isinf(diff.i) || isnan(diff.r) || isnan(diff.i)) { + ranks[idx] = 0.0; + } else { + double modulus = hypot(diff.r, diff.i); + ranks[idx] = 1.0 / (modulus + 1e-15); + } + } +} +""" + +@numba.jit(nopython=True, fastmath=True, parallel=True, cache=True) def _calculate_ranks_numba(solutions, coefficients, y_val, ranks): """ A Numba-jitted, parallel function to calculate fitness. @@ -59,10 +230,36 @@ def _calculate_ranks_numba(solutions, coefficients, y_val, ranks): ans -= y_val - if ans == 0.0: - ranks[idx] = 1.7976931348623157e+308 # np.finfo(float).max - else: - ranks[idx] = abs(1.0 / ans) + ranks[idx] = 1.0 / (abs(ans) + 1e-15) + + +@numba.jit(nopython=True, fastmath=True, parallel=True, cache=True) +def _calculate_ranks_complex_numba(solutions, coefficients, y_val, ranks): + """ + Parallel fitness calculation for Complex numbers on CPU. + Solutions and Coefficients must be of type complex128. + """ + num_coefficients = coefficients.shape[0] + data_size = solutions.shape[0] + + for idx in numba.prange(data_size): + x_val = solutions[idx] + + # Initialize with the leading coefficient + ans = coefficients[0] + + # Horner's Method + for i in range(1, num_coefficients): + ans = ans * x_val + coefficients[i] + + ans -= y_val + + # Calculate rank based on Modulus (Magnitude) + # abs(z) for a complex number returns sqrt(a^2 + b^2) + modulus = abs(ans) + + ranks[idx] = 1.0 / (modulus + 1e-15) + @dataclass class GA_Options: @@ -104,7 +301,10 @@ class GA_Options: groups roots more aggressively. A larger number (e.g., 7) is more precise but may return multiple near-identical roots. Default: 5 + find_complex (bool): Whether to find complex roots as well. Default: True """ + min_range: float = 0.0 # Returned for backwards compatibility even though it's no longer used + max_range: float = 0.0 # Returned for backwards compatibility even though it's no longer used num_of_generations: int = 10 data_size: int = 100000 mutation_strength: float = 0.01 @@ -114,6 +314,7 @@ class GA_Options: selection_percentile: float = 0.66 blend_alpha: float = 0.5 root_precision: int = 5 + find_complex: bool = True def __post_init__(self): """Validates the GA options after initialization.""" @@ -141,6 +342,8 @@ class GA_Options: UserWarning, stacklevel=2 ) + if self.min_range != 0.0 or self.max_range != 0.0: + warnings.warn("min_range and max_range are no longer used, instead cauchy's bound is used to find these values") def _get_cauchy_bound(coeffs: np.ndarray) -> float: """ @@ -151,6 +354,9 @@ def _get_cauchy_bound(coeffs: np.ndarray) -> float: R = 1 + max(|c_n-1/c_n|, |c_n-2/c_n|, ..., |c_0/c_n|) Where c_n is the leading coefficient (coeffs[0]). """ + if len(coeffs) <= 1: + return 1000.0 + # Normalize all coefficients by the leading coefficient normalized_coeffs = np.abs(coeffs[1:] / coeffs[0]) @@ -164,25 +370,31 @@ class Function: Represents an exponential function (polynomial) of the form: c_0*x^n + c_1*x^(n-1) + ... + c_n """ - def __init__(self, largest_exponent: int): + def __init__(self, largest_exponent: int, coefficients: Optional[List[Union[int, float, complex]]] = None): """ Initializes a function with its highest degree. Args: largest_exponent (int): The largest exponent (n) in the function. """ - if not isinstance(largest_exponent, int) or largest_exponent < 0: - raise ValueError("largest_exponent must be a non-negative integer.") self._largest_exponent = largest_exponent - self.coefficients: Optional[np.ndarray] = None - self._initialized = False + if coefficients is not None: + self.set_coeffs(coefficients) + # Verify user provided exponent matches if they provided both + if largest_exponent is not None and self._largest_exponent != largest_exponent: + raise ValueError("Provided largest_exponent does not match coefficient list length.") + elif largest_exponent is not None: + self.coefficients = None + self._initialized = False + else: + raise ValueError("Must provide either coefficients or largest_exponent.") - def set_coeffs(self, coefficients: List[Union[int, float]]): + def set_coeffs(self, coefficients: List[Union[int, float, complex]]): """ Sets the coefficients of the polynomial. Args: - coefficients (List[Union[int, float]]): A list of integer or float + coefficients (List[Union[int, float]]): A list of integer, float or complex coefficients. The list size must be largest_exponent + 1. @@ -198,13 +410,17 @@ class Function: if coefficients[0] == 0 and self._largest_exponent > 0: raise ValueError("The first constant (for the largest exponent) cannot be 0.") - # Check if any coefficient is a float - is_float = any(isinstance(c, float) for c in coefficients) + # Check for complex, then float, then int + is_complex = any(isinstance(c, complex) for c in coefficients) # Choose the dtype based on the input - target_dtype = np.float64 if is_float else np.int64 + if is_complex: + target_dtype = np.complex128 + else: + target_dtype = np.float64 self.coefficients = np.array(coefficients, dtype=target_dtype) + self._largest_exponent = len(coefficients) - 1 self._initialized = True def _check_initialized(self): @@ -305,7 +521,7 @@ class Function: return function - def get_real_roots(self, options: GA_Options = GA_Options(), use_cuda: bool = False) -> np.ndarray: + def get_real_roots(self, options: Optional[GA_Options] = None, use_cuda: bool = False) -> np.ndarray: """ Uses a genetic algorithm to find the approximate real roots of the function (where y=0). @@ -317,9 +533,32 @@ class Function: np.ndarray: An array of approximate root values. """ self._check_initialized() + if options is None: + options = GA_Options() + import copy + safe_options = copy.copy(options) + safe_options.find_complex = False + return self.solve_x(0.0, safe_options, use_cuda) + + + def get_roots(self, options: Optional[GA_Options] = None, use_cuda: bool = False) -> np.ndarray: + """ + Uses a genetic algorithm to find the approximate roots of the function (where y=0). + + Args: + options (GA_Options): Configuration for the genetic algorithm. + use_cuda (bool): If True, attempts to use CUDA for acceleration. + + Returns: + np.ndarray: An array of approximate root values. + """ + self._check_initialized() + if options is None: + options = GA_Options() return self.solve_x(0.0, options, use_cuda) - def solve_x(self, y_val: float, options: GA_Options = GA_Options(), use_cuda: bool = False) -> np.ndarray: + + def solve_x(self, y_val: Union[float, complex], options: Optional[GA_Options] = None, use_cuda: bool = False) -> np.ndarray: """ Uses a genetic algorithm to find x-values for a given y-value. @@ -332,18 +571,46 @@ class Function: np.ndarray: An array of approximate x-values. """ self._check_initialized() - if use_cuda and _CUPY_AVAILABLE: - return self._solve_x_cuda(y_val, options) + if options is None: + options = GA_Options() + if options.find_complex: + target_y = complex(y_val) + + if use_cuda and _CUPY_AVAILABLE: + return self._solve_complex_cuda(target_y, options) + else: + if use_cuda: + # Warn if user wanted CUDA but it's not available + warnings.warn( + "use_cuda=True was specified, but CuPy is not installed. " + "Falling back to NumPy (CPU) for complex roots.", + UserWarning + ) + return self._solve_complex_numpy(target_y, options) else: - if use_cuda: - warnings.warn( - "use_cuda=True was specified, but CuPy is not installed. " - "Falling back to NumPy (CPU). For GPU acceleration, " - "install with 'pip install polysolve[cuda]'.", - UserWarning - ) - - return self._solve_x_numpy(y_val, options) + if isinstance(y_val, complex): + if y_val.imag != 0: + warnings.warn( + "Complex y_val passed but options.find_complex is False. " + "The imaginary part of y_val will be ignored.", + UserWarning + ) + target_y = float(y_val.real) + else: + target_y = float(y_val) + + if use_cuda and _CUPY_AVAILABLE: + return self._solve_x_cuda(target_y, options) + else: + if use_cuda: + warnings.warn( + "use_cuda=True was specified, but CuPy is not installed. " + "Falling back to NumPy (CPU). For GPU acceleration, " + "install with 'pip install polysolve[cuda]'.", + UserWarning + ) + + return self._solve_x_numpy(target_y, options) def _solve_x_numpy(self, y_val: float, options: GA_Options) -> np.ndarray: """Genetic algorithm implementation using NumPy (CPU).""" @@ -358,19 +625,25 @@ class Function: mutation_size = int(data_size * mutation_ratio) random_size = data_size - elite_size - crossover_size - mutation_size + # Pre-calculate indices for slicing the destination array + idx_elite_end = elite_size + idx_cross_end = idx_elite_end + crossover_size + idx_mut_end = idx_cross_end + mutation_size + bound = _get_cauchy_bound(self.coefficients) min_r = -bound max_r = bound # Create initial random solutions - solutions = np.random.uniform(min_r, max_r, data_size) + src_solutions = np.random.uniform(min_r, max_r, data_size) + dst_solutions = np.empty(data_size, dtype=np.float64) # Pre-allocate ranks array ranks = np.empty(data_size, dtype=np.float64) for _ in range(options.num_of_generations): # Calculate fitness for all solutions (vectorized) - _calculate_ranks_numba(solutions, self.coefficients, y_val, ranks) + _calculate_ranks_numba(src_solutions, self.coefficients, y_val, ranks) parent_pool_size = int(data_size * options.selection_percentile) @@ -385,15 +658,13 @@ class Function: # --- Create the next generation --- # 1. Elitism: Keep the best solutions as-is - elite_solutions = solutions[elite_indices] + dst_solutions[:elite_size] = src_solutions[elite_indices] # 2. Crossover: Breed two parents to create a child # Select from the fitter PARENT POOL - parents1_idx = np.random.choice(parent_pool_indices, crossover_size) - parents2_idx = np.random.choice(parent_pool_indices, crossover_size) + parents1 = src_solutions[np.random.choice(parent_pool_indices, crossover_size)] + parents2 = src_solutions[np.random.choice(parent_pool_indices, crossover_size)] - parents1 = solutions[parents1_idx] - parents2 = solutions[parents2_idx] # Blend Crossover (BLX-alpha) alpha = options.blend_alpha @@ -409,34 +680,25 @@ class Function: new_max = p_max + (alpha * parent_range) # Create a new random child within the expanded range - crossover_solutions = np.random.uniform(new_min, new_max, crossover_size) + dst_solutions[idx_elite_end:idx_cross_end] = np.random.uniform(new_min, new_max) # 3. Mutation: # Select from the full list (indices 0 to data_size-1) - mutation_candidates = solutions[np.random.randint(0, data_size, mutation_size)] + mutation_candidates = src_solutions[np.random.randint(0, data_size, mutation_size)] # Use mutation_strength - mutation_factors = np.random.uniform( - 1 - options.mutation_strength, - 1 + options.mutation_strength, - mutation_size - ) - mutated_solutions = mutation_candidates * mutation_factors + noise = np.random.normal(0, options.mutation_strength, mutation_size) + dst_solutions[idx_cross_end:idx_mut_end] = mutation_candidates + noise # 4. New Randoms: Add new blood to prevent getting stuck - random_solutions = np.random.uniform(min_r, max_r, random_size) + dst_solutions[idx_mut_end:] = np.random.uniform(min_r, max_r, random_size) # Assemble the new generation - solutions = np.concatenate([ - elite_solutions, - crossover_solutions, - mutated_solutions, - random_solutions - ]) + src_solutions, dst_solutions = dst_solutions, src_solutions # --- Final Step: Return the best results --- # After all generations, do one last ranking to find the best solutions - _calculate_ranks_numba(solutions, self.coefficients, y_val, ranks) + _calculate_ranks_numba(src_solutions, self.coefficients, y_val, ranks) # 1. Define quality based on the user's desired precision # (e.g., precision=5 -> rank > 1e6, precision=8 -> rank > 1e9) @@ -444,7 +706,7 @@ class Function: quality_threshold = 10**(options.root_precision + 1) # 2. Get all solutions that meet this quality threshold - high_quality_solutions = solutions[ranks > quality_threshold] + high_quality_solutions = src_solutions[ranks > quality_threshold] if high_quality_solutions.size == 0: # No roots found that meet the quality, return empty @@ -457,6 +719,110 @@ class Function: unique_roots = np.unique(rounded_solutions) return np.sort(unique_roots) + + + def _solve_complex_numpy(self, y_val: complex, options: GA_Options) -> np.ndarray: + elite_ratio = options.elite_ratio + crossover_ratio = options.crossover_ratio + mutation_ratio = options.mutation_ratio + + data_size = options.data_size + + elite_size = int(data_size * elite_ratio) + crossover_size = int(data_size * crossover_ratio) + mutation_size = int(data_size * mutation_ratio) + random_size = data_size - elite_size - crossover_size - mutation_size + + # Pre-calculate indices for slicing the destination array + idx_elite_end = elite_size + idx_cross_end = idx_elite_end + crossover_size + idx_mut_end = idx_cross_end + mutation_size + + bound = _get_cauchy_bound(self.coefficients) + min_r = -bound + max_r = bound + + # 3. Initialize Population (Complex128) + real_part = np.random.uniform(min_r, max_r, data_size) + imag_part = np.random.uniform(min_r, max_r, data_size) + src_solutions = real_part + 1j * imag_part + + dst_solutions = np.empty(data_size, dtype=np.complex128) + + # Cast coefficients to complex128 for Numba compatibility + coeffs_complex = self.coefficients.astype(np.complex128) + ranks = np.empty(data_size, dtype=np.float64) + + for _ in range(options.num_of_generations): + # Calculate fitness for all solutions (vectorized) + _calculate_ranks_complex_numba(src_solutions, coeffs_complex, y_val, ranks) + + parent_pool_size = int(data_size * options.selection_percentile) + + # 1. Get indices for the elite solutions (O(N) operation) + # We find the 'elite_size'-th largest element. + elite_indices = np.argpartition(-ranks, elite_size)[:elite_size] + + # 2. Get indices for the parent pool (O(N) operation) + # We find the 'parent_pool_size'-th largest element. + parent_pool_indices = np.argpartition(-ranks, parent_pool_size)[:parent_pool_size] + + # --- Create the next generation --- + + # 1. Elitism: Keep the best solutions as-is + dst_solutions[:elite_size] = src_solutions[elite_indices] + + # 2. Crossover: Breed two parents to create a child + # Select from the fitter PARENT POOL + p1 = src_solutions[np.random.choice(parent_pool_indices, crossover_size)] + p2 = src_solutions[np.random.choice(parent_pool_indices, crossover_size)] + + # Calculate difference vectors + diff_real = p2.real - p1.real + diff_imag = p2.imag - p1.imag + + alpha = options.blend_alpha + + # Generate independant weights for Real and Imaginary parts + # This creates a 2D search area instead of a 1D + u_real = np.random.uniform(-alpha, 1.0 + alpha, crossover_size) + u_imag = np.random.uniform(-alpha, 1.0 + alpha, crossover_size) + + child_real = p1.real + (u_real * diff_real) + child_imag = p1.imag + (u_imag * diff_imag) + + dst_solutions[idx_elite_end:idx_cross_end] = child_real + 1j * child_imag + + # 3. Mutation: + # Select from the full list (indices 0 to data_size-1) + mut_candidates = src_solutions[np.random.randint(0, data_size, mutation_size)] + + noise_real = np.random.normal(0, options.mutation_strength, mutation_size) + noise_imag = np.random.normal(0, options.mutation_strength, mutation_size) + + dst_solutions[idx_cross_end:idx_mut_end] = (mut_candidates.real + noise_real) + 1j * (mut_candidates.imag + noise_imag) + + # 4. New Randoms: Add new blood to prevent getting stuck + rand_real = np.random.uniform(min_r, max_r, random_size) + rand_imag = np.random.uniform(min_r, max_r, random_size) + dst_solutions[idx_mut_end:] = rand_real + 1j * rand_imag + + # Assemble the new generation + src_solutions, dst_solutions = dst_solutions, src_solutions + + # 5. Final Ranking & Clustering + _calculate_ranks_complex_numba(src_solutions, coeffs_complex, y_val, ranks) + quality_threshold = 10**(options.root_precision + 1) + high_quality_solutions = src_solutions[ranks > quality_threshold] + + if high_quality_solutions.size == 0: return np.array([]) + + # Rounding complex numbers: round real and imag separately + rounded_real = np.round(high_quality_solutions.real, options.root_precision) + rounded_imag = np.round(high_quality_solutions.imag, options.root_precision) + + return np.unique(rounded_real + 1j * rounded_imag) + def _solve_x_cuda(self, y_val: float, options: GA_Options) -> np.ndarray: """Genetic algorithm implementation using CuPy (GPU/CUDA).""" @@ -472,97 +838,115 @@ class Function: mutation_size = int(data_size * mutation_ratio) random_size = data_size - elite_size - crossover_size - mutation_size - # ALWAYS cast coefficients to float64 for the kernel. - fitness_gpu = cupy.RawKernel(_FITNESS_KERNEL_FLOAT, 'fitness_kernel') - d_coefficients = cupy.array(self.coefficients, dtype=cupy.float64) - bound = _get_cauchy_bound(self.coefficients) min_r = -bound max_r = bound # Create initial random solutions on the GPU - d_solutions = cupy.random.uniform( - min_r, max_r, options.data_size, dtype=cupy.float64 + d_src_solutions = cupy.random.uniform( + min_r, max_r, data_size, dtype=cupy.float64 ) - d_ranks = cupy.empty(options.data_size, dtype=cupy.float64) - # Configure kernel launch parameters + d_dst_solutions = cupy.empty(data_size, dtype=cupy.float64) + + d_ranks = cupy.empty(data_size, dtype=cupy.float64) + + d_coefficients = cupy.array(self.coefficients, dtype=cupy.float64) + + # Calculate Shared Memory Size + num_coeffs = len(self.coefficients) + required_shared_mem_bytes = num_coeffs * 8 + + device = cupy.cuda.Device() + max_shared_mem = device.attributes['MaxSharedMemoryPerBlock'] + + use_shared_mem = True + + if required_shared_mem_bytes > max_shared_mem: + # The polynomial is too big for the cache! + # We must fall back to the slower Global Memory kernel to prevent a crash. + use_shared_mem = False + warnings.warn( + f"Polynomial degree ({num_coeffs}) exceeds GPU Shared Memory limit " + f"({max_shared_mem} bytes). Falling back to Global Memory (slower).", + UserWarning + ) + + # Kernel Setup + if use_shared_mem: + fitness_gpu = cupy.RawKernel(_FITNESS_KERNEL_FLOAT_DYNAMIC, 'fitness_kernel_shared') + kwargs = {'shared_mem': required_shared_mem_bytes} + else: + fitness_gpu = cupy.RawKernel(_FITNESS_KERNEL_FLOAT, 'fitness_kernel') + kwargs = {} + threads_per_block = 512 blocks_per_grid = (options.data_size + threads_per_block - 1) // threads_per_block + # Indices for slicing the destination buffer + idx_elite_end = elite_size + idx_cross_end = idx_elite_end + crossover_size + idx_mut_end = idx_cross_end + mutation_size + for i in range(options.num_of_generations): # Run the fitness kernel on the GPU + fitness_gpu( (blocks_per_grid,), (threads_per_block,), - (d_coefficients, d_coefficients.size, d_solutions, d_ranks, d_solutions.size, y_val) + (d_coefficients, d_coefficients.size, d_src_solutions, d_ranks, d_src_solutions.size, y_val), + **kwargs ) # Sort solutions by rank on the GPU sorted_indices = cupy.argsort(-d_ranks) - d_solutions = d_solutions[sorted_indices] + d_sorted_src_solutions = d_src_solutions[sorted_indices] # --- Create the next generation --- # 1. Elitism - d_elite_solutions = d_solutions[:elite_size] + d_dst_solutions[:elite_size] = d_sorted_src_solutions[:elite_size] # 2. Crossover parent_pool_size = int(data_size * options.selection_percentile) # Select from the fitter PARENT POOL - parent1_indices = cupy.random.randint(0, parent_pool_size, crossover_size) - parent2_indices = cupy.random.randint(0, parent_pool_size, crossover_size) + p1_indices = cupy.random.randint(0, parent_pool_size, crossover_size) + p2_indices = cupy.random.randint(0, parent_pool_size, crossover_size) # Get parents directly from the sorted solutions array using the pool-sized indices - d_parents1 = d_solutions[parent1_indices] - d_parents2 = d_solutions[parent2_indices] + d_p1 = d_sorted_src_solutions[p1_indices] + d_p2 = d_sorted_src_solutions[p2_indices] # Blend Crossover (BLX-alpha) alpha = options.blend_alpha - # Find min/max for all parent pairs - d_p_min = cupy.minimum(d_parents1, d_parents2) - d_p_max = cupy.maximum(d_parents1, d_parents2) + diff = d_p2 - d_p1 + u = cupy.random.uniform(-alpha, 1.0 + alpha, crossover_size) - # Calculate range (I) - d_parent_range = d_p_max - d_p_min - - # Calculate new min/max for the expanded range - d_new_min = d_p_min - (alpha * d_parent_range) - d_new_max = d_p_max + (alpha * d_parent_range) - - # Create a new random child within the expanded range - d_crossover_solutions = cupy.random.uniform(d_new_min, d_new_max, crossover_size) + d_dst_solutions[idx_elite_end:idx_cross_end] = d_p1 + (u * diff) # 3. Mutation # Select from the full list (indices 0 to data_size-1) mutation_indices = cupy.random.randint(0, data_size, mutation_size) - d_mutation_candidates = d_solutions[mutation_indices] + d_mutation_candidates = d_sorted_src_solutions[mutation_indices] - # Use mutation_strength (the new name) - d_mutation_factors = cupy.random.uniform( - 1 - options.mutation_strength, - 1 + options.mutation_strength, - mutation_size - ) - d_mutated_solutions = d_mutation_candidates * d_mutation_factors + # Use mutation_strength + noise = cupy.random.normal(0, options.mutation_strength, mutation_size) + d_dst_solutions[idx_cross_end:idx_mut_end] = d_mutation_candidates + noise # 4. New Randoms - d_random_solutions = cupy.random.uniform( + d_dst_solutions[idx_mut_end:] = cupy.random.uniform( min_r, max_r, random_size, dtype=cupy.float64 ) # Assemble the new generation - d_solutions = cupy.concatenate([ - d_elite_solutions, - d_crossover_solutions, - d_mutated_solutions, - d_random_solutions - ]) + # d_dst becomes the new source for the next generation + d_src_solutions, d_dst_solutions = d_dst_solutions, d_src_solutions # --- Final Step: Return the best results --- # After all generations, do one last ranking to find the best solutions fitness_gpu( (blocks_per_grid,), (threads_per_block,), - (d_coefficients, d_coefficients.size, d_solutions, d_ranks, d_solutions.size, y_val) + (d_coefficients, d_coefficients.size, d_src_solutions, d_ranks, d_src_solutions.size, y_val), + **kwargs ) # 1. Define quality based on the user's desired precision @@ -571,7 +955,7 @@ class Function: quality_threshold = 10**(options.root_precision + 1) # 2. Get all solutions that meet this quality threshold - d_high_quality_solutions = d_solutions[d_ranks > quality_threshold] + d_high_quality_solutions = d_src_solutions[d_ranks > quality_threshold] if d_high_quality_solutions.size == 0: return np.array([]) @@ -587,6 +971,165 @@ class Function: return final_solutions_gpu.get() + def _solve_complex_cuda(self, y_val: complex, options: GA_Options) -> np.ndarray: + elite_ratio = options.elite_ratio + crossover_ratio = options.crossover_ratio + mutation_ratio = options.mutation_ratio + + data_size = options.data_size + + elite_size = int(data_size * elite_ratio) + crossover_size = int(data_size * crossover_ratio) + mutation_size = int(data_size * mutation_ratio) + random_size = data_size - elite_size - crossover_size - mutation_size + + # 1. Prepare Coefficients (Split into Real/Imag for the Kernel) + # We pass real and imag arrays separately to avoid struct alignment issues + coeffs = self.coefficients.astype(np.complex128) + d_coeffs_real = cupy.array(coeffs.real, dtype=cupy.float64) + d_coeffs_imag = cupy.array(coeffs.imag, dtype=cupy.float64) + + d_y_real = cupy.float64(y_val.real) + d_y_imag = cupy.float64(y_val.imag) + + bound = _get_cauchy_bound(self.coefficients) + min_r = -bound + max_r = bound + + real_part = cupy.random.uniform(min_r, max_r, data_size, dtype=cupy.float64) + imag_part = cupy.random.uniform(min_r, max_r, data_size, dtype=cupy.float64) + d_src_solutions = real_part + 1j * imag_part + + d_dst_solutions = cupy.empty(data_size, dtype=cupy.complex128) + d_ranks = cupy.empty(data_size, dtype=cupy.float64) + + # Calculate Shared Memory Size + num_coeffs = len(self.coefficients) + required_shared_mem_bytes = (num_coeffs * 8) * 2 + + device = cupy.cuda.Device() + max_shared_mem = device.attributes['MaxSharedMemoryPerBlock'] + + use_shared_mem = True + + if required_shared_mem_bytes > max_shared_mem: + # The polynomial is too big for the cache! + # We must fall back to the slower Global Memory kernel to prevent a crash. + use_shared_mem = False + warnings.warn( + f"Polynomial degree ({num_coeffs}) exceeds GPU Shared Memory limit " + f"({max_shared_mem} bytes). Falling back to Global Memory (slower).", + UserWarning + ) + + # Kernel Setup + if use_shared_mem: + fitness_gpu = cupy.RawKernel(_FITNESS_KERNEL_COMPLEX_DYNAMIC, 'fitness_kernel_complex_shared') + kwargs = {'shared_mem': required_shared_mem_bytes} + else: + fitness_gpu = cupy.RawKernel(_FITNESS_KERNEL_COMPLEX, 'fitness_kernel_complex') + kwargs = {} + + threads_per_block = 512 + blocks_per_grid = (options.data_size + threads_per_block - 1) // threads_per_block + + idx_elite_end = elite_size + idx_cross_end = idx_elite_end + crossover_size + idx_mut_end = idx_cross_end + mutation_size + + for _ in range(options.num_of_generations): + d_real_cont = cupy.ascontiguousarray(d_src_solutions.real) + d_imag_cont = cupy.ascontiguousarray(d_src_solutions.imag) + + fitness_gpu( + (blocks_per_grid,), (threads_per_block,), + (d_coeffs_real, d_coeffs_imag, d_coeffs_real.size, + d_real_cont, d_imag_cont, d_ranks, data_size, + d_y_real, d_y_imag), + **kwargs + ) + + # Sort (using d_ranks) + sorted_indices = cupy.argsort(-d_ranks) + d_sorted_src_solutions = d_src_solutions[sorted_indices] + + # 1. Elite: Keep the best + d_dst_solutions[:elite_size] = d_sorted_src_solutions[:elite_size] + + # 2. Crossover: Blend Crossover (BLX-alpha) + # Select parents from the top percentile + parent_pool_size = int(data_size * options.selection_percentile) + + # Randomly pair parents + p1_indices = cupy.random.randint(0, parent_pool_size, crossover_size) + p2_indices = cupy.random.randint(0, parent_pool_size, crossover_size) + + p1 = d_sorted_src_solutions[p1_indices] + p2 = d_sorted_src_solutions[p2_indices] + + # Calculate difference vectors + diff_real = p2.real - p1.real + diff_imag = p2.imag - p1.imag + + alpha = options.blend_alpha + + # Generate independant weights for Real and Imaginary parts + # This creates a 2D search area instead of a 1D + u_real = cupy.random.uniform(-alpha, 1.0 + alpha, crossover_size) + u_imag = cupy.random.uniform(-alpha, 1.0 + alpha, crossover_size) + + child_real = p1.real + (u_real * diff_real) + child_imag = p1.imag + (u_imag * diff_imag) + + # Apply Crossover + d_dst_solutions[idx_elite_end:idx_cross_end] = child_real + 1j * child_imag + + # 3. Mutation: Perturb existing solutions + # Pick random candidates from the full population + mut_indices = cupy.random.randint(0, data_size, mutation_size) + mut_candidates = d_sorted_src_solutions[mut_indices] + + # Generate Independent Scaling Factors for Real and Imaginary parts + # Range: [1 - strength, 1 + strength] + noise_real = cupy.random.normal(0, options.mutation_strength, mutation_size) + noise_imag = cupy.random.normal(0, options.mutation_strength, mutation_size) + + # Apply Mutation: Scale Real/Imag independently to allow "rotation" off the line + d_dst_solutions[idx_cross_end:idx_mut_end] = (mut_candidates.real + noise_real) + 1j * (mut_candidates.imag + noise_imag) + + # 4. Random Injection: Fresh genetic material + rand_real = cupy.random.uniform(min_r, max_r, random_size, dtype=cupy.float64) + rand_imag = cupy.random.uniform(min_r, max_r, random_size, dtype=cupy.float64) + d_dst_solutions[idx_mut_end:] = rand_real + 1j * rand_imag + + d_src_solutions, d_dst_solutions = d_dst_solutions, d_src_solutions + + d_real_cont = cupy.ascontiguousarray(d_src_solutions.real) + d_imag_cont = cupy.ascontiguousarray(d_src_solutions.imag) + + # Final Rank + fitness_gpu( + (blocks_per_grid,), (threads_per_block,), + (d_coeffs_real, d_coeffs_imag, d_coeffs_real.size, + d_real_cont, d_imag_cont, d_ranks, data_size, + d_y_real, d_y_imag), + **kwargs + ) + + # Filtering & Return + quality_threshold = 10**(options.root_precision + 1) + d_high_quality_solutions = d_src_solutions[d_ranks > quality_threshold] + + if d_high_quality_solutions.size == 0: return np.array([]) + + rounded_real = cupy.round(d_high_quality_solutions.real, options.root_precision) + rounded_imag = cupy.round(d_high_quality_solutions.imag, options.root_precision) + + d_unique = cupy.unique(rounded_real + 1j * rounded_imag) + + return cupy.asnumpy(d_unique) + + def __str__(self) -> str: """Returns a human-readable string representation of the function.""" self._check_initialized() @@ -642,10 +1185,17 @@ class Function: other._check_initialized() new_coefficients = np.polyadd(self.coefficients, other.coefficients) + new_coefficients = self._strip_leading_zeros(new_coefficients) result_func = Function(len(new_coefficients) - 1) result_func.set_coeffs(new_coefficients.tolist()) return result_func + + def _strip_leading_zeros(self, coeffs: np.ndarray) -> np.ndarray: + # Remove leading zeros + while len(coeffs) > 1 and np.isclose(coeffs[0], 0): + coeffs = coeffs[1:] + return coeffs def __sub__(self, other: 'Function') -> 'Function': """Subtracts another Function object from this one.""" @@ -653,12 +1203,13 @@ class Function: other._check_initialized() new_coefficients = np.polysub(self.coefficients, other.coefficients) + new_coefficients = self._strip_leading_zeros(new_coefficients) result_func = Function(len(new_coefficients) - 1) result_func.set_coeffs(new_coefficients.tolist()) return result_func - def _multiply_by_scalar(self, scalar: Union[int, float]) -> 'Function': + def _multiply_by_scalar(self, scalar: Union[int, float, complex]) -> 'Function': """Helper method to multiply the function by a scalar constant.""" self._check_initialized() @@ -668,6 +1219,7 @@ class Function: return result_func new_coefficients = self.coefficients * scalar + new_coefficients = self._strip_leading_zeros(new_coefficients) result_func = Function(self._largest_exponent) result_func.set_coeffs(new_coefficients.tolist()) @@ -680,6 +1232,7 @@ class Function: # np.polymul performs convolution of coefficients to multiply polynomials new_coefficients = np.polymul(self.coefficients, other.coefficients) + new_coefficients = self._strip_leading_zeros(new_coefficients) # The degree of the resulting polynomial is derived from the new coefficients new_degree = len(new_coefficients) - 1 @@ -688,7 +1241,7 @@ class Function: result_func.set_coeffs(new_coefficients.tolist()) return result_func - def __mul__(self, other: Union['Function', int, float]) -> 'Function': + def __mul__(self, other: Union['Function', int, float, complex]) -> 'Function': """Multiplies the function by a scalar constant.""" if isinstance(other, (int, float)): return self._multiply_by_scalar(other) @@ -697,17 +1250,17 @@ class Function: else: return NotImplemented - def __rmul__(self, scalar: Union[int, float]) -> 'Function': + def __rmul__(self, scalar: Union[int, float, complex]) -> 'Function': """Handles scalar multiplication from the right (e.g., 3 * func).""" return self.__mul__(scalar) - def __imul__(self, other: Union['Function', int, float]) -> 'Function': + def __imul__(self, other: Union['Function', int, float, complex]) -> 'Function': """Performs in-place multiplication by a scalar (func *= 3).""" self._check_initialized() - if isinstance(other, (int, float)): + if isinstance(other, (int, float, complex)): if other == 0: self.coefficients = np.array([0], dtype=self.coefficients.dtype) self._largest_exponent = 0 @@ -737,18 +1290,21 @@ class Function: if not self._initialized or not other._initialized: return False - return np.array_equal(self.coefficients, other.coefficients) + c1 = self._strip_leading_zeros(self.coefficients) + c2 = self._strip_leading_zeros(other.coefficients) + + if c1.shape != c2.shape: + return False + + return np.allclose(c1, c2) - def quadratic_solve(self) -> Optional[List[float]]: + def quadratic_solve(self) -> Optional[List[complex]]: """ - Calculates the real roots of a quadratic function using the quadratic formula. - - Args: - f (Function): A Function object of degree 2. + Calculates the roots (real or complex) of a quadratic function. Returns: - Optional[List[float]]: A list containing the two real roots, or None if there are no real roots. + Optional[List[complex]]: A list containing the two roots """ self._check_initialized() if self.largest_exponent != 2: @@ -759,30 +1315,16 @@ class Function: discriminant = (b**2) - (4*a*c) sqrt_discriminant = cmath.sqrt(discriminant) + + if b >= 0: + sign_b = 1 + else: + sign_b = -1 - # 1. Calculate the first root. - # We use math.copysign(val, sign) to get the sign of b. - # This ensures (-b - sign*sqrt) is always an *addition* - # (or subtraction of a smaller from a larger number), - # avoiding catastrophic cancellation. - root1 = (-b - math.copysign(sqrt_discriminant, b)) / (2 * a) + root1 = (-b - sign_b * sqrt_discriminant) / (2 * a) - # 2. Calculate the second root using Vieta's formulas. - # We know that root1 * root2 = c / a. - # This is just a division, which is numerically stable. - - # Handle the edge case where c=0. - # If c=0, then root1 is 0.0, and root2 is -b/a - # We can't divide by root1=0, so we check. - if root1 == 0.0: - # If c is also 0, the other root is -b/a - if c == 0.0: - root2 = -b / a - else: - # This case (root1=0 but c!=0) shouldn't happen - # with real numbers, but it's safe to just - # return the one root we found. - return [0.0] + if abs(root1) < 1e-15: + root2 = (-b + sign_b * sqrt_discriminant) / (2 * a) else: # Standard case: Use Vieta's formula root2 = (c / a) / root1 @@ -811,24 +1353,34 @@ if __name__ == '__main__': ddf1 = f1.nth_derivative(2) print(f"Second derivative of f1: {ddf1}") + fc = Function(2, coefficients=[1, 2, 2]) + print(f"\nFunction fc: {f1}") + # --- Root Finding --- # 1. Analytical solution for quadratic roots_analytic = f1.quadratic_solve() - print(f"Analytic roots of f1: {roots_analytic}") # Expected: -1, 2.5 + print(f"\nAnalytic roots of f1: {roots_analytic}") # Expected: -1, 2.5 + c_roots_analytic = fc.quadratic_solve() + print(f"Analytic roots of fc: {c_roots_analytic}") # Expected: -1-j, -1+j # 2. Genetic algorithm solution - ga_opts = GA_Options(num_of_generations=20, data_size=50000) - print("\nFinding roots with Genetic Algorithm (CPU)...") + ga_opts = GA_Options(num_of_generations=100, data_size=100000, root_precision=3, selection_percentile=0.75) + print("\nFinding real roots with Genetic Algorithm (CPU)...") roots_ga_cpu = f1.get_real_roots(ga_opts) - print(f"Approximate roots from GA (CPU): {roots_ga_cpu}") - print("(Note: GA provides approximations around the true roots)") + print(f"Approximate real roots from GA (CPU): {roots_ga_cpu}") + print("\nFinding all roots of fc with Genetic Algorithm (CPU)...") + c_roots_ga_cpu = fc.get_roots(ga_opts) + print(f"Approximate roots of fc from GA (CPU): {c_roots_ga_cpu}") # 3. CUDA accelerated genetic algorithm if _CUPY_AVAILABLE: - print("\nFinding roots with Genetic Algorithm (CUDA)...") + print("\nFinding real roots with Genetic Algorithm (GPU)...") # Since this PC has an RTX 4060 Ti, we can use the CUDA version. roots_ga_gpu = f1.get_real_roots(ga_opts, use_cuda=True) - print(f"Approximate roots from GA (GPU): {roots_ga_gpu}") + print(f"Approximate real roots from GA (GPU): {roots_ga_gpu}") + print("\nFinding all roots of fc with Genetic Algorithm (GPU)...") + c_roots_ga_gpu = fc.get_roots(ga_opts) + print(f"Approximate roots of fc from GA (GPU): {c_roots_ga_gpu}") else: print("\nSkipping CUDA example: CuPy library not found or no compatible GPU.") diff --git a/tests/test_polysolve.py b/tests/test_polysolve.py index dcef499..32803a4 100644 --- a/tests/test_polysolve.py +++ b/tests/test_polysolve.py @@ -43,6 +43,11 @@ def base_func(): f.set_coeffs([1, 2, 3]) return f +@pytest.fixture +def complex_func(): + f = Function(2, [1, 2, 2]) + return f + # --- Core Functionality Tests --- def test_solve_y(quadratic_func): @@ -162,3 +167,36 @@ def test_get_real_roots_cuda(quadratic_func): # Verify that the CUDA implementation also finds the correct roots within tolerance. npt.assert_allclose(np.sort(roots), np.sort(expected_roots), atol=1e-2) +def test_get_roots_numpy(complex_func): + """ + Tests that the NumPy-based genetic algorithm approximates the roots correctly. + """ + # Using more generations for higher accuracy in testing + ga_opts = GA_Options(num_of_generations=50, data_size=200000, selection_percentile=0.66, root_precision=3) + + roots = complex_func.get_roots(ga_opts, use_cuda=False) + + # Check if the algorithm found values close to the two known roots. + # We don't know which order they'll be in, so we check for presence. + expected_roots = np.array([-1.0-1.j, -1.0+1.j]) + + npt.assert_allclose(np.sort(roots), np.sort(expected_roots), atol=1e-2) + + +@pytest.mark.skipif(not _CUPY_AVAILABLE, reason="CuPy is not installed, skipping CUDA test.") +def test_get_roots_cuda(complex_func): + """ + Tests that the CUDA-based genetic algorithm approximates the roots correctly. + This test implicitly verifies that the CUDA kernel is functioning. + It will be skipped automatically if CuPy is not available. + """ + + ga_opts = GA_Options(num_of_generations=50, data_size=200000, selection_percentile=0.66, root_precision=3) + + roots = complex_func.get_roots(ga_opts, use_cuda=True) + + expected_roots = np.array([-1.0-1.j, -1+1.j]) + + # Verify that the CUDA implementation also finds the correct roots within tolerance. + npt.assert_allclose(np.sort(roots), np.sort(expected_roots), atol=1e-2) +