25 #include "mlir/IR/BuiltinAttributes.h"
26 #include "mlir/IR/ImplicitLocOpBuilder.h"
27 #include "mlir/IR/Threading.h"
28 #include "mlir/IR/Visitors.h"
29 #include "mlir/Pass/Pass.h"
30 #include "llvm/Support/Debug.h"
31 #include "llvm/Support/ErrorHandling.h"
32 #include "llvm/Support/FormatAdapters.h"
33 #include "llvm/Support/FormatVariadic.h"
37 #define DEBUG_TYPE "firrtl-lower-open-aggs"
41 #define GEN_PASS_DEF_LOWEROPENAGGS
42 #include "circt/Dialect/FIRRTL/Passes.h.inc"
46 using namespace circt;
47 using namespace firrtl;
60 SmallString<16> suffix;
63 void print(raw_ostream &os,
unsigned indent = 0)
const;
65 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
67 LLVM_DUMP_METHOD
void dump()
const { print(llvm::errs()); }
84 SmallVector<NonHWField, 0> fields;
88 SmallVector<uint64_t, 0> mapToNullInteriors;
90 hw::InnerSymAttr newSym = {};
93 size_t count(
bool includeErased =
false)
const {
96 return fields.size() + (hwType ? 1 : 0) + (includeErased ? 1 : 0);
100 void print(raw_ostream &os,
unsigned indent = 0)
const;
102 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
104 LLVM_DUMP_METHOD
void dump()
const { print(llvm::errs()); }
110 void NonHWField::print(llvm::raw_ostream &os,
unsigned indent)
const {
111 os << llvm::formatv(
"{0}- type: {2}\n"
114 "{1}suffix: \"{5}\"\n",
115 llvm::fmt_pad(
"", indent, 0),
116 llvm::fmt_pad(
"", indent + 2, 0), type, fieldID, isFlip,
119 void MappingInfo::print(llvm::raw_ostream &os,
unsigned indent)
const {
125 os.indent(indent) <<
"hardware: ";
132 os.indent(indent) <<
"non-hardware:\n";
133 for (
auto &field : fields)
134 field.print(os, indent + 2);
136 os.indent(indent) <<
"mappedToNull:\n";
137 for (
auto &
null : mapToNullInteriors)
138 os.indent(indent + 2) <<
"- " <<
null <<
"\n";
140 os.indent(indent) <<
"newSym: ";
147 template <
typename Range>
149 Range &&range,
bool includeErased,
150 llvm::function_ref<LogicalResult(
size_t, MappingInfo &,
size_t)> callback) {
152 for (
const auto &[index, pmi] : llvm::enumerate(range)) {
153 if (failed(callback(index, pmi, count)))
155 count += pmi.count(includeErased);
165 class Visitor :
public FIRRTLVisitor<Visitor, LogicalResult> {
167 explicit Visitor(MLIRContext *context) : context(context){};
170 LogicalResult visit(FModuleLike mod);
176 LogicalResult visitDecl(InstanceOp op);
177 LogicalResult visitDecl(WireOp op);
179 LogicalResult visitExpr(OpenSubfieldOp op);
180 LogicalResult visitExpr(OpenSubindexOp op);
182 LogicalResult visitUnhandledOp(Operation *op) {
183 auto notOpenAggType = [](
auto type) {
184 return !isa<OpenBundleType, OpenVectorType>(type);
186 if (!llvm::all_of(op->getOperandTypes(), notOpenAggType) ||
187 !llvm::all_of(op->getResultTypes(), notOpenAggType))
188 return op->emitOpError(
189 "unhandled use or producer of types containing non-hw types");
193 LogicalResult visitInvalidOp(Operation *op) {
return visitUnhandledOp(op); }
196 bool madeChanges()
const {
return changesMade; }
202 FailureOr<MappingInfo> mapType(Type type, Location errorLoc,
203 hw::InnerSymAttr sym = {});
206 void recordChanges(
bool changed) {
211 MLIRContext *context;
216 DenseMap<FieldRef, Value> nonHWValues;
219 DenseMap<Value, Value> hwOnlyAggMap;
222 SmallVector<Operation *> opsToErase;
229 bool changesMade =
false;
233 LogicalResult Visitor::visit(FModuleLike mod) {
234 auto ports = mod.getPorts();
236 SmallVector<MappingInfo, 16> portMappings;
237 for (
auto &port : ports) {
238 auto pmi = mapType(port.type, port.loc, port.sym);
241 portMappings.push_back(*pmi);
246 size_t countWithErased = 0;
247 for (
auto &pmi : portMappings)
248 countWithErased += pmi.count(
true);
251 SmallVector<std::pair<unsigned, PortInfo>> newPorts;
254 BitVector portsToErase(countWithErased);
258 llvm::dbgs().indent(2) <<
"- name: "
259 << cast<mlir::SymbolOpInterface>(*mod).getNameAttr()
261 llvm::dbgs().indent(4) <<
"ports:\n";
265 [&](
auto index,
auto &pmi,
auto newIndex) -> LogicalResult {
267 llvm::dbgs().indent(6) <<
"- name: " << ports[index].name <<
"\n";
268 llvm::dbgs().indent(8) <<
"type: " << ports[index].type <<
"\n";
269 llvm::dbgs().indent(8) <<
"mapping:\n";
270 pmi.print(llvm::dbgs(), 10);
271 llvm::dbgs() <<
"\n";
275 auto idxOfInsertPoint = index + 1;
280 auto &port = ports[index];
283 portsToErase.set(newIndex);
288 newPort.type = pmi.hwType;
289 newPort.sym = pmi.newSym;
290 newPorts.emplace_back(idxOfInsertPoint, newPort);
293 (pmi.newSym && port.sym.size() == pmi.newSym.size()));
296 if (!port.annotations.empty())
297 return mlir::emitError(port.loc)
298 <<
"annotations on open aggregates not handled yet";
300 assert(!port.sym && !pmi.newSym);
301 if (!port.annotations.empty())
302 return mlir::emitError(port.loc)
303 <<
"annotations found on aggregate with no HW";
307 for (
const auto &[findex, field] : llvm::enumerate(pmi.fields)) {
309 Twine(port.name.strref()) + field.suffix);
311 (
Direction)((
unsigned)port.direction ^ field.isFlip);
312 PortInfo pi(name, field.type, orientation, StringAttr{},
313 port.loc, std::nullopt);
314 newPorts.emplace_back(idxOfInsertPoint, pi);
322 mod.insertPorts(newPorts);
323 recordChanges(!newPorts.empty());
325 assert(mod->getNumRegions() == 1);
329 auto &blocks = mod->getRegion(0).getBlocks();
330 return !blocks.empty() ? &blocks.front() :
nullptr;
339 [&](
auto index, MappingInfo &pmi,
auto newIndex) {
346 assert(portsToErase.test(newIndex));
347 auto oldPort = block->getArgument(newIndex);
348 auto newPortIndex = newIndex;
352 hwOnlyAggMap[oldPort] =
353 block->getArgument(++newPortIndex);
355 for (
auto &field : pmi.fields) {
356 auto ref =
FieldRef(oldPort, field.fieldID);
357 auto newVal = block->getArgument(++newPortIndex);
358 nonHWValues[ref] = newVal;
360 for (
auto fieldID : pmi.mapToNullInteriors) {
361 auto ref =
FieldRef(oldPort, fieldID);
362 assert(!nonHWValues.count(ref));
363 nonHWValues[ref] = {};
372 LLVM_DEBUG(llvm::dbgs().indent(4) <<
"body:\n");
374 ->walk<mlir::WalkOrder::PreOrder>([&](Operation *op) -> WalkResult {
375 return dispatchVisitor(op);
380 assert(opsToErase.empty() || madeChanges());
383 for (
auto &op : llvm::reverse(opsToErase))
388 mod.erasePorts(portsToErase);
389 recordChanges(portsToErase.any());
391 LLVM_DEBUG(refs.printStats(llvm::dbgs()));
396 LogicalResult Visitor::visitExpr(OpenSubfieldOp op) {
422 opsToErase.push_back(op);
427 auto resultRef = refs.getFieldRefFromValue(op.getResult());
428 auto nonHWForResult = nonHWValues.find(resultRef);
429 if (nonHWForResult != nonHWValues.end()) {
431 if (
auto newResult = nonHWForResult->second) {
432 assert(op.getResult().getType() == newResult.getType());
433 assert(!type_isa<FIRRTLBaseType>(newResult.getType()));
434 op.getResult().replaceAllUsesWith(newResult);
439 assert(hwOnlyAggMap.count(op.getInput()));
441 auto newInput = hwOnlyAggMap[op.getInput()];
444 auto bundleType = type_cast<BundleType>(newInput.getType());
447 auto fieldName = op.getFieldName();
448 auto newFieldIndex = bundleType.getElementIndex(fieldName);
449 assert(newFieldIndex.has_value());
451 ImplicitLocOpBuilder builder(op.getLoc(), op);
452 auto newOp = builder.create<SubfieldOp>(newInput, *newFieldIndex);
453 if (
auto name = op->getAttrOfType<StringAttr>(
"name"))
454 newOp->setAttr(
"name", name);
456 hwOnlyAggMap[op.getResult()] = newOp;
458 if (type_isa<FIRRTLBaseType>(op.getType()))
459 op.getResult().replaceAllUsesWith(newOp.getResult());
464 LogicalResult Visitor::visitExpr(OpenSubindexOp op) {
469 opsToErase.push_back(op);
474 auto resultRef = refs.getFieldRefFromValue(op.getResult());
475 auto nonHWForResult = nonHWValues.find(resultRef);
476 if (nonHWForResult != nonHWValues.end()) {
478 if (
auto newResult = nonHWForResult->second) {
479 assert(op.getResult().getType() == newResult.getType());
480 assert(!type_isa<FIRRTLBaseType>(newResult.getType()));
481 op.getResult().replaceAllUsesWith(newResult);
486 assert(hwOnlyAggMap.count(op.getInput()));
488 auto newInput = hwOnlyAggMap[op.getInput()];
491 ImplicitLocOpBuilder builder(op.getLoc(), op);
492 auto newOp = builder.create<SubindexOp>(newInput, op.getIndex());
493 if (
auto name = op->getAttrOfType<StringAttr>(
"name"))
494 newOp->setAttr(
"name", name);
496 hwOnlyAggMap[op.getResult()] = newOp;
498 if (type_isa<FIRRTLBaseType>(op.getType()))
499 op.getResult().replaceAllUsesWith(newOp.getResult());
503 LogicalResult Visitor::visitDecl(InstanceOp op) {
506 SmallVector<MappingInfo, 16> portMappings;
508 for (
auto type : op.getResultTypes()) {
509 auto pmi = mapType(type, op.getLoc());
512 portMappings.push_back(*pmi);
516 size_t countWithErased = 0;
517 for (
auto &pmi : portMappings)
518 countWithErased += pmi.count(
true);
521 SmallVector<std::pair<unsigned, PortInfo>> newPorts;
524 BitVector portsToErase(countWithErased);
528 llvm::dbgs().indent(6) <<
"- instance:\n";
529 llvm::dbgs().indent(10) <<
"name: " << op.getInstanceNameAttr() <<
"\n";
530 llvm::dbgs().indent(10) <<
"module: " << op.getModuleNameAttr() <<
"\n";
531 llvm::dbgs().indent(10) <<
"ports:\n";
535 [&](
auto index,
auto &pmi,
auto newIndex) -> LogicalResult {
537 llvm::dbgs().indent(12)
538 <<
"- name: " << op.getPortName(index) <<
"\n";
539 llvm::dbgs().indent(14) <<
"type: " << op.getType(index) <<
"\n";
540 llvm::dbgs().indent(14) <<
"mapping:\n";
541 pmi.print(llvm::dbgs(), 16);
542 llvm::dbgs() <<
"\n";
546 auto idxOfInsertPoint = index + 1;
552 portsToErase.set(newIndex);
554 auto portName = op.getPortName(index);
555 auto portDirection = op.getPortDirection(index);
556 auto loc = op.getLoc();
560 PortInfo hwPort(portName, pmi.hwType, portDirection,
563 newPorts.emplace_back(idxOfInsertPoint, hwPort);
566 if (!op.getPortAnnotation(index).empty())
567 return mlir::emitError(op.getLoc())
568 <<
"annotations on open aggregates not handled yet";
570 if (!op.getPortAnnotation(index).empty())
571 return mlir::emitError(op.getLoc())
572 <<
"annotations found on aggregate with no HW";
576 for (
const auto &[findex, field] : llvm::enumerate(pmi.fields)) {
580 (
Direction)((
unsigned)portDirection ^ field.isFlip);
581 PortInfo pi(name, field.type, orientation, StringAttr{},
583 newPorts.emplace_back(idxOfInsertPoint, pi);
591 if (newPorts.empty())
600 auto tempOp = op.cloneAndInsertPorts(newPorts);
601 opsToErase.push_back(tempOp);
602 ImplicitLocOpBuilder builder(op.getLoc(), op);
603 auto newInst = tempOp.erasePorts(builder, portsToErase);
607 [&](
auto index, MappingInfo &pmi,
auto newIndex) {
609 auto oldResult = op.getResult(index);
613 assert(oldResult.getType() == newInst.getType(newIndex));
614 oldResult.replaceAllUsesWith(newInst.getResult(newIndex));
619 auto newPortIndex = newIndex;
621 hwOnlyAggMap[oldResult] = newInst.getResult(newPortIndex++);
623 for (
auto &field : pmi.fields) {
624 auto ref = FieldRef(oldResult, field.fieldID);
625 auto newVal = newInst.getResult(newPortIndex++);
626 assert(newVal.getType() == field.type);
627 nonHWValues[ref] = newVal;
629 for (
auto fieldID : pmi.mapToNullInteriors) {
630 auto ref = FieldRef(oldResult, fieldID);
631 assert(!nonHWValues.count(ref));
632 nonHWValues[ref] = {};
636 if (failed(mappingResult))
639 opsToErase.push_back(op);
644 LogicalResult Visitor::visitDecl(WireOp op) {
645 auto pmi = mapType(op.getResultTypes()[0], op.getLoc(), op.getInnerSymAttr());
648 MappingInfo mappings = *pmi;
651 llvm::dbgs().indent(6) <<
"- wire:\n";
652 llvm::dbgs().indent(10) <<
"name: " << op.getNameAttr() <<
"\n";
653 llvm::dbgs().indent(10) <<
"type: " << op.getType(0) <<
"\n";
654 llvm::dbgs().indent(12) <<
"mapping:\n";
655 mappings.print(llvm::dbgs(), 14);
656 llvm::dbgs() <<
"\n";
659 if (mappings.identity)
665 ImplicitLocOpBuilder builder(op.getLoc(), op);
667 if (!op.getAnnotations().empty())
668 return mlir::emitError(op.getLoc())
669 <<
"annotations on open aggregates not handled yet";
673 hwOnlyAggMap[op.getResult()] =
675 .create<WireOp>(mappings.hwType, op.getName(), op.getNameKind(),
676 op.getAnnotations(), mappings.newSym,
681 for (
auto &[type, fieldID, _, suffix] : mappings.fields)
682 nonHWValues[
FieldRef(op.getResult(), fieldID)] =
684 .create<WireOp>(type,
685 builder.getStringAttr(Twine(op.getName()) + suffix),
686 NameKindEnum::DroppableName)
689 for (
auto fieldID : mappings.mapToNullInteriors)
690 nonHWValues[
FieldRef(op.getResult(), fieldID)] = {};
692 opsToErase.push_back(op);
701 FailureOr<MappingInfo> Visitor::mapType(Type type, Location errorLoc,
702 hw::InnerSymAttr sym) {
703 MappingInfo pi{
false, {}, {}, {}};
704 auto ftype = type_dyn_cast<FIRRTLType>(type);
706 if (!ftype || !isa<OpenBundleType, OpenVectorType>(ftype)) {
711 SmallVector<hw::InnerSymPropertiesAttr> newProps;
714 auto recurse = [&](
auto &&f,
FIRRTLType type,
const Twine &suffix =
"",
715 bool flip =
false, uint64_t fieldID = 0,
716 uint64_t newFieldID = 0) -> FailureOr<FIRRTLBaseType> {
718 TypeSwitch<FIRRTLType, FailureOr<FIRRTLBaseType>>(type)
719 .Case<FIRRTLBaseType>([](
auto base) {
return base; })
720 .
template Case<OpenBundleType>([&](OpenBundleType obTy)
721 -> FailureOr<FIRRTLBaseType> {
722 SmallVector<BundleType::BundleElement> hwElements;
724 for (
const auto &[index, element] :
725 llvm::enumerate(obTy.getElements())) {
727 f(f, element.type, suffix +
"_" + element.name.strref(),
728 flip ^ element.isFlip, fieldID + obTy.getFieldID(index),
729 newFieldID +
id + 1);
733 hwElements.emplace_back(element.name, element.isFlip, *base);
738 if (hwElements.empty()) {
739 pi.mapToNullInteriors.push_back(fieldID);
745 .
template Case<OpenVectorType>([&](OpenVectorType ovTy)
746 -> FailureOr<FIRRTLBaseType> {
751 for (
auto idx : llvm::seq<size_t>(0U, ovTy.getNumElements())) {
753 f(f, ovTy.getElementType(), suffix +
"_" + Twine(idx),
flip,
754 fieldID + ovTy.getFieldID(idx), newFieldID +
id + 1);
755 if (failed(hwElementType))
757 assert((!convert || convert == *hwElementType) &&
758 "expected same hw type for all elements");
759 convert = *hwElementType;
765 pi.mapToNullInteriors.push_back(fieldID);
772 .
template Case<RefType>([&](RefType ref) {
773 auto f = NonHWField{ref, fieldID,
flip, {}};
774 suffix.toVector(f.suffix);
775 pi.fields.emplace_back(std::move(f));
781 auto f = NonHWField{prop, fieldID,
flip, {}};
782 suffix.toVector(f.suffix);
783 pi.fields.emplace_back(std::move(f));
786 .Default([&](
auto _) {
787 pi.mapToNullInteriors.push_back(fieldID);
795 if (
auto symOnThis = sym.getSymIfExists(fieldID)) {
797 return mlir::emitError(errorLoc,
"inner symbol ")
798 << symOnThis <<
" mapped to non-HW type";
800 context, symOnThis, newFieldID,
806 auto hwType = recurse(recurse, ftype);
811 assert(pi.hwType != type);
815 assert(sym.size() == newProps.size());
817 if (!pi.hwType && !newProps.empty())
818 return mlir::emitError(errorLoc,
"inner symbol on non-HW type");
820 llvm::sort(newProps, [](
auto &p,
auto &q) {
821 return p.getFieldID() < q.getFieldID();
834 struct LowerOpenAggsPass
835 :
public circt::firrtl::impl::LowerOpenAggsBase<LowerOpenAggsPass> {
836 LowerOpenAggsPass() =
default;
837 void runOnOperation()
override;
842 void LowerOpenAggsPass::runOnOperation() {
844 SmallVector<Operation *, 0> ops(getOperation().getOps<FModuleLike>());
846 LLVM_DEBUG(llvm::dbgs() <<
"Visiting modules:\n");
847 std::atomic<bool> madeChanges =
false;
848 auto result = failableParallelForEach(&getContext(), ops, [&](Operation *op) {
849 Visitor visitor(&getContext());
850 auto result = visitor.visit(cast<FModuleLike>(op));
851 if (visitor.madeChanges())
859 markAllAnalysesPreserved();
864 return std::make_unique<LowerOpenAggsPass>();
assert(baseType &&"element must be base type")
static void dump(DIModule &module, raw_indented_ostream &os)
LogicalResult walkMappings(Range &&range, bool includeErased, llvm::function_ref< LogicalResult(size_t, MappingInfo &, size_t)> callback)
static Block * getBodyBlock(FModuleLike mod)
This class represents a reference to a specific field or element of an aggregate value.
This class provides a read-only projection over the MLIR attributes that represent a set of annotatio...
FIRRTLVisitor allows you to visit all of the expr/stmt/decls with one class declaration.
Caching version of getFieldRefFromValue.
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Direction flip(Direction direction)
Flip a port direction.
Direction
This represents the direction of a single port.
std::unique_ptr< mlir::Pass > createLowerOpenAggsPass()
This is the pass constructor.
uint64_t getMaxFieldID(Type)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
llvm::raw_ostream & debugPassHeader(const mlir::Pass *pass, int width=80)
Write a boilerplate header for a pass to the debug stream.
This holds the name and type that describes the module's ports.