Skip to content

Commit

Permalink
[Data] PythonObjectArray missing methods causing serialization failur…
Browse files Browse the repository at this point in the history
…es + tests (#49202)

## Why are these changes needed?

Discussed with @richardliaw and decide to open this PR based on
#48737 with tests

## Related issue number

N/A

## Checks

- [x] I've signed off every commit(by using the -s flag, i.e., `git
commit -s`) in this PR.
- [x] I've run `scripts/format.sh` to lint the changes in this PR.
- [ ] I've included any doc changes needed for
https://docs.ray.io/en/master/.
- [ ] I've added any new APIs to the API Reference. For example, if I
added a
method in Tune, I've added it in `doc/source/tune/api/` under the
           corresponding `.rst` file.
- [ ] I've made sure the tests are passing. Note that there might be a
few flaky tests, see the recent failures at https://flakey-tests.ray.io/
- Testing Strategy
   - [x] Unit tests
   - [ ] Release tests
   - [ ] This PR is not tested :(

---------

Signed-off-by: akavalar <[email protected]>
Signed-off-by: Xingyu Long <[email protected]>
Co-authored-by: akavalar <[email protected]>
  • Loading branch information
xingyu-long and akavalar authored Dec 13, 2024
1 parent bc41605 commit 3073fe7
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 1 deletion.
38 changes: 38 additions & 0 deletions python/ray/air/tests/test_object_extension.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import types

import numpy as np
import pandas as pd
import pyarrow as pa
import pytest

Expand Down Expand Up @@ -60,6 +61,43 @@ def test_arrow_pandas_roundtrip():
assert t1.equals(t2)


@pytest.mark.skipif(
not _object_extension_type_allowed(), reason="Object extension not supported."
)
def test_pandas_python_object_isna():
arr = np.array([1, np.nan, 3, 4, 5, np.nan, 7, 8, 9], dtype=object)
ta = PythonObjectArray(arr)
np.testing.assert_array_equal(ta.isna(), pd.isna(arr))


@pytest.mark.skipif(
not _object_extension_type_allowed(), reason="Object extension not supported."
)
def test_pandas_python_object_take():
arr = np.array([1, 2, 3, 4, 5], dtype=object)
ta = PythonObjectArray(arr)
indices = [1, 2, 3]
np.testing.assert_array_equal(ta.take(indices).to_numpy(), arr[indices])
indices = [1, 2, -1]
np.testing.assert_array_equal(
ta.take(indices, allow_fill=True, fill_value=100).to_numpy(),
np.array([2, 3, 100]),
)


@pytest.mark.skipif(
not _object_extension_type_allowed(), reason="Object extension not supported."
)
def test_pandas_python_object_concat():
arr1 = np.array([1, 2, 3, 4, 5], dtype=object)
arr2 = np.array([6, 7, 8, 9, 10], dtype=object)
ta1 = PythonObjectArray(arr1)
ta2 = PythonObjectArray(arr2)
concat_arr = PythonObjectArray._concat_same_type([ta1, ta2])
assert len(concat_arr) == arr1.shape[0] + arr2.shape[0]
np.testing.assert_array_equal(concat_arr.to_numpy(), np.concatenate([arr1, arr2]))


if __name__ == "__main__":
import sys

Expand Down
30 changes: 29 additions & 1 deletion python/ray/air/util/object_extensions/pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pandas as pd
import pyarrow as pa
from pandas._libs import lib
from pandas._typing import ArrayLike, Dtype, PositionalIndexer, npt
from pandas._typing import ArrayLike, Dtype, PositionalIndexer, TakeIndexer, npt

import ray.air.util.object_extensions.arrow
from ray.util.annotations import PublicAPI
Expand Down Expand Up @@ -80,6 +80,34 @@ def __arrow_array__(self, type=None):
self.values
)

def isna(self) -> np.ndarray:
return pd.isnull(self.values)

def take(
self,
indices: TakeIndexer,
*,
allow_fill: bool = False,
fill_value: typing.Any = None,
) -> "PythonObjectArray":
if allow_fill and fill_value is None:
fill_value = self.dtype.na_value

result = pd.core.algorithms.take(
self.values, indices, allow_fill=allow_fill, fill_value=fill_value
)
return self._from_sequence(result, dtype=self.dtype)

def copy(self) -> "PythonObjectArray":
return PythonObjectArray(self.values)

@classmethod
def _concat_same_type(
cls, to_concat: collections.abc.Sequence["PythonObjectArray"]
) -> "PythonObjectArray":
values_to_concat = [element.values for element in to_concat]
return cls(np.concatenate(values_to_concat))


@PublicAPI(stability="alpha")
@pd.api.extensions.register_extension_dtype
Expand Down

0 comments on commit 3073fe7

Please sign in to comment.