CIRCT  20.0.0git
FlattenIO.cpp
Go to the documentation of this file.
1 //===- FlattenIO.cpp - HW I/O flattening pass -------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
11 #include "mlir/Pass/Pass.h"
12 #include "mlir/Transforms/DialectConversion.h"
13 #include "llvm/ADT/TypeSwitch.h"
14 
15 namespace circt {
16 namespace hw {
17 #define GEN_PASS_DEF_FLATTENIO
18 #include "circt/Dialect/HW/Passes.h.inc"
19 } // namespace hw
20 } // namespace circt
21 
22 using namespace mlir;
23 using namespace circt;
24 
25 static bool isStructType(Type type) {
26  return isa<hw::StructType>(hw::getCanonicalType(type));
27 }
28 
29 static hw::StructType getStructType(Type type) {
30  return dyn_cast<hw::StructType>(hw::getCanonicalType(type));
31 }
32 
33 // Legal if no in- or output type is a struct.
34 static bool isLegalModLikeOp(hw::HWModuleLike moduleLikeOp) {
35  return llvm::none_of(moduleLikeOp.getHWModuleType().getPortTypes(),
36  isStructType);
37 }
38 
39 static llvm::SmallVector<Type> getInnerTypes(hw::StructType t) {
40  llvm::SmallVector<Type> inner;
41  t.getInnerTypes(inner);
42  for (auto [index, innerType] : llvm::enumerate(inner))
43  inner[index] = hw::getCanonicalType(innerType);
44  return inner;
45 }
46 
47 namespace {
48 
49 // Replaces an output op with a new output with flattened (exploded) structs.
50 struct OutputOpConversion : public OpConversionPattern<hw::OutputOp> {
51  OutputOpConversion(TypeConverter &typeConverter, MLIRContext *context,
52  DenseSet<Operation *> *opVisited)
53  : OpConversionPattern(typeConverter, context), opVisited(opVisited) {}
54 
55  LogicalResult
56  matchAndRewrite(hw::OutputOp op, OpAdaptor adaptor,
57  ConversionPatternRewriter &rewriter) const override {
58  llvm::SmallVector<Value> convOperands;
59 
60  // Flatten the operands.
61  for (auto operand : adaptor.getOperands()) {
62  if (auto structType = getStructType(operand.getType())) {
63  auto explodedStruct = rewriter.create<hw::StructExplodeOp>(
64  op.getLoc(), getInnerTypes(structType), operand);
65  llvm::copy(explodedStruct.getResults(),
66  std::back_inserter(convOperands));
67  } else {
68  convOperands.push_back(operand);
69  }
70  }
71 
72  // And replace.
73  opVisited->insert(op->getParentOp());
74  rewriter.replaceOpWithNewOp<hw::OutputOp>(op, convOperands);
75  return success();
76  }
77  DenseSet<Operation *> *opVisited;
78 };
79 
80 struct InstanceOpConversion : public OpConversionPattern<hw::InstanceOp> {
81  InstanceOpConversion(TypeConverter &typeConverter, MLIRContext *context,
82  DenseSet<hw::InstanceOp> *convertedOps,
83  const StringSet<> *externModules)
84  : OpConversionPattern(typeConverter, context), convertedOps(convertedOps),
85  externModules(externModules) {}
86 
87  LogicalResult
88  matchAndRewrite(hw::InstanceOp op, OpAdaptor adaptor,
89  ConversionPatternRewriter &rewriter) const override {
90  auto referencedMod = op.getReferencedModuleNameAttr();
91  // If externModules is populated and this is an extern module instance,
92  // donot flatten it.
93  if (externModules->contains(referencedMod.getValue()))
94  return success();
95 
96  auto loc = op.getLoc();
97  // Flatten the operands.
98  llvm::SmallVector<Value> convOperands;
99  for (auto operand : adaptor.getOperands()) {
100  if (auto structType = getStructType(operand.getType())) {
101  auto explodedStruct = rewriter.create<hw::StructExplodeOp>(
102  loc, getInnerTypes(structType), operand);
103  llvm::copy(explodedStruct.getResults(),
104  std::back_inserter(convOperands));
105  } else {
106  convOperands.push_back(operand);
107  }
108  }
109 
110  // Get the new module return type.
111  llvm::SmallVector<Type> newResultTypes;
112  for (auto oldResultType : op.getResultTypes()) {
113  if (auto structType = getStructType(oldResultType))
114  for (auto t : structType.getElements())
115  newResultTypes.push_back(t.type);
116  else
117  newResultTypes.push_back(oldResultType);
118  }
119 
120  // Create the new instance with the flattened module, attributes will be
121  // adjusted later.
122  auto newInstance = rewriter.create<hw::InstanceOp>(
123  loc, newResultTypes, op.getInstanceNameAttr(),
124  FlatSymbolRefAttr::get(referencedMod), convOperands,
125  op.getArgNamesAttr(), op.getResultNamesAttr(), op.getParametersAttr(),
126  op.getInnerSymAttr(), op.getDoNotPrintAttr());
127 
128  // re-create any structs in the result.
129  llvm::SmallVector<Value> convResults;
130  size_t oldResultCntr = 0;
131  for (size_t resIndex = 0; resIndex < newInstance.getNumResults();
132  ++resIndex) {
133  Type oldResultType = op.getResultTypes()[oldResultCntr];
134  if (auto structType = getStructType(oldResultType)) {
135  size_t nElements = structType.getElements().size();
136  auto implodedStruct = rewriter.create<hw::StructCreateOp>(
137  loc, structType,
138  newInstance.getResults().slice(resIndex, nElements));
139  convResults.push_back(implodedStruct.getResult());
140  resIndex += nElements - 1;
141  } else
142  convResults.push_back(newInstance.getResult(resIndex));
143 
144  ++oldResultCntr;
145  }
146  rewriter.replaceOp(op, convResults);
147  convertedOps->insert(newInstance);
148  return success();
149  }
150 
151  DenseSet<hw::InstanceOp> *convertedOps;
152  const StringSet<> *externModules;
153 };
154 
155 using IOTypes = std::pair<TypeRange, TypeRange>;
156 
157 struct IOInfo {
158  // A mapping between an arg/res index and the struct type of the given field.
159  DenseMap<unsigned, hw::StructType> argStructs, resStructs;
160 
161  // Records of the original arg/res types.
162  SmallVector<Type> argTypes, resTypes;
163 };
164 
165 class FlattenIOTypeConverter : public TypeConverter {
166 public:
167  FlattenIOTypeConverter() {
168  addConversion([](Type type, SmallVectorImpl<Type> &results) {
169  auto structType = getStructType(type);
170  if (!structType)
171  results.push_back(type);
172  else {
173  for (auto field : structType.getElements())
174  results.push_back(field.type);
175  }
176  return success();
177  });
178 
179  addTargetMaterialization([](OpBuilder &builder, hw::StructType type,
180  ValueRange inputs, Location loc) {
181  auto result = builder.create<hw::StructCreateOp>(loc, type, inputs);
182  return result.getResult();
183  });
184 
185  addTargetMaterialization([](OpBuilder &builder, hw::TypeAliasType type,
186  ValueRange inputs, Location loc) {
187  auto result = builder.create<hw::StructCreateOp>(loc, type, inputs);
188  return result.getResult();
189  });
190  }
191 };
192 
193 } // namespace
194 
195 template <typename... TOp>
196 static void addSignatureConversion(DenseMap<Operation *, IOInfo> &ioMap,
197  ConversionTarget &target,
198  RewritePatternSet &patterns,
199  FlattenIOTypeConverter &typeConverter) {
200  (hw::populateHWModuleLikeTypeConversionPattern(TOp::getOperationName(),
201  patterns, typeConverter),
202  ...);
203 
204  // Legality is defined by a module having been processed once. This is due to
205  // that a pattern cannot be applied multiple times (a 'pattern was already
206  // applied' error - a case that would occur for nested structs). Additionally,
207  // if a pattern could be applied multiple times, this would complicate
208  // updating arg/res names.
209 
210  // Instead, we define legality as when a module has had a modification to its
211  // top-level i/o. This ensures that only a single level of structs are
212  // processed during signature conversion, which then allows us to use the
213  // signature conversion in a recursive manner.
214  target.addDynamicallyLegalOp<TOp...>([&](hw::HWModuleLike moduleLikeOp) {
215  if (isLegalModLikeOp(moduleLikeOp))
216  return true;
217 
218  // This op is involved in conversion. Check if the signature has changed.
219  auto ioInfoIt = ioMap.find(moduleLikeOp);
220  if (ioInfoIt == ioMap.end()) {
221  // Op wasn't primed in the map. Do the safe thing, assume
222  // that it's not considered in this pass, and mark it as legal
223  return true;
224  }
225  auto ioInfo = ioInfoIt->second;
226 
227  auto compareTypes = [&](TypeRange oldTypes, TypeRange newTypes) {
228  return llvm::any_of(llvm::zip(oldTypes, newTypes), [&](auto typePair) {
229  auto oldType = std::get<0>(typePair);
230  auto newType = std::get<1>(typePair);
231  return oldType != newType;
232  });
233  };
234  auto mtype = moduleLikeOp.getHWModuleType();
235  if (compareTypes(mtype.getOutputTypes(), ioInfo.resTypes) ||
236  compareTypes(mtype.getInputTypes(), ioInfo.argTypes))
237  return true;
238 
239  // We're pre-conversion for an op that was primed in the map - it will
240  // always be illegal since it has to-be-converted struct types at its I/O.
241  return false;
242  });
243 }
244 
245 template <typename T>
246 static bool hasUnconvertedOps(mlir::ModuleOp module) {
247  return llvm::any_of(module.getBody()->getOps<T>(),
248  [](T op) { return !isLegalModLikeOp(op); });
249 }
250 
251 template <typename T>
252 static DenseMap<Operation *, IOTypes> populateIOMap(mlir::ModuleOp module) {
253  DenseMap<Operation *, IOTypes> ioMap;
254  for (auto op : module.getOps<T>())
255  ioMap[op] = {op.getArgumentTypes(), op.getResultTypes()};
256  return ioMap;
257 }
258 
259 template <typename ModTy, typename T>
260 static llvm::SmallVector<Attribute>
261 updateNameAttribute(ModTy op, StringRef attrName,
262  DenseMap<unsigned, hw::StructType> &structMap, T oldNames,
263  char joinChar) {
264  llvm::SmallVector<Attribute> newNames;
265  for (auto [i, oldName] : llvm::enumerate(oldNames)) {
266  // Was this arg/res index a struct?
267  auto it = structMap.find(i);
268  if (it == structMap.end()) {
269  // No, keep old name.
270  newNames.push_back(StringAttr::get(op->getContext(), oldName));
271  continue;
272  }
273 
274  // Yes - create new names from the struct fields and the old name at the
275  // index.
276  auto structType = it->second;
277  for (auto field : structType.getElements())
278  newNames.push_back(StringAttr::get(
279  op->getContext(), oldName + Twine(joinChar) + field.name.str()));
280  }
281  return newNames;
282 }
283 
284 template <typename ModTy>
285 static void updateModulePortNames(ModTy op, hw::ModuleType oldModType,
286  char joinChar) {
287  // Module arg and result port names may not be ordered. So we cannot reuse
288  // updateNameAttribute. The arg and result order must be preserved.
289  SmallVector<Attribute> newNames;
290  SmallVector<hw::ModulePort> oldPorts(oldModType.getPorts().begin(),
291  oldModType.getPorts().end());
292  for (auto oldPort : oldPorts) {
293  auto oldName = oldPort.name;
294  if (auto structType = getStructType(oldPort.type)) {
295  for (auto field : structType.getElements()) {
296  newNames.push_back(StringAttr::get(
297  op->getContext(),
298  oldName.getValue() + Twine(joinChar) + field.name.str()));
299  }
300  } else
301  newNames.push_back(oldName);
302  }
303  op.setAllPortNames(newNames);
304 }
305 
306 static llvm::SmallVector<Location>
307 updateLocAttribute(DenseMap<unsigned, hw::StructType> &structMap,
308  SmallVectorImpl<Location> &oldLocs) {
309  llvm::SmallVector<Location> newLocs;
310  for (auto [i, oldLoc] : llvm::enumerate(oldLocs)) {
311  // Was this arg/res index a struct?
312  auto it = structMap.find(i);
313  if (it == structMap.end()) {
314  // No, keep old name.
315  newLocs.push_back(oldLoc);
316  continue;
317  }
318 
319  auto structType = it->second;
320  for (size_t i = 0, e = structType.getElements().size(); i < e; ++i)
321  newLocs.push_back(oldLoc);
322  }
323  return newLocs;
324 }
325 
326 /// The conversion framework seems to throw away block argument locations. We
327 /// use this function to copy the location from the original argument to the
328 /// set of flattened arguments.
329 static void
330 updateBlockLocations(hw::HWModuleLike op,
331  DenseMap<unsigned, hw::StructType> &structMap) {
332  auto locs = op.getInputLocs();
333  if (locs.empty() || op.getModuleBody().empty())
334  return;
335  for (auto [arg, loc] : llvm::zip(op.getBodyBlock()->getArguments(), locs))
336  arg.setLoc(loc);
337 }
338 
339 static void setIOInfo(hw::HWModuleLike op, IOInfo &ioInfo) {
340  ioInfo.argTypes = op.getInputTypes();
341  ioInfo.resTypes = op.getOutputTypes();
342  for (auto [i, arg] : llvm::enumerate(ioInfo.argTypes)) {
343  if (auto structType = getStructType(arg))
344  ioInfo.argStructs[i] = structType;
345  }
346  for (auto [i, res] : llvm::enumerate(ioInfo.resTypes)) {
347  if (auto structType = getStructType(res))
348  ioInfo.resStructs[i] = structType;
349  }
350 }
351 
352 template <typename T>
353 static DenseMap<Operation *, IOInfo> populateIOInfoMap(mlir::ModuleOp module) {
354  DenseMap<Operation *, IOInfo> ioInfoMap;
355  for (auto op : module.getOps<T>()) {
356  IOInfo ioInfo;
357  setIOInfo(op, ioInfo);
358  ioInfoMap[op] = ioInfo;
359  }
360  return ioInfoMap;
361 }
362 
363 template <typename T>
364 static LogicalResult flattenOpsOfType(ModuleOp module, bool recursive,
365  StringSet<> &externModules,
366  char joinChar) {
367  auto *ctx = module.getContext();
368  FlattenIOTypeConverter typeConverter;
369 
370  // Recursively (in case of nested structs) lower the module. We do this one
371  // conversion at a time to allow for updating the arg/res names of the
372  // module in between flattening each level of structs.
373  while (hasUnconvertedOps<T>(module)) {
374  ConversionTarget target(*ctx);
375  RewritePatternSet patterns(ctx);
376  target.addLegalDialect<hw::HWDialect>();
377 
378  // Record any struct types at the module signature. This will be used
379  // post-conversion to update the argument and result names.
380  auto ioInfoMap = populateIOInfoMap<T>(module);
381 
382  // Record the instances that were converted. We keep these around since we
383  // need to update their arg/res attribute names after the modules themselves
384  // have been updated.
385  llvm::DenseSet<hw::InstanceOp> convertedInstances;
386 
387  // Argument conversion for output ops. Similarly to the signature
388  // conversion, legality is based on the op having been visited once, due to
389  // the possibility of nested structs.
390  DenseSet<Operation *> opVisited;
391  patterns.add<OutputOpConversion>(typeConverter, ctx, &opVisited);
392 
393  patterns.add<InstanceOpConversion>(typeConverter, ctx, &convertedInstances,
394  &externModules);
395  target.addDynamicallyLegalOp<hw::OutputOp>(
396  [&](auto op) { return opVisited.contains(op->getParentOp()); });
397  target.addDynamicallyLegalOp<hw::InstanceOp>([&](hw::InstanceOp op) {
398  auto refName = op.getReferencedModuleName();
399  return externModules.contains(refName) ||
400  llvm::none_of(op->getOperands(), [](auto operand) {
401  return isStructType(operand.getType());
402  });
403  });
404 
405  DenseMap<Operation *, ArrayAttr> oldArgNames, oldResNames;
406  DenseMap<Operation *, SmallVector<Location>> oldArgLocs, oldResLocs;
407  DenseMap<Operation *, hw::ModuleType> oldModTypes;
408 
409  for (auto op : module.getOps<T>()) {
410  oldModTypes[op] = op.getHWModuleType();
411  oldArgNames[op] = ArrayAttr::get(module.getContext(), op.getInputNames());
412  oldResNames[op] =
413  ArrayAttr::get(module.getContext(), op.getOutputNames());
414  oldArgLocs[op] = op.getInputLocs();
415  oldResLocs[op] = op.getOutputLocs();
416  }
417 
418  // Signature conversion and legalization patterns.
419  addSignatureConversion<T>(ioInfoMap, target, patterns, typeConverter);
420 
421  if (failed(applyPartialConversion(module, target, std::move(patterns))))
422  return failure();
423 
424  // Update the arg/res names of the module.
425  for (auto op : module.getOps<T>()) {
426  auto ioInfo = ioInfoMap[op];
427  updateModulePortNames(op, oldModTypes[op], joinChar);
428  auto newArgLocs = updateLocAttribute(ioInfo.argStructs, oldArgLocs[op]);
429  auto newResLocs = updateLocAttribute(ioInfo.resStructs, oldResLocs[op]);
430  newArgLocs.append(newResLocs.begin(), newResLocs.end());
431  op.setAllPortLocs(newArgLocs);
432  updateBlockLocations(op, ioInfo.argStructs);
433  }
434 
435  // And likewise with the converted instance ops.
436  for (auto instanceOp : convertedInstances) {
437  auto targetModule =
438  cast<hw::HWModuleLike>(SymbolTable::lookupNearestSymbolFrom(
439  instanceOp, instanceOp.getReferencedModuleNameAttr()));
440 
441  IOInfo ioInfo;
442  if (!ioInfoMap.contains(targetModule)) {
443  // If an extern module, then not yet processed, populate the maps.
444  setIOInfo(targetModule, ioInfo);
445  ioInfoMap[targetModule] = ioInfo;
446  oldArgNames[targetModule] =
447  ArrayAttr::get(module.getContext(), targetModule.getInputNames());
448  oldResNames[targetModule] =
449  ArrayAttr::get(module.getContext(), targetModule.getOutputNames());
450  oldArgLocs[targetModule] = targetModule.getInputLocs();
451  oldResLocs[targetModule] = targetModule.getOutputLocs();
452  } else
453  ioInfo = ioInfoMap[targetModule];
454 
455  instanceOp.setInputNames(ArrayAttr::get(
456  instanceOp.getContext(),
458  instanceOp, "argNames", ioInfo.argStructs,
459  oldArgNames[targetModule].template getAsValueRange<StringAttr>(),
460  joinChar)));
461  instanceOp.setOutputNames(ArrayAttr::get(
462  instanceOp.getContext(),
464  instanceOp, "resultNames", ioInfo.resStructs,
465  oldResNames[targetModule].template getAsValueRange<StringAttr>(),
466  joinChar)));
467  }
468 
469  // Break if we've only lowering a single level of structs.
470  if (!recursive)
471  break;
472  }
473  return success();
474 }
475 
476 //===----------------------------------------------------------------------===//
477 // Pass driver
478 //===----------------------------------------------------------------------===//
479 
480 template <typename... TOps>
481 static bool flattenIO(ModuleOp module, bool recursive,
482  StringSet<> &externModules, char joinChar) {
483  return (failed(flattenOpsOfType<TOps>(module, recursive, externModules,
484  joinChar)) ||
485  ...);
486 }
487 
488 namespace {
489 
490 class FlattenIOPass : public circt::hw::impl::FlattenIOBase<FlattenIOPass> {
491 public:
492  FlattenIOPass(bool recursiveFlag, bool flattenExternFlag, char join) {
493  recursive = recursiveFlag;
494  flattenExtern = flattenExternFlag;
495  joinChar = join;
496  }
497 
498  void runOnOperation() override {
499  ModuleOp module = getOperation();
500  if (!flattenExtern) {
501  // Record the extern modules, donot flatten them.
502  for (auto m : module.getOps<hw::HWModuleExternOp>())
503  externModules.insert(m.getModuleName());
504  if (flattenIO<hw::HWModuleOp, hw::HWModuleGeneratedOp>(
505  module, recursive, externModules, joinChar))
506  signalPassFailure();
507  return;
508  }
509 
511  hw::HWModuleGeneratedOp>(module, recursive, externModules,
512  joinChar))
513  signalPassFailure();
514  };
515 
516 private:
517  StringSet<> externModules;
518 };
519 } // namespace
520 
521 //===----------------------------------------------------------------------===//
522 // Pass initialization
523 //===----------------------------------------------------------------------===//
524 
525 std::unique_ptr<Pass> circt::hw::createFlattenIOPass(bool recursiveFlag,
526  bool flattenExternFlag,
527  char joinChar) {
528  return std::make_unique<FlattenIOPass>(recursiveFlag, flattenExternFlag,
529  joinChar);
530 }
static LogicalResult compareTypes(Location loc, TypeRange rangeA, TypeRange rangeB)
Definition: FSMOps.cpp:118
static llvm::SmallVector< Type > getInnerTypes(hw::StructType t)
Definition: FlattenIO.cpp:39
static llvm::SmallVector< Location > updateLocAttribute(DenseMap< unsigned, hw::StructType > &structMap, SmallVectorImpl< Location > &oldLocs)
Definition: FlattenIO.cpp:307
static bool isLegalModLikeOp(hw::HWModuleLike moduleLikeOp)
Definition: FlattenIO.cpp:34
static void updateBlockLocations(hw::HWModuleLike op, DenseMap< unsigned, hw::StructType > &structMap)
The conversion framework seems to throw away block argument locations.
Definition: FlattenIO.cpp:330
static bool hasUnconvertedOps(mlir::ModuleOp module)
Definition: FlattenIO.cpp:246
static llvm::SmallVector< Attribute > updateNameAttribute(ModTy op, StringRef attrName, DenseMap< unsigned, hw::StructType > &structMap, T oldNames, char joinChar)
Definition: FlattenIO.cpp:261
static void setIOInfo(hw::HWModuleLike op, IOInfo &ioInfo)
Definition: FlattenIO.cpp:339
static LogicalResult flattenOpsOfType(ModuleOp module, bool recursive, StringSet<> &externModules, char joinChar)
Definition: FlattenIO.cpp:364
static DenseMap< Operation *, IOTypes > populateIOMap(mlir::ModuleOp module)
Definition: FlattenIO.cpp:252
static void addSignatureConversion(DenseMap< Operation *, IOInfo > &ioMap, ConversionTarget &target, RewritePatternSet &patterns, FlattenIOTypeConverter &typeConverter)
Definition: FlattenIO.cpp:196
static bool isStructType(Type type)
Definition: FlattenIO.cpp:25
static void updateModulePortNames(ModTy op, hw::ModuleType oldModType, char joinChar)
Definition: FlattenIO.cpp:285
static hw::StructType getStructType(Type type)
Definition: FlattenIO.cpp:29
static bool flattenIO(ModuleOp module, bool recursive, StringSet<> &externModules, char joinChar)
Definition: FlattenIO.cpp:481
static DenseMap< Operation *, IOInfo > populateIOInfoMap(mlir::ModuleOp module)
Definition: FlattenIO.cpp:353
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:55
mlir::Type innerType(mlir::Type type)
Definition: ESITypes.cpp:184
void populateHWModuleLikeTypeConversionPattern(StringRef moduleLikeOpName, RewritePatternSet &patterns, TypeConverter &converter)
mlir::Type getCanonicalType(mlir::Type type)
Definition: HWTypes.cpp:49
std::unique_ptr< mlir::Pass > createFlattenIOPass(bool recursiveFlag=true, bool flattenExternFlag=false, char joinChar='.')
Definition: FlattenIO.cpp:525
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21
Definition: hw.py:1