11 #include "mlir/Pass/Pass.h"
18 #include "mlir/IR/Builders.h"
19 #include "mlir/Transforms/DialectConversion.h"
20 #include "llvm/ADT/TypeSwitch.h"
21 #include "llvm/Support/Debug.h"
23 #define DEBUG_TYPE "ibis-lower-portrefs"
27 #define GEN_PASS_DEF_IBISPORTREFLOWERING
28 #include "circt/Dialect/Ibis/IbisPasses.h.inc"
33 using namespace circt;
40 using OpConversionPattern::OpConversionPattern;
44 matchAndRewrite(InputPortOp op, OpAdaptor adaptor,
45 ConversionPatternRewriter &rewriter)
const override {
46 PortRefType innerPortRefType = cast<PortRefType>(op.getType());
47 Type
innerType = innerPortRefType.getPortType();
48 Direction d = innerPortRefType.getDirection();
55 auto portrefUsers = op.getResult().getUsers();
56 size_t nPortrefUsers =
57 std::distance(portrefUsers.begin(), portrefUsers.end());
58 if (nPortrefUsers != 1)
59 return rewriter.notifyMatchFailure(
60 op,
"expected a single ibis.port.read as the only user of the input "
61 "port reference, but found multiple readers - please run CSE "
62 "prior to this pass");
66 PortReadOp portUnwrapper = dyn_cast<PortReadOp>(*portrefUsers.begin());
68 return rewriter.notifyMatchFailure(
69 op,
"expected a single ibis.port.read as the only user of the input "
73 OpBuilder::InsertionGuard g(rewriter);
74 rewriter.setInsertionPoint(op);
77 auto rawOutput = rewriter.create<OutputPortOp>(
78 op.getLoc(), op.getInnerSym(),
innerType, op.getNameAttr());
81 for (
auto *unwrappedPortUser :
82 llvm::make_early_inc_range(portUnwrapper.getResult().getUsers())) {
83 PortWriteOp portWriter = dyn_cast<PortWriteOp>(unwrappedPortUser);
84 if (!portWriter || portWriter.getPort() != portUnwrapper.getResult())
88 rewriter.replaceOpWithNewOp<PortWriteOp>(portWriter, rawOutput,
89 portWriter.getValue());
93 auto rawInput = rewriter.create<InputPortOp>(
94 op.getLoc(), op.getInnerSym(),
innerType, op.getNameAttr());
100 portUnwrapper.getResult().replaceAllUsesWith(rawInput);
103 for (
auto *portUser :
104 llvm::make_early_inc_range(portUnwrapper.getResult().getUsers())) {
105 PortReadOp portReader = dyn_cast<PortReadOp>(portUser);
106 if (!portReader || portReader.getPort() != portUnwrapper.getResult())
109 rewriter.replaceOpWithNewOp<PortReadOp>(portReader, rawInput);
114 rewriter.eraseOp(portUnwrapper);
115 rewriter.eraseOp(op);
123 using OpConversionPattern::OpConversionPattern;
127 matchAndRewrite(OutputPortOp op, OpAdaptor adaptor,
128 ConversionPatternRewriter &rewriter)
const override {
129 PortRefType innerPortRefType = cast<PortRefType>(op.getType());
130 Type
innerType = innerPortRefType.getPortType();
131 Direction d = innerPortRefType.getDirection();
135 PortWriteOp portWrapper;
136 for (
auto *user : op.getResult().getUsers()) {
137 auto writeOp = dyn_cast<PortWriteOp>(user);
138 if (writeOp && writeOp.getPort() == op.getResult()) {
140 return rewriter.notifyMatchFailure(
141 op,
"expected a single ibis.port.write to wrap the output "
142 "portref, but found multiple");
143 portWrapper = writeOp;
149 return rewriter.notifyMatchFailure(
150 op,
"expected an ibis.port.write to wrap the output portref");
152 OpBuilder::InsertionGuard g(rewriter);
153 rewriter.setInsertionPoint(op);
158 auto rawInput = rewriter.create<InputPortOp>(
159 op.getLoc(), op.getInnerSym(),
innerType, op.getNameAttr());
160 rewriter.create<PortWriteOp>(
161 op.getLoc(), portWrapper.getValue(),
162 rewriter.create<PortReadOp>(op.getLoc(), rawInput));
166 auto rawOutput = rewriter.create<OutputPortOp>(
167 op.getLoc(), op.getInnerSym(),
innerType, op.getNameAttr());
168 rewriter.create<PortWriteOp>(
169 op.getLoc(), rawOutput,
170 rewriter.create<PortReadOp>(op.getLoc(), portWrapper.getValue()));
174 rewriter.eraseOp(portWrapper);
175 rewriter.eraseOp(op);
182 using OpConversionPattern::OpConversionPattern;
186 matchAndRewrite(GetPortOp op, OpAdaptor adaptor,
187 ConversionPatternRewriter &rewriter)
const override {
188 PortRefType outerPortRefType = cast<PortRefType>(op.getType());
189 PortRefType innerPortRefType =
190 cast<PortRefType>(outerPortRefType.getPortType());
191 Type
innerType = innerPortRefType.getPortType();
193 Direction outerDirection = outerPortRefType.getDirection();
194 Direction innerDirection = innerPortRefType.getDirection();
196 StringAttr portName = op.getPortSymbolAttr().getAttr();
198 OpBuilder::InsertionGuard g(rewriter);
199 rewriter.setInsertionPoint(op);
204 PortWriteOp getPortWrapper;
205 for (
auto *user : op.getResult().getUsers()) {
206 auto writeOp = dyn_cast<PortWriteOp>(user);
207 if (!writeOp || writeOp.getPort() != op.getResult())
210 getPortWrapper = writeOp;
215 return rewriter.notifyMatchFailure(
216 op,
"expected an ibis.port.write to wrap the get_port result");
217 wrapper = getPortWrapper;
218 LLVM_DEBUG(llvm::dbgs() <<
"Found wrapper: " << *wrapper);
222 rewriter.create<GetPortOp>(op.getLoc(), op.getInstance(), portName,
225 rewriter.create<PortReadOp>(op.getLoc(), newGetPort);
226 rewriter.create<PortWriteOp>(op.getLoc(), getPortWrapper.getValue(),
231 rewriter.create<GetPortOp>(op.getLoc(), op.getInstance(), portName,
234 rewriter.create<PortReadOp>(op.getLoc(), getPortWrapper.getValue());
235 rewriter.create<PortWriteOp>(op.getLoc(), newGetPort, writeValue);
238 PortReadOp getPortUnwrapper;
239 for (
auto *user : op.getResult().getUsers()) {
240 auto readOp = dyn_cast<PortReadOp>(user);
241 if (!readOp || readOp.getPort() != op.getResult())
244 getPortUnwrapper = readOp;
248 if (!getPortUnwrapper)
249 return rewriter.notifyMatchFailure(
250 op,
"expected an ibis.port.read to unwrap the get_port result");
251 wrapper = getPortUnwrapper;
253 LLVM_DEBUG(llvm::dbgs() <<
"Found unwrapper: " << *wrapper);
270 PortWriteOp portDriver;
271 PortWriteOp portForwardingDriver;
272 for (
auto *user : getPortUnwrapper.getResult().getUsers()) {
273 auto writeOp = dyn_cast<PortWriteOp>(user);
277 bool isForwarding = writeOp.getPort() != getPortUnwrapper.getResult();
279 if (portForwardingDriver)
280 return rewriter.notifyMatchFailure(
281 op,
"expected a single ibis.port.write to use the unwrapped "
282 "get_port result, but found multiple");
283 portForwardingDriver = writeOp;
284 LLVM_DEBUG(llvm::dbgs()
285 <<
"Found forwarding driver: " << *portForwardingDriver);
288 return rewriter.notifyMatchFailure(
289 op,
"expected a single ibis.port.write to use the unwrapped "
290 "get_port result, but found multiple");
291 portDriver = writeOp;
292 LLVM_DEBUG(llvm::dbgs() <<
"Found driver: " << *portDriver);
296 if (!portDriver && !portForwardingDriver)
297 return rewriter.notifyMatchFailure(
298 op,
"expected an ibis.port.write to drive the unwrapped get_port "
301 Value portDriverValue;
302 if (portForwardingDriver) {
308 auto fwPortName = rewriter.getStringAttr(portName.strref() +
"_fw");
309 auto forwardedInputPort = rewriter.create<InputPortOp>(
318 getPortUnwrapper.getResult().replaceAllUsesWith(forwardedInputPort);
319 portDriverValue = rewriter.create<PortReadOp>(
320 op.getLoc(), forwardedInputPort.getPort());
324 portDriverValue = portDriver.getValue();
325 rewriter.eraseOp(portDriver);
331 rewriter.create<GetPortOp>(op.getLoc(), op.getInstance(), portName,
333 rewriter.create<PortWriteOp>(op.getLoc(), rawPort, portDriverValue);
348 rewriter.create<GetPortOp>(op.getLoc(), op.getInstance(), portName,
356 getPortUnwrapper.getResult().replaceAllUsesWith(rawPort);
361 rewriter.eraseOp(wrapper);
362 rewriter.eraseOp(op);
368 struct PortrefLoweringPass
369 :
public circt::ibis::impl::IbisPortrefLoweringBase<PortrefLoweringPass> {
370 void runOnOperation()
override;
375 void PortrefLoweringPass::runOnOperation() {
376 auto *ctx = &getContext();
377 ConversionTarget target(*ctx);
378 target.addIllegalOp<InputPortOp, OutputPortOp>();
379 target.addLegalDialect<IbisDialect>();
382 target.addDynamicallyLegalOp<InputPortOp, OutputPortOp>([&](
auto op) {
383 PortRefType portType =
384 cast<PortRefType>(cast<PortOpInterface>(op).
getPort().getType());
385 return !isa<PortRefType>(portType.getPortType());
391 target.addDynamicallyLegalOp<GetPortOp>([&](GetPortOp op) {
392 PortRefType portType = cast<PortRefType>(op.getPort().getType());
393 return !isa<PortRefType>(portType.getPortType());
397 patterns.add<InputPortConversionPattern, OutputPortConversionPattern,
398 GetPortConversionPattern>(ctx);
401 applyPartialConversion(getOperation(), target, std::move(
patterns))))
406 return std::make_unique<PortrefLoweringPass>();
static PortInfo getPort(ModuleTy &mod, size_t idx)
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Direction
The direction of a Component or Cell port.
mlir::Type innerType(mlir::Type type)
std::unique_ptr< mlir::Pass > createPortrefLoweringPass()
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.