12 #include "mlir/Dialect/Arith/IR/Arith.h"
13 #include "mlir/Pass/Pass.h"
14 #include "mlir/Transforms/DialectConversion.h"
17 #define GEN_PASS_DEF_CONVERTCOMBTOARITH
18 #include "circt/Conversion/Passes.h.inc"
21 using namespace circt;
25 using namespace arith;
37 matchAndRewrite(ReplicateOp op, OpAdaptor adaptor,
38 ConversionPatternRewriter &rewriter)
const override {
40 Type inputType = op.getInput().getType();
41 if (isa<IntegerType>(inputType) && inputType.getIntOrFloatBitWidth() == 1) {
42 Type outType = rewriter.getIntegerType(op.getMultiple());
43 rewriter.replaceOpWithNewOp<ExtSIOp>(op, outType, adaptor.getInput());
47 SmallVector<Value> inputs(op.getMultiple(), adaptor.getInput());
48 rewriter.replaceOpWithNewOp<
ConcatOp>(op, inputs);
59 ConversionPatternRewriter &rewriter)
const override {
61 rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, adaptor.getValueAttr());
71 matchAndRewrite(ICmpOp op, OpAdaptor adaptor,
72 ConversionPatternRewriter &rewriter)
const override {
75 switch (adaptor.getPredicate()) {
76 case ICmpPredicate::cne:
77 case ICmpPredicate::wne:
78 case ICmpPredicate::ne:
79 pred = CmpIPredicate::ne;
81 case ICmpPredicate::ceq:
82 case ICmpPredicate::weq:
83 case ICmpPredicate::eq:
84 pred = CmpIPredicate::eq;
86 case ICmpPredicate::sge:
87 pred = CmpIPredicate::sge;
89 case ICmpPredicate::sgt:
90 pred = CmpIPredicate::sgt;
92 case ICmpPredicate::sle:
93 pred = CmpIPredicate::sle;
95 case ICmpPredicate::slt:
96 pred = CmpIPredicate::slt;
98 case ICmpPredicate::uge:
99 pred = CmpIPredicate::uge;
101 case ICmpPredicate::ugt:
102 pred = CmpIPredicate::ugt;
104 case ICmpPredicate::ule:
105 pred = CmpIPredicate::ule;
107 case ICmpPredicate::ult:
108 pred = CmpIPredicate::ult;
112 rewriter.replaceOpWithNewOp<CmpIOp>(op, pred, adaptor.getLhs(),
123 matchAndRewrite(
ExtractOp op, OpAdaptor adaptor,
124 ConversionPatternRewriter &rewriter)
const override {
126 Value lowBit = rewriter.create<arith::ConstantOp>(
130 rewriter.create<ShRUIOp>(op.getLoc(), adaptor.getInput(), lowBit);
131 rewriter.replaceOpWithNewOp<TruncIOp>(op, op.getResult().getType(),
142 matchAndRewrite(
ConcatOp op, OpAdaptor adaptor,
143 ConversionPatternRewriter &rewriter)
const override {
144 Type type = op.getResult().getType();
145 Location loc = op.getLoc();
149 if (op.getNumOperands() == 1) {
150 rewriter.replaceOp(op, adaptor.getOperands().back());
158 rewriter.createOrFold<ExtUIOp>(loc, type, adaptor.getOperands().back());
162 unsigned offset = type.getIntOrFloatBitWidth();
163 for (
auto operand : adaptor.getOperands().drop_back()) {
164 offset -= operand.getType().getIntOrFloatBitWidth();
165 auto offsetConst = rewriter.create<arith::ConstantOp>(
167 auto extended = rewriter.createOrFold<ExtUIOp>(loc, type, operand);
168 auto shifted = rewriter.createOrFold<ShLIOp>(loc, extended, offsetConst);
169 aggregate = rewriter.createOrFold<OrIOp>(loc, aggregate, shifted);
172 rewriter.replaceOp(op, aggregate);
178 template <
typename SourceOp,
typename TargetOp>
181 using OpAdaptor =
typename SourceOp::Adaptor;
184 matchAndRewrite(SourceOp op, OpAdaptor adaptor,
185 ConversionPatternRewriter &rewriter)
const override {
187 rewriter.replaceOpWithNewOp<TargetOp>(op, op.getResult().getType(),
188 adaptor.getOperands());
195 template <
typename SourceOp,
typename TargetOp>
198 using OpAdaptor =
typename SourceOp::Adaptor;
201 matchAndRewrite(SourceOp op, OpAdaptor adaptor,
202 ConversionPatternRewriter &rewriter)
const override {
203 Location loc = op.getLoc();
204 Value zero = rewriter.create<arith::ConstantOp>(
205 loc, rewriter.getIntegerAttr(adaptor.getRhs().getType(), 0));
206 Value one = rewriter.create<arith::ConstantOp>(
207 loc, rewriter.getIntegerAttr(adaptor.getRhs().getType(), 1));
208 Value isZero = rewriter.create<arith::CmpIOp>(loc, CmpIPredicate::eq,
209 adaptor.getRhs(), zero);
211 rewriter.create<arith::SelectOp>(loc, isZero, one, adaptor.getRhs());
212 rewriter.replaceOpWithNewOp<TargetOp>(op, adaptor.getLhs(), divisor);
218 template <
typename SourceOp,
typename TargetOp>
221 using OpAdaptor =
typename SourceOp::Adaptor;
224 matchAndRewrite(SourceOp op, OpAdaptor adaptor,
225 ConversionPatternRewriter &rewriter)
const override {
228 ValueRange operands = adaptor.getOperands();
229 Value runner = operands[0];
231 llvm::make_range(operands.begin() + 1, operands.end())) {
232 runner = rewriter.create<TargetOp>(op.getLoc(), runner, operand);
234 rewriter.replaceOp(op, runner);
246 template <
typename SourceOp,
typename TargetOp>
249 using OpAdaptor =
typename SourceOp::Adaptor;
252 matchAndRewrite(SourceOp op, OpAdaptor adaptor,
253 ConversionPatternRewriter &rewriter)
const override {
254 unsigned shifteeWidth =
255 hw::type_cast<IntegerType>(adaptor.getLhs().getType())
256 .getIntOrFloatBitWidth();
257 auto zeroConstOp = rewriter.create<arith::ConstantOp>(
259 auto maxShamtConstOp = rewriter.create<arith::ConstantOp>(
262 auto shiftOp = rewriter.createOrFold<TargetOp>(
263 op.getLoc(), adaptor.getLhs(), adaptor.getRhs());
264 auto isAllZeroOp = rewriter.createOrFold<CmpIOp>(
265 op.getLoc(), CmpIPredicate::uge, adaptor.getRhs(),
266 maxShamtConstOp.getResult());
267 rewriter.replaceOpWithNewOp<SelectOp>(op, isAllZeroOp, zeroConstOp,
278 matchAndRewrite(
ShrSOp op, OpAdaptor adaptor,
279 ConversionPatternRewriter &rewriter)
const override {
280 unsigned shifteeWidth =
281 hw::type_cast<IntegerType>(adaptor.getLhs().getType())
282 .getIntOrFloatBitWidth();
284 auto maxShamtMinusOneConstOp = rewriter.create<arith::ConstantOp>(
287 auto shamtOp = rewriter.createOrFold<MinUIOp>(op.getLoc(), adaptor.getRhs(),
288 maxShamtMinusOneConstOp);
289 rewriter.replaceOpWithNewOp<ShRSIOp>(op, adaptor.getLhs(), shamtOp);
301 struct ConvertCombToArithPass
302 :
public circt::impl::ConvertCombToArithBase<ConvertCombToArithPass> {
303 void runOnOperation()
override;
308 TypeConverter &converter, mlir::RewritePatternSet &
patterns) {
310 CombReplicateOpConversion, HWConstantOpConversion, IcmpOpConversion,
311 ExtractOpConversion, ConcatOpConversion, ShrSOpConversion,
312 LogicalShiftConversion<ShlOp, ShLIOp>,
313 LogicalShiftConversion<ShrUOp, ShRUIOp>,
314 BinaryOpConversion<SubOp, SubIOp>, DivOpConversion<DivSOp, DivSIOp>,
315 DivOpConversion<DivUOp, DivUIOp>, DivOpConversion<ModSOp, RemSIOp>,
316 DivOpConversion<ModUOp, RemUIOp>, BinaryOpConversion<MuxOp, SelectOp>,
317 VariadicOpConversion<AddOp, AddIOp>, VariadicOpConversion<MulOp, MulIOp>,
318 VariadicOpConversion<AndOp, AndIOp>, VariadicOpConversion<OrOp, OrIOp>,
319 VariadicOpConversion<XorOp, XOrIOp>>(converter,
patterns.getContext());
322 void ConvertCombToArithPass::runOnOperation() {
323 ConversionTarget target(getContext());
324 target.addIllegalDialect<comb::CombDialect>();
326 target.addLegalDialect<ArithDialect>();
332 RewritePatternSet
patterns(&getContext());
333 TypeConverter converter;
334 converter.addConversion([](Type type) {
return type; });
338 if (failed(mlir::applyPartialConversion(getOperation(), target,
344 return std::make_unique<ConvertCombToArithPass>();
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
void populateCombToArithConversionPatterns(TypeConverter &converter, RewritePatternSet &patterns)
std::unique_ptr< OperationPass< ModuleOp > > createConvertCombToArithPass()