Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings
This repository was archived by the owner on Jul 24, 2024. It is now read-only.

Commit c26acfb

Browse files
eli-goodfriendJohn Davis
authored and
John Davis
committed
LogitNet weights (#34)
* expose max_features to users * off by one in def of nx vs ne for max_features * init needed for use in batchlearner * max_features in class init * tests check number of nonzero coefs * want less than or equals * remove unnecessary max_features arg * use sample weights during fit of lognet, not just cv * want weights in score calc too * simplified weights logic * test with sparse matrix
1 parent 3f8b4a2 commit c26acfb

File tree

3 files changed

+34
-2
lines changed

3 files changed

+34
-2
lines changed

‎__init__.py‎

Whitespace-only changes.

‎glmnet/logistic.py‎

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,15 @@ def _fit(self, X, y, sample_weight=None, relative_penalties=None):
266266
# keep the original order.
267267
_y = (y[:, None] == self.classes_).astype(np.float64, order='F')
268268

269+
# use sample weights, making sure all weights are positive
270+
# this is inspired by the R wrapper for glmnet, in lognet.R
271+
if sample_weight is not None:
272+
weight_gt_0 = sample_weight > 0
273+
sample_weight = sample_weight[weight_gt_0]
274+
_y = _y[weight_gt_0, :]
275+
X = X[weight_gt_0, :]
276+
_y = _y * np.expand_dims(sample_weight, 1)
277+
269278
# we need some sort of "offset" array for glmnet
270279
# an array of shape (n_examples, n_classes)
271280
offset = np.zeros((X.shape[0], n_classes), dtype=np.float64,

‎glmnet/tests/test_logistic.py‎

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88
from scipy.sparse import csr_matrix
99

1010
from sklearn.datasets import make_classification
11-
from sklearn.metrics import accuracy_score
12-
from sklearn.utils import estimator_checks
11+
from sklearn.metrics import accuracy_score, f1_score
12+
from sklearn.utils import estimator_checks, class_weight
1313
from sklearn.utils.testing import ignore_warnings
1414

1515
from util import sanity_check_logistic
@@ -212,6 +212,29 @@ def test_max_features(self):
212212
num_features = np.count_nonzero(m.coef_, axis=1)
213213
self.assertTrue(np.all(num_features <= max_features))
214214

215+
def test_use_sample_weights(self):
216+
x, y = self.multinomial[1]
217+
class_0_idx = np.where(y==0)
218+
to_drop = class_0_idx[0][:-3]
219+
to_keep = np.ones(len(y), dtype=bool)
220+
to_keep[to_drop] = False
221+
y = y[to_keep]
222+
x = x[to_keep, :]
223+
sample_weight = class_weight.compute_sample_weight('balanced', y)
224+
sample_weight[0] = 0.
225+
226+
unweighted = LogitNet(random_state=2, scoring='f1_micro')
227+
unweighted = unweighted.fit(x, y)
228+
unweighted_acc = f1_score(y, unweighted.predict(x), sample_weight=sample_weight,
229+
average='micro')
230+
231+
weighted = LogitNet(random_state=2, scoring='f1_micro')
232+
weighted = weighted.fit(x, y, sample_weight=sample_weight)
233+
weighted_acc = f1_score(y, weighted.predict(x), sample_weight=sample_weight,
234+
average='micro')
235+
236+
self.assertTrue(weighted_acc >= unweighted_acc)
237+
215238
def check_accuracy(y, y_hat, at_least, **other_params):
216239
score = accuracy_score(y, y_hat)
217240
msg = "expected accuracy of {}, got: {} with {}".format(at_least, score, other_params)

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /