CIRCT  19.0.0git
FlattenIO.cpp
Go to the documentation of this file.
1 //===- FlattenIO.cpp - HW I/O flattening pass -------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "PassDetails.h"
10 #include "circt/Dialect/HW/HWOps.h"
12 #include "mlir/Transforms/DialectConversion.h"
13 #include "llvm/ADT/TypeSwitch.h"
14 
15 using namespace mlir;
16 using namespace circt;
17 
18 static bool isStructType(Type type) {
19  return isa<hw::StructType>(hw::getCanonicalType(type));
20 }
21 
22 static hw::StructType getStructType(Type type) {
23  return dyn_cast<hw::StructType>(hw::getCanonicalType(type));
24 }
25 
26 // Legal if no in- or output type is a struct.
27 static bool isLegalModLikeOp(hw::HWModuleLike moduleLikeOp) {
28  return llvm::none_of(moduleLikeOp.getHWModuleType().getPortTypes(),
29  isStructType);
30 }
31 
32 static llvm::SmallVector<Type> getInnerTypes(hw::StructType t) {
33  llvm::SmallVector<Type> inner;
34  t.getInnerTypes(inner);
35  for (auto [index, innerType] : llvm::enumerate(inner))
36  inner[index] = hw::getCanonicalType(innerType);
37  return inner;
38 }
39 
40 namespace {
41 
42 // Replaces an output op with a new output with flattened (exploded) structs.
43 struct OutputOpConversion : public OpConversionPattern<hw::OutputOp> {
44  OutputOpConversion(TypeConverter &typeConverter, MLIRContext *context,
45  DenseSet<Operation *> *opVisited)
46  : OpConversionPattern(typeConverter, context), opVisited(opVisited) {}
47 
48  LogicalResult
49  matchAndRewrite(hw::OutputOp op, OpAdaptor adaptor,
50  ConversionPatternRewriter &rewriter) const override {
51  llvm::SmallVector<Value> convOperands;
52 
53  // Flatten the operands.
54  for (auto operand : adaptor.getOperands()) {
55  if (auto structType = getStructType(operand.getType())) {
56  auto explodedStruct = rewriter.create<hw::StructExplodeOp>(
57  op.getLoc(), getInnerTypes(structType), operand);
58  llvm::copy(explodedStruct.getResults(),
59  std::back_inserter(convOperands));
60  } else {
61  convOperands.push_back(operand);
62  }
63  }
64 
65  // And replace.
66  rewriter.replaceOpWithNewOp<hw::OutputOp>(op, convOperands);
67  opVisited->insert(op->getParentOp());
68  return success();
69  }
70  DenseSet<Operation *> *opVisited;
71 };
72 
73 struct InstanceOpConversion : public OpConversionPattern<hw::InstanceOp> {
74  InstanceOpConversion(TypeConverter &typeConverter, MLIRContext *context,
75  DenseSet<hw::InstanceOp> *convertedOps,
76  const StringSet<> *externModules)
77  : OpConversionPattern(typeConverter, context), convertedOps(convertedOps),
78  externModules(externModules) {}
79 
80  LogicalResult
81  matchAndRewrite(hw::InstanceOp op, OpAdaptor adaptor,
82  ConversionPatternRewriter &rewriter) const override {
83  auto referencedMod = op.getReferencedModuleNameAttr();
84  // If externModules is populated and this is an extern module instance,
85  // donot flatten it.
86  if (externModules->contains(referencedMod.getValue()))
87  return success();
88 
89  auto loc = op.getLoc();
90  // Flatten the operands.
91  llvm::SmallVector<Value> convOperands;
92  for (auto operand : adaptor.getOperands()) {
93  if (auto structType = getStructType(operand.getType())) {
94  auto explodedStruct = rewriter.create<hw::StructExplodeOp>(
95  loc, getInnerTypes(structType), operand);
96  llvm::copy(explodedStruct.getResults(),
97  std::back_inserter(convOperands));
98  } else {
99  convOperands.push_back(operand);
100  }
101  }
102 
103  // Get the new module return type.
104  llvm::SmallVector<Type> newResultTypes;
105  for (auto oldResultType : op.getResultTypes()) {
106  if (auto structType = getStructType(oldResultType))
107  for (auto t : structType.getElements())
108  newResultTypes.push_back(t.type);
109  else
110  newResultTypes.push_back(oldResultType);
111  }
112 
113  // Create the new instance with the flattened module, attributes will be
114  // adjusted later.
115  auto newInstance = rewriter.create<hw::InstanceOp>(
116  loc, newResultTypes, op.getInstanceNameAttr(),
117  FlatSymbolRefAttr::get(referencedMod), convOperands,
118  op.getArgNamesAttr(), op.getResultNamesAttr(), op.getParametersAttr(),
119  op.getInnerSymAttr());
120 
121  // re-create any structs in the result.
122  llvm::SmallVector<Value> convResults;
123  size_t oldResultCntr = 0;
124  for (size_t resIndex = 0; resIndex < newInstance.getNumResults();
125  ++resIndex) {
126  Type oldResultType = op.getResultTypes()[oldResultCntr];
127  if (auto structType = getStructType(oldResultType)) {
128  size_t nElements = structType.getElements().size();
129  auto implodedStruct = rewriter.create<hw::StructCreateOp>(
130  loc, structType,
131  newInstance.getResults().slice(resIndex, nElements));
132  convResults.push_back(implodedStruct.getResult());
133  resIndex += nElements - 1;
134  } else
135  convResults.push_back(newInstance.getResult(resIndex));
136 
137  ++oldResultCntr;
138  }
139  rewriter.replaceOp(op, convResults);
140  convertedOps->insert(newInstance);
141  return success();
142  }
143 
144  DenseSet<hw::InstanceOp> *convertedOps;
145  const StringSet<> *externModules;
146 };
147 
148 using IOTypes = std::pair<TypeRange, TypeRange>;
149 
150 struct IOInfo {
151  // A mapping between an arg/res index and the struct type of the given field.
152  DenseMap<unsigned, hw::StructType> argStructs, resStructs;
153 
154  // Records of the original arg/res types.
155  SmallVector<Type> argTypes, resTypes;
156 };
157 
158 class FlattenIOTypeConverter : public TypeConverter {
159 public:
160  FlattenIOTypeConverter() {
161  addConversion([](Type type, SmallVectorImpl<Type> &results) {
162  auto structType = getStructType(type);
163  if (!structType)
164  results.push_back(type);
165  else {
166  for (auto field : structType.getElements())
167 
168  results.push_back(field.type);
169  }
170  return success();
171  });
172 
173  addTargetMaterialization([](OpBuilder &builder, hw::StructType type,
174  ValueRange inputs, Location loc) {
175  auto result = builder.create<hw::StructCreateOp>(loc, type, inputs);
176  return result.getResult();
177  });
178 
179  addTargetMaterialization([](OpBuilder &builder, hw::TypeAliasType type,
180  ValueRange inputs, Location loc) {
181  auto structType = getStructType(type);
182  assert(structType && "expected struct type");
183  auto result = builder.create<hw::StructCreateOp>(loc, structType, inputs);
184  return result.getResult();
185  });
186  }
187 };
188 
189 } // namespace
190 
191 template <typename... TOp>
192 static void addSignatureConversion(DenseMap<Operation *, IOInfo> &ioMap,
193  ConversionTarget &target,
194  RewritePatternSet &patterns,
195  FlattenIOTypeConverter &typeConverter) {
196  (hw::populateHWModuleLikeTypeConversionPattern(TOp::getOperationName(),
197  patterns, typeConverter),
198  ...);
199 
200  // Legality is defined by a module having been processed once. This is due to
201  // that a pattern cannot be applied multiple times (a 'pattern was already
202  // applied' error - a case that would occur for nested structs). Additionally,
203  // if a pattern could be applied multiple times, this would complicate
204  // updating arg/res names.
205 
206  // Instead, we define legality as when a module has had a modification to its
207  // top-level i/o. This ensures that only a single level of structs are
208  // processed during signature conversion, which then allows us to use the
209  // signature conversion in a recursive manner.
210  target.addDynamicallyLegalOp<TOp...>([&](hw::HWModuleLike moduleLikeOp) {
211  if (isLegalModLikeOp(moduleLikeOp))
212  return true;
213 
214  // This op is involved in conversion. Check if the signature has changed.
215  auto ioInfoIt = ioMap.find(moduleLikeOp);
216  if (ioInfoIt == ioMap.end()) {
217  // Op wasn't primed in the map. Do the safe thing, assume
218  // that it's not considered in this pass, and mark it as legal
219  return true;
220  }
221  auto ioInfo = ioInfoIt->second;
222 
223  auto compareTypes = [&](TypeRange oldTypes, TypeRange newTypes) {
224  return llvm::any_of(llvm::zip(oldTypes, newTypes), [&](auto typePair) {
225  auto oldType = std::get<0>(typePair);
226  auto newType = std::get<1>(typePair);
227  return oldType != newType;
228  });
229  };
230  auto mtype = moduleLikeOp.getHWModuleType();
231  if (compareTypes(mtype.getOutputTypes(), ioInfo.resTypes) ||
232  compareTypes(mtype.getInputTypes(), ioInfo.argTypes))
233  return true;
234 
235  // We're pre-conversion for an op that was primed in the map - it will
236  // always be illegal since it has to-be-converted struct types at its I/O.
237  return false;
238  });
239 }
240 
241 template <typename T>
242 static bool hasUnconvertedOps(mlir::ModuleOp module) {
243  return llvm::any_of(module.getBody()->getOps<T>(),
244  [](T op) { return !isLegalModLikeOp(op); });
245 }
246 
247 template <typename T>
248 static DenseMap<Operation *, IOTypes> populateIOMap(mlir::ModuleOp module) {
249  DenseMap<Operation *, IOTypes> ioMap;
250  for (auto op : module.getOps<T>())
251  ioMap[op] = {op.getArgumentTypes(), op.getResultTypes()};
252  return ioMap;
253 }
254 
255 template <typename ModTy, typename T>
256 static llvm::SmallVector<Attribute>
257 updateNameAttribute(ModTy op, StringRef attrName,
258  DenseMap<unsigned, hw::StructType> &structMap, T oldNames,
259  char joinChar) {
260  llvm::SmallVector<Attribute> newNames;
261  for (auto [i, oldName] : llvm::enumerate(oldNames)) {
262  // Was this arg/res index a struct?
263  auto it = structMap.find(i);
264  if (it == structMap.end()) {
265  // No, keep old name.
266  newNames.push_back(StringAttr::get(op->getContext(), oldName));
267  continue;
268  }
269 
270  // Yes - create new names from the struct fields and the old name at the
271  // index.
272  auto structType = it->second;
273  for (auto field : structType.getElements())
274  newNames.push_back(StringAttr::get(
275  op->getContext(), oldName + Twine(joinChar) + field.name.str()));
276  }
277  return newNames;
278 }
279 
280 template <typename ModTy>
281 static void updateModulePortNames(ModTy op, hw::ModuleType oldModType,
282  char joinChar) {
283  // Module arg and result port names may not be ordered. So we cannot reuse
284  // updateNameAttribute. The arg and result order must be preserved.
285  SmallVector<Attribute> newNames;
286  SmallVector<hw::ModulePort> oldPorts(oldModType.getPorts().begin(),
287  oldModType.getPorts().end());
288  for (auto oldPort : oldPorts) {
289  auto oldName = oldPort.name;
290  if (auto structType = getStructType(oldPort.type)) {
291  for (auto field : structType.getElements()) {
292  newNames.push_back(StringAttr::get(
293  op->getContext(),
294  oldName.getValue() + Twine(joinChar) + field.name.str()));
295  }
296  } else
297  newNames.push_back(oldName);
298  }
299  op.setAllPortNames(newNames);
300 }
301 
302 static llvm::SmallVector<Location>
303 updateLocAttribute(DenseMap<unsigned, hw::StructType> &structMap,
304  SmallVectorImpl<Location> &oldLocs) {
305  llvm::SmallVector<Location> newLocs;
306  for (auto [i, oldLoc] : llvm::enumerate(oldLocs)) {
307  // Was this arg/res index a struct?
308  auto it = structMap.find(i);
309  if (it == structMap.end()) {
310  // No, keep old name.
311  newLocs.push_back(oldLoc);
312  continue;
313  }
314 
315  auto structType = it->second;
316  for (size_t i = 0, e = structType.getElements().size(); i < e; ++i)
317  newLocs.push_back(oldLoc);
318  }
319  return newLocs;
320 }
321 
322 /// The conversion framework seems to throw away block argument locations. We
323 /// use this function to copy the location from the original argument to the
324 /// set of flattened arguments.
325 static void
326 updateBlockLocations(hw::HWModuleLike op,
327  DenseMap<unsigned, hw::StructType> &structMap) {
328  auto locs = op.getInputLocs();
329  if (locs.empty() || op.getModuleBody().empty())
330  return;
331  for (auto [arg, loc] : llvm::zip(op.getBodyBlock()->getArguments(), locs))
332  arg.setLoc(loc);
333 }
334 
335 static void setIOInfo(hw::HWModuleLike op, IOInfo &ioInfo) {
336  ioInfo.argTypes = op.getInputTypes();
337  ioInfo.resTypes = op.getOutputTypes();
338  for (auto [i, arg] : llvm::enumerate(ioInfo.argTypes)) {
339  if (auto structType = getStructType(arg))
340  ioInfo.argStructs[i] = structType;
341  }
342  for (auto [i, res] : llvm::enumerate(ioInfo.resTypes)) {
343  if (auto structType = getStructType(res))
344  ioInfo.resStructs[i] = structType;
345  }
346 }
347 
348 template <typename T>
349 static DenseMap<Operation *, IOInfo> populateIOInfoMap(mlir::ModuleOp module) {
350  DenseMap<Operation *, IOInfo> ioInfoMap;
351  for (auto op : module.getOps<T>()) {
352  IOInfo ioInfo;
353  setIOInfo(op, ioInfo);
354  ioInfoMap[op] = ioInfo;
355  }
356  return ioInfoMap;
357 }
358 
359 template <typename T>
360 static LogicalResult flattenOpsOfType(ModuleOp module, bool recursive,
361  StringSet<> &externModules,
362  char joinChar) {
363  auto *ctx = module.getContext();
364  FlattenIOTypeConverter typeConverter;
365 
366  // Recursively (in case of nested structs) lower the module. We do this one
367  // conversion at a time to allow for updating the arg/res names of the
368  // module in between flattening each level of structs.
369  while (hasUnconvertedOps<T>(module)) {
370  ConversionTarget target(*ctx);
371  RewritePatternSet patterns(ctx);
372  target.addLegalDialect<hw::HWDialect>();
373 
374  // Record any struct types at the module signature. This will be used
375  // post-conversion to update the argument and result names.
376  auto ioInfoMap = populateIOInfoMap<T>(module);
377 
378  // Record the instances that were converted. We keep these around since we
379  // need to update their arg/res attribute names after the modules themselves
380  // have been updated.
381  llvm::DenseSet<hw::InstanceOp> convertedInstances;
382 
383  // Argument conversion for output ops. Similarly to the signature
384  // conversion, legality is based on the op having been visited once, due to
385  // the possibility of nested structs.
386  DenseSet<Operation *> opVisited;
387  patterns.add<OutputOpConversion>(typeConverter, ctx, &opVisited);
388 
389  patterns.add<InstanceOpConversion>(typeConverter, ctx, &convertedInstances,
390  &externModules);
391  target.addDynamicallyLegalOp<hw::OutputOp>(
392  [&](auto op) { return opVisited.contains(op->getParentOp()); });
393  target.addDynamicallyLegalOp<hw::InstanceOp>([&](hw::InstanceOp op) {
394  auto refName = op.getReferencedModuleName();
395  return externModules.contains(refName) ||
396  llvm::none_of(op->getOperands(), [](auto operand) {
397  return isStructType(operand.getType());
398  });
399  });
400 
401  DenseMap<Operation *, ArrayAttr> oldArgNames, oldResNames;
402  DenseMap<Operation *, SmallVector<Location>> oldArgLocs, oldResLocs;
403  DenseMap<Operation *, hw::ModuleType> oldModTypes;
404 
405  for (auto op : module.getOps<T>()) {
406  oldModTypes[op] = op.getHWModuleType();
407  oldArgNames[op] = ArrayAttr::get(module.getContext(), op.getInputNames());
408  oldResNames[op] =
409  ArrayAttr::get(module.getContext(), op.getOutputNames());
410  oldArgLocs[op] = op.getInputLocs();
411  oldResLocs[op] = op.getOutputLocs();
412  }
413 
414  // Signature conversion and legalization patterns.
415  addSignatureConversion<T>(ioInfoMap, target, patterns, typeConverter);
416 
417  if (failed(applyPartialConversion(module, target, std::move(patterns))))
418  return failure();
419 
420  // Update the arg/res names of the module.
421  for (auto op : module.getOps<T>()) {
422  auto ioInfo = ioInfoMap[op];
423  updateModulePortNames(op, oldModTypes[op], joinChar);
424  auto newArgLocs = updateLocAttribute(ioInfo.argStructs, oldArgLocs[op]);
425  auto newResLocs = updateLocAttribute(ioInfo.resStructs, oldResLocs[op]);
426  newArgLocs.append(newResLocs.begin(), newResLocs.end());
427  op.setAllPortLocs(newArgLocs);
428  updateBlockLocations(op, ioInfo.argStructs);
429  }
430 
431  // And likewise with the converted instance ops.
432  for (auto instanceOp : convertedInstances) {
433  auto targetModule =
434  cast<hw::HWModuleLike>(SymbolTable::lookupNearestSymbolFrom(
435  instanceOp, instanceOp.getReferencedModuleNameAttr()));
436 
437  IOInfo ioInfo;
438  if (!ioInfoMap.contains(targetModule)) {
439  // If an extern module, then not yet processed, populate the maps.
440  setIOInfo(targetModule, ioInfo);
441  ioInfoMap[targetModule] = ioInfo;
442  oldArgNames[targetModule] =
443  ArrayAttr::get(module.getContext(), targetModule.getInputNames());
444  oldResNames[targetModule] =
445  ArrayAttr::get(module.getContext(), targetModule.getOutputNames());
446  oldArgLocs[targetModule] = targetModule.getInputLocs();
447  oldResLocs[targetModule] = targetModule.getOutputLocs();
448  } else
449  ioInfo = ioInfoMap[targetModule];
450 
451  instanceOp.setInputNames(ArrayAttr::get(
452  instanceOp.getContext(),
454  instanceOp, "argNames", ioInfo.argStructs,
455  oldArgNames[targetModule].template getAsValueRange<StringAttr>(),
456  joinChar)));
457  instanceOp.setOutputNames(ArrayAttr::get(
458  instanceOp.getContext(),
460  instanceOp, "resultNames", ioInfo.resStructs,
461  oldResNames[targetModule].template getAsValueRange<StringAttr>(),
462  joinChar)));
463  }
464 
465  // Break if we've only lowering a single level of structs.
466  if (!recursive)
467  break;
468  }
469  return success();
470 }
471 
472 //===----------------------------------------------------------------------===//
473 // Pass driver
474 //===----------------------------------------------------------------------===//
475 
476 template <typename... TOps>
477 static bool flattenIO(ModuleOp module, bool recursive,
478  StringSet<> &externModules, char joinChar) {
479  return (failed(flattenOpsOfType<TOps>(module, recursive, externModules,
480  joinChar)) ||
481  ...);
482 }
483 
484 namespace {
485 
486 class FlattenIOPass : public circt::hw::FlattenIOBase<FlattenIOPass> {
487 public:
488  FlattenIOPass(bool recursiveFlag, bool flattenExternFlag, char join) {
489  recursive = recursiveFlag;
490  flattenExtern = flattenExternFlag;
491  joinChar = join;
492  }
493 
494  void runOnOperation() override {
495  ModuleOp module = getOperation();
496  if (!flattenExtern) {
497  // Record the extern modules, donot flatten them.
498  for (auto m : module.getOps<hw::HWModuleExternOp>())
499  externModules.insert(m.getModuleName());
500  if (flattenIO<hw::HWModuleOp, hw::HWModuleGeneratedOp>(
501  module, recursive, externModules, joinChar))
502  signalPassFailure();
503  return;
504  }
505 
507  hw::HWModuleGeneratedOp>(module, recursive, externModules,
508  joinChar))
509  signalPassFailure();
510  };
511 
512 private:
513  StringSet<> externModules;
514 };
515 } // namespace
516 
517 //===----------------------------------------------------------------------===//
518 // Pass initialization
519 //===----------------------------------------------------------------------===//
520 
521 std::unique_ptr<Pass> circt::hw::createFlattenIOPass(bool recursiveFlag,
522  bool flattenExternFlag,
523  char joinChar) {
524  return std::make_unique<FlattenIOPass>(recursiveFlag, flattenExternFlag,
525  joinChar);
526 }
assert(baseType &&"element must be base type")
static LogicalResult compareTypes(Location loc, TypeRange rangeA, TypeRange rangeB)
Definition: FSMOps.cpp:118
static llvm::SmallVector< Type > getInnerTypes(hw::StructType t)
Definition: FlattenIO.cpp:32
static llvm::SmallVector< Location > updateLocAttribute(DenseMap< unsigned, hw::StructType > &structMap, SmallVectorImpl< Location > &oldLocs)
Definition: FlattenIO.cpp:303
static bool isLegalModLikeOp(hw::HWModuleLike moduleLikeOp)
Definition: FlattenIO.cpp:27
static void updateBlockLocations(hw::HWModuleLike op, DenseMap< unsigned, hw::StructType > &structMap)
The conversion framework seems to throw away block argument locations.
Definition: FlattenIO.cpp:326
static bool hasUnconvertedOps(mlir::ModuleOp module)
Definition: FlattenIO.cpp:242
static llvm::SmallVector< Attribute > updateNameAttribute(ModTy op, StringRef attrName, DenseMap< unsigned, hw::StructType > &structMap, T oldNames, char joinChar)
Definition: FlattenIO.cpp:257
static void setIOInfo(hw::HWModuleLike op, IOInfo &ioInfo)
Definition: FlattenIO.cpp:335
static LogicalResult flattenOpsOfType(ModuleOp module, bool recursive, StringSet<> &externModules, char joinChar)
Definition: FlattenIO.cpp:360
static DenseMap< Operation *, IOTypes > populateIOMap(mlir::ModuleOp module)
Definition: FlattenIO.cpp:248
static void addSignatureConversion(DenseMap< Operation *, IOInfo > &ioMap, ConversionTarget &target, RewritePatternSet &patterns, FlattenIOTypeConverter &typeConverter)
Definition: FlattenIO.cpp:192
static bool isStructType(Type type)
Definition: FlattenIO.cpp:18
static void updateModulePortNames(ModTy op, hw::ModuleType oldModType, char joinChar)
Definition: FlattenIO.cpp:281
static hw::StructType getStructType(Type type)
Definition: FlattenIO.cpp:22
static bool flattenIO(ModuleOp module, bool recursive, StringSet<> &externModules, char joinChar)
Definition: FlattenIO.cpp:477
static DenseMap< Operation *, IOInfo > populateIOInfoMap(mlir::ModuleOp module)
Definition: FlattenIO.cpp:349
llvm::SmallVector< StringAttr > inputs
Builder builder
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:54
mlir::Type innerType(mlir::Type type)
Definition: ESITypes.cpp:184
void populateHWModuleLikeTypeConversionPattern(StringRef moduleLikeOpName, RewritePatternSet &patterns, TypeConverter &converter)
mlir::Type getCanonicalType(mlir::Type type)
Definition: HWTypes.cpp:41
std::unique_ptr< mlir::Pass > createFlattenIOPass(bool recursiveFlag=true, bool flattenExternFlag=false, char joinChar='.')
Definition: FlattenIO.cpp:521
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21