CIRCT  19.0.0git
FlattenIO.cpp
Go to the documentation of this file.
1 //===- FlattenIO.cpp - HW I/O flattening pass -------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
11 #include "mlir/Pass/Pass.h"
12 #include "mlir/Transforms/DialectConversion.h"
13 #include "llvm/ADT/TypeSwitch.h"
14 
15 namespace circt {
16 namespace hw {
17 #define GEN_PASS_DEF_FLATTENIO
18 #include "circt/Dialect/HW/Passes.h.inc"
19 } // namespace hw
20 } // namespace circt
21 
22 using namespace mlir;
23 using namespace circt;
24 
25 static bool isStructType(Type type) {
26  return isa<hw::StructType>(hw::getCanonicalType(type));
27 }
28 
29 static hw::StructType getStructType(Type type) {
30  return dyn_cast<hw::StructType>(hw::getCanonicalType(type));
31 }
32 
33 // Legal if no in- or output type is a struct.
34 static bool isLegalModLikeOp(hw::HWModuleLike moduleLikeOp) {
35  return llvm::none_of(moduleLikeOp.getHWModuleType().getPortTypes(),
36  isStructType);
37 }
38 
39 static llvm::SmallVector<Type> getInnerTypes(hw::StructType t) {
40  llvm::SmallVector<Type> inner;
41  t.getInnerTypes(inner);
42  for (auto [index, innerType] : llvm::enumerate(inner))
43  inner[index] = hw::getCanonicalType(innerType);
44  return inner;
45 }
46 
47 namespace {
48 
49 // Replaces an output op with a new output with flattened (exploded) structs.
50 struct OutputOpConversion : public OpConversionPattern<hw::OutputOp> {
51  OutputOpConversion(TypeConverter &typeConverter, MLIRContext *context,
52  DenseSet<Operation *> *opVisited)
53  : OpConversionPattern(typeConverter, context), opVisited(opVisited) {}
54 
55  LogicalResult
56  matchAndRewrite(hw::OutputOp op, OpAdaptor adaptor,
57  ConversionPatternRewriter &rewriter) const override {
58  llvm::SmallVector<Value> convOperands;
59 
60  // Flatten the operands.
61  for (auto operand : adaptor.getOperands()) {
62  if (auto structType = getStructType(operand.getType())) {
63  auto explodedStruct = rewriter.create<hw::StructExplodeOp>(
64  op.getLoc(), getInnerTypes(structType), operand);
65  llvm::copy(explodedStruct.getResults(),
66  std::back_inserter(convOperands));
67  } else {
68  convOperands.push_back(operand);
69  }
70  }
71 
72  // And replace.
73  opVisited->insert(op->getParentOp());
74  rewriter.replaceOpWithNewOp<hw::OutputOp>(op, convOperands);
75  return success();
76  }
77  DenseSet<Operation *> *opVisited;
78 };
79 
80 struct InstanceOpConversion : public OpConversionPattern<hw::InstanceOp> {
81  InstanceOpConversion(TypeConverter &typeConverter, MLIRContext *context,
82  DenseSet<hw::InstanceOp> *convertedOps,
83  const StringSet<> *externModules)
84  : OpConversionPattern(typeConverter, context), convertedOps(convertedOps),
85  externModules(externModules) {}
86 
87  LogicalResult
88  matchAndRewrite(hw::InstanceOp op, OpAdaptor adaptor,
89  ConversionPatternRewriter &rewriter) const override {
90  auto referencedMod = op.getReferencedModuleNameAttr();
91  // If externModules is populated and this is an extern module instance,
92  // donot flatten it.
93  if (externModules->contains(referencedMod.getValue()))
94  return success();
95 
96  auto loc = op.getLoc();
97  // Flatten the operands.
98  llvm::SmallVector<Value> convOperands;
99  for (auto operand : adaptor.getOperands()) {
100  if (auto structType = getStructType(operand.getType())) {
101  auto explodedStruct = rewriter.create<hw::StructExplodeOp>(
102  loc, getInnerTypes(structType), operand);
103  llvm::copy(explodedStruct.getResults(),
104  std::back_inserter(convOperands));
105  } else {
106  convOperands.push_back(operand);
107  }
108  }
109 
110  // Get the new module return type.
111  llvm::SmallVector<Type> newResultTypes;
112  for (auto oldResultType : op.getResultTypes()) {
113  if (auto structType = getStructType(oldResultType))
114  for (auto t : structType.getElements())
115  newResultTypes.push_back(t.type);
116  else
117  newResultTypes.push_back(oldResultType);
118  }
119 
120  // Create the new instance with the flattened module, attributes will be
121  // adjusted later.
122  auto newInstance = rewriter.create<hw::InstanceOp>(
123  loc, newResultTypes, op.getInstanceNameAttr(),
124  FlatSymbolRefAttr::get(referencedMod), convOperands,
125  op.getArgNamesAttr(), op.getResultNamesAttr(), op.getParametersAttr(),
126  op.getInnerSymAttr());
127 
128  // re-create any structs in the result.
129  llvm::SmallVector<Value> convResults;
130  size_t oldResultCntr = 0;
131  for (size_t resIndex = 0; resIndex < newInstance.getNumResults();
132  ++resIndex) {
133  Type oldResultType = op.getResultTypes()[oldResultCntr];
134  if (auto structType = getStructType(oldResultType)) {
135  size_t nElements = structType.getElements().size();
136  auto implodedStruct = rewriter.create<hw::StructCreateOp>(
137  loc, structType,
138  newInstance.getResults().slice(resIndex, nElements));
139  convResults.push_back(implodedStruct.getResult());
140  resIndex += nElements - 1;
141  } else
142  convResults.push_back(newInstance.getResult(resIndex));
143 
144  ++oldResultCntr;
145  }
146  rewriter.replaceOp(op, convResults);
147  convertedOps->insert(newInstance);
148  return success();
149  }
150 
151  DenseSet<hw::InstanceOp> *convertedOps;
152  const StringSet<> *externModules;
153 };
154 
155 using IOTypes = std::pair<TypeRange, TypeRange>;
156 
157 struct IOInfo {
158  // A mapping between an arg/res index and the struct type of the given field.
159  DenseMap<unsigned, hw::StructType> argStructs, resStructs;
160 
161  // Records of the original arg/res types.
162  SmallVector<Type> argTypes, resTypes;
163 };
164 
165 class FlattenIOTypeConverter : public TypeConverter {
166 public:
167  FlattenIOTypeConverter() {
168  addConversion([](Type type, SmallVectorImpl<Type> &results) {
169  auto structType = getStructType(type);
170  if (!structType)
171  results.push_back(type);
172  else {
173  for (auto field : structType.getElements())
174 
175  results.push_back(field.type);
176  }
177  return success();
178  });
179 
180  addTargetMaterialization([](OpBuilder &builder, hw::StructType type,
181  ValueRange inputs, Location loc) {
182  auto result = builder.create<hw::StructCreateOp>(loc, type, inputs);
183  return result.getResult();
184  });
185 
186  addTargetMaterialization([](OpBuilder &builder, hw::TypeAliasType type,
187  ValueRange inputs, Location loc) {
188  auto structType = getStructType(type);
189  assert(structType && "expected struct type");
190  auto result = builder.create<hw::StructCreateOp>(loc, structType, inputs);
191  return result.getResult();
192  });
193  }
194 };
195 
196 } // namespace
197 
198 template <typename... TOp>
199 static void addSignatureConversion(DenseMap<Operation *, IOInfo> &ioMap,
200  ConversionTarget &target,
201  RewritePatternSet &patterns,
202  FlattenIOTypeConverter &typeConverter) {
203  (hw::populateHWModuleLikeTypeConversionPattern(TOp::getOperationName(),
204  patterns, typeConverter),
205  ...);
206 
207  // Legality is defined by a module having been processed once. This is due to
208  // that a pattern cannot be applied multiple times (a 'pattern was already
209  // applied' error - a case that would occur for nested structs). Additionally,
210  // if a pattern could be applied multiple times, this would complicate
211  // updating arg/res names.
212 
213  // Instead, we define legality as when a module has had a modification to its
214  // top-level i/o. This ensures that only a single level of structs are
215  // processed during signature conversion, which then allows us to use the
216  // signature conversion in a recursive manner.
217  target.addDynamicallyLegalOp<TOp...>([&](hw::HWModuleLike moduleLikeOp) {
218  if (isLegalModLikeOp(moduleLikeOp))
219  return true;
220 
221  // This op is involved in conversion. Check if the signature has changed.
222  auto ioInfoIt = ioMap.find(moduleLikeOp);
223  if (ioInfoIt == ioMap.end()) {
224  // Op wasn't primed in the map. Do the safe thing, assume
225  // that it's not considered in this pass, and mark it as legal
226  return true;
227  }
228  auto ioInfo = ioInfoIt->second;
229 
230  auto compareTypes = [&](TypeRange oldTypes, TypeRange newTypes) {
231  return llvm::any_of(llvm::zip(oldTypes, newTypes), [&](auto typePair) {
232  auto oldType = std::get<0>(typePair);
233  auto newType = std::get<1>(typePair);
234  return oldType != newType;
235  });
236  };
237  auto mtype = moduleLikeOp.getHWModuleType();
238  if (compareTypes(mtype.getOutputTypes(), ioInfo.resTypes) ||
239  compareTypes(mtype.getInputTypes(), ioInfo.argTypes))
240  return true;
241 
242  // We're pre-conversion for an op that was primed in the map - it will
243  // always be illegal since it has to-be-converted struct types at its I/O.
244  return false;
245  });
246 }
247 
248 template <typename T>
249 static bool hasUnconvertedOps(mlir::ModuleOp module) {
250  return llvm::any_of(module.getBody()->getOps<T>(),
251  [](T op) { return !isLegalModLikeOp(op); });
252 }
253 
254 template <typename T>
255 static DenseMap<Operation *, IOTypes> populateIOMap(mlir::ModuleOp module) {
256  DenseMap<Operation *, IOTypes> ioMap;
257  for (auto op : module.getOps<T>())
258  ioMap[op] = {op.getArgumentTypes(), op.getResultTypes()};
259  return ioMap;
260 }
261 
262 template <typename ModTy, typename T>
263 static llvm::SmallVector<Attribute>
264 updateNameAttribute(ModTy op, StringRef attrName,
265  DenseMap<unsigned, hw::StructType> &structMap, T oldNames,
266  char joinChar) {
267  llvm::SmallVector<Attribute> newNames;
268  for (auto [i, oldName] : llvm::enumerate(oldNames)) {
269  // Was this arg/res index a struct?
270  auto it = structMap.find(i);
271  if (it == structMap.end()) {
272  // No, keep old name.
273  newNames.push_back(StringAttr::get(op->getContext(), oldName));
274  continue;
275  }
276 
277  // Yes - create new names from the struct fields and the old name at the
278  // index.
279  auto structType = it->second;
280  for (auto field : structType.getElements())
281  newNames.push_back(StringAttr::get(
282  op->getContext(), oldName + Twine(joinChar) + field.name.str()));
283  }
284  return newNames;
285 }
286 
287 template <typename ModTy>
288 static void updateModulePortNames(ModTy op, hw::ModuleType oldModType,
289  char joinChar) {
290  // Module arg and result port names may not be ordered. So we cannot reuse
291  // updateNameAttribute. The arg and result order must be preserved.
292  SmallVector<Attribute> newNames;
293  SmallVector<hw::ModulePort> oldPorts(oldModType.getPorts().begin(),
294  oldModType.getPorts().end());
295  for (auto oldPort : oldPorts) {
296  auto oldName = oldPort.name;
297  if (auto structType = getStructType(oldPort.type)) {
298  for (auto field : structType.getElements()) {
299  newNames.push_back(StringAttr::get(
300  op->getContext(),
301  oldName.getValue() + Twine(joinChar) + field.name.str()));
302  }
303  } else
304  newNames.push_back(oldName);
305  }
306  op.setAllPortNames(newNames);
307 }
308 
309 static llvm::SmallVector<Location>
310 updateLocAttribute(DenseMap<unsigned, hw::StructType> &structMap,
311  SmallVectorImpl<Location> &oldLocs) {
312  llvm::SmallVector<Location> newLocs;
313  for (auto [i, oldLoc] : llvm::enumerate(oldLocs)) {
314  // Was this arg/res index a struct?
315  auto it = structMap.find(i);
316  if (it == structMap.end()) {
317  // No, keep old name.
318  newLocs.push_back(oldLoc);
319  continue;
320  }
321 
322  auto structType = it->second;
323  for (size_t i = 0, e = structType.getElements().size(); i < e; ++i)
324  newLocs.push_back(oldLoc);
325  }
326  return newLocs;
327 }
328 
329 /// The conversion framework seems to throw away block argument locations. We
330 /// use this function to copy the location from the original argument to the
331 /// set of flattened arguments.
332 static void
333 updateBlockLocations(hw::HWModuleLike op,
334  DenseMap<unsigned, hw::StructType> &structMap) {
335  auto locs = op.getInputLocs();
336  if (locs.empty() || op.getModuleBody().empty())
337  return;
338  for (auto [arg, loc] : llvm::zip(op.getBodyBlock()->getArguments(), locs))
339  arg.setLoc(loc);
340 }
341 
342 static void setIOInfo(hw::HWModuleLike op, IOInfo &ioInfo) {
343  ioInfo.argTypes = op.getInputTypes();
344  ioInfo.resTypes = op.getOutputTypes();
345  for (auto [i, arg] : llvm::enumerate(ioInfo.argTypes)) {
346  if (auto structType = getStructType(arg))
347  ioInfo.argStructs[i] = structType;
348  }
349  for (auto [i, res] : llvm::enumerate(ioInfo.resTypes)) {
350  if (auto structType = getStructType(res))
351  ioInfo.resStructs[i] = structType;
352  }
353 }
354 
355 template <typename T>
356 static DenseMap<Operation *, IOInfo> populateIOInfoMap(mlir::ModuleOp module) {
357  DenseMap<Operation *, IOInfo> ioInfoMap;
358  for (auto op : module.getOps<T>()) {
359  IOInfo ioInfo;
360  setIOInfo(op, ioInfo);
361  ioInfoMap[op] = ioInfo;
362  }
363  return ioInfoMap;
364 }
365 
366 template <typename T>
367 static LogicalResult flattenOpsOfType(ModuleOp module, bool recursive,
368  StringSet<> &externModules,
369  char joinChar) {
370  auto *ctx = module.getContext();
371  FlattenIOTypeConverter typeConverter;
372 
373  // Recursively (in case of nested structs) lower the module. We do this one
374  // conversion at a time to allow for updating the arg/res names of the
375  // module in between flattening each level of structs.
376  while (hasUnconvertedOps<T>(module)) {
377  ConversionTarget target(*ctx);
378  RewritePatternSet patterns(ctx);
379  target.addLegalDialect<hw::HWDialect>();
380 
381  // Record any struct types at the module signature. This will be used
382  // post-conversion to update the argument and result names.
383  auto ioInfoMap = populateIOInfoMap<T>(module);
384 
385  // Record the instances that were converted. We keep these around since we
386  // need to update their arg/res attribute names after the modules themselves
387  // have been updated.
388  llvm::DenseSet<hw::InstanceOp> convertedInstances;
389 
390  // Argument conversion for output ops. Similarly to the signature
391  // conversion, legality is based on the op having been visited once, due to
392  // the possibility of nested structs.
393  DenseSet<Operation *> opVisited;
394  patterns.add<OutputOpConversion>(typeConverter, ctx, &opVisited);
395 
396  patterns.add<InstanceOpConversion>(typeConverter, ctx, &convertedInstances,
397  &externModules);
398  target.addDynamicallyLegalOp<hw::OutputOp>(
399  [&](auto op) { return opVisited.contains(op->getParentOp()); });
400  target.addDynamicallyLegalOp<hw::InstanceOp>([&](hw::InstanceOp op) {
401  auto refName = op.getReferencedModuleName();
402  return externModules.contains(refName) ||
403  llvm::none_of(op->getOperands(), [](auto operand) {
404  return isStructType(operand.getType());
405  });
406  });
407 
408  DenseMap<Operation *, ArrayAttr> oldArgNames, oldResNames;
409  DenseMap<Operation *, SmallVector<Location>> oldArgLocs, oldResLocs;
410  DenseMap<Operation *, hw::ModuleType> oldModTypes;
411 
412  for (auto op : module.getOps<T>()) {
413  oldModTypes[op] = op.getHWModuleType();
414  oldArgNames[op] = ArrayAttr::get(module.getContext(), op.getInputNames());
415  oldResNames[op] =
416  ArrayAttr::get(module.getContext(), op.getOutputNames());
417  oldArgLocs[op] = op.getInputLocs();
418  oldResLocs[op] = op.getOutputLocs();
419  }
420 
421  // Signature conversion and legalization patterns.
422  addSignatureConversion<T>(ioInfoMap, target, patterns, typeConverter);
423 
424  if (failed(applyPartialConversion(module, target, std::move(patterns))))
425  return failure();
426 
427  // Update the arg/res names of the module.
428  for (auto op : module.getOps<T>()) {
429  auto ioInfo = ioInfoMap[op];
430  updateModulePortNames(op, oldModTypes[op], joinChar);
431  auto newArgLocs = updateLocAttribute(ioInfo.argStructs, oldArgLocs[op]);
432  auto newResLocs = updateLocAttribute(ioInfo.resStructs, oldResLocs[op]);
433  newArgLocs.append(newResLocs.begin(), newResLocs.end());
434  op.setAllPortLocs(newArgLocs);
435  updateBlockLocations(op, ioInfo.argStructs);
436  }
437 
438  // And likewise with the converted instance ops.
439  for (auto instanceOp : convertedInstances) {
440  auto targetModule =
441  cast<hw::HWModuleLike>(SymbolTable::lookupNearestSymbolFrom(
442  instanceOp, instanceOp.getReferencedModuleNameAttr()));
443 
444  IOInfo ioInfo;
445  if (!ioInfoMap.contains(targetModule)) {
446  // If an extern module, then not yet processed, populate the maps.
447  setIOInfo(targetModule, ioInfo);
448  ioInfoMap[targetModule] = ioInfo;
449  oldArgNames[targetModule] =
450  ArrayAttr::get(module.getContext(), targetModule.getInputNames());
451  oldResNames[targetModule] =
452  ArrayAttr::get(module.getContext(), targetModule.getOutputNames());
453  oldArgLocs[targetModule] = targetModule.getInputLocs();
454  oldResLocs[targetModule] = targetModule.getOutputLocs();
455  } else
456  ioInfo = ioInfoMap[targetModule];
457 
458  instanceOp.setInputNames(ArrayAttr::get(
459  instanceOp.getContext(),
461  instanceOp, "argNames", ioInfo.argStructs,
462  oldArgNames[targetModule].template getAsValueRange<StringAttr>(),
463  joinChar)));
464  instanceOp.setOutputNames(ArrayAttr::get(
465  instanceOp.getContext(),
467  instanceOp, "resultNames", ioInfo.resStructs,
468  oldResNames[targetModule].template getAsValueRange<StringAttr>(),
469  joinChar)));
470  }
471 
472  // Break if we've only lowering a single level of structs.
473  if (!recursive)
474  break;
475  }
476  return success();
477 }
478 
479 //===----------------------------------------------------------------------===//
480 // Pass driver
481 //===----------------------------------------------------------------------===//
482 
483 template <typename... TOps>
484 static bool flattenIO(ModuleOp module, bool recursive,
485  StringSet<> &externModules, char joinChar) {
486  return (failed(flattenOpsOfType<TOps>(module, recursive, externModules,
487  joinChar)) ||
488  ...);
489 }
490 
491 namespace {
492 
493 class FlattenIOPass : public circt::hw::impl::FlattenIOBase<FlattenIOPass> {
494 public:
495  FlattenIOPass(bool recursiveFlag, bool flattenExternFlag, char join) {
496  recursive = recursiveFlag;
497  flattenExtern = flattenExternFlag;
498  joinChar = join;
499  }
500 
501  void runOnOperation() override {
502  ModuleOp module = getOperation();
503  if (!flattenExtern) {
504  // Record the extern modules, donot flatten them.
505  for (auto m : module.getOps<hw::HWModuleExternOp>())
506  externModules.insert(m.getModuleName());
507  if (flattenIO<hw::HWModuleOp, hw::HWModuleGeneratedOp>(
508  module, recursive, externModules, joinChar))
509  signalPassFailure();
510  return;
511  }
512 
514  hw::HWModuleGeneratedOp>(module, recursive, externModules,
515  joinChar))
516  signalPassFailure();
517  };
518 
519 private:
520  StringSet<> externModules;
521 };
522 } // namespace
523 
524 //===----------------------------------------------------------------------===//
525 // Pass initialization
526 //===----------------------------------------------------------------------===//
527 
528 std::unique_ptr<Pass> circt::hw::createFlattenIOPass(bool recursiveFlag,
529  bool flattenExternFlag,
530  char joinChar) {
531  return std::make_unique<FlattenIOPass>(recursiveFlag, flattenExternFlag,
532  joinChar);
533 }
assert(baseType &&"element must be base type")
static LogicalResult compareTypes(Location loc, TypeRange rangeA, TypeRange rangeB)
Definition: FSMOps.cpp:118
static llvm::SmallVector< Type > getInnerTypes(hw::StructType t)
Definition: FlattenIO.cpp:39
static llvm::SmallVector< Location > updateLocAttribute(DenseMap< unsigned, hw::StructType > &structMap, SmallVectorImpl< Location > &oldLocs)
Definition: FlattenIO.cpp:310
static bool isLegalModLikeOp(hw::HWModuleLike moduleLikeOp)
Definition: FlattenIO.cpp:34
static void updateBlockLocations(hw::HWModuleLike op, DenseMap< unsigned, hw::StructType > &structMap)
The conversion framework seems to throw away block argument locations.
Definition: FlattenIO.cpp:333
static bool hasUnconvertedOps(mlir::ModuleOp module)
Definition: FlattenIO.cpp:249
static llvm::SmallVector< Attribute > updateNameAttribute(ModTy op, StringRef attrName, DenseMap< unsigned, hw::StructType > &structMap, T oldNames, char joinChar)
Definition: FlattenIO.cpp:264
static void setIOInfo(hw::HWModuleLike op, IOInfo &ioInfo)
Definition: FlattenIO.cpp:342
static LogicalResult flattenOpsOfType(ModuleOp module, bool recursive, StringSet<> &externModules, char joinChar)
Definition: FlattenIO.cpp:367
static DenseMap< Operation *, IOTypes > populateIOMap(mlir::ModuleOp module)
Definition: FlattenIO.cpp:255
static void addSignatureConversion(DenseMap< Operation *, IOInfo > &ioMap, ConversionTarget &target, RewritePatternSet &patterns, FlattenIOTypeConverter &typeConverter)
Definition: FlattenIO.cpp:199
static bool isStructType(Type type)
Definition: FlattenIO.cpp:25
static void updateModulePortNames(ModTy op, hw::ModuleType oldModType, char joinChar)
Definition: FlattenIO.cpp:288
static hw::StructType getStructType(Type type)
Definition: FlattenIO.cpp:29
static bool flattenIO(ModuleOp module, bool recursive, StringSet<> &externModules, char joinChar)
Definition: FlattenIO.cpp:484
static DenseMap< Operation *, IOInfo > populateIOInfoMap(mlir::ModuleOp module)
Definition: FlattenIO.cpp:356
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:54
mlir::Type innerType(mlir::Type type)
Definition: ESITypes.cpp:184
void populateHWModuleLikeTypeConversionPattern(StringRef moduleLikeOpName, RewritePatternSet &patterns, TypeConverter &converter)
mlir::Type getCanonicalType(mlir::Type type)
Definition: HWTypes.cpp:48
std::unique_ptr< mlir::Pass > createFlattenIOPass(bool recursiveFlag=true, bool flattenExternFlag=false, char joinChar='.')
Definition: FlattenIO.cpp:528
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21
Definition: hw.py:1