Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TRITON_RAISE_BP]: scf.for init arg list rewrite causes invalid IR in tt.broadcast op in loop body #3254

Closed
etiotto opened this issue Jan 23, 2025 · 1 comment · Fixed by #3252
Assignees

Comments

@etiotto
Copy link
Contributor

etiotto commented Jan 23, 2025

Reduced test:

  tt.func @kernel(%arg0 : !tt.ptr<bf16>, %arg1 : i32) {
    %0 = tt.make_range {end = 256 : i32, start = 0 : i32}:tensor<256xi32>
    %1 = tt.expand_dims %0 {axis = 0 : i32} : tensor<256xi32> -> tensor<1x256xi32>

    %splat_arg0 = tt.splat %arg0 : !tt.ptr<bf16> -> tensor<1x256x!tt.ptr<bf16>>
    %2 = tt.addptr %splat_arg0, %1 : tensor<1x256x!tt.ptr<bf16>>, tensor<1x256xi32>

    %c0 = arith.constant 0 : index
    %c12 = arith.constant 12 : index
    %c3 = arith.constant 3 : index
    %_ptr = scf.for %i = %c0 to %c12 step %c3 iter_args(%ptr = %2) -> (tensor<1x256x!tt.ptr<bf16>>) {
        %bptr = tt.broadcast %ptr : tensor<1x256x!tt.ptr<bf16>> -> tensor<4x256x!tt.ptr<bf16>>
        scf.yield %ptr : tensor<1x256x!tt.ptr<bf16>>
    }
    tt.return
  }

Error:

triton-opt -triton-raise-block-pointer ~/tmp/test1.mlir 

/home/jovyan/tmp/test1.mlir:12:17: error: 'tt.broadcast' op operand #0 must be ranked tensor of floating-point or integer or ptr values, but got '!tt.ptr<tensor<1x256xbf16>>'
        %bptr = tt.broadcast %ptr : tensor<1x256x!tt.ptr<bf16>> -> tensor<4x256x!tt.ptr<bf16>>
...
@etiotto etiotto self-assigned this Jan 23, 2025
@etiotto
Copy link
Contributor Author

etiotto commented Jan 23, 2025

The issue here is triggered by the rewrite of the tt.addptr outside the loop into a tt.make_tensor_ptr. The SSA value yielded by the tt.addptr operation (%2) is used in the init_arg list of the scf.for operation, and therefore that init argument is rewrittent to use the block ptr yielded by the tt.make_tensor_ptr operation just mentioned.

Unfortunately the loop contains a tt.broadcast operation which cannot accept a block ptr as its input !

We have a couple of possible solutions:

  1. we can avoid rewriting the init_arg when the loop contains an operation which use that value and cannot be rewritten to use a block ptr (e.g. tt.broadcast), or

  2. we can pass the original SSA value (%2) as well as the rewritten value yielded by the tt,make_tensor_ptr, and use the appropriate init_arg depending on the operation in the loop that references it (that is, use the original init_arg for tt.broadcast and the block ptr argument for operation that can accept it, such as load/store and addptr).

I will go with solution (1) for now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
1 participant