22#include "mlir/IR/Threading.h"
23#include "mlir/Pass/Pass.h"
24#include "llvm/Support/Debug.h"
26#define DEBUG_TYPE "firrtl-lower-signatures"
30#define GEN_PASS_DEF_LOWERSIGNATURES
31#include "circt/Dialect/FIRRTL/Passes.h.inc"
36using namespace firrtl;
44 AttrCache(MLIRContext *context) {
45 nameAttr = StringAttr::get(context,
"name");
46 sPortDirections = StringAttr::get(context,
"portDirections");
47 sPortNames = StringAttr::get(context,
"portNames");
48 sPortTypes = StringAttr::get(context,
"portTypes");
49 sPortLocations = StringAttr::get(context,
"portLocations");
50 sPortAnnotations = StringAttr::get(context,
"portAnnotations");
51 sPortDomains = StringAttr::get(context,
"domainInfo");
52 aEmpty = ArrayAttr::get(context, {});
54 AttrCache(
const AttrCache &) =
default;
56 StringAttr nameAttr, sPortDirections, sPortNames, sPortTypes, sPortLocations,
57 sPortAnnotations, sPortDomains;
61struct FieldMapEntry :
public PortInfo {
71 using E =
typename T::ElementType;
72 using V = SmallVector<E>;
75 using const_iterator =
typename V::const_iterator;
77 template <
typename Container>
78 FieldIDSearch(
const Container &src) {
79 if constexpr (std::is_convertible_v<Container, Attribute>)
84 std::sort(vals.begin(), vals.end(), fieldComp);
87 std::pair<const_iterator, const_iterator> find(uint64_t low,
88 uint64_t high)
const {
89 return {std::lower_bound(vals.begin(), vals.end(), low, fieldCompInt2),
90 std::upper_bound(vals.begin(), vals.end(), high, fieldCompInt1)};
93 bool empty(uint64_t low, uint64_t high)
const {
94 auto [b, e] = find(low, high);
99 static constexpr auto fieldComp = [](
const E &lhs,
const E &rhs) {
100 return lhs.getFieldID() < rhs.getFieldID();
102 static constexpr auto fieldCompInt2 = [](
const E &lhs, uint64_t rhs) {
103 return lhs.getFieldID() < rhs;
105 static constexpr auto fieldCompInt1 = [](uint64_t lhs,
const E &rhs) {
106 return lhs < rhs.getFieldID();
114static hw::InnerSymAttr
116 const FieldIDSearch<hw::InnerSymAttr> &syms,
117 uint64_t low, uint64_t high) {
118 auto [b, e] = syms.find(low, high);
119 SmallVector<hw::InnerSymPropertiesAttr, 4> newSyms(b, e);
122 for (
auto &sym : newSyms)
123 sym = hw::InnerSymPropertiesAttr::get(
124 ctx, sym.getName(), sym.getFieldID() - low, sym.getSymVisibility());
125 return hw::InnerSymAttr::get(ctx, newSyms);
130 const FieldIDSearch<AnnotationSet> &annos, uint64_t low,
133 auto [b, e] = annos.find(low, high);
141 size_t portID,
const PortInfo &port,
bool isFlip,
142 Twine name,
FIRRTLType type, uint64_t fieldID,
143 const FieldIDSearch<hw::InnerSymAttr> &syms,
144 const FieldIDSearch<AnnotationSet> &annos) {
145 auto *ctx = type.getContext();
147 .
Case<BundleType>([&](BundleType bundle) -> LogicalResult {
150 if (conv != Convention::Scalarized && bundle.isPassive()) {
151 auto lastId = fieldID + bundle.getMaxFieldID();
153 {{StringAttr::get(ctx, name), type,
154 isFlip ? Direction::Out : Direction::In,
162 for (
auto [idx, elem] : llvm::enumerate(bundle.getElements())) {
164 mod, newPorts, conv, portID, port, isFlip ^ elem.isFlip,
165 name +
"_" + elem.
name.getValue(), elem.type,
166 fieldID + bundle.getFieldID(idx), syms, annos)))
168 if (!syms.empty(fieldID, fieldID))
169 return mod.emitError(
"Port [")
171 <<
"] should be subdivided, but cannot be because of "
173 << port.
sym.getSymIfExists(fieldID) <<
"] on a bundle";
174 if (!annos.empty(fieldID, fieldID)) {
175 auto err = mod.emitError(
"Port [")
177 <<
"] should be subdivided, but cannot be because of "
179 auto [b, e] = annos.find(fieldID, fieldID);
180 err << b->getClass() <<
"(" << b->getFieldID() <<
")";
183 err <<
", " << b->getClass() <<
"(" << b->getFieldID() <<
")";
184 err <<
"] on a bundle";
191 .Case<FVectorType>([&](FVectorType vector) -> LogicalResult {
192 if (conv != Convention::Scalarized &&
193 vector.getElementType().isPassive()) {
194 auto lastId = fieldID + vector.getMaxFieldID();
196 {{StringAttr::get(ctx, name), type,
197 isFlip ? Direction::Out : Direction::In,
205 for (
size_t i = 0, e = vector.getNumElements(); i < e; ++i) {
207 mod, newPorts, conv, portID, port, isFlip,
208 name +
"_" + Twine(i), vector.getElementType(),
209 fieldID + vector.getFieldID(i), syms, annos)))
211 if (!syms.empty(fieldID, fieldID))
212 return mod.emitError(
"Port [")
214 <<
"] should be subdivided, but cannot be because of "
216 << port.
sym.getSymIfExists(fieldID) <<
"] on a vector";
217 if (!annos.empty(fieldID, fieldID)) {
218 auto err = mod.emitError(
"Port [")
220 <<
"] should be subdivided, but cannot be because of "
222 auto [b, e] = annos.find(fieldID, fieldID);
223 err << b->getClass();
226 err <<
", " << b->getClass();
227 err <<
"] on a vector";
237 {{StringAttr::get(ctx, name), type,
238 isFlip ? Direction::Out : Direction::In,
252 for (
auto [idx, port] : llvm::enumerate(mod.getPorts())) {
254 mod, newPorts, conv, idx, port, port.direction == Direction::Out,
255 port.name.getValue(), type_cast<FIRRTLType>(port.type), 0,
256 FieldIDSearch<hw::InnerSymAttr>(port.sym),
257 FieldIDSearch<AnnotationSet>(port.annotations))))
266 ImplicitLocOpBuilder theBuilder(module.getLoc(), module.getContext());
269 if (
auto mod = dyn_cast<FModuleOp>(module.getOperation())) {
270 Block *body = mod.getBodyBlock();
271 theBuilder.setInsertionPointToStart(body);
272 auto oldNumArgs = body->getNumArguments();
277 SmallVector<Value> bounceWires(oldNumArgs);
278 for (
auto &p : newPorts) {
279 auto newArg = body->addArgument(p.type, p.loc);
282 if (p.fieldID != 0) {
283 auto &wire = bounceWires[p.portID];
285 wire = WireOp::create(theBuilder, module.getPortType(p.portID),
286 module.getPortNameAttr(p.portID),
287 NameKindEnum::InterestingName)
290 bounceWires[p.portID] = newArg;
295 for (
auto idx = 0U; idx < oldNumArgs; ++idx) {
296 if (!bounceWires[idx]) {
297 bounceWires[idx] = WireOp::create(theBuilder, module.getPortType(idx),
298 module.getPortNameAttr(idx))
301 body->getArgument(idx).replaceAllUsesWith(bounceWires[idx]);
305 body->eraseArguments(0, oldNumArgs);
308 for (
auto &p : newPorts) {
309 if (isa<BlockArgument>(bounceWires[p.portID]))
313 theBuilder, body->getArgument(p.resultID),
319 body->getArgument(p.resultID));
323 SmallVector<NamedAttribute, 8> newModuleAttrs;
326 for (
auto attr :
module->getAttrDictionary())
329 if (attr.getName() != "portNames" && attr.getName() != "portDirections" &&
330 attr.getName() != "portTypes" && attr.getName() != "portAnnotations" &&
331 attr.getName() != "portSymbols" && attr.getName() != "portLocations")
332 newModuleAttrs.push_back(attr);
334 SmallVector<Direction> newPortDirections;
335 SmallVector<Attribute> newPortNames;
336 SmallVector<Attribute> newPortTypes;
337 SmallVector<Attribute> newPortSyms;
338 SmallVector<Attribute> newPortLocations;
339 SmallVector<Attribute, 8> newPortAnnotations;
340 SmallVector<Attribute> newPortDomains;
342 for (
auto p : newPorts) {
343 newPortTypes.push_back(TypeAttr::get(p.type));
344 newPortNames.push_back(p.name);
345 newPortDirections.push_back(p.direction);
346 newPortSyms.push_back(p.sym);
347 newPortLocations.push_back(p.loc);
348 newPortAnnotations.push_back(p.annotations.getArrayAttr());
349 newPortDomains.push_back(p.domains ? p.domains : cache.aEmpty);
352 newModuleAttrs.push_back(NamedAttribute(
353 cache.sPortDirections,
356 newModuleAttrs.push_back(
357 NamedAttribute(cache.sPortNames, theBuilder.getArrayAttr(newPortNames)));
359 newModuleAttrs.push_back(
360 NamedAttribute(cache.sPortTypes, theBuilder.getArrayAttr(newPortTypes)));
362 newModuleAttrs.push_back(NamedAttribute(
363 cache.sPortLocations, theBuilder.getArrayAttr(newPortLocations)));
365 newModuleAttrs.push_back(NamedAttribute(
366 cache.sPortAnnotations, theBuilder.getArrayAttr(newPortAnnotations)));
368 newModuleAttrs.push_back(NamedAttribute(
369 cache.sPortDomains, theBuilder.getArrayAttr(newPortDomains)));
372 module->setAttrs(newModuleAttrs);
373 FModuleLike::fixupPortSymsArray(newPortSyms, theBuilder.getContext());
374 module.setPortSymbols(newPortSyms);
379 const DenseMap<StringAttr, PortConversion> &ports) {
380 mod->walk([&](InstanceOp inst) ->
void {
381 ImplicitLocOpBuilder theBuilder(inst.getLoc(), inst);
382 const auto &modPorts = ports.at(inst.getModuleNameAttr().getAttr());
385 SmallVector<PortInfo> instPorts;
386 for (
auto p : modPorts) {
390 instPorts.push_back(p);
392 auto annos = inst.getAnnotations();
393 auto newOp = InstanceOp::create(
394 theBuilder, instPorts, inst.getModuleName(), inst.getName(),
395 inst.getNameKind(), annos.getValue(), inst.getLayers(),
396 inst.getLowerToBind(), inst.getDoNotPrint(), inst.getInnerSymAttr());
398 auto oldDict = inst->getDiscardableAttrDictionary();
399 auto newDict = newOp->getDiscardableAttrDictionary();
400 auto oldNames = inst.getPortNamesAttr();
401 SmallVector<NamedAttribute> newAttrs;
402 for (
auto na : oldDict)
403 if (!newDict.contains(na.getName()))
404 newOp->setDiscardableAttr(na.getName(), na.getValue());
407 SmallVector<WireOp> bounce(inst.getNumResults());
408 for (
auto p : modPorts) {
410 if (p.fieldID == 0) {
411 inst.getResult(p.portID).replaceAllUsesWith(
412 newOp.getResult(p.resultID));
415 if (!bounce[p.portID]) {
416 bounce[p.portID] = WireOp::create(
417 theBuilder, inst.getResult(p.portID).getType(),
418 theBuilder.getStringAttr(
419 inst.getName() +
"." +
420 cast<StringAttr>(oldNames[p.portID]).getValue()));
421 inst.getResult(p.portID).replaceAllUsesWith(
422 bounce[p.portID].getResult());
426 emitConnect(theBuilder, newOp.getResult(p.resultID),
433 newOp.getResult(p.resultID));
437 for (
auto *use : llvm::make_early_inc_range(inst->getUsers())) {
438 assert(isa<MatchingConnectOp>(use) || isa<ConnectOp>(use));
451struct LowerSignaturesPass
452 :
public circt::firrtl::impl::LowerSignaturesBase<LowerSignaturesPass> {
453 void runOnOperation()
override;
458void LowerSignaturesPass::runOnOperation() {
462 AttrCache cache(&getContext());
464 DenseMap<StringAttr, PortConversion> portMap;
465 auto circuit = getOperation();
467 for (
auto mod : circuit.getOps<FModuleLike>()) {
469 portMap[mod.getNameAttr()])
471 return signalPassFailure();
473 parallelForEach(&getContext(), circuit.getOps<FModuleOp>(),
474 [&portMap](FModuleOp mod) { lowerModuleBody(mod, portMap); });
assert(baseType &&"element must be base type")
static LogicalResult computeLowering(FModuleLike mod, Convention conv, PortConversion &newPorts)
static AnnotationSet annosForFieldIDRange(MLIRContext *ctx, const FieldIDSearch< AnnotationSet > &annos, uint64_t low, uint64_t high)
static LogicalResult lowerModuleSignature(FModuleLike module, Convention conv, AttrCache &cache, PortConversion &newPorts)
static LogicalResult computeLoweringImpl(FModuleLike mod, PortConversion &newPorts, Convention conv, size_t portID, const PortInfo &port, bool isFlip, Twine name, FIRRTLType type, uint64_t fieldID, const FieldIDSearch< hw::InnerSymAttr > &syms, const FieldIDSearch< AnnotationSet > &annos)
static void lowerModuleBody(FModuleOp mod, const DenseMap< StringAttr, PortConversion > &ports)
static hw::InnerSymAttr symbolsForFieldIDRange(MLIRContext *ctx, const FieldIDSearch< hw::InnerSymAttr > &syms, uint64_t low, uint64_t high)
static InstancePath empty
#define CIRCT_DEBUG_SCOPED_PASS_LOGGER(PASS)
This class provides a read-only projection over the MLIR attributes that represent a set of annotatio...
MLIRContext * getContext() const
Return the MLIRContext corresponding to this AnnotationSet.
void addAnnotations(ArrayRef< Annotation > annotations)
Add more annotations to this annotation set.
This class provides a read-only projection of an annotation.
This class implements the same functionality as TypeSwitch except that it uses firrtl::type_dyn_cast ...
FIRRTLTypeSwitch< T, ResultT > & Case(CallableT &&caseFn)
Add a case on the given type.
Base class for the port conversion of a particular port.
mlir::DenseBoolArrayAttr packAttribute(MLIRContext *context, ArrayRef< Direction > directions)
Return a DenseBoolArrayAttr containing the packed representation of an array of directions.
Value getValueByFieldID(ImplicitLocOpBuilder builder, Value value, unsigned fieldID)
This gets the value targeted by a field id.
void emitConnect(OpBuilder &builder, Location loc, Value lhs, Value rhs)
Emit a connect between two values.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
This holds the name and type that describes the module's ports.