10 #include "../PassDetail.h"
13 #include "mlir/Dialect/Arith/IR/Arith.h"
14 #include "mlir/Transforms/DialectConversion.h"
16 using namespace circt;
20 using namespace arith;
32 matchAndRewrite(ReplicateOp op, OpAdaptor adaptor,
33 ConversionPatternRewriter &rewriter)
const override {
35 Type inputType = op.getInput().getType();
36 if (inputType.isa<IntegerType>() &&
37 inputType.getIntOrFloatBitWidth() == 1) {
38 Type outType = rewriter.getIntegerType(op.getMultiple());
39 rewriter.replaceOpWithNewOp<ExtSIOp>(op, outType, adaptor.getInput());
43 SmallVector<Value>
inputs(op.getMultiple(), adaptor.getInput());
44 rewriter.replaceOpWithNewOp<ConcatOp>(op,
inputs);
54 matchAndRewrite(hw::ConstantOp op, OpAdaptor adaptor,
55 ConversionPatternRewriter &rewriter)
const override {
57 rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, adaptor.getValueAttr());
67 matchAndRewrite(ICmpOp op, OpAdaptor adaptor,
68 ConversionPatternRewriter &rewriter)
const override {
71 switch (adaptor.getPredicate()) {
72 case ICmpPredicate::cne:
73 case ICmpPredicate::wne:
74 case ICmpPredicate::ne:
75 pred = CmpIPredicate::ne;
77 case ICmpPredicate::ceq:
78 case ICmpPredicate::weq:
79 case ICmpPredicate::eq:
80 pred = CmpIPredicate::eq;
82 case ICmpPredicate::sge:
83 pred = CmpIPredicate::sge;
85 case ICmpPredicate::sgt:
86 pred = CmpIPredicate::sgt;
88 case ICmpPredicate::sle:
89 pred = CmpIPredicate::sle;
91 case ICmpPredicate::slt:
92 pred = CmpIPredicate::slt;
94 case ICmpPredicate::uge:
95 pred = CmpIPredicate::uge;
97 case ICmpPredicate::ugt:
98 pred = CmpIPredicate::ugt;
100 case ICmpPredicate::ule:
101 pred = CmpIPredicate::ule;
103 case ICmpPredicate::ult:
104 pred = CmpIPredicate::ult;
108 rewriter.replaceOpWithNewOp<CmpIOp>(op, pred, adaptor.getLhs(),
119 matchAndRewrite(ExtractOp op, OpAdaptor adaptor,
120 ConversionPatternRewriter &rewriter)
const override {
122 Value lowBit = rewriter.create<arith::ConstantOp>(
126 rewriter.create<ShRUIOp>(op.getLoc(), adaptor.getInput(), lowBit);
127 rewriter.replaceOpWithNewOp<TruncIOp>(op, op.getResult().getType(),
138 matchAndRewrite(ConcatOp op, OpAdaptor adaptor,
139 ConversionPatternRewriter &rewriter)
const override {
140 Type type = op.getResult().getType();
141 Location loc = op.getLoc();
142 unsigned nextInsertion = type.getIntOrFloatBitWidth();
147 for (
unsigned i = 0, e = op.getNumOperands(); i < e; i++) {
149 adaptor.getOperands()[i].getType().getIntOrFloatBitWidth();
151 Value nextInsValue = rewriter.create<arith::ConstantOp>(
154 rewriter.create<ExtUIOp>(loc, type, adaptor.getOperands()[i]);
155 Value shifted = rewriter.create<ShLIOp>(loc, extended, nextInsValue);
156 aggregate = rewriter.create<OrIOp>(loc, aggregate, shifted);
159 rewriter.replaceOp(op, aggregate);
165 template <
typename SourceOp,
typename TargetOp>
168 using OpAdaptor =
typename SourceOp::Adaptor;
171 matchAndRewrite(SourceOp op, OpAdaptor adaptor,
172 ConversionPatternRewriter &rewriter)
const override {
174 rewriter.replaceOpWithNewOp<TargetOp>(op, op.getResult().getType(),
175 adaptor.getOperands());
181 template <
typename SourceOp,
typename TargetOp>
184 using OpAdaptor =
typename SourceOp::Adaptor;
187 matchAndRewrite(SourceOp op, OpAdaptor adaptor,
188 ConversionPatternRewriter &rewriter)
const override {
191 ValueRange operands = adaptor.getOperands();
192 Value runner = operands[0];
194 llvm::make_range(operands.begin() + 1, operands.end())) {
195 runner = rewriter.create<TargetOp>(op.getLoc(), runner, operand);
197 rewriter.replaceOp(op, runner);
208 struct ConvertCombToArithPass
209 :
public ConvertCombToArithBase<ConvertCombToArithPass> {
210 void runOnOperation()
override;
215 TypeConverter &converter, mlir::RewritePatternSet &
patterns) {
217 CombReplicateOpConversion, HWConstantOpConversion, IcmpOpConversion,
218 ExtractOpConversion, ConcatOpConversion,
219 BinaryOpConversion<ShlOp, ShLIOp>, BinaryOpConversion<ShrSOp, ShRSIOp>,
220 BinaryOpConversion<ShrUOp, ShRUIOp>, BinaryOpConversion<SubOp, SubIOp>,
221 BinaryOpConversion<DivSOp, DivSIOp>, BinaryOpConversion<DivUOp, DivUIOp>,
222 BinaryOpConversion<ModSOp, RemSIOp>, BinaryOpConversion<ModUOp, RemUIOp>,
223 BinaryOpConversion<MuxOp, SelectOp>, VariadicOpConversion<AddOp, AddIOp>,
224 VariadicOpConversion<MulOp, MulIOp>, VariadicOpConversion<AndOp, AndIOp>,
225 VariadicOpConversion<OrOp, OrIOp>, VariadicOpConversion<XorOp, XOrIOp>>(
229 void ConvertCombToArithPass::runOnOperation() {
230 ConversionTarget target(getContext());
231 target.addIllegalDialect<comb::CombDialect>();
232 target.addIllegalOp<hw::ConstantOp>();
233 target.addLegalDialect<ArithDialect>();
237 target.addLegalOp<comb::ParityOp>();
239 RewritePatternSet
patterns(&getContext());
240 TypeConverter converter;
241 converter.addConversion([](Type type) {
return type; });
245 if (failed(mlir::applyPartialConversion(getOperation(), target,
251 return std::make_unique<ConvertCombToArithPass>();
llvm::SmallVector< StringAttr > inputs
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
This file defines an intermediate representation for circuits acting as an abstraction for constraint...
void populateCombToArithConversionPatterns(TypeConverter &converter, RewritePatternSet &patterns)
std::unique_ptr< OperationPass< ModuleOp > > createConvertCombToArithPass()