feat: Added function * function multiplication

This commit is contained in:
2025-06-17 14:26:26 -04:00
parent 8656b558b4
commit ee414ea0dc
2 changed files with 72 additions and 20 deletions

View File

@ -1,7 +1,7 @@
import math
import numpy as np
from dataclasses import dataclass
from typing import List, Optional
from typing import List, Optional, Union
import warnings
# Attempt to import CuPy for CUDA acceleration.
@ -393,11 +393,10 @@ class Function:
result_func.set_coeffs(new_coefficients.tolist())
return result_func
def __mul__(self, scalar: int) -> 'Function':
"""Multiplies the function by a scalar constant."""
self._check_initialized()
if not isinstance(scalar, (int, float)):
return NotImplemented
def _multiply_by_scalar(self, scalar: Union[int, float]) -> 'Function':
"""Helper method to multiply the function by a scalar constant."""
self._check_initialized() # It's good practice to check here too
if scalar == 0:
result_func = Function(0)
result_func.set_coeffs([0])
@ -409,19 +408,39 @@ class Function:
result_func.set_coeffs(new_coefficients.tolist())
return result_func
def __rmul__(self, scalar: int) -> 'Function':
def _multiply_by_function(self, other: 'Function') -> 'Function':
"""Helper method for polynomial multiplication (Function * Function)."""
self._check_initialized()
other._check_initialized()
# np.polymul performs convolution of coefficients to multiply polynomials
new_coefficients = np.polymul(self.coefficients, other.coefficients)
# The degree of the resulting polynomial is derived from the new coefficients
new_degree = len(new_coefficients) - 1
result_func = Function(new_degree)
result_func.set_coeffs(new_coefficients.tolist())
return result_func
def __mul__(self, other: Union['Function', int, float]) -> 'Function':
"""Multiplies the function by a scalar constant."""
if isinstance(other, (int, float)):
return self._multiply_by_scalar(other)
elif isinstance(other, self.__class__):
return self._multiply_by_function(other)
else:
return NotImplemented
def __rmul__(self, scalar: Union[int, float]) -> 'Function':
"""Handles scalar multiplication from the right (e.g., 3 * func)."""
return self.__mul__(scalar)
def __imul__(self, scalar: int) -> 'Function':
def __imul__(self, other: Union['Function', int, float]) -> 'Function':
"""Performs in-place multiplication by a scalar (func *= 3)."""
self._check_initialized()
if not isinstance(scalar, (int, float)):
return NotImplemented
if scalar == 0:
raise ValueError("Cannot multiply a function by 0.")
self.coefficients *= scalar
self.coefficients *= other
return self
@ -507,3 +526,17 @@ if __name__ == '__main__':
# Multiplication: (x + 10) * 3 = 3x + 30
f_mul = f2 * 3
print(f"f2 * 3 = {f_mul}")
# f3 represents 2x^2 + 3x + 1
f3 = Function(2)
f3.set_coeffs([2, 3, 1])
print(f"Function f3: {f3}")
# f4 represents 5x - 4
f4 = Function(1)
f4.set_coeffs([5, -4])
print(f"Function f4: {f4}")
# Multiply the two functions
product_func = f3 * f4
print(f"f3 * f4 = {product_func}")

View File

@ -24,6 +24,18 @@ def linear_func() -> Function:
f.set_coeffs([1, 10])
return f
@pytest.fixture
def m_func_1() -> Function:
f = Function(2)
f.set_coeffs([2, 3, 1])
return f
@pytest.fixture
def m_func_2() -> Function:
f = Function(1)
f.set_coeffs([5, -4])
return f
# --- Core Functionality Tests ---
def test_solve_y(quadratic_func):
@ -68,13 +80,20 @@ def test_subtraction(quadratic_func, linear_func):
assert result.largest_exponent == 2
assert np.array_equal(result.coefficients, [2, -4, -15])
def test_multiplication(linear_func):
def test_scalar_multiplication(linear_func):
"""Tests the multiplication of a Function object by a scalar."""
# (x + 10) * 3 = 3x + 30
result = linear_func * 3
assert result.largest_exponent == 1
assert np.array_equal(result.coefficients, [3, 30])
def test_function_multiplication(m_func_1, m_func_2):
"""Tests the multiplication of two Function objects."""
# (2x^2 + 3x + 1) * (5x -4) = 10x^3 + 7x^2 - 7x -4
result = m_func_1 * m_func_2
assert result.largest_exponent == 3
assert np.array_equal(result.coefficients, [19, 7, -7, -4])
# --- Genetic Algorithm Root-Finding Tests ---
def test_get_real_roots_numpy(quadratic_func):