CIRCT  20.0.0git
FIRRTLOps.cpp
Go to the documentation of this file.
1 //===- FIRRTLOps.cpp - Implement the FIRRTL operations --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implement the FIRRTL ops.
10 //
11 //===----------------------------------------------------------------------===//
12 
25 #include "circt/Support/Utils.h"
26 #include "mlir/IR/BuiltinTypes.h"
27 #include "mlir/IR/Diagnostics.h"
28 #include "mlir/IR/DialectImplementation.h"
29 #include "mlir/IR/PatternMatch.h"
30 #include "mlir/IR/SymbolTable.h"
31 #include "mlir/Interfaces/FunctionImplementation.h"
32 #include "llvm/ADT/BitVector.h"
33 #include "llvm/ADT/DenseMap.h"
34 #include "llvm/ADT/DenseSet.h"
35 #include "llvm/ADT/STLExtras.h"
36 #include "llvm/ADT/SmallSet.h"
37 #include "llvm/ADT/StringExtras.h"
38 #include "llvm/ADT/TypeSwitch.h"
39 #include "llvm/Support/FormatVariadic.h"
40 
41 using llvm::SmallDenseSet;
42 using mlir::RegionRange;
43 using namespace circt;
44 using namespace firrtl;
45 using namespace chirrtl;
46 
47 //===----------------------------------------------------------------------===//
48 // Utilities
49 //===----------------------------------------------------------------------===//
50 
51 /// Remove elements from the input array corresponding to set bits in
52 /// `indicesToDrop`, returning the elements not mentioned.
53 template <typename T>
54 static SmallVector<T>
55 removeElementsAtIndices(ArrayRef<T> input,
56  const llvm::BitVector &indicesToDrop) {
57 #ifndef NDEBUG
58  if (!input.empty()) {
59  int lastIndex = indicesToDrop.find_last();
60  if (lastIndex >= 0)
61  assert((size_t)lastIndex < input.size() && "index out of range");
62  }
63 #endif
64 
65  // If the input is empty (which is an optimization we do for certain array
66  // attributes), simply return an empty vector.
67  if (input.empty())
68  return {};
69 
70  // Copy over the live chunks.
71  size_t lastCopied = 0;
72  SmallVector<T> result;
73  result.reserve(input.size() - indicesToDrop.count());
74 
75  for (unsigned indexToDrop : indicesToDrop.set_bits()) {
76  // If we skipped over some valid elements, copy them over.
77  if (indexToDrop > lastCopied) {
78  result.append(input.begin() + lastCopied, input.begin() + indexToDrop);
79  lastCopied = indexToDrop;
80  }
81  // Ignore this value so we don't copy it in the next iteration.
82  ++lastCopied;
83  }
84 
85  // If there are live elements at the end, copy them over.
86  if (lastCopied < input.size())
87  result.append(input.begin() + lastCopied, input.end());
88 
89  return result;
90 }
91 
92 /// Emit an error if optional location is non-null, return null of return type.
93 template <typename RetTy = FIRRTLType, typename... Args>
94 static RetTy emitInferRetTypeError(std::optional<Location> loc,
95  const Twine &message, Args &&...args) {
96  if (loc)
97  (mlir::emitError(*loc, message) << ... << std::forward<Args>(args));
98  return {};
99 }
100 
101 bool firrtl::isDuplexValue(Value val) {
102  // Block arguments are not duplex values.
103  while (Operation *op = val.getDefiningOp()) {
104  auto isDuplex =
105  TypeSwitch<Operation *, std::optional<bool>>(op)
106  .Case<SubfieldOp, SubindexOp, SubaccessOp>([&val](auto op) {
107  val = op.getInput();
108  return std::nullopt;
109  })
110  .Case<RegOp, RegResetOp, WireOp>([](auto) { return true; })
111  .Default([](auto) { return false; });
112  if (isDuplex)
113  return *isDuplex;
114  }
115  return false;
116 }
117 
118 SmallVector<std::pair<circt::FieldRef, circt::FieldRef>>
119 MemOp::computeDataFlow() {
120  // If read result has non-zero latency, then no combinational dependency
121  // exists.
122  if (getReadLatency() > 0)
123  return {};
124  SmallVector<std::pair<circt::FieldRef, circt::FieldRef>> deps;
125  // Add a dependency from the enable and address fields to the data field.
126  for (auto memPort : getResults())
127  if (auto type = type_dyn_cast<BundleType>(memPort.getType())) {
128  auto enableFieldId = type.getFieldID((unsigned)ReadPortSubfield::en);
129  auto addressFieldId = type.getFieldID((unsigned)ReadPortSubfield::addr);
130  auto dataFieldId = type.getFieldID((unsigned)ReadPortSubfield::data);
131  deps.emplace_back(
132  FieldRef(memPort, static_cast<unsigned>(dataFieldId)),
133  FieldRef(memPort, static_cast<unsigned>(enableFieldId)));
134  deps.emplace_back(
135  FieldRef(memPort, static_cast<unsigned>(dataFieldId)),
136  FieldRef(memPort, static_cast<unsigned>(addressFieldId)));
137  }
138  return deps;
139 }
140 
141 /// Return the kind of port this is given the port type from a 'mem' decl.
142 static MemOp::PortKind getMemPortKindFromType(FIRRTLType type) {
143  constexpr unsigned int addr = 1 << 0;
144  constexpr unsigned int en = 1 << 1;
145  constexpr unsigned int clk = 1 << 2;
146  constexpr unsigned int data = 1 << 3;
147  constexpr unsigned int mask = 1 << 4;
148  constexpr unsigned int rdata = 1 << 5;
149  constexpr unsigned int wdata = 1 << 6;
150  constexpr unsigned int wmask = 1 << 7;
151  constexpr unsigned int wmode = 1 << 8;
152  constexpr unsigned int def = 1 << 9;
153  // Get the kind of port based on the fields of the Bundle.
154  auto portType = type_dyn_cast<BundleType>(type);
155  if (!portType)
156  return MemOp::PortKind::Debug;
157  unsigned fields = 0;
158  // Get the kind of port based on the fields of the Bundle.
159  for (auto elem : portType.getElements()) {
160  fields |= llvm::StringSwitch<unsigned>(elem.name.getValue())
161  .Case("addr", addr)
162  .Case("en", en)
163  .Case("clk", clk)
164  .Case("data", data)
165  .Case("mask", mask)
166  .Case("rdata", rdata)
167  .Case("wdata", wdata)
168  .Case("wmask", wmask)
169  .Case("wmode", wmode)
170  .Default(def);
171  }
172  if (fields == (addr | en | clk | data))
173  return MemOp::PortKind::Read;
174  if (fields == (addr | en | clk | data | mask))
175  return MemOp::PortKind::Write;
176  if (fields == (addr | en | clk | wdata | wmask | rdata | wmode))
177  return MemOp::PortKind::ReadWrite;
178  return MemOp::PortKind::Debug;
179 }
180 
182  switch (flow) {
183  case Flow::None:
184  return Flow::None;
185  case Flow::Source:
186  return Flow::Sink;
187  case Flow::Sink:
188  return Flow::Source;
189  case Flow::Duplex:
190  return Flow::Duplex;
191  }
192  // Unreachable but silences warning
193  llvm_unreachable("Unsupported Flow type.");
194 }
195 
196 const char *toString(Flow flow) {
197  switch (flow) {
198  case Flow::None:
199  return "no flow";
200  case Flow::Source:
201  return "source flow";
202  case Flow::Sink:
203  return "sink flow";
204  case Flow::Duplex:
205  return "duplex flow";
206  }
207  // Unreachable but silences warning
208  llvm_unreachable("Unsupported Flow type.");
209 }
210 
211 Flow firrtl::foldFlow(Value val, Flow accumulatedFlow) {
212 
213  if (auto blockArg = dyn_cast<BlockArgument>(val)) {
214  auto *op = val.getParentBlock()->getParentOp();
215  if (auto moduleLike = dyn_cast<FModuleLike>(op)) {
216  auto direction = moduleLike.getPortDirection(blockArg.getArgNumber());
217  if (direction == Direction::Out)
218  return swapFlow(accumulatedFlow);
219  }
220  return accumulatedFlow;
221  }
222 
223  Operation *op = val.getDefiningOp();
224 
225  return TypeSwitch<Operation *, Flow>(op)
226  .Case<SubfieldOp, OpenSubfieldOp>([&](auto op) {
227  return foldFlow(op.getInput(), op.isFieldFlipped()
228  ? swapFlow(accumulatedFlow)
229  : accumulatedFlow);
230  })
231  .Case<SubindexOp, SubaccessOp, OpenSubindexOp, RefSubOp>(
232  [&](auto op) { return foldFlow(op.getInput(), accumulatedFlow); })
233  // Registers, Wires, and behavioral memory ports are always Duplex.
234  .Case<RegOp, RegResetOp, WireOp, MemoryPortOp>(
235  [](auto) { return Flow::Duplex; })
236  .Case<InstanceOp, InstanceChoiceOp>([&](auto inst) {
237  auto resultNo = cast<OpResult>(val).getResultNumber();
238  if (inst.getPortDirection(resultNo) == Direction::Out)
239  return accumulatedFlow;
240  return swapFlow(accumulatedFlow);
241  })
242  .Case<MemOp>([&](auto op) {
243  // only debug ports with RefType have source flow.
244  if (type_isa<RefType>(val.getType()))
245  return Flow::Source;
246  return swapFlow(accumulatedFlow);
247  })
248  .Case<ObjectSubfieldOp>([&](ObjectSubfieldOp op) {
249  auto input = op.getInput();
250  auto *inputOp = input.getDefiningOp();
251 
252  // We are directly accessing a port on a local declaration.
253  if (auto objectOp = dyn_cast_or_null<ObjectOp>(inputOp)) {
254  auto classType = input.getType();
255  auto direction = classType.getElement(op.getIndex()).direction;
256  if (direction == Direction::In)
257  return Flow::Sink;
258  return Flow::Source;
259  }
260 
261  // We are accessing a remote object. Input ports on remote objects are
262  // inaccessible, and thus have Flow::None. Walk backwards through the
263  // chain of subindexes, to detect if we have indexed through an input
264  // port. At the end, either we did index through an input port, or the
265  // entire path was through output ports with source flow.
266  while (true) {
267  auto classType = input.getType();
268  auto direction = classType.getElement(op.getIndex()).direction;
269  if (direction == Direction::In)
270  return Flow::None;
271 
272  op = dyn_cast_or_null<ObjectSubfieldOp>(inputOp);
273  if (op) {
274  input = op.getInput();
275  inputOp = input.getDefiningOp();
276  continue;
277  }
278 
279  return accumulatedFlow;
280  };
281  })
282  // Anything else acts like a universal source.
283  .Default([&](auto) { return accumulatedFlow; });
284 }
285 
286 // TODO: This is doing the same walk as foldFlow. These two functions can be
287 // combined and return a (flow, kind) product.
289  Operation *op = val.getDefiningOp();
290  if (!op)
291  return DeclKind::Port;
292 
293  return TypeSwitch<Operation *, DeclKind>(op)
294  .Case<InstanceOp>([](auto) { return DeclKind::Instance; })
295  .Case<SubfieldOp, SubindexOp, SubaccessOp, OpenSubfieldOp, OpenSubindexOp,
296  RefSubOp>([](auto op) { return getDeclarationKind(op.getInput()); })
297  .Default([](auto) { return DeclKind::Other; });
298 }
299 
300 size_t firrtl::getNumPorts(Operation *op) {
301  if (auto module = dyn_cast<FModuleLike>(op))
302  return module.getNumPorts();
303  return op->getNumResults();
304 }
305 
306 /// Check whether an operation has a `DontTouch` annotation, or a symbol that
307 /// should prevent certain types of canonicalizations.
308 bool firrtl::hasDontTouch(Operation *op) {
309  return op->getAttr(hw::InnerSymbolTable::getInnerSymbolAttrName()) ||
311 }
312 
313 /// Check whether a block argument ("port") or the operation defining a value
314 /// has a `DontTouch` annotation, or a symbol that should prevent certain types
315 /// of canonicalizations.
316 bool firrtl::hasDontTouch(Value value) {
317  if (auto *op = value.getDefiningOp())
318  return hasDontTouch(op);
319  auto arg = dyn_cast<BlockArgument>(value);
320  auto module = cast<FModuleOp>(arg.getOwner()->getParentOp());
321  return (module.getPortSymbolAttr(arg.getArgNumber())) ||
322  AnnotationSet::forPort(module, arg.getArgNumber()).hasDontTouch();
323 }
324 
325 /// Get a special name to use when printing the entry block arguments of the
326 /// region contained by an operation in this dialect.
327 void getAsmBlockArgumentNamesImpl(Operation *op, mlir::Region &region,
328  OpAsmSetValueNameFn setNameFn) {
329  if (region.empty())
330  return;
331  auto *parentOp = op;
332  auto *block = &region.front();
333  // Check to see if the operation containing the arguments has 'firrtl.name'
334  // attributes for them. If so, use that as the name.
335  auto argAttr = parentOp->getAttrOfType<ArrayAttr>("portNames");
336  // Do not crash on invalid IR.
337  if (!argAttr || argAttr.size() != block->getNumArguments())
338  return;
339 
340  for (size_t i = 0, e = block->getNumArguments(); i != e; ++i) {
341  auto str = cast<StringAttr>(argAttr[i]).getValue();
342  if (!str.empty())
343  setNameFn(block->getArgument(i), str);
344  }
345 }
346 
347 /// A forward declaration for `NameKind` attribute parser.
348 static ParseResult parseNameKind(OpAsmParser &parser,
349  firrtl::NameKindEnumAttr &result);
350 
351 //===----------------------------------------------------------------------===//
352 // Layer Verification Utilities
353 //===----------------------------------------------------------------------===//
354 
355 namespace {
356 struct CompareSymbolRefAttr {
357  // True if lhs is lexicographically less than rhs.
358  bool operator()(SymbolRefAttr lhs, SymbolRefAttr rhs) const {
359  auto cmp = lhs.getRootReference().compare(rhs.getRootReference());
360  if (cmp == -1)
361  return true;
362  if (cmp == 1)
363  return false;
364  auto lhsNested = lhs.getNestedReferences();
365  auto rhsNested = rhs.getNestedReferences();
366  auto lhsNestedSize = lhsNested.size();
367  auto rhsNestedSize = rhsNested.size();
368  auto e = std::min(lhsNestedSize, rhsNestedSize);
369  for (unsigned i = 0; i < e; ++i) {
370  auto cmp = lhsNested[i].getAttr().compare(rhsNested[i].getAttr());
371  if (cmp == -1)
372  return true;
373  if (cmp == 1)
374  return false;
375  }
376  return lhsNestedSize < rhsNestedSize;
377  }
378 };
379 } // namespace
380 
382 
383 /// Get the ambient layers active at the given op.
384 static LayerSet getAmbientLayersAt(Operation *op) {
385  // Crawl through the parent ops, accumulating all ambient layers at the given
386  // operation.
387  LayerSet result;
388  for (; op != nullptr; op = op->getParentOp()) {
389  if (auto module = dyn_cast<FModuleLike>(op)) {
390  auto layers = module.getLayersAttr().getAsRange<SymbolRefAttr>();
391  result.insert(layers.begin(), layers.end());
392  break;
393  }
394  if (auto layerblock = dyn_cast<LayerBlockOp>(op)) {
395  result.insert(layerblock.getLayerName());
396  continue;
397  }
398  }
399  return result;
400 }
401 
402 /// Get the ambient layer requirements at the definition site of the value.
403 static LayerSet getAmbientLayersFor(Value value) {
404  return getAmbientLayersAt(getFieldRefFromValue(value).getDefiningOp());
405 }
406 
407 /// Get the effective layer requirements for the given value.
408 /// The effective layers for a value is the union of
409 /// - the ambient layers for the cannonical storage location.
410 /// - any explicit layer annotations in the value's type.
411 static LayerSet getLayersFor(Value value) {
412  auto result = getAmbientLayersFor(value);
413  if (auto type = dyn_cast<RefType>(value.getType()))
414  if (auto layer = type.getLayer())
415  result.insert(type.getLayer());
416  return result;
417 }
418 
419 /// Check that the source layer is compatible with the destination layer.
420 /// Either the source and destination are identical, or the source-layer
421 /// is a parent of the destination. For example `A` is compatible with `A.B.C`,
422 /// because any definition valid in `A` is also valid in `A.B.C`.
423 static bool isLayerCompatibleWith(mlir::SymbolRefAttr srcLayer,
424  mlir::SymbolRefAttr dstLayer) {
425  // A non-colored probe may be cast to any colored probe.
426  if (!srcLayer)
427  return true;
428 
429  // A colored probe cannot be cast to an uncolored probe.
430  if (!dstLayer)
431  return false;
432 
433  // Return true if the srcLayer is a prefix of the dstLayer.
434  if (srcLayer.getRootReference() != dstLayer.getRootReference())
435  return false;
436 
437  auto srcNames = srcLayer.getNestedReferences();
438  auto dstNames = dstLayer.getNestedReferences();
439  if (dstNames.size() < srcNames.size())
440  return false;
441 
442  return llvm::all_of(llvm::zip_first(srcNames, dstNames),
443  [](auto x) { return std::get<0>(x) == std::get<1>(x); });
444 }
445 
446 /// Check that the source layer is present in the destination layers.
447 static bool isLayerCompatibleWith(SymbolRefAttr srcLayer,
448  const LayerSet &dstLayers) {
449  // fast path: the required layer is directly listed in the provided layers.
450  if (dstLayers.contains(srcLayer))
451  return true;
452 
453  // Slow path: the required layer is not directly listed in the provided
454  // layers, but the layer may still be provided by a nested layer.
455  return any_of(dstLayers, [=](SymbolRefAttr dstLayer) {
456  return isLayerCompatibleWith(srcLayer, dstLayer);
457  });
458 }
459 
460 /// Check that the source layers are all present in the destination layers.
461 /// True if all source layers are present in the destination.
462 /// Outputs the set of source layers that are missing in the destination.
463 static bool isLayerSetCompatibleWith(const LayerSet &src, const LayerSet &dst,
464  SmallVectorImpl<SymbolRefAttr> &missing) {
465  for (auto srcLayer : src)
466  if (!isLayerCompatibleWith(srcLayer, dst))
467  missing.push_back(srcLayer);
468 
469  llvm::sort(missing, CompareSymbolRefAttr());
470  return missing.empty();
471 }
472 
473 //===----------------------------------------------------------------------===//
474 // CircuitOp
475 //===----------------------------------------------------------------------===//
476 
477 void CircuitOp::build(OpBuilder &builder, OperationState &result,
478  StringAttr name, ArrayAttr annotations) {
479  // Add an attribute for the name.
480  result.addAttribute(builder.getStringAttr("name"), name);
481 
482  if (!annotations)
483  annotations = builder.getArrayAttr({});
484  result.addAttribute("annotations", annotations);
485 
486  // Create a region and a block for the body.
487  Region *bodyRegion = result.addRegion();
488  Block *body = new Block();
489  bodyRegion->push_back(body);
490 }
491 
492 static ParseResult parseCircuitOpAttrs(OpAsmParser &parser,
493  NamedAttrList &resultAttrs) {
494  auto result = parser.parseOptionalAttrDictWithKeyword(resultAttrs);
495  if (!resultAttrs.get("annotations"))
496  resultAttrs.append("annotations", parser.getBuilder().getArrayAttr({}));
497 
498  return result;
499 }
500 
501 static void printCircuitOpAttrs(OpAsmPrinter &p, Operation *op,
502  DictionaryAttr attr) {
503  // "name" is always elided.
504  SmallVector<StringRef> elidedAttrs = {"name"};
505  // Elide "annotations" if it doesn't exist or if it is empty
506  auto annotationsAttr = op->getAttrOfType<ArrayAttr>("annotations");
507  if (annotationsAttr.empty())
508  elidedAttrs.push_back("annotations");
509 
510  p.printOptionalAttrDictWithKeyword(op->getAttrs(), elidedAttrs);
511 }
512 
513 LogicalResult CircuitOp::verifyRegions() {
514  StringRef main = getName();
515 
516  // Check that the circuit has a non-empty name.
517  if (main.empty()) {
518  emitOpError("must have a non-empty name");
519  return failure();
520  }
521 
522  mlir::SymbolTable symtbl(getOperation());
523 
524  auto *mainModule = symtbl.lookup(main);
525  if (!mainModule)
526  return emitOpError().append(
527  "does not contain module with same name as circuit");
528  if (!isa<FModuleLike>(mainModule))
529  return mainModule->emitError(
530  "entity with name of circuit must be a module");
531  if (symtbl.getSymbolVisibility(mainModule) !=
532  mlir::SymbolTable::Visibility::Public)
533  return mainModule->emitError("main module must be public");
534 
535  // Store a mapping of defname to either the first external module
536  // that defines it or, preferentially, the first external module
537  // that defines it and has no parameters.
538  llvm::DenseMap<Attribute, FExtModuleOp> defnameMap;
539 
540  auto verifyExtModule = [&](FExtModuleOp extModule) -> LogicalResult {
541  if (!extModule)
542  return success();
543 
544  auto defname = extModule.getDefnameAttr();
545  if (!defname)
546  return success();
547 
548  // Check that this extmodule's defname does not conflict with
549  // the symbol name of any module.
550  if (auto collidingModule = symtbl.lookup<FModuleOp>(defname.getValue()))
551  return extModule.emitOpError()
552  .append("attribute 'defname' with value ", defname,
553  " conflicts with the name of another module in the circuit")
554  .attachNote(collidingModule.getLoc())
555  .append("previous module declared here");
556 
557  // Find an optional extmodule with a defname collision. Update
558  // the defnameMap if this is the first extmodule with that
559  // defname or if the current extmodule takes no parameters and
560  // the collision does. The latter condition improves later
561  // extmodule verification as checking against a parameterless
562  // module is stricter.
563  FExtModuleOp collidingExtModule;
564  if (auto &value = defnameMap[defname]) {
565  collidingExtModule = value;
566  if (!value.getParameters().empty() && extModule.getParameters().empty())
567  value = extModule;
568  } else {
569  value = extModule;
570  // Go to the next extmodule if no extmodule with the same
571  // defname was found.
572  return success();
573  }
574 
575  // Check that the number of ports is exactly the same.
576  SmallVector<PortInfo> ports = extModule.getPorts();
577  SmallVector<PortInfo> collidingPorts = collidingExtModule.getPorts();
578 
579  if (ports.size() != collidingPorts.size())
580  return extModule.emitOpError()
581  .append("with 'defname' attribute ", defname, " has ", ports.size(),
582  " ports which is different from a previously defined "
583  "extmodule with the same 'defname' which has ",
584  collidingPorts.size(), " ports")
585  .attachNote(collidingExtModule.getLoc())
586  .append("previous extmodule definition occurred here");
587 
588  // Check that ports match for name and type. Since parameters
589  // *might* affect widths, ignore widths if either module has
590  // parameters. Note that this allows for misdetections, but
591  // has zero false positives.
592  for (auto p : llvm::zip(ports, collidingPorts)) {
593  StringAttr aName = std::get<0>(p).name, bName = std::get<1>(p).name;
594  Type aType = std::get<0>(p).type, bType = std::get<1>(p).type;
595 
596  if (aName != bName)
597  return extModule.emitOpError()
598  .append("with 'defname' attribute ", defname,
599  " has a port with name ", aName,
600  " which does not match the name of the port in the same "
601  "position of a previously defined extmodule with the same "
602  "'defname', expected port to have name ",
603  bName)
604  .attachNote(collidingExtModule.getLoc())
605  .append("previous extmodule definition occurred here");
606 
607  if (!extModule.getParameters().empty() ||
608  !collidingExtModule.getParameters().empty()) {
609  // Compare base types as widthless, others must match.
610  if (auto base = type_dyn_cast<FIRRTLBaseType>(aType))
611  aType = base.getWidthlessType();
612  if (auto base = type_dyn_cast<FIRRTLBaseType>(bType))
613  bType = base.getWidthlessType();
614  }
615  if (aType != bType)
616  return extModule.emitOpError()
617  .append("with 'defname' attribute ", defname,
618  " has a port with name ", aName,
619  " which has a different type ", aType,
620  " which does not match the type of the port in the same "
621  "position of a previously defined extmodule with the same "
622  "'defname', expected port to have type ",
623  bType)
624  .attachNote(collidingExtModule.getLoc())
625  .append("previous extmodule definition occurred here");
626  }
627  return success();
628  };
629 
630  for (auto &op : *getBodyBlock()) {
631  // Verify external modules.
632  if (auto extModule = dyn_cast<FExtModuleOp>(op)) {
633  if (verifyExtModule(extModule).failed())
634  return failure();
635  }
636  }
637 
638  return success();
639 }
640 
641 Block *CircuitOp::getBodyBlock() { return &getBody().front(); }
642 
643 //===----------------------------------------------------------------------===//
644 // FExtModuleOp and FModuleOp
645 //===----------------------------------------------------------------------===//
646 
647 static SmallVector<PortInfo> getPortImpl(FModuleLike module) {
648  SmallVector<PortInfo> results;
649  for (unsigned i = 0, e = module.getNumPorts(); i < e; ++i) {
650  results.push_back({module.getPortNameAttr(i), module.getPortType(i),
651  module.getPortDirection(i), module.getPortSymbolAttr(i),
652  module.getPortLocation(i),
653  AnnotationSet::forPort(module, i)});
654  }
655  return results;
656 }
657 
658 SmallVector<PortInfo> FModuleOp::getPorts() { return ::getPortImpl(*this); }
659 
660 SmallVector<PortInfo> FExtModuleOp::getPorts() { return ::getPortImpl(*this); }
661 
662 SmallVector<PortInfo> FIntModuleOp::getPorts() { return ::getPortImpl(*this); }
663 
664 SmallVector<PortInfo> FMemModuleOp::getPorts() { return ::getPortImpl(*this); }
665 
667  if (dir == Direction::In)
669  if (dir == Direction::Out)
671  assert(0 && "invalid direction");
672  abort();
673 }
674 
675 static SmallVector<hw::PortInfo> getPortListImpl(FModuleLike module) {
676  SmallVector<hw::PortInfo> results;
677  auto aname = StringAttr::get(module.getContext(),
678  hw::HWModuleLike::getPortSymbolAttrName());
679  auto emptyDict = DictionaryAttr::get(module.getContext());
680  for (unsigned i = 0, e = getNumPorts(module); i < e; ++i) {
681  auto sym = module.getPortSymbolAttr(i);
682  results.push_back(
683  {{module.getPortNameAttr(i), module.getPortType(i),
684  dirFtoH(module.getPortDirection(i))},
685  i,
686  sym ? DictionaryAttr::get(
687  module.getContext(),
688  ArrayRef<mlir::NamedAttribute>{NamedAttribute{aname, sym}})
689  : emptyDict,
690  module.getPortLocation(i)});
691  }
692  return results;
693 }
694 
695 SmallVector<::circt::hw::PortInfo> FModuleOp::getPortList() {
697 }
698 
699 SmallVector<::circt::hw::PortInfo> FExtModuleOp::getPortList() {
701 }
702 
703 SmallVector<::circt::hw::PortInfo> FIntModuleOp::getPortList() {
705 }
706 
707 SmallVector<::circt::hw::PortInfo> FMemModuleOp::getPortList() {
709 }
710 
711 static hw::PortInfo getPortImpl(FModuleLike module, size_t idx) {
712  return {{module.getPortNameAttr(idx), module.getPortType(idx),
713  dirFtoH(module.getPortDirection(idx))},
714  idx,
716  module.getContext(),
717  ArrayRef<mlir::NamedAttribute>{NamedAttribute{
718  StringAttr::get(module.getContext(),
719  hw::HWModuleLike::getPortSymbolAttrName()),
720  module.getPortSymbolAttr(idx)}}),
721  module.getPortLocation(idx)};
722 }
723 
725  return ::getPortImpl(*this, idx);
726 }
727 
729  return ::getPortImpl(*this, idx);
730 }
731 
733  return ::getPortImpl(*this, idx);
734 }
735 
737  return ::getPortImpl(*this, idx);
738 }
739 
740 // Return the port with the specified name.
741 BlockArgument FModuleOp::getArgument(size_t portNumber) {
742  return getBodyBlock()->getArgument(portNumber);
743 }
744 
745 /// Inserts the given ports. The insertion indices are expected to be in order.
746 /// Insertion occurs in-order, such that ports with the same insertion index
747 /// appear in the module in the same order they appeared in the list.
748 static void insertPorts(FModuleLike op,
749  ArrayRef<std::pair<unsigned, PortInfo>> ports,
750  bool supportsInternalPaths = false) {
751  if (ports.empty())
752  return;
753  unsigned oldNumArgs = op.getNumPorts();
754  unsigned newNumArgs = oldNumArgs + ports.size();
755 
756  // Add direction markers and names for new ports.
757  auto existingDirections = op.getPortDirectionsAttr();
758  ArrayRef<Attribute> existingNames = op.getPortNames();
759  ArrayRef<Attribute> existingTypes = op.getPortTypes();
760  ArrayRef<Attribute> existingLocs = op.getPortLocations();
761  assert(existingDirections.size() == oldNumArgs);
762  assert(existingNames.size() == oldNumArgs);
763  assert(existingTypes.size() == oldNumArgs);
764  assert(existingLocs.size() == oldNumArgs);
765  SmallVector<Attribute> internalPaths;
766  auto emptyInternalPath = InternalPathAttr::get(op.getContext());
767  if (supportsInternalPaths) {
768  if (auto internalPathsAttr = op->getAttrOfType<ArrayAttr>("internalPaths"))
769  llvm::append_range(internalPaths, internalPathsAttr);
770  else
771  internalPaths.resize(oldNumArgs, emptyInternalPath);
772  assert(internalPaths.size() == oldNumArgs);
773  }
774 
775  SmallVector<bool> newDirections;
776  SmallVector<Attribute> newNames, newTypes, newAnnos, newSyms, newLocs,
777  newInternalPaths;
778  newDirections.reserve(newNumArgs);
779  newNames.reserve(newNumArgs);
780  newTypes.reserve(newNumArgs);
781  newAnnos.reserve(newNumArgs);
782  newSyms.reserve(newNumArgs);
783  newLocs.reserve(newNumArgs);
784  newInternalPaths.reserve(newNumArgs);
785 
786  auto emptyArray = ArrayAttr::get(op.getContext(), {});
787 
788  unsigned oldIdx = 0;
789  auto migrateOldPorts = [&](unsigned untilOldIdx) {
790  while (oldIdx < oldNumArgs && oldIdx < untilOldIdx) {
791  newDirections.push_back(existingDirections[oldIdx]);
792  newNames.push_back(existingNames[oldIdx]);
793  newTypes.push_back(existingTypes[oldIdx]);
794  newAnnos.push_back(op.getAnnotationsAttrForPort(oldIdx));
795  newSyms.push_back(op.getPortSymbolAttr(oldIdx));
796  newLocs.push_back(existingLocs[oldIdx]);
797  if (supportsInternalPaths)
798  newInternalPaths.push_back(internalPaths[oldIdx]);
799  ++oldIdx;
800  }
801  };
802  for (auto pair : llvm::enumerate(ports)) {
803  auto idx = pair.value().first;
804  auto &port = pair.value().second;
805  migrateOldPorts(idx);
806  newDirections.push_back(direction::unGet(port.direction));
807  newNames.push_back(port.name);
808  newTypes.push_back(TypeAttr::get(port.type));
809  auto annos = port.annotations.getArrayAttr();
810  newAnnos.push_back(annos ? annos : emptyArray);
811  newSyms.push_back(port.sym);
812  newLocs.push_back(port.loc);
813  if (supportsInternalPaths)
814  newInternalPaths.push_back(emptyInternalPath);
815  }
816  migrateOldPorts(oldNumArgs);
817 
818  // The lack of *any* port annotations is represented by an empty
819  // `portAnnotations` array as a shorthand.
820  if (llvm::all_of(newAnnos, [](Attribute attr) {
821  return cast<ArrayAttr>(attr).empty();
822  }))
823  newAnnos.clear();
824 
825  // Apply these changed markers.
826  op->setAttr("portDirections",
827  direction::packAttribute(op.getContext(), newDirections));
828  op->setAttr("portNames", ArrayAttr::get(op.getContext(), newNames));
829  op->setAttr("portTypes", ArrayAttr::get(op.getContext(), newTypes));
830  op->setAttr("portAnnotations", ArrayAttr::get(op.getContext(), newAnnos));
831  op.setPortSymbols(newSyms);
832  op->setAttr("portLocations", ArrayAttr::get(op.getContext(), newLocs));
833  if (supportsInternalPaths) {
834  // Drop if all-empty, otherwise set to new array.
835  auto empty = llvm::all_of(newInternalPaths, [](Attribute attr) {
836  return !cast<InternalPathAttr>(attr).getPath();
837  });
838  if (empty)
839  op->removeAttr("internalPaths");
840  else
841  op->setAttr("internalPaths",
842  ArrayAttr::get(op.getContext(), newInternalPaths));
843  }
844 }
845 
846 /// Erases the ports that have their corresponding bit set in `portIndices`.
847 static void erasePorts(FModuleLike op, const llvm::BitVector &portIndices) {
848  if (portIndices.none())
849  return;
850 
851  // Drop the direction markers for dead ports.
852  ArrayRef<bool> portDirections = op.getPortDirectionsAttr().asArrayRef();
853  ArrayRef<Attribute> portNames = op.getPortNames();
854  ArrayRef<Attribute> portTypes = op.getPortTypes();
855  ArrayRef<Attribute> portAnnos = op.getPortAnnotations();
856  ArrayRef<Attribute> portSyms = op.getPortSymbols();
857  ArrayRef<Attribute> portLocs = op.getPortLocations();
858  auto numPorts = op.getNumPorts();
859  (void)numPorts;
860  assert(portDirections.size() == numPorts);
861  assert(portNames.size() == numPorts);
862  assert(portAnnos.size() == numPorts || portAnnos.empty());
863  assert(portTypes.size() == numPorts);
864  assert(portSyms.size() == numPorts || portSyms.empty());
865  assert(portLocs.size() == numPorts);
866 
867  SmallVector<bool> newPortDirections =
868  removeElementsAtIndices<bool>(portDirections, portIndices);
869  SmallVector<Attribute> newPortNames, newPortTypes, newPortAnnos, newPortSyms,
870  newPortLocs;
871  newPortNames = removeElementsAtIndices(portNames, portIndices);
872  newPortTypes = removeElementsAtIndices(portTypes, portIndices);
873  newPortAnnos = removeElementsAtIndices(portAnnos, portIndices);
874  newPortSyms = removeElementsAtIndices(portSyms, portIndices);
875  newPortLocs = removeElementsAtIndices(portLocs, portIndices);
876  op->setAttr("portDirections",
877  direction::packAttribute(op.getContext(), newPortDirections));
878  op->setAttr("portNames", ArrayAttr::get(op.getContext(), newPortNames));
879  op->setAttr("portAnnotations", ArrayAttr::get(op.getContext(), newPortAnnos));
880  op->setAttr("portTypes", ArrayAttr::get(op.getContext(), newPortTypes));
881  FModuleLike::fixupPortSymsArray(newPortSyms, op.getContext());
882  op->setAttr("portSyms", ArrayAttr::get(op.getContext(), newPortSyms));
883  op->setAttr("portLocations", ArrayAttr::get(op.getContext(), newPortLocs));
884 }
885 
886 template <typename T>
887 static void eraseInternalPaths(T op, const llvm::BitVector &portIndices) {
888  // Fixup internalPaths array.
889  auto internalPaths = op.getInternalPaths();
890  if (!internalPaths)
891  return;
892 
893  auto newPaths =
894  removeElementsAtIndices(internalPaths->getValue(), portIndices);
895 
896  // Drop if all-empty, otherwise set to new array.
897  auto empty = llvm::all_of(newPaths, [](Attribute attr) {
898  return !cast<InternalPathAttr>(attr).getPath();
899  });
900  if (empty)
901  op.removeInternalPathsAttr();
902  else
903  op.setInternalPathsAttr(ArrayAttr::get(op.getContext(), newPaths));
904 }
905 
906 void FExtModuleOp::erasePorts(const llvm::BitVector &portIndices) {
907  ::erasePorts(cast<FModuleLike>((Operation *)*this), portIndices);
908  eraseInternalPaths(*this, portIndices);
909 }
910 
911 void FIntModuleOp::erasePorts(const llvm::BitVector &portIndices) {
912  ::erasePorts(cast<FModuleLike>((Operation *)*this), portIndices);
913  eraseInternalPaths(*this, portIndices);
914 }
915 
916 void FMemModuleOp::erasePorts(const llvm::BitVector &portIndices) {
917  ::erasePorts(cast<FModuleLike>((Operation *)*this), portIndices);
918 }
919 
920 void FModuleOp::erasePorts(const llvm::BitVector &portIndices) {
921  ::erasePorts(cast<FModuleLike>((Operation *)*this), portIndices);
922  getBodyBlock()->eraseArguments(portIndices);
923 }
924 
925 /// Inserts the given ports. The insertion indices are expected to be in order.
926 /// Insertion occurs in-order, such that ports with the same insertion index
927 /// appear in the module in the same order they appeared in the list.
928 void FModuleOp::insertPorts(ArrayRef<std::pair<unsigned, PortInfo>> ports) {
929  ::insertPorts(cast<FModuleLike>((Operation *)*this), ports);
930 
931  // Insert the block arguments.
932  auto *body = getBodyBlock();
933  for (size_t i = 0, e = ports.size(); i < e; ++i) {
934  // Block arguments are inserted one at a time, so for each argument we
935  // insert we have to increase the index by 1.
936  auto &[index, port] = ports[i];
937  body->insertArgument(index + i, port.type, port.loc);
938  }
939 }
940 
941 void FExtModuleOp::insertPorts(ArrayRef<std::pair<unsigned, PortInfo>> ports) {
942  ::insertPorts(cast<FModuleLike>((Operation *)*this), ports,
943  /*supportsInternalPaths=*/true);
944 }
945 
946 void FIntModuleOp::insertPorts(ArrayRef<std::pair<unsigned, PortInfo>> ports) {
947  ::insertPorts(cast<FModuleLike>((Operation *)*this), ports,
948  /*supportsInternalPaths=*/true);
949 }
950 
951 /// Inserts the given ports. The insertion indices are expected to be in order.
952 /// Insertion occurs in-order, such that ports with the same insertion index
953 /// appear in the module in the same order they appeared in the list.
954 void FMemModuleOp::insertPorts(ArrayRef<std::pair<unsigned, PortInfo>> ports) {
955  ::insertPorts(cast<FModuleLike>((Operation *)*this), ports);
956 }
957 
958 static void buildModuleLike(OpBuilder &builder, OperationState &result,
959  StringAttr name, ArrayRef<PortInfo> ports,
960  ArrayAttr annotations, ArrayAttr layers,
961  bool withAnnotations, bool withLayers) {
962  // Add an attribute for the name.
963  result.addAttribute(::mlir::SymbolTable::getSymbolAttrName(), name);
964 
965  // Record the names of the arguments if present.
966  SmallVector<Direction, 4> portDirections;
967  SmallVector<Attribute, 4> portNames;
968  SmallVector<Attribute, 4> portTypes;
969  SmallVector<Attribute, 4> portAnnotations;
970  SmallVector<Attribute, 4> portSyms;
971  SmallVector<Attribute, 4> portLocs;
972  for (const auto &port : ports) {
973  portDirections.push_back(port.direction);
974  portNames.push_back(port.name);
975  portTypes.push_back(TypeAttr::get(port.type));
976  portAnnotations.push_back(port.annotations.getArrayAttr());
977  portSyms.push_back(port.sym);
978  portLocs.push_back(port.loc);
979  }
980 
981  // The lack of *any* port annotations is represented by an empty
982  // `portAnnotations` array as a shorthand.
983  if (llvm::all_of(portAnnotations, [](Attribute attr) {
984  return cast<ArrayAttr>(attr).empty();
985  }))
986  portAnnotations.clear();
987 
988  FModuleLike::fixupPortSymsArray(portSyms, builder.getContext());
989  // Both attributes are added, even if the module has no ports.
990  result.addAttribute(
991  "portDirections",
992  direction::packAttribute(builder.getContext(), portDirections));
993  result.addAttribute("portNames", builder.getArrayAttr(portNames));
994  result.addAttribute("portTypes", builder.getArrayAttr(portTypes));
995  result.addAttribute("portSyms", builder.getArrayAttr(portSyms));
996  result.addAttribute("portLocations", builder.getArrayAttr(portLocs));
997 
998  if (withAnnotations) {
999  if (!annotations)
1000  annotations = builder.getArrayAttr({});
1001  result.addAttribute("annotations", annotations);
1002  result.addAttribute("portAnnotations",
1003  builder.getArrayAttr(portAnnotations));
1004  }
1005 
1006  if (withLayers) {
1007  if (!layers)
1008  layers = builder.getArrayAttr({});
1009  result.addAttribute("layers", layers);
1010  }
1011 
1012  result.addRegion();
1013 }
1014 
1015 static void buildModule(OpBuilder &builder, OperationState &result,
1016  StringAttr name, ArrayRef<PortInfo> ports,
1017  ArrayAttr annotations, ArrayAttr layers) {
1018  buildModuleLike(builder, result, name, ports, annotations, layers,
1019  /*withAnnotations=*/true, /*withLayers=*/true);
1020 }
1021 
1022 static void buildClass(OpBuilder &builder, OperationState &result,
1023  StringAttr name, ArrayRef<PortInfo> ports) {
1024  return buildModuleLike(builder, result, name, ports, {}, {},
1025  /*withAnnotations=*/false,
1026  /*withLayers=*/false);
1027 }
1028 
1029 void FModuleOp::build(OpBuilder &builder, OperationState &result,
1030  StringAttr name, ConventionAttr convention,
1031  ArrayRef<PortInfo> ports, ArrayAttr annotations,
1032  ArrayAttr layers) {
1033  buildModule(builder, result, name, ports, annotations, layers);
1034  result.addAttribute("convention", convention);
1035 
1036  // Create a region and a block for the body.
1037  auto *bodyRegion = result.regions[0].get();
1038  Block *body = new Block();
1039  bodyRegion->push_back(body);
1040 
1041  // Add arguments to the body block.
1042  for (auto &elt : ports)
1043  body->addArgument(elt.type, elt.loc);
1044 }
1045 
1046 void FExtModuleOp::build(OpBuilder &builder, OperationState &result,
1047  StringAttr name, ConventionAttr convention,
1048  ArrayRef<PortInfo> ports, StringRef defnameAttr,
1049  ArrayAttr annotations, ArrayAttr parameters,
1050  ArrayAttr internalPaths, ArrayAttr layers) {
1051  buildModule(builder, result, name, ports, annotations, layers);
1052  result.addAttribute("convention", convention);
1053  if (!defnameAttr.empty())
1054  result.addAttribute("defname", builder.getStringAttr(defnameAttr));
1055  if (!parameters)
1056  parameters = builder.getArrayAttr({});
1057  result.addAttribute(getParametersAttrName(result.name), parameters);
1058  if (internalPaths && !internalPaths.empty())
1059  result.addAttribute(getInternalPathsAttrName(result.name), internalPaths);
1060 }
1061 
1062 void FIntModuleOp::build(OpBuilder &builder, OperationState &result,
1063  StringAttr name, ArrayRef<PortInfo> ports,
1064  StringRef intrinsicNameStr, ArrayAttr annotations,
1065  ArrayAttr parameters, ArrayAttr internalPaths,
1066  ArrayAttr layers) {
1067  buildModule(builder, result, name, ports, annotations, layers);
1068  result.addAttribute("intrinsic", builder.getStringAttr(intrinsicNameStr));
1069  if (!parameters)
1070  parameters = builder.getArrayAttr({});
1071  result.addAttribute(getParametersAttrName(result.name), parameters);
1072  if (internalPaths && !internalPaths.empty())
1073  result.addAttribute(getInternalPathsAttrName(result.name), internalPaths);
1074 }
1075 
1076 void FMemModuleOp::build(OpBuilder &builder, OperationState &result,
1077  StringAttr name, ArrayRef<PortInfo> ports,
1078  uint32_t numReadPorts, uint32_t numWritePorts,
1079  uint32_t numReadWritePorts, uint32_t dataWidth,
1080  uint32_t maskBits, uint32_t readLatency,
1081  uint32_t writeLatency, uint64_t depth,
1082  ArrayAttr annotations, ArrayAttr layers) {
1083  auto *context = builder.getContext();
1084  buildModule(builder, result, name, ports, annotations, layers);
1085  auto ui32Type = IntegerType::get(context, 32, IntegerType::Unsigned);
1086  auto ui64Type = IntegerType::get(context, 64, IntegerType::Unsigned);
1087  result.addAttribute("numReadPorts", IntegerAttr::get(ui32Type, numReadPorts));
1088  result.addAttribute("numWritePorts",
1089  IntegerAttr::get(ui32Type, numWritePorts));
1090  result.addAttribute("numReadWritePorts",
1091  IntegerAttr::get(ui32Type, numReadWritePorts));
1092  result.addAttribute("dataWidth", IntegerAttr::get(ui32Type, dataWidth));
1093  result.addAttribute("maskBits", IntegerAttr::get(ui32Type, maskBits));
1094  result.addAttribute("readLatency", IntegerAttr::get(ui32Type, readLatency));
1095  result.addAttribute("writeLatency", IntegerAttr::get(ui32Type, writeLatency));
1096  result.addAttribute("depth", IntegerAttr::get(ui64Type, depth));
1097  result.addAttribute("extraPorts", ArrayAttr::get(context, {}));
1098 }
1099 
1100 /// Print a list of module ports in the following form:
1101 /// in x: !firrtl.uint<1> [{class = "DontTouch}], out "_port": !firrtl.uint<2>
1102 ///
1103 /// When there is no block specified, the port names print as MLIR identifiers,
1104 /// wrapping in quotes if not legal to print as-is. When there is no block
1105 /// specified, this function always return false, indicating that there was no
1106 /// issue printing port names.
1107 ///
1108 /// If there is a block specified, then port names will be printed as SSA
1109 /// values. If there is a reason the printed SSA values can't match the true
1110 /// port name, then this function will return true. When this happens, the
1111 /// caller should print the port names as a part of the `attr-dict`.
1112 static bool
1113 printModulePorts(OpAsmPrinter &p, Block *block, ArrayRef<bool> portDirections,
1114  ArrayRef<Attribute> portNames, ArrayRef<Attribute> portTypes,
1115  ArrayRef<Attribute> portAnnotations,
1116  ArrayRef<Attribute> portSyms, ArrayRef<Attribute> portLocs) {
1117  // When printing port names as SSA values, we can fail to print them
1118  // identically.
1119  bool printedNamesDontMatch = false;
1120 
1121  mlir::OpPrintingFlags flags;
1122 
1123  // If we are printing the ports as block arguments the op must have a first
1124  // block.
1125  SmallString<32> resultNameStr;
1126  p << '(';
1127  for (unsigned i = 0, e = portTypes.size(); i < e; ++i) {
1128  if (i > 0)
1129  p << ", ";
1130 
1131  // Print the port direction.
1132  p << direction::get(portDirections[i]) << " ";
1133 
1134  // Print the port name. If there is a valid block, we print it as a block
1135  // argument.
1136  if (block) {
1137  // Get the printed format for the argument name.
1138  resultNameStr.clear();
1139  llvm::raw_svector_ostream tmpStream(resultNameStr);
1140  p.printOperand(block->getArgument(i), tmpStream);
1141  // If the name wasn't printable in a way that agreed with portName, make
1142  // sure to print out an explicit portNames attribute.
1143  auto portName = cast<StringAttr>(portNames[i]).getValue();
1144  if (!portName.empty() && tmpStream.str().drop_front() != portName)
1145  printedNamesDontMatch = true;
1146  p << tmpStream.str();
1147  } else {
1148  p.printKeywordOrString(cast<StringAttr>(portNames[i]).getValue());
1149  }
1150 
1151  // Print the port type.
1152  p << ": ";
1153  auto portType = cast<TypeAttr>(portTypes[i]).getValue();
1154  p.printType(portType);
1155 
1156  // Print the optional port symbol.
1157  if (!portSyms.empty()) {
1158  if (!cast<hw::InnerSymAttr>(portSyms[i]).empty()) {
1159  p << " sym ";
1160  cast<hw::InnerSymAttr>(portSyms[i]).print(p);
1161  }
1162  }
1163 
1164  // Print the port specific annotations. The port annotations array will be
1165  // empty if there are none.
1166  if (!portAnnotations.empty() &&
1167  !cast<ArrayAttr>(portAnnotations[i]).empty()) {
1168  p << " ";
1169  p.printAttribute(portAnnotations[i]);
1170  }
1171 
1172  // Print the port location.
1173  // TODO: `printOptionalLocationSpecifier` will emit aliases for locations,
1174  // even if they are not printed. This will have to be fixed upstream. For
1175  // now, use what was specified on the command line.
1176  if (flags.shouldPrintDebugInfo() && !portLocs.empty())
1177  p.printOptionalLocationSpecifier(cast<LocationAttr>(portLocs[i]));
1178  }
1179 
1180  p << ')';
1181  return printedNamesDontMatch;
1182 }
1183 
1184 /// Parse a list of module ports. If port names are SSA identifiers, then this
1185 /// will populate `entryArgs`.
1186 static ParseResult
1187 parseModulePorts(OpAsmParser &parser, bool hasSSAIdentifiers,
1188  bool supportsSymbols,
1189  SmallVectorImpl<OpAsmParser::Argument> &entryArgs,
1190  SmallVectorImpl<Direction> &portDirections,
1191  SmallVectorImpl<Attribute> &portNames,
1192  SmallVectorImpl<Attribute> &portTypes,
1193  SmallVectorImpl<Attribute> &portAnnotations,
1194  SmallVectorImpl<Attribute> &portSyms,
1195  SmallVectorImpl<Attribute> &portLocs) {
1196  auto *context = parser.getContext();
1197 
1198  auto parseArgument = [&]() -> ParseResult {
1199  // Parse port direction.
1200  if (succeeded(parser.parseOptionalKeyword("out")))
1201  portDirections.push_back(Direction::Out);
1202  else if (succeeded(parser.parseKeyword("in", "or 'out'")))
1203  portDirections.push_back(Direction::In);
1204  else
1205  return failure();
1206 
1207  // This is the location or the port declaration in the IR. If there is no
1208  // other location information, we use this to point to the MLIR.
1209  llvm::SMLoc irLoc;
1210 
1211  if (hasSSAIdentifiers) {
1212  OpAsmParser::Argument arg;
1213  if (parser.parseArgument(arg))
1214  return failure();
1215  entryArgs.push_back(arg);
1216 
1217  // The name of an argument is of the form "%42" or "%id", and since
1218  // parsing succeeded, we know it always has one character.
1219  assert(arg.ssaName.name.size() > 1 && arg.ssaName.name[0] == '%' &&
1220  "Unknown MLIR name");
1221  if (isdigit(arg.ssaName.name[1]))
1222  portNames.push_back(StringAttr::get(context, ""));
1223  else
1224  portNames.push_back(
1225  StringAttr::get(context, arg.ssaName.name.drop_front()));
1226 
1227  // Store the location of the SSA name.
1228  irLoc = arg.ssaName.location;
1229 
1230  } else {
1231  // Parse the port name.
1232  irLoc = parser.getCurrentLocation();
1233  std::string portName;
1234  if (parser.parseKeywordOrString(&portName))
1235  return failure();
1236  portNames.push_back(StringAttr::get(context, portName));
1237  }
1238 
1239  // Parse the port type.
1240  Type portType;
1241  if (parser.parseColonType(portType))
1242  return failure();
1243  portTypes.push_back(TypeAttr::get(portType));
1244 
1245  if (hasSSAIdentifiers)
1246  entryArgs.back().type = portType;
1247 
1248  // Parse the optional port symbol.
1249  if (supportsSymbols) {
1250  hw::InnerSymAttr innerSymAttr;
1251  if (succeeded(parser.parseOptionalKeyword("sym"))) {
1252  NamedAttrList dummyAttrs;
1253  if (parser.parseCustomAttributeWithFallback(
1254  innerSymAttr, ::mlir::Type{},
1256  return ::mlir::failure();
1257  }
1258  }
1259  portSyms.push_back(innerSymAttr);
1260  }
1261 
1262  // Parse the port annotations.
1263  ArrayAttr annos;
1264  auto parseResult = parser.parseOptionalAttribute(annos);
1265  if (!parseResult.has_value())
1266  annos = parser.getBuilder().getArrayAttr({});
1267  else if (failed(*parseResult))
1268  return failure();
1269  portAnnotations.push_back(annos);
1270 
1271  // Parse the optional port location.
1272  std::optional<Location> maybeLoc;
1273  if (failed(parser.parseOptionalLocationSpecifier(maybeLoc)))
1274  return failure();
1275  Location loc = maybeLoc ? *maybeLoc : parser.getEncodedSourceLoc(irLoc);
1276  portLocs.push_back(loc);
1277  if (hasSSAIdentifiers)
1278  entryArgs.back().sourceLoc = loc;
1279 
1280  return success();
1281  };
1282 
1283  // Parse all ports.
1284  return parser.parseCommaSeparatedList(OpAsmParser::Delimiter::Paren,
1285  parseArgument);
1286 }
1287 
1288 /// Print a paramter list for a module or instance.
1289 static void printParameterList(OpAsmPrinter &p, Operation *op,
1290  ArrayAttr parameters) {
1291  if (!parameters || parameters.empty())
1292  return;
1293 
1294  p << '<';
1295  llvm::interleaveComma(parameters, p, [&](Attribute param) {
1296  auto paramAttr = cast<ParamDeclAttr>(param);
1297  p << paramAttr.getName().getValue() << ": " << paramAttr.getType();
1298  if (auto value = paramAttr.getValue()) {
1299  p << " = ";
1300  p.printAttributeWithoutType(value);
1301  }
1302  });
1303  p << '>';
1304 }
1305 
1306 static void printFModuleLikeOp(OpAsmPrinter &p, FModuleLike op) {
1307  p << " ";
1308 
1309  // Print the visibility of the module.
1310  StringRef visibilityAttrName = SymbolTable::getVisibilityAttrName();
1311  if (auto visibility = op->getAttrOfType<StringAttr>(visibilityAttrName))
1312  p << visibility.getValue() << ' ';
1313 
1314  // Print the operation and the function name.
1315  p.printSymbolName(op.getModuleName());
1316 
1317  // Print the parameter list (if non-empty).
1318  printParameterList(p, op, op->getAttrOfType<ArrayAttr>("parameters"));
1319 
1320  // Both modules and external modules have a body, but it is always empty for
1321  // external modules.
1322  Block *body = nullptr;
1323  if (!op->getRegion(0).empty())
1324  body = &op->getRegion(0).front();
1325 
1326  auto needPortNamesAttr = printModulePorts(
1327  p, body, op.getPortDirectionsAttr(), op.getPortNames(), op.getPortTypes(),
1328  op.getPortAnnotations(), op.getPortSymbols(), op.getPortLocations());
1329 
1330  SmallVector<StringRef, 12> omittedAttrs = {
1331  "sym_name", "portDirections", "portTypes", "portAnnotations",
1332  "portSyms", "portLocations", "parameters", visibilityAttrName};
1333 
1334  if (op.getConvention() == Convention::Internal)
1335  omittedAttrs.push_back("convention");
1336 
1337  // We can omit the portNames if they were able to be printed as properly as
1338  // block arguments.
1339  if (!needPortNamesAttr)
1340  omittedAttrs.push_back("portNames");
1341 
1342  // If there are no annotations we can omit the empty array.
1343  if (op->getAttrOfType<ArrayAttr>("annotations").empty())
1344  omittedAttrs.push_back("annotations");
1345 
1346  // If there are no enabled layers, then omit the empty array.
1347  if (auto layers = op->getAttrOfType<ArrayAttr>("layers"))
1348  if (layers.empty())
1349  omittedAttrs.push_back("layers");
1350 
1351  p.printOptionalAttrDictWithKeyword(op->getAttrs(), omittedAttrs);
1352 }
1353 
1354 void FExtModuleOp::print(OpAsmPrinter &p) { printFModuleLikeOp(p, *this); }
1355 
1356 void FIntModuleOp::print(OpAsmPrinter &p) { printFModuleLikeOp(p, *this); }
1357 
1358 void FMemModuleOp::print(OpAsmPrinter &p) { printFModuleLikeOp(p, *this); }
1359 
1360 void FModuleOp::print(OpAsmPrinter &p) {
1361  printFModuleLikeOp(p, *this);
1362 
1363  // Print the body if this is not an external function. Since this block does
1364  // not have terminators, printing the terminator actually just prints the last
1365  // operation.
1366  Region &fbody = getBody();
1367  if (!fbody.empty()) {
1368  p << " ";
1369  p.printRegion(fbody, /*printEntryBlockArgs=*/false,
1370  /*printBlockTerminators=*/true);
1371  }
1372 }
1373 
1374 /// Parse an parameter list if present.
1375 /// module-parameter-list ::= `<` parameter-decl (`,` parameter-decl)* `>`
1376 /// parameter-decl ::= identifier `:` type
1377 /// parameter-decl ::= identifier `:` type `=` attribute
1378 ///
1379 static ParseResult
1380 parseOptionalParameters(OpAsmParser &parser,
1381  SmallVectorImpl<Attribute> &parameters) {
1382 
1383  return parser.parseCommaSeparatedList(
1384  OpAsmParser::Delimiter::OptionalLessGreater, [&]() {
1385  std::string name;
1386  Type type;
1387  Attribute value;
1388 
1389  if (parser.parseKeywordOrString(&name) || parser.parseColonType(type))
1390  return failure();
1391 
1392  // Parse the default value if present.
1393  if (succeeded(parser.parseOptionalEqual())) {
1394  if (parser.parseAttribute(value, type))
1395  return failure();
1396  }
1397 
1398  auto &builder = parser.getBuilder();
1399  parameters.push_back(ParamDeclAttr::get(
1400  builder.getContext(), builder.getStringAttr(name), type, value));
1401  return success();
1402  });
1403 }
1404 
1405 /// Shim to use with assemblyFormat, custom<ParameterList>.
1406 static ParseResult parseParameterList(OpAsmParser &parser,
1407  ArrayAttr &parameters) {
1408  SmallVector<Attribute> parseParameters;
1409  if (failed(parseOptionalParameters(parser, parseParameters)))
1410  return failure();
1411 
1412  parameters = ArrayAttr::get(parser.getContext(), parseParameters);
1413 
1414  return success();
1415 }
1416 
1417 static ParseResult parseFModuleLikeOp(OpAsmParser &parser,
1418  OperationState &result,
1419  bool hasSSAIdentifiers) {
1420  auto *context = result.getContext();
1421  auto &builder = parser.getBuilder();
1422 
1423  // Parse the visibility attribute.
1424  (void)mlir::impl::parseOptionalVisibilityKeyword(parser, result.attributes);
1425 
1426  // Parse the name as a symbol.
1427  StringAttr nameAttr;
1428  if (parser.parseSymbolName(nameAttr, ::mlir::SymbolTable::getSymbolAttrName(),
1429  result.attributes))
1430  return failure();
1431 
1432  // Parse optional parameters.
1433  SmallVector<Attribute, 4> parameters;
1434  if (parseOptionalParameters(parser, parameters))
1435  return failure();
1436  result.addAttribute("parameters", builder.getArrayAttr(parameters));
1437 
1438  // Parse the module ports.
1439  SmallVector<OpAsmParser::Argument> entryArgs;
1440  SmallVector<Direction, 4> portDirections;
1441  SmallVector<Attribute, 4> portNames;
1442  SmallVector<Attribute, 4> portTypes;
1443  SmallVector<Attribute, 4> portAnnotations;
1444  SmallVector<Attribute, 4> portSyms;
1445  SmallVector<Attribute, 4> portLocs;
1446  if (parseModulePorts(parser, hasSSAIdentifiers, /*supportsSymbols=*/true,
1447  entryArgs, portDirections, portNames, portTypes,
1448  portAnnotations, portSyms, portLocs))
1449  return failure();
1450 
1451  // If module attributes are present, parse them.
1452  if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
1453  return failure();
1454 
1455  assert(portNames.size() == portTypes.size());
1456 
1457  // Record the argument and result types as an attribute. This is necessary
1458  // for external modules.
1459 
1460  // Add port directions.
1461  if (!result.attributes.get("portDirections"))
1462  result.addAttribute("portDirections",
1463  direction::packAttribute(context, portDirections));
1464 
1465  // Add port names.
1466  if (!result.attributes.get("portNames"))
1467  result.addAttribute("portNames", builder.getArrayAttr(portNames));
1468 
1469  // Add the port types.
1470  if (!result.attributes.get("portTypes"))
1471  result.addAttribute("portTypes", ArrayAttr::get(context, portTypes));
1472 
1473  // Add the port annotations.
1474  if (!result.attributes.get("portAnnotations")) {
1475  // If there are no portAnnotations, don't add the attribute.
1476  if (llvm::any_of(portAnnotations, [&](Attribute anno) {
1477  return !cast<ArrayAttr>(anno).empty();
1478  }))
1479  result.addAttribute("portAnnotations",
1480  ArrayAttr::get(context, portAnnotations));
1481  }
1482 
1483  // Add port symbols.
1484  if (!result.attributes.get("portSyms")) {
1485  FModuleLike::fixupPortSymsArray(portSyms, builder.getContext());
1486  result.addAttribute("portSyms", builder.getArrayAttr(portSyms));
1487  }
1488 
1489  // Add port locations.
1490  if (!result.attributes.get("portLocations"))
1491  result.addAttribute("portLocations", ArrayAttr::get(context, portLocs));
1492 
1493  // The annotations attribute is always present, but not printed when empty.
1494  if (!result.attributes.get("annotations"))
1495  result.addAttribute("annotations", builder.getArrayAttr({}));
1496 
1497  // The portAnnotations attribute is always present, but not printed when
1498  // empty.
1499  if (!result.attributes.get("portAnnotations"))
1500  result.addAttribute("portAnnotations", builder.getArrayAttr({}));
1501 
1502  // Parse the optional function body.
1503  auto *body = result.addRegion();
1504 
1505  if (hasSSAIdentifiers) {
1506  if (parser.parseRegion(*body, entryArgs))
1507  return failure();
1508  if (body->empty())
1509  body->push_back(new Block());
1510  }
1511  return success();
1512 }
1513 
1514 ParseResult FModuleOp::parse(OpAsmParser &parser, OperationState &result) {
1515  if (parseFModuleLikeOp(parser, result, /*hasSSAIdentifiers=*/true))
1516  return failure();
1517  if (!result.attributes.get("convention"))
1518  result.addAttribute(
1519  "convention",
1520  ConventionAttr::get(result.getContext(), Convention::Internal));
1521  if (!result.attributes.get("layers"))
1522  result.addAttribute("layers", ArrayAttr::get(parser.getContext(), {}));
1523  return success();
1524 }
1525 
1526 ParseResult FExtModuleOp::parse(OpAsmParser &parser, OperationState &result) {
1527  if (parseFModuleLikeOp(parser, result, /*hasSSAIdentifiers=*/false))
1528  return failure();
1529  if (!result.attributes.get("convention"))
1530  result.addAttribute(
1531  "convention",
1532  ConventionAttr::get(result.getContext(), Convention::Internal));
1533  return success();
1534 }
1535 
1536 ParseResult FIntModuleOp::parse(OpAsmParser &parser, OperationState &result) {
1537  return parseFModuleLikeOp(parser, result, /*hasSSAIdentifiers=*/false);
1538 }
1539 
1540 ParseResult FMemModuleOp::parse(OpAsmParser &parser, OperationState &result) {
1541  return parseFModuleLikeOp(parser, result, /*hasSSAIdentifiers=*/false);
1542 }
1543 
1544 LogicalResult FModuleOp::verify() {
1545  // Verify the block arguments.
1546  auto *body = getBodyBlock();
1547  auto portTypes = getPortTypes();
1548  auto portLocs = getPortLocations();
1549  auto numPorts = portTypes.size();
1550 
1551  // Verify that we have the correct number of block arguments.
1552  if (body->getNumArguments() != numPorts)
1553  return emitOpError("entry block must have ")
1554  << numPorts << " arguments to match module signature";
1555 
1556  // Verify the block arguments' types and locations match our attributes.
1557  for (auto [arg, type, loc] : zip(body->getArguments(), portTypes, portLocs)) {
1558  if (arg.getType() != cast<TypeAttr>(type).getValue())
1559  return emitOpError("block argument types should match signature types");
1560  if (arg.getLoc() != cast<LocationAttr>(loc))
1561  return emitOpError(
1562  "block argument locations should match signature locations");
1563  }
1564 
1565  return success();
1566 }
1567 
1568 static LogicalResult
1569 verifyInternalPaths(FModuleLike op,
1570  std::optional<::mlir::ArrayAttr> internalPaths) {
1571  if (!internalPaths)
1572  return success();
1573 
1574  // If internal paths are present, should cover all ports.
1575  if (internalPaths->size() != op.getNumPorts())
1576  return op.emitError("module has inconsistent internal path array with ")
1577  << internalPaths->size() << " entries for " << op.getNumPorts()
1578  << " ports";
1579 
1580  // No internal paths for non-ref-type ports.
1581  for (auto [idx, path, typeattr] : llvm::enumerate(
1582  internalPaths->getAsRange<InternalPathAttr>(), op.getPortTypes())) {
1583  if (path.getPath() &&
1584  !type_isa<RefType>(cast<TypeAttr>(typeattr).getValue())) {
1585  auto diag =
1586  op.emitError("module has internal path for non-ref-type port ")
1587  << op.getPortNameAttr(idx);
1588  return diag.attachNote(op.getPortLocation(idx)) << "this port";
1589  }
1590  }
1591 
1592  return success();
1593 }
1594 
1595 LogicalResult FExtModuleOp::verify() {
1596  if (failed(verifyInternalPaths(*this, getInternalPaths())))
1597  return failure();
1598 
1599  auto params = getParameters();
1600  if (params.empty())
1601  return success();
1602 
1603  auto checkParmValue = [&](Attribute elt) -> bool {
1604  auto param = cast<ParamDeclAttr>(elt);
1605  auto value = param.getValue();
1606  if (isa<IntegerAttr, StringAttr, FloatAttr, hw::ParamVerbatimAttr>(value))
1607  return true;
1608  emitError() << "has unknown extmodule parameter value '"
1609  << param.getName().getValue() << "' = " << value;
1610  return false;
1611  };
1612 
1613  if (!llvm::all_of(params, checkParmValue))
1614  return failure();
1615 
1616  return success();
1617 }
1618 
1619 LogicalResult FIntModuleOp::verify() {
1620  if (failed(verifyInternalPaths(*this, getInternalPaths())))
1621  return failure();
1622 
1623  auto params = getParameters();
1624  if (params.empty())
1625  return success();
1626 
1627  auto checkParmValue = [&](Attribute elt) -> bool {
1628  auto param = cast<ParamDeclAttr>(elt);
1629  auto value = param.getValue();
1630  if (isa<IntegerAttr, StringAttr, FloatAttr>(value))
1631  return true;
1632  emitError() << "has unknown intmodule parameter value '"
1633  << param.getName().getValue() << "' = " << value;
1634  return false;
1635  };
1636 
1637  if (!llvm::all_of(params, checkParmValue))
1638  return failure();
1639 
1640  return success();
1641 }
1642 
1643 static LogicalResult verifyProbeType(RefType refType, Location loc,
1644  CircuitOp circuitOp,
1645  SymbolTableCollection &symbolTable,
1646  Twine start) {
1647  auto layer = refType.getLayer();
1648  if (!layer)
1649  return success();
1650  auto *layerOp = symbolTable.lookupSymbolIn(circuitOp, layer);
1651  if (!layerOp)
1652  return emitError(loc) << start << " associated with layer '" << layer
1653  << "', but this layer was not defined";
1654  if (!isa<LayerOp>(layerOp)) {
1655  auto diag = emitError(loc)
1656  << start << " associated with layer '" << layer
1657  << "', but symbol '" << layer << "' does not refer to a '"
1658  << LayerOp::getOperationName() << "' op";
1659  return diag.attachNote(layerOp->getLoc()) << "symbol refers to this op";
1660  }
1661  return success();
1662 }
1663 
1664 static LogicalResult verifyPortSymbolUses(FModuleLike module,
1665  SymbolTableCollection &symbolTable) {
1666  // verify types in ports.
1667  auto circuitOp = module->getParentOfType<CircuitOp>();
1668  for (size_t i = 0, e = module.getNumPorts(); i < e; ++i) {
1669  auto type = module.getPortType(i);
1670 
1671  if (auto refType = type_dyn_cast<RefType>(type)) {
1672  if (failed(verifyProbeType(
1673  refType, module.getPortLocation(i), circuitOp, symbolTable,
1674  Twine("probe port '") + module.getPortName(i) + "' is")))
1675  return failure();
1676  continue;
1677  }
1678 
1679  if (auto classType = dyn_cast<ClassType>(type)) {
1680  auto className = classType.getNameAttr();
1681  auto classOp = dyn_cast_or_null<ClassLike>(
1682  symbolTable.lookupSymbolIn(circuitOp, className));
1683  if (!classOp)
1684  return module.emitOpError() << "references unknown class " << className;
1685 
1686  // verify that the result type agrees with the class definition.
1687  if (failed(classOp.verifyType(classType,
1688  [&]() { return module.emitOpError(); })))
1689  return failure();
1690  continue;
1691  }
1692  }
1693 
1694  return success();
1695 }
1696 
1697 LogicalResult FModuleOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1698  if (failed(
1699  verifyPortSymbolUses(cast<FModuleLike>(getOperation()), symbolTable)))
1700  return failure();
1701 
1702  auto circuitOp = (*this)->getParentOfType<CircuitOp>();
1703  for (auto layer : getLayers()) {
1704  if (!symbolTable.lookupSymbolIn(circuitOp, cast<SymbolRefAttr>(layer)))
1705  return emitOpError() << "enables unknown layer '" << layer << "'";
1706  }
1707 
1708  return success();
1709 }
1710 
1711 LogicalResult
1712 FExtModuleOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1713  return verifyPortSymbolUses(cast<FModuleLike>(getOperation()), symbolTable);
1714 }
1715 
1716 LogicalResult
1717 FIntModuleOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1718  return verifyPortSymbolUses(cast<FModuleLike>(getOperation()), symbolTable);
1719 }
1720 
1721 LogicalResult
1722 FMemModuleOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1723  return verifyPortSymbolUses(cast<FModuleLike>(getOperation()), symbolTable);
1724 }
1725 
1726 void FModuleOp::getAsmBlockArgumentNames(mlir::Region &region,
1727  mlir::OpAsmSetValueNameFn setNameFn) {
1728  getAsmBlockArgumentNamesImpl(getOperation(), region, setNameFn);
1729 }
1730 
1731 void FExtModuleOp::getAsmBlockArgumentNames(
1732  mlir::Region &region, mlir::OpAsmSetValueNameFn setNameFn) {
1733  getAsmBlockArgumentNamesImpl(getOperation(), region, setNameFn);
1734 }
1735 
1736 void FIntModuleOp::getAsmBlockArgumentNames(
1737  mlir::Region &region, mlir::OpAsmSetValueNameFn setNameFn) {
1738  getAsmBlockArgumentNamesImpl(getOperation(), region, setNameFn);
1739 }
1740 
1741 void FMemModuleOp::getAsmBlockArgumentNames(
1742  mlir::Region &region, mlir::OpAsmSetValueNameFn setNameFn) {
1743  getAsmBlockArgumentNamesImpl(getOperation(), region, setNameFn);
1744 }
1745 
1746 ArrayAttr FMemModuleOp::getParameters() { return {}; }
1747 
1748 ArrayAttr FModuleOp::getParameters() { return {}; }
1749 
1750 Convention FIntModuleOp::getConvention() { return Convention::Internal; }
1751 
1752 ConventionAttr FIntModuleOp::getConventionAttr() {
1753  return ConventionAttr::get(getContext(), getConvention());
1754 }
1755 
1756 Convention FMemModuleOp::getConvention() { return Convention::Internal; }
1757 
1758 ConventionAttr FMemModuleOp::getConventionAttr() {
1759  return ConventionAttr::get(getContext(), getConvention());
1760 }
1761 
1762 //===----------------------------------------------------------------------===//
1763 // ClassLike Helpers
1764 //===----------------------------------------------------------------------===//
1765 
1767  ClassLike classOp, ClassType type,
1768  function_ref<InFlightDiagnostic()> emitError) {
1769  // This check is probably not required, but done for sanity.
1770  auto name = type.getNameAttr().getAttr();
1771  auto expectedName = classOp.getModuleNameAttr();
1772  if (name != expectedName)
1773  return emitError() << "type has wrong name, got " << name << ", expected "
1774  << expectedName;
1775 
1776  auto elements = type.getElements();
1777  auto numElements = elements.size();
1778  auto expectedNumElements = classOp.getNumPorts();
1779  if (numElements != expectedNumElements)
1780  return emitError() << "has wrong number of ports, got " << numElements
1781  << ", expected " << expectedNumElements;
1782 
1783  auto portNames = classOp.getPortNames();
1784  auto portDirections = classOp.getPortDirections();
1785  auto portTypes = classOp.getPortTypes();
1786 
1787  for (unsigned i = 0; i < numElements; ++i) {
1788  auto element = elements[i];
1789 
1790  auto name = element.name;
1791  auto expectedName = portNames[i];
1792  if (name != expectedName)
1793  return emitError() << "port #" << i << " has wrong name, got " << name
1794  << ", expected " << expectedName;
1795 
1796  auto direction = element.direction;
1797  auto expectedDirection = Direction(portDirections[i]);
1798  if (direction != expectedDirection)
1799  return emitError() << "port " << name << " has wrong direction, got "
1800  << direction::toString(direction) << ", expected "
1801  << direction::toString(expectedDirection);
1802 
1803  auto type = element.type;
1804  auto expectedType = cast<TypeAttr>(portTypes[i]).getValue();
1805  if (type != expectedType)
1806  return emitError() << "port " << name << " has wrong type, got " << type
1807  << ", expected " << expectedType;
1808  }
1809 
1810  return success();
1811 }
1812 
1813 ClassType firrtl::detail::getInstanceTypeForClassLike(ClassLike classOp) {
1814  auto n = classOp.getNumPorts();
1815  SmallVector<ClassElement> elements;
1816  elements.reserve(n);
1817  for (size_t i = 0; i < n; ++i)
1818  elements.push_back({classOp.getPortNameAttr(i), classOp.getPortType(i),
1819  classOp.getPortDirection(i)});
1820  auto name = FlatSymbolRefAttr::get(classOp.getNameAttr());
1821  return ClassType::get(name, elements);
1822 }
1823 
1824 static ParseResult parseClassLike(OpAsmParser &parser, OperationState &result,
1825  bool hasSSAIdentifiers) {
1826  auto *context = result.getContext();
1827  auto &builder = parser.getBuilder();
1828 
1829  // Parse the visibility attribute.
1830  (void)mlir::impl::parseOptionalVisibilityKeyword(parser, result.attributes);
1831 
1832  // Parse the name as a symbol.
1833  StringAttr nameAttr;
1834  if (parser.parseSymbolName(nameAttr, ::mlir::SymbolTable::getSymbolAttrName(),
1835  result.attributes))
1836  return failure();
1837 
1838  // Parse the module ports.
1839  SmallVector<OpAsmParser::Argument> entryArgs;
1840  SmallVector<Direction, 4> portDirections;
1841  SmallVector<Attribute, 4> portNames;
1842  SmallVector<Attribute, 4> portTypes;
1843  SmallVector<Attribute, 4> portAnnotations;
1844  SmallVector<Attribute, 4> portSyms;
1845  SmallVector<Attribute, 4> portLocs;
1846  if (parseModulePorts(parser, hasSSAIdentifiers,
1847  /*supportsSymbols=*/false, entryArgs, portDirections,
1848  portNames, portTypes, portAnnotations, portSyms,
1849  portLocs))
1850  return failure();
1851 
1852  // Ports on ClassLike ops cannot have annotations
1853  for (auto annos : portAnnotations)
1854  if (!cast<ArrayAttr>(annos).empty())
1855  return failure();
1856 
1857  // If attributes are present, parse them.
1858  if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
1859  return failure();
1860 
1861  assert(portNames.size() == portTypes.size());
1862 
1863  // Record the argument and result types as an attribute. This is necessary
1864  // for external modules.
1865 
1866  // Add port directions.
1867  if (!result.attributes.get("portDirections"))
1868  result.addAttribute("portDirections",
1869  direction::packAttribute(context, portDirections));
1870 
1871  // Add port names.
1872  if (!result.attributes.get("portNames"))
1873  result.addAttribute("portNames", builder.getArrayAttr(portNames));
1874 
1875  // Add the port types.
1876  if (!result.attributes.get("portTypes"))
1877  result.addAttribute("portTypes", builder.getArrayAttr(portTypes));
1878 
1879  // Add the port symbols.
1880  if (!result.attributes.get("portSyms")) {
1881  FModuleLike::fixupPortSymsArray(portSyms, builder.getContext());
1882  result.addAttribute("portSyms", builder.getArrayAttr(portSyms));
1883  }
1884 
1885  // Add port locations.
1886  if (!result.attributes.get("portLocations"))
1887  result.addAttribute("portLocations", ArrayAttr::get(context, portLocs));
1888 
1889  // Notably missing compared to other FModuleLike, we do not track port
1890  // annotations, nor port symbols, on classes.
1891 
1892  // Add the region (unused by extclass).
1893  auto *bodyRegion = result.addRegion();
1894 
1895  if (hasSSAIdentifiers) {
1896  if (parser.parseRegion(*bodyRegion, entryArgs))
1897  return failure();
1898  if (bodyRegion->empty())
1899  bodyRegion->push_back(new Block());
1900  }
1901 
1902  return success();
1903 }
1904 
1905 static void printClassLike(OpAsmPrinter &p, ClassLike op) {
1906  p << ' ';
1907 
1908  // Print the visibility of the class.
1909  StringRef visibilityAttrName = SymbolTable::getVisibilityAttrName();
1910  if (auto visibility = op->getAttrOfType<StringAttr>(visibilityAttrName))
1911  p << visibility.getValue() << ' ';
1912 
1913  // Print the class name.
1914  p.printSymbolName(op.getName());
1915 
1916  // Both classes and external classes have a body, but it is always empty for
1917  // external classes.
1918  Region &region = op->getRegion(0);
1919  Block *body = nullptr;
1920  if (!region.empty())
1921  body = &region.front();
1922 
1923  auto needPortNamesAttr = printModulePorts(
1924  p, body, op.getPortDirectionsAttr(), op.getPortNames(), op.getPortTypes(),
1925  {}, op.getPortSymbols(), op.getPortLocations());
1926 
1927  // Print the attr-dict.
1928  SmallVector<StringRef, 8> omittedAttrs = {
1929  "sym_name", "portNames", "portTypes", "portDirections",
1930  "portSyms", "portLocations", visibilityAttrName};
1931 
1932  // We can omit the portNames if they were able to be printed as properly as
1933  // block arguments.
1934  if (!needPortNamesAttr)
1935  omittedAttrs.push_back("portNames");
1936 
1937  p.printOptionalAttrDictWithKeyword(op->getAttrs(), omittedAttrs);
1938 
1939  // print the body if it exists.
1940  if (!region.empty()) {
1941  p << " ";
1942  auto printEntryBlockArgs = false;
1943  auto printBlockTerminators = false;
1944  p.printRegion(region, printEntryBlockArgs, printBlockTerminators);
1945  }
1946 }
1947 
1948 //===----------------------------------------------------------------------===//
1949 // ClassOp
1950 //===----------------------------------------------------------------------===//
1951 
1952 void ClassOp::build(OpBuilder &builder, OperationState &result, StringAttr name,
1953  ArrayRef<PortInfo> ports) {
1954  assert(
1955  llvm::all_of(ports,
1956  [](const auto &port) { return port.annotations.empty(); }) &&
1957  "class ports may not have annotations");
1958 
1959  buildClass(builder, result, name, ports);
1960 
1961  // Create a region and a block for the body.
1962  auto *bodyRegion = result.regions[0].get();
1963  Block *body = new Block();
1964  bodyRegion->push_back(body);
1965 
1966  // Add arguments to the body block.
1967  for (auto &elt : ports)
1968  body->addArgument(elt.type, elt.loc);
1969 }
1970 
1971 void ClassOp::print(OpAsmPrinter &p) {
1972  printClassLike(p, cast<ClassLike>(getOperation()));
1973 }
1974 
1975 ParseResult ClassOp::parse(OpAsmParser &parser, OperationState &result) {
1976  auto hasSSAIdentifiers = true;
1977  return parseClassLike(parser, result, hasSSAIdentifiers);
1978 }
1979 
1980 LogicalResult ClassOp::verify() {
1981  for (auto operand : getBodyBlock()->getArguments()) {
1982  auto type = operand.getType();
1983  if (!isa<PropertyType>(type)) {
1984  emitOpError("ports on a class must be properties");
1985  return failure();
1986  }
1987  }
1988 
1989  return success();
1990 }
1991 
1992 LogicalResult
1993 ClassOp::verifySymbolUses(::mlir::SymbolTableCollection &symbolTable) {
1994  return verifyPortSymbolUses(cast<FModuleLike>(getOperation()), symbolTable);
1995 }
1996 
1997 void ClassOp::getAsmBlockArgumentNames(mlir::Region &region,
1998  mlir::OpAsmSetValueNameFn setNameFn) {
1999  getAsmBlockArgumentNamesImpl(getOperation(), region, setNameFn);
2000 }
2001 
2002 SmallVector<PortInfo> ClassOp::getPorts() {
2003  return ::getPortImpl(cast<FModuleLike>((Operation *)*this));
2004 }
2005 
2006 void ClassOp::erasePorts(const llvm::BitVector &portIndices) {
2007  ::erasePorts(cast<FModuleLike>((Operation *)*this), portIndices);
2008  getBodyBlock()->eraseArguments(portIndices);
2009 }
2010 
2011 void ClassOp::insertPorts(ArrayRef<std::pair<unsigned, PortInfo>> ports) {
2012  ::insertPorts(cast<FModuleLike>((Operation *)*this), ports);
2013 }
2014 
2015 Convention ClassOp::getConvention() { return Convention::Internal; }
2016 
2017 ConventionAttr ClassOp::getConventionAttr() {
2018  return ConventionAttr::get(getContext(), getConvention());
2019 }
2020 
2021 ArrayAttr ClassOp::getParameters() { return {}; }
2022 
2023 ArrayAttr ClassOp::getPortAnnotationsAttr() {
2024  return ArrayAttr::get(getContext(), {});
2025 }
2026 
2027 ArrayAttr ClassOp::getLayersAttr() { return ArrayAttr::get(getContext(), {}); }
2028 
2029 ArrayRef<Attribute> ClassOp::getLayers() { return getLayersAttr(); }
2030 
2031 SmallVector<::circt::hw::PortInfo> ClassOp::getPortList() {
2032  return ::getPortListImpl(*this);
2033 }
2034 
2036  return ::getPortImpl(*this, idx);
2037 }
2038 
2039 BlockArgument ClassOp::getArgument(size_t portNumber) {
2040  return getBodyBlock()->getArgument(portNumber);
2041 }
2042 
2043 bool ClassOp::canDiscardOnUseEmpty() {
2044  // ClassOps are referenced by ClassTypes, and these uses are not
2045  // discoverable by the symbol infrastructure. Return false here to prevent
2046  // passes like symbolDCE from removing our classes.
2047  return false;
2048 }
2049 
2050 //===----------------------------------------------------------------------===//
2051 // ExtClassOp
2052 //===----------------------------------------------------------------------===//
2053 
2054 void ExtClassOp::build(OpBuilder &builder, OperationState &result,
2055  StringAttr name, ArrayRef<PortInfo> ports) {
2056  assert(
2057  llvm::all_of(ports,
2058  [](const auto &port) { return port.annotations.empty(); }) &&
2059  "class ports may not have annotations");
2060 
2061  buildClass(builder, result, name, ports);
2062 }
2063 
2064 void ExtClassOp::print(OpAsmPrinter &p) {
2065  printClassLike(p, cast<ClassLike>(getOperation()));
2066 }
2067 
2068 ParseResult ExtClassOp::parse(OpAsmParser &parser, OperationState &result) {
2069  auto hasSSAIdentifiers = false;
2070  return parseClassLike(parser, result, hasSSAIdentifiers);
2071 }
2072 
2073 LogicalResult
2074 ExtClassOp::verifySymbolUses(::mlir::SymbolTableCollection &symbolTable) {
2075  return verifyPortSymbolUses(cast<FModuleLike>(getOperation()), symbolTable);
2076 }
2077 
2078 void ExtClassOp::getAsmBlockArgumentNames(mlir::Region &region,
2079  mlir::OpAsmSetValueNameFn setNameFn) {
2080  getAsmBlockArgumentNamesImpl(getOperation(), region, setNameFn);
2081 }
2082 
2083 SmallVector<PortInfo> ExtClassOp::getPorts() {
2084  return ::getPortImpl(cast<FModuleLike>((Operation *)*this));
2085 }
2086 
2087 void ExtClassOp::erasePorts(const llvm::BitVector &portIndices) {
2088  ::erasePorts(cast<FModuleLike>((Operation *)*this), portIndices);
2089 }
2090 
2091 void ExtClassOp::insertPorts(ArrayRef<std::pair<unsigned, PortInfo>> ports) {
2092  ::insertPorts(cast<FModuleLike>((Operation *)*this), ports);
2093 }
2094 
2095 Convention ExtClassOp::getConvention() { return Convention::Internal; }
2096 
2097 ConventionAttr ExtClassOp::getConventionAttr() {
2098  return ConventionAttr::get(getContext(), getConvention());
2099 }
2100 
2101 ArrayAttr ExtClassOp::getLayersAttr() {
2102  return ArrayAttr::get(getContext(), {});
2103 }
2104 
2105 ArrayRef<Attribute> ExtClassOp::getLayers() { return getLayersAttr(); }
2106 
2107 ArrayAttr ExtClassOp::getParameters() { return {}; }
2108 
2109 ArrayAttr ExtClassOp::getPortAnnotationsAttr() {
2110  return ArrayAttr::get(getContext(), {});
2111 }
2112 
2113 SmallVector<::circt::hw::PortInfo> ExtClassOp::getPortList() {
2114  return ::getPortListImpl(*this);
2115 }
2116 
2118  return ::getPortImpl(*this, idx);
2119 }
2120 
2121 bool ExtClassOp::canDiscardOnUseEmpty() {
2122  // ClassOps are referenced by ClassTypes, and these uses are not
2123  // discovereable by the symbol infrastructure. Return false here to prevent
2124  // passes like symbolDCE from removing our classes.
2125  return false;
2126 }
2127 
2128 //===----------------------------------------------------------------------===//
2129 // InstanceOp
2130 //===----------------------------------------------------------------------===//
2131 
2132 void InstanceOp::build(
2133  OpBuilder &builder, OperationState &result, TypeRange resultTypes,
2134  StringRef moduleName, StringRef name, NameKindEnum nameKind,
2135  ArrayRef<Direction> portDirections, ArrayRef<Attribute> portNames,
2136  ArrayRef<Attribute> annotations, ArrayRef<Attribute> portAnnotations,
2137  ArrayRef<Attribute> layers, bool lowerToBind, StringAttr innerSym) {
2138  build(builder, result, resultTypes, moduleName, name, nameKind,
2139  portDirections, portNames, annotations, portAnnotations, layers,
2140  lowerToBind,
2141  innerSym ? hw::InnerSymAttr::get(innerSym) : hw::InnerSymAttr());
2142 }
2143 
2144 void InstanceOp::build(
2145  OpBuilder &builder, OperationState &result, TypeRange resultTypes,
2146  StringRef moduleName, StringRef name, NameKindEnum nameKind,
2147  ArrayRef<Direction> portDirections, ArrayRef<Attribute> portNames,
2148  ArrayRef<Attribute> annotations, ArrayRef<Attribute> portAnnotations,
2149  ArrayRef<Attribute> layers, bool lowerToBind, hw::InnerSymAttr innerSym) {
2150  result.addTypes(resultTypes);
2151  result.addAttribute("moduleName",
2152  SymbolRefAttr::get(builder.getContext(), moduleName));
2153  result.addAttribute("name", builder.getStringAttr(name));
2154  result.addAttribute(
2155  "portDirections",
2156  direction::packAttribute(builder.getContext(), portDirections));
2157  result.addAttribute("portNames", builder.getArrayAttr(portNames));
2158  result.addAttribute("annotations", builder.getArrayAttr(annotations));
2159  result.addAttribute("layers", builder.getArrayAttr(layers));
2160  if (lowerToBind)
2161  result.addAttribute("lowerToBind", builder.getUnitAttr());
2162  if (innerSym)
2163  result.addAttribute("inner_sym", innerSym);
2164  result.addAttribute("nameKind",
2165  NameKindEnumAttr::get(builder.getContext(), nameKind));
2166 
2167  if (portAnnotations.empty()) {
2168  SmallVector<Attribute, 16> portAnnotationsVec(resultTypes.size(),
2169  builder.getArrayAttr({}));
2170  result.addAttribute("portAnnotations",
2171  builder.getArrayAttr(portAnnotationsVec));
2172  } else {
2173  assert(portAnnotations.size() == resultTypes.size());
2174  result.addAttribute("portAnnotations",
2175  builder.getArrayAttr(portAnnotations));
2176  }
2177 }
2178 
2179 void InstanceOp::build(OpBuilder &builder, OperationState &result,
2180  FModuleLike module, StringRef name,
2181  NameKindEnum nameKind, ArrayRef<Attribute> annotations,
2182  ArrayRef<Attribute> portAnnotations, bool lowerToBind,
2183  hw::InnerSymAttr innerSym) {
2184 
2185  // Gather the result types.
2186  SmallVector<Type> resultTypes;
2187  resultTypes.reserve(module.getNumPorts());
2188  llvm::transform(
2189  module.getPortTypes(), std::back_inserter(resultTypes),
2190  [](Attribute typeAttr) { return cast<TypeAttr>(typeAttr).getValue(); });
2191 
2192  // Create the port annotations.
2193  ArrayAttr portAnnotationsAttr;
2194  if (portAnnotations.empty()) {
2195  portAnnotationsAttr = builder.getArrayAttr(SmallVector<Attribute, 16>(
2196  resultTypes.size(), builder.getArrayAttr({})));
2197  } else {
2198  portAnnotationsAttr = builder.getArrayAttr(portAnnotations);
2199  }
2200 
2201  return build(
2202  builder, result, resultTypes,
2203  SymbolRefAttr::get(builder.getContext(), module.getModuleNameAttr()),
2204  builder.getStringAttr(name),
2205  NameKindEnumAttr::get(builder.getContext(), nameKind),
2206  module.getPortDirectionsAttr(), module.getPortNamesAttr(),
2207  builder.getArrayAttr(annotations), portAnnotationsAttr,
2208  module.getLayersAttr(), lowerToBind ? builder.getUnitAttr() : UnitAttr(),
2209  innerSym);
2210 }
2211 
2212 void InstanceOp::build(OpBuilder &builder, OperationState &odsState,
2213  ArrayRef<PortInfo> ports, StringRef moduleName,
2214  StringRef name, NameKindEnum nameKind,
2215  ArrayRef<Attribute> annotations,
2216  ArrayRef<Attribute> layers, bool lowerToBind,
2217  hw::InnerSymAttr innerSym) {
2218  // Gather the result types.
2219  SmallVector<Type> newResultTypes;
2220  SmallVector<Direction> newPortDirections;
2221  SmallVector<Attribute> newPortNames;
2222  SmallVector<Attribute> newPortAnnotations;
2223  for (auto &p : ports) {
2224  newResultTypes.push_back(p.type);
2225  newPortDirections.push_back(p.direction);
2226  newPortNames.push_back(p.name);
2227  newPortAnnotations.push_back(p.annotations.getArrayAttr());
2228  }
2229 
2230  return build(builder, odsState, newResultTypes, moduleName, name, nameKind,
2231  newPortDirections, newPortNames, annotations, newPortAnnotations,
2232  layers, lowerToBind, innerSym);
2233 }
2234 
2235 LogicalResult InstanceOp::verify() {
2236  // The instance may only be instantiated under its required layers.
2237  auto ambientLayers = getAmbientLayersAt(getOperation());
2238  SmallVector<SymbolRefAttr> missingLayers;
2239  for (auto layer : getLayersAttr().getAsRange<SymbolRefAttr>())
2240  if (!isLayerCompatibleWith(layer, ambientLayers))
2241  missingLayers.push_back(layer);
2242 
2243  if (missingLayers.empty())
2244  return success();
2245 
2246  auto diag =
2247  emitOpError("ambient layers are insufficient to instantiate module");
2248  auto &note = diag.attachNote();
2249  note << "missing layer requirements: ";
2250  interleaveComma(missingLayers, note);
2251  return failure();
2252 }
2253 
2254 /// Builds a new `InstanceOp` with the ports listed in `portIndices` erased, and
2255 /// updates any users of the remaining ports to point at the new instance.
2256 InstanceOp InstanceOp::erasePorts(OpBuilder &builder,
2257  const llvm::BitVector &portIndices) {
2258  assert(portIndices.size() >= getNumResults() &&
2259  "portIndices is not at least as large as getNumResults()");
2260 
2261  if (portIndices.none())
2262  return *this;
2263 
2264  SmallVector<Type> newResultTypes = removeElementsAtIndices<Type>(
2265  SmallVector<Type>(result_type_begin(), result_type_end()), portIndices);
2266  SmallVector<Direction> newPortDirections = removeElementsAtIndices<Direction>(
2267  direction::unpackAttribute(getPortDirectionsAttr()), portIndices);
2268  SmallVector<Attribute> newPortNames =
2269  removeElementsAtIndices(getPortNames().getValue(), portIndices);
2270  SmallVector<Attribute> newPortAnnotations =
2271  removeElementsAtIndices(getPortAnnotations().getValue(), portIndices);
2272 
2273  auto newOp = builder.create<InstanceOp>(
2274  getLoc(), newResultTypes, getModuleName(), getName(), getNameKind(),
2275  newPortDirections, newPortNames, getAnnotations().getValue(),
2276  newPortAnnotations, getLayers(), getLowerToBind(), getInnerSymAttr());
2277 
2278  for (unsigned oldIdx = 0, newIdx = 0, numOldPorts = getNumResults();
2279  oldIdx != numOldPorts; ++oldIdx) {
2280  if (portIndices.test(oldIdx)) {
2281  assert(getResult(oldIdx).use_empty() && "removed instance port has uses");
2282  continue;
2283  }
2284  getResult(oldIdx).replaceAllUsesWith(newOp.getResult(newIdx));
2285  ++newIdx;
2286  }
2287 
2288  // Compy over "output_file" information so that this is not lost when ports
2289  // are erased.
2290  //
2291  // TODO: Other attributes may need to be copied over.
2292  if (auto outputFile = (*this)->getAttr("output_file"))
2293  newOp->setAttr("output_file", outputFile);
2294 
2295  return newOp;
2296 }
2297 
2298 ArrayAttr InstanceOp::getPortAnnotation(unsigned portIdx) {
2299  assert(portIdx < getNumResults() &&
2300  "index should be smaller than result number");
2301  return cast<ArrayAttr>(getPortAnnotations()[portIdx]);
2302 }
2303 
2304 void InstanceOp::setAllPortAnnotations(ArrayRef<Attribute> annotations) {
2305  assert(annotations.size() == getNumResults() &&
2306  "number of annotations is not equal to result number");
2307  (*this)->setAttr("portAnnotations",
2308  ArrayAttr::get(getContext(), annotations));
2309 }
2310 
2311 InstanceOp
2312 InstanceOp::cloneAndInsertPorts(ArrayRef<std::pair<unsigned, PortInfo>> ports) {
2313  auto portSize = ports.size();
2314  auto newPortCount = getNumResults() + portSize;
2315  SmallVector<Direction> newPortDirections;
2316  newPortDirections.reserve(newPortCount);
2317  SmallVector<Attribute> newPortNames;
2318  newPortNames.reserve(newPortCount);
2319  SmallVector<Type> newPortTypes;
2320  newPortTypes.reserve(newPortCount);
2321  SmallVector<Attribute> newPortAnnos;
2322  newPortAnnos.reserve(newPortCount);
2323 
2324  unsigned oldIndex = 0;
2325  unsigned newIndex = 0;
2326  while (oldIndex + newIndex < newPortCount) {
2327  // Check if we should insert a port here.
2328  if (newIndex < portSize && ports[newIndex].first == oldIndex) {
2329  auto &newPort = ports[newIndex].second;
2330  newPortDirections.push_back(newPort.direction);
2331  newPortNames.push_back(newPort.name);
2332  newPortTypes.push_back(newPort.type);
2333  newPortAnnos.push_back(newPort.annotations.getArrayAttr());
2334  ++newIndex;
2335  } else {
2336  // Copy the next old port.
2337  newPortDirections.push_back(getPortDirection(oldIndex));
2338  newPortNames.push_back(getPortName(oldIndex));
2339  newPortTypes.push_back(getType(oldIndex));
2340  newPortAnnos.push_back(getPortAnnotation(oldIndex));
2341  ++oldIndex;
2342  }
2343  }
2344 
2345  // Create a new instance op with the reset inserted.
2346  return OpBuilder(*this).create<InstanceOp>(
2347  getLoc(), newPortTypes, getModuleName(), getName(), getNameKind(),
2348  newPortDirections, newPortNames, getAnnotations().getValue(),
2349  newPortAnnos, getLayers(), getLowerToBind(), getInnerSymAttr());
2350 }
2351 
2352 LogicalResult InstanceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
2353  return instance_like_impl::verifyReferencedModule(*this, symbolTable,
2354  getModuleNameAttr());
2355 }
2356 
2357 StringRef InstanceOp::getInstanceName() { return getName(); }
2358 
2359 StringAttr InstanceOp::getInstanceNameAttr() { return getNameAttr(); }
2360 
2361 void InstanceOp::print(OpAsmPrinter &p) {
2362  // Print the instance name.
2363  p << " ";
2364  p.printKeywordOrString(getName());
2365  if (auto attr = getInnerSymAttr()) {
2366  p << " sym ";
2367  p.printSymbolName(attr.getSymName());
2368  }
2369  if (getNameKindAttr().getValue() != NameKindEnum::DroppableName)
2370  p << ' ' << stringifyNameKindEnum(getNameKindAttr().getValue());
2371 
2372  // Print the attr-dict.
2373  SmallVector<StringRef, 10> omittedAttrs = {
2374  "moduleName", "name", "portDirections",
2375  "portNames", "portTypes", "portAnnotations",
2376  "inner_sym", "nameKind"};
2377  if (getAnnotations().empty())
2378  omittedAttrs.push_back("annotations");
2379  if (getLayers().empty())
2380  omittedAttrs.push_back("layers");
2381  p.printOptionalAttrDict((*this)->getAttrs(), omittedAttrs);
2382 
2383  // Print the module name.
2384  p << " ";
2385  p.printSymbolName(getModuleName());
2386 
2387  // Collect all the result types as TypeAttrs for printing.
2388  SmallVector<Attribute> portTypes;
2389  portTypes.reserve(getNumResults());
2390  llvm::transform(getResultTypes(), std::back_inserter(portTypes),
2391  &TypeAttr::get);
2392  printModulePorts(p, /*block=*/nullptr, getPortDirectionsAttr(),
2393  getPortNames().getValue(), portTypes,
2394  getPortAnnotations().getValue(), {}, {});
2395 }
2396 
2397 ParseResult InstanceOp::parse(OpAsmParser &parser, OperationState &result) {
2398  auto *context = parser.getContext();
2399  auto &resultAttrs = result.attributes;
2400 
2401  std::string name;
2402  hw::InnerSymAttr innerSymAttr;
2403  FlatSymbolRefAttr moduleName;
2404  SmallVector<OpAsmParser::Argument> entryArgs;
2405  SmallVector<Direction, 4> portDirections;
2406  SmallVector<Attribute, 4> portNames;
2407  SmallVector<Attribute, 4> portTypes;
2408  SmallVector<Attribute, 4> portAnnotations;
2409  SmallVector<Attribute, 4> portSyms;
2410  SmallVector<Attribute, 4> portLocs;
2411  NameKindEnumAttr nameKind;
2412 
2413  if (parser.parseKeywordOrString(&name))
2414  return failure();
2415  if (succeeded(parser.parseOptionalKeyword("sym"))) {
2416  if (parser.parseCustomAttributeWithFallback(
2417  innerSymAttr, ::mlir::Type{},
2418  hw::InnerSymbolTable::getInnerSymbolAttrName(),
2419  result.attributes)) {
2420  return ::mlir::failure();
2421  }
2422  }
2423  if (parseNameKind(parser, nameKind) ||
2424  parser.parseOptionalAttrDict(result.attributes) ||
2425  parser.parseAttribute(moduleName, "moduleName", resultAttrs) ||
2426  parseModulePorts(parser, /*hasSSAIdentifiers=*/false,
2427  /*supportsSymbols=*/false, entryArgs, portDirections,
2428  portNames, portTypes, portAnnotations, portSyms,
2429  portLocs))
2430  return failure();
2431 
2432  // Add the attributes. We let attributes defined in the attr-dict override
2433  // attributes parsed out of the module signature.
2434  if (!resultAttrs.get("moduleName"))
2435  result.addAttribute("moduleName", moduleName);
2436  if (!resultAttrs.get("name"))
2437  result.addAttribute("name", StringAttr::get(context, name));
2438  result.addAttribute("nameKind", nameKind);
2439  if (!resultAttrs.get("portDirections"))
2440  result.addAttribute("portDirections",
2441  direction::packAttribute(context, portDirections));
2442  if (!resultAttrs.get("portNames"))
2443  result.addAttribute("portNames", ArrayAttr::get(context, portNames));
2444  if (!resultAttrs.get("portAnnotations"))
2445  result.addAttribute("portAnnotations",
2446  ArrayAttr::get(context, portAnnotations));
2447 
2448  // Annotations, layers, and LowerToBind are omitted in the printed format
2449  // if they are empty, empty, and false (respectively).
2450  if (!resultAttrs.get("annotations"))
2451  resultAttrs.append("annotations", parser.getBuilder().getArrayAttr({}));
2452  if (!resultAttrs.get("layers"))
2453  resultAttrs.append("layers", parser.getBuilder().getArrayAttr({}));
2454 
2455  // Add result types.
2456  result.types.reserve(portTypes.size());
2457  llvm::transform(
2458  portTypes, std::back_inserter(result.types),
2459  [](Attribute typeAttr) { return cast<TypeAttr>(typeAttr).getValue(); });
2460 
2461  return success();
2462 }
2463 
2465  StringRef base = getName();
2466  if (base.empty())
2467  base = "inst";
2468 
2469  for (size_t i = 0, e = (*this)->getNumResults(); i != e; ++i) {
2470  setNameFn(getResult(i), (base + "_" + getPortNameStr(i)).str());
2471  }
2472 }
2473 
2474 std::optional<size_t> InstanceOp::getTargetResultIndex() {
2475  // Inner symbols on instance operations target the op not any result.
2476  return std::nullopt;
2477 }
2478 
2479 // -----------------------------------------------------------------------------
2480 // InstanceChoiceOp
2481 // -----------------------------------------------------------------------------
2482 
2483 void InstanceChoiceOp::build(
2484  OpBuilder &builder, OperationState &result, FModuleLike defaultModule,
2485  ArrayRef<std::pair<OptionCaseOp, FModuleLike>> cases, StringRef name,
2486  NameKindEnum nameKind, ArrayRef<Attribute> annotations,
2487  ArrayRef<Attribute> portAnnotations, StringAttr innerSym) {
2488  // Gather the result types.
2489  SmallVector<Type> resultTypes;
2490  for (Attribute portType : defaultModule.getPortTypes())
2491  resultTypes.push_back(cast<TypeAttr>(portType).getValue());
2492 
2493  // Create the port annotations.
2494  ArrayAttr portAnnotationsAttr;
2495  if (portAnnotations.empty()) {
2496  portAnnotationsAttr = builder.getArrayAttr(SmallVector<Attribute, 16>(
2497  resultTypes.size(), builder.getArrayAttr({})));
2498  } else {
2499  portAnnotationsAttr = builder.getArrayAttr(portAnnotations);
2500  }
2501 
2502  // Gather the module & case names.
2503  SmallVector<Attribute> moduleNames, caseNames;
2504  moduleNames.push_back(SymbolRefAttr::get(defaultModule.getModuleNameAttr()));
2505  for (auto [caseOption, caseModule] : cases) {
2506  auto caseGroup = caseOption->getParentOfType<OptionOp>();
2507  caseNames.push_back(SymbolRefAttr::get(caseGroup.getSymNameAttr(),
2508  {SymbolRefAttr::get(caseOption)}));
2509  moduleNames.push_back(SymbolRefAttr::get(caseModule.getModuleNameAttr()));
2510  }
2511 
2512  return build(builder, result, resultTypes, builder.getArrayAttr(moduleNames),
2513  builder.getArrayAttr(caseNames), builder.getStringAttr(name),
2514  NameKindEnumAttr::get(builder.getContext(), nameKind),
2515  defaultModule.getPortDirectionsAttr(),
2516  defaultModule.getPortNamesAttr(),
2517  builder.getArrayAttr(annotations), portAnnotationsAttr,
2518  defaultModule.getLayersAttr(),
2519  innerSym ? hw::InnerSymAttr::get(innerSym) : hw::InnerSymAttr());
2520 }
2521 
2522 std::optional<size_t> InstanceChoiceOp::getTargetResultIndex() {
2523  return std::nullopt;
2524 }
2525 
2526 void InstanceChoiceOp::print(OpAsmPrinter &p) {
2527  // Print the instance name.
2528  p << " ";
2529  p.printKeywordOrString(getName());
2530  if (auto attr = getInnerSymAttr()) {
2531  p << " sym ";
2532  p.printSymbolName(attr.getSymName());
2533  }
2534  if (getNameKindAttr().getValue() != NameKindEnum::DroppableName)
2535  p << ' ' << stringifyNameKindEnum(getNameKindAttr().getValue());
2536 
2537  // Print the attr-dict.
2538  SmallVector<StringRef, 10> omittedAttrs = {
2539  "moduleNames", "caseNames", "name",
2540  "portDirections", "portNames", "portTypes",
2541  "portAnnotations", "inner_sym", "nameKind"};
2542  if (getAnnotations().empty())
2543  omittedAttrs.push_back("annotations");
2544  if (getLayers().empty())
2545  omittedAttrs.push_back("layers");
2546  p.printOptionalAttrDict((*this)->getAttrs(), omittedAttrs);
2547 
2548  // Print the module name.
2549  p << ' ';
2550 
2551  auto moduleNames = getModuleNamesAttr();
2552  auto caseNames = getCaseNamesAttr();
2553 
2554  p.printSymbolName(cast<FlatSymbolRefAttr>(moduleNames[0]).getValue());
2555 
2556  p << " alternatives ";
2557  p.printSymbolName(
2558  cast<SymbolRefAttr>(caseNames[0]).getRootReference().getValue());
2559  p << " { ";
2560  for (size_t i = 0, n = caseNames.size(); i < n; ++i) {
2561  if (i != 0)
2562  p << ", ";
2563 
2564  auto symbol = cast<SymbolRefAttr>(caseNames[i]);
2565  p.printSymbolName(symbol.getNestedReferences()[0].getValue());
2566  p << " -> ";
2567  p.printSymbolName(cast<FlatSymbolRefAttr>(moduleNames[i + 1]).getValue());
2568  }
2569 
2570  p << " } ";
2571 
2572  // Collect all the result types as TypeAttrs for printing.
2573  SmallVector<Attribute> portTypes;
2574  portTypes.reserve(getNumResults());
2575  llvm::transform(getResultTypes(), std::back_inserter(portTypes),
2576  &TypeAttr::get);
2577  printModulePorts(p, /*block=*/nullptr, getPortDirectionsAttr(),
2578  getPortNames().getValue(), portTypes,
2579  getPortAnnotations().getValue(), {}, {});
2580 }
2581 
2582 ParseResult InstanceChoiceOp::parse(OpAsmParser &parser,
2583  OperationState &result) {
2584  auto *context = parser.getContext();
2585  auto &resultAttrs = result.attributes;
2586 
2587  std::string name;
2588  hw::InnerSymAttr innerSymAttr;
2589  SmallVector<Attribute> moduleNames;
2590  SmallVector<Attribute> caseNames;
2591  SmallVector<OpAsmParser::Argument> entryArgs;
2592  SmallVector<Direction, 4> portDirections;
2593  SmallVector<Attribute, 4> portNames;
2594  SmallVector<Attribute, 4> portTypes;
2595  SmallVector<Attribute, 4> portAnnotations;
2596  SmallVector<Attribute, 4> portSyms;
2597  SmallVector<Attribute, 4> portLocs;
2598  NameKindEnumAttr nameKind;
2599 
2600  if (parser.parseKeywordOrString(&name))
2601  return failure();
2602  if (succeeded(parser.parseOptionalKeyword("sym"))) {
2603  if (parser.parseCustomAttributeWithFallback(
2604  innerSymAttr, Type{},
2605  hw::InnerSymbolTable::getInnerSymbolAttrName(),
2606  result.attributes)) {
2607  return failure();
2608  }
2609  }
2610  if (parseNameKind(parser, nameKind) ||
2611  parser.parseOptionalAttrDict(result.attributes))
2612  return failure();
2613 
2614  FlatSymbolRefAttr defaultModuleName;
2615  if (parser.parseAttribute(defaultModuleName))
2616  return failure();
2617  moduleNames.push_back(defaultModuleName);
2618 
2619  // alternatives { @opt::@case -> @target, ... }
2620  {
2621  FlatSymbolRefAttr optionName;
2622  if (parser.parseKeyword("alternatives") ||
2623  parser.parseAttribute(optionName) || parser.parseLBrace())
2624  return failure();
2625 
2626  FlatSymbolRefAttr moduleName;
2627  StringAttr caseName;
2628  while (succeeded(parser.parseOptionalSymbolName(caseName))) {
2629  if (parser.parseArrow() || parser.parseAttribute(moduleName))
2630  return failure();
2631  moduleNames.push_back(moduleName);
2632  caseNames.push_back(SymbolRefAttr::get(
2633  optionName.getAttr(), {FlatSymbolRefAttr::get(caseName)}));
2634  if (failed(parser.parseOptionalComma()))
2635  break;
2636  }
2637  if (parser.parseRBrace())
2638  return failure();
2639  }
2640 
2641  if (parseModulePorts(parser, /*hasSSAIdentifiers=*/false,
2642  /*supportsSymbols=*/false, entryArgs, portDirections,
2643  portNames, portTypes, portAnnotations, portSyms,
2644  portLocs))
2645  return failure();
2646 
2647  // Add the attributes. We let attributes defined in the attr-dict override
2648  // attributes parsed out of the module signature.
2649  if (!resultAttrs.get("moduleNames"))
2650  result.addAttribute("moduleNames", ArrayAttr::get(context, moduleNames));
2651  if (!resultAttrs.get("caseNames"))
2652  result.addAttribute("caseNames", ArrayAttr::get(context, caseNames));
2653  if (!resultAttrs.get("name"))
2654  result.addAttribute("name", StringAttr::get(context, name));
2655  result.addAttribute("nameKind", nameKind);
2656  if (!resultAttrs.get("portDirections"))
2657  result.addAttribute("portDirections",
2658  direction::packAttribute(context, portDirections));
2659  if (!resultAttrs.get("portNames"))
2660  result.addAttribute("portNames", ArrayAttr::get(context, portNames));
2661  if (!resultAttrs.get("portAnnotations"))
2662  result.addAttribute("portAnnotations",
2663  ArrayAttr::get(context, portAnnotations));
2664 
2665  // Annotations, layers, and LowerToBind are omitted in the printed format if
2666  // they are empty, empty, and false (respectively).
2667  if (!resultAttrs.get("annotations"))
2668  resultAttrs.append("annotations", parser.getBuilder().getArrayAttr({}));
2669  if (!resultAttrs.get("layers"))
2670  resultAttrs.append("layers", parser.getBuilder().getArrayAttr({}));
2671 
2672  // Add result types.
2673  result.types.reserve(portTypes.size());
2674  llvm::transform(
2675  portTypes, std::back_inserter(result.types),
2676  [](Attribute typeAttr) { return cast<TypeAttr>(typeAttr).getValue(); });
2677 
2678  return success();
2679 }
2680 
2682  StringRef base = getName().empty() ? "inst" : getName();
2683  for (auto [result, name] : llvm::zip(getResults(), getPortNames()))
2684  setNameFn(result, (base + "_" + cast<StringAttr>(name).getValue()).str());
2685 }
2686 
2687 LogicalResult InstanceChoiceOp::verify() {
2688  if (getCaseNamesAttr().empty())
2689  return emitOpError() << "must have at least one case";
2690  if (getModuleNamesAttr().size() != getCaseNamesAttr().size() + 1)
2691  return emitOpError() << "number of referenced modules does not match the "
2692  "number of options";
2693 
2694  // The modules may only be instantiated under their required layers (which
2695  // are the same for all modules).
2696  auto ambientLayers = getAmbientLayersAt(getOperation());
2697  SmallVector<SymbolRefAttr> missingLayers;
2698  for (auto layer : getLayersAttr().getAsRange<SymbolRefAttr>())
2699  if (!isLayerCompatibleWith(layer, ambientLayers))
2700  missingLayers.push_back(layer);
2701 
2702  if (missingLayers.empty())
2703  return success();
2704 
2705  auto diag =
2706  emitOpError("ambient layers are insufficient to instantiate module");
2707  auto &note = diag.attachNote();
2708  note << "missing layer requirements: ";
2709  interleaveComma(missingLayers, note);
2710  return failure();
2711 }
2712 
2713 LogicalResult
2714 InstanceChoiceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
2715  auto caseNames = getCaseNamesAttr();
2716  for (auto moduleName : getModuleNamesAttr()) {
2718  *this, symbolTable, cast<FlatSymbolRefAttr>(moduleName))))
2719  return failure();
2720  }
2721 
2722  auto root = cast<SymbolRefAttr>(caseNames[0]).getRootReference();
2723  for (size_t i = 0, n = caseNames.size(); i < n; ++i) {
2724  auto ref = cast<SymbolRefAttr>(caseNames[i]);
2725  auto refRoot = ref.getRootReference();
2726  if (ref.getRootReference() != root)
2727  return emitOpError() << "case " << ref
2728  << " is not in the same option group as "
2729  << caseNames[0];
2730 
2731  if (!symbolTable.lookupNearestSymbolFrom<OptionOp>(*this, refRoot))
2732  return emitOpError() << "option " << refRoot << " does not exist";
2733 
2734  if (!symbolTable.lookupNearestSymbolFrom<OptionCaseOp>(*this, ref))
2735  return emitOpError() << "option " << refRoot
2736  << " does not contain option case " << ref;
2737  }
2738 
2739  return success();
2740 }
2741 
2742 FlatSymbolRefAttr
2743 InstanceChoiceOp::getTargetOrDefaultAttr(OptionCaseOp option) {
2744  auto caseNames = getCaseNamesAttr();
2745  for (size_t i = 0, n = caseNames.size(); i < n; ++i) {
2746  StringAttr caseSym = cast<SymbolRefAttr>(caseNames[i]).getLeafReference();
2747  if (caseSym == option.getSymName())
2748  return cast<FlatSymbolRefAttr>(getModuleNamesAttr()[i + 1]);
2749  }
2750  return getDefaultTargetAttr();
2751 }
2752 
2753 SmallVector<std::pair<SymbolRefAttr, FlatSymbolRefAttr>, 1>
2754 InstanceChoiceOp::getTargetChoices() {
2755  auto caseNames = getCaseNamesAttr();
2756  auto moduleNames = getModuleNamesAttr();
2757  SmallVector<std::pair<SymbolRefAttr, FlatSymbolRefAttr>, 1> choices;
2758  for (size_t i = 0; i < caseNames.size(); ++i) {
2759  choices.emplace_back(cast<SymbolRefAttr>(caseNames[i]),
2760  cast<FlatSymbolRefAttr>(moduleNames[i + 1]));
2761  }
2762 
2763  return choices;
2764 }
2765 
2766 InstanceChoiceOp
2767 InstanceChoiceOp::erasePorts(OpBuilder &builder,
2768  const llvm::BitVector &portIndices) {
2769  assert(portIndices.size() >= getNumResults() &&
2770  "portIndices is not at least as large as getNumResults()");
2771 
2772  if (portIndices.none())
2773  return *this;
2774 
2775  SmallVector<Type> newResultTypes = removeElementsAtIndices<Type>(
2776  SmallVector<Type>(result_type_begin(), result_type_end()), portIndices);
2777  SmallVector<Direction> newPortDirections = removeElementsAtIndices<Direction>(
2778  direction::unpackAttribute(getPortDirectionsAttr()), portIndices);
2779  SmallVector<Attribute> newPortNames =
2780  removeElementsAtIndices(getPortNames().getValue(), portIndices);
2781  SmallVector<Attribute> newPortAnnotations =
2782  removeElementsAtIndices(getPortAnnotations().getValue(), portIndices);
2783 
2784  auto newOp = builder.create<InstanceChoiceOp>(
2785  getLoc(), newResultTypes, getModuleNames(), getCaseNames(), getName(),
2786  getNameKind(), direction::packAttribute(getContext(), newPortDirections),
2787  ArrayAttr::get(getContext(), newPortNames), getAnnotationsAttr(),
2788  ArrayAttr::get(getContext(), newPortAnnotations), getLayers(),
2789  getInnerSymAttr());
2790 
2791  for (unsigned oldIdx = 0, newIdx = 0, numOldPorts = getNumResults();
2792  oldIdx != numOldPorts; ++oldIdx) {
2793  if (portIndices.test(oldIdx)) {
2794  assert(getResult(oldIdx).use_empty() && "removed instance port has uses");
2795  continue;
2796  }
2797  getResult(oldIdx).replaceAllUsesWith(newOp.getResult(newIdx));
2798  ++newIdx;
2799  }
2800 
2801  // Copy over "output_file" information so that this is not lost when ports
2802  // are erased.
2803  //
2804  // TODO: Other attributes may need to be copied over.
2805  if (auto outputFile = (*this)->getAttr("output_file"))
2806  newOp->setAttr("output_file", outputFile);
2807 
2808  return newOp;
2809 }
2810 
2811 //===----------------------------------------------------------------------===//
2812 // MemOp
2813 //===----------------------------------------------------------------------===//
2814 
2815 void MemOp::build(OpBuilder &builder, OperationState &result,
2816  TypeRange resultTypes, uint32_t readLatency,
2817  uint32_t writeLatency, uint64_t depth, RUWAttr ruw,
2818  ArrayRef<Attribute> portNames, StringRef name,
2819  NameKindEnum nameKind, ArrayRef<Attribute> annotations,
2820  ArrayRef<Attribute> portAnnotations,
2821  hw::InnerSymAttr innerSym) {
2822  result.addAttribute(
2823  "readLatency",
2824  builder.getIntegerAttr(builder.getIntegerType(32), readLatency));
2825  result.addAttribute(
2826  "writeLatency",
2827  builder.getIntegerAttr(builder.getIntegerType(32), writeLatency));
2828  result.addAttribute(
2829  "depth", builder.getIntegerAttr(builder.getIntegerType(64), depth));
2830  result.addAttribute("ruw", ::RUWAttrAttr::get(builder.getContext(), ruw));
2831  result.addAttribute("portNames", builder.getArrayAttr(portNames));
2832  result.addAttribute("name", builder.getStringAttr(name));
2833  result.addAttribute("nameKind",
2834  NameKindEnumAttr::get(builder.getContext(), nameKind));
2835  result.addAttribute("annotations", builder.getArrayAttr(annotations));
2836  if (innerSym)
2837  result.addAttribute("inner_sym", innerSym);
2838  result.addTypes(resultTypes);
2839 
2840  if (portAnnotations.empty()) {
2841  SmallVector<Attribute, 16> portAnnotationsVec(resultTypes.size(),
2842  builder.getArrayAttr({}));
2843  result.addAttribute("portAnnotations",
2844  builder.getArrayAttr(portAnnotationsVec));
2845  } else {
2846  assert(portAnnotations.size() == resultTypes.size());
2847  result.addAttribute("portAnnotations",
2848  builder.getArrayAttr(portAnnotations));
2849  }
2850 }
2851 
2852 void MemOp::build(OpBuilder &builder, OperationState &result,
2853  TypeRange resultTypes, uint32_t readLatency,
2854  uint32_t writeLatency, uint64_t depth, RUWAttr ruw,
2855  ArrayRef<Attribute> portNames, StringRef name,
2856  NameKindEnum nameKind, ArrayRef<Attribute> annotations,
2857  ArrayRef<Attribute> portAnnotations, StringAttr innerSym) {
2858  build(builder, result, resultTypes, readLatency, writeLatency, depth, ruw,
2859  portNames, name, nameKind, annotations, portAnnotations,
2860  innerSym ? hw::InnerSymAttr::get(innerSym) : hw::InnerSymAttr());
2861 }
2862 
2863 ArrayAttr MemOp::getPortAnnotation(unsigned portIdx) {
2864  assert(portIdx < getNumResults() &&
2865  "index should be smaller than result number");
2866  return cast<ArrayAttr>(getPortAnnotations()[portIdx]);
2867 }
2868 
2869 void MemOp::setAllPortAnnotations(ArrayRef<Attribute> annotations) {
2870  assert(annotations.size() == getNumResults() &&
2871  "number of annotations is not equal to result number");
2872  (*this)->setAttr("portAnnotations",
2873  ArrayAttr::get(getContext(), annotations));
2874 }
2875 
2876 // Get the number of read, write and read-write ports.
2877 void MemOp::getNumPorts(size_t &numReadPorts, size_t &numWritePorts,
2878  size_t &numReadWritePorts, size_t &numDbgsPorts) {
2879  numReadPorts = 0;
2880  numWritePorts = 0;
2881  numReadWritePorts = 0;
2882  numDbgsPorts = 0;
2883  for (size_t i = 0, e = getNumResults(); i != e; ++i) {
2884  auto portKind = getPortKind(i);
2885  if (portKind == MemOp::PortKind::Debug)
2886  ++numDbgsPorts;
2887  else if (portKind == MemOp::PortKind::Read)
2888  ++numReadPorts;
2889  else if (portKind == MemOp::PortKind::Write) {
2890  ++numWritePorts;
2891  } else
2892  ++numReadWritePorts;
2893  }
2894 }
2895 
2896 /// Verify the correctness of a MemOp.
2897 LogicalResult MemOp::verify() {
2898 
2899  // Store the port names as we find them. This lets us check quickly
2900  // for uniqueneess.
2901  llvm::SmallDenseSet<Attribute, 8> portNamesSet;
2902 
2903  // Store the previous data type. This lets us check that the data
2904  // type is consistent across all ports.
2905  FIRRTLType oldDataType;
2906 
2907  for (size_t i = 0, e = getNumResults(); i != e; ++i) {
2908  auto portName = getPortName(i);
2909 
2910  // Get a bundle type representing this port, stripping an outer
2911  // flip if it exists. If this is not a bundle<> or
2912  // flip<bundle<>>, then this is an error.
2913  BundleType portBundleType =
2914  type_dyn_cast<BundleType>(getResult(i).getType());
2915 
2916  // Require that all port names are unique.
2917  if (!portNamesSet.insert(portName).second) {
2918  emitOpError() << "has non-unique port name " << portName;
2919  return failure();
2920  }
2921 
2922  // Determine the kind of the memory. If the kind cannot be
2923  // determined, then it's indicative of the wrong number of fields
2924  // in the type (but we don't know any more just yet).
2925 
2926  auto elt = getPortNamed(portName);
2927  if (!elt) {
2928  emitOpError() << "could not get port with name " << portName;
2929  return failure();
2930  }
2931  auto firrtlType = type_cast<FIRRTLType>(elt.getType());
2932  MemOp::PortKind portKind = getMemPortKindFromType(firrtlType);
2933 
2934  if (portKind == MemOp::PortKind::Debug &&
2935  !type_isa<RefType>(getResult(i).getType()))
2936  return emitOpError() << "has an invalid type on port " << portName
2937  << " (expected Read/Write/ReadWrite/Debug)";
2938  if (type_isa<RefType>(firrtlType) && e == 1)
2939  return emitOpError()
2940  << "cannot have only one port of debug type. Debug port can only "
2941  "exist alongside other read/write/read-write port";
2942 
2943  // Safely search for the "data" field, erroring if it can't be
2944  // found.
2945  FIRRTLBaseType dataType;
2946  if (portKind == MemOp::PortKind::Debug) {
2947  auto resType = type_cast<RefType>(getResult(i).getType());
2948  if (!(resType && type_isa<FVectorType>(resType.getType())))
2949  return emitOpError() << "debug ports must be a RefType of FVectorType";
2950  dataType = type_cast<FVectorType>(resType.getType()).getElementType();
2951  } else {
2952  auto dataTypeOption = portBundleType.getElement("data");
2953  if (!dataTypeOption && portKind == MemOp::PortKind::ReadWrite)
2954  dataTypeOption = portBundleType.getElement("wdata");
2955  if (!dataTypeOption) {
2956  emitOpError() << "has no data field on port " << portName
2957  << " (expected to see \"data\" for a read or write "
2958  "port or \"rdata\" for a read/write port)";
2959  return failure();
2960  }
2961  dataType = dataTypeOption->type;
2962  // Read data is expected to ba a flip.
2963  if (portKind == MemOp::PortKind::Read) {
2964  // FIXME error on missing bundle flip
2965  }
2966  }
2967 
2968  // Error if the data type isn't passive.
2969  if (!dataType.isPassive()) {
2970  emitOpError() << "has non-passive data type on port " << portName
2971  << " (memory types must be passive)";
2972  return failure();
2973  }
2974 
2975  // Error if the data type contains analog types.
2976  if (dataType.containsAnalog()) {
2977  emitOpError() << "has a data type that contains an analog type on port "
2978  << portName
2979  << " (memory types cannot contain analog types)";
2980  return failure();
2981  }
2982 
2983  // Check that the port type matches the kind that we determined
2984  // for this port. This catches situations of extraneous port
2985  // fields beind included or the fields being named incorrectly.
2986  FIRRTLType expectedType =
2987  getTypeForPort(getDepth(), dataType, portKind,
2988  dataType.isGround() ? getMaskBits() : 0);
2989  // Compute the original port type as portBundleType may have
2990  // stripped outer flip information.
2991  auto originalType = getResult(i).getType();
2992  if (originalType != expectedType) {
2993  StringRef portKindName;
2994  switch (portKind) {
2995  case MemOp::PortKind::Read:
2996  portKindName = "read";
2997  break;
2998  case MemOp::PortKind::Write:
2999  portKindName = "write";
3000  break;
3001  case MemOp::PortKind::ReadWrite:
3002  portKindName = "readwrite";
3003  break;
3004  case MemOp::PortKind::Debug:
3005  portKindName = "dbg";
3006  break;
3007  }
3008  emitOpError() << "has an invalid type for port " << portName
3009  << " of determined kind \"" << portKindName
3010  << "\" (expected " << expectedType << ", but got "
3011  << originalType << ")";
3012  return failure();
3013  }
3014 
3015  // Error if the type of the current port was not the same as the
3016  // last port, but skip checking the first port.
3017  if (oldDataType && oldDataType != dataType) {
3018  emitOpError() << "port " << getPortName(i)
3019  << " has a different type than port " << getPortName(i - 1)
3020  << " (expected " << oldDataType << ", but got " << dataType
3021  << ")";
3022  return failure();
3023  }
3024 
3025  oldDataType = dataType;
3026  }
3027 
3028  auto maskWidth = getMaskBits();
3029 
3030  auto dataWidth = getDataType().getBitWidthOrSentinel();
3031  if (dataWidth > 0 && maskWidth > (size_t)dataWidth)
3032  return emitOpError("the mask width cannot be greater than "
3033  "data width");
3034 
3035  if (getPortAnnotations().size() != getNumResults())
3036  return emitOpError("the number of result annotations should be "
3037  "equal to the number of results");
3038 
3039  return success();
3040 }
3041 
3042 static size_t getAddressWidth(size_t depth) {
3043  return std::max(1U, llvm::Log2_64_Ceil(depth));
3044 }
3045 
3046 size_t MemOp::getAddrBits() { return getAddressWidth(getDepth()); }
3047 
3048 FIRRTLType MemOp::getTypeForPort(uint64_t depth, FIRRTLBaseType dataType,
3049  PortKind portKind, size_t maskBits) {
3050 
3051  auto *context = dataType.getContext();
3052  if (portKind == PortKind::Debug)
3053  return RefType::get(FVectorType::get(dataType, depth));
3054  FIRRTLBaseType maskType;
3055  // maskBits not specified (==0), then get the mask type from the dataType.
3056  if (maskBits == 0)
3057  maskType = dataType.getMaskType();
3058  else
3059  maskType = UIntType::get(context, maskBits);
3060 
3061  auto getId = [&](StringRef name) -> StringAttr {
3062  return StringAttr::get(context, name);
3063  };
3064 
3065  SmallVector<BundleType::BundleElement, 7> portFields;
3066 
3067  auto addressType = UIntType::get(context, getAddressWidth(depth));
3068 
3069  portFields.push_back({getId("addr"), false, addressType});
3070  portFields.push_back({getId("en"), false, UIntType::get(context, 1)});
3071  portFields.push_back({getId("clk"), false, ClockType::get(context)});
3072 
3073  switch (portKind) {
3074  case PortKind::Read:
3075  portFields.push_back({getId("data"), true, dataType});
3076  break;
3077 
3078  case PortKind::Write:
3079  portFields.push_back({getId("data"), false, dataType});
3080  portFields.push_back({getId("mask"), false, maskType});
3081  break;
3082 
3083  case PortKind::ReadWrite:
3084  portFields.push_back({getId("rdata"), true, dataType});
3085  portFields.push_back({getId("wmode"), false, UIntType::get(context, 1)});
3086  portFields.push_back({getId("wdata"), false, dataType});
3087  portFields.push_back({getId("wmask"), false, maskType});
3088  break;
3089  default:
3090  llvm::report_fatal_error("memory port kind not handled");
3091  break;
3092  }
3093 
3094  return BundleType::get(context, portFields);
3095 }
3096 
3097 /// Return the name and kind of ports supported by this memory.
3098 SmallVector<MemOp::NamedPort> MemOp::getPorts() {
3099  SmallVector<MemOp::NamedPort> result;
3100  // Each entry in the bundle is a port.
3101  for (size_t i = 0, e = getNumResults(); i != e; ++i) {
3102  // Each port is a bundle.
3103  auto portType = type_cast<FIRRTLType>(getResult(i).getType());
3104  result.push_back({getPortName(i), getMemPortKindFromType(portType)});
3105  }
3106  return result;
3107 }
3108 
3109 /// Return the kind of the specified port.
3110 MemOp::PortKind MemOp::getPortKind(StringRef portName) {
3111  return getMemPortKindFromType(
3112  type_cast<FIRRTLType>(getPortNamed(portName).getType()));
3113 }
3114 
3115 /// Return the kind of the specified port number.
3116 MemOp::PortKind MemOp::getPortKind(size_t resultNo) {
3117  return getMemPortKindFromType(
3118  type_cast<FIRRTLType>(getResult(resultNo).getType()));
3119 }
3120 
3121 /// Return the number of bits in the mask for the memory.
3122 size_t MemOp::getMaskBits() {
3123 
3124  for (auto res : getResults()) {
3125  if (type_isa<RefType>(res.getType()))
3126  continue;
3127  auto firstPortType = type_cast<FIRRTLBaseType>(res.getType());
3128  if (getMemPortKindFromType(firstPortType) == PortKind::Read ||
3129  getMemPortKindFromType(firstPortType) == PortKind::Debug)
3130  continue;
3131 
3132  FIRRTLBaseType mType;
3133  for (auto t : type_cast<BundleType>(firstPortType.getPassiveType())) {
3134  if (t.name.getValue().contains("mask"))
3135  mType = t.type;
3136  }
3137  if (type_isa<UIntType>(mType))
3138  return mType.getBitWidthOrSentinel();
3139  }
3140  // Mask of zero bits means, either there are no write/readwrite ports or the
3141  // mask is of aggregate type.
3142  return 0;
3143 }
3144 
3145 /// Return the data-type field of the memory, the type of each element.
3146 FIRRTLBaseType MemOp::getDataType() {
3147  assert(getNumResults() != 0 && "Mems with no read/write ports are illegal");
3148 
3149  if (auto refType = type_dyn_cast<RefType>(getResult(0).getType()))
3150  return type_cast<FVectorType>(refType.getType()).getElementType();
3151  auto firstPortType = type_cast<FIRRTLBaseType>(getResult(0).getType());
3152 
3153  StringRef dataFieldName = "data";
3154  if (getMemPortKindFromType(firstPortType) == PortKind::ReadWrite)
3155  dataFieldName = "rdata";
3156 
3157  return type_cast<BundleType>(firstPortType.getPassiveType())
3158  .getElementType(dataFieldName);
3159 }
3160 
3161 StringAttr MemOp::getPortName(size_t resultNo) {
3162  return cast<StringAttr>(getPortNames()[resultNo]);
3163 }
3164 
3165 FIRRTLBaseType MemOp::getPortType(size_t resultNo) {
3166  return type_cast<FIRRTLBaseType>(getResults()[resultNo].getType());
3167 }
3168 
3169 Value MemOp::getPortNamed(StringAttr name) {
3170  auto namesArray = getPortNames();
3171  for (size_t i = 0, e = namesArray.size(); i != e; ++i) {
3172  if (namesArray[i] == name) {
3173  assert(i < getNumResults() && " names array out of sync with results");
3174  return getResult(i);
3175  }
3176  }
3177  return Value();
3178 }
3179 
3180 // Extract all the relevant attributes from the MemOp and return the FirMemory.
3182  auto op = *this;
3183  size_t numReadPorts = 0;
3184  size_t numWritePorts = 0;
3185  size_t numReadWritePorts = 0;
3187  SmallVector<int32_t> writeClockIDs;
3188 
3189  for (size_t i = 0, e = op.getNumResults(); i != e; ++i) {
3190  auto portKind = op.getPortKind(i);
3191  if (portKind == MemOp::PortKind::Read)
3192  ++numReadPorts;
3193  else if (portKind == MemOp::PortKind::Write) {
3194  for (auto *a : op.getResult(i).getUsers()) {
3195  auto subfield = dyn_cast<SubfieldOp>(a);
3196  if (!subfield || subfield.getFieldIndex() != 2)
3197  continue;
3198  auto clockPort = a->getResult(0);
3199  for (auto *b : clockPort.getUsers()) {
3200  if (auto connect = dyn_cast<FConnectLike>(b)) {
3201  if (connect.getDest() == clockPort) {
3202  auto result =
3203  clockToLeader.insert({circt::firrtl::getModuleScopedDriver(
3204  connect.getSrc(), true, true, true),
3205  numWritePorts});
3206  if (result.second) {
3207  writeClockIDs.push_back(numWritePorts);
3208  } else {
3209  writeClockIDs.push_back(result.first->second);
3210  }
3211  }
3212  }
3213  }
3214  break;
3215  }
3216  ++numWritePorts;
3217  } else
3218  ++numReadWritePorts;
3219  }
3220 
3221  size_t width = 0;
3222  if (auto widthV = getBitWidth(op.getDataType()))
3223  width = *widthV;
3224  else
3225  op.emitError("'firrtl.mem' should have simple type and known width");
3226  MemoryInitAttr init = op->getAttrOfType<MemoryInitAttr>("init");
3227  StringAttr modName;
3228  if (op->hasAttr("modName"))
3229  modName = op->getAttrOfType<StringAttr>("modName");
3230  else {
3231  SmallString<8> clocks;
3232  for (auto a : writeClockIDs)
3233  clocks.append(Twine((char)(a + 'a')).str());
3234  SmallString<32> initStr;
3235  // If there is a file initialization, then come up with a decent
3236  // representation for this. Use the filename, but only characters
3237  // [a-zA-Z0-9] and the bool/hex and inline booleans.
3238  if (init) {
3239  for (auto c : init.getFilename().getValue())
3240  if ((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') ||
3241  (c >= '0' && c <= '9'))
3242  initStr.push_back(c);
3243  initStr.push_back('_');
3244  initStr.push_back(init.getIsBinary() ? 't' : 'f');
3245  initStr.push_back('_');
3246  initStr.push_back(init.getIsInline() ? 't' : 'f');
3247  }
3248  modName = StringAttr::get(
3249  op->getContext(),
3250  llvm::formatv(
3251  "{0}FIRRTLMem_{1}_{2}_{3}_{4}_{5}_{6}_{7}_{8}_{9}_{10}{11}{12}",
3252  op.getPrefix().value_or(""), numReadPorts, numWritePorts,
3253  numReadWritePorts, (size_t)width, op.getDepth(),
3254  op.getReadLatency(), op.getWriteLatency(), op.getMaskBits(),
3255  (unsigned)op.getRuw(), (unsigned)seq::WUW::PortOrder,
3256  clocks.empty() ? "" : "_" + clocks, init ? initStr.str() : ""));
3257  }
3258  return {numReadPorts,
3259  numWritePorts,
3260  numReadWritePorts,
3261  (size_t)width,
3262  op.getDepth(),
3263  op.getReadLatency(),
3264  op.getWriteLatency(),
3265  op.getMaskBits(),
3266  *seq::symbolizeRUW(unsigned(op.getRuw())),
3267  seq::WUW::PortOrder,
3268  writeClockIDs,
3269  modName,
3270  op.getMaskBits() > 1,
3271  init,
3272  op.getPrefixAttr(),
3273  op.getLoc()};
3274 }
3275 
3277  StringRef base = getName();
3278  if (base.empty())
3279  base = "mem";
3280 
3281  for (size_t i = 0, e = (*this)->getNumResults(); i != e; ++i) {
3282  setNameFn(getResult(i), (base + "_" + getPortNameStr(i)).str());
3283  }
3284 }
3285 
3286 std::optional<size_t> MemOp::getTargetResultIndex() {
3287  // Inner symbols on memory operations target the op not any result.
3288  return std::nullopt;
3289 }
3290 
3291 // Construct name of the module which will be used for the memory definition.
3292 StringAttr FirMemory::getFirMemoryName() const { return modName; }
3293 
3294 /// Helper for naming forceable declarations (and their optional ref result).
3295 static void forceableAsmResultNames(Forceable op, StringRef name,
3296  OpAsmSetValueNameFn setNameFn) {
3297  if (name.empty())
3298  return;
3299  setNameFn(op.getDataRaw(), name);
3300  if (op.isForceable())
3301  setNameFn(op.getDataRef(), (name + "_ref").str());
3302 }
3303 
3305  return forceableAsmResultNames(*this, getName(), setNameFn);
3306 }
3307 
3308 LogicalResult NodeOp::inferReturnTypes(
3309  mlir::MLIRContext *context, std::optional<mlir::Location> location,
3310  ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes,
3311  ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions,
3312  ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) {
3313  if (operands.empty())
3314  return failure();
3315  inferredReturnTypes.push_back(operands[0].getType());
3316  for (auto &attr : attributes)
3317  if (attr.getName() == Forceable::getForceableAttrName()) {
3318  auto forceableType =
3319  firrtl::detail::getForceableResultType(true, operands[0].getType());
3320  if (!forceableType) {
3321  if (location)
3322  ::mlir::emitError(*location, "cannot force a node of type ")
3323  << operands[0].getType();
3324  return failure();
3325  }
3326  inferredReturnTypes.push_back(forceableType);
3327  }
3328  return success();
3329 }
3330 
3331 std::optional<size_t> NodeOp::getTargetResultIndex() { return 0; }
3332 
3334  return forceableAsmResultNames(*this, getName(), setNameFn);
3335 }
3336 
3337 std::optional<size_t> RegOp::getTargetResultIndex() { return 0; }
3338 
3339 SmallVector<std::pair<circt::FieldRef, circt::FieldRef>>
3340 RegOp::computeDataFlow() {
3341  // A register does't have any combinational dataflow.
3342  return {};
3343 }
3344 
3345 LogicalResult RegResetOp::verify() {
3346  auto reset = getResetValue();
3347 
3348  FIRRTLBaseType resetType = reset.getType();
3349  FIRRTLBaseType regType = getResult().getType();
3350 
3351  // The type of the initialiser must be equivalent to the register type.
3352  if (!areTypesEquivalent(regType, resetType))
3353  return emitError("type mismatch between register ")
3354  << regType << " and reset value " << resetType;
3355 
3356  return success();
3357 }
3358 
3359 std::optional<size_t> RegResetOp::getTargetResultIndex() { return 0; }
3360 
3362  return forceableAsmResultNames(*this, getName(), setNameFn);
3363 }
3364 
3365 //===----------------------------------------------------------------------===//
3366 // FormalOp
3367 //===----------------------------------------------------------------------===//
3368 
3369 LogicalResult
3370 FormalOp::verifySymbolUses(::mlir::SymbolTableCollection &symbolTable) {
3371  // The referenced symbol is restricted to FModuleOps
3372  auto referencedModule = symbolTable.lookupNearestSymbolFrom<FModuleOp>(
3373  *this, getModuleNameAttr());
3374  if (!referencedModule)
3375  return (*this)->emitOpError("invalid symbol reference");
3376 
3377  return success();
3378 }
3379 
3380 //===----------------------------------------------------------------------===//
3381 // WireOp
3382 //===----------------------------------------------------------------------===//
3383 
3385  return forceableAsmResultNames(*this, getName(), setNameFn);
3386 }
3387 
3388 SmallVector<std::pair<circt::FieldRef, circt::FieldRef>>
3389 RegResetOp::computeDataFlow() {
3390  // A register does't have any combinational dataflow.
3391  return {};
3392 }
3393 
3394 std::optional<size_t> WireOp::getTargetResultIndex() { return 0; }
3395 
3396 LogicalResult WireOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
3397  auto refType = type_dyn_cast<RefType>(getType(0));
3398  if (!refType)
3399  return success();
3400 
3401  return verifyProbeType(
3402  refType, getLoc(), getOperation()->getParentOfType<CircuitOp>(),
3403  symbolTable, Twine("'") + getOperationName() + "' op is");
3404 }
3405 
3406 //===----------------------------------------------------------------------===//
3407 // ObjectOp
3408 //===----------------------------------------------------------------------===//
3409 
3410 void ObjectOp::build(OpBuilder &builder, OperationState &state, ClassLike klass,
3411  StringRef name) {
3412  build(builder, state, klass.getInstanceType(),
3413  StringAttr::get(builder.getContext(), name));
3414 }
3415 
3416 LogicalResult ObjectOp::verify() { return success(); }
3417 
3418 LogicalResult ObjectOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
3419  auto circuitOp = getOperation()->getParentOfType<CircuitOp>();
3420  auto classType = getType();
3421  auto className = classType.getNameAttr();
3422 
3423  // verify that the class exists.
3424  auto classOp = dyn_cast_or_null<ClassLike>(
3425  symbolTable.lookupSymbolIn(circuitOp, className));
3426  if (!classOp)
3427  return emitOpError() << "references unknown class " << className;
3428 
3429  // verify that the result type agrees with the class definition.
3430  if (failed(classOp.verifyType(classType, [&]() { return emitOpError(); })))
3431  return failure();
3432 
3433  return success();
3434 }
3435 
3436 StringAttr ObjectOp::getClassNameAttr() {
3437  return getType().getNameAttr().getAttr();
3438 }
3439 
3440 StringRef ObjectOp::getClassName() { return getType().getName(); }
3441 
3442 ClassLike ObjectOp::getReferencedClass(const SymbolTable &symbolTable) {
3443  auto symRef = getType().getNameAttr();
3444  return symbolTable.lookup<ClassLike>(symRef.getLeafReference());
3445 }
3446 
3447 Operation *ObjectOp::getReferencedOperation(const SymbolTable &symtbl) {
3448  return getReferencedClass(symtbl);
3449 }
3450 
3451 StringRef ObjectOp::getInstanceName() { return getName(); }
3452 
3453 StringAttr ObjectOp::getInstanceNameAttr() { return getNameAttr(); }
3454 
3455 StringRef ObjectOp::getReferencedModuleName() { return getClassName(); }
3456 
3457 StringAttr ObjectOp::getReferencedModuleNameAttr() {
3458  return getClassNameAttr();
3459 }
3460 
3462  setNameFn(getResult(), getName());
3463 }
3464 
3465 //===----------------------------------------------------------------------===//
3466 // Statements
3467 //===----------------------------------------------------------------------===//
3468 
3469 LogicalResult AttachOp::verify() {
3470  // All known widths must match.
3471  std::optional<int32_t> commonWidth;
3472  for (auto operand : getOperands()) {
3473  auto thisWidth = type_cast<AnalogType>(operand.getType()).getWidth();
3474  if (!thisWidth)
3475  continue;
3476  if (!commonWidth) {
3477  commonWidth = thisWidth;
3478  continue;
3479  }
3480  if (commonWidth != thisWidth)
3481  return emitOpError("is inavlid as not all known operand widths match");
3482  }
3483  return success();
3484 }
3485 
3486 /// Check if the source and sink are of appropriate flow.
3487 static LogicalResult checkConnectFlow(Operation *connect) {
3488  Value dst = connect->getOperand(0);
3489  Value src = connect->getOperand(1);
3490 
3491  // TODO: Relax this to allow reads from output ports,
3492  // instance/memory input ports.
3493  auto srcFlow = foldFlow(src);
3494  if (!isValidSrc(srcFlow)) {
3495  // A sink that is a port output or instance input used as a source is okay.
3496  auto kind = getDeclarationKind(src);
3497  if (kind != DeclKind::Port && kind != DeclKind::Instance) {
3498  auto srcRef = getFieldRefFromValue(src, /*lookThroughCasts=*/true);
3499  auto [srcName, rootKnown] = getFieldName(srcRef);
3500  auto diag = emitError(connect->getLoc());
3501  diag << "connect has invalid flow: the source expression ";
3502  if (rootKnown)
3503  diag << "\"" << srcName << "\" ";
3504  diag << "has " << toString(srcFlow) << ", expected source or duplex flow";
3505  return diag.attachNote(srcRef.getLoc()) << "the source was defined here";
3506  }
3507  }
3508 
3509  auto dstFlow = foldFlow(dst);
3510  if (!isValidDst(dstFlow)) {
3511  auto dstRef = getFieldRefFromValue(dst, /*lookThroughCasts=*/true);
3512  auto [dstName, rootKnown] = getFieldName(dstRef);
3513  auto diag = emitError(connect->getLoc());
3514  diag << "connect has invalid flow: the destination expression ";
3515  if (rootKnown)
3516  diag << "\"" << dstName << "\" ";
3517  diag << "has " << toString(dstFlow) << ", expected sink or duplex flow";
3518  return diag.attachNote(dstRef.getLoc())
3519  << "the destination was defined here";
3520  }
3521  return success();
3522 }
3523 
3524 // NOLINTBEGIN(misc-no-recursion)
3525 /// Checks if the type has any 'const' leaf elements . If `isFlip` is `true`,
3526 /// the `const` leaf is not considered to be driven.
3527 static bool isConstFieldDriven(FIRRTLBaseType type, bool isFlip = false,
3528  bool outerTypeIsConst = false) {
3529  auto typeIsConst = outerTypeIsConst || type.isConst();
3530 
3531  if (typeIsConst && type.isPassive())
3532  return !isFlip;
3533 
3534  if (auto bundleType = type_dyn_cast<BundleType>(type))
3535  return llvm::any_of(bundleType.getElements(), [&](auto &element) {
3536  return isConstFieldDriven(element.type, isFlip ^ element.isFlip,
3537  typeIsConst);
3538  });
3539 
3540  if (auto vectorType = type_dyn_cast<FVectorType>(type))
3541  return isConstFieldDriven(vectorType.getElementType(), isFlip, typeIsConst);
3542 
3543  if (typeIsConst)
3544  return !isFlip;
3545  return false;
3546 }
3547 // NOLINTEND(misc-no-recursion)
3548 
3549 /// Checks that connections to 'const' destinations are not dependent on
3550 /// non-'const' conditions in when blocks.
3551 static LogicalResult checkConnectConditionality(FConnectLike connect) {
3552  auto dest = connect.getDest();
3553  auto destType = type_dyn_cast<FIRRTLBaseType>(dest.getType());
3554  auto src = connect.getSrc();
3555  auto srcType = type_dyn_cast<FIRRTLBaseType>(src.getType());
3556  if (!destType || !srcType)
3557  return success();
3558 
3559  auto destRefinedType = destType;
3560  auto srcRefinedType = srcType;
3561 
3562  /// Looks up the value's defining op until the defining op is null or a
3563  /// declaration of the value. If a SubAccessOp is encountered with a 'const'
3564  /// input, `originalFieldType` is made 'const'.
3565  auto findFieldDeclarationRefiningFieldType =
3566  [](Value value, FIRRTLBaseType &originalFieldType) -> Value {
3567  while (auto *definingOp = value.getDefiningOp()) {
3568  bool shouldContinue = true;
3569  TypeSwitch<Operation *>(definingOp)
3570  .Case<SubfieldOp, SubindexOp>([&](auto op) { value = op.getInput(); })
3571  .Case<SubaccessOp>([&](SubaccessOp op) {
3572  if (op.getInput()
3573  .getType()
3574  .base()
3575  .getElementTypePreservingConst()
3576  .isConst())
3577  originalFieldType = originalFieldType.getConstType(true);
3578  value = op.getInput();
3579  })
3580  .Default([&](Operation *) { shouldContinue = false; });
3581  if (!shouldContinue)
3582  break;
3583  }
3584  return value;
3585  };
3586 
3587  auto destDeclaration =
3588  findFieldDeclarationRefiningFieldType(dest, destRefinedType);
3589  auto srcDeclaration =
3590  findFieldDeclarationRefiningFieldType(src, srcRefinedType);
3591 
3592  auto checkConstConditionality = [&](Value value, FIRRTLBaseType type,
3593  Value declaration) -> LogicalResult {
3594  auto *declarationBlock = declaration.getParentBlock();
3595  auto *block = connect->getBlock();
3596  while (block && block != declarationBlock) {
3597  auto *parentOp = block->getParentOp();
3598 
3599  if (auto whenOp = dyn_cast<WhenOp>(parentOp);
3600  whenOp && !whenOp.getCondition().getType().isConst()) {
3601  if (type.isConst())
3602  return connect.emitOpError()
3603  << "assignment to 'const' type " << type
3604  << " is dependent on a non-'const' condition";
3605  return connect->emitOpError()
3606  << "assignment to nested 'const' member of type " << type
3607  << " is dependent on a non-'const' condition";
3608  }
3609 
3610  block = parentOp->getBlock();
3611  }
3612  return success();
3613  };
3614 
3615  auto emitSubaccessError = [&] {
3616  return connect.emitError(
3617  "assignment to non-'const' subaccess of 'const' type is disallowed");
3618  };
3619 
3620  // Check destination if it contains 'const' leaves
3621  if (destRefinedType.containsConst() && isConstFieldDriven(destRefinedType)) {
3622  // Disallow assignment to non-'const' subaccesses of 'const' types
3623  if (destType != destRefinedType)
3624  return emitSubaccessError();
3625 
3626  if (failed(checkConstConditionality(dest, destType, destDeclaration)))
3627  return failure();
3628  }
3629 
3630  // Check source if it contains 'const' 'flip' leaves
3631  if (srcRefinedType.containsConst() &&
3632  isConstFieldDriven(srcRefinedType, /*isFlip=*/true)) {
3633  // Disallow assignment to non-'const' subaccesses of 'const' types
3634  if (srcType != srcRefinedType)
3635  return emitSubaccessError();
3636  if (failed(checkConstConditionality(src, srcType, srcDeclaration)))
3637  return failure();
3638  }
3639 
3640  return success();
3641 }
3642 
3643 LogicalResult ConnectOp::verify() {
3644  auto dstType = getDest().getType();
3645  auto srcType = getSrc().getType();
3646  auto dstBaseType = type_dyn_cast<FIRRTLBaseType>(dstType);
3647  auto srcBaseType = type_dyn_cast<FIRRTLBaseType>(srcType);
3648  if (!dstBaseType || !srcBaseType) {
3649  if (dstType != srcType)
3650  return emitError("may not connect different non-base types");
3651  } else {
3652  // Analog types cannot be connected and must be attached.
3653  if (dstBaseType.containsAnalog() || srcBaseType.containsAnalog())
3654  return emitError("analog types may not be connected");
3655 
3656  // Destination and source types must be equivalent.
3657  if (!areTypesEquivalent(dstBaseType, srcBaseType))
3658  return emitError("type mismatch between destination ")
3659  << dstBaseType << " and source " << srcBaseType;
3660 
3661  // Truncation is banned in a connection: destination bit width must be
3662  // greater than or equal to source bit width.
3663  if (!isTypeLarger(dstBaseType, srcBaseType))
3664  return emitError("destination ")
3665  << dstBaseType << " is not as wide as the source " << srcBaseType;
3666  }
3667 
3668  // Check that the flows make sense.
3669  if (failed(checkConnectFlow(*this)))
3670  return failure();
3671 
3672  if (failed(checkConnectConditionality(*this)))
3673  return failure();
3674 
3675  return success();
3676 }
3677 
3678 LogicalResult MatchingConnectOp::verify() {
3679  if (auto type = type_dyn_cast<FIRRTLType>(getDest().getType())) {
3680  auto baseType = type_cast<FIRRTLBaseType>(type);
3681 
3682  // Analog types cannot be connected and must be attached.
3683  if (baseType && baseType.containsAnalog())
3684  return emitError("analog types may not be connected");
3685 
3686  // The anonymous types of operands must be equivalent.
3687  assert(areAnonymousTypesEquivalent(cast<FIRRTLBaseType>(getSrc().getType()),
3688  baseType) &&
3689  "`SameAnonTypeOperands` trait should have already rejected "
3690  "structurally non-equivalent types");
3691  }
3692 
3693  // Check that the flows make sense.
3694  if (failed(checkConnectFlow(*this)))
3695  return failure();
3696 
3697  if (failed(checkConnectConditionality(*this)))
3698  return failure();
3699 
3700  return success();
3701 }
3702 
3703 LogicalResult RefDefineOp::verify() {
3704  // Check that the flows make sense.
3705  if (failed(checkConnectFlow(*this)))
3706  return failure();
3707 
3708  // For now, refs can't be in bundles so this is sufficient.
3709  // In the future need to ensure no other define's to same "fieldSource".
3710  // (When aggregates can have references, we can define a reference within,
3711  // but this must be unique. Checking this here may be expensive,
3712  // consider adding something to FModuleLike's to check it there instead)
3713  for (auto *user : getDest().getUsers()) {
3714  if (auto conn = dyn_cast<FConnectLike>(user);
3715  conn && conn.getDest() == getDest() && conn != *this)
3716  return emitError("destination reference cannot be reused by multiple "
3717  "operations, it can only capture a unique dataflow");
3718  }
3719 
3720  // Check "static" source/dest
3721  if (auto *op = getDest().getDefiningOp()) {
3722  // TODO: Make ref.sub only source flow?
3723  if (isa<RefSubOp>(op))
3724  return emitError(
3725  "destination reference cannot be a sub-element of a reference");
3726  if (isa<RefCastOp>(op)) // Source flow, check anyway for now.
3727  return emitError(
3728  "destination reference cannot be a cast of another reference");
3729  }
3730 
3731  // This define is only enabled when its ambient layers are active. Check
3732  // that whenever the destination's layer requirements are met, that this
3733  // op is enabled.
3734  auto ambientLayers = getAmbientLayersAt(getOperation());
3735  auto dstLayers = getLayersFor(getDest());
3736  SmallVector<SymbolRefAttr> missingLayers;
3737  if (!isLayerSetCompatibleWith(ambientLayers, dstLayers, missingLayers)) {
3738  auto diag = emitOpError("has more layer requirements than destination");
3739  auto &note = diag.attachNote();
3740  note << "additional layers required: ";
3741  interleaveComma(missingLayers, note);
3742  return failure();
3743  }
3744 
3745  return success();
3746 }
3747 
3748 LogicalResult PropAssignOp::verify() {
3749  // Check that the flows make sense.
3750  if (failed(checkConnectFlow(*this)))
3751  return failure();
3752 
3753  // Verify that there is a single value driving the destination.
3754  for (auto *user : getDest().getUsers()) {
3755  if (auto conn = dyn_cast<FConnectLike>(user);
3756  conn && conn.getDest() == getDest() && conn != *this)
3757  return emitError("destination property cannot be reused by multiple "
3758  "operations, it can only capture a unique dataflow");
3759  }
3760 
3761  return success();
3762 }
3763 
3764 void WhenOp::createElseRegion() {
3765  assert(!hasElseRegion() && "already has an else region");
3766  getElseRegion().push_back(new Block());
3767 }
3768 
3769 void WhenOp::build(OpBuilder &builder, OperationState &result, Value condition,
3770  bool withElseRegion, std::function<void()> thenCtor,
3771  std::function<void()> elseCtor) {
3772  OpBuilder::InsertionGuard guard(builder);
3773  result.addOperands(condition);
3774 
3775  // Create "then" region.
3776  builder.createBlock(result.addRegion());
3777  if (thenCtor)
3778  thenCtor();
3779 
3780  // Create "else" region.
3781  Region *elseRegion = result.addRegion();
3782  if (withElseRegion) {
3783  builder.createBlock(elseRegion);
3784  if (elseCtor)
3785  elseCtor();
3786  }
3787 }
3788 
3789 //===----------------------------------------------------------------------===//
3790 // MatchOp
3791 //===----------------------------------------------------------------------===//
3792 
3793 LogicalResult MatchOp::verify() {
3794  FEnumType type = getInput().getType();
3795 
3796  // Make sure that the number of tags matches the number of regions.
3797  auto numCases = getTags().size();
3798  auto numRegions = getNumRegions();
3799  if (numRegions != numCases)
3800  return emitOpError("expected ")
3801  << numRegions << " tags but got " << numCases;
3802 
3803  auto numTags = type.getNumElements();
3804 
3805  SmallDenseSet<int64_t> seen;
3806  for (const auto &[tag, region] : llvm::zip(getTags(), getRegions())) {
3807  auto tagIndex = size_t(cast<IntegerAttr>(tag).getInt());
3808 
3809  // Ensure that the block has a single argument.
3810  if (region.getNumArguments() != 1)
3811  return emitOpError("region should have exactly one argument");
3812 
3813  // Make sure that it is a valid tag.
3814  if (tagIndex >= numTags)
3815  return emitOpError("the tag index ")
3816  << tagIndex << " is out of the range of valid tags in " << type;
3817 
3818  // Make sure we have not already matched this tag.
3819  auto [it, inserted] = seen.insert(tagIndex);
3820  if (!inserted)
3821  return emitOpError("the tag ") << type.getElementNameAttr(tagIndex)
3822  << " is matched more than once";
3823 
3824  // Check that the block argument type matches the tag's type.
3825  auto expectedType = type.getElementTypePreservingConst(tagIndex);
3826  auto regionType = region.getArgument(0).getType();
3827  if (regionType != expectedType)
3828  return emitOpError("region type ")
3829  << regionType << " does not match the expected type "
3830  << expectedType;
3831  }
3832 
3833  // Check that the match statement is exhaustive.
3834  for (size_t i = 0, e = type.getNumElements(); i < e; ++i)
3835  if (!seen.contains(i))
3836  return emitOpError("missing case for tag ") << type.getElementNameAttr(i);
3837 
3838  return success();
3839 }
3840 
3841 void MatchOp::print(OpAsmPrinter &p) {
3842  auto input = getInput();
3843  FEnumType type = input.getType();
3844  auto regions = getRegions();
3845  p << " " << input << " : " << type;
3846  SmallVector<StringRef> elided = {"tags"};
3847  p.printOptionalAttrDictWithKeyword((*this)->getAttrs(), elided);
3848  p << " {";
3849  p.increaseIndent();
3850  for (const auto &[tag, region] : llvm::zip(getTags(), regions)) {
3851  p.printNewline();
3852  p << "case ";
3853  p.printKeywordOrString(
3854  type.getElementName(cast<IntegerAttr>(tag).getInt()));
3855  p << "(";
3856  p.printRegionArgument(region.front().getArgument(0), /*attrs=*/{},
3857  /*omitType=*/true);
3858  p << ") ";
3859  p.printRegion(region, /*printEntryBlockArgs=*/false);
3860  }
3861  p.decreaseIndent();
3862  p.printNewline();
3863  p << "}";
3864 }
3865 
3866 ParseResult MatchOp::parse(OpAsmParser &parser, OperationState &result) {
3867  auto *context = parser.getContext();
3868  OpAsmParser::UnresolvedOperand input;
3869  if (parser.parseOperand(input) || parser.parseColon())
3870  return failure();
3871 
3872  auto loc = parser.getCurrentLocation();
3873  Type type;
3874  if (parser.parseType(type))
3875  return failure();
3876  auto enumType = type_dyn_cast<FEnumType>(type);
3877  if (!enumType)
3878  return parser.emitError(loc, "expected enumeration type but got") << type;
3879 
3880  if (parser.resolveOperand(input, type, result.operands) ||
3881  parser.parseOptionalAttrDictWithKeyword(result.attributes) ||
3882  parser.parseLBrace())
3883  return failure();
3884 
3885  auto i32Type = IntegerType::get(context, 32);
3886  SmallVector<Attribute> tags;
3887  while (true) {
3888  // Stop parsing when we don't find another "case" keyword.
3889  if (failed(parser.parseOptionalKeyword("case")))
3890  break;
3891 
3892  // Parse the tag and region argument.
3893  auto nameLoc = parser.getCurrentLocation();
3894  std::string name;
3895  OpAsmParser::Argument arg;
3896  auto *region = result.addRegion();
3897  if (parser.parseKeywordOrString(&name) || parser.parseLParen() ||
3898  parser.parseArgument(arg) || parser.parseRParen())
3899  return failure();
3900 
3901  // Figure out the enum index of the tag.
3902  auto index = enumType.getElementIndex(name);
3903  if (!index)
3904  return parser.emitError(nameLoc, "the tag \"")
3905  << name << "\" is not a member of the enumeration " << enumType;
3906  tags.push_back(IntegerAttr::get(i32Type, *index));
3907 
3908  // Parse the region.
3909  arg.type = enumType.getElementTypePreservingConst(*index);
3910  if (parser.parseRegion(*region, arg))
3911  return failure();
3912  }
3913  result.addAttribute("tags", ArrayAttr::get(context, tags));
3914 
3915  return parser.parseRBrace();
3916 }
3917 
3918 void MatchOp::build(OpBuilder &builder, OperationState &result, Value input,
3919  ArrayAttr tags,
3920  MutableArrayRef<std::unique_ptr<Region>> regions) {
3921  result.addOperands(input);
3922  result.addAttribute("tags", tags);
3923  result.addRegions(regions);
3924 }
3925 
3926 //===----------------------------------------------------------------------===//
3927 // Expressions
3928 //===----------------------------------------------------------------------===//
3929 
3930 /// Type inference adaptor that narrows from the very generic MLIR
3931 /// `InferTypeOpInterface` to what we need in the FIRRTL dialect: just operands
3932 /// and attributes, no context or regions. Also, we only ever produce a single
3933 /// result value, so the FIRRTL-specific type inference ops directly return the
3934 /// inferred type rather than pushing into the `results` vector.
3935 LogicalResult impl::inferReturnTypes(
3936  MLIRContext *context, std::optional<Location> loc, ValueRange operands,
3937  DictionaryAttr attrs, mlir::OpaqueProperties properties,
3938  RegionRange regions, SmallVectorImpl<Type> &results,
3939  llvm::function_ref<FIRRTLType(ValueRange, ArrayRef<NamedAttribute>,
3940  std::optional<Location>)>
3941  callback) {
3942  auto type = callback(
3943  operands, attrs ? attrs.getValue() : ArrayRef<NamedAttribute>{}, loc);
3944  if (type) {
3945  results.push_back(type);
3946  return success();
3947  }
3948  return failure();
3949 }
3950 
3951 /// Get an attribute by name from a list of named attributes. Return null if no
3952 /// attribute is found with that name.
3953 static Attribute maybeGetAttr(ArrayRef<NamedAttribute> attrs, StringRef name) {
3954  for (auto attr : attrs)
3955  if (attr.getName() == name)
3956  return attr.getValue();
3957  return {};
3958 }
3959 
3960 /// Get an attribute by name from a list of named attributes. Aborts if the
3961 /// attribute does not exist.
3962 static Attribute getAttr(ArrayRef<NamedAttribute> attrs, StringRef name) {
3963  if (auto attr = maybeGetAttr(attrs, name))
3964  return attr;
3965  llvm::report_fatal_error("attribute '" + name + "' not found");
3966 }
3967 
3968 /// Same as above, but casts the attribute to a specific type.
3969 template <typename AttrClass>
3970 AttrClass getAttr(ArrayRef<NamedAttribute> attrs, StringRef name) {
3971  return cast<AttrClass>(getAttr(attrs, name));
3972 }
3973 
3974 /// Return true if the specified operation is a firrtl expression.
3975 bool firrtl::isExpression(Operation *op) {
3976  struct IsExprClassifier : public ExprVisitor<IsExprClassifier, bool> {
3977  bool visitInvalidExpr(Operation *op) { return false; }
3978  bool visitUnhandledExpr(Operation *op) { return true; }
3979  };
3980 
3981  return IsExprClassifier().dispatchExprVisitor(op);
3982 }
3983 
3985  // Set invalid values to have a distinct name.
3986  std::string name;
3987  if (auto ty = type_dyn_cast<IntType>(getType())) {
3988  const char *base = ty.isSigned() ? "invalid_si" : "invalid_ui";
3989  auto width = ty.getWidthOrSentinel();
3990  if (width == -1)
3991  name = base;
3992  else
3993  name = (Twine(base) + Twine(width)).str();
3994  } else if (auto ty = type_dyn_cast<AnalogType>(getType())) {
3995  auto width = ty.getWidthOrSentinel();
3996  if (width == -1)
3997  name = "invalid_analog";
3998  else
3999  name = ("invalid_analog" + Twine(width)).str();
4000  } else if (type_isa<AsyncResetType>(getType()))
4001  name = "invalid_asyncreset";
4002  else if (type_isa<ResetType>(getType()))
4003  name = "invalid_reset";
4004  else if (type_isa<ClockType>(getType()))
4005  name = "invalid_clock";
4006  else
4007  name = "invalid";
4008 
4009  setNameFn(getResult(), name);
4010 }
4011 
4012 void ConstantOp::print(OpAsmPrinter &p) {
4013  p << " ";
4014  p.printAttributeWithoutType(getValueAttr());
4015  p << " : ";
4016  p.printType(getType());
4017  p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"value"});
4018 }
4019 
4020 ParseResult ConstantOp::parse(OpAsmParser &parser, OperationState &result) {
4021  // Parse the constant value, without knowing its width.
4022  APInt value;
4023  auto loc = parser.getCurrentLocation();
4024  auto valueResult = parser.parseOptionalInteger(value);
4025  if (!valueResult.has_value())
4026  return parser.emitError(loc, "expected integer value");
4027 
4028  // Parse the result firrtl integer type.
4029  IntType resultType;
4030  if (failed(*valueResult) || parser.parseColonType(resultType) ||
4031  parser.parseOptionalAttrDict(result.attributes))
4032  return failure();
4033  result.addTypes(resultType);
4034 
4035  // Now that we know the width and sign of the result type, we can munge the
4036  // APInt as appropriate.
4037  if (resultType.hasWidth()) {
4038  auto width = (unsigned)resultType.getWidthOrSentinel();
4039  if (width > value.getBitWidth()) {
4040  // sext is always safe here, even for unsigned values, because the
4041  // parseOptionalInteger method will return something with a zero in the
4042  // top bits if it is a positive number.
4043  value = value.sext(width);
4044  } else if (width < value.getBitWidth()) {
4045  // The parser can return an unnecessarily wide result with leading
4046  // zeros. This isn't a problem, but truncating off bits is bad.
4047  unsigned neededBits = value.isNegative() ? value.getSignificantBits()
4048  : value.getActiveBits();
4049  if (width < neededBits)
4050  return parser.emitError(loc, "constant out of range for result type ")
4051  << resultType;
4052  value = value.trunc(width);
4053  }
4054  }
4055 
4056  auto intType = parser.getBuilder().getIntegerType(value.getBitWidth(),
4057  resultType.isSigned());
4058  auto valueAttr = parser.getBuilder().getIntegerAttr(intType, value);
4059  result.addAttribute("value", valueAttr);
4060  return success();
4061 }
4062 
4063 LogicalResult ConstantOp::verify() {
4064  // If the result type has a bitwidth, then the attribute must match its width.
4065  IntType intType = getType();
4066  auto width = intType.getWidthOrSentinel();
4067  if (width != -1 && (int)getValue().getBitWidth() != width)
4068  return emitError(
4069  "firrtl.constant attribute bitwidth doesn't match return type");
4070 
4071  // The sign of the attribute's integer type must match our integer type sign.
4072  auto attrType = type_cast<IntegerType>(getValueAttr().getType());
4073  if (attrType.isSignless() || attrType.isSigned() != intType.isSigned())
4074  return emitError("firrtl.constant attribute has wrong sign");
4075 
4076  return success();
4077 }
4078 
4079 /// Build a ConstantOp from an APInt and a FIRRTL type, handling the attribute
4080 /// formation for the 'value' attribute.
4081 void ConstantOp::build(OpBuilder &builder, OperationState &result, IntType type,
4082  const APInt &value) {
4083  int32_t width = type.getWidthOrSentinel();
4084  (void)width;
4085  assert((width == -1 || (int32_t)value.getBitWidth() == width) &&
4086  "incorrect attribute bitwidth for firrtl.constant");
4087 
4088  auto attr =
4089  IntegerAttr::get(type.getContext(), APSInt(value, !type.isSigned()));
4090  return build(builder, result, type, attr);
4091 }
4092 
4093 /// Build a ConstantOp from an APSInt, handling the attribute formation for the
4094 /// 'value' attribute and inferring the FIRRTL type.
4095 void ConstantOp::build(OpBuilder &builder, OperationState &result,
4096  const APSInt &value) {
4097  auto attr = IntegerAttr::get(builder.getContext(), value);
4098  auto type =
4099  IntType::get(builder.getContext(), value.isSigned(), value.getBitWidth());
4100  return build(builder, result, type, attr);
4101 }
4102 
4104  // For constants in particular, propagate the value into the result name to
4105  // make it easier to read the IR.
4106  IntType intTy = getType();
4107  assert(intTy);
4108 
4109  // Otherwise, build a complex name with the value and type.
4110  SmallString<32> specialNameBuffer;
4111  llvm::raw_svector_ostream specialName(specialNameBuffer);
4112  specialName << 'c';
4113  getValue().print(specialName, /*isSigned:*/ intTy.isSigned());
4114 
4115  specialName << (intTy.isSigned() ? "_si" : "_ui");
4116  auto width = intTy.getWidthOrSentinel();
4117  if (width != -1)
4118  specialName << width;
4119  setNameFn(getResult(), specialName.str());
4120 }
4121 
4122 void SpecialConstantOp::print(OpAsmPrinter &p) {
4123  p << " ";
4124  // SpecialConstant uses a BoolAttr, and we want to print `true` as `1`.
4125  p << static_cast<unsigned>(getValue());
4126  p << " : ";
4127  p.printType(getType());
4128  p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"value"});
4129 }
4130 
4131 ParseResult SpecialConstantOp::parse(OpAsmParser &parser,
4132  OperationState &result) {
4133  // Parse the constant value. SpecialConstant uses bool attributes, but it
4134  // prints as an integer.
4135  APInt value;
4136  auto loc = parser.getCurrentLocation();
4137  auto valueResult = parser.parseOptionalInteger(value);
4138  if (!valueResult.has_value())
4139  return parser.emitError(loc, "expected integer value");
4140 
4141  // Clocks and resets can only be 0 or 1.
4142  if (value != 0 && value != 1)
4143  return parser.emitError(loc, "special constants can only be 0 or 1.");
4144 
4145  // Parse the result firrtl type.
4146  Type resultType;
4147  if (failed(*valueResult) || parser.parseColonType(resultType) ||
4148  parser.parseOptionalAttrDict(result.attributes))
4149  return failure();
4150  result.addTypes(resultType);
4151 
4152  // Create the attribute.
4153  auto valueAttr = parser.getBuilder().getBoolAttr(value == 1);
4154  result.addAttribute("value", valueAttr);
4155  return success();
4156 }
4157 
4159  SmallString<32> specialNameBuffer;
4160  llvm::raw_svector_ostream specialName(specialNameBuffer);
4161  specialName << 'c';
4162  specialName << static_cast<unsigned>(getValue());
4163  auto type = getType();
4164  if (type_isa<ClockType>(type)) {
4165  specialName << "_clock";
4166  } else if (type_isa<ResetType>(type)) {
4167  specialName << "_reset";
4168  } else if (type_isa<AsyncResetType>(type)) {
4169  specialName << "_asyncreset";
4170  }
4171  setNameFn(getResult(), specialName.str());
4172 }
4173 
4174 // Checks that an array attr representing an aggregate constant has the correct
4175 // shape. This recurses on the type.
4176 static bool checkAggConstant(Operation *op, Attribute attr,
4177  FIRRTLBaseType type) {
4178  if (type.isGround()) {
4179  if (!isa<IntegerAttr>(attr)) {
4180  op->emitOpError("Ground type is not an integer attribute");
4181  return false;
4182  }
4183  return true;
4184  }
4185  auto attrlist = dyn_cast<ArrayAttr>(attr);
4186  if (!attrlist) {
4187  op->emitOpError("expected array attribute for aggregate constant");
4188  return false;
4189  }
4190  if (auto array = type_dyn_cast<FVectorType>(type)) {
4191  if (array.getNumElements() != attrlist.size()) {
4192  op->emitOpError("array attribute (")
4193  << attrlist.size() << ") has wrong size for vector constant ("
4194  << array.getNumElements() << ")";
4195  return false;
4196  }
4197  return llvm::all_of(attrlist, [&array, op](Attribute attr) {
4198  return checkAggConstant(op, attr, array.getElementType());
4199  });
4200  }
4201  if (auto bundle = type_dyn_cast<BundleType>(type)) {
4202  if (bundle.getNumElements() != attrlist.size()) {
4203  op->emitOpError("array attribute (")
4204  << attrlist.size() << ") has wrong size for bundle constant ("
4205  << bundle.getNumElements() << ")";
4206  return false;
4207  }
4208  for (size_t i = 0; i < bundle.getNumElements(); ++i) {
4209  if (bundle.getElement(i).isFlip) {
4210  op->emitOpError("Cannot have constant bundle type with flip");
4211  return false;
4212  }
4213  if (!checkAggConstant(op, attrlist[i], bundle.getElement(i).type))
4214  return false;
4215  }
4216  return true;
4217  }
4218  op->emitOpError("Unknown aggregate type");
4219  return false;
4220 }
4221 
4222 LogicalResult AggregateConstantOp::verify() {
4223  if (checkAggConstant(getOperation(), getFields(), getType()))
4224  return success();
4225  return failure();
4226 }
4227 
4228 Attribute AggregateConstantOp::getAttributeFromFieldID(uint64_t fieldID) {
4229  FIRRTLBaseType type = getType();
4230  Attribute value = getFields();
4231  while (fieldID != 0) {
4232  if (auto bundle = type_dyn_cast<BundleType>(type)) {
4233  auto index = bundle.getIndexForFieldID(fieldID);
4234  fieldID -= bundle.getFieldID(index);
4235  type = bundle.getElementType(index);
4236  value = cast<ArrayAttr>(value)[index];
4237  } else {
4238  auto vector = type_cast<FVectorType>(type);
4239  auto index = vector.getIndexForFieldID(fieldID);
4240  fieldID -= vector.getFieldID(index);
4241  type = vector.getElementType();
4242  value = cast<ArrayAttr>(value)[index];
4243  }
4244  }
4245  return value;
4246 }
4247 
4248 void FIntegerConstantOp::print(OpAsmPrinter &p) {
4249  p << " ";
4250  p.printAttributeWithoutType(getValueAttr());
4251  p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"value"});
4252 }
4253 
4254 ParseResult FIntegerConstantOp::parse(OpAsmParser &parser,
4255  OperationState &result) {
4256  auto *context = parser.getContext();
4257  APInt value;
4258  if (parser.parseInteger(value) ||
4259  parser.parseOptionalAttrDict(result.attributes))
4260  return failure();
4261  result.addTypes(FIntegerType::get(context));
4262  auto intType =
4263  IntegerType::get(context, value.getBitWidth(), IntegerType::Signed);
4264  auto valueAttr = parser.getBuilder().getIntegerAttr(intType, value);
4265  result.addAttribute("value", valueAttr);
4266  return success();
4267 }
4268 
4269 ParseResult ListCreateOp::parse(OpAsmParser &parser, OperationState &result) {
4270  llvm::SmallVector<OpAsmParser::UnresolvedOperand, 16> operands;
4271  ListType type;
4272 
4273  if (parser.parseOperandList(operands) ||
4274  parser.parseOptionalAttrDict(result.attributes) ||
4275  parser.parseColonType(type))
4276  return failure();
4277  result.addTypes(type);
4278 
4279  return parser.resolveOperands(operands, type.getElementType(),
4280  result.operands);
4281 }
4282 
4283 void ListCreateOp::print(OpAsmPrinter &p) {
4284  p << " ";
4285  p.printOperands(getElements());
4286  p.printOptionalAttrDict((*this)->getAttrs());
4287  p << " : " << getType();
4288 }
4289 
4290 LogicalResult ListCreateOp::verify() {
4291  if (getElements().empty())
4292  return success();
4293 
4294  auto elementType = getElements().front().getType();
4295  auto listElementType = getType().getElementType();
4296  if (elementType != listElementType)
4297  return emitOpError("has elements of type ")
4298  << elementType << " instead of " << listElementType;
4299 
4300  return success();
4301 }
4302 
4303 LogicalResult BundleCreateOp::verify() {
4304  BundleType resultType = getType();
4305  if (resultType.getNumElements() != getFields().size())
4306  return emitOpError("number of fields doesn't match type");
4307  for (size_t i = 0; i < resultType.getNumElements(); ++i)
4308  if (!areTypesConstCastable(
4309  resultType.getElementTypePreservingConst(i),
4310  type_cast<FIRRTLBaseType>(getOperand(i).getType())))
4311  return emitOpError("type of element doesn't match bundle for field ")
4312  << resultType.getElement(i).name;
4313  // TODO: check flow
4314  return success();
4315 }
4316 
4317 LogicalResult VectorCreateOp::verify() {
4318  FVectorType resultType = getType();
4319  if (resultType.getNumElements() != getFields().size())
4320  return emitOpError("number of fields doesn't match type");
4321  auto elemTy = resultType.getElementTypePreservingConst();
4322  for (size_t i = 0; i < resultType.getNumElements(); ++i)
4323  if (!areTypesConstCastable(
4324  elemTy, type_cast<FIRRTLBaseType>(getOperand(i).getType())))
4325  return emitOpError("type of element doesn't match vector element");
4326  // TODO: check flow
4327  return success();
4328 }
4329 
4330 //===----------------------------------------------------------------------===//
4331 // FEnumCreateOp
4332 //===----------------------------------------------------------------------===//
4333 
4334 LogicalResult FEnumCreateOp::verify() {
4335  FEnumType resultType = getResult().getType();
4336  auto elementIndex = resultType.getElementIndex(getFieldName());
4337  if (!elementIndex)
4338  return emitOpError("label ")
4339  << getFieldName() << " is not a member of the enumeration type "
4340  << resultType;
4341  if (!areTypesConstCastable(
4342  resultType.getElementTypePreservingConst(*elementIndex),
4343  getInput().getType()))
4344  return emitOpError("type of element doesn't match enum element");
4345  return success();
4346 }
4347 
4348 void FEnumCreateOp::print(OpAsmPrinter &printer) {
4349  printer << ' ';
4350  printer.printKeywordOrString(getFieldName());
4351  printer << '(' << getInput() << ')';
4352  SmallVector<StringRef> elidedAttrs = {"fieldIndex"};
4353  printer.printOptionalAttrDictWithKeyword((*this)->getAttrs(), elidedAttrs);
4354  printer << " : ";
4355  printer.printFunctionalType(ArrayRef<Type>{getInput().getType()},
4356  ArrayRef<Type>{getResult().getType()});
4357 }
4358 
4359 ParseResult FEnumCreateOp::parse(OpAsmParser &parser, OperationState &result) {
4360  auto *context = parser.getContext();
4361 
4362  OpAsmParser::UnresolvedOperand input;
4363  std::string fieldName;
4364  mlir::FunctionType functionType;
4365  if (parser.parseKeywordOrString(&fieldName) || parser.parseLParen() ||
4366  parser.parseOperand(input) || parser.parseRParen() ||
4367  parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
4368  parser.parseType(functionType))
4369  return failure();
4370 
4371  if (functionType.getNumInputs() != 1)
4372  return parser.emitError(parser.getNameLoc(), "single input type required");
4373  if (functionType.getNumResults() != 1)
4374  return parser.emitError(parser.getNameLoc(), "single result type required");
4375 
4376  auto inputType = functionType.getInput(0);
4377  if (parser.resolveOperand(input, inputType, result.operands))
4378  return failure();
4379 
4380  auto outputType = functionType.getResult(0);
4381  auto enumType = type_dyn_cast<FEnumType>(outputType);
4382  if (!enumType)
4383  return parser.emitError(parser.getNameLoc(),
4384  "output must be enum type, got ")
4385  << outputType;
4386  auto fieldIndex = enumType.getElementIndex(fieldName);
4387  if (!fieldIndex)
4388  return parser.emitError(parser.getNameLoc(),
4389  "unknown field " + fieldName + " in enum type ")
4390  << enumType;
4391 
4392  result.addAttribute(
4393  "fieldIndex",
4394  IntegerAttr::get(IntegerType::get(context, 32), *fieldIndex));
4395 
4396  result.addTypes(enumType);
4397 
4398  return success();
4399 }
4400 
4401 //===----------------------------------------------------------------------===//
4402 // IsTagOp
4403 //===----------------------------------------------------------------------===//
4404 
4405 LogicalResult IsTagOp::verify() {
4406  if (getFieldIndex() >= getInput().getType().base().getNumElements())
4407  return emitOpError("element index is greater than the number of fields in "
4408  "the bundle type");
4409  return success();
4410 }
4411 
4412 void IsTagOp::print(::mlir::OpAsmPrinter &printer) {
4413  printer << ' ' << getInput() << ' ';
4414  printer.printKeywordOrString(getFieldName());
4415  SmallVector<::llvm::StringRef, 1> elidedAttrs = {"fieldIndex"};
4416  printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
4417  printer << " : " << getInput().getType();
4418 }
4419 
4420 ParseResult IsTagOp::parse(OpAsmParser &parser, OperationState &result) {
4421  auto *context = parser.getContext();
4422 
4423  OpAsmParser::UnresolvedOperand input;
4424  std::string fieldName;
4425  Type inputType;
4426  if (parser.parseOperand(input) || parser.parseKeywordOrString(&fieldName) ||
4427  parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
4428  parser.parseType(inputType))
4429  return failure();
4430 
4431  if (parser.resolveOperand(input, inputType, result.operands))
4432  return failure();
4433 
4434  auto enumType = type_dyn_cast<FEnumType>(inputType);
4435  if (!enumType)
4436  return parser.emitError(parser.getNameLoc(),
4437  "input must be enum type, got ")
4438  << inputType;
4439  auto fieldIndex = enumType.getElementIndex(fieldName);
4440  if (!fieldIndex)
4441  return parser.emitError(parser.getNameLoc(),
4442  "unknown field " + fieldName + " in enum type ")
4443  << enumType;
4444 
4445  result.addAttribute(
4446  "fieldIndex",
4447  IntegerAttr::get(IntegerType::get(context, 32), *fieldIndex));
4448 
4449  result.addTypes(UIntType::get(context, 1, /*isConst=*/false));
4450 
4451  return success();
4452 }
4453 
4454 FIRRTLType IsTagOp::inferReturnType(ValueRange operands,
4455  ArrayRef<NamedAttribute> attrs,
4456  std::optional<Location> loc) {
4457  return UIntType::get(operands[0].getContext(), 1,
4458  isConst(operands[0].getType()));
4459 }
4460 
4461 template <typename OpTy>
4462 ParseResult parseSubfieldLikeOp(OpAsmParser &parser, OperationState &result) {
4463  auto *context = parser.getContext();
4464 
4465  OpAsmParser::UnresolvedOperand input;
4466  std::string fieldName;
4467  Type inputType;
4468  if (parser.parseOperand(input) || parser.parseLSquare() ||
4469  parser.parseKeywordOrString(&fieldName) || parser.parseRSquare() ||
4470  parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
4471  parser.parseType(inputType))
4472  return failure();
4473 
4474  if (parser.resolveOperand(input, inputType, result.operands))
4475  return failure();
4476 
4477  auto bundleType = type_dyn_cast<typename OpTy::InputType>(inputType);
4478  if (!bundleType)
4479  return parser.emitError(parser.getNameLoc(),
4480  "input must be bundle type, got ")
4481  << inputType;
4482  auto fieldIndex = bundleType.getElementIndex(fieldName);
4483  if (!fieldIndex)
4484  return parser.emitError(parser.getNameLoc(),
4485  "unknown field " + fieldName + " in bundle type ")
4486  << bundleType;
4487 
4488  result.addAttribute(
4489  "fieldIndex",
4490  IntegerAttr::get(IntegerType::get(context, 32), *fieldIndex));
4491 
4492  SmallVector<Type> inferredReturnTypes;
4493  if (failed(OpTy::inferReturnTypes(context, result.location, result.operands,
4494  result.attributes.getDictionary(context),
4495  result.getRawProperties(), result.regions,
4496  inferredReturnTypes)))
4497  return failure();
4498  result.addTypes(inferredReturnTypes);
4499 
4500  return success();
4501 }
4502 
4503 ParseResult SubtagOp::parse(OpAsmParser &parser, OperationState &result) {
4504  auto *context = parser.getContext();
4505 
4506  OpAsmParser::UnresolvedOperand input;
4507  std::string fieldName;
4508  Type inputType;
4509  if (parser.parseOperand(input) || parser.parseLSquare() ||
4510  parser.parseKeywordOrString(&fieldName) || parser.parseRSquare() ||
4511  parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
4512  parser.parseType(inputType))
4513  return failure();
4514 
4515  if (parser.resolveOperand(input, inputType, result.operands))
4516  return failure();
4517 
4518  auto enumType = type_dyn_cast<FEnumType>(inputType);
4519  if (!enumType)
4520  return parser.emitError(parser.getNameLoc(),
4521  "input must be enum type, got ")
4522  << inputType;
4523  auto fieldIndex = enumType.getElementIndex(fieldName);
4524  if (!fieldIndex)
4525  return parser.emitError(parser.getNameLoc(),
4526  "unknown field " + fieldName + " in enum type ")
4527  << enumType;
4528 
4529  result.addAttribute(
4530  "fieldIndex",
4531  IntegerAttr::get(IntegerType::get(context, 32), *fieldIndex));
4532 
4533  SmallVector<Type> inferredReturnTypes;
4534  if (failed(SubtagOp::inferReturnTypes(
4535  context, result.location, result.operands,
4536  result.attributes.getDictionary(context), result.getRawProperties(),
4537  result.regions, inferredReturnTypes)))
4538  return failure();
4539  result.addTypes(inferredReturnTypes);
4540 
4541  return success();
4542 }
4543 
4544 ParseResult SubfieldOp::parse(OpAsmParser &parser, OperationState &result) {
4545  return parseSubfieldLikeOp<SubfieldOp>(parser, result);
4546 }
4547 ParseResult OpenSubfieldOp::parse(OpAsmParser &parser, OperationState &result) {
4548  return parseSubfieldLikeOp<OpenSubfieldOp>(parser, result);
4549 }
4550 
4551 template <typename OpTy>
4552 static void printSubfieldLikeOp(OpTy op, ::mlir::OpAsmPrinter &printer) {
4553  printer << ' ' << op.getInput() << '[';
4554  printer.printKeywordOrString(op.getFieldName());
4555  printer << ']';
4556  ::llvm::SmallVector<::llvm::StringRef, 2> elidedAttrs;
4557  elidedAttrs.push_back("fieldIndex");
4558  printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
4559  printer << " : " << op.getInput().getType();
4560 }
4561 void SubfieldOp::print(::mlir::OpAsmPrinter &printer) {
4562  return printSubfieldLikeOp<SubfieldOp>(*this, printer);
4563 }
4564 void OpenSubfieldOp::print(::mlir::OpAsmPrinter &printer) {
4565  return printSubfieldLikeOp<OpenSubfieldOp>(*this, printer);
4566 }
4567 
4568 void SubtagOp::print(::mlir::OpAsmPrinter &printer) {
4569  printer << ' ' << getInput() << '[';
4570  printer.printKeywordOrString(getFieldName());
4571  printer << ']';
4572  ::llvm::SmallVector<::llvm::StringRef, 2> elidedAttrs;
4573  elidedAttrs.push_back("fieldIndex");
4574  printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
4575  printer << " : " << getInput().getType();
4576 }
4577 
4578 template <typename OpTy>
4579 static LogicalResult verifySubfieldLike(OpTy op) {
4580  if (op.getFieldIndex() >=
4581  firrtl::type_cast<typename OpTy::InputType>(op.getInput().getType())
4582  .getNumElements())
4583  return op.emitOpError("subfield element index is greater than the number "
4584  "of fields in the bundle type");
4585  return success();
4586 }
4587 LogicalResult SubfieldOp::verify() {
4588  return verifySubfieldLike<SubfieldOp>(*this);
4589 }
4590 LogicalResult OpenSubfieldOp::verify() {
4591  return verifySubfieldLike<OpenSubfieldOp>(*this);
4592 }
4593 
4594 LogicalResult SubtagOp::verify() {
4595  if (getFieldIndex() >= getInput().getType().base().getNumElements())
4596  return emitOpError("subfield element index is greater than the number "
4597  "of fields in the bundle type");
4598  return success();
4599 }
4600 
4601 /// Return true if the specified operation has a constant value. This trivially
4602 /// checks for `firrtl.constant` and friends, but also looks through subaccesses
4603 /// and correctly handles wires driven with only constant values.
4604 bool firrtl::isConstant(Operation *op) {
4605  // Worklist of ops that need to be examined that should all be constant in
4606  // order for the input operation to be constant.
4607  SmallVector<Operation *, 8> worklist({op});
4608 
4609  // Mutable state indicating if this op is a constant. Assume it is a constant
4610  // and look for counterexamples.
4611  bool constant = true;
4612 
4613  // While we haven't found a counterexample and there are still ops in the
4614  // worklist, pull ops off the worklist. If it provides a counterexample, set
4615  // the `constant` to false (and exit on the next loop iteration). Otherwise,
4616  // look through the op or spawn off more ops to look at.
4617  while (constant && !(worklist.empty()))
4618  TypeSwitch<Operation *>(worklist.pop_back_val())
4619  .Case<NodeOp, AsSIntPrimOp, AsUIntPrimOp>([&](auto op) {
4620  if (auto definingOp = op.getInput().getDefiningOp())
4621  worklist.push_back(definingOp);
4622  constant = false;
4623  })
4624  .Case<WireOp, SubindexOp, SubfieldOp>([&](auto op) {
4625  for (auto &use : op.getResult().getUses())
4626  worklist.push_back(use.getOwner());
4627  })
4628  .Case<ConstantOp, SpecialConstantOp, AggregateConstantOp>([](auto) {})
4629  .Default([&](auto) { constant = false; });
4630 
4631  return constant;
4632 }
4633 
4634 /// Return true if the specified value is a constant. This trivially checks for
4635 /// `firrtl.constant` and friends, but also looks through subaccesses and
4636 /// correctly handles wires driven with only constant values.
4637 bool firrtl::isConstant(Value value) {
4638  if (auto *op = value.getDefiningOp())
4639  return isConstant(op);
4640  return false;
4641 }
4642 
4643 LogicalResult ConstCastOp::verify() {
4644  if (!areTypesConstCastable(getResult().getType(), getInput().getType()))
4645  return emitOpError() << getInput().getType()
4646  << " is not 'const'-castable to "
4647  << getResult().getType();
4648  return success();
4649 }
4650 
4651 FIRRTLType SubfieldOp::inferReturnType(ValueRange operands,
4652  ArrayRef<NamedAttribute> attrs,
4653  std::optional<Location> loc) {
4654  auto inType = type_cast<BundleType>(operands[0].getType());
4655  auto fieldIndex =
4656  getAttr<IntegerAttr>(attrs, "fieldIndex").getValue().getZExtValue();
4657 
4658  if (fieldIndex >= inType.getNumElements())
4659  return emitInferRetTypeError(loc,
4660  "subfield element index is greater than the "
4661  "number of fields in the bundle type");
4662 
4663  // SubfieldOp verifier checks that the field index is valid with number of
4664  // subelements.
4665  return inType.getElementTypePreservingConst(fieldIndex);
4666 }
4667 
4668 FIRRTLType OpenSubfieldOp::inferReturnType(ValueRange operands,
4669  ArrayRef<NamedAttribute> attrs,
4670  std::optional<Location> loc) {
4671  auto inType = type_cast<OpenBundleType>(operands[0].getType());
4672  auto fieldIndex =
4673  getAttr<IntegerAttr>(attrs, "fieldIndex").getValue().getZExtValue();
4674 
4675  if (fieldIndex >= inType.getNumElements())
4676  return emitInferRetTypeError(loc,
4677  "subfield element index is greater than the "
4678  "number of fields in the bundle type");
4679 
4680  // OpenSubfieldOp verifier checks that the field index is valid with number of
4681  // subelements.
4682  return inType.getElementTypePreservingConst(fieldIndex);
4683 }
4684 
4685 bool SubfieldOp::isFieldFlipped() {
4686  BundleType bundle = getInput().getType();
4687  return bundle.getElement(getFieldIndex()).isFlip;
4688 }
4689 bool OpenSubfieldOp::isFieldFlipped() {
4690  auto bundle = getInput().getType();
4691  return bundle.getElement(getFieldIndex()).isFlip;
4692 }
4693 
4694 FIRRTLType SubindexOp::inferReturnType(ValueRange operands,
4695  ArrayRef<NamedAttribute> attrs,
4696  std::optional<Location> loc) {
4697  Type inType = operands[0].getType();
4698  auto fieldIdx =
4699  getAttr<IntegerAttr>(attrs, "index").getValue().getZExtValue();
4700 
4701  if (auto vectorType = type_dyn_cast<FVectorType>(inType)) {
4702  if (fieldIdx < vectorType.getNumElements())
4703  return vectorType.getElementTypePreservingConst();
4704  return emitInferRetTypeError(loc, "out of range index '", fieldIdx,
4705  "' in vector type ", inType);
4706  }
4707 
4708  return emitInferRetTypeError(loc, "subindex requires vector operand");
4709 }
4710 
4711 FIRRTLType OpenSubindexOp::inferReturnType(ValueRange operands,
4712  ArrayRef<NamedAttribute> attrs,
4713  std::optional<Location> loc) {
4714  Type inType = operands[0].getType();
4715  auto fieldIdx =
4716  getAttr<IntegerAttr>(attrs, "index").getValue().getZExtValue();
4717 
4718  if (auto vectorType = type_dyn_cast<OpenVectorType>(inType)) {
4719  if (fieldIdx < vectorType.getNumElements())
4720  return vectorType.getElementTypePreservingConst();
4721  return emitInferRetTypeError(loc, "out of range index '", fieldIdx,
4722  "' in vector type ", inType);
4723  }
4724 
4725  return emitInferRetTypeError(loc, "subindex requires vector operand");
4726 }
4727 
4728 FIRRTLType SubtagOp::inferReturnType(ValueRange operands,
4729  ArrayRef<NamedAttribute> attrs,
4730  std::optional<Location> loc) {
4731  auto inType = type_cast<FEnumType>(operands[0].getType());
4732  auto fieldIndex =
4733  getAttr<IntegerAttr>(attrs, "fieldIndex").getValue().getZExtValue();
4734 
4735  if (fieldIndex >= inType.getNumElements())
4736  return emitInferRetTypeError(loc,
4737  "subtag element index is greater than the "
4738  "number of fields in the enum type");
4739 
4740  // SubtagOp verifier checks that the field index is valid with number of
4741  // subelements.
4742  auto elementType = inType.getElement(fieldIndex).type;
4743  return elementType.getConstType(elementType.isConst() || inType.isConst());
4744 }
4745 
4746 FIRRTLType SubaccessOp::inferReturnType(ValueRange operands,
4747  ArrayRef<NamedAttribute> attrs,
4748  std::optional<Location> loc) {
4749  auto inType = operands[0].getType();
4750  auto indexType = operands[1].getType();
4751 
4752  if (!type_isa<UIntType>(indexType))
4753  return emitInferRetTypeError(loc, "subaccess index must be UInt type, not ",
4754  indexType);
4755 
4756  if (auto vectorType = type_dyn_cast<FVectorType>(inType)) {
4757  if (isConst(indexType))
4758  return vectorType.getElementTypePreservingConst();
4759  return vectorType.getElementType().getAllConstDroppedType();
4760  }
4761 
4762  return emitInferRetTypeError(loc, "subaccess requires vector operand, not ",
4763  inType);
4764 }
4765 
4766 FIRRTLType TagExtractOp::inferReturnType(ValueRange operands,
4767  ArrayRef<NamedAttribute> attrs,
4768  std::optional<Location> loc) {
4769  auto inType = type_cast<FEnumType>(operands[0].getType());
4770  auto i = llvm::Log2_32_Ceil(inType.getNumElements());
4771  return UIntType::get(inType.getContext(), i);
4772 }
4773 
4774 ParseResult MultibitMuxOp::parse(OpAsmParser &parser, OperationState &result) {
4775  OpAsmParser::UnresolvedOperand index;
4776  SmallVector<OpAsmParser::UnresolvedOperand, 16> inputs;
4777  Type indexType, elemType;
4778 
4779  if (parser.parseOperand(index) || parser.parseComma() ||
4780  parser.parseOperandList(inputs) ||
4781  parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
4782  parser.parseType(indexType) || parser.parseComma() ||
4783  parser.parseType(elemType))
4784  return failure();
4785 
4786  if (parser.resolveOperand(index, indexType, result.operands))
4787  return failure();
4788 
4789  result.addTypes(elemType);
4790 
4791  return parser.resolveOperands(inputs, elemType, result.operands);
4792 }
4793 
4794 void MultibitMuxOp::print(OpAsmPrinter &p) {
4795  p << " " << getIndex() << ", ";
4796  p.printOperands(getInputs());
4797  p.printOptionalAttrDict((*this)->getAttrs());
4798  p << " : " << getIndex().getType() << ", " << getType();
4799 }
4800 
4801 FIRRTLType MultibitMuxOp::inferReturnType(ValueRange operands,
4802  ArrayRef<NamedAttribute> attrs,
4803  std::optional<Location> loc) {
4804  if (operands.size() < 2)
4805  return emitInferRetTypeError(loc, "at least one input is required");
4806 
4807  // Check all mux inputs have the same type.
4808  if (!llvm::all_of(operands.drop_front(2), [&](auto op) {
4809  return operands[1].getType() == op.getType();
4810  }))
4811  return emitInferRetTypeError(loc, "all inputs must have the same type");
4812 
4813  return type_cast<FIRRTLType>(operands[1].getType());
4814 }
4815 
4816 //===----------------------------------------------------------------------===//
4817 // ObjectSubfieldOp
4818 //===----------------------------------------------------------------------===//
4819 
4821  MLIRContext *context, std::optional<mlir::Location> location,
4822  ValueRange operands, DictionaryAttr attributes, OpaqueProperties properties,
4823  RegionRange regions, llvm::SmallVectorImpl<Type> &inferredReturnTypes) {
4824  auto type = inferReturnType(operands, attributes.getValue(), location);
4825  if (!type)
4826  return failure();
4827  inferredReturnTypes.push_back(type);
4828  return success();
4829 }
4830 
4831 Type ObjectSubfieldOp::inferReturnType(ValueRange operands,
4832  ArrayRef<NamedAttribute> attrs,
4833  std::optional<Location> loc) {
4834  auto classType = dyn_cast<ClassType>(operands[0].getType());
4835  if (!classType)
4836  return emitInferRetTypeError(loc, "base object is not a class");
4837 
4838  auto index = getAttr<IntegerAttr>(attrs, "index").getValue().getZExtValue();
4839  if (classType.getNumElements() <= index)
4840  return emitInferRetTypeError(loc, "element index is greater than the "
4841  "number of fields in the object");
4842 
4843  return classType.getElement(index).type;
4844 }
4845 
4846 void ObjectSubfieldOp::print(OpAsmPrinter &p) {
4847  auto input = getInput();
4848  auto classType = input.getType();
4849  p << ' ' << input << "[";
4850  p.printKeywordOrString(classType.getElement(getIndex()).name);
4851  p << "]";
4852  p.printOptionalAttrDict((*this)->getAttrs(), std::array{StringRef("index")});
4853  p << " : " << classType;
4854 }
4855 
4856 ParseResult ObjectSubfieldOp::parse(OpAsmParser &parser,
4857  OperationState &result) {
4858  auto *context = parser.getContext();
4859 
4860  OpAsmParser::UnresolvedOperand input;
4861  std::string fieldName;
4862  ClassType inputType;
4863  if (parser.parseOperand(input) || parser.parseLSquare() ||
4864  parser.parseKeywordOrString(&fieldName) || parser.parseRSquare() ||
4865  parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
4866  parser.parseType(inputType) ||
4867  parser.resolveOperand(input, inputType, result.operands))
4868  return failure();
4869 
4870  auto index = inputType.getElementIndex(fieldName);
4871  if (!index)
4872  return parser.emitError(parser.getNameLoc(),
4873  "unknown field " + fieldName + " in class type ")
4874  << inputType;
4875  result.addAttribute("index",
4876  IntegerAttr::get(IntegerType::get(context, 32), *index));
4877 
4878  SmallVector<Type> inferredReturnTypes;
4879  if (failed(inferReturnTypes(context, result.location, result.operands,
4880  result.attributes.getDictionary(context),
4881  result.getRawProperties(), result.regions,
4882  inferredReturnTypes)))
4883  return failure();
4884  result.addTypes(inferredReturnTypes);
4885 
4886  return success();
4887 }
4888 
4889 //===----------------------------------------------------------------------===//
4890 // Binary Primitives
4891 //===----------------------------------------------------------------------===//
4892 
4893 /// If LHS and RHS are both UInt or SInt types, the return true and fill in the
4894 /// width of them if known. If unknown, return -1 for the widths.
4895 /// The constness of the result is also returned, where if both lhs and rhs are
4896 /// const, then the result is const.
4897 ///
4898 /// On failure, this reports and error and returns false. This function should
4899 /// not be used if you don't want an error reported.
4900 static bool isSameIntTypeKind(Type lhs, Type rhs, int32_t &lhsWidth,
4901  int32_t &rhsWidth, bool &isConstResult,
4902  std::optional<Location> loc) {
4903  // Must have two integer types with the same signedness.
4904  auto lhsi = type_dyn_cast<IntType>(lhs);
4905  auto rhsi = type_dyn_cast<IntType>(rhs);
4906  if (!lhsi || !rhsi || lhsi.isSigned() != rhsi.isSigned()) {
4907  if (loc) {
4908  if (lhsi && !rhsi)
4909  mlir::emitError(*loc, "second operand must be an integer type, not ")
4910  << rhs;
4911  else if (!lhsi && rhsi)
4912  mlir::emitError(*loc, "first operand must be an integer type, not ")
4913  << lhs;
4914  else if (!lhsi && !rhsi)
4915  mlir::emitError(*loc, "operands must be integer types, not ")
4916  << lhs << " and " << rhs;
4917  else
4918  mlir::emitError(*loc, "operand signedness must match");
4919  }
4920  return false;
4921  }
4922 
4923  lhsWidth = lhsi.getWidthOrSentinel();
4924  rhsWidth = rhsi.getWidthOrSentinel();
4925  isConstResult = lhsi.isConst() && rhsi.isConst();
4926  return true;
4927 }
4928 
4929 LogicalResult impl::verifySameOperandsIntTypeKind(Operation *op) {
4930  assert(op->getNumOperands() == 2 &&
4931  "SameOperandsIntTypeKind on non-binary op");
4932  int32_t lhsWidth, rhsWidth;
4933  bool isConstResult;
4934  return success(isSameIntTypeKind(op->getOperand(0).getType(),
4935  op->getOperand(1).getType(), lhsWidth,
4936  rhsWidth, isConstResult, op->getLoc()));
4937 }
4938 
4939 LogicalResult impl::validateBinaryOpArguments(ValueRange operands,
4940  ArrayRef<NamedAttribute> attrs,
4941  Location loc) {
4942  if (operands.size() != 2 || !attrs.empty()) {
4943  mlir::emitError(loc, "operation requires two operands and no constants");
4944  return failure();
4945  }
4946  return success();
4947 }
4948 
4950  std::optional<Location> loc) {
4951  int32_t lhsWidth, rhsWidth, resultWidth = -1;
4952  bool isConstResult = false;
4953  if (!isSameIntTypeKind(lhs, rhs, lhsWidth, rhsWidth, isConstResult, loc))
4954  return {};
4955 
4956  if (lhsWidth != -1 && rhsWidth != -1)
4957  resultWidth = std::max(lhsWidth, rhsWidth) + 1;
4958  return IntType::get(lhs.getContext(), type_isa<SIntType>(lhs), resultWidth,
4959  isConstResult);
4960 }
4961 
4962 FIRRTLType MulPrimOp::inferBinaryReturnType(FIRRTLType lhs, FIRRTLType rhs,
4963  std::optional<Location> loc) {
4964  int32_t lhsWidth, rhsWidth, resultWidth = -1;
4965  bool isConstResult = false;
4966  if (!isSameIntTypeKind(lhs, rhs, lhsWidth, rhsWidth, isConstResult, loc))
4967  return {};
4968 
4969  if (lhsWidth != -1 && rhsWidth != -1)
4970  resultWidth = lhsWidth + rhsWidth;
4971 
4972  return IntType::get(lhs.getContext(), type_isa<SIntType>(lhs), resultWidth,
4973  isConstResult);
4974 }
4975 
4976 FIRRTLType DivPrimOp::inferBinaryReturnType(FIRRTLType lhs, FIRRTLType rhs,
4977  std::optional<Location> loc) {
4978  int32_t lhsWidth, rhsWidth;
4979  bool isConstResult = false;
4980  if (!isSameIntTypeKind(lhs, rhs, lhsWidth, rhsWidth, isConstResult, loc))
4981  return {};
4982 
4983  // For unsigned, the width is the width of the numerator on the LHS.
4984  if (type_isa<UIntType>(lhs))
4985  return UIntType::get(lhs.getContext(), lhsWidth, isConstResult);
4986 
4987  // For signed, the width is the width of the numerator on the LHS, plus 1.
4988  int32_t resultWidth = lhsWidth != -1 ? lhsWidth + 1 : -1;
4989  return SIntType::get(lhs.getContext(), resultWidth, isConstResult);
4990 }
4991 
4992 FIRRTLType RemPrimOp::inferBinaryReturnType(FIRRTLType lhs, FIRRTLType rhs,
4993  std::optional<Location> loc) {
4994  int32_t lhsWidth, rhsWidth, resultWidth = -1;
4995  bool isConstResult = false;
4996  if (!isSameIntTypeKind(lhs, rhs, lhsWidth, rhsWidth, isConstResult, loc))
4997  return {};
4998 
4999  if (lhsWidth != -1 && rhsWidth != -1)
5000  resultWidth = std::min(lhsWidth, rhsWidth);
5001  return IntType::get(lhs.getContext(), type_isa<SIntType>(lhs), resultWidth,
5002  isConstResult);
5003 }
5004 
5006  std::optional<Location> loc) {
5007  int32_t lhsWidth, rhsWidth, resultWidth = -1;
5008  bool isConstResult = false;
5009  if (!isSameIntTypeKind(lhs, rhs, lhsWidth, rhsWidth, isConstResult, loc))
5010  return {};
5011 
5012  if (lhsWidth != -1 && rhsWidth != -1) {
5013  resultWidth = std::max(lhsWidth, rhsWidth);
5014  if (lhsWidth == resultWidth && lhs.isConst() == isConstResult &&
5015  isa<UIntType>(lhs))
5016  return lhs;
5017  if (rhsWidth == resultWidth && rhs.isConst() == isConstResult &&
5018  isa<UIntType>(rhs))
5019  return rhs;
5020  }
5021  return UIntType::get(lhs.getContext(), resultWidth, isConstResult);
5022 }
5023 
5025  std::optional<Location> loc) {
5026  if (!type_isa<FVectorType>(lhs) || !type_isa<FVectorType>(rhs))
5027  return {};
5028 
5029  auto lhsVec = type_cast<FVectorType>(lhs);
5030  auto rhsVec = type_cast<FVectorType>(rhs);
5031 
5032  if (lhsVec.getNumElements() != rhsVec.getNumElements())
5033  return {};
5034 
5035  auto elemType =
5036  impl::inferBitwiseResult(lhsVec.getElementTypePreservingConst(),
5037  rhsVec.getElementTypePreservingConst(), loc);
5038  if (!elemType)
5039  return {};
5040  auto elemBaseType = type_cast<FIRRTLBaseType>(elemType);
5041  return FVectorType::get(elemBaseType, lhsVec.getNumElements(),
5042  lhsVec.isConst() && rhsVec.isConst() &&
5043  elemBaseType.isConst());
5044 }
5045 
5047  std::optional<Location> loc) {
5048  return UIntType::get(lhs.getContext(), 1, isConst(lhs) && isConst(rhs));
5049 }
5050 
5051 FIRRTLType CatPrimOp::inferBinaryReturnType(FIRRTLType lhs, FIRRTLType rhs,
5052  std::optional<Location> loc) {
5053  int32_t lhsWidth, rhsWidth, resultWidth = -1;
5054  bool isConstResult = false;
5055  if (!isSameIntTypeKind(lhs, rhs, lhsWidth, rhsWidth, isConstResult, loc))
5056  return {};
5057 
5058  if (lhsWidth != -1 && rhsWidth != -1)
5059  resultWidth = lhsWidth + rhsWidth;
5060  return UIntType::get(lhs.getContext(), resultWidth, isConstResult);
5061 }
5062 
5063 FIRRTLType DShlPrimOp::inferBinaryReturnType(FIRRTLType lhs, FIRRTLType rhs,
5064  std::optional<Location> loc) {
5065  auto lhsi = type_dyn_cast<IntType>(lhs);
5066  auto rhsui = type_dyn_cast<UIntType>(rhs);
5067  if (!rhsui || !lhsi)
5068  return emitInferRetTypeError(
5069  loc, "first operand should be integer, second unsigned int");
5070 
5071  // If the left or right has unknown result type, then the operation does
5072  // too.
5073  auto width = lhsi.getWidthOrSentinel();
5074  if (width == -1 || !rhsui.getWidth().has_value()) {
5075  width = -1;
5076  } else {
5077  auto amount = *rhsui.getWidth();
5078  if (amount >= 32)
5079  return emitInferRetTypeError(loc,
5080  "shift amount too large: second operand of "
5081  "dshl is wider than 31 bits");
5082  int64_t newWidth = (int64_t)width + ((int64_t)1 << amount) - 1;
5083  if (newWidth > INT32_MAX)
5084  return emitInferRetTypeError(
5085  loc, "shift amount too large: first operand shifted by maximum "
5086  "amount exceeds maximum width");
5087  width = newWidth;
5088  }
5089  return IntType::get(lhs.getContext(), lhsi.isSigned(), width,
5090  lhsi.isConst() && rhsui.isConst());
5091 }
5092 
5093 FIRRTLType DShlwPrimOp::inferBinaryReturnType(FIRRTLType lhs, FIRRTLType rhs,
5094  std::optional<Location> loc) {
5095  auto lhsi = type_dyn_cast<IntType>(lhs);
5096  auto rhsu = type_dyn_cast<UIntType>(rhs);
5097  if (!lhsi || !rhsu)
5098  return emitInferRetTypeError(
5099  loc, "first operand should be integer, second unsigned int");
5100  return lhsi.getConstType(lhsi.isConst() && rhsu.isConst());
5101 }
5102 
5103 FIRRTLType DShrPrimOp::inferBinaryReturnType(FIRRTLType lhs, FIRRTLType rhs,
5104  std::optional<Location> loc) {
5105  auto lhsi = type_dyn_cast<IntType>(lhs);
5106  auto rhsu = type_dyn_cast<UIntType>(rhs);
5107  if (!lhsi || !rhsu)
5108  return emitInferRetTypeError(
5109  loc, "first operand should be integer, second unsigned int");
5110  return lhsi.getConstType(lhsi.isConst() && rhsu.isConst());
5111 }
5112 
5113 //===----------------------------------------------------------------------===//
5114 // Unary Primitives
5115 //===----------------------------------------------------------------------===//
5116 
5117 LogicalResult impl::validateUnaryOpArguments(ValueRange operands,
5118  ArrayRef<NamedAttribute> attrs,
5119  Location loc) {
5120  if (operands.size() != 1 || !attrs.empty()) {
5121  mlir::emitError(loc, "operation requires one operand and no constants");
5122  return failure();
5123  }
5124  return success();
5125 }
5126 
5127 FIRRTLType
5128 SizeOfIntrinsicOp::inferUnaryReturnType(FIRRTLType input,
5129  std::optional<Location> loc) {
5130  return UIntType::get(input.getContext(), 32);
5131 }
5132 
5133 FIRRTLType AsSIntPrimOp::inferUnaryReturnType(FIRRTLType input,
5134  std::optional<Location> loc) {
5135  auto base = type_dyn_cast<FIRRTLBaseType>(input);
5136  if (!base)
5137  return emitInferRetTypeError(loc, "operand must be a scalar base type");
5138  int32_t width = base.getBitWidthOrSentinel();
5139  if (width == -2)
5140  return emitInferRetTypeError(loc, "operand must be a scalar type");
5141  return SIntType::get(input.getContext(), width, base.isConst());
5142 }
5143 
5144 FIRRTLType AsUIntPrimOp::inferUnaryReturnType(FIRRTLType input,
5145  std::optional<Location> loc) {
5146  auto base = type_dyn_cast<FIRRTLBaseType>(input);
5147  if (!base)
5148  return emitInferRetTypeError(loc, "operand must be a scalar base type");
5149  int32_t width = base.getBitWidthOrSentinel();
5150  if (width == -2)
5151  return emitInferRetTypeError(loc, "operand must be a scalar type");
5152  return UIntType::get(input.getContext(), width, base.isConst());
5153 }
5154 
5155 FIRRTLType
5156 AsAsyncResetPrimOp::inferUnaryReturnType(FIRRTLType input,
5157  std::optional<Location> loc) {
5158  auto base = type_dyn_cast<FIRRTLBaseType>(input);
5159  if (!base)
5160  return emitInferRetTypeError(loc,
5161  "operand must be single bit scalar base type");
5162  int32_t width = base.getBitWidthOrSentinel();
5163  if (width == -2 || width == 0 || width > 1)
5164  return emitInferRetTypeError(loc, "operand must be single bit scalar type");
5165  return AsyncResetType::get(input.getContext(), base.isConst());
5166 }
5167 
5168 FIRRTLType AsClockPrimOp::inferUnaryReturnType(FIRRTLType input,
5169  std::optional<Location> loc) {
5170  return ClockType::get(input.getContext(), isConst(input));
5171 }
5172 
5173 FIRRTLType CvtPrimOp::inferUnaryReturnType(FIRRTLType input,
5174  std::optional<Location> loc) {
5175  if (auto uiType = type_dyn_cast<UIntType>(input)) {
5176  auto width = uiType.getWidthOrSentinel();
5177  if (width != -1)
5178  ++width;
5179  return SIntType::get(input.getContext(), width, uiType.isConst());
5180  }
5181 
5182  if (type_isa<SIntType>(input))
5183  return input;
5184 
5185  return emitInferRetTypeError(loc, "operand must have integer type");
5186 }
5187 
5188 FIRRTLType NegPrimOp::inferUnaryReturnType(FIRRTLType input,
5189  std::optional<Location> loc) {
5190  auto inputi = type_dyn_cast<IntType>(input);
5191  if (!inputi)
5192  return emitInferRetTypeError(loc, "operand must have integer type");
5193  int32_t width = inputi.getWidthOrSentinel();
5194  if (width != -1)
5195  ++width;
5196  return SIntType::get(input.getContext(), width, inputi.isConst());
5197 }
5198 
5199 FIRRTLType NotPrimOp::inferUnaryReturnType(FIRRTLType input,
5200  std::optional<Location> loc) {
5201  auto inputi = type_dyn_cast<IntType>(input);
5202  if (!inputi)
5203  return emitInferRetTypeError(loc, "operand must have integer type");
5204  if (isa<UIntType>(inputi))
5205  return inputi;
5206  return UIntType::get(input.getContext(), inputi.getWidthOrSentinel(),
5207  inputi.isConst());
5208 }
5209 
5211  std::optional<Location> loc) {
5212  return UIntType::get(input.getContext(), 1, isConst(input));
5213 }
5214 
5215 //===----------------------------------------------------------------------===//
5216 // Other Operations
5217 //===----------------------------------------------------------------------===//
5218 
5219 LogicalResult BitsPrimOp::validateArguments(ValueRange operands,
5220  ArrayRef<NamedAttribute> attrs,
5221  Location loc) {
5222  if (operands.size() != 1 || attrs.size() != 2) {
5223  mlir::emitError(loc, "operation requires one operand and two constants");
5224  return failure();
5225  }
5226  return success();
5227 }
5228 
5229 FIRRTLType BitsPrimOp::inferReturnType(ValueRange operands,
5230  ArrayRef<NamedAttribute> attrs,
5231  std::optional<Location> loc) {
5232  auto input = operands[0].getType();
5233  auto high = getAttr<IntegerAttr>(attrs, "hi").getValue().getSExtValue();
5234  auto low = getAttr<IntegerAttr>(attrs, "lo").getValue().getSExtValue();
5235 
5236  auto inputi = type_dyn_cast<IntType>(input);
5237  if (!inputi)
5238  return emitInferRetTypeError(
5239  loc, "input type should be the int type but got ", input);
5240 
5241  // High must be >= low and both most be non-negative.
5242  if (high < low)
5243  return emitInferRetTypeError(
5244  loc, "high must be equal or greater than low, but got high = ", high,
5245  ", low = ", low);
5246 
5247  if (low < 0)
5248  return emitInferRetTypeError(loc, "low must be non-negative but got ", low);
5249 
5250  // If the input has staticly known width, check it. Both and low must be
5251  // strictly less than width.
5252  int32_t width = inputi.getWidthOrSentinel();
5253  if (width != -1 && high >= width)
5254  return emitInferRetTypeError(
5255  loc,
5256  "high must be smaller than the width of input, but got high = ", high,
5257  ", width = ", width);
5258 
5259  return UIntType::get(input.getContext(), high - low + 1, inputi.isConst());
5260 }
5261 
5262 LogicalResult impl::validateOneOperandOneConst(ValueRange operands,
5263  ArrayRef<NamedAttribute> attrs,
5264  Location loc) {
5265  if (operands.size() != 1 || attrs.size() != 1) {
5266  mlir::emitError(loc, "operation requires one operand and one constant");
5267  return failure();
5268  }
5269  return success();
5270 }
5271 
5272 FIRRTLType HeadPrimOp::inferReturnType(ValueRange operands,
5273  ArrayRef<NamedAttribute> attrs,
5274  std::optional<Location> loc) {
5275  auto input = operands[0].getType();
5276  auto amount = getAttr<IntegerAttr>(attrs, "amount").getValue().getSExtValue();
5277 
5278  auto inputi = type_dyn_cast<IntType>(input);
5279  if (amount < 0 || !inputi)
5280  return emitInferRetTypeError(
5281  loc, "operand must have integer type and amount must be >= 0");
5282 
5283  int32_t width = inputi.getWidthOrSentinel();
5284  if (width != -1 && amount > width)
5285  return emitInferRetTypeError(loc, "amount larger than input width");
5286 
5287  return UIntType::get(input.getContext(), amount, inputi.isConst());
5288 }
5289 
5290 LogicalResult MuxPrimOp::validateArguments(ValueRange operands,
5291  ArrayRef<NamedAttribute> attrs,
5292  Location loc) {
5293  if (operands.size() != 3 || attrs.size() != 0) {
5294  mlir::emitError(loc, "operation requires three operands and no constants");
5295  return failure();
5296  }
5297  return success();
5298 }
5299 
5300 /// Infer the result type for a multiplexer given its two operand types, which
5301 /// may be aggregates.
5302 ///
5303 /// This essentially performs a pairwise comparison of fields and elements, as
5304 /// follows:
5305 /// - Identical operands inferred to their common type
5306 /// - Integer operands inferred to the larger one if both have a known width, a
5307 /// widthless integer otherwise.
5308 /// - Vectors inferred based on the element type.
5309 /// - Bundles inferred in a pairwise fashion based on the field types.
5311  FIRRTLBaseType low,
5312  bool isConstCondition,
5313  std::optional<Location> loc) {
5314  // If the types are identical we're done.
5315  if (high == low)
5316  return isConstCondition ? low : low.getAllConstDroppedType();
5317 
5318  // The base types need to be equivalent.
5319  if (high.getTypeID() != low.getTypeID())
5320  return emitInferRetTypeError<FIRRTLBaseType>(
5321  loc, "incompatible mux operand types, true value type: ", high,
5322  ", false value type: ", low);
5323 
5324  bool outerTypeIsConst = isConstCondition && low.isConst() && high.isConst();
5325 
5326  // Two different Int types can be compatible. If either has unknown width,
5327  // then return it. If both are known but different width, then return the
5328  // larger one.
5329  if (type_isa<IntType>(low)) {
5330  int32_t highWidth = high.getBitWidthOrSentinel();
5331  int32_t lowWidth = low.getBitWidthOrSentinel();
5332  if (lowWidth == -1)
5333  return low.getConstType(outerTypeIsConst);
5334  if (highWidth == -1)
5335  return high.getConstType(outerTypeIsConst);
5336  return (lowWidth > highWidth ? low : high).getConstType(outerTypeIsConst);
5337  }
5338 
5339  // Infer vector types by comparing the element types.
5340  auto highVector = type_dyn_cast<FVectorType>(high);
5341  auto lowVector = type_dyn_cast<FVectorType>(low);
5342  if (highVector && lowVector &&
5343  highVector.getNumElements() == lowVector.getNumElements()) {
5344  auto inner = inferMuxReturnType(highVector.getElementTypePreservingConst(),
5345  lowVector.getElementTypePreservingConst(),
5346  isConstCondition, loc);
5347  if (!inner)
5348  return {};
5349  return FVectorType::get(inner, lowVector.getNumElements(),
5350  outerTypeIsConst);
5351  }
5352 
5353  // Infer bundle types by inferring names in a pairwise fashion.
5354  auto highBundle = type_dyn_cast<BundleType>(high);
5355  auto lowBundle = type_dyn_cast<BundleType>(low);
5356  if (highBundle && lowBundle) {
5357  auto highElements = highBundle.getElements();
5358  auto lowElements = lowBundle.getElements();
5359  size_t numElements = highElements.size();
5360 
5361  SmallVector<BundleType::BundleElement> newElements;
5362  if (numElements == lowElements.size()) {
5363  bool failed = false;
5364  for (size_t i = 0; i < numElements; ++i) {
5365  if (highElements[i].name != lowElements[i].name ||
5366  highElements[i].isFlip != lowElements[i].isFlip) {
5367  failed = true;
5368  break;
5369  }
5370  auto element = highElements[i];
5371  element.type = inferMuxReturnType(
5372  highBundle.getElementTypePreservingConst(i),
5373  lowBundle.getElementTypePreservingConst(i), isConstCondition, loc);
5374  if (!element.type)
5375  return {};
5376  newElements.push_back(element);
5377  }
5378  if (!failed)
5379  return BundleType::get(low.getContext(), newElements, outerTypeIsConst);
5380  }
5381  return emitInferRetTypeError<FIRRTLBaseType>(
5382  loc, "incompatible mux operand bundle fields, true value type: ", high,
5383  ", false value type: ", low);
5384  }
5385 
5386  // If we arrive here the types of the two mux arms are fundamentally
5387  // incompatible.
5388  return emitInferRetTypeError<FIRRTLBaseType>(
5389  loc, "invalid mux operand types, true value type: ", high,
5390  ", false value type: ", low);
5391 }
5392 
5393 FIRRTLType MuxPrimOp::inferReturnType(ValueRange operands,
5394  ArrayRef<NamedAttribute> attrs,
5395  std::optional<Location> loc) {
5396  auto highType = type_dyn_cast<FIRRTLBaseType>(operands[1].getType());
5397  auto lowType = type_dyn_cast<FIRRTLBaseType>(operands[2].getType());
5398  if (!highType || !lowType)
5399  return emitInferRetTypeError(loc, "operands must be base type");
5400  return inferMuxReturnType(highType, lowType, isConst(operands[0].getType()),
5401  loc);
5402 }
5403 
5404 FIRRTLType Mux2CellIntrinsicOp::inferReturnType(ValueRange operands,
5405  ArrayRef<NamedAttribute> attrs,
5406  std::optional<Location> loc) {
5407  auto highType = type_dyn_cast<FIRRTLBaseType>(operands[1].getType());
5408  auto lowType = type_dyn_cast<FIRRTLBaseType>(operands[2].getType());
5409  if (!highType || !lowType)
5410  return emitInferRetTypeError(loc, "operands must be base type");
5411  return inferMuxReturnType(highType, lowType, isConst(operands[0].getType()),
5412  loc);
5413 }
5414 
5415 FIRRTLType Mux4CellIntrinsicOp::inferReturnType(ValueRange operands,
5416  ArrayRef<NamedAttribute> attrs,
5417  std::optional<Location> loc) {
5418  SmallVector<FIRRTLBaseType> types;
5419  FIRRTLBaseType result;
5420  for (unsigned i = 1; i < 5; i++) {
5421  types.push_back(type_dyn_cast<FIRRTLBaseType>(operands[i].getType()));
5422  if (!types.back())
5423  return emitInferRetTypeError(loc, "operands must be base type");
5424  if (result) {
5425  result = inferMuxReturnType(result, types.back(),
5426  isConst(operands[0].getType()), loc);
5427  if (!result)
5428  return result;
5429  } else {
5430  result = types.back();
5431  }
5432  }
5433  return result;
5434 }
5435 
5436 FIRRTLType PadPrimOp::inferReturnType(ValueRange operands,
5437  ArrayRef<NamedAttribute> attrs,
5438  std::optional<Location> loc) {
5439  auto input = operands[0].getType();
5440  auto amount = getAttr<IntegerAttr>(attrs, "amount").getValue().getSExtValue();
5441 
5442  auto inputi = type_dyn_cast<IntType>(input);
5443  if (amount < 0 || !inputi)
5444  return emitInferRetTypeError(
5445  loc, "pad input must be integer and amount must be >= 0");
5446 
5447  int32_t width = inputi.getWidthOrSentinel();
5448  if (width == -1)
5449  return inputi;
5450 
5451  width = std::max<int32_t>(width, amount);
5452  return IntType::get(input.getContext(), inputi.isSigned(), width,
5453  inputi.isConst());
5454 }
5455 
5456 FIRRTLType ShlPrimOp::inferReturnType(ValueRange operands,
5457  ArrayRef<NamedAttribute> attrs,
5458  std::optional<Location> loc) {
5459  auto input = operands[0].getType();
5460  auto amount = getAttr<IntegerAttr>(attrs, "amount").getValue().getSExtValue();
5461 
5462  auto inputi = type_dyn_cast<IntType>(input);
5463  if (amount < 0 || !inputi)
5464  return emitInferRetTypeError(
5465  loc, "shl input must be integer and amount must be >= 0");
5466 
5467  int32_t width = inputi.getWidthOrSentinel();
5468  if (width != -1)
5469  width += amount;
5470 
5471  return IntType::get(input.getContext(), inputi.isSigned(), width,
5472  inputi.isConst());
5473 }
5474 
5475 FIRRTLType ShrPrimOp::inferReturnType(ValueRange operands,
5476  ArrayRef<NamedAttribute> attrs,
5477  std::optional<Location> loc) {
5478  auto input = operands[0].getType();
5479  auto amount = getAttr<IntegerAttr>(attrs, "amount").getValue().getSExtValue();
5480 
5481  auto inputi = type_dyn_cast<IntType>(input);
5482  if (amount < 0 || !inputi)
5483  return emitInferRetTypeError(
5484  loc, "shr input must be integer and amount must be >= 0");
5485 
5486  int32_t width = inputi.getWidthOrSentinel();
5487  if (width != -1) {
5488  // UInt saturates at 0 bits, SInt at 1 bit
5489  int32_t minWidth = inputi.isUnsigned() ? 0 : 1;
5490  width = std::max<int32_t>(minWidth, width - amount);
5491  }
5492 
5493  return IntType::get(input.getContext(), inputi.isSigned(), width,
5494  inputi.isConst());
5495 }
5496 
5497 FIRRTLType TailPrimOp::inferReturnType(ValueRange operands,
5498  ArrayRef<NamedAttribute> attrs,
5499  std::optional<Location> loc) {
5500  auto input = operands[0].getType();
5501  auto amount = getAttr<IntegerAttr>(attrs, "amount").getValue().getSExtValue();
5502 
5503  auto inputi = type_dyn_cast<IntType>(input);
5504  if (amount < 0 || !inputi)
5505  return emitInferRetTypeError(
5506  loc, "tail input must be integer and amount must be >= 0");
5507 
5508  int32_t width = inputi.getWidthOrSentinel();
5509  if (width != -1) {
5510  if (width < amount)
5511  return emitInferRetTypeError(
5512  loc, "amount must be less than or equal operand width");
5513  width -= amount;
5514  }
5515 
5516  return IntType::get(input.getContext(), false, width, inputi.isConst());
5517 }
5518 
5519 //===----------------------------------------------------------------------===//
5520 // VerbatimExprOp
5521 //===----------------------------------------------------------------------===//
5522 
5524  function_ref<void(Value, StringRef)> setNameFn) {
5525  // If the text is macro like, then use a pretty name. We only take the
5526  // text up to a weird character (like a paren) and currently ignore
5527  // parenthesized expressions.
5528  auto isOkCharacter = [](char c) { return llvm::isAlnum(c) || c == '_'; };
5529  auto name = getText();
5530  // Ignore a leading ` in macro name.
5531  if (name.starts_with("`"))
5532  name = name.drop_front();
5533  name = name.take_while(isOkCharacter);
5534  if (!name.empty())
5535  setNameFn(getResult(), name);
5536 }
5537 
5538 //===----------------------------------------------------------------------===//
5539 // VerbatimWireOp
5540 //===----------------------------------------------------------------------===//
5541 
5543  function_ref<void(Value, StringRef)> setNameFn) {
5544  // If the text is macro like, then use a pretty name. We only take the
5545  // text up to a weird character (like a paren) and currently ignore
5546  // parenthesized expressions.
5547  auto isOkCharacter = [](char c) { return llvm::isAlnum(c) || c == '_'; };
5548  auto name = getText();
5549  // Ignore a leading ` in macro name.
5550  if (name.starts_with("`"))
5551  name = name.drop_front();
5552  name = name.take_while(isOkCharacter);
5553  if (!name.empty())
5554  setNameFn(getResult(), name);
5555 }
5556 
5557 //===----------------------------------------------------------------------===//
5558 // DPICallIntrinsicOp
5559 //===----------------------------------------------------------------------===//
5560 
5561 static bool isTypeAllowedForDPI(Operation *op, Type type) {
5562  return !type.walk([&](firrtl::IntType intType) -> mlir::WalkResult {
5563  auto width = intType.getWidth();
5564  if (width < 0) {
5565  op->emitError() << "unknown width is not allowed for DPI";
5566  return WalkResult::interrupt();
5567  }
5568  if (width == 1 || width == 8 || width == 16 || width == 32 ||
5569  width >= 64)
5570  return WalkResult::advance();
5571  op->emitError()
5572  << "integer types used by DPI functions must have a "
5573  "specific bit width; "
5574  "it must be equal to 1(bit), 8(byte), 16(shortint), "
5575  "32(int), 64(longint) "
5576  "or greater than 64, but got "
5577  << intType;
5578  return WalkResult::interrupt();
5579  })
5580  .wasInterrupted();
5581 }
5582 
5583 LogicalResult DPICallIntrinsicOp::verify() {
5584  if (auto inputNames = getInputNames()) {
5585  if (getInputs().size() != inputNames->size())
5586  return emitError() << "inputNames has " << inputNames->size()
5587  << " elements but there are " << getInputs().size()
5588  << " input arguments";
5589  }
5590  if (auto outputName = getOutputName())
5591  if (getNumResults() == 0)
5592  return emitError() << "output name is given but there is no result";
5593 
5594  auto checkType = [this](Type type) {
5595  return isTypeAllowedForDPI(*this, type);
5596  };
5597  return success(llvm::all_of(this->getResultTypes(), checkType) &&
5598  llvm::all_of(this->getOperandTypes(), checkType));
5599 }
5600 
5601 SmallVector<std::pair<circt::FieldRef, circt::FieldRef>>
5602 DPICallIntrinsicOp::computeDataFlow() {
5603  if (getClock())
5604  return {};
5605 
5606  SmallVector<std::pair<circt::FieldRef, circt::FieldRef>> deps;
5607 
5608  for (auto operand : getOperands()) {
5609  auto type = type_cast<FIRRTLBaseType>(operand.getType());
5610  auto baseFieldRef = getFieldRefFromValue(operand);
5611  SmallVector<circt::FieldRef> operandFields;
5613  type, [&](uint64_t dstIndex, FIRRTLBaseType t, bool dstIsFlip) {
5614  operandFields.push_back(baseFieldRef.getSubField(dstIndex));
5615  });
5616 
5617  // Record operand -> result dependency.
5618  for (auto result : getResults())
5620  type, [&](uint64_t dstIndex, FIRRTLBaseType t, bool dstIsFlip) {
5621  for (auto field : operandFields)
5622  deps.emplace_back(circt::FieldRef(result, dstIndex), field);
5623  });
5624  }
5625  return deps;
5626 }
5627 
5628 //===----------------------------------------------------------------------===//
5629 // Conversions to/from structs in the standard dialect.
5630 //===----------------------------------------------------------------------===//
5631 
5632 LogicalResult HWStructCastOp::verify() {
5633  // We must have a bundle and a struct, with matching pairwise fields
5634  BundleType bundleType;
5635  hw::StructType structType;
5636  if ((bundleType = type_dyn_cast<BundleType>(getOperand().getType()))) {
5637  structType = dyn_cast<hw::StructType>(getType());
5638  if (!structType)
5639  return emitError("result type must be a struct");
5640  } else if ((bundleType = type_dyn_cast<BundleType>(getType()))) {
5641  structType = dyn_cast<hw::StructType>(getOperand().getType());
5642  if (!structType)
5643  return emitError("operand type must be a struct");
5644  } else {
5645  return emitError("either source or result type must be a bundle type");
5646  }
5647 
5648  auto firFields = bundleType.getElements();
5649  auto hwFields = structType.getElements();
5650  if (firFields.size() != hwFields.size())
5651  return emitError("bundle and struct have different number of fields");
5652 
5653  for (size_t findex = 0, fend = firFields.size(); findex < fend; ++findex) {
5654  if (firFields[findex].name.getValue() != hwFields[findex].name)
5655  return emitError("field names don't match '")
5656  << firFields[findex].name.getValue() << "', '"
5657  << hwFields[findex].name.getValue() << "'";
5658  int64_t firWidth =
5659  FIRRTLBaseType(firFields[findex].type).getBitWidthOrSentinel();
5660  int64_t hwWidth = hw::getBitWidth(hwFields[findex].type);
5661  if (firWidth > 0 && hwWidth > 0 && firWidth != hwWidth)
5662  return emitError("size of field '")
5663  << hwFields[findex].name.getValue() << "' don't match " << firWidth
5664  << ", " << hwWidth;
5665  }
5666 
5667  return success();
5668 }
5669 
5670 LogicalResult BitCastOp::verify() {
5671  auto inTypeBits = getBitWidth(getInput().getType(), /*ignoreFlip=*/true);
5672  auto resTypeBits = getBitWidth(getType());
5673  if (inTypeBits.has_value() && resTypeBits.has_value()) {
5674  // Bitwidths must match for valid bit
5675  if (*inTypeBits == *resTypeBits) {
5676  // non-'const' cannot be casted to 'const'
5677  if (containsConst(getType()) && !isConst(getOperand().getType()))
5678  return emitError("cannot cast non-'const' input type ")
5679  << getOperand().getType() << " to 'const' result type "
5680  << getType();
5681  return success();
5682  }
5683  return emitError("the bitwidth of input (")
5684  << *inTypeBits << ") and result (" << *resTypeBits
5685  << ") don't match";
5686  }
5687  if (!inTypeBits.has_value())
5688  return emitError("bitwidth cannot be determined for input operand type ")
5689  << getInput().getType();
5690  return emitError("bitwidth cannot be determined for result type ")
5691  << getType();
5692 }
5693 
5694 //===----------------------------------------------------------------------===//
5695 // Custom attr-dict Directive that Elides Annotations
5696 //===----------------------------------------------------------------------===//
5697 
5698 /// Parse an optional attribute dictionary, adding an empty 'annotations'
5699 /// attribute if not specified.
5700 static ParseResult parseElideAnnotations(OpAsmParser &parser,
5701  NamedAttrList &resultAttrs) {
5702  auto result = parser.parseOptionalAttrDict(resultAttrs);
5703  if (!resultAttrs.get("annotations"))
5704  resultAttrs.append("annotations", parser.getBuilder().getArrayAttr({}));
5705 
5706  return result;
5707 }
5708 
5709 static void printElideAnnotations(OpAsmPrinter &p, Operation *op,
5710  DictionaryAttr attr,
5711  ArrayRef<StringRef> extraElides = {}) {
5712  SmallVector<StringRef> elidedAttrs(extraElides.begin(), extraElides.end());
5713  // Elide "annotations" if it is empty.
5714  if (op->getAttrOfType<ArrayAttr>("annotations").empty())
5715  elidedAttrs.push_back("annotations");
5716  // Elide "nameKind".
5717  elidedAttrs.push_back("nameKind");
5718 
5719  p.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
5720 }
5721 
5722 /// Parse an optional attribute dictionary, adding empty 'annotations' and
5723 /// 'portAnnotations' attributes if not specified.
5724 static ParseResult parseElidePortAnnotations(OpAsmParser &parser,
5725  NamedAttrList &resultAttrs) {
5726  auto result = parseElideAnnotations(parser, resultAttrs);
5727 
5728  if (!resultAttrs.get("portAnnotations")) {
5729  SmallVector<Attribute, 16> portAnnotations(
5730  parser.getNumResults(), parser.getBuilder().getArrayAttr({}));
5731  resultAttrs.append("portAnnotations",
5732  parser.getBuilder().getArrayAttr(portAnnotations));
5733  }
5734  return result;
5735 }
5736 
5737 // Elide 'annotations' and 'portAnnotations' attributes if they are empty.
5738 static void printElidePortAnnotations(OpAsmPrinter &p, Operation *op,
5739  DictionaryAttr attr,
5740  ArrayRef<StringRef> extraElides = {}) {
5741  SmallVector<StringRef, 2> elidedAttrs(extraElides.begin(), extraElides.end());
5742 
5743  if (llvm::all_of(op->getAttrOfType<ArrayAttr>("portAnnotations"),
5744  [&](Attribute a) { return cast<ArrayAttr>(a).empty(); }))
5745  elidedAttrs.push_back("portAnnotations");
5746  printElideAnnotations(p, op, attr, elidedAttrs);
5747 }
5748 
5749 //===----------------------------------------------------------------------===//
5750 // NameKind Custom Directive
5751 //===----------------------------------------------------------------------===//
5752 
5753 static ParseResult parseNameKind(OpAsmParser &parser,
5754  firrtl::NameKindEnumAttr &result) {
5755  StringRef keyword;
5756 
5757  if (!parser.parseOptionalKeyword(&keyword,
5758  {"interesting_name", "droppable_name"})) {
5759  auto kind = symbolizeNameKindEnum(keyword);
5760  result = NameKindEnumAttr::get(parser.getContext(), kind.value());
5761  return success();
5762  }
5763 
5764  // Default is droppable name.
5765  result =
5766  NameKindEnumAttr::get(parser.getContext(), NameKindEnum::DroppableName);
5767  return success();
5768 }
5769 
5770 static void printNameKind(OpAsmPrinter &p, Operation *op,
5771  firrtl::NameKindEnumAttr attr,
5772  ArrayRef<StringRef> extraElides = {}) {
5773  if (attr.getValue() != NameKindEnum::DroppableName)
5774  p << " " << stringifyNameKindEnum(attr.getValue());
5775 }
5776 
5777 //===----------------------------------------------------------------------===//
5778 // ImplicitSSAName Custom Directive
5779 //===----------------------------------------------------------------------===//
5780 
5781 static ParseResult parseFIRRTLImplicitSSAName(OpAsmParser &parser,
5782  NamedAttrList &resultAttrs) {
5783  if (parseElideAnnotations(parser, resultAttrs))
5784  return failure();
5785  inferImplicitSSAName(parser, resultAttrs);
5786  return success();
5787 }
5788 
5789 static void printFIRRTLImplicitSSAName(OpAsmPrinter &p, Operation *op,
5790  DictionaryAttr attrs) {
5791  SmallVector<StringRef, 4> elides;
5793  elides.push_back(Forceable::getForceableAttrName());
5794  elideImplicitSSAName(p, op, attrs, elides);
5795  printElideAnnotations(p, op, attrs, elides);
5796 }
5797 
5798 //===----------------------------------------------------------------------===//
5799 // MemOp Custom attr-dict Directive
5800 //===----------------------------------------------------------------------===//
5801 
5802 static ParseResult parseMemOp(OpAsmParser &parser, NamedAttrList &resultAttrs) {
5803  return parseElidePortAnnotations(parser, resultAttrs);
5804 }
5805 
5806 /// Always elide "ruw" and elide "annotations" if it exists or if it is empty.
5807 static void printMemOp(OpAsmPrinter &p, Operation *op, DictionaryAttr attr) {
5808  // "ruw" and "inner_sym" is always elided.
5809  printElidePortAnnotations(p, op, attr, {"ruw", "inner_sym"});
5810 }
5811 
5812 //===----------------------------------------------------------------------===//
5813 // ClassInterface custom directive
5814 //===----------------------------------------------------------------------===//
5815 
5816 static ParseResult parseClassInterface(OpAsmParser &parser, Type &result) {
5817  ClassType type;
5818  if (ClassType::parseInterface(parser, type))
5819  return failure();
5820  result = type;
5821  return success();
5822 }
5823 
5824 static void printClassInterface(OpAsmPrinter &p, Operation *, ClassType type) {
5825  type.printInterface(p);
5826 }
5827 
5828 //===----------------------------------------------------------------------===//
5829 // Miscellaneous custom elision logic.
5830 //===----------------------------------------------------------------------===//
5831 
5832 static ParseResult parseElideEmptyName(OpAsmParser &p,
5833  NamedAttrList &resultAttrs) {
5834  auto result = p.parseOptionalAttrDict(resultAttrs);
5835  if (!resultAttrs.get("name"))
5836  resultAttrs.append("name", p.getBuilder().getStringAttr(""));
5837 
5838  return result;
5839 }
5840 
5841 static void printElideEmptyName(OpAsmPrinter &p, Operation *op,
5842  DictionaryAttr attr,
5843  ArrayRef<StringRef> extraElides = {}) {
5844  SmallVector<StringRef> elides(extraElides.begin(), extraElides.end());
5845  if (op->getAttrOfType<StringAttr>("name").getValue().empty())
5846  elides.push_back("name");
5847 
5848  p.printOptionalAttrDict(op->getAttrs(), elides);
5849 }
5850 
5851 static ParseResult parsePrintfAttrs(OpAsmParser &p,
5852  NamedAttrList &resultAttrs) {
5853  return parseElideEmptyName(p, resultAttrs);
5854 }
5855 
5856 static void printPrintfAttrs(OpAsmPrinter &p, Operation *op,
5857  DictionaryAttr attr) {
5858  printElideEmptyName(p, op, attr, {"formatString"});
5859 }
5860 
5861 static ParseResult parseStopAttrs(OpAsmParser &p, NamedAttrList &resultAttrs) {
5862  return parseElideEmptyName(p, resultAttrs);
5863 }
5864 
5865 static void printStopAttrs(OpAsmPrinter &p, Operation *op,
5866  DictionaryAttr attr) {
5867  printElideEmptyName(p, op, attr, {"exitCode"});
5868 }
5869 
5870 static ParseResult parseVerifAttrs(OpAsmParser &p, NamedAttrList &resultAttrs) {
5871  return parseElideEmptyName(p, resultAttrs);
5872 }
5873 
5874 static void printVerifAttrs(OpAsmPrinter &p, Operation *op,
5875  DictionaryAttr attr) {
5876  printElideEmptyName(p, op, attr, {"message"});
5877 }
5878 
5879 //===----------------------------------------------------------------------===//
5880 // Various namers.
5881 //===----------------------------------------------------------------------===//
5882 
5883 static void genericAsmResultNames(Operation *op,
5884  OpAsmSetValueNameFn setNameFn) {
5885  // Many firrtl dialect operations have an optional 'name' attribute. If
5886  // present, use it.
5887  if (op->getNumResults() == 1)
5888  if (auto nameAttr = op->getAttrOfType<StringAttr>("name"))
5889  setNameFn(op->getResult(0), nameAttr.getValue());
5890 }
5891 
5893  genericAsmResultNames(*this, setNameFn);
5894 }
5895 
5897  genericAsmResultNames(*this, setNameFn);
5898 }
5899 
5901  genericAsmResultNames(*this, setNameFn);
5902 }
5903 
5905  genericAsmResultNames(*this, setNameFn);
5906 }
5908  genericAsmResultNames(*this, setNameFn);
5909 }
5911  genericAsmResultNames(*this, setNameFn);
5912 }
5914  genericAsmResultNames(*this, setNameFn);
5915 }
5917  genericAsmResultNames(*this, setNameFn);
5918 }
5920  genericAsmResultNames(*this, setNameFn);
5921 }
5923  genericAsmResultNames(*this, setNameFn);
5924 }
5926  genericAsmResultNames(*this, setNameFn);
5927 }
5929  genericAsmResultNames(*this, setNameFn);
5930 }
5932  genericAsmResultNames(*this, setNameFn);
5933 }
5935  genericAsmResultNames(*this, setNameFn);
5936 }
5938  genericAsmResultNames(*this, setNameFn);
5939 }
5941  genericAsmResultNames(*this, setNameFn);
5942 }
5944  genericAsmResultNames(*this, setNameFn);
5945 }
5947  genericAsmResultNames(*this, setNameFn);
5948 }
5950  genericAsmResultNames(*this, setNameFn);
5951 }
5953  genericAsmResultNames(*this, setNameFn);
5954 }
5956  genericAsmResultNames(*this, setNameFn);
5957 }
5959  genericAsmResultNames(*this, setNameFn);
5960 }
5962  genericAsmResultNames(*this, setNameFn);
5963 }
5965  genericAsmResultNames(*this, setNameFn);
5966 }
5968  genericAsmResultNames(*this, setNameFn);
5969 }
5971  OpAsmSetValueNameFn setNameFn) {
5972  genericAsmResultNames(*this, setNameFn);
5973 }
5975  genericAsmResultNames(*this, setNameFn);
5976 }
5978  genericAsmResultNames(*this, setNameFn);
5979 }
5981  genericAsmResultNames(*this, setNameFn);
5982 }
5984  genericAsmResultNames(*this, setNameFn);
5985 }
5987  genericAsmResultNames(*this, setNameFn);
5988 }
5990  genericAsmResultNames(*this, setNameFn);
5991 }
5993  genericAsmResultNames(*this, setNameFn);
5994 }
5996  genericAsmResultNames(*this, setNameFn);
5997 }
5999  genericAsmResultNames(*this, setNameFn);
6000 }
6002  genericAsmResultNames(*this, setNameFn);
6003 }
6005  genericAsmResultNames(*this, setNameFn);
6006 }
6008  genericAsmResultNames(*this, setNameFn);
6009 }
6011  genericAsmResultNames(*this, setNameFn);
6012 }
6014  genericAsmResultNames(*this, setNameFn);
6015 }
6017  genericAsmResultNames(*this, setNameFn);
6018 }
6020  genericAsmResultNames(*this, setNameFn);
6021 }
6023  genericAsmResultNames(*this, setNameFn);
6024 }
6025 
6027  genericAsmResultNames(*this, setNameFn);
6028 }
6029 
6031  genericAsmResultNames(*this, setNameFn);
6032 }
6033 
6035  genericAsmResultNames(*this, setNameFn);
6036 }
6037 
6039  genericAsmResultNames(*this, setNameFn);
6040 }
6041 
6043  genericAsmResultNames(*this, setNameFn);
6044 }
6045 
6047  genericAsmResultNames(*this, setNameFn);
6048 }
6049 
6051  genericAsmResultNames(*this, setNameFn);
6052 }
6053 
6055  genericAsmResultNames(*this, setNameFn);
6056 }
6057 
6059  genericAsmResultNames(*this, setNameFn);
6060 }
6061 
6063  genericAsmResultNames(*this, setNameFn);
6064 }
6065 
6067  genericAsmResultNames(*this, setNameFn);
6068 }
6069 
6071  genericAsmResultNames(*this, setNameFn);
6072 }
6073 
6075  genericAsmResultNames(*this, setNameFn);
6076 }
6077 
6079  genericAsmResultNames(*this, setNameFn);
6080 }
6081 
6083  genericAsmResultNames(*this, setNameFn);
6084 }
6085 
6087  genericAsmResultNames(*this, setNameFn);
6088 }
6089 
6090 //===----------------------------------------------------------------------===//
6091 // RefOps
6092 //===----------------------------------------------------------------------===//
6093 
6095  genericAsmResultNames(*this, setNameFn);
6096 }
6097 
6099  genericAsmResultNames(*this, setNameFn);
6100 }
6101 
6103  genericAsmResultNames(*this, setNameFn);
6104 }
6105 
6107  genericAsmResultNames(*this, setNameFn);
6108 }
6109 
6111  genericAsmResultNames(*this, setNameFn);
6112 }
6113 
6114 FIRRTLType RefResolveOp::inferReturnType(ValueRange operands,
6115  ArrayRef<NamedAttribute> attrs,
6116  std::optional<Location> loc) {
6117  Type inType = operands[0].getType();
6118  auto inRefType = type_dyn_cast<RefType>(inType);
6119  if (!inRefType)
6120  return emitInferRetTypeError(
6121  loc, "ref.resolve operand must be ref type, not ", inType);
6122  return inRefType.getType();
6123 }
6124 
6125 FIRRTLType RefSendOp::inferReturnType(ValueRange operands,
6126  ArrayRef<NamedAttribute> attrs,
6127  std::optional<Location> loc) {
6128  Type inType = operands[0].getType();
6129  auto inBaseType = type_dyn_cast<FIRRTLBaseType>(inType);
6130  if (!inBaseType)
6131  return emitInferRetTypeError(
6132  loc, "ref.send operand must be base type, not ", inType);
6133  return RefType::get(inBaseType.getPassiveType());
6134 }
6135 
6136 FIRRTLType RefSubOp::inferReturnType(ValueRange operands,
6137  ArrayRef<NamedAttribute> attrs,
6138  std::optional<Location> loc) {
6139  auto refType = type_dyn_cast<RefType>(operands[0].getType());
6140  if (!refType)
6141  return emitInferRetTypeError(loc, "input must be of reference type");
6142  auto inType = refType.getType();
6143  auto fieldIdx =
6144  getAttr<IntegerAttr>(attrs, "index").getValue().getZExtValue();
6145 
6146  // TODO: Determine ref.sub + rwprobe behavior, test.
6147  // Probably best to demote to non-rw, but that has implications
6148  // for any LowerTypes behavior being relied on.
6149  // Allow for now, as need to LowerTypes things generally.
6150  if (auto vectorType = type_dyn_cast<FVectorType>(inType)) {
6151  if (fieldIdx < vectorType.getNumElements())
6152  return RefType::get(
6153  vectorType.getElementType().getConstType(
6154  vectorType.isConst() || vectorType.getElementType().isConst()),
6155  refType.getForceable(), refType.getLayer());
6156  return emitInferRetTypeError(loc, "out of range index '", fieldIdx,
6157  "' in RefType of vector type ", refType);
6158  }
6159  if (auto bundleType = type_dyn_cast<BundleType>(inType)) {
6160  if (fieldIdx >= bundleType.getNumElements()) {
6161  return emitInferRetTypeError(loc,
6162  "subfield element index is greater than "
6163  "the number of fields in the bundle type");
6164  }
6165  return RefType::get(bundleType.getElement(fieldIdx).type.getConstType(
6166  bundleType.isConst() ||
6167  bundleType.getElement(fieldIdx).type.isConst()),
6168  refType.getForceable(), refType.getLayer());
6169  }
6170 
6171  return emitInferRetTypeError(
6172  loc, "ref.sub op requires a RefType of vector or bundle base type");
6173 }
6174 
6175 LogicalResult RefCastOp::verify() {
6176  auto srcLayers = getLayersFor(getInput());
6177  auto dstLayers = getLayersFor(getResult());
6178  SmallVector<SymbolRefAttr> missingLayers;
6179  if (!isLayerSetCompatibleWith(srcLayers, dstLayers, missingLayers)) {
6180  auto diag =
6181  emitOpError("cannot discard layer requirements of input reference");
6182  auto &note = diag.attachNote();
6183  note << "discarding layer requirements: ";
6184  llvm::interleaveComma(missingLayers, note);
6185  return failure();
6186  }
6187  return success();
6188 }
6189 
6190 LogicalResult RefResolveOp::verify() {
6191  auto srcLayers = getLayersFor(getRef());
6192  auto dstLayers = getAmbientLayersAt(getOperation());
6193  SmallVector<SymbolRefAttr> missingLayers;
6194  if (!isLayerSetCompatibleWith(srcLayers, dstLayers, missingLayers)) {
6195  auto diag =
6196  emitOpError("ambient layers are insufficient to resolve reference");
6197  auto &note = diag.attachNote();
6198  note << "missing layer requirements: ";
6199  interleaveComma(missingLayers, note);
6200  return failure();
6201  }
6202  return success();
6203 }
6204 
6205 LogicalResult RWProbeOp::verifyInnerRefs(hw::InnerRefNamespace &ns) {
6206  auto targetRef = getTarget();
6207  if (targetRef.getModule() !=
6208  (*this)->getParentOfType<FModuleLike>().getModuleNameAttr())
6209  return emitOpError() << "has non-local target";
6210 
6211  auto target = ns.lookup(targetRef);
6212  if (!target)
6213  return emitOpError() << "has target that cannot be resolved: " << targetRef;
6214 
6215  auto checkFinalType = [&](auto type, Location loc) -> LogicalResult {
6216  // Determine final type.
6217  mlir::Type fType =
6218  hw::FieldIdImpl::getFinalTypeByFieldID(type, target.getField());
6219  // Check.
6220  auto baseType = type_dyn_cast<FIRRTLBaseType>(fType);
6221  if (!baseType || baseType.getPassiveType() != getType().getType()) {
6222  auto diag = emitOpError("has type mismatch: target resolves to ")
6223  << fType << " instead of expected " << getType().getType();
6224  diag.attachNote(loc) << "target resolves here";
6225  return diag;
6226  }
6227  return success();
6228  };
6229 
6230  auto checkLayers = [&](Location loc) -> LogicalResult {
6231  auto dstLayers = getAmbientLayersAt(target.getOp());
6232  auto srcLayers = getLayersFor(getResult());
6233  SmallVector<SymbolRefAttr> missingLayers;
6234  if (!isLayerSetCompatibleWith(srcLayers, dstLayers, missingLayers)) {
6235  auto diag = emitOpError("target has insufficient layer requirements");
6236  auto &note = diag.attachNote(loc);
6237  note << "target is missing layer requirements: ";
6238  llvm::interleaveComma(missingLayers, note);
6239  return failure();
6240  }
6241  return success();
6242  };
6243  auto checks = [&](auto type, Location loc) {
6244  if (failed(checkLayers(loc)))
6245  return failure();
6246  return checkFinalType(type, loc);
6247  };
6248 
6249  if (target.isPort()) {
6250  auto mod = cast<FModuleLike>(target.getOp());
6251  return checks(mod.getPortType(target.getPort()),
6252  mod.getPortLocation(target.getPort()));
6253  }
6254  hw::InnerSymbolOpInterface symOp =
6255  cast<hw::InnerSymbolOpInterface>(target.getOp());
6256  if (!symOp.getTargetResult())
6257  return emitOpError("has target that cannot be probed")
6258  .attachNote(symOp.getLoc())
6259  .append("target resolves here");
6260  auto *ancestor =
6261  symOp.getTargetResult().getParentBlock()->findAncestorOpInBlock(**this);
6262  if (!ancestor || !symOp->isBeforeInBlock(ancestor))
6263  return emitOpError("is not dominated by target")
6264  .attachNote(symOp.getLoc())
6265  .append("target here");
6266  return checks(symOp.getTargetResult().getType(), symOp.getLoc());
6267 }
6268 
6269 //===----------------------------------------------------------------------===//
6270 // Layer Block Operations
6271 //===----------------------------------------------------------------------===//
6272 
6273 LogicalResult LayerBlockOp::verify() {
6274  auto layerName = getLayerName();
6275  auto *parentOp = (*this)->getParentOp();
6276 
6277  // Get parent operation that isn't a when or match.
6278  while (isa<WhenOp, MatchOp>(parentOp))
6279  parentOp = parentOp->getParentOp();
6280 
6281  // Verify the correctness of the symbol reference. Only verify that this
6282  // layer block makes sense in its parent module or layer block.
6283  auto nestedReferences = layerName.getNestedReferences();
6284  if (nestedReferences.empty()) {
6285  if (!isa<FModuleOp>(parentOp)) {
6286  auto diag = emitOpError() << "has an un-nested layer symbol, but does "
6287  "not have a 'firrtl.module' op as a parent";
6288  return diag.attachNote(parentOp->getLoc())
6289  << "illegal parent op defined here";
6290  }
6291  } else {
6292  auto parentLayerBlock = dyn_cast<LayerBlockOp>(parentOp);
6293  if (!parentLayerBlock) {
6294  auto diag = emitOpError()
6295  << "has a nested layer symbol, but does not have a '"
6296  << getOperationName() << "' op as a parent'";
6297  return diag.attachNote(parentOp->getLoc())
6298  << "illegal parent op defined here";
6299  }
6300  auto parentLayerBlockName = parentLayerBlock.getLayerName();
6301  if (parentLayerBlockName.getRootReference() !=
6302  layerName.getRootReference() ||
6303  parentLayerBlockName.getNestedReferences() !=
6304  layerName.getNestedReferences().drop_back()) {
6305  auto diag = emitOpError() << "is nested under an illegal layer block";
6306  return diag.attachNote(parentLayerBlock->getLoc())
6307  << "illegal parent layer block defined here";
6308  }
6309  }
6310 
6311  // Verify the body of the region.
6312  FieldRefCache fieldRefCache;
6313  auto result = getBody(0)->walk<mlir::WalkOrder::PreOrder>(
6314  [&](Operation *op) -> WalkResult {
6315  // Skip nested layer blocks. Those will be verified separately.
6316  if (isa<LayerBlockOp>(op))
6317  return WalkResult::skip();
6318 
6319  // Check all the operands of each op to make sure that only legal things
6320  // are captured.
6321  for (auto operand : op->getOperands()) {
6322  // Any value captured from the current layer block is fine.
6323  if (auto *definingOp = operand.getDefiningOp())
6324  if (getOperation()->isAncestor(definingOp))
6325  continue;
6326 
6327  auto type = operand.getType();
6328 
6329  // Capture of a non-base type, e.g., reference, is allowed.
6330  if (isa<PropertyType>(type)) {
6331  auto diag = emitOpError() << "captures a property operand";
6332  diag.attachNote(operand.getLoc()) << "operand is defined here";
6333  diag.attachNote(op->getLoc()) << "operand is used here";
6334  return WalkResult::interrupt();
6335  }
6336  }
6337 
6338  // Ensure that the layer block does not drive any sinks outside.
6339  if (auto connect = dyn_cast<FConnectLike>(op)) {
6340  // ref.define is allowed to drive probes outside the layerblock.
6341  if (isa<RefDefineOp>(connect))
6342  return WalkResult::advance();
6343 
6344  // Verify that connects only drive values declared in the layer block.
6345  // If we see a non-passive connect destination, then verify that the
6346  // source is in the same layer block so that the source is not driven.
6347  auto dest =
6348  fieldRefCache.getFieldRefFromValue(connect.getDest()).getValue();
6349  bool passive = true;
6350  if (auto type =
6351  type_dyn_cast<FIRRTLBaseType>(connect.getDest().getType()))
6352  passive = type.isPassive();
6353  // TODO: Improve this verifier. This is intentionally _not_ verifying
6354  // a non-passive ConnectLike because it is hugely annoying to do
6355  // so---it requires a full understanding of if the connect is driving
6356  // destination-to-source, source-to-destination, or bi-directionally
6357  // which requires deep inspection of the type. Eventually, the FIRRTL
6358  // pass pipeline will remove all flips (e.g., canonicalize connect to