11#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
12#include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h"
13#include "mlir/Analysis/DataFlowFramework.h"
14#include "mlir/Transforms/WalkPatternRewriteDriver.h"
19using namespace mlir::dataflow;
23#define GEN_PASS_DEF_COMBOVERFLOWANNOTATING
24#include "circt/Dialect/Comb/Passes.h.inc"
30static LogicalResult
collectRanges(DataFlowSolver &solver, ValueRange values,
31 SmallVectorImpl<ConstantIntRanges> &ranges) {
32 for (Value val : values) {
33 auto *maybeInferredRange =
34 solver.lookupState<IntegerValueRangeLattice>(val);
35 if (!maybeInferredRange || maybeInferredRange->getValue().isUninitialized())
38 const ConstantIntRanges &inferredRange =
39 maybeInferredRange->getValue().getValue();
40 ranges.push_back(inferredRange);
46template <
typename CombOpTy>
48 CombOpAnnotate(MLIRContext *context, DataFlowSolver &s)
51 LogicalResult matchAndRewrite(CombOpTy op,
52 PatternRewriter &rewriter)
const override {
54 if (op->hasAttr(
"comb.nuw"))
57 if (op->getNumOperands() != 2 || op->getNumResults() != 1)
58 return rewriter.notifyMatchFailure(
59 op,
"Only support binary operations with one result");
61 assert(isa<comb::AddOp>(op) || isa<comb::MulOp>(op) ||
62 isa<comb::SubOp>(op));
64 SmallVector<ConstantIntRanges> ranges;
66 return rewriter.notifyMatchFailure(op,
"input without specified range");
68 bool overflowed =
false;
70 auto a = ranges[0].umax();
71 auto b = ranges[1].umax();
76 if (isa<comb::AddOp>(op))
77 (void)a.uadd_ov(b, overflowed);
79 if (isa<comb::MulOp>(op))
80 (
void)a.umul_ov(b, overflowed);
83 op->setAttr(
"comb.nuw", UnitAttr::get(op->getContext()));
89 DataFlowSolver &solver;
92struct CombOverflowAnnotatingPass
93 : comb::impl::CombOverflowAnnotatingBase<CombOverflowAnnotatingPass> {
95 using CombOverflowAnnotatingBase::CombOverflowAnnotatingBase;
96 void runOnOperation()
override;
100void CombOverflowAnnotatingPass::runOnOperation() {
101 Operation *op = getOperation();
102 MLIRContext *ctx = op->getContext();
103 DataFlowSolver solver;
104 solver.load<DeadCodeAnalysis>();
105 solver.load<IntegerRangeAnalysis>();
106 if (failed(solver.initializeAndRun(op)))
107 return signalPassFailure();
111 patterns.add<CombOpAnnotate<comb::AddOp>, CombOpAnnotate<comb::MulOp>>(
114 walkAndApplyPatterns(op, std::move(
patterns));
assert(baseType &&"element must be base type")
static LogicalResult collectRanges(DataFlowSolver &solver, ValueRange values, SmallVectorImpl< ConstantIntRanges > &ranges)
Gather ranges for all the values in values.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.