50#include "llvm/ADT/MapVector.h"
51#include "llvm/ADT/TypeSwitch.h"
52#include "llvm/Support/Debug.h"
53#include "llvm/Support/Threading.h"
57#define GEN_PASS_DEF_LOWERDOMAINS
58#include "circt/Dialect/FIRRTL/Passes.h.inc"
63using namespace firrtl;
64using mlir::UnrealizedConversionCastOp;
72 impl::LowerDomainsBase<LowerDomainsPass>::getArgumentName().data()
76struct AssociationInfo {
78 DistinctAttr distinctAttr;
95 std::optional<unsigned> inputPort;
106 UnrealizedConversionCastOp temp;
110 SmallVector<AssociationInfo> associations{};
115 static DomainInfo input(
unsigned portIndex) {
116 return DomainInfo({{}, portIndex, portIndex + 1, {}, {}});
122 static DomainInfo output(
unsigned portIndex) {
123 return DomainInfo({{}, std::nullopt, portIndex, {}, {}});
144 llvm::once_flag flag;
150 llvm::once_flag flag;
151 StringAttr domainInfoIn;
152 StringAttr domainInfoOut;
153 StringAttr associationsIn;
154 StringAttr associationsOut;
158 Constants(MLIRContext *context) : context(context) {}
161 ArrayAttr getEmptyArrayAttr() {
162 llvm::call_once(emptyArray.flag,
163 [&] { emptyArray.attr = ArrayAttr::get(context, {}); });
164 return emptyArray.attr;
169 void initClassOut() {
170 llvm::call_once(classOut.flag, [&] {
171 classOut.domainInfoIn = StringAttr::get(context,
"domainInfo_in");
172 classOut.domainInfoOut = StringAttr::get(context,
"domainInfo_out");
173 classOut.associationsIn = StringAttr::get(context,
"associations_in");
174 classOut.associationsOut = StringAttr::get(context,
"associations_out");
180 StringAttr getDomainInfoIn() {
182 return classOut.domainInfoIn;
186 StringAttr getDomainInfoOut() {
188 return classOut.domainInfoOut;
192 StringAttr getAssociationsIn() {
194 return classOut.associationsIn;
198 StringAttr getAssociationsOut() {
200 return classOut.associationsOut;
208 EmptyArray emptyArray;
218static UnrealizedConversionCastOp stubOut(Value value) {
219 if (!value.hasNUsesOrMore(1))
222 OpBuilder builder(value.getContext());
223 builder.setInsertionPointAfterValue(value);
224 auto temp = UnrealizedConversionCastOp::create(
225 builder, builder.getUnknownLoc(), {value.getType()}, {});
226 value.replaceAllUsesWith(temp.getResult(0));
240static UnrealizedConversionCastOp splice(UnrealizedConversionCastOp temp,
246 assert(temp && temp.getNumResults() == 1 && temp.getNumOperands() == 0);
251 auto oldValue = temp.getResult(0);
253 OpBuilder builder(temp);
254 auto splice = UnrealizedConversionCastOp::create(
255 builder, builder.getUnknownLoc(), {oldValue.getType()}, {value});
256 oldValue.replaceAllUsesWith(splice.getResult(0));
270 LowerModule(FModuleLike op,
const DenseMap<Attribute, Classes> &classes,
272 : op(op), eraseVector(op.
getNumPorts()), domainToClasses(classes),
273 constants(constants), instanceGraph(instanceGraph) {}
277 LogicalResult lowerModule();
281 LogicalResult lowerInstances();
288 BitVector eraseVector;
292 SmallVector<std::pair<unsigned, PortInfo>> newPorts;
295 SmallVector<std::pair<unsigned, unsigned>> resultMap;
298 const DenseMap<Attribute, Classes> &domainToClasses;
301 Constants &constants;
314LogicalResult LowerModule::lowerModule() {
324 TypeSwitch<Operation *, std::optional<Block *>>(op)
325 .Case<FModuleOp>([](
auto op) {
return op.getBodyBlock(); })
326 .Case<FExtModuleOp>([](
auto) {
return nullptr; })
328 .Default([](
auto) {
return std::nullopt; });
331 Block *body = *shouldProcess;
333 auto *
context = op.getContext();
337 SmallVector<Attribute> portAnnotations;
346 OpBuilder::InsertPoint insertPoint;
348 insertPoint = {body, body->begin()};
349 auto ports = op.getPorts();
350 for (
unsigned i = 0, iDel = 0, iIns = 0, e = op.getNumPorts(); i != e; ++i) {
351 auto port = cast<PortInfo>(ports[i]);
355 if (
auto domainType = dyn_cast<DomainType>(port.type)) {
359 auto domain = domainType.getName();
360 auto [classIn, classOut] = domainToClasses.at(domain.getAttr());
362 indexToDomain[i] = port.direction == Direction::In
363 ? DomainInfo::input(iIns)
364 : DomainInfo::output(iIns);
369 ImplicitLocOpBuilder builder(port.loc,
context);
370 builder.restoreInsertionPoint(insertPoint);
373 auto object = ObjectOp::create(
375 StringAttr::get(
context, Twine(port.name) +
"_object"));
376 instanceGraph.lookup(op)->addInstance(
object,
377 instanceGraph.lookup(classOut));
378 indexToDomain[i].op = object;
379 indexToDomain[i].temp = stubOut(body->getArgument(i));
384 insertPoint = builder.saveInsertionPoint();
391 if (port.direction == Direction::In) {
392 newPorts.push_back({iDel,
PortInfo(port.name, classIn.getInstanceType(),
394 portAnnotations.push_back(constants.getEmptyArrayAttr());
399 classOut.getInstanceType(), Direction::Out)});
403 portAnnotations.push_back(constants.getEmptyArrayAttr());
419 ArrayAttr domainAttr = cast_or_null<ArrayAttr>(port.domains);
420 if (!domainAttr || domainAttr.empty()) {
421 portAnnotations.push_back(port.annotations.getArrayAttr());
428 if (
auto firrtlType = type_dyn_cast<FIRRTLType>(port.type))
430 portAnnotations.push_back(port.annotations.getArrayAttr());
434 SmallVector<Annotation> newAnnotations;
436 for (
auto indexAttr : domainAttr.getAsRange<IntegerAttr>()) {
438 id = DistinctAttr::create(UnitAttr::get(
context));
439 newAnnotations.push_back(
Annotation(DictionaryAttr::getWithSorted(
443 indexToDomain[indexAttr.getUInt()].associations.push_back({id, port.loc});
445 if (!newAnnotations.empty())
446 port.annotations.addAnnotations(newAnnotations);
447 portAnnotations.push_back(port.annotations.getArrayAttr());
451 op.erasePorts(eraseVector);
452 op.setDomainInfoAttr(constants.getEmptyArrayAttr());
456 op.insertPorts(newPorts);
459 for (
auto const &[_, info] : indexToDomain) {
460 auto [object, inputPort, outputPort, temp, associations] =
info;
461 OpBuilder builder(
object);
462 builder.setInsertionPointAfter(
object);
469 auto subDomainInfoIn =
470 ObjectSubfieldOp::create(builder,
object.
getLoc(),
object, 0);
472 PropAssignOp::create(builder,
object.
getLoc(), subDomainInfoIn,
473 body->getArgument(*inputPort));
474 splice(temp, body->getArgument(*inputPort));
476 splice(temp, subDomainInfoIn);
480 auto subAssociations =
481 ObjectSubfieldOp::create(builder,
object.
getLoc(),
object, 2);
482 SmallVector<Value> paths;
483 for (
auto [
id, loc] : associations) {
484 paths.push_back(PathOp::create(
485 builder, loc, TargetKindAttr::get(
context, TargetKind::Reference),
488 auto list = ListCreateOp::create(
490 ListType::get(
context, cast<PropertyType>(PathType::get(
context))),
492 PropAssignOp::create(builder,
object.
getLoc(), subAssociations, list);
495 PropAssignOp::create(builder,
object.
getLoc(),
496 body->getArgument(outputPort),
object);
510 DenseSet<Operation *> conversionsToErase;
511 DenseSet<Operation *> operationsToErase;
512 auto walkResult = op.walk([&](Operation *walkOp) {
516 if (operationsToErase.contains(walkOp)) {
518 return WalkResult::advance();
522 if (
auto castOp = dyn_cast<UnsafeDomainCastOp>(walkOp)) {
523 for (
auto value : castOp.getDomains()) {
524 auto *conversion = value.getDefiningOp();
525 assert(isa<UnrealizedConversionCastOp>(conversion));
526 conversionsToErase.insert(conversion);
529 castOp.getResult().replaceAllUsesWith(castOp.getInput());
531 return WalkResult::advance();
537 if (
auto anonDomain = dyn_cast<DomainCreateAnonOp>(walkOp)) {
539 auto noUser = llvm::all_of(anonDomain->getUsers(), [&](
auto *user) {
540 return operationsToErase.contains(user) ||
541 conversionsToErase.contains(user);
544 conversionsToErase.insert(anonDomain);
545 return WalkResult::advance();
550 OpBuilder builder(anonDomain);
552 domainToClasses.at(anonDomain.getDomainAttr().getAttr()).input;
553 anonDomain.replaceAllUsesWith(UnrealizedConversionCastOp::create(
554 builder, anonDomain.getLoc(), {anonDomain.getType()},
555 {UnknownValueOp::create(builder, anonDomain.getLoc(),
556 classIn.getInstanceType())
559 return WalkResult::advance();
563 if (
auto createDomain = dyn_cast<DomainCreateOp>(walkOp)) {
564 auto noUser = llvm::all_of(createDomain->getUsers(), [&](
auto *user) {
565 return operationsToErase.contains(user) ||
566 conversionsToErase.contains(user);
569 conversionsToErase.insert(createDomain);
570 return WalkResult::advance();
573 OpBuilder builder(createDomain);
575 domainToClasses.at(createDomain.getDomainAttr().getAttr()).input;
576 auto object = ObjectOp::create(builder, createDomain.getLoc(), classIn,
577 createDomain.getNameAttr());
578 instanceGraph.lookup(op)->addInstance(
object,
579 instanceGraph.lookup(classIn));
582 auto fieldValues = createDomain.getFieldValues();
586 for (
auto [fieldIdx, fieldValue] :
llvm::enumerate(fieldValues)) {
587 auto inputPortIdx = fieldIdx * 2;
588 auto subfield = ObjectSubfieldOp::create(
589 builder, createDomain.getLoc(),
object, inputPortIdx);
590 PropAssignOp::create(builder, createDomain.getLoc(), subfield,
594 createDomain.replaceAllUsesWith(UnrealizedConversionCastOp::create(
595 builder, createDomain.getLoc(), {createDomain.getType()},
596 {object.getResult()}));
597 createDomain.erase();
598 return WalkResult::advance();
602 if (
auto subfieldOp = dyn_cast<DomainSubfieldOp>(walkOp)) {
605 auto *inputOp = subfieldOp.getInput().getDefiningOp();
607 subfieldOp.emitOpError(
608 "has an input that is not defined by an operation");
609 return WalkResult::interrupt();
612 auto conversionCast = dyn_cast<UnrealizedConversionCastOp>(inputOp);
613 if (!conversionCast || conversionCast.getNumOperands() != 1) {
614 subfieldOp.emitOpError(
615 "has an input that is not a conversion cast with one operand");
616 return WalkResult::interrupt();
621 auto fieldIndex = subfieldOp.getFieldIndex();
622 auto outputPortIndex = fieldIndex * 2 + 1;
625 OpBuilder builder(subfieldOp);
626 auto objectSubfield = ObjectSubfieldOp::create(
627 builder, subfieldOp.getLoc(), conversionCast.getOperand(0),
631 conversionsToErase.insert(conversionCast);
633 subfieldOp.replaceAllUsesWith(objectSubfield.getResult());
635 return WalkResult::advance();
657 if (
auto wireOp = dyn_cast<WireOp>(walkOp)) {
658 if (type_isa<DomainType>(wireOp.getResult().getType())) {
660 SmallVector<Value> dsts;
661 DomainDefineOp lastDefineOp;
662 for (
auto *user :
llvm::make_early_inc_range(wireOp->getUsers())) {
663 auto domainDefineOp = dyn_cast<DomainDefineOp>(user);
664 if (operationsToErase.contains(domainDefineOp))
666 if (!domainDefineOp) {
667 auto diag = wireOp.emitOpError()
668 <<
"cannot be lowered by `LowerDomains` because it "
669 "has a user that is not a domain define op";
670 diag.attachNote(user->getLoc()) <<
"is one such user";
671 return WalkResult::interrupt();
673 if (!lastDefineOp || lastDefineOp->isBeforeInBlock(domainDefineOp))
674 lastDefineOp = domainDefineOp;
675 if (wireOp == domainDefineOp.getSrc().getDefiningOp())
676 dsts.push_back(domainDefineOp.getDest());
678 src = domainDefineOp.getSrc();
679 operationsToErase.insert(domainDefineOp);
681 conversionsToErase.insert(wireOp);
684 if (!src || dsts.empty())
685 return WalkResult::advance();
689 OpBuilder builder(lastDefineOp);
690 for (
auto dst :
llvm::reverse(dsts))
691 DomainDefineOp::create(builder, builder.getUnknownLoc(), dst, src);
697 if (!wireOp.getDomains().empty()) {
698 for (
auto domain : wireOp.getDomains())
699 if (auto *defOp = domain.getDefiningOp())
700 conversionsToErase.insert(defOp);
701 wireOp->eraseOperands(0, wireOp.getNumOperands());
704 return WalkResult::advance();
708 auto defineOp = dyn_cast<DomainDefineOp>(walkOp);
710 return WalkResult::advance();
718 auto *src = defineOp.getSrc().getDefiningOp();
719 auto dest = dyn_cast<UnrealizedConversionCastOp>(
720 defineOp.getDest().getDefiningOp());
722 return WalkResult::advance();
724 conversionsToErase.insert(src);
725 conversionsToErase.insert(dest);
727 if (
auto srcCast = dyn_cast<UnrealizedConversionCastOp>(src)) {
728 assert(srcCast.getNumOperands() == 1 && srcCast.getNumResults() == 1);
729 OpBuilder builder(defineOp);
730 PropAssignOp::create(builder, defineOp.getLoc(), dest.getOperand(0),
731 srcCast.getOperand(0));
732 }
else if (!isa<DomainCreateAnonOp, DomainCreateOp>(src)) {
733 auto diag = defineOp.emitOpError()
734 <<
"has a source which cannot be lowered by 'LowerDomains'";
735 diag.attachNote(src->getLoc()) <<
"unsupported source is here";
736 return WalkResult::interrupt();
740 return WalkResult::advance();
743 if (walkResult.wasInterrupted())
747 for (
auto *op : conversionsToErase)
752 op.setPortAnnotationsAttr(ArrayAttr::get(
context, portAnnotations));
757LogicalResult LowerModule::lowerInstances() {
759 if (eraseVector.none() && newPorts.empty())
766 if (!isa<FModuleOp, FExtModuleOp>(op))
769 auto *node = instanceGraph.lookup(cast<igraph::ModuleOpInterface>(*op));
770 for (
auto *use :
llvm::make_early_inc_range(node->uses())) {
771 auto instanceOp = dyn_cast<InstanceOp>(*use->getInstance());
773 use->getInstance().emitOpError()
774 <<
"has an unimplemented lowering in LowerDomains";
777 LLVM_DEBUG(llvm::dbgs()
778 <<
" - " << instanceOp.getInstanceName() <<
"\n");
780 for (
auto i : eraseVector.set_bits())
781 indexToDomain[i].temp = stubOut(instanceOp.getResult(i));
783 auto erased = instanceOp.cloneWithErasedPortsAndReplaceUses(eraseVector);
784 auto inserted = erased.cloneWithInsertedPortsAndReplaceUses(newPorts);
785 instanceGraph.replaceInstance(instanceOp, inserted);
787 for (
auto &[i, info] : indexToDomain) {
789 if (
info.inputPort) {
791 splicedValue = inserted->getResult(*
info.inputPort);
795 OpBuilder builder(inserted);
796 builder.setInsertionPointAfter(inserted);
798 ObjectSubfieldOp::create(builder, inserted.getLoc(),
799 inserted->getResult(
info.outputPort), 1);
802 splice(
info.temp, splicedValue);
818 LowerCircuit(CircuitOp circuit,
InstanceGraph &instanceGraph,
819 llvm::Statistic &numDomains)
820 : circuit(circuit), instanceGraph(instanceGraph),
821 constants(circuit.getContext()), numDomains(numDomains) {}
824 LogicalResult lowerCircuit();
828 LogicalResult lowerDomain(DomainOp);
840 llvm::Statistic &numDomains;
844 DenseMap<Attribute, Classes> classes;
847LogicalResult LowerCircuit::lowerDomain(DomainOp op) {
848 ImplicitLocOpBuilder builder(op.getLoc(), op);
849 auto *
context = op.getContext();
850 auto name = op.getNameAttr();
851 SmallVector<PortInfo> classInPorts;
852 for (
auto field : op.getFields().getAsRange<DomainFieldAttr>())
853 classInPorts.
append({{builder.getStringAttr(
854 Twine(field.getName().getValue()) +
"_in"),
855 field.getType(), Direction::In},
856 {builder.getStringAttr(
857 Twine(field.getName().getValue()) +
"_out"),
858 field.getType(), Direction::Out}});
859 auto classIn = ClassOp::create(builder, name, classInPorts);
860 auto classInType = classIn.getInstanceType();
862 ListType::get(
context, cast<PropertyType>(PathType::get(
context)));
864 ClassOp::create(builder, StringAttr::get(
context, Twine(name) +
"_out"),
865 {{constants.getDomainInfoIn(),
868 {constants.getDomainInfoOut(),
871 {constants.getAssociationsIn(),
874 {constants.getAssociationsOut(),
878 auto connectPairWise = [&builder](ClassOp &classOp) {
879 builder.setInsertionPointToStart(classOp.getBodyBlock());
880 for (
size_t i = 0, e = classOp.getNumPorts(); i != e; i += 2)
881 PropAssignOp::create(builder, classOp.getArgument(i + 1),
882 classOp.getArgument(i));
884 connectPairWise(classIn);
885 connectPairWise(classOut);
887 classes.insert({name, {classIn, classOut}});
888 instanceGraph.addModule(classIn);
889 instanceGraph.addModule(classOut);
895LogicalResult LowerCircuit::lowerCircuit() {
896 LLVM_DEBUG(llvm::dbgs() <<
"Processing domains:\n");
897 for (
auto domain :
llvm::make_early_inc_range(circuit.getOps<DomainOp>())) {
898 LLVM_DEBUG(llvm::dbgs() <<
" - " << domain.getName() <<
"\n");
899 if (failed(lowerDomain(domain)))
903 LLVM_DEBUG(llvm::dbgs() <<
"Processing modules:\n");
905 auto moduleOp = dyn_cast<FModuleLike>(node.
getModule<Operation *>());
908 LLVM_DEBUG(llvm::dbgs() <<
" - module: " << moduleOp.getName() <<
"\n");
909 LowerModule lowerModule(moduleOp, classes, constants, instanceGraph);
910 if (failed(lowerModule.lowerModule()))
912 LLVM_DEBUG(llvm::dbgs() <<
" instances:\n");
913 return lowerModule.lowerInstances();
923 LowerCircuit lowerCircuit(getOperation(), getAnalysis<InstanceGraph>(),
925 if (failed(lowerCircuit.lowerCircuit()))
926 return signalPassFailure();
928 markAnalysesPreserved<InstanceGraph>();
assert(baseType &&"element must be base type")
static std::unique_ptr< Context > context
static Location getLoc(DefSlot slot)
static StringAttr append(StringAttr base, const Twine &suffix)
Return a attribute with the specified suffix appended.
#define CIRCT_DEBUG_SCOPED_PASS_LOGGER(PASS)
void runOnOperation() override
This class provides a read-only projection of an annotation.
This graph tracks modules and where they are instantiated.
This is a Node in the InstanceGraph.
auto getModule()
Get the module that this node is tracking.
size_t getNumPorts(Operation *op)
Return the number of ports in a module-like thing (modules, memories, etc)
bool hasZeroBitWidth(FIRRTLType type)
Return true if the type has zero bit width.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
This holds the name and type that describes the module's ports.