Skip to content

Commit

Permalink
Adds support for string and binary data processing in Colocated Python.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 726704962
  • Loading branch information
Google-ML-Automation committed Feb 14, 2025
1 parent 60dcded commit 63b78ff
Showing 1 changed file with 97 additions and 0 deletions.
97 changes: 97 additions & 0 deletions tests/colocated_python_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import base64
import struct
import tempfile
import threading
import time
Expand All @@ -23,6 +25,7 @@
import jax
from jax._src import config
from jax._src import test_util as jtu
from jax._src.lib import xla_extension_version
from jax.experimental import colocated_python
from jax.experimental.colocated_python import serialization
from jax.extend.ifrt_programs import ifrt_programs
Expand All @@ -37,6 +40,7 @@
except (ModuleNotFoundError, ImportError):
raise unittest.SkipTest("tests depend on cloudpickle library")


def _colocated_cpu_devices(
devices: Sequence[jax.Device],
) -> Sequence[jax.Device]:
Expand Down Expand Up @@ -378,6 +382,99 @@ def get_global_state(x: jax.Array) -> jax.Array:
if "_testing_global_state" in colocated_python.__dict__:
del colocated_python._testing_global_state

def testStringProcessing(self):
if xla_extension_version < 315:
self.skipTest(
"String support for colocated Python requires xla_extension_version"
" >= 315"
)
cpu_devices = _colocated_cpu_devices(jax.local_devices())
if len(cpu_devices) < 2:
self.skipTest(f"Need at least two CPU devices, got: {len(cpu_devices)}")

@colocated_python.colocated_python
def f(x):
out_arrays = []
upper_caser = np.vectorize(
lambda x: x.upper(), otypes=[np.dtypes.StringDType()]
)
for shard in x.addressable_shards:
np_array = jax.device_get(shard.data)
out_np_array = upper_caser(np_array)
out_arrays.append(jax.device_put(out_np_array, device=shard.device))
return jax.make_array_from_single_device_arrays(
sharding=x.sharding, shape=x.shape, arrays=out_arrays
)

# Make a string array.
numpy_string_array = np.array(
[["abcd", "efgh"], ["ijkl", "mnop"]], dtype=np.dtypes.StringDType() # type: ignore
)
mesh = jax.sharding.Mesh(
np.array(cpu_devices[:2]).reshape((2, 1)), ("x", "y")
)
sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec("x"))
x = jax.device_put(numpy_string_array, device=sharding)

# Run the colocated Python function with the string array as input.
out = f(x)
out = jax.device_get(out)

# Should have gotten the strings with all upper case letters.
np.testing.assert_equal(
out,
np.array(
[["ABCD", "EFGH"], ["IJKL", "MNOP"]], dtype=np.dtypes.StringDType()
),
)

def testBinaryDataProcessing(self):
if xla_extension_version < 315:
self.skipTest(
"String support for colocated Python requires xla_extension_version"
" >= 315"
)
cpu_devices = _colocated_cpu_devices(jax.local_devices())
if len(cpu_devices) < 1:
self.skipTest("Need at least one CPU devices")

@colocated_python.colocated_python
def f(x):
out_arrays = []
for shard in x.addressable_shards:
np_array = jax.device_get(shard.data)
input_ints = struct.unpack(
"<ii", base64.b64decode(np_array[0].encode("utf-8"))
)
output_string = base64.b64encode(struct.pack(
"<ii", input_ints[0] + 1, input_ints[1] + 1
)).decode("utf-8")
out_np_array = np.array([output_string], dtype=np.dtypes.StringDType())
out_arrays.append(jax.device_put(out_np_array, device=shard.device))

out = jax.make_array_from_single_device_arrays(
sharding=x.sharding, shape=x.shape, arrays=out_arrays
)
return out

# Make the input array with the binary data that packs two integers as UTF-8
# string.
input_string = base64.b64encode(struct.pack("<ii", 1001, 1002)).decode(
"utf-8"
)
numpy_string_array = np.array([input_string], dtype=np.dtypes.StringDType())
sharding = jax.sharding.SingleDeviceSharding(cpu_devices[0])
x = jax.device_put(numpy_string_array, device=sharding)

out = f(x)
out = jax.device_get(out)

# Should have gotten the binary data with the incremented integers as a UTF-8
# string.
out_ints = struct.unpack("<ii", base64.b64decode(out[0].encode("utf-8")))
self.assertEqual(out_ints[0], 1002)
self.assertEqual(out_ints[1], 1003)


if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())

0 comments on commit 63b78ff

Please sign in to comment.