Skip to content

Commit

Permalink
MNIST Arch TestCase
Browse files Browse the repository at this point in the history
  • Loading branch information
WoutLegiest committed Jan 24, 2025
1 parent 3f984f7 commit 0ec53c5
Show file tree
Hide file tree
Showing 8 changed files with 59 additions and 55 deletions.
42 changes: 0 additions & 42 deletions lib/Pipelines/ArithmeticPipelineRegistration.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,41 +103,6 @@ void mlirToSecretArithmeticPipelineBuilder(OpPassManager &pm) {
pm.addPass(createOperationBalancer());
}

void tosaToArithPipelineBuilder(OpPassManager &pm) {
// TOSA to linalg
::mlir::heir::tosaToLinalg(pm);

// Bufferize
::mlir::heir::oneShotBufferize(pm);

// Affine
pm.addNestedPass<FuncOp>(createConvertLinalgToAffineLoopsPass());
pm.addNestedPass<FuncOp>(memref::createExpandStridedMetadataPass());
pm.addNestedPass<FuncOp>(affine::createAffineExpandIndexOpsPass());
pm.addNestedPass<FuncOp>(memref::createExpandOpsPass());
pm.addPass(createExpandCopyPass());
pm.addNestedPass<FuncOp>(affine::createSimplifyAffineStructuresPass());
pm.addNestedPass<FuncOp>(affine::createAffineLoopNormalizePass(true));
pm.addPass(memref::createFoldMemRefAliasOpsPass());

// Affine loop optimizations
pm.addNestedPass<FuncOp>(
affine::createLoopFusionPass(0, 0, true, affine::FusionMode::Greedy));
pm.addNestedPass<FuncOp>(affine::createAffineLoopNormalizePass(true));
pm.addPass(createForwardStoreToLoad());
pm.addPass(affine::createAffineParallelizePass());
pm.addPass(createFullLoopUnroll());
pm.addPass(createForwardStoreToLoad());
pm.addNestedPass<FuncOp>(createRemoveUnusedMemRef());

// Cleanup
pm.addPass(createMemrefGlobalReplacePass());
pm.addPass(createCanonicalizerPass());
pm.addPass(createSCCPPass());
pm.addPass(createCSEPass());
pm.addPass(createSymbolDCEPass());
}

void mlirToRLWEPipeline(OpPassManager &pm,
const MlirToRLWEPipelineOptions &options,
const RLWEScheme scheme) {
Expand Down Expand Up @@ -283,11 +248,4 @@ RLWEPipelineBuilder mlirToLattigoRLWEPipelineBuilder(const RLWEScheme scheme) {
lattigo::createConfigureCryptoContext(configureCryptoContextOptions));
};
}

void registerTosaToArithPipeline() {
PassPipelineRegistration<>(
"tosa-to-arith", "Arithmetic modules to arith tfhe-rs pipeline.",
[](OpPassManager &pm) { tosaToArithPipelineBuilder(pm); });
}

} // namespace mlir::heir
4 changes: 0 additions & 4 deletions lib/Pipelines/ArithmeticPipelineRegistration.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,6 @@ struct MlirToRLWEPipelineOptions
using RLWEPipelineBuilder =
std::function<void(OpPassManager &, const MlirToRLWEPipelineOptions &)>;

void tosaToArithPipelineBuilder(OpPassManager &pm);

void mlirToRLWEPipeline(OpPassManager &pm,
const MlirToRLWEPipelineOptions &options,
RLWEScheme scheme);
Expand All @@ -62,8 +60,6 @@ RLWEPipelineBuilder mlirToOpenFheRLWEPipelineBuilder(RLWEScheme scheme);

RLWEPipelineBuilder mlirToLattigoRLWEPipelineBuilder(RLWEScheme scheme);

void registerTosaToArithPipeline();

} // namespace mlir::heir

