Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit bf93043

Browse files
add type hint to AOLS methods
1 parent 17c64cb commit bf93043

File tree

1 file changed

+54
-11
lines changed

1 file changed

+54
-11
lines changed

‎sysidentpy/model_structure_selection/accelerated_orthogonal_least_squares.py

Lines changed: 54 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# Authors:
44
# Wilson Rocha Lacerda Junior <wilsonrljr@outlook.com>
55
# License: BSD 3 clause
6-
from typing import Tuple, Union
6+
from typing import Tuple, Union, Optional
77

88
import numpy as np
99
from numpy import linalg as LA
@@ -262,7 +262,7 @@ def aols(
262262
theta = theta[theta != 0]
263263
return theta.reshape(-1, 1), pivv, residual_norm
264264

265-
def fit(self, *, X=None, y=None):
265+
def fit(self, *, X: Optional[np.ndarray] =None, y: Optional[np.ndarray] =None):
266266
"""Fit polynomial NARMAX model using AOLS algorithm.
267267
268268
The 'fit' function allows a friendly usage by the user.
@@ -343,7 +343,14 @@ def fit(self, *, X=None, y=None):
343343
] # just to use the `results` method. Will be changed in next update.
344344
return self
345345

346-
def predict(self, *, X=None, y=None, steps_ahead=None, forecast_horizon=None):
346+
def predict(
347+
self,
348+
*,
349+
X: Optional[np.ndarray] = None,
350+
y: Optional[np.ndarray] = None,
351+
steps_ahead: Optional[int] = None,
352+
forecast_horizon: int = 0,
353+
) -> np.ndarray:
347354
"""Return the predicted values given an input.
348355
349356
The predict function allows a friendly usage by the user.
@@ -401,7 +408,9 @@ def predict(self, *, X=None, y=None, steps_ahead=None, forecast_horizon=None):
401408
yhat = np.concatenate([y[: self.max_lag], yhat], axis=0)
402409
return yhat
403410

404-
def _one_step_ahead_prediction(self, X, y):
411+
def _one_step_ahead_prediction(
412+
self, X: Optional[np.ndarray], y: Optional[np.ndarray]
413+
) -> np.ndarray:
405414
"""Perform the 1-step-ahead prediction of a model.
406415
407416
Parameters
@@ -435,7 +444,12 @@ def _one_step_ahead_prediction(self, X, y):
435444
yhat = super()._one_step_ahead_prediction(X_base)
436445
return yhat.reshape(-1, 1)
437446

438-
def _n_step_ahead_prediction(self, X, y, steps_ahead):
447+
def _n_step_ahead_prediction(
448+
self,
449+
X: Optional[np.ndarray],
450+
y: Optional[np.ndarray],
451+
steps_ahead: Optional[int],
452+
) -> np.ndarray:
439453
"""Perform the n-steps-ahead prediction of a model.
440454
441455
Parameters
@@ -455,7 +469,12 @@ def _n_step_ahead_prediction(self, X, y, steps_ahead):
455469
yhat = super()._n_step_ahead_prediction(X, y, steps_ahead)
456470
return yhat
457471

458-
def _model_prediction(self, X, y_initial, forecast_horizon=None):
472+
def _model_prediction(
473+
self,
474+
X: Optional[np.ndarray],
475+
y_initial: Optional[np.ndarray],
476+
forecast_horizon: int = 1,
477+
) -> np.ndarray:
459478
"""Perform the infinity steps-ahead simulation of a model.
460479
461480
Parameters
@@ -481,7 +500,12 @@ def _model_prediction(self, X, y_initial, forecast_horizon=None):
481500
f"model_type must be NARMAX, NAR or NFIR. Got {self.model_type}"
482501
)
483502

484-
def _narmax_predict(self, X, y_initial, forecast_horizon):
503+
def _narmax_predict(
504+
self,
505+
X: Optional[np.ndarray],
506+
y_initial: Optional[np.ndarray],
507+
forecast_horizon: int = 1,
508+
) -> np.ndarray:
485509
if len(y_initial) < self.max_lag:
486510
raise ValueError(
487511
"Insufficient initial condition elements! Expected at least"
@@ -499,11 +523,18 @@ def _narmax_predict(self, X, y_initial, forecast_horizon):
499523
y_output = super()._narmax_predict(X, y_initial, forecast_horizon)
500524
return y_output
501525

502-
def _nfir_predict(self, X, y_initial):
526+
def _nfir_predict(
527+
self, X: Optional[np.ndarray], y_initial: Optional[np.ndarray]
528+
) -> np.ndarray:
503529
y_output = super()._nfir_predict(X, y_initial)
504530
return y_output
505531

506-
def _basis_function_predict(self, X, y_initial, forecast_horizon=None):
532+
def _basis_function_predict(
533+
self,
534+
X: Optional[np.ndarray],
535+
y_initial: Optional[np.ndarray],
536+
forecast_horizon: int = 1,
537+
) -> np.ndarray:
507538
if X is not None:
508539
forecast_horizon = X.shape[0]
509540
else:
@@ -515,7 +546,13 @@ def _basis_function_predict(self, X, y_initial, forecast_horizon=None):
515546
yhat = super()._basis_function_predict(X, y_initial, forecast_horizon)
516547
return yhat.reshape(-1, 1)
517548

518-
def _basis_function_n_step_prediction(self, X, y, steps_ahead, forecast_horizon):
549+
def _basis_function_n_step_prediction(
550+
self,
551+
X: Optional[np.ndarray],
552+
y: Optional[np.ndarray],
553+
steps_ahead: Optional[int],
554+
forecast_horizon: int,
555+
) -> np.ndarray:
519556
"""Perform the n-steps-ahead prediction of a model.
520557
521558
Parameters
@@ -548,7 +585,13 @@ def _basis_function_n_step_prediction(self, X, y, steps_ahead, forecast_horizon)
548585
)
549586
return yhat.reshape(-1, 1)
550587

551-
def _basis_function_n_steps_horizon(self, X, y, steps_ahead, forecast_horizon):
588+
def _basis_function_n_steps_horizon(
589+
self,
590+
X: Optional[np.ndarray],
591+
y: Optional[np.ndarray],
592+
steps_ahead: Optional[int],
593+
forecast_horizon: int,
594+
) -> np.ndarray:
552595
yhat = super()._basis_function_n_steps_horizon(
553596
X, y, steps_ahead, forecast_horizon
554597
)

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /