17#include "mlir/Interfaces/InferIntRangeInterface.h"
18#include "mlir/Interfaces/Utils/InferIntRangeCommon.h"
21using namespace mlir::intrange;
29void comb::AddOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
30 SetIntRangeFn setResultRange) {
31 auto resultRange = argRanges[0];
32 for (
auto argRange : argRanges.drop_front())
34 inferAdd({resultRange, argRange}, intrange::OverflowFlags::None);
36 setResultRange(getResult(), resultRange);
43void comb::SubOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
44 SetIntRangeFn setResultRange) {
45 setResultRange(getResult(),
46 inferSub(argRanges, intrange::OverflowFlags::None));
53void comb::MulOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
54 SetIntRangeFn setResultRange) {
55 auto resultRange = argRanges[0];
56 for (
auto argRange : argRanges.drop_front())
58 inferMul({resultRange, argRange}, intrange::OverflowFlags::None);
60 setResultRange(getResult(), resultRange);
67void comb::DivUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
68 SetIntRangeFn setResultRange) {
69 setResultRange(getResult(), inferDivU(argRanges));
76void comb::DivSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
77 SetIntRangeFn setResultRange) {
78 setResultRange(getResult(), inferDivS(argRanges));
85void comb::ModUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
86 SetIntRangeFn setResultRange) {
87 setResultRange(getResult(), inferRemU(argRanges));
94void comb::ModSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
95 SetIntRangeFn setResultRange) {
96 setResultRange(getResult(), inferRemS(argRanges));
102void comb::AndOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
103 SetIntRangeFn setResultRange) {
104 auto resultRange = argRanges[0];
105 for (
auto argRange : argRanges.drop_front())
106 resultRange = inferAnd({resultRange, argRange});
108 setResultRange(getResult(), resultRange);
115void comb::OrOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
116 SetIntRangeFn setResultRange) {
117 auto resultRange = argRanges[0];
118 for (
auto argRange : argRanges.drop_front())
119 resultRange = inferOr({resultRange, argRange});
121 setResultRange(getResult(), resultRange);
128void comb::XorOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
129 SetIntRangeFn setResultRange) {
130 auto resultRange = argRanges[0];
131 for (
auto argRange : argRanges.drop_front())
132 resultRange = inferXor({resultRange, argRange});
134 setResultRange(getResult(), resultRange);
141void comb::ShlOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
142 SetIntRangeFn setResultRange) {
143 setResultRange(getResult(),
144 inferShl(argRanges, intrange::OverflowFlags::None));
151void comb::ShrUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
152 SetIntRangeFn setResultRange) {
153 setResultRange(getResult(), inferShrU(argRanges));
160void comb::ShrSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
161 SetIntRangeFn setResultRange) {
162 setResultRange(getResult(), inferShrS(argRanges));
169void comb::ConcatOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
170 SetIntRangeFn setResultRange) {
172 const auto resWidth = getResult().getType().getIntOrFloatBitWidth();
173 auto totalWidth = resWidth;
174 APInt umin = APInt::getZero(resWidth);
175 APInt umax = APInt::getZero(resWidth);
176 for (
auto [operand, arg] :
llvm::zip(getOperands(), argRanges)) {
177 assert(totalWidth >= operand.getType().getIntOrFloatBitWidth() &&
178 "ConcatOp: total width in interval range calculation is negative");
179 totalWidth -= operand.getType().getIntOrFloatBitWidth();
180 auto uminUpd = arg.umin().zext(resWidth).ushl_sat(totalWidth);
181 auto umaxUpd = arg.umax().zext(resWidth).ushl_sat(totalWidth);
182 umin = umin.uadd_sat(uminUpd);
183 umax = umax.uadd_sat(umaxUpd);
185 auto urange = ConstantIntRanges::fromUnsigned(umin, umax);
186 setResultRange(getResult(), urange);
193void comb::ExtractOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
194 SetIntRangeFn setResultRange) {
196 const auto resWidth = getResult().getType().getIntOrFloatBitWidth();
197 const auto lowBit = getLowBit();
198 auto umin = argRanges[0].umin().lshr(lowBit).trunc(resWidth);
199 auto umax = argRanges[0].umax().lshr(lowBit).trunc(resWidth);
200 auto urange = ConstantIntRanges::fromUnsigned(umin, umax);
201 setResultRange(getResult(), urange);
208void comb::ReplicateOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
209 SetIntRangeFn setResultRange) {
211 const auto operandWidth = getOperand().getType().getIntOrFloatBitWidth();
212 const auto resWidth = getResult().getType().getIntOrFloatBitWidth();
213 APInt umin = APInt::getZero(resWidth);
214 APInt umax = APInt::getZero(resWidth);
215 auto uminIn = argRanges[0].umin().zext(resWidth);
216 auto umaxIn = argRanges[0].umax().zext(resWidth);
217 for (
unsigned int totalWidth = 0; totalWidth < resWidth;
218 totalWidth += operandWidth) {
219 auto uminUpd = uminIn.ushl_sat(totalWidth);
220 auto umaxUpd = umaxIn.ushl_sat(totalWidth);
221 umin = umin.uadd_sat(uminUpd);
222 umax = umax.uadd_sat(umaxUpd);
224 auto urange = ConstantIntRanges::fromUnsigned(umin, umax);
225 setResultRange(getResult(), urange);
232void comb::MuxOp::inferResultRangesFromOptional(
233 ArrayRef<IntegerValueRange> argRanges, SetIntLatticeFn setResultRange) {
234 std::optional<APInt> mbCondVal =
235 argRanges[0].isUninitialized()
237 : argRanges[0].getValue().getConstantValue();
239 const IntegerValueRange &trueCase = argRanges[1];
240 const IntegerValueRange &falseCase = argRanges[2];
243 if (mbCondVal->isZero())
244 setResultRange(getResult(), falseCase);
246 setResultRange(getResult(), trueCase);
249 setResultRange(getResult(), IntegerValueRange::join(trueCase, falseCase));
256void comb::ICmpOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
257 SetIntRangeFn setResultRange) {
258 comb::ICmpPredicate combPred = getPredicate();
260 APInt min = APInt::getZero(1);
261 APInt max = APInt::getAllOnes(1);
263 intrange::CmpPredicate pred;
265 case comb::ICmpPredicate::eq:
266 pred = intrange::CmpPredicate::eq;
268 case comb::ICmpPredicate::ne:
269 pred = intrange::CmpPredicate::ne;
271 case comb::ICmpPredicate::slt:
272 pred = intrange::CmpPredicate::slt;
274 case comb::ICmpPredicate::sle:
275 pred = intrange::CmpPredicate::sle;
277 case comb::ICmpPredicate::sgt:
278 pred = intrange::CmpPredicate::sgt;
280 case comb::ICmpPredicate::sge:
281 pred = intrange::CmpPredicate::sge;
283 case comb::ICmpPredicate::ult:
284 pred = intrange::CmpPredicate::ult;
286 case comb::ICmpPredicate::ule:
287 pred = intrange::CmpPredicate::ule;
289 case comb::ICmpPredicate::ugt:
290 pred = intrange::CmpPredicate::ugt;
292 case comb::ICmpPredicate::uge:
293 pred = intrange::CmpPredicate::uge;
297 setResultRange(getResult(), ConstantIntRanges::fromUnsigned(min, max));
301 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
303 std::optional<bool> truthValue = intrange::evaluatePred(pred, lhs, rhs);
304 if (truthValue.has_value() && *truthValue)
306 else if (truthValue.has_value() && !(*truthValue))
309 setResultRange(getResult(), ConstantIntRanges::fromUnsigned(min, max));
assert(baseType &&"element must be base type")
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.