Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

A lot of boilerplate for TRITON_INTERPRET=1 without torch #206

Closed
stephen-huan opened this issue Dec 24, 2024 · 2 comments
Closed

A lot of boilerplate for TRITON_INTERPRET=1 without torch #206

stephen-huan opened this issue Dec 24, 2024 · 2 comments
Labels
bug Something isn't working

Comments

@stephen-huan
Copy link

Describe the bug

(This is more of a feature request than a bug, and not a very pressing one, so feel free to ignore.)

On the thread of #204, it is possible to use triton-cpu with numpy/jax with the following Pointer shims

import jax.numpy as jnp
from jax import Array

import triton
import triton.language as tl


class Pointer:

    def __init__(self, data: Array) -> None:
        self.data = data
        self.dtype = data.dtype

    def data_ptr(self) -> int:
        return self.data.unsafe_buffer_pointer()


@triton.jit
def kernel(x_ptr, output_ptr) -> None:
    tl.store(output_ptr, tl.load(x_ptr))


if __name__ == "__main__":
    x = jnp.ones(10)
    output = jnp.zeros(10)
    kernel[lambda _: (1,)](Pointer(x), Pointer(output))
    print(x)
    print(output)
import numpy as np

import triton
import triton.language as tl


class Pointer:

    def __init__(self, data: np.ndarray) -> None:
        self.data = data
        self.dtype = data.dtype

    def data_ptr(self) -> int:
        return self.data.ctypes.data


@triton.jit
def kernel(x_ptr, output_ptr) -> None:
    tl.store(output_ptr, tl.load(x_ptr))


if __name__ == "__main__":
    x = np.ones(10)
    output = np.zeros(10)
    kernel[lambda _: (1,)](Pointer(x), Pointer(output))
    print(x)
    print(output)

