Skip to content

Commit 62e9bb8

Browse files
changjnothman
authored andcommitted
Feature: Implement PowerTransformer (scikit-learn#10210)
1 parent 78ccdd1 commit 62e9bb8

File tree

9 files changed

+599
-31
lines changed

9 files changed

+599
-31
lines changed

doc/modules/classes.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1200,6 +1200,7 @@ Model validation
12001200
preprocessing.OneHotEncoder
12011201
preprocessing.CategoricalEncoder
12021202
preprocessing.PolynomialFeatures
1203+
preprocessing.PowerTransformer
12031204
preprocessing.QuantileTransformer
12041205
preprocessing.RobustScaler
12051206
preprocessing.StandardScaler
@@ -1217,6 +1218,7 @@ Model validation
12171218
preprocessing.quantile_transform
12181219
preprocessing.robust_scale
12191220
preprocessing.scale
1221+
preprocessing.power_transform
12201222

12211223

12221224
.. _random_projection_ref:

doc/modules/preprocessing.rst

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,9 @@ defined by :math:`phi` followed by removal of the mean in that space.
261261
Non-linear transformation
262262
=========================
263263

264+
Mapping to a Uniform distribution
265+
---------------------------------
266+
264267
Like scalers, :class:`QuantileTransformer` puts all features into the same,
265268
known range or distribution. However, by performing a rank transformation, it
266269
smooths out unusual distributions and is less influenced by outliers than
@@ -299,8 +302,53 @@ This can be confirmed on a independent testing set with similar remarks::
299302
... # doctest: +ELLIPSIS +SKIP
300303
array([ 0.01..., 0.25..., 0.46..., 0.60... , 0.94...])
301304

302-
It is also possible to map the transformed data to a normal distribution by
303-
setting ``output_distribution='normal'``::
305+
Mapping to a Gaussian distribution
306+
----------------------------------
307+
308+
In many modeling scenarios, normality of the features in a dataset is desirable.
309+
Power transforms are a family of parametric, monotonic transformations that aim
310+
to map data from any distribution to as close to a Gaussian distribution as
311+
possible in order to stabilize variance and minimize skewness.
312+
313+
:class:`PowerTransformer` currently provides one such power transformation,
314+
the Box-Cox transform. The Box-Cox transform is given by:
315+
316+
.. math::
317+
y_i^{(\lambda)} =
318+
\begin{cases}
319+
\dfrac{y_i^\lambda - 1}{\lambda} & \text{if } \lambda \neq 0, \\[8pt]
320+
\ln{(y_i)} & \text{if } \lambda = 0,
321+
\end{cases}
322+
323+
Box-Cox can only be applied to strictly positive data. The transformation is
324+
parameterized by :math:`\lambda`, which is determined through maximum likelihood
325+
estimation. Here is an example of using Box-Cox to map samples drawn from a
326+
lognormal distribution to a normal distribution::
327+
328+
>>> pt = preprocessing.PowerTransformer(method='box-cox')
329+
>>> X_lognormal = np.random.RandomState(616).lognormal(size=(3, 3))
330+
>>> X_lognormal # doctest: +ELLIPSIS
331+
array([[ 1.28..., 1.18..., 0.84...],
332+
[ 0.94..., 1.60..., 0.38...],
333+
[ 1.35..., 0.21..., 1.09...]])
334+
>>> pt.fit_transform(X_lognormal) # doctest: +ELLIPSIS
335+
array([[ 0.49..., 0.17..., -0.15...],
336+
[-0.05..., 0.58..., -0.57...],
337+
[ 0.69..., -0.84..., 0.10...]])
338+
339+
Below are examples of Box-Cox applied to various probability distributions.
340+
Note that when applied to certain distributions, Box-Cox achieves very
341+
Gaussian-like results, but with others, it is ineffective. This highlights
342+
the importance of visualizing the data before and after transformation.
343+
344+
.. figure:: ../auto_examples/preprocessing/images/sphx_glr_plot_power_transformer_001.png
345+
:target: ../auto_examples/preprocessing/plot_power_transformer.html
346+
:align: center
347+
:scale: 100
348+
349+
It is also possible to map data to a normal distribution using
350+
:class:`QuantileTransformer` by setting ``output_distribution='normal'``.
351+
Using the earlier example with the iris dataset::
304352

305353
>>> quantile_transformer = preprocessing.QuantileTransformer(
306354
... output_distribution='normal', random_state=0)

doc/whats_new/v0.20.rst

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,14 @@ Preprocessing
6060
the maximum value in the features. :issue:`9151` by
6161
:user:`Vighnesh Birodkar <vighneshbirodkar>` and `Joris Van den Bossche`_.
6262

63+
- Added :class:`preprocessing.PowerTransformer`, which implements the Box-Cox
64+
power transformation, allowing users to map data from any distribution to a
65+
Gaussian distribution. This is useful as a variance-stabilizing transformation
66+
in situations where normality and homoscedasticity are desirable.
67+
:issue:`10210` by :user:`Eric Chang <ericchang00>` and
68+
:user:`Maniteja Nandana <maniteja123>`.
69+
70+
6371
Model evaluation
6472

6573
- Added the :func:`metrics.balanced_accuracy_score` metric and a corresponding
@@ -223,16 +231,14 @@ Feature Extraction
223231
throw an exception if ``max_patches`` was greater than or equal to the number
224232
of all possible patches rather than simply returning the number of possible
225233
patches. :issue:`10100` by :user:`Varun Agrawal <varunagrawal>`
226-
234+
227235
- Fixed a bug in :class:`feature_extraction.text.CountVectorizer`,
228236
:class:`feature_extraction.text.TfidfVectorizer`,
229237
:class:`feature_extraction.text.HashingVectorizer` to support 64 bit sparse
230238
array indexing necessary to process large datasets with more than 2·10⁹ tokens
231239
(words or n-grams). :issue:`9147` by :user:`Claes-Fredrik Mannby <mannby>`
232240
and `Roman Yurchak`_.
233241

234-
235-
236242
API changes summary
237243
-------------------
238244

examples/preprocessing/plot_all_scaling.py

Lines changed: 44 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,10 @@
2929
other in the way to estimate the parameters used to shift and scale each
3030
feature.
3131
32-
``QuantileTransformer`` provides a non-linear transformation in which distances
33-
between marginal outliers and inliers are shrunk.
32+
``QuantileTransformer`` provides non-linear transformations in which distances
33+
between marginal outliers and inliers are shrunk. ``PowerTransformer`` provides
34+
non-linear transformations in which data is mapped to a normal distribution to
35+
stabilize variance and minimize skewness.
3436
3537
Unlike the previous transformations, normalization refers to a per sample
3638
transformation instead of a per feature transformation.
@@ -59,7 +61,8 @@
5961
from sklearn.preprocessing import StandardScaler
6062
from sklearn.preprocessing import RobustScaler
6163
from sklearn.preprocessing import Normalizer
62-
from sklearn.preprocessing.data import QuantileTransformer
64+
from sklearn.preprocessing import QuantileTransformer
65+
from sklearn.preprocessing import PowerTransformer
6366

6467
from sklearn.datasets import fetch_california_housing
6568

@@ -84,14 +87,16 @@
8487
MaxAbsScaler().fit_transform(X)),
8588
('Data after robust scaling',
8689
RobustScaler(quantile_range=(25, 75)).fit_transform(X)),
87-
('Data after quantile transformation (uniform pdf)',
88-
QuantileTransformer(output_distribution='uniform')
89-
.fit_transform(X)),
90+
('Data after power transformation (Box-Cox)',
91+
PowerTransformer(method='box-cox').fit_transform(X)),
9092
('Data after quantile transformation (gaussian pdf)',
9193
QuantileTransformer(output_distribution='normal')
9294
.fit_transform(X)),
95+
('Data after quantile transformation (uniform pdf)',
96+
QuantileTransformer(output_distribution='uniform')
97+
.fit_transform(X)),
9398
('Data after sample-wise L2 normalizing',
94-
Normalizer().fit_transform(X))
99+
Normalizer().fit_transform(X)),
95100
]
96101

