Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings
This repository was archived by the owner on Jul 24, 2024. It is now read-only.

Commit 1a41080

Browse files
committed
ENH: Surface more errors from the glmnet solver
The glmnet solver uses integer codes to communicate errors and warnings. The error code returned by then solver is saved to the attribute `jerr` after fitting a model. Negative values denote warnings such as convergence issues, positive values denote fatal conditions such as memory allocation problems, and a value of zero is used when the solver runs successfully without error. Initially we translated the convergence warnings into more complete messages from the numeric code. For all other errors, we just returned something opaque like "glmnet error no. 123." In this PR, we add some additional messages and raise the relevant type of warning or exception.
1 parent f7f2cc2 commit 1a41080

File tree

6 files changed

+94
-19
lines changed

6 files changed

+94
-19
lines changed

‎glmnet/errors.py‎

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
import warnings
2+
3+
4+
def _check_glmnet_error_flag(jerr, n_lambda):
5+
"""Check the error flag. Issue warning on convergence errors (jerr < 0)
6+
and exception on anything else."""
7+
8+
if jerr == 0:
9+
return
10+
11+
if jerr > 0:
12+
_fatal_errors(jerr, n_lambda)
13+
14+
if jerr < 0:
15+
_convergence_errors(jerr, n_lambda)
16+
17+
18+
def _fatal_errors(jerr, n_lambda):
19+
if jerr == 7777:
20+
raise ValueError("All predictors have zero variance "
21+
"(glmnet error no. 7777).")
22+
23+
if jerr == 10000:
24+
raise ValueError("At least one value of relative_penalties must be "
25+
"positive (glmnet error no. 10000).")
26+
27+
if jerr < 7777:
28+
raise RuntimeError("Memory allocation error (glmnet error no. {})."
29+
.format(jerr))
30+
31+
else:
32+
raise RuntimeError("Fatal glmnet error no. {}.".format(jerr))
33+
34+
35+
def _convergence_errors(jerr, n_lambda):
36+
if abs(jerr) <= n_lambda:
37+
warnings.warn("Model did not converge for smaller values of lambda, "
38+
"returning solution for the largest {} values."
39+
.format(-1 * (jerr - 1)), RuntimeWarning)
40+
else:
41+
warnings.warn("Non-fatal glmnet error no. {}.".format(jerr),
42+
RuntimeWarning)
43+

‎glmnet/linear.py‎

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@
77
from sklearn.metrics import r2_score
88
from sklearn.utils import check_array, check_X_y
99

10+
from .errors import _check_glmnet_error_flag
1011
from _glmnet import elnet, spelnet, solns
1112
from .util import (_fix_lambda_path,
12-
_check_glmnet_error_flag,
1313
_check_user_lambda,
1414
_interpolate_model,
1515
_score_lambda_path)
@@ -313,7 +313,7 @@ def _fit(self, X, y, sample_weight, relative_penalties):
313313

314314
# raises RuntimeError if self.jerr_ is nonzero
315315
self.jerr_ = jerr
316-
_check_glmnet_error_flag(self.jerr_)
316+
_check_glmnet_error_flag(self.jerr_, n_lambda)
317317

318318
self.lambda_path_ = self.lambda_path_[:self.n_lambda_]
319319
self.lambda_path_ = _fix_lambda_path(self.lambda_path_)

‎glmnet/logistic.py‎

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@
99
from sklearn.utils import check_array, check_X_y
1010
from sklearn.utils.multiclass import check_classification_targets
1111

12+
from .errors import _check_glmnet_error_flag
1213
from _glmnet import lognet, splognet, lsolns
1314
from .util import (_fix_lambda_path,
14-
_check_glmnet_error_flag,
1515
_check_user_lambda,
1616
_interpolate_model,
1717
_score_lambda_path)
@@ -362,7 +362,7 @@ def _fit(self, X, y, sample_weight=None, relative_penalties=None):
362362

