CIRCT  20.0.0git
IbisPortrefLowering.cpp
Go to the documentation of this file.
1 //===- IbisPortrefLowering.cpp - Implementation of PortrefLowering --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
11 #include "mlir/Pass/Pass.h"
12 
17 
18 #include "mlir/IR/Builders.h"
19 #include "mlir/Transforms/DialectConversion.h"
20 #include "llvm/ADT/TypeSwitch.h"
21 #include "llvm/Support/Debug.h"
22 
23 #define DEBUG_TYPE "ibis-lower-portrefs"
24 
25 namespace circt {
26 namespace ibis {
27 #define GEN_PASS_DEF_IBISPORTREFLOWERING
28 #include "circt/Dialect/Ibis/IbisPasses.h.inc"
29 } // namespace ibis
30 } // namespace circt
31 
32 using namespace mlir;
33 using namespace circt;
34 using namespace ibis;
35 
36 namespace {
37 
38 class InputPortConversionPattern : public OpConversionPattern<InputPortOp> {
39 public:
40  using OpConversionPattern::OpConversionPattern;
41  using OpAdaptor = typename OpConversionPattern<InputPortOp>::OpAdaptor;
42 
43  LogicalResult
44  matchAndRewrite(InputPortOp op, OpAdaptor adaptor,
45  ConversionPatternRewriter &rewriter) const override {
46  PortRefType innerPortRefType = cast<PortRefType>(op.getType());
47  Type innerType = innerPortRefType.getPortType();
48  Direction d = innerPortRefType.getDirection();
49 
50  // CSE check - CSE should have ensured that only a single port unwrapper was
51  // present, so if this is not the case, the user should run
52  // CSE. This goes for other assumptions in the following code -
53  // we require a CSEd form to avoid having to deal with a bunch of edge
54  // cases.
55  auto portrefUsers = op.getResult().getUsers();
56  size_t nPortrefUsers =
57  std::distance(portrefUsers.begin(), portrefUsers.end());
58  if (nPortrefUsers != 1)
59  return rewriter.notifyMatchFailure(
60  op, "expected a single ibis.port.read as the only user of the input "
61  "port reference, but found multiple readers - please run CSE "
62  "prior to this pass");
63 
64  // A single PortReadOp should be present, which unwraps the portref<portref>
65  // into a portref.
66  PortReadOp portUnwrapper = dyn_cast<PortReadOp>(*portrefUsers.begin());
67  if (!portUnwrapper)
68  return rewriter.notifyMatchFailure(
69  op, "expected a single ibis.port.read as the only user of the input "
70  "port reference");
71 
72  // Replace the inner portref + port access with a "raw" port.
73  OpBuilder::InsertionGuard g(rewriter);
74  rewriter.setInsertionPoint(op);
75  if (d == Direction::Input) {
76  // references to inputs becomes outputs (write from this container)
77  auto rawOutput = rewriter.create<OutputPortOp>(
78  op.getLoc(), op.getInnerSym(), innerType, op.getNameAttr());
79 
80  // Replace writes to the unwrapped port with writes to the new port.
81  for (auto *unwrappedPortUser :
82  llvm::make_early_inc_range(portUnwrapper.getResult().getUsers())) {
83  PortWriteOp portWriter = dyn_cast<PortWriteOp>(unwrappedPortUser);
84  if (!portWriter || portWriter.getPort() != portUnwrapper.getResult())
85  continue;
86 
87  // Replace the source port of the write op with the new port.
88  rewriter.replaceOpWithNewOp<PortWriteOp>(portWriter, rawOutput,
89  portWriter.getValue());
90  }
91  } else {
92  // References to outputs becomes inputs (read from this container)
93  auto rawInput = rewriter.create<InputPortOp>(
94  op.getLoc(), op.getInnerSym(), innerType, op.getNameAttr());
95  // TODO: RewriterBase::replaceAllUsesWith is not currently supported by
96  // DialectConversion. Using it may lead to assertions about mutating
97  // replaced/erased ops. For now, do this RAUW directly, until
98  // ConversionPatternRewriter properly supports RAUW.
99  // See https://github.com/llvm/circt/issues/6795.
100  portUnwrapper.getResult().replaceAllUsesWith(rawInput);
101 
102  // Replace all ibis.port.read ops with a read of the new input.
103  for (auto *portUser :
104  llvm::make_early_inc_range(portUnwrapper.getResult().getUsers())) {
105  PortReadOp portReader = dyn_cast<PortReadOp>(portUser);
106  if (!portReader || portReader.getPort() != portUnwrapper.getResult())
107  continue;
108 
109  rewriter.replaceOpWithNewOp<PortReadOp>(portReader, rawInput);
110  }
111  }
112 
113  // Finally, remove the port unwrapper and the original input port.
114  rewriter.eraseOp(portUnwrapper);
115  rewriter.eraseOp(op);
116 
117  return success();
118  }
119 };
120 
121 class OutputPortConversionPattern : public OpConversionPattern<OutputPortOp> {
122 public:
123  using OpConversionPattern::OpConversionPattern;
124  using OpAdaptor = typename OpConversionPattern<OutputPortOp>::OpAdaptor;
125 
126  LogicalResult
127  matchAndRewrite(OutputPortOp op, OpAdaptor adaptor,
128  ConversionPatternRewriter &rewriter) const override {
129  PortRefType innerPortRefType = cast<PortRefType>(op.getType());
130  Type innerType = innerPortRefType.getPortType();
131  Direction d = innerPortRefType.getDirection();
132 
133  // Locate the portwrapper - this is a writeOp with the output portref as
134  // the portref value.
135  PortWriteOp portWrapper;
136  for (auto *user : op.getResult().getUsers()) {
137  auto writeOp = dyn_cast<PortWriteOp>(user);
138  if (writeOp && writeOp.getPort() == op.getResult()) {
139  if (portWrapper)
140  return rewriter.notifyMatchFailure(
141  op, "expected a single ibis.port.write to wrap the output "
142  "portref, but found multiple");
143  portWrapper = writeOp;
144  break;
145  }
146  }
147 
148  if (!portWrapper)
149  return rewriter.notifyMatchFailure(
150  op, "expected an ibis.port.write to wrap the output portref");
151 
152  OpBuilder::InsertionGuard g(rewriter);
153  rewriter.setInsertionPoint(op);
154  if (d == Direction::Input) {
155  // Outputs of inputs are inputs (external driver into this container).
156  // Create the raw input port and write the input port reference with a
157  // read of the raw input port.
158  auto rawInput = rewriter.create<InputPortOp>(
159  op.getLoc(), op.getInnerSym(), innerType, op.getNameAttr());
160  rewriter.create<PortWriteOp>(
161  op.getLoc(), portWrapper.getValue(),
162  rewriter.create<PortReadOp>(op.getLoc(), rawInput));
163  } else {
164  // Outputs of outputs are outputs (external driver out of this container).
165  // Create the raw output port and do a read of the input port reference.
166  auto rawOutput = rewriter.create<OutputPortOp>(
167  op.getLoc(), op.getInnerSym(), innerType, op.getNameAttr());
168  rewriter.create<PortWriteOp>(
169  op.getLoc(), rawOutput,
170  rewriter.create<PortReadOp>(op.getLoc(), portWrapper.getValue()));
171  }
172 
173  // Finally, remove the port wrapper and the original output port.
174  rewriter.eraseOp(portWrapper);
175  rewriter.eraseOp(op);
176 
177  return success();
178  }
179 };
180 
181 class GetPortConversionPattern : public OpConversionPattern<GetPortOp> {
182  using OpConversionPattern::OpConversionPattern;
183  using OpAdaptor = typename OpConversionPattern<GetPortOp>::OpAdaptor;
184 
185  LogicalResult
186  matchAndRewrite(GetPortOp op, OpAdaptor adaptor,
187  ConversionPatternRewriter &rewriter) const override {
188  PortRefType outerPortRefType = cast<PortRefType>(op.getType());
189  PortRefType innerPortRefType =
190  cast<PortRefType>(outerPortRefType.getPortType());
191  Type innerType = innerPortRefType.getPortType();
192 
193  Direction outerDirection = outerPortRefType.getDirection();
194  Direction innerDirection = innerPortRefType.getDirection();
195 
196  StringAttr portName = op.getPortSymbolAttr().getAttr();
197 
198  OpBuilder::InsertionGuard g(rewriter);
199  rewriter.setInsertionPoint(op);
200  Operation *wrapper;
201  if (outerDirection == Direction::Input) {
202  // Locate the get_port wrapper - this is a WriteOp with the get_port
203  // result as the portref value.
204  PortWriteOp getPortWrapper;
205  for (auto *user : op.getResult().getUsers()) {
206  auto writeOp = dyn_cast<PortWriteOp>(user);
207  if (!writeOp || writeOp.getPort() != op.getResult())
208  continue;
209 
210  getPortWrapper = writeOp;
211  break;
212  }
213 
214  if (!getPortWrapper)
215  return rewriter.notifyMatchFailure(
216  op, "expected an ibis.port.write to wrap the get_port result");
217  wrapper = getPortWrapper;
218  LLVM_DEBUG(llvm::dbgs() << "Found wrapper: " << *wrapper);
219  if (innerDirection == Direction::Input) {
220  // The portref<in portref<in T>> is now an output port.
221  auto newGetPort =
222  rewriter.create<GetPortOp>(op.getLoc(), op.getInstance(), portName,
224  auto newGetPortVal =
225  rewriter.create<PortReadOp>(op.getLoc(), newGetPort);
226  rewriter.create<PortWriteOp>(op.getLoc(), getPortWrapper.getValue(),
227  newGetPortVal);
228  } else {
229  // The portref<in portref<out T>> is now an input port.
230  auto newGetPort =
231  rewriter.create<GetPortOp>(op.getLoc(), op.getInstance(), portName,
233  auto writeValue =
234  rewriter.create<PortReadOp>(op.getLoc(), getPortWrapper.getValue());
235  rewriter.create<PortWriteOp>(op.getLoc(), newGetPort, writeValue);
236  }
237  } else {
238  PortReadOp getPortUnwrapper;
239  for (auto *user : op.getResult().getUsers()) {
240  auto readOp = dyn_cast<PortReadOp>(user);
241  if (!readOp || readOp.getPort() != op.getResult())
242  continue;
243 
244  getPortUnwrapper = readOp;
245  break;
246  }
247 
248  if (!getPortUnwrapper)
249  return rewriter.notifyMatchFailure(
250  op, "expected an ibis.port.read to unwrap the get_port result");
251  wrapper = getPortUnwrapper;
252 
253  LLVM_DEBUG(llvm::dbgs() << "Found unwrapper: " << *wrapper);
254  if (innerDirection == Direction::Input) {
255  // In this situation, we're retrieving an input port that is sent as an
256  // output of the container: %rr = ibis.get_port %c %c_in :
257  // !ibis.scoperef<...> -> !ibis.portref<out !ibis.portref<in T>>
258  //
259  // Thus we expect one of these cases:
260  // (always). a read op which unwraps the portref<out portref<in T>> into
261  // a portref<in T>
262  // %r = ibis.port.read %rr : !ibis.portref<out !ibis.portref<in T>>
263  // either:
264  // 1. A write to %r which drives the target input port
265  // ibis.port.write %r, %someValue : !ibis.portref<in T>
266  // 2. A write using %r which forwards the input port reference
267  // ibis.port.write %r_fw, %r : !ibis.portref<out !ibis.portref<in
268  // T>>
269  //
270  PortWriteOp portDriver;
271  PortWriteOp portForwardingDriver;
272  for (auto *user : getPortUnwrapper.getResult().getUsers()) {
273  auto writeOp = dyn_cast<PortWriteOp>(user);
274  if (!writeOp)
275  continue;
276 
277  bool isForwarding = writeOp.getPort() != getPortUnwrapper.getResult();
278  if (isForwarding) {
279  if (portForwardingDriver)
280  return rewriter.notifyMatchFailure(
281  op, "expected a single ibis.port.write to use the unwrapped "
282  "get_port result, but found multiple");
283  portForwardingDriver = writeOp;
284  LLVM_DEBUG(llvm::dbgs()
285  << "Found forwarding driver: " << *portForwardingDriver);
286  } else {
287  if (portDriver)
288  return rewriter.notifyMatchFailure(
289  op, "expected a single ibis.port.write to use the unwrapped "
290  "get_port result, but found multiple");
291  portDriver = writeOp;
292  LLVM_DEBUG(llvm::dbgs() << "Found driver: " << *portDriver);
293  }
294  }
295 
296  if (!portDriver && !portForwardingDriver)
297  return rewriter.notifyMatchFailure(
298  op, "expected an ibis.port.write to drive the unwrapped get_port "
299  "result");
300 
301  Value portDriverValue;
302  if (portForwardingDriver) {
303  // In the case of forwarding, it is simplest to just create a new
304  // input port, and write the forwarded value to it. This will allow
305  // this pattern to recurse and eventually reach the case where the
306  // forwarding is resolved through reading/writing the intermediate
307  // inputs.
308  auto fwPortName = rewriter.getStringAttr(portName.strref() + "_fw");
309  auto forwardedInputPort = rewriter.create<InputPortOp>(
310  op.getLoc(), hw::InnerSymAttr::get(fwPortName), innerType,
311  fwPortName);
312 
313  // TODO: RewriterBase::replaceAllUsesWith is not currently supported
314  // by DialectConversion. Using it may lead to assertions about
315  // mutating replaced/erased ops. For now, do this RAUW directly, until
316  // ConversionPatternRewriter properly supports RAUW.
317  // See https://github.com/llvm/circt/issues/6795.
318  getPortUnwrapper.getResult().replaceAllUsesWith(forwardedInputPort);
319  portDriverValue = rewriter.create<PortReadOp>(
320  op.getLoc(), forwardedInputPort.getPort());
321  } else {
322  // Direct assignmenet - the driver value will be the value of
323  // the driver.
324  portDriverValue = portDriver.getValue();
325  rewriter.eraseOp(portDriver);
326  }
327 
328  // Perform assignment to the input port of the target instance using
329  // the driver value.
330  auto rawPort =
331  rewriter.create<GetPortOp>(op.getLoc(), op.getInstance(), portName,
333  rewriter.create<PortWriteOp>(op.getLoc(), rawPort, portDriverValue);
334  } else {
335  // In this situation, we're retrieving an output port that is sent as an
336  // output of the container: %rr = ibis.get_port %c %c_in :
337  // !ibis.scoperef<...> -> !ibis.portref<out !ibis.portref<out T>>
338  //
339  // Thus we expect two ops to be present:
340  // 1. a read op which unwraps the portref<out portref<in T>> into a
341  // portref<in T>
342  // %r = ibis.port.read %rr : !ibis.portref<out !ibis.portref<in T>>
343  // 2. one (or multiple, if not CSEd)
344  //
345  // We then replace the read op with the actual output port of the
346  // container.
347  auto rawPort =
348  rewriter.create<GetPortOp>(op.getLoc(), op.getInstance(), portName,
350 
351  // TODO: RewriterBase::replaceAllUsesWith is not currently supported by
352  // DialectConversion. Using it may lead to assertions about mutating
353  // replaced/erased ops. For now, do this RAUW directly, until
354  // ConversionPatternRewriter properly supports RAUW.
355  // See https://github.com/llvm/circt/issues/6795.
356  getPortUnwrapper.getResult().replaceAllUsesWith(rawPort);
357  }
358  }
359 
360  // Finally, remove the get_port op.
361  rewriter.eraseOp(wrapper);
362  rewriter.eraseOp(op);
363 
364  return success();
365  }
366 };
367 
368 struct PortrefLoweringPass
369  : public circt::ibis::impl::IbisPortrefLoweringBase<PortrefLoweringPass> {
370  void runOnOperation() override;
371 };
372 
373 } // anonymous namespace
374 
375 void PortrefLoweringPass::runOnOperation() {
376  auto *ctx = &getContext();
377  ConversionTarget target(*ctx);
378  target.addIllegalOp<InputPortOp, OutputPortOp>();
379  target.addLegalDialect<IbisDialect>();
380 
381  // Ports are legal when they do not have portref types anymore.
382  target.addDynamicallyLegalOp<InputPortOp, OutputPortOp>([&](auto op) {
383  PortRefType portType =
384  cast<PortRefType>(cast<PortOpInterface>(op).getPort().getType());
385  return !isa<PortRefType>(portType.getPortType());
386  });
387 
388  PortReadOp op;
389 
390  // get_port's are legal when they do not have portref types anymore.
391  target.addDynamicallyLegalOp<GetPortOp>([&](GetPortOp op) {
392  PortRefType portType = cast<PortRefType>(op.getPort().getType());
393  return !isa<PortRefType>(portType.getPortType());
394  });
395 
396  RewritePatternSet patterns(ctx);
397  patterns.add<InputPortConversionPattern, OutputPortConversionPattern,
398  GetPortConversionPattern>(ctx);
399 
400  if (failed(
401  applyPartialConversion(getOperation(), target, std::move(patterns))))
402  signalPassFailure();
403 }
404 
405 std::unique_ptr<Pass> circt::ibis::createPortrefLoweringPass() {
406  return std::make_unique<PortrefLoweringPass>();
407 }
static PortInfo getPort(ModuleTy &mod, size_t idx)
Definition: HWOps.cpp:1434
@ Input
Definition: HW.h:35
@ Output
Definition: HW.h:35
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:55
Direction
The direction of a Component or Cell port.
Definition: CalyxOps.h:72
mlir::Type innerType(mlir::Type type)
Definition: ESITypes.cpp:184
std::unique_ptr< mlir::Pass > createPortrefLoweringPass()
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21