17 #include "mlir/Pass/Pass.h"
32 #include "mlir/IR/ImplicitLocOpBuilder.h"
33 #include "mlir/IR/Threading.h"
34 #include "llvm/ADT/APSInt.h"
35 #include "llvm/ADT/BitVector.h"
36 #include "llvm/Support/Debug.h"
37 #include "llvm/Support/Parallel.h"
39 #define DEBUG_TYPE "firrtl-lower-signatures"
43 #define GEN_PASS_DEF_LOWERSIGNATURES
44 #include "circt/Dialect/FIRRTL/Passes.h.inc"
48 using namespace circt;
49 using namespace firrtl;
57 AttrCache(MLIRContext *context) {
66 AttrCache(
const AttrCache &) =
default;
68 StringAttr nameAttr, sPortDirections, sPortNames, sPortTypes, sPortLocations,
69 sPortAnnotations, sInternalPaths;
72 struct FieldMapEntry :
public PortInfo {
82 using E =
typename T::ElementType;
83 using V = SmallVector<E>;
86 using const_iterator =
typename V::const_iterator;
88 template <
typename Container>
89 FieldIDSearch(
const Container &src) {
90 if constexpr (std::is_convertible_v<Container, Attribute>)
95 std::sort(vals.begin(), vals.end(), fieldComp);
98 std::pair<const_iterator, const_iterator> find(uint64_t low,
99 uint64_t high)
const {
100 return {std::lower_bound(vals.begin(), vals.end(), low, fieldCompInt2),
101 std::upper_bound(vals.begin(), vals.end(), high, fieldCompInt1)};
104 bool empty(uint64_t low, uint64_t high)
const {
105 auto [b, e] = find(low, high);
110 static constexpr
auto fieldComp = [](
const E &lhs,
const E &rhs) {
111 return lhs.getFieldID() < rhs.getFieldID();
113 static constexpr
auto fieldCompInt2 = [](
const E &lhs, uint64_t rhs) {
114 return lhs.getFieldID() < rhs;
116 static constexpr
auto fieldCompInt1 = [](uint64_t lhs,
const E &rhs) {
117 return lhs < rhs.getFieldID();
125 static hw::InnerSymAttr
127 const FieldIDSearch<hw::InnerSymAttr> &syms,
128 uint64_t low, uint64_t high) {
129 auto [b, e] = syms.find(low, high);
130 SmallVector<hw::InnerSymPropertiesAttr, 4> newSyms(b, e);
133 for (
auto &sym : newSyms)
135 ctx, sym.getName(), sym.getFieldID() - low, sym.getSymVisibility());
141 const FieldIDSearch<AnnotationSet> &annos, uint64_t low,
144 auto [b, e] = annos.find(low, high);
152 size_t portID,
const PortInfo &port,
bool isFlip,
153 Twine name,
FIRRTLType type, uint64_t fieldID,
154 const FieldIDSearch<hw::InnerSymAttr> &syms,
155 const FieldIDSearch<AnnotationSet> &annos) {
156 auto *ctx = type.getContext();
158 .
Case<BundleType>([&](BundleType bundle) -> LogicalResult {
161 if (conv != Convention::Scalarized && bundle.isPassive()) {
162 auto lastId = fieldID + bundle.getMaxFieldID();
172 for (
auto [idx, elem] : llvm::enumerate(bundle.getElements())) {
174 mod, newPorts, conv, portID, port, isFlip ^ elem.isFlip,
175 name +
"_" + elem.
name.getValue(), elem.type,
176 fieldID + bundle.getFieldID(idx), syms, annos)))
178 if (!syms.empty(fieldID, fieldID))
179 return mod.emitError(
"Port [")
181 <<
"] should be subdivided, but cannot be because of "
183 << port.
sym.getSymIfExists(fieldID) <<
"] on a bundle";
184 if (!annos.empty(fieldID, fieldID)) {
185 auto err = mod.emitError(
"Port [")
187 <<
"] should be subdivided, but cannot be because of "
189 auto [b, e] = annos.find(fieldID, fieldID);
190 err << b->getClass() <<
"(" << b->getFieldID() <<
")";
193 err <<
", " << b->getClass() <<
"(" << b->getFieldID() <<
")";
194 err <<
"] on a bundle";
201 .Case<FVectorType>([&](FVectorType vector) -> LogicalResult {
202 if (conv != Convention::Scalarized &&
203 vector.getElementType().isPassive()) {
204 auto lastId = fieldID + vector.getMaxFieldID();
214 for (
size_t i = 0, e = vector.getNumElements(); i < e; ++i) {
216 mod, newPorts, conv, portID, port, isFlip,
217 name +
"_" + Twine(i), vector.getElementType(),
218 fieldID + vector.getFieldID(i), syms, annos)))
220 if (!syms.empty(fieldID, fieldID))
221 return mod.emitError(
"Port [")
223 <<
"] should be subdivided, but cannot be because of "
225 << port.
sym.getSymIfExists(fieldID) <<
"] on a vector";
226 if (!annos.empty(fieldID, fieldID)) {
227 auto err = mod.emitError(
"Port [")
229 <<
"] should be subdivided, but cannot be because of "
231 auto [b, e] = annos.find(fieldID, fieldID);
232 err << b->getClass();
235 err <<
", " << b->getClass();
236 err <<
"] on a vector";
243 .Case<FEnumType>([&](FEnumType fenum) {
return failure(); })
262 for (
auto [idx, port] : llvm::enumerate(mod.getPorts())) {
265 port.name.getValue(), type_cast<FIRRTLType>(port.type), 0,
266 FieldIDSearch<hw::InnerSymAttr>(port.sym),
267 FieldIDSearch<AnnotationSet>(port.annotations))))
276 ImplicitLocOpBuilder theBuilder(module.getLoc(), module.getContext());
279 if (
auto mod = dyn_cast<FModuleOp>(module.getOperation())) {
280 Block *body = mod.getBodyBlock();
281 theBuilder.setInsertionPointToStart(body);
282 auto oldNumArgs = body->getNumArguments();
287 SmallVector<Value> bounceWires(oldNumArgs);
288 for (
auto &p : newPorts) {
289 auto newArg = body->addArgument(p.type, p.loc);
292 if (p.fieldID != 0) {
293 auto &wire = bounceWires[p.portID];
296 .create<WireOp>(module.getPortType(p.portID),
297 module.getPortNameAttr(p.portID),
298 NameKindEnum::InterestingName)
301 bounceWires[p.portID] = newArg;
306 for (
auto idx = 0U; idx < oldNumArgs; ++idx) {
307 if (!bounceWires[idx]) {
308 bounceWires[idx] = theBuilder
309 .create<WireOp>(module.getPortType(idx),
310 module.getPortNameAttr(idx))
313 body->getArgument(idx).replaceAllUsesWith(bounceWires[idx]);
317 body->eraseArguments(0, oldNumArgs);
320 for (
auto &p : newPorts) {
321 if (isa<BlockArgument>(bounceWires[p.portID]))
325 theBuilder, body->getArgument(p.resultID),
331 body->getArgument(p.resultID));
335 SmallVector<NamedAttribute, 8> newModuleAttrs;
338 for (
auto attr : module->getAttrDictionary())
341 if (attr.getName() !=
"portNames" && attr.getName() !=
"portDirections" &&
342 attr.getName() !=
"portTypes" && attr.getName() !=
"portAnnotations" &&
343 attr.getName() !=
"portSymbols" && attr.getName() !=
"portLocations" &&
344 attr.getName() !=
"internalPaths")
345 newModuleAttrs.push_back(attr);
347 SmallVector<Direction> newPortDirections;
348 SmallVector<Attribute> newPortNames;
349 SmallVector<Attribute> newPortTypes;
350 SmallVector<Attribute> newPortSyms;
351 SmallVector<Attribute> newPortLocations;
352 SmallVector<Attribute, 8> newPortAnnotations;
353 SmallVector<Attribute> newInternalPaths;
355 bool hasInternalPaths =
false;
356 auto internalPaths = module->getAttrOfType<ArrayAttr>(
"internalPaths");
357 for (
auto p : newPorts) {
359 newPortNames.push_back(p.name);
360 newPortDirections.push_back(p.direction);
361 newPortSyms.push_back(p.sym);
362 newPortLocations.push_back(p.loc);
363 newPortAnnotations.push_back(p.annotations.getArrayAttr());
365 auto internalPath = cast<InternalPathAttr>(internalPaths[p.portID]);
366 newInternalPaths.push_back(internalPath);
367 if (internalPath.getPath())
368 hasInternalPaths =
true;
372 newModuleAttrs.push_back(NamedAttribute(
373 cache.sPortDirections,
376 newModuleAttrs.push_back(
377 NamedAttribute(cache.sPortNames, theBuilder.getArrayAttr(newPortNames)));
379 newModuleAttrs.push_back(
380 NamedAttribute(cache.sPortTypes, theBuilder.getArrayAttr(newPortTypes)));
382 newModuleAttrs.push_back(NamedAttribute(
383 cache.sPortLocations, theBuilder.getArrayAttr(newPortLocations)));
385 newModuleAttrs.push_back(NamedAttribute(
386 cache.sPortAnnotations, theBuilder.getArrayAttr(newPortAnnotations)));
388 assert(newInternalPaths.empty() ||
389 newInternalPaths.size() == newPorts.size());
390 if (hasInternalPaths) {
391 newModuleAttrs.emplace_back(cache.sInternalPaths,
392 theBuilder.getArrayAttr(newInternalPaths));
396 module->setAttrs(newModuleAttrs);
397 FModuleLike::fixupPortSymsArray(newPortSyms, theBuilder.getContext());
398 module.setPortSymbols(newPortSyms);
403 const DenseMap<StringAttr, PortConversion> &ports) {
404 mod->walk([&](InstanceOp inst) ->
void {
405 ImplicitLocOpBuilder theBuilder(inst.getLoc(), inst);
406 const auto &modPorts = ports.at(inst.getModuleNameAttr().getAttr());
409 SmallVector<PortInfo> instPorts;
410 for (
auto p : modPorts) {
414 instPorts.push_back(p);
416 auto annos = inst.getAnnotations();
417 auto newOp = theBuilder.create<InstanceOp>(
418 instPorts, inst.getModuleName(), inst.getName(), inst.getNameKind(),
419 annos.getValue(), inst.getLayers(), inst.getLowerToBind(),
420 inst.getInnerSymAttr());
422 auto oldDict = inst->getDiscardableAttrDictionary();
423 auto newDict = newOp->getDiscardableAttrDictionary();
424 auto oldNames = inst.getPortNamesAttr();
425 SmallVector<NamedAttribute> newAttrs;
426 for (
auto na : oldDict)
427 if (!newDict.contains(na.getName()))
428 newOp->setDiscardableAttr(na.getName(), na.getValue());
431 SmallVector<WireOp> bounce(inst.getNumResults());
432 for (
auto p : modPorts) {
434 if (p.fieldID == 0) {
435 inst.getResult(p.portID).replaceAllUsesWith(
436 newOp.getResult(p.resultID));
439 if (!bounce[p.portID]) {
440 bounce[p.portID] = theBuilder.create<WireOp>(
441 inst.getResult(p.portID).getType(),
442 theBuilder.getStringAttr(
443 inst.getName() +
"." +
444 cast<StringAttr>(oldNames[p.portID]).getValue()));
445 inst.getResult(p.portID).replaceAllUsesWith(
446 bounce[p.portID].getResult());
450 emitConnect(theBuilder, newOp.getResult(p.resultID),
457 newOp.getResult(p.resultID));
461 for (
auto *use : llvm::make_early_inc_range(inst->getUsers())) {
462 assert(isa<MatchingConnectOp>(use) || isa<ConnectOp>(use));
475 struct LowerSignaturesPass
476 :
public circt::firrtl::impl::LowerSignaturesBase<LowerSignaturesPass> {
477 void runOnOperation()
override;
482 void LowerSignaturesPass::runOnOperation() {
485 AttrCache cache(&getContext());
487 DenseMap<StringAttr, PortConversion> portMap;
488 auto circuit = getOperation();
490 for (
auto mod : circuit.getOps<FModuleLike>()) {
492 portMap[mod.getNameAttr()])
494 return signalPassFailure();
496 parallelForEach(&getContext(), circuit.getOps<FModuleOp>(),
497 [&portMap](FModuleOp mod) { lowerModuleBody(mod, portMap); });
502 return std::make_unique<LowerSignaturesPass>();
assert(baseType &&"element must be base type")
static InstancePath empty
static LogicalResult computeLowering(FModuleLike mod, Convention conv, PortConversion &newPorts)
static AnnotationSet annosForFieldIDRange(MLIRContext *ctx, const FieldIDSearch< AnnotationSet > &annos, uint64_t low, uint64_t high)
static LogicalResult lowerModuleSignature(FModuleLike module, Convention conv, AttrCache &cache, PortConversion &newPorts)
static LogicalResult computeLoweringImpl(FModuleLike mod, PortConversion &newPorts, Convention conv, size_t portID, const PortInfo &port, bool isFlip, Twine name, FIRRTLType type, uint64_t fieldID, const FieldIDSearch< hw::InnerSymAttr > &syms, const FieldIDSearch< AnnotationSet > &annos)
static void lowerModuleBody(FModuleOp mod, const DenseMap< StringAttr, PortConversion > &ports)
static hw::InnerSymAttr symbolsForFieldIDRange(MLIRContext *ctx, const FieldIDSearch< hw::InnerSymAttr > &syms, uint64_t low, uint64_t high)
This class provides a read-only projection over the MLIR attributes that represent a set of annotatio...
void addAnnotations(ArrayRef< Annotation > annotations)
Add more annotations to this annotation set.
MLIRContext * getContext() const
Return the MLIRContext corresponding to this AnnotationSet.
This class provides a read-only projection of an annotation.
This class implements the same functionality as TypeSwitch except that it uses firrtl::type_dyn_cast ...
FIRRTLTypeSwitch< T, ResultT > & Case(CallableT &&caseFn)
Add a case on the given type.
Base class for the port conversion of a particular port.
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
mlir::DenseBoolArrayAttr packAttribute(MLIRContext *context, ArrayRef< Direction > directions)
Return a DenseBoolArrayAttr containing the packed representation of an array of directions.
std::unique_ptr< mlir::Pass > createLowerSignaturesPass()
This is the pass constructor.
Value getValueByFieldID(ImplicitLocOpBuilder builder, Value value, unsigned fieldID)
This gets the value targeted by a field id.
void emitConnect(OpBuilder &builder, Location loc, Value lhs, Value rhs)
Emit a connect between two values.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
llvm::raw_ostream & debugPassHeader(const mlir::Pass *pass, int width=80)
Write a boilerplate header for a pass to the debug stream.
This holds the name and type that describes the module's ports.