Skip to content

Commit 1dc23d7

Browse files
ENH Makes OneToOneFeatureMixin and ClassNamePrefixFeaturesOutMixin public (scikit-learn#24688)
Co-authored-by: Guillaume Lemaitre <[email protected]>
1 parent d4306ba commit 1dc23d7

File tree

29 files changed

+108
-73
lines changed

29 files changed

+108
-73
lines changed

doc/developers/develop.rst

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -647,8 +647,18 @@ scikit-learn introduces the `set_output` API for configuring transformers to
647647
output pandas DataFrames. The `set_output` API is automatically defined if the
648648
transformer defines :term:`get_feature_names_out` and subclasses
649649
:class:`base.TransformerMixin`. :term:`get_feature_names_out` is used to get the
650-
column names of pandas output. You can opt-out of the `set_output` API by
651-
setting `auto_wrap_output_keys=None` when defining a custom subclass::
650+
column names of pandas output.
651+
652+
:class:`base.OneToOneFeatureMixin` and
653+
:class:`base.ClassNamePrefixFeaturesOutMixin` are helpful mixins for defining
654+
:term:`get_feature_names_out`. :class:`base.OneToOneFeatureMixin` is useful when
655+
the transformer has a one-to-one correspondence between input features and output
656+
features, such as :class:`~preprocessing.StandardScaler`.
657+
:class:`base.ClassNamePrefixFeaturesOutMixin` is useful when the transformer
658+
needs to generate its own feature names out, such as :class:`~decomposition.PCA`.
659+
660+
You can opt-out of the `set_output` API by setting `auto_wrap_output_keys=None`
661+
when defining a custom subclass::
652662

653663
class MyTransformer(TransformerMixin, BaseEstimator, auto_wrap_output_keys=None):
654664

doc/modules/classes.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ Base classes
3434
base.DensityMixin
3535
base.RegressorMixin
3636
base.TransformerMixin
37+
base.OneToOneFeatureMixin
38+
base.ClassNamePrefixFeaturesOutMixin
3739
feature_selection.SelectorMixin
3840

3941
Functions

doc/whats_new/v1.2.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,14 @@ Changelog
115115
:pr:`123456` by :user:`Joe Bloggs <joeongithub>`.
116116
where 123456 is the *pull request* number, not the issue number.
117117
118+
:mod:`sklearn.base`
119+
-------------------
120+
121+
- |Enhancement| Introduces :class:`base.ClassNamePrefixFeaturesOutMixin` and
122+
:class:`base.ClassNamePrefixFeaturesOutMixin` mixins that defines
123+
:term:`get_feature_names_out` for common transformer uses cases.
124+
:pr:`24688` by `Thomas Fan`_.
125+
118126
:mod:`sklearn.calibration`
119127
..........................
120128

sklearn/base.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -811,6 +811,10 @@ class TransformerMixin(_SetOutputMixin):
811811
If :term:`get_feature_names_out` is defined, then `BaseEstimator` will
812812
automatically wrap `transform` and `fit_transform` to follow the `set_output`
813813
API. See the :ref:`developer_api_set_output` for details.
814+
815+
:class:`base.OneToOneFeatureMixin` and
816+
:class:`base.ClassNamePrefixFeaturesOutMixin` are helpful mixins for
817+
defining :term:`get_feature_names_out`.
814818
"""
815819

816820
def fit_transform(self, X, y=None, **fit_params):
@@ -847,11 +851,11 @@ def fit_transform(self, X, y=None, **fit_params):
847851
return self.fit(X, y, **fit_params).transform(X)
848852

849853

850-
class _OneToOneFeatureMixin:
854+
class OneToOneFeatureMixin:
851855
"""Provides `get_feature_names_out` for simple transformers.
852856
853-
Assumes there's a 1-to-1 correspondence between input features
854-
and output features.
857+
This mixin assumes there's a 1-to-1 correspondence between input features
858+
and output features, such as :class:`~preprocessing.StandardScaler`.
855859
"""
856860

857861
def get_feature_names_out(self, input_features=None):
@@ -877,15 +881,26 @@ def get_feature_names_out(self, input_features=None):
877881
return _check_feature_names_in(self, input_features)
878882

879883

880-
class _ClassNamePrefixFeaturesOutMixin:
884+
class ClassNamePrefixFeaturesOutMixin:
881885
"""Mixin class for transformers that generate their own names by prefixing.
882886
883-
Assumes that `_n_features_out` is defined for the estimator.
887+
This mixin is useful when the transformer needs to generate its own feature
888+
names out, such as :class:`~decomposition.PCA`. For example, if
889+
:class:`~decomposition.PCA` outputs 3 features, then the generated feature
890+
names out are: `["pca0", "pca1", "pca2"]`.
891+
892+
This mixin assumes that a `_n_features_out` attribute is defined when the
893+
transformer is fitted. `_n_features_out` is the number of output features
894+
that the transformer will return in `transform` of `fit_transform`.
884895
"""
885896

886897
def get_feature_names_out(self, input_features=None):
887898
"""Get output feature names for transformation.
888899
900+
The feature names out will prefixed by the lowercased class name. For
901+
example, if the transformer outputs 3 features, then the feature names
902+
out are: `["class_name0", "class_name1", "class_name2"]`.
903+
889904
Parameters
890905
----------
891906
input_features : array-like of str or None, default=None

sklearn/cluster/_agglomerative.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from scipy import sparse
1616
from scipy.sparse.csgraph import connected_components
1717

18-
from ..base import BaseEstimator, ClusterMixin, _ClassNamePrefixFeaturesOutMixin
18+
from ..base import BaseEstimator, ClusterMixin, ClassNamePrefixFeaturesOutMixin
1919
from ..metrics.pairwise import paired_distances
2020
from ..metrics.pairwise import _VALID_METRICS
2121
from ..metrics import DistanceMetric
@@ -1100,7 +1100,7 @@ def fit_predict(self, X, y=None):
11001100

11011101

11021102
class FeatureAgglomeration(
1103-
_ClassNamePrefixFeaturesOutMixin, AgglomerativeClustering, AgglomerationTransform
1103+
ClassNamePrefixFeaturesOutMixin, AgglomerativeClustering, AgglomerationTransform
11041104
):
11051105
"""Agglomerate features.
11061106

sklearn/cluster/_birch.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
TransformerMixin,
1616
ClusterMixin,
1717
BaseEstimator,
18-
_ClassNamePrefixFeaturesOutMixin,
18+
ClassNamePrefixFeaturesOutMixin,
1919
)
2020
from ..utils.extmath import row_norms
2121
from ..utils._param_validation import Interval
@@ -357,7 +357,7 @@ def radius(self):
357357

358358

359359
class Birch(
360-
_ClassNamePrefixFeaturesOutMixin, ClusterMixin, TransformerMixin, BaseEstimator
360+
ClassNamePrefixFeaturesOutMixin, ClusterMixin, TransformerMixin, BaseEstimator
361361
):
362362
"""Implements the BIRCH clustering algorithm.
363363

sklearn/cluster/_kmeans.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
BaseEstimator,
2323
ClusterMixin,
2424
TransformerMixin,
25-
_ClassNamePrefixFeaturesOutMixin,
25+
ClassNamePrefixFeaturesOutMixin,
2626
)
2727
from ..metrics.pairwise import euclidean_distances
2828
from ..metrics.pairwise import _euclidean_distances
@@ -813,7 +813,7 @@ def _labels_inertia_threadpool_limit(
813813

814814

815815
class _BaseKMeans(
816-
_ClassNamePrefixFeaturesOutMixin, TransformerMixin, ClusterMixin, BaseEstimator, ABC
816+
ClassNamePrefixFeaturesOutMixin, TransformerMixin, ClusterMixin, BaseEstimator, ABC
817817
):
818818
"""Base class for KMeans and MiniBatchKMeans"""
819819

sklearn/cross_decomposition/_pls.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
from ..base import BaseEstimator, RegressorMixin, TransformerMixin
1717
from ..base import MultiOutputMixin
18-
from ..base import _ClassNamePrefixFeaturesOutMixin
18+
from ..base import ClassNamePrefixFeaturesOutMixin
1919
from ..utils import check_array, check_consistent_length
2020
from ..utils.fixes import sp_version
2121
from ..utils.fixes import parse_version
@@ -159,7 +159,7 @@ def _svd_flip_1d(u, v):
159159

160160

161161
class _PLS(
162-
_ClassNamePrefixFeaturesOutMixin,
162+
ClassNamePrefixFeaturesOutMixin,
163163
TransformerMixin,
164164
RegressorMixin,
165165
MultiOutputMixin,
@@ -901,7 +901,7 @@ def __init__(
901901
)
902902

903903

904-
class PLSSVD(_ClassNamePrefixFeaturesOutMixin, TransformerMixin, BaseEstimator):
904+
class PLSSVD(ClassNamePrefixFeaturesOutMixin, TransformerMixin, BaseEstimator):
905905
"""Partial Least Square SVD.
906906
907907
This transformer simply performs a SVD on the cross-covariance matrix

sklearn/decomposition/_base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,13 @@
1111
import numpy as np
1212
from scipy import linalg
1313

14-
from ..base import BaseEstimator, TransformerMixin, _ClassNamePrefixFeaturesOutMixin
14+
from ..base import BaseEstimator, TransformerMixin, ClassNamePrefixFeaturesOutMixin
1515
from ..utils.validation import check_is_fitted
1616
from abc import ABCMeta, abstractmethod
1717

1818

1919
class _BasePCA(
20-
_ClassNamePrefixFeaturesOutMixin, TransformerMixin, BaseEstimator, metaclass=ABCMeta
20+
ClassNamePrefixFeaturesOutMixin, TransformerMixin, BaseEstimator, metaclass=ABCMeta
2121
):
2222
"""Base class for PCA methods.
2323

sklearn/decomposition/_dict_learning.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from scipy import linalg
1616
from joblib import Parallel, effective_n_jobs
1717

18-
from ..base import BaseEstimator, TransformerMixin, _ClassNamePrefixFeaturesOutMixin
18+
from ..base import BaseEstimator, TransformerMixin, ClassNamePrefixFeaturesOutMixin
1919
from ..utils import check_array, check_random_state, gen_even_slices, gen_batches
2020
from ..utils import deprecated
2121
from ..utils._param_validation import Hidden, Interval, StrOptions
@@ -1152,7 +1152,7 @@ def dict_learning_online(
11521152
return dictionary
11531153

11541154

1155-
class _BaseSparseCoding(_ClassNamePrefixFeaturesOutMixin, TransformerMixin):
1155+
class _BaseSparseCoding(ClassNamePrefixFeaturesOutMixin, TransformerMixin):
11561156
"""Base class from SparseCoder and DictionaryLearning algorithms."""
11571157

11581158
def __init__(

0 commit comments

Comments
 (0)