26 #include "mlir/IR/BuiltinAttributes.h"
27 #include "mlir/IR/ImplicitLocOpBuilder.h"
28 #include "mlir/IR/Threading.h"
29 #include "mlir/IR/Visitors.h"
30 #include "llvm/Support/Debug.h"
31 #include "llvm/Support/ErrorHandling.h"
32 #include "llvm/Support/FormatAdapters.h"
33 #include "llvm/Support/FormatVariadic.h"
37 #define DEBUG_TYPE "firrtl-lower-open-aggs"
39 using namespace circt;
40 using namespace firrtl;
53 SmallString<16> suffix;
56 void print(raw_ostream &os,
unsigned indent = 0)
const;
58 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
60 LLVM_DUMP_METHOD
void dump()
const { print(llvm::errs()); }
77 SmallVector<NonHWField, 0> fields;
81 SmallVector<uint64_t, 0> mapToNullInteriors;
83 hw::InnerSymAttr newSym = {};
86 size_t count(
bool includeErased =
false)
const {
89 return fields.size() + (hwType ? 1 : 0) + (includeErased ? 1 : 0);
93 void print(raw_ostream &os,
unsigned indent = 0)
const;
95 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
97 LLVM_DUMP_METHOD
void dump()
const { print(llvm::errs()); }
103 void NonHWField::print(llvm::raw_ostream &os,
unsigned indent)
const {
104 os << llvm::formatv(
"{0}- type: {2}\n"
107 "{1}suffix: \"{5}\"\n",
108 llvm::fmt_pad(
"", indent, 0),
109 llvm::fmt_pad(
"", indent + 2, 0), type, fieldID, isFlip,
112 void MappingInfo::print(llvm::raw_ostream &os,
unsigned indent)
const {
118 os.indent(indent) <<
"hardware: ";
125 os.indent(indent) <<
"non-hardware:\n";
126 for (
auto &field : fields)
127 field.print(os, indent + 2);
129 os.indent(indent) <<
"mappedToNull:\n";
130 for (
auto &
null : mapToNullInteriors)
131 os.indent(indent + 2) <<
"- " <<
null <<
"\n";
133 os.indent(indent) <<
"newSym: ";
140 template <
typename Range>
142 Range &&range,
bool includeErased,
143 llvm::function_ref<LogicalResult(
size_t, MappingInfo &,
size_t)> callback) {
145 for (
const auto &[index, pmi] : llvm::enumerate(range)) {
146 if (failed(callback(index, pmi, count)))
148 count += pmi.count(includeErased);
158 class Visitor :
public FIRRTLVisitor<Visitor, LogicalResult> {
160 explicit Visitor(MLIRContext *context) : context(context){};
163 LogicalResult visit(FModuleLike mod);
169 LogicalResult visitDecl(InstanceOp op);
170 LogicalResult visitDecl(WireOp op);
172 LogicalResult visitExpr(OpenSubfieldOp op);
173 LogicalResult visitExpr(OpenSubindexOp op);
175 LogicalResult visitUnhandledOp(Operation *op) {
176 auto notOpenAggType = [](
auto type) {
177 return !isa<OpenBundleType, OpenVectorType>(type);
179 if (!llvm::all_of(op->getOperandTypes(), notOpenAggType) ||
180 !llvm::all_of(op->getResultTypes(), notOpenAggType))
181 return op->emitOpError(
182 "unhandled use or producer of types containing non-hw types");
186 LogicalResult visitInvalidOp(Operation *op) {
return visitUnhandledOp(op); }
189 bool madeChanges()
const {
return changesMade; }
196 hw::InnerSymAttr sym = {});
199 void recordChanges(
bool changed) {
204 MLIRContext *context;
209 DenseMap<FieldRef, Value> nonHWValues;
212 DenseMap<Value, Value> hwOnlyAggMap;
215 SmallVector<Operation *> opsToErase;
222 bool changesMade =
false;
226 LogicalResult Visitor::visit(FModuleLike mod) {
227 auto ports = mod.getPorts();
229 SmallVector<MappingInfo, 16> portMappings;
230 for (
auto &port : ports) {
231 auto pmi = mapType(port.type, port.loc, port.sym);
234 portMappings.push_back(*pmi);
239 size_t countWithErased = 0;
240 for (
auto &pmi : portMappings)
241 countWithErased += pmi.count(
true);
244 SmallVector<std::pair<unsigned, PortInfo>> newPorts;
247 BitVector portsToErase(countWithErased);
251 llvm::dbgs().indent(2) <<
"- name: "
252 << cast<mlir::SymbolOpInterface>(*mod).getNameAttr()
254 llvm::dbgs().indent(4) <<
"ports:\n";
258 [&](
auto index,
auto &pmi,
auto newIndex) -> LogicalResult {
260 llvm::dbgs().indent(6) <<
"- name: " << ports[index].name <<
"\n";
261 llvm::dbgs().indent(8) <<
"type: " << ports[index].type <<
"\n";
262 llvm::dbgs().indent(8) <<
"mapping:\n";
263 pmi.print(llvm::dbgs(), 10);
264 llvm::dbgs() <<
"\n";
268 auto idxOfInsertPoint = index + 1;
273 auto &port = ports[index];
276 portsToErase.set(newIndex);
281 newPort.type = pmi.hwType;
282 newPort.sym = pmi.newSym;
283 newPorts.emplace_back(idxOfInsertPoint, newPort);
286 (pmi.newSym && port.sym.size() == pmi.newSym.size()));
289 if (!port.annotations.empty())
290 return mlir::emitError(port.loc)
291 <<
"annotations on open aggregates not handled yet";
293 assert(!port.sym && !pmi.newSym);
294 if (!port.annotations.empty())
295 return mlir::emitError(port.loc)
296 <<
"annotations found on aggregate with no HW";
300 for (
const auto &[findex, field] : llvm::enumerate(pmi.fields)) {
302 Twine(port.name.strref()) + field.suffix);
304 (
Direction)((
unsigned)port.direction ^ field.isFlip);
305 PortInfo pi(name, field.type, orientation, StringAttr{},
306 port.loc, std::nullopt);
307 newPorts.emplace_back(idxOfInsertPoint, pi);
315 mod.insertPorts(newPorts);
316 recordChanges(!newPorts.empty());
318 assert(mod->getNumRegions() == 1);
321 auto getBodyBlock = [](
auto mod) {
322 auto &blocks = mod->getRegion(0).getBlocks();
323 return !blocks.empty() ? &blocks.front() :
nullptr;
328 if (
auto *block = getBodyBlock(mod)) {
332 [&](
auto index, MappingInfo &pmi,
auto newIndex) {
339 assert(portsToErase.test(newIndex));
340 auto oldPort = block->getArgument(newIndex);
341 auto newPortIndex = newIndex;
345 hwOnlyAggMap[oldPort] =
346 block->getArgument(++newPortIndex);
348 for (
auto &field : pmi.fields) {
349 auto ref =
FieldRef(oldPort, field.fieldID);
350 auto newVal = block->getArgument(++newPortIndex);
351 nonHWValues[ref] = newVal;
353 for (
auto fieldID : pmi.mapToNullInteriors) {
354 auto ref =
FieldRef(oldPort, fieldID);
355 assert(!nonHWValues.count(ref));
356 nonHWValues[ref] = {};
365 LLVM_DEBUG(llvm::dbgs().indent(4) <<
"body:\n");
367 ->walk<mlir::WalkOrder::PreOrder>([&](Operation *op) -> WalkResult {
368 return dispatchVisitor(op);
373 assert(opsToErase.empty() || madeChanges());
376 for (
auto &op : llvm::reverse(opsToErase))
381 mod.erasePorts(portsToErase);
382 recordChanges(portsToErase.any());
384 LLVM_DEBUG(refs.printStats(llvm::dbgs()));
389 LogicalResult Visitor::visitExpr(OpenSubfieldOp op) {
415 opsToErase.push_back(op);
420 auto resultRef = refs.getFieldRefFromValue(op.getResult());
421 auto nonHWForResult = nonHWValues.find(resultRef);
422 if (nonHWForResult != nonHWValues.end()) {
424 if (
auto newResult = nonHWForResult->second) {
425 assert(op.getResult().getType() == newResult.getType());
426 assert(!type_isa<FIRRTLBaseType>(newResult.getType()));
427 op.getResult().replaceAllUsesWith(newResult);
432 assert(hwOnlyAggMap.count(op.getInput()));
434 auto newInput = hwOnlyAggMap[op.getInput()];
437 auto bundleType = type_cast<BundleType>(newInput.getType());
440 auto fieldName = op.getFieldName();
441 auto newFieldIndex = bundleType.getElementIndex(fieldName);
442 assert(newFieldIndex.has_value());
444 ImplicitLocOpBuilder
builder(op.getLoc(), op);
445 auto newOp =
builder.create<SubfieldOp>(newInput, *newFieldIndex);
446 if (
auto name = op->getAttrOfType<StringAttr>(
"name"))
447 newOp->setAttr(
"name", name);
449 hwOnlyAggMap[op.getResult()] = newOp;
451 if (type_isa<FIRRTLBaseType>(op.getType()))
452 op.getResult().replaceAllUsesWith(newOp.getResult());
457 LogicalResult Visitor::visitExpr(OpenSubindexOp op) {
462 opsToErase.push_back(op);
467 auto resultRef = refs.getFieldRefFromValue(op.getResult());
468 auto nonHWForResult = nonHWValues.find(resultRef);
469 if (nonHWForResult != nonHWValues.end()) {
471 if (
auto newResult = nonHWForResult->second) {
472 assert(op.getResult().getType() == newResult.getType());
473 assert(!type_isa<FIRRTLBaseType>(newResult.getType()));
474 op.getResult().replaceAllUsesWith(newResult);
479 assert(hwOnlyAggMap.count(op.getInput()));
481 auto newInput = hwOnlyAggMap[op.getInput()];
484 ImplicitLocOpBuilder
builder(op.getLoc(), op);
485 auto newOp =
builder.create<SubindexOp>(newInput, op.getIndex());
486 if (
auto name = op->getAttrOfType<StringAttr>(
"name"))
487 newOp->setAttr(
"name", name);
489 hwOnlyAggMap[op.getResult()] = newOp;
491 if (type_isa<FIRRTLBaseType>(op.getType()))
492 op.getResult().replaceAllUsesWith(newOp.getResult());
496 LogicalResult Visitor::visitDecl(InstanceOp op) {
499 SmallVector<MappingInfo, 16> portMappings;
501 for (
auto type : op.getResultTypes()) {
502 auto pmi = mapType(type, op.getLoc());
505 portMappings.push_back(*pmi);
509 size_t countWithErased = 0;
510 for (
auto &pmi : portMappings)
511 countWithErased += pmi.count(
true);
514 SmallVector<std::pair<unsigned, PortInfo>> newPorts;
517 BitVector portsToErase(countWithErased);
521 llvm::dbgs().indent(6) <<
"- instance:\n";
522 llvm::dbgs().indent(10) <<
"name: " << op.getInstanceNameAttr() <<
"\n";
523 llvm::dbgs().indent(10) <<
"module: " << op.getModuleNameAttr() <<
"\n";
524 llvm::dbgs().indent(10) <<
"ports:\n";
528 [&](
auto index,
auto &pmi,
auto newIndex) -> LogicalResult {
530 llvm::dbgs().indent(12)
531 <<
"- name: " << op.getPortName(index) <<
"\n";
532 llvm::dbgs().indent(14) <<
"type: " << op.getType(index) <<
"\n";
533 llvm::dbgs().indent(14) <<
"mapping:\n";
534 pmi.print(llvm::dbgs(), 16);
535 llvm::dbgs() <<
"\n";
539 auto idxOfInsertPoint = index + 1;
545 portsToErase.set(newIndex);
547 auto portName = op.getPortName(index);
548 auto portDirection = op.getPortDirection(index);
549 auto loc = op.getLoc();
553 PortInfo hwPort(portName, pmi.hwType, portDirection,
556 newPorts.emplace_back(idxOfInsertPoint, hwPort);
559 if (!op.getPortAnnotation(index).empty())
560 return mlir::emitError(op.getLoc())
561 <<
"annotations on open aggregates not handled yet";
563 if (!op.getPortAnnotation(index).empty())
564 return mlir::emitError(op.getLoc())
565 <<
"annotations found on aggregate with no HW";
569 for (
const auto &[findex, field] : llvm::enumerate(pmi.fields)) {
573 (
Direction)((
unsigned)portDirection ^ field.isFlip);
574 PortInfo pi(name, field.type, orientation, StringAttr{},
576 newPorts.emplace_back(idxOfInsertPoint, pi);
584 if (newPorts.empty())
593 auto tempOp = op.cloneAndInsertPorts(newPorts);
594 opsToErase.push_back(tempOp);
595 ImplicitLocOpBuilder
builder(op.getLoc(), op);
596 auto newInst = tempOp.erasePorts(
builder, portsToErase);
600 [&](
auto index, MappingInfo &pmi,
auto newIndex) {
602 auto oldResult = op.getResult(index);
606 assert(oldResult.getType() == newInst.getType(newIndex));
607 oldResult.replaceAllUsesWith(newInst.getResult(newIndex));
612 auto newPortIndex = newIndex;
614 hwOnlyAggMap[oldResult] = newInst.getResult(newPortIndex++);
616 for (
auto &field : pmi.fields) {
617 auto ref = FieldRef(oldResult, field.fieldID);
618 auto newVal = newInst.getResult(newPortIndex++);
619 assert(newVal.getType() == field.type);
620 nonHWValues[ref] = newVal;
622 for (
auto fieldID : pmi.mapToNullInteriors) {
623 auto ref = FieldRef(oldResult, fieldID);
624 assert(!nonHWValues.count(ref));
625 nonHWValues[ref] = {};
629 if (failed(mappingResult))
632 opsToErase.push_back(op);
637 LogicalResult Visitor::visitDecl(WireOp op) {
638 auto pmi = mapType(op.getResultTypes()[0], op.getLoc(), op.getInnerSymAttr());
641 MappingInfo mappings = *pmi;
644 llvm::dbgs().indent(6) <<
"- wire:\n";
645 llvm::dbgs().indent(10) <<
"name: " << op.getNameAttr() <<
"\n";
646 llvm::dbgs().indent(10) <<
"type: " << op.getType(0) <<
"\n";
647 llvm::dbgs().indent(12) <<
"mapping:\n";
648 mappings.print(llvm::dbgs(), 14);
649 llvm::dbgs() <<
"\n";
652 if (mappings.identity)
658 ImplicitLocOpBuilder
builder(op.getLoc(), op);
660 if (!op.getAnnotations().empty())
661 return mlir::emitError(op.getLoc())
662 <<
"annotations on open aggregates not handled yet";
666 hwOnlyAggMap[op.getResult()] =
668 .create<WireOp>(mappings.hwType, op.getName(), op.getNameKind(),
669 op.getAnnotations(), mappings.newSym,
674 for (
auto &[type, fieldID, _, suffix] : mappings.fields)
675 nonHWValues[
FieldRef(op.getResult(), fieldID)] =
677 .create<WireOp>(type,
678 builder.getStringAttr(Twine(op.getName()) + suffix),
679 NameKindEnum::DroppableName)
682 for (
auto fieldID : mappings.mapToNullInteriors)
683 nonHWValues[
FieldRef(op.getResult(), fieldID)] = {};
685 opsToErase.push_back(op);
695 hw::InnerSymAttr sym) {
696 MappingInfo pi{
false, {}, {}, {}};
697 auto ftype = type_dyn_cast<FIRRTLType>(type);
699 if (!ftype || !isa<OpenBundleType, OpenVectorType>(ftype)) {
704 SmallVector<hw::InnerSymPropertiesAttr> newProps;
707 auto recurse = [&](
auto &&f,
FIRRTLType type,
const Twine &suffix =
"",
708 bool flip =
false, uint64_t fieldID = 0,
711 TypeSwitch<FIRRTLType, FailureOr<FIRRTLBaseType>>(type)
712 .Case<FIRRTLBaseType>([](
auto base) {
return base; })
713 .
template Case<OpenBundleType>([&](OpenBundleType obTy)
715 SmallVector<BundleType::BundleElement> hwElements;
717 for (
const auto &[index, element] :
718 llvm::enumerate(obTy.getElements())) {
720 f(f, element.type, suffix +
"_" + element.name.strref(),
721 flip ^ element.isFlip, fieldID + obTy.getFieldID(index),
722 newFieldID +
id + 1);
726 hwElements.emplace_back(element.name, element.isFlip, *base);
731 if (hwElements.empty()) {
732 pi.mapToNullInteriors.push_back(fieldID);
738 .
template Case<OpenVectorType>([&](OpenVectorType ovTy)
744 for (
auto idx : llvm::seq<size_t>(0U, ovTy.getNumElements())) {
746 f(f, ovTy.getElementType(), suffix +
"_" + Twine(idx),
flip,
747 fieldID + ovTy.getFieldID(idx), newFieldID +
id + 1);
748 if (failed(hwElementType))
750 assert((!convert || convert == *hwElementType) &&
751 "expected same hw type for all elements");
752 convert = *hwElementType;
758 pi.mapToNullInteriors.push_back(fieldID);
765 .
template Case<RefType>([&](RefType ref) {
766 auto f = NonHWField{ref, fieldID,
flip, {}};
767 suffix.toVector(f.suffix);
768 pi.fields.emplace_back(std::move(f));
774 auto f = NonHWField{prop, fieldID,
flip, {}};
775 suffix.toVector(f.suffix);
776 pi.fields.emplace_back(std::move(f));
779 .Default([&](
auto _) {
780 pi.mapToNullInteriors.push_back(fieldID);
788 if (
auto symOnThis = sym.getSymIfExists(fieldID)) {
790 return mlir::emitError(errorLoc,
"inner symbol ")
791 << symOnThis <<
" mapped to non-HW type";
793 context, symOnThis, newFieldID,
799 auto hwType = recurse(recurse, ftype);
804 assert(pi.hwType != type);
808 assert(sym.size() == newProps.size());
810 if (!pi.hwType && !newProps.empty())
811 return mlir::emitError(errorLoc,
"inner symbol on non-HW type");
813 llvm::sort(newProps, [](
auto &p,
auto &q) {
814 return p.getFieldID() < q.getFieldID();
827 struct LowerOpenAggsPass :
public LowerOpenAggsBase<LowerOpenAggsPass> {
828 LowerOpenAggsPass() =
default;
829 void runOnOperation()
override;
834 void LowerOpenAggsPass::runOnOperation() {
836 SmallVector<Operation *, 0> ops(getOperation().getOps<FModuleLike>());
838 LLVM_DEBUG(llvm::dbgs() <<
"Visiting modules:\n");
839 std::atomic<bool> madeChanges =
false;
840 auto result = failableParallelForEach(&getContext(), ops, [&](Operation *op) {
841 Visitor visitor(&getContext());
842 auto result = visitor.visit(cast<FModuleLike>(op));
843 if (visitor.madeChanges())
851 markAllAnalysesPreserved();
856 return std::make_unique<LowerOpenAggsPass>();
assert(baseType &&"element must be base type")
static void dump(DIModule &module, raw_indented_ostream &os)
LogicalResult walkMappings(Range &&range, bool includeErased, llvm::function_ref< LogicalResult(size_t, MappingInfo &, size_t)> callback)
This class represents a reference to a specific field or element of an aggregate value.
This class provides a read-only projection over the MLIR attributes that represent a set of annotatio...
FIRRTLVisitor allows you to visit all of the expr/stmt/decls with one class declaration.
Caching version of getFieldRefFromValue.
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Direction flip(Direction direction)
Flip a port direction.
Direction
This represents the direction of a single port.
std::unique_ptr< mlir::Pass > createLowerOpenAggsPass()
This is the pass constructor.
uint64_t getMaxFieldID(Type)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
llvm::raw_ostream & debugPassHeader(const mlir::Pass *pass, int width=80)
Write a boilerplate header for a pass to the debug stream.
This holds the name and type that describes the module's ports.