50#include "llvm/ADT/MapVector.h"
51#include "llvm/ADT/TypeSwitch.h"
52#include "llvm/Support/Debug.h"
53#include "llvm/Support/Threading.h"
57#define GEN_PASS_DEF_LOWERDOMAINS
58#include "circt/Dialect/FIRRTL/Passes.h.inc"
63using namespace firrtl;
64using mlir::UnrealizedConversionCastOp;
72 impl::LowerDomainsBase<LowerDomainsPass>::getArgumentName().data()
76struct AssociationInfo {
78 DistinctAttr distinctAttr;
95 std::optional<unsigned> inputPort;
106 UnrealizedConversionCastOp temp;
110 SmallVector<AssociationInfo> associations{};
115 static DomainInfo input(
unsigned portIndex) {
116 return DomainInfo({{}, portIndex, portIndex + 1, {}, {}});
122 static DomainInfo output(
unsigned portIndex) {
123 return DomainInfo({{}, std::nullopt, portIndex, {}, {}});
144 llvm::once_flag flag;
150 llvm::once_flag flag;
151 StringAttr domainInfoIn;
152 StringAttr domainInfoOut;
153 StringAttr associationsIn;
154 StringAttr associationsOut;
158 Constants(MLIRContext *context) : context(context) {}
161 ArrayAttr getEmptyArrayAttr() {
162 llvm::call_once(emptyArray.flag,
163 [&] { emptyArray.attr = ArrayAttr::get(context, {}); });
164 return emptyArray.attr;
169 void initClassOut() {
170 llvm::call_once(classOut.flag, [&] {
171 classOut.domainInfoIn = StringAttr::get(context,
"domainInfo_in");
172 classOut.domainInfoOut = StringAttr::get(context,
"domainInfo_out");
173 classOut.associationsIn = StringAttr::get(context,
"associations_in");
174 classOut.associationsOut = StringAttr::get(context,
"associations_out");
180 StringAttr getDomainInfoIn() {
182 return classOut.domainInfoIn;
186 StringAttr getDomainInfoOut() {
188 return classOut.domainInfoOut;
192 StringAttr getAssociationsIn() {
194 return classOut.associationsIn;
198 StringAttr getAssociationsOut() {
200 return classOut.associationsOut;
208 EmptyArray emptyArray;
218static UnrealizedConversionCastOp stubOut(Value value) {
219 if (!value.hasNUsesOrMore(1))
222 OpBuilder builder(value.getContext());
223 builder.setInsertionPointAfterValue(value);
224 auto temp = UnrealizedConversionCastOp::create(
225 builder, builder.getUnknownLoc(), {value.getType()}, {});
226 value.replaceAllUsesWith(temp.getResult(0));
240static UnrealizedConversionCastOp splice(UnrealizedConversionCastOp temp,
246 assert(temp && temp.getNumResults() == 1 && temp.getNumOperands() == 0);
251 auto oldValue = temp.getResult(0);
253 OpBuilder builder(temp);
254 auto splice = UnrealizedConversionCastOp::create(
255 builder, builder.getUnknownLoc(), {oldValue.getType()}, {value});
256 oldValue.replaceAllUsesWith(splice.getResult(0));
270 LowerModule(FModuleLike op,
const DenseMap<Attribute, Classes> &classes,
272 : op(op), eraseVector(op.
getNumPorts()), domainToClasses(classes),
273 constants(constants), instanceGraph(instanceGraph) {}
277 LogicalResult lowerModule();
281 LogicalResult lowerInstances();
288 BitVector eraseVector;
292 SmallVector<std::pair<unsigned, PortInfo>> newPorts;
295 SmallVector<std::pair<unsigned, unsigned>> resultMap;
298 const DenseMap<Attribute, Classes> &domainToClasses;
301 Constants &constants;
311 llvm::MapVector<unsigned, DomainInfo> indexToDomain;
314LogicalResult LowerModule::lowerModule() {
324 TypeSwitch<Operation *, std::optional<Block *>>(op)
325 .Case<FModuleOp>([](
auto op) {
return op.getBodyBlock(); })
326 .Case<FExtModuleOp>([](
auto) {
return nullptr; })
328 .Default([](
auto) {
return std::nullopt; });
331 Block *body = *shouldProcess;
333 auto *
context = op.getContext();
337 SmallVector<Attribute> portAnnotations;
346 OpBuilder::InsertPoint insertPoint;
348 insertPoint = {body, body->begin()};
349 auto ports = op.getPorts();
350 for (
unsigned i = 0, iDel = 0, iIns = 0, e = op.getNumPorts(); i != e; ++i) {
351 auto port = cast<PortInfo>(ports[i]);
354 if (
auto domain = dyn_cast_or_null<FlatSymbolRefAttr>(port.domains)) {
358 auto [classIn, classOut] = domainToClasses.at(domain.getAttr());
360 indexToDomain[i] = port.direction == Direction::In
361 ? DomainInfo::input(iIns)
362 : DomainInfo::output(iIns);
367 ImplicitLocOpBuilder builder(port.loc,
context);
368 builder.restoreInsertionPoint(insertPoint);
371 auto object = ObjectOp::create(
373 StringAttr::get(
context, Twine(port.name) +
"_object"));
374 instanceGraph.lookup(op)->addInstance(
object,
375 instanceGraph.lookup(classOut));
376 indexToDomain[i].op = object;
377 indexToDomain[i].temp = stubOut(body->getArgument(i));
382 insertPoint = builder.saveInsertionPoint();
389 if (port.direction == Direction::In) {
390 newPorts.push_back({iDel,
PortInfo(port.name, classIn.getInstanceType(),
392 portAnnotations.push_back(constants.getEmptyArrayAttr());
397 classOut.getInstanceType(), Direction::Out)});
401 portAnnotations.push_back(constants.getEmptyArrayAttr());
417 ArrayAttr domainAttr = cast_or_null<ArrayAttr>(port.domains);
418 if (!domainAttr || domainAttr.empty()) {
419 portAnnotations.push_back(port.annotations.getArrayAttr());
426 if (
auto firrtlType = type_dyn_cast<FIRRTLType>(port.type))
428 portAnnotations.push_back(port.annotations.getArrayAttr());
432 SmallVector<Annotation> newAnnotations;
434 for (
auto indexAttr : domainAttr.getAsRange<IntegerAttr>()) {
436 id = DistinctAttr::create(UnitAttr::get(
context));
437 newAnnotations.push_back(
Annotation(DictionaryAttr::getWithSorted(
441 indexToDomain[indexAttr.getUInt()].associations.push_back({id, port.loc});
443 if (!newAnnotations.empty())
444 port.annotations.addAnnotations(newAnnotations);
445 portAnnotations.push_back(port.annotations.getArrayAttr());
449 op.erasePorts(eraseVector);
450 op.setDomainInfoAttr(constants.getEmptyArrayAttr());
454 op.insertPorts(newPorts);
457 for (
auto const &[_, info] : indexToDomain) {
458 auto [object, inputPort, outputPort, temp, associations] =
info;
459 OpBuilder builder(
object);
460 builder.setInsertionPointAfter(
object);
467 auto subDomainInfoIn =
468 ObjectSubfieldOp::create(builder,
object.
getLoc(),
object, 0);
470 PropAssignOp::create(builder,
object.
getLoc(), subDomainInfoIn,
471 body->getArgument(*inputPort));
472 splice(temp, body->getArgument(*inputPort));
474 splice(temp, subDomainInfoIn);
478 auto subAssociations =
479 ObjectSubfieldOp::create(builder,
object.
getLoc(),
object, 2);
480 SmallVector<Value> paths;
481 for (
auto [
id, loc] : associations) {
482 paths.push_back(PathOp::create(
483 builder, loc, TargetKindAttr::get(
context, TargetKind::Reference),
486 auto list = ListCreateOp::create(
488 ListType::get(
context, cast<PropertyType>(PathType::get(
context))),
490 PropAssignOp::create(builder,
object.
getLoc(), subAssociations, list);
493 PropAssignOp::create(builder,
object.
getLoc(),
494 body->getArgument(outputPort),
object);
508 DenseSet<Operation *> conversionsToErase;
509 DenseSet<Operation *> operationsToErase;
510 auto walkResult = op.walk([&](Operation *walkOp) {
514 if (operationsToErase.contains(walkOp)) {
516 return WalkResult::advance();
520 if (
auto castOp = dyn_cast<UnsafeDomainCastOp>(walkOp)) {
521 for (
auto value : castOp.getDomains()) {
522 auto *conversion = value.getDefiningOp();
523 assert(isa<UnrealizedConversionCastOp>(conversion));
524 conversionsToErase.insert(conversion);
527 castOp.getResult().replaceAllUsesWith(castOp.getInput());
529 return WalkResult::advance();
535 if (
auto anonDomain = dyn_cast<DomainCreateAnonOp>(walkOp)) {
537 auto noUser = llvm::all_of(anonDomain->getUsers(), [&](
auto *user) {
538 return operationsToErase.contains(user) ||
539 conversionsToErase.contains(user);
542 conversionsToErase.insert(anonDomain);
543 return WalkResult::advance();
548 OpBuilder builder(anonDomain);
550 domainToClasses.at(anonDomain.getDomainAttr().getAttr()).input;
551 anonDomain.replaceAllUsesWith(UnrealizedConversionCastOp::create(
552 builder, anonDomain.getLoc(), {anonDomain.getType()},
553 {UnknownValueOp::create(builder, anonDomain.getLoc(),
554 classIn.getInstanceType())
557 return WalkResult::advance();
577 if (
auto wireOp = dyn_cast<WireOp>(walkOp)) {
578 if (type_isa<DomainType>(wireOp.getResult().getType())) {
580 SmallVector<Value> dsts;
581 DomainDefineOp lastDefineOp;
582 for (
auto *user :
llvm::make_early_inc_range(wireOp->getUsers())) {
583 auto domainDefineOp = dyn_cast<DomainDefineOp>(user);
584 if (operationsToErase.contains(domainDefineOp))
586 if (!domainDefineOp) {
587 auto diag = wireOp.emitOpError()
588 <<
"cannot be lowered by `LowerDomains` because it "
589 "has a user that is not a domain define op";
590 diag.attachNote(user->getLoc()) <<
"is one such user";
591 return WalkResult::interrupt();
593 if (!lastDefineOp || lastDefineOp->isBeforeInBlock(domainDefineOp))
594 lastDefineOp = domainDefineOp;
595 if (wireOp == domainDefineOp.getSrc().getDefiningOp())
596 dsts.push_back(domainDefineOp.getDest());
598 src = domainDefineOp.getSrc();
599 operationsToErase.insert(domainDefineOp);
601 conversionsToErase.insert(wireOp);
604 if (!src || dsts.empty())
605 return WalkResult::advance();
609 OpBuilder builder(lastDefineOp);
610 for (
auto dst :
llvm::reverse(dsts))
611 DomainDefineOp::create(builder, builder.getUnknownLoc(), dst, src);
613 return WalkResult::advance();
617 auto defineOp = dyn_cast<DomainDefineOp>(walkOp);
619 return WalkResult::advance();
627 auto *src = defineOp.getSrc().getDefiningOp();
628 auto dest = dyn_cast<UnrealizedConversionCastOp>(
629 defineOp.getDest().getDefiningOp());
631 return WalkResult::advance();
633 conversionsToErase.insert(src);
634 conversionsToErase.insert(dest);
636 if (
auto srcCast = dyn_cast<UnrealizedConversionCastOp>(src)) {
637 assert(srcCast.getNumOperands() == 1 && srcCast.getNumResults() == 1);
638 OpBuilder builder(defineOp);
639 PropAssignOp::create(builder, defineOp.getLoc(), dest.getOperand(0),
640 srcCast.getOperand(0));
641 }
else if (!isa<DomainCreateAnonOp>(src)) {
642 auto diag = defineOp.emitOpError()
643 <<
"has a source which cannot be lowered by 'LowerDomains'";
644 diag.attachNote(src->getLoc()) <<
"unsupported source is here";
645 return WalkResult::interrupt();
649 return WalkResult::advance();
652 if (walkResult.wasInterrupted())
656 for (
auto *op : conversionsToErase)
661 op.setPortAnnotationsAttr(ArrayAttr::get(
context, portAnnotations));
666LogicalResult LowerModule::lowerInstances() {
668 if (eraseVector.none() && newPorts.empty())
675 if (!isa<FModuleOp, FExtModuleOp>(op))
678 auto *node = instanceGraph.lookup(cast<igraph::ModuleOpInterface>(*op));
679 for (
auto *use :
llvm::make_early_inc_range(node->uses())) {
680 auto instanceOp = dyn_cast<InstanceOp>(*use->getInstance());
682 use->getInstance().emitOpError()
683 <<
"has an unimplemented lowering in LowerDomains";
686 LLVM_DEBUG(llvm::dbgs()
687 <<
" - " << instanceOp.getInstanceName() <<
"\n");
689 for (
auto i : eraseVector.set_bits())
690 indexToDomain[i].temp = stubOut(instanceOp.getResult(i));
692 auto erased = instanceOp.cloneWithErasedPortsAndReplaceUses(eraseVector);
693 auto inserted = erased.cloneWithInsertedPortsAndReplaceUses(newPorts);
694 instanceGraph.replaceInstance(instanceOp, inserted);
696 for (
auto &[i, info] : indexToDomain) {
698 if (
info.inputPort) {
700 splicedValue = inserted.getResult(*
info.inputPort);
704 OpBuilder builder(inserted);
705 builder.setInsertionPointAfter(inserted);
706 splicedValue = ObjectSubfieldOp::create(
707 builder, inserted.getLoc(), inserted.getResult(
info.outputPort), 1);
710 splice(
info.temp, splicedValue);
726 LowerCircuit(CircuitOp circuit,
InstanceGraph &instanceGraph,
727 llvm::Statistic &numDomains)
728 : circuit(circuit), instanceGraph(instanceGraph),
729 constants(circuit.getContext()), numDomains(numDomains) {}
732 LogicalResult lowerCircuit();
736 LogicalResult lowerDomain(DomainOp);
748 llvm::Statistic &numDomains;
752 DenseMap<Attribute, Classes> classes;
755LogicalResult LowerCircuit::lowerDomain(DomainOp op) {
756 ImplicitLocOpBuilder builder(op.getLoc(), op);
757 auto *
context = op.getContext();
758 auto name = op.getNameAttr();
759 SmallVector<PortInfo> classInPorts;
760 for (
auto field : op.getFields().getAsRange<DomainFieldAttr>())
761 classInPorts.
append({{builder.getStringAttr(
762 Twine(field.getName().getValue()) +
"_in"),
763 field.getType(), Direction::In},
764 {builder.getStringAttr(
765 Twine(field.getName().getValue()) +
"_out"),
766 field.getType(), Direction::Out}});
767 auto classIn = ClassOp::create(builder, name, classInPorts);
768 auto classInType = classIn.getInstanceType();
770 ListType::get(
context, cast<PropertyType>(PathType::get(
context)));
772 ClassOp::create(builder, StringAttr::get(
context, Twine(name) +
"_out"),
773 {{constants.getDomainInfoIn(),
776 {constants.getDomainInfoOut(),
779 {constants.getAssociationsIn(),
782 {constants.getAssociationsOut(),
786 auto connectPairWise = [&builder](ClassOp &classOp) {
787 builder.setInsertionPointToStart(classOp.getBodyBlock());
788 for (
size_t i = 0, e = classOp.getNumPorts(); i != e; i += 2)
789 PropAssignOp::create(builder, classOp.getArgument(i + 1),
790 classOp.getArgument(i));
792 connectPairWise(classIn);
793 connectPairWise(classOut);
795 classes.insert({name, {classIn, classOut}});
796 instanceGraph.addModule(classIn);
797 instanceGraph.addModule(classOut);
803LogicalResult LowerCircuit::lowerCircuit() {
804 LLVM_DEBUG(llvm::dbgs() <<
"Processing domains:\n");
805 for (
auto domain :
llvm::make_early_inc_range(circuit.getOps<DomainOp>())) {
806 LLVM_DEBUG(llvm::dbgs() <<
" - " << domain.getName() <<
"\n");
807 if (failed(lowerDomain(domain)))
811 LLVM_DEBUG(llvm::dbgs() <<
"Processing modules:\n");
813 auto moduleOp = dyn_cast<FModuleLike>(node.
getModule<Operation *>());
816 LLVM_DEBUG(llvm::dbgs() <<
" - module: " << moduleOp.getName() <<
"\n");
817 LowerModule lowerModule(moduleOp, classes, constants, instanceGraph);
818 if (failed(lowerModule.lowerModule()))
820 LLVM_DEBUG(llvm::dbgs() <<
" instances:\n");
821 return lowerModule.lowerInstances();
831 LowerCircuit lowerCircuit(getOperation(), getAnalysis<InstanceGraph>(),
833 if (failed(lowerCircuit.lowerCircuit()))
834 return signalPassFailure();
836 markAnalysesPreserved<InstanceGraph>();
assert(baseType &&"element must be base type")
static std::unique_ptr< Context > context
static Location getLoc(DefSlot slot)
static StringAttr append(StringAttr base, const Twine &suffix)
Return a attribute with the specified suffix appended.
#define CIRCT_DEBUG_SCOPED_PASS_LOGGER(PASS)
void runOnOperation() override
This class provides a read-only projection of an annotation.
This graph tracks modules and where they are instantiated.
This is a Node in the InstanceGraph.
auto getModule()
Get the module that this node is tracking.
size_t getNumPorts(Operation *op)
Return the number of ports in a module-like thing (modules, memories, etc)
bool hasZeroBitWidth(FIRRTLType type)
Return true if the type has zero bit width.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
This holds the name and type that describes the module's ports.