From 213a99074f48c275ab915fadfcc8f866d6ea00c6 Mon Sep 17 00:00:00 2001 From: jax authors Date: Tue, 11 Feb 2025 10:26:38 -0800 Subject: [PATCH] In progress. Adds support for string processing in Colocated Python. PiperOrigin-RevId: 725683017 --- tests/colocated_python_test.py | 134 ++++++++++++++++++++++++++++++++- 1 file changed, 133 insertions(+), 1 deletion(-) diff --git a/tests/colocated_python_test.py b/tests/colocated_python_test.py index 8e1d676a78df..a8b7cd20a49a 100644 --- a/tests/colocated_python_test.py +++ b/tests/colocated_python_test.py @@ -12,10 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +import base64 +import logging +import struct import tempfile import threading import time -from typing import Sequence +from typing import Sequence, assert_never import unittest from absl.testing import absltest @@ -23,6 +26,7 @@ import jax from jax._src import config from jax._src import test_util as jtu +from jax._src.lib import xla_extension_version from jax.experimental import colocated_python from jax.experimental.colocated_python import serialization from jax.extend.ifrt_programs import ifrt_programs @@ -37,6 +41,7 @@ except (ModuleNotFoundError, ImportError): raise unittest.SkipTest("tests depend on cloudpickle library") + def _colocated_cpu_devices( devices: Sequence[jax.Device], ) -> Sequence[jax.Device]: @@ -378,6 +383,133 @@ def get_global_state(x: jax.Array) -> jax.Array: if "_testing_global_state" in colocated_python.__dict__: del colocated_python._testing_global_state + def testStringProcessing(self): + if xla_extension_version < 315: + self.skipTest( + "String support for colocated Python requires xla_extension_version" + " >= 315" + ) + cpu_devices = _colocated_cpu_devices(jax.local_devices()) + if len(cpu_devices) < 2: + self.skipTest(f"Need at least two CPU devices, got: {len(cpu_devices)}") + + @colocated_python.colocated_python + def f(x): + logging.error( + "2DO x: %s, dtype: %s, sharding: %s", x, x.dtype, x.sharding + ) + + out_arrays = [] + string_processor = np.vectorize( + lambda x: x.upper(), otypes=[np.dtypes.StringDType()] + ) + for shard in x.addressable_shards: + logging.error("2DO shard: %s, dtype: %s", shard.data, shard.data.dtype) + np_array = jax.device_get(shard.data) + logging.error("2DO np_array: %s, type: %s", np_array, np_array.dtype) + out_np_array = string_processor(np_array) + logging.error( + "2DO out_np_array: %s, type: %s", out_np_array, out_np_array.dtype + ) + out_jax_array = jax.device_put(out_np_array, device=shard.device) + logging.error( + "2DO out_jax_array: %s, type: %s", + out_jax_array, + out_jax_array.dtype, + ) + out_arrays.append(out_jax_array) + + out = jax.make_array_from_single_device_arrays( + sharding=x.sharding, shape=x.shape, arrays=out_arrays + ) + return out + + numpy_string_array = np.array( + [["abcd", "efgh"], ["ijkl", "mnop"]], dtype=np.dtypes.StringDType() # type: ignore + ) + mesh = jax.sharding.Mesh( + np.array(cpu_devices[:2]).reshape((2, 1)), ("x", "y") + ) + sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec("x")) + x = jax.device_put(numpy_string_array, device=sharding) + logging.error("2DO input_array: %s, dtype: %s", x, x.dtype) + logging.error( + "2DO input_array.shard0: %s, dtype: %s", + x.addressable_shards[0].data, + x.addressable_shards[0].data.dtype, + ) + logging.error( + "2DO input_array.shard1: %s, dtype: %s", + x.addressable_shards[1].data, + x.addressable_shards[1].data.dtype, + ) + + out = f(x) + out = jax.device_get(out) + logging.info("2DO out: %s", out) + np.testing.assert_equal( + out, + np.array( + [["ABCD", "EFGH"], ["IJKL", "MNOP"]], dtype=np.dtypes.StringDType() + ), + ) + + def testBinaryDataProcessing(self): + if xla_extension_version < 315: + self.skipTest( + "String support for colocated Python requires xla_extension_version" + " >= 315" + ) + cpu_devices = _colocated_cpu_devices(jax.local_devices()) + if len(cpu_devices) < 1: + self.skipTest("Need at least one CPU devices") + + @colocated_python.colocated_python + def f(x): + logging.error( + "2DO x: %s, dtype: %s, sharding: %s", x, x.dtype, x.sharding + ) + out_arrays = [] + for shard in x.addressable_shards: + logging.error("2DO shard: %s, dtype: %s", shard.data, shard.data.dtype) + np_array = jax.device_get(shard.data) + logging.error("2DO np_array: %s, type: %s", np_array, np_array.dtype) + input_ints = struct.unpack("