Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Global mutable arrays do not work with AOT #26354

Open
ayaka14732 opened this issue Feb 6, 2025 · 0 comments
Open

Global mutable arrays do not work with AOT #26354

ayaka14732 opened this issue Feb 6, 2025 · 0 comments
Assignees
Labels
bug Something isn't working

Comments

@ayaka14732
Copy link
Member

Description

Repro:

import jax
import jax.numpy as jnp
from jax._src.core import mutable_array

a = jnp.uint32(0)
a_ref = mutable_array(a)

def f(x):
    a_ref[()] = 7  # writing to a global mutable array
    return jnp.isfinite(x)

x = jnp.float32(4.)

lowered = jax.jit(f).lower(x)
compiled = lowered.compile()
print(compiled(x))

Error output:

Traceback (most recent call last):
  File "/home/ayx/dev/checkify/16.py", line 16, in <module>
    print(compiled(x))
          ~~~~~~~~^^^
  File "/home/ayx/dev/checkify/venv/lib/python3.13/site-packages/jax/_src/stages.py", line 608, in __call__
    return self._call(*args, **kwargs)
           ~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/home/ayx/dev/checkify/venv/lib/python3.13/site-packages/jax/_src/interpreters/pxla.py", line 3123, in aot_cache_miss
    outs, out_flat, args_flat = stages.Compiled.call(params, *args, **kwargs)
                                ~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/dev/checkify/venv/lib/python3.13/site-packages/jax/_src/stages.py", line 577, in call
    out_flat = params.executable.call(*args_flat)
  File "/home/ayx/dev/checkify/venv/lib/python3.13/site-packages/jax/_src/interpreters/pxla.py", line 3096, in call
    check_array_xla_sharding_layout_match(
    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^
        args_after_dce, self._in_shardings, self._xla_in_layouts, debug_info,
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
        self._kept_var_idx)
        ^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/dev/checkify/venv/lib/python3.13/site-packages/jax/_src/interpreters/pxla.py", line 3226, in check_array_xla_sharding_layout_match
    for arg, xs, xl, name in safe_zip(
                             ~~~~~~~~^
        args_after_dce, in_xla_shardings, in_xla_layouts, arg_names):
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ValueError: safe_zip() argument 2 is longer than argument 1

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.5.0
jaxlib: 0.5.0
numpy:  2.2.2
python: 3.13.0rc3 (main, Oct  2 2024, 17:18:08) [Clang 18.1.8 ]
device info: cpu-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='ayx1', release='6.10.11-1rodete2-amd64', version='#1 SMP PREEMPT_DYNAMIC Debian 6.10.11-1rodete2 (2024-10-16)', machine='x86_64')
@ayaka14732 ayaka14732 added the bug Something isn't working label Feb 6, 2025
@ayaka14732 ayaka14732 changed the title Global mutable arrays does not work with AOT Global mutable arrays do not work with AOT Feb 6, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants