CIRCT  19.0.0git
HWSpecialize.cpp
Go to the documentation of this file.
1 //===- HWSpecialize.cpp - hw module specialization ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This transform performs specialization of parametric hw.module's.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "PassDetails.h"
16 #include "circt/Dialect/HW/HWOps.h"
21 #include "mlir/IR/Builders.h"
22 #include "mlir/IR/IRMapping.h"
23 #include "mlir/IR/PatternMatch.h"
24 #include "mlir/Transforms/DialectConversion.h"
25 #include "llvm/ADT/TypeSwitch.h"
26 
27 using namespace llvm;
28 using namespace mlir;
29 using namespace circt;
30 using namespace hw;
31 
32 namespace {
33 
34 // Generates a module name by composing the name of 'moduleOp' and the set of
35 // provided 'parameters'.
36 static std::string generateModuleName(Namespace &ns, hw::HWModuleOp moduleOp,
37  ArrayAttr parameters) {
38  assert(parameters.size() != 0);
39  std::string name = moduleOp.getName().str();
40  for (auto param : parameters) {
41  auto paramAttr = cast<ParamDeclAttr>(param);
42  int64_t paramValue = cast<IntegerAttr>(paramAttr.getValue()).getInt();
43  name += "_" + paramAttr.getName().str() + "_" + std::to_string(paramValue);
44  }
45 
46  // Query the namespace to generate a unique name.
47  return ns.newName(name).str();
48 }
49 
50 // Returns true if any operand or result of 'op' is parametric.
51 static bool isParametricOp(Operation *op) {
52  return llvm::any_of(op->getOperandTypes(), isParametricType) ||
53  llvm::any_of(op->getResultTypes(), isParametricType);
54 }
55 
56 // Narrows 'value' using a comb.extract operation to the width of the
57 // hw.array-typed 'array'.
58 static FailureOr<Value> narrowValueToArrayWidth(OpBuilder &builder, Value array,
59  Value value) {
60  OpBuilder::InsertionGuard g(builder);
61  builder.setInsertionPointAfterValue(value);
62  auto arrayType = cast<hw::ArrayType>(array.getType());
63  unsigned hiBit = llvm::Log2_64_Ceil(arrayType.getNumElements());
64 
65  return hiBit == 0
66  ? builder
67  .create<hw::ConstantOp>(value.getLoc(),
68  APInt(arrayType.getNumElements(), 0))
69  .getResult()
70  : builder
71  .create<comb::ExtractOp>(value.getLoc(), value,
72  /*lowBit=*/0, hiBit)
73  .getResult();
74 }
75 
76 static hw::HWModuleOp targetModuleOp(hw::InstanceOp instanceOp,
77  const SymbolCache &sc) {
78  auto *targetOp = sc.getDefinition(instanceOp.getModuleNameAttr());
79  auto targetHWModule = dyn_cast<hw::HWModuleOp>(targetOp);
80  if (!targetHWModule)
81  return {}; // Won't specialize external modules.
82 
83  if (targetHWModule.getParameters().size() == 0)
84  return {}; // nothing to record or specialize
85 
86  return targetHWModule;
87 }
88 
89 // Stores unique module parameters and references to them
90 struct ParameterSpecializationRegistry {
91  llvm::MapVector<hw::HWModuleOp, llvm::SetVector<ArrayAttr>>
92  uniqueModuleParameters;
93 
94  bool isRegistered(hw::HWModuleOp moduleOp, ArrayAttr parameters) const {
95  auto it = uniqueModuleParameters.find(moduleOp);
96  return it != uniqueModuleParameters.end() &&
97  it->second.contains(parameters);
98  }
99 
100  void registerModuleOp(hw::HWModuleOp moduleOp, ArrayAttr parameters) {
101  uniqueModuleParameters[moduleOp].insert(parameters);
102  }
103 };
104 
105 struct EliminateParamValueOpPattern : public OpRewritePattern<ParamValueOp> {
106  EliminateParamValueOpPattern(MLIRContext *context, ArrayAttr parameters)
107  : OpRewritePattern<ParamValueOp>(context), parameters(parameters) {}
108 
109  LogicalResult matchAndRewrite(ParamValueOp op,
110  PatternRewriter &rewriter) const override {
111  // Substitute the param value op with an evaluated constant operation.
112  FailureOr<Attribute> evaluated =
113  evaluateParametricAttr(op.getLoc(), parameters, op.getValue());
114  if (failed(evaluated))
115  return failure();
116  rewriter.replaceOpWithNewOp<hw::ConstantOp>(
117  op, op.getType(),
118  evaluated->cast<IntegerAttr>().getValue().getSExtValue());
119  return success();
120  }
121 
122  ArrayAttr parameters;
123 };
124 
125 // hw.array_get operations require indexes to be of equal width of the
126 // array itself. Since indexes may originate from constants or parameters,
127 // emit comb.extract operations to fulfill this invariant.
128 struct NarrowArrayGetIndexPattern : public OpConversionPattern<ArrayGetOp> {
129 public:
131 
132  LogicalResult
133  matchAndRewrite(ArrayGetOp op, OpAdaptor adaptor,
134  ConversionPatternRewriter &rewriter) const override {
135  auto inputType = type_cast<ArrayType>(op.getInput().getType());
136  Type targetIndexType = IntegerType::get(
137  getContext(), inputType.getNumElements() == 1
138  ? 1
139  : llvm::Log2_64_Ceil(inputType.getNumElements()));
140 
141  if (op.getIndex().getType().getIntOrFloatBitWidth() ==
142  targetIndexType.getIntOrFloatBitWidth())
143  return failure(); // nothing to do
144 
145  // Narrow the index value.
146  FailureOr<Value> narrowedIndex =
147  narrowValueToArrayWidth(rewriter, op.getInput(), op.getIndex());
148  if (failed(narrowedIndex))
149  return failure();
150  rewriter.replaceOpWithNewOp<ArrayGetOp>(op, op.getInput(), *narrowedIndex);
151  return success();
152  }
153 };
154 
155 // Generic pattern to convert parametric result types.
156 struct ParametricTypeConversionPattern : public ConversionPattern {
157  ParametricTypeConversionPattern(MLIRContext *ctx,
158  TypeConverter &typeConverter,
159  ArrayAttr parameters)
160  : ConversionPattern(typeConverter, MatchAnyOpTypeTag(), /*benefit=*/1,
161  ctx),
162  parameters(parameters) {}
163 
164  LogicalResult
165  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
166  ConversionPatternRewriter &rewriter) const override {
167  llvm::SmallVector<Value, 4> convertedOperands;
168  // Update the result types of the operation
169  bool ok = true;
170  rewriter.modifyOpInPlace(op, [&]() {
171  // Mutate result types
172  for (auto it : llvm::enumerate(op->getResultTypes())) {
173  FailureOr<Type> res =
174  evaluateParametricType(op->getLoc(), parameters, it.value());
175  ok &= succeeded(res);
176  if (!ok)
177  return;
178  op->getResult(it.index()).setType(*res);
179  }
180 
181  // Note: 'operands' have already been converted with the supplied type
182  // converter to this pattern. Make sure that we materialize this
183  // conversion by updating the operands to op.
184  op->setOperands(operands);
185  });
186 
187  return success(ok);
188  };
189  ArrayAttr parameters;
190 };
191 
192 struct HWSpecializePass : public hw::HWSpecializeBase<HWSpecializePass> {
193  void runOnOperation() override;
194 };
195 
196 static void populateTypeConversion(Location loc, TypeConverter &typeConverter,
197  ArrayAttr parameters) {
198  // Possibly parametric types
199  typeConverter.addConversion([=](hw::IntType type) {
200  return evaluateParametricType(loc, parameters, type);
201  });
202  typeConverter.addConversion([=](hw::ArrayType type) {
203  return evaluateParametricType(loc, parameters, type);
204  });
205 
206  // Valid target types.
207  typeConverter.addConversion([](mlir::IntegerType type) { return type; });
208 }
209 
210 // Registers any nested parametric instance ops of `target` for the next
211 // specialization loop
212 static LogicalResult registerNestedParametricInstanceOps(
213  HWModuleOp target, ArrayAttr parameters, SymbolCache &sc,
214  const ParameterSpecializationRegistry &currentRegistry,
215  ParameterSpecializationRegistry &nextRegistry,
216  llvm::DenseMap<hw::HWModuleOp,
217  llvm::DenseMap<ArrayAttr, llvm::SmallVector<hw::InstanceOp>>>
218  &parametersUsers) {
219  // Register any nested parametric instance ops for the next loop
220  auto walkResult = target->walk([&](InstanceOp instanceOp) -> WalkResult {
221  auto instanceParameters = instanceOp.getParameters();
222  // We can ignore non-parametric instances
223  if (instanceParameters.empty())
224  return WalkResult::advance();
225 
226  // Replace instance parameters with evaluated versions
227  llvm::SmallVector<Attribute> evaluatedInstanceParameters;
228  evaluatedInstanceParameters.reserve(instanceParameters.size());
229  for (auto instanceParameter : instanceParameters) {
230  auto instanceParameterDecl = cast<hw::ParamDeclAttr>(instanceParameter);
231  auto instanceParameterValue = instanceParameterDecl.getValue();
232  auto evaluated = evaluateParametricAttr(target.getLoc(), parameters,
233  instanceParameterValue);
234  if (failed(evaluated))
235  return WalkResult::interrupt();
236  evaluatedInstanceParameters.push_back(
237  hw::ParamDeclAttr::get(instanceParameterDecl.getName(), *evaluated));
238  }
239 
240  auto evaluatedInstanceParametersAttr =
241  ArrayAttr::get(target.getContext(), evaluatedInstanceParameters);
242 
243  if (auto targetHWModule = targetModuleOp(instanceOp, sc)) {
244  if (!currentRegistry.isRegistered(targetHWModule,
245  evaluatedInstanceParametersAttr))
246  nextRegistry.registerModuleOp(targetHWModule,
247  evaluatedInstanceParametersAttr);
248  parametersUsers[targetHWModule][evaluatedInstanceParametersAttr]
249  .push_back(instanceOp);
250  }
251 
252  return WalkResult::advance();
253  });
254 
255  return failure(walkResult.wasInterrupted());
256 }
257 
258 // Specializes the provided 'base' module into the 'target' module. By doing
259 // so, we create a new module which
260 // 1. has no parameters
261 // 2. has a name composing the name of 'base' as well as the 'parameters'
262 // parameters.
263 // 3. Has a top-level interface with any parametric types resolved.
264 // 4. Any references to module parameters have been replaced with the
265 // parameter value.
266 static LogicalResult specializeModule(
267  OpBuilder builder, ArrayAttr parameters, SymbolCache &sc, Namespace &ns,
268  HWModuleOp source, HWModuleOp &target,
269  const ParameterSpecializationRegistry &currentRegistry,
270  ParameterSpecializationRegistry &nextRegistry,
271  llvm::DenseMap<hw::HWModuleOp,
272  llvm::DenseMap<ArrayAttr, llvm::SmallVector<hw::InstanceOp>>>
273  &parametersUsers) {
274  auto *ctx = builder.getContext();
275  // Update the types of the source module ports based on evaluating any
276  // parametric in/output ports.
277  ModulePortInfo ports(source.getPortList());
278  for (auto in : llvm::enumerate(source.getInputTypes())) {
279  FailureOr<Type> resType =
280  evaluateParametricType(source.getLoc(), parameters, in.value());
281  if (failed(resType))
282  return failure();
283  ports.atInput(in.index()).type = *resType;
284  }
285  for (auto out : llvm::enumerate(source.getOutputTypes())) {
286  FailureOr<Type> resolvedType =
287  evaluateParametricType(source.getLoc(), parameters, out.value());
288  if (failed(resolvedType))
289  return failure();
290  ports.atOutput(out.index()).type = *resolvedType;
291  }
292 
293  // Create the specialized module using the evaluated port info.
294  target = builder.create<HWModuleOp>(
295  source.getLoc(),
296  StringAttr::get(ctx, generateModuleName(ns, source, parameters)), ports);
297 
298  // Erase the default created hw.output op - we'll copy the correct operation
299  // during body elaboration.
300  (*target.getOps<hw::OutputOp>().begin()).erase();
301 
302  // Clone body of the source into the target. Use ValueMapper to ensure safe
303  // cloning in the presence of backedges.
304  BackedgeBuilder bb(builder, source.getLoc());
305  ValueMapper mapper(&bb);
306  for (auto &&[src, dst] : llvm::zip(source.getBodyBlock()->getArguments(),
307  target.getBodyBlock()->getArguments()))
308  mapper.set(src, dst);
309  builder.setInsertionPointToStart(target.getBodyBlock());
310 
311  for (auto &op : source.getOps()) {
312  IRMapping bvMapper;
313  for (auto operand : op.getOperands())
314  bvMapper.map(operand, mapper.get(operand));
315  auto *newOp = builder.clone(op, bvMapper);
316  for (auto &&[oldRes, newRes] :
317  llvm::zip(op.getResults(), newOp->getResults()))
318  mapper.set(oldRes, newRes);
319  }
320 
321  // Register any nested parametric instance ops for the next loop
322  auto nestedRegistrationResult = registerNestedParametricInstanceOps(
323  target, parameters, sc, currentRegistry, nextRegistry, parametersUsers);
324  if (failed(nestedRegistrationResult))
325  return failure();
326 
327  // We've now created a separate copy of the source module with a rewritten
328  // top-level interface. Next, we enter the module to convert parametric
329  // types within operations.
330  RewritePatternSet patterns(ctx);
331  TypeConverter t;
332  populateTypeConversion(target.getLoc(), t, parameters);
333  patterns.add<EliminateParamValueOpPattern>(ctx, parameters);
334  patterns.add<NarrowArrayGetIndexPattern>(ctx);
335  patterns.add<ParametricTypeConversionPattern>(ctx, t, parameters);
336  ConversionTarget convTarget(*ctx);
337  convTarget.addLegalOp<hw::HWModuleOp>();
338  convTarget.addIllegalOp<hw::ParamValueOp>();
339 
340  // Generic legalization of converted operations.
341  convTarget.markUnknownOpDynamicallyLegal(
342  [](Operation *op) { return !isParametricOp(op); });
343 
344  return applyPartialConversion(target, convTarget, std::move(patterns));
345 }
346 
347 void HWSpecializePass::runOnOperation() {
348  ModuleOp module = getOperation();
349 
350  // Record unique module parameters and references to these.
351  llvm::DenseMap<hw::HWModuleOp,
352  llvm::DenseMap<ArrayAttr, llvm::SmallVector<hw::InstanceOp>>>
353  parametersUsers;
354  ParameterSpecializationRegistry registry;
355 
356  // Maintain a symbol cache for fast lookup during module specialization.
357  SymbolCache sc;
358  sc.addDefinitions(module);
359  Namespace ns;
360  ns.add(sc);
361 
362  for (auto hwModule : module.getOps<hw::HWModuleOp>()) {
363  // If this module is parametric, defer registering its parametric
364  // instantiations until this module is specialized
365  if (!hwModule.getParameters().empty())
366  continue;
367  for (auto instanceOp : hwModule.getOps<hw::InstanceOp>()) {
368  if (auto targetHWModule = targetModuleOp(instanceOp, sc)) {
369  auto parameters = instanceOp.getParameters();
370  registry.registerModuleOp(targetHWModule, parameters);
371 
372  parametersUsers[targetHWModule][parameters].push_back(instanceOp);
373  }
374  }
375  }
376 
377  // Create specialized modules.
378  OpBuilder builder = OpBuilder(&getContext());
379  builder.setInsertionPointToStart(module.getBody());
380  llvm::DenseMap<hw::HWModuleOp, llvm::DenseMap<ArrayAttr, hw::HWModuleOp>>
381  specializations;
382 
383  // For every module specialization, any nested parametric modules will be
384  // registered for the next loop. We loop until no new nested modules have been
385  // registered.
386  while (!registry.uniqueModuleParameters.empty()) {
387  // The registry for the next specialization loop
388  ParameterSpecializationRegistry nextRegistry;
389  for (auto it : registry.uniqueModuleParameters) {
390  for (auto parameters : it.second) {
391  HWModuleOp specializedModule;
392  if (failed(specializeModule(builder, parameters, sc, ns, it.first,
393  specializedModule, registry, nextRegistry,
394  parametersUsers))) {
395  signalPassFailure();
396  return;
397  }
398 
399  // Extend the symbol cache with the newly created module.
400  sc.addDefinition(specializedModule.getNameAttr(), specializedModule);
401 
402  // Add the specialization
403  specializations[it.first][parameters] = specializedModule;
404  }
405  }
406 
407  // Transfer newly registered specializations to iterate over
408  registry.uniqueModuleParameters =
409  std::move(nextRegistry.uniqueModuleParameters);
410  }
411 
412  // Rewrite instances of specialized modules to the specialized module.
413  for (auto it : specializations) {
414  auto unspecialized = it.getFirst();
415  auto &users = parametersUsers[unspecialized];
416  for (auto specialization : it.getSecond()) {
417  auto parameters = specialization.getFirst();
418  auto specializedModule = specialization.getSecond();
419  for (auto instanceOp : users[parameters]) {
420  instanceOp->setAttr("moduleName",
421  FlatSymbolRefAttr::get(specializedModule));
422  instanceOp->setAttr("parameters", ArrayAttr::get(&getContext(), {}));
423  }
424  }
425  }
426 }
427 
428 } // namespace
429 
430 std::unique_ptr<Pass> circt::hw::createHWSpecializePass() {
431  return std::make_unique<HWSpecializePass>();
432 }
assert(baseType &&"element must be base type")
static void populateTypeConversion(TypeConverter &converter)
Builder builder
Instantiate one of these and use it to build typed backedges.
A namespace that is used to store existing names and generate new names in some scope within the IR.
Definition: Namespace.h:29
void add(SymbolCache &symCache)
SymbolCache initializer; initialize from every key that is convertible to a StringAttr in the SymbolC...
Definition: Namespace.h:47
StringRef newName(const Twine &name)
Return a unique name, derived from the input name, and add the new name to the internal namespace.
Definition: Namespace.h:63
void addDefinitions(mlir::Operation *top)
Populate the symbol cache with all symbol-defining operations within the 'top' operation.
Definition: SymCache.cpp:23
Default symbol cache implementation; stores associations between names (StringAttr's) to mlir::Operat...
Definition: SymCache.h:85
void addDefinition(mlir::Attribute key, mlir::Operation *op) override
In the building phase, add symbols.
Definition: SymCache.h:88
mlir::Operation * getDefinition(mlir::Attribute attr) const override
Lookup a definition for 'symbol' in the cache.
Definition: SymCache.h:94
The ValueMapper class facilitates the definition and connection of SSA def-use chains between two loc...
Definition: ValueMapper.h:35
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:54
mlir::FailureOr< mlir::TypedAttr > evaluateParametricAttr(mlir::Location loc, mlir::ArrayAttr parameters, mlir::Attribute paramAttr, bool emitErrors=true)
Evaluates a parametric attribute (param.decl.ref/param.expr) based on a set of provided parameter val...
enum PEO uint32_t mlir::FailureOr< mlir::Type > evaluateParametricType(mlir::Location loc, mlir::ArrayAttr parameters, mlir::Type type, bool emitErrors=true)
Returns a resolved version of 'type' wherein any parameter reference has been evaluated based on the ...
bool isParametricType(mlir::Type t)
Returns true if any part of t is parametric.
std::unique_ptr< mlir::Pass > createHWSpecializePass()
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21
Definition: comb.py:1
Definition: hw.py:1
This holds a decoded list of input/inout and output ports for a module or instance.