Skip to content

Commit

Permalink
[LAYOUTS] Implement generically getElemsPerThread (#5841)
Browse files Browse the repository at this point in the history
While doing so, we remove the SliceEncodingAttr hack!
  • Loading branch information
lezcano authored Feb 7, 2025
1 parent 7da6d0b commit 61b5674
Show file tree
Hide file tree
Showing 6 changed files with 43 additions and 332 deletions.
6 changes: 4 additions & 2 deletions include/triton/Dialect/TritonGPU/IR/Dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,12 @@ struct SharedMemory : public SideEffects::Resource::Base<SharedMemory> {
StringRef getName() final { return "<SharedMemory>"; }
};

// Convert a distributed layout to a linear encoding
LinearEncodingAttr toLinearEncoding(Attribute layout, ArrayRef<int64_t> shape);

unsigned getTotalElemsPerThread(Type type);

unsigned getTotalElemsPerThread(Attribute layout, ArrayRef<int64_t> shape,
Type eltTy);
unsigned getTotalElemsPerThread(Attribute layout, ArrayRef<int64_t> shape);

SmallVector<unsigned> getElemsPerThread(Type type);

Expand Down
21 changes: 13 additions & 8 deletions include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -501,13 +501,17 @@ We call each individual tile "rep".
InterfaceMethod<"Return total element size per thread.",
"unsigned",
"getTotalElemsPerThread",
(ins "ArrayRef<int64_t>":$tensorShape,
"Type":$eltTy)>,
(ins "ArrayRef<int64_t>":$shape),
/*defaultImplementation=*/[{
return toLinearEncoding($_self, shape).getTotalElemsPerThread(shape);
}]>,
InterfaceMethod<"Return element size per thread in each dimension.",
"SmallVector<unsigned>",
"getElemsPerThread",
(ins "ArrayRef<int64_t>":$tensorShape,
"Type":$eltTy)>,
(ins "ArrayRef<int64_t>":$shape),
/*defaultImplementation=*/[{
return toLinearEncoding($_self, shape).getElemsPerThread(shape);
}]>,
// Interface for the meta information about the multiple thread hierarchy.
InterfaceMethod<"Get the shape of the warps per CTA.",
"SmallVector<unsigned>",
Expand Down Expand Up @@ -577,8 +581,7 @@ L(T) = [ {0,8} , {1,9} , {2,10}, {3,11}, {0,8} , {1, 9} , {2, 10}, {3, 11},
}];

code extraDistributedDeclaration = extraBaseClassDeclaration # [{
unsigned getTotalElemsPerThread(ArrayRef<int64_t> shape, Type eltTy) const;
SmallVector<unsigned> getElemsPerThread(ArrayRef<int64_t> shape, Type eltTy) const;
// Implemented in subclasses
SmallVector<unsigned> getRepOrder() const;
SmallVector<unsigned> getCTAsPerCGA() const;
SmallVector<unsigned> getCTAOrder() const;
Expand Down Expand Up @@ -613,6 +616,10 @@ def LinearEncodingAttr : DistributedEncoding<"LinearEncoding", "linear_encoding"
let parameters = (ins LinearLayoutParam:$linearLayout);

let extraClassDeclaration = extraDistributedDeclaration # [{
// Generic distributed encoding methods
unsigned getTotalElemsPerThread(ArrayRef<int64_t> shape) const;
SmallVector<unsigned> getElemsPerThread(ArrayRef<int64_t> shape) const;

SmallVector<unsigned> getContigPerThread() const;
SmallVector<unsigned> getOrder() const;

Expand Down Expand Up @@ -965,7 +972,6 @@ V [ 0,4,8...60 1,5...61 2,6...62 3,7...63 ] [ 128,132...188 129,
return true;
}
SmallVector<unsigned> getSizePerThreadForOperand(int kWidth, int opIdx) const;
unsigned getTotalElemsPerThreadForOperand(ArrayRef<int64_t> shape, Type eltTy, int kWidth, int opIdx) const;
SmallVector<int64_t> getInstrShapeForOperand(int kWidth, int opIdx) const;
SmallVector<int64_t> getRepForOperand(ArrayRef<int64_t> operandShape, int kWidth, int opIdx) const;
SmallVector<unsigned> getRepOrderForOperand(int opIdx) const;
Expand Down Expand Up @@ -1095,7 +1101,6 @@ Row |
return true;
}
SmallVector<unsigned> getSizePerThreadForOperand(int kWidth, int opIdx) const;
unsigned getTotalElemsPerThreadForOperand(ArrayRef<int64_t> shape, Type eltTy, int kWidth, int opIdx) const;
SmallVector<int64_t> getElemsPerInstrForOperands() const;
SmallVector<int64_t> getRepForOperand(ArrayRef<int64_t> operandShape,
Type elemType, int kWidth, int opIdx) const;
Expand Down
Loading

0 comments on commit 61b5674

Please sign in to comment.