Skip to content

Commit

Permalink
[FIRRTL] LowerXMR: process all modules (#8168)
Browse files Browse the repository at this point in the history
This changes LowerXMR to process all modules instead of just those that
are reachable from the top level module.
  • Loading branch information
youngar authored Feb 1, 2025
1 parent 6f78c42 commit d070590
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 55 deletions.
114 changes: 59 additions & 55 deletions lib/Dialect/FIRRTL/Transforms/LowerXMR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -300,69 +300,73 @@ class LowerXMRPass : public circt::firrtl::impl::LowerXMRBase<LowerXMRPass> {
SmallVector<FModuleOp> publicModules;

// Traverse the modules in post order.
for (auto node : llvm::post_order(&instanceGraph)) {
auto module = dyn_cast<FModuleOp>(*node->getModule());
if (!module)
continue;
LLVM_DEBUG(llvm::dbgs()
<< "Traversing module:" << module.getModuleNameAttr() << "\n");

moduleStates.insert({module, ModuleState(module)});
DenseSet<InstanceGraphNode *> visited;
for (auto *root : instanceGraph) {
for (auto *node : llvm::post_order_ext(root, visited)) {
auto module = dyn_cast<FModuleOp>(*node->getModule());
if (!module)
continue;
LLVM_DEBUG(llvm::dbgs() << "Traversing module:"
<< module.getModuleNameAttr() << "\n");

if (module.isPublic())
publicModules.push_back(module);
moduleStates.insert({module, ModuleState(module)});

auto result = module.walk([&](Operation *op) {
if (transferFunc(op).failed())
return WalkResult::interrupt();
return WalkResult::advance();
});
if (module.isPublic())
publicModules.push_back(module);

if (result.wasInterrupted())
return signalPassFailure();
auto result = module.walk([&](Operation *op) {
if (transferFunc(op).failed())
return WalkResult::interrupt();
return WalkResult::advance();
});

// Clear any enabled layers.
module.setLayersAttr(ArrayAttr::get(module.getContext(), {}));

// Since we walk operations pre-order and not along dataflow edges,
// ref.sub may not be resolvable when we encounter them (they're not just
// unification). This can happen when refs go through an output port or
// input instance result and back into the design. Handle these by walking
// them, resolving what we can, until all are handled or nothing can be
// resolved.
while (!indexingOps.empty()) {
// Grab the set of unresolved ref.sub's.
decltype(indexingOps) worklist;
worklist.swap(indexingOps);

for (auto op : worklist) {
auto inputEntry =
getRemoteRefSend(op.getInput(), /*errorIfNotFound=*/false);
// If we can't resolve, add back and move on.
if (!inputEntry)
indexingOps.push_back(op);
else
addReachingSendsEntry(op.getResult(), op.getOperation(),
inputEntry);
}
// If nothing was resolved, give up.
if (worklist.size() == indexingOps.size()) {
auto op = worklist.front();
getRemoteRefSend(op.getInput());
op.emitError(
"indexing through probe of unknown origin (input probe?)")
.attachNote(op.getInput().getLoc())
.append("indexing through this reference");
if (result.wasInterrupted())
return signalPassFailure();
}
}

// Record all the RefType ports to be removed later.
size_t numPorts = module.getNumPorts();
for (size_t portNum = 0; portNum < numPorts; ++portNum)
if (isa<RefType>(module.getPortType(portNum))) {
setPortToRemove(module, portNum, numPorts);
// Clear any enabled layers.
module.setLayersAttr(ArrayAttr::get(module.getContext(), {}));

// Since we walk operations pre-order and not along dataflow edges,
// ref.sub may not be resolvable when we encounter them (they're not
// just unification). This can happen when refs go through an output
// port or input instance result and back into the design. Handle these
// by walking them, resolving what we can, until all are handled or
// nothing can be resolved.
while (!indexingOps.empty()) {
// Grab the set of unresolved ref.sub's.
decltype(indexingOps) worklist;
worklist.swap(indexingOps);

for (auto op : worklist) {
auto inputEntry =
getRemoteRefSend(op.getInput(), /*errorIfNotFound=*/false);
// If we can't resolve, add back and move on.
if (!inputEntry)
indexingOps.push_back(op);
else
addReachingSendsEntry(op.getResult(), op.getOperation(),
inputEntry);
}
// If nothing was resolved, give up.
if (worklist.size() == indexingOps.size()) {
auto op = worklist.front();
getRemoteRefSend(op.getInput());
op.emitError(
"indexing through probe of unknown origin (input probe?)")
.attachNote(op.getInput().getLoc())
.append("indexing through this reference");
return signalPassFailure();
}
}

// Record all the RefType ports to be removed later.
size_t numPorts = module.getNumPorts();
for (size_t portNum = 0; portNum < numPorts; ++portNum)
if (isa<RefType>(module.getPortType(portNum))) {
setPortToRemove(module, portNum, numPorts);
}
}
}

LLVM_DEBUG({
Expand Down
23 changes: 23 additions & 0 deletions test/Dialect/FIRRTL/lowerXMR.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -796,3 +796,26 @@ firrtl.circuit "Foo" {
}
}
}

// -----
// Test that all modules are reached and updated.

// CHECK-LABEL: firrtl.circuit "PF"
firrtl.circuit "PF" {
// CHECK: @Child()
firrtl.module @Child(out %p: !firrtl.probe<uint<1>>) {
%c1_ui1 = firrtl.constant 1 : !firrtl.uint<1>
%0 = firrtl.ref.send %c1_ui1 : !firrtl.uint<1>
firrtl.ref.define %p, %0 : !firrtl.probe<uint<1>>
}
// CHECK: @PF()
firrtl.module @PF(out %p: !firrtl.probe<uint<1>>) {
%c_p = firrtl.instance c @Child(out p: !firrtl.probe<uint<1>>)
firrtl.ref.define %p, %c_p : !firrtl.probe<uint<1>>
}
// CHECK: @Other()
firrtl.module @Other(out %p: !firrtl.probe<uint<1>>) {
%c_p = firrtl.instance c @Child(out p: !firrtl.probe<uint<1>>)
firrtl.ref.define %p, %c_p : !firrtl.probe<uint<1>>
}
}

0 comments on commit d070590

Please sign in to comment.