CIRCT 20.0.0git
Loading...
Searching...
No Matches
HWSpecialize.cpp
Go to the documentation of this file.
1//===- HWSpecialize.cpp - hw module specialization ------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This transform performs specialization of parametric hw.module's.
10//
11//===----------------------------------------------------------------------===//
12
20#include "mlir/IR/Builders.h"
21#include "mlir/IR/IRMapping.h"
22#include "mlir/IR/PatternMatch.h"
23#include "mlir/Pass/Pass.h"
24#include "mlir/Transforms/DialectConversion.h"
25#include "llvm/ADT/TypeSwitch.h"
26
27namespace circt {
28namespace hw {
29#define GEN_PASS_DEF_HWSPECIALIZE
30#include "circt/Dialect/HW/Passes.h.inc"
31} // namespace hw
32} // namespace circt
33
34using namespace llvm;
35using namespace mlir;
36using namespace circt;
37using namespace hw;
38
39namespace {
40
41// Generates a module name by composing the name of 'moduleOp' and the set of
42// provided 'parameters'.
43static std::string generateModuleName(Namespace &ns, hw::HWModuleOp moduleOp,
44 ArrayAttr parameters) {
45 assert(parameters.size() != 0);
46 std::string name = moduleOp.getName().str();
47 for (auto param : parameters) {
48 auto paramAttr = cast<ParamDeclAttr>(param);
49 int64_t paramValue = cast<IntegerAttr>(paramAttr.getValue()).getInt();
50 name += "_" + paramAttr.getName().str() + "_" + std::to_string(paramValue);
51 }
52
53 // Query the namespace to generate a unique name.
54 return ns.newName(name).str();
55}
56
57// Returns true if any operand or result of 'op' is parametric.
58static bool isParametricOp(Operation *op) {
59 return llvm::any_of(op->getOperandTypes(), isParametricType) ||
60 llvm::any_of(op->getResultTypes(), isParametricType);
61}
62
63// Narrows 'value' using a comb.extract operation to the width of the
64// hw.array-typed 'array'.
65static FailureOr<Value> narrowValueToArrayWidth(OpBuilder &builder, Value array,
66 Value value) {
67 OpBuilder::InsertionGuard g(builder);
68 builder.setInsertionPointAfterValue(value);
69 auto arrayType = cast<hw::ArrayType>(array.getType());
70 unsigned hiBit = llvm::Log2_64_Ceil(arrayType.getNumElements());
71
72 return hiBit == 0
73 ? builder
74 .create<hw::ConstantOp>(value.getLoc(),
75 APInt(arrayType.getNumElements(), 0))
76 .getResult()
77 : builder
78 .create<comb::ExtractOp>(value.getLoc(), value,
79 /*lowBit=*/0, hiBit)
80 .getResult();
81}
82
83static hw::HWModuleOp targetModuleOp(hw::InstanceOp instanceOp,
84 const SymbolCache &sc) {
85 auto *targetOp = sc.getDefinition(instanceOp.getModuleNameAttr());
86 auto targetHWModule = dyn_cast<hw::HWModuleOp>(targetOp);
87 if (!targetHWModule)
88 return {}; // Won't specialize external modules.
89
90 if (targetHWModule.getParameters().size() == 0)
91 return {}; // nothing to record or specialize
92
93 return targetHWModule;
94}
95
96// Stores unique module parameters and references to them
97struct ParameterSpecializationRegistry {
98 llvm::MapVector<hw::HWModuleOp, llvm::SetVector<ArrayAttr>>
99 uniqueModuleParameters;
100
101 bool isRegistered(hw::HWModuleOp moduleOp, ArrayAttr parameters) const {
102 auto it = uniqueModuleParameters.find(moduleOp);
103 return it != uniqueModuleParameters.end() &&
104 it->second.contains(parameters);
105 }
106
107 void registerModuleOp(hw::HWModuleOp moduleOp, ArrayAttr parameters) {
108 uniqueModuleParameters[moduleOp].insert(parameters);
109 }
110};
111
112struct EliminateParamValueOpPattern : public OpRewritePattern<ParamValueOp> {
113 EliminateParamValueOpPattern(MLIRContext *context, ArrayAttr parameters)
114 : OpRewritePattern<ParamValueOp>(context), parameters(parameters) {}
115
116 LogicalResult matchAndRewrite(ParamValueOp op,
117 PatternRewriter &rewriter) const override {
118 // Substitute the param value op with an evaluated constant operation.
119 FailureOr<Attribute> evaluated =
120 evaluateParametricAttr(op.getLoc(), parameters, op.getValue());
121 if (failed(evaluated))
122 return failure();
123 rewriter.replaceOpWithNewOp<hw::ConstantOp>(
124 op, op.getType(),
125 cast<IntegerAttr>(*evaluated).getValue().getSExtValue());
126 return success();
127 }
128
129 ArrayAttr parameters;
130};
131
132// hw.array_get operations require indexes to be of equal width of the
133// array itself. Since indexes may originate from constants or parameters,
134// emit comb.extract operations to fulfill this invariant.
135struct NarrowArrayGetIndexPattern : public OpConversionPattern<ArrayGetOp> {
136public:
138
139 LogicalResult
140 matchAndRewrite(ArrayGetOp op, OpAdaptor adaptor,
141 ConversionPatternRewriter &rewriter) const override {
142 auto inputType = type_cast<ArrayType>(op.getInput().getType());
143 Type targetIndexType = IntegerType::get(
144 getContext(), inputType.getNumElements() == 1
145 ? 1
146 : llvm::Log2_64_Ceil(inputType.getNumElements()));
147
148 if (op.getIndex().getType().getIntOrFloatBitWidth() ==
149 targetIndexType.getIntOrFloatBitWidth())
150 return failure(); // nothing to do
151
152 // Narrow the index value.
153 FailureOr<Value> narrowedIndex =
154 narrowValueToArrayWidth(rewriter, op.getInput(), op.getIndex());
155 if (failed(narrowedIndex))
156 return failure();
157 rewriter.replaceOpWithNewOp<ArrayGetOp>(op, op.getInput(), *narrowedIndex);
158 return success();
159 }
160};
161
162// Generic pattern to convert parametric result types.
163struct ParametricTypeConversionPattern : public ConversionPattern {
164 ParametricTypeConversionPattern(MLIRContext *ctx,
165 TypeConverter &typeConverter,
166 ArrayAttr parameters)
167 : ConversionPattern(typeConverter, MatchAnyOpTypeTag(), /*benefit=*/1,
168 ctx),
169 parameters(parameters) {}
170
171 LogicalResult
172 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
173 ConversionPatternRewriter &rewriter) const override {
174 llvm::SmallVector<Value, 4> convertedOperands;
175 // Update the result types of the operation
176 bool ok = true;
177 rewriter.modifyOpInPlace(op, [&]() {
178 // Mutate result types
179 for (auto it : llvm::enumerate(op->getResultTypes())) {
180 FailureOr<Type> res =
181 evaluateParametricType(op->getLoc(), parameters, it.value());
182 ok &= succeeded(res);
183 if (!ok)
184 return;
185 op->getResult(it.index()).setType(*res);
186 }
187
188 // Note: 'operands' have already been converted with the supplied type
189 // converter to this pattern. Make sure that we materialize this
190 // conversion by updating the operands to op.
191 op->setOperands(operands);
192 });
193
194 return success(ok);
195 };
196 ArrayAttr parameters;
197};
198
199struct HWSpecializePass
200 : public circt::hw::impl::HWSpecializeBase<HWSpecializePass> {
201 void runOnOperation() override;
202};
203
204static void populateTypeConversion(Location loc, TypeConverter &typeConverter,
205 ArrayAttr parameters) {
206 // Possibly parametric types
207 typeConverter.addConversion([=](hw::IntType type) {
208 return evaluateParametricType(loc, parameters, type);
209 });
210 typeConverter.addConversion([=](hw::ArrayType type) {
211 return evaluateParametricType(loc, parameters, type);
212 });
213
214 // Valid target types.
215 typeConverter.addConversion([](mlir::IntegerType type) { return type; });
216}
217
218// Registers any nested parametric instance ops of `target` for the next
219// specialization loop
220static LogicalResult registerNestedParametricInstanceOps(
221 HWModuleOp target, ArrayAttr parameters, SymbolCache &sc,
222 const ParameterSpecializationRegistry &currentRegistry,
223 ParameterSpecializationRegistry &nextRegistry,
224 llvm::DenseMap<hw::HWModuleOp,
225 llvm::DenseMap<ArrayAttr, llvm::SmallVector<hw::InstanceOp>>>
226 &parametersUsers) {
227 // Register any nested parametric instance ops for the next loop
228 auto walkResult = target->walk([&](InstanceOp instanceOp) -> WalkResult {
229 auto instanceParameters = instanceOp.getParameters();
230 // We can ignore non-parametric instances
231 if (instanceParameters.empty())
232 return WalkResult::advance();
233
234 // Replace instance parameters with evaluated versions
235 llvm::SmallVector<Attribute> evaluatedInstanceParameters;
236 evaluatedInstanceParameters.reserve(instanceParameters.size());
237 for (auto instanceParameter : instanceParameters) {
238 auto instanceParameterDecl = cast<hw::ParamDeclAttr>(instanceParameter);
239 auto instanceParameterValue = instanceParameterDecl.getValue();
240 auto evaluated = evaluateParametricAttr(target.getLoc(), parameters,
241 instanceParameterValue);
242 if (failed(evaluated))
243 return WalkResult::interrupt();
244 evaluatedInstanceParameters.push_back(
245 hw::ParamDeclAttr::get(instanceParameterDecl.getName(), *evaluated));
246 }
247
248 auto evaluatedInstanceParametersAttr =
249 ArrayAttr::get(target.getContext(), evaluatedInstanceParameters);
250
251 if (auto targetHWModule = targetModuleOp(instanceOp, sc)) {
252 if (!currentRegistry.isRegistered(targetHWModule,
253 evaluatedInstanceParametersAttr))
254 nextRegistry.registerModuleOp(targetHWModule,
255 evaluatedInstanceParametersAttr);
256 parametersUsers[targetHWModule][evaluatedInstanceParametersAttr]
257 .push_back(instanceOp);
258 }
259
260 return WalkResult::advance();
261 });
262
263 return failure(walkResult.wasInterrupted());
264}
265
266// Specializes the provided 'base' module into the 'target' module. By doing
267// so, we create a new module which
268// 1. has no parameters
269// 2. has a name composing the name of 'base' as well as the 'parameters'
270// parameters.
271// 3. Has a top-level interface with any parametric types resolved.
272// 4. Any references to module parameters have been replaced with the
273// parameter value.
274static LogicalResult specializeModule(
275 OpBuilder builder, ArrayAttr parameters, SymbolCache &sc, Namespace &ns,
276 HWModuleOp source, HWModuleOp &target,
277 const ParameterSpecializationRegistry &currentRegistry,
278 ParameterSpecializationRegistry &nextRegistry,
279 llvm::DenseMap<hw::HWModuleOp,
280 llvm::DenseMap<ArrayAttr, llvm::SmallVector<hw::InstanceOp>>>
281 &parametersUsers) {
282 auto *ctx = builder.getContext();
283 // Update the types of the source module ports based on evaluating any
284 // parametric in/output ports.
285 ModulePortInfo ports(source.getPortList());
286 for (auto in : llvm::enumerate(source.getInputTypes())) {
287 FailureOr<Type> resType =
288 evaluateParametricType(source.getLoc(), parameters, in.value());
289 if (failed(resType))
290 return failure();
291 ports.atInput(in.index()).type = *resType;
292 }
293 for (auto out : llvm::enumerate(source.getOutputTypes())) {
294 FailureOr<Type> resolvedType =
295 evaluateParametricType(source.getLoc(), parameters, out.value());
296 if (failed(resolvedType))
297 return failure();
298 ports.atOutput(out.index()).type = *resolvedType;
299 }
300
301 // Create the specialized module using the evaluated port info.
302 target = builder.create<HWModuleOp>(
303 source.getLoc(),
304 StringAttr::get(ctx, generateModuleName(ns, source, parameters)), ports);
305
306 // Erase the default created hw.output op - we'll copy the correct operation
307 // during body elaboration.
308 (*target.getOps<hw::OutputOp>().begin()).erase();
309
310 // Clone body of the source into the target. Use ValueMapper to ensure safe
311 // cloning in the presence of backedges.
312 BackedgeBuilder bb(builder, source.getLoc());
313 ValueMapper mapper(&bb);
314 for (auto &&[src, dst] : llvm::zip(source.getBodyBlock()->getArguments(),
315 target.getBodyBlock()->getArguments()))
316 mapper.set(src, dst);
317 builder.setInsertionPointToStart(target.getBodyBlock());
318
319 for (auto &op : source.getOps()) {
320 IRMapping bvMapper;
321 for (auto operand : op.getOperands())
322 bvMapper.map(operand, mapper.get(operand));
323 auto *newOp = builder.clone(op, bvMapper);
324 for (auto &&[oldRes, newRes] :
325 llvm::zip(op.getResults(), newOp->getResults()))
326 mapper.set(oldRes, newRes);
327 }
328
329 // Register any nested parametric instance ops for the next loop
330 auto nestedRegistrationResult = registerNestedParametricInstanceOps(
331 target, parameters, sc, currentRegistry, nextRegistry, parametersUsers);
332 if (failed(nestedRegistrationResult))
333 return failure();
334
335 // We've now created a separate copy of the source module with a rewritten
336 // top-level interface. Next, we enter the module to convert parametric
337 // types within operations.
338 RewritePatternSet patterns(ctx);
339 TypeConverter t;
340 populateTypeConversion(target.getLoc(), t, parameters);
341 patterns.add<EliminateParamValueOpPattern>(ctx, parameters);
342 patterns.add<NarrowArrayGetIndexPattern>(ctx);
343 patterns.add<ParametricTypeConversionPattern>(ctx, t, parameters);
344 ConversionTarget convTarget(*ctx);
345 convTarget.addLegalOp<hw::HWModuleOp>();
346 convTarget.addIllegalOp<hw::ParamValueOp>();
347
348 // Generic legalization of converted operations.
349 convTarget.markUnknownOpDynamicallyLegal(
350 [](Operation *op) { return !isParametricOp(op); });
351
352 return applyPartialConversion(target, convTarget, std::move(patterns));
353}
354
355void HWSpecializePass::runOnOperation() {
356 ModuleOp module = getOperation();
357
358 // Record unique module parameters and references to these.
359 llvm::DenseMap<hw::HWModuleOp,
360 llvm::DenseMap<ArrayAttr, llvm::SmallVector<hw::InstanceOp>>>
361 parametersUsers;
362 ParameterSpecializationRegistry registry;
363
364 // Maintain a symbol cache for fast lookup during module specialization.
365 SymbolCache sc;
366 sc.addDefinitions(module);
367 Namespace ns;
368 ns.add(sc);
369
370 for (auto hwModule : module.getOps<hw::HWModuleOp>()) {
371 // If this module is parametric, defer registering its parametric
372 // instantiations until this module is specialized
373 if (!hwModule.getParameters().empty())
374 continue;
375 for (auto instanceOp : hwModule.getOps<hw::InstanceOp>()) {
376 if (auto targetHWModule = targetModuleOp(instanceOp, sc)) {
377 auto parameters = instanceOp.getParameters();
378 registry.registerModuleOp(targetHWModule, parameters);
379
380 parametersUsers[targetHWModule][parameters].push_back(instanceOp);
381 }
382 }
383 }
384
385 // Create specialized modules.
386 OpBuilder builder = OpBuilder(&getContext());
387 builder.setInsertionPointToStart(module.getBody());
388 llvm::DenseMap<hw::HWModuleOp, llvm::DenseMap<ArrayAttr, hw::HWModuleOp>>
389 specializations;
390
391 // For every module specialization, any nested parametric modules will be
392 // registered for the next loop. We loop until no new nested modules have been
393 // registered.
394 while (!registry.uniqueModuleParameters.empty()) {
395 // The registry for the next specialization loop
396 ParameterSpecializationRegistry nextRegistry;
397 for (auto it : registry.uniqueModuleParameters) {
398 for (auto parameters : it.second) {
399 HWModuleOp specializedModule;
400 if (failed(specializeModule(builder, parameters, sc, ns, it.first,
401 specializedModule, registry, nextRegistry,
402 parametersUsers))) {
403 signalPassFailure();
404 return;
405 }
406
407 // Extend the symbol cache with the newly created module.
408 sc.addDefinition(specializedModule.getNameAttr(), specializedModule);
409
410 // Add the specialization
411 specializations[it.first][parameters] = specializedModule;
412 }
413 }
414
415 // Transfer newly registered specializations to iterate over
416 registry.uniqueModuleParameters =
417 std::move(nextRegistry.uniqueModuleParameters);
418 }
419
420 // Rewrite instances of specialized modules to the specialized module.
421 for (auto it : specializations) {
422 auto unspecialized = it.getFirst();
423 auto &users = parametersUsers[unspecialized];
424 for (auto specialization : it.getSecond()) {
425 auto parameters = specialization.getFirst();
426 auto specializedModule = specialization.getSecond();
427 for (auto instanceOp : users[parameters]) {
428 instanceOp->setAttr("moduleName",
429 FlatSymbolRefAttr::get(specializedModule));
430 instanceOp->setAttr("parameters", ArrayAttr::get(&getContext(), {}));
431 }
432 }
433 }
434}
435
436} // namespace
437
438std::unique_ptr<Pass> circt::hw::createHWSpecializePass() {
439 return std::make_unique<HWSpecializePass>();
440}
assert(baseType &&"element must be base type")
static void populateTypeConversion(TypeConverter &converter)
static Block * getBodyBlock(FModuleLike mod)
Instantiate one of these and use it to build typed backedges.
A namespace that is used to store existing names and generate new names in some scope within the IR.
Definition Namespace.h:30
void add(mlir::ModuleOp module)
Definition Namespace.h:48
StringRef newName(const Twine &name)
Return a unique name, derived from the input name, and add the new name to the internal namespace.
Definition Namespace.h:85
void addDefinitions(mlir::Operation *top)
Populate the symbol cache with all symbol-defining operations within the 'top' operation.
Definition SymCache.cpp:23
Default symbol cache implementation; stores associations between names (StringAttr's) to mlir::Operat...
Definition SymCache.h:85
void addDefinition(mlir::Attribute key, mlir::Operation *op) override
In the building phase, add symbols.
Definition SymCache.h:88
mlir::Operation * getDefinition(mlir::Attribute attr) const override
Lookup a definition for 'symbol' in the cache.
Definition SymCache.h:94
The ValueMapper class facilitates the definition and connection of SSA def-use chains between two loc...
Definition ValueMapper.h:35
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition CalyxOps.cpp:55
mlir::FailureOr< mlir::TypedAttr > evaluateParametricAttr(mlir::Location loc, mlir::ArrayAttr parameters, mlir::Attribute paramAttr, bool emitErrors=true)
Evaluates a parametric attribute (param.decl.ref/param.expr) based on a set of provided parameter val...
mlir::FailureOr< mlir::Type > evaluateParametricType(mlir::Location loc, mlir::ArrayAttr parameters, mlir::Type type, bool emitErrors=true)
Returns a resolved version of 'type' wherein any parameter reference has been evaluated based on the ...
bool isParametricType(mlir::Type t)
Returns true if any part of t is parametric.
std::unique_ptr< mlir::Pass > createHWSpecializePass()
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition comb.py:1
Definition hw.py:1
This holds a decoded list of input/inout and output ports for a module or instance.