diff --git a/tests/colocated_python_test.py b/tests/colocated_python_test.py index 8e1d676a78df..aa03261ea258 100644 --- a/tests/colocated_python_test.py +++ b/tests/colocated_python_test.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import base64 +import struct import tempfile import threading import time @@ -23,6 +25,7 @@ import jax from jax._src import config from jax._src import test_util as jtu +from jax._src.lib import xla_extension_version from jax.experimental import colocated_python from jax.experimental.colocated_python import serialization from jax.extend.ifrt_programs import ifrt_programs @@ -37,6 +40,7 @@ except (ModuleNotFoundError, ImportError): raise unittest.SkipTest("tests depend on cloudpickle library") + def _colocated_cpu_devices( devices: Sequence[jax.Device], ) -> Sequence[jax.Device]: @@ -378,6 +382,99 @@ def get_global_state(x: jax.Array) -> jax.Array: if "_testing_global_state" in colocated_python.__dict__: del colocated_python._testing_global_state + def testStringProcessing(self): + if xla_extension_version < 315: + self.skipTest( + "String support for colocated Python requires xla_extension_version" + " >= 315" + ) + cpu_devices = _colocated_cpu_devices(jax.local_devices()) + if len(cpu_devices) < 2: + self.skipTest(f"Need at least two CPU devices, got: {len(cpu_devices)}") + + @colocated_python.colocated_python + def f(x): + out_arrays = [] + upper_caser = np.vectorize( + lambda x: x.upper(), otypes=[np.dtypes.StringDType()] + ) + for shard in x.addressable_shards: + np_array = jax.device_get(shard.data) + out_np_array = upper_caser(np_array) + out_arrays.append(jax.device_put(out_np_array, device=shard.device)) + return jax.make_array_from_single_device_arrays( + sharding=x.sharding, shape=x.shape, arrays=out_arrays + ) + + # Make a string array. + numpy_string_array = np.array( + [["abcd", "efgh"], ["ijkl", "mnop"]], dtype=np.dtypes.StringDType() # type: ignore + ) + mesh = jax.sharding.Mesh( + np.array(cpu_devices[:2]).reshape((2, 1)), ("x", "y") + ) + sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec("x")) + x = jax.device_put(numpy_string_array, device=sharding) + + # Run the colocated Python function with the string array as input. + out = f(x) + out = jax.device_get(out) + + # Should have gotten the strings with all upper case letters. + np.testing.assert_equal( + out, + np.array( + [["ABCD", "EFGH"], ["IJKL", "MNOP"]], dtype=np.dtypes.StringDType() + ), + ) + + def testBinaryDataProcessing(self): + if xla_extension_version < 315: + self.skipTest( + "String support for colocated Python requires xla_extension_version" + " >= 315" + ) + cpu_devices = _colocated_cpu_devices(jax.local_devices()) + if len(cpu_devices) < 1: + self.skipTest("Need at least one CPU devices") + + @colocated_python.colocated_python + def f(x): + out_arrays = [] + for shard in x.addressable_shards: + np_array = jax.device_get(shard.data) + input_ints = struct.unpack( + "