diff --git a/src/polysolve/__init__.py b/src/polysolve/__init__.py index 885d048..a342ed0 100644 --- a/src/polysolve/__init__.py +++ b/src/polysolve/__init__.py @@ -1,7 +1,7 @@ import math import numpy as np from dataclasses import dataclass -from typing import List, Optional +from typing import List, Optional, Union import warnings # Attempt to import CuPy for CUDA acceleration. @@ -392,36 +392,55 @@ class Function: result_func = Function(len(new_coefficients) - 1) result_func.set_coeffs(new_coefficients.tolist()) return result_func - - def __mul__(self, scalar: int) -> 'Function': - """Multiplies the function by a scalar constant.""" - self._check_initialized() - if not isinstance(scalar, (int, float)): - return NotImplemented + + def _multiply_by_scalar(self, scalar: Union[int, float]) -> 'Function': + """Helper method to multiply the function by a scalar constant.""" + self._check_initialized() # It's good practice to check here too + if scalar == 0: result_func = Function(0) result_func.set_coeffs([0]) return result_func - + new_coefficients = self.coefficients * scalar - + result_func = Function(self._largest_exponent) result_func.set_coeffs(new_coefficients.tolist()) return result_func - def __rmul__(self, scalar: int) -> 'Function': + def _multiply_by_function(self, other: 'Function') -> 'Function': + """Helper method for polynomial multiplication (Function * Function).""" + self._check_initialized() + other._check_initialized() + + # np.polymul performs convolution of coefficients to multiply polynomials + new_coefficients = np.polymul(self.coefficients, other.coefficients) + + # The degree of the resulting polynomial is derived from the new coefficients + new_degree = len(new_coefficients) - 1 + + result_func = Function(new_degree) + result_func.set_coeffs(new_coefficients.tolist()) + return result_func + + def __mul__(self, other: Union['Function', int, float]) -> 'Function': + """Multiplies the function by a scalar constant.""" + if isinstance(other, (int, float)): + return self._multiply_by_scalar(other) + elif isinstance(other, self.__class__): + return self._multiply_by_function(other) + else: + return NotImplemented + + def __rmul__(self, scalar: Union[int, float]) -> 'Function': """Handles scalar multiplication from the right (e.g., 3 * func).""" + return self.__mul__(scalar) - def __imul__(self, scalar: int) -> 'Function': + def __imul__(self, other: Union['Function', int, float]) -> 'Function': """Performs in-place multiplication by a scalar (func *= 3).""" - self._check_initialized() - if not isinstance(scalar, (int, float)): - return NotImplemented - if scalar == 0: - raise ValueError("Cannot multiply a function by 0.") - - self.coefficients *= scalar + + self.coefficients *= other return self @@ -506,4 +525,18 @@ if __name__ == '__main__': # Multiplication: (x + 10) * 3 = 3x + 30 f_mul = f2 * 3 - print(f"f2 * 3 = {f_mul}") \ No newline at end of file + print(f"f2 * 3 = {f_mul}") + + # f3 represents 2x^2 + 3x + 1 + f3 = Function(2) + f3.set_coeffs([2, 3, 1]) + print(f"Function f3: {f3}") + + # f4 represents 5x - 4 + f4 = Function(1) + f4.set_coeffs([5, -4]) + print(f"Function f4: {f4}") + + # Multiply the two functions + product_func = f3 * f4 + print(f"f3 * f4 = {product_func}") diff --git a/tests/test_polysolve.py b/tests/test_polysolve.py index 4bf0f11..e0a502a 100644 --- a/tests/test_polysolve.py +++ b/tests/test_polysolve.py @@ -24,6 +24,18 @@ def linear_func() -> Function: f.set_coeffs([1, 10]) return f +@pytest.fixture +def m_func_1() -> Function: + f = Function(2) + f.set_coeffs([2, 3, 1]) + return f + +@pytest.fixture +def m_func_2() -> Function: + f = Function(1) + f.set_coeffs([5, -4]) + return f + # --- Core Functionality Tests --- def test_solve_y(quadratic_func): @@ -68,13 +80,20 @@ def test_subtraction(quadratic_func, linear_func): assert result.largest_exponent == 2 assert np.array_equal(result.coefficients, [2, -4, -15]) -def test_multiplication(linear_func): +def test_scalar_multiplication(linear_func): """Tests the multiplication of a Function object by a scalar.""" # (x + 10) * 3 = 3x + 30 result = linear_func * 3 assert result.largest_exponent == 1 assert np.array_equal(result.coefficients, [3, 30]) +def test_function_multiplication(m_func_1, m_func_2): + """Tests the multiplication of two Function objects.""" + # (2x^2 + 3x + 1) * (5x -4) = 10x^3 + 7x^2 - 7x -4 + result = m_func_1 * m_func_2 + assert result.largest_exponent == 3 + assert np.array_equal(result.coefficients, [19, 7, -7, -4]) + # --- Genetic Algorithm Root-Finding Tests --- def test_get_real_roots_numpy(quadratic_func):