From 026021356df7286b3933a34c746c32e924581442 Mon Sep 17 00:00:00 2001 From: Stephen Huan Date: Sat, 21 Dec 2024 03:30:43 -0800 Subject: [PATCH] Fix tests with CPU backend --- tests/triton_call_test.py | 6 ++++++ tests/triton_test.py | 5 ++++- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/tests/triton_call_test.py b/tests/triton_call_test.py index fba77c97..b1a0ba4b 100644 --- a/tests/triton_call_test.py +++ b/tests/triton_call_test.py @@ -29,6 +29,12 @@ config.parse_flags_with_absl() +try: + jt.get_compute_capability(0) +except AttributeError: + # TODO(stephen-huan): add in jaxlib + jt.get_compute_capability = lambda _: np.inf + def setUpModule(): config.update("jax_enable_x64", True) diff --git a/tests/triton_test.py b/tests/triton_test.py index 89c7ff82..dc3be2a5 100644 --- a/tests/triton_test.py +++ b/tests/triton_test.py @@ -20,7 +20,10 @@ import numpy as np import triton import triton.language as tl -from triton.language.extra.cuda import libdevice +try: + from triton.language.extra.cuda import libdevice +except ImportError: + from triton.language.extra.cpu import libdevice @triton.jit