CIRCT  18.0.0git
FlattenIO.cpp
Go to the documentation of this file.
1 //===- FlattenIO.cpp - HW I/O flattening pass -------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "PassDetails.h"
10 #include "circt/Dialect/HW/HWOps.h"
12 #include "mlir/Transforms/DialectConversion.h"
13 #include "llvm/ADT/TypeSwitch.h"
14 
15 using namespace mlir;
16 using namespace circt;
17 
18 static bool isStructType(Type type) {
19  return hw::getCanonicalType(type).isa<hw::StructType>();
20 }
21 
22 static hw::StructType getStructType(Type type) {
23  return hw::getCanonicalType(type).dyn_cast<hw::StructType>();
24 }
25 
26 // Legal if no in- or output type is a struct.
27 static bool isLegalModLikeOp(hw::HWModuleLike moduleLikeOp) {
28  return llvm::none_of(moduleLikeOp.getHWModuleType().getPortTypes(),
29  isStructType);
30 }
31 
32 static llvm::SmallVector<Type> getInnerTypes(hw::StructType t) {
33  llvm::SmallVector<Type> inner;
34  t.getInnerTypes(inner);
35  for (auto [index, innerType] : llvm::enumerate(inner))
36  inner[index] = hw::getCanonicalType(innerType);
37  return inner;
38 }
39 
40 namespace {
41 
42 // Replaces an output op with a new output with flattened (exploded) structs.
43 struct OutputOpConversion : public OpConversionPattern<hw::OutputOp> {
44  OutputOpConversion(TypeConverter &typeConverter, MLIRContext *context,
45  DenseSet<Operation *> *opVisited)
46  : OpConversionPattern(typeConverter, context), opVisited(opVisited) {}
47 
48  LogicalResult
49  matchAndRewrite(hw::OutputOp op, OpAdaptor adaptor,
50  ConversionPatternRewriter &rewriter) const override {
51  llvm::SmallVector<Value> convOperands;
52 
53  // Flatten the operands.
54  for (auto operand : adaptor.getOperands()) {
55  if (auto structType = getStructType(operand.getType())) {
56  auto explodedStruct = rewriter.create<hw::StructExplodeOp>(
57  op.getLoc(), getInnerTypes(structType), operand);
58  llvm::copy(explodedStruct.getResults(),
59  std::back_inserter(convOperands));
60  } else {
61  convOperands.push_back(operand);
62  }
63  }
64 
65  // And replace.
66  rewriter.replaceOpWithNewOp<hw::OutputOp>(op, convOperands);
67  opVisited->insert(op->getParentOp());
68  return success();
69  }
70  DenseSet<Operation *> *opVisited;
71 };
72 
73 struct InstanceOpConversion : public OpConversionPattern<hw::InstanceOp> {
74  InstanceOpConversion(TypeConverter &typeConverter, MLIRContext *context,
75  DenseSet<hw::InstanceOp> *convertedOps)
76  : OpConversionPattern(typeConverter, context),
77  convertedOps(convertedOps) {}
78 
79  LogicalResult
80  matchAndRewrite(hw::InstanceOp op, OpAdaptor adaptor,
81  ConversionPatternRewriter &rewriter) const override {
82  auto loc = op.getLoc();
83  // Flatten the operands.
84  llvm::SmallVector<Value> convOperands;
85  for (auto operand : adaptor.getOperands()) {
86  if (auto structType = getStructType(operand.getType())) {
87  auto explodedStruct = rewriter.create<hw::StructExplodeOp>(
88  loc, getInnerTypes(structType), operand);
89  llvm::copy(explodedStruct.getResults(),
90  std::back_inserter(convOperands));
91  } else {
92  convOperands.push_back(operand);
93  }
94  }
95 
96  // Create the new instance...
97  auto newInstance = rewriter.create<hw::InstanceOp>(
98  loc, op.getReferencedModuleSlow(), op.getInstanceName(), convOperands);
99 
100  // re-create any structs in the result.
101  llvm::SmallVector<Value> convResults;
102  size_t oldResultCntr = 0;
103  for (size_t resIndex = 0; resIndex < newInstance.getNumResults();
104  ++resIndex) {
105  Type oldResultType = op.getResultTypes()[oldResultCntr];
106  if (auto structType = getStructType(oldResultType)) {
107  size_t nElements = structType.getElements().size();
108  auto implodedStruct = rewriter.create<hw::StructCreateOp>(
109  loc, structType,
110  newInstance.getResults().slice(resIndex, nElements));
111  convResults.push_back(implodedStruct.getResult());
112  resIndex += nElements - 1;
113  } else
114  convResults.push_back(newInstance.getResult(resIndex));
115 
116  ++oldResultCntr;
117  }
118  rewriter.replaceOp(op, convResults);
119  convertedOps->insert(newInstance);
120  return success();
121  }
122 
123  DenseSet<hw::InstanceOp> *convertedOps;
124 };
125 
126 using IOTypes = std::pair<TypeRange, TypeRange>;
127 
128 struct IOInfo {
129  // A mapping between an arg/res index and the struct type of the given field.
130  DenseMap<unsigned, hw::StructType> argStructs, resStructs;
131 
132  // Records of the original arg/res types.
133  SmallVector<Type> argTypes, resTypes;
134 };
135 
136 class FlattenIOTypeConverter : public TypeConverter {
137 public:
138  FlattenIOTypeConverter() {
139  addConversion([](Type type, SmallVectorImpl<Type> &results) {
140  auto structType = getStructType(type);
141  if (!structType)
142  results.push_back(type);
143  else {
144  for (auto field : structType.getElements())
145  results.push_back(field.type);
146  }
147  return success();
148  });
149 
150  addTargetMaterialization([](OpBuilder &builder, hw::StructType type,
151  ValueRange inputs, Location loc) {
152  auto result = builder.create<hw::StructCreateOp>(loc, type, inputs);
153  return result.getResult();
154  });
155 
156  addTargetMaterialization([](OpBuilder &builder, hw::TypeAliasType type,
157  ValueRange inputs, Location loc) {
158  auto structType = getStructType(type);
159  assert(structType && "expected struct type");
160  auto result = builder.create<hw::StructCreateOp>(loc, structType, inputs);
161  return result.getResult();
162  });
163  }
164 };
165 
166 } // namespace
167 
168 template <typename... TOp>
169 static void addSignatureConversion(DenseMap<Operation *, IOInfo> &ioMap,
170  ConversionTarget &target,
171  RewritePatternSet &patterns,
172  FlattenIOTypeConverter &typeConverter) {
173  (hw::populateHWModuleLikeTypeConversionPattern(TOp::getOperationName(),
174  patterns, typeConverter),
175  ...);
176 
177  // Legality is defined by a module having been processed once. This is due to
178  // that a pattern cannot be applied multiple times (a 'pattern was already
179  // applied' error - a case that would occur for nested structs). Additionally,
180  // if a pattern could be applied multiple times, this would complicate
181  // updating arg/res names.
182 
183  // Instead, we define legality as when a module has had a modification to its
184  // top-level i/o. This ensures that only a single level of structs are
185  // processed during signature conversion, which then allows us to use the
186  // signature conversion in a recursive manner.
187  target.addDynamicallyLegalOp<TOp...>([&](hw::HWModuleLike moduleLikeOp) {
188  if (isLegalModLikeOp(moduleLikeOp))
189  return true;
190 
191  // This op is involved in conversion. Check if the signature has changed.
192  auto ioInfoIt = ioMap.find(moduleLikeOp);
193  if (ioInfoIt == ioMap.end()) {
194  // Op wasn't primed in the map. Do the safe thing, assume
195  // that it's not considered in this pass, and mark it as legal
196  return true;
197  }
198  auto ioInfo = ioInfoIt->second;
199 
200  auto compareTypes = [&](TypeRange oldTypes, TypeRange newTypes) {
201  return llvm::any_of(llvm::zip(oldTypes, newTypes), [&](auto typePair) {
202  auto oldType = std::get<0>(typePair);
203  auto newType = std::get<1>(typePair);
204  return oldType != newType;
205  });
206  };
207  auto mtype = moduleLikeOp.getHWModuleType();
208  if (compareTypes(mtype.getOutputTypes(), ioInfo.resTypes) ||
209  compareTypes(mtype.getInputTypes(), ioInfo.argTypes))
210  return true;
211 
212  // We're pre-conversion for an op that was primed in the map - it will
213  // always be illegal since it has to-be-converted struct types at its I/O.
214  return false;
215  });
216 }
217 
218 template <typename T>
219 static bool hasUnconvertedOps(mlir::ModuleOp module) {
220  return llvm::any_of(module.getBody()->getOps<T>(),
221  [](T op) { return !isLegalModLikeOp(op); });
222 }
223 
224 template <typename T>
225 static DenseMap<Operation *, IOTypes> populateIOMap(mlir::ModuleOp module) {
226  DenseMap<Operation *, IOTypes> ioMap;
227  for (auto op : module.getOps<T>())
228  ioMap[op] = {op.getArgumentTypes(), op.getResultTypes()};
229  return ioMap;
230 }
231 
232 template <typename ModTy, typename T>
233 static llvm::SmallVector<Attribute>
234 updateNameAttribute(ModTy op, StringRef attrName,
235  DenseMap<unsigned, hw::StructType> &structMap, T oldNames) {
236  llvm::SmallVector<Attribute> newNames;
237  for (auto [i, oldName] : llvm::enumerate(oldNames)) {
238  // Was this arg/res index a struct?
239  auto it = structMap.find(i);
240  if (it == structMap.end()) {
241  // No, keep old name.
242  newNames.push_back(StringAttr::get(op->getContext(), oldName));
243  continue;
244  }
245 
246  // Yes - create new names from the struct fields and the old name at the
247  // index.
248  auto structType = it->second;
249  for (auto field : structType.getElements())
250  newNames.push_back(
251  StringAttr::get(op->getContext(), oldName + "." + field.name.str()));
252  }
253  return newNames;
254 }
255 
256 static llvm::SmallVector<Attribute>
257 updateLocAttribute(DenseMap<unsigned, hw::StructType> &structMap,
258  ArrayAttr oldLocs) {
259  llvm::SmallVector<Attribute> newLocs;
260  if (!oldLocs)
261  return newLocs;
262  for (auto [i, oldLoc] : llvm::enumerate(oldLocs.getAsRange<Location>())) {
263  // Was this arg/res index a struct?
264  auto it = structMap.find(i);
265  if (it == structMap.end()) {
266  // No, keep old name.
267  newLocs.push_back(oldLoc);
268  continue;
269  }
270 
271  auto structType = it->second;
272  for (size_t i = 0, e = structType.getElements().size(); i < e; ++i)
273  newLocs.push_back(oldLoc);
274  }
275  return newLocs;
276 }
277 
278 /// The conversion framework seems to throw away block argument locations. We
279 /// use this function to copy the location from the original argument to the
280 /// set of flattened arguments.
281 static void
282 updateBlockLocations(hw::HWModuleLike op,
283  DenseMap<unsigned, hw::StructType> &structMap) {
284  auto locs = op.getInputLocs();
285  if (locs.empty() || op.getModuleBody().empty())
286  return;
287  for (auto [arg, loc] : llvm::zip(op.getBodyBlock()->getArguments(), locs))
288  arg.setLoc(loc);
289 }
290 
291 template <typename T>
292 static DenseMap<Operation *, IOInfo> populateIOInfoMap(mlir::ModuleOp module) {
293  DenseMap<Operation *, IOInfo> ioInfoMap;
294  for (auto op : module.getOps<T>()) {
295  IOInfo ioInfo;
296  ioInfo.argTypes = op.getInputTypes();
297  ioInfo.resTypes = op.getOutputTypes();
298  for (auto [i, arg] : llvm::enumerate(ioInfo.argTypes)) {
299  if (auto structType = getStructType(arg))
300  ioInfo.argStructs[i] = structType;
301  }
302  for (auto [i, res] : llvm::enumerate(ioInfo.resTypes)) {
303  if (auto structType = getStructType(res))
304  ioInfo.resStructs[i] = structType;
305  }
306  ioInfoMap[op] = ioInfo;
307  }
308  return ioInfoMap;
309 }
310 
311 template <typename T>
312 static LogicalResult flattenOpsOfType(ModuleOp module, bool recursive) {
313  auto *ctx = module.getContext();
314  FlattenIOTypeConverter typeConverter;
315 
316  // Recursively (in case of nested structs) lower the module. We do this one
317  // conversion at a time to allow for updating the arg/res names of the
318  // module in between flattening each level of structs.
319  while (hasUnconvertedOps<T>(module)) {
320  ConversionTarget target(*ctx);
321  RewritePatternSet patterns(ctx);
322  target.addLegalDialect<hw::HWDialect>();
323 
324  // Record any struct types at the module signature. This will be used
325  // post-conversion to update the argument and result names.
326  auto ioInfoMap = populateIOInfoMap<T>(module);
327 
328  // Record the instances that were converted. We keep these around since we
329  // need to update their arg/res attribute names after the modules themselves
330  // have been updated.
331  llvm::DenseSet<hw::InstanceOp> convertedInstances;
332 
333  // Argument conversion for output ops. Similarly to the signature
334  // conversion, legality is based on the op having been visited once, due to
335  // the possibility of nested structs.
336  DenseSet<Operation *> opVisited;
337  patterns.add<OutputOpConversion>(typeConverter, ctx, &opVisited);
338 
339  patterns.add<InstanceOpConversion>(typeConverter, ctx, &convertedInstances);
340  target.addDynamicallyLegalOp<hw::OutputOp>(
341  [&](auto op) { return opVisited.contains(op->getParentOp()); });
342  target.addDynamicallyLegalOp<hw::InstanceOp>([&](auto op) {
343  return llvm::none_of(op->getOperands(), [](auto operand) {
344  return isStructType(operand.getType());
345  });
346  });
347 
348  DenseMap<Operation *, ArrayAttr> oldArgNames, oldResNames, oldArgLocs,
349  oldResLocs;
350  for (auto op : module.getOps<T>()) {
351  oldArgNames[op] = ArrayAttr::get(module.getContext(), op.getInputNames());
352  oldResNames[op] =
353  ArrayAttr::get(module.getContext(), op.getOutputNames());
354  oldArgLocs[op] = op.getInputLocsAttr();
355  oldResLocs[op] = op.getOutputLocsAttr();
356  }
357 
358  // Signature conversion and legalization patterns.
359  addSignatureConversion<T>(ioInfoMap, target, patterns, typeConverter);
360 
361  if (failed(applyPartialConversion(module, target, std::move(patterns))))
362  return failure();
363 
364  // Update the arg/res names of the module.
365  for (auto op : module.getOps<T>()) {
366  auto ioInfo = ioInfoMap[op];
367  auto newArgNames = updateNameAttribute(
368  op, "argNames", ioInfo.argStructs,
369  oldArgNames[op].template getAsValueRange<StringAttr>());
370  auto newResNames = updateNameAttribute(
371  op, "resultNames", ioInfo.resStructs,
372  oldResNames[op].template getAsValueRange<StringAttr>());
373  newArgNames.append(newResNames.begin(), newResNames.end());
374  op.setAllPortNames(newArgNames);
375  auto newArgLocs = updateLocAttribute(ioInfo.argStructs, oldArgLocs[op]);
376  auto newResLocs = updateLocAttribute(ioInfo.resStructs, oldResLocs[op]);
377  newArgLocs.append(newResLocs.begin(), newResLocs.end());
378  op.setPortLocsAttr(ArrayAttr::get(op.getContext(), newArgLocs));
379  updateBlockLocations(op, ioInfo.argStructs);
380  }
381 
382  // And likewise with the converted instance ops.
383  for (auto instanceOp : convertedInstances) {
384  Operation *targetModule = instanceOp.getReferencedModuleSlow();
385  auto ioInfo = ioInfoMap[targetModule];
386  instanceOp.setInputNames(ArrayAttr::get(
387  instanceOp.getContext(),
388  updateNameAttribute(instanceOp, "argNames", ioInfo.argStructs,
389  oldArgNames[targetModule]
390  .template getAsValueRange<StringAttr>())));
391  instanceOp.setOutputNames(ArrayAttr::get(
392  instanceOp.getContext(),
393  updateNameAttribute(instanceOp, "resultNames", ioInfo.resStructs,
394  oldResNames[targetModule]
395  .template getAsValueRange<StringAttr>())));
396  instanceOp.dump();
397  }
398 
399  // Break if we've only lowering a single level of structs.
400  if (!recursive)
401  break;
402  }
403  return success();
404 }
405 
406 //===----------------------------------------------------------------------===//
407 // Pass driver
408 //===----------------------------------------------------------------------===//
409 
410 template <typename... TOps>
411 static bool flattenIO(ModuleOp module, bool recursive) {
412  return (failed(flattenOpsOfType<TOps>(module, recursive)) || ...);
413 }
414 
415 namespace {
416 
417 class FlattenIOPass : public circt::hw::FlattenIOBase<FlattenIOPass> {
418 public:
419  void runOnOperation() override {
420  ModuleOp module = getOperation();
422  hw::HWModuleGeneratedOp>(module, recursive))
423  signalPassFailure();
424  };
425 };
426 
427 } // namespace
428 
429 //===----------------------------------------------------------------------===//
430 // Pass initialization
431 //===----------------------------------------------------------------------===//
432 
433 std::unique_ptr<Pass> circt::hw::createFlattenIOPass() {
434  return std::make_unique<FlattenIOPass>();
435 }
assert(baseType &&"element must be base type")
static LogicalResult compareTypes(Location loc, TypeRange rangeA, TypeRange rangeB)
Definition: FSMOps.cpp:118
static llvm::SmallVector< Type > getInnerTypes(hw::StructType t)
Definition: FlattenIO.cpp:32
static bool isLegalModLikeOp(hw::HWModuleLike moduleLikeOp)
Definition: FlattenIO.cpp:27
static void updateBlockLocations(hw::HWModuleLike op, DenseMap< unsigned, hw::StructType > &structMap)
The conversion framework seems to throw away block argument locations.
Definition: FlattenIO.cpp:282
static bool hasUnconvertedOps(mlir::ModuleOp module)
Definition: FlattenIO.cpp:219
static LogicalResult flattenOpsOfType(ModuleOp module, bool recursive)
Definition: FlattenIO.cpp:312
static llvm::SmallVector< Attribute > updateLocAttribute(DenseMap< unsigned, hw::StructType > &structMap, ArrayAttr oldLocs)
Definition: FlattenIO.cpp:257
static llvm::SmallVector< Attribute > updateNameAttribute(ModTy op, StringRef attrName, DenseMap< unsigned, hw::StructType > &structMap, T oldNames)
Definition: FlattenIO.cpp:234
static DenseMap< Operation *, IOTypes > populateIOMap(mlir::ModuleOp module)
Definition: FlattenIO.cpp:225
static void addSignatureConversion(DenseMap< Operation *, IOInfo > &ioMap, ConversionTarget &target, RewritePatternSet &patterns, FlattenIOTypeConverter &typeConverter)
Definition: FlattenIO.cpp:169
static bool flattenIO(ModuleOp module, bool recursive)
Definition: FlattenIO.cpp:411
static bool isStructType(Type type)
Definition: FlattenIO.cpp:18
static hw::StructType getStructType(Type type)
Definition: FlattenIO.cpp:22
static DenseMap< Operation *, IOInfo > populateIOInfoMap(mlir::ModuleOp module)
Definition: FlattenIO.cpp:292
llvm::SmallVector< StringAttr > inputs
Builder builder
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:53
mlir::Type innerType(mlir::Type type)
Definition: ESITypes.cpp:184
std::unique_ptr< mlir::Pass > createFlattenIOPass()
Definition: FlattenIO.cpp:433
void populateHWModuleLikeTypeConversionPattern(StringRef moduleLikeOpName, RewritePatternSet &patterns, TypeConverter &converter)
mlir::Type getCanonicalType(mlir::Type type)
Definition: HWTypes.cpp:41
This file defines an intermediate representation for circuits acting as an abstraction for constraint...
Definition: DebugAnalysis.h:21