Skip to content

Commit 7c835d5

Browse files
MAINT Add parameter validation to HashingVectorizer. (scikit-learn#24181)
Co-authored-by: jeremie du boisberranger <[email protected]>
1 parent d7c978b commit 7c835d5

File tree

2 files changed

+30
-13
lines changed

2 files changed

+30
-13
lines changed

sklearn/feature_extraction/text.py

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
from collections import defaultdict
1616
from collections.abc import Mapping
1717
from functools import partial
18-
import numbers
1918
from numbers import Integral, Real
2019
from operator import itemgetter
2120
import re
@@ -631,7 +630,7 @@ class HashingVectorizer(TransformerMixin, _VectorizerMixin, BaseEstimator):
631630
'strict', meaning that a UnicodeDecodeError will be raised. Other
632631
values are 'ignore' and 'replace'.
633632
634-
strip_accents : {'ascii', 'unicode'}, default=None
633+
strip_accents : {'ascii', 'unicode'} or callable, default=None
635634
Remove accents and perform other character normalization
636635
during the preprocessing step.
637636
'ascii' is a fast method that only works on characters that have
@@ -664,7 +663,7 @@ class HashingVectorizer(TransformerMixin, _VectorizerMixin, BaseEstimator):
664663
will be removed from the resulting tokens.
665664
Only applies if ``analyzer == 'word'``.
666665
667-
token_pattern : str, default=r"(?u)\\b\\w\\w+\\b"
666+
token_pattern : str or None, default=r"(?u)\\b\\w\\w+\\b"
668667
Regular expression denoting what constitutes a "token", only used
669668
if ``analyzer == 'word'``. The default regexp selects tokens of 2
670669
or more alphanumeric characters (punctuation is completely ignored
@@ -740,6 +739,25 @@ class HashingVectorizer(TransformerMixin, _VectorizerMixin, BaseEstimator):
740739
(4, 16)
741740
"""
742741

742+
_parameter_constraints: dict = {
743+
"input": [StrOptions({"filename", "file", "content"})],
744+
"encoding": [str],
745+
"decode_error": [StrOptions({"strict", "ignore", "replace"})],
746+
"strip_accents": [StrOptions({"ascii", "unicode"}), None, callable],
747+
"lowercase": ["boolean"],
748+
"preprocessor": [callable, None],
749+
"tokenizer": [callable, None],
750+
"stop_words": [StrOptions({"english"}), list, None],
751+
"token_pattern": [str, None],
752+
"ngram_range": [tuple],
753+
"analyzer": [StrOptions({"word", "char", "char_wb"}), callable],
754+
"n_features": [Interval(Integral, 1, np.iinfo(np.int32).max, closed="left")],
755+
"binary": ["boolean"],
756+
"norm": [StrOptions({"l1", "l2"}), None],
757+
"alternate_sign": ["boolean"],
758+
"dtype": "no_validation", # delegate to numpy
759+
}
760+
743761
def __init__(
744762
self,
745763
*,
@@ -796,6 +814,8 @@ def partial_fit(self, X, y=None):
796814
self : object
797815
HashingVectorizer instance.
798816
"""
817+
# TODO: only validate during the first call
818+
self._validate_params()
799819
return self
800820

801821
def fit(self, X, y=None):
@@ -814,6 +834,8 @@ def fit(self, X, y=None):
814834
self : object
815835
HashingVectorizer instance.
816836
"""
837+
self._validate_params()
838+
817839
# triggers a parameter validation
818840
if isinstance(X, str):
819841
raise ValueError(
@@ -935,7 +957,7 @@ class CountVectorizer(_VectorizerMixin, BaseEstimator):
935957
Remove accents and perform other character normalization
936958
during the preprocessing step.
937959
'ascii' is a fast method that only works on characters that have
938-
an direct ASCII mapping.
960+
a direct ASCII mapping.
939961
'unicode' is a slightly slower method that works on any characters.
940962
None (default) does nothing.
941963
@@ -1359,12 +1381,8 @@ def fit_transform(self, raw_documents, y=None):
13591381

13601382
if not self.fixed_vocabulary_:
13611383
n_doc = X.shape[0]
1362-
max_doc_count = (
1363-
max_df if isinstance(max_df, numbers.Integral) else max_df * n_doc
1364-
)
1365-
min_doc_count = (
1366-
min_df if isinstance(min_df, numbers.Integral) else min_df * n_doc
1367-
)
1384+
max_doc_count = max_df if isinstance(max_df, Integral) else max_df * n_doc
1385+
min_doc_count = min_df if isinstance(min_df, Integral) else min_df * n_doc
13681386
if max_doc_count < min_doc_count:
13691387
raise ValueError("max_df corresponds to < documents than min_df")
13701388
if max_features is not None:
@@ -1771,11 +1789,11 @@ class TfidfVectorizer(CountVectorizer):
17711789
'strict', meaning that a UnicodeDecodeError will be raised. Other
17721790
values are 'ignore' and 'replace'.
17731791
1774-
strip_accents : {'ascii', 'unicode'}, default=None
1792+
strip_accents : {'ascii', 'unicode'} or callable, default=None
17751793
Remove accents and perform other character normalization
17761794
during the preprocessing step.
17771795
'ascii' is a fast method that only works on characters that have
1778-
an direct ASCII mapping.
1796+
a direct ASCII mapping.
17791797
'unicode' is a slightly slower method that works on any characters.
17801798
None (default) does nothing.
17811799

sklearn/tests/test_common.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -465,7 +465,6 @@ def test_estimators_do_not_raise_errors_in_init_or_set_params(Estimator):
465465

466466
PARAM_VALIDATION_ESTIMATORS_TO_IGNORE = [
467467
"DictionaryLearning",
468-
"HashingVectorizer",
469468
"MiniBatchDictionaryLearning",
470469
"MultiTaskElasticNet",
471470
"MultiTaskLasso",

0 commit comments

Comments
 (0)