CIRCT  19.0.0git
ArcCanonicalizer.cpp
Go to the documentation of this file.
1 //===- ArcCanonicalizer.cpp -------------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //===----------------------------------------------------------------------===//
7 //
8 // Simulation centric canonicalizations for non-arc operations and
9 // canonicalizations that require efficient symbol lookups.
10 //
11 //===----------------------------------------------------------------------===//
12 
16 #include "circt/Dialect/HW/HWOps.h"
19 #include "circt/Support/SymCache.h"
20 #include "mlir/IR/IRMapping.h"
21 #include "mlir/IR/Matchers.h"
22 #include "mlir/IR/PatternMatch.h"
23 #include "mlir/Pass/Pass.h"
24 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
25 #include "llvm/Support/Debug.h"
26 
27 #define DEBUG_TYPE "arc-canonicalizer"
28 
29 namespace circt {
30 namespace arc {
31 #define GEN_PASS_DEF_ARCCANONICALIZER
32 #include "circt/Dialect/Arc/ArcPasses.h.inc"
33 } // namespace arc
34 } // namespace circt
35 
36 using namespace circt;
37 using namespace arc;
38 
39 //===----------------------------------------------------------------------===//
40 // Datastructures
41 //===----------------------------------------------------------------------===//
42 
43 namespace {
44 
45 /// A combination of SymbolCache and SymbolUserMap that also allows to add users
46 /// and remove symbols on-demand.
47 class SymbolHandler : public SymbolCache {
48 public:
49  /// Return the users of the provided symbol operation.
50  ArrayRef<Operation *> getUsers(Operation *symbol) const {
51  auto it = userMap.find(symbol);
52  return it != userMap.end() ? it->second.getArrayRef() : std::nullopt;
53  }
54 
55  /// Return true if the given symbol has no uses.
56  bool useEmpty(Operation *symbol) {
57  return !userMap.count(symbol) || userMap[symbol].empty();
58  }
59 
60  void addUser(Operation *def, Operation *user) {
61  assert(isa<mlir::SymbolOpInterface>(def));
62  if (!symbolCache.contains(cast<mlir::SymbolOpInterface>(def).getNameAttr()))
63  symbolCache.insert(
64  {cast<mlir::SymbolOpInterface>(def).getNameAttr(), def});
65  userMap[def].insert(user);
66  }
67 
68  void removeUser(Operation *def, Operation *user) {
69  assert(isa<mlir::SymbolOpInterface>(def));
70  if (symbolCache.contains(cast<mlir::SymbolOpInterface>(def).getNameAttr()))
71  userMap[def].remove(user);
72  if (userMap[def].empty())
73  userMap.erase(def);
74  }
75 
76  void removeDefinitionAndAllUsers(Operation *def) {
77  assert(isa<mlir::SymbolOpInterface>(def));
78  symbolCache.erase(cast<mlir::SymbolOpInterface>(def).getNameAttr());
79  userMap.erase(def);
80  }
81 
82  void collectAllSymbolUses(Operation *symbolTableOp,
83  SymbolTableCollection &symbolTable) {
84  // NOTE: the following is almost 1-1 taken from the SymbolUserMap
85  // constructor. They made it difficult to extend the implementation by
86  // having a lot of members private and non-virtual methods.
87  SmallVector<Operation *> symbols;
88  auto walkFn = [&](Operation *symbolTableOp, bool allUsesVisible) {
89  for (Operation &nestedOp : symbolTableOp->getRegion(0).getOps()) {
90  auto symbolUses = SymbolTable::getSymbolUses(&nestedOp);
91  assert(symbolUses && "expected uses to be valid");
92 
93  for (const SymbolTable::SymbolUse &use : *symbolUses) {
94  symbols.clear();
95  (void)symbolTable.lookupSymbolIn(symbolTableOp, use.getSymbolRef(),
96  symbols);
97  for (Operation *symbolOp : symbols)
98  userMap[symbolOp].insert(use.getUser());
99  }
100  }
101  };
102  // We just set `allSymUsesVisible` to false here because it isn't necessary
103  // for building the user map.
104  SymbolTable::walkSymbolTables(symbolTableOp, /*allSymUsesVisible=*/false,
105  walkFn);
106  }
107 
108 private:
109  DenseMap<Operation *, SetVector<Operation *>> userMap;
110 };
111 
112 /// A Listener keeping the provided SymbolHandler up-to-date. This is especially
113 /// important for simplifications (e.g. DCE) the rewriter performs automatically
114 /// that we cannot or do not want to turn off.
115 class ArcListener : public mlir::RewriterBase::Listener {
116 public:
117  explicit ArcListener(SymbolHandler *handler) : Listener(), handler(handler) {}
118 
119  void notifyOperationReplaced(Operation *op, Operation *replacement) override {
120  // If, e.g., a DefineOp is replaced with another DefineOp but with the same
121  // symbol, we don't want to drop the list of users.
122  auto symOp = dyn_cast<mlir::SymbolOpInterface>(op);
123  auto symReplacement = dyn_cast<mlir::SymbolOpInterface>(replacement);
124  if (symOp && symReplacement &&
125  symOp.getNameAttr() == symReplacement.getNameAttr())
126  return;
127 
128  remove(op);
129  // TODO: if an operation is inserted that defines a symbol and the symbol
130  // already has uses, those users are not added.
131  add(replacement);
132  }
133 
134  void notifyOperationReplaced(Operation *op, ValueRange replacement) override {
135  remove(op);
136  }
137 
138  void notifyOperationErased(Operation *op) override { remove(op); }
139 
140  void notifyOperationInserted(Operation *op,
141  mlir::IRRewriter::InsertPoint) override {
142  // TODO: if an operation is inserted that defines a symbol and the symbol
143  // already has uses, those users are not added.
144  add(op);
145  }
146 
147 private:
148  FailureOr<Operation *> maybeGetDefinition(Operation *op) {
149  if (auto callOp = dyn_cast<mlir::CallOpInterface>(op)) {
150  auto symAttr =
151  dyn_cast<mlir::SymbolRefAttr>(callOp.getCallableForCallee());
152  if (!symAttr)
153  return failure();
154  if (auto *def = handler->getDefinition(symAttr.getLeafReference()))
155  return def;
156  }
157  return failure();
158  }
159 
160  void remove(Operation *op) {
161  auto maybeDef = maybeGetDefinition(op);
162  if (!failed(maybeDef))
163  handler->removeUser(*maybeDef, op);
164 
165  if (isa<mlir::SymbolOpInterface>(op))
166  handler->removeDefinitionAndAllUsers(op);
167  }
168 
169  void add(Operation *op) {
170  auto maybeDef = maybeGetDefinition(op);
171  if (!failed(maybeDef))
172  handler->addUser(*maybeDef, op);
173 
174  if (auto defOp = dyn_cast<mlir::SymbolOpInterface>(op))
175  handler->addDefinition(defOp.getNameAttr(), op);
176  }
177 
178  SymbolHandler *handler;
179 };
180 
181 struct PatternStatistics {
182  unsigned removeUnusedArcArgumentsPatternNumArgsRemoved = 0;
183 };
184 
185 } // namespace
186 
187 //===----------------------------------------------------------------------===//
188 // Canonicalization patterns
189 //===----------------------------------------------------------------------===//
190 
191 namespace {
192 /// A rewrite pattern that has access to a symbol cache to access and modify the
193 /// symbol-defining op and symbol users as well as a namespace to query new
194 /// names. Each pattern has to make sure that the symbol handler is kept
195 /// up-to-date no matter whether the pattern succeeds of fails.
196 template <typename SourceOp>
197 class SymOpRewritePattern : public OpRewritePattern<SourceOp> {
198 public:
199  SymOpRewritePattern(MLIRContext *ctxt, SymbolHandler &symbolCache,
200  Namespace &names, PatternStatistics &stats,
201  mlir::PatternBenefit benefit = 1,
202  ArrayRef<StringRef> generatedNames = {})
203  : OpRewritePattern<SourceOp>(ctxt, benefit, generatedNames), names(names),
204  symbolCache(symbolCache), statistics(stats) {}
205 
206 protected:
207  Namespace &names;
208  SymbolHandler &symbolCache;
209  PatternStatistics &statistics;
210 };
211 
212 class MemWritePortEnableAndMaskCanonicalizer
213  : public SymOpRewritePattern<MemoryWritePortOp> {
214 public:
215  MemWritePortEnableAndMaskCanonicalizer(
216  MLIRContext *ctxt, SymbolHandler &symbolCache, Namespace &names,
217  PatternStatistics &stats, DenseMap<StringAttr, StringAttr> &arcMapping)
218  : SymOpRewritePattern<MemoryWritePortOp>(ctxt, symbolCache, names, stats),
219  arcMapping(arcMapping) {}
220  LogicalResult matchAndRewrite(MemoryWritePortOp op,
221  PatternRewriter &rewriter) const final;
222 
223 private:
224  DenseMap<StringAttr, StringAttr> &arcMapping;
225 };
226 
227 struct CallPassthroughArc : public SymOpRewritePattern<CallOp> {
228  using SymOpRewritePattern::SymOpRewritePattern;
229  LogicalResult matchAndRewrite(CallOp op,
230  PatternRewriter &rewriter) const final;
231 };
232 
233 struct RemoveUnusedArcs : public SymOpRewritePattern<DefineOp> {
234  using SymOpRewritePattern::SymOpRewritePattern;
235  LogicalResult matchAndRewrite(DefineOp op,
236  PatternRewriter &rewriter) const final;
237 };
238 
239 struct ICMPCanonicalizer : public OpRewritePattern<comb::ICmpOp> {
240  using OpRewritePattern::OpRewritePattern;
241  LogicalResult matchAndRewrite(comb::ICmpOp op,
242  PatternRewriter &rewriter) const final;
243 };
244 
245 struct CompRegCanonicalizer : public OpRewritePattern<seq::CompRegOp> {
246  using OpRewritePattern::OpRewritePattern;
247  LogicalResult matchAndRewrite(seq::CompRegOp op,
248  PatternRewriter &rewriter) const final;
249 };
250 
251 struct RemoveUnusedArcArgumentsPattern : public SymOpRewritePattern<DefineOp> {
252  using SymOpRewritePattern::SymOpRewritePattern;
253  LogicalResult matchAndRewrite(DefineOp op,
254  PatternRewriter &rewriter) const final;
255 };
256 
257 struct SinkArcInputsPattern : public SymOpRewritePattern<DefineOp> {
258  using SymOpRewritePattern::SymOpRewritePattern;
259  LogicalResult matchAndRewrite(DefineOp op,
260  PatternRewriter &rewriter) const final;
261 };
262 
263 struct MergeVectorizeOps : public OpRewritePattern<VectorizeOp> {
264  using OpRewritePattern::OpRewritePattern;
265  LogicalResult matchAndRewrite(VectorizeOp op,
266  PatternRewriter &rewriter) const final;
267 };
268 
269 struct KeepOneVecOp : public OpRewritePattern<VectorizeOp> {
270  using OpRewritePattern::OpRewritePattern;
271  LogicalResult matchAndRewrite(VectorizeOp op,
272  PatternRewriter &rewriter) const final;
273 };
274 
275 } // namespace
276 
277 //===----------------------------------------------------------------------===//
278 // Helpers
279 //===----------------------------------------------------------------------===//
280 
281 LogicalResult canonicalizePassthoughCall(mlir::CallOpInterface callOp,
282  SymbolHandler &symbolCache,
283  PatternRewriter &rewriter) {
284  auto defOp = cast<DefineOp>(symbolCache.getDefinition(
285  callOp.getCallableForCallee().get<SymbolRefAttr>().getLeafReference()));
286  if (defOp.isPassthrough()) {
287  symbolCache.removeUser(defOp, callOp);
288  rewriter.replaceOp(callOp, callOp.getArgOperands());
289  return success();
290  }
291  return failure();
292 }
293 
294 LogicalResult updateInputOperands(VectorizeOp &vecOp,
295  const SmallVector<Value> &newOperands) {
296  // Set the new inputOperandSegments value
297  unsigned groupSize = vecOp.getResults().size();
298  unsigned numOfGroups = newOperands.size() / groupSize;
299  SmallVector<int32_t> newAttr(numOfGroups, groupSize);
300  vecOp.setInputOperandSegments(newAttr);
301  vecOp.getOperation()->setOperands(ValueRange(newOperands));
302  return success();
303 }
304 
305 //===----------------------------------------------------------------------===//
306 // Canonicalization pattern implementations
307 //===----------------------------------------------------------------------===//
308 
309 LogicalResult MemWritePortEnableAndMaskCanonicalizer::matchAndRewrite(
310  MemoryWritePortOp op, PatternRewriter &rewriter) const {
311  auto defOp = cast<DefineOp>(symbolCache.getDefinition(op.getArcAttr()));
312  APInt enable;
313 
314  if (op.getEnable() &&
315  mlir::matchPattern(
316  defOp.getBodyBlock().getTerminator()->getOperand(op.getEnableIdx()),
317  mlir::m_ConstantInt(&enable))) {
318  if (enable.isZero()) {
319  symbolCache.removeUser(defOp, op);
320  rewriter.eraseOp(op);
321  if (symbolCache.useEmpty(defOp)) {
322  symbolCache.removeDefinitionAndAllUsers(defOp);
323  rewriter.eraseOp(defOp);
324  }
325  return success();
326  }
327  if (enable.isAllOnes()) {
328  if (arcMapping.count(defOp.getNameAttr())) {
329  auto arcWithoutEnable = arcMapping[defOp.getNameAttr()];
330  // Remove the enable attribute
331  rewriter.modifyOpInPlace(op, [&]() {
332  op.setEnable(false);
333  op.setArc(arcWithoutEnable.getValue());
334  });
335  symbolCache.removeUser(defOp, op);
336  symbolCache.addUser(symbolCache.getDefinition(arcWithoutEnable), op);
337  return success();
338  }
339 
340  auto newName = names.newName(defOp.getName());
341  auto users = SmallVector<Operation *>(symbolCache.getUsers(defOp));
342  symbolCache.removeDefinitionAndAllUsers(defOp);
343 
344  // Remove the enable attribute
345  rewriter.modifyOpInPlace(op, [&]() {
346  op.setEnable(false);
347  op.setArc(newName);
348  });
349 
350  auto newResultTypes = op.getArcResultTypes();
351 
352  // Create a new arc that acts as replacement for other users
353  rewriter.setInsertionPoint(defOp);
354  auto newDefOp = rewriter.cloneWithoutRegions(defOp);
355  auto *block = rewriter.createBlock(
356  &newDefOp.getBody(), newDefOp.getBody().end(),
357  newDefOp.getArgumentTypes(),
358  SmallVector<Location>(newDefOp.getNumArguments(), defOp.getLoc()));
359  auto callOp = rewriter.create<CallOp>(newDefOp.getLoc(), newResultTypes,
360  newName, block->getArguments());
361  SmallVector<Value> results(callOp->getResults());
362  Value constTrue = rewriter.create<hw::ConstantOp>(
363  newDefOp.getLoc(), rewriter.getI1Type(), 1);
364  results.insert(results.begin() + op.getEnableIdx(), constTrue);
365  rewriter.create<OutputOp>(newDefOp.getLoc(), results);
366 
367  // Remove the enable output from the current arc
368  auto *terminator = defOp.getBodyBlock().getTerminator();
369  rewriter.modifyOpInPlace(
370  terminator, [&]() { terminator->eraseOperand(op.getEnableIdx()); });
371  rewriter.modifyOpInPlace(defOp, [&]() {
372  defOp.setName(newName);
373  defOp.setFunctionType(
374  rewriter.getFunctionType(defOp.getArgumentTypes(), newResultTypes));
375  });
376 
377  // Update symbol cache
378  symbolCache.addDefinition(defOp.getNameAttr(), defOp);
379  symbolCache.addDefinition(newDefOp.getNameAttr(), newDefOp);
380  symbolCache.addUser(defOp, callOp);
381  for (auto *user : users)
382  symbolCache.addUser(user == op ? defOp : newDefOp, user);
383 
384  arcMapping[newDefOp.getNameAttr()] = defOp.getNameAttr();
385  return success();
386  }
387  }
388  return failure();
389 }
390 
391 LogicalResult
392 CallPassthroughArc::matchAndRewrite(CallOp op,
393  PatternRewriter &rewriter) const {
394  return canonicalizePassthoughCall(op, symbolCache, rewriter);
395 }
396 
397 LogicalResult
398 RemoveUnusedArcs::matchAndRewrite(DefineOp op,
399  PatternRewriter &rewriter) const {
400  if (symbolCache.useEmpty(op)) {
401  op.getBody().walk([&](mlir::CallOpInterface user) {
402  if (auto symbol = dyn_cast<SymbolRefAttr>(user.getCallableForCallee()))
403  if (auto *defOp = symbolCache.getDefinition(symbol.getLeafReference()))
404  symbolCache.removeUser(defOp, user);
405  });
406  symbolCache.removeDefinitionAndAllUsers(op);
407  rewriter.eraseOp(op);
408  return success();
409  }
410  return failure();
411 }
412 
413 LogicalResult
414 ICMPCanonicalizer::matchAndRewrite(comb::ICmpOp op,
415  PatternRewriter &rewriter) const {
416  auto getConstant = [&](const APInt &constant) -> Value {
417  return rewriter.create<hw::ConstantOp>(op.getLoc(), constant);
418  };
419  auto sameWidthIntegers = [](TypeRange types) -> std::optional<unsigned> {
420  if (llvm::all_equal(types) && !types.empty())
421  if (auto intType = dyn_cast<IntegerType>(*types.begin()))
422  return intType.getWidth();
423  return std::nullopt;
424  };
425  auto negate = [&](Value input) -> Value {
426  auto constTrue = rewriter.create<hw::ConstantOp>(op.getLoc(), APInt(1, 1));
427  return rewriter.create<comb::XorOp>(op.getLoc(), input, constTrue,
428  op.getTwoState());
429  };
430 
431  APInt rhs;
432  if (matchPattern(op.getRhs(), mlir::m_ConstantInt(&rhs))) {
433  if (auto concatOp = op.getLhs().getDefiningOp<comb::ConcatOp>()) {
434  if (auto optionalWidth =
435  sameWidthIntegers(concatOp->getOperands().getTypes())) {
436  if ((op.getPredicate() == comb::ICmpPredicate::eq ||
437  op.getPredicate() == comb::ICmpPredicate::ne) &&
438  rhs.isAllOnes()) {
439  Value andOp = rewriter.create<comb::AndOp>(
440  op.getLoc(), concatOp.getInputs(), op.getTwoState());
441  if (*optionalWidth == 1) {
442  if (op.getPredicate() == comb::ICmpPredicate::ne)
443  andOp = negate(andOp);
444  rewriter.replaceOp(op, andOp);
445  return success();
446  }
447  rewriter.replaceOpWithNewOp<comb::ICmpOp>(
448  op, op.getPredicate(), andOp,
449  getConstant(APInt(*optionalWidth, rhs.getZExtValue())),
450  op.getTwoState());
451  return success();
452  }
453 
454  if ((op.getPredicate() == comb::ICmpPredicate::ne ||
455  op.getPredicate() == comb::ICmpPredicate::eq) &&
456  rhs.isZero()) {
457  Value orOp = rewriter.create<comb::OrOp>(
458  op.getLoc(), concatOp.getInputs(), op.getTwoState());
459  if (*optionalWidth == 1) {
460  if (op.getPredicate() == comb::ICmpPredicate::eq)
461  orOp = negate(orOp);
462  rewriter.replaceOp(op, orOp);
463  return success();
464  }
465  rewriter.replaceOpWithNewOp<comb::ICmpOp>(
466  op, op.getPredicate(), orOp,
467  getConstant(APInt(*optionalWidth, rhs.getZExtValue())),
468  op.getTwoState());
469  return success();
470  }
471  }
472  }
473  }
474  return failure();
475 }
476 
477 LogicalResult RemoveUnusedArcArgumentsPattern::matchAndRewrite(
478  DefineOp op, PatternRewriter &rewriter) const {
479  BitVector toDelete(op.getNumArguments());
480  for (auto [i, arg] : llvm::enumerate(op.getArguments()))
481  if (arg.use_empty())
482  toDelete.set(i);
483 
484  if (toDelete.none())
485  return failure();
486 
487  // Collect the mutable callers in a first iteration. If there is a user that
488  // does not implement the interface, we have to abort the rewrite and have to
489  // make sure that we didn't change anything so far.
490  SmallVector<mlir::CallOpInterface> mutableUsers;
491  for (auto *user : symbolCache.getUsers(op)) {
492  auto callOpMutable = dyn_cast<mlir::CallOpInterface>(user);
493  if (!callOpMutable)
494  return failure();
495  mutableUsers.push_back(callOpMutable);
496  }
497 
498  // Do the actual rewrites.
499  for (auto user : mutableUsers)
500  for (int i = toDelete.size() - 1; i >= 0; --i)
501  if (toDelete[i])
502  user.getArgOperandsMutable().erase(i);
503 
504  op.eraseArguments(toDelete);
505  op.setFunctionType(
506  rewriter.getFunctionType(op.getArgumentTypes(), op.getResultTypes()));
507 
508  statistics.removeUnusedArcArgumentsPatternNumArgsRemoved += toDelete.count();
509  return success();
510 }
511 
512 LogicalResult
513 SinkArcInputsPattern::matchAndRewrite(DefineOp op,
514  PatternRewriter &rewriter) const {
515  // First check that all users implement the interface we need to be able to
516  // modify the users.
517  auto users = symbolCache.getUsers(op);
518  if (llvm::any_of(
519  users, [](auto *user) { return !isa<mlir::CallOpInterface>(user); }))
520  return failure();
521 
522  // Find all arguments that use constant operands only.
523  SmallVector<Operation *> stateConsts(op.getNumArguments());
524  bool first = true;
525  for (auto *user : users) {
526  auto callOp = cast<mlir::CallOpInterface>(user);
527  for (auto [constArg, input] :
528  llvm::zip(stateConsts, callOp.getArgOperands())) {
529  if (auto *constOp = input.getDefiningOp();
530  constOp && constOp->template hasTrait<OpTrait::ConstantLike>()) {
531  if (first) {
532  constArg = constOp;
533  continue;
534  }
535  if (constArg &&
536  constArg->getName() == input.getDefiningOp()->getName() &&
537  constArg->getAttrDictionary() ==
538  input.getDefiningOp()->getAttrDictionary())
539  continue;
540  }
541  constArg = nullptr;
542  }
543  first = false;
544  }
545 
546  // Move the constants into the arc and erase the block arguments.
547  rewriter.setInsertionPointToStart(&op.getBodyBlock());
548  llvm::BitVector toDelete(op.getBodyBlock().getNumArguments());
549  for (auto [constArg, arg] : llvm::zip(stateConsts, op.getArguments())) {
550  if (!constArg)
551  continue;
552  auto *inlinedConst = rewriter.clone(*constArg);
553  rewriter.replaceAllUsesWith(arg, inlinedConst->getResult(0));
554  toDelete.set(arg.getArgNumber());
555  }
556  op.getBodyBlock().eraseArguments(toDelete);
557  op.setType(rewriter.getFunctionType(op.getBodyBlock().getArgumentTypes(),
558  op.getResultTypes()));
559 
560  // Rewrite all arc uses to not pass in the constant anymore.
561  for (auto *user : users) {
562  auto callOp = cast<mlir::CallOpInterface>(user);
563  SmallPtrSet<Value, 4> maybeUnusedValues;
564  SmallVector<Value> newInputs;
565  for (auto [index, value] : llvm::enumerate(callOp.getArgOperands())) {
566  if (toDelete[index])
567  maybeUnusedValues.insert(value);
568  else
569  newInputs.push_back(value);
570  }
571  rewriter.modifyOpInPlace(
572  callOp, [&]() { callOp.getArgOperandsMutable().assign(newInputs); });
573  for (auto value : maybeUnusedValues)
574  if (value.use_empty())
575  rewriter.eraseOp(value.getDefiningOp());
576  }
577 
578  return success(toDelete.any());
579 }
580 
581 LogicalResult
582 CompRegCanonicalizer::matchAndRewrite(seq::CompRegOp op,
583  PatternRewriter &rewriter) const {
584  if (!op.getReset())
585  return failure();
586 
587  // Because Arcilator supports constant zero reset values, skip them.
588  APInt constant;
589  if (mlir::matchPattern(op.getResetValue(), mlir::m_ConstantInt(&constant)))
590  if (constant.isZero())
591  return failure();
592 
593  Value newInput = rewriter.create<comb::MuxOp>(
594  op->getLoc(), op.getReset(), op.getResetValue(), op.getInput());
595  rewriter.modifyOpInPlace(op, [&]() {
596  op.getInputMutable().set(newInput);
597  op.getResetMutable().clear();
598  op.getResetValueMutable().clear();
599  });
600 
601  return success();
602 }
603 
604 LogicalResult
605 MergeVectorizeOps::matchAndRewrite(VectorizeOp vecOp,
606  PatternRewriter &rewriter) const {
607  auto &currentBlock = vecOp.getBody().front();
608  IRMapping argMapping;
609  SmallVector<Value> newOperands;
610  SmallVector<VectorizeOp> vecOpsToRemove;
611  bool canBeMerged = false;
612  // Used to calculate the new positions of args after insertions and removals
613  unsigned paddedBy = 0;
614 
615  for (unsigned argIdx = 0, numArgs = vecOp.getInputs().size();
616  argIdx < numArgs; ++argIdx) {
617  auto inputVec = vecOp.getInputs()[argIdx];
618  // Make sure that the input comes from a `VectorizeOp`
619  // Ensure that the input vector matches the output of the `otherVecOp`
620  // Make sure that the results of the otherVecOp have only one use
621  auto otherVecOp = inputVec[0].getDefiningOp<VectorizeOp>();
622  if (!otherVecOp || inputVec != otherVecOp.getResults() ||
623  otherVecOp == vecOp ||
624  !llvm::all_of(otherVecOp.getResults(),
625  [](auto result) { return result.hasOneUse(); })) {
626  newOperands.insert(newOperands.end(), inputVec.begin(), inputVec.end());
627  continue;
628  }
629  // If this flag is set that means we changed the IR so we cannot return
630  // failure
631  canBeMerged = true;
632  newOperands.insert(newOperands.end(), otherVecOp.getOperands().begin(),
633  otherVecOp.getOperands().end());
634 
635  auto &otherBlock = otherVecOp.getBody().front();
636  for (auto &otherArg : otherBlock.getArguments()) {
637  auto newArg = currentBlock.insertArgument(
638  argIdx + paddedBy, otherArg.getType(), otherArg.getLoc());
639  argMapping.map(otherArg, newArg);
640  ++paddedBy;
641  }
642 
643  rewriter.setInsertionPointToStart(&currentBlock);
644  for (auto &op : otherBlock.without_terminator())
645  rewriter.clone(op, argMapping);
646 
647  unsigned argNewPos = paddedBy + argIdx;
648  // Get the result of the return value and use it in all places the
649  // the `otherVecOp` results were used
650  auto retOp = cast<VectorizeReturnOp>(otherBlock.getTerminator());
651  rewriter.replaceAllUsesWith(currentBlock.getArgument(argNewPos),
652  argMapping.lookupOrDefault(retOp.getValue()));
653  currentBlock.eraseArgument(argNewPos);
654  vecOpsToRemove.push_back(otherVecOp);
655  // We erased an arg so the padding decreased by 1
656  paddedBy--;
657  }
658 
659  // We didn't change the IR as there were no vectors to merge
660  if (!canBeMerged)
661  return failure();
662 
663  (void)updateInputOperands(vecOp, newOperands);
664 
665  // Erase dead VectorizeOps
666  for (auto deadOp : vecOpsToRemove)
667  rewriter.eraseOp(deadOp);
668 
669  return success();
670 }
671 
672 namespace llvm {
673 static unsigned hashValue(const SmallVector<Value> &inputs) {
674  unsigned hash = hash_value(inputs.size());
675  for (auto input : inputs)
676  hash = hash_combine(hash, input);
677  return hash;
678 }
679 
680 template <>
681 struct DenseMapInfo<SmallVector<Value>> {
682  static inline SmallVector<Value> getEmptyKey() {
683  return SmallVector<Value>();
684  }
685 
686  static inline SmallVector<Value> getTombstoneKey() {
687  return SmallVector<Value>();
688  }
689 
690  static unsigned getHashValue(const SmallVector<Value> &inputs) {
691  return hashValue(inputs);
692  }
693 
694  static bool isEqual(const SmallVector<Value> &lhs,
695  const SmallVector<Value> &rhs) {
696  return lhs == rhs;
697  }
698 };
699 } // namespace llvm
700 
701 LogicalResult KeepOneVecOp::matchAndRewrite(VectorizeOp vecOp,
702  PatternRewriter &rewriter) const {
703  BitVector argsToRemove(vecOp.getInputs().size(), false);
704  DenseMap<SmallVector<Value>, unsigned> inExists;
705  auto &currentBlock = vecOp.getBody().front();
706  SmallVector<Value> newOperands;
707  unsigned shuffledBy = 0;
708  bool changed = false;
709  for (auto [argIdx, inputVec] : llvm::enumerate(vecOp.getInputs())) {
710  auto input = SmallVector<Value>(inputVec.begin(), inputVec.end());
711  if (auto in = inExists.find(input); in != inExists.end()) {
712  argsToRemove.set(argIdx);
713  rewriter.replaceAllUsesWith(currentBlock.getArgument(argIdx - shuffledBy),
714  currentBlock.getArgument(in->second));
715  currentBlock.eraseArgument(argIdx - shuffledBy);
716  ++shuffledBy;
717  changed = true;
718  continue;
719  }
720  inExists[input] = argIdx;
721  newOperands.insert(newOperands.end(), inputVec.begin(), inputVec.end());
722  }
723 
724  if (!changed)
725  return failure();
726  return updateInputOperands(vecOp, newOperands);
727 }
728 
729 //===----------------------------------------------------------------------===//
730 // ArcCanonicalizerPass implementation
731 //===----------------------------------------------------------------------===//
732 
733 namespace {
734 struct ArcCanonicalizerPass
735  : public arc::impl::ArcCanonicalizerBase<ArcCanonicalizerPass> {
736  void runOnOperation() override;
737 };
738 } // namespace
739 
740 void ArcCanonicalizerPass::runOnOperation() {
741  MLIRContext &ctxt = getContext();
742  SymbolTableCollection symbolTable;
743  SymbolHandler cache;
744  cache.addDefinitions(getOperation());
745  cache.collectAllSymbolUses(getOperation(), symbolTable);
746  Namespace names;
747  names.add(cache);
748  DenseMap<StringAttr, StringAttr> arcMapping;
749 
750  mlir::GreedyRewriteConfig config;
751  config.enableRegionSimplification = mlir::GreedySimplifyRegionLevel::Disabled;
752  config.maxIterations = 10;
753  config.useTopDownTraversal = true;
754  ArcListener listener(&cache);
755  config.listener = &listener;
756 
757  PatternStatistics statistics;
758  RewritePatternSet symbolPatterns(&getContext());
759  symbolPatterns.add<CallPassthroughArc, RemoveUnusedArcs,
760  RemoveUnusedArcArgumentsPattern, SinkArcInputsPattern>(
761  &getContext(), cache, names, statistics);
762  symbolPatterns.add<MemWritePortEnableAndMaskCanonicalizer>(
763  &getContext(), cache, names, statistics, arcMapping);
764 
765  if (failed(mlir::applyPatternsAndFoldGreedily(
766  getOperation(), std::move(symbolPatterns), config)))
767  return signalPassFailure();
768 
769  numArcArgsRemoved = statistics.removeUnusedArcArgumentsPatternNumArgsRemoved;
770 
771  RewritePatternSet patterns(&ctxt);
772  for (auto *dialect : ctxt.getLoadedDialects())
773  dialect->getCanonicalizationPatterns(patterns);
774  for (mlir::RegisteredOperationName op : ctxt.getRegisteredOperations())
775  op.getCanonicalizationPatterns(patterns, &ctxt);
776  patterns.add<ICMPCanonicalizer, CompRegCanonicalizer, MergeVectorizeOps,
777  KeepOneVecOp>(&getContext());
778 
779  // Don't test for convergence since it is often not reached.
780  (void)mlir::applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
781  config);
782 }
783 
784 std::unique_ptr<mlir::Pass> arc::createArcCanonicalizerPass() {
785  return std::make_unique<ArcCanonicalizerPass>();
786 }
LogicalResult canonicalizePassthoughCall(mlir::CallOpInterface callOp, SymbolHandler &symbolCache, PatternRewriter &rewriter)
LogicalResult updateInputOperands(VectorizeOp &vecOp, const SmallVector< Value > &newOperands)
assert(baseType &&"element must be base type")
static std::optional< APSInt > getConstant(Attribute operand)
Determine the value of a constant operand for the sake of constant folding.
static InstancePath empty
A namespace that is used to store existing names and generate new names in some scope within the IR.
Definition: Namespace.h:30
void add(mlir::ModuleOp module)
Definition: Namespace.h:46
Default symbol cache implementation; stores associations between names (StringAttr's) to mlir::Operat...
Definition: SymCache.h:85
SymbolCacheBase::Iterator end() override
Definition: SymCache.h:125
def create(data_type, value)
Definition: hw.py:393
std::unique_ptr< mlir::Pass > createArcCanonicalizerPass()
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21
inline ::llvm::hash_code hash_value(const FieldRef &fieldRef)
Get a hash code for a FieldRef.
Definition: FieldRef.h:92
static unsigned hashValue(const SmallVector< Value > &inputs)
static SmallVector< Value > getTombstoneKey()
static SmallVector< Value > getEmptyKey()
static bool isEqual(const SmallVector< Value > &lhs, const SmallVector< Value > &rhs)
static unsigned getHashValue(const SmallVector< Value > &inputs)