|
43 | 43 |
|
44 | 44 | from abc import ABC, abstractmethod
|
45 | 45 | from dataclasses import dataclass, field
|
| 46 | +from inspect import signature |
46 | 47 | from types import TracebackType
|
47 | 48 | from typing import (
|
48 | 49 | TYPE_CHECKING,
|
|
102 | 103 | StrOrBytes,
|
103 | 104 | WarningType,
|
104 | 105 | )
|
105 |
| -from ..utils import GenericWrapper |
| 106 | +from ..utils import GenericWrapper, import_object |
106 | 107 | from .authentication import MySQLAuthenticator
|
107 | 108 | from .charsets import Charset, charsets
|
108 | 109 | from .protocol import MySQLProtocol
|
@@ -181,6 +182,7 @@ def __init__(
|
181 | 182 | raw: bool = False,
|
182 | 183 | kerberos_auth_mode: Optional[str] = None,
|
183 | 184 | krb_service_principal: Optional[str] = None,
|
| 185 | + fido_callback: Optional[Union[str, Callable[[str], None]]] = None, |
184 | 186 | webauthn_callback: Optional[Union[str, Callable[[str], None]]] = None,
|
185 | 187 | allow_local_infile: bool = DEFAULT_CONFIGURATION["allow_local_infile"],
|
186 | 188 | allow_local_infile_in_path: Optional[str] = DEFAULT_CONFIGURATION[
|
@@ -259,8 +261,10 @@ def __init__(
|
259 | 261 | self._in_transaction: bool = False
|
260 | 262 | self._oci_config_file: Optional[str] = None
|
261 | 263 | self._oci_config_profile: Optional[str] = None
|
262 |
| - self._fido_callback: Optional[Union[str, Callable[[str], None]]] = None |
263 |
| - self._webauthn_callback: Optional[Union[str, Callable[[str], None]]] = None |
| 264 | + self._fido_callback: Optional[Union[str, Callable[[str], None]]] = fido_callback |
| 265 | + self._webauthn_callback: Optional[ |
| 266 | + Union[str, Callable[[str], None]] |
| 267 | + ] = webauthn_callback |
264 | 268 |
|
265 | 269 | self.converter: Optional[MySQLConverter] = None
|
266 | 270 |
|
@@ -588,6 +592,41 @@ def _validate_tls_versions(self) -> None:
|
588 | 592 | elif invalid_tls_versions:
|
589 | 593 | raise AttributeError(TLS_VERSION_ERROR.format(tls_ver, TLS_VERSIONS))
|
590 | 594 |
|
| 595 | + @staticmethod |
| 596 | + def _validate_callable( |
| 597 | + option_name: str, callback: Union[str, Callable], num_args: int = 0 |
| 598 | + ) -> None: |
| 599 | + """Validates if it's a Python callable. |
| 600 | +
|
| 601 | + Args: |
| 602 | + option_name (str): Connection option name. |
| 603 | + callback (str or callable): The fully qualified path to the callable or |
| 604 | + a callable. |
| 605 | + num_args (int): Number of positional arguments allowed. |
| 606 | +
|
| 607 | + Raises: |
| 608 | + ProgrammingError: If `callback` is not valid or wrong number of positional |
| 609 | + arguments. |
| 610 | +
|
| 611 | + .. versionadded:: 8.2.0 |
| 612 | + """ |
| 613 | + if isinstance(callback, str): |
| 614 | + try: |
| 615 | + callback = import_object(callback) |
| 616 | + except ValueError as err: |
| 617 | + raise ProgrammingError(f"{err}") from err |
| 618 | + |
| 619 | + if not callable(callback): |
| 620 | + raise ProgrammingError(f"Expected a callable for '{option_name}'") |
| 621 | + |
| 622 | + # Check if the callable signature has <num_args> positional arguments |
| 623 | + num_params = len(signature(callback).parameters) |
| 624 | + if num_params != num_args: |
| 625 | + raise ProgrammingError( |
| 626 | + f"'{option_name}' requires {num_args} positional argument, but the " |
| 627 | + f"callback provided has {num_params}" |
| 628 | + ) |
| 629 | + |
591 | 630 | @property
|
592 | 631 | @abstractmethod
|
593 | 632 | def connection_id(self) -> Optional[int]:
|
|
0 commit comments