14 #include "../PassDetail.h"
23 #include "mlir/Transforms/DialectConversion.h"
24 #include "llvm/ADT/TypeSwitch.h"
27 using namespace circt;
40 llvm::function_ref<std::string(StringRef)> namehintCallback) {
41 if (
auto *sourceOp = oldValue.getDefiningOp()) {
43 sourceOp->getAttrOfType<mlir::StringAttr>(
"sv.namehint")) {
44 auto newNamehint = namehintCallback(namehint.strref());
45 newOp->setAttr(
"sv.namehint",
53 unsigned startBit,
unsigned bitWidth) {
54 SmallVector<Value, 1> result;
55 builder.createOrFold<comb::ExtractOp>(result, loc, value, startBit, bitWidth);
56 Value extractedValue = result[0];
57 if (extractedValue != value) {
59 auto *newOp = extractedValue.getDefiningOp();
61 return (oldNamehint +
"_" + std::to_string(startBit) +
"_to_" +
62 std::to_string(startBit + bitWidth))
66 return extractedValue;
72 unsigned targetWidth,
bool signExtension) {
73 unsigned sourceWidth = value.getType().getIntOrFloatBitWidth();
74 unsigned extensionLength = targetWidth - sourceWidth;
76 if (extensionLength == 0)
85 SmallVector<Value, 1> result;
86 builder.createOrFold<comb::ReplicateOp>(result, loc, highBit,
88 extensionBits = result[0];
92 .create<hw::ConstantOp>(
93 loc,
builder.getIntegerType(extensionLength), 0)
97 auto extOp =
builder.create<comb::ConcatOp>(loc, extensionBits, value);
99 return (oldNamehint +
"_" + (signExtension ?
"sext_" :
"zext_") +
100 std::to_string(targetWidth))
104 return extOp->getOpResult(0);
109 llvm::TypeSwitch<Type, bool>(type)
110 .Case<IntegerType>([](
auto type) {
return !type.isSignless(); })
111 .Case<hw::ArrayType>(
113 .Case<hw::StructType>([](
auto type) {
114 return llvm::any_of(type.getElements(), [](
auto element) {
115 return isSignednessType(element.type);
118 .Case<hw::InOutType>(
120 .Default([](
auto type) {
return false; });
126 if (
auto typeAttr = attr.dyn_cast<TypeAttr>())
134 if (
auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
137 llvm::none_of(funcOp.getFunctionBody().getArgumentTypes(),
141 if (
auto modOp = dyn_cast<hw::HWModuleLike>(op)) {
143 llvm::none_of(modOp.getModuleBody().getArgumentTypes(),
147 auto attrs = llvm::map_range(op->getAttrs(), [](
const NamedAttribute &attr) {
148 return attr.getValue();
154 return operandsOK && resultsOK && attrsOK;
166 matchAndRewrite(ConstantOp constOp, OpAdaptor adaptor,
167 ConversionPatternRewriter &rewriter)
const override {
168 rewriter.replaceOpWithNewOp<hw::ConstantOp>(constOp,
169 constOp.getConstantValue());
177 matchAndRewrite(DivOp op, OpAdaptor adaptor,
178 ConversionPatternRewriter &rewriter)
const override {
179 auto loc = op.getLoc();
180 auto isLhsTypeSigned =
181 op.getOperand(0).getType().template cast<IntegerType>().isSigned();
182 auto rhsType = op.getOperand(1).getType().template cast<IntegerType>();
183 auto targetType = op.getResult().getType().template cast<IntegerType>();
194 bool signedDivision = targetType.isSigned();
195 unsigned extendSize = std::max(
196 targetType.getWidth(),
197 rhsType.getWidth() + (signedDivision && !rhsType.isSigned() ? 1 : 0));
200 Value lhsValue =
extendTypeWidth(rewriter, loc, adaptor.getInputs()[0],
201 extendSize, isLhsTypeSigned);
202 Value rhsValue =
extendTypeWidth(rewriter, loc, adaptor.getInputs()[1],
203 extendSize, rhsType.isSigned());
207 divResult = rewriter.create<comb::DivSOp>(loc, lhsValue, rhsValue,
false)
210 divResult = rewriter.create<comb::DivUOp>(loc, lhsValue, rhsValue,
false)
214 divResult.getDefiningOp()->setDialectAttrs(op->getDialectAttrs());
217 Value truncateResult =
extractBits(rewriter, loc, divResult, 0,
218 targetType.getWidth());
219 rewriter.replaceOp(op, truncateResult);
231 matchAndRewrite(CastOp op, OpAdaptor adaptor,
232 ConversionPatternRewriter &rewriter)
const override {
233 auto sourceType = op.getIn().getType().cast<IntegerType>();
234 auto sourceWidth = sourceType.getWidth();
235 bool isSourceTypeSigned = sourceType.isSigned();
236 auto targetWidth = op.getOut().getType().cast<IntegerType>().getWidth();
239 if (sourceWidth == targetWidth) {
242 replaceValue = adaptor.getIn();
243 }
else if (sourceWidth < targetWidth) {
247 targetWidth, isSourceTypeSigned);
250 replaceValue =
extractBits(rewriter, op.getLoc(), adaptor.getIn(),
253 rewriter.replaceOp(op, replaceValue);
265 static comb::ICmpPredicate lowerPredicate(ICmpPredicate pred,
bool isSigned) {
266 #define _CREATE_HWARITH_ICMP_CASE(x) \
267 case ICmpPredicate::x: \
268 return isSigned ? comb::ICmpPredicate::s##x : comb::ICmpPredicate::u##x
271 case ICmpPredicate::eq:
272 return comb::ICmpPredicate::eq;
274 case ICmpPredicate::ne:
275 return comb::ICmpPredicate::ne;
283 #undef _CREATE_HWARITH_ICMP_CASE
286 "Missing hwarith::ICmpPredicate to comb::ICmpPredicate lowering");
287 return comb::ICmpPredicate::eq;
294 matchAndRewrite(ICmpOp op, OpAdaptor adaptor,
295 ConversionPatternRewriter &rewriter)
const override {
296 auto lhsType = op.getLhs().getType().cast<IntegerType>();
297 auto rhsType = op.getRhs().getType().cast<IntegerType>();
298 IntegerType::SignednessSemantics cmpSignedness;
299 const unsigned cmpWidth =
302 ICmpPredicate pred = op.getPredicate();
303 comb::ICmpPredicate combPred = lowerPredicate(
304 pred, cmpSignedness == IntegerType::SignednessSemantics::Signed);
306 const auto loc = op.getLoc();
307 Value lhsValue =
extendTypeWidth(rewriter, loc, adaptor.getLhs(), cmpWidth,
309 Value rhsValue =
extendTypeWidth(rewriter, loc, adaptor.getRhs(), cmpWidth,
312 auto newOp = rewriter.replaceOpWithNewOp<comb::ICmpOp>(
313 op, combPred, lhsValue, rhsValue,
false);
314 newOp->setDialectAttrs(op->getDialectAttrs());
320 template <
class BinOp,
class ReplaceOp>
326 matchAndRewrite(BinOp op, OpAdaptor adaptor,
327 ConversionPatternRewriter &rewriter)
const override {
328 auto loc = op.getLoc();
329 auto isLhsTypeSigned =
330 op.getOperand(0).getType().template cast<IntegerType>().isSigned();
331 auto isRhsTypeSigned =
332 op.getOperand(1).getType().template cast<IntegerType>().isSigned();
334 op.getResult().getType().template cast<IntegerType>().getWidth();
336 Value lhsValue =
extendTypeWidth(rewriter, loc, adaptor.getInputs()[0],
337 targetWidth, isLhsTypeSigned);
338 Value rhsValue =
extendTypeWidth(rewriter, loc, adaptor.getInputs()[1],
339 targetWidth, isRhsTypeSigned);
341 rewriter.replaceOpWithNewOp<ReplaceOp>(op, lhsValue, rhsValue,
false);
342 newOp->setDialectAttrs(op->getDialectAttrs());
349 Type HWArithToHWTypeConverter::removeSignedness(Type type) {
350 auto it = conversionCache.find(type);
351 if (it != conversionCache.end())
352 return it->second.type;
355 llvm::TypeSwitch<Type, Type>(type)
356 .Case<IntegerType>([](
auto type) {
357 if (type.isSignless())
361 .Case<hw::ArrayType>([
this](
auto type) {
363 type.getNumElements());
365 .Case<hw::StructType>([
this](
auto type) {
367 llvm::SmallVector<hw::StructType::FieldInfo> convertedElements;
368 for (
auto element : type.getElements()) {
369 convertedElements.push_back(
370 {element.name, removeSignedness(element.type)});
374 .Case<hw::InOutType>([
this](
auto type) {
377 .Default([](
auto type) {
return type; });
379 return convertedType;
382 HWArithToHWTypeConverter::HWArithToHWTypeConverter() {
384 addConversion([
this](Type type) {
return removeSignedness(type); });
386 addTargetMaterialization(
387 [&](mlir::OpBuilder &
builder, mlir::Type resultType,
389 mlir::Location loc) -> std::optional<mlir::Value> {
395 addSourceMaterialization(
396 [&](mlir::OpBuilder &
builder, mlir::Type resultType,
398 mlir::Location loc) -> std::optional<mlir::Value> {
411 patterns.add<ConstantOpLowering, CastOpLowering, ICmpOpLowering,
412 BinaryOpLowering<AddOp, comb::AddOp>,
413 BinaryOpLowering<SubOp, comb::SubOp>,
414 BinaryOpLowering<MulOp, comb::MulOp>, DivOpLowering>(
415 typeConverter,
patterns.getContext());
420 class HWArithToHWPass :
public HWArithToHWBase<HWArithToHWPass> {
422 void runOnOperation()
override {
423 ModuleOp module = getOperation();
425 ConversionTarget target(getContext());
426 target.markUnknownOpDynamicallyLegal(
isLegalOp);
427 RewritePatternSet
patterns(&getContext());
429 target.addIllegalDialect<HWArithDialect>();
442 if (failed(applyFullConversion(module, target, std::move(
patterns))))
443 return signalPassFailure();
453 return std::make_unique<HWArithToHWPass>();
static bool isSignednessAttr(Attribute attr)
static bool isLegalOp(Operation *op)
Returns true if the given op is considered as legal for HWArith conversion.
#define _CREATE_HWARITH_ICMP_CASE(x)
static void improveNamehint(Value oldValue, Operation *newOp, llvm::function_ref< std::string(StringRef)> namehintCallback)
static Value extractBits(OpBuilder &builder, Location loc, Value value, unsigned startBit, unsigned bitWidth)
static bool isSignednessType(Type type)
static Value extendTypeWidth(OpBuilder &builder, Location loc, Value value, unsigned targetWidth, bool signExtension)
llvm::SmallVector< StringAttr > inputs
A helper type converter class that automatically populates the relevant materializations and type con...
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
unsigned inferAddResultType(IntegerType::SignednessSemantics &signedness, IntegerType lhs, IntegerType rhs)
This file defines an intermediate representation for circuits acting as an abstraction for constraint...
void populateHWArithToHWConversionPatterns(HWArithToHWTypeConverter &typeConverter, RewritePatternSet &patterns)
Get the HWArith to HW conversion patterns.
std::unique_ptr< mlir::Pass > createHWArithToHWPass()
Generic pattern which replaces an operation by one of the same operation name, but with converted att...