CIRCT  20.0.0git
KanagawaPortrefLowering.cpp
Go to the documentation of this file.
1 //===- KanagawaPortrefLowering.cpp - Implementation of PortrefLowering ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
11 #include "mlir/Pass/Pass.h"
12 
17 
18 #include "mlir/IR/Builders.h"
19 #include "mlir/Transforms/DialectConversion.h"
20 #include "llvm/ADT/TypeSwitch.h"
21 #include "llvm/Support/Debug.h"
22 
23 #define DEBUG_TYPE "kanagawa-lower-portrefs"
24 
25 namespace circt {
26 namespace kanagawa {
27 #define GEN_PASS_DEF_KANAGAWAPORTREFLOWERING
28 #include "circt/Dialect/Kanagawa/KanagawaPasses.h.inc"
29 } // namespace kanagawa
30 } // namespace circt
31 
32 using namespace mlir;
33 using namespace circt;
34 using namespace kanagawa;
35 
36 namespace {
37 
38 class InputPortConversionPattern : public OpConversionPattern<InputPortOp> {
39 public:
40  using OpConversionPattern::OpConversionPattern;
41  using OpAdaptor = typename OpConversionPattern<InputPortOp>::OpAdaptor;
42 
43  LogicalResult
44  matchAndRewrite(InputPortOp op, OpAdaptor adaptor,
45  ConversionPatternRewriter &rewriter) const override {
46  PortRefType innerPortRefType = cast<PortRefType>(op.getType());
47  Type innerType = innerPortRefType.getPortType();
48  Direction d = innerPortRefType.getDirection();
49 
50  // CSE check - CSE should have ensured that only a single port unwrapper was
51  // present, so if this is not the case, the user should run
52  // CSE. This goes for other assumptions in the following code -
53  // we require a CSEd form to avoid having to deal with a bunch of edge
54  // cases.
55  auto portrefUsers = op.getResult().getUsers();
56  size_t nPortrefUsers =
57  std::distance(portrefUsers.begin(), portrefUsers.end());
58  if (nPortrefUsers != 1)
59  return rewriter.notifyMatchFailure(
60  op,
61  "expected a single kanagawa.port.read as the only user of the input "
62  "port reference, but found multiple readers - please run CSE "
63  "prior to this pass");
64 
65  // A single PortReadOp should be present, which unwraps the portref<portref>
66  // into a portref.
67  PortReadOp portUnwrapper = dyn_cast<PortReadOp>(*portrefUsers.begin());
68  if (!portUnwrapper)
69  return rewriter.notifyMatchFailure(
70  op,
71  "expected a single kanagawa.port.read as the only user of the input "
72  "port reference");
73 
74  // Replace the inner portref + port access with a "raw" port.
75  OpBuilder::InsertionGuard g(rewriter);
76  rewriter.setInsertionPoint(op);
77  if (d == Direction::Input) {
78  // references to inputs becomes outputs (write from this container)
79  auto rawOutput = rewriter.create<OutputPortOp>(
80  op.getLoc(), op.getInnerSym(), innerType, op.getNameAttr());
81 
82  // Replace writes to the unwrapped port with writes to the new port.
83  for (auto *unwrappedPortUser :
84  llvm::make_early_inc_range(portUnwrapper.getResult().getUsers())) {
85  PortWriteOp portWriter = dyn_cast<PortWriteOp>(unwrappedPortUser);
86  if (!portWriter || portWriter.getPort() != portUnwrapper.getResult())
87  continue;
88 
89  // Replace the source port of the write op with the new port.
90  rewriter.replaceOpWithNewOp<PortWriteOp>(portWriter, rawOutput,
91  portWriter.getValue());
92  }
93  } else {
94  // References to outputs becomes inputs (read from this container)
95  auto rawInput = rewriter.create<InputPortOp>(
96  op.getLoc(), op.getInnerSym(), innerType, op.getNameAttr());
97  // TODO: RewriterBase::replaceAllUsesWith is not currently supported by
98  // DialectConversion. Using it may lead to assertions about mutating
99  // replaced/erased ops. For now, do this RAUW directly, until
100  // ConversionPatternRewriter properly supports RAUW.
101  // See https://github.com/llvm/circt/issues/6795.
102  portUnwrapper.getResult().replaceAllUsesWith(rawInput);
103 
104  // Replace all kanagawa.port.read ops with a read of the new input.
105  for (auto *portUser :
106  llvm::make_early_inc_range(portUnwrapper.getResult().getUsers())) {
107  PortReadOp portReader = dyn_cast<PortReadOp>(portUser);
108  if (!portReader || portReader.getPort() != portUnwrapper.getResult())
109  continue;
110 
111  rewriter.replaceOpWithNewOp<PortReadOp>(portReader, rawInput);
112  }
113  }
114 
115  // Finally, remove the port unwrapper and the original input port.
116  rewriter.eraseOp(portUnwrapper);
117  rewriter.eraseOp(op);
118 
119  return success();
120  }
121 };
122 
123 class OutputPortConversionPattern : public OpConversionPattern<OutputPortOp> {
124 public:
125  using OpConversionPattern::OpConversionPattern;
126  using OpAdaptor = typename OpConversionPattern<OutputPortOp>::OpAdaptor;
127 
128  LogicalResult
129  matchAndRewrite(OutputPortOp op, OpAdaptor adaptor,
130  ConversionPatternRewriter &rewriter) const override {
131  PortRefType innerPortRefType = cast<PortRefType>(op.getType());
132  Type innerType = innerPortRefType.getPortType();
133  Direction d = innerPortRefType.getDirection();
134 
135  // Locate the portwrapper - this is a writeOp with the output portref as
136  // the portref value.
137  PortWriteOp portWrapper;
138  for (auto *user : op.getResult().getUsers()) {
139  auto writeOp = dyn_cast<PortWriteOp>(user);
140  if (writeOp && writeOp.getPort() == op.getResult()) {
141  if (portWrapper)
142  return rewriter.notifyMatchFailure(
143  op, "expected a single kanagawa.port.write to wrap the output "
144  "portref, but found multiple");
145  portWrapper = writeOp;
146  break;
147  }
148  }
149 
150  if (!portWrapper)
151  return rewriter.notifyMatchFailure(
152  op, "expected an kanagawa.port.write to wrap the output portref");
153 
154  OpBuilder::InsertionGuard g(rewriter);
155  rewriter.setInsertionPoint(op);
156  if (d == Direction::Input) {
157  // Outputs of inputs are inputs (external driver into this container).
158  // Create the raw input port and write the input port reference with a
159  // read of the raw input port.
160  auto rawInput = rewriter.create<InputPortOp>(
161  op.getLoc(), op.getInnerSym(), innerType, op.getNameAttr());
162  rewriter.create<PortWriteOp>(
163  op.getLoc(), portWrapper.getValue(),
164  rewriter.create<PortReadOp>(op.getLoc(), rawInput));
165  } else {
166  // Outputs of outputs are outputs (external driver out of this container).
167  // Create the raw output port and do a read of the input port reference.
168  auto rawOutput = rewriter.create<OutputPortOp>(
169  op.getLoc(), op.getInnerSym(), innerType, op.getNameAttr());
170  rewriter.create<PortWriteOp>(
171  op.getLoc(), rawOutput,
172  rewriter.create<PortReadOp>(op.getLoc(), portWrapper.getValue()));
173  }
174 
175  // Finally, remove the port wrapper and the original output port.
176  rewriter.eraseOp(portWrapper);
177  rewriter.eraseOp(op);
178 
179  return success();
180  }
181 };
182 
183 class GetPortConversionPattern : public OpConversionPattern<GetPortOp> {
184  using OpConversionPattern::OpConversionPattern;
185  using OpAdaptor = typename OpConversionPattern<GetPortOp>::OpAdaptor;
186 
187  LogicalResult
188  matchAndRewrite(GetPortOp op, OpAdaptor adaptor,
189  ConversionPatternRewriter &rewriter) const override {
190  PortRefType outerPortRefType = cast<PortRefType>(op.getType());
191  PortRefType innerPortRefType =
192  cast<PortRefType>(outerPortRefType.getPortType());
193  Type innerType = innerPortRefType.getPortType();
194 
195  Direction outerDirection = outerPortRefType.getDirection();
196  Direction innerDirection = innerPortRefType.getDirection();
197 
198  StringAttr portName = op.getPortSymbolAttr().getAttr();
199 
200  OpBuilder::InsertionGuard g(rewriter);
201  rewriter.setInsertionPoint(op);
202  Operation *wrapper;
203  if (outerDirection == Direction::Input) {
204  // Locate the get_port wrapper - this is a WriteOp with the get_port
205  // result as the portref value.
206  PortWriteOp getPortWrapper;
207  for (auto *user : op.getResult().getUsers()) {
208  auto writeOp = dyn_cast<PortWriteOp>(user);
209  if (!writeOp || writeOp.getPort() != op.getResult())
210  continue;
211 
212  getPortWrapper = writeOp;
213  break;
214  }
215 
216  if (!getPortWrapper)
217  return rewriter.notifyMatchFailure(
218  op, "expected an kanagawa.port.write to wrap the get_port result");
219  wrapper = getPortWrapper;
220  LLVM_DEBUG(llvm::dbgs() << "Found wrapper: " << *wrapper);
221  if (innerDirection == Direction::Input) {
222  // The portref<in portref<in T>> is now an output port.
223  auto newGetPort =
224  rewriter.create<GetPortOp>(op.getLoc(), op.getInstance(), portName,
226  auto newGetPortVal =
227  rewriter.create<PortReadOp>(op.getLoc(), newGetPort);
228  rewriter.create<PortWriteOp>(op.getLoc(), getPortWrapper.getValue(),
229  newGetPortVal);
230  } else {
231  // The portref<in portref<out T>> is now an input port.
232  auto newGetPort =
233  rewriter.create<GetPortOp>(op.getLoc(), op.getInstance(), portName,
235  auto writeValue =
236  rewriter.create<PortReadOp>(op.getLoc(), getPortWrapper.getValue());
237  rewriter.create<PortWriteOp>(op.getLoc(), newGetPort, writeValue);
238  }
239  } else {
240  PortReadOp getPortUnwrapper;
241  for (auto *user : op.getResult().getUsers()) {
242  auto readOp = dyn_cast<PortReadOp>(user);
243  if (!readOp || readOp.getPort() != op.getResult())
244  continue;
245 
246  getPortUnwrapper = readOp;
247  break;
248  }
249 
250  if (!getPortUnwrapper)
251  return rewriter.notifyMatchFailure(
252  op, "expected an kanagawa.port.read to unwrap the get_port result");
253  wrapper = getPortUnwrapper;
254 
255  LLVM_DEBUG(llvm::dbgs() << "Found unwrapper: " << *wrapper);
256  if (innerDirection == Direction::Input) {
257  // In this situation, we're retrieving an input port that is sent as an
258  // output of the container: %rr = kanagawa.get_port %c %c_in :
259  // !kanagawa.scoperef<...> -> !kanagawa.portref<out !kanagawa.portref<in
260  // T>>
261  //
262  // Thus we expect one of these cases:
263  // (always). a read op which unwraps the portref<out portref<in T>> into
264  // a portref<in T>
265  // %r = kanagawa.port.read %rr : !kanagawa.portref<out
266  // !kanagawa.portref<in T>>
267  // either:
268  // 1. A write to %r which drives the target input port
269  // kanagawa.port.write %r, %someValue : !kanagawa.portref<in T>
270  // 2. A write using %r which forwards the input port reference
271  // kanagawa.port.write %r_fw, %r : !kanagawa.portref<out
272  // !kanagawa.portref<in T>>
273  //
274  PortWriteOp portDriver;
275  PortWriteOp portForwardingDriver;
276  for (auto *user : getPortUnwrapper.getResult().getUsers()) {
277  auto writeOp = dyn_cast<PortWriteOp>(user);
278  if (!writeOp)
279  continue;
280 
281  bool isForwarding = writeOp.getPort() != getPortUnwrapper.getResult();
282  if (isForwarding) {
283  if (portForwardingDriver)
284  return rewriter.notifyMatchFailure(
285  op,
286  "expected a single kanagawa.port.write to use the unwrapped "
287  "get_port result, but found multiple");
288  portForwardingDriver = writeOp;
289  LLVM_DEBUG(llvm::dbgs()
290  << "Found forwarding driver: " << *portForwardingDriver);
291  } else {
292  if (portDriver)
293  return rewriter.notifyMatchFailure(
294  op,
295  "expected a single kanagawa.port.write to use the unwrapped "
296  "get_port result, but found multiple");
297  portDriver = writeOp;
298  LLVM_DEBUG(llvm::dbgs() << "Found driver: " << *portDriver);
299  }
300  }
301 
302  if (!portDriver && !portForwardingDriver)
303  return rewriter.notifyMatchFailure(
304  op,
305  "expected an kanagawa.port.write to drive the unwrapped get_port "
306  "result");
307 
308  Value portDriverValue;
309  if (portForwardingDriver) {
310  // In the case of forwarding, it is simplest to just create a new
311  // input port, and write the forwarded value to it. This will allow
312  // this pattern to recurse and eventually reach the case where the
313  // forwarding is resolved through reading/writing the intermediate
314  // inputs.
315  auto fwPortName = rewriter.getStringAttr(portName.strref() + "_fw");
316  auto forwardedInputPort = rewriter.create<InputPortOp>(
317  op.getLoc(), hw::InnerSymAttr::get(fwPortName), innerType,
318  fwPortName);
319 
320  // TODO: RewriterBase::replaceAllUsesWith is not currently supported
321  // by DialectConversion. Using it may lead to assertions about
322  // mutating replaced/erased ops. For now, do this RAUW directly, until
323  // ConversionPatternRewriter properly supports RAUW.
324  // See https://github.com/llvm/circt/issues/6795.
325  getPortUnwrapper.getResult().replaceAllUsesWith(forwardedInputPort);
326  portDriverValue = rewriter.create<PortReadOp>(
327  op.getLoc(), forwardedInputPort.getPort());
328  } else {
329  // Direct assignmenet - the driver value will be the value of
330  // the driver.
331  portDriverValue = portDriver.getValue();
332  rewriter.eraseOp(portDriver);
333  }
334 
335  // Perform assignment to the input port of the target instance using
336  // the driver value.
337  auto rawPort =
338  rewriter.create<GetPortOp>(op.getLoc(), op.getInstance(), portName,
340  rewriter.create<PortWriteOp>(op.getLoc(), rawPort, portDriverValue);
341  } else {
342  // In this situation, we're retrieving an output port that is sent as an
343  // output of the container: %rr = kanagawa.get_port %c %c_in :
344  // !kanagawa.scoperef<...> -> !kanagawa.portref<out
345  // !kanagawa.portref<out T>>
346  //
347  // Thus we expect two ops to be present:
348  // 1. a read op which unwraps the portref<out portref<in T>> into a
349  // portref<in T>
350  // %r = kanagawa.port.read %rr : !kanagawa.portref<out
351  // !kanagawa.portref<in T>>
352  // 2. one (or multiple, if not CSEd)
353  //
354  // We then replace the read op with the actual output port of the
355  // container.
356  auto rawPort =
357  rewriter.create<GetPortOp>(op.getLoc(), op.getInstance(), portName,
359 
360  // TODO: RewriterBase::replaceAllUsesWith is not currently supported by
361  // DialectConversion. Using it may lead to assertions about mutating
362  // replaced/erased ops. For now, do this RAUW directly, until
363  // ConversionPatternRewriter properly supports RAUW.
364  // See https://github.com/llvm/circt/issues/6795.
365  getPortUnwrapper.getResult().replaceAllUsesWith(rawPort);
366  }
367  }
368 
369  // Finally, remove the get_port op.
370  rewriter.eraseOp(wrapper);
371  rewriter.eraseOp(op);
372 
373  return success();
374  }
375 };
376 
377 struct PortrefLoweringPass
378  : public circt::kanagawa::impl::KanagawaPortrefLoweringBase<
379  PortrefLoweringPass> {
380  void runOnOperation() override;
381 };
382 
383 } // anonymous namespace
384 
385 void PortrefLoweringPass::runOnOperation() {
386  auto *ctx = &getContext();
387  ConversionTarget target(*ctx);
388  target.addIllegalOp<InputPortOp, OutputPortOp>();
389  target.addLegalDialect<KanagawaDialect>();
390 
391  // Ports are legal when they do not have portref types anymore.
392  target.addDynamicallyLegalOp<InputPortOp, OutputPortOp>([&](auto op) {
393  PortRefType portType =
394  cast<PortRefType>(cast<PortOpInterface>(op).getPort().getType());
395  return !isa<PortRefType>(portType.getPortType());
396  });
397 
398  PortReadOp op;
399 
400  // get_port's are legal when they do not have portref types anymore.
401  target.addDynamicallyLegalOp<GetPortOp>([&](GetPortOp op) {
402  PortRefType portType = cast<PortRefType>(op.getPort().getType());
403  return !isa<PortRefType>(portType.getPortType());
404  });
405 
406  RewritePatternSet patterns(ctx);
407  patterns.add<InputPortConversionPattern, OutputPortConversionPattern,
408  GetPortConversionPattern>(ctx);
409 
410  if (failed(
411  applyPartialConversion(getOperation(), target, std::move(patterns))))
412  signalPassFailure();
413 }
414 
416  return std::make_unique<PortrefLoweringPass>();
417 }
static PortInfo getPort(ModuleTy &mod, size_t idx)
Definition: HWOps.cpp:1440
@ Input
Definition: HW.h:35
@ Output
Definition: HW.h:35
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:55
Direction
The direction of a Component or Cell port.
Definition: CalyxOps.h:76
mlir::Type innerType(mlir::Type type)
Definition: ESITypes.cpp:184
std::unique_ptr< mlir::Pass > createPortrefLoweringPass()
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21