You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
It would be nice if the interpreter could support jax/numpy without all the boilerplate, especially because the interpreter lowers to numpy on cpu anyways. It would be extra nice if passing jax/numpy arrays "just worked" like pytorch tensors.
As I primarily write jax, this is not-so-relevant for me as jax has jax-triton and pallas (which has its own interpret mode). But given that (roughly) numpy : cpu :: pytorch : gpus, it would be nice if numpy was "blessed" for the cpu backend.
I would submit a PR, but it seems triton assumes things are torch tensor-like in all sorts of places in a much more global manner than #205. Naively, it might be possible to simply add additional checks when the kernel is being executed (.data_ptr(), .unsafe_buffer_pointer(), .ctypes.data) but there's too much I don't understand about triton's organization (for example, what is TensorWrapper doing in jit.py and why does it have torch semantics?)
Describe the bug
(This is more of a feature request than a bug, and not a very pressing one, so feel free to ignore.)
On the thread of #204, it is possible to use triton-cpu with numpy/jax with the following
Pointer
shims(note that in the case of jax on gpu, it's possible to use jax-triton, see e.g. jax-ml/jax-triton#322 for an extension to cpu).
However, when
TRITON_INTERPRET=1
, the amount of boilerplate required drastically increases.(this could probably be written more efficiently with
jax.device_put
andjax.device_get
.)(an explicit
main
method is used to work around triton-lang#5484).This seems to be mostly a consequence of these lines in the interpreter.
triton-cpu/python/triton/runtime/interpreter.py
Lines 1048 to 1073 in daa7eb0
It would be nice if the interpreter could support jax/numpy without all the boilerplate, especially because the interpreter lowers to numpy on cpu anyways. It would be extra nice if passing jax/numpy arrays "just worked" like pytorch tensors.
As I primarily write jax, this is not-so-relevant for me as jax has jax-triton and pallas (which has its own interpret mode). But given that (roughly) numpy : cpu :: pytorch : gpus, it would be nice if numpy was "blessed" for the cpu backend.
I would submit a PR, but it seems triton assumes things are torch tensor-like in all sorts of places in a much more global manner than #205. Naively, it might be possible to simply add additional checks when the kernel is being executed (
.data_ptr()
,.unsafe_buffer_pointer()
,.ctypes.data
) but there's too much I don't understand about triton's organization (for example, what isTensorWrapper
doing injit.py
and why does it have torch semantics?)triton-cpu/third_party/cpu/backend/driver.py
Lines 206 to 224 in daa7eb0
triton-cpu/python/triton/runtime/jit.py
Lines 895 to 929 in daa7eb0
Environment details
triton-cpu: daa7eb0
The text was updated successfully, but these errors were encountered: