CIRCT 23.0.0git
Loading...
Searching...
No Matches
LinkCircuits.cpp
Go to the documentation of this file.
1//===- LinkCircuits.cpp - Merge FIRRTL circuits --------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass links multiple circuits into a single one.
10//
11//===----------------------------------------------------------------------===//
12
18#include "mlir/IR/Attributes.h"
19#include "mlir/IR/Builders.h"
20#include "mlir/IR/BuiltinAttributes.h"
21#include "mlir/IR/Diagnostics.h"
22#include "mlir/IR/Operation.h"
23#include "mlir/IR/SymbolTable.h"
24#include "mlir/Support/LLVM.h"
25#include "llvm/ADT/MapVector.h"
26#include "llvm/ADT/STLExtras.h"
27#include <iterator>
28
29namespace circt {
30namespace firrtl {
31#define GEN_PASS_DEF_LINKCIRCUITS
32#include "circt/Dialect/FIRRTL/Passes.h.inc"
33} // namespace firrtl
34} // namespace circt
35
36using namespace mlir;
37using namespace circt;
38using namespace firrtl;
39
40namespace {
41struct LinkCircuitsPass
42 : public circt::firrtl::impl::LinkCircuitsBase<LinkCircuitsPass> {
43 using Base::Base;
44
45 void runOnOperation() override;
46 LogicalResult mergeCircuits();
47 LinkCircuitsPass(StringRef baseCircuitNameOption, bool noMangleOption) {
48 baseCircuitName = std::string(baseCircuitNameOption);
49 noMangle = noMangleOption;
50 }
51};
52} // namespace
53
54template <typename CallableT>
55static DictionaryAttr transformAnnotationTarget(DictionaryAttr anno,
56 CallableT transformTokensFn) {
57 return DictionaryAttr::getWithSorted(
58 anno.getContext(),
59 to_vector(map_range(
60 anno.getValue(), [&](NamedAttribute namedAttr) -> NamedAttribute {
61 if (namedAttr.getName() == "target")
62 if (auto target = dyn_cast<StringAttr>(namedAttr.getValue()))
63 if (auto tokens = tokenizePath(target.getValue()))
64 return {
65 namedAttr.getName(),
66 StringAttr::get(target.getContext(),
67 transformTokensFn(tokens.value()).str())};
68 return namedAttr;
69 })));
70}
71
72static LogicalResult mangleCircuitSymbols(CircuitOp circuit) {
73 auto circuitName = circuit.getNameAttr();
74
75 llvm::MapVector<StringRef, Operation *> renameTable;
76 auto symbolTable = SymbolTable(circuit.getOperation());
77 auto manglePrivateSymbol = [&](SymbolOpInterface symbolOp) {
78 auto symbolName = symbolOp.getNameAttr();
79 auto newSymbolName =
80 StringAttr::get(symbolOp->getContext(),
81 circuitName.getValue() + "_" + symbolName.getValue());
82 renameTable.insert(std::pair(symbolName.getValue(), symbolOp));
83 return symbolTable.rename(symbolOp, newSymbolName);
84 };
85
86 for (auto &op : circuit.getOps()) {
87 auto symbolOp = dyn_cast<SymbolOpInterface>(op);
88 if (!symbolOp)
89 continue;
90
91 // Skip mangling for extmodules: they are declarations of external entities
92 // whose names must be preserved for cross-circuit linking resolution.
93 if (symbolOp.isPrivate() && !isa<FExtModuleOp>(symbolOp))
94 if (failed(manglePrivateSymbol(symbolOp)))
95 return failure();
96 }
97
98 circuit.walk([&](Operation *op) {
99 auto updateType = [&](Type type) -> Type {
100 if (auto cls = dyn_cast<ClassType>(type))
101 if (auto *newOp = renameTable.lookup(cls.getName()))
102 return ClassType::get(FlatSymbolRefAttr::get(newOp),
103 cls.getElements());
104 return type;
105 };
106 auto updateTypeAttr = [&](Attribute attr) -> Attribute {
107 if (auto typeAttr = dyn_cast<TypeAttr>(attr)) {
108 auto newType = updateType(typeAttr.getValue());
109 if (newType != typeAttr.getValue())
110 return TypeAttr::get(newType);
111 }
112 return attr;
113 };
114 auto updateResults = [&](auto &&results) {
115 for (auto result : results)
116 if (auto newType = updateType(result.getType());
117 newType != result.getType())
118 result.setType(newType);
119 };
120
121 TypeSwitch<Operation *>(op)
122 .Case<CircuitOp>([&](CircuitOp circuit) {
123 SmallVector<Attribute> newAnnotations;
124 llvm::transform(
125 circuit.getAnnotationsAttr(), std::back_inserter(newAnnotations),
126 [&](Attribute attr) {
127 return transformAnnotationTarget(
128 cast<DictionaryAttr>(attr), [&](TokenAnnoTarget &tokens) {
129 if (auto *newModule = renameTable.lookup(tokens.module))
130 tokens.module =
131 cast<SymbolOpInterface>(newModule).getName();
132 return tokens;
133 });
134 });
135 circuit.setAnnotationsAttr(
136 ArrayAttr::get(circuit.getContext(), newAnnotations));
137 })
138 .Case<ObjectOp>([&](ObjectOp obj) {
139 auto resultTypeName = obj.getResult().getType().getName();
140 if (auto *newOp = renameTable.lookup(resultTypeName))
141 obj.getResult().setType(dyn_cast<ClassOp>(newOp).getInstanceType());
142 })
143 .Case<FModuleOp>([&](FModuleOp module) {
144 SmallVector<Attribute> newPortTypes;
145 llvm::transform(module.getPortTypesAttr().getValue(),
146 std::back_inserter(newPortTypes), updateTypeAttr);
147 module.setPortTypesAttr(
148 ArrayAttr::get(module->getContext(), newPortTypes));
149 updateResults(module.getBodyBlock()->getArguments());
150 })
151 .Case<InstanceOp>(
152 [&](InstanceOp instance) { updateResults(instance->getResults()); })
153 .Case<WireOp>([&](WireOp wire) { updateResults(wire->getResults()); })
154 .Default([](Operation *op) {});
155 });
156 return success();
157}
158
159static void collectLayerSymbols(LayerOp layer,
160 SmallVectorImpl<FlatSymbolRefAttr> &stack,
161 llvm::DenseSet<Attribute> &layers) {
162 stack.push_back(FlatSymbolRefAttr::get(layer.getSymNameAttr()));
163 layers.insert(SymbolRefAttr::get(stack.front().getAttr(),
164 ArrayRef(stack).drop_front()));
165 for (auto child : layer.getOps<LayerOp>())
166 collectLayerSymbols(child, stack, layers);
167 stack.pop_back();
168}
169
170static void collectLayerSymbols(CircuitOp circuit,
171 llvm::DenseSet<Attribute> &layers) {
172 SmallVector<FlatSymbolRefAttr> stack;
173 for (auto layer : circuit.getOps<LayerOp>())
174 collectLayerSymbols(layer, stack, layers);
175}
176
177static LogicalResult
178verifyKnownLayers(FExtModuleOp extModule,
179 const llvm::DenseSet<Attribute> &availableLayers) {
180 auto knownLayersAttr = extModule.getKnownLayersAttr();
181 if (!knownLayersAttr)
182 return success();
183
184 SmallVector<Attribute> missingLayers;
185 for (auto attr : knownLayersAttr)
186 if (!availableLayers.contains(attr))
187 missingLayers.push_back(attr);
188
189 if (missingLayers.empty())
190 return success();
191
192 auto diag = extModule.emitOpError()
193 << "declares known layers that are not defined in the linked "
194 "circuit: ";
195 llvm::interleaveComma(missingLayers, diag,
196 [&](Attribute attr) { diag << attr; });
197 return failure();
198}
199
200static LogicalResult mergeLayer(LayerOp dst, LayerOp src) {
201 if (dst.getConvention() != src.getConvention())
202 return src.emitOpError("layer convention mismatch with existing layer");
203
204 SymbolTable dstSymbolTable(dst);
205
206 for (auto &op : llvm::make_early_inc_range(src.getBody().front())) {
207 if (auto srcChildLayer = dyn_cast<LayerOp>(op))
208 if (auto dstChildLayer = cast_if_present<LayerOp>(
209 dstSymbolTable.lookup(srcChildLayer.getNameAttr()))) {
210 if (failed(mergeLayer(dstChildLayer, srcChildLayer)))
211 return failure();
212 continue;
213 }
214 op.moveBefore(&dst.getBody().front(), dst.getBody().front().end());
215 }
216 return success();
217}
218
219/// Resolves symbol collisions during circuit merging. Handles:
220///
221/// 1. Extmodule + module: declaration is removed in favor of the definition
222/// if their port attributes match. The definition must be public.
223/// 2. Identical extmodules: duplicates are removed.
224/// 3. Extmodule with empty parameters: the placeholder (without parameters)
225/// is removed in favor of the fully-parameterized one.
226/// 4. Layers: recursively merged.
227///
228/// \param collidingOp The operation already present in the merged circuit
229/// \param incomingOp The operation being added from another circuit
230/// \return FailureOr<bool> Returns success with true if incomingOp was erased,
231/// success with false if collidingOp was erased, or failure if the
232/// collision cannot be resolved
233///
234/// \note The empty parameters workaround (case 3) should be removed once ODS
235/// is updated to properly support placeholder declarations.
236static FailureOr<bool>
237handleCollidingOps(SymbolOpInterface collidingOp, SymbolOpInterface incomingOp,
238 const llvm::DenseSet<Attribute> &mergedLayers,
239 const llvm::DenseSet<Attribute> &incomingLayers) {
240 if ((isa<FExtModuleOp>(collidingOp) && isa<FModuleOp>(incomingOp)) ||
241 (isa<FExtModuleOp>(incomingOp) && isa<FModuleOp>(collidingOp))) {
242 auto definition = collidingOp;
243 auto declaration = incomingOp;
244 if (!isa<FModuleOp>(collidingOp)) {
245 definition = incomingOp;
246 declaration = collidingOp;
247 }
248
249 if (!definition.isPublic())
250 return definition->emitOpError("should be a public symbol");
251
252 auto extModule = cast<FExtModuleOp>(declaration);
253 const auto &layersToCheck =
254 (definition == incomingOp) ? incomingLayers : mergedLayers;
255 if (failed(verifyKnownLayers(extModule, layersToCheck)))
256 return failure();
257
258 constexpr const StringRef attrsToCompare[] = {
259 "portDirections", "portSymbols", "portNames", "portTypes", "layers"};
260 if (!all_of(attrsToCompare, [&](StringRef attr) {
261 return definition->getAttr(attr) == declaration->getAttr(attr);
262 }))
263 return failure();
264
265 declaration->erase();
266 return declaration == incomingOp;
267 }
268
269 if (isa<FExtModuleOp>(collidingOp) && isa<FExtModuleOp>(incomingOp)) {
270 constexpr const StringRef attrsToCompare[] = {
271 "portDirections", "portSymbols", "portNames",
272 "portTypes", "knownLayers", "layers",
273 };
274 if (!all_of(attrsToCompare, [&](StringRef attr) {
275 return collidingOp->getAttr(attr) == incomingOp->getAttr(attr);
276 }))
277 return failure();
278
279 auto collidingParams = collidingOp->getAttrOfType<ArrayAttr>("parameters");
280 auto incomingParams = incomingOp->getAttrOfType<ArrayAttr>("parameters");
281 if (collidingParams == incomingParams) {
282 if (collidingOp->getAttr("defname") != incomingOp->getAttr("defname"))
283 return failure();
284 incomingOp->erase();
285 return true;
286 }
287
288 // FIXME: definition and declaration may have different defname and
289 // decalration has no parameters
290 if (collidingParams.empty() || incomingParams.empty()) {
291 auto declaration = collidingParams.empty() ? collidingOp : incomingOp;
292 declaration->erase();
293 return declaration == incomingOp;
294 }
295 }
296
297 if (isa<LayerOp>(collidingOp) && isa<LayerOp>(incomingOp)) {
298 if (failed(
299 mergeLayer(cast<LayerOp>(collidingOp), cast<LayerOp>(incomingOp))))
300 return failure();
301 incomingOp->erase();
302 return true;
303 }
304
305 return failure();
306}
307
308LogicalResult LinkCircuitsPass::mergeCircuits() {
309 auto module = getOperation();
310
311 SmallVector<CircuitOp> circuits;
312 for (CircuitOp circuitOp : module.getOps<CircuitOp>())
313 circuits.push_back(circuitOp);
314
315 auto builder = OpBuilder(module);
316 builder.setInsertionPointToEnd(module.getBody());
317 auto mergedCircuit =
318 CircuitOp::create(builder, module.getLoc(),
319 StringAttr::get(&getContext(), baseCircuitName));
320 SmallVector<Attribute> mergedAnnotations;
321
322 llvm::DenseSet<Attribute> mergedLayers;
323
324 for (auto circuit : circuits) {
325 if (!noMangle)
326 if (failed(mangleCircuitSymbols(circuit)))
327 return circuit->emitError("failed to mangle private symbol");
328
329 llvm::DenseSet<Attribute> incomingLayers;
330 collectLayerSymbols(circuit, incomingLayers);
331
332 // TODO: other circuit attributes (such as enable_layers...)
333 llvm::transform(circuit.getAnnotations().getValue(),
334 std::back_inserter(mergedAnnotations), [&](Attribute attr) {
335 return transformAnnotationTarget(
336 cast<DictionaryAttr>(attr),
337 [&](TokenAnnoTarget &tokens) {
338 tokens.circuit = mergedCircuit.getName();
339 return tokens;
340 });
341 });
342
343 // reconstruct symbol table after each merge
344 auto mergedSymbolTable = SymbolTable(mergedCircuit.getOperation());
345
346 SmallVector<Operation *> opsToMove;
347 for (auto &op : circuit.getOps())
348 opsToMove.push_back(&op);
349 for (auto *op : opsToMove) {
350 if (auto symbolOp = dyn_cast<SymbolOpInterface>(op))
351 if (auto collidingOp = cast_if_present<SymbolOpInterface>(
352 mergedSymbolTable.lookup(symbolOp.getNameAttr()))) {
353 auto opErased = handleCollidingOps(collidingOp, symbolOp,
354 mergedLayers, incomingLayers);
355 if (failed(opErased))
356 return mergedCircuit->emitError("has colliding symbol " +
357 symbolOp.getName() +
358 " which cannot be merged");
359 if (opErased.value())
360 continue;
361 }
362
363 op->moveBefore(mergedCircuit.getBodyBlock(),
364 mergedCircuit.getBodyBlock()->end());
365 }
366
367 mergedLayers.insert(incomingLayers.begin(), incomingLayers.end());
368 circuit->erase();
369 }
370
371 mergedCircuit.setAnnotationsAttr(
372 ArrayAttr::get(mergedCircuit.getContext(), mergedAnnotations));
373
374 return mlir::detail::verifySymbolTable(mergedCircuit);
375}
376
377void LinkCircuitsPass::runOnOperation() {
378 if (failed(mergeCircuits()))
379 signalPassFailure();
380}
static FIRRTLBaseType updateType(FIRRTLBaseType oldType, unsigned fieldID, FIRRTLBaseType fieldType)
Update the type of a single field within a type.
static LogicalResult verifyKnownLayers(FExtModuleOp extModule, const llvm::DenseSet< Attribute > &availableLayers)
static DictionaryAttr transformAnnotationTarget(DictionaryAttr anno, CallableT transformTokensFn)
static LogicalResult mergeLayer(LayerOp dst, LayerOp src)
static void collectLayerSymbols(LayerOp layer, SmallVectorImpl< FlatSymbolRefAttr > &stack, llvm::DenseSet< Attribute > &layers)
static LogicalResult mangleCircuitSymbols(CircuitOp circuit)
static FailureOr< bool > handleCollidingOps(SymbolOpInterface collidingOp, SymbolOpInterface incomingOp, const llvm::DenseSet< Attribute > &mergedLayers, const llvm::DenseSet< Attribute > &incomingLayers)
Resolves symbol collisions during circuit merging.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.