21#include "mlir/Pass/Pass.h"
23#include "mlir/Transforms/DialectConversion.h"
24#include "llvm/ADT/TypeSwitch.h"
27#define GEN_PASS_DEF_HWARITHTOHW
28#include "circt/Conversion/Passes.h.inc"
45 llvm::function_ref<std::string(StringRef)> namehintCallback) {
46 if (
auto *sourceOp = oldValue.getDefiningOp()) {
48 sourceOp->getAttrOfType<mlir::StringAttr>(
"sv.namehint")) {
49 auto newNamehint = namehintCallback(namehint.strref());
50 newOp->setAttr(
"sv.namehint",
51 StringAttr::get(oldValue.getContext(), newNamehint));
57static Value
extractBits(OpBuilder &builder, Location loc, Value value,
58 unsigned startBit,
unsigned bitWidth) {
59 Value extractedValue =
61 Operation *definingOp = extractedValue.getDefiningOp();
62 if (extractedValue != value && definingOp) {
65 return (oldNamehint +
"_" + std::to_string(startBit) +
"_to_" +
66 std::to_string(startBit + bitWidth))
70 return extractedValue;
76 unsigned targetWidth,
bool signExtension) {
77 unsigned sourceWidth = value.getType().getIntOrFloatBitWidth();
78 unsigned extensionLength = targetWidth - sourceWidth;
80 if (extensionLength == 0)
90 builder.createOrFold<comb::ReplicateOp>(loc, highBit, extensionLength);
95 builder.getIntegerType(extensionLength), 0)
99 auto extOp = comb::ConcatOp::create(builder, loc, extensionBits, value);
101 return (oldNamehint +
"_" + (signExtension ?
"sext_" :
"zext_") +
102 std::to_string(targetWidth))
106 return extOp->getOpResult(0);
111 llvm::TypeSwitch<Type, bool>(type)
112 .Case<IntegerType>([](
auto type) {
return !type.isSignless(); })
113 .Case<hw::ArrayType>(
115 .Case<hw::UnpackedArrayType>(
117 .Case<hw::StructType>([](
auto type) {
118 return llvm::any_of(type.getElements(), [](
auto element) {
119 return isSignednessType(element.type);
122 .Case<hw::UnionType>([](
auto type) {
123 return llvm::any_of(type.getElements(), [](
auto element) {
124 return isSignednessType(element.type);
127 .Case<hw::InOutType>(
129 .Case<hw::TypeAliasType>(
131 .Default([](
auto type) {
return false; });
137 if (
auto typeAttr = dyn_cast<TypeAttr>(attr))
145 if (
auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
148 llvm::none_of(funcOp.getFunctionBody().getArgumentTypes(),
152 if (
auto modOp = dyn_cast<hw::HWModuleLike>(op)) {
154 llvm::none_of(modOp.getModuleBody().getArgumentTypes(),
158 auto attrs = llvm::map_range(op->getAttrs(), [](
const NamedAttribute &attr) {
159 return attr.getValue();
165 return operandsOK && resultsOK && attrsOK;
177 matchAndRewrite(
ConstantOp constOp, OpAdaptor adaptor,
178 ConversionPatternRewriter &rewriter)
const override {
180 constOp.getConstantValue());
188 matchAndRewrite(
DivOp op, OpAdaptor adaptor,
189 ConversionPatternRewriter &rewriter)
const override {
190 auto loc = op.getLoc();
191 auto isLhsTypeSigned =
192 cast<IntegerType>(op.getOperand(0).getType()).isSigned();
193 auto rhsType = cast<IntegerType>(op.getOperand(1).getType());
194 auto targetType = cast<IntegerType>(op.getResult().getType());
205 bool signedDivision = targetType.isSigned();
206 unsigned extendSize = std::max(
207 targetType.getWidth(),
208 rhsType.getWidth() + (signedDivision && !rhsType.isSigned() ? 1 : 0));
211 Value lhsValue =
extendTypeWidth(rewriter, loc, adaptor.getInputs()[0],
212 extendSize, isLhsTypeSigned);
213 Value rhsValue =
extendTypeWidth(rewriter, loc, adaptor.getInputs()[1],
214 extendSize, rhsType.isSigned());
218 divResult = comb::DivSOp::create(rewriter, loc, lhsValue, rhsValue,
false)
221 divResult = comb::DivUOp::create(rewriter, loc, lhsValue, rhsValue,
false)
225 auto *divOp = divResult.getDefiningOp();
226 rewriter.modifyOpInPlace(
227 divOp, [&]() { divOp->setDialectAttrs(op->getDialectAttrs()); });
230 Value truncateResult =
extractBits(rewriter, loc, divResult, 0,
231 targetType.getWidth());
232 rewriter.replaceOp(op, truncateResult);
244 matchAndRewrite(
CastOp op, OpAdaptor adaptor,
245 ConversionPatternRewriter &rewriter)
const override {
246 auto sourceType = cast<IntegerType>(op.getIn().getType());
247 auto sourceWidth = sourceType.getWidth();
248 bool isSourceTypeSigned = sourceType.isSigned();
249 auto targetWidth = cast<IntegerType>(op.getOut().getType()).getWidth();
252 if (sourceWidth == targetWidth) {
255 replaceValue = adaptor.getIn();
256 }
else if (sourceWidth < targetWidth) {
260 targetWidth, isSourceTypeSigned);
263 replaceValue =
extractBits(rewriter, op.getLoc(), adaptor.getIn(),
266 rewriter.replaceOp(op, replaceValue);
278static comb::ICmpPredicate lowerPredicate(ICmpPredicate pred,
bool isSigned) {
280 case ICmpPredicate::eq:
281 return comb::ICmpPredicate::eq;
282 case ICmpPredicate::ne:
283 return comb::ICmpPredicate::ne;
284 case ICmpPredicate::lt:
285 return isSigned ? comb::ICmpPredicate::slt : comb::ICmpPredicate::ult;
286 case ICmpPredicate::ge:
287 return isSigned ? comb::ICmpPredicate::sge : comb::ICmpPredicate::uge;
288 case ICmpPredicate::le:
289 return isSigned ? comb::ICmpPredicate::sle : comb::ICmpPredicate::ule;
290 case ICmpPredicate::gt:
291 return isSigned ? comb::ICmpPredicate::sgt : comb::ICmpPredicate::ugt;
295 "Missing hwarith::ICmpPredicate to comb::ICmpPredicate lowering");
296 return comb::ICmpPredicate::eq;
303 matchAndRewrite(
ICmpOp op, OpAdaptor adaptor,
304 ConversionPatternRewriter &rewriter)
const override {
305 auto lhsType = cast<IntegerType>(op.getLhs().getType());
306 auto rhsType = cast<IntegerType>(op.getRhs().getType());
307 IntegerType::SignednessSemantics cmpSignedness;
308 const unsigned cmpWidth =
311 ICmpPredicate pred = op.getPredicate();
312 comb::ICmpPredicate combPred = lowerPredicate(
313 pred, cmpSignedness == IntegerType::SignednessSemantics::Signed);
315 const auto loc = op.getLoc();
316 Value lhsValue =
extendTypeWidth(rewriter, loc, adaptor.getLhs(), cmpWidth,
318 Value rhsValue =
extendTypeWidth(rewriter, loc, adaptor.getRhs(), cmpWidth,
321 auto newOp = comb::ICmpOp::create(rewriter, op->getLoc(), combPred,
322 lhsValue, rhsValue,
false);
323 rewriter.modifyOpInPlace(
324 newOp, [&]() { newOp->setDialectAttrs(op->getDialectAttrs()); });
325 rewriter.replaceOp(op, newOp);
331template <
class BinOp,
class ReplaceOp>
337 matchAndRewrite(BinOp op, OpAdaptor adaptor,
338 ConversionPatternRewriter &rewriter)
const override {
339 auto loc = op.getLoc();
340 auto isLhsTypeSigned =
341 cast<IntegerType>(op.getOperand(0).getType()).isSigned();
342 auto isRhsTypeSigned =
343 cast<IntegerType>(op.getOperand(1).getType()).isSigned();
344 auto targetWidth = cast<IntegerType>(op.getResult().getType()).getWidth();
346 Value lhsValue =
extendTypeWidth(rewriter, loc, adaptor.getInputs()[0],
347 targetWidth, isLhsTypeSigned);
348 Value rhsValue =
extendTypeWidth(rewriter, loc, adaptor.getInputs()[1],
349 targetWidth, isRhsTypeSigned);
351 ReplaceOp::create(rewriter, op.getLoc(), lhsValue, rhsValue,
false);
352 rewriter.modifyOpInPlace(
353 newOp, [&]() { newOp->setDialectAttrs(op->getDialectAttrs()); });
354 rewriter.replaceOp(op, newOp);
365 return it->second.type;
368 llvm::TypeSwitch<Type, Type>(type)
369 .Case<IntegerType>([](
auto type) {
370 if (type.isSignless())
372 return IntegerType::get(type.getContext(), type.getWidth());
374 .Case<hw::ArrayType>([
this](
auto type) {
376 type.getNumElements());
378 .Case<hw::UnpackedArrayType>([
this](
auto type) {
379 return hw::UnpackedArrayType::get(
382 .Case<hw::StructType>([
this](
auto type) {
384 llvm::SmallVector<hw::StructType::FieldInfo> convertedElements;
385 for (
auto element : type.getElements()) {
386 convertedElements.push_back(
389 return hw::StructType::get(type.getContext(), convertedElements);
391 .Case<hw::UnionType>([
this](
auto type) {
392 llvm::SmallVector<hw::UnionType::FieldInfo> convertedElements;
393 for (
auto element : type.getElements()) {
394 convertedElements.push_back({element.name,
398 return hw::UnionType::get(type.getContext(), convertedElements);
400 .Case<hw::InOutType>([
this](
auto type) {
403 .Case<hw::TypeAliasType>([
this](
auto type) {
404 return hw::TypeAliasType::get(
407 .Default([](
auto type) {
return type; });
409 return convertedType;
416 addTargetMaterialization([&](mlir::OpBuilder &builder, mlir::Type resultType,
417 mlir::ValueRange inputs,
418 mlir::Location loc) -> mlir::Value {
419 if (inputs.size() != 1)
421 return UnrealizedConversionCastOp::create(builder, loc, resultType,
426 addSourceMaterialization([&](mlir::OpBuilder &builder, mlir::Type resultType,
427 mlir::ValueRange inputs,
428 mlir::Location loc) -> mlir::Value {
429 if (inputs.size() != 1)
431 return UnrealizedConversionCastOp::create(builder, loc, resultType,
443 patterns.add<ConstantOpLowering, CastOpLowering, ICmpOpLowering,
444 BinaryOpLowering<AddOp, comb::AddOp>,
445 BinaryOpLowering<SubOp, comb::SubOp>,
446 BinaryOpLowering<MulOp, comb::MulOp>, DivOpLowering>(
447 typeConverter,
patterns.getContext());
452class HWArithToHWPass :
public circt::impl::HWArithToHWBase<HWArithToHWPass> {
454 void runOnOperation()
override {
455 ModuleOp
module = getOperation();
457 ConversionTarget target(getContext());
458 target.markUnknownOpDynamicallyLegal(
isLegalOp);
459 RewritePatternSet
patterns(&getContext());
461 target.addIllegalDialect<HWArithDialect>();
474 if (failed(applyFullConversion(module, target, std::move(
patterns))))
475 return signalPassFailure();
485 return std::make_unique<HWArithToHWPass>();
static SmallVector< Value > extractBits(OpBuilder &builder, Value val)
static bool isLegalOp(Operation *op)
Returns true if the given op is considered as legal - i.e.
static bool isSignednessAttr(Attribute attr)
static bool isLegalOp(Operation *op)
Returns true if the given op is considered as legal for HWArith conversion.
static void improveNamehint(Value oldValue, Operation *newOp, llvm::function_ref< std::string(StringRef)> namehintCallback)
static Value extractBits(OpBuilder &builder, Location loc, Value value, unsigned startBit, unsigned bitWidth)
static bool isSignednessType(Type type)
static Value extendTypeWidth(OpBuilder &builder, Location loc, Value value, unsigned targetWidth, bool signExtension)
A helper type converter class that automatically populates the relevant materializations and type con...
mlir::Type removeSignedness(mlir::Type type)
llvm::DenseMap< mlir::Type, ConvertedType > conversionCache
HWArithToHWTypeConverter()
unsigned inferAddResultType(IntegerType::SignednessSemantics &signedness, IntegerType lhs, IntegerType rhs)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
void populateHWArithToHWConversionPatterns(HWArithToHWTypeConverter &typeConverter, RewritePatternSet &patterns)
Get the HWArith to HW conversion patterns.
std::unique_ptr< mlir::Pass > createHWArithToHWPass()
Generic pattern which replaces an operation by one of the same operation name, but with converted att...