48#include "llvm/ADT/MapVector.h"
49#include "llvm/ADT/TypeSwitch.h"
50#include "llvm/Support/Debug.h"
51#include "llvm/Support/Threading.h"
55#define GEN_PASS_DEF_LOWERDOMAINS
56#include "circt/Dialect/FIRRTL/Passes.h.inc"
61using namespace firrtl;
62using mlir::UnrealizedConversionCastOp;
70 impl::LowerDomainsBase<LowerDomainsPass>::getArgumentName().data()
74struct AssociationInfo {
76 DistinctAttr distinctAttr;
93 std::optional<unsigned> inputPort;
104 UnrealizedConversionCastOp temp;
108 SmallVector<AssociationInfo> associations{};
113 static DomainInfo input(
unsigned portIndex) {
114 return DomainInfo({{}, portIndex, portIndex + 1, {}, {}});
120 static DomainInfo output(
unsigned portIndex) {
121 return DomainInfo({{}, std::nullopt, portIndex, {}, {}});
142 llvm::once_flag flag;
148 llvm::once_flag flag;
149 StringAttr domainInfoIn;
150 StringAttr domainInfoOut;
151 StringAttr associationsIn;
152 StringAttr associationsOut;
156 Constants(MLIRContext *context) : context(context) {}
159 ArrayAttr getEmptyArrayAttr() {
160 llvm::call_once(emptyArray.flag,
161 [&] { emptyArray.attr = ArrayAttr::get(context, {}); });
162 return emptyArray.attr;
167 void initClassOut() {
168 llvm::call_once(classOut.flag, [&] {
169 classOut.domainInfoIn = StringAttr::get(context,
"domainInfo_in");
170 classOut.domainInfoOut = StringAttr::get(context,
"domainInfo_out");
171 classOut.associationsIn = StringAttr::get(context,
"associations_in");
172 classOut.associationsOut = StringAttr::get(context,
"associations_out");
178 StringAttr getDomainInfoIn() {
180 return classOut.domainInfoIn;
184 StringAttr getDomainInfoOut() {
186 return classOut.domainInfoOut;
190 StringAttr getAssociationsIn() {
192 return classOut.associationsIn;
196 StringAttr getAssociationsOut() {
198 return classOut.associationsOut;
203 MLIRContext *context;
206 EmptyArray emptyArray;
216static UnrealizedConversionCastOp stubOut(Value value) {
217 if (!value.hasNUsesOrMore(1))
220 OpBuilder builder(value.getContext());
221 builder.setInsertionPointAfterValue(value);
222 auto temp = UnrealizedConversionCastOp::create(
223 builder, builder.getUnknownLoc(), {value.getType()}, {});
224 value.replaceAllUsesWith(temp.getResult(0));
238static UnrealizedConversionCastOp splice(UnrealizedConversionCastOp temp,
244 assert(temp && temp.getNumResults() == 1 && temp.getNumOperands() == 0);
249 auto oldValue = temp.getResult(0);
251 OpBuilder builder(temp);
252 auto splice = UnrealizedConversionCastOp::create(
253 builder, builder.getUnknownLoc(), {oldValue.getType()}, {value});
254 oldValue.replaceAllUsesWith(splice.getResult(0));
268 LowerModule(FModuleLike op,
const DenseMap<Attribute, Classes> &classes,
270 : op(op), eraseVector(op.
getNumPorts()), domainToClasses(classes),
271 constants(constants), instanceGraph(instanceGraph) {}
275 LogicalResult lowerModule();
279 LogicalResult lowerInstances();
286 BitVector eraseVector;
290 SmallVector<std::pair<unsigned, PortInfo>> newPorts;
293 SmallVector<std::pair<unsigned, unsigned>> resultMap;
296 const DenseMap<Attribute, Classes> &domainToClasses;
299 Constants &constants;
309 llvm::MapVector<unsigned, DomainInfo> indexToDomain;
312LogicalResult LowerModule::lowerModule() {
322 TypeSwitch<Operation *, std::optional<Block *>>(op)
323 .Case<FModuleOp>([](
auto op) {
return op.getBodyBlock(); })
324 .Case<FExtModuleOp>([](
auto) {
return nullptr; })
326 .Default([](
auto) {
return std::nullopt; });
329 Block *body = *shouldProcess;
331 auto *context = op.getContext();
335 SmallVector<Attribute> portAnnotations;
344 OpBuilder::InsertPoint insertPoint;
346 insertPoint = {body, body->begin()};
347 auto ports = op.getPorts();
348 for (
unsigned i = 0, iDel = 0, iIns = 0, e = op.getNumPorts(); i != e; ++i) {
349 auto port = cast<PortInfo>(ports[i]);
352 if (
auto domain = dyn_cast_or_null<FlatSymbolRefAttr>(port.domains)) {
356 auto [classIn, classOut] = domainToClasses.at(domain.getAttr());
358 indexToDomain[i] = port.direction == Direction::In
359 ? DomainInfo::input(iIns)
360 : DomainInfo::output(iIns);
365 ImplicitLocOpBuilder builder(port.loc, context);
366 builder.restoreInsertionPoint(insertPoint);
369 auto object = ObjectOp::create(
371 StringAttr::get(context, Twine(port.name) +
"_object"));
372 instanceGraph.lookup(op)->addInstance(
object,
373 instanceGraph.lookup(classOut));
374 indexToDomain[i].op = object;
375 indexToDomain[i].temp = stubOut(body->getArgument(i));
380 insertPoint = builder.saveInsertionPoint();
387 if (port.direction == Direction::In) {
388 newPorts.push_back({iDel,
PortInfo(port.name, classIn.getInstanceType(),
390 portAnnotations.push_back(constants.getEmptyArrayAttr());
394 {iDel,
PortInfo(StringAttr::get(context, Twine(port.name) +
"_out"),
395 classOut.getInstanceType(), Direction::Out)});
399 portAnnotations.push_back(constants.getEmptyArrayAttr());
415 ArrayAttr domainAttr = cast_or_null<ArrayAttr>(port.domains);
416 if (!domainAttr || domainAttr.empty()) {
417 portAnnotations.push_back(port.annotations.getArrayAttr());
421 SmallVector<Annotation> newAnnotations;
423 for (
auto indexAttr : domainAttr.getAsRange<IntegerAttr>()) {
425 id = DistinctAttr::create(UnitAttr::get(context));
426 newAnnotations.push_back(
Annotation(DictionaryAttr::getWithSorted(
427 context, {{
"class", StringAttr::get(context,
"circt.tracker")},
430 indexToDomain[indexAttr.getUInt()].associations.push_back({id, port.loc});
432 if (!newAnnotations.empty())
433 port.annotations.addAnnotations(newAnnotations);
434 portAnnotations.push_back(port.annotations.getArrayAttr());
438 op.erasePorts(eraseVector);
439 op.setDomainInfoAttr(constants.getEmptyArrayAttr());
443 op.insertPorts(newPorts);
446 for (
auto const &[_, info] : indexToDomain) {
447 auto [object, inputPort, outputPort, temp, associations] =
info;
448 OpBuilder builder(
object);
449 builder.setInsertionPointAfter(
object);
456 auto subDomainInfoIn =
457 ObjectSubfieldOp::create(builder,
object.
getLoc(),
object, 0);
459 PropAssignOp::create(builder,
object.
getLoc(), subDomainInfoIn,
460 body->getArgument(*inputPort));
461 splice(temp, body->getArgument(*inputPort));
463 splice(temp, subDomainInfoIn);
467 auto subAssociations =
468 ObjectSubfieldOp::create(builder,
object.
getLoc(),
object, 2);
469 SmallVector<Value> paths;
470 for (
auto [
id, loc] : associations) {
471 paths.push_back(PathOp::create(
472 builder, loc, TargetKindAttr::get(context, TargetKind::Reference),
475 auto list = ListCreateOp::create(
477 ListType::get(context, cast<PropertyType>(PathType::get(context))),
479 PropAssignOp::create(builder,
object.
getLoc(), subAssociations, list);
482 PropAssignOp::create(builder,
object.
getLoc(),
483 body->getArgument(outputPort),
object);
497 DenseSet<Operation *> conversionsToErase;
498 auto walkResult = op.walk([&conversionsToErase](Operation *walkOp) {
500 if (
auto castOp = dyn_cast<UnsafeDomainCastOp>(walkOp)) {
501 for (
auto value : castOp.getDomains()) {
502 auto *conversion = value.getDefiningOp();
503 assert(isa<UnrealizedConversionCastOp>(conversion));
504 conversionsToErase.insert(conversion);
507 castOp.getResult().replaceAllUsesWith(castOp.getInput());
509 return WalkResult::advance();
513 auto defineOp = dyn_cast<DomainDefineOp>(walkOp);
515 return WalkResult::advance();
519 auto src = dyn_cast<UnrealizedConversionCastOp>(
520 defineOp.getSrc().getDefiningOp());
521 auto dest = dyn_cast<UnrealizedConversionCastOp>(
522 defineOp.getDest().getDefiningOp());
524 return WalkResult::advance();
525 assert(src.getNumOperands() == 1 && src.getNumResults() == 1);
526 assert(dest.getNumOperands() == 1 && dest.getNumResults() == 1);
528 conversionsToErase.insert(src);
529 conversionsToErase.insert(dest);
531 OpBuilder builder(defineOp);
532 PropAssignOp::create(builder, defineOp.getLoc(), dest.getOperand(0),
535 return WalkResult::advance();
540 assert(!walkResult.wasInterrupted());
543 for (
auto *op : conversionsToErase)
548 op.setPortAnnotationsAttr(ArrayAttr::get(context, portAnnotations));
553LogicalResult LowerModule::lowerInstances() {
555 if (eraseVector.none() && newPorts.empty())
562 if (!isa<FModuleOp, FExtModuleOp>(op))
565 auto *node = instanceGraph.lookup(cast<igraph::ModuleOpInterface>(*op));
566 for (
auto *use :
llvm::make_early_inc_range(node->uses())) {
567 auto instanceOp = dyn_cast<InstanceOp>(*use->getInstance());
569 use->getInstance().emitOpError()
570 <<
"has an unimplemented lowering in LowerDomains";
573 LLVM_DEBUG(llvm::dbgs()
574 <<
" - " << instanceOp.getInstanceName() <<
"\n");
576 for (
auto i : eraseVector.set_bits())
577 indexToDomain[i].temp = stubOut(instanceOp.getResult(i));
579 auto erased = instanceOp.cloneWithErasedPortsAndReplaceUses(eraseVector);
580 auto inserted = erased.cloneWithInsertedPortsAndReplaceUses(newPorts);
581 instanceGraph.replaceInstance(instanceOp, inserted);
583 for (
auto &[i, info] : indexToDomain) {
585 if (
info.inputPort) {
587 splicedValue = inserted.getResult(*
info.inputPort);
591 OpBuilder builder(inserted);
592 builder.setInsertionPointAfter(inserted);
593 splicedValue = ObjectSubfieldOp::create(
594 builder, inserted.getLoc(), inserted.getResult(
info.outputPort), 1);
597 splice(
info.temp, splicedValue);
613 LowerCircuit(CircuitOp circuit,
InstanceGraph &instanceGraph,
614 llvm::Statistic &numDomains)
615 : circuit(circuit), instanceGraph(instanceGraph),
616 constants(circuit.getContext()), numDomains(numDomains) {}
619 LogicalResult lowerCircuit();
623 LogicalResult lowerDomain(DomainOp);
635 llvm::Statistic &numDomains;
639 DenseMap<Attribute, Classes> classes;
642LogicalResult LowerCircuit::lowerDomain(DomainOp op) {
643 ImplicitLocOpBuilder builder(op.getLoc(), op);
644 auto *context = op.getContext();
645 auto name = op.getNameAttr();
646 SmallVector<PortInfo> classInPorts;
647 for (
auto field : op.getFields().getAsRange<DomainFieldAttr>())
648 classInPorts.
append({{builder.getStringAttr(
649 Twine(field.getName().getValue()) +
"_in"),
650 field.getType(), Direction::In},
651 {builder.getStringAttr(
652 Twine(field.getName().getValue()) +
"_out"),
653 field.getType(), Direction::Out}});
654 auto classIn = ClassOp::create(builder, name, classInPorts);
655 auto classInType = classIn.getInstanceType();
657 ListType::get(context, cast<PropertyType>(PathType::get(context)));
659 ClassOp::create(builder, StringAttr::get(context, Twine(name) +
"_out"),
660 {{constants.getDomainInfoIn(),
663 {constants.getDomainInfoOut(),
666 {constants.getAssociationsIn(),
669 {constants.getAssociationsOut(),
673 auto connectPairWise = [&builder](ClassOp &classOp) {
674 builder.setInsertionPointToStart(classOp.getBodyBlock());
675 for (
size_t i = 0, e = classOp.getNumPorts(); i != e; i += 2)
676 PropAssignOp::create(builder, classOp.getArgument(i + 1),
677 classOp.getArgument(i));
679 connectPairWise(classIn);
680 connectPairWise(classOut);
682 classes.insert({name, {classIn, classOut}});
683 instanceGraph.addModule(classIn);
684 instanceGraph.addModule(classOut);
690LogicalResult LowerCircuit::lowerCircuit() {
691 LLVM_DEBUG(llvm::dbgs() <<
"Processing domains:\n");
692 for (
auto domain :
llvm::make_early_inc_range(circuit.getOps<DomainOp>())) {
693 LLVM_DEBUG(llvm::dbgs() <<
" - " << domain.getName() <<
"\n");
694 if (failed(lowerDomain(domain)))
698 LLVM_DEBUG(llvm::dbgs() <<
"Processing modules:\n");
700 auto moduleOp = dyn_cast<FModuleLike>(node.
getModule<Operation *>());
703 LLVM_DEBUG(llvm::dbgs() <<
" - module: " << moduleOp.getName() <<
"\n");
704 LowerModule lowerModule(moduleOp, classes, constants, instanceGraph);
705 if (failed(lowerModule.lowerModule()))
707 LLVM_DEBUG(llvm::dbgs() <<
" instances:\n");
708 return lowerModule.lowerInstances();
718 LowerCircuit lowerCircuit(getOperation(), getAnalysis<InstanceGraph>(),
720 if (failed(lowerCircuit.lowerCircuit()))
721 return signalPassFailure();
723 markAnalysesPreserved<InstanceGraph>();
assert(baseType &&"element must be base type")
static Location getLoc(DefSlot slot)
static StringAttr append(StringAttr base, const Twine &suffix)
Return a attribute with the specified suffix appended.
#define CIRCT_DEBUG_SCOPED_PASS_LOGGER(PASS)
void runOnOperation() override
This class provides a read-only projection of an annotation.
This graph tracks modules and where they are instantiated.
This is a Node in the InstanceGraph.
auto getModule()
Get the module that this node is tracking.
size_t getNumPorts(Operation *op)
Return the number of ports in a module-like thing (modules, memories, etc)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
This holds the name and type that describes the module's ports.