23#include "mlir/IR/Threading.h"
24#include "mlir/Pass/Pass.h"
25#include "llvm/ADT/STLExtras.h"
26#include "llvm/Support/Debug.h"
28#define DEBUG_TYPE "firrtl-lower-signatures"
32#define GEN_PASS_DEF_LOWERSIGNATURES
33#include "circt/Dialect/FIRRTL/Passes.h.inc"
38using namespace firrtl;
46 AttrCache(MLIRContext *
context) {
47 nameAttr = StringAttr::get(
context,
"name");
48 sPortDirections = StringAttr::get(
context,
"portDirections");
49 sPortNames = StringAttr::get(
context,
"portNames");
50 sPortTypes = StringAttr::get(
context,
"portTypes");
51 sPortLocations = StringAttr::get(
context,
"portLocations");
52 sPortAnnotations = StringAttr::get(
context,
"portAnnotations");
53 sPortDomains = StringAttr::get(
context,
"domainInfo");
54 aEmpty = ArrayAttr::get(
context, {});
56 AttrCache(
const AttrCache &) =
default;
58 StringAttr nameAttr, sPortDirections, sPortNames, sPortTypes, sPortLocations,
59 sPortAnnotations, sPortDomains;
63struct FieldMapEntry :
public PortInfo {
73 using E =
typename T::ElementType;
74 using V = SmallVector<E>;
77 using const_iterator =
typename V::const_iterator;
79 template <
typename Container>
80 FieldIDSearch(
const Container &src) {
81 if constexpr (std::is_convertible_v<Container, Attribute>)
86 std::sort(vals.begin(), vals.end(), fieldComp);
89 std::pair<const_iterator, const_iterator> find(uint64_t low,
90 uint64_t high)
const {
91 return {std::lower_bound(vals.begin(), vals.end(), low, fieldCompInt2),
92 std::upper_bound(vals.begin(), vals.end(), high, fieldCompInt1)};
95 bool empty(uint64_t low, uint64_t high)
const {
96 auto [b, e] = find(low, high);
101 static constexpr auto fieldComp = [](
const E &lhs,
const E &rhs) {
102 return lhs.getFieldID() < rhs.getFieldID();
104 static constexpr auto fieldCompInt2 = [](
const E &lhs, uint64_t rhs) {
105 return lhs.getFieldID() < rhs;
107 static constexpr auto fieldCompInt1 = [](uint64_t lhs,
const E &rhs) {
108 return lhs < rhs.getFieldID();
116static hw::InnerSymAttr
118 const FieldIDSearch<hw::InnerSymAttr> &syms,
119 uint64_t low, uint64_t high) {
120 auto [b, e] = syms.find(low, high);
121 SmallVector<hw::InnerSymPropertiesAttr, 4> newSyms(b, e);
124 for (
auto &sym : newSyms)
125 sym = hw::InnerSymPropertiesAttr::get(
126 ctx, sym.getName(), sym.getFieldID() - low, sym.getSymVisibility());
127 return hw::InnerSymAttr::get(ctx, newSyms);
132 const FieldIDSearch<AnnotationSet> &annos, uint64_t low,
135 auto [b, e] = annos.find(low, high);
143 size_t portID,
const PortInfo &port,
bool isFlip,
144 Twine name,
FIRRTLType type, uint64_t fieldID,
145 const FieldIDSearch<hw::InnerSymAttr> &syms,
146 const FieldIDSearch<AnnotationSet> &annos) {
147 auto *ctx = type.getContext();
149 .
Case<BundleType>([&](BundleType bundle) -> LogicalResult {
152 if (conv != Convention::Scalarized && bundle.isPassive()) {
153 auto lastId = fieldID + bundle.getMaxFieldID();
155 {{StringAttr::get(ctx, name), type,
156 isFlip ? Direction::Out : Direction::In,
164 for (
auto [idx, elem] : llvm::enumerate(bundle.getElements())) {
166 mod, newPorts, conv, portID, port, isFlip ^ elem.isFlip,
167 name +
"_" + elem.
name.getValue(), elem.type,
168 fieldID + bundle.getFieldID(idx), syms, annos)))
170 if (!syms.empty(fieldID, fieldID))
171 return mod.emitError(
"Port [")
173 <<
"] should be subdivided, but cannot be because of "
175 << port.
sym.getSymIfExists(fieldID) <<
"] on a bundle";
176 if (!annos.empty(fieldID, fieldID)) {
177 auto err = mod.emitError(
"Port [")
179 <<
"] should be subdivided, but cannot be because of "
181 auto [b, e] = annos.find(fieldID, fieldID);
182 err << b->getClass() <<
"(" << b->getFieldID() <<
")";
185 err <<
", " << b->getClass() <<
"(" << b->getFieldID() <<
")";
186 err <<
"] on a bundle";
193 .Case<FVectorType>([&](FVectorType vector) -> LogicalResult {
194 if (conv != Convention::Scalarized &&
195 vector.getElementType().isPassive()) {
196 auto lastId = fieldID + vector.getMaxFieldID();
198 {{StringAttr::get(ctx, name), type,
199 isFlip ? Direction::Out : Direction::In,
207 for (
size_t i = 0, e = vector.getNumElements(); i < e; ++i) {
209 mod, newPorts, conv, portID, port, isFlip,
210 name +
"_" + Twine(i), vector.getElementType(),
211 fieldID + vector.getFieldID(i), syms, annos)))
213 if (!syms.empty(fieldID, fieldID))
214 return mod.emitError(
"Port [")
216 <<
"] should be subdivided, but cannot be because of "
218 << port.
sym.getSymIfExists(fieldID) <<
"] on a vector";
219 if (!annos.empty(fieldID, fieldID)) {
220 auto err = mod.emitError(
"Port [")
222 <<
"] should be subdivided, but cannot be because of "
224 auto [b, e] = annos.find(fieldID, fieldID);
225 err << b->getClass();
228 err <<
", " << b->getClass();
229 err <<
"] on a vector";
239 {{StringAttr::get(ctx, name), type,
240 isFlip ? Direction::Out : Direction::In,
254 for (
auto [idx, port] : llvm::enumerate(mod.getPorts())) {
256 mod, newPorts, conv, idx, port, port.direction == Direction::Out,
257 port.name.getValue(), type_cast<FIRRTLType>(port.type), 0,
258 FieldIDSearch<hw::InnerSymAttr>(port.sym),
259 FieldIDSearch<AnnotationSet>(port.annotations))))
268 ImplicitLocOpBuilder theBuilder(module.getLoc(), module.getContext());
273 DenseMap<size_t, size_t> domainMap;
274 for (
auto &newPort : newPorts) {
275 if (!type_isa<DomainType>(newPort.type))
277 domainMap[newPort.portID] = newPort.resultID;
279 for (
auto &newPort : newPorts) {
280 if (type_isa<DomainType>(newPort.type))
282 auto oldAssociations = dyn_cast_or_null<ArrayAttr>(newPort.domains);
283 if (!oldAssociations)
285 SmallVector<Attribute> newAssociations;
286 for (
auto oldAttr : oldAssociations)
287 newAssociations.push_back(theBuilder.getUI32IntegerAttr(
288 domainMap[cast<IntegerAttr>(oldAttr).getValue().getZExtValue()]));
289 newPort.domains = theBuilder.getArrayAttr(newAssociations);
292 if (
auto mod = dyn_cast<FModuleOp>(module.getOperation())) {
293 Block *body = mod.getBodyBlock();
294 theBuilder.setInsertionPointToStart(body);
295 auto oldNumArgs = body->getNumArguments();
300 SmallVector<Value> bounceWires(oldNumArgs);
301 for (
auto &p : newPorts) {
302 auto newArg = body->addArgument(p.type, p.loc);
305 if (p.fieldID != 0) {
306 auto &wire = bounceWires[p.portID];
308 wire = WireOp::create(theBuilder, module.getPortType(p.portID),
309 module.getPortNameAttr(p.portID),
310 NameKindEnum::InterestingName)
313 bounceWires[p.portID] = newArg;
318 for (
auto idx = 0U; idx < oldNumArgs; ++idx) {
319 if (!bounceWires[idx]) {
320 bounceWires[idx] = WireOp::create(theBuilder, module.getPortType(idx),
321 module.getPortNameAttr(idx))
324 body->getArgument(idx).replaceAllUsesWith(bounceWires[idx]);
328 body->eraseArguments(0, oldNumArgs);
331 for (
auto &p : newPorts) {
332 if (isa<BlockArgument>(bounceWires[p.portID]))
336 theBuilder, body->getArgument(p.resultID),
342 body->getArgument(p.resultID));
346 SmallVector<NamedAttribute, 8> newModuleAttrs;
349 for (
auto attr :
module->getAttrDictionary())
352 if (attr.getName() != "portNames" && attr.getName() != "portDirections" &&
353 attr.getName() != "portTypes" && attr.getName() != "portAnnotations" &&
354 attr.getName() != "portSymbols" && attr.getName() != "portLocations")
355 newModuleAttrs.push_back(attr);
357 SmallVector<Direction> newPortDirections;
358 SmallVector<Attribute> newPortNames;
359 SmallVector<Attribute> newPortTypes;
360 SmallVector<Attribute> newPortSyms;
361 SmallVector<Attribute> newPortLocations;
362 SmallVector<Attribute, 8> newPortAnnotations;
363 SmallVector<Attribute> newPortDomains;
365 for (
auto p : newPorts) {
366 newPortTypes.push_back(TypeAttr::get(p.type));
367 newPortNames.push_back(p.name);
368 newPortDirections.push_back(p.direction);
369 newPortSyms.push_back(p.sym);
370 newPortLocations.push_back(p.loc);
371 newPortAnnotations.push_back(p.annotations.getArrayAttr());
372 newPortDomains.push_back(p.domains ? p.domains : cache.aEmpty);
375 newModuleAttrs.push_back(NamedAttribute(
376 cache.sPortDirections,
379 newModuleAttrs.push_back(
380 NamedAttribute(cache.sPortNames, theBuilder.getArrayAttr(newPortNames)));
382 newModuleAttrs.push_back(
383 NamedAttribute(cache.sPortTypes, theBuilder.getArrayAttr(newPortTypes)));
385 newModuleAttrs.push_back(NamedAttribute(
386 cache.sPortLocations, theBuilder.getArrayAttr(newPortLocations)));
388 newModuleAttrs.push_back(NamedAttribute(
389 cache.sPortAnnotations, theBuilder.getArrayAttr(newPortAnnotations)));
391 newModuleAttrs.push_back(NamedAttribute(
392 cache.sPortDomains, theBuilder.getArrayAttr(newPortDomains)));
395 module->setAttrs(newModuleAttrs);
396 FModuleLike::fixupPortSymsArray(newPortSyms, theBuilder.getContext());
397 module.setPortSymbols(newPortSyms);
402 const DenseMap<StringAttr, PortConversion> &ports) {
403 auto fixupInstance = [&](
auto inst,
auto clone) ->
void {
404 ImplicitLocOpBuilder theBuilder(inst.getLoc(), inst);
408 StringAttr moduleName =
409 cast<StringAttr>(inst.getReferencedModuleNamesAttr()[0]);
411 const auto &modPorts = ports.at(moduleName);
414 SmallVector<PortInfo> instPorts;
415 for (
auto p : modPorts) {
419 instPorts.push_back(p);
422 auto newOp = clone(theBuilder, inst, instPorts);
424 auto oldDict = inst->getDiscardableAttrDictionary();
425 auto newDict = newOp->getDiscardableAttrDictionary();
426 auto oldNames = inst.getPortNamesAttr();
427 SmallVector<NamedAttribute> newAttrs;
428 for (
auto na : oldDict)
429 if (!newDict.contains(na.getName()))
430 newOp->setDiscardableAttr(na.getName(), na.getValue());
433 SmallVector<WireOp> bounce(inst.getNumResults());
434 for (
auto p : modPorts) {
436 if (p.fieldID == 0) {
437 inst.getResult(p.portID).replaceAllUsesWith(
438 newOp.getResult(p.resultID));
441 if (!bounce[p.portID]) {
442 bounce[p.portID] = WireOp::create(
443 theBuilder, inst.getResult(p.portID).getType(),
444 theBuilder.getStringAttr(
445 inst.getName() +
"." +
446 cast<StringAttr>(oldNames[p.portID]).getValue()));
447 inst.getResult(p.portID).replaceAllUsesWith(
448 bounce[p.portID].getResult());
452 emitConnect(theBuilder, newOp.getResult(p.resultID),
459 newOp.getResult(p.resultID));
463 for (
auto *use : llvm::make_early_inc_range(inst->getUsers())) {
464 assert(isa<MatchingConnectOp>(use) || isa<ConnectOp>(use));
471 mod->walk([&](Operation *op) ->
void {
472 TypeSwitch<Operation *>(op)
473 .Case<InstanceOp>([&](
auto inst) {
474 fixupInstance(inst, [&](ImplicitLocOpBuilder &theBuilder,
476 ArrayRef<PortInfo> newPorts) {
477 return InstanceOp::create(
478 theBuilder, newPorts, inst.getModuleName(), inst.getName(),
479 inst.getNameKind(), inst.getAnnotations().getValue(),
480 inst.getLayers(), inst.getLowerToBind(), inst.getDoNotPrint(),
481 inst.getInnerSymAttr());
484 .Case<InstanceChoiceOp>([&](
auto inst) {
485 fixupInstance(inst, [&](ImplicitLocOpBuilder &theBuilder,
486 InstanceChoiceOp inst,
487 ArrayRef<PortInfo> newPorts) {
488 return InstanceChoiceOp::create(
489 theBuilder, newPorts, inst.getModuleNamesAttr(),
490 inst.getCaseNamesAttr(), inst.getName(), inst.getNameKind(),
491 inst.getAnnotationsAttr(), inst.getLayersAttr(),
492 inst.getInnerSymAttr());
503struct LowerSignaturesPass
504 :
public circt::firrtl::impl::LowerSignaturesBase<LowerSignaturesPass> {
505 void runOnOperation()
override;
510void LowerSignaturesPass::runOnOperation() {
512 auto &instanceGraph = getAnalysis<InstanceGraph>();
515 AttrCache cache(&getContext());
517 DenseMap<StringAttr, PortConversion> portMap;
518 auto circuit = getOperation();
520 for (
auto mod : circuit.getOps<FModuleLike>()) {
521 auto convention = mod.getConvention();
524 if (llvm::any_of(instanceGraph.lookup(mod)->uses(),
526 return use->getInstance<InstanceChoiceOp>();
528 convention = Convention::Scalarized;
531 return signalPassFailure();
533 parallelForEach(&getContext(), circuit.getOps<FModuleOp>(),
534 [&portMap](FModuleOp mod) { lowerModuleBody(mod, portMap); });
assert(baseType &&"element must be base type")
static std::unique_ptr< Context > context
static LogicalResult computeLowering(FModuleLike mod, Convention conv, PortConversion &newPorts)
static AnnotationSet annosForFieldIDRange(MLIRContext *ctx, const FieldIDSearch< AnnotationSet > &annos, uint64_t low, uint64_t high)
static LogicalResult lowerModuleSignature(FModuleLike module, Convention conv, AttrCache &cache, PortConversion &newPorts)
static LogicalResult computeLoweringImpl(FModuleLike mod, PortConversion &newPorts, Convention conv, size_t portID, const PortInfo &port, bool isFlip, Twine name, FIRRTLType type, uint64_t fieldID, const FieldIDSearch< hw::InnerSymAttr > &syms, const FieldIDSearch< AnnotationSet > &annos)
static void lowerModuleBody(FModuleOp mod, const DenseMap< StringAttr, PortConversion > &ports)
static hw::InnerSymAttr symbolsForFieldIDRange(MLIRContext *ctx, const FieldIDSearch< hw::InnerSymAttr > &syms, uint64_t low, uint64_t high)
static InstancePath empty
#define CIRCT_DEBUG_SCOPED_PASS_LOGGER(PASS)
This class provides a read-only projection over the MLIR attributes that represent a set of annotatio...
MLIRContext * getContext() const
Return the MLIRContext corresponding to this AnnotationSet.
void addAnnotations(ArrayRef< Annotation > annotations)
Add more annotations to this annotation set.
This class provides a read-only projection of an annotation.
This class implements the same functionality as TypeSwitch except that it uses firrtl::type_dyn_cast ...
FIRRTLTypeSwitch< T, ResultT > & Case(CallableT &&caseFn)
Add a case on the given type.
Base class for the port conversion of a particular port.
This is an edge in the InstanceGraph.
mlir::DenseBoolArrayAttr packAttribute(MLIRContext *context, ArrayRef< Direction > directions)
Return a DenseBoolArrayAttr containing the packed representation of an array of directions.
Value getValueByFieldID(ImplicitLocOpBuilder builder, Value value, unsigned fieldID)
This gets the value targeted by a field id.
void emitConnect(OpBuilder &builder, Location loc, Value lhs, Value rhs)
Emit a connect between two values.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
This holds the name and type that describes the module's ports.