14 #include "../PassDetail.h"
23 #include "mlir/Transforms/DialectConversion.h"
24 #include "llvm/ADT/TypeSwitch.h"
27 using namespace circt;
40 llvm::function_ref<std::string(StringRef)> namehintCallback) {
41 if (
auto *sourceOp = oldValue.getDefiningOp()) {
43 sourceOp->getAttrOfType<mlir::StringAttr>(
"sv.namehint")) {
44 auto newNamehint = namehintCallback(namehint.strref());
45 newOp->setAttr(
"sv.namehint",
53 unsigned startBit,
unsigned bitWidth) {
54 Value extractedValue =
56 Operation *definingOp = extractedValue.getDefiningOp();
57 if (extractedValue != value && definingOp) {
60 return (oldNamehint +
"_" + std::to_string(startBit) +
"_to_" +
61 std::to_string(startBit + bitWidth))
65 return extractedValue;
71 unsigned targetWidth,
bool signExtension) {
72 unsigned sourceWidth = value.getType().getIntOrFloatBitWidth();
73 unsigned extensionLength = targetWidth - sourceWidth;
75 if (extensionLength == 0)
85 builder.createOrFold<comb::ReplicateOp>(loc, highBit, extensionLength);
90 loc,
builder.getIntegerType(extensionLength), 0)
96 return (oldNamehint +
"_" + (signExtension ?
"sext_" :
"zext_") +
97 std::to_string(targetWidth))
101 return extOp->getOpResult(0);
106 llvm::TypeSwitch<Type, bool>(type)
107 .Case<IntegerType>([](
auto type) {
return !type.isSignless(); })
108 .Case<hw::ArrayType>(
110 .Case<hw::StructType>([](
auto type) {
111 return llvm::any_of(type.getElements(), [](
auto element) {
112 return isSignednessType(element.type);
115 .Case<hw::InOutType>(
117 .Default([](
auto type) {
return false; });
123 if (
auto typeAttr = attr.dyn_cast<TypeAttr>())
131 if (
auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
134 llvm::none_of(funcOp.getFunctionBody().getArgumentTypes(),
138 if (
auto modOp = dyn_cast<hw::HWModuleLike>(op)) {
140 llvm::none_of(modOp.getModuleBody().getArgumentTypes(),
144 auto attrs = llvm::map_range(op->getAttrs(), [](
const NamedAttribute &attr) {
145 return attr.getValue();
151 return operandsOK && resultsOK && attrsOK;
163 matchAndRewrite(
ConstantOp constOp, OpAdaptor adaptor,
164 ConversionPatternRewriter &rewriter)
const override {
166 constOp.getConstantValue());
174 matchAndRewrite(
DivOp op, OpAdaptor adaptor,
175 ConversionPatternRewriter &rewriter)
const override {
176 auto loc = op.getLoc();
177 auto isLhsTypeSigned =
178 op.getOperand(0).getType().template cast<IntegerType>().isSigned();
179 auto rhsType = op.getOperand(1).getType().template cast<IntegerType>();
180 auto targetType = op.getResult().getType().template cast<IntegerType>();
191 bool signedDivision = targetType.isSigned();
192 unsigned extendSize = std::max(
193 targetType.getWidth(),
194 rhsType.getWidth() + (signedDivision && !rhsType.isSigned() ? 1 : 0));
197 Value lhsValue =
extendTypeWidth(rewriter, loc, adaptor.getInputs()[0],
198 extendSize, isLhsTypeSigned);
199 Value rhsValue =
extendTypeWidth(rewriter, loc, adaptor.getInputs()[1],
200 extendSize, rhsType.isSigned());
204 divResult = rewriter.create<
comb::DivSOp>(loc, lhsValue, rhsValue,
false)
207 divResult = rewriter.create<
comb::DivUOp>(loc, lhsValue, rhsValue,
false)
211 divResult.getDefiningOp()->setDialectAttrs(op->getDialectAttrs());
214 Value truncateResult =
extractBits(rewriter, loc, divResult, 0,
215 targetType.getWidth());
216 rewriter.replaceOp(op, truncateResult);
228 matchAndRewrite(
CastOp op, OpAdaptor adaptor,
229 ConversionPatternRewriter &rewriter)
const override {
230 auto sourceType = op.getIn().getType().cast<IntegerType>();
231 auto sourceWidth = sourceType.getWidth();
232 bool isSourceTypeSigned = sourceType.isSigned();
233 auto targetWidth = op.getOut().getType().cast<IntegerType>().
getWidth();
236 if (sourceWidth == targetWidth) {
239 replaceValue = adaptor.getIn();
240 }
else if (sourceWidth < targetWidth) {
244 targetWidth, isSourceTypeSigned);
247 replaceValue =
extractBits(rewriter, op.getLoc(), adaptor.getIn(),
250 rewriter.replaceOp(op, replaceValue);
262 static comb::ICmpPredicate lowerPredicate(ICmpPredicate pred,
bool isSigned) {
264 case ICmpPredicate::eq:
265 return comb::ICmpPredicate::eq;
266 case ICmpPredicate::ne:
267 return comb::ICmpPredicate::ne;
268 case ICmpPredicate::lt:
269 return isSigned ? comb::ICmpPredicate::slt : comb::ICmpPredicate::ult;
270 case ICmpPredicate::ge:
271 return isSigned ? comb::ICmpPredicate::sge : comb::ICmpPredicate::uge;
272 case ICmpPredicate::le:
273 return isSigned ? comb::ICmpPredicate::sle : comb::ICmpPredicate::ule;
274 case ICmpPredicate::gt:
275 return isSigned ? comb::ICmpPredicate::sgt : comb::ICmpPredicate::ugt;
279 "Missing hwarith::ICmpPredicate to comb::ICmpPredicate lowering");
280 return comb::ICmpPredicate::eq;
287 matchAndRewrite(
ICmpOp op, OpAdaptor adaptor,
288 ConversionPatternRewriter &rewriter)
const override {
289 auto lhsType = op.getLhs().getType().cast<IntegerType>();
290 auto rhsType = op.getRhs().getType().cast<IntegerType>();
291 IntegerType::SignednessSemantics cmpSignedness;
292 const unsigned cmpWidth =
295 ICmpPredicate pred = op.getPredicate();
296 comb::ICmpPredicate combPred = lowerPredicate(
297 pred, cmpSignedness == IntegerType::SignednessSemantics::Signed);
299 const auto loc = op.getLoc();
300 Value lhsValue =
extendTypeWidth(rewriter, loc, adaptor.getLhs(), cmpWidth,
302 Value rhsValue =
extendTypeWidth(rewriter, loc, adaptor.getRhs(), cmpWidth,
305 auto newOp = rewriter.replaceOpWithNewOp<comb::ICmpOp>(
306 op, combPred, lhsValue, rhsValue,
false);
307 newOp->setDialectAttrs(op->getDialectAttrs());
313 template <
class BinOp,
class ReplaceOp>
319 matchAndRewrite(BinOp op, OpAdaptor adaptor,
320 ConversionPatternRewriter &rewriter)
const override {
321 auto loc = op.getLoc();
322 auto isLhsTypeSigned =
323 op.getOperand(0).getType().template cast<IntegerType>().isSigned();
324 auto isRhsTypeSigned =
325 op.getOperand(1).getType().template cast<IntegerType>().isSigned();
327 op.getResult().getType().template cast<IntegerType>().getWidth();
329 Value lhsValue =
extendTypeWidth(rewriter, loc, adaptor.getInputs()[0],
330 targetWidth, isLhsTypeSigned);
331 Value rhsValue =
extendTypeWidth(rewriter, loc, adaptor.getInputs()[1],
332 targetWidth, isRhsTypeSigned);
334 rewriter.replaceOpWithNewOp<ReplaceOp>(op, lhsValue, rhsValue,
false);
335 newOp->setDialectAttrs(op->getDialectAttrs());
342 Type HWArithToHWTypeConverter::removeSignedness(Type type) {
343 auto it = conversionCache.find(type);
344 if (it != conversionCache.end())
345 return it->second.type;
348 llvm::TypeSwitch<Type, Type>(type)
349 .Case<IntegerType>([](
auto type) {
350 if (type.isSignless())
354 .Case<hw::ArrayType>([
this](
auto type) {
356 type.getNumElements());
358 .Case<hw::StructType>([
this](
auto type) {
360 llvm::SmallVector<hw::StructType::FieldInfo> convertedElements;
361 for (
auto element : type.getElements()) {
362 convertedElements.push_back(
363 {element.name, removeSignedness(element.type)});
367 .Case<hw::InOutType>([
this](
auto type) {
370 .Default([](
auto type) {
return type; });
372 return convertedType;
375 HWArithToHWTypeConverter::HWArithToHWTypeConverter() {
377 addConversion([
this](Type type) {
return removeSignedness(type); });
379 addTargetMaterialization(
380 [&](mlir::OpBuilder &
builder, mlir::Type resultType,
382 mlir::Location loc) -> std::optional<mlir::Value> {
388 addSourceMaterialization(
389 [&](mlir::OpBuilder &
builder, mlir::Type resultType,
391 mlir::Location loc) -> std::optional<mlir::Value> {
404 patterns.add<ConstantOpLowering, CastOpLowering, ICmpOpLowering,
405 BinaryOpLowering<AddOp, comb::AddOp>,
406 BinaryOpLowering<SubOp, comb::SubOp>,
407 BinaryOpLowering<MulOp, comb::MulOp>, DivOpLowering>(
408 typeConverter,
patterns.getContext());
413 class HWArithToHWPass :
public HWArithToHWBase<HWArithToHWPass> {
415 void runOnOperation()
override {
416 ModuleOp module = getOperation();
418 ConversionTarget target(getContext());
419 target.markUnknownOpDynamicallyLegal(
isLegalOp);
420 RewritePatternSet
patterns(&getContext());
422 target.addIllegalDialect<HWArithDialect>();
435 if (failed(applyFullConversion(module, target, std::move(
patterns))))
436 return signalPassFailure();
446 return std::make_unique<HWArithToHWPass>();
static bool isSignednessAttr(Attribute attr)
static bool isLegalOp(Operation *op)
Returns true if the given op is considered as legal for HWArith conversion.
static void improveNamehint(Value oldValue, Operation *newOp, llvm::function_ref< std::string(StringRef)> namehintCallback)
static Value extractBits(OpBuilder &builder, Location loc, Value value, unsigned startBit, unsigned bitWidth)
static bool isSignednessType(Type type)
static Value extendTypeWidth(OpBuilder &builder, Location loc, Value value, unsigned targetWidth, bool signExtension)
llvm::SmallVector< StringAttr > inputs
A helper type converter class that automatically populates the relevant materializations and type con...
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
uint64_t getWidth(Type t)
unsigned inferAddResultType(IntegerType::SignednessSemantics &signedness, IntegerType lhs, IntegerType rhs)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
void populateHWArithToHWConversionPatterns(HWArithToHWTypeConverter &typeConverter, RewritePatternSet &patterns)
Get the HWArith to HW conversion patterns.
std::unique_ptr< mlir::Pass > createHWArithToHWPass()
Generic pattern which replaces an operation by one of the same operation name, but with converted att...