Skip to content

Commit 09278ed

Browse files
authored
BUG: retain NAs in ufunc on ArrowEA (pandas-dev#62908)
1 parent 36f5e25 commit 09278ed

File tree

3 files changed

+19
-1
lines changed

3 files changed

+19
-1
lines changed

pandas/core/arrays/arrow/array.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -829,6 +829,14 @@ def __arrow_array__(self, type=None):
829829
"""Convert myself to a pyarrow ChunkedArray."""
830830
return self._pa_array
831831

832+
def __array_ufunc__(self, ufunc: np.ufunc, method: str, *inputs, **kwargs):
833+
# Need to wrap np.array results GH#62800
834+
result = super().__array_ufunc__(ufunc, method, *inputs, **kwargs)
835+
if type(self) is ArrowExtensionArray:
836+
# Exclude ArrowStringArray
837+
return type(self)._from_sequence(result)
838+
return result
839+
832840
def __array__(
833841
self, dtype: NpDtype | None = None, copy: bool | None = None
834842
) -> np.ndarray:

pandas/tests/extension/test_arrow.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3800,3 +3800,13 @@ def test_cast_pontwise_result_decimal_nan():
38003800

38013801
pa_type = result.dtype.pyarrow_dtype
38023802
assert pa.types.is_decimal(pa_type)
3803+
3804+
3805+
def test_ufunc_retains_missing():
3806+
# GH#62800
3807+
ser = pd.Series([0.1, pd.NA], dtype="float64[pyarrow]")
3808+
3809+
result = np.sin(ser)
3810+
3811+
expected = pd.Series([np.sin(0.1), pd.NA], dtype="float64[pyarrow]")
3812+
tm.assert_series_equal(result, expected)

pandas/tests/series/test_npfuncs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def test_log_arrow_backed_missing_value(using_nan_is_na):
4343
ser = Series([1, 2, None], dtype="float64[pyarrow]")
4444
if using_nan_is_na:
4545
result = np.log(ser)
46-
expected = np.log(Series([1, 2, None], dtype="float64"))
46+
expected = np.log(Series([1, 2, None], dtype="float64[pyarrow]"))
4747
tm.assert_series_equal(result, expected)
4848
else:
4949
# we get cast to object which raises

0 commit comments

Comments
 (0)