363363
# raises RuntimeError if self.jerr_ is nonzero
364364
self.jerr_ = jerr
365-
_check_glmnet_error_flag(self.jerr_)
365+
_check_glmnet_error_flag(self.jerr_, n_lambda)
366366

367367
# glmnet may not return the requested number of lambda values, so we
368368
# need to trim the trailing zeros from the returned path so

‎glmnet/tests/test_errors.py‎

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import unittest
2+
3+
from glmnet.errors import _check_glmnet_error_flag
4+
5+
6+
class TestErrors(unittest.TestCase):
7+
8+
def test_zero_jerr(self):
9+
# This should not raise any warnings or exceptions.
10+
_check_glmnet_error_flag(0, n_lambda=100)
11+
12+
def test_convergence_err(self):
13+
msg = ("Model did not converge for smaller values of lambda, "
14+
"returning solution for the largest 75 values.")
15+
with self.assertWarns(RuntimeWarning, msg=msg):
16+
_check_glmnet_error_flag(-76, n_lambda=100)
17+
18+
def test_zero_var_err(self):
19+
msg = "All predictors have zero variance (glmnet error no. 7777)."
20+
with self.assertRaises(ValueError, msg=msg):
21+
_check_glmnet_error_flag(7777, n_lambda=100)
22+
23+
def test_all_negative_rel_penalty(self):
24+
msg = ("At least one value of relative_penalties must be positive, "
25+
"(glmnet error no. 10000).")
26+
with self.assertRaises(ValueError, msg=msg):
27+
_check_glmnet_error_flag(10000, n_lambda=100)
28+
29+
def test_memory_allocation_err(self):
30+
msg = "Memory allocation error (glmnet error no. 1234)."
31+
with self.assertRaises(RuntimeError, msg=msg):
32+
_check_glmnet_error_flag(1234, n_lambda=100)
33+
34+
def test_other_fatal_err(self):
35+
msg = "Fatal glmnet error no. 8888."
36+
with self.assertRaises(RuntimeError, msg=msg):
37+
_check_glmnet_error_flag(8888, msg)

‎glmnet/tests/test_linear.py‎

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,15 @@ def test_with_single_var(self):
6363
m = m.fit(x, y)
6464
self.check_r2_score(y, m.predict(x), 0.90)
6565

66+
def test_with_no_predictor_variance(self):
67+
x = np.ones((500, 1))
68+
y = np.random.rand(500)
69+
70+
m = ElasticNet(random_state=561)
71+
msg = "All predictors have zero variance (glmnet error no. 7777)."
72+
with self.assertRaises(ValueError, msg=msg):
73+
m.fit(x, y)
74+
6675
def test_relative_penalties(self):
6776
m1 = ElasticNet(random_state=4328)
6877
m2 = ElasticNet(random_state=4328)
@@ -84,7 +93,7 @@ def test_relative_penalties(self):
8493

8594
# verify that the unpenalized coef ests exceed the penalized ones
8695
# in absolute value
87-
assert(np.all(np.abs(m1.coef_) <= np.abs(m2.coef_)))
96+
assert(np.all(np.abs(m1.coef_) <= np.abs(m2.coef_)))
8897

8998
def test_alphas(self):
9099
x, y = self.inputs[0]

‎glmnet/util.py‎

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -126,20 +126,6 @@ def _fix_lambda_path(lambda_path):
126126
return lambda_path
127127

128128

129-
def _check_glmnet_error_flag(jerr):
130-
"""Check the error flag. Issue warning on convergence errors (jerr < 0)
131-
and exception on anything else."""
132-
133-
if jerr and jerr != 0:
134-
if jerr < 0:
135-
import warnings
136-
msg = "glmnet did not converge for some values of lambda {}"
137-
warnings.warn(msg.format(jerr), RuntimeWarning)
138-
else:
139-
msg = "glmnet error no. {}"
140-
raise RuntimeError(msg.format(jerr))
141-
142-
143129
def _check_user_lambda(lambda_path, lambda_best=None, lamb=None):
144130
"""Verify the user-provided value of lambda is acceptable and ensure this
145131
is a 1-d array.

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /