CIRCT  20.0.0git
GroupResetsAndEnables.cpp
Go to the documentation of this file.
1 //===- GroupResetsAndEnables.cpp ------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
11 #include "mlir/Dialect/SCF/IR/SCF.h"
12 #include "mlir/IR/Dominance.h"
13 #include "mlir/IR/OpDefinition.h"
14 #include "mlir/Pass/Pass.h"
15 #include "mlir/Pass/PassManager.h"
16 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
17 #include "llvm/ADT/SmallVector.h"
18 #include "llvm/Support/Debug.h"
19 
20 #define DEBUG_TYPE "arc-group-resets-and-enables"
21 
22 namespace circt {
23 namespace arc {
24 #define GEN_PASS_DEF_GROUPRESETSANDENABLES
25 #include "circt/Dialect/Arc/ArcPasses.h.inc"
26 } // namespace arc
27 } // namespace circt
28 
29 using namespace circt;
30 using namespace arc;
31 using namespace mlir;
32 
33 //===----------------------------------------------------------------------===//
34 // Rewrite Patterns
35 //===----------------------------------------------------------------------===//
36 
37 namespace {
38 
39 struct ResetGroupingPattern : public OpRewritePattern<ClockTreeOp> {
40  using OpRewritePattern::OpRewritePattern;
41  LogicalResult matchAndRewrite(ClockTreeOp clockTreeOp,
42  PatternRewriter &rewriter) const override {
43  // Group similar resets into single IfOps
44  // Create a list of reset values and map from them to the states they reset
45  llvm::MapVector<mlir::Value, SmallVector<scf::IfOp>> resetMap;
46 
47  for (auto ifOp : clockTreeOp.getBody().getOps<scf::IfOp>())
48  if (ifOp.getResults().empty())
49  resetMap[ifOp.getCondition()].push_back(ifOp);
50 
51  // TODO: Check that conflicting memory effects aren't being reordered
52 
53  // Combine IfOps
54  bool changed = false;
55  for (auto &[cond, oldOps] : resetMap) {
56  if (oldOps.size() <= 1)
57  continue;
58  scf::IfOp lastIfOp = oldOps.pop_back_val();
59  for (auto thisOp : oldOps) {
60  // Inline the before and after region inside the original If
61  rewriter.eraseOp(thisOp.thenBlock()->getTerminator());
62  rewriter.inlineBlockBefore(thisOp.thenBlock(),
63  &lastIfOp.thenBlock()->front());
64  // Check we're not inlining an empty block
65  if (auto *elseBlock = thisOp.elseBlock()) {
66  rewriter.eraseOp(elseBlock->getTerminator());
67  if (auto *lastElseBlock = lastIfOp.elseBlock()) {
68  rewriter.inlineBlockBefore(elseBlock,
69  &lastIfOp.elseBlock()->front());
70  } else {
71  lastElseBlock = rewriter.createBlock(&lastIfOp.getElseRegion());
72  rewriter.setInsertionPointToEnd(lastElseBlock);
73  auto yieldOp = rewriter.create<scf::YieldOp>(
74  lastElseBlock->getParentOp()->getLoc());
75  rewriter.inlineBlockBefore(thisOp.elseBlock(), yieldOp);
76  }
77  }
78  rewriter.eraseOp(thisOp);
79  changed = true;
80  }
81  }
82  return success(changed);
83  }
84 };
85 
86 struct EnableGroupingPattern : public OpRewritePattern<ClockTreeOp> {
87  using OpRewritePattern::OpRewritePattern;
88  LogicalResult matchAndRewrite(ClockTreeOp clockTreeOp,
89  PatternRewriter &rewriter) const override {
90  // Amass regions that we want to group enables in
91  SmallVector<Region *> groupingRegions;
92  groupingRegions.push_back(&clockTreeOp.getBody());
93  for (auto ifOp : clockTreeOp.getBody().getOps<scf::IfOp>()) {
94  groupingRegions.push_back(&ifOp.getThenRegion());
95  groupingRegions.push_back(&ifOp.getElseRegion());
96  }
97 
98  bool changed = false;
99  for (auto *region : groupingRegions) {
100  llvm::MapVector<mlir::Value, SmallVector<StateWriteOp>> enableMap;
101  for (auto writeOp : region->getOps<StateWriteOp>()) {
102  if (writeOp.getCondition())
103  enableMap[writeOp.getCondition()].push_back(writeOp);
104  }
105  for (auto &[enable, writeOps] : enableMap) {
106  // Only group if multiple writes share an enable
107  if (writeOps.size() <= 1)
108  continue;
109  if (region->getParentOp()->hasTrait<OpTrait::NoTerminator>())
110  rewriter.setInsertionPointToEnd(&region->back());
111  else
112  rewriter.setInsertionPoint(region->back().getTerminator());
113  scf::IfOp ifOp =
114  rewriter.create<scf::IfOp>(writeOps[0].getLoc(), enable, false);
115  for (auto writeOp : writeOps) {
116  rewriter.modifyOpInPlace(writeOp, [&]() {
117  writeOp->moveBefore(ifOp.thenBlock()->getTerminator());
118  writeOp.getConditionMutable().erase(0);
119  });
120  }
121  changed = true;
122  }
123  }
124  return success(changed);
125  }
126 };
127 
128 /// Where possible without domination issues, group assignments inside IfOps and
129 /// return true if any operations were moved.
130 bool groupInRegion(Block *block, Operation *clockTreeOp,
131  PatternRewriter *rewriter) {
132  bool changed = false;
133  if (!block)
134  return false;
135 
136  SmallVector<Operation *> worklist;
137  // Don't walk as we don't want nested ops in order to restrict to IfOps
138  for (auto &op : block->getOperations()) {
139  worklist.push_back(&op);
140  }
141  while (!worklist.empty()) {
142  Operation *op = worklist.pop_back_val();
143  mlir::DominanceInfo dom(op);
144  for (auto operand : op->getOperands()) {
145  Operation *definition = operand.getDefiningOp();
146  if (definition == nullptr)
147  continue;
148  // Skip if the operand is already defined in this block or is
149  // defined out of the clock tree
150  if (definition->getBlock() == op->getBlock() ||
151  !clockTreeOp->isAncestor(definition))
152  continue;
153  if (llvm::any_of(definition->getUsers(),
154  [&](auto *user) { return !dom.dominates(op, user); }))
155  continue;
156  // For some currently unknown reason, just calling moveBefore
157  // directly has the same output but is much slower
158  rewriter->modifyOpInPlace(definition,
159  [&]() { definition->moveBefore(op); });
160  changed = true;
161  worklist.push_back(definition);
162  }
163  }
164  return changed;
165 }
166 
167 struct GroupAssignmentsInIfPattern : public OpRewritePattern<scf::IfOp> {
168  using OpRewritePattern::OpRewritePattern;
169  LogicalResult matchAndRewrite(scf::IfOp ifOp,
170  PatternRewriter &rewriter) const override {
171  // Pull values only used in certain reset/enable cases into the appropriate
172  // IfOps
173  // Skip anything not in a ClockTreeOp
174  auto clockTreeOp = ifOp->getParentOfType<ClockTreeOp>();
175  if (!clockTreeOp)
176  return failure();
177  // Group assignments in each region and keep track of whether either
178  // grouping made changes
179  bool changed = groupInRegion(ifOp.thenBlock(), clockTreeOp, &rewriter) ||
180  groupInRegion(ifOp.elseBlock(), clockTreeOp, &rewriter);
181  return success(changed);
182  }
183 };
184 
185 } // namespace
186 
187 //===----------------------------------------------------------------------===//
188 // Pass Infrastructure
189 //===----------------------------------------------------------------------===//
190 
191 namespace {
192 struct GroupResetsAndEnablesPass
193  : public arc::impl::GroupResetsAndEnablesBase<GroupResetsAndEnablesPass> {
194 
195  void runOnOperation() override;
196  LogicalResult runOnModel(ModelOp modelOp);
197 };
198 } // namespace
199 
200 void GroupResetsAndEnablesPass::runOnOperation() {
201  for (auto op : getOperation().getOps<ModelOp>())
202  if (failed(runOnModel(op)))
203  return signalPassFailure();
204 }
205 
206 LogicalResult GroupResetsAndEnablesPass::runOnModel(ModelOp modelOp) {
207  LLVM_DEBUG(llvm::dbgs() << "Grouping resets and enables in `"
208  << modelOp.getName() << "`\n");
209 
210  MLIRContext &context = getContext();
211  RewritePatternSet patterns(&context);
212  patterns.add<ResetGroupingPattern, EnableGroupingPattern,
213  GroupAssignmentsInIfPattern>(&context);
214 
215  if (failed(applyPatternsAndFoldGreedily(modelOp, std::move(patterns))))
216  return emitError(modelOp.getLoc(),
217  "GroupResetsAndEnables: greedy rewriter did not converge");
218 
219  return success();
220 }
221 
222 std::unique_ptr<Pass> arc::createGroupResetsAndEnablesPass() {
223  return std::make_unique<GroupResetsAndEnablesPass>();
224 }
std::unique_ptr< mlir::Pass > createGroupResetsAndEnablesPass()
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21