Loading [MathJax]/extensions/tex2jax.js
CIRCT 22.0.0git
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
KanagawaPortrefLowering.cpp
Go to the documentation of this file.
1//===- KanagawaPortrefLowering.cpp - Implementation of PortrefLowering ----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
11#include "mlir/Pass/Pass.h"
12
17
18#include "mlir/IR/Builders.h"
19#include "mlir/Transforms/DialectConversion.h"
20#include "llvm/ADT/TypeSwitch.h"
21#include "llvm/Support/Debug.h"
22
23#define DEBUG_TYPE "kanagawa-lower-portrefs"
24
25namespace circt {
26namespace kanagawa {
27#define GEN_PASS_DEF_KANAGAWAPORTREFLOWERING
28#include "circt/Dialect/Kanagawa/KanagawaPasses.h.inc"
29} // namespace kanagawa
30} // namespace circt
31
32using namespace mlir;
33using namespace circt;
34using namespace kanagawa;
35
36namespace {
37
38class InputPortConversionPattern : public OpConversionPattern<InputPortOp> {
39public:
40 using OpConversionPattern::OpConversionPattern;
41 using OpAdaptor = typename OpConversionPattern<InputPortOp>::OpAdaptor;
42
43 LogicalResult
44 matchAndRewrite(InputPortOp op, OpAdaptor adaptor,
45 ConversionPatternRewriter &rewriter) const override {
46 PortRefType innerPortRefType = cast<PortRefType>(op.getType());
47 Type innerType = innerPortRefType.getPortType();
48 Direction d = innerPortRefType.getDirection();
49
50 // CSE check - CSE should have ensured that only a single port unwrapper was
51 // present, so if this is not the case, the user should run
52 // CSE. This goes for other assumptions in the following code -
53 // we require a CSEd form to avoid having to deal with a bunch of edge
54 // cases.
55 auto portrefUsers = op.getResult().getUsers();
56 size_t nPortrefUsers =
57 std::distance(portrefUsers.begin(), portrefUsers.end());
58 if (nPortrefUsers != 1)
59 return rewriter.notifyMatchFailure(
60 op,
61 "expected a single kanagawa.port.read as the only user of the input "
62 "port reference, but found multiple readers - please run CSE "
63 "prior to this pass");
64
65 // A single PortReadOp should be present, which unwraps the portref<portref>
66 // into a portref.
67 PortReadOp portUnwrapper = dyn_cast<PortReadOp>(*portrefUsers.begin());
68 if (!portUnwrapper)
69 return rewriter.notifyMatchFailure(
70 op,
71 "expected a single kanagawa.port.read as the only user of the input "
72 "port reference");
73
74 // Replace the inner portref + port access with a "raw" port.
75 OpBuilder::InsertionGuard g(rewriter);
76 rewriter.setInsertionPoint(op);
77 if (d == Direction::Input) {
78 // references to inputs becomes outputs (write from this container)
79 auto rawOutput = OutputPortOp::create(
80 rewriter, op.getLoc(), op.getInnerSym(), innerType, op.getNameAttr());
81
82 // Replace writes to the unwrapped port with writes to the new port.
83 for (auto *unwrappedPortUser :
84 llvm::make_early_inc_range(portUnwrapper.getResult().getUsers())) {
85 PortWriteOp portWriter = dyn_cast<PortWriteOp>(unwrappedPortUser);
86 if (!portWriter || portWriter.getPort() != portUnwrapper.getResult())
87 continue;
88
89 // Replace the source port of the write op with the new port.
90 rewriter.replaceOpWithNewOp<PortWriteOp>(portWriter, rawOutput,
91 portWriter.getValue());
92 }
93 } else {
94 // References to outputs becomes inputs (read from this container)
95 auto rawInput = InputPortOp::create(
96 rewriter, op.getLoc(), op.getInnerSym(), innerType, op.getNameAttr());
97 // TODO: RewriterBase::replaceAllUsesWith is not currently supported by
98 // DialectConversion. Using it may lead to assertions about mutating
99 // replaced/erased ops. For now, do this RAUW directly, until
100 // ConversionPatternRewriter properly supports RAUW.
101 // See https://github.com/llvm/circt/issues/6795.
102 portUnwrapper.getResult().replaceAllUsesWith(rawInput);
103
104 // Replace all kanagawa.port.read ops with a read of the new input.
105 for (auto *portUser :
106 llvm::make_early_inc_range(portUnwrapper.getResult().getUsers())) {
107 PortReadOp portReader = dyn_cast<PortReadOp>(portUser);
108 if (!portReader || portReader.getPort() != portUnwrapper.getResult())
109 continue;
110
111 rewriter.replaceOpWithNewOp<PortReadOp>(portReader, rawInput);
112 }
113 }
114
115 // Finally, remove the port unwrapper and the original input port.
116 rewriter.eraseOp(portUnwrapper);
117 rewriter.eraseOp(op);
118
119 return success();
120 }
121};
122
123class OutputPortConversionPattern : public OpConversionPattern<OutputPortOp> {
124public:
125 using OpConversionPattern::OpConversionPattern;
126 using OpAdaptor = typename OpConversionPattern<OutputPortOp>::OpAdaptor;
127
128 LogicalResult
129 matchAndRewrite(OutputPortOp op, OpAdaptor adaptor,
130 ConversionPatternRewriter &rewriter) const override {
131 PortRefType innerPortRefType = cast<PortRefType>(op.getType());
132 Type innerType = innerPortRefType.getPortType();
133 Direction d = innerPortRefType.getDirection();
134
135 // Locate the portwrapper - this is a writeOp with the output portref as
136 // the portref value.
137 PortWriteOp portWrapper;
138 for (auto *user : op.getResult().getUsers()) {
139 auto writeOp = dyn_cast<PortWriteOp>(user);
140 if (writeOp && writeOp.getPort() == op.getResult()) {
141 if (portWrapper)
142 return rewriter.notifyMatchFailure(
143 op, "expected a single kanagawa.port.write to wrap the output "
144 "portref, but found multiple");
145 portWrapper = writeOp;
146 break;
147 }
148 }
149
150 if (!portWrapper)
151 return rewriter.notifyMatchFailure(
152 op, "expected an kanagawa.port.write to wrap the output portref");
153
154 OpBuilder::InsertionGuard g(rewriter);
155 rewriter.setInsertionPoint(op);
156 if (d == Direction::Input) {
157 // Outputs of inputs are inputs (external driver into this container).
158 // Create the raw input port and write the input port reference with a
159 // read of the raw input port.
160 auto rawInput = InputPortOp::create(
161 rewriter, op.getLoc(), op.getInnerSym(), innerType, op.getNameAttr());
162 PortWriteOp::create(rewriter, op.getLoc(), portWrapper.getValue(),
163 PortReadOp::create(rewriter, op.getLoc(), rawInput));
164 } else {
165 // Outputs of outputs are outputs (external driver out of this container).
166 // Create the raw output port and do a read of the input port reference.
167 auto rawOutput = OutputPortOp::create(
168 rewriter, op.getLoc(), op.getInnerSym(), innerType, op.getNameAttr());
169 PortWriteOp::create(
170 rewriter, op.getLoc(), rawOutput,
171 PortReadOp::create(rewriter, op.getLoc(), portWrapper.getValue()));
172 }
173
174 // Finally, remove the port wrapper and the original output port.
175 rewriter.eraseOp(portWrapper);
176 rewriter.eraseOp(op);
177
178 return success();
179 }
180};
181
182class GetPortConversionPattern : public OpConversionPattern<GetPortOp> {
183 using OpConversionPattern::OpConversionPattern;
184 using OpAdaptor = typename OpConversionPattern<GetPortOp>::OpAdaptor;
185
186 LogicalResult
187 matchAndRewrite(GetPortOp op, OpAdaptor adaptor,
188 ConversionPatternRewriter &rewriter) const override {
189 PortRefType outerPortRefType = cast<PortRefType>(op.getType());
190 PortRefType innerPortRefType =
191 cast<PortRefType>(outerPortRefType.getPortType());
192 Type innerType = innerPortRefType.getPortType();
193
194 Direction outerDirection = outerPortRefType.getDirection();
195 Direction innerDirection = innerPortRefType.getDirection();
196
197 StringAttr portName = op.getPortSymbolAttr().getAttr();
198
199 OpBuilder::InsertionGuard g(rewriter);
200 rewriter.setInsertionPoint(op);
201 Operation *wrapper;
202 if (outerDirection == Direction::Input) {
203 // Locate the get_port wrapper - this is a WriteOp with the get_port
204 // result as the portref value.
205 PortWriteOp getPortWrapper;
206 for (auto *user : op.getResult().getUsers()) {
207 auto writeOp = dyn_cast<PortWriteOp>(user);
208 if (!writeOp || writeOp.getPort() != op.getResult())
209 continue;
210
211 getPortWrapper = writeOp;
212 break;
213 }
214
215 if (!getPortWrapper)
216 return rewriter.notifyMatchFailure(
217 op, "expected an kanagawa.port.write to wrap the get_port result");
218 wrapper = getPortWrapper;
219 LLVM_DEBUG(llvm::dbgs() << "Found wrapper: " << *wrapper);
220 if (innerDirection == Direction::Input) {
221 // The portref<in portref<in T>> is now an output port.
222 auto newGetPort =
223 GetPortOp::create(rewriter, op.getLoc(), op.getInstance(), portName,
224 innerType, Direction::Output);
225 auto newGetPortVal =
226 PortReadOp::create(rewriter, op.getLoc(), newGetPort);
227 PortWriteOp::create(rewriter, op.getLoc(), getPortWrapper.getValue(),
228 newGetPortVal);
229 } else {
230 // The portref<in portref<out T>> is now an input port.
231 auto newGetPort =
232 GetPortOp::create(rewriter, op.getLoc(), op.getInstance(), portName,
233 innerType, Direction::Input);
234 auto writeValue = PortReadOp::create(rewriter, op.getLoc(),
235 getPortWrapper.getValue());
236 PortWriteOp::create(rewriter, op.getLoc(), newGetPort, writeValue);
237 }
238 } else {
239 PortReadOp getPortUnwrapper;
240 for (auto *user : op.getResult().getUsers()) {
241 auto readOp = dyn_cast<PortReadOp>(user);
242 if (!readOp || readOp.getPort() != op.getResult())
243 continue;
244
245 getPortUnwrapper = readOp;
246 break;
247 }
248
249 if (!getPortUnwrapper)
250 return rewriter.notifyMatchFailure(
251 op, "expected an kanagawa.port.read to unwrap the get_port result");
252 wrapper = getPortUnwrapper;
253
254 LLVM_DEBUG(llvm::dbgs() << "Found unwrapper: " << *wrapper);
255 if (innerDirection == Direction::Input) {
256 // In this situation, we're retrieving an input port that is sent as an
257 // output of the container: %rr = kanagawa.get_port %c %c_in :
258 // !kanagawa.scoperef<...> -> !kanagawa.portref<out !kanagawa.portref<in
259 // T>>
260 //
261 // Thus we expect one of these cases:
262 // (always). a read op which unwraps the portref<out portref<in T>> into
263 // a portref<in T>
264 // %r = kanagawa.port.read %rr : !kanagawa.portref<out
265 // !kanagawa.portref<in T>>
266 // either:
267 // 1. A write to %r which drives the target input port
268 // kanagawa.port.write %r, %someValue : !kanagawa.portref<in T>
269 // 2. A write using %r which forwards the input port reference
270 // kanagawa.port.write %r_fw, %r : !kanagawa.portref<out
271 // !kanagawa.portref<in T>>
272 //
273 PortWriteOp portDriver;
274 PortWriteOp portForwardingDriver;
275 for (auto *user : getPortUnwrapper.getResult().getUsers()) {
276 auto writeOp = dyn_cast<PortWriteOp>(user);
277 if (!writeOp)
278 continue;
279
280 bool isForwarding = writeOp.getPort() != getPortUnwrapper.getResult();
281 if (isForwarding) {
282 if (portForwardingDriver)
283 return rewriter.notifyMatchFailure(
284 op,
285 "expected a single kanagawa.port.write to use the unwrapped "
286 "get_port result, but found multiple");
287 portForwardingDriver = writeOp;
288 LLVM_DEBUG(llvm::dbgs()
289 << "Found forwarding driver: " << *portForwardingDriver);
290 } else {
291 if (portDriver)
292 return rewriter.notifyMatchFailure(
293 op,
294 "expected a single kanagawa.port.write to use the unwrapped "
295 "get_port result, but found multiple");
296 portDriver = writeOp;
297 LLVM_DEBUG(llvm::dbgs() << "Found driver: " << *portDriver);
298 }
299 }
300
301 if (!portDriver && !portForwardingDriver)
302 return rewriter.notifyMatchFailure(
303 op,
304 "expected an kanagawa.port.write to drive the unwrapped get_port "
305 "result");
306
307 Value portDriverValue;
308 if (portForwardingDriver) {
309 // In the case of forwarding, it is simplest to just create a new
310 // input port, and write the forwarded value to it. This will allow
311 // this pattern to recurse and eventually reach the case where the
312 // forwarding is resolved through reading/writing the intermediate
313 // inputs.
314 auto fwPortName = rewriter.getStringAttr(portName.strref() + "_fw");
315 auto forwardedInputPort = InputPortOp::create(
316 rewriter, op.getLoc(), hw::InnerSymAttr::get(fwPortName),
317 innerType, fwPortName);
318
319 // TODO: RewriterBase::replaceAllUsesWith is not currently supported
320 // by DialectConversion. Using it may lead to assertions about
321 // mutating replaced/erased ops. For now, do this RAUW directly, until
322 // ConversionPatternRewriter properly supports RAUW.
323 // See https://github.com/llvm/circt/issues/6795.
324 getPortUnwrapper.getResult().replaceAllUsesWith(forwardedInputPort);
325 portDriverValue = PortReadOp::create(rewriter, op.getLoc(),
326 forwardedInputPort.getPort());
327 } else {
328 // Direct assignmenet - the driver value will be the value of
329 // the driver.
330 portDriverValue = portDriver.getValue();
331 rewriter.eraseOp(portDriver);
332 }
333
334 // Perform assignment to the input port of the target instance using
335 // the driver value.
336 auto rawPort =
337 GetPortOp::create(rewriter, op.getLoc(), op.getInstance(), portName,
338 innerType, Direction::Input);
339 PortWriteOp::create(rewriter, op.getLoc(), rawPort, portDriverValue);
340 } else {
341 // In this situation, we're retrieving an output port that is sent as an
342 // output of the container: %rr = kanagawa.get_port %c %c_in :
343 // !kanagawa.scoperef<...> -> !kanagawa.portref<out
344 // !kanagawa.portref<out T>>
345 //
346 // Thus we expect two ops to be present:
347 // 1. a read op which unwraps the portref<out portref<in T>> into a
348 // portref<in T>
349 // %r = kanagawa.port.read %rr : !kanagawa.portref<out
350 // !kanagawa.portref<in T>>
351 // 2. one (or multiple, if not CSEd)
352 //
353 // We then replace the read op with the actual output port of the
354 // container.
355 auto rawPort =
356 GetPortOp::create(rewriter, op.getLoc(), op.getInstance(), portName,
357 innerType, Direction::Output);
358
359 // TODO: RewriterBase::replaceAllUsesWith is not currently supported by
360 // DialectConversion. Using it may lead to assertions about mutating
361 // replaced/erased ops. For now, do this RAUW directly, until
362 // ConversionPatternRewriter properly supports RAUW.
363 // See https://github.com/llvm/circt/issues/6795.
364 getPortUnwrapper.getResult().replaceAllUsesWith(rawPort);
365 }
366 }
367
368 // Finally, remove the get_port op.
369 rewriter.eraseOp(wrapper);
370 rewriter.eraseOp(op);
371
372 return success();
373 }
374};
375
376struct PortrefLoweringPass
377 : public circt::kanagawa::impl::KanagawaPortrefLoweringBase<
378 PortrefLoweringPass> {
379 void runOnOperation() override;
380};
381
382} // anonymous namespace
383
384void PortrefLoweringPass::runOnOperation() {
385 auto *ctx = &getContext();
386 ConversionTarget target(*ctx);
387 target.addIllegalOp<InputPortOp, OutputPortOp>();
388 target.addLegalDialect<KanagawaDialect>();
389
390 // Ports are legal when they do not have portref types anymore.
391 target.addDynamicallyLegalOp<InputPortOp, OutputPortOp>([&](auto op) {
392 PortRefType portType =
393 cast<PortRefType>(cast<PortOpInterface>(op).getPort().getType());
394 return !isa<PortRefType>(portType.getPortType());
395 });
396
397 PortReadOp op;
398
399 // get_port's are legal when they do not have portref types anymore.
400 target.addDynamicallyLegalOp<GetPortOp>([&](GetPortOp op) {
401 PortRefType portType = cast<PortRefType>(op.getPort().getType());
402 return !isa<PortRefType>(portType.getPortType());
403 });
404
405 RewritePatternSet patterns(ctx);
406 patterns.add<InputPortConversionPattern, OutputPortConversionPattern,
407 GetPortConversionPattern>(ctx);
408
409 if (failed(
410 applyPartialConversion(getOperation(), target, std::move(patterns))))
411 signalPassFailure();
412}
413
415 return std::make_unique<PortrefLoweringPass>();
416}
static PortInfo getPort(ModuleTy &mod, size_t idx)
Definition HWOps.cpp:1448
Direction
The direction of a Component or Cell port.
Definition CalyxOps.h:76
mlir::Type innerType(mlir::Type type)
Definition ESITypes.cpp:227
std::unique_ptr< mlir::Pass > createPortrefLoweringPass()
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.