Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make jnp.asarray lower to asarray_p in the simplest cases. #26244

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

jakevdp
Copy link
Collaborator

@jakevdp jakevdp commented Jan 31, 2025

This is an experiment toward a fix to #25745 and #18020.

Basically, we currently lower jnp.asarray to nothing in many cases, and this leads to strange behavior under some transformations; for example:

In [1]: import jax

In [2]: import numpy as np

In [3]: out = jax.vmap(jax.numpy.asarray)(np.arange(4))

In [4]: type(out)
Out[4]: numpy.ndarray

The fix here ensures that asarray lowers to a primitive, so that these kinds of corner cases behave as expected.

Many of the test failures here should be fixed by work related to #25931, which removes calls to jnp.asarray from JAX API implementations.

@jakevdp jakevdp self-assigned this Jan 31, 2025
@jakevdp jakevdp marked this pull request as draft January 31, 2025 19:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant