CIRCT 23.0.0git
Loading...
Searching...
No Matches
LowerTypes.cpp
Go to the documentation of this file.
1//===- LowerTypes.cpp - Lower Aggregate Types -------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the LowerTypes pass. This pass replaces aggregate types
10// with expanded values.
11//
12// This pass walks the operations in reverse order. This lets it visit users
13// before defs. Users can usually be expanded out to multiple operations (think
14// mux of a bundle to muxes of each field) with a temporary subWhatever op
15// inserted. When processing an aggregate producer, we blow out the op as
16// appropriate, then walk the users, often those are subWhatever ops which can
17// be bypassed and deleted. Function arguments are logically last on the
18// operation visit order and walked left to right, being peeled one layer at a
19// time with replacements inserted to the right of the original argument.
20//
21// Each processing of an op peels one layer of aggregate type off. Because new
22// ops are inserted immediately above the current up, the walk will visit them
23// next, effectively recusing on the aggregate types, without recusing. These
24// potentially temporary ops(if the aggregate is complex) effectively serve as
25// the worklist. Often aggregates are shallow, so the new ops are the final
26// ones.
27//
28//===----------------------------------------------------------------------===//
29
38#include "circt/Support/Debug.h"
39#include "mlir/IR/ImplicitLocOpBuilder.h"
40#include "mlir/IR/Threading.h"
41#include "mlir/Pass/Pass.h"
42#include "llvm/ADT/APSInt.h"
43#include "llvm/ADT/BitVector.h"
44#include "llvm/ADT/STLExtras.h"
45#include "llvm/Support/Debug.h"
46
47#define DEBUG_TYPE "firrtl-lower-types"
48
49namespace circt {
50namespace firrtl {
51#define GEN_PASS_DEF_LOWERFIRRTLTYPES
52#include "circt/Dialect/FIRRTL/Passes.h.inc"
53} // namespace firrtl
54} // namespace circt
55
56using namespace circt;
57using namespace firrtl;
58
59// TODO: check all argument types
60namespace {
61/// This represents a flattened bundle field element.
62struct FlatBundleFieldEntry {
63 /// This is the underlying ground type of the field.
64 FIRRTLBaseType type;
65 /// The index in the parent type
66 size_t index;
67 /// The fieldID
68 unsigned fieldID;
69 /// This is a suffix to add to the field name to make it unique.
70 SmallString<16> suffix;
71 /// This indicates whether the field was flipped to be an output.
72 bool isOutput;
73
74 FlatBundleFieldEntry(const FIRRTLBaseType &type, size_t index,
75 unsigned fieldID, StringRef suffix, bool isOutput)
76 : type(type), index(index), fieldID(fieldID), suffix(suffix),
77 isOutput(isOutput) {}
78
79 void dump() const {
80 llvm::errs() << "FBFE{" << type << " index<" << index << "> fieldID<"
81 << fieldID << "> suffix<" << suffix << "> isOutput<"
82 << isOutput << ">}\n";
83 }
84};
85} // end anonymous namespace
86
87/// Return fieldType or fieldType as same ref as type.
89 return mapBaseType(type, [&](auto) { return fieldType; });
90}
91
92/// Return fieldType or fieldType as same ref as type.
93static Type mapLoweredType(Type type, FIRRTLBaseType fieldType) {
94 auto ftype = type_dyn_cast<FIRRTLType>(type);
95 if (!ftype)
96 return type;
97 return mapLoweredType(ftype, fieldType);
98}
99
100/// Return true if the type is a 1d vector type or ground type.
103 .Case<BundleType>([&](auto bundle) { return false; })
104 .Case<FVectorType>([&](FVectorType vector) {
105 // When the size is 1, lower the vector into a scalar.
106 return vector.getElementType().isGround() &&
107 vector.getNumElements() > 1;
108 })
109 .Default([](auto groundType) { return true; });
110}
111
112// NOLINTBEGIN(misc-no-recursion)
113/// Return true if the type has a bundle type as subtype.
116 .Case<BundleType>([&](auto bundle) { return true; })
117 .Case<FVectorType>([&](FVectorType vector) {
118 return containsBundleType(vector.getElementType());
119 })
120 .Default([](auto groundType) { return false; });
121}
122// NOLINTEND(misc-no-recursion)
123
124/// Return true if we can preserve the type.
125static bool isPreservableAggregateType(Type type,
127 if (auto refType = type_dyn_cast<RefType>(type)) {
128 // Always preserve rwprobe's.
129 if (refType.getForceable())
130 return true;
131 // FIXME: Don't preserve read-only RefType for now. This is workaround for
132 // MemTap which causes type mismatches (issue 4479).
133 return false;
134 }
135
136 // Return false if no aggregate value is preserved.
137 if (mode == PreserveAggregate::None)
138 return false;
139
140 auto firrtlType = type_dyn_cast<FIRRTLBaseType>(type);
141 if (!firrtlType)
142 return false;
143
144 // We can a preserve the type iff (i) the type is not passive, (ii) the type
145 // doesn't contain analog and (iii) type don't contain zero bitwidth.
146 if (!firrtlType.isPassive() || firrtlType.containsAnalog() ||
147 hasZeroBitWidth(firrtlType))
148 return false;
149
150 switch (mode) {
152 return true;
154 return isOneDimVectorType(firrtlType);
156 return !containsBundleType(firrtlType);
157 default:
158 llvm_unreachable("unexpected mode");
159 }
160}
161
162/// Peel one layer of an aggregate type into its components. Type may be
163/// complex, but empty, in which case fields is empty, but the return is true.
164static bool peelType(Type type, SmallVectorImpl<FlatBundleFieldEntry> &fields,
166 // If the aggregate preservation is enabled and the type is preservable,
167 // then just return.
168 if (isPreservableAggregateType(type, mode))
169 return false;
170
171 if (auto refType = type_dyn_cast<RefType>(type))
172 type = refType.getType();
174 .Case<BundleType>([&](auto bundle) {
175 SmallString<16> tmpSuffix;
176 // Otherwise, we have a bundle type. Break it down.
177 for (size_t i = 0, e = bundle.getNumElements(); i < e; ++i) {
178 auto elt = bundle.getElement(i);
179 // Construct the suffix to pass down.
180 tmpSuffix.resize(0);
181 tmpSuffix.push_back('_');
182 tmpSuffix.append(elt.name.getValue());
183 fields.emplace_back(elt.type, i, bundle.getFieldID(i), tmpSuffix,
184 elt.isFlip);
185 }
186 return true;
187 })
188 .Case<FVectorType>([&](auto vector) {
189 // Increment the field ID to point to the first element.
190 for (size_t i = 0, e = vector.getNumElements(); i != e; ++i) {
191 fields.emplace_back(vector.getElementType(), i, vector.getFieldID(i),
192 "_" + std::to_string(i), false);
193 }
194 return true;
195 })
196 .Default([](auto op) { return false; });
197}
198
199/// Return if something is not a normal subaccess. Non-normal includes
200/// zero-length vectors and constant indexes (which are really subindexes).
201static bool isNotSubAccess(Operation *op) {
202 SubaccessOp sao = llvm::dyn_cast<SubaccessOp>(op);
203 if (!sao)
204 return true;
205 ConstantOp arg =
206 llvm::dyn_cast_or_null<ConstantOp>(sao.getIndex().getDefiningOp());
207 return arg && sao.getInput().getType().base().getNumElements() != 0;
208}
209
210/// Look through and collect subfields leading to a subaccess.
211static SmallVector<Operation *> getSAWritePath(Operation *op) {
212 SmallVector<Operation *> retval;
213 auto defOp = op->getOperand(0).getDefiningOp();
214 while (isa_and_nonnull<SubfieldOp, SubindexOp, SubaccessOp>(defOp)) {
215 retval.push_back(defOp);
216 defOp = defOp->getOperand(0).getDefiningOp();
217 }
218 // Trim to the subaccess
219 while (!retval.empty() && isNotSubAccess(retval.back()))
220 retval.pop_back();
221 return retval;
222}
223
224/// Clone memory for the specified field. Returns null op on error.
225static MemOp cloneMemWithNewType(ImplicitLocOpBuilder *b, MemOp op,
226 FlatBundleFieldEntry field) {
227 SmallVector<Type, 8> ports;
228 SmallVector<Attribute, 8> portNames;
229 SmallVector<Attribute, 8> portLocations;
230
231 auto oldPorts = op.getPorts();
232 for (size_t portIdx = 0, e = oldPorts.size(); portIdx < e; ++portIdx) {
233 auto port = oldPorts[portIdx];
234 ports.push_back(
235 MemOp::getTypeForPort(op.getDepth(), field.type, port.second));
236 portNames.push_back(port.first);
237 }
238
239 // It's easier to duplicate the old annotations, then fix and filter them.
240 auto newMem =
241 MemOp::create(*b, ports, op.getReadLatency(), op.getWriteLatency(),
242 op.getDepth(), op.getRuw(), b->getArrayAttr(portNames),
243 (op.getName() + field.suffix).str(), op.getNameKind(),
244 op.getAnnotations(), op.getPortAnnotations(),
245 op.getInnerSymAttr(), op.getInitAttr(), op.getPrefixAttr());
246
247 if (op.getInnerSym()) {
248 op.emitError("cannot split memory with symbol present");
249 return {};
250 }
251
252 SmallVector<Attribute> newAnnotations;
253 for (size_t portIdx = 0, e = newMem.getNumResults(); portIdx < e; ++portIdx) {
254 auto portType = type_cast<BundleType>(newMem.getResult(portIdx).getType());
255 auto oldPortType = type_cast<BundleType>(op.getResult(portIdx).getType());
256 SmallVector<Attribute> portAnno;
257 for (auto attr : newMem.getPortAnnotation(portIdx)) {
258 Annotation anno(attr);
259 if (auto annoFieldID = anno.getFieldID()) {
260 auto targetIndex = oldPortType.getIndexForFieldID(annoFieldID);
261
262 // Apply annotations to all elements if the target is the whole
263 // sub-field.
264 if (annoFieldID == oldPortType.getFieldID(targetIndex)) {
265 anno.setMember(
266 "circt.fieldID",
267 b->getI32IntegerAttr(portType.getFieldID(targetIndex)));
268 portAnno.push_back(anno.getDict());
269 continue;
270 }
271
272 // Handle aggregate sub-fields, including `(r/w)data` and `(w)mask`.
273 if (type_isa<BundleType>(oldPortType.getElement(targetIndex).type)) {
274 // Check whether the annotation falls into the range of the current
275 // field. Note that the `field` here is peeled from the `data`
276 // sub-field of the memory port, thus we need to add the fieldID of
277 // `data` or `mask` sub-field to get the "real" fieldID.
278 auto fieldID = field.fieldID + oldPortType.getFieldID(targetIndex);
279 if (annoFieldID >= fieldID &&
280 annoFieldID <=
281 fieldID + hw::FieldIdImpl::getMaxFieldID(field.type)) {
282 // Set the field ID of the new annotation.
283 auto newFieldID =
284 annoFieldID - fieldID + portType.getFieldID(targetIndex);
285 anno.setMember("circt.fieldID", b->getI32IntegerAttr(newFieldID));
286 portAnno.push_back(anno.getDict());
287 }
288 }
289 } else
290 portAnno.push_back(attr);
291 }
292 newAnnotations.push_back(b->getArrayAttr(portAnno));
293 }
294 newMem.setAllPortAnnotations(newAnnotations);
295 return newMem;
296}
297
298//===----------------------------------------------------------------------===//
299// Module Type Lowering
300//===----------------------------------------------------------------------===//
301namespace {
302
303struct AttrCache {
304 AttrCache(MLIRContext *context) {
305 i64ty = IntegerType::get(context, 64);
306 nameAttr = StringAttr::get(context, "name");
307 nameKindAttr = StringAttr::get(context, "nameKind");
308 sPortDirections = StringAttr::get(context, "portDirections");
309 sPortNames = StringAttr::get(context, "portNames");
310 sPortTypes = StringAttr::get(context, "portTypes");
311 sPortSymbols = StringAttr::get(context, "portSymbols");
312 sPortLocations = StringAttr::get(context, "portLocations");
313 sPortAnnotations = StringAttr::get(context, "portAnnotations");
314 sPortDomains = StringAttr::get(context, "domainInfo");
315 sEmpty = StringAttr::get(context, "");
316 aEmpty = ArrayAttr::get(context, {});
317 }
318 AttrCache(const AttrCache &) = default;
319
320 Type i64ty;
321 StringAttr nameAttr, nameKindAttr, sPortDirections, sPortNames, sPortTypes,
322 sPortSymbols, sPortLocations, sPortAnnotations, sPortDomains, sEmpty;
323 ArrayAttr aEmpty;
324};
325
326/// Helper class to handle domain lowering consistently across modules,
327/// extmodules, and instances. This class tracks domain port indices and
328/// provides methods to rewrite domain associations after port lowering.
329class DomainLoweringHelper {
330public:
331 /// Construct a helper by scanning the original port types for domain types.
332 /// For modules/extmodules, pass the port types attribute array.
333 /// For instances, pass the result types directly.
334 DomainLoweringHelper(MLIRContext *context, ArrayRef<Attribute> portTypes)
335 : context(context) {
336 for (auto [index, typeAttr] : llvm::enumerate(portTypes))
337 if (type_isa<DomainType>(cast<TypeAttr>(typeAttr).getValue()))
338 domainIndexByOrdinal.push_back(index);
339 }
340
341 /// Construct a helper by scanning instance result types for domain types.
342 DomainLoweringHelper(MLIRContext *context, TypeRange resultTypes)
343 : context(context) {
344 for (auto [index, type] : llvm::enumerate(resultTypes))
345 if (type_isa<DomainType>(type))
346 domainIndexByOrdinal.push_back(index);
347 }
348
349 /// Compute the mapping from old domain port indices to new port indices after
350 /// type lowering. Call this after ports have been lowered but before
351 /// rewriting domain associations. This overload takes a range of types
352 /// directly (e.g., from instances).
353 void computeDomainMap(TypeRange types) {
354 size_t i = 0, ord = 0;
355 for (auto type : types) {
356 if (type_isa<DomainType>(type))
357 domainMap[domainIndexByOrdinal[ord++]] = i;
358 ++i;
359 }
360 }
361
362 /// Compute the mapping from old domain port indices to new port indices after
363 /// type lowering. Call this after ports have been lowered but before
364 /// rewriting domain associations. This overload takes a range of PortInfo
365 /// and extracts types from them (e.g., from modules/extmodules).
366 void computeDomainMap(ArrayRef<PortInfo> ports) {
367 size_t i = 0, ord = 0;
368 for (const auto &port : ports) {
369 if (type_isa<DomainType>(port.type))
370 domainMap[domainIndexByOrdinal[ord++]] = i;
371 ++i;
372 }
373 }
374
375 /// Rewrite a domain attribute to use new port indices. The domain attribute
376 /// contains an array of port indices that need to be updated to reflect the
377 /// new port numbering after type lowering.
378 void rewriteDomain(Attribute &domain) {
379 auto oldAssociations = dyn_cast<ArrayAttr>(domain);
380 if (!oldAssociations)
381 return;
382 SmallVector<Attribute> newAssociations;
383 for (auto oldAttr : oldAssociations)
384 newAssociations.push_back(IntegerAttr::get(
385 IntegerType::get(context, 32, IntegerType::Unsigned),
386 domainMap[cast<IntegerAttr>(oldAttr).getValue().getZExtValue()]));
387 domain = ArrayAttr::get(context, newAssociations);
388 }
389
390private:
391 MLIRContext *context;
392 /// Maps ordinal position of domain ports to their original indices.
393 SmallVector<unsigned> domainIndexByOrdinal;
394 /// Maps old port indices to new port indices after lowering.
395 DenseMap<unsigned, unsigned> domainMap;
396};
397
398// The visitors all return true if the operation should be deleted, false if
399// not.
400struct TypeLoweringVisitor : public FIRRTLVisitor<TypeLoweringVisitor, bool> {
401
402 TypeLoweringVisitor(
403 MLIRContext *context, PreserveAggregate::PreserveMode preserveAggregate,
404 Convention bodyConvention,
405 PreserveAggregate::PreserveMode memoryPreservationMode,
406 SymbolTable &symTbl, const AttrCache &cache,
407 const llvm::DenseMap<FModuleLike, Convention> &conventionTable)
408 : context(context), defaultAggregatePreservationMode(preserveAggregate),
409 memoryPreservationMode(memoryPreservationMode), symTbl(symTbl),
410 cache(cache), conventionTable(conventionTable) {
411 bodyAggregatePreservationMode = bodyConvention == Convention::Scalarized
413 : defaultAggregatePreservationMode;
414 }
415 using FIRRTLVisitor<TypeLoweringVisitor, bool>::visitDecl;
416 using FIRRTLVisitor<TypeLoweringVisitor, bool>::visitExpr;
417 using FIRRTLVisitor<TypeLoweringVisitor, bool>::visitStmt;
418
419 /// If the referenced operation is a FModuleOp or an FExtModuleOp, perform
420 /// type lowering on all operations.
421 void lowerModule(FModuleLike op);
422
423 bool lowerArg(FModuleLike module, size_t argIndex, size_t argsRemoved,
424 SmallVectorImpl<PortInfo> &newArgs,
425 SmallVectorImpl<Value> &lowering);
426 std::pair<Value, PortInfo> addArg(Operation *module, unsigned insertPt,
427 unsigned insertPtOffset, FIRRTLType srcType,
428 const FlatBundleFieldEntry &field,
429 PortInfo &oldArg, hw::InnerSymAttr newSym);
430
431 // Helpers to manage state.
432 bool visitDecl(FExtModuleOp op);
433 bool visitDecl(FModuleOp op);
434 bool visitDecl(InstanceOp op);
435 bool visitDecl(InstanceChoiceOp op);
436 bool visitDecl(MemOp op);
437 bool visitDecl(NodeOp op);
438 bool visitDecl(RegOp op);
439 bool visitDecl(WireOp op);
440 bool visitDecl(RegResetOp op);
441 bool visitExpr(InvalidValueOp op);
442 bool visitExpr(SubaccessOp op);
443 bool visitExpr(VectorCreateOp op);
444 bool visitExpr(BundleCreateOp op);
445 bool visitExpr(ElementwiseAndPrimOp op);
446 bool visitExpr(ElementwiseOrPrimOp op);
447 bool visitExpr(ElementwiseXorPrimOp op);
448 bool visitExpr(MultibitMuxOp op);
449 bool visitExpr(MuxPrimOp op);
450 bool visitExpr(Mux2CellIntrinsicOp op);
451 bool visitExpr(Mux4CellIntrinsicOp op);
452 bool visitExpr(BitCastOp op);
453 bool visitExpr(RefSendOp op);
454 bool visitExpr(RefResolveOp op);
455 bool visitExpr(RefCastOp op);
456 bool visitStmt(ConnectOp op);
457 bool visitStmt(MatchingConnectOp op);
458 bool visitStmt(RefDefineOp op);
459 bool visitStmt(WhenOp op);
460 bool visitStmt(LayerBlockOp op);
461 bool visitUnrealizedConversionCast(mlir::UnrealizedConversionCastOp op);
462
463 bool isFailed() const { return encounteredError; }
464
465 bool visitInvalidOp(Operation *op) {
466 if (auto castOp = dyn_cast<mlir::UnrealizedConversionCastOp>(op))
467 return visitUnrealizedConversionCast(castOp);
468 return false;
469 }
470
471private:
472 void processUsers(Value val, ArrayRef<Value> mapping);
473 bool processSAPath(Operation *);
474 void lowerBlock(Block *);
475 void lowerSAWritePath(Operation *, ArrayRef<Operation *> writePath);
476
477 /// Lower a "producer" operation one layer based on policy.
478 /// Use the provided \p clone function to generate individual ops for
479 /// the expanded subelements/fields. The type used to determine if lowering
480 /// is needed is either \p srcType if provided or from the assumed-to-exist
481 /// first result of the operation. When lowering, the clone callback will be
482 /// invoked with each subelement/field of this type.
483 bool lowerProducer(
484 Operation *op,
485 llvm::function_ref<Value(const FlatBundleFieldEntry &, ArrayAttr)> clone,
486 Type srcType = {});
487
488 /// Filter out and return \p annotations that target includes \field,
489 /// modifying as needed to adjust fieldID's relative to to \field.
490 ArrayAttr filterAnnotations(MLIRContext *ctxt, ArrayAttr annotations,
491 FIRRTLType srcType, FlatBundleFieldEntry field);
492
493 /// Partition inner symbols on given type. Fails if any symbols
494 /// cannot be assigned to a field, such as inner symbol on root.
495 LogicalResult partitionSymbols(hw::InnerSymAttr sym, FIRRTLType parentType,
496 SmallVectorImpl<hw::InnerSymAttr> &newSyms,
497 Location errorLoc);
498
500 getPreservationModeForPorts(FModuleLike moduleLike);
501 Value getSubWhatever(Value val, size_t index);
502
503 /// Helper function to lower instance-like operations (InstanceOp and
504 /// InstanceChoiceOp).
505 bool lowerInstanceLike(FInstanceLike op, PreserveAggregate::PreserveMode mode,
506 ArrayAttr oldPortAnno,
507 llvm::function_ref<Operation *(
508 ArrayRef<Type>, ArrayRef<Direction>, ArrayAttr,
509 ArrayAttr, ArrayAttr, hw::InnerSymAttr)>
510 createNewInstance);
511
512 size_t uniqueIdx = 0;
513 std::string uniqueName() {
514 auto myID = uniqueIdx++;
515 return (Twine("__GEN_") + Twine(myID)).str();
516 }
517
518 MLIRContext *context;
519
520 /// Aggregate preservation mode.
521 PreserveAggregate::PreserveMode defaultAggregatePreservationMode;
522 PreserveAggregate::PreserveMode bodyAggregatePreservationMode;
523 PreserveAggregate::PreserveMode memoryPreservationMode;
524
525 /// The builder is set and maintained in the main loop.
526 ImplicitLocOpBuilder *builder;
527
528 // Keep a symbol table around for resolving symbols
529 SymbolTable &symTbl;
530
531 // Cache some attributes
532 const AttrCache &cache;
533
534 const llvm::DenseMap<FModuleLike, Convention> &conventionTable;
535
536 // Set true if the lowering failed.
537 bool encounteredError = false;
538};
539} // namespace
540
541/// Return aggregate preservation mode for the module ports. If the module has a
542/// scalarized linkage, then we may not preserve it's aggregate ports.
544TypeLoweringVisitor::getPreservationModeForPorts(FModuleLike module) {
545 auto lookup = conventionTable.find(module);
546 if (lookup == conventionTable.end())
547 return defaultAggregatePreservationMode;
548 switch (lookup->second) {
549 case Convention::Scalarized:
551 case Convention::Internal:
552 return defaultAggregatePreservationMode;
553 }
554 llvm_unreachable("Unknown convention");
555 return defaultAggregatePreservationMode;
556}
557
558Value TypeLoweringVisitor::getSubWhatever(Value val, size_t index) {
559 if (type_isa<BundleType>(val.getType()))
560 return SubfieldOp::create(*builder, val, index);
561 if (type_isa<FVectorType>(val.getType()))
562 return SubindexOp::create(*builder, val, index);
563 if (type_isa<RefType>(val.getType()))
564 return RefSubOp::create(*builder, val, index);
565 llvm_unreachable("Unknown aggregate type");
566 return nullptr;
567}
568
569/// Conditionally expand a subaccessop write path
570bool TypeLoweringVisitor::processSAPath(Operation *op) {
571 // Does this LHS have a subaccessop?
572 SmallVector<Operation *> writePath = getSAWritePath(op);
573 if (writePath.empty())
574 return false;
575
576 lowerSAWritePath(op, writePath);
577 // Unhook the writePath from the connect. This isn't the right type, but we
578 // are deleting the op anyway.
579 op->eraseOperands(0, 2);
580 // See how far up the tree we can delete things.
581 for (size_t i = 0; i < writePath.size(); ++i) {
582 if (writePath[i]->use_empty()) {
583 writePath[i]->erase();
584 } else {
585 break;
586 }
587 }
588 return true;
589}
590
591void TypeLoweringVisitor::lowerBlock(Block *block) {
592 // Lower the operations bottom up.
593 for (auto it = block->rbegin(), e = block->rend(); it != e;) {
594 auto &iop = *it;
595 builder->setInsertionPoint(&iop);
596 builder->setLoc(iop.getLoc());
597 bool removeOp = dispatchVisitor(&iop);
598 ++it;
599 // Erase old ops eagerly so we don't have dangling uses we've already
600 // lowered.
601 if (removeOp)
602 iop.erase();
603 }
604}
605
606ArrayAttr TypeLoweringVisitor::filterAnnotations(MLIRContext *ctxt,
607 ArrayAttr annotations,
608 FIRRTLType srcType,
609 FlatBundleFieldEntry field) {
610 SmallVector<Attribute> retval;
611 if (!annotations || annotations.empty())
612 return ArrayAttr::get(ctxt, retval);
613 for (auto opAttr : annotations) {
614 Annotation anno(opAttr);
615 auto fieldID = anno.getFieldID();
616 anno.removeMember("circt.fieldID");
617
618 // If no fieldID set, or points to root, forward the annotation without the
619 // fieldID field (which was removed above).
620 if (fieldID == 0) {
621 retval.push_back(anno.getAttr());
622 continue;
623 }
624 // Check whether the annotation falls into the range of the current field.
625
626 if (fieldID < field.fieldID ||
627 fieldID > field.fieldID + hw::FieldIdImpl::getMaxFieldID(field.type))
628 continue;
629
630 // Add fieldID back if non-zero relative to this field.
631 if (auto newFieldID = fieldID - field.fieldID) {
632 // If the target is a subfield/subindex of the current field, create a
633 // new annotation with the correct circt.fieldID.
634 anno.setMember("circt.fieldID", builder->getI32IntegerAttr(newFieldID));
635 }
636
637 retval.push_back(anno.getAttr());
638 }
639 return ArrayAttr::get(ctxt, retval);
640}
641
642LogicalResult TypeLoweringVisitor::partitionSymbols(
643 hw::InnerSymAttr sym, FIRRTLType parentType,
644 SmallVectorImpl<hw::InnerSymAttr> &newSyms, Location errorLoc) {
645
646 // No symbol, nothing to partition.
647 if (!sym || sym.empty())
648 return success();
649
650 auto *context = sym.getContext();
651
652 auto baseType = getBaseType(parentType);
653 if (!baseType)
654 return mlir::emitError(errorLoc,
655 "unable to partition symbol on unsupported type ")
656 << parentType;
657
658 return TypeSwitch<FIRRTLType, LogicalResult>(baseType)
659 .Case<BundleType, FVectorType>([&](auto aggType) -> LogicalResult {
660 struct BinningInfo {
661 uint64_t index;
662 uint64_t relFieldID;
663 hw::InnerSymPropertiesAttr prop;
664 };
665
666 // Walk each inner symbol, compute binning information/assignment.
667 SmallVector<BinningInfo> binning;
668 for (auto prop : sym) {
669 auto fieldID = prop.getFieldID();
670 // Special-case fieldID == 0, helper methods require non-zero fieldID.
671 if (fieldID == 0)
672 return mlir::emitError(errorLoc, "unable to lower due to symbol ")
673 << prop.getName()
674 << " with target not preserved by lowering";
675 auto [index, relFieldID] = aggType.getIndexAndSubfieldID(fieldID);
676 binning.push_back({index, relFieldID, prop});
677 }
678
679 // Sort by index, fieldID.
680 llvm::stable_sort(binning, [&](auto &lhs, auto &rhs) {
681 return std::tuple(lhs.index, lhs.relFieldID) <
682 std::tuple(rhs.index, rhs.relFieldID);
683 });
684 assert(!binning.empty());
685
686 // Populate newSyms, group all symbols on same index.
687 newSyms.resize(aggType.getNumElements());
688 for (auto binIt = binning.begin(), binEnd = binning.end();
689 binIt != binEnd;) {
690 auto curIndex = binIt->index;
691 SmallVector<hw::InnerSymPropertiesAttr> propsForIndex;
692 // Gather all adjacent symbols for this index.
693 while (binIt != binEnd && binIt->index == curIndex) {
694 propsForIndex.push_back(hw::InnerSymPropertiesAttr::get(
695 context, binIt->prop.getName(), binIt->relFieldID,
696 binIt->prop.getSymVisibility()));
697 ++binIt;
698 }
699
700 assert(!newSyms[curIndex]);
701 newSyms[curIndex] = hw::InnerSymAttr::get(context, propsForIndex);
702 }
703 return success();
704 })
705 .Default([&](auto ty) {
706 return mlir::emitError(
707 errorLoc, "unable to partition symbol on unsupported type ")
708 << ty;
709 });
710}
711
712bool TypeLoweringVisitor::lowerProducer(
713 Operation *op,
714 llvm::function_ref<Value(const FlatBundleFieldEntry &, ArrayAttr)> clone,
715 Type srcType) {
716
717 if (!srcType)
718 srcType = op->getResult(0).getType();
719 auto srcFType = type_dyn_cast<FIRRTLType>(srcType);
720 if (!srcFType)
721 return false;
722 SmallVector<FlatBundleFieldEntry, 8> fieldTypes;
723
724 if (!peelType(srcFType, fieldTypes, bodyAggregatePreservationMode))
725 return false;
726
727 SmallVector<Value> lowered;
728 // Loop over the leaf aggregates.
729 SmallString<16> loweredName;
730 auto nameKindAttr = op->getAttrOfType<NameKindEnumAttr>(cache.nameKindAttr);
731
732 if (auto nameAttr = op->getAttrOfType<StringAttr>(cache.nameAttr))
733 loweredName = nameAttr.getValue();
734 auto baseNameLen = loweredName.size();
735 auto oldAnno = dyn_cast_or_null<ArrayAttr>(op->getAttr("annotations"));
736
737 SmallVector<hw::InnerSymAttr> fieldSyms(fieldTypes.size());
738 if (auto symOp = dyn_cast<hw::InnerSymbolOpInterface>(op)) {
739 if (failed(partitionSymbols(symOp.getInnerSymAttr(), srcFType, fieldSyms,
740 symOp.getLoc()))) {
741 encounteredError = true;
742 return false;
743 }
744 }
745
746 for (const auto &[field, sym] : llvm::zip_equal(fieldTypes, fieldSyms)) {
747 if (!loweredName.empty()) {
748 loweredName.resize(baseNameLen);
749 loweredName += field.suffix;
750 }
751
752 // For all annotations on the parent op, filter them based on the target
753 // attribute.
754 ArrayAttr loweredAttrs =
755 filterAnnotations(context, oldAnno, srcFType, field);
756 auto newVal = clone(field, loweredAttrs);
757
758 // If inner symbols on this field, add to new op.
759 if (sym) {
760 // Splitting up something with symbols on it should lower to ops
761 // that also can have symbols on them.
762 auto newSymOp = newVal.getDefiningOp<hw::InnerSymbolOpInterface>();
763 assert(
764 newSymOp &&
765 "op with inner symbol lowered to op that cannot take inner symbol");
766 newSymOp.setInnerSymbolAttr(sym);
767 }
768
769 // Carry over the name, if present.
770 if (auto *newOp = newVal.getDefiningOp()) {
771 if (!loweredName.empty())
772 newOp->setAttr(cache.nameAttr, StringAttr::get(context, loweredName));
773 if (nameKindAttr)
774 newOp->setAttr(cache.nameKindAttr, nameKindAttr);
775
776 // Clone discardable attributes as well.
777 newOp->setDiscardableAttrs(op->getDiscardableAttrDictionary());
778 }
779 lowered.push_back(newVal);
780 }
781
782 processUsers(op->getResult(0), lowered);
783 return true;
784}
785
786void TypeLoweringVisitor::processUsers(Value val, ArrayRef<Value> mapping) {
787 for (auto *user : llvm::make_early_inc_range(val.getUsers())) {
788 TypeSwitch<Operation *, void>(user)
789 .Case<SubindexOp>([mapping](SubindexOp sio) {
790 Value repl = mapping[sio.getIndex()];
791 sio.replaceAllUsesWith(repl);
792 sio.erase();
793 })
794 .Case<SubfieldOp>([mapping](SubfieldOp sfo) {
795 // Get the input bundle type.
796 Value repl = mapping[sfo.getFieldIndex()];
797 sfo.replaceAllUsesWith(repl);
798 sfo.erase();
799 })
800 .Case<RefSubOp>([mapping](RefSubOp refSub) {
801 Value repl = mapping[refSub.getIndex()];
802 refSub.replaceAllUsesWith(repl);
803 refSub.erase();
804 })
805 .Default([&](auto op) {
806 // This means we have already processed the user, and it didn't lower
807 // its inputs. This is an opaque user, which will continue to have
808 // aggregate type as input, even after LowerTypes. So, construct the
809 // vector/bundle back from the lowered elements to ensure a valid
810 // input into the opaque op. This only supports Bundles and Vectors.
811
812 // This builder ensures that the aggregate construction happens at the
813 // user location, and the LowerTypes algorithm will not touch them any
814 // more, because LowerTypes was reverse iterating on the block and the
815 // user has already been processed.
816 ImplicitLocOpBuilder b(user->getLoc(), user);
817
818 // This shouldn't happen (non-FIRRTLBaseType's in lowered types, or
819 // refs), check explicitly here for clarity/early detection.
820 assert(llvm::none_of(mapping, [](auto v) {
821 auto fbasetype = type_dyn_cast<FIRRTLBaseType>(v.getType());
822 return !fbasetype || fbasetype.containsReference();
823 }));
824
825 Value input =
826 TypeSwitch<Type, Value>(val.getType())
827 .template Case<FVectorType>([&](auto vecType) {
828 return b.createOrFold<VectorCreateOp>(vecType, mapping);
829 })
830 .template Case<BundleType>([&](auto bundleType) {
831 return b.createOrFold<BundleCreateOp>(bundleType, mapping);
832 })
833 .Default([&](auto _) -> Value { return {}; });
834 if (!input) {
835 user->emitError("unable to reconstruct source of type ")
836 << val.getType();
837 encounteredError = true;
838 return;
839 }
840 user->replaceUsesOfWith(val, input);
841 });
842 }
843}
844
845/// Helper function to remove elements from a vector based on a BitVector mask.
846template <typename T>
847static void eraseElementsAtIndices(SmallVectorImpl<T> &vec,
848 const llvm::BitVector &removalMask) {
849 size_t writeIndex = 0, readIndex = 0;
850
851 // Iterate over each set bit (element to remove) in the mask.
852 // Between each removal point, we bulk-copy the range of elements to keep.
853 for (size_t removalIndex : removalMask.set_bits()) {
854 // Copy the range [readIndex, removalIndex) - these are elements to keep.
855 assert(removalIndex >= readIndex && "removal index before read index");
856 size_t rangeSize = removalIndex - readIndex;
857 if (rangeSize > 0) {
858 // Bulk move the range of elements to keep to the write position.
859 // Skip if the read and write positions are the same (= the first
860 // iteration).
861 if (writeIndex != readIndex)
862 std::move(vec.begin() + readIndex, vec.begin() + removalIndex,
863 vec.begin() + writeIndex);
864 writeIndex += rangeSize;
865 }
866 readIndex = removalIndex + 1;
867 }
868
869 // Copy any remaining elements after the last removal point.
870 size_t remainingSize = vec.size() - readIndex;
871 if (remainingSize > 0) {
872 if (writeIndex != readIndex)
873 std::move(vec.begin() + readIndex, vec.end(), vec.begin() + writeIndex);
874 writeIndex += remainingSize;
875 }
876
877 // Truncate the vector to the new size (number of elements kept).
878 vec.truncate(writeIndex);
879}
880
881void TypeLoweringVisitor::lowerModule(FModuleLike op) {
882 if (auto module = llvm::dyn_cast<FModuleOp>(*op))
883 visitDecl(module);
884 else if (auto extModule = llvm::dyn_cast<FExtModuleOp>(*op))
885 visitDecl(extModule);
886}
887
888// Creates and returns a new block argument of the specified type to the
889// module. This also maintains the name attribute for the new argument,
890// possibly with a new suffix appended.
891std::pair<Value, PortInfo>
892TypeLoweringVisitor::addArg(Operation *module, unsigned insertPt,
893 unsigned insertPtOffset, FIRRTLType srcType,
894 const FlatBundleFieldEntry &field, PortInfo &oldArg,
895 hw::InnerSymAttr newSym) {
896 Value newValue;
897 FIRRTLType fieldType = mapLoweredType(srcType, field.type);
898 if (auto mod = llvm::dyn_cast<FModuleOp>(module)) {
899 Block *body = mod.getBodyBlock();
900 // Append the new argument.
901 newValue = body->insertArgument(insertPt, fieldType, oldArg.loc);
902 }
903
904 // Save the name attribute for the new argument.
905 auto name = builder->getStringAttr(oldArg.name.getValue() + field.suffix);
906
907 // Populate the new arg attributes.
908 auto newAnnotations = filterAnnotations(
909 context, oldArg.annotations.getArrayAttr(), srcType, field);
910 // Flip the direction if the field is an output.
911 auto direction = (Direction)((unsigned)oldArg.direction ^ field.isOutput);
912
913 return std::make_pair(
914 newValue, PortInfo{name, fieldType, direction, newSym, oldArg.loc,
915 AnnotationSet(newAnnotations), oldArg.domains});
916}
917
918// Lower arguments with bundle type by flattening them.
919bool TypeLoweringVisitor::lowerArg(FModuleLike module, size_t argIndex,
920 size_t argsRemoved,
921 SmallVectorImpl<PortInfo> &newArgs,
922 SmallVectorImpl<Value> &lowering) {
923
924 // Flatten any bundle types.
925 SmallVector<FlatBundleFieldEntry> fieldTypes;
926 auto srcType = type_cast<FIRRTLType>(newArgs[argIndex].type);
927 if (!peelType(srcType, fieldTypes, getPreservationModeForPorts(module)))
928 return false;
929
930 SmallVector<hw::InnerSymAttr> fieldSyms(fieldTypes.size());
931 if (failed(partitionSymbols(newArgs[argIndex].sym, srcType, fieldSyms,
932 newArgs[argIndex].loc))) {
933 encounteredError = true;
934 return false;
935 }
936
937 for (const auto &[idx, field, fieldSym] :
938 llvm::enumerate(fieldTypes, fieldSyms)) {
939 auto newValue = addArg(module, 1 + argIndex + idx, argsRemoved, srcType,
940 field, newArgs[argIndex], fieldSym);
941 newArgs.insert(newArgs.begin() + 1 + argIndex + idx, newValue.second);
942 // Lower any other arguments by copying them to keep the relative order.
943 lowering.push_back(newValue.first);
944 }
945 return true;
946}
947
948static Value cloneAccess(ImplicitLocOpBuilder *builder, Operation *op,
949 Value rhs) {
950 if (auto rop = llvm::dyn_cast<SubfieldOp>(op))
951 return SubfieldOp::create(*builder, rhs, rop.getFieldIndex());
952 if (auto rop = llvm::dyn_cast<SubindexOp>(op))
953 return SubindexOp::create(*builder, rhs, rop.getIndex());
954 if (auto rop = llvm::dyn_cast<SubaccessOp>(op))
955 return SubaccessOp::create(*builder, rhs, rop.getIndex());
956 op->emitError("Unknown accessor");
957 return nullptr;
958}
959
960void TypeLoweringVisitor::lowerSAWritePath(Operation *op,
961 ArrayRef<Operation *> writePath) {
962 SubaccessOp sao = cast<SubaccessOp>(writePath.back());
963 FVectorType saoType = sao.getInput().getType();
964 auto selectWidth = llvm::Log2_64_Ceil(saoType.getNumElements());
965
966 for (size_t index = 0, e = saoType.getNumElements(); index < e; ++index) {
967 auto cond = EQPrimOp::create(
968 *builder, sao.getIndex(),
969 builder->createOrFold<ConstantOp>(UIntType::get(context, selectWidth),
970 APInt(selectWidth, index)));
971 WhenOp::create(*builder, cond, false, [&]() {
972 // Recreate the write Path
973 Value leaf = SubindexOp::create(*builder, sao.getInput(), index);
974 for (int i = writePath.size() - 2; i >= 0; --i) {
975 if (auto access = cloneAccess(builder, writePath[i], leaf))
976 leaf = access;
977 else {
978 encounteredError = true;
979 return;
980 }
981 }
982
983 emitConnect(*builder, leaf, op->getOperand(1));
984 });
985 }
986}
987
988// Expand connects of aggregates
989bool TypeLoweringVisitor::visitStmt(ConnectOp op) {
990 if (processSAPath(op))
991 return true;
992
993 // Attempt to get the bundle types.
994 SmallVector<FlatBundleFieldEntry> fields;
995
996 // We have to expand connections even if the aggregate preservation is true.
997 if (!peelType(op.getDest().getType(), fields, PreserveAggregate::None))
998 return false;
999
1000 // Loop over the leaf aggregates.
1001 for (const auto &field : llvm::enumerate(fields)) {
1002 Value src = getSubWhatever(op.getSrc(), field.index());
1003 Value dest = getSubWhatever(op.getDest(), field.index());
1004 if (field.value().isOutput)
1005 std::swap(src, dest);
1006 emitConnect(*builder, dest, src);
1007 }
1008 return true;
1009}
1010
1011// Expand connects of aggregates
1012bool TypeLoweringVisitor::visitStmt(MatchingConnectOp op) {
1013 if (processSAPath(op))
1014 return true;
1015
1016 // Attempt to get the bundle types.
1017 SmallVector<FlatBundleFieldEntry> fields;
1018
1019 // We have to expand connections even if the aggregate preservation is true.
1020 if (!peelType(op.getDest().getType(), fields, PreserveAggregate::None))
1021 return false;
1022
1023 // Loop over the leaf aggregates.
1024 for (const auto &field : llvm::enumerate(fields)) {
1025 Value src = getSubWhatever(op.getSrc(), field.index());
1026 Value dest = getSubWhatever(op.getDest(), field.index());
1027 if (field.value().isOutput)
1028 std::swap(src, dest);
1029 MatchingConnectOp::create(*builder, dest, src);
1030 }
1031 return true;
1032}
1033
1034// Expand connects of references-of-aggregates
1035bool TypeLoweringVisitor::visitStmt(RefDefineOp op) {
1036 // Attempt to get the bundle types.
1037 SmallVector<FlatBundleFieldEntry> fields;
1038
1039 if (!peelType(op.getDest().getType(), fields, bodyAggregatePreservationMode))
1040 return false;
1041
1042 // Loop over the leaf aggregates.
1043 for (const auto &field : llvm::enumerate(fields)) {
1044 Value src = getSubWhatever(op.getSrc(), field.index());
1045 Value dest = getSubWhatever(op.getDest(), field.index());
1046 assert(!field.value().isOutput && "unexpected flip in reftype destination");
1047 RefDefineOp::create(*builder, dest, src);
1048 }
1049 return true;
1050}
1051
1052bool TypeLoweringVisitor::visitStmt(WhenOp op) {
1053 // The WhenOp itself does not require any lowering, the only value it uses
1054 // is a one-bit predicate. Recursively visit all regions so internal
1055 // operations are lowered.
1056
1057 // Visit operations in the then block.
1058 lowerBlock(&op.getThenBlock());
1059
1060 // Visit operations in the else block.
1061 if (op.hasElseRegion())
1062 lowerBlock(&op.getElseBlock());
1063 return false; // don't delete the when!
1064}
1065
1066/// Lower any types declared in layer blocks.
1067bool TypeLoweringVisitor::visitStmt(LayerBlockOp op) {
1068 lowerBlock(op.getBody());
1069 return false;
1070}
1071
1072/// Lower memory operations. A new memory is created for every leaf
1073/// element in a memory's data type.
1074bool TypeLoweringVisitor::visitDecl(MemOp op) {
1075 // Attempt to get the bundle types.
1076 SmallVector<FlatBundleFieldEntry> fields;
1077
1078 // MemOp should have ground types so we can't preserve aggregates.
1079 if (!peelType(op.getDataType(), fields, memoryPreservationMode))
1080 return false;
1081
1082 if (op.getInnerSym()) {
1083 op->emitError() << "has a symbol, but no symbols may exist on aggregates "
1084 "passed through LowerTypes";
1085 encounteredError = true;
1086 return false;
1087 }
1088
1089 SmallVector<MemOp> newMemories;
1090 SmallVector<WireOp> oldPorts;
1091
1092 // Wires for old ports
1093 for (unsigned int index = 0, end = op.getNumResults(); index < end; ++index) {
1094 auto result = op.getResult(index);
1095 if (op.getPortKind(index) == MemOp::PortKind::Debug) {
1096 op.emitOpError("cannot lower memory with debug port");
1097 encounteredError = true;
1098 return false;
1099 }
1100 auto wire =
1101 WireOp::create(*builder, result.getType(),
1102 (op.getName() + "_" + op.getPortName(index)).str());
1103 oldPorts.push_back(wire);
1104 result.replaceAllUsesWith(wire.getResult());
1105 }
1106 // If annotations targeting fields of an aggregate are present, we cannot
1107 // flatten the memory. It must be split into one memory per aggregate field.
1108 // Do not overwrite the pass flag!
1109
1110 // Memory for each field
1111 for (const auto &field : fields) {
1112 auto newMemForField = cloneMemWithNewType(builder, op, field);
1113 if (!newMemForField) {
1114 op.emitError("failed cloning memory for field");
1115 encounteredError = true;
1116 return false;
1117 }
1118 newMemories.push_back(newMemForField);
1119 }
1120 // Hook up the new memories to the wires the old memory was replaced with.
1121 for (size_t index = 0, rend = op.getNumResults(); index < rend; ++index) {
1122 auto result = oldPorts[index].getResult();
1123 auto rType = type_cast<BundleType>(result.getType());
1124 for (size_t fieldIndex = 0, fend = rType.getNumElements();
1125 fieldIndex != fend; ++fieldIndex) {
1126 auto name = rType.getElement(fieldIndex).name.getValue();
1127 auto oldField = SubfieldOp::create(*builder, result, fieldIndex);
1128 // data and mask depend on the memory type which was split. They can also
1129 // go both directions, depending on the port direction.
1130 if (name == "data" || name == "mask" || name == "wdata" ||
1131 name == "wmask" || name == "rdata") {
1132 for (const auto &field : fields) {
1133 auto realOldField = getSubWhatever(oldField, field.index);
1134 auto newField = getSubWhatever(
1135 newMemories[field.index].getResult(index), fieldIndex);
1136 if (rType.getElement(fieldIndex).isFlip)
1137 std::swap(realOldField, newField);
1138 emitConnect(*builder, newField, realOldField);
1139 }
1140 } else {
1141 for (auto mem : newMemories) {
1142 auto newField =
1143 SubfieldOp::create(*builder, mem.getResult(index), fieldIndex);
1144 emitConnect(*builder, newField, oldField);
1145 }
1146 }
1147 }
1148 }
1149 return true;
1150}
1151
1152bool TypeLoweringVisitor::visitDecl(FExtModuleOp extModule) {
1153 ImplicitLocOpBuilder theBuilder(extModule.getLoc(), context);
1154 builder = &theBuilder;
1155
1156 // Top level builder
1157 OpBuilder builder(context);
1158
1159 // Lower the module block arguments.
1160 llvm::BitVector argsToRemove;
1161 auto newArgs = extModule.getPorts();
1162 argsToRemove.reserve(newArgs.size());
1163
1164 DomainLoweringHelper domainHelper(context, extModule.getPortTypes());
1165
1166 size_t argsRemoved = 0;
1167 for (size_t argIndex = 0; argIndex < newArgs.size(); ++argIndex) {
1168 SmallVector<Value> lowering;
1169 if (lowerArg(extModule, argIndex, argsRemoved, newArgs, lowering)) {
1170 argsToRemove.push_back(true);
1171 ++argsRemoved;
1172 } else {
1173 argsToRemove.push_back(false);
1174 }
1175 // lowerArg might have invalidated any reference to newArgs, be careful
1176 }
1177
1178 // Remove block args that have been lowered.
1179 if (argsRemoved != 0)
1180 eraseElementsAtIndices(newArgs, argsToRemove);
1181
1182 domainHelper.computeDomainMap(newArgs);
1183
1184 SmallVector<NamedAttribute, 8> newModuleAttrs;
1185
1186 // Copy over any attributes that weren't original argument attributes.
1187 for (auto attr : extModule->getAttrDictionary())
1188 // Drop old "portNames", directions, and argument attributes. These are
1189 // handled differently below.
1190 if (attr.getName() != "portDirections" && attr.getName() != "portNames" &&
1191 attr.getName() != "portTypes" && attr.getName() != "portAnnotations" &&
1192 attr.getName() != "portSymbols" && attr.getName() != "portLocations")
1193 newModuleAttrs.push_back(attr);
1194
1195 SmallVector<Direction> newArgDirections;
1196 SmallVector<Attribute> newArgNames;
1197 SmallVector<Attribute, 8> newArgTypes;
1198 SmallVector<Attribute, 8> newArgSyms;
1199 SmallVector<Attribute, 8> newArgLocations;
1200 SmallVector<Attribute, 8> newArgAnnotations;
1201 SmallVector<Attribute, 8> newArgDomains;
1202
1203 for (auto &port : newArgs) {
1204 newArgDirections.push_back(port.direction);
1205 newArgNames.push_back(port.name);
1206 newArgTypes.push_back(TypeAttr::get(port.type));
1207 newArgSyms.push_back(port.sym);
1208 newArgLocations.push_back(port.loc);
1209 newArgAnnotations.push_back(port.annotations.getArrayAttr());
1210 if (port.domains) {
1211 domainHelper.rewriteDomain(port.domains);
1212 } else {
1213 port.domains = cache.aEmpty;
1214 }
1215 newArgDomains.push_back(port.domains);
1216 }
1217
1218 newModuleAttrs.push_back(
1219 NamedAttribute(cache.sPortDirections,
1220 direction::packAttribute(context, newArgDirections)));
1221
1222 newModuleAttrs.push_back(
1223 NamedAttribute(cache.sPortNames, builder.getArrayAttr(newArgNames)));
1224
1225 newModuleAttrs.push_back(
1226 NamedAttribute(cache.sPortTypes, builder.getArrayAttr(newArgTypes)));
1227
1228 newModuleAttrs.push_back(NamedAttribute(
1229 cache.sPortLocations, builder.getArrayAttr(newArgLocations)));
1230
1231 newModuleAttrs.push_back(NamedAttribute(
1232 cache.sPortAnnotations, builder.getArrayAttr(newArgAnnotations)));
1233
1234 newModuleAttrs.push_back(
1235 NamedAttribute(cache.sPortDomains, builder.getArrayAttr(newArgDomains)));
1236
1237 // Update the module's attributes.
1238 extModule->setAttrs(newModuleAttrs);
1239 FModuleLike::fixupPortSymsArray(newArgSyms, context);
1240 extModule.setPortSymbols(newArgSyms);
1241
1242 return false;
1243}
1244
1245bool TypeLoweringVisitor::visitDecl(FModuleOp module) {
1246 auto *body = module.getBodyBlock();
1247
1248 ImplicitLocOpBuilder theBuilder(module.getLoc(), context);
1249 builder = &theBuilder;
1250
1251 // Lower the operations.
1252 lowerBlock(body);
1253
1254 // Lower the module block arguments.
1255 llvm::BitVector argsToRemove;
1256 auto newArgs = module.getPorts();
1257 argsToRemove.reserve(newArgs.size());
1258
1259 DomainLoweringHelper domainHelper(context, module.getPortTypes());
1260
1261 size_t argsRemoved = 0;
1262 for (size_t argIndex = 0; argIndex < newArgs.size(); ++argIndex) {
1263 SmallVector<Value> lowerings;
1264 if (lowerArg(module, argIndex, argsRemoved, newArgs, lowerings)) {
1265 auto arg = module.getArgument(argIndex);
1266 processUsers(arg, lowerings);
1267 argsToRemove.push_back(true);
1268 ++argsRemoved;
1269 } else
1270 argsToRemove.push_back(false);
1271 // lowerArg might have invalidated any reference to newArgs, be careful
1272 }
1273
1274 // Remove block args that have been lowered.
1275 if (argsRemoved != 0) {
1276 body->eraseArguments(argsToRemove);
1277 eraseElementsAtIndices(newArgs, argsToRemove);
1278 }
1279
1280 domainHelper.computeDomainMap(newArgs);
1281
1282 SmallVector<NamedAttribute, 8> newModuleAttrs;
1283
1284 // Copy over any attributes that weren't original argument attributes.
1285 for (auto attr : module->getAttrDictionary())
1286 // Drop old "portNames", directions, and argument attributes. These are
1287 // handled differently below.
1288 if (attr.getName() != "portNames" && attr.getName() != "portDirections" &&
1289 attr.getName() != "portTypes" && attr.getName() != "portAnnotations" &&
1290 attr.getName() != "portSymbols" && attr.getName() != "portLocations")
1291 newModuleAttrs.push_back(attr);
1292
1293 SmallVector<Direction> newArgDirections;
1294 SmallVector<Attribute> newArgNames;
1295 SmallVector<Attribute> newArgTypes;
1296 SmallVector<Attribute> newArgSyms;
1297 SmallVector<Attribute> newArgLocations;
1298 SmallVector<Attribute, 8> newArgAnnotations;
1299 SmallVector<Attribute> newPortDomains;
1300 for (auto &port : newArgs) {
1301 newArgDirections.push_back(port.direction);
1302 newArgNames.push_back(port.name);
1303 newArgTypes.push_back(TypeAttr::get(port.type));
1304 newArgSyms.push_back(port.sym);
1305 newArgLocations.push_back(port.loc);
1306 newArgAnnotations.push_back(port.annotations.getArrayAttr());
1307 if (port.domains) {
1308 domainHelper.rewriteDomain(port.domains);
1309 } else {
1310 port.domains = cache.aEmpty;
1311 }
1312 newPortDomains.push_back(port.domains);
1313 }
1314
1315 newModuleAttrs.push_back(
1316 NamedAttribute(cache.sPortDirections,
1317 direction::packAttribute(context, newArgDirections)));
1318
1319 newModuleAttrs.push_back(
1320 NamedAttribute(cache.sPortNames, builder->getArrayAttr(newArgNames)));
1321
1322 newModuleAttrs.push_back(
1323 NamedAttribute(cache.sPortTypes, builder->getArrayAttr(newArgTypes)));
1324
1325 newModuleAttrs.push_back(NamedAttribute(
1326 cache.sPortLocations, builder->getArrayAttr(newArgLocations)));
1327
1328 newModuleAttrs.push_back(NamedAttribute(
1329 cache.sPortAnnotations, builder->getArrayAttr(newArgAnnotations)));
1330
1331 newModuleAttrs.push_back(NamedAttribute(
1332 cache.sPortDomains, builder->getArrayAttr(newPortDomains)));
1333
1334 // Update the module's attributes.
1335 module->setAttrs(newModuleAttrs);
1336 FModuleLike::fixupPortSymsArray(newArgSyms, context);
1337 module.setPortSymbols(newArgSyms);
1338 return false;
1339}
1340
1341/// Lower a wire op with a bundle to multiple non-bundled wires.
1342bool TypeLoweringVisitor::visitDecl(WireOp op) {
1343 if (op.isForceable())
1344 return false;
1345
1346 auto clone = [&](const FlatBundleFieldEntry &field,
1347 ArrayAttr attrs) -> Value {
1348 return WireOp::create(*builder,
1349 mapLoweredType(op.getDataRaw().getType(), field.type),
1350 "", NameKindEnum::DroppableName, attrs, StringAttr{})
1351 .getResult();
1352 };
1353 return lowerProducer(op, clone);
1354}
1355
1356/// Lower a reg op with a bundle to multiple non-bundled regs.
1357bool TypeLoweringVisitor::visitDecl(RegOp op) {
1358 if (op.isForceable())
1359 return false;
1360
1361 auto clone = [&](const FlatBundleFieldEntry &field,
1362 ArrayAttr attrs) -> Value {
1363 return RegOp::create(*builder, field.type, op.getClockVal(), "",
1364 NameKindEnum::DroppableName, attrs, StringAttr{})
1365 .getResult();
1366 };
1367 return lowerProducer(op, clone);
1368}
1369
1370/// Lower a reg op with a bundle to multiple non-bundled regs.
1371bool TypeLoweringVisitor::visitDecl(RegResetOp op) {
1372 if (op.isForceable())
1373 return false;
1374
1375 auto clone = [&](const FlatBundleFieldEntry &field,
1376 ArrayAttr attrs) -> Value {
1377 auto resetVal = getSubWhatever(op.getResetValue(), field.index);
1378 return RegResetOp::create(*builder, field.type, op.getClockVal(),
1379 op.getResetSignal(), resetVal, "",
1380 NameKindEnum::DroppableName, attrs, StringAttr{})
1381 .getResult();
1382 };
1383 return lowerProducer(op, clone);
1384}
1385
1386/// Lower a wire op with a bundle to multiple non-bundled wires.
1387bool TypeLoweringVisitor::visitDecl(NodeOp op) {
1388 if (op.isForceable())
1389 return false;
1390
1391 auto clone = [&](const FlatBundleFieldEntry &field,
1392 ArrayAttr attrs) -> Value {
1393 auto input = getSubWhatever(op.getInput(), field.index);
1394 return NodeOp::create(*builder, input, "", NameKindEnum::DroppableName,
1395 attrs)
1396 .getResult();
1397 };
1398 return lowerProducer(op, clone);
1399}
1400
1401/// Lower an InvalidValue op with a bundle to multiple non-bundled InvalidOps.
1402bool TypeLoweringVisitor::visitExpr(InvalidValueOp op) {
1403 auto clone = [&](const FlatBundleFieldEntry &field,
1404 ArrayAttr attrs) -> Value {
1405 return InvalidValueOp::create(*builder, field.type);
1406 };
1407 return lowerProducer(op, clone);
1408}
1409
1410// Expand muxes of aggregates
1411bool TypeLoweringVisitor::visitExpr(MuxPrimOp op) {
1412 auto clone = [&](const FlatBundleFieldEntry &field,
1413 ArrayAttr attrs) -> Value {
1414 auto high = getSubWhatever(op.getHigh(), field.index);
1415 auto low = getSubWhatever(op.getLow(), field.index);
1416 return MuxPrimOp::create(*builder, op.getSel(), high, low);
1417 };
1418 return lowerProducer(op, clone);
1419}
1420
1421// Expand muxes of aggregates
1422bool TypeLoweringVisitor::visitExpr(Mux2CellIntrinsicOp op) {
1423 auto clone = [&](const FlatBundleFieldEntry &field,
1424 ArrayAttr attrs) -> Value {
1425 auto high = getSubWhatever(op.getHigh(), field.index);
1426 auto low = getSubWhatever(op.getLow(), field.index);
1427 return Mux2CellIntrinsicOp::create(*builder, op.getSel(), high, low);
1428 };
1429 return lowerProducer(op, clone);
1430}
1431
1432// Expand muxes of aggregates
1433bool TypeLoweringVisitor::visitExpr(Mux4CellIntrinsicOp op) {
1434 auto clone = [&](const FlatBundleFieldEntry &field,
1435 ArrayAttr attrs) -> Value {
1436 auto v3 = getSubWhatever(op.getV3(), field.index);
1437 auto v2 = getSubWhatever(op.getV2(), field.index);
1438 auto v1 = getSubWhatever(op.getV1(), field.index);
1439 auto v0 = getSubWhatever(op.getV0(), field.index);
1440 return Mux4CellIntrinsicOp::create(*builder, op.getSel(), v3, v2, v1, v0);
1441 };
1442 return lowerProducer(op, clone);
1443}
1444
1445// Expand UnrealizedConversionCastOp of aggregates
1446bool TypeLoweringVisitor::visitUnrealizedConversionCast(
1447 mlir::UnrealizedConversionCastOp op) {
1448 auto clone = [&](const FlatBundleFieldEntry &field,
1449 ArrayAttr attrs) -> Value {
1450 auto input = getSubWhatever(op.getOperand(0), field.index);
1451 return mlir::UnrealizedConversionCastOp::create(*builder, field.type, input)
1452 .getResult(0);
1453 };
1454 // If the input to the cast is not a FIRRTL type, getSubWhatever cannot handle
1455 // it, donot lower the op.
1456 if (!type_isa<FIRRTLType>(op->getOperand(0).getType()))
1457 return false;
1458 return lowerProducer(op, clone);
1459}
1460
1461// Expand BitCastOp of aggregates
1462bool TypeLoweringVisitor::visitExpr(BitCastOp op) {
1463 Value srcLoweredVal = op.getInput();
1464 // If the input is of aggregate type, then cat all the leaf fields to form a
1465 // UInt type result. That is, first bitcast the aggregate type to a UInt.
1466 // Attempt to get the bundle types.
1467 SmallVector<FlatBundleFieldEntry> fields;
1468 if (peelType(op.getInput().getType(), fields, PreserveAggregate::None)) {
1469 size_t uptoBits = 0;
1470 // Loop over the leaf aggregates and concat each of them to get a UInt.
1471 // Bitcast the fields to handle nested aggregate types.
1472 for (const auto &field : llvm::enumerate(fields)) {
1473 auto fieldBitwidth = *getBitWidth(field.value().type);
1474 // Ignore zero width fields, like empty bundles.
1475 if (fieldBitwidth == 0)
1476 continue;
1477 Value src = getSubWhatever(op.getInput(), field.index());
1478 // The src could be an aggregate type, bitcast it to a UInt type.
1479 src = builder->createOrFold<BitCastOp>(
1480 UIntType::get(context, fieldBitwidth), src);
1481 // Take the first field, or else Cat the previous fields with this field.
1482 if (uptoBits == 0)
1483 srcLoweredVal = src;
1484 else {
1485 if (type_isa<BundleType>(op.getInput().getType())) {
1486 srcLoweredVal =
1487 CatPrimOp::create(*builder, ValueRange{srcLoweredVal, src});
1488 } else {
1489 srcLoweredVal =
1490 CatPrimOp::create(*builder, ValueRange{src, srcLoweredVal});
1491 }
1492 }
1493 // Record the total bits already accumulated.
1494 uptoBits += fieldBitwidth;
1495 }
1496 } else {
1497 srcLoweredVal = builder->createOrFold<AsUIntPrimOp>(srcLoweredVal);
1498 }
1499 // Now the input has been cast to srcLoweredVal, which is of UInt type.
1500 // If the result is an aggregate type, then use lowerProducer.
1501 if (type_isa<BundleType, FVectorType>(op.getResult().getType())) {
1502 // uptoBits is used to keep track of the bits that have been extracted.
1503 size_t uptoBits = 0;
1504 auto aggregateBits = *getBitWidth(op.getResult().getType());
1505 auto clone = [&](const FlatBundleFieldEntry &field,
1506 ArrayAttr attrs) -> Value {
1507 // All the fields must have valid bitwidth, a requirement for BitCastOp.
1508 auto fieldBits = *getBitWidth(field.type);
1509 // If empty field, then it doesnot have any use, so replace it with an
1510 // invalid op, which should be trivially removed.
1511 if (fieldBits == 0)
1512 return InvalidValueOp::create(*builder, field.type);
1513
1514 // Assign the field to the corresponding bits from the input.
1515 // Bitcast the field, incase its an aggregate type.
1516 BitsPrimOp extractBits;
1517 if (type_isa<BundleType>(op.getResult().getType())) {
1518 extractBits = BitsPrimOp::create(*builder, srcLoweredVal,
1519 aggregateBits - uptoBits - 1,
1520 aggregateBits - uptoBits - fieldBits);
1521 } else {
1522 extractBits = BitsPrimOp::create(*builder, srcLoweredVal,
1523 uptoBits + fieldBits - 1, uptoBits);
1524 }
1525 uptoBits += fieldBits;
1526 return BitCastOp::create(*builder, field.type, extractBits);
1527 };
1528 return lowerProducer(op, clone);
1529 }
1530
1531 // If ground type, then replace the result.
1532 if (type_isa<SIntType>(op.getType()))
1533 srcLoweredVal = AsSIntPrimOp::create(*builder, srcLoweredVal);
1534 op.getResult().replaceAllUsesWith(srcLoweredVal);
1535 return true;
1536}
1537
1538bool TypeLoweringVisitor::visitExpr(RefSendOp op) {
1539 auto clone = [&](const FlatBundleFieldEntry &field,
1540 ArrayAttr attrs) -> Value {
1541 return RefSendOp::create(*builder,
1542 getSubWhatever(op.getBase(), field.index));
1543 };
1544 // Be careful re:what gets lowered, consider ref.send of non-passive
1545 // and whether we're using the ref or the base type to choose
1546 // whether this should be lowered.
1547 return lowerProducer(op, clone);
1548}
1549
1550bool TypeLoweringVisitor::visitExpr(RefResolveOp op) {
1551 auto clone = [&](const FlatBundleFieldEntry &field,
1552 ArrayAttr attrs) -> Value {
1553 Value src = getSubWhatever(op.getRef(), field.index);
1554 return RefResolveOp::create(*builder, src);
1555 };
1556 // Lower according to lowering of the reference.
1557 // Particularly, preserve if rwprobe.
1558 return lowerProducer(op, clone, op.getRef().getType());
1559}
1560
1561bool TypeLoweringVisitor::visitExpr(RefCastOp op) {
1562 auto clone = [&](const FlatBundleFieldEntry &field,
1563 ArrayAttr attrs) -> Value {
1564 auto input = getSubWhatever(op.getInput(), field.index);
1565 return RefCastOp::create(*builder,
1566 RefType::get(field.type,
1567 op.getType().getForceable(),
1568 op.getType().getLayer()),
1569 input);
1570 };
1571 return lowerProducer(op, clone);
1572}
1573
1574/// Helper function to lower instance-like operations. This contains the common
1575/// logic for both InstanceOp and InstanceChoiceOp.
1576bool TypeLoweringVisitor::lowerInstanceLike(
1577 FInstanceLike op, PreserveAggregate::PreserveMode mode,
1578 ArrayAttr oldPortAnno,
1579 llvm::function_ref<Operation *(ArrayRef<Type>, ArrayRef<Direction>,
1580 ArrayAttr, ArrayAttr, ArrayAttr,
1581 hw::InnerSymAttr)>
1582 createNewInstance) {
1583 bool skip = true;
1584 SmallVector<Type, 8> resultTypes;
1585 SmallVector<int64_t, 8> endFields; // Compressed sparse row encoding
1586 SmallVector<Direction> newDirs;
1587 SmallVector<Attribute> newNames, newDomains, newPortAnno;
1588
1589 // Create domain helper to track domain port indices.
1590 DomainLoweringHelper domainHelper(context, op->getResultTypes());
1591 auto emptyAnno = builder->getArrayAttr({});
1592
1593 endFields.push_back(0);
1594 for (size_t i = 0, e = op->getNumResults(); i != e; ++i) {
1595 auto srcType = type_cast<FIRRTLType>(op->getResult(i).getType());
1596
1597 // Flatten any nested bundle types the usual way.
1598 SmallVector<FlatBundleFieldEntry, 8> fieldTypes;
1599 if (!peelType(srcType, fieldTypes, mode)) {
1600 newDirs.push_back(op.getPortDirection(i));
1601 newNames.push_back(op.getPortNameAttr(i));
1602 newDomains.push_back(op.getPortDomain(i));
1603 resultTypes.push_back(srcType);
1604 newPortAnno.push_back(oldPortAnno ? oldPortAnno[i] : emptyAnno);
1605 } else {
1606 skip = false;
1607 auto oldName = op.getPortName(i);
1608 auto oldDir = op.getPortDirection(i);
1609 // Store the flat type for the new bundle type.
1610 for (const auto &field : fieldTypes) {
1611 newDirs.push_back(direction::get((unsigned)oldDir ^ field.isOutput));
1612 newNames.push_back(builder->getStringAttr(oldName + field.suffix));
1613 newDomains.push_back(op.getPortDomain(i));
1614 resultTypes.push_back(mapLoweredType(srcType, field.type));
1615 auto annos =
1616 oldPortAnno
1617 ? filterAnnotations(context,
1618 dyn_cast_or_null<ArrayAttr>(oldPortAnno[i]),
1619 srcType, field)
1620 : emptyAnno;
1621 newPortAnno.push_back(annos);
1622 }
1623 }
1624 endFields.push_back(resultTypes.size());
1625 }
1626
1627 auto sym = getInnerSymName(op);
1628
1629 if (skip) {
1630 return false;
1631 }
1632
1633 // Compute the mapping from old domain indices to new domain indices.
1634 domainHelper.computeDomainMap(resultTypes);
1635
1636 // Rewrite domain associations to use the new port numbers.
1637 for (auto &domain : newDomains)
1638 domainHelper.rewriteDomain(domain);
1639
1640 // Create the new instance using the provided factory function.
1641 auto *newInstance = createNewInstance(
1642 resultTypes, newDirs, builder->getArrayAttr(newNames),
1643 builder->getArrayAttr(newDomains), builder->getArrayAttr(newPortAnno),
1644 sym ? hw::InnerSymAttr::get(sym) : hw::InnerSymAttr());
1645
1646 newInstance->setDiscardableAttrs(op->getDiscardableAttrDictionary());
1647
1648 SmallVector<Value> lowered;
1649 for (size_t aggIndex = 0, eAgg = op->getNumResults(); aggIndex != eAgg;
1650 ++aggIndex) {
1651 lowered.clear();
1652 for (size_t fieldIndex = endFields[aggIndex],
1653 eField = endFields[aggIndex + 1];
1654 fieldIndex < eField; ++fieldIndex)
1655 lowered.push_back(newInstance->getResult(fieldIndex));
1656 if (lowered.size() != 1 ||
1657 op->getResult(aggIndex).getType() != resultTypes[endFields[aggIndex]])
1658 processUsers(op->getResult(aggIndex), lowered);
1659 else
1660 op->getResult(aggIndex).replaceAllUsesWith(lowered[0]);
1661 }
1662 return true;
1663}
1664
1665bool TypeLoweringVisitor::visitDecl(InstanceOp op) {
1666 // Determine preservation mode from the referenced module.
1667 PreserveAggregate::PreserveMode mode = getPreservationModeForPorts(
1668 cast<FModuleLike>(op.getReferencedOperation(symTbl)));
1669
1670 // Lambda to create the new InstanceOp with lowered types.
1671 auto createNewInstance = [&](ArrayRef<Type> resultTypes,
1672 ArrayRef<Direction> newDirs, ArrayAttr newNames,
1673 ArrayAttr newDomains, ArrayAttr newPortAnno,
1674 hw::InnerSymAttr sym) -> Operation * {
1675 // FIXME: annotation update
1676 return InstanceOp::create(
1677 *builder, resultTypes, op.getModuleNameAttr(), op.getNameAttr(),
1678 op.getNameKindAttr(), direction::packAttribute(context, newDirs),
1679 newNames, newDomains, op.getAnnotations(), newPortAnno,
1680 op.getLayersAttr(), op.getLowerToBindAttr(), op.getDoNotPrintAttr(),
1681 sym);
1682 };
1683
1684 return lowerInstanceLike(op, mode, op.getPortAnnotations(),
1685 createNewInstance);
1686}
1687
1688bool TypeLoweringVisitor::visitDecl(InstanceChoiceOp op) {
1689 // Get the default target module to determine preservation mode.
1690 auto *moduleOp = symTbl.lookupNearestSymbolFrom(
1691 op, cast<FlatSymbolRefAttr>(op.getDefaultTargetAttr()));
1692 auto mode = getPreservationModeForPorts(cast<FModuleLike>(moduleOp));
1693
1694 // Lambda to create the new InstanceChoiceOp with lowered types.
1695 auto createNewInstance = [&](ArrayRef<Type> resultTypes,
1696 ArrayRef<Direction> newDirs, ArrayAttr newNames,
1697 ArrayAttr newDomains, ArrayAttr newPortAnno,
1698 hw::InnerSymAttr sym) -> Operation * {
1699 return InstanceChoiceOp::create(
1700 *builder, resultTypes, op.getModuleNames(), op.getCaseNames(),
1701 op.getNameAttr(), op.getNameKindAttr(),
1702 direction::packAttribute(context, newDirs), newNames, newDomains,
1703 op.getAnnotations(), newPortAnno, op.getLayersAttr(), sym,
1704 op.getInstanceMacroAttr());
1705 };
1706
1707 return lowerInstanceLike(op, mode, op.getPortAnnotations(),
1708 createNewInstance);
1709}
1710
1711bool TypeLoweringVisitor::visitExpr(SubaccessOp op) {
1712 auto input = op.getInput();
1713 FVectorType vType = input.getType();
1714
1715 // Check for empty vectors
1716 if (vType.getNumElements() == 0) {
1717 Value inv = InvalidValueOp::create(*builder, vType.getElementType());
1718 op.replaceAllUsesWith(inv);
1719 return true;
1720 }
1721
1722 // Check for constant instances
1723 if (ConstantOp arg =
1724 llvm::dyn_cast_or_null<ConstantOp>(op.getIndex().getDefiningOp())) {
1725 auto sio = SubindexOp::create(*builder, op.getInput(),
1726 arg.getValue().getExtValue());
1727 op.replaceAllUsesWith(sio.getResult());
1728 return true;
1729 }
1730
1731 // Construct a multibit mux
1732 SmallVector<Value> inputs;
1733 inputs.reserve(vType.getNumElements());
1734 for (int index = vType.getNumElements() - 1; index >= 0; index--)
1735 inputs.push_back(SubindexOp::create(*builder, input, index));
1736
1737 Value multibitMux = MultibitMuxOp::create(*builder, op.getIndex(), inputs);
1738 op.replaceAllUsesWith(multibitMux);
1739 return true;
1740}
1741
1742bool TypeLoweringVisitor::visitExpr(VectorCreateOp op) {
1743 auto clone = [&](const FlatBundleFieldEntry &field,
1744 ArrayAttr attrs) -> Value {
1745 return op.getOperand(field.index);
1746 };
1747 return lowerProducer(op, clone);
1748}
1749
1750bool TypeLoweringVisitor::visitExpr(BundleCreateOp op) {
1751 auto clone = [&](const FlatBundleFieldEntry &field,
1752 ArrayAttr attrs) -> Value {
1753 return op.getOperand(field.index);
1754 };
1755 return lowerProducer(op, clone);
1756}
1757
1758bool TypeLoweringVisitor::visitExpr(ElementwiseOrPrimOp op) {
1759 auto clone = [&](const FlatBundleFieldEntry &field,
1760 ArrayAttr attrs) -> Value {
1761 Value operands[] = {getSubWhatever(op.getLhs(), field.index),
1762 getSubWhatever(op.getRhs(), field.index)};
1763 return type_isa<BundleType, FVectorType>(field.type)
1764 ? (Value)ElementwiseOrPrimOp::create(*builder, field.type,
1765 operands)
1766 : (Value)OrPrimOp::create(*builder, operands);
1767 };
1768
1769 return lowerProducer(op, clone);
1770}
1771
1772bool TypeLoweringVisitor::visitExpr(ElementwiseAndPrimOp op) {
1773 auto clone = [&](const FlatBundleFieldEntry &field,
1774 ArrayAttr attrs) -> Value {
1775 Value operands[] = {getSubWhatever(op.getLhs(), field.index),
1776 getSubWhatever(op.getRhs(), field.index)};
1777 return type_isa<BundleType, FVectorType>(field.type)
1778 ? (Value)ElementwiseAndPrimOp::create(*builder, field.type,
1779 operands)
1780 : (Value)AndPrimOp::create(*builder, operands);
1781 };
1782
1783 return lowerProducer(op, clone);
1784}
1785
1786bool TypeLoweringVisitor::visitExpr(ElementwiseXorPrimOp op) {
1787 auto clone = [&](const FlatBundleFieldEntry &field,
1788 ArrayAttr attrs) -> Value {
1789 Value operands[] = {getSubWhatever(op.getLhs(), field.index),
1790 getSubWhatever(op.getRhs(), field.index)};
1791 return type_isa<BundleType, FVectorType>(field.type)
1792 ? (Value)ElementwiseXorPrimOp::create(*builder, field.type,
1793 operands)
1794 : (Value)XorPrimOp::create(*builder, operands);
1795 };
1796
1797 return lowerProducer(op, clone);
1798}
1799
1800bool TypeLoweringVisitor::visitExpr(MultibitMuxOp op) {
1801 auto clone = [&](const FlatBundleFieldEntry &field,
1802 ArrayAttr attrs) -> Value {
1803 SmallVector<Value> newInputs;
1804 newInputs.reserve(op.getInputs().size());
1805 for (auto input : op.getInputs()) {
1806 auto inputSub = getSubWhatever(input, field.index);
1807 newInputs.push_back(inputSub);
1808 }
1809 return MultibitMuxOp::create(*builder, op.getIndex(), newInputs);
1810 };
1811 return lowerProducer(op, clone);
1812}
1813
1814//===----------------------------------------------------------------------===//
1815// Pass Infrastructure
1816//===----------------------------------------------------------------------===//
1817
1818namespace {
1819struct LowerTypesPass
1820 : public circt::firrtl::impl::LowerFIRRTLTypesBase<LowerTypesPass> {
1821 using Base::Base;
1822
1823 void runOnOperation() override;
1824};
1825} // end anonymous namespace
1826
1827// This is the main entrypoint for the lowering pass.
1828void LowerTypesPass::runOnOperation() {
1830
1831 std::vector<FModuleLike> ops;
1832 auto &instanceGraph = getAnalysis<InstanceGraph>();
1833 // Symbol Table
1834 auto &symTbl = getAnalysis<SymbolTable>();
1835 // Cached attr
1836 AttrCache cache(&getContext());
1837
1838 DenseMap<FModuleLike, Convention> conventionTable;
1839 auto circuit = getOperation();
1840 for (auto module : circuit.getOps<FModuleLike>()) {
1841 auto convention = module.getConvention();
1842 // Instance choices select between modules with a shared port shape, so
1843 // any module instantiated by one must use the scalarized convention.
1844 if (llvm::any_of(instanceGraph.lookup(module)->uses(),
1845 [](InstanceRecord *use) {
1846 return use->getInstance<InstanceChoiceOp>();
1847 }))
1848 convention = Convention::Scalarized;
1849 conventionTable.insert({module, convention});
1850 ops.push_back(module);
1851 }
1852
1853 // This lambda, executes in parallel for each Op within the circt.
1854 auto lowerModules = [&](FModuleLike op) -> LogicalResult {
1855 // Use body type lowering attribute if it exists, otherwise use internal.
1856 Convention convention = Convention::Internal;
1857 if (auto conventionAttr = dyn_cast_or_null<ConventionAttr>(
1858 op->getDiscardableAttr("body_type_lowering")))
1859 convention = conventionAttr.getValue();
1860
1861 auto tl =
1862 TypeLoweringVisitor(&getContext(), preserveAggregate, convention,
1863 preserveMemories, symTbl, cache, conventionTable);
1864 tl.lowerModule(op);
1865
1866 return LogicalResult::failure(tl.isFailed());
1867 };
1868
1869 auto result = failableParallelForEach(&getContext(), ops, lowerModules);
1870
1871 if (failed(result))
1872 signalPassFailure();
1873}
assert(baseType &&"element must be base type")
static SmallVector< Value > extractBits(OpBuilder &builder, Value val)
static std::unique_ptr< Context > context
static void dump(DIModule &module, raw_indented_ostream &os)
static void eraseElementsAtIndices(SmallVectorImpl< T > &vec, const llvm::BitVector &removalMask)
Helper function to remove elements from a vector based on a BitVector mask.
static bool isPreservableAggregateType(Type type, PreserveAggregate::PreserveMode mode)
Return true if we can preserve the type.
static FIRRTLType mapLoweredType(FIRRTLType type, FIRRTLBaseType fieldType)
Return fieldType or fieldType as same ref as type.
static MemOp cloneMemWithNewType(ImplicitLocOpBuilder *b, MemOp op, FlatBundleFieldEntry field)
Clone memory for the specified field. Returns null op on error.
static bool containsBundleType(FIRRTLType type)
Return true if the type has a bundle type as subtype.
static Value cloneAccess(ImplicitLocOpBuilder *builder, Operation *op, Value rhs)
static bool peelType(Type type, SmallVectorImpl< FlatBundleFieldEntry > &fields, PreserveAggregate::PreserveMode mode)
Peel one layer of an aggregate type into its components.
static bool isNotSubAccess(Operation *op)
Return if something is not a normal subaccess.
static SmallVector< Operation * > getSAWritePath(Operation *op)
Look through and collect subfields leading to a subaccess.
static bool isOneDimVectorType(FIRRTLType type)
Return true if the type is a 1d vector type or ground type.
#define CIRCT_DEBUG_SCOPED_PASS_LOGGER(PASS)
Definition Debug.h:70
This class provides a read-only projection over the MLIR attributes that represent a set of annotatio...
ArrayAttr getArrayAttr() const
Return this annotation set as an ArrayAttr.
This class provides a read-only projection of an annotation.
DictionaryAttr getDict() const
Get the data dictionary of this attribute.
unsigned getFieldID() const
Get the field id this attribute targets.
void setMember(StringAttr name, Attribute value)
Add or set a member of the annotation to a value.
This class implements the same functionality as TypeSwitch except that it uses firrtl::type_dyn_cast ...
FIRRTLTypeSwitch< T, ResultT > & Case(CallableT &&caseFn)
Add a case on the given type.
FIRRTLVisitor allows you to visit all of the expr/stmt/decls with one class declaration.
ResultType visitInvalidOp(Operation *op, ExtraArgs... args)
visitInvalidOp is an override point for non-FIRRTL dialect operations.
This is an edge in the InstanceGraph.
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition CalyxOps.cpp:55
@ All
Preserve all aggregate values.
Definition Passes.h:40
@ OneDimVec
Preserve only 1d vectors of ground type (e.g. UInt<2>[3]).
Definition Passes.h:34
@ Vec
Preserve only vectors (e.g. UInt<2>[3][3]).
Definition Passes.h:37
@ None
Don't preserve aggregate at all.
Definition Passes.h:31
mlir::DenseBoolArrayAttr packAttribute(MLIRContext *context, ArrayRef< Direction > directions)
Return a DenseBoolArrayAttr containing the packed representation of an array of directions.
static Direction get(bool isOutput)
Return an output direction if isOutput is true, otherwise return an input direction.
Definition FIRRTLEnums.h:36
Direction
This represents the direction of a single port.
Definition FIRRTLEnums.h:27
FIRRTLBaseType getBaseType(Type type)
If it is a base type, return it as is.
FIRRTLType mapBaseType(FIRRTLType type, function_ref< FIRRTLBaseType(FIRRTLBaseType)> fn)
Return a FIRRTLType with its base type component mutated by the given function.
bool hasZeroBitWidth(FIRRTLType type)
Return true if the type has zero bit width.
bool type_isa(Type type)
void emitConnect(OpBuilder &builder, Location loc, Value lhs, Value rhs)
Emit a connect between two values.
StringAttr getInnerSymName(Operation *op)
Return the StringAttr for the inner_sym name, if it exists.
Definition FIRRTLOps.h:108
std::optional< int64_t > getBitWidth(FIRRTLBaseType type, bool ignoreFlip=false)
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition hw.py:1
This holds the name and type that describes the module's ports.