|
10 | 10 |
|
11 | 11 | # Author: Chyi-Kwei Yau |
12 | 12 | # Author: Matthew D. Hoffman (original onlineldavb implementation) |
| 13 | +from numbers import Integral, Real |
13 | 14 |
|
14 | 15 | import numpy as np |
15 | 16 | import scipy.sparse as sp |
|
21 | 22 | from ..utils.validation import check_non_negative |
22 | 23 | from ..utils.validation import check_is_fitted |
23 | 24 | from ..utils.fixes import delayed |
| 25 | +from ..utils._param_validation import Interval, StrOptions |
24 | 26 |
|
25 | 27 | from ._online_lda_fast import ( |
26 | 28 | mean_change, |
@@ -320,6 +322,25 @@ class conditional densities to the data and using Bayes' rule. |
320 | 322 | [0.15297572, 0.00362644, 0.44412786, 0.39568399, 0.003586 ]]) |
321 | 323 | """ |
322 | 324 |
|
| 325 | + _parameter_constraints: dict = { |
| 326 | + "n_components": [Interval(Integral, 0, None, closed="neither")], |
| 327 | + "doc_topic_prior": [None, Interval(Real, 0, 1, closed="both")], |
| 328 | + "topic_word_prior": [None, Interval(Real, 0, 1, closed="both")], |
| 329 | + "learning_method": [StrOptions({"batch", "online"})], |
| 330 | + "learning_decay": [Interval(Real, 0, 1, closed="both")], |
| 331 | + "learning_offset": [Interval(Real, 1.0, None, closed="left")], |
| 332 | + "max_iter": [Interval(Integral, 0, None, closed="left")], |
| 333 | + "batch_size": [Interval(Integral, 0, None, closed="neither")], |
| 334 | + "evaluate_every": [Interval(Integral, None, None, closed="neither")], |
| 335 | + "total_samples": [Interval(Real, 0, None, closed="neither")], |
| 336 | + "perp_tol": [Interval(Real, 0, None, closed="left")], |
| 337 | + "mean_change_tol": [Interval(Real, 0, None, closed="left")], |
| 338 | + "max_doc_update_iter": [Interval(Integral, 0, None, closed="left")], |
| 339 | + "n_jobs": [None, Integral], |
| 340 | + "verbose": ["verbose"], |
| 341 | + "random_state": ["random_state"], |
| 342 | + } |
| 343 | + |
323 | 344 | def __init__( |
324 | 345 | self, |
325 | 346 | n_components=10, |
@@ -357,26 +378,6 @@ def __init__( |
357 | 378 | self.verbose = verbose |
358 | 379 | self.random_state = random_state |
359 | 380 |
|
360 | | - def _check_params(self): |
361 | | - """Check model parameters.""" |
362 | | - if self.n_components <= 0: |
363 | | - raise ValueError("Invalid 'n_components' parameter: %r" % self.n_components) |
364 | | - |
365 | | - if self.total_samples <= 0: |
366 | | - raise ValueError( |
367 | | - "Invalid 'total_samples' parameter: %r" % self.total_samples |
368 | | - ) |
369 | | - |
370 | | - if self.learning_offset < 0: |
371 | | - raise ValueError( |
372 | | - "Invalid 'learning_offset' parameter: %r" % self.learning_offset |
373 | | - ) |
374 | | - |
375 | | - if self.learning_method not in ("batch", "online"): |
376 | | - raise ValueError( |
377 | | - "Invalid 'learning_method' parameter: %r" % self.learning_method |
378 | | - ) |
379 | | - |
380 | 381 | def _init_latent_vars(self, n_features): |
381 | 382 | """Initialize latent variables.""" |
382 | 383 |
|
@@ -559,7 +560,7 @@ def partial_fit(self, X, y=None): |
559 | 560 | self |
560 | 561 | Partially fitted estimator. |
561 | 562 | """ |
562 | | - self._check_params() |
| 563 | + self._validate_params() |
563 | 564 | first_time = not hasattr(self, "components_") |
564 | 565 | X = self._check_non_neg_array( |
565 | 566 | X, reset_n_features=first_time, whom="LatentDirichletAllocation.partial_fit" |
@@ -609,7 +610,7 @@ def fit(self, X, y=None): |
609 | 610 | self |
610 | 611 | Fitted estimator. |
611 | 612 | """ |
612 | | - self._check_params() |
| 613 | + self._validate_params() |
613 | 614 | X = self._check_non_neg_array( |
614 | 615 | X, reset_n_features=True, whom="LatentDirichletAllocation.fit" |
615 | 616 | ) |
|
0 commit comments