CIRCT  19.0.0git
IbisPortrefLowering.cpp
Go to the documentation of this file.
1 //===- IbisPortrefLowering.cpp - Implementation of PortrefLowering --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "PassDetails.h"
10 
15 
16 #include "mlir/IR/Builders.h"
17 #include "mlir/Transforms/DialectConversion.h"
18 #include "llvm/ADT/TypeSwitch.h"
19 #include "llvm/Support/Debug.h"
20 
21 #define DEBUG_TYPE "ibis-lower-portrefs"
22 
23 using namespace mlir;
24 using namespace circt;
25 using namespace ibis;
26 
27 namespace {
28 
29 class InputPortConversionPattern : public OpConversionPattern<InputPortOp> {
30 public:
31  using OpConversionPattern::OpConversionPattern;
32  using OpAdaptor = typename OpConversionPattern<InputPortOp>::OpAdaptor;
33 
34  LogicalResult
35  matchAndRewrite(InputPortOp op, OpAdaptor adaptor,
36  ConversionPatternRewriter &rewriter) const override {
37  PortRefType innerPortRefType = cast<PortRefType>(op.getType());
38  Type innerType = innerPortRefType.getPortType();
39  Direction d = innerPortRefType.getDirection();
40 
41  // CSE check - CSE should have ensured that only a single port unwrapper was
42  // present, so if this is not the case, the user should run
43  // CSE. This goes for other assumptions in the following code -
44  // we require a CSEd form to avoid having to deal with a bunch of edge
45  // cases.
46  auto portrefUsers = op.getResult().getUsers();
47  size_t nPortrefUsers =
48  std::distance(portrefUsers.begin(), portrefUsers.end());
49  if (nPortrefUsers != 1)
50  return rewriter.notifyMatchFailure(
51  op, "expected a single ibis.port.read as the only user of the input "
52  "port reference, but found multiple readers - please run CSE "
53  "prior to this pass");
54 
55  // A single PortReadOp should be present, which unwraps the portref<portref>
56  // into a portref.
57  PortReadOp portUnwrapper = dyn_cast<PortReadOp>(*portrefUsers.begin());
58  if (!portUnwrapper)
59  return rewriter.notifyMatchFailure(
60  op, "expected a single ibis.port.read as the only user of the input "
61  "port reference");
62 
63  // Replace the inner portref + port access with a "raw" port.
64  OpBuilder::InsertionGuard g(rewriter);
65  rewriter.setInsertionPoint(op);
66  if (d == Direction::Input) {
67  // references to inputs becomes outputs (write from this container)
68  auto rawOutput = rewriter.create<OutputPortOp>(
69  op.getLoc(), op.getNameAttr(), op.getInnerSym(), innerType);
70 
71  // Replace writes to the unwrapped port with writes to the new port.
72  for (auto *unwrappedPortUser :
73  llvm::make_early_inc_range(portUnwrapper.getResult().getUsers())) {
74  PortWriteOp portWriter = dyn_cast<PortWriteOp>(unwrappedPortUser);
75  if (!portWriter || portWriter.getPort() != portUnwrapper.getResult())
76  continue;
77 
78  // Replace the source port of the write op with the new port.
79  rewriter.replaceOpWithNewOp<PortWriteOp>(portWriter, rawOutput,
80  portWriter.getValue());
81  }
82  } else {
83  // References to outputs becomes inputs (read from this container)
84  auto rawInput = rewriter.create<InputPortOp>(
85  op.getLoc(), op.getNameAttr(), op.getInnerSym(), innerType);
86  // TODO: RewriterBase::replaceAllUsesWith is not currently supported by
87  // DialectConversion. Using it may lead to assertions about mutating
88  // replaced/erased ops. For now, do this RAUW directly, until
89  // ConversionPatternRewriter properly supports RAUW.
90  // See https://github.com/llvm/circt/issues/6795.
91  portUnwrapper.getResult().replaceAllUsesWith(rawInput);
92 
93  // Replace all ibis.port.read ops with a read of the new input.
94  for (auto *portUser :
95  llvm::make_early_inc_range(portUnwrapper.getResult().getUsers())) {
96  PortReadOp portReader = dyn_cast<PortReadOp>(portUser);
97  if (!portReader || portReader.getPort() != portUnwrapper.getResult())
98  continue;
99 
100  rewriter.replaceOpWithNewOp<PortReadOp>(portReader, rawInput);
101  }
102  }
103 
104  // Finally, remove the port unwrapper and the original input port.
105  rewriter.eraseOp(portUnwrapper);
106  rewriter.eraseOp(op);
107 
108  return success();
109  }
110 };
111 
112 class OutputPortConversionPattern : public OpConversionPattern<OutputPortOp> {
113 public:
114  using OpConversionPattern::OpConversionPattern;
115  using OpAdaptor = typename OpConversionPattern<OutputPortOp>::OpAdaptor;
116 
117  LogicalResult
118  matchAndRewrite(OutputPortOp op, OpAdaptor adaptor,
119  ConversionPatternRewriter &rewriter) const override {
120  PortRefType innerPortRefType = cast<PortRefType>(op.getType());
121  Type innerType = innerPortRefType.getPortType();
122  Direction d = innerPortRefType.getDirection();
123 
124  // Locate the portwrapper - this is a writeOp with the output portref as
125  // the portref value.
126  PortWriteOp portWrapper;
127  for (auto *user : op.getResult().getUsers()) {
128  auto writeOp = dyn_cast<PortWriteOp>(user);
129  if (writeOp && writeOp.getPort() == op.getResult()) {
130  if (portWrapper)
131  return rewriter.notifyMatchFailure(
132  op, "expected a single ibis.port.write to wrap the output "
133  "portref, but found multiple");
134  portWrapper = writeOp;
135  break;
136  }
137  }
138 
139  if (!portWrapper)
140  return rewriter.notifyMatchFailure(
141  op, "expected an ibis.port.write to wrap the output portref");
142 
143  OpBuilder::InsertionGuard g(rewriter);
144  rewriter.setInsertionPoint(op);
145  if (d == Direction::Input) {
146  // Outputs of inputs are inputs (external driver into this container).
147  // Create the raw input port and write the input port reference with a
148  // read of the raw input port.
149  auto rawInput = rewriter.create<InputPortOp>(
150  op.getLoc(), op.getNameAttr(), op.getInnerSym(), innerType);
151  rewriter.create<PortWriteOp>(
152  op.getLoc(), portWrapper.getValue(),
153  rewriter.create<PortReadOp>(op.getLoc(), rawInput));
154  } else {
155  // Outputs of outputs are outputs (external driver out of this container).
156  // Create the raw output port and do a read of the input port reference.
157  auto rawOutput = rewriter.create<OutputPortOp>(
158  op.getLoc(), op.getNameAttr(), op.getInnerSym(), innerType);
159  rewriter.create<PortWriteOp>(
160  op.getLoc(), rawOutput,
161  rewriter.create<PortReadOp>(op.getLoc(), portWrapper.getValue()));
162  }
163 
164  // Finally, remove the port wrapper and the original output port.
165  rewriter.eraseOp(portWrapper);
166  rewriter.eraseOp(op);
167 
168  return success();
169  }
170 };
171 
172 class GetPortConversionPattern : public OpConversionPattern<GetPortOp> {
173  using OpConversionPattern::OpConversionPattern;
174  using OpAdaptor = typename OpConversionPattern<GetPortOp>::OpAdaptor;
175 
176  LogicalResult
177  matchAndRewrite(GetPortOp op, OpAdaptor adaptor,
178  ConversionPatternRewriter &rewriter) const override {
179  PortRefType outerPortRefType = cast<PortRefType>(op.getType());
180  PortRefType innerPortRefType =
181  cast<PortRefType>(outerPortRefType.getPortType());
182  Type innerType = innerPortRefType.getPortType();
183 
184  Direction outerDirection = outerPortRefType.getDirection();
185  Direction innerDirection = innerPortRefType.getDirection();
186 
187  StringAttr portName = op.getPortSymbolAttr().getAttr();
188 
189  OpBuilder::InsertionGuard g(rewriter);
190  rewriter.setInsertionPoint(op);
191  Operation *wrapper;
192  if (outerDirection == Direction::Input) {
193  // Locate the get_port wrapper - this is a WriteOp with the get_port
194  // result as the portref value.
195  PortWriteOp getPortWrapper;
196  for (auto *user : op.getResult().getUsers()) {
197  auto writeOp = dyn_cast<PortWriteOp>(user);
198  if (!writeOp || writeOp.getPort() != op.getResult())
199  continue;
200 
201  getPortWrapper = writeOp;
202  break;
203  }
204 
205  if (!getPortWrapper)
206  return rewriter.notifyMatchFailure(
207  op, "expected an ibis.port.write to wrap the get_port result");
208  wrapper = getPortWrapper;
209  LLVM_DEBUG(llvm::dbgs() << "Found wrapper: " << *wrapper);
210  if (innerDirection == Direction::Input) {
211  // The portref<in portref<in T>> is now an output port.
212  auto newGetPort =
213  rewriter.create<GetPortOp>(op.getLoc(), op.getInstance(), portName,
215  auto newGetPortVal =
216  rewriter.create<PortReadOp>(op.getLoc(), newGetPort);
217  rewriter.create<PortWriteOp>(op.getLoc(), getPortWrapper.getValue(),
218  newGetPortVal);
219  } else {
220  // The portref<in portref<out T>> is now an input port.
221  auto newGetPort =
222  rewriter.create<GetPortOp>(op.getLoc(), op.getInstance(), portName,
224  auto writeValue =
225  rewriter.create<PortReadOp>(op.getLoc(), getPortWrapper.getValue());
226  rewriter.create<PortWriteOp>(op.getLoc(), newGetPort, writeValue);
227  }
228  } else {
229  PortReadOp getPortUnwrapper;
230  for (auto *user : op.getResult().getUsers()) {
231  auto readOp = dyn_cast<PortReadOp>(user);
232  if (!readOp || readOp.getPort() != op.getResult())
233  continue;
234 
235  getPortUnwrapper = readOp;
236  break;
237  }
238 
239  if (!getPortUnwrapper)
240  return rewriter.notifyMatchFailure(
241  op, "expected an ibis.port.read to unwrap the get_port result");
242  wrapper = getPortUnwrapper;
243 
244  LLVM_DEBUG(llvm::dbgs() << "Found unwrapper: " << *wrapper);
245  if (innerDirection == Direction::Input) {
246  // In this situation, we're retrieving an input port that is sent as an
247  // output of the container: %rr = ibis.get_port %c %c_in :
248  // !ibis.scoperef<...> -> !ibis.portref<out !ibis.portref<in T>>
249  //
250  // Thus we expect one of these cases:
251  // (always). a read op which unwraps the portref<out portref<in T>> into
252  // a portref<in T>
253  // %r = ibis.port.read %rr : !ibis.portref<out !ibis.portref<in T>>
254  // either:
255  // 1. A write to %r which drives the target input port
256  // ibis.port.write %r, %someValue : !ibis.portref<in T>
257  // 2. A write using %r which forwards the input port reference
258  // ibis.port.write %r_fw, %r : !ibis.portref<out !ibis.portref<in
259  // T>>
260  //
261  PortWriteOp portDriver;
262  PortWriteOp portForwardingDriver;
263  for (auto *user : getPortUnwrapper.getResult().getUsers()) {
264  auto writeOp = dyn_cast<PortWriteOp>(user);
265  if (!writeOp)
266  continue;
267 
268  bool isForwarding = writeOp.getPort() != getPortUnwrapper.getResult();
269  if (isForwarding) {
270  if (portForwardingDriver)
271  return rewriter.notifyMatchFailure(
272  op, "expected a single ibis.port.write to use the unwrapped "
273  "get_port result, but found multiple");
274  portForwardingDriver = writeOp;
275  LLVM_DEBUG(llvm::dbgs()
276  << "Found forwarding driver: " << *portForwardingDriver);
277  } else {
278  if (portDriver)
279  return rewriter.notifyMatchFailure(
280  op, "expected a single ibis.port.write to use the unwrapped "
281  "get_port result, but found multiple");
282  portDriver = writeOp;
283  LLVM_DEBUG(llvm::dbgs() << "Found driver: " << *portDriver);
284  }
285  }
286 
287  if (!portDriver && !portForwardingDriver)
288  return rewriter.notifyMatchFailure(
289  op, "expected an ibis.port.write to drive the unwrapped get_port "
290  "result");
291 
292  Value portDriverValue;
293  if (portForwardingDriver) {
294  // In the case of forwarding, it is simplest to just create a new
295  // input port, and write the forwarded value to it. This will allow
296  // this pattern to recurse and eventually reach the case where the
297  // forwarding is resolved through reading/writing the intermediate
298  // inputs.
299  auto fwPortName = rewriter.getStringAttr(portName.strref() + "_fw");
300  auto forwardedInputPort = rewriter.create<InputPortOp>(
301  op.getLoc(), fwPortName, hw::InnerSymAttr::get(fwPortName),
302  innerType);
303 
304  // TODO: RewriterBase::replaceAllUsesWith is not currently supported
305  // by DialectConversion. Using it may lead to assertions about
306  // mutating replaced/erased ops. For now, do this RAUW directly, until
307  // ConversionPatternRewriter properly supports RAUW.
308  // See https://github.com/llvm/circt/issues/6795.
309  getPortUnwrapper.getResult().replaceAllUsesWith(forwardedInputPort);
310  portDriverValue = rewriter.create<PortReadOp>(
311  op.getLoc(), forwardedInputPort.getPort());
312  } else {
313  // Direct assignmenet - the driver value will be the value of
314  // the driver.
315  portDriverValue = portDriver.getValue();
316  rewriter.eraseOp(portDriver);
317  }
318 
319  // Perform assignment to the input port of the target instance using
320  // the driver value.
321  auto rawPort =
322  rewriter.create<GetPortOp>(op.getLoc(), op.getInstance(), portName,
324  rewriter.create<PortWriteOp>(op.getLoc(), rawPort, portDriverValue);
325  } else {
326  // In this situation, we're retrieving an output port that is sent as an
327  // output of the container: %rr = ibis.get_port %c %c_in :
328  // !ibis.scoperef<...> -> !ibis.portref<out !ibis.portref<out T>>
329  //
330  // Thus we expect two ops to be present:
331  // 1. a read op which unwraps the portref<out portref<in T>> into a
332  // portref<in T>
333  // %r = ibis.port.read %rr : !ibis.portref<out !ibis.portref<in T>>
334  // 2. one (or multiple, if not CSEd)
335  //
336  // We then replace the read op with the actual output port of the
337  // container.
338  auto rawPort =
339  rewriter.create<GetPortOp>(op.getLoc(), op.getInstance(), portName,
341 
342  // TODO: RewriterBase::replaceAllUsesWith is not currently supported by
343  // DialectConversion. Using it may lead to assertions about mutating
344  // replaced/erased ops. For now, do this RAUW directly, until
345  // ConversionPatternRewriter properly supports RAUW.
346  // See https://github.com/llvm/circt/issues/6795.
347  getPortUnwrapper.getResult().replaceAllUsesWith(rawPort);
348  }
349  }
350 
351  // Finally, remove the get_port op.
352  rewriter.eraseOp(wrapper);
353  rewriter.eraseOp(op);
354 
355  return success();
356  }
357 };
358 
359 struct PortrefLoweringPass
360  : public IbisPortrefLoweringBase<PortrefLoweringPass> {
361  void runOnOperation() override;
362 };
363 
364 } // anonymous namespace
365 
366 void PortrefLoweringPass::runOnOperation() {
367  auto *ctx = &getContext();
368  ConversionTarget target(*ctx);
369  target.addIllegalOp<InputPortOp, OutputPortOp>();
370  target.addLegalDialect<IbisDialect>();
371 
372  // Ports are legal when they do not have portref types anymore.
373  target.addDynamicallyLegalOp<InputPortOp, OutputPortOp>([&](auto op) {
374  PortRefType portType =
375  cast<PortRefType>(cast<PortOpInterface>(op).getPort().getType());
376  return !isa<PortRefType>(portType.getPortType());
377  });
378 
379  PortReadOp op;
380 
381  // get_port's are legal when they do not have portref types anymore.
382  target.addDynamicallyLegalOp<GetPortOp>([&](GetPortOp op) {
383  PortRefType portType = cast<PortRefType>(op.getPort().getType());
384  return !isa<PortRefType>(portType.getPortType());
385  });
386 
387  RewritePatternSet patterns(ctx);
388  patterns.add<InputPortConversionPattern, OutputPortConversionPattern,
389  GetPortConversionPattern>(ctx);
390 
391  if (failed(
392  applyPartialConversion(getOperation(), target, std::move(patterns))))
393  signalPassFailure();
394 }
395 
396 std::unique_ptr<Pass> circt::ibis::createPortrefLoweringPass() {
397  return std::make_unique<PortrefLoweringPass>();
398 }
static PortInfo getPort(ModuleTy &mod, size_t idx)
Definition: HWOps.cpp:1434
@ Input
Definition: HW.h:35
@ Output
Definition: HW.h:35
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:54
Direction
The direction of a Component or Cell port.
Definition: CalyxOps.h:72
mlir::Type innerType(mlir::Type type)
Definition: ESITypes.cpp:184
std::unique_ptr< mlir::Pass > createPortrefLoweringPass()
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21