|
11 | 11 | assert_all_finite, array2d, atleast2d_or_csc, |
12 | 12 | atleast2d_or_csr, warn_if_not_float, |
13 | 13 | check_random_state) |
| 14 | +from class_weight import compute_class_weight |
14 | 15 |
|
15 | 16 | __all__ = ["murmurhash3_32", "as_float_array", "check_arrays", "safe_asarray", |
16 | 17 | "assert_all_finite", "array2d", "atleast2d_or_csc", |
17 | | - "atleast2d_or_csr", "warn_if_not_float", "check_random_state"] |
| 18 | + "atleast2d_or_csr", "warn_if_not_float", "check_random_state", |
| 19 | + "compute_class_weight"] |
18 | 20 |
|
19 | 21 | # Make sure that DeprecationWarning get printed |
20 | 22 | warnings.simplefilter("always", DeprecationWarning) |
@@ -346,52 +348,3 @@ def gen_even_slices(n, n_packs): |
346 | 348 |
|
347 | 349 | class ConvergenceWarning(Warning): |
348 | 350 | "Custom warning to capture convergence problems" |
349 | | - |
350 | | - |
351 | | -def compute_class_weight(class_weight, classes, y): |
352 | | - """Estimate class weights for unbalanced datasets. |
353 | | -
|
354 | | - Parameters |
355 | | - ---------- |
356 | | - class_weight : dict, 'auto' or None |
357 | | - If 'auto', class weights will be given inverse proportional |
358 | | - to the frequency of the class in the data. |
359 | | - If a dictionary is given, keys are classes and values |
360 | | - are corresponding class weights. |
361 | | - If None is given, the class weights will be uniform. |
362 | | - classes : list |
363 | | - List of the classes occuring in the data, as given by |
364 | | - ``np.unique(y_org)`` with ``y_org`` the original class labels. |
365 | | - y : array-like, shape=(n_samples,), dtype=int |
366 | | - Array of class indices per sample; |
367 | | - 0 <= y[i] < n_classes for i in range(n_samples). |
368 | | -
|
369 | | -
|
370 | | - Returns |
371 | | - ------- |
372 | | - class_weight_vect : ndarray, shape=(n_classes,) |
373 | | - Array with class_weight_vect[i] the weight for i-th class |
374 | | - (as determined by sorting). |
375 | | - """ |
376 | | - if class_weight is None or len(class_weight) == 0: |
377 | | - # uniform class weights |
378 | | - weight = np.ones(classes.shape[0], dtype=np.float64, order='C') |
379 | | - elif class_weight == 'auto': |
380 | | - # proportional to the number of samples in the class |
381 | | - weight = np.array([1.0 / np.sum(y == i) for i in classes], |
382 | | - dtype=np.float64, order='C') |
383 | | - weight *= classes.shape[0] / np.sum(weight) |
384 | | - else: |
385 | | - # user-defined dictionary |
386 | | - weight = np.ones(classes.shape[0], dtype=np.float64, order='C') |
387 | | - if not isinstance(class_weight, dict): |
388 | | - raise ValueError("class_weight must be dict, 'auto', or None," |
389 | | - " got: %r" % class_weight) |
390 | | - for c in class_weight: |
391 | | - i = np.searchsorted(classes, c) |
392 | | - if classes[i] != c: |
393 | | - raise ValueError("Class label %d not present." % c) |
394 | | - else: |
395 | | - weight[i] = class_weight[c] |
396 | | - |
397 | | - return weight |
0 commit comments