18#include "mlir/IR/Attributes.h"
19#include "mlir/IR/Builders.h"
20#include "mlir/IR/BuiltinAttributes.h"
21#include "mlir/IR/Diagnostics.h"
22#include "mlir/IR/Operation.h"
23#include "mlir/IR/SymbolTable.h"
24#include "mlir/Support/LLVM.h"
25#include "llvm/ADT/MapVector.h"
26#include "llvm/ADT/STLExtras.h"
31#define GEN_PASS_DEF_LINKCIRCUITS
32#include "circt/Dialect/FIRRTL/Passes.h.inc"
38using namespace firrtl;
41struct LinkCircuitsPass
42 :
public circt::firrtl::impl::LinkCircuitsBase<LinkCircuitsPass> {
45 void runOnOperation()
override;
46 LogicalResult mergeCircuits();
47 LinkCircuitsPass(StringRef baseCircuitNameOption,
bool noMangleOption) {
48 baseCircuitName = std::string(baseCircuitNameOption);
49 noMangle = noMangleOption;
54template <
typename CallableT>
56 CallableT transformTokensFn) {
57 return DictionaryAttr::getWithSorted(
60 anno.getValue(), [&](NamedAttribute namedAttr) -> NamedAttribute {
61 if (namedAttr.getName() ==
"target")
62 if (auto target = dyn_cast<StringAttr>(namedAttr.getValue()))
63 if (auto tokens = tokenizePath(target.getValue()))
66 StringAttr::get(target.getContext(),
67 transformTokensFn(tokens.value()).str())};
73 auto circuitName = circuit.getNameAttr();
75 llvm::MapVector<StringRef, Operation *> renameTable;
76 auto symbolTable = SymbolTable(circuit.getOperation());
77 auto manglePrivateSymbol = [&](SymbolOpInterface symbolOp) {
78 auto symbolName = symbolOp.getNameAttr();
80 StringAttr::get(symbolOp->getContext(),
81 circuitName.getValue() +
"_" + symbolName.getValue());
82 renameTable.insert(std::pair(symbolName.getValue(), symbolOp));
83 return symbolTable.rename(symbolOp, newSymbolName);
86 for (
auto &op : circuit.getOps()) {
87 auto symbolOp = dyn_cast<SymbolOpInterface>(op);
91 if (symbolOp.isPrivate())
92 if (failed(manglePrivateSymbol(symbolOp)))
96 circuit.walk([&](Operation *op) {
98 if (
auto cls = dyn_cast<ClassType>(type))
99 if (
auto *newOp = renameTable.lookup(cls.getName()))
100 return ClassType::get(FlatSymbolRefAttr::get(newOp),
104 auto updateTypeAttr = [&](Attribute attr) -> Attribute {
105 if (
auto typeAttr = dyn_cast<TypeAttr>(attr)) {
106 auto newType =
updateType(typeAttr.getValue());
107 if (newType != typeAttr.getValue())
108 return TypeAttr::get(newType);
112 auto updateResults = [&](
auto &&results) {
113 for (
auto result : results)
114 if (
auto newType =
updateType(result.getType());
115 newType != result.getType())
116 result.setType(newType);
119 TypeSwitch<Operation *>(op)
120 .Case<CircuitOp>([&](CircuitOp circuit) {
121 SmallVector<Attribute> newAnnotations;
123 circuit.getAnnotationsAttr(), std::back_inserter(newAnnotations),
124 [&](Attribute attr) {
125 return transformAnnotationTarget(
126 cast<DictionaryAttr>(attr), [&](TokenAnnoTarget &tokens) {
127 if (auto *newModule = renameTable.lookup(tokens.module))
129 cast<SymbolOpInterface>(newModule).getName();
133 circuit.setAnnotationsAttr(
134 ArrayAttr::get(circuit.getContext(), newAnnotations));
136 .Case<ObjectOp>([&](ObjectOp obj) {
137 auto resultTypeName = obj.getResult().getType().getName();
138 if (
auto *newOp = renameTable.lookup(resultTypeName))
139 obj.getResult().setType(dyn_cast<ClassOp>(newOp).getInstanceType());
141 .Case<FModuleOp>([&](FModuleOp module) {
142 SmallVector<Attribute> newPortTypes;
143 llvm::transform(module.getPortTypesAttr().getValue(),
144 std::back_inserter(newPortTypes), updateTypeAttr);
145 module.setPortTypesAttr(
146 ArrayAttr::get(module->getContext(), newPortTypes));
147 updateResults(module.getBodyBlock()->getArguments());
150 [&](InstanceOp instance) { updateResults(instance->getResults()); })
151 .Case<WireOp>([&](WireOp wire) { updateResults(wire->getResults()); })
152 .Default([](Operation *op) {});
158 SmallVectorImpl<FlatSymbolRefAttr> &stack,
159 llvm::DenseSet<Attribute> &layers) {
160 stack.push_back(FlatSymbolRefAttr::get(layer.getSymNameAttr()));
161 layers.insert(SymbolRefAttr::get(stack.front().getAttr(),
162 ArrayRef(stack).drop_front()));
163 for (
auto child : layer.getOps<LayerOp>())
169 llvm::DenseSet<Attribute> &layers) {
170 SmallVector<FlatSymbolRefAttr> stack;
171 for (
auto layer : circuit.getOps<LayerOp>())
177 const llvm::DenseSet<Attribute> &availableLayers) {
178 auto knownLayersAttr = extModule.getKnownLayersAttr();
179 if (!knownLayersAttr)
182 SmallVector<Attribute> missingLayers;
183 for (
auto attr : knownLayersAttr)
184 if (!availableLayers.contains(attr))
185 missingLayers.push_back(attr);
187 if (missingLayers.empty())
190 auto diag = extModule.emitOpError()
191 <<
"declares known layers that are not defined in the linked "
193 llvm::interleaveComma(missingLayers, diag,
194 [&](Attribute attr) { diag << attr; });
199 if (dst.getConvention() != src.getConvention())
200 return src.emitOpError(
"layer convention mismatch with existing layer");
202 SymbolTable dstSymbolTable(dst);
204 for (
auto &op : llvm::make_early_inc_range(src.getBody().front())) {
205 if (
auto srcChildLayer = dyn_cast<LayerOp>(op))
206 if (
auto dstChildLayer = cast_if_present<LayerOp>(
207 dstSymbolTable.lookup(srcChildLayer.getNameAttr()))) {
208 if (failed(
mergeLayer(dstChildLayer, srcChildLayer)))
212 op.moveBefore(&dst.getBody().front(), dst.getBody().front().end());
245static FailureOr<bool>
247 const llvm::DenseSet<Attribute> &mergedLayers,
248 const llvm::DenseSet<Attribute> &incomingLayers) {
249 if (!collidingOp.isPublic())
250 return collidingOp->emitOpError(
"should be a public symbol");
251 if (!incomingOp.isPublic())
252 return incomingOp->emitOpError(
"should be a public symbol");
254 if ((isa<FExtModuleOp>(collidingOp) && isa<FModuleOp>(incomingOp)) ||
255 (isa<FExtModuleOp>(incomingOp) && isa<FModuleOp>(collidingOp))) {
256 auto definition = collidingOp;
257 auto declaration = incomingOp;
258 if (!isa<FModuleOp>(collidingOp)) {
259 definition = incomingOp;
260 declaration = collidingOp;
263 auto extModule = cast<FExtModuleOp>(declaration);
264 const auto &layersToCheck =
265 (definition == incomingOp) ? incomingLayers : mergedLayers;
269 constexpr const StringRef attrsToCompare[] = {
270 "portDirections",
"portSymbols",
"portNames",
"portTypes",
"layers"};
271 if (!all_of(attrsToCompare, [&](StringRef attr) {
272 return definition->getAttr(attr) == declaration->getAttr(attr);
276 declaration->erase();
277 return declaration == incomingOp;
280 if (isa<FExtModuleOp>(collidingOp) && isa<FExtModuleOp>(incomingOp)) {
281 constexpr const StringRef attrsToCompare[] = {
282 "portDirections",
"portSymbols",
"portNames",
283 "portTypes",
"knownLayers",
"layers",
285 if (!all_of(attrsToCompare, [&](StringRef attr) {
286 return collidingOp->getAttr(attr) == incomingOp->getAttr(attr);
290 auto collidingParams = collidingOp->getAttrOfType<ArrayAttr>(
"parameters");
291 auto incomingParams = incomingOp->getAttrOfType<ArrayAttr>(
"parameters");
292 if (collidingParams == incomingParams) {
293 if (collidingOp->getAttr(
"defname") != incomingOp->getAttr(
"defname"))
301 if (collidingParams.empty() || incomingParams.empty()) {
302 auto declaration = collidingParams.empty() ? collidingOp : incomingOp;
303 declaration->erase();
304 return declaration == incomingOp;
308 if (isa<LayerOp>(collidingOp) && isa<LayerOp>(incomingOp)) {
310 mergeLayer(cast<LayerOp>(collidingOp), cast<LayerOp>(incomingOp))))
319LogicalResult LinkCircuitsPass::mergeCircuits() {
320 auto module = getOperation();
322 SmallVector<CircuitOp> circuits;
323 for (CircuitOp circuitOp : module.getOps<CircuitOp>())
324 circuits.push_back(circuitOp);
326 auto builder = OpBuilder(module);
327 builder.setInsertionPointToEnd(module.getBody());
329 CircuitOp::create(builder, module.getLoc(),
330 StringAttr::get(&getContext(), baseCircuitName));
331 SmallVector<Attribute> mergedAnnotations;
333 llvm::DenseSet<Attribute> mergedLayers;
335 for (
auto circuit : circuits) {
338 return circuit->emitError(
"failed to mangle private symbol");
340 llvm::DenseSet<Attribute> incomingLayers;
344 llvm::transform(circuit.getAnnotations().getValue(),
345 std::back_inserter(mergedAnnotations), [&](Attribute attr) {
346 return transformAnnotationTarget(
347 cast<DictionaryAttr>(attr),
348 [&](TokenAnnoTarget &tokens) {
349 tokens.circuit = mergedCircuit.getName();
355 auto mergedSymbolTable = SymbolTable(mergedCircuit.getOperation());
357 SmallVector<Operation *> opsToMove;
358 for (
auto &op : circuit.getOps())
359 opsToMove.push_back(&op);
360 for (
auto *op : opsToMove) {
361 if (
auto symbolOp = dyn_cast<SymbolOpInterface>(op))
362 if (
auto collidingOp = cast_if_present<SymbolOpInterface>(
363 mergedSymbolTable.lookup(symbolOp.getNameAttr()))) {
365 mergedLayers, incomingLayers);
366 if (failed(opErased))
367 return mergedCircuit->emitError(
"has colliding symbol " +
369 " which cannot be merged");
370 if (opErased.value())
374 op->moveBefore(mergedCircuit.getBodyBlock(),
375 mergedCircuit.getBodyBlock()->end());
378 mergedLayers.insert(incomingLayers.begin(), incomingLayers.end());
382 mergedCircuit.setAnnotationsAttr(
383 ArrayAttr::get(mergedCircuit.getContext(), mergedAnnotations));
385 return mlir::detail::verifySymbolTable(mergedCircuit);
388void LinkCircuitsPass::runOnOperation() {
389 if (failed(mergeCircuits()))
static FIRRTLBaseType updateType(FIRRTLBaseType oldType, unsigned fieldID, FIRRTLBaseType fieldType)
Update the type of a single field within a type.
static LogicalResult verifyKnownLayers(FExtModuleOp extModule, const llvm::DenseSet< Attribute > &availableLayers)
static DictionaryAttr transformAnnotationTarget(DictionaryAttr anno, CallableT transformTokensFn)
static LogicalResult mergeLayer(LayerOp dst, LayerOp src)
static void collectLayerSymbols(LayerOp layer, SmallVectorImpl< FlatSymbolRefAttr > &stack, llvm::DenseSet< Attribute > &layers)
static LogicalResult mangleCircuitSymbols(CircuitOp circuit)
static FailureOr< bool > handleCollidingOps(SymbolOpInterface collidingOp, SymbolOpInterface incomingOp, const llvm::DenseSet< Attribute > &mergedLayers, const llvm::DenseSet< Attribute > &incomingLayers)
Handles colliding symbols when merging circuits.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.