diff --git a/python/ray/air/tests/test_object_extension.py b/python/ray/air/tests/test_object_extension.py index 64600bafc69cd..b95f4c44a958b 100644 --- a/python/ray/air/tests/test_object_extension.py +++ b/python/ray/air/tests/test_object_extension.py @@ -1,6 +1,7 @@ import types import numpy as np +import pandas as pd import pyarrow as pa import pytest @@ -60,6 +61,43 @@ def test_arrow_pandas_roundtrip(): assert t1.equals(t2) +@pytest.mark.skipif( + not _object_extension_type_allowed(), reason="Object extension not supported." +) +def test_pandas_python_object_isna(): + arr = np.array([1, np.nan, 3, 4, 5, np.nan, 7, 8, 9], dtype=object) + ta = PythonObjectArray(arr) + np.testing.assert_array_equal(ta.isna(), pd.isna(arr)) + + +@pytest.mark.skipif( + not _object_extension_type_allowed(), reason="Object extension not supported." +) +def test_pandas_python_object_take(): + arr = np.array([1, 2, 3, 4, 5], dtype=object) + ta = PythonObjectArray(arr) + indices = [1, 2, 3] + np.testing.assert_array_equal(ta.take(indices).to_numpy(), arr[indices]) + indices = [1, 2, -1] + np.testing.assert_array_equal( + ta.take(indices, allow_fill=True, fill_value=100).to_numpy(), + np.array([2, 3, 100]), + ) + + +@pytest.mark.skipif( + not _object_extension_type_allowed(), reason="Object extension not supported." +) +def test_pandas_python_object_concat(): + arr1 = np.array([1, 2, 3, 4, 5], dtype=object) + arr2 = np.array([6, 7, 8, 9, 10], dtype=object) + ta1 = PythonObjectArray(arr1) + ta2 = PythonObjectArray(arr2) + concat_arr = PythonObjectArray._concat_same_type([ta1, ta2]) + assert len(concat_arr) == arr1.shape[0] + arr2.shape[0] + np.testing.assert_array_equal(concat_arr.to_numpy(), np.concatenate([arr1, arr2])) + + if __name__ == "__main__": import sys