39#include "mlir/IR/ImplicitLocOpBuilder.h"
40#include "mlir/IR/Threading.h"
41#include "mlir/Pass/Pass.h"
42#include "llvm/ADT/APSInt.h"
43#include "llvm/ADT/BitVector.h"
44#include "llvm/ADT/STLExtras.h"
45#include "llvm/Support/Debug.h"
47#define DEBUG_TYPE "firrtl-lower-types"
51#define GEN_PASS_DEF_LOWERFIRRTLTYPES
52#include "circt/Dialect/FIRRTL/Passes.h.inc"
57using namespace firrtl;
62struct FlatBundleFieldEntry {
70 SmallString<16> suffix;
75 unsigned fieldID, StringRef suffix,
bool isOutput)
76 : type(type), index(index), fieldID(fieldID), suffix(suffix),
80 llvm::errs() <<
"FBFE{" << type <<
" index<" << index <<
"> fieldID<"
81 << fieldID <<
"> suffix<" << suffix <<
"> isOutput<"
82 << isOutput <<
">}\n";
89 return mapBaseType(type, [&](
auto) {
return fieldType; });
94 auto ftype = type_dyn_cast<FIRRTLType>(type);
103 .
Case<BundleType>([&](
auto bundle) {
return false; })
104 .Case<FVectorType>([&](FVectorType vector) {
106 return vector.getElementType().isGround() &&
107 vector.getNumElements() > 1;
109 .Default([](
auto groundType) {
return true; });
116 .
Case<BundleType>([&](
auto bundle) {
return true; })
117 .Case<FVectorType>([&](FVectorType vector) {
120 .Default([](
auto groundType) {
return false; });
127 if (
auto refType = type_dyn_cast<RefType>(type)) {
129 if (refType.getForceable())
140 auto firrtlType = type_dyn_cast<FIRRTLBaseType>(type);
146 if (!firrtlType.isPassive() || firrtlType.containsAnalog() ||
158 llvm_unreachable(
"unexpected mode");
164static bool peelType(Type type, SmallVectorImpl<FlatBundleFieldEntry> &fields,
171 if (
auto refType = type_dyn_cast<RefType>(type))
172 type = refType.getType();
174 .
Case<BundleType>([&](
auto bundle) {
175 SmallString<16> tmpSuffix;
177 for (
size_t i = 0, e = bundle.getNumElements(); i < e; ++i) {
178 auto elt = bundle.getElement(i);
181 tmpSuffix.push_back(
'_');
182 tmpSuffix.append(elt.name.getValue());
183 fields.emplace_back(elt.type, i, bundle.getFieldID(i), tmpSuffix,
188 .Case<FVectorType>([&](
auto vector) {
190 for (
size_t i = 0, e = vector.getNumElements(); i != e; ++i) {
191 fields.emplace_back(vector.getElementType(), i, vector.getFieldID(i),
192 "_" + std::to_string(i),
false);
196 .Default([](
auto op) {
return false; });
202 SubaccessOp sao = llvm::dyn_cast<SubaccessOp>(op);
206 llvm::dyn_cast_or_null<ConstantOp>(sao.getIndex().getDefiningOp());
207 return arg && sao.getInput().getType().base().getNumElements() != 0;
212 SmallVector<Operation *> retval;
213 auto defOp = op->getOperand(0).getDefiningOp();
214 while (isa_and_nonnull<SubfieldOp, SubindexOp, SubaccessOp>(defOp)) {
215 retval.push_back(defOp);
216 defOp = defOp->getOperand(0).getDefiningOp();
226 FlatBundleFieldEntry field) {
227 SmallVector<Type, 8> ports;
228 SmallVector<Attribute, 8> portNames;
229 SmallVector<Attribute, 8> portLocations;
231 auto oldPorts = op.getPorts();
232 for (
size_t portIdx = 0, e = oldPorts.size(); portIdx < e; ++portIdx) {
233 auto port = oldPorts[portIdx];
235 MemOp::getTypeForPort(op.getDepth(), field.type, port.second));
236 portNames.push_back(port.first);
241 MemOp::create(*b, ports, op.getReadLatency(), op.getWriteLatency(),
242 op.getDepth(), op.getRuw(), b->getArrayAttr(portNames),
243 (op.getName() + field.suffix).str(), op.getNameKind(),
244 op.getAnnotations(), op.getPortAnnotations(),
245 op.getInnerSymAttr(), op.getInitAttr(), op.getPrefixAttr());
247 if (op.getInnerSym()) {
248 op.emitError(
"cannot split memory with symbol present");
252 SmallVector<Attribute> newAnnotations;
253 for (
size_t portIdx = 0, e = newMem.getNumResults(); portIdx < e; ++portIdx) {
254 auto portType = type_cast<BundleType>(newMem.getResult(portIdx).getType());
255 auto oldPortType = type_cast<BundleType>(op.getResult(portIdx).getType());
256 SmallVector<Attribute> portAnno;
257 for (
auto attr : newMem.getPortAnnotation(portIdx)) {
260 auto targetIndex = oldPortType.getIndexForFieldID(annoFieldID);
264 if (annoFieldID == oldPortType.getFieldID(targetIndex)) {
267 b->getI32IntegerAttr(portType.getFieldID(targetIndex)));
268 portAnno.push_back(anno.
getDict());
273 if (type_isa<BundleType>(oldPortType.getElement(targetIndex).type)) {
278 auto fieldID = field.fieldID + oldPortType.getFieldID(targetIndex);
279 if (annoFieldID >= fieldID &&
284 annoFieldID - fieldID + portType.getFieldID(targetIndex);
285 anno.
setMember(
"circt.fieldID", b->getI32IntegerAttr(newFieldID));
286 portAnno.push_back(anno.
getDict());
290 portAnno.push_back(attr);
292 newAnnotations.push_back(b->getArrayAttr(portAnno));
294 newMem.setAllPortAnnotations(newAnnotations);
304 AttrCache(MLIRContext *
context) {
305 i64ty = IntegerType::get(
context, 64);
306 nameAttr = StringAttr::get(
context,
"name");
307 nameKindAttr = StringAttr::get(
context,
"nameKind");
308 sPortDirections = StringAttr::get(
context,
"portDirections");
309 sPortNames = StringAttr::get(
context,
"portNames");
310 sPortTypes = StringAttr::get(
context,
"portTypes");
311 sPortSymbols = StringAttr::get(
context,
"portSymbols");
312 sPortLocations = StringAttr::get(
context,
"portLocations");
313 sPortAnnotations = StringAttr::get(
context,
"portAnnotations");
314 sPortDomains = StringAttr::get(
context,
"domainInfo");
315 sEmpty = StringAttr::get(
context,
"");
316 aEmpty = ArrayAttr::get(
context, {});
318 AttrCache(
const AttrCache &) =
default;
321 StringAttr nameAttr, nameKindAttr, sPortDirections, sPortNames, sPortTypes,
322 sPortSymbols, sPortLocations, sPortAnnotations, sPortDomains, sEmpty;
329class DomainLoweringHelper {
334 DomainLoweringHelper(MLIRContext *context, ArrayRef<Attribute> portTypes)
336 for (
auto [index, typeAttr] :
llvm::enumerate(portTypes))
337 if (
type_isa<DomainType>(cast<TypeAttr>(typeAttr).getValue()))
338 domainIndexByOrdinal.push_back(index);
342 DomainLoweringHelper(MLIRContext *context, TypeRange resultTypes)
344 for (
auto [index, type] :
llvm::enumerate(resultTypes))
346 domainIndexByOrdinal.push_back(index);
353 void computeDomainMap(TypeRange types) {
354 size_t i = 0, ord = 0;
355 for (
auto type : types) {
356 if (type_isa<DomainType>(type))
357 domainMap[domainIndexByOrdinal[ord++]] = i;
366 void computeDomainMap(ArrayRef<PortInfo> ports) {
367 size_t i = 0, ord = 0;
368 for (
const auto &port : ports) {
369 if (type_isa<DomainType>(port.type))
370 domainMap[domainIndexByOrdinal[ord++]] = i;
378 void rewriteDomain(Attribute &domain) {
379 auto oldAssociations = dyn_cast<ArrayAttr>(domain);
380 if (!oldAssociations)
382 SmallVector<Attribute> newAssociations;
383 for (
auto oldAttr : oldAssociations)
384 newAssociations.push_back(IntegerAttr::
get(
385 IntegerType::
get(context, 32, IntegerType::Unsigned),
386 domainMap[cast<IntegerAttr>(oldAttr).getValue().getZExtValue()]));
387 domain = ArrayAttr::get(context, newAssociations);
391 MLIRContext *context;
393 SmallVector<unsigned> domainIndexByOrdinal;
395 DenseMap<unsigned, unsigned> domainMap;
400struct TypeLoweringVisitor :
public FIRRTLVisitor<TypeLoweringVisitor, bool> {
404 Convention bodyConvention,
406 SymbolTable &symTbl,
const AttrCache &cache,
407 const llvm::DenseMap<FModuleLike, Convention> &conventionTable)
408 : context(context), defaultAggregatePreservationMode(preserveAggregate),
409 memoryPreservationMode(memoryPreservationMode), symTbl(symTbl),
410 cache(cache), conventionTable(conventionTable) {
411 bodyAggregatePreservationMode = bodyConvention == Convention::Scalarized
413 : defaultAggregatePreservationMode;
421 void lowerModule(FModuleLike op);
423 bool lowerArg(FModuleLike module,
size_t argIndex,
size_t argsRemoved,
424 SmallVectorImpl<PortInfo> &newArgs,
425 SmallVectorImpl<Value> &lowering);
426 std::pair<Value, PortInfo> addArg(Operation *module,
unsigned insertPt,
428 const FlatBundleFieldEntry &field,
429 PortInfo &oldArg, hw::InnerSymAttr newSym);
432 bool visitDecl(FExtModuleOp op);
433 bool visitDecl(FModuleOp op);
434 bool visitDecl(InstanceOp op);
435 bool visitDecl(InstanceChoiceOp op);
436 bool visitDecl(MemOp op);
437 bool visitDecl(NodeOp op);
438 bool visitDecl(RegOp op);
439 bool visitDecl(WireOp op);
440 bool visitDecl(RegResetOp op);
441 bool visitExpr(InvalidValueOp op);
442 bool visitExpr(SubaccessOp op);
443 bool visitExpr(VectorCreateOp op);
444 bool visitExpr(BundleCreateOp op);
445 bool visitExpr(ElementwiseAndPrimOp op);
446 bool visitExpr(ElementwiseOrPrimOp op);
447 bool visitExpr(ElementwiseXorPrimOp op);
448 bool visitExpr(MultibitMuxOp op);
449 bool visitExpr(MuxPrimOp op);
450 bool visitExpr(Mux2CellIntrinsicOp op);
451 bool visitExpr(Mux4CellIntrinsicOp op);
452 bool visitExpr(BitCastOp op);
453 bool visitExpr(RefSendOp op);
454 bool visitExpr(RefResolveOp op);
455 bool visitExpr(RefCastOp op);
456 bool visitStmt(ConnectOp op);
457 bool visitStmt(MatchingConnectOp op);
458 bool visitStmt(RefDefineOp op);
459 bool visitStmt(WhenOp op);
460 bool visitStmt(LayerBlockOp op);
461 bool visitUnrealizedConversionCast(mlir::UnrealizedConversionCastOp op);
463 bool isFailed()
const {
return encounteredError; }
466 if (
auto castOp = dyn_cast<mlir::UnrealizedConversionCastOp>(op))
467 return visitUnrealizedConversionCast(castOp);
472 void processUsers(Value val, ArrayRef<Value> mapping);
473 bool processSAPath(Operation *);
474 void lowerBlock(Block *);
475 void lowerSAWritePath(Operation *, ArrayRef<Operation *> writePath);
485 llvm::function_ref<Value(
const FlatBundleFieldEntry &, ArrayAttr)> clone,
490 ArrayAttr filterAnnotations(MLIRContext *ctxt, ArrayAttr annotations,
491 FIRRTLType srcType, FlatBundleFieldEntry field);
495 LogicalResult partitionSymbols(hw::InnerSymAttr sym,
FIRRTLType parentType,
496 SmallVectorImpl<hw::InnerSymAttr> &newSyms,
500 getPreservationModeForPorts(FModuleLike moduleLike);
501 Value getSubWhatever(Value val,
size_t index);
506 ArrayAttr oldPortAnno,
507 llvm::function_ref<Operation *(
508 ArrayRef<Type>, ArrayRef<Direction>, ArrayAttr,
509 ArrayAttr, ArrayAttr, hw::InnerSymAttr)>
512 size_t uniqueIdx = 0;
513 std::string uniqueName() {
514 auto myID = uniqueIdx++;
515 return (Twine(
"__GEN_") + Twine(myID)).str();
526 ImplicitLocOpBuilder *builder;
532 const AttrCache &cache;
534 const llvm::DenseMap<FModuleLike, Convention> &conventionTable;
537 bool encounteredError =
false;
544TypeLoweringVisitor::getPreservationModeForPorts(FModuleLike module) {
545 auto lookup = conventionTable.find(module);
546 if (lookup == conventionTable.end())
547 return defaultAggregatePreservationMode;
548 switch (lookup->second) {
549 case Convention::Scalarized:
551 case Convention::Internal:
552 return defaultAggregatePreservationMode;
554 llvm_unreachable(
"Unknown convention");
555 return defaultAggregatePreservationMode;
558Value TypeLoweringVisitor::getSubWhatever(Value val,
size_t index) {
559 if (type_isa<BundleType>(val.getType()))
560 return SubfieldOp::create(*builder, val, index);
561 if (type_isa<FVectorType>(val.getType()))
562 return SubindexOp::create(*builder, val, index);
563 if (type_isa<RefType>(val.getType()))
564 return RefSubOp::create(*builder, val, index);
565 llvm_unreachable(
"Unknown aggregate type");
570bool TypeLoweringVisitor::processSAPath(Operation *op) {
573 if (writePath.empty())
576 lowerSAWritePath(op, writePath);
579 op->eraseOperands(0, 2);
581 for (
size_t i = 0; i < writePath.size(); ++i) {
582 if (writePath[i]->use_empty()) {
583 writePath[i]->erase();
591void TypeLoweringVisitor::lowerBlock(Block *block) {
593 for (
auto it = block->rbegin(), e = block->rend(); it != e;) {
595 builder->setInsertionPoint(&iop);
596 builder->setLoc(iop.getLoc());
597 bool removeOp = dispatchVisitor(&iop);
606ArrayAttr TypeLoweringVisitor::filterAnnotations(MLIRContext *ctxt,
607 ArrayAttr annotations,
609 FlatBundleFieldEntry field) {
610 SmallVector<Attribute> retval;
611 if (!annotations || annotations.empty())
612 return ArrayAttr::get(ctxt, retval);
613 for (
auto opAttr : annotations) {
615 auto fieldID = anno.getFieldID();
616 anno.removeMember(
"circt.fieldID");
621 retval.push_back(anno.getAttr());
626 if (fieldID < field.fieldID ||
631 if (
auto newFieldID = fieldID - field.fieldID) {
634 anno.setMember(
"circt.fieldID", builder->getI32IntegerAttr(newFieldID));
637 retval.push_back(anno.getAttr());
639 return ArrayAttr::get(ctxt, retval);
642LogicalResult TypeLoweringVisitor::partitionSymbols(
644 SmallVectorImpl<hw::InnerSymAttr> &newSyms, Location errorLoc) {
647 if (!sym || sym.empty())
650 auto *
context = sym.getContext();
654 return mlir::emitError(errorLoc,
655 "unable to partition symbol on unsupported type ")
658 return TypeSwitch<FIRRTLType, LogicalResult>(baseType)
659 .Case<BundleType, FVectorType>([&](
auto aggType) -> LogicalResult {
663 hw::InnerSymPropertiesAttr prop;
667 SmallVector<BinningInfo> binning;
668 for (
auto prop : sym) {
669 auto fieldID = prop.getFieldID();
672 return mlir::emitError(errorLoc,
"unable to lower due to symbol ")
674 <<
" with target not preserved by lowering";
675 auto [index, relFieldID] = aggType.getIndexAndSubfieldID(fieldID);
676 binning.push_back({index, relFieldID, prop});
680 llvm::stable_sort(binning, [&](
auto &lhs,
auto &rhs) {
681 return std::tuple(lhs.index, lhs.relFieldID) <
682 std::tuple(rhs.index, rhs.relFieldID);
687 newSyms.resize(aggType.getNumElements());
688 for (
auto binIt = binning.begin(), binEnd = binning.end();
690 auto curIndex = binIt->index;
691 SmallVector<hw::InnerSymPropertiesAttr> propsForIndex;
693 while (binIt != binEnd && binIt->index == curIndex) {
694 propsForIndex.push_back(hw::InnerSymPropertiesAttr::get(
695 context, binIt->prop.getName(), binIt->relFieldID,
696 binIt->prop.getSymVisibility()));
700 assert(!newSyms[curIndex]);
701 newSyms[curIndex] = hw::InnerSymAttr::get(
context, propsForIndex);
705 .Default([&](
auto ty) {
706 return mlir::emitError(
707 errorLoc,
"unable to partition symbol on unsupported type ")
712bool TypeLoweringVisitor::lowerProducer(
714 llvm::function_ref<Value(
const FlatBundleFieldEntry &, ArrayAttr)> clone,
718 srcType = op->getResult(0).getType();
719 auto srcFType = type_dyn_cast<FIRRTLType>(srcType);
722 SmallVector<FlatBundleFieldEntry, 8> fieldTypes;
724 if (!
peelType(srcFType, fieldTypes, bodyAggregatePreservationMode))
727 SmallVector<Value> lowered;
729 SmallString<16> loweredName;
730 auto nameKindAttr = op->getAttrOfType<NameKindEnumAttr>(cache.nameKindAttr);
732 if (
auto nameAttr = op->getAttrOfType<StringAttr>(cache.nameAttr))
733 loweredName = nameAttr.getValue();
734 auto baseNameLen = loweredName.size();
735 auto oldAnno = dyn_cast_or_null<ArrayAttr>(op->getAttr(
"annotations"));
737 SmallVector<hw::InnerSymAttr> fieldSyms(fieldTypes.size());
738 if (
auto symOp = dyn_cast<hw::InnerSymbolOpInterface>(op)) {
739 if (failed(partitionSymbols(symOp.getInnerSymAttr(), srcFType, fieldSyms,
741 encounteredError =
true;
746 for (
const auto &[field, sym] :
llvm::zip_equal(fieldTypes, fieldSyms)) {
747 if (!loweredName.empty()) {
748 loweredName.resize(baseNameLen);
749 loweredName += field.suffix;
754 ArrayAttr loweredAttrs =
755 filterAnnotations(
context, oldAnno, srcFType, field);
756 auto newVal = clone(field, loweredAttrs);
762 auto newSymOp = newVal.getDefiningOp<hw::InnerSymbolOpInterface>();
765 "op with inner symbol lowered to op that cannot take inner symbol");
766 newSymOp.setInnerSymbolAttr(sym);
770 if (
auto *newOp = newVal.getDefiningOp()) {
771 if (!loweredName.empty())
772 newOp->setAttr(cache.nameAttr, StringAttr::get(
context, loweredName));
774 newOp->setAttr(cache.nameKindAttr, nameKindAttr);
777 newOp->setDiscardableAttrs(op->getDiscardableAttrDictionary());
779 lowered.push_back(newVal);
782 processUsers(op->getResult(0), lowered);
786void TypeLoweringVisitor::processUsers(Value val, ArrayRef<Value> mapping) {
787 for (
auto *user :
llvm::make_early_inc_range(val.getUsers())) {
788 TypeSwitch<Operation *, void>(user)
789 .Case<SubindexOp>([mapping](SubindexOp sio) {
790 Value repl = mapping[sio.getIndex()];
791 sio.replaceAllUsesWith(repl);
794 .Case<SubfieldOp>([mapping](SubfieldOp sfo) {
796 Value repl = mapping[sfo.getFieldIndex()];
797 sfo.replaceAllUsesWith(repl);
800 .Case<RefSubOp>([mapping](RefSubOp refSub) {
801 Value repl = mapping[refSub.getIndex()];
802 refSub.replaceAllUsesWith(repl);
805 .Default([&](
auto op) {
816 ImplicitLocOpBuilder
b(user->getLoc(), user);
820 assert(llvm::none_of(mapping, [](
auto v) {
821 auto fbasetype = type_dyn_cast<FIRRTLBaseType>(v.getType());
822 return !fbasetype || fbasetype.containsReference();
826 TypeSwitch<Type, Value>(val.getType())
827 .template Case<FVectorType>([&](
auto vecType) {
828 return b.createOrFold<VectorCreateOp>(vecType, mapping);
830 .
template Case<BundleType>([&](
auto bundleType) {
831 return b.createOrFold<BundleCreateOp>(bundleType, mapping);
833 .Default([&](
auto _) -> Value {
return {}; });
835 user->emitError(
"unable to reconstruct source of type ")
837 encounteredError =
true;
840 user->replaceUsesOfWith(val, input);
848 const llvm::BitVector &removalMask) {
849 size_t writeIndex = 0, readIndex = 0;
853 for (
size_t removalIndex : removalMask.set_bits()) {
855 assert(removalIndex >= readIndex &&
"removal index before read index");
856 size_t rangeSize = removalIndex - readIndex;
861 if (writeIndex != readIndex)
862 std::move(vec.begin() + readIndex, vec.begin() + removalIndex,
863 vec.begin() + writeIndex);
864 writeIndex += rangeSize;
866 readIndex = removalIndex + 1;
870 size_t remainingSize = vec.size() - readIndex;
871 if (remainingSize > 0) {
872 if (writeIndex != readIndex)
873 std::move(vec.begin() + readIndex, vec.end(), vec.begin() + writeIndex);
874 writeIndex += remainingSize;
878 vec.truncate(writeIndex);
881void TypeLoweringVisitor::lowerModule(FModuleLike op) {
882 if (
auto module = llvm::dyn_cast<FModuleOp>(*op))
884 else if (
auto extModule = llvm::dyn_cast<FExtModuleOp>(*op))
885 visitDecl(extModule);
891std::pair<Value, PortInfo>
892TypeLoweringVisitor::addArg(Operation *module,
unsigned insertPt,
894 const FlatBundleFieldEntry &field,
PortInfo &oldArg,
895 hw::InnerSymAttr newSym) {
898 if (
auto mod = llvm::dyn_cast<FModuleOp>(module)) {
899 Block *body = mod.getBodyBlock();
901 newValue = body->insertArgument(insertPt, fieldType, oldArg.
loc);
905 auto name = builder->getStringAttr(oldArg.
name.getValue() + field.suffix);
908 auto newAnnotations = filterAnnotations(
913 return std::make_pair(
914 newValue,
PortInfo{name, fieldType, direction, newSym, oldArg.
loc,
919bool TypeLoweringVisitor::lowerArg(FModuleLike module,
size_t argIndex,
921 SmallVectorImpl<PortInfo> &newArgs,
922 SmallVectorImpl<Value> &lowering) {
925 SmallVector<FlatBundleFieldEntry> fieldTypes;
926 auto srcType = type_cast<FIRRTLType>(newArgs[argIndex].type);
927 if (!
peelType(srcType, fieldTypes, getPreservationModeForPorts(module)))
930 SmallVector<hw::InnerSymAttr> fieldSyms(fieldTypes.size());
931 if (failed(partitionSymbols(newArgs[argIndex].sym, srcType, fieldSyms,
932 newArgs[argIndex].loc))) {
933 encounteredError =
true;
937 for (
const auto &[idx, field, fieldSym] :
938 llvm::enumerate(fieldTypes, fieldSyms)) {
939 auto newValue = addArg(module, 1 + argIndex + idx, argsRemoved, srcType,
940 field, newArgs[argIndex], fieldSym);
941 newArgs.insert(newArgs.begin() + 1 + argIndex + idx, newValue.second);
943 lowering.push_back(newValue.first);
948static Value
cloneAccess(ImplicitLocOpBuilder *builder, Operation *op,
950 if (
auto rop = llvm::dyn_cast<SubfieldOp>(op))
951 return SubfieldOp::create(*builder, rhs, rop.getFieldIndex());
952 if (
auto rop = llvm::dyn_cast<SubindexOp>(op))
953 return SubindexOp::create(*builder, rhs, rop.getIndex());
954 if (
auto rop = llvm::dyn_cast<SubaccessOp>(op))
955 return SubaccessOp::create(*builder, rhs, rop.getIndex());
956 op->emitError(
"Unknown accessor");
960void TypeLoweringVisitor::lowerSAWritePath(Operation *op,
961 ArrayRef<Operation *> writePath) {
962 SubaccessOp sao = cast<SubaccessOp>(writePath.back());
963 FVectorType saoType = sao.getInput().getType();
964 auto selectWidth = llvm::Log2_64_Ceil(saoType.getNumElements());
966 for (
size_t index = 0, e = saoType.getNumElements(); index < e; ++index) {
967 auto cond = EQPrimOp::create(
968 *builder, sao.getIndex(),
969 builder->createOrFold<ConstantOp>(UIntType::get(
context, selectWidth),
970 APInt(selectWidth, index)));
971 WhenOp::create(*builder, cond,
false, [&]() {
973 Value leaf = SubindexOp::create(*builder, sao.getInput(), index);
974 for (
int i = writePath.size() - 2; i >= 0; --i) {
975 if (
auto access =
cloneAccess(builder, writePath[i], leaf))
978 encounteredError =
true;
989bool TypeLoweringVisitor::visitStmt(ConnectOp op) {
990 if (processSAPath(op))
994 SmallVector<FlatBundleFieldEntry> fields;
1001 for (
const auto &field :
llvm::enumerate(fields)) {
1002 Value src = getSubWhatever(op.getSrc(), field.index());
1003 Value dest = getSubWhatever(op.getDest(), field.index());
1004 if (field.value().isOutput)
1005 std::swap(src, dest);
1012bool TypeLoweringVisitor::visitStmt(MatchingConnectOp op) {
1013 if (processSAPath(op))
1017 SmallVector<FlatBundleFieldEntry> fields;
1024 for (
const auto &field :
llvm::enumerate(fields)) {
1025 Value src = getSubWhatever(op.getSrc(), field.index());
1026 Value dest = getSubWhatever(op.getDest(), field.index());
1027 if (field.value().isOutput)
1028 std::swap(src, dest);
1029 MatchingConnectOp::create(*builder, dest, src);
1035bool TypeLoweringVisitor::visitStmt(RefDefineOp op) {
1037 SmallVector<FlatBundleFieldEntry> fields;
1039 if (!
peelType(op.getDest().getType(), fields, bodyAggregatePreservationMode))
1043 for (
const auto &field :
llvm::enumerate(fields)) {
1044 Value src = getSubWhatever(op.getSrc(), field.index());
1045 Value dest = getSubWhatever(op.getDest(), field.index());
1046 assert(!field.value().isOutput &&
"unexpected flip in reftype destination");
1047 RefDefineOp::create(*builder, dest, src);
1052bool TypeLoweringVisitor::visitStmt(WhenOp op) {
1058 lowerBlock(&op.getThenBlock());
1061 if (op.hasElseRegion())
1062 lowerBlock(&op.getElseBlock());
1067bool TypeLoweringVisitor::visitStmt(LayerBlockOp op) {
1068 lowerBlock(op.getBody());
1074bool TypeLoweringVisitor::visitDecl(MemOp op) {
1076 SmallVector<FlatBundleFieldEntry> fields;
1079 if (!
peelType(op.getDataType(), fields, memoryPreservationMode))
1082 if (op.getInnerSym()) {
1083 op->emitError() <<
"has a symbol, but no symbols may exist on aggregates "
1084 "passed through LowerTypes";
1085 encounteredError =
true;
1089 SmallVector<MemOp> newMemories;
1090 SmallVector<WireOp> oldPorts;
1093 for (
unsigned int index = 0, end = op.getNumResults(); index <
end; ++index) {
1094 auto result = op.getResult(index);
1095 if (op.getPortKind(index) == MemOp::PortKind::Debug) {
1096 op.emitOpError(
"cannot lower memory with debug port");
1097 encounteredError =
true;
1101 WireOp::create(*builder, result.getType(),
1102 (op.getName() +
"_" + op.getPortName(index)).str());
1103 oldPorts.push_back(wire);
1104 result.replaceAllUsesWith(wire.getResult());
1111 for (
const auto &field : fields) {
1113 if (!newMemForField) {
1114 op.emitError(
"failed cloning memory for field");
1115 encounteredError =
true;
1118 newMemories.push_back(newMemForField);
1121 for (
size_t index = 0, rend = op.getNumResults(); index < rend; ++index) {
1122 auto result = oldPorts[index].getResult();
1123 auto rType = type_cast<BundleType>(result.getType());
1124 for (
size_t fieldIndex = 0, fend = rType.getNumElements();
1125 fieldIndex != fend; ++fieldIndex) {
1126 auto name = rType.getElement(fieldIndex).name.getValue();
1127 auto oldField = SubfieldOp::create(*builder, result, fieldIndex);
1130 if (name ==
"data" || name ==
"mask" || name ==
"wdata" ||
1131 name ==
"wmask" || name ==
"rdata") {
1132 for (
const auto &field : fields) {
1133 auto realOldField = getSubWhatever(oldField, field.index);
1134 auto newField = getSubWhatever(
1135 newMemories[field.index].getResult(index), fieldIndex);
1136 if (rType.getElement(fieldIndex).isFlip)
1137 std::swap(realOldField, newField);
1141 for (
auto mem : newMemories) {
1143 SubfieldOp::create(*builder, mem.getResult(index), fieldIndex);
1152bool TypeLoweringVisitor::visitDecl(FExtModuleOp extModule) {
1153 ImplicitLocOpBuilder theBuilder(extModule.getLoc(),
context);
1154 builder = &theBuilder;
1160 llvm::BitVector argsToRemove;
1161 auto newArgs = extModule.getPorts();
1162 argsToRemove.reserve(newArgs.size());
1164 DomainLoweringHelper domainHelper(
context, extModule.getPortTypes());
1166 size_t argsRemoved = 0;
1167 for (
size_t argIndex = 0; argIndex < newArgs.size(); ++argIndex) {
1168 SmallVector<Value> lowering;
1169 if (lowerArg(extModule, argIndex, argsRemoved, newArgs, lowering)) {
1170 argsToRemove.push_back(
true);
1173 argsToRemove.push_back(
false);
1179 if (argsRemoved != 0)
1182 domainHelper.computeDomainMap(newArgs);
1184 SmallVector<NamedAttribute, 8> newModuleAttrs;
1187 for (
auto attr : extModule->getAttrDictionary())
1190 if (attr.
getName() !=
"portDirections" && attr.
getName() !=
"portNames" &&
1191 attr.
getName() !=
"portTypes" && attr.
getName() !=
"portAnnotations" &&
1192 attr.
getName() !=
"portSymbols" && attr.
getName() !=
"portLocations")
1193 newModuleAttrs.push_back(attr);
1195 SmallVector<Direction> newArgDirections;
1196 SmallVector<Attribute> newArgNames;
1197 SmallVector<Attribute, 8> newArgTypes;
1198 SmallVector<Attribute, 8> newArgSyms;
1199 SmallVector<Attribute, 8> newArgLocations;
1200 SmallVector<Attribute, 8> newArgAnnotations;
1201 SmallVector<Attribute, 8> newArgDomains;
1203 for (
auto &port : newArgs) {
1204 newArgDirections.push_back(port.direction);
1205 newArgNames.push_back(port.name);
1206 newArgTypes.push_back(TypeAttr::get(port.type));
1207 newArgSyms.push_back(port.sym);
1208 newArgLocations.push_back(port.loc);
1209 newArgAnnotations.push_back(port.annotations.getArrayAttr());
1211 domainHelper.rewriteDomain(port.domains);
1213 port.domains = cache.aEmpty;
1215 newArgDomains.push_back(port.domains);
1218 newModuleAttrs.push_back(
1219 NamedAttribute(cache.sPortDirections,
1222 newModuleAttrs.push_back(
1223 NamedAttribute(cache.sPortNames, builder.getArrayAttr(newArgNames)));
1225 newModuleAttrs.push_back(
1226 NamedAttribute(cache.sPortTypes, builder.getArrayAttr(newArgTypes)));
1228 newModuleAttrs.push_back(NamedAttribute(
1229 cache.sPortLocations, builder.getArrayAttr(newArgLocations)));
1231 newModuleAttrs.push_back(NamedAttribute(
1232 cache.sPortAnnotations, builder.getArrayAttr(newArgAnnotations)));
1234 newModuleAttrs.push_back(
1235 NamedAttribute(cache.sPortDomains, builder.getArrayAttr(newArgDomains)));
1238 extModule->setAttrs(newModuleAttrs);
1239 FModuleLike::fixupPortSymsArray(newArgSyms,
context);
1240 extModule.setPortSymbols(newArgSyms);
1245bool TypeLoweringVisitor::visitDecl(FModuleOp module) {
1246 auto *body =
module.getBodyBlock();
1248 ImplicitLocOpBuilder theBuilder(module.getLoc(),
context);
1249 builder = &theBuilder;
1255 llvm::BitVector argsToRemove;
1256 auto newArgs =
module.getPorts();
1257 argsToRemove.reserve(newArgs.size());
1259 DomainLoweringHelper domainHelper(
context, module.getPortTypes());
1261 size_t argsRemoved = 0;
1262 for (
size_t argIndex = 0; argIndex < newArgs.size(); ++argIndex) {
1263 SmallVector<Value> lowerings;
1264 if (lowerArg(module, argIndex, argsRemoved, newArgs, lowerings)) {
1265 auto arg =
module.getArgument(argIndex);
1266 processUsers(arg, lowerings);
1267 argsToRemove.push_back(
true);
1270 argsToRemove.push_back(
false);
1275 if (argsRemoved != 0) {
1276 body->eraseArguments(argsToRemove);
1280 domainHelper.computeDomainMap(newArgs);
1282 SmallVector<NamedAttribute, 8> newModuleAttrs;
1285 for (
auto attr : module->getAttrDictionary())
1288 if (attr.
getName() !=
"portNames" && attr.
getName() !=
"portDirections" &&
1289 attr.
getName() !=
"portTypes" && attr.
getName() !=
"portAnnotations" &&
1290 attr.
getName() !=
"portSymbols" && attr.
getName() !=
"portLocations")
1291 newModuleAttrs.push_back(attr);
1293 SmallVector<Direction> newArgDirections;
1294 SmallVector<Attribute> newArgNames;
1295 SmallVector<Attribute> newArgTypes;
1296 SmallVector<Attribute> newArgSyms;
1297 SmallVector<Attribute> newArgLocations;
1298 SmallVector<Attribute, 8> newArgAnnotations;
1299 SmallVector<Attribute> newPortDomains;
1300 for (
auto &port : newArgs) {
1301 newArgDirections.push_back(port.direction);
1302 newArgNames.push_back(port.name);
1303 newArgTypes.push_back(TypeAttr::get(port.type));
1304 newArgSyms.push_back(port.sym);
1305 newArgLocations.push_back(port.loc);
1306 newArgAnnotations.push_back(port.annotations.getArrayAttr());
1308 domainHelper.rewriteDomain(port.domains);
1310 port.domains = cache.aEmpty;
1312 newPortDomains.push_back(port.domains);
1315 newModuleAttrs.push_back(
1316 NamedAttribute(cache.sPortDirections,
1319 newModuleAttrs.push_back(
1320 NamedAttribute(cache.sPortNames, builder->getArrayAttr(newArgNames)));
1322 newModuleAttrs.push_back(
1323 NamedAttribute(cache.sPortTypes, builder->getArrayAttr(newArgTypes)));
1325 newModuleAttrs.push_back(NamedAttribute(
1326 cache.sPortLocations, builder->getArrayAttr(newArgLocations)));
1328 newModuleAttrs.push_back(NamedAttribute(
1329 cache.sPortAnnotations, builder->getArrayAttr(newArgAnnotations)));
1331 newModuleAttrs.push_back(NamedAttribute(
1332 cache.sPortDomains, builder->getArrayAttr(newPortDomains)));
1335 module->setAttrs(newModuleAttrs);
1336 FModuleLike::fixupPortSymsArray(newArgSyms,
context);
1337 module.setPortSymbols(newArgSyms);
1342bool TypeLoweringVisitor::visitDecl(WireOp op) {
1343 if (op.isForceable())
1346 auto clone = [&](
const FlatBundleFieldEntry &field,
1347 ArrayAttr attrs) -> Value {
1348 return WireOp::create(*builder,
1350 "", NameKindEnum::DroppableName, attrs, StringAttr{},
1351 false, op.getDomains())
1354 return lowerProducer(op, clone);
1358bool TypeLoweringVisitor::visitDecl(RegOp op) {
1359 if (op.isForceable())
1362 auto clone = [&](
const FlatBundleFieldEntry &field,
1363 ArrayAttr attrs) -> Value {
1364 return RegOp::create(*builder, field.type, op.getClockVal(),
"",
1365 NameKindEnum::DroppableName, attrs, StringAttr{})
1368 return lowerProducer(op, clone);
1372bool TypeLoweringVisitor::visitDecl(RegResetOp op) {
1373 if (op.isForceable())
1376 auto clone = [&](
const FlatBundleFieldEntry &field,
1377 ArrayAttr attrs) -> Value {
1378 auto resetVal = getSubWhatever(op.getResetValue(), field.index);
1379 return RegResetOp::create(*builder, field.type, op.getClockVal(),
1380 op.getResetSignal(), resetVal,
"",
1381 NameKindEnum::DroppableName, attrs, StringAttr{})
1384 return lowerProducer(op, clone);
1388bool TypeLoweringVisitor::visitDecl(NodeOp op) {
1389 if (op.isForceable())
1392 auto clone = [&](
const FlatBundleFieldEntry &field,
1393 ArrayAttr attrs) -> Value {
1394 auto input = getSubWhatever(op.getInput(), field.index);
1395 return NodeOp::create(*builder, input,
"", NameKindEnum::DroppableName,
1399 return lowerProducer(op, clone);
1403bool TypeLoweringVisitor::visitExpr(InvalidValueOp op) {
1404 auto clone = [&](
const FlatBundleFieldEntry &field,
1405 ArrayAttr attrs) -> Value {
1406 return InvalidValueOp::create(*builder, field.type);
1408 return lowerProducer(op, clone);
1412bool TypeLoweringVisitor::visitExpr(MuxPrimOp op) {
1413 auto clone = [&](
const FlatBundleFieldEntry &field,
1414 ArrayAttr attrs) -> Value {
1415 auto high = getSubWhatever(op.getHigh(), field.index);
1416 auto low = getSubWhatever(op.getLow(), field.index);
1417 return MuxPrimOp::create(*builder, op.getSel(), high, low);
1419 return lowerProducer(op, clone);
1423bool TypeLoweringVisitor::visitExpr(Mux2CellIntrinsicOp op) {
1424 auto clone = [&](
const FlatBundleFieldEntry &field,
1425 ArrayAttr attrs) -> Value {
1426 auto high = getSubWhatever(op.getHigh(), field.index);
1427 auto low = getSubWhatever(op.getLow(), field.index);
1428 return Mux2CellIntrinsicOp::create(*builder, op.getSel(), high, low);
1430 return lowerProducer(op, clone);
1434bool TypeLoweringVisitor::visitExpr(Mux4CellIntrinsicOp op) {
1435 auto clone = [&](
const FlatBundleFieldEntry &field,
1436 ArrayAttr attrs) -> Value {
1437 auto v3 = getSubWhatever(op.getV3(), field.index);
1438 auto v2 = getSubWhatever(op.getV2(), field.index);
1439 auto v1 = getSubWhatever(op.getV1(), field.index);
1440 auto v0 = getSubWhatever(op.getV0(), field.index);
1441 return Mux4CellIntrinsicOp::create(*builder, op.getSel(), v3, v2, v1, v0);
1443 return lowerProducer(op, clone);
1447bool TypeLoweringVisitor::visitUnrealizedConversionCast(
1448 mlir::UnrealizedConversionCastOp op) {
1449 auto clone = [&](
const FlatBundleFieldEntry &field,
1450 ArrayAttr attrs) -> Value {
1451 auto input = getSubWhatever(op.getOperand(0), field.index);
1452 return mlir::UnrealizedConversionCastOp::create(*builder, field.type, input)
1457 if (!type_isa<FIRRTLType>(op->getOperand(0).getType()))
1459 return lowerProducer(op, clone);
1463bool TypeLoweringVisitor::visitExpr(BitCastOp op) {
1464 Value srcLoweredVal = op.getInput();
1468 SmallVector<FlatBundleFieldEntry> fields;
1470 size_t uptoBits = 0;
1473 for (
const auto &field :
llvm::enumerate(fields)) {
1474 auto fieldBitwidth = *
getBitWidth(field.value().type);
1476 if (fieldBitwidth == 0)
1478 Value src = getSubWhatever(op.getInput(), field.index());
1480 src = builder->createOrFold<BitCastOp>(
1481 UIntType::get(
context, fieldBitwidth), src);
1484 srcLoweredVal = src;
1486 if (type_isa<BundleType>(op.getInput().getType())) {
1488 CatPrimOp::create(*builder, ValueRange{srcLoweredVal, src});
1491 CatPrimOp::create(*builder, ValueRange{src, srcLoweredVal});
1495 uptoBits += fieldBitwidth;
1498 srcLoweredVal = builder->createOrFold<AsUIntPrimOp>(srcLoweredVal);
1502 if (type_isa<BundleType, FVectorType>(op.getResult().getType())) {
1504 size_t uptoBits = 0;
1505 auto aggregateBits = *
getBitWidth(op.getResult().getType());
1506 auto clone = [&](
const FlatBundleFieldEntry &field,
1507 ArrayAttr attrs) -> Value {
1513 return InvalidValueOp::create(*builder, field.type);
1518 if (type_isa<BundleType>(op.getResult().getType())) {
1519 extractBits = BitsPrimOp::create(*builder, srcLoweredVal,
1520 aggregateBits - uptoBits - 1,
1521 aggregateBits - uptoBits - fieldBits);
1523 extractBits = BitsPrimOp::create(*builder, srcLoweredVal,
1524 uptoBits + fieldBits - 1, uptoBits);
1526 uptoBits += fieldBits;
1527 return BitCastOp::create(*builder, field.type,
extractBits);
1529 return lowerProducer(op, clone);
1533 if (type_isa<SIntType>(op.getType()))
1534 srcLoweredVal = AsSIntPrimOp::create(*builder, srcLoweredVal);
1535 op.getResult().replaceAllUsesWith(srcLoweredVal);
1539bool TypeLoweringVisitor::visitExpr(RefSendOp op) {
1540 auto clone = [&](
const FlatBundleFieldEntry &field,
1541 ArrayAttr attrs) -> Value {
1542 return RefSendOp::create(*builder,
1543 getSubWhatever(op.getBase(), field.index));
1548 return lowerProducer(op, clone);
1551bool TypeLoweringVisitor::visitExpr(RefResolveOp op) {
1552 auto clone = [&](
const FlatBundleFieldEntry &field,
1553 ArrayAttr attrs) -> Value {
1554 Value src = getSubWhatever(op.getRef(), field.index);
1555 return RefResolveOp::create(*builder, src);
1559 return lowerProducer(op, clone, op.getRef().getType());
1562bool TypeLoweringVisitor::visitExpr(RefCastOp op) {
1563 auto clone = [&](
const FlatBundleFieldEntry &field,
1564 ArrayAttr attrs) -> Value {
1565 auto input = getSubWhatever(op.getInput(), field.index);
1566 return RefCastOp::create(*builder,
1567 RefType::get(field.type,
1568 op.getType().getForceable(),
1569 op.getType().getLayer()),
1572 return lowerProducer(op, clone);
1577bool TypeLoweringVisitor::lowerInstanceLike(
1579 ArrayAttr oldPortAnno,
1580 llvm::function_ref<Operation *(ArrayRef<Type>, ArrayRef<Direction>,
1581 ArrayAttr, ArrayAttr, ArrayAttr,
1583 createNewInstance) {
1585 SmallVector<Type, 8> resultTypes;
1586 SmallVector<int64_t, 8> endFields;
1587 SmallVector<Direction> newDirs;
1588 SmallVector<Attribute> newNames, newDomains, newPortAnno;
1591 DomainLoweringHelper domainHelper(
context, op->getResultTypes());
1592 auto emptyAnno = builder->getArrayAttr({});
1594 endFields.push_back(0);
1595 for (
size_t i = 0, e = op->getNumResults(); i != e; ++i) {
1596 auto srcType = type_cast<FIRRTLType>(op->getResult(i).getType());
1599 SmallVector<FlatBundleFieldEntry, 8> fieldTypes;
1600 if (!
peelType(srcType, fieldTypes, mode)) {
1601 newDirs.push_back(op.getPortDirection(i));
1602 newNames.push_back(op.getPortNameAttr(i));
1603 newDomains.push_back(op.getPortDomain(i));
1604 resultTypes.push_back(srcType);
1605 newPortAnno.push_back(oldPortAnno ? oldPortAnno[i] : emptyAnno);
1608 auto oldName = op.getPortName(i);
1609 auto oldDir = op.getPortDirection(i);
1611 for (
const auto &field : fieldTypes) {
1612 newDirs.push_back(
direction::get((
unsigned)oldDir ^ field.isOutput));
1613 newNames.push_back(builder->getStringAttr(oldName + field.suffix));
1614 newDomains.push_back(op.getPortDomain(i));
1619 dyn_cast_or_null<ArrayAttr>(oldPortAnno[i]),
1622 newPortAnno.push_back(annos);
1625 endFields.push_back(resultTypes.size());
1635 domainHelper.computeDomainMap(resultTypes);
1638 for (
auto &domain : newDomains)
1639 domainHelper.rewriteDomain(domain);
1642 auto *newInstance = createNewInstance(
1643 resultTypes, newDirs, builder->getArrayAttr(newNames),
1644 builder->getArrayAttr(newDomains), builder->getArrayAttr(newPortAnno),
1645 sym ? hw::InnerSymAttr::get(sym) :
hw::InnerSymAttr());
1647 newInstance->setDiscardableAttrs(op->getDiscardableAttrDictionary());
1649 SmallVector<Value> lowered;
1650 for (
size_t aggIndex = 0, eAgg = op->getNumResults(); aggIndex != eAgg;
1653 for (
size_t fieldIndex = endFields[aggIndex],
1654 eField = endFields[aggIndex + 1];
1655 fieldIndex < eField; ++fieldIndex)
1656 lowered.push_back(newInstance->getResult(fieldIndex));
1657 if (lowered.size() != 1 ||
1658 op->getResult(aggIndex).getType() != resultTypes[endFields[aggIndex]])
1659 processUsers(op->getResult(aggIndex), lowered);
1661 op->getResult(aggIndex).replaceAllUsesWith(lowered[0]);
1666bool TypeLoweringVisitor::visitDecl(InstanceOp op) {
1669 cast<FModuleLike>(op.getReferencedOperation(symTbl)));
1672 auto createNewInstance = [&](ArrayRef<Type> resultTypes,
1673 ArrayRef<Direction> newDirs, ArrayAttr newNames,
1674 ArrayAttr newDomains, ArrayAttr newPortAnno,
1675 hw::InnerSymAttr sym) -> Operation * {
1677 return InstanceOp::create(
1678 *builder, resultTypes, op.getModuleNameAttr(), op.getNameAttr(),
1680 newNames, newDomains, op.getAnnotations(), newPortAnno,
1681 op.getLayersAttr(), op.getLowerToBindAttr(), op.getDoNotPrintAttr(),
1685 return lowerInstanceLike(op, mode, op.getPortAnnotations(),
1689bool TypeLoweringVisitor::visitDecl(InstanceChoiceOp op) {
1691 auto *moduleOp = symTbl.lookupNearestSymbolFrom(
1692 op, cast<FlatSymbolRefAttr>(op.getDefaultTargetAttr()));
1693 auto mode = getPreservationModeForPorts(cast<FModuleLike>(moduleOp));
1696 auto createNewInstance = [&](ArrayRef<Type> resultTypes,
1697 ArrayRef<Direction> newDirs, ArrayAttr newNames,
1698 ArrayAttr newDomains, ArrayAttr newPortAnno,
1699 hw::InnerSymAttr sym) -> Operation * {
1700 return InstanceChoiceOp::create(
1701 *builder, resultTypes, op.getModuleNames(), op.getCaseNames(),
1702 op.getNameAttr(), op.getNameKindAttr(),
1704 op.getAnnotations(), newPortAnno, op.getLayersAttr(), sym,
1705 op.getInstanceMacroAttr());
1708 return lowerInstanceLike(op, mode, op.getPortAnnotations(),
1712bool TypeLoweringVisitor::visitExpr(SubaccessOp op) {
1713 auto input = op.getInput();
1714 FVectorType vType = input.getType();
1717 if (vType.getNumElements() == 0) {
1718 Value inv = InvalidValueOp::create(*builder, vType.getElementType());
1719 op.replaceAllUsesWith(inv);
1724 if (ConstantOp arg =
1725 llvm::dyn_cast_or_null<ConstantOp>(op.getIndex().getDefiningOp())) {
1726 auto sio = SubindexOp::create(*builder, op.getInput(),
1727 arg.getValue().getExtValue());
1728 op.replaceAllUsesWith(sio.getResult());
1733 SmallVector<Value> inputs;
1734 inputs.reserve(vType.getNumElements());
1735 for (
int index = vType.getNumElements() - 1; index >= 0; index--)
1736 inputs.push_back(SubindexOp::create(*builder, input, index));
1738 Value multibitMux = MultibitMuxOp::create(*builder, op.getIndex(), inputs);
1739 op.replaceAllUsesWith(multibitMux);
1743bool TypeLoweringVisitor::visitExpr(VectorCreateOp op) {
1744 auto clone = [&](
const FlatBundleFieldEntry &field,
1745 ArrayAttr attrs) -> Value {
1746 return op.getOperand(field.index);
1748 return lowerProducer(op, clone);
1751bool TypeLoweringVisitor::visitExpr(BundleCreateOp op) {
1752 auto clone = [&](
const FlatBundleFieldEntry &field,
1753 ArrayAttr attrs) -> Value {
1754 return op.getOperand(field.index);
1756 return lowerProducer(op, clone);
1759bool TypeLoweringVisitor::visitExpr(ElementwiseOrPrimOp op) {
1760 auto clone = [&](
const FlatBundleFieldEntry &field,
1761 ArrayAttr attrs) -> Value {
1762 Value operands[] = {getSubWhatever(op.getLhs(), field.index),
1763 getSubWhatever(op.getRhs(), field.index)};
1764 return type_isa<BundleType, FVectorType>(field.type)
1765 ? (Value)ElementwiseOrPrimOp::create(*builder, field.type,
1767 : (Value)OrPrimOp::create(*builder, operands);
1770 return lowerProducer(op, clone);
1773bool TypeLoweringVisitor::visitExpr(ElementwiseAndPrimOp op) {
1774 auto clone = [&](
const FlatBundleFieldEntry &field,
1775 ArrayAttr attrs) -> Value {
1776 Value operands[] = {getSubWhatever(op.getLhs(), field.index),
1777 getSubWhatever(op.getRhs(), field.index)};
1778 return type_isa<BundleType, FVectorType>(field.type)
1779 ? (Value)ElementwiseAndPrimOp::create(*builder, field.type,
1781 : (Value)AndPrimOp::create(*builder, operands);
1784 return lowerProducer(op, clone);
1787bool TypeLoweringVisitor::visitExpr(ElementwiseXorPrimOp op) {
1788 auto clone = [&](
const FlatBundleFieldEntry &field,
1789 ArrayAttr attrs) -> Value {
1790 Value operands[] = {getSubWhatever(op.getLhs(), field.index),
1791 getSubWhatever(op.getRhs(), field.index)};
1792 return type_isa<BundleType, FVectorType>(field.type)
1793 ? (Value)ElementwiseXorPrimOp::create(*builder, field.type,
1795 : (Value)XorPrimOp::create(*builder, operands);
1798 return lowerProducer(op, clone);
1801bool TypeLoweringVisitor::visitExpr(MultibitMuxOp op) {
1802 auto clone = [&](
const FlatBundleFieldEntry &field,
1803 ArrayAttr attrs) -> Value {
1804 SmallVector<Value> newInputs;
1805 newInputs.reserve(op.getInputs().size());
1806 for (
auto input : op.getInputs()) {
1807 auto inputSub = getSubWhatever(input, field.index);
1808 newInputs.push_back(inputSub);
1810 return MultibitMuxOp::create(*builder, op.getIndex(), newInputs);
1812 return lowerProducer(op, clone);
1820struct LowerTypesPass
1821 :
public circt::firrtl::impl::LowerFIRRTLTypesBase<LowerTypesPass> {
1824 void runOnOperation()
override;
1829void LowerTypesPass::runOnOperation() {
1832 std::vector<FModuleLike> ops;
1833 auto &instanceGraph = getAnalysis<InstanceGraph>();
1835 auto &symTbl = getAnalysis<SymbolTable>();
1837 AttrCache cache(&getContext());
1839 DenseMap<FModuleLike, Convention> conventionTable;
1840 auto circuit = getOperation();
1841 for (
auto module : circuit.getOps<FModuleLike>()) {
1842 auto convention =
module.getConvention();
1845 if (llvm::any_of(instanceGraph.lookup(module)->uses(),
1847 return use->getInstance<InstanceChoiceOp>();
1849 convention = Convention::Scalarized;
1850 conventionTable.insert({module, convention});
1851 ops.push_back(module);
1855 auto lowerModules = [&](FModuleLike op) -> LogicalResult {
1857 Convention convention = Convention::Internal;
1858 if (
auto conventionAttr = dyn_cast_or_null<ConventionAttr>(
1859 op->getDiscardableAttr(
"body_type_lowering")))
1860 convention = conventionAttr.getValue();
1863 TypeLoweringVisitor(&getContext(), preserveAggregate, convention,
1864 preserveMemories, symTbl, cache, conventionTable);
1867 return LogicalResult::failure(tl.isFailed());
1870 auto result = failableParallelForEach(&getContext(), ops, lowerModules);
1873 signalPassFailure();
assert(baseType &&"element must be base type")
static SmallVector< Value > extractBits(OpBuilder &builder, Value val)
static std::unique_ptr< Context > context
static void dump(DIModule &module, raw_indented_ostream &os)
static void eraseElementsAtIndices(SmallVectorImpl< T > &vec, const llvm::BitVector &removalMask)
Helper function to remove elements from a vector based on a BitVector mask.
static bool isPreservableAggregateType(Type type, PreserveAggregate::PreserveMode mode)
Return true if we can preserve the type.
static FIRRTLType mapLoweredType(FIRRTLType type, FIRRTLBaseType fieldType)
Return fieldType or fieldType as same ref as type.
static MemOp cloneMemWithNewType(ImplicitLocOpBuilder *b, MemOp op, FlatBundleFieldEntry field)
Clone memory for the specified field. Returns null op on error.
static bool containsBundleType(FIRRTLType type)
Return true if the type has a bundle type as subtype.
static Value cloneAccess(ImplicitLocOpBuilder *builder, Operation *op, Value rhs)
static bool peelType(Type type, SmallVectorImpl< FlatBundleFieldEntry > &fields, PreserveAggregate::PreserveMode mode)
Peel one layer of an aggregate type into its components.
static bool isNotSubAccess(Operation *op)
Return if something is not a normal subaccess.
static SmallVector< Operation * > getSAWritePath(Operation *op)
Look through and collect subfields leading to a subaccess.
static bool isOneDimVectorType(FIRRTLType type)
Return true if the type is a 1d vector type or ground type.
#define CIRCT_DEBUG_SCOPED_PASS_LOGGER(PASS)
This class provides a read-only projection over the MLIR attributes that represent a set of annotatio...
ArrayAttr getArrayAttr() const
Return this annotation set as an ArrayAttr.
This class provides a read-only projection of an annotation.
DictionaryAttr getDict() const
Get the data dictionary of this attribute.
unsigned getFieldID() const
Get the field id this attribute targets.
void setMember(StringAttr name, Attribute value)
Add or set a member of the annotation to a value.
This class implements the same functionality as TypeSwitch except that it uses firrtl::type_dyn_cast ...
FIRRTLTypeSwitch< T, ResultT > & Case(CallableT &&caseFn)
Add a case on the given type.
FIRRTLVisitor allows you to visit all of the expr/stmt/decls with one class declaration.
ResultType visitInvalidOp(Operation *op, ExtraArgs... args)
visitInvalidOp is an override point for non-FIRRTL dialect operations.
This is an edge in the InstanceGraph.
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
@ All
Preserve all aggregate values.
@ OneDimVec
Preserve only 1d vectors of ground type (e.g. UInt<2>[3]).
@ Vec
Preserve only vectors (e.g. UInt<2>[3][3]).
@ None
Don't preserve aggregate at all.
mlir::DenseBoolArrayAttr packAttribute(MLIRContext *context, ArrayRef< Direction > directions)
Return a DenseBoolArrayAttr containing the packed representation of an array of directions.
static Direction get(bool isOutput)
Return an output direction if isOutput is true, otherwise return an input direction.
Direction
This represents the direction of a single port.
FIRRTLBaseType getBaseType(Type type)
If it is a base type, return it as is.
FIRRTLType mapBaseType(FIRRTLType type, function_ref< FIRRTLBaseType(FIRRTLBaseType)> fn)
Return a FIRRTLType with its base type component mutated by the given function.
bool hasZeroBitWidth(FIRRTLType type)
Return true if the type has zero bit width.
void emitConnect(OpBuilder &builder, Location loc, Value lhs, Value rhs)
Emit a connect between two values.
StringAttr getInnerSymName(Operation *op)
Return the StringAttr for the inner_sym name, if it exists.
std::optional< int64_t > getBitWidth(FIRRTLBaseType type, bool ignoreFlip=false)
uint64_t getMaxFieldID(Type)
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
This holds the name and type that describes the module's ports.
AnnotationSet annotations