Skip to content

Commit 006a6cb

Browse files
committed
ENH add VarianceThreshold feature selection method
1 parent 36ba287 commit 006a6cb

File tree

5 files changed

+147
-0
lines changed

5 files changed

+147
-0
lines changed

doc/modules/feature_selection.rst

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,39 @@ for feature selection/dimensionality reduction on sample sets, either to
1212
improve estimators' accuracy scores or to boost their performance on very
1313
high-dimensional datasets.
1414

15+
16+
Removing features with low variance
17+
===================================
18+
19+
:class:`VarianceThreshold` is a simple baseline approach to feature selection.
20+
It removes all features whose variance doesn't meet some threshold.
21+
By default, it removes all zero-variance features,
22+
i.e. features that have the same value in all samples.
23+
24+
As an example, suppose that we have a dataset with boolean features,
25+
and we want to remove all features that are either one or zero (on or off)
26+
in more than 80% of the samples.
27+
Boolean features are Bernoulli random variables,
28+
and the variance of such variables is given by
29+
30+
.. math:: \mathrm{Var}[X] = p(1 - p)
31+
32+
so we can select using the threshold ``.8 * (1 - .8)``::
33+
34+
>>> from sklearn.feature_selection import VarianceThreshold
35+
>>> X = [[0, 0, 1], [0, 1, 0], [1, 0, 0], [0, 1, 1], [0, 1, 0], [0, 1, 1]]
36+
>>> sel = VarianceThreshold(threshold=(.8 * (1 - .8)))
37+
>>> sel.fit_transform(X)
38+
array([[0, 1],
39+
[1, 0],
40+
[0, 0],
41+
[1, 1],
42+
[1, 0],
43+
[1, 1]])
44+
45+
As expected, ``VarianceThreshold`` has removed the first column,
46+
which has a probability :math:`p = 5/6 > .8` of containing a one.
47+
1548
Univariate feature selection
1649
============================
1750

doc/whats_new.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,15 @@
55
0.15
66
=====
77

8+
Changelog
9+
---------
10+
811
- Add predict method to :class:`cluster.AffinityPropagation` and
912
:class:`cluster.MeanShift`, by `Mathieu Blondel`_.
1013

14+
- New unsupervised feature selection algorithm
15+
:class:`feature_selection.VarianceThreshold`, by `Lars Buitinck`_.
16+
1117
.. _changes_0_14:
1218

1319
0.14

sklearn/feature_selection/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
from .univariate_selection import SelectFwe
1616
from .univariate_selection import GenericUnivariateSelect
1717

18+
from .variance_threshold import VarianceThreshold
19+
1820
from .rfe import RFE
1921
from .rfe import RFECV
2022

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
from sklearn.utils.testing import (assert_array_equal, assert_equal,
2+
assert_raises)
3+
4+
import numpy as np
5+
from scipy.sparse import bsr_matrix, csc_matrix, csr_matrix
6+
7+
from sklearn.feature_selection import VarianceThreshold
8+
9+
data = [[0, 1, 2, 3, 4],
10+
[0, 2, 2, 3, 5],
11+
[1, 1, 2, 4, 0]]
12+
13+
14+
def test_zero_variance():
15+
"""Test VarianceThreshold with default setting, zero variance."""
16+
17+
for X in [data, csr_matrix(data), csc_matrix(data), bsr_matrix(data)]:
18+
sel = VarianceThreshold().fit(X)
19+
assert_array_equal([0, 1, 3, 4], sel.get_support(indices=True))
20+
21+
assert_raises(ValueError, VarianceThreshold().fit, [0, 1, 2, 3])
22+
assert_raises(ValueError, VarianceThreshold().fit, [[0, 1], [0, 1]])
23+
24+
25+
def test_variance_threshold():
26+
"""Test VarianceThreshold with custom variance."""
27+
for X in [data, csr_matrix(data)]:
28+
X = VarianceThreshold(threshold=.4).fit_transform(X)
29+
assert_equal((len(data), 1), X.shape)
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# Author: Lars Buitinck <[email protected]>
2+
# License: 3-clause BSD
3+
4+
import numpy as np
5+
from ..base import BaseEstimator
6+
from .base import SelectorMixin
7+
from ..utils import atleast2d_or_csr
8+
from ..utils.sparsefuncs import csr_mean_variance_axis0
9+
10+
11+
class VarianceThreshold(BaseEstimator, SelectorMixin):
12+
"""Feature selector that removes all low-variance features.
13+
14+
This feature selection algorithm looks only at the features (X), not the
15+
desired outputs (y), and can thus be used for unsupervised learning.
16+
17+
Parameters
18+
----------
19+
threshold : float, optional
20+
Features with a training-set variance lower than this threshold will
21+
be removed. The default is to keep all features with non-zero variance,
22+
i.e. remove the features that have the same value in all samples.
23+
24+
Attributes
25+
----------
26+
`variances_` : array, shape (n_features,)
27+
Variances of individual features.
28+
29+
Examples
30+
--------
31+
The following dataset has integer features, two of which are the same
32+
in every sample. These are removed with the default setting for threshold::
33+
34+
>>> X = [[0, 2, 0, 3], [0, 1, 4, 3], [0, 1, 1, 3]]
35+
>>> selector = VarianceThreshold()
36+
>>> selector.fit_transform(X)
37+
array([[2, 0],
38+
[1, 4],
39+
[1, 1]])
40+
"""
41+
42+
def __init__(self, threshold=0.):
43+
self.threshold = threshold
44+
45+
def fit(self, X, y=None):
46+
"""Learn empirical variances from X.
47+
48+
Parameters
49+
----------
50+
X : {array-like, sparse matrix}, shape (n_samples, n_features)
51+
Sample vectors from which to compute variances.
52+
53+
y : any
54+
Ignored. This parameter exists only for compatibility with
55+
sklearn.pipeline.Pipeline.
56+
57+
Returns
58+
-------
59+
self
60+
"""
61+
X = atleast2d_or_csr(X, dtype=np.float64)
62+
63+
if hasattr(X, "toarray"): # sparse matrix
64+
_, self.variances_ = csr_mean_variance_axis0(X)
65+
else:
66+
self.variances_ = np.var(X, axis=0)
67+
68+
if np.all(self.variances_ <= self.threshold):
69+
msg = "No feature in X meets the variance threshold {0:.5f}"
70+
if X.shape[0] == 1:
71+
msg += " (X contains only one sample)"
72+
raise ValueError(msg.format(self.threshold))
73+
74+
return self
75+
76+
def _get_support_mask(self):
77+
return self.variances_ > self.threshold

0 commit comments

Comments
 (0)