20 #include "mlir/IR/IRMapping.h"
21 #include "mlir/IR/Matchers.h"
22 #include "mlir/IR/PatternMatch.h"
23 #include "mlir/Pass/Pass.h"
24 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
25 #include "llvm/Support/Debug.h"
27 #define DEBUG_TYPE "arc-canonicalizer"
31 #define GEN_PASS_DEF_ARCCANONICALIZER
32 #include "circt/Dialect/Arc/ArcPasses.h.inc"
36 using namespace circt;
50 ArrayRef<Operation *> getUsers(Operation *symbol)
const {
51 auto it = userMap.find(symbol);
52 return it != userMap.
end() ? it->second.getArrayRef() : std::nullopt;
56 bool useEmpty(Operation *symbol) {
57 return !userMap.count(symbol) || userMap[symbol].empty();
60 void addUser(Operation *def, Operation *user) {
61 assert(isa<mlir::SymbolOpInterface>(def));
62 if (!symbolCache.contains(cast<mlir::SymbolOpInterface>(def).getNameAttr()))
64 {cast<mlir::SymbolOpInterface>(def).getNameAttr(), def});
65 userMap[def].insert(user);
68 void removeUser(Operation *def, Operation *user) {
69 assert(isa<mlir::SymbolOpInterface>(def));
70 if (symbolCache.contains(cast<mlir::SymbolOpInterface>(def).getNameAttr()))
71 userMap[def].remove(user);
72 if (userMap[def].
empty())
76 void removeDefinitionAndAllUsers(Operation *def) {
77 assert(isa<mlir::SymbolOpInterface>(def));
78 symbolCache.erase(cast<mlir::SymbolOpInterface>(def).getNameAttr());
82 void collectAllSymbolUses(Operation *symbolTableOp,
83 SymbolTableCollection &symbolTable) {
87 SmallVector<Operation *> symbols;
88 auto walkFn = [&](Operation *symbolTableOp,
bool allUsesVisible) {
89 for (Operation &nestedOp : symbolTableOp->getRegion(0).getOps()) {
90 auto symbolUses = SymbolTable::getSymbolUses(&nestedOp);
91 assert(symbolUses &&
"expected uses to be valid");
93 for (
const SymbolTable::SymbolUse &use : *symbolUses) {
95 (void)symbolTable.lookupSymbolIn(symbolTableOp, use.getSymbolRef(),
97 for (Operation *symbolOp : symbols)
98 userMap[symbolOp].insert(use.getUser());
104 SymbolTable::walkSymbolTables(symbolTableOp,
false,
109 DenseMap<Operation *, SetVector<Operation *>> userMap;
115 class ArcListener :
public mlir::RewriterBase::Listener {
117 explicit ArcListener(SymbolHandler *handler) : Listener(), handler(handler) {}
119 void notifyOperationReplaced(Operation *op, Operation *replacement)
override {
122 auto symOp = dyn_cast<mlir::SymbolOpInterface>(op);
123 auto symReplacement = dyn_cast<mlir::SymbolOpInterface>(replacement);
124 if (symOp && symReplacement &&
125 symOp.getNameAttr() == symReplacement.getNameAttr())
134 void notifyOperationReplaced(Operation *op, ValueRange replacement)
override {
138 void notifyOperationErased(Operation *op)
override { remove(op); }
140 void notifyOperationInserted(Operation *op,
141 mlir::IRRewriter::InsertPoint)
override {
148 FailureOr<Operation *> maybeGetDefinition(Operation *op) {
149 if (
auto callOp = dyn_cast<mlir::CallOpInterface>(op)) {
151 dyn_cast<mlir::SymbolRefAttr>(callOp.getCallableForCallee());
154 if (
auto *def = handler->getDefinition(symAttr.getLeafReference()))
160 void remove(Operation *op) {
161 auto maybeDef = maybeGetDefinition(op);
162 if (!failed(maybeDef))
163 handler->removeUser(*maybeDef, op);
165 if (isa<mlir::SymbolOpInterface>(op))
166 handler->removeDefinitionAndAllUsers(op);
169 void add(Operation *op) {
170 auto maybeDef = maybeGetDefinition(op);
171 if (!failed(maybeDef))
172 handler->addUser(*maybeDef, op);
174 if (
auto defOp = dyn_cast<mlir::SymbolOpInterface>(op))
175 handler->addDefinition(defOp.getNameAttr(), op);
178 SymbolHandler *handler;
181 struct PatternStatistics {
182 unsigned removeUnusedArcArgumentsPatternNumArgsRemoved = 0;
196 template <
typename SourceOp>
199 SymOpRewritePattern(MLIRContext *
ctxt, SymbolHandler &symbolCache,
200 Namespace &names, PatternStatistics &stats,
201 mlir::PatternBenefit benefit = 1,
202 ArrayRef<StringRef> generatedNames = {})
204 symbolCache(symbolCache), statistics(stats) {}
208 SymbolHandler &symbolCache;
209 PatternStatistics &statistics;
212 class MemWritePortEnableAndMaskCanonicalizer
213 :
public SymOpRewritePattern<MemoryWritePortOp> {
215 MemWritePortEnableAndMaskCanonicalizer(
216 MLIRContext *
ctxt, SymbolHandler &symbolCache,
Namespace &names,
217 PatternStatistics &stats, DenseMap<StringAttr, StringAttr> &arcMapping)
218 : SymOpRewritePattern<MemoryWritePortOp>(
ctxt, symbolCache, names, stats),
219 arcMapping(arcMapping) {}
220 LogicalResult matchAndRewrite(MemoryWritePortOp op,
221 PatternRewriter &rewriter)
const final;
224 DenseMap<StringAttr, StringAttr> &arcMapping;
227 struct CallPassthroughArc :
public SymOpRewritePattern<CallOp> {
228 using SymOpRewritePattern::SymOpRewritePattern;
229 LogicalResult matchAndRewrite(CallOp op,
230 PatternRewriter &rewriter)
const final;
233 struct RemoveUnusedArcs :
public SymOpRewritePattern<DefineOp> {
234 using SymOpRewritePattern::SymOpRewritePattern;
235 LogicalResult matchAndRewrite(DefineOp op,
236 PatternRewriter &rewriter)
const final;
240 using OpRewritePattern::OpRewritePattern;
241 LogicalResult matchAndRewrite(comb::ICmpOp op,
242 PatternRewriter &rewriter)
const final;
246 using OpRewritePattern::OpRewritePattern;
248 PatternRewriter &rewriter)
const final;
251 struct RemoveUnusedArcArgumentsPattern :
public SymOpRewritePattern<DefineOp> {
252 using SymOpRewritePattern::SymOpRewritePattern;
253 LogicalResult matchAndRewrite(DefineOp op,
254 PatternRewriter &rewriter)
const final;
257 struct SinkArcInputsPattern :
public SymOpRewritePattern<DefineOp> {
258 using SymOpRewritePattern::SymOpRewritePattern;
259 LogicalResult matchAndRewrite(DefineOp op,
260 PatternRewriter &rewriter)
const final;
264 using OpRewritePattern::OpRewritePattern;
265 LogicalResult matchAndRewrite(VectorizeOp op,
266 PatternRewriter &rewriter)
const final;
270 using OpRewritePattern::OpRewritePattern;
271 LogicalResult matchAndRewrite(VectorizeOp op,
272 PatternRewriter &rewriter)
const final;
282 SymbolHandler &symbolCache,
283 PatternRewriter &rewriter) {
284 auto defOp = cast<DefineOp>(symbolCache.getDefinition(
285 callOp.getCallableForCallee().get<SymbolRefAttr>().getLeafReference()));
286 if (defOp.isPassthrough()) {
287 symbolCache.removeUser(defOp, callOp);
288 rewriter.replaceOp(callOp, callOp.getArgOperands());
295 const SmallVector<Value> &newOperands) {
297 unsigned groupSize = vecOp.getResults().size();
298 unsigned numOfGroups = newOperands.size() / groupSize;
299 SmallVector<int32_t> newAttr(numOfGroups, groupSize);
300 vecOp.setInputOperandSegments(newAttr);
301 vecOp.getOperation()->setOperands(ValueRange(newOperands));
309 LogicalResult MemWritePortEnableAndMaskCanonicalizer::matchAndRewrite(
310 MemoryWritePortOp op, PatternRewriter &rewriter)
const {
311 auto defOp = cast<DefineOp>(symbolCache.getDefinition(op.getArcAttr()));
314 if (op.getEnable() &&
316 defOp.getBodyBlock().getTerminator()->getOperand(op.getEnableIdx()),
317 mlir::m_ConstantInt(&enable))) {
318 if (enable.isZero()) {
319 symbolCache.removeUser(defOp, op);
320 rewriter.eraseOp(op);
321 if (symbolCache.useEmpty(defOp)) {
322 symbolCache.removeDefinitionAndAllUsers(defOp);
323 rewriter.eraseOp(defOp);
327 if (enable.isAllOnes()) {
328 if (arcMapping.count(defOp.getNameAttr())) {
329 auto arcWithoutEnable = arcMapping[defOp.getNameAttr()];
331 rewriter.modifyOpInPlace(op, [&]() {
333 op.setArc(arcWithoutEnable.getValue());
335 symbolCache.removeUser(defOp, op);
336 symbolCache.addUser(symbolCache.getDefinition(arcWithoutEnable), op);
340 auto newName = names.newName(defOp.getName());
341 auto users = SmallVector<Operation *>(symbolCache.getUsers(defOp));
342 symbolCache.removeDefinitionAndAllUsers(defOp);
345 rewriter.modifyOpInPlace(op, [&]() {
350 auto newResultTypes = op.getArcResultTypes();
353 rewriter.setInsertionPoint(defOp);
354 auto newDefOp = rewriter.cloneWithoutRegions(defOp);
355 auto *block = rewriter.createBlock(
356 &newDefOp.getBody(), newDefOp.getBody().end(),
357 newDefOp.getArgumentTypes(),
358 SmallVector<Location>(newDefOp.getNumArguments(), defOp.getLoc()));
359 auto callOp = rewriter.create<CallOp>(newDefOp.getLoc(), newResultTypes,
360 newName, block->getArguments());
361 SmallVector<Value> results(callOp->getResults());
363 newDefOp.getLoc(), rewriter.getI1Type(), 1);
364 results.insert(results.begin() + op.getEnableIdx(), constTrue);
365 rewriter.
create<OutputOp>(newDefOp.getLoc(), results);
368 auto *terminator = defOp.getBodyBlock().getTerminator();
369 rewriter.modifyOpInPlace(
370 terminator, [&]() { terminator->eraseOperand(op.getEnableIdx()); });
371 rewriter.modifyOpInPlace(defOp, [&]() {
372 defOp.setName(newName);
373 defOp.setFunctionType(
374 rewriter.getFunctionType(defOp.getArgumentTypes(), newResultTypes));
378 symbolCache.addDefinition(defOp.getNameAttr(), defOp);
379 symbolCache.addDefinition(newDefOp.getNameAttr(), newDefOp);
380 symbolCache.addUser(defOp, callOp);
381 for (
auto *user : users)
382 symbolCache.addUser(user == op ? defOp : newDefOp, user);
384 arcMapping[newDefOp.getNameAttr()] = defOp.getNameAttr();
392 CallPassthroughArc::matchAndRewrite(CallOp op,
393 PatternRewriter &rewriter)
const {
398 RemoveUnusedArcs::matchAndRewrite(DefineOp op,
399 PatternRewriter &rewriter)
const {
400 if (symbolCache.useEmpty(op)) {
401 op.getBody().walk([&](mlir::CallOpInterface user) {
402 if (
auto symbol = dyn_cast<SymbolRefAttr>(user.getCallableForCallee()))
403 if (
auto *defOp = symbolCache.getDefinition(symbol.getLeafReference()))
404 symbolCache.removeUser(defOp, user);
406 symbolCache.removeDefinitionAndAllUsers(op);
407 rewriter.eraseOp(op);
414 ICMPCanonicalizer::matchAndRewrite(comb::ICmpOp op,
415 PatternRewriter &rewriter)
const {
416 auto getConstant = [&](
const APInt &constant) -> Value {
419 auto sameWidthIntegers = [](TypeRange types) -> std::optional<unsigned> {
420 if (llvm::all_equal(types) && !types.empty())
421 if (
auto intType = dyn_cast<IntegerType>(*types.begin()))
422 return intType.getWidth();
425 auto negate = [&](Value input) -> Value {
432 if (matchPattern(op.getRhs(), mlir::m_ConstantInt(&rhs))) {
433 if (
auto concatOp = op.getLhs().getDefiningOp<
comb::ConcatOp>()) {
434 if (
auto optionalWidth =
435 sameWidthIntegers(concatOp->getOperands().getTypes())) {
436 if ((op.getPredicate() == comb::ICmpPredicate::eq ||
437 op.getPredicate() == comb::ICmpPredicate::ne) &&
440 op.getLoc(), concatOp.getInputs(), op.getTwoState());
441 if (*optionalWidth == 1) {
442 if (op.getPredicate() == comb::ICmpPredicate::ne)
443 andOp = negate(andOp);
444 rewriter.replaceOp(op, andOp);
447 rewriter.replaceOpWithNewOp<comb::ICmpOp>(
448 op, op.getPredicate(), andOp,
449 getConstant(APInt(*optionalWidth, rhs.getZExtValue())),
454 if ((op.getPredicate() == comb::ICmpPredicate::ne ||
455 op.getPredicate() == comb::ICmpPredicate::eq) &&
458 op.getLoc(), concatOp.getInputs(), op.getTwoState());
459 if (*optionalWidth == 1) {
460 if (op.getPredicate() == comb::ICmpPredicate::eq)
462 rewriter.replaceOp(op, orOp);
465 rewriter.replaceOpWithNewOp<comb::ICmpOp>(
466 op, op.getPredicate(), orOp,
467 getConstant(APInt(*optionalWidth, rhs.getZExtValue())),
477 LogicalResult RemoveUnusedArcArgumentsPattern::matchAndRewrite(
478 DefineOp op, PatternRewriter &rewriter)
const {
479 BitVector toDelete(op.getNumArguments());
480 for (
auto [i, arg] : llvm::enumerate(op.getArguments()))
490 SmallVector<mlir::CallOpInterface> mutableUsers;
491 for (
auto *user : symbolCache.getUsers(op)) {
492 auto callOpMutable = dyn_cast<mlir::CallOpInterface>(user);
495 mutableUsers.push_back(callOpMutable);
499 for (
auto user : mutableUsers)
500 for (
int i = toDelete.size() - 1; i >= 0; --i)
502 user.getArgOperandsMutable().erase(i);
504 op.eraseArguments(toDelete);
506 rewriter.getFunctionType(op.getArgumentTypes(), op.getResultTypes()));
508 statistics.removeUnusedArcArgumentsPatternNumArgsRemoved += toDelete.count();
513 SinkArcInputsPattern::matchAndRewrite(DefineOp op,
514 PatternRewriter &rewriter)
const {
517 auto users = symbolCache.getUsers(op);
519 users, [](
auto *user) {
return !isa<mlir::CallOpInterface>(user); }))
523 SmallVector<Operation *> stateConsts(op.getNumArguments());
525 for (
auto *user : users) {
526 auto callOp = cast<mlir::CallOpInterface>(user);
527 for (
auto [constArg, input] :
528 llvm::zip(stateConsts, callOp.getArgOperands())) {
529 if (
auto *constOp = input.getDefiningOp();
530 constOp && constOp->template hasTrait<OpTrait::ConstantLike>()) {
536 constArg->getName() == input.getDefiningOp()->getName() &&
537 constArg->getAttrDictionary() ==
538 input.getDefiningOp()->getAttrDictionary())
547 rewriter.setInsertionPointToStart(&op.getBodyBlock());
548 llvm::BitVector toDelete(op.getBodyBlock().getNumArguments());
549 for (
auto [constArg, arg] : llvm::zip(stateConsts, op.getArguments())) {
552 auto *inlinedConst = rewriter.clone(*constArg);
553 rewriter.replaceAllUsesWith(arg, inlinedConst->getResult(0));
554 toDelete.set(arg.getArgNumber());
556 op.getBodyBlock().eraseArguments(toDelete);
557 op.setType(rewriter.getFunctionType(op.getBodyBlock().getArgumentTypes(),
558 op.getResultTypes()));
561 for (
auto *user : users) {
562 auto callOp = cast<mlir::CallOpInterface>(user);
563 SmallPtrSet<Value, 4> maybeUnusedValues;
564 SmallVector<Value> newInputs;
565 for (
auto [index, value] : llvm::enumerate(callOp.getArgOperands())) {
567 maybeUnusedValues.insert(value);
569 newInputs.push_back(value);
571 rewriter.modifyOpInPlace(
572 callOp, [&]() { callOp.getArgOperandsMutable().assign(newInputs); });
573 for (
auto value : maybeUnusedValues)
574 if (value.use_empty())
575 rewriter.eraseOp(value.getDefiningOp());
578 return success(toDelete.any());
583 PatternRewriter &rewriter)
const {
589 if (mlir::matchPattern(op.getResetValue(), mlir::m_ConstantInt(&constant)))
590 if (constant.isZero())
594 op->getLoc(), op.getReset(), op.getResetValue(), op.getInput());
595 rewriter.modifyOpInPlace(op, [&]() {
596 op.getInputMutable().set(newInput);
597 op.getResetMutable().clear();
598 op.getResetValueMutable().clear();
605 MergeVectorizeOps::matchAndRewrite(VectorizeOp vecOp,
606 PatternRewriter &rewriter)
const {
607 auto ¤tBlock = vecOp.getBody().front();
608 IRMapping argMapping;
609 SmallVector<Value> newOperands;
610 SmallVector<VectorizeOp> vecOpsToRemove;
611 bool canBeMerged =
false;
613 unsigned paddedBy = 0;
615 for (
unsigned argIdx = 0, numArgs = vecOp.getInputs().size();
616 argIdx < numArgs; ++argIdx) {
617 auto inputVec = vecOp.getInputs()[argIdx];
621 auto otherVecOp = inputVec[0].getDefiningOp<VectorizeOp>();
622 if (!otherVecOp || inputVec != otherVecOp.getResults() ||
623 otherVecOp == vecOp ||
624 !llvm::all_of(otherVecOp.getResults(),
625 [](
auto result) { return result.hasOneUse(); })) {
626 newOperands.insert(newOperands.end(), inputVec.begin(), inputVec.end());
632 newOperands.insert(newOperands.end(), otherVecOp.getOperands().begin(),
633 otherVecOp.getOperands().end());
635 auto &otherBlock = otherVecOp.getBody().front();
636 for (
auto &otherArg : otherBlock.getArguments()) {
637 auto newArg = currentBlock.insertArgument(
638 argIdx + paddedBy, otherArg.getType(), otherArg.getLoc());
639 argMapping.map(otherArg, newArg);
643 rewriter.setInsertionPointToStart(¤tBlock);
644 for (
auto &op : otherBlock.without_terminator())
645 rewriter.clone(op, argMapping);
647 unsigned argNewPos = paddedBy + argIdx;
650 auto retOp = cast<VectorizeReturnOp>(otherBlock.getTerminator());
651 rewriter.replaceAllUsesWith(currentBlock.getArgument(argNewPos),
652 argMapping.lookupOrDefault(retOp.getValue()));
653 currentBlock.eraseArgument(argNewPos);
654 vecOpsToRemove.push_back(otherVecOp);
666 for (
auto deadOp : vecOpsToRemove)
667 rewriter.eraseOp(deadOp);
673 static unsigned hashValue(
const SmallVector<Value> &inputs) {
675 for (
auto input : inputs)
676 hash = hash_combine(hash, input);
681 struct DenseMapInfo<SmallVector<Value>> {
683 return SmallVector<Value>();
687 return SmallVector<Value>();
694 static bool isEqual(
const SmallVector<Value> &lhs,
695 const SmallVector<Value> &rhs) {
701 LogicalResult KeepOneVecOp::matchAndRewrite(VectorizeOp vecOp,
702 PatternRewriter &rewriter)
const {
703 BitVector argsToRemove(vecOp.getInputs().size(),
false);
704 DenseMap<SmallVector<Value>,
unsigned> inExists;
705 auto ¤tBlock = vecOp.getBody().front();
706 SmallVector<Value> newOperands;
707 unsigned shuffledBy = 0;
708 bool changed =
false;
709 for (
auto [argIdx, inputVec] : llvm::enumerate(vecOp.getInputs())) {
710 auto input = SmallVector<Value>(inputVec.begin(), inputVec.end());
711 if (
auto in = inExists.find(input); in != inExists.end()) {
712 argsToRemove.set(argIdx);
713 rewriter.replaceAllUsesWith(currentBlock.getArgument(argIdx - shuffledBy),
714 currentBlock.getArgument(in->second));
715 currentBlock.eraseArgument(argIdx - shuffledBy);
720 inExists[input] = argIdx;
721 newOperands.insert(newOperands.end(), inputVec.begin(), inputVec.end());
734 struct ArcCanonicalizerPass
735 :
public arc::impl::ArcCanonicalizerBase<ArcCanonicalizerPass> {
736 void runOnOperation()
override;
740 void ArcCanonicalizerPass::runOnOperation() {
741 MLIRContext &
ctxt = getContext();
742 SymbolTableCollection symbolTable;
744 cache.addDefinitions(getOperation());
745 cache.collectAllSymbolUses(getOperation(), symbolTable);
748 DenseMap<StringAttr, StringAttr> arcMapping;
750 mlir::GreedyRewriteConfig config;
751 config.enableRegionSimplification = mlir::GreedySimplifyRegionLevel::Disabled;
752 config.maxIterations = 10;
753 config.useTopDownTraversal =
true;
754 ArcListener listener(&cache);
755 config.listener = &listener;
757 PatternStatistics statistics;
758 RewritePatternSet symbolPatterns(&getContext());
759 symbolPatterns.add<CallPassthroughArc, RemoveUnusedArcs,
760 RemoveUnusedArcArgumentsPattern, SinkArcInputsPattern>(
761 &getContext(), cache, names, statistics);
762 symbolPatterns.add<MemWritePortEnableAndMaskCanonicalizer>(
763 &getContext(), cache, names, statistics, arcMapping);
765 if (failed(mlir::applyPatternsAndFoldGreedily(
766 getOperation(), std::move(symbolPatterns), config)))
767 return signalPassFailure();
769 numArcArgsRemoved = statistics.removeUnusedArcArgumentsPatternNumArgsRemoved;
772 for (
auto *dialect :
ctxt.getLoadedDialects())
773 dialect->getCanonicalizationPatterns(
patterns);
774 for (mlir::RegisteredOperationName op :
ctxt.getRegisteredOperations())
776 patterns.add<ICMPCanonicalizer, CompRegCanonicalizer, MergeVectorizeOps,
777 KeepOneVecOp>(&getContext());
780 (void)mlir::applyPatternsAndFoldGreedily(getOperation(), std::move(
patterns),
785 return std::make_unique<ArcCanonicalizerPass>();
LogicalResult canonicalizePassthoughCall(mlir::CallOpInterface callOp, SymbolHandler &symbolCache, PatternRewriter &rewriter)
LogicalResult updateInputOperands(VectorizeOp &vecOp, const SmallVector< Value > &newOperands)
assert(baseType &&"element must be base type")
static std::optional< APSInt > getConstant(Attribute operand)
Determine the value of a constant operand for the sake of constant folding.
static InstancePath empty
A namespace that is used to store existing names and generate new names in some scope within the IR.
void add(mlir::ModuleOp module)
Default symbol cache implementation; stores associations between names (StringAttr's) to mlir::Operat...
SymbolCacheBase::Iterator end() override
def create(data_type, value)
std::unique_ptr< mlir::Pass > createArcCanonicalizerPass()
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
inline ::llvm::hash_code hash_value(const FieldRef &fieldRef)
Get a hash code for a FieldRef.
static unsigned hashValue(const SmallVector< Value > &inputs)
static SmallVector< Value > getTombstoneKey()
static SmallVector< Value > getEmptyKey()
static bool isEqual(const SmallVector< Value > &lhs, const SmallVector< Value > &rhs)
static unsigned getHashValue(const SmallVector< Value > &inputs)