CIRCT 23.0.0git
Loading...
Searching...
No Matches
LowerSignatures.cpp
Go to the documentation of this file.
1//===- LowerSignatures.cpp - Lower Module Signatures ------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the LowerSignatures pass. This pass replaces aggregate
10// types with expanded values in module arguments as specified by the ABI
11// information.
12//
13//===----------------------------------------------------------------------===//
14
21#include "circt/Support/Debug.h"
23#include "mlir/IR/Threading.h"
24#include "mlir/Pass/Pass.h"
25#include "llvm/ADT/STLExtras.h"
26#include "llvm/Support/Debug.h"
27
28#define DEBUG_TYPE "firrtl-lower-signatures"
29
30namespace circt {
31namespace firrtl {
32#define GEN_PASS_DEF_LOWERSIGNATURES
33#include "circt/Dialect/FIRRTL/Passes.h.inc"
34} // namespace firrtl
35} // namespace circt
36
37using namespace circt;
38using namespace firrtl;
39
40//===----------------------------------------------------------------------===//
41// Module Type Lowering
42//===----------------------------------------------------------------------===//
43namespace {
44
45struct AttrCache {
46 AttrCache(MLIRContext *context) {
47 nameAttr = StringAttr::get(context, "name");
48 sPortDirections = StringAttr::get(context, "portDirections");
49 sPortNames = StringAttr::get(context, "portNames");
50 sPortTypes = StringAttr::get(context, "portTypes");
51 sPortLocations = StringAttr::get(context, "portLocations");
52 sPortAnnotations = StringAttr::get(context, "portAnnotations");
53 sPortDomains = StringAttr::get(context, "domainInfo");
54 aEmpty = ArrayAttr::get(context, {});
55 }
56 AttrCache(const AttrCache &) = default;
57
58 StringAttr nameAttr, sPortDirections, sPortNames, sPortTypes, sPortLocations,
59 sPortAnnotations, sPortDomains;
60 ArrayAttr aEmpty;
61};
62
63struct FieldMapEntry : public PortInfo {
64 size_t portID;
65 size_t resultID;
66 size_t fieldID;
67};
68
69using PortConversion = SmallVector<FieldMapEntry>;
70
71template <typename T>
72class FieldIDSearch {
73 using E = typename T::ElementType;
74 using V = SmallVector<E>;
75
76public:
77 using const_iterator = typename V::const_iterator;
78
79 template <typename Container>
80 FieldIDSearch(const Container &src) {
81 if constexpr (std::is_convertible_v<Container, Attribute>)
82 if (!src)
83 return;
84 for (auto attr : src)
85 vals.push_back(attr);
86 std::sort(vals.begin(), vals.end(), fieldComp);
87 }
88
89 std::pair<const_iterator, const_iterator> find(uint64_t low,
90 uint64_t high) const {
91 return {std::lower_bound(vals.begin(), vals.end(), low, fieldCompInt2),
92 std::upper_bound(vals.begin(), vals.end(), high, fieldCompInt1)};
93 }
94
95 bool empty(uint64_t low, uint64_t high) const {
96 auto [b, e] = find(low, high);
97 return b == e;
98 }
99
100private:
101 static constexpr auto fieldComp = [](const E &lhs, const E &rhs) {
102 return lhs.getFieldID() < rhs.getFieldID();
103 };
104 static constexpr auto fieldCompInt2 = [](const E &lhs, uint64_t rhs) {
105 return lhs.getFieldID() < rhs;
106 };
107 static constexpr auto fieldCompInt1 = [](uint64_t lhs, const E &rhs) {
108 return lhs < rhs.getFieldID();
109 };
110
111 V vals;
112};
113
114} // namespace
115
116static hw::InnerSymAttr
117symbolsForFieldIDRange(MLIRContext *ctx,
118 const FieldIDSearch<hw::InnerSymAttr> &syms,
119 uint64_t low, uint64_t high) {
120 auto [b, e] = syms.find(low, high);
121 SmallVector<hw::InnerSymPropertiesAttr, 4> newSyms(b, e);
122 if (newSyms.empty())
123 return {};
124 for (auto &sym : newSyms)
125 sym = hw::InnerSymPropertiesAttr::get(
126 ctx, sym.getName(), sym.getFieldID() - low, sym.getSymVisibility());
127 return hw::InnerSymAttr::get(ctx, newSyms);
128}
129
130static AnnotationSet
131annosForFieldIDRange(MLIRContext *ctx,
132 const FieldIDSearch<AnnotationSet> &annos, uint64_t low,
133 uint64_t high) {
134 AnnotationSet newAnnos(ctx);
135 auto [b, e] = annos.find(low, high);
136 for (; b != e; ++b)
137 newAnnos.addAnnotations(Annotation(*b, b->getFieldID() - low));
138 return newAnnos;
139}
140
141static LogicalResult
142computeLoweringImpl(FModuleLike mod, PortConversion &newPorts, Convention conv,
143 size_t portID, const PortInfo &port, bool isFlip,
144 Twine name, FIRRTLType type, uint64_t fieldID,
145 const FieldIDSearch<hw::InnerSymAttr> &syms,
146 const FieldIDSearch<AnnotationSet> &annos) {
147 auto *ctx = type.getContext();
149 .Case<BundleType>([&](BundleType bundle) -> LogicalResult {
150 // This should be enhanced to be able to handle bundle<all flips of
151 // passive>, or this should be a canonicalizer
152 if (conv != Convention::Scalarized && bundle.isPassive()) {
153 auto lastId = fieldID + bundle.getMaxFieldID();
154 newPorts.push_back(
155 {{StringAttr::get(ctx, name), type,
156 isFlip ? Direction::Out : Direction::In,
157 symbolsForFieldIDRange(ctx, syms, fieldID, lastId), port.loc,
158 annosForFieldIDRange(ctx, annos, fieldID, lastId),
159 port.domains},
160 portID,
161 newPorts.size(),
162 fieldID});
163 } else {
164 for (auto [idx, elem] : llvm::enumerate(bundle.getElements())) {
165 if (failed(computeLoweringImpl(
166 mod, newPorts, conv, portID, port, isFlip ^ elem.isFlip,
167 name + "_" + elem.name.getValue(), elem.type,
168 fieldID + bundle.getFieldID(idx), syms, annos)))
169 return failure();
170 if (!syms.empty(fieldID, fieldID))
171 return mod.emitError("Port [")
172 << port.name
173 << "] should be subdivided, but cannot be because of "
174 "symbol ["
175 << port.sym.getSymIfExists(fieldID) << "] on a bundle";
176 if (!annos.empty(fieldID, fieldID)) {
177 auto err = mod.emitError("Port [")
178 << port.name
179 << "] should be subdivided, but cannot be because of "
180 "annotations [";
181 auto [b, e] = annos.find(fieldID, fieldID);
182 err << b->getClass() << "(" << b->getFieldID() << ")";
183 b++;
184 for (; b != e; ++b)
185 err << ", " << b->getClass() << "(" << b->getFieldID() << ")";
186 err << "] on a bundle";
187 return err;
188 }
189 }
190 }
191 return success();
192 })
193 .Case<FVectorType>([&](FVectorType vector) -> LogicalResult {
194 if (conv != Convention::Scalarized &&
195 vector.getElementType().isPassive()) {
196 auto lastId = fieldID + vector.getMaxFieldID();
197 newPorts.push_back(
198 {{StringAttr::get(ctx, name), type,
199 isFlip ? Direction::Out : Direction::In,
200 symbolsForFieldIDRange(ctx, syms, fieldID, lastId), port.loc,
201 annosForFieldIDRange(ctx, annos, fieldID, lastId),
202 port.domains},
203 portID,
204 newPorts.size(),
205 fieldID});
206 } else {
207 for (size_t i = 0, e = vector.getNumElements(); i < e; ++i) {
208 if (failed(computeLoweringImpl(
209 mod, newPorts, conv, portID, port, isFlip,
210 name + "_" + Twine(i), vector.getElementType(),
211 fieldID + vector.getFieldID(i), syms, annos)))
212 return failure();
213 if (!syms.empty(fieldID, fieldID))
214 return mod.emitError("Port [")
215 << port.name
216 << "] should be subdivided, but cannot be because of "
217 "symbol ["
218 << port.sym.getSymIfExists(fieldID) << "] on a vector";
219 if (!annos.empty(fieldID, fieldID)) {
220 auto err = mod.emitError("Port [")
221 << port.name
222 << "] should be subdivided, but cannot be because of "
223 "annotations [";
224 auto [b, e] = annos.find(fieldID, fieldID);
225 err << b->getClass();
226 ++b;
227 for (; b != e; ++b)
228 err << ", " << b->getClass();
229 err << "] on a vector";
230 return err;
231 }
232 }
233 }
234 return success();
235 })
236 .Default([&](FIRRTLType type) {
237 // Properties and other types wind up here.
238 newPorts.push_back(
239 {{StringAttr::get(ctx, name), type,
240 isFlip ? Direction::Out : Direction::In,
241 symbolsForFieldIDRange(ctx, syms, fieldID, fieldID), port.loc,
242 annosForFieldIDRange(ctx, annos, fieldID, fieldID), port.domains},
243 portID,
244 newPorts.size(),
245 fieldID});
246 return success();
247 });
248}
249
250// compute a new moduletype from an old module type and lowering convention.
251// Also compute a fieldID map from port, fieldID -> port
252static LogicalResult computeLowering(FModuleLike mod, Convention conv,
253 PortConversion &newPorts) {
254 for (auto [idx, port] : llvm::enumerate(mod.getPorts())) {
255 if (failed(computeLoweringImpl(
256 mod, newPorts, conv, idx, port, port.direction == Direction::Out,
257 port.name.getValue(), type_cast<FIRRTLType>(port.type), 0,
258 FieldIDSearch<hw::InnerSymAttr>(port.sym),
259 FieldIDSearch<AnnotationSet>(port.annotations))))
260 return failure();
261 }
262 return success();
263}
264
265static LogicalResult lowerModuleSignature(FModuleLike module, Convention conv,
266 AttrCache &cache,
267 PortConversion &newPorts) {
268 ImplicitLocOpBuilder theBuilder(module.getLoc(), module.getContext());
269 if (computeLowering(module, conv, newPorts).failed())
270 return failure();
271
272 // Update domain information now that all port expansions are fixed.
273 DenseMap<size_t, size_t> domainMap;
274 for (auto &newPort : newPorts) {
275 if (!type_isa<DomainType>(newPort.type))
276 continue;
277 domainMap[newPort.portID] = newPort.resultID;
278 }
279 for (auto &newPort : newPorts) {
280 if (type_isa<DomainType>(newPort.type))
281 continue;
282 auto oldAssociations = dyn_cast_or_null<ArrayAttr>(newPort.domains);
283 if (!oldAssociations)
284 continue;
285 SmallVector<Attribute> newAssociations;
286 for (auto oldAttr : oldAssociations)
287 newAssociations.push_back(theBuilder.getUI32IntegerAttr(
288 domainMap[cast<IntegerAttr>(oldAttr).getValue().getZExtValue()]));
289 newPort.domains = theBuilder.getArrayAttr(newAssociations);
290 }
291
292 if (auto mod = dyn_cast<FModuleOp>(module.getOperation())) {
293 Block *body = mod.getBodyBlock();
294 theBuilder.setInsertionPointToStart(body);
295 auto oldNumArgs = body->getNumArguments();
296
297 // Compute the replacement value for old arguments
298 // This creates all the new arguments and produces bounce wires when
299 // necessary
300 SmallVector<Value> bounceWires(oldNumArgs);
301 for (auto &p : newPorts) {
302 auto newArg = body->addArgument(p.type, p.loc);
303 // Get or create a bounce wire for changed ports
304 // For unmodified ports, move the uses to the replacement port
305 if (p.fieldID != 0) {
306 auto &wire = bounceWires[p.portID];
307 if (!wire)
308 wire = WireOp::create(theBuilder, module.getPortType(p.portID),
309 module.getPortNameAttr(p.portID),
310 NameKindEnum::InterestingName)
311 .getResult();
312 } else {
313 bounceWires[p.portID] = newArg;
314 }
315 }
316 // replace old arguments. Somethings get dropped completely, like
317 // zero-length vectors.
318 for (auto idx = 0U; idx < oldNumArgs; ++idx) {
319 if (!bounceWires[idx]) {
320 bounceWires[idx] = WireOp::create(theBuilder, module.getPortType(idx),
321 module.getPortNameAttr(idx))
322 .getResult();
323 }
324 body->getArgument(idx).replaceAllUsesWith(bounceWires[idx]);
325 }
326
327 // Goodbye old ports, now ResultID in the PortInfo is correct.
328 body->eraseArguments(0, oldNumArgs);
329
330 // Connect the bounce wires to the new arguments
331 for (auto &p : newPorts) {
332 if (isa<BlockArgument>(bounceWires[p.portID]))
333 continue;
334 if (p.isOutput())
336 theBuilder, body->getArgument(p.resultID),
337 getValueByFieldID(theBuilder, bounceWires[p.portID], p.fieldID));
338 else
340 theBuilder,
341 getValueByFieldID(theBuilder, bounceWires[p.portID], p.fieldID),
342 body->getArgument(p.resultID));
343 }
344 }
345
346 SmallVector<NamedAttribute, 8> newModuleAttrs;
347
348 // Copy over any attributes that weren't original argument attributes.
349 for (auto attr : module->getAttrDictionary())
350 // Drop old "portNames", directions, and argument attributes. These are
351 // handled differently below.
352 if (attr.getName() != "portNames" && attr.getName() != "portDirections" &&
353 attr.getName() != "portTypes" && attr.getName() != "portAnnotations" &&
354 attr.getName() != "portSymbols" && attr.getName() != "portLocations")
355 newModuleAttrs.push_back(attr);
356
357 SmallVector<Direction> newPortDirections;
358 SmallVector<Attribute> newPortNames;
359 SmallVector<Attribute> newPortTypes;
360 SmallVector<Attribute> newPortSyms;
361 SmallVector<Attribute> newPortLocations;
362 SmallVector<Attribute, 8> newPortAnnotations;
363 SmallVector<Attribute> newPortDomains;
364
365 for (auto p : newPorts) {
366 newPortTypes.push_back(TypeAttr::get(p.type));
367 newPortNames.push_back(p.name);
368 newPortDirections.push_back(p.direction);
369 newPortSyms.push_back(p.sym);
370 newPortLocations.push_back(p.loc);
371 newPortAnnotations.push_back(p.annotations.getArrayAttr());
372 newPortDomains.push_back(p.domains ? p.domains : cache.aEmpty);
373 }
374
375 newModuleAttrs.push_back(NamedAttribute(
376 cache.sPortDirections,
377 direction::packAttribute(module.getContext(), newPortDirections)));
378
379 newModuleAttrs.push_back(
380 NamedAttribute(cache.sPortNames, theBuilder.getArrayAttr(newPortNames)));
381
382 newModuleAttrs.push_back(
383 NamedAttribute(cache.sPortTypes, theBuilder.getArrayAttr(newPortTypes)));
384
385 newModuleAttrs.push_back(NamedAttribute(
386 cache.sPortLocations, theBuilder.getArrayAttr(newPortLocations)));
387
388 newModuleAttrs.push_back(NamedAttribute(
389 cache.sPortAnnotations, theBuilder.getArrayAttr(newPortAnnotations)));
390
391 newModuleAttrs.push_back(NamedAttribute(
392 cache.sPortDomains, theBuilder.getArrayAttr(newPortDomains)));
393
394 // Update the module's attributes.
395 module->setAttrs(newModuleAttrs);
396 FModuleLike::fixupPortSymsArray(newPortSyms, theBuilder.getContext());
397 module.setPortSymbols(newPortSyms);
398 return success();
399}
400
401static void lowerModuleBody(FModuleOp mod,
402 const DenseMap<StringAttr, PortConversion> &ports) {
403 auto fixupInstance = [&](auto inst, auto clone) -> void {
404 ImplicitLocOpBuilder theBuilder(inst.getLoc(), inst);
405
406 // Get the module name. The first element works for both InstanceOp and
407 // InstanceChoiceOp.
408 StringAttr moduleName =
409 cast<StringAttr>(inst.getReferencedModuleNamesAttr()[0]);
410
411 const auto &modPorts = ports.at(moduleName);
412
413 // Fix up the Instance
414 SmallVector<PortInfo> instPorts; // Oh I wish ArrayRef was polymorphic.
415 for (auto p : modPorts) {
416 p.sym = {};
417 // Might need to partially copy stuff from the old instance.
418 p.annotations = AnnotationSet{mod.getContext()};
419 instPorts.push_back(p);
420 }
421
422 auto newOp = clone(theBuilder, inst, instPorts);
423
424 auto oldDict = inst->getDiscardableAttrDictionary();
425 auto newDict = newOp->getDiscardableAttrDictionary();
426 auto oldNames = inst.getPortNamesAttr();
427 SmallVector<NamedAttribute> newAttrs;
428 for (auto na : oldDict)
429 if (!newDict.contains(na.getName()))
430 newOp->setDiscardableAttr(na.getName(), na.getValue());
431
432 // Connect up the old instance users to the new instance
433 SmallVector<WireOp> bounce(inst.getNumResults());
434 for (auto p : modPorts) {
435 // No change? No bounce wire.
436 if (p.fieldID == 0) {
437 inst.getResult(p.portID).replaceAllUsesWith(
438 newOp.getResult(p.resultID));
439 continue;
440 }
441 if (!bounce[p.portID]) {
442 bounce[p.portID] = WireOp::create(
443 theBuilder, inst.getResult(p.portID).getType(),
444 theBuilder.getStringAttr(
445 inst.getName() + "." +
446 cast<StringAttr>(oldNames[p.portID]).getValue()));
447 inst.getResult(p.portID).replaceAllUsesWith(
448 bounce[p.portID].getResult());
449 }
450 // Connect up the Instance to the bounce wires
451 if (p.isInput())
452 emitConnect(theBuilder, newOp.getResult(p.resultID),
453 getValueByFieldID(theBuilder, bounce[p.portID].getResult(),
454 p.fieldID));
455 else
456 emitConnect(theBuilder,
457 getValueByFieldID(theBuilder, bounce[p.portID].getResult(),
458 p.fieldID),
459 newOp.getResult(p.resultID));
460 }
461 // Zero Width ports may have dangling connects since they are not preserved
462 // and do not have bounce wires.
463 for (auto *use : llvm::make_early_inc_range(inst->getUsers())) {
464 assert(isa<MatchingConnectOp>(use) || isa<ConnectOp>(use));
465 use->erase();
466 }
467 inst->erase();
468 return;
469 };
470
471 mod->walk([&](Operation *op) -> void {
472 TypeSwitch<Operation *>(op)
473 .Case<InstanceOp>([&](auto inst) {
474 fixupInstance(inst, [&](ImplicitLocOpBuilder &theBuilder,
475 InstanceOp inst,
476 ArrayRef<PortInfo> newPorts) {
477 return InstanceOp::create(
478 theBuilder, newPorts, inst.getModuleName(), inst.getName(),
479 inst.getNameKind(), inst.getAnnotations().getValue(),
480 inst.getLayers(), inst.getLowerToBind(), inst.getDoNotPrint(),
481 inst.getInnerSymAttr());
482 });
483 })
484 .Case<InstanceChoiceOp>([&](auto inst) {
485 fixupInstance(inst, [&](ImplicitLocOpBuilder &theBuilder,
486 InstanceChoiceOp inst,
487 ArrayRef<PortInfo> newPorts) {
488 return InstanceChoiceOp::create(
489 theBuilder, newPorts, inst.getModuleNamesAttr(),
490 inst.getCaseNamesAttr(), inst.getName(), inst.getNameKind(),
491 inst.getAnnotationsAttr(), inst.getLayersAttr(),
492 inst.getInnerSymAttr());
493 });
494 });
495 });
496}
497
498//===----------------------------------------------------------------------===//
499// Pass Infrastructure
500//===----------------------------------------------------------------------===//
501
502namespace {
503struct LowerSignaturesPass
504 : public circt::firrtl::impl::LowerSignaturesBase<LowerSignaturesPass> {
505 void runOnOperation() override;
506};
507} // end anonymous namespace
508
509// This is the main entrypoint for the lowering pass.
510void LowerSignaturesPass::runOnOperation() {
512 auto &instanceGraph = getAnalysis<InstanceGraph>();
513
514 // Cached attr
515 AttrCache cache(&getContext());
516
517 DenseMap<StringAttr, PortConversion> portMap;
518 auto circuit = getOperation();
519
520 for (auto mod : circuit.getOps<FModuleLike>()) {
521 auto convention = mod.getConvention();
522 // Instance choices select between modules with a shared port shape, so
523 // any module instantiated by one must use the scalarized convention.
524 if (llvm::any_of(instanceGraph.lookup(mod)->uses(),
525 [](InstanceRecord *use) {
526 return use->getInstance<InstanceChoiceOp>();
527 }))
528 convention = Convention::Scalarized;
529 if (lowerModuleSignature(mod, convention, cache, portMap[mod.getNameAttr()])
530 .failed())
531 return signalPassFailure();
532 }
533 parallelForEach(&getContext(), circuit.getOps<FModuleOp>(),
534 [&portMap](FModuleOp mod) { lowerModuleBody(mod, portMap); });
535}
assert(baseType &&"element must be base type")
static std::unique_ptr< Context > context
static Term * find(Term *x)
static LogicalResult computeLowering(FModuleLike mod, Convention conv, PortConversion &newPorts)
static AnnotationSet annosForFieldIDRange(MLIRContext *ctx, const FieldIDSearch< AnnotationSet > &annos, uint64_t low, uint64_t high)
static LogicalResult lowerModuleSignature(FModuleLike module, Convention conv, AttrCache &cache, PortConversion &newPorts)
static LogicalResult computeLoweringImpl(FModuleLike mod, PortConversion &newPorts, Convention conv, size_t portID, const PortInfo &port, bool isFlip, Twine name, FIRRTLType type, uint64_t fieldID, const FieldIDSearch< hw::InnerSymAttr > &syms, const FieldIDSearch< AnnotationSet > &annos)
static void lowerModuleBody(FModuleOp mod, const DenseMap< StringAttr, PortConversion > &ports)
static hw::InnerSymAttr symbolsForFieldIDRange(MLIRContext *ctx, const FieldIDSearch< hw::InnerSymAttr > &syms, uint64_t low, uint64_t high)
static InstancePath empty
#define CIRCT_DEBUG_SCOPED_PASS_LOGGER(PASS)
Definition Debug.h:70
This class provides a read-only projection over the MLIR attributes that represent a set of annotatio...
MLIRContext * getContext() const
Return the MLIRContext corresponding to this AnnotationSet.
void addAnnotations(ArrayRef< Annotation > annotations)
Add more annotations to this annotation set.
This class provides a read-only projection of an annotation.
This class implements the same functionality as TypeSwitch except that it uses firrtl::type_dyn_cast ...
FIRRTLTypeSwitch< T, ResultT > & Case(CallableT &&caseFn)
Add a case on the given type.
Base class for the port conversion of a particular port.
This is an edge in the InstanceGraph.
mlir::DenseBoolArrayAttr packAttribute(MLIRContext *context, ArrayRef< Direction > directions)
Return a DenseBoolArrayAttr containing the packed representation of an array of directions.
Value getValueByFieldID(ImplicitLocOpBuilder builder, Value value, unsigned fieldID)
This gets the value targeted by a field id.
void emitConnect(OpBuilder &builder, Location loc, Value lhs, Value rhs)
Emit a connect between two values.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
This holds the name and type that describes the module's ports.