|
15 | 15 |
|
16 | 16 | if TYPE_CHECKING: |
17 | 17 | from pandas._typing import ( |
18 | | - ArrayLike, |
19 | 18 | npt, |
20 | 19 | ) |
21 | 20 |
|
| 21 | + from pandas.core.arrays.base import ExtensionArray |
| 22 | + |
22 | 23 |
|
23 | 24 | def to_numpy_dtype_inference( |
24 | | - arr: ArrayLike, dtype: npt.DTypeLike | None, na_value, hasna: bool |
25 | | -) -> tuple[npt.DTypeLike, Any]: |
| 25 | + arr: ExtensionArray, dtype: npt.DTypeLike | None, na_value, hasna: bool |
| 26 | +) -> tuple[np.dtype | None, Any]: |
| 27 | + result_dtype: np.dtype | None |
| 28 | + inferred_numeric_dtype = False |
26 | 29 | if dtype is None and is_numeric_dtype(arr.dtype): |
27 | | - dtype_given = False |
| 30 | + inferred_numeric_dtype = True |
28 | 31 | if hasna: |
29 | 32 | if arr.dtype.kind == "b": |
30 | | - dtype = np.dtype(np.object_) |
| 33 | + result_dtype = np.dtype(np.object_) |
31 | 34 | else: |
32 | 35 | if arr.dtype.kind in "iu": |
33 | | - dtype = np.dtype(np.float64) |
| 36 | + result_dtype = np.dtype(np.float64) |
34 | 37 | else: |
35 | | - dtype = arr.dtype.numpy_dtype # type: ignore[union-attr] |
| 38 | + result_dtype = arr.dtype.numpy_dtype # type: ignore[attr-defined] |
36 | 39 | if na_value is lib.no_default: |
37 | 40 | na_value = np.nan |
38 | 41 | else: |
39 | | - dtype = arr.dtype.numpy_dtype # type: ignore[union-attr] |
| 42 | + result_dtype = arr.dtype.numpy_dtype # type: ignore[attr-defined] |
40 | 43 | elif dtype is not None: |
41 | | - dtype = np.dtype(dtype) |
42 | | - dtype_given = True |
| 44 | + result_dtype = np.dtype(dtype) |
43 | 45 | else: |
44 | | - dtype_given = True |
| 46 | + result_dtype = None |
45 | 47 |
|
46 | 48 | if na_value is lib.no_default: |
47 | | - if dtype is None or not hasna: |
| 49 | + if result_dtype is None or not hasna: |
48 | 50 | na_value = arr.dtype.na_value |
49 | | - elif dtype.kind == "f": # type: ignore[union-attr] |
| 51 | + elif result_dtype.kind == "f": |
50 | 52 | na_value = np.nan |
51 | | - elif dtype.kind == "M": # type: ignore[union-attr] |
| 53 | + elif result_dtype.kind == "M": |
52 | 54 | na_value = np.datetime64("nat") |
53 | | - elif dtype.kind == "m": # type: ignore[union-attr] |
| 55 | + elif result_dtype.kind == "m": |
54 | 56 | na_value = np.timedelta64("nat") |
55 | 57 | else: |
56 | 58 | na_value = arr.dtype.na_value |
57 | 59 |
|
58 | | - if not dtype_given and hasna: |
| 60 | + if inferred_numeric_dtype and hasna: |
59 | 61 | try: |
60 | | - np_can_hold_element(dtype, na_value) # type: ignore[arg-type] |
| 62 | + np_can_hold_element(result_dtype, na_value) # type: ignore[arg-type] |
61 | 63 | except LossySetitemError: |
62 | | - dtype = np.dtype(np.object_) |
63 | | - return dtype, na_value |
| 64 | + result_dtype = np.dtype(np.object_) |
| 65 | + return result_dtype, na_value |
0 commit comments