11 #include "mlir/Pass/Pass.h"
20 #include "mlir/IR/OperationSupport.h"
21 #include "mlir/Transforms/DialectConversion.h"
22 #include "llvm/ADT/TypeSwitch.h"
26 #define GEN_PASS_DEF_KANAGAWACONTAINERSTOHW
27 #include "circt/Dialect/Kanagawa/KanagawaPasses.h.inc"
31 using namespace circt;
32 using namespace kanagawa;
38 struct ContainerPortInfo {
39 std::unique_ptr<hw::ModulePortInfo> hwPorts;
42 llvm::DenseMap<StringAttr, InputPortOp> opInputs;
45 llvm::DenseMap<StringAttr, OutputPortOp> opOutputs;
48 llvm::DenseMap<StringAttr, StringAttr> portSymbolsToPortName;
50 ContainerPortInfo() =
default;
51 ContainerPortInfo(ContainerOp container) {
52 SmallVector<hw::PortInfo, 4> inputs, outputs;
53 auto *ctx = container.getContext();
57 auto copyPortAttrs = [ctx](
auto port) {
58 llvm::DenseSet<StringAttr> elidedAttrs;
59 elidedAttrs.insert(port.getInnerSymAttrName());
60 elidedAttrs.insert(port.getTypeAttrName());
61 elidedAttrs.insert(port.getNameAttrName());
62 llvm::SmallVector<NamedAttribute> attrs;
63 for (NamedAttribute namedAttr : port->getAttrs()) {
64 if (elidedAttrs.contains(namedAttr.getName()))
66 attrs.push_back(namedAttr);
74 for (
auto input : container.getBodyBlock()->getOps<InputPortOp>()) {
77 opInputs[uniquePortName] = input;
78 hw::PortInfo portInfo;
79 portInfo.name = uniquePortName;
80 portSymbolsToPortName[input.getInnerSym().getSymName()] = uniquePortName;
81 portInfo.type = cast<PortOpInterface>(input.getOperation()).getPortType();
83 portInfo.attrs = copyPortAttrs(input);
84 inputs.push_back(portInfo);
87 for (
auto output : container.getBodyBlock()->getOps<OutputPortOp>()) {
90 opOutputs[uniquePortName] = output;
92 hw::PortInfo portInfo;
93 portInfo.name = uniquePortName;
94 portSymbolsToPortName[output.getInnerSym().getSymName()] = uniquePortName;
96 cast<PortOpInterface>(output.getOperation()).getPortType();
98 portInfo.attrs = copyPortAttrs(output);
99 outputs.push_back(portInfo);
101 hwPorts = std::make_unique<hw::ModulePortInfo>(inputs, outputs);
105 using ContainerPortInfoMap =
106 llvm::DenseMap<hw::InnerRefAttr, ContainerPortInfo>;
107 using ContainerHWModSymbolMap = llvm::DenseMap<hw::InnerRefAttr, StringAttr>;
109 static StringAttr concatNames(mlir::StringAttr lhs, mlir::StringAttr rhs) {
110 return StringAttr::get(lhs.getContext(), lhs.strref() +
"_" + rhs.strref());
114 ContainerOpConversionPattern(MLIRContext *ctx,
Namespace &modNamespace,
115 ContainerPortInfoMap &portOrder,
116 ContainerHWModSymbolMap &modSymMap)
118 portOrder(portOrder), modSymMap(modSymMap) {}
121 matchAndRewrite(ContainerOp op, OpAdaptor adaptor,
122 ConversionPatternRewriter &rewriter)
const override {
123 auto design = op->getParentOfType<DesignOp>();
124 rewriter.setInsertionPoint(design);
128 StringAttr hwmodName;
129 if (op.getIsTopLevel())
130 hwmodName = op.getNameHintAttr();
133 concatNames(op.getInnerRef().getModule(), op.getNameHintAttr());
136 modNamespace.newName(hwmodName.getValue()));
138 const ContainerPortInfo &cpi = portOrder.at(op.getInnerRef());
140 rewriter.create<
hw::HWModuleOp>(op.getLoc(), hwmodName, *cpi.hwPorts);
141 modSymMap[op.getInnerRef()] = hwMod.getSymNameAttr();
143 hw::OutputOp outputOp =
144 cast<hw::OutputOp>(hwMod.getBodyBlock()->getTerminator());
147 for (
auto [idx, input] : llvm::enumerate(cpi.hwPorts->getInputs())) {
148 Value barg = hwMod.getBodyBlock()->getArgument(idx);
149 InputPortOp inputPort = cpi.opInputs.at(input.name);
151 for (
auto *user : inputPort.getOperation()->getUsers()) {
152 auto reader = dyn_cast<PortReadOp>(user);
154 return rewriter.notifyMatchFailure(
155 user,
"expected only kanagawa.port.read ops of the input port");
157 rewriter.replaceOp(reader, barg);
160 rewriter.eraseOp(inputPort);
164 llvm::SmallVector<Value> outputValues;
165 for (
auto [idx, output] : llvm::enumerate(cpi.hwPorts->getOutputs())) {
166 auto outputPort = cpi.opOutputs.at(output.name);
168 auto users = outputPort->getUsers();
169 size_t nUsers = std::distance(users.begin(), users.end());
171 return outputPort->emitOpError()
172 <<
"expected exactly one kanagawa.port.write op of the output "
174 << output.name.str() <<
" found: " << nUsers;
175 auto writer = cast<PortWriteOp>(*users.begin());
176 outputValues.push_back(writer.getValue());
177 rewriter.eraseOp(outputPort);
178 rewriter.eraseOp(writer);
181 rewriter.mergeBlocks(&op.getBodyRegion().front(), hwMod.getBodyBlock());
184 rewriter.eraseOp(outputOp);
185 rewriter.setInsertionPointToEnd(hwMod.getBodyBlock());
186 outputOp = rewriter.create<hw::OutputOp>(op.getLoc(), outputValues);
187 rewriter.eraseOp(op);
192 ContainerPortInfoMap &portOrder;
193 ContainerHWModSymbolMap &modSymMap;
197 ThisOpConversionPattern(MLIRContext *ctx)
201 matchAndRewrite(ThisOp op, OpAdaptor adaptor,
202 ConversionPatternRewriter &rewriter)
const override {
204 rewriter.eraseOp(op);
209 struct ContainerInstanceOpConversionPattern
212 ContainerInstanceOpConversionPattern(MLIRContext *ctx,
213 ContainerPortInfoMap &portOrder,
214 ContainerHWModSymbolMap &modSymMap)
216 modSymMap(modSymMap) {}
219 matchAndRewrite(ContainerInstanceOp op, OpAdaptor adaptor,
220 ConversionPatternRewriter &rewriter)
const override {
221 rewriter.setInsertionPoint(op);
222 llvm::SmallVector<Value> operands;
224 const ContainerPortInfo &cpi =
225 portOrder.at(op.getResult().getType().getScopeRef());
228 llvm::DenseMap<StringAttr, PortReadOp> outputReadsToReplace;
229 llvm::DenseMap<StringAttr, PortWriteOp> inputWritesToUse;
230 llvm::SmallVector<Operation *> getPortsToErase;
231 for (
auto *user : op->getUsers()) {
232 auto getPort = dyn_cast<GetPortOp>(user);
234 return rewriter.notifyMatchFailure(
235 user,
"expected only kanagawa.get_port op usage of the instance");
237 for (
auto *user :
getPort->getUsers()) {
239 llvm::TypeSwitch<Operation *, LogicalResult>(user)
240 .Case<PortReadOp>([&](
auto read) {
241 auto [it, inserted] = outputReadsToReplace.insert(
242 {cpi.portSymbolsToPortName.at(
243 getPort.getPortSymbolAttr().getAttr()),
246 return rewriter.notifyMatchFailure(
247 read,
"expected only one kanagawa.port.read op of the "
251 .Case<PortWriteOp>([&](
auto write) {
252 auto [it, inserted] = inputWritesToUse.insert(
253 {cpi.portSymbolsToPortName.at(
254 getPort.getPortSymbolAttr().getAttr()),
257 return rewriter.notifyMatchFailure(
259 "expected only one kanagawa.port.write op of the input "
263 .Default([&](
auto op) {
264 return rewriter.notifyMatchFailure(
265 op,
"expected only kanagawa.port.read or "
266 "kanagawa.port.write ops "
273 getPortsToErase.push_back(
getPort);
277 size_t nInputPorts = std::distance(cpi.hwPorts->getInputs().begin(),
278 cpi.hwPorts->getInputs().end());
279 if (nInputPorts != inputWritesToUse.size()) {
281 llvm::raw_string_ostream ers(errMsg);
282 ers <<
"Error when lowering instance ";
283 op.print(ers, mlir::OpPrintingFlags().printGenericOpForm());
285 ers <<
"\nexpected exactly one kanagawa.port.write op of each input "
287 "Mising port assignments were:\n";
288 for (
auto input : cpi.hwPorts->getInputs()) {
289 if (inputWritesToUse.find(input.name) == inputWritesToUse.end())
290 ers <<
"\t" << input.name <<
"\n";
292 return rewriter.notifyMatchFailure(op, errMsg);
294 for (
auto input : cpi.hwPorts->getInputs()) {
295 auto writeOp = inputWritesToUse.at(input.name);
296 operands.push_back(writeOp.getValue());
297 rewriter.eraseOp(writeOp);
301 llvm::SmallVector<Type> retTypes;
302 for (
auto output : cpi.hwPorts->getOutputs())
303 retTypes.push_back(output.type);
307 llvm::SmallVector<Attribute> argNames, resNames;
308 llvm::transform(cpi.hwPorts->getInputs(), std::back_inserter(argNames),
309 [](
auto port) { return port.name; });
310 llvm::transform(cpi.hwPorts->getOutputs(), std::back_inserter(resNames),
311 [](
auto port) { return port.name; });
314 StringRef moduleName = modSymMap[op.getTargetNameAttr()];
315 auto hwInst = rewriter.create<hw::InstanceOp>(
316 op.getLoc(), retTypes, op.getInnerSym().getSymName(), moduleName,
317 operands, rewriter.getArrayAttr(argNames),
318 rewriter.getArrayAttr(resNames),
319 rewriter.getArrayAttr({}),
nullptr);
322 for (
auto [output, value] :
323 llvm::zip(cpi.hwPorts->getOutputs(), hwInst.getResults())) {
324 auto outputReadIt = outputReadsToReplace.find(output.name);
325 if (outputReadIt == outputReadsToReplace.end())
332 outputReadIt->second.getResult().replaceAllUsesWith(value);
333 rewriter.eraseOp(outputReadIt->second);
337 for (
auto *
getPort : getPortsToErase)
341 rewriter.eraseOp(op);
345 ContainerPortInfoMap &portOrder;
346 ContainerHWModSymbolMap &modSymMap;
349 struct ContainersToHWPass
350 :
public circt::kanagawa::impl::KanagawaContainersToHWBase<
351 ContainersToHWPass> {
352 void runOnOperation()
override;
356 void ContainersToHWPass::runOnOperation() {
357 auto *ctx = &getContext();
360 ContainerPortInfoMap portOrder;
361 for (
auto design : getOperation().getOps<DesignOp>())
362 for (
auto container : design.getOps<ContainerOp>())
363 portOrder.try_emplace(container.getInnerRef(),
364 ContainerPortInfo(container));
366 ConversionTarget target(*ctx);
367 ContainerHWModSymbolMap modSymMap;
371 modNamespace.
add(modSymCache);
372 target.addIllegalOp<ContainerOp, ContainerInstanceOp, ThisOp>();
373 target.markUnknownOpDynamicallyLegal([](Operation *) {
return true; });
380 for (
auto designOp : getOperation().getOps<DesignOp>())
381 modNamespace.
erase(designOp.getSymName());
387 target.addLegalDialect<KanagawaDialect>();
390 patterns.add<ContainerOpConversionPattern>(ctx, modNamespace, portOrder,
392 patterns.add<ContainerInstanceOpConversionPattern>(ctx, portOrder, modSymMap);
393 patterns.add<ThisOpConversionPattern>(ctx);
396 applyPartialConversion(getOperation(), target, std::move(
patterns))))
401 llvm::make_early_inc_range(getOperation().getOps<DesignOp>()))
402 if (design.getBody().front().empty())
407 return std::make_unique<ContainersToHWPass>();
static PortInfo getPort(ModuleTy &mod, size_t idx)
A namespace that is used to store existing names and generate new names in some scope within the IR.
void add(mlir::ModuleOp module)
bool erase(llvm::StringRef symbol)
Removes a symbol from the namespace.
StringRef newName(const Twine &name)
Return a unique name, derived from the input name, and add the new name to the internal namespace.
void addDefinitions(mlir::Operation *top)
Populate the symbol cache with all symbol-defining operations within the 'top' operation.
Default symbol cache implementation; stores associations between names (StringAttr's) to mlir::Operat...
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
std::unique_ptr< mlir::Pass > createContainersToHWPass()
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.