20#include "mlir/IR/Builders.h"
21#include "mlir/IR/IRMapping.h"
22#include "mlir/IR/PatternMatch.h"
23#include "mlir/Pass/Pass.h"
24#include "mlir/Transforms/DialectConversion.h"
25#include "llvm/ADT/TypeSwitch.h"
29#define GEN_PASS_DEF_HWSPECIALIZE
30#include "circt/Dialect/HW/Passes.h.inc"
44 ArrayAttr parameters) {
45 assert(parameters.size() != 0);
46 std::string name = moduleOp.getName().str();
47 for (
auto param : parameters) {
48 auto paramAttr = cast<ParamDeclAttr>(param);
49 int64_t paramValue = cast<IntegerAttr>(paramAttr.getValue()).getInt();
50 name +=
"_" + paramAttr.getName().str() +
"_" + std::to_string(paramValue);
58static bool isParametricOp(Operation *op) {
65static FailureOr<Value> narrowValueToArrayWidth(OpBuilder &builder, Value array,
67 OpBuilder::InsertionGuard g(builder);
68 builder.setInsertionPointAfterValue(value);
69 auto arrayType = cast<hw::ArrayType>(array.getType());
70 unsigned hiBit = llvm::Log2_64_Ceil(arrayType.getNumElements());
74 APInt(arrayType.getNumElements(), 0))
83 auto *targetOp = sc.
getDefinition(instanceOp.getModuleNameAttr());
84 auto targetHWModule = dyn_cast<hw::HWModuleOp>(targetOp);
88 if (targetHWModule.getParameters().size() == 0)
91 return targetHWModule;
95struct ParameterSpecializationRegistry {
96 llvm::MapVector<hw::HWModuleOp, llvm::SetVector<ArrayAttr>>
97 uniqueModuleParameters;
99 bool isRegistered(
hw::HWModuleOp moduleOp, ArrayAttr parameters)
const {
100 auto it = uniqueModuleParameters.find(moduleOp);
101 return it != uniqueModuleParameters.end() &&
102 it->second.contains(parameters);
105 void registerModuleOp(
hw::HWModuleOp moduleOp, ArrayAttr parameters) {
106 uniqueModuleParameters[moduleOp].insert(parameters);
110struct EliminateParamValueOpPattern :
public OpRewritePattern<ParamValueOp> {
111 EliminateParamValueOpPattern(MLIRContext *context, ArrayAttr parameters)
114 LogicalResult matchAndRewrite(ParamValueOp op,
115 PatternRewriter &rewriter)
const override {
117 FailureOr<Attribute> evaluated =
119 if (failed(evaluated))
123 cast<IntegerAttr>(*evaluated).getValue().getSExtValue());
127 ArrayAttr parameters;
138 matchAndRewrite(
ArrayGetOp op, OpAdaptor adaptor,
139 ConversionPatternRewriter &rewriter)
const override {
140 auto inputType = type_cast<ArrayType>(op.getInput().getType());
141 Type targetIndexType = IntegerType::get(
142 getContext(), inputType.getNumElements() == 1
144 :
llvm::Log2_64_Ceil(inputType.getNumElements()));
146 if (op.getIndex().getType().getIntOrFloatBitWidth() ==
147 targetIndexType.getIntOrFloatBitWidth())
151 FailureOr<Value> narrowedIndex =
152 narrowValueToArrayWidth(rewriter, op.getInput(), op.getIndex());
153 if (failed(narrowedIndex))
155 rewriter.replaceOpWithNewOp<
ArrayGetOp>(op, op.getInput(), *narrowedIndex);
161struct ParametricTypeConversionPattern :
public ConversionPattern {
162 ParametricTypeConversionPattern(MLIRContext *ctx,
163 TypeConverter &typeConverter,
164 ArrayAttr parameters)
165 : ConversionPattern(typeConverter, MatchAnyOpTypeTag(), 1,
167 parameters(parameters) {}
170 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
171 ConversionPatternRewriter &rewriter)
const override {
172 llvm::SmallVector<Value, 4> convertedOperands;
175 rewriter.modifyOpInPlace(op, [&]() {
177 for (
auto it :
llvm::enumerate(op->getResultTypes())) {
178 FailureOr<Type> res =
180 ok &= succeeded(res);
183 op->getResult(it.index()).setType(*res);
189 op->setOperands(operands);
194 ArrayAttr parameters;
197struct HWSpecializePass
198 :
public circt::hw::impl::HWSpecializeBase<HWSpecializePass> {
199 void runOnOperation()
override;
203 ArrayAttr parameters) {
205 typeConverter.addConversion([=](hw::IntType type) {
208 typeConverter.addConversion([=](hw::ArrayType type) {
213 typeConverter.addConversion([](mlir::IntegerType type) {
return type; });
218static LogicalResult registerNestedParametricInstanceOps(
220 const ParameterSpecializationRegistry ¤tRegistry,
221 ParameterSpecializationRegistry &nextRegistry,
223 llvm::DenseMap<ArrayAttr, llvm::SmallVector<hw::InstanceOp>>>
226 auto walkResult = target->walk([&](InstanceOp instanceOp) -> WalkResult {
227 auto instanceParameters = instanceOp.getParameters();
229 if (instanceParameters.empty())
230 return WalkResult::advance();
233 llvm::SmallVector<Attribute> evaluatedInstanceParameters;
234 evaluatedInstanceParameters.reserve(instanceParameters.size());
235 for (
auto instanceParameter : instanceParameters) {
236 auto instanceParameterDecl = cast<hw::ParamDeclAttr>(instanceParameter);
237 auto instanceParameterValue = instanceParameterDecl.getValue();
239 instanceParameterValue);
240 if (failed(evaluated))
241 return WalkResult::interrupt();
242 evaluatedInstanceParameters.push_back(
243 hw::ParamDeclAttr::get(instanceParameterDecl.getName(), *evaluated));
246 auto evaluatedInstanceParametersAttr =
247 ArrayAttr::get(target.getContext(), evaluatedInstanceParameters);
249 if (
auto targetHWModule = targetModuleOp(instanceOp, sc)) {
250 if (!currentRegistry.isRegistered(targetHWModule,
251 evaluatedInstanceParametersAttr))
252 nextRegistry.registerModuleOp(targetHWModule,
253 evaluatedInstanceParametersAttr);
254 parametersUsers[targetHWModule][evaluatedInstanceParametersAttr]
255 .push_back(instanceOp);
258 return WalkResult::advance();
261 return failure(walkResult.wasInterrupted());
272static LogicalResult specializeModule(
275 const ParameterSpecializationRegistry ¤tRegistry,
276 ParameterSpecializationRegistry &nextRegistry,
278 llvm::DenseMap<ArrayAttr, llvm::SmallVector<hw::InstanceOp>>>
280 auto *ctx = builder.getContext();
284 for (
auto in :
llvm::enumerate(source.getInputTypes())) {
285 FailureOr<Type> resType =
289 ports.atInput(in.index()).type = *resType;
291 for (
auto out :
llvm::enumerate(source.getOutputTypes())) {
292 FailureOr<Type> resolvedType =
294 if (failed(resolvedType))
296 ports.atOutput(out.index()).type = *resolvedType;
300 target = HWModuleOp::create(
301 builder, source.getLoc(),
302 StringAttr::get(ctx, generateModuleName(ns, source, parameters)), ports);
306 (*target.getOps<hw::OutputOp>().begin()).erase();
314 mapper.set(src, dst);
315 builder.setInsertionPointToStart(target.getBodyBlock());
317 for (
auto &op : source.getOps()) {
319 for (
auto operand : op.getOperands())
320 bvMapper.map(operand, mapper.
get(operand));
321 auto *newOp = builder.clone(op, bvMapper);
322 for (
auto &&[oldRes, newRes] :
323 llvm::zip(op.getResults(), newOp->getResults()))
324 mapper.set(oldRes, newRes);
328 auto nestedRegistrationResult = registerNestedParametricInstanceOps(
329 target, parameters, sc, currentRegistry, nextRegistry, parametersUsers);
330 if (failed(nestedRegistrationResult))
339 patterns.add<EliminateParamValueOpPattern>(ctx, parameters);
340 patterns.add<NarrowArrayGetIndexPattern>(ctx);
341 patterns.add<ParametricTypeConversionPattern>(ctx, t, parameters);
342 ConversionTarget convTarget(*ctx);
344 convTarget.addIllegalOp<hw::ParamValueOp>();
347 convTarget.markUnknownOpDynamicallyLegal(
348 [](Operation *op) {
return !isParametricOp(op); });
350 return applyPartialConversion(target, convTarget, std::move(
patterns));
353void HWSpecializePass::runOnOperation() {
354 ModuleOp
module = getOperation();
358 llvm::DenseMap<ArrayAttr, llvm::SmallVector<hw::InstanceOp>>>
360 ParameterSpecializationRegistry registry;
371 if (!hwModule.getParameters().empty())
373 for (
auto instanceOp : hwModule.getOps<
hw::InstanceOp>()) {
374 if (
auto targetHWModule = targetModuleOp(instanceOp, sc)) {
375 auto parameters = instanceOp.getParameters();
376 registry.registerModuleOp(targetHWModule, parameters);
378 parametersUsers[targetHWModule][parameters].push_back(instanceOp);
384 OpBuilder builder = OpBuilder(&getContext());
385 builder.setInsertionPointToStart(module.getBody());
386 llvm::DenseMap<hw::HWModuleOp, llvm::DenseMap<ArrayAttr, hw::HWModuleOp>>
392 while (!registry.uniqueModuleParameters.empty()) {
394 ParameterSpecializationRegistry nextRegistry;
395 for (
auto it : registry.uniqueModuleParameters) {
396 for (
auto parameters : it.second) {
398 if (failed(specializeModule(builder, parameters, sc, ns, it.first,
399 specializedModule, registry, nextRegistry,
406 sc.
addDefinition(specializedModule.getNameAttr(), specializedModule);
409 specializations[it.first][parameters] = specializedModule;
414 registry.uniqueModuleParameters =
415 std::move(nextRegistry.uniqueModuleParameters);
419 for (
auto it : specializations) {
420 auto unspecialized = it.getFirst();
421 auto &users = parametersUsers[unspecialized];
422 for (
auto specialization : it.getSecond()) {
423 auto parameters = specialization.getFirst();
424 auto specializedModule = specialization.getSecond();
425 for (
auto instanceOp : users[parameters]) {
426 instanceOp->setAttr(
"moduleName",
427 FlatSymbolRefAttr::get(specializedModule));
428 instanceOp->setAttr(
"parameters", ArrayAttr::get(&getContext(), {}));
assert(baseType &&"element must be base type")
static void populateTypeConversion(TypeConverter &converter)
static Block * getBodyBlock(FModuleLike mod)
Instantiate one of these and use it to build typed backedges.
A namespace that is used to store existing names and generate new names in some scope within the IR.
void add(mlir::ModuleOp module)
StringRef newName(const Twine &name)
Return a unique name, derived from the input name, and add the new name to the internal namespace.
void addDefinitions(mlir::Operation *top)
Populate the symbol cache with all symbol-defining operations within the 'top' operation.
Default symbol cache implementation; stores associations between names (StringAttr's) to mlir::Operat...
void addDefinition(mlir::Attribute key, mlir::Operation *op) override
In the building phase, add symbols.
mlir::Operation * getDefinition(mlir::Attribute attr) const override
Lookup a definition for 'symbol' in the cache.
The ValueMapper class facilitates the definition and connection of SSA def-use chains between two loc...
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
mlir::FailureOr< mlir::TypedAttr > evaluateParametricAttr(mlir::Location loc, mlir::ArrayAttr parameters, mlir::Attribute paramAttr, bool emitErrors=true)
Evaluates a parametric attribute (param.decl.ref/param.expr) based on a set of provided parameter val...
mlir::FailureOr< mlir::Type > evaluateParametricType(mlir::Location loc, mlir::ArrayAttr parameters, mlir::Type type, bool emitErrors=true)
Returns a resolved version of 'type' wherein any parameter reference has been evaluated based on the ...
bool isParametricType(mlir::Type t)
Returns true if any part of t is parametric.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
This holds a decoded list of input/inout and output ports for a module or instance.