Loading [MathJax]/extensions/tex2jax.js
CIRCT 22.0.0git
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
ArcCanonicalizer.cpp
Go to the documentation of this file.
1//===- ArcCanonicalizer.cpp -------------------------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//===----------------------------------------------------------------------===//
7//
8// Simulation centric canonicalizations for non-arc operations and
9// canonicalizations that require efficient symbol lookups.
10//
11//===----------------------------------------------------------------------===//
12
20#include "mlir/IR/IRMapping.h"
21#include "mlir/IR/Matchers.h"
22#include "mlir/IR/PatternMatch.h"
23#include "mlir/Pass/Pass.h"
24#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
25#include "llvm/Support/Debug.h"
26
27#define DEBUG_TYPE "arc-canonicalizer"
28
29namespace circt {
30namespace arc {
31#define GEN_PASS_DEF_ARCCANONICALIZER
32#include "circt/Dialect/Arc/ArcPasses.h.inc"
33} // namespace arc
34} // namespace circt
35
36using namespace circt;
37using namespace arc;
38
39//===----------------------------------------------------------------------===//
40// Datastructures
41//===----------------------------------------------------------------------===//
42
43namespace {
44
45/// A combination of SymbolCache and SymbolUserMap that also allows to add users
46/// and remove symbols on-demand.
47class SymbolHandler : public SymbolCache {
48public:
49 /// Return the users of the provided symbol operation.
50 ArrayRef<Operation *> getUsers(Operation *symbol) const {
51 auto it = userMap.find(symbol);
52 return it != userMap.end() ? it->second.getArrayRef()
53 : ArrayRef<Operation *>();
54 }
55
56 /// Return true if the given symbol has no uses.
57 bool useEmpty(Operation *symbol) {
58 return !userMap.count(symbol) || userMap[symbol].empty();
59 }
60
61 void addUser(Operation *def, Operation *user) {
62 assert(isa<mlir::SymbolOpInterface>(def));
63 if (!symbolCache.contains(cast<mlir::SymbolOpInterface>(def).getNameAttr()))
64 symbolCache.insert(
65 {cast<mlir::SymbolOpInterface>(def).getNameAttr(), def});
66 userMap[def].insert(user);
67 }
68
69 void removeUser(Operation *def, Operation *user) {
70 assert(isa<mlir::SymbolOpInterface>(def));
71 if (symbolCache.contains(cast<mlir::SymbolOpInterface>(def).getNameAttr()))
72 userMap[def].remove(user);
73 if (userMap[def].empty())
74 userMap.erase(def);
75 }
76
77 void removeDefinitionAndAllUsers(Operation *def) {
78 assert(isa<mlir::SymbolOpInterface>(def));
79 symbolCache.erase(cast<mlir::SymbolOpInterface>(def).getNameAttr());
80 userMap.erase(def);
81 }
82
83 void collectAllSymbolUses(Operation *symbolTableOp,
84 SymbolTableCollection &symbolTable) {
85 // NOTE: the following is almost 1-1 taken from the SymbolUserMap
86 // constructor. They made it difficult to extend the implementation by
87 // having a lot of members private and non-virtual methods.
88 SmallVector<Operation *> symbols;
89 auto walkFn = [&](Operation *symbolTableOp, bool allUsesVisible) {
90 for (Operation &nestedOp : symbolTableOp->getRegion(0).getOps()) {
91 auto symbolUses = SymbolTable::getSymbolUses(&nestedOp);
92 assert(symbolUses && "expected uses to be valid");
93
94 for (const SymbolTable::SymbolUse &use : *symbolUses) {
95 symbols.clear();
96 (void)symbolTable.lookupSymbolIn(symbolTableOp, use.getSymbolRef(),
97 symbols);
98 for (Operation *symbolOp : symbols)
99 userMap[symbolOp].insert(use.getUser());
100 }
101 }
102 };
103 // We just set `allSymUsesVisible` to false here because it isn't necessary
104 // for building the user map.
105 SymbolTable::walkSymbolTables(symbolTableOp, /*allSymUsesVisible=*/false,
106 walkFn);
107 }
108
109private:
110 DenseMap<Operation *, SetVector<Operation *>> userMap;
111};
112
113/// A Listener keeping the provided SymbolHandler up-to-date. This is especially
114/// important for simplifications (e.g. DCE) the rewriter performs automatically
115/// that we cannot or do not want to turn off.
116class ArcListener : public mlir::RewriterBase::Listener {
117public:
118 explicit ArcListener(SymbolHandler *handler) : Listener(), handler(handler) {}
119
120 void notifyOperationReplaced(Operation *op, Operation *replacement) override {
121 // If, e.g., a DefineOp is replaced with another DefineOp but with the same
122 // symbol, we don't want to drop the list of users.
123 auto symOp = dyn_cast<mlir::SymbolOpInterface>(op);
124 auto symReplacement = dyn_cast<mlir::SymbolOpInterface>(replacement);
125 if (symOp && symReplacement &&
126 symOp.getNameAttr() == symReplacement.getNameAttr())
127 return;
128
129 remove(op);
130 // TODO: if an operation is inserted that defines a symbol and the symbol
131 // already has uses, those users are not added.
132 add(replacement);
133 }
134
135 void notifyOperationReplaced(Operation *op, ValueRange replacement) override {
136 remove(op);
137 }
138
139 void notifyOperationErased(Operation *op) override { remove(op); }
140
141 void notifyOperationInserted(Operation *op,
142 mlir::IRRewriter::InsertPoint) override {
143 // TODO: if an operation is inserted that defines a symbol and the symbol
144 // already has uses, those users are not added.
145 add(op);
146 }
147
148private:
149 FailureOr<Operation *> maybeGetDefinition(Operation *op) {
150 if (auto callOp = dyn_cast<mlir::CallOpInterface>(op)) {
151 auto symAttr =
152 dyn_cast<mlir::SymbolRefAttr>(callOp.getCallableForCallee());
153 if (!symAttr)
154 return failure();
155 if (auto *def = handler->getDefinition(symAttr.getLeafReference()))
156 return def;
157 }
158 return failure();
159 }
160
161 void remove(Operation *op) {
162 auto maybeDef = maybeGetDefinition(op);
163 if (!failed(maybeDef))
164 handler->removeUser(*maybeDef, op);
165
166 if (isa<mlir::SymbolOpInterface>(op))
167 handler->removeDefinitionAndAllUsers(op);
168 }
169
170 void add(Operation *op) {
171 auto maybeDef = maybeGetDefinition(op);
172 if (!failed(maybeDef))
173 handler->addUser(*maybeDef, op);
174
175 if (auto defOp = dyn_cast<mlir::SymbolOpInterface>(op))
176 handler->addDefinition(defOp.getNameAttr(), op);
177 }
178
179 SymbolHandler *handler;
180};
181
182struct PatternStatistics {
183 unsigned removeUnusedArcArgumentsPatternNumArgsRemoved = 0;
184};
185
186} // namespace
187
188//===----------------------------------------------------------------------===//
189// Canonicalization patterns
190//===----------------------------------------------------------------------===//
191
192namespace {
193/// A rewrite pattern that has access to a symbol cache to access and modify the
194/// symbol-defining op and symbol users as well as a namespace to query new
195/// names. Each pattern has to make sure that the symbol handler is kept
196/// up-to-date no matter whether the pattern succeeds of fails.
197template <typename SourceOp>
198class SymOpRewritePattern : public OpRewritePattern<SourceOp> {
199public:
200 SymOpRewritePattern(MLIRContext *ctxt, SymbolHandler &symbolCache,
201 Namespace &names, PatternStatistics &stats,
202 mlir::PatternBenefit benefit = 1,
203 ArrayRef<StringRef> generatedNames = {})
204 : OpRewritePattern<SourceOp>(ctxt, benefit, generatedNames), names(names),
205 symbolCache(symbolCache), statistics(stats) {}
206
207protected:
208 Namespace &names;
209 SymbolHandler &symbolCache;
210 PatternStatistics &statistics;
211};
212
213class MemWritePortEnableAndMaskCanonicalizer
214 : public SymOpRewritePattern<MemoryWritePortOp> {
215public:
216 MemWritePortEnableAndMaskCanonicalizer(
217 MLIRContext *ctxt, SymbolHandler &symbolCache, Namespace &names,
218 PatternStatistics &stats, DenseMap<StringAttr, StringAttr> &arcMapping)
219 : SymOpRewritePattern<MemoryWritePortOp>(ctxt, symbolCache, names, stats),
220 arcMapping(arcMapping) {}
221 LogicalResult matchAndRewrite(MemoryWritePortOp op,
222 PatternRewriter &rewriter) const final;
223
224private:
225 DenseMap<StringAttr, StringAttr> &arcMapping;
226};
227
228struct CallPassthroughArc : public SymOpRewritePattern<CallOp> {
229 using SymOpRewritePattern::SymOpRewritePattern;
230 LogicalResult matchAndRewrite(CallOp op,
231 PatternRewriter &rewriter) const final;
232};
233
234struct RemoveUnusedArcs : public SymOpRewritePattern<DefineOp> {
235 using SymOpRewritePattern::SymOpRewritePattern;
236 LogicalResult matchAndRewrite(DefineOp op,
237 PatternRewriter &rewriter) const final;
238};
239
240struct ICMPCanonicalizer : public OpRewritePattern<comb::ICmpOp> {
241 using OpRewritePattern::OpRewritePattern;
242 LogicalResult matchAndRewrite(comb::ICmpOp op,
243 PatternRewriter &rewriter) const final;
244};
245
246struct CompRegCanonicalizer : public OpRewritePattern<seq::CompRegOp> {
247 using OpRewritePattern::OpRewritePattern;
248 LogicalResult matchAndRewrite(seq::CompRegOp op,
249 PatternRewriter &rewriter) const final;
250};
251
252struct RemoveUnusedArcArgumentsPattern : public SymOpRewritePattern<DefineOp> {
253 using SymOpRewritePattern::SymOpRewritePattern;
254 LogicalResult matchAndRewrite(DefineOp op,
255 PatternRewriter &rewriter) const final;
256};
257
258struct SinkArcInputsPattern : public SymOpRewritePattern<DefineOp> {
259 using SymOpRewritePattern::SymOpRewritePattern;
260 LogicalResult matchAndRewrite(DefineOp op,
261 PatternRewriter &rewriter) const final;
262};
263
264struct MergeVectorizeOps : public OpRewritePattern<VectorizeOp> {
265 using OpRewritePattern::OpRewritePattern;
266 LogicalResult matchAndRewrite(VectorizeOp op,
267 PatternRewriter &rewriter) const final;
268};
269
270struct KeepOneVecOp : public OpRewritePattern<VectorizeOp> {
271 using OpRewritePattern::OpRewritePattern;
272 LogicalResult matchAndRewrite(VectorizeOp op,
273 PatternRewriter &rewriter) const final;
274};
275
276} // namespace
277
278//===----------------------------------------------------------------------===//
279// Helpers
280//===----------------------------------------------------------------------===//
281
282LogicalResult canonicalizePassthoughCall(mlir::CallOpInterface callOp,
283 SymbolHandler &symbolCache,
284 PatternRewriter &rewriter) {
285 auto defOp = cast<DefineOp>(symbolCache.getDefinition(
286 llvm::cast<SymbolRefAttr>(callOp.getCallableForCallee())
287 .getLeafReference()));
288 if (defOp.isPassthrough()) {
289 symbolCache.removeUser(defOp, callOp);
290 rewriter.replaceOp(callOp, callOp.getArgOperands());
291 return success();
292 }
293 return failure();
294}
295
296LogicalResult updateInputOperands(VectorizeOp &vecOp,
297 const SmallVector<Value> &newOperands) {
298 // Set the new inputOperandSegments value
299 unsigned groupSize = vecOp.getResults().size();
300 unsigned numOfGroups = newOperands.size() / groupSize;
301 SmallVector<int32_t> newAttr(numOfGroups, groupSize);
302 vecOp.setInputOperandSegments(newAttr);
303 vecOp.getOperation()->setOperands(ValueRange(newOperands));
304 return success();
305}
306
307//===----------------------------------------------------------------------===//
308// Canonicalization pattern implementations
309//===----------------------------------------------------------------------===//
310
311LogicalResult MemWritePortEnableAndMaskCanonicalizer::matchAndRewrite(
312 MemoryWritePortOp op, PatternRewriter &rewriter) const {
313 auto defOp = cast<DefineOp>(symbolCache.getDefinition(op.getArcAttr()));
314 APInt enable;
315
316 if (op.getEnable() &&
317 mlir::matchPattern(
318 defOp.getBodyBlock().getTerminator()->getOperand(op.getEnableIdx()),
319 mlir::m_ConstantInt(&enable))) {
320 if (enable.isZero()) {
321 symbolCache.removeUser(defOp, op);
322 rewriter.eraseOp(op);
323 if (symbolCache.useEmpty(defOp)) {
324 symbolCache.removeDefinitionAndAllUsers(defOp);
325 rewriter.eraseOp(defOp);
326 }
327 return success();
328 }
329 if (enable.isAllOnes()) {
330 if (arcMapping.count(defOp.getNameAttr())) {
331 auto arcWithoutEnable = arcMapping[defOp.getNameAttr()];
332 // Remove the enable attribute
333 rewriter.modifyOpInPlace(op, [&]() {
334 op.setEnable(false);
335 op.setArc(arcWithoutEnable.getValue());
336 });
337 symbolCache.removeUser(defOp, op);
338 symbolCache.addUser(symbolCache.getDefinition(arcWithoutEnable), op);
339 return success();
340 }
341
342 auto newName = names.newName(defOp.getName());
343 auto users = SmallVector<Operation *>(symbolCache.getUsers(defOp));
344 symbolCache.removeDefinitionAndAllUsers(defOp);
345
346 // Remove the enable attribute
347 rewriter.modifyOpInPlace(op, [&]() {
348 op.setEnable(false);
349 op.setArc(newName);
350 });
351
352 auto newResultTypes = op.getArcResultTypes();
353
354 // Create a new arc that acts as replacement for other users
355 rewriter.setInsertionPoint(defOp);
356 auto newDefOp = rewriter.cloneWithoutRegions(defOp);
357 auto *block = rewriter.createBlock(
358 &newDefOp.getBody(), newDefOp.getBody().end(),
359 newDefOp.getArgumentTypes(),
360 SmallVector<Location>(newDefOp.getNumArguments(), defOp.getLoc()));
361 auto callOp = rewriter.create<CallOp>(newDefOp.getLoc(), newResultTypes,
362 newName, block->getArguments());
363 SmallVector<Value> results(callOp->getResults());
364 Value constTrue = rewriter.create<hw::ConstantOp>(
365 newDefOp.getLoc(), rewriter.getI1Type(), 1);
366 results.insert(results.begin() + op.getEnableIdx(), constTrue);
367 rewriter.create<OutputOp>(newDefOp.getLoc(), results);
368
369 // Remove the enable output from the current arc
370 auto *terminator = defOp.getBodyBlock().getTerminator();
371 rewriter.modifyOpInPlace(
372 terminator, [&]() { terminator->eraseOperand(op.getEnableIdx()); });
373 rewriter.modifyOpInPlace(defOp, [&]() {
374 defOp.setName(newName);
375 defOp.setFunctionType(
376 rewriter.getFunctionType(defOp.getArgumentTypes(), newResultTypes));
377 });
378
379 // Update symbol cache
380 symbolCache.addDefinition(defOp.getNameAttr(), defOp);
381 symbolCache.addDefinition(newDefOp.getNameAttr(), newDefOp);
382 symbolCache.addUser(defOp, callOp);
383 for (auto *user : users)
384 symbolCache.addUser(user == op ? defOp : newDefOp, user);
385
386 arcMapping[newDefOp.getNameAttr()] = defOp.getNameAttr();
387 return success();
388 }
389 }
390 return failure();
391}
392
393LogicalResult
394CallPassthroughArc::matchAndRewrite(CallOp op,
395 PatternRewriter &rewriter) const {
396 return canonicalizePassthoughCall(op, symbolCache, rewriter);
397}
398
399LogicalResult
400RemoveUnusedArcs::matchAndRewrite(DefineOp op,
401 PatternRewriter &rewriter) const {
402 if (symbolCache.useEmpty(op)) {
403 op.getBody().walk([&](mlir::CallOpInterface user) {
404 if (auto symbol = dyn_cast<SymbolRefAttr>(user.getCallableForCallee()))
405 if (auto *defOp = symbolCache.getDefinition(symbol.getLeafReference()))
406 symbolCache.removeUser(defOp, user);
407 });
408 symbolCache.removeDefinitionAndAllUsers(op);
409 rewriter.eraseOp(op);
410 return success();
411 }
412 return failure();
413}
414
415LogicalResult
416ICMPCanonicalizer::matchAndRewrite(comb::ICmpOp op,
417 PatternRewriter &rewriter) const {
418 auto getConstant = [&](const APInt &constant) -> Value {
419 return rewriter.create<hw::ConstantOp>(op.getLoc(), constant);
420 };
421 auto sameWidthIntegers = [](TypeRange types) -> std::optional<unsigned> {
422 if (llvm::all_equal(types) && !types.empty())
423 if (auto intType = dyn_cast<IntegerType>(*types.begin()))
424 return intType.getWidth();
425 return std::nullopt;
426 };
427 auto negate = [&](Value input) -> Value {
428 auto constTrue = rewriter.create<hw::ConstantOp>(op.getLoc(), APInt(1, 1));
429 return rewriter.create<comb::XorOp>(op.getLoc(), input, constTrue,
430 op.getTwoState());
431 };
432
433 APInt rhs;
434 if (matchPattern(op.getRhs(), mlir::m_ConstantInt(&rhs))) {
435 if (auto concatOp = op.getLhs().getDefiningOp<comb::ConcatOp>()) {
436 if (auto optionalWidth =
437 sameWidthIntegers(concatOp->getOperands().getTypes())) {
438 if ((op.getPredicate() == comb::ICmpPredicate::eq ||
439 op.getPredicate() == comb::ICmpPredicate::ne) &&
440 rhs.isAllOnes()) {
441 Value andOp = rewriter.create<comb::AndOp>(
442 op.getLoc(), concatOp.getInputs(), op.getTwoState());
443 if (*optionalWidth == 1) {
444 if (op.getPredicate() == comb::ICmpPredicate::ne)
445 andOp = negate(andOp);
446 rewriter.replaceOp(op, andOp);
447 return success();
448 }
449 rewriter.replaceOpWithNewOp<comb::ICmpOp>(
450 op, op.getPredicate(), andOp,
451 getConstant(APInt(*optionalWidth, rhs.getZExtValue(),
452 /*isSigned=*/false, /*implicitTrunc=*/true)),
453 op.getTwoState());
454 return success();
455 }
456
457 if ((op.getPredicate() == comb::ICmpPredicate::ne ||
458 op.getPredicate() == comb::ICmpPredicate::eq) &&
459 rhs.isZero()) {
460 Value orOp = rewriter.create<comb::OrOp>(
461 op.getLoc(), concatOp.getInputs(), op.getTwoState());
462 if (*optionalWidth == 1) {
463 if (op.getPredicate() == comb::ICmpPredicate::eq)
464 orOp = negate(orOp);
465 rewriter.replaceOp(op, orOp);
466 return success();
467 }
468 rewriter.replaceOpWithNewOp<comb::ICmpOp>(
469 op, op.getPredicate(), orOp,
470 getConstant(APInt(*optionalWidth, rhs.getZExtValue(),
471 /*isSigned=*/false, /*implicitTrunc=*/true)),
472 op.getTwoState());
473 return success();
474 }
475 }
476 }
477 }
478 return failure();
479}
480
481LogicalResult RemoveUnusedArcArgumentsPattern::matchAndRewrite(
482 DefineOp op, PatternRewriter &rewriter) const {
483 BitVector toDelete(op.getNumArguments());
484 for (auto [i, arg] : llvm::enumerate(op.getArguments()))
485 if (arg.use_empty())
486 toDelete.set(i);
487
488 if (toDelete.none())
489 return failure();
490
491 // Collect the mutable callers in a first iteration. If there is a user that
492 // does not implement the interface, we have to abort the rewrite and have to
493 // make sure that we didn't change anything so far.
494 SmallVector<mlir::CallOpInterface> mutableUsers;
495 for (auto *user : symbolCache.getUsers(op)) {
496 auto callOpMutable = dyn_cast<mlir::CallOpInterface>(user);
497 if (!callOpMutable)
498 return failure();
499 mutableUsers.push_back(callOpMutable);
500 }
501
502 // Do the actual rewrites.
503 for (auto user : mutableUsers)
504 for (int i = toDelete.size() - 1; i >= 0; --i)
505 if (toDelete[i])
506 user.getArgOperandsMutable().erase(i);
507
508 if (failed(op.eraseArguments(toDelete)))
509 return failure();
510 op.setFunctionType(
511 rewriter.getFunctionType(op.getArgumentTypes(), op.getResultTypes()));
512
513 statistics.removeUnusedArcArgumentsPatternNumArgsRemoved += toDelete.count();
514 return success();
515}
516
517LogicalResult
518SinkArcInputsPattern::matchAndRewrite(DefineOp op,
519 PatternRewriter &rewriter) const {
520 // First check that all users implement the interface we need to be able to
521 // modify the users.
522 auto users = symbolCache.getUsers(op);
523 if (llvm::any_of(
524 users, [](auto *user) { return !isa<mlir::CallOpInterface>(user); }))
525 return failure();
526
527 // Find all arguments that use constant operands only.
528 SmallVector<Operation *> stateConsts(op.getNumArguments());
529 bool first = true;
530 for (auto *user : users) {
531 auto callOp = cast<mlir::CallOpInterface>(user);
532 for (auto [constArg, input] :
533 llvm::zip(stateConsts, callOp.getArgOperands())) {
534 if (auto *constOp = input.getDefiningOp();
535 constOp && constOp->template hasTrait<OpTrait::ConstantLike>()) {
536 if (first) {
537 constArg = constOp;
538 continue;
539 }
540 if (constArg &&
541 constArg->getName() == input.getDefiningOp()->getName() &&
542 constArg->getAttrDictionary() ==
543 input.getDefiningOp()->getAttrDictionary())
544 continue;
545 }
546 constArg = nullptr;
547 }
548 first = false;
549 }
550
551 // Move the constants into the arc and erase the block arguments.
552 rewriter.setInsertionPointToStart(&op.getBodyBlock());
553 llvm::BitVector toDelete(op.getBodyBlock().getNumArguments());
554 for (auto [constArg, arg] : llvm::zip(stateConsts, op.getArguments())) {
555 if (!constArg)
556 continue;
557 auto *inlinedConst = rewriter.clone(*constArg);
558 rewriter.replaceAllUsesWith(arg, inlinedConst->getResult(0));
559 toDelete.set(arg.getArgNumber());
560 }
561 op.getBodyBlock().eraseArguments(toDelete);
562 op.setType(rewriter.getFunctionType(op.getBodyBlock().getArgumentTypes(),
563 op.getResultTypes()));
564
565 // Rewrite all arc uses to not pass in the constant anymore.
566 for (auto *user : users) {
567 auto callOp = cast<mlir::CallOpInterface>(user);
568 SmallPtrSet<Value, 4> maybeUnusedValues;
569 SmallVector<Value> newInputs;
570 for (auto [index, value] : llvm::enumerate(callOp.getArgOperands())) {
571 if (toDelete[index])
572 maybeUnusedValues.insert(value);
573 else
574 newInputs.push_back(value);
575 }
576 rewriter.modifyOpInPlace(
577 callOp, [&]() { callOp.getArgOperandsMutable().assign(newInputs); });
578 for (auto value : maybeUnusedValues)
579 if (value.use_empty())
580 rewriter.eraseOp(value.getDefiningOp());
581 }
582
583 return success(toDelete.any());
584}
585
586LogicalResult
587CompRegCanonicalizer::matchAndRewrite(seq::CompRegOp op,
588 PatternRewriter &rewriter) const {
589 if (!op.getReset())
590 return failure();
591
592 // Because Arcilator supports constant zero reset values, skip them.
593 APInt constant;
594 if (mlir::matchPattern(op.getResetValue(), mlir::m_ConstantInt(&constant)))
595 if (constant.isZero())
596 return failure();
597
598 Value newInput = rewriter.create<comb::MuxOp>(
599 op->getLoc(), op.getReset(), op.getResetValue(), op.getInput());
600 rewriter.modifyOpInPlace(op, [&]() {
601 op.getInputMutable().set(newInput);
602 op.getResetMutable().clear();
603 op.getResetValueMutable().clear();
604 });
605
606 return success();
607}
608
609LogicalResult
610MergeVectorizeOps::matchAndRewrite(VectorizeOp vecOp,
611 PatternRewriter &rewriter) const {
612 auto &currentBlock = vecOp.getBody().front();
613 IRMapping argMapping;
614 SmallVector<Value> newOperands;
615 SmallVector<VectorizeOp> vecOpsToRemove;
616 bool canBeMerged = false;
617 // Used to calculate the new positions of args after insertions and removals
618 unsigned paddedBy = 0;
619
620 for (unsigned argIdx = 0, numArgs = vecOp.getInputs().size();
621 argIdx < numArgs; ++argIdx) {
622 auto inputVec = vecOp.getInputs()[argIdx];
623 // Make sure that the input comes from a `VectorizeOp`
624 // Ensure that the input vector matches the output of the `otherVecOp`
625 // Make sure that the results of the otherVecOp have only one use
626 auto otherVecOp = inputVec[0].getDefiningOp<VectorizeOp>();
627 if (!otherVecOp || otherVecOp == vecOp ||
628 !llvm::all_of(otherVecOp.getResults(),
629 [](auto result) { return result.hasOneUse(); }) ||
630 !llvm::all_of(inputVec, [&](auto result) {
631 return result.template getDefiningOp<VectorizeOp>() == otherVecOp;
632 })) {
633 newOperands.insert(newOperands.end(), inputVec.begin(), inputVec.end());
634 continue;
635 }
636
637 // Here, all elements are from the same `VectorizeOp`.
638 // If all elements of the input vector come from the same `VectorizeOp`
639 // sort the vectors by their indices
640 DenseMap<Value, size_t> resultIdxMap;
641 for (auto [resultIdx, result] : llvm::enumerate(otherVecOp.getResults()))
642 resultIdxMap[result] = resultIdx;
643
644 SmallVector<Value> tempVec(inputVec.begin(), inputVec.end());
645 llvm::sort(tempVec, [&](Value a, Value b) {
646 return resultIdxMap[a] < resultIdxMap[b];
647 });
648
649 // Check if inputVec matches the result after sorting.
650 if (tempVec != SmallVector<Value>(otherVecOp.getResults().begin(),
651 otherVecOp.getResults().end())) {
652 newOperands.insert(newOperands.end(), inputVec.begin(), inputVec.end());
653 continue;
654 }
655
656 DenseMap<size_t, size_t> fromRealIdxToSortedIdx;
657 for (auto [inIdx, in] : llvm::enumerate(inputVec))
658 fromRealIdxToSortedIdx[inIdx] = resultIdxMap[in];
659
660 // If this flag is set that means we changed the IR so we cannot return
661 // failure
662 canBeMerged = true;
663
664 // If the results got shuffled, then shuffle the operands before merging.
665 if (inputVec != otherVecOp.getResults()) {
666 for (auto otherVecOpInputVec : otherVecOp.getInputs()) {
667 // use the tempVec again instead of creating another one.
668 tempVec = SmallVector<Value>(inputVec.size());
669 for (auto [realIdx, opernad] : llvm::enumerate(otherVecOpInputVec))
670 tempVec[realIdx] =
671 otherVecOpInputVec[fromRealIdxToSortedIdx[realIdx]];
672
673 newOperands.insert(newOperands.end(), tempVec.begin(), tempVec.end());
674 }
675
676 } else
677 newOperands.insert(newOperands.end(), otherVecOp.getOperands().begin(),
678 otherVecOp.getOperands().end());
679
680 auto &otherBlock = otherVecOp.getBody().front();
681 for (auto &otherArg : otherBlock.getArguments()) {
682 auto newArg = currentBlock.insertArgument(
683 argIdx + paddedBy, otherArg.getType(), otherArg.getLoc());
684 argMapping.map(otherArg, newArg);
685 ++paddedBy;
686 }
687
688 rewriter.setInsertionPointToStart(&currentBlock);
689 for (auto &op : otherBlock.without_terminator())
690 rewriter.clone(op, argMapping);
691
692 unsigned argNewPos = paddedBy + argIdx;
693 // Get the result of the return value and use it in all places the
694 // the `otherVecOp` results were used
695 auto retOp = cast<VectorizeReturnOp>(otherBlock.getTerminator());
696 rewriter.replaceAllUsesWith(currentBlock.getArgument(argNewPos),
697 argMapping.lookupOrDefault(retOp.getValue()));
698 currentBlock.eraseArgument(argNewPos);
699 vecOpsToRemove.push_back(otherVecOp);
700 // We erased an arg so the padding decreased by 1
701 paddedBy--;
702 }
703
704 // We didn't change the IR as there were no vectors to merge
705 if (!canBeMerged)
706 return failure();
707
708 (void)updateInputOperands(vecOp, newOperands);
709
710 // Erase dead VectorizeOps
711 for (auto deadOp : vecOpsToRemove)
712 rewriter.eraseOp(deadOp);
713
714 return success();
715}
716
717namespace llvm {
718static unsigned hashValue(const SmallVector<Value> &inputs) {
719 unsigned hash = hash_value(inputs.size());
720 for (auto input : inputs)
721 hash = hash_combine(hash, input);
722 return hash;
723}
724
725template <>
726struct DenseMapInfo<SmallVector<Value>> {
727 static inline SmallVector<Value> getEmptyKey() {
728 return SmallVector<Value>();
729 }
730
731 static inline SmallVector<Value> getTombstoneKey() {
732 return SmallVector<Value>();
733 }
734
735 static unsigned getHashValue(const SmallVector<Value> &inputs) {
736 return hashValue(inputs);
737 }
738
739 static bool isEqual(const SmallVector<Value> &lhs,
740 const SmallVector<Value> &rhs) {
741 return lhs == rhs;
742 }
743};
744} // namespace llvm
745
746LogicalResult KeepOneVecOp::matchAndRewrite(VectorizeOp vecOp,
747 PatternRewriter &rewriter) const {
748 DenseMap<SmallVector<Value>, unsigned> inExists;
749 auto &currentBlock = vecOp.getBody().front();
750 SmallVector<Value> newOperands;
751 BitVector argsToRemove(vecOp.getInputs().size(), false);
752 for (size_t argIdx = 0; argIdx < vecOp.getInputs().size(); ++argIdx) {
753 auto input = SmallVector<Value>(vecOp.getInputs()[argIdx].begin(),
754 vecOp.getInputs()[argIdx].end());
755 if (auto in = inExists.find(input); in != inExists.end()) {
756 rewriter.replaceAllUsesWith(currentBlock.getArgument(argIdx),
757 currentBlock.getArgument(in->second));
758 argsToRemove.set(argIdx);
759 continue;
760 }
761 inExists[input] = argIdx;
762 newOperands.insert(newOperands.end(), input.begin(), input.end());
763 }
764
765 if (argsToRemove.none())
766 return failure();
767
768 currentBlock.eraseArguments(argsToRemove);
769 return updateInputOperands(vecOp, newOperands);
770}
771
772//===----------------------------------------------------------------------===//
773// ArcCanonicalizerPass implementation
774//===----------------------------------------------------------------------===//
775
776namespace {
777struct ArcCanonicalizerPass
778 : public arc::impl::ArcCanonicalizerBase<ArcCanonicalizerPass> {
779 void runOnOperation() override;
780};
781} // namespace
782
783void ArcCanonicalizerPass::runOnOperation() {
784 MLIRContext &ctxt = getContext();
785 SymbolTableCollection symbolTable;
786 SymbolHandler cache;
787 cache.addDefinitions(getOperation());
788 cache.collectAllSymbolUses(getOperation(), symbolTable);
789 Namespace names;
790 names.add(cache);
791 DenseMap<StringAttr, StringAttr> arcMapping;
792
793 mlir::GreedyRewriteConfig config;
794 config.setRegionSimplificationLevel(
795 mlir::GreedySimplifyRegionLevel::Disabled);
796 config.setMaxIterations(10);
797 config.setUseTopDownTraversal(true);
798 ArcListener listener(&cache);
799 config.setListener(&listener);
800
801 PatternStatistics statistics;
802 RewritePatternSet symbolPatterns(&getContext());
803 symbolPatterns.add<CallPassthroughArc, RemoveUnusedArcs,
804 RemoveUnusedArcArgumentsPattern, SinkArcInputsPattern>(
805 &getContext(), cache, names, statistics);
806 symbolPatterns.add<MemWritePortEnableAndMaskCanonicalizer>(
807 &getContext(), cache, names, statistics, arcMapping);
808
809 if (failed(mlir::applyPatternsGreedily(getOperation(),
810 std::move(symbolPatterns), config)))
811 return signalPassFailure();
812
813 numArcArgsRemoved = statistics.removeUnusedArcArgumentsPatternNumArgsRemoved;
814
815 RewritePatternSet patterns(&ctxt);
816 for (auto *dialect : ctxt.getLoadedDialects())
817 dialect->getCanonicalizationPatterns(patterns);
818 for (mlir::RegisteredOperationName op : ctxt.getRegisteredOperations())
819 op.getCanonicalizationPatterns(patterns, &ctxt);
820 patterns.add<ICMPCanonicalizer, CompRegCanonicalizer, MergeVectorizeOps,
821 KeepOneVecOp>(&getContext());
822
823 // Don't test for convergence since it is often not reached.
824 (void)mlir::applyPatternsGreedily(getOperation(), std::move(patterns),
825 config);
826}
827
828std::unique_ptr<mlir::Pass> arc::createArcCanonicalizerPass() {
829 return std::make_unique<ArcCanonicalizerPass>();
830}
LogicalResult canonicalizePassthoughCall(mlir::CallOpInterface callOp, SymbolHandler &symbolCache, PatternRewriter &rewriter)
LogicalResult updateInputOperands(VectorizeOp &vecOp, const SmallVector< Value > &newOperands)
assert(baseType &&"element must be base type")
static std::optional< APSInt > getConstant(Attribute operand)
Determine the value of a constant operand for the sake of constant folding.
static InstancePath empty
A namespace that is used to store existing names and generate new names in some scope within the IR.
Definition Namespace.h:30
void add(mlir::ModuleOp module)
Definition Namespace.h:48
Default symbol cache implementation; stores associations between names (StringAttr's) to mlir::Operat...
Definition SymCache.h:85
SymbolCacheBase::Iterator end() override
Definition SymCache.h:125
create(data_type, value)
Definition hw.py:433
std::unique_ptr< mlir::Pass > createArcCanonicalizerPass()
static llvm::hash_code hash_value(const ModulePort &port)
Definition HWTypes.h:38
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
static unsigned hashValue(const SmallVector< Value > &inputs)
static SmallVector< Value > getTombstoneKey()
static bool isEqual(const SmallVector< Value > &lhs, const SmallVector< Value > &rhs)
static unsigned getHashValue(const SmallVector< Value > &inputs)