11 #include "mlir/Dialect/SCF/IR/SCF.h"
12 #include "mlir/IR/Dominance.h"
13 #include "mlir/IR/OpDefinition.h"
14 #include "mlir/Pass/Pass.h"
15 #include "mlir/Pass/PassManager.h"
16 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
17 #include "llvm/ADT/SmallVector.h"
18 #include "llvm/Support/Debug.h"
20 #define DEBUG_TYPE "arc-group-resets-and-enables"
24 #define GEN_PASS_DEF_GROUPRESETSANDENABLES
25 #include "circt/Dialect/Arc/ArcPasses.h.inc"
29 using namespace circt;
40 using OpRewritePattern::OpRewritePattern;
41 LogicalResult matchAndRewrite(ClockTreeOp clockTreeOp,
42 PatternRewriter &rewriter)
const override {
45 llvm::MapVector<mlir::Value, SmallVector<scf::IfOp>> resetMap;
47 for (
auto ifOp : clockTreeOp.getBody().getOps<scf::IfOp>())
48 if (ifOp.getResults().empty())
49 resetMap[ifOp.getCondition()].push_back(ifOp);
55 for (
auto &[cond, oldOps] : resetMap) {
56 if (oldOps.size() <= 1)
58 scf::IfOp lastIfOp = oldOps.pop_back_val();
59 for (
auto thisOp : oldOps) {
61 rewriter.eraseOp(thisOp.thenBlock()->getTerminator());
62 rewriter.inlineBlockBefore(thisOp.thenBlock(),
63 &lastIfOp.thenBlock()->front());
65 if (
auto *elseBlock = thisOp.elseBlock()) {
66 rewriter.eraseOp(elseBlock->getTerminator());
67 if (
auto *lastElseBlock = lastIfOp.elseBlock()) {
68 rewriter.inlineBlockBefore(elseBlock,
69 &lastIfOp.elseBlock()->front());
71 lastElseBlock = rewriter.createBlock(&lastIfOp.getElseRegion());
72 rewriter.setInsertionPointToEnd(lastElseBlock);
73 auto yieldOp = rewriter.create<scf::YieldOp>(
74 lastElseBlock->getParentOp()->getLoc());
75 rewriter.inlineBlockBefore(thisOp.elseBlock(), yieldOp);
78 rewriter.eraseOp(thisOp);
82 return success(changed);
87 using OpRewritePattern::OpRewritePattern;
88 LogicalResult matchAndRewrite(ClockTreeOp clockTreeOp,
89 PatternRewriter &rewriter)
const override {
91 SmallVector<Region *> groupingRegions;
92 groupingRegions.push_back(&clockTreeOp.getBody());
93 for (
auto ifOp : clockTreeOp.getBody().getOps<scf::IfOp>()) {
94 groupingRegions.push_back(&ifOp.getThenRegion());
95 groupingRegions.push_back(&ifOp.getElseRegion());
99 for (
auto *region : groupingRegions) {
100 llvm::MapVector<mlir::Value, SmallVector<StateWriteOp>> enableMap;
101 for (
auto writeOp : region->getOps<StateWriteOp>()) {
102 if (writeOp.getCondition())
103 enableMap[writeOp.getCondition()].push_back(writeOp);
105 for (
auto &[enable, writeOps] : enableMap) {
107 if (writeOps.size() <= 1)
109 if (region->getParentOp()->hasTrait<OpTrait::NoTerminator>())
110 rewriter.setInsertionPointToEnd(®ion->back());
112 rewriter.setInsertionPoint(region->back().getTerminator());
114 rewriter.create<scf::IfOp>(writeOps[0].getLoc(), enable,
false);
115 for (
auto writeOp : writeOps) {
116 rewriter.modifyOpInPlace(writeOp, [&]() {
117 writeOp->moveBefore(ifOp.thenBlock()->getTerminator());
118 writeOp.getConditionMutable().erase(0);
124 return success(changed);
130 bool groupInRegion(Block *block, Operation *clockTreeOp,
131 PatternRewriter *rewriter) {
132 bool changed =
false;
136 SmallVector<Operation *> worklist;
138 for (
auto &op : block->getOperations()) {
139 worklist.push_back(&op);
141 while (!worklist.empty()) {
142 Operation *op = worklist.pop_back_val();
143 mlir::DominanceInfo dom(op);
144 for (
auto operand : op->getOperands()) {
145 Operation *definition = operand.getDefiningOp();
146 if (definition ==
nullptr)
150 if (definition->getBlock() == op->getBlock() ||
151 !clockTreeOp->isAncestor(definition))
153 if (llvm::any_of(definition->getUsers(),
154 [&](
auto *user) { return !dom.dominates(op, user); }))
158 rewriter->modifyOpInPlace(definition,
159 [&]() { definition->moveBefore(op); });
161 worklist.push_back(definition);
168 using OpRewritePattern::OpRewritePattern;
169 LogicalResult matchAndRewrite(scf::IfOp ifOp,
170 PatternRewriter &rewriter)
const override {
174 auto clockTreeOp = ifOp->getParentOfType<ClockTreeOp>();
179 bool changed = groupInRegion(ifOp.thenBlock(), clockTreeOp, &rewriter) ||
180 groupInRegion(ifOp.elseBlock(), clockTreeOp, &rewriter);
181 return success(changed);
192 struct GroupResetsAndEnablesPass
193 :
public arc::impl::GroupResetsAndEnablesBase<GroupResetsAndEnablesPass> {
195 void runOnOperation()
override;
196 LogicalResult runOnModel(ModelOp modelOp);
200 void GroupResetsAndEnablesPass::runOnOperation() {
201 for (
auto op : getOperation().getOps<ModelOp>())
202 if (failed(runOnModel(op)))
203 return signalPassFailure();
206 LogicalResult GroupResetsAndEnablesPass::runOnModel(ModelOp modelOp) {
207 LLVM_DEBUG(llvm::dbgs() <<
"Grouping resets and enables in `"
208 << modelOp.getName() <<
"`\n");
210 MLIRContext &context = getContext();
211 RewritePatternSet
patterns(&context);
212 patterns.add<ResetGroupingPattern, EnableGroupingPattern,
213 GroupAssignmentsInIfPattern>(&context);
215 if (failed(applyPatternsAndFoldGreedily(modelOp, std::move(
patterns))))
216 return emitError(modelOp.getLoc(),
217 "GroupResetsAndEnables: greedy rewriter did not converge");
223 return std::make_unique<GroupResetsAndEnablesPass>();
std::unique_ptr< mlir::Pass > createGroupResetsAndEnablesPass()
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.