feat: Added function * function multiplication

This commit is contained in:
2025-06-17 14:26:26 -04:00
parent 8656b558b4
commit ee414ea0dc
2 changed files with 72 additions and 20 deletions

View File

@ -1,7 +1,7 @@
import math import math
import numpy as np import numpy as np
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Optional from typing import List, Optional, Union
import warnings import warnings
# Attempt to import CuPy for CUDA acceleration. # Attempt to import CuPy for CUDA acceleration.
@ -393,11 +393,10 @@ class Function:
result_func.set_coeffs(new_coefficients.tolist()) result_func.set_coeffs(new_coefficients.tolist())
return result_func return result_func
def __mul__(self, scalar: int) -> 'Function': def _multiply_by_scalar(self, scalar: Union[int, float]) -> 'Function':
"""Multiplies the function by a scalar constant.""" """Helper method to multiply the function by a scalar constant."""
self._check_initialized() self._check_initialized() # It's good practice to check here too
if not isinstance(scalar, (int, float)):
return NotImplemented
if scalar == 0: if scalar == 0:
result_func = Function(0) result_func = Function(0)
result_func.set_coeffs([0]) result_func.set_coeffs([0])
@ -409,19 +408,39 @@ class Function:
result_func.set_coeffs(new_coefficients.tolist()) result_func.set_coeffs(new_coefficients.tolist())
return result_func return result_func
def __rmul__(self, scalar: int) -> 'Function': def _multiply_by_function(self, other: 'Function') -> 'Function':
"""Helper method for polynomial multiplication (Function * Function)."""
self._check_initialized()
other._check_initialized()
# np.polymul performs convolution of coefficients to multiply polynomials
new_coefficients = np.polymul(self.coefficients, other.coefficients)
# The degree of the resulting polynomial is derived from the new coefficients
new_degree = len(new_coefficients) - 1
result_func = Function(new_degree)
result_func.set_coeffs(new_coefficients.tolist())
return result_func
def __mul__(self, other: Union['Function', int, float]) -> 'Function':
"""Multiplies the function by a scalar constant."""
if isinstance(other, (int, float)):
return self._multiply_by_scalar(other)
elif isinstance(other, self.__class__):
return self._multiply_by_function(other)
else:
return NotImplemented
def __rmul__(self, scalar: Union[int, float]) -> 'Function':
"""Handles scalar multiplication from the right (e.g., 3 * func).""" """Handles scalar multiplication from the right (e.g., 3 * func)."""
return self.__mul__(scalar) return self.__mul__(scalar)
def __imul__(self, scalar: int) -> 'Function': def __imul__(self, other: Union['Function', int, float]) -> 'Function':
"""Performs in-place multiplication by a scalar (func *= 3).""" """Performs in-place multiplication by a scalar (func *= 3)."""
self._check_initialized()
if not isinstance(scalar, (int, float)):
return NotImplemented
if scalar == 0:
raise ValueError("Cannot multiply a function by 0.")
self.coefficients *= scalar self.coefficients *= other
return self return self
@ -507,3 +526,17 @@ if __name__ == '__main__':
# Multiplication: (x + 10) * 3 = 3x + 30 # Multiplication: (x + 10) * 3 = 3x + 30
f_mul = f2 * 3 f_mul = f2 * 3
print(f"f2 * 3 = {f_mul}") print(f"f2 * 3 = {f_mul}")
# f3 represents 2x^2 + 3x + 1
f3 = Function(2)
f3.set_coeffs([2, 3, 1])
print(f"Function f3: {f3}")
# f4 represents 5x - 4
f4 = Function(1)
f4.set_coeffs([5, -4])
print(f"Function f4: {f4}")
# Multiply the two functions
product_func = f3 * f4
print(f"f3 * f4 = {product_func}")

View File

@ -24,6 +24,18 @@ def linear_func() -> Function:
f.set_coeffs([1, 10]) f.set_coeffs([1, 10])
return f return f
@pytest.fixture
def m_func_1() -> Function:
f = Function(2)
f.set_coeffs([2, 3, 1])
return f
@pytest.fixture
def m_func_2() -> Function:
f = Function(1)
f.set_coeffs([5, -4])
return f
# --- Core Functionality Tests --- # --- Core Functionality Tests ---
def test_solve_y(quadratic_func): def test_solve_y(quadratic_func):
@ -68,13 +80,20 @@ def test_subtraction(quadratic_func, linear_func):
assert result.largest_exponent == 2 assert result.largest_exponent == 2
assert np.array_equal(result.coefficients, [2, -4, -15]) assert np.array_equal(result.coefficients, [2, -4, -15])
def test_multiplication(linear_func): def test_scalar_multiplication(linear_func):
"""Tests the multiplication of a Function object by a scalar.""" """Tests the multiplication of a Function object by a scalar."""
# (x + 10) * 3 = 3x + 30 # (x + 10) * 3 = 3x + 30
result = linear_func * 3 result = linear_func * 3
assert result.largest_exponent == 1 assert result.largest_exponent == 1
assert np.array_equal(result.coefficients, [3, 30]) assert np.array_equal(result.coefficients, [3, 30])
def test_function_multiplication(m_func_1, m_func_2):
"""Tests the multiplication of two Function objects."""
# (2x^2 + 3x + 1) * (5x -4) = 10x^3 + 7x^2 - 7x -4
result = m_func_1 * m_func_2
assert result.largest_exponent == 3
assert np.array_equal(result.coefficients, [19, 7, -7, -4])
# --- Genetic Algorithm Root-Finding Tests --- # --- Genetic Algorithm Root-Finding Tests ---
def test_get_real_roots_numpy(quadratic_func): def test_get_real_roots_numpy(quadratic_func):