50#include "llvm/ADT/MapVector.h"
51#include "llvm/ADT/TypeSwitch.h"
52#include "llvm/Support/Debug.h"
53#include "llvm/Support/Threading.h"
57#define GEN_PASS_DEF_LOWERDOMAINS
58#include "circt/Dialect/FIRRTL/Passes.h.inc"
63using namespace firrtl;
64using mlir::UnrealizedConversionCastOp;
72 impl::LowerDomainsBase<LowerDomainsPass>::getArgumentName().data()
76struct AssociationInfo {
78 DistinctAttr distinctAttr;
95 std::optional<unsigned> inputPort;
106 UnrealizedConversionCastOp temp;
110 SmallVector<AssociationInfo> associations{};
115 static DomainInfo input(
unsigned portIndex) {
116 return DomainInfo({{}, portIndex, portIndex + 1, {}, {}});
122 static DomainInfo output(
unsigned portIndex) {
123 return DomainInfo({{}, std::nullopt, portIndex, {}, {}});
144 llvm::once_flag flag;
150 llvm::once_flag flag;
151 StringAttr domainInfoIn;
152 StringAttr domainInfoOut;
153 StringAttr associationsIn;
154 StringAttr associationsOut;
158 Constants(MLIRContext *context) : context(context) {}
161 ArrayAttr getEmptyArrayAttr() {
162 llvm::call_once(emptyArray.flag,
163 [&] { emptyArray.attr = ArrayAttr::get(context, {}); });
164 return emptyArray.attr;
169 void initClassOut() {
170 llvm::call_once(classOut.flag, [&] {
171 classOut.domainInfoIn = StringAttr::get(context,
"domainInfo_in");
172 classOut.domainInfoOut = StringAttr::get(context,
"domainInfo_out");
173 classOut.associationsIn = StringAttr::get(context,
"associations_in");
174 classOut.associationsOut = StringAttr::get(context,
"associations_out");
180 StringAttr getDomainInfoIn() {
182 return classOut.domainInfoIn;
186 StringAttr getDomainInfoOut() {
188 return classOut.domainInfoOut;
192 StringAttr getAssociationsIn() {
194 return classOut.associationsIn;
198 StringAttr getAssociationsOut() {
200 return classOut.associationsOut;
208 EmptyArray emptyArray;
218static UnrealizedConversionCastOp stubOut(Value value) {
219 if (!value.hasNUsesOrMore(1))
222 OpBuilder builder(value.getContext());
223 builder.setInsertionPointAfterValue(value);
224 auto temp = UnrealizedConversionCastOp::create(
225 builder, builder.getUnknownLoc(), {value.getType()}, {});
226 value.replaceAllUsesWith(temp.getResult(0));
240static UnrealizedConversionCastOp splice(UnrealizedConversionCastOp temp,
246 assert(temp && temp.getNumResults() == 1 && temp.getNumOperands() == 0);
251 auto oldValue = temp.getResult(0);
253 OpBuilder builder(temp);
254 auto splice = UnrealizedConversionCastOp::create(
255 builder, builder.getUnknownLoc(), {oldValue.getType()}, {value});
256 oldValue.replaceAllUsesWith(splice.getResult(0));
270 LowerModule(FModuleLike op,
const DenseMap<Attribute, Classes> &classes,
272 : op(op), eraseVector(op.
getNumPorts()), domainToClasses(classes),
273 constants(constants), instanceGraph(instanceGraph) {}
277 LogicalResult lowerModule();
281 LogicalResult lowerInstances();
288 BitVector eraseVector;
292 SmallVector<std::pair<unsigned, PortInfo>> newPorts;
295 SmallVector<std::pair<unsigned, unsigned>> resultMap;
298 const DenseMap<Attribute, Classes> &domainToClasses;
301 Constants &constants;
311 llvm::MapVector<unsigned, DomainInfo> indexToDomain;
314LogicalResult LowerModule::lowerModule() {
324 TypeSwitch<Operation *, std::optional<Block *>>(op)
325 .Case<FModuleOp>([](
auto op) {
return op.getBodyBlock(); })
326 .Case<FExtModuleOp>([](
auto) {
return nullptr; })
328 .Default([](
auto) {
return std::nullopt; });
331 Block *body = *shouldProcess;
333 auto *
context = op.getContext();
337 SmallVector<Attribute> portAnnotations;
346 OpBuilder::InsertPoint insertPoint;
348 insertPoint = {body, body->begin()};
349 auto ports = op.getPorts();
350 for (
unsigned i = 0, iDel = 0, iIns = 0, e = op.getNumPorts(); i != e; ++i) {
351 auto port = cast<PortInfo>(ports[i]);
355 if (
auto domainType = dyn_cast<DomainType>(port.type)) {
359 auto domain = domainType.getName();
360 auto [classIn, classOut] = domainToClasses.at(domain.getAttr());
362 indexToDomain[i] = port.direction == Direction::In
363 ? DomainInfo::input(iIns)
364 : DomainInfo::output(iIns);
369 ImplicitLocOpBuilder builder(port.loc,
context);
370 builder.restoreInsertionPoint(insertPoint);
373 auto object = ObjectOp::create(
375 StringAttr::get(
context, Twine(port.name) +
"_object"));
376 instanceGraph.lookup(op)->addInstance(
object,
377 instanceGraph.lookup(classOut));
378 indexToDomain[i].op = object;
379 indexToDomain[i].temp = stubOut(body->getArgument(i));
384 insertPoint = builder.saveInsertionPoint();
391 if (port.direction == Direction::In) {
392 newPorts.push_back({iDel,
PortInfo(port.name, classIn.getInstanceType(),
394 portAnnotations.push_back(constants.getEmptyArrayAttr());
399 classOut.getInstanceType(), Direction::Out)});
403 portAnnotations.push_back(constants.getEmptyArrayAttr());
419 ArrayAttr domainAttr = cast_or_null<ArrayAttr>(port.domains);
420 if (!domainAttr || domainAttr.empty()) {
421 portAnnotations.push_back(port.annotations.getArrayAttr());
428 if (
auto firrtlType = type_dyn_cast<FIRRTLType>(port.type))
430 portAnnotations.push_back(port.annotations.getArrayAttr());
434 SmallVector<Annotation> newAnnotations;
436 for (
auto indexAttr : domainAttr.getAsRange<IntegerAttr>()) {
438 id = DistinctAttr::create(UnitAttr::get(
context));
439 newAnnotations.push_back(
Annotation(DictionaryAttr::getWithSorted(
443 indexToDomain[indexAttr.getUInt()].associations.push_back({id, port.loc});
445 if (!newAnnotations.empty())
446 port.annotations.addAnnotations(newAnnotations);
447 portAnnotations.push_back(port.annotations.getArrayAttr());
451 op.erasePorts(eraseVector);
452 op.setDomainInfoAttr(constants.getEmptyArrayAttr());
456 op.insertPorts(newPorts);
459 for (
auto const &[_, info] : indexToDomain) {
460 auto [object, inputPort, outputPort, temp, associations] =
info;
461 OpBuilder builder(
object);
462 builder.setInsertionPointAfter(
object);
469 auto subDomainInfoIn =
470 ObjectSubfieldOp::create(builder,
object.
getLoc(),
object, 0);
472 PropAssignOp::create(builder,
object.
getLoc(), subDomainInfoIn,
473 body->getArgument(*inputPort));
474 splice(temp, body->getArgument(*inputPort));
476 splice(temp, subDomainInfoIn);
480 auto subAssociations =
481 ObjectSubfieldOp::create(builder,
object.
getLoc(),
object, 2);
482 SmallVector<Value> paths;
483 for (
auto [
id, loc] : associations) {
484 paths.push_back(PathOp::create(
485 builder, loc, TargetKindAttr::get(
context, TargetKind::Reference),
488 auto list = ListCreateOp::create(
490 ListType::get(
context, cast<PropertyType>(PathType::get(
context))),
492 PropAssignOp::create(builder,
object.
getLoc(), subAssociations, list);
495 PropAssignOp::create(builder,
object.
getLoc(),
496 body->getArgument(outputPort),
object);
510 DenseSet<Operation *> conversionsToErase;
511 DenseSet<Operation *> operationsToErase;
512 auto walkResult = op.walk([&](Operation *walkOp) {
516 if (operationsToErase.contains(walkOp)) {
518 return WalkResult::advance();
522 if (
auto castOp = dyn_cast<UnsafeDomainCastOp>(walkOp)) {
523 for (
auto value : castOp.getDomains()) {
524 auto *conversion = value.getDefiningOp();
525 assert(isa<UnrealizedConversionCastOp>(conversion));
526 conversionsToErase.insert(conversion);
529 castOp.getResult().replaceAllUsesWith(castOp.getInput());
531 return WalkResult::advance();
537 if (
auto anonDomain = dyn_cast<DomainCreateAnonOp>(walkOp)) {
539 auto noUser = llvm::all_of(anonDomain->getUsers(), [&](
auto *user) {
540 return operationsToErase.contains(user) ||
541 conversionsToErase.contains(user);
544 conversionsToErase.insert(anonDomain);
545 return WalkResult::advance();
550 OpBuilder builder(anonDomain);
552 domainToClasses.at(anonDomain.getDomainAttr().getAttr()).input;
553 anonDomain.replaceAllUsesWith(UnrealizedConversionCastOp::create(
554 builder, anonDomain.getLoc(), {anonDomain.getType()},
555 {UnknownValueOp::create(builder, anonDomain.getLoc(),
556 classIn.getInstanceType())
559 return WalkResult::advance();
563 if (
auto createDomain = dyn_cast<DomainCreateOp>(walkOp)) {
564 auto noUser = llvm::all_of(createDomain->getUsers(), [&](
auto *user) {
565 return operationsToErase.contains(user) ||
566 conversionsToErase.contains(user);
569 conversionsToErase.insert(createDomain);
570 return WalkResult::advance();
573 OpBuilder builder(createDomain);
575 domainToClasses.at(createDomain.getDomainAttr().getAttr()).input;
576 auto object = ObjectOp::create(builder, createDomain.getLoc(), classIn,
577 createDomain.getNameAttr());
578 instanceGraph.lookup(op)->addInstance(
object,
579 instanceGraph.lookup(classIn));
582 auto fieldValues = createDomain.getFieldValues();
585 for (
auto [idx, port] :
llvm::enumerate(classIn.getPorts())) {
586 if (port.direction == Direction::Out)
588 auto subfield = ObjectSubfieldOp::create(
589 builder, createDomain.getLoc(),
object, idx);
590 PropAssignOp::create(builder, createDomain.getLoc(), subfield,
591 fieldValues[fieldIdx++]);
594 createDomain.replaceAllUsesWith(UnrealizedConversionCastOp::create(
595 builder, createDomain.getLoc(), {createDomain.getType()},
596 {object.getResult()}));
597 createDomain.erase();
598 return WalkResult::advance();
618 if (
auto wireOp = dyn_cast<WireOp>(walkOp)) {
619 if (type_isa<DomainType>(wireOp.getResult().getType())) {
621 SmallVector<Value> dsts;
622 DomainDefineOp lastDefineOp;
623 for (
auto *user :
llvm::make_early_inc_range(wireOp->getUsers())) {
624 auto domainDefineOp = dyn_cast<DomainDefineOp>(user);
625 if (operationsToErase.contains(domainDefineOp))
627 if (!domainDefineOp) {
628 auto diag = wireOp.emitOpError()
629 <<
"cannot be lowered by `LowerDomains` because it "
630 "has a user that is not a domain define op";
631 diag.attachNote(user->getLoc()) <<
"is one such user";
632 return WalkResult::interrupt();
634 if (!lastDefineOp || lastDefineOp->isBeforeInBlock(domainDefineOp))
635 lastDefineOp = domainDefineOp;
636 if (wireOp == domainDefineOp.getSrc().getDefiningOp())
637 dsts.push_back(domainDefineOp.getDest());
639 src = domainDefineOp.getSrc();
640 operationsToErase.insert(domainDefineOp);
642 conversionsToErase.insert(wireOp);
645 if (!src || dsts.empty())
646 return WalkResult::advance();
650 OpBuilder builder(lastDefineOp);
651 for (
auto dst :
llvm::reverse(dsts))
652 DomainDefineOp::create(builder, builder.getUnknownLoc(), dst, src);
654 return WalkResult::advance();
658 auto defineOp = dyn_cast<DomainDefineOp>(walkOp);
660 return WalkResult::advance();
668 auto *src = defineOp.getSrc().getDefiningOp();
669 auto dest = dyn_cast<UnrealizedConversionCastOp>(
670 defineOp.getDest().getDefiningOp());
672 return WalkResult::advance();
674 conversionsToErase.insert(src);
675 conversionsToErase.insert(dest);
677 if (
auto srcCast = dyn_cast<UnrealizedConversionCastOp>(src)) {
678 assert(srcCast.getNumOperands() == 1 && srcCast.getNumResults() == 1);
679 OpBuilder builder(defineOp);
680 PropAssignOp::create(builder, defineOp.getLoc(), dest.getOperand(0),
681 srcCast.getOperand(0));
682 }
else if (!isa<DomainCreateAnonOp, DomainCreateOp>(src)) {
683 auto diag = defineOp.emitOpError()
684 <<
"has a source which cannot be lowered by 'LowerDomains'";
685 diag.attachNote(src->getLoc()) <<
"unsupported source is here";
686 return WalkResult::interrupt();
690 return WalkResult::advance();
693 if (walkResult.wasInterrupted())
697 for (
auto *op : conversionsToErase)
702 op.setPortAnnotationsAttr(ArrayAttr::get(
context, portAnnotations));
707LogicalResult LowerModule::lowerInstances() {
709 if (eraseVector.none() && newPorts.empty())
716 if (!isa<FModuleOp, FExtModuleOp>(op))
719 auto *node = instanceGraph.lookup(cast<igraph::ModuleOpInterface>(*op));
720 for (
auto *use :
llvm::make_early_inc_range(node->uses())) {
721 auto instanceOp = dyn_cast<InstanceOp>(*use->getInstance());
723 use->getInstance().emitOpError()
724 <<
"has an unimplemented lowering in LowerDomains";
727 LLVM_DEBUG(llvm::dbgs()
728 <<
" - " << instanceOp.getInstanceName() <<
"\n");
730 for (
auto i : eraseVector.set_bits())
731 indexToDomain[i].temp = stubOut(instanceOp.getResult(i));
733 auto erased = instanceOp.cloneWithErasedPortsAndReplaceUses(eraseVector);
734 auto inserted = erased.cloneWithInsertedPortsAndReplaceUses(newPorts);
735 instanceGraph.replaceInstance(instanceOp, inserted);
737 for (
auto &[i, info] : indexToDomain) {
739 if (
info.inputPort) {
741 splicedValue = inserted.getResult(*
info.inputPort);
745 OpBuilder builder(inserted);
746 builder.setInsertionPointAfter(inserted);
747 splicedValue = ObjectSubfieldOp::create(
748 builder, inserted.getLoc(), inserted.getResult(
info.outputPort), 1);
751 splice(
info.temp, splicedValue);
767 LowerCircuit(CircuitOp circuit,
InstanceGraph &instanceGraph,
768 llvm::Statistic &numDomains)
769 : circuit(circuit), instanceGraph(instanceGraph),
770 constants(circuit.getContext()), numDomains(numDomains) {}
773 LogicalResult lowerCircuit();
777 LogicalResult lowerDomain(DomainOp);
789 llvm::Statistic &numDomains;
793 DenseMap<Attribute, Classes> classes;
796LogicalResult LowerCircuit::lowerDomain(DomainOp op) {
797 ImplicitLocOpBuilder builder(op.getLoc(), op);
798 auto *
context = op.getContext();
799 auto name = op.getNameAttr();
800 SmallVector<PortInfo> classInPorts;
801 for (
auto field : op.getFields().getAsRange<DomainFieldAttr>())
802 classInPorts.
append({{builder.getStringAttr(
803 Twine(field.getName().getValue()) +
"_in"),
804 field.getType(), Direction::In},
805 {builder.getStringAttr(
806 Twine(field.getName().getValue()) +
"_out"),
807 field.getType(), Direction::Out}});
808 auto classIn = ClassOp::create(builder, name, classInPorts);
809 auto classInType = classIn.getInstanceType();
811 ListType::get(
context, cast<PropertyType>(PathType::get(
context)));
813 ClassOp::create(builder, StringAttr::get(
context, Twine(name) +
"_out"),
814 {{constants.getDomainInfoIn(),
817 {constants.getDomainInfoOut(),
820 {constants.getAssociationsIn(),
823 {constants.getAssociationsOut(),
827 auto connectPairWise = [&builder](ClassOp &classOp) {
828 builder.setInsertionPointToStart(classOp.getBodyBlock());
829 for (
size_t i = 0, e = classOp.getNumPorts(); i != e; i += 2)
830 PropAssignOp::create(builder, classOp.getArgument(i + 1),
831 classOp.getArgument(i));
833 connectPairWise(classIn);
834 connectPairWise(classOut);
836 classes.insert({name, {classIn, classOut}});
837 instanceGraph.addModule(classIn);
838 instanceGraph.addModule(classOut);
844LogicalResult LowerCircuit::lowerCircuit() {
845 LLVM_DEBUG(llvm::dbgs() <<
"Processing domains:\n");
846 for (
auto domain :
llvm::make_early_inc_range(circuit.getOps<DomainOp>())) {
847 LLVM_DEBUG(llvm::dbgs() <<
" - " << domain.getName() <<
"\n");
848 if (failed(lowerDomain(domain)))
852 LLVM_DEBUG(llvm::dbgs() <<
"Processing modules:\n");
854 auto moduleOp = dyn_cast<FModuleLike>(node.
getModule<Operation *>());
857 LLVM_DEBUG(llvm::dbgs() <<
" - module: " << moduleOp.getName() <<
"\n");
858 LowerModule lowerModule(moduleOp, classes, constants, instanceGraph);
859 if (failed(lowerModule.lowerModule()))
861 LLVM_DEBUG(llvm::dbgs() <<
" instances:\n");
862 return lowerModule.lowerInstances();
872 LowerCircuit lowerCircuit(getOperation(), getAnalysis<InstanceGraph>(),
874 if (failed(lowerCircuit.lowerCircuit()))
875 return signalPassFailure();
877 markAnalysesPreserved<InstanceGraph>();
assert(baseType &&"element must be base type")
static std::unique_ptr< Context > context
static Location getLoc(DefSlot slot)
static StringAttr append(StringAttr base, const Twine &suffix)
Return a attribute with the specified suffix appended.
#define CIRCT_DEBUG_SCOPED_PASS_LOGGER(PASS)
void runOnOperation() override
This class provides a read-only projection of an annotation.
This graph tracks modules and where they are instantiated.
This is a Node in the InstanceGraph.
auto getModule()
Get the module that this node is tracking.
size_t getNumPorts(Operation *op)
Return the number of ports in a module-like thing (modules, memories, etc)
bool hasZeroBitWidth(FIRRTLType type)
Return true if the type has zero bit width.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
This holds the name and type that describes the module's ports.