50#include "llvm/ADT/MapVector.h"
51#include "llvm/ADT/TypeSwitch.h"
52#include "llvm/Support/Debug.h"
53#include "llvm/Support/Threading.h"
57#define GEN_PASS_DEF_LOWERDOMAINS
58#include "circt/Dialect/FIRRTL/Passes.h.inc"
63using namespace firrtl;
64using mlir::UnrealizedConversionCastOp;
72 impl::LowerDomainsBase<LowerDomainsPass>::getArgumentName().data()
76struct AssociationInfo {
78 DistinctAttr distinctAttr;
95 std::optional<unsigned> inputPort;
106 UnrealizedConversionCastOp temp;
110 SmallVector<AssociationInfo> associations{};
115 static DomainInfo input(
unsigned portIndex) {
116 return DomainInfo({{}, portIndex, portIndex + 1, {}, {}});
122 static DomainInfo output(
unsigned portIndex) {
123 return DomainInfo({{}, std::nullopt, portIndex, {}, {}});
144 llvm::once_flag flag;
150 llvm::once_flag flag;
151 StringAttr domainInfoIn;
152 StringAttr domainInfoOut;
153 StringAttr associationsIn;
154 StringAttr associationsOut;
158 Constants(MLIRContext *context) : context(context) {}
161 ArrayAttr getEmptyArrayAttr() {
162 llvm::call_once(emptyArray.flag,
163 [&] { emptyArray.attr = ArrayAttr::get(context, {}); });
164 return emptyArray.attr;
169 void initClassOut() {
170 llvm::call_once(classOut.flag, [&] {
171 classOut.domainInfoIn = StringAttr::get(context,
"domainInfo_in");
172 classOut.domainInfoOut = StringAttr::get(context,
"domainInfo_out");
173 classOut.associationsIn = StringAttr::get(context,
"associations_in");
174 classOut.associationsOut = StringAttr::get(context,
"associations_out");
180 StringAttr getDomainInfoIn() {
182 return classOut.domainInfoIn;
186 StringAttr getDomainInfoOut() {
188 return classOut.domainInfoOut;
192 StringAttr getAssociationsIn() {
194 return classOut.associationsIn;
198 StringAttr getAssociationsOut() {
200 return classOut.associationsOut;
208 EmptyArray emptyArray;
218static UnrealizedConversionCastOp stubOut(Value value) {
219 if (!value.hasNUsesOrMore(1))
222 OpBuilder builder(value.getContext());
223 builder.setInsertionPointAfterValue(value);
224 auto temp = UnrealizedConversionCastOp::create(
225 builder, builder.getUnknownLoc(), {value.getType()}, {});
226 value.replaceAllUsesWith(temp.getResult(0));
240static UnrealizedConversionCastOp splice(UnrealizedConversionCastOp temp,
246 assert(temp && temp.getNumResults() == 1 && temp.getNumOperands() == 0);
251 auto oldValue = temp.getResult(0);
253 OpBuilder builder(temp);
254 auto splice = UnrealizedConversionCastOp::create(
255 builder, builder.getUnknownLoc(), {oldValue.getType()}, {value});
256 oldValue.replaceAllUsesWith(splice.getResult(0));
270 LowerModule(FModuleLike op,
const DenseMap<Attribute, Classes> &classes,
272 : op(op), eraseVector(op.
getNumPorts()), domainToClasses(classes),
273 constants(constants), instanceGraph(instanceGraph) {}
277 LogicalResult lowerModule();
281 LogicalResult lowerInstances();
288 BitVector eraseVector;
292 SmallVector<std::pair<unsigned, PortInfo>> newPorts;
295 SmallVector<std::pair<unsigned, unsigned>> resultMap;
298 const DenseMap<Attribute, Classes> &domainToClasses;
301 Constants &constants;
314LogicalResult LowerModule::lowerModule() {
324 TypeSwitch<Operation *, std::optional<Block *>>(op)
325 .Case<FModuleOp>([](
auto op) {
return op.getBodyBlock(); })
326 .Case<FExtModuleOp>([](
auto) {
return nullptr; })
328 .Default([](
auto) {
return std::nullopt; });
331 Block *body = *shouldProcess;
333 auto *
context = op.getContext();
337 SmallVector<Attribute> portAnnotations;
346 OpBuilder::InsertPoint insertPoint;
348 insertPoint = {body, body->begin()};
349 auto ports = op.getPorts();
350 for (
unsigned i = 0, iDel = 0, iIns = 0, e = op.getNumPorts(); i != e; ++i) {
351 auto port = cast<PortInfo>(ports[i]);
355 if (
auto domainType = dyn_cast<DomainType>(port.type)) {
359 auto domain = domainType.getName();
360 auto [classIn, classOut] = domainToClasses.at(domain.getAttr());
362 indexToDomain[i] = port.direction == Direction::In
363 ? DomainInfo::input(iIns)
364 : DomainInfo::output(iIns);
369 ImplicitLocOpBuilder builder(port.loc,
context);
370 builder.restoreInsertionPoint(insertPoint);
373 auto object = ObjectOp::create(
375 StringAttr::get(
context, Twine(port.name) +
"_object"));
376 instanceGraph.lookup(op)->addInstance(
object,
377 instanceGraph.lookup(classOut));
378 indexToDomain[i].op = object;
379 indexToDomain[i].temp = stubOut(body->getArgument(i));
384 insertPoint = builder.saveInsertionPoint();
391 if (port.direction == Direction::In) {
392 newPorts.push_back({iDel,
PortInfo(port.name, classIn.getInstanceType(),
394 portAnnotations.push_back(constants.getEmptyArrayAttr());
399 classOut.getInstanceType(), Direction::Out)});
403 portAnnotations.push_back(constants.getEmptyArrayAttr());
419 ArrayAttr domainAttr = cast_or_null<ArrayAttr>(port.domains);
420 if (!domainAttr || domainAttr.empty()) {
421 portAnnotations.push_back(port.annotations.getArrayAttr());
428 if (
auto firrtlType = type_dyn_cast<FIRRTLType>(port.type))
430 portAnnotations.push_back(port.annotations.getArrayAttr());
434 SmallVector<Annotation> newAnnotations;
436 for (
auto indexAttr : domainAttr.getAsRange<IntegerAttr>()) {
438 id = DistinctAttr::create(UnitAttr::get(
context));
439 newAnnotations.push_back(
Annotation(DictionaryAttr::getWithSorted(
443 indexToDomain[indexAttr.getUInt()].associations.push_back({id, port.loc});
445 if (!newAnnotations.empty())
446 port.annotations.addAnnotations(newAnnotations);
447 portAnnotations.push_back(port.annotations.getArrayAttr());
451 op.erasePorts(eraseVector);
452 op.setDomainInfoAttr(constants.getEmptyArrayAttr());
456 op.insertPorts(newPorts);
459 for (
auto const &[_, info] : indexToDomain) {
460 auto [object, inputPort, outputPort, temp, associations] =
info;
461 OpBuilder builder(
object);
462 builder.setInsertionPointAfter(
object);
469 auto subDomainInfoIn =
470 ObjectSubfieldOp::create(builder,
object.
getLoc(),
object, 0);
472 PropAssignOp::create(builder,
object.
getLoc(), subDomainInfoIn,
473 body->getArgument(*inputPort));
474 splice(temp, body->getArgument(*inputPort));
476 splice(temp, subDomainInfoIn);
480 auto subAssociations =
481 ObjectSubfieldOp::create(builder,
object.
getLoc(),
object, 2);
482 SmallVector<Value> paths;
483 for (
auto [
id, loc] : associations) {
484 paths.push_back(PathOp::create(
485 builder, loc, TargetKindAttr::get(
context, TargetKind::Reference),
488 auto list = ListCreateOp::create(
490 ListType::get(
context, cast<PropertyType>(PathType::get(
context))),
492 PropAssignOp::create(builder,
object.
getLoc(), subAssociations, list);
495 PropAssignOp::create(builder,
object.
getLoc(),
496 body->getArgument(outputPort),
object);
510 DenseSet<Operation *> conversionsToErase;
511 DenseSet<Operation *> operationsToErase;
512 auto walkResult = op.walk([&](Operation *walkOp) {
516 if (operationsToErase.contains(walkOp)) {
518 return WalkResult::advance();
522 if (
auto castOp = dyn_cast<UnsafeDomainCastOp>(walkOp)) {
523 for (
auto value : castOp.getDomains()) {
524 auto *conversion = value.getDefiningOp();
525 assert(isa<UnrealizedConversionCastOp>(conversion));
526 conversionsToErase.insert(conversion);
529 castOp.getResult().replaceAllUsesWith(castOp.getInput());
531 return WalkResult::advance();
537 if (
auto anonDomain = dyn_cast<DomainCreateAnonOp>(walkOp)) {
539 auto noUser = llvm::all_of(anonDomain->getUsers(), [&](
auto *user) {
540 return operationsToErase.contains(user) ||
541 conversionsToErase.contains(user);
544 conversionsToErase.insert(anonDomain);
545 return WalkResult::advance();
550 OpBuilder builder(anonDomain);
552 domainToClasses.at(anonDomain.getDomainAttr().getAttr()).input;
553 anonDomain.replaceAllUsesWith(UnrealizedConversionCastOp::create(
554 builder, anonDomain.getLoc(), {anonDomain.getType()},
555 {UnknownValueOp::create(builder, anonDomain.getLoc(),
556 classIn.getInstanceType())
559 return WalkResult::advance();
563 if (
auto createDomain = dyn_cast<DomainCreateOp>(walkOp)) {
564 auto noUser = llvm::all_of(createDomain->getUsers(), [&](
auto *user) {
565 return operationsToErase.contains(user) ||
566 conversionsToErase.contains(user);
569 conversionsToErase.insert(createDomain);
570 return WalkResult::advance();
573 OpBuilder builder(createDomain);
575 domainToClasses.at(createDomain.getDomainAttr().getAttr()).input;
576 auto object = ObjectOp::create(builder, createDomain.getLoc(), classIn,
577 createDomain.getNameAttr());
578 instanceGraph.lookup(op)->addInstance(
object,
579 instanceGraph.lookup(classIn));
582 auto fieldValues = createDomain.getFieldValues();
586 for (
auto [fieldIdx, fieldValue] :
llvm::enumerate(fieldValues)) {
587 auto inputPortIdx = fieldIdx * 2;
588 auto subfield = ObjectSubfieldOp::create(
589 builder, createDomain.getLoc(),
object, inputPortIdx);
590 PropAssignOp::create(builder, createDomain.getLoc(), subfield,
594 createDomain.replaceAllUsesWith(UnrealizedConversionCastOp::create(
595 builder, createDomain.getLoc(), {createDomain.getType()},
596 {object.getResult()}));
597 createDomain.erase();
598 return WalkResult::advance();
602 if (
auto subfieldOp = dyn_cast<DomainSubfieldOp>(walkOp)) {
605 auto *inputOp = subfieldOp.getInput().getDefiningOp();
607 subfieldOp.emitOpError(
608 "has an input that is not defined by an operation");
609 return WalkResult::interrupt();
612 auto conversionCast = dyn_cast<UnrealizedConversionCastOp>(inputOp);
613 if (!conversionCast || conversionCast.getNumOperands() != 1) {
614 subfieldOp.emitOpError(
615 "has an input that is not a conversion cast with one operand");
616 return WalkResult::interrupt();
621 auto fieldIndex = subfieldOp.getFieldIndex();
622 auto outputPortIndex = fieldIndex * 2 + 1;
625 OpBuilder builder(subfieldOp);
626 auto objectSubfield = ObjectSubfieldOp::create(
627 builder, subfieldOp.getLoc(), conversionCast.getOperand(0),
631 conversionsToErase.insert(conversionCast);
633 subfieldOp.replaceAllUsesWith(objectSubfield.getResult());
635 return WalkResult::advance();
651 if (
auto wireOp = dyn_cast<WireOp>(walkOp)) {
653 if (
auto domainType =
654 type_dyn_cast<DomainType>(wireOp.getResult().getType())) {
655 OpBuilder builder(wireOp);
657 domainToClasses.at(domainType.getName().getAttr()).input;
658 auto object = ObjectOp::create(builder, wireOp.getLoc(), classIn,
659 wireOp.getNameAttr());
660 instanceGraph.lookup(op)->addInstance(
object,
661 instanceGraph.lookup(classIn));
662 auto cast = UnrealizedConversionCastOp::create(
663 builder, wireOp.getLoc(), {wireOp.getResult().getType()},
664 {object.getResult()});
665 wireOp.getResult().replaceAllUsesWith(cast.getResult(0));
666 conversionsToErase.insert(cast);
667 conversionsToErase.insert(wireOp);
671 if (!wireOp.getDomains().empty()) {
672 for (
auto domain : wireOp.getDomains())
673 if (auto *defOp = domain.getDefiningOp())
674 conversionsToErase.insert(defOp);
675 wireOp->eraseOperands(0, wireOp.getNumOperands());
678 return WalkResult::advance();
682 auto defineOp = dyn_cast<DomainDefineOp>(walkOp);
684 return WalkResult::advance();
694 auto *src = defineOp.getSrc().getDefiningOp();
695 auto dest = dyn_cast<UnrealizedConversionCastOp>(
696 defineOp.getDest().getDefiningOp());
698 return WalkResult::advance();
700 conversionsToErase.insert(src);
701 conversionsToErase.insert(dest);
703 if (
auto srcCast = dyn_cast<UnrealizedConversionCastOp>(src)) {
704 assert(srcCast.getNumOperands() == 1 && srcCast.getNumResults() == 1);
705 OpBuilder builder(defineOp);
709 if (dest.getOperand(0).getDefiningOp<ObjectOp>()) {
711 firrtl::type_cast<DomainType>(defineOp.getDest().getType());
712 auto numFields = domainType.getNumFields();
713 for (
size_t i = 0; i < numFields; ++i) {
714 auto destIn = ObjectSubfieldOp::create(builder, defineOp.getLoc(),
715 dest.getOperand(0), i * 2);
716 auto srcOut = ObjectSubfieldOp::create(
717 builder, defineOp.getLoc(), srcCast.getOperand(0), i * 2 + 1);
718 PropAssignOp::create(builder, defineOp.getLoc(), destIn, srcOut);
721 PropAssignOp::create(builder, defineOp.getLoc(), dest.getOperand(0),
722 srcCast.getOperand(0));
724 }
else if (!isa<DomainCreateAnonOp, DomainCreateOp>(src)) {
725 auto diag = defineOp.emitOpError()
726 <<
"has a source which cannot be lowered by 'LowerDomains'";
727 diag.attachNote(src->getLoc()) <<
"unsupported source is here";
728 return WalkResult::interrupt();
732 return WalkResult::advance();
735 if (walkResult.wasInterrupted())
739 for (
auto *op : conversionsToErase)
744 op.setPortAnnotationsAttr(ArrayAttr::get(
context, portAnnotations));
749LogicalResult LowerModule::lowerInstances() {
751 if (eraseVector.none() && newPorts.empty())
758 if (!isa<FModuleOp, FExtModuleOp>(op))
761 auto *node = instanceGraph.lookup(cast<igraph::ModuleOpInterface>(*op));
762 for (
auto *use :
llvm::make_early_inc_range(node->uses())) {
763 auto instanceOp = dyn_cast<InstanceOp>(*use->getInstance());
765 use->getInstance().emitOpError()
766 <<
"has an unimplemented lowering in LowerDomains";
769 LLVM_DEBUG(llvm::dbgs()
770 <<
" - " << instanceOp.getInstanceName() <<
"\n");
772 for (
auto i : eraseVector.set_bits())
773 indexToDomain[i].temp = stubOut(instanceOp.getResult(i));
775 auto erased = instanceOp.cloneWithErasedPortsAndReplaceUses(eraseVector);
776 auto inserted = erased.cloneWithInsertedPortsAndReplaceUses(newPorts);
777 instanceGraph.replaceInstance(instanceOp, inserted);
779 for (
auto &[i, info] : indexToDomain) {
781 if (
info.inputPort) {
783 splicedValue = inserted->getResult(*
info.inputPort);
787 OpBuilder builder(inserted);
788 builder.setInsertionPointAfter(inserted);
790 ObjectSubfieldOp::create(builder, inserted.getLoc(),
791 inserted->getResult(
info.outputPort), 1);
794 splice(
info.temp, splicedValue);
810 LowerCircuit(CircuitOp circuit,
InstanceGraph &instanceGraph,
811 llvm::Statistic &numDomains)
812 : circuit(circuit), instanceGraph(instanceGraph),
813 constants(circuit.getContext()), numDomains(numDomains) {}
816 LogicalResult lowerCircuit();
820 LogicalResult lowerDomain(DomainOp);
832 llvm::Statistic &numDomains;
836 DenseMap<Attribute, Classes> classes;
839LogicalResult LowerCircuit::lowerDomain(DomainOp op) {
840 ImplicitLocOpBuilder builder(op.getLoc(), op);
841 auto *
context = op.getContext();
842 auto name = op.getNameAttr();
843 SmallVector<PortInfo> classInPorts;
844 for (
auto field : op.getFields().getAsRange<DomainFieldAttr>())
845 classInPorts.
append({{builder.getStringAttr(
846 Twine(field.getName().getValue()) +
"_in"),
847 field.getType(), Direction::In},
848 {builder.getStringAttr(
849 Twine(field.getName().getValue()) +
"_out"),
850 field.getType(), Direction::Out}});
851 auto classIn = ClassOp::create(builder, name, classInPorts);
852 auto classInType = classIn.getInstanceType();
854 ListType::get(
context, cast<PropertyType>(PathType::get(
context)));
856 ClassOp::create(builder, StringAttr::get(
context, Twine(name) +
"_out"),
857 {{constants.getDomainInfoIn(),
860 {constants.getDomainInfoOut(),
863 {constants.getAssociationsIn(),
866 {constants.getAssociationsOut(),
870 auto connectPairWise = [&builder](ClassOp &classOp) {
871 builder.setInsertionPointToStart(classOp.getBodyBlock());
872 for (
size_t i = 0, e = classOp.getNumPorts(); i != e; i += 2)
873 PropAssignOp::create(builder, classOp.getArgument(i + 1),
874 classOp.getArgument(i));
876 connectPairWise(classIn);
877 connectPairWise(classOut);
879 classes.insert({name, {classIn, classOut}});
880 instanceGraph.addModule(classIn);
881 instanceGraph.addModule(classOut);
887LogicalResult LowerCircuit::lowerCircuit() {
888 LLVM_DEBUG(llvm::dbgs() <<
"Processing domains:\n");
889 for (
auto domain :
llvm::make_early_inc_range(circuit.getOps<DomainOp>())) {
890 LLVM_DEBUG(llvm::dbgs() <<
" - " << domain.getName() <<
"\n");
891 if (failed(lowerDomain(domain)))
895 LLVM_DEBUG(llvm::dbgs() <<
"Processing modules:\n");
897 auto moduleOp = dyn_cast<FModuleLike>(node.
getModule<Operation *>());
900 LLVM_DEBUG(llvm::dbgs() <<
" - module: " << moduleOp.getName() <<
"\n");
901 LowerModule lowerModule(moduleOp, classes, constants, instanceGraph);
902 if (failed(lowerModule.lowerModule()))
904 LLVM_DEBUG(llvm::dbgs() <<
" instances:\n");
905 return lowerModule.lowerInstances();
915 LowerCircuit lowerCircuit(getOperation(), getAnalysis<InstanceGraph>(),
917 if (failed(lowerCircuit.lowerCircuit()))
918 return signalPassFailure();
920 markAnalysesPreserved<InstanceGraph>();
assert(baseType &&"element must be base type")
static std::unique_ptr< Context > context
static Location getLoc(DefSlot slot)
static StringAttr append(StringAttr base, const Twine &suffix)
Return a attribute with the specified suffix appended.
#define CIRCT_DEBUG_SCOPED_PASS_LOGGER(PASS)
void runOnOperation() override
This class provides a read-only projection of an annotation.
This graph tracks modules and where they are instantiated.
This is a Node in the InstanceGraph.
auto getModule()
Get the module that this node is tracking.
size_t getNumPorts(Operation *op)
Return the number of ports in a module-like thing (modules, memories, etc)
bool hasZeroBitWidth(FIRRTLType type)
Return true if the type has zero bit width.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
This holds the name and type that describes the module's ports.