Skip to content

Commit eaab848

Browse files
authored
DOC: Clarify recommended usage of fit_transform() vs fit().transform() in TargetEncoder (scikit-learn#32347)
1 parent 3eb55e2 commit eaab848

File tree

4 files changed

+42
-24
lines changed

4 files changed

+42
-24
lines changed

doc/modules/preprocessing.rst

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -936,34 +936,37 @@ cardinality categories are location based such as zip code or region.
936936
where :math:`L_i` is the set of observations with category :math:`i` and
937937
:math:`n_i` is the number of observations with category :math:`i`.
938938

939+
.. note::
940+
In :class:`TargetEncoder`, `fit(X, y).transform(X)` does not equal `fit_transform(X, y)`.
939941

940942
:meth:`~TargetEncoder.fit_transform` internally relies on a :term:`cross fitting`
941943
scheme to prevent target information from leaking into the train-time
942944
representation, especially for non-informative high-cardinality categorical
943-
variables, and help prevent the downstream model from overfitting spurious
944-
correlations. Note that as a result, `fit(X, y).transform(X)` does not equal
945-
`fit_transform(X, y)`. In :meth:`~TargetEncoder.fit_transform`, the training
946-
data is split into *k* folds (determined by the `cv` parameter) and each fold is
947-
encoded using the encodings learnt using the other *k-1* folds. The following
948-
diagram shows the :term:`cross fitting` scheme in
945+
variables (features with many unique categories where each category appears
946+
only a few times), and help prevent the downstream model from overfitting spurious
947+
correlations. In :meth:`~TargetEncoder.fit_transform`, the training data is split into
948+
*k* folds (determined by the `cv` parameter) and each fold is encoded using the
949+
encodings learnt using the *other k-1* folds. For this reason, training data should
950+
always be trained and transformed with `fit_transform(X_train, y_train)`.
951+
952+
This diagram shows the :term:`cross fitting` scheme in
949953
:meth:`~TargetEncoder.fit_transform` with the default `cv=5`:
950954

951955
.. image:: ../images/target_encoder_cross_validation.svg
952956
:width: 600
953957
:align: center
954958

955-
:meth:`~TargetEncoder.fit_transform` also learns a 'full data' encoding using
956-
the whole training set. This is never used in
957-
:meth:`~TargetEncoder.fit_transform` but is saved to the attribute `encodings_`,
958-
for use when :meth:`~TargetEncoder.transform` is called. Note that the encodings
959-
learned for each fold during the :term:`cross fitting` scheme are not saved to
960-
an attribute.
961-
962-
The :meth:`~TargetEncoder.fit` method does **not** use any :term:`cross fitting`
963-
schemes and learns one encoding on the entire training set, which is used to
964-
encode categories in :meth:`~TargetEncoder.transform`.
965-
This encoding is the same as the 'full data'
966-
encoding learned in :meth:`~TargetEncoder.fit_transform`.
959+
The :meth:`~TargetEncoder.fit` method does **not** use any :term:`cross fitting` schemes
960+
and learns one encoding on the entire training set. It is discouraged to use this
961+
method because it can introduce data leakage as mentioned above. Use
962+
:meth:`~TargetEncoder.fit_transform` instead.
963+
964+
During :meth:`~TargetEncoder.fit_transform`, the encoder learns category
965+
encodings from the full training data and stores them in the
966+
:attr:`~TargetEncoder.encodings_` attribute. The intermediate encodings learned
967+
for each fold during the :term:`cross fitting` process are temporary and not
968+
saved. The stored encodings can then be used to transform test data with
969+
`encoder.transform(X_test)`.
967970

968971
.. note::
969972
:class:`TargetEncoder` considers missing values, such as `np.nan` or `None`,

examples/preprocessing/plot_target_encoder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
.. note::
1414
`fit(X, y).transform(X)` does not equal `fit_transform(X, y)` because a
1515
cross fitting scheme is used in `fit_transform` for encoding. See the
16-
:ref:`User Guide <target_encoder>`. for details.
16+
:ref:`User Guide <target_encoder>` for details.
1717
"""
1818

1919
# Authors: The scikit-learn developers

examples/preprocessing/plot_target_encoder_cross_val.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
and the target. To prevent overfitting, :meth:`TargetEncoder.fit_transform` uses
1212
an internal :term:`cross fitting` scheme to encode the training data to be used
1313
by a downstream model. This scheme involves splitting the data into *k* folds
14-
and encoding each fold using the encodings learnt using the other *k-1* folds.
14+
and encoding each fold using the encodings learnt using the *other k-1* folds.
1515
In this example, we demonstrate the importance of the cross
1616
fitting procedure to prevent overfitting.
1717
"""
@@ -140,7 +140,7 @@
140140
# %%
141141
# While :meth:`TargetEncoder.fit_transform` uses an internal
142142
# :term:`cross fitting` scheme to learn encodings for the training set,
143-
# :meth:`TargetEncoder.transform` itself does not.
143+
# :meth:`TargetEncoder.fit` followed by :meth:`TargetEncoder.transform` does not.
144144
# It uses the complete training set to learn encodings and to transform the
145145
# categorical features. Thus, we can use :meth:`TargetEncoder.fit` followed by
146146
# :meth:`TargetEncoder.transform` to disable the :term:`cross fitting`. This

sklearn/preprocessing/_target_encoder.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,14 @@ def __init__(
218218
def fit(self, X, y):
219219
"""Fit the :class:`TargetEncoder` to X and y.
220220
221+
It is discouraged to use this method because it can introduce data leakage.
222+
Use `fit_transform` on training data instead.
223+
224+
.. note::
225+
`fit(X, y).transform(X)` does not equal `fit_transform(X, y)` because a
226+
:term:`cross fitting` scheme is used in `fit_transform` for encoding.
227+
See the :ref:`User Guide <target_encoder>` for details.
228+
221229
Parameters
222230
----------
223231
X : array-like of shape (n_samples, n_features)
@@ -236,12 +244,16 @@ def fit(self, X, y):
236244

237245
@_fit_context(prefer_skip_nested_validation=True)
238246
def fit_transform(self, X, y):
239-
"""Fit :class:`TargetEncoder` and transform X with the target encoding.
247+
"""Fit :class:`TargetEncoder` and transform `X` with the target encoding.
248+
249+
This method uses a :term:`cross fitting` scheme to prevent target leakage
250+
and overfitting in downstream predictors. It is the recommended method for
251+
encoding training data.
240252
241253
.. note::
242254
`fit(X, y).transform(X)` does not equal `fit_transform(X, y)` because a
243255
:term:`cross fitting` scheme is used in `fit_transform` for encoding.
244-
See the :ref:`User Guide <target_encoder>`. for details.
256+
See the :ref:`User Guide <target_encoder>` for details.
245257
246258
Parameters
247259
----------
@@ -314,10 +326,13 @@ def fit_transform(self, X, y):
314326
def transform(self, X):
315327
"""Transform X with the target encoding.
316328
329+
This method internally uses the `encodings_` attribute learnt during
330+
:meth:`TargetEncoder.fit_transform` to transform test data.
331+
317332
.. note::
318333
`fit(X, y).transform(X)` does not equal `fit_transform(X, y)` because a
319334
:term:`cross fitting` scheme is used in `fit_transform` for encoding.
320-
See the :ref:`User Guide <target_encoder>`. for details.
335+
See the :ref:`User Guide <target_encoder>` for details.
321336
322337
Parameters
323338
----------

0 commit comments

Comments
 (0)