18#include "mlir/IR/Attributes.h"
19#include "mlir/IR/Builders.h"
20#include "mlir/IR/BuiltinAttributes.h"
21#include "mlir/IR/Diagnostics.h"
22#include "mlir/IR/Operation.h"
23#include "mlir/IR/SymbolTable.h"
24#include "mlir/Support/LLVM.h"
25#include "llvm/ADT/MapVector.h"
26#include "llvm/ADT/STLExtras.h"
31#define GEN_PASS_DEF_LINKCIRCUITS
32#include "circt/Dialect/FIRRTL/Passes.h.inc"
38using namespace firrtl;
41struct LinkCircuitsPass
42 :
public circt::firrtl::impl::LinkCircuitsBase<LinkCircuitsPass> {
45 void runOnOperation()
override;
46 LogicalResult mergeCircuits();
47 LinkCircuitsPass(StringRef baseCircuitNameOption,
bool noMangleOption) {
48 baseCircuitName = std::string(baseCircuitNameOption);
49 noMangle = noMangleOption;
54template <
typename CallableT>
56 CallableT transformTokensFn) {
57 return DictionaryAttr::getWithSorted(
60 anno.getValue(), [&](NamedAttribute namedAttr) -> NamedAttribute {
61 if (namedAttr.getName() ==
"target")
62 if (auto target = dyn_cast<StringAttr>(namedAttr.getValue()))
63 if (auto tokens = tokenizePath(target.getValue()))
66 StringAttr::get(target.getContext(),
67 transformTokensFn(tokens.value()).str())};
73 auto circuitName = circuit.getNameAttr();
75 llvm::MapVector<StringRef, Operation *> renameTable;
76 auto symbolTable = SymbolTable(circuit.getOperation());
77 auto manglePrivateSymbol = [&](SymbolOpInterface symbolOp) {
78 auto symbolName = symbolOp.getNameAttr();
80 StringAttr::get(symbolOp->getContext(),
81 circuitName.getValue() +
"_" + symbolName.getValue());
82 renameTable.insert(std::pair(symbolName.getValue(), symbolOp));
83 return symbolTable.rename(symbolOp, newSymbolName);
86 for (
auto &op : circuit.getOps()) {
87 auto symbolOp = dyn_cast<SymbolOpInterface>(op);
93 if (symbolOp.isPrivate() && !isa<FExtModuleOp>(symbolOp))
94 if (failed(manglePrivateSymbol(symbolOp)))
98 circuit.walk([&](Operation *op) {
100 if (
auto cls = dyn_cast<ClassType>(type))
101 if (
auto *newOp = renameTable.lookup(cls.getName()))
102 return ClassType::get(FlatSymbolRefAttr::get(newOp),
106 auto updateTypeAttr = [&](Attribute attr) -> Attribute {
107 if (
auto typeAttr = dyn_cast<TypeAttr>(attr)) {
108 auto newType =
updateType(typeAttr.getValue());
109 if (newType != typeAttr.getValue())
110 return TypeAttr::get(newType);
114 auto updateResults = [&](
auto &&results) {
115 for (
auto result : results)
116 if (
auto newType =
updateType(result.getType());
117 newType != result.getType())
118 result.setType(newType);
121 TypeSwitch<Operation *>(op)
122 .Case<CircuitOp>([&](CircuitOp circuit) {
123 SmallVector<Attribute> newAnnotations;
125 circuit.getAnnotationsAttr(), std::back_inserter(newAnnotations),
126 [&](Attribute attr) {
127 return transformAnnotationTarget(
128 cast<DictionaryAttr>(attr), [&](TokenAnnoTarget &tokens) {
129 if (auto *newModule = renameTable.lookup(tokens.module))
131 cast<SymbolOpInterface>(newModule).getName();
135 circuit.setAnnotationsAttr(
136 ArrayAttr::get(circuit.getContext(), newAnnotations));
138 .Case<ObjectOp>([&](ObjectOp obj) {
139 auto resultTypeName = obj.getResult().getType().getName();
140 if (
auto *newOp = renameTable.lookup(resultTypeName))
141 obj.getResult().setType(dyn_cast<ClassOp>(newOp).getInstanceType());
143 .Case<FModuleOp>([&](FModuleOp module) {
144 SmallVector<Attribute> newPortTypes;
145 llvm::transform(module.getPortTypesAttr().getValue(),
146 std::back_inserter(newPortTypes), updateTypeAttr);
147 module.setPortTypesAttr(
148 ArrayAttr::get(module->getContext(), newPortTypes));
149 updateResults(module.getBodyBlock()->getArguments());
152 [&](InstanceOp instance) { updateResults(instance->getResults()); })
153 .Case<WireOp>([&](WireOp wire) { updateResults(wire->getResults()); })
154 .Default([](Operation *op) {});
160 SmallVectorImpl<FlatSymbolRefAttr> &stack,
161 llvm::DenseSet<Attribute> &layers) {
162 stack.push_back(FlatSymbolRefAttr::get(layer.getSymNameAttr()));
163 layers.insert(SymbolRefAttr::get(stack.front().getAttr(),
164 ArrayRef(stack).drop_front()));
165 for (
auto child : layer.getOps<LayerOp>())
171 llvm::DenseSet<Attribute> &layers) {
172 SmallVector<FlatSymbolRefAttr> stack;
173 for (
auto layer : circuit.getOps<LayerOp>())
179 const llvm::DenseSet<Attribute> &availableLayers) {
180 auto knownLayersAttr = extModule.getKnownLayersAttr();
181 if (!knownLayersAttr)
184 SmallVector<Attribute> missingLayers;
185 for (
auto attr : knownLayersAttr)
186 if (!availableLayers.contains(attr))
187 missingLayers.push_back(attr);
189 if (missingLayers.empty())
192 auto diag = extModule.emitOpError()
193 <<
"declares known layers that are not defined in the linked "
195 llvm::interleaveComma(missingLayers, diag,
196 [&](Attribute attr) { diag << attr; });
201 if (dst.getConvention() != src.getConvention())
202 return src.emitOpError(
"layer convention mismatch with existing layer");
204 SymbolTable dstSymbolTable(dst);
206 for (
auto &op : llvm::make_early_inc_range(src.getBody().front())) {
207 if (
auto srcChildLayer = dyn_cast<LayerOp>(op))
208 if (
auto dstChildLayer = cast_if_present<LayerOp>(
209 dstSymbolTable.lookup(srcChildLayer.getNameAttr()))) {
210 if (failed(
mergeLayer(dstChildLayer, srcChildLayer)))
214 op.moveBefore(&dst.getBody().front(), dst.getBody().front().end());
236static FailureOr<bool>
238 const llvm::DenseSet<Attribute> &mergedLayers,
239 const llvm::DenseSet<Attribute> &incomingLayers) {
240 if ((isa<FExtModuleOp>(collidingOp) && isa<FModuleOp>(incomingOp)) ||
241 (isa<FExtModuleOp>(incomingOp) && isa<FModuleOp>(collidingOp))) {
242 auto definition = collidingOp;
243 auto declaration = incomingOp;
244 if (!isa<FModuleOp>(collidingOp)) {
245 definition = incomingOp;
246 declaration = collidingOp;
249 if (!definition.isPublic())
250 return definition->emitOpError(
"should be a public symbol");
252 auto extModule = cast<FExtModuleOp>(declaration);
253 const auto &layersToCheck =
254 (definition == incomingOp) ? incomingLayers : mergedLayers;
258 constexpr const StringRef attrsToCompare[] = {
259 "portDirections",
"portSymbols",
"portNames",
"portTypes",
"layers"};
260 if (!all_of(attrsToCompare, [&](StringRef attr) {
261 return definition->getAttr(attr) == declaration->getAttr(attr);
265 declaration->erase();
266 return declaration == incomingOp;
269 if (isa<FExtModuleOp>(collidingOp) && isa<FExtModuleOp>(incomingOp)) {
270 constexpr const StringRef attrsToCompare[] = {
271 "portDirections",
"portSymbols",
"portNames",
272 "portTypes",
"knownLayers",
"layers",
274 if (!all_of(attrsToCompare, [&](StringRef attr) {
275 return collidingOp->getAttr(attr) == incomingOp->getAttr(attr);
279 auto collidingParams = collidingOp->getAttrOfType<ArrayAttr>(
"parameters");
280 auto incomingParams = incomingOp->getAttrOfType<ArrayAttr>(
"parameters");
281 if (collidingParams == incomingParams) {
282 if (collidingOp->getAttr(
"defname") != incomingOp->getAttr(
"defname"))
290 if (collidingParams.empty() || incomingParams.empty()) {
291 auto declaration = collidingParams.empty() ? collidingOp : incomingOp;
292 declaration->erase();
293 return declaration == incomingOp;
297 if (isa<LayerOp>(collidingOp) && isa<LayerOp>(incomingOp)) {
299 mergeLayer(cast<LayerOp>(collidingOp), cast<LayerOp>(incomingOp))))
308LogicalResult LinkCircuitsPass::mergeCircuits() {
309 auto module = getOperation();
311 SmallVector<CircuitOp> circuits;
312 for (CircuitOp circuitOp : module.getOps<CircuitOp>())
313 circuits.push_back(circuitOp);
315 auto builder = OpBuilder(module);
316 builder.setInsertionPointToEnd(module.getBody());
318 CircuitOp::create(builder, module.getLoc(),
319 StringAttr::get(&getContext(), baseCircuitName));
320 SmallVector<Attribute> mergedAnnotations;
322 llvm::DenseSet<Attribute> mergedLayers;
324 for (
auto circuit : circuits) {
327 return circuit->emitError(
"failed to mangle private symbol");
329 llvm::DenseSet<Attribute> incomingLayers;
333 llvm::transform(circuit.getAnnotations().getValue(),
334 std::back_inserter(mergedAnnotations), [&](Attribute attr) {
335 return transformAnnotationTarget(
336 cast<DictionaryAttr>(attr),
337 [&](TokenAnnoTarget &tokens) {
338 tokens.circuit = mergedCircuit.getName();
344 auto mergedSymbolTable = SymbolTable(mergedCircuit.getOperation());
346 SmallVector<Operation *> opsToMove;
347 for (
auto &op : circuit.getOps())
348 opsToMove.push_back(&op);
349 for (
auto *op : opsToMove) {
350 if (
auto symbolOp = dyn_cast<SymbolOpInterface>(op))
351 if (
auto collidingOp = cast_if_present<SymbolOpInterface>(
352 mergedSymbolTable.lookup(symbolOp.getNameAttr()))) {
354 mergedLayers, incomingLayers);
355 if (failed(opErased))
356 return mergedCircuit->emitError(
"has colliding symbol " +
358 " which cannot be merged");
359 if (opErased.value())
363 op->moveBefore(mergedCircuit.getBodyBlock(),
364 mergedCircuit.getBodyBlock()->end());
367 mergedLayers.insert(incomingLayers.begin(), incomingLayers.end());
371 mergedCircuit.setAnnotationsAttr(
372 ArrayAttr::get(mergedCircuit.getContext(), mergedAnnotations));
374 return mlir::detail::verifySymbolTable(mergedCircuit);
377void LinkCircuitsPass::runOnOperation() {
378 if (failed(mergeCircuits()))
static FIRRTLBaseType updateType(FIRRTLBaseType oldType, unsigned fieldID, FIRRTLBaseType fieldType)
Update the type of a single field within a type.
static LogicalResult verifyKnownLayers(FExtModuleOp extModule, const llvm::DenseSet< Attribute > &availableLayers)
static DictionaryAttr transformAnnotationTarget(DictionaryAttr anno, CallableT transformTokensFn)
static LogicalResult mergeLayer(LayerOp dst, LayerOp src)
static void collectLayerSymbols(LayerOp layer, SmallVectorImpl< FlatSymbolRefAttr > &stack, llvm::DenseSet< Attribute > &layers)
static LogicalResult mangleCircuitSymbols(CircuitOp circuit)
static FailureOr< bool > handleCollidingOps(SymbolOpInterface collidingOp, SymbolOpInterface incomingOp, const llvm::DenseSet< Attribute > &mergedLayers, const llvm::DenseSet< Attribute > &incomingLayers)
Resolves symbol collisions during circuit merging.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.