20 #include "mlir/IR/Builders.h"
21 #include "mlir/IR/IRMapping.h"
22 #include "mlir/IR/PatternMatch.h"
23 #include "mlir/Pass/Pass.h"
24 #include "mlir/Transforms/DialectConversion.h"
25 #include "llvm/ADT/TypeSwitch.h"
29 #define GEN_PASS_DEF_HWSPECIALIZE
30 #include "circt/Dialect/HW/Passes.h.inc"
36 using namespace circt;
44 ArrayAttr parameters) {
45 assert(parameters.size() != 0);
46 std::string name = moduleOp.getName().str();
47 for (
auto param : parameters) {
48 auto paramAttr = cast<ParamDeclAttr>(param);
49 int64_t paramValue = cast<IntegerAttr>(paramAttr.getValue()).getInt();
50 name +=
"_" + paramAttr.getName().str() +
"_" + std::to_string(paramValue);
58 static bool isParametricOp(Operation *op) {
65 static FailureOr<Value> narrowValueToArrayWidth(OpBuilder &builder, Value array,
67 OpBuilder::InsertionGuard g(builder);
68 builder.setInsertionPointAfterValue(value);
69 auto arrayType = cast<hw::ArrayType>(array.getType());
70 unsigned hiBit = llvm::Log2_64_Ceil(arrayType.getNumElements());
75 APInt(arrayType.getNumElements(), 0))
78 .create<
comb::ExtractOp>(value.getLoc(), value,
85 auto *targetOp = sc.
getDefinition(instanceOp.getModuleNameAttr());
86 auto targetHWModule = dyn_cast<hw::HWModuleOp>(targetOp);
90 if (targetHWModule.getParameters().size() == 0)
93 return targetHWModule;
97 struct ParameterSpecializationRegistry {
98 llvm::MapVector<hw::HWModuleOp, llvm::SetVector<ArrayAttr>>
99 uniqueModuleParameters;
101 bool isRegistered(
hw::HWModuleOp moduleOp, ArrayAttr parameters)
const {
102 auto it = uniqueModuleParameters.find(moduleOp);
103 return it != uniqueModuleParameters.end() &&
104 it->second.contains(parameters);
107 void registerModuleOp(
hw::HWModuleOp moduleOp, ArrayAttr parameters) {
108 uniqueModuleParameters[moduleOp].insert(parameters);
112 struct EliminateParamValueOpPattern :
public OpRewritePattern<ParamValueOp> {
113 EliminateParamValueOpPattern(MLIRContext *context, ArrayAttr parameters)
116 LogicalResult matchAndRewrite(ParamValueOp op,
117 PatternRewriter &rewriter)
const override {
119 FailureOr<Attribute> evaluated =
121 if (failed(evaluated))
125 cast<IntegerAttr>(*evaluated).getValue().getSExtValue());
129 ArrayAttr parameters;
140 matchAndRewrite(
ArrayGetOp op, OpAdaptor adaptor,
141 ConversionPatternRewriter &rewriter)
const override {
142 auto inputType = type_cast<ArrayType>(op.getInput().getType());
144 getContext(), inputType.getNumElements() == 1
146 : llvm::Log2_64_Ceil(inputType.getNumElements()));
148 if (op.getIndex().getType().getIntOrFloatBitWidth() ==
149 targetIndexType.getIntOrFloatBitWidth())
153 FailureOr<Value> narrowedIndex =
154 narrowValueToArrayWidth(rewriter, op.getInput(), op.getIndex());
155 if (failed(narrowedIndex))
157 rewriter.replaceOpWithNewOp<
ArrayGetOp>(op, op.getInput(), *narrowedIndex);
163 struct ParametricTypeConversionPattern :
public ConversionPattern {
164 ParametricTypeConversionPattern(MLIRContext *ctx,
165 TypeConverter &typeConverter,
166 ArrayAttr parameters)
167 : ConversionPattern(typeConverter, MatchAnyOpTypeTag(), 1,
169 parameters(parameters) {}
172 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
173 ConversionPatternRewriter &rewriter)
const override {
174 llvm::SmallVector<Value, 4> convertedOperands;
177 rewriter.modifyOpInPlace(op, [&]() {
179 for (
auto it : llvm::enumerate(op->getResultTypes())) {
180 FailureOr<Type> res =
181 evaluateParametricType(op->getLoc(), parameters, it.value());
182 ok &= succeeded(res);
185 op->getResult(it.index()).setType(*res);
191 op->setOperands(operands);
196 ArrayAttr parameters;
199 struct HWSpecializePass
200 :
public circt::hw::impl::HWSpecializeBase<HWSpecializePass> {
201 void runOnOperation()
override;
205 ArrayAttr parameters) {
207 typeConverter.addConversion([=](hw::IntType type) {
210 typeConverter.addConversion([=](hw::ArrayType type) {
215 typeConverter.addConversion([](mlir::IntegerType type) {
return type; });
220 static LogicalResult registerNestedParametricInstanceOps(
222 const ParameterSpecializationRegistry ¤tRegistry,
223 ParameterSpecializationRegistry &nextRegistry,
225 llvm::DenseMap<ArrayAttr, llvm::SmallVector<hw::InstanceOp>>>
228 auto walkResult = target->walk([&](InstanceOp instanceOp) -> WalkResult {
229 auto instanceParameters = instanceOp.getParameters();
231 if (instanceParameters.empty())
232 return WalkResult::advance();
235 llvm::SmallVector<Attribute> evaluatedInstanceParameters;
236 evaluatedInstanceParameters.reserve(instanceParameters.size());
237 for (
auto instanceParameter : instanceParameters) {
238 auto instanceParameterDecl = cast<hw::ParamDeclAttr>(instanceParameter);
239 auto instanceParameterValue = instanceParameterDecl.getValue();
241 instanceParameterValue);
242 if (failed(evaluated))
243 return WalkResult::interrupt();
244 evaluatedInstanceParameters.push_back(
248 auto evaluatedInstanceParametersAttr =
251 if (
auto targetHWModule = targetModuleOp(instanceOp, sc)) {
252 if (!currentRegistry.isRegistered(targetHWModule,
253 evaluatedInstanceParametersAttr))
254 nextRegistry.registerModuleOp(targetHWModule,
255 evaluatedInstanceParametersAttr);
256 parametersUsers[targetHWModule][evaluatedInstanceParametersAttr]
257 .push_back(instanceOp);
260 return WalkResult::advance();
263 return failure(walkResult.wasInterrupted());
274 static LogicalResult specializeModule(
277 const ParameterSpecializationRegistry ¤tRegistry,
278 ParameterSpecializationRegistry &nextRegistry,
280 llvm::DenseMap<ArrayAttr, llvm::SmallVector<hw::InstanceOp>>>
282 auto *ctx = builder.getContext();
286 for (
auto in : llvm::enumerate(source.getInputTypes())) {
287 FailureOr<Type> resType =
291 ports.atInput(in.index()).type = *resType;
293 for (
auto out : llvm::enumerate(source.getOutputTypes())) {
294 FailureOr<Type> resolvedType =
296 if (failed(resolvedType))
298 ports.atOutput(out.index()).type = *resolvedType;
304 StringAttr::get(ctx, generateModuleName(ns, source, parameters)), ports);
308 (*target.getOps<hw::OutputOp>().begin()).erase();
314 for (
auto &&[src, dst] : llvm::zip(source.getBodyBlock()->getArguments(),
315 target.getBodyBlock()->getArguments()))
316 mapper.set(src, dst);
317 builder.setInsertionPointToStart(target.getBodyBlock());
319 for (
auto &op : source.getOps()) {
321 for (
auto operand : op.getOperands())
322 bvMapper.map(operand, mapper.get(operand));
323 auto *newOp = builder.clone(op, bvMapper);
324 for (
auto &&[oldRes, newRes] :
325 llvm::zip(op.getResults(), newOp->getResults()))
326 mapper.set(oldRes, newRes);
330 auto nestedRegistrationResult = registerNestedParametricInstanceOps(
331 target, parameters, sc, currentRegistry, nextRegistry, parametersUsers);
332 if (failed(nestedRegistrationResult))
341 patterns.add<EliminateParamValueOpPattern>(ctx, parameters);
342 patterns.add<NarrowArrayGetIndexPattern>(ctx);
343 patterns.add<ParametricTypeConversionPattern>(ctx, t, parameters);
344 ConversionTarget convTarget(*ctx);
346 convTarget.addIllegalOp<hw::ParamValueOp>();
349 convTarget.markUnknownOpDynamicallyLegal(
350 [](Operation *op) {
return !isParametricOp(op); });
352 return applyPartialConversion(target, convTarget, std::move(
patterns));
355 void HWSpecializePass::runOnOperation() {
356 ModuleOp module = getOperation();
360 llvm::DenseMap<ArrayAttr, llvm::SmallVector<hw::InstanceOp>>>
362 ParameterSpecializationRegistry registry;
373 if (!hwModule.getParameters().empty())
375 for (
auto instanceOp : hwModule.getOps<hw::InstanceOp>()) {
376 if (
auto targetHWModule = targetModuleOp(instanceOp, sc)) {
377 auto parameters = instanceOp.getParameters();
378 registry.registerModuleOp(targetHWModule, parameters);
380 parametersUsers[targetHWModule][parameters].push_back(instanceOp);
386 OpBuilder builder = OpBuilder(&getContext());
387 builder.setInsertionPointToStart(module.getBody());
388 llvm::DenseMap<hw::HWModuleOp, llvm::DenseMap<ArrayAttr, hw::HWModuleOp>>
394 while (!registry.uniqueModuleParameters.empty()) {
396 ParameterSpecializationRegistry nextRegistry;
397 for (
auto it : registry.uniqueModuleParameters) {
398 for (
auto parameters : it.second) {
400 if (failed(specializeModule(builder, parameters, sc, ns, it.first,
401 specializedModule, registry, nextRegistry,
408 sc.
addDefinition(specializedModule.getNameAttr(), specializedModule);
411 specializations[it.first][parameters] = specializedModule;
416 registry.uniqueModuleParameters =
417 std::move(nextRegistry.uniqueModuleParameters);
421 for (
auto it : specializations) {
422 auto unspecialized = it.getFirst();
423 auto &users = parametersUsers[unspecialized];
424 for (
auto specialization : it.getSecond()) {
425 auto parameters = specialization.getFirst();
426 auto specializedModule = specialization.getSecond();
427 for (
auto instanceOp : users[parameters]) {
428 instanceOp->setAttr(
"moduleName",
430 instanceOp->setAttr(
"parameters",
ArrayAttr::get(&getContext(), {}));
439 return std::make_unique<HWSpecializePass>();
assert(baseType &&"element must be base type")
static void populateTypeConversion(TypeConverter &converter)
Instantiate one of these and use it to build typed backedges.
A namespace that is used to store existing names and generate new names in some scope within the IR.
void add(mlir::ModuleOp module)
StringRef newName(const Twine &name)
Return a unique name, derived from the input name, and add the new name to the internal namespace.
void addDefinitions(mlir::Operation *top)
Populate the symbol cache with all symbol-defining operations within the 'top' operation.
Default symbol cache implementation; stores associations between names (StringAttr's) to mlir::Operat...
void addDefinition(mlir::Attribute key, mlir::Operation *op) override
In the building phase, add symbols.
mlir::Operation * getDefinition(mlir::Attribute attr) const override
Lookup a definition for 'symbol' in the cache.
The ValueMapper class facilitates the definition and connection of SSA def-use chains between two loc...
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
mlir::FailureOr< mlir::TypedAttr > evaluateParametricAttr(mlir::Location loc, mlir::ArrayAttr parameters, mlir::Attribute paramAttr, bool emitErrors=true)
Evaluates a parametric attribute (param.decl.ref/param.expr) based on a set of provided parameter val...
enum PEO uint32_t mlir::FailureOr< mlir::Type > evaluateParametricType(mlir::Location loc, mlir::ArrayAttr parameters, mlir::Type type, bool emitErrors=true)
Returns a resolved version of 'type' wherein any parameter reference has been evaluated based on the ...
bool isParametricType(mlir::Type t)
Returns true if any part of t is parametric.
std::unique_ptr< mlir::Pass > createHWSpecializePass()
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
This holds a decoded list of input/inout and output ports for a module or instance.