diff --git a/python/ray/air/tests/test_object_extension.py b/python/ray/air/tests/test_object_extension.py index 64600bafc69cd..b95f4c44a958b 100644 --- a/python/ray/air/tests/test_object_extension.py +++ b/python/ray/air/tests/test_object_extension.py @@ -1,6 +1,7 @@ import types import numpy as np +import pandas as pd import pyarrow as pa import pytest @@ -60,6 +61,43 @@ def test_arrow_pandas_roundtrip(): assert t1.equals(t2) +@pytest.mark.skipif( + not _object_extension_type_allowed(), reason="Object extension not supported." +) +def test_pandas_python_object_isna(): + arr = np.array([1, np.nan, 3, 4, 5, np.nan, 7, 8, 9], dtype=object) + ta = PythonObjectArray(arr) + np.testing.assert_array_equal(ta.isna(), pd.isna(arr)) + + +@pytest.mark.skipif( + not _object_extension_type_allowed(), reason="Object extension not supported." +) +def test_pandas_python_object_take(): + arr = np.array([1, 2, 3, 4, 5], dtype=object) + ta = PythonObjectArray(arr) + indices = [1, 2, 3] + np.testing.assert_array_equal(ta.take(indices).to_numpy(), arr[indices]) + indices = [1, 2, -1] + np.testing.assert_array_equal( + ta.take(indices, allow_fill=True, fill_value=100).to_numpy(), + np.array([2, 3, 100]), + ) + + +@pytest.mark.skipif( + not _object_extension_type_allowed(), reason="Object extension not supported." +) +def test_pandas_python_object_concat(): + arr1 = np.array([1, 2, 3, 4, 5], dtype=object) + arr2 = np.array([6, 7, 8, 9, 10], dtype=object) + ta1 = PythonObjectArray(arr1) + ta2 = PythonObjectArray(arr2) + concat_arr = PythonObjectArray._concat_same_type([ta1, ta2]) + assert len(concat_arr) == arr1.shape[0] + arr2.shape[0] + np.testing.assert_array_equal(concat_arr.to_numpy(), np.concatenate([arr1, arr2])) + + if __name__ == "__main__": import sys diff --git a/python/ray/air/util/object_extensions/pandas.py b/python/ray/air/util/object_extensions/pandas.py index ccd71e5dc89a4..dbc5732f350b8 100644 --- a/python/ray/air/util/object_extensions/pandas.py +++ b/python/ray/air/util/object_extensions/pandas.py @@ -5,7 +5,7 @@ import pandas as pd import pyarrow as pa from pandas._libs import lib -from pandas._typing import ArrayLike, Dtype, PositionalIndexer, npt +from pandas._typing import ArrayLike, Dtype, PositionalIndexer, TakeIndexer, npt import ray.air.util.object_extensions.arrow from ray.util.annotations import PublicAPI @@ -80,6 +80,34 @@ def __arrow_array__(self, type=None): self.values ) + def isna(self) -> np.ndarray: + return pd.isnull(self.values) + + def take( + self, + indices: TakeIndexer, + *, + allow_fill: bool = False, + fill_value: typing.Any = None, + ) -> "PythonObjectArray": + if allow_fill and fill_value is None: + fill_value = self.dtype.na_value + + result = pd.core.algorithms.take( + self.values, indices, allow_fill=allow_fill, fill_value=fill_value + ) + return self._from_sequence(result, dtype=self.dtype) + + def copy(self) -> "PythonObjectArray": + return PythonObjectArray(self.values) + + @classmethod + def _concat_same_type( + cls, to_concat: collections.abc.Sequence["PythonObjectArray"] + ) -> "PythonObjectArray": + values_to_concat = [element.values for element in to_concat] + return cls(np.concatenate(values_to_concat)) + @PublicAPI(stability="alpha") @pd.api.extensions.register_extension_dtype