CIRCT 23.0.0git
Loading...
Searching...
No Matches
LowerTypes.cpp
Go to the documentation of this file.
1//===- LowerTypes.cpp - Lower Aggregate Types -------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the LowerTypes pass. This pass replaces aggregate types
10// with expanded values.
11//
12// This pass walks the operations in reverse order. This lets it visit users
13// before defs. Users can usually be expanded out to multiple operations (think
14// mux of a bundle to muxes of each field) with a temporary subWhatever op
15// inserted. When processing an aggregate producer, we blow out the op as
16// appropriate, then walk the users, often those are subWhatever ops which can
17// be bypassed and deleted. Function arguments are logically last on the
18// operation visit order and walked left to right, being peeled one layer at a
19// time with replacements inserted to the right of the original argument.
20//
21// Each processing of an op peels one layer of aggregate type off. Because new
22// ops are inserted immediately above the current up, the walk will visit them
23// next, effectively recusing on the aggregate types, without recusing. These
24// potentially temporary ops(if the aggregate is complex) effectively serve as
25// the worklist. Often aggregates are shallow, so the new ops are the final
26// ones.
27//
28//===----------------------------------------------------------------------===//
29
38#include "circt/Support/Debug.h"
39#include "mlir/IR/ImplicitLocOpBuilder.h"
40#include "mlir/IR/Threading.h"
41#include "mlir/Pass/Pass.h"
42#include "llvm/ADT/APSInt.h"
43#include "llvm/ADT/BitVector.h"
44#include "llvm/Support/Debug.h"
45
46#define DEBUG_TYPE "firrtl-lower-types"
47
48namespace circt {
49namespace firrtl {
50#define GEN_PASS_DEF_LOWERFIRRTLTYPES
51#include "circt/Dialect/FIRRTL/Passes.h.inc"
52} // namespace firrtl
53} // namespace circt
54
55using namespace circt;
56using namespace firrtl;
57
58// TODO: check all argument types
59namespace {
60/// This represents a flattened bundle field element.
61struct FlatBundleFieldEntry {
62 /// This is the underlying ground type of the field.
63 FIRRTLBaseType type;
64 /// The index in the parent type
65 size_t index;
66 /// The fieldID
67 unsigned fieldID;
68 /// This is a suffix to add to the field name to make it unique.
69 SmallString<16> suffix;
70 /// This indicates whether the field was flipped to be an output.
71 bool isOutput;
72
73 FlatBundleFieldEntry(const FIRRTLBaseType &type, size_t index,
74 unsigned fieldID, StringRef suffix, bool isOutput)
75 : type(type), index(index), fieldID(fieldID), suffix(suffix),
76 isOutput(isOutput) {}
77
78 void dump() const {
79 llvm::errs() << "FBFE{" << type << " index<" << index << "> fieldID<"
80 << fieldID << "> suffix<" << suffix << "> isOutput<"
81 << isOutput << ">}\n";
82 }
83};
84} // end anonymous namespace
85
86/// Return fieldType or fieldType as same ref as type.
88 return mapBaseType(type, [&](auto) { return fieldType; });
89}
90
91/// Return fieldType or fieldType as same ref as type.
92static Type mapLoweredType(Type type, FIRRTLBaseType fieldType) {
93 auto ftype = type_dyn_cast<FIRRTLType>(type);
94 if (!ftype)
95 return type;
96 return mapLoweredType(ftype, fieldType);
97}
98
99/// Return true if the type is a 1d vector type or ground type.
102 .Case<BundleType>([&](auto bundle) { return false; })
103 .Case<FVectorType>([&](FVectorType vector) {
104 // When the size is 1, lower the vector into a scalar.
105 return vector.getElementType().isGround() &&
106 vector.getNumElements() > 1;
107 })
108 .Default([](auto groundType) { return true; });
109}
110
111// NOLINTBEGIN(misc-no-recursion)
112/// Return true if the type has a bundle type as subtype.
115 .Case<BundleType>([&](auto bundle) { return true; })
116 .Case<FVectorType>([&](FVectorType vector) {
117 return containsBundleType(vector.getElementType());
118 })
119 .Default([](auto groundType) { return false; });
120}
121// NOLINTEND(misc-no-recursion)
122
123/// Return true if we can preserve the type.
124static bool isPreservableAggregateType(Type type,
126 if (auto refType = type_dyn_cast<RefType>(type)) {
127 // Always preserve rwprobe's.
128 if (refType.getForceable())
129 return true;
130 // FIXME: Don't preserve read-only RefType for now. This is workaround for
131 // MemTap which causes type mismatches (issue 4479).
132 return false;
133 }
134
135 // Return false if no aggregate value is preserved.
136 if (mode == PreserveAggregate::None)
137 return false;
138
139 auto firrtlType = type_dyn_cast<FIRRTLBaseType>(type);
140 if (!firrtlType)
141 return false;
142
143 // We can a preserve the type iff (i) the type is not passive, (ii) the type
144 // doesn't contain analog and (iii) type don't contain zero bitwidth.
145 if (!firrtlType.isPassive() || firrtlType.containsAnalog() ||
146 hasZeroBitWidth(firrtlType))
147 return false;
148
149 switch (mode) {
151 return true;
153 return isOneDimVectorType(firrtlType);
155 return !containsBundleType(firrtlType);
156 default:
157 llvm_unreachable("unexpected mode");
158 }
159}
160
161/// Peel one layer of an aggregate type into its components. Type may be
162/// complex, but empty, in which case fields is empty, but the return is true.
163static bool peelType(Type type, SmallVectorImpl<FlatBundleFieldEntry> &fields,
165 // If the aggregate preservation is enabled and the type is preservable,
166 // then just return.
167 if (isPreservableAggregateType(type, mode))
168 return false;
169
170 if (auto refType = type_dyn_cast<RefType>(type))
171 type = refType.getType();
173 .Case<BundleType>([&](auto bundle) {
174 SmallString<16> tmpSuffix;
175 // Otherwise, we have a bundle type. Break it down.
176 for (size_t i = 0, e = bundle.getNumElements(); i < e; ++i) {
177 auto elt = bundle.getElement(i);
178 // Construct the suffix to pass down.
179 tmpSuffix.resize(0);
180 tmpSuffix.push_back('_');
181 tmpSuffix.append(elt.name.getValue());
182 fields.emplace_back(elt.type, i, bundle.getFieldID(i), tmpSuffix,
183 elt.isFlip);
184 }
185 return true;
186 })
187 .Case<FVectorType>([&](auto vector) {
188 // Increment the field ID to point to the first element.
189 for (size_t i = 0, e = vector.getNumElements(); i != e; ++i) {
190 fields.emplace_back(vector.getElementType(), i, vector.getFieldID(i),
191 "_" + std::to_string(i), false);
192 }
193 return true;
194 })
195 .Default([](auto op) { return false; });
196}
197
198/// Return if something is not a normal subaccess. Non-normal includes
199/// zero-length vectors and constant indexes (which are really subindexes).
200static bool isNotSubAccess(Operation *op) {
201 SubaccessOp sao = llvm::dyn_cast<SubaccessOp>(op);
202 if (!sao)
203 return true;
204 ConstantOp arg =
205 llvm::dyn_cast_or_null<ConstantOp>(sao.getIndex().getDefiningOp());
206 return arg && sao.getInput().getType().base().getNumElements() != 0;
207}
208
209/// Look through and collect subfields leading to a subaccess.
210static SmallVector<Operation *> getSAWritePath(Operation *op) {
211 SmallVector<Operation *> retval;
212 auto defOp = op->getOperand(0).getDefiningOp();
213 while (isa_and_nonnull<SubfieldOp, SubindexOp, SubaccessOp>(defOp)) {
214 retval.push_back(defOp);
215 defOp = defOp->getOperand(0).getDefiningOp();
216 }
217 // Trim to the subaccess
218 while (!retval.empty() && isNotSubAccess(retval.back()))
219 retval.pop_back();
220 return retval;
221}
222
223/// Clone memory for the specified field. Returns null op on error.
224static MemOp cloneMemWithNewType(ImplicitLocOpBuilder *b, MemOp op,
225 FlatBundleFieldEntry field) {
226 SmallVector<Type, 8> ports;
227 SmallVector<Attribute, 8> portNames;
228 SmallVector<Attribute, 8> portLocations;
229
230 auto oldPorts = op.getPorts();
231 for (size_t portIdx = 0, e = oldPorts.size(); portIdx < e; ++portIdx) {
232 auto port = oldPorts[portIdx];
233 ports.push_back(
234 MemOp::getTypeForPort(op.getDepth(), field.type, port.second));
235 portNames.push_back(port.first);
236 }
237
238 // It's easier to duplicate the old annotations, then fix and filter them.
239 auto newMem =
240 MemOp::create(*b, ports, op.getReadLatency(), op.getWriteLatency(),
241 op.getDepth(), op.getRuw(), b->getArrayAttr(portNames),
242 (op.getName() + field.suffix).str(), op.getNameKind(),
243 op.getAnnotations(), op.getPortAnnotations(),
244 op.getInnerSymAttr(), op.getInitAttr(), op.getPrefixAttr());
245
246 if (op.getInnerSym()) {
247 op.emitError("cannot split memory with symbol present");
248 return {};
249 }
250
251 SmallVector<Attribute> newAnnotations;
252 for (size_t portIdx = 0, e = newMem.getNumResults(); portIdx < e; ++portIdx) {
253 auto portType = type_cast<BundleType>(newMem.getResult(portIdx).getType());
254 auto oldPortType = type_cast<BundleType>(op.getResult(portIdx).getType());
255 SmallVector<Attribute> portAnno;
256 for (auto attr : newMem.getPortAnnotation(portIdx)) {
257 Annotation anno(attr);
258 if (auto annoFieldID = anno.getFieldID()) {
259 auto targetIndex = oldPortType.getIndexForFieldID(annoFieldID);
260
261 // Apply annotations to all elements if the target is the whole
262 // sub-field.
263 if (annoFieldID == oldPortType.getFieldID(targetIndex)) {
264 anno.setMember(
265 "circt.fieldID",
266 b->getI32IntegerAttr(portType.getFieldID(targetIndex)));
267 portAnno.push_back(anno.getDict());
268 continue;
269 }
270
271 // Handle aggregate sub-fields, including `(r/w)data` and `(w)mask`.
272 if (type_isa<BundleType>(oldPortType.getElement(targetIndex).type)) {
273 // Check whether the annotation falls into the range of the current
274 // field. Note that the `field` here is peeled from the `data`
275 // sub-field of the memory port, thus we need to add the fieldID of
276 // `data` or `mask` sub-field to get the "real" fieldID.
277 auto fieldID = field.fieldID + oldPortType.getFieldID(targetIndex);
278 if (annoFieldID >= fieldID &&
279 annoFieldID <=
280 fieldID + hw::FieldIdImpl::getMaxFieldID(field.type)) {
281 // Set the field ID of the new annotation.
282 auto newFieldID =
283 annoFieldID - fieldID + portType.getFieldID(targetIndex);
284 anno.setMember("circt.fieldID", b->getI32IntegerAttr(newFieldID));
285 portAnno.push_back(anno.getDict());
286 }
287 }
288 } else
289 portAnno.push_back(attr);
290 }
291 newAnnotations.push_back(b->getArrayAttr(portAnno));
292 }
293 newMem.setAllPortAnnotations(newAnnotations);
294 return newMem;
295}
296
297//===----------------------------------------------------------------------===//
298// Module Type Lowering
299//===----------------------------------------------------------------------===//
300namespace {
301
302struct AttrCache {
303 AttrCache(MLIRContext *context) {
304 i64ty = IntegerType::get(context, 64);
305 nameAttr = StringAttr::get(context, "name");
306 nameKindAttr = StringAttr::get(context, "nameKind");
307 sPortDirections = StringAttr::get(context, "portDirections");
308 sPortNames = StringAttr::get(context, "portNames");
309 sPortTypes = StringAttr::get(context, "portTypes");
310 sPortSymbols = StringAttr::get(context, "portSymbols");
311 sPortLocations = StringAttr::get(context, "portLocations");
312 sPortAnnotations = StringAttr::get(context, "portAnnotations");
313 sPortDomains = StringAttr::get(context, "domainInfo");
314 sEmpty = StringAttr::get(context, "");
315 aEmpty = ArrayAttr::get(context, {});
316 }
317 AttrCache(const AttrCache &) = default;
318
319 Type i64ty;
320 StringAttr nameAttr, nameKindAttr, sPortDirections, sPortNames, sPortTypes,
321 sPortSymbols, sPortLocations, sPortAnnotations, sPortDomains, sEmpty;
322 ArrayAttr aEmpty;
323};
324
325/// Helper class to handle domain lowering consistently across modules,
326/// extmodules, and instances. This class tracks domain port indices and
327/// provides methods to rewrite domain associations after port lowering.
328class DomainLoweringHelper {
329public:
330 /// Construct a helper by scanning the original port types for domain types.
331 /// For modules/extmodules, pass the port types attribute array.
332 /// For instances, pass the result types directly.
333 DomainLoweringHelper(MLIRContext *context, ArrayRef<Attribute> portTypes)
334 : context(context) {
335 for (auto [index, typeAttr] : llvm::enumerate(portTypes))
336 if (type_isa<DomainType>(cast<TypeAttr>(typeAttr).getValue()))
337 domainIndexByOrdinal.push_back(index);
338 }
339
340 /// Construct a helper by scanning instance result types for domain types.
341 DomainLoweringHelper(MLIRContext *context, TypeRange resultTypes)
342 : context(context) {
343 for (auto [index, type] : llvm::enumerate(resultTypes))
344 if (type_isa<DomainType>(type))
345 domainIndexByOrdinal.push_back(index);
346 }
347
348 /// Compute the mapping from old domain port indices to new port indices after
349 /// type lowering. Call this after ports have been lowered but before
350 /// rewriting domain associations. This overload takes a range of types
351 /// directly (e.g., from instances).
352 void computeDomainMap(TypeRange types) {
353 size_t i = 0, ord = 0;
354 for (auto type : types) {
355 if (type_isa<DomainType>(type))
356 domainMap[domainIndexByOrdinal[ord++]] = i;
357 ++i;
358 }
359 }
360
361 /// Compute the mapping from old domain port indices to new port indices after
362 /// type lowering. Call this after ports have been lowered but before
363 /// rewriting domain associations. This overload takes a range of PortInfo
364 /// and extracts types from them (e.g., from modules/extmodules).
365 void computeDomainMap(ArrayRef<PortInfo> ports) {
366 size_t i = 0, ord = 0;
367 for (const auto &port : ports) {
368 if (type_isa<DomainType>(port.type))
369 domainMap[domainIndexByOrdinal[ord++]] = i;
370 ++i;
371 }
372 }
373
374 /// Rewrite a domain attribute to use new port indices. The domain attribute
375 /// contains an array of port indices that need to be updated to reflect the
376 /// new port numbering after type lowering.
377 void rewriteDomain(Attribute &domain) {
378 auto oldAssociations = dyn_cast<ArrayAttr>(domain);
379 if (!oldAssociations)
380 return;
381 SmallVector<Attribute> newAssociations;
382 for (auto oldAttr : oldAssociations)
383 newAssociations.push_back(IntegerAttr::get(
384 IntegerType::get(context, 32, IntegerType::Unsigned),
385 domainMap[cast<IntegerAttr>(oldAttr).getValue().getZExtValue()]));
386 domain = ArrayAttr::get(context, newAssociations);
387 }
388
389private:
390 MLIRContext *context;
391 /// Maps ordinal position of domain ports to their original indices.
392 SmallVector<unsigned> domainIndexByOrdinal;
393 /// Maps old port indices to new port indices after lowering.
394 DenseMap<unsigned, unsigned> domainMap;
395};
396
397// The visitors all return true if the operation should be deleted, false if
398// not.
399struct TypeLoweringVisitor : public FIRRTLVisitor<TypeLoweringVisitor, bool> {
400
401 TypeLoweringVisitor(
402 MLIRContext *context, PreserveAggregate::PreserveMode preserveAggregate,
403 Convention bodyConvention,
404 PreserveAggregate::PreserveMode memoryPreservationMode,
405 SymbolTable &symTbl, const AttrCache &cache,
406 const llvm::DenseMap<FModuleLike, Convention> &conventionTable)
407 : context(context), defaultAggregatePreservationMode(preserveAggregate),
408 memoryPreservationMode(memoryPreservationMode), symTbl(symTbl),
409 cache(cache), conventionTable(conventionTable) {
410 bodyAggregatePreservationMode = bodyConvention == Convention::Scalarized
412 : defaultAggregatePreservationMode;
413 }
414 using FIRRTLVisitor<TypeLoweringVisitor, bool>::visitDecl;
415 using FIRRTLVisitor<TypeLoweringVisitor, bool>::visitExpr;
416 using FIRRTLVisitor<TypeLoweringVisitor, bool>::visitStmt;
417
418 /// If the referenced operation is a FModuleOp or an FExtModuleOp, perform
419 /// type lowering on all operations.
420 void lowerModule(FModuleLike op);
421
422 bool lowerArg(FModuleLike module, size_t argIndex, size_t argsRemoved,
423 SmallVectorImpl<PortInfo> &newArgs,
424 SmallVectorImpl<Value> &lowering);
425 std::pair<Value, PortInfo> addArg(Operation *module, unsigned insertPt,
426 unsigned insertPtOffset, FIRRTLType srcType,
427 const FlatBundleFieldEntry &field,
428 PortInfo &oldArg, hw::InnerSymAttr newSym);
429
430 // Helpers to manage state.
431 bool visitDecl(FExtModuleOp op);
432 bool visitDecl(FModuleOp op);
433 bool visitDecl(InstanceOp op);
434 bool visitDecl(MemOp op);
435 bool visitDecl(NodeOp op);
436 bool visitDecl(RegOp op);
437 bool visitDecl(WireOp op);
438 bool visitDecl(RegResetOp op);
439 bool visitExpr(InvalidValueOp op);
440 bool visitExpr(SubaccessOp op);
441 bool visitExpr(VectorCreateOp op);
442 bool visitExpr(BundleCreateOp op);
443 bool visitExpr(ElementwiseAndPrimOp op);
444 bool visitExpr(ElementwiseOrPrimOp op);
445 bool visitExpr(ElementwiseXorPrimOp op);
446 bool visitExpr(MultibitMuxOp op);
447 bool visitExpr(MuxPrimOp op);
448 bool visitExpr(Mux2CellIntrinsicOp op);
449 bool visitExpr(Mux4CellIntrinsicOp op);
450 bool visitExpr(BitCastOp op);
451 bool visitExpr(RefSendOp op);
452 bool visitExpr(RefResolveOp op);
453 bool visitExpr(RefCastOp op);
454 bool visitStmt(ConnectOp op);
455 bool visitStmt(MatchingConnectOp op);
456 bool visitStmt(RefDefineOp op);
457 bool visitStmt(WhenOp op);
458 bool visitStmt(LayerBlockOp op);
459 bool visitUnrealizedConversionCast(mlir::UnrealizedConversionCastOp op);
460
461 bool isFailed() const { return encounteredError; }
462
463 bool visitInvalidOp(Operation *op) {
464 if (auto castOp = dyn_cast<mlir::UnrealizedConversionCastOp>(op))
465 return visitUnrealizedConversionCast(castOp);
466 return false;
467 }
468
469private:
470 void processUsers(Value val, ArrayRef<Value> mapping);
471 bool processSAPath(Operation *);
472 void lowerBlock(Block *);
473 void lowerSAWritePath(Operation *, ArrayRef<Operation *> writePath);
474
475 /// Lower a "producer" operation one layer based on policy.
476 /// Use the provided \p clone function to generate individual ops for
477 /// the expanded subelements/fields. The type used to determine if lowering
478 /// is needed is either \p srcType if provided or from the assumed-to-exist
479 /// first result of the operation. When lowering, the clone callback will be
480 /// invoked with each subelement/field of this type.
481 bool lowerProducer(
482 Operation *op,
483 llvm::function_ref<Value(const FlatBundleFieldEntry &, ArrayAttr)> clone,
484 Type srcType = {});
485
486 /// Filter out and return \p annotations that target includes \field,
487 /// modifying as needed to adjust fieldID's relative to to \field.
488 ArrayAttr filterAnnotations(MLIRContext *ctxt, ArrayAttr annotations,
489 FIRRTLType srcType, FlatBundleFieldEntry field);
490
491 /// Partition inner symbols on given type. Fails if any symbols
492 /// cannot be assigned to a field, such as inner symbol on root.
493 LogicalResult partitionSymbols(hw::InnerSymAttr sym, FIRRTLType parentType,
494 SmallVectorImpl<hw::InnerSymAttr> &newSyms,
495 Location errorLoc);
496
498 getPreservationModeForPorts(FModuleLike moduleLike);
499 Value getSubWhatever(Value val, size_t index);
500
501 size_t uniqueIdx = 0;
502 std::string uniqueName() {
503 auto myID = uniqueIdx++;
504 return (Twine("__GEN_") + Twine(myID)).str();
505 }
506
507 MLIRContext *context;
508
509 /// Aggregate preservation mode.
510 PreserveAggregate::PreserveMode defaultAggregatePreservationMode;
511 PreserveAggregate::PreserveMode bodyAggregatePreservationMode;
512 PreserveAggregate::PreserveMode memoryPreservationMode;
513
514 /// The builder is set and maintained in the main loop.
515 ImplicitLocOpBuilder *builder;
516
517 // Keep a symbol table around for resolving symbols
518 SymbolTable &symTbl;
519
520 // Cache some attributes
521 const AttrCache &cache;
522
523 const llvm::DenseMap<FModuleLike, Convention> &conventionTable;
524
525 // Set true if the lowering failed.
526 bool encounteredError = false;
527};
528} // namespace
529
530/// Return aggregate preservation mode for the module ports. If the module has a
531/// scalarized linkage, then we may not preserve it's aggregate ports.
533TypeLoweringVisitor::getPreservationModeForPorts(FModuleLike module) {
534 auto lookup = conventionTable.find(module);
535 if (lookup == conventionTable.end())
536 return defaultAggregatePreservationMode;
537 switch (lookup->second) {
538 case Convention::Scalarized:
540 case Convention::Internal:
541 return defaultAggregatePreservationMode;
542 }
543 llvm_unreachable("Unknown convention");
544 return defaultAggregatePreservationMode;
545}
546
547Value TypeLoweringVisitor::getSubWhatever(Value val, size_t index) {
548 if (type_isa<BundleType>(val.getType()))
549 return SubfieldOp::create(*builder, val, index);
550 if (type_isa<FVectorType>(val.getType()))
551 return SubindexOp::create(*builder, val, index);
552 if (type_isa<RefType>(val.getType()))
553 return RefSubOp::create(*builder, val, index);
554 llvm_unreachable("Unknown aggregate type");
555 return nullptr;
556}
557
558/// Conditionally expand a subaccessop write path
559bool TypeLoweringVisitor::processSAPath(Operation *op) {
560 // Does this LHS have a subaccessop?
561 SmallVector<Operation *> writePath = getSAWritePath(op);
562 if (writePath.empty())
563 return false;
564
565 lowerSAWritePath(op, writePath);
566 // Unhook the writePath from the connect. This isn't the right type, but we
567 // are deleting the op anyway.
568 op->eraseOperands(0, 2);
569 // See how far up the tree we can delete things.
570 for (size_t i = 0; i < writePath.size(); ++i) {
571 if (writePath[i]->use_empty()) {
572 writePath[i]->erase();
573 } else {
574 break;
575 }
576 }
577 return true;
578}
579
580void TypeLoweringVisitor::lowerBlock(Block *block) {
581 // Lower the operations bottom up.
582 for (auto it = block->rbegin(), e = block->rend(); it != e;) {
583 auto &iop = *it;
584 builder->setInsertionPoint(&iop);
585 builder->setLoc(iop.getLoc());
586 bool removeOp = dispatchVisitor(&iop);
587 ++it;
588 // Erase old ops eagerly so we don't have dangling uses we've already
589 // lowered.
590 if (removeOp)
591 iop.erase();
592 }
593}
594
595ArrayAttr TypeLoweringVisitor::filterAnnotations(MLIRContext *ctxt,
596 ArrayAttr annotations,
597 FIRRTLType srcType,
598 FlatBundleFieldEntry field) {
599 SmallVector<Attribute> retval;
600 if (!annotations || annotations.empty())
601 return ArrayAttr::get(ctxt, retval);
602 for (auto opAttr : annotations) {
603 Annotation anno(opAttr);
604 auto fieldID = anno.getFieldID();
605 anno.removeMember("circt.fieldID");
606
607 // If no fieldID set, or points to root, forward the annotation without the
608 // fieldID field (which was removed above).
609 if (fieldID == 0) {
610 retval.push_back(anno.getAttr());
611 continue;
612 }
613 // Check whether the annotation falls into the range of the current field.
614
615 if (fieldID < field.fieldID ||
616 fieldID > field.fieldID + hw::FieldIdImpl::getMaxFieldID(field.type))
617 continue;
618
619 // Add fieldID back if non-zero relative to this field.
620 if (auto newFieldID = fieldID - field.fieldID) {
621 // If the target is a subfield/subindex of the current field, create a
622 // new annotation with the correct circt.fieldID.
623 anno.setMember("circt.fieldID", builder->getI32IntegerAttr(newFieldID));
624 }
625
626 retval.push_back(anno.getAttr());
627 }
628 return ArrayAttr::get(ctxt, retval);
629}
630
631LogicalResult TypeLoweringVisitor::partitionSymbols(
632 hw::InnerSymAttr sym, FIRRTLType parentType,
633 SmallVectorImpl<hw::InnerSymAttr> &newSyms, Location errorLoc) {
634
635 // No symbol, nothing to partition.
636 if (!sym || sym.empty())
637 return success();
638
639 auto *context = sym.getContext();
640
641 auto baseType = getBaseType(parentType);
642 if (!baseType)
643 return mlir::emitError(errorLoc,
644 "unable to partition symbol on unsupported type ")
645 << parentType;
646
647 return TypeSwitch<FIRRTLType, LogicalResult>(baseType)
648 .Case<BundleType, FVectorType>([&](auto aggType) -> LogicalResult {
649 struct BinningInfo {
650 uint64_t index;
651 uint64_t relFieldID;
652 hw::InnerSymPropertiesAttr prop;
653 };
654
655 // Walk each inner symbol, compute binning information/assignment.
656 SmallVector<BinningInfo> binning;
657 for (auto prop : sym) {
658 auto fieldID = prop.getFieldID();
659 // Special-case fieldID == 0, helper methods require non-zero fieldID.
660 if (fieldID == 0)
661 return mlir::emitError(errorLoc, "unable to lower due to symbol ")
662 << prop.getName()
663 << " with target not preserved by lowering";
664 auto [index, relFieldID] = aggType.getIndexAndSubfieldID(fieldID);
665 binning.push_back({index, relFieldID, prop});
666 }
667
668 // Sort by index, fieldID.
669 llvm::stable_sort(binning, [&](auto &lhs, auto &rhs) {
670 return std::tuple(lhs.index, lhs.relFieldID) <
671 std::tuple(rhs.index, rhs.relFieldID);
672 });
673 assert(!binning.empty());
674
675 // Populate newSyms, group all symbols on same index.
676 newSyms.resize(aggType.getNumElements());
677 for (auto binIt = binning.begin(), binEnd = binning.end();
678 binIt != binEnd;) {
679 auto curIndex = binIt->index;
680 SmallVector<hw::InnerSymPropertiesAttr> propsForIndex;
681 // Gather all adjacent symbols for this index.
682 while (binIt != binEnd && binIt->index == curIndex) {
683 propsForIndex.push_back(hw::InnerSymPropertiesAttr::get(
684 context, binIt->prop.getName(), binIt->relFieldID,
685 binIt->prop.getSymVisibility()));
686 ++binIt;
687 }
688
689 assert(!newSyms[curIndex]);
690 newSyms[curIndex] = hw::InnerSymAttr::get(context, propsForIndex);
691 }
692 return success();
693 })
694 .Default([&](auto ty) {
695 return mlir::emitError(
696 errorLoc, "unable to partition symbol on unsupported type ")
697 << ty;
698 });
699}
700
701bool TypeLoweringVisitor::lowerProducer(
702 Operation *op,
703 llvm::function_ref<Value(const FlatBundleFieldEntry &, ArrayAttr)> clone,
704 Type srcType) {
705
706 if (!srcType)
707 srcType = op->getResult(0).getType();
708 auto srcFType = type_dyn_cast<FIRRTLType>(srcType);
709 if (!srcFType)
710 return false;
711 SmallVector<FlatBundleFieldEntry, 8> fieldTypes;
712
713 if (!peelType(srcFType, fieldTypes, bodyAggregatePreservationMode))
714 return false;
715
716 SmallVector<Value> lowered;
717 // Loop over the leaf aggregates.
718 SmallString<16> loweredName;
719 auto nameKindAttr = op->getAttrOfType<NameKindEnumAttr>(cache.nameKindAttr);
720
721 if (auto nameAttr = op->getAttrOfType<StringAttr>(cache.nameAttr))
722 loweredName = nameAttr.getValue();
723 auto baseNameLen = loweredName.size();
724 auto oldAnno = dyn_cast_or_null<ArrayAttr>(op->getAttr("annotations"));
725
726 SmallVector<hw::InnerSymAttr> fieldSyms(fieldTypes.size());
727 if (auto symOp = dyn_cast<hw::InnerSymbolOpInterface>(op)) {
728 if (failed(partitionSymbols(symOp.getInnerSymAttr(), srcFType, fieldSyms,
729 symOp.getLoc()))) {
730 encounteredError = true;
731 return false;
732 }
733 }
734
735 for (const auto &[field, sym] : llvm::zip_equal(fieldTypes, fieldSyms)) {
736 if (!loweredName.empty()) {
737 loweredName.resize(baseNameLen);
738 loweredName += field.suffix;
739 }
740
741 // For all annotations on the parent op, filter them based on the target
742 // attribute.
743 ArrayAttr loweredAttrs =
744 filterAnnotations(context, oldAnno, srcFType, field);
745 auto newVal = clone(field, loweredAttrs);
746
747 // If inner symbols on this field, add to new op.
748 if (sym) {
749 // Splitting up something with symbols on it should lower to ops
750 // that also can have symbols on them.
751 auto newSymOp = newVal.getDefiningOp<hw::InnerSymbolOpInterface>();
752 assert(
753 newSymOp &&
754 "op with inner symbol lowered to op that cannot take inner symbol");
755 newSymOp.setInnerSymbolAttr(sym);
756 }
757
758 // Carry over the name, if present.
759 if (auto *newOp = newVal.getDefiningOp()) {
760 if (!loweredName.empty())
761 newOp->setAttr(cache.nameAttr, StringAttr::get(context, loweredName));
762 if (nameKindAttr)
763 newOp->setAttr(cache.nameKindAttr, nameKindAttr);
764
765 // Clone discardable attributes as well.
766 newOp->setDiscardableAttrs(op->getDiscardableAttrDictionary());
767 }
768 lowered.push_back(newVal);
769 }
770
771 processUsers(op->getResult(0), lowered);
772 return true;
773}
774
775void TypeLoweringVisitor::processUsers(Value val, ArrayRef<Value> mapping) {
776 for (auto *user : llvm::make_early_inc_range(val.getUsers())) {
777 TypeSwitch<Operation *, void>(user)
778 .Case<SubindexOp>([mapping](SubindexOp sio) {
779 Value repl = mapping[sio.getIndex()];
780 sio.replaceAllUsesWith(repl);
781 sio.erase();
782 })
783 .Case<SubfieldOp>([mapping](SubfieldOp sfo) {
784 // Get the input bundle type.
785 Value repl = mapping[sfo.getFieldIndex()];
786 sfo.replaceAllUsesWith(repl);
787 sfo.erase();
788 })
789 .Case<RefSubOp>([mapping](RefSubOp refSub) {
790 Value repl = mapping[refSub.getIndex()];
791 refSub.replaceAllUsesWith(repl);
792 refSub.erase();
793 })
794 .Default([&](auto op) {
795 // This means we have already processed the user, and it didn't lower
796 // its inputs. This is an opaque user, which will continue to have
797 // aggregate type as input, even after LowerTypes. So, construct the
798 // vector/bundle back from the lowered elements to ensure a valid
799 // input into the opaque op. This only supports Bundles and Vectors.
800
801 // This builder ensures that the aggregate construction happens at the
802 // user location, and the LowerTypes algorithm will not touch them any
803 // more, because LowerTypes was reverse iterating on the block and the
804 // user has already been processed.
805 ImplicitLocOpBuilder b(user->getLoc(), user);
806
807 // This shouldn't happen (non-FIRRTLBaseType's in lowered types, or
808 // refs), check explicitly here for clarity/early detection.
809 assert(llvm::none_of(mapping, [](auto v) {
810 auto fbasetype = type_dyn_cast<FIRRTLBaseType>(v.getType());
811 return !fbasetype || fbasetype.containsReference();
812 }));
813
814 Value input =
815 TypeSwitch<Type, Value>(val.getType())
816 .template Case<FVectorType>([&](auto vecType) {
817 return b.createOrFold<VectorCreateOp>(vecType, mapping);
818 })
819 .template Case<BundleType>([&](auto bundleType) {
820 return b.createOrFold<BundleCreateOp>(bundleType, mapping);
821 })
822 .Default([&](auto _) -> Value { return {}; });
823 if (!input) {
824 user->emitError("unable to reconstruct source of type ")
825 << val.getType();
826 encounteredError = true;
827 return;
828 }
829 user->replaceUsesOfWith(val, input);
830 });
831 }
832}
833
834/// Helper function to remove elements from a vector based on a BitVector mask.
835template <typename T>
836static void eraseElementsAtIndices(SmallVectorImpl<T> &vec,
837 const llvm::BitVector &removalMask) {
838 size_t writeIndex = 0, readIndex = 0;
839
840 // Iterate over each set bit (element to remove) in the mask.
841 // Between each removal point, we bulk-copy the range of elements to keep.
842 for (size_t removalIndex : removalMask.set_bits()) {
843 // Copy the range [readIndex, removalIndex) - these are elements to keep.
844 assert(removalIndex >= readIndex && "removal index before read index");
845 size_t rangeSize = removalIndex - readIndex;
846 if (rangeSize > 0) {
847 // Bulk move the range of elements to keep to the write position.
848 // Skip if the read and write positions are the same (= the first
849 // iteration).
850 if (writeIndex != readIndex)
851 std::move(vec.begin() + readIndex, vec.begin() + removalIndex,
852 vec.begin() + writeIndex);
853 writeIndex += rangeSize;
854 }
855 readIndex = removalIndex + 1;
856 }
857
858 // Copy any remaining elements after the last removal point.
859 size_t remainingSize = vec.size() - readIndex;
860 if (remainingSize > 0) {
861 if (writeIndex != readIndex)
862 std::move(vec.begin() + readIndex, vec.end(), vec.begin() + writeIndex);
863 writeIndex += remainingSize;
864 }
865
866 // Truncate the vector to the new size (number of elements kept).
867 vec.truncate(writeIndex);
868}
869
870void TypeLoweringVisitor::lowerModule(FModuleLike op) {
871 if (auto module = llvm::dyn_cast<FModuleOp>(*op))
872 visitDecl(module);
873 else if (auto extModule = llvm::dyn_cast<FExtModuleOp>(*op))
874 visitDecl(extModule);
875}
876
877// Creates and returns a new block argument of the specified type to the
878// module. This also maintains the name attribute for the new argument,
879// possibly with a new suffix appended.
880std::pair<Value, PortInfo>
881TypeLoweringVisitor::addArg(Operation *module, unsigned insertPt,
882 unsigned insertPtOffset, FIRRTLType srcType,
883 const FlatBundleFieldEntry &field, PortInfo &oldArg,
884 hw::InnerSymAttr newSym) {
885 Value newValue;
886 FIRRTLType fieldType = mapLoweredType(srcType, field.type);
887 if (auto mod = llvm::dyn_cast<FModuleOp>(module)) {
888 Block *body = mod.getBodyBlock();
889 // Append the new argument.
890 newValue = body->insertArgument(insertPt, fieldType, oldArg.loc);
891 }
892
893 // Save the name attribute for the new argument.
894 auto name = builder->getStringAttr(oldArg.name.getValue() + field.suffix);
895
896 // Populate the new arg attributes.
897 auto newAnnotations = filterAnnotations(
898 context, oldArg.annotations.getArrayAttr(), srcType, field);
899 // Flip the direction if the field is an output.
900 auto direction = (Direction)((unsigned)oldArg.direction ^ field.isOutput);
901
902 return std::make_pair(
903 newValue, PortInfo{name, fieldType, direction, newSym, oldArg.loc,
904 AnnotationSet(newAnnotations), oldArg.domains});
905}
906
907// Lower arguments with bundle type by flattening them.
908bool TypeLoweringVisitor::lowerArg(FModuleLike module, size_t argIndex,
909 size_t argsRemoved,
910 SmallVectorImpl<PortInfo> &newArgs,
911 SmallVectorImpl<Value> &lowering) {
912
913 // Flatten any bundle types.
914 SmallVector<FlatBundleFieldEntry> fieldTypes;
915 auto srcType = type_cast<FIRRTLType>(newArgs[argIndex].type);
916 if (!peelType(srcType, fieldTypes, getPreservationModeForPorts(module)))
917 return false;
918
919 SmallVector<hw::InnerSymAttr> fieldSyms(fieldTypes.size());
920 if (failed(partitionSymbols(newArgs[argIndex].sym, srcType, fieldSyms,
921 newArgs[argIndex].loc))) {
922 encounteredError = true;
923 return false;
924 }
925
926 for (const auto &[idx, field, fieldSym] :
927 llvm::enumerate(fieldTypes, fieldSyms)) {
928 auto newValue = addArg(module, 1 + argIndex + idx, argsRemoved, srcType,
929 field, newArgs[argIndex], fieldSym);
930 newArgs.insert(newArgs.begin() + 1 + argIndex + idx, newValue.second);
931 // Lower any other arguments by copying them to keep the relative order.
932 lowering.push_back(newValue.first);
933 }
934 return true;
935}
936
937static Value cloneAccess(ImplicitLocOpBuilder *builder, Operation *op,
938 Value rhs) {
939 if (auto rop = llvm::dyn_cast<SubfieldOp>(op))
940 return SubfieldOp::create(*builder, rhs, rop.getFieldIndex());
941 if (auto rop = llvm::dyn_cast<SubindexOp>(op))
942 return SubindexOp::create(*builder, rhs, rop.getIndex());
943 if (auto rop = llvm::dyn_cast<SubaccessOp>(op))
944 return SubaccessOp::create(*builder, rhs, rop.getIndex());
945 op->emitError("Unknown accessor");
946 return nullptr;
947}
948
949void TypeLoweringVisitor::lowerSAWritePath(Operation *op,
950 ArrayRef<Operation *> writePath) {
951 SubaccessOp sao = cast<SubaccessOp>(writePath.back());
952 FVectorType saoType = sao.getInput().getType();
953 auto selectWidth = llvm::Log2_64_Ceil(saoType.getNumElements());
954
955 for (size_t index = 0, e = saoType.getNumElements(); index < e; ++index) {
956 auto cond = EQPrimOp::create(
957 *builder, sao.getIndex(),
958 builder->createOrFold<ConstantOp>(UIntType::get(context, selectWidth),
959 APInt(selectWidth, index)));
960 WhenOp::create(*builder, cond, false, [&]() {
961 // Recreate the write Path
962 Value leaf = SubindexOp::create(*builder, sao.getInput(), index);
963 for (int i = writePath.size() - 2; i >= 0; --i) {
964 if (auto access = cloneAccess(builder, writePath[i], leaf))
965 leaf = access;
966 else {
967 encounteredError = true;
968 return;
969 }
970 }
971
972 emitConnect(*builder, leaf, op->getOperand(1));
973 });
974 }
975}
976
977// Expand connects of aggregates
978bool TypeLoweringVisitor::visitStmt(ConnectOp op) {
979 if (processSAPath(op))
980 return true;
981
982 // Attempt to get the bundle types.
983 SmallVector<FlatBundleFieldEntry> fields;
984
985 // We have to expand connections even if the aggregate preservation is true.
986 if (!peelType(op.getDest().getType(), fields, PreserveAggregate::None))
987 return false;
988
989 // Loop over the leaf aggregates.
990 for (const auto &field : llvm::enumerate(fields)) {
991 Value src = getSubWhatever(op.getSrc(), field.index());
992 Value dest = getSubWhatever(op.getDest(), field.index());
993 if (field.value().isOutput)
994 std::swap(src, dest);
995 emitConnect(*builder, dest, src);
996 }
997 return true;
998}
999
1000// Expand connects of aggregates
1001bool TypeLoweringVisitor::visitStmt(MatchingConnectOp op) {
1002 if (processSAPath(op))
1003 return true;
1004
1005 // Attempt to get the bundle types.
1006 SmallVector<FlatBundleFieldEntry> fields;
1007
1008 // We have to expand connections even if the aggregate preservation is true.
1009 if (!peelType(op.getDest().getType(), fields, PreserveAggregate::None))
1010 return false;
1011
1012 // Loop over the leaf aggregates.
1013 for (const auto &field : llvm::enumerate(fields)) {
1014 Value src = getSubWhatever(op.getSrc(), field.index());
1015 Value dest = getSubWhatever(op.getDest(), field.index());
1016 if (field.value().isOutput)
1017 std::swap(src, dest);
1018 MatchingConnectOp::create(*builder, dest, src);
1019 }
1020 return true;
1021}
1022
1023// Expand connects of references-of-aggregates
1024bool TypeLoweringVisitor::visitStmt(RefDefineOp op) {
1025 // Attempt to get the bundle types.
1026 SmallVector<FlatBundleFieldEntry> fields;
1027
1028 if (!peelType(op.getDest().getType(), fields, bodyAggregatePreservationMode))
1029 return false;
1030
1031 // Loop over the leaf aggregates.
1032 for (const auto &field : llvm::enumerate(fields)) {
1033 Value src = getSubWhatever(op.getSrc(), field.index());
1034 Value dest = getSubWhatever(op.getDest(), field.index());
1035 assert(!field.value().isOutput && "unexpected flip in reftype destination");
1036 RefDefineOp::create(*builder, dest, src);
1037 }
1038 return true;
1039}
1040
1041bool TypeLoweringVisitor::visitStmt(WhenOp op) {
1042 // The WhenOp itself does not require any lowering, the only value it uses
1043 // is a one-bit predicate. Recursively visit all regions so internal
1044 // operations are lowered.
1045
1046 // Visit operations in the then block.
1047 lowerBlock(&op.getThenBlock());
1048
1049 // Visit operations in the else block.
1050 if (op.hasElseRegion())
1051 lowerBlock(&op.getElseBlock());
1052 return false; // don't delete the when!
1053}
1054
1055/// Lower any types declared in layer blocks.
1056bool TypeLoweringVisitor::visitStmt(LayerBlockOp op) {
1057 lowerBlock(op.getBody());
1058 return false;
1059}
1060
1061/// Lower memory operations. A new memory is created for every leaf
1062/// element in a memory's data type.
1063bool TypeLoweringVisitor::visitDecl(MemOp op) {
1064 // Attempt to get the bundle types.
1065 SmallVector<FlatBundleFieldEntry> fields;
1066
1067 // MemOp should have ground types so we can't preserve aggregates.
1068 if (!peelType(op.getDataType(), fields, memoryPreservationMode))
1069 return false;
1070
1071 if (op.getInnerSym()) {
1072 op->emitError() << "has a symbol, but no symbols may exist on aggregates "
1073 "passed through LowerTypes";
1074 encounteredError = true;
1075 return false;
1076 }
1077
1078 SmallVector<MemOp> newMemories;
1079 SmallVector<WireOp> oldPorts;
1080
1081 // Wires for old ports
1082 for (unsigned int index = 0, end = op.getNumResults(); index < end; ++index) {
1083 auto result = op.getResult(index);
1084 if (op.getPortKind(index) == MemOp::PortKind::Debug) {
1085 op.emitOpError("cannot lower memory with debug port");
1086 encounteredError = true;
1087 return false;
1088 }
1089 auto wire =
1090 WireOp::create(*builder, result.getType(),
1091 (op.getName() + "_" + op.getPortName(index)).str());
1092 oldPorts.push_back(wire);
1093 result.replaceAllUsesWith(wire.getResult());
1094 }
1095 // If annotations targeting fields of an aggregate are present, we cannot
1096 // flatten the memory. It must be split into one memory per aggregate field.
1097 // Do not overwrite the pass flag!
1098
1099 // Memory for each field
1100 for (const auto &field : fields) {
1101 auto newMemForField = cloneMemWithNewType(builder, op, field);
1102 if (!newMemForField) {
1103 op.emitError("failed cloning memory for field");
1104 encounteredError = true;
1105 return false;
1106 }
1107 newMemories.push_back(newMemForField);
1108 }
1109 // Hook up the new memories to the wires the old memory was replaced with.
1110 for (size_t index = 0, rend = op.getNumResults(); index < rend; ++index) {
1111 auto result = oldPorts[index].getResult();
1112 auto rType = type_cast<BundleType>(result.getType());
1113 for (size_t fieldIndex = 0, fend = rType.getNumElements();
1114 fieldIndex != fend; ++fieldIndex) {
1115 auto name = rType.getElement(fieldIndex).name.getValue();
1116 auto oldField = SubfieldOp::create(*builder, result, fieldIndex);
1117 // data and mask depend on the memory type which was split. They can also
1118 // go both directions, depending on the port direction.
1119 if (name == "data" || name == "mask" || name == "wdata" ||
1120 name == "wmask" || name == "rdata") {
1121 for (const auto &field : fields) {
1122 auto realOldField = getSubWhatever(oldField, field.index);
1123 auto newField = getSubWhatever(
1124 newMemories[field.index].getResult(index), fieldIndex);
1125 if (rType.getElement(fieldIndex).isFlip)
1126 std::swap(realOldField, newField);
1127 emitConnect(*builder, newField, realOldField);
1128 }
1129 } else {
1130 for (auto mem : newMemories) {
1131 auto newField =
1132 SubfieldOp::create(*builder, mem.getResult(index), fieldIndex);
1133 emitConnect(*builder, newField, oldField);
1134 }
1135 }
1136 }
1137 }
1138 return true;
1139}
1140
1141bool TypeLoweringVisitor::visitDecl(FExtModuleOp extModule) {
1142 ImplicitLocOpBuilder theBuilder(extModule.getLoc(), context);
1143 builder = &theBuilder;
1144
1145 // Top level builder
1146 OpBuilder builder(context);
1147
1148 // Lower the module block arguments.
1149 llvm::BitVector argsToRemove;
1150 auto newArgs = extModule.getPorts();
1151 argsToRemove.reserve(newArgs.size());
1152
1153 DomainLoweringHelper domainHelper(context, extModule.getPortTypes());
1154
1155 size_t argsRemoved = 0;
1156 for (size_t argIndex = 0; argIndex < newArgs.size(); ++argIndex) {
1157 SmallVector<Value> lowering;
1158 if (lowerArg(extModule, argIndex, argsRemoved, newArgs, lowering)) {
1159 argsToRemove.push_back(true);
1160 ++argsRemoved;
1161 } else {
1162 argsToRemove.push_back(false);
1163 }
1164 // lowerArg might have invalidated any reference to newArgs, be careful
1165 }
1166
1167 // Remove block args that have been lowered.
1168 if (argsRemoved != 0)
1169 eraseElementsAtIndices(newArgs, argsToRemove);
1170
1171 domainHelper.computeDomainMap(newArgs);
1172
1173 SmallVector<NamedAttribute, 8> newModuleAttrs;
1174
1175 // Copy over any attributes that weren't original argument attributes.
1176 for (auto attr : extModule->getAttrDictionary())
1177 // Drop old "portNames", directions, and argument attributes. These are
1178 // handled differently below.
1179 if (attr.getName() != "portDirections" && attr.getName() != "portNames" &&
1180 attr.getName() != "portTypes" && attr.getName() != "portAnnotations" &&
1181 attr.getName() != "portSymbols" && attr.getName() != "portLocations")
1182 newModuleAttrs.push_back(attr);
1183
1184 SmallVector<Direction> newArgDirections;
1185 SmallVector<Attribute> newArgNames;
1186 SmallVector<Attribute, 8> newArgTypes;
1187 SmallVector<Attribute, 8> newArgSyms;
1188 SmallVector<Attribute, 8> newArgLocations;
1189 SmallVector<Attribute, 8> newArgAnnotations;
1190 SmallVector<Attribute, 8> newArgDomains;
1191
1192 for (auto &port : newArgs) {
1193 newArgDirections.push_back(port.direction);
1194 newArgNames.push_back(port.name);
1195 newArgTypes.push_back(TypeAttr::get(port.type));
1196 newArgSyms.push_back(port.sym);
1197 newArgLocations.push_back(port.loc);
1198 newArgAnnotations.push_back(port.annotations.getArrayAttr());
1199 if (port.domains) {
1200 domainHelper.rewriteDomain(port.domains);
1201 } else {
1202 port.domains = cache.aEmpty;
1203 }
1204 newArgDomains.push_back(port.domains);
1205 }
1206
1207 newModuleAttrs.push_back(
1208 NamedAttribute(cache.sPortDirections,
1209 direction::packAttribute(context, newArgDirections)));
1210
1211 newModuleAttrs.push_back(
1212 NamedAttribute(cache.sPortNames, builder.getArrayAttr(newArgNames)));
1213
1214 newModuleAttrs.push_back(
1215 NamedAttribute(cache.sPortTypes, builder.getArrayAttr(newArgTypes)));
1216
1217 newModuleAttrs.push_back(NamedAttribute(
1218 cache.sPortLocations, builder.getArrayAttr(newArgLocations)));
1219
1220 newModuleAttrs.push_back(NamedAttribute(
1221 cache.sPortAnnotations, builder.getArrayAttr(newArgAnnotations)));
1222
1223 newModuleAttrs.push_back(
1224 NamedAttribute(cache.sPortDomains, builder.getArrayAttr(newArgDomains)));
1225
1226 // Update the module's attributes.
1227 extModule->setAttrs(newModuleAttrs);
1228 FModuleLike::fixupPortSymsArray(newArgSyms, context);
1229 extModule.setPortSymbols(newArgSyms);
1230
1231 return false;
1232}
1233
1234bool TypeLoweringVisitor::visitDecl(FModuleOp module) {
1235 auto *body = module.getBodyBlock();
1236
1237 ImplicitLocOpBuilder theBuilder(module.getLoc(), context);
1238 builder = &theBuilder;
1239
1240 // Lower the operations.
1241 lowerBlock(body);
1242
1243 // Lower the module block arguments.
1244 llvm::BitVector argsToRemove;
1245 auto newArgs = module.getPorts();
1246 argsToRemove.reserve(newArgs.size());
1247
1248 DomainLoweringHelper domainHelper(context, module.getPortTypes());
1249
1250 size_t argsRemoved = 0;
1251 for (size_t argIndex = 0; argIndex < newArgs.size(); ++argIndex) {
1252 SmallVector<Value> lowerings;
1253 if (lowerArg(module, argIndex, argsRemoved, newArgs, lowerings)) {
1254 auto arg = module.getArgument(argIndex);
1255 processUsers(arg, lowerings);
1256 argsToRemove.push_back(true);
1257 ++argsRemoved;
1258 } else
1259 argsToRemove.push_back(false);
1260 // lowerArg might have invalidated any reference to newArgs, be careful
1261 }
1262
1263 // Remove block args that have been lowered.
1264 if (argsRemoved != 0) {
1265 body->eraseArguments(argsToRemove);
1266 eraseElementsAtIndices(newArgs, argsToRemove);
1267 }
1268
1269 domainHelper.computeDomainMap(newArgs);
1270
1271 SmallVector<NamedAttribute, 8> newModuleAttrs;
1272
1273 // Copy over any attributes that weren't original argument attributes.
1274 for (auto attr : module->getAttrDictionary())
1275 // Drop old "portNames", directions, and argument attributes. These are
1276 // handled differently below.
1277 if (attr.getName() != "portNames" && attr.getName() != "portDirections" &&
1278 attr.getName() != "portTypes" && attr.getName() != "portAnnotations" &&
1279 attr.getName() != "portSymbols" && attr.getName() != "portLocations")
1280 newModuleAttrs.push_back(attr);
1281
1282 SmallVector<Direction> newArgDirections;
1283 SmallVector<Attribute> newArgNames;
1284 SmallVector<Attribute> newArgTypes;
1285 SmallVector<Attribute> newArgSyms;
1286 SmallVector<Attribute> newArgLocations;
1287 SmallVector<Attribute, 8> newArgAnnotations;
1288 SmallVector<Attribute> newPortDomains;
1289 for (auto &port : newArgs) {
1290 newArgDirections.push_back(port.direction);
1291 newArgNames.push_back(port.name);
1292 newArgTypes.push_back(TypeAttr::get(port.type));
1293 newArgSyms.push_back(port.sym);
1294 newArgLocations.push_back(port.loc);
1295 newArgAnnotations.push_back(port.annotations.getArrayAttr());
1296 if (port.domains) {
1297 domainHelper.rewriteDomain(port.domains);
1298 } else {
1299 port.domains = cache.aEmpty;
1300 }
1301 newPortDomains.push_back(port.domains);
1302 }
1303
1304 newModuleAttrs.push_back(
1305 NamedAttribute(cache.sPortDirections,
1306 direction::packAttribute(context, newArgDirections)));
1307
1308 newModuleAttrs.push_back(
1309 NamedAttribute(cache.sPortNames, builder->getArrayAttr(newArgNames)));
1310
1311 newModuleAttrs.push_back(
1312 NamedAttribute(cache.sPortTypes, builder->getArrayAttr(newArgTypes)));
1313
1314 newModuleAttrs.push_back(NamedAttribute(
1315 cache.sPortLocations, builder->getArrayAttr(newArgLocations)));
1316
1317 newModuleAttrs.push_back(NamedAttribute(
1318 cache.sPortAnnotations, builder->getArrayAttr(newArgAnnotations)));
1319
1320 newModuleAttrs.push_back(NamedAttribute(
1321 cache.sPortDomains, builder->getArrayAttr(newPortDomains)));
1322
1323 // Update the module's attributes.
1324 module->setAttrs(newModuleAttrs);
1325 FModuleLike::fixupPortSymsArray(newArgSyms, context);
1326 module.setPortSymbols(newArgSyms);
1327 return false;
1328}
1329
1330/// Lower a wire op with a bundle to multiple non-bundled wires.
1331bool TypeLoweringVisitor::visitDecl(WireOp op) {
1332 if (op.isForceable())
1333 return false;
1334
1335 auto clone = [&](const FlatBundleFieldEntry &field,
1336 ArrayAttr attrs) -> Value {
1337 return WireOp::create(*builder,
1338 mapLoweredType(op.getDataRaw().getType(), field.type),
1339 "", NameKindEnum::DroppableName, attrs, StringAttr{})
1340 .getResult();
1341 };
1342 return lowerProducer(op, clone);
1343}
1344
1345/// Lower a reg op with a bundle to multiple non-bundled regs.
1346bool TypeLoweringVisitor::visitDecl(RegOp op) {
1347 if (op.isForceable())
1348 return false;
1349
1350 auto clone = [&](const FlatBundleFieldEntry &field,
1351 ArrayAttr attrs) -> Value {
1352 return RegOp::create(*builder, field.type, op.getClockVal(), "",
1353 NameKindEnum::DroppableName, attrs, StringAttr{})
1354 .getResult();
1355 };
1356 return lowerProducer(op, clone);
1357}
1358
1359/// Lower a reg op with a bundle to multiple non-bundled regs.
1360bool TypeLoweringVisitor::visitDecl(RegResetOp op) {
1361 if (op.isForceable())
1362 return false;
1363
1364 auto clone = [&](const FlatBundleFieldEntry &field,
1365 ArrayAttr attrs) -> Value {
1366 auto resetVal = getSubWhatever(op.getResetValue(), field.index);
1367 return RegResetOp::create(*builder, field.type, op.getClockVal(),
1368 op.getResetSignal(), resetVal, "",
1369 NameKindEnum::DroppableName, attrs, StringAttr{})
1370 .getResult();
1371 };
1372 return lowerProducer(op, clone);
1373}
1374
1375/// Lower a wire op with a bundle to multiple non-bundled wires.
1376bool TypeLoweringVisitor::visitDecl(NodeOp op) {
1377 if (op.isForceable())
1378 return false;
1379
1380 auto clone = [&](const FlatBundleFieldEntry &field,
1381 ArrayAttr attrs) -> Value {
1382 auto input = getSubWhatever(op.getInput(), field.index);
1383 return NodeOp::create(*builder, input, "", NameKindEnum::DroppableName,
1384 attrs)
1385 .getResult();
1386 };
1387 return lowerProducer(op, clone);
1388}
1389
1390/// Lower an InvalidValue op with a bundle to multiple non-bundled InvalidOps.
1391bool TypeLoweringVisitor::visitExpr(InvalidValueOp op) {
1392 auto clone = [&](const FlatBundleFieldEntry &field,
1393 ArrayAttr attrs) -> Value {
1394 return InvalidValueOp::create(*builder, field.type);
1395 };
1396 return lowerProducer(op, clone);
1397}
1398
1399// Expand muxes of aggregates
1400bool TypeLoweringVisitor::visitExpr(MuxPrimOp op) {
1401 auto clone = [&](const FlatBundleFieldEntry &field,
1402 ArrayAttr attrs) -> Value {
1403 auto high = getSubWhatever(op.getHigh(), field.index);
1404 auto low = getSubWhatever(op.getLow(), field.index);
1405 return MuxPrimOp::create(*builder, op.getSel(), high, low);
1406 };
1407 return lowerProducer(op, clone);
1408}
1409
1410// Expand muxes of aggregates
1411bool TypeLoweringVisitor::visitExpr(Mux2CellIntrinsicOp op) {
1412 auto clone = [&](const FlatBundleFieldEntry &field,
1413 ArrayAttr attrs) -> Value {
1414 auto high = getSubWhatever(op.getHigh(), field.index);
1415 auto low = getSubWhatever(op.getLow(), field.index);
1416 return Mux2CellIntrinsicOp::create(*builder, op.getSel(), high, low);
1417 };
1418 return lowerProducer(op, clone);
1419}
1420
1421// Expand muxes of aggregates
1422bool TypeLoweringVisitor::visitExpr(Mux4CellIntrinsicOp op) {
1423 auto clone = [&](const FlatBundleFieldEntry &field,
1424 ArrayAttr attrs) -> Value {
1425 auto v3 = getSubWhatever(op.getV3(), field.index);
1426 auto v2 = getSubWhatever(op.getV2(), field.index);
1427 auto v1 = getSubWhatever(op.getV1(), field.index);
1428 auto v0 = getSubWhatever(op.getV0(), field.index);
1429 return Mux4CellIntrinsicOp::create(*builder, op.getSel(), v3, v2, v1, v0);
1430 };
1431 return lowerProducer(op, clone);
1432}
1433
1434// Expand UnrealizedConversionCastOp of aggregates
1435bool TypeLoweringVisitor::visitUnrealizedConversionCast(
1436 mlir::UnrealizedConversionCastOp op) {
1437 auto clone = [&](const FlatBundleFieldEntry &field,
1438 ArrayAttr attrs) -> Value {
1439 auto input = getSubWhatever(op.getOperand(0), field.index);
1440 return mlir::UnrealizedConversionCastOp::create(*builder, field.type, input)
1441 .getResult(0);
1442 };
1443 // If the input to the cast is not a FIRRTL type, getSubWhatever cannot handle
1444 // it, donot lower the op.
1445 if (!type_isa<FIRRTLType>(op->getOperand(0).getType()))
1446 return false;
1447 return lowerProducer(op, clone);
1448}
1449
1450// Expand BitCastOp of aggregates
1451bool TypeLoweringVisitor::visitExpr(BitCastOp op) {
1452 Value srcLoweredVal = op.getInput();
1453 // If the input is of aggregate type, then cat all the leaf fields to form a
1454 // UInt type result. That is, first bitcast the aggregate type to a UInt.
1455 // Attempt to get the bundle types.
1456 SmallVector<FlatBundleFieldEntry> fields;
1457 if (peelType(op.getInput().getType(), fields, PreserveAggregate::None)) {
1458 size_t uptoBits = 0;
1459 // Loop over the leaf aggregates and concat each of them to get a UInt.
1460 // Bitcast the fields to handle nested aggregate types.
1461 for (const auto &field : llvm::enumerate(fields)) {
1462 auto fieldBitwidth = *getBitWidth(field.value().type);
1463 // Ignore zero width fields, like empty bundles.
1464 if (fieldBitwidth == 0)
1465 continue;
1466 Value src = getSubWhatever(op.getInput(), field.index());
1467 // The src could be an aggregate type, bitcast it to a UInt type.
1468 src = builder->createOrFold<BitCastOp>(
1469 UIntType::get(context, fieldBitwidth), src);
1470 // Take the first field, or else Cat the previous fields with this field.
1471 if (uptoBits == 0)
1472 srcLoweredVal = src;
1473 else {
1474 if (type_isa<BundleType>(op.getInput().getType())) {
1475 srcLoweredVal =
1476 CatPrimOp::create(*builder, ValueRange{srcLoweredVal, src});
1477 } else {
1478 srcLoweredVal =
1479 CatPrimOp::create(*builder, ValueRange{src, srcLoweredVal});
1480 }
1481 }
1482 // Record the total bits already accumulated.
1483 uptoBits += fieldBitwidth;
1484 }
1485 } else {
1486 srcLoweredVal = builder->createOrFold<AsUIntPrimOp>(srcLoweredVal);
1487 }
1488 // Now the input has been cast to srcLoweredVal, which is of UInt type.
1489 // If the result is an aggregate type, then use lowerProducer.
1490 if (type_isa<BundleType, FVectorType>(op.getResult().getType())) {
1491 // uptoBits is used to keep track of the bits that have been extracted.
1492 size_t uptoBits = 0;
1493 auto aggregateBits = *getBitWidth(op.getResult().getType());
1494 auto clone = [&](const FlatBundleFieldEntry &field,
1495 ArrayAttr attrs) -> Value {
1496 // All the fields must have valid bitwidth, a requirement for BitCastOp.
1497 auto fieldBits = *getBitWidth(field.type);
1498 // If empty field, then it doesnot have any use, so replace it with an
1499 // invalid op, which should be trivially removed.
1500 if (fieldBits == 0)
1501 return InvalidValueOp::create(*builder, field.type);
1502
1503 // Assign the field to the corresponding bits from the input.
1504 // Bitcast the field, incase its an aggregate type.
1505 BitsPrimOp extractBits;
1506 if (type_isa<BundleType>(op.getResult().getType())) {
1507 extractBits = BitsPrimOp::create(*builder, srcLoweredVal,
1508 aggregateBits - uptoBits - 1,
1509 aggregateBits - uptoBits - fieldBits);
1510 } else {
1511 extractBits = BitsPrimOp::create(*builder, srcLoweredVal,
1512 uptoBits + fieldBits - 1, uptoBits);
1513 }
1514 uptoBits += fieldBits;
1515 return BitCastOp::create(*builder, field.type, extractBits);
1516 };
1517 return lowerProducer(op, clone);
1518 }
1519
1520 // If ground type, then replace the result.
1521 if (type_isa<SIntType>(op.getType()))
1522 srcLoweredVal = AsSIntPrimOp::create(*builder, srcLoweredVal);
1523 op.getResult().replaceAllUsesWith(srcLoweredVal);
1524 return true;
1525}
1526
1527bool TypeLoweringVisitor::visitExpr(RefSendOp op) {
1528 auto clone = [&](const FlatBundleFieldEntry &field,
1529 ArrayAttr attrs) -> Value {
1530 return RefSendOp::create(*builder,
1531 getSubWhatever(op.getBase(), field.index));
1532 };
1533 // Be careful re:what gets lowered, consider ref.send of non-passive
1534 // and whether we're using the ref or the base type to choose
1535 // whether this should be lowered.
1536 return lowerProducer(op, clone);
1537}
1538
1539bool TypeLoweringVisitor::visitExpr(RefResolveOp op) {
1540 auto clone = [&](const FlatBundleFieldEntry &field,
1541 ArrayAttr attrs) -> Value {
1542 Value src = getSubWhatever(op.getRef(), field.index);
1543 return RefResolveOp::create(*builder, src);
1544 };
1545 // Lower according to lowering of the reference.
1546 // Particularly, preserve if rwprobe.
1547 return lowerProducer(op, clone, op.getRef().getType());
1548}
1549
1550bool TypeLoweringVisitor::visitExpr(RefCastOp op) {
1551 auto clone = [&](const FlatBundleFieldEntry &field,
1552 ArrayAttr attrs) -> Value {
1553 auto input = getSubWhatever(op.getInput(), field.index);
1554 return RefCastOp::create(*builder,
1555 RefType::get(field.type,
1556 op.getType().getForceable(),
1557 op.getType().getLayer()),
1558 input);
1559 };
1560 return lowerProducer(op, clone);
1561}
1562
1563bool TypeLoweringVisitor::visitDecl(InstanceOp op) {
1564 bool skip = true;
1565 SmallVector<Type, 8> resultTypes;
1566 SmallVector<int64_t, 8> endFields; // Compressed sparse row encoding
1567 auto oldPortAnno = op.getPortAnnotations();
1568 SmallVector<Direction> newDirs;
1569 SmallVector<Attribute> newNames;
1570 SmallVector<Attribute> newDomains;
1571 SmallVector<Attribute> newPortAnno;
1572 PreserveAggregate::PreserveMode mode = getPreservationModeForPorts(
1573 cast<FModuleLike>(op.getReferencedOperation(symTbl)));
1574
1575 // Create domain helper to track domain port indices.
1576 DomainLoweringHelper domainHelper(context, op.getResultTypes());
1577
1578 endFields.push_back(0);
1579 for (size_t i = 0, e = op.getNumResults(); i != e; ++i) {
1580 auto srcType = type_cast<FIRRTLType>(op.getType(i));
1581
1582 // Flatten any nested bundle types the usual way.
1583 SmallVector<FlatBundleFieldEntry, 8> fieldTypes;
1584 if (!peelType(srcType, fieldTypes, mode)) {
1585 newDirs.push_back(op.getPortDirection(i));
1586 newNames.push_back(op.getPortNameAttr(i));
1587 newDomains.push_back(op.getPortDomain(i));
1588 resultTypes.push_back(srcType);
1589 newPortAnno.push_back(oldPortAnno[i]);
1590 } else {
1591 skip = false;
1592 auto oldName = op.getPortName(i);
1593 auto oldDir = op.getPortDirection(i);
1594 // Store the flat type for the new bundle type.
1595 for (const auto &field : fieldTypes) {
1596 newDirs.push_back(direction::get((unsigned)oldDir ^ field.isOutput));
1597 newNames.push_back(builder->getStringAttr(oldName + field.suffix));
1598 newDomains.push_back(op.getPortDomain(i));
1599 resultTypes.push_back(mapLoweredType(srcType, field.type));
1600 auto annos = filterAnnotations(
1601 context, dyn_cast_or_null<ArrayAttr>(oldPortAnno[i]), srcType,
1602 field);
1603 newPortAnno.push_back(annos);
1604 }
1605 }
1606 endFields.push_back(resultTypes.size());
1607 }
1608
1609 auto sym = getInnerSymName(op);
1610
1611 if (skip) {
1612 return false;
1613 }
1614
1615 // Compute the mapping from old domain indices to new domain indices.
1616 domainHelper.computeDomainMap(resultTypes);
1617
1618 // Rewrite domain associations to use the new port numbers.
1619 for (auto &domain : newDomains)
1620 domainHelper.rewriteDomain(domain);
1621
1622 // FIXME: annotation update
1623 auto newInstance = InstanceOp::create(
1624 *builder, resultTypes, op.getModuleNameAttr(), op.getNameAttr(),
1625 op.getNameKindAttr(), direction::packAttribute(context, newDirs),
1626 builder->getArrayAttr(newNames), builder->getArrayAttr(newDomains),
1627 op.getAnnotations(), builder->getArrayAttr(newPortAnno),
1628 op.getLayersAttr(), op.getLowerToBindAttr(), op.getDoNotPrintAttr(),
1629 sym ? hw::InnerSymAttr::get(sym) : hw::InnerSymAttr());
1630
1631 newInstance->setDiscardableAttrs(op->getDiscardableAttrDictionary());
1632
1633 SmallVector<Value> lowered;
1634 for (size_t aggIndex = 0, eAgg = op.getNumResults(); aggIndex != eAgg;
1635 ++aggIndex) {
1636 lowered.clear();
1637 for (size_t fieldIndex = endFields[aggIndex],
1638 eField = endFields[aggIndex + 1];
1639 fieldIndex < eField; ++fieldIndex)
1640 lowered.push_back(newInstance.getResult(fieldIndex));
1641 if (lowered.size() != 1 ||
1642 op.getType(aggIndex) != resultTypes[endFields[aggIndex]])
1643 processUsers(op.getResult(aggIndex), lowered);
1644 else
1645 op.getResult(aggIndex).replaceAllUsesWith(lowered[0]);
1646 }
1647 return true;
1648}
1649
1650bool TypeLoweringVisitor::visitExpr(SubaccessOp op) {
1651 auto input = op.getInput();
1652 FVectorType vType = input.getType();
1653
1654 // Check for empty vectors
1655 if (vType.getNumElements() == 0) {
1656 Value inv = InvalidValueOp::create(*builder, vType.getElementType());
1657 op.replaceAllUsesWith(inv);
1658 return true;
1659 }
1660
1661 // Check for constant instances
1662 if (ConstantOp arg =
1663 llvm::dyn_cast_or_null<ConstantOp>(op.getIndex().getDefiningOp())) {
1664 auto sio = SubindexOp::create(*builder, op.getInput(),
1665 arg.getValue().getExtValue());
1666 op.replaceAllUsesWith(sio.getResult());
1667 return true;
1668 }
1669
1670 // Construct a multibit mux
1671 SmallVector<Value> inputs;
1672 inputs.reserve(vType.getNumElements());
1673 for (int index = vType.getNumElements() - 1; index >= 0; index--)
1674 inputs.push_back(SubindexOp::create(*builder, input, index));
1675
1676 Value multibitMux = MultibitMuxOp::create(*builder, op.getIndex(), inputs);
1677 op.replaceAllUsesWith(multibitMux);
1678 return true;
1679}
1680
1681bool TypeLoweringVisitor::visitExpr(VectorCreateOp op) {
1682 auto clone = [&](const FlatBundleFieldEntry &field,
1683 ArrayAttr attrs) -> Value {
1684 return op.getOperand(field.index);
1685 };
1686 return lowerProducer(op, clone);
1687}
1688
1689bool TypeLoweringVisitor::visitExpr(BundleCreateOp op) {
1690 auto clone = [&](const FlatBundleFieldEntry &field,
1691 ArrayAttr attrs) -> Value {
1692 return op.getOperand(field.index);
1693 };
1694 return lowerProducer(op, clone);
1695}
1696
1697bool TypeLoweringVisitor::visitExpr(ElementwiseOrPrimOp op) {
1698 auto clone = [&](const FlatBundleFieldEntry &field,
1699 ArrayAttr attrs) -> Value {
1700 Value operands[] = {getSubWhatever(op.getLhs(), field.index),
1701 getSubWhatever(op.getRhs(), field.index)};
1702 return type_isa<BundleType, FVectorType>(field.type)
1703 ? (Value)ElementwiseOrPrimOp::create(*builder, field.type,
1704 operands)
1705 : (Value)OrPrimOp::create(*builder, operands);
1706 };
1707
1708 return lowerProducer(op, clone);
1709}
1710
1711bool TypeLoweringVisitor::visitExpr(ElementwiseAndPrimOp op) {
1712 auto clone = [&](const FlatBundleFieldEntry &field,
1713 ArrayAttr attrs) -> Value {
1714 Value operands[] = {getSubWhatever(op.getLhs(), field.index),
1715 getSubWhatever(op.getRhs(), field.index)};
1716 return type_isa<BundleType, FVectorType>(field.type)
1717 ? (Value)ElementwiseAndPrimOp::create(*builder, field.type,
1718 operands)
1719 : (Value)AndPrimOp::create(*builder, operands);
1720 };
1721
1722 return lowerProducer(op, clone);
1723}
1724
1725bool TypeLoweringVisitor::visitExpr(ElementwiseXorPrimOp op) {
1726 auto clone = [&](const FlatBundleFieldEntry &field,
1727 ArrayAttr attrs) -> Value {
1728 Value operands[] = {getSubWhatever(op.getLhs(), field.index),
1729 getSubWhatever(op.getRhs(), field.index)};
1730 return type_isa<BundleType, FVectorType>(field.type)
1731 ? (Value)ElementwiseXorPrimOp::create(*builder, field.type,
1732 operands)
1733 : (Value)XorPrimOp::create(*builder, operands);
1734 };
1735
1736 return lowerProducer(op, clone);
1737}
1738
1739bool TypeLoweringVisitor::visitExpr(MultibitMuxOp op) {
1740 auto clone = [&](const FlatBundleFieldEntry &field,
1741 ArrayAttr attrs) -> Value {
1742 SmallVector<Value> newInputs;
1743 newInputs.reserve(op.getInputs().size());
1744 for (auto input : op.getInputs()) {
1745 auto inputSub = getSubWhatever(input, field.index);
1746 newInputs.push_back(inputSub);
1747 }
1748 return MultibitMuxOp::create(*builder, op.getIndex(), newInputs);
1749 };
1750 return lowerProducer(op, clone);
1751}
1752
1753//===----------------------------------------------------------------------===//
1754// Pass Infrastructure
1755//===----------------------------------------------------------------------===//
1756
1757namespace {
1758struct LowerTypesPass
1759 : public circt::firrtl::impl::LowerFIRRTLTypesBase<LowerTypesPass> {
1760 using Base::Base;
1761
1762 void runOnOperation() override;
1763};
1764} // end anonymous namespace
1765
1766// This is the main entrypoint for the lowering pass.
1767void LowerTypesPass::runOnOperation() {
1769
1770 std::vector<FModuleLike> ops;
1771 // Symbol Table
1772 auto &symTbl = getAnalysis<SymbolTable>();
1773 // Cached attr
1774 AttrCache cache(&getContext());
1775
1776 DenseMap<FModuleLike, Convention> conventionTable;
1777 auto circuit = getOperation();
1778 for (auto module : circuit.getOps<FModuleLike>()) {
1779 conventionTable.insert({module, module.getConvention()});
1780 ops.push_back(module);
1781 }
1782
1783 // This lambda, executes in parallel for each Op within the circt.
1784 auto lowerModules = [&](FModuleLike op) -> LogicalResult {
1785 // Use body type lowering attribute if it exists, otherwise use internal.
1786 Convention convention = Convention::Internal;
1787 if (auto conventionAttr = dyn_cast_or_null<ConventionAttr>(
1788 op->getDiscardableAttr("body_type_lowering")))
1789 convention = conventionAttr.getValue();
1790
1791 auto tl =
1792 TypeLoweringVisitor(&getContext(), preserveAggregate, convention,
1793 preserveMemories, symTbl, cache, conventionTable);
1794 tl.lowerModule(op);
1795
1796 return LogicalResult::failure(tl.isFailed());
1797 };
1798
1799 auto result = failableParallelForEach(&getContext(), ops, lowerModules);
1800
1801 if (failed(result))
1802 signalPassFailure();
1803}
assert(baseType &&"element must be base type")
static SmallVector< Value > extractBits(OpBuilder &builder, Value val)
static std::unique_ptr< Context > context
static void dump(DIModule &module, raw_indented_ostream &os)
static void eraseElementsAtIndices(SmallVectorImpl< T > &vec, const llvm::BitVector &removalMask)
Helper function to remove elements from a vector based on a BitVector mask.
static bool isPreservableAggregateType(Type type, PreserveAggregate::PreserveMode mode)
Return true if we can preserve the type.
static FIRRTLType mapLoweredType(FIRRTLType type, FIRRTLBaseType fieldType)
Return fieldType or fieldType as same ref as type.
static MemOp cloneMemWithNewType(ImplicitLocOpBuilder *b, MemOp op, FlatBundleFieldEntry field)
Clone memory for the specified field. Returns null op on error.
static bool containsBundleType(FIRRTLType type)
Return true if the type has a bundle type as subtype.
static Value cloneAccess(ImplicitLocOpBuilder *builder, Operation *op, Value rhs)
static bool peelType(Type type, SmallVectorImpl< FlatBundleFieldEntry > &fields, PreserveAggregate::PreserveMode mode)
Peel one layer of an aggregate type into its components.
static bool isNotSubAccess(Operation *op)
Return if something is not a normal subaccess.
static SmallVector< Operation * > getSAWritePath(Operation *op)
Look through and collect subfields leading to a subaccess.
static bool isOneDimVectorType(FIRRTLType type)
Return true if the type is a 1d vector type or ground type.
#define CIRCT_DEBUG_SCOPED_PASS_LOGGER(PASS)
Definition Debug.h:70
This class provides a read-only projection over the MLIR attributes that represent a set of annotatio...
ArrayAttr getArrayAttr() const
Return this annotation set as an ArrayAttr.
This class provides a read-only projection of an annotation.
DictionaryAttr getDict() const
Get the data dictionary of this attribute.
unsigned getFieldID() const
Get the field id this attribute targets.
void setMember(StringAttr name, Attribute value)
Add or set a member of the annotation to a value.
This class implements the same functionality as TypeSwitch except that it uses firrtl::type_dyn_cast ...
FIRRTLTypeSwitch< T, ResultT > & Case(CallableT &&caseFn)
Add a case on the given type.
FIRRTLVisitor allows you to visit all of the expr/stmt/decls with one class declaration.
ResultType visitInvalidOp(Operation *op, ExtraArgs... args)
visitInvalidOp is an override point for non-FIRRTL dialect operations.
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition CalyxOps.cpp:55
@ All
Preserve all aggregate values.
Definition Passes.h:40
@ OneDimVec
Preserve only 1d vectors of ground type (e.g. UInt<2>[3]).
Definition Passes.h:34
@ Vec
Preserve only vectors (e.g. UInt<2>[3][3]).
Definition Passes.h:37
@ None
Don't preserve aggregate at all.
Definition Passes.h:31
mlir::DenseBoolArrayAttr packAttribute(MLIRContext *context, ArrayRef< Direction > directions)
Return a DenseBoolArrayAttr containing the packed representation of an array of directions.
static Direction get(bool isOutput)
Return an output direction if isOutput is true, otherwise return an input direction.
Definition FIRRTLEnums.h:36
Direction
This represents the direction of a single port.
Definition FIRRTLEnums.h:27
FIRRTLBaseType getBaseType(Type type)
If it is a base type, return it as is.
FIRRTLType mapBaseType(FIRRTLType type, function_ref< FIRRTLBaseType(FIRRTLBaseType)> fn)
Return a FIRRTLType with its base type component mutated by the given function.
bool hasZeroBitWidth(FIRRTLType type)
Return true if the type has zero bit width.
bool type_isa(Type type)
void emitConnect(OpBuilder &builder, Location loc, Value lhs, Value rhs)
Emit a connect between two values.
StringAttr getInnerSymName(Operation *op)
Return the StringAttr for the inner_sym name, if it exists.
Definition FIRRTLOps.h:108
std::optional< int64_t > getBitWidth(FIRRTLBaseType type, bool ignoreFlip=false)
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition hw.py:1
This holds the name and type that describes the module's ports.