Skip to content

Commit 1c1538a

Browse files
committed
ENH move utility function into dedicated file, not __init__.py
1 parent 35b1d39 commit 1c1538a

File tree

2 files changed

+56
-50
lines changed

2 files changed

+56
-50
lines changed

sklearn/utils/__init__.py

Lines changed: 3 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,12 @@
1111
assert_all_finite, array2d, atleast2d_or_csc,
1212
atleast2d_or_csr, warn_if_not_float,
1313
check_random_state)
14+
from class_weight import compute_class_weight
1415

1516
__all__ = ["murmurhash3_32", "as_float_array", "check_arrays", "safe_asarray",
1617
"assert_all_finite", "array2d", "atleast2d_or_csc",
17-
"atleast2d_or_csr", "warn_if_not_float", "check_random_state"]
18+
"atleast2d_or_csr", "warn_if_not_float", "check_random_state",
19+
"compute_class_weight"]
1820

1921
# Make sure that DeprecationWarning get printed
2022
warnings.simplefilter("always", DeprecationWarning)
@@ -346,52 +348,3 @@ def gen_even_slices(n, n_packs):
346348

347349
class ConvergenceWarning(Warning):
348350
"Custom warning to capture convergence problems"
349-
350-
351-
def compute_class_weight(class_weight, classes, y):
352-
"""Estimate class weights for unbalanced datasets.
353-
354-
Parameters
355-
----------
356-
class_weight : dict, 'auto' or None
357-
If 'auto', class weights will be given inverse proportional
358-
to the frequency of the class in the data.
359-
If a dictionary is given, keys are classes and values
360-
are corresponding class weights.
361-
If None is given, the class weights will be uniform.
362-
classes : list
363-
List of the classes occuring in the data, as given by
364-
``np.unique(y_org)`` with ``y_org`` the original class labels.
365-
y : array-like, shape=(n_samples,), dtype=int
366-
Array of class indices per sample;
367-
0 <= y[i] < n_classes for i in range(n_samples).
368-
369-
370-
Returns
371-
-------
372-
class_weight_vect : ndarray, shape=(n_classes,)
373-
Array with class_weight_vect[i] the weight for i-th class
374-
(as determined by sorting).
375-
"""
376-
if class_weight is None or len(class_weight) == 0:
377-
# uniform class weights
378-
weight = np.ones(classes.shape[0], dtype=np.float64, order='C')
379-
elif class_weight == 'auto':
380-
# proportional to the number of samples in the class
381-
weight = np.array([1.0 / np.sum(y == i) for i in classes],
382-
dtype=np.float64, order='C')
383-
weight *= classes.shape[0] / np.sum(weight)
384-
else:
385-
# user-defined dictionary
386-
weight = np.ones(classes.shape[0], dtype=np.float64, order='C')
387-
if not isinstance(class_weight, dict):
388-
raise ValueError("class_weight must be dict, 'auto', or None,"
389-
" got: %r" % class_weight)
390-
for c in class_weight:
391-
i = np.searchsorted(classes, c)
392-
if classes[i] != c:
393-
raise ValueError("Class label %d not present." % c)
394-
else:
395-
weight[i] = class_weight[c]
396-
397-
return weight

sklearn/utils/class_weight.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# Authors: Andreas Mueller
2+
# License: Simplified BSD
3+
4+
import numpy as np
5+
6+
7+
def compute_class_weight(class_weight, classes, y):
8+
"""Estimate class weights for unbalanced datasets.
9+
10+
Parameters
11+
----------
12+
class_weight : dict, 'auto' or None
13+
If 'auto', class weights will be given inverse proportional
14+
to the frequency of the class in the data.
15+
If a dictionary is given, keys are classes and values
16+
are corresponding class weights.
17+
If None is given, the class weights will be uniform.
18+
classes : list
19+
List of the classes occuring in the data, as given by
20+
``np.unique(y_org)`` with ``y_org`` the original class labels.
21+
y : array-like, shape=(n_samples,), dtype=int
22+
Array of class indices per sample;
23+
0 <= y[i] < n_classes for i in range(n_samples).
24+
25+
26+
Returns
27+
-------
28+
class_weight_vect : ndarray, shape=(n_classes,)
29+
Array with class_weight_vect[i] the weight for i-th class
30+
(as determined by sorting).
31+
"""
32+
if class_weight is None or len(class_weight) == 0:
33+
# uniform class weights
34+
weight = np.ones(classes.shape[0], dtype=np.float64, order='C')
35+
elif class_weight == 'auto':
36+
# proportional to the number of samples in the class
37+
weight = np.array([1.0 / np.sum(y == i) for i in classes],
38+
dtype=np.float64, order='C')
39+
weight *= classes.shape[0] / np.sum(weight)
40+
else:
41+
# user-defined dictionary
42+
weight = np.ones(classes.shape[0], dtype=np.float64, order='C')
43+
if not isinstance(class_weight, dict):
44+
raise ValueError("class_weight must be dict, 'auto', or None,"
45+
" got: %r" % class_weight)
46+
for c in class_weight:
47+
i = np.searchsorted(classes, c)
48+
if classes[i] != c:
49+
raise ValueError("Class label %d not present." % c)
50+
else:
51+
weight[i] = class_weight[c]
52+
53+
return weight

0 commit comments

Comments
 (0)