Skip to content

Commit c48e566

Browse files
authored
Fix min window type check (#523)
* fix: replace dict with Mapping * fix: replace list with Sequence * fix: add type hint * fix: does not accept None
1 parent f63372e commit c48e566

File tree

1 file changed

+7
-5
lines changed

1 file changed

+7
-5
lines changed

bayes_opt/domain_reduction.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from __future__ import annotations
99

1010
from abc import ABC, abstractmethod
11+
from collections.abc import Iterable, Mapping, Sequence
1112
from typing import TYPE_CHECKING, Any
1213
from warnings import warn
1314

@@ -16,8 +17,6 @@
1617
from bayes_opt.target_space import TargetSpace
1718

1819
if TYPE_CHECKING:
19-
from collections.abc import Iterable, Mapping, Sequence
20-
2120
from numpy.typing import NDArray
2221

2322
Float = np.floating[Any]
@@ -66,12 +65,14 @@ def __init__(
6665
gamma_osc: float = 0.7,
6766
gamma_pan: float = 1.0,
6867
eta: float = 0.9,
69-
minimum_window: NDArray[Float] | Sequence[float] | float | Mapping[str, float] | None = 0.0,
68+
minimum_window: NDArray[Float] | Sequence[float] | Mapping[str, float] | float = 0.0,
7069
) -> None:
7170
self.gamma_osc = gamma_osc
7271
self.gamma_pan = gamma_pan
7372
self.eta = eta
74-
if isinstance(minimum_window, dict):
73+
74+
self.minimum_window_value: NDArray[Float] | Sequence[float] | float
75+
if isinstance(minimum_window, Mapping):
7576
self.minimum_window_value = [
7677
item[1] for item in sorted(minimum_window.items(), key=lambda x: x[0])
7778
]
@@ -90,8 +91,9 @@ def initialize(self, target_space: TargetSpace) -> None:
9091
self.original_bounds = np.copy(target_space.bounds)
9192
self.bounds = [self.original_bounds]
9293

94+
self.minimum_window: NDArray[Float] | Sequence[float]
9395
# Set the minimum window to an array of length bounds
94-
if isinstance(self.minimum_window_value, (list, np.ndarray)):
96+
if isinstance(self.minimum_window_value, (Sequence, np.ndarray)):
9597
if len(self.minimum_window_value) != len(target_space.bounds):
9698
error_msg = "Length of minimum_window must be the same as the number of parameters"
9799
raise ValueError(error_msg)

0 commit comments

Comments
 (0)