diff --git a/third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp b/third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp index 5ba8d8af0c..ed9b8e2c7a 100644 --- a/third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp +++ b/third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp @@ -245,18 +245,21 @@ DpasEncodingAttr::getDPASRepetitions(ArrayRef shape, std::max(1, shape[rank - 1] / (shapePerWarp[rank - 1] * warpsPerCTA[rank - 1]))}; } break; + case OpIdx::OperandC: { + auto shapePerWarp = getShapeC(); + int64_t numRepBatch = + rank == 3 ? std::max(1, shape[0] / + (shapePerWarp[0] * warpsPerCTA[0])) + : 1; + return {numRepBatch, + std::max(1, shape[rank - 2] / (shapePerWarp[rank - 2] * + warpsPerCTA[rank - 2])), + std::max(1, shape[rank - 1] / (shapePerWarp[rank - 1] * + warpsPerCTA[rank - 1]))}; + } break; } - auto shapePerWarp = getShapeC(); - int64_t numRepBatch = - rank == 3 - ? std::max(1, shape[0] / (shapePerWarp[0] * warpsPerCTA[0])) - : 1; - return {numRepBatch, - std::max(1, shape[rank - 2] / (shapePerWarp[rank - 2] * - warpsPerCTA[rank - 2])), - std::max(1, shape[rank - 1] / (shapePerWarp[rank - 1] * - warpsPerCTA[rank - 1]))}; + llvm_unreachable("unexpected opIdx"); } unsigned DpasEncodingAttr::getTotalElemsPerThreadForOperand( @@ -279,6 +282,8 @@ unsigned DpasEncodingAttr::getTotalElemsPerThreadForOperand( // dpas operands scalar are evenly sharded to each work item. return (totalElem / threadsPerWar) * product(rep); } break; + case OpIdx::OperandC: + llvm_unreachable("unexpected OperandC"); } llvm_unreachable("unexpected opIdx"); } @@ -350,8 +355,9 @@ DpasEncodingAttr::getSizePerThreadForOperand(int kWidth, OpIdx opIdx) const { return {shapeB[rank - 2] / threadsPerWarp[0], shapeB[rank - 1] / threadsPerWarp[1] * repCluster[rank - 1]}; } break; + default: + llvm_unreachable("unexpected opIdx"); } - llvm_unreachable("unexpected opIdx"); } SmallVector DpasEncodingAttr::getContigPerThread() const { diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp index 943784ba2e..350f71f3dc 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -397,6 +397,8 @@ struct ConvertLayoutOpConversion repInner = repetitions[1]; repClusterOuter = repCluster[rank - 1]; } break; + default: + llvm_unreachable("unexpected opIdx"); } // TODO: Operands B requires extra steps to combine [8, 16] to [16, 16]. diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp index 73eae00cd6..894c56d240 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -861,6 +861,8 @@ struct LoadOpConversion i32_val(outer * repOuterStride + rep * repStride)); offsetY = i32_val(k * repKStride); } break; + default: + llvm_unreachable("unexpected opIdx"); } offsetX = add(offsetX, offsetBaseX); @@ -942,6 +944,8 @@ struct LoadOpConversion k + row}] = bitcast(loadVal, unpackedDPASOperandType); } break; + default: + llvm_unreachable("unexpected opIdx"); } } }