Skip to content

Commit fe8e719

Browse files
jatinshaharjoly
authored andcommitted
Add sample_weight parameter to metrics.log_loss
- Also modified binary output & multiclass tests in test_sample_weight_invariance to test for prediction inputs as probabilities - Updated What's New
1 parent 22cafa6 commit fe8e719

File tree

3 files changed

+34
-26
lines changed

3 files changed

+34
-26
lines changed

doc/whats_new.rst

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,8 @@ Enhancements
4444
descent for :class:`linear_model.Lasso`, :class:`linear_model.ElasticNet`
4545
and related. By `Manoj Kumar`_.
4646

47-
- Add ``sample_weight`` parameter to `metrics.jaccard_similarity_score`.
48-
By `Jatin Shah`.
49-
50-
47+
- Add ``sample_weight`` parameter to `metrics.jaccard_similarity_score` and
48+
`metrics.log_loss`. By `Jatin Shah`.
5149

5250
Documentation improvements
5351
..........................

sklearn/metrics/classification.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,15 @@ def _check_clf_targets(y_true, y_pred):
9292
return y_type, y_true, y_pred
9393

9494

95+
def _weighted_sum(sample_score, sample_weight, normalize=False):
96+
if normalize:
97+
return np.average(sample_score, weights=sample_weight)
98+
elif sample_weight is not None:
99+
return np.dot(sample_score, sample_weight)
100+
else:
101+
return sample_score.sum()
102+
103+
95104
def accuracy_score(y_true, y_pred, normalize=True, sample_weight=None):
96105
"""Accuracy classification score.
97106
@@ -159,14 +168,7 @@ def accuracy_score(y_true, y_pred, normalize=True, sample_weight=None):
159168
else:
160169
score = y_true == y_pred
161170

162-
if normalize:
163-
if sample_weight is not None:
164-
return np.average(score, weights=sample_weight)
165-
return np.mean(score)
166-
else:
167-
if sample_weight is not None:
168-
return np.dot(score, sample_weight)
169-
return np.sum(score)
171+
return _weighted_sum(score, sample_weight, normalize)
170172

171173

172174
def confusion_matrix(y_true, y_pred, labels=None):
@@ -344,13 +346,7 @@ def jaccard_similarity_score(y_true, y_pred, normalize=True,
344346
else:
345347
score = y_true == y_pred
346348

347-
if normalize:
348-
return np.average(score, weights=sample_weight)
349-
else:
350-
if sample_weight is not None:
351-
return np.dot(score, sample_weight)
352-
else:
353-
return np.sum(score)
349+
return _weighted_sum(score, sample_weight, normalize)
354350

355351

356352
def matthews_corrcoef(y_true, y_pred):
@@ -1317,7 +1313,7 @@ def hamming_loss(y_true, y_pred, classes=None):
13171313
raise ValueError("{0} is not supported".format(y_type))
13181314

13191315

1320-
def log_loss(y_true, y_pred, eps=1e-15, normalize=True):
1316+
def log_loss(y_true, y_pred, eps=1e-15, normalize=True, sample_weight=None):
13211317
"""Log loss, aka logistic loss or cross-entropy loss.
13221318
13231319
This is the loss function used in (multinomial) logistic regression
@@ -1345,6 +1341,9 @@ def log_loss(y_true, y_pred, eps=1e-15, normalize=True):
13451341
If true, return the mean loss per sample.
13461342
Otherwise, return the sum of the per-sample losses.
13471343
1344+
sample_weight : array-like of shape = [n_samples], optional
1345+
Sample weights.
1346+
13481347
Returns
13491348
-------
13501349
loss : float
@@ -1393,8 +1392,9 @@ def log_loss(y_true, y_pred, eps=1e-15, normalize=True):
13931392

13941393
# Renormalize
13951394
Y /= Y.sum(axis=1)[:, np.newaxis]
1396-
loss = -(T * np.log(Y)).sum()
1397-
return loss / T.shape[0] if normalize else loss
1395+
loss = -(T * np.log(Y)).sum(axis=1)
1396+
1397+
return _weighted_sum(loss, sample_weight, normalize)
13981398

13991399

14001400
def hinge_loss(y_true, pred_decision, pos_label=None, neg_label=None):

sklearn/metrics/tests/test_common.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,8 @@
136136

137137
THRESHOLDED_METRICS = {
138138
"log_loss": log_loss,
139+
"unnormalized_log_loss": partial(log_loss, normalize=False),
140+
139141
"hinge_loss": hinge_loss,
140142

141143
"roc_auc_score": roc_auc_score,
@@ -239,6 +241,7 @@
239241
# Threshold-based metrics with "multilabel-indicator" format support
240242
THRESHOLDED_MULTILABEL_METRICS = [
241243
"log_loss",
244+
"unnormalized_log_loss",
242245

243246
"roc_auc_score", "weighted_roc_auc", "samples_roc_auc",
244247
"micro_roc_auc", "macro_roc_auc",
@@ -315,7 +318,6 @@
315318
"confusion_matrix",
316319
"hamming_loss",
317320
"hinge_loss",
318-
"log_loss",
319321
"matthews_corrcoef_score",
320322
]
321323

@@ -532,7 +534,7 @@ def test_invariance_string_vs_numbers_labels():
532534
"invariance test".format(name))
533535

534536
for name, metric in THRESHOLDED_METRICS.items():
535-
if name in ("log_loss", "hinge_loss"):
537+
if name in ("log_loss", "hinge_loss", "unnormalized_log_loss"):
536538
measure_with_number = metric(y1, y2)
537539
measure_with_str = metric(y1_str, y2)
538540
assert_array_equal(measure_with_number, measure_with_str,
@@ -968,23 +970,31 @@ def test_sample_weight_invariance(n_samples=50):
968970
random_state = check_random_state(0)
969971
y_true = random_state.randint(0, 2, size=(n_samples, ))
970972
y_pred = random_state.randint(0, 2, size=(n_samples, ))
973+
y_score = random_state.random_sample(size=(n_samples,))
971974
for name in ALL_METRICS:
972975
if (name in METRICS_WITHOUT_SAMPLE_WEIGHT or
973976
name in METRIC_UNDEFINED_MULTICLASS):
974977
continue
975978
metric = ALL_METRICS[name]
976-
yield check_sample_weight_invariance, name, metric, y_true, y_pred
979+
if name in THRESHOLDED_METRICS:
980+
yield check_sample_weight_invariance, name, metric, y_true, y_score
981+
else:
982+
yield check_sample_weight_invariance, name, metric, y_true, y_pred
977983

978984
# multiclass
979985
random_state = check_random_state(0)
980986
y_true = random_state.randint(0, 5, size=(n_samples, ))
981987
y_pred = random_state.randint(0, 5, size=(n_samples, ))
988+
y_score = random_state.random_sample(size=(n_samples, 5))
982989
for name in ALL_METRICS:
983990
if (name in METRICS_WITHOUT_SAMPLE_WEIGHT or
984991
name in METRIC_UNDEFINED_MULTICLASS):
985992
continue
986993
metric = ALL_METRICS[name]
987-
yield check_sample_weight_invariance, name, metric, y_true, y_pred
994+
if name in THRESHOLDED_METRICS:
995+
yield check_sample_weight_invariance, name, metric, y_true, y_score
996+
else:
997+
yield check_sample_weight_invariance, name, metric, y_true, y_pred
988998

989999
# multilabel sequence
9901000
y_true = 2 * [(1, 2, ), (1, ), (0, ), (0, 1), (1, 2)]

0 commit comments

Comments
 (0)