diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp index 71fd573ff1e3..c2879c9a74ed 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp @@ -497,23 +497,14 @@ assignMemoryLayouts(scf::ForOp &forOp, if (!op.hasAttr(mlir::triton::kLoopStageAttrName)) continue; - // Check stage for uses. If any direct use is in a different stage, treat it + // Check stage for uses. If the first use is in a different stage, treat it // as a pipelined load. - bool isPipelined = false; auto [sLoad, _cLoad] = tt::getStageCluster(&op); - auto directUsers = getDirectUserInBlock(&op); - LDBG("DirectUser for load " << op); - for (auto user : directUsers) { - LDBG(" - use: " << *user); - if (!user->hasAttr(mlir::triton::kLoopStageAttrName)) - continue; - auto [stage, _cluster] = tt::getStageCluster(user); - if (stage != sLoad) { - isPipelined = true; - break; - } - } - if (!isPipelined) + Operation *firstUse = getFirstUseOfPipelinedLoad(&op); + LDBG("first use for load " << op); + LDBG(" - use: " << *firstUse); + auto firstUseStageCluster = tt::maybeGetStageCluster(firstUse); + if (!firstUseStageCluster || firstUseStageCluster->first == sLoad) continue; // Try to set shared encoding etc for the pipelined load. diff --git a/test/TritonGPU/matmul-loop-pipeline.mlir b/test/TritonGPU/matmul-loop-pipeline.mlir index f8042feee9bf..1a91ab022a78 100644 --- a/test/TritonGPU/matmul-loop-pipeline.mlir +++ b/test/TritonGPU/matmul-loop-pipeline.mlir @@ -1,4 +1,4 @@ -// RUN: triton-opt %s -tritongpu-pipeline | FileCheck %s +// RUN: triton-opt %s -split-input-file -tritongpu-pipeline | FileCheck %s #blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [8], order = [0]}> @@ -28,3 +28,22 @@ tt.func public @softmax_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, } } + +// ----- + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:90"} { + +// CHECK-LABEL: @scalar_load +tt.func public @scalar_load(%arg0: !tt.ptr, %arg1: i32, %arg2: i32, %arg3: f32) -> f32 { + %c1_i32 = arith.constant 1 : i32 + %2 = scf.for %i = %arg1 to %arg2 step %c1_i32 iter_args(%k = %arg3) -> f32 : i32 { + // CHECK: tt.load %arg0 + %0 = tt.load %arg0 {loop.cluster = 1 : i32, loop.stage = 0 : i32} : !tt.ptr + %1 = arith.addf %0, %k {loop.cluster = 1 : i32, loop.stage = 0 : i32} : f32 + %2 = arith.addf %1, %k {loop.cluster = 0 : i32, loop.stage = 1 : i32} : f32 + scf.yield %2 : f32 + } {num_stages = 2 : i32} + tt.return %2 : f32 +} + +}