Skip to content

Commit

Permalink
Merge pull request #26472 from jakevdp:jnp-einsum
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 726580373
  • Loading branch information
Google-ML-Automation committed Feb 13, 2025
2 parents 229aa65 + 7ab7b21 commit 5ebb7eb
Show file tree
Hide file tree
Showing 4 changed files with 586 additions and 551 deletions.
4 changes: 2 additions & 2 deletions jax/_src/export/shape_poly.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
from jax._src import effects
from jax._src.lax import lax
from jax._src.interpreters import mlir
from jax._src.numpy import lax_numpy
from jax._src.numpy import einsum as jnp_einsum
from jax._src import source_info_util
from jax._src import tree_util
from jax._src import util
Expand Down Expand Up @@ -1267,7 +1267,7 @@ def fake_dim(d):
contract_operands.append(operands[idx[0]])
return contract_operands, contractions

lax_numpy._poly_einsum_handlers[_DimExpr] = _einsum_contract_path
jnp_einsum._poly_einsum_handlers[_DimExpr] = _einsum_contract_path

# To implement shape-constraint checking we use a shape assertion primitive.
# shape_assertion_p.bind(assert_what: bool, *error_message_inputs,
Expand Down
Loading

0 comments on commit 5ebb7eb

Please sign in to comment.