Skip to content

Commit

Permalink
[Pipeliner] Enable automatic loop fusion (#5726)
Browse files Browse the repository at this point in the history
This PR turns on automatic loop fusion in the CUDA >= 8.0 pass
pipelines. Automatic loop fusion is only enabled for simple loop nests
(1 outer loop, 1 inner loop), when the user requests fusion with
`tl.range(..., fuse=True)` in the frontend.

This PR also rewrites the persistent matmul examples to use loop nests.
This is cleaner, but will also enable more powerful and flexible
optimizations of loop nests in the future. Primarily, it hides the
brittleless of the pipeliner behind a single layer inside the compiler,
so ideally the brittleness needs to be dealt with only once and hidden
from users.

To achieve this, several things have been added to loop fusion:

1. To avoid generating the inner loop inside a conditional, loop nest
fusion will "speculate" the length of the inner loop, essentially
generating a branch where the inner loop is missing and one where the
inner loop is always known to execute at least once.
2. Codegen of the loop induction variables has been slightly altered to
better match the expectations of the scheduler, pipeliner(s), and
`optimize-accumulator-init`.
3. Codegen of loop iter args has been altered to generate fewer SSA
dependencies between the prologue, inner loop, and epilogue, making it
more likely for pipelining to be successful. E.g., inner loop iter args
that can be initialized outside the loop and reset in the epilogue are
done so, rather than in the prologue.

Some other things in this PR:

* Fixed a bug in the pipeline expander
* Added AxisInfo implementation for `ub::PoisonOp`

I verified the performance of the rewritten persistent matmul kernels on
H100 and Blackwell.

Performance of `09-persistent-matmul.py` on H100.

Before (2 runs)

```
root@dev-0:~/code/triton$ python python/tutorials/09-persistent-matmul.py 
M=32, N=32, K=32 verification naive vs: torch: ✅ cublas: ✅ persistent: ✅ TMA persistent: ✅ Tensor descriptor persistent: ✅ 
M=8192, N=8192, K=512 verification naive vs: torch: ✅ cublas: ✅ persistent: ✅ TMA persistent: ✅ Tensor descriptor persistent: ✅ 
273.146 4025.362 ROOT
├─ nan 0.031 _ZN2at6native18elementwise_kernelILi128ELi4EZNS0_22gpu_kernel_impl_nocastIZZZNS0_23direct_copy_kernel_cudaERNS_18TensorIteratorBaseEENKUlvE1_clEvENKUlvE8_clEvEUlN3c104HalfEE_EEvS4_RKT_EUliE_EEviT1_
├─ nan 0.027 _ZN2at6native54_GLOBAL__N__a236ace4_21_DistributionNormal_cu_0c5b6e8543distribution_elementwise_grid_stride_kernelIfLi4EZNS0_9templates4cuda20normal_and_transformIN3c104HalfEfLm4EPNS_17CUDAGeneratorImplEZZZNS4_13normal_kernelIS9_EEvRKNS_10TensorBaseEddT_ENKUlvE_clEvENKUlvE1_clEvEUlfE_EEvRNS_18TensorIteratorBaseET2_T3_EUlP24curandStatePhilox4_32_10E0_ZNS1_27distribution_nullary_kernelIS7_fLi4ES9_SO_SH_EEvSJ_SK_RKSL_T4_EUlifE_EEviNS_15PhiloxCudaStateET1_SK_
├─ 283.506 2666.310 cublas [M=8192, N=8192, K=512]
│  └─ nan 2666.310 sm90_xmma_gemm_f16f16_f16f32_f32_tn_n_tilesize128x128x64_warpgroupsize1x1x1_execute_segment_k_off_kernel__5x_cublas
├─ 223.326 307.709 matmul_kernel [M=8192, N=8192, K=512]
├─ 259.293 265.027 matmul_kernel_descriptor_persistent [M=8192, N=8192, K=512]
├─ 238.500 288.133 matmul_kernel_persistent [M=8192, N=8192, K=512]
├─ 258.738 265.594 matmul_kernel_tma_persistent [M=8192, N=8192, K=512]
└─ 295.529 232.531 torch [M=8192, N=8192, K=512]
   └─ nan 232.531 sm90_xmma_gemm_f16f16_f16f32_f32_tn_n_tilesize128x128x64_warpgroupsize1x1x1_execute_segment_k_off_kernel__5x_cublas

Legend (Metric: tflop16/s (inc) Min: 223.33 Max: 295.53)
█ 288.31 - 295.53
█ 273.87 - 288.31
█ 259.43 - 273.87
█ 244.99 - 259.43
█ 230.55 - 244.99
█ 223.33 - 230.55

name User code    ◀  Only in left graph    ▶  Only in right graph

root@dev-0:~/code/triton$ python python/tutorials/09-persistent-matmul.py 
M=32, N=32, K=32 verification naive vs: torch: ✅ cublas: ✅ persistent: ✅ TMA persistent: ✅ Tensor descriptor persistent: ✅ 
M=8192, N=8192, K=512 verification naive vs: torch: ✅ cublas: ✅ persistent: ✅ TMA persistent: ✅ Tensor descriptor persistent: ✅ 
273.367 4022.105 ROOT
├─ nan 0.031 _ZN2at6native18elementwise_kernelILi128ELi4EZNS0_22gpu_kernel_impl_nocastIZZZNS0_23direct_copy_kernel_cudaERNS_18TensorIteratorBaseEENKUlvE1_clEvENKUlvE8_clEvEUlN3c104HalfEE_EEvS4_RKT_EUliE_EEviT1_
├─ nan 0.027 _ZN2at6native54_GLOBAL__N__a236ace4_21_DistributionNormal_cu_0c5b6e8543distribution_elementwise_grid_stride_kernelIfLi4EZNS0_9templates4cuda20normal_and_transformIN3c104HalfEfLm4EPNS_17CUDAGeneratorImplEZZZNS4_13normal_kernelIS9_EEvRKNS_10TensorBaseEddT_ENKUlvE_clEvENKUlvE1_clEvEUlfE_EEvRNS_18TensorIteratorBaseET2_T3_EUlP24curandStatePhilox4_32_10E0_ZNS1_27distribution_nullary_kernelIS7_fLi4ES9_SO_SH_EEvSJ_SK_RKSL_T4_EUlifE_EEviNS_15PhiloxCudaStateET1_SK_
├─ 284.284 2659.011 cublas [M=8192, N=8192, K=512]
│  └─ nan 2659.011 sm90_xmma_gemm_f16f16_f16f32_f32_tn_n_tilesize128x128x64_warpgroupsize1x1x1_execute_segment_k_off_kernel__5x_cublas
├─ 221.823 309.795 matmul_kernel [M=8192, N=8192, K=512]
├─ 254.755 269.748 matmul_kernel_descriptor_persistent [M=8192, N=8192, K=512]
├─ 240.774 285.411 matmul_kernel_persistent [M=8192, N=8192, K=512]
├─ 259.109 265.214 matmul_kernel_tma_persistent [M=8192, N=8192, K=512]
└─ 295.100 232.868 torch [M=8192, N=8192, K=512]
   └─ nan 232.868 sm90_xmma_gemm_f16f16_f16f32_f32_tn_n_tilesize128x128x64_warpgroupsize1x1x1_execute_segment_k_off_kernel__5x_cublas

Legend (Metric: tflop16/s (inc) Min: 221.82 Max: 295.10)
█ 287.77 - 295.10
█ 273.12 - 287.77
█ 258.46 - 273.12
█ 243.81 - 258.46
█ 229.15 - 243.81
█ 221.82 - 229.15

name User code    ◀  Only in left graph    ▶  Only in right graph

```

After:

```
root@dev-0:~/code/triton$ python python/tutorials/09-persistent-matmul.py 
M=32, N=32, K=32 verification naive vs: torch: ✅ cublas: ✅ persistent: ✅ TMA persistent: ✅ Tensor descriptor persistent: ✅ 
M=8192, N=8192, K=512 verification naive vs: torch: ✅ cublas: ✅ persistent: ✅ TMA persistent: ✅ Tensor descriptor persistent: ✅ 
274.040 4012.227 ROOT
├─ nan 0.031 _ZN2at6native18elementwise_kernelILi128ELi4EZNS0_22gpu_kernel_impl_nocastIZZZNS0_23direct_copy_kernel_cudaERNS_18TensorIteratorBaseEENKUlvE1_clEvENKUlvE8_clEvEUlN3c104HalfEE_EEvS4_RKT_EUliE_EEviT1_
├─ nan 0.027 _ZN2at6native54_GLOBAL__N__a236ace4_21_DistributionNormal_cu_0c5b6e8543distribution_elementwise_grid_stride_kernelIfLi4EZNS0_9templates4cuda20normal_and_transformIN3c104HalfEfLm4EPNS_17CUDAGeneratorImplEZZZNS4_13normal_kernelIS9_EEvRKNS_10TensorBaseEddT_ENKUlvE_clEvENKUlvE1_clEvEUlfE_EEvRNS_18TensorIteratorBaseET2_T3_EUlP24curandStatePhilox4_32_10E0_ZNS1_27distribution_nullary_kernelIS7_fLi4ES9_SO_SH_EEvSJ_SK_RKSL_T4_EUlifE_EEviNS_15PhiloxCudaStateET1_SK_
├─ 285.369 2648.904 cublas [M=8192, N=8192, K=512]
│  └─ nan 2648.904 sm90_xmma_gemm_f16f16_f16f32_f32_tn_n_tilesize128x128x64_warpgroupsize1x1x1_execute_segment_k_off_kernel__5x_cublas
├─ 217.548 315.881 matmul_kernel [M=8192, N=8192, K=512]
├─ 262.312 261.976 matmul_kernel_descriptor_persistent [M=8192, N=8192, K=512]
├─ 244.740 280.785 matmul_kernel_persistent [M=8192, N=8192, K=512]
├─ 255.113 269.368 matmul_kernel_tma_persistent [M=8192, N=8192, K=512]
└─ 292.108 235.253 torch [M=8192, N=8192, K=512]
   └─ nan 235.253 sm90_xmma_gemm_f16f16_f16f32_f32_tn_n_tilesize128x128x64_warpgroupsize1x1x1_execute_segment_k_off_kernel__5x_cublas

Legend (Metric: tflop16/s (inc) Min: 217.55 Max: 292.11)
█ 284.65 - 292.11
█ 269.74 - 284.65
█ 254.83 - 269.74
█ 239.92 - 254.83
█ 225.00 - 239.92
█ 217.55 - 225.00

name User code    ◀  Only in left graph    ▶  Only in right graph

root@dev-0:~/code/triton$ python python/tutorials/09-persistent-matmul.py 
M=32, N=32, K=32 verification naive vs: torch: ✅ cublas: ✅ persistent: ✅ TMA persistent: ✅ Tensor descriptor persistent: ✅ 
M=8192, N=8192, K=512 verification naive vs: torch: ✅ cublas: ✅ persistent: ✅ TMA persistent: ✅ Tensor descriptor persistent: ✅ 
274.997 3998.267 ROOT
├─ nan 0.031 _ZN2at6native18elementwise_kernelILi128ELi4EZNS0_22gpu_kernel_impl_nocastIZZZNS0_23direct_copy_kernel_cudaERNS_18TensorIteratorBaseEENKUlvE1_clEvENKUlvE8_clEvEUlN3c104HalfEE_EEvS4_RKT_EUliE_EEviT1_
├─ nan 0.027 _ZN2at6native54_GLOBAL__N__a236ace4_21_DistributionNormal_cu_0c5b6e8543distribution_elementwise_grid_stride_kernelIfLi4EZNS0_9templates4cuda20normal_and_transformIN3c104HalfEfLm4EPNS_17CUDAGeneratorImplEZZZNS4_13normal_kernelIS9_EEvRKNS_10TensorBaseEddT_ENKUlvE_clEvENKUlvE1_clEvEUlfE_EEvRNS_18TensorIteratorBaseET2_T3_EUlP24curandStatePhilox4_32_10E0_ZNS1_27distribution_nullary_kernelIS7_fLi4ES9_SO_SH_EEvSJ_SK_RKSL_T4_EUlifE_EEviNS_15PhiloxCudaStateET1_SK_
├─ 285.498 2647.706 cublas [M=8192, N=8192, K=512]
│  └─ nan 2647.706 sm90_xmma_gemm_f16f16_f16f32_f32_tn_n_tilesize128x128x64_warpgroupsize1x1x1_execute_segment_k_off_kernel__5x_cublas
├─ 217.884 315.394 matmul_kernel [M=8192, N=8192, K=512]
├─ 262.534 261.755 matmul_kernel_descriptor_persistent [M=8192, N=8192, K=512]
├─ 246.617 278.649 matmul_kernel_persistent [M=8192, N=8192, K=512]
├─ 262.525 261.764 matmul_kernel_tma_persistent [M=8192, N=8192, K=512]
└─ 295.007 232.942 torch [M=8192, N=8192, K=512]
   └─ nan 232.942 sm90_xmma_gemm_f16f16_f16f32_f32_tn_n_tilesize128x128x64_warpgroupsize1x1x1_execute_segment_k_off_kernel__5x_cublas

Legend (Metric: tflop16/s (inc) Min: 217.88 Max: 295.01)
█ 287.29 - 295.01
█ 271.87 - 287.29
█ 256.45 - 271.87
█ 241.02 - 256.45
█ 225.60 - 241.02
█ 217.88 - 225.60

name User code    ◀  Only in left graph    ▶  Only in right graph

```
  • Loading branch information
Mogball authored Feb 7, 2025
1 parent 61b5674 commit 0cb0140
Show file tree
Hide file tree
Showing 16 changed files with 867 additions and 326 deletions.
24 changes: 24 additions & 0 deletions lib/Analysis/AxisInfo.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "mlir/Analysis/DataFlowFramework.h"
#include "mlir/Dialect/UB/IR/UBOps.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"

Expand Down Expand Up @@ -269,6 +270,28 @@ class ConstantOpAxisInfoVisitor final
}
};

