CIRCT 22.0.0git
Loading...
Searching...
No Matches
LinkCircuits.cpp
Go to the documentation of this file.
1//===- LinkCircuits.cpp - Merge FIRRTL circuits --------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass links multiple circuits into a single one.
10//
11//===----------------------------------------------------------------------===//
12
18#include "mlir/IR/Attributes.h"
19#include "mlir/IR/Builders.h"
20#include "mlir/IR/BuiltinAttributes.h"
21#include "mlir/IR/Diagnostics.h"
22#include "mlir/IR/Operation.h"
23#include "mlir/IR/SymbolTable.h"
24#include "mlir/Pass/Pass.h"
25#include "mlir/Support/LLVM.h"
26#include "llvm/ADT/MapVector.h"
27#include "llvm/ADT/STLExtras.h"
28#include <iterator>
29
30namespace circt {
31namespace firrtl {
32#define GEN_PASS_DEF_LINKCIRCUITS
33#include "circt/Dialect/FIRRTL/Passes.h.inc"
34} // namespace firrtl
35} // namespace circt
36
37using namespace mlir;
38using namespace circt;
39using namespace firrtl;
40
41namespace {
42struct LinkCircuitsPass
43 : public circt::firrtl::impl::LinkCircuitsBase<LinkCircuitsPass> {
44 using Base::Base;
45
46 void runOnOperation() override;
47 LogicalResult mergeCircuits();
48 LinkCircuitsPass(StringRef baseCircuitNameOption, bool noMangleOption) {
49 baseCircuitName = std::string(baseCircuitNameOption);
50 noMangle = noMangleOption;
51 }
52};
53} // namespace
54
55template <typename CallableT>
56static DictionaryAttr transformAnnotationTarget(DictionaryAttr anno,
57 CallableT transformTokensFn) {
58 return DictionaryAttr::getWithSorted(
59 anno.getContext(),
60 to_vector(map_range(
61 anno.getValue(), [&](NamedAttribute namedAttr) -> NamedAttribute {
62 if (namedAttr.getName() == "target")
63 if (auto target = dyn_cast<StringAttr>(namedAttr.getValue()))
64 if (auto tokens = tokenizePath(target.getValue()))
65 return {
66 namedAttr.getName(),
67 StringAttr::get(target.getContext(),
68 transformTokensFn(tokens.value()).str())};
69 return namedAttr;
70 })));
71}
72
73static LogicalResult mangleCircuitSymbols(CircuitOp circuit) {
74 auto circuitName = circuit.getNameAttr();
75
76 llvm::MapVector<StringRef, Operation *> renameTable;
77 auto symbolTable = SymbolTable(circuit.getOperation());
78 auto manglePrivateSymbol = [&](SymbolOpInterface symbolOp) {
79 auto symbolName = symbolOp.getNameAttr();
80 auto newSymbolName =
81 StringAttr::get(symbolOp->getContext(),
82 circuitName.getValue() + "_" + symbolName.getValue());
83 renameTable.insert(std::pair(symbolName.getValue(), symbolOp));
84 return symbolTable.rename(symbolOp, newSymbolName);
85 };
86
87 for (auto &op : circuit.getOps()) {
88 auto symbolOp = dyn_cast<SymbolOpInterface>(op);
89 if (!symbolOp)
90 continue;
91
92 if (symbolOp.isPrivate())
93 if (failed(manglePrivateSymbol(symbolOp)))
94 return failure();
95 }
96
97 circuit.walk([&](Operation *op) {
98 auto updateType = [&](Type type) -> Type {
99 if (auto cls = dyn_cast<ClassType>(type))
100 if (auto *newOp = renameTable.lookup(cls.getName()))
101 return ClassType::get(FlatSymbolRefAttr::get(newOp),
102 cls.getElements());
103 return type;
104 };
105 auto updateTypeAttr = [&](Attribute attr) -> Attribute {
106 if (auto typeAttr = dyn_cast<TypeAttr>(attr)) {
107 auto newType = updateType(typeAttr.getValue());
108 if (newType != typeAttr.getValue())
109 return TypeAttr::get(newType);
110 }
111 return attr;
112 };
113 auto updateResults = [&](auto &&results) {
114 for (auto result : results)
115 if (auto newType = updateType(result.getType());
116 newType != result.getType())
117 result.setType(newType);
118 };
119
120 TypeSwitch<Operation *>(op)
121 .Case<CircuitOp>([&](CircuitOp circuit) {
122 SmallVector<Attribute> newAnnotations;
123 llvm::transform(
124 circuit.getAnnotationsAttr(), std::back_inserter(newAnnotations),
125 [&](Attribute attr) {
126 return transformAnnotationTarget(
127 cast<DictionaryAttr>(attr), [&](TokenAnnoTarget &tokens) {
128 if (auto *newModule = renameTable.lookup(tokens.module))
129 tokens.module =
130 cast<SymbolOpInterface>(newModule).getName();
131 return tokens;
132 });
133 });
134 circuit.setAnnotationsAttr(
135 ArrayAttr::get(circuit.getContext(), newAnnotations));
136 })
137 .Case<ObjectOp>([&](ObjectOp obj) {
138 auto resultTypeName = obj.getResult().getType().getName();
139 if (auto *newOp = renameTable.lookup(resultTypeName))
140 obj.getResult().setType(dyn_cast<ClassOp>(newOp).getInstanceType());
141 })
142 .Case<FModuleOp>([&](FModuleOp module) {
143 SmallVector<Attribute> newPortTypes;
144 llvm::transform(module.getPortTypesAttr().getValue(),
145 std::back_inserter(newPortTypes), updateTypeAttr);
146 module.setPortTypesAttr(
147 ArrayAttr::get(module->getContext(), newPortTypes));
148 updateResults(module.getBodyBlock()->getArguments());
149 })
150 .Case<InstanceOp>(
151 [&](InstanceOp instance) { updateResults(instance->getResults()); })
152 .Case<WireOp>([&](WireOp wire) { updateResults(wire->getResults()); })
153 .Default([](Operation *op) {});
154 });
155 return success();
156}
157
158/// Handles colliding symbols when merging circuits.
159///
160/// This function resolves symbol collisions between operations in different
161/// circuits during the linking process. It handles three specific cases:
162///
163/// 1. Identical extmodules: When two extmodules have identical attributes, the
164/// incoming one is removed as they are duplicates.
165///
166/// 2. Extmodule declaration + module definition: When an extmodule
167/// (declaration) collides with a module (definition), the declaration is
168/// removed in favor of the definition if their attributes match.
169///
170/// 3. Extmodule with empty parameters (Zaozi workaround): When two extmodules
171/// collide and one has empty parameters, the one without parameters
172/// (placeholder declaration) is removed. This handles a limitation in
173/// Zaozi's module generation where placeholder extmodule declarations are
174/// created from instance ops without knowing the actual parameters or
175/// defname.
176///
177/// \param collidingOp The operation already present in the merged circuit
178/// \param incomingOp The operation being added from another circuit
179/// \return FailureOr<bool> Returns success with true if incomingOp was erased,
180/// success with false if collidingOp was erased, or failure if the
181/// collision cannot be resolved
182///
183/// \note This workaround for empty parameters should ultimately be
184/// removed once ODS is updated to properly support placeholder
185/// declarations.
186static FailureOr<bool> handleCollidingOps(SymbolOpInterface collidingOp,
187 SymbolOpInterface incomingOp) {
188 if (!collidingOp.isPublic())
189 return collidingOp->emitOpError("should be a public symbol");
190 if (!incomingOp.isPublic())
191 return incomingOp->emitOpError("should be a public symbol");
192
193 if ((isa<FExtModuleOp>(collidingOp) && isa<FModuleOp>(incomingOp)) ||
194 (isa<FExtModuleOp>(incomingOp) && isa<FModuleOp>(collidingOp))) {
195 auto definition = collidingOp;
196 auto declaration = incomingOp;
197 if (!isa<FModuleOp>(collidingOp)) {
198 definition = incomingOp;
199 declaration = collidingOp;
200 }
201
202 constexpr const StringRef attrsToCompare[] = {
203 "portDirections", "portSymbols", "portNames", "portTypes", "layers"};
204 if (!all_of(attrsToCompare, [&](StringRef attr) {
205 return definition->getAttr(attr) == declaration->getAttr(attr);
206 }))
207 return failure();
208
209 declaration->erase();
210 return declaration == incomingOp;
211 }
212
213 if (isa<FExtModuleOp>(collidingOp) && isa<FExtModuleOp>(incomingOp)) {
214 constexpr const StringRef attrsToCompare[] = {
215 "portDirections", "portSymbols", "portNames",
216 "portTypes", "knownLayers", "layers",
217 };
218 if (!all_of(attrsToCompare, [&](StringRef attr) {
219 return collidingOp->getAttr(attr) == incomingOp->getAttr(attr);
220 }))
221 return failure();
222
223 auto collidingParams = collidingOp->getAttrOfType<ArrayAttr>("parameters");
224 auto incomingParams = incomingOp->getAttrOfType<ArrayAttr>("parameters");
225 if (collidingParams == incomingParams) {
226 if (collidingOp->getAttr("defname") != incomingOp->getAttr("defname"))
227 return failure();
228 incomingOp->erase();
229 return true;
230 }
231
232 // FIXME: definition and declaration may have different defname and
233 // decalration has no parameters
234 if (collidingParams.empty() || incomingParams.empty()) {
235 auto declaration = collidingParams.empty() ? collidingOp : incomingOp;
236 declaration->erase();
237 return declaration == incomingOp;
238 }
239 }
240
241 return failure();
242}
243
244LogicalResult LinkCircuitsPass::mergeCircuits() {
245 auto module = getOperation();
246
247 SmallVector<CircuitOp> circuits;
248 for (CircuitOp circuitOp : module.getOps<CircuitOp>())
249 circuits.push_back(circuitOp);
250
251 auto builder = OpBuilder(module);
252 builder.setInsertionPointToEnd(module.getBody());
253 auto mergedCircuit =
254 CircuitOp::create(builder, module.getLoc(),
255 StringAttr::get(&getContext(), baseCircuitName));
256 SmallVector<Attribute> mergedAnnotations;
257
258 for (auto circuit : circuits) {
259 if (!noMangle)
260 if (failed(mangleCircuitSymbols(circuit)))
261 return circuit->emitError("failed to mangle private symbol");
262
263 // TODO: other circuit attributes (such as enable_layers...)
264 llvm::transform(circuit.getAnnotations().getValue(),
265 std::back_inserter(mergedAnnotations), [&](Attribute attr) {
266 return transformAnnotationTarget(
267 cast<DictionaryAttr>(attr),
268 [&](TokenAnnoTarget &tokens) {
269 tokens.circuit = mergedCircuit.getName();
270 return tokens;
271 });
272 });
273
274 // reconstruct symbol table after each merge
275 auto mergedSymbolTable = SymbolTable(mergedCircuit.getOperation());
276
277 SmallVector<Operation *> opsToMove;
278 for (auto &op : circuit.getOps())
279 opsToMove.push_back(&op);
280 for (auto *op : opsToMove) {
281 if (auto symbolOp = dyn_cast<SymbolOpInterface>(op))
282 if (auto collidingOp = cast_if_present<SymbolOpInterface>(
283 mergedSymbolTable.lookup(symbolOp.getNameAttr()))) {
284 auto opErased = handleCollidingOps(collidingOp, symbolOp);
285 if (failed(opErased))
286 return mergedCircuit->emitError("has colliding symbol " +
287 symbolOp.getName() +
288 " which cannot be merged");
289 if (opErased.value())
290 continue;
291 }
292
293 op->moveBefore(mergedCircuit.getBodyBlock(),
294 mergedCircuit.getBodyBlock()->end());
295 }
296
297 circuit->erase();
298 }
299
300 mergedCircuit.setAnnotationsAttr(
301 ArrayAttr::get(mergedCircuit.getContext(), mergedAnnotations));
302
303 return mlir::detail::verifySymbolTable(mergedCircuit);
304}
305
306void LinkCircuitsPass::runOnOperation() {
307 if (failed(mergeCircuits()))
308 signalPassFailure();
309}
static FIRRTLBaseType updateType(FIRRTLBaseType oldType, unsigned fieldID, FIRRTLBaseType fieldType)
Update the type of a single field within a type.
static DictionaryAttr transformAnnotationTarget(DictionaryAttr anno, CallableT transformTokensFn)
static FailureOr< bool > handleCollidingOps(SymbolOpInterface collidingOp, SymbolOpInterface incomingOp)
Handles colliding symbols when merging circuits.
static LogicalResult mangleCircuitSymbols(CircuitOp circuit)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.