24#include "mlir/IR/BuiltinAttributes.h"
25#include "mlir/IR/Threading.h"
26#include "mlir/IR/Visitors.h"
27#include "mlir/Pass/Pass.h"
28#include "llvm/Support/Debug.h"
29#include "llvm/Support/ErrorHandling.h"
30#include "llvm/Support/FormatAdapters.h"
31#include "llvm/Support/FormatVariadic.h"
33#define DEBUG_TYPE "firrtl-lower-open-aggs"
37#define GEN_PASS_DEF_LOWEROPENAGGS
38#include "circt/Dialect/FIRRTL/Passes.h.inc"
43using namespace firrtl;
56 SmallString<16> suffix;
59 void print(raw_ostream &os,
unsigned indent = 0)
const;
61#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
63 LLVM_DUMP_METHOD
void dump()
const { print(llvm::errs()); }
80 SmallVector<NonHWField, 0> fields;
84 SmallVector<uint64_t, 0> mapToNullInteriors;
86 hw::InnerSymAttr newSym = {};
89 size_t count(
bool includeErased =
false)
const {
92 return fields.size() + (hwType ? 1 : 0) + (includeErased ? 1 : 0);
96 void print(raw_ostream &os,
unsigned indent = 0)
const;
98#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
100 LLVM_DUMP_METHOD
void dump()
const { print(llvm::errs()); }
106void NonHWField::print(llvm::raw_ostream &os,
unsigned indent)
const {
107 os << llvm::formatv(
"{0}- type: {2}\n"
110 "{1}suffix: \"{5}\"\n",
111 llvm::fmt_pad(
"", indent, 0),
112 llvm::fmt_pad(
"", indent + 2, 0), type, fieldID, isFlip,
115void MappingInfo::print(llvm::raw_ostream &os,
unsigned indent)
const {
121 os.indent(indent) <<
"hardware: ";
128 os.indent(indent) <<
"non-hardware:\n";
129 for (
auto &field : fields)
130 field.print(os, indent + 2);
132 os.indent(indent) <<
"mappedToNull:\n";
133 for (
auto &null : mapToNullInteriors)
134 os.indent(indent + 2) <<
"- " << null <<
"\n";
136 os.indent(indent) <<
"newSym: ";
143template <
typename Range>
145 Range &&range,
bool includeErased,
146 llvm::function_ref<LogicalResult(
size_t, MappingInfo &,
size_t)> callback) {
148 for (
const auto &[index, pmi] : llvm::enumerate(range)) {
149 if (failed(callback(index, pmi, count)))
151 count += pmi.count(includeErased);
161class Visitor :
public FIRRTLVisitor<Visitor, LogicalResult> {
163 explicit Visitor(MLIRContext *context) : context(context){};
166 LogicalResult visit(FModuleLike mod);
172 LogicalResult visitDecl(InstanceOp op);
173 LogicalResult visitDecl(WireOp op);
175 LogicalResult visitExpr(OpenSubfieldOp op);
176 LogicalResult visitExpr(OpenSubindexOp op);
179 auto notOpenAggType = [](
auto type) {
180 return !isa<OpenBundleType, OpenVectorType>(type);
182 if (!llvm::all_of(op->getOperandTypes(), notOpenAggType) ||
183 !llvm::all_of(op->getResultTypes(), notOpenAggType))
184 return op->emitOpError(
185 "unhandled use or producer of types containing non-hw types");
192 bool madeChanges()
const {
return changesMade; }
198 FailureOr<MappingInfo> mapType(Type type, Location errorLoc,
199 hw::InnerSymAttr sym = {});
202 void recordChanges(
bool changed) {
207 MLIRContext *context;
212 DenseMap<FieldRef, Value> nonHWValues;
215 DenseMap<Value, Value> hwOnlyAggMap;
218 SmallVector<Operation *> opsToErase;
225 bool changesMade =
false;
229LogicalResult Visitor::visit(FModuleLike mod) {
230 auto ports = mod.getPorts();
232 SmallVector<MappingInfo, 16> portMappings;
233 for (
auto &port : ports) {
234 auto pmi = mapType(port.type, port.loc, port.sym);
237 portMappings.push_back(*pmi);
242 size_t countWithErased = 0;
243 for (
auto &pmi : portMappings)
244 countWithErased += pmi.count(true);
247 SmallVector<std::pair<unsigned, PortInfo>> newPorts;
250 BitVector portsToErase(countWithErased);
254 llvm::dbgs().indent(2) <<
"- name: "
255 << cast<mlir::SymbolOpInterface>(*mod).getNameAttr()
257 llvm::dbgs().indent(4) <<
"ports:\n";
261 [&](
auto index,
auto &pmi,
auto newIndex) -> LogicalResult {
263 llvm::dbgs().indent(6) <<
"- name: " << ports[index].name <<
"\n";
264 llvm::dbgs().indent(8) <<
"type: " << ports[index].type <<
"\n";
265 llvm::dbgs().indent(8) <<
"mapping:\n";
266 pmi.print(llvm::dbgs(), 10);
267 llvm::dbgs() <<
"\n";
271 auto idxOfInsertPoint = index + 1;
276 auto &port = ports[index];
279 portsToErase.set(newIndex);
284 newPort.type = pmi.hwType;
285 newPort.sym = pmi.newSym;
286 newPorts.emplace_back(idxOfInsertPoint, newPort);
289 (pmi.newSym && port.sym.size() == pmi.newSym.size()));
292 if (!port.annotations.empty())
293 return mlir::emitError(port.loc)
294 <<
"annotations on open aggregates not handled yet";
296 assert(!port.sym && !pmi.newSym);
297 if (!port.annotations.empty())
298 return mlir::emitError(port.loc)
299 <<
"annotations found on aggregate with no HW";
303 for (
const auto &[findex, field] :
llvm::enumerate(pmi.fields)) {
304 auto name = StringAttr::get(context,
305 Twine(port.name.strref()) + field.suffix);
307 (
Direction)((
unsigned)port.direction ^ field.isFlip);
308 PortInfo pi(name, field.type, orientation, StringAttr{},
309 port.loc, std::nullopt);
310 newPorts.emplace_back(idxOfInsertPoint, pi);
318 mod.insertPorts(newPorts);
319 recordChanges(!newPorts.empty());
321 assert(mod->getNumRegions() == 1);
325 auto &blocks = mod->getRegion(0).getBlocks();
326 return !blocks.empty() ? &blocks.front() :
nullptr;
335 [&](
auto index, MappingInfo &pmi,
auto newIndex) {
342 assert(portsToErase.test(newIndex));
343 auto oldPort = block->getArgument(newIndex);
344 auto newPortIndex = newIndex;
348 hwOnlyAggMap[oldPort] =
349 block->getArgument(++newPortIndex);
351 for (
auto &field : pmi.fields) {
352 auto ref =
FieldRef(oldPort, field.fieldID);
353 auto newVal = block->getArgument(++newPortIndex);
354 nonHWValues[ref] = newVal;
356 for (
auto fieldID : pmi.mapToNullInteriors) {
357 auto ref =
FieldRef(oldPort, fieldID);
358 assert(!nonHWValues.count(ref));
359 nonHWValues[ref] = {};
368 LLVM_DEBUG(llvm::dbgs().indent(4) <<
"body:\n");
370 ->walk<mlir::WalkOrder::PreOrder>([&](Operation *op) -> WalkResult {
371 return dispatchVisitor(op);
376 assert(opsToErase.empty() || madeChanges());
379 for (
auto &op :
llvm::reverse(opsToErase))
384 mod.erasePorts(portsToErase);
385 recordChanges(portsToErase.any());
387 LLVM_DEBUG(refs.printStats(llvm::dbgs()));
392LogicalResult Visitor::visitExpr(OpenSubfieldOp op) {
418 opsToErase.push_back(op);
423 auto resultRef = refs.getFieldRefFromValue(op.getResult());
424 auto nonHWForResult = nonHWValues.find(resultRef);
425 if (nonHWForResult != nonHWValues.end()) {
427 if (
auto newResult = nonHWForResult->second) {
428 assert(op.getResult().getType() == newResult.getType());
429 assert(!type_isa<FIRRTLBaseType>(newResult.getType()));
430 op.getResult().replaceAllUsesWith(newResult);
435 assert(hwOnlyAggMap.count(op.getInput()));
437 auto newInput = hwOnlyAggMap[op.getInput()];
440 auto bundleType = type_cast<BundleType>(newInput.getType());
443 auto fieldName = op.getFieldName();
444 auto newFieldIndex = bundleType.getElementIndex(fieldName);
445 assert(newFieldIndex.has_value());
447 ImplicitLocOpBuilder builder(op.getLoc(), op);
448 auto newOp = builder.create<SubfieldOp>(newInput, *newFieldIndex);
449 if (
auto name = op->getAttrOfType<StringAttr>(
"name"))
450 newOp->setAttr(
"name", name);
452 hwOnlyAggMap[op.getResult()] = newOp;
454 if (type_isa<FIRRTLBaseType>(op.getType()))
455 op.getResult().replaceAllUsesWith(newOp.getResult());
460LogicalResult Visitor::visitExpr(OpenSubindexOp op) {
465 opsToErase.push_back(op);
470 auto resultRef = refs.getFieldRefFromValue(op.getResult());
471 auto nonHWForResult = nonHWValues.find(resultRef);
472 if (nonHWForResult != nonHWValues.end()) {
474 if (
auto newResult = nonHWForResult->second) {
475 assert(op.getResult().getType() == newResult.getType());
476 assert(!type_isa<FIRRTLBaseType>(newResult.getType()));
477 op.getResult().replaceAllUsesWith(newResult);
482 assert(hwOnlyAggMap.count(op.getInput()));
484 auto newInput = hwOnlyAggMap[op.getInput()];
487 ImplicitLocOpBuilder builder(op.getLoc(), op);
488 auto newOp = builder.create<SubindexOp>(newInput, op.getIndex());
489 if (
auto name = op->getAttrOfType<StringAttr>(
"name"))
490 newOp->setAttr(
"name", name);
492 hwOnlyAggMap[op.getResult()] = newOp;
494 if (type_isa<FIRRTLBaseType>(op.getType()))
495 op.getResult().replaceAllUsesWith(newOp.getResult());
499LogicalResult Visitor::visitDecl(InstanceOp op) {
502 SmallVector<MappingInfo, 16> portMappings;
504 for (
auto type : op.getResultTypes()) {
505 auto pmi = mapType(type, op.getLoc());
508 portMappings.push_back(*pmi);
512 size_t countWithErased = 0;
513 for (
auto &pmi : portMappings)
514 countWithErased += pmi.count(true);
517 SmallVector<std::pair<unsigned, PortInfo>> newPorts;
520 BitVector portsToErase(countWithErased);
524 llvm::dbgs().indent(6) <<
"- instance:\n";
525 llvm::dbgs().indent(10) <<
"name: " << op.getInstanceNameAttr() <<
"\n";
526 llvm::dbgs().indent(10) <<
"module: " << op.getModuleNameAttr() <<
"\n";
527 llvm::dbgs().indent(10) <<
"ports:\n";
531 [&](
auto index,
auto &pmi,
auto newIndex) -> LogicalResult {
533 llvm::dbgs().indent(12)
534 <<
"- name: " << op.getPortName(index) <<
"\n";
535 llvm::dbgs().indent(14) <<
"type: " << op.getType(index) <<
"\n";
536 llvm::dbgs().indent(14) <<
"mapping:\n";
537 pmi.print(llvm::dbgs(), 16);
538 llvm::dbgs() <<
"\n";
542 auto idxOfInsertPoint = index + 1;
548 portsToErase.set(newIndex);
550 auto portName = op.getPortName(index);
551 auto portDirection = op.getPortDirection(index);
552 auto loc = op.getLoc();
556 PortInfo hwPort(portName, pmi.hwType, portDirection,
559 newPorts.emplace_back(idxOfInsertPoint, hwPort);
562 if (!op.getPortAnnotation(index).empty())
563 return mlir::emitError(op.getLoc())
564 <<
"annotations on open aggregates not handled yet";
566 if (!op.getPortAnnotation(index).empty())
567 return mlir::emitError(op.getLoc())
568 <<
"annotations found on aggregate with no HW";
572 for (
const auto &[findex, field] :
llvm::enumerate(pmi.fields)) {
574 StringAttr::get(context, Twine(portName.strref()) + field.suffix);
576 (
Direction)((
unsigned)portDirection ^ field.isFlip);
577 PortInfo pi(name, field.type, orientation, StringAttr{},
579 newPorts.emplace_back(idxOfInsertPoint, pi);
587 if (newPorts.empty())
596 auto tempOp = op.cloneAndInsertPorts(newPorts);
597 opsToErase.push_back(tempOp);
598 ImplicitLocOpBuilder builder(op.getLoc(), op);
599 auto newInst = tempOp.erasePorts(builder, portsToErase);
603 [&](
auto index, MappingInfo &pmi,
auto newIndex) {
605 auto oldResult = op.getResult(index);
609 assert(oldResult.getType() == newInst.getType(newIndex));
610 oldResult.replaceAllUsesWith(newInst.getResult(newIndex));
615 auto newPortIndex = newIndex;
617 hwOnlyAggMap[oldResult] = newInst.getResult(newPortIndex++);
619 for (
auto &field : pmi.fields) {
620 auto ref =
FieldRef(oldResult, field.fieldID);
621 auto newVal = newInst.getResult(newPortIndex++);
622 assert(newVal.getType() == field.type);
623 nonHWValues[ref] = newVal;
625 for (
auto fieldID : pmi.mapToNullInteriors) {
626 auto ref =
FieldRef(oldResult, fieldID);
627 assert(!nonHWValues.count(ref));
628 nonHWValues[ref] = {};
632 if (failed(mappingResult))
635 opsToErase.push_back(op);
640LogicalResult Visitor::visitDecl(WireOp op) {
641 auto pmi = mapType(op.getResultTypes()[0], op.getLoc(), op.getInnerSymAttr());
644 MappingInfo mappings = *pmi;
647 llvm::dbgs().indent(6) <<
"- wire:\n";
648 llvm::dbgs().indent(10) <<
"name: " << op.getNameAttr() <<
"\n";
649 llvm::dbgs().indent(10) <<
"type: " << op.getType(0) <<
"\n";
650 llvm::dbgs().indent(12) <<
"mapping:\n";
651 mappings.print(llvm::dbgs(), 14);
652 llvm::dbgs() <<
"\n";
655 if (mappings.identity)
661 ImplicitLocOpBuilder builder(op.getLoc(), op);
663 if (!op.getAnnotations().empty())
664 return mlir::emitError(op.getLoc())
665 <<
"annotations on open aggregates not handled yet";
669 hwOnlyAggMap[op.getResult()] =
671 .create<WireOp>(mappings.hwType, op.getName(), op.getNameKind(),
672 op.getAnnotations(), mappings.newSym,
677 for (
auto &[type, fieldID, _, suffix] : mappings.fields)
678 nonHWValues[
FieldRef(op.getResult(), fieldID)] =
680 .create<WireOp>(type,
681 builder.getStringAttr(Twine(op.
getName()) + suffix),
682 NameKindEnum::DroppableName)
685 for (
auto fieldID : mappings.mapToNullInteriors)
686 nonHWValues[
FieldRef(op.getResult(), fieldID)] = {};
688 opsToErase.push_back(op);
697FailureOr<MappingInfo> Visitor::mapType(Type type, Location errorLoc,
698 hw::InnerSymAttr sym) {
699 MappingInfo pi{
false, {}, {}, {}};
700 auto ftype = type_dyn_cast<FIRRTLType>(type);
702 if (!ftype || !isa<OpenBundleType, OpenVectorType>(ftype)) {
707 SmallVector<hw::InnerSymPropertiesAttr> newProps;
710 auto recurse = [&](
auto &&f,
FIRRTLType type,
const Twine &suffix =
"",
711 bool flip =
false, uint64_t fieldID = 0,
712 uint64_t newFieldID = 0) -> FailureOr<FIRRTLBaseType> {
714 TypeSwitch<FIRRTLType, FailureOr<FIRRTLBaseType>>(type)
715 .Case<FIRRTLBaseType>([](
auto base) {
return base; })
716 .
template Case<OpenBundleType>([&](OpenBundleType obTy)
717 -> FailureOr<FIRRTLBaseType> {
718 SmallVector<BundleType::BundleElement> hwElements;
720 for (
const auto &[index, element] :
721 llvm::enumerate(obTy.getElements())) {
723 f(f, element.type, suffix +
"_" + element.name.strref(),
724 flip ^ element.isFlip, fieldID + obTy.getFieldID(index),
725 newFieldID +
id + 1);
729 hwElements.emplace_back(element.name, element.isFlip, *base);
734 if (hwElements.empty()) {
735 pi.mapToNullInteriors.push_back(fieldID);
739 return BundleType::get(context, hwElements, obTy.isConst());
741 .
template Case<OpenVectorType>([&](OpenVectorType ovTy)
742 -> FailureOr<FIRRTLBaseType> {
747 for (
auto idx :
llvm::
seq<size_t>(0U, ovTy.getNumElements())) {
749 f(f, ovTy.getElementType(), suffix +
"_" + Twine(idx),
flip,
750 fieldID + ovTy.getFieldID(idx), newFieldID +
id + 1);
751 if (failed(hwElementType))
753 assert((!convert || convert == *hwElementType) &&
754 "expected same hw type for all elements");
755 convert = *hwElementType;
761 pi.mapToNullInteriors.push_back(fieldID);
765 return FVectorType::get(convert, ovTy.getNumElements(),
768 .
template Case<RefType>([&](RefType ref) {
769 auto f = NonHWField{ref, fieldID,
flip, {}};
770 suffix.toVector(f.suffix);
771 pi.fields.emplace_back(std::move(f));
777 auto f = NonHWField{prop, fieldID,
flip, {}};
778 suffix.toVector(f.suffix);
779 pi.fields.emplace_back(std::move(f));
782 .Default([&](
auto _) {
783 pi.mapToNullInteriors.push_back(fieldID);
791 if (
auto symOnThis = sym.getSymIfExists(fieldID)) {
793 return mlir::emitError(errorLoc,
"inner symbol ")
794 << symOnThis <<
" mapped to non-HW type";
795 newProps.push_back(hw::InnerSymPropertiesAttr::get(
796 context, symOnThis, newFieldID,
797 StringAttr::get(context,
"public")));
802 auto hwType = recurse(recurse, ftype);
807 assert(pi.hwType != type);
811 assert(sym.size() == newProps.size());
813 if (!pi.hwType && !newProps.empty())
814 return mlir::emitError(errorLoc,
"inner symbol on non-HW type");
816 llvm::sort(newProps, [](
auto &p,
auto &q) {
817 return p.getFieldID() < q.getFieldID();
819 pi.newSym = hw::InnerSymAttr::get(context, newProps);
830struct LowerOpenAggsPass
831 :
public circt::firrtl::impl::LowerOpenAggsBase<LowerOpenAggsPass> {
832 LowerOpenAggsPass() =
default;
833 void runOnOperation()
override;
838void LowerOpenAggsPass::runOnOperation() {
840 SmallVector<Operation *, 0> ops(getOperation().getOps<FModuleLike>());
842 LLVM_DEBUG(llvm::dbgs() <<
"Visiting modules:\n");
843 std::atomic<bool> madeChanges =
false;
844 auto result = failableParallelForEach(&getContext(), ops, [&](Operation *op) {
845 Visitor visitor(&getContext());
846 auto result = visitor.visit(cast<FModuleLike>(op));
847 if (visitor.madeChanges())
855 markAllAnalysesPreserved();
860 return std::make_unique<LowerOpenAggsPass>();
assert(baseType &&"element must be base type")
static void dump(DIModule &module, raw_indented_ostream &os)
LogicalResult walkMappings(Range &&range, bool includeErased, llvm::function_ref< LogicalResult(size_t, MappingInfo &, size_t)> callback)
static Block * getBodyBlock(FModuleLike mod)
This class represents a reference to a specific field or element of an aggregate value.
This class provides a read-only projection over the MLIR attributes that represent a set of annotatio...
FIRRTLVisitor allows you to visit all of the expr/stmt/decls with one class declaration.
ResultType visitInvalidOp(Operation *op, ExtraArgs... args)
visitInvalidOp is an override point for non-FIRRTL dialect operations.
ResultType visitUnhandledOp(Operation *op, ExtraArgs... args)
visitUnhandledOp is an override point for FIRRTL dialect ops that the concrete visitor didn't bother ...
Caching version of getFieldRefFromValue.
Direction
This represents the direction of a single port.
std::unique_ptr< mlir::Pass > createLowerOpenAggsPass()
This is the pass constructor.
uint64_t getMaxFieldID(Type)
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
ModulePort::Direction flip(ModulePort::Direction direction)
Flip a port direction.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
llvm::raw_ostream & debugPassHeader(const mlir::Pass *pass, int width=80)
Write a boilerplate header for a pass to the debug stream.
This holds the name and type that describes the module's ports.