CIRCT  20.0.0git
KanagawaContainersToHW.cpp
Go to the documentation of this file.
1 //===- KanagawaContainersToHW.cpp -----------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
11 #include "mlir/Pass/Pass.h"
12 
13 #include "circt/Dialect/HW/HWOps.h"
18 
20 #include "mlir/IR/OperationSupport.h"
21 #include "mlir/Transforms/DialectConversion.h"
22 #include "llvm/ADT/TypeSwitch.h"
23 
24 namespace circt {
25 namespace kanagawa {
26 #define GEN_PASS_DEF_KANAGAWACONTAINERSTOHW
27 #include "circt/Dialect/Kanagawa/KanagawaPasses.h.inc"
28 } // namespace kanagawa
29 } // namespace circt
30 
31 using namespace circt;
32 using namespace kanagawa;
33 
34 namespace {
35 
36 // Analysis result for generating the port interface of a container + a bit of
37 // port op caching.
38 struct ContainerPortInfo {
39  std::unique_ptr<hw::ModulePortInfo> hwPorts;
40 
41  // A mapping between the port name and the port op within the container.
42  llvm::DenseMap<StringAttr, InputPortOp> opInputs;
43 
44  // A mapping between the port name and the port op within the container.
45  llvm::DenseMap<StringAttr, OutputPortOp> opOutputs;
46 
47  // A mapping between port symbols and their corresponding port name.
48  llvm::DenseMap<StringAttr, StringAttr> portSymbolsToPortName;
49 
50  ContainerPortInfo() = default;
51  ContainerPortInfo(ContainerOp container) {
52  SmallVector<hw::PortInfo, 4> inputs, outputs;
53  auto *ctx = container.getContext();
54 
55  // Copies all attributes from a port, except for the port symbol, name, and
56  // type.
57  auto copyPortAttrs = [ctx](auto port) {
58  llvm::DenseSet<StringAttr> elidedAttrs;
59  elidedAttrs.insert(port.getInnerSymAttrName());
60  elidedAttrs.insert(port.getTypeAttrName());
61  elidedAttrs.insert(port.getNameAttrName());
62  llvm::SmallVector<NamedAttribute> attrs;
63  for (NamedAttribute namedAttr : port->getAttrs()) {
64  if (elidedAttrs.contains(namedAttr.getName()))
65  continue;
66  attrs.push_back(namedAttr);
67  }
68  return DictionaryAttr::get(ctx, attrs);
69  };
70 
71  // Gather in and output port ops to define the hw.module interface. Here, we
72  // also perform uniqueing of the port names.
73  Namespace portNs;
74  for (auto input : container.getBodyBlock()->getOps<InputPortOp>()) {
75  auto uniquePortName =
76  StringAttr::get(ctx, portNs.newName(input.getNameHint()));
77  opInputs[uniquePortName] = input;
78  hw::PortInfo portInfo;
79  portInfo.name = uniquePortName;
80  portSymbolsToPortName[input.getInnerSym().getSymName()] = uniquePortName;
81  portInfo.type = cast<PortOpInterface>(input.getOperation()).getPortType();
82  portInfo.dir = hw::ModulePort::Direction::Input;
83  portInfo.attrs = copyPortAttrs(input);
84  inputs.push_back(portInfo);
85  }
86 
87  for (auto output : container.getBodyBlock()->getOps<OutputPortOp>()) {
88  auto uniquePortName =
89  StringAttr::get(ctx, portNs.newName(output.getNameAttr().getValue()));
90  opOutputs[uniquePortName] = output;
91 
92  hw::PortInfo portInfo;
93  portInfo.name = uniquePortName;
94  portSymbolsToPortName[output.getInnerSym().getSymName()] = uniquePortName;
95  portInfo.type =
96  cast<PortOpInterface>(output.getOperation()).getPortType();
97  portInfo.dir = hw::ModulePort::Direction::Output;
98  portInfo.attrs = copyPortAttrs(output);
99  outputs.push_back(portInfo);
100  }
101  hwPorts = std::make_unique<hw::ModulePortInfo>(inputs, outputs);
102  }
103 };
104 
105 using ContainerPortInfoMap =
106  llvm::DenseMap<hw::InnerRefAttr, ContainerPortInfo>;
107 using ContainerHWModSymbolMap = llvm::DenseMap<hw::InnerRefAttr, StringAttr>;
108 
109 static StringAttr concatNames(mlir::StringAttr lhs, mlir::StringAttr rhs) {
110  return StringAttr::get(lhs.getContext(), lhs.strref() + "_" + rhs.strref());
111 }
112 
113 struct ContainerOpConversionPattern : public OpConversionPattern<ContainerOp> {
114  ContainerOpConversionPattern(MLIRContext *ctx, Namespace &modNamespace,
115  ContainerPortInfoMap &portOrder,
116  ContainerHWModSymbolMap &modSymMap)
117  : OpConversionPattern<ContainerOp>(ctx), modNamespace(modNamespace),
118  portOrder(portOrder), modSymMap(modSymMap) {}
119 
120  LogicalResult
121  matchAndRewrite(ContainerOp op, OpAdaptor adaptor,
122  ConversionPatternRewriter &rewriter) const override {
123  auto design = op->getParentOfType<DesignOp>();
124  rewriter.setInsertionPoint(design);
125 
126  // Generate and de-alias the hw.module name.
127  // If the container is a top level container, ignore the design name.
128  StringAttr hwmodName;
129  if (op.getIsTopLevel())
130  hwmodName = op.getNameHintAttr();
131  else
132  hwmodName =
133  concatNames(op.getInnerRef().getModule(), op.getNameHintAttr());
134 
135  hwmodName = StringAttr::get(op.getContext(),
136  modNamespace.newName(hwmodName.getValue()));
137 
138  const ContainerPortInfo &cpi = portOrder.at(op.getInnerRef());
139  auto hwMod =
140  rewriter.create<hw::HWModuleOp>(op.getLoc(), hwmodName, *cpi.hwPorts);
141  modSymMap[op.getInnerRef()] = hwMod.getSymNameAttr();
142 
143  hw::OutputOp outputOp =
144  cast<hw::OutputOp>(hwMod.getBodyBlock()->getTerminator());
145 
146  // Replace all of the reads of the inputs to use the input block arguments.
147  for (auto [idx, input] : llvm::enumerate(cpi.hwPorts->getInputs())) {
148  Value barg = hwMod.getBodyBlock()->getArgument(idx);
149  InputPortOp inputPort = cpi.opInputs.at(input.name);
150  // Replace all reads of the input port with the input block argument.
151  for (auto *user : inputPort.getOperation()->getUsers()) {
152  auto reader = dyn_cast<PortReadOp>(user);
153  if (!reader)
154  return rewriter.notifyMatchFailure(
155  user, "expected only kanagawa.port.read ops of the input port");
156 
157  rewriter.replaceOp(reader, barg);
158  }
159 
160  rewriter.eraseOp(inputPort);
161  }
162 
163  // Adjust the hw.output op to use kanagawa.port.write values
164  llvm::SmallVector<Value> outputValues;
165  for (auto [idx, output] : llvm::enumerate(cpi.hwPorts->getOutputs())) {
166  auto outputPort = cpi.opOutputs.at(output.name);
167  // Locate the write to the output op.
168  auto users = outputPort->getUsers();
169  size_t nUsers = std::distance(users.begin(), users.end());
170  if (nUsers != 1)
171  return outputPort->emitOpError()
172  << "expected exactly one kanagawa.port.write op of the output "
173  "port: "
174  << output.name.str() << " found: " << nUsers;
175  auto writer = cast<PortWriteOp>(*users.begin());
176  outputValues.push_back(writer.getValue());
177  rewriter.eraseOp(outputPort);
178  rewriter.eraseOp(writer);
179  }
180 
181  rewriter.mergeBlocks(&op.getBodyRegion().front(), hwMod.getBodyBlock());
182 
183  // Rewrite the hw.output op.
184  rewriter.eraseOp(outputOp);
185  rewriter.setInsertionPointToEnd(hwMod.getBodyBlock());
186  outputOp = rewriter.create<hw::OutputOp>(op.getLoc(), outputValues);
187  rewriter.eraseOp(op);
188  return success();
189  }
190 
191  Namespace &modNamespace;
192  ContainerPortInfoMap &portOrder;
193  ContainerHWModSymbolMap &modSymMap;
194 };
195 
196 struct ThisOpConversionPattern : public OpConversionPattern<ThisOp> {
197  ThisOpConversionPattern(MLIRContext *ctx)
198  : OpConversionPattern<ThisOp>(ctx) {}
199 
200  LogicalResult
201  matchAndRewrite(ThisOp op, OpAdaptor adaptor,
202  ConversionPatternRewriter &rewriter) const override {
203  // TODO: remove this op from the dialect - not needed anymore.
204  rewriter.eraseOp(op);
205  return success();
206  }
207 };
208 
209 struct ContainerInstanceOpConversionPattern
210  : public OpConversionPattern<ContainerInstanceOp> {
211 
212  ContainerInstanceOpConversionPattern(MLIRContext *ctx,
213  ContainerPortInfoMap &portOrder,
214  ContainerHWModSymbolMap &modSymMap)
215  : OpConversionPattern<ContainerInstanceOp>(ctx), portOrder(portOrder),
216  modSymMap(modSymMap) {}
217 
218  LogicalResult
219  matchAndRewrite(ContainerInstanceOp op, OpAdaptor adaptor,
220  ConversionPatternRewriter &rewriter) const override {
221  rewriter.setInsertionPoint(op);
222  llvm::SmallVector<Value> operands;
223 
224  const ContainerPortInfo &cpi =
225  portOrder.at(op.getResult().getType().getScopeRef());
226 
227  // Gather the get_port ops that target the instance
228  llvm::DenseMap<StringAttr, PortReadOp> outputReadsToReplace;
229  llvm::DenseMap<StringAttr, PortWriteOp> inputWritesToUse;
230  llvm::SmallVector<Operation *> getPortsToErase;
231  for (auto *user : op->getUsers()) {
232  auto getPort = dyn_cast<GetPortOp>(user);
233  if (!getPort)
234  return rewriter.notifyMatchFailure(
235  user, "expected only kanagawa.get_port op usage of the instance");
236 
237  for (auto *user : getPort->getUsers()) {
238  auto res =
239  llvm::TypeSwitch<Operation *, LogicalResult>(user)
240  .Case<PortReadOp>([&](auto read) {
241  auto [it, inserted] = outputReadsToReplace.insert(
242  {cpi.portSymbolsToPortName.at(
243  getPort.getPortSymbolAttr().getAttr()),
244  read});
245  if (!inserted)
246  return rewriter.notifyMatchFailure(
247  read, "expected only one kanagawa.port.read op of the "
248  "output port");
249  return success();
250  })
251  .Case<PortWriteOp>([&](auto write) {
252  auto [it, inserted] = inputWritesToUse.insert(
253  {cpi.portSymbolsToPortName.at(
254  getPort.getPortSymbolAttr().getAttr()),
255  write});
256  if (!inserted)
257  return rewriter.notifyMatchFailure(
258  write,
259  "expected only one kanagawa.port.write op of the input "
260  "port");
261  return success();
262  })
263  .Default([&](auto op) {
264  return rewriter.notifyMatchFailure(
265  op, "expected only kanagawa.port.read or "
266  "kanagawa.port.write ops "
267  "of the "
268  "instance");
269  });
270  if (failed(res))
271  return failure();
272  }
273  getPortsToErase.push_back(getPort);
274  }
275 
276  // Grab the operands in the order of the hw.module ports.
277  size_t nInputPorts = std::distance(cpi.hwPorts->getInputs().begin(),
278  cpi.hwPorts->getInputs().end());
279  if (nInputPorts != inputWritesToUse.size()) {
280  std::string errMsg;
281  llvm::raw_string_ostream ers(errMsg);
282  ers << "Error when lowering instance ";
283  op.print(ers, mlir::OpPrintingFlags().printGenericOpForm());
284 
285  ers << "\nexpected exactly one kanagawa.port.write op of each input "
286  "port. "
287  "Mising port assignments were:\n";
288  for (auto input : cpi.hwPorts->getInputs()) {
289  if (inputWritesToUse.find(input.name) == inputWritesToUse.end())
290  ers << "\t" << input.name << "\n";
291  }
292  return rewriter.notifyMatchFailure(op, errMsg);
293  }
294  for (auto input : cpi.hwPorts->getInputs()) {
295  auto writeOp = inputWritesToUse.at(input.name);
296  operands.push_back(writeOp.getValue());
297  rewriter.eraseOp(writeOp);
298  }
299 
300  // Determine the result types.
301  llvm::SmallVector<Type> retTypes;
302  for (auto output : cpi.hwPorts->getOutputs())
303  retTypes.push_back(output.type);
304 
305  // Gather arg and res names
306  // TODO: @mortbopet - this should be part of ModulePortInfo
307  llvm::SmallVector<Attribute> argNames, resNames;
308  llvm::transform(cpi.hwPorts->getInputs(), std::back_inserter(argNames),
309  [](auto port) { return port.name; });
310  llvm::transform(cpi.hwPorts->getOutputs(), std::back_inserter(resNames),
311  [](auto port) { return port.name; });
312 
313  // Create the hw.instance op.
314  StringRef moduleName = modSymMap[op.getTargetNameAttr()];
315  auto hwInst = rewriter.create<hw::InstanceOp>(
316  op.getLoc(), retTypes, op.getInnerSym().getSymName(), moduleName,
317  operands, rewriter.getArrayAttr(argNames),
318  rewriter.getArrayAttr(resNames),
319  /*parameters*/ rewriter.getArrayAttr({}), /*innerSym*/ nullptr);
320 
321  // Replace the reads of the output ports with the hw.instance results.
322  for (auto [output, value] :
323  llvm::zip(cpi.hwPorts->getOutputs(), hwInst.getResults())) {
324  auto outputReadIt = outputReadsToReplace.find(output.name);
325  if (outputReadIt == outputReadsToReplace.end())
326  continue;
327  // TODO: RewriterBase::replaceAllUsesWith is not currently supported by
328  // DialectConversion. Using it may lead to assertions about mutating
329  // replaced/erased ops. For now, do this RAUW directly, until
330  // ConversionPatternRewriter properly supports RAUW.
331  // See https://github.com/llvm/circt/issues/6795.
332  outputReadIt->second.getResult().replaceAllUsesWith(value);
333  rewriter.eraseOp(outputReadIt->second);
334  }
335 
336  // Erase the get_port ops.
337  for (auto *getPort : getPortsToErase)
338  rewriter.eraseOp(getPort);
339 
340  // And finally erase the instance op.
341  rewriter.eraseOp(op);
342  return success();
343  }
344 
345  ContainerPortInfoMap &portOrder;
346  ContainerHWModSymbolMap &modSymMap;
347 }; // namespace
348 
349 struct ContainersToHWPass
350  : public circt::kanagawa::impl::KanagawaContainersToHWBase<
351  ContainersToHWPass> {
352  void runOnOperation() override;
353 };
354 } // anonymous namespace
355 
356 void ContainersToHWPass::runOnOperation() {
357  auto *ctx = &getContext();
358 
359  // Generate module signatures.
360  ContainerPortInfoMap portOrder;
361  for (auto design : getOperation().getOps<DesignOp>())
362  for (auto container : design.getOps<ContainerOp>())
363  portOrder.try_emplace(container.getInnerRef(),
364  ContainerPortInfo(container));
365 
366  ConversionTarget target(*ctx);
367  ContainerHWModSymbolMap modSymMap;
368  SymbolCache modSymCache;
369  modSymCache.addDefinitions(getOperation());
370  Namespace modNamespace;
371  modNamespace.add(modSymCache);
372  target.addIllegalOp<ContainerOp, ContainerInstanceOp, ThisOp>();
373  target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
374 
375  // Remove the name of the kanagawa.design's from the namespace - The
376  // kanagawa.design op will be removed after this pass, and there may be
377  // kanagawa.component's inside the design that have the same name as the
378  // design; we want that name to persist, and not be falsely considered a
379  // duplicate.
380  for (auto designOp : getOperation().getOps<DesignOp>())
381  modNamespace.erase(designOp.getSymName());
382 
383  // Parts of the conversion patterns will update operations in place, which in
384  // turn requires the updated operations to be legalizeable. These in-place ops
385  // also include kanagawa ops that eventually will get replaced once all of the
386  // patterns apply.
387  target.addLegalDialect<KanagawaDialect>();
388 
389  RewritePatternSet patterns(ctx);
390  patterns.add<ContainerOpConversionPattern>(ctx, modNamespace, portOrder,
391  modSymMap);
392  patterns.add<ContainerInstanceOpConversionPattern>(ctx, portOrder, modSymMap);
393  patterns.add<ThisOpConversionPattern>(ctx);
394 
395  if (failed(
396  applyPartialConversion(getOperation(), target, std::move(patterns))))
397  signalPassFailure();
398 
399  // Delete empty design ops.
400  for (auto design :
401  llvm::make_early_inc_range(getOperation().getOps<DesignOp>()))
402  if (design.getBody().front().empty())
403  design.erase();
404 }
405 
407  return std::make_unique<ContainersToHWPass>();
408 }
static PortInfo getPort(ModuleTy &mod, size_t idx)
Definition: HWOps.cpp:1440
@ Input
Definition: HW.h:35
@ Output
Definition: HW.h:35
A namespace that is used to store existing names and generate new names in some scope within the IR.
Definition: Namespace.h:30
void add(mlir::ModuleOp module)
Definition: Namespace.h:48
bool erase(llvm::StringRef symbol)
Removes a symbol from the namespace.
Definition: Namespace.h:67
StringRef newName(const Twine &name)
Return a unique name, derived from the input name, and add the new name to the internal namespace.
Definition: Namespace.h:85
void addDefinitions(mlir::Operation *top)
Populate the symbol cache with all symbol-defining operations within the 'top' operation.
Definition: SymCache.cpp:23
Default symbol cache implementation; stores associations between names (StringAttr's) to mlir::Operat...
Definition: SymCache.h:85
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:55
std::unique_ptr< mlir::Pass > createContainersToHWPass()
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21