Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit c3f73b1

Browse files
completely refactored defuzzification and put in new defuzz.py
1 parent 931d571 commit c3f73b1

File tree

2 files changed

+128
-41
lines changed

2 files changed

+128
-41
lines changed

‎src/fuzzylogic/classes.py‎

Lines changed: 53 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919

2020
import numpy as np
2121

22+
from fuzzylogic import defuzz
23+
2224
from .combinators import MAX, MIN, bounded_sum, product, simple_disjoint_sum
2325
from .functions import Membership, inv, normalize
2426

@@ -481,50 +483,60 @@ def __eq__(self, other: object) -> bool:
481483
def __getitem__(self, key: Iterable[Set]) -> Set:
482484
return self.conditions[frozenset(key)]
483485

484-
def __call__(self, values: dict[Domain, float | int], method: str = "cog") -> float | None:
485-
"""Calculate the infered value based on different methods.
486-
Default is center of gravity (cog).
486+
def __call__(self, values: dict[Domain, float], method=defuzz.cog) -> float | None:
487+
"""
488+
Calculate the inferred crisp value based on the fuzzy rules.
489+
The 'method' parameter should be one of the static methods from the DefuzzMethod class.
487490
"""
488-
assert isinstance(values, dict), "Please make sure to pass a dict[Domain, float|int] as values."
489-
assert len(self.conditions) > 0, "No point in having a rule with no conditions, is there?"
491+
assert isinstance(values, dict), "Please pass a dict[Domain, float|int] as values."
492+
assert values, "No condition rules defined!"
493+
494+
# Extract common target domain and build list of (then_set, firing_strength)
495+
sample_then_set = next(iter(self.conditions.values()))
496+
target_domain = getattr(sample_then_set, "domain", None)
497+
assert target_domain, "Target domain must be defined."
498+
499+
target_weights: list[tuple[Set, float]] = []
500+
for if_sets, then_set in self.conditions.items():
501+
assert then_set.domain == target_domain, "All target sets must be in the same Domain."
502+
degrees = []
503+
for s in if_sets:
504+
assert s.domain is not None, "Domain must be defined for all fuzzy sets."
505+
degrees.append(s(values[s.domain]))
506+
firing_strength = min(degrees, default=0)
507+
if firing_strength > 0:
508+
target_weights.append((then_set, firing_strength))
509+
if not target_weights:
510+
return None
511+
512+
# For center-of-gravity / centroid:
513+
if method == defuzz.cog:
514+
return defuzz.cog(target_weights)
515+
516+
# For methods that rely on an aggregated membership function:
517+
points = list(target_domain.range)
518+
n = len(points)
519+
step = (
520+
(target_domain._high - target_domain._low) / (n - 1)
521+
if n > 1
522+
else (target_domain._high - target_domain._low)
523+
)
524+
525+
def aggregated_membership(x: float) -> float:
526+
# For each rule, limit its inferred output by its firing strength and then take the max
527+
return max(min(weight, then_set(x)) for then_set, weight in target_weights)
528+
490529
match method:
491-
case "cog":
492-
# iterate over the conditions and calculate the actual values and weights contributing to cog
493-
target_weights: list[tuple[Set, float]] = []
494-
target_domain = list(self.conditions.values())[0].domain
495-
assert target_domain is not None, "Target domain must be defined."
496-
for if_sets, then_set in self.conditions.items():
497-
actual_values: list[float] = []
498-
assert then_set.domain == target_domain, "All target sets must be in the same Domain."
499-
for s in if_sets:
500-
assert s.domain is not None, "Domains must be defined."
501-
actual_values.append(s(values[s.domain]))
502-
x = min(actual_values, default=0)
503-
if x > 0:
504-
target_weights.append((then_set, x))
505-
if not target_weights:
506-
return None
507-
sum_weights = 0
508-
sum_weighted_cogs: float = 0
509-
for then_set, weight in target_weights:
510-
sum_weighted_cogs += then_set.center_of_gravity() * weight
511-
sum_weights += weight
512-
index = sum_weighted_cogs / sum_weights
513-
return (target_domain._high - target_domain._low) / len( # type: ignore
514-
target_domain.range
515-
) * index + target_domain._low # type: ignore
516-
case "centroid": # centroid == center of mass == center of gravity for simple solids
517-
raise NotImplementedError("actually the same as 'cog' if densities are uniform.")
518-
case "bisector":
519-
raise NotImplementedError("Bisector method not implemented yet.")
520-
case "mom":
521-
raise NotImplementedError("Middle of max method not implemented yet.")
522-
case "som":
523-
raise NotImplementedError("Smallest of max method not implemented yet.")
524-
case "lom":
525-
raise NotImplementedError("Largest of max method not implemented yet.")
530+
case defuzz.bisector:
531+
return defuzz.bisector(aggregated_membership, points, step)
532+
case defuzz.mom:
533+
return defuzz.mom(aggregated_membership, points)
534+
case defuzz.som:
535+
return defuzz.som(aggregated_membership, points)
536+
case defuzz.lom:
537+
return defuzz.lom(aggregated_membership, points)
526538
case _:
527-
raise ValueError("Invalid method.")
539+
raise ValueError("Invalid defuzzification method specified.")
528540

529541

530542
def rule_from_table(table: str, references: dict[str, float]) -> Rule:

‎src/fuzzylogic/defuzz.py‎

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING
4+
5+
if TYPE_CHECKING:
6+
from .classes import Membership, Set
7+
8+
9+
def cog(target_weights: list[tuple[Set, float]]) -> float:
10+
"""
11+
Defuzzify using the center-of-gravity (or centroid) method.
12+
target_weights: list of tuples (then_set, weight)
13+
14+
The COG is defined by the formula:
15+
16+
COG = (∑ μi ×ばつ xi) / (∑ μi)
17+
18+
where:
19+
• μi is the membership value for the ith element,
20+
• xi is the corresponding value for the ith element in the output domain.
21+
22+
"""
23+
sum_weights = sum(weight for _, weight in target_weights)
24+
sum_weighted_cogs = sum(then_set.center_of_gravity() * weight for then_set, weight in target_weights)
25+
return sum_weighted_cogs / sum_weights
26+
27+
28+
def bisector(
29+
aggregated_membership: Membership,
30+
points: list[float],
31+
step: float,
32+
) -> float:
33+
"""
34+
Defuzzify via the bisector method.
35+
aggregated_membership: function mapping crisp value x -> membership degree (typically in [0,1])
36+
points: discretized points in the target domain
37+
step: spacing between points
38+
"""
39+
total_area = sum(aggregated_membership(x) * step for x in points)
40+
half_area = total_area / 2.0
41+
cumulative = 0.0
42+
for x in points:
43+
cumulative += aggregated_membership(x) * step
44+
if cumulative >= half_area:
45+
return x
46+
return points[-1]
47+
48+
49+
def mom(aggregated_membership: Membership, points: list[float]) -> float | None:
50+
"""
51+
Mean of Maxima (MOM): average the x-values where the aggregated membership is maximal.
52+
"""
53+
max_points = _get_max_points(aggregated_membership, points)
54+
return sum(max_points) / len(max_points) if max_points else None
55+
56+
57+
def som(aggregated_membership: Membership, points: list[float]) -> float | None:
58+
"""
59+
Smallest of Maxima: return the smallest x-value at which the aggregated membership is maximal.
60+
"""
61+
return min(_get_max_points(aggregated_membership, points), default=None)
62+
63+
64+
def lom(aggregated_membership: Membership, points: list[float]) -> float | None:
65+
"""
66+
Largest of Maxima: return the largest x-value at which the aggregated membership is maximal.
67+
"""
68+
return max(_get_max_points(aggregated_membership, points), default=None)
69+
70+
71+
def _get_max_points(aggregated_membership: Membership, points: list[float]) -> list[float]:
72+
values_points = [(x, aggregated_membership(x)) for x in points]
73+
max_value = max(y for (_, y) in values_points)
74+
tol = 1e-6 # tolerance for floating point comparisons
75+
return [x for (x, y) in values_points if abs(y - max_value) < tol]

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /