Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MemoryBanking] Support multi-dimension memory banking #8033

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions include/circt/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,9 @@ std::unique_ptr<mlir::Pass> createStripDebugInfoWithPredPass(
std::unique_ptr<mlir::Pass> createMaximizeSSAPass();
std::unique_ptr<mlir::Pass> createInsertMergeBlocksPass();
std::unique_ptr<mlir::Pass> createPrintOpCountPass();
std::unique_ptr<mlir::Pass>
createMemoryBankingPass(std::optional<unsigned> bankingFactor = std::nullopt);
std::unique_ptr<mlir::Pass> createMemoryBankingPass(
std::optional<unsigned> bankingFactor = std::nullopt,
std::optional<unsigned> bankingDimension = std::nullopt);
std::unique_ptr<mlir::Pass> createIndexSwitchToIfPass();

//===----------------------------------------------------------------------===//
Expand Down
4 changes: 3 additions & 1 deletion include/circt/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,9 @@ def MemoryBanking : Pass<"memory-banking", "::mlir::func::FuncOp"> {
let constructor = "circt::createMemoryBankingPass()";
let options = [
Option<"bankingFactor", "banking-factor", "unsigned", /*default=*/"1",
"Use this banking factor for all memories being partitioned">
"Use this banking factor for all memories being partitioned">,
Option<"bankingDimension", "dimension", "unsigned", /*default=*/"0",
"The dimension along which to bank the memory. For rank=1, must be 0.">
];
let dependentDialects = ["mlir::memref::MemRefDialect, mlir::scf::SCFDialect, mlir::affine::AffineDialect"];
}
Expand Down
100 changes: 62 additions & 38 deletions lib/Transforms/MemoryBanking.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/raw_ostream.h"

namespace circt {
#define GEN_PASS_DEF_MEMORYBANKING
Expand All @@ -42,7 +41,8 @@ struct MemoryBankingPass
: public circt::impl::MemoryBankingBase<MemoryBankingPass> {
MemoryBankingPass(const MemoryBankingPass &other) = default;
explicit MemoryBankingPass(
std::optional<unsigned> bankingFactor = std::nullopt) {}
std::optional<unsigned> bankingFactor = std::nullopt,
std::optional<unsigned> bankingDimension = std::nullopt) {}

void runOnOperation() override;

Expand All @@ -67,26 +67,29 @@ DenseSet<Value> collectMemRefs(mlir::affine::AffineParallelOp parOp) {
}

MemRefType computeBankedMemRefType(MemRefType originalType,
uint64_t bankingFactor) {
uint64_t bankingFactor,
unsigned bankingDimension) {
ArrayRef<int64_t> originalShape = originalType.getShape();
assert(!originalShape.empty() && "memref shape should not be empty");
assert(originalType.getRank() == 1 &&
"currently only support one dimension memories");
SmallVector<int64_t, 4> newShape(originalShape.begin(), originalShape.end());
assert(newShape.front() % bankingFactor == 0 &&

assert(bankingDimension < originalType.getRank() &&
"dimension must be within the memref rank");
assert(originalShape[bankingDimension] % bankingFactor == 0 &&
"memref shape must be evenly divided by the banking factor");
newShape.front() /= bankingFactor;
SmallVector<int64_t, 4> newShape(originalShape.begin(), originalShape.end());
newShape[bankingDimension] /= bankingFactor;
MemRefType newMemRefType =
MemRefType::get(newShape, originalType.getElementType(),
originalType.getLayout(), originalType.getMemorySpace());

return newMemRefType;
}

SmallVector<Value, 4> createBanks(Value originalMem, uint64_t bankingFactor) {
SmallVector<Value, 4> createBanks(Value originalMem, uint64_t bankingFactor,
unsigned bankingDimension) {
MemRefType originalMemRefType = cast<MemRefType>(originalMem.getType());
MemRefType newMemRefType =
computeBankedMemRefType(originalMemRefType, bankingFactor);
MemRefType newMemRefType = computeBankedMemRefType(
originalMemRefType, bankingFactor, bankingDimension);
SmallVector<Value, 4> banks;
if (auto blockArgMem = dyn_cast<BlockArgument>(originalMem)) {
Block *block = blockArgMem.getOwner();
Expand Down Expand Up @@ -132,24 +135,31 @@ SmallVector<Value, 4> createBanks(Value originalMem, uint64_t bankingFactor) {
struct BankAffineLoadPattern
: public OpRewritePattern<mlir::affine::AffineLoadOp> {
BankAffineLoadPattern(MLIRContext *context, uint64_t bankingFactor,
unsigned bankingDimension,
DenseMap<Value, SmallVector<Value>> &memoryToBanks)
: OpRewritePattern<mlir::affine::AffineLoadOp>(context),
bankingFactor(bankingFactor), memoryToBanks(memoryToBanks) {}
bankingFactor(bankingFactor), bankingDimension(bankingDimension),
memoryToBanks(memoryToBanks) {}

LogicalResult matchAndRewrite(mlir::affine::AffineLoadOp loadOp,
PatternRewriter &rewriter) const override {
Location loc = loadOp.getLoc();
auto banks = memoryToBanks[loadOp.getMemref()];
Value loadIndex = loadOp.getIndices().front();
auto modMap =
AffineMap::get(1, 0, {rewriter.getAffineDimExpr(0) % bankingFactor});
auto loadIndices = loadOp.getIndices();
int64_t memrefRank = loadOp.getMemRefType().getRank();
auto modMap = AffineMap::get(
/*dimCount=*/memrefRank, /*symbolCount=*/0,
{rewriter.getAffineDimExpr(bankingDimension) % bankingFactor});
auto divMap = AffineMap::get(
1, 0, {rewriter.getAffineDimExpr(0).floorDiv(bankingFactor)});
memrefRank, 0,
{rewriter.getAffineDimExpr(bankingDimension).floorDiv(bankingFactor)});

Value bankIndex = rewriter.create<affine::AffineApplyOp>(
loc, modMap, loadIndex); // assuming one-dim
Value bankIndex =
rewriter.create<affine::AffineApplyOp>(loc, modMap, loadIndices);
Value offset =
rewriter.create<affine::AffineApplyOp>(loc, divMap, loadIndex);
rewriter.create<affine::AffineApplyOp>(loc, divMap, loadIndices);
SmallVector<Value, 4> newIndices(loadIndices.begin(), loadIndices.end());
newIndices[bankingDimension] = offset;

SmallVector<Type> resultTypes = {loadOp.getResult().getType()};

Expand All @@ -165,8 +175,8 @@ struct BankAffineLoadPattern
for (unsigned i = 0; i < bankingFactor; ++i) {
Region &caseRegion = switchOp.getCaseRegions()[i];
rewriter.setInsertionPointToStart(&caseRegion.emplaceBlock());
Value bankedLoad =
rewriter.create<mlir::affine::AffineLoadOp>(loc, banks[i], offset);
Value bankedLoad = rewriter.create<mlir::affine::AffineLoadOp>(
loc, banks[i], newIndices);
rewriter.create<scf::YieldOp>(loc, bankedLoad);
}

Expand All @@ -186,19 +196,22 @@ struct BankAffineLoadPattern

private:
uint64_t bankingFactor;
unsigned bankingDimension;
DenseMap<Value, SmallVector<Value>> &memoryToBanks;
};

// Replace the original store operations with newly created memory banks
struct BankAffineStorePattern
: public OpRewritePattern<mlir::affine::AffineStoreOp> {
BankAffineStorePattern(MLIRContext *context, uint64_t bankingFactor,
unsigned bankingDimension,
DenseMap<Value, SmallVector<Value>> &memoryToBanks,
DenseSet<Operation *> &opsToErase,
DenseSet<Operation *> &processedOps)
: OpRewritePattern<mlir::affine::AffineStoreOp>(context),
bankingFactor(bankingFactor), memoryToBanks(memoryToBanks),
opsToErase(opsToErase), processedOps(processedOps) {}
bankingFactor(bankingFactor), bankingDimension(bankingDimension),
memoryToBanks(memoryToBanks), opsToErase(opsToErase),
processedOps(processedOps) {}

LogicalResult matchAndRewrite(mlir::affine::AffineStoreOp storeOp,
PatternRewriter &rewriter) const override {
Expand All @@ -207,17 +220,22 @@ struct BankAffineStorePattern
}
Location loc = storeOp.getLoc();
auto banks = memoryToBanks[storeOp.getMemref()];
Value storeIndex = storeOp.getIndices().front();
auto storeIndices = storeOp.getIndices();
int64_t memrefRank = storeOp.getMemRefType().getRank();

auto modMap =
AffineMap::get(1, 0, {rewriter.getAffineDimExpr(0) % bankingFactor});
auto modMap = AffineMap::get(
/*dimCount=*/memrefRank, /*symbolCount=*/0,
{rewriter.getAffineDimExpr(bankingDimension) % bankingFactor});
auto divMap = AffineMap::get(
1, 0, {rewriter.getAffineDimExpr(0).floorDiv(bankingFactor)});
memrefRank, 0,
{rewriter.getAffineDimExpr(bankingDimension).floorDiv(bankingFactor)});

Value bankIndex = rewriter.create<affine::AffineApplyOp>(
loc, modMap, storeIndex); // assuming one-dim
Value bankIndex =
rewriter.create<affine::AffineApplyOp>(loc, modMap, storeIndices);
Value offset =
rewriter.create<affine::AffineApplyOp>(loc, divMap, storeIndex);
rewriter.create<affine::AffineApplyOp>(loc, divMap, storeIndices);
SmallVector<Value, 4> newIndices(storeIndices.begin(), storeIndices.end());
newIndices[bankingDimension] = offset;

SmallVector<Type> resultTypes = {};

Expand All @@ -234,7 +252,7 @@ struct BankAffineStorePattern
Region &caseRegion = switchOp.getCaseRegions()[i];
rewriter.setInsertionPointToStart(&caseRegion.emplaceBlock());
rewriter.create<mlir::affine::AffineStoreOp>(
loc, storeOp.getValueToStore(), banks[i], offset);
loc, storeOp.getValueToStore(), banks[i], newIndices);
rewriter.create<scf::YieldOp>(loc);
}

Expand All @@ -252,6 +270,7 @@ struct BankAffineStorePattern

private:
uint64_t bankingFactor;
unsigned bankingDimension;
DenseMap<Value, SmallVector<Value>> &memoryToBanks;
DenseSet<Operation *> &opsToErase;
DenseSet<Operation *> &processedOps;
Expand Down Expand Up @@ -358,17 +377,21 @@ void MemoryBankingPass::runOnOperation() {
getOperation().walk([&](mlir::affine::AffineParallelOp parOp) {
DenseSet<Value> memrefsInPar = collectMemRefs(parOp);

for (auto memrefVal : memrefsInPar)
memoryToBanks[memrefVal] = createBanks(memrefVal, bankingFactor);
for (auto memrefVal : memrefsInPar) {
if (!memoryToBanks.contains(memrefVal))
memoryToBanks[memrefVal] =
cgyurgyik marked this conversation as resolved.
Show resolved Hide resolved
createBanks(memrefVal, bankingFactor, bankingDimension);
}
});

auto *ctx = &getContext();
RewritePatternSet patterns(ctx);

DenseSet<Operation *> processedOps;
patterns.add<BankAffineLoadPattern>(ctx, bankingFactor, memoryToBanks);
patterns.add<BankAffineStorePattern>(ctx, bankingFactor, memoryToBanks,
opsToErase, processedOps);
patterns.add<BankAffineLoadPattern>(ctx, bankingFactor, bankingDimension,
memoryToBanks);
patterns.add<BankAffineStorePattern>(ctx, bankingFactor, bankingDimension,
memoryToBanks, opsToErase, processedOps);
patterns.add<BankReturnPattern>(ctx, memoryToBanks);

GreedyRewriteConfig config;
Expand All @@ -390,7 +413,8 @@ void MemoryBankingPass::runOnOperation() {

namespace circt {
std::unique_ptr<mlir::Pass>
createMemoryBankingPass(std::optional<unsigned> bankingFactor) {
return std::make_unique<MemoryBankingPass>(bankingFactor);
createMemoryBankingPass(std::optional<unsigned> bankingFactor,
std::optional<unsigned> bankingDimension) {
return std::make_unique<MemoryBankingPass>(bankingFactor, bankingDimension);
}
} // namespace circt
75 changes: 75 additions & 0 deletions test/Transforms/memory_banking_multi_dim.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
// RUN: circt-opt %s -split-input-file -memory-banking="banking-factor=2 dimension=1" | FileCheck %s --check-prefix RANK2-BANKDIM1

// RANK2-BANKDIM1: #[[$ATTR_0:.+]] = affine_map<(d0, d1) -> (d1 mod 2)>
// RANK2-BANKDIM1: #[[$ATTR_1:.+]] = affine_map<(d0, d1) -> (d1 floordiv 2)>

// RANK2-BANKDIM1-LABEL: func.func @rank_two_bank_dim1(
// RANK2-BANKDIM1-SAME: %[[VAL_0:arg0]]: memref<8x3xf32>,
// RANK2-BANKDIM1-SAME: %[[VAL_1:arg1]]: memref<8x3xf32>,
// RANK2-BANKDIM1-SAME: %[[VAL_2:arg2]]: memref<8x3xf32>,
// RANK2-BANKDIM1-SAME: %[[VAL_3:arg3]]: memref<8x3xf32>) -> (memref<8x3xf32>, memref<8x3xf32>) {
// RANK2-BANKDIM1: %[[VAL_4:.*]] = arith.constant 0.000000e+00 : f32
// RANK2-BANKDIM1: %[[VAL_5:.*]] = memref.alloc() : memref<8x3xf32>
// RANK2-BANKDIM1: %[[VAL_6:.*]] = memref.alloc() : memref<8x3xf32>
// RANK2-BANKDIM1: affine.parallel (%[[VAL_7:.*]]) = (0) to (8) {
// RANK2-BANKDIM1: affine.parallel (%[[VAL_8:.*]]) = (0) to (6) {
// RANK2-BANKDIM1: %[[VAL_9:.*]] = affine.apply #[[$ATTR_0]](%[[VAL_7]], %[[VAL_8]])
// RANK2-BANKDIM1: %[[VAL_10:.*]] = affine.apply #[[$ATTR_1]](%[[VAL_7]], %[[VAL_8]])
// RANK2-BANKDIM1: %[[VAL_11:.*]] = scf.index_switch %[[VAL_9]] -> f32
// RANK2-BANKDIM1: case 0 {
// RANK2-BANKDIM1: %[[VAL_12:.*]] = affine.load %[[VAL_0]]{{\[}}%[[VAL_7]], %[[VAL_10]]] : memref<8x3xf32>
// RANK2-BANKDIM1: scf.yield %[[VAL_12]] : f32
// RANK2-BANKDIM1: }
// RANK2-BANKDIM1: case 1 {
// RANK2-BANKDIM1: %[[VAL_13:.*]] = affine.load %[[VAL_1]]{{\[}}%[[VAL_7]], %[[VAL_10]]] : memref<8x3xf32>
// RANK2-BANKDIM1: scf.yield %[[VAL_13]] : f32
// RANK2-BANKDIM1: }
// RANK2-BANKDIM1: default {
// RANK2-BANKDIM1: scf.yield %[[VAL_4]] : f32
// RANK2-BANKDIM1: }
// RANK2-BANKDIM1: %[[VAL_14:.*]] = affine.apply #[[$ATTR_0]](%[[VAL_7]], %[[VAL_8]])
// RANK2-BANKDIM1: %[[VAL_15:.*]] = affine.apply #[[$ATTR_1]](%[[VAL_7]], %[[VAL_8]])
// RANK2-BANKDIM1: %[[VAL_16:.*]] = scf.index_switch %[[VAL_14]] -> f32
// RANK2-BANKDIM1: case 0 {
// RANK2-BANKDIM1: %[[VAL_17:.*]] = affine.load %[[VAL_2]]{{\[}}%[[VAL_7]], %[[VAL_15]]] : memref<8x3xf32>
// RANK2-BANKDIM1: scf.yield %[[VAL_17]] : f32
// RANK2-BANKDIM1: }
// RANK2-BANKDIM1: case 1 {
// RANK2-BANKDIM1: %[[VAL_18:.*]] = affine.load %[[VAL_3]]{{\[}}%[[VAL_7]], %[[VAL_15]]] : memref<8x3xf32>
// RANK2-BANKDIM1: scf.yield %[[VAL_18]] : f32
// RANK2-BANKDIM1: }
// RANK2-BANKDIM1: default {
// RANK2-BANKDIM1: scf.yield %[[VAL_4]] : f32
// RANK2-BANKDIM1: }
// RANK2-BANKDIM1: %[[VAL_19:.*]] = arith.mulf %[[VAL_11]], %[[VAL_16]] : f32
// RANK2-BANKDIM1: %[[VAL_20:.*]] = affine.apply #[[$ATTR_0]](%[[VAL_7]], %[[VAL_8]])
// RANK2-BANKDIM1: %[[VAL_21:.*]] = affine.apply #[[$ATTR_1]](%[[VAL_7]], %[[VAL_8]])
// RANK2-BANKDIM1: scf.index_switch %[[VAL_20]]
// RANK2-BANKDIM1: case 0 {
// RANK2-BANKDIM1: affine.store %[[VAL_19]], %[[VAL_5]]{{\[}}%[[VAL_7]], %[[VAL_21]]] : memref<8x3xf32>
// RANK2-BANKDIM1: scf.yield
// RANK2-BANKDIM1: }
// RANK2-BANKDIM1: case 1 {
// RANK2-BANKDIM1: affine.store %[[VAL_19]], %[[VAL_6]]{{\[}}%[[VAL_7]], %[[VAL_21]]] : memref<8x3xf32>
// RANK2-BANKDIM1: scf.yield
// RANK2-BANKDIM1: }
// RANK2-BANKDIM1: default {
// RANK2-BANKDIM1: }
// RANK2-BANKDIM1: }
// RANK2-BANKDIM1: }
// RANK2-BANKDIM1: return %[[VAL_5]], %[[VAL_6]] : memref<8x3xf32>, memref<8x3xf32>
// RANK2-BANKDIM1: }

func.func @rank_two_bank_dim1(%arg0: memref<8x6xf32>, %arg1: memref<8x6xf32>) -> (memref<8x6xf32>) {
%mem = memref.alloc() : memref<8x6xf32>
affine.parallel (%i) = (0) to (8) {
affine.parallel (%j) = (0) to (6) {
%1 = affine.load %arg0[%i, %j] : memref<8x6xf32>
%2 = affine.load %arg1[%i, %j] : memref<8x6xf32>
%3 = arith.mulf %1, %2 : f32
affine.store %3, %mem[%i, %j] : memref<8x6xf32>
}
}
return %mem : memref<8x6xf32>
}

Loading