19 #include "mlir/IR/Matchers.h"
20 #include "mlir/IR/PatternMatch.h"
21 #include "mlir/Pass/Pass.h"
22 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
23 #include "llvm/Support/Debug.h"
25 #define DEBUG_TYPE "arc-canonicalizer"
29 #define GEN_PASS_DEF_ARCCANONICALIZER
30 #include "circt/Dialect/Arc/ArcPasses.h.inc"
34 using namespace circt;
46 ArrayRef<Operation *>
getUsers(Operation *symbol)
const {
47 auto it = userMap.find(symbol);
48 return it != userMap.end() ? it->second.getArrayRef() : std::nullopt;
53 return !userMap.count(symbol) || userMap[symbol].empty();
56 void addUser(Operation *def, Operation *user) {
57 assert(isa<mlir::SymbolOpInterface>(def));
58 if (!symbolCache.contains(cast<mlir::SymbolOpInterface>(def).getNameAttr()))
60 {cast<mlir::SymbolOpInterface>(def).getNameAttr(), def});
61 userMap[def].insert(user);
65 assert(isa<mlir::SymbolOpInterface>(def));
66 if (symbolCache.contains(cast<mlir::SymbolOpInterface>(def).getNameAttr()))
67 userMap[def].remove(user);
68 if (userMap[def].
empty())
73 assert(isa<mlir::SymbolOpInterface>(def));
74 symbolCache.erase(cast<mlir::SymbolOpInterface>(def).getNameAttr());
79 SymbolTableCollection &symbolTable) {
83 SmallVector<Operation *> symbols;
84 auto walkFn = [&](Operation *symbolTableOp,
bool allUsesVisible) {
85 for (Operation &nestedOp : symbolTableOp->getRegion(0).getOps()) {
86 auto symbolUses = SymbolTable::getSymbolUses(&nestedOp);
87 assert(symbolUses &&
"expected uses to be valid");
89 for (
const SymbolTable::SymbolUse &use : *symbolUses) {
91 (void)symbolTable.lookupSymbolIn(symbolTableOp, use.getSymbolRef(),
93 for (Operation *symbolOp : symbols)
94 userMap[symbolOp].insert(use.getUser());
100 SymbolTable::walkSymbolTables(symbolTableOp,
false,
105 DenseMap<Operation *, SetVector<Operation *>>
userMap;
109 unsigned removeUnusedArcArgumentsPatternNumArgsRemoved = 0;
121 template <
typename SourceOp>
124 SymOpRewritePattern(MLIRContext *ctxt,
SymbolHandler &symbolCache,
126 mlir::PatternBenefit benefit = 1,
127 ArrayRef<StringRef> generatedNames = {})
129 symbolCache(symbolCache), statistics(stats) {}
137 class MemWritePortEnableAndMaskCanonicalizer
138 :
public SymOpRewritePattern<MemoryWritePortOp> {
140 MemWritePortEnableAndMaskCanonicalizer(
143 : SymOpRewritePattern<MemoryWritePortOp>(ctxt, symbolCache, names, stats),
144 arcMapping(arcMapping) {}
145 LogicalResult matchAndRewrite(MemoryWritePortOp op,
146 PatternRewriter &rewriter)
const final;
149 DenseMap<StringAttr, StringAttr> &arcMapping;
152 struct CallPassthroughArc :
public SymOpRewritePattern<CallOp> {
153 using SymOpRewritePattern::SymOpRewritePattern;
154 LogicalResult matchAndRewrite(CallOp op,
155 PatternRewriter &rewriter)
const final;
158 struct StatePassthroughArc :
public SymOpRewritePattern<StateOp> {
159 using SymOpRewritePattern::SymOpRewritePattern;
160 LogicalResult matchAndRewrite(StateOp op,
161 PatternRewriter &rewriter)
const final;
164 struct RemoveUnusedArcs :
public SymOpRewritePattern<DefineOp> {
165 using SymOpRewritePattern::SymOpRewritePattern;
166 LogicalResult matchAndRewrite(DefineOp op,
167 PatternRewriter &rewriter)
const final;
171 using OpRewritePattern::OpRewritePattern;
172 LogicalResult matchAndRewrite(comb::ICmpOp op,
173 PatternRewriter &rewriter)
const final;
176 struct RemoveUnusedArcArgumentsPattern :
public SymOpRewritePattern<DefineOp> {
177 using SymOpRewritePattern::SymOpRewritePattern;
178 LogicalResult matchAndRewrite(DefineOp op,
179 PatternRewriter &rewriter)
const final;
182 struct SinkArcInputsPattern :
public SymOpRewritePattern<DefineOp> {
183 using SymOpRewritePattern::SymOpRewritePattern;
184 LogicalResult matchAndRewrite(DefineOp op,
185 PatternRewriter &rewriter)
const final;
196 PatternRewriter &rewriter) {
198 callOp.getCallableForCallee().get<SymbolRefAttr>().getLeafReference()));
199 if (defOp.isPassthrough()) {
201 rewriter.replaceOp(callOp, callOp.getArgOperands());
211 LogicalResult MemWritePortEnableAndMaskCanonicalizer::matchAndRewrite(
212 MemoryWritePortOp op, PatternRewriter &rewriter)
const {
213 auto defOp = cast<DefineOp>(symbolCache.getDefinition(op.getArcAttr()));
216 if (op.getEnable() &&
218 defOp.getBodyBlock().getTerminator()->getOperand(op.getEnableIdx()),
219 mlir::m_ConstantInt(&enable))) {
220 if (enable.isZero()) {
221 symbolCache.removeUser(defOp, op);
222 rewriter.eraseOp(op);
223 if (symbolCache.useEmpty(defOp)) {
224 symbolCache.removeDefinitionAndAllUsers(defOp);
225 rewriter.eraseOp(defOp);
229 if (enable.isAllOnes()) {
230 if (arcMapping.count(defOp.getNameAttr())) {
231 auto arcWithoutEnable = arcMapping[defOp.getNameAttr()];
233 rewriter.updateRootInPlace(op, [&]() {
235 op.setArc(arcWithoutEnable.getValue());
237 symbolCache.removeUser(defOp, op);
238 symbolCache.addUser(symbolCache.getDefinition(arcWithoutEnable), op);
242 auto newName = names.newName(defOp.getName());
243 auto users = SmallVector<Operation *>(symbolCache.getUsers(defOp));
244 symbolCache.removeDefinitionAndAllUsers(defOp);
247 rewriter.updateRootInPlace(op, [&]() {
252 auto newResultTypes = op.getArcResultTypes();
255 rewriter.setInsertionPoint(defOp);
256 auto newDefOp = rewriter.cloneWithoutRegions(defOp);
257 auto *block = rewriter.createBlock(
258 &newDefOp.getBody(), newDefOp.getBody().end(),
259 newDefOp.getArgumentTypes(),
260 SmallVector<Location>(newDefOp.getNumArguments(), defOp.getLoc()));
261 auto callOp = rewriter.create<CallOp>(newDefOp.getLoc(), newResultTypes,
262 newName, block->getArguments());
263 SmallVector<Value> results(callOp->getResults());
265 newDefOp.getLoc(), rewriter.getI1Type(), 1);
266 results.insert(results.begin() + op.getEnableIdx(), constTrue);
267 rewriter.
create<OutputOp>(newDefOp.getLoc(), results);
270 auto *terminator = defOp.getBodyBlock().getTerminator();
271 rewriter.updateRootInPlace(
272 terminator, [&]() { terminator->eraseOperand(op.getEnableIdx()); });
273 rewriter.updateRootInPlace(defOp, [&]() {
274 defOp.setName(newName);
275 defOp.setFunctionType(
276 rewriter.getFunctionType(defOp.getArgumentTypes(), newResultTypes));
280 symbolCache.addDefinition(defOp.getNameAttr(), defOp);
281 symbolCache.addDefinition(newDefOp.getNameAttr(), newDefOp);
282 symbolCache.addUser(defOp, callOp);
283 for (
auto *user : users)
284 symbolCache.addUser(user == op ? defOp : newDefOp, user);
286 arcMapping[newDefOp.getNameAttr()] = defOp.getNameAttr();
294 CallPassthroughArc::matchAndRewrite(CallOp op,
295 PatternRewriter &rewriter)
const {
300 StatePassthroughArc::matchAndRewrite(StateOp op,
301 PatternRewriter &rewriter)
const {
302 if (op.getLatency() == 0)
308 RemoveUnusedArcs::matchAndRewrite(DefineOp op,
309 PatternRewriter &rewriter)
const {
310 if (symbolCache.useEmpty(op)) {
311 op.getBody().walk([&](mlir::CallOpInterface user) {
312 if (
auto symbol = user.getCallableForCallee().dyn_cast<SymbolRefAttr>())
313 if (
auto *defOp = symbolCache.getDefinition(symbol.getLeafReference()))
314 symbolCache.removeUser(defOp, user);
316 symbolCache.removeDefinitionAndAllUsers(op);
317 rewriter.eraseOp(op);
324 ICMPCanonicalizer::matchAndRewrite(comb::ICmpOp op,
325 PatternRewriter &rewriter)
const {
326 auto getConstant = [&](
const APInt &constant) -> Value {
329 auto sameWidthIntegers = [](TypeRange types) -> std::optional<unsigned> {
330 if (llvm::all_equal(types) && !types.empty())
331 if (
auto intType = dyn_cast<IntegerType>(*types.begin()))
332 return intType.getWidth();
335 auto negate = [&](Value input) -> Value {
342 if (matchPattern(op.getRhs(), mlir::m_ConstantInt(&rhs))) {
343 if (
auto concatOp = op.getLhs().getDefiningOp<
comb::ConcatOp>()) {
344 if (
auto optionalWidth =
345 sameWidthIntegers(concatOp->getOperands().getTypes())) {
346 if ((op.getPredicate() == comb::ICmpPredicate::eq ||
347 op.getPredicate() == comb::ICmpPredicate::ne) &&
350 op.getLoc(), concatOp.getInputs(), op.getTwoState());
351 if (*optionalWidth == 1) {
352 if (op.getPredicate() == comb::ICmpPredicate::ne)
353 andOp = negate(andOp);
354 rewriter.replaceOp(op, andOp);
357 rewriter.replaceOpWithNewOp<comb::ICmpOp>(
358 op, op.getPredicate(), andOp,
359 getConstant(APInt(*optionalWidth, rhs.getZExtValue())),
364 if ((op.getPredicate() == comb::ICmpPredicate::ne ||
365 op.getPredicate() == comb::ICmpPredicate::eq) &&
368 op.getLoc(), concatOp.getInputs(), op.getTwoState());
369 if (*optionalWidth == 1) {
370 if (op.getPredicate() == comb::ICmpPredicate::eq)
372 rewriter.replaceOp(op, orOp);
375 rewriter.replaceOpWithNewOp<comb::ICmpOp>(
376 op, op.getPredicate(), orOp,
377 getConstant(APInt(*optionalWidth, rhs.getZExtValue())),
387 LogicalResult RemoveUnusedArcArgumentsPattern::matchAndRewrite(
388 DefineOp op, PatternRewriter &rewriter)
const {
389 BitVector toDelete(op.getNumArguments());
390 for (
auto [i, arg] : llvm::enumerate(op.getArguments()))
400 SmallVector<mlir::CallOpInterface> mutableUsers;
401 for (
auto *user : symbolCache.getUsers(op)) {
402 auto callOpMutable = dyn_cast<mlir::CallOpInterface>(user);
405 mutableUsers.push_back(callOpMutable);
409 for (
auto user : mutableUsers)
410 for (
int i = toDelete.size() - 1; i >= 0; --i)
412 user.getArgOperandsMutable().erase(i);
414 op.eraseArguments(toDelete);
416 rewriter.getFunctionType(op.getArgumentTypes(), op.getResultTypes()));
418 statistics.removeUnusedArcArgumentsPatternNumArgsRemoved += toDelete.count();
423 SinkArcInputsPattern::matchAndRewrite(DefineOp op,
424 PatternRewriter &rewriter)
const {
427 auto users = symbolCache.getUsers(op);
429 users, [](
auto *user) {
return !isa<mlir::CallOpInterface>(user); }))
433 SmallVector<Operation *> stateConsts(op.getNumArguments());
435 for (
auto *user : users) {
436 auto callOp = cast<mlir::CallOpInterface>(user);
437 for (
auto [constArg, input] :
438 llvm::zip(stateConsts, callOp.getArgOperands())) {
439 if (
auto *constOp = input.getDefiningOp();
440 constOp && constOp->template hasTrait<OpTrait::ConstantLike>()) {
446 constArg->getName() == input.getDefiningOp()->getName() &&
447 constArg->getAttrDictionary() ==
448 input.getDefiningOp()->getAttrDictionary())
457 rewriter.setInsertionPointToStart(&op.getBodyBlock());
458 llvm::BitVector toDelete(op.getBodyBlock().getNumArguments());
459 for (
auto [constArg, arg] : llvm::zip(stateConsts, op.getArguments())) {
462 auto *inlinedConst = rewriter.clone(*constArg);
463 rewriter.replaceAllUsesWith(arg, inlinedConst->getResult(0));
464 toDelete.set(arg.getArgNumber());
466 op.getBodyBlock().eraseArguments(toDelete);
467 op.setType(rewriter.getFunctionType(op.getBodyBlock().getArgumentTypes(),
468 op.getResultTypes()));
471 for (
auto *user : users) {
472 auto callOp = cast<mlir::CallOpInterface>(user);
473 SmallPtrSet<Value, 4> maybeUnusedValues;
474 SmallVector<Value> newInputs;
475 for (
auto [index,
value] : llvm::enumerate(callOp.getArgOperands())) {
477 maybeUnusedValues.insert(
value);
479 newInputs.push_back(
value);
481 rewriter.updateRootInPlace(
482 callOp, [&]() { callOp.getArgOperandsMutable().assign(newInputs); });
483 for (
auto value : maybeUnusedValues)
484 if (
value.use_empty())
485 rewriter.eraseOp(
value.getDefiningOp());
488 return success(toDelete.any());
496 struct ArcCanonicalizerPass
497 :
public arc::impl::ArcCanonicalizerBase<ArcCanonicalizerPass> {
498 void runOnOperation()
override;
502 void ArcCanonicalizerPass::runOnOperation() {
503 MLIRContext &ctxt = getContext();
504 SymbolTableCollection symbolTable;
510 DenseMap<StringAttr, StringAttr> arcMapping;
512 mlir::GreedyRewriteConfig config;
513 config.enableRegionSimplification =
false;
514 config.maxIterations = 10;
515 config.useTopDownTraversal =
true;
518 RewritePatternSet symbolPatterns(&getContext());
519 symbolPatterns.add<CallPassthroughArc, StatePassthroughArc, RemoveUnusedArcs,
520 RemoveUnusedArcArgumentsPattern, SinkArcInputsPattern>(
521 &getContext(), cache, names, statistics);
522 symbolPatterns.add<MemWritePortEnableAndMaskCanonicalizer>(
523 &getContext(), cache, names, statistics, arcMapping);
525 if (failed(mlir::applyPatternsAndFoldGreedily(
526 getOperation(), std::move(symbolPatterns), config)))
527 return signalPassFailure();
529 numArcArgsRemoved = statistics.removeUnusedArcArgumentsPatternNumArgsRemoved;
532 for (
auto *dialect : ctxt.getLoadedDialects())
533 dialect->getCanonicalizationPatterns(
patterns);
534 for (mlir::RegisteredOperationName op : ctxt.getRegisteredOperations())
535 op.getCanonicalizationPatterns(
patterns, &ctxt);
536 patterns.add<ICMPCanonicalizer>(&getContext());
539 (void)mlir::applyPatternsAndFoldGreedily(getOperation(), std::move(
patterns),
544 return std::make_unique<ArcCanonicalizerPass>();
LogicalResult canonicalizePassthoughCall(mlir::CallOpInterface callOp, SymbolHandler &symbolCache, PatternRewriter &rewriter)
assert(baseType &&"element must be base type")
static std::optional< APSInt > getConstant(Attribute operand)
Determine the value of a constant operand for the sake of constant folding.
static InstancePath empty
A combination of SymbolCache and SymbolUserMap that also allows to add users and remove symbols on-de...
ArrayRef< Operation * > getUsers(Operation *symbol) const
Return the users of the provided symbol operation.
void addUser(Operation *def, Operation *user)
bool useEmpty(Operation *symbol)
Return true if the given symbol has no uses.
void removeUser(Operation *def, Operation *user)
DenseMap< Operation *, SetVector< Operation * > > userMap
void removeDefinitionAndAllUsers(Operation *def)
void collectAllSymbolUses(Operation *symbolTableOp, SymbolTableCollection &symbolTable)
A namespace that is used to store existing names and generate new names in some scope within the IR.
void add(SymbolCache &symCache)
SymbolCache initializer; initialize from every key that is convertible to a StringAttr in the SymbolC...
void addDefinitions(mlir::Operation *top)
Populate the symbol cache with all symbol-defining operations within the 'top' operation.
Default symbol cache implementation; stores associations between names (StringAttr's) to mlir::Operat...
mlir::Operation * getDefinition(mlir::Attribute attr) const override
Lookup a definition for 'symbol' in the cache.
def create(data_type, value)
std::unique_ptr< mlir::Pass > createArcCanonicalizerPass()
This file defines an intermediate representation for circuits acting as an abstraction for constraint...