24#include "mlir/IR/BuiltinAttributes.h"
25#include "mlir/IR/Threading.h"
26#include "mlir/IR/Visitors.h"
27#include "mlir/Pass/Pass.h"
28#include "llvm/Support/Debug.h"
29#include "llvm/Support/ErrorHandling.h"
30#include "llvm/Support/FormatAdapters.h"
31#include "llvm/Support/FormatVariadic.h"
33#define DEBUG_TYPE "firrtl-lower-open-aggs"
37#define GEN_PASS_DEF_LOWEROPENAGGS
38#include "circt/Dialect/FIRRTL/Passes.h.inc"
43using namespace firrtl;
56 SmallString<16> suffix;
59 void print(raw_ostream &os,
unsigned indent = 0)
const;
61#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
63 LLVM_DUMP_METHOD
void dump()
const {
print(llvm::errs()); }
80 SmallVector<NonHWField, 0> fields;
84 SmallVector<uint64_t, 0> mapToNullInteriors;
86 hw::InnerSymAttr newSym = {};
89 size_t count(
bool includeErased =
false)
const {
92 return fields.size() + (hwType ? 1 : 0) + (includeErased ? 1 : 0);
96 void print(raw_ostream &os,
unsigned indent = 0)
const;
98#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
100 LLVM_DUMP_METHOD
void dump()
const {
print(llvm::errs()); }
106void NonHWField::print(llvm::raw_ostream &os,
unsigned indent)
const {
107 os << llvm::formatv(
"{0}- type: {2}\n"
110 "{1}suffix: \"{5}\"\n",
111 llvm::fmt_pad(
"", indent, 0),
112 llvm::fmt_pad(
"", indent + 2, 0), type, fieldID, isFlip,
115void MappingInfo::print(llvm::raw_ostream &os,
unsigned indent)
const {
121 os.indent(indent) <<
"hardware: ";
128 os.indent(indent) <<
"non-hardware:\n";
129 for (
auto &field : fields)
130 field.
print(os, indent + 2);
132 os.indent(indent) <<
"mappedToNull:\n";
133 for (
auto &null : mapToNullInteriors)
134 os.indent(indent + 2) <<
"- " << null <<
"\n";
136 os.indent(indent) <<
"newSym: ";
143template <
typename Range>
145 Range &&range,
bool includeErased,
146 llvm::function_ref<LogicalResult(
size_t, MappingInfo &,
size_t)> callback) {
148 for (
const auto &[index, pmi] : llvm::enumerate(range)) {
149 if (failed(callback(index, pmi, count)))
151 count += pmi.count(includeErased);
161class Visitor :
public FIRRTLVisitor<Visitor, LogicalResult> {
163 explicit Visitor(MLIRContext *context) : context(context){};
166 LogicalResult visit(FModuleLike mod);
172 LogicalResult visitDecl(InstanceOp op);
173 LogicalResult visitDecl(WireOp op);
175 LogicalResult visitExpr(OpenSubfieldOp op);
176 LogicalResult visitExpr(OpenSubindexOp op);
179 auto notOpenAggType = [](
auto type) {
180 return !isa<OpenBundleType, OpenVectorType>(type);
182 if (!llvm::all_of(op->getOperandTypes(), notOpenAggType) ||
183 !llvm::all_of(op->getResultTypes(), notOpenAggType))
184 return op->emitOpError(
185 "unhandled use or producer of types containing non-hw types");
192 bool madeChanges()
const {
return changesMade; }
198 FailureOr<MappingInfo> mapType(Type type, Location errorLoc,
199 hw::InnerSymAttr sym = {});
202 void recordChanges(
bool changed) {
207 MLIRContext *context;
212 DenseMap<FieldRef, Value> nonHWValues;
215 DenseMap<Value, Value> hwOnlyAggMap;
218 SmallVector<Operation *> opsToErase;
225 bool changesMade =
false;
229LogicalResult Visitor::visit(FModuleLike mod) {
230 auto ports = mod.getPorts();
232 SmallVector<MappingInfo, 16> portMappings;
233 for (
auto &port : ports) {
234 auto pmi = mapType(port.type, port.loc, port.sym);
237 portMappings.push_back(*pmi);
242 size_t countWithErased = 0;
243 for (
auto &pmi : portMappings)
244 countWithErased += pmi.count(true);
247 SmallVector<std::pair<unsigned, PortInfo>> newPorts;
250 BitVector portsToErase(countWithErased);
254 llvm::dbgs().indent(2) <<
"- name: "
255 << cast<mlir::SymbolOpInterface>(*mod).getNameAttr()
257 llvm::dbgs().indent(4) <<
"ports:\n";
261 [&](
auto index,
auto &pmi,
auto newIndex) -> LogicalResult {
263 llvm::dbgs().indent(6) <<
"- name: " << ports[index].name <<
"\n";
264 llvm::dbgs().indent(8) <<
"type: " << ports[index].type <<
"\n";
265 llvm::dbgs().indent(8) <<
"mapping:\n";
266 pmi.print(llvm::dbgs(), 10);
267 llvm::dbgs() <<
"\n";
271 auto idxOfInsertPoint = index + 1;
276 auto &port = ports[index];
279 portsToErase.set(newIndex);
284 newPort.type = pmi.hwType;
285 newPort.sym = pmi.newSym;
286 newPorts.emplace_back(idxOfInsertPoint, newPort);
289 (pmi.newSym && port.sym.size() == pmi.newSym.size()));
292 if (!port.annotations.empty())
293 return mlir::emitError(port.loc)
294 <<
"annotations on open aggregates not handled yet";
296 assert(!port.sym && !pmi.newSym);
297 if (!port.annotations.empty())
298 return mlir::emitError(port.loc)
299 <<
"annotations found on aggregate with no HW";
303 for (
const auto &[findex, field] :
llvm::enumerate(pmi.fields)) {
304 auto name = StringAttr::get(context,
305 Twine(port.name.strref()) + field.suffix);
307 (
Direction)((
unsigned)port.direction ^ field.isFlip);
308 PortInfo pi(name, field.type, orientation, StringAttr{},
309 port.loc, std::nullopt);
310 newPorts.emplace_back(idxOfInsertPoint, pi);
318 mod.insertPorts(newPorts);
319 recordChanges(!newPorts.empty());
321 assert(mod->getNumRegions() == 1);
325 auto &blocks = mod->getRegion(0).getBlocks();
326 return !blocks.empty() ? &blocks.front() :
nullptr;
335 [&](
auto index, MappingInfo &pmi,
auto newIndex) {
342 assert(portsToErase.test(newIndex));
343 auto oldPort = block->getArgument(newIndex);
344 auto newPortIndex = newIndex;
348 hwOnlyAggMap[oldPort] =
349 block->getArgument(++newPortIndex);
351 for (
auto &field : pmi.fields) {
352 auto ref =
FieldRef(oldPort, field.fieldID);
353 auto newVal = block->getArgument(++newPortIndex);
354 nonHWValues[ref] = newVal;
356 for (
auto fieldID : pmi.mapToNullInteriors) {
357 auto ref =
FieldRef(oldPort, fieldID);
358 assert(!nonHWValues.count(ref));
359 nonHWValues[ref] = {};
368 LLVM_DEBUG(llvm::dbgs().indent(4) <<
"body:\n");
370 ->walk<mlir::WalkOrder::PreOrder>([&](Operation *op) -> WalkResult {
371 return dispatchVisitor(op);
376 assert(opsToErase.empty() || madeChanges());
379 for (
auto &op :
llvm::reverse(opsToErase))
384 mod.erasePorts(portsToErase);
385 recordChanges(portsToErase.any());
387 LLVM_DEBUG(refs.printStats(llvm::dbgs()));
392LogicalResult Visitor::visitExpr(OpenSubfieldOp op) {
418 opsToErase.push_back(op);
423 auto resultRef = refs.getFieldRefFromValue(op.getResult());
424 auto nonHWForResult = nonHWValues.find(resultRef);
425 if (nonHWForResult != nonHWValues.end()) {
427 if (
auto newResult = nonHWForResult->second) {
428 assert(op.getResult().getType() == newResult.getType());
429 assert(!type_isa<FIRRTLBaseType>(newResult.getType()));
430 op.getResult().replaceAllUsesWith(newResult);
435 assert(hwOnlyAggMap.count(op.getInput()));
437 auto newInput = hwOnlyAggMap[op.getInput()];
440 auto bundleType = type_cast<BundleType>(newInput.getType());
443 auto fieldName = op.getFieldName();
444 auto newFieldIndex = bundleType.getElementIndex(fieldName);
445 assert(newFieldIndex.has_value());
447 ImplicitLocOpBuilder builder(op.getLoc(), op);
448 auto newOp = SubfieldOp::create(builder, newInput, *newFieldIndex);
449 if (
auto name = op->getAttrOfType<StringAttr>(
"name"))
450 newOp->setAttr(
"name", name);
452 hwOnlyAggMap[op.getResult()] = newOp;
454 if (type_isa<FIRRTLBaseType>(op.getType()))
455 op.getResult().replaceAllUsesWith(newOp.getResult());
460LogicalResult Visitor::visitExpr(OpenSubindexOp op) {
465 opsToErase.push_back(op);
470 auto resultRef = refs.getFieldRefFromValue(op.getResult());
471 auto nonHWForResult = nonHWValues.find(resultRef);
472 if (nonHWForResult != nonHWValues.end()) {
474 if (
auto newResult = nonHWForResult->second) {
475 assert(op.getResult().getType() == newResult.getType());
476 assert(!type_isa<FIRRTLBaseType>(newResult.getType()));
477 op.getResult().replaceAllUsesWith(newResult);
482 assert(hwOnlyAggMap.count(op.getInput()));
484 auto newInput = hwOnlyAggMap[op.getInput()];
487 ImplicitLocOpBuilder builder(op.getLoc(), op);
488 auto newOp = SubindexOp::create(builder, newInput, op.getIndex());
489 if (
auto name = op->getAttrOfType<StringAttr>(
"name"))
490 newOp->setAttr(
"name", name);
492 hwOnlyAggMap[op.getResult()] = newOp;
494 if (type_isa<FIRRTLBaseType>(op.getType()))
495 op.getResult().replaceAllUsesWith(newOp.getResult());
499LogicalResult Visitor::visitDecl(InstanceOp op) {
502 SmallVector<MappingInfo, 16> portMappings;
504 for (
auto type : op.getResultTypes()) {
505 auto pmi = mapType(type, op.getLoc());
508 portMappings.push_back(*pmi);
512 size_t countWithErased = 0;
513 for (
auto &pmi : portMappings)
514 countWithErased += pmi.count(true);
517 SmallVector<std::pair<unsigned, PortInfo>> newPorts;
520 BitVector portsToErase(countWithErased);
524 llvm::dbgs().indent(6) <<
"- instance:\n";
525 llvm::dbgs().indent(10) <<
"name: " << op.getInstanceNameAttr() <<
"\n";
526 llvm::dbgs().indent(10) <<
"module: " << op.getModuleNameAttr() <<
"\n";
527 llvm::dbgs().indent(10) <<
"ports:\n";
531 [&](
auto index,
auto &pmi,
auto newIndex) -> LogicalResult {
533 llvm::dbgs().indent(12)
534 <<
"- name: " << op.getPortName(index) <<
"\n";
535 llvm::dbgs().indent(14) <<
"type: " << op.getType(index) <<
"\n";
536 llvm::dbgs().indent(14) <<
"mapping:\n";
537 pmi.print(llvm::dbgs(), 16);
538 llvm::dbgs() <<
"\n";
542 auto idxOfInsertPoint = index + 1;
548 portsToErase.set(newIndex);
550 auto portName = op.getPortName(index);
551 auto portDirection = op.getPortDirection(index);
552 auto portDomain = op.getPortDomain(index);
553 auto loc = op.getLoc();
557 PortInfo hwPort(portName, pmi.hwType, portDirection,
561 newPorts.emplace_back(idxOfInsertPoint, hwPort);
564 if (!op.getPortAnnotation(index).empty())
565 return mlir::emitError(op.getLoc())
566 <<
"annotations on open aggregates not handled yet";
568 if (!op.getPortAnnotation(index).empty())
569 return mlir::emitError(op.getLoc())
570 <<
"annotations found on aggregate with no HW";
574 for (
const auto &[findex, field] :
llvm::enumerate(pmi.fields)) {
576 StringAttr::get(context, Twine(portName.strref()) + field.suffix);
578 (
Direction)((
unsigned)portDirection ^ field.isFlip);
579 PortInfo pi(name, field.type, orientation, StringAttr{},
580 loc, std::nullopt, portDomain);
581 newPorts.emplace_back(idxOfInsertPoint, pi);
589 if (newPorts.empty())
598 auto tempOp = op.cloneAndInsertPorts(newPorts);
599 opsToErase.push_back(tempOp);
600 ImplicitLocOpBuilder builder(op.getLoc(), op);
601 auto newInst = tempOp.erasePorts(builder, portsToErase);
605 [&](
auto index, MappingInfo &pmi,
auto newIndex) {
607 auto oldResult = op.getResult(index);
611 assert(oldResult.getType() == newInst.getType(newIndex));
612 oldResult.replaceAllUsesWith(newInst.getResult(newIndex));
617 auto newPortIndex = newIndex;
619 hwOnlyAggMap[oldResult] = newInst.getResult(newPortIndex++);
621 for (
auto &field : pmi.fields) {
622 auto ref =
FieldRef(oldResult, field.fieldID);
623 auto newVal = newInst.getResult(newPortIndex++);
624 assert(newVal.getType() == field.type);
625 nonHWValues[ref] = newVal;
627 for (
auto fieldID : pmi.mapToNullInteriors) {
628 auto ref =
FieldRef(oldResult, fieldID);
629 assert(!nonHWValues.count(ref));
630 nonHWValues[ref] = {};
634 if (failed(mappingResult))
637 opsToErase.push_back(op);
642LogicalResult Visitor::visitDecl(WireOp op) {
643 auto pmi = mapType(op.getResultTypes()[0], op.getLoc(), op.getInnerSymAttr());
646 MappingInfo mappings = *pmi;
649 llvm::dbgs().indent(6) <<
"- wire:\n";
650 llvm::dbgs().indent(10) <<
"name: " << op.getNameAttr() <<
"\n";
651 llvm::dbgs().indent(10) <<
"type: " << op.getType(0) <<
"\n";
652 llvm::dbgs().indent(12) <<
"mapping:\n";
653 mappings.print(llvm::dbgs(), 14);
654 llvm::dbgs() <<
"\n";
657 if (mappings.identity)
663 ImplicitLocOpBuilder builder(op.getLoc(), op);
665 if (!op.getAnnotations().empty())
666 return mlir::emitError(op.getLoc())
667 <<
"annotations on open aggregates not handled yet";
671 hwOnlyAggMap[op.getResult()] =
672 WireOp::create(builder, mappings.hwType, op.getName(), op.getNameKind(),
673 op.getAnnotations(), mappings.newSym, op.getForceable())
677 for (
auto &[type, fieldID, _, suffix] : mappings.fields)
678 nonHWValues[
FieldRef(op.getResult(), fieldID)] =
679 WireOp::create(builder, type,
680 builder.getStringAttr(Twine(op.
getName()) + suffix),
681 NameKindEnum::DroppableName)
684 for (
auto fieldID : mappings.mapToNullInteriors)
685 nonHWValues[
FieldRef(op.getResult(), fieldID)] = {};
687 opsToErase.push_back(op);
696FailureOr<MappingInfo> Visitor::mapType(Type type, Location errorLoc,
697 hw::InnerSymAttr sym) {
698 MappingInfo pi{
false, {}, {}, {}};
699 auto ftype = type_dyn_cast<FIRRTLType>(type);
701 if (!ftype || !isa<OpenBundleType, OpenVectorType>(ftype)) {
706 SmallVector<hw::InnerSymPropertiesAttr> newProps;
709 auto recurse = [&](
auto &&f,
FIRRTLType type,
const Twine &suffix =
"",
710 bool flip =
false, uint64_t fieldID = 0,
711 uint64_t newFieldID = 0) -> FailureOr<FIRRTLBaseType> {
713 TypeSwitch<FIRRTLType, FailureOr<FIRRTLBaseType>>(type)
714 .Case<FIRRTLBaseType>([](
auto base) {
return base; })
715 .
template Case<OpenBundleType>([&](OpenBundleType obTy)
716 -> FailureOr<FIRRTLBaseType> {
717 SmallVector<BundleType::BundleElement> hwElements;
719 for (
const auto &[index, element] :
720 llvm::enumerate(obTy.getElements())) {
722 f(f, element.type, suffix +
"_" + element.name.strref(),
723 flip ^ element.isFlip, fieldID + obTy.getFieldID(index),
724 newFieldID +
id + 1);
728 hwElements.emplace_back(element.name, element.isFlip, *base);
733 if (hwElements.empty()) {
734 pi.mapToNullInteriors.push_back(fieldID);
738 return BundleType::get(context, hwElements, obTy.isConst());
740 .
template Case<OpenVectorType>([&](OpenVectorType ovTy)
741 -> FailureOr<FIRRTLBaseType> {
746 for (
auto idx :
llvm::
seq<size_t>(0U, ovTy.getNumElements())) {
748 f(f, ovTy.getElementType(), suffix +
"_" + Twine(idx),
flip,
749 fieldID + ovTy.getFieldID(idx), newFieldID +
id + 1);
750 if (failed(hwElementType))
753 "expected same hw type for all elements");
760 pi.mapToNullInteriors.push_back(fieldID);
764 return FVectorType::get(
convert, ovTy.getNumElements(),
767 .
template Case<RefType>([&](RefType ref) {
768 auto f = NonHWField{ref, fieldID,
flip, {}};
769 suffix.toVector(f.suffix);
770 pi.fields.emplace_back(std::move(f));
776 auto f = NonHWField{prop, fieldID,
flip, {}};
777 suffix.toVector(f.suffix);
778 pi.fields.emplace_back(std::move(f));
781 .Default([&](
auto _) {
782 pi.mapToNullInteriors.push_back(fieldID);
790 if (
auto symOnThis = sym.getSymIfExists(fieldID)) {
792 return mlir::emitError(errorLoc,
"inner symbol ")
793 << symOnThis <<
" mapped to non-HW type";
794 newProps.push_back(hw::InnerSymPropertiesAttr::get(
795 context, symOnThis, newFieldID,
796 StringAttr::get(context,
"public")));
801 auto hwType = recurse(recurse, ftype);
806 assert(pi.hwType != type);
810 assert(sym.size() == newProps.size());
812 if (!pi.hwType && !newProps.empty())
813 return mlir::emitError(errorLoc,
"inner symbol on non-HW type");
815 llvm::sort(newProps, [](
auto &p,
auto &q) {
816 return p.getFieldID() < q.getFieldID();
818 pi.newSym = hw::InnerSymAttr::get(context, newProps);
829struct LowerOpenAggsPass
830 :
public circt::firrtl::impl::LowerOpenAggsBase<LowerOpenAggsPass> {
831 LowerOpenAggsPass() =
default;
832 void runOnOperation()
override;
837void LowerOpenAggsPass::runOnOperation() {
839 SmallVector<Operation *, 0> ops(getOperation().getOps<FModuleLike>());
841 LLVM_DEBUG(llvm::dbgs() <<
"Visiting modules:\n");
842 std::atomic<bool> madeChanges =
false;
843 auto result = failableParallelForEach(&getContext(), ops, [&](Operation *op) {
844 Visitor visitor(&getContext());
845 auto result = visitor.visit(cast<FModuleLike>(op));
846 if (visitor.madeChanges())
854 markAllAnalysesPreserved();
assert(baseType &&"element must be base type")
static void dump(DIModule &module, raw_indented_ostream &os)
static void print(TypedAttr val, llvm::raw_ostream &os)
LogicalResult walkMappings(Range &&range, bool includeErased, llvm::function_ref< LogicalResult(size_t, MappingInfo &, size_t)> callback)
static LogicalResult convert(StopBIOp op, StopBIOp::Adaptor adaptor, ConversionPatternRewriter &rewriter)
static Block * getBodyBlock(FModuleLike mod)
This class represents a reference to a specific field or element of an aggregate value.
This class provides a read-only projection over the MLIR attributes that represent a set of annotatio...
FIRRTLVisitor allows you to visit all of the expr/stmt/decls with one class declaration.
ResultType visitInvalidOp(Operation *op, ExtraArgs... args)
visitInvalidOp is an override point for non-FIRRTL dialect operations.
ResultType visitUnhandledOp(Operation *op, ExtraArgs... args)
visitUnhandledOp is an override point for FIRRTL dialect ops that the concrete visitor didn't bother ...
Caching version of getFieldRefFromValue.
Direction
This represents the direction of a single port.
uint64_t getMaxFieldID(Type)
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
ModulePort::Direction flip(ModulePort::Direction direction)
Flip a port direction.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
llvm::raw_ostream & debugPassHeader(const mlir::Pass *pass, int width=80)
Write a boilerplate header for a pass to the debug stream.
This holds the name and type that describes the module's ports.