21 #include "mlir/Pass/Pass.h"
23 #include "mlir/Transforms/DialectConversion.h"
24 #include "llvm/ADT/TypeSwitch.h"
27 #define GEN_PASS_DEF_HWARITHTOHW
28 #include "circt/Conversion/Passes.h.inc"
32 using namespace circt;
45 llvm::function_ref<std::string(StringRef)> namehintCallback) {
46 if (
auto *sourceOp = oldValue.getDefiningOp()) {
48 sourceOp->getAttrOfType<mlir::StringAttr>(
"sv.namehint")) {
49 auto newNamehint = namehintCallback(namehint.strref());
50 newOp->setAttr(
"sv.namehint",
57 static Value
extractBits(OpBuilder &builder, Location loc, Value value,
58 unsigned startBit,
unsigned bitWidth) {
59 Value extractedValue =
61 Operation *definingOp = extractedValue.getDefiningOp();
62 if (extractedValue != value && definingOp) {
65 return (oldNamehint +
"_" + std::to_string(startBit) +
"_to_" +
66 std::to_string(startBit + bitWidth))
70 return extractedValue;
76 unsigned targetWidth,
bool signExtension) {
77 unsigned sourceWidth = value.getType().getIntOrFloatBitWidth();
78 unsigned extensionLength = targetWidth - sourceWidth;
80 if (extensionLength == 0)
90 builder.createOrFold<comb::ReplicateOp>(loc, highBit, extensionLength);
93 extensionBits = builder
95 loc, builder.getIntegerType(extensionLength), 0)
99 auto extOp = builder.create<
comb::ConcatOp>(loc, extensionBits, value);
101 return (oldNamehint +
"_" + (signExtension ?
"sext_" :
"zext_") +
102 std::to_string(targetWidth))
106 return extOp->getOpResult(0);
111 llvm::TypeSwitch<Type, bool>(type)
112 .Case<IntegerType>([](
auto type) {
return !type.isSignless(); })
113 .Case<hw::ArrayType>(
115 .Case<hw::StructType>([](
auto type) {
116 return llvm::any_of(type.getElements(), [](
auto element) {
117 return isSignednessType(element.type);
120 .Case<hw::InOutType>(
122 .Default([](
auto type) {
return false; });
128 if (
auto typeAttr = dyn_cast<TypeAttr>(attr))
136 if (
auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
139 llvm::none_of(funcOp.getFunctionBody().getArgumentTypes(),
143 if (
auto modOp = dyn_cast<hw::HWModuleLike>(op)) {
145 llvm::none_of(modOp.getModuleBody().getArgumentTypes(),
149 auto attrs = llvm::map_range(op->getAttrs(), [](
const NamedAttribute &attr) {
150 return attr.getValue();
156 return operandsOK && resultsOK && attrsOK;
168 matchAndRewrite(
ConstantOp constOp, OpAdaptor adaptor,
169 ConversionPatternRewriter &rewriter)
const override {
171 constOp.getConstantValue());
179 matchAndRewrite(
DivOp op, OpAdaptor adaptor,
180 ConversionPatternRewriter &rewriter)
const override {
181 auto loc = op.getLoc();
182 auto isLhsTypeSigned =
183 cast<IntegerType>(op.getOperand(0).getType()).isSigned();
184 auto rhsType = cast<IntegerType>(op.getOperand(1).getType());
185 auto targetType = cast<IntegerType>(op.getResult().getType());
196 bool signedDivision = targetType.isSigned();
197 unsigned extendSize = std::max(
198 targetType.getWidth(),
199 rhsType.getWidth() + (signedDivision && !rhsType.isSigned() ? 1 : 0));
202 Value lhsValue =
extendTypeWidth(rewriter, loc, adaptor.getInputs()[0],
203 extendSize, isLhsTypeSigned);
204 Value rhsValue =
extendTypeWidth(rewriter, loc, adaptor.getInputs()[1],
205 extendSize, rhsType.isSigned());
209 divResult = rewriter.create<
comb::DivSOp>(loc, lhsValue, rhsValue,
false)
212 divResult = rewriter.create<
comb::DivUOp>(loc, lhsValue, rhsValue,
false)
216 auto *divOp = divResult.getDefiningOp();
217 rewriter.modifyOpInPlace(
218 divOp, [&]() { divOp->setDialectAttrs(op->getDialectAttrs()); });
221 Value truncateResult =
extractBits(rewriter, loc, divResult, 0,
222 targetType.getWidth());
223 rewriter.replaceOp(op, truncateResult);
235 matchAndRewrite(
CastOp op, OpAdaptor adaptor,
236 ConversionPatternRewriter &rewriter)
const override {
237 auto sourceType = cast<IntegerType>(op.getIn().getType());
238 auto sourceWidth = sourceType.getWidth();
239 bool isSourceTypeSigned = sourceType.isSigned();
240 auto targetWidth = cast<IntegerType>(op.getOut().getType()).getWidth();
243 if (sourceWidth == targetWidth) {
246 replaceValue = adaptor.getIn();
247 }
else if (sourceWidth < targetWidth) {
251 targetWidth, isSourceTypeSigned);
254 replaceValue =
extractBits(rewriter, op.getLoc(), adaptor.getIn(),
257 rewriter.replaceOp(op, replaceValue);
269 static comb::ICmpPredicate lowerPredicate(ICmpPredicate pred,
bool isSigned) {
271 case ICmpPredicate::eq:
272 return comb::ICmpPredicate::eq;
273 case ICmpPredicate::ne:
274 return comb::ICmpPredicate::ne;
275 case ICmpPredicate::lt:
276 return isSigned ? comb::ICmpPredicate::slt : comb::ICmpPredicate::ult;
277 case ICmpPredicate::ge:
278 return isSigned ? comb::ICmpPredicate::sge : comb::ICmpPredicate::uge;
279 case ICmpPredicate::le:
280 return isSigned ? comb::ICmpPredicate::sle : comb::ICmpPredicate::ule;
281 case ICmpPredicate::gt:
282 return isSigned ? comb::ICmpPredicate::sgt : comb::ICmpPredicate::ugt;
286 "Missing hwarith::ICmpPredicate to comb::ICmpPredicate lowering");
287 return comb::ICmpPredicate::eq;
294 matchAndRewrite(
ICmpOp op, OpAdaptor adaptor,
295 ConversionPatternRewriter &rewriter)
const override {
296 auto lhsType = cast<IntegerType>(op.getLhs().getType());
297 auto rhsType = cast<IntegerType>(op.getRhs().getType());
298 IntegerType::SignednessSemantics cmpSignedness;
299 const unsigned cmpWidth =
302 ICmpPredicate pred = op.getPredicate();
303 comb::ICmpPredicate combPred = lowerPredicate(
304 pred, cmpSignedness == IntegerType::SignednessSemantics::Signed);
306 const auto loc = op.getLoc();
307 Value lhsValue =
extendTypeWidth(rewriter, loc, adaptor.getLhs(), cmpWidth,
309 Value rhsValue =
extendTypeWidth(rewriter, loc, adaptor.getRhs(), cmpWidth,
312 auto newOp = rewriter.create<comb::ICmpOp>(op->getLoc(), combPred, lhsValue,
314 rewriter.modifyOpInPlace(
315 newOp, [&]() { newOp->setDialectAttrs(op->getDialectAttrs()); });
316 rewriter.replaceOp(op, newOp);
322 template <
class BinOp,
class ReplaceOp>
328 matchAndRewrite(BinOp op, OpAdaptor adaptor,
329 ConversionPatternRewriter &rewriter)
const override {
330 auto loc = op.getLoc();
331 auto isLhsTypeSigned =
332 cast<IntegerType>(op.getOperand(0).getType()).isSigned();
333 auto isRhsTypeSigned =
334 cast<IntegerType>(op.getOperand(1).getType()).isSigned();
335 auto targetWidth = cast<IntegerType>(op.getResult().getType()).getWidth();
337 Value lhsValue =
extendTypeWidth(rewriter, loc, adaptor.getInputs()[0],
338 targetWidth, isLhsTypeSigned);
339 Value rhsValue =
extendTypeWidth(rewriter, loc, adaptor.getInputs()[1],
340 targetWidth, isRhsTypeSigned);
342 rewriter.create<ReplaceOp>(op.getLoc(), lhsValue, rhsValue,
false);
343 rewriter.modifyOpInPlace(
344 newOp, [&]() { newOp->setDialectAttrs(op->getDialectAttrs()); });
345 rewriter.replaceOp(op, newOp);
353 Type HWArithToHWTypeConverter::removeSignedness(Type type) {
354 auto it = conversionCache.find(type);
355 if (it != conversionCache.end())
356 return it->second.type;
359 llvm::TypeSwitch<Type, Type>(type)
360 .Case<IntegerType>([](
auto type) {
361 if (type.isSignless())
365 .Case<hw::ArrayType>([
this](
auto type) {
367 type.getNumElements());
369 .Case<hw::StructType>([
this](
auto type) {
371 llvm::SmallVector<hw::StructType::FieldInfo> convertedElements;
372 for (
auto element : type.getElements()) {
373 convertedElements.push_back(
374 {element.name, removeSignedness(element.type)});
378 .Case<hw::InOutType>([
this](
auto type) {
381 .Default([](
auto type) {
return type; });
383 return convertedType;
386 HWArithToHWTypeConverter::HWArithToHWTypeConverter() {
388 addConversion([
this](Type type) {
return removeSignedness(type); });
390 addTargetMaterialization(
391 [&](mlir::OpBuilder &builder, mlir::Type resultType,
392 mlir::ValueRange inputs,
393 mlir::Location loc) -> std::optional<mlir::Value> {
394 if (inputs.size() != 1)
397 .create<UnrealizedConversionCastOp>(loc, resultType, inputs[0])
401 addSourceMaterialization(
402 [&](mlir::OpBuilder &builder, mlir::Type resultType,
403 mlir::ValueRange inputs,
404 mlir::Location loc) -> std::optional<mlir::Value> {
405 if (inputs.size() != 1)
408 .create<UnrealizedConversionCastOp>(loc, resultType, inputs[0])
419 patterns.add<ConstantOpLowering, CastOpLowering, ICmpOpLowering,
420 BinaryOpLowering<AddOp, comb::AddOp>,
421 BinaryOpLowering<SubOp, comb::SubOp>,
422 BinaryOpLowering<MulOp, comb::MulOp>, DivOpLowering>(
423 typeConverter,
patterns.getContext());
428 class HWArithToHWPass :
public circt::impl::HWArithToHWBase<HWArithToHWPass> {
430 void runOnOperation()
override {
431 ModuleOp module = getOperation();
433 ConversionTarget target(getContext());
434 target.markUnknownOpDynamicallyLegal(
isLegalOp);
435 RewritePatternSet
patterns(&getContext());
437 target.addIllegalDialect<HWArithDialect>();
450 if (failed(applyFullConversion(module, target, std::move(
patterns))))
451 return signalPassFailure();
461 return std::make_unique<HWArithToHWPass>();
static bool isSignednessAttr(Attribute attr)
static bool isLegalOp(Operation *op)
Returns true if the given op is considered as legal for HWArith conversion.
static void improveNamehint(Value oldValue, Operation *newOp, llvm::function_ref< std::string(StringRef)> namehintCallback)
static Value extractBits(OpBuilder &builder, Location loc, Value value, unsigned startBit, unsigned bitWidth)
static bool isSignednessType(Type type)
static Value extendTypeWidth(OpBuilder &builder, Location loc, Value value, unsigned targetWidth, bool signExtension)
A helper type converter class that automatically populates the relevant materializations and type con...
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
unsigned inferAddResultType(IntegerType::SignednessSemantics &signedness, IntegerType lhs, IntegerType rhs)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
void populateHWArithToHWConversionPatterns(HWArithToHWTypeConverter &typeConverter, RewritePatternSet &patterns)
Get the HWArith to HW conversion patterns.
std::unique_ptr< mlir::Pass > createHWArithToHWPass()
Generic pattern which replaces an operation by one of the same operation name, but with converted att...