diff --git a/include/circt/Transforms/Passes.h b/include/circt/Transforms/Passes.h index 2a13b48cfb3c..df75c6ec96c9 100644 --- a/include/circt/Transforms/Passes.h +++ b/include/circt/Transforms/Passes.h @@ -44,8 +44,9 @@ std::unique_ptr createStripDebugInfoWithPredPass( std::unique_ptr createMaximizeSSAPass(); std::unique_ptr createInsertMergeBlocksPass(); std::unique_ptr createPrintOpCountPass(); -std::unique_ptr -createMemoryBankingPass(std::optional bankingFactor = std::nullopt); +std::unique_ptr createMemoryBankingPass( + std::optional bankingFactor = std::nullopt, + std::optional bankingDimension = std::nullopt); std::unique_ptr createIndexSwitchToIfPass(); //===----------------------------------------------------------------------===// diff --git a/include/circt/Transforms/Passes.td b/include/circt/Transforms/Passes.td index 67e25d2db4eb..8c25c0f41c5f 100644 --- a/include/circt/Transforms/Passes.td +++ b/include/circt/Transforms/Passes.td @@ -128,7 +128,9 @@ def MemoryBanking : Pass<"memory-banking", "::mlir::func::FuncOp"> { let constructor = "circt::createMemoryBankingPass()"; let options = [ Option<"bankingFactor", "banking-factor", "unsigned", /*default=*/"1", - "Use this banking factor for all memories being partitioned"> + "Use this banking factor for all memories being partitioned">, + Option<"bankingDimension", "dimension", "unsigned", /*default=*/"0", + "The dimension along which to bank the memory. For rank=1, must be 0."> ]; let dependentDialects = ["mlir::memref::MemRefDialect, mlir::scf::SCFDialect, mlir::affine::AffineDialect"]; } diff --git a/lib/Transforms/MemoryBanking.cpp b/lib/Transforms/MemoryBanking.cpp index 2337d2ae6874..c4e974d0a64b 100644 --- a/lib/Transforms/MemoryBanking.cpp +++ b/lib/Transforms/MemoryBanking.cpp @@ -24,7 +24,6 @@ #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/TypeSwitch.h" -#include "llvm/Support/raw_ostream.h" namespace circt { #define GEN_PASS_DEF_MEMORYBANKING @@ -42,7 +41,8 @@ struct MemoryBankingPass : public circt::impl::MemoryBankingBase { MemoryBankingPass(const MemoryBankingPass &other) = default; explicit MemoryBankingPass( - std::optional bankingFactor = std::nullopt) {} + std::optional bankingFactor = std::nullopt, + std::optional bankingDimension = std::nullopt) {} void runOnOperation() override; @@ -67,15 +67,17 @@ DenseSet collectMemRefs(mlir::affine::AffineParallelOp parOp) { } MemRefType computeBankedMemRefType(MemRefType originalType, - uint64_t bankingFactor) { + uint64_t bankingFactor, + unsigned bankingDimension) { ArrayRef originalShape = originalType.getShape(); assert(!originalShape.empty() && "memref shape should not be empty"); - assert(originalType.getRank() == 1 && - "currently only support one dimension memories"); - SmallVector newShape(originalShape.begin(), originalShape.end()); - assert(newShape.front() % bankingFactor == 0 && + + assert(bankingDimension < originalType.getRank() && + "dimension must be within the memref rank"); + assert(originalShape[bankingDimension] % bankingFactor == 0 && "memref shape must be evenly divided by the banking factor"); - newShape.front() /= bankingFactor; + SmallVector newShape(originalShape.begin(), originalShape.end()); + newShape[bankingDimension] /= bankingFactor; MemRefType newMemRefType = MemRefType::get(newShape, originalType.getElementType(), originalType.getLayout(), originalType.getMemorySpace()); @@ -83,20 +85,19 @@ MemRefType computeBankedMemRefType(MemRefType originalType, return newMemRefType; } -SmallVector createBanks(Value originalMem, uint64_t bankingFactor) { +SmallVector createBanks(Value originalMem, uint64_t bankingFactor, + unsigned bankingDimension) { MemRefType originalMemRefType = cast(originalMem.getType()); - MemRefType newMemRefType = - computeBankedMemRefType(originalMemRefType, bankingFactor); + MemRefType newMemRefType = computeBankedMemRefType( + originalMemRefType, bankingFactor, bankingDimension); SmallVector banks; if (auto blockArgMem = dyn_cast(originalMem)) { Block *block = blockArgMem.getOwner(); unsigned blockArgNum = blockArgMem.getArgNumber(); - SmallVector banksType; - for (unsigned i = 0; i < bankingFactor; ++i) { + for (unsigned i = 0; i < bankingFactor; ++i) block->insertArgument(blockArgNum + 1 + i, newMemRefType, blockArgMem.getLoc()); - } auto blockArgs = block->getArguments().slice(blockArgNum + 1, bankingFactor); @@ -132,24 +133,31 @@ SmallVector createBanks(Value originalMem, uint64_t bankingFactor) { struct BankAffineLoadPattern : public OpRewritePattern { BankAffineLoadPattern(MLIRContext *context, uint64_t bankingFactor, + unsigned bankingDimension, DenseMap> &memoryToBanks) : OpRewritePattern(context), - bankingFactor(bankingFactor), memoryToBanks(memoryToBanks) {} + bankingFactor(bankingFactor), bankingDimension(bankingDimension), + memoryToBanks(memoryToBanks) {} LogicalResult matchAndRewrite(mlir::affine::AffineLoadOp loadOp, PatternRewriter &rewriter) const override { Location loc = loadOp.getLoc(); auto banks = memoryToBanks[loadOp.getMemref()]; - Value loadIndex = loadOp.getIndices().front(); - auto modMap = - AffineMap::get(1, 0, {rewriter.getAffineDimExpr(0) % bankingFactor}); + auto loadIndices = loadOp.getIndices(); + int64_t memrefRank = loadOp.getMemRefType().getRank(); + auto modMap = AffineMap::get( + /*dimCount=*/memrefRank, /*symbolCount=*/0, + {rewriter.getAffineDimExpr(bankingDimension) % bankingFactor}); auto divMap = AffineMap::get( - 1, 0, {rewriter.getAffineDimExpr(0).floorDiv(bankingFactor)}); + memrefRank, 0, + {rewriter.getAffineDimExpr(bankingDimension).floorDiv(bankingFactor)}); - Value bankIndex = rewriter.create( - loc, modMap, loadIndex); // assuming one-dim + Value bankIndex = + rewriter.create(loc, modMap, loadIndices); Value offset = - rewriter.create(loc, divMap, loadIndex); + rewriter.create(loc, divMap, loadIndices); + SmallVector newIndices(loadIndices.begin(), loadIndices.end()); + newIndices[bankingDimension] = offset; SmallVector resultTypes = {loadOp.getResult().getType()}; @@ -165,8 +173,8 @@ struct BankAffineLoadPattern for (unsigned i = 0; i < bankingFactor; ++i) { Region &caseRegion = switchOp.getCaseRegions()[i]; rewriter.setInsertionPointToStart(&caseRegion.emplaceBlock()); - Value bankedLoad = - rewriter.create(loc, banks[i], offset); + Value bankedLoad = rewriter.create( + loc, banks[i], newIndices); rewriter.create(loc, bankedLoad); } @@ -186,6 +194,7 @@ struct BankAffineLoadPattern private: uint64_t bankingFactor; + unsigned bankingDimension; DenseMap> &memoryToBanks; }; @@ -193,12 +202,14 @@ struct BankAffineLoadPattern struct BankAffineStorePattern : public OpRewritePattern { BankAffineStorePattern(MLIRContext *context, uint64_t bankingFactor, + unsigned bankingDimension, DenseMap> &memoryToBanks, DenseSet &opsToErase, DenseSet &processedOps) : OpRewritePattern(context), - bankingFactor(bankingFactor), memoryToBanks(memoryToBanks), - opsToErase(opsToErase), processedOps(processedOps) {} + bankingFactor(bankingFactor), bankingDimension(bankingDimension), + memoryToBanks(memoryToBanks), opsToErase(opsToErase), + processedOps(processedOps) {} LogicalResult matchAndRewrite(mlir::affine::AffineStoreOp storeOp, PatternRewriter &rewriter) const override { @@ -207,17 +218,22 @@ struct BankAffineStorePattern } Location loc = storeOp.getLoc(); auto banks = memoryToBanks[storeOp.getMemref()]; - Value storeIndex = storeOp.getIndices().front(); + auto storeIndices = storeOp.getIndices(); + int64_t memrefRank = storeOp.getMemRefType().getRank(); - auto modMap = - AffineMap::get(1, 0, {rewriter.getAffineDimExpr(0) % bankingFactor}); + auto modMap = AffineMap::get( + /*dimCount=*/memrefRank, /*symbolCount=*/0, + {rewriter.getAffineDimExpr(bankingDimension) % bankingFactor}); auto divMap = AffineMap::get( - 1, 0, {rewriter.getAffineDimExpr(0).floorDiv(bankingFactor)}); + memrefRank, 0, + {rewriter.getAffineDimExpr(bankingDimension).floorDiv(bankingFactor)}); - Value bankIndex = rewriter.create( - loc, modMap, storeIndex); // assuming one-dim + Value bankIndex = + rewriter.create(loc, modMap, storeIndices); Value offset = - rewriter.create(loc, divMap, storeIndex); + rewriter.create(loc, divMap, storeIndices); + SmallVector newIndices(storeIndices.begin(), storeIndices.end()); + newIndices[bankingDimension] = offset; SmallVector resultTypes = {}; @@ -234,7 +250,7 @@ struct BankAffineStorePattern Region &caseRegion = switchOp.getCaseRegions()[i]; rewriter.setInsertionPointToStart(&caseRegion.emplaceBlock()); rewriter.create( - loc, storeOp.getValueToStore(), banks[i], offset); + loc, storeOp.getValueToStore(), banks[i], newIndices); rewriter.create(loc); } @@ -252,6 +268,7 @@ struct BankAffineStorePattern private: uint64_t bankingFactor; + unsigned bankingDimension; DenseMap> &memoryToBanks; DenseSet &opsToErase; DenseSet &processedOps; @@ -358,17 +375,22 @@ void MemoryBankingPass::runOnOperation() { getOperation().walk([&](mlir::affine::AffineParallelOp parOp) { DenseSet memrefsInPar = collectMemRefs(parOp); - for (auto memrefVal : memrefsInPar) - memoryToBanks[memrefVal] = createBanks(memrefVal, bankingFactor); + for (auto memrefVal : memrefsInPar) { + auto [it, inserted] = + memoryToBanks.insert(std::make_pair(memrefVal, SmallVector{})); + if (inserted) + it->second = createBanks(memrefVal, bankingFactor, bankingDimension); + } }); auto *ctx = &getContext(); RewritePatternSet patterns(ctx); DenseSet processedOps; - patterns.add(ctx, bankingFactor, memoryToBanks); - patterns.add(ctx, bankingFactor, memoryToBanks, - opsToErase, processedOps); + patterns.add(ctx, bankingFactor, bankingDimension, + memoryToBanks); + patterns.add(ctx, bankingFactor, bankingDimension, + memoryToBanks, opsToErase, processedOps); patterns.add(ctx, memoryToBanks); GreedyRewriteConfig config; @@ -390,7 +412,8 @@ void MemoryBankingPass::runOnOperation() { namespace circt { std::unique_ptr -createMemoryBankingPass(std::optional bankingFactor) { - return std::make_unique(bankingFactor); +createMemoryBankingPass(std::optional bankingFactor, + std::optional bankingDimension) { + return std::make_unique(bankingFactor, bankingDimension); } } // namespace circt diff --git a/test/Transforms/memory_banking_multi_dim.mlir b/test/Transforms/memory_banking_multi_dim.mlir new file mode 100644 index 000000000000..d8c32da9089b --- /dev/null +++ b/test/Transforms/memory_banking_multi_dim.mlir @@ -0,0 +1,75 @@ +// RUN: circt-opt %s -split-input-file -memory-banking="banking-factor=2 dimension=1" | FileCheck %s --check-prefix RANK2-BANKDIM1 + +// RANK2-BANKDIM1: #[[$ATTR_0:.+]] = affine_map<(d0, d1) -> (d1 mod 2)> +// RANK2-BANKDIM1: #[[$ATTR_1:.+]] = affine_map<(d0, d1) -> (d1 floordiv 2)> + +// RANK2-BANKDIM1-LABEL: func.func @rank_two_bank_dim1( +// RANK2-BANKDIM1-SAME: %[[VAL_0:arg0]]: memref<8x3xf32>, +// RANK2-BANKDIM1-SAME: %[[VAL_1:arg1]]: memref<8x3xf32>, +// RANK2-BANKDIM1-SAME: %[[VAL_2:arg2]]: memref<8x3xf32>, +// RANK2-BANKDIM1-SAME: %[[VAL_3:arg3]]: memref<8x3xf32>) -> (memref<8x3xf32>, memref<8x3xf32>) { +// RANK2-BANKDIM1: %[[VAL_4:.*]] = arith.constant 0.000000e+00 : f32 +// RANK2-BANKDIM1: %[[VAL_5:.*]] = memref.alloc() : memref<8x3xf32> +// RANK2-BANKDIM1: %[[VAL_6:.*]] = memref.alloc() : memref<8x3xf32> +// RANK2-BANKDIM1: affine.parallel (%[[VAL_7:.*]]) = (0) to (8) { +// RANK2-BANKDIM1: affine.parallel (%[[VAL_8:.*]]) = (0) to (6) { +// RANK2-BANKDIM1: %[[VAL_9:.*]] = affine.apply #[[$ATTR_0]](%[[VAL_7]], %[[VAL_8]]) +// RANK2-BANKDIM1: %[[VAL_10:.*]] = affine.apply #[[$ATTR_1]](%[[VAL_7]], %[[VAL_8]]) +// RANK2-BANKDIM1: %[[VAL_11:.*]] = scf.index_switch %[[VAL_9]] -> f32 +// RANK2-BANKDIM1: case 0 { +// RANK2-BANKDIM1: %[[VAL_12:.*]] = affine.load %[[VAL_0]]{{\[}}%[[VAL_7]], %[[VAL_10]]] : memref<8x3xf32> +// RANK2-BANKDIM1: scf.yield %[[VAL_12]] : f32 +// RANK2-BANKDIM1: } +// RANK2-BANKDIM1: case 1 { +// RANK2-BANKDIM1: %[[VAL_13:.*]] = affine.load %[[VAL_1]]{{\[}}%[[VAL_7]], %[[VAL_10]]] : memref<8x3xf32> +// RANK2-BANKDIM1: scf.yield %[[VAL_13]] : f32 +// RANK2-BANKDIM1: } +// RANK2-BANKDIM1: default { +// RANK2-BANKDIM1: scf.yield %[[VAL_4]] : f32 +// RANK2-BANKDIM1: } +// RANK2-BANKDIM1: %[[VAL_14:.*]] = affine.apply #[[$ATTR_0]](%[[VAL_7]], %[[VAL_8]]) +// RANK2-BANKDIM1: %[[VAL_15:.*]] = affine.apply #[[$ATTR_1]](%[[VAL_7]], %[[VAL_8]]) +// RANK2-BANKDIM1: %[[VAL_16:.*]] = scf.index_switch %[[VAL_14]] -> f32 +// RANK2-BANKDIM1: case 0 { +// RANK2-BANKDIM1: %[[VAL_17:.*]] = affine.load %[[VAL_2]]{{\[}}%[[VAL_7]], %[[VAL_15]]] : memref<8x3xf32> +// RANK2-BANKDIM1: scf.yield %[[VAL_17]] : f32 +// RANK2-BANKDIM1: } +// RANK2-BANKDIM1: case 1 { +// RANK2-BANKDIM1: %[[VAL_18:.*]] = affine.load %[[VAL_3]]{{\[}}%[[VAL_7]], %[[VAL_15]]] : memref<8x3xf32> +// RANK2-BANKDIM1: scf.yield %[[VAL_18]] : f32 +// RANK2-BANKDIM1: } +// RANK2-BANKDIM1: default { +// RANK2-BANKDIM1: scf.yield %[[VAL_4]] : f32 +// RANK2-BANKDIM1: } +// RANK2-BANKDIM1: %[[VAL_19:.*]] = arith.mulf %[[VAL_11]], %[[VAL_16]] : f32 +// RANK2-BANKDIM1: %[[VAL_20:.*]] = affine.apply #[[$ATTR_0]](%[[VAL_7]], %[[VAL_8]]) +// RANK2-BANKDIM1: %[[VAL_21:.*]] = affine.apply #[[$ATTR_1]](%[[VAL_7]], %[[VAL_8]]) +// RANK2-BANKDIM1: scf.index_switch %[[VAL_20]] +// RANK2-BANKDIM1: case 0 { +// RANK2-BANKDIM1: affine.store %[[VAL_19]], %[[VAL_5]]{{\[}}%[[VAL_7]], %[[VAL_21]]] : memref<8x3xf32> +// RANK2-BANKDIM1: scf.yield +// RANK2-BANKDIM1: } +// RANK2-BANKDIM1: case 1 { +// RANK2-BANKDIM1: affine.store %[[VAL_19]], %[[VAL_6]]{{\[}}%[[VAL_7]], %[[VAL_21]]] : memref<8x3xf32> +// RANK2-BANKDIM1: scf.yield +// RANK2-BANKDIM1: } +// RANK2-BANKDIM1: default { +// RANK2-BANKDIM1: } +// RANK2-BANKDIM1: } +// RANK2-BANKDIM1: } +// RANK2-BANKDIM1: return %[[VAL_5]], %[[VAL_6]] : memref<8x3xf32>, memref<8x3xf32> +// RANK2-BANKDIM1: } + +func.func @rank_two_bank_dim1(%arg0: memref<8x6xf32>, %arg1: memref<8x6xf32>) -> (memref<8x6xf32>) { + %mem = memref.alloc() : memref<8x6xf32> + affine.parallel (%i) = (0) to (8) { + affine.parallel (%j) = (0) to (6) { + %1 = affine.load %arg0[%i, %j] : memref<8x6xf32> + %2 = affine.load %arg1[%i, %j] : memref<8x6xf32> + %3 = arith.mulf %1, %2 : f32 + affine.store %3, %mem[%i, %j] : memref<8x6xf32> + } + } + return %mem : memref<8x6xf32> +} +