97102
# scale the output between 0 and 1 for the colorbar
@@ -286,6 +291,35 @@ def make_plot(item_idx):
286291

287292
make_plot(4)
288293

294+
##############################################################################
295+
# PowerTransformer (Box-Cox)
296+
# --------------------------
297+
#
298+
# ``PowerTransformer`` applies a power transformation to each
299+
# feature to make the data more Gaussian-like. Currently,
300+
# ``PowerTransformer`` implements the Box-Cox transform. It differs from
301+
# QuantileTransformer (Gaussian output) in that it does not map the
302+
# data to a zero-mean, unit-variance Gaussian distribution. Instead, Box-Cox
303+
# finds the optimal scaling factor to stabilize variance and mimimize skewness
304+
# through maximum likelihood estimation. Note that Box-Cox can only be applied
305+
# to positive, non-zero data. Income and number of households happen to be
306+
# strictly positive, but if negative values are present, a constant can be
307+
# added to each feature to shift it into the positive range - this is known as
308+
# the two-parameter Box-Cox transform.
309+
310+
make_plot(5)
311+
312+
##############################################################################
313+
# QuantileTransformer (Gaussian output)
314+
# -------------------------------------
315+
#
316+
# ``QuantileTransformer`` has an additional ``output_distribution`` parameter
317+
# allowing to match a Gaussian distribution instead of a uniform distribution.
318+
# Note that this non-parametetric transformer introduces saturation artifacts
319+
# for extreme values.
320+
321+
make_plot(6)
322+
289323
###################################################################
290324
# QuantileTransformer (uniform output)
291325
# ------------------------------------
@@ -302,18 +336,7 @@ def make_plot(item_idx):
302336
# any outlier by setting them to the a priori defined range boundaries (0 and
303337
# 1).
304338

