48#include "llvm/ADT/MapVector.h"
49#include "llvm/ADT/TypeSwitch.h"
50#include "llvm/Support/Debug.h"
51#include "llvm/Support/Threading.h"
55#define GEN_PASS_DEF_LOWERDOMAINS
56#include "circt/Dialect/FIRRTL/Passes.h.inc"
61using namespace firrtl;
62using mlir::UnrealizedConversionCastOp;
70 impl::LowerDomainsBase<LowerDomainsPass>::getArgumentName().data()
74struct AssociationInfo {
76 DistinctAttr distinctAttr;
93 std::optional<unsigned> inputPort;
104 UnrealizedConversionCastOp temp;
108 SmallVector<AssociationInfo> associations{};
113 static DomainInfo input(
unsigned portIndex) {
114 return DomainInfo({{}, portIndex, portIndex + 1, {}, {}});
120 static DomainInfo output(
unsigned portIndex) {
121 return DomainInfo({{}, std::nullopt, portIndex, {}, {}});
142 llvm::once_flag flag;
148 llvm::once_flag flag;
149 StringAttr domainInfoIn;
150 StringAttr domainInfoOut;
151 StringAttr associationsIn;
152 StringAttr associationsOut;
156 Constants(MLIRContext *context) : context(context) {}
159 ArrayAttr getEmptyArrayAttr() {
160 llvm::call_once(emptyArray.flag,
161 [&] { emptyArray.attr = ArrayAttr::get(context, {}); });
162 return emptyArray.attr;
167 void initClassOut() {
168 llvm::call_once(classOut.flag, [&] {
169 classOut.domainInfoIn = StringAttr::get(context,
"domainInfo_in");
170 classOut.domainInfoOut = StringAttr::get(context,
"domainInfo_out");
171 classOut.associationsIn = StringAttr::get(context,
"associations_in");
172 classOut.associationsOut = StringAttr::get(context,
"associations_out");
178 StringAttr getDomainInfoIn() {
180 return classOut.domainInfoIn;
184 StringAttr getDomainInfoOut() {
186 return classOut.domainInfoOut;
190 StringAttr getAssociationsIn() {
192 return classOut.associationsIn;
196 StringAttr getAssociationsOut() {
198 return classOut.associationsOut;
206 EmptyArray emptyArray;
216static UnrealizedConversionCastOp stubOut(Value value) {
217 if (!value.hasNUsesOrMore(1))
220 OpBuilder builder(value.getContext());
221 builder.setInsertionPointAfterValue(value);
222 auto temp = UnrealizedConversionCastOp::create(
223 builder, builder.getUnknownLoc(), {value.getType()}, {});
224 value.replaceAllUsesWith(temp.getResult(0));
238static UnrealizedConversionCastOp splice(UnrealizedConversionCastOp temp,
244 assert(temp && temp.getNumResults() == 1 && temp.getNumOperands() == 0);
249 auto oldValue = temp.getResult(0);
251 OpBuilder builder(temp);
252 auto splice = UnrealizedConversionCastOp::create(
253 builder, builder.getUnknownLoc(), {oldValue.getType()}, {value});
254 oldValue.replaceAllUsesWith(splice.getResult(0));
268 LowerModule(FModuleLike op,
const DenseMap<Attribute, Classes> &classes,
270 : op(op), eraseVector(op.
getNumPorts()), domainToClasses(classes),
271 constants(constants), instanceGraph(instanceGraph) {}
275 LogicalResult lowerModule();
279 LogicalResult lowerInstances();
286 BitVector eraseVector;
290 SmallVector<std::pair<unsigned, PortInfo>> newPorts;
293 SmallVector<std::pair<unsigned, unsigned>> resultMap;
296 const DenseMap<Attribute, Classes> &domainToClasses;
299 Constants &constants;
309 llvm::MapVector<unsigned, DomainInfo> indexToDomain;
312LogicalResult LowerModule::lowerModule() {
322 TypeSwitch<Operation *, std::optional<Block *>>(op)
323 .Case<FModuleOp>([](
auto op) {
return op.getBodyBlock(); })
324 .Case<FExtModuleOp>([](
auto) {
return nullptr; })
326 .Default([](
auto) {
return std::nullopt; });
329 Block *body = *shouldProcess;
331 auto *
context = op.getContext();
335 SmallVector<Attribute> portAnnotations;
344 OpBuilder::InsertPoint insertPoint;
346 insertPoint = {body, body->begin()};
347 auto ports = op.getPorts();
348 for (
unsigned i = 0, iDel = 0, iIns = 0, e = op.getNumPorts(); i != e; ++i) {
349 auto port = cast<PortInfo>(ports[i]);
352 if (
auto domain = dyn_cast_or_null<FlatSymbolRefAttr>(port.domains)) {
356 auto [classIn, classOut] = domainToClasses.at(domain.getAttr());
358 indexToDomain[i] = port.direction == Direction::In
359 ? DomainInfo::input(iIns)
360 : DomainInfo::output(iIns);
365 ImplicitLocOpBuilder builder(port.loc,
context);
366 builder.restoreInsertionPoint(insertPoint);
369 auto object = ObjectOp::create(
371 StringAttr::get(
context, Twine(port.name) +
"_object"));
372 instanceGraph.lookup(op)->addInstance(
object,
373 instanceGraph.lookup(classOut));
374 indexToDomain[i].op = object;
375 indexToDomain[i].temp = stubOut(body->getArgument(i));
380 insertPoint = builder.saveInsertionPoint();
387 if (port.direction == Direction::In) {
388 newPorts.push_back({iDel,
PortInfo(port.name, classIn.getInstanceType(),
390 portAnnotations.push_back(constants.getEmptyArrayAttr());
395 classOut.getInstanceType(), Direction::Out)});
399 portAnnotations.push_back(constants.getEmptyArrayAttr());
415 ArrayAttr domainAttr = cast_or_null<ArrayAttr>(port.domains);
416 if (!domainAttr || domainAttr.empty()) {
417 portAnnotations.push_back(port.annotations.getArrayAttr());
421 SmallVector<Annotation> newAnnotations;
423 for (
auto indexAttr : domainAttr.getAsRange<IntegerAttr>()) {
425 id = DistinctAttr::create(UnitAttr::get(
context));
426 newAnnotations.push_back(
Annotation(DictionaryAttr::getWithSorted(
430 indexToDomain[indexAttr.getUInt()].associations.push_back({id, port.loc});
432 if (!newAnnotations.empty())
433 port.annotations.addAnnotations(newAnnotations);
434 portAnnotations.push_back(port.annotations.getArrayAttr());
438 op.erasePorts(eraseVector);
439 op.setDomainInfoAttr(constants.getEmptyArrayAttr());
443 op.insertPorts(newPorts);
446 for (
auto const &[_, info] : indexToDomain) {
447 auto [object, inputPort, outputPort, temp, associations] =
info;
448 OpBuilder builder(
object);
449 builder.setInsertionPointAfter(
object);
456 auto subDomainInfoIn =
457 ObjectSubfieldOp::create(builder,
object.
getLoc(),
object, 0);
459 PropAssignOp::create(builder,
object.
getLoc(), subDomainInfoIn,
460 body->getArgument(*inputPort));
461 splice(temp, body->getArgument(*inputPort));
463 splice(temp, subDomainInfoIn);
467 auto subAssociations =
468 ObjectSubfieldOp::create(builder,
object.
getLoc(),
object, 2);
469 SmallVector<Value> paths;
470 for (
auto [
id, loc] : associations) {
471 paths.push_back(PathOp::create(
472 builder, loc, TargetKindAttr::get(
context, TargetKind::Reference),
475 auto list = ListCreateOp::create(
477 ListType::get(
context, cast<PropertyType>(PathType::get(
context))),
479 PropAssignOp::create(builder,
object.
getLoc(), subAssociations, list);
482 PropAssignOp::create(builder,
object.
getLoc(),
483 body->getArgument(outputPort),
object);
497 DenseSet<Operation *> conversionsToErase;
498 DenseSet<Operation *> operationsToErase;
499 auto walkResult = op.walk([&](Operation *walkOp) {
503 if (operationsToErase.contains(walkOp)) {
505 return WalkResult::advance();
509 if (
auto castOp = dyn_cast<UnsafeDomainCastOp>(walkOp)) {
510 for (
auto value : castOp.getDomains()) {
511 auto *conversion = value.getDefiningOp();
512 assert(isa<UnrealizedConversionCastOp>(conversion));
513 conversionsToErase.insert(conversion);
516 castOp.getResult().replaceAllUsesWith(castOp.getInput());
518 return WalkResult::advance();
522 if (
auto anonDomain = dyn_cast<DomainCreateAnonOp>(walkOp)) {
523 conversionsToErase.insert(anonDomain);
524 return WalkResult::advance();
544 if (
auto wireOp = dyn_cast<WireOp>(walkOp)) {
545 if (type_isa<DomainType>(wireOp.getResult().getType())) {
547 DomainDefineOp lastDefineOp;
548 for (
auto *user :
llvm::make_early_inc_range(wireOp->getUsers())) {
551 auto domainDefineOp = dyn_cast<DomainDefineOp>(user);
552 if (!domainDefineOp) {
553 auto diag = wireOp.emitOpError()
554 <<
"cannot be lowered by `LowerDomains` because it "
555 "has a user that is not a domain define op";
556 diag.attachNote(user->getLoc()) <<
"is one such user";
557 return WalkResult::interrupt();
559 if (!lastDefineOp || lastDefineOp->isBeforeInBlock(domainDefineOp))
560 lastDefineOp = domainDefineOp;
561 if (wireOp == domainDefineOp.getSrc().getDefiningOp())
562 dst = domainDefineOp.getDest();
564 src = domainDefineOp.getSrc();
565 operationsToErase.insert(domainDefineOp);
567 conversionsToErase.insert(wireOp);
570 return WalkResult::advance();
574 OpBuilder builder(lastDefineOp);
575 DomainDefineOp::create(builder, builder.getUnknownLoc(), dst, src);
577 return WalkResult::advance();
581 auto defineOp = dyn_cast<DomainDefineOp>(walkOp);
583 return WalkResult::advance();
591 auto *src = defineOp.getSrc().getDefiningOp();
592 auto dest = dyn_cast<UnrealizedConversionCastOp>(
593 defineOp.getDest().getDefiningOp());
595 return WalkResult::advance();
597 conversionsToErase.insert(src);
598 conversionsToErase.insert(dest);
600 if (
auto srcCast = dyn_cast<UnrealizedConversionCastOp>(src)) {
601 assert(srcCast.getNumOperands() == 1 && srcCast.getNumResults() == 1);
602 OpBuilder builder(defineOp);
603 PropAssignOp::create(builder, defineOp.getLoc(), dest.getOperand(0),
604 srcCast.getOperand(0));
605 }
else if (!isa<DomainCreateAnonOp>(src)) {
606 auto diag = defineOp.emitOpError()
607 <<
"has a source which cannot be lowered by 'LowerDomains'";
608 diag.attachNote(src->getLoc()) <<
"unsupported source is here";
609 return WalkResult::interrupt();
613 return WalkResult::advance();
616 if (walkResult.wasInterrupted())
620 for (
auto *op : conversionsToErase)
625 op.setPortAnnotationsAttr(ArrayAttr::get(
context, portAnnotations));
630LogicalResult LowerModule::lowerInstances() {
632 if (eraseVector.none() && newPorts.empty())
639 if (!isa<FModuleOp, FExtModuleOp>(op))
642 auto *node = instanceGraph.lookup(cast<igraph::ModuleOpInterface>(*op));
643 for (
auto *use :
llvm::make_early_inc_range(node->uses())) {
644 auto instanceOp = dyn_cast<InstanceOp>(*use->getInstance());
646 use->getInstance().emitOpError()
647 <<
"has an unimplemented lowering in LowerDomains";
650 LLVM_DEBUG(llvm::dbgs()
651 <<
" - " << instanceOp.getInstanceName() <<
"\n");
653 for (
auto i : eraseVector.set_bits())
654 indexToDomain[i].temp = stubOut(instanceOp.getResult(i));
656 auto erased = instanceOp.cloneWithErasedPortsAndReplaceUses(eraseVector);
657 auto inserted = erased.cloneWithInsertedPortsAndReplaceUses(newPorts);
658 instanceGraph.replaceInstance(instanceOp, inserted);
660 for (
auto &[i, info] : indexToDomain) {
662 if (
info.inputPort) {
664 splicedValue = inserted.getResult(*
info.inputPort);
668 OpBuilder builder(inserted);
669 builder.setInsertionPointAfter(inserted);
670 splicedValue = ObjectSubfieldOp::create(
671 builder, inserted.getLoc(), inserted.getResult(
info.outputPort), 1);
674 splice(
info.temp, splicedValue);
690 LowerCircuit(CircuitOp circuit,
InstanceGraph &instanceGraph,
691 llvm::Statistic &numDomains)
692 : circuit(circuit), instanceGraph(instanceGraph),
693 constants(circuit.getContext()), numDomains(numDomains) {}
696 LogicalResult lowerCircuit();
700 LogicalResult lowerDomain(DomainOp);
712 llvm::Statistic &numDomains;
716 DenseMap<Attribute, Classes> classes;
719LogicalResult LowerCircuit::lowerDomain(DomainOp op) {
720 ImplicitLocOpBuilder builder(op.getLoc(), op);
721 auto *
context = op.getContext();
722 auto name = op.getNameAttr();
723 SmallVector<PortInfo> classInPorts;
724 for (
auto field : op.getFields().getAsRange<DomainFieldAttr>())
725 classInPorts.
append({{builder.getStringAttr(
726 Twine(field.getName().getValue()) +
"_in"),
727 field.getType(), Direction::In},
728 {builder.getStringAttr(
729 Twine(field.getName().getValue()) +
"_out"),
730 field.getType(), Direction::Out}});
731 auto classIn = ClassOp::create(builder, name, classInPorts);
732 auto classInType = classIn.getInstanceType();
734 ListType::get(
context, cast<PropertyType>(PathType::get(
context)));
736 ClassOp::create(builder, StringAttr::get(
context, Twine(name) +
"_out"),
737 {{constants.getDomainInfoIn(),
740 {constants.getDomainInfoOut(),
743 {constants.getAssociationsIn(),
746 {constants.getAssociationsOut(),
750 auto connectPairWise = [&builder](ClassOp &classOp) {
751 builder.setInsertionPointToStart(classOp.getBodyBlock());
752 for (
size_t i = 0, e = classOp.getNumPorts(); i != e; i += 2)
753 PropAssignOp::create(builder, classOp.getArgument(i + 1),
754 classOp.getArgument(i));
756 connectPairWise(classIn);
757 connectPairWise(classOut);
759 classes.insert({name, {classIn, classOut}});
760 instanceGraph.addModule(classIn);
761 instanceGraph.addModule(classOut);
767LogicalResult LowerCircuit::lowerCircuit() {
768 LLVM_DEBUG(llvm::dbgs() <<
"Processing domains:\n");
769 for (
auto domain :
llvm::make_early_inc_range(circuit.getOps<DomainOp>())) {
770 LLVM_DEBUG(llvm::dbgs() <<
" - " << domain.getName() <<
"\n");
771 if (failed(lowerDomain(domain)))
775 LLVM_DEBUG(llvm::dbgs() <<
"Processing modules:\n");
777 auto moduleOp = dyn_cast<FModuleLike>(node.
getModule<Operation *>());
780 LLVM_DEBUG(llvm::dbgs() <<
" - module: " << moduleOp.getName() <<
"\n");
781 LowerModule lowerModule(moduleOp, classes, constants, instanceGraph);
782 if (failed(lowerModule.lowerModule()))
784 LLVM_DEBUG(llvm::dbgs() <<
" instances:\n");
785 return lowerModule.lowerInstances();
795 LowerCircuit lowerCircuit(getOperation(), getAnalysis<InstanceGraph>(),
797 if (failed(lowerCircuit.lowerCircuit()))
798 return signalPassFailure();
800 markAnalysesPreserved<InstanceGraph>();
assert(baseType &&"element must be base type")
static std::unique_ptr< Context > context
static Location getLoc(DefSlot slot)
static StringAttr append(StringAttr base, const Twine &suffix)
Return a attribute with the specified suffix appended.
#define CIRCT_DEBUG_SCOPED_PASS_LOGGER(PASS)
void runOnOperation() override
This class provides a read-only projection of an annotation.
This graph tracks modules and where they are instantiated.
This is a Node in the InstanceGraph.
auto getModule()
Get the module that this node is tracking.
size_t getNumPorts(Operation *op)
Return the number of ports in a module-like thing (modules, memories, etc)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
This holds the name and type that describes the module's ports.