#endif // LIB_PIPELINES_ARITHMETICPIPELINEREGISTRATION_H_
7 changes: 7 additions & 0 deletions lib/Pipelines/BooleanPipelineRegistration.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,13 @@ struct TosaToBooleanTfheOptions
llvm::cl::init("main")};
};

struct TosaToArithTfheOptions
: public PassPipelineOptions<TosaToArithTfheOptions> {
PassOptions::Option<bool> unroll{
*this, "full-unroll", llvm::cl::desc("Full unroll all loops."),
llvm::cl::init(true)};
};

struct TosaToBooleanJaxiteOptions : public TosaToBooleanTfheOptions {
PassOptions::Option<int> parallelism{
*this, "parallelism",
Expand Down
5 changes: 3 additions & 2 deletions lib/Pipelines/PipelineRegistration.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ void oneShotBufferize(OpPassManager &manager) {
manager.addPass(createCanonicalizerPass());
}

void tosaPipelineBuilder(OpPassManager &manager) {
void tosaPipelineBuilder(OpPassManager &manager, bool unroll) {
// TOSA to linalg
tosaToLinalg(manager);
// Bufferize
Expand All @@ -78,7 +78,8 @@ void tosaPipelineBuilder(OpPassManager &manager) {
manager.addPass(memref::createFoldMemRefAliasOpsPass());
manager.addPass(createExpandCopyPass());
manager.addPass(createExtractLoopBodyPass());
manager.addPass(createUnrollAndForwardPass());
if (unroll) manager.addPass(createUnrollAndForwardPass());

// Cleanup
manager.addPass(createMemrefGlobalReplacePass());
arith::ArithIntRangeNarrowingOptions options;
Expand Down
2 changes: 1 addition & 1 deletion lib/Pipelines/PipelineRegistration.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ void tosaToLinalg(OpPassManager &manager);

void oneShotBufferize(OpPassManager &manager);

void tosaPipelineBuilder(OpPassManager &manager);
void tosaPipelineBuilder(OpPassManager &manager, bool unroll);

void polynomialToLLVMPipelineBuilder(OpPassManager &manager);

Expand Down
17 changes: 17 additions & 0 deletions tests/Transforms/tosa_to_tfhe/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
load("//bazel:lit.bzl", "glob_lit_tests")

package(default_applicable_licenses = ["@heir//:license"])

glob_lit_tests(
name = "all_tests",
data = ["@heir//tests:test_utilities"],
driver = "@heir//tests:run_lit.sh",
tags_override = {
"mnist_2fc.mlir": [
"nofastbuild",
"notap",
"manual",
],
},
test_file_exts = ["mlir"],
)
24 changes: 24 additions & 0 deletions tests/Transforms/tosa_to_tfhe/mnist_2fc.mlir

Large diffs are not rendered by default.

13 changes: 7 additions & 6 deletions tools/heir-opt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -304,8 +304,6 @@ int main(int argc, char **argv) {
// Register internal pipeline
#endif

registerTosaToArithPipeline();

// Dialect conversion passes in HEIR
mod_arith::registerModArithToArithPasses();
mlir::heir::arith::registerArithToModArithPasses();
Expand All @@ -332,10 +330,13 @@ int main(int argc, char **argv) {
secret::registerBufferizableOpInterfaceExternalModels(registry);
rns::registerExternalRNSTypeInterfaces(registry);

PassPipelineRegistration<>("heir-tosa-to-arith",
"Run passes to lower TOSA models with stripped "
"quant types to arithmetic",
::mlir::heir::tosaPipelineBuilder);
PassPipelineRegistration<TosaToArithTfheOptions>(
"heir-tosa-to-arith",
"Run passes to lower TOSA models with stripped "
"quant types to arithmetic",
[](OpPassManager &pm, const TosaToArithTfheOptions &options) {
::mlir::heir::tosaPipelineBuilder(pm, options.unroll);
});

PassPipelineRegistration<>(
"heir-polynomial-to-llvm",
Expand Down

0 comments on commit 0ec53c5

Please sign in to comment.