Loading [MathJax]/extensions/tex2jax.js
CIRCT 22.0.0git
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
HWSpecialize.cpp
Go to the documentation of this file.
1//===- HWSpecialize.cpp - hw module specialization ------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This transform performs specialization of parametric hw.module's.
10//
11//===----------------------------------------------------------------------===//
12
20#include "mlir/IR/Builders.h"
21#include "mlir/IR/IRMapping.h"
22#include "mlir/IR/PatternMatch.h"
23#include "mlir/Pass/Pass.h"
24#include "mlir/Transforms/DialectConversion.h"
25#include "llvm/ADT/TypeSwitch.h"
26
27namespace circt {
28namespace hw {
29#define GEN_PASS_DEF_HWSPECIALIZE
30#include "circt/Dialect/HW/Passes.h.inc"
31} // namespace hw
32} // namespace circt
33
34using namespace llvm;
35using namespace mlir;
36using namespace circt;
37using namespace hw;
38
39namespace {
40
41// Generates a module name by composing the name of 'moduleOp' and the set of
42// provided 'parameters'.
43static std::string generateModuleName(Namespace &ns, hw::HWModuleOp moduleOp,
44 ArrayAttr parameters) {
45 assert(parameters.size() != 0);
46 std::string name = moduleOp.getName().str();
47 for (auto param : parameters) {
48 auto paramAttr = cast<ParamDeclAttr>(param);
49 int64_t paramValue = cast<IntegerAttr>(paramAttr.getValue()).getInt();
50 name += "_" + paramAttr.getName().str() + "_" + std::to_string(paramValue);
51 }
52
53 // Query the namespace to generate a unique name.
54 return ns.newName(name).str();
55}
56
57// Returns true if any operand or result of 'op' is parametric.
58static bool isParametricOp(Operation *op) {
59 return llvm::any_of(op->getOperandTypes(), isParametricType) ||
60 llvm::any_of(op->getResultTypes(), isParametricType);
61}
62
63// Narrows 'value' using a comb.extract operation to the width of the
64// hw.array-typed 'array'.
65static FailureOr<Value> narrowValueToArrayWidth(OpBuilder &builder, Value array,
66 Value value) {
67 OpBuilder::InsertionGuard g(builder);
68 builder.setInsertionPointAfterValue(value);
69 auto arrayType = cast<hw::ArrayType>(array.getType());
70 unsigned hiBit = llvm::Log2_64_Ceil(arrayType.getNumElements());
71
72 return hiBit == 0
73 ? hw::ConstantOp::create(builder, value.getLoc(),
74 APInt(arrayType.getNumElements(), 0))
75 .getResult()
76 : comb::ExtractOp::create(builder, value.getLoc(), value,
77 /*lowBit=*/0, hiBit)
78 .getResult();
79}
80
81static hw::HWModuleOp targetModuleOp(hw::InstanceOp instanceOp,
82 const SymbolCache &sc) {
83 auto *targetOp = sc.getDefinition(instanceOp.getModuleNameAttr());
84 auto targetHWModule = dyn_cast<hw::HWModuleOp>(targetOp);
85 if (!targetHWModule)
86 return {}; // Won't specialize external modules.
87
88 if (targetHWModule.getParameters().size() == 0)
89 return {}; // nothing to record or specialize
90
91 return targetHWModule;
92}
93
94// Stores unique module parameters and references to them
95struct ParameterSpecializationRegistry {
96 llvm::MapVector<hw::HWModuleOp, llvm::SetVector<ArrayAttr>>
97 uniqueModuleParameters;
98
99 bool isRegistered(hw::HWModuleOp moduleOp, ArrayAttr parameters) const {
100 auto it = uniqueModuleParameters.find(moduleOp);
101 return it != uniqueModuleParameters.end() &&
102 it->second.contains(parameters);
103 }
104
105 void registerModuleOp(hw::HWModuleOp moduleOp, ArrayAttr parameters) {
106 uniqueModuleParameters[moduleOp].insert(parameters);
107 }
108};
109
110struct EliminateParamValueOpPattern : public OpRewritePattern<ParamValueOp> {
111 EliminateParamValueOpPattern(MLIRContext *context, ArrayAttr parameters)
112 : OpRewritePattern<ParamValueOp>(context), parameters(parameters) {}
113
114 LogicalResult matchAndRewrite(ParamValueOp op,
115 PatternRewriter &rewriter) const override {
116 // Substitute the param value op with an evaluated constant operation.
117 FailureOr<Attribute> evaluated =
118 evaluateParametricAttr(op.getLoc(), parameters, op.getValue());
119 if (failed(evaluated))
120 return failure();
121 rewriter.replaceOpWithNewOp<hw::ConstantOp>(
122 op, op.getType(),
123 cast<IntegerAttr>(*evaluated).getValue().getSExtValue());
124 return success();
125 }
126
127 ArrayAttr parameters;
128};
129
130// hw.array_get operations require indexes to be of equal width of the
131// array itself. Since indexes may originate from constants or parameters,
132// emit comb.extract operations to fulfill this invariant.
133struct NarrowArrayGetIndexPattern : public OpConversionPattern<ArrayGetOp> {
134public:
136
137 LogicalResult
138 matchAndRewrite(ArrayGetOp op, OpAdaptor adaptor,
139 ConversionPatternRewriter &rewriter) const override {
140 auto inputType = type_cast<ArrayType>(op.getInput().getType());
141 Type targetIndexType = IntegerType::get(
142 getContext(), inputType.getNumElements() == 1
143 ? 1
144 : llvm::Log2_64_Ceil(inputType.getNumElements()));
145
146 if (op.getIndex().getType().getIntOrFloatBitWidth() ==
147 targetIndexType.getIntOrFloatBitWidth())
148 return failure(); // nothing to do
149
150 // Narrow the index value.
151 FailureOr<Value> narrowedIndex =
152 narrowValueToArrayWidth(rewriter, op.getInput(), op.getIndex());
153 if (failed(narrowedIndex))
154 return failure();
155 rewriter.replaceOpWithNewOp<ArrayGetOp>(op, op.getInput(), *narrowedIndex);
156 return success();
157 }
158};
159
160// Generic pattern to convert parametric result types.
161struct ParametricTypeConversionPattern : public ConversionPattern {
162 ParametricTypeConversionPattern(MLIRContext *ctx,
163 TypeConverter &typeConverter,
164 ArrayAttr parameters)
165 : ConversionPattern(typeConverter, MatchAnyOpTypeTag(), /*benefit=*/1,
166 ctx),
167 parameters(parameters) {}
168
169 LogicalResult
170 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
171 ConversionPatternRewriter &rewriter) const override {
172 llvm::SmallVector<Value, 4> convertedOperands;
173 // Update the result types of the operation
174 bool ok = true;
175 rewriter.modifyOpInPlace(op, [&]() {
176 // Mutate result types
177 for (auto it : llvm::enumerate(op->getResultTypes())) {
178 FailureOr<Type> res =
179 evaluateParametricType(op->getLoc(), parameters, it.value());
180 ok &= succeeded(res);
181 if (!ok)
182 return;
183 op->getResult(it.index()).setType(*res);
184 }
185
186 // Note: 'operands' have already been converted with the supplied type
187 // converter to this pattern. Make sure that we materialize this
188 // conversion by updating the operands to op.
189 op->setOperands(operands);
190 });
191
192 return success(ok);
193 };
194 ArrayAttr parameters;
195};
196
197struct HWSpecializePass
198 : public circt::hw::impl::HWSpecializeBase<HWSpecializePass> {
199 void runOnOperation() override;
200};
201
202static void populateTypeConversion(Location loc, TypeConverter &typeConverter,
203 ArrayAttr parameters) {
204 // Possibly parametric types
205 typeConverter.addConversion([=](hw::IntType type) {
206 return evaluateParametricType(loc, parameters, type);
207 });
208 typeConverter.addConversion([=](hw::ArrayType type) {
209 return evaluateParametricType(loc, parameters, type);
210 });
211
212 // Valid target types.
213 typeConverter.addConversion([](mlir::IntegerType type) { return type; });
214}
215
216// Registers any nested parametric instance ops of `target` for the next
217// specialization loop
218static LogicalResult registerNestedParametricInstanceOps(
219 HWModuleOp target, ArrayAttr parameters, SymbolCache &sc,
220 const ParameterSpecializationRegistry &currentRegistry,
221 ParameterSpecializationRegistry &nextRegistry,
222 llvm::DenseMap<hw::HWModuleOp,
223 llvm::DenseMap<ArrayAttr, llvm::SmallVector<hw::InstanceOp>>>
224 &parametersUsers) {
225 // Register any nested parametric instance ops for the next loop
226 auto walkResult = target->walk([&](InstanceOp instanceOp) -> WalkResult {
227 auto instanceParameters = instanceOp.getParameters();
228 // We can ignore non-parametric instances
229 if (instanceParameters.empty())
230 return WalkResult::advance();
231
232 // Replace instance parameters with evaluated versions
233 llvm::SmallVector<Attribute> evaluatedInstanceParameters;
234 evaluatedInstanceParameters.reserve(instanceParameters.size());
235 for (auto instanceParameter : instanceParameters) {
236 auto instanceParameterDecl = cast<hw::ParamDeclAttr>(instanceParameter);
237 auto instanceParameterValue = instanceParameterDecl.getValue();
238 auto evaluated = evaluateParametricAttr(target.getLoc(), parameters,
239 instanceParameterValue);
240 if (failed(evaluated))
241 return WalkResult::interrupt();
242 evaluatedInstanceParameters.push_back(
243 hw::ParamDeclAttr::get(instanceParameterDecl.getName(), *evaluated));
244 }
245
246 auto evaluatedInstanceParametersAttr =
247 ArrayAttr::get(target.getContext(), evaluatedInstanceParameters);
248
249 if (auto targetHWModule = targetModuleOp(instanceOp, sc)) {
250 if (!currentRegistry.isRegistered(targetHWModule,
251 evaluatedInstanceParametersAttr))
252 nextRegistry.registerModuleOp(targetHWModule,
253 evaluatedInstanceParametersAttr);
254 parametersUsers[targetHWModule][evaluatedInstanceParametersAttr]
255 .push_back(instanceOp);
256 }
257
258 return WalkResult::advance();
259 });
260
261 return failure(walkResult.wasInterrupted());
262}
263
264// Specializes the provided 'base' module into the 'target' module. By doing
265// so, we create a new module which
266// 1. has no parameters
267// 2. has a name composing the name of 'base' as well as the 'parameters'
268// parameters.
269// 3. Has a top-level interface with any parametric types resolved.
270// 4. Any references to module parameters have been replaced with the
271// parameter value.
272static LogicalResult specializeModule(
273 OpBuilder builder, ArrayAttr parameters, SymbolCache &sc, Namespace &ns,
274 HWModuleOp source, HWModuleOp &target,
275 const ParameterSpecializationRegistry &currentRegistry,
276 ParameterSpecializationRegistry &nextRegistry,
277 llvm::DenseMap<hw::HWModuleOp,
278 llvm::DenseMap<ArrayAttr, llvm::SmallVector<hw::InstanceOp>>>
279 &parametersUsers) {
280 auto *ctx = builder.getContext();
281 // Update the types of the source module ports based on evaluating any
282 // parametric in/output ports.
283 ModulePortInfo ports(source.getPortList());
284 for (auto in : llvm::enumerate(source.getInputTypes())) {
285 FailureOr<Type> resType =
286 evaluateParametricType(source.getLoc(), parameters, in.value());
287 if (failed(resType))
288 return failure();
289 ports.atInput(in.index()).type = *resType;
290 }
291 for (auto out : llvm::enumerate(source.getOutputTypes())) {
292 FailureOr<Type> resolvedType =
293 evaluateParametricType(source.getLoc(), parameters, out.value());
294 if (failed(resolvedType))
295 return failure();
296 ports.atOutput(out.index()).type = *resolvedType;
297 }
298
299 // Create the specialized module using the evaluated port info.
300 target = HWModuleOp::create(
301 builder, source.getLoc(),
302 StringAttr::get(ctx, generateModuleName(ns, source, parameters)), ports);
303
304 // Erase the default created hw.output op - we'll copy the correct operation
305 // during body elaboration.
306 (*target.getOps<hw::OutputOp>().begin()).erase();
307
308 // Clone body of the source into the target. Use ValueMapper to ensure safe
309 // cloning in the presence of backedges.
310 BackedgeBuilder bb(builder, source.getLoc());
311 ValueMapper mapper(&bb);
312 for (auto &&[src, dst] : llvm::zip(source.getBodyBlock()->getArguments(),
313 target.getBodyBlock()->getArguments()))
314 mapper.set(src, dst);
315 builder.setInsertionPointToStart(target.getBodyBlock());
316
317 for (auto &op : source.getOps()) {
318 IRMapping bvMapper;
319 for (auto operand : op.getOperands())
320 bvMapper.map(operand, mapper.get(operand));
321 auto *newOp = builder.clone(op, bvMapper);
322 for (auto &&[oldRes, newRes] :
323 llvm::zip(op.getResults(), newOp->getResults()))
324 mapper.set(oldRes, newRes);
325 }
326
327 // Register any nested parametric instance ops for the next loop
328 auto nestedRegistrationResult = registerNestedParametricInstanceOps(
329 target, parameters, sc, currentRegistry, nextRegistry, parametersUsers);
330 if (failed(nestedRegistrationResult))
331 return failure();
332
333 // We've now created a separate copy of the source module with a rewritten
334 // top-level interface. Next, we enter the module to convert parametric
335 // types within operations.
336 RewritePatternSet patterns(ctx);
337 TypeConverter t;
338 populateTypeConversion(target.getLoc(), t, parameters);
339 patterns.add<EliminateParamValueOpPattern>(ctx, parameters);
340 patterns.add<NarrowArrayGetIndexPattern>(ctx);
341 patterns.add<ParametricTypeConversionPattern>(ctx, t, parameters);
342 ConversionTarget convTarget(*ctx);
343 convTarget.addLegalOp<hw::HWModuleOp>();
344 convTarget.addIllegalOp<hw::ParamValueOp>();
345
346 // Generic legalization of converted operations.
347 convTarget.markUnknownOpDynamicallyLegal(
348 [](Operation *op) { return !isParametricOp(op); });
349
350 return applyPartialConversion(target, convTarget, std::move(patterns));
351}
352
353void HWSpecializePass::runOnOperation() {
354 ModuleOp module = getOperation();
355
356 // Record unique module parameters and references to these.
357 llvm::DenseMap<hw::HWModuleOp,
358 llvm::DenseMap<ArrayAttr, llvm::SmallVector<hw::InstanceOp>>>
359 parametersUsers;
360 ParameterSpecializationRegistry registry;
361
362 // Maintain a symbol cache for fast lookup during module specialization.
363 SymbolCache sc;
364 sc.addDefinitions(module);
365 Namespace ns;
366 ns.add(sc);
367
368 for (auto hwModule : module.getOps<hw::HWModuleOp>()) {
369 // If this module is parametric, defer registering its parametric
370 // instantiations until this module is specialized
371 if (!hwModule.getParameters().empty())
372 continue;
373 for (auto instanceOp : hwModule.getOps<hw::InstanceOp>()) {
374 if (auto targetHWModule = targetModuleOp(instanceOp, sc)) {
375 auto parameters = instanceOp.getParameters();
376 registry.registerModuleOp(targetHWModule, parameters);
377
378 parametersUsers[targetHWModule][parameters].push_back(instanceOp);
379 }
380 }
381 }
382
383 // Create specialized modules.
384 OpBuilder builder = OpBuilder(&getContext());
385 builder.setInsertionPointToStart(module.getBody());
386 llvm::DenseMap<hw::HWModuleOp, llvm::DenseMap<ArrayAttr, hw::HWModuleOp>>
387 specializations;
388
389 // For every module specialization, any nested parametric modules will be
390 // registered for the next loop. We loop until no new nested modules have been
391 // registered.
392 while (!registry.uniqueModuleParameters.empty()) {
393 // The registry for the next specialization loop
394 ParameterSpecializationRegistry nextRegistry;
395 for (auto it : registry.uniqueModuleParameters) {
396 for (auto parameters : it.second) {
397 HWModuleOp specializedModule;
398 if (failed(specializeModule(builder, parameters, sc, ns, it.first,
399 specializedModule, registry, nextRegistry,
400 parametersUsers))) {
401 signalPassFailure();
402 return;
403 }
404
405 // Extend the symbol cache with the newly created module.
406 sc.addDefinition(specializedModule.getNameAttr(), specializedModule);
407
408 // Add the specialization
409 specializations[it.first][parameters] = specializedModule;
410 }
411 }
412
413 // Transfer newly registered specializations to iterate over
414 registry.uniqueModuleParameters =
415 std::move(nextRegistry.uniqueModuleParameters);
416 }
417
418 // Rewrite instances of specialized modules to the specialized module.
419 for (auto it : specializations) {
420 auto unspecialized = it.getFirst();
421 auto &users = parametersUsers[unspecialized];
422 for (auto specialization : it.getSecond()) {
423 auto parameters = specialization.getFirst();
424 auto specializedModule = specialization.getSecond();
425 for (auto instanceOp : users[parameters]) {
426 instanceOp->setAttr("moduleName",
427 FlatSymbolRefAttr::get(specializedModule));
428 instanceOp->setAttr("parameters", ArrayAttr::get(&getContext(), {}));
429 }
430 }
431 }
432}
433
434} // namespace
assert(baseType &&"element must be base type")
static void populateTypeConversion(TypeConverter &converter)
static Block * getBodyBlock(FModuleLike mod)
Instantiate one of these and use it to build typed backedges.
A namespace that is used to store existing names and generate new names in some scope within the IR.
Definition Namespace.h:30
void add(mlir::ModuleOp module)
Definition Namespace.h:48
StringRef newName(const Twine &name)
Return a unique name, derived from the input name, and add the new name to the internal namespace.
Definition Namespace.h:87
void addDefinitions(mlir::Operation *top)
Populate the symbol cache with all symbol-defining operations within the 'top' operation.
Definition SymCache.cpp:23
Default symbol cache implementation; stores associations between names (StringAttr's) to mlir::Operat...
Definition SymCache.h:85
void addDefinition(mlir::Attribute key, mlir::Operation *op) override
In the building phase, add symbols.
Definition SymCache.h:88
mlir::Operation * getDefinition(mlir::Attribute attr) const override
Lookup a definition for 'symbol' in the cache.
Definition SymCache.h:94
The ValueMapper class facilitates the definition and connection of SSA def-use chains between two loc...
Definition ValueMapper.h:35
create(low_bit, result_type, input=None)
Definition comb.py:187
create(data_type, value)
Definition hw.py:433
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition CalyxOps.cpp:55
mlir::FailureOr< mlir::TypedAttr > evaluateParametricAttr(mlir::Location loc, mlir::ArrayAttr parameters, mlir::Attribute paramAttr, bool emitErrors=true)
Evaluates a parametric attribute (param.decl.ref/param.expr) based on a set of provided parameter val...
mlir::FailureOr< mlir::Type > evaluateParametricType(mlir::Location loc, mlir::ArrayAttr parameters, mlir::Type type, bool emitErrors=true)
Returns a resolved version of 'type' wherein any parameter reference has been evaluated based on the ...
bool isParametricType(mlir::Type t)
Returns true if any part of t is parametric.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition hw.py:1
This holds a decoded list of input/inout and output ports for a module or instance.