Skip to content

Commit a4d2d07

Browse files
authored
Fix add typehint without generic (#507)
* fix: acquisition.py * fix: constraint.py * fix: target_space.py * fix: domain_reduction.py * fix: logger.py * fix: observer.py * fix: bayesian_optimization.py * fix: util.py * fix * fix: dict -> Mapping * fix: allow Sequence * fix: diallow null * fix: allow scipy constraints * fix: revert ArrayLike * fix * fix: codecov * fix: deque, suggest * fix: allow null target_func * fix: nullable ConstraintModel.fun * fix: NonlinearConstraint * fix: docs * fix: deps. errors * chore: revert numpy version * tests: codecov * chore: rm myst-parse * fix: review(only single point) * fix: review(params type) * docs: add ext options comment * fix: rm unnecessary overload * fix: rm whitespace * fix: rm autodoc docstring hook
1 parent 2298f7c commit a4d2d07

File tree

12 files changed

+370
-162
lines changed

12 files changed

+370
-162
lines changed

bayes_opt/acquisition.py

Lines changed: 91 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,7 @@
2323
import abc
2424
import warnings
2525
from copy import deepcopy
26-
from numbers import Number
27-
from typing import TYPE_CHECKING, Callable
26+
from typing import TYPE_CHECKING, Any, Literal, NoReturn
2827

2928
import numpy as np
3029
from numpy.random import RandomState
@@ -41,8 +40,15 @@
4140
from bayes_opt.target_space import TargetSpace
4241

4342
if TYPE_CHECKING:
43+
from collections.abc import Callable
44+
45+
from numpy.typing import NDArray
46+
from scipy.optimize import OptimizeResult
47+
4448
from bayes_opt.constraint import ConstraintModel
4549

50+
Float = np.floating[Any]
51+
4652

4753
class AcquisitionFunction(abc.ABC):
4854
"""Base class for acquisition functions.
@@ -53,7 +59,7 @@ class AcquisitionFunction(abc.ABC):
5359
Set the random state for reproducibility.
5460
"""
5561

56-
def __init__(self, random_state=None):
62+
def __init__(self, random_state: int | RandomState | None = None) -> None:
5763
if random_state is not None:
5864
if isinstance(random_state, RandomState):
5965
self.random_state = random_state
@@ -64,7 +70,7 @@ def __init__(self, random_state=None):
6470
self.i = 0
6571

6672
@abc.abstractmethod
67-
def base_acq(self, *args, **kwargs):
73+
def base_acq(self, *args: Any, **kwargs: Any) -> NDArray[Float]:
6874
"""Provide access to the base acquisition function."""
6975

7076
def _fit_gp(self, gp: GaussianProcessRegressor, target_space: TargetSpace) -> None:
@@ -80,10 +86,10 @@ def suggest(
8086
self,
8187
gp: GaussianProcessRegressor,
8288
target_space: TargetSpace,
83-
n_random=10_000,
84-
n_l_bfgs_b=10,
89+
n_random: int = 10_000,
90+
n_l_bfgs_b: int = 10,
8591
fit_gp: bool = True,
86-
):
92+
) -> NDArray[Float]:
8793
"""Suggest a promising point to probe next.
8894
8995
Parameters
@@ -123,7 +129,9 @@ def suggest(
123129
acq = self._get_acq(gp=gp, constraint=target_space.constraint)
124130
return self._acq_min(acq, target_space.bounds, n_random=n_random, n_l_bfgs_b=n_l_bfgs_b)
125131

126-
def _get_acq(self, gp: GaussianProcessRegressor, constraint: ConstraintModel | None = None) -> Callable:
132+
def _get_acq(
133+
self, gp: GaussianProcessRegressor, constraint: ConstraintModel | None = None
134+
) -> Callable[[NDArray[Float]], NDArray[Float]]:
127135
"""Prepare the acquisition function for minimization.
128136
129137
Transforms a base_acq Callable, which takes `mean` and `std` as
@@ -148,25 +156,36 @@ def _get_acq(self, gp: GaussianProcessRegressor, constraint: ConstraintModel | N
148156
dim = gp.X_train_.shape[1]
149157
if constraint is not None:
150158

151-
def acq(x):
159+
def acq(x: NDArray[Float]) -> NDArray[Float]:
152160
x = x.reshape(-1, dim)
153161
with warnings.catch_warnings():
154162
warnings.simplefilter("ignore")
163+
mean: NDArray[Float]
164+
std: NDArray[Float]
165+
p_constraints: NDArray[Float]
155166
mean, std = gp.predict(x, return_std=True)
156167
p_constraints = constraint.predict(x)
157168
return -1 * self.base_acq(mean, std) * p_constraints
158169
else:
159170

160-
def acq(x):
171+
def acq(x: NDArray[Float]) -> NDArray[Float]:
161172
x = x.reshape(-1, dim)
162173
with warnings.catch_warnings():
163174
warnings.simplefilter("ignore")
175+
mean: NDArray[Float]
176+
std: NDArray[Float]
164177
mean, std = gp.predict(x, return_std=True)
165178
return -1 * self.base_acq(mean, std)
166179

167180
return acq
168181

169-
def _acq_min(self, acq: Callable, bounds: np.ndarray, n_random=10_000, n_l_bfgs_b=10) -> np.ndarray:
182+
def _acq_min(
183+
self,
184+
acq: Callable[[NDArray[Float]], NDArray[Float]],
185+
bounds: NDArray[Float],
186+
n_random: int = 10_000,
187+
n_l_bfgs_b: int = 10,
188+
) -> NDArray[Float]:
170189
"""Find the maximum of the acquisition function.
171190
172191
Uses a combination of random sampling (cheap) and the 'L-BFGS-B'
@@ -200,13 +219,14 @@ def _acq_min(self, acq: Callable, bounds: np.ndarray, n_random=10_000, n_l_bfgs_
200219
raise ValueError(error_msg)
201220
x_min_r, min_acq_r = self._random_sample_minimize(acq, bounds, n_random=n_random)
202221
x_min_l, min_acq_l = self._l_bfgs_b_minimize(acq, bounds, n_x_seeds=n_l_bfgs_b)
222+
# Either n_random or n_l_bfgs_b is not 0 => at least one of x_min_r and x_min_l is not None
203223
if min_acq_r < min_acq_l:
204224
return x_min_r
205225
return x_min_l
206226

207227
def _random_sample_minimize(
208-
self, acq: Callable, bounds: np.ndarray, n_random: int
209-
) -> tuple[np.ndarray, float]:
228+
self, acq: Callable[[NDArray[Float]], NDArray[Float]], bounds: NDArray[Float], n_random: int
229+
) -> tuple[NDArray[Float] | None, float]:
210230
"""Random search to find the minimum of `acq` function.
211231
212232
Parameters
@@ -239,8 +259,8 @@ def _random_sample_minimize(
239259
return x_min, min_acq
240260

241261
def _l_bfgs_b_minimize(
242-
self, acq: Callable, bounds: np.ndarray, n_x_seeds: int = 10
243-
) -> tuple[np.ndarray, float]:
262+
self, acq: Callable[[NDArray[Float]], NDArray[Float]], bounds: NDArray[Float], n_x_seeds: int = 10
263+
) -> tuple[NDArray[Float] | None, float]:
244264
"""Random search to find the minimum of `acq` function.
245265
246266
Parameters
@@ -268,10 +288,12 @@ def _l_bfgs_b_minimize(
268288
return None, np.inf
269289
x_seeds = self.random_state.uniform(bounds[:, 0], bounds[:, 1], size=(n_x_seeds, bounds.shape[0]))
270290

271-
min_acq = None
291+
min_acq: float | None = None
292+
x_try: NDArray[Float]
293+
x_min: NDArray[Float]
272294
for x_try in x_seeds:
273295
# Find the minimum of minus the acquisition function
274-
res = minimize(acq, x_try, bounds=bounds, method="L-BFGS-B")
296+
res: OptimizeResult = minimize(acq, x_try, bounds=bounds, method="L-BFGS-B")
275297

276298
# See if success
277299
if not res.success:
@@ -317,7 +339,11 @@ class UpperConfidenceBound(AcquisitionFunction):
317339
"""
318340

319341
def __init__(
320-
self, kappa=2.576, exploration_decay=None, exploration_decay_delay=None, random_state=None
342+
self,
343+
kappa: float = 2.576,
344+
exploration_decay: float | None = None,
345+
exploration_decay_delay: int | None = None,
346+
random_state: int | RandomState | None = None,
321347
) -> None:
322348
if kappa < 0:
323349
error_msg = "kappa must be greater than or equal to 0."
@@ -328,7 +354,7 @@ def __init__(
328354
self.exploration_decay = exploration_decay
329355
self.exploration_decay_delay = exploration_decay_delay
330356

331-
def base_acq(self, mean, std):
357+
def base_acq(self, mean: NDArray[Float], std: NDArray[Float]) -> NDArray[Float]:
332358
"""Calculate the upper confidence bound.
333359
334360
Parameters
@@ -350,10 +376,10 @@ def suggest(
350376
self,
351377
gp: GaussianProcessRegressor,
352378
target_space: TargetSpace,
353-
n_random=10_000,
354-
n_l_bfgs_b=10,
379+
n_random: int = 10_000,
380+
n_l_bfgs_b: int = 10,
355381
fit_gp: bool = True,
356-
) -> np.ndarray:
382+
) -> NDArray[Float]:
357383
"""Suggest a promising point to probe next.
358384
359385
Parameters
@@ -432,14 +458,20 @@ class ProbabilityOfImprovement(AcquisitionFunction):
432458
Set the random state for reproducibility.
433459
"""
434460

435-
def __init__(self, xi, exploration_decay=None, exploration_decay_delay=None, random_state=None) -> None:
461+
def __init__(
462+
self,
463+
xi: float,
464+
exploration_decay: float | None = None,
465+
exploration_decay_delay: int | None = None,
466+
random_state: int | RandomState | None = None,
467+
) -> None:
436468
super().__init__(random_state=random_state)
437469
self.xi = xi
438470
self.exploration_decay = exploration_decay
439471
self.exploration_decay_delay = exploration_decay_delay
440472
self.y_max = None
441473

442-
def base_acq(self, mean, std):
474+
def base_acq(self, mean: NDArray[Float], std: NDArray[Float]) -> NDArray[Float]:
443475
"""Calculate the probability of improvement.
444476
445477
Parameters
@@ -473,10 +505,10 @@ def suggest(
473505
self,
474506
gp: GaussianProcessRegressor,
475507
target_space: TargetSpace,
476-
n_random=10_000,
477-
n_l_bfgs_b=10,
508+
n_random: int = 10_000,
509+
n_l_bfgs_b: int = 10,
478510
fit_gp: bool = True,
479-
) -> np.ndarray:
511+
) -> NDArray[Float]:
480512
"""Suggest a promising point to probe next.
481513
482514
Parameters
@@ -565,14 +597,20 @@ class ExpectedImprovement(AcquisitionFunction):
565597
Set the random state for reproducibility.
566598
"""
567599

568-
def __init__(self, xi, exploration_decay=None, exploration_decay_delay=None, random_state=None) -> None:
600+
def __init__(
601+
self,
602+
xi: float,
603+
exploration_decay: float | None = None,
604+
exploration_decay_delay: int | None = None,
605+
random_state: int | RandomState | None = None,
606+
) -> None:
569607
super().__init__(random_state=random_state)
570608
self.xi = xi
571609
self.exploration_decay = exploration_decay
572610
self.exploration_decay_delay = exploration_decay_delay
573611
self.y_max = None
574612

575-
def base_acq(self, mean, std):
613+
def base_acq(self, mean: NDArray[Float], std: NDArray[Float]) -> NDArray[Float]:
576614
"""Calculate the expected improvement.
577615
578616
Parameters
@@ -607,10 +645,10 @@ def suggest(
607645
self,
608646
gp: GaussianProcessRegressor,
609647
target_space: TargetSpace,
610-
n_random=10_000,
611-
n_l_bfgs_b=10,
648+
n_random: int = 10_000,
649+
n_l_bfgs_b: int = 10,
612650
fit_gp: bool = True,
613-
) -> np.ndarray:
651+
) -> NDArray[Float]:
614652
"""Suggest a promising point to probe next.
615653
616654
Parameters
@@ -701,19 +739,24 @@ class ConstantLiar(AcquisitionFunction):
701739
"""
702740

703741
def __init__(
704-
self, base_acquisition: AcquisitionFunction, strategy="max", random_state=None, atol=1e-5, rtol=1e-8
742+
self,
743+
base_acquisition: AcquisitionFunction,
744+
strategy: Literal["min", "mean", "max"] | float = "max",
745+
random_state: int | RandomState | None = None,
746+
atol: float = 1e-5,
747+
rtol: float = 1e-8,
705748
) -> None:
706749
super().__init__(random_state)
707750
self.base_acquisition = base_acquisition
708751
self.dummies = []
709-
if not isinstance(strategy, Number) and strategy not in ["min", "mean", "max"]:
752+
if not isinstance(strategy, float) and strategy not in ["min", "mean", "max"]:
710753
error_msg = f"Received invalid argument {strategy} for strategy."
711754
raise ValueError(error_msg)
712-
self.strategy = strategy
755+
self.strategy: Literal["min", "mean", "max"] | float = strategy
713756
self.atol = atol
714757
self.rtol = rtol
715758

716-
def base_acq(self, *args, **kwargs):
759+
def base_acq(self, *args: Any, **kwargs: Any) -> NDArray[Float]:
717760
"""Calculate the acquisition function.
718761
719762
Calls the base acquisition function's `base_acq` method.
@@ -774,10 +817,10 @@ def suggest(
774817
self,
775818
gp: GaussianProcessRegressor,
776819
target_space: TargetSpace,
777-
n_random=10_000,
778-
n_l_bfgs_b=10,
820+
n_random: int = 10_000,
821+
n_l_bfgs_b: int = 10,
779822
fit_gp: bool = True,
780-
) -> np.ndarray:
823+
) -> NDArray[Float]:
781824
"""Suggest a promising point to probe next.
782825
783826
Parameters
@@ -824,8 +867,9 @@ def suggest(
824867
# Create a copy of the target space
825868
dummy_target_space = self._copy_target_space(target_space)
826869

870+
dummy_target: float
827871
# Choose the dummy target value
828-
if isinstance(self.strategy, Number):
872+
if isinstance(self.strategy, float):
829873
dummy_target = self.strategy
830874
elif self.strategy == "min":
831875
dummy_target = target_space.target.min()
@@ -875,14 +919,16 @@ class GPHedge(AcquisitionFunction):
875919
Set the random state for reproducibility.
876920
"""
877921

