CIRCT 22.0.0git
Loading...
Searching...
No Matches
FIRRTLOps.cpp
Go to the documentation of this file.
1//===- FIRRTLOps.cpp - Implement the FIRRTL operations --------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implement the FIRRTL ops.
10//
11//===----------------------------------------------------------------------===//
12
27#include "circt/Support/Utils.h"
28#include "mlir/IR/BuiltinTypes.h"
29#include "mlir/IR/Diagnostics.h"
30#include "mlir/IR/DialectImplementation.h"
31#include "mlir/IR/PatternMatch.h"
32#include "mlir/IR/SymbolTable.h"
33#include "mlir/Interfaces/FunctionImplementation.h"
34#include "llvm/ADT/BitVector.h"
35#include "llvm/ADT/DenseMap.h"
36#include "llvm/ADT/DenseSet.h"
37#include "llvm/ADT/STLExtras.h"
38#include "llvm/ADT/SmallSet.h"
39#include "llvm/ADT/StringExtras.h"
40#include "llvm/ADT/TypeSwitch.h"
41#include "llvm/Support/Casting.h"
42#include "llvm/Support/FormatVariadic.h"
43
44using llvm::SmallDenseSet;
45using mlir::RegionRange;
46using namespace circt;
47using namespace firrtl;
48using namespace chirrtl;
49
50//===----------------------------------------------------------------------===//
51// Utilities
52//===----------------------------------------------------------------------===//
53
54/// Remove elements from the input array corresponding to set bits in
55/// `indicesToDrop`, returning the elements not mentioned.
56template <typename T>
57static SmallVector<T>
58removeElementsAtIndices(ArrayRef<T> input,
59 const llvm::BitVector &indicesToDrop) {
60#ifndef NDEBUG
61 if (!input.empty()) {
62 int lastIndex = indicesToDrop.find_last();
63 if (lastIndex >= 0)
64 assert((size_t)lastIndex < input.size() && "index out of range");
65 }
66#endif
67
68 // If the input is empty (which is an optimization we do for certain array
69 // attributes), simply return an empty vector.
70 if (input.empty())
71 return {};
72
73 // Copy over the live chunks.
74 size_t lastCopied = 0;
75 SmallVector<T> result;
76 result.reserve(input.size() - indicesToDrop.count());
77
78 for (unsigned indexToDrop : indicesToDrop.set_bits()) {
79 // If we skipped over some valid elements, copy them over.
80 if (indexToDrop > lastCopied) {
81 result.append(input.begin() + lastCopied, input.begin() + indexToDrop);
82 lastCopied = indexToDrop;
83 }
84 // Ignore this value so we don't copy it in the next iteration.
85 ++lastCopied;
86 }
87
88 // If there are live elements at the end, copy them over.
89 if (lastCopied < input.size())
90 result.append(input.begin() + lastCopied, input.end());
91
92 return result;
93}
94
95/// Emit an error if optional location is non-null, return null of return type.
96template <typename RetTy = FIRRTLType, typename... Args>
97static RetTy emitInferRetTypeError(std::optional<Location> loc,
98 const Twine &message, Args &&...args) {
99 if (loc)
100 (mlir::emitError(*loc, message) << ... << std::forward<Args>(args));
101 return {};
102}
103
104bool firrtl::isDuplexValue(Value val) {
105 // Block arguments are not duplex values.
106 while (Operation *op = val.getDefiningOp()) {
107 auto isDuplex =
108 TypeSwitch<Operation *, std::optional<bool>>(op)
109 .Case<SubfieldOp, SubindexOp, SubaccessOp>([&val](auto op) {
110 val = op.getInput();
111 return std::nullopt;
112 })
113 .Case<RegOp, RegResetOp, WireOp>([](auto) { return true; })
114 .Default([](auto) { return false; });
115 if (isDuplex)
116 return *isDuplex;
117 }
118 return false;
119}
120
121SmallVector<std::pair<circt::FieldRef, circt::FieldRef>>
122MemOp::computeDataFlow() {
123 // If read result has non-zero latency, then no combinational dependency
124 // exists.
125 if (getReadLatency() > 0)
126 return {};
127 SmallVector<std::pair<circt::FieldRef, circt::FieldRef>> deps;
128 // Add a dependency from the enable and address fields to the data field.
129 for (auto memPort : getResults())
130 if (auto type = type_dyn_cast<BundleType>(memPort.getType())) {
131 auto enableFieldId = type.getFieldID((unsigned)ReadPortSubfield::en);
132 auto addressFieldId = type.getFieldID((unsigned)ReadPortSubfield::addr);
133 auto dataFieldId = type.getFieldID((unsigned)ReadPortSubfield::data);
134 deps.emplace_back(
135 FieldRef(memPort, static_cast<unsigned>(dataFieldId)),
136 FieldRef(memPort, static_cast<unsigned>(enableFieldId)));
137 deps.emplace_back(
138 FieldRef(memPort, static_cast<unsigned>(dataFieldId)),
139 FieldRef(memPort, static_cast<unsigned>(addressFieldId)));
140 }
141 return deps;
142}
143
144/// Return the kind of port this is given the port type from a 'mem' decl.
145static MemOp::PortKind getMemPortKindFromType(FIRRTLType type) {
146 constexpr unsigned int addr = 1 << 0;
147 constexpr unsigned int en = 1 << 1;
148 constexpr unsigned int clk = 1 << 2;
149 constexpr unsigned int data = 1 << 3;
150 constexpr unsigned int mask = 1 << 4;
151 constexpr unsigned int rdata = 1 << 5;
152 constexpr unsigned int wdata = 1 << 6;
153 constexpr unsigned int wmask = 1 << 7;
154 constexpr unsigned int wmode = 1 << 8;
155 constexpr unsigned int def = 1 << 9;
156 // Get the kind of port based on the fields of the Bundle.
157 auto portType = type_dyn_cast<BundleType>(type);
158 if (!portType)
159 return MemOp::PortKind::Debug;
160 unsigned fields = 0;
161 // Get the kind of port based on the fields of the Bundle.
162 for (auto elem : portType.getElements()) {
163 fields |= llvm::StringSwitch<unsigned>(elem.name.getValue())
164 .Case("addr", addr)
165 .Case("en", en)
166 .Case("clk", clk)
167 .Case("data", data)
168 .Case("mask", mask)
169 .Case("rdata", rdata)
170 .Case("wdata", wdata)
171 .Case("wmask", wmask)
172 .Case("wmode", wmode)
173 .Default(def);
174 }
175 if (fields == (addr | en | clk | data))
176 return MemOp::PortKind::Read;
177 if (fields == (addr | en | clk | data | mask))
178 return MemOp::PortKind::Write;
179 if (fields == (addr | en | clk | wdata | wmask | rdata | wmode))
180 return MemOp::PortKind::ReadWrite;
181 return MemOp::PortKind::Debug;
182}
183
185 switch (flow) {
186 case Flow::None:
187 return Flow::None;
188 case Flow::Source:
189 return Flow::Sink;
190 case Flow::Sink:
191 return Flow::Source;
192 case Flow::Duplex:
193 return Flow::Duplex;
194 }
195 // Unreachable but silences warning
196 llvm_unreachable("Unsupported Flow type.");
197}
198
199const char *toString(Flow flow) {
200 switch (flow) {
201 case Flow::None:
202 return "no flow";
203 case Flow::Source:
204 return "source flow";
205 case Flow::Sink:
206 return "sink flow";
207 case Flow::Duplex:
208 return "duplex flow";
209 }
210 // Unreachable but silences warning
211 llvm_unreachable("Unsupported Flow type.");
212}
213
214Flow firrtl::foldFlow(Value val, Flow accumulatedFlow) {
215
216 if (auto blockArg = dyn_cast<BlockArgument>(val)) {
217 auto *op = val.getParentBlock()->getParentOp();
218 if (auto moduleLike = dyn_cast<FModuleLike>(op)) {
219 auto direction = moduleLike.getPortDirection(blockArg.getArgNumber());
220 if (direction == Direction::Out)
221 return swapFlow(accumulatedFlow);
222 }
223 return accumulatedFlow;
224 }
225
226 Operation *op = val.getDefiningOp();
227
228 return TypeSwitch<Operation *, Flow>(op)
229 .Case<SubfieldOp, OpenSubfieldOp>([&](auto op) {
230 return foldFlow(op.getInput(), op.isFieldFlipped()
231 ? swapFlow(accumulatedFlow)
232 : accumulatedFlow);
233 })
234 .Case<SubindexOp, SubaccessOp, OpenSubindexOp, RefSubOp>(
235 [&](auto op) { return foldFlow(op.getInput(), accumulatedFlow); })
236 // Registers, Wires, and behavioral memory ports are always Duplex.
237 .Case<RegOp, RegResetOp, WireOp, MemoryPortOp>(
238 [](auto) { return Flow::Duplex; })
239 .Case<InstanceOp, InstanceChoiceOp>([&](auto inst) {
240 auto resultNo = cast<OpResult>(val).getResultNumber();
241 if (inst.getPortDirection(resultNo) == Direction::Out)
242 return accumulatedFlow;
243 return swapFlow(accumulatedFlow);
244 })
245 .Case<MemOp>([&](auto op) {
246 // only debug ports with RefType have source flow.
247 if (type_isa<RefType>(val.getType()))
248 return Flow::Source;
249 return swapFlow(accumulatedFlow);
250 })
251 .Case<ObjectSubfieldOp>([&](ObjectSubfieldOp op) {
252 auto input = op.getInput();
253 auto *inputOp = input.getDefiningOp();
254
255 // We are directly accessing a port on a local declaration.
256 if (auto objectOp = dyn_cast_or_null<ObjectOp>(inputOp)) {
257 auto classType = input.getType();
258 auto direction = classType.getElement(op.getIndex()).direction;
259 if (direction == Direction::In)
260 return Flow::Sink;
261 return Flow::Source;
262 }
263
264 // We are accessing a remote object. Input ports on remote objects are
265 // inaccessible, and thus have Flow::None. Walk backwards through the
266 // chain of subindexes, to detect if we have indexed through an input
267 // port. At the end, either we did index through an input port, or the
268 // entire path was through output ports with source flow.
269 while (true) {
270 auto classType = input.getType();
271 auto direction = classType.getElement(op.getIndex()).direction;
272 if (direction == Direction::In)
273 return Flow::None;
274
275 op = dyn_cast_or_null<ObjectSubfieldOp>(inputOp);
276 if (op) {
277 input = op.getInput();
278 inputOp = input.getDefiningOp();
279 continue;
280 }
281
282 return accumulatedFlow;
283 };
284 })
285 // Anything else acts like a universal source.
286 .Default([&](auto) { return accumulatedFlow; });
287}
288
289// TODO: This is doing the same walk as foldFlow. These two functions can be
290// combined and return a (flow, kind) product.
292 Operation *op = val.getDefiningOp();
293 if (!op)
294 return DeclKind::Port;
295
296 return TypeSwitch<Operation *, DeclKind>(op)
297 .Case<InstanceOp>([](auto) { return DeclKind::Instance; })
298 .Case<SubfieldOp, SubindexOp, SubaccessOp, OpenSubfieldOp, OpenSubindexOp,
299 RefSubOp>([](auto op) { return getDeclarationKind(op.getInput()); })
300 .Default([](auto) { return DeclKind::Other; });
301}
302
303size_t firrtl::getNumPorts(Operation *op) {
304 if (auto module = dyn_cast<FModuleLike>(op))
305 return module.getNumPorts();
306 return op->getNumResults();
307}
308
309/// Check whether an operation has a `DontTouch` annotation, or a symbol that
310/// should prevent certain types of canonicalizations.
311bool firrtl::hasDontTouch(Operation *op) {
312 return op->getAttr(hw::InnerSymbolTable::getInnerSymbolAttrName()) ||
314}
315
316/// Check whether a block argument ("port") or the operation defining a value
317/// has a `DontTouch` annotation, or a symbol that should prevent certain types
318/// of canonicalizations.
319bool firrtl::hasDontTouch(Value value) {
320 if (auto *op = value.getDefiningOp())
321 return hasDontTouch(op);
322 auto arg = dyn_cast<BlockArgument>(value);
323 auto module = dyn_cast<FModuleOp>(arg.getOwner()->getParentOp());
324 if (!module)
325 return false;
326 return (module.getPortSymbolAttr(arg.getArgNumber())) ||
327 AnnotationSet::forPort(module, arg.getArgNumber()).hasDontTouch();
328}
329
330/// Get a special name to use when printing the entry block arguments of the
331/// region contained by an operation in this dialect.
332void getAsmBlockArgumentNamesImpl(Operation *op, mlir::Region &region,
333 OpAsmSetValueNameFn setNameFn) {
334 if (region.empty())
335 return;
336 auto *parentOp = op;
337 auto *block = &region.front();
338 // Check to see if the operation containing the arguments has 'firrtl.name'
339 // attributes for them. If so, use that as the name.
340 auto argAttr = parentOp->getAttrOfType<ArrayAttr>("portNames");
341 // Do not crash on invalid IR.
342 if (!argAttr || argAttr.size() != block->getNumArguments())
343 return;
344
345 for (size_t i = 0, e = block->getNumArguments(); i != e; ++i) {
346 auto str = cast<StringAttr>(argAttr[i]).getValue();
347 if (!str.empty())
348 setNameFn(block->getArgument(i), str);
349 }
350}
351
352/// A forward declaration for `NameKind` attribute parser.
353static ParseResult parseNameKind(OpAsmParser &parser,
354 firrtl::NameKindEnumAttr &result);
355
356//===----------------------------------------------------------------------===//
357// Layer Verification Utilities
358//===----------------------------------------------------------------------===//
359
360/// Get the ambient layers active at the given op.
361static LayerSet getAmbientLayersAt(Operation *op) {
362 // Crawl through the parent ops, accumulating all ambient layers at the given
363 // operation.
364 LayerSet result;
365 for (; op != nullptr; op = op->getParentOp()) {
366 if (auto module = dyn_cast<FModuleLike>(op)) {
367 auto layers = module.getLayersAttr().getAsRange<SymbolRefAttr>();
368 result.insert(layers.begin(), layers.end());
369 break;
370 }
371 if (auto layerblock = dyn_cast<LayerBlockOp>(op)) {
372 result.insert(layerblock.getLayerName());
373 continue;
374 }
375 }
376 return result;
377}
378
379/// Get the ambient layer requirements at the definition site of the value.
380static LayerSet getAmbientLayersFor(Value value) {
381 return getAmbientLayersAt(getFieldRefFromValue(value).getDefiningOp());
382}
383
384/// Get the effective layer requirements for the given value.
385/// The effective layers for a value is the union of
386/// - the ambient layers for the cannonical storage location.
387/// - any explicit layer annotations in the value's type.
388static LayerSet getLayersFor(Value value) {
389 auto result = getAmbientLayersFor(value);
390 if (auto type = dyn_cast<RefType>(value.getType()))
391 if (auto layer = type.getLayer())
392 result.insert(type.getLayer());
393 return result;
394}
395
396/// Check that the source layer is compatible with the destination layer.
397/// Either the source and destination are identical, or the source-layer
398/// is a parent of the destination. For example `A` is compatible with `A.B.C`,
399/// because any definition valid in `A` is also valid in `A.B.C`.
400static bool isLayerCompatibleWith(mlir::SymbolRefAttr srcLayer,
401 mlir::SymbolRefAttr dstLayer) {
402 // A non-colored probe may be cast to any colored probe.
403 if (!srcLayer)
404 return true;
405
406 // A colored probe cannot be cast to an uncolored probe.
407 if (!dstLayer)
408 return false;
409
410 // Return true if the srcLayer is a prefix of the dstLayer.
411 if (srcLayer.getRootReference() != dstLayer.getRootReference())
412 return false;
413
414 auto srcNames = srcLayer.getNestedReferences();
415 auto dstNames = dstLayer.getNestedReferences();
416 if (dstNames.size() < srcNames.size())
417 return false;
418
419 return llvm::all_of(llvm::zip_first(srcNames, dstNames),
420 [](auto x) { return std::get<0>(x) == std::get<1>(x); });
421}
422
423/// Check that the source layer is present in the destination layers.
424static bool isLayerCompatibleWith(SymbolRefAttr srcLayer,
425 const LayerSet &dstLayers) {
426 // fast path: the required layer is directly listed in the provided layers.
427 if (dstLayers.contains(srcLayer))
428 return true;
429
430 // Slow path: the required layer is not directly listed in the provided
431 // layers, but the layer may still be provided by a nested layer.
432 return any_of(dstLayers, [=](SymbolRefAttr dstLayer) {
433 return isLayerCompatibleWith(srcLayer, dstLayer);
434 });
435}
436
437/// Check that the source layers are all present in the destination layers.
438/// True if all source layers are present in the destination.
439/// Outputs the set of source layers that are missing in the destination.
440static bool isLayerSetCompatibleWith(const LayerSet &src, const LayerSet &dst,
441 SmallVectorImpl<SymbolRefAttr> &missing) {
442 for (auto srcLayer : src)
443 if (!isLayerCompatibleWith(srcLayer, dst))
444 missing.push_back(srcLayer);
445
446 llvm::sort(missing, LayerSetCompare());
447 return missing.empty();
448}
449
450static LogicalResult checkLayerCompatibility(
451 Operation *op, const LayerSet &src, const LayerSet &dst,
452 const Twine &errorMsg,
453 const Twine &noteMsg = Twine("missing layer requirements")) {
454 SmallVector<SymbolRefAttr> missing;
455 if (isLayerSetCompatibleWith(src, dst, missing))
456 return success();
457 interleaveComma(missing, op->emitOpError(errorMsg).attachNote()
458 << noteMsg << ": ");
459 return failure();
460}
461
462//===----------------------------------------------------------------------===//
463// CircuitOp
464//===----------------------------------------------------------------------===//
465
466void CircuitOp::build(OpBuilder &builder, OperationState &result,
467 StringAttr name, ArrayAttr annotations) {
468 // Add an attribute for the name.
469 result.getOrAddProperties<Properties>().setName(name);
470
471 if (!annotations)
472 annotations = builder.getArrayAttr({});
473 result.getOrAddProperties<Properties>().setAnnotations(annotations);
474
475 // Create a region and a block for the body.
476 Region *bodyRegion = result.addRegion();
477 Block *body = new Block();
478 bodyRegion->push_back(body);
479}
480
481static ParseResult parseCircuitOpAttrs(OpAsmParser &parser,
482 NamedAttrList &resultAttrs) {
483 auto result = parser.parseOptionalAttrDictWithKeyword(resultAttrs);
484 if (!resultAttrs.get("annotations"))
485 resultAttrs.append("annotations", parser.getBuilder().getArrayAttr({}));
486
487 return result;
488}
489
490static void printCircuitOpAttrs(OpAsmPrinter &p, Operation *op,
491 DictionaryAttr attr) {
492 // "name" is always elided.
493 SmallVector<StringRef> elidedAttrs = {"name"};
494 // Elide "annotations" if it doesn't exist or if it is empty
495 auto annotationsAttr = op->getAttrOfType<ArrayAttr>("annotations");
496 if (annotationsAttr.empty())
497 elidedAttrs.push_back("annotations");
498
499 p.printOptionalAttrDictWithKeyword(op->getAttrs(), elidedAttrs);
500}
501
502LogicalResult CircuitOp::verifyRegions() {
503 StringRef main = getName();
504
505 // Check that the circuit has a non-empty name.
506 if (main.empty()) {
507 emitOpError("must have a non-empty name");
508 return failure();
509 }
510
511 mlir::SymbolTable symtbl(getOperation());
512
513 auto *mainModule = symtbl.lookup(main);
514 if (!mainModule)
515 return emitOpError().append(
516 "does not contain module with same name as circuit");
517 if (!isa<FModuleLike>(mainModule))
518 return mainModule->emitError(
519 "entity with name of circuit must be a module");
520 if (symtbl.getSymbolVisibility(mainModule) !=
521 mlir::SymbolTable::Visibility::Public)
522 return mainModule->emitError("main module must be public");
523
524 // Store a mapping of defname to either the first external module
525 // that defines it or, preferentially, the first external module
526 // that defines it and has no parameters.
527 llvm::DenseMap<Attribute, FExtModuleOp> defnameMap;
528
529 auto verifyExtModule = [&](FExtModuleOp extModule) -> LogicalResult {
530 if (!extModule)
531 return success();
532
533 auto defname = extModule.getDefnameAttr();
534 if (!defname)
535 return success();
536
537 // Check that this extmodule's defname does not conflict with
538 // the symbol name of any module.
539 if (auto collidingModule = symtbl.lookup<FModuleOp>(defname.getValue()))
540 return extModule.emitOpError()
541 .append("attribute 'defname' with value ", defname,
542 " conflicts with the name of another module in the circuit")
543 .attachNote(collidingModule.getLoc())
544 .append("previous module declared here");
545
546 // Find an optional extmodule with a defname collision. Update
547 // the defnameMap if this is the first extmodule with that
548 // defname or if the current extmodule takes no parameters and
549 // the collision does. The latter condition improves later
550 // extmodule verification as checking against a parameterless
551 // module is stricter.
552 FExtModuleOp collidingExtModule;
553 if (auto &value = defnameMap[defname]) {
554 collidingExtModule = value;
555 if (!value.getParameters().empty() && extModule.getParameters().empty())
556 value = extModule;
557 } else {
558 value = extModule;
559 // Go to the next extmodule if no extmodule with the same
560 // defname was found.
561 return success();
562 }
563
564 // Check that the number of ports is exactly the same.
565 SmallVector<PortInfo> ports = extModule.getPorts();
566 SmallVector<PortInfo> collidingPorts = collidingExtModule.getPorts();
567
568 if (ports.size() != collidingPorts.size())
569 return extModule.emitOpError()
570 .append("with 'defname' attribute ", defname, " has ", ports.size(),
571 " ports which is different from a previously defined "
572 "extmodule with the same 'defname' which has ",
573 collidingPorts.size(), " ports")
574 .attachNote(collidingExtModule.getLoc())
575 .append("previous extmodule definition occurred here");
576
577 // Check that ports match for name and type. Since parameters
578 // *might* affect widths, ignore widths if either module has
579 // parameters. Note that this allows for misdetections, but
580 // has zero false positives.
581 for (auto p : llvm::zip(ports, collidingPorts)) {
582 StringAttr aName = std::get<0>(p).name, bName = std::get<1>(p).name;
583 Type aType = std::get<0>(p).type, bType = std::get<1>(p).type;
584
585 if (aName != bName)
586 return extModule.emitOpError()
587 .append("with 'defname' attribute ", defname,
588 " has a port with name ", aName,
589 " which does not match the name of the port in the same "
590 "position of a previously defined extmodule with the same "
591 "'defname', expected port to have name ",
592 bName)
593 .attachNote(collidingExtModule.getLoc())
594 .append("previous extmodule definition occurred here");
595
596 if (!extModule.getParameters().empty() ||
597 !collidingExtModule.getParameters().empty()) {
598 // Compare base types as widthless, others must match.
599 if (auto base = type_dyn_cast<FIRRTLBaseType>(aType))
600 aType = base.getWidthlessType();
601 if (auto base = type_dyn_cast<FIRRTLBaseType>(bType))
602 bType = base.getWidthlessType();
603 }
604 if (aType != bType)
605 return extModule.emitOpError()
606 .append("with 'defname' attribute ", defname,
607 " has a port with name ", aName,
608 " which has a different type ", aType,
609 " which does not match the type of the port in the same "
610 "position of a previously defined extmodule with the same "
611 "'defname', expected port to have type ",
612 bType)
613 .attachNote(collidingExtModule.getLoc())
614 .append("previous extmodule definition occurred here");
615 }
616 return success();
617 };
618
619 SmallVector<FModuleOp, 1> dutModules;
620 for (auto &op : *getBodyBlock()) {
621 // Verify modules.
622 if (auto moduleOp = dyn_cast<FModuleOp>(op)) {
623 if (AnnotationSet(moduleOp).hasAnnotation(dutAnnoClass))
624 dutModules.push_back(moduleOp);
625 continue;
626 }
627
628 // Verify external modules.
629 if (auto extModule = dyn_cast<FExtModuleOp>(op)) {
630 if (verifyExtModule(extModule).failed())
631 return failure();
632 }
633 }
634
635 // Error if there is more than one design-under-test.
636 if (dutModules.size() > 1) {
637 auto diag = dutModules[0]->emitOpError()
638 << "is annotated as the design-under-test (DUT), but other "
639 "modules are also annotated";
640 for (auto moduleOp : ArrayRef(dutModules).drop_front())
641 diag.attachNote(moduleOp.getLoc()) << "is also annotated as the DUT";
642 return failure();
643 }
644
645 return success();
646}
647
648Block *CircuitOp::getBodyBlock() { return &getBody().front(); }
649
650//===----------------------------------------------------------------------===//
651// FExtModuleOp and FModuleOp
652//===----------------------------------------------------------------------===//
653
654static SmallVector<PortInfo> getPortImpl(FModuleLike module) {
655 SmallVector<PortInfo> results;
656 ArrayRef<Attribute> domains = module.getDomainInfo();
657 for (unsigned i = 0, e = module.getNumPorts(); i < e; ++i) {
658 results.push_back({module.getPortNameAttr(i), module.getPortType(i),
659 module.getPortDirection(i), module.getPortSymbolAttr(i),
660 module.getPortLocation(i),
661 AnnotationSet::forPort(module, i),
662 domains.empty() ? Attribute{} : domains[i]});
663 }
664 return results;
665}
666
667SmallVector<PortInfo> FModuleOp::getPorts() { return ::getPortImpl(*this); }
668
669SmallVector<PortInfo> FExtModuleOp::getPorts() { return ::getPortImpl(*this); }
670
671SmallVector<PortInfo> FIntModuleOp::getPorts() { return ::getPortImpl(*this); }
672
673SmallVector<PortInfo> FMemModuleOp::getPorts() { return ::getPortImpl(*this); }
674
676 if (dir == Direction::In)
677 return hw::ModulePort::Direction::Input;
678 if (dir == Direction::Out)
679 return hw::ModulePort::Direction::Output;
680 assert(0 && "invalid direction");
681 abort();
682}
683
684static SmallVector<hw::PortInfo> getPortListImpl(FModuleLike module) {
685 SmallVector<hw::PortInfo> results;
686 auto aname = StringAttr::get(module.getContext(),
687 hw::HWModuleLike::getPortSymbolAttrName());
688 auto emptyDict = DictionaryAttr::get(module.getContext());
689 for (unsigned i = 0, e = getNumPorts(module); i < e; ++i) {
690 auto sym = module.getPortSymbolAttr(i);
691 results.push_back(
692 {{module.getPortNameAttr(i), module.getPortType(i),
693 dirFtoH(module.getPortDirection(i))},
694 i,
695 sym ? DictionaryAttr::get(
696 module.getContext(),
697 ArrayRef<mlir::NamedAttribute>{NamedAttribute{aname, sym}})
698 : emptyDict,
699 module.getPortLocation(i)});
700 }
701 return results;
702}
703
704SmallVector<::circt::hw::PortInfo> FModuleOp::getPortList() {
705 return ::getPortListImpl(*this);
706}
707
708SmallVector<::circt::hw::PortInfo> FExtModuleOp::getPortList() {
709 return ::getPortListImpl(*this);
710}
711
712SmallVector<::circt::hw::PortInfo> FIntModuleOp::getPortList() {
713 return ::getPortListImpl(*this);
714}
715
716SmallVector<::circt::hw::PortInfo> FMemModuleOp::getPortList() {
717 return ::getPortListImpl(*this);
718}
719
720static hw::PortInfo getPortImpl(FModuleLike module, size_t idx) {
721 return {{module.getPortNameAttr(idx), module.getPortType(idx),
722 dirFtoH(module.getPortDirection(idx))},
723 idx,
724 DictionaryAttr::get(
725 module.getContext(),
726 ArrayRef<mlir::NamedAttribute>{NamedAttribute{
727 StringAttr::get(module.getContext(),
728 hw::HWModuleLike::getPortSymbolAttrName()),
729 module.getPortSymbolAttr(idx)}}),
730 module.getPortLocation(idx)};
731}
732
733::circt::hw::PortInfo FModuleOp::getPort(size_t idx) {
734 return ::getPortImpl(*this, idx);
735}
736
737::circt::hw::PortInfo FExtModuleOp::getPort(size_t idx) {
738 return ::getPortImpl(*this, idx);
739}
740
741::circt::hw::PortInfo FIntModuleOp::getPort(size_t idx) {
742 return ::getPortImpl(*this, idx);
743}
744
745::circt::hw::PortInfo FMemModuleOp::getPort(size_t idx) {
746 return ::getPortImpl(*this, idx);
747}
748
749// Return the port with the specified name.
750BlockArgument FModuleOp::getArgument(size_t portNumber) {
751 return getBodyBlock()->getArgument(portNumber);
752}
753
754/// Return an updated domain info Attribute with domain indices updated based on
755/// port insertions.
756static Attribute fixDomainInfoInsertions(MLIRContext *context,
757 Attribute domainInfoAttr,
758 ArrayRef<unsigned> indexMap) {
759 // This is a domain type port. Return the original domain info unmodified.
760 auto di = dyn_cast_or_null<ArrayAttr>(domainInfoAttr);
761 if (!di || di.empty())
762 return domainInfoAttr;
763
764 // This is a non-domain type port. Update any indeices referenced.
765 SmallVector<Attribute> domainInfo;
766 for (auto attr : di) {
767 auto oldIdx = cast<IntegerAttr>(attr).getUInt();
768 auto newIdx = indexMap[oldIdx];
769 if (oldIdx == newIdx)
770 domainInfo.push_back(attr);
771 else
772 domainInfo.push_back(IntegerAttr::get(
773 IntegerType::get(context, 32, IntegerType::Unsigned), newIdx));
774 }
775 return ArrayAttr::get(context, domainInfo);
776}
777
778/// Inserts the given ports. The insertion indices are expected to be in order.
779/// Insertion occurs in-order, such that ports with the same insertion index
780/// appear in the module in the same order they appeared in the list.
781static void insertPorts(FModuleLike op,
782 ArrayRef<std::pair<unsigned, PortInfo>> ports) {
783 if (ports.empty())
784 return;
785 unsigned oldNumArgs = op.getNumPorts();
786 unsigned newNumArgs = oldNumArgs + ports.size();
787
788 // Add direction markers and names for new ports.
789 auto existingDirections = op.getPortDirectionsAttr();
790 ArrayRef<Attribute> existingNames = op.getPortNames();
791 ArrayRef<Attribute> existingTypes = op.getPortTypes();
792 ArrayRef<Attribute> existingLocs = op.getPortLocations();
793 assert(existingDirections.size() == oldNumArgs);
794 assert(existingNames.size() == oldNumArgs);
795 assert(existingTypes.size() == oldNumArgs);
796 assert(existingLocs.size() == oldNumArgs);
797
798 SmallVector<bool> newDirections;
799 SmallVector<Attribute> newNames, newTypes, newDomains, newAnnos, newSyms,
800 newLocs;
801 newDirections.reserve(newNumArgs);
802 newNames.reserve(newNumArgs);
803 newTypes.reserve(newNumArgs);
804 newDomains.reserve(newNumArgs);
805 newAnnos.reserve(newNumArgs);
806 newSyms.reserve(newNumArgs);
807 newLocs.reserve(newNumArgs);
808
809 SmallVector<unsigned> indexMap(oldNumArgs);
810
811 auto emptyArray = ArrayAttr::get(op.getContext(), {});
812
813 unsigned oldIdx = 0;
814 unsigned inserted = 0;
815 auto migrateOldPorts = [&](unsigned untilOldIdx) {
816 while (oldIdx < oldNumArgs && oldIdx < untilOldIdx) {
817 newDirections.push_back(existingDirections[oldIdx]);
818 newNames.push_back(existingNames[oldIdx]);
819 newTypes.push_back(existingTypes[oldIdx]);
820 newDomains.push_back(fixDomainInfoInsertions(
821 op.getContext(), op.getDomainInfoAttrForPort(oldIdx), indexMap));
822 newAnnos.push_back(op.getAnnotationsAttrForPort(oldIdx));
823 newSyms.push_back(op.getPortSymbolAttr(oldIdx));
824 newLocs.push_back(existingLocs[oldIdx]);
825
826 indexMap[oldIdx] = oldIdx + inserted;
827 ++oldIdx;
828 }
829 };
830
831 for (auto [idx, port] : ports) {
832 migrateOldPorts(idx);
833 newDirections.push_back(direction::unGet(port.direction));
834 newNames.push_back(port.name);
835 newTypes.push_back(TypeAttr::get(port.type));
836 newDomains.push_back(fixDomainInfoInsertions(
837 op.getContext(),
838 port.domains ? port.domains : ArrayAttr::get(op.getContext(), {}),
839 indexMap));
840 auto annos = port.annotations.getArrayAttr();
841 newAnnos.push_back(annos ? annos : emptyArray);
842 newSyms.push_back(port.sym);
843 newLocs.push_back(port.loc);
844 ++inserted;
845 }
846 migrateOldPorts(oldNumArgs);
847
848 // The lack of *any* port annotations is represented by an empty
849 // `portAnnotations` array as a shorthand.
850 if (llvm::all_of(newAnnos, [](Attribute attr) {
851 return cast<ArrayAttr>(attr).empty();
852 }))
853 newAnnos.clear();
854
855 // The lack of *any* domains is also represented by an empty `domainInfo`
856 // attribute.
857 if (llvm::all_of(newDomains, [](Attribute attr) {
858 if (!attr)
859 return true;
860 if (auto arrayAttr = dyn_cast<ArrayAttr>(attr))
861 return arrayAttr.empty();
862 return false;
863 }))
864 newDomains.clear();
865
866 // Apply these changed markers.
867 op->setAttr("portDirections",
868 direction::packAttribute(op.getContext(), newDirections));
869 op->setAttr("portNames", ArrayAttr::get(op.getContext(), newNames));
870 op->setAttr("portTypes", ArrayAttr::get(op.getContext(), newTypes));
871 op->setAttr("domainInfo", ArrayAttr::get(op.getContext(), newDomains));
872 op->setAttr("portAnnotations", ArrayAttr::get(op.getContext(), newAnnos));
873 FModuleLike::fixupPortSymsArray(newSyms, op.getContext());
874 op.setPortSymbols(newSyms);
875 op->setAttr("portLocations", ArrayAttr::get(op.getContext(), newLocs));
876}
877
878// Return an Attribute with updated port domain information based on information
879// about which ports have been deleted. This is necessary because the port
880// domain storage uses an integer to indicate the index of a domain port with
881// which it is associated.
882//
883// Note: this will _always_ return one-entry-per-port. While this is not
884// required for FModuleLike, InstanceOps have this restriction.
885static ArrayAttr fixDomainInfoDeletions(MLIRContext *context,
886 ArrayAttr domainInfoAttr,
887 const llvm::BitVector &portIndices,
888 bool supportsEmptyAttr) {
889 if (supportsEmptyAttr && domainInfoAttr.empty())
890 return domainInfoAttr;
891
892 // Compute an array where each entry is the number of ports before that
893 // index that have been deleted. This is used to update domain information.
894 SmallVector<unsigned> numDeleted;
895 numDeleted.resize(portIndices.size());
896 size_t deletionIndex = portIndices.find_first();
897 for (size_t i = 0, e = portIndices.size(); i != e; ++i) {
898 if (i == deletionIndex) {
899 if (i == 0)
900 numDeleted[i] = 1;
901 else
902 numDeleted[i] = numDeleted[i - 1] + 1;
903 deletionIndex = portIndices.find_next(i);
904 continue;
905 }
906 if (i == 0)
907 numDeleted[i] = 0;
908 else
909 numDeleted[i] = numDeleted[i - 1];
910 }
911
912 // Return a cached empty ArrayAttr.
913 ArrayAttr eEmpty;
914 auto getEmpty = [&]() {
915 if (!eEmpty)
916 eEmpty = ArrayAttr::get(context, {});
917 return eEmpty;
918 };
919
920 // Update the domain indices.
921 SmallVector<Attribute> newDomainInfo;
922 newDomainInfo.reserve(portIndices.size() - portIndices.count());
923 for (size_t i = 0, e = portIndices.size(); i != e; ++i) {
924 // If this port is deleted, then do nothing.
925 if (portIndices.test(i))
926 continue;
927 // If there is no domain info, then add an empty attribute.
928 if (domainInfoAttr.empty()) {
929 newDomainInfo.push_back(getEmpty());
930 continue;
931 }
932 auto attr = domainInfoAttr[i];
933 // If this is a Domain Type (indicating domain kind) or an empty domain.
934 auto domains = dyn_cast<ArrayAttr>(attr);
935 if (!domains || domains.empty()) {
936 newDomainInfo.push_back(attr);
937 continue;
938 }
939 // This contains indexes to domain ports. Update them.
940 SmallVector<Attribute> newDomains;
941 for (auto domain : domains) {
942 // If the domain port was deleted, drop the association.
943 auto oldIdx = cast<IntegerAttr>(domain).getUInt();
944 if (portIndices.test(oldIdx))
945 continue;
946 // If the new index is the same, do nothing.
947 auto newIdx = oldIdx - numDeleted[oldIdx];
948 if (oldIdx == newIdx) {
949 newDomainInfo.push_back(attr);
950 continue;
951 }
952 // Update the index.
953 newDomains.push_back(IntegerAttr::get(
954 IntegerType::get(context, 32, IntegerType::Unsigned), newIdx));
955 }
956 newDomainInfo.push_back(ArrayAttr::get(context, newDomains));
957 }
958
959 return ArrayAttr::get(context, newDomainInfo);
960}
961
962/// Erases the ports that have their corresponding bit set in `portIndices`.
963static void erasePorts(FModuleLike op, const llvm::BitVector &portIndices) {
964 if (portIndices.none())
965 return;
966
967 // Drop the direction markers for dead ports.
968 ArrayRef<bool> portDirections = op.getPortDirectionsAttr().asArrayRef();
969 ArrayRef<Attribute> portNames = op.getPortNames();
970 ArrayRef<Attribute> portTypes = op.getPortTypes();
971 ArrayRef<Attribute> portAnnos = op.getPortAnnotations();
972 ArrayRef<Attribute> portSyms = op.getPortSymbols();
973 ArrayRef<Attribute> portLocs = op.getPortLocations();
974 ArrayRef<Attribute> portDomains = op.getDomainInfo();
975 auto numPorts = op.getNumPorts();
976 (void)numPorts;
977 assert(portDirections.size() == numPorts);
978 assert(portNames.size() == numPorts);
979 assert(portAnnos.size() == numPorts || portAnnos.empty());
980 assert(portTypes.size() == numPorts);
981 assert(portSyms.size() == numPorts || portSyms.empty());
982 assert(portLocs.size() == numPorts);
983 assert(portDomains.size() == numPorts || portDomains.empty());
984
985 SmallVector<bool> newPortDirections =
986 removeElementsAtIndices<bool>(portDirections, portIndices);
987 SmallVector<Attribute> newPortNames, newPortTypes, newPortAnnos, newPortSyms,
988 newPortLocs;
989 newPortNames = removeElementsAtIndices(portNames, portIndices);
990 newPortTypes = removeElementsAtIndices(portTypes, portIndices);
991 newPortAnnos = removeElementsAtIndices(portAnnos, portIndices);
992 newPortSyms = removeElementsAtIndices(portSyms, portIndices);
993 newPortLocs = removeElementsAtIndices(portLocs, portIndices);
994
995 op->setAttr("portDirections",
996 direction::packAttribute(op.getContext(), newPortDirections));
997 op->setAttr("portNames", ArrayAttr::get(op.getContext(), newPortNames));
998 op->setAttr("portAnnotations", ArrayAttr::get(op.getContext(), newPortAnnos));
999 op->setAttr("portTypes", ArrayAttr::get(op.getContext(), newPortTypes));
1000 FModuleLike::fixupPortSymsArray(newPortSyms, op.getContext());
1001 op->setAttr("portSymbols", ArrayAttr::get(op.getContext(), newPortSyms));
1002 op->setAttr("portLocations", ArrayAttr::get(op.getContext(), newPortLocs));
1003 op->setAttr("domainInfo",
1004 fixDomainInfoDeletions(op.getContext(), op.getDomainInfoAttr(),
1005 portIndices, /*supportsEmptyAttr=*/true));
1006}
1007
1008void FExtModuleOp::erasePorts(const llvm::BitVector &portIndices) {
1009 ::erasePorts(cast<FModuleLike>((Operation *)*this), portIndices);
1010}
1011
1012void FIntModuleOp::erasePorts(const llvm::BitVector &portIndices) {
1013 ::erasePorts(cast<FModuleLike>((Operation *)*this), portIndices);
1014}
1015
1016void FMemModuleOp::erasePorts(const llvm::BitVector &portIndices) {
1017 ::erasePorts(cast<FModuleLike>((Operation *)*this), portIndices);
1018}
1019
1020void FModuleOp::erasePorts(const llvm::BitVector &portIndices) {
1021 ::erasePorts(cast<FModuleLike>((Operation *)*this), portIndices);
1022 getBodyBlock()->eraseArguments(portIndices);
1023}
1024
1025/// Inserts the given ports. The insertion indices are expected to be in order.
1026/// Insertion occurs in-order, such that ports with the same insertion index
1027/// appear in the module in the same order they appeared in the list.
1028void FModuleOp::insertPorts(ArrayRef<std::pair<unsigned, PortInfo>> ports) {
1029 ::insertPorts(cast<FModuleLike>((Operation *)*this), ports);
1030
1031 // Insert the block arguments.
1032 auto *body = getBodyBlock();
1033 for (size_t i = 0, e = ports.size(); i < e; ++i) {
1034 // Block arguments are inserted one at a time, so for each argument we
1035 // insert we have to increase the index by 1.
1036 auto &[index, port] = ports[i];
1037 body->insertArgument(index + i, port.type, port.loc);
1038 }
1039}
1040
1041void FExtModuleOp::insertPorts(ArrayRef<std::pair<unsigned, PortInfo>> ports) {
1042 ::insertPorts(cast<FModuleLike>((Operation *)*this), ports);
1043}
1044
1045void FIntModuleOp::insertPorts(ArrayRef<std::pair<unsigned, PortInfo>> ports) {
1046 ::insertPorts(cast<FModuleLike>((Operation *)*this), ports);
1047}
1048
1049/// Inserts the given ports. The insertion indices are expected to be in order.
1050/// Insertion occurs in-order, such that ports with the same insertion index
1051/// appear in the module in the same order they appeared in the list.
1052void FMemModuleOp::insertPorts(ArrayRef<std::pair<unsigned, PortInfo>> ports) {
1053 ::insertPorts(cast<FModuleLike>((Operation *)*this), ports);
1054}
1055
1056template <typename OpTy>
1057void buildModuleLike(OpBuilder &builder, OperationState &result,
1058 StringAttr name, ArrayRef<PortInfo> ports) {
1059 // Add an attribute for the name.
1060 auto &properties = result.getOrAddProperties<typename OpTy::Properties>();
1061 properties.setSymName(name);
1062
1063 // Record the names of the arguments if present.
1064 SmallVector<Direction, 4> portDirections;
1065 SmallVector<Attribute, 4> portNames;
1066 SmallVector<Attribute, 4> portTypes;
1067 SmallVector<Attribute, 4> portSyms;
1068 SmallVector<Attribute, 4> portLocs;
1069 SmallVector<Attribute, 4> portDomains;
1070 for (const auto &port : ports) {
1071 portDirections.push_back(port.direction);
1072 portNames.push_back(port.name);
1073 portTypes.push_back(TypeAttr::get(port.type));
1074 portSyms.push_back(port.sym);
1075 portLocs.push_back(port.loc);
1076 portDomains.push_back(port.domains);
1077 }
1078 if (llvm::all_of(portDomains, [](Attribute attr) {
1079 if (!attr)
1080 return true;
1081 if (auto arrayAttr = dyn_cast<ArrayAttr>(attr))
1082 return arrayAttr.empty();
1083 return false;
1084 }))
1085 portDomains.clear();
1086
1087 FModuleLike::fixupPortSymsArray(portSyms, builder.getContext());
1088
1089 // Both attributes are added, even if the module has no ports.
1090 properties.setPortDirections(
1091 direction::packAttribute(builder.getContext(), portDirections));
1092 properties.setPortNames(builder.getArrayAttr(portNames));
1093 properties.setPortTypes(builder.getArrayAttr(portTypes));
1094 properties.setPortSymbols(builder.getArrayAttr(portSyms));
1095 properties.setPortLocations(builder.getArrayAttr(portLocs));
1096 properties.setDomainInfo(builder.getArrayAttr(portDomains));
1097
1098 result.addRegion();
1099}
1100
1101template <typename OpTy>
1102static void buildModule(OpBuilder &builder, OperationState &result,
1103 StringAttr name, ArrayRef<PortInfo> ports,
1104 ArrayAttr annotations, ArrayAttr layers) {
1105 buildModuleLike<OpTy>(builder, result, name, ports);
1106 auto &properties = result.getOrAddProperties<typename OpTy::Properties>();
1107 // Annotations.
1108 if (!annotations)
1109 annotations = builder.getArrayAttr({});
1110 properties.setAnnotations(annotations);
1111
1112 // Port annotations. lack of *any* port annotations is represented by an empty
1113 // `portAnnotations` array as a shorthand.
1114 SmallVector<Attribute, 4> portAnnotations;
1115 for (const auto &port : ports)
1116 portAnnotations.push_back(port.annotations.getArrayAttr());
1117 if (llvm::all_of(portAnnotations, [](Attribute attr) {
1118 return cast<ArrayAttr>(attr).empty();
1119 }))
1120 portAnnotations.clear();
1121 properties.setPortAnnotations(builder.getArrayAttr(portAnnotations));
1122
1123 // Layers.
1124 if (!layers)
1125 layers = builder.getArrayAttr({});
1126 properties.setLayers(layers);
1127}
1128
1129template <typename OpTy>
1130static void buildClass(OpBuilder &builder, OperationState &result,
1131 StringAttr name, ArrayRef<PortInfo> ports) {
1132 return buildModuleLike<OpTy>(builder, result, name, ports);
1133}
1134
1135void FModuleOp::build(OpBuilder &builder, OperationState &result,
1136 StringAttr name, ConventionAttr convention,
1137 ArrayRef<PortInfo> ports, ArrayAttr annotations,
1138 ArrayAttr layers) {
1139 buildModule<FModuleOp>(builder, result, name, ports, annotations, layers);
1140 auto &properties = result.getOrAddProperties<Properties>();
1141 properties.setConvention(convention);
1142
1143 // Create a region and a block for the body.
1144 auto *bodyRegion = result.regions[0].get();
1145 Block *body = new Block();
1146 bodyRegion->push_back(body);
1147
1148 // Add arguments to the body block.
1149 for (auto &elt : ports)
1150 body->addArgument(elt.type, elt.loc);
1151}
1152
1153void FExtModuleOp::build(OpBuilder &builder, OperationState &result,
1154 StringAttr name, ConventionAttr convention,
1155 ArrayRef<PortInfo> ports, ArrayAttr knownLayers,
1156 StringRef defnameAttr, ArrayAttr annotations,
1157 ArrayAttr parameters, ArrayAttr layers) {
1158 buildModule<FExtModuleOp>(builder, result, name, ports, annotations, layers);
1159 auto &properties = result.getOrAddProperties<Properties>();
1160 properties.setConvention(convention);
1161 if (!knownLayers)
1162 knownLayers = builder.getArrayAttr({});
1163 properties.setKnownLayers(knownLayers);
1164 if (!defnameAttr.empty())
1165 properties.setDefname(builder.getStringAttr(defnameAttr));
1166 if (!parameters)
1167 parameters = builder.getArrayAttr({});
1168 properties.setParameters(parameters);
1169}
1170
1171void FIntModuleOp::build(OpBuilder &builder, OperationState &result,
1172 StringAttr name, ArrayRef<PortInfo> ports,
1173 StringRef intrinsicNameStr, ArrayAttr annotations,
1174 ArrayAttr parameters, ArrayAttr layers) {
1175 buildModule<FIntModuleOp>(builder, result, name, ports, annotations, layers);
1176 auto &properties = result.getOrAddProperties<Properties>();
1177 properties.setIntrinsic(builder.getStringAttr(intrinsicNameStr));
1178 if (!parameters)
1179 parameters = builder.getArrayAttr({});
1180 properties.setParameters(parameters);
1181}
1182
1183void FMemModuleOp::build(OpBuilder &builder, OperationState &result,
1184 StringAttr name, ArrayRef<PortInfo> ports,
1185 uint32_t numReadPorts, uint32_t numWritePorts,
1186 uint32_t numReadWritePorts, uint32_t dataWidth,
1187 uint32_t maskBits, uint32_t readLatency,
1188 uint32_t writeLatency, uint64_t depth, RUWBehavior ruw,
1189 ArrayAttr annotations, ArrayAttr layers) {
1190 auto *context = builder.getContext();
1191 buildModule<FMemModuleOp>(builder, result, name, ports, annotations, layers);
1192 auto ui32Type = IntegerType::get(context, 32, IntegerType::Unsigned);
1193 auto ui64Type = IntegerType::get(context, 64, IntegerType::Unsigned);
1194 auto &properties = result.getOrAddProperties<Properties>();
1195 properties.setNumReadPorts(IntegerAttr::get(ui32Type, numReadPorts));
1196 properties.setNumWritePorts(IntegerAttr::get(ui32Type, numWritePorts));
1197 properties.setNumReadWritePorts(
1198 IntegerAttr::get(ui32Type, numReadWritePorts));
1199 properties.setDataWidth(IntegerAttr::get(ui32Type, dataWidth));
1200 properties.setMaskBits(IntegerAttr::get(ui32Type, maskBits));
1201 properties.setReadLatency(IntegerAttr::get(ui32Type, readLatency));
1202 properties.setWriteLatency(IntegerAttr::get(ui32Type, writeLatency));
1203 properties.setDepth(IntegerAttr::get(ui64Type, depth));
1204 properties.setExtraPorts(ArrayAttr::get(context, {}));
1205 properties.setRuw(RUWBehaviorAttr::get(context, ruw));
1206}
1207
1208/// Print a list of module ports in the following form:
1209/// in x: !firrtl.uint<1> [{class = "DontTouch}], out "_port": !firrtl.uint<2>
1210///
1211/// When there is no block specified, the port names print as MLIR identifiers,
1212/// wrapping in quotes if not legal to print as-is. When there is no block
1213/// specified, this function always return false, indicating that there was no
1214/// issue printing port names.
1215///
1216/// If there is a block specified, then port names will be printed as SSA
1217/// values. If there is a reason the printed SSA values can't match the true
1218/// port name, then this function will return true. When this happens, the
1219/// caller should print the port names as a part of the `attr-dict`.
1220static bool
1221printModulePorts(OpAsmPrinter &p, Block *block, ArrayRef<bool> portDirections,
1222 ArrayRef<Attribute> portNames, ArrayRef<Attribute> portTypes,
1223 ArrayRef<Attribute> portAnnotations,
1224 ArrayRef<Attribute> portSyms, ArrayRef<Attribute> portLocs,
1225 ArrayRef<Attribute> domainInfo) {
1226 // When printing port names as SSA values, we can fail to print them
1227 // identically.
1228 bool printedNamesDontMatch = false;
1229
1230 mlir::OpPrintingFlags flags;
1231
1232 // If we are printing the ports as block arguments the op must have a first
1233 // block.
1234 SmallString<32> resultNameStr;
1235 DenseMap<unsigned, std::string> domainPortNames;
1236 p << '(';
1237 for (unsigned i = 0, e = portTypes.size(); i < e; ++i) {
1238 if (i > 0)
1239 p << ", ";
1240
1241 // Print the port direction.
1242 p << direction::get(portDirections[i]) << " ";
1243
1244 // Print the port name. If there is a valid block, we print it as a block
1245 // argument.
1246 auto portType = cast<TypeAttr>(portTypes[i]).getValue();
1247 if (block) {
1248 // Get the printed format for the argument name.
1249 resultNameStr.clear();
1250 llvm::raw_svector_ostream tmpStream(resultNameStr);
1251 p.printOperand(block->getArgument(i), tmpStream);
1252 // If the name wasn't printable in a way that agreed with portName, make
1253 // sure to print out an explicit portNames attribute.
1254 auto portName = cast<StringAttr>(portNames[i]).getValue();
1255 if (tmpStream.str().drop_front() != portName)
1256 printedNamesDontMatch = true;
1257 p << tmpStream.str();
1258 if (isa<DomainType>(portType))
1259 domainPortNames[i] = tmpStream.str();
1260 } else {
1261 auto name = cast<StringAttr>(portNames[i]).getValue();
1262 p.printKeywordOrString(name);
1263 if (isa<DomainType>(portType))
1264 domainPortNames[i] = name.str();
1265 }
1266
1267 // Print the port type.
1268 p << ": ";
1269 p.printType(portType);
1270
1271 // Print the optional port symbol.
1272 if (!portSyms.empty()) {
1273 if (!cast<hw::InnerSymAttr>(portSyms[i]).empty()) {
1274 p << " sym ";
1275 cast<hw::InnerSymAttr>(portSyms[i]).print(p);
1276 }
1277 }
1278
1279 // Print domain information.
1280 if (!domainInfo.empty()) {
1281 if (auto domainKind = dyn_cast<FlatSymbolRefAttr>(domainInfo[i])) {
1282 p << " of " << domainKind;
1283 } else {
1284 auto domains = cast<ArrayAttr>(domainInfo[i]);
1285 if (!domains.empty()) {
1286 p << " domains [";
1287 llvm::interleaveComma(domains, p, [&](Attribute attr) {
1288 p << domainPortNames[cast<IntegerAttr>(attr).getUInt()];
1289 });
1290 p << "]";
1291 }
1292 }
1293 }
1294
1295 // Print the port specific annotations. The port annotations array will be
1296 // empty if there are none.
1297 if (!portAnnotations.empty() &&
1298 !cast<ArrayAttr>(portAnnotations[i]).empty()) {
1299 p << " ";
1300 p.printAttribute(portAnnotations[i]);
1301 }
1302
1303 // Print the port location.
1304 // TODO: `printOptionalLocationSpecifier` will emit aliases for locations,
1305 // even if they are not printed. This will have to be fixed upstream. For
1306 // now, use what was specified on the command line.
1307 if (flags.shouldPrintDebugInfo() && !portLocs.empty())
1308 p.printOptionalLocationSpecifier(cast<LocationAttr>(portLocs[i]));
1309 }
1310
1311 p << ')';
1312 return printedNamesDontMatch;
1313}
1314
1315/// Parse a list of module ports. If port names are SSA identifiers, then this
1316/// will populate `entryArgs`.
1317static ParseResult parseModulePorts(
1318 OpAsmParser &parser, bool hasSSAIdentifiers, bool supportsSymbols,
1319 bool supportsDomains, SmallVectorImpl<OpAsmParser::Argument> &entryArgs,
1320 SmallVectorImpl<Direction> &portDirections,
1321 SmallVectorImpl<Attribute> &portNames,
1322 SmallVectorImpl<Attribute> &portTypes,
1323 SmallVectorImpl<Attribute> &portAnnotations,
1324 SmallVectorImpl<Attribute> &portSyms, SmallVectorImpl<Attribute> &portLocs,
1325 SmallVectorImpl<Attribute> &domains) {
1326 auto *context = parser.getContext();
1327
1328 // Mapping of domain name to port index.
1329 DenseMap<Attribute, size_t> domainIndex;
1330
1331 auto parseArgument = [&]() -> ParseResult {
1332 // Parse port direction.
1333 if (succeeded(parser.parseOptionalKeyword("out")))
1334 portDirections.push_back(Direction::Out);
1335 else if (succeeded(parser.parseKeyword("in", " or 'out'")))
1336 portDirections.push_back(Direction::In);
1337 else
1338 return failure();
1339
1340 // This is the location or the port declaration in the IR. If there is no
1341 // other location information, we use this to point to the MLIR.
1342 llvm::SMLoc irLoc;
1343
1344 if (hasSSAIdentifiers) {
1345 OpAsmParser::Argument arg;
1346 if (parser.parseArgument(arg))
1347 return failure();
1348 entryArgs.push_back(arg);
1349
1350 // The name of an argument is of the form "%42" or "%id", and since
1351 // parsing succeeded, we know it always has one character.
1352 assert(arg.ssaName.name.size() > 1 && arg.ssaName.name[0] == '%' &&
1353 "Unknown MLIR name");
1354 if (isdigit(arg.ssaName.name[1]))
1355 portNames.push_back(StringAttr::get(context, ""));
1356 else
1357 portNames.push_back(
1358 StringAttr::get(context, arg.ssaName.name.drop_front()));
1359
1360 // Store the location of the SSA name.
1361 irLoc = arg.ssaName.location;
1362
1363 } else {
1364 // Parse the port name.
1365 irLoc = parser.getCurrentLocation();
1366 std::string portName;
1367 if (parser.parseKeywordOrString(&portName))
1368 return failure();
1369 portNames.push_back(StringAttr::get(context, portName));
1370 }
1371
1372 // Parse the port type.
1373 Type portType;
1374 if (parser.parseColonType(portType))
1375 return failure();
1376 portTypes.push_back(TypeAttr::get(portType));
1377 if (isa<DomainType>(portType))
1378 domainIndex[portNames.back()] = portNames.size() - 1;
1379
1380 if (hasSSAIdentifiers)
1381 entryArgs.back().type = portType;
1382
1383 // Parse the optional port symbol.
1384 if (supportsSymbols) {
1385 hw::InnerSymAttr innerSymAttr;
1386 if (succeeded(parser.parseOptionalKeyword("sym"))) {
1387 NamedAttrList dummyAttrs;
1388 if (parser.parseCustomAttributeWithFallback(
1389 innerSymAttr, ::mlir::Type{},
1391 return ::mlir::failure();
1392 }
1393 }
1394 portSyms.push_back(innerSymAttr);
1395 }
1396
1397 // Parse optional port domain information if it exists.
1398 Attribute domainsAttr;
1399 SmallVector<Attribute> portDomains;
1400 if (supportsDomains) {
1401 if (isa<DomainType>(portType)) {
1402 if (parser.parseKeyword("of"))
1403 return failure();
1404 StringAttr domainKind;
1405 if (parser.parseSymbolName(domainKind))
1406 return failure();
1407 domainsAttr = FlatSymbolRefAttr::get(context, domainKind);
1408 } else if (succeeded(parser.parseOptionalKeyword("domains"))) {
1409 auto result = parser.parseCommaSeparatedList(
1410 OpAsmParser::Delimiter::Square, [&]() -> ParseResult {
1411 StringAttr argName;
1412 if (hasSSAIdentifiers) {
1413 OpAsmParser::Argument arg;
1414 if (parser.parseArgument(arg))
1415 return failure();
1416 argName =
1417 StringAttr::get(context, arg.ssaName.name.drop_front());
1418 } else {
1419 std::string portName;
1420 if (parser.parseKeywordOrString(&portName))
1421 return failure();
1422 argName = StringAttr::get(context, portName);
1423 }
1424
1425 auto index = domainIndex.find(argName);
1426 if (index == domainIndex.end()) {
1427 parser.emitError(irLoc)
1428 << "domain name '" << argName << "' not found";
1429 return failure();
1430 }
1431 portDomains.push_back(IntegerAttr::get(
1432 IntegerType::get(context, 32, IntegerType::Unsigned),
1433 index->second));
1434 return success();
1435 });
1436 if (failed(result))
1437 return failure();
1438 domainsAttr = parser.getBuilder().getArrayAttr(portDomains);
1439 }
1440 }
1441 if (!domainsAttr)
1442 domainsAttr = parser.getBuilder().getArrayAttr({});
1443 domains.push_back(domainsAttr);
1444
1445 // Parse the port annotations.
1446 ArrayAttr annos;
1447 auto parseResult = parser.parseOptionalAttribute(annos);
1448 if (!parseResult.has_value())
1449 annos = parser.getBuilder().getArrayAttr({});
1450 else if (failed(*parseResult))
1451 return failure();
1452 portAnnotations.push_back(annos);
1453
1454 // Parse the optional port location.
1455 std::optional<Location> maybeLoc;
1456 if (failed(parser.parseOptionalLocationSpecifier(maybeLoc)))
1457 return failure();
1458 Location loc = maybeLoc ? *maybeLoc : parser.getEncodedSourceLoc(irLoc);
1459 portLocs.push_back(loc);
1460 if (hasSSAIdentifiers)
1461 entryArgs.back().sourceLoc = loc;
1462
1463 return success();
1464 };
1465
1466 // Parse all ports.
1467 return parser.parseCommaSeparatedList(OpAsmParser::Delimiter::Paren,
1468 parseArgument);
1469}
1470
1471/// Print a paramter list for a module or instance.
1472static void printParameterList(OpAsmPrinter &p, Operation *op,
1473 ArrayAttr parameters) {
1474 if (!parameters || parameters.empty())
1475 return;
1476
1477 p << '<';
1478 llvm::interleaveComma(parameters, p, [&](Attribute param) {
1479 auto paramAttr = cast<ParamDeclAttr>(param);
1480 p << paramAttr.getName().getValue() << ": " << paramAttr.getType();
1481 if (auto value = paramAttr.getValue()) {
1482 p << " = ";
1483 p.printAttributeWithoutType(value);
1484 }
1485 });
1486 p << '>';
1487}
1488
1489static void printFModuleLikeOp(OpAsmPrinter &p, FModuleLike op) {
1490 p << " ";
1491
1492 // Print the visibility of the module.
1493 StringRef visibilityAttrName = SymbolTable::getVisibilityAttrName();
1494 if (auto visibility = op->getAttrOfType<StringAttr>(visibilityAttrName))
1495 p << visibility.getValue() << ' ';
1496
1497 // Print the operation and the function name.
1498 p.printSymbolName(op.getModuleName());
1499
1500 // Print the parameter list (if non-empty).
1501 printParameterList(p, op, op->getAttrOfType<ArrayAttr>("parameters"));
1502
1503 // Both modules and external modules have a body, but it is always empty for
1504 // external modules.
1505 Block *body = nullptr;
1506 if (!op->getRegion(0).empty())
1507 body = &op->getRegion(0).front();
1508
1509 auto needPortNamesAttr = printModulePorts(
1510 p, body, op.getPortDirectionsAttr(), op.getPortNames(), op.getPortTypes(),
1511 op.getPortAnnotations(), op.getPortSymbols(), op.getPortLocations(),
1512 op.getDomainInfo());
1513
1514 SmallVector<StringRef, 13> omittedAttrs = {
1515 "sym_name", "portDirections", "portTypes",
1516 "portAnnotations", "portSymbols", "portLocations",
1517 "parameters", visibilityAttrName, "domainInfo"};
1518
1519 if (op.getConvention() == Convention::Internal)
1520 omittedAttrs.push_back("convention");
1521
1522 // We can omit the portNames if they were able to be printed as properly as
1523 // block arguments.
1524 if (!needPortNamesAttr)
1525 omittedAttrs.push_back("portNames");
1526
1527 // If there are no annotations we can omit the empty array.
1528 if (op->getAttrOfType<ArrayAttr>("annotations").empty())
1529 omittedAttrs.push_back("annotations");
1530
1531 // If there are no known layers, then omit the empty array.
1532 if (auto knownLayers = op->getAttrOfType<ArrayAttr>("knownLayers"))
1533 if (knownLayers.empty())
1534 omittedAttrs.push_back("knownLayers");
1535
1536 // If there are no enabled layers, then omit the empty array.
1537 if (auto layers = op->getAttrOfType<ArrayAttr>("layers"))
1538 if (layers.empty())
1539 omittedAttrs.push_back("layers");
1540
1541 p.printOptionalAttrDictWithKeyword(op->getAttrs(), omittedAttrs);
1542}
1543
1544void FExtModuleOp::print(OpAsmPrinter &p) { printFModuleLikeOp(p, *this); }
1545
1546void FIntModuleOp::print(OpAsmPrinter &p) { printFModuleLikeOp(p, *this); }
1547
1548void FMemModuleOp::print(OpAsmPrinter &p) { printFModuleLikeOp(p, *this); }
1549
1550void FModuleOp::print(OpAsmPrinter &p) {
1551 printFModuleLikeOp(p, *this);
1552
1553 // Print the body if this is not an external function. Since this block does
1554 // not have terminators, printing the terminator actually just prints the last
1555 // operation.
1556 Region &fbody = getBody();
1557 if (!fbody.empty()) {
1558 p << " ";
1559 p.printRegion(fbody, /*printEntryBlockArgs=*/false,
1560 /*printBlockTerminators=*/true);
1561 }
1562}
1563
1564/// Parse an parameter list if present.
1565/// module-parameter-list ::= `<` parameter-decl (`,` parameter-decl)* `>`
1566/// parameter-decl ::= identifier `:` type
1567/// parameter-decl ::= identifier `:` type `=` attribute
1568///
1569static ParseResult
1570parseOptionalParameters(OpAsmParser &parser,
1571 SmallVectorImpl<Attribute> &parameters) {
1572
1573 return parser.parseCommaSeparatedList(
1574 OpAsmParser::Delimiter::OptionalLessGreater, [&]() {
1575 std::string name;
1576 Type type;
1577 Attribute value;
1578
1579 if (parser.parseKeywordOrString(&name) || parser.parseColonType(type))
1580 return failure();
1581
1582 // Parse the default value if present.
1583 if (succeeded(parser.parseOptionalEqual())) {
1584 if (parser.parseAttribute(value, type))
1585 return failure();
1586 }
1587
1588 auto &builder = parser.getBuilder();
1589 parameters.push_back(ParamDeclAttr::get(
1590 builder.getContext(), builder.getStringAttr(name), type, value));
1591 return success();
1592 });
1593}
1594
1595/// Shim to use with assemblyFormat, custom<ParameterList>.
1596static ParseResult parseParameterList(OpAsmParser &parser,
1597 ArrayAttr &parameters) {
1598 SmallVector<Attribute> parseParameters;
1599 if (failed(parseOptionalParameters(parser, parseParameters)))
1600 return failure();
1601
1602 parameters = ArrayAttr::get(parser.getContext(), parseParameters);
1603
1604 return success();
1605}
1606
1607template <typename Properties, typename = void>
1608struct HasParameters : std::false_type {};
1609
1610template <typename Properties>
1612 Properties, std::void_t<decltype(std::declval<Properties>().parameters)>>
1613 : std::true_type {};
1614
1615template <typename OpTy>
1616static ParseResult parseFModuleLikeOp(OpAsmParser &parser,
1617 OperationState &result,
1618 bool hasSSAIdentifiers) {
1619 auto *context = result.getContext();
1620 auto &builder = parser.getBuilder();
1621 using Properties = typename OpTy::Properties;
1622 auto &properties = result.getOrAddProperties<Properties>();
1623
1624 // TODO: this should be using properties.
1625 // Parse the visibility attribute.
1626 (void)mlir::impl::parseOptionalVisibilityKeyword(parser, result.attributes);
1627
1628 // Parse the name as a symbol.
1629 StringAttr nameAttr;
1630 if (parser.parseSymbolName(nameAttr))
1631 return failure();
1632 properties.setSymName(nameAttr);
1633
1634 // Parse optional parameters.
1635 if constexpr (HasParameters<Properties>::value) {
1636 SmallVector<Attribute, 4> parameters;
1637 if (parseOptionalParameters(parser, parameters))
1638 return failure();
1639 properties.setParameters(builder.getArrayAttr(parameters));
1640 }
1641
1642 // Parse the module ports.
1643 SmallVector<OpAsmParser::Argument> entryArgs;
1644 SmallVector<Direction, 4> portDirections;
1645 SmallVector<Attribute, 4> portNames;
1646 SmallVector<Attribute, 4> portTypes;
1647 SmallVector<Attribute, 4> portAnnotations;
1648 SmallVector<Attribute, 4> portSyms;
1649 SmallVector<Attribute, 4> portLocs;
1650 SmallVector<Attribute, 4> domains;
1651 if (parseModulePorts(parser, hasSSAIdentifiers, /*supportsSymbols=*/true,
1652 /*supportsDomains=*/true, entryArgs, portDirections,
1653 portNames, portTypes, portAnnotations, portSyms,
1654 portLocs, domains))
1655 return failure();
1656
1657 // If module attributes are present, parse them.
1658 if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
1659 return failure();
1660
1661 assert(portNames.size() == portTypes.size());
1662
1663 // Record the argument and result types as an attribute. This is necessary
1664 // for external modules.
1665
1666 // Add port directions.
1667 properties.setPortDirections(
1668 direction::packAttribute(context, portDirections));
1669
1670 // Add port names.
1671 properties.setPortNames(builder.getArrayAttr(portNames));
1672
1673 // Add the port types.
1674 properties.setPortTypes(ArrayAttr::get(context, portTypes));
1675
1676 // Add the port annotations.
1677 // If there are no portAnnotations, don't add the attribute.
1678 if (llvm::any_of(portAnnotations, [&](Attribute anno) {
1679 return !cast<ArrayAttr>(anno).empty();
1680 }))
1681 properties.setPortAnnotations(ArrayAttr::get(context, portAnnotations));
1682 else
1683 properties.setPortAnnotations(builder.getArrayAttr({}));
1684
1685 // Add port symbols.
1686 FModuleLike::fixupPortSymsArray(portSyms, builder.getContext());
1687 properties.setPortSymbols(builder.getArrayAttr(portSyms));
1688
1689 // Add port locations.
1690 properties.setPortLocations(ArrayAttr::get(context, portLocs));
1691
1692 // The annotations attribute is always present, but not printed when empty.
1693 properties.setAnnotations(builder.getArrayAttr({}));
1694
1695 // Add domains. Use an empty array if none are set.
1696 if (llvm::all_of(domains, [&](Attribute attr) {
1697 auto arrayAttr = dyn_cast<ArrayAttr>(attr);
1698 return arrayAttr && arrayAttr.empty();
1699 }))
1700 properties.setDomainInfo(ArrayAttr::get(context, {}));
1701 else
1702 properties.setDomainInfo(ArrayAttr::get(context, domains));
1703
1704 // Parse the optional function body.
1705 auto *body = result.addRegion();
1706
1707 if (hasSSAIdentifiers) {
1708 if (parser.parseRegion(*body, entryArgs))
1709 return failure();
1710 if (body->empty())
1711 body->push_back(new Block());
1712 }
1713 return success();
1714}
1715
1716ParseResult FModuleOp::parse(OpAsmParser &parser, OperationState &result) {
1717 if (parseFModuleLikeOp<FModuleOp>(parser, result,
1718 /*hasSSAIdentifiers=*/true))
1719 return failure();
1720 auto &properties = result.getOrAddProperties<Properties>();
1721 properties.setConvention(
1722 ConventionAttr::get(result.getContext(), Convention::Internal));
1723 properties.setLayers(ArrayAttr::get(parser.getContext(), {}));
1724 return success();
1725}
1726
1727ParseResult FExtModuleOp::parse(OpAsmParser &parser, OperationState &result) {
1728 if (parseFModuleLikeOp<FExtModuleOp>(parser, result,
1729 /*hasSSAIdentifiers=*/false))
1730 return failure();
1731 auto &properties = result.getOrAddProperties<Properties>();
1732 properties.setConvention(
1733 ConventionAttr::get(result.getContext(), Convention::Internal));
1734 properties.setKnownLayers(ArrayAttr::get(result.getContext(), {}));
1735 return success();
1736}
1737
1738ParseResult FIntModuleOp::parse(OpAsmParser &parser, OperationState &result) {
1739 return parseFModuleLikeOp<FIntModuleOp>(parser, result,
1740 /*hasSSAIdentifiers=*/false);
1741}
1742
1743ParseResult FMemModuleOp::parse(OpAsmParser &parser, OperationState &result) {
1744 return parseFModuleLikeOp<FMemModuleOp>(parser, result,
1745 /*hasSSAIdentifiers=*/false);
1746}
1747
1748LogicalResult FModuleOp::verify() {
1749 // Verify the block arguments.
1750 auto *body = getBodyBlock();
1751 auto portTypes = getPortTypes();
1752 auto portLocs = getPortLocations();
1753 auto numPorts = portTypes.size();
1754
1755 // Verify that we have the correct number of block arguments.
1756 if (body->getNumArguments() != numPorts)
1757 return emitOpError("entry block must have ")
1758 << numPorts << " arguments to match module signature";
1759
1760 // Verify the block arguments' types and locations match our attributes.
1761 for (auto [arg, type, loc] : zip(body->getArguments(), portTypes, portLocs)) {
1762 if (arg.getType() != cast<TypeAttr>(type).getValue())
1763 return emitOpError("block argument types should match signature types");
1764 if (arg.getLoc() != cast<LocationAttr>(loc))
1765 return emitOpError(
1766 "block argument locations should match signature locations");
1767 }
1768
1769 return success();
1770}
1771
1772LogicalResult FExtModuleOp::verify() {
1773 auto params = getParameters();
1774
1775 auto checkParmValue = [&](Attribute elt) -> bool {
1776 auto param = cast<ParamDeclAttr>(elt);
1777 auto value = param.getValue();
1778 if (isa<IntegerAttr, StringAttr, FloatAttr, hw::ParamVerbatimAttr>(value))
1779 return true;
1780 emitError() << "has unknown extmodule parameter value '"
1781 << param.getName().getValue() << "' = " << value;
1782 return false;
1783 };
1784
1785 if (!llvm::all_of(params, checkParmValue))
1786 return failure();
1787
1788 // Verify that any mentioned layers are marked as known.
1789 LayerSet known;
1790 known.insert_range(getKnownLayersAttr().getAsRange<SymbolRefAttr>());
1791
1792 LayerSet referenced;
1793 referenced.insert_range(getLayersAttr().getAsRange<SymbolRefAttr>());
1794 for (auto attr : getPortTypes()) {
1795 auto type = cast<TypeAttr>(attr).getValue();
1796 if (auto refType = type_dyn_cast<RefType>(type))
1797 if (auto layer = refType.getLayer())
1798 referenced.insert(layer);
1799 }
1800
1801 return checkLayerCompatibility(getOperation(), referenced, known,
1802 "references unknown layers", "unknown layers");
1803}
1804
1805LogicalResult FIntModuleOp::verify() {
1806 auto params = getParameters();
1807 if (params.empty())
1808 return success();
1809
1810 auto checkParmValue = [&](Attribute elt) -> bool {
1811 auto param = cast<ParamDeclAttr>(elt);
1812 auto value = param.getValue();
1813 if (isa<IntegerAttr, StringAttr, FloatAttr>(value))
1814 return true;
1815 emitError() << "has unknown intmodule parameter value '"
1816 << param.getName().getValue() << "' = " << value;
1817 return false;
1818 };
1819
1820 if (!llvm::all_of(params, checkParmValue))
1821 return failure();
1822
1823 return success();
1824}
1825
1826static LogicalResult verifyProbeType(RefType refType, Location loc,
1827 CircuitOp circuitOp,
1828 SymbolTableCollection &symbolTable,
1829 Twine start) {
1830 auto layer = refType.getLayer();
1831 if (!layer)
1832 return success();
1833 auto *layerOp = symbolTable.lookupSymbolIn(circuitOp, layer);
1834 if (!layerOp)
1835 return emitError(loc) << start << " associated with layer '" << layer
1836 << "', but this layer was not defined";
1837 if (!isa<LayerOp>(layerOp)) {
1838 auto diag = emitError(loc)
1839 << start << " associated with layer '" << layer
1840 << "', but symbol '" << layer << "' does not refer to a '"
1841 << LayerOp::getOperationName() << "' op";
1842 return diag.attachNote(layerOp->getLoc()) << "symbol refers to this op";
1843 }
1844 return success();
1845}
1846
1847static LogicalResult verifyPortSymbolUses(FModuleLike module,
1848 SymbolTableCollection &symbolTable) {
1849 // verify types in ports.
1850 auto circuitOp = module->getParentOfType<CircuitOp>();
1851 for (size_t i = 0, e = module.getNumPorts(); i < e; ++i) {
1852 auto type = module.getPortType(i);
1853
1854 if (auto refType = type_dyn_cast<RefType>(type)) {
1855 if (failed(verifyProbeType(
1856 refType, module.getPortLocation(i), circuitOp, symbolTable,
1857 Twine("probe port '") + module.getPortName(i) + "' is")))
1858 return failure();
1859 continue;
1860 }
1861
1862 if (auto classType = dyn_cast<ClassType>(type)) {
1863 auto className = classType.getNameAttr();
1864 auto classOp = dyn_cast_or_null<ClassLike>(
1865 symbolTable.lookupSymbolIn(circuitOp, className));
1866 if (!classOp)
1867 return module.emitOpError() << "references unknown class " << className;
1868
1869 // verify that the result type agrees with the class definition.
1870 if (failed(classOp.verifyType(classType,
1871 [&]() { return module.emitOpError(); })))
1872 return failure();
1873 continue;
1874 }
1875
1876 if (isa<DomainType>(type)) {
1877 auto domainInfo = module.getDomainInfoAttrForPort(i);
1878 if (auto kind = dyn_cast<FlatSymbolRefAttr>(domainInfo))
1879 if (!dyn_cast_or_null<DomainOp>(
1880 symbolTable.lookupSymbolIn(circuitOp, kind)))
1881 return mlir::emitError(module.getPortLocation(i))
1882 << "domain port '" << module.getPortName(i)
1883 << "' has undefined domain kind '" << kind.getValue() << "'";
1884 }
1885 }
1886
1887 return success();
1888}
1889
1890LogicalResult FModuleOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1891 if (failed(verifyPortSymbolUses(*this, symbolTable)))
1892 return failure();
1893
1894 auto circuitOp = getOperation()->getParentOfType<CircuitOp>();
1895 for (auto layer : getLayers()) {
1896 if (!symbolTable.lookupSymbolIn(circuitOp, cast<SymbolRefAttr>(layer)))
1897 return emitOpError() << "enables undefined layer '" << layer << "'";
1898 }
1899
1900 return success();
1901}
1902
1903LogicalResult
1904FExtModuleOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1905 if (failed(verifyPortSymbolUses(*this, symbolTable)))
1906 return failure();
1907
1908 auto circuitOp = getOperation()->getParentOfType<CircuitOp>();
1909 for (auto layer : getKnownLayersAttr().getAsRange<SymbolRefAttr>()) {
1910 if (!symbolTable.lookupSymbolIn(circuitOp, layer))
1911 return emitOpError() << "knows undefined layer '" << layer << "'";
1912 }
1913 for (auto layer : getLayersAttr().getAsRange<SymbolRefAttr>()) {
1914 if (!symbolTable.lookupSymbolIn(circuitOp, layer))
1915 return emitOpError() << "enables undefined layer '" << layer << "'";
1916 }
1917
1918 return success();
1919}
1920
1921LogicalResult
1922FIntModuleOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1923 return verifyPortSymbolUses(*this, symbolTable);
1924}
1925
1926LogicalResult
1927FMemModuleOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1928 return verifyPortSymbolUses(*this, symbolTable);
1929}
1930
1931void FModuleOp::getAsmBlockArgumentNames(mlir::Region &region,
1932 mlir::OpAsmSetValueNameFn setNameFn) {
1933 getAsmBlockArgumentNamesImpl(getOperation(), region, setNameFn);
1934}
1935
1936void FExtModuleOp::getAsmBlockArgumentNames(
1937 mlir::Region &region, mlir::OpAsmSetValueNameFn setNameFn) {
1938 getAsmBlockArgumentNamesImpl(getOperation(), region, setNameFn);
1939}
1940
1941StringRef FExtModuleOp::getExtModuleName() {
1942 if (auto defname = getDefname(); defname && !defname->empty())
1943 return *defname;
1944 return getName();
1945}
1946
1947void FIntModuleOp::getAsmBlockArgumentNames(
1948 mlir::Region &region, mlir::OpAsmSetValueNameFn setNameFn) {
1949 getAsmBlockArgumentNamesImpl(getOperation(), region, setNameFn);
1950}
1951
1952void FMemModuleOp::getAsmBlockArgumentNames(
1953 mlir::Region &region, mlir::OpAsmSetValueNameFn setNameFn) {
1954 getAsmBlockArgumentNamesImpl(getOperation(), region, setNameFn);
1955}
1956
1957ArrayAttr FMemModuleOp::getParameters() { return {}; }
1958
1959ArrayAttr FModuleOp::getParameters() { return {}; }
1960
1961Convention FIntModuleOp::getConvention() { return Convention::Internal; }
1962
1963ConventionAttr FIntModuleOp::getConventionAttr() {
1964 return ConventionAttr::get(getContext(), getConvention());
1965}
1966
1967Convention FMemModuleOp::getConvention() { return Convention::Internal; }
1968
1969ConventionAttr FMemModuleOp::getConventionAttr() {
1970 return ConventionAttr::get(getContext(), getConvention());
1971}
1972
1973//===----------------------------------------------------------------------===//
1974// ClassLike Helpers
1975//===----------------------------------------------------------------------===//
1976
1978 ClassLike classOp, ClassType type,
1979 function_ref<InFlightDiagnostic()> emitError) {
1980 // This check is probably not required, but done for sanity.
1981 auto name = type.getNameAttr().getAttr();
1982 auto expectedName = classOp.getModuleNameAttr();
1983 if (name != expectedName)
1984 return emitError() << "type has wrong name, got " << name << ", expected "
1985 << expectedName;
1986
1987 auto elements = type.getElements();
1988 auto numElements = elements.size();
1989 auto expectedNumElements = classOp.getNumPorts();
1990 if (numElements != expectedNumElements)
1991 return emitError() << "has wrong number of ports, got " << numElements
1992 << ", expected " << expectedNumElements;
1993
1994 auto portNames = classOp.getPortNames();
1995 auto portDirections = classOp.getPortDirections();
1996 auto portTypes = classOp.getPortTypes();
1997
1998 for (unsigned i = 0; i < numElements; ++i) {
1999 auto element = elements[i];
2000
2001 auto name = element.name;
2002 auto expectedName = portNames[i];
2003 if (name != expectedName)
2004 return emitError() << "port #" << i << " has wrong name, got " << name
2005 << ", expected " << expectedName;
2006
2007 auto direction = element.direction;
2008 auto expectedDirection = Direction(portDirections[i]);
2009 if (direction != expectedDirection)
2010 return emitError() << "port " << name << " has wrong direction, got "
2011 << direction::toString(direction) << ", expected "
2012 << direction::toString(expectedDirection);
2013
2014 auto type = element.type;
2015 auto expectedType = cast<TypeAttr>(portTypes[i]).getValue();
2016 if (type != expectedType)
2017 return emitError() << "port " << name << " has wrong type, got " << type
2018 << ", expected " << expectedType;
2019 }
2020
2021 return success();
2022}
2023
2025 auto n = classOp.getNumPorts();
2026 SmallVector<ClassElement> elements;
2027 elements.reserve(n);
2028 for (size_t i = 0; i < n; ++i)
2029 elements.push_back({classOp.getPortNameAttr(i), classOp.getPortType(i),
2030 classOp.getPortDirection(i)});
2031 auto name = FlatSymbolRefAttr::get(classOp.getNameAttr());
2032 return ClassType::get(name, elements);
2033}
2034
2035template <typename OpTy>
2036ParseResult parseClassLike(OpAsmParser &parser, OperationState &result,
2037 bool hasSSAIdentifiers) {
2038 auto *context = result.getContext();
2039 auto &builder = parser.getBuilder();
2040 auto &properties = result.getOrAddProperties<typename OpTy::Properties>();
2041
2042 // TODO: this should use properties.
2043 // Parse the visibility attribute.
2044 (void)mlir::impl::parseOptionalVisibilityKeyword(parser, result.attributes);
2045
2046 // Parse the name as a symbol.
2047 StringAttr nameAttr;
2048 if (parser.parseSymbolName(nameAttr))
2049 return failure();
2050 properties.setSymName(nameAttr);
2051
2052 // Parse the module ports.
2053 SmallVector<OpAsmParser::Argument> entryArgs;
2054 SmallVector<Direction, 4> portDirections;
2055 SmallVector<Attribute, 4> portNames;
2056 SmallVector<Attribute, 4> portTypes;
2057 SmallVector<Attribute, 4> portAnnotations;
2058 SmallVector<Attribute, 4> portSyms;
2059 SmallVector<Attribute, 4> portLocs;
2060 SmallVector<Attribute, 4> domains;
2061 if (parseModulePorts(parser, hasSSAIdentifiers,
2062 /*supportsSymbols=*/false, /*supportsDomains=*/false,
2063 entryArgs, portDirections, portNames, portTypes,
2064 portAnnotations, portSyms, portLocs, domains))
2065 return failure();
2066
2067 // Ports on ClassLike ops cannot have annotations
2068 for (auto annos : portAnnotations)
2069 if (!cast<ArrayAttr>(annos).empty())
2070 return failure();
2071
2072 // If attributes are present, parse them.
2073 if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
2074 return failure();
2075
2076 assert(portNames.size() == portTypes.size());
2077
2078 // Record the argument and result types as an attribute. This is necessary
2079 // for external modules.
2080
2081 // Add port directions.
2082 properties.setPortDirections(
2083 direction::packAttribute(context, portDirections));
2084
2085 // Add port names.
2086 properties.setPortNames(builder.getArrayAttr(portNames));
2087
2088 // Add the port types.
2089 properties.setPortTypes(builder.getArrayAttr(portTypes));
2090
2091 // Add the port symbols.
2092 FModuleLike::fixupPortSymsArray(portSyms, builder.getContext());
2093 properties.setPortSymbols(builder.getArrayAttr(portSyms));
2094
2095 // Add port locations.
2096 properties.setPortLocations(ArrayAttr::get(context, portLocs));
2097
2098 // Notably missing compared to other FModuleLike, we do not track port
2099 // annotations, nor port symbols, on classes.
2100
2101 // Add the region (unused by extclass).
2102 auto *bodyRegion = result.addRegion();
2103
2104 if (hasSSAIdentifiers) {
2105 if (parser.parseRegion(*bodyRegion, entryArgs))
2106 return failure();
2107 if (bodyRegion->empty())
2108 bodyRegion->push_back(new Block());
2109 }
2110
2111 return success();
2112}
2113
2114static void printClassLike(OpAsmPrinter &p, ClassLike op) {
2115 p << ' ';
2116
2117 // Print the visibility of the class.
2118 StringRef visibilityAttrName = SymbolTable::getVisibilityAttrName();
2119 if (auto visibility = op->getAttrOfType<StringAttr>(visibilityAttrName))
2120 p << visibility.getValue() << ' ';
2121
2122 // Print the class name.
2123 p.printSymbolName(op.getName());
2124
2125 // Both classes and external classes have a body, but it is always empty for
2126 // external classes.
2127 Region &region = op->getRegion(0);
2128 Block *body = nullptr;
2129 if (!region.empty())
2130 body = &region.front();
2131
2132 auto needPortNamesAttr = printModulePorts(
2133 p, body, op.getPortDirectionsAttr(), op.getPortNames(), op.getPortTypes(),
2134 {}, op.getPortSymbols(), op.getPortLocations(), {});
2135
2136 // Print the attr-dict.
2137 SmallVector<StringRef, 8> omittedAttrs = {
2138 "sym_name", "portNames", "portTypes", "portDirections",
2139 "portSymbols", "portLocations", visibilityAttrName, "domainInfo"};
2140
2141 // We can omit the portNames if they were able to be printed as properly as
2142 // block arguments.
2143 if (!needPortNamesAttr)
2144 omittedAttrs.push_back("portNames");
2145
2146 p.printOptionalAttrDictWithKeyword(op->getAttrs(), omittedAttrs);
2147
2148 // print the body if it exists.
2149 if (!region.empty()) {
2150 p << " ";
2151 auto printEntryBlockArgs = false;
2152 auto printBlockTerminators = false;
2153 p.printRegion(region, printEntryBlockArgs, printBlockTerminators);
2154 }
2155}
2156
2157//===----------------------------------------------------------------------===//
2158// ClassOp
2159//===----------------------------------------------------------------------===//
2160
2161void ClassOp::build(OpBuilder &builder, OperationState &result, StringAttr name,
2162 ArrayRef<PortInfo> ports) {
2163 assert(
2164 llvm::all_of(ports,
2165 [](const auto &port) { return port.annotations.empty(); }) &&
2166 "class ports may not have annotations");
2167
2168 buildClass<ClassOp>(builder, result, name, ports);
2169
2170 // Create a region and a block for the body.
2171 auto *bodyRegion = result.regions[0].get();
2172 Block *body = new Block();
2173 bodyRegion->push_back(body);
2174
2175 // Add arguments to the body block.
2176 for (auto &elt : ports)
2177 body->addArgument(elt.type, elt.loc);
2178}
2179
2180void ClassOp::build(::mlir::OpBuilder &odsBuilder,
2181 ::mlir::OperationState &odsState, Twine name,
2182 mlir::ArrayRef<mlir::StringRef> fieldNames,
2183 mlir::ArrayRef<mlir::Type> fieldTypes) {
2184
2185 SmallVector<PortInfo, 10> ports;
2186 for (auto [fieldName, fieldType] : llvm::zip(fieldNames, fieldTypes)) {
2187 ports.emplace_back(odsBuilder.getStringAttr(fieldName + "_in"), fieldType,
2188 Direction::In);
2189 ports.emplace_back(odsBuilder.getStringAttr(fieldName), fieldType,
2190 Direction::Out);
2191 }
2192 build(odsBuilder, odsState, odsBuilder.getStringAttr(name), ports);
2193 // Create a region and a block for the body.
2194 auto &body = odsState.regions[0]->getBlocks().front();
2195 auto prevLoc = odsBuilder.saveInsertionPoint();
2196 odsBuilder.setInsertionPointToEnd(&body);
2197 auto args = body.getArguments();
2198 auto loc = odsState.location;
2199 for (unsigned i = 0, e = ports.size(); i != e; i += 2)
2200 PropAssignOp::create(odsBuilder, loc, args[i + 1], args[i]);
2201
2202 odsBuilder.restoreInsertionPoint(prevLoc);
2203}
2204void ClassOp::print(OpAsmPrinter &p) {
2205 printClassLike(p, cast<ClassLike>(getOperation()));
2206}
2207
2208ParseResult ClassOp::parse(OpAsmParser &parser, OperationState &result) {
2209 auto hasSSAIdentifiers = true;
2210 return parseClassLike<ClassOp>(parser, result, hasSSAIdentifiers);
2211}
2212
2213LogicalResult ClassOp::verify() {
2214 for (auto operand : getBodyBlock()->getArguments()) {
2215 auto type = operand.getType();
2216 if (!isa<PropertyType>(type)) {
2217 emitOpError("ports on a class must be properties");
2218 return failure();
2219 }
2220 }
2221
2222 return success();
2223}
2224
2225LogicalResult
2226ClassOp::verifySymbolUses(::mlir::SymbolTableCollection &symbolTable) {
2227 return verifyPortSymbolUses(cast<FModuleLike>(getOperation()), symbolTable);
2228}
2229
2230void ClassOp::getAsmBlockArgumentNames(mlir::Region &region,
2231 mlir::OpAsmSetValueNameFn setNameFn) {
2232 getAsmBlockArgumentNamesImpl(getOperation(), region, setNameFn);
2233}
2234
2235SmallVector<PortInfo> ClassOp::getPorts() {
2236 return ::getPortImpl(cast<FModuleLike>((Operation *)*this));
2237}
2238
2239void ClassOp::erasePorts(const llvm::BitVector &portIndices) {
2240 ::erasePorts(cast<FModuleLike>((Operation *)*this), portIndices);
2241 getBodyBlock()->eraseArguments(portIndices);
2242}
2243
2244void ClassOp::insertPorts(ArrayRef<std::pair<unsigned, PortInfo>> ports) {
2245 ::insertPorts(cast<FModuleLike>((Operation *)*this), ports);
2246}
2247
2248Convention ClassOp::getConvention() { return Convention::Internal; }
2249
2250ConventionAttr ClassOp::getConventionAttr() {
2251 return ConventionAttr::get(getContext(), getConvention());
2252}
2253
2254ArrayAttr ClassOp::getParameters() { return {}; }
2255
2256ArrayAttr ClassOp::getPortAnnotationsAttr() {
2257 return ArrayAttr::get(getContext(), {});
2258}
2259
2260ArrayRef<Attribute> ClassOp::getPortAnnotations() { return {}; }
2261
2262void ClassOp::setPortAnnotationsAttr(ArrayAttr annotations) {
2263 llvm_unreachable("classes do not support annotations");
2264}
2265
2266ArrayAttr ClassOp::getLayersAttr() { return ArrayAttr::get(getContext(), {}); }
2267
2268ArrayRef<Attribute> ClassOp::getLayers() { return {}; }
2269
2270SmallVector<::circt::hw::PortInfo> ClassOp::getPortList() {
2271 return ::getPortListImpl(*this);
2272}
2273
2274::circt::hw::PortInfo ClassOp::getPort(size_t idx) {
2275 return ::getPortImpl(*this, idx);
2276}
2277
2278BlockArgument ClassOp::getArgument(size_t portNumber) {
2279 return getBodyBlock()->getArgument(portNumber);
2280}
2281
2282bool ClassOp::canDiscardOnUseEmpty() {
2283 // ClassOps are referenced by ClassTypes, and these uses are not
2284 // discoverable by the symbol infrastructure. Return false here to prevent
2285 // passes like symbolDCE from removing our classes.
2286 return false;
2287}
2288
2289//===----------------------------------------------------------------------===//
2290// ExtClassOp
2291//===----------------------------------------------------------------------===//
2292
2293void ExtClassOp::build(OpBuilder &builder, OperationState &result,
2294 StringAttr name, ArrayRef<PortInfo> ports) {
2295 assert(
2296 llvm::all_of(ports,
2297 [](const auto &port) { return port.annotations.empty(); }) &&
2298 "class ports may not have annotations");
2299 buildClass<ClassOp>(builder, result, name, ports);
2300}
2301
2302void ExtClassOp::print(OpAsmPrinter &p) {
2303 printClassLike(p, cast<ClassLike>(getOperation()));
2304}
2305
2306ParseResult ExtClassOp::parse(OpAsmParser &parser, OperationState &result) {
2307 auto hasSSAIdentifiers = false;
2308 return parseClassLike<ExtClassOp>(parser, result, hasSSAIdentifiers);
2309}
2310
2311LogicalResult
2312ExtClassOp::verifySymbolUses(::mlir::SymbolTableCollection &symbolTable) {
2313 return verifyPortSymbolUses(cast<FModuleLike>(getOperation()), symbolTable);
2314}
2315
2316void ExtClassOp::getAsmBlockArgumentNames(mlir::Region &region,
2317 mlir::OpAsmSetValueNameFn setNameFn) {
2318 getAsmBlockArgumentNamesImpl(getOperation(), region, setNameFn);
2319}
2320
2321SmallVector<PortInfo> ExtClassOp::getPorts() {
2322 return ::getPortImpl(cast<FModuleLike>((Operation *)*this));
2323}
2324
2325void ExtClassOp::erasePorts(const llvm::BitVector &portIndices) {
2326 ::erasePorts(cast<FModuleLike>((Operation *)*this), portIndices);
2327}
2328
2329void ExtClassOp::insertPorts(ArrayRef<std::pair<unsigned, PortInfo>> ports) {
2330 ::insertPorts(cast<FModuleLike>((Operation *)*this), ports);
2331}
2332
2333Convention ExtClassOp::getConvention() { return Convention::Internal; }
2334
2335ConventionAttr ExtClassOp::getConventionAttr() {
2336 return ConventionAttr::get(getContext(), getConvention());
2337}
2338
2339ArrayAttr ExtClassOp::getLayersAttr() {
2340 return ArrayAttr::get(getContext(), {});
2341}
2342
2343ArrayRef<Attribute> ExtClassOp::getLayers() { return {}; }
2344
2345ArrayAttr ExtClassOp::getParameters() { return {}; }
2346
2347ArrayAttr ExtClassOp::getPortAnnotationsAttr() {
2348 return ArrayAttr::get(getContext(), {});
2349}
2350
2351ArrayRef<Attribute> ExtClassOp::getPortAnnotations() { return {}; }
2352
2353void ExtClassOp::setPortAnnotationsAttr(ArrayAttr annotations) {
2354 llvm_unreachable("classes do not support annotations");
2355}
2356
2357SmallVector<::circt::hw::PortInfo> ExtClassOp::getPortList() {
2358 return ::getPortListImpl(*this);
2359}
2360
2361::circt::hw::PortInfo ExtClassOp::getPort(size_t idx) {
2362 return ::getPortImpl(*this, idx);
2363}
2364
2365bool ExtClassOp::canDiscardOnUseEmpty() {
2366 // ClassOps are referenced by ClassTypes, and these uses are not
2367 // discovereable by the symbol infrastructure. Return false here to prevent
2368 // passes like symbolDCE from removing our classes.
2369 return false;
2370}
2371
2372//===----------------------------------------------------------------------===//
2373// InstanceOp
2374//===----------------------------------------------------------------------===//
2375
2376void InstanceOp::build(
2377 OpBuilder &builder, OperationState &result, TypeRange resultTypes,
2378 StringRef moduleName, StringRef name, NameKindEnum nameKind,
2379 ArrayRef<Direction> portDirections, ArrayRef<Attribute> portNames,
2380 ArrayRef<Attribute> domainInfo, ArrayRef<Attribute> annotations,
2381 ArrayRef<Attribute> portAnnotations, ArrayRef<Attribute> layers,
2382 bool lowerToBind, bool doNotPrint, StringAttr innerSym) {
2383 build(builder, result, resultTypes, moduleName, name, nameKind,
2384 portDirections, portNames, domainInfo, annotations, portAnnotations,
2385 layers, lowerToBind, doNotPrint,
2386 innerSym ? hw::InnerSymAttr::get(innerSym) : hw::InnerSymAttr());
2387}
2388
2389void InstanceOp::build(
2390 OpBuilder &builder, OperationState &result, TypeRange resultTypes,
2391 StringRef moduleName, StringRef name, NameKindEnum nameKind,
2392 ArrayRef<Direction> portDirections, ArrayRef<Attribute> portNames,
2393 ArrayRef<Attribute> domainInfo, ArrayRef<Attribute> annotations,
2394 ArrayRef<Attribute> portAnnotations, ArrayRef<Attribute> layers,
2395 bool lowerToBind, bool doNotPrint, hw::InnerSymAttr innerSym) {
2396 result.addTypes(resultTypes);
2397 result.getOrAddProperties<Properties>().setModuleName(
2398 SymbolRefAttr::get(builder.getContext(), moduleName));
2399 result.getOrAddProperties<Properties>().setName(builder.getStringAttr(name));
2400 result.getOrAddProperties<Properties>().setPortDirections(
2401 direction::packAttribute(builder.getContext(), portDirections));
2402 result.getOrAddProperties<Properties>().setPortNames(
2403 builder.getArrayAttr(portNames));
2404
2405 if (domainInfo.empty()) {
2406 SmallVector<Attribute, 16> domainInfoVec(resultTypes.size(),
2407 builder.getArrayAttr({}));
2408 result.getOrAddProperties<Properties>().setDomainInfo(
2409 builder.getArrayAttr(domainInfoVec));
2410 } else {
2411 assert(domainInfo.size() == resultTypes.size());
2412 result.getOrAddProperties<Properties>().setDomainInfo(
2413 builder.getArrayAttr(domainInfo));
2414 }
2415
2416 result.getOrAddProperties<Properties>().setAnnotations(
2417 builder.getArrayAttr(annotations));
2418 result.getOrAddProperties<Properties>().setLayers(
2419 builder.getArrayAttr(layers));
2420 if (lowerToBind)
2421 result.getOrAddProperties<Properties>().setLowerToBind(
2422 builder.getUnitAttr());
2423 if (doNotPrint)
2424 result.getOrAddProperties<Properties>().setDoNotPrint(
2425 builder.getUnitAttr());
2426 if (innerSym)
2427 result.getOrAddProperties<Properties>().setInnerSym(innerSym);
2428
2429 result.getOrAddProperties<Properties>().setNameKind(
2430 NameKindEnumAttr::get(builder.getContext(), nameKind));
2431
2432 if (portAnnotations.empty()) {
2433 SmallVector<Attribute, 16> portAnnotationsVec(resultTypes.size(),
2434 builder.getArrayAttr({}));
2435 result.getOrAddProperties<Properties>().setPortAnnotations(
2436 builder.getArrayAttr(portAnnotationsVec));
2437 } else {
2438 assert(portAnnotations.size() == resultTypes.size());
2439 result.getOrAddProperties<Properties>().setPortAnnotations(
2440 builder.getArrayAttr(portAnnotations));
2441 }
2442}
2443
2444void InstanceOp::build(OpBuilder &builder, OperationState &result,
2445 FModuleLike module, StringRef name,
2446 NameKindEnum nameKind, ArrayRef<Attribute> annotations,
2447 ArrayRef<Attribute> portAnnotations, bool lowerToBind,
2448 bool doNotPrint, hw::InnerSymAttr innerSym) {
2449
2450 // Gather the result types.
2451 SmallVector<Type> resultTypes;
2452 resultTypes.reserve(module.getNumPorts());
2453 llvm::transform(
2454 module.getPortTypes(), std::back_inserter(resultTypes),
2455 [](Attribute typeAttr) { return cast<TypeAttr>(typeAttr).getValue(); });
2456
2457 // The storage for annotations and domains on the module uses an empty array
2458 // if there are no annotations or domains. However, the instance op always
2459 // stores one-array-per-port. Do this massaging here.
2460 ArrayAttr portAnnotationsAttr;
2461 if (portAnnotations.empty()) {
2462 portAnnotationsAttr = builder.getArrayAttr(SmallVector<Attribute, 16>(
2463 resultTypes.size(), builder.getArrayAttr({})));
2464 } else {
2465 portAnnotationsAttr = builder.getArrayAttr(portAnnotations);
2466 }
2467 ArrayAttr domainInfoAttr = module.getDomainInfoAttr();
2468 if (domainInfoAttr.empty()) {
2469 domainInfoAttr = builder.getArrayAttr(SmallVector<Attribute, 16>(
2470 resultTypes.size(), builder.getArrayAttr({})));
2471 }
2472
2473 return build(
2474 builder, result, resultTypes,
2475 SymbolRefAttr::get(builder.getContext(), module.getModuleNameAttr()),
2476 builder.getStringAttr(name),
2477 NameKindEnumAttr::get(builder.getContext(), nameKind),
2478 module.getPortDirectionsAttr(), module.getPortNamesAttr(), domainInfoAttr,
2479 builder.getArrayAttr(annotations), portAnnotationsAttr,
2480 module.getLayersAttr(), lowerToBind ? builder.getUnitAttr() : UnitAttr(),
2481 doNotPrint ? builder.getUnitAttr() : UnitAttr(), innerSym);
2482}
2483
2484void InstanceOp::build(OpBuilder &builder, OperationState &odsState,
2485 ArrayRef<PortInfo> ports, StringRef moduleName,
2486 StringRef name, NameKindEnum nameKind,
2487 ArrayRef<Attribute> annotations,
2488 ArrayRef<Attribute> layers, bool lowerToBind,
2489 bool doNotPrint, hw::InnerSymAttr innerSym) {
2490 // Gather the result types.
2491 SmallVector<Type> newResultTypes;
2492 SmallVector<Direction> newPortDirections;
2493 SmallVector<Attribute> newPortNames;
2494 SmallVector<Attribute> newPortAnnotations;
2495 SmallVector<Attribute> newDomainInfo;
2496 for (auto &p : ports) {
2497 newResultTypes.push_back(p.type);
2498 newPortDirections.push_back(p.direction);
2499 newPortNames.push_back(p.name);
2500 newPortAnnotations.push_back(p.annotations.getArrayAttr());
2501 if (p.domains)
2502 newDomainInfo.push_back(p.domains);
2503 else
2504 newDomainInfo.push_back(builder.getArrayAttr({}));
2505 }
2506
2507 return build(builder, odsState, newResultTypes, moduleName, name, nameKind,
2508 newPortDirections, newPortNames, newDomainInfo, annotations,
2509 newPortAnnotations, layers, lowerToBind, doNotPrint, innerSym);
2510}
2511
2512LogicalResult InstanceOp::verify() {
2513 // The instance may only be instantiated under its required layers.
2514 auto ambientLayers = getAmbientLayersAt(getOperation());
2515 SmallVector<SymbolRefAttr> missingLayers;
2516 for (auto layer : getLayersAttr().getAsRange<SymbolRefAttr>())
2517 if (!isLayerCompatibleWith(layer, ambientLayers))
2518 missingLayers.push_back(layer);
2519
2520 if (missingLayers.empty())
2521 return success();
2522
2523 auto diag =
2524 emitOpError("ambient layers are insufficient to instantiate module");
2525 auto &note = diag.attachNote();
2526 note << "missing layer requirements: ";
2527 interleaveComma(missingLayers, note);
2528 return failure();
2529}
2530
2532 Operation *op1, Operation *op2,
2533 ArrayRef<std::pair<unsigned, PortInfo>> insertions) {
2534 assert(op1 != op2);
2535 size_t n = insertions.size();
2536 size_t inserted = 0;
2537 for (size_t i = 0, e = op1->getNumResults(); i < e; ++i) {
2538 while (inserted < n) {
2539 auto &[index, portInfo] = insertions[inserted];
2540 if (i < index)
2541 break;
2542 ++inserted;
2543 }
2544 auto r1 = op1->getResult(i);
2545 auto r2 = op2->getResult(i + inserted);
2546 r1.replaceAllUsesWith(r2);
2547 }
2548}
2549
2550static void replaceUsesRespectingErasedPorts(Operation *op1, Operation *op2,
2551 const llvm::BitVector &erasures) {
2552 assert(op1 != op2);
2553 size_t erased = 0;
2554 for (size_t i = 0, e = op1->getNumResults(); i < e; ++i) {
2555 auto r1 = op1->getResult(i);
2556 if (erasures[i]) {
2557 assert(r1.use_empty() && "removed instance port has uses");
2558 ++erased;
2559 continue;
2560 }
2561 auto r2 = op2->getResult(i - erased);
2562 r1.replaceAllUsesWith(r2);
2563 }
2564}
2565
2566InstanceOp InstanceOp::cloneWithErasedPorts(const llvm::BitVector &erasures) {
2567 assert(erasures.size() >= getNumResults() &&
2568 "erasures is not at least as large as getNumResults()");
2569
2570 SmallVector<Type> newResultTypes = removeElementsAtIndices<Type>(
2571 SmallVector<Type>(result_type_begin(), result_type_end()), erasures);
2572 SmallVector<Direction> newPortDirections = removeElementsAtIndices<Direction>(
2573 direction::unpackAttribute(getPortDirectionsAttr()), erasures);
2574 SmallVector<Attribute> newPortNames =
2575 removeElementsAtIndices(getPortNames().getValue(), erasures);
2576 SmallVector<Attribute> newPortAnnotations =
2577 removeElementsAtIndices(getPortAnnotations().getValue(), erasures);
2578 ArrayAttr newDomainInfo =
2579 fixDomainInfoDeletions(getContext(), getDomainInfoAttr(), erasures,
2580 /*supportsEmptyAttr=*/false);
2581
2582 OpBuilder builder(*this);
2583 auto clone = InstanceOp::create(
2584 builder, getLoc(), newResultTypes, getModuleName(), getName(),
2585 getNameKind(), newPortDirections, newPortNames, newDomainInfo.getValue(),
2586 getAnnotations().getValue(), newPortAnnotations, getLayers(),
2587 getLowerToBind(), getDoNotPrint(), getInnerSymAttr());
2588
2589 if (auto outputFile = (*this)->getAttr("output_file"))
2590 clone->setAttr("output_file", outputFile);
2591
2592 return clone;
2593}
2594
2595InstanceOp InstanceOp::cloneWithErasedPortsAndReplaceUses(
2596 const llvm::BitVector &erasures) {
2597 auto clone = cloneWithErasedPorts(erasures);
2598 replaceUsesRespectingErasedPorts(getOperation(), clone, erasures);
2599 return clone;
2600}
2601
2602ArrayAttr InstanceOp::getPortAnnotation(unsigned portIdx) {
2603 assert(portIdx < getNumResults() &&
2604 "index should be smaller than result number");
2605 return cast<ArrayAttr>(getPortAnnotations()[portIdx]);
2606}
2607
2608void InstanceOp::setAllPortAnnotations(ArrayRef<Attribute> annotations) {
2609 assert(annotations.size() == getNumResults() &&
2610 "number of annotations is not equal to result number");
2611 (*this)->setAttr("portAnnotations",
2612 ArrayAttr::get(getContext(), annotations));
2613}
2614
2615ArrayAttr InstanceOp::getPortDomain(unsigned portIdx) {
2616 assert(portIdx < getNumResults() &&
2617 "index should be smaller than result number");
2618 return cast<ArrayAttr>(getDomainInfo()[portIdx]);
2619}
2620
2621InstanceOp InstanceOp::cloneWithInsertedPorts(
2622 ArrayRef<std::pair<unsigned, PortInfo>> insertions) {
2623 auto *context = getContext();
2624 auto empty = ArrayAttr::get(context, {});
2625
2626 auto oldPortCount = getNumResults();
2627 auto numInsertions = insertions.size();
2628 auto newPortCount = oldPortCount + numInsertions;
2629
2630 SmallVector<Direction> newPortDirections;
2631 SmallVector<Attribute> newPortNames;
2632 SmallVector<Type> newPortTypes;
2633 SmallVector<Attribute> newPortAnnos;
2634 SmallVector<Attribute> newDomainInfo;
2635
2636 newPortDirections.reserve(newPortCount);
2637 newPortNames.reserve(newPortCount);
2638 newPortTypes.reserve(newPortCount);
2639 newPortAnnos.reserve(newPortCount);
2640 newDomainInfo.reserve(newPortCount);
2641
2642 SmallVector<unsigned> indexMap(oldPortCount);
2643
2644 size_t inserted = 0;
2645 for (size_t i = 0; i < oldPortCount; ++i) {
2646 while (inserted < numInsertions) {
2647 auto &[index, info] = insertions[inserted];
2648 if (index > i)
2649 break;
2650
2651 auto domains = fixDomainInfoInsertions(
2652 context, info.domains ? info.domains : empty, indexMap);
2653 newPortDirections.push_back(info.direction);
2654 newPortNames.push_back(info.name);
2655 newPortTypes.push_back(info.type);
2656 newPortAnnos.push_back(info.annotations.getArrayAttr());
2657 newDomainInfo.push_back(domains);
2658 ++inserted;
2659 }
2660
2661 newPortDirections.push_back(getPortDirection(i));
2662 newPortNames.push_back(getPortNameAttr(i));
2663 newPortTypes.push_back(getType(i));
2664 newPortAnnos.push_back(getPortAnnotation(i));
2665 newDomainInfo.push_back(getDomainInfo()[i]);
2666 indexMap[i] = i + inserted;
2667 }
2668
2669 while (inserted < numInsertions) {
2670 auto &[index, info] = insertions[inserted];
2671 auto domains = fixDomainInfoInsertions(
2672 context, info.domains ? info.domains : empty, indexMap);
2673 newPortDirections.push_back(info.direction);
2674 newPortNames.push_back(info.name);
2675 newPortTypes.push_back(info.type);
2676 newPortAnnos.push_back(info.annotations.getArrayAttr());
2677 newDomainInfo.push_back(domains);
2678 ++inserted;
2679 }
2680
2681 OpBuilder builder(*this);
2682 auto clone = InstanceOp::create(
2683 builder, getLoc(), newPortTypes, getModuleName(), getName(),
2684 getNameKind(), newPortDirections, newPortNames, newDomainInfo,
2685 getAnnotations().getValue(), newPortAnnos, getLayers(), getLowerToBind(),
2686 getDoNotPrint(), getInnerSymAttr());
2687
2688 if (auto outputFile = (*this)->getAttr("output_file"))
2689 clone->setAttr("output_file", outputFile);
2690
2691 return clone;
2692}
2693
2694InstanceOp InstanceOp::cloneWithInsertedPortsAndReplaceUses(
2695 ArrayRef<std::pair<unsigned, PortInfo>> insertions) {
2696 auto clone = cloneWithInsertedPorts(insertions);
2697 replaceUsesRespectingInsertedPorts(getOperation(), clone, insertions);
2698 return clone;
2699}
2700
2701LogicalResult InstanceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
2702 return instance_like_impl::verifyReferencedModule(*this, symbolTable,
2703 getModuleNameAttr());
2704}
2705
2706StringRef InstanceOp::getInstanceName() { return getName(); }
2707
2708StringAttr InstanceOp::getInstanceNameAttr() { return getNameAttr(); }
2709
2710void InstanceOp::print(OpAsmPrinter &p) {
2711 // Print the instance name.
2712 p << " ";
2713 p.printKeywordOrString(getName());
2714 if (auto attr = getInnerSymAttr()) {
2715 p << " sym ";
2716 p.printSymbolName(attr.getSymName());
2717 }
2718 if (getNameKindAttr().getValue() != NameKindEnum::DroppableName)
2719 p << ' ' << stringifyNameKindEnum(getNameKindAttr().getValue());
2720
2721 // Print the attr-dict.
2722 SmallVector<StringRef, 10> omittedAttrs = {
2723 "moduleName", "name", "portDirections",
2724 "portNames", "portTypes", "portAnnotations",
2725 "inner_sym", "nameKind", "domainInfo"};
2726 if (getAnnotations().empty())
2727 omittedAttrs.push_back("annotations");
2728 if (getLayers().empty())
2729 omittedAttrs.push_back("layers");
2730 p.printOptionalAttrDict((*this)->getAttrs(), omittedAttrs);
2731
2732 // Print the module name.
2733 p << " ";
2734 p.printSymbolName(getModuleName());
2735
2736 // Collect all the result types as TypeAttrs for printing.
2737 SmallVector<Attribute> portTypes;
2738 portTypes.reserve(getNumResults());
2739 llvm::transform(getResultTypes(), std::back_inserter(portTypes),
2740 &TypeAttr::get);
2741 // This needs to be passed domain information.
2742 printModulePorts(p, /*block=*/nullptr, getPortDirectionsAttr(),
2743 getPortNames().getValue(), portTypes,
2744 getPortAnnotations().getValue(), {}, {},
2745 getDomainInfo().getValue());
2746}
2747
2748ParseResult InstanceOp::parse(OpAsmParser &parser, OperationState &result) {
2749 auto *context = parser.getContext();
2750 auto &properties = result.getOrAddProperties<Properties>();
2751
2752 std::string name;
2753 hw::InnerSymAttr innerSymAttr;
2754 FlatSymbolRefAttr moduleName;
2755 SmallVector<OpAsmParser::Argument> entryArgs;
2756 SmallVector<Direction, 4> portDirections;
2757 SmallVector<Attribute, 4> portNames;
2758 SmallVector<Attribute, 4> portTypes;
2759 SmallVector<Attribute, 4> portAnnotations;
2760 SmallVector<Attribute, 4> portSyms;
2761 SmallVector<Attribute, 4> portLocs;
2762 SmallVector<Attribute, 4> domains;
2763 NameKindEnumAttr nameKind;
2764
2765 if (parser.parseKeywordOrString(&name))
2766 return failure();
2767 if (succeeded(parser.parseOptionalKeyword("sym"))) {
2768 if (parser.parseCustomAttributeWithFallback(
2769 innerSymAttr, ::mlir::Type{},
2771 result.attributes)) {
2772 return ::mlir::failure();
2773 }
2774 }
2775 if (parseNameKind(parser, nameKind) ||
2776 parser.parseOptionalAttrDict(result.attributes) ||
2777 parser.parseAttribute(moduleName) ||
2778 parseModulePorts(parser, /*hasSSAIdentifiers=*/false,
2779 /*supportsSymbols=*/false, /*supportsDomains=*/true,
2780 entryArgs, portDirections, portNames, portTypes,
2781 portAnnotations, portSyms, portLocs, domains))
2782 return failure();
2783
2784 // Add the attributes. We let attributes defined in the attr-dict override
2785 // attributes parsed out of the module signature.
2786
2787 properties.setModuleName(moduleName);
2788 properties.setName(StringAttr::get(context, name));
2789 properties.setNameKind(nameKind);
2790 properties.setPortDirections(
2791 direction::packAttribute(context, portDirections));
2792 properties.setPortNames(ArrayAttr::get(context, portNames));
2793 properties.setPortAnnotations(ArrayAttr::get(context, portAnnotations));
2794
2795 // Annotations, layers, and LowerToBind are omitted in the printed format
2796 // if they are empty, empty, and false (respectively).
2797 properties.setAnnotations(parser.getBuilder().getArrayAttr({}));
2798 properties.setLayers(parser.getBuilder().getArrayAttr({}));
2799
2800 // Add domain information.
2801 properties.setDomainInfo(ArrayAttr::get(context, domains));
2802
2803 // Add result types.
2804 result.types.reserve(portTypes.size());
2805 llvm::transform(
2806 portTypes, std::back_inserter(result.types),
2807 [](Attribute typeAttr) { return cast<TypeAttr>(typeAttr).getValue(); });
2808
2809 return success();
2810}
2811
2812void InstanceOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
2813 StringRef base = getName();
2814 if (base.empty())
2815 base = "inst";
2816
2817 for (size_t i = 0, e = (*this)->getNumResults(); i != e; ++i) {
2818 setNameFn(getResult(i), (base + "_" + getPortName(i)).str());
2819 }
2820}
2821
2822std::optional<size_t> InstanceOp::getTargetResultIndex() {
2823 // Inner symbols on instance operations target the op not any result.
2824 return std::nullopt;
2825}
2826
2827// -----------------------------------------------------------------------------
2828// InstanceChoiceOp
2829// -----------------------------------------------------------------------------
2830
2831void InstanceChoiceOp::build(
2832 OpBuilder &builder, OperationState &result, FModuleLike defaultModule,
2833 ArrayRef<std::pair<OptionCaseOp, FModuleLike>> cases, StringRef name,
2834 NameKindEnum nameKind, ArrayRef<Attribute> annotations,
2835 ArrayRef<Attribute> portAnnotations, StringAttr innerSym) {
2836 // Gather the result types.
2837 SmallVector<Type> resultTypes;
2838 for (Attribute portType : defaultModule.getPortTypes())
2839 resultTypes.push_back(cast<TypeAttr>(portType).getValue());
2840
2841 // Create the port annotations.
2842 ArrayAttr portAnnotationsAttr;
2843 if (portAnnotations.empty()) {
2844 portAnnotationsAttr = builder.getArrayAttr(SmallVector<Attribute, 16>(
2845 resultTypes.size(), builder.getArrayAttr({})));
2846 } else {
2847 portAnnotationsAttr = builder.getArrayAttr(portAnnotations);
2848 }
2849
2850 // Create the domain info attribute.
2851 ArrayAttr domainInfoAttr = defaultModule.getDomainInfoAttr();
2852 if (domainInfoAttr.empty()) {
2853 domainInfoAttr = builder.getArrayAttr(SmallVector<Attribute, 16>(
2854 resultTypes.size(), builder.getArrayAttr({})));
2855 }
2856
2857 // Gather the module & case names.
2858 SmallVector<Attribute> moduleNames, caseNames;
2859 moduleNames.push_back(SymbolRefAttr::get(defaultModule.getModuleNameAttr()));
2860 for (auto [caseOption, caseModule] : cases) {
2861 auto caseGroup = caseOption->getParentOfType<OptionOp>();
2862 caseNames.push_back(SymbolRefAttr::get(caseGroup.getSymNameAttr(),
2863 {SymbolRefAttr::get(caseOption)}));
2864 moduleNames.push_back(SymbolRefAttr::get(caseModule.getModuleNameAttr()));
2865 }
2866
2867 return build(builder, result, resultTypes, builder.getArrayAttr(moduleNames),
2868 builder.getArrayAttr(caseNames), builder.getStringAttr(name),
2869 NameKindEnumAttr::get(builder.getContext(), nameKind),
2870 defaultModule.getPortDirectionsAttr(),
2871 defaultModule.getPortNamesAttr(), domainInfoAttr,
2872 builder.getArrayAttr(annotations), portAnnotationsAttr,
2873 defaultModule.getLayersAttr(),
2874 innerSym ? hw::InnerSymAttr::get(innerSym) : hw::InnerSymAttr());
2875}
2876
2877std::optional<size_t> InstanceChoiceOp::getTargetResultIndex() {
2878 return std::nullopt;
2879}
2880
2881void InstanceChoiceOp::print(OpAsmPrinter &p) {
2882 // Print the instance name.
2883 p << " ";
2884 p.printKeywordOrString(getName());
2885 if (auto attr = getInnerSymAttr()) {
2886 p << " sym ";
2887 p.printSymbolName(attr.getSymName());
2888 }
2889 if (getNameKindAttr().getValue() != NameKindEnum::DroppableName)
2890 p << ' ' << stringifyNameKindEnum(getNameKindAttr().getValue());
2891
2892 // Print the attr-dict.
2893 SmallVector<StringRef, 11> omittedAttrs = {
2894 "moduleNames", "caseNames", "name",
2895 "portDirections", "portNames", "portTypes",
2896 "portAnnotations", "inner_sym", "nameKind",
2897 "domainInfo"};
2898 if (getAnnotations().empty())
2899 omittedAttrs.push_back("annotations");
2900 if (getLayers().empty())
2901 omittedAttrs.push_back("layers");
2902 p.printOptionalAttrDict((*this)->getAttrs(), omittedAttrs);
2903
2904 // Print the module name.
2905 p << ' ';
2906
2907 auto moduleNames = getModuleNamesAttr();
2908 auto caseNames = getCaseNamesAttr();
2909
2910 p.printSymbolName(cast<FlatSymbolRefAttr>(moduleNames[0]).getValue());
2911
2912 p << " alternatives ";
2913 p.printSymbolName(
2914 cast<SymbolRefAttr>(caseNames[0]).getRootReference().getValue());
2915 p << " { ";
2916 for (size_t i = 0, n = caseNames.size(); i < n; ++i) {
2917 if (i != 0)
2918 p << ", ";
2919
2920 auto symbol = cast<SymbolRefAttr>(caseNames[i]);
2921 p.printSymbolName(symbol.getNestedReferences()[0].getValue());
2922 p << " -> ";
2923 p.printSymbolName(cast<FlatSymbolRefAttr>(moduleNames[i + 1]).getValue());
2924 }
2925
2926 p << " } ";
2927
2928 // Collect all the result types as TypeAttrs for printing.
2929 SmallVector<Attribute> portTypes;
2930 portTypes.reserve(getNumResults());
2931 llvm::transform(getResultTypes(), std::back_inserter(portTypes),
2932 &TypeAttr::get);
2933 printModulePorts(p, /*block=*/nullptr, getPortDirectionsAttr(),
2934 getPortNames().getValue(), portTypes,
2935 getPortAnnotations().getValue(), {}, {},
2936 getDomainInfo().getValue());
2937}
2938
2939ParseResult InstanceChoiceOp::parse(OpAsmParser &parser,
2940 OperationState &result) {
2941 auto *context = parser.getContext();
2942 auto &properties = result.getOrAddProperties<Properties>();
2943
2944 std::string name;
2945 hw::InnerSymAttr innerSymAttr;
2946 SmallVector<Attribute> moduleNames;
2947 SmallVector<Attribute> caseNames;
2948 SmallVector<OpAsmParser::Argument> entryArgs;
2949 SmallVector<Direction, 4> portDirections;
2950 SmallVector<Attribute, 4> portNames;
2951 SmallVector<Attribute, 4> portTypes;
2952 SmallVector<Attribute, 4> portAnnotations;
2953 SmallVector<Attribute, 4> portSyms;
2954 SmallVector<Attribute, 4> portLocs;
2955 SmallVector<Attribute, 4> domains;
2956 NameKindEnumAttr nameKind;
2957
2958 if (parser.parseKeywordOrString(&name))
2959 return failure();
2960 if (succeeded(parser.parseOptionalKeyword("sym"))) {
2961 if (parser.parseCustomAttributeWithFallback(
2962 innerSymAttr, Type{},
2964 result.attributes)) {
2965 return failure();
2966 }
2967 }
2968 if (parseNameKind(parser, nameKind) ||
2969 parser.parseOptionalAttrDict(result.attributes))
2970 return failure();
2971
2972 FlatSymbolRefAttr defaultModuleName;
2973 if (parser.parseAttribute(defaultModuleName))
2974 return failure();
2975 moduleNames.push_back(defaultModuleName);
2976
2977 // alternatives { @opt::@case -> @target, ... }
2978 {
2979 FlatSymbolRefAttr optionName;
2980 if (parser.parseKeyword("alternatives") ||
2981 parser.parseAttribute(optionName) || parser.parseLBrace())
2982 return failure();
2983
2984 FlatSymbolRefAttr moduleName;
2985 StringAttr caseName;
2986 while (succeeded(parser.parseOptionalSymbolName(caseName))) {
2987 if (parser.parseArrow() || parser.parseAttribute(moduleName))
2988 return failure();
2989 moduleNames.push_back(moduleName);
2990 caseNames.push_back(SymbolRefAttr::get(
2991 optionName.getAttr(), {FlatSymbolRefAttr::get(caseName)}));
2992 if (failed(parser.parseOptionalComma()))
2993 break;
2994 }
2995 if (parser.parseRBrace())
2996 return failure();
2997 }
2998
2999 if (parseModulePorts(parser, /*hasSSAIdentifiers=*/false,
3000 /*supportsSymbols=*/false, /*supportsDomains=*/true,
3001 entryArgs, portDirections, portNames, portTypes,
3002 portAnnotations, portSyms, portLocs, domains))
3003 return failure();
3004
3005 // Add the attributes. We let attributes defined in the attr-dict override
3006 // attributes parsed out of the module signature.
3007 properties.setModuleNames(ArrayAttr::get(context, moduleNames));
3008 properties.setCaseNames(ArrayAttr::get(context, caseNames));
3009 properties.setName(StringAttr::get(context, name));
3010 properties.setNameKind(nameKind);
3011 properties.setPortDirections(
3012 direction::packAttribute(context, portDirections));
3013 properties.setPortNames(ArrayAttr::get(context, portNames));
3014 properties.setDomainInfo(ArrayAttr::get(context, domains));
3015 properties.setPortAnnotations(ArrayAttr::get(context, portAnnotations));
3016
3017 // Annotations, layers, and LowerToBind are omitted in the printed format if
3018 // they are empty, empty, and false (respectively).
3019 properties.setAnnotations(parser.getBuilder().getArrayAttr({}));
3020 properties.setLayers(parser.getBuilder().getArrayAttr({}));
3021
3022 // Add result types.
3023 result.types.reserve(portTypes.size());
3024 llvm::transform(
3025 portTypes, std::back_inserter(result.types),
3026 [](Attribute typeAttr) { return cast<TypeAttr>(typeAttr).getValue(); });
3027
3028 return success();
3029}
3030
3031void InstanceChoiceOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
3032 StringRef base = getName().empty() ? "inst" : getName();
3033 for (auto [result, name] : llvm::zip(getResults(), getPortNames()))
3034 setNameFn(result, (base + "_" + cast<StringAttr>(name).getValue()).str());
3035}
3036
3037LogicalResult InstanceChoiceOp::verify() {
3038 if (getCaseNamesAttr().empty())
3039 return emitOpError() << "must have at least one case";
3040 if (getModuleNamesAttr().size() != getCaseNamesAttr().size() + 1)
3041 return emitOpError() << "number of referenced modules does not match the "
3042 "number of options";
3043
3044 // The modules may only be instantiated under their required layers (which
3045 // are the same for all modules).
3046 auto ambientLayers = getAmbientLayersAt(getOperation());
3047 SmallVector<SymbolRefAttr> missingLayers;
3048 for (auto layer : getLayersAttr().getAsRange<SymbolRefAttr>())
3049 if (!isLayerCompatibleWith(layer, ambientLayers))
3050 missingLayers.push_back(layer);
3051
3052 if (missingLayers.empty())
3053 return success();
3054
3055 auto diag =
3056 emitOpError("ambient layers are insufficient to instantiate module");
3057 auto &note = diag.attachNote();
3058 note << "missing layer requirements: ";
3059 interleaveComma(missingLayers, note);
3060 return failure();
3061}
3062
3063LogicalResult
3064InstanceChoiceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
3065 auto caseNames = getCaseNamesAttr();
3066 for (auto moduleName : getModuleNamesAttr()) {
3068 *this, symbolTable, cast<FlatSymbolRefAttr>(moduleName))))
3069 return failure();
3070 }
3071
3072 auto root = cast<SymbolRefAttr>(caseNames[0]).getRootReference();
3073 for (size_t i = 0, n = caseNames.size(); i < n; ++i) {
3074 auto ref = cast<SymbolRefAttr>(caseNames[i]);
3075 auto refRoot = ref.getRootReference();
3076 if (ref.getRootReference() != root)
3077 return emitOpError() << "case " << ref
3078 << " is not in the same option group as "
3079 << caseNames[0];
3080
3081 if (!symbolTable.lookupNearestSymbolFrom<OptionOp>(*this, refRoot))
3082 return emitOpError() << "option " << refRoot << " does not exist";
3083
3084 if (!symbolTable.lookupNearestSymbolFrom<OptionCaseOp>(*this, ref))
3085 return emitOpError() << "option " << refRoot
3086 << " does not contain option case " << ref;
3087 }
3088
3089 return success();
3090}
3091
3092FlatSymbolRefAttr
3093InstanceChoiceOp::getTargetOrDefaultAttr(OptionCaseOp option) {
3094 auto caseNames = getCaseNamesAttr();
3095 for (size_t i = 0, n = caseNames.size(); i < n; ++i) {
3096 StringAttr caseSym = cast<SymbolRefAttr>(caseNames[i]).getLeafReference();
3097 if (caseSym == option.getSymName())
3098 return cast<FlatSymbolRefAttr>(getModuleNamesAttr()[i + 1]);
3099 }
3100 return getDefaultTargetAttr();
3101}
3102
3103SmallVector<std::pair<SymbolRefAttr, FlatSymbolRefAttr>, 1>
3104InstanceChoiceOp::getTargetChoices() {
3105 auto caseNames = getCaseNamesAttr();
3106 auto moduleNames = getModuleNamesAttr();
3107 SmallVector<std::pair<SymbolRefAttr, FlatSymbolRefAttr>, 1> choices;
3108 for (size_t i = 0; i < caseNames.size(); ++i) {
3109 choices.emplace_back(cast<SymbolRefAttr>(caseNames[i]),
3110 cast<FlatSymbolRefAttr>(moduleNames[i + 1]));
3111 }
3112
3113 return choices;
3114}
3115
3116InstanceChoiceOp InstanceChoiceOp::cloneWithInsertedPorts(
3117 ArrayRef<std::pair<unsigned, PortInfo>> insertions) {
3118 auto *context = getContext();
3119 auto empty = ArrayAttr::get(context, {});
3120
3121 auto oldPortCount = getNumResults();
3122 auto numInsertions = insertions.size();
3123 auto newPortCount = oldPortCount + numInsertions;
3124
3125 SmallVector<Direction> newPortDirections;
3126 SmallVector<Attribute> newPortNames;
3127 SmallVector<Type> newPortTypes;
3128 SmallVector<Attribute> newPortAnnos;
3129 SmallVector<Attribute> newDomainInfo;
3130
3131 newPortDirections.reserve(newPortCount);
3132 newPortNames.reserve(newPortCount);
3133 newPortTypes.reserve(newPortCount);
3134 newPortAnnos.reserve(newPortCount);
3135 newDomainInfo.reserve(newPortCount);
3136
3137 SmallVector<unsigned> indexMap(oldPortCount);
3138
3139 size_t inserted = 0;
3140 for (size_t i = 0; i < oldPortCount; ++i) {
3141 while (inserted < numInsertions) {
3142 auto &[index, info] = insertions[inserted];
3143 if (i < index)
3144 break;
3145
3146 auto domains = fixDomainInfoInsertions(
3147 context, info.domains ? info.domains : empty, indexMap);
3148 newPortDirections.push_back(info.direction);
3149 newPortNames.push_back(info.name);
3150 newPortTypes.push_back(info.type);
3151 newPortAnnos.push_back(info.annotations.getArrayAttr());
3152 newDomainInfo.push_back(domains);
3153 ++inserted;
3154 }
3155
3156 newPortDirections.push_back(getPortDirection(i));
3157 newPortNames.push_back(getPortNameAttr(i));
3158 newPortTypes.push_back(getType(i));
3159 newPortAnnos.push_back(getPortAnnotations()[i]);
3160 newDomainInfo.push_back(getDomainInfo()[i]);
3161 indexMap[i] = i + inserted;
3162 }
3163
3164 while (inserted < numInsertions) {
3165 auto &[index, info] = insertions[inserted];
3166 auto domains = fixDomainInfoInsertions(
3167 context, info.domains ? info.domains : empty, indexMap);
3168 newPortDirections.push_back(info.direction);
3169 newPortNames.push_back(info.name);
3170 newPortTypes.push_back(info.type);
3171 newPortAnnos.push_back(info.annotations.getArrayAttr());
3172 newDomainInfo.push_back(domains);
3173 ++inserted;
3174 }
3175
3176 OpBuilder builder(*this);
3177 auto clone = InstanceChoiceOp::create(
3178 builder, getLoc(), newPortTypes, getModuleNames(), getCaseNames(),
3179 getName(), getNameKind(),
3180 direction::packAttribute(context, newPortDirections),
3181 ArrayAttr::get(context, newPortNames),
3182 ArrayAttr::get(context, newDomainInfo), getAnnotationsAttr(),
3183 ArrayAttr::get(context, newPortAnnos), getLayers(), getInnerSymAttr());
3184
3185 if (auto outputFile = (*this)->getAttr("output_file"))
3186 clone->setAttr("output_file", outputFile);
3187
3188 return clone;
3189}
3190
3191InstanceChoiceOp InstanceChoiceOp::cloneWithInsertedPortsAndReplaceUses(
3192 ArrayRef<std::pair<unsigned, PortInfo>> insertions) {
3193 auto clone = cloneWithInsertedPorts(insertions);
3194 replaceUsesRespectingInsertedPorts(getOperation(), clone, insertions);
3195 return clone;
3196}
3197
3198InstanceChoiceOp
3199InstanceChoiceOp::cloneWithErasedPorts(const llvm::BitVector &erasures) {
3200 assert(erasures.size() >= getNumResults() &&
3201 "erasures is not at least as large as getNumResults()");
3202
3203 SmallVector<Type> newResultTypes = removeElementsAtIndices<Type>(
3204 SmallVector<Type>(result_type_begin(), result_type_end()), erasures);
3205 SmallVector<Direction> newPortDirections = removeElementsAtIndices<Direction>(
3206 direction::unpackAttribute(getPortDirectionsAttr()), erasures);
3207 SmallVector<Attribute> newPortNames =
3208 removeElementsAtIndices(getPortNames().getValue(), erasures);
3209 SmallVector<Attribute> newPortAnnotations =
3210 removeElementsAtIndices(getPortAnnotations().getValue(), erasures);
3211 ArrayAttr newPortDomains =
3212 fixDomainInfoDeletions(getContext(), getDomainInfoAttr(), erasures,
3213 /*supportsEmptyAttr=*/false);
3214
3215 OpBuilder builder(*this);
3216 auto clone = InstanceChoiceOp::create(
3217 builder, getLoc(), newResultTypes, getModuleNames(), getCaseNames(),
3218 getName(), getNameKind(),
3219 direction::packAttribute(getContext(), newPortDirections),
3220 ArrayAttr::get(getContext(), newPortNames), newPortDomains,
3221 getAnnotationsAttr(), ArrayAttr::get(getContext(), newPortAnnotations),
3222 getLayers(), getInnerSymAttr());
3223
3224 if (auto outputFile = (*this)->getAttr("output_file"))
3225 clone->setAttr("output_file", outputFile);
3226
3227 return clone;
3228}
3229
3230InstanceChoiceOp InstanceChoiceOp::cloneWithErasedPortsAndReplaceUses(
3231 const llvm::BitVector &erasures) {
3232 auto clone = cloneWithErasedPorts(erasures);
3233 replaceUsesRespectingErasedPorts(getOperation(), clone, erasures);
3234 return clone;
3235}
3236
3237//===----------------------------------------------------------------------===//
3238// MemOp
3239//===----------------------------------------------------------------------===//
3240
3241ArrayAttr MemOp::getPortAnnotation(unsigned portIdx) {
3242 assert(portIdx < getNumResults() &&
3243 "index should be smaller than result number");
3244 return cast<ArrayAttr>(getPortAnnotations()[portIdx]);
3245}
3246
3247void MemOp::setAllPortAnnotations(ArrayRef<Attribute> annotations) {
3248 assert(annotations.size() == getNumResults() &&
3249 "number of annotations is not equal to result number");
3250 (*this)->setAttr("portAnnotations",
3251 ArrayAttr::get(getContext(), annotations));
3252}
3253
3254// Get the number of read, write and read-write ports.
3255void MemOp::getNumPorts(size_t &numReadPorts, size_t &numWritePorts,
3256 size_t &numReadWritePorts, size_t &numDbgsPorts) {
3257 numReadPorts = 0;
3258 numWritePorts = 0;
3259 numReadWritePorts = 0;
3260 numDbgsPorts = 0;
3261 for (size_t i = 0, e = getNumResults(); i != e; ++i) {
3262 auto portKind = getPortKind(i);
3263 if (portKind == MemOp::PortKind::Debug)
3264 ++numDbgsPorts;
3265 else if (portKind == MemOp::PortKind::Read)
3266 ++numReadPorts;
3267 else if (portKind == MemOp::PortKind::Write) {
3268 ++numWritePorts;
3269 } else
3270 ++numReadWritePorts;
3271 }
3272}
3273
3274/// Verify the correctness of a MemOp.
3275LogicalResult MemOp::verify() {
3276
3277 // Store the port names as we find them. This lets us check quickly
3278 // for uniqueneess.
3279 llvm::SmallDenseSet<Attribute, 8> portNamesSet;
3280
3281 // Store the previous data type. This lets us check that the data
3282 // type is consistent across all ports.
3283 FIRRTLType oldDataType;
3284
3285 for (size_t i = 0, e = getNumResults(); i != e; ++i) {
3286 auto portName = getPortNameAttr(i);
3287
3288 // Get a bundle type representing this port, stripping an outer
3289 // flip if it exists. If this is not a bundle<> or
3290 // flip<bundle<>>, then this is an error.
3291 BundleType portBundleType =
3292 type_dyn_cast<BundleType>(getResult(i).getType());
3293
3294 // Require that all port names are unique.
3295 if (!portNamesSet.insert(portName).second) {
3296 emitOpError() << "has non-unique port name " << portName;
3297 return failure();
3298 }
3299
3300 // Determine the kind of the memory. If the kind cannot be
3301 // determined, then it's indicative of the wrong number of fields
3302 // in the type (but we don't know any more just yet).
3303
3304 auto elt = getPortNamed(portName);
3305 if (!elt) {
3306 emitOpError() << "could not get port with name " << portName;
3307 return failure();
3308 }
3309 auto firrtlType = type_cast<FIRRTLType>(elt.getType());
3310 MemOp::PortKind portKind = getMemPortKindFromType(firrtlType);
3311
3312 if (portKind == MemOp::PortKind::Debug &&
3313 !type_isa<RefType>(getResult(i).getType()))
3314 return emitOpError() << "has an invalid type on port " << portName
3315 << " (expected Read/Write/ReadWrite/Debug)";
3316 if (type_isa<RefType>(firrtlType) && e == 1)
3317 return emitOpError()
3318 << "cannot have only one port of debug type. Debug port can only "
3319 "exist alongside other read/write/read-write port";
3320
3321 // Safely search for the "data" field, erroring if it can't be
3322 // found.
3323 FIRRTLBaseType dataType;
3324 if (portKind == MemOp::PortKind::Debug) {
3325 auto resType = type_cast<RefType>(getResult(i).getType());
3326 if (!(resType && type_isa<FVectorType>(resType.getType())))
3327 return emitOpError() << "debug ports must be a RefType of FVectorType";
3328 dataType = type_cast<FVectorType>(resType.getType()).getElementType();
3329 } else {
3330 auto dataTypeOption = portBundleType.getElement("data");
3331 if (!dataTypeOption && portKind == MemOp::PortKind::ReadWrite)
3332 dataTypeOption = portBundleType.getElement("wdata");
3333 if (!dataTypeOption) {
3334 emitOpError() << "has no data field on port " << portName
3335 << " (expected to see \"data\" for a read or write "
3336 "port or \"rdata\" for a read/write port)";
3337 return failure();
3338 }
3339 dataType = dataTypeOption->type;
3340 // Read data is expected to ba a flip.
3341 if (portKind == MemOp::PortKind::Read) {
3342 // FIXME error on missing bundle flip
3343 }
3344 }
3345
3346 // Error if the data type isn't passive.
3347 if (!dataType.isPassive()) {
3348 emitOpError() << "has non-passive data type on port " << portName
3349 << " (memory types must be passive)";
3350 return failure();
3351 }
3352
3353 // Error if the data type contains analog types.
3354 if (dataType.containsAnalog()) {
3355 emitOpError() << "has a data type that contains an analog type on port "
3356 << portName
3357 << " (memory types cannot contain analog types)";
3358 return failure();
3359 }
3360
3361 // Check that the port type matches the kind that we determined
3362 // for this port. This catches situations of extraneous port
3363 // fields beind included or the fields being named incorrectly.
3364 FIRRTLType expectedType =
3365 getTypeForPort(getDepth(), dataType, portKind,
3366 dataType.isGround() ? getMaskBits() : 0);
3367 // Compute the original port type as portBundleType may have
3368 // stripped outer flip information.
3369 auto originalType = getResult(i).getType();
3370 if (originalType != expectedType) {
3371 StringRef portKindName;
3372 switch (portKind) {
3373 case MemOp::PortKind::Read:
3374 portKindName = "read";
3375 break;
3376 case MemOp::PortKind::Write:
3377 portKindName = "write";
3378 break;
3379 case MemOp::PortKind::ReadWrite:
3380 portKindName = "readwrite";
3381 break;
3382 case MemOp::PortKind::Debug:
3383 portKindName = "dbg";
3384 break;
3385 }
3386 emitOpError() << "has an invalid type for port " << portName
3387 << " of determined kind \"" << portKindName
3388 << "\" (expected " << expectedType << ", but got "
3389 << originalType << ")";
3390 return failure();
3391 }
3392
3393 // Error if the type of the current port was not the same as the
3394 // last port, but skip checking the first port.
3395 if (oldDataType && oldDataType != dataType) {
3396 emitOpError() << "port " << getPortNameAttr(i)
3397 << " has a different type than port "
3398 << getPortNameAttr(i - 1) << " (expected " << oldDataType
3399 << ", but got " << dataType << ")";
3400 return failure();
3401 }
3402
3403 oldDataType = dataType;
3404 }
3405
3406 auto maskWidth = getMaskBits();
3407
3408 auto dataWidth = getDataType().getBitWidthOrSentinel();
3409 if (dataWidth > 0 && maskWidth > (size_t)dataWidth)
3410 return emitOpError("the mask width cannot be greater than "
3411 "data width");
3412
3413 if (getPortAnnotations().size() != getNumResults())
3414 return emitOpError("the number of result annotations should be "
3415 "equal to the number of results");
3416
3417 return success();
3418}
3419
3420static size_t getAddressWidth(size_t depth) {
3421 return std::max(1U, llvm::Log2_64_Ceil(depth));
3422}
3423
3424size_t MemOp::getAddrBits() { return getAddressWidth(getDepth()); }
3425
3426FIRRTLType MemOp::getTypeForPort(uint64_t depth, FIRRTLBaseType dataType,
3427 PortKind portKind, size_t maskBits) {
3428
3429 auto *context = dataType.getContext();
3430 if (portKind == PortKind::Debug)
3431 return RefType::get(FVectorType::get(dataType, depth));
3432 FIRRTLBaseType maskType;
3433 // maskBits not specified (==0), then get the mask type from the dataType.
3434 if (maskBits == 0)
3435 maskType = dataType.getMaskType();
3436 else
3437 maskType = UIntType::get(context, maskBits);
3438
3439 auto getId = [&](StringRef name) -> StringAttr {
3440 return StringAttr::get(context, name);
3441 };
3442
3443 SmallVector<BundleType::BundleElement, 7> portFields;
3444
3445 auto addressType = UIntType::get(context, getAddressWidth(depth));
3446
3447 portFields.push_back({getId("addr"), false, addressType});
3448 portFields.push_back({getId("en"), false, UIntType::get(context, 1)});
3449 portFields.push_back({getId("clk"), false, ClockType::get(context)});
3450
3451 switch (portKind) {
3452 case PortKind::Read:
3453 portFields.push_back({getId("data"), true, dataType});
3454 break;
3455
3456 case PortKind::Write:
3457 portFields.push_back({getId("data"), false, dataType});
3458 portFields.push_back({getId("mask"), false, maskType});
3459 break;
3460
3461 case PortKind::ReadWrite:
3462 portFields.push_back({getId("rdata"), true, dataType});
3463 portFields.push_back({getId("wmode"), false, UIntType::get(context, 1)});
3464 portFields.push_back({getId("wdata"), false, dataType});
3465 portFields.push_back({getId("wmask"), false, maskType});
3466 break;
3467 default:
3468 llvm::report_fatal_error("memory port kind not handled");
3469 break;
3470 }
3471
3472 return BundleType::get(context, portFields);
3473}
3474
3475/// Return the name and kind of ports supported by this memory.
3476SmallVector<MemOp::NamedPort> MemOp::getPorts() {
3477 SmallVector<MemOp::NamedPort> result;
3478 // Each entry in the bundle is a port.
3479 for (size_t i = 0, e = getNumResults(); i != e; ++i) {
3480 // Each port is a bundle.
3481 auto portType = type_cast<FIRRTLType>(getResult(i).getType());
3482 result.push_back({getPortNameAttr(i), getMemPortKindFromType(portType)});
3483 }
3484 return result;
3485}
3486
3487/// Return the kind of the specified port.
3488MemOp::PortKind MemOp::getPortKind(StringRef portName) {
3490 type_cast<FIRRTLType>(getPortNamed(portName).getType()));
3491}
3492
3493/// Return the kind of the specified port number.
3494MemOp::PortKind MemOp::getPortKind(size_t resultNo) {
3496 type_cast<FIRRTLType>(getResult(resultNo).getType()));
3497}
3498
3499/// Return the number of bits in the mask for the memory.
3500size_t MemOp::getMaskBits() {
3501
3502 for (auto res : getResults()) {
3503 if (type_isa<RefType>(res.getType()))
3504 continue;
3505 auto firstPortType = type_cast<FIRRTLBaseType>(res.getType());
3506 if (getMemPortKindFromType(firstPortType) == PortKind::Read ||
3507 getMemPortKindFromType(firstPortType) == PortKind::Debug)
3508 continue;
3509
3510 FIRRTLBaseType mType;
3511 for (auto t : type_cast<BundleType>(firstPortType.getPassiveType())) {
3512 if (t.name.getValue().contains("mask"))
3513 mType = t.type;
3514 }
3515 if (type_isa<UIntType>(mType))
3516 return mType.getBitWidthOrSentinel();
3517 }
3518 // Mask of zero bits means, either there are no write/readwrite ports or the
3519 // mask is of aggregate type.
3520 return 0;
3521}
3522
3523/// Return the data-type field of the memory, the type of each element.
3524FIRRTLBaseType MemOp::getDataType() {
3525 assert(getNumResults() != 0 && "Mems with no read/write ports are illegal");
3526
3527 if (auto refType = type_dyn_cast<RefType>(getResult(0).getType()))
3528 return type_cast<FVectorType>(refType.getType()).getElementType();
3529 auto firstPortType = type_cast<FIRRTLBaseType>(getResult(0).getType());
3530
3531 StringRef dataFieldName = "data";
3532 if (getMemPortKindFromType(firstPortType) == PortKind::ReadWrite)
3533 dataFieldName = "rdata";
3534
3535 return type_cast<BundleType>(firstPortType.getPassiveType())
3536 .getElementType(dataFieldName);
3537}
3538
3539StringAttr MemOp::getPortNameAttr(size_t resultNo) {
3540 return cast<StringAttr>(getPortNames()[resultNo]);
3541}
3542
3543FIRRTLBaseType MemOp::getPortType(size_t resultNo) {
3544 return type_cast<FIRRTLBaseType>(getResults()[resultNo].getType());
3545}
3546
3547Value MemOp::getPortNamed(StringAttr name) {
3548 auto namesArray = getPortNames();
3549 for (size_t i = 0, e = namesArray.size(); i != e; ++i) {
3550 if (namesArray[i] == name) {
3551 assert(i < getNumResults() && " names array out of sync with results");
3552 return getResult(i);
3553 }
3554 }
3555 return Value();
3556}
3557
3558// Extract all the relevant attributes from the MemOp and return the FirMemory.
3559FirMemory MemOp::getSummary() {
3560 auto op = *this;
3561 size_t numReadPorts = 0;
3562 size_t numWritePorts = 0;
3563 size_t numReadWritePorts = 0;
3565 SmallVector<int32_t> writeClockIDs;
3566
3567 for (size_t i = 0, e = op.getNumResults(); i != e; ++i) {
3568 auto portKind = op.getPortKind(i);
3569 if (portKind == MemOp::PortKind::Read)
3570 ++numReadPorts;
3571 else if (portKind == MemOp::PortKind::Write) {
3572 for (auto *a : op.getResult(i).getUsers()) {
3573 auto subfield = dyn_cast<SubfieldOp>(a);
3574 if (!subfield || subfield.getFieldIndex() != 2)
3575 continue;
3576 auto clockPort = a->getResult(0);
3577 for (auto *b : clockPort.getUsers()) {
3578 if (auto connect = dyn_cast<FConnectLike>(b)) {
3579 if (connect.getDest() == clockPort) {
3580 auto result =
3581 clockToLeader.insert({circt::firrtl::getModuleScopedDriver(
3582 connect.getSrc(), true, true, true),
3583 numWritePorts});
3584 if (result.second) {
3585 writeClockIDs.push_back(numWritePorts);
3586 } else {
3587 writeClockIDs.push_back(result.first->second);
3588 }
3589 }
3590 }
3591 }
3592 break;
3593 }
3594 ++numWritePorts;
3595 } else
3596 ++numReadWritePorts;
3597 }
3598
3599 size_t width = 0;
3600 if (auto widthV = getBitWidth(op.getDataType()))
3601 width = *widthV;
3602 else
3603 op.emitError("'firrtl.mem' should have simple type and known width");
3604 MemoryInitAttr init = op->getAttrOfType<MemoryInitAttr>("init");
3605 StringAttr modName;
3606 if (op->hasAttr("modName"))
3607 modName = op->getAttrOfType<StringAttr>("modName");
3608 else {
3609 SmallString<8> clocks;
3610 for (auto a : writeClockIDs)
3611 clocks.append(Twine((char)(a + 'a')).str());
3612 SmallString<32> initStr;
3613 // If there is a file initialization, then come up with a decent
3614 // representation for this. Use the filename, but only characters
3615 // [a-zA-Z0-9] and the bool/hex and inline booleans.
3616 if (init) {
3617 for (auto c : init.getFilename().getValue())
3618 if ((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') ||
3619 (c >= '0' && c <= '9'))
3620 initStr.push_back(c);
3621 initStr.push_back('_');
3622 initStr.push_back(init.getIsBinary() ? 't' : 'f');
3623 initStr.push_back('_');
3624 initStr.push_back(init.getIsInline() ? 't' : 'f');
3625 }
3626 modName = StringAttr::get(
3627 op->getContext(),
3628 llvm::formatv(
3629 "{0}FIRRTLMem_{1}_{2}_{3}_{4}_{5}_{6}_{7}_{8}_{9}_{10}{11}{12}",
3630 op.getPrefix().value_or(""), numReadPorts, numWritePorts,
3631 numReadWritePorts, (size_t)width, op.getDepth(),
3632 op.getReadLatency(), op.getWriteLatency(), op.getMaskBits(),
3633 (unsigned)op.getRuw(), (unsigned)seq::WUW::PortOrder,
3634 clocks.empty() ? "" : "_" + clocks, init ? initStr.str() : ""));
3635 }
3636 return {numReadPorts,
3637 numWritePorts,
3638 numReadWritePorts,
3639 (size_t)width,
3640 op.getDepth(),
3641 op.getReadLatency(),
3642 op.getWriteLatency(),
3643 op.getMaskBits(),
3644 *seq::symbolizeRUW(unsigned(op.getRuw())),
3645 seq::WUW::PortOrder,
3646 writeClockIDs,
3647 modName,
3648 op.getMaskBits() > 1,
3649 init,
3650 op.getPrefixAttr(),
3651 op.getLoc()};
3652}
3653
3654void MemOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
3655 StringRef base = getName();
3656 if (base.empty())
3657 base = "mem";
3658
3659 for (size_t i = 0, e = (*this)->getNumResults(); i != e; ++i) {
3660 setNameFn(getResult(i), (base + "_" + getPortName(i)).str());
3661 }
3662}
3663
3664std::optional<size_t> MemOp::getTargetResultIndex() {
3665 // Inner symbols on memory operations target the op not any result.
3666 return std::nullopt;
3667}
3668
3669// Construct name of the module which will be used for the memory definition.
3670StringAttr FirMemory::getFirMemoryName() const { return modName; }
3671
3672/// Helper for naming forceable declarations (and their optional ref result).
3673static void forceableAsmResultNames(Forceable op, StringRef name,
3674 OpAsmSetValueNameFn setNameFn) {
3675 if (name.empty())
3676 return;
3677 setNameFn(op.getDataRaw(), name);
3678 if (op.isForceable())
3679 setNameFn(op.getDataRef(), (name + "_ref").str());
3680}
3681
3682void NodeOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
3683 return forceableAsmResultNames(*this, getName(), setNameFn);
3684}
3685
3686LogicalResult NodeOp::inferReturnTypes(
3687 mlir::MLIRContext *context, std::optional<mlir::Location> location,
3688 ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes,
3689 ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions,
3690 ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) {
3691 if (operands.empty())
3692 return failure();
3693 Adaptor adaptor(operands, attributes, properties, regions);
3694 inferredReturnTypes.push_back(adaptor.getInput().getType());
3695 if (adaptor.getForceable()) {
3696 auto forceableType = firrtl::detail::getForceableResultType(
3697 true, adaptor.getInput().getType());
3698 if (!forceableType) {
3699 if (location)
3700 ::mlir::emitError(*location, "cannot force a node of type ")
3701 << operands[0].getType();
3702 return failure();
3703 }
3704 inferredReturnTypes.push_back(forceableType);
3705 }
3706 return success();
3707}
3708
3709std::optional<size_t> NodeOp::getTargetResultIndex() { return 0; }
3710
3711void RegOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
3712 return forceableAsmResultNames(*this, getName(), setNameFn);
3713}
3714
3715std::optional<size_t> RegOp::getTargetResultIndex() { return 0; }
3716
3717SmallVector<std::pair<circt::FieldRef, circt::FieldRef>>
3718RegOp::computeDataFlow() {
3719 // A register does't have any combinational dataflow.
3720 return {};
3721}
3722
3723LogicalResult RegResetOp::verify() {
3724 auto reset = getResetValue();
3725
3726 FIRRTLBaseType resetType = reset.getType();
3727 FIRRTLBaseType regType = getResult().getType();
3728
3729 // The type of the initialiser must be equivalent to the register type.
3730 if (!areTypesEquivalent(regType, resetType))
3731 return emitError("type mismatch between register ")
3732 << regType << " and reset value " << resetType;
3733
3734 return success();
3735}
3736
3737std::optional<size_t> RegResetOp::getTargetResultIndex() { return 0; }
3738
3739void RegResetOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
3740 return forceableAsmResultNames(*this, getName(), setNameFn);
3741}
3742
3743//===----------------------------------------------------------------------===//
3744// FormalOp
3745//===----------------------------------------------------------------------===//
3746
3747LogicalResult
3748FormalOp::verifySymbolUses(mlir::SymbolTableCollection &symbolTable) {
3749 auto *op = symbolTable.lookupNearestSymbolFrom(*this, getModuleNameAttr());
3750 if (!op)
3751 return emitOpError() << "targets unknown module " << getModuleNameAttr();
3752
3753 if (!isa<FModuleLike>(op)) {
3754 auto d = emitOpError() << "target " << getModuleNameAttr()
3755 << " is not a module";
3756 d.attachNote(op->getLoc()) << "target defined here";
3757 return d;
3758 }
3759
3760 return success();
3761}
3762
3763//===----------------------------------------------------------------------===//
3764// SimulationOp
3765//===----------------------------------------------------------------------===//
3766
3767LogicalResult
3768SimulationOp::verifySymbolUses(mlir::SymbolTableCollection &symbolTable) {
3769 auto *op = symbolTable.lookupNearestSymbolFrom(*this, getModuleNameAttr());
3770 if (!op)
3771 return emitOpError() << "targets unknown module " << getModuleNameAttr();
3772
3773 auto complain = [&] {
3774 auto d = emitOpError() << "target " << getModuleNameAttr() << " ";
3775 d.attachNote(op->getLoc()) << "target defined here";
3776 return d;
3777 };
3778
3779 auto module = dyn_cast<FModuleLike>(op);
3780 if (!module)
3781 return complain() << "is not a module";
3782
3783 auto numPorts = module.getPortDirections().size();
3784 if (numPorts != 4)
3785 return complain() << "must have 4 ports, got " << numPorts << " instead";
3786
3787 auto checkPort = [&](unsigned idx, StringRef expName, Direction expDir,
3788 llvm::function_ref<bool(Type)> checkType,
3789 StringRef expType) {
3790 auto name = module.getPortNameAttr(idx);
3791 if (name != expName) {
3792 complain() << "port " << idx << " must be called \"" << expName
3793 << "\", got " << name << " instead";
3794 return false;
3795 }
3796 if (auto dir = module.getPortDirection(idx); dir != expDir) {
3797 auto stringify = [](Direction dir) {
3798 return dir == Direction::In ? "an input" : "an output";
3799 };
3800 complain() << "port " << name << " must be " << stringify(expDir)
3801 << ", got " << stringify(dir) << " instead";
3802 return false;
3803 }
3804 if (auto type = module.getPortType(idx); !checkType(type)) {
3805 complain() << "port " << name << " must be a '!firrtl." << expType
3806 << "', got " << type << " instead";
3807 return false;
3808 }
3809 return true;
3810 };
3811
3812 auto isClock = [](Type type) { return isa<ClockType>(type); };
3813 auto isBool = [](Type type) {
3814 if (auto uintType = dyn_cast<UIntType>(type))
3815 return uintType.getWidth() == 1;
3816 return false;
3817 };
3818 return success(checkPort(0, "clock", Direction::In, isClock, "clock") &&
3819 checkPort(1, "init", Direction::In, isBool, "uint<1>") &&
3820 checkPort(2, "done", Direction::Out, isBool, "uint<1>") &&
3821 checkPort(3, "success", Direction::Out, isBool, "uint<1>"));
3822}
3823
3824//===----------------------------------------------------------------------===//
3825// WireOp
3826//===----------------------------------------------------------------------===//
3827
3828void WireOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
3829 return forceableAsmResultNames(*this, getName(), setNameFn);
3830}
3831
3832SmallVector<std::pair<circt::FieldRef, circt::FieldRef>>
3833RegResetOp::computeDataFlow() {
3834 // A register does't have any combinational dataflow.
3835 return {};
3836}
3837
3838std::optional<size_t> WireOp::getTargetResultIndex() { return 0; }
3839
3840LogicalResult WireOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
3841 auto refType = type_dyn_cast<RefType>(getType(0));
3842 if (!refType)
3843 return success();
3844
3845 return verifyProbeType(
3846 refType, getLoc(), getOperation()->getParentOfType<CircuitOp>(),
3847 symbolTable, Twine("'") + getOperationName() + "' op is");
3848}
3849
3850//===----------------------------------------------------------------------===//
3851// ContractOp
3852//===----------------------------------------------------------------------===//
3853
3854LogicalResult ContractOp::verify() {
3855 if (getBody().getArgumentTypes() != getInputs().getType())
3856 return emitOpError("result types and region argument types must match");
3857 return success();
3858}
3859
3860//===----------------------------------------------------------------------===//
3861// ObjectOp
3862//===----------------------------------------------------------------------===//
3863
3864void ObjectOp::build(OpBuilder &builder, OperationState &state, ClassLike klass,
3865 StringRef name) {
3866 build(builder, state, klass.getInstanceType(),
3867 StringAttr::get(builder.getContext(), name));
3868}
3869
3870LogicalResult ObjectOp::verify() { return success(); }
3871
3872LogicalResult ObjectOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
3873 auto circuitOp = getOperation()->getParentOfType<CircuitOp>();
3874 auto classType = getType();
3875 auto className = classType.getNameAttr();
3876
3877 // verify that the class exists.
3878 auto classOp = dyn_cast_or_null<ClassLike>(
3879 symbolTable.lookupSymbolIn(circuitOp, className));
3880 if (!classOp)
3881 return emitOpError() << "references unknown class " << className;
3882
3883 // verify that the result type agrees with the class definition.
3884 if (failed(classOp.verifyType(classType, [&]() { return emitOpError(); })))
3885 return failure();
3886
3887 return success();
3888}
3889
3890StringAttr ObjectOp::getClassNameAttr() {
3891 return getType().getNameAttr().getAttr();
3892}
3893
3894StringRef ObjectOp::getClassName() { return getType().getName(); }
3895
3896ClassLike ObjectOp::getReferencedClass(const SymbolTable &symbolTable) {
3897 auto symRef = getType().getNameAttr();
3898 return symbolTable.lookup<ClassLike>(symRef.getLeafReference());
3899}
3900
3901Operation *ObjectOp::getReferencedOperation(const SymbolTable &symtbl) {
3902 return getReferencedClass(symtbl);
3903}
3904
3905StringRef ObjectOp::getInstanceName() { return getName(); }
3906
3907StringAttr ObjectOp::getInstanceNameAttr() { return getNameAttr(); }
3908
3909StringRef ObjectOp::getReferencedModuleName() { return getClassName(); }
3910
3911StringAttr ObjectOp::getReferencedModuleNameAttr() {
3912 return getClassNameAttr();
3913}
3914
3915void ObjectOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
3916 setNameFn(getResult(), getName());
3917}
3918
3919//===----------------------------------------------------------------------===//
3920// Statements
3921//===----------------------------------------------------------------------===//
3922
3923LogicalResult AttachOp::verify() {
3924 // All known widths must match.
3925 std::optional<int32_t> commonWidth;
3926 for (auto operand : getOperands()) {
3927 auto thisWidth = type_cast<AnalogType>(operand.getType()).getWidth();
3928 if (!thisWidth)
3929 continue;
3930 if (!commonWidth) {
3931 commonWidth = thisWidth;
3932 continue;
3933 }
3934 if (commonWidth != thisWidth)
3935 return emitOpError("is inavlid as not all known operand widths match");
3936 }
3937 return success();
3938}
3939
3940/// Check if the source and sink are of appropriate flow.
3941static LogicalResult checkConnectFlow(Operation *connect) {
3942 Value dst = connect->getOperand(0);
3943 Value src = connect->getOperand(1);
3944
3945 // TODO: Relax this to allow reads from output ports,
3946 // instance/memory input ports.
3947 auto srcFlow = foldFlow(src);
3948 if (!isValidSrc(srcFlow)) {
3949 // A sink that is a port output or instance input used as a source is okay,
3950 // as long as it is not a property.
3951 auto kind = getDeclarationKind(src);
3952 if (isa<PropertyType>(src.getType()) ||
3953 (kind != DeclKind::Port && kind != DeclKind::Instance)) {
3954 auto srcRef = getFieldRefFromValue(src, /*lookThroughCasts=*/true);
3955 auto [srcName, rootKnown] = getFieldName(srcRef);
3956 auto diag = emitError(connect->getLoc());
3957 diag << "connect has invalid flow: the source expression ";
3958 if (rootKnown)
3959 diag << "\"" << srcName << "\" ";
3960 diag << "has " << toString(srcFlow) << ", expected source or duplex flow";
3961 return diag.attachNote(srcRef.getLoc()) << "the source was defined here";
3962 }
3963 }
3964
3965 auto dstFlow = foldFlow(dst);
3966 if (!isValidDst(dstFlow)) {
3967 auto dstRef = getFieldRefFromValue(dst, /*lookThroughCasts=*/true);
3968 auto [dstName, rootKnown] = getFieldName(dstRef);
3969 auto diag = emitError(connect->getLoc());
3970 diag << "connect has invalid flow: the destination expression ";
3971 if (rootKnown)
3972 diag << "\"" << dstName << "\" ";
3973 diag << "has " << toString(dstFlow) << ", expected sink or duplex flow";
3974 return diag.attachNote(dstRef.getLoc())
3975 << "the destination was defined here";
3976 }
3977 return success();
3978}
3979
3980// NOLINTBEGIN(misc-no-recursion)
3981/// Checks if the type has any 'const' leaf elements . If `isFlip` is `true`,
3982/// the `const` leaf is not considered to be driven.
3983static bool isConstFieldDriven(FIRRTLBaseType type, bool isFlip = false,
3984 bool outerTypeIsConst = false) {
3985 auto typeIsConst = outerTypeIsConst || type.isConst();
3986
3987 if (typeIsConst && type.isPassive())
3988 return !isFlip;
3989
3990 if (auto bundleType = type_dyn_cast<BundleType>(type))
3991 return llvm::any_of(bundleType.getElements(), [&](auto &element) {
3992 return isConstFieldDriven(element.type, isFlip ^ element.isFlip,
3993 typeIsConst);
3994 });
3995
3996 if (auto vectorType = type_dyn_cast<FVectorType>(type))
3997 return isConstFieldDriven(vectorType.getElementType(), isFlip, typeIsConst);
3998
3999 if (typeIsConst)
4000 return !isFlip;
4001 return false;
4002}
4003// NOLINTEND(misc-no-recursion)
4004
4005/// Checks that connections to 'const' destinations are not dependent on
4006/// non-'const' conditions in when blocks.
4007static LogicalResult checkConnectConditionality(FConnectLike connect) {
4008 auto dest = connect.getDest();
4009 auto destType = type_dyn_cast<FIRRTLBaseType>(dest.getType());
4010 auto src = connect.getSrc();
4011 auto srcType = type_dyn_cast<FIRRTLBaseType>(src.getType());
4012 if (!destType || !srcType)
4013 return success();
4014
4015 auto destRefinedType = destType;
4016 auto srcRefinedType = srcType;
4017
4018 /// Looks up the value's defining op until the defining op is null or a
4019 /// declaration of the value. If a SubAccessOp is encountered with a 'const'
4020 /// input, `originalFieldType` is made 'const'.
4021 auto findFieldDeclarationRefiningFieldType =
4022 [](Value value, FIRRTLBaseType &originalFieldType) -> Value {
4023 while (auto *definingOp = value.getDefiningOp()) {
4024 bool shouldContinue = true;
4025 TypeSwitch<Operation *>(definingOp)
4026 .Case<SubfieldOp, SubindexOp>([&](auto op) { value = op.getInput(); })
4027 .Case<SubaccessOp>([&](SubaccessOp op) {
4028 if (op.getInput()
4029 .getType()
4030 .base()
4031 .getElementTypePreservingConst()
4032 .isConst())
4033 originalFieldType = originalFieldType.getConstType(true);
4034 value = op.getInput();
4035 })
4036 .Default([&](Operation *) { shouldContinue = false; });
4037 if (!shouldContinue)
4038 break;
4039 }
4040 return value;
4041 };
4042
4043 auto destDeclaration =
4044 findFieldDeclarationRefiningFieldType(dest, destRefinedType);
4045 auto srcDeclaration =
4046 findFieldDeclarationRefiningFieldType(src, srcRefinedType);
4047
4048 auto checkConstConditionality = [&](Value value, FIRRTLBaseType type,
4049 Value declaration) -> LogicalResult {
4050 auto *declarationBlock = declaration.getParentBlock();
4051 auto *block = connect->getBlock();
4052 while (block && block != declarationBlock) {
4053 auto *parentOp = block->getParentOp();
4054
4055 if (auto whenOp = dyn_cast<WhenOp>(parentOp);
4056 whenOp && !whenOp.getCondition().getType().isConst()) {
4057 if (type.isConst())
4058 return connect.emitOpError()
4059 << "assignment to 'const' type " << type
4060 << " is dependent on a non-'const' condition";
4061 return connect->emitOpError()
4062 << "assignment to nested 'const' member of type " << type
4063 << " is dependent on a non-'const' condition";
4064 }
4065
4066 block = parentOp->getBlock();
4067 }
4068 return success();
4069 };
4070
4071 auto emitSubaccessError = [&] {
4072 return connect.emitError(
4073 "assignment to non-'const' subaccess of 'const' type is disallowed");
4074 };
4075
4076 // Check destination if it contains 'const' leaves
4077 if (destRefinedType.containsConst() && isConstFieldDriven(destRefinedType)) {
4078 // Disallow assignment to non-'const' subaccesses of 'const' types
4079 if (destType != destRefinedType)
4080 return emitSubaccessError();
4081
4082 if (failed(checkConstConditionality(dest, destType, destDeclaration)))
4083 return failure();
4084 }
4085
4086 // Check source if it contains 'const' 'flip' leaves
4087 if (srcRefinedType.containsConst() &&
4088 isConstFieldDriven(srcRefinedType, /*isFlip=*/true)) {
4089 // Disallow assignment to non-'const' subaccesses of 'const' types
4090 if (srcType != srcRefinedType)
4091 return emitSubaccessError();
4092 if (failed(checkConstConditionality(src, srcType, srcDeclaration)))
4093 return failure();
4094 }
4095
4096 return success();
4097}
4098
4099/// Returns success if the given connect is the sole driver of its dest operand.
4100/// Returns failure if there are other connects driving the dest.
4101static LogicalResult checkSingleConnect(FConnectLike connect) {
4102 // For now, refs and domains can't be in bundles so this is sufficient. This
4103 // is insufficient for properties, and insufficient before
4104 // lower-open-aggregates has run. In the future, need to ensure no other
4105 // define's to same "fieldSource". (When aggregates can have references, we
4106 // can define a reference within, but this must be unique. Checking this here
4107 // may be expensive, consider adding something to FModuleLike's to check it
4108 // there instead)
4109 auto dest = connect.getDest();
4110 for (auto *user : dest.getUsers()) {
4111 if (auto c = dyn_cast<FConnectLike>(user);
4112 c && c.getDest() == dest && c != connect) {
4113 auto diag = connect.emitError("destination cannot be driven by multiple "
4114 "operations");
4115 diag.attachNote(c->getLoc()) << "other driver is here";
4116 return failure();
4117 }
4118 }
4119 return success();
4120}
4121
4122LogicalResult ConnectOp::verify() {
4123 auto dstType = getDest().getType();
4124 auto srcType = getSrc().getType();
4125 auto dstBaseType = type_dyn_cast<FIRRTLBaseType>(dstType);
4126 auto srcBaseType = type_dyn_cast<FIRRTLBaseType>(srcType);
4127 if (!dstBaseType || !srcBaseType) {
4128 if (dstType != srcType)
4129 return emitError("may not connect different non-base types");
4130 } else {
4131 // Analog types cannot be connected and must be attached.
4132 if (dstBaseType.containsAnalog() || srcBaseType.containsAnalog())
4133 return emitError("analog types may not be connected");
4134
4135 // Destination and source types must be equivalent.
4136 if (!areTypesEquivalent(dstBaseType, srcBaseType))
4137 return emitError("type mismatch between destination ")
4138 << dstBaseType << " and source " << srcBaseType;
4139
4140 // Truncation is banned in a connection: destination bit width must be
4141 // greater than or equal to source bit width.
4142 if (!isTypeLarger(dstBaseType, srcBaseType))
4143 return emitError("destination ")
4144 << dstBaseType << " is not as wide as the source " << srcBaseType;
4145 }
4146
4147 // Check that the flows make sense.
4148 if (failed(checkConnectFlow(*this)))
4149 return failure();
4150
4151 if (failed(checkConnectConditionality(*this)))
4152 return failure();
4153
4154 return success();
4155}
4156
4157LogicalResult MatchingConnectOp::verify() {
4158 if (auto type = type_dyn_cast<FIRRTLType>(getDest().getType())) {
4159 auto baseType = type_cast<FIRRTLBaseType>(type);
4160
4161 // Analog types cannot be connected and must be attached.
4162 if (baseType && baseType.containsAnalog())
4163 return emitError("analog types may not be connected");
4164
4165 // The anonymous types of operands must be equivalent.
4166 assert(areAnonymousTypesEquivalent(cast<FIRRTLBaseType>(getSrc().getType()),
4167 baseType) &&
4168 "`SameAnonTypeOperands` trait should have already rejected "
4169 "structurally non-equivalent types");
4170 }
4171
4172 // Check that the flows make sense.
4173 if (failed(checkConnectFlow(*this)))
4174 return failure();
4175
4176 if (failed(checkConnectConditionality(*this)))
4177 return failure();
4178
4179 return success();
4180}
4181
4182LogicalResult RefDefineOp::verify() {
4183 if (failed(checkConnectFlow(*this)))
4184 return failure();
4185
4186 if (failed(checkSingleConnect(*this)))
4187 return failure();
4188
4189 if (auto *op = getDest().getDefiningOp()) {
4190 // TODO: Make ref.sub only source flow?
4191 if (isa<RefSubOp>(op))
4192 return emitError(
4193 "destination reference cannot be a sub-element of a reference");
4194 if (isa<RefCastOp>(op)) // Source flow, check anyway for now.
4195 return emitError(
4196 "destination reference cannot be a cast of another reference");
4197 }
4198
4199 // This define is only enabled when its ambient layers are active. Check
4200 // that whenever the destination's layer requirements are met, that this
4201 // op is enabled.
4202 auto ambientLayers = getAmbientLayersAt(getOperation());
4203 auto dstLayers = getLayersFor(getDest());
4204 SmallVector<SymbolRefAttr> missingLayers;
4205
4206 return checkLayerCompatibility(getOperation(), ambientLayers, dstLayers,
4207 "has more layer requirements than destination",
4208 "additional layers required");
4209}
4210
4211LogicalResult PropAssignOp::verify() {
4212 if (failed(checkConnectFlow(*this)))
4213 return failure();
4214
4215 if (failed(checkSingleConnect(*this)))
4216 return failure();
4217
4218 return success();
4219}
4220
4221template <typename T>
4222static FlatSymbolRefAttr getDomainTypeNameOfResult(T op, size_t i) {
4223 auto info = op.getDomainInfo();
4224 if (info.empty())
4225 return {};
4226 return dyn_cast<FlatSymbolRefAttr>(info[i]);
4227}
4228
4229static FlatSymbolRefAttr getDomainTypeName(Value value) {
4230 if (!isa<DomainType>(value.getType()))
4231 return {};
4232
4233 if (auto arg = dyn_cast<BlockArgument>(value)) {
4234 auto *parent = arg.getOwner()->getParentOp();
4235 if (auto module = dyn_cast<FModuleLike>(parent)) {
4236 auto info = module.getDomainInfo();
4237 if (info.empty())
4238 return {};
4239 auto attr = info[arg.getArgNumber()];
4240 return dyn_cast<FlatSymbolRefAttr>(attr);
4241 }
4242
4243 return {};
4244 }
4245
4246 if (auto result = dyn_cast<OpResult>(value)) {
4247 auto *op = result.getDefiningOp();
4248 if (auto instance = dyn_cast<InstanceOp>(op))
4249 return getDomainTypeNameOfResult(instance, result.getResultNumber());
4250 if (auto instance = dyn_cast<InstanceChoiceOp>(op))
4251 return getDomainTypeNameOfResult(instance, result.getResultNumber());
4252 return {};
4253 }
4254
4255 return {};
4256}
4257
4258LogicalResult DomainDefineOp::verify() {
4259 if (failed(checkConnectFlow(*this)))
4260 return failure();
4261
4262 if (failed(checkSingleConnect(*this)))
4263 return failure();
4264
4265 auto dst = getDest();
4266 auto src = getSrc();
4267
4268 auto dstDomain = getDomainTypeName(dst);
4269 if (!dstDomain)
4270 return emitError("could not determine domain-type of destination");
4271
4272 auto srcDomain = getDomainTypeName(src);
4273 if (!srcDomain)
4274 return emitError("could not determine domain-type of source");
4275
4276 if (dstDomain != srcDomain) {
4277 auto diag = emitError()
4278 << "source domain type " << srcDomain
4279 << " does not match destination domain type " << dstDomain;
4280 return diag;
4281 }
4282
4283 return success();
4284}
4285
4286void WhenOp::createElseRegion() {
4287 assert(!hasElseRegion() && "already has an else region");
4288 getElseRegion().push_back(new Block());
4289}
4290
4291void WhenOp::build(OpBuilder &builder, OperationState &result, Value condition,
4292 bool withElseRegion, std::function<void()> thenCtor,
4293 std::function<void()> elseCtor) {
4294 OpBuilder::InsertionGuard guard(builder);
4295 result.addOperands(condition);
4296
4297 // Create "then" region.
4298 builder.createBlock(result.addRegion());
4299 if (thenCtor)
4300 thenCtor();
4301
4302 // Create "else" region.
4303 Region *elseRegion = result.addRegion();
4304 if (withElseRegion) {
4305 builder.createBlock(elseRegion);
4306 if (elseCtor)
4307 elseCtor();
4308 }
4309}
4310
4311//===----------------------------------------------------------------------===//
4312// MatchOp
4313//===----------------------------------------------------------------------===//
4314
4315LogicalResult MatchOp::verify() {
4316 FEnumType type = getInput().getType();
4317
4318 // Make sure that the number of tags matches the number of regions.
4319 auto numCases = getTags().size();
4320 auto numRegions = getNumRegions();
4321 if (numRegions != numCases)
4322 return emitOpError("expected ")
4323 << numRegions << " tags but got " << numCases;
4324
4325 auto numTags = type.getNumElements();
4326
4327 SmallDenseSet<int64_t> seen;
4328 for (const auto &[tag, region] : llvm::zip(getTags(), getRegions())) {
4329 auto tagIndex = size_t(cast<IntegerAttr>(tag).getInt());
4330
4331 // Ensure that the block has a single argument.
4332 if (region.getNumArguments() != 1)
4333 return emitOpError("region should have exactly one argument");
4334
4335 // Make sure that it is a valid tag.
4336 if (tagIndex >= numTags)
4337 return emitOpError("the tag index ")
4338 << tagIndex << " is out of the range of valid tags in " << type;
4339
4340 // Make sure we have not already matched this tag.
4341 auto [it, inserted] = seen.insert(tagIndex);
4342 if (!inserted)
4343 return emitOpError("the tag ") << type.getElementNameAttr(tagIndex)
4344 << " is matched more than once";
4345
4346 // Check that the block argument type matches the tag's type.
4347 auto expectedType = type.getElementTypePreservingConst(tagIndex);
4348 auto regionType = region.getArgument(0).getType();
4349 if (regionType != expectedType)
4350 return emitOpError("region type ")
4351 << regionType << " does not match the expected type "
4352 << expectedType;
4353 }
4354
4355 // Check that the match statement is exhaustive.
4356 for (size_t i = 0, e = type.getNumElements(); i < e; ++i)
4357 if (!seen.contains(i))
4358 return emitOpError("missing case for tag ") << type.getElementNameAttr(i);
4359
4360 return success();
4361}
4362
4363void MatchOp::print(OpAsmPrinter &p) {
4364 auto input = getInput();
4365 FEnumType type = input.getType();
4366 auto regions = getRegions();
4367 p << " " << input << " : " << type;
4368 SmallVector<StringRef> elided = {"tags"};
4369 p.printOptionalAttrDictWithKeyword((*this)->getAttrs(), elided);
4370 p << " {";
4371 p.increaseIndent();
4372 for (const auto &[tag, region] : llvm::zip(getTags(), regions)) {
4373 p.printNewline();
4374 p << "case ";
4375 p.printKeywordOrString(
4376 type.getElementName(cast<IntegerAttr>(tag).getInt()));
4377 p << "(";
4378 p.printRegionArgument(region.front().getArgument(0), /*attrs=*/{},
4379 /*omitType=*/true);
4380 p << ") ";
4381 p.printRegion(region, /*printEntryBlockArgs=*/false);
4382 }
4383 p.decreaseIndent();
4384 p.printNewline();
4385 p << "}";
4386}
4387
4388ParseResult MatchOp::parse(OpAsmParser &parser, OperationState &result) {
4389 auto *context = parser.getContext();
4390 auto &properties = result.getOrAddProperties<Properties>();
4391 OpAsmParser::UnresolvedOperand input;
4392 if (parser.parseOperand(input) || parser.parseColon())
4393 return failure();
4394
4395 auto loc = parser.getCurrentLocation();
4396 Type type;
4397 if (parser.parseType(type))
4398 return failure();
4399 auto enumType = type_dyn_cast<FEnumType>(type);
4400 if (!enumType)
4401 return parser.emitError(loc, "expected enumeration type but got") << type;
4402
4403 if (parser.resolveOperand(input, type, result.operands) ||
4404 parser.parseOptionalAttrDictWithKeyword(result.attributes) ||
4405 parser.parseLBrace())
4406 return failure();
4407
4408 auto i32Type = IntegerType::get(context, 32);
4409 SmallVector<Attribute> tags;
4410 while (true) {
4411 // Stop parsing when we don't find another "case" keyword.
4412 if (failed(parser.parseOptionalKeyword("case")))
4413 break;
4414
4415 // Parse the tag and region argument.
4416 auto nameLoc = parser.getCurrentLocation();
4417 std::string name;
4418 OpAsmParser::Argument arg;
4419 auto *region = result.addRegion();
4420 if (parser.parseKeywordOrString(&name) || parser.parseLParen() ||
4421 parser.parseArgument(arg) || parser.parseRParen())
4422 return failure();
4423
4424 // Figure out the enum index of the tag.
4425 auto index = enumType.getElementIndex(name);
4426 if (!index)
4427 return parser.emitError(nameLoc, "the tag \"")
4428 << name << "\" is not a member of the enumeration " << enumType;
4429 tags.push_back(IntegerAttr::get(i32Type, *index));
4430
4431 // Parse the region.
4432 arg.type = enumType.getElementTypePreservingConst(*index);
4433 if (parser.parseRegion(*region, arg))
4434 return failure();
4435 }
4436 properties.setTags(ArrayAttr::get(context, tags));
4437
4438 return parser.parseRBrace();
4439}
4440
4441void MatchOp::build(OpBuilder &builder, OperationState &result, Value input,
4442 ArrayAttr tags,
4443 MutableArrayRef<std::unique_ptr<Region>> regions) {
4444 auto &properties = result.getOrAddProperties<Properties>();
4445 result.addOperands(input);
4446 properties.setTags(tags);
4447 result.addRegions(regions);
4448}
4449
4450//===----------------------------------------------------------------------===//
4451// Expressions
4452//===----------------------------------------------------------------------===//
4453
4454/// Return true if the specified operation is a firrtl expression.
4455bool firrtl::isExpression(Operation *op) {
4456 struct IsExprClassifier : public ExprVisitor<IsExprClassifier, bool> {
4457 bool visitInvalidExpr(Operation *op) { return false; }
4458 bool visitUnhandledExpr(Operation *op) { return true; }
4459 };
4460
4461 return IsExprClassifier().dispatchExprVisitor(op);
4462}
4463
4464void InvalidValueOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
4465 // Set invalid values to have a distinct name.
4466 std::string name;
4467 if (auto ty = type_dyn_cast<IntType>(getType())) {
4468 const char *base = ty.isSigned() ? "invalid_si" : "invalid_ui";
4469 auto width = ty.getWidthOrSentinel();
4470 if (width == -1)
4471 name = base;
4472 else
4473 name = (Twine(base) + Twine(width)).str();
4474 } else if (auto ty = type_dyn_cast<AnalogType>(getType())) {
4475 auto width = ty.getWidthOrSentinel();
4476 if (width == -1)
4477 name = "invalid_analog";
4478 else
4479 name = ("invalid_analog" + Twine(width)).str();
4480 } else if (type_isa<AsyncResetType>(getType()))
4481 name = "invalid_asyncreset";
4482 else if (type_isa<ResetType>(getType()))
4483 name = "invalid_reset";
4484 else if (type_isa<ClockType>(getType()))
4485 name = "invalid_clock";
4486 else
4487 name = "invalid";
4488
4489 setNameFn(getResult(), name);
4490}
4491
4492void ConstantOp::print(OpAsmPrinter &p) {
4493 p << " ";
4494 p.printAttributeWithoutType(getValueAttr());
4495 p << " : ";
4496 p.printType(getType());
4497 p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"value"});
4498}
4499
4500ParseResult ConstantOp::parse(OpAsmParser &parser, OperationState &result) {
4501 auto &properties = result.getOrAddProperties<Properties>();
4502 // Parse the constant value, without knowing its width.
4503 APInt value;
4504 auto loc = parser.getCurrentLocation();
4505 auto valueResult = parser.parseOptionalInteger(value);
4506 if (!valueResult.has_value())
4507 return parser.emitError(loc, "expected integer value");
4508
4509 // Parse the result firrtl integer type.
4510 IntType resultType;
4511 if (failed(*valueResult) || parser.parseColonType(resultType) ||
4512 parser.parseOptionalAttrDict(result.attributes))
4513 return failure();
4514 result.addTypes(resultType);
4515
4516 // Now that we know the width and sign of the result type, we can munge the
4517 // APInt as appropriate.
4518 if (resultType.hasWidth()) {
4519 auto width = (unsigned)resultType.getWidthOrSentinel();
4520 if (width > value.getBitWidth()) {
4521 // sext is always safe here, even for unsigned values, because the
4522 // parseOptionalInteger method will return something with a zero in the
4523 // top bits if it is a positive number.
4524 value = value.sext(width);
4525 } else if (width < value.getBitWidth()) {
4526 // The parser can return an unnecessarily wide result with leading
4527 // zeros. This isn't a problem, but truncating off bits is bad.
4528 unsigned neededBits = value.isNegative() ? value.getSignificantBits()
4529 : value.getActiveBits();
4530 if (width < neededBits)
4531 return parser.emitError(loc, "constant out of range for result type ")
4532 << resultType;
4533 value = value.trunc(width);
4534 }
4535 }
4536
4537 auto intType = parser.getBuilder().getIntegerType(value.getBitWidth(),
4538 resultType.isSigned());
4539 auto valueAttr = parser.getBuilder().getIntegerAttr(intType, value);
4540 properties.setValue(valueAttr);
4541 return success();
4542}
4543
4544LogicalResult ConstantOp::verify() {
4545 // If the result type has a bitwidth, then the attribute must match its width.
4546 IntType intType = getType();
4547 auto width = intType.getWidthOrSentinel();
4548 if (width != -1 && (int)getValue().getBitWidth() != width)
4549 return emitError(
4550 "firrtl.constant attribute bitwidth doesn't match return type");
4551
4552 // The sign of the attribute's integer type must match our integer type sign.
4553 auto attrType = type_cast<IntegerType>(getValueAttr().getType());
4554 if (attrType.isSignless() || attrType.isSigned() != intType.isSigned())
4555 return emitError("firrtl.constant attribute has wrong sign");
4556
4557 return success();
4558}
4559
4560/// Build a ConstantOp from an APInt and a FIRRTL type, handling the attribute
4561/// formation for the 'value' attribute.
4562void ConstantOp::build(OpBuilder &builder, OperationState &result, IntType type,
4563 const APInt &value) {
4564 int32_t width = type.getWidthOrSentinel();
4565 (void)width;
4566 assert((width == -1 || (int32_t)value.getBitWidth() == width) &&
4567 "incorrect attribute bitwidth for firrtl.constant");
4568
4569 auto attr =
4570 IntegerAttr::get(type.getContext(), APSInt(value, !type.isSigned()));
4571 return build(builder, result, type, attr);
4572}
4573
4574/// Build a ConstantOp from an APSInt, handling the attribute formation for the
4575/// 'value' attribute and inferring the FIRRTL type.
4576void ConstantOp::build(OpBuilder &builder, OperationState &result,
4577 const APSInt &value) {
4578 auto attr = IntegerAttr::get(builder.getContext(), value);
4579 auto type =
4580 IntType::get(builder.getContext(), value.isSigned(), value.getBitWidth());
4581 return build(builder, result, type, attr);
4582}
4583
4584void ConstantOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
4585 // For constants in particular, propagate the value into the result name to
4586 // make it easier to read the IR.
4587 IntType intTy = getType();
4588 assert(intTy);
4589
4590 // Otherwise, build a complex name with the value and type.
4591 SmallString<32> specialNameBuffer;
4592 llvm::raw_svector_ostream specialName(specialNameBuffer);
4593 specialName << 'c';
4594 getValue().print(specialName, /*isSigned:*/ intTy.isSigned());
4595
4596 specialName << (intTy.isSigned() ? "_si" : "_ui");
4597 auto width = intTy.getWidthOrSentinel();
4598 if (width != -1)
4599 specialName << width;
4600 setNameFn(getResult(), specialName.str());
4601}
4602
4603void SpecialConstantOp::print(OpAsmPrinter &p) {
4604 p << " ";
4605 // SpecialConstant uses a BoolAttr, and we want to print `true` as `1`.
4606 p << static_cast<unsigned>(getValue());
4607 p << " : ";
4608 p.printType(getType());
4609 p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"value"});
4610}
4611
4612ParseResult SpecialConstantOp::parse(OpAsmParser &parser,
4613 OperationState &result) {
4614 auto &properties = result.getOrAddProperties<Properties>();
4615 // Parse the constant value. SpecialConstant uses bool attributes, but it
4616 // prints as an integer.
4617 APInt value;
4618 auto loc = parser.getCurrentLocation();
4619 auto valueResult = parser.parseOptionalInteger(value);
4620 if (!valueResult.has_value())
4621 return parser.emitError(loc, "expected integer value");
4622
4623 // Clocks and resets can only be 0 or 1.
4624 if (value != 0 && value != 1)
4625 return parser.emitError(loc, "special constants can only be 0 or 1.");
4626
4627 // Parse the result firrtl type.
4628 Type resultType;
4629 if (failed(*valueResult) || parser.parseColonType(resultType) ||
4630 parser.parseOptionalAttrDict(result.attributes))
4631 return failure();
4632 result.addTypes(resultType);
4633
4634 // Create the attribute.
4635 auto valueAttr = parser.getBuilder().getBoolAttr(value == 1);
4636 properties.setValue(valueAttr);
4637 return success();
4638}
4639
4640void SpecialConstantOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
4641 SmallString<32> specialNameBuffer;
4642 llvm::raw_svector_ostream specialName(specialNameBuffer);
4643 specialName << 'c';
4644 specialName << static_cast<unsigned>(getValue());
4645 auto type = getType();
4646 if (type_isa<ClockType>(type)) {
4647 specialName << "_clock";
4648 } else if (type_isa<ResetType>(type)) {
4649 specialName << "_reset";
4650 } else if (type_isa<AsyncResetType>(type)) {
4651 specialName << "_asyncreset";
4652 }
4653 setNameFn(getResult(), specialName.str());
4654}
4655
4656// Checks that an array attr representing an aggregate constant has the correct
4657// shape. This recurses on the type.
4658static bool checkAggConstant(Operation *op, Attribute attr,
4659 FIRRTLBaseType type) {
4660 if (type.isGround()) {
4661 if (!isa<IntegerAttr>(attr)) {
4662 op->emitOpError("Ground type is not an integer attribute");
4663 return false;
4664 }
4665 return true;
4666 }
4667 auto attrlist = dyn_cast<ArrayAttr>(attr);
4668 if (!attrlist) {
4669 op->emitOpError("expected array attribute for aggregate constant");
4670 return false;
4671 }
4672 if (auto array = type_dyn_cast<FVectorType>(type)) {
4673 if (array.getNumElements() != attrlist.size()) {
4674 op->emitOpError("array attribute (")
4675 << attrlist.size() << ") has wrong size for vector constant ("
4676 << array.getNumElements() << ")";
4677 return false;
4678 }
4679 return llvm::all_of(attrlist, [&array, op](Attribute attr) {
4680 return checkAggConstant(op, attr, array.getElementType());
4681 });
4682 }
4683 if (auto bundle = type_dyn_cast<BundleType>(type)) {
4684 if (bundle.getNumElements() != attrlist.size()) {
4685 op->emitOpError("array attribute (")
4686 << attrlist.size() << ") has wrong size for bundle constant ("
4687 << bundle.getNumElements() << ")";
4688 return false;
4689 }
4690 for (size_t i = 0; i < bundle.getNumElements(); ++i) {
4691 if (bundle.getElement(i).isFlip) {
4692 op->emitOpError("Cannot have constant bundle type with flip");
4693 return false;
4694 }
4695 if (!checkAggConstant(op, attrlist[i], bundle.getElement(i).type))
4696 return false;
4697 }
4698 return true;
4699 }
4700 op->emitOpError("Unknown aggregate type");
4701 return false;
4702}
4703
4704LogicalResult AggregateConstantOp::verify() {
4705 if (checkAggConstant(getOperation(), getFields(), getType()))
4706 return success();
4707 return failure();
4708}
4709
4710Attribute AggregateConstantOp::getAttributeFromFieldID(uint64_t fieldID) {
4711 FIRRTLBaseType type = getType();
4712 Attribute value = getFields();
4713 while (fieldID != 0) {
4714 if (auto bundle = type_dyn_cast<BundleType>(type)) {
4715 auto index = bundle.getIndexForFieldID(fieldID);
4716 fieldID -= bundle.getFieldID(index);
4717 type = bundle.getElementType(index);
4718 value = cast<ArrayAttr>(value)[index];
4719 } else {
4720 auto vector = type_cast<FVectorType>(type);
4721 auto index = vector.getIndexForFieldID(fieldID);
4722 fieldID -= vector.getFieldID(index);
4723 type = vector.getElementType();
4724 value = cast<ArrayAttr>(value)[index];
4725 }
4726 }
4727 return value;
4728}
4729
4730void FIntegerConstantOp::print(OpAsmPrinter &p) {
4731 p << " ";
4732 p.printAttributeWithoutType(getValueAttr());
4733 p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"value"});
4734}
4735
4736ParseResult FIntegerConstantOp::parse(OpAsmParser &parser,
4737 OperationState &result) {
4738 auto *context = parser.getContext();
4739 auto &properties = result.getOrAddProperties<Properties>();
4740 APInt value;
4741 if (parser.parseInteger(value) ||
4742 parser.parseOptionalAttrDict(result.attributes))
4743 return failure();
4744 result.addTypes(FIntegerType::get(context));
4745 auto intType =
4746 IntegerType::get(context, value.getBitWidth(), IntegerType::Signed);
4747 auto valueAttr = parser.getBuilder().getIntegerAttr(intType, value);
4748 properties.setValue(valueAttr);
4749 return success();
4750}
4751
4752ParseResult ListCreateOp::parse(OpAsmParser &parser, OperationState &result) {
4753 llvm::SmallVector<OpAsmParser::UnresolvedOperand, 16> operands;
4754 ListType type;
4755
4756 if (parser.parseOperandList(operands) ||
4757 parser.parseOptionalAttrDict(result.attributes) ||
4758 parser.parseColonType(type))
4759 return failure();
4760 result.addTypes(type);
4761
4762 return parser.resolveOperands(operands, type.getElementType(),
4763 result.operands);
4764}
4765
4766void ListCreateOp::print(OpAsmPrinter &p) {
4767 p << " ";
4768 p.printOperands(getElements());
4769 p.printOptionalAttrDict((*this)->getAttrs());
4770 p << " : " << getType();
4771}
4772
4773LogicalResult ListCreateOp::verify() {
4774 if (getElements().empty())
4775 return success();
4776
4777 auto elementType = getElements().front().getType();
4778 auto listElementType = getType().getElementType();
4779 if (elementType != listElementType)
4780 return emitOpError("has elements of type ")
4781 << elementType << " instead of " << listElementType;
4782
4783 return success();
4784}
4785
4786LogicalResult BundleCreateOp::verify() {
4787 BundleType resultType = getType();
4788 if (resultType.getNumElements() != getFields().size())
4789 return emitOpError("number of fields doesn't match type");
4790 for (size_t i = 0; i < resultType.getNumElements(); ++i)
4792 resultType.getElementTypePreservingConst(i),
4793 type_cast<FIRRTLBaseType>(getOperand(i).getType())))
4794 return emitOpError("type of element doesn't match bundle for field ")
4795 << resultType.getElement(i).name;
4796 // TODO: check flow
4797 return success();
4798}
4799
4800LogicalResult VectorCreateOp::verify() {
4801 FVectorType resultType = getType();
4802 if (resultType.getNumElements() != getFields().size())
4803 return emitOpError("number of fields doesn't match type");
4804 auto elemTy = resultType.getElementTypePreservingConst();
4805 for (size_t i = 0; i < resultType.getNumElements(); ++i)
4807 elemTy, type_cast<FIRRTLBaseType>(getOperand(i).getType())))
4808 return emitOpError("type of element doesn't match vector element");
4809 // TODO: check flow
4810 return success();
4811}
4812
4813//===----------------------------------------------------------------------===//
4814// FEnumCreateOp
4815//===----------------------------------------------------------------------===//
4816
4817LogicalResult FEnumCreateOp::verify() {
4818 FEnumType resultType = getResult().getType();
4819 auto elementIndex = resultType.getElementIndex(getFieldName());
4820 if (!elementIndex)
4821 return emitOpError("label ")
4822 << getFieldName() << " is not a member of the enumeration type "
4823 << resultType;
4825 resultType.getElementTypePreservingConst(*elementIndex),
4826 getInput().getType()))
4827 return emitOpError("type of element doesn't match enum element");
4828 return success();
4829}
4830
4831void FEnumCreateOp::print(OpAsmPrinter &printer) {
4832 printer << ' ';
4833 printer.printKeywordOrString(getFieldName());
4834 printer << '(' << getInput() << ')';
4835 SmallVector<StringRef> elidedAttrs = {"fieldIndex"};
4836 printer.printOptionalAttrDictWithKeyword((*this)->getAttrs(), elidedAttrs);
4837 printer << " : ";
4838 printer.printFunctionalType(ArrayRef<Type>{getInput().getType()},
4839 ArrayRef<Type>{getResult().getType()});
4840}
4841
4842ParseResult FEnumCreateOp::parse(OpAsmParser &parser, OperationState &result) {
4843 auto *context = parser.getContext();
4844 auto &properties = result.getOrAddProperties<Properties>();
4845
4846 OpAsmParser::UnresolvedOperand input;
4847 std::string fieldName;
4848 mlir::FunctionType functionType;
4849 if (parser.parseKeywordOrString(&fieldName) || parser.parseLParen() ||
4850 parser.parseOperand(input) || parser.parseRParen() ||
4851 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
4852 parser.parseType(functionType))
4853 return failure();
4854
4855 if (functionType.getNumInputs() != 1)
4856 return parser.emitError(parser.getNameLoc(), "single input type required");
4857 if (functionType.getNumResults() != 1)
4858 return parser.emitError(parser.getNameLoc(), "single result type required");
4859
4860 auto inputType = functionType.getInput(0);
4861 if (parser.resolveOperand(input, inputType, result.operands))
4862 return failure();
4863
4864 auto outputType = functionType.getResult(0);
4865 auto enumType = type_dyn_cast<FEnumType>(outputType);
4866 if (!enumType)
4867 return parser.emitError(parser.getNameLoc(),
4868 "output must be enum type, got ")
4869 << outputType;
4870 auto fieldIndex = enumType.getElementIndex(fieldName);
4871 if (!fieldIndex)
4872 return parser.emitError(parser.getNameLoc(),
4873 "unknown field " + fieldName + " in enum type ")
4874 << enumType;
4875
4876 properties.setFieldIndex(
4877 IntegerAttr::get(IntegerType::get(context, 32), *fieldIndex));
4878
4879 result.addTypes(enumType);
4880
4881 return success();
4882}
4883
4884//===----------------------------------------------------------------------===//
4885// IsTagOp
4886//===----------------------------------------------------------------------===//
4887
4888LogicalResult IsTagOp::verify() {
4889 if (getFieldIndex() >= getInput().getType().base().getNumElements())
4890 return emitOpError("element index is greater than the number of fields in "
4891 "the bundle type");
4892 return success();
4893}
4894
4895void IsTagOp::print(::mlir::OpAsmPrinter &printer) {
4896 printer << ' ' << getInput() << ' ';
4897 printer.printKeywordOrString(getFieldName());
4898 SmallVector<::llvm::StringRef, 1> elidedAttrs = {"fieldIndex"};
4899 printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
4900 printer << " : " << getInput().getType();
4901}
4902
4903ParseResult IsTagOp::parse(OpAsmParser &parser, OperationState &result) {
4904 auto *context = parser.getContext();
4905 auto &properties = result.getOrAddProperties<Properties>();
4906
4907 OpAsmParser::UnresolvedOperand input;
4908 std::string fieldName;
4909 Type inputType;
4910 if (parser.parseOperand(input) || parser.parseKeywordOrString(&fieldName) ||
4911 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
4912 parser.parseType(inputType))
4913 return failure();
4914
4915 if (parser.resolveOperand(input, inputType, result.operands))
4916 return failure();
4917
4918 auto enumType = type_dyn_cast<FEnumType>(inputType);
4919 if (!enumType)
4920 return parser.emitError(parser.getNameLoc(),
4921 "input must be enum type, got ")
4922 << inputType;
4923 auto fieldIndex = enumType.getElementIndex(fieldName);
4924 if (!fieldIndex)
4925 return parser.emitError(parser.getNameLoc(),
4926 "unknown field " + fieldName + " in enum type ")
4927 << enumType;
4928
4929 properties.setFieldIndex(
4930 IntegerAttr::get(IntegerType::get(context, 32), *fieldIndex));
4931
4932 result.addTypes(UIntType::get(context, 1, /*isConst=*/false));
4933
4934 return success();
4935}
4936
4937FIRRTLType IsTagOp::inferReturnType(ValueRange operands, DictionaryAttr attrs,
4938 OpaqueProperties properties,
4939 mlir::RegionRange regions,
4940 std::optional<Location> loc) {
4941 Adaptor adaptor(operands, attrs, properties, regions);
4942 return UIntType::get(attrs.getContext(), 1,
4943 isConst(adaptor.getInput().getType()));
4944}
4945
4946template <typename OpTy>
4947ParseResult parseSubfieldLikeOp(OpAsmParser &parser, OperationState &result) {
4948 auto *context = parser.getContext();
4949
4950 OpAsmParser::UnresolvedOperand input;
4951 std::string fieldName;
4952 Type inputType;
4953 if (parser.parseOperand(input) || parser.parseLSquare() ||
4954 parser.parseKeywordOrString(&fieldName) || parser.parseRSquare() ||
4955 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
4956 parser.parseType(inputType))
4957 return failure();
4958
4959 if (parser.resolveOperand(input, inputType, result.operands))
4960 return failure();
4961
4962 auto bundleType = type_dyn_cast<typename OpTy::InputType>(inputType);
4963 if (!bundleType)
4964 return parser.emitError(parser.getNameLoc(),
4965 "input must be bundle type, got ")
4966 << inputType;
4967 auto fieldIndex = bundleType.getElementIndex(fieldName);
4968 if (!fieldIndex)
4969 return parser.emitError(parser.getNameLoc(),
4970 "unknown field " + fieldName + " in bundle type ")
4971 << bundleType;
4972
4973 result.getOrAddProperties<typename OpTy::Properties>().setFieldIndex(
4974 IntegerAttr::get(IntegerType::get(context, 32), *fieldIndex));
4975
4976 auto type = OpTy::inferReturnType(inputType, *fieldIndex, {});
4977 if (!type)
4978 return failure();
4979 result.addTypes(type);
4980
4981 return success();
4982}
4983
4984ParseResult SubtagOp::parse(OpAsmParser &parser, OperationState &result) {
4985 auto *context = parser.getContext();
4986
4987 OpAsmParser::UnresolvedOperand input;
4988 std::string fieldName;
4989 Type inputType;
4990 if (parser.parseOperand(input) || parser.parseLSquare() ||
4991 parser.parseKeywordOrString(&fieldName) || parser.parseRSquare() ||
4992 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
4993 parser.parseType(inputType))
4994 return failure();
4995
4996 if (parser.resolveOperand(input, inputType, result.operands))
4997 return failure();
4998
4999 auto enumType = type_dyn_cast<FEnumType>(inputType);
5000 if (!enumType)
5001 return parser.emitError(parser.getNameLoc(),
5002 "input must be enum type, got ")
5003 << inputType;
5004 auto fieldIndex = enumType.getElementIndex(fieldName);
5005 if (!fieldIndex)
5006 return parser.emitError(parser.getNameLoc(),
5007 "unknown field " + fieldName + " in enum type ")
5008 << enumType;
5009
5010 result.getOrAddProperties<Properties>().setFieldIndex(
5011 IntegerAttr::get(IntegerType::get(context, 32), *fieldIndex));
5012
5013 SmallVector<Type> inferredReturnTypes;
5014 if (failed(SubtagOp::inferReturnTypes(
5015 context, result.location, result.operands,
5016 result.attributes.getDictionary(context), result.getRawProperties(),
5017 result.regions, inferredReturnTypes)))
5018 return failure();
5019 result.addTypes(inferredReturnTypes);
5020
5021 return success();
5022}
5023
5024ParseResult SubfieldOp::parse(OpAsmParser &parser, OperationState &result) {
5025 return parseSubfieldLikeOp<SubfieldOp>(parser, result);
5026}
5027ParseResult OpenSubfieldOp::parse(OpAsmParser &parser, OperationState &result) {
5028 return parseSubfieldLikeOp<OpenSubfieldOp>(parser, result);
5029}
5030
5031template <typename OpTy>
5032static void printSubfieldLikeOp(OpTy op, ::mlir::OpAsmPrinter &printer) {
5033 printer << ' ' << op.getInput() << '[';
5034 printer.printKeywordOrString(op.getFieldName());
5035 printer << ']';
5036 ::llvm::SmallVector<::llvm::StringRef, 2> elidedAttrs;
5037 elidedAttrs.push_back("fieldIndex");
5038 printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
5039 printer << " : " << op.getInput().getType();
5040}
5041void SubfieldOp::print(::mlir::OpAsmPrinter &printer) {
5042 return printSubfieldLikeOp<SubfieldOp>(*this, printer);
5043}
5044void OpenSubfieldOp::print(::mlir::OpAsmPrinter &printer) {
5045 return printSubfieldLikeOp<OpenSubfieldOp>(*this, printer);
5046}
5047
5048void SubtagOp::print(::mlir::OpAsmPrinter &printer) {
5049 printer << ' ' << getInput() << '[';
5050 printer.printKeywordOrString(getFieldName());
5051 printer << ']';
5052 ::llvm::SmallVector<::llvm::StringRef, 2> elidedAttrs;
5053 elidedAttrs.push_back("fieldIndex");
5054 printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
5055 printer << " : " << getInput().getType();
5056}
5057
5058template <typename OpTy>
5059static LogicalResult verifySubfieldLike(OpTy op) {
5060 if (op.getFieldIndex() >=
5061 firrtl::type_cast<typename OpTy::InputType>(op.getInput().getType())
5062 .getNumElements())
5063 return op.emitOpError("subfield element index is greater than the number "
5064 "of fields in the bundle type");
5065 return success();
5066}
5067LogicalResult SubfieldOp::verify() {
5068 return verifySubfieldLike<SubfieldOp>(*this);
5069}
5070LogicalResult OpenSubfieldOp::verify() {
5071 return verifySubfieldLike<OpenSubfieldOp>(*this);
5072}
5073
5074LogicalResult SubtagOp::verify() {
5075 if (getFieldIndex() >= getInput().getType().base().getNumElements())
5076 return emitOpError("subfield element index is greater than the number "
5077 "of fields in the bundle type");
5078 return success();
5079}
5080
5081/// Return true if the specified operation has a constant value. This trivially
5082/// checks for `firrtl.constant` and friends, but also looks through subaccesses
5083/// and correctly handles wires driven with only constant values.
5084bool firrtl::isConstant(Operation *op) {
5085 // Worklist of ops that need to be examined that should all be constant in
5086 // order for the input operation to be constant.
5087 SmallVector<Operation *, 8> worklist({op});
5088
5089 // Mutable state indicating if this op is a constant. Assume it is a constant
5090 // and look for counterexamples.
5091 bool constant = true;
5092
5093 // While we haven't found a counterexample and there are still ops in the
5094 // worklist, pull ops off the worklist. If it provides a counterexample, set
5095 // the `constant` to false (and exit on the next loop iteration). Otherwise,
5096 // look through the op or spawn off more ops to look at.
5097 while (constant && !(worklist.empty()))
5098 TypeSwitch<Operation *>(worklist.pop_back_val())
5099 .Case<NodeOp, AsSIntPrimOp, AsUIntPrimOp>([&](auto op) {
5100 if (auto definingOp = op.getInput().getDefiningOp())
5101 worklist.push_back(definingOp);
5102 constant = false;
5103 })
5104 .Case<WireOp, SubindexOp, SubfieldOp>([&](auto op) {
5105 for (auto &use : op.getResult().getUses())
5106 worklist.push_back(use.getOwner());
5107 })
5108 .Case<ConstantOp, SpecialConstantOp, AggregateConstantOp>([](auto) {})
5109 .Default([&](auto) { constant = false; });
5110
5111 return constant;
5112}
5113
5114/// Return true if the specified value is a constant. This trivially checks for
5115/// `firrtl.constant` and friends, but also looks through subaccesses and
5116/// correctly handles wires driven with only constant values.
5117bool firrtl::isConstant(Value value) {
5118 if (auto *op = value.getDefiningOp())
5119 return isConstant(op);
5120 return false;
5121}
5122
5123LogicalResult ConstCastOp::verify() {
5124 if (!areTypesConstCastable(getResult().getType(), getInput().getType()))
5125 return emitOpError() << getInput().getType()
5126 << " is not 'const'-castable to "
5127 << getResult().getType();
5128 return success();
5129}
5130
5131FIRRTLType SubfieldOp::inferReturnType(Type type, uint32_t fieldIndex,
5132 std::optional<Location> loc) {
5133 auto inType = type_cast<BundleType>(type);
5134
5135 if (fieldIndex >= inType.getNumElements())
5136 return emitInferRetTypeError(loc,
5137 "subfield element index is greater than the "
5138 "number of fields in the bundle type");
5139
5140 // SubfieldOp verifier checks that the field index is valid with number of
5141 // subelements.
5142 return inType.getElementTypePreservingConst(fieldIndex);
5143}
5144
5145FIRRTLType OpenSubfieldOp::inferReturnType(Type type, uint32_t fieldIndex,
5146 std::optional<Location> loc) {
5147 auto inType = type_cast<OpenBundleType>(type);
5148
5149 if (fieldIndex >= inType.getNumElements())
5150 return emitInferRetTypeError(loc,
5151 "subfield element index is greater than the "
5152 "number of fields in the bundle type");
5153
5154 // OpenSubfieldOp verifier checks that the field index is valid with number of
5155 // subelements.
5156 return inType.getElementTypePreservingConst(fieldIndex);
5157}
5158
5159bool SubfieldOp::isFieldFlipped() {
5160 BundleType bundle = getInput().getType();
5161 return bundle.getElement(getFieldIndex()).isFlip;
5162}
5163bool OpenSubfieldOp::isFieldFlipped() {
5164 auto bundle = getInput().getType();
5165 return bundle.getElement(getFieldIndex()).isFlip;
5166}
5167
5168FIRRTLType SubindexOp::inferReturnType(Type type, uint32_t fieldIndex,
5169 std::optional<Location> loc) {
5170 if (auto vectorType = type_dyn_cast<FVectorType>(type)) {
5171 if (fieldIndex < vectorType.getNumElements())
5172 return vectorType.getElementTypePreservingConst();
5173 return emitInferRetTypeError(loc, "out of range index '", fieldIndex,
5174 "' in vector type ", type);
5175 }
5176 return emitInferRetTypeError(loc, "subindex requires vector operand");
5177}
5178
5179FIRRTLType OpenSubindexOp::inferReturnType(Type type, uint32_t fieldIndex,
5180 std::optional<Location> loc) {
5181 if (auto vectorType = type_dyn_cast<OpenVectorType>(type)) {
5182 if (fieldIndex < vectorType.getNumElements())
5183 return vectorType.getElementTypePreservingConst();
5184 return emitInferRetTypeError(loc, "out of range index '", fieldIndex,
5185 "' in vector type ", type);
5186 }
5187
5188 return emitInferRetTypeError(loc, "subindex requires vector operand");
5189}
5190
5191FIRRTLType SubtagOp::inferReturnType(ValueRange operands, DictionaryAttr attrs,
5192 OpaqueProperties properties,
5193 mlir::RegionRange regions,
5194 std::optional<Location> loc) {
5195 Adaptor adaptor(operands, attrs, properties, regions);
5196 auto inType = type_cast<FEnumType>(adaptor.getInput().getType());
5197 auto fieldIndex = adaptor.getFieldIndex();
5198
5199 if (fieldIndex >= inType.getNumElements())
5200 return emitInferRetTypeError(loc,
5201 "subtag element index is greater than the "
5202 "number of fields in the enum type");
5203
5204 // SubtagOp verifier checks that the field index is valid with number of
5205 // subelements.
5206 auto elementType = inType.getElement(fieldIndex).type;
5207 return elementType.getConstType(elementType.isConst() || inType.isConst());
5208}
5209
5210FIRRTLType SubaccessOp::inferReturnType(Type inType, Type indexType,
5211 std::optional<Location> loc) {
5212 if (!type_isa<UIntType>(indexType))
5213 return emitInferRetTypeError(loc, "subaccess index must be UInt type, not ",
5214 indexType);
5215
5216 if (auto vectorType = type_dyn_cast<FVectorType>(inType)) {
5217 if (isConst(indexType))
5218 return vectorType.getElementTypePreservingConst();
5219 return vectorType.getElementType().getAllConstDroppedType();
5220 }
5221
5222 return emitInferRetTypeError(loc, "subaccess requires vector operand, not ",
5223 inType);
5224}
5225
5226FIRRTLType TagExtractOp::inferReturnType(FIRRTLType input,
5227 std::optional<Location> loc) {
5228 auto inType = type_cast<FEnumType>(input);
5229 return UIntType::get(inType.getContext(), inType.getTagWidth());
5230}
5231
5232ParseResult MultibitMuxOp::parse(OpAsmParser &parser, OperationState &result) {
5233 OpAsmParser::UnresolvedOperand index;
5234 SmallVector<OpAsmParser::UnresolvedOperand, 16> inputs;
5235 Type indexType, elemType;
5236
5237 if (parser.parseOperand(index) || parser.parseComma() ||
5238 parser.parseOperandList(inputs) ||
5239 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
5240 parser.parseType(indexType) || parser.parseComma() ||
5241 parser.parseType(elemType))
5242 return failure();
5243
5244 if (parser.resolveOperand(index, indexType, result.operands))
5245 return failure();
5246
5247 result.addTypes(elemType);
5248
5249 return parser.resolveOperands(inputs, elemType, result.operands);
5250}
5251
5252void MultibitMuxOp::print(OpAsmPrinter &p) {
5253 p << " " << getIndex() << ", ";
5254 p.printOperands(getInputs());
5255 p.printOptionalAttrDict((*this)->getAttrs());
5256 p << " : " << getIndex().getType() << ", " << getType();
5257}
5258
5259FIRRTLType MultibitMuxOp::inferReturnType(ValueRange operands,
5260 DictionaryAttr attrs,
5261 OpaqueProperties properties,
5262 mlir::RegionRange regions,
5263 std::optional<Location> loc) {
5264 if (operands.size() < 2)
5265 return emitInferRetTypeError(loc, "at least one input is required");
5266
5267 // Check all mux inputs have the same type.
5268 if (!llvm::all_of(operands.drop_front(2), [&](auto op) {
5269 return operands[1].getType() == op.getType();
5270 }))
5271 return emitInferRetTypeError(loc, "all inputs must have the same type");
5272
5273 return type_cast<FIRRTLType>(operands[1].getType());
5274}
5275
5276//===----------------------------------------------------------------------===//
5277// ObjectSubfieldOp
5278//===----------------------------------------------------------------------===//
5279
5280LogicalResult ObjectSubfieldOp::inferReturnTypes(
5281 MLIRContext *context, std::optional<mlir::Location> location,
5282 ValueRange operands, DictionaryAttr attributes, OpaqueProperties properties,
5283 RegionRange regions, llvm::SmallVectorImpl<Type> &inferredReturnTypes) {
5284 auto type =
5285 inferReturnType(operands, attributes, properties, regions, location);
5286 if (!type)
5287 return failure();
5288 inferredReturnTypes.push_back(type);
5289 return success();
5290}
5291
5292Type ObjectSubfieldOp::inferReturnType(Type inType, uint32_t fieldIndex,
5293 std::optional<Location> loc) {
5294 auto classType = dyn_cast<ClassType>(inType);
5295 if (!classType)
5296 return emitInferRetTypeError(loc, "base object is not a class");
5297
5298 if (classType.getNumElements() <= fieldIndex)
5299 return emitInferRetTypeError(loc, "element index is greater than the "
5300 "number of fields in the object");
5301 return classType.getElement(fieldIndex).type;
5302}
5303
5304void ObjectSubfieldOp::print(OpAsmPrinter &p) {
5305 auto input = getInput();
5306 auto classType = input.getType();
5307 p << ' ' << input << "[";
5308 p.printKeywordOrString(classType.getElement(getIndex()).name);
5309 p << "]";
5310 p.printOptionalAttrDict((*this)->getAttrs(), std::array{StringRef("index")});
5311 p << " : " << classType;
5312}
5313
5314ParseResult ObjectSubfieldOp::parse(OpAsmParser &parser,
5315 OperationState &result) {
5316 auto *context = parser.getContext();
5317
5318 OpAsmParser::UnresolvedOperand input;
5319 std::string fieldName;
5320 ClassType inputType;
5321 if (parser.parseOperand(input) || parser.parseLSquare() ||
5322 parser.parseKeywordOrString(&fieldName) || parser.parseRSquare() ||
5323 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
5324 parser.parseType(inputType) ||
5325 parser.resolveOperand(input, inputType, result.operands))
5326 return failure();
5327
5328 auto index = inputType.getElementIndex(fieldName);
5329 if (!index)
5330 return parser.emitError(parser.getNameLoc(),
5331 "unknown field " + fieldName + " in class type ")
5332 << inputType;
5333 result.getOrAddProperties<Properties>().setIndex(
5334 IntegerAttr::get(IntegerType::get(context, 32), *index));
5335
5336 SmallVector<Type> inferredReturnTypes;
5337 if (failed(inferReturnTypes(context, result.location, result.operands,
5338 result.attributes.getDictionary(context),
5339 result.getRawProperties(), result.regions,
5340 inferredReturnTypes)))
5341 return failure();
5342 result.addTypes(inferredReturnTypes);
5343
5344 return success();
5345}
5346
5347//===----------------------------------------------------------------------===//
5348// Binary Primitives
5349//===----------------------------------------------------------------------===//
5350
5351/// If LHS and RHS are both UInt or SInt types, the return true and fill in the
5352/// width of them if known. If unknown, return -1 for the widths.
5353/// The constness of the result is also returned, where if both lhs and rhs are
5354/// const, then the result is const.
5355///
5356/// On failure, this reports and error and returns false. This function should
5357/// not be used if you don't want an error reported.
5358static bool isSameIntTypeKind(Type lhs, Type rhs, int32_t &lhsWidth,
5359 int32_t &rhsWidth, bool &isConstResult,
5360 std::optional<Location> loc) {
5361 // Must have two integer types with the same signedness.
5362 auto lhsi = type_dyn_cast<IntType>(lhs);
5363 auto rhsi = type_dyn_cast<IntType>(rhs);
5364 if (!lhsi || !rhsi || lhsi.isSigned() != rhsi.isSigned()) {
5365 if (loc) {
5366 if (lhsi && !rhsi)
5367 mlir::emitError(*loc, "second operand must be an integer type, not ")
5368 << rhs;
5369 else if (!lhsi && rhsi)
5370 mlir::emitError(*loc, "first operand must be an integer type, not ")
5371 << lhs;
5372 else if (!lhsi && !rhsi)
5373 mlir::emitError(*loc, "operands must be integer types, not ")
5374 << lhs << " and " << rhs;
5375 else
5376 mlir::emitError(*loc, "operand signedness must match");
5377 }
5378 return false;
5379 }
5380
5381 lhsWidth = lhsi.getWidthOrSentinel();
5382 rhsWidth = rhsi.getWidthOrSentinel();
5383 isConstResult = lhsi.isConst() && rhsi.isConst();
5384 return true;
5385}
5386
5387LogicalResult impl::verifySameOperandsIntTypeKind(Operation *op) {
5388 assert(op->getNumOperands() == 2 &&
5389 "SameOperandsIntTypeKind on non-binary op");
5390 int32_t lhsWidth, rhsWidth;
5391 bool isConstResult;
5392 return success(isSameIntTypeKind(op->getOperand(0).getType(),
5393 op->getOperand(1).getType(), lhsWidth,
5394 rhsWidth, isConstResult, op->getLoc()));
5395}
5396
5398 std::optional<Location> loc) {
5399 int32_t lhsWidth, rhsWidth, resultWidth = -1;
5400 bool isConstResult = false;
5401 if (!isSameIntTypeKind(lhs, rhs, lhsWidth, rhsWidth, isConstResult, loc))
5402 return {};
5403
5404 if (lhsWidth != -1 && rhsWidth != -1)
5405 resultWidth = std::max(lhsWidth, rhsWidth) + 1;
5406 return IntType::get(lhs.getContext(), type_isa<SIntType>(lhs), resultWidth,
5407 isConstResult);
5408}
5409
5410FIRRTLType MulPrimOp::inferReturnType(FIRRTLType lhs, FIRRTLType rhs,
5411 std::optional<Location> loc) {
5412 int32_t lhsWidth, rhsWidth, resultWidth = -1;
5413 bool isConstResult = false;
5414 if (!isSameIntTypeKind(lhs, rhs, lhsWidth, rhsWidth, isConstResult, loc))
5415 return {};
5416
5417 if (lhsWidth != -1 && rhsWidth != -1)
5418 resultWidth = lhsWidth + rhsWidth;
5419
5420 return IntType::get(lhs.getContext(), type_isa<SIntType>(lhs), resultWidth,
5421 isConstResult);
5422}
5423
5424FIRRTLType DivPrimOp::inferReturnType(FIRRTLType lhs, FIRRTLType rhs,
5425 std::optional<Location> loc) {
5426 int32_t lhsWidth, rhsWidth;
5427 bool isConstResult = false;
5428 if (!isSameIntTypeKind(lhs, rhs, lhsWidth, rhsWidth, isConstResult, loc))
5429 return {};
5430
5431 // For unsigned, the width is the width of the numerator on the LHS.
5432 if (type_isa<UIntType>(lhs))
5433 return UIntType::get(lhs.getContext(), lhsWidth, isConstResult);
5434
5435 // For signed, the width is the width of the numerator on the LHS, plus 1.
5436 int32_t resultWidth = lhsWidth != -1 ? lhsWidth + 1 : -1;
5437 return SIntType::get(lhs.getContext(), resultWidth, isConstResult);
5438}
5439
5440FIRRTLType RemPrimOp::inferReturnType(FIRRTLType lhs, FIRRTLType rhs,
5441 std::optional<Location> loc) {
5442 int32_t lhsWidth, rhsWidth, resultWidth = -1;
5443 bool isConstResult = false;
5444 if (!isSameIntTypeKind(lhs, rhs, lhsWidth, rhsWidth, isConstResult, loc))
5445 return {};
5446
5447 if (lhsWidth != -1 && rhsWidth != -1)
5448 resultWidth = std::min(lhsWidth, rhsWidth);
5449 return IntType::get(lhs.getContext(), type_isa<SIntType>(lhs), resultWidth,
5450 isConstResult);
5451}
5452
5454 std::optional<Location> loc) {
5455 int32_t lhsWidth, rhsWidth, resultWidth = -1;
5456 bool isConstResult = false;
5457 if (!isSameIntTypeKind(lhs, rhs, lhsWidth, rhsWidth, isConstResult, loc))
5458 return {};
5459
5460 if (lhsWidth != -1 && rhsWidth != -1) {
5461 resultWidth = std::max(lhsWidth, rhsWidth);
5462 if (lhsWidth == resultWidth && lhs.isConst() == isConstResult &&
5463 isa<UIntType>(lhs))
5464 return lhs;
5465 if (rhsWidth == resultWidth && rhs.isConst() == isConstResult &&
5466 isa<UIntType>(rhs))
5467 return rhs;
5468 }
5469 return UIntType::get(lhs.getContext(), resultWidth, isConstResult);
5470}
5471
5473 std::optional<Location> loc) {
5474 if (!type_isa<FVectorType>(lhs) || !type_isa<FVectorType>(rhs))
5475 return {};
5476
5477 auto lhsVec = type_cast<FVectorType>(lhs);
5478 auto rhsVec = type_cast<FVectorType>(rhs);
5479
5480 if (lhsVec.getNumElements() != rhsVec.getNumElements())
5481 return {};
5482
5483 auto elemType =
5484 impl::inferBitwiseResult(lhsVec.getElementTypePreservingConst(),
5485 rhsVec.getElementTypePreservingConst(), loc);
5486 if (!elemType)
5487 return {};
5488 auto elemBaseType = type_cast<FIRRTLBaseType>(elemType);
5489 return FVectorType::get(elemBaseType, lhsVec.getNumElements(),
5490 lhsVec.isConst() && rhsVec.isConst() &&
5491 elemBaseType.isConst());
5492}
5493
5495 std::optional<Location> loc) {
5496 return UIntType::get(lhs.getContext(), 1, isConst(lhs) && isConst(rhs));
5497}
5498
5499FIRRTLType CatPrimOp::inferReturnType(ValueRange operands, DictionaryAttr attrs,
5500 OpaqueProperties properties,
5501 mlir::RegionRange regions,
5502 std::optional<Location> loc) {
5503 // If no operands, return a 0-bit UInt
5504 if (operands.empty())
5505 return UIntType::get(attrs.getContext(), 0);
5506
5507 // Check that all operands are Int types with same signedness
5508 bool isSigned = type_isa<SIntType>(operands[0].getType());
5509 for (auto operand : operands) {
5510 auto type = type_dyn_cast<IntType>(operand.getType());
5511 if (!type)
5512 return emitInferRetTypeError(loc, "all operands must be Int type");
5513 if (type.isSigned() != isSigned)
5514 return emitInferRetTypeError(loc,
5515 "all operands must have same signedness");
5516 }
5517
5518 // Calculate the total width and determine if result is const
5519 int32_t resultWidth = 0;
5520 bool isConstResult = true;
5521
5522 for (auto operand : operands) {
5523 auto type = type_cast<IntType>(operand.getType());
5524 int32_t width = type.getWidthOrSentinel();
5525
5526 // If any width is unknown, the result width is unknown
5527 if (width == -1) {
5528 resultWidth = -1;
5529 }
5530
5531 if (resultWidth != -1)
5532 resultWidth += width;
5533
5534 // Result is const only if all operands are const
5535 isConstResult &= type.isConst();
5536 }
5537
5538 // Create and return the result type
5539 return UIntType::get(attrs.getContext(), resultWidth, isConstResult);
5540}
5541
5542FIRRTLType DShlPrimOp::inferReturnType(FIRRTLType lhs, FIRRTLType rhs,
5543 std::optional<Location> loc) {
5544 auto lhsi = type_dyn_cast<IntType>(lhs);
5545 auto rhsui = type_dyn_cast<UIntType>(rhs);
5546 if (!rhsui || !lhsi)
5547 return emitInferRetTypeError(
5548 loc, "first operand should be integer, second unsigned int");
5549
5550 // If the left or right has unknown result type, then the operation does
5551 // too.
5552 auto width = lhsi.getWidthOrSentinel();
5553 if (width == -1 || !rhsui.getWidth().has_value()) {
5554 width = -1;
5555 } else {
5556 auto amount = *rhsui.getWidth();
5557 if (amount >= 32)
5558 return emitInferRetTypeError(loc,
5559 "shift amount too large: second operand of "
5560 "dshl is wider than 31 bits");
5561 int64_t newWidth = (int64_t)width + ((int64_t)1 << amount) - 1;
5562 if (newWidth > INT32_MAX)
5563 return emitInferRetTypeError(
5564 loc, "shift amount too large: first operand shifted by maximum "
5565 "amount exceeds maximum width");
5566 width = newWidth;
5567 }
5568 return IntType::get(lhs.getContext(), lhsi.isSigned(), width,
5569 lhsi.isConst() && rhsui.isConst());
5570}
5571
5572FIRRTLType DShlwPrimOp::inferReturnType(FIRRTLType lhs, FIRRTLType rhs,
5573 std::optional<Location> loc) {
5574 auto lhsi = type_dyn_cast<IntType>(lhs);
5575 auto rhsu = type_dyn_cast<UIntType>(rhs);
5576 if (!lhsi || !rhsu)
5577 return emitInferRetTypeError(
5578 loc, "first operand should be integer, second unsigned int");
5579 return lhsi.getConstType(lhsi.isConst() && rhsu.isConst());
5580}
5581
5582FIRRTLType DShrPrimOp::inferReturnType(FIRRTLType lhs, FIRRTLType rhs,
5583 std::optional<Location> loc) {
5584 auto lhsi = type_dyn_cast<IntType>(lhs);
5585 auto rhsu = type_dyn_cast<UIntType>(rhs);
5586 if (!lhsi || !rhsu)
5587 return emitInferRetTypeError(
5588 loc, "first operand should be integer, second unsigned int");
5589 return lhsi.getConstType(lhsi.isConst() && rhsu.isConst());
5590}
5591
5592//===----------------------------------------------------------------------===//
5593// Unary Primitives
5594//===----------------------------------------------------------------------===//
5595
5596FIRRTLType SizeOfIntrinsicOp::inferReturnType(FIRRTLType input,
5597 std::optional<Location> loc) {
5598 return UIntType::get(input.getContext(), 32);
5599}
5600
5601FIRRTLType AsSIntPrimOp::inferReturnType(FIRRTLType input,
5602 std::optional<Location> loc) {
5603 auto base = type_dyn_cast<FIRRTLBaseType>(input);
5604 if (!base)
5605 return emitInferRetTypeError(loc, "operand must be a scalar base type");
5606 int32_t width = base.getBitWidthOrSentinel();
5607 if (width == -2)
5608 return emitInferRetTypeError(loc, "operand must be a scalar type");
5609 return SIntType::get(input.getContext(), width, base.isConst());
5610}
5611
5612FIRRTLType AsUIntPrimOp::inferReturnType(FIRRTLType input,
5613 std::optional<Location> loc) {
5614 auto base = type_dyn_cast<FIRRTLBaseType>(input);
5615 if (!base)
5616 return emitInferRetTypeError(loc, "operand must be a scalar base type");
5617 int32_t width = base.getBitWidthOrSentinel();
5618 if (width == -2)
5619 return emitInferRetTypeError(loc, "operand must be a scalar type");
5620 return UIntType::get(input.getContext(), width, base.isConst());
5621}
5622
5623FIRRTLType AsAsyncResetPrimOp::inferReturnType(FIRRTLType input,
5624 std::optional<Location> loc) {
5625 auto base = type_dyn_cast<FIRRTLBaseType>(input);
5626 if (!base)
5627 return emitInferRetTypeError(loc,
5628 "operand must be single bit scalar base type");
5629 int32_t width = base.getBitWidthOrSentinel();
5630 if (width == -2 || width == 0 || width > 1)
5631 return emitInferRetTypeError(loc, "operand must be single bit scalar type");
5632 return AsyncResetType::get(input.getContext(), base.isConst());
5633}
5634
5635FIRRTLType AsClockPrimOp::inferReturnType(FIRRTLType input,
5636 std::optional<Location> loc) {
5637 return ClockType::get(input.getContext(), isConst(input));
5638}
5639
5640FIRRTLType CvtPrimOp::inferReturnType(FIRRTLType input,
5641 std::optional<Location> loc) {
5642 if (auto uiType = type_dyn_cast<UIntType>(input)) {
5643 auto width = uiType.getWidthOrSentinel();
5644 if (width != -1)
5645 ++width;
5646 return SIntType::get(input.getContext(), width, uiType.isConst());
5647 }
5648
5649 if (type_isa<SIntType>(input))
5650 return input;
5651
5652 return emitInferRetTypeError(loc, "operand must have integer type");
5653}
5654
5655FIRRTLType NegPrimOp::inferReturnType(FIRRTLType input,
5656 std::optional<Location> loc) {
5657 auto inputi = type_dyn_cast<IntType>(input);
5658 if (!inputi)
5659 return emitInferRetTypeError(loc, "operand must have integer type");
5660 int32_t width = inputi.getWidthOrSentinel();
5661 if (width != -1)
5662 ++width;
5663 return SIntType::get(input.getContext(), width, inputi.isConst());
5664}
5665
5666FIRRTLType NotPrimOp::inferReturnType(FIRRTLType input,
5667 std::optional<Location> loc) {
5668 auto inputi = type_dyn_cast<IntType>(input);
5669 if (!inputi)
5670 return emitInferRetTypeError(loc, "operand must have integer type");
5671 if (isa<UIntType>(inputi))
5672 return inputi;
5673 return UIntType::get(input.getContext(), inputi.getWidthOrSentinel(),
5674 inputi.isConst());
5675}
5676
5678 std::optional<Location> loc) {
5679 return UIntType::get(input.getContext(), 1, isConst(input));
5680}
5681
5682//===----------------------------------------------------------------------===//
5683// Other Operations
5684//===----------------------------------------------------------------------===//
5685
5686FIRRTLType BitsPrimOp::inferReturnType(FIRRTLType input, int64_t high,
5687 int64_t low,
5688 std::optional<Location> loc) {
5689 auto inputi = type_dyn_cast<IntType>(input);
5690 if (!inputi)
5691 return emitInferRetTypeError(
5692 loc, "input type should be the int type but got ", input);
5693
5694 // High must be >= low and both most be non-negative.
5695 if (high < low)
5696 return emitInferRetTypeError(
5697 loc, "high must be equal or greater than low, but got high = ", high,
5698 ", low = ", low);
5699
5700 if (low < 0)
5701 return emitInferRetTypeError(loc, "low must be non-negative but got ", low);
5702
5703 // If the input has staticly known width, check it. Both and low must be
5704 // strictly less than width.
5705 int32_t width = inputi.getWidthOrSentinel();
5706 if (width != -1 && high >= width)
5707 return emitInferRetTypeError(
5708 loc,
5709 "high must be smaller than the width of input, but got high = ", high,
5710 ", width = ", width);
5711
5712 return UIntType::get(input.getContext(), high - low + 1, inputi.isConst());
5713}
5714
5715FIRRTLType HeadPrimOp::inferReturnType(FIRRTLType input, int64_t amount,
5716 std::optional<Location> loc) {
5717
5718 auto inputi = type_dyn_cast<IntType>(input);
5719 if (amount < 0 || !inputi)
5720 return emitInferRetTypeError(
5721 loc, "operand must have integer type and amount must be >= 0");
5722
5723 int32_t width = inputi.getWidthOrSentinel();
5724 if (width != -1 && amount > width)
5725 return emitInferRetTypeError(loc, "amount larger than input width");
5726
5727 return UIntType::get(input.getContext(), amount, inputi.isConst());
5728}
5729
5730/// Infer the result type for a multiplexer given its two operand types, which
5731/// may be aggregates.
5732///
5733/// This essentially performs a pairwise comparison of fields and elements, as
5734/// follows:
5735/// - Identical operands inferred to their common type
5736/// - Integer operands inferred to the larger one if both have a known width, a
5737/// widthless integer otherwise.
5738/// - Vectors inferred based on the element type.
5739/// - Bundles inferred in a pairwise fashion based on the field types.
5741 FIRRTLBaseType low,
5742 bool isConstCondition,
5743 std::optional<Location> loc) {
5744 // If the types are identical we're done.
5745 if (high == low)
5746 return isConstCondition ? low : low.getAllConstDroppedType();
5747
5748 // The base types need to be equivalent.
5749 if (high.getTypeID() != low.getTypeID())
5750 return emitInferRetTypeError<FIRRTLBaseType>(
5751 loc, "incompatible mux operand types, true value type: ", high,
5752 ", false value type: ", low);
5753
5754 bool outerTypeIsConst = isConstCondition && low.isConst() && high.isConst();
5755
5756 // Two different Int types can be compatible. If either has unknown width,
5757 // then return it. If both are known but different width, then return the
5758 // larger one.
5759 if (type_isa<IntType>(low)) {
5760 int32_t highWidth = high.getBitWidthOrSentinel();
5761 int32_t lowWidth = low.getBitWidthOrSentinel();
5762 if (lowWidth == -1)
5763 return low.getConstType(outerTypeIsConst);
5764 if (highWidth == -1)
5765 return high.getConstType(outerTypeIsConst);
5766 return (lowWidth > highWidth ? low : high).getConstType(outerTypeIsConst);
5767 }
5768
5769 // Two different Enum types can be compatible if one is the constant version
5770 // of the other.
5771 auto highEnum = type_dyn_cast<FEnumType>(high);
5772 auto lowEnum = type_dyn_cast<FEnumType>(low);
5773 if (lowEnum && highEnum) {
5774 if (lowEnum.getNumElements() != highEnum.getNumElements())
5775 return emitInferRetTypeError<FIRRTLBaseType>(
5776 loc, "incompatible mux operand types, true value type: ", high,
5777 ", false value type: ", low);
5778 SmallVector<FEnumType::EnumElement> elements;
5779 for (auto [high, low] : llvm::zip_equal(highEnum, lowEnum)) {
5780 // Variants should have the same name and value.
5781 if (high.name != low.name || high.value != low.value)
5782 return emitInferRetTypeError<FIRRTLBaseType>(
5783 loc, "incompatible mux operand types, true value type: ", highEnum,
5784 ", false value type: ", lowEnum);
5785 // Enumerations can only have constant variants only if the whole
5786 // enumeration is constant, so this logic can differ a bit from bundles.
5787 auto inner =
5788 inferMuxReturnType(high.type, low.type, isConstCondition, loc);
5789 if (!inner)
5790 return {};
5791 elements.emplace_back(high.name, high.value, inner);
5792 }
5793 return FEnumType::get(high.getContext(), elements, outerTypeIsConst);
5794 }
5795
5796 // Infer vector types by comparing the element types.
5797 auto highVector = type_dyn_cast<FVectorType>(high);
5798 auto lowVector = type_dyn_cast<FVectorType>(low);
5799 if (highVector && lowVector &&
5800 highVector.getNumElements() == lowVector.getNumElements()) {
5801 auto inner = inferMuxReturnType(highVector.getElementTypePreservingConst(),
5802 lowVector.getElementTypePreservingConst(),
5803 isConstCondition, loc);
5804 if (!inner)
5805 return {};
5806 return FVectorType::get(inner, lowVector.getNumElements(),
5807 outerTypeIsConst);
5808 }
5809
5810 // Infer bundle types by inferring names in a pairwise fashion.
5811 auto highBundle = type_dyn_cast<BundleType>(high);
5812 auto lowBundle = type_dyn_cast<BundleType>(low);
5813 if (highBundle && lowBundle) {
5814 auto highElements = highBundle.getElements();
5815 auto lowElements = lowBundle.getElements();
5816 size_t numElements = highElements.size();
5817
5818 SmallVector<BundleType::BundleElement> newElements;
5819 if (numElements == lowElements.size()) {
5820 bool failed = false;
5821 for (size_t i = 0; i < numElements; ++i) {
5822 if (highElements[i].name != lowElements[i].name ||
5823 highElements[i].isFlip != lowElements[i].isFlip) {
5824 failed = true;
5825 break;
5826 }
5827 auto element = highElements[i];
5828 element.type = inferMuxReturnType(
5829 highBundle.getElementTypePreservingConst(i),
5830 lowBundle.getElementTypePreservingConst(i), isConstCondition, loc);
5831 if (!element.type)
5832 return {};
5833 newElements.push_back(element);
5834 }
5835 if (!failed)
5836 return BundleType::get(low.getContext(), newElements, outerTypeIsConst);
5837 }
5838 return emitInferRetTypeError<FIRRTLBaseType>(
5839 loc, "incompatible mux operand bundle fields, true value type: ", high,
5840 ", false value type: ", low);
5841 }
5842
5843 // If we arrive here the types of the two mux arms are fundamentally
5844 // incompatible.
5845 return emitInferRetTypeError<FIRRTLBaseType>(
5846 loc, "invalid mux operand types, true value type: ", high,
5847 ", false value type: ", low);
5848}
5849
5850FIRRTLType MuxPrimOp::inferReturnType(FIRRTLType sel, FIRRTLType high,
5851 FIRRTLType low,
5852 std::optional<Location> loc) {
5853 auto highType = type_dyn_cast<FIRRTLBaseType>(high);
5854 auto lowType = type_dyn_cast<FIRRTLBaseType>(low);
5855 if (!highType || !lowType)
5856 return emitInferRetTypeError(loc, "operands must be base type");
5857 return inferMuxReturnType(highType, lowType, isConst(sel), loc);
5858}
5859
5860FIRRTLType Mux2CellIntrinsicOp::inferReturnType(ValueRange operands,
5861 DictionaryAttr attrs,
5862 OpaqueProperties properties,
5863 mlir::RegionRange regions,
5864 std::optional<Location> loc) {
5865 auto highType = type_dyn_cast<FIRRTLBaseType>(operands[1].getType());
5866 auto lowType = type_dyn_cast<FIRRTLBaseType>(operands[2].getType());
5867 if (!highType || !lowType)
5868 return emitInferRetTypeError(loc, "operands must be base type");
5869 return inferMuxReturnType(highType, lowType, isConst(operands[0].getType()),
5870 loc);
5871}
5872
5873FIRRTLType Mux4CellIntrinsicOp::inferReturnType(ValueRange operands,
5874 DictionaryAttr attrs,
5875 OpaqueProperties properties,
5876 mlir::RegionRange regions,
5877 std::optional<Location> loc) {
5878 SmallVector<FIRRTLBaseType> types;
5879 FIRRTLBaseType result;
5880 for (unsigned i = 1; i < 5; i++) {
5881 types.push_back(type_dyn_cast<FIRRTLBaseType>(operands[i].getType()));
5882 if (!types.back())
5883 return emitInferRetTypeError(loc, "operands must be base type");
5884 if (result) {
5885 result = inferMuxReturnType(result, types.back(),
5886 isConst(operands[0].getType()), loc);
5887 if (!result)
5888 return result;
5889 } else {
5890 result = types.back();
5891 }
5892 }
5893 return result;
5894}
5895
5896FIRRTLType PadPrimOp::inferReturnType(FIRRTLType input, int64_t amount,
5897 std::optional<Location> loc) {
5898 auto inputi = type_dyn_cast<IntType>(input);
5899 if (amount < 0 || !inputi)
5900 return emitInferRetTypeError(
5901 loc, "pad input must be integer and amount must be >= 0");
5902
5903 int32_t width = inputi.getWidthOrSentinel();
5904 if (width == -1)
5905 return inputi;
5906
5907 width = std::max<int32_t>(width, amount);
5908 return IntType::get(input.getContext(), inputi.isSigned(), width,
5909 inputi.isConst());
5910}
5911
5912FIRRTLType ShlPrimOp::inferReturnType(FIRRTLType input, int64_t amount,
5913 std::optional<Location> loc) {
5914 auto inputi = type_dyn_cast<IntType>(input);
5915 if (amount < 0 || !inputi)
5916 return emitInferRetTypeError(
5917 loc, "shl input must be integer and amount must be >= 0");
5918
5919 int32_t width = inputi.getWidthOrSentinel();
5920 if (width != -1)
5921 width += amount;
5922
5923 return IntType::get(input.getContext(), inputi.isSigned(), width,
5924 inputi.isConst());
5925}
5926
5927FIRRTLType ShrPrimOp::inferReturnType(FIRRTLType input, int64_t amount,
5928 std::optional<Location> loc) {
5929 auto inputi = type_dyn_cast<IntType>(input);
5930 if (amount < 0 || !inputi)
5931 return emitInferRetTypeError(
5932 loc, "shr input must be integer and amount must be >= 0");
5933
5934 int32_t width = inputi.getWidthOrSentinel();
5935 if (width != -1) {
5936 // UInt saturates at 0 bits, SInt at 1 bit
5937 int32_t minWidth = inputi.isUnsigned() ? 0 : 1;
5938 width = std::max<int32_t>(minWidth, width - amount);
5939 }
5940
5941 return IntType::get(input.getContext(), inputi.isSigned(), width,
5942 inputi.isConst());
5943}
5944
5945FIRRTLType TailPrimOp::inferReturnType(FIRRTLType input, int64_t amount,
5946 std::optional<Location> loc) {
5947
5948 auto inputi = type_dyn_cast<IntType>(input);
5949 if (amount < 0 || !inputi)
5950 return emitInferRetTypeError(
5951 loc, "tail input must be integer and amount must be >= 0");
5952
5953 int32_t width = inputi.getWidthOrSentinel();
5954 if (width != -1) {
5955 if (width < amount)
5956 return emitInferRetTypeError(
5957 loc, "amount must be less than or equal operand width");
5958 width -= amount;
5959 }
5960
5961 return IntType::get(input.getContext(), false, width, inputi.isConst());
5962}
5963
5964//===----------------------------------------------------------------------===//
5965// VerbatimExprOp
5966//===----------------------------------------------------------------------===//
5967
5968void VerbatimExprOp::getAsmResultNames(
5969 function_ref<void(Value, StringRef)> setNameFn) {
5970 // If the text is macro like, then use a pretty name. We only take the
5971 // text up to a weird character (like a paren) and currently ignore
5972 // parenthesized expressions.
5973 auto isOkCharacter = [](char c) { return llvm::isAlnum(c) || c == '_'; };
5974 auto name = getText();
5975 // Ignore a leading ` in macro name.
5976 if (name.starts_with("`"))
5977 name = name.drop_front();
5978 name = name.take_while(isOkCharacter);
5979 if (!name.empty())
5980 setNameFn(getResult(), name);
5981}
5982
5983//===----------------------------------------------------------------------===//
5984// VerbatimWireOp
5985//===----------------------------------------------------------------------===//
5986
5987void VerbatimWireOp::getAsmResultNames(
5988 function_ref<void(Value, StringRef)> setNameFn) {
5989 // If the text is macro like, then use a pretty name. We only take the
5990 // text up to a weird character (like a paren) and currently ignore
5991 // parenthesized expressions.
5992 auto isOkCharacter = [](char c) { return llvm::isAlnum(c) || c == '_'; };
5993 auto name = getText();
5994 // Ignore a leading ` in macro name.
5995 if (name.starts_with("`"))
5996 name = name.drop_front();
5997 name = name.take_while(isOkCharacter);
5998 if (!name.empty())
5999 setNameFn(getResult(), name);
6000}
6001
6002//===----------------------------------------------------------------------===//
6003// DPICallIntrinsicOp
6004//===----------------------------------------------------------------------===//
6005
6006static bool isTypeAllowedForDPI(Operation *op, Type type) {
6007 return !type.walk([&](firrtl::IntType intType) -> mlir::WalkResult {
6008 auto width = intType.getWidth();
6009 if (width < 0) {
6010 op->emitError() << "unknown width is not allowed for DPI";
6011 return WalkResult::interrupt();
6012 }
6013 if (width == 1 || width == 8 || width == 16 || width == 32 ||
6014 width >= 64)
6015 return WalkResult::advance();
6016 op->emitError()
6017 << "integer types used by DPI functions must have a "
6018 "specific bit width; "
6019 "it must be equal to 1(bit), 8(byte), 16(shortint), "
6020 "32(int), 64(longint) "
6021 "or greater than 64, but got "
6022 << intType;
6023 return WalkResult::interrupt();
6024 })
6025 .wasInterrupted();
6026}
6027
6028LogicalResult DPICallIntrinsicOp::verify() {
6029 if (auto inputNames = getInputNames()) {
6030 if (getInputs().size() != inputNames->size())
6031 return emitError() << "inputNames has " << inputNames->size()
6032 << " elements but there are " << getInputs().size()
6033 << " input arguments";
6034 }
6035 if (auto outputName = getOutputName())
6036 if (getNumResults() == 0)
6037 return emitError() << "output name is given but there is no result";
6038
6039 auto checkType = [this](Type type) {
6040 return isTypeAllowedForDPI(*this, type);
6041 };
6042 return success(llvm::all_of(this->getResultTypes(), checkType) &&
6043 llvm::all_of(this->getOperandTypes(), checkType));
6044}
6045
6046SmallVector<std::pair<circt::FieldRef, circt::FieldRef>>
6047DPICallIntrinsicOp::computeDataFlow() {
6048 if (getClock())
6049 return {};
6050
6051 SmallVector<std::pair<circt::FieldRef, circt::FieldRef>> deps;
6052
6053 for (auto operand : getOperands()) {
6054 auto type = type_cast<FIRRTLBaseType>(operand.getType());
6055 auto baseFieldRef = getFieldRefFromValue(operand);
6056 SmallVector<circt::FieldRef> operandFields;
6058 type, [&](uint64_t dstIndex, FIRRTLBaseType t, bool dstIsFlip) {
6059 operandFields.push_back(baseFieldRef.getSubField(dstIndex));
6060 });
6061
6062 // Record operand -> result dependency.
6063 for (auto result : getResults())
6065 type, [&](uint64_t dstIndex, FIRRTLBaseType t, bool dstIsFlip) {
6066 for (auto field : operandFields)
6067 deps.emplace_back(circt::FieldRef(result, dstIndex), field);
6068 });
6069 }
6070 return deps;
6071}
6072
6073//===----------------------------------------------------------------------===//
6074// Conversions to/from structs in the standard dialect.
6075//===----------------------------------------------------------------------===//
6076
6077LogicalResult HWStructCastOp::verify() {
6078 // We must have a bundle and a struct, with matching pairwise fields
6079 BundleType bundleType;
6080 hw::StructType structType;
6081 if ((bundleType = type_dyn_cast<BundleType>(getOperand().getType()))) {
6082 structType = dyn_cast<hw::StructType>(getType());
6083 if (!structType)
6084 return emitError("result type must be a struct");
6085 } else if ((bundleType = type_dyn_cast<BundleType>(getType()))) {
6086 structType = dyn_cast<hw::StructType>(getOperand().getType());
6087 if (!structType)
6088 return emitError("operand type must be a struct");
6089 } else {
6090 return emitError("either source or result type must be a bundle type");
6091 }
6092
6093 auto firFields = bundleType.getElements();
6094 auto hwFields = structType.getElements();
6095 if (firFields.size() != hwFields.size())
6096 return emitError("bundle and struct have different number of fields");
6097
6098 for (size_t findex = 0, fend = firFields.size(); findex < fend; ++findex) {
6099 if (firFields[findex].name.getValue() != hwFields[findex].name)
6100 return emitError("field names don't match '")
6101 << firFields[findex].name.getValue() << "', '"
6102 << hwFields[findex].name.getValue() << "'";
6103 int64_t firWidth =
6104 FIRRTLBaseType(firFields[findex].type).getBitWidthOrSentinel();
6105 int64_t hwWidth = hw::getBitWidth(hwFields[findex].type);
6106 if (firWidth > 0 && hwWidth > 0 && firWidth != hwWidth)
6107 return emitError("size of field '")
6108 << hwFields[findex].name.getValue() << "' don't match " << firWidth
6109 << ", " << hwWidth;
6110 }
6111
6112 return success();
6113}
6114
6115LogicalResult BitCastOp::verify() {
6116 auto inTypeBits = getBitWidth(getInput().getType(), /*ignoreFlip=*/true);
6117 auto resTypeBits = getBitWidth(getType());
6118 if (inTypeBits.has_value() && resTypeBits.has_value()) {
6119 // Bitwidths must match for valid bit
6120 if (*inTypeBits == *resTypeBits) {
6121 // non-'const' cannot be casted to 'const'
6122 if (containsConst(getType()) && !isConst(getOperand().getType()))
6123 return emitError("cannot cast non-'const' input type ")
6124 << getOperand().getType() << " to 'const' result type "
6125 << getType();
6126 return success();
6127 }
6128 return emitError("the bitwidth of input (")
6129 << *inTypeBits << ") and result (" << *resTypeBits
6130 << ") don't match";
6131 }
6132 if (!inTypeBits.has_value())
6133 return emitError("bitwidth cannot be determined for input operand type ")
6134 << getInput().getType();
6135 return emitError("bitwidth cannot be determined for result type ")
6136 << getType();
6137}
6138
6139//===----------------------------------------------------------------------===//
6140// Custom attr-dict Directive that Elides Annotations
6141//===----------------------------------------------------------------------===//
6142
6143/// Parse an optional attribute dictionary, adding an empty 'annotations'
6144/// attribute if not specified.
6145static ParseResult parseElideAnnotations(OpAsmParser &parser,
6146 NamedAttrList &resultAttrs) {
6147 auto result = parser.parseOptionalAttrDict(resultAttrs);
6148 if (!resultAttrs.get("annotations"))
6149 resultAttrs.append("annotations", parser.getBuilder().getArrayAttr({}));
6150
6151 return result;
6152}
6153
6154static void printElideAnnotations(OpAsmPrinter &p, Operation *op,
6155 DictionaryAttr attr,
6156 ArrayRef<StringRef> extraElides = {}) {
6157 SmallVector<StringRef> elidedAttrs(extraElides.begin(), extraElides.end());
6158 // Elide "annotations" if it is empty.
6159 if (op->getAttrOfType<ArrayAttr>("annotations").empty())
6160 elidedAttrs.push_back("annotations");
6161 // Elide "nameKind".
6162 elidedAttrs.push_back("nameKind");
6163
6164 p.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
6165}
6166
6167/// Parse an optional attribute dictionary, adding empty 'annotations' and
6168/// 'portAnnotations' attributes if not specified.
6169static ParseResult parseElidePortAnnotations(OpAsmParser &parser,
6170 NamedAttrList &resultAttrs) {
6171 auto result = parseElideAnnotations(parser, resultAttrs);
6172
6173 if (!resultAttrs.get("portAnnotations")) {
6174 SmallVector<Attribute, 16> portAnnotations(
6175 parser.getNumResults(), parser.getBuilder().getArrayAttr({}));
6176 resultAttrs.append("portAnnotations",
6177 parser.getBuilder().getArrayAttr(portAnnotations));
6178 }
6179 return result;
6180}
6181
6182// Elide 'annotations' and 'portAnnotations' attributes if they are empty.
6183static void printElidePortAnnotations(OpAsmPrinter &p, Operation *op,
6184 DictionaryAttr attr,
6185 ArrayRef<StringRef> extraElides = {}) {
6186 SmallVector<StringRef, 2> elidedAttrs(extraElides.begin(), extraElides.end());
6187
6188 if (llvm::all_of(op->getAttrOfType<ArrayAttr>("portAnnotations"),
6189 [&](Attribute a) { return cast<ArrayAttr>(a).empty(); }))
6190 elidedAttrs.push_back("portAnnotations");
6191 printElideAnnotations(p, op, attr, elidedAttrs);
6192}
6193
6194//===----------------------------------------------------------------------===//
6195// NameKind Custom Directive
6196//===----------------------------------------------------------------------===//
6197
6198static ParseResult parseNameKind(OpAsmParser &parser,
6199 firrtl::NameKindEnumAttr &result) {
6200 StringRef keyword;
6201
6202 if (!parser.parseOptionalKeyword(&keyword,
6203 {"interesting_name", "droppable_name"})) {
6204 auto kind = symbolizeNameKindEnum(keyword);
6205 result = NameKindEnumAttr::get(parser.getContext(), kind.value());
6206 return success();
6207 }
6208
6209 // Default is droppable name.
6210 result =
6211 NameKindEnumAttr::get(parser.getContext(), NameKindEnum::DroppableName);
6212 return success();
6213}
6214
6215static void printNameKind(OpAsmPrinter &p, Operation *op,
6216 firrtl::NameKindEnumAttr attr,
6217 ArrayRef<StringRef> extraElides = {}) {
6218 if (attr.getValue() != NameKindEnum::DroppableName)
6219 p << " " << stringifyNameKindEnum(attr.getValue());
6220}
6221
6222//===----------------------------------------------------------------------===//
6223// ImplicitSSAName Custom Directive
6224//===----------------------------------------------------------------------===//
6225
6226static ParseResult parseFIRRTLImplicitSSAName(OpAsmParser &parser,
6227 NamedAttrList &resultAttrs) {
6228 if (parseElideAnnotations(parser, resultAttrs))
6229 return failure();
6230 inferImplicitSSAName(parser, resultAttrs);
6231 return success();
6232}
6233
6234static void printFIRRTLImplicitSSAName(OpAsmPrinter &p, Operation *op,
6235 DictionaryAttr attrs) {
6236 SmallVector<StringRef, 4> elides;
6238 elides.push_back(Forceable::getForceableAttrName());
6239 elideImplicitSSAName(p, op, attrs, elides);
6240 printElideAnnotations(p, op, attrs, elides);
6241}
6242
6243//===----------------------------------------------------------------------===//
6244// MemOp Custom attr-dict Directive
6245//===----------------------------------------------------------------------===//
6246
6247static ParseResult parseMemOp(OpAsmParser &parser, NamedAttrList &resultAttrs) {
6248 return parseElidePortAnnotations(parser, resultAttrs);
6249}
6250
6251/// Always elide "ruw" and elide "annotations" if it exists or if it is empty.
6252static void printMemOp(OpAsmPrinter &p, Operation *op, DictionaryAttr attr) {
6253 // "ruw" and "inner_sym" is always elided.
6254 printElidePortAnnotations(p, op, attr, {"ruw", "inner_sym"});
6255}
6256
6257//===----------------------------------------------------------------------===//
6258// ClassInterface custom directive
6259//===----------------------------------------------------------------------===//
6260
6261static ParseResult parseClassInterface(OpAsmParser &parser, Type &result) {
6262 ClassType type;
6263 if (ClassType::parseInterface(parser, type))
6264 return failure();
6265 result = type;
6266 return success();
6267}
6268
6269static void printClassInterface(OpAsmPrinter &p, Operation *, ClassType type) {
6270 type.printInterface(p);
6271}
6272
6273//===----------------------------------------------------------------------===//
6274// Miscellaneous custom elision logic.
6275//===----------------------------------------------------------------------===//
6276
6277static ParseResult parseElideEmptyName(OpAsmParser &p,
6278 NamedAttrList &resultAttrs) {
6279 auto result = p.parseOptionalAttrDict(resultAttrs);
6280 if (!resultAttrs.get("name"))
6281 resultAttrs.append("name", p.getBuilder().getStringAttr(""));
6282
6283 return result;
6284}
6285
6286static void printElideEmptyName(OpAsmPrinter &p, Operation *op,
6287 DictionaryAttr attr,
6288 ArrayRef<StringRef> extraElides = {}) {
6289 SmallVector<StringRef> elides(extraElides.begin(), extraElides.end());
6290 if (op->getAttrOfType<StringAttr>("name").getValue().empty())
6291 elides.push_back("name");
6292
6293 p.printOptionalAttrDict(op->getAttrs(), elides);
6294}
6295
6296static ParseResult parsePrintfAttrs(OpAsmParser &p,
6297 NamedAttrList &resultAttrs) {
6298 return parseElideEmptyName(p, resultAttrs);
6299}
6300
6301static void printPrintfAttrs(OpAsmPrinter &p, Operation *op,
6302 DictionaryAttr attr) {
6303 printElideEmptyName(p, op, attr, {"formatString"});
6304}
6305
6306static ParseResult parseFPrintfAttrs(OpAsmParser &p,
6307 NamedAttrList &resultAttrs) {
6308 return parseElideEmptyName(p, resultAttrs);
6309}
6310
6311static void printFPrintfAttrs(OpAsmPrinter &p, Operation *op,
6312 DictionaryAttr attr) {
6313 printElideEmptyName(p, op, attr,
6314 {"formatString", "outputFile", "operandSegmentSizes"});
6315}
6316
6317static ParseResult parseStopAttrs(OpAsmParser &p, NamedAttrList &resultAttrs) {
6318 return parseElideEmptyName(p, resultAttrs);
6319}
6320
6321static void printStopAttrs(OpAsmPrinter &p, Operation *op,
6322 DictionaryAttr attr) {
6323 printElideEmptyName(p, op, attr, {"exitCode"});
6324}
6325
6326static ParseResult parseVerifAttrs(OpAsmParser &p, NamedAttrList &resultAttrs) {
6327 return parseElideEmptyName(p, resultAttrs);
6328}
6329
6330static void printVerifAttrs(OpAsmPrinter &p, Operation *op,
6331 DictionaryAttr attr) {
6332 printElideEmptyName(p, op, attr, {"message"});
6333}
6334
6335//===----------------------------------------------------------------------===//
6336// Various namers.
6337//===----------------------------------------------------------------------===//
6338
6339static void genericAsmResultNames(Operation *op,
6340 OpAsmSetValueNameFn setNameFn) {
6341 // Many firrtl dialect operations have an optional 'name' attribute. If
6342 // present, use it.
6343 if (op->getNumResults() == 1)
6344 if (auto nameAttr = op->getAttrOfType<StringAttr>("name"))
6345 setNameFn(op->getResult(0), nameAttr.getValue());
6346}
6347
6348void AddPrimOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
6349 genericAsmResultNames(*this, setNameFn);
6350}
6351
6352void AndPrimOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
6353 genericAsmResultNames(*this, setNameFn);
6354}
6355
6356void AndRPrimOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
6357 genericAsmResultNames(*this, setNameFn);
6358}
6359
6360void SizeOfIntrinsicOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
6361 genericAsmResultNames(*this, setNameFn);
6362}
6363void AsAsyncResetPrimOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
6364 genericAsmResultNames(*this, setNameFn);
6365}
6366void AsClockPrimOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
6367 genericAsmResultNames(*this, setNameFn);
6368}
6369void AsSIntPrimOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
6370 genericAsmResultNames(*this, setNameFn);
6371}
6372void AsUIntPrimOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
6373 genericAsmResultNames(*this, setNameFn);
6374}
6375void BitsPrimOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
6376 genericAsmResultNames(*this, setNameFn);
6377}
6378void CatPrimOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
6379 genericAsmResultNames(*this, setNameFn);
6380}
6381void CvtPrimOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
6382 genericAsmResultNames(*this, setNameFn);
6383}
6384void DShlPrimOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
6385 genericAsmResultNames(*this, setNameFn);
6386}
6387void DShlwPrimOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
6388 genericAsmResultNames(*this, setNameFn);
6389}
6390void DShrPrimOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
6391 genericAsmResultNames(*this, setNameFn);
6392}
6393void DivPrimOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
6394 genericAsmResultNames(*this, setNameFn);
6395}
6396void EQPrimOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
6397 genericAsmResultNames(*this, setNameFn);
6398}
6399void GEQPrimOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
6400 genericAsmResultNames(*this, setNameFn);
6401}
6402void GTPrimOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
6403 genericAsmResultNames(*this, setNameFn);
6404}
6405void GenericIntrinsicOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
6406 genericAsmResultNames(*this, setNameFn);
6407}
6408void HeadPrimOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
6409 genericAsmResultNames(*this, setNameFn);
6410}
6411void IntegerAddOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
6412 genericAsmResultNames(*this, setNameFn);
6413}
6414void IntegerMulOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
6415 genericAsmResultNames(*this, setNameFn);
6416}
6417void IntegerShrOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
6418 genericAsmResultNames(*this, setNameFn);
6419}
6420void IntegerShlOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
6421 genericAsmResultNames(*this, setNameFn);
6422}
6423void IsTagOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
6424 genericAsmResultNames(*this, setNameFn);
6425}
6426void IsXIntrinsicOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
6427 genericAsmResultNames(*this, setNameFn);
6428}
6429void PlusArgsValueIntrinsicOp::getAsmResultNames(
6430 OpAsmSetValueNameFn setNameFn) {
6431 genericAsmResultNames(*this, setNameFn);
6432}
6433void PlusArgsTestIntrinsicOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
6434 genericAsmResultNames(*this, setNameFn);
6435}
6436void LEQPrimOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
6437 genericAsmResultNames(*this, setNameFn);
6438}
6439void LTPrimOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
6440 genericAsmResultNames(*this, setNameFn);
6441}
6442void MulPrimOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
6443 genericAsmResultNames(*this, setNameFn);
6444}
6445void MultibitMuxOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
6446 genericAsmResultNames(*this, setNameFn);
6447}
6448void MuxPrimOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
6449 genericAsmResultNames(*this, setNameFn);
6450}
6451void Mux4CellIntrinsicOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
6452 genericAsmResultNames(*this, setNameFn);
6453}
6454void Mux2CellIntrinsicOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
6455 genericAsmResultNames(*this, setNameFn);
6456}
6457void NEQPrimOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
6458 genericAsmResultNames(*this, setNameFn);
6459}
6460void NegPrimOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
6461 genericAsmResultNames(*this, setNameFn);
6462}
6463void NotPrimOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
6464 genericAsmResultNames(*this, setNameFn);
6465}
6466void OrPrimOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
6467 genericAsmResultNames(*this, setNameFn);
6468}
6469void OrRPrimOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
6470 genericAsmResultNames(*this, setNameFn);
6471}
6472void PadPrimOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
6473 genericAsmResultNames(*this, setNameFn);
6474}
6475void RemPrimOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
6476 genericAsmResultNames(*this, setNameFn);
6477}
6478void ShlPrimOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
6479 genericAsmResultNames(*this, setNameFn);
6480}
6481void ShrPrimOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
6482 genericAsmResultNames(*this, setNameFn);
6483}
6484
6485void SubPrimOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
6486 genericAsmResultNames(*this, setNameFn);
6487}
6488
6489void SubaccessOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
6490 genericAsmResultNames(*this, setNameFn);
6491}
6492
6493void SubfieldOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
6494 genericAsmResultNames(*this, setNameFn);
6495}
6496
6497void OpenSubfieldOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
6498 genericAsmResultNames(*this, setNameFn);
6499}
6500
6501void SubtagOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
6502 genericAsmResultNames(*this, setNameFn);
6503}
6504
6505void SubindexOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
6506 genericAsmResultNames(*this, setNameFn);
6507}
6508
6509void OpenSubindexOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
6510 genericAsmResultNames(*this, setNameFn);
6511}
6512
6513void TagExtractOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
6514 genericAsmResultNames(*this, setNameFn);
6515}
6516
6517void TailPrimOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
6518 genericAsmResultNames(*this, setNameFn);
6519}
6520
6521void XorPrimOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
6522 genericAsmResultNames(*this, setNameFn);
6523}
6524
6525void XorRPrimOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
6526 genericAsmResultNames(*this, setNameFn);
6527}
6528
6529void UninferredResetCastOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
6530 genericAsmResultNames(*this, setNameFn);
6531}
6532
6533void ConstCastOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
6534 genericAsmResultNames(*this, setNameFn);
6535}
6536
6537void ElementwiseXorPrimOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
6538 genericAsmResultNames(*this, setNameFn);
6539}
6540
6541void ElementwiseOrPrimOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
6542 genericAsmResultNames(*this, setNameFn);
6543}
6544
6545void ElementwiseAndPrimOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
6546 genericAsmResultNames(*this, setNameFn);
6547}
6548
6549//===----------------------------------------------------------------------===//
6550// RefOps
6551//===----------------------------------------------------------------------===//
6552
6553void RefCastOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
6554 genericAsmResultNames(*this, setNameFn);
6555}
6556
6557void RefResolveOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
6558 genericAsmResultNames(*this, setNameFn);
6559}
6560
6561void RefSendOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
6562 genericAsmResultNames(*this, setNameFn);
6563}
6564
6565void RefSubOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
6566 genericAsmResultNames(*this, setNameFn);
6567}
6568
6569void RWProbeOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
6570 genericAsmResultNames(*this, setNameFn);
6571}
6572
6573FIRRTLType RefResolveOp::inferReturnType(ValueRange operands,
6574 DictionaryAttr attrs,
6575 OpaqueProperties properties,
6576 mlir::RegionRange regions,
6577 std::optional<Location> loc) {
6578 Type inType = operands[0].getType();
6579 auto inRefType = type_dyn_cast<RefType>(inType);
6580 if (!inRefType)
6581 return emitInferRetTypeError(
6582 loc, "ref.resolve operand must be ref type, not ", inType);
6583 return inRefType.getType();
6584}
6585
6586FIRRTLType RefSendOp::inferReturnType(ValueRange operands, DictionaryAttr attrs,
6587 OpaqueProperties properties,
6588 mlir::RegionRange regions,
6589 std::optional<Location> loc) {
6590 Type inType = operands[0].getType();
6591 auto inBaseType = type_dyn_cast<FIRRTLBaseType>(inType);
6592 if (!inBaseType)
6593 return emitInferRetTypeError(
6594 loc, "ref.send operand must be base type, not ", inType);
6595 return RefType::get(inBaseType.getPassiveType());
6596}
6597
6598FIRRTLType RefSubOp::inferReturnType(Type type, uint32_t fieldIndex,
6599 std::optional<Location> loc) {
6600 auto refType = type_dyn_cast<RefType>(type);
6601 if (!refType)
6602 return emitInferRetTypeError(loc, "input must be of reference type");
6603 auto inType = refType.getType();
6604
6605 // TODO: Determine ref.sub + rwprobe behavior, test.
6606 // Probably best to demote to non-rw, but that has implications
6607 // for any LowerTypes behavior being relied on.
6608 // Allow for now, as need to LowerTypes things generally.
6609 if (auto vectorType = type_dyn_cast<FVectorType>(inType)) {
6610 if (fieldIndex < vectorType.getNumElements())
6611 return RefType::get(
6612 vectorType.getElementType().getConstType(
6613 vectorType.isConst() || vectorType.getElementType().isConst()),
6614 refType.getForceable(), refType.getLayer());
6615 return emitInferRetTypeError(loc, "out of range index '", fieldIndex,
6616 "' in RefType of vector type ", refType);
6617 }
6618 if (auto bundleType = type_dyn_cast<BundleType>(inType)) {
6619 if (fieldIndex >= bundleType.getNumElements()) {
6620 return emitInferRetTypeError(loc,
6621 "subfield element index is greater than "
6622 "the number of fields in the bundle type");
6623 }
6624 return RefType::get(
6625 bundleType.getElement(fieldIndex)
6626 .type.getConstType(
6627 bundleType.isConst() ||
6628 bundleType.getElement(fieldIndex).type.isConst()),
6629 refType.getForceable(), refType.getLayer());
6630 }
6631
6632 return emitInferRetTypeError(
6633 loc, "ref.sub op requires a RefType of vector or bundle base type");
6634}
6635
6636LogicalResult RefCastOp::verify() {
6637 auto srcLayers = getLayersFor(getInput());
6638 auto dstLayers = getLayersFor(getResult());
6640 getOperation(), srcLayers, dstLayers,
6641 "cannot discard layer requirements of input reference",
6642 "discarding layer requirements");
6643}
6644
6645LogicalResult RefResolveOp::verify() {
6646 auto srcLayers = getLayersFor(getRef());
6647 auto dstLayers = getAmbientLayersAt(getOperation());
6649 getOperation(), srcLayers, dstLayers,
6650 "ambient layers are insufficient to resolve reference");
6651}
6652
6653LogicalResult RWProbeOp::verifyInnerRefs(hw::InnerRefNamespace &ns) {
6654 auto targetRef = getTarget();
6655 if (targetRef.getModule() !=
6656 (*this)->getParentOfType<FModuleLike>().getModuleNameAttr())
6657 return emitOpError() << "has non-local target";
6658
6659 auto target = ns.lookup(targetRef);
6660 if (!target)
6661 return emitOpError() << "has target that cannot be resolved: " << targetRef;
6662
6663 auto checkFinalType = [&](auto type, Location loc) -> LogicalResult {
6664 // Determine final type.
6665 mlir::Type fType =
6666 hw::FieldIdImpl::getFinalTypeByFieldID(type, target.getField());
6667 // Check.
6668 auto baseType = type_dyn_cast<FIRRTLBaseType>(fType);
6669 if (!baseType || baseType.getPassiveType() != getType().getType()) {
6670 auto diag = emitOpError("has type mismatch: target resolves to ")
6671 << fType << " instead of expected " << getType().getType();
6672 diag.attachNote(loc) << "target resolves here";
6673 return diag;
6674 }
6675 return success();
6676 };
6677 if (target.isPort()) {
6678 auto mod = cast<FModuleLike>(target.getOp());
6679 return checkFinalType(mod.getPortType(target.getPort()),
6680 mod.getPortLocation(target.getPort()));
6681 }
6682 hw::InnerSymbolOpInterface symOp =
6683 cast<hw::InnerSymbolOpInterface>(target.getOp());
6684 if (!symOp.getTargetResult())
6685 return emitOpError("has target that cannot be probed")
6686 .attachNote(symOp.getLoc())
6687 .append("target resolves here");
6688 auto *ancestor =
6689 symOp.getTargetResult().getParentBlock()->findAncestorOpInBlock(**this);
6690 if (!ancestor || !symOp->isBeforeInBlock(ancestor))
6691 return emitOpError("is not dominated by target")
6692 .attachNote(symOp.getLoc())
6693 .append("target here");
6694 return checkFinalType(symOp.getTargetResult().getType(), symOp.getLoc());
6695}
6696
6697LogicalResult RefForceOp::verify() {
6698 auto ambientLayers = getAmbientLayersAt(getOperation());
6699 auto destLayers = getLayersFor(getDest());
6701 getOperation(), destLayers, ambientLayers,
6702 "has insufficient ambient layers to force its reference");
6703}
6704
6705LogicalResult RefForceInitialOp::verify() {
6706 auto ambientLayers = getAmbientLayersAt(getOperation());
6707 auto destLayers = getLayersFor(getDest());
6709 getOperation(), destLayers, ambientLayers,
6710 "has insufficient ambient layers to force its reference");
6711}
6712
6713LogicalResult RefReleaseOp::verify() {
6714 auto ambientLayers = getAmbientLayersAt(getOperation());
6715 auto destLayers = getLayersFor(getDest());
6717 getOperation(), destLayers, ambientLayers,
6718 "has insufficient ambient layers to release its reference");
6719}
6720
6721LogicalResult RefReleaseInitialOp::verify() {
6722 auto ambientLayers = getAmbientLayersAt(getOperation());
6723 auto destLayers = getLayersFor(getDest());
6725 getOperation(), destLayers, ambientLayers,
6726 "has insufficient ambient layers to release its reference");
6727}
6728
6729LogicalResult XMRRefOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
6730 auto *target = symbolTable.lookupNearestSymbolFrom(*this, getRefAttr());
6731 if (!target)
6732 return emitOpError("has an invalid symbol reference");
6733
6734 if (!isa<hw::HierPathOp>(target))
6735 return emitOpError("does not target a hierpath op");
6736
6737 // TODO: Verify that the target's type matches the type of this op.
6738 return success();
6739}
6740
6741LogicalResult XMRDerefOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
6742 auto *target = symbolTable.lookupNearestSymbolFrom(*this, getRefAttr());
6743 if (!target)
6744 return emitOpError("has an invalid symbol reference");
6745
6746 if (!isa<hw::HierPathOp>(target))
6747 return emitOpError("does not target a hierpath op");
6748
6749 // TODO: Verify that the target's type matches the type of this op.
6750 return success();
6751}
6752
6753//===----------------------------------------------------------------------===//
6754// Layer Block Operations
6755//===----------------------------------------------------------------------===//
6756
6757LogicalResult LayerBlockOp::verify() {
6758 auto layerName = getLayerName();
6759 auto *parentOp = (*this)->getParentOp();
6760
6761 // Get parent operation that isn't a when or match.
6762 while (isa<WhenOp, MatchOp>(parentOp))
6763 parentOp = parentOp->getParentOp();
6764
6765 // Verify the correctness of the symbol reference. Only verify that this
6766 // layer block makes sense in its parent module or layer block.
6767 auto nestedReferences = layerName.getNestedReferences();
6768 if (nestedReferences.empty()) {
6769 if (!isa<FModuleOp>(parentOp)) {
6770 auto diag = emitOpError() << "has an un-nested layer symbol, but does "
6771 "not have a 'firrtl.module' op as a parent";
6772 return diag.attachNote(parentOp->getLoc())
6773 << "illegal parent op defined here";
6774 }
6775 } else {
6776 auto parentLayerBlock = dyn_cast<LayerBlockOp>(parentOp);
6777 if (!parentLayerBlock) {
6778 auto diag = emitOpError()
6779 << "has a nested layer symbol, but does not have a '"
6780 << getOperationName() << "' op as a parent'";
6781 return diag.attachNote(parentOp->getLoc())
6782 << "illegal parent op defined here";
6783 }
6784 auto parentLayerBlockName = parentLayerBlock.getLayerName();
6785 if (parentLayerBlockName.getRootReference() !=
6786 layerName.getRootReference() ||
6787 parentLayerBlockName.getNestedReferences() !=
6788 layerName.getNestedReferences().drop_back()) {
6789 auto diag = emitOpError() << "is nested under an illegal layer block";
6790 return diag.attachNote(parentLayerBlock->getLoc())
6791 << "illegal parent layer block defined here";
6792 }
6793 }
6794
6795 // Verify the body of the region.
6796 FieldRefCache fieldRefCache;
6797 auto result = getBody(0)->walk<mlir::WalkOrder::PreOrder>(
6798 [&](Operation *op) -> WalkResult {
6799 // Skip nested layer blocks. Those will be verified separately.
6800 if (isa<LayerBlockOp>(op))
6801 return WalkResult::skip();
6802
6803 // Check all the operands of each op to make sure that only legal things
6804 // are captured.
6805 for (auto operand : op->getOperands()) {
6806 // Any value captured from the current layer block is fine.
6807 if (auto *definingOp = operand.getDefiningOp())
6808 if (getOperation()->isAncestor(definingOp))
6809 continue;
6810
6811 auto type = operand.getType();
6812
6813 // Capture of a non-base type, e.g., reference, is allowed.
6814 if (isa<PropertyType>(type)) {
6815 auto diag = emitOpError() << "captures a property operand";
6816 diag.attachNote(operand.getLoc()) << "operand is defined here";
6817 diag.attachNote(op->getLoc()) << "operand is used here";
6818 return WalkResult::interrupt();
6819 }
6820 }
6821
6822 // Ensure that the layer block does not drive any sinks outside.
6823 if (auto connect = dyn_cast<FConnectLike>(op)) {
6824 // ref.define is allowed to drive probes outside the layerblock.
6825 if (isa<RefDefineOp>(connect))
6826 return WalkResult::advance();
6827
6828 // Verify that connects only drive values declared in the layer block.
6829 // If we see a non-passive connect destination, then verify that the
6830 // source is in the same layer block so that the source is not driven.
6831 auto dest =
6832 fieldRefCache.getFieldRefFromValue(connect.getDest()).getValue();
6833 bool passive = true;
6834 if (auto type =
6835 type_dyn_cast<FIRRTLBaseType>(connect.getDest().getType()))
6836 passive = type.isPassive();
6837 // TODO: Improve this verifier. This is intentionally _not_ verifying
6838 // a non-passive ConnectLike because it is hugely annoying to do
6839 // so---it requires a full understanding of if the connect is driving
6840 // destination-to-source, source-to-destination, or bi-directionally
6841 // which requires deep inspection of the type. Eventually, the FIRRTL
6842 // pass pipeline will remove all flips (e.g., canonicalize connect to
6843 // matchingconnect) and this hole won't exist.
6844 if (!passive)
6845 return WalkResult::advance();
6846
6847 if (isAncestorOfValueOwner(getOperation(), dest))
6848 return WalkResult::advance();
6849
6850 auto diag =
6851 connect.emitOpError()
6852 << "connects to a destination which is defined outside its "
6853 "enclosing layer block";
6854 diag.attachNote(getLoc()) << "enclosing layer block is defined here";
6855 diag.attachNote(dest.getLoc()) << "destination is defined here";
6856 return WalkResult::interrupt();
6857 }
6858
6859 return WalkResult::advance();
6860 });
6861
6862 return failure(result.wasInterrupted());
6863}
6864
6865LogicalResult
6866LayerBlockOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
6867 auto layerOp =
6868 symbolTable.lookupNearestSymbolFrom<LayerOp>(*this, getLayerNameAttr());
6869 if (!layerOp) {
6870 return emitOpError("invalid symbol reference");
6871 }
6872
6873 return success();
6874}
6875
6876//===----------------------------------------------------------------------===//
6877// Format String operations
6878//===----------------------------------------------------------------------===//
6879
6880void TimeOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
6881 setNameFn(getResult(), "time");
6882}
6883
6884void HierarchicalModuleNameOp::getAsmResultNames(
6885 OpAsmSetValueNameFn setNameFn) {
6886 setNameFn(getResult(), "hierarchicalmodulename");
6887}
6888
6889ParseResult FPrintFOp::parse(::mlir::OpAsmParser &parser,
6890 ::mlir::OperationState &result) {
6891 // Parse required clock and condition operands
6892 OpAsmParser::UnresolvedOperand clock, cond;
6893 if (parser.parseOperand(clock) || parser.parseComma() ||
6894 parser.parseOperand(cond) || parser.parseComma())
6895 return failure();
6896
6897 auto parseFormatString =
6898 [&parser](llvm::SMLoc &loc, StringAttr &result,
6899 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands)
6900 -> ParseResult {
6901 loc = parser.getCurrentLocation();
6902 // NOTE: Don't use parseAttribute. "format_string" : !firrtl.clock is
6903 // considered as a typed attr.
6904 std::string resultStr;
6905 if (parser.parseString(&resultStr))
6906 return failure();
6907 result = parser.getBuilder().getStringAttr(resultStr);
6908
6909 // This is an optional
6910 if (parser.parseOperandList(operands, AsmParser::Delimiter::OptionalParen))
6911 return failure();
6912 return success();
6913 };
6914
6915 // Parse output file and format string substitutions
6916 SmallVector<OpAsmParser::UnresolvedOperand> outputFileSubstitutions,
6917 substitutions;
6918 llvm::SMLoc outputFileLoc, formatStringLoc;
6919
6920 if (parseFormatString(
6921 outputFileLoc,
6922 result.getOrAddProperties<FPrintFOp::Properties>().outputFile,
6923 outputFileSubstitutions) ||
6924 parser.parseComma() ||
6925 parseFormatString(
6926 formatStringLoc,
6927 result.getOrAddProperties<FPrintFOp::Properties>().formatString,
6928 substitutions))
6929 return failure();
6930
6931 if (parseFPrintfAttrs(parser, result.attributes))
6932 return failure();
6933
6934 // Parse types
6935 Type clockType, condType;
6936 SmallVector<Type> restTypes;
6937
6938 if (parser.parseColon() || parser.parseType(clockType) ||
6939 parser.parseComma() || parser.parseType(condType))
6940 return failure();
6941
6942 if (succeeded(parser.parseOptionalComma())) {
6943 if (parser.parseTypeList(restTypes))
6944 return failure();
6945 }
6946
6947 // Set operand segment sizes
6948 result.getOrAddProperties<FPrintFOp::Properties>().operandSegmentSizes = {
6949 1, 1, static_cast<int32_t>(outputFileSubstitutions.size()),
6950 static_cast<int32_t>(substitutions.size())};
6951
6952 // Resolve all operands
6953 if (parser.resolveOperand(clock, clockType, result.operands) ||
6954 parser.resolveOperand(cond, condType, result.operands) ||
6955 parser.resolveOperands(
6956 outputFileSubstitutions,
6957 ArrayRef(restTypes).take_front(outputFileSubstitutions.size()),
6958 outputFileLoc, result.operands) ||
6959 parser.resolveOperands(
6960 substitutions,
6961 ArrayRef(restTypes).drop_front(outputFileSubstitutions.size()),
6962 formatStringLoc, result.operands))
6963 return failure();
6964
6965 return success();
6966}
6967
6968void FPrintFOp::print(OpAsmPrinter &p) {
6969 p << " " << getClock() << ", " << getCond() << ", ";
6970 p.printAttributeWithoutType(getOutputFileAttr());
6971 if (!getOutputFileSubstitutions().empty()) {
6972 p << "(";
6973 p.printOperands(getOutputFileSubstitutions());
6974 p << ")";
6975 }
6976 p << ", ";
6977 p.printAttributeWithoutType(getFormatStringAttr());
6978 if (!getSubstitutions().empty()) {
6979 p << "(";
6980 p.printOperands(getSubstitutions());
6981 p << ")";
6982 }
6983 printFPrintfAttrs(p, *this, (*this)->getAttrDictionary());
6984 p << " : " << getClock().getType() << ", " << getCond().getType();
6985 if (!getOutputFileSubstitutions().empty() || !getSubstitutions().empty()) {
6986 for (auto type : getOperands().drop_front(2).getTypes()) {
6987 p << ", ";
6988 p.printType(type);
6989 }
6990 }
6991}
6992
6993//===----------------------------------------------------------------------===//
6994// FFlushOp
6995//===----------------------------------------------------------------------===//
6996
6997LogicalResult FFlushOp::verify() {
6998 if (!getOutputFileAttr() && !getOutputFileSubstitutions().empty())
6999 return emitOpError("substitutions without output file are not allowed");
7000 return success();
7001}
7002
7003//===----------------------------------------------------------------------===//
7004// BindOp
7005//===----------------------------------------------------------------------===//
7006
7007LogicalResult BindOp::verifyInnerRefs(hw::InnerRefNamespace &ns) {
7008 auto ref = getInstanceAttr();
7009 auto target = ns.lookup(ref);
7010 if (!target)
7011 return emitError() << "target " << ref << " cannot be resolved";
7012
7013 if (!target.isOpOnly())
7014 return emitError() << "target " << ref << " is not an operation";
7015
7016 auto instance = dyn_cast<InstanceOp>(target.getOp());
7017 if (!instance)
7018 return emitError() << "target " << ref << " is not an instance";
7019
7020 if (!instance.getDoNotPrint())
7021 return emitError() << "target " << ref << " is not marked doNotPrint";
7022
7023 return success();
7024}
7025
7026//===----------------------------------------------------------------------===//
7027// TblGen Generated Logic.
7028//===----------------------------------------------------------------------===//
7029
7030// Provide the autogenerated implementation guts for the Op classes.
7031#define GET_OP_CLASSES
7032#include "circt/Dialect/FIRRTL/FIRRTL.cpp.inc"
static void printNameKind(OpAsmPrinter &p, Operation *op, firrtl::NameKindEnumAttr attr, ArrayRef< StringRef > extraElides={})
static ParseResult parseNameKind(OpAsmParser &parser, firrtl::NameKindEnumAttr &result)
assert(baseType &&"element must be base type")
MlirType uint64_t numElements
Definition CHIRRTL.cpp:30
MlirType elementType
Definition CHIRRTL.cpp:29
#define isdigit(x)
Definition FIRLexer.cpp:26
static Attribute fixDomainInfoInsertions(MLIRContext *context, Attribute domainInfoAttr, ArrayRef< unsigned > indexMap)
Return an updated domain info Attribute with domain indices updated based on port insertions.
static LogicalResult verifyProbeType(RefType refType, Location loc, CircuitOp circuitOp, SymbolTableCollection &symbolTable, Twine start)
static ArrayAttr fixDomainInfoDeletions(MLIRContext *context, ArrayAttr domainInfoAttr, const llvm::BitVector &portIndices, bool supportsEmptyAttr)
static SmallVector< PortInfo > getPortImpl(FModuleLike module)
static void buildClass(OpBuilder &builder, OperationState &result, StringAttr name, ArrayRef< PortInfo > ports)
static FlatSymbolRefAttr getDomainTypeName(Value value)
static void printStopAttrs(OpAsmPrinter &p, Operation *op, DictionaryAttr attr)
static void buildModule(OpBuilder &builder, OperationState &result, StringAttr name, ArrayRef< PortInfo > ports, ArrayAttr annotations, ArrayAttr layers)
static LayerSet getLayersFor(Value value)
Get the effective layer requirements for the given value.
static SmallVector< hw::PortInfo > getPortListImpl(FModuleLike module)
ParseResult parseSubfieldLikeOp(OpAsmParser &parser, OperationState &result)
static bool isSameIntTypeKind(Type lhs, Type rhs, int32_t &lhsWidth, int32_t &rhsWidth, bool &isConstResult, std::optional< Location > loc)
If LHS and RHS are both UInt or SInt types, the return true and fill in the width of them if known.
static LogicalResult verifySubfieldLike(OpTy op)
static void printFPrintfAttrs(OpAsmPrinter &p, Operation *op, DictionaryAttr attr)
static LogicalResult checkSingleConnect(FConnectLike connect)
Returns success if the given connect is the sole driver of its dest operand.
static bool isConstFieldDriven(FIRRTLBaseType type, bool isFlip=false, bool outerTypeIsConst=false)
Checks if the type has any 'const' leaf elements .
static ParseResult parsePrintfAttrs(OpAsmParser &p, NamedAttrList &resultAttrs)
static ParseResult parseParameterList(OpAsmParser &parser, ArrayAttr &parameters)
Shim to use with assemblyFormat, custom<ParameterList>.
static RetTy emitInferRetTypeError(std::optional< Location > loc, const Twine &message, Args &&...args)
Emit an error if optional location is non-null, return null of return type.
Definition FIRRTLOps.cpp:97
static SmallVector< T > removeElementsAtIndices(ArrayRef< T > input, const llvm::BitVector &indicesToDrop)
Remove elements from the input array corresponding to set bits in indicesToDrop, returning the elemen...
Definition FIRRTLOps.cpp:58
static LogicalResult checkLayerCompatibility(Operation *op, const LayerSet &src, const LayerSet &dst, const Twine &errorMsg, const Twine &noteMsg=Twine("missing layer requirements"))
static ParseResult parseModulePorts(OpAsmParser &parser, bool hasSSAIdentifiers, bool supportsSymbols, bool supportsDomains, SmallVectorImpl< OpAsmParser::Argument > &entryArgs, SmallVectorImpl< Direction > &portDirections, SmallVectorImpl< Attribute > &portNames, SmallVectorImpl< Attribute > &portTypes, SmallVectorImpl< Attribute > &portAnnotations, SmallVectorImpl< Attribute > &portSyms, SmallVectorImpl< Attribute > &portLocs, SmallVectorImpl< Attribute > &domains)
Parse a list of module ports.
static LogicalResult checkConnectConditionality(FConnectLike connect)
Checks that connections to 'const' destinations are not dependent on non-'const' conditions in when b...
static void erasePorts(FModuleLike op, const llvm::BitVector &portIndices)
Erases the ports that have their corresponding bit set in portIndices.
static ParseResult parseClassInterface(OpAsmParser &parser, Type &result)
static void printElidePortAnnotations(OpAsmPrinter &p, Operation *op, DictionaryAttr attr, ArrayRef< StringRef > extraElides={})
static ParseResult parseStopAttrs(OpAsmParser &p, NamedAttrList &resultAttrs)
static ParseResult parseNameKind(OpAsmParser &parser, firrtl::NameKindEnumAttr &result)
A forward declaration for NameKind attribute parser.
static size_t getAddressWidth(size_t depth)
static void forceableAsmResultNames(Forceable op, StringRef name, OpAsmSetValueNameFn setNameFn)
Helper for naming forceable declarations (and their optional ref result).
static void printFModuleLikeOp(OpAsmPrinter &p, FModuleLike op)
static void printSubfieldLikeOp(OpTy op, ::mlir::OpAsmPrinter &printer)
static FlatSymbolRefAttr getDomainTypeNameOfResult(T op, size_t i)
static bool checkAggConstant(Operation *op, Attribute attr, FIRRTLBaseType type)
static void printClassLike(OpAsmPrinter &p, ClassLike op)
static hw::ModulePort::Direction dirFtoH(Direction dir)
static ParseResult parseOptionalParameters(OpAsmParser &parser, SmallVectorImpl< Attribute > &parameters)
Parse an parameter list if present.
static MemOp::PortKind getMemPortKindFromType(FIRRTLType type)
Return the kind of port this is given the port type from a 'mem' decl.
static void genericAsmResultNames(Operation *op, OpAsmSetValueNameFn setNameFn)
static void printClassInterface(OpAsmPrinter &p, Operation *, ClassType type)
static void printPrintfAttrs(OpAsmPrinter &p, Operation *op, DictionaryAttr attr)
const char * toString(Flow flow)
static void replaceUsesRespectingInsertedPorts(Operation *op1, Operation *op2, ArrayRef< std::pair< unsigned, PortInfo > > insertions)
static bool isLayerSetCompatibleWith(const LayerSet &src, const LayerSet &dst, SmallVectorImpl< SymbolRefAttr > &missing)
Check that the source layers are all present in the destination layers.
static bool isLayerCompatibleWith(mlir::SymbolRefAttr srcLayer, mlir::SymbolRefAttr dstLayer)
Check that the source layer is compatible with the destination layer.
static LayerSet getAmbientLayersFor(Value value)
Get the ambient layer requirements at the definition site of the value.
void buildModuleLike(OpBuilder &builder, OperationState &result, StringAttr name, ArrayRef< PortInfo > ports)
static LayerSet getAmbientLayersAt(Operation *op)
Get the ambient layers active at the given op.
static void printFIRRTLImplicitSSAName(OpAsmPrinter &p, Operation *op, DictionaryAttr attrs)
static ParseResult parseFIRRTLImplicitSSAName(OpAsmParser &parser, NamedAttrList &resultAttrs)
static FIRRTLBaseType inferMuxReturnType(FIRRTLBaseType high, FIRRTLBaseType low, bool isConstCondition, std::optional< Location > loc)
Infer the result type for a multiplexer given its two operand types, which may be aggregates.
static ParseResult parseCircuitOpAttrs(OpAsmParser &parser, NamedAttrList &resultAttrs)
void getAsmBlockArgumentNamesImpl(Operation *op, mlir::Region &region, OpAsmSetValueNameFn setNameFn)
Get a special name to use when printing the entry block arguments of the region contained by an opera...
static void printElideAnnotations(OpAsmPrinter &p, Operation *op, DictionaryAttr attr, ArrayRef< StringRef > extraElides={})
static ParseResult parseElidePortAnnotations(OpAsmParser &parser, NamedAttrList &resultAttrs)
Parse an optional attribute dictionary, adding empty 'annotations' and 'portAnnotations' attributes i...
static void insertPorts(FModuleLike op, ArrayRef< std::pair< unsigned, PortInfo > > ports)
Inserts the given ports.
static ParseResult parseFPrintfAttrs(OpAsmParser &p, NamedAttrList &resultAttrs)
static ParseResult parseMemOp(OpAsmParser &parser, NamedAttrList &resultAttrs)
static void replaceUsesRespectingErasedPorts(Operation *op1, Operation *op2, const llvm::BitVector &erasures)
static LogicalResult checkConnectFlow(Operation *connect)
Check if the source and sink are of appropriate flow.
static void printParameterList(OpAsmPrinter &p, Operation *op, ArrayAttr parameters)
Print a paramter list for a module or instance.
static ParseResult parseVerifAttrs(OpAsmParser &p, NamedAttrList &resultAttrs)
static ParseResult parseElideAnnotations(OpAsmParser &parser, NamedAttrList &resultAttrs)
Parse an optional attribute dictionary, adding an empty 'annotations' attribute if not specified.
ParseResult parseClassLike(OpAsmParser &parser, OperationState &result, bool hasSSAIdentifiers)
static void printCircuitOpAttrs(OpAsmPrinter &p, Operation *op, DictionaryAttr attr)
static LogicalResult verifyPortSymbolUses(FModuleLike module, SymbolTableCollection &symbolTable)
static void printVerifAttrs(OpAsmPrinter &p, Operation *op, DictionaryAttr attr)
static void printMemOp(OpAsmPrinter &p, Operation *op, DictionaryAttr attr)
Always elide "ruw" and elide "annotations" if it exists or if it is empty.
static bool isTypeAllowedForDPI(Operation *op, Type type)
static ParseResult parseElideEmptyName(OpAsmParser &p, NamedAttrList &resultAttrs)
static bool printModulePorts(OpAsmPrinter &p, Block *block, ArrayRef< bool > portDirections, ArrayRef< Attribute > portNames, ArrayRef< Attribute > portTypes, ArrayRef< Attribute > portAnnotations, ArrayRef< Attribute > portSyms, ArrayRef< Attribute > portLocs, ArrayRef< Attribute > domainInfo)
Print a list of module ports in the following form: in x: !firrtl.uint<1> [{class = "DontTouch}],...
static void printElideEmptyName(OpAsmPrinter &p, Operation *op, DictionaryAttr attr, ArrayRef< StringRef > extraElides={})
static ParseResult parseFModuleLikeOp(OpAsmParser &parser, OperationState &result, bool hasSSAIdentifiers)
static bool isAncestor(Block *block, Block *other)
Definition LayerSink.cpp:57
static Location getLoc(DefSlot slot)
Definition Mem2Reg.cpp:216
static StringAttr append(StringAttr base, const Twine &suffix)
Return a attribute with the specified suffix appended.
static std::optional< APInt > getInt(Value value)
Helper to convert a value to a constant integer if it is one.
static Block * getBodyBlock(FModuleLike mod)
static InstancePath empty
This class represents a reference to a specific field or element of an aggregate value.
Definition FieldRef.h:28
Value getValue() const
Get the Value which created this location.
Definition FieldRef.h:37
This class provides a read-only projection over the MLIR attributes that represent a set of annotatio...
bool hasDontTouch() const
firrtl.transforms.DontTouchAnnotation
static AnnotationSet forPort(FModuleLike op, size_t portNo)
Get an annotation set for the specified port.
ExprVisitor is a visitor for FIRRTL expression nodes.
ResultType dispatchExprVisitor(Operation *op, ExtraArgs... args)
FIRRTLBaseType getConstType(bool isConst) const
Return a 'const' or non-'const' version of this type.
FIRRTLBaseType getMaskType()
Return this type with all ground types replaced with UInt<1>.
int32_t getBitWidthOrSentinel()
If this is an IntType, AnalogType, or sugar type for a single bit (Clock, Reset, etc) then return the...
FIRRTLBaseType getAllConstDroppedType()
Return this type with a 'const' modifiers dropped.
bool isPassive() const
Return true if this is a "passive" type - one that contains no "flip" types recursively within itself...
bool isConst() const
Returns true if this is a 'const' type that can only hold compile-time constant values.
bool isConst() const
Returns true if this is a 'const' type that can only hold compile-time constant values.
Caching version of getFieldRefFromValue.
FieldRef getFieldRefFromValue(Value value, bool lookThroughCasts=false)
Caching version of getFieldRefFromValue.
This is the common base class between SIntType and UIntType.
int32_t getWidthOrSentinel() const
Return the width of this type, or -1 if it has none specified.
static IntType get(MLIRContext *context, bool isSigned, int32_t widthOrSentinel=-1, bool isConst=false)
Return an SIntType or UIntType with the specified signedness, width, and constness.
bool hasWidth() const
Return true if this integer type has a known width.
std::optional< int32_t > getWidth() const
Return an optional containing the width, if the width is known (or empty if width is unknown).
static StringRef getInnerSymbolAttrName()
Return the name of the attribute used for inner symbol names.
int main(int argc, char **argv)
Definition driver.cpp:48
connect(destination, source)
Definition support.py:39
ClassType getInstanceTypeForClassLike(ClassLike classOp)
LogicalResult verifyTypeAgainstClassLike(ClassLike classOp, ClassType type, function_ref< InFlightDiagnostic()> emitError)
Assuming that the classOp is the source of truth, verify that the type accurately matches the signatu...
RefType getForceableResultType(bool forceable, Type type)
Return null or forceable reference result type.
mlir::DenseBoolArrayAttr packAttribute(MLIRContext *context, ArrayRef< Direction > directions)
Return a DenseBoolArrayAttr containing the packed representation of an array of directions.
static bool unGet(Direction dir)
Convert from Direction to bool. The opposite of get;.
Definition FIRRTLEnums.h:39
SmallVector< Direction > unpackAttribute(mlir::DenseBoolArrayAttr directions)
Turn a packed representation of port attributes into a vector that can be worked with.
static Direction get(bool isOutput)
Return an output direction if isOutput is true, otherwise return an input direction.
Definition FIRRTLEnums.h:36
static StringRef toString(Direction direction)
Definition FIRRTLEnums.h:44
FIRRTLType inferElementwiseResult(FIRRTLType lhs, FIRRTLType rhs, std::optional< Location > loc)
FIRRTLType inferBitwiseResult(FIRRTLType lhs, FIRRTLType rhs, std::optional< Location > loc)
FIRRTLType inferAddSubResult(FIRRTLType lhs, FIRRTLType rhs, std::optional< Location > loc)
FIRRTLType inferComparisonResult(FIRRTLType lhs, FIRRTLType rhs, std::optional< Location > loc)
FIRRTLType inferReductionResult(FIRRTLType arg, std::optional< Location > loc)
LogicalResult verifySameOperandsIntTypeKind(Operation *op)
LogicalResult verifyReferencedModule(Operation *instanceOp, SymbolTableCollection &symbolTable, mlir::FlatSymbolRefAttr moduleName)
Verify that the instance refers to a valid FIRRTL module.
BaseTy type_cast(Type type)
Flow swapFlow(Flow flow)
Get a flow's reverse.
Direction
This represents the direction of a single port.
Definition FIRRTLEnums.h:27
FieldRef getFieldRefFromValue(Value value, bool lookThroughCasts=false)
Get the FieldRef from a value.
void walkGroundTypes(FIRRTLType firrtlType, llvm::function_ref< void(uint64_t, FIRRTLBaseType, bool)> fn)
Walk leaf ground types in the firrtlType and apply the function fn.
bool isConstant(Operation *op)
Return true if the specified operation has a constant value.
bool areAnonymousTypesEquivalent(FIRRTLBaseType lhs, FIRRTLBaseType rhs)
Return true if anonymous types of given arguments are equivalent by pointer comparison.
constexpr bool isValidDst(Flow flow)
Definition FIRRTLOps.h:69
Flow foldFlow(Value val, Flow accumulatedFlow=Flow::Source)
Compute the flow for a Value, val, as determined by the FIRRTL specification.
constexpr const char * dutAnnoClass
bool areTypesEquivalent(FIRRTLType destType, FIRRTLType srcType, bool destOuterTypeIsConst=false, bool srcOuterTypeIsConst=false, bool requireSameWidths=false)
Returns whether the two types are equivalent.
bool hasDontTouch(Value value)
Check whether a block argument ("port") or the operation defining a value has a DontTouch annotation,...
size_t getNumPorts(Operation *op)
Return the number of ports in a module-like thing (modules, memories, etc)
mlir::Type getPassiveType(mlir::Type anyBaseFIRRTLType)
bool isTypeLarger(FIRRTLBaseType dstType, FIRRTLBaseType srcType)
Returns true if the destination is at least as wide as a source.
bool containsConst(Type type)
Returns true if the type is or contains a 'const' type whose value is guaranteed to be unchanging at ...
bool isDuplexValue(Value val)
Returns true if the value results from an expression with duplex flow.
SmallSet< SymbolRefAttr, 4, LayerSetCompare > LayerSet
Definition LayerSet.h:42
constexpr bool isValidSrc(Flow flow)
Definition FIRRTLOps.h:65
Value getModuleScopedDriver(Value val, bool lookThroughWires, bool lookThroughNodes, bool lookThroughCasts)
Return the value that drives another FIRRTL value within module scope.
std::pair< std::string, bool > getFieldName(const FieldRef &fieldRef, bool nameSafe=false)
Get a string identifier representing the FieldRef.
BaseTy type_dyn_cast(Type type)
bool isConst(Type type)
Returns true if this is a 'const' type whose value is guaranteed to be unchanging at circuit executio...
bool areTypesConstCastable(FIRRTLType destType, FIRRTLType srcType, bool srcOuterTypeIsConst=false)
Returns whether the srcType can be const-casted to the destType.
bool isExpression(Operation *op)
Return true if the specified operation is a firrtl expression.
DeclKind getDeclarationKind(Value val)
std::optional< int64_t > getBitWidth(FIRRTLBaseType type, bool ignoreFlip=false)
::mlir::Type getFinalTypeByFieldID(Type type, uint64_t fieldID)
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
int64_t getBitWidth(mlir::Type type)
Return the hardware bit width of a type.
Definition HWTypes.cpp:110
void info(Twine message)
Definition LSPUtils.cpp:20
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
void elideImplicitSSAName(OpAsmPrinter &printer, Operation *op, DictionaryAttr attrs, SmallVectorImpl< StringRef > &elides)
Check if the name attribute in attrs matches the SSA name of the operation's first result.
bool isAncestorOfValueOwner(Operation *op, Value value)
Return true if a Value is created "underneath" an operation.
Definition Utils.h:27
bool inferImplicitSSAName(OpAsmParser &parser, NamedAttrList &attrs)
Ensure that attrs contains a name attribute by inferring its value from the SSA name of the operation...
Definition hw.py:1
function_ref< void(Value, StringRef)> OpAsmSetValueNameFn
Definition LLVM.h:183
StringAttr getFirMemoryName() const
Compares two SymbolRefAttr lexicographically, returning true if LHS should be ordered before RHS.
Definition LayerSet.h:19
This class represents the namespace in which InnerRef's can be resolved.
InnerSymTarget lookup(hw::InnerRefAttr inner) const
Resolve the InnerRef to its target within this namespace, returning empty target if no such name exis...
This holds the name, type, direction of a module's ports.