25 #include "mlir/IR/BuiltinAttributes.h"
26 #include "mlir/IR/ImplicitLocOpBuilder.h"
27 #include "mlir/IR/Threading.h"
28 #include "mlir/IR/Visitors.h"
29 #include "llvm/Support/Debug.h"
30 #include "llvm/Support/ErrorHandling.h"
31 #include "llvm/Support/FormatAdapters.h"
32 #include "llvm/Support/FormatVariadic.h"
36 #define DEBUG_TYPE "firrtl-lower-open-aggs"
38 using namespace circt;
39 using namespace firrtl;
52 SmallString<16> suffix;
55 void print(raw_ostream &os,
unsigned indent = 0)
const;
57 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
76 SmallVector<NonHWField, 0> fields;
80 SmallVector<uint64_t, 0> mapToNullInteriors;
82 hw::InnerSymAttr newSym = {};
85 size_t count(
bool includeErased =
false)
const {
88 return fields.size() + (hwType ? 1 : 0) + (includeErased ? 1 : 0);
92 void print(raw_ostream &os,
unsigned indent = 0)
const;
94 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
102 void NonHWField::print(llvm::raw_ostream &os,
unsigned indent)
const {
103 os << llvm::formatv(
"{0}- type: {2}\n"
106 "{1}suffix: \"{5}\"\n",
107 llvm::fmt_pad(
"", indent, 0),
108 llvm::fmt_pad(
"", indent + 2, 0), type, fieldID, isFlip,
111 void MappingInfo::print(llvm::raw_ostream &os,
unsigned indent)
const {
117 os.indent(indent) <<
"hardware: ";
124 os.indent(indent) <<
"non-hardware:\n";
125 for (
auto &field : fields)
126 field.print(os, indent + 2);
128 os.indent(indent) <<
"mappedToNull:\n";
129 for (
auto &
null : mapToNullInteriors)
130 os.indent(indent + 2) <<
"- " <<
null <<
"\n";
132 os.indent(indent) <<
"newSym: ";
139 template <
typename Range>
141 Range &&range,
bool includeErased,
142 llvm::function_ref<LogicalResult(
size_t, MappingInfo &,
size_t)> callback) {
144 for (
const auto &[index, pmi] : llvm::enumerate(range)) {
145 if (failed(callback(index, pmi, count)))
147 count += pmi.count(includeErased);
157 class Visitor :
public FIRRTLVisitor<Visitor, LogicalResult> {
159 explicit Visitor(MLIRContext *context) : context(context){};
162 LogicalResult visit(FModuleLike mod);
168 LogicalResult visitDecl(InstanceOp op);
169 LogicalResult visitDecl(WireOp op);
171 LogicalResult visitExpr(OpenSubfieldOp op);
172 LogicalResult visitExpr(OpenSubindexOp op);
174 LogicalResult visitUnhandledOp(Operation *op) {
175 auto notOpenAggType = [](
auto type) {
176 return !isa<OpenBundleType, OpenVectorType>(type);
178 if (!llvm::all_of(op->getOperandTypes(), notOpenAggType) ||
179 !llvm::all_of(op->getResultTypes(), notOpenAggType))
180 return op->emitOpError(
181 "unhandled use or producer of types containing non-hw types");
185 LogicalResult visitInvalidOp(Operation *op) {
return visitUnhandledOp(op); }
188 bool madeChanges()
const {
return changesMade; }
195 hw::InnerSymAttr sym = {});
198 void recordChanges(
bool changed) {
203 MLIRContext *context;
208 DenseMap<FieldRef, Value> nonHWValues;
211 DenseMap<Value, Value> hwOnlyAggMap;
214 SmallVector<Operation *> opsToErase;
221 bool changesMade =
false;
225 LogicalResult Visitor::visit(FModuleLike mod) {
226 auto ports = mod.getPorts();
228 SmallVector<MappingInfo, 16> portMappings;
229 for (
auto &port : ports) {
230 auto pmi = mapType(port.type, port.loc, port.sym);
233 portMappings.push_back(*pmi);
238 size_t countWithErased = 0;
239 for (
auto &pmi : portMappings)
240 countWithErased += pmi.count(
true);
243 SmallVector<std::pair<unsigned, PortInfo>> newPorts;
246 BitVector portsToErase(countWithErased);
251 << cast<mlir::SymbolOpInterface>(*mod).getNameAttr()
257 [&](
auto index,
auto &pmi,
auto newIndex) -> LogicalResult {
259 llvm::dbgs().indent(6) <<
"- name: " << ports[index].name <<
"\n";
260 llvm::dbgs().indent(8) <<
"type: " << ports[index].type <<
"\n";
267 auto idxOfInsertPoint = index + 1;
272 auto &port = ports[index];
275 portsToErase.set(newIndex);
280 newPort.type = pmi.hwType;
281 newPort.sym = pmi.newSym;
282 newPorts.emplace_back(idxOfInsertPoint, newPort);
285 (pmi.newSym && port.sym.size() == pmi.newSym.size()));
288 if (!port.annotations.empty())
289 return mlir::emitError(port.loc)
290 <<
"annotations on open aggregates not handled yet";
292 assert(!port.sym && !pmi.newSym);
293 if (!port.annotations.empty())
294 return mlir::emitError(port.loc)
295 <<
"annotations found on aggregate with no HW";
299 for (
const auto &[findex, field] : llvm::enumerate(pmi.fields)) {
301 Twine(port.name.strref()) + field.suffix);
303 (
Direction)((
unsigned)port.direction ^ field.isFlip);
304 PortInfo pi(name, field.type, orientation, StringAttr{},
305 port.loc, std::nullopt);
306 newPorts.emplace_back(idxOfInsertPoint, pi);
314 mod.insertPorts(newPorts);
315 recordChanges(!newPorts.empty());
317 assert(mod->getNumRegions() == 1);
320 auto getBodyBlock = [](
auto mod) {
321 auto &blocks = mod->getRegion(0).getBlocks();
322 return !blocks.empty() ? &blocks.front() :
nullptr;
327 if (
auto *block = getBodyBlock(mod)) {
331 [&](
auto index, MappingInfo &pmi,
auto newIndex) {
338 assert(portsToErase.test(newIndex));
339 auto oldPort = block->getArgument(newIndex);
340 auto newPortIndex = newIndex;
344 hwOnlyAggMap[oldPort] =
345 block->getArgument(++newPortIndex);
347 for (
auto &field : pmi.fields) {
348 auto ref =
FieldRef(oldPort, field.fieldID);
349 auto newVal = block->getArgument(++newPortIndex);
350 nonHWValues[ref] = newVal;
352 for (
auto fieldID : pmi.mapToNullInteriors) {
353 auto ref =
FieldRef(oldPort, fieldID);
354 assert(!nonHWValues.count(ref));
355 nonHWValues[ref] = {};
364 LLVM_DEBUG(
llvm::dbgs().indent(4) <<
"body:\n");
366 ->walk<mlir::WalkOrder::PreOrder>([&](Operation *op) -> WalkResult {
367 return dispatchVisitor(op);
372 assert(opsToErase.empty() || madeChanges());
375 for (
auto &op : llvm::reverse(opsToErase))
380 mod.erasePorts(portsToErase);
381 recordChanges(portsToErase.any());
388 LogicalResult Visitor::visitExpr(OpenSubfieldOp op) {
414 opsToErase.push_back(op);
419 auto resultRef = refs.getFieldRefFromValue(op.getResult());
420 auto nonHWForResult = nonHWValues.find(resultRef);
421 if (nonHWForResult != nonHWValues.end()) {
423 if (
auto newResult = nonHWForResult->second) {
424 assert(op.getResult().getType() == newResult.getType());
425 assert(!type_isa<FIRRTLBaseType>(newResult.getType()));
426 op.getResult().replaceAllUsesWith(newResult);
431 assert(hwOnlyAggMap.count(op.getInput()));
433 auto newInput = hwOnlyAggMap[op.getInput()];
436 auto bundleType = type_cast<BundleType>(newInput.getType());
439 auto fieldName = op.getFieldName();
440 auto newFieldIndex = bundleType.getElementIndex(fieldName);
441 assert(newFieldIndex.has_value());
443 ImplicitLocOpBuilder
builder(op.getLoc(), op);
444 auto newOp =
builder.create<SubfieldOp>(newInput, *newFieldIndex);
445 if (
auto name = op->getAttrOfType<StringAttr>(
"name"))
446 newOp->setAttr(
"name", name);
448 hwOnlyAggMap[op.getResult()] = newOp;
450 if (type_isa<FIRRTLBaseType>(op.getType()))
451 op.getResult().replaceAllUsesWith(newOp.getResult());
456 LogicalResult Visitor::visitExpr(OpenSubindexOp op) {
461 opsToErase.push_back(op);
466 auto resultRef = refs.getFieldRefFromValue(op.getResult());
467 auto nonHWForResult = nonHWValues.find(resultRef);
468 if (nonHWForResult != nonHWValues.end()) {
470 if (
auto newResult = nonHWForResult->second) {
471 assert(op.getResult().getType() == newResult.getType());
472 assert(!type_isa<FIRRTLBaseType>(newResult.getType()));
473 op.getResult().replaceAllUsesWith(newResult);
478 assert(hwOnlyAggMap.count(op.getInput()));
480 auto newInput = hwOnlyAggMap[op.getInput()];
483 ImplicitLocOpBuilder
builder(op.getLoc(), op);
484 auto newOp =
builder.create<SubindexOp>(newInput, op.getIndex());
485 if (
auto name = op->getAttrOfType<StringAttr>(
"name"))
486 newOp->setAttr(
"name", name);
488 hwOnlyAggMap[op.getResult()] = newOp;
490 if (type_isa<FIRRTLBaseType>(op.getType()))
491 op.getResult().replaceAllUsesWith(newOp.getResult());
495 LogicalResult Visitor::visitDecl(InstanceOp op) {
498 SmallVector<MappingInfo, 16> portMappings;
500 for (
auto type : op.getResultTypes()) {
501 auto pmi = mapType(type, op.getLoc());
504 portMappings.push_back(*pmi);
508 size_t countWithErased = 0;
509 for (
auto &pmi : portMappings)
510 countWithErased += pmi.count(
true);
513 SmallVector<std::pair<unsigned, PortInfo>> newPorts;
516 BitVector portsToErase(countWithErased);
521 llvm::dbgs().indent(10) <<
"name: " << op.getInstanceNameAttr() <<
"\n";
522 llvm::dbgs().indent(10) <<
"module: " << op.getModuleNameAttr() <<
"\n";
527 [&](
auto index,
auto &pmi,
auto newIndex) -> LogicalResult {
530 <<
"- name: " << op.getPortName(index) <<
"\n";
531 llvm::dbgs().indent(14) <<
"type: " << op.getType(index) <<
"\n";
538 auto idxOfInsertPoint = index + 1;
544 portsToErase.set(newIndex);
546 auto portName = op.getPortName(index);
547 auto portDirection = op.getPortDirection(index);
548 auto loc = op.getLoc();
552 PortInfo hwPort(portName, pmi.hwType, portDirection,
555 newPorts.emplace_back(idxOfInsertPoint, hwPort);
558 if (!op.getPortAnnotation(index).empty())
559 return mlir::emitError(op.getLoc())
560 <<
"annotations on open aggregates not handled yet";
562 if (!op.getPortAnnotation(index).empty())
563 return mlir::emitError(op.getLoc())
564 <<
"annotations found on aggregate with no HW";
568 for (
const auto &[findex, field] : llvm::enumerate(pmi.fields)) {
572 (
Direction)((
unsigned)portDirection ^ field.isFlip);
573 PortInfo pi(name, field.type, orientation, StringAttr{},
575 newPorts.emplace_back(idxOfInsertPoint, pi);
583 if (newPorts.empty())
592 auto tempOp = op.cloneAndInsertPorts(newPorts);
593 opsToErase.push_back(tempOp);
594 ImplicitLocOpBuilder
builder(op.getLoc(), op);
595 auto newInst = tempOp.erasePorts(
builder, portsToErase);
599 [&](
auto index, MappingInfo &pmi,
auto newIndex) {
601 auto oldResult = op.getResult(index);
605 assert(oldResult.getType() == newInst.getType(newIndex));
606 oldResult.replaceAllUsesWith(newInst.getResult(newIndex));
611 auto newPortIndex = newIndex;
613 hwOnlyAggMap[oldResult] = newInst.getResult(newPortIndex++);
615 for (
auto &field : pmi.fields) {
616 auto ref = FieldRef(oldResult, field.fieldID);
617 auto newVal = newInst.getResult(newPortIndex++);
618 assert(newVal.getType() == field.type);
619 nonHWValues[ref] = newVal;
621 for (
auto fieldID : pmi.mapToNullInteriors) {
622 auto ref = FieldRef(oldResult, fieldID);
623 assert(!nonHWValues.count(ref));
624 nonHWValues[ref] = {};
628 if (failed(mappingResult))
631 opsToErase.push_back(op);
636 LogicalResult Visitor::visitDecl(WireOp op) {
637 auto pmi = mapType(op.getResultTypes()[0], op.getLoc(), op.getInnerSymAttr());
640 MappingInfo mappings = *pmi;
644 llvm::dbgs().indent(10) <<
"name: " << op.getNameAttr() <<
"\n";
645 llvm::dbgs().indent(10) <<
"type: " << op.getType(0) <<
"\n";
651 if (mappings.identity)
657 ImplicitLocOpBuilder
builder(op.getLoc(), op);
659 if (!op.getAnnotations().empty())
660 return mlir::emitError(op.getLoc())
661 <<
"annotations on open aggregates not handled yet";
665 hwOnlyAggMap[op.getResult()] =
667 .create<WireOp>(mappings.hwType, op.getName(), op.getNameKind(),
668 op.getAnnotations(), mappings.newSym,
673 for (
auto &[type, fieldID, _, suffix] : mappings.fields)
674 nonHWValues[
FieldRef(op.getResult(), fieldID)] =
676 .create<WireOp>(type,
677 builder.getStringAttr(Twine(op.getName()) + suffix),
678 NameKindEnum::DroppableName)
681 for (
auto fieldID : mappings.mapToNullInteriors)
682 nonHWValues[
FieldRef(op.getResult(), fieldID)] = {};
684 opsToErase.push_back(op);
694 hw::InnerSymAttr sym) {
695 MappingInfo pi{
false, {}, {}, {}};
696 auto ftype = type_dyn_cast<FIRRTLType>(type);
698 if (!ftype || !isa<OpenBundleType, OpenVectorType>(ftype)) {
703 SmallVector<hw::InnerSymPropertiesAttr> newProps;
706 auto recurse = [&](
auto &&f,
FIRRTLType type,
const Twine &suffix =
"",
707 bool flip =
false, uint64_t fieldID = 0,
710 TypeSwitch<FIRRTLType, FailureOr<FIRRTLBaseType>>(type)
711 .Case<FIRRTLBaseType>([](
auto base) {
return base; })
712 .
template Case<OpenBundleType>([&](OpenBundleType obTy)
714 SmallVector<BundleType::BundleElement> hwElements;
716 for (
const auto &[index, element] :
717 llvm::enumerate(obTy.getElements())) {
719 f(f, element.type, suffix +
"_" + element.name.strref(),
720 flip ^ element.isFlip, fieldID + obTy.getFieldID(index),
721 newFieldID +
id + 1);
725 hwElements.emplace_back(element.name, element.isFlip, *base);
730 if (hwElements.empty()) {
731 pi.mapToNullInteriors.push_back(fieldID);
737 .
template Case<OpenVectorType>([&](OpenVectorType ovTy)
743 for (
auto idx : llvm::seq<size_t>(0U, ovTy.getNumElements())) {
745 f(f, ovTy.getElementType(), suffix +
"_" + Twine(idx),
flip,
746 fieldID + ovTy.getFieldID(idx), newFieldID +
id + 1);
747 if (failed(hwElementType))
749 assert((!convert || convert == *hwElementType) &&
750 "expected same hw type for all elements");
751 convert = *hwElementType;
757 pi.mapToNullInteriors.push_back(fieldID);
764 .
template Case<RefType>([&](RefType ref) {
765 auto f = NonHWField{ref, fieldID,
flip, {}};
766 suffix.toVector(f.suffix);
767 pi.fields.emplace_back(std::move(f));
773 auto f = NonHWField{prop, fieldID,
flip, {}};
774 suffix.toVector(f.suffix);
775 pi.fields.emplace_back(std::move(f));
778 .Default([&](
auto _) {
779 pi.mapToNullInteriors.push_back(fieldID);
787 if (
auto symOnThis = sym.getSymIfExists(fieldID)) {
789 return mlir::emitError(errorLoc,
"inner symbol ")
790 << symOnThis <<
" mapped to non-HW type";
792 context, symOnThis, newFieldID,
798 auto hwType = recurse(recurse, ftype);
803 assert(pi.hwType != type);
807 assert(sym.size() == newProps.size());
809 if (!pi.hwType && !newProps.empty())
810 return mlir::emitError(errorLoc,
"inner symbol on non-HW type");
812 llvm::sort(newProps, [](
auto &p,
auto &q) {
813 return p.getFieldID() < q.getFieldID();
826 struct LowerOpenAggsPass :
public LowerOpenAggsBase<LowerOpenAggsPass> {
827 LowerOpenAggsPass() =
default;
828 void runOnOperation()
override;
833 void LowerOpenAggsPass::runOnOperation() {
834 LLVM_DEBUG(
llvm::dbgs() <<
"===- Running Lower Open Aggregates Pass "
835 "-------------------------------------===\n");
836 SmallVector<Operation *, 0> ops(getOperation().getOps<FModuleLike>());
838 LLVM_DEBUG(
llvm::dbgs() <<
"Visiting modules:\n");
839 std::atomic<bool> madeChanges =
false;
840 auto result = failableParallelForEach(&getContext(), ops, [&](Operation *op) {
841 Visitor visitor(&getContext());
842 auto result = visitor.visit(cast<FModuleLike>(op));
843 if (visitor.madeChanges())
851 markAllAnalysesPreserved();
856 return std::make_unique<LowerOpenAggsPass>();
assert(baseType &&"element must be base type")
static void dump(DIVariable &variable, raw_indented_ostream &os)
LogicalResult walkMappings(Range &&range, bool includeErased, llvm::function_ref< LogicalResult(size_t, MappingInfo &, size_t)> callback)
This class represents a reference to a specific field or element of an aggregate value.
This class provides a read-only projection over the MLIR attributes that represent a set of annotatio...
FIRRTLVisitor allows you to visit all of the expr/stmt/decls with one class declaration.
Caching version of getFieldRefFromValue.
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Direction flip(Direction direction)
Flip a port direction.
Direction
This represents the direction of a single port.
std::unique_ptr< mlir::Pass > createLowerOpenAggsPass()
This is the pass constructor.
uint64_t getMaxFieldID(Type)
This file defines an intermediate representation for circuits acting as an abstraction for constraint...
mlir::raw_indented_ostream & errs()
mlir::raw_indented_ostream & dbgs()
This holds the name and type that describes the module's ports.