(note that in the case of jax on gpu, it's possible to use jax-triton, see e.g. jax-ml/jax-triton#322 for an extension to cpu).

However, when TRITON_INTERPRET=1, the amount of boilerplate required drastically increases.

import os

os.environ["TRITON_INTERPRET"] = "1"


import jax
import jax.numpy as jnp
from jax import Array

import triton
import triton.language as tl


class Data:

    def __init__(self, data: Array) -> None:
        self.data = data

    def copy_(self, other: Array) -> None:
        self.data = other


class Pointer:

    def __init__(self, data: Array) -> None:
        self.data = Data(data)
        self.dtype = data.dtype
        self.ptr = data.unsafe_buffer_pointer()
        self.device = data.devices().pop()

    def data_ptr(self) -> int:
        return self.ptr

    def cpu(self) -> "Pointer":
        return self.to(jax.devices(backend="cpu")[0])

    def to(self, device) -> "Pointer":
        return Pointer(self.data.data.to_device(device))


@triton.jit
def kernel(x_ptr, output_ptr) -> None:
    tl.store(output_ptr, tl.load(x_ptr))


def main():
    x = jnp.ones(10)
    output = jnp.zeros(10)
    kernel[lambda _: (1,)](Pointer(x), Pointer(output))
    print(x)
    print(output)


if __name__ == "__main__":
    main()

(this could probably be written more efficiently with jax.device_put and jax.device_get.)

import os

os.environ["TRITON_INTERPRET"] = "1"


import numpy as np

import triton
import triton.language as tl


class Data:

    def __init__(self, data: np.ndarray) -> None:
        self.data = data

    def copy_(self, other: np.ndarray) -> None:
        self.data = other


class Pointer:

    def __init__(self, data: np.ndarray) -> None:
        self.data = Data(data)
        self.dtype = data.dtype
        self.ptr = data.ctypes.data
        self.device = 0

    def data_ptr(self) -> int:
        return self.ptr

    def cpu(self) -> "Pointer":
        return self

    def to(self, device) -> "Pointer":
        return self


@triton.jit
def kernel(x_ptr, output_ptr) -> None:
    tl.store(output_ptr, tl.load(x_ptr))


def main():
    x = np.ones(10)
    output = np.zeros(10)
    kernel[lambda _: (1,)](Pointer(x), Pointer(output))
    print(x)
    print(output)


if __name__ == "__main__":
    main()

(an explicit main method is used to work around triton-lang#5484).

This seems to be mostly a consequence of these lines in the interpreter.

def _init_args_hst(self, args_dev, kwargs):
args_hst = []
for arg in args_dev:
if hasattr(arg, "data_ptr"):
args_hst.append(arg.cpu())
else:
args_hst.append(arg)
# Process keyword arguments
kwargs_hst = {}
for key, value in kwargs.items():
if hasattr(value, "data_ptr"):
kwargs_hst[key] = value.cpu()
else:
kwargs_hst[key] = value
return args_hst, kwargs_hst
def _restore_args_dev(self, args_dev, args_hst, kwargs, kwargs_hst):
for arg_dev, arg_hst in zip(args_dev, args_hst):
if hasattr(arg_dev, "data_ptr"):
arg_dev.data.copy_(arg_hst.to(arg_dev.device).data)
# Restore keyword arguments
for key, kwarg_dev in kwargs.items():
kwarg_hst = kwargs_hst[key]
if hasattr(kwarg_dev, "data_ptr"):
kwarg_dev.data.copy_(kwarg_hst.to(kwarg_dev.device).data)

It would be nice if the interpreter could support jax/numpy without all the boilerplate, especially because the interpreter lowers to numpy on cpu anyways. It would be extra nice if passing jax/numpy arrays "just worked" like pytorch tensors.

As I primarily write jax, this is not-so-relevant for me as jax has jax-triton and pallas (which has its own interpret mode). But given that (roughly) numpy : cpu :: pytorch : gpus, it would be nice if numpy was "blessed" for the cpu backend.

I would submit a PR, but it seems triton assumes things are torch tensor-like in all sorts of places in a much more global manner than #205. Naively, it might be possible to simply add additional checks when the kernel is being executed (.data_ptr(), .unsafe_buffer_pointer(), .ctypes.data) but there's too much I don't understand about triton's organization (for example, what is TensorWrapper doing in jit.py and why does it have torch semantics?)

PyObject *ptr = PyObject_GetAttrString(obj, "data_ptr");
if(ptr){{
PyObject *empty_tuple = PyTuple_New(0);
PyObject *ret = PyObject_Call(ptr, empty_tuple, NULL);
Py_DECREF(empty_tuple);
Py_DECREF(ptr);
if (!PyLong_Check(ret)) {{
PyErr_SetString(PyExc_TypeError, "data_ptr method of Pointer object must return 64-bit int");
ptr_info.valid = false;
return ptr_info;
}}
ptr_info.dev_ptr = (void*) PyLong_AsLongLong(ret);
if(!ptr_info.dev_ptr) {{
return ptr_info;
}}
Py_DECREF(ret); // Thanks ChatGPT!
return ptr_info;
}}
PyErr_SetString(PyExc_TypeError, "Pointer argument must be either uint64 or have data_ptr method");

class TensorWrapper:
def __init__(self, base, dtype):
self.dtype = dtype
self.base = base
self.data = base.data
self.device = base.device
self.shape = self.base.shape
def data_ptr(self):
return self.base.data_ptr()
def stride(self, i):
return self.base.stride(i)
def __str__(self) -> str:
return f"TensorWrapper[{self.dtype}]({self.base})"
def element_size(self):
return self.base.element_size()
def cpu(self):
return TensorWrapper(self.base.cpu(), self.dtype)
def copy_(self, other):
self.base.copy_(other.base)
def clone(self):
return TensorWrapper(self.base.clone(), self.dtype)
def to(self, device):
return TensorWrapper(self.base.to(device), self.dtype)
def new_empty(self, sizes):
return TensorWrapper(self.base.new_empty(sizes), self.dtype)

Environment details

triton-cpu: daa7eb0

@stephen-huan stephen-huan added the bug Something isn't working label Dec 24, 2024
@minjang
Copy link
Collaborator

minjang commented Dec 25, 2024

Again, you'd want to raise this issue in the upstream. The interpreter is maintained by the upstream, not by triton-cpu.

@stephen-huan
Copy link
Author

Opened as triton-lang#5493. Sorry for all the duplicates!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants