CIRCT 23.0.0git
Loading...
Searching...
No Matches
LowerTypes.cpp
Go to the documentation of this file.
1//===- LowerTypes.cpp - Lower Aggregate Types -------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the LowerTypes pass. This pass replaces aggregate types
10// with expanded values.
11//
12// This pass walks the operations in reverse order. This lets it visit users
13// before defs. Users can usually be expanded out to multiple operations (think
14// mux of a bundle to muxes of each field) with a temporary subWhatever op
15// inserted. When processing an aggregate producer, we blow out the op as
16// appropriate, then walk the users, often those are subWhatever ops which can
17// be bypassed and deleted. Function arguments are logically last on the
18// operation visit order and walked left to right, being peeled one layer at a
19// time with replacements inserted to the right of the original argument.
20//
21// Each processing of an op peels one layer of aggregate type off. Because new
22// ops are inserted immediately above the current up, the walk will visit them
23// next, effectively recusing on the aggregate types, without recusing. These
24// potentially temporary ops(if the aggregate is complex) effectively serve as
25// the worklist. Often aggregates are shallow, so the new ops are the final
26// ones.
27//
28//===----------------------------------------------------------------------===//
29
38#include "circt/Support/Debug.h"
39#include "mlir/IR/ImplicitLocOpBuilder.h"
40#include "mlir/IR/Threading.h"
41#include "mlir/Pass/Pass.h"
42#include "llvm/ADT/APSInt.h"
43#include "llvm/ADT/BitVector.h"
44#include "llvm/ADT/STLExtras.h"
45#include "llvm/Support/Debug.h"
46
47#define DEBUG_TYPE "firrtl-lower-types"
48
49namespace circt {
50namespace firrtl {
51#define GEN_PASS_DEF_LOWERFIRRTLTYPES
52#include "circt/Dialect/FIRRTL/Passes.h.inc"
53} // namespace firrtl
54} // namespace circt
55
56using namespace circt;
57using namespace firrtl;
58
59// TODO: check all argument types
60namespace {
61/// This represents a flattened bundle field element.
62struct FlatBundleFieldEntry {
63 /// This is the underlying ground type of the field.
64 FIRRTLBaseType type;
65 /// The index in the parent type
66 size_t index;
67 /// The fieldID
68 unsigned fieldID;
69 /// This is a suffix to add to the field name to make it unique.
70 SmallString<16> suffix;
71 /// This indicates whether the field was flipped to be an output.
72 bool isOutput;
73
74 FlatBundleFieldEntry(const FIRRTLBaseType &type, size_t index,
75 unsigned fieldID, StringRef suffix, bool isOutput)
76 : type(type), index(index), fieldID(fieldID), suffix(suffix),
77 isOutput(isOutput) {}
78
79 void dump() const {
80 llvm::errs() << "FBFE{" << type << " index<" << index << "> fieldID<"
81 << fieldID << "> suffix<" << suffix << "> isOutput<"
82 << isOutput << ">}\n";
83 }
84};
85} // end anonymous namespace
86
87/// Return fieldType or fieldType as same ref as type.
89 return mapBaseType(type, [&](auto) { return fieldType; });
90}
91
92/// Return fieldType or fieldType as same ref as type.
93static Type mapLoweredType(Type type, FIRRTLBaseType fieldType) {
94 auto ftype = type_dyn_cast<FIRRTLType>(type);
95 if (!ftype)
96 return type;
97 return mapLoweredType(ftype, fieldType);
98}
99
100/// Return true if the type is a 1d vector type or ground type.
103 .Case<BundleType>([&](auto bundle) { return false; })
104 .Case<FVectorType>([&](FVectorType vector) {
105 // When the size is 1, lower the vector into a scalar.
106 return vector.getElementType().isGround() &&
107 vector.getNumElements() > 1;
108 })
109 .Default([](auto groundType) { return true; });
110}
111
112// NOLINTBEGIN(misc-no-recursion)
113/// Return true if the type has a bundle type as subtype.
116 .Case<BundleType>([&](auto bundle) { return true; })
117 .Case<FVectorType>([&](FVectorType vector) {
118 return containsBundleType(vector.getElementType());
119 })
120 .Default([](auto groundType) { return false; });
121}
122// NOLINTEND(misc-no-recursion)
123
124/// Return true if we can preserve the type.
125static bool isPreservableAggregateType(Type type,
127 if (auto refType = type_dyn_cast<RefType>(type)) {
128 // Always preserve rwprobe's.
129 if (refType.getForceable())
130 return true;
131 // FIXME: Don't preserve read-only RefType for now. This is workaround for
132 // MemTap which causes type mismatches (issue 4479).
133 return false;
134 }
135
136 // Return false if no aggregate value is preserved.
137 if (mode == PreserveAggregate::None)
138 return false;
139
140 auto firrtlType = type_dyn_cast<FIRRTLBaseType>(type);
141 if (!firrtlType)
142 return false;
143
144 // We can a preserve the type iff (i) the type is not passive, (ii) the type
145 // doesn't contain analog and (iii) type don't contain zero bitwidth.
146 if (!firrtlType.isPassive() || firrtlType.containsAnalog() ||
147 hasZeroBitWidth(firrtlType))
148 return false;
149
150 switch (mode) {
152 return true;
154 return isOneDimVectorType(firrtlType);
156 return !containsBundleType(firrtlType);
157 default:
158 llvm_unreachable("unexpected mode");
159 }
160}
161
162/// Peel one layer of an aggregate type into its components. Type may be
163/// complex, but empty, in which case fields is empty, but the return is true.
164static bool peelType(Type type, SmallVectorImpl<FlatBundleFieldEntry> &fields,
166 // If the aggregate preservation is enabled and the type is preservable,
167 // then just return.
168 if (isPreservableAggregateType(type, mode))
169 return false;
170
171 if (auto refType = type_dyn_cast<RefType>(type))
172 type = refType.getType();
174 .Case<BundleType>([&](auto bundle) {
175 SmallString<16> tmpSuffix;
176 // Otherwise, we have a bundle type. Break it down.
177 for (size_t i = 0, e = bundle.getNumElements(); i < e; ++i) {
178 auto elt = bundle.getElement(i);
179 // Construct the suffix to pass down.
180 tmpSuffix.resize(0);
181 tmpSuffix.push_back('_');
182 tmpSuffix.append(elt.name.getValue());
183 fields.emplace_back(elt.type, i, bundle.getFieldID(i), tmpSuffix,
184 elt.isFlip);
185 }
186 return true;
187 })
188 .Case<FVectorType>([&](auto vector) {
189 // Increment the field ID to point to the first element.
190 for (size_t i = 0, e = vector.getNumElements(); i != e; ++i) {
191 fields.emplace_back(vector.getElementType(), i, vector.getFieldID(i),
192 "_" + std::to_string(i), false);
193 }
194 return true;
195 })
196 .Default([](auto op) { return false; });
197}
198
199/// Return if something is not a normal subaccess. Non-normal includes
200/// zero-length vectors and constant indexes (which are really subindexes).
201static bool isNotSubAccess(Operation *op) {
202 SubaccessOp sao = llvm::dyn_cast<SubaccessOp>(op);
203 if (!sao)
204 return true;
205 ConstantOp arg =
206 llvm::dyn_cast_or_null<ConstantOp>(sao.getIndex().getDefiningOp());
207 return arg && sao.getInput().getType().base().getNumElements() != 0;
208}
209
210/// Look through and collect subfields leading to a subaccess.
211static SmallVector<Operation *> getSAWritePath(Operation *op) {
212 SmallVector<Operation *> retval;
213 auto defOp = op->getOperand(0).getDefiningOp();
214 while (isa_and_nonnull<SubfieldOp, SubindexOp, SubaccessOp>(defOp)) {
215 retval.push_back(defOp);
216 defOp = defOp->getOperand(0).getDefiningOp();
217 }
218 // Trim to the subaccess
219 while (!retval.empty() && isNotSubAccess(retval.back()))
220 retval.pop_back();
221 return retval;
222}
223
224/// Clone memory for the specified field. Returns null op on error.
225static MemOp cloneMemWithNewType(ImplicitLocOpBuilder *b, MemOp op,
226 FlatBundleFieldEntry field) {
227 SmallVector<Type, 8> ports;
228 SmallVector<Attribute, 8> portNames;
229 SmallVector<Attribute, 8> portLocations;
230
231 auto oldPorts = op.getPorts();
232 for (size_t portIdx = 0, e = oldPorts.size(); portIdx < e; ++portIdx) {
233 auto port = oldPorts[portIdx];
234 ports.push_back(
235 MemOp::getTypeForPort(op.getDepth(), field.type, port.second));
236 portNames.push_back(port.first);
237 }
238
239 // It's easier to duplicate the old annotations, then fix and filter them.
240 auto newMem =
241 MemOp::create(*b, ports, op.getReadLatency(), op.getWriteLatency(),
242 op.getDepth(), op.getRuw(), b->getArrayAttr(portNames),
243 (op.getName() + field.suffix).str(), op.getNameKind(),
244 op.getAnnotations(), op.getPortAnnotations(),
245 op.getInnerSymAttr(), op.getInitAttr(), op.getPrefixAttr());
246
247 if (op.getInnerSym()) {
248 op.emitError("cannot split memory with symbol present");
249 return {};
250 }
251
252 SmallVector<Attribute> newAnnotations;
253 for (size_t portIdx = 0, e = newMem.getNumResults(); portIdx < e; ++portIdx) {
254 auto portType = type_cast<BundleType>(newMem.getResult(portIdx).getType());
255 auto oldPortType = type_cast<BundleType>(op.getResult(portIdx).getType());
256 SmallVector<Attribute> portAnno;
257 for (auto attr : newMem.getPortAnnotation(portIdx)) {
258 Annotation anno(attr);
259 if (auto annoFieldID = anno.getFieldID()) {
260 auto targetIndex = oldPortType.getIndexForFieldID(annoFieldID);
261
262 // Apply annotations to all elements if the target is the whole
263 // sub-field.
264 if (annoFieldID == oldPortType.getFieldID(targetIndex)) {
265 anno.setMember(
266 "circt.fieldID",
267 b->getI32IntegerAttr(portType.getFieldID(targetIndex)));
268 portAnno.push_back(anno.getDict());
269 continue;
270 }
271
272 // Handle aggregate sub-fields, including `(r/w)data` and `(w)mask`.
273 if (type_isa<BundleType>(oldPortType.getElement(targetIndex).type)) {
274 // Check whether the annotation falls into the range of the current
275 // field. Note that the `field` here is peeled from the `data`
276 // sub-field of the memory port, thus we need to add the fieldID of
277 // `data` or `mask` sub-field to get the "real" fieldID.
278 auto fieldID = field.fieldID + oldPortType.getFieldID(targetIndex);
279 if (annoFieldID >= fieldID &&
280 annoFieldID <=
281 fieldID + hw::FieldIdImpl::getMaxFieldID(field.type)) {
282 // Set the field ID of the new annotation.
283 auto newFieldID =
284 annoFieldID - fieldID + portType.getFieldID(targetIndex);
285 anno.setMember("circt.fieldID", b->getI32IntegerAttr(newFieldID));
286 portAnno.push_back(anno.getDict());
287 }
288 }
289 } else
290 portAnno.push_back(attr);
291 }
292 newAnnotations.push_back(b->getArrayAttr(portAnno));
293 }
294 newMem.setAllPortAnnotations(newAnnotations);
295 return newMem;
296}
297
298//===----------------------------------------------------------------------===//
299// Module Type Lowering
300//===----------------------------------------------------------------------===//
301namespace {
302
303struct AttrCache {
304 AttrCache(MLIRContext *context) {
305 i64ty = IntegerType::get(context, 64);
306 nameAttr = StringAttr::get(context, "name");
307 nameKindAttr = StringAttr::get(context, "nameKind");
308 sPortDirections = StringAttr::get(context, "portDirections");
309 sPortNames = StringAttr::get(context, "portNames");
310 sPortTypes = StringAttr::get(context, "portTypes");
311 sPortSymbols = StringAttr::get(context, "portSymbols");
312 sPortLocations = StringAttr::get(context, "portLocations");
313 sPortAnnotations = StringAttr::get(context, "portAnnotations");
314 sPortDomains = StringAttr::get(context, "domainInfo");
315 sEmpty = StringAttr::get(context, "");
316 aEmpty = ArrayAttr::get(context, {});
317 }
318 AttrCache(const AttrCache &) = default;
319
320 Type i64ty;
321 StringAttr nameAttr, nameKindAttr, sPortDirections, sPortNames, sPortTypes,
322 sPortSymbols, sPortLocations, sPortAnnotations, sPortDomains, sEmpty;
323 ArrayAttr aEmpty;
324};
325
326/// Helper class to handle domain lowering consistently across modules,
327/// extmodules, and instances. This class tracks domain port indices and
328/// provides methods to rewrite domain associations after port lowering.
329class DomainLoweringHelper {
330public:
331 /// Construct a helper by scanning the original port types for domain types.
332 /// For modules/extmodules, pass the port types attribute array.
333 /// For instances, pass the result types directly.
334 DomainLoweringHelper(MLIRContext *context, ArrayRef<Attribute> portTypes)
335 : context(context) {
336 for (auto [index, typeAttr] : llvm::enumerate(portTypes))
337 if (type_isa<DomainType>(cast<TypeAttr>(typeAttr).getValue()))
338 domainIndexByOrdinal.push_back(index);
339 }
340
341 /// Construct a helper by scanning instance result types for domain types.
342 DomainLoweringHelper(MLIRContext *context, TypeRange resultTypes)
343 : context(context) {
344 for (auto [index, type] : llvm::enumerate(resultTypes))
345 if (type_isa<DomainType>(type))
346 domainIndexByOrdinal.push_back(index);
347 }
348
349 /// Compute the mapping from old domain port indices to new port indices after
350 /// type lowering. Call this after ports have been lowered but before
351 /// rewriting domain associations. This overload takes a range of types
352 /// directly (e.g., from instances).
353 void computeDomainMap(TypeRange types) {
354 size_t i = 0, ord = 0;
355 for (auto type : types) {
356 if (type_isa<DomainType>(type))
357 domainMap[domainIndexByOrdinal[ord++]] = i;
358 ++i;
359 }
360 }
361
362 /// Compute the mapping from old domain port indices to new port indices after
363 /// type lowering. Call this after ports have been lowered but before
364 /// rewriting domain associations. This overload takes a range of PortInfo
365 /// and extracts types from them (e.g., from modules/extmodules).
366 void computeDomainMap(ArrayRef<PortInfo> ports) {
367 size_t i = 0, ord = 0;
368 for (const auto &port : ports) {
369 if (type_isa<DomainType>(port.type))
370 domainMap[domainIndexByOrdinal[ord++]] = i;
371 ++i;
372 }
373 }
374
375 /// Rewrite a domain attribute to use new port indices. The domain attribute
376 /// contains an array of port indices that need to be updated to reflect the
377 /// new port numbering after type lowering.
378 void rewriteDomain(Attribute &domain) {
379 auto oldAssociations = dyn_cast<ArrayAttr>(domain);
380 if (!oldAssociations)
381 return;
382 SmallVector<Attribute> newAssociations;
383 for (auto oldAttr : oldAssociations)
384 newAssociations.push_back(IntegerAttr::get(
385 IntegerType::get(context, 32, IntegerType::Unsigned),
386 domainMap[cast<IntegerAttr>(oldAttr).getValue().getZExtValue()]));
387 domain = ArrayAttr::get(context, newAssociations);
388 }
389
390private:
391 MLIRContext *context;
392 /// Maps ordinal position of domain ports to their original indices.
393 SmallVector<unsigned> domainIndexByOrdinal;
394 /// Maps old port indices to new port indices after lowering.
395 DenseMap<unsigned, unsigned> domainMap;
396};
397
398// The visitors all return true if the operation should be deleted, false if
399// not.
400struct TypeLoweringVisitor : public FIRRTLVisitor<TypeLoweringVisitor, bool> {
401
402 TypeLoweringVisitor(
403 MLIRContext *context, PreserveAggregate::PreserveMode preserveAggregate,
404 Convention bodyConvention,
405 PreserveAggregate::PreserveMode memoryPreservationMode,
406 SymbolTable &symTbl, const AttrCache &cache,
407 const llvm::DenseMap<FModuleLike, Convention> &conventionTable)
408 : context(context), defaultAggregatePreservationMode(preserveAggregate),
409 memoryPreservationMode(memoryPreservationMode), symTbl(symTbl),
410 cache(cache), conventionTable(conventionTable) {
411 bodyAggregatePreservationMode = bodyConvention == Convention::Scalarized
413 : defaultAggregatePreservationMode;
414 }
415 using FIRRTLVisitor<TypeLoweringVisitor, bool>::visitDecl;
416 using FIRRTLVisitor<TypeLoweringVisitor, bool>::visitExpr;
417 using FIRRTLVisitor<TypeLoweringVisitor, bool>::visitStmt;
418
419 /// If the referenced operation is a FModuleOp or an FExtModuleOp, perform
420 /// type lowering on all operations.
421 void lowerModule(FModuleLike op);
422
423 bool lowerArg(FModuleLike module, size_t argIndex, size_t argsRemoved,
424 SmallVectorImpl<PortInfo> &newArgs,
425 SmallVectorImpl<Value> &lowering);
426 std::pair<Value, PortInfo> addArg(Operation *module, unsigned insertPt,
427 unsigned insertPtOffset, FIRRTLType srcType,
428 const FlatBundleFieldEntry &field,
429 PortInfo &oldArg, hw::InnerSymAttr newSym);
430
431 // Helpers to manage state.
432 bool visitDecl(FExtModuleOp op);
433 bool visitDecl(FModuleOp op);
434 bool visitDecl(InstanceOp op);
435 bool visitDecl(InstanceChoiceOp op);
436 bool visitDecl(MemOp op);
437 bool visitDecl(NodeOp op);
438 bool visitDecl(RegOp op);
439 bool visitDecl(WireOp op);
440 bool visitDecl(RegResetOp op);
441 bool visitExpr(InvalidValueOp op);
442 bool visitExpr(SubaccessOp op);
443 bool visitExpr(VectorCreateOp op);
444 bool visitExpr(BundleCreateOp op);
445 bool visitExpr(ElementwiseAndPrimOp op);
446 bool visitExpr(ElementwiseOrPrimOp op);
447 bool visitExpr(ElementwiseXorPrimOp op);
448 bool visitExpr(MultibitMuxOp op);
449 bool visitExpr(MuxPrimOp op);
450 bool visitExpr(Mux2CellIntrinsicOp op);
451 bool visitExpr(Mux4CellIntrinsicOp op);
452 bool visitExpr(BitCastOp op);
453 bool visitExpr(RefSendOp op);
454 bool visitExpr(RefResolveOp op);
455 bool visitExpr(RefCastOp op);
456 bool visitStmt(ConnectOp op);
457 bool visitStmt(MatchingConnectOp op);
458 bool visitStmt(RefDefineOp op);
459 bool visitStmt(WhenOp op);
460 bool visitStmt(LayerBlockOp op);
461 bool visitUnrealizedConversionCast(mlir::UnrealizedConversionCastOp op);
462
463 bool isFailed() const { return encounteredError; }
464
465 bool visitInvalidOp(Operation *op) {
466 if (auto castOp = dyn_cast<mlir::UnrealizedConversionCastOp>(op))
467 return visitUnrealizedConversionCast(castOp);
468 return false;
469 }
470
471private:
472 void processUsers(Value val, ArrayRef<Value> mapping);
473 bool processSAPath(Operation *);
474 void lowerBlock(Block *);
475 void lowerSAWritePath(Operation *, ArrayRef<Operation *> writePath);
476
477 /// Lower a "producer" operation one layer based on policy.
478 /// Use the provided \p clone function to generate individual ops for
479 /// the expanded subelements/fields. The type used to determine if lowering
480 /// is needed is either \p srcType if provided or from the assumed-to-exist
481 /// first result of the operation. When lowering, the clone callback will be
482 /// invoked with each subelement/field of this type.
483 bool lowerProducer(
484 Operation *op,
485 llvm::function_ref<Value(const FlatBundleFieldEntry &, ArrayAttr)> clone,
486 Type srcType = {});
487
488 /// Filter out and return \p annotations that target includes \field,
489 /// modifying as needed to adjust fieldID's relative to to \field.
490 ArrayAttr filterAnnotations(MLIRContext *ctxt, ArrayAttr annotations,
491 FIRRTLType srcType, FlatBundleFieldEntry field);
492
493 /// Partition inner symbols on given type. Fails if any symbols
494 /// cannot be assigned to a field, such as inner symbol on root.
495 LogicalResult partitionSymbols(hw::InnerSymAttr sym, FIRRTLType parentType,
496 SmallVectorImpl<hw::InnerSymAttr> &newSyms,
497 Location errorLoc);
498
500 getPreservationModeForPorts(FModuleLike moduleLike);
501 Value getSubWhatever(Value val, size_t index);
502
503 /// Helper function to lower instance-like operations (InstanceOp and
504 /// InstanceChoiceOp).
505 bool lowerInstanceLike(FInstanceLike op, PreserveAggregate::PreserveMode mode,
506 ArrayAttr oldPortAnno,
507 llvm::function_ref<Operation *(
508 ArrayRef<Type>, ArrayRef<Direction>, ArrayAttr,
509 ArrayAttr, ArrayAttr, hw::InnerSymAttr)>
510 createNewInstance);
511
512 size_t uniqueIdx = 0;
513 std::string uniqueName() {
514 auto myID = uniqueIdx++;
515 return (Twine("__GEN_") + Twine(myID)).str();
516 }
517
518 MLIRContext *context;
519
520 /// Aggregate preservation mode.
521 PreserveAggregate::PreserveMode defaultAggregatePreservationMode;
522 PreserveAggregate::PreserveMode bodyAggregatePreservationMode;
523 PreserveAggregate::PreserveMode memoryPreservationMode;
524
525 /// The builder is set and maintained in the main loop.
526 ImplicitLocOpBuilder *builder;
527
528 // Keep a symbol table around for resolving symbols
529 SymbolTable &symTbl;
530
531 // Cache some attributes
532 const AttrCache &cache;
533
534 const llvm::DenseMap<FModuleLike, Convention> &conventionTable;
535
536 // Set true if the lowering failed.
537 bool encounteredError = false;
538};
539} // namespace
540
541/// Return aggregate preservation mode for the module ports. If the module has a
542/// scalarized linkage, then we may not preserve it's aggregate ports.
544TypeLoweringVisitor::getPreservationModeForPorts(FModuleLike module) {
545 auto lookup = conventionTable.find(module);
546 if (lookup == conventionTable.end())
547 return defaultAggregatePreservationMode;
548 switch (lookup->second) {
549 case Convention::Scalarized:
551 case Convention::Internal:
552 return defaultAggregatePreservationMode;
553 }
554 llvm_unreachable("Unknown convention");
555 return defaultAggregatePreservationMode;
556}
557
558Value TypeLoweringVisitor::getSubWhatever(Value val, size_t index) {
559 if (type_isa<BundleType>(val.getType()))
560 return SubfieldOp::create(*builder, val, index);
561 if (type_isa<FVectorType>(val.getType()))
562 return SubindexOp::create(*builder, val, index);
563 if (type_isa<RefType>(val.getType()))
564 return RefSubOp::create(*builder, val, index);
565 llvm_unreachable("Unknown aggregate type");
566 return nullptr;
567}
568
569/// Conditionally expand a subaccessop write path
570bool TypeLoweringVisitor::processSAPath(Operation *op) {
571 // Does this LHS have a subaccessop?
572 SmallVector<Operation *> writePath = getSAWritePath(op);
573 if (writePath.empty())
574 return false;
575
576 lowerSAWritePath(op, writePath);
577 // Unhook the writePath from the connect. This isn't the right type, but we
578 // are deleting the op anyway.
579 op->eraseOperands(0, 2);
580 // See how far up the tree we can delete things.
581 for (size_t i = 0; i < writePath.size(); ++i) {
582 if (writePath[i]->use_empty()) {
583 writePath[i]->erase();
584 } else {
585 break;
586 }
587 }
588 return true;
589}
590
591void TypeLoweringVisitor::lowerBlock(Block *block) {
592 // Lower the operations bottom up.
593 for (auto it = block->rbegin(), e = block->rend(); it != e;) {
594 auto &iop = *it;
595 builder->setInsertionPoint(&iop);
596 builder->setLoc(iop.getLoc());
597 bool removeOp = dispatchVisitor(&iop);
598 ++it;
599 // Erase old ops eagerly so we don't have dangling uses we've already
600 // lowered.
601 if (removeOp)
602 iop.erase();
603 }
604}
605
606ArrayAttr TypeLoweringVisitor::filterAnnotations(MLIRContext *ctxt,
607 ArrayAttr annotations,
608 FIRRTLType srcType,
609 FlatBundleFieldEntry field) {
610 SmallVector<Attribute> retval;
611 if (!annotations || annotations.empty())
612 return ArrayAttr::get(ctxt, retval);
613 for (auto opAttr : annotations) {
614 Annotation anno(opAttr);
615 auto fieldID = anno.getFieldID();
616 anno.removeMember("circt.fieldID");
617
618 // If no fieldID set, or points to root, forward the annotation without the
619 // fieldID field (which was removed above).
620 if (fieldID == 0) {
621 retval.push_back(anno.getAttr());
622 continue;
623 }
624 // Check whether the annotation falls into the range of the current field.
625
626 if (fieldID < field.fieldID ||
627 fieldID > field.fieldID + hw::FieldIdImpl::getMaxFieldID(field.type))
628 continue;
629
630 // Add fieldID back if non-zero relative to this field.
631 if (auto newFieldID = fieldID - field.fieldID) {
632 // If the target is a subfield/subindex of the current field, create a
633 // new annotation with the correct circt.fieldID.
634 anno.setMember("circt.fieldID", builder->getI32IntegerAttr(newFieldID));
635 }
636
637 retval.push_back(anno.getAttr());
638 }
639 return ArrayAttr::get(ctxt, retval);
640}
641
642LogicalResult TypeLoweringVisitor::partitionSymbols(
643 hw::InnerSymAttr sym, FIRRTLType parentType,
644 SmallVectorImpl<hw::InnerSymAttr> &newSyms, Location errorLoc) {
645
646 // No symbol, nothing to partition.
647 if (!sym || sym.empty())
648 return success();
649
650 auto *context = sym.getContext();
651
652 auto baseType = getBaseType(parentType);
653 if (!baseType)
654 return mlir::emitError(errorLoc,
655 "unable to partition symbol on unsupported type ")
656 << parentType;
657
658 return TypeSwitch<FIRRTLType, LogicalResult>(baseType)
659 .Case<BundleType, FVectorType>([&](auto aggType) -> LogicalResult {
660 struct BinningInfo {
661 uint64_t index;
662 uint64_t relFieldID;
663 hw::InnerSymPropertiesAttr prop;
664 };
665
666 // Walk each inner symbol, compute binning information/assignment.
667 SmallVector<BinningInfo> binning;
668 for (auto prop : sym) {
669 auto fieldID = prop.getFieldID();
670 // Special-case fieldID == 0, helper methods require non-zero fieldID.
671 if (fieldID == 0)
672 return mlir::emitError(errorLoc, "unable to lower due to symbol ")
673 << prop.getName()
674 << " with target not preserved by lowering";
675 auto [index, relFieldID] = aggType.getIndexAndSubfieldID(fieldID);
676 binning.push_back({index, relFieldID, prop});
677 }
678
679 // Sort by index, fieldID.
680 llvm::stable_sort(binning, [&](auto &lhs, auto &rhs) {
681 return std::tuple(lhs.index, lhs.relFieldID) <
682 std::tuple(rhs.index, rhs.relFieldID);
683 });
684 assert(!binning.empty());
685
686 // Populate newSyms, group all symbols on same index.
687 newSyms.resize(aggType.getNumElements());
688 for (auto binIt = binning.begin(), binEnd = binning.end();
689 binIt != binEnd;) {
690 auto curIndex = binIt->index;
691 SmallVector<hw::InnerSymPropertiesAttr> propsForIndex;
692 // Gather all adjacent symbols for this index.
693 while (binIt != binEnd && binIt->index == curIndex) {
694 propsForIndex.push_back(hw::InnerSymPropertiesAttr::get(
695 context, binIt->prop.getName(), binIt->relFieldID,
696 binIt->prop.getSymVisibility()));
697 ++binIt;
698 }
699
700 assert(!newSyms[curIndex]);
701 newSyms[curIndex] = hw::InnerSymAttr::get(context, propsForIndex);
702 }
703 return success();
704 })
705 .Default([&](auto ty) {
706 return mlir::emitError(
707 errorLoc, "unable to partition symbol on unsupported type ")
708 << ty;
709 });
710}
711
712bool TypeLoweringVisitor::lowerProducer(
713 Operation *op,
714 llvm::function_ref<Value(const FlatBundleFieldEntry &, ArrayAttr)> clone,
715 Type srcType) {
716
717 if (!srcType)
718 srcType = op->getResult(0).getType();
719 auto srcFType = type_dyn_cast<FIRRTLType>(srcType);
720 if (!srcFType)
721 return false;
722 SmallVector<FlatBundleFieldEntry, 8> fieldTypes;
723
724 if (!peelType(srcFType, fieldTypes, bodyAggregatePreservationMode))
725 return false;
726
727 SmallVector<Value> lowered;
728 // Loop over the leaf aggregates.
729 SmallString<16> loweredName;
730 auto nameKindAttr = op->getAttrOfType<NameKindEnumAttr>(cache.nameKindAttr);
731
732 if (auto nameAttr = op->getAttrOfType<StringAttr>(cache.nameAttr))
733 loweredName = nameAttr.getValue();
734 auto baseNameLen = loweredName.size();
735 auto oldAnno = dyn_cast_or_null<ArrayAttr>(op->getAttr("annotations"));
736
737 SmallVector<hw::InnerSymAttr> fieldSyms(fieldTypes.size());
738 if (auto symOp = dyn_cast<hw::InnerSymbolOpInterface>(op)) {
739 if (failed(partitionSymbols(symOp.getInnerSymAttr(), srcFType, fieldSyms,
740 symOp.getLoc()))) {
741 encounteredError = true;
742 return false;
743 }
744 }
745
746 for (const auto &[field, sym] : llvm::zip_equal(fieldTypes, fieldSyms)) {
747 if (!loweredName.empty()) {
748 loweredName.resize(baseNameLen);
749 loweredName += field.suffix;
750 }
751
752 // For all annotations on the parent op, filter them based on the target
753 // attribute.
754 ArrayAttr loweredAttrs =
755 filterAnnotations(context, oldAnno, srcFType, field);
756 auto newVal = clone(field, loweredAttrs);
757
758 // If inner symbols on this field, add to new op.
759 if (sym) {
760 // Splitting up something with symbols on it should lower to ops
761 // that also can have symbols on them.
762 auto newSymOp = newVal.getDefiningOp<hw::InnerSymbolOpInterface>();
763 assert(
764 newSymOp &&
765 "op with inner symbol lowered to op that cannot take inner symbol");
766 newSymOp.setInnerSymbolAttr(sym);
767 }
768
769 // Carry over the name, if present.
770 if (auto *newOp = newVal.getDefiningOp()) {
771 if (!loweredName.empty())
772 newOp->setAttr(cache.nameAttr, StringAttr::get(context, loweredName));
773 if (nameKindAttr)
774 newOp->setAttr(cache.nameKindAttr, nameKindAttr);
775
776 // Clone discardable attributes as well.
777 newOp->setDiscardableAttrs(op->getDiscardableAttrDictionary());
778 }
779 lowered.push_back(newVal);
780 }
781
782 processUsers(op->getResult(0), lowered);
783 return true;
784}
785
786void TypeLoweringVisitor::processUsers(Value val, ArrayRef<Value> mapping) {
787 for (auto *user : llvm::make_early_inc_range(val.getUsers())) {
788 TypeSwitch<Operation *, void>(user)
789 .Case<SubindexOp>([mapping](SubindexOp sio) {
790 Value repl = mapping[sio.getIndex()];
791 sio.replaceAllUsesWith(repl);
792 sio.erase();
793 })
794 .Case<SubfieldOp>([mapping](SubfieldOp sfo) {
795 // Get the input bundle type.
796 Value repl = mapping[sfo.getFieldIndex()];
797 sfo.replaceAllUsesWith(repl);
798 sfo.erase();
799 })
800 .Case<RefSubOp>([mapping](RefSubOp refSub) {
801 Value repl = mapping[refSub.getIndex()];
802 refSub.replaceAllUsesWith(repl);
803 refSub.erase();
804 })
805 .Default([&](auto op) {
806 // This means we have already processed the user, and it didn't lower
807 // its inputs. This is an opaque user, which will continue to have
808 // aggregate type as input, even after LowerTypes. So, construct the
809 // vector/bundle back from the lowered elements to ensure a valid
810 // input into the opaque op. This only supports Bundles and Vectors.
811
812 // This builder ensures that the aggregate construction happens at the
813 // user location, and the LowerTypes algorithm will not touch them any
814 // more, because LowerTypes was reverse iterating on the block and the
815 // user has already been processed.
816 ImplicitLocOpBuilder b(user->getLoc(), user);
817
818 // This shouldn't happen (non-FIRRTLBaseType's in lowered types, or
819 // refs), check explicitly here for clarity/early detection.
820 assert(llvm::none_of(mapping, [](auto v) {
821 auto fbasetype = type_dyn_cast<FIRRTLBaseType>(v.getType());
822 return !fbasetype || fbasetype.containsReference();
823 }));
824
825 Value input =
826 TypeSwitch<Type, Value>(val.getType())
827 .template Case<FVectorType>([&](auto vecType) {
828 return b.createOrFold<VectorCreateOp>(vecType, mapping);
829 })
830 .template Case<BundleType>([&](auto bundleType) {
831 return b.createOrFold<BundleCreateOp>(bundleType, mapping);
832 })
833 .Default([&](auto _) -> Value { return {}; });
834 if (!input) {
835 user->emitError("unable to reconstruct source of type ")
836 << val.getType();
837 encounteredError = true;
838 return;
839 }
840 user->replaceUsesOfWith(val, input);
841 });
842 }
843}
844
845/// Helper function to remove elements from a vector based on a BitVector mask.
846template <typename T>
847static void eraseElementsAtIndices(SmallVectorImpl<T> &vec,
848 const llvm::BitVector &removalMask) {
849 size_t writeIndex = 0, readIndex = 0;
850
851 // Iterate over each set bit (element to remove) in the mask.
852 // Between each removal point, we bulk-copy the range of elements to keep.
853 for (size_t removalIndex : removalMask.set_bits()) {
854 // Copy the range [readIndex, removalIndex) - these are elements to keep.
855 assert(removalIndex >= readIndex && "removal index before read index");
856 size_t rangeSize = removalIndex - readIndex;
857 if (rangeSize > 0) {
858 // Bulk move the range of elements to keep to the write position.
859 // Skip if the read and write positions are the same (= the first
860 // iteration).
861 if (writeIndex != readIndex)
862 std::move(vec.begin() + readIndex, vec.begin() + removalIndex,
863 vec.begin() + writeIndex);
864 writeIndex += rangeSize;
865 }
866 readIndex = removalIndex + 1;
867 }
868
869 // Copy any remaining elements after the last removal point.
870 size_t remainingSize = vec.size() - readIndex;
871 if (remainingSize > 0) {
872 if (writeIndex != readIndex)
873 std::move(vec.begin() + readIndex, vec.end(), vec.begin() + writeIndex);
874 writeIndex += remainingSize;
875 }
876
877 // Truncate the vector to the new size (number of elements kept).
878 vec.truncate(writeIndex);
879}
880
881void TypeLoweringVisitor::lowerModule(FModuleLike op) {
882 if (auto module = llvm::dyn_cast<FModuleOp>(*op))
883 visitDecl(module);
884 else if (auto extModule = llvm::dyn_cast<FExtModuleOp>(*op))
885 visitDecl(extModule);
886}
887
888// Creates and returns a new block argument of the specified type to the
889// module. This also maintains the name attribute for the new argument,
890// possibly with a new suffix appended.
891std::pair<Value, PortInfo>
892TypeLoweringVisitor::addArg(Operation *module, unsigned insertPt,
893 unsigned insertPtOffset, FIRRTLType srcType,
894 const FlatBundleFieldEntry &field, PortInfo &oldArg,
895 hw::InnerSymAttr newSym) {
896 Value newValue;
897 FIRRTLType fieldType = mapLoweredType(srcType, field.type);
898 if (auto mod = llvm::dyn_cast<FModuleOp>(module)) {
899 Block *body = mod.getBodyBlock();
900 // Append the new argument.
901 newValue = body->insertArgument(insertPt, fieldType, oldArg.loc);
902 }
903
904 // Save the name attribute for the new argument.
905 auto name = builder->getStringAttr(oldArg.name.getValue() + field.suffix);
906
907 // Populate the new arg attributes.
908 auto newAnnotations = filterAnnotations(
909 context, oldArg.annotations.getArrayAttr(), srcType, field);
910 // Flip the direction if the field is an output.
911 auto direction = (Direction)((unsigned)oldArg.direction ^ field.isOutput);
912
913 return std::make_pair(
914 newValue, PortInfo{name, fieldType, direction, newSym, oldArg.loc,
915 AnnotationSet(newAnnotations), oldArg.domains});
916}
917
918// Lower arguments with bundle type by flattening them.
919bool TypeLoweringVisitor::lowerArg(FModuleLike module, size_t argIndex,
920 size_t argsRemoved,
921 SmallVectorImpl<PortInfo> &newArgs,
922 SmallVectorImpl<Value> &lowering) {
923
924 // Flatten any bundle types.
925 SmallVector<FlatBundleFieldEntry> fieldTypes;
926 auto srcType = type_cast<FIRRTLType>(newArgs[argIndex].type);
927 if (!peelType(srcType, fieldTypes, getPreservationModeForPorts(module)))
928 return false;
929
930 SmallVector<hw::InnerSymAttr> fieldSyms(fieldTypes.size());
931 if (failed(partitionSymbols(newArgs[argIndex].sym, srcType, fieldSyms,
932 newArgs[argIndex].loc))) {
933 encounteredError = true;
934 return false;
935 }
936
937 for (const auto &[idx, field, fieldSym] :
938 llvm::enumerate(fieldTypes, fieldSyms)) {
939 auto newValue = addArg(module, 1 + argIndex + idx, argsRemoved, srcType,
940 field, newArgs[argIndex], fieldSym);
941 newArgs.insert(newArgs.begin() + 1 + argIndex + idx, newValue.second);
942 // Lower any other arguments by copying them to keep the relative order.
943 lowering.push_back(newValue.first);
944 }
945 return true;
946}
947
948static Value cloneAccess(ImplicitLocOpBuilder *builder, Operation *op,
949 Value rhs) {
950 if (auto rop = llvm::dyn_cast<SubfieldOp>(op))
951 return SubfieldOp::create(*builder, rhs, rop.getFieldIndex());
952 if (auto rop = llvm::dyn_cast<SubindexOp>(op))
953 return SubindexOp::create(*builder, rhs, rop.getIndex());
954 if (auto rop = llvm::dyn_cast<SubaccessOp>(op))
955 return SubaccessOp::create(*builder, rhs, rop.getIndex());
956 op->emitError("Unknown accessor");
957 return nullptr;
958}
959
960void TypeLoweringVisitor::lowerSAWritePath(Operation *op,
961 ArrayRef<Operation *> writePath) {
962 SubaccessOp sao = cast<SubaccessOp>(writePath.back());
963 FVectorType saoType = sao.getInput().getType();
964 auto selectWidth = llvm::Log2_64_Ceil(saoType.getNumElements());
965
966 for (size_t index = 0, e = saoType.getNumElements(); index < e; ++index) {
967 auto cond = EQPrimOp::create(
968 *builder, sao.getIndex(),
969 builder->createOrFold<ConstantOp>(UIntType::get(context, selectWidth),
970 APInt(selectWidth, index)));
971 WhenOp::create(*builder, cond, false, [&]() {
972 // Recreate the write Path
973 Value leaf = SubindexOp::create(*builder, sao.getInput(), index);
974 for (int i = writePath.size() - 2; i >= 0; --i) {
975 if (auto access = cloneAccess(builder, writePath[i], leaf))
976 leaf = access;
977 else {
978 encounteredError = true;
979 return;
980 }
981 }
982
983 emitConnect(*builder, leaf, op->getOperand(1));
984 });
985 }
986}
987
988// Expand connects of aggregates
989bool TypeLoweringVisitor::visitStmt(ConnectOp op) {
990 if (processSAPath(op))
991 return true;
992
993 // Attempt to get the bundle types.
994 SmallVector<FlatBundleFieldEntry> fields;
995
996 // We have to expand connections even if the aggregate preservation is true.
997 if (!peelType(op.getDest().getType(), fields, PreserveAggregate::None))
998 return false;
999
1000 // Loop over the leaf aggregates.
1001 for (const auto &field : llvm::enumerate(fields)) {
1002 Value src = getSubWhatever(op.getSrc(), field.index());
1003 Value dest = getSubWhatever(op.getDest(), field.index());
1004 if (field.value().isOutput)
1005 std::swap(src, dest);
1006 emitConnect(*builder, dest, src);
1007 }
1008 return true;
1009}
1010
1011// Expand connects of aggregates
1012bool TypeLoweringVisitor::visitStmt(MatchingConnectOp op) {
1013 if (processSAPath(op))
1014 return true;
1015
1016 // Attempt to get the bundle types.
1017 SmallVector<FlatBundleFieldEntry> fields;
1018
1019 // We have to expand connections even if the aggregate preservation is true.
1020 if (!peelType(op.getDest().getType(), fields, PreserveAggregate::None))
1021 return false;
1022
1023 // Loop over the leaf aggregates.
1024 for (const auto &field : llvm::enumerate(fields)) {
1025 Value src = getSubWhatever(op.getSrc(), field.index());
1026 Value dest = getSubWhatever(op.getDest(), field.index());
1027 if (field.value().isOutput)
1028 std::swap(src, dest);
1029 MatchingConnectOp::create(*builder, dest, src);
1030 }
1031 return true;
1032}
1033
1034// Expand connects of references-of-aggregates
1035bool TypeLoweringVisitor::visitStmt(RefDefineOp op) {
1036 // Attempt to get the bundle types.
1037 SmallVector<FlatBundleFieldEntry> fields;
1038
1039 if (!peelType(op.getDest().getType(), fields, bodyAggregatePreservationMode))
1040 return false;
1041
1042 // Loop over the leaf aggregates.
1043 for (const auto &field : llvm::enumerate(fields)) {
1044 Value src = getSubWhatever(op.getSrc(), field.index());
1045 Value dest = getSubWhatever(op.getDest(), field.index());
1046 assert(!field.value().isOutput && "unexpected flip in reftype destination");
1047 RefDefineOp::create(*builder, dest, src);
1048 }
1049 return true;
1050}
1051
1052bool TypeLoweringVisitor::visitStmt(WhenOp op) {
1053 // The WhenOp itself does not require any lowering, the only value it uses
1054 // is a one-bit predicate. Recursively visit all regions so internal
1055 // operations are lowered.
1056
1057 // Visit operations in the then block.
1058 lowerBlock(&op.getThenBlock());
1059
1060 // Visit operations in the else block.
1061 if (op.hasElseRegion())
1062 lowerBlock(&op.getElseBlock());
1063 return false; // don't delete the when!
1064}
1065
1066/// Lower any types declared in layer blocks.
1067bool TypeLoweringVisitor::visitStmt(LayerBlockOp op) {
1068 lowerBlock(op.getBody());
1069 return false;
1070}
1071
1072/// Lower memory operations. A new memory is created for every leaf
1073/// element in a memory's data type.
1074bool TypeLoweringVisitor::visitDecl(MemOp op) {
1075 // Attempt to get the bundle types.
1076 SmallVector<FlatBundleFieldEntry> fields;
1077
1078 // MemOp should have ground types so we can't preserve aggregates.
1079 if (!peelType(op.getDataType(), fields, memoryPreservationMode))
1080 return false;
1081
1082 if (op.getInnerSym()) {
1083 op->emitError() << "has a symbol, but no symbols may exist on aggregates "
1084 "passed through LowerTypes";
1085 encounteredError = true;
1086 return false;
1087 }
1088
1089 SmallVector<MemOp> newMemories;
1090 SmallVector<WireOp> oldPorts;
1091
1092 // Wires for old ports
1093 for (unsigned int index = 0, end = op.getNumResults(); index < end; ++index) {
1094 auto result = op.getResult(index);
1095 if (op.getPortKind(index) == MemOp::PortKind::Debug) {
1096 op.emitOpError("cannot lower memory with debug port");
1097 encounteredError = true;
1098 return false;
1099 }
1100 auto wire =
1101 WireOp::create(*builder, result.getType(),
1102 (op.getName() + "_" + op.getPortName(index)).str());
1103 oldPorts.push_back(wire);
1104 result.replaceAllUsesWith(wire.getResult());
1105 }
1106 // If annotations targeting fields of an aggregate are present, we cannot
1107 // flatten the memory. It must be split into one memory per aggregate field.
1108 // Do not overwrite the pass flag!
1109
1110 // Memory for each field
1111 for (const auto &field : fields) {
1112 auto newMemForField = cloneMemWithNewType(builder, op, field);
1113 if (!newMemForField) {
1114 op.emitError("failed cloning memory for field");
1115 encounteredError = true;
1116 return false;
1117 }
1118 newMemories.push_back(newMemForField);
1119 }
1120 // Hook up the new memories to the wires the old memory was replaced with.
1121 for (size_t index = 0, rend = op.getNumResults(); index < rend; ++index) {
1122 auto result = oldPorts[index].getResult();
1123 auto rType = type_cast<BundleType>(result.getType());
1124 for (size_t fieldIndex = 0, fend = rType.getNumElements();
1125 fieldIndex != fend; ++fieldIndex) {
1126 auto name = rType.getElement(fieldIndex).name.getValue();
1127 auto oldField = SubfieldOp::create(*builder, result, fieldIndex);
1128 // data and mask depend on the memory type which was split. They can also
1129 // go both directions, depending on the port direction.
1130 if (name == "data" || name == "mask" || name == "wdata" ||
1131 name == "wmask" || name == "rdata") {
1132 for (const auto &field : fields) {
1133 auto realOldField = getSubWhatever(oldField, field.index);
1134 auto newField = getSubWhatever(
1135 newMemories[field.index].getResult(index), fieldIndex);
1136 if (rType.getElement(fieldIndex).isFlip)
1137 std::swap(realOldField, newField);
1138 emitConnect(*builder, newField, realOldField);
1139 }
1140 } else {
1141 for (auto mem : newMemories) {
1142 auto newField =
1143 SubfieldOp::create(*builder, mem.getResult(index), fieldIndex);
1144 emitConnect(*builder, newField, oldField);
1145 }
1146 }
1147 }
1148 }
1149 return true;
1150}
1151
1152bool TypeLoweringVisitor::visitDecl(FExtModuleOp extModule) {
1153 ImplicitLocOpBuilder theBuilder(extModule.getLoc(), context);
1154 builder = &theBuilder;
1155
1156 // Top level builder
1157 OpBuilder builder(context);
1158
1159 // Lower the module block arguments.
1160 llvm::BitVector argsToRemove;
1161 auto newArgs = extModule.getPorts();
1162 argsToRemove.reserve(newArgs.size());
1163
1164 DomainLoweringHelper domainHelper(context, extModule.getPortTypes());
1165
1166 size_t argsRemoved = 0;
1167 for (size_t argIndex = 0; argIndex < newArgs.size(); ++argIndex) {
1168 SmallVector<Value> lowering;
1169 if (lowerArg(extModule, argIndex, argsRemoved, newArgs, lowering)) {
1170 argsToRemove.push_back(true);
1171 ++argsRemoved;
1172 } else {
1173 argsToRemove.push_back(false);
1174 }
1175 // lowerArg might have invalidated any reference to newArgs, be careful
1176 }
1177
1178 // Remove block args that have been lowered.
1179 if (argsRemoved != 0)
1180 eraseElementsAtIndices(newArgs, argsToRemove);
1181
1182 domainHelper.computeDomainMap(newArgs);
1183
1184 SmallVector<NamedAttribute, 8> newModuleAttrs;
1185
1186 // Copy over any attributes that weren't original argument attributes.
1187 for (auto attr : extModule->getAttrDictionary())
1188 // Drop old "portNames", directions, and argument attributes. These are
1189 // handled differently below.
1190 if (attr.getName() != "portDirections" && attr.getName() != "portNames" &&
1191 attr.getName() != "portTypes" && attr.getName() != "portAnnotations" &&
1192 attr.getName() != "portSymbols" && attr.getName() != "portLocations")
1193 newModuleAttrs.push_back(attr);
1194
1195 SmallVector<Direction> newArgDirections;
1196 SmallVector<Attribute> newArgNames;
1197 SmallVector<Attribute, 8> newArgTypes;
1198 SmallVector<Attribute, 8> newArgSyms;
1199 SmallVector<Attribute, 8> newArgLocations;
1200 SmallVector<Attribute, 8> newArgAnnotations;
1201 SmallVector<Attribute, 8> newArgDomains;
1202
1203 for (auto &port : newArgs) {
1204 newArgDirections.push_back(port.direction);
1205 newArgNames.push_back(port.name);
1206 newArgTypes.push_back(TypeAttr::get(port.type));
1207 newArgSyms.push_back(port.sym);
1208 newArgLocations.push_back(port.loc);
1209 newArgAnnotations.push_back(port.annotations.getArrayAttr());
1210 if (port.domains) {
1211 domainHelper.rewriteDomain(port.domains);
1212 } else {
1213 port.domains = cache.aEmpty;
1214 }
1215 newArgDomains.push_back(port.domains);
1216 }
1217
1218 newModuleAttrs.push_back(
1219 NamedAttribute(cache.sPortDirections,
1220 direction::packAttribute(context, newArgDirections)));
1221
1222 newModuleAttrs.push_back(
1223 NamedAttribute(cache.sPortNames, builder.getArrayAttr(newArgNames)));
1224
1225 newModuleAttrs.push_back(
1226 NamedAttribute(cache.sPortTypes, builder.getArrayAttr(newArgTypes)));
1227
1228 newModuleAttrs.push_back(NamedAttribute(
1229 cache.sPortLocations, builder.getArrayAttr(newArgLocations)));
1230
1231 newModuleAttrs.push_back(NamedAttribute(
1232 cache.sPortAnnotations, builder.getArrayAttr(newArgAnnotations)));
1233
1234 newModuleAttrs.push_back(
1235 NamedAttribute(cache.sPortDomains, builder.getArrayAttr(newArgDomains)));
1236
1237 // Update the module's attributes.
1238 extModule->setAttrs(newModuleAttrs);
1239 FModuleLike::fixupPortSymsArray(newArgSyms, context);
1240 extModule.setPortSymbols(newArgSyms);
1241
1242 return false;
1243}
1244
1245bool TypeLoweringVisitor::visitDecl(FModuleOp module) {
1246 auto *body = module.getBodyBlock();
1247
1248 ImplicitLocOpBuilder theBuilder(module.getLoc(), context);
1249 builder = &theBuilder;
1250
1251 // Lower the operations.
1252 lowerBlock(body);
1253
1254 // Lower the module block arguments.
1255 llvm::BitVector argsToRemove;
1256 auto newArgs = module.getPorts();
1257 argsToRemove.reserve(newArgs.size());
1258
1259 DomainLoweringHelper domainHelper(context, module.getPortTypes());
1260
1261 size_t argsRemoved = 0;
1262 for (size_t argIndex = 0; argIndex < newArgs.size(); ++argIndex) {
1263 SmallVector<Value> lowerings;
1264 if (lowerArg(module, argIndex, argsRemoved, newArgs, lowerings)) {
1265 auto arg = module.getArgument(argIndex);
1266 processUsers(arg, lowerings);
1267 argsToRemove.push_back(true);
1268 ++argsRemoved;
1269 } else
1270 argsToRemove.push_back(false);
1271 // lowerArg might have invalidated any reference to newArgs, be careful
1272 }
1273
1274 // Remove block args that have been lowered.
1275 if (argsRemoved != 0) {
1276 body->eraseArguments(argsToRemove);
1277 eraseElementsAtIndices(newArgs, argsToRemove);
1278 }
1279
1280 domainHelper.computeDomainMap(newArgs);
1281
1282 SmallVector<NamedAttribute, 8> newModuleAttrs;
1283
1284 // Copy over any attributes that weren't original argument attributes.
1285 for (auto attr : module->getAttrDictionary())
1286 // Drop old "portNames", directions, and argument attributes. These are
1287 // handled differently below.
1288 if (attr.getName() != "portNames" && attr.getName() != "portDirections" &&
1289 attr.getName() != "portTypes" && attr.getName() != "portAnnotations" &&
1290 attr.getName() != "portSymbols" && attr.getName() != "portLocations")
1291 newModuleAttrs.push_back(attr);
1292
1293 SmallVector<Direction> newArgDirections;
1294 SmallVector<Attribute> newArgNames;
1295 SmallVector<Attribute> newArgTypes;
1296 SmallVector<Attribute> newArgSyms;
1297 SmallVector<Attribute> newArgLocations;
1298 SmallVector<Attribute, 8> newArgAnnotations;
1299 SmallVector<Attribute> newPortDomains;
1300 for (auto &port : newArgs) {
1301 newArgDirections.push_back(port.direction);
1302 newArgNames.push_back(port.name);
1303 newArgTypes.push_back(TypeAttr::get(port.type));
1304 newArgSyms.push_back(port.sym);
1305 newArgLocations.push_back(port.loc);
1306 newArgAnnotations.push_back(port.annotations.getArrayAttr());
1307 if (port.domains) {
1308 domainHelper.rewriteDomain(port.domains);
1309 } else {
1310 port.domains = cache.aEmpty;
1311 }
1312 newPortDomains.push_back(port.domains);
1313 }
1314
1315 newModuleAttrs.push_back(
1316 NamedAttribute(cache.sPortDirections,
1317 direction::packAttribute(context, newArgDirections)));
1318
1319 newModuleAttrs.push_back(
1320 NamedAttribute(cache.sPortNames, builder->getArrayAttr(newArgNames)));
1321
1322 newModuleAttrs.push_back(
1323 NamedAttribute(cache.sPortTypes, builder->getArrayAttr(newArgTypes)));
1324
1325 newModuleAttrs.push_back(NamedAttribute(
1326 cache.sPortLocations, builder->getArrayAttr(newArgLocations)));
1327
1328 newModuleAttrs.push_back(NamedAttribute(
1329 cache.sPortAnnotations, builder->getArrayAttr(newArgAnnotations)));
1330
1331 newModuleAttrs.push_back(NamedAttribute(
1332 cache.sPortDomains, builder->getArrayAttr(newPortDomains)));
1333
1334 // Update the module's attributes.
1335 module->setAttrs(newModuleAttrs);
1336 FModuleLike::fixupPortSymsArray(newArgSyms, context);
1337 module.setPortSymbols(newArgSyms);
1338 return false;
1339}
1340
1341/// Lower a wire op with a bundle to multiple non-bundled wires.
1342bool TypeLoweringVisitor::visitDecl(WireOp op) {
1343 if (op.isForceable())
1344 return false;
1345
1346 auto clone = [&](const FlatBundleFieldEntry &field,
1347 ArrayAttr attrs) -> Value {
1348 return WireOp::create(*builder,
1349 mapLoweredType(op.getDataRaw().getType(), field.type),
1350 "", NameKindEnum::DroppableName, attrs, StringAttr{},
1351 false, op.getDomains())
1352 .getResult();
1353 };
1354 return lowerProducer(op, clone);
1355}
1356
1357/// Lower a reg op with a bundle to multiple non-bundled regs.
1358bool TypeLoweringVisitor::visitDecl(RegOp op) {
1359 if (op.isForceable())
1360 return false;
1361
1362 auto clone = [&](const FlatBundleFieldEntry &field,
1363 ArrayAttr attrs) -> Value {
1364 return RegOp::create(*builder, field.type, op.getClockVal(), "",
1365 NameKindEnum::DroppableName, attrs, StringAttr{})
1366 .getResult();
1367 };
1368 return lowerProducer(op, clone);
1369}
1370
1371/// Lower a reg op with a bundle to multiple non-bundled regs.
1372bool TypeLoweringVisitor::visitDecl(RegResetOp op) {
1373 if (op.isForceable())
1374 return false;
1375
1376 auto clone = [&](const FlatBundleFieldEntry &field,
1377 ArrayAttr attrs) -> Value {
1378 auto resetVal = getSubWhatever(op.getResetValue(), field.index);
1379 return RegResetOp::create(*builder, field.type, op.getClockVal(),
1380 op.getResetSignal(), resetVal, "",
1381 NameKindEnum::DroppableName, attrs, StringAttr{})
1382 .getResult();
1383 };
1384 return lowerProducer(op, clone);
1385}
1386
1387/// Lower a wire op with a bundle to multiple non-bundled wires.
1388bool TypeLoweringVisitor::visitDecl(NodeOp op) {
1389 if (op.isForceable())
1390 return false;
1391
1392 auto clone = [&](const FlatBundleFieldEntry &field,
1393 ArrayAttr attrs) -> Value {
1394 auto input = getSubWhatever(op.getInput(), field.index);
1395 return NodeOp::create(*builder, input, "", NameKindEnum::DroppableName,
1396 attrs)
1397 .getResult();
1398 };
1399 return lowerProducer(op, clone);
1400}
1401
1402/// Lower an InvalidValue op with a bundle to multiple non-bundled InvalidOps.
1403bool TypeLoweringVisitor::visitExpr(InvalidValueOp op) {
1404 auto clone = [&](const FlatBundleFieldEntry &field,
1405 ArrayAttr attrs) -> Value {
1406 return InvalidValueOp::create(*builder, field.type);
1407 };
1408 return lowerProducer(op, clone);
1409}
1410
1411// Expand muxes of aggregates
1412bool TypeLoweringVisitor::visitExpr(MuxPrimOp op) {
1413 auto clone = [&](const FlatBundleFieldEntry &field,
1414 ArrayAttr attrs) -> Value {
1415 auto high = getSubWhatever(op.getHigh(), field.index);
1416 auto low = getSubWhatever(op.getLow(), field.index);
1417 return MuxPrimOp::create(*builder, op.getSel(), high, low);
1418 };
1419 return lowerProducer(op, clone);
1420}
1421
1422// Expand muxes of aggregates
1423bool TypeLoweringVisitor::visitExpr(Mux2CellIntrinsicOp op) {
1424 auto clone = [&](const FlatBundleFieldEntry &field,
1425 ArrayAttr attrs) -> Value {
1426 auto high = getSubWhatever(op.getHigh(), field.index);
1427 auto low = getSubWhatever(op.getLow(), field.index);
1428 return Mux2CellIntrinsicOp::create(*builder, op.getSel(), high, low);
1429 };
1430 return lowerProducer(op, clone);
1431}
1432
1433// Expand muxes of aggregates
1434bool TypeLoweringVisitor::visitExpr(Mux4CellIntrinsicOp op) {
1435 auto clone = [&](const FlatBundleFieldEntry &field,
1436 ArrayAttr attrs) -> Value {
1437 auto v3 = getSubWhatever(op.getV3(), field.index);
1438 auto v2 = getSubWhatever(op.getV2(), field.index);
1439 auto v1 = getSubWhatever(op.getV1(), field.index);
1440 auto v0 = getSubWhatever(op.getV0(), field.index);
1441 return Mux4CellIntrinsicOp::create(*builder, op.getSel(), v3, v2, v1, v0);
1442 };
1443 return lowerProducer(op, clone);
1444}
1445
1446// Expand UnrealizedConversionCastOp of aggregates
1447bool TypeLoweringVisitor::visitUnrealizedConversionCast(
1448 mlir::UnrealizedConversionCastOp op) {
1449 auto clone = [&](const FlatBundleFieldEntry &field,
1450 ArrayAttr attrs) -> Value {
1451 auto input = getSubWhatever(op.getOperand(0), field.index);
1452 return mlir::UnrealizedConversionCastOp::create(*builder, field.type, input)
1453 .getResult(0);
1454 };
1455 // If the input to the cast is not a FIRRTL type, getSubWhatever cannot handle
1456 // it, donot lower the op.
1457 if (!type_isa<FIRRTLType>(op->getOperand(0).getType()))
1458 return false;
1459 return lowerProducer(op, clone);
1460}
1461
1462// Expand BitCastOp of aggregates
1463bool TypeLoweringVisitor::visitExpr(BitCastOp op) {
1464 Value srcLoweredVal = op.getInput();
1465 // If the input is of aggregate type, then cat all the leaf fields to form a
1466 // UInt type result. That is, first bitcast the aggregate type to a UInt.
1467 // Attempt to get the bundle types.
1468 SmallVector<FlatBundleFieldEntry> fields;
1469 if (peelType(op.getInput().getType(), fields, PreserveAggregate::None)) {
1470 size_t uptoBits = 0;
1471 // Loop over the leaf aggregates and concat each of them to get a UInt.
1472 // Bitcast the fields to handle nested aggregate types.
1473 for (const auto &field : llvm::enumerate(fields)) {
1474 auto fieldBitwidth = *getBitWidth(field.value().type);
1475 // Ignore zero width fields, like empty bundles.
1476 if (fieldBitwidth == 0)
1477 continue;
1478 Value src = getSubWhatever(op.getInput(), field.index());
1479 // The src could be an aggregate type, bitcast it to a UInt type.
1480 src = builder->createOrFold<BitCastOp>(
1481 UIntType::get(context, fieldBitwidth), src);
1482 // Take the first field, or else Cat the previous fields with this field.
1483 if (uptoBits == 0)
1484 srcLoweredVal = src;
1485 else {
1486 if (type_isa<BundleType>(op.getInput().getType())) {
1487 srcLoweredVal =
1488 CatPrimOp::create(*builder, ValueRange{srcLoweredVal, src});
1489 } else {
1490 srcLoweredVal =
1491 CatPrimOp::create(*builder, ValueRange{src, srcLoweredVal});
1492 }
1493 }
1494 // Record the total bits already accumulated.
1495 uptoBits += fieldBitwidth;
1496 }
1497 } else {
1498 srcLoweredVal = builder->createOrFold<AsUIntPrimOp>(srcLoweredVal);
1499 }
1500 // Now the input has been cast to srcLoweredVal, which is of UInt type.
1501 // If the result is an aggregate type, then use lowerProducer.
1502 if (type_isa<BundleType, FVectorType>(op.getResult().getType())) {
1503 // uptoBits is used to keep track of the bits that have been extracted.
1504 size_t uptoBits = 0;
1505 auto aggregateBits = *getBitWidth(op.getResult().getType());
1506 auto clone = [&](const FlatBundleFieldEntry &field,
1507 ArrayAttr attrs) -> Value {
1508 // All the fields must have valid bitwidth, a requirement for BitCastOp.
1509 auto fieldBits = *getBitWidth(field.type);
1510 // If empty field, then it doesnot have any use, so replace it with an
1511 // invalid op, which should be trivially removed.
1512 if (fieldBits == 0)
1513 return InvalidValueOp::create(*builder, field.type);
1514
1515 // Assign the field to the corresponding bits from the input.
1516 // Bitcast the field, incase its an aggregate type.
1517 BitsPrimOp extractBits;
1518 if (type_isa<BundleType>(op.getResult().getType())) {
1519 extractBits = BitsPrimOp::create(*builder, srcLoweredVal,
1520 aggregateBits - uptoBits - 1,
1521 aggregateBits - uptoBits - fieldBits);
1522 } else {
1523 extractBits = BitsPrimOp::create(*builder, srcLoweredVal,
1524 uptoBits + fieldBits - 1, uptoBits);
1525 }
1526 uptoBits += fieldBits;
1527 return BitCastOp::create(*builder, field.type, extractBits);
1528 };
1529 return lowerProducer(op, clone);
1530 }
1531
1532 // If ground type, then replace the result.
1533 if (type_isa<SIntType>(op.getType()))
1534 srcLoweredVal = AsSIntPrimOp::create(*builder, srcLoweredVal);
1535 op.getResult().replaceAllUsesWith(srcLoweredVal);
1536 return true;
1537}
1538
1539bool TypeLoweringVisitor::visitExpr(RefSendOp op) {
1540 auto clone = [&](const FlatBundleFieldEntry &field,
1541 ArrayAttr attrs) -> Value {
1542 return RefSendOp::create(*builder,
1543 getSubWhatever(op.getBase(), field.index));
1544 };
1545 // Be careful re:what gets lowered, consider ref.send of non-passive
1546 // and whether we're using the ref or the base type to choose
1547 // whether this should be lowered.
1548 return lowerProducer(op, clone);
1549}
1550
1551bool TypeLoweringVisitor::visitExpr(RefResolveOp op) {
1552 auto clone = [&](const FlatBundleFieldEntry &field,
1553 ArrayAttr attrs) -> Value {
1554 Value src = getSubWhatever(op.getRef(), field.index);
1555 return RefResolveOp::create(*builder, src);
1556 };
1557 // Lower according to lowering of the reference.
1558 // Particularly, preserve if rwprobe.
1559 return lowerProducer(op, clone, op.getRef().getType());
1560}
1561
1562bool TypeLoweringVisitor::visitExpr(RefCastOp op) {
1563 auto clone = [&](const FlatBundleFieldEntry &field,
1564 ArrayAttr attrs) -> Value {
1565 auto input = getSubWhatever(op.getInput(), field.index);
1566 return RefCastOp::create(*builder,
1567 RefType::get(field.type,
1568 op.getType().getForceable(),
1569 op.getType().getLayer()),
1570 input);
1571 };
1572 return lowerProducer(op, clone);
1573}
1574
1575/// Helper function to lower instance-like operations. This contains the common
1576/// logic for both InstanceOp and InstanceChoiceOp.
1577bool TypeLoweringVisitor::lowerInstanceLike(
1578 FInstanceLike op, PreserveAggregate::PreserveMode mode,
1579 ArrayAttr oldPortAnno,
1580 llvm::function_ref<Operation *(ArrayRef<Type>, ArrayRef<Direction>,
1581 ArrayAttr, ArrayAttr, ArrayAttr,
1582 hw::InnerSymAttr)>
1583 createNewInstance) {
1584 bool skip = true;
1585 SmallVector<Type, 8> resultTypes;
1586 SmallVector<int64_t, 8> endFields; // Compressed sparse row encoding
1587 SmallVector<Direction> newDirs;
1588 SmallVector<Attribute> newNames, newDomains, newPortAnno;
1589
1590 // Create domain helper to track domain port indices.
1591 DomainLoweringHelper domainHelper(context, op->getResultTypes());
1592 auto emptyAnno = builder->getArrayAttr({});
1593
1594 endFields.push_back(0);
1595 for (size_t i = 0, e = op->getNumResults(); i != e; ++i) {
1596 auto srcType = type_cast<FIRRTLType>(op->getResult(i).getType());
1597
1598 // Flatten any nested bundle types the usual way.
1599 SmallVector<FlatBundleFieldEntry, 8> fieldTypes;
1600 if (!peelType(srcType, fieldTypes, mode)) {
1601 newDirs.push_back(op.getPortDirection(i));
1602 newNames.push_back(op.getPortNameAttr(i));
1603 newDomains.push_back(op.getPortDomain(i));
1604 resultTypes.push_back(srcType);
1605 newPortAnno.push_back(oldPortAnno ? oldPortAnno[i] : emptyAnno);
1606 } else {
1607 skip = false;
1608 auto oldName = op.getPortName(i);
1609 auto oldDir = op.getPortDirection(i);
1610 // Store the flat type for the new bundle type.
1611 for (const auto &field : fieldTypes) {
1612 newDirs.push_back(direction::get((unsigned)oldDir ^ field.isOutput));
1613 newNames.push_back(builder->getStringAttr(oldName + field.suffix));
1614 newDomains.push_back(op.getPortDomain(i));
1615 resultTypes.push_back(mapLoweredType(srcType, field.type));
1616 auto annos =
1617 oldPortAnno
1618 ? filterAnnotations(context,
1619 dyn_cast_or_null<ArrayAttr>(oldPortAnno[i]),
1620 srcType, field)
1621 : emptyAnno;
1622 newPortAnno.push_back(annos);
1623 }
1624 }
1625 endFields.push_back(resultTypes.size());
1626 }
1627
1628 auto sym = getInnerSymName(op);
1629
1630 if (skip) {
1631 return false;
1632 }
1633
1634 // Compute the mapping from old domain indices to new domain indices.
1635 domainHelper.computeDomainMap(resultTypes);
1636
1637 // Rewrite domain associations to use the new port numbers.
1638 for (auto &domain : newDomains)
1639 domainHelper.rewriteDomain(domain);
1640
1641 // Create the new instance using the provided factory function.
1642 auto *newInstance = createNewInstance(
1643 resultTypes, newDirs, builder->getArrayAttr(newNames),
1644 builder->getArrayAttr(newDomains), builder->getArrayAttr(newPortAnno),
1645 sym ? hw::InnerSymAttr::get(sym) : hw::InnerSymAttr());
1646
1647 newInstance->setDiscardableAttrs(op->getDiscardableAttrDictionary());
1648
1649 SmallVector<Value> lowered;
1650 for (size_t aggIndex = 0, eAgg = op->getNumResults(); aggIndex != eAgg;
1651 ++aggIndex) {
1652 lowered.clear();
1653 for (size_t fieldIndex = endFields[aggIndex],
1654 eField = endFields[aggIndex + 1];
1655 fieldIndex < eField; ++fieldIndex)
1656 lowered.push_back(newInstance->getResult(fieldIndex));
1657 if (lowered.size() != 1 ||
1658 op->getResult(aggIndex).getType() != resultTypes[endFields[aggIndex]])
1659 processUsers(op->getResult(aggIndex), lowered);
1660 else
1661 op->getResult(aggIndex).replaceAllUsesWith(lowered[0]);
1662 }
1663 return true;
1664}
1665
1666bool TypeLoweringVisitor::visitDecl(InstanceOp op) {
1667 // Determine preservation mode from the referenced module.
1668 PreserveAggregate::PreserveMode mode = getPreservationModeForPorts(
1669 cast<FModuleLike>(op.getReferencedOperation(symTbl)));
1670
1671 // Lambda to create the new InstanceOp with lowered types.
1672 auto createNewInstance = [&](ArrayRef<Type> resultTypes,
1673 ArrayRef<Direction> newDirs, ArrayAttr newNames,
1674 ArrayAttr newDomains, ArrayAttr newPortAnno,
1675 hw::InnerSymAttr sym) -> Operation * {
1676 // FIXME: annotation update
1677 return InstanceOp::create(
1678 *builder, resultTypes, op.getModuleNameAttr(), op.getNameAttr(),
1679 op.getNameKindAttr(), direction::packAttribute(context, newDirs),
1680 newNames, newDomains, op.getAnnotations(), newPortAnno,
1681 op.getLayersAttr(), op.getLowerToBindAttr(), op.getDoNotPrintAttr(),
1682 sym);
1683 };
1684
1685 return lowerInstanceLike(op, mode, op.getPortAnnotations(),
1686 createNewInstance);
1687}
1688
1689bool TypeLoweringVisitor::visitDecl(InstanceChoiceOp op) {
1690 // Get the default target module to determine preservation mode.
1691 auto *moduleOp = symTbl.lookupNearestSymbolFrom(
1692 op, cast<FlatSymbolRefAttr>(op.getDefaultTargetAttr()));
1693 auto mode = getPreservationModeForPorts(cast<FModuleLike>(moduleOp));
1694
1695 // Lambda to create the new InstanceChoiceOp with lowered types.
1696 auto createNewInstance = [&](ArrayRef<Type> resultTypes,
1697 ArrayRef<Direction> newDirs, ArrayAttr newNames,
1698 ArrayAttr newDomains, ArrayAttr newPortAnno,
1699 hw::InnerSymAttr sym) -> Operation * {
1700 return InstanceChoiceOp::create(
1701 *builder, resultTypes, op.getModuleNames(), op.getCaseNames(),
1702 op.getNameAttr(), op.getNameKindAttr(),
1703 direction::packAttribute(context, newDirs), newNames, newDomains,
1704 op.getAnnotations(), newPortAnno, op.getLayersAttr(), sym,
1705 op.getInstanceMacroAttr());
1706 };
1707
1708 return lowerInstanceLike(op, mode, op.getPortAnnotations(),
1709 createNewInstance);
1710}
1711
1712bool TypeLoweringVisitor::visitExpr(SubaccessOp op) {
1713 auto input = op.getInput();
1714 FVectorType vType = input.getType();
1715
1716 // Check for empty vectors
1717 if (vType.getNumElements() == 0) {
1718 Value inv = InvalidValueOp::create(*builder, vType.getElementType());
1719 op.replaceAllUsesWith(inv);
1720 return true;
1721 }
1722
1723 // Check for constant instances
1724 if (ConstantOp arg =
1725 llvm::dyn_cast_or_null<ConstantOp>(op.getIndex().getDefiningOp())) {
1726 auto sio = SubindexOp::create(*builder, op.getInput(),
1727 arg.getValue().getExtValue());
1728 op.replaceAllUsesWith(sio.getResult());
1729 return true;
1730 }
1731
1732 // Construct a multibit mux
1733 SmallVector<Value> inputs;
1734 inputs.reserve(vType.getNumElements());
1735 for (int index = vType.getNumElements() - 1; index >= 0; index--)
1736 inputs.push_back(SubindexOp::create(*builder, input, index));
1737
1738 Value multibitMux = MultibitMuxOp::create(*builder, op.getIndex(), inputs);
1739 op.replaceAllUsesWith(multibitMux);
1740 return true;
1741}
1742
1743bool TypeLoweringVisitor::visitExpr(VectorCreateOp op) {
1744 auto clone = [&](const FlatBundleFieldEntry &field,
1745 ArrayAttr attrs) -> Value {
1746 return op.getOperand(field.index);
1747 };
1748 return lowerProducer(op, clone);
1749}
1750
1751bool TypeLoweringVisitor::visitExpr(BundleCreateOp op) {
1752 auto clone = [&](const FlatBundleFieldEntry &field,
1753 ArrayAttr attrs) -> Value {
1754 return op.getOperand(field.index);
1755 };
1756 return lowerProducer(op, clone);
1757}
1758
1759bool TypeLoweringVisitor::visitExpr(ElementwiseOrPrimOp op) {
1760 auto clone = [&](const FlatBundleFieldEntry &field,
1761 ArrayAttr attrs) -> Value {
1762 Value operands[] = {getSubWhatever(op.getLhs(), field.index),
1763 getSubWhatever(op.getRhs(), field.index)};
1764 return type_isa<BundleType, FVectorType>(field.type)
1765 ? (Value)ElementwiseOrPrimOp::create(*builder, field.type,
1766 operands)
1767 : (Value)OrPrimOp::create(*builder, operands);
1768 };
1769
1770 return lowerProducer(op, clone);
1771}
1772
1773bool TypeLoweringVisitor::visitExpr(ElementwiseAndPrimOp op) {
1774 auto clone = [&](const FlatBundleFieldEntry &field,
1775 ArrayAttr attrs) -> Value {
1776 Value operands[] = {getSubWhatever(op.getLhs(), field.index),
1777 getSubWhatever(op.getRhs(), field.index)};
1778 return type_isa<BundleType, FVectorType>(field.type)
1779 ? (Value)ElementwiseAndPrimOp::create(*builder, field.type,
1780 operands)
1781 : (Value)AndPrimOp::create(*builder, operands);
1782 };
1783
1784 return lowerProducer(op, clone);
1785}
1786
1787bool TypeLoweringVisitor::visitExpr(ElementwiseXorPrimOp op) {
1788 auto clone = [&](const FlatBundleFieldEntry &field,
1789 ArrayAttr attrs) -> Value {
1790 Value operands[] = {getSubWhatever(op.getLhs(), field.index),
1791 getSubWhatever(op.getRhs(), field.index)};
1792 return type_isa<BundleType, FVectorType>(field.type)
1793 ? (Value)ElementwiseXorPrimOp::create(*builder, field.type,
1794 operands)
1795 : (Value)XorPrimOp::create(*builder, operands);
1796 };
1797
1798 return lowerProducer(op, clone);
1799}
1800
1801bool TypeLoweringVisitor::visitExpr(MultibitMuxOp op) {
1802 auto clone = [&](const FlatBundleFieldEntry &field,
1803 ArrayAttr attrs) -> Value {
1804 SmallVector<Value> newInputs;
1805 newInputs.reserve(op.getInputs().size());
1806 for (auto input : op.getInputs()) {
1807 auto inputSub = getSubWhatever(input, field.index);
1808 newInputs.push_back(inputSub);
1809 }
1810 return MultibitMuxOp::create(*builder, op.getIndex(), newInputs);
1811 };
1812 return lowerProducer(op, clone);
1813}
1814
1815//===----------------------------------------------------------------------===//
1816// Pass Infrastructure
1817//===----------------------------------------------------------------------===//
1818
1819namespace {
1820struct LowerTypesPass
1821 : public circt::firrtl::impl::LowerFIRRTLTypesBase<LowerTypesPass> {
1822 using Base::Base;
1823
1824 void runOnOperation() override;
1825};
1826} // end anonymous namespace
1827
1828// This is the main entrypoint for the lowering pass.
1829void LowerTypesPass::runOnOperation() {
1831
1832 std::vector<FModuleLike> ops;
1833 auto &instanceGraph = getAnalysis<InstanceGraph>();
1834 // Symbol Table
1835 auto &symTbl = getAnalysis<SymbolTable>();
1836 // Cached attr
1837 AttrCache cache(&getContext());
1838
1839 DenseMap<FModuleLike, Convention> conventionTable;
1840 auto circuit = getOperation();
1841 for (auto module : circuit.getOps<FModuleLike>()) {
1842 auto convention = module.getConvention();
1843 // Instance choices select between modules with a shared port shape, so
1844 // any module instantiated by one must use the scalarized convention.
1845 if (llvm::any_of(instanceGraph.lookup(module)->uses(),
1846 [](InstanceRecord *use) {
1847 return use->getInstance<InstanceChoiceOp>();
1848 }))
1849 convention = Convention::Scalarized;
1850 conventionTable.insert({module, convention});
1851 ops.push_back(module);
1852 }
1853
1854 // This lambda, executes in parallel for each Op within the circt.
1855 auto lowerModules = [&](FModuleLike op) -> LogicalResult {
1856 // Use body type lowering attribute if it exists, otherwise use internal.
1857 Convention convention = Convention::Internal;
1858 if (auto conventionAttr = dyn_cast_or_null<ConventionAttr>(
1859 op->getDiscardableAttr("body_type_lowering")))
1860 convention = conventionAttr.getValue();
1861
1862 auto tl =
1863 TypeLoweringVisitor(&getContext(), preserveAggregate, convention,
1864 preserveMemories, symTbl, cache, conventionTable);
1865 tl.lowerModule(op);
1866
1867 return LogicalResult::failure(tl.isFailed());
1868 };
1869
1870 auto result = failableParallelForEach(&getContext(), ops, lowerModules);
1871
1872 if (failed(result))
1873 signalPassFailure();
1874}
assert(baseType &&"element must be base type")
static SmallVector< Value > extractBits(OpBuilder &builder, Value val)
static std::unique_ptr< Context > context
static void dump(DIModule &module, raw_indented_ostream &os)
static void eraseElementsAtIndices(SmallVectorImpl< T > &vec, const llvm::BitVector &removalMask)
Helper function to remove elements from a vector based on a BitVector mask.
static bool isPreservableAggregateType(Type type, PreserveAggregate::PreserveMode mode)
Return true if we can preserve the type.
static FIRRTLType mapLoweredType(FIRRTLType type, FIRRTLBaseType fieldType)
Return fieldType or fieldType as same ref as type.
static MemOp cloneMemWithNewType(ImplicitLocOpBuilder *b, MemOp op, FlatBundleFieldEntry field)
Clone memory for the specified field. Returns null op on error.
static bool containsBundleType(FIRRTLType type)
Return true if the type has a bundle type as subtype.
static Value cloneAccess(ImplicitLocOpBuilder *builder, Operation *op, Value rhs)
static bool peelType(Type type, SmallVectorImpl< FlatBundleFieldEntry > &fields, PreserveAggregate::PreserveMode mode)
Peel one layer of an aggregate type into its components.
static bool isNotSubAccess(Operation *op)
Return if something is not a normal subaccess.
static SmallVector< Operation * > getSAWritePath(Operation *op)
Look through and collect subfields leading to a subaccess.
static bool isOneDimVectorType(FIRRTLType type)
Return true if the type is a 1d vector type or ground type.
#define CIRCT_DEBUG_SCOPED_PASS_LOGGER(PASS)
Definition Debug.h:70
This class provides a read-only projection over the MLIR attributes that represent a set of annotatio...
ArrayAttr getArrayAttr() const
Return this annotation set as an ArrayAttr.
This class provides a read-only projection of an annotation.
DictionaryAttr getDict() const
Get the data dictionary of this attribute.
unsigned getFieldID() const
Get the field id this attribute targets.
void setMember(StringAttr name, Attribute value)
Add or set a member of the annotation to a value.
This class implements the same functionality as TypeSwitch except that it uses firrtl::type_dyn_cast ...
FIRRTLTypeSwitch< T, ResultT > & Case(CallableT &&caseFn)
Add a case on the given type.
FIRRTLVisitor allows you to visit all of the expr/stmt/decls with one class declaration.
ResultType visitInvalidOp(Operation *op, ExtraArgs... args)
visitInvalidOp is an override point for non-FIRRTL dialect operations.
This is an edge in the InstanceGraph.
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition CalyxOps.cpp:56
@ All
Preserve all aggregate values.
Definition Passes.h:40
@ OneDimVec
Preserve only 1d vectors of ground type (e.g. UInt<2>[3]).
Definition Passes.h:34
@ Vec
Preserve only vectors (e.g. UInt<2>[3][3]).
Definition Passes.h:37
@ None
Don't preserve aggregate at all.
Definition Passes.h:31
mlir::DenseBoolArrayAttr packAttribute(MLIRContext *context, ArrayRef< Direction > directions)
Return a DenseBoolArrayAttr containing the packed representation of an array of directions.
static Direction get(bool isOutput)
Return an output direction if isOutput is true, otherwise return an input direction.
Definition FIRRTLEnums.h:36
Direction
This represents the direction of a single port.
Definition FIRRTLEnums.h:27
FIRRTLBaseType getBaseType(Type type)
If it is a base type, return it as is.
FIRRTLType mapBaseType(FIRRTLType type, function_ref< FIRRTLBaseType(FIRRTLBaseType)> fn)
Return a FIRRTLType with its base type component mutated by the given function.
bool hasZeroBitWidth(FIRRTLType type)
Return true if the type has zero bit width.
bool type_isa(Type type)
void emitConnect(OpBuilder &builder, Location loc, Value lhs, Value rhs)
Emit a connect between two values.
StringAttr getInnerSymName(Operation *op)
Return the StringAttr for the inner_sym name, if it exists.
Definition FIRRTLOps.h:108
std::optional< int64_t > getBitWidth(FIRRTLBaseType type, bool ignoreFlip=false)
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition hw.py:1
This holds the name and type that describes the module's ports.