CIRCT 20.0.0git
Loading...
Searching...
No Matches
FIRRTLOps.cpp
Go to the documentation of this file.
1//===- FIRRTLOps.cpp - Implement the FIRRTL operations --------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implement the FIRRTL ops.
10//
11//===----------------------------------------------------------------------===//
12
26#include "circt/Support/Utils.h"
27#include "mlir/IR/BuiltinTypes.h"
28#include "mlir/IR/Diagnostics.h"
29#include "mlir/IR/DialectImplementation.h"
30#include "mlir/IR/PatternMatch.h"
31#include "mlir/IR/SymbolTable.h"
32#include "mlir/Interfaces/FunctionImplementation.h"
33#include "llvm/ADT/BitVector.h"
34#include "llvm/ADT/DenseMap.h"
35#include "llvm/ADT/DenseSet.h"
36#include "llvm/ADT/STLExtras.h"
37#include "llvm/ADT/SmallSet.h"
38#include "llvm/ADT/StringExtras.h"
39#include "llvm/ADT/TypeSwitch.h"
40#include "llvm/Support/FormatVariadic.h"
41
42using llvm::SmallDenseSet;
43using mlir::RegionRange;
44using namespace circt;
45using namespace firrtl;
46using namespace chirrtl;
47
48//===----------------------------------------------------------------------===//
49// Utilities
50//===----------------------------------------------------------------------===//
51
52/// Remove elements from the input array corresponding to set bits in
53/// `indicesToDrop`, returning the elements not mentioned.
54template <typename T>
55static SmallVector<T>
56removeElementsAtIndices(ArrayRef<T> input,
57 const llvm::BitVector &indicesToDrop) {
58#ifndef NDEBUG
59 if (!input.empty()) {
60 int lastIndex = indicesToDrop.find_last();
61 if (lastIndex >= 0)
62 assert((size_t)lastIndex < input.size() && "index out of range");
63 }
64#endif
65
66 // If the input is empty (which is an optimization we do for certain array
67 // attributes), simply return an empty vector.
68 if (input.empty())
69 return {};
70
71 // Copy over the live chunks.
72 size_t lastCopied = 0;
73 SmallVector<T> result;
74 result.reserve(input.size() - indicesToDrop.count());
75
76 for (unsigned indexToDrop : indicesToDrop.set_bits()) {
77 // If we skipped over some valid elements, copy them over.
78 if (indexToDrop > lastCopied) {
79 result.append(input.begin() + lastCopied, input.begin() + indexToDrop);
80 lastCopied = indexToDrop;
81 }
82 // Ignore this value so we don't copy it in the next iteration.
83 ++lastCopied;
84 }
85
86 // If there are live elements at the end, copy them over.
87 if (lastCopied < input.size())
88 result.append(input.begin() + lastCopied, input.end());
89
90 return result;
91}
92
93/// Emit an error if optional location is non-null, return null of return type.
94template <typename RetTy = FIRRTLType, typename... Args>
95static RetTy emitInferRetTypeError(std::optional<Location> loc,
96 const Twine &message, Args &&...args) {
97 if (loc)
98 (mlir::emitError(*loc, message) << ... << std::forward<Args>(args));
99 return {};
100}
101
102bool firrtl::isDuplexValue(Value val) {
103 // Block arguments are not duplex values.
104 while (Operation *op = val.getDefiningOp()) {
105 auto isDuplex =
106 TypeSwitch<Operation *, std::optional<bool>>(op)
107 .Case<SubfieldOp, SubindexOp, SubaccessOp>([&val](auto op) {
108 val = op.getInput();
109 return std::nullopt;
110 })
111 .Case<RegOp, RegResetOp, WireOp>([](auto) { return true; })
112 .Default([](auto) { return false; });
113 if (isDuplex)
114 return *isDuplex;
115 }
116 return false;
117}
118
119SmallVector<std::pair<circt::FieldRef, circt::FieldRef>>
120MemOp::computeDataFlow() {
121 // If read result has non-zero latency, then no combinational dependency
122 // exists.
123 if (getReadLatency() > 0)
124 return {};
125 SmallVector<std::pair<circt::FieldRef, circt::FieldRef>> deps;
126 // Add a dependency from the enable and address fields to the data field.
127 for (auto memPort : getResults())
128 if (auto type = type_dyn_cast<BundleType>(memPort.getType())) {
129 auto enableFieldId = type.getFieldID((unsigned)ReadPortSubfield::en);
130 auto addressFieldId = type.getFieldID((unsigned)ReadPortSubfield::addr);
131 auto dataFieldId = type.getFieldID((unsigned)ReadPortSubfield::data);
132 deps.emplace_back(
133 FieldRef(memPort, static_cast<unsigned>(dataFieldId)),
134 FieldRef(memPort, static_cast<unsigned>(enableFieldId)));
135 deps.emplace_back(
136 FieldRef(memPort, static_cast<unsigned>(dataFieldId)),
137 FieldRef(memPort, static_cast<unsigned>(addressFieldId)));
138 }
139 return deps;
140}
141
142/// Return the kind of port this is given the port type from a 'mem' decl.
143static MemOp::PortKind getMemPortKindFromType(FIRRTLType type) {
144 constexpr unsigned int addr = 1 << 0;
145 constexpr unsigned int en = 1 << 1;
146 constexpr unsigned int clk = 1 << 2;
147 constexpr unsigned int data = 1 << 3;
148 constexpr unsigned int mask = 1 << 4;
149 constexpr unsigned int rdata = 1 << 5;
150 constexpr unsigned int wdata = 1 << 6;
151 constexpr unsigned int wmask = 1 << 7;
152 constexpr unsigned int wmode = 1 << 8;
153 constexpr unsigned int def = 1 << 9;
154 // Get the kind of port based on the fields of the Bundle.
155 auto portType = type_dyn_cast<BundleType>(type);
156 if (!portType)
157 return MemOp::PortKind::Debug;
158 unsigned fields = 0;
159 // Get the kind of port based on the fields of the Bundle.
160 for (auto elem : portType.getElements()) {
161 fields |= llvm::StringSwitch<unsigned>(elem.name.getValue())
162 .Case("addr", addr)
163 .Case("en", en)
164 .Case("clk", clk)
165 .Case("data", data)
166 .Case("mask", mask)
167 .Case("rdata", rdata)
168 .Case("wdata", wdata)
169 .Case("wmask", wmask)
170 .Case("wmode", wmode)
171 .Default(def);
172 }
173 if (fields == (addr | en | clk | data))
174 return MemOp::PortKind::Read;
175 if (fields == (addr | en | clk | data | mask))
176 return MemOp::PortKind::Write;
177 if (fields == (addr | en | clk | wdata | wmask | rdata | wmode))
178 return MemOp::PortKind::ReadWrite;
179 return MemOp::PortKind::Debug;
180}
181
183 switch (flow) {
184 case Flow::None:
185 return Flow::None;
186 case Flow::Source:
187 return Flow::Sink;
188 case Flow::Sink:
189 return Flow::Source;
190 case Flow::Duplex:
191 return Flow::Duplex;
192 }
193 // Unreachable but silences warning
194 llvm_unreachable("Unsupported Flow type.");
195}
196
197const char *toString(Flow flow) {
198 switch (flow) {
199 case Flow::None:
200 return "no flow";
201 case Flow::Source:
202 return "source flow";
203 case Flow::Sink:
204 return "sink flow";
205 case Flow::Duplex:
206 return "duplex flow";
207 }
208 // Unreachable but silences warning
209 llvm_unreachable("Unsupported Flow type.");
210}
211
212Flow firrtl::foldFlow(Value val, Flow accumulatedFlow) {
213
214 if (auto blockArg = dyn_cast<BlockArgument>(val)) {
215 auto *op = val.getParentBlock()->getParentOp();
216 if (auto moduleLike = dyn_cast<FModuleLike>(op)) {
217 auto direction = moduleLike.getPortDirection(blockArg.getArgNumber());
218 if (direction == Direction::Out)
219 return swapFlow(accumulatedFlow);
220 }
221 return accumulatedFlow;
222 }
223
224 Operation *op = val.getDefiningOp();
225
226 return TypeSwitch<Operation *, Flow>(op)
227 .Case<SubfieldOp, OpenSubfieldOp>([&](auto op) {
228 return foldFlow(op.getInput(), op.isFieldFlipped()
229 ? swapFlow(accumulatedFlow)
230 : accumulatedFlow);
231 })
232 .Case<SubindexOp, SubaccessOp, OpenSubindexOp, RefSubOp>(
233 [&](auto op) { return foldFlow(op.getInput(), accumulatedFlow); })
234 // Registers, Wires, and behavioral memory ports are always Duplex.
235 .Case<RegOp, RegResetOp, WireOp, MemoryPortOp>(
236 [](auto) { return Flow::Duplex; })
237 .Case<InstanceOp, InstanceChoiceOp>([&](auto inst) {
238 auto resultNo = cast<OpResult>(val).getResultNumber();
239 if (inst.getPortDirection(resultNo) == Direction::Out)
240 return accumulatedFlow;
241 return swapFlow(accumulatedFlow);
242 })
243 .Case<MemOp>([&](auto op) {
244 // only debug ports with RefType have source flow.
245 if (type_isa<RefType>(val.getType()))
246 return Flow::Source;
247 return swapFlow(accumulatedFlow);
248 })
249 .Case<ObjectSubfieldOp>([&](ObjectSubfieldOp op) {
250 auto input = op.getInput();
251 auto *inputOp = input.getDefiningOp();
252
253 // We are directly accessing a port on a local declaration.
254 if (auto objectOp = dyn_cast_or_null<ObjectOp>(inputOp)) {
255 auto classType = input.getType();
256 auto direction = classType.getElement(op.getIndex()).direction;
257 if (direction == Direction::In)
258 return Flow::Sink;
259 return Flow::Source;
260 }
261
262 // We are accessing a remote object. Input ports on remote objects are
263 // inaccessible, and thus have Flow::None. Walk backwards through the
264 // chain of subindexes, to detect if we have indexed through an input
265 // port. At the end, either we did index through an input port, or the
266 // entire path was through output ports with source flow.
267 while (true) {
268 auto classType = input.getType();
269 auto direction = classType.getElement(op.getIndex()).direction;
270 if (direction == Direction::In)
271 return Flow::None;
272
273 op = dyn_cast_or_null<ObjectSubfieldOp>(inputOp);
274 if (op) {
275 input = op.getInput();
276 inputOp = input.getDefiningOp();
277 continue;
278 }
279
280 return accumulatedFlow;
281 };
282 })
283 // Anything else acts like a universal source.
284 .Default([&](auto) { return accumulatedFlow; });
285}
286
287// TODO: This is doing the same walk as foldFlow. These two functions can be
288// combined and return a (flow, kind) product.
290 Operation *op = val.getDefiningOp();
291 if (!op)
292 return DeclKind::Port;
293
294 return TypeSwitch<Operation *, DeclKind>(op)
295 .Case<InstanceOp>([](auto) { return DeclKind::Instance; })
296 .Case<SubfieldOp, SubindexOp, SubaccessOp, OpenSubfieldOp, OpenSubindexOp,
297 RefSubOp>([](auto op) { return getDeclarationKind(op.getInput()); })
298 .Default([](auto) { return DeclKind::Other; });
299}
300
301size_t firrtl::getNumPorts(Operation *op) {
302 if (auto module = dyn_cast<FModuleLike>(op))
303 return module.getNumPorts();
304 return op->getNumResults();
305}
306
307/// Check whether an operation has a `DontTouch` annotation, or a symbol that
308/// should prevent certain types of canonicalizations.
309bool firrtl::hasDontTouch(Operation *op) {
310 return op->getAttr(hw::InnerSymbolTable::getInnerSymbolAttrName()) ||
312}
313
314/// Check whether a block argument ("port") or the operation defining a value
315/// has a `DontTouch` annotation, or a symbol that should prevent certain types
316/// of canonicalizations.
317bool firrtl::hasDontTouch(Value value) {
318 if (auto *op = value.getDefiningOp())
319 return hasDontTouch(op);
320 auto arg = dyn_cast<BlockArgument>(value);
321 auto module = cast<FModuleOp>(arg.getOwner()->getParentOp());
322 return (module.getPortSymbolAttr(arg.getArgNumber())) ||
323 AnnotationSet::forPort(module, arg.getArgNumber()).hasDontTouch();
324}
325
326/// Get a special name to use when printing the entry block arguments of the
327/// region contained by an operation in this dialect.
328void getAsmBlockArgumentNamesImpl(Operation *op, mlir::Region &region,
329 OpAsmSetValueNameFn setNameFn) {
330 if (region.empty())
331 return;
332 auto *parentOp = op;
333 auto *block = &region.front();
334 // Check to see if the operation containing the arguments has 'firrtl.name'
335 // attributes for them. If so, use that as the name.
336 auto argAttr = parentOp->getAttrOfType<ArrayAttr>("portNames");
337 // Do not crash on invalid IR.
338 if (!argAttr || argAttr.size() != block->getNumArguments())
339 return;
340
341 for (size_t i = 0, e = block->getNumArguments(); i != e; ++i) {
342 auto str = cast<StringAttr>(argAttr[i]).getValue();
343 if (!str.empty())
344 setNameFn(block->getArgument(i), str);
345 }
346}
347
348/// A forward declaration for `NameKind` attribute parser.
349static ParseResult parseNameKind(OpAsmParser &parser,
350 firrtl::NameKindEnumAttr &result);
351
352//===----------------------------------------------------------------------===//
353// Layer Verification Utilities
354//===----------------------------------------------------------------------===//
355
356namespace {
357struct CompareSymbolRefAttr {
358 // True if lhs is lexicographically less than rhs.
359 bool operator()(SymbolRefAttr lhs, SymbolRefAttr rhs) const {
360 auto cmp = lhs.getRootReference().compare(rhs.getRootReference());
361 if (cmp == -1)
362 return true;
363 if (cmp == 1)
364 return false;
365 auto lhsNested = lhs.getNestedReferences();
366 auto rhsNested = rhs.getNestedReferences();
367 auto lhsNestedSize = lhsNested.size();
368 auto rhsNestedSize = rhsNested.size();
369 auto e = std::min(lhsNestedSize, rhsNestedSize);
370 for (unsigned i = 0; i < e; ++i) {
371 auto cmp = lhsNested[i].getAttr().compare(rhsNested[i].getAttr());
372 if (cmp == -1)
373 return true;
374 if (cmp == 1)
375 return false;
376 }
377 return lhsNestedSize < rhsNestedSize;
378 }
379};
380} // namespace
381
383
384/// Get the ambient layers active at the given op.
385static LayerSet getAmbientLayersAt(Operation *op) {
386 // Crawl through the parent ops, accumulating all ambient layers at the given
387 // operation.
388 LayerSet result;
389 for (; op != nullptr; op = op->getParentOp()) {
390 if (auto module = dyn_cast<FModuleLike>(op)) {
391 auto layers = module.getLayersAttr().getAsRange<SymbolRefAttr>();
392 result.insert(layers.begin(), layers.end());
393 break;
394 }
395 if (auto layerblock = dyn_cast<LayerBlockOp>(op)) {
396 result.insert(layerblock.getLayerName());
397 continue;
398 }
399 }
400 return result;
401}
402
403/// Get the ambient layer requirements at the definition site of the value.
404static LayerSet getAmbientLayersFor(Value value) {
405 return getAmbientLayersAt(getFieldRefFromValue(value).getDefiningOp());
406}
407
408/// Get the effective layer requirements for the given value.
409/// The effective layers for a value is the union of
410/// - the ambient layers for the cannonical storage location.
411/// - any explicit layer annotations in the value's type.
412static LayerSet getLayersFor(Value value) {
413 auto result = getAmbientLayersFor(value);
414 if (auto type = dyn_cast<RefType>(value.getType()))
415 if (auto layer = type.getLayer())
416 result.insert(type.getLayer());
417 return result;
418}
419
420/// Check that the source layer is compatible with the destination layer.
421/// Either the source and destination are identical, or the source-layer
422/// is a parent of the destination. For example `A` is compatible with `A.B.C`,
423/// because any definition valid in `A` is also valid in `A.B.C`.
424static bool isLayerCompatibleWith(mlir::SymbolRefAttr srcLayer,
425 mlir::SymbolRefAttr dstLayer) {
426 // A non-colored probe may be cast to any colored probe.
427 if (!srcLayer)
428 return true;
429
430 // A colored probe cannot be cast to an uncolored probe.
431 if (!dstLayer)
432 return false;
433
434 // Return true if the srcLayer is a prefix of the dstLayer.
435 if (srcLayer.getRootReference() != dstLayer.getRootReference())
436 return false;
437
438 auto srcNames = srcLayer.getNestedReferences();
439 auto dstNames = dstLayer.getNestedReferences();
440 if (dstNames.size() < srcNames.size())
441 return false;
442
443 return llvm::all_of(llvm::zip_first(srcNames, dstNames),
444 [](auto x) { return std::get<0>(x) == std::get<1>(x); });
445}
446
447/// Check that the source layer is present in the destination layers.
448static bool isLayerCompatibleWith(SymbolRefAttr srcLayer,
449 const LayerSet &dstLayers) {
450 // fast path: the required layer is directly listed in the provided layers.
451 if (dstLayers.contains(srcLayer))
452 return true;
453
454 // Slow path: the required layer is not directly listed in the provided
455 // layers, but the layer may still be provided by a nested layer.
456 return any_of(dstLayers, [=](SymbolRefAttr dstLayer) {
457 return isLayerCompatibleWith(srcLayer, dstLayer);
458 });
459}
460
461/// Check that the source layers are all present in the destination layers.
462/// True if all source layers are present in the destination.
463/// Outputs the set of source layers that are missing in the destination.
464static bool isLayerSetCompatibleWith(const LayerSet &src, const LayerSet &dst,
465 SmallVectorImpl<SymbolRefAttr> &missing) {
466 for (auto srcLayer : src)
467 if (!isLayerCompatibleWith(srcLayer, dst))
468 missing.push_back(srcLayer);
469
470 llvm::sort(missing, CompareSymbolRefAttr());
471 return missing.empty();
472}
473
474//===----------------------------------------------------------------------===//
475// CircuitOp
476//===----------------------------------------------------------------------===//
477
478void CircuitOp::build(OpBuilder &builder, OperationState &result,
479 StringAttr name, ArrayAttr annotations) {
480 // Add an attribute for the name.
481 result.getOrAddProperties<Properties>().setName(name);
482
483 if (!annotations)
484 annotations = builder.getArrayAttr({});
485 result.getOrAddProperties<Properties>().setAnnotations(annotations);
486
487 // Create a region and a block for the body.
488 Region *bodyRegion = result.addRegion();
489 Block *body = new Block();
490 bodyRegion->push_back(body);
491}
492
493static ParseResult parseCircuitOpAttrs(OpAsmParser &parser,
494 NamedAttrList &resultAttrs) {
495 auto result = parser.parseOptionalAttrDictWithKeyword(resultAttrs);
496 if (!resultAttrs.get("annotations"))
497 resultAttrs.append("annotations", parser.getBuilder().getArrayAttr({}));
498
499 return result;
500}
501
502static void printCircuitOpAttrs(OpAsmPrinter &p, Operation *op,
503 DictionaryAttr attr) {
504 // "name" is always elided.
505 SmallVector<StringRef> elidedAttrs = {"name"};
506 // Elide "annotations" if it doesn't exist or if it is empty
507 auto annotationsAttr = op->getAttrOfType<ArrayAttr>("annotations");
508 if (annotationsAttr.empty())
509 elidedAttrs.push_back("annotations");
510
511 p.printOptionalAttrDictWithKeyword(op->getAttrs(), elidedAttrs);
512}
513
514LogicalResult CircuitOp::verifyRegions() {
515 StringRef main = getName();
516
517 // Check that the circuit has a non-empty name.
518 if (main.empty()) {
519 emitOpError("must have a non-empty name");
520 return failure();
521 }
522
523 mlir::SymbolTable symtbl(getOperation());
524
525 auto *mainModule = symtbl.lookup(main);
526 if (!mainModule)
527 return emitOpError().append(
528 "does not contain module with same name as circuit");
529 if (!isa<FModuleLike>(mainModule))
530 return mainModule->emitError(
531 "entity with name of circuit must be a module");
532 if (symtbl.getSymbolVisibility(mainModule) !=
533 mlir::SymbolTable::Visibility::Public)
534 return mainModule->emitError("main module must be public");
535
536 // Store a mapping of defname to either the first external module
537 // that defines it or, preferentially, the first external module
538 // that defines it and has no parameters.
539 llvm::DenseMap<Attribute, FExtModuleOp> defnameMap;
540
541 auto verifyExtModule = [&](FExtModuleOp extModule) -> LogicalResult {
542 if (!extModule)
543 return success();
544
545 auto defname = extModule.getDefnameAttr();
546 if (!defname)
547 return success();
548
549 // Check that this extmodule's defname does not conflict with
550 // the symbol name of any module.
551 if (auto collidingModule = symtbl.lookup<FModuleOp>(defname.getValue()))
552 return extModule.emitOpError()
553 .append("attribute 'defname' with value ", defname,
554 " conflicts with the name of another module in the circuit")
555 .attachNote(collidingModule.getLoc())
556 .append("previous module declared here");
557
558 // Find an optional extmodule with a defname collision. Update
559 // the defnameMap if this is the first extmodule with that
560 // defname or if the current extmodule takes no parameters and
561 // the collision does. The latter condition improves later
562 // extmodule verification as checking against a parameterless
563 // module is stricter.
564 FExtModuleOp collidingExtModule;
565 if (auto &value = defnameMap[defname]) {
566 collidingExtModule = value;
567 if (!value.getParameters().empty() && extModule.getParameters().empty())
568 value = extModule;
569 } else {
570 value = extModule;
571 // Go to the next extmodule if no extmodule with the same
572 // defname was found.
573 return success();
574 }
575
576 // Check that the number of ports is exactly the same.
577 SmallVector<PortInfo> ports = extModule.getPorts();
578 SmallVector<PortInfo> collidingPorts = collidingExtModule.getPorts();
579
580 if (ports.size() != collidingPorts.size())
581 return extModule.emitOpError()
582 .append("with 'defname' attribute ", defname, " has ", ports.size(),
583 " ports which is different from a previously defined "
584 "extmodule with the same 'defname' which has ",
585 collidingPorts.size(), " ports")
586 .attachNote(collidingExtModule.getLoc())
587 .append("previous extmodule definition occurred here");
588
589 // Check that ports match for name and type. Since parameters
590 // *might* affect widths, ignore widths if either module has
591 // parameters. Note that this allows for misdetections, but
592 // has zero false positives.
593 for (auto p : llvm::zip(ports, collidingPorts)) {
594 StringAttr aName = std::get<0>(p).name, bName = std::get<1>(p).name;
595 Type aType = std::get<0>(p).type, bType = std::get<1>(p).type;
596
597 if (aName != bName)
598 return extModule.emitOpError()
599 .append("with 'defname' attribute ", defname,
600 " has a port with name ", aName,
601 " which does not match the name of the port in the same "
602 "position of a previously defined extmodule with the same "
603 "'defname', expected port to have name ",
604 bName)
605 .attachNote(collidingExtModule.getLoc())
606 .append("previous extmodule definition occurred here");
607
608 if (!extModule.getParameters().empty() ||
609 !collidingExtModule.getParameters().empty()) {
610 // Compare base types as widthless, others must match.
611 if (auto base = type_dyn_cast<FIRRTLBaseType>(aType))
612 aType = base.getWidthlessType();
613 if (auto base = type_dyn_cast<FIRRTLBaseType>(bType))
614 bType = base.getWidthlessType();
615 }
616 if (aType != bType)
617 return extModule.emitOpError()
618 .append("with 'defname' attribute ", defname,
619 " has a port with name ", aName,
620 " which has a different type ", aType,
621 " which does not match the type of the port in the same "
622 "position of a previously defined extmodule with the same "
623 "'defname', expected port to have type ",
624 bType)
625 .attachNote(collidingExtModule.getLoc())
626 .append("previous extmodule definition occurred here");
627 }
628 return success();
629 };
630
631 SmallVector<FModuleOp, 1> dutModules;
632 for (auto &op : *getBodyBlock()) {
633 // Verify modules.
634 if (auto moduleOp = dyn_cast<FModuleOp>(op)) {
635 if (AnnotationSet(moduleOp).hasAnnotation(dutAnnoClass))
636 dutModules.push_back(moduleOp);
637 continue;
638 }
639
640 // Verify external modules.
641 if (auto extModule = dyn_cast<FExtModuleOp>(op)) {
642 if (verifyExtModule(extModule).failed())
643 return failure();
644 }
645 }
646
647 // Error if there is more than one design-under-test.
648 if (dutModules.size() > 1) {
649 auto diag = dutModules[0]->emitOpError()
650 << "is annotated as the design-under-test (DUT), but other "
651 "modules are also annotated";
652 for (auto moduleOp : ArrayRef(dutModules).drop_front())
653 diag.attachNote(moduleOp.getLoc()) << "is also annotated as the DUT";
654 return failure();
655 }
656
657 return success();
658}
659
660Block *CircuitOp::getBodyBlock() { return &getBody().front(); }
661
662//===----------------------------------------------------------------------===//
663// FExtModuleOp and FModuleOp
664//===----------------------------------------------------------------------===//
665
666static SmallVector<PortInfo> getPortImpl(FModuleLike module) {
667 SmallVector<PortInfo> results;
668 for (unsigned i = 0, e = module.getNumPorts(); i < e; ++i) {
669 results.push_back({module.getPortNameAttr(i), module.getPortType(i),
670 module.getPortDirection(i), module.getPortSymbolAttr(i),
671 module.getPortLocation(i),
672 AnnotationSet::forPort(module, i)});
673 }
674 return results;
675}
676
677SmallVector<PortInfo> FModuleOp::getPorts() { return ::getPortImpl(*this); }
678
679SmallVector<PortInfo> FExtModuleOp::getPorts() { return ::getPortImpl(*this); }
680
681SmallVector<PortInfo> FIntModuleOp::getPorts() { return ::getPortImpl(*this); }
682
683SmallVector<PortInfo> FMemModuleOp::getPorts() { return ::getPortImpl(*this); }
684
686 if (dir == Direction::In)
687 return hw::ModulePort::Direction::Input;
688 if (dir == Direction::Out)
689 return hw::ModulePort::Direction::Output;
690 assert(0 && "invalid direction");
691 abort();
692}
693
694static SmallVector<hw::PortInfo> getPortListImpl(FModuleLike module) {
695 SmallVector<hw::PortInfo> results;
696 auto aname = StringAttr::get(module.getContext(),
697 hw::HWModuleLike::getPortSymbolAttrName());
698 auto emptyDict = DictionaryAttr::get(module.getContext());
699 for (unsigned i = 0, e = getNumPorts(module); i < e; ++i) {
700 auto sym = module.getPortSymbolAttr(i);
701 results.push_back(
702 {{module.getPortNameAttr(i), module.getPortType(i),
703 dirFtoH(module.getPortDirection(i))},
704 i,
705 sym ? DictionaryAttr::get(
706 module.getContext(),
707 ArrayRef<mlir::NamedAttribute>{NamedAttribute{aname, sym}})
708 : emptyDict,
709 module.getPortLocation(i)});
710 }
711 return results;
712}
713
714SmallVector<::circt::hw::PortInfo> FModuleOp::getPortList() {
715 return ::getPortListImpl(*this);
716}
717
718SmallVector<::circt::hw::PortInfo> FExtModuleOp::getPortList() {
719 return ::getPortListImpl(*this);
720}
721
722SmallVector<::circt::hw::PortInfo> FIntModuleOp::getPortList() {
723 return ::getPortListImpl(*this);
724}
725
726SmallVector<::circt::hw::PortInfo> FMemModuleOp::getPortList() {
727 return ::getPortListImpl(*this);
728}
729
730static hw::PortInfo getPortImpl(FModuleLike module, size_t idx) {
731 return {{module.getPortNameAttr(idx), module.getPortType(idx),
732 dirFtoH(module.getPortDirection(idx))},
733 idx,
734 DictionaryAttr::get(
735 module.getContext(),
736 ArrayRef<mlir::NamedAttribute>{NamedAttribute{
737 StringAttr::get(module.getContext(),
738 hw::HWModuleLike::getPortSymbolAttrName()),
739 module.getPortSymbolAttr(idx)}}),
740 module.getPortLocation(idx)};
741}
742
743::circt::hw::PortInfo FModuleOp::getPort(size_t idx) {
744 return ::getPortImpl(*this, idx);
745}
746
747::circt::hw::PortInfo FExtModuleOp::getPort(size_t idx) {
748 return ::getPortImpl(*this, idx);
749}
750
751::circt::hw::PortInfo FIntModuleOp::getPort(size_t idx) {
752 return ::getPortImpl(*this, idx);
753}
754
755::circt::hw::PortInfo FMemModuleOp::getPort(size_t idx) {
756 return ::getPortImpl(*this, idx);
757}
758
759// Return the port with the specified name.
760BlockArgument FModuleOp::getArgument(size_t portNumber) {
761 return getBodyBlock()->getArgument(portNumber);
762}
763
764/// Inserts the given ports. The insertion indices are expected to be in order.
765/// Insertion occurs in-order, such that ports with the same insertion index
766/// appear in the module in the same order they appeared in the list.
767static void insertPorts(FModuleLike op,
768 ArrayRef<std::pair<unsigned, PortInfo>> ports,
769 bool supportsInternalPaths = false) {
770 if (ports.empty())
771 return;
772 unsigned oldNumArgs = op.getNumPorts();
773 unsigned newNumArgs = oldNumArgs + ports.size();
774
775 // Add direction markers and names for new ports.
776 auto existingDirections = op.getPortDirectionsAttr();
777 ArrayRef<Attribute> existingNames = op.getPortNames();
778 ArrayRef<Attribute> existingTypes = op.getPortTypes();
779 ArrayRef<Attribute> existingLocs = op.getPortLocations();
780 assert(existingDirections.size() == oldNumArgs);
781 assert(existingNames.size() == oldNumArgs);
782 assert(existingTypes.size() == oldNumArgs);
783 assert(existingLocs.size() == oldNumArgs);
784 SmallVector<Attribute> internalPaths;
785 auto emptyInternalPath = InternalPathAttr::get(op.getContext());
786 if (supportsInternalPaths) {
787 if (auto internalPathsAttr = op->getAttrOfType<ArrayAttr>("internalPaths"))
788 llvm::append_range(internalPaths, internalPathsAttr);
789 else
790 internalPaths.resize(oldNumArgs, emptyInternalPath);
791 assert(internalPaths.size() == oldNumArgs);
792 }
793
794 SmallVector<bool> newDirections;
795 SmallVector<Attribute> newNames, newTypes, newAnnos, newSyms, newLocs,
796 newInternalPaths;
797 newDirections.reserve(newNumArgs);
798 newNames.reserve(newNumArgs);
799 newTypes.reserve(newNumArgs);
800 newAnnos.reserve(newNumArgs);
801 newSyms.reserve(newNumArgs);
802 newLocs.reserve(newNumArgs);
803 newInternalPaths.reserve(newNumArgs);
804
805 auto emptyArray = ArrayAttr::get(op.getContext(), {});
806
807 unsigned oldIdx = 0;
808 auto migrateOldPorts = [&](unsigned untilOldIdx) {
809 while (oldIdx < oldNumArgs && oldIdx < untilOldIdx) {
810 newDirections.push_back(existingDirections[oldIdx]);
811 newNames.push_back(existingNames[oldIdx]);
812 newTypes.push_back(existingTypes[oldIdx]);
813 newAnnos.push_back(op.getAnnotationsAttrForPort(oldIdx));
814 newSyms.push_back(op.getPortSymbolAttr(oldIdx));
815 newLocs.push_back(existingLocs[oldIdx]);
816 if (supportsInternalPaths)
817 newInternalPaths.push_back(internalPaths[oldIdx]);
818 ++oldIdx;
819 }
820 };
821 for (auto pair : llvm::enumerate(ports)) {
822 auto idx = pair.value().first;
823 auto &port = pair.value().second;
824 migrateOldPorts(idx);
825 newDirections.push_back(direction::unGet(port.direction));
826 newNames.push_back(port.name);
827 newTypes.push_back(TypeAttr::get(port.type));
828 auto annos = port.annotations.getArrayAttr();
829 newAnnos.push_back(annos ? annos : emptyArray);
830 newSyms.push_back(port.sym);
831 newLocs.push_back(port.loc);
832 if (supportsInternalPaths)
833 newInternalPaths.push_back(emptyInternalPath);
834 }
835 migrateOldPorts(oldNumArgs);
836
837 // The lack of *any* port annotations is represented by an empty
838 // `portAnnotations` array as a shorthand.
839 if (llvm::all_of(newAnnos, [](Attribute attr) {
840 return cast<ArrayAttr>(attr).empty();
841 }))
842 newAnnos.clear();
843
844 // Apply these changed markers.
845 op->setAttr("portDirections",
846 direction::packAttribute(op.getContext(), newDirections));
847 op->setAttr("portNames", ArrayAttr::get(op.getContext(), newNames));
848 op->setAttr("portTypes", ArrayAttr::get(op.getContext(), newTypes));
849 op->setAttr("portAnnotations", ArrayAttr::get(op.getContext(), newAnnos));
850 FModuleLike::fixupPortSymsArray(newSyms, op.getContext());
851 op.setPortSymbols(newSyms);
852 op->setAttr("portLocations", ArrayAttr::get(op.getContext(), newLocs));
853 if (supportsInternalPaths) {
854 // Drop if all-empty, otherwise set to new array.
855 auto empty = llvm::all_of(newInternalPaths, [](Attribute attr) {
856 return !cast<InternalPathAttr>(attr).getPath();
857 });
858 if (empty)
859 op->removeAttr("internalPaths");
860 else
861 op->setAttr("internalPaths",
862 ArrayAttr::get(op.getContext(), newInternalPaths));
863 }
864}
865
866/// Erases the ports that have their corresponding bit set in `portIndices`.
867static void erasePorts(FModuleLike op, const llvm::BitVector &portIndices) {
868 if (portIndices.none())
869 return;
870
871 // Drop the direction markers for dead ports.
872 ArrayRef<bool> portDirections = op.getPortDirectionsAttr().asArrayRef();
873 ArrayRef<Attribute> portNames = op.getPortNames();
874 ArrayRef<Attribute> portTypes = op.getPortTypes();
875 ArrayRef<Attribute> portAnnos = op.getPortAnnotations();
876 ArrayRef<Attribute> portSyms = op.getPortSymbols();
877 ArrayRef<Attribute> portLocs = op.getPortLocations();
878 auto numPorts = op.getNumPorts();
879 (void)numPorts;
880 assert(portDirections.size() == numPorts);
881 assert(portNames.size() == numPorts);
882 assert(portAnnos.size() == numPorts || portAnnos.empty());
883 assert(portTypes.size() == numPorts);
884 assert(portSyms.size() == numPorts || portSyms.empty());
885 assert(portLocs.size() == numPorts);
886
887 SmallVector<bool> newPortDirections =
888 removeElementsAtIndices<bool>(portDirections, portIndices);
889 SmallVector<Attribute> newPortNames, newPortTypes, newPortAnnos, newPortSyms,
890 newPortLocs;
891 newPortNames = removeElementsAtIndices(portNames, portIndices);
892 newPortTypes = removeElementsAtIndices(portTypes, portIndices);
893 newPortAnnos = removeElementsAtIndices(portAnnos, portIndices);
894 newPortSyms = removeElementsAtIndices(portSyms, portIndices);
895 newPortLocs = removeElementsAtIndices(portLocs, portIndices);
896 op->setAttr("portDirections",
897 direction::packAttribute(op.getContext(), newPortDirections));
898 op->setAttr("portNames", ArrayAttr::get(op.getContext(), newPortNames));
899 op->setAttr("portAnnotations", ArrayAttr::get(op.getContext(), newPortAnnos));
900 op->setAttr("portTypes", ArrayAttr::get(op.getContext(), newPortTypes));
901 FModuleLike::fixupPortSymsArray(newPortSyms, op.getContext());
902 op->setAttr("portSymbols", ArrayAttr::get(op.getContext(), newPortSyms));
903 op->setAttr("portLocations", ArrayAttr::get(op.getContext(), newPortLocs));
904}
905
906template <typename T>
907static void eraseInternalPaths(T op, const llvm::BitVector &portIndices) {
908 // Fixup internalPaths array.
909 auto internalPaths = op.getInternalPaths();
910 if (!internalPaths)
911 return;
912
913 auto newPaths =
914 removeElementsAtIndices(internalPaths->getValue(), portIndices);
915
916 // Drop if all-empty, otherwise set to new array.
917 auto empty = llvm::all_of(newPaths, [](Attribute attr) {
918 return !cast<InternalPathAttr>(attr).getPath();
919 });
920 if (empty)
921 op.removeInternalPathsAttr();
922 else
923 op.setInternalPathsAttr(ArrayAttr::get(op.getContext(), newPaths));
924}
925
926void FExtModuleOp::erasePorts(const llvm::BitVector &portIndices) {
927 ::erasePorts(cast<FModuleLike>((Operation *)*this), portIndices);
928 eraseInternalPaths(*this, portIndices);
929}
930
931void FIntModuleOp::erasePorts(const llvm::BitVector &portIndices) {
932 ::erasePorts(cast<FModuleLike>((Operation *)*this), portIndices);
933 eraseInternalPaths(*this, portIndices);
934}
935
936void FMemModuleOp::erasePorts(const llvm::BitVector &portIndices) {
937 ::erasePorts(cast<FModuleLike>((Operation *)*this), portIndices);
938}
939
940void FModuleOp::erasePorts(const llvm::BitVector &portIndices) {
941 ::erasePorts(cast<FModuleLike>((Operation *)*this), portIndices);
942 getBodyBlock()->eraseArguments(portIndices);
943}
944
945/// Inserts the given ports. The insertion indices are expected to be in order.
946/// Insertion occurs in-order, such that ports with the same insertion index
947/// appear in the module in the same order they appeared in the list.
948void FModuleOp::insertPorts(ArrayRef<std::pair<unsigned, PortInfo>> ports) {
949 ::insertPorts(cast<FModuleLike>((Operation *)*this), ports);
950
951 // Insert the block arguments.
952 auto *body = getBodyBlock();
953 for (size_t i = 0, e = ports.size(); i < e; ++i) {
954 // Block arguments are inserted one at a time, so for each argument we
955 // insert we have to increase the index by 1.
956 auto &[index, port] = ports[i];
957 body->insertArgument(index + i, port.type, port.loc);
958 }
959}
960
961void FExtModuleOp::insertPorts(ArrayRef<std::pair<unsigned, PortInfo>> ports) {
962 ::insertPorts(cast<FModuleLike>((Operation *)*this), ports,
963 /*supportsInternalPaths=*/true);
964}
965
966void FIntModuleOp::insertPorts(ArrayRef<std::pair<unsigned, PortInfo>> ports) {
967 ::insertPorts(cast<FModuleLike>((Operation *)*this), ports,
968 /*supportsInternalPaths=*/true);
969}
970
971/// Inserts the given ports. The insertion indices are expected to be in order.
972/// Insertion occurs in-order, such that ports with the same insertion index
973/// appear in the module in the same order they appeared in the list.
974void FMemModuleOp::insertPorts(ArrayRef<std::pair<unsigned, PortInfo>> ports) {
975 ::insertPorts(cast<FModuleLike>((Operation *)*this), ports);
976}
977
978template <typename OpTy>
979void buildModuleLike(OpBuilder &builder, OperationState &result,
980 StringAttr name, ArrayRef<PortInfo> ports) {
981 // Add an attribute for the name.
982 auto &properties = result.getOrAddProperties<typename OpTy::Properties>();
983 properties.setSymName(name);
984
985 // Record the names of the arguments if present.
986 SmallVector<Direction, 4> portDirections;
987 SmallVector<Attribute, 4> portNames;
988 SmallVector<Attribute, 4> portTypes;
989 SmallVector<Attribute, 4> portSyms;
990 SmallVector<Attribute, 4> portLocs;
991 for (const auto &port : ports) {
992 portDirections.push_back(port.direction);
993 portNames.push_back(port.name);
994 portTypes.push_back(TypeAttr::get(port.type));
995 portSyms.push_back(port.sym);
996 portLocs.push_back(port.loc);
997 }
998
999 FModuleLike::fixupPortSymsArray(portSyms, builder.getContext());
1000
1001 // Both attributes are added, even if the module has no ports.
1002 properties.setPortDirections(
1003 direction::packAttribute(builder.getContext(), portDirections));
1004 properties.setPortNames(builder.getArrayAttr(portNames));
1005 properties.setPortTypes(builder.getArrayAttr(portTypes));
1006 properties.setPortSymbols(builder.getArrayAttr(portSyms));
1007 properties.setPortLocations(builder.getArrayAttr(portLocs));
1008
1009 result.addRegion();
1010}
1011
1012template <typename OpTy>
1013static void buildModule(OpBuilder &builder, OperationState &result,
1014 StringAttr name, ArrayRef<PortInfo> ports,
1015 ArrayAttr annotations, ArrayAttr layers) {
1016 buildModuleLike<OpTy>(builder, result, name, ports);
1017 auto &properties = result.getOrAddProperties<typename OpTy::Properties>();
1018 // Annotations.
1019 if (!annotations)
1020 annotations = builder.getArrayAttr({});
1021 properties.setAnnotations(annotations);
1022
1023 // Port annotations. lack of *any* port annotations is represented by an empty
1024 // `portAnnotations` array as a shorthand.
1025 SmallVector<Attribute, 4> portAnnotations;
1026 for (const auto &port : ports)
1027 portAnnotations.push_back(port.annotations.getArrayAttr());
1028 if (llvm::all_of(portAnnotations, [](Attribute attr) {
1029 return cast<ArrayAttr>(attr).empty();
1030 }))
1031 portAnnotations.clear();
1032 properties.setPortAnnotations(builder.getArrayAttr(portAnnotations));
1033
1034 // Layers.
1035 if (!layers)
1036 layers = builder.getArrayAttr({});
1037 properties.setLayers(layers);
1038}
1039
1040template <typename OpTy>
1041static void buildClass(OpBuilder &builder, OperationState &result,
1042 StringAttr name, ArrayRef<PortInfo> ports) {
1043 return buildModuleLike<OpTy>(builder, result, name, ports);
1044}
1045
1046void FModuleOp::build(OpBuilder &builder, OperationState &result,
1047 StringAttr name, ConventionAttr convention,
1048 ArrayRef<PortInfo> ports, ArrayAttr annotations,
1049 ArrayAttr layers) {
1050 buildModule<FModuleOp>(builder, result, name, ports, annotations, layers);
1051 auto &properties = result.getOrAddProperties<Properties>();
1052 properties.setConvention(convention);
1053
1054 // Create a region and a block for the body.
1055 auto *bodyRegion = result.regions[0].get();
1056 Block *body = new Block();
1057 bodyRegion->push_back(body);
1058
1059 // Add arguments to the body block.
1060 for (auto &elt : ports)
1061 body->addArgument(elt.type, elt.loc);
1062}
1063
1064void FExtModuleOp::build(OpBuilder &builder, OperationState &result,
1065 StringAttr name, ConventionAttr convention,
1066 ArrayRef<PortInfo> ports, StringRef defnameAttr,
1067 ArrayAttr annotations, ArrayAttr parameters,
1068 ArrayAttr internalPaths, ArrayAttr layers) {
1069 buildModule<FExtModuleOp>(builder, result, name, ports, annotations, layers);
1070 auto &properties = result.getOrAddProperties<Properties>();
1071 properties.setConvention(convention);
1072 if (!defnameAttr.empty())
1073 properties.setDefname(builder.getStringAttr(defnameAttr));
1074 if (!parameters)
1075 parameters = builder.getArrayAttr({});
1076 properties.setParameters(parameters);
1077 if (internalPaths && !internalPaths.empty())
1078 properties.setInternalPaths(internalPaths);
1079}
1080
1081void FIntModuleOp::build(OpBuilder &builder, OperationState &result,
1082 StringAttr name, ArrayRef<PortInfo> ports,
1083 StringRef intrinsicNameStr, ArrayAttr annotations,
1084 ArrayAttr parameters, ArrayAttr internalPaths,
1085 ArrayAttr layers) {
1086 buildModule<FIntModuleOp>(builder, result, name, ports, annotations, layers);
1087 auto &properties = result.getOrAddProperties<Properties>();
1088 properties.setIntrinsic(builder.getStringAttr(intrinsicNameStr));
1089 if (!parameters)
1090 parameters = builder.getArrayAttr({});
1091 properties.setParameters(parameters);
1092 if (internalPaths && !internalPaths.empty())
1093 properties.setInternalPaths(internalPaths);
1094}
1095
1096void FMemModuleOp::build(OpBuilder &builder, OperationState &result,
1097 StringAttr name, ArrayRef<PortInfo> ports,
1098 uint32_t numReadPorts, uint32_t numWritePorts,
1099 uint32_t numReadWritePorts, uint32_t dataWidth,
1100 uint32_t maskBits, uint32_t readLatency,
1101 uint32_t writeLatency, uint64_t depth,
1102 ArrayAttr annotations, ArrayAttr layers) {
1103 auto *context = builder.getContext();
1104 buildModule<FMemModuleOp>(builder, result, name, ports, annotations, layers);
1105 auto ui32Type = IntegerType::get(context, 32, IntegerType::Unsigned);
1106 auto ui64Type = IntegerType::get(context, 64, IntegerType::Unsigned);
1107 auto &properties = result.getOrAddProperties<Properties>();
1108 properties.setNumReadPorts(IntegerAttr::get(ui32Type, numReadPorts));
1109 properties.setNumWritePorts(IntegerAttr::get(ui32Type, numWritePorts));
1110 properties.setNumReadWritePorts(
1111 IntegerAttr::get(ui32Type, numReadWritePorts));
1112 properties.setDataWidth(IntegerAttr::get(ui32Type, dataWidth));
1113 properties.setMaskBits(IntegerAttr::get(ui32Type, maskBits));
1114 properties.setReadLatency(IntegerAttr::get(ui32Type, readLatency));
1115 properties.setWriteLatency(IntegerAttr::get(ui32Type, writeLatency));
1116 properties.setDepth(IntegerAttr::get(ui64Type, depth));
1117 properties.setExtraPorts(ArrayAttr::get(context, {}));
1118}
1119
1120/// Print a list of module ports in the following form:
1121/// in x: !firrtl.uint<1> [{class = "DontTouch}], out "_port": !firrtl.uint<2>
1122///
1123/// When there is no block specified, the port names print as MLIR identifiers,
1124/// wrapping in quotes if not legal to print as-is. When there is no block
1125/// specified, this function always return false, indicating that there was no
1126/// issue printing port names.
1127///
1128/// If there is a block specified, then port names will be printed as SSA
1129/// values. If there is a reason the printed SSA values can't match the true
1130/// port name, then this function will return true. When this happens, the
1131/// caller should print the port names as a part of the `attr-dict`.
1132static bool
1133printModulePorts(OpAsmPrinter &p, Block *block, ArrayRef<bool> portDirections,
1134 ArrayRef<Attribute> portNames, ArrayRef<Attribute> portTypes,
1135 ArrayRef<Attribute> portAnnotations,
1136 ArrayRef<Attribute> portSyms, ArrayRef<Attribute> portLocs) {
1137 // When printing port names as SSA values, we can fail to print them
1138 // identically.
1139 bool printedNamesDontMatch = false;
1140
1141 mlir::OpPrintingFlags flags;
1142
1143 // If we are printing the ports as block arguments the op must have a first
1144 // block.
1145 SmallString<32> resultNameStr;
1146 p << '(';
1147 for (unsigned i = 0, e = portTypes.size(); i < e; ++i) {
1148 if (i > 0)
1149 p << ", ";
1150
1151 // Print the port direction.
1152 p << direction::get(portDirections[i]) << " ";
1153
1154 // Print the port name. If there is a valid block, we print it as a block
1155 // argument.
1156 if (block) {
1157 // Get the printed format for the argument name.
1158 resultNameStr.clear();
1159 llvm::raw_svector_ostream tmpStream(resultNameStr);
1160 p.printOperand(block->getArgument(i), tmpStream);
1161 // If the name wasn't printable in a way that agreed with portName, make
1162 // sure to print out an explicit portNames attribute.
1163 auto portName = cast<StringAttr>(portNames[i]).getValue();
1164 if (tmpStream.str().drop_front() != portName)
1165 printedNamesDontMatch = true;
1166 p << tmpStream.str();
1167 } else {
1168 p.printKeywordOrString(cast<StringAttr>(portNames[i]).getValue());
1169 }
1170
1171 // Print the port type.
1172 p << ": ";
1173 auto portType = cast<TypeAttr>(portTypes[i]).getValue();
1174 p.printType(portType);
1175
1176 // Print the optional port symbol.
1177 if (!portSyms.empty()) {
1178 if (!cast<hw::InnerSymAttr>(portSyms[i]).empty()) {
1179 p << " sym ";
1180 cast<hw::InnerSymAttr>(portSyms[i]).print(p);
1181 }
1182 }
1183
1184 // Print the port specific annotations. The port annotations array will be
1185 // empty if there are none.
1186 if (!portAnnotations.empty() &&
1187 !cast<ArrayAttr>(portAnnotations[i]).empty()) {
1188 p << " ";
1189 p.printAttribute(portAnnotations[i]);
1190 }
1191
1192 // Print the port location.
1193 // TODO: `printOptionalLocationSpecifier` will emit aliases for locations,
1194 // even if they are not printed. This will have to be fixed upstream. For
1195 // now, use what was specified on the command line.
1196 if (flags.shouldPrintDebugInfo() && !portLocs.empty())
1197 p.printOptionalLocationSpecifier(cast<LocationAttr>(portLocs[i]));
1198 }
1199
1200 p << ')';
1201 return printedNamesDontMatch;
1202}
1203
1204/// Parse a list of module ports. If port names are SSA identifiers, then this
1205/// will populate `entryArgs`.
1206static ParseResult
1207parseModulePorts(OpAsmParser &parser, bool hasSSAIdentifiers,
1208 bool supportsSymbols,
1209 SmallVectorImpl<OpAsmParser::Argument> &entryArgs,
1210 SmallVectorImpl<Direction> &portDirections,
1211 SmallVectorImpl<Attribute> &portNames,
1212 SmallVectorImpl<Attribute> &portTypes,
1213 SmallVectorImpl<Attribute> &portAnnotations,
1214 SmallVectorImpl<Attribute> &portSyms,
1215 SmallVectorImpl<Attribute> &portLocs) {
1216 auto *context = parser.getContext();
1217
1218 auto parseArgument = [&]() -> ParseResult {
1219 // Parse port direction.
1220 if (succeeded(parser.parseOptionalKeyword("out")))
1221 portDirections.push_back(Direction::Out);
1222 else if (succeeded(parser.parseKeyword("in", "or 'out'")))
1223 portDirections.push_back(Direction::In);
1224 else
1225 return failure();
1226
1227 // This is the location or the port declaration in the IR. If there is no
1228 // other location information, we use this to point to the MLIR.
1229 llvm::SMLoc irLoc;
1230
1231 if (hasSSAIdentifiers) {
1232 OpAsmParser::Argument arg;
1233 if (parser.parseArgument(arg))
1234 return failure();
1235 entryArgs.push_back(arg);
1236
1237 // The name of an argument is of the form "%42" or "%id", and since
1238 // parsing succeeded, we know it always has one character.
1239 assert(arg.ssaName.name.size() > 1 && arg.ssaName.name[0] == '%' &&
1240 "Unknown MLIR name");
1241 if (isdigit(arg.ssaName.name[1]))
1242 portNames.push_back(StringAttr::get(context, ""));
1243 else
1244 portNames.push_back(
1245 StringAttr::get(context, arg.ssaName.name.drop_front()));
1246
1247 // Store the location of the SSA name.
1248 irLoc = arg.ssaName.location;
1249
1250 } else {
1251 // Parse the port name.
1252 irLoc = parser.getCurrentLocation();
1253 std::string portName;
1254 if (parser.parseKeywordOrString(&portName))
1255 return failure();
1256 portNames.push_back(StringAttr::get(context, portName));
1257 }
1258
1259 // Parse the port type.
1260 Type portType;
1261 if (parser.parseColonType(portType))
1262 return failure();
1263 portTypes.push_back(TypeAttr::get(portType));
1264
1265 if (hasSSAIdentifiers)
1266 entryArgs.back().type = portType;
1267
1268 // Parse the optional port symbol.
1269 if (supportsSymbols) {
1270 hw::InnerSymAttr innerSymAttr;
1271 if (succeeded(parser.parseOptionalKeyword("sym"))) {
1272 NamedAttrList dummyAttrs;
1273 if (parser.parseCustomAttributeWithFallback(
1274 innerSymAttr, ::mlir::Type{},
1276 return ::mlir::failure();
1277 }
1278 }
1279 portSyms.push_back(innerSymAttr);
1280 }
1281
1282 // Parse the port annotations.
1283 ArrayAttr annos;
1284 auto parseResult = parser.parseOptionalAttribute(annos);
1285 if (!parseResult.has_value())
1286 annos = parser.getBuilder().getArrayAttr({});
1287 else if (failed(*parseResult))
1288 return failure();
1289 portAnnotations.push_back(annos);
1290
1291 // Parse the optional port location.
1292 std::optional<Location> maybeLoc;
1293 if (failed(parser.parseOptionalLocationSpecifier(maybeLoc)))
1294 return failure();
1295 Location loc = maybeLoc ? *maybeLoc : parser.getEncodedSourceLoc(irLoc);
1296 portLocs.push_back(loc);
1297 if (hasSSAIdentifiers)
1298 entryArgs.back().sourceLoc = loc;
1299
1300 return success();
1301 };
1302
1303 // Parse all ports.
1304 return parser.parseCommaSeparatedList(OpAsmParser::Delimiter::Paren,
1305 parseArgument);
1306}
1307
1308/// Print a paramter list for a module or instance.
1309static void printParameterList(OpAsmPrinter &p, Operation *op,
1310 ArrayAttr parameters) {
1311 if (!parameters || parameters.empty())
1312 return;
1313
1314 p << '<';
1315 llvm::interleaveComma(parameters, p, [&](Attribute param) {
1316 auto paramAttr = cast<ParamDeclAttr>(param);
1317 p << paramAttr.getName().getValue() << ": " << paramAttr.getType();
1318 if (auto value = paramAttr.getValue()) {
1319 p << " = ";
1320 p.printAttributeWithoutType(value);
1321 }
1322 });
1323 p << '>';
1324}
1325
1326static void printFModuleLikeOp(OpAsmPrinter &p, FModuleLike op) {
1327 p << " ";
1328
1329 // Print the visibility of the module.
1330 StringRef visibilityAttrName = SymbolTable::getVisibilityAttrName();
1331 if (auto visibility = op->getAttrOfType<StringAttr>(visibilityAttrName))
1332 p << visibility.getValue() << ' ';
1333
1334 // Print the operation and the function name.
1335 p.printSymbolName(op.getModuleName());
1336
1337 // Print the parameter list (if non-empty).
1338 printParameterList(p, op, op->getAttrOfType<ArrayAttr>("parameters"));
1339
1340 // Both modules and external modules have a body, but it is always empty for
1341 // external modules.
1342 Block *body = nullptr;
1343 if (!op->getRegion(0).empty())
1344 body = &op->getRegion(0).front();
1345
1346 auto needPortNamesAttr = printModulePorts(
1347 p, body, op.getPortDirectionsAttr(), op.getPortNames(), op.getPortTypes(),
1348 op.getPortAnnotations(), op.getPortSymbols(), op.getPortLocations());
1349
1350 SmallVector<StringRef, 12> omittedAttrs = {
1351 "sym_name", "portDirections", "portTypes", "portAnnotations",
1352 "portSymbols", "portLocations", "parameters", visibilityAttrName};
1353
1354 if (op.getConvention() == Convention::Internal)
1355 omittedAttrs.push_back("convention");
1356
1357 // We can omit the portNames if they were able to be printed as properly as
1358 // block arguments.
1359 if (!needPortNamesAttr)
1360 omittedAttrs.push_back("portNames");
1361
1362 // If there are no annotations we can omit the empty array.
1363 if (op->getAttrOfType<ArrayAttr>("annotations").empty())
1364 omittedAttrs.push_back("annotations");
1365
1366 // If there are no enabled layers, then omit the empty array.
1367 if (auto layers = op->getAttrOfType<ArrayAttr>("layers"))
1368 if (layers.empty())
1369 omittedAttrs.push_back("layers");
1370
1371 p.printOptionalAttrDictWithKeyword(op->getAttrs(), omittedAttrs);
1372}
1373
1374void FExtModuleOp::print(OpAsmPrinter &p) { printFModuleLikeOp(p, *this); }
1375
1376void FIntModuleOp::print(OpAsmPrinter &p) { printFModuleLikeOp(p, *this); }
1377
1378void FMemModuleOp::print(OpAsmPrinter &p) { printFModuleLikeOp(p, *this); }
1379
1380void FModuleOp::print(OpAsmPrinter &p) {
1381 printFModuleLikeOp(p, *this);
1382
1383 // Print the body if this is not an external function. Since this block does
1384 // not have terminators, printing the terminator actually just prints the last
1385 // operation.
1386 Region &fbody = getBody();
1387 if (!fbody.empty()) {
1388 p << " ";
1389 p.printRegion(fbody, /*printEntryBlockArgs=*/false,
1390 /*printBlockTerminators=*/true);
1391 }
1392}
1393
1394/// Parse an parameter list if present.
1395/// module-parameter-list ::= `<` parameter-decl (`,` parameter-decl)* `>`
1396/// parameter-decl ::= identifier `:` type
1397/// parameter-decl ::= identifier `:` type `=` attribute
1398///
1399static ParseResult
1400parseOptionalParameters(OpAsmParser &parser,
1401 SmallVectorImpl<Attribute> &parameters) {
1402
1403 return parser.parseCommaSeparatedList(
1404 OpAsmParser::Delimiter::OptionalLessGreater, [&]() {
1405 std::string name;
1406 Type type;
1407 Attribute value;
1408
1409 if (parser.parseKeywordOrString(&name) || parser.parseColonType(type))
1410 return failure();
1411
1412 // Parse the default value if present.
1413 if (succeeded(parser.parseOptionalEqual())) {
1414 if (parser.parseAttribute(value, type))
1415 return failure();
1416 }
1417
1418 auto &builder = parser.getBuilder();
1419 parameters.push_back(ParamDeclAttr::get(
1420 builder.getContext(), builder.getStringAttr(name), type, value));
1421 return success();
1422 });
1423}
1424
1425/// Shim to use with assemblyFormat, custom<ParameterList>.
1426static ParseResult parseParameterList(OpAsmParser &parser,
1427 ArrayAttr &parameters) {
1428 SmallVector<Attribute> parseParameters;
1429 if (failed(parseOptionalParameters(parser, parseParameters)))
1430 return failure();
1431
1432 parameters = ArrayAttr::get(parser.getContext(), parseParameters);
1433
1434 return success();
1435}
1436
1437template <typename Properties, typename = void>
1438struct HasParameters : std::false_type {};
1439
1440template <typename Properties>
1442 Properties, std::void_t<decltype(std::declval<Properties>().parameters)>>
1443 : std::true_type {};
1444
1445template <typename OpTy>
1446static ParseResult parseFModuleLikeOp(OpAsmParser &parser,
1447 OperationState &result,
1448 bool hasSSAIdentifiers) {
1449 auto *context = result.getContext();
1450 auto &builder = parser.getBuilder();
1451 using Properties = typename OpTy::Properties;
1452 auto &properties = result.getOrAddProperties<Properties>();
1453
1454 // TODO: this should be using properties.
1455 // Parse the visibility attribute.
1456 (void)mlir::impl::parseOptionalVisibilityKeyword(parser, result.attributes);
1457
1458 // Parse the name as a symbol.
1459 StringAttr nameAttr;
1460 if (parser.parseSymbolName(nameAttr))
1461 return failure();
1462 properties.setSymName(nameAttr);
1463
1464 // Parse optional parameters.
1465 if constexpr (HasParameters<Properties>::value) {
1466 SmallVector<Attribute, 4> parameters;
1467 if (parseOptionalParameters(parser, parameters))
1468 return failure();
1469 properties.setParameters(builder.getArrayAttr(parameters));
1470 }
1471
1472 // Parse the module ports.
1473 SmallVector<OpAsmParser::Argument> entryArgs;
1474 SmallVector<Direction, 4> portDirections;
1475 SmallVector<Attribute, 4> portNames;
1476 SmallVector<Attribute, 4> portTypes;
1477 SmallVector<Attribute, 4> portAnnotations;
1478 SmallVector<Attribute, 4> portSyms;
1479 SmallVector<Attribute, 4> portLocs;
1480 if (parseModulePorts(parser, hasSSAIdentifiers, /*supportsSymbols=*/true,
1481 entryArgs, portDirections, portNames, portTypes,
1482 portAnnotations, portSyms, portLocs))
1483 return failure();
1484
1485 // If module attributes are present, parse them.
1486 if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
1487 return failure();
1488
1489 assert(portNames.size() == portTypes.size());
1490
1491 // Record the argument and result types as an attribute. This is necessary
1492 // for external modules.
1493
1494 // Add port directions.
1495 properties.setPortDirections(
1496 direction::packAttribute(context, portDirections));
1497
1498 // Add port names.
1499 properties.setPortNames(builder.getArrayAttr(portNames));
1500
1501 // Add the port types.
1502 properties.setPortTypes(ArrayAttr::get(context, portTypes));
1503
1504 // Add the port annotations.
1505 // If there are no portAnnotations, don't add the attribute.
1506 if (llvm::any_of(portAnnotations, [&](Attribute anno) {
1507 return !cast<ArrayAttr>(anno).empty();
1508 }))
1509 properties.setPortAnnotations(ArrayAttr::get(context, portAnnotations));
1510 else
1511 properties.setPortAnnotations(builder.getArrayAttr({}));
1512
1513 // Add port symbols.
1514 FModuleLike::fixupPortSymsArray(portSyms, builder.getContext());
1515 properties.setPortSymbols(builder.getArrayAttr(portSyms));
1516
1517 // Add port locations.
1518 properties.setPortLocations(ArrayAttr::get(context, portLocs));
1519
1520 // The annotations attribute is always present, but not printed when empty.
1521 properties.setAnnotations(builder.getArrayAttr({}));
1522
1523 // Parse the optional function body.
1524 auto *body = result.addRegion();
1525
1526 if (hasSSAIdentifiers) {
1527 if (parser.parseRegion(*body, entryArgs))
1528 return failure();
1529 if (body->empty())
1530 body->push_back(new Block());
1531 }
1532 return success();
1533}
1534
1535ParseResult FModuleOp::parse(OpAsmParser &parser, OperationState &result) {
1536 if (parseFModuleLikeOp<FModuleOp>(parser, result,
1537 /*hasSSAIdentifiers=*/true))
1538 return failure();
1539 auto &properties = result.getOrAddProperties<Properties>();
1540 properties.setConvention(
1541 ConventionAttr::get(result.getContext(), Convention::Internal));
1542 properties.setLayers(ArrayAttr::get(parser.getContext(), {}));
1543 return success();
1544}
1545
1546ParseResult FExtModuleOp::parse(OpAsmParser &parser, OperationState &result) {
1547 if (parseFModuleLikeOp<FExtModuleOp>(parser, result,
1548 /*hasSSAIdentifiers=*/false))
1549 return failure();
1550 auto &properties = result.getOrAddProperties<Properties>();
1551 properties.setConvention(
1552 ConventionAttr::get(result.getContext(), Convention::Internal));
1553 return success();
1554}
1555
1556ParseResult FIntModuleOp::parse(OpAsmParser &parser, OperationState &result) {
1557 return parseFModuleLikeOp<FIntModuleOp>(parser, result,
1558 /*hasSSAIdentifiers=*/false);
1559}
1560
1561ParseResult FMemModuleOp::parse(OpAsmParser &parser, OperationState &result) {
1562 return parseFModuleLikeOp<FMemModuleOp>(parser, result,
1563 /*hasSSAIdentifiers=*/false);
1564}
1565
1566LogicalResult FModuleOp::verify() {
1567 // Verify the block arguments.
1568 auto *body = getBodyBlock();
1569 auto portTypes = getPortTypes();
1570 auto portLocs = getPortLocations();
1571 auto numPorts = portTypes.size();
1572
1573 // Verify that we have the correct number of block arguments.
1574 if (body->getNumArguments() != numPorts)
1575 return emitOpError("entry block must have ")
1576 << numPorts << " arguments to match module signature";
1577
1578 // Verify the block arguments' types and locations match our attributes.
1579 for (auto [arg, type, loc] : zip(body->getArguments(), portTypes, portLocs)) {
1580 if (arg.getType() != cast<TypeAttr>(type).getValue())
1581 return emitOpError("block argument types should match signature types");
1582 if (arg.getLoc() != cast<LocationAttr>(loc))
1583 return emitOpError(
1584 "block argument locations should match signature locations");
1585 }
1586
1587 return success();
1588}
1589
1590static LogicalResult
1591verifyInternalPaths(FModuleLike op,
1592 std::optional<::mlir::ArrayAttr> internalPaths) {
1593 if (!internalPaths)
1594 return success();
1595
1596 // If internal paths are present, should cover all ports.
1597 if (internalPaths->size() != op.getNumPorts())
1598 return op.emitError("module has inconsistent internal path array with ")
1599 << internalPaths->size() << " entries for " << op.getNumPorts()
1600 << " ports";
1601
1602 // No internal paths for non-ref-type ports.
1603 for (auto [idx, path, typeattr] : llvm::enumerate(
1604 internalPaths->getAsRange<InternalPathAttr>(), op.getPortTypes())) {
1605 if (path.getPath() &&
1606 !type_isa<RefType>(cast<TypeAttr>(typeattr).getValue())) {
1607 auto diag =
1608 op.emitError("module has internal path for non-ref-type port ")
1609 << op.getPortNameAttr(idx);
1610 return diag.attachNote(op.getPortLocation(idx)) << "this port";
1611 }
1612 }
1613
1614 return success();
1615}
1616
1617LogicalResult FExtModuleOp::verify() {
1618 if (failed(verifyInternalPaths(*this, getInternalPaths())))
1619 return failure();
1620
1621 auto params = getParameters();
1622 if (params.empty())
1623 return success();
1624
1625 auto checkParmValue = [&](Attribute elt) -> bool {
1626 auto param = cast<ParamDeclAttr>(elt);
1627 auto value = param.getValue();
1628 if (isa<IntegerAttr, StringAttr, FloatAttr, hw::ParamVerbatimAttr>(value))
1629 return true;
1630 emitError() << "has unknown extmodule parameter value '"
1631 << param.getName().getValue() << "' = " << value;
1632 return false;
1633 };
1634
1635 if (!llvm::all_of(params, checkParmValue))
1636 return failure();
1637
1638 return success();
1639}
1640
1641LogicalResult FIntModuleOp::verify() {
1642 if (failed(verifyInternalPaths(*this, getInternalPaths())))
1643 return failure();
1644
1645 auto params = getParameters();
1646 if (params.empty())
1647 return success();
1648
1649 auto checkParmValue = [&](Attribute elt) -> bool {
1650 auto param = cast<ParamDeclAttr>(elt);
1651 auto value = param.getValue();
1652 if (isa<IntegerAttr, StringAttr, FloatAttr>(value))
1653 return true;
1654 emitError() << "has unknown intmodule parameter value '"
1655 << param.getName().getValue() << "' = " << value;
1656 return false;
1657 };
1658
1659 if (!llvm::all_of(params, checkParmValue))
1660 return failure();
1661
1662 return success();
1663}
1664
1665static LogicalResult verifyProbeType(RefType refType, Location loc,
1666 CircuitOp circuitOp,
1667 SymbolTableCollection &symbolTable,
1668 Twine start) {
1669 auto layer = refType.getLayer();
1670 if (!layer)
1671 return success();
1672 auto *layerOp = symbolTable.lookupSymbolIn(circuitOp, layer);
1673 if (!layerOp)
1674 return emitError(loc) << start << " associated with layer '" << layer
1675 << "', but this layer was not defined";
1676 if (!isa<LayerOp>(layerOp)) {
1677 auto diag = emitError(loc)
1678 << start << " associated with layer '" << layer
1679 << "', but symbol '" << layer << "' does not refer to a '"
1680 << LayerOp::getOperationName() << "' op";
1681 return diag.attachNote(layerOp->getLoc()) << "symbol refers to this op";
1682 }
1683 return success();
1684}
1685
1686static LogicalResult verifyPortSymbolUses(FModuleLike module,
1687 SymbolTableCollection &symbolTable) {
1688 // verify types in ports.
1689 auto circuitOp = module->getParentOfType<CircuitOp>();
1690 for (size_t i = 0, e = module.getNumPorts(); i < e; ++i) {
1691 auto type = module.getPortType(i);
1692
1693 if (auto refType = type_dyn_cast<RefType>(type)) {
1694 if (failed(verifyProbeType(
1695 refType, module.getPortLocation(i), circuitOp, symbolTable,
1696 Twine("probe port '") + module.getPortName(i) + "' is")))
1697 return failure();
1698 continue;
1699 }
1700
1701 if (auto classType = dyn_cast<ClassType>(type)) {
1702 auto className = classType.getNameAttr();
1703 auto classOp = dyn_cast_or_null<ClassLike>(
1704 symbolTable.lookupSymbolIn(circuitOp, className));
1705 if (!classOp)
1706 return module.emitOpError() << "references unknown class " << className;
1707
1708 // verify that the result type agrees with the class definition.
1709 if (failed(classOp.verifyType(classType,
1710 [&]() { return module.emitOpError(); })))
1711 return failure();
1712 continue;
1713 }
1714 }
1715
1716 return success();
1717}
1718
1719LogicalResult FModuleOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1720 if (failed(
1721 verifyPortSymbolUses(cast<FModuleLike>(getOperation()), symbolTable)))
1722 return failure();
1723
1724 auto circuitOp = (*this)->getParentOfType<CircuitOp>();
1725 for (auto layer : getLayers()) {
1726 if (!symbolTable.lookupSymbolIn(circuitOp, cast<SymbolRefAttr>(layer)))
1727 return emitOpError() << "enables unknown layer '" << layer << "'";
1728 }
1729
1730 return success();
1731}
1732
1733LogicalResult
1734FExtModuleOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1735 return verifyPortSymbolUses(cast<FModuleLike>(getOperation()), symbolTable);
1736}
1737
1738LogicalResult
1739FIntModuleOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1740 return verifyPortSymbolUses(cast<FModuleLike>(getOperation()), symbolTable);
1741}
1742
1743LogicalResult
1744FMemModuleOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1745 return verifyPortSymbolUses(cast<FModuleLike>(getOperation()), symbolTable);
1746}
1747
1748void FModuleOp::getAsmBlockArgumentNames(mlir::Region &region,
1749 mlir::OpAsmSetValueNameFn setNameFn) {
1750 getAsmBlockArgumentNamesImpl(getOperation(), region, setNameFn);
1751}
1752
1753void FExtModuleOp::getAsmBlockArgumentNames(
1754 mlir::Region &region, mlir::OpAsmSetValueNameFn setNameFn) {
1755 getAsmBlockArgumentNamesImpl(getOperation(), region, setNameFn);
1756}
1757
1758void FIntModuleOp::getAsmBlockArgumentNames(
1759 mlir::Region &region, mlir::OpAsmSetValueNameFn setNameFn) {
1760 getAsmBlockArgumentNamesImpl(getOperation(), region, setNameFn);
1761}
1762
1763void FMemModuleOp::getAsmBlockArgumentNames(
1764 mlir::Region &region, mlir::OpAsmSetValueNameFn setNameFn) {
1765 getAsmBlockArgumentNamesImpl(getOperation(), region, setNameFn);
1766}
1767
1768ArrayAttr FMemModuleOp::getParameters() { return {}; }
1769
1770ArrayAttr FModuleOp::getParameters() { return {}; }
1771
1772Convention FIntModuleOp::getConvention() { return Convention::Internal; }
1773
1774ConventionAttr FIntModuleOp::getConventionAttr() {
1775 return ConventionAttr::get(getContext(), getConvention());
1776}
1777
1778Convention FMemModuleOp::getConvention() { return Convention::Internal; }
1779
1780ConventionAttr FMemModuleOp::getConventionAttr() {
1781 return ConventionAttr::get(getContext(), getConvention());
1782}
1783
1784//===----------------------------------------------------------------------===//
1785// ClassLike Helpers
1786//===----------------------------------------------------------------------===//
1787
1789 ClassLike classOp, ClassType type,
1790 function_ref<InFlightDiagnostic()> emitError) {
1791 // This check is probably not required, but done for sanity.
1792 auto name = type.getNameAttr().getAttr();
1793 auto expectedName = classOp.getModuleNameAttr();
1794 if (name != expectedName)
1795 return emitError() << "type has wrong name, got " << name << ", expected "
1796 << expectedName;
1797
1798 auto elements = type.getElements();
1799 auto numElements = elements.size();
1800 auto expectedNumElements = classOp.getNumPorts();
1801 if (numElements != expectedNumElements)
1802 return emitError() << "has wrong number of ports, got " << numElements
1803 << ", expected " << expectedNumElements;
1804
1805 auto portNames = classOp.getPortNames();
1806 auto portDirections = classOp.getPortDirections();
1807 auto portTypes = classOp.getPortTypes();
1808
1809 for (unsigned i = 0; i < numElements; ++i) {
1810 auto element = elements[i];
1811
1812 auto name = element.name;
1813 auto expectedName = portNames[i];
1814 if (name != expectedName)
1815 return emitError() << "port #" << i << " has wrong name, got " << name
1816 << ", expected " << expectedName;
1817
1818 auto direction = element.direction;
1819 auto expectedDirection = Direction(portDirections[i]);
1820 if (direction != expectedDirection)
1821 return emitError() << "port " << name << " has wrong direction, got "
1822 << direction::toString(direction) << ", expected "
1823 << direction::toString(expectedDirection);
1824
1825 auto type = element.type;
1826 auto expectedType = cast<TypeAttr>(portTypes[i]).getValue();
1827 if (type != expectedType)
1828 return emitError() << "port " << name << " has wrong type, got " << type
1829 << ", expected " << expectedType;
1830 }
1831
1832 return success();
1833}
1834
1836 auto n = classOp.getNumPorts();
1837 SmallVector<ClassElement> elements;
1838 elements.reserve(n);
1839 for (size_t i = 0; i < n; ++i)
1840 elements.push_back({classOp.getPortNameAttr(i), classOp.getPortType(i),
1841 classOp.getPortDirection(i)});
1842 auto name = FlatSymbolRefAttr::get(classOp.getNameAttr());
1843 return ClassType::get(name, elements);
1844}
1845
1846template <typename OpTy>
1847ParseResult parseClassLike(OpAsmParser &parser, OperationState &result,
1848 bool hasSSAIdentifiers) {
1849 auto *context = result.getContext();
1850 auto &builder = parser.getBuilder();
1851 auto &properties = result.getOrAddProperties<typename OpTy::Properties>();
1852
1853 // TODO: this should use properties.
1854 // Parse the visibility attribute.
1855 (void)mlir::impl::parseOptionalVisibilityKeyword(parser, result.attributes);
1856
1857 // Parse the name as a symbol.
1858 StringAttr nameAttr;
1859 if (parser.parseSymbolName(nameAttr))
1860 return failure();
1861 properties.setSymName(nameAttr);
1862
1863 // Parse the module ports.
1864 SmallVector<OpAsmParser::Argument> entryArgs;
1865 SmallVector<Direction, 4> portDirections;
1866 SmallVector<Attribute, 4> portNames;
1867 SmallVector<Attribute, 4> portTypes;
1868 SmallVector<Attribute, 4> portAnnotations;
1869 SmallVector<Attribute, 4> portSyms;
1870 SmallVector<Attribute, 4> portLocs;
1871 if (parseModulePorts(parser, hasSSAIdentifiers,
1872 /*supportsSymbols=*/false, entryArgs, portDirections,
1873 portNames, portTypes, portAnnotations, portSyms,
1874 portLocs))
1875 return failure();
1876
1877 // Ports on ClassLike ops cannot have annotations
1878 for (auto annos : portAnnotations)
1879 if (!cast<ArrayAttr>(annos).empty())
1880 return failure();
1881
1882 // If attributes are present, parse them.
1883 if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
1884 return failure();
1885
1886 assert(portNames.size() == portTypes.size());
1887
1888 // Record the argument and result types as an attribute. This is necessary
1889 // for external modules.
1890
1891 // Add port directions.
1892 properties.setPortDirections(
1893 direction::packAttribute(context, portDirections));
1894
1895 // Add port names.
1896 properties.setPortNames(builder.getArrayAttr(portNames));
1897
1898 // Add the port types.
1899 properties.setPortTypes(builder.getArrayAttr(portTypes));
1900
1901 // Add the port symbols.
1902 FModuleLike::fixupPortSymsArray(portSyms, builder.getContext());
1903 properties.setPortSymbols(builder.getArrayAttr(portSyms));
1904
1905 // Add port locations.
1906 properties.setPortLocations(ArrayAttr::get(context, portLocs));
1907
1908 // Notably missing compared to other FModuleLike, we do not track port
1909 // annotations, nor port symbols, on classes.
1910
1911 // Add the region (unused by extclass).
1912 auto *bodyRegion = result.addRegion();
1913
1914 if (hasSSAIdentifiers) {
1915 if (parser.parseRegion(*bodyRegion, entryArgs))
1916 return failure();
1917 if (bodyRegion->empty())
1918 bodyRegion->push_back(new Block());
1919 }
1920
1921 return success();
1922}
1923
1924static void printClassLike(OpAsmPrinter &p, ClassLike op) {
1925 p << ' ';
1926
1927 // Print the visibility of the class.
1928 StringRef visibilityAttrName = SymbolTable::getVisibilityAttrName();
1929 if (auto visibility = op->getAttrOfType<StringAttr>(visibilityAttrName))
1930 p << visibility.getValue() << ' ';
1931
1932 // Print the class name.
1933 p.printSymbolName(op.getName());
1934
1935 // Both classes and external classes have a body, but it is always empty for
1936 // external classes.
1937 Region &region = op->getRegion(0);
1938 Block *body = nullptr;
1939 if (!region.empty())
1940 body = &region.front();
1941
1942 auto needPortNamesAttr = printModulePorts(
1943 p, body, op.getPortDirectionsAttr(), op.getPortNames(), op.getPortTypes(),
1944 {}, op.getPortSymbols(), op.getPortLocations());
1945
1946 // Print the attr-dict.
1947 SmallVector<StringRef, 8> omittedAttrs = {
1948 "sym_name", "portNames", "portTypes", "portDirections",
1949 "portSymbols", "portLocations", visibilityAttrName};
1950
1951 // We can omit the portNames if they were able to be printed as properly as
1952 // block arguments.
1953 if (!needPortNamesAttr)
1954 omittedAttrs.push_back("portNames");
1955
1956 p.printOptionalAttrDictWithKeyword(op->getAttrs(), omittedAttrs);
1957
1958 // print the body if it exists.
1959 if (!region.empty()) {
1960 p << " ";
1961 auto printEntryBlockArgs = false;
1962 auto printBlockTerminators = false;
1963 p.printRegion(region, printEntryBlockArgs, printBlockTerminators);
1964 }
1965}
1966
1967//===----------------------------------------------------------------------===//
1968// ClassOp
1969//===----------------------------------------------------------------------===//
1970
1971void ClassOp::build(OpBuilder &builder, OperationState &result, StringAttr name,
1972 ArrayRef<PortInfo> ports) {
1973 assert(
1974 llvm::all_of(ports,
1975 [](const auto &port) { return port.annotations.empty(); }) &&
1976 "class ports may not have annotations");
1977
1978 buildClass<ClassOp>(builder, result, name, ports);
1979
1980 // Create a region and a block for the body.
1981 auto *bodyRegion = result.regions[0].get();
1982 Block *body = new Block();
1983 bodyRegion->push_back(body);
1984
1985 // Add arguments to the body block.
1986 for (auto &elt : ports)
1987 body->addArgument(elt.type, elt.loc);
1988}
1989
1990void ClassOp::build(::mlir::OpBuilder &odsBuilder,
1991 ::mlir::OperationState &odsState, Twine name,
1992 mlir::ArrayRef<mlir::StringRef> fieldNames,
1993 mlir::ArrayRef<mlir::Type> fieldTypes) {
1994
1995 SmallVector<PortInfo, 10> ports;
1996 for (auto [fieldName, fieldType] : llvm::zip(fieldNames, fieldTypes)) {
1997 ports.emplace_back(odsBuilder.getStringAttr(fieldName + "_in"), fieldType,
1998 Direction::In);
1999 ports.emplace_back(odsBuilder.getStringAttr(fieldName), fieldType,
2000 Direction::Out);
2001 }
2002 build(odsBuilder, odsState, odsBuilder.getStringAttr(name), ports);
2003 // Create a region and a block for the body.
2004 auto &body = odsState.regions[0]->getBlocks().front();
2005 auto prevLoc = odsBuilder.saveInsertionPoint();
2006 odsBuilder.setInsertionPointToEnd(&body);
2007 auto args = body.getArguments();
2008 auto loc = odsState.location;
2009 for (unsigned i = 0, e = ports.size(); i != e; i += 2)
2010 odsBuilder.create<PropAssignOp>(loc, args[i + 1], args[i]);
2011
2012 odsBuilder.restoreInsertionPoint(prevLoc);
2013}
2014void ClassOp::print(OpAsmPrinter &p) {
2015 printClassLike(p, cast<ClassLike>(getOperation()));
2016}
2017
2018ParseResult ClassOp::parse(OpAsmParser &parser, OperationState &result) {
2019 auto hasSSAIdentifiers = true;
2020 return parseClassLike<ClassOp>(parser, result, hasSSAIdentifiers);
2021}
2022
2023LogicalResult ClassOp::verify() {
2024 for (auto operand : getBodyBlock()->getArguments()) {
2025 auto type = operand.getType();
2026 if (!isa<PropertyType>(type)) {
2027 emitOpError("ports on a class must be properties");
2028 return failure();
2029 }
2030 }
2031
2032 return success();
2033}
2034
2035LogicalResult
2036ClassOp::verifySymbolUses(::mlir::SymbolTableCollection &symbolTable) {
2037 return verifyPortSymbolUses(cast<FModuleLike>(getOperation()), symbolTable);
2038}
2039
2040void ClassOp::getAsmBlockArgumentNames(mlir::Region &region,
2041 mlir::OpAsmSetValueNameFn setNameFn) {
2042 getAsmBlockArgumentNamesImpl(getOperation(), region, setNameFn);
2043}
2044
2045SmallVector<PortInfo> ClassOp::getPorts() {
2046 return ::getPortImpl(cast<FModuleLike>((Operation *)*this));
2047}
2048
2049void ClassOp::erasePorts(const llvm::BitVector &portIndices) {
2050 ::erasePorts(cast<FModuleLike>((Operation *)*this), portIndices);
2051 getBodyBlock()->eraseArguments(portIndices);
2052}
2053
2054void ClassOp::insertPorts(ArrayRef<std::pair<unsigned, PortInfo>> ports) {
2055 ::insertPorts(cast<FModuleLike>((Operation *)*this), ports);
2056}
2057
2058Convention ClassOp::getConvention() { return Convention::Internal; }
2059
2060ConventionAttr ClassOp::getConventionAttr() {
2061 return ConventionAttr::get(getContext(), getConvention());
2062}
2063
2064ArrayAttr ClassOp::getParameters() { return {}; }
2065
2066ArrayAttr ClassOp::getPortAnnotationsAttr() {
2067 return ArrayAttr::get(getContext(), {});
2068}
2069
2070ArrayRef<Attribute> ClassOp::getPortAnnotations() { return {}; }
2071
2072void ClassOp::setPortAnnotationsAttr(ArrayAttr annotations) {
2073 llvm_unreachable("classes do not support annotations");
2074}
2075
2076ArrayAttr ClassOp::getLayersAttr() { return ArrayAttr::get(getContext(), {}); }
2077
2078ArrayRef<Attribute> ClassOp::getLayers() { return {}; }
2079
2080SmallVector<::circt::hw::PortInfo> ClassOp::getPortList() {
2081 return ::getPortListImpl(*this);
2082}
2083
2084::circt::hw::PortInfo ClassOp::getPort(size_t idx) {
2085 return ::getPortImpl(*this, idx);
2086}
2087
2088BlockArgument ClassOp::getArgument(size_t portNumber) {
2089 return getBodyBlock()->getArgument(portNumber);
2090}
2091
2092bool ClassOp::canDiscardOnUseEmpty() {
2093 // ClassOps are referenced by ClassTypes, and these uses are not
2094 // discoverable by the symbol infrastructure. Return false here to prevent
2095 // passes like symbolDCE from removing our classes.
2096 return false;
2097}
2098
2099//===----------------------------------------------------------------------===//
2100// ExtClassOp
2101//===----------------------------------------------------------------------===//
2102
2103void ExtClassOp::build(OpBuilder &builder, OperationState &result,
2104 StringAttr name, ArrayRef<PortInfo> ports) {
2105 assert(
2106 llvm::all_of(ports,
2107 [](const auto &port) { return port.annotations.empty(); }) &&
2108 "class ports may not have annotations");
2109 buildClass<ClassOp>(builder, result, name, ports);
2110}
2111
2112void ExtClassOp::print(OpAsmPrinter &p) {
2113 printClassLike(p, cast<ClassLike>(getOperation()));
2114}
2115
2116ParseResult ExtClassOp::parse(OpAsmParser &parser, OperationState &result) {
2117 auto hasSSAIdentifiers = false;
2118 return parseClassLike<ExtClassOp>(parser, result, hasSSAIdentifiers);
2119}
2120
2121LogicalResult
2122ExtClassOp::verifySymbolUses(::mlir::SymbolTableCollection &symbolTable) {
2123 return verifyPortSymbolUses(cast<FModuleLike>(getOperation()), symbolTable);
2124}
2125
2126void ExtClassOp::getAsmBlockArgumentNames(mlir::Region &region,
2127 mlir::OpAsmSetValueNameFn setNameFn) {
2128 getAsmBlockArgumentNamesImpl(getOperation(), region, setNameFn);
2129}
2130
2131SmallVector<PortInfo> ExtClassOp::getPorts() {
2132 return ::getPortImpl(cast<FModuleLike>((Operation *)*this));
2133}
2134
2135void ExtClassOp::erasePorts(const llvm::BitVector &portIndices) {
2136 ::erasePorts(cast<FModuleLike>((Operation *)*this), portIndices);
2137}
2138
2139void ExtClassOp::insertPorts(ArrayRef<std::pair<unsigned, PortInfo>> ports) {
2140 ::insertPorts(cast<FModuleLike>((Operation *)*this), ports);
2141}
2142
2143Convention ExtClassOp::getConvention() { return Convention::Internal; }
2144
2145ConventionAttr ExtClassOp::getConventionAttr() {
2146 return ConventionAttr::get(getContext(), getConvention());
2147}
2148
2149ArrayAttr ExtClassOp::getLayersAttr() {
2150 return ArrayAttr::get(getContext(), {});
2151}
2152
2153ArrayRef<Attribute> ExtClassOp::getLayers() { return {}; }
2154
2155ArrayAttr ExtClassOp::getParameters() { return {}; }
2156
2157ArrayAttr ExtClassOp::getPortAnnotationsAttr() {
2158 return ArrayAttr::get(getContext(), {});
2159}
2160
2161ArrayRef<Attribute> ExtClassOp::getPortAnnotations() { return {}; }
2162
2163void ExtClassOp::setPortAnnotationsAttr(ArrayAttr annotations) {
2164 llvm_unreachable("classes do not support annotations");
2165}
2166
2167SmallVector<::circt::hw::PortInfo> ExtClassOp::getPortList() {
2168 return ::getPortListImpl(*this);
2169}
2170
2171::circt::hw::PortInfo ExtClassOp::getPort(size_t idx) {
2172 return ::getPortImpl(*this, idx);
2173}
2174
2175bool ExtClassOp::canDiscardOnUseEmpty() {
2176 // ClassOps are referenced by ClassTypes, and these uses are not
2177 // discovereable by the symbol infrastructure. Return false here to prevent
2178 // passes like symbolDCE from removing our classes.
2179 return false;
2180}
2181
2182//===----------------------------------------------------------------------===//
2183// LayerOp
2184//===----------------------------------------------------------------------===//
2185
2186LogicalResult LayerOp::verify() {
2187
2188 // A Bind Convention layer may not exist under an Inline Convention layer
2189 // because we haven't implemented a lowering for it. A lowering should be
2190 // possible, but it gets weird. This also transitively disallows a Bind
2191 // Convention under an Inline Convention under a Bind Convention. We don't
2192 // have a lowering for this either. Consequently, just reject this for now as
2193 // it's a niche use case.
2194 //
2195 // TODO: Remove this restriction by defining a lowering for bind-under-inline
2196 // and bind-under-inline-under-bind.
2197 if (getConvention() == LayerConvention::Bind) {
2198 Operation *parentOp = (*this)->getParentOp();
2199 while (auto parentLayer = dyn_cast<LayerOp>(parentOp)) {
2200 if (parentLayer.getConvention() == LayerConvention::Inline) {
2201 auto diag = emitOpError() << "has bind convention and cannot be nested "
2202 "under a layer with inline convention";
2203 diag.attachNote(parentLayer.getLoc())
2204 << "layer with inline convention here";
2205 return failure();
2206 }
2207 parentOp = parentOp->getParentOp();
2208 }
2209 }
2210
2211 return success();
2212}
2213
2214//===----------------------------------------------------------------------===//
2215// InstanceOp
2216//===----------------------------------------------------------------------===//
2217
2218void InstanceOp::build(
2219 OpBuilder &builder, OperationState &result, TypeRange resultTypes,
2220 StringRef moduleName, StringRef name, NameKindEnum nameKind,
2221 ArrayRef<Direction> portDirections, ArrayRef<Attribute> portNames,
2222 ArrayRef<Attribute> annotations, ArrayRef<Attribute> portAnnotations,
2223 ArrayRef<Attribute> layers, bool lowerToBind, StringAttr innerSym) {
2224 build(builder, result, resultTypes, moduleName, name, nameKind,
2225 portDirections, portNames, annotations, portAnnotations, layers,
2226 lowerToBind,
2227 innerSym ? hw::InnerSymAttr::get(innerSym) : hw::InnerSymAttr());
2228}
2229
2230void InstanceOp::build(
2231 OpBuilder &builder, OperationState &result, TypeRange resultTypes,
2232 StringRef moduleName, StringRef name, NameKindEnum nameKind,
2233 ArrayRef<Direction> portDirections, ArrayRef<Attribute> portNames,
2234 ArrayRef<Attribute> annotations, ArrayRef<Attribute> portAnnotations,
2235 ArrayRef<Attribute> layers, bool lowerToBind, hw::InnerSymAttr innerSym) {
2236 result.addTypes(resultTypes);
2237 result.getOrAddProperties<Properties>().setModuleName(
2238 SymbolRefAttr::get(builder.getContext(), moduleName));
2239 result.getOrAddProperties<Properties>().setName(builder.getStringAttr(name));
2240 result.getOrAddProperties<Properties>().setPortDirections(
2241 direction::packAttribute(builder.getContext(), portDirections));
2242 result.getOrAddProperties<Properties>().setPortNames(
2243 builder.getArrayAttr(portNames));
2244 result.getOrAddProperties<Properties>().setAnnotations(
2245 builder.getArrayAttr(annotations));
2246 result.getOrAddProperties<Properties>().setLayers(
2247 builder.getArrayAttr(layers));
2248 if (lowerToBind)
2249 result.getOrAddProperties<Properties>().setLowerToBind(
2250 builder.getUnitAttr());
2251 if (innerSym)
2252 result.getOrAddProperties<Properties>().setInnerSym(innerSym);
2253
2254 result.getOrAddProperties<Properties>().setNameKind(
2255 NameKindEnumAttr::get(builder.getContext(), nameKind));
2256
2257 if (portAnnotations.empty()) {
2258 SmallVector<Attribute, 16> portAnnotationsVec(resultTypes.size(),
2259 builder.getArrayAttr({}));
2260 result.getOrAddProperties<Properties>().setPortAnnotations(
2261 builder.getArrayAttr(portAnnotationsVec));
2262 } else {
2263 assert(portAnnotations.size() == resultTypes.size());
2264 result.getOrAddProperties<Properties>().setPortAnnotations(
2265 builder.getArrayAttr(portAnnotations));
2266 }
2267}
2268
2269void InstanceOp::build(OpBuilder &builder, OperationState &result,
2270 FModuleLike module, StringRef name,
2271 NameKindEnum nameKind, ArrayRef<Attribute> annotations,
2272 ArrayRef<Attribute> portAnnotations, bool lowerToBind,
2273 hw::InnerSymAttr innerSym) {
2274
2275 // Gather the result types.
2276 SmallVector<Type> resultTypes;
2277 resultTypes.reserve(module.getNumPorts());
2278 llvm::transform(
2279 module.getPortTypes(), std::back_inserter(resultTypes),
2280 [](Attribute typeAttr) { return cast<TypeAttr>(typeAttr).getValue(); });
2281
2282 // Create the port annotations.
2283 ArrayAttr portAnnotationsAttr;
2284 if (portAnnotations.empty()) {
2285 portAnnotationsAttr = builder.getArrayAttr(SmallVector<Attribute, 16>(
2286 resultTypes.size(), builder.getArrayAttr({})));
2287 } else {
2288 portAnnotationsAttr = builder.getArrayAttr(portAnnotations);
2289 }
2290
2291 return build(
2292 builder, result, resultTypes,
2293 SymbolRefAttr::get(builder.getContext(), module.getModuleNameAttr()),
2294 builder.getStringAttr(name),
2295 NameKindEnumAttr::get(builder.getContext(), nameKind),
2296 module.getPortDirectionsAttr(), module.getPortNamesAttr(),
2297 builder.getArrayAttr(annotations), portAnnotationsAttr,
2298 module.getLayersAttr(), lowerToBind ? builder.getUnitAttr() : UnitAttr(),
2299 innerSym);
2300}
2301
2302void InstanceOp::build(OpBuilder &builder, OperationState &odsState,
2303 ArrayRef<PortInfo> ports, StringRef moduleName,
2304 StringRef name, NameKindEnum nameKind,
2305 ArrayRef<Attribute> annotations,
2306 ArrayRef<Attribute> layers, bool lowerToBind,
2307 hw::InnerSymAttr innerSym) {
2308 // Gather the result types.
2309 SmallVector<Type> newResultTypes;
2310 SmallVector<Direction> newPortDirections;
2311 SmallVector<Attribute> newPortNames;
2312 SmallVector<Attribute> newPortAnnotations;
2313 for (auto &p : ports) {
2314 newResultTypes.push_back(p.type);
2315 newPortDirections.push_back(p.direction);
2316 newPortNames.push_back(p.name);
2317 newPortAnnotations.push_back(p.annotations.getArrayAttr());
2318 }
2319
2320 return build(builder, odsState, newResultTypes, moduleName, name, nameKind,
2321 newPortDirections, newPortNames, annotations, newPortAnnotations,
2322 layers, lowerToBind, innerSym);
2323}
2324
2325LogicalResult InstanceOp::verify() {
2326 // The instance may only be instantiated under its required layers.
2327 auto ambientLayers = getAmbientLayersAt(getOperation());
2328 SmallVector<SymbolRefAttr> missingLayers;
2329 for (auto layer : getLayersAttr().getAsRange<SymbolRefAttr>())
2330 if (!isLayerCompatibleWith(layer, ambientLayers))
2331 missingLayers.push_back(layer);
2332
2333 if (missingLayers.empty())
2334 return success();
2335
2336 auto diag =
2337 emitOpError("ambient layers are insufficient to instantiate module");
2338 auto &note = diag.attachNote();
2339 note << "missing layer requirements: ";
2340 interleaveComma(missingLayers, note);
2341 return failure();
2342}
2343
2344/// Builds a new `InstanceOp` with the ports listed in `portIndices` erased, and
2345/// updates any users of the remaining ports to point at the new instance.
2346InstanceOp InstanceOp::erasePorts(OpBuilder &builder,
2347 const llvm::BitVector &portIndices) {
2348 assert(portIndices.size() >= getNumResults() &&
2349 "portIndices is not at least as large as getNumResults()");
2350
2351 if (portIndices.none())
2352 return *this;
2353
2354 SmallVector<Type> newResultTypes = removeElementsAtIndices<Type>(
2355 SmallVector<Type>(result_type_begin(), result_type_end()), portIndices);
2356 SmallVector<Direction> newPortDirections = removeElementsAtIndices<Direction>(
2357 direction::unpackAttribute(getPortDirectionsAttr()), portIndices);
2358 SmallVector<Attribute> newPortNames =
2359 removeElementsAtIndices(getPortNames().getValue(), portIndices);
2360 SmallVector<Attribute> newPortAnnotations =
2361 removeElementsAtIndices(getPortAnnotations().getValue(), portIndices);
2362
2363 auto newOp = builder.create<InstanceOp>(
2364 getLoc(), newResultTypes, getModuleName(), getName(), getNameKind(),
2365 newPortDirections, newPortNames, getAnnotations().getValue(),
2366 newPortAnnotations, getLayers(), getLowerToBind(), getInnerSymAttr());
2367
2368 for (unsigned oldIdx = 0, newIdx = 0, numOldPorts = getNumResults();
2369 oldIdx != numOldPorts; ++oldIdx) {
2370 if (portIndices.test(oldIdx)) {
2371 assert(getResult(oldIdx).use_empty() && "removed instance port has uses");
2372 continue;
2373 }
2374 getResult(oldIdx).replaceAllUsesWith(newOp.getResult(newIdx));
2375 ++newIdx;
2376 }
2377
2378 // Compy over "output_file" information so that this is not lost when ports
2379 // are erased.
2380 //
2381 // TODO: Other attributes may need to be copied over.
2382 if (auto outputFile = (*this)->getAttr("output_file"))
2383 newOp->setAttr("output_file", outputFile);
2384
2385 return newOp;
2386}
2387
2388ArrayAttr InstanceOp::getPortAnnotation(unsigned portIdx) {
2389 assert(portIdx < getNumResults() &&
2390 "index should be smaller than result number");
2391 return cast<ArrayAttr>(getPortAnnotations()[portIdx]);
2392}
2393
2394void InstanceOp::setAllPortAnnotations(ArrayRef<Attribute> annotations) {
2395 assert(annotations.size() == getNumResults() &&
2396 "number of annotations is not equal to result number");
2397 (*this)->setAttr("portAnnotations",
2398 ArrayAttr::get(getContext(), annotations));
2399}
2400
2401InstanceOp
2402InstanceOp::cloneAndInsertPorts(ArrayRef<std::pair<unsigned, PortInfo>> ports) {
2403 auto portSize = ports.size();
2404 auto newPortCount = getNumResults() + portSize;
2405 SmallVector<Direction> newPortDirections;
2406 newPortDirections.reserve(newPortCount);
2407 SmallVector<Attribute> newPortNames;
2408 newPortNames.reserve(newPortCount);
2409 SmallVector<Type> newPortTypes;
2410 newPortTypes.reserve(newPortCount);
2411 SmallVector<Attribute> newPortAnnos;
2412 newPortAnnos.reserve(newPortCount);
2413
2414 unsigned oldIndex = 0;
2415 unsigned newIndex = 0;
2416 while (oldIndex + newIndex < newPortCount) {
2417 // Check if we should insert a port here.
2418 if (newIndex < portSize && ports[newIndex].first == oldIndex) {
2419 auto &newPort = ports[newIndex].second;
2420 newPortDirections.push_back(newPort.direction);
2421 newPortNames.push_back(newPort.name);
2422 newPortTypes.push_back(newPort.type);
2423 newPortAnnos.push_back(newPort.annotations.getArrayAttr());
2424 ++newIndex;
2425 } else {
2426 // Copy the next old port.
2427 newPortDirections.push_back(getPortDirection(oldIndex));
2428 newPortNames.push_back(getPortName(oldIndex));
2429 newPortTypes.push_back(getType(oldIndex));
2430 newPortAnnos.push_back(getPortAnnotation(oldIndex));
2431 ++oldIndex;
2432 }
2433 }
2434
2435 // Create a new instance op with the reset inserted.
2436 return OpBuilder(*this).create<InstanceOp>(
2437 getLoc(), newPortTypes, getModuleName(), getName(), getNameKind(),
2438 newPortDirections, newPortNames, getAnnotations().getValue(),
2439 newPortAnnos, getLayers(), getLowerToBind(), getInnerSymAttr());
2440}
2441
2442LogicalResult InstanceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
2443 return instance_like_impl::verifyReferencedModule(*this, symbolTable,
2444 getModuleNameAttr());
2445}
2446
2447StringRef InstanceOp::getInstanceName() { return getName(); }
2448
2449StringAttr InstanceOp::getInstanceNameAttr() { return getNameAttr(); }
2450
2451void InstanceOp::print(OpAsmPrinter &p) {
2452 // Print the instance name.
2453 p << " ";
2454 p.printKeywordOrString(getName());
2455 if (auto attr = getInnerSymAttr()) {
2456 p << " sym ";
2457 p.printSymbolName(attr.getSymName());
2458 }
2459 if (getNameKindAttr().getValue() != NameKindEnum::DroppableName)
2460 p << ' ' << stringifyNameKindEnum(getNameKindAttr().getValue());
2461
2462 // Print the attr-dict.
2463 SmallVector<StringRef, 10> omittedAttrs = {
2464 "moduleName", "name", "portDirections",
2465 "portNames", "portTypes", "portAnnotations",
2466 "inner_sym", "nameKind"};
2467 if (getAnnotations().empty())
2468 omittedAttrs.push_back("annotations");
2469 if (getLayers().empty())
2470 omittedAttrs.push_back("layers");
2471 p.printOptionalAttrDict((*this)->getAttrs(), omittedAttrs);
2472
2473 // Print the module name.
2474 p << " ";
2475 p.printSymbolName(getModuleName());
2476
2477 // Collect all the result types as TypeAttrs for printing.
2478 SmallVector<Attribute> portTypes;
2479 portTypes.reserve(getNumResults());
2480 llvm::transform(getResultTypes(), std::back_inserter(portTypes),
2481 &TypeAttr::get);
2482 printModulePorts(p, /*block=*/nullptr, getPortDirectionsAttr(),
2483 getPortNames().getValue(), portTypes,
2484 getPortAnnotations().getValue(), {}, {});
2485}
2486
2487ParseResult InstanceOp::parse(OpAsmParser &parser, OperationState &result) {
2488 auto *context = parser.getContext();
2489 auto &properties = result.getOrAddProperties<Properties>();
2490
2491 std::string name;
2492 hw::InnerSymAttr innerSymAttr;
2493 FlatSymbolRefAttr moduleName;
2494 SmallVector<OpAsmParser::Argument> entryArgs;
2495 SmallVector<Direction, 4> portDirections;
2496 SmallVector<Attribute, 4> portNames;
2497 SmallVector<Attribute, 4> portTypes;
2498 SmallVector<Attribute, 4> portAnnotations;
2499 SmallVector<Attribute, 4> portSyms;
2500 SmallVector<Attribute, 4> portLocs;
2501 NameKindEnumAttr nameKind;
2502
2503 if (parser.parseKeywordOrString(&name))
2504 return failure();
2505 if (succeeded(parser.parseOptionalKeyword("sym"))) {
2506 if (parser.parseCustomAttributeWithFallback(
2507 innerSymAttr, ::mlir::Type{},
2509 result.attributes)) {
2510 return ::mlir::failure();
2511 }
2512 }
2513 if (parseNameKind(parser, nameKind) ||
2514 parser.parseOptionalAttrDict(result.attributes) ||
2515 parser.parseAttribute(moduleName) ||
2516 parseModulePorts(parser, /*hasSSAIdentifiers=*/false,
2517 /*supportsSymbols=*/false, entryArgs, portDirections,
2518 portNames, portTypes, portAnnotations, portSyms,
2519 portLocs))
2520 return failure();
2521
2522 // Add the attributes. We let attributes defined in the attr-dict override
2523 // attributes parsed out of the module signature.
2524
2525 properties.setModuleName(moduleName);
2526 properties.setName(StringAttr::get(context, name));
2527 properties.setNameKind(nameKind);
2528 properties.setPortDirections(
2529 direction::packAttribute(context, portDirections));
2530 properties.setPortNames(ArrayAttr::get(context, portNames));
2531 properties.setPortAnnotations(ArrayAttr::get(context, portAnnotations));
2532
2533 // Annotations, layers, and LowerToBind are omitted in the printed format
2534 // if they are empty, empty, and false (respectively).
2535 properties.setAnnotations(parser.getBuilder().getArrayAttr({}));
2536 properties.setLayers(parser.getBuilder().getArrayAttr({}));
2537
2538 // Add result types.
2539 result.types.reserve(portTypes.size());
2540 llvm::transform(
2541 portTypes, std::back_inserter(result.types),
2542 [](Attribute typeAttr) { return cast<TypeAttr>(typeAttr).getValue(); });
2543
2544 return success();
2545}
2546
2547void InstanceOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
2548 StringRef base = getName();
2549 if (base.empty())
2550 base = "inst";
2551
2552 for (size_t i = 0, e = (*this)->getNumResults(); i != e; ++i) {
2553 setNameFn(getResult(i), (base + "_" + getPortNameStr(i)).str());
2554 }
2555}
2556
2557std::optional<size_t> InstanceOp::getTargetResultIndex() {
2558 // Inner symbols on instance operations target the op not any result.
2559 return std::nullopt;
2560}
2561
2562// -----------------------------------------------------------------------------
2563// InstanceChoiceOp
2564// -----------------------------------------------------------------------------
2565
2566void InstanceChoiceOp::build(
2567 OpBuilder &builder, OperationState &result, FModuleLike defaultModule,
2568 ArrayRef<std::pair<OptionCaseOp, FModuleLike>> cases, StringRef name,
2569 NameKindEnum nameKind, ArrayRef<Attribute> annotations,
2570 ArrayRef<Attribute> portAnnotations, StringAttr innerSym) {
2571 // Gather the result types.
2572 SmallVector<Type> resultTypes;
2573 for (Attribute portType : defaultModule.getPortTypes())
2574 resultTypes.push_back(cast<TypeAttr>(portType).getValue());
2575
2576 // Create the port annotations.
2577 ArrayAttr portAnnotationsAttr;
2578 if (portAnnotations.empty()) {
2579 portAnnotationsAttr = builder.getArrayAttr(SmallVector<Attribute, 16>(
2580 resultTypes.size(), builder.getArrayAttr({})));
2581 } else {
2582 portAnnotationsAttr = builder.getArrayAttr(portAnnotations);
2583 }
2584
2585 // Gather the module & case names.
2586 SmallVector<Attribute> moduleNames, caseNames;
2587 moduleNames.push_back(SymbolRefAttr::get(defaultModule.getModuleNameAttr()));
2588 for (auto [caseOption, caseModule] : cases) {
2589 auto caseGroup = caseOption->getParentOfType<OptionOp>();
2590 caseNames.push_back(SymbolRefAttr::get(caseGroup.getSymNameAttr(),
2591 {SymbolRefAttr::get(caseOption)}));
2592 moduleNames.push_back(SymbolRefAttr::get(caseModule.getModuleNameAttr()));
2593 }
2594
2595 return build(builder, result, resultTypes, builder.getArrayAttr(moduleNames),
2596 builder.getArrayAttr(caseNames), builder.getStringAttr(name),
2597 NameKindEnumAttr::get(builder.getContext(), nameKind),
2598 defaultModule.getPortDirectionsAttr(),
2599 defaultModule.getPortNamesAttr(),
2600 builder.getArrayAttr(annotations), portAnnotationsAttr,
2601 defaultModule.getLayersAttr(),
2602 innerSym ? hw::InnerSymAttr::get(innerSym) : hw::InnerSymAttr());
2603}
2604
2605std::optional<size_t> InstanceChoiceOp::getTargetResultIndex() {
2606 return std::nullopt;
2607}
2608
2609void InstanceChoiceOp::print(OpAsmPrinter &p) {
2610 // Print the instance name.
2611 p << " ";
2612 p.printKeywordOrString(getName());
2613 if (auto attr = getInnerSymAttr()) {
2614 p << " sym ";
2615 p.printSymbolName(attr.getSymName());
2616 }
2617 if (getNameKindAttr().getValue() != NameKindEnum::DroppableName)
2618 p << ' ' << stringifyNameKindEnum(getNameKindAttr().getValue());
2619
2620 // Print the attr-dict.
2621 SmallVector<StringRef, 10> omittedAttrs = {
2622 "moduleNames", "caseNames", "name",
2623 "portDirections", "portNames", "portTypes",
2624 "portAnnotations", "inner_sym", "nameKind"};
2625 if (getAnnotations().empty())
2626 omittedAttrs.push_back("annotations");
2627 if (getLayers().empty())
2628 omittedAttrs.push_back("layers");
2629 p.printOptionalAttrDict((*this)->getAttrs(), omittedAttrs);
2630
2631 // Print the module name.
2632 p << ' ';
2633
2634 auto moduleNames = getModuleNamesAttr();
2635 auto caseNames = getCaseNamesAttr();
2636
2637 p.printSymbolName(cast<FlatSymbolRefAttr>(moduleNames[0]).getValue());
2638
2639 p << " alternatives ";
2640 p.printSymbolName(
2641 cast<SymbolRefAttr>(caseNames[0]).getRootReference().getValue());
2642 p << " { ";
2643 for (size_t i = 0, n = caseNames.size(); i < n; ++i) {
2644 if (i != 0)
2645 p << ", ";
2646
2647 auto symbol = cast<SymbolRefAttr>(caseNames[i]);
2648 p.printSymbolName(symbol.getNestedReferences()[0].getValue());
2649 p << " -> ";
2650 p.printSymbolName(cast<FlatSymbolRefAttr>(moduleNames[i + 1]).getValue());
2651 }
2652
2653 p << " } ";
2654
2655 // Collect all the result types as TypeAttrs for printing.
2656 SmallVector<Attribute> portTypes;
2657 portTypes.reserve(getNumResults());
2658 llvm::transform(getResultTypes(), std::back_inserter(portTypes),
2659 &TypeAttr::get);
2660 printModulePorts(p, /*block=*/nullptr, getPortDirectionsAttr(),
2661 getPortNames().getValue(), portTypes,
2662 getPortAnnotations().getValue(), {}, {});
2663}
2664
2665ParseResult InstanceChoiceOp::parse(OpAsmParser &parser,
2666 OperationState &result) {
2667 auto *context = parser.getContext();
2668 auto &properties = result.getOrAddProperties<Properties>();
2669
2670 std::string name;
2671 hw::InnerSymAttr innerSymAttr;
2672 SmallVector<Attribute> moduleNames;
2673 SmallVector<Attribute> caseNames;
2674 SmallVector<OpAsmParser::Argument> entryArgs;
2675 SmallVector<Direction, 4> portDirections;
2676 SmallVector<Attribute, 4> portNames;
2677 SmallVector<Attribute, 4> portTypes;
2678 SmallVector<Attribute, 4> portAnnotations;
2679 SmallVector<Attribute, 4> portSyms;
2680 SmallVector<Attribute, 4> portLocs;
2681 NameKindEnumAttr nameKind;
2682
2683 if (parser.parseKeywordOrString(&name))
2684 return failure();
2685 if (succeeded(parser.parseOptionalKeyword("sym"))) {
2686 if (parser.parseCustomAttributeWithFallback(
2687 innerSymAttr, Type{},
2689 result.attributes)) {
2690 return failure();
2691 }
2692 }
2693 if (parseNameKind(parser, nameKind) ||
2694 parser.parseOptionalAttrDict(result.attributes))
2695 return failure();
2696
2697 FlatSymbolRefAttr defaultModuleName;
2698 if (parser.parseAttribute(defaultModuleName))
2699 return failure();
2700 moduleNames.push_back(defaultModuleName);
2701
2702 // alternatives { @opt::@case -> @target, ... }
2703 {
2704 FlatSymbolRefAttr optionName;
2705 if (parser.parseKeyword("alternatives") ||
2706 parser.parseAttribute(optionName) || parser.parseLBrace())
2707 return failure();
2708
2709 FlatSymbolRefAttr moduleName;
2710 StringAttr caseName;
2711 while (succeeded(parser.parseOptionalSymbolName(caseName))) {
2712 if (parser.parseArrow() || parser.parseAttribute(moduleName))
2713 return failure();
2714 moduleNames.push_back(moduleName);
2715 caseNames.push_back(SymbolRefAttr::get(
2716 optionName.getAttr(), {FlatSymbolRefAttr::get(caseName)}));
2717 if (failed(parser.parseOptionalComma()))
2718 break;
2719 }
2720 if (parser.parseRBrace())
2721 return failure();
2722 }
2723
2724 if (parseModulePorts(parser, /*hasSSAIdentifiers=*/false,
2725 /*supportsSymbols=*/false, entryArgs, portDirections,
2726 portNames, portTypes, portAnnotations, portSyms,
2727 portLocs))
2728 return failure();
2729
2730 // Add the attributes. We let attributes defined in the attr-dict override
2731 // attributes parsed out of the module signature.
2732 properties.setModuleNames(ArrayAttr::get(context, moduleNames));
2733 properties.setCaseNames(ArrayAttr::get(context, caseNames));
2734 properties.setName(StringAttr::get(context, name));
2735 properties.setNameKind(nameKind);
2736 properties.setPortDirections(
2737 direction::packAttribute(context, portDirections));
2738 properties.setPortNames(ArrayAttr::get(context, portNames));
2739 properties.setPortAnnotations(ArrayAttr::get(context, portAnnotations));
2740
2741 // Annotations, layers, and LowerToBind are omitted in the printed format if
2742 // they are empty, empty, and false (respectively).
2743 properties.setAnnotations(parser.getBuilder().getArrayAttr({}));
2744 properties.setLayers(parser.getBuilder().getArrayAttr({}));
2745
2746 // Add result types.
2747 result.types.reserve(portTypes.size());
2748 llvm::transform(
2749 portTypes, std::back_inserter(result.types),
2750 [](Attribute typeAttr) { return cast<TypeAttr>(typeAttr).getValue(); });
2751
2752 return success();
2753}
2754
2755void InstanceChoiceOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
2756 StringRef base = getName().empty() ? "inst" : getName();
2757 for (auto [result, name] : llvm::zip(getResults(), getPortNames()))
2758 setNameFn(result, (base + "_" + cast<StringAttr>(name).getValue()).str());
2759}
2760
2761LogicalResult InstanceChoiceOp::verify() {
2762 if (getCaseNamesAttr().empty())
2763 return emitOpError() << "must have at least one case";
2764 if (getModuleNamesAttr().size() != getCaseNamesAttr().size() + 1)
2765 return emitOpError() << "number of referenced modules does not match the "
2766 "number of options";
2767
2768 // The modules may only be instantiated under their required layers (which
2769 // are the same for all modules).
2770 auto ambientLayers = getAmbientLayersAt(getOperation());
2771 SmallVector<SymbolRefAttr> missingLayers;
2772 for (auto layer : getLayersAttr().getAsRange<SymbolRefAttr>())
2773 if (!isLayerCompatibleWith(layer, ambientLayers))
2774 missingLayers.push_back(layer);
2775
2776 if (missingLayers.empty())
2777 return success();
2778
2779 auto diag =
2780 emitOpError("ambient layers are insufficient to instantiate module");
2781 auto &note = diag.attachNote();
2782 note << "missing layer requirements: ";
2783 interleaveComma(missingLayers, note);
2784 return failure();
2785}
2786
2787LogicalResult
2788InstanceChoiceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
2789 auto caseNames = getCaseNamesAttr();
2790 for (auto moduleName : getModuleNamesAttr()) {
2792 *this, symbolTable, cast<FlatSymbolRefAttr>(moduleName))))
2793 return failure();
2794 }
2795
2796 auto root = cast<SymbolRefAttr>(caseNames[0]).getRootReference();
2797 for (size_t i = 0, n = caseNames.size(); i < n; ++i) {
2798 auto ref = cast<SymbolRefAttr>(caseNames[i]);
2799 auto refRoot = ref.getRootReference();
2800 if (ref.getRootReference() != root)
2801 return emitOpError() << "case " << ref
2802 << " is not in the same option group as "
2803 << caseNames[0];
2804
2805 if (!symbolTable.lookupNearestSymbolFrom<OptionOp>(*this, refRoot))
2806 return emitOpError() << "option " << refRoot << " does not exist";
2807
2808 if (!symbolTable.lookupNearestSymbolFrom<OptionCaseOp>(*this, ref))
2809 return emitOpError() << "option " << refRoot
2810 << " does not contain option case " << ref;
2811 }
2812
2813 return success();
2814}
2815
2816FlatSymbolRefAttr
2817InstanceChoiceOp::getTargetOrDefaultAttr(OptionCaseOp option) {
2818 auto caseNames = getCaseNamesAttr();
2819 for (size_t i = 0, n = caseNames.size(); i < n; ++i) {
2820 StringAttr caseSym = cast<SymbolRefAttr>(caseNames[i]).getLeafReference();
2821 if (caseSym == option.getSymName())
2822 return cast<FlatSymbolRefAttr>(getModuleNamesAttr()[i + 1]);
2823 }
2824 return getDefaultTargetAttr();
2825}
2826
2827SmallVector<std::pair<SymbolRefAttr, FlatSymbolRefAttr>, 1>
2828InstanceChoiceOp::getTargetChoices() {
2829 auto caseNames = getCaseNamesAttr();
2830 auto moduleNames = getModuleNamesAttr();
2831 SmallVector<std::pair<SymbolRefAttr, FlatSymbolRefAttr>, 1> choices;
2832 for (size_t i = 0; i < caseNames.size(); ++i) {
2833 choices.emplace_back(cast<SymbolRefAttr>(caseNames[i]),
2834 cast<FlatSymbolRefAttr>(moduleNames[i + 1]));
2835 }
2836
2837 return choices;
2838}
2839
2840InstanceChoiceOp
2841InstanceChoiceOp::erasePorts(OpBuilder &builder,
2842 const llvm::BitVector &portIndices) {
2843 assert(portIndices.size() >= getNumResults() &&
2844 "portIndices is not at least as large as getNumResults()");
2845
2846 if (portIndices.none())
2847 return *this;
2848
2849 SmallVector<Type> newResultTypes = removeElementsAtIndices<Type>(
2850 SmallVector<Type>(result_type_begin(), result_type_end()), portIndices);
2851 SmallVector<Direction> newPortDirections = removeElementsAtIndices<Direction>(
2852 direction::unpackAttribute(getPortDirectionsAttr()), portIndices);
2853 SmallVector<Attribute> newPortNames =
2854 removeElementsAtIndices(getPortNames().getValue(), portIndices);
2855 SmallVector<Attribute> newPortAnnotations =
2856 removeElementsAtIndices(getPortAnnotations().getValue(), portIndices);
2857
2858 auto newOp = builder.create<InstanceChoiceOp>(
2859 getLoc(), newResultTypes, getModuleNames(), getCaseNames(), getName(),
2860 getNameKind(), direction::packAttribute(getContext(), newPortDirections),
2861 ArrayAttr::get(getContext(), newPortNames), getAnnotationsAttr(),
2862 ArrayAttr::get(getContext(), newPortAnnotations), getLayers(),
2863 getInnerSymAttr());
2864
2865 for (unsigned oldIdx = 0, newIdx = 0, numOldPorts = getNumResults();
2866 oldIdx != numOldPorts; ++oldIdx) {
2867 if (portIndices.test(oldIdx)) {
2868 assert(getResult(oldIdx).use_empty() && "removed instance port has uses");
2869 continue;
2870 }
2871 getResult(oldIdx).replaceAllUsesWith(newOp.getResult(newIdx));
2872 ++newIdx;
2873 }
2874
2875 // Copy over "output_file" information so that this is not lost when ports
2876 // are erased.
2877 //
2878 // TODO: Other attributes may need to be copied over.
2879 if (auto outputFile = (*this)->getAttr("output_file"))
2880 newOp->setAttr("output_file", outputFile);
2881
2882 return newOp;
2883}
2884
2885//===----------------------------------------------------------------------===//
2886// MemOp
2887//===----------------------------------------------------------------------===//
2888
2889ArrayAttr MemOp::getPortAnnotation(unsigned portIdx) {
2890 assert(portIdx < getNumResults() &&
2891 "index should be smaller than result number");
2892 return cast<ArrayAttr>(getPortAnnotations()[portIdx]);
2893}
2894
2895void MemOp::setAllPortAnnotations(ArrayRef<Attribute> annotations) {
2896 assert(annotations.size() == getNumResults() &&
2897 "number of annotations is not equal to result number");
2898 (*this)->setAttr("portAnnotations",
2899 ArrayAttr::get(getContext(), annotations));
2900}
2901
2902// Get the number of read, write and read-write ports.
2903void MemOp::getNumPorts(size_t &numReadPorts, size_t &numWritePorts,
2904 size_t &numReadWritePorts, size_t &numDbgsPorts) {
2905 numReadPorts = 0;
2906 numWritePorts = 0;
2907 numReadWritePorts = 0;
2908 numDbgsPorts = 0;
2909 for (size_t i = 0, e = getNumResults(); i != e; ++i) {
2910 auto portKind = getPortKind(i);
2911 if (portKind == MemOp::PortKind::Debug)
2912 ++numDbgsPorts;
2913 else if (portKind == MemOp::PortKind::Read)
2914 ++numReadPorts;
2915 else if (portKind == MemOp::PortKind::Write) {
2916 ++numWritePorts;
2917 } else
2918 ++numReadWritePorts;
2919 }
2920}
2921
2922/// Verify the correctness of a MemOp.
2923LogicalResult MemOp::verify() {
2924
2925 // Store the port names as we find them. This lets us check quickly
2926 // for uniqueneess.
2927 llvm::SmallDenseSet<Attribute, 8> portNamesSet;
2928
2929 // Store the previous data type. This lets us check that the data
2930 // type is consistent across all ports.
2931 FIRRTLType oldDataType;
2932
2933 for (size_t i = 0, e = getNumResults(); i != e; ++i) {
2934 auto portName = getPortName(i);
2935
2936 // Get a bundle type representing this port, stripping an outer
2937 // flip if it exists. If this is not a bundle<> or
2938 // flip<bundle<>>, then this is an error.
2939 BundleType portBundleType =
2940 type_dyn_cast<BundleType>(getResult(i).getType());
2941
2942 // Require that all port names are unique.
2943 if (!portNamesSet.insert(portName).second) {
2944 emitOpError() << "has non-unique port name " << portName;
2945 return failure();
2946 }
2947
2948 // Determine the kind of the memory. If the kind cannot be
2949 // determined, then it's indicative of the wrong number of fields
2950 // in the type (but we don't know any more just yet).
2951
2952 auto elt = getPortNamed(portName);
2953 if (!elt) {
2954 emitOpError() << "could not get port with name " << portName;
2955 return failure();
2956 }
2957 auto firrtlType = type_cast<FIRRTLType>(elt.getType());
2958 MemOp::PortKind portKind = getMemPortKindFromType(firrtlType);
2959
2960 if (portKind == MemOp::PortKind::Debug &&
2961 !type_isa<RefType>(getResult(i).getType()))
2962 return emitOpError() << "has an invalid type on port " << portName
2963 << " (expected Read/Write/ReadWrite/Debug)";
2964 if (type_isa<RefType>(firrtlType) && e == 1)
2965 return emitOpError()
2966 << "cannot have only one port of debug type. Debug port can only "
2967 "exist alongside other read/write/read-write port";
2968
2969 // Safely search for the "data" field, erroring if it can't be
2970 // found.
2971 FIRRTLBaseType dataType;
2972 if (portKind == MemOp::PortKind::Debug) {
2973 auto resType = type_cast<RefType>(getResult(i).getType());
2974 if (!(resType && type_isa<FVectorType>(resType.getType())))
2975 return emitOpError() << "debug ports must be a RefType of FVectorType";
2976 dataType = type_cast<FVectorType>(resType.getType()).getElementType();
2977 } else {
2978 auto dataTypeOption = portBundleType.getElement("data");
2979 if (!dataTypeOption && portKind == MemOp::PortKind::ReadWrite)
2980 dataTypeOption = portBundleType.getElement("wdata");
2981 if (!dataTypeOption) {
2982 emitOpError() << "has no data field on port " << portName
2983 << " (expected to see \"data\" for a read or write "
2984 "port or \"rdata\" for a read/write port)";
2985 return failure();
2986 }
2987 dataType = dataTypeOption->type;
2988 // Read data is expected to ba a flip.
2989 if (portKind == MemOp::PortKind::Read) {
2990 // FIXME error on missing bundle flip
2991 }
2992 }
2993
2994 // Error if the data type isn't passive.
2995 if (!dataType.isPassive()) {
2996 emitOpError() << "has non-passive data type on port " << portName
2997 << " (memory types must be passive)";
2998 return failure();
2999 }
3000
3001 // Error if the data type contains analog types.
3002 if (dataType.containsAnalog()) {
3003 emitOpError() << "has a data type that contains an analog type on port "
3004 << portName
3005 << " (memory types cannot contain analog types)";
3006 return failure();
3007 }
3008
3009 // Check that the port type matches the kind that we determined
3010 // for this port. This catches situations of extraneous port
3011 // fields beind included or the fields being named incorrectly.
3012 FIRRTLType expectedType =
3013 getTypeForPort(getDepth(), dataType, portKind,
3014 dataType.isGround() ? getMaskBits() : 0);
3015 // Compute the original port type as portBundleType may have
3016 // stripped outer flip information.
3017 auto originalType = getResult(i).getType();
3018 if (originalType != expectedType) {
3019 StringRef portKindName;
3020 switch (portKind) {
3021 case MemOp::PortKind::Read:
3022 portKindName = "read";
3023 break;
3024 case MemOp::PortKind::Write:
3025 portKindName = "write";
3026 break;
3027 case MemOp::PortKind::ReadWrite:
3028 portKindName = "readwrite";
3029 break;
3030 case MemOp::PortKind::Debug:
3031 portKindName = "dbg";
3032 break;
3033 }
3034 emitOpError() << "has an invalid type for port " << portName
3035 << " of determined kind \"" << portKindName
3036 << "\" (expected " << expectedType << ", but got "
3037 << originalType << ")";
3038 return failure();
3039 }
3040
3041 // Error if the type of the current port was not the same as the
3042 // last port, but skip checking the first port.
3043 if (oldDataType && oldDataType != dataType) {
3044 emitOpError() << "port " << getPortName(i)
3045 << " has a different type than port " << getPortName(i - 1)
3046 << " (expected " << oldDataType << ", but got " << dataType
3047 << ")";
3048 return failure();
3049 }
3050
3051 oldDataType = dataType;
3052 }
3053
3054 auto maskWidth = getMaskBits();
3055
3056 auto dataWidth = getDataType().getBitWidthOrSentinel();
3057 if (dataWidth > 0 && maskWidth > (size_t)dataWidth)
3058 return emitOpError("the mask width cannot be greater than "
3059 "data width");
3060
3061 if (getPortAnnotations().size() != getNumResults())
3062 return emitOpError("the number of result annotations should be "
3063 "equal to the number of results");
3064
3065 return success();
3066}
3067
3068static size_t getAddressWidth(size_t depth) {
3069 return std::max(1U, llvm::Log2_64_Ceil(depth));
3070}
3071
3072size_t MemOp::getAddrBits() { return getAddressWidth(getDepth()); }
3073
3074FIRRTLType MemOp::getTypeForPort(uint64_t depth, FIRRTLBaseType dataType,
3075 PortKind portKind, size_t maskBits) {
3076
3077 auto *context = dataType.getContext();
3078 if (portKind == PortKind::Debug)
3079 return RefType::get(FVectorType::get(dataType, depth));
3080 FIRRTLBaseType maskType;
3081 // maskBits not specified (==0), then get the mask type from the dataType.
3082 if (maskBits == 0)
3083 maskType = dataType.getMaskType();
3084 else
3085 maskType = UIntType::get(context, maskBits);
3086
3087 auto getId = [&](StringRef name) -> StringAttr {
3088 return StringAttr::get(context, name);
3089 };
3090
3091 SmallVector<BundleType::BundleElement, 7> portFields;
3092
3093 auto addressType = UIntType::get(context, getAddressWidth(depth));
3094
3095 portFields.push_back({getId("addr"), false, addressType});
3096 portFields.push_back({getId("en"), false, UIntType::get(context, 1)});
3097 portFields.push_back({getId("clk"), false, ClockType::get(context)});
3098
3099 switch (portKind) {
3100 case PortKind::Read:
3101 portFields.push_back({getId("data"), true, dataType});
3102 break;
3103
3104 case PortKind::Write:
3105 portFields.push_back({getId("data"), false, dataType});
3106 portFields.push_back({getId("mask"), false, maskType});
3107 break;
3108
3109 case PortKind::ReadWrite:
3110 portFields.push_back({getId("rdata"), true, dataType});
3111 portFields.push_back({getId("wmode"), false, UIntType::get(context, 1)});
3112 portFields.push_back({getId("wdata"), false, dataType});
3113 portFields.push_back({getId("wmask"), false, maskType});
3114 break;
3115 default:
3116 llvm::report_fatal_error("memory port kind not handled");
3117 break;
3118 }
3119
3120 return BundleType::get(context, portFields);
3121}
3122
3123/// Return the name and kind of ports supported by this memory.
3124SmallVector<MemOp::NamedPort> MemOp::getPorts() {
3125 SmallVector<MemOp::NamedPort> result;
3126 // Each entry in the bundle is a port.
3127 for (size_t i = 0, e = getNumResults(); i != e; ++i) {
3128 // Each port is a bundle.
3129 auto portType = type_cast<FIRRTLType>(getResult(i).getType());
3130 result.push_back({getPortName(i), getMemPortKindFromType(portType)});
3131 }
3132 return result;
3133}
3134
3135/// Return the kind of the specified port.
3136MemOp::PortKind MemOp::getPortKind(StringRef portName) {
3138 type_cast<FIRRTLType>(getPortNamed(portName).getType()));
3139}
3140
3141/// Return the kind of the specified port number.
3142MemOp::PortKind MemOp::getPortKind(size_t resultNo) {
3144 type_cast<FIRRTLType>(getResult(resultNo).getType()));
3145}
3146
3147/// Return the number of bits in the mask for the memory.
3148size_t MemOp::getMaskBits() {
3149
3150 for (auto res : getResults()) {
3151 if (type_isa<RefType>(res.getType()))
3152 continue;
3153 auto firstPortType = type_cast<FIRRTLBaseType>(res.getType());
3154 if (getMemPortKindFromType(firstPortType) == PortKind::Read ||
3155 getMemPortKindFromType(firstPortType) == PortKind::Debug)
3156 continue;
3157
3158 FIRRTLBaseType mType;
3159 for (auto t : type_cast<BundleType>(firstPortType.getPassiveType())) {
3160 if (t.name.getValue().contains("mask"))
3161 mType = t.type;
3162 }
3163 if (type_isa<UIntType>(mType))
3164 return mType.getBitWidthOrSentinel();
3165 }
3166 // Mask of zero bits means, either there are no write/readwrite ports or the
3167 // mask is of aggregate type.
3168 return 0;
3169}
3170
3171/// Return the data-type field of the memory, the type of each element.
3172FIRRTLBaseType MemOp::getDataType() {
3173 assert(getNumResults() != 0 && "Mems with no read/write ports are illegal");
3174
3175 if (auto refType = type_dyn_cast<RefType>(getResult(0).getType()))
3176 return type_cast<FVectorType>(refType.getType()).getElementType();
3177 auto firstPortType = type_cast<FIRRTLBaseType>(getResult(0).getType());
3178
3179 StringRef dataFieldName = "data";
3180 if (getMemPortKindFromType(firstPortType) == PortKind::ReadWrite)
3181 dataFieldName = "rdata";
3182
3183 return type_cast<BundleType>(firstPortType.getPassiveType())
3184 .getElementType(dataFieldName);
3185}
3186
3187StringAttr MemOp::getPortName(size_t resultNo) {
3188 return cast<StringAttr>(getPortNames()[resultNo]);
3189}
3190
3191FIRRTLBaseType MemOp::getPortType(size_t resultNo) {
3192 return type_cast<FIRRTLBaseType>(getResults()[resultNo].getType());
3193}
3194
3195Value MemOp::getPortNamed(StringAttr name) {
3196 auto namesArray = getPortNames();
3197 for (size_t i = 0, e = namesArray.size(); i != e; ++i) {
3198 if (namesArray[i] == name) {
3199 assert(i < getNumResults() && " names array out of sync with results");
3200 return getResult(i);
3201 }
3202 }
3203 return Value();
3204}
3205
3206// Extract all the relevant attributes from the MemOp and return the FirMemory.
3207FirMemory MemOp::getSummary() {
3208 auto op = *this;
3209 size_t numReadPorts = 0;
3210 size_t numWritePorts = 0;
3211 size_t numReadWritePorts = 0;
3213 SmallVector<int32_t> writeClockIDs;
3214
3215 for (size_t i = 0, e = op.getNumResults(); i != e; ++i) {
3216 auto portKind = op.getPortKind(i);
3217 if (portKind == MemOp::PortKind::Read)
3218 ++numReadPorts;
3219 else if (portKind == MemOp::PortKind::Write) {
3220 for (auto *a : op.getResult(i).getUsers()) {
3221 auto subfield = dyn_cast<SubfieldOp>(a);
3222 if (!subfield || subfield.getFieldIndex() != 2)
3223 continue;
3224 auto clockPort = a->getResult(0);
3225 for (auto *b : clockPort.getUsers()) {
3226 if (auto connect = dyn_cast<FConnectLike>(b)) {
3227 if (connect.getDest() == clockPort) {
3228 auto result =
3229 clockToLeader.insert({circt::firrtl::getModuleScopedDriver(
3230 connect.getSrc(), true, true, true),
3231 numWritePorts});
3232 if (result.second) {
3233 writeClockIDs.push_back(numWritePorts);
3234 } else {
3235 writeClockIDs.push_back(result.first->second);
3236 }
3237 }
3238 }
3239 }
3240 break;
3241 }
3242 ++numWritePorts;
3243 } else
3244 ++numReadWritePorts;
3245 }
3246
3247 size_t width = 0;
3248 if (auto widthV = getBitWidth(op.getDataType()))
3249 width = *widthV;
3250 else
3251 op.emitError("'firrtl.mem' should have simple type and known width");
3252 MemoryInitAttr init = op->getAttrOfType<MemoryInitAttr>("init");
3253 StringAttr modName;
3254 if (op->hasAttr("modName"))
3255 modName = op->getAttrOfType<StringAttr>("modName");
3256 else {
3257 SmallString<8> clocks;
3258 for (auto a : writeClockIDs)
3259 clocks.append(Twine((char)(a + 'a')).str());
3260 SmallString<32> initStr;
3261 // If there is a file initialization, then come up with a decent
3262 // representation for this. Use the filename, but only characters
3263 // [a-zA-Z0-9] and the bool/hex and inline booleans.
3264 if (init) {
3265 for (auto c : init.getFilename().getValue())
3266 if ((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') ||
3267 (c >= '0' && c <= '9'))
3268 initStr.push_back(c);
3269 initStr.push_back('_');
3270 initStr.push_back(init.getIsBinary() ? 't' : 'f');
3271 initStr.push_back('_');
3272 initStr.push_back(init.getIsInline() ? 't' : 'f');
3273 }
3274 modName = StringAttr::get(
3275 op->getContext(),
3276 llvm::formatv(
3277 "{0}FIRRTLMem_{1}_{2}_{3}_{4}_{5}_{6}_{7}_{8}_{9}_{10}{11}{12}",
3278 op.getPrefix().value_or(""), numReadPorts, numWritePorts,
3279 numReadWritePorts, (size_t)width, op.getDepth(),
3280 op.getReadLatency(), op.getWriteLatency(), op.getMaskBits(),
3281 (unsigned)op.getRuw(), (unsigned)seq::WUW::PortOrder,
3282 clocks.empty() ? "" : "_" + clocks, init ? initStr.str() : ""));
3283 }
3284 return {numReadPorts,
3285 numWritePorts,
3286 numReadWritePorts,
3287 (size_t)width,
3288 op.getDepth(),
3289 op.getReadLatency(),
3290 op.getWriteLatency(),
3291 op.getMaskBits(),
3292 *seq::symbolizeRUW(unsigned(op.getRuw())),
3293 seq::WUW::PortOrder,
3294 writeClockIDs,
3295 modName,
3296 op.getMaskBits() > 1,
3297 init,
3298 op.getPrefixAttr(),
3299 op.getLoc()};
3300}
3301
3302void MemOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
3303 StringRef base = getName();
3304 if (base.empty())
3305 base = "mem";
3306
3307 for (size_t i = 0, e = (*this)->getNumResults(); i != e; ++i) {
3308 setNameFn(getResult(i), (base + "_" + getPortNameStr(i)).str());
3309 }
3310}
3311
3312std::optional<size_t> MemOp::getTargetResultIndex() {
3313 // Inner symbols on memory operations target the op not any result.
3314 return std::nullopt;
3315}
3316
3317// Construct name of the module which will be used for the memory definition.
3318StringAttr FirMemory::getFirMemoryName() const { return modName; }
3319
3320/// Helper for naming forceable declarations (and their optional ref result).
3321static void forceableAsmResultNames(Forceable op, StringRef name,
3322 OpAsmSetValueNameFn setNameFn) {
3323 if (name.empty())
3324 return;
3325 setNameFn(op.getDataRaw(), name);
3326 if (op.isForceable())
3327 setNameFn(op.getDataRef(), (name + "_ref").str());
3328}
3329
3330void NodeOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
3331 return forceableAsmResultNames(*this, getName(), setNameFn);
3332}
3333
3334LogicalResult NodeOp::inferReturnTypes(
3335 mlir::MLIRContext *context, std::optional<mlir::Location> location,
3336 ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes,
3337 ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions,
3338 ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) {
3339 if (operands.empty())
3340 return failure();
3341 Adaptor adaptor(operands, attributes, properties, regions);
3342 inferredReturnTypes.push_back(adaptor.getInput().getType());
3343 if (adaptor.getForceable()) {
3344 auto forceableType = firrtl::detail::getForceableResultType(
3345 true, adaptor.getInput().getType());
3346 if (!forceableType) {
3347 if (location)
3348 ::mlir::emitError(*location, "cannot force a node of type ")
3349 << operands[0].getType();
3350 return failure();
3351 }
3352 inferredReturnTypes.push_back(forceableType);
3353 }
3354 return success();
3355}
3356
3357std::optional<size_t> NodeOp::getTargetResultIndex() { return 0; }
3358
3359void RegOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
3360 return forceableAsmResultNames(*this, getName(), setNameFn);
3361}
3362
3363std::optional<size_t> RegOp::getTargetResultIndex() { return 0; }
3364
3365SmallVector<std::pair<circt::FieldRef, circt::FieldRef>>
3366RegOp::computeDataFlow() {
3367 // A register does't have any combinational dataflow.
3368 return {};
3369}
3370
3371LogicalResult RegResetOp::verify() {
3372 auto reset = getResetValue();
3373
3374 FIRRTLBaseType resetType = reset.getType();
3375 FIRRTLBaseType regType = getResult().getType();
3376
3377 // The type of the initialiser must be equivalent to the register type.
3378 if (!areTypesEquivalent(regType, resetType))
3379 return emitError("type mismatch between register ")
3380 << regType << " and reset value " << resetType;
3381
3382 return success();
3383}
3384
3385std::optional<size_t> RegResetOp::getTargetResultIndex() { return 0; }
3386
3387void RegResetOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
3388 return forceableAsmResultNames(*this, getName(), setNameFn);
3389}
3390
3391//===----------------------------------------------------------------------===//
3392// FormalOp
3393//===----------------------------------------------------------------------===//
3394
3395LogicalResult
3396FormalOp::verifySymbolUses(::mlir::SymbolTableCollection &symbolTable) {
3397 // The referenced symbol is restricted to FModuleOps
3398 auto referencedModule = symbolTable.lookupNearestSymbolFrom<FModuleOp>(
3399 *this, getModuleNameAttr());
3400 if (!referencedModule)
3401 return (*this)->emitOpError("invalid symbol reference");
3402
3403 return success();
3404}
3405
3406//===----------------------------------------------------------------------===//
3407// WireOp
3408//===----------------------------------------------------------------------===//
3409
3410void WireOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
3411 return forceableAsmResultNames(*this, getName(), setNameFn);
3412}
3413
3414SmallVector<std::pair<circt::FieldRef, circt::FieldRef>>
3415RegResetOp::computeDataFlow() {
3416 // A register does't have any combinational dataflow.
3417 return {};
3418}
3419
3420std::optional<size_t> WireOp::getTargetResultIndex() { return 0; }
3421
3422LogicalResult WireOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
3423 auto refType = type_dyn_cast<RefType>(getType(0));
3424 if (!refType)
3425 return success();
3426
3427 return verifyProbeType(
3428 refType, getLoc(), getOperation()->getParentOfType<CircuitOp>(),
3429 symbolTable, Twine("'") + getOperationName() + "' op is");
3430}
3431
3432//===----------------------------------------------------------------------===//
3433// ObjectOp
3434//===----------------------------------------------------------------------===//
3435
3436void ObjectOp::build(OpBuilder &builder, OperationState &state, ClassLike klass,
3437 StringRef name) {
3438 build(builder, state, klass.getInstanceType(),
3439 StringAttr::get(builder.getContext(), name));
3440}
3441
3442LogicalResult ObjectOp::verify() { return success(); }
3443
3444LogicalResult ObjectOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
3445 auto circuitOp = getOperation()->getParentOfType<CircuitOp>();
3446 auto classType = getType();
3447 auto className = classType.getNameAttr();
3448
3449 // verify that the class exists.
3450 auto classOp = dyn_cast_or_null<ClassLike>(
3451 symbolTable.lookupSymbolIn(circuitOp, className));
3452 if (!classOp)
3453 return emitOpError() << "references unknown class " << className;
3454
3455 // verify that the result type agrees with the class definition.
3456 if (failed(classOp.verifyType(classType, [&]() { return emitOpError(); })))
3457 return failure();
3458
3459 return success();
3460}
3461
3462StringAttr ObjectOp::getClassNameAttr() {
3463 return getType().getNameAttr().getAttr();
3464}
3465
3466StringRef ObjectOp::getClassName() { return getType().getName(); }
3467
3468ClassLike ObjectOp::getReferencedClass(const SymbolTable &symbolTable) {
3469 auto symRef = getType().getNameAttr();
3470 return symbolTable.lookup<ClassLike>(symRef.getLeafReference());
3471}
3472
3473Operation *ObjectOp::getReferencedOperation(const SymbolTable &symtbl) {
3474 return getReferencedClass(symtbl);
3475}
3476
3477StringRef ObjectOp::getInstanceName() { return getName(); }
3478
3479StringAttr ObjectOp::getInstanceNameAttr() { return getNameAttr(); }
3480
3481StringRef ObjectOp::getReferencedModuleName() { return getClassName(); }
3482
3483StringAttr ObjectOp::getReferencedModuleNameAttr() {
3484 return getClassNameAttr();
3485}
3486
3487void ObjectOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
3488 setNameFn(getResult(), getName());
3489}
3490
3491//===----------------------------------------------------------------------===//
3492// Statements
3493//===----------------------------------------------------------------------===//
3494
3495LogicalResult AttachOp::verify() {
3496 // All known widths must match.
3497 std::optional<int32_t> commonWidth;
3498 for (auto operand : getOperands()) {
3499 auto thisWidth = type_cast<AnalogType>(operand.getType()).getWidth();
3500 if (!thisWidth)
3501 continue;
3502 if (!commonWidth) {
3503 commonWidth = thisWidth;
3504 continue;
3505 }
3506 if (commonWidth != thisWidth)
3507 return emitOpError("is inavlid as not all known operand widths match");
3508 }
3509 return success();
3510}
3511
3512/// Check if the source and sink are of appropriate flow.
3513static LogicalResult checkConnectFlow(Operation *connect) {
3514 Value dst = connect->getOperand(0);
3515 Value src = connect->getOperand(1);
3516
3517 // TODO: Relax this to allow reads from output ports,
3518 // instance/memory input ports.
3519 auto srcFlow = foldFlow(src);
3520 if (!isValidSrc(srcFlow)) {
3521 // A sink that is a port output or instance input used as a source is okay.
3522 auto kind = getDeclarationKind(src);
3523 if (kind != DeclKind::Port && kind != DeclKind::Instance) {
3524 auto srcRef = getFieldRefFromValue(src, /*lookThroughCasts=*/true);
3525 auto [srcName, rootKnown] = getFieldName(srcRef);
3526 auto diag = emitError(connect->getLoc());
3527 diag << "connect has invalid flow: the source expression ";
3528 if (rootKnown)
3529 diag << "\"" << srcName << "\" ";
3530 diag << "has " << toString(srcFlow) << ", expected source or duplex flow";
3531 return diag.attachNote(srcRef.getLoc()) << "the source was defined here";
3532 }
3533 }
3534
3535 auto dstFlow = foldFlow(dst);
3536 if (!isValidDst(dstFlow)) {
3537 auto dstRef = getFieldRefFromValue(dst, /*lookThroughCasts=*/true);
3538 auto [dstName, rootKnown] = getFieldName(dstRef);
3539 auto diag = emitError(connect->getLoc());
3540 diag << "connect has invalid flow: the destination expression ";
3541 if (rootKnown)
3542 diag << "\"" << dstName << "\" ";
3543 diag << "has " << toString(dstFlow) << ", expected sink or duplex flow";
3544 return diag.attachNote(dstRef.getLoc())
3545 << "the destination was defined here";
3546 }
3547 return success();
3548}
3549
3550// NOLINTBEGIN(misc-no-recursion)
3551/// Checks if the type has any 'const' leaf elements . If `isFlip` is `true`,
3552/// the `const` leaf is not considered to be driven.
3553static bool isConstFieldDriven(FIRRTLBaseType type, bool isFlip = false,
3554 bool outerTypeIsConst = false) {
3555 auto typeIsConst = outerTypeIsConst || type.isConst();
3556
3557 if (typeIsConst && type.isPassive())
3558 return !isFlip;
3559
3560 if (auto bundleType = type_dyn_cast<BundleType>(type))
3561 return llvm::any_of(bundleType.getElements(), [&](auto &element) {
3562 return isConstFieldDriven(element.type, isFlip ^ element.isFlip,
3563 typeIsConst);
3564 });
3565
3566 if (auto vectorType = type_dyn_cast<FVectorType>(type))
3567 return isConstFieldDriven(vectorType.getElementType(), isFlip, typeIsConst);
3568
3569 if (typeIsConst)
3570 return !isFlip;
3571 return false;
3572}
3573// NOLINTEND(misc-no-recursion)
3574
3575/// Checks that connections to 'const' destinations are not dependent on
3576/// non-'const' conditions in when blocks.
3577static LogicalResult checkConnectConditionality(FConnectLike connect) {
3578 auto dest = connect.getDest();
3579 auto destType = type_dyn_cast<FIRRTLBaseType>(dest.getType());
3580 auto src = connect.getSrc();
3581 auto srcType = type_dyn_cast<FIRRTLBaseType>(src.getType());
3582 if (!destType || !srcType)
3583 return success();
3584
3585 auto destRefinedType = destType;
3586 auto srcRefinedType = srcType;
3587
3588 /// Looks up the value's defining op until the defining op is null or a
3589 /// declaration of the value. If a SubAccessOp is encountered with a 'const'
3590 /// input, `originalFieldType` is made 'const'.
3591 auto findFieldDeclarationRefiningFieldType =
3592 [](Value value, FIRRTLBaseType &originalFieldType) -> Value {
3593 while (auto *definingOp = value.getDefiningOp()) {
3594 bool shouldContinue = true;
3595 TypeSwitch<Operation *>(definingOp)
3596 .Case<SubfieldOp, SubindexOp>([&](auto op) { value = op.getInput(); })
3597 .Case<SubaccessOp>([&](SubaccessOp op) {
3598 if (op.getInput()
3599 .getType()
3600 .base()
3601 .getElementTypePreservingConst()
3602 .isConst())
3603 originalFieldType = originalFieldType.getConstType(true);
3604 value = op.getInput();
3605 })
3606 .Default([&](Operation *) { shouldContinue = false; });
3607 if (!shouldContinue)
3608 break;
3609 }
3610 return value;
3611 };
3612
3613 auto destDeclaration =
3614 findFieldDeclarationRefiningFieldType(dest, destRefinedType);
3615 auto srcDeclaration =
3616 findFieldDeclarationRefiningFieldType(src, srcRefinedType);
3617
3618 auto checkConstConditionality = [&](Value value, FIRRTLBaseType type,
3619 Value declaration) -> LogicalResult {
3620 auto *declarationBlock = declaration.getParentBlock();
3621 auto *block = connect->getBlock();
3622 while (block && block != declarationBlock) {
3623 auto *parentOp = block->getParentOp();
3624
3625 if (auto whenOp = dyn_cast<WhenOp>(parentOp);
3626 whenOp && !whenOp.getCondition().getType().isConst()) {
3627 if (type.isConst())
3628 return connect.emitOpError()
3629 << "assignment to 'const' type " << type
3630 << " is dependent on a non-'const' condition";
3631 return connect->emitOpError()
3632 << "assignment to nested 'const' member of type " << type
3633 << " is dependent on a non-'const' condition";
3634 }
3635
3636 block = parentOp->getBlock();
3637 }
3638 return success();
3639 };
3640
3641 auto emitSubaccessError = [&] {
3642 return connect.emitError(
3643 "assignment to non-'const' subaccess of 'const' type is disallowed");
3644 };
3645
3646 // Check destination if it contains 'const' leaves
3647 if (destRefinedType.containsConst() && isConstFieldDriven(destRefinedType)) {
3648 // Disallow assignment to non-'const' subaccesses of 'const' types
3649 if (destType != destRefinedType)
3650 return emitSubaccessError();
3651
3652 if (failed(checkConstConditionality(dest, destType, destDeclaration)))
3653 return failure();
3654 }
3655
3656 // Check source if it contains 'const' 'flip' leaves
3657 if (srcRefinedType.containsConst() &&
3658 isConstFieldDriven(srcRefinedType, /*isFlip=*/true)) {
3659 // Disallow assignment to non-'const' subaccesses of 'const' types
3660 if (srcType != srcRefinedType)
3661 return emitSubaccessError();
3662 if (failed(checkConstConditionality(src, srcType, srcDeclaration)))
3663 return failure();
3664 }
3665
3666 return success();
3667}
3668
3669LogicalResult ConnectOp::verify() {
3670 auto dstType = getDest().getType();
3671 auto srcType = getSrc().getType();
3672 auto dstBaseType = type_dyn_cast<FIRRTLBaseType>(dstType);
3673 auto srcBaseType = type_dyn_cast<FIRRTLBaseType>(srcType);
3674 if (!dstBaseType || !srcBaseType) {
3675 if (dstType != srcType)
3676 return emitError("may not connect different non-base types");
3677 } else {
3678 // Analog types cannot be connected and must be attached.
3679 if (dstBaseType.containsAnalog() || srcBaseType.containsAnalog())
3680 return emitError("analog types may not be connected");
3681
3682 // Destination and source types must be equivalent.
3683 if (!areTypesEquivalent(dstBaseType, srcBaseType))
3684 return emitError("type mismatch between destination ")
3685 << dstBaseType << " and source " << srcBaseType;
3686
3687 // Truncation is banned in a connection: destination bit width must be
3688 // greater than or equal to source bit width.
3689 if (!isTypeLarger(dstBaseType, srcBaseType))
3690 return emitError("destination ")
3691 << dstBaseType << " is not as wide as the source " << srcBaseType;
3692 }
3693
3694 // Check that the flows make sense.
3695 if (failed(checkConnectFlow(*this)))
3696 return failure();
3697
3698 if (failed(checkConnectConditionality(*this)))
3699 return failure();
3700
3701 return success();
3702}
3703
3704LogicalResult MatchingConnectOp::verify() {
3705 if (auto type = type_dyn_cast<FIRRTLType>(getDest().getType())) {
3706 auto baseType = type_cast<FIRRTLBaseType>(type);
3707
3708 // Analog types cannot be connected and must be attached.
3709 if (baseType && baseType.containsAnalog())
3710 return emitError("analog types may not be connected");
3711
3712 // The anonymous types of operands must be equivalent.
3713 assert(areAnonymousTypesEquivalent(cast<FIRRTLBaseType>(getSrc().getType()),
3714 baseType) &&
3715 "`SameAnonTypeOperands` trait should have already rejected "
3716 "structurally non-equivalent types");
3717 }
3718
3719 // Check that the flows make sense.
3720 if (failed(checkConnectFlow(*this)))
3721 return failure();
3722
3723 if (failed(checkConnectConditionality(*this)))
3724 return failure();
3725
3726 return success();
3727}
3728
3729LogicalResult RefDefineOp::verify() {
3730 // Check that the flows make sense.
3731 if (failed(checkConnectFlow(*this)))
3732 return failure();
3733
3734 // For now, refs can't be in bundles so this is sufficient.
3735 // In the future need to ensure no other define's to same "fieldSource".
3736 // (When aggregates can have references, we can define a reference within,
3737 // but this must be unique. Checking this here may be expensive,
3738 // consider adding something to FModuleLike's to check it there instead)
3739 for (auto *user : getDest().getUsers()) {
3740 if (auto conn = dyn_cast<FConnectLike>(user);
3741 conn && conn.getDest() == getDest() && conn != *this)
3742 return emitError("destination reference cannot be reused by multiple "
3743 "operations, it can only capture a unique dataflow");
3744 }
3745
3746 // Check "static" source/dest
3747 if (auto *op = getDest().getDefiningOp()) {
3748 // TODO: Make ref.sub only source flow?
3749 if (isa<RefSubOp>(op))
3750 return emitError(
3751 "destination reference cannot be a sub-element of a reference");
3752 if (isa<RefCastOp>(op)) // Source flow, check anyway for now.
3753 return emitError(
3754 "destination reference cannot be a cast of another reference");
3755 }
3756
3757 // This define is only enabled when its ambient layers are active. Check
3758 // that whenever the destination's layer requirements are met, that this
3759 // op is enabled.
3760 auto ambientLayers = getAmbientLayersAt(getOperation());
3761 auto dstLayers = getLayersFor(getDest());
3762 SmallVector<SymbolRefAttr> missingLayers;
3763 if (!isLayerSetCompatibleWith(ambientLayers, dstLayers, missingLayers)) {
3764 auto diag = emitOpError("has more layer requirements than destination");
3765 auto &note = diag.attachNote();
3766 note << "additional layers required: ";
3767 interleaveComma(missingLayers, note);
3768 return failure();
3769 }
3770
3771 return success();
3772}
3773
3774LogicalResult PropAssignOp::verify() {
3775 // Check that the flows make sense.
3776 if (failed(checkConnectFlow(*this)))
3777 return failure();
3778
3779 // Verify that there is a single value driving the destination.
3780 for (auto *user : getDest().getUsers()) {
3781 if (auto conn = dyn_cast<FConnectLike>(user);
3782 conn && conn.getDest() == getDest() && conn != *this)
3783 return emitError("destination property cannot be reused by multiple "
3784 "operations, it can only capture a unique dataflow");
3785 }
3786
3787 return success();
3788}
3789
3790void WhenOp::createElseRegion() {
3791 assert(!hasElseRegion() && "already has an else region");
3792 getElseRegion().push_back(new Block());
3793}
3794
3795void WhenOp::build(OpBuilder &builder, OperationState &result, Value condition,
3796 bool withElseRegion, std::function<void()> thenCtor,
3797 std::function<void()> elseCtor) {
3798 OpBuilder::InsertionGuard guard(builder);
3799 result.addOperands(condition);
3800
3801 // Create "then" region.
3802 builder.createBlock(result.addRegion());
3803 if (thenCtor)
3804 thenCtor();
3805
3806 // Create "else" region.
3807 Region *elseRegion = result.addRegion();
3808 if (withElseRegion) {
3809 builder.createBlock(elseRegion);
3810 if (elseCtor)
3811 elseCtor();
3812 }
3813}
3814
3815//===----------------------------------------------------------------------===//
3816// MatchOp
3817//===----------------------------------------------------------------------===//
3818
3819LogicalResult MatchOp::verify() {
3820 FEnumType type = getInput().getType();
3821
3822 // Make sure that the number of tags matches the number of regions.
3823 auto numCases = getTags().size();
3824 auto numRegions = getNumRegions();
3825 if (numRegions != numCases)
3826 return emitOpError("expected ")
3827 << numRegions << " tags but got " << numCases;
3828
3829 auto numTags = type.getNumElements();
3830
3831 SmallDenseSet<int64_t> seen;
3832 for (const auto &[tag, region] : llvm::zip(getTags(), getRegions())) {
3833 auto tagIndex = size_t(cast<IntegerAttr>(tag).getInt());
3834
3835 // Ensure that the block has a single argument.
3836 if (region.getNumArguments() != 1)
3837 return emitOpError("region should have exactly one argument");
3838
3839 // Make sure that it is a valid tag.
3840 if (tagIndex >= numTags)
3841 return emitOpError("the tag index ")
3842 << tagIndex << " is out of the range of valid tags in " << type;
3843
3844 // Make sure we have not already matched this tag.
3845 auto [it, inserted] = seen.insert(tagIndex);
3846 if (!inserted)
3847 return emitOpError("the tag ") << type.getElementNameAttr(tagIndex)
3848 << " is matched more than once";
3849
3850 // Check that the block argument type matches the tag's type.
3851 auto expectedType = type.getElementTypePreservingConst(tagIndex);
3852 auto regionType = region.getArgument(0).getType();
3853 if (regionType != expectedType)
3854 return emitOpError("region type ")
3855 << regionType << " does not match the expected type "
3856 << expectedType;
3857 }
3858
3859 // Check that the match statement is exhaustive.
3860 for (size_t i = 0, e = type.getNumElements(); i < e; ++i)
3861 if (!seen.contains(i))
3862 return emitOpError("missing case for tag ") << type.getElementNameAttr(i);
3863
3864 return success();
3865}
3866
3867void MatchOp::print(OpAsmPrinter &p) {
3868 auto input = getInput();
3869 FEnumType type = input.getType();
3870 auto regions = getRegions();
3871 p << " " << input << " : " << type;
3872 SmallVector<StringRef> elided = {"tags"};
3873 p.printOptionalAttrDictWithKeyword((*this)->getAttrs(), elided);
3874 p << " {";
3875 p.increaseIndent();
3876 for (const auto &[tag, region] : llvm::zip(getTags(), regions)) {
3877 p.printNewline();
3878 p << "case ";
3879 p.printKeywordOrString(
3880 type.getElementName(cast<IntegerAttr>(tag).getInt()));
3881 p << "(";
3882 p.printRegionArgument(region.front().getArgument(0), /*attrs=*/{},
3883 /*omitType=*/true);
3884 p << ") ";
3885 p.printRegion(region, /*printEntryBlockArgs=*/false);
3886 }
3887 p.decreaseIndent();
3888 p.printNewline();
3889 p << "}";
3890}
3891
3892ParseResult MatchOp::parse(OpAsmParser &parser, OperationState &result) {
3893 auto *context = parser.getContext();
3894 auto &properties = result.getOrAddProperties<Properties>();
3895 OpAsmParser::UnresolvedOperand input;
3896 if (parser.parseOperand(input) || parser.parseColon())
3897 return failure();
3898
3899 auto loc = parser.getCurrentLocation();
3900 Type type;
3901 if (parser.parseType(type))
3902 return failure();
3903 auto enumType = type_dyn_cast<FEnumType>(type);
3904 if (!enumType)
3905 return parser.emitError(loc, "expected enumeration type but got") << type;
3906
3907 if (parser.resolveOperand(input, type, result.operands) ||
3908 parser.parseOptionalAttrDictWithKeyword(result.attributes) ||
3909 parser.parseLBrace())
3910 return failure();
3911
3912 auto i32Type = IntegerType::get(context, 32);
3913 SmallVector<Attribute> tags;
3914 while (true) {
3915 // Stop parsing when we don't find another "case" keyword.
3916 if (failed(parser.parseOptionalKeyword("case")))
3917 break;
3918
3919 // Parse the tag and region argument.
3920 auto nameLoc = parser.getCurrentLocation();
3921 std::string name;
3922 OpAsmParser::Argument arg;
3923 auto *region = result.addRegion();
3924 if (parser.parseKeywordOrString(&name) || parser.parseLParen() ||
3925 parser.parseArgument(arg) || parser.parseRParen())
3926 return failure();
3927
3928 // Figure out the enum index of the tag.
3929 auto index = enumType.getElementIndex(name);
3930 if (!index)
3931 return parser.emitError(nameLoc, "the tag \"")
3932 << name << "\" is not a member of the enumeration " << enumType;
3933 tags.push_back(IntegerAttr::get(i32Type, *index));
3934
3935 // Parse the region.
3936 arg.type = enumType.getElementTypePreservingConst(*index);
3937 if (parser.parseRegion(*region, arg))
3938 return failure();
3939 }
3940 properties.setTags(ArrayAttr::get(context, tags));
3941
3942 return parser.parseRBrace();
3943}
3944
3945void MatchOp::build(OpBuilder &builder, OperationState &result, Value input,
3946 ArrayAttr tags,
3947 MutableArrayRef<std::unique_ptr<Region>> regions) {
3948 auto &properties = result.getOrAddProperties<Properties>();
3949 result.addOperands(input);
3950 properties.setTags(tags);
3951 result.addRegions(regions);
3952}
3953
3954//===----------------------------------------------------------------------===//
3955// Expressions
3956//===----------------------------------------------------------------------===//
3957
3958/// Return true if the specified operation is a firrtl expression.
3959bool firrtl::isExpression(Operation *op) {
3960 struct IsExprClassifier : public ExprVisitor<IsExprClassifier, bool> {
3961 bool visitInvalidExpr(Operation *op) { return false; }
3962 bool visitUnhandledExpr(Operation *op) { return true; }
3963 };
3964
3965 return IsExprClassifier().dispatchExprVisitor(op);
3966}
3967
3968void InvalidValueOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
3969 // Set invalid values to have a distinct name.
3970 std::string name;
3971 if (auto ty = type_dyn_cast<IntType>(getType())) {
3972 const char *base = ty.isSigned() ? "invalid_si" : "invalid_ui";
3973 auto width = ty.getWidthOrSentinel();
3974 if (width == -1)
3975 name = base;
3976 else
3977 name = (Twine(base) + Twine(width)).str();
3978 } else if (auto ty = type_dyn_cast<AnalogType>(getType())) {
3979 auto width = ty.getWidthOrSentinel();
3980 if (width == -1)
3981 name = "invalid_analog";
3982 else
3983 name = ("invalid_analog" + Twine(width)).str();
3984 } else if (type_isa<AsyncResetType>(getType()))
3985 name = "invalid_asyncreset";
3986 else if (type_isa<ResetType>(getType()))
3987 name = "invalid_reset";
3988 else if (type_isa<ClockType>(getType()))
3989 name = "invalid_clock";
3990 else
3991 name = "invalid";
3992
3993 setNameFn(getResult(), name);
3994}
3995
3996void ConstantOp::print(OpAsmPrinter &p) {
3997 p << " ";
3998 p.printAttributeWithoutType(getValueAttr());
3999 p << " : ";
4000 p.printType(getType());
4001 p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"value"});
4002}
4003
4004ParseResult ConstantOp::parse(OpAsmParser &parser, OperationState &result) {
4005 auto &properties = result.getOrAddProperties<Properties>();
4006 // Parse the constant value, without knowing its width.
4007 APInt value;
4008 auto loc = parser.getCurrentLocation();
4009 auto valueResult = parser.parseOptionalInteger(value);
4010 if (!valueResult.has_value())
4011 return parser.emitError(loc, "expected integer value");
4012
4013 // Parse the result firrtl integer type.
4014 IntType resultType;
4015 if (failed(*valueResult) || parser.parseColonType(resultType) ||
4016 parser.parseOptionalAttrDict(result.attributes))
4017 return failure();
4018 result.addTypes(resultType);
4019
4020 // Now that we know the width and sign of the result type, we can munge the
4021 // APInt as appropriate.
4022 if (resultType.hasWidth()) {
4023 auto width = (unsigned)resultType.getWidthOrSentinel();
4024 if (width > value.getBitWidth()) {
4025 // sext is always safe here, even for unsigned values, because the
4026 // parseOptionalInteger method will return something with a zero in the
4027 // top bits if it is a positive number.
4028 value = value.sext(width);
4029 } else if (width < value.getBitWidth()) {
4030 // The parser can return an unnecessarily wide result with leading
4031 // zeros. This isn't a problem, but truncating off bits is bad.
4032 unsigned neededBits = value.isNegative() ? value.getSignificantBits()
4033 : value.getActiveBits();
4034 if (width < neededBits)
4035 return parser.emitError(loc, "constant out of range for result type ")
4036 << resultType;
4037 value = value.trunc(width);
4038 }
4039 }
4040
4041 auto intType = parser.getBuilder().getIntegerType(value.getBitWidth(),
4042 resultType.isSigned());
4043 auto valueAttr = parser.getBuilder().getIntegerAttr(intType, value);
4044 properties.setValue(valueAttr);
4045 return success();
4046}
4047
4048LogicalResult ConstantOp::verify() {
4049 // If the result type has a bitwidth, then the attribute must match its width.
4050 IntType intType = getType();
4051 auto width = intType.getWidthOrSentinel();
4052 if (width != -1 && (int)getValue().getBitWidth() != width)
4053 return emitError(
4054 "firrtl.constant attribute bitwidth doesn't match return type");
4055
4056 // The sign of the attribute's integer type must match our integer type sign.
4057 auto attrType = type_cast<IntegerType>(getValueAttr().getType());
4058 if (attrType.isSignless() || attrType.isSigned() != intType.isSigned())
4059 return emitError("firrtl.constant attribute has wrong sign");
4060
4061 return success();
4062}
4063
4064/// Build a ConstantOp from an APInt and a FIRRTL type, handling the attribute
4065/// formation for the 'value' attribute.
4066void ConstantOp::build(OpBuilder &builder, OperationState &result, IntType type,
4067 const APInt &value) {
4068 int32_t width = type.getWidthOrSentinel();
4069 (void)width;
4070 assert((width == -1 || (int32_t)value.getBitWidth() == width) &&
4071 "incorrect attribute bitwidth for firrtl.constant");
4072
4073 auto attr =
4074 IntegerAttr::get(type.getContext(), APSInt(value, !type.isSigned()));
4075 return build(builder, result, type, attr);
4076}
4077
4078/// Build a ConstantOp from an APSInt, handling the attribute formation for the
4079/// 'value' attribute and inferring the FIRRTL type.
4080void ConstantOp::build(OpBuilder &builder, OperationState &result,
4081 const APSInt &value) {
4082 auto attr = IntegerAttr::get(builder.getContext(), value);
4083 auto type =
4084 IntType::get(builder.getContext(), value.isSigned(), value.getBitWidth());
4085 return build(builder, result, type, attr);
4086}
4087
4088void ConstantOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
4089 // For constants in particular, propagate the value into the result name to
4090 // make it easier to read the IR.
4091 IntType intTy = getType();
4092 assert(intTy);
4093
4094 // Otherwise, build a complex name with the value and type.
4095 SmallString<32> specialNameBuffer;
4096 llvm::raw_svector_ostream specialName(specialNameBuffer);
4097 specialName << 'c';
4098 getValue().print(specialName, /*isSigned:*/ intTy.isSigned());
4099
4100 specialName << (intTy.isSigned() ? "_si" : "_ui");
4101 auto width = intTy.getWidthOrSentinel();
4102 if (width != -1)
4103 specialName << width;
4104 setNameFn(getResult(), specialName.str());
4105}
4106
4107void SpecialConstantOp::print(OpAsmPrinter &p) {
4108 p << " ";
4109 // SpecialConstant uses a BoolAttr, and we want to print `true` as `1`.
4110 p << static_cast<unsigned>(getValue());
4111 p << " : ";
4112 p.printType(getType());
4113 p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"value"});
4114}
4115
4116ParseResult SpecialConstantOp::parse(OpAsmParser &parser,
4117 OperationState &result) {
4118 auto &properties = result.getOrAddProperties<Properties>();
4119 // Parse the constant value. SpecialConstant uses bool attributes, but it
4120 // prints as an integer.
4121 APInt value;
4122 auto loc = parser.getCurrentLocation();
4123 auto valueResult = parser.parseOptionalInteger(value);
4124 if (!valueResult.has_value())
4125 return parser.emitError(loc, "expected integer value");
4126
4127 // Clocks and resets can only be 0 or 1.
4128 if (value != 0 && value != 1)
4129 return parser.emitError(loc, "special constants can only be 0 or 1.");
4130
4131 // Parse the result firrtl type.
4132 Type resultType;
4133 if (failed(*valueResult) || parser.parseColonType(resultType) ||
4134 parser.parseOptionalAttrDict(result.attributes))
4135 return failure();
4136 result.addTypes(resultType);
4137
4138 // Create the attribute.
4139 auto valueAttr = parser.getBuilder().getBoolAttr(value == 1);
4140 properties.setValue(valueAttr);
4141 return success();
4142}
4143
4144void SpecialConstantOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
4145 SmallString<32> specialNameBuffer;
4146 llvm::raw_svector_ostream specialName(specialNameBuffer);
4147 specialName << 'c';
4148 specialName << static_cast<unsigned>(getValue());
4149 auto type = getType();
4150 if (type_isa<ClockType>(type)) {
4151 specialName << "_clock";
4152 } else if (type_isa<ResetType>(type)) {
4153 specialName << "_reset";
4154 } else if (type_isa<AsyncResetType>(type)) {
4155 specialName << "_asyncreset";
4156 }
4157 setNameFn(getResult(), specialName.str());
4158}
4159
4160// Checks that an array attr representing an aggregate constant has the correct
4161// shape. This recurses on the type.
4162static bool checkAggConstant(Operation *op, Attribute attr,
4163 FIRRTLBaseType type) {
4164 if (type.isGround()) {
4165 if (!isa<IntegerAttr>(attr)) {
4166 op->emitOpError("Ground type is not an integer attribute");
4167 return false;
4168 }
4169 return true;
4170 }
4171 auto attrlist = dyn_cast<ArrayAttr>(attr);
4172 if (!attrlist) {
4173 op->emitOpError("expected array attribute for aggregate constant");
4174 return false;
4175 }
4176 if (auto array = type_dyn_cast<FVectorType>(type)) {
4177 if (array.getNumElements() != attrlist.size()) {
4178 op->emitOpError("array attribute (")
4179 << attrlist.size() << ") has wrong size for vector constant ("
4180 << array.getNumElements() << ")";
4181 return false;
4182 }
4183 return llvm::all_of(attrlist, [&array, op](Attribute attr) {
4184 return checkAggConstant(op, attr, array.getElementType());
4185 });
4186 }
4187 if (auto bundle = type_dyn_cast<BundleType>(type)) {
4188 if (bundle.getNumElements() != attrlist.size()) {
4189 op->emitOpError("array attribute (")
4190 << attrlist.size() << ") has wrong size for bundle constant ("
4191 << bundle.getNumElements() << ")";
4192 return false;
4193 }
4194 for (size_t i = 0; i < bundle.getNumElements(); ++i) {
4195 if (bundle.getElement(i).isFlip) {
4196 op->emitOpError("Cannot have constant bundle type with flip");
4197 return false;
4198 }
4199 if (!checkAggConstant(op, attrlist[i], bundle.getElement(i).type))
4200 return false;
4201 }
4202 return true;
4203 }
4204 op->emitOpError("Unknown aggregate type");
4205 return false;
4206}
4207
4208LogicalResult AggregateConstantOp::verify() {
4209 if (checkAggConstant(getOperation(), getFields(), getType()))
4210 return success();
4211 return failure();
4212}
4213
4214Attribute AggregateConstantOp::getAttributeFromFieldID(uint64_t fieldID) {
4215 FIRRTLBaseType type = getType();
4216 Attribute value = getFields();
4217 while (fieldID != 0) {
4218 if (auto bundle = type_dyn_cast<BundleType>(type)) {
4219 auto index = bundle.getIndexForFieldID(fieldID);
4220 fieldID -= bundle.getFieldID(index);
4221 type = bundle.getElementType(index);
4222 value = cast<ArrayAttr>(value)[index];
4223 } else {
4224 auto vector = type_cast<FVectorType>(type);
4225 auto index = vector.getIndexForFieldID(fieldID);
4226 fieldID -= vector.getFieldID(index);
4227 type = vector.getElementType();
4228 value = cast<ArrayAttr>(value)[index];
4229 }
4230 }
4231 return value;
4232}
4233
4234void FIntegerConstantOp::print(OpAsmPrinter &p) {
4235 p << " ";
4236 p.printAttributeWithoutType(getValueAttr());
4237 p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"value"});
4238}
4239
4240ParseResult FIntegerConstantOp::parse(OpAsmParser &parser,
4241 OperationState &result) {
4242 auto *context = parser.getContext();
4243 auto &properties = result.getOrAddProperties<Properties>();
4244 APInt value;
4245 if (parser.parseInteger(value) ||
4246 parser.parseOptionalAttrDict(result.attributes))
4247 return failure();
4248 result.addTypes(FIntegerType::get(context));
4249 auto intType =
4250 IntegerType::get(context, value.getBitWidth(), IntegerType::Signed);
4251 auto valueAttr = parser.getBuilder().getIntegerAttr(intType, value);
4252 properties.setValue(valueAttr);
4253 return success();
4254}
4255
4256ParseResult ListCreateOp::parse(OpAsmParser &parser, OperationState &result) {
4257 llvm::SmallVector<OpAsmParser::UnresolvedOperand, 16> operands;
4258 ListType type;
4259
4260 if (parser.parseOperandList(operands) ||
4261 parser.parseOptionalAttrDict(result.attributes) ||
4262 parser.parseColonType(type))
4263 return failure();
4264 result.addTypes(type);
4265
4266 return parser.resolveOperands(operands, type.getElementType(),
4267 result.operands);
4268}
4269
4270void ListCreateOp::print(OpAsmPrinter &p) {
4271 p << " ";
4272 p.printOperands(getElements());
4273 p.printOptionalAttrDict((*this)->getAttrs());
4274 p << " : " << getType();
4275}
4276
4277LogicalResult ListCreateOp::verify() {
4278 if (getElements().empty())
4279 return success();
4280
4281 auto elementType = getElements().front().getType();
4282 auto listElementType = getType().getElementType();
4283 if (elementType != listElementType)
4284 return emitOpError("has elements of type ")
4285 << elementType << " instead of " << listElementType;
4286
4287 return success();
4288}
4289
4290LogicalResult BundleCreateOp::verify() {
4291 BundleType resultType = getType();
4292 if (resultType.getNumElements() != getFields().size())
4293 return emitOpError("number of fields doesn't match type");
4294 for (size_t i = 0; i < resultType.getNumElements(); ++i)
4296 resultType.getElementTypePreservingConst(i),
4297 type_cast<FIRRTLBaseType>(getOperand(i).getType())))
4298 return emitOpError("type of element doesn't match bundle for field ")
4299 << resultType.getElement(i).name;
4300 // TODO: check flow
4301 return success();
4302}
4303
4304LogicalResult VectorCreateOp::verify() {
4305 FVectorType resultType = getType();
4306 if (resultType.getNumElements() != getFields().size())
4307 return emitOpError("number of fields doesn't match type");
4308 auto elemTy = resultType.getElementTypePreservingConst();
4309 for (size_t i = 0; i < resultType.getNumElements(); ++i)
4311 elemTy, type_cast<FIRRTLBaseType>(getOperand(i).getType())))
4312 return emitOpError("type of element doesn't match vector element");
4313 // TODO: check flow
4314 return success();
4315}
4316
4317//===----------------------------------------------------------------------===//
4318// FEnumCreateOp
4319//===----------------------------------------------------------------------===//
4320
4321LogicalResult FEnumCreateOp::verify() {
4322 FEnumType resultType = getResult().getType();
4323 auto elementIndex = resultType.getElementIndex(getFieldName());
4324 if (!elementIndex)
4325 return emitOpError("label ")
4326 << getFieldName() << " is not a member of the enumeration type "
4327 << resultType;
4329 resultType.getElementTypePreservingConst(*elementIndex),
4330 getInput().getType()))
4331 return emitOpError("type of element doesn't match enum element");
4332 return success();
4333}
4334
4335void FEnumCreateOp::print(OpAsmPrinter &printer) {
4336 printer << ' ';
4337 printer.printKeywordOrString(getFieldName());
4338 printer << '(' << getInput() << ')';
4339 SmallVector<StringRef> elidedAttrs = {"fieldIndex"};
4340 printer.printOptionalAttrDictWithKeyword((*this)->getAttrs(), elidedAttrs);
4341 printer << " : ";
4342 printer.printFunctionalType(ArrayRef<Type>{getInput().getType()},
4343 ArrayRef<Type>{getResult().getType()});
4344}
4345
4346ParseResult FEnumCreateOp::parse(OpAsmParser &parser, OperationState &result) {
4347 auto *context = parser.getContext();
4348 auto &properties = result.getOrAddProperties<Properties>();
4349
4350 OpAsmParser::UnresolvedOperand input;
4351 std::string fieldName;
4352 mlir::FunctionType functionType;
4353 if (parser.parseKeywordOrString(&fieldName) || parser.parseLParen() ||
4354 parser.parseOperand(input) || parser.parseRParen() ||
4355 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
4356 parser.parseType(functionType))
4357 return failure();
4358
4359 if (functionType.getNumInputs() != 1)
4360 return parser.emitError(parser.getNameLoc(), "single input type required");
4361 if (functionType.getNumResults() != 1)
4362 return parser.emitError(parser.getNameLoc(), "single result type required");
4363
4364 auto inputType = functionType.getInput(0);
4365 if (parser.resolveOperand(input, inputType, result.operands))
4366 return failure();
4367
4368 auto outputType = functionType.getResult(0);
4369 auto enumType = type_dyn_cast<FEnumType>(outputType);
4370 if (!enumType)
4371 return parser.emitError(parser.getNameLoc(),
4372 "output must be enum type, got ")
4373 << outputType;
4374 auto fieldIndex = enumType.getElementIndex(fieldName);
4375 if (!fieldIndex)
4376 return parser.emitError(parser.getNameLoc(),
4377 "unknown field " + fieldName + " in enum type ")
4378 << enumType;
4379
4380 properties.setFieldIndex(
4381 IntegerAttr::get(IntegerType::get(context, 32), *fieldIndex));
4382
4383 result.addTypes(enumType);
4384
4385 return success();
4386}
4387
4388//===----------------------------------------------------------------------===//
4389// IsTagOp
4390//===----------------------------------------------------------------------===//
4391
4392LogicalResult IsTagOp::verify() {
4393 if (getFieldIndex() >= getInput().getType().base().getNumElements())
4394 return emitOpError("element index is greater than the number of fields in "
4395 "the bundle type");
4396 return success();
4397}
4398
4399void IsTagOp::print(::mlir::OpAsmPrinter &printer) {
4400 printer << ' ' << getInput() << ' ';
4401 printer.printKeywordOrString(getFieldName());
4402 SmallVector<::llvm::StringRef, 1> elidedAttrs = {"fieldIndex"};
4403 printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
4404 printer << " : " << getInput().getType();
4405}
4406
4407ParseResult IsTagOp::parse(OpAsmParser &parser, OperationState &result) {
4408 auto *context = parser.getContext();
4409 auto &properties = result.getOrAddProperties<Properties>();
4410
4411 OpAsmParser::UnresolvedOperand input;
4412 std::string fieldName;
4413 Type inputType;
4414 if (parser.parseOperand(input) || parser.parseKeywordOrString(&fieldName) ||
4415 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
4416 parser.parseType(inputType))
4417 return failure();
4418
4419 if (parser.resolveOperand(input, inputType, result.operands))
4420 return failure();
4421
4422 auto enumType = type_dyn_cast<FEnumType>(inputType);
4423 if (!enumType)
4424 return parser.emitError(parser.getNameLoc(),
4425 "input must be enum type, got ")
4426 << inputType;
4427 auto fieldIndex = enumType.getElementIndex(fieldName);
4428 if (!fieldIndex)
4429 return parser.emitError(parser.getNameLoc(),
4430 "unknown field " + fieldName + " in enum type ")
4431 << enumType;
4432
4433 properties.setFieldIndex(
4434 IntegerAttr::get(IntegerType::get(context, 32), *fieldIndex));
4435
4436 result.addTypes(UIntType::get(context, 1, /*isConst=*/false));
4437
4438 return success();
4439}
4440
4441FIRRTLType IsTagOp::inferReturnType(ValueRange operands, DictionaryAttr attrs,
4442 OpaqueProperties properties,
4443 mlir::RegionRange regions,
4444 std::optional<Location> loc) {
4445 Adaptor adaptor(operands, attrs, properties, regions);
4446 return UIntType::get(attrs.getContext(), 1,
4447 isConst(adaptor.getInput().getType()));
4448}
4449
4450template <typename OpTy>
4451ParseResult parseSubfieldLikeOp(OpAsmParser &parser, OperationState &result) {
4452 auto *context = parser.getContext();
4453
4454 OpAsmParser::UnresolvedOperand input;
4455 std::string fieldName;
4456 Type inputType;
4457 if (parser.parseOperand(input) || parser.parseLSquare() ||
4458 parser.parseKeywordOrString(&fieldName) || parser.parseRSquare() ||
4459 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
4460 parser.parseType(inputType))
4461 return failure();
4462
4463 if (parser.resolveOperand(input, inputType, result.operands))
4464 return failure();
4465
4466 auto bundleType = type_dyn_cast<typename OpTy::InputType>(inputType);
4467 if (!bundleType)
4468 return parser.emitError(parser.getNameLoc(),
4469 "input must be bundle type, got ")
4470 << inputType;
4471 auto fieldIndex = bundleType.getElementIndex(fieldName);
4472 if (!fieldIndex)
4473 return parser.emitError(parser.getNameLoc(),
4474 "unknown field " + fieldName + " in bundle type ")
4475 << bundleType;
4476
4477 result.getOrAddProperties<typename OpTy::Properties>().setFieldIndex(
4478 IntegerAttr::get(IntegerType::get(context, 32), *fieldIndex));
4479
4480 auto type = OpTy::inferReturnType(inputType, *fieldIndex, {});
4481 if (!type)
4482 return failure();
4483 result.addTypes(type);
4484
4485 return success();
4486}
4487
4488ParseResult SubtagOp::parse(OpAsmParser &parser, OperationState &result) {
4489 auto *context = parser.getContext();
4490
4491 OpAsmParser::UnresolvedOperand input;
4492 std::string fieldName;
4493 Type inputType;
4494 if (parser.parseOperand(input) || parser.parseLSquare() ||
4495 parser.parseKeywordOrString(&fieldName) || parser.parseRSquare() ||
4496 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
4497 parser.parseType(inputType))
4498 return failure();
4499
4500 if (parser.resolveOperand(input, inputType, result.operands))
4501 return failure();
4502
4503 auto enumType = type_dyn_cast<FEnumType>(inputType);
4504 if (!enumType)
4505 return parser.emitError(parser.getNameLoc(),
4506 "input must be enum type, got ")
4507 << inputType;
4508 auto fieldIndex = enumType.getElementIndex(fieldName);
4509 if (!fieldIndex)
4510 return parser.emitError(parser.getNameLoc(),
4511 "unknown field " + fieldName + " in enum type ")
4512 << enumType;
4513
4514 result.getOrAddProperties<Properties>().setFieldIndex(
4515 IntegerAttr::get(IntegerType::get(context, 32), *fieldIndex));
4516
4517 SmallVector<Type> inferredReturnTypes;
4518 if (failed(SubtagOp::inferReturnTypes(
4519 context, result.location, result.operands,
4520 result.attributes.getDictionary(context), result.getRawProperties(),
4521 result.regions, inferredReturnTypes)))
4522 return failure();
4523 result.addTypes(inferredReturnTypes);
4524
4525 return success();
4526}
4527
4528ParseResult SubfieldOp::parse(OpAsmParser &parser, OperationState &result) {
4529 return parseSubfieldLikeOp<SubfieldOp>(parser, result);
4530}
4531ParseResult OpenSubfieldOp::parse(OpAsmParser &parser, OperationState &result) {
4532 return parseSubfieldLikeOp<OpenSubfieldOp>(parser, result);
4533}
4534
4535template <typename OpTy>
4536static void printSubfieldLikeOp(OpTy op, ::mlir::OpAsmPrinter &printer) {
4537 printer << ' ' << op.getInput() << '[';
4538 printer.printKeywordOrString(op.getFieldName());
4539 printer << ']';
4540 ::llvm::SmallVector<::llvm::StringRef, 2> elidedAttrs;
4541 elidedAttrs.push_back("fieldIndex");
4542 printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
4543 printer << " : " << op.getInput().getType();
4544}
4545void SubfieldOp::print(::mlir::OpAsmPrinter &printer) {
4546 return printSubfieldLikeOp<SubfieldOp>(*this, printer);
4547}
4548void OpenSubfieldOp::print(::mlir::OpAsmPrinter &printer) {
4549 return printSubfieldLikeOp<OpenSubfieldOp>(*this, printer);
4550}
4551
4552void SubtagOp::print(::mlir::OpAsmPrinter &printer) {
4553 printer << ' ' << getInput() << '[';
4554 printer.printKeywordOrString(getFieldName());
4555 printer << ']';
4556 ::llvm::SmallVector<::llvm::StringRef, 2> elidedAttrs;
4557 elidedAttrs.push_back("fieldIndex");
4558 printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
4559 printer << " : " << getInput().getType();
4560}
4561
4562template <typename OpTy>
4563static LogicalResult verifySubfieldLike(OpTy op) {
4564 if (op.getFieldIndex() >=
4565 firrtl::type_cast<typename OpTy::InputType>(op.getInput().getType())
4566 .getNumElements())
4567 return op.emitOpError("subfield element index is greater than the number "
4568 "of fields in the bundle type");
4569 return success();
4570}
4571LogicalResult SubfieldOp::verify() {
4572 return verifySubfieldLike<SubfieldOp>(*this);
4573}
4574LogicalResult OpenSubfieldOp::verify() {
4575 return verifySubfieldLike<OpenSubfieldOp>(*this);
4576}
4577
4578LogicalResult SubtagOp::verify() {
4579 if (getFieldIndex() >= getInput().getType().base().getNumElements())
4580 return emitOpError("subfield element index is greater than the number "
4581 "of fields in the bundle type");
4582 return success();
4583}
4584
4585/// Return true if the specified operation has a constant value. This trivially
4586/// checks for `firrtl.constant` and friends, but also looks through subaccesses
4587/// and correctly handles wires driven with only constant values.
4588bool firrtl::isConstant(Operation *op) {
4589 // Worklist of ops that need to be examined that should all be constant in
4590 // order for the input operation to be constant.
4591 SmallVector<Operation *, 8> worklist({op});
4592
4593 // Mutable state indicating if this op is a constant. Assume it is a constant
4594 // and look for counterexamples.
4595 bool constant = true;
4596
4597 // While we haven't found a counterexample and there are still ops in the
4598 // worklist, pull ops off the worklist. If it provides a counterexample, set
4599 // the `constant` to false (and exit on the next loop iteration). Otherwise,
4600 // look through the op or spawn off more ops to look at.
4601 while (constant && !(worklist.empty()))
4602 TypeSwitch<Operation *>(worklist.pop_back_val())
4603 .Case<NodeOp, AsSIntPrimOp, AsUIntPrimOp>([&](auto op) {
4604 if (auto definingOp = op.getInput().getDefiningOp())
4605 worklist.push_back(definingOp);
4606 constant = false;
4607 })
4608 .Case<WireOp, SubindexOp, SubfieldOp>([&](auto op) {
4609 for (auto &use : op.getResult().getUses())
4610 worklist.push_back(use.getOwner());
4611 })
4612 .Case<ConstantOp, SpecialConstantOp, AggregateConstantOp>([](auto) {})
4613 .Default([&](auto) { constant = false; });
4614
4615 return constant;
4616}
4617
4618/// Return true if the specified value is a constant. This trivially checks for
4619/// `firrtl.constant` and friends, but also looks through subaccesses and
4620/// correctly handles wires driven with only constant values.
4621bool firrtl::isConstant(Value value) {
4622 if (auto *op = value.getDefiningOp())
4623 return isConstant(op);
4624 return false;
4625}
4626
4627LogicalResult ConstCastOp::verify() {
4628 if (!areTypesConstCastable(getResult().getType(), getInput().getType()))
4629 return emitOpError() << getInput().getType()
4630 << " is not 'const'-castable to "
4631 << getResult().getType();
4632 return success();
4633}
4634
4635FIRRTLType SubfieldOp::inferReturnType(Type type, uint32_t fieldIndex,
4636 std::optional<Location> loc) {
4637 auto inType = type_cast<BundleType>(type);
4638
4639 if (fieldIndex >= inType.getNumElements())
4640 return emitInferRetTypeError(loc,
4641 "subfield element index is greater than the "
4642 "number of fields in the bundle type");
4643
4644 // SubfieldOp verifier checks that the field index is valid with number of
4645 // subelements.
4646 return inType.getElementTypePreservingConst(fieldIndex);
4647}
4648
4649FIRRTLType OpenSubfieldOp::inferReturnType(Type type, uint32_t fieldIndex,
4650 std::optional<Location> loc) {
4651 auto inType = type_cast<OpenBundleType>(type);
4652
4653 if (fieldIndex >= inType.getNumElements())
4654 return emitInferRetTypeError(loc,
4655 "subfield element index is greater than the "
4656 "number of fields in the bundle type");
4657
4658 // OpenSubfieldOp verifier checks that the field index is valid with number of
4659 // subelements.
4660 return inType.getElementTypePreservingConst(fieldIndex);
4661}
4662
4663bool SubfieldOp::isFieldFlipped() {
4664 BundleType bundle = getInput().getType();
4665 return bundle.getElement(getFieldIndex()).isFlip;
4666}
4667bool OpenSubfieldOp::isFieldFlipped() {
4668 auto bundle = getInput().getType();
4669 return bundle.getElement(getFieldIndex()).isFlip;
4670}
4671
4672FIRRTLType SubindexOp::inferReturnType(Type type, uint32_t fieldIndex,
4673 std::optional<Location> loc) {
4674 if (auto vectorType = type_dyn_cast<FVectorType>(type)) {
4675 if (fieldIndex < vectorType.getNumElements())
4676 return vectorType.getElementTypePreservingConst();
4677 return emitInferRetTypeError(loc, "out of range index '", fieldIndex,
4678 "' in vector type ", type);
4679 }
4680 return emitInferRetTypeError(loc, "subindex requires vector operand");
4681}
4682
4683FIRRTLType OpenSubindexOp::inferReturnType(Type type, uint32_t fieldIndex,
4684 std::optional<Location> loc) {
4685 if (auto vectorType = type_dyn_cast<OpenVectorType>(type)) {
4686 if (fieldIndex < vectorType.getNumElements())
4687 return vectorType.getElementTypePreservingConst();
4688 return emitInferRetTypeError(loc, "out of range index '", fieldIndex,
4689 "' in vector type ", type);
4690 }
4691
4692 return emitInferRetTypeError(loc, "subindex requires vector operand");
4693}
4694
4695FIRRTLType SubtagOp::inferReturnType(ValueRange operands, DictionaryAttr attrs,
4696 OpaqueProperties properties,
4697 mlir::RegionRange regions,
4698 std::optional<Location> loc) {
4699 Adaptor adaptor(operands, attrs, properties, regions);
4700 auto inType = type_cast<FEnumType>(adaptor.getInput().getType());
4701 auto fieldIndex = adaptor.getFieldIndex();
4702
4703 if (fieldIndex >= inType.getNumElements())
4704 return emitInferRetTypeError(loc,
4705 "subtag element index is greater than the "
4706 "number of fields in the enum type");
4707
4708 // SubtagOp verifier checks that the field index is valid with number of
4709 // subelements.
4710 auto elementType = inType.getElement(fieldIndex).type;
4711 return elementType.getConstType(elementType.isConst() || inType.isConst());
4712}
4713
4714FIRRTLType SubaccessOp::inferReturnType(Type inType, Type indexType,
4715 std::optional<Location> loc) {
4716 if (!type_isa<UIntType>(indexType))
4717 return emitInferRetTypeError(loc, "subaccess index must be UInt type, not ",
4718 indexType);
4719
4720 if (auto vectorType = type_dyn_cast<FVectorType>(inType)) {
4721 if (isConst(indexType))
4722 return vectorType.getElementTypePreservingConst();
4723 return vectorType.getElementType().getAllConstDroppedType();
4724 }
4725
4726 return emitInferRetTypeError(loc, "subaccess requires vector operand, not ",
4727 inType);
4728}
4729
4730FIRRTLType TagExtractOp::inferReturnType(ValueRange operands,
4731 DictionaryAttr attrs,
4732 OpaqueProperties properties,
4733 mlir::RegionRange regions,
4734 std::optional<Location> loc) {
4735 Adaptor adaptor(operands, attrs, properties, regions);
4736 auto inType = type_cast<FEnumType>(adaptor.getInput().getType());
4737 auto i = llvm::Log2_32_Ceil(inType.getNumElements());
4738 return UIntType::get(inType.getContext(), i);
4739}
4740
4741ParseResult MultibitMuxOp::parse(OpAsmParser &parser, OperationState &result) {
4742 OpAsmParser::UnresolvedOperand index;
4743 SmallVector<OpAsmParser::UnresolvedOperand, 16> inputs;
4744 Type indexType, elemType;
4745
4746 if (parser.parseOperand(index) || parser.parseComma() ||
4747 parser.parseOperandList(inputs) ||
4748 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
4749 parser.parseType(indexType) || parser.parseComma() ||
4750 parser.parseType(elemType))
4751 return failure();
4752
4753 if (parser.resolveOperand(index, indexType, result.operands))
4754 return failure();
4755
4756 result.addTypes(elemType);
4757
4758 return parser.resolveOperands(inputs, elemType, result.operands);
4759}
4760
4761void MultibitMuxOp::print(OpAsmPrinter &p) {
4762 p << " " << getIndex() << ", ";
4763 p.printOperands(getInputs());
4764 p.printOptionalAttrDict((*this)->getAttrs());
4765 p << " : " << getIndex().getType() << ", " << getType();
4766}
4767
4768FIRRTLType MultibitMuxOp::inferReturnType(ValueRange operands,
4769 DictionaryAttr attrs,
4770 OpaqueProperties properties,
4771 mlir::RegionRange regions,
4772 std::optional<Location> loc) {
4773 if (operands.size() < 2)
4774 return emitInferRetTypeError(loc, "at least one input is required");
4775
4776 // Check all mux inputs have the same type.
4777 if (!llvm::all_of(operands.drop_front(2), [&](auto op) {
4778 return operands[1].getType() == op.getType();
4779 }))
4780 return emitInferRetTypeError(loc, "all inputs must have the same type");
4781
4782 return type_cast<FIRRTLType>(operands[1].getType());
4783}
4784
4785//===----------------------------------------------------------------------===//
4786// ObjectSubfieldOp
4787//===----------------------------------------------------------------------===//
4788
4789LogicalResult ObjectSubfieldOp::inferReturnTypes(
4790 MLIRContext *context, std::optional<mlir::Location> location,
4791 ValueRange operands, DictionaryAttr attributes, OpaqueProperties properties,
4792 RegionRange regions, llvm::SmallVectorImpl<Type> &inferredReturnTypes) {
4793 auto type =
4794 inferReturnType(operands, attributes, properties, regions, location);
4795 if (!type)
4796 return failure();
4797 inferredReturnTypes.push_back(type);
4798 return success();
4799}
4800
4801Type ObjectSubfieldOp::inferReturnType(Type inType, uint32_t fieldIndex,
4802 std::optional<Location> loc) {
4803 auto classType = dyn_cast<ClassType>(inType);
4804 if (!classType)
4805 return emitInferRetTypeError(loc, "base object is not a class");
4806
4807 if (classType.getNumElements() <= fieldIndex)
4808 return emitInferRetTypeError(loc, "element index is greater than the "
4809 "number of fields in the object");
4810 return classType.getElement(fieldIndex).type;
4811}
4812
4813void ObjectSubfieldOp::print(OpAsmPrinter &p) {
4814 auto input = getInput();
4815 auto classType = input.getType();
4816 p << ' ' << input << "[";
4817 p.printKeywordOrString(classType.getElement(getIndex()).name);
4818 p << "]";
4819 p.printOptionalAttrDict((*this)->getAttrs(), std::array{StringRef("index")});
4820 p << " : " << classType;
4821}
4822
4823ParseResult ObjectSubfieldOp::parse(OpAsmParser &parser,
4824 OperationState &result) {
4825 auto *context = parser.getContext();
4826
4827 OpAsmParser::UnresolvedOperand input;
4828 std::string fieldName;
4829 ClassType inputType;
4830 if (parser.parseOperand(input) || parser.parseLSquare() ||
4831 parser.parseKeywordOrString(&fieldName) || parser.parseRSquare() ||
4832 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
4833 parser.parseType(inputType) ||
4834 parser.resolveOperand(input, inputType, result.operands))
4835 return failure();
4836
4837 auto index = inputType.getElementIndex(fieldName);
4838 if (!index)
4839 return parser.emitError(parser.getNameLoc(),
4840 "unknown field " + fieldName + " in class type ")
4841 << inputType;
4842 result.getOrAddProperties<Properties>().setIndex(
4843 IntegerAttr::get(IntegerType::get(context, 32), *index));
4844
4845 SmallVector<Type> inferredReturnTypes;
4846 if (failed(inferReturnTypes(context, result.location, result.operands,
4847 result.attributes.getDictionary(context),
4848 result.getRawProperties(), result.regions,
4849 inferredReturnTypes)))
4850 return failure();
4851 result.addTypes(inferredReturnTypes);
4852
4853 return success();
4854}
4855
4856//===----------------------------------------------------------------------===//
4857// Binary Primitives
4858//===----------------------------------------------------------------------===//
4859
4860/// If LHS and RHS are both UInt or SInt types, the return true and fill in the
4861/// width of them if known. If unknown, return -1 for the widths.
4862/// The constness of the result is also returned, where if both lhs and rhs are
4863/// const, then the result is const.
4864///
4865/// On failure, this reports and error and returns false. This function should
4866/// not be used if you don't want an error reported.
4867static bool isSameIntTypeKind(Type lhs, Type rhs, int32_t &lhsWidth,
4868 int32_t &rhsWidth, bool &isConstResult,
4869 std::optional<Location> loc) {
4870 // Must have two integer types with the same signedness.
4871 auto lhsi = type_dyn_cast<IntType>(lhs);
4872 auto rhsi = type_dyn_cast<IntType>(rhs);
4873 if (!lhsi || !rhsi || lhsi.isSigned() != rhsi.isSigned()) {
4874 if (loc) {
4875 if (lhsi && !rhsi)
4876 mlir::emitError(*loc, "second operand must be an integer type, not ")
4877 << rhs;
4878 else if (!lhsi && rhsi)
4879 mlir::emitError(*loc, "first operand must be an integer type, not ")
4880 << lhs;
4881 else if (!lhsi && !rhsi)
4882 mlir::emitError(*loc, "operands must be integer types, not ")
4883 << lhs << " and " << rhs;
4884 else
4885 mlir::emitError(*loc, "operand signedness must match");
4886 }
4887 return false;
4888 }
4889
4890 lhsWidth = lhsi.getWidthOrSentinel();
4891 rhsWidth = rhsi.getWidthOrSentinel();
4892 isConstResult = lhsi.isConst() && rhsi.isConst();
4893 return true;
4894}
4895
4896LogicalResult impl::verifySameOperandsIntTypeKind(Operation *op) {
4897 assert(op->getNumOperands() == 2 &&
4898 "SameOperandsIntTypeKind on non-binary op");
4899 int32_t lhsWidth, rhsWidth;
4900 bool isConstResult;
4901 return success(isSameIntTypeKind(op->getOperand(0).getType(),
4902 op->getOperand(1).getType(), lhsWidth,
4903 rhsWidth, isConstResult, op->getLoc()));
4904}
4905
4907 std::optional<Location> loc) {
4908 int32_t lhsWidth, rhsWidth, resultWidth = -1;
4909 bool isConstResult = false;
4910 if (!isSameIntTypeKind(lhs, rhs, lhsWidth, rhsWidth, isConstResult, loc))
4911 return {};
4912
4913 if (lhsWidth != -1 && rhsWidth != -1)
4914 resultWidth = std::max(lhsWidth, rhsWidth) + 1;
4915 return IntType::get(lhs.getContext(), type_isa<SIntType>(lhs), resultWidth,
4916 isConstResult);
4917}
4918
4919FIRRTLType MulPrimOp::inferReturnType(FIRRTLType lhs, FIRRTLType rhs,
4920 std::optional<Location> loc) {
4921 int32_t lhsWidth, rhsWidth, resultWidth = -1;
4922 bool isConstResult = false;
4923 if (!isSameIntTypeKind(lhs, rhs, lhsWidth, rhsWidth, isConstResult, loc))
4924 return {};
4925
4926 if (lhsWidth != -1 && rhsWidth != -1)
4927 resultWidth = lhsWidth + rhsWidth;
4928
4929 return IntType::get(lhs.getContext(), type_isa<SIntType>(lhs), resultWidth,
4930 isConstResult);
4931}
4932
4933FIRRTLType DivPrimOp::inferReturnType(FIRRTLType lhs, FIRRTLType rhs,
4934 std::optional<Location> loc) {
4935 int32_t lhsWidth, rhsWidth;
4936 bool isConstResult = false;
4937 if (!isSameIntTypeKind(lhs, rhs, lhsWidth, rhsWidth, isConstResult, loc))
4938 return {};
4939
4940 // For unsigned, the width is the width of the numerator on the LHS.
4941 if (type_isa<UIntType>(lhs))
4942 return UIntType::get(lhs.getContext(), lhsWidth, isConstResult);
4943
4944 // For signed, the width is the width of the numerator on the LHS, plus 1.
4945 int32_t resultWidth = lhsWidth != -1 ? lhsWidth + 1 : -1;
4946 return SIntType::get(lhs.getContext(), resultWidth, isConstResult);
4947}
4948
4949FIRRTLType RemPrimOp::inferReturnType(FIRRTLType lhs, FIRRTLType rhs,
4950 std::optional<Location> loc) {
4951 int32_t lhsWidth, rhsWidth, resultWidth = -1;
4952 bool isConstResult = false;
4953 if (!isSameIntTypeKind(lhs, rhs, lhsWidth, rhsWidth, isConstResult, loc))
4954 return {};
4955
4956 if (lhsWidth != -1 && rhsWidth != -1)
4957 resultWidth = std::min(lhsWidth, rhsWidth);
4958 return IntType::get(lhs.getContext(), type_isa<SIntType>(lhs), resultWidth,
4959 isConstResult);
4960}
4961
4963 std::optional<Location> loc) {
4964 int32_t lhsWidth, rhsWidth, resultWidth = -1;
4965 bool isConstResult = false;
4966 if (!isSameIntTypeKind(lhs, rhs, lhsWidth, rhsWidth, isConstResult, loc))
4967 return {};
4968
4969 if (lhsWidth != -1 && rhsWidth != -1) {
4970 resultWidth = std::max(lhsWidth, rhsWidth);
4971 if (lhsWidth == resultWidth && lhs.isConst() == isConstResult &&
4972 isa<UIntType>(lhs))
4973 return lhs;
4974 if (rhsWidth == resultWidth && rhs.isConst() == isConstResult &&
4975 isa<UIntType>(rhs))
4976 return rhs;
4977 }
4978 return UIntType::get(lhs.getContext(), resultWidth, isConstResult);
4979}
4980
4982 std::optional<Location> loc) {
4983 if (!type_isa<FVectorType>(lhs) || !type_isa<FVectorType>(rhs))
4984 return {};
4985
4986 auto lhsVec = type_cast<FVectorType>(lhs);
4987 auto rhsVec = type_cast<FVectorType>(rhs);
4988
4989 if (lhsVec.getNumElements() != rhsVec.getNumElements())
4990 return {};
4991
4992 auto elemType =
4993 impl::inferBitwiseResult(lhsVec.getElementTypePreservingConst(),
4994 rhsVec.getElementTypePreservingConst(), loc);
4995 if (!elemType)
4996 return {};
4997 auto elemBaseType = type_cast<FIRRTLBaseType>(elemType);
4998 return FVectorType::get(elemBaseType, lhsVec.getNumElements(),
4999 lhsVec.isConst() && rhsVec.isConst() &&
5000 elemBaseType.isConst());
5001}
5002
5004 std::optional<Location> loc) {
5005 return UIntType::get(lhs.getContext(), 1, isConst(lhs) && isConst(rhs));
5006}
5007
5008FIRRTLType CatPrimOp::inferReturnType(FIRRTLType lhs, FIRRTLType rhs,
5009 std::optional<Location> loc) {
5010 int32_t lhsWidth, rhsWidth, resultWidth = -1;
5011 bool isConstResult = false;
5012 if (!isSameIntTypeKind(lhs, rhs, lhsWidth, rhsWidth, isConstResult, loc))
5013 return {};
5014
5015 if (lhsWidth != -1 && rhsWidth != -1)
5016 resultWidth = lhsWidth + rhsWidth;
5017 return UIntType::get(lhs.getContext(), resultWidth, isConstResult);
5018}
5019
5020FIRRTLType DShlPrimOp::inferReturnType(FIRRTLType lhs, FIRRTLType rhs,
5021 std::optional<Location> loc) {
5022 auto lhsi = type_dyn_cast<IntType>(lhs);
5023 auto rhsui = type_dyn_cast<UIntType>(rhs);
5024 if (!rhsui || !lhsi)
5025 return emitInferRetTypeError(
5026 loc, "first operand should be integer, second unsigned int");
5027
5028 // If the left or right has unknown result type, then the operation does
5029 // too.
5030 auto width = lhsi.getWidthOrSentinel();
5031 if (width == -1 || !rhsui.getWidth().has_value()) {
5032 width = -1;
5033 } else {
5034 auto amount = *rhsui.getWidth();
5035 if (amount >= 32)
5036 return emitInferRetTypeError(loc,
5037 "shift amount too large: second operand of "
5038 "dshl is wider than 31 bits");
5039 int64_t newWidth = (int64_t)width + ((int64_t)1 << amount) - 1;
5040 if (newWidth > INT32_MAX)
5041 return emitInferRetTypeError(
5042 loc, "shift amount too large: first operand shifted by maximum "
5043 "amount exceeds maximum width");
5044 width = newWidth;
5045 }
5046 return IntType::get(lhs.getContext(), lhsi.isSigned(), width,
5047 lhsi.isConst() && rhsui.isConst());
5048}
5049
5050FIRRTLType DShlwPrimOp::inferReturnType(FIRRTLType lhs, FIRRTLType rhs,
5051 std::optional<Location> loc) {
5052 auto lhsi = type_dyn_cast<IntType>(lhs);
5053 auto rhsu = type_dyn_cast<UIntType>(rhs);
5054 if (!lhsi || !rhsu)
5055 return emitInferRetTypeError(
5056 loc, "first operand should be integer, second unsigned int");
5057 return lhsi.getConstType(lhsi.isConst() && rhsu.isConst());
5058}
5059
5060FIRRTLType DShrPrimOp::inferReturnType(FIRRTLType lhs, FIRRTLType rhs,
5061 std::optional<Location> loc) {
5062 auto lhsi = type_dyn_cast<IntType>(lhs);
5063 auto rhsu = type_dyn_cast<UIntType>(rhs);
5064 if (!lhsi || !rhsu)
5065 return emitInferRetTypeError(
5066 loc, "first operand should be integer, second unsigned int");
5067 return lhsi.getConstType(lhsi.isConst() && rhsu.isConst());
5068}
5069
5070//===----------------------------------------------------------------------===//
5071// Unary Primitives
5072//===----------------------------------------------------------------------===//
5073
5074FIRRTLType SizeOfIntrinsicOp::inferReturnType(FIRRTLType input,
5075 std::optional<Location> loc) {
5076 return UIntType::get(input.getContext(), 32);
5077}
5078
5079FIRRTLType AsSIntPrimOp::inferReturnType(FIRRTLType input,
5080 std::optional<Location> loc) {
5081 auto base = type_dyn_cast<FIRRTLBaseType>(input);
5082 if (!base)
5083 return emitInferRetTypeError(loc, "operand must be a scalar base type");
5084 int32_t width = base.getBitWidthOrSentinel();
5085 if (width == -2)
5086 return emitInferRetTypeError(loc, "operand must be a scalar type");
5087 return SIntType::get(input.getContext(), width, base.isConst());
5088}
5089
5090FIRRTLType AsUIntPrimOp::inferReturnType(FIRRTLType input,
5091 std::optional<Location> loc) {
5092 auto base = type_dyn_cast<FIRRTLBaseType>(input);
5093 if (!base)
5094 return emitInferRetTypeError(loc, "operand must be a scalar base type");
5095 int32_t width = base.getBitWidthOrSentinel();
5096 if (width == -2)
5097 return emitInferRetTypeError(loc, "operand must be a scalar type");
5098 return UIntType::get(input.getContext(), width, base.isConst());
5099}
5100
5101FIRRTLType AsAsyncResetPrimOp::inferReturnType(FIRRTLType input,
5102 std::optional<Location> loc) {
5103 auto base = type_dyn_cast<FIRRTLBaseType>(input);
5104 if (!base)
5105 return emitInferRetTypeError(loc,
5106 "operand must be single bit scalar base type");
5107 int32_t width = base.getBitWidthOrSentinel();
5108 if (width == -2 || width == 0 || width > 1)
5109 return emitInferRetTypeError(loc, "operand must be single bit scalar type");
5110 return AsyncResetType::get(input.getContext(), base.isConst());
5111}
5112
5113FIRRTLType AsClockPrimOp::inferReturnType(FIRRTLType input,
5114 std::optional<Location> loc) {
5115 return ClockType::get(input.getContext(), isConst(input));
5116}
5117
5118FIRRTLType CvtPrimOp::inferReturnType(FIRRTLType input,
5119 std::optional<Location> loc) {
5120 if (auto uiType = type_dyn_cast<UIntType>(input)) {
5121 auto width = uiType.getWidthOrSentinel();
5122 if (width != -1)
5123 ++width;
5124 return SIntType::get(input.getContext(), width, uiType.isConst());
5125 }
5126
5127 if (type_isa<SIntType>(input))
5128 return input;
5129
5130 return emitInferRetTypeError(loc, "operand must have integer type");
5131}
5132
5133FIRRTLType NegPrimOp::inferReturnType(FIRRTLType input,
5134 std::optional<Location> loc) {
5135 auto inputi = type_dyn_cast<IntType>(input);
5136 if (!inputi)
5137 return emitInferRetTypeError(loc, "operand must have integer type");
5138 int32_t width = inputi.getWidthOrSentinel();
5139 if (width != -1)
5140 ++width;
5141 return SIntType::get(input.getContext(), width, inputi.isConst());
5142}
5143
5144FIRRTLType NotPrimOp::inferReturnType(FIRRTLType input,
5145 std::optional<Location> loc) {
5146 auto inputi = type_dyn_cast<IntType>(input);
5147 if (!inputi)
5148 return emitInferRetTypeError(loc, "operand must have integer type");
5149 if (isa<UIntType>(inputi))
5150 return inputi;
5151 return UIntType::get(input.getContext(), inputi.getWidthOrSentinel(),
5152 inputi.isConst());
5153}
5154
5156 std::optional<Location> loc) {
5157 return UIntType::get(input.getContext(), 1, isConst(input));
5158}
5159
5160//===----------------------------------------------------------------------===//
5161// Other Operations
5162//===----------------------------------------------------------------------===//
5163
5164FIRRTLType BitsPrimOp::inferReturnType(FIRRTLType input, int64_t high,
5165 int64_t low,
5166 std::optional<Location> loc) {
5167 auto inputi = type_dyn_cast<IntType>(input);
5168 if (!inputi)
5169 return emitInferRetTypeError(
5170 loc, "input type should be the int type but got ", input);
5171
5172 // High must be >= low and both most be non-negative.
5173 if (high < low)
5174 return emitInferRetTypeError(
5175 loc, "high must be equal or greater than low, but got high = ", high,
5176 ", low = ", low);
5177
5178 if (low < 0)
5179 return emitInferRetTypeError(loc, "low must be non-negative but got ", low);
5180
5181 // If the input has staticly known width, check it. Both and low must be
5182 // strictly less than width.
5183 int32_t width = inputi.getWidthOrSentinel();
5184 if (width != -1 && high >= width)
5185 return emitInferRetTypeError(
5186 loc,
5187 "high must be smaller than the width of input, but got high = ", high,
5188 ", width = ", width);
5189
5190 return UIntType::get(input.getContext(), high - low + 1, inputi.isConst());
5191}
5192
5193FIRRTLType HeadPrimOp::inferReturnType(FIRRTLType input, int64_t amount,
5194 std::optional<Location> loc) {
5195
5196 auto inputi = type_dyn_cast<IntType>(input);
5197 if (amount < 0 || !inputi)
5198 return emitInferRetTypeError(
5199 loc, "operand must have integer type and amount must be >= 0");
5200
5201 int32_t width = inputi.getWidthOrSentinel();
5202 if (width != -1 && amount > width)
5203 return emitInferRetTypeError(loc, "amount larger than input width");
5204
5205 return UIntType::get(input.getContext(), amount, inputi.isConst());
5206}
5207
5208/// Infer the result type for a multiplexer given its two operand types, which
5209/// may be aggregates.
5210///
5211/// This essentially performs a pairwise comparison of fields and elements, as
5212/// follows:
5213/// - Identical operands inferred to their common type
5214/// - Integer operands inferred to the larger one if both have a known width, a
5215/// widthless integer otherwise.
5216/// - Vectors inferred based on the element type.
5217/// - Bundles inferred in a pairwise fashion based on the field types.
5219 FIRRTLBaseType low,
5220 bool isConstCondition,
5221 std::optional<Location> loc) {
5222 // If the types are identical we're done.
5223 if (high == low)
5224 return isConstCondition ? low : low.getAllConstDroppedType();
5225
5226 // The base types need to be equivalent.
5227 if (high.getTypeID() != low.getTypeID())
5228 return emitInferRetTypeError<FIRRTLBaseType>(
5229 loc, "incompatible mux operand types, true value type: ", high,
5230 ", false value type: ", low);
5231
5232 bool outerTypeIsConst = isConstCondition && low.isConst() && high.isConst();
5233
5234 // Two different Int types can be compatible. If either has unknown width,
5235 // then return it. If both are known but different width, then return the
5236 // larger one.
5237 if (type_isa<IntType>(low)) {
5238 int32_t highWidth = high.getBitWidthOrSentinel();
5239 int32_t lowWidth = low.getBitWidthOrSentinel();
5240 if (lowWidth == -1)
5241 return low.getConstType(outerTypeIsConst);
5242 if (highWidth == -1)
5243 return high.getConstType(outerTypeIsConst);
5244 return (lowWidth > highWidth ? low : high).getConstType(outerTypeIsConst);
5245 }
5246
5247 // Infer vector types by comparing the element types.
5248 auto highVector = type_dyn_cast<FVectorType>(high);
5249 auto lowVector = type_dyn_cast<FVectorType>(low);
5250 if (highVector && lowVector &&
5251 highVector.getNumElements() == lowVector.getNumElements()) {
5252 auto inner = inferMuxReturnType(highVector.getElementTypePreservingConst(),
5253 lowVector.getElementTypePreservingConst(),
5254 isConstCondition, loc);
5255 if (!inner)
5256 return {};
5257 return FVectorType::get(inner, lowVector.getNumElements(),
5258 outerTypeIsConst);
5259 }
5260
5261 // Infer bundle types by inferring names in a pairwise fashion.
5262 auto highBundle = type_dyn_cast<BundleType>(high);
5263 auto lowBundle = type_dyn_cast<BundleType>(low);
5264 if (highBundle && lowBundle) {
5265 auto highElements = highBundle.getElements();
5266 auto lowElements = lowBundle.getElements();
5267 size_t numElements = highElements.size();
5268
5269 SmallVector<BundleType::BundleElement> newElements;
5270 if (numElements == lowElements.size()) {
5271 bool failed = false;
5272 for (size_t i = 0; i < numElements; ++i) {
5273 if (highElements[i].name != lowElements[i].name ||
5274 highElements[i].isFlip != lowElements[i].isFlip) {
5275 failed = true;
5276 break;
5277 }
5278 auto element = highElements[i];
5279 element.type = inferMuxReturnType(
5280 highBundle.getElementTypePreservingConst(i),
5281 lowBundle.getElementTypePreservingConst(i), isConstCondition, loc);
5282 if (!element.type)
5283 return {};
5284 newElements.push_back(element);
5285 }
5286 if (!failed)
5287 return BundleType::get(low.getContext(), newElements, outerTypeIsConst);
5288 }
5289 return emitInferRetTypeError<FIRRTLBaseType>(
5290 loc, "incompatible mux operand bundle fields, true value type: ", high,
5291 ", false value type: ", low);
5292 }
5293
5294 // If we arrive here the types of the two mux arms are fundamentally
5295 // incompatible.
5296 return emitInferRetTypeError<FIRRTLBaseType>(
5297 loc, "invalid mux operand types, true value type: ", high,
5298 ", false value type: ", low);
5299}
5300
5301FIRRTLType MuxPrimOp::inferReturnType(FIRRTLType sel, FIRRTLType high,
5302 FIRRTLType low,
5303 std::optional<Location> loc) {
5304 auto highType = type_dyn_cast<FIRRTLBaseType>(high);
5305 auto lowType = type_dyn_cast<FIRRTLBaseType>(low);
5306 if (!highType || !lowType)
5307 return emitInferRetTypeError(loc, "operands must be base type");
5308 return inferMuxReturnType(highType, lowType, isConst(sel), loc);
5309}
5310
5311FIRRTLType Mux2CellIntrinsicOp::inferReturnType(ValueRange operands,
5312 DictionaryAttr attrs,
5313 OpaqueProperties properties,
5314 mlir::RegionRange regions,
5315 std::optional<Location> loc) {
5316 auto highType = type_dyn_cast<FIRRTLBaseType>(operands[1].getType());
5317 auto lowType = type_dyn_cast<FIRRTLBaseType>(operands[2].getType());
5318 if (!highType || !lowType)
5319 return emitInferRetTypeError(loc, "operands must be base type");
5320 return inferMuxReturnType(highType, lowType, isConst(operands[0].getType()),
5321 loc);
5322}
5323
5324FIRRTLType Mux4CellIntrinsicOp::inferReturnType(ValueRange operands,
5325 DictionaryAttr attrs,
5326 OpaqueProperties properties,
5327 mlir::RegionRange regions,
5328 std::optional<Location> loc) {
5329 SmallVector<FIRRTLBaseType> types;
5330 FIRRTLBaseType result;
5331 for (unsigned i = 1; i < 5; i++) {
5332 types.push_back(type_dyn_cast<FIRRTLBaseType>(operands[i].getType()));
5333 if (!types.back())
5334 return emitInferRetTypeError(loc, "operands must be base type");
5335 if (result) {
5336 result = inferMuxReturnType(result, types.back(),
5337 isConst(operands[0].getType()), loc);
5338 if (!result)
5339 return result;
5340 } else {
5341 result = types.back();
5342 }
5343 }
5344 return result;
5345}
5346
5347FIRRTLType PadPrimOp::inferReturnType(FIRRTLType input, int64_t amount,
5348 std::optional<Location> loc) {
5349 auto inputi = type_dyn_cast<IntType>(input);
5350 if (amount < 0 || !inputi)
5351 return emitInferRetTypeError(
5352 loc, "pad input must be integer and amount must be >= 0");
5353
5354 int32_t width = inputi.getWidthOrSentinel();
5355 if (width == -1)
5356 return inputi;
5357
5358 width = std::max<int32_t>(width, amount);
5359 return IntType::get(input.getContext(), inputi.isSigned(), width,
5360 inputi.isConst());
5361}
5362
5363FIRRTLType ShlPrimOp::inferReturnType(FIRRTLType input, int64_t amount,
5364 std::optional<Location> loc) {
5365 auto inputi = type_dyn_cast<IntType>(input);
5366 if (amount < 0 || !inputi)
5367 return emitInferRetTypeError(
5368 loc, "shl input must be integer and amount must be >= 0");
5369
5370 int32_t width = inputi.getWidthOrSentinel();
5371 if (width != -1)
5372 width += amount;
5373
5374 return IntType::get(input.getContext(), inputi.isSigned(), width,
5375 inputi.isConst());
5376}
5377
5378FIRRTLType ShrPrimOp::inferReturnType(FIRRTLType input, int64_t amount,
5379 std::optional<Location> loc) {
5380 auto inputi = type_dyn_cast<IntType>(input);
5381 if (amount < 0 || !inputi)
5382 return emitInferRetTypeError(
5383 loc, "shr input must be integer and amount must be >= 0");
5384
5385 int32_t width = inputi.getWidthOrSentinel();
5386 if (width != -1) {
5387 // UInt saturates at 0 bits, SInt at 1 bit
5388 int32_t minWidth = inputi.isUnsigned() ? 0 : 1;
5389 width = std::max<int32_t>(minWidth, width - amount);
5390 }
5391
5392 return IntType::get(input.getContext(), inputi.isSigned(), width,
5393 inputi.isConst());
5394}
5395
5396FIRRTLType TailPrimOp::inferReturnType(FIRRTLType input, int64_t amount,
5397 std::optional<Location> loc) {
5398
5399 auto inputi = type_dyn_cast<IntType>(input);
5400 if (amount < 0 || !inputi)
5401 return emitInferRetTypeError(
5402 loc, "tail input must be integer and amount must be >= 0");
5403
5404 int32_t width = inputi.getWidthOrSentinel();
5405 if (width != -1) {
5406 if (width < amount)
5407 return emitInferRetTypeError(
5408 loc, "amount must be less than or equal operand width");
5409 width -= amount;
5410 }
5411
5412 return IntType::get(input.getContext(), false, width, inputi.isConst());
5413}
5414
5415//===----------------------------------------------------------------------===//
5416// VerbatimExprOp
5417//===----------------------------------------------------------------------===//
5418
5419void VerbatimExprOp::getAsmResultNames(
5420 function_ref<void(Value, StringRef)> setNameFn) {
5421 // If the text is macro like, then use a pretty name. We only take the
5422 // text up to a weird character (like a paren) and currently ignore
5423 // parenthesized expressions.
5424 auto isOkCharacter = [](char c) { return llvm::isAlnum(c) || c == '_'; };
5425 auto name = getText();
5426 // Ignore a leading ` in macro name.
5427 if (name.starts_with("`"))
5428 name = name.drop_front();
5429 name = name.take_while(isOkCharacter);
5430 if (!name.empty())
5431 setNameFn(getResult(), name);
5432}
5433
5434//===----------------------------------------------------------------------===//
5435// VerbatimWireOp
5436//===----------------------------------------------------------------------===//
5437
5438void VerbatimWireOp::getAsmResultNames(
5439 function_ref<void(Value, StringRef)> setNameFn) {
5440 // If the text is macro like, then use a pretty name. We only take the
5441 // text up to a weird character (like a paren) and currently ignore
5442 // parenthesized expressions.
5443 auto isOkCharacter = [](char c) { return llvm::isAlnum(c) || c == '_'; };
5444 auto name = getText();
5445 // Ignore a leading ` in macro name.
5446 if (name.starts_with("`"))
5447 name = name.drop_front();
5448 name = name.take_while(isOkCharacter);
5449 if (!name.empty())
5450 setNameFn(getResult(), name);
5451}
5452
5453//===----------------------------------------------------------------------===//
5454// DPICallIntrinsicOp
5455//===----------------------------------------------------------------------===//
5456
5457static bool isTypeAllowedForDPI(Operation *op, Type type) {
5458 return !type.walk([&](firrtl::IntType intType) -> mlir::WalkResult {
5459 auto width = intType.getWidth();
5460 if (width < 0) {
5461 op->emitError() << "unknown width is not allowed for DPI";
5462 return WalkResult::interrupt();
5463 }
5464 if (width == 1 || width == 8 || width == 16 || width == 32 ||
5465 width >= 64)
5466 return WalkResult::advance();
5467 op->emitError()
5468 << "integer types used by DPI functions must have a "
5469 "specific bit width; "
5470 "it must be equal to 1(bit), 8(byte), 16(shortint), "
5471 "32(int), 64(longint) "
5472 "or greater than 64, but got "
5473 << intType;
5474 return WalkResult::interrupt();
5475 })
5476 .wasInterrupted();
5477}
5478
5479LogicalResult DPICallIntrinsicOp::verify() {
5480 if (auto inputNames = getInputNames()) {
5481 if (getInputs().size() != inputNames->size())
5482 return emitError() << "inputNames has " << inputNames->size()
5483 << " elements but there are " << getInputs().size()
5484 << " input arguments";
5485 }
5486 if (auto outputName = getOutputName())
5487 if (getNumResults() == 0)
5488 return emitError() << "output name is given but there is no result";
5489
5490 auto checkType = [this](Type type) {
5491 return isTypeAllowedForDPI(*this, type);
5492 };
5493 return success(llvm::all_of(this->getResultTypes(), checkType) &&
5494 llvm::all_of(this->getOperandTypes(), checkType));
5495}
5496
5497SmallVector<std::pair<circt::FieldRef, circt::FieldRef>>
5498DPICallIntrinsicOp::computeDataFlow() {
5499 if (getClock())
5500 return {};
5501
5502 SmallVector<std::pair<circt::FieldRef, circt::FieldRef>> deps;
5503
5504 for (auto operand : getOperands()) {
5505 auto type = type_cast<FIRRTLBaseType>(operand.getType());
5506 auto baseFieldRef = getFieldRefFromValue(operand);
5507 SmallVector<circt::FieldRef> operandFields;
5509 type, [&](uint64_t dstIndex, FIRRTLBaseType t, bool dstIsFlip) {
5510 operandFields.push_back(baseFieldRef.getSubField(dstIndex));
5511 });
5512
5513 // Record operand -> result dependency.
5514 for (auto result : getResults())
5516 type, [&](uint64_t dstIndex, FIRRTLBaseType t, bool dstIsFlip) {
5517 for (auto field : operandFields)
5518 deps.emplace_back(circt::FieldRef(result, dstIndex), field);
5519 });
5520 }
5521 return deps;
5522}
5523
5524//===----------------------------------------------------------------------===//
5525// Conversions to/from structs in the standard dialect.
5526//===----------------------------------------------------------------------===//
5527
5528LogicalResult HWStructCastOp::verify() {
5529 // We must have a bundle and a struct, with matching pairwise fields
5530 BundleType bundleType;
5531 hw::StructType structType;
5532 if ((bundleType = type_dyn_cast<BundleType>(getOperand().getType()))) {
5533 structType = dyn_cast<hw::StructType>(getType());
5534 if (!structType)
5535 return emitError("result type must be a struct");
5536 } else if ((bundleType = type_dyn_cast<BundleType>(getType()))) {
5537 structType = dyn_cast<hw::StructType>(getOperand().getType());
5538 if (!structType)
5539 return emitError("operand type must be a struct");
5540 } else {
5541 return emitError("either source or result type must be a bundle type");
5542 }
5543
5544 auto firFields = bundleType.getElements();
5545 auto hwFields = structType.getElements();
5546 if (firFields.size() != hwFields.size())
5547 return emitError("bundle and struct have different number of fields");
5548
5549 for (size_t findex = 0, fend = firFields.size(); findex < fend; ++findex) {
5550 if (firFields[findex].name.getValue() != hwFields[findex].name)
5551 return emitError("field names don't match '")
5552 << firFields[findex].name.getValue() << "', '"
5553 << hwFields[findex].name.getValue() << "'";
5554 int64_t firWidth =
5555 FIRRTLBaseType(firFields[findex].type).getBitWidthOrSentinel();
5556 int64_t hwWidth = hw::getBitWidth(hwFields[findex].type);
5557 if (firWidth > 0 && hwWidth > 0 && firWidth != hwWidth)
5558 return emitError("size of field '")
5559 << hwFields[findex].name.getValue() << "' don't match " << firWidth
5560 << ", " << hwWidth;
5561 }
5562
5563 return success();
5564}
5565
5566LogicalResult BitCastOp::verify() {
5567 auto inTypeBits = getBitWidth(getInput().getType(), /*ignoreFlip=*/true);
5568 auto resTypeBits = getBitWidth(getType());
5569 if (inTypeBits.has_value() && resTypeBits.has_value()) {
5570 // Bitwidths must match for valid bit
5571 if (*inTypeBits == *resTypeBits) {
5572 // non-'const' cannot be casted to 'const'
5573 if (containsConst(getType()) && !isConst(getOperand().getType()))
5574 return emitError("cannot cast non-'const' input type ")
5575 << getOperand().getType() << " to 'const' result type "
5576 << getType();
5577 return success();
5578 }
5579 return emitError("the bitwidth of input (")
5580 << *inTypeBits << ") and result (" << *resTypeBits
5581 << ") don't match";
5582 }
5583 if (!inTypeBits.has_value())
5584 return emitError("bitwidth cannot be determined for input operand type ")
5585 << getInput().getType();
5586 return emitError("bitwidth cannot be determined for result type ")
5587 << getType();
5588}
5589
5590//===----------------------------------------------------------------------===//
5591// Custom attr-dict Directive that Elides Annotations
5592//===----------------------------------------------------------------------===//
5593
5594/// Parse an optional attribute dictionary, adding an empty 'annotations'
5595/// attribute if not specified.
5596static ParseResult parseElideAnnotations(OpAsmParser &parser,
5597 NamedAttrList &resultAttrs) {
5598 auto result = parser.parseOptionalAttrDict(resultAttrs);
5599 if (!resultAttrs.get("annotations"))
5600 resultAttrs.append("annotations", parser.getBuilder().getArrayAttr({}));
5601
5602 return result;
5603}
5604
5605static void printElideAnnotations(OpAsmPrinter &p, Operation *op,
5606 DictionaryAttr attr,
5607 ArrayRef<StringRef> extraElides = {}) {
5608 SmallVector<StringRef> elidedAttrs(extraElides.begin(), extraElides.end());
5609 // Elide "annotations" if it is empty.
5610 if (op->getAttrOfType<ArrayAttr>("annotations").empty())
5611 elidedAttrs.push_back("annotations");
5612 // Elide "nameKind".
5613 elidedAttrs.push_back("nameKind");
5614
5615 p.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
5616}
5617
5618/// Parse an optional attribute dictionary, adding empty 'annotations' and
5619/// 'portAnnotations' attributes if not specified.
5620static ParseResult parseElidePortAnnotations(OpAsmParser &parser,
5621 NamedAttrList &resultAttrs) {
5622 auto result = parseElideAnnotations(parser, resultAttrs);
5623
5624 if (!resultAttrs.get("portAnnotations")) {
5625 SmallVector<Attribute, 16> portAnnotations(
5626 parser.getNumResults(), parser.getBuilder().getArrayAttr({}));
5627 resultAttrs.append("portAnnotations",
5628 parser.getBuilder().getArrayAttr(portAnnotations));
5629 }
5630 return result;
5631}
5632
5633// Elide 'annotations' and 'portAnnotations' attributes if they are empty.
5634static void printElidePortAnnotations(OpAsmPrinter &p, Operation *op,
5635 DictionaryAttr attr,
5636 ArrayRef<StringRef> extraElides = {}) {
5637 SmallVector<StringRef, 2> elidedAttrs(extraElides.begin(), extraElides.end());
5638
5639 if (llvm::all_of(op->getAttrOfType<ArrayAttr>("portAnnotations"),
5640 [&](Attribute a) { return cast<ArrayAttr>(a).empty(); }))
5641 elidedAttrs.push_back("portAnnotations");
5642 printElideAnnotations(p, op, attr, elidedAttrs);
5643}
5644
5645//===----------------------------------------------------------------------===//
5646// NameKind Custom Directive
5647//===----------------------------------------------------------------------===//
5648
5649static ParseResult parseNameKind(OpAsmParser &parser,
5650 firrtl::NameKindEnumAttr &result) {
5651 StringRef keyword;
5652
5653 if (!parser.parseOptionalKeyword(&keyword,
5654 {"interesting_name", "droppable_name"})) {
5655 auto kind = symbolizeNameKindEnum(keyword);
5656 result = NameKindEnumAttr::get(parser.getContext(), kind.value());
5657 return success();
5658 }
5659
5660 // Default is droppable name.
5661 result =
5662 NameKindEnumAttr::get(parser.getContext(), NameKindEnum::DroppableName);
5663 return success();
5664}
5665
5666static void printNameKind(OpAsmPrinter &p, Operation *op,
5667 firrtl::NameKindEnumAttr attr,
5668 ArrayRef<StringRef> extraElides = {}) {
5669 if (attr.getValue() != NameKindEnum::DroppableName)
5670 p << " " << stringifyNameKindEnum(attr.getValue());
5671}
5672
5673//===----------------------------------------------------------------------===//
5674// ImplicitSSAName Custom Directive
5675//===----------------------------------------------------------------------===//
5676
5677static ParseResult parseFIRRTLImplicitSSAName(OpAsmParser &parser,
5678 NamedAttrList &resultAttrs) {
5679 if (parseElideAnnotations(parser, resultAttrs))
5680 return failure();
5681 inferImplicitSSAName(parser, resultAttrs);
5682 return success();
5683}
5684
5685static void printFIRRTLImplicitSSAName(OpAsmPrinter &p, Operation *op,
5686 DictionaryAttr attrs) {
5687 SmallVector<StringRef, 4> elides;
5689 elides.push_back(Forceable::getForceableAttrName());
5690 elideImplicitSSAName(p, op, attrs, elides);
5691 printElideAnnotations(p, op, attrs, elides);
5692}
5693
5694//===----------------------------------------------------------------------===//
5695// MemOp Custom attr-dict Directive
5696//===----------------------------------------------------------------------===//
5697
5698static ParseResult parseMemOp(OpAsmParser &parser, NamedAttrList &resultAttrs) {
5699 return parseElidePortAnnotations(parser, resultAttrs);
5700}
5701
5702/// Always elide "ruw" and elide "annotations" if it exists or if it is empty.
5703static void printMemOp(OpAsmPrinter &p, Operation *op, DictionaryAttr attr) {
5704 // "ruw" and "inner_sym" is always elided.
5705 printElidePortAnnotations(p, op, attr, {"ruw", "inner_sym"});
5706}
5707
5708//===----------------------------------------------------------------------===//
5709// ClassInterface custom directive
5710//===----------------------------------------------------------------------===//
5711
5712static ParseResult parseClassInterface(OpAsmParser &parser, Type &result) {
5713 ClassType type;
5714 if (ClassType::parseInterface(parser, type))
5715 return failure();
5716 result = type;
5717 return success();
5718}
5719
5720static void printClassInterface(OpAsmPrinter &p, Operation *, ClassType type) {
5721 type.printInterface(p);
5722}
5723
5724//===----------------------------------------------------------------------===//
5725// Miscellaneous custom elision logic.
5726//===----------------------------------------------------------------------===//
5727
5728static ParseResult parseElideEmptyName(OpAsmParser &p,
5729 NamedAttrList &resultAttrs) {
5730 auto result = p.parseOptionalAttrDict(resultAttrs);
5731 if (!resultAttrs.get("name"))
5732 resultAttrs.append("name", p.getBuilder().getStringAttr(""));
5733
5734 return result;
5735}
5736
5737static void printElideEmptyName(OpAsmPrinter &p, Operation *op,
5738 DictionaryAttr attr,
5739 ArrayRef<StringRef> extraElides = {}) {
5740 SmallVector<StringRef> elides(extraElides.begin(), extraElides.end());
5741 if (op->getAttrOfType<StringAttr>("name").getValue().empty())
5742 elides.push_back("name");
5743
5744 p.printOptionalAttrDict(op->getAttrs(), elides);
5745}
5746
5747static ParseResult parsePrintfAttrs(OpAsmParser &p,
5748 NamedAttrList &resultAttrs) {
5749 return parseElideEmptyName(p, resultAttrs);
5750}
5751
5752static void printPrintfAttrs(OpAsmPrinter &p, Operation *op,
5753 DictionaryAttr attr) {
5754 printElideEmptyName(p, op, attr, {"formatString"});
5755}
5756
5757static ParseResult parseStopAttrs(OpAsmParser &p, NamedAttrList &resultAttrs) {
5758 return parseElideEmptyName(p, resultAttrs);
5759}
5760
5761static void printStopAttrs(OpAsmPrinter &p, Operation *op,
5762 DictionaryAttr attr) {
5763 printElideEmptyName(p, op, attr, {"exitCode"});
5764}
5765
5766static ParseResult parseVerifAttrs(OpAsmParser &p, NamedAttrList &resultAttrs) {
5767 return parseElideEmptyName(p, resultAttrs);
5768}
5769
5770static void printVerifAttrs(OpAsmPrinter &p, Operation *op,
5771 DictionaryAttr attr) {
5772 printElideEmptyName(p, op, attr, {"message"});
5773}
5774
5775//===----------------------------------------------------------------------===//
5776// Various namers.
5777//===----------------------------------------------------------------------===//
5778
5779static void genericAsmResultNames(Operation *op,
5780 OpAsmSetValueNameFn setNameFn) {
5781 // Many firrtl dialect operations have an optional 'name' attribute. If
5782 // present, use it.
5783 if (op->getNumResults() == 1)
5784 if (auto nameAttr = op->getAttrOfType<StringAttr>("name"))
5785 setNameFn(op->getResult(0), nameAttr.getValue());
5786}
5787
5788void AddPrimOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
5789 genericAsmResultNames(*this, setNameFn);
5790}
5791
5792void AndPrimOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
5793 genericAsmResultNames(*this, setNameFn);
5794}
5795
5796void AndRPrimOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
5797 genericAsmResultNames(*this, setNameFn);
5798}
5799
5800void SizeOfIntrinsicOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
5801 genericAsmResultNames(*this, setNameFn);
5802}
5803void AsAsyncResetPrimOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
5804 genericAsmResultNames(*this, setNameFn);
5805}
5806void AsClockPrimOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
5807 genericAsmResultNames(*this, setNameFn);
5808}
5809void AsSIntPrimOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
5810 genericAsmResultNames(*this, setNameFn);
5811}
5812void AsUIntPrimOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
5813 genericAsmResultNames(*this, setNameFn);
5814}
5815void BitsPrimOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
5816 genericAsmResultNames(*this, setNameFn);
5817}
5818void CatPrimOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
5819 genericAsmResultNames(*this, setNameFn);
5820}
5821void CvtPrimOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
5822 genericAsmResultNames(*this, setNameFn);
5823}
5824void DShlPrimOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
5825 genericAsmResultNames(*this, setNameFn);
5826}
5827void DShlwPrimOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
5828 genericAsmResultNames(*this, setNameFn);
5829}
5830void DShrPrimOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
5831 genericAsmResultNames(*this, setNameFn);
5832}
5833void DivPrimOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
5834 genericAsmResultNames(*this, setNameFn);
5835}
5836void EQPrimOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
5837 genericAsmResultNames(*this, setNameFn);
5838}
5839void GEQPrimOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
5840 genericAsmResultNames(*this, setNameFn);
5841}
5842void GTPrimOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
5843 genericAsmResultNames(*this, setNameFn);
5844}
5845void GenericIntrinsicOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
5846 genericAsmResultNames(*this, setNameFn);
5847}
5848void HeadPrimOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
5849 genericAsmResultNames(*this, setNameFn);
5850}
5851void IntegerAddOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
5852 genericAsmResultNames(*this, setNameFn);
5853}
5854void IntegerMulOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
5855 genericAsmResultNames(*this, setNameFn);
5856}
5857void IntegerShrOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
5858 genericAsmResultNames(*this, setNameFn);
5859}
5860void IntegerShlOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
5861 genericAsmResultNames(*this, setNameFn);
5862}
5863void IsTagOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
5864 genericAsmResultNames(*this, setNameFn);
5865}
5866void IsXIntrinsicOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
5867 genericAsmResultNames(*this, setNameFn);
5868}
5869void PlusArgsValueIntrinsicOp::getAsmResultNames(
5870 OpAsmSetValueNameFn setNameFn) {
5871 genericAsmResultNames(*this, setNameFn);
5872}
5873void PlusArgsTestIntrinsicOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
5874 genericAsmResultNames(*this, setNameFn);
5875}
5876void LEQPrimOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
5877 genericAsmResultNames(*this, setNameFn);
5878}
5879void LTPrimOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
5880 genericAsmResultNames(*this, setNameFn);
5881}
5882void MulPrimOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
5883 genericAsmResultNames(*this, setNameFn);
5884}
5885void MultibitMuxOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
5886 genericAsmResultNames(*this, setNameFn);
5887}
5888void MuxPrimOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
5889 genericAsmResultNames(*this, setNameFn);
5890}
5891void Mux4CellIntrinsicOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
5892 genericAsmResultNames(*this, setNameFn);
5893}
5894void Mux2CellIntrinsicOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
5895 genericAsmResultNames(*this, setNameFn);
5896}
5897void NEQPrimOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
5898 genericAsmResultNames(*this, setNameFn);
5899}
5900void NegPrimOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
5901 genericAsmResultNames(*this, setNameFn);
5902}
5903void NotPrimOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
5904 genericAsmResultNames(*this, setNameFn);
5905}
5906void OrPrimOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
5907 genericAsmResultNames(*this, setNameFn);
5908}
5909void OrRPrimOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
5910 genericAsmResultNames(*this, setNameFn);
5911}
5912void PadPrimOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
5913 genericAsmResultNames(*this, setNameFn);
5914}
5915void RemPrimOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
5916 genericAsmResultNames(*this, setNameFn);
5917}
5918void ShlPrimOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
5919 genericAsmResultNames(*this, setNameFn);
5920}
5921void ShrPrimOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
5922 genericAsmResultNames(*this, setNameFn);
5923}
5924
5925void SubPrimOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
5926 genericAsmResultNames(*this, setNameFn);
5927}
5928
5929void SubaccessOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
5930 genericAsmResultNames(*this, setNameFn);
5931}
5932
5933void SubfieldOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
5934 genericAsmResultNames(*this, setNameFn);
5935}
5936
5937void OpenSubfieldOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
5938 genericAsmResultNames(*this, setNameFn);
5939}
5940
5941void SubtagOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
5942 genericAsmResultNames(*this, setNameFn);
5943}
5944
5945void SubindexOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
5946 genericAsmResultNames(*this, setNameFn);
5947}
5948
5949void OpenSubindexOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
5950 genericAsmResultNames(*this, setNameFn);
5951}
5952
5953void TagExtractOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
5954 genericAsmResultNames(*this, setNameFn);
5955}
5956
5957void TailPrimOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
5958 genericAsmResultNames(*this, setNameFn);
5959}
5960
5961void XorPrimOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
5962 genericAsmResultNames(*this, setNameFn);
5963}
5964
5965void XorRPrimOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
5966 genericAsmResultNames(*this, setNameFn);
5967}
5968
5969void UninferredResetCastOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
5970 genericAsmResultNames(*this, setNameFn);
5971}
5972
5973void ConstCastOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
5974 genericAsmResultNames(*this, setNameFn);
5975}
5976
5977void ElementwiseXorPrimOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
5978 genericAsmResultNames(*this, setNameFn);
5979}
5980
5981void ElementwiseOrPrimOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
5982 genericAsmResultNames(*this, setNameFn);
5983}
5984
5985void ElementwiseAndPrimOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
5986 genericAsmResultNames(*this, setNameFn);
5987}
5988
5989//===----------------------------------------------------------------------===//
5990// RefOps
5991//===----------------------------------------------------------------------===//
5992
5993void RefCastOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
5994 genericAsmResultNames(*this, setNameFn);
5995}
5996
5997void RefResolveOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
5998 genericAsmResultNames(*this, setNameFn);
5999}
6000
6001void RefSendOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
6002 genericAsmResultNames(*this, setNameFn);
6003}
6004
6005void RefSubOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
6006 genericAsmResultNames(*this, setNameFn);
6007}
6008
6009void RWProbeOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
6010 genericAsmResultNames(*this, setNameFn);
6011}
6012
6013FIRRTLType RefResolveOp::inferReturnType(ValueRange operands,
6014 DictionaryAttr attrs,
6015 OpaqueProperties properties,
6016 mlir::RegionRange regions,
6017 std::optional<Location> loc) {
6018 Type inType = operands[0].getType();
6019 auto inRefType = type_dyn_cast<RefType>(inType);
6020 if (!inRefType)
6021 return emitInferRetTypeError(
6022 loc, "ref.resolve operand must be ref type, not ", inType);
6023 return inRefType.getType();
6024}
6025
6026FIRRTLType RefSendOp::inferReturnType(ValueRange operands, DictionaryAttr attrs,
6027 OpaqueProperties properties,
6028 mlir::RegionRange regions,
6029 std::optional<Location> loc) {
6030 Type inType = operands[0].getType();
6031 auto inBaseType = type_dyn_cast<FIRRTLBaseType>(inType);
6032 if (!inBaseType)
6033 return emitInferRetTypeError(
6034 loc, "ref.send operand must be base type, not ", inType);
6035 return RefType::get(inBaseType.getPassiveType());
6036}
6037
6038FIRRTLType RefSubOp::inferReturnType(Type type, uint32_t fieldIndex,
6039 std::optional<Location> loc) {
6040 auto refType = type_dyn_cast<RefType>(type);
6041 if (!refType)
6042 return emitInferRetTypeError(loc, "input must be of reference type");
6043 auto inType = refType.getType();
6044
6045 // TODO: Determine ref.sub + rwprobe behavior, test.
6046 // Probably best to demote to non-rw, but that has implications
6047 // for any LowerTypes behavior being relied on.
6048 // Allow for now, as need to LowerTypes things generally.
6049 if (auto vectorType = type_dyn_cast<FVectorType>(inType)) {
6050 if (fieldIndex < vectorType.getNumElements())
6051 return RefType::get(
6052 vectorType.getElementType().getConstType(
6053 vectorType.isConst() || vectorType.getElementType().isConst()),
6054 refType.getForceable(), refType.getLayer());
6055 return emitInferRetTypeError(loc, "out of range index '", fieldIndex,
6056 "' in RefType of vector type ", refType);
6057 }
6058 if (auto bundleType = type_dyn_cast<BundleType>(inType)) {
6059 if (fieldIndex >= bundleType.getNumElements()) {
6060 return emitInferRetTypeError(loc,
6061 "subfield element index is greater than "
6062 "the number of fields in the bundle type");
6063 }
6064 return RefType::get(
6065 bundleType.getElement(fieldIndex)
6066 .type.getConstType(
6067 bundleType.isConst() ||
6068 bundleType.getElement(fieldIndex).type.isConst()),
6069 refType.getForceable(), refType.getLayer());
6070 }
6071
6072 return emitInferRetTypeError(
6073 loc, "ref.sub op requires a RefType of vector or bundle base type");
6074}
6075
6076LogicalResult RefCastOp::verify() {
6077 auto srcLayers = getLayersFor(getInput());
6078 auto dstLayers = getLayersFor(getResult());
6079 SmallVector<SymbolRefAttr> missingLayers;
6080 if (!isLayerSetCompatibleWith(srcLayers, dstLayers, missingLayers)) {
6081 auto diag =
6082 emitOpError("cannot discard layer requirements of input reference");
6083 auto &note = diag.attachNote();
6084 note << "discarding layer requirements: ";
6085 llvm::interleaveComma(missingLayers, note);
6086 return failure();
6087 }
6088 return success();
6089}
6090
6091LogicalResult RefResolveOp::verify() {
6092 auto srcLayers = getLayersFor(getRef());
6093 auto dstLayers = getAmbientLayersAt(getOperation());
6094 SmallVector<SymbolRefAttr> missingLayers;
6095 if (!isLayerSetCompatibleWith(srcLayers, dstLayers, missingLayers)) {
6096 auto diag =
6097 emitOpError("ambient layers are insufficient to resolve reference");
6098 auto &note = diag.attachNote();
6099 note << "missing layer requirements: ";
6100 interleaveComma(missingLayers, note);
6101 return failure();
6102 }
6103 return success();
6104}
6105
6106LogicalResult RWProbeOp::verifyInnerRefs(hw::InnerRefNamespace &ns) {
6107 auto targetRef = getTarget();
6108 if (targetRef.getModule() !=
6109 (*this)->getParentOfType<FModuleLike>().getModuleNameAttr())
6110 return emitOpError() << "has non-local target";
6111
6112 auto target = ns.lookup(targetRef);
6113 if (!target)
6114 return emitOpError() << "has target that cannot be resolved: " << targetRef;
6115
6116 auto checkFinalType = [&](auto type, Location loc) -> LogicalResult {
6117 // Determine final type.
6118 mlir::Type fType =
6119 hw::FieldIdImpl::getFinalTypeByFieldID(type, target.getField());
6120 // Check.
6121 auto baseType = type_dyn_cast<FIRRTLBaseType>(fType);
6122 if (!baseType || baseType.getPassiveType() != getType().getType()) {
6123 auto diag = emitOpError("has type mismatch: target resolves to ")
6124 << fType << " instead of expected " << getType().getType();
6125 diag.attachNote(loc) << "target resolves here";
6126 return diag;
6127 }
6128 return success();
6129 };
6130
6131 auto checkLayers = [&](Location loc) -> LogicalResult {
6132 auto dstLayers = getAmbientLayersAt(target.getOp());
6133 auto srcLayers = getLayersFor(getResult());
6134 SmallVector<SymbolRefAttr> missingLayers;
6135 if (!isLayerSetCompatibleWith(srcLayers, dstLayers, missingLayers)) {
6136 auto diag = emitOpError("target has insufficient layer requirements");
6137 auto &note = diag.attachNote(loc);
6138 note << "target is missing layer requirements: ";
6139 llvm::interleaveComma(missingLayers, note);
6140 return failure();
6141 }
6142 return success();
6143 };
6144 auto checks = [&](auto type, Location loc) {
6145 if (failed(checkLayers(loc)))
6146 return failure();
6147 return checkFinalType(type, loc);
6148 };
6149
6150 if (target.isPort()) {
6151 auto mod = cast<FModuleLike>(target.getOp());
6152 return checks(mod.getPortType(target.getPort()),
6153 mod.getPortLocation(target.getPort()));
6154 }
6155 hw::InnerSymbolOpInterface symOp =
6156 cast<hw::InnerSymbolOpInterface>(target.getOp());
6157 if (!symOp.getTargetResult())
6158 return emitOpError("has target that cannot be probed")
6159 .attachNote(symOp.getLoc())
6160 .append("target resolves here");
6161 auto *ancestor =
6162 symOp.getTargetResult().getParentBlock()->findAncestorOpInBlock(**this);
6163 if (!ancestor || !symOp->isBeforeInBlock(ancestor))
6164 return emitOpError("is not dominated by target")
6165 .attachNote(symOp.getLoc())
6166 .append("target here");
6167 return checks(symOp.getTargetResult().getType(), symOp.getLoc());
6168}
6169
6170//===----------------------------------------------------------------------===//
6171// Layer Block Operations
6172//===----------------------------------------------------------------------===//
6173
6174LogicalResult LayerBlockOp::verify() {
6175 auto layerName = getLayerName();
6176 auto *parentOp = (*this)->getParentOp();
6177
6178 // Get parent operation that isn't a when or match.
6179 while (isa<WhenOp, MatchOp>(parentOp))
6180 parentOp = parentOp->getParentOp();
6181
6182 // Verify the correctness of the symbol reference. Only verify that this
6183 // layer block makes sense in its parent module or layer block.
6184 auto nestedReferences = layerName.getNestedReferences();
6185 if (nestedReferences.empty()) {
6186 if (!isa<FModuleOp>(parentOp)) {
6187 auto diag = emitOpError() << "has an un-nested layer symbol, but does "
6188 "not have a 'firrtl.module' op as a parent";
6189 return diag.attachNote(parentOp->getLoc())
6190 << "illegal parent op defined here";
6191 }
6192 } else {
6193 auto parentLayerBlock = dyn_cast<LayerBlockOp>(parentOp);
6194 if (!parentLayerBlock) {
6195 auto diag = emitOpError()
6196 << "has a nested layer symbol, but does not have a '"
6197 << getOperationName() << "' op as a parent'";
6198 return diag.attachNote(parentOp->getLoc())
6199 << "illegal parent op defined here";
6200 }
6201 auto parentLayerBlockName = parentLayerBlock.getLayerName();
6202 if (parentLayerBlockName.getRootReference() !=
6203 layerName.getRootReference() ||
6204 parentLayerBlockName.getNestedReferences() !=
6205 layerName.getNestedReferences().drop_back()) {
6206 auto diag = emitOpError() << "is nested under an illegal layer block";
6207 return diag.attachNote(parentLayerBlock->getLoc())
6208 << "illegal parent layer block defined here";
6209 }
6210 }
6211
6212 // Verify the body of the region.
6213 FieldRefCache fieldRefCache;
6214 auto result = getBody(0)->walk<mlir::WalkOrder::PreOrder>(
6215 [&](Operation *op) -> WalkResult {
6216 // Skip nested layer blocks. Those will be verified separately.
6217 if (isa<LayerBlockOp>(op))
6218 return WalkResult::skip();
6219
6220 // Check all the operands of each op to make sure that only legal things
6221 // are captured.
6222 for (auto operand : op->getOperands()) {
6223 // Any value captured from the current layer block is fine.
6224 if (auto *definingOp = operand.getDefiningOp())
6225 if (getOperation()->isAncestor(definingOp))
6226 continue;
6227
6228 auto type = operand.getType();
6229
6230 // Capture of a non-base type, e.g., reference, is allowed.
6231 if (isa<PropertyType>(type)) {
6232 auto diag = emitOpError() << "captures a property operand";
6233 diag.attachNote(operand.getLoc()) << "operand is defined here";
6234 diag.attachNote(op->getLoc()) << "operand is used here";
6235 return WalkResult::interrupt();
6236 }
6237 }
6238
6239 // Ensure that the layer block does not drive any sinks outside.
6240 if (auto connect = dyn_cast<FConnectLike>(op)) {
6241 // ref.define is allowed to drive probes outside the layerblock.
6242 if (isa<RefDefineOp>(connect))
6243 return WalkResult::advance();
6244
6245 // Verify that connects only drive values declared in the layer block.
6246 // If we see a non-passive connect destination, then verify that the
6247 // source is in the same layer block so that the source is not driven.
6248 auto dest =
6249 fieldRefCache.getFieldRefFromValue(connect.getDest()).getValue();
6250 bool passive = true;
6251 if (auto type =
6252 type_dyn_cast<FIRRTLBaseType>(connect.getDest().getType()))
6253 passive = type.isPassive();
6254 // TODO: Improve this verifier. This is intentionally _not_ verifying
6255 // a non-passive ConnectLike because it is hugely annoying to do
6256 // so---it requires a full understanding of if the connect is driving
6257 // destination-to-source, source-to-destination, or bi-directionally
6258 // which requires deep inspection of the type. Eventually, the FIRRTL
6259 // pass pipeline will remove all flips (e.g., canonicalize connect to
6260 // matchingconnect) and this hole won't exist.
6261 if (!passive)
6262 return WalkResult::advance();
6263
6264 if (isAncestorOfValueOwner(getOperation(), dest))
6265 return WalkResult::advance();
6266
6267 auto diag =
6268 connect.emitOpError()
6269 << "connects to a destination which is defined outside its "
6270 "enclosing layer block";
6271 diag.attachNote(getLoc()) << "enclosing layer block is defined here";
6272 diag.attachNote(dest.getLoc()) << "destination is defined here";
6273 return WalkResult::interrupt();
6274 }
6275
6276 return WalkResult::advance();
6277 });
6278
6279 return failure(result.wasInterrupted());
6280}
6281
6282LogicalResult
6283LayerBlockOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
6284 auto layerOp =
6285 symbolTable.lookupNearestSymbolFrom<LayerOp>(*this, getLayerNameAttr());
6286 if (!layerOp) {
6287 return emitOpError("invalid symbol reference");
6288 }
6289
6290 return success();
6291}
6292
6293//===----------------------------------------------------------------------===//
6294// TblGen Generated Logic.
6295//===----------------------------------------------------------------------===//
6296
6297// Provide the autogenerated implementation guts for the Op classes.
6298#define GET_OP_CLASSES
6299#include "circt/Dialect/FIRRTL/FIRRTL.cpp.inc"
static bool isAncestor(Block *block, Block *other)
static void printNameKind(OpAsmPrinter &p, Operation *op, firrtl::NameKindEnumAttr attr, ArrayRef< StringRef > extraElides={})
static ParseResult parseNameKind(OpAsmParser &parser, firrtl::NameKindEnumAttr &result)
assert(baseType &&"element must be base type")
MlirType uint64_t numElements
Definition CHIRRTL.cpp:30
MlirType elementType
Definition CHIRRTL.cpp:29
#define isdigit(x)
Definition FIRLexer.cpp:26
static bool printModulePorts(OpAsmPrinter &p, Block *block, ArrayRef< bool > portDirections, ArrayRef< Attribute > portNames, ArrayRef< Attribute > portTypes, ArrayRef< Attribute > portAnnotations, ArrayRef< Attribute > portSyms, ArrayRef< Attribute > portLocs)
Print a list of module ports in the following form: in x: !firrtl.uint<1> [{class = "DontTouch}],...
static LogicalResult verifyProbeType(RefType refType, Location loc, CircuitOp circuitOp, SymbolTableCollection &symbolTable, Twine start)
static SmallVector< PortInfo > getPortImpl(FModuleLike module)