39#include "mlir/IR/ImplicitLocOpBuilder.h"
40#include "mlir/IR/Threading.h"
41#include "mlir/Pass/Pass.h"
42#include "llvm/ADT/APSInt.h"
43#include "llvm/ADT/BitVector.h"
44#include "llvm/Support/Debug.h"
46#define DEBUG_TYPE "firrtl-lower-types"
50#define GEN_PASS_DEF_LOWERFIRRTLTYPES
51#include "circt/Dialect/FIRRTL/Passes.h.inc"
56using namespace firrtl;
61struct FlatBundleFieldEntry {
69 SmallString<16> suffix;
74 unsigned fieldID, StringRef suffix,
bool isOutput)
75 : type(type), index(index), fieldID(fieldID), suffix(suffix),
79 llvm::errs() <<
"FBFE{" << type <<
" index<" << index <<
"> fieldID<"
80 << fieldID <<
"> suffix<" << suffix <<
"> isOutput<"
81 << isOutput <<
">}\n";
88 return mapBaseType(type, [&](
auto) {
return fieldType; });
93 auto ftype = type_dyn_cast<FIRRTLType>(type);
102 .
Case<BundleType>([&](
auto bundle) {
return false; })
103 .Case<FVectorType>([&](FVectorType vector) {
105 return vector.getElementType().isGround() &&
106 vector.getNumElements() > 1;
108 .Default([](
auto groundType) {
return true; });
115 .
Case<BundleType>([&](
auto bundle) {
return true; })
116 .Case<FVectorType>([&](FVectorType vector) {
119 .Default([](
auto groundType) {
return false; });
126 if (
auto refType = type_dyn_cast<RefType>(type)) {
128 if (refType.getForceable())
139 auto firrtlType = type_dyn_cast<FIRRTLBaseType>(type);
145 if (!firrtlType.isPassive() || firrtlType.containsAnalog() ||
157 llvm_unreachable(
"unexpected mode");
163static bool peelType(Type type, SmallVectorImpl<FlatBundleFieldEntry> &fields,
170 if (
auto refType = type_dyn_cast<RefType>(type))
171 type = refType.getType();
173 .
Case<BundleType>([&](
auto bundle) {
174 SmallString<16> tmpSuffix;
176 for (
size_t i = 0, e = bundle.getNumElements(); i < e; ++i) {
177 auto elt = bundle.getElement(i);
180 tmpSuffix.push_back(
'_');
181 tmpSuffix.append(elt.name.getValue());
182 fields.emplace_back(elt.type, i, bundle.getFieldID(i), tmpSuffix,
187 .Case<FVectorType>([&](
auto vector) {
189 for (
size_t i = 0, e = vector.getNumElements(); i != e; ++i) {
190 fields.emplace_back(vector.getElementType(), i, vector.getFieldID(i),
191 "_" + std::to_string(i),
false);
195 .Default([](
auto op) {
return false; });
201 SubaccessOp sao = llvm::dyn_cast<SubaccessOp>(op);
205 llvm::dyn_cast_or_null<ConstantOp>(sao.getIndex().getDefiningOp());
206 return arg && sao.getInput().getType().base().getNumElements() != 0;
211 SmallVector<Operation *> retval;
212 auto defOp = op->getOperand(0).getDefiningOp();
213 while (isa_and_nonnull<SubfieldOp, SubindexOp, SubaccessOp>(defOp)) {
214 retval.push_back(defOp);
215 defOp = defOp->getOperand(0).getDefiningOp();
225 FlatBundleFieldEntry field) {
226 SmallVector<Type, 8> ports;
227 SmallVector<Attribute, 8> portNames;
228 SmallVector<Attribute, 8> portLocations;
230 auto oldPorts = op.getPorts();
231 for (
size_t portIdx = 0, e = oldPorts.size(); portIdx < e; ++portIdx) {
232 auto port = oldPorts[portIdx];
234 MemOp::getTypeForPort(op.getDepth(), field.type, port.second));
235 portNames.push_back(port.first);
240 MemOp::create(*b, ports, op.getReadLatency(), op.getWriteLatency(),
241 op.getDepth(), op.getRuw(), b->getArrayAttr(portNames),
242 (op.getName() + field.suffix).str(), op.getNameKind(),
243 op.getAnnotations(), op.getPortAnnotations(),
244 op.getInnerSymAttr(), op.getInitAttr(), op.getPrefixAttr());
246 if (op.getInnerSym()) {
247 op.emitError(
"cannot split memory with symbol present");
251 SmallVector<Attribute> newAnnotations;
252 for (
size_t portIdx = 0, e = newMem.getNumResults(); portIdx < e; ++portIdx) {
253 auto portType = type_cast<BundleType>(newMem.getResult(portIdx).getType());
254 auto oldPortType = type_cast<BundleType>(op.getResult(portIdx).getType());
255 SmallVector<Attribute> portAnno;
256 for (
auto attr : newMem.getPortAnnotation(portIdx)) {
259 auto targetIndex = oldPortType.getIndexForFieldID(annoFieldID);
263 if (annoFieldID == oldPortType.getFieldID(targetIndex)) {
266 b->getI32IntegerAttr(portType.getFieldID(targetIndex)));
267 portAnno.push_back(anno.
getDict());
272 if (type_isa<BundleType>(oldPortType.getElement(targetIndex).type)) {
277 auto fieldID = field.fieldID + oldPortType.getFieldID(targetIndex);
278 if (annoFieldID >= fieldID &&
283 annoFieldID - fieldID + portType.getFieldID(targetIndex);
284 anno.
setMember(
"circt.fieldID", b->getI32IntegerAttr(newFieldID));
285 portAnno.push_back(anno.
getDict());
289 portAnno.push_back(attr);
291 newAnnotations.push_back(b->getArrayAttr(portAnno));
293 newMem.setAllPortAnnotations(newAnnotations);
303 AttrCache(MLIRContext *
context) {
304 i64ty = IntegerType::get(
context, 64);
305 nameAttr = StringAttr::get(
context,
"name");
306 nameKindAttr = StringAttr::get(
context,
"nameKind");
307 sPortDirections = StringAttr::get(
context,
"portDirections");
308 sPortNames = StringAttr::get(
context,
"portNames");
309 sPortTypes = StringAttr::get(
context,
"portTypes");
310 sPortSymbols = StringAttr::get(
context,
"portSymbols");
311 sPortLocations = StringAttr::get(
context,
"portLocations");
312 sPortAnnotations = StringAttr::get(
context,
"portAnnotations");
313 sPortDomains = StringAttr::get(
context,
"domainInfo");
314 sEmpty = StringAttr::get(
context,
"");
315 aEmpty = ArrayAttr::get(
context, {});
317 AttrCache(
const AttrCache &) =
default;
320 StringAttr nameAttr, nameKindAttr, sPortDirections, sPortNames, sPortTypes,
321 sPortSymbols, sPortLocations, sPortAnnotations, sPortDomains, sEmpty;
328class DomainLoweringHelper {
333 DomainLoweringHelper(MLIRContext *context, ArrayRef<Attribute> portTypes)
335 for (
auto [index, typeAttr] :
llvm::enumerate(portTypes))
336 if (
type_isa<DomainType>(cast<TypeAttr>(typeAttr).getValue()))
337 domainIndexByOrdinal.push_back(index);
341 DomainLoweringHelper(MLIRContext *context, TypeRange resultTypes)
343 for (
auto [index, type] :
llvm::enumerate(resultTypes))
345 domainIndexByOrdinal.push_back(index);
352 void computeDomainMap(TypeRange types) {
353 size_t i = 0, ord = 0;
354 for (
auto type : types) {
355 if (type_isa<DomainType>(type))
356 domainMap[domainIndexByOrdinal[ord++]] = i;
365 void computeDomainMap(ArrayRef<PortInfo> ports) {
366 size_t i = 0, ord = 0;
367 for (
const auto &port : ports) {
368 if (type_isa<DomainType>(port.type))
369 domainMap[domainIndexByOrdinal[ord++]] = i;
377 void rewriteDomain(Attribute &domain) {
378 auto oldAssociations = dyn_cast<ArrayAttr>(domain);
379 if (!oldAssociations)
381 SmallVector<Attribute> newAssociations;
382 for (
auto oldAttr : oldAssociations)
383 newAssociations.push_back(IntegerAttr::
get(
384 IntegerType::
get(context, 32, IntegerType::Unsigned),
385 domainMap[cast<IntegerAttr>(oldAttr).getValue().getZExtValue()]));
386 domain = ArrayAttr::get(context, newAssociations);
390 MLIRContext *context;
392 SmallVector<unsigned> domainIndexByOrdinal;
394 DenseMap<unsigned, unsigned> domainMap;
399struct TypeLoweringVisitor :
public FIRRTLVisitor<TypeLoweringVisitor, bool> {
403 Convention bodyConvention,
405 SymbolTable &symTbl,
const AttrCache &cache,
406 const llvm::DenseMap<FModuleLike, Convention> &conventionTable)
407 : context(context), defaultAggregatePreservationMode(preserveAggregate),
408 memoryPreservationMode(memoryPreservationMode), symTbl(symTbl),
409 cache(cache), conventionTable(conventionTable) {
410 bodyAggregatePreservationMode = bodyConvention == Convention::Scalarized
412 : defaultAggregatePreservationMode;
420 void lowerModule(FModuleLike op);
422 bool lowerArg(FModuleLike module,
size_t argIndex,
size_t argsRemoved,
423 SmallVectorImpl<PortInfo> &newArgs,
424 SmallVectorImpl<Value> &lowering);
425 std::pair<Value, PortInfo> addArg(Operation *module,
unsigned insertPt,
427 const FlatBundleFieldEntry &field,
428 PortInfo &oldArg, hw::InnerSymAttr newSym);
431 bool visitDecl(FExtModuleOp op);
432 bool visitDecl(FModuleOp op);
433 bool visitDecl(InstanceOp op);
434 bool visitDecl(MemOp op);
435 bool visitDecl(NodeOp op);
436 bool visitDecl(RegOp op);
437 bool visitDecl(WireOp op);
438 bool visitDecl(RegResetOp op);
439 bool visitExpr(InvalidValueOp op);
440 bool visitExpr(SubaccessOp op);
441 bool visitExpr(VectorCreateOp op);
442 bool visitExpr(BundleCreateOp op);
443 bool visitExpr(ElementwiseAndPrimOp op);
444 bool visitExpr(ElementwiseOrPrimOp op);
445 bool visitExpr(ElementwiseXorPrimOp op);
446 bool visitExpr(MultibitMuxOp op);
447 bool visitExpr(MuxPrimOp op);
448 bool visitExpr(Mux2CellIntrinsicOp op);
449 bool visitExpr(Mux4CellIntrinsicOp op);
450 bool visitExpr(BitCastOp op);
451 bool visitExpr(RefSendOp op);
452 bool visitExpr(RefResolveOp op);
453 bool visitExpr(RefCastOp op);
454 bool visitStmt(ConnectOp op);
455 bool visitStmt(MatchingConnectOp op);
456 bool visitStmt(RefDefineOp op);
457 bool visitStmt(WhenOp op);
458 bool visitStmt(LayerBlockOp op);
459 bool visitUnrealizedConversionCast(mlir::UnrealizedConversionCastOp op);
461 bool isFailed()
const {
return encounteredError; }
464 if (
auto castOp = dyn_cast<mlir::UnrealizedConversionCastOp>(op))
465 return visitUnrealizedConversionCast(castOp);
470 void processUsers(Value val, ArrayRef<Value> mapping);
471 bool processSAPath(Operation *);
472 void lowerBlock(Block *);
473 void lowerSAWritePath(Operation *, ArrayRef<Operation *> writePath);
483 llvm::function_ref<Value(
const FlatBundleFieldEntry &, ArrayAttr)> clone,
488 ArrayAttr filterAnnotations(MLIRContext *ctxt, ArrayAttr annotations,
489 FIRRTLType srcType, FlatBundleFieldEntry field);
493 LogicalResult partitionSymbols(hw::InnerSymAttr sym,
FIRRTLType parentType,
494 SmallVectorImpl<hw::InnerSymAttr> &newSyms,
498 getPreservationModeForPorts(FModuleLike moduleLike);
499 Value getSubWhatever(Value val,
size_t index);
501 size_t uniqueIdx = 0;
502 std::string uniqueName() {
503 auto myID = uniqueIdx++;
504 return (Twine(
"__GEN_") + Twine(myID)).str();
515 ImplicitLocOpBuilder *builder;
521 const AttrCache &cache;
523 const llvm::DenseMap<FModuleLike, Convention> &conventionTable;
526 bool encounteredError =
false;
533TypeLoweringVisitor::getPreservationModeForPorts(FModuleLike module) {
534 auto lookup = conventionTable.find(module);
535 if (lookup == conventionTable.end())
536 return defaultAggregatePreservationMode;
537 switch (lookup->second) {
538 case Convention::Scalarized:
540 case Convention::Internal:
541 return defaultAggregatePreservationMode;
543 llvm_unreachable(
"Unknown convention");
544 return defaultAggregatePreservationMode;
547Value TypeLoweringVisitor::getSubWhatever(Value val,
size_t index) {
548 if (type_isa<BundleType>(val.getType()))
549 return SubfieldOp::create(*builder, val, index);
550 if (type_isa<FVectorType>(val.getType()))
551 return SubindexOp::create(*builder, val, index);
552 if (type_isa<RefType>(val.getType()))
553 return RefSubOp::create(*builder, val, index);
554 llvm_unreachable(
"Unknown aggregate type");
559bool TypeLoweringVisitor::processSAPath(Operation *op) {
562 if (writePath.empty())
565 lowerSAWritePath(op, writePath);
568 op->eraseOperands(0, 2);
570 for (
size_t i = 0; i < writePath.size(); ++i) {
571 if (writePath[i]->use_empty()) {
572 writePath[i]->erase();
580void TypeLoweringVisitor::lowerBlock(Block *block) {
582 for (
auto it = block->rbegin(), e = block->rend(); it != e;) {
584 builder->setInsertionPoint(&iop);
585 builder->setLoc(iop.getLoc());
586 bool removeOp = dispatchVisitor(&iop);
595ArrayAttr TypeLoweringVisitor::filterAnnotations(MLIRContext *ctxt,
596 ArrayAttr annotations,
598 FlatBundleFieldEntry field) {
599 SmallVector<Attribute> retval;
600 if (!annotations || annotations.empty())
601 return ArrayAttr::get(ctxt, retval);
602 for (
auto opAttr : annotations) {
604 auto fieldID = anno.getFieldID();
605 anno.removeMember(
"circt.fieldID");
610 retval.push_back(anno.getAttr());
615 if (fieldID < field.fieldID ||
620 if (
auto newFieldID = fieldID - field.fieldID) {
623 anno.setMember(
"circt.fieldID", builder->getI32IntegerAttr(newFieldID));
626 retval.push_back(anno.getAttr());
628 return ArrayAttr::get(ctxt, retval);
631LogicalResult TypeLoweringVisitor::partitionSymbols(
633 SmallVectorImpl<hw::InnerSymAttr> &newSyms, Location errorLoc) {
636 if (!sym || sym.empty())
639 auto *
context = sym.getContext();
643 return mlir::emitError(errorLoc,
644 "unable to partition symbol on unsupported type ")
647 return TypeSwitch<FIRRTLType, LogicalResult>(baseType)
648 .Case<BundleType, FVectorType>([&](
auto aggType) -> LogicalResult {
652 hw::InnerSymPropertiesAttr prop;
656 SmallVector<BinningInfo> binning;
657 for (
auto prop : sym) {
658 auto fieldID = prop.getFieldID();
661 return mlir::emitError(errorLoc,
"unable to lower due to symbol ")
663 <<
" with target not preserved by lowering";
664 auto [index, relFieldID] = aggType.getIndexAndSubfieldID(fieldID);
665 binning.push_back({index, relFieldID, prop});
669 llvm::stable_sort(binning, [&](
auto &lhs,
auto &rhs) {
670 return std::tuple(lhs.index, lhs.relFieldID) <
671 std::tuple(rhs.index, rhs.relFieldID);
676 newSyms.resize(aggType.getNumElements());
677 for (
auto binIt = binning.begin(), binEnd = binning.end();
679 auto curIndex = binIt->index;
680 SmallVector<hw::InnerSymPropertiesAttr> propsForIndex;
682 while (binIt != binEnd && binIt->index == curIndex) {
683 propsForIndex.push_back(hw::InnerSymPropertiesAttr::get(
684 context, binIt->prop.getName(), binIt->relFieldID,
685 binIt->prop.getSymVisibility()));
689 assert(!newSyms[curIndex]);
690 newSyms[curIndex] = hw::InnerSymAttr::get(
context, propsForIndex);
694 .Default([&](
auto ty) {
695 return mlir::emitError(
696 errorLoc,
"unable to partition symbol on unsupported type ")
701bool TypeLoweringVisitor::lowerProducer(
703 llvm::function_ref<Value(
const FlatBundleFieldEntry &, ArrayAttr)> clone,
707 srcType = op->getResult(0).getType();
708 auto srcFType = type_dyn_cast<FIRRTLType>(srcType);
711 SmallVector<FlatBundleFieldEntry, 8> fieldTypes;
713 if (!
peelType(srcFType, fieldTypes, bodyAggregatePreservationMode))
716 SmallVector<Value> lowered;
718 SmallString<16> loweredName;
719 auto nameKindAttr = op->getAttrOfType<NameKindEnumAttr>(cache.nameKindAttr);
721 if (
auto nameAttr = op->getAttrOfType<StringAttr>(cache.nameAttr))
722 loweredName = nameAttr.getValue();
723 auto baseNameLen = loweredName.size();
724 auto oldAnno = dyn_cast_or_null<ArrayAttr>(op->getAttr(
"annotations"));
726 SmallVector<hw::InnerSymAttr> fieldSyms(fieldTypes.size());
727 if (
auto symOp = dyn_cast<hw::InnerSymbolOpInterface>(op)) {
728 if (failed(partitionSymbols(symOp.getInnerSymAttr(), srcFType, fieldSyms,
730 encounteredError =
true;
735 for (
const auto &[field, sym] :
llvm::zip_equal(fieldTypes, fieldSyms)) {
736 if (!loweredName.empty()) {
737 loweredName.resize(baseNameLen);
738 loweredName += field.suffix;
743 ArrayAttr loweredAttrs =
744 filterAnnotations(
context, oldAnno, srcFType, field);
745 auto newVal = clone(field, loweredAttrs);
751 auto newSymOp = newVal.getDefiningOp<hw::InnerSymbolOpInterface>();
754 "op with inner symbol lowered to op that cannot take inner symbol");
755 newSymOp.setInnerSymbolAttr(sym);
759 if (
auto *newOp = newVal.getDefiningOp()) {
760 if (!loweredName.empty())
761 newOp->setAttr(cache.nameAttr, StringAttr::get(
context, loweredName));
763 newOp->setAttr(cache.nameKindAttr, nameKindAttr);
766 newOp->setDiscardableAttrs(op->getDiscardableAttrDictionary());
768 lowered.push_back(newVal);
771 processUsers(op->getResult(0), lowered);
775void TypeLoweringVisitor::processUsers(Value val, ArrayRef<Value> mapping) {
776 for (
auto *user :
llvm::make_early_inc_range(val.getUsers())) {
777 TypeSwitch<Operation *, void>(user)
778 .Case<SubindexOp>([mapping](SubindexOp sio) {
779 Value repl = mapping[sio.getIndex()];
780 sio.replaceAllUsesWith(repl);
783 .Case<SubfieldOp>([mapping](SubfieldOp sfo) {
785 Value repl = mapping[sfo.getFieldIndex()];
786 sfo.replaceAllUsesWith(repl);
789 .Case<RefSubOp>([mapping](RefSubOp refSub) {
790 Value repl = mapping[refSub.getIndex()];
791 refSub.replaceAllUsesWith(repl);
794 .Default([&](
auto op) {
805 ImplicitLocOpBuilder b(user->getLoc(), user);
809 assert(llvm::none_of(mapping, [](
auto v) {
810 auto fbasetype = type_dyn_cast<FIRRTLBaseType>(v.getType());
811 return !fbasetype || fbasetype.containsReference();
815 TypeSwitch<Type, Value>(val.getType())
816 .template Case<FVectorType>([&](
auto vecType) {
817 return b.createOrFold<VectorCreateOp>(vecType, mapping);
819 .
template Case<BundleType>([&](
auto bundleType) {
820 return b.createOrFold<BundleCreateOp>(bundleType, mapping);
822 .Default([&](
auto _) -> Value {
return {}; });
824 user->emitError(
"unable to reconstruct source of type ")
826 encounteredError =
true;
829 user->replaceUsesOfWith(val, input);
834void TypeLoweringVisitor::lowerModule(FModuleLike op) {
835 if (
auto module = llvm::dyn_cast<FModuleOp>(*op))
837 else if (
auto extModule = llvm::dyn_cast<FExtModuleOp>(*op))
838 visitDecl(extModule);
844std::pair<Value, PortInfo>
845TypeLoweringVisitor::addArg(Operation *module,
unsigned insertPt,
847 const FlatBundleFieldEntry &field,
PortInfo &oldArg,
848 hw::InnerSymAttr newSym) {
851 if (
auto mod = llvm::dyn_cast<FModuleOp>(module)) {
852 Block *body = mod.getBodyBlock();
854 newValue = body->insertArgument(insertPt, fieldType, oldArg.
loc);
858 auto name = builder->getStringAttr(oldArg.
name.getValue() + field.suffix);
861 auto newAnnotations = filterAnnotations(
866 return std::make_pair(
867 newValue,
PortInfo{name, fieldType, direction, newSym, oldArg.
loc,
872bool TypeLoweringVisitor::lowerArg(FModuleLike module,
size_t argIndex,
874 SmallVectorImpl<PortInfo> &newArgs,
875 SmallVectorImpl<Value> &lowering) {
878 SmallVector<FlatBundleFieldEntry> fieldTypes;
879 auto srcType = type_cast<FIRRTLType>(newArgs[argIndex].type);
880 if (!
peelType(srcType, fieldTypes, getPreservationModeForPorts(module)))
883 SmallVector<hw::InnerSymAttr> fieldSyms(fieldTypes.size());
884 if (failed(partitionSymbols(newArgs[argIndex].sym, srcType, fieldSyms,
885 newArgs[argIndex].loc))) {
886 encounteredError =
true;
890 for (
const auto &[idx, field, fieldSym] :
891 llvm::enumerate(fieldTypes, fieldSyms)) {
892 auto newValue = addArg(module, 1 + argIndex + idx, argsRemoved, srcType,
893 field, newArgs[argIndex], fieldSym);
894 newArgs.insert(newArgs.begin() + 1 + argIndex + idx, newValue.second);
896 lowering.push_back(newValue.first);
901static Value
cloneAccess(ImplicitLocOpBuilder *builder, Operation *op,
903 if (
auto rop = llvm::dyn_cast<SubfieldOp>(op))
904 return SubfieldOp::create(*builder, rhs, rop.getFieldIndex());
905 if (
auto rop = llvm::dyn_cast<SubindexOp>(op))
906 return SubindexOp::create(*builder, rhs, rop.getIndex());
907 if (
auto rop = llvm::dyn_cast<SubaccessOp>(op))
908 return SubaccessOp::create(*builder, rhs, rop.getIndex());
909 op->emitError(
"Unknown accessor");
913void TypeLoweringVisitor::lowerSAWritePath(Operation *op,
914 ArrayRef<Operation *> writePath) {
915 SubaccessOp sao = cast<SubaccessOp>(writePath.back());
916 FVectorType saoType = sao.getInput().getType();
917 auto selectWidth = llvm::Log2_64_Ceil(saoType.getNumElements());
919 for (
size_t index = 0, e = saoType.getNumElements(); index < e; ++index) {
920 auto cond = EQPrimOp::create(
921 *builder, sao.getIndex(),
922 builder->createOrFold<ConstantOp>(UIntType::get(
context, selectWidth),
923 APInt(selectWidth, index)));
924 WhenOp::create(*builder, cond,
false, [&]() {
926 Value leaf = SubindexOp::create(*builder, sao.getInput(), index);
927 for (
int i = writePath.size() - 2; i >= 0; --i) {
928 if (
auto access =
cloneAccess(builder, writePath[i], leaf))
931 encounteredError =
true;
942bool TypeLoweringVisitor::visitStmt(ConnectOp op) {
943 if (processSAPath(op))
947 SmallVector<FlatBundleFieldEntry> fields;
954 for (
const auto &field :
llvm::enumerate(fields)) {
955 Value src = getSubWhatever(op.getSrc(), field.index());
956 Value dest = getSubWhatever(op.getDest(), field.index());
957 if (field.value().isOutput)
958 std::swap(src, dest);
965bool TypeLoweringVisitor::visitStmt(MatchingConnectOp op) {
966 if (processSAPath(op))
970 SmallVector<FlatBundleFieldEntry> fields;
977 for (
const auto &field :
llvm::enumerate(fields)) {
978 Value src = getSubWhatever(op.getSrc(), field.index());
979 Value dest = getSubWhatever(op.getDest(), field.index());
980 if (field.value().isOutput)
981 std::swap(src, dest);
982 MatchingConnectOp::create(*builder, dest, src);
988bool TypeLoweringVisitor::visitStmt(RefDefineOp op) {
990 SmallVector<FlatBundleFieldEntry> fields;
992 if (!
peelType(op.getDest().getType(), fields, bodyAggregatePreservationMode))
996 for (
const auto &field :
llvm::enumerate(fields)) {
997 Value src = getSubWhatever(op.getSrc(), field.index());
998 Value dest = getSubWhatever(op.getDest(), field.index());
999 assert(!field.value().isOutput &&
"unexpected flip in reftype destination");
1000 RefDefineOp::create(*builder, dest, src);
1005bool TypeLoweringVisitor::visitStmt(WhenOp op) {
1011 lowerBlock(&op.getThenBlock());
1014 if (op.hasElseRegion())
1015 lowerBlock(&op.getElseBlock());
1020bool TypeLoweringVisitor::visitStmt(LayerBlockOp op) {
1021 lowerBlock(op.getBody());
1027bool TypeLoweringVisitor::visitDecl(MemOp op) {
1029 SmallVector<FlatBundleFieldEntry> fields;
1032 if (!
peelType(op.getDataType(), fields, memoryPreservationMode))
1035 if (op.getInnerSym()) {
1036 op->emitError() <<
"has a symbol, but no symbols may exist on aggregates "
1037 "passed through LowerTypes";
1038 encounteredError =
true;
1042 SmallVector<MemOp> newMemories;
1043 SmallVector<WireOp> oldPorts;
1046 for (
unsigned int index = 0, end = op.getNumResults(); index <
end; ++index) {
1047 auto result = op.getResult(index);
1048 if (op.getPortKind(index) == MemOp::PortKind::Debug) {
1049 op.emitOpError(
"cannot lower memory with debug port");
1050 encounteredError =
true;
1054 WireOp::create(*builder, result.getType(),
1055 (op.getName() +
"_" + op.getPortName(index)).str());
1056 oldPorts.push_back(wire);
1057 result.replaceAllUsesWith(wire.getResult());
1064 for (
const auto &field : fields) {
1066 if (!newMemForField) {
1067 op.emitError(
"failed cloning memory for field");
1068 encounteredError =
true;
1071 newMemories.push_back(newMemForField);
1074 for (
size_t index = 0, rend = op.getNumResults(); index < rend; ++index) {
1075 auto result = oldPorts[index].getResult();
1076 auto rType = type_cast<BundleType>(result.getType());
1077 for (
size_t fieldIndex = 0, fend = rType.getNumElements();
1078 fieldIndex != fend; ++fieldIndex) {
1079 auto name = rType.getElement(fieldIndex).name.getValue();
1080 auto oldField = SubfieldOp::create(*builder, result, fieldIndex);
1083 if (name ==
"data" || name ==
"mask" || name ==
"wdata" ||
1084 name ==
"wmask" || name ==
"rdata") {
1085 for (
const auto &field : fields) {
1086 auto realOldField = getSubWhatever(oldField, field.index);
1087 auto newField = getSubWhatever(
1088 newMemories[field.index].getResult(index), fieldIndex);
1089 if (rType.getElement(fieldIndex).isFlip)
1090 std::swap(realOldField, newField);
1094 for (
auto mem : newMemories) {
1096 SubfieldOp::create(*builder, mem.getResult(index), fieldIndex);
1105bool TypeLoweringVisitor::visitDecl(FExtModuleOp extModule) {
1106 ImplicitLocOpBuilder theBuilder(extModule.getLoc(),
context);
1107 builder = &theBuilder;
1113 SmallVector<unsigned> argsToRemove;
1114 auto newArgs = extModule.getPorts();
1116 DomainLoweringHelper domainHelper(
context, extModule.getPortTypes());
1118 for (
size_t argIndex = 0, argsRemoved = 0; argIndex < newArgs.size();
1120 SmallVector<Value> lowering;
1121 if (lowerArg(extModule, argIndex, argsRemoved, newArgs, lowering)) {
1122 argsToRemove.push_back(argIndex);
1129 for (
auto toRemove :
llvm::reverse(argsToRemove))
1130 newArgs.erase(newArgs.begin() + toRemove);
1132 domainHelper.computeDomainMap(newArgs);
1134 SmallVector<NamedAttribute, 8> newModuleAttrs;
1137 for (
auto attr : extModule->getAttrDictionary())
1140 if (attr.
getName() !=
"portDirections" && attr.
getName() !=
"portNames" &&
1141 attr.
getName() !=
"portTypes" && attr.
getName() !=
"portAnnotations" &&
1142 attr.
getName() !=
"portSymbols" && attr.
getName() !=
"portLocations")
1143 newModuleAttrs.push_back(attr);
1145 SmallVector<Direction> newArgDirections;
1146 SmallVector<Attribute> newArgNames;
1147 SmallVector<Attribute, 8> newArgTypes;
1148 SmallVector<Attribute, 8> newArgSyms;
1149 SmallVector<Attribute, 8> newArgLocations;
1150 SmallVector<Attribute, 8> newArgAnnotations;
1151 SmallVector<Attribute, 8> newArgDomains;
1153 for (
auto &port : newArgs) {
1154 newArgDirections.push_back(port.direction);
1155 newArgNames.push_back(port.name);
1156 newArgTypes.push_back(TypeAttr::get(port.type));
1157 newArgSyms.push_back(port.sym);
1158 newArgLocations.push_back(port.loc);
1159 newArgAnnotations.push_back(port.annotations.getArrayAttr());
1160 if (
auto &domains = port.domains) {
1161 domainHelper.rewriteDomain(port.domains);
1163 port.domains = cache.aEmpty;
1165 newArgDomains.push_back(port.domains);
1168 newModuleAttrs.push_back(
1169 NamedAttribute(cache.sPortDirections,
1172 newModuleAttrs.push_back(
1173 NamedAttribute(cache.sPortNames, builder.getArrayAttr(newArgNames)));
1175 newModuleAttrs.push_back(
1176 NamedAttribute(cache.sPortTypes, builder.getArrayAttr(newArgTypes)));
1178 newModuleAttrs.push_back(NamedAttribute(
1179 cache.sPortLocations, builder.getArrayAttr(newArgLocations)));
1181 newModuleAttrs.push_back(NamedAttribute(
1182 cache.sPortAnnotations, builder.getArrayAttr(newArgAnnotations)));
1184 newModuleAttrs.push_back(
1185 NamedAttribute(cache.sPortDomains, builder.getArrayAttr(newArgDomains)));
1188 extModule->setAttrs(newModuleAttrs);
1189 FModuleLike::fixupPortSymsArray(newArgSyms,
context);
1190 extModule.setPortSymbols(newArgSyms);
1195bool TypeLoweringVisitor::visitDecl(FModuleOp module) {
1196 auto *body =
module.getBodyBlock();
1198 ImplicitLocOpBuilder theBuilder(module.getLoc(),
context);
1199 builder = &theBuilder;
1205 llvm::BitVector argsToRemove;
1206 auto newArgs =
module.getPorts();
1208 DomainLoweringHelper domainHelper(
context, module.getPortTypes());
1210 size_t argsRemoved = 0;
1211 for (
size_t argIndex = 0; argIndex < newArgs.size(); ++argIndex) {
1212 SmallVector<Value> lowerings;
1213 if (lowerArg(module, argIndex, argsRemoved, newArgs, lowerings)) {
1214 auto arg =
module.getArgument(argIndex);
1215 processUsers(arg, lowerings);
1216 argsToRemove.push_back(
true);
1219 argsToRemove.push_back(
false);
1224 if (argsRemoved != 0) {
1225 body->eraseArguments(argsToRemove);
1226 size_t size = newArgs.size();
1227 for (
size_t src = 0, dst = 0; src < size; ++src) {
1228 if (argsToRemove[src])
1230 newArgs[dst] = newArgs[src];
1233 newArgs.erase(newArgs.end() - argsRemoved, newArgs.end());
1236 domainHelper.computeDomainMap(newArgs);
1238 SmallVector<NamedAttribute, 8> newModuleAttrs;
1241 for (
auto attr : module->getAttrDictionary())
1244 if (attr.
getName() !=
"portNames" && attr.
getName() !=
"portDirections" &&
1245 attr.
getName() !=
"portTypes" && attr.
getName() !=
"portAnnotations" &&
1246 attr.
getName() !=
"portSymbols" && attr.
getName() !=
"portLocations")
1247 newModuleAttrs.push_back(attr);
1249 SmallVector<Direction> newArgDirections;
1250 SmallVector<Attribute> newArgNames;
1251 SmallVector<Attribute> newArgTypes;
1252 SmallVector<Attribute> newArgSyms;
1253 SmallVector<Attribute> newArgLocations;
1254 SmallVector<Attribute, 8> newArgAnnotations;
1255 SmallVector<Attribute> newPortDomains;
1256 for (
auto &port : newArgs) {
1257 newArgDirections.push_back(port.direction);
1258 newArgNames.push_back(port.name);
1259 newArgTypes.push_back(TypeAttr::get(port.type));
1260 newArgSyms.push_back(port.sym);
1261 newArgLocations.push_back(port.loc);
1262 newArgAnnotations.push_back(port.annotations.getArrayAttr());
1263 if (
auto domains = port.domains) {
1264 domainHelper.rewriteDomain(port.domains);
1266 port.domains = cache.aEmpty;
1268 newPortDomains.push_back(port.domains);
1271 newModuleAttrs.push_back(
1272 NamedAttribute(cache.sPortDirections,
1275 newModuleAttrs.push_back(
1276 NamedAttribute(cache.sPortNames, builder->getArrayAttr(newArgNames)));
1278 newModuleAttrs.push_back(
1279 NamedAttribute(cache.sPortTypes, builder->getArrayAttr(newArgTypes)));
1281 newModuleAttrs.push_back(NamedAttribute(
1282 cache.sPortLocations, builder->getArrayAttr(newArgLocations)));
1284 newModuleAttrs.push_back(NamedAttribute(
1285 cache.sPortAnnotations, builder->getArrayAttr(newArgAnnotations)));
1287 newModuleAttrs.push_back(NamedAttribute(
1288 cache.sPortDomains, builder->getArrayAttr(newPortDomains)));
1291 module->setAttrs(newModuleAttrs);
1292 FModuleLike::fixupPortSymsArray(newArgSyms,
context);
1293 module.setPortSymbols(newArgSyms);
1298bool TypeLoweringVisitor::visitDecl(WireOp op) {
1299 if (op.isForceable())
1302 auto clone = [&](
const FlatBundleFieldEntry &field,
1303 ArrayAttr attrs) -> Value {
1304 return WireOp::create(*builder,
1306 "", NameKindEnum::DroppableName, attrs, StringAttr{})
1309 return lowerProducer(op, clone);
1313bool TypeLoweringVisitor::visitDecl(RegOp op) {
1314 if (op.isForceable())
1317 auto clone = [&](
const FlatBundleFieldEntry &field,
1318 ArrayAttr attrs) -> Value {
1319 return RegOp::create(*builder, field.type, op.getClockVal(),
"",
1320 NameKindEnum::DroppableName, attrs, StringAttr{})
1323 return lowerProducer(op, clone);
1327bool TypeLoweringVisitor::visitDecl(RegResetOp op) {
1328 if (op.isForceable())
1331 auto clone = [&](
const FlatBundleFieldEntry &field,
1332 ArrayAttr attrs) -> Value {
1333 auto resetVal = getSubWhatever(op.getResetValue(), field.index);
1334 return RegResetOp::create(*builder, field.type, op.getClockVal(),
1335 op.getResetSignal(), resetVal,
"",
1336 NameKindEnum::DroppableName, attrs, StringAttr{})
1339 return lowerProducer(op, clone);
1343bool TypeLoweringVisitor::visitDecl(NodeOp op) {
1344 if (op.isForceable())
1347 auto clone = [&](
const FlatBundleFieldEntry &field,
1348 ArrayAttr attrs) -> Value {
1349 auto input = getSubWhatever(op.getInput(), field.index);
1350 return NodeOp::create(*builder, input,
"", NameKindEnum::DroppableName,
1354 return lowerProducer(op, clone);
1358bool TypeLoweringVisitor::visitExpr(InvalidValueOp op) {
1359 auto clone = [&](
const FlatBundleFieldEntry &field,
1360 ArrayAttr attrs) -> Value {
1361 return InvalidValueOp::create(*builder, field.type);
1363 return lowerProducer(op, clone);
1367bool TypeLoweringVisitor::visitExpr(MuxPrimOp op) {
1368 auto clone = [&](
const FlatBundleFieldEntry &field,
1369 ArrayAttr attrs) -> Value {
1370 auto high = getSubWhatever(op.getHigh(), field.index);
1371 auto low = getSubWhatever(op.getLow(), field.index);
1372 return MuxPrimOp::create(*builder, op.getSel(), high, low);
1374 return lowerProducer(op, clone);
1378bool TypeLoweringVisitor::visitExpr(Mux2CellIntrinsicOp op) {
1379 auto clone = [&](
const FlatBundleFieldEntry &field,
1380 ArrayAttr attrs) -> Value {
1381 auto high = getSubWhatever(op.getHigh(), field.index);
1382 auto low = getSubWhatever(op.getLow(), field.index);
1383 return Mux2CellIntrinsicOp::create(*builder, op.getSel(), high, low);
1385 return lowerProducer(op, clone);
1389bool TypeLoweringVisitor::visitExpr(Mux4CellIntrinsicOp op) {
1390 auto clone = [&](
const FlatBundleFieldEntry &field,
1391 ArrayAttr attrs) -> Value {
1392 auto v3 = getSubWhatever(op.getV3(), field.index);
1393 auto v2 = getSubWhatever(op.getV2(), field.index);
1394 auto v1 = getSubWhatever(op.getV1(), field.index);
1395 auto v0 = getSubWhatever(op.getV0(), field.index);
1396 return Mux4CellIntrinsicOp::create(*builder, op.getSel(), v3, v2, v1, v0);
1398 return lowerProducer(op, clone);
1402bool TypeLoweringVisitor::visitUnrealizedConversionCast(
1403 mlir::UnrealizedConversionCastOp op) {
1404 auto clone = [&](
const FlatBundleFieldEntry &field,
1405 ArrayAttr attrs) -> Value {
1406 auto input = getSubWhatever(op.getOperand(0), field.index);
1407 return mlir::UnrealizedConversionCastOp::create(*builder, field.type, input)
1412 if (!type_isa<FIRRTLType>(op->getOperand(0).getType()))
1414 return lowerProducer(op, clone);
1418bool TypeLoweringVisitor::visitExpr(BitCastOp op) {
1419 Value srcLoweredVal = op.getInput();
1423 SmallVector<FlatBundleFieldEntry> fields;
1425 size_t uptoBits = 0;
1428 for (
const auto &field :
llvm::enumerate(fields)) {
1429 auto fieldBitwidth = *
getBitWidth(field.value().type);
1431 if (fieldBitwidth == 0)
1433 Value src = getSubWhatever(op.getInput(), field.index());
1435 src = builder->createOrFold<BitCastOp>(
1436 UIntType::get(
context, fieldBitwidth), src);
1439 srcLoweredVal = src;
1441 if (type_isa<BundleType>(op.getInput().getType())) {
1443 CatPrimOp::create(*builder, ValueRange{srcLoweredVal, src});
1446 CatPrimOp::create(*builder, ValueRange{src, srcLoweredVal});
1450 uptoBits += fieldBitwidth;
1453 srcLoweredVal = builder->createOrFold<AsUIntPrimOp>(srcLoweredVal);
1457 if (type_isa<BundleType, FVectorType>(op.getResult().getType())) {
1459 size_t uptoBits = 0;
1460 auto aggregateBits = *
getBitWidth(op.getResult().getType());
1461 auto clone = [&](
const FlatBundleFieldEntry &field,
1462 ArrayAttr attrs) -> Value {
1468 return InvalidValueOp::create(*builder, field.type);
1473 if (type_isa<BundleType>(op.getResult().getType())) {
1474 extractBits = BitsPrimOp::create(*builder, srcLoweredVal,
1475 aggregateBits - uptoBits - 1,
1476 aggregateBits - uptoBits - fieldBits);
1478 extractBits = BitsPrimOp::create(*builder, srcLoweredVal,
1479 uptoBits + fieldBits - 1, uptoBits);
1481 uptoBits += fieldBits;
1482 return BitCastOp::create(*builder, field.type,
extractBits);
1484 return lowerProducer(op, clone);
1488 if (type_isa<SIntType>(op.getType()))
1489 srcLoweredVal = AsSIntPrimOp::create(*builder, srcLoweredVal);
1490 op.getResult().replaceAllUsesWith(srcLoweredVal);
1494bool TypeLoweringVisitor::visitExpr(RefSendOp op) {
1495 auto clone = [&](
const FlatBundleFieldEntry &field,
1496 ArrayAttr attrs) -> Value {
1497 return RefSendOp::create(*builder,
1498 getSubWhatever(op.getBase(), field.index));
1503 return lowerProducer(op, clone);
1506bool TypeLoweringVisitor::visitExpr(RefResolveOp op) {
1507 auto clone = [&](
const FlatBundleFieldEntry &field,
1508 ArrayAttr attrs) -> Value {
1509 Value src = getSubWhatever(op.getRef(), field.index);
1510 return RefResolveOp::create(*builder, src);
1514 return lowerProducer(op, clone, op.getRef().getType());
1517bool TypeLoweringVisitor::visitExpr(RefCastOp op) {
1518 auto clone = [&](
const FlatBundleFieldEntry &field,
1519 ArrayAttr attrs) -> Value {
1520 auto input = getSubWhatever(op.getInput(), field.index);
1521 return RefCastOp::create(*builder,
1522 RefType::get(field.type,
1523 op.getType().getForceable(),
1524 op.getType().getLayer()),
1527 return lowerProducer(op, clone);
1530bool TypeLoweringVisitor::visitDecl(InstanceOp op) {
1532 SmallVector<Type, 8> resultTypes;
1533 SmallVector<int64_t, 8> endFields;
1534 auto oldPortAnno = op.getPortAnnotations();
1535 SmallVector<Direction> newDirs;
1536 SmallVector<Attribute> newNames;
1537 SmallVector<Attribute> newDomains;
1538 SmallVector<Attribute> newPortAnno;
1540 cast<FModuleLike>(op.getReferencedOperation(symTbl)));
1543 DomainLoweringHelper domainHelper(
context, op.getResultTypes());
1545 endFields.push_back(0);
1546 for (
size_t i = 0, e = op.getNumResults(); i != e; ++i) {
1547 auto srcType = type_cast<FIRRTLType>(op.getType(i));
1550 SmallVector<FlatBundleFieldEntry, 8> fieldTypes;
1551 if (!
peelType(srcType, fieldTypes, mode)) {
1552 newDirs.push_back(op.getPortDirection(i));
1553 newNames.push_back(op.getPortNameAttr(i));
1554 newDomains.push_back(op.getPortDomain(i));
1555 resultTypes.push_back(srcType);
1556 newPortAnno.push_back(oldPortAnno[i]);
1559 auto oldName = op.getPortName(i);
1560 auto oldDir = op.getPortDirection(i);
1562 for (
const auto &field : fieldTypes) {
1563 newDirs.push_back(
direction::get((
unsigned)oldDir ^ field.isOutput));
1564 newNames.push_back(builder->getStringAttr(oldName + field.suffix));
1565 newDomains.push_back(op.getPortDomain(i));
1567 auto annos = filterAnnotations(
1568 context, dyn_cast_or_null<ArrayAttr>(oldPortAnno[i]), srcType,
1570 newPortAnno.push_back(annos);
1573 endFields.push_back(resultTypes.size());
1583 domainHelper.computeDomainMap(resultTypes);
1586 for (
auto &domain : newDomains)
1587 domainHelper.rewriteDomain(domain);
1590 auto newInstance = InstanceOp::create(
1591 *builder, resultTypes, op.getModuleNameAttr(), op.getNameAttr(),
1593 builder->getArrayAttr(newNames), builder->getArrayAttr(newDomains),
1594 op.getAnnotations(), builder->getArrayAttr(newPortAnno),
1595 op.getLayersAttr(), op.getLowerToBindAttr(), op.getDoNotPrintAttr(),
1596 sym ? hw::InnerSymAttr::get(sym) :
hw::InnerSymAttr());
1598 newInstance->setDiscardableAttrs(op->getDiscardableAttrDictionary());
1600 SmallVector<Value> lowered;
1601 for (
size_t aggIndex = 0, eAgg = op.getNumResults(); aggIndex != eAgg;
1604 for (
size_t fieldIndex = endFields[aggIndex],
1605 eField = endFields[aggIndex + 1];
1606 fieldIndex < eField; ++fieldIndex)
1607 lowered.push_back(newInstance.getResult(fieldIndex));
1608 if (lowered.size() != 1 ||
1609 op.getType(aggIndex) != resultTypes[endFields[aggIndex]])
1610 processUsers(op.getResult(aggIndex), lowered);
1612 op.getResult(aggIndex).replaceAllUsesWith(lowered[0]);
1617bool TypeLoweringVisitor::visitExpr(SubaccessOp op) {
1618 auto input = op.getInput();
1619 FVectorType vType = input.getType();
1622 if (vType.getNumElements() == 0) {
1623 Value inv = InvalidValueOp::create(*builder, vType.getElementType());
1624 op.replaceAllUsesWith(inv);
1629 if (ConstantOp arg =
1630 llvm::dyn_cast_or_null<ConstantOp>(op.getIndex().getDefiningOp())) {
1631 auto sio = SubindexOp::create(*builder, op.getInput(),
1632 arg.getValue().getExtValue());
1633 op.replaceAllUsesWith(sio.getResult());
1638 SmallVector<Value> inputs;
1639 inputs.reserve(vType.getNumElements());
1640 for (
int index = vType.getNumElements() - 1; index >= 0; index--)
1641 inputs.push_back(SubindexOp::create(*builder, input, index));
1643 Value multibitMux = MultibitMuxOp::create(*builder, op.getIndex(), inputs);
1644 op.replaceAllUsesWith(multibitMux);
1648bool TypeLoweringVisitor::visitExpr(VectorCreateOp op) {
1649 auto clone = [&](
const FlatBundleFieldEntry &field,
1650 ArrayAttr attrs) -> Value {
1651 return op.getOperand(field.index);
1653 return lowerProducer(op, clone);
1656bool TypeLoweringVisitor::visitExpr(BundleCreateOp op) {
1657 auto clone = [&](
const FlatBundleFieldEntry &field,
1658 ArrayAttr attrs) -> Value {
1659 return op.getOperand(field.index);
1661 return lowerProducer(op, clone);
1664bool TypeLoweringVisitor::visitExpr(ElementwiseOrPrimOp op) {
1665 auto clone = [&](
const FlatBundleFieldEntry &field,
1666 ArrayAttr attrs) -> Value {
1667 Value operands[] = {getSubWhatever(op.getLhs(), field.index),
1668 getSubWhatever(op.getRhs(), field.index)};
1669 return type_isa<BundleType, FVectorType>(field.type)
1670 ? (Value)ElementwiseOrPrimOp::create(*builder, field.type,
1672 : (Value)OrPrimOp::create(*builder, operands);
1675 return lowerProducer(op, clone);
1678bool TypeLoweringVisitor::visitExpr(ElementwiseAndPrimOp op) {
1679 auto clone = [&](
const FlatBundleFieldEntry &field,
1680 ArrayAttr attrs) -> Value {
1681 Value operands[] = {getSubWhatever(op.getLhs(), field.index),
1682 getSubWhatever(op.getRhs(), field.index)};
1683 return type_isa<BundleType, FVectorType>(field.type)
1684 ? (Value)ElementwiseAndPrimOp::create(*builder, field.type,
1686 : (Value)AndPrimOp::create(*builder, operands);
1689 return lowerProducer(op, clone);
1692bool TypeLoweringVisitor::visitExpr(ElementwiseXorPrimOp op) {
1693 auto clone = [&](
const FlatBundleFieldEntry &field,
1694 ArrayAttr attrs) -> Value {
1695 Value operands[] = {getSubWhatever(op.getLhs(), field.index),
1696 getSubWhatever(op.getRhs(), field.index)};
1697 return type_isa<BundleType, FVectorType>(field.type)
1698 ? (Value)ElementwiseXorPrimOp::create(*builder, field.type,
1700 : (Value)XorPrimOp::create(*builder, operands);
1703 return lowerProducer(op, clone);
1706bool TypeLoweringVisitor::visitExpr(MultibitMuxOp op) {
1707 auto clone = [&](
const FlatBundleFieldEntry &field,
1708 ArrayAttr attrs) -> Value {
1709 SmallVector<Value> newInputs;
1710 newInputs.reserve(op.getInputs().size());
1711 for (
auto input : op.getInputs()) {
1712 auto inputSub = getSubWhatever(input, field.index);
1713 newInputs.push_back(inputSub);
1715 return MultibitMuxOp::create(*builder, op.getIndex(), newInputs);
1717 return lowerProducer(op, clone);
1725struct LowerTypesPass
1726 :
public circt::firrtl::impl::LowerFIRRTLTypesBase<LowerTypesPass> {
1729 void runOnOperation()
override;
1734void LowerTypesPass::runOnOperation() {
1737 std::vector<FModuleLike> ops;
1739 auto &symTbl = getAnalysis<SymbolTable>();
1741 AttrCache cache(&getContext());
1743 DenseMap<FModuleLike, Convention> conventionTable;
1744 auto circuit = getOperation();
1745 for (
auto module : circuit.getOps<FModuleLike>()) {
1746 conventionTable.insert({module,
module.getConvention()});
1747 ops.push_back(module);
1751 auto lowerModules = [&](FModuleLike op) -> LogicalResult {
1753 Convention convention = Convention::Internal;
1754 if (
auto conventionAttr = dyn_cast_or_null<ConventionAttr>(
1755 op->getDiscardableAttr(
"body_type_lowering")))
1756 convention = conventionAttr.getValue();
1759 TypeLoweringVisitor(&getContext(), preserveAggregate, convention,
1760 preserveMemories, symTbl, cache, conventionTable);
1763 return LogicalResult::failure(tl.isFailed());
1766 auto result = failableParallelForEach(&getContext(), ops, lowerModules);
1769 signalPassFailure();
assert(baseType &&"element must be base type")
static SmallVector< Value > extractBits(OpBuilder &builder, Value val)
static std::unique_ptr< Context > context
static void dump(DIModule &module, raw_indented_ostream &os)
static bool isPreservableAggregateType(Type type, PreserveAggregate::PreserveMode mode)
Return true if we can preserve the type.
static FIRRTLType mapLoweredType(FIRRTLType type, FIRRTLBaseType fieldType)
Return fieldType or fieldType as same ref as type.
static MemOp cloneMemWithNewType(ImplicitLocOpBuilder *b, MemOp op, FlatBundleFieldEntry field)
Clone memory for the specified field. Returns null op on error.
static bool containsBundleType(FIRRTLType type)
Return true if the type has a bundle type as subtype.
static Value cloneAccess(ImplicitLocOpBuilder *builder, Operation *op, Value rhs)
static bool peelType(Type type, SmallVectorImpl< FlatBundleFieldEntry > &fields, PreserveAggregate::PreserveMode mode)
Peel one layer of an aggregate type into its components.
static bool isNotSubAccess(Operation *op)
Return if something is not a normal subaccess.
static SmallVector< Operation * > getSAWritePath(Operation *op)
Look through and collect subfields leading to a subaccess.
static bool isOneDimVectorType(FIRRTLType type)
Return true if the type is a 1d vector type or ground type.
#define CIRCT_DEBUG_SCOPED_PASS_LOGGER(PASS)
This class provides a read-only projection over the MLIR attributes that represent a set of annotatio...
ArrayAttr getArrayAttr() const
Return this annotation set as an ArrayAttr.
This class provides a read-only projection of an annotation.
DictionaryAttr getDict() const
Get the data dictionary of this attribute.
unsigned getFieldID() const
Get the field id this attribute targets.
void setMember(StringAttr name, Attribute value)
Add or set a member of the annotation to a value.
This class implements the same functionality as TypeSwitch except that it uses firrtl::type_dyn_cast ...
FIRRTLTypeSwitch< T, ResultT > & Case(CallableT &&caseFn)
Add a case on the given type.
FIRRTLVisitor allows you to visit all of the expr/stmt/decls with one class declaration.
ResultType visitInvalidOp(Operation *op, ExtraArgs... args)
visitInvalidOp is an override point for non-FIRRTL dialect operations.
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
@ All
Preserve all aggregate values.
@ OneDimVec
Preserve only 1d vectors of ground type (e.g. UInt<2>[3]).
@ Vec
Preserve only vectors (e.g. UInt<2>[3][3]).
@ None
Don't preserve aggregate at all.
mlir::DenseBoolArrayAttr packAttribute(MLIRContext *context, ArrayRef< Direction > directions)
Return a DenseBoolArrayAttr containing the packed representation of an array of directions.
static Direction get(bool isOutput)
Return an output direction if isOutput is true, otherwise return an input direction.
Direction
This represents the direction of a single port.
FIRRTLBaseType getBaseType(Type type)
If it is a base type, return it as is.
FIRRTLType mapBaseType(FIRRTLType type, function_ref< FIRRTLBaseType(FIRRTLBaseType)> fn)
Return a FIRRTLType with its base type component mutated by the given function.
bool hasZeroBitWidth(FIRRTLType type)
Return true if the type has zero bit width.
void emitConnect(OpBuilder &builder, Location loc, Value lhs, Value rhs)
Emit a connect between two values.
StringAttr getInnerSymName(Operation *op)
Return the StringAttr for the inner_sym name, if it exists.
std::optional< int64_t > getBitWidth(FIRRTLBaseType type, bool ignoreFlip=false)
uint64_t getMaxFieldID(Type)
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
This holds the name and type that describes the module's ports.
AnnotationSet annotations