1818from itertools import permutations
1919from random import choice , randint
2020from statistics import median
21- from typing import Callable
2221
2322import numpy as np
2423
25- from .functions import R , S , constant , gauss , rectangular , sigmoid , singleton , step , trapezoid , triangular
24+ from .classes import Array
25+ from .functions import (
26+ Membership ,
27+ R ,
28+ S ,
29+ constant ,
30+ gauss ,
31+ rectangular ,
32+ sigmoid ,
33+ singleton ,
34+ step ,
35+ trapezoid ,
36+ triangular ,
37+ )
2638
2739np .seterr (all = "raise" )
2840functions = [step , rectangular ]
3345argument4_functions = [trapezoid ]
3446
3547
36- def normalize (target : np . ndarray , output_length : int = 16 ) -> np . ndarray :
48+ def normalize (target : Array , output_length : int = 16 ) -> Array :
3749 """Normalize and interpolate a numpy array.
3850
3951 Return an array of output_length and normalized values.
4052 """
41- min_val = np .min (target )
42- max_val = np .max (target )
53+ min_val = float ( np .min (target ) )
54+ max_val = float ( np .max (target ) )
4355 if min_val == max_val :
4456 return np .ones (output_length )
4557 normalized_array = (target - min_val ) / (max_val - min_val )
@@ -49,13 +61,12 @@ def normalize(target: np.ndarray, output_length: int = 16) -> np.ndarray:
4961 return normalized_array
5062
5163
52- def guess_function (target : np . ndarray ) -> Callable :
64+ def guess_function (target : Array ) -> Membership :
5365 normalized = normalize (target )
54- # trivial case
5566 return constant if np .all (normalized == 1 ) else singleton
5667
5768
58- def fitness (func : Callable , target : np . ndarray , certainty : int | None = None ) -> float :
69+ def fitness (func : Membership , target : Array , certainty : int | None = None ) -> float :
5970 """Compute the difference between the array and the function evaluated at the parameters.
6071
6172 if the error is 0, we have a perfect match: fitness -> 1
@@ -66,7 +77,7 @@ def fitness(func: Callable, target: np.ndarray, certainty: int | None = None) ->
6677 return result if certainty is None else round (result , certainty )
6778
6879
69- def seed_population (func : Callable , target : np . ndarray ) -> dict [tuple , float ]:
80+ def seed_population (func : Membership , target : Array ) -> dict [tuple , float ]:
7081 # create a random population of parameters
7182 params = [p for p in inspect .signature (func ).parameters .values () if p .kind == p .POSITIONAL_OR_KEYWORD ]
7283 seed_population = {}
@@ -106,7 +117,7 @@ def reproduce(parent1: tuple, parent2: tuple) -> tuple:
106117
107118
108119def guess_parameters (
109- func : Callable , target : np . ndarray , precision : int | None = None , certainty : int | None = None
120+ func : Membership , target : Array , precision : int | None = None , certainty : int | None = None
110121) -> tuple :
111122 """Find the best fitting parameters for a function, targetting an array.
112123
@@ -188,7 +199,7 @@ def best() -> tuple:
188199 return best ()
189200
190201
191- def shave (target : np . ndarray , components : dict [Callable , tuple ]) -> np . ndarray :
202+ def shave (target : Array , components : dict [Membership , tuple ]) -> Array :
192203 """Remove the membership functions from the target array."""
193204 result = np .zeros_like (target )
194205 for func , params in components .items ():
0 commit comments