Skip to content

Commit

Permalink
add tests for some ops of PythonObjectArray
Browse files Browse the repository at this point in the history
Signed-off-by: Xingyu Long <[email protected]>
  • Loading branch information
xingyu-long committed Dec 13, 2024
1 parent 56e19a8 commit f145fd0
Showing 1 changed file with 38 additions and 0 deletions.
38 changes: 38 additions & 0 deletions python/ray/air/tests/test_object_extension.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import types

import numpy as np
import pandas as pd
import pyarrow as pa
import pytest

Expand Down Expand Up @@ -60,6 +61,43 @@ def test_arrow_pandas_roundtrip():
assert t1.equals(t2)


@pytest.mark.skipif(
not _object_extension_type_allowed(), reason="Object extension not supported."
)
def test_pandas_python_object_isna():
arr = np.array([1, np.nan, 3, 4, 5, np.nan, 7, 8, 9], dtype=object)
ta = PythonObjectArray(arr)
np.testing.assert_array_equal(ta.isna(), pd.isna(arr))


@pytest.mark.skipif(
not _object_extension_type_allowed(), reason="Object extension not supported."
)
def test_pandas_python_object_take():
arr = np.array([1, 2, 3, 4, 5], dtype=object)
ta = PythonObjectArray(arr)
indices = [1, 2, 3]
np.testing.assert_array_equal(ta.take(indices).to_numpy(), arr[indices])
indices = [1, 2, -1]
np.testing.assert_array_equal(
ta.take(indices, allow_fill=True, fill_value=100).to_numpy(),
np.array([2, 3, 100]),
)


@pytest.mark.skipif(
not _object_extension_type_allowed(), reason="Object extension not supported."
)
def test_pandas_python_object_concat():
arr1 = np.array([1, 2, 3, 4, 5], dtype=object)
arr2 = np.array([6, 7, 8, 9, 10], dtype=object)
ta1 = PythonObjectArray(arr1)
ta2 = PythonObjectArray(arr2)
concat_arr = PythonObjectArray._concat_same_type([ta1, ta2])
assert len(concat_arr) == arr1.shape[0] + arr2.shape[0]
np.testing.assert_array_equal(concat_arr.to_numpy(), np.concatenate([arr1, arr2]))


if __name__ == "__main__":
import sys

Expand Down

0 comments on commit f145fd0

Please sign in to comment.