class PoisonOpAxisInfoVisitor final : public AxisInfoVisitorImpl<ub::PoisonOp> {
public:
using AxisInfoVisitorImpl::AxisInfoVisitorImpl;

AxisInfo
getAxisInfo(ub::PoisonOp op,
ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) override {
constexpr int64_t largePowerOf2 = int64_t(1) << 32;
// Poison values are never accessed, thus assume optimistic values.
if (auto shape = dyn_cast<mlir::ShapedType>(op.getType())) {
unsigned rank = shape.getRank();
return AxisInfo(
/*contiguity=*/AxisInfo::DimVectorT(rank, largePowerOf2),
/*divisibility=*/AxisInfo::DimVectorT(rank, largePowerOf2),
/*constancy=*/AxisInfo::DimVectorT(shape.getShape()));
}

return AxisInfo(/*contiguity=*/{1}, /*divisibility=*/{largePowerOf2},
/*constancy=*/{1});
}
};

template <typename OpTy>
class AddSubOpAxisInfoVisitor final : public BinaryOpVisitorImpl<OpTy> {
public:
Expand Down Expand Up @@ -1012,6 +1035,7 @@ AxisInfoAnalysis::AxisInfoAnalysis(DataFlowSolver &solver)
CastOpAxisInfoVisitor<mlir::UnrealizedConversionCastOp>,
CastOpAxisInfoVisitor<triton::BitcastOp>>();
visitors.append<MakeRangeOpAxisInfoVisitor>();
visitors.append<PoisonOpAxisInfoVisitor>();
visitors.append<ConstantOpAxisInfoVisitor>();
visitors.append<AddSubOpAxisInfoVisitor<triton::AddPtrOp>,
AddSubOpAxisInfoVisitor<arith::AddIOp>,
Expand Down
22 changes: 12 additions & 10 deletions lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <algorithm>
#include <numeric>

#include "mlir/Dialect/UB/IR/UBOps.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/Support/LLVM.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
Expand Down Expand Up @@ -97,16 +98,17 @@ TritonGPUConversionTarget::TritonGPUConversionTarget(

addDynamicallyLegalDialect<arith::ArithDialect, math::MathDialect,
triton::TritonDialect, cf::ControlFlowDialect,
scf::SCFDialect>([&](Operation *op) {
bool hasLegalRegions = true;
for (auto &region : op->getRegions()) {
hasLegalRegions = hasLegalRegions && typeConverter.isLegal(&region);
}
if (hasLegalRegions && typeConverter.isLegal(op)) {
return true;
}
return false;
});
scf::SCFDialect, ub::UBDialect>(
[&](Operation *op) {
bool hasLegalRegions = true;
for (auto &region : op->getRegions()) {
hasLegalRegions = hasLegalRegions && typeConverter.isLegal(&region);
}
if (hasLegalRegions && typeConverter.isLegal(op)) {
return true;
}
return false;
});

// We have requirements for the data layouts
addDynamicallyLegalOp<triton::DotOp>([](triton::DotOp dotOp) -> bool {
Expand Down
2 changes: 2 additions & 0 deletions lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/UB/IR/UBOps.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
Expand Down Expand Up @@ -859,6 +860,7 @@ class ConvertTritonToTritonGPU
// mlir::scf::populateSCFStructurealTypeConversionsAndLegality(...) here?
populateSCFPatterns(typeConverter, patterns);
populateCFPatterns(typeConverter, patterns);
patterns.insert<GenericOpPattern<ub::PoisonOp>>(typeConverter, context);

auto inti = llvm::APSInt(32, false);
auto i32_ty = IntegerType::get(mod->getContext(), 32);
Expand Down
Loading

0 comments on commit 0cb0140

Please sign in to comment.