48#include "llvm/ADT/MapVector.h"
49#include "llvm/ADT/TypeSwitch.h"
50#include "llvm/Support/Debug.h"
51#include "llvm/Support/Threading.h"
55#define GEN_PASS_DEF_LOWERDOMAINS
56#include "circt/Dialect/FIRRTL/Passes.h.inc"
61using namespace firrtl;
62using mlir::UnrealizedConversionCastOp;
70 impl::LowerDomainsBase<LowerDomainsPass>::getArgumentName().data()
74struct AssociationInfo {
76 DistinctAttr distinctAttr;
93 std::optional<unsigned> inputPort;
104 UnrealizedConversionCastOp temp;
108 SmallVector<AssociationInfo> associations{};
113 static DomainInfo input(
unsigned portIndex) {
114 return DomainInfo({{}, portIndex, portIndex + 1, {}, {}});
120 static DomainInfo output(
unsigned portIndex) {
121 return DomainInfo({{}, std::nullopt, portIndex, {}, {}});
142 llvm::once_flag flag;
148 llvm::once_flag flag;
149 StringAttr domainInfoIn;
150 StringAttr domainInfoOut;
151 StringAttr associationsIn;
152 StringAttr associationsOut;
156 Constants(MLIRContext *context) : context(context) {}
159 ArrayAttr getEmptyArrayAttr() {
160 llvm::call_once(emptyArray.flag,
161 [&] { emptyArray.attr = ArrayAttr::get(context, {}); });
162 return emptyArray.attr;
167 void initClassOut() {
168 llvm::call_once(classOut.flag, [&] {
169 classOut.domainInfoIn = StringAttr::get(context,
"domainInfo_in");
170 classOut.domainInfoOut = StringAttr::get(context,
"domainInfo_out");
171 classOut.associationsIn = StringAttr::get(context,
"associations_in");
172 classOut.associationsOut = StringAttr::get(context,
"associations_out");
178 StringAttr getDomainInfoIn() {
180 return classOut.domainInfoIn;
184 StringAttr getDomainInfoOut() {
186 return classOut.domainInfoOut;
190 StringAttr getAssociationsIn() {
192 return classOut.associationsIn;
196 StringAttr getAssociationsOut() {
198 return classOut.associationsOut;
206 EmptyArray emptyArray;
216static UnrealizedConversionCastOp stubOut(Value value) {
217 if (!value.hasNUsesOrMore(1))
220 OpBuilder builder(value.getContext());
221 builder.setInsertionPointAfterValue(value);
222 auto temp = UnrealizedConversionCastOp::create(
223 builder, builder.getUnknownLoc(), {value.getType()}, {});
224 value.replaceAllUsesWith(temp.getResult(0));
238static UnrealizedConversionCastOp splice(UnrealizedConversionCastOp temp,
244 assert(temp && temp.getNumResults() == 1 && temp.getNumOperands() == 0);
249 auto oldValue = temp.getResult(0);
251 OpBuilder builder(temp);
252 auto splice = UnrealizedConversionCastOp::create(
253 builder, builder.getUnknownLoc(), {oldValue.getType()}, {value});
254 oldValue.replaceAllUsesWith(splice.getResult(0));
268 LowerModule(FModuleLike op,
const DenseMap<Attribute, Classes> &classes,
270 : op(op), eraseVector(op.
getNumPorts()), domainToClasses(classes),
271 constants(constants), instanceGraph(instanceGraph) {}
275 LogicalResult lowerModule();
279 LogicalResult lowerInstances();
286 BitVector eraseVector;
290 SmallVector<std::pair<unsigned, PortInfo>> newPorts;
293 SmallVector<std::pair<unsigned, unsigned>> resultMap;
296 const DenseMap<Attribute, Classes> &domainToClasses;
299 Constants &constants;
309 llvm::MapVector<unsigned, DomainInfo> indexToDomain;
312LogicalResult LowerModule::lowerModule() {
322 TypeSwitch<Operation *, std::optional<Block *>>(op)
323 .Case<FModuleOp>([](
auto op) {
return op.getBodyBlock(); })
324 .Case<FExtModuleOp>([](
auto) {
return nullptr; })
326 .Default([](
auto) {
return std::nullopt; });
329 Block *body = *shouldProcess;
331 auto *
context = op.getContext();
335 SmallVector<Attribute> portAnnotations;
344 OpBuilder::InsertPoint insertPoint;
346 insertPoint = {body, body->begin()};
347 auto ports = op.getPorts();
348 for (
unsigned i = 0, iDel = 0, iIns = 0, e = op.getNumPorts(); i != e; ++i) {
349 auto port = cast<PortInfo>(ports[i]);
352 if (
auto domain = dyn_cast_or_null<FlatSymbolRefAttr>(port.domains)) {
356 auto [classIn, classOut] = domainToClasses.at(domain.getAttr());
358 indexToDomain[i] = port.direction == Direction::In
359 ? DomainInfo::input(iIns)
360 : DomainInfo::output(iIns);
365 ImplicitLocOpBuilder builder(port.loc,
context);
366 builder.restoreInsertionPoint(insertPoint);
369 auto object = ObjectOp::create(
371 StringAttr::get(
context, Twine(port.name) +
"_object"));
372 instanceGraph.lookup(op)->addInstance(
object,
373 instanceGraph.lookup(classOut));
374 indexToDomain[i].op = object;
375 indexToDomain[i].temp = stubOut(body->getArgument(i));
380 insertPoint = builder.saveInsertionPoint();
387 if (port.direction == Direction::In) {
388 newPorts.push_back({iDel,
PortInfo(port.name, classIn.getInstanceType(),
390 portAnnotations.push_back(constants.getEmptyArrayAttr());
395 classOut.getInstanceType(), Direction::Out)});
399 portAnnotations.push_back(constants.getEmptyArrayAttr());
415 ArrayAttr domainAttr = cast_or_null<ArrayAttr>(port.domains);
416 if (!domainAttr || domainAttr.empty()) {
417 portAnnotations.push_back(port.annotations.getArrayAttr());
424 if (
auto firrtlType = type_dyn_cast<FIRRTLType>(port.type))
426 portAnnotations.push_back(port.annotations.getArrayAttr());
430 SmallVector<Annotation> newAnnotations;
432 for (
auto indexAttr : domainAttr.getAsRange<IntegerAttr>()) {
434 id = DistinctAttr::create(UnitAttr::get(
context));
435 newAnnotations.push_back(
Annotation(DictionaryAttr::getWithSorted(
439 indexToDomain[indexAttr.getUInt()].associations.push_back({id, port.loc});
441 if (!newAnnotations.empty())
442 port.annotations.addAnnotations(newAnnotations);
443 portAnnotations.push_back(port.annotations.getArrayAttr());
447 op.erasePorts(eraseVector);
448 op.setDomainInfoAttr(constants.getEmptyArrayAttr());
452 op.insertPorts(newPorts);
455 for (
auto const &[_, info] : indexToDomain) {
456 auto [object, inputPort, outputPort, temp, associations] =
info;
457 OpBuilder builder(
object);
458 builder.setInsertionPointAfter(
object);
465 auto subDomainInfoIn =
466 ObjectSubfieldOp::create(builder,
object.
getLoc(),
object, 0);
468 PropAssignOp::create(builder,
object.
getLoc(), subDomainInfoIn,
469 body->getArgument(*inputPort));
470 splice(temp, body->getArgument(*inputPort));
472 splice(temp, subDomainInfoIn);
476 auto subAssociations =
477 ObjectSubfieldOp::create(builder,
object.
getLoc(),
object, 2);
478 SmallVector<Value> paths;
479 for (
auto [
id, loc] : associations) {
480 paths.push_back(PathOp::create(
481 builder, loc, TargetKindAttr::get(
context, TargetKind::Reference),
484 auto list = ListCreateOp::create(
486 ListType::get(
context, cast<PropertyType>(PathType::get(
context))),
488 PropAssignOp::create(builder,
object.
getLoc(), subAssociations, list);
491 PropAssignOp::create(builder,
object.
getLoc(),
492 body->getArgument(outputPort),
object);
506 DenseSet<Operation *> conversionsToErase;
507 DenseSet<Operation *> operationsToErase;
508 auto walkResult = op.walk([&](Operation *walkOp) {
512 if (operationsToErase.contains(walkOp)) {
514 return WalkResult::advance();
518 if (
auto castOp = dyn_cast<UnsafeDomainCastOp>(walkOp)) {
519 for (
auto value : castOp.getDomains()) {
520 auto *conversion = value.getDefiningOp();
521 assert(isa<UnrealizedConversionCastOp>(conversion));
522 conversionsToErase.insert(conversion);
525 castOp.getResult().replaceAllUsesWith(castOp.getInput());
527 return WalkResult::advance();
531 if (
auto anonDomain = dyn_cast<DomainCreateAnonOp>(walkOp)) {
532 conversionsToErase.insert(anonDomain);
533 return WalkResult::advance();
553 if (
auto wireOp = dyn_cast<WireOp>(walkOp)) {
554 if (type_isa<DomainType>(wireOp.getResult().getType())) {
556 SmallVector<Value> dsts;
557 DomainDefineOp lastDefineOp;
558 for (
auto *user :
llvm::make_early_inc_range(wireOp->getUsers())) {
559 auto domainDefineOp = dyn_cast<DomainDefineOp>(user);
560 if (operationsToErase.contains(domainDefineOp))
562 if (!domainDefineOp) {
563 auto diag = wireOp.emitOpError()
564 <<
"cannot be lowered by `LowerDomains` because it "
565 "has a user that is not a domain define op";
566 diag.attachNote(user->getLoc()) <<
"is one such user";
567 return WalkResult::interrupt();
569 if (!lastDefineOp || lastDefineOp->isBeforeInBlock(domainDefineOp))
570 lastDefineOp = domainDefineOp;
571 if (wireOp == domainDefineOp.getSrc().getDefiningOp())
572 dsts.push_back(domainDefineOp.getDest());
574 src = domainDefineOp.getSrc();
575 operationsToErase.insert(domainDefineOp);
577 conversionsToErase.insert(wireOp);
580 if (!src || dsts.empty())
581 return WalkResult::advance();
585 OpBuilder builder(lastDefineOp);
586 for (
auto dst :
llvm::reverse(dsts))
587 DomainDefineOp::create(builder, builder.getUnknownLoc(), dst, src);
589 return WalkResult::advance();
593 auto defineOp = dyn_cast<DomainDefineOp>(walkOp);
595 return WalkResult::advance();
603 auto *src = defineOp.getSrc().getDefiningOp();
604 auto dest = dyn_cast<UnrealizedConversionCastOp>(
605 defineOp.getDest().getDefiningOp());
607 return WalkResult::advance();
609 conversionsToErase.insert(src);
610 conversionsToErase.insert(dest);
612 if (
auto srcCast = dyn_cast<UnrealizedConversionCastOp>(src)) {
613 assert(srcCast.getNumOperands() == 1 && srcCast.getNumResults() == 1);
614 OpBuilder builder(defineOp);
615 PropAssignOp::create(builder, defineOp.getLoc(), dest.getOperand(0),
616 srcCast.getOperand(0));
617 }
else if (!isa<DomainCreateAnonOp>(src)) {
618 auto diag = defineOp.emitOpError()
619 <<
"has a source which cannot be lowered by 'LowerDomains'";
620 diag.attachNote(src->getLoc()) <<
"unsupported source is here";
621 return WalkResult::interrupt();
625 return WalkResult::advance();
628 if (walkResult.wasInterrupted())
632 for (
auto *op : conversionsToErase)
637 op.setPortAnnotationsAttr(ArrayAttr::get(
context, portAnnotations));
642LogicalResult LowerModule::lowerInstances() {
644 if (eraseVector.none() && newPorts.empty())
651 if (!isa<FModuleOp, FExtModuleOp>(op))
654 auto *node = instanceGraph.lookup(cast<igraph::ModuleOpInterface>(*op));
655 for (
auto *use :
llvm::make_early_inc_range(node->uses())) {
656 auto instanceOp = dyn_cast<InstanceOp>(*use->getInstance());
658 use->getInstance().emitOpError()
659 <<
"has an unimplemented lowering in LowerDomains";
662 LLVM_DEBUG(llvm::dbgs()
663 <<
" - " << instanceOp.getInstanceName() <<
"\n");
665 for (
auto i : eraseVector.set_bits())
666 indexToDomain[i].temp = stubOut(instanceOp.getResult(i));
668 auto erased = instanceOp.cloneWithErasedPortsAndReplaceUses(eraseVector);
669 auto inserted = erased.cloneWithInsertedPortsAndReplaceUses(newPorts);
670 instanceGraph.replaceInstance(instanceOp, inserted);
672 for (
auto &[i, info] : indexToDomain) {
674 if (
info.inputPort) {
676 splicedValue = inserted.getResult(*
info.inputPort);
680 OpBuilder builder(inserted);
681 builder.setInsertionPointAfter(inserted);
682 splicedValue = ObjectSubfieldOp::create(
683 builder, inserted.getLoc(), inserted.getResult(
info.outputPort), 1);
686 splice(
info.temp, splicedValue);
702 LowerCircuit(CircuitOp circuit,
InstanceGraph &instanceGraph,
703 llvm::Statistic &numDomains)
704 : circuit(circuit), instanceGraph(instanceGraph),
705 constants(circuit.getContext()), numDomains(numDomains) {}
708 LogicalResult lowerCircuit();
712 LogicalResult lowerDomain(DomainOp);
724 llvm::Statistic &numDomains;
728 DenseMap<Attribute, Classes> classes;
731LogicalResult LowerCircuit::lowerDomain(DomainOp op) {
732 ImplicitLocOpBuilder builder(op.getLoc(), op);
733 auto *
context = op.getContext();
734 auto name = op.getNameAttr();
735 SmallVector<PortInfo> classInPorts;
736 for (
auto field : op.getFields().getAsRange<DomainFieldAttr>())
737 classInPorts.
append({{builder.getStringAttr(
738 Twine(field.getName().getValue()) +
"_in"),
739 field.getType(), Direction::In},
740 {builder.getStringAttr(
741 Twine(field.getName().getValue()) +
"_out"),
742 field.getType(), Direction::Out}});
743 auto classIn = ClassOp::create(builder, name, classInPorts);
744 auto classInType = classIn.getInstanceType();
746 ListType::get(
context, cast<PropertyType>(PathType::get(
context)));
748 ClassOp::create(builder, StringAttr::get(
context, Twine(name) +
"_out"),
749 {{constants.getDomainInfoIn(),
752 {constants.getDomainInfoOut(),
755 {constants.getAssociationsIn(),
758 {constants.getAssociationsOut(),
762 auto connectPairWise = [&builder](ClassOp &classOp) {
763 builder.setInsertionPointToStart(classOp.getBodyBlock());
764 for (
size_t i = 0, e = classOp.getNumPorts(); i != e; i += 2)
765 PropAssignOp::create(builder, classOp.getArgument(i + 1),
766 classOp.getArgument(i));
768 connectPairWise(classIn);
769 connectPairWise(classOut);
771 classes.insert({name, {classIn, classOut}});
772 instanceGraph.addModule(classIn);
773 instanceGraph.addModule(classOut);
779LogicalResult LowerCircuit::lowerCircuit() {
780 LLVM_DEBUG(llvm::dbgs() <<
"Processing domains:\n");
781 for (
auto domain :
llvm::make_early_inc_range(circuit.getOps<DomainOp>())) {
782 LLVM_DEBUG(llvm::dbgs() <<
" - " << domain.getName() <<
"\n");
783 if (failed(lowerDomain(domain)))
787 LLVM_DEBUG(llvm::dbgs() <<
"Processing modules:\n");
789 auto moduleOp = dyn_cast<FModuleLike>(node.
getModule<Operation *>());
792 LLVM_DEBUG(llvm::dbgs() <<
" - module: " << moduleOp.getName() <<
"\n");
793 LowerModule lowerModule(moduleOp, classes, constants, instanceGraph);
794 if (failed(lowerModule.lowerModule()))
796 LLVM_DEBUG(llvm::dbgs() <<
" instances:\n");
797 return lowerModule.lowerInstances();
807 LowerCircuit lowerCircuit(getOperation(), getAnalysis<InstanceGraph>(),
809 if (failed(lowerCircuit.lowerCircuit()))
810 return signalPassFailure();
812 markAnalysesPreserved<InstanceGraph>();
assert(baseType &&"element must be base type")
static std::unique_ptr< Context > context
static Location getLoc(DefSlot slot)
static StringAttr append(StringAttr base, const Twine &suffix)
Return a attribute with the specified suffix appended.
#define CIRCT_DEBUG_SCOPED_PASS_LOGGER(PASS)
void runOnOperation() override
This class provides a read-only projection of an annotation.
This graph tracks modules and where they are instantiated.
This is a Node in the InstanceGraph.
auto getModule()
Get the module that this node is tracking.
size_t getNumPorts(Operation *op)
Return the number of ports in a module-like thing (modules, memories, etc)
bool hasZeroBitWidth(FIRRTLType type)
Return true if the type has zero bit width.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
This holds the name and type that describes the module's ports.