CIRCT  20.0.0git
HWSpecialize.cpp
Go to the documentation of this file.
1 //===- HWSpecialize.cpp - hw module specialization ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This transform performs specialization of parametric hw.module's.
10 //
11 //===----------------------------------------------------------------------===//
12 
15 #include "circt/Dialect/HW/HWOps.h"
20 #include "mlir/IR/Builders.h"
21 #include "mlir/IR/IRMapping.h"
22 #include "mlir/IR/PatternMatch.h"
23 #include "mlir/Pass/Pass.h"
24 #include "mlir/Transforms/DialectConversion.h"
25 #include "llvm/ADT/TypeSwitch.h"
26 
27 namespace circt {
28 namespace hw {
29 #define GEN_PASS_DEF_HWSPECIALIZE
30 #include "circt/Dialect/HW/Passes.h.inc"
31 } // namespace hw
32 } // namespace circt
33 
34 using namespace llvm;
35 using namespace mlir;
36 using namespace circt;
37 using namespace hw;
38 
39 namespace {
40 
41 // Generates a module name by composing the name of 'moduleOp' and the set of
42 // provided 'parameters'.
43 static std::string generateModuleName(Namespace &ns, hw::HWModuleOp moduleOp,
44  ArrayAttr parameters) {
45  assert(parameters.size() != 0);
46  std::string name = moduleOp.getName().str();
47  for (auto param : parameters) {
48  auto paramAttr = cast<ParamDeclAttr>(param);
49  int64_t paramValue = cast<IntegerAttr>(paramAttr.getValue()).getInt();
50  name += "_" + paramAttr.getName().str() + "_" + std::to_string(paramValue);
51  }
52 
53  // Query the namespace to generate a unique name.
54  return ns.newName(name).str();
55 }
56 
57 // Returns true if any operand or result of 'op' is parametric.
58 static bool isParametricOp(Operation *op) {
59  return llvm::any_of(op->getOperandTypes(), isParametricType) ||
60  llvm::any_of(op->getResultTypes(), isParametricType);
61 }
62 
63 // Narrows 'value' using a comb.extract operation to the width of the
64 // hw.array-typed 'array'.
65 static FailureOr<Value> narrowValueToArrayWidth(OpBuilder &builder, Value array,
66  Value value) {
67  OpBuilder::InsertionGuard g(builder);
68  builder.setInsertionPointAfterValue(value);
69  auto arrayType = cast<hw::ArrayType>(array.getType());
70  unsigned hiBit = llvm::Log2_64_Ceil(arrayType.getNumElements());
71 
72  return hiBit == 0
73  ? builder
74  .create<hw::ConstantOp>(value.getLoc(),
75  APInt(arrayType.getNumElements(), 0))
76  .getResult()
77  : builder
78  .create<comb::ExtractOp>(value.getLoc(), value,
79  /*lowBit=*/0, hiBit)
80  .getResult();
81 }
82 
83 static hw::HWModuleOp targetModuleOp(hw::InstanceOp instanceOp,
84  const SymbolCache &sc) {
85  auto *targetOp = sc.getDefinition(instanceOp.getModuleNameAttr());
86  auto targetHWModule = dyn_cast<hw::HWModuleOp>(targetOp);
87  if (!targetHWModule)
88  return {}; // Won't specialize external modules.
89 
90  if (targetHWModule.getParameters().size() == 0)
91  return {}; // nothing to record or specialize
92 
93  return targetHWModule;
94 }
95 
96 // Stores unique module parameters and references to them
97 struct ParameterSpecializationRegistry {
98  llvm::MapVector<hw::HWModuleOp, llvm::SetVector<ArrayAttr>>
99  uniqueModuleParameters;
100 
101  bool isRegistered(hw::HWModuleOp moduleOp, ArrayAttr parameters) const {
102  auto it = uniqueModuleParameters.find(moduleOp);
103  return it != uniqueModuleParameters.end() &&
104  it->second.contains(parameters);
105  }
106 
107  void registerModuleOp(hw::HWModuleOp moduleOp, ArrayAttr parameters) {
108  uniqueModuleParameters[moduleOp].insert(parameters);
109  }
110 };
111 
112 struct EliminateParamValueOpPattern : public OpRewritePattern<ParamValueOp> {
113  EliminateParamValueOpPattern(MLIRContext *context, ArrayAttr parameters)
114  : OpRewritePattern<ParamValueOp>(context), parameters(parameters) {}
115 
116  LogicalResult matchAndRewrite(ParamValueOp op,
117  PatternRewriter &rewriter) const override {
118  // Substitute the param value op with an evaluated constant operation.
119  FailureOr<Attribute> evaluated =
120  evaluateParametricAttr(op.getLoc(), parameters, op.getValue());
121  if (failed(evaluated))
122  return failure();
123  rewriter.replaceOpWithNewOp<hw::ConstantOp>(
124  op, op.getType(),
125  cast<IntegerAttr>(*evaluated).getValue().getSExtValue());
126  return success();
127  }
128 
129  ArrayAttr parameters;
130 };
131 
132 // hw.array_get operations require indexes to be of equal width of the
133 // array itself. Since indexes may originate from constants or parameters,
134 // emit comb.extract operations to fulfill this invariant.
135 struct NarrowArrayGetIndexPattern : public OpConversionPattern<ArrayGetOp> {
136 public:
138 
139  LogicalResult
140  matchAndRewrite(ArrayGetOp op, OpAdaptor adaptor,
141  ConversionPatternRewriter &rewriter) const override {
142  auto inputType = type_cast<ArrayType>(op.getInput().getType());
143  Type targetIndexType = IntegerType::get(
144  getContext(), inputType.getNumElements() == 1
145  ? 1
146  : llvm::Log2_64_Ceil(inputType.getNumElements()));
147 
148  if (op.getIndex().getType().getIntOrFloatBitWidth() ==
149  targetIndexType.getIntOrFloatBitWidth())
150  return failure(); // nothing to do
151 
152  // Narrow the index value.
153  FailureOr<Value> narrowedIndex =
154  narrowValueToArrayWidth(rewriter, op.getInput(), op.getIndex());
155  if (failed(narrowedIndex))
156  return failure();
157  rewriter.replaceOpWithNewOp<ArrayGetOp>(op, op.getInput(), *narrowedIndex);
158  return success();
159  }
160 };
161 
162 // Generic pattern to convert parametric result types.
163 struct ParametricTypeConversionPattern : public ConversionPattern {
164  ParametricTypeConversionPattern(MLIRContext *ctx,
165  TypeConverter &typeConverter,
166  ArrayAttr parameters)
167  : ConversionPattern(typeConverter, MatchAnyOpTypeTag(), /*benefit=*/1,
168  ctx),
169  parameters(parameters) {}
170 
171  LogicalResult
172  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
173  ConversionPatternRewriter &rewriter) const override {
174  llvm::SmallVector<Value, 4> convertedOperands;
175  // Update the result types of the operation
176  bool ok = true;
177  rewriter.modifyOpInPlace(op, [&]() {
178  // Mutate result types
179  for (auto it : llvm::enumerate(op->getResultTypes())) {
180  FailureOr<Type> res =
181  evaluateParametricType(op->getLoc(), parameters, it.value());
182  ok &= succeeded(res);
183  if (!ok)
184  return;
185  op->getResult(it.index()).setType(*res);
186  }
187 
188  // Note: 'operands' have already been converted with the supplied type
189  // converter to this pattern. Make sure that we materialize this
190  // conversion by updating the operands to op.
191  op->setOperands(operands);
192  });
193 
194  return success(ok);
195  };
196  ArrayAttr parameters;
197 };
198 
199 struct HWSpecializePass
200  : public circt::hw::impl::HWSpecializeBase<HWSpecializePass> {
201  void runOnOperation() override;
202 };
203 
204 static void populateTypeConversion(Location loc, TypeConverter &typeConverter,
205  ArrayAttr parameters) {
206  // Possibly parametric types
207  typeConverter.addConversion([=](hw::IntType type) {
208  return evaluateParametricType(loc, parameters, type);
209  });
210  typeConverter.addConversion([=](hw::ArrayType type) {
211  return evaluateParametricType(loc, parameters, type);
212  });
213 
214  // Valid target types.
215  typeConverter.addConversion([](mlir::IntegerType type) { return type; });
216 }
217 
218 // Registers any nested parametric instance ops of `target` for the next
219 // specialization loop
220 static LogicalResult registerNestedParametricInstanceOps(
221  HWModuleOp target, ArrayAttr parameters, SymbolCache &sc,
222  const ParameterSpecializationRegistry &currentRegistry,
223  ParameterSpecializationRegistry &nextRegistry,
224  llvm::DenseMap<hw::HWModuleOp,
225  llvm::DenseMap<ArrayAttr, llvm::SmallVector<hw::InstanceOp>>>
226  &parametersUsers) {
227  // Register any nested parametric instance ops for the next loop
228  auto walkResult = target->walk([&](InstanceOp instanceOp) -> WalkResult {
229  auto instanceParameters = instanceOp.getParameters();
230  // We can ignore non-parametric instances
231  if (instanceParameters.empty())
232  return WalkResult::advance();
233 
234  // Replace instance parameters with evaluated versions
235  llvm::SmallVector<Attribute> evaluatedInstanceParameters;
236  evaluatedInstanceParameters.reserve(instanceParameters.size());
237  for (auto instanceParameter : instanceParameters) {
238  auto instanceParameterDecl = cast<hw::ParamDeclAttr>(instanceParameter);
239  auto instanceParameterValue = instanceParameterDecl.getValue();
240  auto evaluated = evaluateParametricAttr(target.getLoc(), parameters,
241  instanceParameterValue);
242  if (failed(evaluated))
243  return WalkResult::interrupt();
244  evaluatedInstanceParameters.push_back(
245  hw::ParamDeclAttr::get(instanceParameterDecl.getName(), *evaluated));
246  }
247 
248  auto evaluatedInstanceParametersAttr =
249  ArrayAttr::get(target.getContext(), evaluatedInstanceParameters);
250 
251  if (auto targetHWModule = targetModuleOp(instanceOp, sc)) {
252  if (!currentRegistry.isRegistered(targetHWModule,
253  evaluatedInstanceParametersAttr))
254  nextRegistry.registerModuleOp(targetHWModule,
255  evaluatedInstanceParametersAttr);
256  parametersUsers[targetHWModule][evaluatedInstanceParametersAttr]
257  .push_back(instanceOp);
258  }
259 
260  return WalkResult::advance();
261  });
262 
263  return failure(walkResult.wasInterrupted());
264 }
265 
266 // Specializes the provided 'base' module into the 'target' module. By doing
267 // so, we create a new module which
268 // 1. has no parameters
269 // 2. has a name composing the name of 'base' as well as the 'parameters'
270 // parameters.
271 // 3. Has a top-level interface with any parametric types resolved.
272 // 4. Any references to module parameters have been replaced with the
273 // parameter value.
274 static LogicalResult specializeModule(
275  OpBuilder builder, ArrayAttr parameters, SymbolCache &sc, Namespace &ns,
276  HWModuleOp source, HWModuleOp &target,
277  const ParameterSpecializationRegistry &currentRegistry,
278  ParameterSpecializationRegistry &nextRegistry,
279  llvm::DenseMap<hw::HWModuleOp,
280  llvm::DenseMap<ArrayAttr, llvm::SmallVector<hw::InstanceOp>>>
281  &parametersUsers) {
282  auto *ctx = builder.getContext();
283  // Update the types of the source module ports based on evaluating any
284  // parametric in/output ports.
285  ModulePortInfo ports(source.getPortList());
286  for (auto in : llvm::enumerate(source.getInputTypes())) {
287  FailureOr<Type> resType =
288  evaluateParametricType(source.getLoc(), parameters, in.value());
289  if (failed(resType))
290  return failure();
291  ports.atInput(in.index()).type = *resType;
292  }
293  for (auto out : llvm::enumerate(source.getOutputTypes())) {
294  FailureOr<Type> resolvedType =
295  evaluateParametricType(source.getLoc(), parameters, out.value());
296  if (failed(resolvedType))
297  return failure();
298  ports.atOutput(out.index()).type = *resolvedType;
299  }
300 
301  // Create the specialized module using the evaluated port info.
302  target = builder.create<HWModuleOp>(
303  source.getLoc(),
304  StringAttr::get(ctx, generateModuleName(ns, source, parameters)), ports);
305 
306  // Erase the default created hw.output op - we'll copy the correct operation
307  // during body elaboration.
308  (*target.getOps<hw::OutputOp>().begin()).erase();
309 
310  // Clone body of the source into the target. Use ValueMapper to ensure safe
311  // cloning in the presence of backedges.
312  BackedgeBuilder bb(builder, source.getLoc());
313  ValueMapper mapper(&bb);
314  for (auto &&[src, dst] : llvm::zip(source.getBodyBlock()->getArguments(),
315  target.getBodyBlock()->getArguments()))
316  mapper.set(src, dst);
317  builder.setInsertionPointToStart(target.getBodyBlock());
318 
319  for (auto &op : source.getOps()) {
320  IRMapping bvMapper;
321  for (auto operand : op.getOperands())
322  bvMapper.map(operand, mapper.get(operand));
323  auto *newOp = builder.clone(op, bvMapper);
324  for (auto &&[oldRes, newRes] :
325  llvm::zip(op.getResults(), newOp->getResults()))
326  mapper.set(oldRes, newRes);
327  }
328 
329  // Register any nested parametric instance ops for the next loop
330  auto nestedRegistrationResult = registerNestedParametricInstanceOps(
331  target, parameters, sc, currentRegistry, nextRegistry, parametersUsers);
332  if (failed(nestedRegistrationResult))
333  return failure();
334 
335  // We've now created a separate copy of the source module with a rewritten
336  // top-level interface. Next, we enter the module to convert parametric
337  // types within operations.
338  RewritePatternSet patterns(ctx);
339  TypeConverter t;
340  populateTypeConversion(target.getLoc(), t, parameters);
341  patterns.add<EliminateParamValueOpPattern>(ctx, parameters);
342  patterns.add<NarrowArrayGetIndexPattern>(ctx);
343  patterns.add<ParametricTypeConversionPattern>(ctx, t, parameters);
344  ConversionTarget convTarget(*ctx);
345  convTarget.addLegalOp<hw::HWModuleOp>();
346  convTarget.addIllegalOp<hw::ParamValueOp>();
347 
348  // Generic legalization of converted operations.
349  convTarget.markUnknownOpDynamicallyLegal(
350  [](Operation *op) { return !isParametricOp(op); });
351 
352  return applyPartialConversion(target, convTarget, std::move(patterns));
353 }
354 
355 void HWSpecializePass::runOnOperation() {
356  ModuleOp module = getOperation();
357 
358  // Record unique module parameters and references to these.
359  llvm::DenseMap<hw::HWModuleOp,
360  llvm::DenseMap<ArrayAttr, llvm::SmallVector<hw::InstanceOp>>>
361  parametersUsers;
362  ParameterSpecializationRegistry registry;
363 
364  // Maintain a symbol cache for fast lookup during module specialization.
365  SymbolCache sc;
366  sc.addDefinitions(module);
367  Namespace ns;
368  ns.add(sc);
369 
370  for (auto hwModule : module.getOps<hw::HWModuleOp>()) {
371  // If this module is parametric, defer registering its parametric
372  // instantiations until this module is specialized
373  if (!hwModule.getParameters().empty())
374  continue;
375  for (auto instanceOp : hwModule.getOps<hw::InstanceOp>()) {
376  if (auto targetHWModule = targetModuleOp(instanceOp, sc)) {
377  auto parameters = instanceOp.getParameters();
378  registry.registerModuleOp(targetHWModule, parameters);
379 
380  parametersUsers[targetHWModule][parameters].push_back(instanceOp);
381  }
382  }
383  }
384 
385  // Create specialized modules.
386  OpBuilder builder = OpBuilder(&getContext());
387  builder.setInsertionPointToStart(module.getBody());
388  llvm::DenseMap<hw::HWModuleOp, llvm::DenseMap<ArrayAttr, hw::HWModuleOp>>
389  specializations;
390 
391  // For every module specialization, any nested parametric modules will be
392  // registered for the next loop. We loop until no new nested modules have been
393  // registered.
394  while (!registry.uniqueModuleParameters.empty()) {
395  // The registry for the next specialization loop
396  ParameterSpecializationRegistry nextRegistry;
397  for (auto it : registry.uniqueModuleParameters) {
398  for (auto parameters : it.second) {
399  HWModuleOp specializedModule;
400  if (failed(specializeModule(builder, parameters, sc, ns, it.first,
401  specializedModule, registry, nextRegistry,
402  parametersUsers))) {
403  signalPassFailure();
404  return;
405  }
406 
407  // Extend the symbol cache with the newly created module.
408  sc.addDefinition(specializedModule.getNameAttr(), specializedModule);
409 
410  // Add the specialization
411  specializations[it.first][parameters] = specializedModule;
412  }
413  }
414 
415  // Transfer newly registered specializations to iterate over
416  registry.uniqueModuleParameters =
417  std::move(nextRegistry.uniqueModuleParameters);
418  }
419 
420  // Rewrite instances of specialized modules to the specialized module.
421  for (auto it : specializations) {
422  auto unspecialized = it.getFirst();
423  auto &users = parametersUsers[unspecialized];
424  for (auto specialization : it.getSecond()) {
425  auto parameters = specialization.getFirst();
426  auto specializedModule = specialization.getSecond();
427  for (auto instanceOp : users[parameters]) {
428  instanceOp->setAttr("moduleName",
429  FlatSymbolRefAttr::get(specializedModule));
430  instanceOp->setAttr("parameters", ArrayAttr::get(&getContext(), {}));
431  }
432  }
433  }
434 }
435 
436 } // namespace
437 
438 std::unique_ptr<Pass> circt::hw::createHWSpecializePass() {
439  return std::make_unique<HWSpecializePass>();
440 }
assert(baseType &&"element must be base type")
static void populateTypeConversion(TypeConverter &converter)
Instantiate one of these and use it to build typed backedges.
A namespace that is used to store existing names and generate new names in some scope within the IR.
Definition: Namespace.h:30
void add(mlir::ModuleOp module)
Definition: Namespace.h:48
StringRef newName(const Twine &name)
Return a unique name, derived from the input name, and add the new name to the internal namespace.
Definition: Namespace.h:85
void addDefinitions(mlir::Operation *top)
Populate the symbol cache with all symbol-defining operations within the 'top' operation.
Definition: SymCache.cpp:23
Default symbol cache implementation; stores associations between names (StringAttr's) to mlir::Operat...
Definition: SymCache.h:85
void addDefinition(mlir::Attribute key, mlir::Operation *op) override
In the building phase, add symbols.
Definition: SymCache.h:88
mlir::Operation * getDefinition(mlir::Attribute attr) const override
Lookup a definition for 'symbol' in the cache.
Definition: SymCache.h:94
The ValueMapper class facilitates the definition and connection of SSA def-use chains between two loc...
Definition: ValueMapper.h:35
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:55
mlir::FailureOr< mlir::TypedAttr > evaluateParametricAttr(mlir::Location loc, mlir::ArrayAttr parameters, mlir::Attribute paramAttr, bool emitErrors=true)
Evaluates a parametric attribute (param.decl.ref/param.expr) based on a set of provided parameter val...
enum PEO uint32_t mlir::FailureOr< mlir::Type > evaluateParametricType(mlir::Location loc, mlir::ArrayAttr parameters, mlir::Type type, bool emitErrors=true)
Returns a resolved version of 'type' wherein any parameter reference has been evaluated based on the ...
bool isParametricType(mlir::Type t)
Returns true if any part of t is parametric.
std::unique_ptr< mlir::Pass > createHWSpecializePass()
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21
Definition: comb.py:1
Definition: hw.py:1
This holds a decoded list of input/inout and output ports for a module or instance.