CIRCT 20.0.0git
Loading...
Searching...
No Matches
KanagawaPortrefLowering.cpp
Go to the documentation of this file.
1//===- KanagawaPortrefLowering.cpp - Implementation of PortrefLowering ----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
11#include "mlir/Pass/Pass.h"
12
17
18#include "mlir/IR/Builders.h"
19#include "mlir/Transforms/DialectConversion.h"
20#include "llvm/ADT/TypeSwitch.h"
21#include "llvm/Support/Debug.h"
22
23#define DEBUG_TYPE "kanagawa-lower-portrefs"
24
25namespace circt {
26namespace kanagawa {
27#define GEN_PASS_DEF_KANAGAWAPORTREFLOWERING
28#include "circt/Dialect/Kanagawa/KanagawaPasses.h.inc"
29} // namespace kanagawa
30} // namespace circt
31
32using namespace mlir;
33using namespace circt;
34using namespace kanagawa;
35
36namespace {
37
38class InputPortConversionPattern : public OpConversionPattern<InputPortOp> {
39public:
40 using OpConversionPattern::OpConversionPattern;
41 using OpAdaptor = typename OpConversionPattern<InputPortOp>::OpAdaptor;
42
43 LogicalResult
44 matchAndRewrite(InputPortOp op, OpAdaptor adaptor,
45 ConversionPatternRewriter &rewriter) const override {
46 PortRefType innerPortRefType = cast<PortRefType>(op.getType());
47 Type innerType = innerPortRefType.getPortType();
48 Direction d = innerPortRefType.getDirection();
49
50 // CSE check - CSE should have ensured that only a single port unwrapper was
51 // present, so if this is not the case, the user should run
52 // CSE. This goes for other assumptions in the following code -
53 // we require a CSEd form to avoid having to deal with a bunch of edge
54 // cases.
55 auto portrefUsers = op.getResult().getUsers();
56 size_t nPortrefUsers =
57 std::distance(portrefUsers.begin(), portrefUsers.end());
58 if (nPortrefUsers != 1)
59 return rewriter.notifyMatchFailure(
60 op,
61 "expected a single kanagawa.port.read as the only user of the input "
62 "port reference, but found multiple readers - please run CSE "
63 "prior to this pass");
64
65 // A single PortReadOp should be present, which unwraps the portref<portref>
66 // into a portref.
67 PortReadOp portUnwrapper = dyn_cast<PortReadOp>(*portrefUsers.begin());
68 if (!portUnwrapper)
69 return rewriter.notifyMatchFailure(
70 op,
71 "expected a single kanagawa.port.read as the only user of the input "
72 "port reference");
73
74 // Replace the inner portref + port access with a "raw" port.
75 OpBuilder::InsertionGuard g(rewriter);
76 rewriter.setInsertionPoint(op);
77 if (d == Direction::Input) {
78 // references to inputs becomes outputs (write from this container)
79 auto rawOutput = rewriter.create<OutputPortOp>(
80 op.getLoc(), op.getInnerSym(), innerType, op.getNameAttr());
81
82 // Replace writes to the unwrapped port with writes to the new port.
83 for (auto *unwrappedPortUser :
84 llvm::make_early_inc_range(portUnwrapper.getResult().getUsers())) {
85 PortWriteOp portWriter = dyn_cast<PortWriteOp>(unwrappedPortUser);
86 if (!portWriter || portWriter.getPort() != portUnwrapper.getResult())
87 continue;
88
89 // Replace the source port of the write op with the new port.
90 rewriter.replaceOpWithNewOp<PortWriteOp>(portWriter, rawOutput,
91 portWriter.getValue());
92 }
93 } else {
94 // References to outputs becomes inputs (read from this container)
95 auto rawInput = rewriter.create<InputPortOp>(
96 op.getLoc(), op.getInnerSym(), innerType, op.getNameAttr());
97 // TODO: RewriterBase::replaceAllUsesWith is not currently supported by
98 // DialectConversion. Using it may lead to assertions about mutating
99 // replaced/erased ops. For now, do this RAUW directly, until
100 // ConversionPatternRewriter properly supports RAUW.
101 // See https://github.com/llvm/circt/issues/6795.
102 portUnwrapper.getResult().replaceAllUsesWith(rawInput);
103
104 // Replace all kanagawa.port.read ops with a read of the new input.
105 for (auto *portUser :
106 llvm::make_early_inc_range(portUnwrapper.getResult().getUsers())) {
107 PortReadOp portReader = dyn_cast<PortReadOp>(portUser);
108 if (!portReader || portReader.getPort() != portUnwrapper.getResult())
109 continue;
110
111 rewriter.replaceOpWithNewOp<PortReadOp>(portReader, rawInput);
112 }
113 }
114
115 // Finally, remove the port unwrapper and the original input port.
116 rewriter.eraseOp(portUnwrapper);
117 rewriter.eraseOp(op);
118
119 return success();
120 }
121};
122
123class OutputPortConversionPattern : public OpConversionPattern<OutputPortOp> {
124public:
125 using OpConversionPattern::OpConversionPattern;
126 using OpAdaptor = typename OpConversionPattern<OutputPortOp>::OpAdaptor;
127
128 LogicalResult
129 matchAndRewrite(OutputPortOp op, OpAdaptor adaptor,
130 ConversionPatternRewriter &rewriter) const override {
131 PortRefType innerPortRefType = cast<PortRefType>(op.getType());
132 Type innerType = innerPortRefType.getPortType();
133 Direction d = innerPortRefType.getDirection();
134
135 // Locate the portwrapper - this is a writeOp with the output portref as
136 // the portref value.
137 PortWriteOp portWrapper;
138 for (auto *user : op.getResult().getUsers()) {
139 auto writeOp = dyn_cast<PortWriteOp>(user);
140 if (writeOp && writeOp.getPort() == op.getResult()) {
141 if (portWrapper)
142 return rewriter.notifyMatchFailure(
143 op, "expected a single kanagawa.port.write to wrap the output "
144 "portref, but found multiple");
145 portWrapper = writeOp;
146 break;
147 }
148 }
149
150 if (!portWrapper)
151 return rewriter.notifyMatchFailure(
152 op, "expected an kanagawa.port.write to wrap the output portref");
153
154 OpBuilder::InsertionGuard g(rewriter);
155 rewriter.setInsertionPoint(op);
156 if (d == Direction::Input) {
157 // Outputs of inputs are inputs (external driver into this container).
158 // Create the raw input port and write the input port reference with a
159 // read of the raw input port.
160 auto rawInput = rewriter.create<InputPortOp>(
161 op.getLoc(), op.getInnerSym(), innerType, op.getNameAttr());
162 rewriter.create<PortWriteOp>(
163 op.getLoc(), portWrapper.getValue(),
164 rewriter.create<PortReadOp>(op.getLoc(), rawInput));
165 } else {
166 // Outputs of outputs are outputs (external driver out of this container).
167 // Create the raw output port and do a read of the input port reference.
168 auto rawOutput = rewriter.create<OutputPortOp>(
169 op.getLoc(), op.getInnerSym(), innerType, op.getNameAttr());
170 rewriter.create<PortWriteOp>(
171 op.getLoc(), rawOutput,
172 rewriter.create<PortReadOp>(op.getLoc(), portWrapper.getValue()));
173 }
174
175 // Finally, remove the port wrapper and the original output port.
176 rewriter.eraseOp(portWrapper);
177 rewriter.eraseOp(op);
178
179 return success();
180 }
181};
182
183class GetPortConversionPattern : public OpConversionPattern<GetPortOp> {
184 using OpConversionPattern::OpConversionPattern;
185 using OpAdaptor = typename OpConversionPattern<GetPortOp>::OpAdaptor;
186
187 LogicalResult
188 matchAndRewrite(GetPortOp op, OpAdaptor adaptor,
189 ConversionPatternRewriter &rewriter) const override {
190 PortRefType outerPortRefType = cast<PortRefType>(op.getType());
191 PortRefType innerPortRefType =
192 cast<PortRefType>(outerPortRefType.getPortType());
193 Type innerType = innerPortRefType.getPortType();
194
195 Direction outerDirection = outerPortRefType.getDirection();
196 Direction innerDirection = innerPortRefType.getDirection();
197
198 StringAttr portName = op.getPortSymbolAttr().getAttr();
199
200 OpBuilder::InsertionGuard g(rewriter);
201 rewriter.setInsertionPoint(op);
202 Operation *wrapper;
203 if (outerDirection == Direction::Input) {
204 // Locate the get_port wrapper - this is a WriteOp with the get_port
205 // result as the portref value.
206 PortWriteOp getPortWrapper;
207 for (auto *user : op.getResult().getUsers()) {
208 auto writeOp = dyn_cast<PortWriteOp>(user);
209 if (!writeOp || writeOp.getPort() != op.getResult())
210 continue;
211
212 getPortWrapper = writeOp;
213 break;
214 }
215
216 if (!getPortWrapper)
217 return rewriter.notifyMatchFailure(
218 op, "expected an kanagawa.port.write to wrap the get_port result");
219 wrapper = getPortWrapper;
220 LLVM_DEBUG(llvm::dbgs() << "Found wrapper: " << *wrapper);
221 if (innerDirection == Direction::Input) {
222 // The portref<in portref<in T>> is now an output port.
223 auto newGetPort =
224 rewriter.create<GetPortOp>(op.getLoc(), op.getInstance(), portName,
225 innerType, Direction::Output);
226 auto newGetPortVal =
227 rewriter.create<PortReadOp>(op.getLoc(), newGetPort);
228 rewriter.create<PortWriteOp>(op.getLoc(), getPortWrapper.getValue(),
229 newGetPortVal);
230 } else {
231 // The portref<in portref<out T>> is now an input port.
232 auto newGetPort =
233 rewriter.create<GetPortOp>(op.getLoc(), op.getInstance(), portName,
234 innerType, Direction::Input);
235 auto writeValue =
236 rewriter.create<PortReadOp>(op.getLoc(), getPortWrapper.getValue());
237 rewriter.create<PortWriteOp>(op.getLoc(), newGetPort, writeValue);
238 }
239 } else {
240 PortReadOp getPortUnwrapper;
241 for (auto *user : op.getResult().getUsers()) {
242 auto readOp = dyn_cast<PortReadOp>(user);
243 if (!readOp || readOp.getPort() != op.getResult())
244 continue;
245
246 getPortUnwrapper = readOp;
247 break;
248 }
249
250 if (!getPortUnwrapper)
251 return rewriter.notifyMatchFailure(
252 op, "expected an kanagawa.port.read to unwrap the get_port result");
253 wrapper = getPortUnwrapper;
254
255 LLVM_DEBUG(llvm::dbgs() << "Found unwrapper: " << *wrapper);
256 if (innerDirection == Direction::Input) {
257 // In this situation, we're retrieving an input port that is sent as an
258 // output of the container: %rr = kanagawa.get_port %c %c_in :
259 // !kanagawa.scoperef<...> -> !kanagawa.portref<out !kanagawa.portref<in
260 // T>>
261 //
262 // Thus we expect one of these cases:
263 // (always). a read op which unwraps the portref<out portref<in T>> into
264 // a portref<in T>
265 // %r = kanagawa.port.read %rr : !kanagawa.portref<out
266 // !kanagawa.portref<in T>>
267 // either:
268 // 1. A write to %r which drives the target input port
269 // kanagawa.port.write %r, %someValue : !kanagawa.portref<in T>
270 // 2. A write using %r which forwards the input port reference
271 // kanagawa.port.write %r_fw, %r : !kanagawa.portref<out
272 // !kanagawa.portref<in T>>
273 //
274 PortWriteOp portDriver;
275 PortWriteOp portForwardingDriver;
276 for (auto *user : getPortUnwrapper.getResult().getUsers()) {
277 auto writeOp = dyn_cast<PortWriteOp>(user);
278 if (!writeOp)
279 continue;
280
281 bool isForwarding = writeOp.getPort() != getPortUnwrapper.getResult();
282 if (isForwarding) {
283 if (portForwardingDriver)
284 return rewriter.notifyMatchFailure(
285 op,
286 "expected a single kanagawa.port.write to use the unwrapped "
287 "get_port result, but found multiple");
288 portForwardingDriver = writeOp;
289 LLVM_DEBUG(llvm::dbgs()
290 << "Found forwarding driver: " << *portForwardingDriver);
291 } else {
292 if (portDriver)
293 return rewriter.notifyMatchFailure(
294 op,
295 "expected a single kanagawa.port.write to use the unwrapped "
296 "get_port result, but found multiple");
297 portDriver = writeOp;
298 LLVM_DEBUG(llvm::dbgs() << "Found driver: " << *portDriver);
299 }
300 }
301
302 if (!portDriver && !portForwardingDriver)
303 return rewriter.notifyMatchFailure(
304 op,
305 "expected an kanagawa.port.write to drive the unwrapped get_port "
306 "result");
307
308 Value portDriverValue;
309 if (portForwardingDriver) {
310 // In the case of forwarding, it is simplest to just create a new
311 // input port, and write the forwarded value to it. This will allow
312 // this pattern to recurse and eventually reach the case where the
313 // forwarding is resolved through reading/writing the intermediate
314 // inputs.
315 auto fwPortName = rewriter.getStringAttr(portName.strref() + "_fw");
316 auto forwardedInputPort = rewriter.create<InputPortOp>(
317 op.getLoc(), hw::InnerSymAttr::get(fwPortName), innerType,
318 fwPortName);
319
320 // TODO: RewriterBase::replaceAllUsesWith is not currently supported
321 // by DialectConversion. Using it may lead to assertions about
322 // mutating replaced/erased ops. For now, do this RAUW directly, until
323 // ConversionPatternRewriter properly supports RAUW.
324 // See https://github.com/llvm/circt/issues/6795.
325 getPortUnwrapper.getResult().replaceAllUsesWith(forwardedInputPort);
326 portDriverValue = rewriter.create<PortReadOp>(
327 op.getLoc(), forwardedInputPort.getPort());
328 } else {
329 // Direct assignmenet - the driver value will be the value of
330 // the driver.
331 portDriverValue = portDriver.getValue();
332 rewriter.eraseOp(portDriver);
333 }
334
335 // Perform assignment to the input port of the target instance using
336 // the driver value.
337 auto rawPort =
338 rewriter.create<GetPortOp>(op.getLoc(), op.getInstance(), portName,
339 innerType, Direction::Input);
340 rewriter.create<PortWriteOp>(op.getLoc(), rawPort, portDriverValue);
341 } else {
342 // In this situation, we're retrieving an output port that is sent as an
343 // output of the container: %rr = kanagawa.get_port %c %c_in :
344 // !kanagawa.scoperef<...> -> !kanagawa.portref<out
345 // !kanagawa.portref<out T>>
346 //
347 // Thus we expect two ops to be present:
348 // 1. a read op which unwraps the portref<out portref<in T>> into a
349 // portref<in T>
350 // %r = kanagawa.port.read %rr : !kanagawa.portref<out
351 // !kanagawa.portref<in T>>
352 // 2. one (or multiple, if not CSEd)
353 //
354 // We then replace the read op with the actual output port of the
355 // container.
356 auto rawPort =
357 rewriter.create<GetPortOp>(op.getLoc(), op.getInstance(), portName,
358 innerType, Direction::Output);
359
360 // TODO: RewriterBase::replaceAllUsesWith is not currently supported by
361 // DialectConversion. Using it may lead to assertions about mutating
362 // replaced/erased ops. For now, do this RAUW directly, until
363 // ConversionPatternRewriter properly supports RAUW.
364 // See https://github.com/llvm/circt/issues/6795.
365 getPortUnwrapper.getResult().replaceAllUsesWith(rawPort);
366 }
367 }
368
369 // Finally, remove the get_port op.
370 rewriter.eraseOp(wrapper);
371 rewriter.eraseOp(op);
372
373 return success();
374 }
375};
376
377struct PortrefLoweringPass
378 : public circt::kanagawa::impl::KanagawaPortrefLoweringBase<
379 PortrefLoweringPass> {
380 void runOnOperation() override;
381};
382
383} // anonymous namespace
384
385void PortrefLoweringPass::runOnOperation() {
386 auto *ctx = &getContext();
387 ConversionTarget target(*ctx);
388 target.addIllegalOp<InputPortOp, OutputPortOp>();
389 target.addLegalDialect<KanagawaDialect>();
390
391 // Ports are legal when they do not have portref types anymore.
392 target.addDynamicallyLegalOp<InputPortOp, OutputPortOp>([&](auto op) {
393 PortRefType portType =
394 cast<PortRefType>(cast<PortOpInterface>(op).getPort().getType());
395 return !isa<PortRefType>(portType.getPortType());
396 });
397
398 PortReadOp op;
399
400 // get_port's are legal when they do not have portref types anymore.
401 target.addDynamicallyLegalOp<GetPortOp>([&](GetPortOp op) {
402 PortRefType portType = cast<PortRefType>(op.getPort().getType());
403 return !isa<PortRefType>(portType.getPortType());
404 });
405
406 RewritePatternSet patterns(ctx);
407 patterns.add<InputPortConversionPattern, OutputPortConversionPattern,
408 GetPortConversionPattern>(ctx);
409
410 if (failed(
411 applyPartialConversion(getOperation(), target, std::move(patterns))))
412 signalPassFailure();
413}
414
416 return std::make_unique<PortrefLoweringPass>();
417}
static PortInfo getPort(ModuleTy &mod, size_t idx)
Definition HWOps.cpp:1440
Direction
The direction of a Component or Cell port.
Definition CalyxOps.h:76
mlir::Type innerType(mlir::Type type)
Definition ESITypes.cpp:227
std::unique_ptr< mlir::Pass > createPortrefLoweringPass()
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.