21#include "mlir/Pass/Pass.h" 
   23#include "mlir/Transforms/DialectConversion.h" 
   24#include "llvm/ADT/TypeSwitch.h" 
   27#define GEN_PASS_DEF_HWARITHTOHW 
   28#include "circt/Conversion/Passes.h.inc" 
   45                llvm::function_ref<std::string(StringRef)> namehintCallback) {
 
   46  if (
auto *sourceOp = oldValue.getDefiningOp()) {
 
   48            sourceOp->getAttrOfType<mlir::StringAttr>(
"sv.namehint")) {
 
   49      auto newNamehint = namehintCallback(namehint.strref());
 
   50      newOp->setAttr(
"sv.namehint",
 
   51                     StringAttr::get(oldValue.getContext(), newNamehint));
 
 
   57static Value 
extractBits(OpBuilder &builder, Location loc, Value value,
 
   58                         unsigned startBit, 
unsigned bitWidth) {
 
   59  Value extractedValue =
 
   61  Operation *definingOp = extractedValue.getDefiningOp();
 
   62  if (extractedValue != value && definingOp) {
 
   65      return (oldNamehint + 
"_" + std::to_string(startBit) + 
"_to_" +
 
   66              std::to_string(startBit + bitWidth))
 
   70  return extractedValue;
 
    57static Value 
extractBits(OpBuilder &builder, Location loc, Value value, {
…}
  
   76                             unsigned targetWidth, 
bool signExtension) {
 
   77  unsigned sourceWidth = value.getType().getIntOrFloatBitWidth();
 
   78  unsigned extensionLength = targetWidth - sourceWidth;
 
   80  if (extensionLength == 0)
 
   90        builder.createOrFold<comb::ReplicateOp>(loc, highBit, extensionLength);
 
   95                               builder.getIntegerType(extensionLength), 0)
 
   99  auto extOp = comb::ConcatOp::create(builder, loc, extensionBits, value);
 
  101    return (oldNamehint + 
"_" + (signExtension ? 
"sext_" : 
"zext_") +
 
  102            std::to_string(targetWidth))
 
  106  return extOp->getOpResult(0);
 
 
  111      llvm::TypeSwitch<Type, bool>(type)
 
  112          .Case<IntegerType>([](
auto type) { 
return !type.isSignless(); })
 
  113          .Case<hw::ArrayType>(
 
  115          .Case<hw::UnpackedArrayType>(
 
  117          .Case<hw::StructType>([](
auto type) {
 
  118            return llvm::any_of(type.getElements(), [](
auto element) {
 
  119              return isSignednessType(element.type);
 
  122          .Case<hw::InOutType>(
 
  124          .Case<hw::TypeAliasType>(
 
  126          .Default([](
auto type) { 
return false; });
 
 
  132  if (
auto typeAttr = dyn_cast<TypeAttr>(attr))
 
 
  140  if (
auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
 
  143           llvm::none_of(funcOp.getFunctionBody().getArgumentTypes(),
 
  147  if (
auto modOp = dyn_cast<hw::HWModuleLike>(op)) {
 
  149           llvm::none_of(modOp.getModuleBody().getArgumentTypes(),
 
  153  auto attrs = llvm::map_range(op->getAttrs(), [](
const NamedAttribute &attr) {
 
  154    return attr.getValue();
 
  160  return operandsOK && resultsOK && attrsOK;
 
 
  172  matchAndRewrite(
ConstantOp constOp, OpAdaptor adaptor,
 
  173                  ConversionPatternRewriter &rewriter)
 const override {
 
  175                                                constOp.getConstantValue());
 
  183  matchAndRewrite(
DivOp op, OpAdaptor adaptor,
 
  184                  ConversionPatternRewriter &rewriter)
 const override {
 
  185    auto loc = op.getLoc();
 
  186    auto isLhsTypeSigned =
 
  187        cast<IntegerType>(op.getOperand(0).getType()).isSigned();
 
  188    auto rhsType = cast<IntegerType>(op.getOperand(1).getType());
 
  189    auto targetType = cast<IntegerType>(op.getResult().getType());
 
  200    bool signedDivision = targetType.isSigned();
 
  201    unsigned extendSize = std::max(
 
  202        targetType.getWidth(),
 
  203        rhsType.getWidth() + (signedDivision && !rhsType.isSigned() ? 1 : 0));
 
  206    Value lhsValue = 
extendTypeWidth(rewriter, loc, adaptor.getInputs()[0],
 
  207                                     extendSize, isLhsTypeSigned);
 
  208    Value rhsValue = 
extendTypeWidth(rewriter, loc, adaptor.getInputs()[1],
 
  209                                     extendSize, rhsType.isSigned());
 
  213      divResult = comb::DivSOp::create(rewriter, loc, lhsValue, rhsValue, 
false)
 
  216      divResult = comb::DivUOp::create(rewriter, loc, lhsValue, rhsValue, 
false)
 
  220    auto *divOp = divResult.getDefiningOp();
 
  221    rewriter.modifyOpInPlace(
 
  222        divOp, [&]() { divOp->setDialectAttrs(op->getDialectAttrs()); });
 
  225    Value truncateResult = 
extractBits(rewriter, loc, divResult, 0,
 
  226                                       targetType.getWidth());
 
  227    rewriter.replaceOp(op, truncateResult);
 
  239  matchAndRewrite(
CastOp op, OpAdaptor adaptor,
 
  240                  ConversionPatternRewriter &rewriter)
 const override {
 
  241    auto sourceType = cast<IntegerType>(op.getIn().getType());
 
  242    auto sourceWidth = sourceType.getWidth();
 
  243    bool isSourceTypeSigned = sourceType.isSigned();
 
  244    auto targetWidth = cast<IntegerType>(op.getOut().getType()).getWidth();
 
  247    if (sourceWidth == targetWidth) {
 
  250      replaceValue = adaptor.getIn();
 
  251    } 
else if (sourceWidth < targetWidth) {
 
  255                                     targetWidth, isSourceTypeSigned);
 
  258      replaceValue = 
extractBits(rewriter, op.getLoc(), adaptor.getIn(),
 
  261    rewriter.replaceOp(op, replaceValue);
 
  273static comb::ICmpPredicate lowerPredicate(ICmpPredicate pred, 
bool isSigned) {
 
  275  case ICmpPredicate::eq:
 
  276    return comb::ICmpPredicate::eq;
 
  277  case ICmpPredicate::ne:
 
  278    return comb::ICmpPredicate::ne;
 
  279  case ICmpPredicate::lt:
 
  280    return isSigned ? comb::ICmpPredicate::slt : comb::ICmpPredicate::ult;
 
  281  case ICmpPredicate::ge:
 
  282    return isSigned ? comb::ICmpPredicate::sge : comb::ICmpPredicate::uge;
 
  283  case ICmpPredicate::le:
 
  284    return isSigned ? comb::ICmpPredicate::sle : comb::ICmpPredicate::ule;
 
  285  case ICmpPredicate::gt:
 
  286    return isSigned ? comb::ICmpPredicate::sgt : comb::ICmpPredicate::ugt;
 
  290      "Missing hwarith::ICmpPredicate to comb::ICmpPredicate lowering");
 
  291  return comb::ICmpPredicate::eq;
 
  298  matchAndRewrite(
ICmpOp op, OpAdaptor adaptor,
 
  299                  ConversionPatternRewriter &rewriter)
 const override {
 
  300    auto lhsType = cast<IntegerType>(op.getLhs().getType());
 
  301    auto rhsType = cast<IntegerType>(op.getRhs().getType());
 
  302    IntegerType::SignednessSemantics cmpSignedness;
 
  303    const unsigned cmpWidth =
 
  306    ICmpPredicate pred = op.getPredicate();
 
  307    comb::ICmpPredicate combPred = lowerPredicate(
 
  308        pred, cmpSignedness == IntegerType::SignednessSemantics::Signed);
 
  310    const auto loc = op.getLoc();
 
  311    Value lhsValue = 
extendTypeWidth(rewriter, loc, adaptor.getLhs(), cmpWidth,
 
  313    Value rhsValue = 
extendTypeWidth(rewriter, loc, adaptor.getRhs(), cmpWidth,
 
  316    auto newOp = comb::ICmpOp::create(rewriter, op->getLoc(), combPred,
 
  317                                      lhsValue, rhsValue, 
false);
 
  318    rewriter.modifyOpInPlace(
 
  319        newOp, [&]() { newOp->setDialectAttrs(op->getDialectAttrs()); });
 
  320    rewriter.replaceOp(op, newOp);
 
  326template <
class BinOp, 
class ReplaceOp>
 
  332  matchAndRewrite(BinOp op, OpAdaptor adaptor,
 
  333                  ConversionPatternRewriter &rewriter)
 const override {
 
  334    auto loc = op.getLoc();
 
  335    auto isLhsTypeSigned =
 
  336        cast<IntegerType>(op.getOperand(0).getType()).isSigned();
 
  337    auto isRhsTypeSigned =
 
  338        cast<IntegerType>(op.getOperand(1).getType()).isSigned();
 
  339    auto targetWidth = cast<IntegerType>(op.getResult().getType()).getWidth();
 
  341    Value lhsValue = 
extendTypeWidth(rewriter, loc, adaptor.getInputs()[0],
 
  342                                     targetWidth, isLhsTypeSigned);
 
  343    Value rhsValue = 
extendTypeWidth(rewriter, loc, adaptor.getInputs()[1],
 
  344                                     targetWidth, isRhsTypeSigned);
 
  346        ReplaceOp::create(rewriter, op.getLoc(), lhsValue, rhsValue, 
false);
 
  347    rewriter.modifyOpInPlace(
 
  348        newOp, [&]() { newOp->setDialectAttrs(op->getDialectAttrs()); });
 
  349    rewriter.replaceOp(op, newOp);
 
  360    return it->second.type;
 
  363      llvm::TypeSwitch<Type, Type>(type)
 
  364          .Case<IntegerType>([](
auto type) {
 
  365            if (type.isSignless())
 
  367            return IntegerType::get(type.getContext(), type.getWidth());
 
  369          .Case<hw::ArrayType>([
this](
auto type) {
 
  371                                      type.getNumElements());
 
  373          .Case<hw::UnpackedArrayType>([
this](
auto type) {
 
  374            return hw::UnpackedArrayType::get(
 
  377          .Case<hw::StructType>([
this](
auto type) {
 
  379            llvm::SmallVector<hw::StructType::FieldInfo> convertedElements;
 
  380            for (
auto element : type.getElements()) {
 
  381              convertedElements.push_back(
 
  384            return hw::StructType::get(type.getContext(), convertedElements);
 
  386          .Case<hw::InOutType>([
this](
auto type) {
 
  389          .Case<hw::TypeAliasType>([
this](
auto type) {
 
  390            return hw::TypeAliasType::get(
 
  393          .Default([](
auto type) { 
return type; });
 
  395  return convertedType;
 
 
  402  addTargetMaterialization([&](mlir::OpBuilder &builder, mlir::Type resultType,
 
  403                               mlir::ValueRange inputs,
 
  404                               mlir::Location loc) -> mlir::Value {
 
  405    if (inputs.size() != 1)
 
  407    return UnrealizedConversionCastOp::create(builder, loc, resultType,
 
  412  addSourceMaterialization([&](mlir::OpBuilder &builder, mlir::Type resultType,
 
  413                               mlir::ValueRange inputs,
 
  414                               mlir::Location loc) -> mlir::Value {
 
  415    if (inputs.size() != 1)
 
  417    return UnrealizedConversionCastOp::create(builder, loc, resultType,
 
 
  429  patterns.add<ConstantOpLowering, CastOpLowering, ICmpOpLowering,
 
  430               BinaryOpLowering<AddOp, comb::AddOp>,
 
  431               BinaryOpLowering<SubOp, comb::SubOp>,
 
  432               BinaryOpLowering<MulOp, comb::MulOp>, DivOpLowering>(
 
  433      typeConverter, 
patterns.getContext());
 
 
  438class HWArithToHWPass : 
public circt::impl::HWArithToHWBase<HWArithToHWPass> {
 
  440  void runOnOperation()
 override {
 
  441    ModuleOp 
module = getOperation();
 
  443    ConversionTarget target(getContext());
 
  444    target.markUnknownOpDynamicallyLegal(
isLegalOp);
 
  445    RewritePatternSet 
patterns(&getContext());
 
  447    target.addIllegalDialect<HWArithDialect>();
 
  460    if (failed(applyFullConversion(module, target, std::move(
patterns))))
 
  461      return signalPassFailure();
 
  471  return std::make_unique<HWArithToHWPass>();
 
 
static SmallVector< Value > extractBits(OpBuilder &builder, Value val)
 
static bool isLegalOp(Operation *op)
Returns true if the given op is considered as legal - i.e.
 
static bool isSignednessAttr(Attribute attr)
 
static bool isLegalOp(Operation *op)
Returns true if the given op is considered as legal for HWArith conversion.
 
static void improveNamehint(Value oldValue, Operation *newOp, llvm::function_ref< std::string(StringRef)> namehintCallback)
 
static Value extractBits(OpBuilder &builder, Location loc, Value value, unsigned startBit, unsigned bitWidth)
 
static bool isSignednessType(Type type)
 
static Value extendTypeWidth(OpBuilder &builder, Location loc, Value value, unsigned targetWidth, bool signExtension)
 
A helper type converter class that automatically populates the relevant materializations and type con...
 
mlir::Type removeSignedness(mlir::Type type)
 
llvm::DenseMap< mlir::Type, ConvertedType > conversionCache
 
HWArithToHWTypeConverter()
 
unsigned inferAddResultType(IntegerType::SignednessSemantics &signedness, IntegerType lhs, IntegerType rhs)
 
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
 
void populateHWArithToHWConversionPatterns(HWArithToHWTypeConverter &typeConverter, RewritePatternSet &patterns)
Get the HWArith to HW conversion patterns.
 
std::unique_ptr< mlir::Pass > createHWArithToHWPass()
 
Generic pattern which replaces an operation by one of the same operation name, but with converted att...