Skip to content

Commit

Permalink
Fix tests with CPU backend
Browse files Browse the repository at this point in the history
  • Loading branch information
stephen-huan committed Jan 5, 2025
1 parent ed25ada commit 0260213
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 1 deletion.
6 changes: 6 additions & 0 deletions tests/triton_call_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,12 @@

config.parse_flags_with_absl()

try:
jt.get_compute_capability(0)
except AttributeError:
# TODO(stephen-huan): add in jaxlib
jt.get_compute_capability = lambda _: np.inf


def setUpModule():
config.update("jax_enable_x64", True)
Expand Down
5 changes: 4 additions & 1 deletion tests/triton_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@
import numpy as np
import triton
import triton.language as tl
from triton.language.extra.cuda import libdevice
try:
from triton.language.extra.cuda import libdevice
except ImportError:
from triton.language.extra.cpu import libdevice


@triton.jit
Expand Down

0 comments on commit 0260213

Please sign in to comment.