21#include "mlir/Pass/Pass.h"
23#include "mlir/Transforms/DialectConversion.h"
24#include "llvm/ADT/TypeSwitch.h"
27#define GEN_PASS_DEF_HWARITHTOHW
28#include "circt/Conversion/Passes.h.inc"
45 llvm::function_ref<std::string(StringRef)> namehintCallback) {
46 if (
auto *sourceOp = oldValue.getDefiningOp()) {
48 sourceOp->getAttrOfType<mlir::StringAttr>(
"sv.namehint")) {
49 auto newNamehint = namehintCallback(namehint.strref());
50 newOp->setAttr(
"sv.namehint",
51 StringAttr::get(oldValue.getContext(), newNamehint));
57static Value
extractBits(OpBuilder &builder, Location loc, Value value,
58 unsigned startBit,
unsigned bitWidth) {
59 Value extractedValue =
61 Operation *definingOp = extractedValue.getDefiningOp();
62 if (extractedValue != value && definingOp) {
65 return (oldNamehint +
"_" + std::to_string(startBit) +
"_to_" +
66 std::to_string(startBit + bitWidth))
70 return extractedValue;
76 unsigned targetWidth,
bool signExtension) {
77 unsigned sourceWidth = value.getType().getIntOrFloatBitWidth();
78 unsigned extensionLength = targetWidth - sourceWidth;
80 if (extensionLength == 0)
90 builder.createOrFold<comb::ReplicateOp>(loc, highBit, extensionLength);
93 extensionBits = builder
95 loc, builder.getIntegerType(extensionLength), 0)
99 auto extOp = builder.create<
comb::ConcatOp>(loc, extensionBits, value);
101 return (oldNamehint +
"_" + (signExtension ?
"sext_" :
"zext_") +
102 std::to_string(targetWidth))
106 return extOp->getOpResult(0);
111 llvm::TypeSwitch<Type, bool>(type)
112 .Case<IntegerType>([](
auto type) {
return !type.isSignless(); })
113 .Case<hw::ArrayType>(
115 .Case<hw::StructType>([](
auto type) {
116 return llvm::any_of(type.getElements(), [](
auto element) {
117 return isSignednessType(element.type);
120 .Case<hw::InOutType>(
122 .Case<hw::TypeAliasType>(
124 .Default([](
auto type) {
return false; });
130 if (
auto typeAttr = dyn_cast<TypeAttr>(attr))
138 if (
auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
141 llvm::none_of(funcOp.getFunctionBody().getArgumentTypes(),
145 if (
auto modOp = dyn_cast<hw::HWModuleLike>(op)) {
147 llvm::none_of(modOp.getModuleBody().getArgumentTypes(),
151 auto attrs = llvm::map_range(op->getAttrs(), [](
const NamedAttribute &attr) {
152 return attr.getValue();
158 return operandsOK && resultsOK && attrsOK;
170 matchAndRewrite(
ConstantOp constOp, OpAdaptor adaptor,
171 ConversionPatternRewriter &rewriter)
const override {
173 constOp.getConstantValue());
181 matchAndRewrite(
DivOp op, OpAdaptor adaptor,
182 ConversionPatternRewriter &rewriter)
const override {
183 auto loc = op.getLoc();
184 auto isLhsTypeSigned =
185 cast<IntegerType>(op.getOperand(0).getType()).isSigned();
186 auto rhsType = cast<IntegerType>(op.getOperand(1).getType());
187 auto targetType = cast<IntegerType>(op.getResult().getType());
198 bool signedDivision = targetType.isSigned();
199 unsigned extendSize = std::max(
200 targetType.getWidth(),
201 rhsType.getWidth() + (signedDivision && !rhsType.isSigned() ? 1 : 0));
204 Value lhsValue =
extendTypeWidth(rewriter, loc, adaptor.getInputs()[0],
205 extendSize, isLhsTypeSigned);
206 Value rhsValue =
extendTypeWidth(rewriter, loc, adaptor.getInputs()[1],
207 extendSize, rhsType.isSigned());
211 divResult = rewriter.create<
comb::DivSOp>(loc, lhsValue, rhsValue,
false)
214 divResult = rewriter.create<
comb::DivUOp>(loc, lhsValue, rhsValue,
false)
218 auto *divOp = divResult.getDefiningOp();
219 rewriter.modifyOpInPlace(
220 divOp, [&]() { divOp->setDialectAttrs(op->getDialectAttrs()); });
223 Value truncateResult =
extractBits(rewriter, loc, divResult, 0,
224 targetType.getWidth());
225 rewriter.replaceOp(op, truncateResult);
237 matchAndRewrite(
CastOp op, OpAdaptor adaptor,
238 ConversionPatternRewriter &rewriter)
const override {
239 auto sourceType = cast<IntegerType>(op.getIn().getType());
240 auto sourceWidth = sourceType.getWidth();
241 bool isSourceTypeSigned = sourceType.isSigned();
242 auto targetWidth = cast<IntegerType>(op.getOut().getType()).getWidth();
245 if (sourceWidth == targetWidth) {
248 replaceValue = adaptor.getIn();
249 }
else if (sourceWidth < targetWidth) {
253 targetWidth, isSourceTypeSigned);
256 replaceValue =
extractBits(rewriter, op.getLoc(), adaptor.getIn(),
259 rewriter.replaceOp(op, replaceValue);
271static comb::ICmpPredicate lowerPredicate(ICmpPredicate pred,
bool isSigned) {
273 case ICmpPredicate::eq:
274 return comb::ICmpPredicate::eq;
275 case ICmpPredicate::ne:
276 return comb::ICmpPredicate::ne;
277 case ICmpPredicate::lt:
278 return isSigned ? comb::ICmpPredicate::slt : comb::ICmpPredicate::ult;
279 case ICmpPredicate::ge:
280 return isSigned ? comb::ICmpPredicate::sge : comb::ICmpPredicate::uge;
281 case ICmpPredicate::le:
282 return isSigned ? comb::ICmpPredicate::sle : comb::ICmpPredicate::ule;
283 case ICmpPredicate::gt:
284 return isSigned ? comb::ICmpPredicate::sgt : comb::ICmpPredicate::ugt;
288 "Missing hwarith::ICmpPredicate to comb::ICmpPredicate lowering");
289 return comb::ICmpPredicate::eq;
296 matchAndRewrite(
ICmpOp op, OpAdaptor adaptor,
297 ConversionPatternRewriter &rewriter)
const override {
298 auto lhsType = cast<IntegerType>(op.getLhs().getType());
299 auto rhsType = cast<IntegerType>(op.getRhs().getType());
300 IntegerType::SignednessSemantics cmpSignedness;
301 const unsigned cmpWidth =
304 ICmpPredicate pred = op.getPredicate();
305 comb::ICmpPredicate combPred = lowerPredicate(
306 pred, cmpSignedness == IntegerType::SignednessSemantics::Signed);
308 const auto loc = op.getLoc();
309 Value lhsValue =
extendTypeWidth(rewriter, loc, adaptor.getLhs(), cmpWidth,
311 Value rhsValue =
extendTypeWidth(rewriter, loc, adaptor.getRhs(), cmpWidth,
314 auto newOp = rewriter.create<comb::ICmpOp>(op->getLoc(), combPred, lhsValue,
316 rewriter.modifyOpInPlace(
317 newOp, [&]() { newOp->setDialectAttrs(op->getDialectAttrs()); });
318 rewriter.replaceOp(op, newOp);
324template <
class BinOp,
class ReplaceOp>
330 matchAndRewrite(BinOp op, OpAdaptor adaptor,
331 ConversionPatternRewriter &rewriter)
const override {
332 auto loc = op.getLoc();
333 auto isLhsTypeSigned =
334 cast<IntegerType>(op.getOperand(0).getType()).isSigned();
335 auto isRhsTypeSigned =
336 cast<IntegerType>(op.getOperand(1).getType()).isSigned();
337 auto targetWidth = cast<IntegerType>(op.getResult().getType()).getWidth();
339 Value lhsValue =
extendTypeWidth(rewriter, loc, adaptor.getInputs()[0],
340 targetWidth, isLhsTypeSigned);
341 Value rhsValue =
extendTypeWidth(rewriter, loc, adaptor.getInputs()[1],
342 targetWidth, isRhsTypeSigned);
344 rewriter.create<ReplaceOp>(op.getLoc(), lhsValue, rhsValue,
false);
345 rewriter.modifyOpInPlace(
346 newOp, [&]() { newOp->setDialectAttrs(op->getDialectAttrs()); });
347 rewriter.replaceOp(op, newOp);
358 return it->second.type;
361 llvm::TypeSwitch<Type, Type>(type)
362 .Case<IntegerType>([](
auto type) {
363 if (type.isSignless())
365 return IntegerType::get(type.getContext(), type.getWidth());
367 .Case<hw::ArrayType>([
this](
auto type) {
369 type.getNumElements());
371 .Case<hw::StructType>([
this](
auto type) {
373 llvm::SmallVector<hw::StructType::FieldInfo> convertedElements;
374 for (
auto element : type.getElements()) {
375 convertedElements.push_back(
378 return hw::StructType::get(type.getContext(), convertedElements);
380 .Case<hw::InOutType>([
this](
auto type) {
383 .Case<hw::TypeAliasType>([
this](
auto type) {
384 return hw::TypeAliasType::get(
387 .Default([](
auto type) {
return type; });
389 return convertedType;
396 addTargetMaterialization([&](mlir::OpBuilder &builder, mlir::Type resultType,
397 mlir::ValueRange inputs,
398 mlir::Location loc) -> mlir::Value {
399 if (inputs.size() != 1)
402 .create<UnrealizedConversionCastOp>(loc, resultType, inputs[0])
406 addSourceMaterialization([&](mlir::OpBuilder &builder, mlir::Type resultType,
407 mlir::ValueRange inputs,
408 mlir::Location loc) -> mlir::Value {
409 if (inputs.size() != 1)
412 .create<UnrealizedConversionCastOp>(loc, resultType, inputs[0])
423 patterns.add<ConstantOpLowering, CastOpLowering, ICmpOpLowering,
424 BinaryOpLowering<AddOp, comb::AddOp>,
425 BinaryOpLowering<SubOp, comb::SubOp>,
426 BinaryOpLowering<MulOp, comb::MulOp>, DivOpLowering>(
427 typeConverter,
patterns.getContext());
432class HWArithToHWPass :
public circt::impl::HWArithToHWBase<HWArithToHWPass> {
434 void runOnOperation()
override {
435 ModuleOp
module = getOperation();
437 ConversionTarget target(getContext());
438 target.markUnknownOpDynamicallyLegal(
isLegalOp);
439 RewritePatternSet
patterns(&getContext());
441 target.addIllegalDialect<HWArithDialect>();
454 if (failed(applyFullConversion(module, target, std::move(
patterns))))
455 return signalPassFailure();
465 return std::make_unique<HWArithToHWPass>();
static SmallVector< Value > extractBits(ConversionPatternRewriter &rewriter, Value val)
static bool isLegalOp(Operation *op)
Returns true if the given op is considered as legal - i.e.
static bool isSignednessAttr(Attribute attr)
static bool isLegalOp(Operation *op)
Returns true if the given op is considered as legal for HWArith conversion.
static void improveNamehint(Value oldValue, Operation *newOp, llvm::function_ref< std::string(StringRef)> namehintCallback)
static Value extractBits(OpBuilder &builder, Location loc, Value value, unsigned startBit, unsigned bitWidth)
static bool isSignednessType(Type type)
static Value extendTypeWidth(OpBuilder &builder, Location loc, Value value, unsigned targetWidth, bool signExtension)
A helper type converter class that automatically populates the relevant materializations and type con...
mlir::Type removeSignedness(mlir::Type type)
llvm::DenseMap< mlir::Type, ConvertedType > conversionCache
HWArithToHWTypeConverter()
unsigned inferAddResultType(IntegerType::SignednessSemantics &signedness, IntegerType lhs, IntegerType rhs)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
void populateHWArithToHWConversionPatterns(HWArithToHWTypeConverter &typeConverter, RewritePatternSet &patterns)
Get the HWArith to HW conversion patterns.
std::unique_ptr< mlir::Pass > createHWArithToHWPass()
Generic pattern which replaces an operation by one of the same operation name, but with converted att...