Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds support for string and binary data processing in Colocated Python. #26530

Merged
merged 1 commit into from
Feb 14, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 97 additions & 0 deletions tests/colocated_python_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import base64
import struct
import tempfile
import threading
import time
Expand All @@ -23,6 +25,7 @@
import jax
from jax._src import config
from jax._src import test_util as jtu
from jax._src.lib import xla_extension_version
from jax.experimental import colocated_python
from jax.experimental.colocated_python import serialization
from jax.extend.ifrt_programs import ifrt_programs
Expand All @@ -37,6 +40,7 @@
except (ModuleNotFoundError, ImportError):
raise unittest.SkipTest("tests depend on cloudpickle library")


def _colocated_cpu_devices(
devices: Sequence[jax.Device],
) -> Sequence[jax.Device]:
Expand Down Expand Up @@ -378,6 +382,99 @@ def get_global_state(x: jax.Array) -> jax.Array:
if "_testing_global_state" in colocated_python.__dict__:
del colocated_python._testing_global_state

def testStringProcessing(self):
if xla_extension_version < 315:
self.skipTest(
"String support for colocated Python requires xla_extension_version"
" >= 315"
)
cpu_devices = _colocated_cpu_devices(jax.local_devices())
if len(cpu_devices) < 2:
self.skipTest(f"Need at least two CPU devices, got: {len(cpu_devices)}")

@colocated_python.colocated_python
def f(x):
out_arrays = []
upper_caser = np.vectorize(
lambda x: x.upper(), otypes=[np.dtypes.StringDType()]
)
for shard in x.addressable_shards:
np_array = jax.device_get(shard.data)
out_np_array = upper_caser(np_array)
out_arrays.append(jax.device_put(out_np_array, device=shard.device))
return jax.make_array_from_single_device_arrays(
sharding=x.sharding, shape=x.shape, arrays=out_arrays
)

# Make a string array.
numpy_string_array = np.array(
[["abcd", "efgh"], ["ijkl", "mnop"]], dtype=np.dtypes.StringDType() # type: ignore
)
mesh = jax.sharding.Mesh(
np.array(cpu_devices[:2]).reshape((2, 1)), ("x", "y")
)
sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec("x"))
x = jax.device_put(numpy_string_array, device=sharding)

# Run the colocated Python function with the string array as input.
out = f(x)
out = jax.device_get(out)

# Should have gotten the strings with all upper case letters.
np.testing.assert_equal(
out,
np.array(
[["ABCD", "EFGH"], ["IJKL", "MNOP"]], dtype=np.dtypes.StringDType()
),
)

def testBinaryDataProcessing(self):
if xla_extension_version < 315:
self.skipTest(
"String support for colocated Python requires xla_extension_version"
" >= 315"
)
cpu_devices = _colocated_cpu_devices(jax.local_devices())
if len(cpu_devices) < 1:
self.skipTest("Need at least one CPU devices")

@colocated_python.colocated_python
def f(x):
out_arrays = []
for shard in x.addressable_shards:
np_array = jax.device_get(shard.data)
input_ints = struct.unpack(
"<ii", base64.b64decode(np_array[0].encode("ascii"))
)
output_string = base64.b64encode(
struct.pack("<ii", input_ints[0] + 1, input_ints[1] + 1)
).decode("ascii")
out_np_array = np.array([output_string], dtype=np.dtypes.StringDType())
out_arrays.append(jax.device_put(out_np_array, device=shard.device))

out = jax.make_array_from_single_device_arrays(
sharding=x.sharding, shape=x.shape, arrays=out_arrays
)
return out

# Make the input array with the binary data that packs two integers as ascii
# string.
input_string = base64.b64encode(struct.pack("<ii", 1001, 1002)).decode(
"ascii"
)
numpy_string_array = np.array([input_string], dtype=np.dtypes.StringDType())
sharding = jax.sharding.SingleDeviceSharding(cpu_devices[0])
x = jax.device_put(numpy_string_array, device=sharding)

out = f(x)
out = jax.device_get(out)

# Should have gotten the binary data with the incremented integers as a
# ascii string.
out_ints = struct.unpack("<ii", base64.b64decode(out[0].encode("ascii")))
self.assertEqual(out_ints[0], 1002)
self.assertEqual(out_ints[1], 1003)


if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())
Loading