305-
make_plot(5)
306-
307-
##############################################################################
308-
# QuantileTransformer (Gaussian output)
309-
# -------------------------------------
310-
#
311-
# ``QuantileTransformer`` has an additional ``output_distribution`` parameter
312-
# allowing to match a Gaussian distribution instead of a uniform distribution.
313-
# Note that this non-parametetric transformer introduces saturation artifacts
314-
# for extreme values.
315-
316-
make_plot(6)
339+
make_plot(7)
317340

318341
##############################################################################
319342
# Normalizer
@@ -326,5 +349,6 @@ def make_plot(item_idx):
326349
# transformed data only lie in the positive quadrant. This would not be the
327350
# case if some original features had a mix of positive and negative values.
328351

329-
make_plot(7)
352+
make_plot(8)
353+
330354
plt.show()
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
"""
2+
==========================================================
3+
Using PowerTransformer to apply the Box-Cox transformation
4+
==========================================================
5+
6+
This example demonstrates the use of the Box-Cox transform through
7+
:class:`preprocessing.PowerTransformer` to map data from various distributions
8+
to a normal distribution.
9+
10+
Box-Cox is useful as a transformation in modeling problems where
11+
homoscedasticity and normality are desired. Below are examples of Box-Cox
12+
applied to six different probability distributions: Lognormal, Chi-squared,
13+
Weibull, Gaussian, Uniform, and Bimodal.
14+
15+
Note that the transformation successfully maps the data to a normal
16+
distribution when applied to certain datasets, but is ineffective with others.
17+
This highlights the importance of visualizing the data before and after
18+
transformation.
19+
"""
20+
21+
# Author: Eric Chang <[email protected]>
22+
# License: BSD 3 clause
23+
24+
import numpy as np
25+
import matplotlib.pyplot as plt
26+
27+
from sklearn.preprocessing import PowerTransformer, minmax_scale
28+
29+
print(__doc__)
30+
31+
32+
N_SAMPLES = 3000
33+
FONT_SIZE = 6
34+
BINS = 100
35+
36+
37+
pt = PowerTransformer(method='box-cox')
38+
rng = np.random.RandomState(304)
39+
size = (N_SAMPLES, 1)
40+
41+
42+
# lognormal distribution
43+
X_lognormal = rng.lognormal(size=size)
44+
45+
# chi-squared distribution
46+
df = 3
47+
X_chisq = rng.chisquare(df=df, size=size)
48+
49+
# weibull distribution
50+
a = 50
51+
X_weibull = rng.weibull(a=a, size=size)
52+
53+
# gaussian distribution
54+
loc = 100
55+
X_gaussian = rng.normal(loc=loc, size=size)
56+
57+
# uniform distirbution
58+
X_uniform = rng.uniform(low=0, high=1, size=size)
59+
60+
# bimodal distribution
61+
loc_a, loc_b = 100, 105
62+
X_a, X_b = rng.normal(loc=loc_a, size=size), rng.normal(loc=loc_b, size=size)
63+
X_bimodal = np.concatenate([X_a, X_b], axis=0)
64+
65+
66+
# create plots
67+
distributions = [
68+
('Lognormal', X_lognormal),
69+
('Chi-squared', X_chisq),
70+
('Weibull', X_weibull),
71+
('Gaussian', X_gaussian),
72+
('Uniform', X_uniform),
73+
('Bimodal', X_bimodal)
74+
]
75+
76+
colors = ['firebrick', 'darkorange', 'goldenrod',
77+
'seagreen', 'royalblue', 'darkorchid']
78+
79+
fig, axes = plt.subplots(nrows=4, ncols=3)
80+
axes = axes.flatten()
81+
axes_idxs = [(0, 3), (1, 4), (2, 5), (6, 9), (7, 10), (8, 11)]
82+
axes_list = [(axes[i], axes[j]) for i, j in axes_idxs]
83+
84+
85+
for distribution, color, axes in zip(distributions, colors, axes_list):
86+
name, X = distribution
87+
# scale all distributions to the range [0, 10]
88+
X = minmax_scale(X, feature_range=(1e-10, 10))
89+
90+
# perform power transform
91+
X_trans = pt.fit_transform(X)
92+
lmbda = round(pt.lambdas_[0], 2)
93+
94+
ax_original, ax_trans = axes
95+
96+
ax_original.hist(X, color=color, bins=BINS)
97+
ax_original.set_title(name, fontsize=FONT_SIZE)
98+
ax_original.tick_params(axis='both', which='major', labelsize=FONT_SIZE)
99+
100+
ax_trans.hist(X_trans, color=color, bins=BINS)
101+
ax_trans.set_title('{} after Box-Cox, $\lambda$ = {}'.format(name, lmbda),
102+
fontsize=FONT_SIZE)
103+
ax_trans.tick_params(axis='both', which='major', labelsize=FONT_SIZE)
104+
105+
106+
plt.tight_layout()
107+
plt.show()

sklearn/preprocessing/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@
2121
from .data import maxabs_scale
2222
from .data import minmax_scale
2323
from .data import quantile_transform
24+
from .data import power_transform
2425
from .data import OneHotEncoder
26+
from .data import PowerTransformer
2527
from .data import CategoricalEncoder
2628

2729
from .data import PolynomialFeatures
@@ -48,6 +50,7 @@
4850
'Normalizer',
4951
'OneHotEncoder',
5052
'CategoricalEncoder',
53+
'PowerTransformer',
5154
'RobustScaler',
5255
'StandardScaler',
5356
'add_dummy_feature',
@@ -60,4 +63,5 @@
6063
'minmax_scale',
6164
'label_binarize',
6265
'quantile_transform',
66+
'power_transform',
6367
]

0 commit comments

Comments
 (0)