20 #include "mlir/IR/Matchers.h"
21 #include "mlir/IR/PatternMatch.h"
22 #include "mlir/Pass/Pass.h"
23 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
24 #include "llvm/Support/Debug.h"
26 #define DEBUG_TYPE "arc-canonicalizer"
30 #define GEN_PASS_DEF_ARCCANONICALIZER
31 #include "circt/Dialect/Arc/ArcPasses.h.inc"
35 using namespace circt;
49 ArrayRef<Operation *> getUsers(Operation *symbol)
const {
50 auto it = userMap.find(symbol);
51 return it != userMap.
end() ? it->second.getArrayRef() : std::nullopt;
55 bool useEmpty(Operation *symbol) {
56 return !userMap.count(symbol) || userMap[symbol].empty();
59 void addUser(Operation *def, Operation *user) {
60 assert(isa<mlir::SymbolOpInterface>(def));
61 if (!symbolCache.contains(cast<mlir::SymbolOpInterface>(def).getNameAttr()))
63 {cast<mlir::SymbolOpInterface>(def).getNameAttr(), def});
64 userMap[def].insert(user);
67 void removeUser(Operation *def, Operation *user) {
68 assert(isa<mlir::SymbolOpInterface>(def));
69 if (symbolCache.contains(cast<mlir::SymbolOpInterface>(def).getNameAttr()))
70 userMap[def].remove(user);
71 if (userMap[def].
empty())
75 void removeDefinitionAndAllUsers(Operation *def) {
76 assert(isa<mlir::SymbolOpInterface>(def));
77 symbolCache.erase(cast<mlir::SymbolOpInterface>(def).getNameAttr());
81 void collectAllSymbolUses(Operation *symbolTableOp,
82 SymbolTableCollection &symbolTable) {
86 SmallVector<Operation *> symbols;
87 auto walkFn = [&](Operation *symbolTableOp,
bool allUsesVisible) {
88 for (Operation &nestedOp : symbolTableOp->getRegion(0).getOps()) {
89 auto symbolUses = SymbolTable::getSymbolUses(&nestedOp);
90 assert(symbolUses &&
"expected uses to be valid");
92 for (
const SymbolTable::SymbolUse &use : *symbolUses) {
94 (void)symbolTable.lookupSymbolIn(symbolTableOp, use.getSymbolRef(),
96 for (Operation *symbolOp : symbols)
97 userMap[symbolOp].insert(use.getUser());
103 SymbolTable::walkSymbolTables(symbolTableOp,
false,
108 DenseMap<Operation *, SetVector<Operation *>> userMap;
114 class ArcListener :
public mlir::RewriterBase::Listener {
116 explicit ArcListener(SymbolHandler *handler) : Listener(), handler(handler) {}
118 void notifyOperationReplaced(Operation *op, Operation *replacement)
override {
121 auto symOp = dyn_cast<mlir::SymbolOpInterface>(op);
122 auto symReplacement = dyn_cast<mlir::SymbolOpInterface>(replacement);
123 if (symOp && symReplacement &&
124 symOp.getNameAttr() == symReplacement.getNameAttr())
133 void notifyOperationReplaced(Operation *op, ValueRange replacement)
override {
137 void notifyOperationErased(Operation *op)
override { remove(op); }
139 void notifyOperationInserted(Operation *op,
140 mlir::IRRewriter::InsertPoint)
override {
148 if (
auto callOp = dyn_cast<mlir::CallOpInterface>(op)) {
150 dyn_cast<mlir::SymbolRefAttr>(callOp.getCallableForCallee());
153 if (
auto *def = handler->getDefinition(symAttr.getLeafReference()))
159 void remove(Operation *op) {
160 auto maybeDef = maybeGetDefinition(op);
161 if (!failed(maybeDef))
162 handler->removeUser(*maybeDef, op);
164 if (isa<mlir::SymbolOpInterface>(op))
165 handler->removeDefinitionAndAllUsers(op);
168 void add(Operation *op) {
169 auto maybeDef = maybeGetDefinition(op);
170 if (!failed(maybeDef))
171 handler->addUser(*maybeDef, op);
173 if (
auto defOp = dyn_cast<mlir::SymbolOpInterface>(op))
174 handler->addDefinition(defOp.getNameAttr(), op);
177 SymbolHandler *handler;
180 struct PatternStatistics {
181 unsigned removeUnusedArcArgumentsPatternNumArgsRemoved = 0;
195 template <
typename SourceOp>
198 SymOpRewritePattern(MLIRContext *
ctxt, SymbolHandler &symbolCache,
199 Namespace &names, PatternStatistics &stats,
200 mlir::PatternBenefit benefit = 1,
201 ArrayRef<StringRef> generatedNames = {})
203 symbolCache(symbolCache), statistics(stats) {}
207 SymbolHandler &symbolCache;
208 PatternStatistics &statistics;
211 class MemWritePortEnableAndMaskCanonicalizer
212 :
public SymOpRewritePattern<MemoryWritePortOp> {
214 MemWritePortEnableAndMaskCanonicalizer(
215 MLIRContext *
ctxt, SymbolHandler &symbolCache,
Namespace &names,
216 PatternStatistics &stats, DenseMap<StringAttr, StringAttr> &arcMapping)
217 : SymOpRewritePattern<MemoryWritePortOp>(
ctxt, symbolCache, names, stats),
218 arcMapping(arcMapping) {}
219 LogicalResult matchAndRewrite(MemoryWritePortOp op,
220 PatternRewriter &rewriter)
const final;
223 DenseMap<StringAttr, StringAttr> &arcMapping;
226 struct CallPassthroughArc :
public SymOpRewritePattern<CallOp> {
227 using SymOpRewritePattern::SymOpRewritePattern;
228 LogicalResult matchAndRewrite(CallOp op,
229 PatternRewriter &rewriter)
const final;
232 struct RemoveUnusedArcs :
public SymOpRewritePattern<DefineOp> {
233 using SymOpRewritePattern::SymOpRewritePattern;
234 LogicalResult matchAndRewrite(DefineOp op,
235 PatternRewriter &rewriter)
const final;
239 using OpRewritePattern::OpRewritePattern;
240 LogicalResult matchAndRewrite(comb::ICmpOp op,
241 PatternRewriter &rewriter)
const final;
245 using OpRewritePattern::OpRewritePattern;
247 PatternRewriter &rewriter)
const final;
250 struct RemoveUnusedArcArgumentsPattern :
public SymOpRewritePattern<DefineOp> {
251 using SymOpRewritePattern::SymOpRewritePattern;
252 LogicalResult matchAndRewrite(DefineOp op,
253 PatternRewriter &rewriter)
const final;
256 struct SinkArcInputsPattern :
public SymOpRewritePattern<DefineOp> {
257 using SymOpRewritePattern::SymOpRewritePattern;
258 LogicalResult matchAndRewrite(DefineOp op,
259 PatternRewriter &rewriter)
const final;
269 SymbolHandler &symbolCache,
270 PatternRewriter &rewriter) {
271 auto defOp = cast<DefineOp>(symbolCache.getDefinition(
272 callOp.getCallableForCallee().get<SymbolRefAttr>().getLeafReference()));
273 if (defOp.isPassthrough()) {
274 symbolCache.removeUser(defOp, callOp);
275 rewriter.replaceOp(callOp, callOp.getArgOperands());
285 LogicalResult MemWritePortEnableAndMaskCanonicalizer::matchAndRewrite(
286 MemoryWritePortOp op, PatternRewriter &rewriter)
const {
287 auto defOp = cast<DefineOp>(symbolCache.getDefinition(op.getArcAttr()));
290 if (op.getEnable() &&
292 defOp.getBodyBlock().getTerminator()->getOperand(op.getEnableIdx()),
293 mlir::m_ConstantInt(&enable))) {
294 if (enable.isZero()) {
295 symbolCache.removeUser(defOp, op);
296 rewriter.eraseOp(op);
297 if (symbolCache.useEmpty(defOp)) {
298 symbolCache.removeDefinitionAndAllUsers(defOp);
299 rewriter.eraseOp(defOp);
303 if (enable.isAllOnes()) {
304 if (arcMapping.count(defOp.getNameAttr())) {
305 auto arcWithoutEnable = arcMapping[defOp.getNameAttr()];
307 rewriter.modifyOpInPlace(op, [&]() {
309 op.setArc(arcWithoutEnable.getValue());
311 symbolCache.removeUser(defOp, op);
312 symbolCache.addUser(symbolCache.getDefinition(arcWithoutEnable), op);
316 auto newName = names.newName(defOp.getName());
317 auto users = SmallVector<Operation *>(symbolCache.getUsers(defOp));
318 symbolCache.removeDefinitionAndAllUsers(defOp);
321 rewriter.modifyOpInPlace(op, [&]() {
326 auto newResultTypes = op.getArcResultTypes();
329 rewriter.setInsertionPoint(defOp);
330 auto newDefOp = rewriter.cloneWithoutRegions(defOp);
331 auto *block = rewriter.createBlock(
332 &newDefOp.getBody(), newDefOp.getBody().end(),
333 newDefOp.getArgumentTypes(),
334 SmallVector<Location>(newDefOp.getNumArguments(), defOp.getLoc()));
335 auto callOp = rewriter.create<CallOp>(newDefOp.getLoc(), newResultTypes,
336 newName, block->getArguments());
337 SmallVector<Value> results(callOp->getResults());
339 newDefOp.getLoc(), rewriter.getI1Type(), 1);
340 results.insert(results.begin() + op.getEnableIdx(), constTrue);
341 rewriter.
create<OutputOp>(newDefOp.getLoc(), results);
344 auto *terminator = defOp.getBodyBlock().getTerminator();
345 rewriter.modifyOpInPlace(
346 terminator, [&]() { terminator->eraseOperand(op.getEnableIdx()); });
347 rewriter.modifyOpInPlace(defOp, [&]() {
348 defOp.setName(newName);
349 defOp.setFunctionType(
350 rewriter.getFunctionType(defOp.getArgumentTypes(), newResultTypes));
354 symbolCache.addDefinition(defOp.getNameAttr(), defOp);
355 symbolCache.addDefinition(newDefOp.getNameAttr(), newDefOp);
356 symbolCache.addUser(defOp, callOp);
357 for (
auto *user : users)
358 symbolCache.addUser(user == op ? defOp : newDefOp, user);
360 arcMapping[newDefOp.getNameAttr()] = defOp.getNameAttr();
368 CallPassthroughArc::matchAndRewrite(CallOp op,
369 PatternRewriter &rewriter)
const {
374 RemoveUnusedArcs::matchAndRewrite(DefineOp op,
375 PatternRewriter &rewriter)
const {
376 if (symbolCache.useEmpty(op)) {
377 op.getBody().walk([&](mlir::CallOpInterface user) {
378 if (
auto symbol = dyn_cast<SymbolRefAttr>(user.getCallableForCallee()))
379 if (
auto *defOp = symbolCache.getDefinition(symbol.getLeafReference()))
380 symbolCache.removeUser(defOp, user);
382 symbolCache.removeDefinitionAndAllUsers(op);
383 rewriter.eraseOp(op);
390 ICMPCanonicalizer::matchAndRewrite(comb::ICmpOp op,
391 PatternRewriter &rewriter)
const {
392 auto getConstant = [&](
const APInt &constant) -> Value {
395 auto sameWidthIntegers = [](TypeRange types) -> std::optional<unsigned> {
396 if (llvm::all_equal(types) && !types.empty())
397 if (
auto intType = dyn_cast<IntegerType>(*types.begin()))
398 return intType.getWidth();
401 auto negate = [&](Value input) -> Value {
408 if (matchPattern(op.getRhs(), mlir::m_ConstantInt(&rhs))) {
409 if (
auto concatOp = op.getLhs().getDefiningOp<
comb::ConcatOp>()) {
410 if (
auto optionalWidth =
411 sameWidthIntegers(concatOp->getOperands().getTypes())) {
412 if ((op.getPredicate() == comb::ICmpPredicate::eq ||
413 op.getPredicate() == comb::ICmpPredicate::ne) &&
416 op.getLoc(), concatOp.getInputs(), op.getTwoState());
417 if (*optionalWidth == 1) {
418 if (op.getPredicate() == comb::ICmpPredicate::ne)
419 andOp = negate(andOp);
420 rewriter.replaceOp(op, andOp);
423 rewriter.replaceOpWithNewOp<comb::ICmpOp>(
424 op, op.getPredicate(), andOp,
425 getConstant(APInt(*optionalWidth, rhs.getZExtValue())),
430 if ((op.getPredicate() == comb::ICmpPredicate::ne ||
431 op.getPredicate() == comb::ICmpPredicate::eq) &&
434 op.getLoc(), concatOp.getInputs(), op.getTwoState());
435 if (*optionalWidth == 1) {
436 if (op.getPredicate() == comb::ICmpPredicate::eq)
438 rewriter.replaceOp(op, orOp);
441 rewriter.replaceOpWithNewOp<comb::ICmpOp>(
442 op, op.getPredicate(), orOp,
443 getConstant(APInt(*optionalWidth, rhs.getZExtValue())),
453 LogicalResult RemoveUnusedArcArgumentsPattern::matchAndRewrite(
454 DefineOp op, PatternRewriter &rewriter)
const {
455 BitVector toDelete(op.getNumArguments());
456 for (
auto [i, arg] : llvm::enumerate(op.getArguments()))
466 SmallVector<mlir::CallOpInterface> mutableUsers;
467 for (
auto *user : symbolCache.getUsers(op)) {
468 auto callOpMutable = dyn_cast<mlir::CallOpInterface>(user);
471 mutableUsers.push_back(callOpMutable);
475 for (
auto user : mutableUsers)
476 for (
int i = toDelete.size() - 1; i >= 0; --i)
478 user.getArgOperandsMutable().erase(i);
480 op.eraseArguments(toDelete);
482 rewriter.getFunctionType(op.getArgumentTypes(), op.getResultTypes()));
484 statistics.removeUnusedArcArgumentsPatternNumArgsRemoved += toDelete.count();
489 SinkArcInputsPattern::matchAndRewrite(DefineOp op,
490 PatternRewriter &rewriter)
const {
493 auto users = symbolCache.getUsers(op);
495 users, [](
auto *user) {
return !isa<mlir::CallOpInterface>(user); }))
499 SmallVector<Operation *> stateConsts(op.getNumArguments());
501 for (
auto *user : users) {
502 auto callOp = cast<mlir::CallOpInterface>(user);
503 for (
auto [constArg, input] :
504 llvm::zip(stateConsts, callOp.getArgOperands())) {
505 if (
auto *constOp = input.getDefiningOp();
506 constOp && constOp->template hasTrait<OpTrait::ConstantLike>()) {
512 constArg->getName() == input.getDefiningOp()->getName() &&
513 constArg->getAttrDictionary() ==
514 input.getDefiningOp()->getAttrDictionary())
523 rewriter.setInsertionPointToStart(&op.getBodyBlock());
524 llvm::BitVector toDelete(op.getBodyBlock().getNumArguments());
525 for (
auto [constArg, arg] : llvm::zip(stateConsts, op.getArguments())) {
528 auto *inlinedConst = rewriter.clone(*constArg);
529 rewriter.replaceAllUsesWith(arg, inlinedConst->getResult(0));
530 toDelete.set(arg.getArgNumber());
532 op.getBodyBlock().eraseArguments(toDelete);
533 op.setType(rewriter.getFunctionType(op.getBodyBlock().getArgumentTypes(),
534 op.getResultTypes()));
537 for (
auto *user : users) {
538 auto callOp = cast<mlir::CallOpInterface>(user);
539 SmallPtrSet<Value, 4> maybeUnusedValues;
540 SmallVector<Value> newInputs;
541 for (
auto [index, value] : llvm::enumerate(callOp.getArgOperands())) {
543 maybeUnusedValues.insert(value);
545 newInputs.push_back(value);
547 rewriter.modifyOpInPlace(
548 callOp, [&]() { callOp.getArgOperandsMutable().assign(newInputs); });
549 for (
auto value : maybeUnusedValues)
550 if (value.use_empty())
551 rewriter.eraseOp(value.getDefiningOp());
554 return success(toDelete.any());
559 PatternRewriter &rewriter)
const {
565 if (mlir::matchPattern(op.getResetValue(), mlir::m_ConstantInt(&constant)))
566 if (constant.isZero())
570 op->getLoc(), op.getReset(), op.getResetValue(), op.getInput());
571 rewriter.modifyOpInPlace(op, [&]() {
572 op.getInputMutable().set(newInput);
573 op.getResetMutable().clear();
574 op.getResetValueMutable().clear();
585 struct ArcCanonicalizerPass
586 :
public arc::impl::ArcCanonicalizerBase<ArcCanonicalizerPass> {
587 void runOnOperation()
override;
591 void ArcCanonicalizerPass::runOnOperation() {
592 MLIRContext &
ctxt = getContext();
593 SymbolTableCollection symbolTable;
595 cache.addDefinitions(getOperation());
596 cache.collectAllSymbolUses(getOperation(), symbolTable);
599 DenseMap<StringAttr, StringAttr> arcMapping;
601 mlir::GreedyRewriteConfig config;
602 config.enableRegionSimplification =
false;
603 config.maxIterations = 10;
604 config.useTopDownTraversal =
true;
605 ArcListener listener(&cache);
606 config.listener = &listener;
608 PatternStatistics statistics;
609 RewritePatternSet symbolPatterns(&getContext());
610 symbolPatterns.add<CallPassthroughArc, RemoveUnusedArcs,
611 RemoveUnusedArcArgumentsPattern, SinkArcInputsPattern>(
612 &getContext(), cache, names, statistics);
613 symbolPatterns.add<MemWritePortEnableAndMaskCanonicalizer>(
614 &getContext(), cache, names, statistics, arcMapping);
616 if (failed(mlir::applyPatternsAndFoldGreedily(
617 getOperation(), std::move(symbolPatterns), config)))
618 return signalPassFailure();
620 numArcArgsRemoved = statistics.removeUnusedArcArgumentsPatternNumArgsRemoved;
623 for (
auto *dialect :
ctxt.getLoadedDialects())
624 dialect->getCanonicalizationPatterns(
patterns);
625 for (mlir::RegisteredOperationName op :
ctxt.getRegisteredOperations())
627 patterns.add<ICMPCanonicalizer, CompRegCanonicalizer>(&getContext());
630 (void)mlir::applyPatternsAndFoldGreedily(getOperation(), std::move(
patterns),
635 return std::make_unique<ArcCanonicalizerPass>();
LogicalResult canonicalizePassthoughCall(mlir::CallOpInterface callOp, SymbolHandler &symbolCache, PatternRewriter &rewriter)
assert(baseType &&"element must be base type")
static std::optional< APSInt > getConstant(Attribute operand)
Determine the value of a constant operand for the sake of constant folding.
static InstancePath empty
A namespace that is used to store existing names and generate new names in some scope within the IR.
void add(SymbolCache &symCache)
SymbolCache initializer; initialize from every key that is convertible to a StringAttr in the SymbolC...
Default symbol cache implementation; stores associations between names (StringAttr's) to mlir::Operat...
SymbolCacheBase::Iterator end() override
def create(data_type, value)
std::unique_ptr< mlir::Pass > createArcCanonicalizerPass()
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.