878-
def __init__(self, base_acquisitions: list[AcquisitionFunction], random_state=None) -> None:
922+
def __init__(
923+
self, base_acquisitions: list[AcquisitionFunction], random_state: int | RandomState | None = None
924+
) -> None:
879925
super().__init__(random_state)
880926
self.base_acquisitions = base_acquisitions
881927
self.n_acq = len(self.base_acquisitions)
882928
self.gains = np.zeros(self.n_acq)
883929
self.previous_candidates = None
884930

885-
def base_acq(self, *args, **kwargs):
931+
def base_acq(self, *args: Any, **kwargs: Any) -> NoReturn:
886932
"""Raise an error, since the base acquisition function is ambiguous."""
887933
msg = (
888934
"GPHedge base acquisition function is ambiguous."
@@ -909,10 +955,10 @@ def suggest(
909955
self,
910956
gp: GaussianProcessRegressor,
911957
target_space: TargetSpace,
912-
n_random=10_000,
913-
n_l_bfgs_b=10,
958+
n_random: int = 10_000,
959+
n_l_bfgs_b: int = 10,
914960
fit_gp: bool = True,
915-
) -> np.ndarray:
961+
) -> NDArray[Float]:
916962
"""Suggest a promising point to probe next.
917963
918964
Parameters

0 commit comments

Comments
 (0)