Skip to content

Commit

Permalink
Add CPU backend
Browse files Browse the repository at this point in the history
  • Loading branch information
stephen-huan committed Jan 5, 2025
1 parent 95404f9 commit ed25ada
Show file tree
Hide file tree
Showing 2 changed files with 172 additions and 36 deletions.
14 changes: 0 additions & 14 deletions jax_triton/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
"__version_info__",
]

from jax._src.lib import gpu_triton
from jax_triton import utils
from jax_triton.triton_lib import triton_call
from jax.experimental.pallas import cdiv
Expand All @@ -33,17 +32,4 @@
from jax_triton.version import __version__
from jax_triton.version import __version_info__

try:
get_compute_capability = gpu_triton.get_compute_capability
get_serialized_metadata = gpu_triton.get_serialized_metadata
except AttributeError:
raise ImportError(
"jax-triton requires JAX to be installed with GPU support. The "
"installation page on the JAX documentation website includes "
"instructions for installing a supported version:\n"
"https://jax.readthedocs.io/en/latest/installation.html"
)
else:
del gpu_triton # Not part of the API.

# trailer
194 changes: 172 additions & 22 deletions jax_triton/triton_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,12 @@
import triton.backends.amd.compiler as hb
except ImportError:
hb = None
pass

try:
import triton.backends.cpu.compiler as cpub

except ImportError:
cpub = None


os.environ["TRITON_CACHE_DIR"] = ""
Expand Down Expand Up @@ -170,6 +175,13 @@ def get_hip_backend(device, compute_capability):
backend = hb.HIPBackend(target)
return backend

def get_cpu_backend(device, compute_capability):
arch = _triton.llvm.get_cpu_tripple()
arch = arch.split("-")[0]
target = cpub.GPUTarget('cpu', arch, 0)
backend = cpub.CPUBackend(target)
return backend

@dataclasses.dataclass
class CompilationResult:
binary: str
Expand All @@ -181,8 +193,8 @@ class CompilationResult:

def compile_ttir_inplace(
ttir,
backend: [cb.CUDABackend | hb.HIPBackend],
options: [cb.CUDAOptions | hb.HIPOptions],
backend: cb.CUDABackend | hb.HIPBackend | cpub.CPUBackend,
options: cb.CUDAOptions | hb.HIPOptions | cpub.CPUOptions,
compute_capability,
platform
):
Expand All @@ -201,6 +213,13 @@ def compile_ttir_inplace(
options,
compute_capability,
)
elif platform == 'cpu':
return compile_ttir_to_asm_inplace(
ttir,
backend,
options,
compute_capability,
)
else:
raise ValueError(
"Unsupported device."
Expand Down Expand Up @@ -322,6 +341,70 @@ def compile_ttir_to_hsaco_inplace(
llir=llir,
)

def compile_ttir_to_asm_inplace(
ttir,
cpu_backend: cpub.CPUBackend,
cpu_options: cpub.CPUOptions,
compute_capability,
) -> CompilationResult:
if cpu_options.debug:
print(ttir)
try:
metadata = {}
opt_ttir = cpu_backend.make_ttir(ttir, metadata, cpu_options)
ttcir = cpu_backend.make_ttcir(
opt_ttir,
metadata,
cpu_options
)
except RuntimeError as e:
ttir.dump()
raise ValueError("TTIR->TTCIR pass failed!") from e
if cpu_options.debug:
print(ttcir)
try:
tttcir = cpu_backend.make_tttcir(
ttcir,
metadata,
cpu_options
)
except RuntimeError as e:
ttcir.dump()
raise ValueError("TTCIR->TTTCIR pass failed!") from e
if cpu_options.debug:
print(tttcir)
try:
llir = cpu_backend.make_llir(
tttcir,
metadata,
cpu_options
)
except RuntimeError as e:
tttcir.dump()
raise ValueError("TTTCIR->LLIR pass failed!") from e
shared_mem_bytes = metadata["shared"]
if cpu_options.debug:
print(llir)
asm = cpu_backend.make_asm(
llir,
metadata,
cpu_options
)
if cpu_options.debug:
print(asm)
name = metadata["name"]
cluster_dims = metadata["cluster_dims"]
tttcir = str(tttcir) if _JAX_TRITON_DUMP_DIR else None
llir = str(llir) if _JAX_TRITON_DUMP_DIR else None
return CompilationResult(
binary=asm,
name=name,
shared_mem_bytes=shared_mem_bytes,
cluster_dims=cluster_dims,
ttgir=tttcir,
llir=llir,
)

_COMPILED_KERNEL_CACHE = {} # TODO(cjfj): Convert to LRU cache?


Expand Down Expand Up @@ -690,6 +773,12 @@ def prune_configs(configs, named_args, **kwargs):
platform="rocm",
)

