CIRCT 20.0.0git
Loading...
Searching...
No Matches
ArcCanonicalizer.cpp
Go to the documentation of this file.
1//===- ArcCanonicalizer.cpp -------------------------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//===----------------------------------------------------------------------===//
7//
8// Simulation centric canonicalizations for non-arc operations and
9// canonicalizations that require efficient symbol lookups.
10//
11//===----------------------------------------------------------------------===//
12
20#include "mlir/IR/IRMapping.h"
21#include "mlir/IR/Matchers.h"
22#include "mlir/IR/PatternMatch.h"
23#include "mlir/Pass/Pass.h"
24#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
25#include "llvm/Support/Debug.h"
26
27#define DEBUG_TYPE "arc-canonicalizer"
28
29namespace circt {
30namespace arc {
31#define GEN_PASS_DEF_ARCCANONICALIZER
32#include "circt/Dialect/Arc/ArcPasses.h.inc"
33} // namespace arc
34} // namespace circt
35
36using namespace circt;
37using namespace arc;
38
39//===----------------------------------------------------------------------===//
40// Datastructures
41//===----------------------------------------------------------------------===//
42
43namespace {
44
45/// A combination of SymbolCache and SymbolUserMap that also allows to add users
46/// and remove symbols on-demand.
47class SymbolHandler : public SymbolCache {
48public:
49 /// Return the users of the provided symbol operation.
50 ArrayRef<Operation *> getUsers(Operation *symbol) const {
51 auto it = userMap.find(symbol);
52 return it != userMap.end() ? it->second.getArrayRef() : std::nullopt;
53 }
54
55 /// Return true if the given symbol has no uses.
56 bool useEmpty(Operation *symbol) {
57 return !userMap.count(symbol) || userMap[symbol].empty();
58 }
59
60 void addUser(Operation *def, Operation *user) {
61 assert(isa<mlir::SymbolOpInterface>(def));
62 if (!symbolCache.contains(cast<mlir::SymbolOpInterface>(def).getNameAttr()))
63 symbolCache.insert(
64 {cast<mlir::SymbolOpInterface>(def).getNameAttr(), def});
65 userMap[def].insert(user);
66 }
67
68 void removeUser(Operation *def, Operation *user) {
69 assert(isa<mlir::SymbolOpInterface>(def));
70 if (symbolCache.contains(cast<mlir::SymbolOpInterface>(def).getNameAttr()))
71 userMap[def].remove(user);
72 if (userMap[def].empty())
73 userMap.erase(def);
74 }
75
76 void removeDefinitionAndAllUsers(Operation *def) {
77 assert(isa<mlir::SymbolOpInterface>(def));
78 symbolCache.erase(cast<mlir::SymbolOpInterface>(def).getNameAttr());
79 userMap.erase(def);
80 }
81
82 void collectAllSymbolUses(Operation *symbolTableOp,
83 SymbolTableCollection &symbolTable) {
84 // NOTE: the following is almost 1-1 taken from the SymbolUserMap
85 // constructor. They made it difficult to extend the implementation by
86 // having a lot of members private and non-virtual methods.
87 SmallVector<Operation *> symbols;
88 auto walkFn = [&](Operation *symbolTableOp, bool allUsesVisible) {
89 for (Operation &nestedOp : symbolTableOp->getRegion(0).getOps()) {
90 auto symbolUses = SymbolTable::getSymbolUses(&nestedOp);
91 assert(symbolUses && "expected uses to be valid");
92
93 for (const SymbolTable::SymbolUse &use : *symbolUses) {
94 symbols.clear();
95 (void)symbolTable.lookupSymbolIn(symbolTableOp, use.getSymbolRef(),
96 symbols);
97 for (Operation *symbolOp : symbols)
98 userMap[symbolOp].insert(use.getUser());
99 }
100 }
101 };
102 // We just set `allSymUsesVisible` to false here because it isn't necessary
103 // for building the user map.
104 SymbolTable::walkSymbolTables(symbolTableOp, /*allSymUsesVisible=*/false,
105 walkFn);
106 }
107
108private:
109 DenseMap<Operation *, SetVector<Operation *>> userMap;
110};
111
112/// A Listener keeping the provided SymbolHandler up-to-date. This is especially
113/// important for simplifications (e.g. DCE) the rewriter performs automatically
114/// that we cannot or do not want to turn off.
115class ArcListener : public mlir::RewriterBase::Listener {
116public:
117 explicit ArcListener(SymbolHandler *handler) : Listener(), handler(handler) {}
118
119 void notifyOperationReplaced(Operation *op, Operation *replacement) override {
120 // If, e.g., a DefineOp is replaced with another DefineOp but with the same
121 // symbol, we don't want to drop the list of users.
122 auto symOp = dyn_cast<mlir::SymbolOpInterface>(op);
123 auto symReplacement = dyn_cast<mlir::SymbolOpInterface>(replacement);
124 if (symOp && symReplacement &&
125 symOp.getNameAttr() == symReplacement.getNameAttr())
126 return;
127
128 remove(op);
129 // TODO: if an operation is inserted that defines a symbol and the symbol
130 // already has uses, those users are not added.
131 add(replacement);
132 }
133
134 void notifyOperationReplaced(Operation *op, ValueRange replacement) override {
135 remove(op);
136 }
137
138 void notifyOperationErased(Operation *op) override { remove(op); }
139
140 void notifyOperationInserted(Operation *op,
141 mlir::IRRewriter::InsertPoint) override {
142 // TODO: if an operation is inserted that defines a symbol and the symbol
143 // already has uses, those users are not added.
144 add(op);
145 }
146
147private:
148 FailureOr<Operation *> maybeGetDefinition(Operation *op) {
149 if (auto callOp = dyn_cast<mlir::CallOpInterface>(op)) {
150 auto symAttr =
151 dyn_cast<mlir::SymbolRefAttr>(callOp.getCallableForCallee());
152 if (!symAttr)
153 return failure();
154 if (auto *def = handler->getDefinition(symAttr.getLeafReference()))
155 return def;
156 }
157 return failure();
158 }
159
160 void remove(Operation *op) {
161 auto maybeDef = maybeGetDefinition(op);
162 if (!failed(maybeDef))
163 handler->removeUser(*maybeDef, op);
164
165 if (isa<mlir::SymbolOpInterface>(op))
166 handler->removeDefinitionAndAllUsers(op);
167 }
168
169 void add(Operation *op) {
170 auto maybeDef = maybeGetDefinition(op);
171 if (!failed(maybeDef))
172 handler->addUser(*maybeDef, op);
173
174 if (auto defOp = dyn_cast<mlir::SymbolOpInterface>(op))
175 handler->addDefinition(defOp.getNameAttr(), op);
176 }
177
178 SymbolHandler *handler;
179};
180
181struct PatternStatistics {
182 unsigned removeUnusedArcArgumentsPatternNumArgsRemoved = 0;
183};
184
185} // namespace
186
187//===----------------------------------------------------------------------===//
188// Canonicalization patterns
189//===----------------------------------------------------------------------===//
190
191namespace {
192/// A rewrite pattern that has access to a symbol cache to access and modify the
193/// symbol-defining op and symbol users as well as a namespace to query new
194/// names. Each pattern has to make sure that the symbol handler is kept
195/// up-to-date no matter whether the pattern succeeds of fails.
196template <typename SourceOp>
197class SymOpRewritePattern : public OpRewritePattern<SourceOp> {
198public:
199 SymOpRewritePattern(MLIRContext *ctxt, SymbolHandler &symbolCache,
200 Namespace &names, PatternStatistics &stats,
201 mlir::PatternBenefit benefit = 1,
202 ArrayRef<StringRef> generatedNames = {})
203 : OpRewritePattern<SourceOp>(ctxt, benefit, generatedNames), names(names),
204 symbolCache(symbolCache), statistics(stats) {}
205
206protected:
207 Namespace &names;
208 SymbolHandler &symbolCache;
209 PatternStatistics &statistics;
210};
211
212class MemWritePortEnableAndMaskCanonicalizer
213 : public SymOpRewritePattern<MemoryWritePortOp> {
214public:
215 MemWritePortEnableAndMaskCanonicalizer(
216 MLIRContext *ctxt, SymbolHandler &symbolCache, Namespace &names,
217 PatternStatistics &stats, DenseMap<StringAttr, StringAttr> &arcMapping)
218 : SymOpRewritePattern<MemoryWritePortOp>(ctxt, symbolCache, names, stats),
219 arcMapping(arcMapping) {}
220 LogicalResult matchAndRewrite(MemoryWritePortOp op,
221 PatternRewriter &rewriter) const final;
222
223private:
224 DenseMap<StringAttr, StringAttr> &arcMapping;
225};
226
227struct CallPassthroughArc : public SymOpRewritePattern<CallOp> {
228 using SymOpRewritePattern::SymOpRewritePattern;
229 LogicalResult matchAndRewrite(CallOp op,
230 PatternRewriter &rewriter) const final;
231};
232
233struct RemoveUnusedArcs : public SymOpRewritePattern<DefineOp> {
234 using SymOpRewritePattern::SymOpRewritePattern;
235 LogicalResult matchAndRewrite(DefineOp op,
236 PatternRewriter &rewriter) const final;
237};
238
239struct ICMPCanonicalizer : public OpRewritePattern<comb::ICmpOp> {
240 using OpRewritePattern::OpRewritePattern;
241 LogicalResult matchAndRewrite(comb::ICmpOp op,
242 PatternRewriter &rewriter) const final;
243};
244
245struct CompRegCanonicalizer : public OpRewritePattern<seq::CompRegOp> {
246 using OpRewritePattern::OpRewritePattern;
247 LogicalResult matchAndRewrite(seq::CompRegOp op,
248 PatternRewriter &rewriter) const final;
249};
250
251struct RemoveUnusedArcArgumentsPattern : public SymOpRewritePattern<DefineOp> {
252 using SymOpRewritePattern::SymOpRewritePattern;
253 LogicalResult matchAndRewrite(DefineOp op,
254 PatternRewriter &rewriter) const final;
255};
256
257struct SinkArcInputsPattern : public SymOpRewritePattern<DefineOp> {
258 using SymOpRewritePattern::SymOpRewritePattern;
259 LogicalResult matchAndRewrite(DefineOp op,
260 PatternRewriter &rewriter) const final;
261};
262
263struct MergeVectorizeOps : public OpRewritePattern<VectorizeOp> {
264 using OpRewritePattern::OpRewritePattern;
265 LogicalResult matchAndRewrite(VectorizeOp op,
266 PatternRewriter &rewriter) const final;
267};
268
269struct KeepOneVecOp : public OpRewritePattern<VectorizeOp> {
270 using OpRewritePattern::OpRewritePattern;
271 LogicalResult matchAndRewrite(VectorizeOp op,
272 PatternRewriter &rewriter) const final;
273};
274
275} // namespace
276
277//===----------------------------------------------------------------------===//
278// Helpers
279//===----------------------------------------------------------------------===//
280
281LogicalResult canonicalizePassthoughCall(mlir::CallOpInterface callOp,
282 SymbolHandler &symbolCache,
283 PatternRewriter &rewriter) {
284 auto defOp = cast<DefineOp>(symbolCache.getDefinition(
285 llvm::cast<SymbolRefAttr>(callOp.getCallableForCallee())
286 .getLeafReference()));
287 if (defOp.isPassthrough()) {
288 symbolCache.removeUser(defOp, callOp);
289 rewriter.replaceOp(callOp, callOp.getArgOperands());
290 return success();
291 }
292 return failure();
293}
294
295LogicalResult updateInputOperands(VectorizeOp &vecOp,
296 const SmallVector<Value> &newOperands) {
297 // Set the new inputOperandSegments value
298 unsigned groupSize = vecOp.getResults().size();
299 unsigned numOfGroups = newOperands.size() / groupSize;
300 SmallVector<int32_t> newAttr(numOfGroups, groupSize);
301 vecOp.setInputOperandSegments(newAttr);
302 vecOp.getOperation()->setOperands(ValueRange(newOperands));
303 return success();
304}
305
306//===----------------------------------------------------------------------===//
307// Canonicalization pattern implementations
308//===----------------------------------------------------------------------===//
309
310LogicalResult MemWritePortEnableAndMaskCanonicalizer::matchAndRewrite(
311 MemoryWritePortOp op, PatternRewriter &rewriter) const {
312 auto defOp = cast<DefineOp>(symbolCache.getDefinition(op.getArcAttr()));
313 APInt enable;
314
315 if (op.getEnable() &&
316 mlir::matchPattern(
317 defOp.getBodyBlock().getTerminator()->getOperand(op.getEnableIdx()),
318 mlir::m_ConstantInt(&enable))) {
319 if (enable.isZero()) {
320 symbolCache.removeUser(defOp, op);
321 rewriter.eraseOp(op);
322 if (symbolCache.useEmpty(defOp)) {
323 symbolCache.removeDefinitionAndAllUsers(defOp);
324 rewriter.eraseOp(defOp);
325 }
326 return success();
327 }
328 if (enable.isAllOnes()) {
329 if (arcMapping.count(defOp.getNameAttr())) {
330 auto arcWithoutEnable = arcMapping[defOp.getNameAttr()];
331 // Remove the enable attribute
332 rewriter.modifyOpInPlace(op, [&]() {
333 op.setEnable(false);
334 op.setArc(arcWithoutEnable.getValue());
335 });
336 symbolCache.removeUser(defOp, op);
337 symbolCache.addUser(symbolCache.getDefinition(arcWithoutEnable), op);
338 return success();
339 }
340
341 auto newName = names.newName(defOp.getName());
342 auto users = SmallVector<Operation *>(symbolCache.getUsers(defOp));
343 symbolCache.removeDefinitionAndAllUsers(defOp);
344
345 // Remove the enable attribute
346 rewriter.modifyOpInPlace(op, [&]() {
347 op.setEnable(false);
348 op.setArc(newName);
349 });
350
351 auto newResultTypes = op.getArcResultTypes();
352
353 // Create a new arc that acts as replacement for other users
354 rewriter.setInsertionPoint(defOp);
355 auto newDefOp = rewriter.cloneWithoutRegions(defOp);
356 auto *block = rewriter.createBlock(
357 &newDefOp.getBody(), newDefOp.getBody().end(),
358 newDefOp.getArgumentTypes(),
359 SmallVector<Location>(newDefOp.getNumArguments(), defOp.getLoc()));
360 auto callOp = rewriter.create<CallOp>(newDefOp.getLoc(), newResultTypes,
361 newName, block->getArguments());
362 SmallVector<Value> results(callOp->getResults());
363 Value constTrue = rewriter.create<hw::ConstantOp>(
364 newDefOp.getLoc(), rewriter.getI1Type(), 1);
365 results.insert(results.begin() + op.getEnableIdx(), constTrue);
366 rewriter.create<OutputOp>(newDefOp.getLoc(), results);
367
368 // Remove the enable output from the current arc
369 auto *terminator = defOp.getBodyBlock().getTerminator();
370 rewriter.modifyOpInPlace(
371 terminator, [&]() { terminator->eraseOperand(op.getEnableIdx()); });
372 rewriter.modifyOpInPlace(defOp, [&]() {
373 defOp.setName(newName);
374 defOp.setFunctionType(
375 rewriter.getFunctionType(defOp.getArgumentTypes(), newResultTypes));
376 });
377
378 // Update symbol cache
379 symbolCache.addDefinition(defOp.getNameAttr(), defOp);
380 symbolCache.addDefinition(newDefOp.getNameAttr(), newDefOp);
381 symbolCache.addUser(defOp, callOp);
382 for (auto *user : users)
383 symbolCache.addUser(user == op ? defOp : newDefOp, user);
384
385 arcMapping[newDefOp.getNameAttr()] = defOp.getNameAttr();
386 return success();
387 }
388 }
389 return failure();
390}
391
392LogicalResult
393CallPassthroughArc::matchAndRewrite(CallOp op,
394 PatternRewriter &rewriter) const {
395 return canonicalizePassthoughCall(op, symbolCache, rewriter);
396}
397
398LogicalResult
399RemoveUnusedArcs::matchAndRewrite(DefineOp op,
400 PatternRewriter &rewriter) const {
401 if (symbolCache.useEmpty(op)) {
402 op.getBody().walk([&](mlir::CallOpInterface user) {
403 if (auto symbol = dyn_cast<SymbolRefAttr>(user.getCallableForCallee()))
404 if (auto *defOp = symbolCache.getDefinition(symbol.getLeafReference()))
405 symbolCache.removeUser(defOp, user);
406 });
407 symbolCache.removeDefinitionAndAllUsers(op);
408 rewriter.eraseOp(op);
409 return success();
410 }
411 return failure();
412}
413
414LogicalResult
415ICMPCanonicalizer::matchAndRewrite(comb::ICmpOp op,
416 PatternRewriter &rewriter) const {
417 auto getConstant = [&](const APInt &constant) -> Value {
418 return rewriter.create<hw::ConstantOp>(op.getLoc(), constant);
419 };
420 auto sameWidthIntegers = [](TypeRange types) -> std::optional<unsigned> {
421 if (llvm::all_equal(types) && !types.empty())
422 if (auto intType = dyn_cast<IntegerType>(*types.begin()))
423 return intType.getWidth();
424 return std::nullopt;
425 };
426 auto negate = [&](Value input) -> Value {
427 auto constTrue = rewriter.create<hw::ConstantOp>(op.getLoc(), APInt(1, 1));
428 return rewriter.create<comb::XorOp>(op.getLoc(), input, constTrue,
429 op.getTwoState());
430 };
431
432 APInt rhs;
433 if (matchPattern(op.getRhs(), mlir::m_ConstantInt(&rhs))) {
434 if (auto concatOp = op.getLhs().getDefiningOp<comb::ConcatOp>()) {
435 if (auto optionalWidth =
436 sameWidthIntegers(concatOp->getOperands().getTypes())) {
437 if ((op.getPredicate() == comb::ICmpPredicate::eq ||
438 op.getPredicate() == comb::ICmpPredicate::ne) &&
439 rhs.isAllOnes()) {
440 Value andOp = rewriter.create<comb::AndOp>(
441 op.getLoc(), concatOp.getInputs(), op.getTwoState());
442 if (*optionalWidth == 1) {
443 if (op.getPredicate() == comb::ICmpPredicate::ne)
444 andOp = negate(andOp);
445 rewriter.replaceOp(op, andOp);
446 return success();
447 }
448 rewriter.replaceOpWithNewOp<comb::ICmpOp>(
449 op, op.getPredicate(), andOp,
450 getConstant(APInt(*optionalWidth, rhs.getZExtValue(),
451 /*isSigned=*/false, /*implicitTrunc=*/true)),
452 op.getTwoState());
453 return success();
454 }
455
456 if ((op.getPredicate() == comb::ICmpPredicate::ne ||
457 op.getPredicate() == comb::ICmpPredicate::eq) &&
458 rhs.isZero()) {
459 Value orOp = rewriter.create<comb::OrOp>(
460 op.getLoc(), concatOp.getInputs(), op.getTwoState());
461 if (*optionalWidth == 1) {
462 if (op.getPredicate() == comb::ICmpPredicate::eq)
463 orOp = negate(orOp);
464 rewriter.replaceOp(op, orOp);
465 return success();
466 }
467 rewriter.replaceOpWithNewOp<comb::ICmpOp>(
468 op, op.getPredicate(), orOp,
469 getConstant(APInt(*optionalWidth, rhs.getZExtValue(),
470 /*isSigned=*/false, /*implicitTrunc=*/true)),
471 op.getTwoState());
472 return success();
473 }
474 }
475 }
476 }
477 return failure();
478}
479
480LogicalResult RemoveUnusedArcArgumentsPattern::matchAndRewrite(
481 DefineOp op, PatternRewriter &rewriter) const {
482 BitVector toDelete(op.getNumArguments());
483 for (auto [i, arg] : llvm::enumerate(op.getArguments()))
484 if (arg.use_empty())
485 toDelete.set(i);
486
487 if (toDelete.none())
488 return failure();
489
490 // Collect the mutable callers in a first iteration. If there is a user that
491 // does not implement the interface, we have to abort the rewrite and have to
492 // make sure that we didn't change anything so far.
493 SmallVector<mlir::CallOpInterface> mutableUsers;
494 for (auto *user : symbolCache.getUsers(op)) {
495 auto callOpMutable = dyn_cast<mlir::CallOpInterface>(user);
496 if (!callOpMutable)
497 return failure();
498 mutableUsers.push_back(callOpMutable);
499 }
500
501 // Do the actual rewrites.
502 for (auto user : mutableUsers)
503 for (int i = toDelete.size() - 1; i >= 0; --i)
504 if (toDelete[i])
505 user.getArgOperandsMutable().erase(i);
506
507 op.eraseArguments(toDelete);
508 op.setFunctionType(
509 rewriter.getFunctionType(op.getArgumentTypes(), op.getResultTypes()));
510
511 statistics.removeUnusedArcArgumentsPatternNumArgsRemoved += toDelete.count();
512 return success();
513}
514
515LogicalResult
516SinkArcInputsPattern::matchAndRewrite(DefineOp op,
517 PatternRewriter &rewriter) const {
518 // First check that all users implement the interface we need to be able to
519 // modify the users.
520 auto users = symbolCache.getUsers(op);
521 if (llvm::any_of(
522 users, [](auto *user) { return !isa<mlir::CallOpInterface>(user); }))
523 return failure();
524
525 // Find all arguments that use constant operands only.
526 SmallVector<Operation *> stateConsts(op.getNumArguments());
527 bool first = true;
528 for (auto *user : users) {
529 auto callOp = cast<mlir::CallOpInterface>(user);
530 for (auto [constArg, input] :
531 llvm::zip(stateConsts, callOp.getArgOperands())) {
532 if (auto *constOp = input.getDefiningOp();
533 constOp && constOp->template hasTrait<OpTrait::ConstantLike>()) {
534 if (first) {
535 constArg = constOp;
536 continue;
537 }
538 if (constArg &&
539 constArg->getName() == input.getDefiningOp()->getName() &&
540 constArg->getAttrDictionary() ==
541 input.getDefiningOp()->getAttrDictionary())
542 continue;
543 }
544 constArg = nullptr;
545 }
546 first = false;
547 }
548
549 // Move the constants into the arc and erase the block arguments.
550 rewriter.setInsertionPointToStart(&op.getBodyBlock());
551 llvm::BitVector toDelete(op.getBodyBlock().getNumArguments());
552 for (auto [constArg, arg] : llvm::zip(stateConsts, op.getArguments())) {
553 if (!constArg)
554 continue;
555 auto *inlinedConst = rewriter.clone(*constArg);
556 rewriter.replaceAllUsesWith(arg, inlinedConst->getResult(0));
557 toDelete.set(arg.getArgNumber());
558 }
559 op.getBodyBlock().eraseArguments(toDelete);
560 op.setType(rewriter.getFunctionType(op.getBodyBlock().getArgumentTypes(),
561 op.getResultTypes()));
562
563 // Rewrite all arc uses to not pass in the constant anymore.
564 for (auto *user : users) {
565 auto callOp = cast<mlir::CallOpInterface>(user);
566 SmallPtrSet<Value, 4> maybeUnusedValues;
567 SmallVector<Value> newInputs;
568 for (auto [index, value] : llvm::enumerate(callOp.getArgOperands())) {
569 if (toDelete[index])
570 maybeUnusedValues.insert(value);
571 else
572 newInputs.push_back(value);
573 }
574 rewriter.modifyOpInPlace(
575 callOp, [&]() { callOp.getArgOperandsMutable().assign(newInputs); });
576 for (auto value : maybeUnusedValues)
577 if (value.use_empty())
578 rewriter.eraseOp(value.getDefiningOp());
579 }
580
581 return success(toDelete.any());
582}
583
584LogicalResult
585CompRegCanonicalizer::matchAndRewrite(seq::CompRegOp op,
586 PatternRewriter &rewriter) const {
587 if (!op.getReset())
588 return failure();
589
590 // Because Arcilator supports constant zero reset values, skip them.
591 APInt constant;
592 if (mlir::matchPattern(op.getResetValue(), mlir::m_ConstantInt(&constant)))
593 if (constant.isZero())
594 return failure();
595
596 Value newInput = rewriter.create<comb::MuxOp>(
597 op->getLoc(), op.getReset(), op.getResetValue(), op.getInput());
598 rewriter.modifyOpInPlace(op, [&]() {
599 op.getInputMutable().set(newInput);
600 op.getResetMutable().clear();
601 op.getResetValueMutable().clear();
602 });
603
604 return success();
605}
606
607LogicalResult
608MergeVectorizeOps::matchAndRewrite(VectorizeOp vecOp,
609 PatternRewriter &rewriter) const {
610 auto &currentBlock = vecOp.getBody().front();
611 IRMapping argMapping;
612 SmallVector<Value> newOperands;
613 SmallVector<VectorizeOp> vecOpsToRemove;
614 bool canBeMerged = false;
615 // Used to calculate the new positions of args after insertions and removals
616 unsigned paddedBy = 0;
617
618 for (unsigned argIdx = 0, numArgs = vecOp.getInputs().size();
619 argIdx < numArgs; ++argIdx) {
620 auto inputVec = vecOp.getInputs()[argIdx];
621 // Make sure that the input comes from a `VectorizeOp`
622 // Ensure that the input vector matches the output of the `otherVecOp`
623 // Make sure that the results of the otherVecOp have only one use
624 auto otherVecOp = inputVec[0].getDefiningOp<VectorizeOp>();
625 if (!otherVecOp || otherVecOp == vecOp ||
626 !llvm::all_of(otherVecOp.getResults(),
627 [](auto result) { return result.hasOneUse(); }) ||
628 !llvm::all_of(inputVec, [&](auto result) {
629 return result.template getDefiningOp<VectorizeOp>() == otherVecOp;
630 })) {
631 newOperands.insert(newOperands.end(), inputVec.begin(), inputVec.end());
632 continue;
633 }
634
635 // Here, all elements are from the same `VectorizeOp`.
636 // If all elements of the input vector come from the same `VectorizeOp`
637 // sort the vectors by their indices
638 DenseMap<Value, size_t> resultIdxMap;
639 for (auto [resultIdx, result] : llvm::enumerate(otherVecOp.getResults()))
640 resultIdxMap[result] = resultIdx;
641
642 SmallVector<Value> tempVec(inputVec.begin(), inputVec.end());
643 llvm::sort(tempVec, [&](Value a, Value b) {
644 return resultIdxMap[a] < resultIdxMap[b];
645 });
646
647 // Check if inputVec matches the result after sorting.
648 if (tempVec != SmallVector<Value>(otherVecOp.getResults().begin(),
649 otherVecOp.getResults().end())) {
650 newOperands.insert(newOperands.end(), inputVec.begin(), inputVec.end());
651 continue;
652 }
653
654 DenseMap<size_t, size_t> fromRealIdxToSortedIdx;
655 for (auto [inIdx, in] : llvm::enumerate(inputVec))
656 fromRealIdxToSortedIdx[inIdx] = resultIdxMap[in];
657
658 // If this flag is set that means we changed the IR so we cannot return
659 // failure
660 canBeMerged = true;
661
662 // If the results got shuffled, then shuffle the operands before merging.
663 if (inputVec != otherVecOp.getResults()) {
664 for (auto otherVecOpInputVec : otherVecOp.getInputs()) {
665 // use the tempVec again instead of creating another one.
666 tempVec = SmallVector<Value>(inputVec.size());
667 for (auto [realIdx, opernad] : llvm::enumerate(otherVecOpInputVec))
668 tempVec[realIdx] =
669 otherVecOpInputVec[fromRealIdxToSortedIdx[realIdx]];
670
671 newOperands.insert(newOperands.end(), tempVec.begin(), tempVec.end());
672 }
673
674 } else
675 newOperands.insert(newOperands.end(), otherVecOp.getOperands().begin(),
676 otherVecOp.getOperands().end());
677
678 auto &otherBlock = otherVecOp.getBody().front();
679 for (auto &otherArg : otherBlock.getArguments()) {
680 auto newArg = currentBlock.insertArgument(
681 argIdx + paddedBy, otherArg.getType(), otherArg.getLoc());
682 argMapping.map(otherArg, newArg);
683 ++paddedBy;
684 }
685
686 rewriter.setInsertionPointToStart(&currentBlock);
687 for (auto &op : otherBlock.without_terminator())
688 rewriter.clone(op, argMapping);
689
690 unsigned argNewPos = paddedBy + argIdx;
691 // Get the result of the return value and use it in all places the
692 // the `otherVecOp` results were used
693 auto retOp = cast<VectorizeReturnOp>(otherBlock.getTerminator());
694 rewriter.replaceAllUsesWith(currentBlock.getArgument(argNewPos),
695 argMapping.lookupOrDefault(retOp.getValue()));
696 currentBlock.eraseArgument(argNewPos);
697 vecOpsToRemove.push_back(otherVecOp);
698 // We erased an arg so the padding decreased by 1
699 paddedBy--;
700 }
701
702 // We didn't change the IR as there were no vectors to merge
703 if (!canBeMerged)
704 return failure();
705
706 (void)updateInputOperands(vecOp, newOperands);
707
708 // Erase dead VectorizeOps
709 for (auto deadOp : vecOpsToRemove)
710 rewriter.eraseOp(deadOp);
711
712 return success();
713}
714
715namespace llvm {
716static unsigned hashValue(const SmallVector<Value> &inputs) {
717 unsigned hash = hash_value(inputs.size());
718 for (auto input : inputs)
719 hash = hash_combine(hash, input);
720 return hash;
721}
722
723template <>
724struct DenseMapInfo<SmallVector<Value>> {
725 static inline SmallVector<Value> getEmptyKey() {
726 return SmallVector<Value>();
727 }
728
729 static inline SmallVector<Value> getTombstoneKey() {
730 return SmallVector<Value>();
731 }
732
733 static unsigned getHashValue(const SmallVector<Value> &inputs) {
734 return hashValue(inputs);
735 }
736
737 static bool isEqual(const SmallVector<Value> &lhs,
738 const SmallVector<Value> &rhs) {
739 return lhs == rhs;
740 }
741};
742} // namespace llvm
743
744LogicalResult KeepOneVecOp::matchAndRewrite(VectorizeOp vecOp,
745 PatternRewriter &rewriter) const {
746 DenseMap<SmallVector<Value>, unsigned> inExists;
747 auto &currentBlock = vecOp.getBody().front();
748 SmallVector<Value> newOperands;
749 BitVector argsToRemove(vecOp.getInputs().size(), false);
750 for (size_t argIdx = 0; argIdx < vecOp.getInputs().size(); ++argIdx) {
751 auto input = SmallVector<Value>(vecOp.getInputs()[argIdx].begin(),
752 vecOp.getInputs()[argIdx].end());
753 if (auto in = inExists.find(input); in != inExists.end()) {
754 rewriter.replaceAllUsesWith(currentBlock.getArgument(argIdx),
755 currentBlock.getArgument(in->second));
756 argsToRemove.set(argIdx);
757 continue;
758 }
759 inExists[input] = argIdx;
760 newOperands.insert(newOperands.end(), input.begin(), input.end());
761 }
762
763 if (argsToRemove.none())
764 return failure();
765
766 currentBlock.eraseArguments(argsToRemove);
767 return updateInputOperands(vecOp, newOperands);
768}
769
770//===----------------------------------------------------------------------===//
771// ArcCanonicalizerPass implementation
772//===----------------------------------------------------------------------===//
773
774namespace {
775struct ArcCanonicalizerPass
776 : public arc::impl::ArcCanonicalizerBase<ArcCanonicalizerPass> {
777 void runOnOperation() override;
778};
779} // namespace
780
781void ArcCanonicalizerPass::runOnOperation() {
782 MLIRContext &ctxt = getContext();
783 SymbolTableCollection symbolTable;
784 SymbolHandler cache;
785 cache.addDefinitions(getOperation());
786 cache.collectAllSymbolUses(getOperation(), symbolTable);
787 Namespace names;
788 names.add(cache);
789 DenseMap<StringAttr, StringAttr> arcMapping;
790
791 mlir::GreedyRewriteConfig config;
792 config.enableRegionSimplification = mlir::GreedySimplifyRegionLevel::Disabled;
793 config.maxIterations = 10;
794 config.useTopDownTraversal = true;
795 ArcListener listener(&cache);
796 config.listener = &listener;
797
798 PatternStatistics statistics;
799 RewritePatternSet symbolPatterns(&getContext());
800 symbolPatterns.add<CallPassthroughArc, RemoveUnusedArcs,
801 RemoveUnusedArcArgumentsPattern, SinkArcInputsPattern>(
802 &getContext(), cache, names, statistics);
803 symbolPatterns.add<MemWritePortEnableAndMaskCanonicalizer>(
804 &getContext(), cache, names, statistics, arcMapping);
805
806 if (failed(mlir::applyPatternsGreedily(getOperation(),
807 std::move(symbolPatterns), config)))
808 return signalPassFailure();
809
810 numArcArgsRemoved = statistics.removeUnusedArcArgumentsPatternNumArgsRemoved;
811
812 RewritePatternSet patterns(&ctxt);
813 for (auto *dialect : ctxt.getLoadedDialects())
814 dialect->getCanonicalizationPatterns(patterns);
815 for (mlir::RegisteredOperationName op : ctxt.getRegisteredOperations())
816 op.getCanonicalizationPatterns(patterns, &ctxt);
817 patterns.add<ICMPCanonicalizer, CompRegCanonicalizer, MergeVectorizeOps,
818 KeepOneVecOp>(&getContext());
819
820 // Don't test for convergence since it is often not reached.
821 (void)mlir::applyPatternsGreedily(getOperation(), std::move(patterns),
822 config);
823}
824
825std::unique_ptr<mlir::Pass> arc::createArcCanonicalizerPass() {
826 return std::make_unique<ArcCanonicalizerPass>();
827}
LogicalResult canonicalizePassthoughCall(mlir::CallOpInterface callOp, SymbolHandler &symbolCache, PatternRewriter &rewriter)
LogicalResult updateInputOperands(VectorizeOp &vecOp, const SmallVector< Value > &newOperands)
assert(baseType &&"element must be base type")
static std::optional< APSInt > getConstant(Attribute operand)
Determine the value of a constant operand for the sake of constant folding.
static InstancePath empty
A namespace that is used to store existing names and generate new names in some scope within the IR.
Definition Namespace.h:30
void add(mlir::ModuleOp module)
Definition Namespace.h:48
Default symbol cache implementation; stores associations between names (StringAttr's) to mlir::Operat...
Definition SymCache.h:85
SymbolCacheBase::Iterator end() override
Definition SymCache.h:125
create(data_type, value)
Definition hw.py:433
std::unique_ptr< mlir::Pass > createArcCanonicalizerPass()
static llvm::hash_code hash_value(const ModulePort &port)
Definition HWTypes.h:38
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
static unsigned hashValue(const SmallVector< Value > &inputs)
static SmallVector< Value > getTombstoneKey()
static bool isEqual(const SmallVector< Value > &lhs, const SmallVector< Value > &rhs)
static unsigned getHashValue(const SmallVector< Value > &inputs)