12#include "mlir/Transforms/DialectConversion.h"
13#include "llvm/ADT/TypeSwitch.h"
15#include "mlir/Analysis/DataFlowFramework.h"
17#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
18#include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h"
19#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
24using namespace mlir::dataflow;
28#define GEN_PASS_DEF_COMBINTRANGENARROWING
29#include "circt/Dialect/Comb/Passes.h.inc"
35static LogicalResult
collectRanges(DataFlowSolver &solver, ValueRange values,
36 SmallVectorImpl<ConstantIntRanges> &ranges) {
37 for (Value val : values) {
38 auto *maybeInferredRange =
39 solver.lookupState<IntegerValueRangeLattice>(val);
40 if (!maybeInferredRange || maybeInferredRange->getValue().isUninitialized())
43 const ConstantIntRanges &inferredRange =
44 maybeInferredRange->getValue().getValue();
45 ranges.push_back(inferredRange);
35static LogicalResult
collectRanges(DataFlowSolver &solver, ValueRange values, {
…}
51template <
typename CombOpTy>
53 CombOpNarrow(MLIRContext *context, DataFlowSolver &s)
56 LogicalResult matchAndRewrite(CombOpTy op,
57 PatternRewriter &rewriter)
const override {
59 auto opWidth = op.getType().getIntOrFloatBitWidth();
60 if (op->getNumOperands() != 2 || op->getNumResults() != 1)
61 return rewriter.notifyMatchFailure(
62 op,
"Only support binary operations with one result");
64 SmallVector<ConstantIntRanges> ranges;
66 return rewriter.notifyMatchFailure(op,
"input without specified range");
68 return rewriter.notifyMatchFailure(op,
"output without specified range");
70 auto removeWidth = ranges[0].umax().countLeadingZeros();
71 for (
const ConstantIntRanges &range : ranges) {
72 auto rangeCanRemove = range.umax().countLeadingZeros();
73 removeWidth = std::min(removeWidth, rangeCanRemove);
76 return rewriter.notifyMatchFailure(op,
"no bits to remove");
77 if (removeWidth == opWidth)
78 return rewriter.notifyMatchFailure(
79 op,
"all bits to remove - replace by zero");
82 Value lhs = op.getOperand(0);
83 Value rhs = op.getOperand(1);
85 Location loc = op.getLoc();
86 auto newWidth = opWidth - removeWidth;
88 auto replaceType = rewriter.getIntegerType(newWidth);
95 auto narrowOp = rewriter.
create<CombOpTy>(loc, extractLhsOp, extractRhsOp);
101 loc, op.getType(), ValueRange{zero, narrowOp});
103 rewriter.replaceOp(op, replaceOp);
108 DataFlowSolver &solver;
111struct CombIntRangeNarrowingPass
112 : comb::impl::CombIntRangeNarrowingBase<CombIntRangeNarrowingPass> {
114 using CombIntRangeNarrowingBase::CombIntRangeNarrowingBase;
115 void runOnOperation()
override;
119void CombIntRangeNarrowingPass::runOnOperation() {
120 Operation *op = getOperation();
121 MLIRContext *ctx = op->getContext();
122 DataFlowSolver solver;
123 solver.load<DeadCodeAnalysis>();
124 solver.load<IntegerRangeAnalysis>();
125 if (failed(solver.initializeAndRun(op)))
126 return signalPassFailure();
131 if (failed(applyPatternsGreedily(op, std::move(
patterns))))
135void comb::populateCombNarrowingPatterns(RewritePatternSet &
patterns,
136 DataFlowSolver &solver) {
137 patterns.add<CombOpNarrow<comb::AddOp>, CombOpNarrow<comb::MulOp>,
138 CombOpNarrow<comb::SubOp>>(
patterns.getContext(), solver);
static LogicalResult collectRanges(DataFlowSolver &solver, ValueRange values, SmallVectorImpl< ConstantIntRanges > &ranges)
Gather ranges for all the values in values.
void populateCombNarrowingPatterns(mlir::RewritePatternSet &patterns, mlir::DataFlowSolver &solver)
Add patterns for int range based narrowing.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.