21 #include "mlir/IR/Builders.h"
22 #include "mlir/IR/IRMapping.h"
23 #include "mlir/IR/PatternMatch.h"
24 #include "mlir/Transforms/DialectConversion.h"
25 #include "llvm/ADT/TypeSwitch.h"
29 using namespace circt;
37 ArrayAttr parameters) {
38 assert(parameters.size() != 0);
39 std::string name = moduleOp.getName().str();
40 for (
auto param : parameters) {
41 auto paramAttr = param.cast<ParamDeclAttr>();
42 int64_t paramValue = paramAttr.getValue().cast<IntegerAttr>().
getInt();
43 name +=
"_" + paramAttr.getName().str() +
"_" + std::to_string(paramValue);
51 static bool isParametricOp(Operation *op) {
60 OpBuilder::InsertionGuard g(
builder);
62 auto arrayType = array.getType().cast<hw::ArrayType>();
63 unsigned hiBit = llvm::Log2_64_Ceil(arrayType.getNumElements());
68 APInt(arrayType.getNumElements(), 0))
78 auto *targetOp = sc.
getDefinition(instanceOp.getModuleNameAttr());
79 auto targetHWModule = dyn_cast<hw::HWModuleOp>(targetOp);
83 if (targetHWModule.getParameters().size() == 0)
86 return targetHWModule;
90 struct ParameterSpecializationRegistry {
91 llvm::MapVector<hw::HWModuleOp, llvm::SetVector<ArrayAttr>>
92 uniqueModuleParameters;
94 bool isRegistered(
hw::HWModuleOp moduleOp, ArrayAttr parameters)
const {
95 auto it = uniqueModuleParameters.find(moduleOp);
96 return it != uniqueModuleParameters.end() &&
97 it->second.contains(parameters);
100 void registerModuleOp(
hw::HWModuleOp moduleOp, ArrayAttr parameters) {
101 uniqueModuleParameters[moduleOp].insert(parameters);
105 struct EliminateParamValueOpPattern :
public OpRewritePattern<ParamValueOp> {
106 EliminateParamValueOpPattern(MLIRContext *context, ArrayAttr parameters)
109 LogicalResult matchAndRewrite(ParamValueOp op,
110 PatternRewriter &rewriter)
const override {
114 if (failed(evaluated))
118 evaluated->cast<IntegerAttr>().getValue().getSExtValue());
122 ArrayAttr parameters;
133 matchAndRewrite(
ArrayGetOp op, OpAdaptor adaptor,
134 ConversionPatternRewriter &rewriter)
const override {
135 auto inputType = type_cast<ArrayType>(op.getInput().getType());
137 getContext(), inputType.getNumElements() == 1
139 : llvm::Log2_64_Ceil(inputType.getNumElements()));
141 if (op.getIndex().getType().getIntOrFloatBitWidth() ==
142 targetIndexType.getIntOrFloatBitWidth())
147 narrowValueToArrayWidth(rewriter, op.getInput(), op.getIndex());
148 if (failed(narrowedIndex))
150 rewriter.replaceOpWithNewOp<
ArrayGetOp>(op, op.getInput(), *narrowedIndex);
156 struct ParametricTypeConversionPattern :
public ConversionPattern {
157 ParametricTypeConversionPattern(MLIRContext *ctx,
158 TypeConverter &typeConverter,
159 ArrayAttr parameters)
160 : ConversionPattern(typeConverter, MatchAnyOpTypeTag(), 1,
162 parameters(parameters) {}
165 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
166 ConversionPatternRewriter &rewriter)
const override {
167 llvm::SmallVector<Value, 4> convertedOperands;
170 rewriter.updateRootInPlace(op, [&]() {
172 for (
auto it : llvm::enumerate(op->getResultTypes())) {
173 FailureOr<Type> res =
174 evaluateParametricType(op->getLoc(), parameters, it.value());
175 ok &= succeeded(res);
178 op->getResult(it.index()).setType(*res);
184 op->setOperands(operands);
189 ArrayAttr parameters;
192 struct HWSpecializePass :
public hw::HWSpecializeBase<HWSpecializePass> {
193 void runOnOperation()
override;
197 ArrayAttr parameters) {
199 typeConverter.addConversion([=](hw::IntType type) {
202 typeConverter.addConversion([=](hw::ArrayType type) {
207 typeConverter.addConversion([](mlir::IntegerType type) {
return type; });
212 static LogicalResult registerNestedParametricInstanceOps(
214 const ParameterSpecializationRegistry ¤tRegistry,
215 ParameterSpecializationRegistry &nextRegistry,
217 llvm::DenseMap<ArrayAttr, llvm::SmallVector<hw::InstanceOp>>>
220 auto walkResult = target->walk([&](InstanceOp instanceOp) -> WalkResult {
221 auto instanceParameters = instanceOp.getParameters();
223 if (instanceParameters.empty())
224 return WalkResult::advance();
227 llvm::SmallVector<Attribute> evaluatedInstanceParameters;
228 evaluatedInstanceParameters.reserve(instanceParameters.size());
229 for (
auto instanceParameter : instanceParameters) {
230 auto instanceParameterDecl = instanceParameter.cast<hw::ParamDeclAttr>();
231 auto instanceParameterValue = instanceParameterDecl.getValue();
233 instanceParameterValue);
234 if (failed(evaluated))
235 return WalkResult::interrupt();
236 evaluatedInstanceParameters.push_back(
240 auto evaluatedInstanceParametersAttr =
243 if (
auto targetHWModule = targetModuleOp(instanceOp, sc)) {
244 if (!currentRegistry.isRegistered(targetHWModule,
245 evaluatedInstanceParametersAttr))
246 nextRegistry.registerModuleOp(targetHWModule,
247 evaluatedInstanceParametersAttr);
248 parametersUsers[targetHWModule][evaluatedInstanceParametersAttr]
249 .push_back(instanceOp);
252 return WalkResult::advance();
255 return failure(walkResult.wasInterrupted());
266 static LogicalResult specializeModule(
269 const ParameterSpecializationRegistry ¤tRegistry,
270 ParameterSpecializationRegistry &nextRegistry,
272 llvm::DenseMap<ArrayAttr, llvm::SmallVector<hw::InstanceOp>>>
274 auto *ctx =
builder.getContext();
277 ModulePortInfo ports(source.getPortList());
278 for (
auto in : llvm::enumerate(source.getInputTypes())) {
283 ports.atInput(in.index()).type = *resType;
285 for (
auto out : llvm::enumerate(source.getOutputTypes())) {
288 if (failed(resolvedType))
290 ports.atOutput(out.index()).type = *resolvedType;
296 StringAttr::get(ctx, generateModuleName(ns, source, parameters)), ports);
300 (*target.getOps<hw::OutputOp>().begin()).erase();
306 for (
auto &&[src, dst] : llvm::zip(source.getBodyBlock()->getArguments(),
307 target.getBodyBlock()->getArguments()))
308 mapper.set(src, dst);
309 builder.setInsertionPointToStart(target.getBodyBlock());
311 for (
auto &op : source.getOps()) {
313 for (
auto operand : op.getOperands())
314 bvMapper.map(operand, mapper.get(operand));
315 auto *newOp =
builder.clone(op, bvMapper);
316 for (
auto &&[oldRes, newRes] :
317 llvm::zip(op.getResults(), newOp->getResults()))
318 mapper.set(oldRes, newRes);
322 auto nestedRegistrationResult = registerNestedParametricInstanceOps(
323 target, parameters, sc, currentRegistry, nextRegistry, parametersUsers);
324 if (failed(nestedRegistrationResult))
333 patterns.add<EliminateParamValueOpPattern>(ctx, parameters);
334 patterns.add<NarrowArrayGetIndexPattern>(ctx);
335 patterns.add<ParametricTypeConversionPattern>(ctx, t, parameters);
336 ConversionTarget convTarget(*ctx);
338 convTarget.addIllegalOp<hw::ParamValueOp>();
341 convTarget.markUnknownOpDynamicallyLegal(
342 [](Operation *op) {
return !isParametricOp(op); });
344 return applyPartialConversion(target, convTarget, std::move(
patterns));
347 void HWSpecializePass::runOnOperation() {
348 ModuleOp module = getOperation();
352 llvm::DenseMap<ArrayAttr, llvm::SmallVector<hw::InstanceOp>>>
354 ParameterSpecializationRegistry registry;
365 if (!hwModule.getParameters().empty())
367 for (
auto instanceOp : hwModule.getOps<hw::InstanceOp>()) {
368 if (
auto targetHWModule = targetModuleOp(instanceOp, sc)) {
369 auto parameters = instanceOp.getParameters();
370 registry.registerModuleOp(targetHWModule, parameters);
372 parametersUsers[targetHWModule][parameters].push_back(instanceOp);
378 OpBuilder
builder = OpBuilder(&getContext());
379 builder.setInsertionPointToStart(module.getBody());
380 llvm::DenseMap<hw::HWModuleOp, llvm::DenseMap<ArrayAttr, hw::HWModuleOp>>
386 while (!registry.uniqueModuleParameters.empty()) {
388 ParameterSpecializationRegistry nextRegistry;
389 for (
auto it : registry.uniqueModuleParameters) {
390 for (
auto parameters : it.second) {
392 if (failed(specializeModule(
builder, parameters, sc, ns, it.first,
393 specializedModule, registry, nextRegistry,
400 sc.
addDefinition(specializedModule.getNameAttr(), specializedModule);
403 specializations[it.first][parameters] = specializedModule;
408 registry.uniqueModuleParameters =
409 std::move(nextRegistry.uniqueModuleParameters);
413 for (
auto it : specializations) {
414 auto unspecialized = it.getFirst();
415 auto &users = parametersUsers[unspecialized];
416 for (
auto specialization : it.getSecond()) {
417 auto parameters = specialization.getFirst();
418 auto specializedModule = specialization.getSecond();
419 for (
auto instanceOp : users[parameters]) {
420 instanceOp->setAttr(
"moduleName",
422 instanceOp->setAttr(
"parameters",
ArrayAttr::get(&getContext(), {}));
431 return std::make_unique<HWSpecializePass>();
assert(baseType &&"element must be base type")
static void populateTypeConversion(TypeConverter &typeConverter)
static std::optional< APInt > getInt(Value value)
Helper to convert a value to a constant integer if it is one.
Instantiate one of these and use it to build typed backedges.
A namespace that is used to store existing names and generate new names in some scope within the IR.
void add(SymbolCache &symCache)
SymbolCache initializer; initialize from every key that is convertible to a StringAttr in the SymbolC...
StringRef newName(const Twine &name)
Return a unique name, derived from the input name, and add the new name to the internal namespace.
void addDefinitions(mlir::Operation *top)
Populate the symbol cache with all symbol-defining operations within the 'top' operation.
Default symbol cache implementation; stores associations between names (StringAttr's) to mlir::Operat...
void addDefinition(mlir::Attribute key, mlir::Operation *op) override
In the building phase, add symbols.
mlir::Operation * getDefinition(mlir::Attribute attr) const override
Lookup a definition for 'symbol' in the cache.
The ValueMapper class facilitates the definition and connection of SSA def-use chains between two loc...
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
mlir::FailureOr< mlir::TypedAttr > evaluateParametricAttr(mlir::Location loc, mlir::ArrayAttr parameters, mlir::Attribute paramAttr, bool emitErrors=true)
Evaluates a parametric attribute (param.decl.ref/param.expr) based on a set of provided parameter val...
enum PEO uint32_t mlir::FailureOr< mlir::Type > evaluateParametricType(mlir::Location loc, mlir::ArrayAttr parameters, mlir::Type type, bool emitErrors=true)
Returns a resolved version of 'type' wherein any parameter reference has been evaluated based on the ...
bool isParametricType(mlir::Type t)
Returns true if any part of t is parametric.
std::unique_ptr< mlir::Pass > createHWSpecializePass()
This file defines an intermediate representation for circuits acting as an abstraction for constraint...