39#include "mlir/IR/ImplicitLocOpBuilder.h"
40#include "mlir/IR/Threading.h"
41#include "mlir/Pass/Pass.h"
42#include "llvm/ADT/APSInt.h"
43#include "llvm/ADT/BitVector.h"
44#include "llvm/ADT/STLExtras.h"
45#include "llvm/Support/Debug.h"
47#define DEBUG_TYPE "firrtl-lower-types"
51#define GEN_PASS_DEF_LOWERFIRRTLTYPES
52#include "circt/Dialect/FIRRTL/Passes.h.inc"
57using namespace firrtl;
62struct FlatBundleFieldEntry {
70 SmallString<16> suffix;
75 unsigned fieldID, StringRef suffix,
bool isOutput)
76 : type(type), index(index), fieldID(fieldID), suffix(suffix),
80 llvm::errs() <<
"FBFE{" << type <<
" index<" << index <<
"> fieldID<"
81 << fieldID <<
"> suffix<" << suffix <<
"> isOutput<"
82 << isOutput <<
">}\n";
89 return mapBaseType(type, [&](
auto) {
return fieldType; });
94 auto ftype = type_dyn_cast<FIRRTLType>(type);
103 .
Case<BundleType>([&](
auto bundle) {
return false; })
104 .Case<FVectorType>([&](FVectorType vector) {
106 return vector.getElementType().isGround() &&
107 vector.getNumElements() > 1;
109 .Default([](
auto groundType) {
return true; });
116 .
Case<BundleType>([&](
auto bundle) {
return true; })
117 .Case<FVectorType>([&](FVectorType vector) {
120 .Default([](
auto groundType) {
return false; });
127 if (
auto refType = type_dyn_cast<RefType>(type)) {
129 if (refType.getForceable())
140 auto firrtlType = type_dyn_cast<FIRRTLBaseType>(type);
146 if (!firrtlType.isPassive() || firrtlType.containsAnalog() ||
158 llvm_unreachable(
"unexpected mode");
164static bool peelType(Type type, SmallVectorImpl<FlatBundleFieldEntry> &fields,
171 if (
auto refType = type_dyn_cast<RefType>(type))
172 type = refType.getType();
174 .
Case<BundleType>([&](
auto bundle) {
175 SmallString<16> tmpSuffix;
177 for (
size_t i = 0, e = bundle.getNumElements(); i < e; ++i) {
178 auto elt = bundle.getElement(i);
181 tmpSuffix.push_back(
'_');
182 tmpSuffix.append(elt.name.getValue());
183 fields.emplace_back(elt.type, i, bundle.getFieldID(i), tmpSuffix,
188 .Case<FVectorType>([&](
auto vector) {
190 for (
size_t i = 0, e = vector.getNumElements(); i != e; ++i) {
191 fields.emplace_back(vector.getElementType(), i, vector.getFieldID(i),
192 "_" + std::to_string(i),
false);
196 .Default([](
auto op) {
return false; });
202 SubaccessOp sao = llvm::dyn_cast<SubaccessOp>(op);
206 llvm::dyn_cast_or_null<ConstantOp>(sao.getIndex().getDefiningOp());
207 return arg && sao.getInput().getType().base().getNumElements() != 0;
212 SmallVector<Operation *> retval;
213 auto defOp = op->getOperand(0).getDefiningOp();
214 while (isa_and_nonnull<SubfieldOp, SubindexOp, SubaccessOp>(defOp)) {
215 retval.push_back(defOp);
216 defOp = defOp->getOperand(0).getDefiningOp();
226 FlatBundleFieldEntry field) {
227 SmallVector<Type, 8> ports;
228 SmallVector<Attribute, 8> portNames;
229 SmallVector<Attribute, 8> portLocations;
231 auto oldPorts = op.getPorts();
232 for (
size_t portIdx = 0, e = oldPorts.size(); portIdx < e; ++portIdx) {
233 auto port = oldPorts[portIdx];
235 MemOp::getTypeForPort(op.getDepth(), field.type, port.second));
236 portNames.push_back(port.first);
241 MemOp::create(*b, ports, op.getReadLatency(), op.getWriteLatency(),
242 op.getDepth(), op.getRuw(), b->getArrayAttr(portNames),
243 (op.getName() + field.suffix).str(), op.getNameKind(),
244 op.getAnnotations(), op.getPortAnnotations(),
245 op.getInnerSymAttr(), op.getInitAttr(), op.getPrefixAttr());
247 if (op.getInnerSym()) {
248 op.emitError(
"cannot split memory with symbol present");
252 SmallVector<Attribute> newAnnotations;
253 for (
size_t portIdx = 0, e = newMem.getNumResults(); portIdx < e; ++portIdx) {
254 auto portType = type_cast<BundleType>(newMem.getResult(portIdx).getType());
255 auto oldPortType = type_cast<BundleType>(op.getResult(portIdx).getType());
256 SmallVector<Attribute> portAnno;
257 for (
auto attr : newMem.getPortAnnotation(portIdx)) {
260 auto targetIndex = oldPortType.getIndexForFieldID(annoFieldID);
264 if (annoFieldID == oldPortType.getFieldID(targetIndex)) {
267 b->getI32IntegerAttr(portType.getFieldID(targetIndex)));
268 portAnno.push_back(anno.
getDict());
273 if (type_isa<BundleType>(oldPortType.getElement(targetIndex).type)) {
278 auto fieldID = field.fieldID + oldPortType.getFieldID(targetIndex);
279 if (annoFieldID >= fieldID &&
284 annoFieldID - fieldID + portType.getFieldID(targetIndex);
285 anno.
setMember(
"circt.fieldID", b->getI32IntegerAttr(newFieldID));
286 portAnno.push_back(anno.
getDict());
290 portAnno.push_back(attr);
292 newAnnotations.push_back(b->getArrayAttr(portAnno));
294 newMem.setAllPortAnnotations(newAnnotations);
304 AttrCache(MLIRContext *
context) {
305 i64ty = IntegerType::get(
context, 64);
306 nameAttr = StringAttr::get(
context,
"name");
307 nameKindAttr = StringAttr::get(
context,
"nameKind");
308 sPortDirections = StringAttr::get(
context,
"portDirections");
309 sPortNames = StringAttr::get(
context,
"portNames");
310 sPortTypes = StringAttr::get(
context,
"portTypes");
311 sPortSymbols = StringAttr::get(
context,
"portSymbols");
312 sPortLocations = StringAttr::get(
context,
"portLocations");
313 sPortAnnotations = StringAttr::get(
context,
"portAnnotations");
314 sPortDomains = StringAttr::get(
context,
"domainInfo");
315 sEmpty = StringAttr::get(
context,
"");
316 aEmpty = ArrayAttr::get(
context, {});
318 AttrCache(
const AttrCache &) =
default;
321 StringAttr nameAttr, nameKindAttr, sPortDirections, sPortNames, sPortTypes,
322 sPortSymbols, sPortLocations, sPortAnnotations, sPortDomains, sEmpty;
329class DomainLoweringHelper {
334 DomainLoweringHelper(MLIRContext *context, ArrayRef<Attribute> portTypes)
336 for (
auto [index, typeAttr] :
llvm::enumerate(portTypes))
337 if (
type_isa<DomainType>(cast<TypeAttr>(typeAttr).getValue()))
338 domainIndexByOrdinal.push_back(index);
342 DomainLoweringHelper(MLIRContext *context, TypeRange resultTypes)
344 for (
auto [index, type] :
llvm::enumerate(resultTypes))
346 domainIndexByOrdinal.push_back(index);
353 void computeDomainMap(TypeRange types) {
354 size_t i = 0, ord = 0;
355 for (
auto type : types) {
356 if (type_isa<DomainType>(type))
357 domainMap[domainIndexByOrdinal[ord++]] = i;
366 void computeDomainMap(ArrayRef<PortInfo> ports) {
367 size_t i = 0, ord = 0;
368 for (
const auto &port : ports) {
369 if (type_isa<DomainType>(port.type))
370 domainMap[domainIndexByOrdinal[ord++]] = i;
378 void rewriteDomain(Attribute &domain) {
379 auto oldAssociations = dyn_cast<ArrayAttr>(domain);
380 if (!oldAssociations)
382 SmallVector<Attribute> newAssociations;
383 for (
auto oldAttr : oldAssociations)
384 newAssociations.push_back(IntegerAttr::
get(
385 IntegerType::
get(context, 32, IntegerType::Unsigned),
386 domainMap[cast<IntegerAttr>(oldAttr).getValue().getZExtValue()]));
387 domain = ArrayAttr::get(context, newAssociations);
391 MLIRContext *context;
393 SmallVector<unsigned> domainIndexByOrdinal;
395 DenseMap<unsigned, unsigned> domainMap;
400struct TypeLoweringVisitor :
public FIRRTLVisitor<TypeLoweringVisitor, bool> {
404 Convention bodyConvention,
406 SymbolTable &symTbl,
const AttrCache &cache,
407 const llvm::DenseMap<FModuleLike, Convention> &conventionTable)
408 : context(context), defaultAggregatePreservationMode(preserveAggregate),
409 memoryPreservationMode(memoryPreservationMode), symTbl(symTbl),
410 cache(cache), conventionTable(conventionTable) {
411 bodyAggregatePreservationMode = bodyConvention == Convention::Scalarized
413 : defaultAggregatePreservationMode;
421 void lowerModule(FModuleLike op);
423 bool lowerArg(FModuleLike module,
size_t argIndex,
size_t argsRemoved,
424 SmallVectorImpl<PortInfo> &newArgs,
425 SmallVectorImpl<Value> &lowering);
426 std::pair<Value, PortInfo> addArg(Operation *module,
unsigned insertPt,
428 const FlatBundleFieldEntry &field,
429 PortInfo &oldArg, hw::InnerSymAttr newSym);
432 bool visitDecl(FExtModuleOp op);
433 bool visitDecl(FModuleOp op);
434 bool visitDecl(InstanceOp op);
435 bool visitDecl(InstanceChoiceOp op);
436 bool visitDecl(MemOp op);
437 bool visitDecl(NodeOp op);
438 bool visitDecl(RegOp op);
439 bool visitDecl(WireOp op);
440 bool visitDecl(RegResetOp op);
441 bool visitExpr(InvalidValueOp op);
442 bool visitExpr(SubaccessOp op);
443 bool visitExpr(VectorCreateOp op);
444 bool visitExpr(BundleCreateOp op);
445 bool visitExpr(ElementwiseAndPrimOp op);
446 bool visitExpr(ElementwiseOrPrimOp op);
447 bool visitExpr(ElementwiseXorPrimOp op);
448 bool visitExpr(MultibitMuxOp op);
449 bool visitExpr(MuxPrimOp op);
450 bool visitExpr(Mux2CellIntrinsicOp op);
451 bool visitExpr(Mux4CellIntrinsicOp op);
452 bool visitExpr(BitCastOp op);
453 bool visitExpr(RefSendOp op);
454 bool visitExpr(RefResolveOp op);
455 bool visitExpr(RefCastOp op);
456 bool visitStmt(ConnectOp op);
457 bool visitStmt(MatchingConnectOp op);
458 bool visitStmt(RefDefineOp op);
459 bool visitStmt(WhenOp op);
460 bool visitStmt(LayerBlockOp op);
461 bool visitUnrealizedConversionCast(mlir::UnrealizedConversionCastOp op);
463 bool isFailed()
const {
return encounteredError; }
466 if (
auto castOp = dyn_cast<mlir::UnrealizedConversionCastOp>(op))
467 return visitUnrealizedConversionCast(castOp);
472 void processUsers(Value val, ArrayRef<Value> mapping);
473 bool processSAPath(Operation *);
474 void lowerBlock(Block *);
475 void lowerSAWritePath(Operation *, ArrayRef<Operation *> writePath);
485 llvm::function_ref<Value(
const FlatBundleFieldEntry &, ArrayAttr)> clone,
490 ArrayAttr filterAnnotations(MLIRContext *ctxt, ArrayAttr annotations,
491 FIRRTLType srcType, FlatBundleFieldEntry field);
495 LogicalResult partitionSymbols(hw::InnerSymAttr sym,
FIRRTLType parentType,
496 SmallVectorImpl<hw::InnerSymAttr> &newSyms,
500 getPreservationModeForPorts(FModuleLike moduleLike);
501 Value getSubWhatever(Value val,
size_t index);
506 ArrayAttr oldPortAnno,
507 llvm::function_ref<Operation *(
508 ArrayRef<Type>, ArrayRef<Direction>, ArrayAttr,
509 ArrayAttr, ArrayAttr, hw::InnerSymAttr)>
512 size_t uniqueIdx = 0;
513 std::string uniqueName() {
514 auto myID = uniqueIdx++;
515 return (Twine(
"__GEN_") + Twine(myID)).str();
526 ImplicitLocOpBuilder *builder;
532 const AttrCache &cache;
534 const llvm::DenseMap<FModuleLike, Convention> &conventionTable;
537 bool encounteredError =
false;
544TypeLoweringVisitor::getPreservationModeForPorts(FModuleLike module) {
545 auto lookup = conventionTable.find(module);
546 if (lookup == conventionTable.end())
547 return defaultAggregatePreservationMode;
548 switch (lookup->second) {
549 case Convention::Scalarized:
551 case Convention::Internal:
552 return defaultAggregatePreservationMode;
554 llvm_unreachable(
"Unknown convention");
555 return defaultAggregatePreservationMode;
558Value TypeLoweringVisitor::getSubWhatever(Value val,
size_t index) {
559 if (type_isa<BundleType>(val.getType()))
560 return SubfieldOp::create(*builder, val, index);
561 if (type_isa<FVectorType>(val.getType()))
562 return SubindexOp::create(*builder, val, index);
563 if (type_isa<RefType>(val.getType()))
564 return RefSubOp::create(*builder, val, index);
565 llvm_unreachable(
"Unknown aggregate type");
570bool TypeLoweringVisitor::processSAPath(Operation *op) {
573 if (writePath.empty())
576 lowerSAWritePath(op, writePath);
579 op->eraseOperands(0, 2);
581 for (
size_t i = 0; i < writePath.size(); ++i) {
582 if (writePath[i]->use_empty()) {
583 writePath[i]->erase();
591void TypeLoweringVisitor::lowerBlock(Block *block) {
593 for (
auto it = block->rbegin(), e = block->rend(); it != e;) {
595 builder->setInsertionPoint(&iop);
596 builder->setLoc(iop.getLoc());
597 bool removeOp = dispatchVisitor(&iop);
606ArrayAttr TypeLoweringVisitor::filterAnnotations(MLIRContext *ctxt,
607 ArrayAttr annotations,
609 FlatBundleFieldEntry field) {
610 SmallVector<Attribute> retval;
611 if (!annotations || annotations.empty())
612 return ArrayAttr::get(ctxt, retval);
613 for (
auto opAttr : annotations) {
615 auto fieldID = anno.getFieldID();
616 anno.removeMember(
"circt.fieldID");
621 retval.push_back(anno.getAttr());
626 if (fieldID < field.fieldID ||
631 if (
auto newFieldID = fieldID - field.fieldID) {
634 anno.setMember(
"circt.fieldID", builder->getI32IntegerAttr(newFieldID));
637 retval.push_back(anno.getAttr());
639 return ArrayAttr::get(ctxt, retval);
642LogicalResult TypeLoweringVisitor::partitionSymbols(
644 SmallVectorImpl<hw::InnerSymAttr> &newSyms, Location errorLoc) {
647 if (!sym || sym.empty())
650 auto *
context = sym.getContext();
654 return mlir::emitError(errorLoc,
655 "unable to partition symbol on unsupported type ")
658 return TypeSwitch<FIRRTLType, LogicalResult>(baseType)
659 .Case<BundleType, FVectorType>([&](
auto aggType) -> LogicalResult {
663 hw::InnerSymPropertiesAttr prop;
667 SmallVector<BinningInfo> binning;
668 for (
auto prop : sym) {
669 auto fieldID = prop.getFieldID();
672 return mlir::emitError(errorLoc,
"unable to lower due to symbol ")
674 <<
" with target not preserved by lowering";
675 auto [index, relFieldID] = aggType.getIndexAndSubfieldID(fieldID);
676 binning.push_back({index, relFieldID, prop});
680 llvm::stable_sort(binning, [&](
auto &lhs,
auto &rhs) {
681 return std::tuple(lhs.index, lhs.relFieldID) <
682 std::tuple(rhs.index, rhs.relFieldID);
687 newSyms.resize(aggType.getNumElements());
688 for (
auto binIt = binning.begin(), binEnd = binning.end();
690 auto curIndex = binIt->index;
691 SmallVector<hw::InnerSymPropertiesAttr> propsForIndex;
693 while (binIt != binEnd && binIt->index == curIndex) {
694 propsForIndex.push_back(hw::InnerSymPropertiesAttr::get(
695 context, binIt->prop.getName(), binIt->relFieldID,
696 binIt->prop.getSymVisibility()));
700 assert(!newSyms[curIndex]);
701 newSyms[curIndex] = hw::InnerSymAttr::get(
context, propsForIndex);
705 .Default([&](
auto ty) {
706 return mlir::emitError(
707 errorLoc,
"unable to partition symbol on unsupported type ")
712bool TypeLoweringVisitor::lowerProducer(
714 llvm::function_ref<Value(
const FlatBundleFieldEntry &, ArrayAttr)> clone,
718 srcType = op->getResult(0).getType();
719 auto srcFType = type_dyn_cast<FIRRTLType>(srcType);
722 SmallVector<FlatBundleFieldEntry, 8> fieldTypes;
724 if (!
peelType(srcFType, fieldTypes, bodyAggregatePreservationMode))
727 SmallVector<Value> lowered;
729 SmallString<16> loweredName;
730 auto nameKindAttr = op->getAttrOfType<NameKindEnumAttr>(cache.nameKindAttr);
732 if (
auto nameAttr = op->getAttrOfType<StringAttr>(cache.nameAttr))
733 loweredName = nameAttr.getValue();
734 auto baseNameLen = loweredName.size();
735 auto oldAnno = dyn_cast_or_null<ArrayAttr>(op->getAttr(
"annotations"));
737 SmallVector<hw::InnerSymAttr> fieldSyms(fieldTypes.size());
738 if (
auto symOp = dyn_cast<hw::InnerSymbolOpInterface>(op)) {
739 if (failed(partitionSymbols(symOp.getInnerSymAttr(), srcFType, fieldSyms,
741 encounteredError =
true;
746 for (
const auto &[field, sym] :
llvm::zip_equal(fieldTypes, fieldSyms)) {
747 if (!loweredName.empty()) {
748 loweredName.resize(baseNameLen);
749 loweredName += field.suffix;
754 ArrayAttr loweredAttrs =
755 filterAnnotations(
context, oldAnno, srcFType, field);
756 auto newVal = clone(field, loweredAttrs);
762 auto newSymOp = newVal.getDefiningOp<hw::InnerSymbolOpInterface>();
765 "op with inner symbol lowered to op that cannot take inner symbol");
766 newSymOp.setInnerSymbolAttr(sym);
770 if (
auto *newOp = newVal.getDefiningOp()) {
771 if (!loweredName.empty())
772 newOp->setAttr(cache.nameAttr, StringAttr::get(
context, loweredName));
774 newOp->setAttr(cache.nameKindAttr, nameKindAttr);
777 newOp->setDiscardableAttrs(op->getDiscardableAttrDictionary());
779 lowered.push_back(newVal);
782 processUsers(op->getResult(0), lowered);
786void TypeLoweringVisitor::processUsers(Value val, ArrayRef<Value> mapping) {
787 for (
auto *user :
llvm::make_early_inc_range(val.getUsers())) {
788 TypeSwitch<Operation *, void>(user)
789 .Case<SubindexOp>([mapping](SubindexOp sio) {
790 Value repl = mapping[sio.getIndex()];
791 sio.replaceAllUsesWith(repl);
794 .Case<SubfieldOp>([mapping](SubfieldOp sfo) {
796 Value repl = mapping[sfo.getFieldIndex()];
797 sfo.replaceAllUsesWith(repl);
800 .Case<RefSubOp>([mapping](RefSubOp refSub) {
801 Value repl = mapping[refSub.getIndex()];
802 refSub.replaceAllUsesWith(repl);
805 .Default([&](
auto op) {
816 ImplicitLocOpBuilder
b(user->getLoc(), user);
820 assert(llvm::none_of(mapping, [](
auto v) {
821 auto fbasetype = type_dyn_cast<FIRRTLBaseType>(v.getType());
822 return !fbasetype || fbasetype.containsReference();
826 TypeSwitch<Type, Value>(val.getType())
827 .template Case<FVectorType>([&](
auto vecType) {
828 return b.createOrFold<VectorCreateOp>(vecType, mapping);
830 .
template Case<BundleType>([&](
auto bundleType) {
831 return b.createOrFold<BundleCreateOp>(bundleType, mapping);
833 .Default([&](
auto _) -> Value {
return {}; });
835 user->emitError(
"unable to reconstruct source of type ")
837 encounteredError =
true;
840 user->replaceUsesOfWith(val, input);
848 const llvm::BitVector &removalMask) {
849 size_t writeIndex = 0, readIndex = 0;
853 for (
size_t removalIndex : removalMask.set_bits()) {
855 assert(removalIndex >= readIndex &&
"removal index before read index");
856 size_t rangeSize = removalIndex - readIndex;
861 if (writeIndex != readIndex)
862 std::move(vec.begin() + readIndex, vec.begin() + removalIndex,
863 vec.begin() + writeIndex);
864 writeIndex += rangeSize;
866 readIndex = removalIndex + 1;
870 size_t remainingSize = vec.size() - readIndex;
871 if (remainingSize > 0) {
872 if (writeIndex != readIndex)
873 std::move(vec.begin() + readIndex, vec.end(), vec.begin() + writeIndex);
874 writeIndex += remainingSize;
878 vec.truncate(writeIndex);
881void TypeLoweringVisitor::lowerModule(FModuleLike op) {
882 if (
auto module = llvm::dyn_cast<FModuleOp>(*op))
884 else if (
auto extModule = llvm::dyn_cast<FExtModuleOp>(*op))
885 visitDecl(extModule);
891std::pair<Value, PortInfo>
892TypeLoweringVisitor::addArg(Operation *module,
unsigned insertPt,
894 const FlatBundleFieldEntry &field,
PortInfo &oldArg,
895 hw::InnerSymAttr newSym) {
898 if (
auto mod = llvm::dyn_cast<FModuleOp>(module)) {
899 Block *body = mod.getBodyBlock();
901 newValue = body->insertArgument(insertPt, fieldType, oldArg.
loc);
905 auto name = builder->getStringAttr(oldArg.
name.getValue() + field.suffix);
908 auto newAnnotations = filterAnnotations(
913 return std::make_pair(
914 newValue,
PortInfo{name, fieldType, direction, newSym, oldArg.
loc,
919bool TypeLoweringVisitor::lowerArg(FModuleLike module,
size_t argIndex,
921 SmallVectorImpl<PortInfo> &newArgs,
922 SmallVectorImpl<Value> &lowering) {
925 SmallVector<FlatBundleFieldEntry> fieldTypes;
926 auto srcType = type_cast<FIRRTLType>(newArgs[argIndex].type);
927 if (!
peelType(srcType, fieldTypes, getPreservationModeForPorts(module)))
930 SmallVector<hw::InnerSymAttr> fieldSyms(fieldTypes.size());
931 if (failed(partitionSymbols(newArgs[argIndex].sym, srcType, fieldSyms,
932 newArgs[argIndex].loc))) {
933 encounteredError =
true;
937 for (
const auto &[idx, field, fieldSym] :
938 llvm::enumerate(fieldTypes, fieldSyms)) {
939 auto newValue = addArg(module, 1 + argIndex + idx, argsRemoved, srcType,
940 field, newArgs[argIndex], fieldSym);
941 newArgs.insert(newArgs.begin() + 1 + argIndex + idx, newValue.second);
943 lowering.push_back(newValue.first);
948static Value
cloneAccess(ImplicitLocOpBuilder *builder, Operation *op,
950 if (
auto rop = llvm::dyn_cast<SubfieldOp>(op))
951 return SubfieldOp::create(*builder, rhs, rop.getFieldIndex());
952 if (
auto rop = llvm::dyn_cast<SubindexOp>(op))
953 return SubindexOp::create(*builder, rhs, rop.getIndex());
954 if (
auto rop = llvm::dyn_cast<SubaccessOp>(op))
955 return SubaccessOp::create(*builder, rhs, rop.getIndex());
956 op->emitError(
"Unknown accessor");
960void TypeLoweringVisitor::lowerSAWritePath(Operation *op,
961 ArrayRef<Operation *> writePath) {
962 SubaccessOp sao = cast<SubaccessOp>(writePath.back());
963 FVectorType saoType = sao.getInput().getType();
964 auto selectWidth = llvm::Log2_64_Ceil(saoType.getNumElements());
966 for (
size_t index = 0, e = saoType.getNumElements(); index < e; ++index) {
967 auto cond = EQPrimOp::create(
968 *builder, sao.getIndex(),
969 builder->createOrFold<ConstantOp>(UIntType::get(
context, selectWidth),
970 APInt(selectWidth, index)));
971 WhenOp::create(*builder, cond,
false, [&]() {
973 Value leaf = SubindexOp::create(*builder, sao.getInput(), index);
974 for (
int i = writePath.size() - 2; i >= 0; --i) {
975 if (
auto access =
cloneAccess(builder, writePath[i], leaf))
978 encounteredError =
true;
989bool TypeLoweringVisitor::visitStmt(ConnectOp op) {
990 if (processSAPath(op))
994 SmallVector<FlatBundleFieldEntry> fields;
1001 for (
const auto &field :
llvm::enumerate(fields)) {
1002 Value src = getSubWhatever(op.getSrc(), field.index());
1003 Value dest = getSubWhatever(op.getDest(), field.index());
1004 if (field.value().isOutput)
1005 std::swap(src, dest);
1012bool TypeLoweringVisitor::visitStmt(MatchingConnectOp op) {
1013 if (processSAPath(op))
1017 SmallVector<FlatBundleFieldEntry> fields;
1024 for (
const auto &field :
llvm::enumerate(fields)) {
1025 Value src = getSubWhatever(op.getSrc(), field.index());
1026 Value dest = getSubWhatever(op.getDest(), field.index());
1027 if (field.value().isOutput)
1028 std::swap(src, dest);
1029 MatchingConnectOp::create(*builder, dest, src);
1035bool TypeLoweringVisitor::visitStmt(RefDefineOp op) {
1037 SmallVector<FlatBundleFieldEntry> fields;
1039 if (!
peelType(op.getDest().getType(), fields, bodyAggregatePreservationMode))
1043 for (
const auto &field :
llvm::enumerate(fields)) {
1044 Value src = getSubWhatever(op.getSrc(), field.index());
1045 Value dest = getSubWhatever(op.getDest(), field.index());
1046 assert(!field.value().isOutput &&
"unexpected flip in reftype destination");
1047 RefDefineOp::create(*builder, dest, src);
1052bool TypeLoweringVisitor::visitStmt(WhenOp op) {
1058 lowerBlock(&op.getThenBlock());
1061 if (op.hasElseRegion())
1062 lowerBlock(&op.getElseBlock());
1067bool TypeLoweringVisitor::visitStmt(LayerBlockOp op) {
1068 lowerBlock(op.getBody());
1074bool TypeLoweringVisitor::visitDecl(MemOp op) {
1076 SmallVector<FlatBundleFieldEntry> fields;
1079 if (!
peelType(op.getDataType(), fields, memoryPreservationMode))
1082 if (op.getInnerSym()) {
1083 op->emitError() <<
"has a symbol, but no symbols may exist on aggregates "
1084 "passed through LowerTypes";
1085 encounteredError =
true;
1089 SmallVector<MemOp> newMemories;
1090 SmallVector<WireOp> oldPorts;
1093 for (
unsigned int index = 0, end = op.getNumResults(); index <
end; ++index) {
1094 auto result = op.getResult(index);
1095 if (op.getPortKind(index) == MemOp::PortKind::Debug) {
1096 op.emitOpError(
"cannot lower memory with debug port");
1097 encounteredError =
true;
1101 WireOp::create(*builder, result.getType(),
1102 (op.getName() +
"_" + op.getPortName(index)).str());
1103 oldPorts.push_back(wire);
1104 result.replaceAllUsesWith(wire.getResult());
1111 for (
const auto &field : fields) {
1113 if (!newMemForField) {
1114 op.emitError(
"failed cloning memory for field");
1115 encounteredError =
true;
1118 newMemories.push_back(newMemForField);
1121 for (
size_t index = 0, rend = op.getNumResults(); index < rend; ++index) {
1122 auto result = oldPorts[index].getResult();
1123 auto rType = type_cast<BundleType>(result.getType());
1124 for (
size_t fieldIndex = 0, fend = rType.getNumElements();
1125 fieldIndex != fend; ++fieldIndex) {
1126 auto name = rType.getElement(fieldIndex).name.getValue();
1127 auto oldField = SubfieldOp::create(*builder, result, fieldIndex);
1130 if (name ==
"data" || name ==
"mask" || name ==
"wdata" ||
1131 name ==
"wmask" || name ==
"rdata") {
1132 for (
const auto &field : fields) {
1133 auto realOldField = getSubWhatever(oldField, field.index);
1134 auto newField = getSubWhatever(
1135 newMemories[field.index].getResult(index), fieldIndex);
1136 if (rType.getElement(fieldIndex).isFlip)
1137 std::swap(realOldField, newField);
1141 for (
auto mem : newMemories) {
1143 SubfieldOp::create(*builder, mem.getResult(index), fieldIndex);
1152bool TypeLoweringVisitor::visitDecl(FExtModuleOp extModule) {
1153 ImplicitLocOpBuilder theBuilder(extModule.getLoc(),
context);
1154 builder = &theBuilder;
1160 llvm::BitVector argsToRemove;
1161 auto newArgs = extModule.getPorts();
1162 argsToRemove.reserve(newArgs.size());
1164 DomainLoweringHelper domainHelper(
context, extModule.getPortTypes());
1166 size_t argsRemoved = 0;
1167 for (
size_t argIndex = 0; argIndex < newArgs.size(); ++argIndex) {
1168 SmallVector<Value> lowering;
1169 if (lowerArg(extModule, argIndex, argsRemoved, newArgs, lowering)) {
1170 argsToRemove.push_back(
true);
1173 argsToRemove.push_back(
false);
1179 if (argsRemoved != 0)
1182 domainHelper.computeDomainMap(newArgs);
1184 SmallVector<NamedAttribute, 8> newModuleAttrs;
1187 for (
auto attr : extModule->getAttrDictionary())
1190 if (attr.
getName() !=
"portDirections" && attr.
getName() !=
"portNames" &&
1191 attr.
getName() !=
"portTypes" && attr.
getName() !=
"portAnnotations" &&
1192 attr.
getName() !=
"portSymbols" && attr.
getName() !=
"portLocations")
1193 newModuleAttrs.push_back(attr);
1195 SmallVector<Direction> newArgDirections;
1196 SmallVector<Attribute> newArgNames;
1197 SmallVector<Attribute, 8> newArgTypes;
1198 SmallVector<Attribute, 8> newArgSyms;
1199 SmallVector<Attribute, 8> newArgLocations;
1200 SmallVector<Attribute, 8> newArgAnnotations;
1201 SmallVector<Attribute, 8> newArgDomains;
1203 for (
auto &port : newArgs) {
1204 newArgDirections.push_back(port.direction);
1205 newArgNames.push_back(port.name);
1206 newArgTypes.push_back(TypeAttr::get(port.type));
1207 newArgSyms.push_back(port.sym);
1208 newArgLocations.push_back(port.loc);
1209 newArgAnnotations.push_back(port.annotations.getArrayAttr());
1211 domainHelper.rewriteDomain(port.domains);
1213 port.domains = cache.aEmpty;
1215 newArgDomains.push_back(port.domains);
1218 newModuleAttrs.push_back(
1219 NamedAttribute(cache.sPortDirections,
1222 newModuleAttrs.push_back(
1223 NamedAttribute(cache.sPortNames, builder.getArrayAttr(newArgNames)));
1225 newModuleAttrs.push_back(
1226 NamedAttribute(cache.sPortTypes, builder.getArrayAttr(newArgTypes)));
1228 newModuleAttrs.push_back(NamedAttribute(
1229 cache.sPortLocations, builder.getArrayAttr(newArgLocations)));
1231 newModuleAttrs.push_back(NamedAttribute(
1232 cache.sPortAnnotations, builder.getArrayAttr(newArgAnnotations)));
1234 newModuleAttrs.push_back(
1235 NamedAttribute(cache.sPortDomains, builder.getArrayAttr(newArgDomains)));
1238 extModule->setAttrs(newModuleAttrs);
1239 FModuleLike::fixupPortSymsArray(newArgSyms,
context);
1240 extModule.setPortSymbols(newArgSyms);
1245bool TypeLoweringVisitor::visitDecl(FModuleOp module) {
1246 auto *body =
module.getBodyBlock();
1248 ImplicitLocOpBuilder theBuilder(module.getLoc(),
context);
1249 builder = &theBuilder;
1255 llvm::BitVector argsToRemove;
1256 auto newArgs =
module.getPorts();
1257 argsToRemove.reserve(newArgs.size());
1259 DomainLoweringHelper domainHelper(
context, module.getPortTypes());
1261 size_t argsRemoved = 0;
1262 for (
size_t argIndex = 0; argIndex < newArgs.size(); ++argIndex) {
1263 SmallVector<Value> lowerings;
1264 if (lowerArg(module, argIndex, argsRemoved, newArgs, lowerings)) {
1265 auto arg =
module.getArgument(argIndex);
1266 processUsers(arg, lowerings);
1267 argsToRemove.push_back(
true);
1270 argsToRemove.push_back(
false);
1275 if (argsRemoved != 0) {
1276 body->eraseArguments(argsToRemove);
1280 domainHelper.computeDomainMap(newArgs);
1282 SmallVector<NamedAttribute, 8> newModuleAttrs;
1285 for (
auto attr : module->getAttrDictionary())
1288 if (attr.
getName() !=
"portNames" && attr.
getName() !=
"portDirections" &&
1289 attr.
getName() !=
"portTypes" && attr.
getName() !=
"portAnnotations" &&
1290 attr.
getName() !=
"portSymbols" && attr.
getName() !=
"portLocations")
1291 newModuleAttrs.push_back(attr);
1293 SmallVector<Direction> newArgDirections;
1294 SmallVector<Attribute> newArgNames;
1295 SmallVector<Attribute> newArgTypes;
1296 SmallVector<Attribute> newArgSyms;
1297 SmallVector<Attribute> newArgLocations;
1298 SmallVector<Attribute, 8> newArgAnnotations;
1299 SmallVector<Attribute> newPortDomains;
1300 for (
auto &port : newArgs) {
1301 newArgDirections.push_back(port.direction);
1302 newArgNames.push_back(port.name);
1303 newArgTypes.push_back(TypeAttr::get(port.type));
1304 newArgSyms.push_back(port.sym);
1305 newArgLocations.push_back(port.loc);
1306 newArgAnnotations.push_back(port.annotations.getArrayAttr());
1308 domainHelper.rewriteDomain(port.domains);
1310 port.domains = cache.aEmpty;
1312 newPortDomains.push_back(port.domains);
1315 newModuleAttrs.push_back(
1316 NamedAttribute(cache.sPortDirections,
1319 newModuleAttrs.push_back(
1320 NamedAttribute(cache.sPortNames, builder->getArrayAttr(newArgNames)));
1322 newModuleAttrs.push_back(
1323 NamedAttribute(cache.sPortTypes, builder->getArrayAttr(newArgTypes)));
1325 newModuleAttrs.push_back(NamedAttribute(
1326 cache.sPortLocations, builder->getArrayAttr(newArgLocations)));
1328 newModuleAttrs.push_back(NamedAttribute(
1329 cache.sPortAnnotations, builder->getArrayAttr(newArgAnnotations)));
1331 newModuleAttrs.push_back(NamedAttribute(
1332 cache.sPortDomains, builder->getArrayAttr(newPortDomains)));
1335 module->setAttrs(newModuleAttrs);
1336 FModuleLike::fixupPortSymsArray(newArgSyms,
context);
1337 module.setPortSymbols(newArgSyms);
1342bool TypeLoweringVisitor::visitDecl(WireOp op) {
1343 if (op.isForceable())
1346 auto clone = [&](
const FlatBundleFieldEntry &field,
1347 ArrayAttr attrs) -> Value {
1348 return WireOp::create(*builder,
1350 "", NameKindEnum::DroppableName, attrs, StringAttr{})
1353 return lowerProducer(op, clone);
1357bool TypeLoweringVisitor::visitDecl(RegOp op) {
1358 if (op.isForceable())
1361 auto clone = [&](
const FlatBundleFieldEntry &field,
1362 ArrayAttr attrs) -> Value {
1363 return RegOp::create(*builder, field.type, op.getClockVal(),
"",
1364 NameKindEnum::DroppableName, attrs, StringAttr{})
1367 return lowerProducer(op, clone);
1371bool TypeLoweringVisitor::visitDecl(RegResetOp op) {
1372 if (op.isForceable())
1375 auto clone = [&](
const FlatBundleFieldEntry &field,
1376 ArrayAttr attrs) -> Value {
1377 auto resetVal = getSubWhatever(op.getResetValue(), field.index);
1378 return RegResetOp::create(*builder, field.type, op.getClockVal(),
1379 op.getResetSignal(), resetVal,
"",
1380 NameKindEnum::DroppableName, attrs, StringAttr{})
1383 return lowerProducer(op, clone);
1387bool TypeLoweringVisitor::visitDecl(NodeOp op) {
1388 if (op.isForceable())
1391 auto clone = [&](
const FlatBundleFieldEntry &field,
1392 ArrayAttr attrs) -> Value {
1393 auto input = getSubWhatever(op.getInput(), field.index);
1394 return NodeOp::create(*builder, input,
"", NameKindEnum::DroppableName,
1398 return lowerProducer(op, clone);
1402bool TypeLoweringVisitor::visitExpr(InvalidValueOp op) {
1403 auto clone = [&](
const FlatBundleFieldEntry &field,
1404 ArrayAttr attrs) -> Value {
1405 return InvalidValueOp::create(*builder, field.type);
1407 return lowerProducer(op, clone);
1411bool TypeLoweringVisitor::visitExpr(MuxPrimOp op) {
1412 auto clone = [&](
const FlatBundleFieldEntry &field,
1413 ArrayAttr attrs) -> Value {
1414 auto high = getSubWhatever(op.getHigh(), field.index);
1415 auto low = getSubWhatever(op.getLow(), field.index);
1416 return MuxPrimOp::create(*builder, op.getSel(), high, low);
1418 return lowerProducer(op, clone);
1422bool TypeLoweringVisitor::visitExpr(Mux2CellIntrinsicOp op) {
1423 auto clone = [&](
const FlatBundleFieldEntry &field,
1424 ArrayAttr attrs) -> Value {
1425 auto high = getSubWhatever(op.getHigh(), field.index);
1426 auto low = getSubWhatever(op.getLow(), field.index);
1427 return Mux2CellIntrinsicOp::create(*builder, op.getSel(), high, low);
1429 return lowerProducer(op, clone);
1433bool TypeLoweringVisitor::visitExpr(Mux4CellIntrinsicOp op) {
1434 auto clone = [&](
const FlatBundleFieldEntry &field,
1435 ArrayAttr attrs) -> Value {
1436 auto v3 = getSubWhatever(op.getV3(), field.index);
1437 auto v2 = getSubWhatever(op.getV2(), field.index);
1438 auto v1 = getSubWhatever(op.getV1(), field.index);
1439 auto v0 = getSubWhatever(op.getV0(), field.index);
1440 return Mux4CellIntrinsicOp::create(*builder, op.getSel(), v3, v2, v1, v0);
1442 return lowerProducer(op, clone);
1446bool TypeLoweringVisitor::visitUnrealizedConversionCast(
1447 mlir::UnrealizedConversionCastOp op) {
1448 auto clone = [&](
const FlatBundleFieldEntry &field,
1449 ArrayAttr attrs) -> Value {
1450 auto input = getSubWhatever(op.getOperand(0), field.index);
1451 return mlir::UnrealizedConversionCastOp::create(*builder, field.type, input)
1456 if (!type_isa<FIRRTLType>(op->getOperand(0).getType()))
1458 return lowerProducer(op, clone);
1462bool TypeLoweringVisitor::visitExpr(BitCastOp op) {
1463 Value srcLoweredVal = op.getInput();
1467 SmallVector<FlatBundleFieldEntry> fields;
1469 size_t uptoBits = 0;
1472 for (
const auto &field :
llvm::enumerate(fields)) {
1473 auto fieldBitwidth = *
getBitWidth(field.value().type);
1475 if (fieldBitwidth == 0)
1477 Value src = getSubWhatever(op.getInput(), field.index());
1479 src = builder->createOrFold<BitCastOp>(
1480 UIntType::get(
context, fieldBitwidth), src);
1483 srcLoweredVal = src;
1485 if (type_isa<BundleType>(op.getInput().getType())) {
1487 CatPrimOp::create(*builder, ValueRange{srcLoweredVal, src});
1490 CatPrimOp::create(*builder, ValueRange{src, srcLoweredVal});
1494 uptoBits += fieldBitwidth;
1497 srcLoweredVal = builder->createOrFold<AsUIntPrimOp>(srcLoweredVal);
1501 if (type_isa<BundleType, FVectorType>(op.getResult().getType())) {
1503 size_t uptoBits = 0;
1504 auto aggregateBits = *
getBitWidth(op.getResult().getType());
1505 auto clone = [&](
const FlatBundleFieldEntry &field,
1506 ArrayAttr attrs) -> Value {
1512 return InvalidValueOp::create(*builder, field.type);
1517 if (type_isa<BundleType>(op.getResult().getType())) {
1518 extractBits = BitsPrimOp::create(*builder, srcLoweredVal,
1519 aggregateBits - uptoBits - 1,
1520 aggregateBits - uptoBits - fieldBits);
1522 extractBits = BitsPrimOp::create(*builder, srcLoweredVal,
1523 uptoBits + fieldBits - 1, uptoBits);
1525 uptoBits += fieldBits;
1526 return BitCastOp::create(*builder, field.type,
extractBits);
1528 return lowerProducer(op, clone);
1532 if (type_isa<SIntType>(op.getType()))
1533 srcLoweredVal = AsSIntPrimOp::create(*builder, srcLoweredVal);
1534 op.getResult().replaceAllUsesWith(srcLoweredVal);
1538bool TypeLoweringVisitor::visitExpr(RefSendOp op) {
1539 auto clone = [&](
const FlatBundleFieldEntry &field,
1540 ArrayAttr attrs) -> Value {
1541 return RefSendOp::create(*builder,
1542 getSubWhatever(op.getBase(), field.index));
1547 return lowerProducer(op, clone);
1550bool TypeLoweringVisitor::visitExpr(RefResolveOp op) {
1551 auto clone = [&](
const FlatBundleFieldEntry &field,
1552 ArrayAttr attrs) -> Value {
1553 Value src = getSubWhatever(op.getRef(), field.index);
1554 return RefResolveOp::create(*builder, src);
1558 return lowerProducer(op, clone, op.getRef().getType());
1561bool TypeLoweringVisitor::visitExpr(RefCastOp op) {
1562 auto clone = [&](
const FlatBundleFieldEntry &field,
1563 ArrayAttr attrs) -> Value {
1564 auto input = getSubWhatever(op.getInput(), field.index);
1565 return RefCastOp::create(*builder,
1566 RefType::get(field.type,
1567 op.getType().getForceable(),
1568 op.getType().getLayer()),
1571 return lowerProducer(op, clone);
1576bool TypeLoweringVisitor::lowerInstanceLike(
1578 ArrayAttr oldPortAnno,
1579 llvm::function_ref<Operation *(ArrayRef<Type>, ArrayRef<Direction>,
1580 ArrayAttr, ArrayAttr, ArrayAttr,
1582 createNewInstance) {
1584 SmallVector<Type, 8> resultTypes;
1585 SmallVector<int64_t, 8> endFields;
1586 SmallVector<Direction> newDirs;
1587 SmallVector<Attribute> newNames, newDomains, newPortAnno;
1590 DomainLoweringHelper domainHelper(
context, op->getResultTypes());
1591 auto emptyAnno = builder->getArrayAttr({});
1593 endFields.push_back(0);
1594 for (
size_t i = 0, e = op->getNumResults(); i != e; ++i) {
1595 auto srcType = type_cast<FIRRTLType>(op->getResult(i).getType());
1598 SmallVector<FlatBundleFieldEntry, 8> fieldTypes;
1599 if (!
peelType(srcType, fieldTypes, mode)) {
1600 newDirs.push_back(op.getPortDirection(i));
1601 newNames.push_back(op.getPortNameAttr(i));
1602 newDomains.push_back(op.getPortDomain(i));
1603 resultTypes.push_back(srcType);
1604 newPortAnno.push_back(oldPortAnno ? oldPortAnno[i] : emptyAnno);
1607 auto oldName = op.getPortName(i);
1608 auto oldDir = op.getPortDirection(i);
1610 for (
const auto &field : fieldTypes) {
1611 newDirs.push_back(
direction::get((
unsigned)oldDir ^ field.isOutput));
1612 newNames.push_back(builder->getStringAttr(oldName + field.suffix));
1613 newDomains.push_back(op.getPortDomain(i));
1618 dyn_cast_or_null<ArrayAttr>(oldPortAnno[i]),
1621 newPortAnno.push_back(annos);
1624 endFields.push_back(resultTypes.size());
1634 domainHelper.computeDomainMap(resultTypes);
1637 for (
auto &domain : newDomains)
1638 domainHelper.rewriteDomain(domain);
1641 auto *newInstance = createNewInstance(
1642 resultTypes, newDirs, builder->getArrayAttr(newNames),
1643 builder->getArrayAttr(newDomains), builder->getArrayAttr(newPortAnno),
1644 sym ? hw::InnerSymAttr::get(sym) :
hw::InnerSymAttr());
1646 newInstance->setDiscardableAttrs(op->getDiscardableAttrDictionary());
1648 SmallVector<Value> lowered;
1649 for (
size_t aggIndex = 0, eAgg = op->getNumResults(); aggIndex != eAgg;
1652 for (
size_t fieldIndex = endFields[aggIndex],
1653 eField = endFields[aggIndex + 1];
1654 fieldIndex < eField; ++fieldIndex)
1655 lowered.push_back(newInstance->getResult(fieldIndex));
1656 if (lowered.size() != 1 ||
1657 op->getResult(aggIndex).getType() != resultTypes[endFields[aggIndex]])
1658 processUsers(op->getResult(aggIndex), lowered);
1660 op->getResult(aggIndex).replaceAllUsesWith(lowered[0]);
1665bool TypeLoweringVisitor::visitDecl(InstanceOp op) {
1668 cast<FModuleLike>(op.getReferencedOperation(symTbl)));
1671 auto createNewInstance = [&](ArrayRef<Type> resultTypes,
1672 ArrayRef<Direction> newDirs, ArrayAttr newNames,
1673 ArrayAttr newDomains, ArrayAttr newPortAnno,
1674 hw::InnerSymAttr sym) -> Operation * {
1676 return InstanceOp::create(
1677 *builder, resultTypes, op.getModuleNameAttr(), op.getNameAttr(),
1679 newNames, newDomains, op.getAnnotations(), newPortAnno,
1680 op.getLayersAttr(), op.getLowerToBindAttr(), op.getDoNotPrintAttr(),
1684 return lowerInstanceLike(op, mode, op.getPortAnnotations(),
1688bool TypeLoweringVisitor::visitDecl(InstanceChoiceOp op) {
1690 auto *moduleOp = symTbl.lookupNearestSymbolFrom(
1691 op, cast<FlatSymbolRefAttr>(op.getDefaultTargetAttr()));
1692 auto mode = getPreservationModeForPorts(cast<FModuleLike>(moduleOp));
1695 auto createNewInstance = [&](ArrayRef<Type> resultTypes,
1696 ArrayRef<Direction> newDirs, ArrayAttr newNames,
1697 ArrayAttr newDomains, ArrayAttr newPortAnno,
1698 hw::InnerSymAttr sym) -> Operation * {
1699 return InstanceChoiceOp::create(
1700 *builder, resultTypes, op.getModuleNames(), op.getCaseNames(),
1701 op.getNameAttr(), op.getNameKindAttr(),
1703 op.getAnnotations(), newPortAnno, op.getLayersAttr(), sym,
1704 op.getInstanceMacroAttr());
1707 return lowerInstanceLike(op, mode, op.getPortAnnotations(),
1711bool TypeLoweringVisitor::visitExpr(SubaccessOp op) {
1712 auto input = op.getInput();
1713 FVectorType vType = input.getType();
1716 if (vType.getNumElements() == 0) {
1717 Value inv = InvalidValueOp::create(*builder, vType.getElementType());
1718 op.replaceAllUsesWith(inv);
1723 if (ConstantOp arg =
1724 llvm::dyn_cast_or_null<ConstantOp>(op.getIndex().getDefiningOp())) {
1725 auto sio = SubindexOp::create(*builder, op.getInput(),
1726 arg.getValue().getExtValue());
1727 op.replaceAllUsesWith(sio.getResult());
1732 SmallVector<Value> inputs;
1733 inputs.reserve(vType.getNumElements());
1734 for (
int index = vType.getNumElements() - 1; index >= 0; index--)
1735 inputs.push_back(SubindexOp::create(*builder, input, index));
1737 Value multibitMux = MultibitMuxOp::create(*builder, op.getIndex(), inputs);
1738 op.replaceAllUsesWith(multibitMux);
1742bool TypeLoweringVisitor::visitExpr(VectorCreateOp op) {
1743 auto clone = [&](
const FlatBundleFieldEntry &field,
1744 ArrayAttr attrs) -> Value {
1745 return op.getOperand(field.index);
1747 return lowerProducer(op, clone);
1750bool TypeLoweringVisitor::visitExpr(BundleCreateOp op) {
1751 auto clone = [&](
const FlatBundleFieldEntry &field,
1752 ArrayAttr attrs) -> Value {
1753 return op.getOperand(field.index);
1755 return lowerProducer(op, clone);
1758bool TypeLoweringVisitor::visitExpr(ElementwiseOrPrimOp op) {
1759 auto clone = [&](
const FlatBundleFieldEntry &field,
1760 ArrayAttr attrs) -> Value {
1761 Value operands[] = {getSubWhatever(op.getLhs(), field.index),
1762 getSubWhatever(op.getRhs(), field.index)};
1763 return type_isa<BundleType, FVectorType>(field.type)
1764 ? (Value)ElementwiseOrPrimOp::create(*builder, field.type,
1766 : (Value)OrPrimOp::create(*builder, operands);
1769 return lowerProducer(op, clone);
1772bool TypeLoweringVisitor::visitExpr(ElementwiseAndPrimOp op) {
1773 auto clone = [&](
const FlatBundleFieldEntry &field,
1774 ArrayAttr attrs) -> Value {
1775 Value operands[] = {getSubWhatever(op.getLhs(), field.index),
1776 getSubWhatever(op.getRhs(), field.index)};
1777 return type_isa<BundleType, FVectorType>(field.type)
1778 ? (Value)ElementwiseAndPrimOp::create(*builder, field.type,
1780 : (Value)AndPrimOp::create(*builder, operands);
1783 return lowerProducer(op, clone);
1786bool TypeLoweringVisitor::visitExpr(ElementwiseXorPrimOp op) {
1787 auto clone = [&](
const FlatBundleFieldEntry &field,
1788 ArrayAttr attrs) -> Value {
1789 Value operands[] = {getSubWhatever(op.getLhs(), field.index),
1790 getSubWhatever(op.getRhs(), field.index)};
1791 return type_isa<BundleType, FVectorType>(field.type)
1792 ? (Value)ElementwiseXorPrimOp::create(*builder, field.type,
1794 : (Value)XorPrimOp::create(*builder, operands);
1797 return lowerProducer(op, clone);
1800bool TypeLoweringVisitor::visitExpr(MultibitMuxOp op) {
1801 auto clone = [&](
const FlatBundleFieldEntry &field,
1802 ArrayAttr attrs) -> Value {
1803 SmallVector<Value> newInputs;
1804 newInputs.reserve(op.getInputs().size());
1805 for (
auto input : op.getInputs()) {
1806 auto inputSub = getSubWhatever(input, field.index);
1807 newInputs.push_back(inputSub);
1809 return MultibitMuxOp::create(*builder, op.getIndex(), newInputs);
1811 return lowerProducer(op, clone);
1819struct LowerTypesPass
1820 :
public circt::firrtl::impl::LowerFIRRTLTypesBase<LowerTypesPass> {
1823 void runOnOperation()
override;
1828void LowerTypesPass::runOnOperation() {
1831 std::vector<FModuleLike> ops;
1832 auto &instanceGraph = getAnalysis<InstanceGraph>();
1834 auto &symTbl = getAnalysis<SymbolTable>();
1836 AttrCache cache(&getContext());
1838 DenseMap<FModuleLike, Convention> conventionTable;
1839 auto circuit = getOperation();
1840 for (
auto module : circuit.getOps<FModuleLike>()) {
1841 auto convention =
module.getConvention();
1844 if (llvm::any_of(instanceGraph.lookup(module)->uses(),
1846 return use->getInstance<InstanceChoiceOp>();
1848 convention = Convention::Scalarized;
1849 conventionTable.insert({module, convention});
1850 ops.push_back(module);
1854 auto lowerModules = [&](FModuleLike op) -> LogicalResult {
1856 Convention convention = Convention::Internal;
1857 if (
auto conventionAttr = dyn_cast_or_null<ConventionAttr>(
1858 op->getDiscardableAttr(
"body_type_lowering")))
1859 convention = conventionAttr.getValue();
1862 TypeLoweringVisitor(&getContext(), preserveAggregate, convention,
1863 preserveMemories, symTbl, cache, conventionTable);
1866 return LogicalResult::failure(tl.isFailed());
1869 auto result = failableParallelForEach(&getContext(), ops, lowerModules);
1872 signalPassFailure();
assert(baseType &&"element must be base type")
static SmallVector< Value > extractBits(OpBuilder &builder, Value val)
static std::unique_ptr< Context > context
static void dump(DIModule &module, raw_indented_ostream &os)
static void eraseElementsAtIndices(SmallVectorImpl< T > &vec, const llvm::BitVector &removalMask)
Helper function to remove elements from a vector based on a BitVector mask.
static bool isPreservableAggregateType(Type type, PreserveAggregate::PreserveMode mode)
Return true if we can preserve the type.
static FIRRTLType mapLoweredType(FIRRTLType type, FIRRTLBaseType fieldType)
Return fieldType or fieldType as same ref as type.
static MemOp cloneMemWithNewType(ImplicitLocOpBuilder *b, MemOp op, FlatBundleFieldEntry field)
Clone memory for the specified field. Returns null op on error.
static bool containsBundleType(FIRRTLType type)
Return true if the type has a bundle type as subtype.
static Value cloneAccess(ImplicitLocOpBuilder *builder, Operation *op, Value rhs)
static bool peelType(Type type, SmallVectorImpl< FlatBundleFieldEntry > &fields, PreserveAggregate::PreserveMode mode)
Peel one layer of an aggregate type into its components.
static bool isNotSubAccess(Operation *op)
Return if something is not a normal subaccess.
static SmallVector< Operation * > getSAWritePath(Operation *op)
Look through and collect subfields leading to a subaccess.
static bool isOneDimVectorType(FIRRTLType type)
Return true if the type is a 1d vector type or ground type.
#define CIRCT_DEBUG_SCOPED_PASS_LOGGER(PASS)
This class provides a read-only projection over the MLIR attributes that represent a set of annotatio...
ArrayAttr getArrayAttr() const
Return this annotation set as an ArrayAttr.
This class provides a read-only projection of an annotation.
DictionaryAttr getDict() const
Get the data dictionary of this attribute.
unsigned getFieldID() const
Get the field id this attribute targets.
void setMember(StringAttr name, Attribute value)
Add or set a member of the annotation to a value.
This class implements the same functionality as TypeSwitch except that it uses firrtl::type_dyn_cast ...
FIRRTLTypeSwitch< T, ResultT > & Case(CallableT &&caseFn)
Add a case on the given type.
FIRRTLVisitor allows you to visit all of the expr/stmt/decls with one class declaration.
ResultType visitInvalidOp(Operation *op, ExtraArgs... args)
visitInvalidOp is an override point for non-FIRRTL dialect operations.
This is an edge in the InstanceGraph.
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
@ All
Preserve all aggregate values.
@ OneDimVec
Preserve only 1d vectors of ground type (e.g. UInt<2>[3]).
@ Vec
Preserve only vectors (e.g. UInt<2>[3][3]).
@ None
Don't preserve aggregate at all.
mlir::DenseBoolArrayAttr packAttribute(MLIRContext *context, ArrayRef< Direction > directions)
Return a DenseBoolArrayAttr containing the packed representation of an array of directions.
static Direction get(bool isOutput)
Return an output direction if isOutput is true, otherwise return an input direction.
Direction
This represents the direction of a single port.
FIRRTLBaseType getBaseType(Type type)
If it is a base type, return it as is.
FIRRTLType mapBaseType(FIRRTLType type, function_ref< FIRRTLBaseType(FIRRTLBaseType)> fn)
Return a FIRRTLType with its base type component mutated by the given function.
bool hasZeroBitWidth(FIRRTLType type)
Return true if the type has zero bit width.
void emitConnect(OpBuilder &builder, Location loc, Value lhs, Value rhs)
Emit a connect between two values.
StringAttr getInnerSymName(Operation *op)
Return the StringAttr for the inner_sym name, if it exists.
std::optional< int64_t > getBitWidth(FIRRTLBaseType type, bool ignoreFlip=false)
uint64_t getMaxFieldID(Type)
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
This holds the name and type that describes the module's ports.
AnnotationSet annotations