Skip to content

Commit

Permalink
Merge pull request #26523 from andportnoy:aportnoy/mosaic-gpu-dialect…
Browse files Browse the repository at this point in the history
…-hasattr-import

PiperOrigin-RevId: 726900934
  • Loading branch information
Google-ML-Automation committed Feb 14, 2025
2 parents 12d533f + ea5eb49 commit 4df5961
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions jax/experimental/mosaic/gpu/dialect_lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,8 @@ def _fragmented_array_from_ir(

# TODO(dasenov): Remove this when minimum jaxlib version >= 0.5.1.
# Jaxlib doesn't contain the latest Mosaic GPU dialect bindings.
WaitOp = mgpu.WaitOp if jax.version._version == jax.lib.__version__ else None
ArriveExpectTxOp = mgpu.ArriveExpectTxOp if jax.version._version == jax.lib.__version__ else None
WaitOp = getattr(mgpu, "WaitOp", None)
ArriveExpectTxOp = getattr(mgpu, "ArriveExpectTxOp", None)

def _register_lowering(
op: str | Type[ir.OpView] | None
Expand Down

0 comments on commit 4df5961

Please sign in to comment.