11#include "mlir/Pass/Pass.h"
18#include "mlir/IR/Builders.h"
19#include "mlir/Transforms/DialectConversion.h"
20#include "llvm/ADT/TypeSwitch.h"
21#include "llvm/Support/Debug.h"
23#define DEBUG_TYPE "kanagawa-lower-portrefs"
27#define GEN_PASS_DEF_KANAGAWAPORTREFLOWERING
28#include "circt/Dialect/Kanagawa/KanagawaPasses.h.inc"
34using namespace kanagawa;
40 using OpConversionPattern::OpConversionPattern;
44 matchAndRewrite(InputPortOp op, OpAdaptor adaptor,
45 ConversionPatternRewriter &rewriter)
const override {
46 PortRefType innerPortRefType = cast<PortRefType>(op.getType());
47 Type
innerType = innerPortRefType.getPortType();
48 Direction d = innerPortRefType.getDirection();
55 auto portrefUsers = op.getResult().getUsers();
56 size_t nPortrefUsers =
57 std::distance(portrefUsers.begin(), portrefUsers.end());
58 if (nPortrefUsers != 1)
59 return rewriter.notifyMatchFailure(
61 "expected a single kanagawa.port.read as the only user of the input "
62 "port reference, but found multiple readers - please run CSE "
63 "prior to this pass");
67 PortReadOp portUnwrapper = dyn_cast<PortReadOp>(*portrefUsers.begin());
69 return rewriter.notifyMatchFailure(
71 "expected a single kanagawa.port.read as the only user of the input "
75 OpBuilder::InsertionGuard g(rewriter);
76 rewriter.setInsertionPoint(op);
77 if (d == Direction::Input) {
79 auto rawOutput = OutputPortOp::create(
80 rewriter, op.getLoc(), op.getInnerSym(),
innerType, op.getNameAttr());
83 for (
auto *unwrappedPortUser :
84 llvm::make_early_inc_range(portUnwrapper.getResult().getUsers())) {
85 PortWriteOp portWriter = dyn_cast<PortWriteOp>(unwrappedPortUser);
86 if (!portWriter || portWriter.getPort() != portUnwrapper.getResult())
90 rewriter.replaceOpWithNewOp<PortWriteOp>(portWriter, rawOutput,
91 portWriter.getValue());
95 auto rawInput = InputPortOp::create(
96 rewriter, op.getLoc(), op.getInnerSym(),
innerType, op.getNameAttr());
102 portUnwrapper.getResult().replaceAllUsesWith(rawInput);
105 for (
auto *portUser :
106 llvm::make_early_inc_range(portUnwrapper.getResult().getUsers())) {
107 PortReadOp portReader = dyn_cast<PortReadOp>(portUser);
108 if (!portReader || portReader.getPort() != portUnwrapper.getResult())
111 rewriter.replaceOpWithNewOp<PortReadOp>(portReader, rawInput);
116 rewriter.eraseOp(portUnwrapper);
117 rewriter.eraseOp(op);
125 using OpConversionPattern::OpConversionPattern;
129 matchAndRewrite(OutputPortOp op, OpAdaptor adaptor,
130 ConversionPatternRewriter &rewriter)
const override {
131 PortRefType innerPortRefType = cast<PortRefType>(op.getType());
132 Type
innerType = innerPortRefType.getPortType();
133 Direction d = innerPortRefType.getDirection();
137 PortWriteOp portWrapper;
138 for (
auto *user : op.getResult().getUsers()) {
139 auto writeOp = dyn_cast<PortWriteOp>(user);
140 if (writeOp && writeOp.getPort() == op.getResult()) {
142 return rewriter.notifyMatchFailure(
143 op,
"expected a single kanagawa.port.write to wrap the output "
144 "portref, but found multiple");
145 portWrapper = writeOp;
151 return rewriter.notifyMatchFailure(
152 op,
"expected an kanagawa.port.write to wrap the output portref");
154 OpBuilder::InsertionGuard g(rewriter);
155 rewriter.setInsertionPoint(op);
156 if (d == Direction::Input) {
160 auto rawInput = InputPortOp::create(
161 rewriter, op.getLoc(), op.getInnerSym(), innerType, op.getNameAttr());
162 PortWriteOp::create(rewriter, op.getLoc(), portWrapper.getValue(),
163 PortReadOp::create(rewriter, op.getLoc(), rawInput));
167 auto rawOutput = OutputPortOp::create(
168 rewriter, op.getLoc(), op.getInnerSym(), innerType, op.getNameAttr());
170 rewriter, op.getLoc(), rawOutput,
171 PortReadOp::create(rewriter, op.getLoc(), portWrapper.getValue()));
175 rewriter.eraseOp(portWrapper);
176 rewriter.eraseOp(op);
183 using OpConversionPattern::OpConversionPattern;
187 matchAndRewrite(GetPortOp op, OpAdaptor adaptor,
188 ConversionPatternRewriter &rewriter)
const override {
189 PortRefType outerPortRefType = cast<PortRefType>(op.getType());
190 PortRefType innerPortRefType =
191 cast<PortRefType>(outerPortRefType.getPortType());
192 Type
innerType = innerPortRefType.getPortType();
194 Direction outerDirection = outerPortRefType.getDirection();
195 Direction innerDirection = innerPortRefType.getDirection();
197 StringAttr portName = op.getPortSymbolAttr().getAttr();
199 OpBuilder::InsertionGuard g(rewriter);
200 rewriter.setInsertionPoint(op);
202 if (outerDirection == Direction::Input) {
205 PortWriteOp getPortWrapper;
206 for (
auto *user : op.getResult().getUsers()) {
207 auto writeOp = dyn_cast<PortWriteOp>(user);
208 if (!writeOp || writeOp.getPort() != op.getResult())
211 getPortWrapper = writeOp;
216 return rewriter.notifyMatchFailure(
217 op,
"expected an kanagawa.port.write to wrap the get_port result");
218 wrapper = getPortWrapper;
219 LLVM_DEBUG(llvm::dbgs() <<
"Found wrapper: " << *wrapper);
220 if (innerDirection == Direction::Input) {
223 GetPortOp::create(rewriter, op.getLoc(), op.getInstance(), portName,
224 innerType, Direction::Output);
226 PortReadOp::create(rewriter, op.getLoc(), newGetPort);
227 PortWriteOp::create(rewriter, op.getLoc(), getPortWrapper.getValue(),
232 GetPortOp::create(rewriter, op.getLoc(), op.getInstance(), portName,
233 innerType, Direction::Input);
234 auto writeValue = PortReadOp::create(rewriter, op.getLoc(),
235 getPortWrapper.getValue());
236 PortWriteOp::create(rewriter, op.getLoc(), newGetPort, writeValue);
239 PortReadOp getPortUnwrapper;
240 for (
auto *user : op.getResult().getUsers()) {
241 auto readOp = dyn_cast<PortReadOp>(user);
242 if (!readOp || readOp.getPort() != op.getResult())
245 getPortUnwrapper = readOp;
249 if (!getPortUnwrapper)
250 return rewriter.notifyMatchFailure(
251 op,
"expected an kanagawa.port.read to unwrap the get_port result");
252 wrapper = getPortUnwrapper;
254 LLVM_DEBUG(llvm::dbgs() <<
"Found unwrapper: " << *wrapper);
255 if (innerDirection == Direction::Input) {
273 PortWriteOp portDriver;
274 PortWriteOp portForwardingDriver;
275 for (
auto *user : getPortUnwrapper.getResult().getUsers()) {
276 auto writeOp = dyn_cast<PortWriteOp>(user);
280 bool isForwarding = writeOp.getPort() != getPortUnwrapper.getResult();
282 if (portForwardingDriver)
283 return rewriter.notifyMatchFailure(
285 "expected a single kanagawa.port.write to use the unwrapped "
286 "get_port result, but found multiple");
287 portForwardingDriver = writeOp;
288 LLVM_DEBUG(llvm::dbgs()
289 <<
"Found forwarding driver: " << *portForwardingDriver);
292 return rewriter.notifyMatchFailure(
294 "expected a single kanagawa.port.write to use the unwrapped "
295 "get_port result, but found multiple");
296 portDriver = writeOp;
297 LLVM_DEBUG(llvm::dbgs() <<
"Found driver: " << *portDriver);
301 if (!portDriver && !portForwardingDriver)
302 return rewriter.notifyMatchFailure(
304 "expected an kanagawa.port.write to drive the unwrapped get_port "
307 Value portDriverValue;
308 if (portForwardingDriver) {
314 auto fwPortName = rewriter.getStringAttr(portName.strref() +
"_fw");
315 auto forwardedInputPort = InputPortOp::create(
316 rewriter, op.getLoc(), hw::InnerSymAttr::get(fwPortName),
317 innerType, fwPortName);
324 getPortUnwrapper.getResult().replaceAllUsesWith(forwardedInputPort);
325 portDriverValue = PortReadOp::create(rewriter, op.getLoc(),
326 forwardedInputPort.getPort());
330 portDriverValue = portDriver.getValue();
331 rewriter.eraseOp(portDriver);
337 GetPortOp::create(rewriter, op.getLoc(), op.getInstance(), portName,
338 innerType, Direction::Input);
339 PortWriteOp::create(rewriter, op.getLoc(), rawPort, portDriverValue);
356 GetPortOp::create(rewriter, op.getLoc(), op.getInstance(), portName,
357 innerType, Direction::Output);
364 getPortUnwrapper.getResult().replaceAllUsesWith(rawPort);
369 rewriter.eraseOp(wrapper);
370 rewriter.eraseOp(op);
376struct PortrefLoweringPass
377 :
public circt::kanagawa::impl::KanagawaPortrefLoweringBase<
378 PortrefLoweringPass> {
379 void runOnOperation()
override;
384void PortrefLoweringPass::runOnOperation() {
385 auto *ctx = &getContext();
386 ConversionTarget target(*ctx);
387 target.addIllegalOp<InputPortOp, OutputPortOp>();
388 target.addLegalDialect<KanagawaDialect>();
391 target.addDynamicallyLegalOp<InputPortOp, OutputPortOp>([&](
auto op) {
392 PortRefType portType =
393 cast<PortRefType>(cast<PortOpInterface>(op).
getPort().getType());
394 return !isa<PortRefType>(portType.getPortType());
400 target.addDynamicallyLegalOp<GetPortOp>([&](GetPortOp op) {
401 PortRefType portType = cast<PortRefType>(op.getPort().getType());
402 return !isa<PortRefType>(portType.getPortType());
406 patterns.add<InputPortConversionPattern, OutputPortConversionPattern,
407 GetPortConversionPattern>(ctx);
410 applyPartialConversion(getOperation(), target, std::move(
patterns))))
415 return std::make_unique<PortrefLoweringPass>();
static PortInfo getPort(ModuleTy &mod, size_t idx)
Direction
The direction of a Component or Cell port.
mlir::Type innerType(mlir::Type type)
std::unique_ptr< mlir::Pass > createPortrefLoweringPass()
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.