Skip to content

Commit

Permalink
In progress. Adds support for string processing in Colocated Python.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 725683017
  • Loading branch information
Google-ML-Automation committed Feb 13, 2025
1 parent 60dcded commit 213a990
Showing 1 changed file with 133 additions and 1 deletion.
134 changes: 133 additions & 1 deletion tests/colocated_python_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import base64
import logging
import struct
import tempfile
import threading
import time
from typing import Sequence
from typing import Sequence, assert_never
import unittest

from absl.testing import absltest
from absl.testing import parameterized
import jax
from jax._src import config
from jax._src import test_util as jtu
from jax._src.lib import xla_extension_version
from jax.experimental import colocated_python
from jax.experimental.colocated_python import serialization
from jax.extend.ifrt_programs import ifrt_programs
Expand All @@ -37,6 +41,7 @@
except (ModuleNotFoundError, ImportError):
raise unittest.SkipTest("tests depend on cloudpickle library")


def _colocated_cpu_devices(
devices: Sequence[jax.Device],
) -> Sequence[jax.Device]:
Expand Down Expand Up @@ -378,6 +383,133 @@ def get_global_state(x: jax.Array) -> jax.Array:
if "_testing_global_state" in colocated_python.__dict__:
del colocated_python._testing_global_state

def testStringProcessing(self):
if xla_extension_version < 315:
self.skipTest(
"String support for colocated Python requires xla_extension_version"
" >= 315"
)
cpu_devices = _colocated_cpu_devices(jax.local_devices())
if len(cpu_devices) < 2:
self.skipTest(f"Need at least two CPU devices, got: {len(cpu_devices)}")

@colocated_python.colocated_python
def f(x):
logging.error(
"2DO x: %s, dtype: %s, sharding: %s", x, x.dtype, x.sharding
)

out_arrays = []
string_processor = np.vectorize(
lambda x: x.upper(), otypes=[np.dtypes.StringDType()]
)
for shard in x.addressable_shards:
logging.error("2DO shard: %s, dtype: %s", shard.data, shard.data.dtype)
np_array = jax.device_get(shard.data)
logging.error("2DO np_array: %s, type: %s", np_array, np_array.dtype)
out_np_array = string_processor(np_array)
logging.error(
"2DO out_np_array: %s, type: %s", out_np_array, out_np_array.dtype
)
out_jax_array = jax.device_put(out_np_array, device=shard.device)
logging.error(
"2DO out_jax_array: %s, type: %s",
out_jax_array,
out_jax_array.dtype,
)
out_arrays.append(out_jax_array)

out = jax.make_array_from_single_device_arrays(
sharding=x.sharding, shape=x.shape, arrays=out_arrays
)
return out

numpy_string_array = np.array(
[["abcd", "efgh"], ["ijkl", "mnop"]], dtype=np.dtypes.StringDType() # type: ignore
)
mesh = jax.sharding.Mesh(
np.array(cpu_devices[:2]).reshape((2, 1)), ("x", "y")
)
sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec("x"))
x = jax.device_put(numpy_string_array, device=sharding)
logging.error("2DO input_array: %s, dtype: %s", x, x.dtype)
logging.error(
"2DO input_array.shard0: %s, dtype: %s",
x.addressable_shards[0].data,
x.addressable_shards[0].data.dtype,
)
logging.error(
"2DO input_array.shard1: %s, dtype: %s",
x.addressable_shards[1].data,
x.addressable_shards[1].data.dtype,
)

out = f(x)
out = jax.device_get(out)
logging.info("2DO out: %s", out)
np.testing.assert_equal(
out,
np.array(
[["ABCD", "EFGH"], ["IJKL", "MNOP"]], dtype=np.dtypes.StringDType()
),
)

def testBinaryDataProcessing(self):
if xla_extension_version < 315:
self.skipTest(
"String support for colocated Python requires xla_extension_version"
" >= 315"
)
cpu_devices = _colocated_cpu_devices(jax.local_devices())
if len(cpu_devices) < 1:
self.skipTest("Need at least one CPU devices")

@colocated_python.colocated_python
def f(x):
logging.error(
"2DO x: %s, dtype: %s, sharding: %s", x, x.dtype, x.sharding
)
out_arrays = []
for shard in x.addressable_shards:
logging.error("2DO shard: %s, dtype: %s", shard.data, shard.data.dtype)
np_array = jax.device_get(shard.data)
logging.error("2DO np_array: %s, type: %s", np_array, np_array.dtype)
input_ints = struct.unpack("<ii", base64.b64decode(np_array[0].encode("utf-8")))
output_string = base64.b64encode(struct.pack(
"<ii", input_ints[0] + 1, input_ints[1] + 1
)).decode("utf-8")
out_np_array = np.array([output_string], dtype=np.dtypes.StringDType())
logging.error(
"2DO out_np_array: %s, type: %s", out_np_array, out_np_array.dtype
)
out_jax_array = jax.device_put(out_np_array, device=shard.device)
logging.error(
"2DO out_jax_array: %s, type: %s",
out_jax_array,
out_jax_array.dtype,
)
out_arrays.append(out_jax_array)

out = jax.make_array_from_single_device_arrays(
sharding=x.sharding, shape=x.shape, arrays=out_arrays
)
return out

input_string = base64.b64encode(struct.pack("<ii", 1001, 1002)).decode(
"utf-8"
)
numpy_string_array = np.array([input_string], dtype=np.dtypes.StringDType())
sharding = jax.sharding.SingleDeviceSharding(cpu_devices[0])
x = jax.device_put(numpy_string_array, device=sharding)
logging.error("2DO input_array: %s, dtype: %s", x, x.dtype)

out = f(x)
out = jax.device_get(out)
out_ints = struct.unpack("<ii", base64.b64decode(out[0].encode("utf-8")))
logging.info("2DO out ints: %s", out_ints)
self.assertEqual(out_ints[0], 1002)
self.assertEqual(out_ints[1], 1003)


if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())

0 comments on commit 213a990

Please sign in to comment.