From 56e19a8b16472d70890420285352dd1ebeca01ab Mon Sep 17 00:00:00 2001 From: akavalar Date: Wed, 13 Nov 2024 22:15:25 -0800 Subject: [PATCH 1/2] missing methods Signed-off-by: akavalar --- .../ray/air/util/object_extensions/pandas.py | 30 ++++++++++++++++++- 1 file changed, 29 insertions(+), 1 deletion(-) diff --git a/python/ray/air/util/object_extensions/pandas.py b/python/ray/air/util/object_extensions/pandas.py index ccd71e5dc89a..dbc5732f350b 100644 --- a/python/ray/air/util/object_extensions/pandas.py +++ b/python/ray/air/util/object_extensions/pandas.py @@ -5,7 +5,7 @@ import pandas as pd import pyarrow as pa from pandas._libs import lib -from pandas._typing import ArrayLike, Dtype, PositionalIndexer, npt +from pandas._typing import ArrayLike, Dtype, PositionalIndexer, TakeIndexer, npt import ray.air.util.object_extensions.arrow from ray.util.annotations import PublicAPI @@ -80,6 +80,34 @@ def __arrow_array__(self, type=None): self.values ) + def isna(self) -> np.ndarray: + return pd.isnull(self.values) + + def take( + self, + indices: TakeIndexer, + *, + allow_fill: bool = False, + fill_value: typing.Any = None, + ) -> "PythonObjectArray": + if allow_fill and fill_value is None: + fill_value = self.dtype.na_value + + result = pd.core.algorithms.take( + self.values, indices, allow_fill=allow_fill, fill_value=fill_value + ) + return self._from_sequence(result, dtype=self.dtype) + + def copy(self) -> "PythonObjectArray": + return PythonObjectArray(self.values) + + @classmethod + def _concat_same_type( + cls, to_concat: collections.abc.Sequence["PythonObjectArray"] + ) -> "PythonObjectArray": + values_to_concat = [element.values for element in to_concat] + return cls(np.concatenate(values_to_concat)) + @PublicAPI(stability="alpha") @pd.api.extensions.register_extension_dtype From f145fd05fd11bf7058f816f53ceaefdd9657491a Mon Sep 17 00:00:00 2001 From: Xingyu Long Date: Tue, 10 Dec 2024 18:39:32 -0800 Subject: [PATCH 2/2] add tests for some ops of PythonObjectArray Signed-off-by: Xingyu Long --- python/ray/air/tests/test_object_extension.py | 38 +++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/python/ray/air/tests/test_object_extension.py b/python/ray/air/tests/test_object_extension.py index 64600bafc69c..b95f4c44a958 100644 --- a/python/ray/air/tests/test_object_extension.py +++ b/python/ray/air/tests/test_object_extension.py @@ -1,6 +1,7 @@ import types import numpy as np +import pandas as pd import pyarrow as pa import pytest @@ -60,6 +61,43 @@ def test_arrow_pandas_roundtrip(): assert t1.equals(t2) +@pytest.mark.skipif( + not _object_extension_type_allowed(), reason="Object extension not supported." +) +def test_pandas_python_object_isna(): + arr = np.array([1, np.nan, 3, 4, 5, np.nan, 7, 8, 9], dtype=object) + ta = PythonObjectArray(arr) + np.testing.assert_array_equal(ta.isna(), pd.isna(arr)) + + +@pytest.mark.skipif( + not _object_extension_type_allowed(), reason="Object extension not supported." +) +def test_pandas_python_object_take(): + arr = np.array([1, 2, 3, 4, 5], dtype=object) + ta = PythonObjectArray(arr) + indices = [1, 2, 3] + np.testing.assert_array_equal(ta.take(indices).to_numpy(), arr[indices]) + indices = [1, 2, -1] + np.testing.assert_array_equal( + ta.take(indices, allow_fill=True, fill_value=100).to_numpy(), + np.array([2, 3, 100]), + ) + + +@pytest.mark.skipif( + not _object_extension_type_allowed(), reason="Object extension not supported." +) +def test_pandas_python_object_concat(): + arr1 = np.array([1, 2, 3, 4, 5], dtype=object) + arr2 = np.array([6, 7, 8, 9, 10], dtype=object) + ta1 = PythonObjectArray(arr1) + ta2 = PythonObjectArray(arr2) + concat_arr = PythonObjectArray._concat_same_type([ta1, ta2]) + assert len(concat_arr) == arr1.shape[0] + arr2.shape[0] + np.testing.assert_array_equal(concat_arr.to_numpy(), np.concatenate([arr1, arr2])) + + if __name__ == "__main__": import sys