Skip to content

Commit 42fa09e

Browse files
MAINT Add parameter validation to LatentDirichletAllocation (scikit-learn#24212)
Co-authored-by: Thomas J. Fan <[email protected]>
1 parent 69066c6 commit 42fa09e

File tree

3 files changed

+23
-39
lines changed

3 files changed

+23
-39
lines changed

sklearn/decomposition/_lda.py

Lines changed: 23 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
# Author: Chyi-Kwei Yau
1212
# Author: Matthew D. Hoffman (original onlineldavb implementation)
13+
from numbers import Integral, Real
1314

1415
import numpy as np
1516
import scipy.sparse as sp
@@ -21,6 +22,7 @@
2122
from ..utils.validation import check_non_negative
2223
from ..utils.validation import check_is_fitted
2324
from ..utils.fixes import delayed
25+
from ..utils._param_validation import Interval, StrOptions
2426

2527
from ._online_lda_fast import (
2628
mean_change,
@@ -320,6 +322,25 @@ class conditional densities to the data and using Bayes' rule.
320322
[0.15297572, 0.00362644, 0.44412786, 0.39568399, 0.003586 ]])
321323
"""
322324

325+
_parameter_constraints: dict = {
326+
"n_components": [Interval(Integral, 0, None, closed="neither")],
327+
"doc_topic_prior": [None, Interval(Real, 0, 1, closed="both")],
328+
"topic_word_prior": [None, Interval(Real, 0, 1, closed="both")],
329+
"learning_method": [StrOptions({"batch", "online"})],
330+
"learning_decay": [Interval(Real, 0, 1, closed="both")],
331+
"learning_offset": [Interval(Real, 1.0, None, closed="left")],
332+
"max_iter": [Interval(Integral, 0, None, closed="left")],
333+
"batch_size": [Interval(Integral, 0, None, closed="neither")],
334+
"evaluate_every": [Interval(Integral, None, None, closed="neither")],
335+
"total_samples": [Interval(Real, 0, None, closed="neither")],
336+
"perp_tol": [Interval(Real, 0, None, closed="left")],
337+
"mean_change_tol": [Interval(Real, 0, None, closed="left")],
338+
"max_doc_update_iter": [Interval(Integral, 0, None, closed="left")],
339+
"n_jobs": [None, Integral],
340+
"verbose": ["verbose"],
341+
"random_state": ["random_state"],
342+
}
343+
323344
def __init__(
324345
self,
325346
n_components=10,
@@ -357,26 +378,6 @@ def __init__(
357378
self.verbose = verbose
358379
self.random_state = random_state
359380

360-
def _check_params(self):
361-
"""Check model parameters."""
362-
if self.n_components <= 0:
363-
raise ValueError("Invalid 'n_components' parameter: %r" % self.n_components)
364-
365-
if self.total_samples <= 0:
366-
raise ValueError(
367-
"Invalid 'total_samples' parameter: %r" % self.total_samples
368-
)
369-
370-
if self.learning_offset < 0:
371-
raise ValueError(
372-
"Invalid 'learning_offset' parameter: %r" % self.learning_offset
373-
)
374-
375-
if self.learning_method not in ("batch", "online"):
376-
raise ValueError(
377-
"Invalid 'learning_method' parameter: %r" % self.learning_method
378-
)
379-
380381
def _init_latent_vars(self, n_features):
381382
"""Initialize latent variables."""
382383

@@ -559,7 +560,7 @@ def partial_fit(self, X, y=None):
559560
self
560561
Partially fitted estimator.
561562
"""
562-
self._check_params()
563+
self._validate_params()
563564
first_time = not hasattr(self, "components_")
564565
X = self._check_non_neg_array(
565566
X, reset_n_features=first_time, whom="LatentDirichletAllocation.partial_fit"
@@ -609,7 +610,7 @@ def fit(self, X, y=None):
609610
self
610611
Fitted estimator.
611612
"""
612-
self._check_params()
613+
self._validate_params()
613614
X = self._check_non_neg_array(
614615
X, reset_n_features=True, whom="LatentDirichletAllocation.fit"
615616
)

sklearn/decomposition/tests/test_online_lda.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -152,22 +152,6 @@ def test_lda_fit_transform(method):
152152
assert_array_almost_equal(X_fit, X_trans, 4)
153153

154154

155-
def test_invalid_params():
156-
# test `_check_params` method
157-
X = np.ones((5, 10))
158-
159-
invalid_models = (
160-
("n_components", LatentDirichletAllocation(n_components=0)),
161-
("learning_method", LatentDirichletAllocation(learning_method="unknown")),
162-
("total_samples", LatentDirichletAllocation(total_samples=0)),
163-
("learning_offset", LatentDirichletAllocation(learning_offset=-1)),
164-
)
165-
for param, model in invalid_models:
166-
regex = r"^Invalid %r parameter" % param
167-
with pytest.raises(ValueError, match=regex):
168-
model.fit(X)
169-
170-
171155
def test_lda_negative_input():
172156
# test pass dense matrix with sparse negative input.
173157
X = np.full((5, 10), -1.0)

sklearn/tests/test_common.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -467,7 +467,6 @@ def test_estimators_do_not_raise_errors_in_init_or_set_params(Estimator):
467467
"ClassifierChain",
468468
"DictionaryLearning",
469469
"HashingVectorizer",
470-
"LatentDirichletAllocation",
471470
"MiniBatchDictionaryLearning",
472471
"MultiTaskElasticNet",
473472
"MultiTaskLasso",

0 commit comments

Comments
 (0)