Skip to content

Commit

Permalink
Avoid call to asarray in jnp.einsum
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Jan 31, 2025
1 parent bf96717 commit d9f9870
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -9744,11 +9744,12 @@ def einsum(

contractions = tuple((a, frozenset(b), c) for a, b, c, *_ in contractions)

einsum = jit(_einsum, static_argnums=(1, 2, 3, 4, 5), inline=True)
jit_einsum = jit(_einsum, static_argnums=(1, 2, 3, 4, 5), inline=True)
if spec is not None:
einsum = jax.named_call(einsum, name=spec)
return einsum(operands, contractions, precision,
preferred_element_type, _dot_general, out_sharding)
jit_einsum = jax.named_call(jit_einsum, name=spec)
operand_arrays = list(util.ensure_arraylike_tuple("einsum", operands))
return jit_einsum(operand_arrays, contractions, precision,
preferred_element_type, _dot_general, out_sharding)


# Enable other modules to override einsum_contact_path.
Expand Down Expand Up @@ -9843,7 +9844,7 @@ def _removechars(s, chars):


def _einsum(
operands: Sequence,
operands: List[jax.Array],
contractions: Sequence[tuple[tuple[int, ...], frozenset[str], str]],
precision,
preferred_element_type,
Expand All @@ -9859,7 +9860,6 @@ def _einsum(
"`out_sharding` argument of `einsum` only supports NamedSharding"
" instances. Please file a bug if this is not enough for your use case.")
dtypes.check_user_dtype_supported(preferred_element_type, "einsum")
operands = list(map(asarray, operands))
if preferred_element_type is None:
preferred_element_type, output_weak_type = dtypes.result_type(*operands, return_weak_type_flag=True)
else:
Expand Down

0 comments on commit d9f9870

Please sign in to comment.