20#include "mlir/IR/Builders.h" 
   21#include "mlir/IR/IRMapping.h" 
   22#include "mlir/IR/PatternMatch.h" 
   23#include "mlir/Pass/Pass.h" 
   24#include "mlir/Transforms/DialectConversion.h" 
   25#include "llvm/ADT/TypeSwitch.h" 
   29#define GEN_PASS_DEF_HWSPECIALIZE 
   30#include "circt/Dialect/HW/Passes.h.inc" 
   44                                      ArrayAttr parameters) {
 
   45  assert(parameters.size() != 0);
 
   46  std::string name = moduleOp.getName().str();
 
   47  for (
auto param : parameters) {
 
   48    auto paramAttr = cast<ParamDeclAttr>(param);
 
   49    int64_t paramValue = cast<IntegerAttr>(paramAttr.getValue()).getInt();
 
   50    name += 
"_" + paramAttr.getName().str() + 
"_" + std::to_string(paramValue);
 
   58static bool isParametricOp(Operation *op) {
 
   65static FailureOr<Value> narrowValueToArrayWidth(OpBuilder &builder, Value array,
 
   67  OpBuilder::InsertionGuard g(builder);
 
   68  builder.setInsertionPointAfterValue(value);
 
   69  auto arrayType = cast<hw::ArrayType>(array.getType());
 
   70  unsigned hiBit = llvm::Log2_64_Ceil(arrayType.getNumElements());
 
   74                                      APInt(arrayType.getNumElements(), 0))
 
   83  auto *targetOp = sc.
getDefinition(instanceOp.getModuleNameAttr());
 
   84  auto targetHWModule = dyn_cast<hw::HWModuleOp>(targetOp);
 
   88  if (targetHWModule.getParameters().size() == 0)
 
   91  return targetHWModule;
 
   95struct ParameterSpecializationRegistry {
 
   96  llvm::MapVector<hw::HWModuleOp, llvm::SetVector<ArrayAttr>>
 
   97      uniqueModuleParameters;
 
   99  bool isRegistered(
hw::HWModuleOp moduleOp, ArrayAttr parameters)
 const {
 
  100    auto it = uniqueModuleParameters.find(moduleOp);
 
  101    return it != uniqueModuleParameters.end() &&
 
  102           it->second.contains(parameters);
 
  105  void registerModuleOp(
hw::HWModuleOp moduleOp, ArrayAttr parameters) {
 
  106    uniqueModuleParameters[moduleOp].insert(parameters);
 
  110struct EliminateParamValueOpPattern : 
public OpRewritePattern<ParamValueOp> {
 
  111  EliminateParamValueOpPattern(MLIRContext *context, ArrayAttr parameters)
 
  114  LogicalResult matchAndRewrite(ParamValueOp op,
 
  115                                PatternRewriter &rewriter)
 const override {
 
  117    FailureOr<Attribute> evaluated =
 
  119    if (failed(evaluated))
 
  123        cast<IntegerAttr>(*evaluated).getValue().getSExtValue());
 
  127  ArrayAttr parameters;
 
  138  matchAndRewrite(
ArrayGetOp op, OpAdaptor adaptor,
 
  139                  ConversionPatternRewriter &rewriter)
 const override {
 
  140    auto inputType = type_cast<ArrayType>(op.getInput().getType());
 
  141    Type targetIndexType = IntegerType::get(
 
  142        getContext(), inputType.getNumElements() == 1
 
  144                          : 
llvm::Log2_64_Ceil(inputType.getNumElements()));
 
  146    if (op.getIndex().getType().getIntOrFloatBitWidth() ==
 
  147        targetIndexType.getIntOrFloatBitWidth())
 
  151    FailureOr<Value> narrowedIndex =
 
  152        narrowValueToArrayWidth(rewriter, op.getInput(), op.getIndex());
 
  153    if (failed(narrowedIndex))
 
  155    rewriter.replaceOpWithNewOp<
ArrayGetOp>(op, op.getInput(), *narrowedIndex);
 
  161struct ParametricTypeConversionPattern : 
public ConversionPattern {
 
  162  ParametricTypeConversionPattern(MLIRContext *ctx,
 
  163                                  TypeConverter &typeConverter,
 
  164                                  ArrayAttr parameters)
 
  165      : ConversionPattern(typeConverter, MatchAnyOpTypeTag(), 1,
 
  167        parameters(parameters) {}
 
  170  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
 
  171                  ConversionPatternRewriter &rewriter)
 const override {
 
  172    llvm::SmallVector<Value, 4> convertedOperands;
 
  175    rewriter.modifyOpInPlace(op, [&]() {
 
  177      for (
auto it : 
llvm::enumerate(op->getResultTypes())) {
 
  178        FailureOr<Type> res =
 
  180        ok &= succeeded(res);
 
  183        op->getResult(it.index()).setType(*res);
 
  189      op->setOperands(operands);
 
  194  ArrayAttr parameters;
 
  197struct HWSpecializePass
 
  198    : 
public circt::hw::impl::HWSpecializeBase<HWSpecializePass> {
 
  199  void runOnOperation() 
override;
 
  203                                   ArrayAttr parameters) {
 
  205  typeConverter.addConversion([=](hw::IntType type) {
 
  208  typeConverter.addConversion([=](hw::ArrayType type) {
 
  213  typeConverter.addConversion([](mlir::IntegerType type) { 
return type; });
 
  218static LogicalResult registerNestedParametricInstanceOps(
 
  220    const ParameterSpecializationRegistry ¤tRegistry,
 
  221    ParameterSpecializationRegistry &nextRegistry,
 
  223                   llvm::DenseMap<ArrayAttr, llvm::SmallVector<hw::InstanceOp>>>
 
  226  auto walkResult = target->walk([&](InstanceOp instanceOp) -> WalkResult {
 
  227    auto instanceParameters = instanceOp.getParameters();
 
  229    if (instanceParameters.empty())
 
  230      return WalkResult::advance();
 
  233    llvm::SmallVector<Attribute> evaluatedInstanceParameters;
 
  234    evaluatedInstanceParameters.reserve(instanceParameters.size());
 
  235    for (
auto instanceParameter : instanceParameters) {
 
  236      auto instanceParameterDecl = cast<hw::ParamDeclAttr>(instanceParameter);
 
  237      auto instanceParameterValue = instanceParameterDecl.getValue();
 
  239                                              instanceParameterValue);
 
  240      if (failed(evaluated))
 
  241        return WalkResult::interrupt();
 
  242      evaluatedInstanceParameters.push_back(
 
  243          hw::ParamDeclAttr::get(instanceParameterDecl.getName(), *evaluated));
 
  246    auto evaluatedInstanceParametersAttr =
 
  247        ArrayAttr::get(target.getContext(), evaluatedInstanceParameters);
 
  249    if (
auto targetHWModule = targetModuleOp(instanceOp, sc)) {
 
  250      if (!currentRegistry.isRegistered(targetHWModule,
 
  251                                        evaluatedInstanceParametersAttr))
 
  252        nextRegistry.registerModuleOp(targetHWModule,
 
  253                                      evaluatedInstanceParametersAttr);
 
  254      parametersUsers[targetHWModule][evaluatedInstanceParametersAttr]
 
  255          .push_back(instanceOp);
 
  258    return WalkResult::advance();
 
  261  return failure(walkResult.wasInterrupted());
 
  272static LogicalResult specializeModule(
 
  275    const ParameterSpecializationRegistry ¤tRegistry,
 
  276    ParameterSpecializationRegistry &nextRegistry,
 
  278                   llvm::DenseMap<ArrayAttr, llvm::SmallVector<hw::InstanceOp>>>
 
  280  auto *ctx = builder.getContext();
 
  284  for (
auto in : 
llvm::enumerate(source.getInputTypes())) {
 
  285    FailureOr<Type> resType =
 
  289    ports.atInput(in.index()).type = *resType;
 
  291  for (
auto out : 
llvm::enumerate(source.getOutputTypes())) {
 
  292    FailureOr<Type> resolvedType =
 
  294    if (failed(resolvedType))
 
  296    ports.atOutput(out.index()).type = *resolvedType;
 
  300  target = HWModuleOp::create(
 
  301      builder, source.getLoc(),
 
  302      StringAttr::get(ctx, generateModuleName(ns, source, parameters)), ports);
 
  306  (*target.getOps<hw::OutputOp>().begin()).erase();
 
  314    mapper.set(src, dst);
 
  315  builder.setInsertionPointToStart(target.getBodyBlock());
 
  317  for (
auto &op : source.getOps()) {
 
  319    for (
auto operand : op.getOperands())
 
  320      bvMapper.map(operand, mapper.
get(operand));
 
  321    auto *newOp = builder.clone(op, bvMapper);
 
  322    for (
auto &&[oldRes, newRes] :
 
  323         llvm::zip(op.getResults(), newOp->getResults()))
 
  324      mapper.set(oldRes, newRes);
 
  328  auto nestedRegistrationResult = registerNestedParametricInstanceOps(
 
  329      target, parameters, sc, currentRegistry, nextRegistry, parametersUsers);
 
  330  if (failed(nestedRegistrationResult))
 
  339  patterns.add<EliminateParamValueOpPattern>(ctx, parameters);
 
  340  patterns.add<NarrowArrayGetIndexPattern>(ctx);
 
  341  patterns.add<ParametricTypeConversionPattern>(ctx, t, parameters);
 
  342  ConversionTarget convTarget(*ctx);
 
  344  convTarget.addIllegalOp<hw::ParamValueOp>();
 
  347  convTarget.markUnknownOpDynamicallyLegal(
 
  348      [](Operation *op) { 
return !isParametricOp(op); });
 
  350  return applyPartialConversion(target, convTarget, std::move(
patterns));
 
  353void HWSpecializePass::runOnOperation() {
 
  354  ModuleOp 
module = getOperation();
 
  358                 llvm::DenseMap<ArrayAttr, llvm::SmallVector<hw::InstanceOp>>>
 
  360  ParameterSpecializationRegistry registry;
 
  371    if (!hwModule.getParameters().empty())
 
  373    for (
auto instanceOp : hwModule.getOps<
hw::InstanceOp>()) {
 
  374      if (
auto targetHWModule = targetModuleOp(instanceOp, sc)) {
 
  375        auto parameters = instanceOp.getParameters();
 
  376        registry.registerModuleOp(targetHWModule, parameters);
 
  378        parametersUsers[targetHWModule][parameters].push_back(instanceOp);
 
  384  OpBuilder builder = OpBuilder(&getContext());
 
  385  builder.setInsertionPointToStart(module.getBody());
 
  386  llvm::DenseMap<hw::HWModuleOp, llvm::DenseMap<ArrayAttr, hw::HWModuleOp>>
 
  392  while (!registry.uniqueModuleParameters.empty()) {
 
  394    ParameterSpecializationRegistry nextRegistry;
 
  395    for (
auto it : registry.uniqueModuleParameters) {
 
  396      for (
auto parameters : it.second) {
 
  398        if (failed(specializeModule(builder, parameters, sc, ns, it.first,
 
  399                                    specializedModule, registry, nextRegistry,
 
  406        sc.
addDefinition(specializedModule.getNameAttr(), specializedModule);
 
  409        specializations[it.first][parameters] = specializedModule;
 
  414    registry.uniqueModuleParameters =
 
  415        std::move(nextRegistry.uniqueModuleParameters);
 
  419  for (
auto it : specializations) {
 
  420    auto unspecialized = it.getFirst();
 
  421    auto &users = parametersUsers[unspecialized];
 
  422    for (
auto specialization : it.getSecond()) {
 
  423      auto parameters = specialization.getFirst();
 
  424      auto specializedModule = specialization.getSecond();
 
  425      for (
auto instanceOp : users[parameters]) {
 
  426        instanceOp->setAttr(
"moduleName",
 
  427                            FlatSymbolRefAttr::get(specializedModule));
 
  428        instanceOp->setAttr(
"parameters", ArrayAttr::get(&getContext(), {}));
 
assert(baseType &&"element must be base type")
 
static void populateTypeConversion(TypeConverter &converter)
 
static Block * getBodyBlock(FModuleLike mod)
 
Instantiate one of these and use it to build typed backedges.
 
A namespace that is used to store existing names and generate new names in some scope within the IR.
 
void add(mlir::ModuleOp module)
 
StringRef newName(const Twine &name)
Return a unique name, derived from the input name, and add the new name to the internal namespace.
 
void addDefinitions(mlir::Operation *top)
Populate the symbol cache with all symbol-defining operations within the 'top' operation.
 
Default symbol cache implementation; stores associations between names (StringAttr's) to mlir::Operat...
 
void addDefinition(mlir::Attribute key, mlir::Operation *op) override
In the building phase, add symbols.
 
mlir::Operation * getDefinition(mlir::Attribute attr) const override
Lookup a definition for 'symbol' in the cache.
 
The ValueMapper class facilitates the definition and connection of SSA def-use chains between two loc...
 
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
 
mlir::FailureOr< mlir::TypedAttr > evaluateParametricAttr(mlir::Location loc, mlir::ArrayAttr parameters, mlir::Attribute paramAttr, bool emitErrors=true)
Evaluates a parametric attribute (param.decl.ref/param.expr) based on a set of provided parameter val...
 
mlir::FailureOr< mlir::Type > evaluateParametricType(mlir::Location loc, mlir::ArrayAttr parameters, mlir::Type type, bool emitErrors=true)
Returns a resolved version of 'type' wherein any parameter reference has been evaluated based on the ...
 
bool isParametricType(mlir::Type t)
Returns true if any part of t is parametric.
 
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
 
This holds a decoded list of input/inout and output ports for a module or instance.