Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert "Revert "[Backend] Improve dot support to target FMA (#4516)"" #3056

Merged
merged 2 commits into from
Dec 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions include/triton/Conversion/TritonGPUToLLVM/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -348,12 +348,18 @@ SmallVector<Value> delinearize(RewriterBase &rewriter, Location loc,
SmallVector<Value> delinearize(RewriterBase &rewriter, Location loc,
Value linear, ArrayRef<unsigned> shape);

SmallVector<unsigned> delinearize(unsigned linear, ArrayRef<unsigned> shape,
ArrayRef<unsigned> order);

Value linearize(RewriterBase &rewriter, Location loc, ArrayRef<Value> multiDim,
ArrayRef<unsigned> shape, ArrayRef<unsigned> order);

Value linearize(RewriterBase &rewriter, Location loc, ArrayRef<Value> multiDim,
ArrayRef<unsigned> shape);

size_t linearize(ArrayRef<unsigned> multiDim, ArrayRef<unsigned> shape,
ArrayRef<unsigned> order);

Value addStringToModule(Location loc, RewriterBase &rewriter, StringRef key,
StringRef content);

Expand Down Expand Up @@ -496,6 +502,24 @@ inline Value dot(RewriterBase &rewriter, Location loc, ArrayRef<Value> offsets,
return ret;
}

/// Extend 2d shared object to 3d.
///
/// If tensor has 3 dimensions, returns original shared object.
/// If tensor shape is [M, N], return shared object describing shape [1, M, N]
///
/// This Function is used to simplify processing of 2d and 3d dot operands,
/// particularly in the conversion of local_load operation.
///
/// \param rewriter
/// \param loc
/// \param smemObj
/// \param shape shape of a tensor represented by smemObj
/// \returns shared object describing 3d tensor
SharedMemoryObject
getExpandedSharedMemoryObject(ConversionPatternRewriter &rewriter, Location loc,
SharedMemoryObject smemObj,
ArrayRef<int64_t> shape);

// -----------------------------------------------------------------------
// Blocked layout indices
// -----------------------------------------------------------------------
Expand Down
6 changes: 6 additions & 0 deletions include/triton/Dialect/TritonGPU/IR/Dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,12 @@ void dumpHWLayout(RankedTensorType tensorType);
// Return a string representation of the layout of the tensor.
std::string getLayoutStr(RankedTensorType tensorType, bool useHWPointOfView);

template <typename T>
llvm::SmallVector<T> expandMatrixShapeWithBatch(llvm::ArrayRef<T> s);

llvm::SmallVector<unsigned>
expandMatrixOrderWithBatch(llvm::ArrayRef<unsigned> o);

} // namespace gpu
} // namespace triton
} // namespace mlir
Expand Down
Loading