39#include "mlir/IR/ImplicitLocOpBuilder.h"
40#include "mlir/IR/Threading.h"
41#include "mlir/Pass/Pass.h"
42#include "llvm/ADT/APSInt.h"
43#include "llvm/ADT/BitVector.h"
44#include "llvm/Support/Debug.h"
46#define DEBUG_TYPE "firrtl-lower-types"
50#define GEN_PASS_DEF_LOWERFIRRTLTYPES
51#include "circt/Dialect/FIRRTL/Passes.h.inc"
56using namespace firrtl;
61struct FlatBundleFieldEntry {
69 SmallString<16> suffix;
74 unsigned fieldID, StringRef suffix,
bool isOutput)
75 : type(type), index(index), fieldID(fieldID), suffix(suffix),
79 llvm::errs() <<
"FBFE{" << type <<
" index<" << index <<
"> fieldID<"
80 << fieldID <<
"> suffix<" << suffix <<
"> isOutput<"
81 << isOutput <<
">}\n";
88 return mapBaseType(type, [&](
auto) {
return fieldType; });
93 auto ftype = type_dyn_cast<FIRRTLType>(type);
102 .
Case<BundleType>([&](
auto bundle) {
return false; })
103 .Case<FVectorType>([&](FVectorType vector) {
105 return vector.getElementType().isGround() &&
106 vector.getNumElements() > 1;
108 .Default([](
auto groundType) {
return true; });
115 .
Case<BundleType>([&](
auto bundle) {
return true; })
116 .Case<FVectorType>([&](FVectorType vector) {
119 .Default([](
auto groundType) {
return false; });
126 if (
auto refType = type_dyn_cast<RefType>(type)) {
128 if (refType.getForceable())
139 auto firrtlType = type_dyn_cast<FIRRTLBaseType>(type);
145 if (!firrtlType.isPassive() || firrtlType.containsAnalog() ||
157 llvm_unreachable(
"unexpected mode");
163static bool peelType(Type type, SmallVectorImpl<FlatBundleFieldEntry> &fields,
170 if (
auto refType = type_dyn_cast<RefType>(type))
171 type = refType.getType();
173 .
Case<BundleType>([&](
auto bundle) {
174 SmallString<16> tmpSuffix;
176 for (
size_t i = 0, e = bundle.getNumElements(); i < e; ++i) {
177 auto elt = bundle.getElement(i);
180 tmpSuffix.push_back(
'_');
181 tmpSuffix.append(elt.name.getValue());
182 fields.emplace_back(elt.type, i, bundle.getFieldID(i), tmpSuffix,
187 .Case<FVectorType>([&](
auto vector) {
189 for (
size_t i = 0, e = vector.getNumElements(); i != e; ++i) {
190 fields.emplace_back(vector.getElementType(), i, vector.getFieldID(i),
191 "_" + std::to_string(i),
false);
195 .Default([](
auto op) {
return false; });
201 SubaccessOp sao = llvm::dyn_cast<SubaccessOp>(op);
205 llvm::dyn_cast_or_null<ConstantOp>(sao.getIndex().getDefiningOp());
206 return arg && sao.getInput().getType().base().getNumElements() != 0;
211 SmallVector<Operation *> retval;
212 auto defOp = op->getOperand(0).getDefiningOp();
213 while (isa_and_nonnull<SubfieldOp, SubindexOp, SubaccessOp>(defOp)) {
214 retval.push_back(defOp);
215 defOp = defOp->getOperand(0).getDefiningOp();
225 FlatBundleFieldEntry field) {
226 SmallVector<Type, 8> ports;
227 SmallVector<Attribute, 8> portNames;
228 SmallVector<Attribute, 8> portLocations;
230 auto oldPorts = op.getPorts();
231 for (
size_t portIdx = 0, e = oldPorts.size(); portIdx < e; ++portIdx) {
232 auto port = oldPorts[portIdx];
234 MemOp::getTypeForPort(op.getDepth(), field.type, port.second));
235 portNames.push_back(port.first);
240 MemOp::create(*b, ports, op.getReadLatency(), op.getWriteLatency(),
241 op.getDepth(), op.getRuw(), b->getArrayAttr(portNames),
242 (op.getName() + field.suffix).str(), op.getNameKind(),
243 op.getAnnotations(), op.getPortAnnotations(),
244 op.getInnerSymAttr(), op.getInitAttr(), op.getPrefixAttr());
246 if (op.getInnerSym()) {
247 op.emitError(
"cannot split memory with symbol present");
251 SmallVector<Attribute> newAnnotations;
252 for (
size_t portIdx = 0, e = newMem.getNumResults(); portIdx < e; ++portIdx) {
253 auto portType = type_cast<BundleType>(newMem.getResult(portIdx).getType());
254 auto oldPortType = type_cast<BundleType>(op.getResult(portIdx).getType());
255 SmallVector<Attribute> portAnno;
256 for (
auto attr : newMem.getPortAnnotation(portIdx)) {
259 auto targetIndex = oldPortType.getIndexForFieldID(annoFieldID);
263 if (annoFieldID == oldPortType.getFieldID(targetIndex)) {
266 b->getI32IntegerAttr(portType.getFieldID(targetIndex)));
267 portAnno.push_back(anno.
getDict());
272 if (type_isa<BundleType>(oldPortType.getElement(targetIndex).type)) {
277 auto fieldID = field.fieldID + oldPortType.getFieldID(targetIndex);
278 if (annoFieldID >= fieldID &&
283 annoFieldID - fieldID + portType.getFieldID(targetIndex);
284 anno.
setMember(
"circt.fieldID", b->getI32IntegerAttr(newFieldID));
285 portAnno.push_back(anno.
getDict());
289 portAnno.push_back(attr);
291 newAnnotations.push_back(b->getArrayAttr(portAnno));
293 newMem.setAllPortAnnotations(newAnnotations);
303 AttrCache(MLIRContext *
context) {
304 i64ty = IntegerType::get(
context, 64);
305 nameAttr = StringAttr::get(
context,
"name");
306 nameKindAttr = StringAttr::get(
context,
"nameKind");
307 sPortDirections = StringAttr::get(
context,
"portDirections");
308 sPortNames = StringAttr::get(
context,
"portNames");
309 sPortTypes = StringAttr::get(
context,
"portTypes");
310 sPortSymbols = StringAttr::get(
context,
"portSymbols");
311 sPortLocations = StringAttr::get(
context,
"portLocations");
312 sPortAnnotations = StringAttr::get(
context,
"portAnnotations");
313 sPortDomains = StringAttr::get(
context,
"domainInfo");
314 sEmpty = StringAttr::get(
context,
"");
315 aEmpty = ArrayAttr::get(
context, {});
317 AttrCache(
const AttrCache &) =
default;
320 StringAttr nameAttr, nameKindAttr, sPortDirections, sPortNames, sPortTypes,
321 sPortSymbols, sPortLocations, sPortAnnotations, sPortDomains, sEmpty;
328class DomainLoweringHelper {
333 DomainLoweringHelper(MLIRContext *context, ArrayRef<Attribute> portTypes)
335 for (
auto [index, typeAttr] :
llvm::enumerate(portTypes))
336 if (
type_isa<DomainType>(cast<TypeAttr>(typeAttr).getValue()))
337 domainIndexByOrdinal.push_back(index);
341 DomainLoweringHelper(MLIRContext *context, TypeRange resultTypes)
343 for (
auto [index, type] :
llvm::enumerate(resultTypes))
345 domainIndexByOrdinal.push_back(index);
352 void computeDomainMap(TypeRange types) {
353 size_t i = 0, ord = 0;
354 for (
auto type : types) {
355 if (type_isa<DomainType>(type))
356 domainMap[domainIndexByOrdinal[ord++]] = i;
365 void computeDomainMap(ArrayRef<PortInfo> ports) {
366 size_t i = 0, ord = 0;
367 for (
const auto &port : ports) {
368 if (type_isa<DomainType>(port.type))
369 domainMap[domainIndexByOrdinal[ord++]] = i;
377 void rewriteDomain(Attribute &domain) {
378 auto oldAssociations = dyn_cast<ArrayAttr>(domain);
379 if (!oldAssociations)
381 SmallVector<Attribute> newAssociations;
382 for (
auto oldAttr : oldAssociations)
383 newAssociations.push_back(IntegerAttr::
get(
384 IntegerType::
get(context, 32, IntegerType::Unsigned),
385 domainMap[cast<IntegerAttr>(oldAttr).getValue().getZExtValue()]));
386 domain = ArrayAttr::get(context, newAssociations);
390 MLIRContext *context;
392 SmallVector<unsigned> domainIndexByOrdinal;
394 DenseMap<unsigned, unsigned> domainMap;
399struct TypeLoweringVisitor :
public FIRRTLVisitor<TypeLoweringVisitor, bool> {
403 Convention bodyConvention,
405 SymbolTable &symTbl,
const AttrCache &cache,
406 const llvm::DenseMap<FModuleLike, Convention> &conventionTable)
407 : context(context), defaultAggregatePreservationMode(preserveAggregate),
408 memoryPreservationMode(memoryPreservationMode), symTbl(symTbl),
409 cache(cache), conventionTable(conventionTable) {
410 bodyAggregatePreservationMode = bodyConvention == Convention::Scalarized
412 : defaultAggregatePreservationMode;
420 void lowerModule(FModuleLike op);
422 bool lowerArg(FModuleLike module,
size_t argIndex,
size_t argsRemoved,
423 SmallVectorImpl<PortInfo> &newArgs,
424 SmallVectorImpl<Value> &lowering);
425 std::pair<Value, PortInfo> addArg(Operation *module,
unsigned insertPt,
427 const FlatBundleFieldEntry &field,
428 PortInfo &oldArg, hw::InnerSymAttr newSym);
431 bool visitDecl(FExtModuleOp op);
432 bool visitDecl(FModuleOp op);
433 bool visitDecl(InstanceOp op);
434 bool visitDecl(MemOp op);
435 bool visitDecl(NodeOp op);
436 bool visitDecl(RegOp op);
437 bool visitDecl(WireOp op);
438 bool visitDecl(RegResetOp op);
439 bool visitExpr(InvalidValueOp op);
440 bool visitExpr(SubaccessOp op);
441 bool visitExpr(VectorCreateOp op);
442 bool visitExpr(BundleCreateOp op);
443 bool visitExpr(ElementwiseAndPrimOp op);
444 bool visitExpr(ElementwiseOrPrimOp op);
445 bool visitExpr(ElementwiseXorPrimOp op);
446 bool visitExpr(MultibitMuxOp op);
447 bool visitExpr(MuxPrimOp op);
448 bool visitExpr(Mux2CellIntrinsicOp op);
449 bool visitExpr(Mux4CellIntrinsicOp op);
450 bool visitExpr(BitCastOp op);
451 bool visitExpr(RefSendOp op);
452 bool visitExpr(RefResolveOp op);
453 bool visitExpr(RefCastOp op);
454 bool visitStmt(ConnectOp op);
455 bool visitStmt(MatchingConnectOp op);
456 bool visitStmt(RefDefineOp op);
457 bool visitStmt(WhenOp op);
458 bool visitStmt(LayerBlockOp op);
459 bool visitUnrealizedConversionCast(mlir::UnrealizedConversionCastOp op);
461 bool isFailed()
const {
return encounteredError; }
464 if (
auto castOp = dyn_cast<mlir::UnrealizedConversionCastOp>(op))
465 return visitUnrealizedConversionCast(castOp);
470 void processUsers(Value val, ArrayRef<Value> mapping);
471 bool processSAPath(Operation *);
472 void lowerBlock(Block *);
473 void lowerSAWritePath(Operation *, ArrayRef<Operation *> writePath);
483 llvm::function_ref<Value(
const FlatBundleFieldEntry &, ArrayAttr)> clone,
488 ArrayAttr filterAnnotations(MLIRContext *ctxt, ArrayAttr annotations,
489 FIRRTLType srcType, FlatBundleFieldEntry field);
493 LogicalResult partitionSymbols(hw::InnerSymAttr sym,
FIRRTLType parentType,
494 SmallVectorImpl<hw::InnerSymAttr> &newSyms,
498 getPreservationModeForPorts(FModuleLike moduleLike);
499 Value getSubWhatever(Value val,
size_t index);
501 size_t uniqueIdx = 0;
502 std::string uniqueName() {
503 auto myID = uniqueIdx++;
504 return (Twine(
"__GEN_") + Twine(myID)).str();
515 ImplicitLocOpBuilder *builder;
521 const AttrCache &cache;
523 const llvm::DenseMap<FModuleLike, Convention> &conventionTable;
526 bool encounteredError =
false;
533TypeLoweringVisitor::getPreservationModeForPorts(FModuleLike module) {
534 auto lookup = conventionTable.find(module);
535 if (lookup == conventionTable.end())
536 return defaultAggregatePreservationMode;
537 switch (lookup->second) {
538 case Convention::Scalarized:
540 case Convention::Internal:
541 return defaultAggregatePreservationMode;
543 llvm_unreachable(
"Unknown convention");
544 return defaultAggregatePreservationMode;
547Value TypeLoweringVisitor::getSubWhatever(Value val,
size_t index) {
548 if (type_isa<BundleType>(val.getType()))
549 return SubfieldOp::create(*builder, val, index);
550 if (type_isa<FVectorType>(val.getType()))
551 return SubindexOp::create(*builder, val, index);
552 if (type_isa<RefType>(val.getType()))
553 return RefSubOp::create(*builder, val, index);
554 llvm_unreachable(
"Unknown aggregate type");
559bool TypeLoweringVisitor::processSAPath(Operation *op) {
562 if (writePath.empty())
565 lowerSAWritePath(op, writePath);
568 op->eraseOperands(0, 2);
570 for (
size_t i = 0; i < writePath.size(); ++i) {
571 if (writePath[i]->use_empty()) {
572 writePath[i]->erase();
580void TypeLoweringVisitor::lowerBlock(Block *block) {
582 for (
auto it = block->rbegin(), e = block->rend(); it != e;) {
584 builder->setInsertionPoint(&iop);
585 builder->setLoc(iop.getLoc());
586 bool removeOp = dispatchVisitor(&iop);
595ArrayAttr TypeLoweringVisitor::filterAnnotations(MLIRContext *ctxt,
596 ArrayAttr annotations,
598 FlatBundleFieldEntry field) {
599 SmallVector<Attribute> retval;
600 if (!annotations || annotations.empty())
601 return ArrayAttr::get(ctxt, retval);
602 for (
auto opAttr : annotations) {
604 auto fieldID = anno.getFieldID();
605 anno.removeMember(
"circt.fieldID");
610 retval.push_back(anno.getAttr());
615 if (fieldID < field.fieldID ||
620 if (
auto newFieldID = fieldID - field.fieldID) {
623 anno.setMember(
"circt.fieldID", builder->getI32IntegerAttr(newFieldID));
626 retval.push_back(anno.getAttr());
628 return ArrayAttr::get(ctxt, retval);
631LogicalResult TypeLoweringVisitor::partitionSymbols(
633 SmallVectorImpl<hw::InnerSymAttr> &newSyms, Location errorLoc) {
636 if (!sym || sym.empty())
639 auto *
context = sym.getContext();
643 return mlir::emitError(errorLoc,
644 "unable to partition symbol on unsupported type ")
647 return TypeSwitch<FIRRTLType, LogicalResult>(baseType)
648 .Case<BundleType, FVectorType>([&](
auto aggType) -> LogicalResult {
652 hw::InnerSymPropertiesAttr prop;
656 SmallVector<BinningInfo> binning;
657 for (
auto prop : sym) {
658 auto fieldID = prop.getFieldID();
661 return mlir::emitError(errorLoc,
"unable to lower due to symbol ")
663 <<
" with target not preserved by lowering";
664 auto [index, relFieldID] = aggType.getIndexAndSubfieldID(fieldID);
665 binning.push_back({index, relFieldID, prop});
669 llvm::stable_sort(binning, [&](
auto &lhs,
auto &rhs) {
670 return std::tuple(lhs.index, lhs.relFieldID) <
671 std::tuple(rhs.index, rhs.relFieldID);
676 newSyms.resize(aggType.getNumElements());
677 for (
auto binIt = binning.begin(), binEnd = binning.end();
679 auto curIndex = binIt->index;
680 SmallVector<hw::InnerSymPropertiesAttr> propsForIndex;
682 while (binIt != binEnd && binIt->index == curIndex) {
683 propsForIndex.push_back(hw::InnerSymPropertiesAttr::get(
684 context, binIt->prop.getName(), binIt->relFieldID,
685 binIt->prop.getSymVisibility()));
689 assert(!newSyms[curIndex]);
690 newSyms[curIndex] = hw::InnerSymAttr::get(
context, propsForIndex);
694 .Default([&](
auto ty) {
695 return mlir::emitError(
696 errorLoc,
"unable to partition symbol on unsupported type ")
701bool TypeLoweringVisitor::lowerProducer(
703 llvm::function_ref<Value(
const FlatBundleFieldEntry &, ArrayAttr)> clone,
707 srcType = op->getResult(0).getType();
708 auto srcFType = type_dyn_cast<FIRRTLType>(srcType);
711 SmallVector<FlatBundleFieldEntry, 8> fieldTypes;
713 if (!
peelType(srcFType, fieldTypes, bodyAggregatePreservationMode))
716 SmallVector<Value> lowered;
718 SmallString<16> loweredName;
719 auto nameKindAttr = op->getAttrOfType<NameKindEnumAttr>(cache.nameKindAttr);
721 if (
auto nameAttr = op->getAttrOfType<StringAttr>(cache.nameAttr))
722 loweredName = nameAttr.getValue();
723 auto baseNameLen = loweredName.size();
724 auto oldAnno = dyn_cast_or_null<ArrayAttr>(op->getAttr(
"annotations"));
726 SmallVector<hw::InnerSymAttr> fieldSyms(fieldTypes.size());
727 if (
auto symOp = dyn_cast<hw::InnerSymbolOpInterface>(op)) {
728 if (failed(partitionSymbols(symOp.getInnerSymAttr(), srcFType, fieldSyms,
730 encounteredError =
true;
735 for (
const auto &[field, sym] :
llvm::zip_equal(fieldTypes, fieldSyms)) {
736 if (!loweredName.empty()) {
737 loweredName.resize(baseNameLen);
738 loweredName += field.suffix;
743 ArrayAttr loweredAttrs =
744 filterAnnotations(
context, oldAnno, srcFType, field);
745 auto newVal = clone(field, loweredAttrs);
751 auto newSymOp = newVal.getDefiningOp<hw::InnerSymbolOpInterface>();
754 "op with inner symbol lowered to op that cannot take inner symbol");
755 newSymOp.setInnerSymbolAttr(sym);
759 if (
auto *newOp = newVal.getDefiningOp()) {
760 if (!loweredName.empty())
761 newOp->setAttr(cache.nameAttr, StringAttr::get(
context, loweredName));
763 newOp->setAttr(cache.nameKindAttr, nameKindAttr);
766 newOp->setDiscardableAttrs(op->getDiscardableAttrDictionary());
768 lowered.push_back(newVal);
771 processUsers(op->getResult(0), lowered);
775void TypeLoweringVisitor::processUsers(Value val, ArrayRef<Value> mapping) {
776 for (
auto *user :
llvm::make_early_inc_range(val.getUsers())) {
777 TypeSwitch<Operation *, void>(user)
778 .Case<SubindexOp>([mapping](SubindexOp sio) {
779 Value repl = mapping[sio.getIndex()];
780 sio.replaceAllUsesWith(repl);
783 .Case<SubfieldOp>([mapping](SubfieldOp sfo) {
785 Value repl = mapping[sfo.getFieldIndex()];
786 sfo.replaceAllUsesWith(repl);
789 .Case<RefSubOp>([mapping](RefSubOp refSub) {
790 Value repl = mapping[refSub.getIndex()];
791 refSub.replaceAllUsesWith(repl);
794 .Default([&](
auto op) {
805 ImplicitLocOpBuilder b(user->getLoc(), user);
809 assert(llvm::none_of(mapping, [](
auto v) {
810 auto fbasetype = type_dyn_cast<FIRRTLBaseType>(v.getType());
811 return !fbasetype || fbasetype.containsReference();
815 TypeSwitch<Type, Value>(val.getType())
816 .template Case<FVectorType>([&](
auto vecType) {
817 return b.createOrFold<VectorCreateOp>(vecType, mapping);
819 .
template Case<BundleType>([&](
auto bundleType) {
820 return b.createOrFold<BundleCreateOp>(bundleType, mapping);
822 .Default([&](
auto _) -> Value {
return {}; });
824 user->emitError(
"unable to reconstruct source of type ")
826 encounteredError =
true;
829 user->replaceUsesOfWith(val, input);
837 const llvm::BitVector &removalMask) {
838 size_t writeIndex = 0, readIndex = 0;
842 for (
size_t removalIndex : removalMask.set_bits()) {
844 assert(removalIndex >= readIndex &&
"removal index before read index");
845 size_t rangeSize = removalIndex - readIndex;
850 if (writeIndex != readIndex)
851 std::move(vec.begin() + readIndex, vec.begin() + removalIndex,
852 vec.begin() + writeIndex);
853 writeIndex += rangeSize;
855 readIndex = removalIndex + 1;
859 size_t remainingSize = vec.size() - readIndex;
860 if (remainingSize > 0) {
861 if (writeIndex != readIndex)
862 std::move(vec.begin() + readIndex, vec.end(), vec.begin() + writeIndex);
863 writeIndex += remainingSize;
867 vec.truncate(writeIndex);
870void TypeLoweringVisitor::lowerModule(FModuleLike op) {
871 if (
auto module = llvm::dyn_cast<FModuleOp>(*op))
873 else if (
auto extModule = llvm::dyn_cast<FExtModuleOp>(*op))
874 visitDecl(extModule);
880std::pair<Value, PortInfo>
881TypeLoweringVisitor::addArg(Operation *module,
unsigned insertPt,
883 const FlatBundleFieldEntry &field,
PortInfo &oldArg,
884 hw::InnerSymAttr newSym) {
887 if (
auto mod = llvm::dyn_cast<FModuleOp>(module)) {
888 Block *body = mod.getBodyBlock();
890 newValue = body->insertArgument(insertPt, fieldType, oldArg.
loc);
894 auto name = builder->getStringAttr(oldArg.
name.getValue() + field.suffix);
897 auto newAnnotations = filterAnnotations(
902 return std::make_pair(
903 newValue,
PortInfo{name, fieldType, direction, newSym, oldArg.
loc,
908bool TypeLoweringVisitor::lowerArg(FModuleLike module,
size_t argIndex,
910 SmallVectorImpl<PortInfo> &newArgs,
911 SmallVectorImpl<Value> &lowering) {
914 SmallVector<FlatBundleFieldEntry> fieldTypes;
915 auto srcType = type_cast<FIRRTLType>(newArgs[argIndex].type);
916 if (!
peelType(srcType, fieldTypes, getPreservationModeForPorts(module)))
919 SmallVector<hw::InnerSymAttr> fieldSyms(fieldTypes.size());
920 if (failed(partitionSymbols(newArgs[argIndex].sym, srcType, fieldSyms,
921 newArgs[argIndex].loc))) {
922 encounteredError =
true;
926 for (
const auto &[idx, field, fieldSym] :
927 llvm::enumerate(fieldTypes, fieldSyms)) {
928 auto newValue = addArg(module, 1 + argIndex + idx, argsRemoved, srcType,
929 field, newArgs[argIndex], fieldSym);
930 newArgs.insert(newArgs.begin() + 1 + argIndex + idx, newValue.second);
932 lowering.push_back(newValue.first);
937static Value
cloneAccess(ImplicitLocOpBuilder *builder, Operation *op,
939 if (
auto rop = llvm::dyn_cast<SubfieldOp>(op))
940 return SubfieldOp::create(*builder, rhs, rop.getFieldIndex());
941 if (
auto rop = llvm::dyn_cast<SubindexOp>(op))
942 return SubindexOp::create(*builder, rhs, rop.getIndex());
943 if (
auto rop = llvm::dyn_cast<SubaccessOp>(op))
944 return SubaccessOp::create(*builder, rhs, rop.getIndex());
945 op->emitError(
"Unknown accessor");
949void TypeLoweringVisitor::lowerSAWritePath(Operation *op,
950 ArrayRef<Operation *> writePath) {
951 SubaccessOp sao = cast<SubaccessOp>(writePath.back());
952 FVectorType saoType = sao.getInput().getType();
953 auto selectWidth = llvm::Log2_64_Ceil(saoType.getNumElements());
955 for (
size_t index = 0, e = saoType.getNumElements(); index < e; ++index) {
956 auto cond = EQPrimOp::create(
957 *builder, sao.getIndex(),
958 builder->createOrFold<ConstantOp>(UIntType::get(
context, selectWidth),
959 APInt(selectWidth, index)));
960 WhenOp::create(*builder, cond,
false, [&]() {
962 Value leaf = SubindexOp::create(*builder, sao.getInput(), index);
963 for (
int i = writePath.size() - 2; i >= 0; --i) {
964 if (
auto access =
cloneAccess(builder, writePath[i], leaf))
967 encounteredError =
true;
978bool TypeLoweringVisitor::visitStmt(ConnectOp op) {
979 if (processSAPath(op))
983 SmallVector<FlatBundleFieldEntry> fields;
990 for (
const auto &field :
llvm::enumerate(fields)) {
991 Value src = getSubWhatever(op.getSrc(), field.index());
992 Value dest = getSubWhatever(op.getDest(), field.index());
993 if (field.value().isOutput)
994 std::swap(src, dest);
1001bool TypeLoweringVisitor::visitStmt(MatchingConnectOp op) {
1002 if (processSAPath(op))
1006 SmallVector<FlatBundleFieldEntry> fields;
1013 for (
const auto &field :
llvm::enumerate(fields)) {
1014 Value src = getSubWhatever(op.getSrc(), field.index());
1015 Value dest = getSubWhatever(op.getDest(), field.index());
1016 if (field.value().isOutput)
1017 std::swap(src, dest);
1018 MatchingConnectOp::create(*builder, dest, src);
1024bool TypeLoweringVisitor::visitStmt(RefDefineOp op) {
1026 SmallVector<FlatBundleFieldEntry> fields;
1028 if (!
peelType(op.getDest().getType(), fields, bodyAggregatePreservationMode))
1032 for (
const auto &field :
llvm::enumerate(fields)) {
1033 Value src = getSubWhatever(op.getSrc(), field.index());
1034 Value dest = getSubWhatever(op.getDest(), field.index());
1035 assert(!field.value().isOutput &&
"unexpected flip in reftype destination");
1036 RefDefineOp::create(*builder, dest, src);
1041bool TypeLoweringVisitor::visitStmt(WhenOp op) {
1047 lowerBlock(&op.getThenBlock());
1050 if (op.hasElseRegion())
1051 lowerBlock(&op.getElseBlock());
1056bool TypeLoweringVisitor::visitStmt(LayerBlockOp op) {
1057 lowerBlock(op.getBody());
1063bool TypeLoweringVisitor::visitDecl(MemOp op) {
1065 SmallVector<FlatBundleFieldEntry> fields;
1068 if (!
peelType(op.getDataType(), fields, memoryPreservationMode))
1071 if (op.getInnerSym()) {
1072 op->emitError() <<
"has a symbol, but no symbols may exist on aggregates "
1073 "passed through LowerTypes";
1074 encounteredError =
true;
1078 SmallVector<MemOp> newMemories;
1079 SmallVector<WireOp> oldPorts;
1082 for (
unsigned int index = 0, end = op.getNumResults(); index <
end; ++index) {
1083 auto result = op.getResult(index);
1084 if (op.getPortKind(index) == MemOp::PortKind::Debug) {
1085 op.emitOpError(
"cannot lower memory with debug port");
1086 encounteredError =
true;
1090 WireOp::create(*builder, result.getType(),
1091 (op.getName() +
"_" + op.getPortName(index)).str());
1092 oldPorts.push_back(wire);
1093 result.replaceAllUsesWith(wire.getResult());
1100 for (
const auto &field : fields) {
1102 if (!newMemForField) {
1103 op.emitError(
"failed cloning memory for field");
1104 encounteredError =
true;
1107 newMemories.push_back(newMemForField);
1110 for (
size_t index = 0, rend = op.getNumResults(); index < rend; ++index) {
1111 auto result = oldPorts[index].getResult();
1112 auto rType = type_cast<BundleType>(result.getType());
1113 for (
size_t fieldIndex = 0, fend = rType.getNumElements();
1114 fieldIndex != fend; ++fieldIndex) {
1115 auto name = rType.getElement(fieldIndex).name.getValue();
1116 auto oldField = SubfieldOp::create(*builder, result, fieldIndex);
1119 if (name ==
"data" || name ==
"mask" || name ==
"wdata" ||
1120 name ==
"wmask" || name ==
"rdata") {
1121 for (
const auto &field : fields) {
1122 auto realOldField = getSubWhatever(oldField, field.index);
1123 auto newField = getSubWhatever(
1124 newMemories[field.index].getResult(index), fieldIndex);
1125 if (rType.getElement(fieldIndex).isFlip)
1126 std::swap(realOldField, newField);
1130 for (
auto mem : newMemories) {
1132 SubfieldOp::create(*builder, mem.getResult(index), fieldIndex);
1141bool TypeLoweringVisitor::visitDecl(FExtModuleOp extModule) {
1142 ImplicitLocOpBuilder theBuilder(extModule.getLoc(),
context);
1143 builder = &theBuilder;
1149 llvm::BitVector argsToRemove;
1150 auto newArgs = extModule.getPorts();
1151 argsToRemove.reserve(newArgs.size());
1153 DomainLoweringHelper domainHelper(
context, extModule.getPortTypes());
1155 size_t argsRemoved = 0;
1156 for (
size_t argIndex = 0; argIndex < newArgs.size(); ++argIndex) {
1157 SmallVector<Value> lowering;
1158 if (lowerArg(extModule, argIndex, argsRemoved, newArgs, lowering)) {
1159 argsToRemove.push_back(
true);
1162 argsToRemove.push_back(
false);
1168 if (argsRemoved != 0)
1171 domainHelper.computeDomainMap(newArgs);
1173 SmallVector<NamedAttribute, 8> newModuleAttrs;
1176 for (
auto attr : extModule->getAttrDictionary())
1179 if (attr.
getName() !=
"portDirections" && attr.
getName() !=
"portNames" &&
1180 attr.
getName() !=
"portTypes" && attr.
getName() !=
"portAnnotations" &&
1181 attr.
getName() !=
"portSymbols" && attr.
getName() !=
"portLocations")
1182 newModuleAttrs.push_back(attr);
1184 SmallVector<Direction> newArgDirections;
1185 SmallVector<Attribute> newArgNames;
1186 SmallVector<Attribute, 8> newArgTypes;
1187 SmallVector<Attribute, 8> newArgSyms;
1188 SmallVector<Attribute, 8> newArgLocations;
1189 SmallVector<Attribute, 8> newArgAnnotations;
1190 SmallVector<Attribute, 8> newArgDomains;
1192 for (
auto &port : newArgs) {
1193 newArgDirections.push_back(port.direction);
1194 newArgNames.push_back(port.name);
1195 newArgTypes.push_back(TypeAttr::get(port.type));
1196 newArgSyms.push_back(port.sym);
1197 newArgLocations.push_back(port.loc);
1198 newArgAnnotations.push_back(port.annotations.getArrayAttr());
1200 domainHelper.rewriteDomain(port.domains);
1202 port.domains = cache.aEmpty;
1204 newArgDomains.push_back(port.domains);
1207 newModuleAttrs.push_back(
1208 NamedAttribute(cache.sPortDirections,
1211 newModuleAttrs.push_back(
1212 NamedAttribute(cache.sPortNames, builder.getArrayAttr(newArgNames)));
1214 newModuleAttrs.push_back(
1215 NamedAttribute(cache.sPortTypes, builder.getArrayAttr(newArgTypes)));
1217 newModuleAttrs.push_back(NamedAttribute(
1218 cache.sPortLocations, builder.getArrayAttr(newArgLocations)));
1220 newModuleAttrs.push_back(NamedAttribute(
1221 cache.sPortAnnotations, builder.getArrayAttr(newArgAnnotations)));
1223 newModuleAttrs.push_back(
1224 NamedAttribute(cache.sPortDomains, builder.getArrayAttr(newArgDomains)));
1227 extModule->setAttrs(newModuleAttrs);
1228 FModuleLike::fixupPortSymsArray(newArgSyms,
context);
1229 extModule.setPortSymbols(newArgSyms);
1234bool TypeLoweringVisitor::visitDecl(FModuleOp module) {
1235 auto *body =
module.getBodyBlock();
1237 ImplicitLocOpBuilder theBuilder(module.getLoc(),
context);
1238 builder = &theBuilder;
1244 llvm::BitVector argsToRemove;
1245 auto newArgs =
module.getPorts();
1246 argsToRemove.reserve(newArgs.size());
1248 DomainLoweringHelper domainHelper(
context, module.getPortTypes());
1250 size_t argsRemoved = 0;
1251 for (
size_t argIndex = 0; argIndex < newArgs.size(); ++argIndex) {
1252 SmallVector<Value> lowerings;
1253 if (lowerArg(module, argIndex, argsRemoved, newArgs, lowerings)) {
1254 auto arg =
module.getArgument(argIndex);
1255 processUsers(arg, lowerings);
1256 argsToRemove.push_back(
true);
1259 argsToRemove.push_back(
false);
1264 if (argsRemoved != 0) {
1265 body->eraseArguments(argsToRemove);
1269 domainHelper.computeDomainMap(newArgs);
1271 SmallVector<NamedAttribute, 8> newModuleAttrs;
1274 for (
auto attr : module->getAttrDictionary())
1277 if (attr.
getName() !=
"portNames" && attr.
getName() !=
"portDirections" &&
1278 attr.
getName() !=
"portTypes" && attr.
getName() !=
"portAnnotations" &&
1279 attr.
getName() !=
"portSymbols" && attr.
getName() !=
"portLocations")
1280 newModuleAttrs.push_back(attr);
1282 SmallVector<Direction> newArgDirections;
1283 SmallVector<Attribute> newArgNames;
1284 SmallVector<Attribute> newArgTypes;
1285 SmallVector<Attribute> newArgSyms;
1286 SmallVector<Attribute> newArgLocations;
1287 SmallVector<Attribute, 8> newArgAnnotations;
1288 SmallVector<Attribute> newPortDomains;
1289 for (
auto &port : newArgs) {
1290 newArgDirections.push_back(port.direction);
1291 newArgNames.push_back(port.name);
1292 newArgTypes.push_back(TypeAttr::get(port.type));
1293 newArgSyms.push_back(port.sym);
1294 newArgLocations.push_back(port.loc);
1295 newArgAnnotations.push_back(port.annotations.getArrayAttr());
1297 domainHelper.rewriteDomain(port.domains);
1299 port.domains = cache.aEmpty;
1301 newPortDomains.push_back(port.domains);
1304 newModuleAttrs.push_back(
1305 NamedAttribute(cache.sPortDirections,
1308 newModuleAttrs.push_back(
1309 NamedAttribute(cache.sPortNames, builder->getArrayAttr(newArgNames)));
1311 newModuleAttrs.push_back(
1312 NamedAttribute(cache.sPortTypes, builder->getArrayAttr(newArgTypes)));
1314 newModuleAttrs.push_back(NamedAttribute(
1315 cache.sPortLocations, builder->getArrayAttr(newArgLocations)));
1317 newModuleAttrs.push_back(NamedAttribute(
1318 cache.sPortAnnotations, builder->getArrayAttr(newArgAnnotations)));
1320 newModuleAttrs.push_back(NamedAttribute(
1321 cache.sPortDomains, builder->getArrayAttr(newPortDomains)));
1324 module->setAttrs(newModuleAttrs);
1325 FModuleLike::fixupPortSymsArray(newArgSyms,
context);
1326 module.setPortSymbols(newArgSyms);
1331bool TypeLoweringVisitor::visitDecl(WireOp op) {
1332 if (op.isForceable())
1335 auto clone = [&](
const FlatBundleFieldEntry &field,
1336 ArrayAttr attrs) -> Value {
1337 return WireOp::create(*builder,
1339 "", NameKindEnum::DroppableName, attrs, StringAttr{})
1342 return lowerProducer(op, clone);
1346bool TypeLoweringVisitor::visitDecl(RegOp op) {
1347 if (op.isForceable())
1350 auto clone = [&](
const FlatBundleFieldEntry &field,
1351 ArrayAttr attrs) -> Value {
1352 return RegOp::create(*builder, field.type, op.getClockVal(),
"",
1353 NameKindEnum::DroppableName, attrs, StringAttr{})
1356 return lowerProducer(op, clone);
1360bool TypeLoweringVisitor::visitDecl(RegResetOp op) {
1361 if (op.isForceable())
1364 auto clone = [&](
const FlatBundleFieldEntry &field,
1365 ArrayAttr attrs) -> Value {
1366 auto resetVal = getSubWhatever(op.getResetValue(), field.index);
1367 return RegResetOp::create(*builder, field.type, op.getClockVal(),
1368 op.getResetSignal(), resetVal,
"",
1369 NameKindEnum::DroppableName, attrs, StringAttr{})
1372 return lowerProducer(op, clone);
1376bool TypeLoweringVisitor::visitDecl(NodeOp op) {
1377 if (op.isForceable())
1380 auto clone = [&](
const FlatBundleFieldEntry &field,
1381 ArrayAttr attrs) -> Value {
1382 auto input = getSubWhatever(op.getInput(), field.index);
1383 return NodeOp::create(*builder, input,
"", NameKindEnum::DroppableName,
1387 return lowerProducer(op, clone);
1391bool TypeLoweringVisitor::visitExpr(InvalidValueOp op) {
1392 auto clone = [&](
const FlatBundleFieldEntry &field,
1393 ArrayAttr attrs) -> Value {
1394 return InvalidValueOp::create(*builder, field.type);
1396 return lowerProducer(op, clone);
1400bool TypeLoweringVisitor::visitExpr(MuxPrimOp op) {
1401 auto clone = [&](
const FlatBundleFieldEntry &field,
1402 ArrayAttr attrs) -> Value {
1403 auto high = getSubWhatever(op.getHigh(), field.index);
1404 auto low = getSubWhatever(op.getLow(), field.index);
1405 return MuxPrimOp::create(*builder, op.getSel(), high, low);
1407 return lowerProducer(op, clone);
1411bool TypeLoweringVisitor::visitExpr(Mux2CellIntrinsicOp op) {
1412 auto clone = [&](
const FlatBundleFieldEntry &field,
1413 ArrayAttr attrs) -> Value {
1414 auto high = getSubWhatever(op.getHigh(), field.index);
1415 auto low = getSubWhatever(op.getLow(), field.index);
1416 return Mux2CellIntrinsicOp::create(*builder, op.getSel(), high, low);
1418 return lowerProducer(op, clone);
1422bool TypeLoweringVisitor::visitExpr(Mux4CellIntrinsicOp op) {
1423 auto clone = [&](
const FlatBundleFieldEntry &field,
1424 ArrayAttr attrs) -> Value {
1425 auto v3 = getSubWhatever(op.getV3(), field.index);
1426 auto v2 = getSubWhatever(op.getV2(), field.index);
1427 auto v1 = getSubWhatever(op.getV1(), field.index);
1428 auto v0 = getSubWhatever(op.getV0(), field.index);
1429 return Mux4CellIntrinsicOp::create(*builder, op.getSel(), v3, v2, v1, v0);
1431 return lowerProducer(op, clone);
1435bool TypeLoweringVisitor::visitUnrealizedConversionCast(
1436 mlir::UnrealizedConversionCastOp op) {
1437 auto clone = [&](
const FlatBundleFieldEntry &field,
1438 ArrayAttr attrs) -> Value {
1439 auto input = getSubWhatever(op.getOperand(0), field.index);
1440 return mlir::UnrealizedConversionCastOp::create(*builder, field.type, input)
1445 if (!type_isa<FIRRTLType>(op->getOperand(0).getType()))
1447 return lowerProducer(op, clone);
1451bool TypeLoweringVisitor::visitExpr(BitCastOp op) {
1452 Value srcLoweredVal = op.getInput();
1456 SmallVector<FlatBundleFieldEntry> fields;
1458 size_t uptoBits = 0;
1461 for (
const auto &field :
llvm::enumerate(fields)) {
1462 auto fieldBitwidth = *
getBitWidth(field.value().type);
1464 if (fieldBitwidth == 0)
1466 Value src = getSubWhatever(op.getInput(), field.index());
1468 src = builder->createOrFold<BitCastOp>(
1469 UIntType::get(
context, fieldBitwidth), src);
1472 srcLoweredVal = src;
1474 if (type_isa<BundleType>(op.getInput().getType())) {
1476 CatPrimOp::create(*builder, ValueRange{srcLoweredVal, src});
1479 CatPrimOp::create(*builder, ValueRange{src, srcLoweredVal});
1483 uptoBits += fieldBitwidth;
1486 srcLoweredVal = builder->createOrFold<AsUIntPrimOp>(srcLoweredVal);
1490 if (type_isa<BundleType, FVectorType>(op.getResult().getType())) {
1492 size_t uptoBits = 0;
1493 auto aggregateBits = *
getBitWidth(op.getResult().getType());
1494 auto clone = [&](
const FlatBundleFieldEntry &field,
1495 ArrayAttr attrs) -> Value {
1501 return InvalidValueOp::create(*builder, field.type);
1506 if (type_isa<BundleType>(op.getResult().getType())) {
1507 extractBits = BitsPrimOp::create(*builder, srcLoweredVal,
1508 aggregateBits - uptoBits - 1,
1509 aggregateBits - uptoBits - fieldBits);
1511 extractBits = BitsPrimOp::create(*builder, srcLoweredVal,
1512 uptoBits + fieldBits - 1, uptoBits);
1514 uptoBits += fieldBits;
1515 return BitCastOp::create(*builder, field.type,
extractBits);
1517 return lowerProducer(op, clone);
1521 if (type_isa<SIntType>(op.getType()))
1522 srcLoweredVal = AsSIntPrimOp::create(*builder, srcLoweredVal);
1523 op.getResult().replaceAllUsesWith(srcLoweredVal);
1527bool TypeLoweringVisitor::visitExpr(RefSendOp op) {
1528 auto clone = [&](
const FlatBundleFieldEntry &field,
1529 ArrayAttr attrs) -> Value {
1530 return RefSendOp::create(*builder,
1531 getSubWhatever(op.getBase(), field.index));
1536 return lowerProducer(op, clone);
1539bool TypeLoweringVisitor::visitExpr(RefResolveOp op) {
1540 auto clone = [&](
const FlatBundleFieldEntry &field,
1541 ArrayAttr attrs) -> Value {
1542 Value src = getSubWhatever(op.getRef(), field.index);
1543 return RefResolveOp::create(*builder, src);
1547 return lowerProducer(op, clone, op.getRef().getType());
1550bool TypeLoweringVisitor::visitExpr(RefCastOp op) {
1551 auto clone = [&](
const FlatBundleFieldEntry &field,
1552 ArrayAttr attrs) -> Value {
1553 auto input = getSubWhatever(op.getInput(), field.index);
1554 return RefCastOp::create(*builder,
1555 RefType::get(field.type,
1556 op.getType().getForceable(),
1557 op.getType().getLayer()),
1560 return lowerProducer(op, clone);
1563bool TypeLoweringVisitor::visitDecl(InstanceOp op) {
1565 SmallVector<Type, 8> resultTypes;
1566 SmallVector<int64_t, 8> endFields;
1567 auto oldPortAnno = op.getPortAnnotations();
1568 SmallVector<Direction> newDirs;
1569 SmallVector<Attribute> newNames;
1570 SmallVector<Attribute> newDomains;
1571 SmallVector<Attribute> newPortAnno;
1573 cast<FModuleLike>(op.getReferencedOperation(symTbl)));
1576 DomainLoweringHelper domainHelper(
context, op.getResultTypes());
1578 endFields.push_back(0);
1579 for (
size_t i = 0, e = op.getNumResults(); i != e; ++i) {
1580 auto srcType = type_cast<FIRRTLType>(op.getType(i));
1583 SmallVector<FlatBundleFieldEntry, 8> fieldTypes;
1584 if (!
peelType(srcType, fieldTypes, mode)) {
1585 newDirs.push_back(op.getPortDirection(i));
1586 newNames.push_back(op.getPortNameAttr(i));
1587 newDomains.push_back(op.getPortDomain(i));
1588 resultTypes.push_back(srcType);
1589 newPortAnno.push_back(oldPortAnno[i]);
1592 auto oldName = op.getPortName(i);
1593 auto oldDir = op.getPortDirection(i);
1595 for (
const auto &field : fieldTypes) {
1596 newDirs.push_back(
direction::get((
unsigned)oldDir ^ field.isOutput));
1597 newNames.push_back(builder->getStringAttr(oldName + field.suffix));
1598 newDomains.push_back(op.getPortDomain(i));
1600 auto annos = filterAnnotations(
1601 context, dyn_cast_or_null<ArrayAttr>(oldPortAnno[i]), srcType,
1603 newPortAnno.push_back(annos);
1606 endFields.push_back(resultTypes.size());
1616 domainHelper.computeDomainMap(resultTypes);
1619 for (
auto &domain : newDomains)
1620 domainHelper.rewriteDomain(domain);
1623 auto newInstance = InstanceOp::create(
1624 *builder, resultTypes, op.getModuleNameAttr(), op.getNameAttr(),
1626 builder->getArrayAttr(newNames), builder->getArrayAttr(newDomains),
1627 op.getAnnotations(), builder->getArrayAttr(newPortAnno),
1628 op.getLayersAttr(), op.getLowerToBindAttr(), op.getDoNotPrintAttr(),
1629 sym ? hw::InnerSymAttr::get(sym) :
hw::InnerSymAttr());
1631 newInstance->setDiscardableAttrs(op->getDiscardableAttrDictionary());
1633 SmallVector<Value> lowered;
1634 for (
size_t aggIndex = 0, eAgg = op.getNumResults(); aggIndex != eAgg;
1637 for (
size_t fieldIndex = endFields[aggIndex],
1638 eField = endFields[aggIndex + 1];
1639 fieldIndex < eField; ++fieldIndex)
1640 lowered.push_back(newInstance.getResult(fieldIndex));
1641 if (lowered.size() != 1 ||
1642 op.getType(aggIndex) != resultTypes[endFields[aggIndex]])
1643 processUsers(op.getResult(aggIndex), lowered);
1645 op.getResult(aggIndex).replaceAllUsesWith(lowered[0]);
1650bool TypeLoweringVisitor::visitExpr(SubaccessOp op) {
1651 auto input = op.getInput();
1652 FVectorType vType = input.getType();
1655 if (vType.getNumElements() == 0) {
1656 Value inv = InvalidValueOp::create(*builder, vType.getElementType());
1657 op.replaceAllUsesWith(inv);
1662 if (ConstantOp arg =
1663 llvm::dyn_cast_or_null<ConstantOp>(op.getIndex().getDefiningOp())) {
1664 auto sio = SubindexOp::create(*builder, op.getInput(),
1665 arg.getValue().getExtValue());
1666 op.replaceAllUsesWith(sio.getResult());
1671 SmallVector<Value> inputs;
1672 inputs.reserve(vType.getNumElements());
1673 for (
int index = vType.getNumElements() - 1; index >= 0; index--)
1674 inputs.push_back(SubindexOp::create(*builder, input, index));
1676 Value multibitMux = MultibitMuxOp::create(*builder, op.getIndex(), inputs);
1677 op.replaceAllUsesWith(multibitMux);
1681bool TypeLoweringVisitor::visitExpr(VectorCreateOp op) {
1682 auto clone = [&](
const FlatBundleFieldEntry &field,
1683 ArrayAttr attrs) -> Value {
1684 return op.getOperand(field.index);
1686 return lowerProducer(op, clone);
1689bool TypeLoweringVisitor::visitExpr(BundleCreateOp op) {
1690 auto clone = [&](
const FlatBundleFieldEntry &field,
1691 ArrayAttr attrs) -> Value {
1692 return op.getOperand(field.index);
1694 return lowerProducer(op, clone);
1697bool TypeLoweringVisitor::visitExpr(ElementwiseOrPrimOp op) {
1698 auto clone = [&](
const FlatBundleFieldEntry &field,
1699 ArrayAttr attrs) -> Value {
1700 Value operands[] = {getSubWhatever(op.getLhs(), field.index),
1701 getSubWhatever(op.getRhs(), field.index)};
1702 return type_isa<BundleType, FVectorType>(field.type)
1703 ? (Value)ElementwiseOrPrimOp::create(*builder, field.type,
1705 : (Value)OrPrimOp::create(*builder, operands);
1708 return lowerProducer(op, clone);
1711bool TypeLoweringVisitor::visitExpr(ElementwiseAndPrimOp op) {
1712 auto clone = [&](
const FlatBundleFieldEntry &field,
1713 ArrayAttr attrs) -> Value {
1714 Value operands[] = {getSubWhatever(op.getLhs(), field.index),
1715 getSubWhatever(op.getRhs(), field.index)};
1716 return type_isa<BundleType, FVectorType>(field.type)
1717 ? (Value)ElementwiseAndPrimOp::create(*builder, field.type,
1719 : (Value)AndPrimOp::create(*builder, operands);
1722 return lowerProducer(op, clone);
1725bool TypeLoweringVisitor::visitExpr(ElementwiseXorPrimOp op) {
1726 auto clone = [&](
const FlatBundleFieldEntry &field,
1727 ArrayAttr attrs) -> Value {
1728 Value operands[] = {getSubWhatever(op.getLhs(), field.index),
1729 getSubWhatever(op.getRhs(), field.index)};
1730 return type_isa<BundleType, FVectorType>(field.type)
1731 ? (Value)ElementwiseXorPrimOp::create(*builder, field.type,
1733 : (Value)XorPrimOp::create(*builder, operands);
1736 return lowerProducer(op, clone);
1739bool TypeLoweringVisitor::visitExpr(MultibitMuxOp op) {
1740 auto clone = [&](
const FlatBundleFieldEntry &field,
1741 ArrayAttr attrs) -> Value {
1742 SmallVector<Value> newInputs;
1743 newInputs.reserve(op.getInputs().size());
1744 for (
auto input : op.getInputs()) {
1745 auto inputSub = getSubWhatever(input, field.index);
1746 newInputs.push_back(inputSub);
1748 return MultibitMuxOp::create(*builder, op.getIndex(), newInputs);
1750 return lowerProducer(op, clone);
1758struct LowerTypesPass
1759 :
public circt::firrtl::impl::LowerFIRRTLTypesBase<LowerTypesPass> {
1762 void runOnOperation()
override;
1767void LowerTypesPass::runOnOperation() {
1770 std::vector<FModuleLike> ops;
1772 auto &symTbl = getAnalysis<SymbolTable>();
1774 AttrCache cache(&getContext());
1776 DenseMap<FModuleLike, Convention> conventionTable;
1777 auto circuit = getOperation();
1778 for (
auto module : circuit.getOps<FModuleLike>()) {
1779 conventionTable.insert({module,
module.getConvention()});
1780 ops.push_back(module);
1784 auto lowerModules = [&](FModuleLike op) -> LogicalResult {
1786 Convention convention = Convention::Internal;
1787 if (
auto conventionAttr = dyn_cast_or_null<ConventionAttr>(
1788 op->getDiscardableAttr(
"body_type_lowering")))
1789 convention = conventionAttr.getValue();
1792 TypeLoweringVisitor(&getContext(), preserveAggregate, convention,
1793 preserveMemories, symTbl, cache, conventionTable);
1796 return LogicalResult::failure(tl.isFailed());
1799 auto result = failableParallelForEach(&getContext(), ops, lowerModules);
1802 signalPassFailure();
assert(baseType &&"element must be base type")
static SmallVector< Value > extractBits(OpBuilder &builder, Value val)
static std::unique_ptr< Context > context
static void dump(DIModule &module, raw_indented_ostream &os)
static void eraseElementsAtIndices(SmallVectorImpl< T > &vec, const llvm::BitVector &removalMask)
Helper function to remove elements from a vector based on a BitVector mask.
static bool isPreservableAggregateType(Type type, PreserveAggregate::PreserveMode mode)
Return true if we can preserve the type.
static FIRRTLType mapLoweredType(FIRRTLType type, FIRRTLBaseType fieldType)
Return fieldType or fieldType as same ref as type.
static MemOp cloneMemWithNewType(ImplicitLocOpBuilder *b, MemOp op, FlatBundleFieldEntry field)
Clone memory for the specified field. Returns null op on error.
static bool containsBundleType(FIRRTLType type)
Return true if the type has a bundle type as subtype.
static Value cloneAccess(ImplicitLocOpBuilder *builder, Operation *op, Value rhs)
static bool peelType(Type type, SmallVectorImpl< FlatBundleFieldEntry > &fields, PreserveAggregate::PreserveMode mode)
Peel one layer of an aggregate type into its components.
static bool isNotSubAccess(Operation *op)
Return if something is not a normal subaccess.
static SmallVector< Operation * > getSAWritePath(Operation *op)
Look through and collect subfields leading to a subaccess.
static bool isOneDimVectorType(FIRRTLType type)
Return true if the type is a 1d vector type or ground type.
#define CIRCT_DEBUG_SCOPED_PASS_LOGGER(PASS)
This class provides a read-only projection over the MLIR attributes that represent a set of annotatio...
ArrayAttr getArrayAttr() const
Return this annotation set as an ArrayAttr.
This class provides a read-only projection of an annotation.
DictionaryAttr getDict() const
Get the data dictionary of this attribute.
unsigned getFieldID() const
Get the field id this attribute targets.
void setMember(StringAttr name, Attribute value)
Add or set a member of the annotation to a value.
This class implements the same functionality as TypeSwitch except that it uses firrtl::type_dyn_cast ...
FIRRTLTypeSwitch< T, ResultT > & Case(CallableT &&caseFn)
Add a case on the given type.
FIRRTLVisitor allows you to visit all of the expr/stmt/decls with one class declaration.
ResultType visitInvalidOp(Operation *op, ExtraArgs... args)
visitInvalidOp is an override point for non-FIRRTL dialect operations.
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
@ All
Preserve all aggregate values.
@ OneDimVec
Preserve only 1d vectors of ground type (e.g. UInt<2>[3]).
@ Vec
Preserve only vectors (e.g. UInt<2>[3][3]).
@ None
Don't preserve aggregate at all.
mlir::DenseBoolArrayAttr packAttribute(MLIRContext *context, ArrayRef< Direction > directions)
Return a DenseBoolArrayAttr containing the packed representation of an array of directions.
static Direction get(bool isOutput)
Return an output direction if isOutput is true, otherwise return an input direction.
Direction
This represents the direction of a single port.
FIRRTLBaseType getBaseType(Type type)
If it is a base type, return it as is.
FIRRTLType mapBaseType(FIRRTLType type, function_ref< FIRRTLBaseType(FIRRTLBaseType)> fn)
Return a FIRRTLType with its base type component mutated by the given function.
bool hasZeroBitWidth(FIRRTLType type)
Return true if the type has zero bit width.
void emitConnect(OpBuilder &builder, Location loc, Value lhs, Value rhs)
Emit a connect between two values.
StringAttr getInnerSymName(Operation *op)
Return the StringAttr for the inner_sym name, if it exists.
std::optional< int64_t > getBitWidth(FIRRTLBaseType type, bool ignoreFlip=false)
uint64_t getMaxFieldID(Type)
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
This holds the name and type that describes the module's ports.
AnnotationSet annotations