48#include "llvm/ADT/MapVector.h"
49#include "llvm/ADT/TypeSwitch.h"
50#include "llvm/Support/Debug.h"
51#include "llvm/Support/Threading.h"
55#define GEN_PASS_DEF_LOWERDOMAINS
56#include "circt/Dialect/FIRRTL/Passes.h.inc"
61using namespace firrtl;
69 impl::LowerDomainsBase<LowerDomainsPass>::getArgumentName().data()
73struct AssociationInfo {
75 DistinctAttr distinctAttr;
92 std::optional<unsigned> inputPort;
100 SmallVector<AssociationInfo> associations{};
120 llvm::once_flag flag;
126 llvm::once_flag flag;
127 StringAttr domainInfoIn;
128 StringAttr domainInfoOut;
129 StringAttr associationsIn;
130 StringAttr associationsOut;
134 Constants(MLIRContext *context) : context(context) {}
137 ArrayAttr getEmptyArrayAttr() {
138 llvm::call_once(emptyArray.flag,
139 [&] { emptyArray.attr = ArrayAttr::get(context, {}); });
140 return emptyArray.attr;
145 void initClassOut() {
146 llvm::call_once(classOut.flag, [&] {
147 classOut.domainInfoIn = StringAttr::get(context,
"domainInfo_in");
148 classOut.domainInfoOut = StringAttr::get(context,
"domainInfo_out");
149 classOut.associationsIn = StringAttr::get(context,
"associations_in");
150 classOut.associationsOut = StringAttr::get(context,
"associations_out");
156 StringAttr getDomainInfoIn() {
158 return classOut.domainInfoIn;
162 StringAttr getDomainInfoOut() {
164 return classOut.domainInfoOut;
168 StringAttr getAssociationsIn() {
170 return classOut.associationsIn;
174 StringAttr getAssociationsOut() {
176 return classOut.associationsOut;
181 MLIRContext *context;
184 EmptyArray emptyArray;
197 LowerModule(FModuleLike op,
const DenseMap<Attribute, Classes> &classes,
199 : op(op), eraseVector(op.
getNumPorts()), domainToClasses(classes),
200 constants(constants), instanceGraph(instanceGraph) {}
204 LogicalResult lowerModule();
208 LogicalResult lowerInstances();
212 LogicalResult eraseDomainUsers(Value value) {
213 for (
auto *user :
llvm::make_early_inc_range(value.getUsers())) {
215 if (
auto castOp = dyn_cast<UnsafeDomainCastOp>(user)) {
216 castOp.getResult().replaceAllUsesWith(castOp.getInput());
221 if (isa<DomainDefineOp>(user)) {
225 return user->emitOpError()
226 <<
"has an unimplemented lowering in the LowerDomains pass.";
235 BitVector eraseVector;
239 SmallVector<std::pair<unsigned, PortInfo>> newPorts;
242 SmallVector<std::pair<unsigned, unsigned>> resultMap;
245 const DenseMap<Attribute, Classes> &domainToClasses;
248 Constants &constants;
257LogicalResult LowerModule::lowerModule() {
260 if (op.getDomainInfo().empty())
268 TypeSwitch<Operation *, std::optional<Block *>>(op)
269 .Case<FModuleOp>([](
auto op) {
return op.getBodyBlock(); })
270 .Case<FExtModuleOp>([](
auto) {
return nullptr; })
272 .Default([](
auto) {
return std::nullopt; });
275 Block *body = *shouldProcess;
277 auto *context = op.getContext();
281 llvm::MapVector<unsigned, DomainInfo> indexToDomain;
285 SmallVector<Attribute> portAnnotations;
294 auto ports = op.getPorts();
295 for (
unsigned i = 0, iDel = 0, iIns = 0, e = op.getNumPorts(); i != e; ++i) {
296 auto port = cast<PortInfo>(ports[i]);
299 if (
auto domain = dyn_cast_or_null<FlatSymbolRefAttr>(port.domains)) {
303 auto [classIn, classOut] = domainToClasses.at(domain.getAttr());
306 auto builder = ImplicitLocOpBuilder::atBlockEnd(port.loc, body);
307 auto object = ObjectOp::create(
309 StringAttr::get(context, Twine(port.name) +
"_object"));
310 instanceGraph.lookup(op)->addInstance(
object,
311 instanceGraph.lookup(classOut));
312 if (port.direction == Direction::In)
313 indexToDomain[i] = {object, iIns, iIns + 1};
315 indexToDomain[i] = {object, std::nullopt, iIns};
318 if (failed(eraseDomainUsers(body->getArgument(i))))
325 if (port.direction == Direction::In) {
326 newPorts.push_back({iDel,
PortInfo(port.name, classIn.getInstanceType(),
328 portAnnotations.push_back(constants.getEmptyArrayAttr());
332 {iDel,
PortInfo(StringAttr::get(context, Twine(port.name) +
"_out"),
333 classOut.getInstanceType(), Direction::Out)});
334 portAnnotations.push_back(constants.getEmptyArrayAttr());
344 resultMap.emplace_back(iDel++, iIns++);
351 ArrayAttr domainAttr = cast_or_null<ArrayAttr>(port.domains);
352 if (!domainAttr || domainAttr.empty()) {
353 portAnnotations.push_back(port.annotations.getArrayAttr());
357 SmallVector<Annotation> newAnnotations;
359 for (
auto indexAttr : domainAttr.getAsRange<IntegerAttr>()) {
361 id = DistinctAttr::create(UnitAttr::get(context));
362 newAnnotations.push_back(
Annotation(DictionaryAttr::getWithSorted(
363 context, {{
"class", StringAttr::get(context,
"circt.tracker")},
366 indexToDomain[indexAttr.getUInt()].associations.push_back({id, port.loc});
368 if (!newAnnotations.empty())
369 port.annotations.addAnnotations(newAnnotations);
370 portAnnotations.push_back(port.annotations.getArrayAttr());
374 op.erasePorts(eraseVector);
375 op.setDomainInfoAttr(constants.getEmptyArrayAttr());
379 op.insertPorts(newPorts);
382 for (
auto const &[_, info] : indexToDomain) {
383 auto [object, inputPort, outputPort, associations] =
info;
384 OpBuilder builder(
object);
385 builder.setInsertionPointAfter(
object);
391 auto subDomainInfoIn =
392 ObjectSubfieldOp::create(builder,
object.
getLoc(),
object, 0);
393 PropAssignOp::create(builder,
object.
getLoc(), subDomainInfoIn,
394 body->getArgument(*inputPort));
396 auto subAssociations =
397 ObjectSubfieldOp::create(builder,
object.
getLoc(),
object, 2);
399 SmallVector<Value> paths;
400 for (
auto [
id, loc] : associations) {
401 paths.push_back(PathOp::create(
402 builder, loc, TargetKindAttr::get(context, TargetKind::Reference),
405 auto list = ListCreateOp::create(
407 ListType::get(context, cast<PropertyType>(PathType::get(context))),
409 PropAssignOp::create(builder,
object.
getLoc(), subAssociations, list);
411 PropAssignOp::create(builder,
object.
getLoc(),
412 body->getArgument(outputPort),
object);
417 op.setPortAnnotationsAttr(ArrayAttr::get(context, portAnnotations));
420 llvm::dbgs() <<
" portMap:\n";
421 for (
auto [oldIndex, newIndex] : resultMap)
422 llvm::dbgs() <<
" - " << oldIndex <<
": " << newIndex <<
"\n";
428LogicalResult LowerModule::lowerInstances() {
430 if (eraseVector.none() && newPorts.empty())
437 if (!isa<FModuleOp, FExtModuleOp>(op))
440 auto *node = instanceGraph.lookup(cast<igraph::ModuleOpInterface>(*op));
441 for (
auto *use :
llvm::make_early_inc_range(node->uses())) {
442 auto instanceOp = dyn_cast<InstanceOp>(*use->getInstance());
444 use->getInstance().emitOpError()
445 <<
"has an unimplemented lowering in LowerDomains";
448 LLVM_DEBUG(llvm::dbgs()
449 <<
" - " << instanceOp.getInstanceName() <<
"\n");
451 for (
auto bit : eraseVector.set_bits())
452 if (failed(eraseDomainUsers(instanceOp.getResult(bit))))
455 auto erased = instanceOp.cloneWithErasedPortsAndReplaceUses(eraseVector);
456 auto inserted = erased.cloneWithInsertedPortsAndReplaceUses(newPorts);
457 instanceGraph.replaceInstance(instanceOp, inserted);
472 LowerCircuit(CircuitOp circuit,
InstanceGraph &instanceGraph,
473 llvm::Statistic &numDomains)
474 : circuit(circuit), instanceGraph(instanceGraph),
475 constants(circuit.getContext()), numDomains(numDomains) {}
478 LogicalResult lowerCircuit();
482 LogicalResult lowerDomain(DomainOp);
494 llvm::Statistic &numDomains;
498 DenseMap<Attribute, Classes> classes;
501LogicalResult LowerCircuit::lowerDomain(DomainOp op) {
502 ImplicitLocOpBuilder builder(op.getLoc(), op);
503 auto *context = op.getContext();
504 auto name = op.getNameAttr();
505 SmallVector<PortInfo> classInPorts;
506 for (
auto field : op.getFields().getAsRange<DomainFieldAttr>())
507 classInPorts.
append({{builder.getStringAttr(
508 Twine(field.getName().getValue()) +
"_in"),
509 field.getType(), Direction::In},
510 {builder.getStringAttr(
511 Twine(field.getName().getValue()) +
"_out"),
512 field.getType(), Direction::Out}});
513 auto classIn = ClassOp::create(builder, name, classInPorts);
514 auto classInType = classIn.getInstanceType();
516 ListType::get(context, cast<PropertyType>(PathType::get(context)));
518 ClassOp::create(builder, StringAttr::get(context, Twine(name) +
"_out"),
519 {{constants.getDomainInfoIn(),
522 {constants.getDomainInfoOut(),
525 {constants.getAssociationsIn(),
528 {constants.getAssociationsOut(),
532 auto connectPairWise = [&builder](ClassOp &classOp) {
533 builder.setInsertionPointToStart(classOp.getBodyBlock());
534 for (
size_t i = 0, e = classOp.getNumPorts(); i != e; i += 2)
535 PropAssignOp::create(builder, classOp.getArgument(i + 1),
536 classOp.getArgument(i));
538 connectPairWise(classIn);
539 connectPairWise(classOut);
541 classes.insert({name, {classIn, classOut}});
542 instanceGraph.addModule(classIn);
543 instanceGraph.addModule(classOut);
549LogicalResult LowerCircuit::lowerCircuit() {
550 LLVM_DEBUG(llvm::dbgs() <<
"Processing domains:\n");
551 for (
auto domain :
llvm::make_early_inc_range(circuit.getOps<DomainOp>())) {
552 LLVM_DEBUG(llvm::dbgs() <<
" - " << domain.getName() <<
"\n");
553 if (failed(lowerDomain(domain)))
557 LLVM_DEBUG(llvm::dbgs() <<
"Processing modules:\n");
559 auto moduleOp = dyn_cast<FModuleLike>(node.
getModule<Operation *>());
562 LLVM_DEBUG(llvm::dbgs() <<
" - module: " << moduleOp.getName() <<
"\n");
563 LowerModule lowerModule(moduleOp, classes, constants, instanceGraph);
564 if (failed(lowerModule.lowerModule()))
566 LLVM_DEBUG(llvm::dbgs() <<
" instances:\n");
567 return lowerModule.lowerInstances();
577 LowerCircuit lowerCircuit(getOperation(), getAnalysis<InstanceGraph>(),
579 if (failed(lowerCircuit.lowerCircuit()))
580 return signalPassFailure();
582 markAllAnalysesPreserved();
static Location getLoc(DefSlot slot)
static StringAttr append(StringAttr base, const Twine &suffix)
Return a attribute with the specified suffix appended.
#define CIRCT_DEBUG_SCOPED_PASS_LOGGER(PASS)
void runOnOperation() override
This class provides a read-only projection of an annotation.
This graph tracks modules and where they are instantiated.
This is a Node in the InstanceGraph.
auto getModule()
Get the module that this node is tracking.
size_t getNumPorts(Operation *op)
Return the number of ports in a module-like thing (modules, memories, etc)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
This holds the name and type that describes the module's ports.