Just implemented this Kalman Filter in Python + Numpy keeping the Wikipedia notation
It's a pretty straightforward implementation of the original algorithm, the goals were
develop skills related to implementing a scientific paper
keep it readable (so I have used private methods for intermediate results)
It includes a simple test case
import numpy as np
"""
Kalman Filter in plain Python + Numpy
Filter State
- x = Updated State
- P = Updated State Uncertainty
- x_pred = Predicted State
- P_pred = Predicted State Uncertainty
Models
- F = Prediction Model
- H = Observation Model (Maps from State Space to Observation Space)
- B = Evolved Forcing Term
- Q = Evolution Noise
- R = Observation Noise
"""
class KF:
def __init__(self, x, P, F, B, H, Q, R):
# Init
self.x = x
self.x_pred = x
# Initial State Uncertainty
self.P = P
self.P_pred = P
# Prediction Model
self.F = F
# Observation Model
self.H = H
# Process Noise
self.Q = Q
# Observation Noise
self.R = R
def predict(self, u):
self.x_pred = F.dot(x) + B.dot(u)
self.P_pred = F.dot(P.dot(np.transpose(F))) + Q
def __innovation(self, y):
return y - self.H.dot(self.x_pred)
def __S(self):
return self.R + self.H.dot(self.P_pred.dot(np.transpose(self.H)))
def __K(self):
return self.P_pred.dot(np.transpose(self.H).dot(np.linalg.inv(self.__S())))
def __I(self, A):
if(A.shape[0] != A.shape[1]):
raise ValueError("[Identity] Not Square")
return np.identity(A.shape[0])
def update(self, y):
self.x = self.x_pred + self.__K().dot(self.__innovation(y))
temp = self.__K().dot(self.H)
self.P = (self.__I(temp) - temp).dot(self.P_pred.dot(np.transpose(self.__I(temp) - temp))) + self.__K().dot(self.R.dot(np.transpose(self.__K())))
def to_str(self):
return "x \n" + np.array2string(self.x) + "\n P \n" + np.array2string(self.P) + "\n x_pred \n" + np.array2string(self.x_pred) + "\n P_pred \n" + np.array2string(self.P_pred)
F = np.array([[1,0], [0,1]])
P = np.array([[1,0], [0,1]])
B = np.array([[1,0], [0,1]])
H = np.array([[1,0], [0,1]])
Q = np.array([[1,0], [0,1]])
R = np.array([[1,0], [0,1]])
x = np.array([[1], [0]])
y = np.array([[3], [5]])
u = np.array([[0], [0]])
kf = KF(x, P, F, B, H, Q, R)
kf.predict(u)
kf.update(y)
print("State = " + kf.to_str())
Output
State = x
[[2.33333333]
[6.66666667]]
P
[[0.66666667 0. ]
[0. 0.66666667]]
x_pred
[[1.]
[0.]]
P_pred
[[2. 0.]
[0. 2.]]
Validation with Filterpy as suggested by @AlexV
from filterpy.kalman import KalmanFilter
my_filter = KalmanFilter(dim_x=2, dim_z=2)
my_filter.x = x # initial state (location and velocity)
my_filter.F = F
my_filter.H = H # Measurement function
my_filter.P = P # covariance matrix
my_filter.R = R # state uncertainty
my_filter.Q = Q
my_filter.predict()
my_filter.update(y)
print(np.array2string(my_filter.x))
Output
[[2.33333333]
[6.66666667]]
-
3\$\begingroup\$ I assume you did this for the learning experience, but I would nevertheless like to bring filterpy to your attention. You could use it as reference to test against, but that depends on how much you trust the other implementation ;-). \$\endgroup\$AlexV– AlexV2019年04月25日 15:07:41 +00:00Commented Apr 25, 2019 at 15:07
1 Answer 1
Ignore for the moment the cookbook and filterpy.kalman, and let's assume that your implementation is nominally numerically correct. You had already done one manual check to validate that this is correct, but I recommend expanding your automated tests for more confidence.
Add PEP484 type hints. I also suggest moving your model descriptions to inline comments on your constructor parameters, despite the fact that the registered docstring does not benefit from this. Also, the docstring needs to be on the inside of the class, not the outside.
You have F, B and Q that incorrectly reference globals instead of your class members. Refer to the class members instead, and move all of your global code to functions.
Replace all of your transpose() calls with .T.
Replace all of your .dot() calls with @ matrix multiplication operators.
Don't use __ name-mangling prefixes. You can just leave your members public.
Since five of your state variables are constant, I recommend that you store them as read-only views.
S and K are simple accessors, so could benefit from being @property.
I never changes, so move it from a method to a constructed member variable. You have enough information on construction to know its dimensions.
In this expression:
self.P_pred.dot(np.transpose(self.H).dot(np.linalg.inv(self.__S())))
inv() is not the best approach. Read about solve: it removes one of your matrix multiplications, and is represented by the LAPACK primitive gesv so will be fast.
The variable temp should capture one more term (the I subtraction).
Consider replacing to_str with __str__ so that any string cast will show your string render.
Because you have a simple class with a mix of mutable and immutable members, one simple robustness and performance improvement is to define __slots__.
All together (omitting the original implementation),
import numpy as np
class KalmanFilter:
"""Kalman Filter in plain Python + Numpy"""
__slots__ = (
'B', 'F', 'H', 'I', 'P', 'P_pred', 'Q', 'R', 'x', 'x_pred',
)
def __init__(
self,
x: np.ndarray, # updated state
P: np.ndarray, # updated state uncertainty
F: np.ndarray, # prediction model
B: np.ndarray, # evolved forcing term
H: np.ndarray, # observation model (maps from state space to observation space)
Q: np.ndarray, # evolution noise
R: np.ndarray, # observation noise
) -> None:
self.x = x
self.x_pred = x # predicted state
self.P = P
self.P_pred = P # predicted state uncertainty
# All of these state matrices are constant, so use read-only views
fbhqr = F, B, H, Q, R
self.F, self.B, self.H, self.Q, self.R = views = [data.view() for data in fbhqr]
for view in views:
view.flags.writeable = False
# The input A to I() is always k@H
# k always has the same leading dimension as P_pred
if P.shape[0] != H.shape[1]:
raise ValueError('P0 and H1 do not imply a square matrix')
self.I = np.eye(P.shape[0], dtype=P.dtype)
self.I.flags.writeable = False
def predict(self, u: np.ndarray) -> None:
self.x_pred = [email protected] + self.B@u
self.P_pred = [email protected]@self.F.T + self.Q
def innovation(self, y: np.ndarray) -> np.ndarray:
return y - [email protected]_pred
@property
def S(self)-> np.ndarray:
return self.R + [email protected] [email protected]
@property
def K(self) -> np.ndarray:
# Equivalent to:
# return self.P_pred @ self.H.T @ np.linalg.inv(self.S)
return np.linalg.solve(self.S.T, [email protected]_pred.T).T
def update(self, y: np.ndarray) -> None:
k = self.K
temp = self.I - [email protected]
self.x = self.x_pred + [email protected](y)
self.P = [email protected] [email protected] + [email protected]@k.T
def __str__(self) -> str:
return f'''x
{self.x}
P
{self.P}
x_pred
{self.x_pred}
P_pred
{self.P_pred}'''
def simple_test() -> None:
for cls in (KalmanFilterOP, KalmanFilter):
F = np.eye(2, dtype=np.int32)
P = F.copy()
B = F.copy()
H = F.copy()
Q = F.copy()
R = F.copy()
x = np.array([[1], [0]])
y = np.array([[3], [5]])
u = np.array([[0], [0]])
kf = cls(x, P, F, B, H, Q, R)
kf.predict(u)
kf.update(y)
assert np.allclose(
kf.P,
(
(2/3, 0),
(0, 2/3),
), rtol=0, atol=1e-14,
)
assert np.array_equal(
kf.P_pred,
(
(2, 0),
(0, 2),
),
)
assert np.allclose(
kf.x.T, (7/3, 10/3),
rtol=0, atol=1e-14,
)
# Non-equivalent
assert str(kf) == '''x
[[2.33333333]
[3.33333333]]
P
[[0.66666667 0. ]
[0. 0.66666667]]
x_pred
[[1]
[0]]
P_pred
[[2 0]
[0 2]]'''
def equivalence_test() -> None:
rand = np.random.default_rng(seed=0)
for outer_rep in range(10):
states = []
fpbhqr = rand.random(size=(6, 2, 2))
x = rand.random(size=(2, 1))
uy = rand.random(size=(5, 2, 2, 1))
for cls in (KalmanFilter, KalmanFilterOP):
f, p, b, h, q, r = fpbhqr.copy()
kf = cls(x=x.copy(), P=p, F=f, B=b, H=h, Q=q, R=r)
kf_states = []
for u, y in uy.copy():
kf.predict(u)
kf.update(y)
kf_states.append([
(member, getattr(kf, member))
for member in ('B', 'F', 'H', 'K', 'P', 'P_pred', 'Q', 'R', 'S', 'x', 'x_pred')
])
states.append(kf_states)
for ka, kb in zip(*states):
for (name, state_a), (name_b, state_b) in zip(ka, kb):
assert np.allclose(state_a, state_b, atol=0, rtol=1e-12), name
if __name__ == '__main__':
equivalence_test()
simple_test()