Skip to content

Commit

Permalink
Merge pull request #26243 from jakevdp:einsum-asarray
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 722455518
  • Loading branch information
Google-ML-Automation committed Feb 3, 2025
2 parents af84143 + 4e30a08 commit 57fa372
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -9744,11 +9744,12 @@ def einsum(

contractions = tuple((a, frozenset(b), c) for a, b, c, *_ in contractions)

einsum = jit(_einsum, static_argnums=(1, 2, 3, 4, 5), inline=True)
jit_einsum = jit(_einsum, static_argnums=(1, 2, 3, 4, 5), inline=True)
if spec is not None:
einsum = jax.named_call(einsum, name=spec)
return einsum(operands, contractions, precision,
preferred_element_type, _dot_general, out_sharding)
jit_einsum = jax.named_call(jit_einsum, name=spec)
operand_arrays = list(util.ensure_arraylike_tuple("einsum", operands))
return jit_einsum(operand_arrays, contractions, precision,
preferred_element_type, _dot_general, out_sharding)


# Enable other modules to override einsum_contact_path.
Expand Down Expand Up @@ -9843,7 +9844,7 @@ def _removechars(s, chars):


def _einsum(
operands: Sequence,
operands: list[jax.Array],
contractions: Sequence[tuple[tuple[int, ...], frozenset[str], str]],
precision,
preferred_element_type,
Expand All @@ -9859,7 +9860,6 @@ def _einsum(
"`out_sharding` argument of `einsum` only supports NamedSharding"
" instances. Please file a bug if this is not enough for your use case.")
dtypes.check_user_dtype_supported(preferred_element_type, "einsum")
operands = list(map(asarray, operands))
if preferred_element_type is None:
preferred_element_type, output_weak_type = dtypes.result_type(*operands, return_weak_type_flag=True)
else:
Expand Down

0 comments on commit 57fa372

Please sign in to comment.