21#include "mlir/Pass/Pass.h"
23#include "mlir/Transforms/DialectConversion.h"
24#include "llvm/ADT/TypeSwitch.h"
27#define GEN_PASS_DEF_HWARITHTOHW
28#include "circt/Conversion/Passes.h.inc"
45 llvm::function_ref<std::string(StringRef)> namehintCallback) {
46 if (
auto *sourceOp = oldValue.getDefiningOp()) {
48 sourceOp->getAttrOfType<mlir::StringAttr>(
"sv.namehint")) {
49 auto newNamehint = namehintCallback(namehint.strref());
50 newOp->setAttr(
"sv.namehint",
51 StringAttr::get(oldValue.getContext(), newNamehint));
57static Value
extractBits(OpBuilder &builder, Location loc, Value value,
58 unsigned startBit,
unsigned bitWidth) {
59 Value extractedValue =
61 Operation *definingOp = extractedValue.getDefiningOp();
62 if (extractedValue != value && definingOp) {
65 return (oldNamehint +
"_" + std::to_string(startBit) +
"_to_" +
66 std::to_string(startBit + bitWidth))
70 return extractedValue;
57static Value
extractBits(OpBuilder &builder, Location loc, Value value, {
…}
76 unsigned targetWidth,
bool signExtension) {
77 unsigned sourceWidth = value.getType().getIntOrFloatBitWidth();
78 unsigned extensionLength = targetWidth - sourceWidth;
80 if (extensionLength == 0)
90 builder.createOrFold<comb::ReplicateOp>(loc, highBit, extensionLength);
93 extensionBits = builder
95 loc, builder.getIntegerType(extensionLength), 0)
99 auto extOp = builder.create<
comb::ConcatOp>(loc, extensionBits, value);
101 return (oldNamehint +
"_" + (signExtension ?
"sext_" :
"zext_") +
102 std::to_string(targetWidth))
106 return extOp->getOpResult(0);
111 llvm::TypeSwitch<Type, bool>(type)
112 .Case<IntegerType>([](
auto type) {
return !type.isSignless(); })
113 .Case<hw::ArrayType>(
115 .Case<hw::UnpackedArrayType>(
117 .Case<hw::StructType>([](
auto type) {
118 return llvm::any_of(type.getElements(), [](
auto element) {
119 return isSignednessType(element.type);
122 .Case<hw::InOutType>(
124 .Case<hw::TypeAliasType>(
126 .Default([](
auto type) {
return false; });
132 if (
auto typeAttr = dyn_cast<TypeAttr>(attr))
140 if (
auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
143 llvm::none_of(funcOp.getFunctionBody().getArgumentTypes(),
147 if (
auto modOp = dyn_cast<hw::HWModuleLike>(op)) {
149 llvm::none_of(modOp.getModuleBody().getArgumentTypes(),
153 auto attrs = llvm::map_range(op->getAttrs(), [](
const NamedAttribute &attr) {
154 return attr.getValue();
160 return operandsOK && resultsOK && attrsOK;
172 matchAndRewrite(
ConstantOp constOp, OpAdaptor adaptor,
173 ConversionPatternRewriter &rewriter)
const override {
175 constOp.getConstantValue());
183 matchAndRewrite(
DivOp op, OpAdaptor adaptor,
184 ConversionPatternRewriter &rewriter)
const override {
185 auto loc = op.getLoc();
186 auto isLhsTypeSigned =
187 cast<IntegerType>(op.getOperand(0).getType()).isSigned();
188 auto rhsType = cast<IntegerType>(op.getOperand(1).getType());
189 auto targetType = cast<IntegerType>(op.getResult().getType());
200 bool signedDivision = targetType.isSigned();
201 unsigned extendSize = std::max(
202 targetType.getWidth(),
203 rhsType.getWidth() + (signedDivision && !rhsType.isSigned() ? 1 : 0));
206 Value lhsValue =
extendTypeWidth(rewriter, loc, adaptor.getInputs()[0],
207 extendSize, isLhsTypeSigned);
208 Value rhsValue =
extendTypeWidth(rewriter, loc, adaptor.getInputs()[1],
209 extendSize, rhsType.isSigned());
213 divResult = rewriter.create<
comb::DivSOp>(loc, lhsValue, rhsValue,
false)
216 divResult = rewriter.create<
comb::DivUOp>(loc, lhsValue, rhsValue,
false)
220 auto *divOp = divResult.getDefiningOp();
221 rewriter.modifyOpInPlace(
222 divOp, [&]() { divOp->setDialectAttrs(op->getDialectAttrs()); });
225 Value truncateResult =
extractBits(rewriter, loc, divResult, 0,
226 targetType.getWidth());
227 rewriter.replaceOp(op, truncateResult);
239 matchAndRewrite(
CastOp op, OpAdaptor adaptor,
240 ConversionPatternRewriter &rewriter)
const override {
241 auto sourceType = cast<IntegerType>(op.getIn().getType());
242 auto sourceWidth = sourceType.getWidth();
243 bool isSourceTypeSigned = sourceType.isSigned();
244 auto targetWidth = cast<IntegerType>(op.getOut().getType()).getWidth();
247 if (sourceWidth == targetWidth) {
250 replaceValue = adaptor.getIn();
251 }
else if (sourceWidth < targetWidth) {
255 targetWidth, isSourceTypeSigned);
258 replaceValue =
extractBits(rewriter, op.getLoc(), adaptor.getIn(),
261 rewriter.replaceOp(op, replaceValue);
273static comb::ICmpPredicate lowerPredicate(ICmpPredicate pred,
bool isSigned) {
275 case ICmpPredicate::eq:
276 return comb::ICmpPredicate::eq;
277 case ICmpPredicate::ne:
278 return comb::ICmpPredicate::ne;
279 case ICmpPredicate::lt:
280 return isSigned ? comb::ICmpPredicate::slt : comb::ICmpPredicate::ult;
281 case ICmpPredicate::ge:
282 return isSigned ? comb::ICmpPredicate::sge : comb::ICmpPredicate::uge;
283 case ICmpPredicate::le:
284 return isSigned ? comb::ICmpPredicate::sle : comb::ICmpPredicate::ule;
285 case ICmpPredicate::gt:
286 return isSigned ? comb::ICmpPredicate::sgt : comb::ICmpPredicate::ugt;
290 "Missing hwarith::ICmpPredicate to comb::ICmpPredicate lowering");
291 return comb::ICmpPredicate::eq;
298 matchAndRewrite(
ICmpOp op, OpAdaptor adaptor,
299 ConversionPatternRewriter &rewriter)
const override {
300 auto lhsType = cast<IntegerType>(op.getLhs().getType());
301 auto rhsType = cast<IntegerType>(op.getRhs().getType());
302 IntegerType::SignednessSemantics cmpSignedness;
303 const unsigned cmpWidth =
306 ICmpPredicate pred = op.getPredicate();
307 comb::ICmpPredicate combPred = lowerPredicate(
308 pred, cmpSignedness == IntegerType::SignednessSemantics::Signed);
310 const auto loc = op.getLoc();
311 Value lhsValue =
extendTypeWidth(rewriter, loc, adaptor.getLhs(), cmpWidth,
313 Value rhsValue =
extendTypeWidth(rewriter, loc, adaptor.getRhs(), cmpWidth,
316 auto newOp = rewriter.create<comb::ICmpOp>(op->getLoc(), combPred, lhsValue,
318 rewriter.modifyOpInPlace(
319 newOp, [&]() { newOp->setDialectAttrs(op->getDialectAttrs()); });
320 rewriter.replaceOp(op, newOp);
326template <
class BinOp,
class ReplaceOp>
332 matchAndRewrite(BinOp op, OpAdaptor adaptor,
333 ConversionPatternRewriter &rewriter)
const override {
334 auto loc = op.getLoc();
335 auto isLhsTypeSigned =
336 cast<IntegerType>(op.getOperand(0).getType()).isSigned();
337 auto isRhsTypeSigned =
338 cast<IntegerType>(op.getOperand(1).getType()).isSigned();
339 auto targetWidth = cast<IntegerType>(op.getResult().getType()).getWidth();
341 Value lhsValue =
extendTypeWidth(rewriter, loc, adaptor.getInputs()[0],
342 targetWidth, isLhsTypeSigned);
343 Value rhsValue =
extendTypeWidth(rewriter, loc, adaptor.getInputs()[1],
344 targetWidth, isRhsTypeSigned);
346 rewriter.create<ReplaceOp>(op.getLoc(), lhsValue, rhsValue,
false);
347 rewriter.modifyOpInPlace(
348 newOp, [&]() { newOp->setDialectAttrs(op->getDialectAttrs()); });
349 rewriter.replaceOp(op, newOp);
360 return it->second.type;
363 llvm::TypeSwitch<Type, Type>(type)
364 .Case<IntegerType>([](
auto type) {
365 if (type.isSignless())
367 return IntegerType::get(type.getContext(), type.getWidth());
369 .Case<hw::ArrayType>([
this](
auto type) {
371 type.getNumElements());
373 .Case<hw::UnpackedArrayType>([
this](
auto type) {
374 return hw::UnpackedArrayType::get(
377 .Case<hw::StructType>([
this](
auto type) {
379 llvm::SmallVector<hw::StructType::FieldInfo> convertedElements;
380 for (
auto element : type.getElements()) {
381 convertedElements.push_back(
384 return hw::StructType::get(type.getContext(), convertedElements);
386 .Case<hw::InOutType>([
this](
auto type) {
389 .Case<hw::TypeAliasType>([
this](
auto type) {
390 return hw::TypeAliasType::get(
393 .Default([](
auto type) {
return type; });
395 return convertedType;
402 addTargetMaterialization([&](mlir::OpBuilder &builder, mlir::Type resultType,
403 mlir::ValueRange inputs,
404 mlir::Location loc) -> mlir::Value {
405 if (inputs.size() != 1)
408 .create<UnrealizedConversionCastOp>(loc, resultType, inputs[0])
412 addSourceMaterialization([&](mlir::OpBuilder &builder, mlir::Type resultType,
413 mlir::ValueRange inputs,
414 mlir::Location loc) -> mlir::Value {
415 if (inputs.size() != 1)
418 .create<UnrealizedConversionCastOp>(loc, resultType, inputs[0])
429 patterns.add<ConstantOpLowering, CastOpLowering, ICmpOpLowering,
430 BinaryOpLowering<AddOp, comb::AddOp>,
431 BinaryOpLowering<SubOp, comb::SubOp>,
432 BinaryOpLowering<MulOp, comb::MulOp>, DivOpLowering>(
433 typeConverter,
patterns.getContext());
438class HWArithToHWPass :
public circt::impl::HWArithToHWBase<HWArithToHWPass> {
440 void runOnOperation()
override {
441 ModuleOp
module = getOperation();
443 ConversionTarget target(getContext());
444 target.markUnknownOpDynamicallyLegal(
isLegalOp);
445 RewritePatternSet
patterns(&getContext());
447 target.addIllegalDialect<HWArithDialect>();
460 if (failed(applyFullConversion(module, target, std::move(
patterns))))
461 return signalPassFailure();
471 return std::make_unique<HWArithToHWPass>();
static SmallVector< Value > extractBits(OpBuilder &builder, Value val)
static bool isLegalOp(Operation *op)
Returns true if the given op is considered as legal - i.e.
static bool isSignednessAttr(Attribute attr)
static bool isLegalOp(Operation *op)
Returns true if the given op is considered as legal for HWArith conversion.
static void improveNamehint(Value oldValue, Operation *newOp, llvm::function_ref< std::string(StringRef)> namehintCallback)
static Value extractBits(OpBuilder &builder, Location loc, Value value, unsigned startBit, unsigned bitWidth)
static bool isSignednessType(Type type)
static Value extendTypeWidth(OpBuilder &builder, Location loc, Value value, unsigned targetWidth, bool signExtension)
A helper type converter class that automatically populates the relevant materializations and type con...
mlir::Type removeSignedness(mlir::Type type)
llvm::DenseMap< mlir::Type, ConvertedType > conversionCache
HWArithToHWTypeConverter()
unsigned inferAddResultType(IntegerType::SignednessSemantics &signedness, IntegerType lhs, IntegerType rhs)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
void populateHWArithToHWConversionPatterns(HWArithToHWTypeConverter &typeConverter, RewritePatternSet &patterns)
Get the HWArith to HW conversion patterns.
std::unique_ptr< mlir::Pass > createHWArithToHWPass()
Generic pattern which replaces an operation by one of the same operation name, but with converted att...