CIRCT 23.0.0git
Loading...
Searching...
No Matches
ArcCanonicalizer.cpp
Go to the documentation of this file.
1//===- ArcCanonicalizer.cpp -------------------------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//===----------------------------------------------------------------------===//
7//
8// Simulation centric canonicalizations for non-arc operations and
9// canonicalizations that require efficient symbol lookups.
10//
11//===----------------------------------------------------------------------===//
12
20#include "mlir/IR/IRMapping.h"
21#include "mlir/IR/Matchers.h"
22#include "mlir/IR/PatternMatch.h"
23#include "mlir/Pass/Pass.h"
24#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
25#include "llvm/Support/Debug.h"
26
27#define DEBUG_TYPE "arc-canonicalizer"
28
29namespace circt {
30namespace arc {
31#define GEN_PASS_DEF_ARCCANONICALIZER
32#include "circt/Dialect/Arc/ArcPasses.h.inc"
33} // namespace arc
34} // namespace circt
35
36using namespace circt;
37using namespace arc;
38
39//===----------------------------------------------------------------------===//
40// Datastructures
41//===----------------------------------------------------------------------===//
42
43namespace {
44
45/// A combination of SymbolCache and SymbolUserMap that also allows to add users
46/// and remove symbols on-demand.
47class SymbolHandler : public SymbolCache {
48public:
49 /// Return the users of the provided symbol operation.
50 ArrayRef<Operation *> getUsers(Operation *symbol) const {
51 auto it = userMap.find(symbol);
52 return it != userMap.end() ? it->second.getArrayRef()
53 : ArrayRef<Operation *>();
54 }
55
56 /// Return true if the given symbol has no uses.
57 bool useEmpty(Operation *symbol) {
58 return !userMap.count(symbol) || userMap[symbol].empty();
59 }
60
61 void addUser(Operation *def, Operation *user) {
62 assert(isa<mlir::SymbolOpInterface>(def));
63 if (!symbolCache.contains(cast<mlir::SymbolOpInterface>(def).getNameAttr()))
64 symbolCache.insert(
65 {cast<mlir::SymbolOpInterface>(def).getNameAttr(), def});
66 userMap[def].insert(user);
67 }
68
69 void removeUser(Operation *def, Operation *user) {
70 assert(isa<mlir::SymbolOpInterface>(def));
71 if (symbolCache.contains(cast<mlir::SymbolOpInterface>(def).getNameAttr()))
72 userMap[def].remove(user);
73 if (userMap[def].empty())
74 userMap.erase(def);
75 }
76
77 void removeDefinitionAndAllUsers(Operation *def) {
78 assert(isa<mlir::SymbolOpInterface>(def));
79 symbolCache.erase(cast<mlir::SymbolOpInterface>(def).getNameAttr());
80 userMap.erase(def);
81 }
82
83 void collectAllSymbolUses(Operation *symbolTableOp,
84 SymbolTableCollection &symbolTable) {
85 // NOTE: the following is almost 1-1 taken from the SymbolUserMap
86 // constructor. They made it difficult to extend the implementation by
87 // having a lot of members private and non-virtual methods.
88 SmallVector<Operation *> symbols;
89 auto walkFn = [&](Operation *symbolTableOp, bool allUsesVisible) {
90 for (Operation &nestedOp : symbolTableOp->getRegion(0).getOps()) {
91 auto symbolUses = SymbolTable::getSymbolUses(&nestedOp);
92 assert(symbolUses && "expected uses to be valid");
93
94 for (const SymbolTable::SymbolUse &use : *symbolUses) {
95 symbols.clear();
96 (void)symbolTable.lookupSymbolIn(symbolTableOp, use.getSymbolRef(),
97 symbols);
98 for (Operation *symbolOp : symbols)
99 userMap[symbolOp].insert(use.getUser());
100 }
101 }
102 };
103 // We just set `allSymUsesVisible` to false here because it isn't necessary
104 // for building the user map.
105 SymbolTable::walkSymbolTables(symbolTableOp, /*allSymUsesVisible=*/false,
106 walkFn);
107 }
108
109private:
110 DenseMap<Operation *, SetVector<Operation *>> userMap;
111};
112
113/// A Listener keeping the provided SymbolHandler up-to-date. This is especially
114/// important for simplifications (e.g. DCE) the rewriter performs automatically
115/// that we cannot or do not want to turn off.
116class ArcListener : public mlir::RewriterBase::Listener {
117public:
118 explicit ArcListener(SymbolHandler *handler) : Listener(), handler(handler) {}
119
120 void notifyOperationReplaced(Operation *op, Operation *replacement) override {
121 // If, e.g., a DefineOp is replaced with another DefineOp but with the same
122 // symbol, we don't want to drop the list of users.
123 auto symOp = dyn_cast<mlir::SymbolOpInterface>(op);
124 auto symReplacement = dyn_cast<mlir::SymbolOpInterface>(replacement);
125 if (symOp && symReplacement &&
126 symOp.getNameAttr() == symReplacement.getNameAttr())
127 return;
128
129 remove(op);
130 // TODO: if an operation is inserted that defines a symbol and the symbol
131 // already has uses, those users are not added.
132 add(replacement);
133 }
134
135 void notifyOperationReplaced(Operation *op, ValueRange replacement) override {
136 remove(op);
137 }
138
139 void notifyOperationErased(Operation *op) override { remove(op); }
140
141 void notifyOperationInserted(Operation *op,
142 mlir::IRRewriter::InsertPoint) override {
143 // TODO: if an operation is inserted that defines a symbol and the symbol
144 // already has uses, those users are not added.
145 add(op);
146 }
147
148private:
149 FailureOr<Operation *> maybeGetDefinition(Operation *op) {
150 if (auto callOp = dyn_cast<mlir::CallOpInterface>(op)) {
151 auto symAttr =
152 dyn_cast<mlir::SymbolRefAttr>(callOp.getCallableForCallee());
153 if (!symAttr)
154 return failure();
155 if (auto *def = handler->getDefinition(symAttr.getLeafReference()))
156 return def;
157 }
158 return failure();
159 }
160
161 void remove(Operation *op) {
162 auto maybeDef = maybeGetDefinition(op);
163 if (!failed(maybeDef))
164 handler->removeUser(*maybeDef, op);
165
166 if (isa<mlir::SymbolOpInterface>(op))
167 handler->removeDefinitionAndAllUsers(op);
168 }
169
170 void add(Operation *op) {
171 auto maybeDef = maybeGetDefinition(op);
172 if (!failed(maybeDef))
173 handler->addUser(*maybeDef, op);
174
175 if (auto defOp = dyn_cast<mlir::SymbolOpInterface>(op))
176 handler->addDefinition(defOp.getNameAttr(), op);
177 }
178
179 SymbolHandler *handler;
180};
181
182struct PatternStatistics {
183 unsigned removeUnusedArcArgumentsPatternNumArgsRemoved = 0;
184};
185
186} // namespace
187
188//===----------------------------------------------------------------------===//
189// Canonicalization patterns
190//===----------------------------------------------------------------------===//
191
192namespace {
193/// A rewrite pattern that has access to a symbol cache to access and modify the
194/// symbol-defining op and symbol users as well as a namespace to query new
195/// names. Each pattern has to make sure that the symbol handler is kept
196/// up-to-date no matter whether the pattern succeeds of fails.
197template <typename SourceOp>
198class SymOpRewritePattern : public OpRewritePattern<SourceOp> {
199public:
200 SymOpRewritePattern(MLIRContext *ctxt, SymbolHandler &symbolCache,
201 Namespace &names, PatternStatistics &stats,
202 mlir::PatternBenefit benefit = 1,
203 ArrayRef<StringRef> generatedNames = {})
204 : OpRewritePattern<SourceOp>(ctxt, benefit, generatedNames), names(names),
205 symbolCache(symbolCache), statistics(stats) {}
206
207protected:
208 Namespace &names;
209 SymbolHandler &symbolCache;
210 PatternStatistics &statistics;
211};
212
213class MemWritePortEnableAndMaskCanonicalizer
214 : public SymOpRewritePattern<MemoryWritePortOp> {
215public:
216 MemWritePortEnableAndMaskCanonicalizer(
217 MLIRContext *ctxt, SymbolHandler &symbolCache, Namespace &names,
218 PatternStatistics &stats, DenseMap<StringAttr, StringAttr> &arcMapping)
219 : SymOpRewritePattern<MemoryWritePortOp>(ctxt, symbolCache, names, stats),
220 arcMapping(arcMapping) {}
221 LogicalResult matchAndRewrite(MemoryWritePortOp op,
222 PatternRewriter &rewriter) const final;
223
224private:
225 DenseMap<StringAttr, StringAttr> &arcMapping;
226};
227
228struct CallPassthroughArc : public SymOpRewritePattern<CallOp> {
229 using SymOpRewritePattern::SymOpRewritePattern;
230 LogicalResult matchAndRewrite(CallOp op,
231 PatternRewriter &rewriter) const final;
232};
233
234struct RemoveUnusedArcs : public SymOpRewritePattern<DefineOp> {
235 using SymOpRewritePattern::SymOpRewritePattern;
236 LogicalResult matchAndRewrite(DefineOp op,
237 PatternRewriter &rewriter) const final;
238};
239
240struct ICMPCanonicalizer : public OpRewritePattern<comb::ICmpOp> {
241 using OpRewritePattern::OpRewritePattern;
242 LogicalResult matchAndRewrite(comb::ICmpOp op,
243 PatternRewriter &rewriter) const final;
244};
245
246struct CompRegCanonicalizer : public OpRewritePattern<seq::CompRegOp> {
247 using OpRewritePattern::OpRewritePattern;
248 LogicalResult matchAndRewrite(seq::CompRegOp op,
249 PatternRewriter &rewriter) const final;
250};
251
252struct RemoveUnusedArcArgumentsPattern : public SymOpRewritePattern<DefineOp> {
253 using SymOpRewritePattern::SymOpRewritePattern;
254 LogicalResult matchAndRewrite(DefineOp op,
255 PatternRewriter &rewriter) const final;
256};
257
258struct SinkArcInputsPattern : public SymOpRewritePattern<DefineOp> {
259 using SymOpRewritePattern::SymOpRewritePattern;
260 LogicalResult matchAndRewrite(DefineOp op,
261 PatternRewriter &rewriter) const final;
262};
263
264struct MergeVectorizeOps : public OpRewritePattern<VectorizeOp> {
265 using OpRewritePattern::OpRewritePattern;
266 LogicalResult matchAndRewrite(VectorizeOp op,
267 PatternRewriter &rewriter) const final;
268};
269
270struct KeepOneVecOp : public OpRewritePattern<VectorizeOp> {
271 using OpRewritePattern::OpRewritePattern;
272 LogicalResult matchAndRewrite(VectorizeOp op,
273 PatternRewriter &rewriter) const final;
274};
275
276} // namespace
277
278//===----------------------------------------------------------------------===//
279// Helpers
280//===----------------------------------------------------------------------===//
281
282LogicalResult canonicalizePassthoughCall(mlir::CallOpInterface callOp,
283 SymbolHandler &symbolCache,
284 PatternRewriter &rewriter) {
285 auto defOp = cast<DefineOp>(symbolCache.getDefinition(
286 llvm::cast<SymbolRefAttr>(callOp.getCallableForCallee())
287 .getLeafReference()));
288 if (defOp.isPassthrough()) {
289 symbolCache.removeUser(defOp, callOp);
290 rewriter.replaceOp(callOp, callOp.getArgOperands());
291 return success();
292 }
293 return failure();
294}
295
296LogicalResult updateInputOperands(VectorizeOp &vecOp,
297 const SmallVector<Value> &newOperands) {
298 // Set the new inputOperandSegments value
299 unsigned groupSize = vecOp.getResults().size();
300 unsigned numOfGroups = newOperands.size() / groupSize;
301 SmallVector<int32_t> newAttr(numOfGroups, groupSize);
302 vecOp.setInputOperandSegments(newAttr);
303 vecOp.getOperation()->setOperands(ValueRange(newOperands));
304 return success();
305}
306
307//===----------------------------------------------------------------------===//
308// Canonicalization pattern implementations
309//===----------------------------------------------------------------------===//
310
311LogicalResult MemWritePortEnableAndMaskCanonicalizer::matchAndRewrite(
312 MemoryWritePortOp op, PatternRewriter &rewriter) const {
313 auto defOp = cast<DefineOp>(symbolCache.getDefinition(op.getArcAttr()));
314 APInt enable;
315
316 if (op.getEnable() &&
317 mlir::matchPattern(
318 defOp.getBodyBlock().getTerminator()->getOperand(op.getEnableIdx()),
319 mlir::m_ConstantInt(&enable))) {
320 if (enable.isZero()) {
321 symbolCache.removeUser(defOp, op);
322 rewriter.eraseOp(op);
323 if (symbolCache.useEmpty(defOp)) {
324 symbolCache.removeDefinitionAndAllUsers(defOp);
325 rewriter.eraseOp(defOp);
326 }
327 return success();
328 }
329 if (enable.isAllOnes()) {
330 if (arcMapping.count(defOp.getNameAttr())) {
331 auto arcWithoutEnable = arcMapping[defOp.getNameAttr()];
332 // Remove the enable attribute
333 rewriter.modifyOpInPlace(op, [&]() {
334 op.setEnable(false);
335 op.setArc(arcWithoutEnable.getValue());
336 });
337 symbolCache.removeUser(defOp, op);
338 symbolCache.addUser(symbolCache.getDefinition(arcWithoutEnable), op);
339 return success();
340 }
341
342 auto newName = names.newName(defOp.getName());
343 auto users = SmallVector<Operation *>(symbolCache.getUsers(defOp));
344 symbolCache.removeDefinitionAndAllUsers(defOp);
345
346 // Remove the enable attribute
347 rewriter.modifyOpInPlace(op, [&]() {
348 op.setEnable(false);
349 op.setArc(newName);
350 });
351
352 auto newResultTypes = op.getArcResultTypes();
353
354 // Create a new arc that acts as replacement for other users
355 rewriter.setInsertionPoint(defOp);
356 auto newDefOp = rewriter.cloneWithoutRegions(defOp);
357 auto *block = rewriter.createBlock(
358 &newDefOp.getBody(), newDefOp.getBody().end(),
359 newDefOp.getArgumentTypes(),
360 SmallVector<Location>(newDefOp.getNumArguments(), defOp.getLoc()));
361 auto callOp = CallOp::create(rewriter, newDefOp.getLoc(), newResultTypes,
362 newName, block->getArguments());
363 SmallVector<Value> results(callOp->getResults());
364 Value constTrue = hw::ConstantOp::create(rewriter, newDefOp.getLoc(),
365 rewriter.getI1Type(), 1);
366 results.insert(results.begin() + op.getEnableIdx(), constTrue);
367 OutputOp::create(rewriter, newDefOp.getLoc(), results);
368
369 // Remove the enable output from the current arc
370 auto *terminator = defOp.getBodyBlock().getTerminator();
371 rewriter.modifyOpInPlace(
372 terminator, [&]() { terminator->eraseOperand(op.getEnableIdx()); });
373 rewriter.modifyOpInPlace(defOp, [&]() {
374 defOp.setName(newName);
375 defOp.setFunctionType(
376 rewriter.getFunctionType(defOp.getArgumentTypes(), newResultTypes));
377 });
378
379 // Update symbol cache
380 symbolCache.addDefinition(defOp.getNameAttr(), defOp);
381 symbolCache.addDefinition(newDefOp.getNameAttr(), newDefOp);
382 symbolCache.addUser(defOp, callOp);
383 for (auto *user : users)
384 symbolCache.addUser(user == op ? defOp : newDefOp, user);
385
386 arcMapping[newDefOp.getNameAttr()] = defOp.getNameAttr();
387 return success();
388 }
389 }
390 return failure();
391}
392
393LogicalResult
394CallPassthroughArc::matchAndRewrite(CallOp op,
395 PatternRewriter &rewriter) const {
396 return canonicalizePassthoughCall(op, symbolCache, rewriter);
397}
398
399LogicalResult
400RemoveUnusedArcs::matchAndRewrite(DefineOp op,
401 PatternRewriter &rewriter) const {
402 if (symbolCache.useEmpty(op)) {
403 op.getBody().walk([&](mlir::CallOpInterface user) {
404 if (auto symbol = dyn_cast<SymbolRefAttr>(user.getCallableForCallee()))
405 if (auto *defOp = symbolCache.getDefinition(symbol.getLeafReference()))
406 symbolCache.removeUser(defOp, user);
407 });
408 symbolCache.removeDefinitionAndAllUsers(op);
409 rewriter.eraseOp(op);
410 return success();
411 }
412 return failure();
413}
414
415LogicalResult
416ICMPCanonicalizer::matchAndRewrite(comb::ICmpOp op,
417 PatternRewriter &rewriter) const {
418 auto getConstant = [&](const APInt &constant) -> Value {
419 return hw::ConstantOp::create(rewriter, op.getLoc(), constant);
420 };
421 auto sameWidthIntegers = [](TypeRange types) -> std::optional<unsigned> {
422 if (llvm::all_equal(types) && !types.empty())
423 if (auto intType = dyn_cast<IntegerType>(*types.begin()))
424 return intType.getWidth();
425 return std::nullopt;
426 };
427 auto negate = [&](Value input) -> Value {
428 auto constTrue = hw::ConstantOp::create(rewriter, op.getLoc(), APInt(1, 1));
429 return comb::XorOp::create(rewriter, op.getLoc(), input, constTrue,
430 op.getTwoState());
431 };
432
433 APInt rhs;
434 if (matchPattern(op.getRhs(), mlir::m_ConstantInt(&rhs))) {
435 if (auto concatOp = op.getLhs().getDefiningOp<comb::ConcatOp>()) {
436 if (auto optionalWidth =
437 sameWidthIntegers(concatOp->getOperands().getTypes())) {
438 if ((op.getPredicate() == comb::ICmpPredicate::eq ||
439 op.getPredicate() == comb::ICmpPredicate::ne) &&
440 rhs.isAllOnes()) {
441 Value andOp = comb::AndOp::create(
442 rewriter, op.getLoc(), concatOp.getInputs(), op.getTwoState());
443 if (*optionalWidth == 1) {
444 if (op.getPredicate() == comb::ICmpPredicate::ne)
445 andOp = negate(andOp);
446 rewriter.replaceOp(op, andOp);
447 return success();
448 }
449 rewriter.replaceOpWithNewOp<comb::ICmpOp>(
450 op, op.getPredicate(), andOp,
451 getConstant(APInt(*optionalWidth, rhs.getZExtValue(),
452 /*isSigned=*/false, /*implicitTrunc=*/true)),
453 op.getTwoState());
454 return success();
455 }
456
457 if ((op.getPredicate() == comb::ICmpPredicate::ne ||
458 op.getPredicate() == comb::ICmpPredicate::eq) &&
459 rhs.isZero()) {
460 Value orOp = comb::OrOp::create(
461 rewriter, op.getLoc(), concatOp.getInputs(), op.getTwoState());
462 if (*optionalWidth == 1) {
463 if (op.getPredicate() == comb::ICmpPredicate::eq)
464 orOp = negate(orOp);
465 rewriter.replaceOp(op, orOp);
466 return success();
467 }
468 rewriter.replaceOpWithNewOp<comb::ICmpOp>(
469 op, op.getPredicate(), orOp,
470 getConstant(APInt(*optionalWidth, rhs.getZExtValue(),
471 /*isSigned=*/false, /*implicitTrunc=*/true)),
472 op.getTwoState());
473 return success();
474 }
475 }
476 }
477 }
478 return failure();
479}
480
481LogicalResult RemoveUnusedArcArgumentsPattern::matchAndRewrite(
482 DefineOp op, PatternRewriter &rewriter) const {
483 BitVector toDelete(op.getNumArguments());
484 for (auto [i, arg] : llvm::enumerate(op.getArguments()))
485 if (arg.use_empty())
486 toDelete.set(i);
487
488 if (toDelete.none())
489 return failure();
490
491 // Collect the mutable callers in a first iteration. If there is a user that
492 // does not implement the interface, we have to abort the rewrite and have to
493 // make sure that we didn't change anything so far.
494 SmallVector<mlir::CallOpInterface> mutableUsers;
495 for (auto *user : symbolCache.getUsers(op)) {
496 auto callOpMutable = dyn_cast<mlir::CallOpInterface>(user);
497 if (!callOpMutable)
498 return failure();
499 mutableUsers.push_back(callOpMutable);
500 }
501
502 // Do the actual rewrites.
503 for (auto user : mutableUsers)
504 for (int i = toDelete.size() - 1; i >= 0; --i)
505 if (toDelete[i])
506 rewriter.modifyOpInPlace(
507 user, [&]() { user.getArgOperandsMutable().erase(i); });
508
509 bool matchFailure = false;
510 rewriter.modifyOpInPlace(op, [&]() {
511 if (failed(op.eraseArguments(toDelete))) {
512 matchFailure = true;
513 return;
514 }
515
516 op.setFunctionType(
517 rewriter.getFunctionType(op.getArgumentTypes(), op.getResultTypes()));
518 });
519
520 if (matchFailure)
521 return rewriter.notifyMatchFailure(op, "failed to erase arguments");
522
523 statistics.removeUnusedArcArgumentsPatternNumArgsRemoved += toDelete.count();
524 return success();
525}
526
527LogicalResult
528SinkArcInputsPattern::matchAndRewrite(DefineOp op,
529 PatternRewriter &rewriter) const {
530 // First check that all users implement the interface we need to be able to
531 // modify the users.
532 auto users = symbolCache.getUsers(op);
533 if (llvm::any_of(
534 users, [](auto *user) { return !isa<mlir::CallOpInterface>(user); }))
535 return failure();
536
537 // Find all arguments that use constant operands only.
538 SmallVector<Operation *> stateConsts(op.getNumArguments());
539 bool first = true;
540 for (auto *user : users) {
541 auto callOp = cast<mlir::CallOpInterface>(user);
542 for (auto [constArg, input] :
543 llvm::zip(stateConsts, callOp.getArgOperands())) {
544 if (auto *constOp = input.getDefiningOp();
545 constOp && constOp->template hasTrait<OpTrait::ConstantLike>()) {
546 if (first) {
547 constArg = constOp;
548 continue;
549 }
550 if (constArg &&
551 constArg->getName() == input.getDefiningOp()->getName() &&
552 constArg->getAttrDictionary() ==
553 input.getDefiningOp()->getAttrDictionary())
554 continue;
555 }
556 constArg = nullptr;
557 }
558 first = false;
559 }
560
561 // Move the constants into the arc and erase the block arguments.
562 rewriter.setInsertionPointToStart(&op.getBodyBlock());
563 llvm::BitVector toDelete(op.getBodyBlock().getNumArguments());
564 for (auto [constArg, arg] : llvm::zip(stateConsts, op.getArguments())) {
565 if (!constArg)
566 continue;
567 auto *inlinedConst = rewriter.clone(*constArg);
568 rewriter.replaceAllUsesWith(arg, inlinedConst->getResult(0));
569 toDelete.set(arg.getArgNumber());
570 }
571 op.getBodyBlock().eraseArguments(toDelete);
572 op.setType(rewriter.getFunctionType(op.getBodyBlock().getArgumentTypes(),
573 op.getResultTypes()));
574
575 // Rewrite all arc uses to not pass in the constant anymore.
576 for (auto *user : users) {
577 auto callOp = cast<mlir::CallOpInterface>(user);
578 SmallPtrSet<Value, 4> maybeUnusedValues;
579 SmallVector<Value> newInputs;
580 for (auto [index, value] : llvm::enumerate(callOp.getArgOperands())) {
581 if (toDelete[index])
582 maybeUnusedValues.insert(value);
583 else
584 newInputs.push_back(value);
585 }
586 rewriter.modifyOpInPlace(
587 callOp, [&]() { callOp.getArgOperandsMutable().assign(newInputs); });
588 for (auto value : maybeUnusedValues)
589 if (value.use_empty())
590 rewriter.eraseOp(value.getDefiningOp());
591 }
592
593 return success(toDelete.any());
594}
595
596LogicalResult
597CompRegCanonicalizer::matchAndRewrite(seq::CompRegOp op,
598 PatternRewriter &rewriter) const {
599 if (!op.getReset())
600 return failure();
601
602 // Because Arcilator supports constant zero reset values, skip them.
603 APInt constant;
604 if (mlir::matchPattern(op.getResetValue(), mlir::m_ConstantInt(&constant)))
605 if (constant.isZero())
606 return failure();
607
608 Value newInput = comb::MuxOp::create(rewriter, op->getLoc(), op.getReset(),
609 op.getResetValue(), op.getInput());
610 rewriter.modifyOpInPlace(op, [&]() {
611 op.getInputMutable().set(newInput);
612 op.getResetMutable().clear();
613 op.getResetValueMutable().clear();
614 });
615
616 return success();
617}
618
619LogicalResult
620MergeVectorizeOps::matchAndRewrite(VectorizeOp vecOp,
621 PatternRewriter &rewriter) const {
622 auto &currentBlock = vecOp.getBody().front();
623 IRMapping argMapping;
624 SmallVector<Value> newOperands;
625 SmallVector<VectorizeOp> vecOpsToRemove;
626 bool canBeMerged = false;
627 // Used to calculate the new positions of args after insertions and removals
628 unsigned paddedBy = 0;
629
630 for (unsigned argIdx = 0, numArgs = vecOp.getInputs().size();
631 argIdx < numArgs; ++argIdx) {
632 auto inputVec = vecOp.getInputs()[argIdx];
633 // Make sure that the input comes from a `VectorizeOp`
634 // Ensure that the input vector matches the output of the `otherVecOp`
635 // Make sure that the results of the otherVecOp have only one use
636 auto otherVecOp = inputVec[0].getDefiningOp<VectorizeOp>();
637 if (!otherVecOp || otherVecOp == vecOp ||
638 !llvm::all_of(otherVecOp.getResults(),
639 [](auto result) { return result.hasOneUse(); }) ||
640 !llvm::all_of(inputVec, [&](auto result) {
641 return result.template getDefiningOp<VectorizeOp>() == otherVecOp;
642 })) {
643 newOperands.insert(newOperands.end(), inputVec.begin(), inputVec.end());
644 continue;
645 }
646
647 // Here, all elements are from the same `VectorizeOp`.
648 // If all elements of the input vector come from the same `VectorizeOp`
649 // sort the vectors by their indices
650 DenseMap<Value, size_t> resultIdxMap;
651 for (auto [resultIdx, result] : llvm::enumerate(otherVecOp.getResults()))
652 resultIdxMap[result] = resultIdx;
653
654 SmallVector<Value> tempVec(inputVec.begin(), inputVec.end());
655 llvm::sort(tempVec, [&](Value a, Value b) {
656 return resultIdxMap[a] < resultIdxMap[b];
657 });
658
659 // Check if inputVec matches the result after sorting.
660 if (tempVec != SmallVector<Value>(otherVecOp.getResults().begin(),
661 otherVecOp.getResults().end())) {
662 newOperands.insert(newOperands.end(), inputVec.begin(), inputVec.end());
663 continue;
664 }
665
666 DenseMap<size_t, size_t> fromRealIdxToSortedIdx;
667 for (auto [inIdx, in] : llvm::enumerate(inputVec))
668 fromRealIdxToSortedIdx[inIdx] = resultIdxMap[in];
669
670 // If this flag is set that means we changed the IR so we cannot return
671 // failure
672 canBeMerged = true;
673
674 // If the results got shuffled, then shuffle the operands before merging.
675 if (inputVec != otherVecOp.getResults()) {
676 for (auto otherVecOpInputVec : otherVecOp.getInputs()) {
677 // use the tempVec again instead of creating another one.
678 tempVec = SmallVector<Value>(inputVec.size());
679 for (auto [realIdx, opernad] : llvm::enumerate(otherVecOpInputVec))
680 tempVec[realIdx] =
681 otherVecOpInputVec[fromRealIdxToSortedIdx[realIdx]];
682
683 newOperands.insert(newOperands.end(), tempVec.begin(), tempVec.end());
684 }
685
686 } else
687 newOperands.insert(newOperands.end(), otherVecOp.getOperands().begin(),
688 otherVecOp.getOperands().end());
689
690 auto &otherBlock = otherVecOp.getBody().front();
691 for (auto &otherArg : otherBlock.getArguments()) {
692 auto newArg = currentBlock.insertArgument(
693 argIdx + paddedBy, otherArg.getType(), otherArg.getLoc());
694 argMapping.map(otherArg, newArg);
695 ++paddedBy;
696 }
697
698 rewriter.setInsertionPointToStart(&currentBlock);
699 for (auto &op : otherBlock.without_terminator())
700 rewriter.clone(op, argMapping);
701
702 unsigned argNewPos = paddedBy + argIdx;
703 // Get the result of the return value and use it in all places the
704 // the `otherVecOp` results were used
705 auto retOp = cast<VectorizeReturnOp>(otherBlock.getTerminator());
706 rewriter.replaceAllUsesWith(currentBlock.getArgument(argNewPos),
707 argMapping.lookupOrDefault(retOp.getValue()));
708 rewriter.modifyOpInPlace(vecOp,
709 [&]() { currentBlock.eraseArgument(argNewPos); });
710 vecOpsToRemove.push_back(otherVecOp);
711 // We erased an arg so the padding decreased by 1
712 paddedBy--;
713 }
714
715 // We didn't change the IR as there were no vectors to merge
716 if (!canBeMerged)
717 return failure();
718
719 (void)updateInputOperands(vecOp, newOperands);
720
721 // Erase dead VectorizeOps
722 for (auto deadOp : vecOpsToRemove)
723 rewriter.eraseOp(deadOp);
724
725 return success();
726}
727
728namespace llvm {
729static unsigned hashValue(const SmallVector<Value> &inputs) {
730 unsigned hash = hash_value(inputs.size());
731 for (auto input : inputs)
732 hash = hash_combine(hash, input);
733 return hash;
734}
735
736template <>
737struct DenseMapInfo<SmallVector<Value>> {
738 static inline SmallVector<Value> getEmptyKey() {
739 return SmallVector<Value>();
740 }
741
742 static inline SmallVector<Value> getTombstoneKey() {
743 return SmallVector<Value>();
744 }
745
746 static unsigned getHashValue(const SmallVector<Value> &inputs) {
747 return hashValue(inputs);
748 }
749
750 static bool isEqual(const SmallVector<Value> &lhs,
751 const SmallVector<Value> &rhs) {
752 return lhs == rhs;
753 }
754};
755} // namespace llvm
756
757LogicalResult KeepOneVecOp::matchAndRewrite(VectorizeOp vecOp,
758 PatternRewriter &rewriter) const {
759 DenseMap<SmallVector<Value>, unsigned> inExists;
760 auto &currentBlock = vecOp.getBody().front();
761 SmallVector<Value> newOperands;
762 BitVector argsToRemove(vecOp.getInputs().size(), false);
763 for (size_t argIdx = 0; argIdx < vecOp.getInputs().size(); ++argIdx) {
764 auto input = SmallVector<Value>(vecOp.getInputs()[argIdx].begin(),
765 vecOp.getInputs()[argIdx].end());
766 if (auto in = inExists.find(input); in != inExists.end()) {
767 rewriter.replaceAllUsesWith(currentBlock.getArgument(argIdx),
768 currentBlock.getArgument(in->second));
769 argsToRemove.set(argIdx);
770 continue;
771 }
772 inExists[input] = argIdx;
773 newOperands.insert(newOperands.end(), input.begin(), input.end());
774 }
775
776 if (argsToRemove.none())
777 return failure();
778
779 rewriter.modifyOpInPlace(
780 vecOp, [&]() { currentBlock.eraseArguments(argsToRemove); });
781 return updateInputOperands(vecOp, newOperands);
782}
783
784//===----------------------------------------------------------------------===//
785// ArcCanonicalizerPass implementation
786//===----------------------------------------------------------------------===//
787
788namespace {
789struct ArcCanonicalizerPass
790 : public arc::impl::ArcCanonicalizerBase<ArcCanonicalizerPass> {
791 void runOnOperation() override;
792};
793} // namespace
794
795void ArcCanonicalizerPass::runOnOperation() {
796 MLIRContext &ctxt = getContext();
797 SymbolTableCollection symbolTable;
798 SymbolHandler cache;
799 cache.addDefinitions(getOperation());
800 cache.collectAllSymbolUses(getOperation(), symbolTable);
801 Namespace names;
802 names.add(cache);
803 DenseMap<StringAttr, StringAttr> arcMapping;
804
805 mlir::GreedyRewriteConfig config;
806 config.setRegionSimplificationLevel(
807 mlir::GreedySimplifyRegionLevel::Disabled);
808 config.setMaxIterations(10);
809 config.setUseTopDownTraversal(true);
810 ArcListener listener(&cache);
811 config.setListener(&listener);
812
813 PatternStatistics statistics;
814 RewritePatternSet symbolPatterns(&getContext());
815 symbolPatterns.add<CallPassthroughArc, RemoveUnusedArcs,
816 RemoveUnusedArcArgumentsPattern, SinkArcInputsPattern>(
817 &getContext(), cache, names, statistics);
818 symbolPatterns.add<MemWritePortEnableAndMaskCanonicalizer>(
819 &getContext(), cache, names, statistics, arcMapping);
820
821 if (failed(mlir::applyPatternsGreedily(getOperation(),
822 std::move(symbolPatterns), config)))
823 return signalPassFailure();
824
825 numArcArgsRemoved = statistics.removeUnusedArcArgumentsPatternNumArgsRemoved;
826
827 RewritePatternSet patterns(&ctxt);
828 for (auto *dialect : ctxt.getLoadedDialects())
829 dialect->getCanonicalizationPatterns(patterns);
830 for (mlir::RegisteredOperationName op : ctxt.getRegisteredOperations())
831 op.getCanonicalizationPatterns(patterns, &ctxt);
832 patterns.add<ICMPCanonicalizer, CompRegCanonicalizer, MergeVectorizeOps,
833 KeepOneVecOp>(&getContext());
834
835 // Don't test for convergence since it is often not reached.
836 (void)mlir::applyPatternsGreedily(getOperation(), std::move(patterns),
837 config);
838}
LogicalResult canonicalizePassthoughCall(mlir::CallOpInterface callOp, SymbolHandler &symbolCache, PatternRewriter &rewriter)
LogicalResult updateInputOperands(VectorizeOp &vecOp, const SmallVector< Value > &newOperands)
assert(baseType &&"element must be base type")
static std::optional< APSInt > getConstant(Attribute operand)
Determine the value of a constant operand for the sake of constant folding.
static InstancePath empty
A namespace that is used to store existing names and generate new names in some scope within the IR.
Definition Namespace.h:30
void add(mlir::ModuleOp module)
Definition Namespace.h:48
Default symbol cache implementation; stores associations between names (StringAttr's) to mlir::Operat...
Definition SymCache.h:85
SymbolCacheBase::Iterator end() override
Definition SymCache.h:125
create(data_type, value)
Definition hw.py:433
Definition arc.py:1
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
llvm::hash_code hash_value(const DenseSet< T > &set)
static unsigned hashValue(const SmallVector< Value > &inputs)
static SmallVector< Value > getTombstoneKey()
static bool isEqual(const SmallVector< Value > &lhs, const SmallVector< Value > &rhs)
static unsigned getHashValue(const SmallVector< Value > &inputs)