CIRCT 22.0.0git
Loading...
Searching...
No Matches
LinkCircuits.cpp
Go to the documentation of this file.
1//===- LinkCircuits.cpp - Merge FIRRTL circuits --------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass links multiple circuits into a single one.
10//
11//===----------------------------------------------------------------------===//
12
18#include "mlir/IR/Attributes.h"
19#include "mlir/IR/Builders.h"
20#include "mlir/IR/BuiltinAttributes.h"
21#include "mlir/IR/Diagnostics.h"
22#include "mlir/IR/Operation.h"
23#include "mlir/IR/SymbolTable.h"
24#include "mlir/Pass/Pass.h"
25#include "mlir/Support/LLVM.h"
26#include "llvm/ADT/MapVector.h"
27#include "llvm/ADT/STLExtras.h"
28#include <iterator>
29
30namespace circt {
31namespace firrtl {
32#define GEN_PASS_DEF_LINKCIRCUITS
33#include "circt/Dialect/FIRRTL/Passes.h.inc"
34} // namespace firrtl
35} // namespace circt
36
37using namespace mlir;
38using namespace circt;
39using namespace firrtl;
40
41namespace {
42struct LinkCircuitsPass
43 : public circt::firrtl::impl::LinkCircuitsBase<LinkCircuitsPass> {
44 using Base::Base;
45
46 void runOnOperation() override;
47 LogicalResult mergeCircuits();
48 LinkCircuitsPass(StringRef baseCircuitNameOption, bool noMangleOption) {
49 baseCircuitName = std::string(baseCircuitNameOption);
50 noMangle = noMangleOption;
51 }
52};
53} // namespace
54
55template <typename CallableT>
56static DictionaryAttr transformAnnotationTarget(DictionaryAttr anno,
57 CallableT transformTokensFn) {
58 return DictionaryAttr::getWithSorted(
59 anno.getContext(),
60 to_vector(map_range(
61 anno.getValue(), [&](NamedAttribute namedAttr) -> NamedAttribute {
62 if (namedAttr.getName() == "target")
63 if (auto target = dyn_cast<StringAttr>(namedAttr.getValue()))
64 if (auto tokens = tokenizePath(target.getValue()))
65 return {
66 namedAttr.getName(),
67 StringAttr::get(target.getContext(),
68 transformTokensFn(tokens.value()).str())};
69 return namedAttr;
70 })));
71}
72
73static LogicalResult mangleCircuitSymbols(CircuitOp circuit) {
74 auto circuitName = circuit.getNameAttr();
75
76 llvm::MapVector<StringRef, Operation *> renameTable;
77 auto symbolTable = SymbolTable(circuit.getOperation());
78 auto manglePrivateSymbol = [&](SymbolOpInterface symbolOp) {
79 auto symbolName = symbolOp.getNameAttr();
80 auto newSymbolName =
81 StringAttr::get(symbolOp->getContext(),
82 circuitName.getValue() + "_" + symbolName.getValue());
83 renameTable.insert(std::pair(symbolName.getValue(), symbolOp));
84 return symbolTable.rename(symbolOp, newSymbolName);
85 };
86
87 for (auto &op : circuit.getOps()) {
88 auto symbolOp = dyn_cast<SymbolOpInterface>(op);
89 if (!symbolOp)
90 continue;
91
92 if (symbolOp.isPrivate())
93 if (failed(manglePrivateSymbol(symbolOp)))
94 return failure();
95 }
96
97 circuit.walk([&](Operation *op) {
98 auto updateType = [&](Type type) -> Type {
99 if (auto cls = dyn_cast<ClassType>(type))
100 if (auto *newOp = renameTable.lookup(cls.getName()))
101 return ClassType::get(FlatSymbolRefAttr::get(newOp),
102 cls.getElements());
103 return type;
104 };
105 auto updateTypeAttr = [&](Attribute attr) -> Attribute {
106 if (auto typeAttr = dyn_cast<TypeAttr>(attr)) {
107 auto newType = updateType(typeAttr.getValue());
108 if (newType != typeAttr.getValue())
109 return TypeAttr::get(newType);
110 }
111 return attr;
112 };
113 auto updateResults = [&](auto &&results) {
114 for (auto result : results)
115 if (auto newType = updateType(result.getType());
116 newType != result.getType())
117 result.setType(newType);
118 };
119
120 TypeSwitch<Operation *>(op)
121 .Case<CircuitOp>([&](CircuitOp circuit) {
122 SmallVector<Attribute> newAnnotations;
123 llvm::transform(
124 circuit.getAnnotationsAttr(), std::back_inserter(newAnnotations),
125 [&](Attribute attr) {
126 return transformAnnotationTarget(
127 cast<DictionaryAttr>(attr), [&](TokenAnnoTarget &tokens) {
128 if (auto *newModule = renameTable.lookup(tokens.module))
129 tokens.module =
130 cast<SymbolOpInterface>(newModule).getName();
131 return tokens;
132 });
133 });
134 circuit.setAnnotationsAttr(
135 ArrayAttr::get(circuit.getContext(), newAnnotations));
136 })
137 .Case<ObjectOp>([&](ObjectOp obj) {
138 auto resultTypeName = obj.getResult().getType().getName();
139 if (auto *newOp = renameTable.lookup(resultTypeName))
140 obj.getResult().setType(dyn_cast<ClassOp>(newOp).getInstanceType());
141 })
142 .Case<FModuleOp>([&](FModuleOp module) {
143 SmallVector<Attribute> newPortTypes;
144 llvm::transform(module.getPortTypesAttr().getValue(),
145 std::back_inserter(newPortTypes), updateTypeAttr);
146 module.setPortTypesAttr(
147 ArrayAttr::get(module->getContext(), newPortTypes));
148 updateResults(module.getBodyBlock()->getArguments());
149 })
150 .Case<InstanceOp>(
151 [&](InstanceOp instance) { updateResults(instance->getResults()); })
152 .Case<WireOp>([&](WireOp wire) { updateResults(wire->getResults()); })
153 .Default([](Operation *op) {});
154 });
155 return success();
156}
157
158/// return if the incomingOp has been erased
159static FailureOr<bool> linkExtmodule(SymbolOpInterface collidingOp,
160 SymbolOpInterface incomingOp) {
161 if (!((isa<FExtModuleOp>(collidingOp) && isa<FModuleOp>(incomingOp)) ||
162 (isa<FExtModuleOp>(incomingOp) && isa<FModuleOp>(collidingOp))))
163 return failure();
164 auto definition = collidingOp;
165 auto declaration = incomingOp;
166 if (!isa<FModuleOp>(collidingOp)) {
167 definition = incomingOp;
168 declaration = collidingOp;
169 }
170 if (!definition.isPublic())
171 return definition->emitOpError("should be a public symbol");
172 if (!declaration.isPublic())
173 return declaration->emitOpError("should be a public symbol");
174
175 constexpr const StringRef attrsToCompare[] = {
176 "portDirections", "portSymbols", "portNames", "portTypes", "layers"};
177 auto allAttrsMatch = all_of(attrsToCompare, [&](StringRef attr) {
178 return definition->getAttr(attr) == declaration->getAttr(attr);
179 });
180
181 if (!allAttrsMatch)
182 return false;
183
184 declaration->erase();
185 return declaration == incomingOp;
186}
187
188LogicalResult LinkCircuitsPass::mergeCircuits() {
189 auto module = getOperation();
190
191 SmallVector<CircuitOp> circuits;
192 for (CircuitOp circuitOp : module.getOps<CircuitOp>())
193 circuits.push_back(circuitOp);
194
195 auto builder = OpBuilder(module);
196 builder.setInsertionPointToEnd(module.getBody());
197 auto mergedCircuit =
198 CircuitOp::create(builder, module.getLoc(),
199 StringAttr::get(&getContext(), baseCircuitName));
200 SmallVector<Attribute> mergedAnnotations;
201
202 for (auto circuit : circuits) {
203 if (!noMangle)
204 if (failed(mangleCircuitSymbols(circuit)))
205 return circuit->emitError("failed to mangle private symbol");
206
207 // TODO: other circuit attributes (such as enable_layers...)
208 llvm::transform(circuit.getAnnotations().getValue(),
209 std::back_inserter(mergedAnnotations), [&](Attribute attr) {
210 return transformAnnotationTarget(
211 cast<DictionaryAttr>(attr),
212 [&](TokenAnnoTarget &tokens) {
213 tokens.circuit = mergedCircuit.getName();
214 return tokens;
215 });
216 });
217
218 // reconstruct symbol table after each merge
219 auto mergedSymbolTable = SymbolTable(mergedCircuit.getOperation());
220
221 SmallVector<Operation *> opsToMove;
222 for (auto &op : circuit.getOps())
223 opsToMove.push_back(&op);
224 for (auto *op : opsToMove) {
225 if (auto symbolOp = dyn_cast<SymbolOpInterface>(op))
226 if (auto collidingOp = cast_if_present<SymbolOpInterface>(
227 mergedSymbolTable.lookup(symbolOp.getNameAttr()))) {
228 auto opErased = linkExtmodule(collidingOp, symbolOp);
229 if (failed(opErased))
230 return mergedCircuit->emitError("has colliding symbol " +
231 symbolOp.getName() +
232 " which cannot be merged");
233 if (opErased.value())
234 continue;
235 }
236
237 op->moveBefore(mergedCircuit.getBodyBlock(),
238 mergedCircuit.getBodyBlock()->end());
239 }
240
241 circuit->erase();
242 }
243
244 mergedCircuit.setAnnotationsAttr(
245 ArrayAttr::get(mergedCircuit.getContext(), mergedAnnotations));
246
247 return mlir::detail::verifySymbolTable(mergedCircuit);
248}
249
250void LinkCircuitsPass::runOnOperation() {
251 if (failed(mergeCircuits()))
252 signalPassFailure();
253}
static FIRRTLBaseType updateType(FIRRTLBaseType oldType, unsigned fieldID, FIRRTLBaseType fieldType)
Update the type of a single field within a type.
static DictionaryAttr transformAnnotationTarget(DictionaryAttr anno, CallableT transformTokensFn)
static FailureOr< bool > linkExtmodule(SymbolOpInterface collidingOp, SymbolOpInterface incomingOp)
return if the incomingOp has been erased
static LogicalResult mangleCircuitSymbols(CircuitOp circuit)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.