11 #include "mlir/Pass/Pass.h"
18 #include "mlir/IR/Builders.h"
19 #include "mlir/Transforms/DialectConversion.h"
20 #include "llvm/ADT/TypeSwitch.h"
21 #include "llvm/Support/Debug.h"
23 #define DEBUG_TYPE "kanagawa-lower-portrefs"
27 #define GEN_PASS_DEF_KANAGAWAPORTREFLOWERING
28 #include "circt/Dialect/Kanagawa/KanagawaPasses.h.inc"
33 using namespace circt;
34 using namespace kanagawa;
40 using OpConversionPattern::OpConversionPattern;
44 matchAndRewrite(InputPortOp op, OpAdaptor adaptor,
45 ConversionPatternRewriter &rewriter)
const override {
46 PortRefType innerPortRefType = cast<PortRefType>(op.getType());
47 Type
innerType = innerPortRefType.getPortType();
48 Direction d = innerPortRefType.getDirection();
55 auto portrefUsers = op.getResult().getUsers();
56 size_t nPortrefUsers =
57 std::distance(portrefUsers.begin(), portrefUsers.end());
58 if (nPortrefUsers != 1)
59 return rewriter.notifyMatchFailure(
61 "expected a single kanagawa.port.read as the only user of the input "
62 "port reference, but found multiple readers - please run CSE "
63 "prior to this pass");
67 PortReadOp portUnwrapper = dyn_cast<PortReadOp>(*portrefUsers.begin());
69 return rewriter.notifyMatchFailure(
71 "expected a single kanagawa.port.read as the only user of the input "
75 OpBuilder::InsertionGuard g(rewriter);
76 rewriter.setInsertionPoint(op);
79 auto rawOutput = rewriter.create<OutputPortOp>(
80 op.getLoc(), op.getInnerSym(),
innerType, op.getNameAttr());
83 for (
auto *unwrappedPortUser :
84 llvm::make_early_inc_range(portUnwrapper.getResult().getUsers())) {
85 PortWriteOp portWriter = dyn_cast<PortWriteOp>(unwrappedPortUser);
86 if (!portWriter || portWriter.getPort() != portUnwrapper.getResult())
90 rewriter.replaceOpWithNewOp<PortWriteOp>(portWriter, rawOutput,
91 portWriter.getValue());
95 auto rawInput = rewriter.create<InputPortOp>(
96 op.getLoc(), op.getInnerSym(),
innerType, op.getNameAttr());
102 portUnwrapper.getResult().replaceAllUsesWith(rawInput);
105 for (
auto *portUser :
106 llvm::make_early_inc_range(portUnwrapper.getResult().getUsers())) {
107 PortReadOp portReader = dyn_cast<PortReadOp>(portUser);
108 if (!portReader || portReader.getPort() != portUnwrapper.getResult())
111 rewriter.replaceOpWithNewOp<PortReadOp>(portReader, rawInput);
116 rewriter.eraseOp(portUnwrapper);
117 rewriter.eraseOp(op);
125 using OpConversionPattern::OpConversionPattern;
129 matchAndRewrite(OutputPortOp op, OpAdaptor adaptor,
130 ConversionPatternRewriter &rewriter)
const override {
131 PortRefType innerPortRefType = cast<PortRefType>(op.getType());
132 Type
innerType = innerPortRefType.getPortType();
133 Direction d = innerPortRefType.getDirection();
137 PortWriteOp portWrapper;
138 for (
auto *user : op.getResult().getUsers()) {
139 auto writeOp = dyn_cast<PortWriteOp>(user);
140 if (writeOp && writeOp.getPort() == op.getResult()) {
142 return rewriter.notifyMatchFailure(
143 op,
"expected a single kanagawa.port.write to wrap the output "
144 "portref, but found multiple");
145 portWrapper = writeOp;
151 return rewriter.notifyMatchFailure(
152 op,
"expected an kanagawa.port.write to wrap the output portref");
154 OpBuilder::InsertionGuard g(rewriter);
155 rewriter.setInsertionPoint(op);
160 auto rawInput = rewriter.create<InputPortOp>(
161 op.getLoc(), op.getInnerSym(),
innerType, op.getNameAttr());
162 rewriter.create<PortWriteOp>(
163 op.getLoc(), portWrapper.getValue(),
164 rewriter.create<PortReadOp>(op.getLoc(), rawInput));
168 auto rawOutput = rewriter.create<OutputPortOp>(
169 op.getLoc(), op.getInnerSym(),
innerType, op.getNameAttr());
170 rewriter.create<PortWriteOp>(
171 op.getLoc(), rawOutput,
172 rewriter.create<PortReadOp>(op.getLoc(), portWrapper.getValue()));
176 rewriter.eraseOp(portWrapper);
177 rewriter.eraseOp(op);
184 using OpConversionPattern::OpConversionPattern;
188 matchAndRewrite(GetPortOp op, OpAdaptor adaptor,
189 ConversionPatternRewriter &rewriter)
const override {
190 PortRefType outerPortRefType = cast<PortRefType>(op.getType());
191 PortRefType innerPortRefType =
192 cast<PortRefType>(outerPortRefType.getPortType());
193 Type
innerType = innerPortRefType.getPortType();
195 Direction outerDirection = outerPortRefType.getDirection();
196 Direction innerDirection = innerPortRefType.getDirection();
198 StringAttr portName = op.getPortSymbolAttr().getAttr();
200 OpBuilder::InsertionGuard g(rewriter);
201 rewriter.setInsertionPoint(op);
206 PortWriteOp getPortWrapper;
207 for (
auto *user : op.getResult().getUsers()) {
208 auto writeOp = dyn_cast<PortWriteOp>(user);
209 if (!writeOp || writeOp.getPort() != op.getResult())
212 getPortWrapper = writeOp;
217 return rewriter.notifyMatchFailure(
218 op,
"expected an kanagawa.port.write to wrap the get_port result");
219 wrapper = getPortWrapper;
220 LLVM_DEBUG(llvm::dbgs() <<
"Found wrapper: " << *wrapper);
224 rewriter.create<GetPortOp>(op.getLoc(), op.getInstance(), portName,
227 rewriter.create<PortReadOp>(op.getLoc(), newGetPort);
228 rewriter.create<PortWriteOp>(op.getLoc(), getPortWrapper.getValue(),
233 rewriter.create<GetPortOp>(op.getLoc(), op.getInstance(), portName,
236 rewriter.create<PortReadOp>(op.getLoc(), getPortWrapper.getValue());
237 rewriter.create<PortWriteOp>(op.getLoc(), newGetPort, writeValue);
240 PortReadOp getPortUnwrapper;
241 for (
auto *user : op.getResult().getUsers()) {
242 auto readOp = dyn_cast<PortReadOp>(user);
243 if (!readOp || readOp.getPort() != op.getResult())
246 getPortUnwrapper = readOp;
250 if (!getPortUnwrapper)
251 return rewriter.notifyMatchFailure(
252 op,
"expected an kanagawa.port.read to unwrap the get_port result");
253 wrapper = getPortUnwrapper;
255 LLVM_DEBUG(llvm::dbgs() <<
"Found unwrapper: " << *wrapper);
274 PortWriteOp portDriver;
275 PortWriteOp portForwardingDriver;
276 for (
auto *user : getPortUnwrapper.getResult().getUsers()) {
277 auto writeOp = dyn_cast<PortWriteOp>(user);
281 bool isForwarding = writeOp.getPort() != getPortUnwrapper.getResult();
283 if (portForwardingDriver)
284 return rewriter.notifyMatchFailure(
286 "expected a single kanagawa.port.write to use the unwrapped "
287 "get_port result, but found multiple");
288 portForwardingDriver = writeOp;
289 LLVM_DEBUG(llvm::dbgs()
290 <<
"Found forwarding driver: " << *portForwardingDriver);
293 return rewriter.notifyMatchFailure(
295 "expected a single kanagawa.port.write to use the unwrapped "
296 "get_port result, but found multiple");
297 portDriver = writeOp;
298 LLVM_DEBUG(llvm::dbgs() <<
"Found driver: " << *portDriver);
302 if (!portDriver && !portForwardingDriver)
303 return rewriter.notifyMatchFailure(
305 "expected an kanagawa.port.write to drive the unwrapped get_port "
308 Value portDriverValue;
309 if (portForwardingDriver) {
315 auto fwPortName = rewriter.getStringAttr(portName.strref() +
"_fw");
316 auto forwardedInputPort = rewriter.create<InputPortOp>(
325 getPortUnwrapper.getResult().replaceAllUsesWith(forwardedInputPort);
326 portDriverValue = rewriter.create<PortReadOp>(
327 op.getLoc(), forwardedInputPort.getPort());
331 portDriverValue = portDriver.getValue();
332 rewriter.eraseOp(portDriver);
338 rewriter.create<GetPortOp>(op.getLoc(), op.getInstance(), portName,
340 rewriter.create<PortWriteOp>(op.getLoc(), rawPort, portDriverValue);
357 rewriter.create<GetPortOp>(op.getLoc(), op.getInstance(), portName,
365 getPortUnwrapper.getResult().replaceAllUsesWith(rawPort);
370 rewriter.eraseOp(wrapper);
371 rewriter.eraseOp(op);
377 struct PortrefLoweringPass
378 :
public circt::kanagawa::impl::KanagawaPortrefLoweringBase<
379 PortrefLoweringPass> {
380 void runOnOperation()
override;
385 void PortrefLoweringPass::runOnOperation() {
386 auto *ctx = &getContext();
387 ConversionTarget target(*ctx);
388 target.addIllegalOp<InputPortOp, OutputPortOp>();
389 target.addLegalDialect<KanagawaDialect>();
392 target.addDynamicallyLegalOp<InputPortOp, OutputPortOp>([&](
auto op) {
393 PortRefType portType =
394 cast<PortRefType>(cast<PortOpInterface>(op).
getPort().getType());
395 return !isa<PortRefType>(portType.getPortType());
401 target.addDynamicallyLegalOp<GetPortOp>([&](GetPortOp op) {
402 PortRefType portType = cast<PortRefType>(op.getPort().getType());
403 return !isa<PortRefType>(portType.getPortType());
407 patterns.add<InputPortConversionPattern, OutputPortConversionPattern,
408 GetPortConversionPattern>(ctx);
411 applyPartialConversion(getOperation(), target, std::move(
patterns))))
416 return std::make_unique<PortrefLoweringPass>();
static PortInfo getPort(ModuleTy &mod, size_t idx)
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Direction
The direction of a Component or Cell port.
mlir::Type innerType(mlir::Type type)
std::unique_ptr< mlir::Pass > createPortrefLoweringPass()
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.