mlir.register_lowering(
triton_kernel_call_p,
functools.partial(triton_kernel_call_lowering, get_cpu_backend),
platform="cpu",
)

class ShapeDtype(Protocol):

@property
Expand Down Expand Up @@ -827,23 +916,84 @@ def add(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
if input_output_aliases is None:
input_output_aliases = {}

out_flat = triton_kernel_call_p.bind(
*array_args,
fn=kernel,
scalar_args=tuple(scalar_args),
name=name,
custom_call_target_name=custom_call_target_name,
out_shapes=tuple(flat_out_shapes),
grid=grid,
num_warps=num_warps,
num_stages=num_stages,
num_ctas=num_ctas,
compute_capability=compute_capability,
enable_fp_fusion=enable_fp_fusion,
input_output_aliases=tuple(input_output_aliases.items()),
zeroed_outputs=zeroed_outputs,
debug=debug,
serialized_metadata=serialized_metadata,
**metaparams,
)
if triton.runtime.driver.active.get_current_target().backend != "cpu":
out_flat = triton_kernel_call_p.bind(
*array_args,
fn=kernel,
scalar_args=tuple(scalar_args),
name=name,
custom_call_target_name=custom_call_target_name,
out_shapes=tuple(flat_out_shapes),
grid=grid,
num_warps=num_warps,
num_stages=num_stages,
num_ctas=num_ctas,
compute_capability=compute_capability,
enable_fp_fusion=enable_fp_fusion,
input_output_aliases=tuple(input_output_aliases.items()),
zeroed_outputs=zeroed_outputs,
debug=debug,
serialized_metadata=serialized_metadata,
**metaparams,
)
else:
if isinstance(kernel, autotuner.Autotuner):
for config in kernel.configs:
if config.pre_hook is not None:
raise NotImplementedError("`pre_hook` is not supported")

class Pointer:

def __init__(self, x):
self.x = x
self.dtype = x.dtype

def data_ptr(self):
return self.x.unsafe_buffer_pointer()

def to_triton_arg(arg):
if arg.ndim == 0:
dtypes = {
jnp.bool.dtype: bool,
jnp.int32.dtype: int,
jnp.int64.dtype: int,
jnp.float32.dtype: float,
jnp.float64.dtype: float,
}
if arg.dtype not in dtypes:
raise ValueError(f"Invalid argument {arg} with type {arg.dtype}.")
return dtypes[arg.dtype](arg)
else:
return Pointer(arg)

def callback(flat_args, outputs):
kernel[lambda meta: normalize_grid(grid, metaparams | meta)](
*map(to_triton_arg, flat_args),
*map(Pointer, outputs),
**metaparams,
)
return outputs

# FIXME(stephen-huan): doesn't take into account kernel's meta
config_zeroed_outputs = zeroed_outputs
if callable(zeroed_outputs):
config_zeroed_outputs = config_zeroed_outputs(metaparams)

output_input_aliases = {}
for input_idx, output_idx in input_output_aliases.items():
if output_idx in output_input_aliases:
# TODO(stephen-huan): not sure how to handle this properly
raise NotImplementedError(
"Multiple inputs aliased to the same output is not supported."
)
output_input_aliases[output_idx] = flat_args[input_idx]
if output_idx in config_zeroed_outputs:
flat_args[input_idx] = flat_args[input_idx].at[:].set(0)

out_shapes = tuple(flat_out_shapes)
outputs = [
output_input_aliases.get(i, jnp.zeros(shape.shape, shape.dtype))
for i, shape in enumerate(out_shapes)
]
out_flat = jax.pure_callback(callback, out_shapes, flat_args, outputs)
return tree_util.tree_unflatten(out_tree, out_flat)

0 comments on commit ed25ada

Please sign in to comment.