CIRCT  19.0.0git
FIRRTLOps.cpp
Go to the documentation of this file.
1 //===- FIRRTLOps.cpp - Implement the FIRRTL operations --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implement the FIRRTL ops.
10 //
11 //===----------------------------------------------------------------------===//
12 
24 #include "mlir/IR/BuiltinTypes.h"
25 #include "mlir/IR/Diagnostics.h"
26 #include "mlir/IR/DialectImplementation.h"
27 #include "mlir/IR/PatternMatch.h"
28 #include "mlir/IR/SymbolTable.h"
29 #include "mlir/Interfaces/FunctionImplementation.h"
30 #include "llvm/ADT/BitVector.h"
31 #include "llvm/ADT/DenseMap.h"
32 #include "llvm/ADT/DenseSet.h"
33 #include "llvm/ADT/STLExtras.h"
34 #include "llvm/ADT/SmallSet.h"
35 #include "llvm/ADT/StringExtras.h"
36 #include "llvm/ADT/TypeSwitch.h"
37 #include "llvm/Support/FormatVariadic.h"
38 
39 using llvm::SmallDenseSet;
40 using mlir::RegionRange;
41 using namespace circt;
42 using namespace firrtl;
43 using namespace chirrtl;
44 
45 //===----------------------------------------------------------------------===//
46 // Utilities
47 //===----------------------------------------------------------------------===//
48 
49 /// Remove elements from the input array corresponding to set bits in
50 /// `indicesToDrop`, returning the elements not mentioned.
51 template <typename T>
52 static SmallVector<T>
53 removeElementsAtIndices(ArrayRef<T> input,
54  const llvm::BitVector &indicesToDrop) {
55 #ifndef NDEBUG
56  if (!input.empty()) {
57  int lastIndex = indicesToDrop.find_last();
58  if (lastIndex >= 0)
59  assert((size_t)lastIndex < input.size() && "index out of range");
60  }
61 #endif
62 
63  // If the input is empty (which is an optimization we do for certain array
64  // attributes), simply return an empty vector.
65  if (input.empty())
66  return {};
67 
68  // Copy over the live chunks.
69  size_t lastCopied = 0;
70  SmallVector<T> result;
71  result.reserve(input.size() - indicesToDrop.count());
72 
73  for (unsigned indexToDrop : indicesToDrop.set_bits()) {
74  // If we skipped over some valid elements, copy them over.
75  if (indexToDrop > lastCopied) {
76  result.append(input.begin() + lastCopied, input.begin() + indexToDrop);
77  lastCopied = indexToDrop;
78  }
79  // Ignore this value so we don't copy it in the next iteration.
80  ++lastCopied;
81  }
82 
83  // If there are live elements at the end, copy them over.
84  if (lastCopied < input.size())
85  result.append(input.begin() + lastCopied, input.end());
86 
87  return result;
88 }
89 
90 /// Emit an error if optional location is non-null, return null of return type.
91 template <typename RetTy = FIRRTLType, typename... Args>
92 static RetTy emitInferRetTypeError(std::optional<Location> loc,
93  const Twine &message, Args &&...args) {
94  if (loc)
95  (mlir::emitError(*loc, message) << ... << std::forward<Args>(args));
96  return {};
97 }
98 
99 bool firrtl::isDuplexValue(Value val) {
100  // Block arguments are not duplex values.
101  while (Operation *op = val.getDefiningOp()) {
102  auto isDuplex =
103  TypeSwitch<Operation *, std::optional<bool>>(op)
104  .Case<SubfieldOp, SubindexOp, SubaccessOp>([&val](auto op) {
105  val = op.getInput();
106  return std::nullopt;
107  })
108  .Case<RegOp, RegResetOp, WireOp>([](auto) { return true; })
109  .Default([](auto) { return false; });
110  if (isDuplex)
111  return *isDuplex;
112  }
113  return false;
114 }
115 
116 /// Return the kind of port this is given the port type from a 'mem' decl.
117 static MemOp::PortKind getMemPortKindFromType(FIRRTLType type) {
118  constexpr unsigned int addr = 1 << 0;
119  constexpr unsigned int en = 1 << 1;
120  constexpr unsigned int clk = 1 << 2;
121  constexpr unsigned int data = 1 << 3;
122  constexpr unsigned int mask = 1 << 4;
123  constexpr unsigned int rdata = 1 << 5;
124  constexpr unsigned int wdata = 1 << 6;
125  constexpr unsigned int wmask = 1 << 7;
126  constexpr unsigned int wmode = 1 << 8;
127  constexpr unsigned int def = 1 << 9;
128  // Get the kind of port based on the fields of the Bundle.
129  auto portType = type_dyn_cast<BundleType>(type);
130  if (!portType)
131  return MemOp::PortKind::Debug;
132  unsigned fields = 0;
133  // Get the kind of port based on the fields of the Bundle.
134  for (auto elem : portType.getElements()) {
135  fields |= llvm::StringSwitch<unsigned>(elem.name.getValue())
136  .Case("addr", addr)
137  .Case("en", en)
138  .Case("clk", clk)
139  .Case("data", data)
140  .Case("mask", mask)
141  .Case("rdata", rdata)
142  .Case("wdata", wdata)
143  .Case("wmask", wmask)
144  .Case("wmode", wmode)
145  .Default(def);
146  }
147  if (fields == (addr | en | clk | data))
148  return MemOp::PortKind::Read;
149  if (fields == (addr | en | clk | data | mask))
150  return MemOp::PortKind::Write;
151  if (fields == (addr | en | clk | wdata | wmask | rdata | wmode))
152  return MemOp::PortKind::ReadWrite;
153  return MemOp::PortKind::Debug;
154 }
155 
157  switch (flow) {
158  case Flow::None:
159  return Flow::None;
160  case Flow::Source:
161  return Flow::Sink;
162  case Flow::Sink:
163  return Flow::Source;
164  case Flow::Duplex:
165  return Flow::Duplex;
166  }
167 }
168 
169 constexpr const char *toString(Flow flow) {
170  switch (flow) {
171  case Flow::None:
172  return "no flow";
173  case Flow::Source:
174  return "source flow";
175  case Flow::Sink:
176  return "sink flow";
177  case Flow::Duplex:
178  return "duplex flow";
179  }
180 }
181 
182 Flow firrtl::foldFlow(Value val, Flow accumulatedFlow) {
183  auto swap = [&accumulatedFlow]() -> Flow {
184  return swapFlow(accumulatedFlow);
185  };
186 
187  if (auto blockArg = dyn_cast<BlockArgument>(val)) {
188  auto *op = val.getParentBlock()->getParentOp();
189  if (auto moduleLike = dyn_cast<FModuleLike>(op)) {
190  auto direction = moduleLike.getPortDirection(blockArg.getArgNumber());
191  if (direction == Direction::Out)
192  return swap();
193  }
194  return accumulatedFlow;
195  }
196 
197  Operation *op = val.getDefiningOp();
198 
199  return TypeSwitch<Operation *, Flow>(op)
200  .Case<SubfieldOp, OpenSubfieldOp>([&](auto op) {
201  return foldFlow(op.getInput(),
202  op.isFieldFlipped() ? swap() : accumulatedFlow);
203  })
204  .Case<SubindexOp, SubaccessOp, OpenSubindexOp, RefSubOp>(
205  [&](auto op) { return foldFlow(op.getInput(), accumulatedFlow); })
206  // Registers, Wires, and behavioral memory ports are always Duplex.
207  .Case<RegOp, RegResetOp, WireOp, MemoryPortOp>(
208  [](auto) { return Flow::Duplex; })
209  .Case<InstanceOp, InstanceChoiceOp>([&](auto inst) {
210  auto resultNo = cast<OpResult>(val).getResultNumber();
211  if (inst.getPortDirection(resultNo) == Direction::Out)
212  return accumulatedFlow;
213  return swap();
214  })
215  .Case<MemOp>([&](auto op) {
216  // only debug ports with RefType have source flow.
217  if (type_isa<RefType>(val.getType()))
218  return Flow::Source;
219  return swap();
220  })
221  .Case<ObjectSubfieldOp>([&](ObjectSubfieldOp op) {
222  auto input = op.getInput();
223  auto *inputOp = input.getDefiningOp();
224 
225  // We are directly accessing a port on a local declaration.
226  if (auto objectOp = dyn_cast_or_null<ObjectOp>(inputOp)) {
227  auto classType = input.getType();
228  auto direction = classType.getElement(op.getIndex()).direction;
229  if (direction == Direction::In)
230  return Flow::Sink;
231  return Flow::Source;
232  }
233 
234  // We are accessing a remote object. Input ports on remote objects are
235  // inaccessible, and thus have Flow::None. Walk backwards through the
236  // chain of subindexes, to detect if we have indexed through an input
237  // port. At the end, either we did index through an input port, or the
238  // entire path was through output ports with source flow.
239  while (true) {
240  auto classType = input.getType();
241  auto direction = classType.getElement(op.getIndex()).direction;
242  if (direction == Direction::In)
243  return Flow::None;
244 
245  op = dyn_cast_or_null<ObjectSubfieldOp>(inputOp);
246  if (op) {
247  input = op.getInput();
248  inputOp = input.getDefiningOp();
249  continue;
250  }
251 
252  return accumulatedFlow;
253  };
254  })
255  // Anything else acts like a universal source.
256  .Default([&](auto) { return accumulatedFlow; });
257 }
258 
259 // TODO: This is doing the same walk as foldFlow. These two functions can be
260 // combined and return a (flow, kind) product.
262  Operation *op = val.getDefiningOp();
263  if (!op)
264  return DeclKind::Port;
265 
266  return TypeSwitch<Operation *, DeclKind>(op)
267  .Case<InstanceOp>([](auto) { return DeclKind::Instance; })
268  .Case<SubfieldOp, SubindexOp, SubaccessOp, OpenSubfieldOp, OpenSubindexOp,
269  RefSubOp>([](auto op) { return getDeclarationKind(op.getInput()); })
270  .Default([](auto) { return DeclKind::Other; });
271 }
272 
273 size_t firrtl::getNumPorts(Operation *op) {
274  if (auto module = dyn_cast<FModuleLike>(op))
275  return module.getNumPorts();
276  return op->getNumResults();
277 }
278 
279 /// Check whether an operation has a `DontTouch` annotation, or a symbol that
280 /// should prevent certain types of canonicalizations.
281 bool firrtl::hasDontTouch(Operation *op) {
282  return op->getAttr(hw::InnerSymbolTable::getInnerSymbolAttrName()) ||
284 }
285 
286 /// Check whether a block argument ("port") or the operation defining a value
287 /// has a `DontTouch` annotation, or a symbol that should prevent certain types
288 /// of canonicalizations.
289 bool firrtl::hasDontTouch(Value value) {
290  if (auto *op = value.getDefiningOp())
291  return hasDontTouch(op);
292  auto arg = dyn_cast<BlockArgument>(value);
293  auto module = cast<FModuleOp>(arg.getOwner()->getParentOp());
294  return (module.getPortSymbolAttr(arg.getArgNumber())) ||
295  AnnotationSet::forPort(module, arg.getArgNumber()).hasDontTouch();
296 }
297 
298 /// Get a special name to use when printing the entry block arguments of the
299 /// region contained by an operation in this dialect.
300 void getAsmBlockArgumentNamesImpl(Operation *op, mlir::Region &region,
301  OpAsmSetValueNameFn setNameFn) {
302  if (region.empty())
303  return;
304  auto *parentOp = op;
305  auto *block = &region.front();
306  // Check to see if the operation containing the arguments has 'firrtl.name'
307  // attributes for them. If so, use that as the name.
308  auto argAttr = parentOp->getAttrOfType<ArrayAttr>("portNames");
309  // Do not crash on invalid IR.
310  if (!argAttr || argAttr.size() != block->getNumArguments())
311  return;
312 
313  for (size_t i = 0, e = block->getNumArguments(); i != e; ++i) {
314  auto str = cast<StringAttr>(argAttr[i]).getValue();
315  if (!str.empty())
316  setNameFn(block->getArgument(i), str);
317  }
318 }
319 
320 /// A forward declaration for `NameKind` attribute parser.
321 static ParseResult parseNameKind(OpAsmParser &parser,
322  firrtl::NameKindEnumAttr &result);
323 
324 //===----------------------------------------------------------------------===//
325 // Layer Verification Utilities
326 //===----------------------------------------------------------------------===//
327 
328 namespace {
329 struct CompareSymbolRefAttr {
330  // True if lhs is lexicographically less than rhs.
331  bool operator()(SymbolRefAttr lhs, SymbolRefAttr rhs) const {
332  auto cmp = lhs.getRootReference().compare(rhs.getRootReference());
333  if (cmp == -1)
334  return true;
335  if (cmp == 1)
336  return false;
337  auto lhsNested = lhs.getNestedReferences();
338  auto rhsNested = rhs.getNestedReferences();
339  auto lhsNestedSize = lhsNested.size();
340  auto rhsNestedSize = rhsNested.size();
341  auto e = std::min(lhsNestedSize, rhsNestedSize);
342  for (unsigned i = 0; i < e; ++i) {
343  auto cmp = lhsNested[i].getAttr().compare(rhsNested[i].getAttr());
344  if (cmp == -1)
345  return true;
346  if (cmp == 1)
347  return false;
348  }
349  return lhsNestedSize < rhsNestedSize;
350  }
351 };
352 } // namespace
353 
355 
356 /// Get the ambient layers active at the given op.
357 static LayerSet getAmbientLayersAt(Operation *op) {
358  // Crawl through the parent ops, accumulating all ambient layers at the given
359  // operation.
360  LayerSet result;
361  for (; op != nullptr; op = op->getParentOp()) {
362  if (auto module = dyn_cast<FModuleLike>(op)) {
363  auto layers = module.getLayersAttr().getAsRange<SymbolRefAttr>();
364  result.insert(layers.begin(), layers.end());
365  break;
366  }
367  if (auto layerblock = dyn_cast<LayerBlockOp>(op)) {
368  result.insert(layerblock.getLayerName());
369  continue;
370  }
371  }
372  return result;
373 }
374 
375 /// Get the ambient layer requirements at the definition site of the value.
376 static LayerSet getAmbientLayersFor(Value value) {
377  return getAmbientLayersAt(getFieldRefFromValue(value).getDefiningOp());
378 }
379 
380 /// Get the effective layer requirements for the given value.
381 /// The effective layers for a value is the union of
382 /// - the ambient layers for the cannonical storage location.
383 /// - any explicit layer annotations in the value's type.
384 static LayerSet getLayersFor(Value value) {
385  auto result = getAmbientLayersFor(value);
386  if (auto type = dyn_cast<RefType>(value.getType()))
387  if (auto layer = type.getLayer())
388  result.insert(type.getLayer());
389  return result;
390 }
391 
392 /// Check that the source layer is compatible with the destination layer.
393 /// Either the source and destination are identical, or the source-layer
394 /// is a parent of the destination. For example `A` is compatible with `A.B.C`,
395 /// because any definition valid in `A` is also valid in `A.B.C`.
396 static bool isLayerCompatibleWith(mlir::SymbolRefAttr srcLayer,
397  mlir::SymbolRefAttr dstLayer) {
398  // A non-colored probe may be cast to any colored probe.
399  if (!srcLayer)
400  return true;
401 
402  // A colored probe cannot be cast to an uncolored probe.
403  if (!dstLayer)
404  return false;
405 
406  // Return true if the srcLayer is a prefix of the dstLayer.
407  if (srcLayer.getRootReference() != dstLayer.getRootReference())
408  return false;
409 
410  auto srcNames = srcLayer.getNestedReferences();
411  auto dstNames = dstLayer.getNestedReferences();
412  if (dstNames.size() < srcNames.size())
413  return false;
414 
415  return llvm::all_of(llvm::zip_first(srcNames, dstNames),
416  [](auto x) { return std::get<0>(x) == std::get<1>(x); });
417 }
418 
419 /// Check that the source layer is present in the destination layers.
420 static bool isLayerCompatibleWith(SymbolRefAttr srcLayer,
421  const LayerSet &dstLayers) {
422  // fast path: the required layer is directly listed in the provided layers.
423  if (dstLayers.contains(srcLayer))
424  return true;
425 
426  // Slow path: the required layer is not directly listed in the provided
427  // layers, but the layer may still be provided by a nested layer.
428  return any_of(dstLayers, [=](SymbolRefAttr dstLayer) {
429  return isLayerCompatibleWith(srcLayer, dstLayer);
430  });
431 }
432 
433 /// Check that the source layers are all present in the destination layers.
434 /// True if all source layers are present in the destination.
435 /// Outputs the set of source layers that are missing in the destination.
436 static bool isLayerSetCompatibleWith(const LayerSet &src, const LayerSet &dst,
437  SmallVectorImpl<SymbolRefAttr> &missing) {
438  for (auto srcLayer : src)
439  if (!isLayerCompatibleWith(srcLayer, dst))
440  missing.push_back(srcLayer);
441 
442  llvm::sort(missing, CompareSymbolRefAttr());
443  return missing.empty();
444 }
445 
446 //===----------------------------------------------------------------------===//
447 // CircuitOp
448 //===----------------------------------------------------------------------===//
449 
450 void CircuitOp::build(OpBuilder &builder, OperationState &result,
451  StringAttr name, ArrayAttr annotations) {
452  // Add an attribute for the name.
453  result.addAttribute(builder.getStringAttr("name"), name);
454 
455  if (!annotations)
456  annotations = builder.getArrayAttr({});
457  result.addAttribute("annotations", annotations);
458 
459  // Create a region and a block for the body.
460  Region *bodyRegion = result.addRegion();
461  Block *body = new Block();
462  bodyRegion->push_back(body);
463 }
464 
465 static ParseResult parseCircuitOpAttrs(OpAsmParser &parser,
466  NamedAttrList &resultAttrs) {
467  auto result = parser.parseOptionalAttrDictWithKeyword(resultAttrs);
468  if (!resultAttrs.get("annotations"))
469  resultAttrs.append("annotations", parser.getBuilder().getArrayAttr({}));
470 
471  return result;
472 }
473 
474 static void printCircuitOpAttrs(OpAsmPrinter &p, Operation *op,
475  DictionaryAttr attr) {
476  // "name" is always elided.
477  SmallVector<StringRef> elidedAttrs = {"name"};
478  // Elide "annotations" if it doesn't exist or if it is empty
479  auto annotationsAttr = op->getAttrOfType<ArrayAttr>("annotations");
480  if (annotationsAttr.empty())
481  elidedAttrs.push_back("annotations");
482 
483  p.printOptionalAttrDictWithKeyword(op->getAttrs(), elidedAttrs);
484 }
485 
486 LogicalResult CircuitOp::verifyRegions() {
487  StringRef main = getName();
488 
489  // Check that the circuit has a non-empty name.
490  if (main.empty()) {
491  emitOpError("must have a non-empty name");
492  return failure();
493  }
494 
495  mlir::SymbolTable symtbl(getOperation());
496 
497  // Store a mapping of defname to either the first external module
498  // that defines it or, preferentially, the first external module
499  // that defines it and has no parameters.
500  llvm::DenseMap<Attribute, FExtModuleOp> defnameMap;
501 
502  auto verifyExtModule = [&](FExtModuleOp extModule) -> LogicalResult {
503  if (!extModule)
504  return success();
505 
506  auto defname = extModule.getDefnameAttr();
507  if (!defname)
508  return success();
509 
510  // Check that this extmodule's defname does not conflict with
511  // the symbol name of any module.
512  if (auto collidingModule = symtbl.lookup<FModuleOp>(defname.getValue()))
513  return extModule.emitOpError()
514  .append("attribute 'defname' with value ", defname,
515  " conflicts with the name of another module in the circuit")
516  .attachNote(collidingModule.getLoc())
517  .append("previous module declared here");
518 
519  // Find an optional extmodule with a defname collision. Update
520  // the defnameMap if this is the first extmodule with that
521  // defname or if the current extmodule takes no parameters and
522  // the collision does. The latter condition improves later
523  // extmodule verification as checking against a parameterless
524  // module is stricter.
525  FExtModuleOp collidingExtModule;
526  if (auto &value = defnameMap[defname]) {
527  collidingExtModule = value;
528  if (!value.getParameters().empty() && extModule.getParameters().empty())
529  value = extModule;
530  } else {
531  value = extModule;
532  // Go to the next extmodule if no extmodule with the same
533  // defname was found.
534  return success();
535  }
536 
537  // Check that the number of ports is exactly the same.
538  SmallVector<PortInfo> ports = extModule.getPorts();
539  SmallVector<PortInfo> collidingPorts = collidingExtModule.getPorts();
540 
541  if (ports.size() != collidingPorts.size())
542  return extModule.emitOpError()
543  .append("with 'defname' attribute ", defname, " has ", ports.size(),
544  " ports which is different from a previously defined "
545  "extmodule with the same 'defname' which has ",
546  collidingPorts.size(), " ports")
547  .attachNote(collidingExtModule.getLoc())
548  .append("previous extmodule definition occurred here");
549 
550  // Check that ports match for name and type. Since parameters
551  // *might* affect widths, ignore widths if either module has
552  // parameters. Note that this allows for misdetections, but
553  // has zero false positives.
554  for (auto p : llvm::zip(ports, collidingPorts)) {
555  StringAttr aName = std::get<0>(p).name, bName = std::get<1>(p).name;
556  Type aType = std::get<0>(p).type, bType = std::get<1>(p).type;
557 
558  if (aName != bName)
559  return extModule.emitOpError()
560  .append("with 'defname' attribute ", defname,
561  " has a port with name ", aName,
562  " which does not match the name of the port in the same "
563  "position of a previously defined extmodule with the same "
564  "'defname', expected port to have name ",
565  bName)
566  .attachNote(collidingExtModule.getLoc())
567  .append("previous extmodule definition occurred here");
568 
569  if (!extModule.getParameters().empty() ||
570  !collidingExtModule.getParameters().empty()) {
571  // Compare base types as widthless, others must match.
572  if (auto base = type_dyn_cast<FIRRTLBaseType>(aType))
573  aType = base.getWidthlessType();
574  if (auto base = type_dyn_cast<FIRRTLBaseType>(bType))
575  bType = base.getWidthlessType();
576  }
577  if (aType != bType)
578  return extModule.emitOpError()
579  .append("with 'defname' attribute ", defname,
580  " has a port with name ", aName,
581  " which has a different type ", aType,
582  " which does not match the type of the port in the same "
583  "position of a previously defined extmodule with the same "
584  "'defname', expected port to have type ",
585  bType)
586  .attachNote(collidingExtModule.getLoc())
587  .append("previous extmodule definition occurred here");
588  }
589  return success();
590  };
591 
592  for (auto &op : *getBodyBlock()) {
593  // Verify external modules.
594  if (auto extModule = dyn_cast<FExtModuleOp>(op)) {
595  if (verifyExtModule(extModule).failed())
596  return failure();
597  }
598  }
599 
600  return success();
601 }
602 
603 Block *CircuitOp::getBodyBlock() { return &getBody().front(); }
604 
605 //===----------------------------------------------------------------------===//
606 // FExtModuleOp and FModuleOp
607 //===----------------------------------------------------------------------===//
608 
609 static SmallVector<PortInfo> getPortImpl(FModuleLike module) {
610  SmallVector<PortInfo> results;
611  for (unsigned i = 0, e = module.getNumPorts(); i < e; ++i) {
612  results.push_back({module.getPortNameAttr(i), module.getPortType(i),
613  module.getPortDirection(i), module.getPortSymbolAttr(i),
614  module.getPortLocation(i),
615  AnnotationSet::forPort(module, i)});
616  }
617  return results;
618 }
619 
620 SmallVector<PortInfo> FModuleOp::getPorts() { return ::getPortImpl(*this); }
621 
622 SmallVector<PortInfo> FExtModuleOp::getPorts() { return ::getPortImpl(*this); }
623 
624 SmallVector<PortInfo> FIntModuleOp::getPorts() { return ::getPortImpl(*this); }
625 
626 SmallVector<PortInfo> FMemModuleOp::getPorts() { return ::getPortImpl(*this); }
627 
629  if (dir == Direction::In)
631  if (dir == Direction::Out)
633  assert(0 && "invalid direction");
634  abort();
635 }
636 
637 static SmallVector<hw::PortInfo> getPortListImpl(FModuleLike module) {
638  SmallVector<hw::PortInfo> results;
639  auto aname = StringAttr::get(module.getContext(),
640  hw::HWModuleLike::getPortSymbolAttrName());
641  auto emptyDict = DictionaryAttr::get(module.getContext());
642  for (unsigned i = 0, e = getNumPorts(module); i < e; ++i) {
643  auto sym = module.getPortSymbolAttr(i);
644  results.push_back(
645  {{module.getPortNameAttr(i), module.getPortType(i),
646  dirFtoH(module.getPortDirection(i))},
647  i,
648  sym ? DictionaryAttr::get(
649  module.getContext(),
650  ArrayRef<mlir::NamedAttribute>{NamedAttribute{aname, sym}})
651  : emptyDict,
652  module.getPortLocation(i)});
653  }
654  return results;
655 }
656 
657 SmallVector<::circt::hw::PortInfo> FModuleOp::getPortList() {
659 }
660 
661 SmallVector<::circt::hw::PortInfo> FExtModuleOp::getPortList() {
663 }
664 
665 SmallVector<::circt::hw::PortInfo> FIntModuleOp::getPortList() {
667 }
668 
669 SmallVector<::circt::hw::PortInfo> FMemModuleOp::getPortList() {
671 }
672 
673 static hw::PortInfo getPortImpl(FModuleLike module, size_t idx) {
674  return {{module.getPortNameAttr(idx), module.getPortType(idx),
675  dirFtoH(module.getPortDirection(idx))},
676  idx,
678  module.getContext(),
679  ArrayRef<mlir::NamedAttribute>{NamedAttribute{
680  StringAttr::get(module.getContext(),
681  hw::HWModuleLike::getPortSymbolAttrName()),
682  module.getPortSymbolAttr(idx)}}),
683  module.getPortLocation(idx)};
684 }
685 
687  return ::getPortImpl(*this, idx);
688 }
689 
691  return ::getPortImpl(*this, idx);
692 }
693 
695  return ::getPortImpl(*this, idx);
696 }
697 
699  return ::getPortImpl(*this, idx);
700 }
701 
702 // Return the port with the specified name.
703 BlockArgument FModuleOp::getArgument(size_t portNumber) {
704  return getBodyBlock()->getArgument(portNumber);
705 }
706 
707 /// Inserts the given ports. The insertion indices are expected to be in order.
708 /// Insertion occurs in-order, such that ports with the same insertion index
709 /// appear in the module in the same order they appeared in the list.
710 static void insertPorts(FModuleLike op,
711  ArrayRef<std::pair<unsigned, PortInfo>> ports,
712  bool supportsInternalPaths = false) {
713  if (ports.empty())
714  return;
715  unsigned oldNumArgs = op.getNumPorts();
716  unsigned newNumArgs = oldNumArgs + ports.size();
717 
718  // Add direction markers and names for new ports.
719  SmallVector<Direction> existingDirections =
720  direction::unpackAttribute(op.getPortDirectionsAttr());
721  ArrayRef<Attribute> existingNames = op.getPortNames();
722  ArrayRef<Attribute> existingTypes = op.getPortTypes();
723  ArrayRef<Attribute> existingLocs = op.getPortLocations();
724  assert(existingDirections.size() == oldNumArgs);
725  assert(existingNames.size() == oldNumArgs);
726  assert(existingTypes.size() == oldNumArgs);
727  assert(existingLocs.size() == oldNumArgs);
728  SmallVector<Attribute> internalPaths;
729  auto emptyInternalPath = InternalPathAttr::get(op.getContext());
730  if (supportsInternalPaths) {
731  if (auto internalPathsAttr = op->getAttrOfType<ArrayAttr>("internalPaths"))
732  llvm::append_range(internalPaths, internalPathsAttr);
733  else
734  internalPaths.resize(oldNumArgs, emptyInternalPath);
735  assert(internalPaths.size() == oldNumArgs);
736  }
737 
738  SmallVector<Direction> newDirections;
739  SmallVector<Attribute> newNames, newTypes, newAnnos, newSyms, newLocs,
740  newInternalPaths;
741  newDirections.reserve(newNumArgs);
742  newNames.reserve(newNumArgs);
743  newTypes.reserve(newNumArgs);
744  newAnnos.reserve(newNumArgs);
745  newSyms.reserve(newNumArgs);
746  newLocs.reserve(newNumArgs);
747  newInternalPaths.reserve(newNumArgs);
748 
749  auto emptyArray = ArrayAttr::get(op.getContext(), {});
750 
751  unsigned oldIdx = 0;
752  auto migrateOldPorts = [&](unsigned untilOldIdx) {
753  while (oldIdx < oldNumArgs && oldIdx < untilOldIdx) {
754  newDirections.push_back(existingDirections[oldIdx]);
755  newNames.push_back(existingNames[oldIdx]);
756  newTypes.push_back(existingTypes[oldIdx]);
757  newAnnos.push_back(op.getAnnotationsAttrForPort(oldIdx));
758  newSyms.push_back(op.getPortSymbolAttr(oldIdx));
759  newLocs.push_back(existingLocs[oldIdx]);
760  if (supportsInternalPaths)
761  newInternalPaths.push_back(internalPaths[oldIdx]);
762  ++oldIdx;
763  }
764  };
765  for (auto pair : llvm::enumerate(ports)) {
766  auto idx = pair.value().first;
767  auto &port = pair.value().second;
768  migrateOldPorts(idx);
769  newDirections.push_back(port.direction);
770  newNames.push_back(port.name);
771  newTypes.push_back(TypeAttr::get(port.type));
772  auto annos = port.annotations.getArrayAttr();
773  newAnnos.push_back(annos ? annos : emptyArray);
774  newSyms.push_back(port.sym);
775  newLocs.push_back(port.loc);
776  if (supportsInternalPaths)
777  newInternalPaths.push_back(emptyInternalPath);
778  }
779  migrateOldPorts(oldNumArgs);
780 
781  // The lack of *any* port annotations is represented by an empty
782  // `portAnnotations` array as a shorthand.
783  if (llvm::all_of(newAnnos, [](Attribute attr) {
784  return cast<ArrayAttr>(attr).empty();
785  }))
786  newAnnos.clear();
787 
788  // Apply these changed markers.
789  op->setAttr("portDirections",
790  direction::packAttribute(op.getContext(), newDirections));
791  op->setAttr("portNames", ArrayAttr::get(op.getContext(), newNames));
792  op->setAttr("portTypes", ArrayAttr::get(op.getContext(), newTypes));
793  op->setAttr("portAnnotations", ArrayAttr::get(op.getContext(), newAnnos));
794  op.setPortSymbols(newSyms);
795  op->setAttr("portLocations", ArrayAttr::get(op.getContext(), newLocs));
796  if (supportsInternalPaths) {
797  // Drop if all-empty, otherwise set to new array.
798  auto empty = llvm::all_of(newInternalPaths, [](Attribute attr) {
799  return !cast<InternalPathAttr>(attr).getPath();
800  });
801  if (empty)
802  op->removeAttr("internalPaths");
803  else
804  op->setAttr("internalPaths",
805  ArrayAttr::get(op.getContext(), newInternalPaths));
806  }
807 }
808 
809 /// Erases the ports that have their corresponding bit set in `portIndices`.
810 static void erasePorts(FModuleLike op, const llvm::BitVector &portIndices) {
811  if (portIndices.none())
812  return;
813 
814  // Drop the direction markers for dead ports.
815  SmallVector<Direction> portDirections =
816  direction::unpackAttribute(op.getPortDirectionsAttr());
817  ArrayRef<Attribute> portNames = op.getPortNames();
818  ArrayRef<Attribute> portTypes = op.getPortTypes();
819  ArrayRef<Attribute> portAnnos = op.getPortAnnotations();
820  ArrayRef<Attribute> portSyms = op.getPortSymbols();
821  ArrayRef<Attribute> portLocs = op.getPortLocations();
822  auto numPorts = op.getNumPorts();
823  (void)numPorts;
824  assert(portDirections.size() == numPorts);
825  assert(portNames.size() == numPorts);
826  assert(portAnnos.size() == numPorts || portAnnos.empty());
827  assert(portTypes.size() == numPorts);
828  assert(portSyms.size() == numPorts || portSyms.empty());
829  assert(portLocs.size() == numPorts);
830 
831  SmallVector<Direction> newPortDirections =
832  removeElementsAtIndices<Direction>(portDirections, portIndices);
833  SmallVector<Attribute> newPortNames, newPortTypes, newPortAnnos, newPortSyms,
834  newPortLocs;
835  newPortNames = removeElementsAtIndices(portNames, portIndices);
836  newPortTypes = removeElementsAtIndices(portTypes, portIndices);
837  newPortAnnos = removeElementsAtIndices(portAnnos, portIndices);
838  newPortSyms = removeElementsAtIndices(portSyms, portIndices);
839  newPortLocs = removeElementsAtIndices(portLocs, portIndices);
840  op->setAttr("portDirections",
841  direction::packAttribute(op.getContext(), newPortDirections));
842  op->setAttr("portNames", ArrayAttr::get(op.getContext(), newPortNames));
843  op->setAttr("portAnnotations", ArrayAttr::get(op.getContext(), newPortAnnos));
844  op->setAttr("portTypes", ArrayAttr::get(op.getContext(), newPortTypes));
845  FModuleLike::fixupPortSymsArray(newPortSyms, op.getContext());
846  op->setAttr("portSyms", ArrayAttr::get(op.getContext(), newPortSyms));
847  op->setAttr("portLocations", ArrayAttr::get(op.getContext(), newPortLocs));
848 }
849 
850 template <typename T>
851 static void eraseInternalPaths(T op, const llvm::BitVector &portIndices) {
852  // Fixup internalPaths array.
853  auto internalPaths = op.getInternalPaths();
854  if (!internalPaths)
855  return;
856 
857  auto newPaths =
858  removeElementsAtIndices(internalPaths->getValue(), portIndices);
859 
860  // Drop if all-empty, otherwise set to new array.
861  auto empty = llvm::all_of(newPaths, [](Attribute attr) {
862  return !cast<InternalPathAttr>(attr).getPath();
863  });
864  if (empty)
865  op.removeInternalPathsAttr();
866  else
867  op.setInternalPathsAttr(ArrayAttr::get(op.getContext(), newPaths));
868 }
869 
870 void FExtModuleOp::erasePorts(const llvm::BitVector &portIndices) {
871  ::erasePorts(cast<FModuleLike>((Operation *)*this), portIndices);
872  eraseInternalPaths(*this, portIndices);
873 }
874 
875 void FIntModuleOp::erasePorts(const llvm::BitVector &portIndices) {
876  ::erasePorts(cast<FModuleLike>((Operation *)*this), portIndices);
877  eraseInternalPaths(*this, portIndices);
878 }
879 
880 void FMemModuleOp::erasePorts(const llvm::BitVector &portIndices) {
881  ::erasePorts(cast<FModuleLike>((Operation *)*this), portIndices);
882 }
883 
884 void FModuleOp::erasePorts(const llvm::BitVector &portIndices) {
885  ::erasePorts(cast<FModuleLike>((Operation *)*this), portIndices);
886  getBodyBlock()->eraseArguments(portIndices);
887 }
888 
889 /// Inserts the given ports. The insertion indices are expected to be in order.
890 /// Insertion occurs in-order, such that ports with the same insertion index
891 /// appear in the module in the same order they appeared in the list.
892 void FModuleOp::insertPorts(ArrayRef<std::pair<unsigned, PortInfo>> ports) {
893  ::insertPorts(cast<FModuleLike>((Operation *)*this), ports);
894 
895  // Insert the block arguments.
896  auto *body = getBodyBlock();
897  for (size_t i = 0, e = ports.size(); i < e; ++i) {
898  // Block arguments are inserted one at a time, so for each argument we
899  // insert we have to increase the index by 1.
900  auto &[index, port] = ports[i];
901  body->insertArgument(index + i, port.type, port.loc);
902  }
903 }
904 
905 void FExtModuleOp::insertPorts(ArrayRef<std::pair<unsigned, PortInfo>> ports) {
906  ::insertPorts(cast<FModuleLike>((Operation *)*this), ports,
907  /*supportsInternalPaths=*/true);
908 }
909 
910 void FIntModuleOp::insertPorts(ArrayRef<std::pair<unsigned, PortInfo>> ports) {
911  ::insertPorts(cast<FModuleLike>((Operation *)*this), ports,
912  /*supportsInternalPaths=*/true);
913 }
914 
915 /// Inserts the given ports. The insertion indices are expected to be in order.
916 /// Insertion occurs in-order, such that ports with the same insertion index
917 /// appear in the module in the same order they appeared in the list.
918 void FMemModuleOp::insertPorts(ArrayRef<std::pair<unsigned, PortInfo>> ports) {
919  ::insertPorts(cast<FModuleLike>((Operation *)*this), ports);
920 }
921 
922 static void buildModuleLike(OpBuilder &builder, OperationState &result,
923  StringAttr name, ArrayRef<PortInfo> ports,
924  ArrayAttr annotations, ArrayAttr layers,
925  bool withAnnotations, bool withLayers) {
926  // Add an attribute for the name.
927  result.addAttribute(::mlir::SymbolTable::getSymbolAttrName(), name);
928 
929  // Record the names of the arguments if present.
930  SmallVector<Direction, 4> portDirections;
931  SmallVector<Attribute, 4> portNames;
932  SmallVector<Attribute, 4> portTypes;
933  SmallVector<Attribute, 4> portAnnotations;
934  SmallVector<Attribute, 4> portSyms;
935  SmallVector<Attribute, 4> portLocs;
936  for (const auto &port : ports) {
937  portDirections.push_back(port.direction);
938  portNames.push_back(port.name);
939  portTypes.push_back(TypeAttr::get(port.type));
940  portAnnotations.push_back(port.annotations.getArrayAttr());
941  portSyms.push_back(port.sym);
942  portLocs.push_back(port.loc);
943  }
944 
945  // The lack of *any* port annotations is represented by an empty
946  // `portAnnotations` array as a shorthand.
947  if (llvm::all_of(portAnnotations, [](Attribute attr) {
948  return cast<ArrayAttr>(attr).empty();
949  }))
950  portAnnotations.clear();
951 
952  FModuleLike::fixupPortSymsArray(portSyms, builder.getContext());
953  // Both attributes are added, even if the module has no ports.
954  result.addAttribute(
955  "portDirections",
956  direction::packAttribute(builder.getContext(), portDirections));
957  result.addAttribute("portNames", builder.getArrayAttr(portNames));
958  result.addAttribute("portTypes", builder.getArrayAttr(portTypes));
959  result.addAttribute("portSyms", builder.getArrayAttr(portSyms));
960  result.addAttribute("portLocations", builder.getArrayAttr(portLocs));
961 
962  if (withAnnotations) {
963  if (!annotations)
964  annotations = builder.getArrayAttr({});
965  result.addAttribute("annotations", annotations);
966  result.addAttribute("portAnnotations",
967  builder.getArrayAttr(portAnnotations));
968  }
969 
970  if (withLayers) {
971  if (!layers)
972  layers = builder.getArrayAttr({});
973  result.addAttribute("layers", layers);
974  }
975 
976  result.addRegion();
977 }
978 
979 static void buildModule(OpBuilder &builder, OperationState &result,
980  StringAttr name, ArrayRef<PortInfo> ports,
981  ArrayAttr annotations, ArrayAttr layers) {
982  buildModuleLike(builder, result, name, ports, annotations, layers,
983  /*withAnnotations=*/true, /*withLayers=*/true);
984 }
985 
986 static void buildClass(OpBuilder &builder, OperationState &result,
987  StringAttr name, ArrayRef<PortInfo> ports) {
988  return buildModuleLike(builder, result, name, ports, {}, {},
989  /*withAnnotations=*/false,
990  /*withLayers=*/false);
991 }
992 
993 void FModuleOp::build(OpBuilder &builder, OperationState &result,
994  StringAttr name, ConventionAttr convention,
995  ArrayRef<PortInfo> ports, ArrayAttr annotations,
996  ArrayAttr layers) {
997  buildModule(builder, result, name, ports, annotations, layers);
998  result.addAttribute("convention", convention);
999 
1000  // Create a region and a block for the body.
1001  auto *bodyRegion = result.regions[0].get();
1002  Block *body = new Block();
1003  bodyRegion->push_back(body);
1004 
1005  // Add arguments to the body block.
1006  for (auto &elt : ports)
1007  body->addArgument(elt.type, elt.loc);
1008 }
1009 
1010 void FExtModuleOp::build(OpBuilder &builder, OperationState &result,
1011  StringAttr name, ConventionAttr convention,
1012  ArrayRef<PortInfo> ports, StringRef defnameAttr,
1013  ArrayAttr annotations, ArrayAttr parameters,
1014  ArrayAttr internalPaths, ArrayAttr layers) {
1015  buildModule(builder, result, name, ports, annotations, layers);
1016  result.addAttribute("convention", convention);
1017  if (!defnameAttr.empty())
1018  result.addAttribute("defname", builder.getStringAttr(defnameAttr));
1019  if (!parameters)
1020  parameters = builder.getArrayAttr({});
1021  result.addAttribute(getParametersAttrName(result.name), parameters);
1022  if (internalPaths && !internalPaths.empty())
1023  result.addAttribute(getInternalPathsAttrName(result.name), internalPaths);
1024 }
1025 
1026 void FIntModuleOp::build(OpBuilder &builder, OperationState &result,
1027  StringAttr name, ArrayRef<PortInfo> ports,
1028  StringRef intrinsicNameAttr, ArrayAttr annotations,
1029  ArrayAttr parameters, ArrayAttr internalPaths,
1030  ArrayAttr layers) {
1031  buildModule(builder, result, name, ports, annotations, layers);
1032  result.addAttribute("intrinsic", builder.getStringAttr(intrinsicNameAttr));
1033  if (!parameters)
1034  parameters = builder.getArrayAttr({});
1035  result.addAttribute(getParametersAttrName(result.name), parameters);
1036  if (internalPaths && !internalPaths.empty())
1037  result.addAttribute(getInternalPathsAttrName(result.name), internalPaths);
1038 }
1039 
1040 void FMemModuleOp::build(OpBuilder &builder, OperationState &result,
1041  StringAttr name, ArrayRef<PortInfo> ports,
1042  uint32_t numReadPorts, uint32_t numWritePorts,
1043  uint32_t numReadWritePorts, uint32_t dataWidth,
1044  uint32_t maskBits, uint32_t readLatency,
1045  uint32_t writeLatency, uint64_t depth,
1046  ArrayAttr annotations, ArrayAttr layers) {
1047  auto *context = builder.getContext();
1048  buildModule(builder, result, name, ports, annotations, layers);
1049  auto ui32Type = IntegerType::get(context, 32, IntegerType::Unsigned);
1050  auto ui64Type = IntegerType::get(context, 64, IntegerType::Unsigned);
1051  result.addAttribute("numReadPorts", IntegerAttr::get(ui32Type, numReadPorts));
1052  result.addAttribute("numWritePorts",
1053  IntegerAttr::get(ui32Type, numWritePorts));
1054  result.addAttribute("numReadWritePorts",
1055  IntegerAttr::get(ui32Type, numReadWritePorts));
1056  result.addAttribute("dataWidth", IntegerAttr::get(ui32Type, dataWidth));
1057  result.addAttribute("maskBits", IntegerAttr::get(ui32Type, maskBits));
1058  result.addAttribute("readLatency", IntegerAttr::get(ui32Type, readLatency));
1059  result.addAttribute("writeLatency", IntegerAttr::get(ui32Type, writeLatency));
1060  result.addAttribute("depth", IntegerAttr::get(ui64Type, depth));
1061  result.addAttribute("extraPorts", ArrayAttr::get(context, {}));
1062 }
1063 
1064 /// Print a list of module ports in the following form:
1065 /// in x: !firrtl.uint<1> [{class = "DontTouch}], out "_port": !firrtl.uint<2>
1066 ///
1067 /// When there is no block specified, the port names print as MLIR identifiers,
1068 /// wrapping in quotes if not legal to print as-is. When there is no block
1069 /// specified, this function always return false, indicating that there was no
1070 /// issue printing port names.
1071 ///
1072 /// If there is a block specified, then port names will be printed as SSA
1073 /// values. If there is a reason the printed SSA values can't match the true
1074 /// port name, then this function will return true. When this happens, the
1075 /// caller should print the port names as a part of the `attr-dict`.
1076 static bool printModulePorts(OpAsmPrinter &p, Block *block,
1077  ArrayRef<Direction> portDirections,
1078  ArrayRef<Attribute> portNames,
1079  ArrayRef<Attribute> portTypes,
1080  ArrayRef<Attribute> portAnnotations,
1081  ArrayRef<Attribute> portSyms,
1082  ArrayRef<Attribute> portLocs) {
1083  // When printing port names as SSA values, we can fail to print them
1084  // identically.
1085  bool printedNamesDontMatch = false;
1086 
1087  mlir::OpPrintingFlags flags;
1088 
1089  // If we are printing the ports as block arguments the op must have a first
1090  // block.
1091  SmallString<32> resultNameStr;
1092  p << '(';
1093  for (unsigned i = 0, e = portTypes.size(); i < e; ++i) {
1094  if (i > 0)
1095  p << ", ";
1096 
1097  // Print the port direction.
1098  p << portDirections[i] << " ";
1099 
1100  // Print the port name. If there is a valid block, we print it as a block
1101  // argument.
1102  if (block) {
1103  // Get the printed format for the argument name.
1104  resultNameStr.clear();
1105  llvm::raw_svector_ostream tmpStream(resultNameStr);
1106  p.printOperand(block->getArgument(i), tmpStream);
1107  // If the name wasn't printable in a way that agreed with portName, make
1108  // sure to print out an explicit portNames attribute.
1109  auto portName = cast<StringAttr>(portNames[i]).getValue();
1110  if (!portName.empty() && tmpStream.str().drop_front() != portName)
1111  printedNamesDontMatch = true;
1112  p << tmpStream.str();
1113  } else {
1114  p.printKeywordOrString(cast<StringAttr>(portNames[i]).getValue());
1115  }
1116 
1117  // Print the port type.
1118  p << ": ";
1119  auto portType = cast<TypeAttr>(portTypes[i]).getValue();
1120  p.printType(portType);
1121 
1122  // Print the optional port symbol.
1123  if (!portSyms.empty()) {
1124  if (!portSyms[i].cast<hw::InnerSymAttr>().empty()) {
1125  p << " sym ";
1126  portSyms[i].cast<hw::InnerSymAttr>().print(p);
1127  }
1128  }
1129 
1130  // Print the port specific annotations. The port annotations array will be
1131  // empty if there are none.
1132  if (!portAnnotations.empty() &&
1133  !cast<ArrayAttr>(portAnnotations[i]).empty()) {
1134  p << " ";
1135  p.printAttribute(portAnnotations[i]);
1136  }
1137 
1138  // Print the port location.
1139  // TODO: `printOptionalLocationSpecifier` will emit aliases for locations,
1140  // even if they are not printed. This will have to be fixed upstream. For
1141  // now, use what was specified on the command line.
1142  if (flags.shouldPrintDebugInfo() && !portLocs.empty())
1143  p.printOptionalLocationSpecifier(cast<LocationAttr>(portLocs[i]));
1144  }
1145 
1146  p << ')';
1147  return printedNamesDontMatch;
1148 }
1149 
1150 /// Parse a list of module ports. If port names are SSA identifiers, then this
1151 /// will populate `entryArgs`.
1152 static ParseResult
1153 parseModulePorts(OpAsmParser &parser, bool hasSSAIdentifiers,
1154  bool supportsSymbols,
1155  SmallVectorImpl<OpAsmParser::Argument> &entryArgs,
1156  SmallVectorImpl<Direction> &portDirections,
1157  SmallVectorImpl<Attribute> &portNames,
1158  SmallVectorImpl<Attribute> &portTypes,
1159  SmallVectorImpl<Attribute> &portAnnotations,
1160  SmallVectorImpl<Attribute> &portSyms,
1161  SmallVectorImpl<Attribute> &portLocs) {
1162  auto *context = parser.getContext();
1163 
1164  auto parseArgument = [&]() -> ParseResult {
1165  // Parse port direction.
1166  if (succeeded(parser.parseOptionalKeyword("out")))
1167  portDirections.push_back(Direction::Out);
1168  else if (succeeded(parser.parseKeyword("in", "or 'out'")))
1169  portDirections.push_back(Direction::In);
1170  else
1171  return failure();
1172 
1173  // This is the location or the port declaration in the IR. If there is no
1174  // other location information, we use this to point to the MLIR.
1175  llvm::SMLoc irLoc;
1176 
1177  if (hasSSAIdentifiers) {
1178  OpAsmParser::Argument arg;
1179  if (parser.parseArgument(arg))
1180  return failure();
1181  entryArgs.push_back(arg);
1182 
1183  // The name of an argument is of the form "%42" or "%id", and since
1184  // parsing succeeded, we know it always has one character.
1185  assert(arg.ssaName.name.size() > 1 && arg.ssaName.name[0] == '%' &&
1186  "Unknown MLIR name");
1187  if (isdigit(arg.ssaName.name[1]))
1188  portNames.push_back(StringAttr::get(context, ""));
1189  else
1190  portNames.push_back(
1191  StringAttr::get(context, arg.ssaName.name.drop_front()));
1192 
1193  // Store the location of the SSA name.
1194  irLoc = arg.ssaName.location;
1195 
1196  } else {
1197  // Parse the port name.
1198  irLoc = parser.getCurrentLocation();
1199  std::string portName;
1200  if (parser.parseKeywordOrString(&portName))
1201  return failure();
1202  portNames.push_back(StringAttr::get(context, portName));
1203  }
1204 
1205  // Parse the port type.
1206  Type portType;
1207  if (parser.parseColonType(portType))
1208  return failure();
1209  portTypes.push_back(TypeAttr::get(portType));
1210 
1211  if (hasSSAIdentifiers)
1212  entryArgs.back().type = portType;
1213 
1214  // Parse the optional port symbol.
1215  if (supportsSymbols) {
1216  hw::InnerSymAttr innerSymAttr;
1217  if (succeeded(parser.parseOptionalKeyword("sym"))) {
1218  NamedAttrList dummyAttrs;
1219  if (parser.parseCustomAttributeWithFallback(
1220  innerSymAttr, ::mlir::Type{},
1222  return ::mlir::failure();
1223  }
1224  }
1225  portSyms.push_back(innerSymAttr);
1226  }
1227 
1228  // Parse the port annotations.
1229  ArrayAttr annos;
1230  auto parseResult = parser.parseOptionalAttribute(annos);
1231  if (!parseResult.has_value())
1232  annos = parser.getBuilder().getArrayAttr({});
1233  else if (failed(*parseResult))
1234  return failure();
1235  portAnnotations.push_back(annos);
1236 
1237  // Parse the optional port location.
1238  std::optional<Location> maybeLoc;
1239  if (failed(parser.parseOptionalLocationSpecifier(maybeLoc)))
1240  return failure();
1241  Location loc = maybeLoc ? *maybeLoc : parser.getEncodedSourceLoc(irLoc);
1242  portLocs.push_back(loc);
1243  if (hasSSAIdentifiers)
1244  entryArgs.back().sourceLoc = loc;
1245 
1246  return success();
1247  };
1248 
1249  // Parse all ports.
1250  return parser.parseCommaSeparatedList(OpAsmParser::Delimiter::Paren,
1251  parseArgument);
1252 }
1253 
1254 /// Print a paramter list for a module or instance.
1255 static void printParameterList(ArrayAttr parameters, OpAsmPrinter &p) {
1256  if (!parameters || parameters.empty())
1257  return;
1258 
1259  p << '<';
1260  llvm::interleaveComma(parameters, p, [&](Attribute param) {
1261  auto paramAttr = cast<ParamDeclAttr>(param);
1262  p << paramAttr.getName().getValue() << ": " << paramAttr.getType();
1263  if (auto value = paramAttr.getValue()) {
1264  p << " = ";
1265  p.printAttributeWithoutType(value);
1266  }
1267  });
1268  p << '>';
1269 }
1270 
1271 static void printFModuleLikeOp(OpAsmPrinter &p, FModuleLike op) {
1272  p << " ";
1273 
1274  // Print the visibility of the module.
1275  StringRef visibilityAttrName = SymbolTable::getVisibilityAttrName();
1276  if (auto visibility = op->getAttrOfType<StringAttr>(visibilityAttrName))
1277  p << visibility.getValue() << ' ';
1278 
1279  // Print the operation and the function name.
1280  p.printSymbolName(op.getModuleName());
1281 
1282  // Print the parameter list (if non-empty).
1283  printParameterList(op->getAttrOfType<ArrayAttr>("parameters"), p);
1284 
1285  // Both modules and external modules have a body, but it is always empty for
1286  // external modules.
1287  Block *body = nullptr;
1288  if (!op->getRegion(0).empty())
1289  body = &op->getRegion(0).front();
1290 
1291  auto portDirections = direction::unpackAttribute(op.getPortDirectionsAttr());
1292 
1293  auto needPortNamesAttr = printModulePorts(
1294  p, body, portDirections, op.getPortNames(), op.getPortTypes(),
1295  op.getPortAnnotations(), op.getPortSymbols(), op.getPortLocations());
1296 
1297  SmallVector<StringRef, 12> omittedAttrs = {
1298  "sym_name", "portDirections", "portTypes", "portAnnotations",
1299  "portSyms", "portLocations", "parameters", visibilityAttrName};
1300 
1301  if (op.getConvention() == Convention::Internal)
1302  omittedAttrs.push_back("convention");
1303 
1304  // We can omit the portNames if they were able to be printed as properly as
1305  // block arguments.
1306  if (!needPortNamesAttr)
1307  omittedAttrs.push_back("portNames");
1308 
1309  // If there are no annotations we can omit the empty array.
1310  if (op->getAttrOfType<ArrayAttr>("annotations").empty())
1311  omittedAttrs.push_back("annotations");
1312 
1313  // If there are no enabled layers, then omit the empty array.
1314  if (auto layers = op->getAttrOfType<ArrayAttr>("layers"))
1315  if (layers.empty())
1316  omittedAttrs.push_back("layers");
1317 
1318  p.printOptionalAttrDictWithKeyword(op->getAttrs(), omittedAttrs);
1319 }
1320 
1321 void FExtModuleOp::print(OpAsmPrinter &p) { printFModuleLikeOp(p, *this); }
1322 
1323 void FIntModuleOp::print(OpAsmPrinter &p) { printFModuleLikeOp(p, *this); }
1324 
1325 void FMemModuleOp::print(OpAsmPrinter &p) { printFModuleLikeOp(p, *this); }
1326 
1327 void FModuleOp::print(OpAsmPrinter &p) {
1328  printFModuleLikeOp(p, *this);
1329 
1330  // Print the body if this is not an external function. Since this block does
1331  // not have terminators, printing the terminator actually just prints the last
1332  // operation.
1333  Region &fbody = getBody();
1334  if (!fbody.empty()) {
1335  p << " ";
1336  p.printRegion(fbody, /*printEntryBlockArgs=*/false,
1337  /*printBlockTerminators=*/true);
1338  }
1339 }
1340 
1341 /// Parse an parameter list if present.
1342 /// module-parameter-list ::= `<` parameter-decl (`,` parameter-decl)* `>`
1343 /// parameter-decl ::= identifier `:` type
1344 /// parameter-decl ::= identifier `:` type `=` attribute
1345 ///
1346 static ParseResult
1347 parseOptionalParameters(OpAsmParser &parser,
1348  SmallVectorImpl<Attribute> &parameters) {
1349 
1350  return parser.parseCommaSeparatedList(
1351  OpAsmParser::Delimiter::OptionalLessGreater, [&]() {
1352  std::string name;
1353  Type type;
1354  Attribute value;
1355 
1356  if (parser.parseKeywordOrString(&name) || parser.parseColonType(type))
1357  return failure();
1358 
1359  // Parse the default value if present.
1360  if (succeeded(parser.parseOptionalEqual())) {
1361  if (parser.parseAttribute(value, type))
1362  return failure();
1363  }
1364 
1365  auto &builder = parser.getBuilder();
1366  parameters.push_back(ParamDeclAttr::get(
1367  builder.getContext(), builder.getStringAttr(name), type, value));
1368  return success();
1369  });
1370 }
1371 
1372 static ParseResult parseFModuleLikeOp(OpAsmParser &parser,
1373  OperationState &result,
1374  bool hasSSAIdentifiers) {
1375  auto *context = result.getContext();
1376  auto &builder = parser.getBuilder();
1377 
1378  // Parse the visibility attribute.
1379  (void)mlir::impl::parseOptionalVisibilityKeyword(parser, result.attributes);
1380 
1381  // Parse the name as a symbol.
1382  StringAttr nameAttr;
1383  if (parser.parseSymbolName(nameAttr, ::mlir::SymbolTable::getSymbolAttrName(),
1384  result.attributes))
1385  return failure();
1386 
1387  // Parse optional parameters.
1388  SmallVector<Attribute, 4> parameters;
1389  if (parseOptionalParameters(parser, parameters))
1390  return failure();
1391  result.addAttribute("parameters", builder.getArrayAttr(parameters));
1392 
1393  // Parse the module ports.
1394  SmallVector<OpAsmParser::Argument> entryArgs;
1395  SmallVector<Direction, 4> portDirections;
1396  SmallVector<Attribute, 4> portNames;
1397  SmallVector<Attribute, 4> portTypes;
1398  SmallVector<Attribute, 4> portAnnotations;
1399  SmallVector<Attribute, 4> portSyms;
1400  SmallVector<Attribute, 4> portLocs;
1401  if (parseModulePorts(parser, hasSSAIdentifiers, /*supportsSymbols=*/true,
1402  entryArgs, portDirections, portNames, portTypes,
1403  portAnnotations, portSyms, portLocs))
1404  return failure();
1405 
1406  // If module attributes are present, parse them.
1407  if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
1408  return failure();
1409 
1410  assert(portNames.size() == portTypes.size());
1411 
1412  // Record the argument and result types as an attribute. This is necessary
1413  // for external modules.
1414 
1415  // Add port directions.
1416  if (!result.attributes.get("portDirections"))
1417  result.addAttribute("portDirections",
1418  direction::packAttribute(context, portDirections));
1419 
1420  // Add port names.
1421  if (!result.attributes.get("portNames"))
1422  result.addAttribute("portNames", builder.getArrayAttr(portNames));
1423 
1424  // Add the port types.
1425  if (!result.attributes.get("portTypes"))
1426  result.addAttribute("portTypes", ArrayAttr::get(context, portTypes));
1427 
1428  // Add the port annotations.
1429  if (!result.attributes.get("portAnnotations")) {
1430  // If there are no portAnnotations, don't add the attribute.
1431  if (llvm::any_of(portAnnotations, [&](Attribute anno) {
1432  return !cast<ArrayAttr>(anno).empty();
1433  }))
1434  result.addAttribute("portAnnotations",
1435  ArrayAttr::get(context, portAnnotations));
1436  }
1437 
1438  // Add port symbols.
1439  if (!result.attributes.get("portSyms")) {
1440  FModuleLike::fixupPortSymsArray(portSyms, builder.getContext());
1441  result.addAttribute("portSyms", builder.getArrayAttr(portSyms));
1442  }
1443 
1444  // Add port locations.
1445  if (!result.attributes.get("portLocations"))
1446  result.addAttribute("portLocations", ArrayAttr::get(context, portLocs));
1447 
1448  // The annotations attribute is always present, but not printed when empty.
1449  if (!result.attributes.get("annotations"))
1450  result.addAttribute("annotations", builder.getArrayAttr({}));
1451 
1452  // The portAnnotations attribute is always present, but not printed when
1453  // empty.
1454  if (!result.attributes.get("portAnnotations"))
1455  result.addAttribute("portAnnotations", builder.getArrayAttr({}));
1456 
1457  // Parse the optional function body.
1458  auto *body = result.addRegion();
1459 
1460  if (hasSSAIdentifiers) {
1461  if (parser.parseRegion(*body, entryArgs))
1462  return failure();
1463  if (body->empty())
1464  body->push_back(new Block());
1465  }
1466  return success();
1467 }
1468 
1469 ParseResult FModuleOp::parse(OpAsmParser &parser, OperationState &result) {
1470  if (parseFModuleLikeOp(parser, result, /*hasSSAIdentifiers=*/true))
1471  return failure();
1472  if (!result.attributes.get("convention"))
1473  result.addAttribute(
1474  "convention",
1475  ConventionAttr::get(result.getContext(), Convention::Internal));
1476  if (!result.attributes.get("layers"))
1477  result.addAttribute("layers", ArrayAttr::get(parser.getContext(), {}));
1478  return success();
1479 }
1480 
1481 ParseResult FExtModuleOp::parse(OpAsmParser &parser, OperationState &result) {
1482  if (parseFModuleLikeOp(parser, result, /*hasSSAIdentifiers=*/false))
1483  return failure();
1484  if (!result.attributes.get("convention"))
1485  result.addAttribute(
1486  "convention",
1487  ConventionAttr::get(result.getContext(), Convention::Internal));
1488  return success();
1489 }
1490 
1491 ParseResult FIntModuleOp::parse(OpAsmParser &parser, OperationState &result) {
1492  return parseFModuleLikeOp(parser, result, /*hasSSAIdentifiers=*/false);
1493 }
1494 
1495 ParseResult FMemModuleOp::parse(OpAsmParser &parser, OperationState &result) {
1496  return parseFModuleLikeOp(parser, result, /*hasSSAIdentifiers=*/false);
1497 }
1498 
1499 LogicalResult FModuleOp::verify() {
1500  // Verify the block arguments.
1501  auto *body = getBodyBlock();
1502  auto portTypes = getPortTypes();
1503  auto portLocs = getPortLocations();
1504  auto numPorts = portTypes.size();
1505 
1506  // Verify that we have the correct number of block arguments.
1507  if (body->getNumArguments() != numPorts)
1508  return emitOpError("entry block must have ")
1509  << numPorts << " arguments to match module signature";
1510 
1511  // Verify the block arguments' types and locations match our attributes.
1512  for (auto [arg, type, loc] : zip(body->getArguments(), portTypes, portLocs)) {
1513  if (arg.getType() != cast<TypeAttr>(type).getValue())
1514  return emitOpError("block argument types should match signature types");
1515  if (arg.getLoc() != cast<LocationAttr>(loc))
1516  return emitOpError(
1517  "block argument locations should match signature locations");
1518  }
1519 
1520  return success();
1521 }
1522 
1523 static LogicalResult
1524 verifyInternalPaths(FModuleLike op,
1525  std::optional<::mlir::ArrayAttr> internalPaths) {
1526  if (!internalPaths)
1527  return success();
1528 
1529  // If internal paths are present, should cover all ports.
1530  if (internalPaths->size() != op.getNumPorts())
1531  return op.emitError("module has inconsistent internal path array with ")
1532  << internalPaths->size() << " entries for " << op.getNumPorts()
1533  << " ports";
1534 
1535  // No internal paths for non-ref-type ports.
1536  for (auto [idx, path, typeattr] : llvm::enumerate(
1537  internalPaths->getAsRange<InternalPathAttr>(), op.getPortTypes())) {
1538  if (path.getPath() &&
1539  !type_isa<RefType>(cast<TypeAttr>(typeattr).getValue())) {
1540  auto diag =
1541  op.emitError("module has internal path for non-ref-type port ")
1542  << op.getPortNameAttr(idx);
1543  return diag.attachNote(op.getPortLocation(idx)) << "this port";
1544  }
1545  }
1546 
1547  return success();
1548 }
1549 
1550 LogicalResult FExtModuleOp::verify() {
1551  if (failed(verifyInternalPaths(*this, getInternalPaths())))
1552  return failure();
1553 
1554  auto params = getParameters();
1555  if (params.empty())
1556  return success();
1557 
1558  auto checkParmValue = [&](Attribute elt) -> bool {
1559  auto param = cast<ParamDeclAttr>(elt);
1560  auto value = param.getValue();
1561  if (isa<IntegerAttr, StringAttr, FloatAttr, hw::ParamVerbatimAttr>(value))
1562  return true;
1563  emitError() << "has unknown extmodule parameter value '"
1564  << param.getName().getValue() << "' = " << value;
1565  return false;
1566  };
1567 
1568  if (!llvm::all_of(params, checkParmValue))
1569  return failure();
1570 
1571  return success();
1572 }
1573 
1574 LogicalResult FIntModuleOp::verify() {
1575  if (failed(verifyInternalPaths(*this, getInternalPaths())))
1576  return failure();
1577 
1578  auto params = getParameters();
1579  if (params.empty())
1580  return success();
1581 
1582  auto checkParmValue = [&](Attribute elt) -> bool {
1583  auto param = cast<ParamDeclAttr>(elt);
1584  auto value = param.getValue();
1585  if (isa<IntegerAttr, StringAttr, FloatAttr>(value))
1586  return true;
1587  emitError() << "has unknown intmodule parameter value '"
1588  << param.getName().getValue() << "' = " << value;
1589  return false;
1590  };
1591 
1592  if (!llvm::all_of(params, checkParmValue))
1593  return failure();
1594 
1595  return success();
1596 }
1597 
1598 static LogicalResult verifyProbeType(RefType refType, Location loc,
1599  CircuitOp circuitOp,
1600  SymbolTableCollection &symbolTable,
1601  Twine start) {
1602  auto layer = refType.getLayer();
1603  if (!layer)
1604  return success();
1605  auto *layerOp = symbolTable.lookupSymbolIn(circuitOp, layer);
1606  if (!layerOp)
1607  return emitError(loc) << start << " associated with layer '" << layer
1608  << "', but this layer was not defined";
1609  if (!isa<LayerOp>(layerOp)) {
1610  auto diag = emitError(loc)
1611  << start << " associated with layer '" << layer
1612  << "', but symbol '" << layer << "' does not refer to a '"
1613  << LayerOp::getOperationName() << "' op";
1614  return diag.attachNote(layerOp->getLoc()) << "symbol refers to this op";
1615  }
1616  return success();
1617 }
1618 
1619 static LogicalResult verifyPortSymbolUses(FModuleLike module,
1620  SymbolTableCollection &symbolTable) {
1621  // verify types in ports.
1622  auto circuitOp = module->getParentOfType<CircuitOp>();
1623  for (size_t i = 0, e = module.getNumPorts(); i < e; ++i) {
1624  auto type = module.getPortType(i);
1625 
1626  if (auto refType = type_dyn_cast<RefType>(type)) {
1627  if (failed(verifyProbeType(
1628  refType, module.getPortLocation(i), circuitOp, symbolTable,
1629  Twine("probe port '") + module.getPortName(i) + "' is")))
1630  return failure();
1631  continue;
1632  }
1633 
1634  if (auto classType = dyn_cast<ClassType>(type)) {
1635  auto className = classType.getNameAttr();
1636  auto classOp = dyn_cast_or_null<ClassLike>(
1637  symbolTable.lookupSymbolIn(circuitOp, className));
1638  if (!classOp)
1639  return module.emitOpError() << "references unknown class " << className;
1640 
1641  // verify that the result type agrees with the class definition.
1642  if (failed(classOp.verifyType(classType,
1643  [&]() { return module.emitOpError(); })))
1644  return failure();
1645  continue;
1646  }
1647  }
1648 
1649  return success();
1650 }
1651 
1652 LogicalResult FModuleOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1653  if (failed(
1654  verifyPortSymbolUses(cast<FModuleLike>(getOperation()), symbolTable)))
1655  return failure();
1656 
1657  auto circuitOp = (*this)->getParentOfType<CircuitOp>();
1658  for (auto layer : getLayers()) {
1659  if (!symbolTable.lookupSymbolIn(circuitOp, cast<SymbolRefAttr>(layer)))
1660  return emitOpError() << "enables unknown layer '" << layer << "'";
1661  }
1662 
1663  return success();
1664 }
1665 
1666 LogicalResult
1667 FExtModuleOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1668  return verifyPortSymbolUses(cast<FModuleLike>(getOperation()), symbolTable);
1669 }
1670 
1671 LogicalResult
1672 FIntModuleOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1673  return verifyPortSymbolUses(cast<FModuleLike>(getOperation()), symbolTable);
1674 }
1675 
1676 LogicalResult
1677 FMemModuleOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1678  return verifyPortSymbolUses(cast<FModuleLike>(getOperation()), symbolTable);
1679 }
1680 
1681 void FModuleOp::getAsmBlockArgumentNames(mlir::Region &region,
1682  mlir::OpAsmSetValueNameFn setNameFn) {
1683  getAsmBlockArgumentNamesImpl(getOperation(), region, setNameFn);
1684 }
1685 
1686 void FExtModuleOp::getAsmBlockArgumentNames(
1687  mlir::Region &region, mlir::OpAsmSetValueNameFn setNameFn) {
1688  getAsmBlockArgumentNamesImpl(getOperation(), region, setNameFn);
1689 }
1690 
1691 void FIntModuleOp::getAsmBlockArgumentNames(
1692  mlir::Region &region, mlir::OpAsmSetValueNameFn setNameFn) {
1693  getAsmBlockArgumentNamesImpl(getOperation(), region, setNameFn);
1694 }
1695 
1696 void FMemModuleOp::getAsmBlockArgumentNames(
1697  mlir::Region &region, mlir::OpAsmSetValueNameFn setNameFn) {
1698  getAsmBlockArgumentNamesImpl(getOperation(), region, setNameFn);
1699 }
1700 
1701 ArrayAttr FMemModuleOp::getParameters() { return {}; }
1702 
1703 ArrayAttr FModuleOp::getParameters() { return {}; }
1704 
1705 Convention FIntModuleOp::getConvention() { return Convention::Internal; }
1706 
1707 ConventionAttr FIntModuleOp::getConventionAttr() {
1708  return ConventionAttr::get(getContext(), getConvention());
1709 }
1710 
1711 Convention FMemModuleOp::getConvention() { return Convention::Internal; }
1712 
1713 ConventionAttr FMemModuleOp::getConventionAttr() {
1714  return ConventionAttr::get(getContext(), getConvention());
1715 }
1716 
1717 //===----------------------------------------------------------------------===//
1718 // ClassLike Helpers
1719 //===----------------------------------------------------------------------===//
1720 
1722  ClassLike classOp, ClassType type,
1723  function_ref<InFlightDiagnostic()> emitError) {
1724  // This check is probably not required, but done for sanity.
1725  auto name = type.getNameAttr().getAttr();
1726  auto expectedName = classOp.getModuleNameAttr();
1727  if (name != expectedName)
1728  return emitError() << "type has wrong name, got " << name << ", expected "
1729  << expectedName;
1730 
1731  auto elements = type.getElements();
1732  auto numElements = elements.size();
1733  auto expectedNumElements = classOp.getNumPorts();
1734  if (numElements != expectedNumElements)
1735  return emitError() << "has wrong number of ports, got " << numElements
1736  << ", expected " << expectedNumElements;
1737 
1738  auto portNames = classOp.getPortNames();
1739  auto portDirections = classOp.getPortDirections();
1740  auto portTypes = classOp.getPortTypes();
1741 
1742  for (unsigned i = 0; i < numElements; ++i) {
1743  auto element = elements[i];
1744 
1745  auto name = element.name;
1746  auto expectedName = portNames[i];
1747  if (name != expectedName)
1748  return emitError() << "port #" << i << " has wrong name, got " << name
1749  << ", expected " << expectedName;
1750 
1751  auto direction = element.direction;
1752  auto expectedDirection = Direction(portDirections[i]);
1753  if (direction != expectedDirection)
1754  return emitError() << "port " << name << " has wrong direction, got "
1755  << direction::toString(direction) << ", expected "
1756  << direction::toString(expectedDirection);
1757 
1758  auto type = element.type;
1759  auto expectedType = cast<TypeAttr>(portTypes[i]).getValue();
1760  if (type != expectedType)
1761  return emitError() << "port " << name << " has wrong type, got " << type
1762  << ", expected " << expectedType;
1763  }
1764 
1765  return success();
1766 }
1767 
1768 ClassType firrtl::detail::getInstanceTypeForClassLike(ClassLike classOp) {
1769  auto n = classOp.getNumPorts();
1770  SmallVector<ClassElement> elements;
1771  elements.reserve(n);
1772  for (size_t i = 0; i < n; ++i)
1773  elements.push_back({classOp.getPortNameAttr(i), classOp.getPortType(i),
1774  classOp.getPortDirection(i)});
1775  auto name = FlatSymbolRefAttr::get(classOp.getNameAttr());
1776  return ClassType::get(name, elements);
1777 }
1778 
1779 static ParseResult parseClassLike(OpAsmParser &parser, OperationState &result,
1780  bool hasSSAIdentifiers) {
1781  auto *context = result.getContext();
1782  auto &builder = parser.getBuilder();
1783 
1784  // Parse the visibility attribute.
1785  (void)mlir::impl::parseOptionalVisibilityKeyword(parser, result.attributes);
1786 
1787  // Parse the name as a symbol.
1788  StringAttr nameAttr;
1789  if (parser.parseSymbolName(nameAttr, ::mlir::SymbolTable::getSymbolAttrName(),
1790  result.attributes))
1791  return failure();
1792 
1793  // Parse the module ports.
1794  SmallVector<OpAsmParser::Argument> entryArgs;
1795  SmallVector<Direction, 4> portDirections;
1796  SmallVector<Attribute, 4> portNames;
1797  SmallVector<Attribute, 4> portTypes;
1798  SmallVector<Attribute, 4> portAnnotations;
1799  SmallVector<Attribute, 4> portSyms;
1800  SmallVector<Attribute, 4> portLocs;
1801  if (parseModulePorts(parser, hasSSAIdentifiers,
1802  /*supportsSymbols=*/false, entryArgs, portDirections,
1803  portNames, portTypes, portAnnotations, portSyms,
1804  portLocs))
1805  return failure();
1806 
1807  // Ports on ClassLike ops cannot have annotations
1808  for (auto annos : portAnnotations)
1809  if (!cast<ArrayAttr>(annos).empty())
1810  return failure();
1811 
1812  // If attributes are present, parse them.
1813  if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
1814  return failure();
1815 
1816  assert(portNames.size() == portTypes.size());
1817 
1818  // Record the argument and result types as an attribute. This is necessary
1819  // for external modules.
1820 
1821  // Add port directions.
1822  if (!result.attributes.get("portDirections"))
1823  result.addAttribute("portDirections",
1824  direction::packAttribute(context, portDirections));
1825 
1826  // Add port names.
1827  if (!result.attributes.get("portNames"))
1828  result.addAttribute("portNames", builder.getArrayAttr(portNames));
1829 
1830  // Add the port types.
1831  if (!result.attributes.get("portTypes"))
1832  result.addAttribute("portTypes", builder.getArrayAttr(portTypes));
1833 
1834  // Add the port symbols.
1835  if (!result.attributes.get("portSyms")) {
1836  FModuleLike::fixupPortSymsArray(portSyms, builder.getContext());
1837  result.addAttribute("portSyms", builder.getArrayAttr(portSyms));
1838  }
1839 
1840  // Add port locations.
1841  if (!result.attributes.get("portLocations"))
1842  result.addAttribute("portLocations", ArrayAttr::get(context, portLocs));
1843 
1844  // Notably missing compared to other FModuleLike, we do not track port
1845  // annotations, nor port symbols, on classes.
1846 
1847  // Add the region (unused by extclass).
1848  auto *bodyRegion = result.addRegion();
1849 
1850  if (hasSSAIdentifiers) {
1851  if (parser.parseRegion(*bodyRegion, entryArgs))
1852  return failure();
1853  if (bodyRegion->empty())
1854  bodyRegion->push_back(new Block());
1855  }
1856 
1857  return success();
1858 }
1859 
1860 static void printClassLike(OpAsmPrinter &p, ClassLike op) {
1861  p << ' ';
1862 
1863  // Print the visibility of the class.
1864  StringRef visibilityAttrName = SymbolTable::getVisibilityAttrName();
1865  if (auto visibility = op->getAttrOfType<StringAttr>(visibilityAttrName))
1866  p << visibility.getValue() << ' ';
1867 
1868  // Print the class name.
1869  p.printSymbolName(op.getName());
1870 
1871  // Both classes and external classes have a body, but it is always empty for
1872  // external classes.
1873  Region &region = op->getRegion(0);
1874  Block *body = nullptr;
1875  if (!region.empty())
1876  body = &region.front();
1877 
1878  auto portDirections = direction::unpackAttribute(op.getPortDirectionsAttr());
1879 
1880  auto needPortNamesAttr = printModulePorts(
1881  p, body, portDirections, op.getPortNames(), op.getPortTypes(), {},
1882  op.getPortSymbols(), op.getPortLocations());
1883 
1884  // Print the attr-dict.
1885  SmallVector<StringRef, 8> omittedAttrs = {
1886  "sym_name", "portNames", "portTypes", "portDirections",
1887  "portSyms", "portLocations", visibilityAttrName};
1888 
1889  // We can omit the portNames if they were able to be printed as properly as
1890  // block arguments.
1891  if (!needPortNamesAttr)
1892  omittedAttrs.push_back("portNames");
1893 
1894  p.printOptionalAttrDictWithKeyword(op->getAttrs(), omittedAttrs);
1895 
1896  // print the body if it exists.
1897  if (!region.empty()) {
1898  p << " ";
1899  auto printEntryBlockArgs = false;
1900  auto printBlockTerminators = false;
1901  p.printRegion(region, printEntryBlockArgs, printBlockTerminators);
1902  }
1903 }
1904 
1905 //===----------------------------------------------------------------------===//
1906 // ClassOp
1907 //===----------------------------------------------------------------------===//
1908 
1909 void ClassOp::build(OpBuilder &builder, OperationState &result, StringAttr name,
1910  ArrayRef<PortInfo> ports) {
1911  assert(
1912  llvm::all_of(ports,
1913  [](const auto &port) { return port.annotations.empty(); }) &&
1914  "class ports may not have annotations");
1915 
1916  buildClass(builder, result, name, ports);
1917 
1918  // Create a region and a block for the body.
1919  auto *bodyRegion = result.regions[0].get();
1920  Block *body = new Block();
1921  bodyRegion->push_back(body);
1922 
1923  // Add arguments to the body block.
1924  for (auto &elt : ports)
1925  body->addArgument(elt.type, elt.loc);
1926 }
1927 
1928 void ClassOp::print(OpAsmPrinter &p) {
1929  printClassLike(p, cast<ClassLike>(getOperation()));
1930 }
1931 
1932 ParseResult ClassOp::parse(OpAsmParser &parser, OperationState &result) {
1933  auto hasSSAIdentifiers = true;
1934  return parseClassLike(parser, result, hasSSAIdentifiers);
1935 }
1936 
1937 LogicalResult ClassOp::verify() {
1938  for (auto operand : getBodyBlock()->getArguments()) {
1939  auto type = operand.getType();
1940  if (!isa<PropertyType>(type)) {
1941  emitOpError("ports on a class must be properties");
1942  return failure();
1943  }
1944  }
1945 
1946  return success();
1947 }
1948 
1949 LogicalResult
1950 ClassOp::verifySymbolUses(::mlir::SymbolTableCollection &symbolTable) {
1951  return verifyPortSymbolUses(cast<FModuleLike>(getOperation()), symbolTable);
1952 }
1953 
1954 void ClassOp::getAsmBlockArgumentNames(mlir::Region &region,
1955  mlir::OpAsmSetValueNameFn setNameFn) {
1956  getAsmBlockArgumentNamesImpl(getOperation(), region, setNameFn);
1957 }
1958 
1959 SmallVector<PortInfo> ClassOp::getPorts() {
1960  return ::getPortImpl(cast<FModuleLike>((Operation *)*this));
1961 }
1962 
1963 void ClassOp::erasePorts(const llvm::BitVector &portIndices) {
1964  ::erasePorts(cast<FModuleLike>((Operation *)*this), portIndices);
1965  getBodyBlock()->eraseArguments(portIndices);
1966 }
1967 
1968 void ClassOp::insertPorts(ArrayRef<std::pair<unsigned, PortInfo>> ports) {
1969  ::insertPorts(cast<FModuleLike>((Operation *)*this), ports);
1970 }
1971 
1972 Convention ClassOp::getConvention() { return Convention::Internal; }
1973 
1974 ConventionAttr ClassOp::getConventionAttr() {
1975  return ConventionAttr::get(getContext(), getConvention());
1976 }
1977 
1978 ArrayAttr ClassOp::getParameters() { return {}; }
1979 
1980 ArrayAttr ClassOp::getPortAnnotationsAttr() {
1981  return ArrayAttr::get(getContext(), {});
1982 }
1983 
1984 ArrayAttr ClassOp::getLayersAttr() { return ArrayAttr::get(getContext(), {}); }
1985 
1986 ArrayRef<Attribute> ClassOp::getLayers() { return getLayersAttr(); }
1987 
1988 SmallVector<::circt::hw::PortInfo> ClassOp::getPortList() {
1989  return ::getPortListImpl(*this);
1990 }
1991 
1993  return ::getPortImpl(*this, idx);
1994 }
1995 
1996 BlockArgument ClassOp::getArgument(size_t portNumber) {
1997  return getBodyBlock()->getArgument(portNumber);
1998 }
1999 
2000 bool ClassOp::canDiscardOnUseEmpty() {
2001  // ClassOps are referenced by ClassTypes, and these uses are not
2002  // discoverable by the symbol infrastructure. Return false here to prevent
2003  // passes like symbolDCE from removing our classes.
2004  return false;
2005 }
2006 
2007 //===----------------------------------------------------------------------===//
2008 // ExtClassOp
2009 //===----------------------------------------------------------------------===//
2010 
2011 void ExtClassOp::build(OpBuilder &builder, OperationState &result,
2012  StringAttr name, ArrayRef<PortInfo> ports) {
2013  assert(
2014  llvm::all_of(ports,
2015  [](const auto &port) { return port.annotations.empty(); }) &&
2016  "class ports may not have annotations");
2017 
2018  buildClass(builder, result, name, ports);
2019 }
2020 
2021 void ExtClassOp::print(OpAsmPrinter &p) {
2022  printClassLike(p, cast<ClassLike>(getOperation()));
2023 }
2024 
2025 ParseResult ExtClassOp::parse(OpAsmParser &parser, OperationState &result) {
2026  auto hasSSAIdentifiers = false;
2027  return parseClassLike(parser, result, hasSSAIdentifiers);
2028 }
2029 
2030 LogicalResult
2031 ExtClassOp::verifySymbolUses(::mlir::SymbolTableCollection &symbolTable) {
2032  return verifyPortSymbolUses(cast<FModuleLike>(getOperation()), symbolTable);
2033 }
2034 
2035 void ExtClassOp::getAsmBlockArgumentNames(mlir::Region &region,
2036  mlir::OpAsmSetValueNameFn setNameFn) {
2037  getAsmBlockArgumentNamesImpl(getOperation(), region, setNameFn);
2038 }
2039 
2040 SmallVector<PortInfo> ExtClassOp::getPorts() {
2041  return ::getPortImpl(cast<FModuleLike>((Operation *)*this));
2042 }
2043 
2044 void ExtClassOp::erasePorts(const llvm::BitVector &portIndices) {
2045  ::erasePorts(cast<FModuleLike>((Operation *)*this), portIndices);
2046 }
2047 
2048 void ExtClassOp::insertPorts(ArrayRef<std::pair<unsigned, PortInfo>> ports) {
2049  ::insertPorts(cast<FModuleLike>((Operation *)*this), ports);
2050 }
2051 
2052 Convention ExtClassOp::getConvention() { return Convention::Internal; }
2053 
2054 ConventionAttr ExtClassOp::getConventionAttr() {
2055  return ConventionAttr::get(getContext(), getConvention());
2056 }
2057 
2058 ArrayAttr ExtClassOp::getLayersAttr() {
2059  return ArrayAttr::get(getContext(), {});
2060 }
2061 
2062 ArrayRef<Attribute> ExtClassOp::getLayers() { return getLayersAttr(); }
2063 
2064 ArrayAttr ExtClassOp::getParameters() { return {}; }
2065 
2066 ArrayAttr ExtClassOp::getPortAnnotationsAttr() {
2067  return ArrayAttr::get(getContext(), {});
2068 }
2069 
2070 SmallVector<::circt::hw::PortInfo> ExtClassOp::getPortList() {
2071  return ::getPortListImpl(*this);
2072 }
2073 
2075  return ::getPortImpl(*this, idx);
2076 }
2077 
2078 bool ExtClassOp::canDiscardOnUseEmpty() {
2079  // ClassOps are referenced by ClassTypes, and these uses are not
2080  // discovereable by the symbol infrastructure. Return false here to prevent
2081  // passes like symbolDCE from removing our classes.
2082  return false;
2083 }
2084 
2085 //===----------------------------------------------------------------------===//
2086 // InstanceOp
2087 //===----------------------------------------------------------------------===//
2088 
2089 SmallVector<::circt::hw::PortInfo> InstanceOp::getPortList() {
2090  auto circuit = (*this)->getParentOfType<CircuitOp>();
2091  if (!circuit)
2092  llvm::report_fatal_error("instance op not in circuit");
2093  return circuit.lookupSymbol<hw::PortList>(getModuleNameAttr()).getPortList();
2094 }
2095 
2096 void InstanceOp::build(
2097  OpBuilder &builder, OperationState &result, TypeRange resultTypes,
2098  StringRef moduleName, StringRef name, NameKindEnum nameKind,
2099  ArrayRef<Direction> portDirections, ArrayRef<Attribute> portNames,
2100  ArrayRef<Attribute> annotations, ArrayRef<Attribute> portAnnotations,
2101  ArrayRef<Attribute> layers, bool lowerToBind, StringAttr innerSym) {
2102  build(builder, result, resultTypes, moduleName, name, nameKind,
2103  portDirections, portNames, annotations, portAnnotations, layers,
2104  lowerToBind,
2105  innerSym ? hw::InnerSymAttr::get(innerSym) : hw::InnerSymAttr());
2106 }
2107 
2108 void InstanceOp::build(
2109  OpBuilder &builder, OperationState &result, TypeRange resultTypes,
2110  StringRef moduleName, StringRef name, NameKindEnum nameKind,
2111  ArrayRef<Direction> portDirections, ArrayRef<Attribute> portNames,
2112  ArrayRef<Attribute> annotations, ArrayRef<Attribute> portAnnotations,
2113  ArrayRef<Attribute> layers, bool lowerToBind, hw::InnerSymAttr innerSym) {
2114  result.addTypes(resultTypes);
2115  result.addAttribute("moduleName",
2116  SymbolRefAttr::get(builder.getContext(), moduleName));
2117  result.addAttribute("name", builder.getStringAttr(name));
2118  result.addAttribute(
2119  "portDirections",
2120  direction::packAttribute(builder.getContext(), portDirections));
2121  result.addAttribute("portNames", builder.getArrayAttr(portNames));
2122  result.addAttribute("annotations", builder.getArrayAttr(annotations));
2123  result.addAttribute("layers", builder.getArrayAttr(layers));
2124  if (lowerToBind)
2125  result.addAttribute("lowerToBind", builder.getUnitAttr());
2126  if (innerSym)
2127  result.addAttribute("inner_sym", innerSym);
2128  result.addAttribute("nameKind",
2129  NameKindEnumAttr::get(builder.getContext(), nameKind));
2130 
2131  if (portAnnotations.empty()) {
2132  SmallVector<Attribute, 16> portAnnotationsVec(resultTypes.size(),
2133  builder.getArrayAttr({}));
2134  result.addAttribute("portAnnotations",
2135  builder.getArrayAttr(portAnnotationsVec));
2136  } else {
2137  assert(portAnnotations.size() == resultTypes.size());
2138  result.addAttribute("portAnnotations",
2139  builder.getArrayAttr(portAnnotations));
2140  }
2141 }
2142 
2143 void InstanceOp::build(OpBuilder &builder, OperationState &result,
2144  FModuleLike module, StringRef name,
2145  NameKindEnum nameKind, ArrayRef<Attribute> annotations,
2146  ArrayRef<Attribute> portAnnotations, bool lowerToBind,
2147  hw::InnerSymAttr innerSym) {
2148 
2149  // Gather the result types.
2150  SmallVector<Type> resultTypes;
2151  resultTypes.reserve(module.getNumPorts());
2152  llvm::transform(
2153  module.getPortTypes(), std::back_inserter(resultTypes),
2154  [](Attribute typeAttr) { return cast<TypeAttr>(typeAttr).getValue(); });
2155 
2156  // Create the port annotations.
2157  ArrayAttr portAnnotationsAttr;
2158  if (portAnnotations.empty()) {
2159  portAnnotationsAttr = builder.getArrayAttr(SmallVector<Attribute, 16>(
2160  resultTypes.size(), builder.getArrayAttr({})));
2161  } else {
2162  portAnnotationsAttr = builder.getArrayAttr(portAnnotations);
2163  }
2164 
2165  return build(
2166  builder, result, resultTypes,
2167  SymbolRefAttr::get(builder.getContext(), module.getModuleNameAttr()),
2168  builder.getStringAttr(name),
2169  NameKindEnumAttr::get(builder.getContext(), nameKind),
2170  module.getPortDirectionsAttr(), module.getPortNamesAttr(),
2171  builder.getArrayAttr(annotations), portAnnotationsAttr,
2172  module.getLayersAttr(), lowerToBind ? builder.getUnitAttr() : UnitAttr(),
2173  innerSym);
2174 }
2175 
2176 void InstanceOp::build(OpBuilder &builder, OperationState &odsState,
2177  ArrayRef<PortInfo> ports, StringRef moduleName,
2178  StringRef name, NameKindEnum nameKind,
2179  ArrayRef<Attribute> annotations,
2180  ArrayRef<Attribute> layers, bool lowerToBind,
2181  hw::InnerSymAttr innerSym) {
2182  // Gather the result types.
2183  SmallVector<Type> newResultTypes;
2184  SmallVector<Direction> newPortDirections;
2185  SmallVector<Attribute> newPortNames;
2186  SmallVector<Attribute> newPortAnnotations;
2187  for (auto &p : ports) {
2188  newResultTypes.push_back(p.type);
2189  newPortDirections.push_back(p.direction);
2190  newPortNames.push_back(p.name);
2191  newPortAnnotations.push_back(p.annotations.getArrayAttr());
2192  }
2193 
2194  return build(builder, odsState, newResultTypes, moduleName, name, nameKind,
2195  newPortDirections, newPortNames, annotations, newPortAnnotations,
2196  layers, lowerToBind, innerSym);
2197 }
2198 
2199 LogicalResult InstanceOp::verify() {
2200  // The instance may only be instantiated under its required layers.
2201  auto ambientLayers = getAmbientLayersAt(getOperation());
2202  SmallVector<SymbolRefAttr> missingLayers;
2203  for (auto layer : getLayersAttr().getAsRange<SymbolRefAttr>())
2204  if (!isLayerCompatibleWith(layer, ambientLayers))
2205  missingLayers.push_back(layer);
2206 
2207  if (missingLayers.empty())
2208  return success();
2209 
2210  auto diag =
2211  emitOpError("ambient layers are insufficient to instantiate module");
2212  auto &note = diag.attachNote();
2213  note << "missing layer requirements: ";
2214  interleaveComma(missingLayers, note);
2215  return failure();
2216 }
2217 
2218 /// Builds a new `InstanceOp` with the ports listed in `portIndices` erased, and
2219 /// updates any users of the remaining ports to point at the new instance.
2220 InstanceOp InstanceOp::erasePorts(OpBuilder &builder,
2221  const llvm::BitVector &portIndices) {
2222  assert(portIndices.size() >= getNumResults() &&
2223  "portIndices is not at least as large as getNumResults()");
2224 
2225  if (portIndices.none())
2226  return *this;
2227 
2228  SmallVector<Type> newResultTypes = removeElementsAtIndices<Type>(
2229  SmallVector<Type>(result_type_begin(), result_type_end()), portIndices);
2230  SmallVector<Direction> newPortDirections = removeElementsAtIndices<Direction>(
2231  direction::unpackAttribute(getPortDirectionsAttr()), portIndices);
2232  SmallVector<Attribute> newPortNames =
2233  removeElementsAtIndices(getPortNames().getValue(), portIndices);
2234  SmallVector<Attribute> newPortAnnotations =
2235  removeElementsAtIndices(getPortAnnotations().getValue(), portIndices);
2236 
2237  auto newOp = builder.create<InstanceOp>(
2238  getLoc(), newResultTypes, getModuleName(), getName(), getNameKind(),
2239  newPortDirections, newPortNames, getAnnotations().getValue(),
2240  newPortAnnotations, getLayers(), getLowerToBind(), getInnerSymAttr());
2241 
2242  for (unsigned oldIdx = 0, newIdx = 0, numOldPorts = getNumResults();
2243  oldIdx != numOldPorts; ++oldIdx) {
2244  if (portIndices.test(oldIdx)) {
2245  assert(getResult(oldIdx).use_empty() && "removed instance port has uses");
2246  continue;
2247  }
2248  getResult(oldIdx).replaceAllUsesWith(newOp.getResult(newIdx));
2249  ++newIdx;
2250  }
2251 
2252  // Compy over "output_file" information so that this is not lost when ports
2253  // are erased.
2254  //
2255  // TODO: Other attributes may need to be copied over.
2256  if (auto outputFile = (*this)->getAttr("output_file"))
2257  newOp->setAttr("output_file", outputFile);
2258 
2259  return newOp;
2260 }
2261 
2262 ArrayAttr InstanceOp::getPortAnnotation(unsigned portIdx) {
2263  assert(portIdx < getNumResults() &&
2264  "index should be smaller than result number");
2265  return cast<ArrayAttr>(getPortAnnotations()[portIdx]);
2266 }
2267 
2268 void InstanceOp::setAllPortAnnotations(ArrayRef<Attribute> annotations) {
2269  assert(annotations.size() == getNumResults() &&
2270  "number of annotations is not equal to result number");
2271  (*this)->setAttr("portAnnotations",
2272  ArrayAttr::get(getContext(), annotations));
2273 }
2274 
2275 InstanceOp
2276 InstanceOp::cloneAndInsertPorts(ArrayRef<std::pair<unsigned, PortInfo>> ports) {
2277  auto portSize = ports.size();
2278  auto newPortCount = getNumResults() + portSize;
2279  SmallVector<Direction> newPortDirections;
2280  newPortDirections.reserve(newPortCount);
2281  SmallVector<Attribute> newPortNames;
2282  newPortNames.reserve(newPortCount);
2283  SmallVector<Type> newPortTypes;
2284  newPortTypes.reserve(newPortCount);
2285  SmallVector<Attribute> newPortAnnos;
2286  newPortAnnos.reserve(newPortCount);
2287 
2288  unsigned oldIndex = 0;
2289  unsigned newIndex = 0;
2290  while (oldIndex + newIndex < newPortCount) {
2291  // Check if we should insert a port here.
2292  if (newIndex < portSize && ports[newIndex].first == oldIndex) {
2293  auto &newPort = ports[newIndex].second;
2294  newPortDirections.push_back(newPort.direction);
2295  newPortNames.push_back(newPort.name);
2296  newPortTypes.push_back(newPort.type);
2297  newPortAnnos.push_back(newPort.annotations.getArrayAttr());
2298  ++newIndex;
2299  } else {
2300  // Copy the next old port.
2301  newPortDirections.push_back(getPortDirection(oldIndex));
2302  newPortNames.push_back(getPortName(oldIndex));
2303  newPortTypes.push_back(getType(oldIndex));
2304  newPortAnnos.push_back(getPortAnnotation(oldIndex));
2305  ++oldIndex;
2306  }
2307  }
2308 
2309  // Create a new instance op with the reset inserted.
2310  return OpBuilder(*this).create<InstanceOp>(
2311  getLoc(), newPortTypes, getModuleName(), getName(), getNameKind(),
2312  newPortDirections, newPortNames, getAnnotations().getValue(),
2313  newPortAnnos, getLayers(), getLowerToBind(), getInnerSymAttr());
2314 }
2315 
2316 LogicalResult InstanceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
2317  return instance_like_impl::verifyReferencedModule(*this, symbolTable,
2318  getModuleNameAttr());
2319 }
2320 
2321 StringRef InstanceOp::getInstanceName() { return getName(); }
2322 
2323 StringAttr InstanceOp::getInstanceNameAttr() { return getNameAttr(); }
2324 
2325 void InstanceOp::print(OpAsmPrinter &p) {
2326  // Print the instance name.
2327  p << " ";
2328  p.printKeywordOrString(getName());
2329  if (auto attr = getInnerSymAttr()) {
2330  p << " sym ";
2331  p.printSymbolName(attr.getSymName());
2332  }
2333  if (getNameKindAttr().getValue() != NameKindEnum::DroppableName)
2334  p << ' ' << stringifyNameKindEnum(getNameKindAttr().getValue());
2335 
2336  // Print the attr-dict.
2337  SmallVector<StringRef, 10> omittedAttrs = {
2338  "moduleName", "name", "portDirections",
2339  "portNames", "portTypes", "portAnnotations",
2340  "inner_sym", "nameKind"};
2341  if (getAnnotations().empty())
2342  omittedAttrs.push_back("annotations");
2343  if (getLayers().empty())
2344  omittedAttrs.push_back("layers");
2345  p.printOptionalAttrDict((*this)->getAttrs(), omittedAttrs);
2346 
2347  // Print the module name.
2348  p << " ";
2349  p.printSymbolName(getModuleName());
2350 
2351  // Collect all the result types as TypeAttrs for printing.
2352  SmallVector<Attribute> portTypes;
2353  portTypes.reserve(getNumResults());
2354  llvm::transform(getResultTypes(), std::back_inserter(portTypes),
2355  &TypeAttr::get);
2356  auto portDirections = direction::unpackAttribute(getPortDirectionsAttr());
2357  printModulePorts(p, /*block=*/nullptr, portDirections,
2358  getPortNames().getValue(), portTypes,
2359  getPortAnnotations().getValue(), {}, {});
2360 }
2361 
2362 ParseResult InstanceOp::parse(OpAsmParser &parser, OperationState &result) {
2363  auto *context = parser.getContext();
2364  auto &resultAttrs = result.attributes;
2365 
2366  std::string name;
2367  hw::InnerSymAttr innerSymAttr;
2368  FlatSymbolRefAttr moduleName;
2369  SmallVector<OpAsmParser::Argument> entryArgs;
2370  SmallVector<Direction, 4> portDirections;
2371  SmallVector<Attribute, 4> portNames;
2372  SmallVector<Attribute, 4> portTypes;
2373  SmallVector<Attribute, 4> portAnnotations;
2374  SmallVector<Attribute, 4> portSyms;
2375  SmallVector<Attribute, 4> portLocs;
2376  NameKindEnumAttr nameKind;
2377 
2378  if (parser.parseKeywordOrString(&name))
2379  return failure();
2380  if (succeeded(parser.parseOptionalKeyword("sym"))) {
2381  if (parser.parseCustomAttributeWithFallback(
2382  innerSymAttr, ::mlir::Type{},
2383  hw::InnerSymbolTable::getInnerSymbolAttrName(),
2384  result.attributes)) {
2385  return ::mlir::failure();
2386  }
2387  }
2388  if (parseNameKind(parser, nameKind) ||
2389  parser.parseOptionalAttrDict(result.attributes) ||
2390  parser.parseAttribute(moduleName, "moduleName", resultAttrs) ||
2391  parseModulePorts(parser, /*hasSSAIdentifiers=*/false,
2392  /*supportsSymbols=*/false, entryArgs, portDirections,
2393  portNames, portTypes, portAnnotations, portSyms,
2394  portLocs))
2395  return failure();
2396 
2397  // Add the attributes. We let attributes defined in the attr-dict override
2398  // attributes parsed out of the module signature.
2399  if (!resultAttrs.get("moduleName"))
2400  result.addAttribute("moduleName", moduleName);
2401  if (!resultAttrs.get("name"))
2402  result.addAttribute("name", StringAttr::get(context, name));
2403  result.addAttribute("nameKind", nameKind);
2404  if (!resultAttrs.get("portDirections"))
2405  result.addAttribute("portDirections",
2406  direction::packAttribute(context, portDirections));
2407  if (!resultAttrs.get("portNames"))
2408  result.addAttribute("portNames", ArrayAttr::get(context, portNames));
2409  if (!resultAttrs.get("portAnnotations"))
2410  result.addAttribute("portAnnotations",
2411  ArrayAttr::get(context, portAnnotations));
2412 
2413  // Annotations, layers, and LowerToBind are omitted in the printed format
2414  // if they are empty, empty, and false (respectively).
2415  if (!resultAttrs.get("annotations"))
2416  resultAttrs.append("annotations", parser.getBuilder().getArrayAttr({}));
2417  if (!resultAttrs.get("layers"))
2418  resultAttrs.append("layers", parser.getBuilder().getArrayAttr({}));
2419 
2420  // Add result types.
2421  result.types.reserve(portTypes.size());
2422  llvm::transform(
2423  portTypes, std::back_inserter(result.types),
2424  [](Attribute typeAttr) { return cast<TypeAttr>(typeAttr).getValue(); });
2425 
2426  return success();
2427 }
2428 
2430  StringRef base = getName();
2431  if (base.empty())
2432  base = "inst";
2433 
2434  for (size_t i = 0, e = (*this)->getNumResults(); i != e; ++i) {
2435  setNameFn(getResult(i), (base + "_" + getPortNameStr(i)).str());
2436  }
2437 }
2438 
2439 std::optional<size_t> InstanceOp::getTargetResultIndex() {
2440  // Inner symbols on instance operations target the op not any result.
2441  return std::nullopt;
2442 }
2443 
2444 // -----------------------------------------------------------------------------
2445 // InstanceChoiceOp
2446 // -----------------------------------------------------------------------------
2447 
2448 void InstanceChoiceOp::build(
2449  OpBuilder &builder, OperationState &result, FModuleLike defaultModule,
2450  ArrayRef<std::pair<OptionCaseOp, FModuleLike>> cases, StringRef name,
2451  NameKindEnum nameKind, ArrayRef<Attribute> annotations,
2452  ArrayRef<Attribute> portAnnotations, StringAttr innerSym) {
2453  // Gather the result types.
2454  SmallVector<Type> resultTypes;
2455  for (Attribute portType : defaultModule.getPortTypes())
2456  resultTypes.push_back(cast<TypeAttr>(portType).getValue());
2457 
2458  // Create the port annotations.
2459  ArrayAttr portAnnotationsAttr;
2460  if (portAnnotations.empty()) {
2461  portAnnotationsAttr = builder.getArrayAttr(SmallVector<Attribute, 16>(
2462  resultTypes.size(), builder.getArrayAttr({})));
2463  } else {
2464  portAnnotationsAttr = builder.getArrayAttr(portAnnotations);
2465  }
2466 
2467  // Gather the module & case names.
2468  SmallVector<Attribute> moduleNames, caseNames;
2469  moduleNames.push_back(SymbolRefAttr::get(defaultModule.getModuleNameAttr()));
2470  for (auto [caseOption, caseModule] : cases) {
2471  auto caseGroup = caseOption->getParentOfType<OptionOp>();
2472  caseNames.push_back(SymbolRefAttr::get(caseGroup.getSymNameAttr(),
2473  {SymbolRefAttr::get(caseOption)}));
2474  moduleNames.push_back(SymbolRefAttr::get(caseModule.getModuleNameAttr()));
2475  }
2476 
2477  return build(builder, result, resultTypes, builder.getArrayAttr(moduleNames),
2478  builder.getArrayAttr(caseNames), builder.getStringAttr(name),
2479  NameKindEnumAttr::get(builder.getContext(), nameKind),
2480  defaultModule.getPortDirectionsAttr(),
2481  defaultModule.getPortNamesAttr(),
2482  builder.getArrayAttr(annotations), portAnnotationsAttr,
2483  defaultModule.getLayersAttr(),
2484  innerSym ? hw::InnerSymAttr::get(innerSym) : hw::InnerSymAttr());
2485 }
2486 
2487 std::optional<size_t> InstanceChoiceOp::getTargetResultIndex() {
2488  return std::nullopt;
2489 }
2490 
2491 void InstanceChoiceOp::print(OpAsmPrinter &p) {
2492  // Print the instance name.
2493  p << " ";
2494  p.printKeywordOrString(getName());
2495  if (auto attr = getInnerSymAttr()) {
2496  p << " sym ";
2497  p.printSymbolName(attr.getSymName());
2498  }
2499  if (getNameKindAttr().getValue() != NameKindEnum::DroppableName)
2500  p << ' ' << stringifyNameKindEnum(getNameKindAttr().getValue());
2501 
2502  // Print the attr-dict.
2503  SmallVector<StringRef, 10> omittedAttrs = {
2504  "moduleNames", "caseNames", "name",
2505  "portDirections", "portNames", "portTypes",
2506  "portAnnotations", "inner_sym", "nameKind"};
2507  if (getAnnotations().empty())
2508  omittedAttrs.push_back("annotations");
2509  if (getLayers().empty())
2510  omittedAttrs.push_back("layers");
2511  p.printOptionalAttrDict((*this)->getAttrs(), omittedAttrs);
2512 
2513  // Print the module name.
2514  p << ' ';
2515 
2516  auto moduleNames = getModuleNamesAttr();
2517  auto caseNames = getCaseNamesAttr();
2518 
2519  p.printSymbolName(moduleNames[0].cast<FlatSymbolRefAttr>().getValue());
2520 
2521  p << " alternatives ";
2522  p.printSymbolName(
2523  caseNames[0].cast<SymbolRefAttr>().getRootReference().getValue());
2524  p << " { ";
2525  for (size_t i = 0, n = caseNames.size(); i < n; ++i) {
2526  if (i != 0)
2527  p << ", ";
2528 
2529  auto symbol = caseNames[i].cast<SymbolRefAttr>();
2530  p.printSymbolName(symbol.getNestedReferences()[0].getValue());
2531  p << " -> ";
2532  p.printSymbolName(moduleNames[i + 1].cast<FlatSymbolRefAttr>().getValue());
2533  }
2534 
2535  p << " } ";
2536 
2537  // Collect all the result types as TypeAttrs for printing.
2538  SmallVector<Attribute> portTypes;
2539  portTypes.reserve(getNumResults());
2540  llvm::transform(getResultTypes(), std::back_inserter(portTypes),
2541  &TypeAttr::get);
2542  auto portDirections = direction::unpackAttribute(getPortDirectionsAttr());
2543  printModulePorts(p, /*block=*/nullptr, portDirections,
2544  getPortNames().getValue(), portTypes,
2545  getPortAnnotations().getValue(), {}, {});
2546 }
2547 
2548 ParseResult InstanceChoiceOp::parse(OpAsmParser &parser,
2549  OperationState &result) {
2550  auto *context = parser.getContext();
2551  auto &resultAttrs = result.attributes;
2552 
2553  std::string name;
2554  hw::InnerSymAttr innerSymAttr;
2555  SmallVector<Attribute> moduleNames;
2556  SmallVector<Attribute> caseNames;
2557  SmallVector<OpAsmParser::Argument> entryArgs;
2558  SmallVector<Direction, 4> portDirections;
2559  SmallVector<Attribute, 4> portNames;
2560  SmallVector<Attribute, 4> portTypes;
2561  SmallVector<Attribute, 4> portAnnotations;
2562  SmallVector<Attribute, 4> portSyms;
2563  SmallVector<Attribute, 4> portLocs;
2564  NameKindEnumAttr nameKind;
2565 
2566  if (parser.parseKeywordOrString(&name))
2567  return failure();
2568  if (succeeded(parser.parseOptionalKeyword("sym"))) {
2569  if (parser.parseCustomAttributeWithFallback(
2570  innerSymAttr, Type{},
2571  hw::InnerSymbolTable::getInnerSymbolAttrName(),
2572  result.attributes)) {
2573  return failure();
2574  }
2575  }
2576  if (parseNameKind(parser, nameKind) ||
2577  parser.parseOptionalAttrDict(result.attributes))
2578  return failure();
2579 
2580  FlatSymbolRefAttr defaultModuleName;
2581  if (parser.parseAttribute(defaultModuleName))
2582  return failure();
2583  moduleNames.push_back(defaultModuleName);
2584 
2585  // alternatives { @opt::@case -> @target, ... }
2586  {
2587  FlatSymbolRefAttr optionName;
2588  if (parser.parseKeyword("alternatives") ||
2589  parser.parseAttribute(optionName) || parser.parseLBrace())
2590  return failure();
2591 
2592  FlatSymbolRefAttr moduleName;
2593  StringAttr caseName;
2594  while (succeeded(parser.parseOptionalSymbolName(caseName))) {
2595  if (parser.parseArrow() || parser.parseAttribute(moduleName))
2596  return failure();
2597  moduleNames.push_back(moduleName);
2598  caseNames.push_back(SymbolRefAttr::get(
2599  optionName.getAttr(), {FlatSymbolRefAttr::get(caseName)}));
2600  if (failed(parser.parseOptionalComma()))
2601  break;
2602  }
2603  if (parser.parseRBrace())
2604  return failure();
2605  }
2606 
2607  if (parseModulePorts(parser, /*hasSSAIdentifiers=*/false,
2608  /*supportsSymbols=*/false, entryArgs, portDirections,
2609  portNames, portTypes, portAnnotations, portSyms,
2610  portLocs))
2611  return failure();
2612 
2613  // Add the attributes. We let attributes defined in the attr-dict override
2614  // attributes parsed out of the module signature.
2615  if (!resultAttrs.get("moduleNames"))
2616  result.addAttribute("moduleNames", ArrayAttr::get(context, moduleNames));
2617  if (!resultAttrs.get("caseNames"))
2618  result.addAttribute("caseNames", ArrayAttr::get(context, caseNames));
2619  if (!resultAttrs.get("name"))
2620  result.addAttribute("name", StringAttr::get(context, name));
2621  result.addAttribute("nameKind", nameKind);
2622  if (!resultAttrs.get("portDirections"))
2623  result.addAttribute("portDirections",
2624  direction::packAttribute(context, portDirections));
2625  if (!resultAttrs.get("portNames"))
2626  result.addAttribute("portNames", ArrayAttr::get(context, portNames));
2627  if (!resultAttrs.get("portAnnotations"))
2628  result.addAttribute("portAnnotations",
2629  ArrayAttr::get(context, portAnnotations));
2630 
2631  // Annotations, layers, and LowerToBind are omitted in the printed format if
2632  // they are empty, empty, and false (respectively).
2633  if (!resultAttrs.get("annotations"))
2634  resultAttrs.append("annotations", parser.getBuilder().getArrayAttr({}));
2635  if (!resultAttrs.get("layers"))
2636  resultAttrs.append("layers", parser.getBuilder().getArrayAttr({}));
2637 
2638  // Add result types.
2639  result.types.reserve(portTypes.size());
2640  llvm::transform(
2641  portTypes, std::back_inserter(result.types),
2642  [](Attribute typeAttr) { return cast<TypeAttr>(typeAttr).getValue(); });
2643 
2644  return success();
2645 }
2646 
2648  StringRef base = getName().empty() ? "inst" : getName();
2649  for (auto [result, name] : llvm::zip(getResults(), getPortNames()))
2650  setNameFn(result, (base + "_" + name.cast<StringAttr>().getValue()).str());
2651 }
2652 
2653 LogicalResult InstanceChoiceOp::verify() {
2654  if (getCaseNamesAttr().empty())
2655  return emitOpError() << "must have at least one case";
2656  if (getModuleNamesAttr().size() != getCaseNamesAttr().size() + 1)
2657  return emitOpError() << "number of referenced modules does not match the "
2658  "number of options";
2659 
2660  // The modules may only be instantiated under their required layers (which
2661  // are the same for all modules).
2662  auto ambientLayers = getAmbientLayersAt(getOperation());
2663  SmallVector<SymbolRefAttr> missingLayers;
2664  for (auto layer : getLayersAttr().getAsRange<SymbolRefAttr>())
2665  if (!isLayerCompatibleWith(layer, ambientLayers))
2666  missingLayers.push_back(layer);
2667 
2668  if (missingLayers.empty())
2669  return success();
2670 
2671  auto diag =
2672  emitOpError("ambient layers are insufficient to instantiate module");
2673  auto &note = diag.attachNote();
2674  note << "missing layer requirements: ";
2675  interleaveComma(missingLayers, note);
2676  return failure();
2677 }
2678 
2679 LogicalResult
2680 InstanceChoiceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
2681  auto caseNames = getCaseNamesAttr();
2682  for (auto moduleName : getModuleNamesAttr()) {
2684  *this, symbolTable, moduleName.cast<FlatSymbolRefAttr>())))
2685  return failure();
2686  }
2687 
2688  auto root = caseNames[0].cast<SymbolRefAttr>().getRootReference();
2689  for (size_t i = 0, n = caseNames.size(); i < n; ++i) {
2690  auto ref = caseNames[i].cast<SymbolRefAttr>();
2691  auto refRoot = ref.getRootReference();
2692  if (ref.getRootReference() != root)
2693  return emitOpError() << "case " << ref
2694  << " is not in the same option group as "
2695  << caseNames[0];
2696 
2697  if (!symbolTable.lookupNearestSymbolFrom<OptionOp>(*this, refRoot))
2698  return emitOpError() << "option " << refRoot << " does not exist";
2699 
2700  if (!symbolTable.lookupNearestSymbolFrom<OptionCaseOp>(*this, ref))
2701  return emitOpError() << "option " << refRoot
2702  << " does not contain option case " << ref;
2703  }
2704 
2705  return success();
2706 }
2707 
2708 FlatSymbolRefAttr
2709 InstanceChoiceOp::getTargetOrDefaultAttr(OptionCaseOp option) {
2710  auto caseNames = getCaseNamesAttr();
2711  for (size_t i = 0, n = caseNames.size(); i < n; ++i) {
2712  StringAttr caseSym = caseNames[i].cast<SymbolRefAttr>().getLeafReference();
2713  if (caseSym == option.getSymName())
2714  return getModuleNamesAttr()[i + 1].cast<FlatSymbolRefAttr>();
2715  }
2716  return getDefaultTargetAttr();
2717 }
2718 
2719 SmallVector<std::pair<SymbolRefAttr, FlatSymbolRefAttr>, 1>
2720 InstanceChoiceOp::getTargetChoices() {
2721  auto caseNames = getCaseNamesAttr();
2722  auto moduleNames = getModuleNamesAttr();
2723  SmallVector<std::pair<SymbolRefAttr, FlatSymbolRefAttr>, 1> choices;
2724  for (size_t i = 0; i < caseNames.size(); ++i) {
2725  choices.emplace_back(caseNames[i].cast<SymbolRefAttr>(),
2726  moduleNames[i + 1].cast<FlatSymbolRefAttr>());
2727  }
2728 
2729  return choices;
2730 }
2731 
2732 //===----------------------------------------------------------------------===//
2733 // MemOp
2734 //===----------------------------------------------------------------------===//
2735 
2736 void MemOp::build(OpBuilder &builder, OperationState &result,
2737  TypeRange resultTypes, uint32_t readLatency,
2738  uint32_t writeLatency, uint64_t depth, RUWAttr ruw,
2739  ArrayRef<Attribute> portNames, StringRef name,
2740  NameKindEnum nameKind, ArrayRef<Attribute> annotations,
2741  ArrayRef<Attribute> portAnnotations,
2742  hw::InnerSymAttr innerSym) {
2743  result.addAttribute(
2744  "readLatency",
2745  builder.getIntegerAttr(builder.getIntegerType(32), readLatency));
2746  result.addAttribute(
2747  "writeLatency",
2748  builder.getIntegerAttr(builder.getIntegerType(32), writeLatency));
2749  result.addAttribute(
2750  "depth", builder.getIntegerAttr(builder.getIntegerType(64), depth));
2751  result.addAttribute("ruw", ::RUWAttrAttr::get(builder.getContext(), ruw));
2752  result.addAttribute("portNames", builder.getArrayAttr(portNames));
2753  result.addAttribute("name", builder.getStringAttr(name));
2754  result.addAttribute("nameKind",
2755  NameKindEnumAttr::get(builder.getContext(), nameKind));
2756  result.addAttribute("annotations", builder.getArrayAttr(annotations));
2757  if (innerSym)
2758  result.addAttribute("inner_sym", innerSym);
2759  result.addTypes(resultTypes);
2760 
2761  if (portAnnotations.empty()) {
2762  SmallVector<Attribute, 16> portAnnotationsVec(resultTypes.size(),
2763  builder.getArrayAttr({}));
2764  result.addAttribute("portAnnotations",
2765  builder.getArrayAttr(portAnnotationsVec));
2766  } else {
2767  assert(portAnnotations.size() == resultTypes.size());
2768  result.addAttribute("portAnnotations",
2769  builder.getArrayAttr(portAnnotations));
2770  }
2771 }
2772 
2773 void MemOp::build(OpBuilder &builder, OperationState &result,
2774  TypeRange resultTypes, uint32_t readLatency,
2775  uint32_t writeLatency, uint64_t depth, RUWAttr ruw,
2776  ArrayRef<Attribute> portNames, StringRef name,
2777  NameKindEnum nameKind, ArrayRef<Attribute> annotations,
2778  ArrayRef<Attribute> portAnnotations, StringAttr innerSym) {
2779  build(builder, result, resultTypes, readLatency, writeLatency, depth, ruw,
2780  portNames, name, nameKind, annotations, portAnnotations,
2781  innerSym ? hw::InnerSymAttr::get(innerSym) : hw::InnerSymAttr());
2782 }
2783 
2784 ArrayAttr MemOp::getPortAnnotation(unsigned portIdx) {
2785  assert(portIdx < getNumResults() &&
2786  "index should be smaller than result number");
2787  return cast<ArrayAttr>(getPortAnnotations()[portIdx]);
2788 }
2789 
2790 void MemOp::setAllPortAnnotations(ArrayRef<Attribute> annotations) {
2791  assert(annotations.size() == getNumResults() &&
2792  "number of annotations is not equal to result number");
2793  (*this)->setAttr("portAnnotations",
2794  ArrayAttr::get(getContext(), annotations));
2795 }
2796 
2797 // Get the number of read, write and read-write ports.
2798 void MemOp::getNumPorts(size_t &numReadPorts, size_t &numWritePorts,
2799  size_t &numReadWritePorts, size_t &numDbgsPorts) {
2800  numReadPorts = 0;
2801  numWritePorts = 0;
2802  numReadWritePorts = 0;
2803  numDbgsPorts = 0;
2804  for (size_t i = 0, e = getNumResults(); i != e; ++i) {
2805  auto portKind = getPortKind(i);
2806  if (portKind == MemOp::PortKind::Debug)
2807  ++numDbgsPorts;
2808  else if (portKind == MemOp::PortKind::Read)
2809  ++numReadPorts;
2810  else if (portKind == MemOp::PortKind::Write) {
2811  ++numWritePorts;
2812  } else
2813  ++numReadWritePorts;
2814  }
2815 }
2816 
2817 /// Verify the correctness of a MemOp.
2818 LogicalResult MemOp::verify() {
2819 
2820  // Store the port names as we find them. This lets us check quickly
2821  // for uniqueneess.
2822  llvm::SmallDenseSet<Attribute, 8> portNamesSet;
2823 
2824  // Store the previous data type. This lets us check that the data
2825  // type is consistent across all ports.
2826  FIRRTLType oldDataType;
2827 
2828  for (size_t i = 0, e = getNumResults(); i != e; ++i) {
2829  auto portName = getPortName(i);
2830 
2831  // Get a bundle type representing this port, stripping an outer
2832  // flip if it exists. If this is not a bundle<> or
2833  // flip<bundle<>>, then this is an error.
2834  BundleType portBundleType =
2835  type_dyn_cast<BundleType>(getResult(i).getType());
2836 
2837  // Require that all port names are unique.
2838  if (!portNamesSet.insert(portName).second) {
2839  emitOpError() << "has non-unique port name " << portName;
2840  return failure();
2841  }
2842 
2843  // Determine the kind of the memory. If the kind cannot be
2844  // determined, then it's indicative of the wrong number of fields
2845  // in the type (but we don't know any more just yet).
2846 
2847  auto elt = getPortNamed(portName);
2848  if (!elt) {
2849  emitOpError() << "could not get port with name " << portName;
2850  return failure();
2851  }
2852  auto firrtlType = type_cast<FIRRTLType>(elt.getType());
2853  MemOp::PortKind portKind = getMemPortKindFromType(firrtlType);
2854 
2855  if (portKind == MemOp::PortKind::Debug &&
2856  !type_isa<RefType>(getResult(i).getType()))
2857  return emitOpError() << "has an invalid type on port " << portName
2858  << " (expected Read/Write/ReadWrite/Debug)";
2859  if (type_isa<RefType>(firrtlType) && e == 1)
2860  return emitOpError()
2861  << "cannot have only one port of debug type. Debug port can only "
2862  "exist alongside other read/write/read-write port";
2863 
2864  // Safely search for the "data" field, erroring if it can't be
2865  // found.
2866  FIRRTLBaseType dataType;
2867  if (portKind == MemOp::PortKind::Debug) {
2868  auto resType = type_cast<RefType>(getResult(i).getType());
2869  if (!(resType && type_isa<FVectorType>(resType.getType())))
2870  return emitOpError() << "debug ports must be a RefType of FVectorType";
2871  dataType = type_cast<FVectorType>(resType.getType()).getElementType();
2872  } else {
2873  auto dataTypeOption = portBundleType.getElement("data");
2874  if (!dataTypeOption && portKind == MemOp::PortKind::ReadWrite)
2875  dataTypeOption = portBundleType.getElement("wdata");
2876  if (!dataTypeOption) {
2877  emitOpError() << "has no data field on port " << portName
2878  << " (expected to see \"data\" for a read or write "
2879  "port or \"rdata\" for a read/write port)";
2880  return failure();
2881  }
2882  dataType = dataTypeOption->type;
2883  // Read data is expected to ba a flip.
2884  if (portKind == MemOp::PortKind::Read) {
2885  // FIXME error on missing bundle flip
2886  }
2887  }
2888 
2889  // Error if the data type isn't passive.
2890  if (!dataType.isPassive()) {
2891  emitOpError() << "has non-passive data type on port " << portName
2892  << " (memory types must be passive)";
2893  return failure();
2894  }
2895 
2896  // Error if the data type contains analog types.
2897  if (dataType.containsAnalog()) {
2898  emitOpError() << "has a data type that contains an analog type on port "
2899  << portName
2900  << " (memory types cannot contain analog types)";
2901  return failure();
2902  }
2903 
2904  // Check that the port type matches the kind that we determined
2905  // for this port. This catches situations of extraneous port
2906  // fields beind included or the fields being named incorrectly.
2907  FIRRTLType expectedType =
2908  getTypeForPort(getDepth(), dataType, portKind,
2909  dataType.isGround() ? getMaskBits() : 0);
2910  // Compute the original port type as portBundleType may have
2911  // stripped outer flip information.
2912  auto originalType = getResult(i).getType();
2913  if (originalType != expectedType) {
2914  StringRef portKindName;
2915  switch (portKind) {
2916  case MemOp::PortKind::Read:
2917  portKindName = "read";
2918  break;
2919  case MemOp::PortKind::Write:
2920  portKindName = "write";
2921  break;
2922  case MemOp::PortKind::ReadWrite:
2923  portKindName = "readwrite";
2924  break;
2925  case MemOp::PortKind::Debug:
2926  portKindName = "dbg";
2927  break;
2928  }
2929  emitOpError() << "has an invalid type for port " << portName
2930  << " of determined kind \"" << portKindName
2931  << "\" (expected " << expectedType << ", but got "
2932  << originalType << ")";
2933  return failure();
2934  }
2935 
2936  // Error if the type of the current port was not the same as the
2937  // last port, but skip checking the first port.
2938  if (oldDataType && oldDataType != dataType) {
2939  emitOpError() << "port " << getPortName(i)
2940  << " has a different type than port " << getPortName(i - 1)
2941  << " (expected " << oldDataType << ", but got " << dataType
2942  << ")";
2943  return failure();
2944  }
2945 
2946  oldDataType = dataType;
2947  }
2948 
2949  auto maskWidth = getMaskBits();
2950 
2951  auto dataWidth = getDataType().getBitWidthOrSentinel();
2952  if (dataWidth > 0 && maskWidth > (size_t)dataWidth)
2953  return emitOpError("the mask width cannot be greater than "
2954  "data width");
2955 
2956  if (getPortAnnotations().size() != getNumResults())
2957  return emitOpError("the number of result annotations should be "
2958  "equal to the number of results");
2959 
2960  return success();
2961 }
2962 
2963 static size_t getAddressWidth(size_t depth) {
2964  return std::max(1U, llvm::Log2_64_Ceil(depth));
2965 }
2966 
2967 size_t MemOp::getAddrBits() { return getAddressWidth(getDepth()); }
2968 
2969 FIRRTLType MemOp::getTypeForPort(uint64_t depth, FIRRTLBaseType dataType,
2970  PortKind portKind, size_t maskBits) {
2971 
2972  auto *context = dataType.getContext();
2973  if (portKind == PortKind::Debug)
2974  return RefType::get(FVectorType::get(dataType, depth));
2975  FIRRTLBaseType maskType;
2976  // maskBits not specified (==0), then get the mask type from the dataType.
2977  if (maskBits == 0)
2978  maskType = dataType.getMaskType();
2979  else
2980  maskType = UIntType::get(context, maskBits);
2981 
2982  auto getId = [&](StringRef name) -> StringAttr {
2983  return StringAttr::get(context, name);
2984  };
2985 
2986  SmallVector<BundleType::BundleElement, 7> portFields;
2987 
2988  auto addressType = UIntType::get(context, getAddressWidth(depth));
2989 
2990  portFields.push_back({getId("addr"), false, addressType});
2991  portFields.push_back({getId("en"), false, UIntType::get(context, 1)});
2992  portFields.push_back({getId("clk"), false, ClockType::get(context)});
2993 
2994  switch (portKind) {
2995  case PortKind::Read:
2996  portFields.push_back({getId("data"), true, dataType});
2997  break;
2998 
2999  case PortKind::Write:
3000  portFields.push_back({getId("data"), false, dataType});
3001  portFields.push_back({getId("mask"), false, maskType});
3002  break;
3003 
3004  case PortKind::ReadWrite:
3005  portFields.push_back({getId("rdata"), true, dataType});
3006  portFields.push_back({getId("wmode"), false, UIntType::get(context, 1)});
3007  portFields.push_back({getId("wdata"), false, dataType});
3008  portFields.push_back({getId("wmask"), false, maskType});
3009  break;
3010  default:
3011  llvm::report_fatal_error("memory port kind not handled");
3012  break;
3013  }
3014 
3015  return BundleType::get(context, portFields);
3016 }
3017 
3018 /// Return the name and kind of ports supported by this memory.
3019 SmallVector<MemOp::NamedPort> MemOp::getPorts() {
3020  SmallVector<MemOp::NamedPort> result;
3021  // Each entry in the bundle is a port.
3022  for (size_t i = 0, e = getNumResults(); i != e; ++i) {
3023  // Each port is a bundle.
3024  auto portType = type_cast<FIRRTLType>(getResult(i).getType());
3025  result.push_back({getPortName(i), getMemPortKindFromType(portType)});
3026  }
3027  return result;
3028 }
3029 
3030 /// Return the kind of the specified port.
3031 MemOp::PortKind MemOp::getPortKind(StringRef portName) {
3032  return getMemPortKindFromType(
3033  type_cast<FIRRTLType>(getPortNamed(portName).getType()));
3034 }
3035 
3036 /// Return the kind of the specified port number.
3037 MemOp::PortKind MemOp::getPortKind(size_t resultNo) {
3038  return getMemPortKindFromType(
3039  type_cast<FIRRTLType>(getResult(resultNo).getType()));
3040 }
3041 
3042 /// Return the number of bits in the mask for the memory.
3043 size_t MemOp::getMaskBits() {
3044 
3045  for (auto res : getResults()) {
3046  if (type_isa<RefType>(res.getType()))
3047  continue;
3048  auto firstPortType = type_cast<FIRRTLBaseType>(res.getType());
3049  if (getMemPortKindFromType(firstPortType) == PortKind::Read ||
3050  getMemPortKindFromType(firstPortType) == PortKind::Debug)
3051  continue;
3052 
3053  FIRRTLBaseType mType;
3054  for (auto t : type_cast<BundleType>(firstPortType.getPassiveType())) {
3055  if (t.name.getValue().contains("mask"))
3056  mType = t.type;
3057  }
3058  if (type_isa<UIntType>(mType))
3059  return mType.getBitWidthOrSentinel();
3060  }
3061  // Mask of zero bits means, either there are no write/readwrite ports or the
3062  // mask is of aggregate type.
3063  return 0;
3064 }
3065 
3066 /// Return the data-type field of the memory, the type of each element.
3067 FIRRTLBaseType MemOp::getDataType() {
3068  assert(getNumResults() != 0 && "Mems with no read/write ports are illegal");
3069 
3070  if (auto refType = type_dyn_cast<RefType>(getResult(0).getType()))
3071  return type_cast<FVectorType>(refType.getType()).getElementType();
3072  auto firstPortType = type_cast<FIRRTLBaseType>(getResult(0).getType());
3073 
3074  StringRef dataFieldName = "data";
3075  if (getMemPortKindFromType(firstPortType) == PortKind::ReadWrite)
3076  dataFieldName = "rdata";
3077 
3078  return type_cast<BundleType>(firstPortType.getPassiveType())
3079  .getElementType(dataFieldName);
3080 }
3081 
3082 StringAttr MemOp::getPortName(size_t resultNo) {
3083  return cast<StringAttr>(getPortNames()[resultNo]);
3084 }
3085 
3086 FIRRTLBaseType MemOp::getPortType(size_t resultNo) {
3087  return type_cast<FIRRTLBaseType>(getResults()[resultNo].getType());
3088 }
3089 
3090 Value MemOp::getPortNamed(StringAttr name) {
3091  auto namesArray = getPortNames();
3092  for (size_t i = 0, e = namesArray.size(); i != e; ++i) {
3093  if (namesArray[i] == name) {
3094  assert(i < getNumResults() && " names array out of sync with results");
3095  return getResult(i);
3096  }
3097  }
3098  return Value();
3099 }
3100 
3101 // Extract all the relevant attributes from the MemOp and return the FirMemory.
3103  auto op = *this;
3104  size_t numReadPorts = 0;
3105  size_t numWritePorts = 0;
3106  size_t numReadWritePorts = 0;
3108  SmallVector<int32_t> writeClockIDs;
3109 
3110  for (size_t i = 0, e = op.getNumResults(); i != e; ++i) {
3111  auto portKind = op.getPortKind(i);
3112  if (portKind == MemOp::PortKind::Read)
3113  ++numReadPorts;
3114  else if (portKind == MemOp::PortKind::Write) {
3115  for (auto *a : op.getResult(i).getUsers()) {
3116  auto subfield = dyn_cast<SubfieldOp>(a);
3117  if (!subfield || subfield.getFieldIndex() != 2)
3118  continue;
3119  auto clockPort = a->getResult(0);
3120  for (auto *b : clockPort.getUsers()) {
3121  if (auto connect = dyn_cast<FConnectLike>(b)) {
3122  if (connect.getDest() == clockPort) {
3123  auto result =
3124  clockToLeader.insert({circt::firrtl::getModuleScopedDriver(
3125  connect.getSrc(), true, true, true),
3126  numWritePorts});
3127  if (result.second) {
3128  writeClockIDs.push_back(numWritePorts);
3129  } else {
3130  writeClockIDs.push_back(result.first->second);
3131  }
3132  }
3133  }
3134  }
3135  break;
3136  }
3137  ++numWritePorts;
3138  } else
3139  ++numReadWritePorts;
3140  }
3141 
3142  size_t width = 0;
3143  if (auto widthV = getBitWidth(op.getDataType()))
3144  width = *widthV;
3145  else
3146  op.emitError("'firrtl.mem' should have simple type and known width");
3147  MemoryInitAttr init = op->getAttrOfType<MemoryInitAttr>("init");
3148  StringAttr modName;
3149  if (op->hasAttr("modName"))
3150  modName = op->getAttrOfType<StringAttr>("modName");
3151  else {
3152  SmallString<8> clocks;
3153  for (auto a : writeClockIDs)
3154  clocks.append(Twine((char)(a + 'a')).str());
3155  SmallString<32> initStr;
3156  // If there is a file initialization, then come up with a decent
3157  // representation for this. Use the filename, but only characters
3158  // [a-zA-Z0-9] and the bool/hex and inline booleans.
3159  if (init) {
3160  for (auto c : init.getFilename().getValue())
3161  if ((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') ||
3162  (c >= '0' && c <= '9'))
3163  initStr.push_back(c);
3164  initStr.push_back('_');
3165  initStr.push_back(init.getIsBinary() ? 't' : 'f');
3166  initStr.push_back('_');
3167  initStr.push_back(init.getIsInline() ? 't' : 'f');
3168  }
3169  modName = StringAttr::get(
3170  op->getContext(),
3171  llvm::formatv(
3172  "{0}FIRRTLMem_{1}_{2}_{3}_{4}_{5}_{6}_{7}_{8}_{9}_{10}{11}{12}",
3173  op.getPrefix().value_or(""), numReadPorts, numWritePorts,
3174  numReadWritePorts, (size_t)width, op.getDepth(),
3175  op.getReadLatency(), op.getWriteLatency(), op.getMaskBits(),
3176  (unsigned)op.getRuw(), (unsigned)seq::WUW::PortOrder,
3177  clocks.empty() ? "" : "_" + clocks, init ? initStr.str() : ""));
3178  }
3179  return {numReadPorts,
3180  numWritePorts,
3181  numReadWritePorts,
3182  (size_t)width,
3183  op.getDepth(),
3184  op.getReadLatency(),
3185  op.getWriteLatency(),
3186  op.getMaskBits(),
3187  *seq::symbolizeRUW(unsigned(op.getRuw())),
3188  seq::WUW::PortOrder,
3189  writeClockIDs,
3190  modName,
3191  op.getMaskBits() > 1,
3192  init,
3193  op.getPrefixAttr(),
3194  op.getLoc()};
3195 }
3196 
3198  StringRef base = getName();
3199  if (base.empty())
3200  base = "mem";
3201 
3202  for (size_t i = 0, e = (*this)->getNumResults(); i != e; ++i) {
3203  setNameFn(getResult(i), (base + "_" + getPortNameStr(i)).str());
3204  }
3205 }
3206 
3207 std::optional<size_t> MemOp::getTargetResultIndex() {
3208  // Inner symbols on memory operations target the op not any result.
3209  return std::nullopt;
3210 }
3211 
3212 // Construct name of the module which will be used for the memory definition.
3213 StringAttr FirMemory::getFirMemoryName() const { return modName; }
3214 
3215 /// Helper for naming forceable declarations (and their optional ref result).
3216 static void forceableAsmResultNames(Forceable op, StringRef name,
3217  OpAsmSetValueNameFn setNameFn) {
3218  if (name.empty())
3219  return;
3220  setNameFn(op.getDataRaw(), name);
3221  if (op.isForceable())
3222  setNameFn(op.getDataRef(), (name + "_ref").str());
3223 }
3224 
3226  return forceableAsmResultNames(*this, getName(), setNameFn);
3227 }
3228 
3229 LogicalResult NodeOp::inferReturnTypes(
3230  mlir::MLIRContext *context, std::optional<mlir::Location> location,
3231  ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes,
3232  ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions,
3233  ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) {
3234  if (operands.empty())
3235  return failure();
3236  inferredReturnTypes.push_back(operands[0].getType());
3237  for (auto &attr : attributes)
3238  if (attr.getName() == Forceable::getForceableAttrName()) {
3239  auto forceableType =
3240  firrtl::detail::getForceableResultType(true, operands[0].getType());
3241  if (!forceableType) {
3242  if (location)
3243  ::mlir::emitError(*location, "cannot force a node of type ")
3244  << operands[0].getType();
3245  return failure();
3246  }
3247  inferredReturnTypes.push_back(forceableType);
3248  }
3249  return success();
3250 }
3251 
3252 std::optional<size_t> NodeOp::getTargetResultIndex() { return 0; }
3253 
3255  return forceableAsmResultNames(*this, getName(), setNameFn);
3256 }
3257 
3258 std::optional<size_t> RegOp::getTargetResultIndex() { return 0; }
3259 
3260 LogicalResult RegResetOp::verify() {
3261  auto reset = getResetValue();
3262 
3263  FIRRTLBaseType resetType = reset.getType();
3264  FIRRTLBaseType regType = getResult().getType();
3265 
3266  // The type of the initialiser must be equivalent to the register type.
3267  if (!areTypesEquivalent(regType, resetType))
3268  return emitError("type mismatch between register ")
3269  << regType << " and reset value " << resetType;
3270 
3271  return success();
3272 }
3273 
3274 std::optional<size_t> RegResetOp::getTargetResultIndex() { return 0; }
3275 
3277  return forceableAsmResultNames(*this, getName(), setNameFn);
3278 }
3279 
3281  return forceableAsmResultNames(*this, getName(), setNameFn);
3282 }
3283 
3284 std::optional<size_t> WireOp::getTargetResultIndex() { return 0; }
3285 
3286 LogicalResult WireOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
3287  auto refType = type_dyn_cast<RefType>(getType(0));
3288  if (!refType)
3289  return success();
3290 
3291  return verifyProbeType(
3292  refType, getLoc(), getOperation()->getParentOfType<CircuitOp>(),
3293  symbolTable, Twine("'") + getOperationName() + "' op is");
3294 }
3295 
3296 void ObjectOp::build(OpBuilder &builder, OperationState &state, ClassLike klass,
3297  StringRef name) {
3298  build(builder, state, klass.getInstanceType(),
3299  StringAttr::get(builder.getContext(), name));
3300 }
3301 
3302 LogicalResult ObjectOp::verify() { return success(); }
3303 
3304 LogicalResult ObjectOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
3305  auto circuitOp = getOperation()->getParentOfType<CircuitOp>();
3306  auto classType = getType();
3307  auto className = classType.getNameAttr();
3308 
3309  // verify that the class exists.
3310  auto classOp = dyn_cast_or_null<ClassLike>(
3311  symbolTable.lookupSymbolIn(circuitOp, className));
3312  if (!classOp)
3313  return emitOpError() << "references unknown class " << className;
3314 
3315  // verify that the result type agrees with the class definition.
3316  if (failed(classOp.verifyType(classType, [&]() { return emitOpError(); })))
3317  return failure();
3318 
3319  return success();
3320 }
3321 
3322 StringAttr ObjectOp::getClassNameAttr() {
3323  return getType().getNameAttr().getAttr();
3324 }
3325 
3326 StringRef ObjectOp::getClassName() { return getType().getName(); }
3327 
3328 ClassLike ObjectOp::getReferencedClass(const SymbolTable &symbolTable) {
3329  auto symRef = getType().getNameAttr();
3330  return symbolTable.lookup<ClassLike>(symRef.getLeafReference());
3331 }
3332 
3333 Operation *ObjectOp::getReferencedOperation(const SymbolTable &symtbl) {
3334  return getReferencedClass(symtbl);
3335 }
3336 
3337 StringRef ObjectOp::getInstanceName() { return getName(); }
3338 
3339 StringAttr ObjectOp::getInstanceNameAttr() { return getNameAttr(); }
3340 
3341 StringRef ObjectOp::getReferencedModuleName() { return getClassName(); }
3342 
3343 StringAttr ObjectOp::getReferencedModuleNameAttr() {
3344  return getClassNameAttr();
3345 }
3346 
3348  setNameFn(getResult(), getName());
3349 }
3350 
3351 //===----------------------------------------------------------------------===//
3352 // Statements
3353 //===----------------------------------------------------------------------===//
3354 
3355 LogicalResult AttachOp::verify() {
3356  // All known widths must match.
3357  std::optional<int32_t> commonWidth;
3358  for (auto operand : getOperands()) {
3359  auto thisWidth = type_cast<AnalogType>(operand.getType()).getWidth();
3360  if (!thisWidth)
3361  continue;
3362  if (!commonWidth) {
3363  commonWidth = thisWidth;
3364  continue;
3365  }
3366  if (commonWidth != thisWidth)
3367  return emitOpError("is inavlid as not all known operand widths match");
3368  }
3369  return success();
3370 }
3371 
3372 /// Check if the source and sink are of appropriate flow.
3373 static LogicalResult checkConnectFlow(Operation *connect) {
3374  Value dst = connect->getOperand(0);
3375  Value src = connect->getOperand(1);
3376 
3377  // TODO: Relax this to allow reads from output ports,
3378  // instance/memory input ports.
3379  auto srcFlow = foldFlow(src);
3380  if (!isValidSrc(srcFlow)) {
3381  // A sink that is a port output or instance input used as a source is okay.
3382  auto kind = getDeclarationKind(src);
3383  if (kind != DeclKind::Port && kind != DeclKind::Instance) {
3384  auto srcRef = getFieldRefFromValue(src, /*lookThroughCasts=*/true);
3385  auto [srcName, rootKnown] = getFieldName(srcRef);
3386  auto diag = emitError(connect->getLoc());
3387  diag << "connect has invalid flow: the source expression ";
3388  if (rootKnown)
3389  diag << "\"" << srcName << "\" ";
3390  diag << "has " << toString(srcFlow) << ", expected source or duplex flow";
3391  return diag.attachNote(srcRef.getLoc()) << "the source was defined here";
3392  }
3393  }
3394 
3395  auto dstFlow = foldFlow(dst);
3396  if (!isValidDst(dstFlow)) {
3397  auto dstRef = getFieldRefFromValue(dst, /*lookThroughCasts=*/true);
3398  auto [dstName, rootKnown] = getFieldName(dstRef);
3399  auto diag = emitError(connect->getLoc());
3400  diag << "connect has invalid flow: the destination expression ";
3401  if (rootKnown)
3402  diag << "\"" << dstName << "\" ";
3403  diag << "has " << toString(dstFlow) << ", expected sink or duplex flow";
3404  return diag.attachNote(dstRef.getLoc())
3405  << "the destination was defined here";
3406  }
3407  return success();
3408 }
3409 
3410 // NOLINTBEGIN(misc-no-recursion)
3411 /// Checks if the type has any 'const' leaf elements . If `isFlip` is `true`,
3412 /// the `const` leaf is not considered to be driven.
3413 static bool isConstFieldDriven(FIRRTLBaseType type, bool isFlip = false,
3414  bool outerTypeIsConst = false) {
3415  auto typeIsConst = outerTypeIsConst || type.isConst();
3416 
3417  if (typeIsConst && type.isPassive())
3418  return !isFlip;
3419 
3420  if (auto bundleType = type_dyn_cast<BundleType>(type))
3421  return llvm::any_of(bundleType.getElements(), [&](auto &element) {
3422  return isConstFieldDriven(element.type, isFlip ^ element.isFlip,
3423  typeIsConst);
3424  });
3425 
3426  if (auto vectorType = type_dyn_cast<FVectorType>(type))
3427  return isConstFieldDriven(vectorType.getElementType(), isFlip, typeIsConst);
3428 
3429  if (typeIsConst)
3430  return !isFlip;
3431  return false;
3432 }
3433 // NOLINTEND(misc-no-recursion)
3434 
3435 /// Checks that connections to 'const' destinations are not dependent on
3436 /// non-'const' conditions in when blocks.
3437 static LogicalResult checkConnectConditionality(FConnectLike connect) {
3438  auto dest = connect.getDest();
3439  auto destType = type_dyn_cast<FIRRTLBaseType>(dest.getType());
3440  auto src = connect.getSrc();
3441  auto srcType = type_dyn_cast<FIRRTLBaseType>(src.getType());
3442  if (!destType || !srcType)
3443  return success();
3444 
3445  auto destRefinedType = destType;
3446  auto srcRefinedType = srcType;
3447 
3448  /// Looks up the value's defining op until the defining op is null or a
3449  /// declaration of the value. If a SubAccessOp is encountered with a 'const'
3450  /// input, `originalFieldType` is made 'const'.
3451  auto findFieldDeclarationRefiningFieldType =
3452  [](Value value, FIRRTLBaseType &originalFieldType) -> Value {
3453  while (auto *definingOp = value.getDefiningOp()) {
3454  bool shouldContinue = true;
3455  TypeSwitch<Operation *>(definingOp)
3456  .Case<SubfieldOp, SubindexOp>([&](auto op) { value = op.getInput(); })
3457  .Case<SubaccessOp>([&](SubaccessOp op) {
3458  if (op.getInput()
3459  .getType()
3460  .base()
3461  .getElementTypePreservingConst()
3462  .isConst())
3463  originalFieldType = originalFieldType.getConstType(true);
3464  value = op.getInput();
3465  })
3466  .Default([&](Operation *) { shouldContinue = false; });
3467  if (!shouldContinue)
3468  break;
3469  }
3470  return value;
3471  };
3472 
3473  auto destDeclaration =
3474  findFieldDeclarationRefiningFieldType(dest, destRefinedType);
3475  auto srcDeclaration =
3476  findFieldDeclarationRefiningFieldType(src, srcRefinedType);
3477 
3478  auto checkConstConditionality = [&](Value value, FIRRTLBaseType type,
3479  Value declaration) -> LogicalResult {
3480  auto *declarationBlock = declaration.getParentBlock();
3481  auto *block = connect->getBlock();
3482  while (block && block != declarationBlock) {
3483  auto *parentOp = block->getParentOp();
3484 
3485  if (auto whenOp = dyn_cast<WhenOp>(parentOp);
3486  whenOp && !whenOp.getCondition().getType().isConst()) {
3487  if (type.isConst())
3488  return connect.emitOpError()
3489  << "assignment to 'const' type " << type
3490  << " is dependent on a non-'const' condition";
3491  return connect->emitOpError()
3492  << "assignment to nested 'const' member of type " << type
3493  << " is dependent on a non-'const' condition";
3494  }
3495 
3496  block = parentOp->getBlock();
3497  }
3498  return success();
3499  };
3500 
3501  auto emitSubaccessError = [&] {
3502  return connect.emitError(
3503  "assignment to non-'const' subaccess of 'const' type is disallowed");
3504  };
3505 
3506  // Check destination if it contains 'const' leaves
3507  if (destRefinedType.containsConst() && isConstFieldDriven(destRefinedType)) {
3508  // Disallow assignment to non-'const' subaccesses of 'const' types
3509  if (destType != destRefinedType)
3510  return emitSubaccessError();
3511 
3512  if (failed(checkConstConditionality(dest, destType, destDeclaration)))
3513  return failure();
3514  }
3515 
3516  // Check source if it contains 'const' 'flip' leaves
3517  if (srcRefinedType.containsConst() &&
3518  isConstFieldDriven(srcRefinedType, /*isFlip=*/true)) {
3519  // Disallow assignment to non-'const' subaccesses of 'const' types
3520  if (srcType != srcRefinedType)
3521  return emitSubaccessError();
3522  if (failed(checkConstConditionality(src, srcType, srcDeclaration)))
3523  return failure();
3524  }
3525 
3526  return success();
3527 }
3528 
3529 LogicalResult ConnectOp::verify() {
3530  auto dstType = getDest().getType();
3531  auto srcType = getSrc().getType();
3532  auto dstBaseType = type_dyn_cast<FIRRTLBaseType>(dstType);
3533  auto srcBaseType = type_dyn_cast<FIRRTLBaseType>(srcType);
3534  if (!dstBaseType || !srcBaseType) {
3535  if (dstType != srcType)
3536  return emitError("may not connect different non-base types");
3537  } else {
3538  // Analog types cannot be connected and must be attached.
3539  if (dstBaseType.containsAnalog() || srcBaseType.containsAnalog())
3540  return emitError("analog types may not be connected");
3541 
3542  // Destination and source types must be equivalent.
3543  if (!areTypesEquivalent(dstBaseType, srcBaseType))
3544  return emitError("type mismatch between destination ")
3545  << dstBaseType << " and source " << srcBaseType;
3546 
3547  // Truncation is banned in a connection: destination bit width must be
3548  // greater than or equal to source bit width.
3549  if (!isTypeLarger(dstBaseType, srcBaseType))
3550  return emitError("destination ")
3551  << dstBaseType << " is not as wide as the source " << srcBaseType;
3552  }
3553 
3554  // Check that the flows make sense.
3555  if (failed(checkConnectFlow(*this)))
3556  return failure();
3557 
3558  if (failed(checkConnectConditionality(*this)))
3559  return failure();
3560 
3561  return success();
3562 }
3563 
3564 LogicalResult StrictConnectOp::verify() {
3565  if (auto type = type_dyn_cast<FIRRTLType>(getDest().getType())) {
3566  auto baseType = type_cast<FIRRTLBaseType>(type);
3567 
3568  // Analog types cannot be connected and must be attached.
3569  if (baseType && baseType.containsAnalog())
3570  return emitError("analog types may not be connected");
3571 
3572  // The anonymous types of operands must be equivalent.
3573  assert(areAnonymousTypesEquivalent(cast<FIRRTLBaseType>(getSrc().getType()),
3574  baseType) &&
3575  "`SameAnonTypeOperands` trait should have already rejected "
3576  "structurally non-equivalent types");
3577  }
3578 
3579  // Check that the flows make sense.
3580  if (failed(checkConnectFlow(*this)))
3581  return failure();
3582 
3583  if (failed(checkConnectConditionality(*this)))
3584  return failure();
3585 
3586  return success();
3587 }
3588 
3589 LogicalResult RefDefineOp::verify() {
3590  // Check that the flows make sense.
3591  if (failed(checkConnectFlow(*this)))
3592  return failure();
3593 
3594  // For now, refs can't be in bundles so this is sufficient.
3595  // In the future need to ensure no other define's to same "fieldSource".
3596  // (When aggregates can have references, we can define a reference within,
3597  // but this must be unique. Checking this here may be expensive,
3598  // consider adding something to FModuleLike's to check it there instead)
3599  for (auto *user : getDest().getUsers()) {
3600  if (auto conn = dyn_cast<FConnectLike>(user);
3601  conn && conn.getDest() == getDest() && conn != *this)
3602  return emitError("destination reference cannot be reused by multiple "
3603  "operations, it can only capture a unique dataflow");
3604  }
3605 
3606  // Check "static" source/dest
3607  if (auto *op = getDest().getDefiningOp()) {
3608  // TODO: Make ref.sub only source flow?
3609  if (isa<RefSubOp>(op))
3610  return emitError(
3611  "destination reference cannot be a sub-element of a reference");
3612  if (isa<RefCastOp>(op)) // Source flow, check anyway for now.
3613  return emitError(
3614  "destination reference cannot be a cast of another reference");
3615  }
3616 
3617  // This define is only enabled when its ambient layers are active. Check
3618  // that whenever the destination's layer requirements are met, that this
3619  // op is enabled.
3620  auto ambientLayers = getAmbientLayersAt(getOperation());
3621  auto dstLayers = getLayersFor(getDest());
3622  SmallVector<SymbolRefAttr> missingLayers;
3623  if (!isLayerSetCompatibleWith(ambientLayers, dstLayers, missingLayers)) {
3624  auto diag = emitOpError("has more layer requirements than destination");
3625  auto &note = diag.attachNote();
3626  note << "additional layers required: ";
3627  interleaveComma(missingLayers, note);
3628  return failure();
3629  }
3630 
3631  return success();
3632 }
3633 
3634 LogicalResult PropAssignOp::verify() {
3635  // Check that the flows make sense.
3636  if (failed(checkConnectFlow(*this)))
3637  return failure();
3638 
3639  // Verify that there is a single value driving the destination.
3640  for (auto *user : getDest().getUsers()) {
3641  if (auto conn = dyn_cast<FConnectLike>(user);
3642  conn && conn.getDest() == getDest() && conn != *this)
3643  return emitError("destination property cannot be reused by multiple "
3644  "operations, it can only capture a unique dataflow");
3645  }
3646 
3647  return success();
3648 }
3649 
3650 void WhenOp::createElseRegion() {
3651  assert(!hasElseRegion() && "already has an else region");
3652  getElseRegion().push_back(new Block());
3653 }
3654 
3655 void WhenOp::build(OpBuilder &builder, OperationState &result, Value condition,
3656  bool withElseRegion, std::function<void()> thenCtor,
3657  std::function<void()> elseCtor) {
3658  OpBuilder::InsertionGuard guard(builder);
3659  result.addOperands(condition);
3660 
3661  // Create "then" region.
3662  builder.createBlock(result.addRegion());
3663  if (thenCtor)
3664  thenCtor();
3665 
3666  // Create "else" region.
3667  Region *elseRegion = result.addRegion();
3668  if (withElseRegion) {
3669  builder.createBlock(elseRegion);
3670  if (elseCtor)
3671  elseCtor();
3672  }
3673 }
3674 
3675 //===----------------------------------------------------------------------===//
3676 // MatchOp
3677 //===----------------------------------------------------------------------===//
3678 
3679 LogicalResult MatchOp::verify() {
3680  FEnumType type = getInput().getType();
3681 
3682  // Make sure that the number of tags matches the number of regions.
3683  auto numCases = getTags().size();
3684  auto numRegions = getNumRegions();
3685  if (numRegions != numCases)
3686  return emitOpError("expected ")
3687  << numRegions << " tags but got " << numCases;
3688 
3689  auto numTags = type.getNumElements();
3690 
3691  SmallDenseSet<int64_t> seen;
3692  for (const auto &[tag, region] : llvm::zip(getTags(), getRegions())) {
3693  auto tagIndex = size_t(cast<IntegerAttr>(tag).getInt());
3694 
3695  // Ensure that the block has a single argument.
3696  if (region.getNumArguments() != 1)
3697  return emitOpError("region should have exactly one argument");
3698 
3699  // Make sure that it is a valid tag.
3700  if (tagIndex >= numTags)
3701  return emitOpError("the tag index ")
3702  << tagIndex << " is out of the range of valid tags in " << type;
3703 
3704  // Make sure we have not already matched this tag.
3705  auto [it, inserted] = seen.insert(tagIndex);
3706  if (!inserted)
3707  return emitOpError("the tag ") << type.getElementNameAttr(tagIndex)
3708  << " is matched more than once";
3709 
3710  // Check that the block argument type matches the tag's type.
3711  auto expectedType = type.getElementTypePreservingConst(tagIndex);
3712  auto regionType = region.getArgument(0).getType();
3713  if (regionType != expectedType)
3714  return emitOpError("region type ")
3715  << regionType << " does not match the expected type "
3716  << expectedType;
3717  }
3718 
3719  // Check that the match statement is exhaustive.
3720  for (size_t i = 0, e = type.getNumElements(); i < e; ++i)
3721  if (!seen.contains(i))
3722  return emitOpError("missing case for tag ") << type.getElementNameAttr(i);
3723 
3724  return success();
3725 }
3726 
3727 void MatchOp::print(OpAsmPrinter &p) {
3728  auto input = getInput();
3729  FEnumType type = input.getType();
3730  auto regions = getRegions();
3731  p << " " << input << " : " << type;
3732  SmallVector<StringRef> elided = {"tags"};
3733  p.printOptionalAttrDictWithKeyword((*this)->getAttrs(), elided);
3734  p << " {";
3735  p.increaseIndent();
3736  for (const auto &[tag, region] : llvm::zip(getTags(), regions)) {
3737  p.printNewline();
3738  p << "case ";
3739  p.printKeywordOrString(
3740  type.getElementName(cast<IntegerAttr>(tag).getInt()));
3741  p << "(";
3742  p.printRegionArgument(region.front().getArgument(0), /*attrs=*/{},
3743  /*omitType=*/true);
3744  p << ") ";
3745  p.printRegion(region, /*printEntryBlockArgs=*/false);
3746  }
3747  p.decreaseIndent();
3748  p.printNewline();
3749  p << "}";
3750 }
3751 
3752 ParseResult MatchOp::parse(OpAsmParser &parser, OperationState &result) {
3753  auto *context = parser.getContext();
3754  OpAsmParser::UnresolvedOperand input;
3755  if (parser.parseOperand(input) || parser.parseColon())
3756  return failure();
3757 
3758  auto loc = parser.getCurrentLocation();
3759  Type type;
3760  if (parser.parseType(type))
3761  return failure();
3762  auto enumType = type_dyn_cast<FEnumType>(type);
3763  if (!enumType)
3764  return parser.emitError(loc, "expected enumeration type but got") << type;
3765 
3766  if (parser.resolveOperand(input, type, result.operands) ||
3767  parser.parseOptionalAttrDictWithKeyword(result.attributes) ||
3768  parser.parseLBrace())
3769  return failure();
3770 
3771  auto i32Type = IntegerType::get(context, 32);
3772  SmallVector<Attribute> tags;
3773  while (true) {
3774  // Stop parsing when we don't find another "case" keyword.
3775  if (failed(parser.parseOptionalKeyword("case")))
3776  break;
3777 
3778  // Parse the tag and region argument.
3779  auto nameLoc = parser.getCurrentLocation();
3780  std::string name;
3781  OpAsmParser::Argument arg;
3782  auto *region = result.addRegion();
3783  if (parser.parseKeywordOrString(&name) || parser.parseLParen() ||
3784  parser.parseArgument(arg) || parser.parseRParen())
3785  return failure();
3786 
3787  // Figure out the enum index of the tag.
3788  auto index = enumType.getElementIndex(name);
3789  if (!index)
3790  return parser.emitError(nameLoc, "the tag \"")
3791  << name << "\" is not a member of the enumeration " << enumType;
3792  tags.push_back(IntegerAttr::get(i32Type, *index));
3793 
3794  // Parse the region.
3795  arg.type = enumType.getElementTypePreservingConst(*index);
3796  if (parser.parseRegion(*region, arg))
3797  return failure();
3798  }
3799  result.addAttribute("tags", ArrayAttr::get(context, tags));
3800 
3801  return parser.parseRBrace();
3802 }
3803 
3804 void MatchOp::build(OpBuilder &builder, OperationState &result, Value input,
3805  ArrayAttr tags,
3806  MutableArrayRef<std::unique_ptr<Region>> regions) {
3807  result.addOperands(input);
3808  result.addAttribute("tags", tags);
3809  result.addRegions(regions);
3810 }
3811 
3812 //===----------------------------------------------------------------------===//
3813 // Expressions
3814 //===----------------------------------------------------------------------===//
3815 
3816 /// Type inference adaptor that narrows from the very generic MLIR
3817 /// `InferTypeOpInterface` to what we need in the FIRRTL dialect: just operands
3818 /// and attributes, no context or regions. Also, we only ever produce a single
3819 /// result value, so the FIRRTL-specific type inference ops directly return the
3820 /// inferred type rather than pushing into the `results` vector.
3821 LogicalResult impl::inferReturnTypes(
3822  MLIRContext *context, std::optional<Location> loc, ValueRange operands,
3823  DictionaryAttr attrs, mlir::OpaqueProperties properties,
3824  RegionRange regions, SmallVectorImpl<Type> &results,
3825  llvm::function_ref<FIRRTLType(ValueRange, ArrayRef<NamedAttribute>,
3826  std::optional<Location>)>
3827  callback) {
3828  auto type = callback(
3829  operands, attrs ? attrs.getValue() : ArrayRef<NamedAttribute>{}, loc);
3830  if (type) {
3831  results.push_back(type);
3832  return success();
3833  }
3834  return failure();
3835 }
3836 
3837 /// Get an attribute by name from a list of named attributes. Return null if no
3838 /// attribute is found with that name.
3839 static Attribute maybeGetAttr(ArrayRef<NamedAttribute> attrs, StringRef name) {
3840  for (auto attr : attrs)
3841  if (attr.getName() == name)
3842  return attr.getValue();
3843  return {};
3844 }
3845 
3846 /// Get an attribute by name from a list of named attributes. Aborts if the
3847 /// attribute does not exist.
3848 static Attribute getAttr(ArrayRef<NamedAttribute> attrs, StringRef name) {
3849  if (auto attr = maybeGetAttr(attrs, name))
3850  return attr;
3851  llvm::report_fatal_error("attribute '" + name + "' not found");
3852 }
3853 
3854 /// Same as above, but casts the attribute to a specific type.
3855 template <typename AttrClass>
3856 AttrClass getAttr(ArrayRef<NamedAttribute> attrs, StringRef name) {
3857  return cast<AttrClass>(getAttr(attrs, name));
3858 }
3859 
3860 /// Return true if the specified operation is a firrtl expression.
3861 bool firrtl::isExpression(Operation *op) {
3862  struct IsExprClassifier : public ExprVisitor<IsExprClassifier, bool> {
3863  bool visitInvalidExpr(Operation *op) { return false; }
3864  bool visitUnhandledExpr(Operation *op) { return true; }
3865  };
3866 
3867  return IsExprClassifier().dispatchExprVisitor(op);
3868 }
3869 
3871  // Set invalid values to have a distinct name.
3872  std::string name;
3873  if (auto ty = type_dyn_cast<IntType>(getType())) {
3874  const char *base = ty.isSigned() ? "invalid_si" : "invalid_ui";
3875  auto width = ty.getWidthOrSentinel();
3876  if (width == -1)
3877  name = base;
3878  else
3879  name = (Twine(base) + Twine(width)).str();
3880  } else if (auto ty = type_dyn_cast<AnalogType>(getType())) {
3881  auto width = ty.getWidthOrSentinel();
3882  if (width == -1)
3883  name = "invalid_analog";
3884  else
3885  name = ("invalid_analog" + Twine(width)).str();
3886  } else if (type_isa<AsyncResetType>(getType()))
3887  name = "invalid_asyncreset";
3888  else if (type_isa<ResetType>(getType()))
3889  name = "invalid_reset";
3890  else if (type_isa<ClockType>(getType()))
3891  name = "invalid_clock";
3892  else
3893  name = "invalid";
3894 
3895  setNameFn(getResult(), name);
3896 }
3897 
3898 void ConstantOp::print(OpAsmPrinter &p) {
3899  p << " ";
3900  p.printAttributeWithoutType(getValueAttr());
3901  p << " : ";
3902  p.printType(getType());
3903  p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"value"});
3904 }
3905 
3906 ParseResult ConstantOp::parse(OpAsmParser &parser, OperationState &result) {
3907  // Parse the constant value, without knowing its width.
3908  APInt value;
3909  auto loc = parser.getCurrentLocation();
3910  auto valueResult = parser.parseOptionalInteger(value);
3911  if (!valueResult.has_value())
3912  return parser.emitError(loc, "expected integer value");
3913 
3914  // Parse the result firrtl integer type.
3915  IntType resultType;
3916  if (failed(*valueResult) || parser.parseColonType(resultType) ||
3917  parser.parseOptionalAttrDict(result.attributes))
3918  return failure();
3919  result.addTypes(resultType);
3920 
3921  // Now that we know the width and sign of the result type, we can munge the
3922  // APInt as appropriate.
3923  if (resultType.hasWidth()) {
3924  auto width = (unsigned)resultType.getWidthOrSentinel();
3925  if (width > value.getBitWidth()) {
3926  // sext is always safe here, even for unsigned values, because the
3927  // parseOptionalInteger method will return something with a zero in the
3928  // top bits if it is a positive number.
3929  value = value.sext(width);
3930  } else if (width < value.getBitWidth()) {
3931  // The parser can return an unnecessarily wide result with leading
3932  // zeros. This isn't a problem, but truncating off bits is bad.
3933  if (value.getNumSignBits() < value.getBitWidth() - width)
3934  return parser.emitError(loc, "constant too large for result type ")
3935  << resultType;
3936  value = value.trunc(width);
3937  }
3938  }
3939 
3940  auto intType = parser.getBuilder().getIntegerType(value.getBitWidth(),
3941  resultType.isSigned());
3942  auto valueAttr = parser.getBuilder().getIntegerAttr(intType, value);
3943  result.addAttribute("value", valueAttr);
3944  return success();
3945 }
3946 
3947 LogicalResult ConstantOp::verify() {
3948  // If the result type has a bitwidth, then the attribute must match its width.
3949  IntType intType = getType();
3950  auto width = intType.getWidthOrSentinel();
3951  if (width != -1 && (int)getValue().getBitWidth() != width)
3952  return emitError(
3953  "firrtl.constant attribute bitwidth doesn't match return type");
3954 
3955  // The sign of the attribute's integer type must match our integer type sign.
3956  auto attrType = type_cast<IntegerType>(getValueAttr().getType());
3957  if (attrType.isSignless() || attrType.isSigned() != intType.isSigned())
3958  return emitError("firrtl.constant attribute has wrong sign");
3959 
3960  return success();
3961 }
3962 
3963 /// Build a ConstantOp from an APInt and a FIRRTL type, handling the attribute
3964 /// formation for the 'value' attribute.
3965 void ConstantOp::build(OpBuilder &builder, OperationState &result, IntType type,
3966  const APInt &value) {
3967  int32_t width = type.getWidthOrSentinel();
3968  (void)width;
3969  assert((width == -1 || (int32_t)value.getBitWidth() == width) &&
3970  "incorrect attribute bitwidth for firrtl.constant");
3971 
3972  auto attr =
3973  IntegerAttr::get(type.getContext(), APSInt(value, !type.isSigned()));
3974  return build(builder, result, type, attr);
3975 }
3976 
3977 /// Build a ConstantOp from an APSInt, handling the attribute formation for the
3978 /// 'value' attribute and inferring the FIRRTL type.
3979 void ConstantOp::build(OpBuilder &builder, OperationState &result,
3980  const APSInt &value) {
3981  auto attr = IntegerAttr::get(builder.getContext(), value);
3982  auto type =
3983  IntType::get(builder.getContext(), value.isSigned(), value.getBitWidth());
3984  return build(builder, result, type, attr);
3985 }
3986 
3988  // For constants in particular, propagate the value into the result name to
3989  // make it easier to read the IR.
3990  IntType intTy = getType();
3991  assert(intTy);
3992 
3993  // Otherwise, build a complex name with the value and type.
3994  SmallString<32> specialNameBuffer;
3995  llvm::raw_svector_ostream specialName(specialNameBuffer);
3996  specialName << 'c';
3997  getValue().print(specialName, /*isSigned:*/ intTy.isSigned());
3998 
3999  specialName << (intTy.isSigned() ? "_si" : "_ui");
4000  auto width = intTy.getWidthOrSentinel();
4001  if (width != -1)
4002  specialName << width;
4003  setNameFn(getResult(), specialName.str());
4004 }
4005 
4006 void SpecialConstantOp::print(OpAsmPrinter &p) {
4007  p << " ";
4008  // SpecialConstant uses a BoolAttr, and we want to print `true` as `1`.
4009  p << static_cast<unsigned>(getValue());
4010  p << " : ";
4011  p.printType(getType());
4012  p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"value"});
4013 }
4014 
4015 ParseResult SpecialConstantOp::parse(OpAsmParser &parser,
4016  OperationState &result) {
4017  // Parse the constant value. SpecialConstant uses bool attributes, but it
4018  // prints as an integer.
4019  APInt value;
4020  auto loc = parser.getCurrentLocation();
4021  auto valueResult = parser.parseOptionalInteger(value);
4022  if (!valueResult.has_value())
4023  return parser.emitError(loc, "expected integer value");
4024 
4025  // Clocks and resets can only be 0 or 1.
4026  if (value != 0 && value != 1)
4027  return parser.emitError(loc, "special constants can only be 0 or 1.");
4028 
4029  // Parse the result firrtl type.
4030  Type resultType;
4031  if (failed(*valueResult) || parser.parseColonType(resultType) ||
4032  parser.parseOptionalAttrDict(result.attributes))
4033  return failure();
4034  result.addTypes(resultType);
4035 
4036  // Create the attribute.
4037  auto valueAttr = parser.getBuilder().getBoolAttr(value == 1);
4038  result.addAttribute("value", valueAttr);
4039  return success();
4040 }
4041 
4043  SmallString<32> specialNameBuffer;
4044  llvm::raw_svector_ostream specialName(specialNameBuffer);
4045  specialName << 'c';
4046  specialName << static_cast<unsigned>(getValue());
4047  auto type = getType();
4048  if (type_isa<ClockType>(type)) {
4049  specialName << "_clock";
4050  } else if (type_isa<ResetType>(type)) {
4051  specialName << "_reset";
4052  } else if (type_isa<AsyncResetType>(type)) {
4053  specialName << "_asyncreset";
4054  }
4055  setNameFn(getResult(), specialName.str());
4056 }
4057 
4058 // Checks that an array attr representing an aggregate constant has the correct
4059 // shape. This recurses on the type.
4060 static bool checkAggConstant(Operation *op, Attribute attr,
4061  FIRRTLBaseType type) {
4062  if (type.isGround()) {
4063  if (!isa<IntegerAttr>(attr)) {
4064  op->emitOpError("Ground type is not an integer attribute");
4065  return false;
4066  }
4067  return true;
4068  }
4069  auto attrlist = dyn_cast<ArrayAttr>(attr);
4070  if (!attrlist) {
4071  op->emitOpError("expected array attribute for aggregate constant");
4072  return false;
4073  }
4074  if (auto array = type_dyn_cast<FVectorType>(type)) {
4075  if (array.getNumElements() != attrlist.size()) {
4076  op->emitOpError("array attribute (")
4077  << attrlist.size() << ") has wrong size for vector constant ("
4078  << array.getNumElements() << ")";
4079  return false;
4080  }
4081  return llvm::all_of(attrlist, [&array, op](Attribute attr) {
4082  return checkAggConstant(op, attr, array.getElementType());
4083  });
4084  }
4085  if (auto bundle = type_dyn_cast<BundleType>(type)) {
4086  if (bundle.getNumElements() != attrlist.size()) {
4087  op->emitOpError("array attribute (")
4088  << attrlist.size() << ") has wrong size for bundle constant ("
4089  << bundle.getNumElements() << ")";
4090  return false;
4091  }
4092  for (size_t i = 0; i < bundle.getNumElements(); ++i) {
4093  if (bundle.getElement(i).isFlip) {
4094  op->emitOpError("Cannot have constant bundle type with flip");
4095  return false;
4096  }
4097  if (!checkAggConstant(op, attrlist[i], bundle.getElement(i).type))
4098  return false;
4099  }
4100  return true;
4101  }
4102  op->emitOpError("Unknown aggregate type");
4103  return false;
4104 }
4105 
4106 LogicalResult AggregateConstantOp::verify() {
4107  if (checkAggConstant(getOperation(), getFields(), getType()))
4108  return success();
4109  return failure();
4110 }
4111 
4112 Attribute AggregateConstantOp::getAttributeFromFieldID(uint64_t fieldID) {
4113  FIRRTLBaseType type = getType();
4114  Attribute value = getFields();
4115  while (fieldID != 0) {
4116  if (auto bundle = type_dyn_cast<BundleType>(type)) {
4117  auto index = bundle.getIndexForFieldID(fieldID);
4118  fieldID -= bundle.getFieldID(index);
4119  type = bundle.getElementType(index);
4120  value = cast<ArrayAttr>(value)[index];
4121  } else {
4122  auto vector = type_cast<FVectorType>(type);
4123  auto index = vector.getIndexForFieldID(fieldID);
4124  fieldID -= vector.getFieldID(index);
4125  type = vector.getElementType();
4126  value = cast<ArrayAttr>(value)[index];
4127  }
4128  }
4129  return value;
4130 }
4131 
4132 void FIntegerConstantOp::print(OpAsmPrinter &p) {
4133  p << " ";
4134  p.printAttributeWithoutType(getValueAttr());
4135  p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"value"});
4136 }
4137 
4138 ParseResult FIntegerConstantOp::parse(OpAsmParser &parser,
4139  OperationState &result) {
4140  auto *context = parser.getContext();
4141  APInt value;
4142  if (parser.parseInteger(value) ||
4143  parser.parseOptionalAttrDict(result.attributes))
4144  return failure();
4145  result.addTypes(FIntegerType::get(context));
4146  auto intType =
4147  IntegerType::get(context, value.getBitWidth(), IntegerType::Signed);
4148  auto valueAttr = parser.getBuilder().getIntegerAttr(intType, value);
4149  result.addAttribute("value", valueAttr);
4150  return success();
4151 }
4152 
4153 ParseResult ListCreateOp::parse(OpAsmParser &parser, OperationState &result) {
4154  llvm::SmallVector<OpAsmParser::UnresolvedOperand, 16> operands;
4155  ListType type;
4156 
4157  if (parser.parseOperandList(operands) ||
4158  parser.parseOptionalAttrDict(result.attributes) ||
4159  parser.parseColonType(type))
4160  return failure();
4161  result.addTypes(type);
4162 
4163  return parser.resolveOperands(operands, type.getElementType(),
4164  result.operands);
4165 }
4166 
4167 void ListCreateOp::print(OpAsmPrinter &p) {
4168  p << " ";
4169  p.printOperands(getElements());
4170  p.printOptionalAttrDict((*this)->getAttrs());
4171  p << " : " << getType();
4172 }
4173 
4174 LogicalResult ListCreateOp::verify() {
4175  if (getElements().empty())
4176  return success();
4177 
4178  auto elementType = getElements().front().getType();
4179  auto listElementType = getType().getElementType();
4180  if (elementType != listElementType)
4181  return emitOpError("has elements of type ")
4182  << elementType << " instead of " << listElementType;
4183 
4184  return success();
4185 }
4186 
4187 LogicalResult BundleCreateOp::verify() {
4188  BundleType resultType = getType();
4189  if (resultType.getNumElements() != getFields().size())
4190  return emitOpError("number of fields doesn't match type");
4191  for (size_t i = 0; i < resultType.getNumElements(); ++i)
4192  if (!areTypesConstCastable(
4193  resultType.getElementTypePreservingConst(i),
4194  type_cast<FIRRTLBaseType>(getOperand(i).getType())))
4195  return emitOpError("type of element doesn't match bundle for field ")
4196  << resultType.getElement(i).name;
4197  // TODO: check flow
4198  return success();
4199 }
4200 
4201 LogicalResult VectorCreateOp::verify() {
4202  FVectorType resultType = getType();
4203  if (resultType.getNumElements() != getFields().size())
4204  return emitOpError("number of fields doesn't match type");
4205  auto elemTy = resultType.getElementTypePreservingConst();
4206  for (size_t i = 0; i < resultType.getNumElements(); ++i)
4207  if (!areTypesConstCastable(
4208  elemTy, type_cast<FIRRTLBaseType>(getOperand(i).getType())))
4209  return emitOpError("type of element doesn't match vector element");
4210  // TODO: check flow
4211  return success();
4212 }
4213 
4214 //===----------------------------------------------------------------------===//
4215 // FEnumCreateOp
4216 //===----------------------------------------------------------------------===//
4217 
4218 LogicalResult FEnumCreateOp::verify() {
4219  FEnumType resultType = getResult().getType();
4220  auto elementIndex = resultType.getElementIndex(getFieldName());
4221  if (!elementIndex)
4222  return emitOpError("label ")
4223  << getFieldName() << " is not a member of the enumeration type "
4224  << resultType;
4225  if (!areTypesConstCastable(
4226  resultType.getElementTypePreservingConst(*elementIndex),
4227  getInput().getType()))
4228  return emitOpError("type of element doesn't match enum element");
4229  return success();
4230 }
4231 
4232 void FEnumCreateOp::print(OpAsmPrinter &printer) {
4233  printer << ' ';
4234  printer.printKeywordOrString(getFieldName());
4235  printer << '(' << getInput() << ')';
4236  SmallVector<StringRef> elidedAttrs = {"fieldIndex"};
4237  printer.printOptionalAttrDictWithKeyword((*this)->getAttrs(), elidedAttrs);
4238  printer << " : ";
4239  printer.printFunctionalType(ArrayRef<Type>{getInput().getType()},
4240  ArrayRef<Type>{getResult().getType()});
4241 }
4242 
4243 ParseResult FEnumCreateOp::parse(OpAsmParser &parser, OperationState &result) {
4244  auto *context = parser.getContext();
4245 
4246  OpAsmParser::UnresolvedOperand input;
4247  std::string fieldName;
4248  mlir::FunctionType functionType;
4249  if (parser.parseKeywordOrString(&fieldName) || parser.parseLParen() ||
4250  parser.parseOperand(input) || parser.parseRParen() ||
4251  parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
4252  parser.parseType(functionType))
4253  return failure();
4254 
4255  if (functionType.getNumInputs() != 1)
4256  return parser.emitError(parser.getNameLoc(), "single input type required");
4257  if (functionType.getNumResults() != 1)
4258  return parser.emitError(parser.getNameLoc(), "single result type required");
4259 
4260  auto inputType = functionType.getInput(0);
4261  if (parser.resolveOperand(input, inputType, result.operands))
4262  return failure();
4263 
4264  auto outputType = functionType.getResult(0);
4265  auto enumType = type_dyn_cast<FEnumType>(outputType);
4266  if (!enumType)
4267  return parser.emitError(parser.getNameLoc(),
4268  "output must be enum type, got ")
4269  << outputType;
4270  auto fieldIndex = enumType.getElementIndex(fieldName);
4271  if (!fieldIndex)
4272  return parser.emitError(parser.getNameLoc(),
4273  "unknown field " + fieldName + " in enum type ")
4274  << enumType;
4275 
4276  result.addAttribute(
4277  "fieldIndex",
4278  IntegerAttr::get(IntegerType::get(context, 32), *fieldIndex));
4279 
4280  result.addTypes(enumType);
4281 
4282  return success();
4283 }
4284 
4285 //===----------------------------------------------------------------------===//
4286 // IsTagOp
4287 //===----------------------------------------------------------------------===//
4288 
4289 LogicalResult IsTagOp::verify() {
4290  if (getFieldIndex() >= getInput().getType().base().getNumElements())
4291  return emitOpError("element index is greater than the number of fields in "
4292  "the bundle type");
4293  return success();
4294 }
4295 
4296 void IsTagOp::print(::mlir::OpAsmPrinter &printer) {
4297  printer << ' ' << getInput() << ' ';
4298  printer.printKeywordOrString(getFieldName());
4299  SmallVector<::llvm::StringRef, 1> elidedAttrs = {"fieldIndex"};
4300  printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
4301  printer << " : " << getInput().getType();
4302 }
4303 
4304 ParseResult IsTagOp::parse(OpAsmParser &parser, OperationState &result) {
4305  auto *context = parser.getContext();
4306 
4307  OpAsmParser::UnresolvedOperand input;
4308  std::string fieldName;
4309  Type inputType;
4310  if (parser.parseOperand(input) || parser.parseKeywordOrString(&fieldName) ||
4311  parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
4312  parser.parseType(inputType))
4313  return failure();
4314 
4315  if (parser.resolveOperand(input, inputType, result.operands))
4316  return failure();
4317 
4318  auto enumType = type_dyn_cast<FEnumType>(inputType);
4319  if (!enumType)
4320  return parser.emitError(parser.getNameLoc(),
4321  "input must be enum type, got ")
4322  << inputType;
4323  auto fieldIndex = enumType.getElementIndex(fieldName);
4324  if (!fieldIndex)
4325  return parser.emitError(parser.getNameLoc(),
4326  "unknown field " + fieldName + " in enum type ")
4327  << enumType;
4328 
4329  result.addAttribute(
4330  "fieldIndex",
4331  IntegerAttr::get(IntegerType::get(context, 32), *fieldIndex));
4332 
4333  result.addTypes(UIntType::get(context, 1, /*isConst=*/false));
4334 
4335  return success();
4336 }
4337 
4338 FIRRTLType IsTagOp::inferReturnType(ValueRange operands,
4339  ArrayRef<NamedAttribute> attrs,
4340  std::optional<Location> loc) {
4341  return UIntType::get(operands[0].getContext(), 1,
4342  isConst(operands[0].getType()));
4343 }
4344 
4345 template <typename OpTy>
4346 ParseResult parseSubfieldLikeOp(OpAsmParser &parser, OperationState &result) {
4347  auto *context = parser.getContext();
4348 
4349  OpAsmParser::UnresolvedOperand input;
4350  std::string fieldName;
4351  Type inputType;
4352  if (parser.parseOperand(input) || parser.parseLSquare() ||
4353  parser.parseKeywordOrString(&fieldName) || parser.parseRSquare() ||
4354  parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
4355  parser.parseType(inputType))
4356  return failure();
4357 
4358  if (parser.resolveOperand(input, inputType, result.operands))
4359  return failure();
4360 
4361  auto bundleType = type_dyn_cast<typename OpTy::InputType>(inputType);
4362  if (!bundleType)
4363  return parser.emitError(parser.getNameLoc(),
4364  "input must be bundle type, got ")
4365  << inputType;
4366  auto fieldIndex = bundleType.getElementIndex(fieldName);
4367  if (!fieldIndex)
4368  return parser.emitError(parser.getNameLoc(),
4369  "unknown field " + fieldName + " in bundle type ")
4370  << bundleType;
4371 
4372  result.addAttribute(
4373  "fieldIndex",
4374  IntegerAttr::get(IntegerType::get(context, 32), *fieldIndex));
4375 
4376  SmallVector<Type> inferredReturnTypes;
4377  if (failed(OpTy::inferReturnTypes(context, result.location, result.operands,
4378  result.attributes.getDictionary(context),
4379  result.getRawProperties(), result.regions,
4380  inferredReturnTypes)))
4381  return failure();
4382  result.addTypes(inferredReturnTypes);
4383 
4384  return success();
4385 }
4386 
4387 ParseResult SubtagOp::parse(OpAsmParser &parser, OperationState &result) {
4388  auto *context = parser.getContext();
4389 
4390  OpAsmParser::UnresolvedOperand input;
4391  std::string fieldName;
4392  Type inputType;
4393  if (parser.parseOperand(input) || parser.parseLSquare() ||
4394  parser.parseKeywordOrString(&fieldName) || parser.parseRSquare() ||
4395  parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
4396  parser.parseType(inputType))
4397  return failure();
4398 
4399  if (parser.resolveOperand(input, inputType, result.operands))
4400  return failure();
4401 
4402  auto enumType = type_dyn_cast<FEnumType>(inputType);
4403  if (!enumType)
4404  return parser.emitError(parser.getNameLoc(),
4405  "input must be enum type, got ")
4406  << inputType;
4407  auto fieldIndex = enumType.getElementIndex(fieldName);
4408  if (!fieldIndex)
4409  return parser.emitError(parser.getNameLoc(),
4410  "unknown field " + fieldName + " in enum type ")
4411  << enumType;
4412 
4413  result.addAttribute(
4414  "fieldIndex",
4415  IntegerAttr::get(IntegerType::get(context, 32), *fieldIndex));
4416 
4417  SmallVector<Type> inferredReturnTypes;
4418  if (failed(SubtagOp::inferReturnTypes(
4419  context, result.location, result.operands,
4420  result.attributes.getDictionary(context), result.getRawProperties(),
4421  result.regions, inferredReturnTypes)))
4422  return failure();
4423  result.addTypes(inferredReturnTypes);
4424 
4425  return success();
4426 }
4427 
4428 ParseResult SubfieldOp::parse(OpAsmParser &parser, OperationState &result) {
4429  return parseSubfieldLikeOp<SubfieldOp>(parser, result);
4430 }
4431 ParseResult OpenSubfieldOp::parse(OpAsmParser &parser, OperationState &result) {
4432  return parseSubfieldLikeOp<OpenSubfieldOp>(parser, result);
4433 }
4434 
4435 template <typename OpTy>
4436 static void printSubfieldLikeOp(OpTy op, ::mlir::OpAsmPrinter &printer) {
4437  printer << ' ' << op.getInput() << '[';
4438  printer.printKeywordOrString(op.getFieldName());
4439  printer << ']';
4440  ::llvm::SmallVector<::llvm::StringRef, 2> elidedAttrs;
4441  elidedAttrs.push_back("fieldIndex");
4442  printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
4443  printer << " : " << op.getInput().getType();
4444 }
4445 void SubfieldOp::print(::mlir::OpAsmPrinter &printer) {
4446  return printSubfieldLikeOp<SubfieldOp>(*this, printer);
4447 }
4448 void OpenSubfieldOp::print(::mlir::OpAsmPrinter &printer) {
4449  return printSubfieldLikeOp<OpenSubfieldOp>(*this, printer);
4450 }
4451 
4452 void SubtagOp::print(::mlir::OpAsmPrinter &printer) {
4453  printer << ' ' << getInput() << '[';
4454  printer.printKeywordOrString(getFieldName());
4455  printer << ']';
4456  ::llvm::SmallVector<::llvm::StringRef, 2> elidedAttrs;
4457  elidedAttrs.push_back("fieldIndex");
4458  printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
4459  printer << " : " << getInput().getType();
4460 }
4461 
4462 template <typename OpTy>
4463 static LogicalResult verifySubfieldLike(OpTy op) {
4464  if (op.getFieldIndex() >=
4465  firrtl::type_cast<typename OpTy::InputType>(op.getInput().getType())
4466  .getNumElements())
4467  return op.emitOpError("subfield element index is greater than the number "
4468  "of fields in the bundle type");
4469  return success();
4470 }
4471 LogicalResult SubfieldOp::verify() {
4472  return verifySubfieldLike<SubfieldOp>(*this);
4473 }
4474 LogicalResult OpenSubfieldOp::verify() {
4475  return verifySubfieldLike<OpenSubfieldOp>(*this);
4476 }
4477 
4478 LogicalResult SubtagOp::verify() {
4479  if (getFieldIndex() >= getInput().getType().base().getNumElements())
4480  return emitOpError("subfield element index is greater than the number "
4481  "of fields in the bundle type");
4482  return success();
4483 }
4484 
4485 /// Return true if the specified operation has a constant value. This trivially
4486 /// checks for `firrtl.constant` and friends, but also looks through subaccesses
4487 /// and correctly handles wires driven with only constant values.
4488 bool firrtl::isConstant(Operation *op) {
4489  // Worklist of ops that need to be examined that should all be constant in
4490  // order for the input operation to be constant.
4491  SmallVector<Operation *, 8> worklist({op});
4492 
4493  // Mutable state indicating if this op is a constant. Assume it is a constant
4494  // and look for counterexamples.
4495  bool constant = true;
4496 
4497  // While we haven't found a counterexample and there are still ops in the
4498  // worklist, pull ops off the worklist. If it provides a counterexample, set
4499  // the `constant` to false (and exit on the next loop iteration). Otherwise,
4500  // look through the op or spawn off more ops to look at.
4501  while (constant && !(worklist.empty()))
4502  TypeSwitch<Operation *>(worklist.pop_back_val())
4503  .Case<NodeOp, AsSIntPrimOp, AsUIntPrimOp>([&](auto op) {
4504  if (auto definingOp = op.getInput().getDefiningOp())
4505  worklist.push_back(definingOp);
4506  constant = false;
4507  })
4508  .Case<WireOp, SubindexOp, SubfieldOp>([&](auto op) {
4509  for (auto &use : op.getResult().getUses())
4510  worklist.push_back(use.getOwner());
4511  })
4512  .Case<ConstantOp, SpecialConstantOp, AggregateConstantOp>([](auto) {})
4513  .Default([&](auto) { constant = false; });
4514 
4515  return constant;
4516 }
4517 
4518 /// Return true if the specified value is a constant. This trivially checks for
4519 /// `firrtl.constant` and friends, but also looks through subaccesses and
4520 /// correctly handles wires driven with only constant values.
4521 bool firrtl::isConstant(Value value) {
4522  if (auto *op = value.getDefiningOp())
4523  return isConstant(op);
4524  return false;
4525 }
4526 
4527 LogicalResult ConstCastOp::verify() {
4528  if (!areTypesConstCastable(getResult().getType(), getInput().getType()))
4529  return emitOpError() << getInput().getType()
4530  << " is not 'const'-castable to "
4531  << getResult().getType();
4532  return success();
4533 }
4534 
4535 FIRRTLType SubfieldOp::inferReturnType(ValueRange operands,
4536  ArrayRef<NamedAttribute> attrs,
4537  std::optional<Location> loc) {
4538  auto inType = type_cast<BundleType>(operands[0].getType());
4539  auto fieldIndex =
4540  getAttr<IntegerAttr>(attrs, "fieldIndex").getValue().getZExtValue();
4541 
4542  if (fieldIndex >= inType.getNumElements())
4543  return emitInferRetTypeError(loc,
4544  "subfield element index is greater than the "
4545  "number of fields in the bundle type");
4546 
4547  // SubfieldOp verifier checks that the field index is valid with number of
4548  // subelements.
4549  return inType.getElementTypePreservingConst(fieldIndex);
4550 }
4551 
4552 FIRRTLType OpenSubfieldOp::inferReturnType(ValueRange operands,
4553  ArrayRef<NamedAttribute> attrs,
4554  std::optional<Location> loc) {
4555  auto inType = type_cast<OpenBundleType>(operands[0].getType());
4556  auto fieldIndex =
4557  getAttr<IntegerAttr>(attrs, "fieldIndex").getValue().getZExtValue();
4558 
4559  if (fieldIndex >= inType.getNumElements())
4560  return emitInferRetTypeError(loc,
4561  "subfield element index is greater than the "
4562  "number of fields in the bundle type");
4563 
4564  // OpenSubfieldOp verifier checks that the field index is valid with number of
4565  // subelements.
4566  return inType.getElementTypePreservingConst(fieldIndex);
4567 }
4568 
4569 bool SubfieldOp::isFieldFlipped() {
4570  BundleType bundle = getInput().getType();
4571  return bundle.getElement(getFieldIndex()).isFlip;
4572 }
4573 bool OpenSubfieldOp::isFieldFlipped() {
4574  auto bundle = getInput().getType();
4575  return bundle.getElement(getFieldIndex()).isFlip;
4576 }
4577 
4578 FIRRTLType SubindexOp::inferReturnType(ValueRange operands,
4579  ArrayRef<NamedAttribute> attrs,
4580  std::optional<Location> loc) {
4581  Type inType = operands[0].getType();
4582  auto fieldIdx =
4583  getAttr<IntegerAttr>(attrs, "index").getValue().getZExtValue();
4584 
4585  if (auto vectorType = type_dyn_cast<FVectorType>(inType)) {
4586  if (fieldIdx < vectorType.getNumElements())
4587  return vectorType.getElementTypePreservingConst();
4588  return emitInferRetTypeError(loc, "out of range index '", fieldIdx,
4589  "' in vector type ", inType);
4590  }
4591 
4592  return emitInferRetTypeError(loc, "subindex requires vector operand");
4593 }
4594 
4595 FIRRTLType OpenSubindexOp::inferReturnType(ValueRange operands,
4596  ArrayRef<NamedAttribute> attrs,
4597  std::optional<Location> loc) {
4598  Type inType = operands[0].getType();
4599  auto fieldIdx =
4600  getAttr<IntegerAttr>(attrs, "index").getValue().getZExtValue();
4601 
4602  if (auto vectorType = type_dyn_cast<OpenVectorType>(inType)) {
4603  if (fieldIdx < vectorType.getNumElements())
4604  return vectorType.getElementTypePreservingConst();
4605  return emitInferRetTypeError(loc, "out of range index '", fieldIdx,
4606  "' in vector type ", inType);
4607  }
4608 
4609  return emitInferRetTypeError(loc, "subindex requires vector operand");
4610 }
4611 
4612 FIRRTLType SubtagOp::inferReturnType(ValueRange operands,
4613  ArrayRef<NamedAttribute> attrs,
4614  std::optional<Location> loc) {
4615  auto inType = type_cast<FEnumType>(operands[0].getType());
4616  auto fieldIndex =
4617  getAttr<IntegerAttr>(attrs, "fieldIndex").getValue().getZExtValue();
4618 
4619  if (fieldIndex >= inType.getNumElements())
4620  return emitInferRetTypeError(loc,
4621  "subtag element index is greater than the "
4622  "number of fields in the enum type");
4623 
4624  // SubtagOp verifier checks that the field index is valid with number of
4625  // subelements.
4626  auto elementType = inType.getElement(fieldIndex).type;
4627  return elementType.getConstType(elementType.isConst() || inType.isConst());
4628 }
4629 
4630 FIRRTLType SubaccessOp::inferReturnType(ValueRange operands,
4631  ArrayRef<NamedAttribute> attrs,
4632  std::optional<Location> loc) {
4633  auto inType = operands[0].getType();
4634  auto indexType = operands[1].getType();
4635 
4636  if (!type_isa<UIntType>(indexType))
4637  return emitInferRetTypeError(loc, "subaccess index must be UInt type, not ",
4638  indexType);
4639 
4640  if (auto vectorType = type_dyn_cast<FVectorType>(inType)) {
4641  if (isConst(indexType))
4642  return vectorType.getElementTypePreservingConst();
4643  return vectorType.getElementType().getAllConstDroppedType();
4644  }
4645 
4646  return emitInferRetTypeError(loc, "subaccess requires vector operand, not ",
4647  inType);
4648 }
4649 
4650 FIRRTLType TagExtractOp::inferReturnType(ValueRange operands,
4651  ArrayRef<NamedAttribute> attrs,
4652  std::optional<Location> loc) {
4653  auto inType = type_cast<FEnumType>(operands[0].getType());
4654  auto i = llvm::Log2_32_Ceil(inType.getNumElements());
4655  return UIntType::get(inType.getContext(), i);
4656 }
4657 
4658 ParseResult MultibitMuxOp::parse(OpAsmParser &parser, OperationState &result) {
4659  OpAsmParser::UnresolvedOperand index;
4660  SmallVector<OpAsmParser::UnresolvedOperand, 16> inputs;
4661  Type indexType, elemType;
4662 
4663  if (parser.parseOperand(index) || parser.parseComma() ||
4664  parser.parseOperandList(inputs) ||
4665  parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
4666  parser.parseType(indexType) || parser.parseComma() ||
4667  parser.parseType(elemType))
4668  return failure();
4669 
4670  if (parser.resolveOperand(index, indexType, result.operands))
4671  return failure();
4672 
4673  result.addTypes(elemType);
4674 
4675  return parser.resolveOperands(inputs, elemType, result.operands);
4676 }
4677 
4678 void MultibitMuxOp::print(OpAsmPrinter &p) {
4679  p << " " << getIndex() << ", ";
4680  p.printOperands(getInputs());
4681  p.printOptionalAttrDict((*this)->getAttrs());
4682  p << " : " << getIndex().getType() << ", " << getType();
4683 }
4684 
4685 FIRRTLType MultibitMuxOp::inferReturnType(ValueRange operands,
4686  ArrayRef<NamedAttribute> attrs,
4687  std::optional<Location> loc) {
4688  if (operands.size() < 2)
4689  return emitInferRetTypeError(loc, "at least one input is required");
4690 
4691  // Check all mux inputs have the same type.
4692  if (!llvm::all_of(operands.drop_front(2), [&](auto op) {
4693  return operands[1].getType() == op.getType();
4694  }))
4695  return emitInferRetTypeError(loc, "all inputs must have the same type");
4696 
4697  return type_cast<FIRRTLType>(operands[1].getType());
4698 }
4699 
4700 //===----------------------------------------------------------------------===//
4701 // ObjectSubfieldOp
4702 //===----------------------------------------------------------------------===//
4703 
4705  MLIRContext *context, std::optional<mlir::Location> location,
4706  ValueRange operands, DictionaryAttr attributes, OpaqueProperties properties,
4707  RegionRange regions, llvm::SmallVectorImpl<Type> &inferredReturnTypes) {
4708  auto type = inferReturnType(operands, attributes.getValue(), location);
4709  if (!type)
4710  return failure();
4711  inferredReturnTypes.push_back(type);
4712  return success();
4713 }
4714 
4715 Type ObjectSubfieldOp::inferReturnType(ValueRange operands,
4716  ArrayRef<NamedAttribute> attrs,
4717  std::optional<Location> loc) {
4718  auto classType = dyn_cast<ClassType>(operands[0].getType());
4719  if (!classType)
4720  return emitInferRetTypeError(loc, "base object is not a class");
4721 
4722  auto index = getAttr<IntegerAttr>(attrs, "index").getValue().getZExtValue();
4723  if (classType.getNumElements() <= index)
4724  return emitInferRetTypeError(loc, "element index is greater than the "
4725  "number of fields in the object");
4726 
4727  return classType.getElement(index).type;
4728 }
4729 
4730 void ObjectSubfieldOp::print(OpAsmPrinter &p) {
4731  auto input = getInput();
4732  auto classType = input.getType();
4733  p << ' ' << input << "[";
4734  p.printKeywordOrString(classType.getElement(getIndex()).name);
4735  p << "]";
4736  p.printOptionalAttrDict((*this)->getAttrs(), std::array{StringRef("index")});
4737  p << " : " << classType;
4738 }
4739 
4740 ParseResult ObjectSubfieldOp::parse(OpAsmParser &parser,
4741  OperationState &result) {
4742  auto *context = parser.getContext();
4743 
4744  OpAsmParser::UnresolvedOperand input;
4745  std::string fieldName;
4746  ClassType inputType;
4747  if (parser.parseOperand(input) || parser.parseLSquare() ||
4748  parser.parseKeywordOrString(&fieldName) || parser.parseRSquare() ||
4749  parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
4750  parser.parseType(inputType) ||
4751  parser.resolveOperand(input, inputType, result.operands))
4752  return failure();
4753 
4754  auto index = inputType.getElementIndex(fieldName);
4755  if (!index)
4756  return parser.emitError(parser.getNameLoc(),
4757  "unknown field " + fieldName + " in class type ")
4758  << inputType;
4759  result.addAttribute("index",
4760  IntegerAttr::get(IntegerType::get(context, 32), *index));
4761 
4762  SmallVector<Type> inferredReturnTypes;
4763  if (failed(inferReturnTypes(context, result.location, result.operands,
4764  result.attributes.getDictionary(context),
4765  result.getRawProperties(), result.regions,
4766  inferredReturnTypes)))
4767  return failure();
4768  result.addTypes(inferredReturnTypes);
4769 
4770  return success();
4771 }
4772 
4773 //===----------------------------------------------------------------------===//
4774 // Binary Primitives
4775 //===----------------------------------------------------------------------===//
4776 
4777 /// If LHS and RHS are both UInt or SInt types, the return true and fill in the
4778 /// width of them if known. If unknown, return -1 for the widths.
4779 /// The constness of the result is also returned, where if both lhs and rhs are
4780 /// const, then the result is const.
4781 ///
4782 /// On failure, this reports and error and returns false. This function should
4783 /// not be used if you don't want an error reported.
4784 static bool isSameIntTypeKind(Type lhs, Type rhs, int32_t &lhsWidth,
4785  int32_t &rhsWidth, bool &isConstResult,
4786  std::optional<Location> loc) {
4787  // Must have two integer types with the same signedness.
4788  auto lhsi = type_dyn_cast<IntType>(lhs);
4789  auto rhsi = type_dyn_cast<IntType>(rhs);
4790  if (!lhsi || !rhsi || lhsi.isSigned() != rhsi.isSigned()) {
4791  if (loc) {
4792  if (lhsi && !rhsi)
4793  mlir::emitError(*loc, "second operand must be an integer type, not ")
4794  << rhs;
4795  else if (!lhsi && rhsi)
4796  mlir::emitError(*loc, "first operand must be an integer type, not ")
4797  << lhs;
4798  else if (!lhsi && !rhsi)
4799  mlir::emitError(*loc, "operands must be integer types, not ")
4800  << lhs << " and " << rhs;
4801  else
4802  mlir::emitError(*loc, "operand signedness must match");
4803  }
4804  return false;
4805  }
4806 
4807  lhsWidth = lhsi.getWidthOrSentinel();
4808  rhsWidth = rhsi.getWidthOrSentinel();
4809  isConstResult = lhsi.isConst() && rhsi.isConst();
4810  return true;
4811 }
4812 
4813 LogicalResult impl::verifySameOperandsIntTypeKind(Operation *op) {
4814  assert(op->getNumOperands() == 2 &&
4815  "SameOperandsIntTypeKind on non-binary op");
4816  int32_t lhsWidth, rhsWidth;
4817  bool isConstResult;
4818  return success(isSameIntTypeKind(op->getOperand(0).getType(),
4819  op->getOperand(1).getType(), lhsWidth,
4820  rhsWidth, isConstResult, op->getLoc()));
4821 }
4822 
4823 LogicalResult impl::validateBinaryOpArguments(ValueRange operands,
4824  ArrayRef<NamedAttribute> attrs,
4825  Location loc) {
4826  if (operands.size() != 2 || !attrs.empty()) {
4827  mlir::emitError(loc, "operation requires two operands and no constants");
4828  return failure();
4829  }
4830  return success();
4831 }
4832 
4834  std::optional<Location> loc) {
4835  int32_t lhsWidth, rhsWidth, resultWidth = -1;
4836  bool isConstResult = false;
4837  if (!isSameIntTypeKind(lhs, rhs, lhsWidth, rhsWidth, isConstResult, loc))
4838  return {};
4839 
4840  if (lhsWidth != -1 && rhsWidth != -1)
4841  resultWidth = std::max(lhsWidth, rhsWidth) + 1;
4842  return IntType::get(lhs.getContext(), type_isa<SIntType>(lhs), resultWidth,
4843  isConstResult);
4844 }
4845 
4846 FIRRTLType MulPrimOp::inferBinaryReturnType(FIRRTLType lhs, FIRRTLType rhs,
4847  std::optional<Location> loc) {
4848  int32_t lhsWidth, rhsWidth, resultWidth = -1;
4849  bool isConstResult = false;
4850  if (!isSameIntTypeKind(lhs, rhs, lhsWidth, rhsWidth, isConstResult, loc))
4851  return {};
4852 
4853  if (lhsWidth != -1 && rhsWidth != -1)
4854  resultWidth = lhsWidth + rhsWidth;
4855 
4856  return IntType::get(lhs.getContext(), type_isa<SIntType>(lhs), resultWidth,
4857  isConstResult);
4858 }
4859 
4860 FIRRTLType DivPrimOp::inferBinaryReturnType(FIRRTLType lhs, FIRRTLType rhs,
4861  std::optional<Location> loc) {
4862  int32_t lhsWidth, rhsWidth;
4863  bool isConstResult = false;
4864  if (!isSameIntTypeKind(lhs, rhs, lhsWidth, rhsWidth, isConstResult, loc))
4865  return {};
4866 
4867  // For unsigned, the width is the width of the numerator on the LHS.
4868  if (type_isa<UIntType>(lhs))
4869  return UIntType::get(lhs.getContext(), lhsWidth, isConstResult);
4870 
4871  // For signed, the width is the width of the numerator on the LHS, plus 1.
4872  int32_t resultWidth = lhsWidth != -1 ? lhsWidth + 1 : -1;
4873  return SIntType::get(lhs.getContext(), resultWidth, isConstResult);
4874 }
4875 
4876 FIRRTLType RemPrimOp::inferBinaryReturnType(FIRRTLType lhs, FIRRTLType rhs,
4877  std::optional<Location> loc) {
4878  int32_t lhsWidth, rhsWidth, resultWidth = -1;
4879  bool isConstResult = false;
4880  if (!isSameIntTypeKind(lhs, rhs, lhsWidth, rhsWidth, isConstResult, loc))
4881  return {};
4882 
4883  if (lhsWidth != -1 && rhsWidth != -1)
4884  resultWidth = std::min(lhsWidth, rhsWidth);
4885  return IntType::get(lhs.getContext(), type_isa<SIntType>(lhs), resultWidth,
4886  isConstResult);
4887 }
4888 
4890  std::optional<Location> loc) {
4891  int32_t lhsWidth, rhsWidth, resultWidth = -1;
4892  bool isConstResult = false;
4893  if (!isSameIntTypeKind(lhs, rhs, lhsWidth, rhsWidth, isConstResult, loc))
4894  return {};
4895 
4896  if (lhsWidth != -1 && rhsWidth != -1)
4897  resultWidth = std::max(lhsWidth, rhsWidth);
4898  return UIntType::get(lhs.getContext(), resultWidth, isConstResult);
4899 }
4900 
4902  std::optional<Location> loc) {
4903  if (!type_isa<FVectorType>(lhs) || !type_isa<FVectorType>(rhs))
4904  return {};
4905 
4906  auto lhsVec = type_cast<FVectorType>(lhs);
4907  auto rhsVec = type_cast<FVectorType>(rhs);
4908 
4909  if (lhsVec.getNumElements() != rhsVec.getNumElements())
4910  return {};
4911 
4912  auto elemType =
4913  impl::inferBitwiseResult(lhsVec.getElementTypePreservingConst(),
4914  rhsVec.getElementTypePreservingConst(), loc);
4915  if (!elemType)
4916  return {};
4917  auto elemBaseType = type_cast<FIRRTLBaseType>(elemType);
4918  return FVectorType::get(elemBaseType, lhsVec.getNumElements(),
4919  lhsVec.isConst() && rhsVec.isConst() &&
4920  elemBaseType.isConst());
4921 }
4922 
4924  std::optional<Location> loc) {
4925  return UIntType::get(lhs.getContext(), 1, isConst(lhs) && isConst(rhs));
4926 }
4927 
4928 FIRRTLType CatPrimOp::inferBinaryReturnType(FIRRTLType lhs, FIRRTLType rhs,
4929  std::optional<Location> loc) {
4930  int32_t lhsWidth, rhsWidth, resultWidth = -1;
4931  bool isConstResult = false;
4932  if (!isSameIntTypeKind(lhs, rhs, lhsWidth, rhsWidth, isConstResult, loc))
4933  return {};
4934 
4935  if (lhsWidth != -1 && rhsWidth != -1)
4936  resultWidth = lhsWidth + rhsWidth;
4937  return UIntType::get(lhs.getContext(), resultWidth, isConstResult);
4938 }
4939 
4940 FIRRTLType DShlPrimOp::inferBinaryReturnType(FIRRTLType lhs, FIRRTLType rhs,
4941  std::optional<Location> loc) {
4942  auto lhsi = type_dyn_cast<IntType>(lhs);
4943  auto rhsui = type_dyn_cast<UIntType>(rhs);
4944  if (!rhsui || !lhsi)
4945  return emitInferRetTypeError(
4946  loc, "first operand should be integer, second unsigned int");
4947 
4948  // If the left or right has unknown result type, then the operation does
4949  // too.
4950  auto width = lhsi.getWidthOrSentinel();
4951  if (width == -1 || !rhsui.getWidth().has_value()) {
4952  width = -1;
4953  } else {
4954  auto amount = *rhsui.getWidth();
4955  if (amount >= 32)
4956  return emitInferRetTypeError(loc,
4957  "shift amount too large: second operand of "
4958  "dshl is wider than 31 bits");
4959  int64_t newWidth = (int64_t)width + ((int64_t)1 << amount) - 1;
4960  if (newWidth > INT32_MAX)
4961  return emitInferRetTypeError(
4962  loc, "shift amount too large: first operand shifted by maximum "
4963  "amount exceeds maximum width");
4964  width = newWidth;
4965  }
4966  return IntType::get(lhs.getContext(), lhsi.isSigned(), width,
4967  lhsi.isConst() && rhsui.isConst());
4968 }
4969 
4970 FIRRTLType DShlwPrimOp::inferBinaryReturnType(FIRRTLType lhs, FIRRTLType rhs,
4971  std::optional<Location> loc) {
4972  auto lhsi = type_dyn_cast<IntType>(lhs);
4973  auto rhsu = type_dyn_cast<UIntType>(rhs);
4974  if (!lhsi || !rhsu)
4975  return emitInferRetTypeError(
4976  loc, "first operand should be integer, second unsigned int");
4977  return lhsi.getConstType(lhsi.isConst() && rhsu.isConst());
4978 }
4979 
4980 FIRRTLType DShrPrimOp::inferBinaryReturnType(FIRRTLType lhs, FIRRTLType rhs,
4981  std::optional<Location> loc) {
4982  auto lhsi = type_dyn_cast<IntType>(lhs);
4983  auto rhsu = type_dyn_cast<UIntType>(rhs);
4984  if (!lhsi || !rhsu)
4985  return emitInferRetTypeError(
4986  loc, "first operand should be integer, second unsigned int");
4987  return lhsi.getConstType(lhsi.isConst() && rhsu.isConst());
4988 }
4989 
4990 //===----------------------------------------------------------------------===//
4991 // Unary Primitives
4992 //===----------------------------------------------------------------------===//
4993 
4994 LogicalResult impl::validateUnaryOpArguments(ValueRange operands,
4995  ArrayRef<NamedAttribute> attrs,
4996  Location loc) {
4997  if (operands.size() != 1 || !attrs.empty()) {
4998  mlir::emitError(loc, "operation requires one operand and no constants");
4999  return failure();
5000  }
5001  return success();
5002 }
5003 
5004 FIRRTLType
5005 SizeOfIntrinsicOp::inferUnaryReturnType(FIRRTLType input,
5006  std::optional<Location> loc) {
5007  return UIntType::get(input.getContext(), 32);
5008 }
5009 
5010 FIRRTLType AsSIntPrimOp::inferUnaryReturnType(FIRRTLType input,
5011  std::optional<Location> loc) {
5012  auto base = type_dyn_cast<FIRRTLBaseType>(input);
5013  if (!base)
5014  return emitInferRetTypeError(loc, "operand must be a scalar base type");
5015  int32_t width = base.getBitWidthOrSentinel();
5016  if (width == -2)
5017  return emitInferRetTypeError(loc, "operand must be a scalar type");
5018  return SIntType::get(input.getContext(), width, base.isConst());
5019 }
5020 
5021 FIRRTLType AsUIntPrimOp::inferUnaryReturnType(FIRRTLType input,
5022  std::optional<Location> loc) {
5023  auto base = type_dyn_cast<FIRRTLBaseType>(input);
5024  if (!base)
5025  return emitInferRetTypeError(loc, "operand must be a scalar base type");
5026  int32_t width = base.getBitWidthOrSentinel();
5027  if (width == -2)
5028  return emitInferRetTypeError(loc, "operand must be a scalar type");
5029  return UIntType::get(input.getContext(), width, base.isConst());
5030 }
5031 
5032 FIRRTLType
5033 AsAsyncResetPrimOp::inferUnaryReturnType(FIRRTLType input,
5034  std::optional<Location> loc) {
5035  auto base = type_dyn_cast<FIRRTLBaseType>(input);
5036  if (!base)
5037  return emitInferRetTypeError(loc,
5038  "operand must be single bit scalar base type");
5039  int32_t width = base.getBitWidthOrSentinel();
5040  if (width == -2 || width == 0 || width > 1)
5041  return emitInferRetTypeError(loc, "operand must be single bit scalar type");
5042  return AsyncResetType::get(input.getContext(), base.isConst());
5043 }
5044 
5045 FIRRTLType AsClockPrimOp::inferUnaryReturnType(FIRRTLType input,
5046  std::optional<Location> loc) {
5047  return ClockType::get(input.getContext(), isConst(input));
5048 }
5049 
5050 FIRRTLType CvtPrimOp::inferUnaryReturnType(FIRRTLType input,
5051  std::optional<Location> loc) {
5052  if (auto uiType = type_dyn_cast<UIntType>(input)) {
5053  auto width = uiType.getWidthOrSentinel();
5054  if (width != -1)
5055  ++width;
5056  return SIntType::get(input.getContext(), width, uiType.isConst());
5057  }
5058 
5059  if (type_isa<SIntType>(input))
5060  return input;
5061 
5062  return emitInferRetTypeError(loc, "operand must have integer type");
5063 }
5064 
5065 FIRRTLType NegPrimOp::inferUnaryReturnType(FIRRTLType input,
5066  std::optional<Location> loc) {
5067  auto inputi = type_dyn_cast<IntType>(input);
5068  if (!inputi)
5069  return emitInferRetTypeError(loc, "operand must have integer type");
5070  int32_t width = inputi.getWidthOrSentinel();
5071  if (width != -1)
5072  ++width;
5073  return SIntType::get(input.getContext(), width, inputi.isConst());
5074 }
5075 
5076 FIRRTLType NotPrimOp::inferUnaryReturnType(FIRRTLType input,
5077  std::optional<Location> loc) {
5078  auto inputi = type_dyn_cast<IntType>(input);
5079  if (!inputi)
5080  return emitInferRetTypeError(loc, "operand must have integer type");
5081  return UIntType::get(input.getContext(), inputi.getWidthOrSentinel(),
5082  inputi.isConst());
5083 }
5084 
5086  std::optional<Location> loc) {
5087  return UIntType::get(input.getContext(), 1, isConst(input));
5088 }
5089 
5090 //===----------------------------------------------------------------------===//
5091 // Other Operations
5092 //===----------------------------------------------------------------------===//
5093 
5094 LogicalResult BitsPrimOp::validateArguments(ValueRange operands,
5095  ArrayRef<NamedAttribute> attrs,
5096  Location loc) {
5097  if (operands.size() != 1 || attrs.size() != 2) {
5098  mlir::emitError(loc, "operation requires one operand and two constants");
5099  return failure();
5100  }
5101  return success();
5102 }
5103 
5104 FIRRTLType BitsPrimOp::inferReturnType(ValueRange operands,
5105  ArrayRef<NamedAttribute> attrs,
5106  std::optional<Location> loc) {
5107  auto input = operands[0].getType();
5108  auto high = getAttr<IntegerAttr>(attrs, "hi").getValue().getSExtValue();
5109  auto low = getAttr<IntegerAttr>(attrs, "lo").getValue().getSExtValue();
5110 
5111  auto inputi = type_dyn_cast<IntType>(input);
5112  if (!inputi)
5113  return emitInferRetTypeError(
5114  loc, "input type should be the int type but got ", input);
5115 
5116  // High must be >= low and both most be non-negative.
5117  if (high < low)
5118  return emitInferRetTypeError(
5119  loc, "high must be equal or greater than low, but got high = ", high,
5120  ", low = ", low);
5121 
5122  if (low < 0)
5123  return emitInferRetTypeError(loc, "low must be non-negative but got ", low);
5124 
5125  // If the input has staticly known width, check it. Both and low must be
5126  // strictly less than width.
5127  int32_t width = inputi.getWidthOrSentinel();
5128  if (width != -1 && high >= width)
5129  return emitInferRetTypeError(
5130  loc,
5131  "high must be smaller than the width of input, but got high = ", high,
5132  ", width = ", width);
5133 
5134  return UIntType::get(input.getContext(), high - low + 1, inputi.isConst());
5135 }
5136 
5137 LogicalResult impl::validateOneOperandOneConst(ValueRange operands,
5138  ArrayRef<NamedAttribute> attrs,
5139  Location loc) {
5140  if (operands.size() != 1 || attrs.size() != 1) {
5141  mlir::emitError(loc, "operation requires one operand and one constant");
5142  return failure();
5143  }
5144  return success();
5145 }
5146 
5147 FIRRTLType HeadPrimOp::inferReturnType(ValueRange operands,
5148  ArrayRef<NamedAttribute> attrs,
5149  std::optional<Location> loc) {
5150  auto input = operands[0].getType();
5151  auto amount = getAttr<IntegerAttr>(attrs, "amount").getValue().getSExtValue();
5152 
5153  auto inputi = type_dyn_cast<IntType>(input);
5154  if (amount < 0 || !inputi)
5155  return emitInferRetTypeError(
5156  loc, "operand must have integer type and amount must be >= 0");
5157 
5158  int32_t width = inputi.getWidthOrSentinel();
5159  if (width != -1 && amount > width)
5160  return emitInferRetTypeError(loc, "amount larger than input width");
5161 
5162  return UIntType::get(input.getContext(), amount, inputi.isConst());
5163 }
5164 
5165 LogicalResult MuxPrimOp::validateArguments(ValueRange operands,
5166  ArrayRef<NamedAttribute> attrs,
5167  Location loc) {
5168  if (operands.size() != 3 || attrs.size() != 0) {
5169  mlir::emitError(loc, "operation requires three operands and no constants");
5170  return failure();
5171  }
5172  return success();
5173 }
5174 
5175 /// Infer the result type for a multiplexer given its two operand types, which
5176 /// may be aggregates.
5177 ///
5178 /// This essentially performs a pairwise comparison of fields and elements, as
5179 /// follows:
5180 /// - Identical operands inferred to their common type
5181 /// - Integer operands inferred to the larger one if both have a known width, a
5182 /// widthless integer otherwise.
5183 /// - Vectors inferred based on the element type.
5184 /// - Bundles inferred in a pairwise fashion based on the field types.
5186  FIRRTLBaseType low,
5187  bool isConstCondition,
5188  std::optional<Location> loc) {
5189  // If the types are identical we're done.
5190  if (high == low)
5191  return isConstCondition ? low : low.getAllConstDroppedType();
5192 
5193  // The base types need to be equivalent.
5194  if (high.getTypeID() != low.getTypeID())
5195  return emitInferRetTypeError<FIRRTLBaseType>(
5196  loc, "incompatible mux operand types, true value type: ", high,
5197  ", false value type: ", low);
5198 
5199  bool outerTypeIsConst = isConstCondition && low.isConst() && high.isConst();
5200 
5201  // Two different Int types can be compatible. If either has unknown width,
5202  // then return it. If both are known but different width, then return the
5203  // larger one.
5204  if (type_isa<IntType>(low)) {
5205  int32_t highWidth = high.getBitWidthOrSentinel();
5206  int32_t lowWidth = low.getBitWidthOrSentinel();
5207  if (lowWidth == -1)
5208  return low.getConstType(outerTypeIsConst);
5209  if (highWidth == -1)
5210  return high.getConstType(outerTypeIsConst);
5211  return (lowWidth > highWidth ? low : high).getConstType(outerTypeIsConst);
5212  }
5213 
5214  // Infer vector types by comparing the element types.
5215  auto highVector = type_dyn_cast<FVectorType>(high);
5216  auto lowVector = type_dyn_cast<FVectorType>(low);
5217  if (highVector && lowVector &&
5218  highVector.getNumElements() == lowVector.getNumElements()) {
5219  auto inner = inferMuxReturnType(highVector.getElementTypePreservingConst(),
5220  lowVector.getElementTypePreservingConst(),
5221  isConstCondition, loc);
5222  if (!inner)
5223  return {};
5224  return FVectorType::get(inner, lowVector.getNumElements(),
5225  outerTypeIsConst);
5226  }
5227 
5228  // Infer bundle types by inferring names in a pairwise fashion.
5229  auto highBundle = type_dyn_cast<BundleType>(high);
5230  auto lowBundle = type_dyn_cast<BundleType>(low);
5231  if (highBundle && lowBundle) {
5232  auto highElements = highBundle.getElements();
5233  auto lowElements = lowBundle.getElements();
5234  size_t numElements = highElements.size();
5235 
5236  SmallVector<BundleType::BundleElement> newElements;
5237  if (numElements == lowElements.size()) {
5238  bool failed = false;
5239  for (size_t i = 0; i < numElements; ++i) {
5240  if (highElements[i].name != lowElements[i].name ||
5241  highElements[i].isFlip != lowElements[i].isFlip) {
5242  failed = true;
5243  break;
5244  }
5245  auto element = highElements[i];
5246  element.type = inferMuxReturnType(
5247  highBundle.getElementTypePreservingConst(i),
5248  lowBundle.getElementTypePreservingConst(i), isConstCondition, loc);
5249  if (!element.type)
5250  return {};
5251  newElements.push_back(element);
5252  }
5253  if (!failed)
5254  return BundleType::get(low.getContext(), newElements, outerTypeIsConst);
5255  }
5256  return emitInferRetTypeError<FIRRTLBaseType>(
5257  loc, "incompatible mux operand bundle fields, true value type: ", high,
5258  ", false value type: ", low);
5259  }
5260 
5261  // If we arrive here the types of the two mux arms are fundamentally
5262  // incompatible.
5263  return emitInferRetTypeError<FIRRTLBaseType>(
5264  loc, "invalid mux operand types, true value type: ", high,
5265  ", false value type: ", low);
5266 }
5267 
5268 FIRRTLType MuxPrimOp::inferReturnType(ValueRange operands,
5269  ArrayRef<NamedAttribute> attrs,
5270  std::optional<Location> loc) {
5271  auto highType = type_dyn_cast<FIRRTLBaseType>(operands[1].getType());
5272  auto lowType = type_dyn_cast<FIRRTLBaseType>(operands[2].getType());
5273  if (!highType || !lowType)
5274  return emitInferRetTypeError(loc, "operands must be base type");
5275  return inferMuxReturnType(highType, lowType, isConst(operands[0].getType()),
5276  loc);
5277 }
5278 
5279 FIRRTLType Mux2CellIntrinsicOp::inferReturnType(ValueRange operands,
5280  ArrayRef<NamedAttribute> attrs,
5281  std::optional<Location> loc) {
5282  auto highType = type_dyn_cast<FIRRTLBaseType>(operands[1].getType());
5283  auto lowType = type_dyn_cast<FIRRTLBaseType>(operands[2].getType());
5284  if (!highType || !lowType)
5285  return emitInferRetTypeError(loc, "operands must be base type");
5286  return inferMuxReturnType(highType, lowType, isConst(operands[0].getType()),
5287  loc);
5288 }
5289 
5290 FIRRTLType Mux4CellIntrinsicOp::inferReturnType(ValueRange operands,
5291  ArrayRef<NamedAttribute> attrs,
5292  std::optional<Location> loc) {
5293  SmallVector<FIRRTLBaseType> types;
5294  FIRRTLBaseType result;
5295  for (unsigned i = 1; i < 5; i++) {
5296  types.push_back(type_dyn_cast<FIRRTLBaseType>(operands[i].getType()));
5297  if (!types.back())
5298  return emitInferRetTypeError(loc, "operands must be base type");
5299  if (result) {
5300  result = inferMuxReturnType(result, types.back(),
5301  isConst(operands[0].getType()), loc);
5302  if (!result)
5303  return result;
5304  } else {
5305  result = types.back();
5306  }
5307  }
5308  return result;
5309 }
5310 
5311 FIRRTLType PadPrimOp::inferReturnType(ValueRange operands,
5312  ArrayRef<NamedAttribute> attrs,
5313  std::optional<Location> loc) {
5314  auto input = operands[0].getType();
5315  auto amount = getAttr<IntegerAttr>(attrs, "amount").getValue().getSExtValue();
5316 
5317  auto inputi = type_dyn_cast<IntType>(input);
5318  if (amount < 0 || !inputi)
5319  return emitInferRetTypeError(
5320  loc, "pad input must be integer and amount must be >= 0");
5321 
5322  int32_t width = inputi.getWidthOrSentinel();
5323  if (width == -1)
5324  return inputi;
5325 
5326  width = std::max<int32_t>(width, amount);
5327  return IntType::get(input.getContext(), inputi.isSigned(), width,
5328  inputi.isConst());
5329 }
5330 
5331 FIRRTLType ShlPrimOp::inferReturnType(ValueRange operands,
5332  ArrayRef<NamedAttribute> attrs,
5333  std::optional<Location> loc) {
5334  auto input = operands[0].getType();
5335  auto amount = getAttr<IntegerAttr>(attrs, "amount").getValue().getSExtValue();
5336 
5337  auto inputi = type_dyn_cast<IntType>(input);
5338  if (amount < 0 || !inputi)
5339  return emitInferRetTypeError(
5340  loc, "shl input must be integer and amount must be >= 0");
5341 
5342  int32_t width = inputi.getWidthOrSentinel();
5343  if (width != -1)
5344  width += amount;
5345 
5346  return IntType::get(input.getContext(), inputi.isSigned(), width,
5347  inputi.isConst());
5348 }
5349 
5350 FIRRTLType ShrPrimOp::inferReturnType(ValueRange operands,
5351  ArrayRef<NamedAttribute> attrs,
5352  std::optional<Location> loc) {
5353  auto input = operands[0].getType();
5354  auto amount = getAttr<IntegerAttr>(attrs, "amount").getValue().getSExtValue();
5355 
5356  auto inputi = type_dyn_cast<IntType>(input);
5357  if (amount < 0 || !inputi)
5358  return emitInferRetTypeError(
5359  loc, "shr input must be integer and amount must be >= 0");
5360 
5361  int32_t width = inputi.getWidthOrSentinel();
5362  if (width != -1) {
5363  // UInt saturates at 0 bits, SInt at 1 bit
5364  int32_t minWidth = inputi.isUnsigned() ? 0 : 1;
5365  width = std::max<int32_t>(minWidth, width - amount);
5366  }
5367 
5368  return IntType::get(input.getContext(), inputi.isSigned(), width,
5369  inputi.isConst());
5370 }
5371 
5372 FIRRTLType TailPrimOp::inferReturnType(ValueRange operands,
5373  ArrayRef<NamedAttribute> attrs,
5374  std::optional<Location> loc) {
5375  auto input = operands[0].getType();
5376  auto amount = getAttr<IntegerAttr>(attrs, "amount").getValue().getSExtValue();
5377 
5378  auto inputi = type_dyn_cast<IntType>(input);
5379  if (amount < 0 || !inputi)
5380  return emitInferRetTypeError(
5381  loc, "tail input must be integer and amount must be >= 0");
5382 
5383  int32_t width = inputi.getWidthOrSentinel();
5384  if (width != -1) {
5385  if (width < amount)
5386  return emitInferRetTypeError(
5387  loc, "amount must be less than or equal operand width");
5388  width -= amount;
5389  }
5390 
5391  return IntType::get(input.getContext(), false, width, inputi.isConst());
5392 }
5393 
5394 //===----------------------------------------------------------------------===//
5395 // VerbatimExprOp
5396 //===----------------------------------------------------------------------===//
5397 
5399  function_ref<void(Value, StringRef)> setNameFn) {
5400  // If the text is macro like, then use a pretty name. We only take the
5401  // text up to a weird character (like a paren) and currently ignore
5402  // parenthesized expressions.
5403  auto isOkCharacter = [](char c) { return llvm::isAlnum(c) || c == '_'; };
5404  auto name = getText();
5405  // Ignore a leading ` in macro name.
5406  if (name.starts_with("`"))
5407  name = name.drop_front();
5408  name = name.take_while(isOkCharacter);
5409  if (!name.empty())
5410  setNameFn(getResult(), name);
5411 }
5412 
5413 //===----------------------------------------------------------------------===//
5414 // VerbatimWireOp
5415 //===----------------------------------------------------------------------===//
5416 
5418  function_ref<void(Value, StringRef)> setNameFn) {
5419  // If the text is macro like, then use a pretty name. We only take the
5420  // text up to a weird character (like a paren) and currently ignore
5421  // parenthesized expressions.
5422  auto isOkCharacter = [](char c) { return llvm::isAlnum(c) || c == '_'; };
5423  auto name = getText();
5424  // Ignore a leading ` in macro name.
5425  if (name.starts_with("`"))
5426  name = name.drop_front();
5427  name = name.take_while(isOkCharacter);
5428  if (!name.empty())
5429  setNameFn(getResult(), name);
5430 }
5431 
5432 //===----------------------------------------------------------------------===//
5433 // Conversions to/from structs in the standard dialect.
5434 //===----------------------------------------------------------------------===//
5435 
5436 LogicalResult HWStructCastOp::verify() {
5437  // We must have a bundle and a struct, with matching pairwise fields
5438  BundleType bundleType;
5439  hw::StructType structType;
5440  if ((bundleType = type_dyn_cast<BundleType>(getOperand().getType()))) {
5441  structType = getType().dyn_cast<hw::StructType>();
5442  if (!structType)
5443  return emitError("result type must be a struct");
5444  } else if ((bundleType = type_dyn_cast<BundleType>(getType()))) {
5445  structType = getOperand().getType().dyn_cast<hw::StructType>();
5446  if (!structType)
5447  return emitError("operand type must be a struct");
5448  } else {
5449  return emitError("either source or result type must be a bundle type");
5450  }
5451 
5452  auto firFields = bundleType.getElements();
5453  auto hwFields = structType.getElements();
5454  if (firFields.size() != hwFields.size())
5455  return emitError("bundle and struct have different number of fields");
5456 
5457  for (size_t findex = 0, fend = firFields.size(); findex < fend; ++findex) {
5458  if (firFields[findex].name.getValue() != hwFields[findex].name)
5459  return emitError("field names don't match '")
5460  << firFields[findex].name.getValue() << "', '"
5461  << hwFields[findex].name.getValue() << "'";
5462  int64_t firWidth =
5463  FIRRTLBaseType(firFields[findex].type).getBitWidthOrSentinel();
5464  int64_t hwWidth = hw::getBitWidth(hwFields[findex].type);
5465  if (firWidth > 0 && hwWidth > 0 && firWidth != hwWidth)
5466  return emitError("size of field '")
5467  << hwFields[findex].name.getValue() << "' don't match " << firWidth
5468  << ", " << hwWidth;
5469  }
5470 
5471  return success();
5472 }
5473 
5474 LogicalResult BitCastOp::verify() {
5475  auto inTypeBits = getBitWidth(getInput().getType(), /*ignoreFlip=*/true);
5476  auto resTypeBits = getBitWidth(getType());
5477  if (inTypeBits.has_value() && resTypeBits.has_value()) {
5478  // Bitwidths must match for valid bit
5479  if (*inTypeBits == *resTypeBits) {
5480  // non-'const' cannot be casted to 'const'
5481  if (containsConst(getType()) && !isConst(getOperand().getType()))
5482  return emitError("cannot cast non-'const' input type ")
5483  << getOperand().getType() << " to 'const' result type "
5484  << getType();
5485  return success();
5486  }
5487  return emitError("the bitwidth of input (")
5488  << *inTypeBits << ") and result (" << *resTypeBits
5489  << ") don't match";
5490  }
5491  if (!inTypeBits.has_value())
5492  return emitError("bitwidth cannot be determined for input operand type ")
5493  << getInput().getType();
5494  return emitError("bitwidth cannot be determined for result type ")
5495  << getType();
5496 }
5497 
5498 //===----------------------------------------------------------------------===//
5499 // Custom attr-dict Directive that Elides Annotations
5500 //===----------------------------------------------------------------------===//
5501 
5502 /// Parse an optional attribute dictionary, adding an empty 'annotations'
5503 /// attribute if not specified.
5504 static ParseResult parseElideAnnotations(OpAsmParser &parser,
5505  NamedAttrList &resultAttrs) {
5506  auto result = parser.parseOptionalAttrDict(resultAttrs);
5507  if (!resultAttrs.get("annotations"))
5508  resultAttrs.append("annotations", parser.getBuilder().getArrayAttr({}));
5509 
5510  return result;
5511 }
5512 
5513 static void printElideAnnotations(OpAsmPrinter &p, Operation *op,
5514  DictionaryAttr attr,
5515  ArrayRef<StringRef> extraElides = {}) {
5516  SmallVector<StringRef> elidedAttrs(extraElides.begin(), extraElides.end());
5517  // Elide "annotations" if it is empty.
5518  if (op->getAttrOfType<ArrayAttr>("annotations").empty())
5519  elidedAttrs.push_back("annotations");
5520  // Elide "nameKind".
5521  elidedAttrs.push_back("nameKind");
5522 
5523  p.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
5524 }
5525 
5526 /// Parse an optional attribute dictionary, adding empty 'annotations' and
5527 /// 'portAnnotations' attributes if not specified.
5528 static ParseResult parseElidePortAnnotations(OpAsmParser &parser,
5529  NamedAttrList &resultAttrs) {
5530  auto result = parseElideAnnotations(parser, resultAttrs);
5531 
5532  if (!resultAttrs.get("portAnnotations")) {
5533  SmallVector<Attribute, 16> portAnnotations(
5534  parser.getNumResults(), parser.getBuilder().getArrayAttr({}));
5535  resultAttrs.append("portAnnotations",
5536  parser.getBuilder().getArrayAttr(portAnnotations));
5537  }
5538  return result;
5539 }
5540 
5541 // Elide 'annotations' and 'portAnnotations' attributes if they are empty.
5542 static void printElidePortAnnotations(OpAsmPrinter &p, Operation *op,
5543  DictionaryAttr attr,
5544  ArrayRef<StringRef> extraElides = {}) {
5545  SmallVector<StringRef, 2> elidedAttrs(extraElides.begin(), extraElides.end());
5546 
5547  if (llvm::all_of(op->getAttrOfType<ArrayAttr>("portAnnotations"),
5548  [&](Attribute a) { return cast<ArrayAttr>(a).empty(); }))
5549  elidedAttrs.push_back("portAnnotations");
5550  printElideAnnotations(p, op, attr, elidedAttrs);
5551 }
5552 
5553 //===----------------------------------------------------------------------===//
5554 // NameKind Custom Directive
5555 //===----------------------------------------------------------------------===//
5556 
5557 static ParseResult parseNameKind(OpAsmParser &parser,
5558  firrtl::NameKindEnumAttr &result) {
5559  StringRef keyword;
5560 
5561  if (!parser.parseOptionalKeyword(&keyword,
5562  {"interesting_name", "droppable_name"})) {
5563  auto kind = symbolizeNameKindEnum(keyword);
5564  result = NameKindEnumAttr::get(parser.getContext(), kind.value());
5565  return success();
5566  }
5567 
5568  // Default is droppable name.
5569  result =
5570  NameKindEnumAttr::get(parser.getContext(), NameKindEnum::DroppableName);
5571  return success();
5572 }
5573 
5574 static void printNameKind(OpAsmPrinter &p, Operation *op,
5575  firrtl::NameKindEnumAttr attr,
5576  ArrayRef<StringRef> extraElides = {}) {
5577  if (attr.getValue() != NameKindEnum::DroppableName)
5578  p << " " << stringifyNameKindEnum(attr.getValue());
5579 }
5580 
5581 //===----------------------------------------------------------------------===//
5582 // ImplicitSSAName Custom Directive
5583 //===----------------------------------------------------------------------===//
5584 
5585 static ParseResult parseFIRRTLImplicitSSAName(OpAsmParser &parser,
5586  NamedAttrList &resultAttrs) {
5587  if (parseElideAnnotations(parser, resultAttrs))
5588  return failure();
5589  inferImplicitSSAName(parser, resultAttrs);
5590  return success();
5591 }
5592 
5593 static void printFIRRTLImplicitSSAName(OpAsmPrinter &p, Operation *op,
5594  DictionaryAttr attrs) {
5595  SmallVector<StringRef, 4> elides;
5597  elides.push_back(Forceable::getForceableAttrName());
5598  elideImplicitSSAName(p, op, attrs, elides);
5599  printElideAnnotations(p, op, attrs, elides);
5600 }
5601 
5602 //===----------------------------------------------------------------------===//
5603 // MemOp Custom attr-dict Directive
5604 //===----------------------------------------------------------------------===//
5605 
5606 static ParseResult parseMemOp(OpAsmParser &parser, NamedAttrList &resultAttrs) {
5607  return parseElidePortAnnotations(parser, resultAttrs);
5608 }
5609 
5610 /// Always elide "ruw" and elide "annotations" if it exists or if it is empty.
5611 static void printMemOp(OpAsmPrinter &p, Operation *op, DictionaryAttr attr) {
5612  // "ruw" and "inner_sym" is always elided.
5613  printElidePortAnnotations(p, op, attr, {"ruw", "inner_sym"});
5614 }
5615 
5616 //===----------------------------------------------------------------------===//
5617 // ClassInterface custom directive
5618 //===----------------------------------------------------------------------===//
5619 
5620 static ParseResult parseClassInterface(OpAsmParser &parser, Type &result) {
5621  ClassType type;
5622  if (ClassType::parseInterface(parser, type))
5623  return failure();
5624  result = type;
5625  return success();
5626 }
5627 
5628 static void printClassInterface(OpAsmPrinter &p, Operation *, ClassType type) {
5629  type.printInterface(p);
5630 }
5631 
5632 //===----------------------------------------------------------------------===//
5633 // Miscellaneous custom elision logic.
5634 //===----------------------------------------------------------------------===//
5635 
5636 static ParseResult parseElideEmptyName(OpAsmParser &p,
5637  NamedAttrList &resultAttrs) {
5638  auto result = p.parseOptionalAttrDict(resultAttrs);
5639  if (!resultAttrs.get("name"))
5640  resultAttrs.append("name", p.getBuilder().getStringAttr(""));
5641 
5642  return result;
5643 }
5644 
5645 static void printElideEmptyName(OpAsmPrinter &p, Operation *op,
5646  DictionaryAttr attr,
5647  ArrayRef<StringRef> extraElides = {}) {
5648  SmallVector<StringRef> elides(extraElides.begin(), extraElides.end());
5649  if (op->getAttrOfType<StringAttr>("name").getValue().empty())
5650  elides.push_back("name");
5651 
5652  p.printOptionalAttrDict(op->getAttrs(), elides);
5653 }
5654 
5655 static ParseResult parsePrintfAttrs(OpAsmParser &p,
5656  NamedAttrList &resultAttrs) {
5657  return parseElideEmptyName(p, resultAttrs);
5658 }
5659 
5660 static void printPrintfAttrs(OpAsmPrinter &p, Operation *op,
5661  DictionaryAttr attr) {
5662  printElideEmptyName(p, op, attr, {"formatString"});
5663 }
5664 
5665 static ParseResult parseStopAttrs(OpAsmParser &p, NamedAttrList &resultAttrs) {
5666  return parseElideEmptyName(p, resultAttrs);
5667 }
5668 
5669 static void printStopAttrs(OpAsmPrinter &p, Operation *op,
5670  DictionaryAttr attr) {
5671  printElideEmptyName(p, op, attr, {"exitCode"});
5672 }
5673 
5674 static ParseResult parseVerifAttrs(OpAsmParser &p, NamedAttrList &resultAttrs) {
5675  return parseElideEmptyName(p, resultAttrs);
5676 }
5677 
5678 static void printVerifAttrs(OpAsmPrinter &p, Operation *op,
5679  DictionaryAttr attr) {
5680  printElideEmptyName(p, op, attr, {"message"});
5681 }
5682 
5683 //===----------------------------------------------------------------------===//
5684 // Various namers.
5685 //===----------------------------------------------------------------------===//
5686 
5687 static void genericAsmResultNames(Operation *op,
5688  OpAsmSetValueNameFn setNameFn) {
5689  // Many firrtl dialect operations have an optional 'name' attribute. If
5690  // present, use it.
5691  if (op->getNumResults() == 1)
5692  if (auto nameAttr = op->getAttrOfType<StringAttr>("name"))
5693  setNameFn(op->getResult(0), nameAttr.getValue());
5694 }
5695 
5697  genericAsmResultNames(*this, setNameFn);
5698 }
5699 
5701  genericAsmResultNames(*this, setNameFn);
5702 }
5703 
5705  genericAsmResultNames(*this, setNameFn);
5706 }
5707 
5709  genericAsmResultNames(*this, setNameFn);
5710 }
5712  genericAsmResultNames(*this, setNameFn);
5713 }
5715  genericAsmResultNames(*this, setNameFn);
5716 }
5718  genericAsmResultNames(*this, setNameFn);
5719 }
5721  genericAsmResultNames(*this, setNameFn);
5722 }
5724  genericAsmResultNames(*this, setNameFn);
5725 }
5727  genericAsmResultNames(*this, setNameFn);
5728 }
5730  genericAsmResultNames(*this, setNameFn);
5731 }
5733  genericAsmResultNames(*this, setNameFn);
5734 }
5736  genericAsmResultNames(*this, setNameFn);
5737 }
5739  genericAsmResultNames(*this, setNameFn);
5740 }
5742  genericAsmResultNames(*this, setNameFn);
5743 }
5745  genericAsmResultNames(*this, setNameFn);
5746 }
5748  genericAsmResultNames(*this, setNameFn);
5749 }
5751  genericAsmResultNames(*this, setNameFn);
5752 }
5754  genericAsmResultNames(*this, setNameFn);
5755 }
5757  genericAsmResultNames(*this, setNameFn);
5758 }
5760  genericAsmResultNames(*this, setNameFn);
5761 }
5763  genericAsmResultNames(*this, setNameFn);
5764 }
5766  genericAsmResultNames(*this, setNameFn);
5767 }
5769  genericAsmResultNames(*this, setNameFn);
5770 }
5772  OpAsmSetValueNameFn setNameFn) {
5773  genericAsmResultNames(*this, setNameFn);
5774 }
5776  genericAsmResultNames(*this, setNameFn);
5777 }
5779  genericAsmResultNames(*this, setNameFn);
5780 }
5782  genericAsmResultNames(*this, setNameFn);
5783 }
5785  genericAsmResultNames(*this, setNameFn);
5786 }
5788  genericAsmResultNames(*this, setNameFn);
5789 }
5791  genericAsmResultNames(*this, setNameFn);
5792 }
5794  genericAsmResultNames(*this, setNameFn);
5795 }
5797  genericAsmResultNames(*this, setNameFn);
5798 }
5800  genericAsmResultNames(*this, setNameFn);
5801 }
5803  genericAsmResultNames(*this, setNameFn);
5804 }
5806  genericAsmResultNames(*this, setNameFn);
5807 }
5809  genericAsmResultNames(*this, setNameFn);
5810 }
5812  genericAsmResultNames(*this, setNameFn);
5813 }
5815  genericAsmResultNames(*this, setNameFn);
5816 }
5818  genericAsmResultNames(*this, setNameFn);
5819 }
5821  genericAsmResultNames(*this, setNameFn);
5822 }
5824  genericAsmResultNames(*this, setNameFn);
5825 }
5826 
5828  genericAsmResultNames(*this, setNameFn);
5829 }
5830 
5832  genericAsmResultNames(*this, setNameFn);
5833 }
5834 
5836  genericAsmResultNames(*this, setNameFn);
5837 }
5839  genericAsmResultNames(*this, setNameFn);
5840 }
5841 
5843  genericAsmResultNames(*this, setNameFn);
5844 }
5845 
5847  genericAsmResultNames(*this, setNameFn);
5848 }
5849 
5851  genericAsmResultNames(*this, setNameFn);
5852 }
5853 
5855  genericAsmResultNames(*this, setNameFn);
5856 }
5857 
5859  genericAsmResultNames(*this, setNameFn);
5860 }
5861 
5863  genericAsmResultNames(*this, setNameFn);
5864 }
5865 
5867  genericAsmResultNames(*this, setNameFn);
5868 }
5869 
5871  genericAsmResultNames(*this, setNameFn);
5872 }
5873 
5875  genericAsmResultNames(*this, setNameFn);
5876 }
5877 
5879  genericAsmResultNames(*this, setNameFn);
5880 }
5881 
5883  genericAsmResultNames(*this, setNameFn);
5884 }
5885 
5887  genericAsmResultNames(*this, setNameFn);
5888 }
5889 
5890 //===----------------------------------------------------------------------===//
5891 // RefOps
5892 //===----------------------------------------------------------------------===//
5893 
5895  genericAsmResultNames(*this, setNameFn);
5896 }
5897 
5899  genericAsmResultNames(*this, setNameFn);
5900 }
5901 
5903  genericAsmResultNames(*this, setNameFn);
5904 }
5905 
5907  genericAsmResultNames(*this, setNameFn);
5908 }
5909 
5911  genericAsmResultNames(*this, setNameFn);
5912 }
5913 
5914 FIRRTLType RefResolveOp::inferReturnType(ValueRange operands,
5915  ArrayRef<NamedAttribute> attrs,
5916  std::optional<Location> loc) {
5917  Type inType = operands[0].getType();
5918  auto inRefType = type_dyn_cast<RefType>(inType);
5919  if (!inRefType)
5920  return emitInferRetTypeError(
5921  loc, "ref.resolve operand must be ref type, not ", inType);
5922  return inRefType.getType();
5923 }
5924 
5925 FIRRTLType RefSendOp::inferReturnType(ValueRange operands,
5926  ArrayRef<NamedAttribute> attrs,
5927  std::optional<Location> loc) {
5928  Type inType = operands[0].getType();
5929  auto inBaseType = type_dyn_cast<FIRRTLBaseType>(inType);
5930  if (!inBaseType)
5931  return emitInferRetTypeError(
5932  loc, "ref.send operand must be base type, not ", inType);
5933  return RefType::get(inBaseType.getPassiveType());
5934 }
5935 
5936 FIRRTLType RefSubOp::inferReturnType(ValueRange operands,
5937  ArrayRef<NamedAttribute> attrs,
5938  std::optional<Location> loc) {
5939  auto refType = type_dyn_cast<RefType>(operands[0].getType());
5940  if (!refType)
5941  return emitInferRetTypeError(loc, "input must be of reference type");
5942  auto inType = refType.getType();
5943  auto fieldIdx =
5944  getAttr<IntegerAttr>(attrs, "index").getValue().getZExtValue();
5945 
5946  // TODO: Determine ref.sub + rwprobe behavior, test.
5947  // Probably best to demote to non-rw, but that has implications
5948  // for any LowerTypes behavior being relied on.
5949  // Allow for now, as need to LowerTypes things generally.
5950  if (auto vectorType = type_dyn_cast<FVectorType>(inType)) {
5951  if (fieldIdx < vectorType.getNumElements())
5952  return RefType::get(
5953  vectorType.getElementType().getConstType(
5954  vectorType.isConst() || vectorType.getElementType().isConst()),
5955  refType.getForceable(), refType.getLayer());
5956  return emitInferRetTypeError(loc, "out of range index '", fieldIdx,
5957  "' in RefType of vector type ", refType);
5958  }
5959  if (auto bundleType = type_dyn_cast<BundleType>(inType)) {
5960  if (fieldIdx >= bundleType.getNumElements()) {
5961  return emitInferRetTypeError(loc,
5962  "subfield element index is greater than "
5963  "the number of fields in the bundle type");
5964  }
5965  return RefType::get(bundleType.getElement(fieldIdx).type.getConstType(
5966  bundleType.isConst() ||
5967  bundleType.getElement(fieldIdx).type.isConst()),
5968  refType.getForceable(), refType.getLayer());
5969  }
5970 
5971  return emitInferRetTypeError(
5972  loc, "ref.sub op requires a RefType of vector or bundle base type");
5973 }
5974 
5975 LogicalResult RefCastOp::verify() {
5976  auto srcLayers = getLayersFor(getInput());
5977  auto dstLayers = getLayersFor(getResult());
5978  SmallVector<SymbolRefAttr> missingLayers;
5979  if (!isLayerSetCompatibleWith(srcLayers, dstLayers, missingLayers)) {
5980  auto diag =
5981  emitOpError("cannot discard layer requirements of input reference");
5982  auto &note = diag.attachNote();
5983  note << "discarding layer requirements: ";
5984  llvm::interleaveComma(missingLayers, note);
5985  return failure();
5986  }
5987  return success();
5988 }
5989 
5990 LogicalResult RefResolveOp::verify() {
5991  auto srcLayers = getLayersFor(getRef());
5992  auto dstLayers = getAmbientLayersAt(getOperation());
5993  SmallVector<SymbolRefAttr> missingLayers;
5994  if (!isLayerSetCompatibleWith(srcLayers, dstLayers, missingLayers)) {
5995  auto diag =
5996  emitOpError("ambient layers are insufficient to resolve reference");
5997  auto &note = diag.attachNote();
5998  note << "missing layer requirements: ";
5999  interleaveComma(missingLayers, note);
6000  return failure();
6001  }
6002  return success();
6003 }
6004 
6005 LogicalResult RWProbeOp::verifyInnerRefs(hw::InnerRefNamespace &ns) {
6006  auto targetRef = getTarget();
6007  if (targetRef.getModule() !=
6008  (*this)->getParentOfType<FModuleLike>().getModuleNameAttr())
6009  return emitOpError() << "has non-local target";
6010 
6011  auto target = ns.lookup(targetRef);
6012  if (!target)
6013  return emitOpError() << "has target that cannot be resolved: " << targetRef;
6014 
6015  auto checkFinalType = [&](auto type, Location loc) -> LogicalResult {
6016  // Determine final type.
6017  mlir::Type fType =
6018  hw::FieldIdImpl::getFinalTypeByFieldID(type, target.getField());
6019  // Check.
6020  auto baseType = type_dyn_cast<FIRRTLBaseType>(fType);
6021  if (!baseType || baseType.getPassiveType() != getType().getType()) {
6022  auto diag = emitOpError("has type mismatch: target resolves to ")
6023  << fType << " instead of expected " << getType().getType();
6024  diag.attachNote(loc) << "target resolves here";
6025  return diag;
6026  }
6027  return success();
6028  };
6029  if (target.isPort()) {
6030  auto mod = cast<FModuleLike>(target.getOp());
6031  return checkFinalType(mod.getPortType(target.getPort()),
6032  mod.getPortLocation(target.getPort()));
6033  }
6034  hw::InnerSymbolOpInterface symOp =
6035  cast<hw::InnerSymbolOpInterface>(target.getOp());
6036  if (!symOp.getTargetResult())
6037  return emitOpError("has target that cannot be probed")
6038  .attachNote(symOp.getLoc())
6039  .append("target resolves here");
6040  auto *ancestor =
6041  symOp.getTargetResult().getParentBlock()->findAncestorOpInBlock(**this);
6042  if (!ancestor || !symOp->isBeforeInBlock(ancestor))
6043  return emitOpError("is not dominated by target")
6044  .attachNote(symOp.getLoc())
6045  .append("target here");
6046  return checkFinalType(symOp.getTargetResult().getType(), symOp.getLoc());
6047 }
6048 
6049 //===----------------------------------------------------------------------===//
6050 // Layer Block Operations
6051 //===----------------------------------------------------------------------===//
6052 
6053 LogicalResult LayerBlockOp::verify() {
6054  auto layerName = getLayerName();
6055  auto *parentOp = (*this)->getParentOp();
6056 
6057  // Verify the correctness of the symbol reference. Only verify that this
6058  // layer block makes sense in its parent module or layer block.
6059  auto nestedReferences = layerName.getNestedReferences();
6060  if (nestedReferences.empty()) {
6061  if (!isa<FModuleOp>(parentOp)) {
6062  auto diag = emitOpError() << "has an un-nested layer symbol, but does "
6063  "not have a 'firrtl.module' op as a parent";
6064  return diag.attachNote(parentOp->getLoc())
6065  << "illegal parent op defined here";
6066  }
6067  } else {
6068  auto parentLayerBlock = dyn_cast<LayerBlockOp>(parentOp);
6069  if (!parentLayerBlock) {
6070  auto diag = emitOpError()
6071  << "has a nested layer symbol, but does not have a '"
6072  << getOperationName() << "' op as a parent'";
6073  return diag.attachNote(parentOp->getLoc())
6074  << "illegal parent op defined here";
6075  }
6076  auto parentLayerBlockName = parentLayerBlock.getLayerName();
6077  if (parentLayerBlockName.getRootReference() !=
6078  layerName.getRootReference() ||
6079  parentLayerBlockName.getNestedReferences() !=
6080  layerName.getNestedReferences().drop_back()) {
6081  auto diag = emitOpError() << "is nested under an illegal layer block";
6082  return diag.attachNote(parentLayerBlock->getLoc())
6083  << "illegal parent layer block defined here";
6084  }
6085  }
6086 
6087  // Verify the body of the region.
6088  auto result = getBody(0)->walk<mlir::WalkOrder::PreOrder>(
6089  [&](Operation *op) -> WalkResult {
6090  // Skip nested layer blocks. Those will be verified separately.
6091  if (isa<LayerBlockOp>(op))
6092  return WalkResult::skip();
6093 
6094  // Check all the operands of each op to make sure that only legal things
6095  // are captured.
6096  for (auto operand : op->getOperands()) {
6097  // Any value captured from the current layer block is fine.
6098  if (auto *definingOp = operand.getDefiningOp())
6099  if (getOperation()->isAncestor(definingOp))
6100  continue;
6101 
6102  auto type = operand.getType();
6103 
6104  // Capture of a non-base type, e.g., reference, is allowed.
6105  if (isa<PropertyType>(type)) {
6106  auto diag = emitOpError() << "captures a property operand";
6107  diag.attachNote(operand.getLoc()) << "operand is defined here";
6108  diag.attachNote(op->getLoc()) << "operand is used here";
6109  return WalkResult::interrupt();
6110  }
6111 
6112  // Capturing a non-passive type is illegal.
6113  if (auto baseType = type_dyn_cast<FIRRTLBaseType>(type)) {
6114  if (!baseType.isPassive()) {
6115  auto diag = emitOpError()
6116  << "captures an operand which is not a passive type";
6117  diag.attachNote(operand.getLoc()) << "operand is defined here";
6118  diag.attachNote(op->getLoc()) << "operand is used here";
6119  return WalkResult::interrupt();
6120  }
6121  }
6122  }
6123 
6124  // Ensure that the layer block does not drive any sinks outside.
6125  if (auto connect = dyn_cast<FConnectLike>(op)) {
6126  // ref.define is allowed to drive probes outside the layerblock.
6127  if (isa<RefDefineOp>(connect))
6128  return WalkResult::advance();
6129 
6130  // We can drive any destination inside the current layerblock.
6131  auto dest = getFieldRefFromValue(connect.getDest()).getValue();
6132  if (auto *destOp = dest.getDefiningOp())
6133  if (getOperation()->isAncestor(destOp))
6134  return WalkResult::advance();
6135 
6136  auto diag =
6137  connect.emitOpError()
6138  << "connects to a destination which is defined outside its "
6139  "enclosing layer block";
6140  diag.attachNote(getLoc()) << "enclosing layer block is defined here";
6141  diag.attachNote(dest.getLoc()) << "destination is defined here";
6142  return WalkResult::interrupt();
6143  }
6144 
6145  return WalkResult::advance();
6146  });
6147 
6148  return failure(result.wasInterrupted());
6149 }
6150 
6151 LogicalResult
6152 LayerBlockOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
6153  auto layerOp =
6154  symbolTable.lookupNearestSymbolFrom<LayerOp>(*this, getLayerNameAttr());
6155  if (!layerOp) {
6156  return emitOpError("invalid symbol reference");
6157  }
6158 
6159  return success();
6160 }
6161 
6162 //===----------------------------------------------------------------------===//
6163 // TblGen Generated Logic.
6164 //===----------------------------------------------------------------------===//
6165 
6166 // Provide the autogenerated implementation guts for the Op classes.
6167 #define GET_OP_CLASSES
6168 #include "circt/Dialect/FIRRTL/FIRRTL.cpp.inc"
assert(baseType &&"element must be base type")
MlirType uint64_t numElements
Definition: CHIRRTL.cpp:30
MlirType elementType
Definition: CHIRRTL.cpp:29
#define isdigit(x)
Definition: FIRLexer.cpp:26
static LogicalResult verifyProbeType(RefType refType, Location loc, CircuitOp circuitOp, SymbolTableCollection &symbolTable, Twine start)
Definition: FIRRTLOps.cpp:1598
static SmallVector< T > removeElementsAtIndices(ArrayRef< T > input, const llvm::BitVector &indicesToDrop)
Remove elements from the input array corresponding to set bits in indicesToDrop, returning the elemen...
Definition: FIRRTLOps.cpp:53
static void printStopAttrs(OpAsmPrinter &p, Operation *op, DictionaryAttr attr)
Definition: FIRRTLOps.cpp:5669
static LayerSet getLayersFor(Value value)
Get the effective layer requirements for the given value.
Definition: FIRRTLOps.cpp:384
ParseResult parseSubfieldLikeOp(OpAsmParser &parser, OperationState &result)
Definition: FIRRTLOps.cpp:4346
static bool isSameIntTypeKind(Type lhs, Type rhs, int32_t &lhsWidth, int32_t &rhsWidth, bool &isConstResult, std::optional< Location > loc)
If LHS and RHS are both UInt or SInt types, the return true and fill in the width of them if known.
Definition: FIRRTLOps.cpp:4784
static LogicalResult verifySubfieldLike(OpTy op)
Definition: FIRRTLOps.cpp:4463
static void eraseInternalPaths(T op, const llvm::BitVector &portIndices)
Definition: FIRRTLOps.cpp:851
static bool isConstFieldDriven(FIRRTLBaseType type, bool isFlip=false, bool outerTypeIsConst=false)
Checks if the type has any 'const' leaf elements .
Definition: FIRRTLOps.cpp:3413
static ParseResult parsePrintfAttrs(OpAsmParser &p, NamedAttrList &resultAttrs)
Definition: FIRRTLOps.cpp:5655
static RetTy emitInferRetTypeError(std::optional< Location > loc, const Twine &message, Args &&...args)
Emit an error if optional location is non-null, return null of return type.
Definition: FIRRTLOps.cpp:92
static ParseResult parseFModuleLikeOp(OpAsmParser &parser, OperationState &result, bool hasSSAIdentifiers)
Definition: FIRRTLOps.cpp:1372
static bool printModulePorts(OpAsmPrinter &p, Block *block, ArrayRef< Direction > portDirections, ArrayRef< Attribute > portNames, ArrayRef< Attribute > portTypes, ArrayRef< Attribute > portAnnotations, ArrayRef< Attribute > portSyms, ArrayRef< Attribute > portLocs)
Print a list of module ports in the following form: in x: !firrtl.uint<1> [{class = "DontTouch}],...
Definition: FIRRTLOps.cpp:1076
static LogicalResult checkConnectConditionality(FConnectLike connect)
Checks that connections to 'const' destinations are not dependent on non-'const' conditions in when b...
Definition: FIRRTLOps.cpp:3437
static void erasePorts(FModuleLike op, const llvm::BitVector &portIndices)
Erases the ports that have their corresponding bit set in portIndices.
Definition: FIRRTLOps.cpp:810
static ParseResult parseModulePorts(OpAsmParser &parser, bool hasSSAIdentifiers, bool supportsSymbols, SmallVectorImpl< OpAsmParser::Argument > &entryArgs, SmallVectorImpl< Direction > &portDirections, SmallVectorImpl< Attribute > &portNames, SmallVectorImpl< Attribute > &portTypes, SmallVectorImpl< Attribute > &portAnnotations, SmallVectorImpl< Attribute > &portSyms, SmallVectorImpl< Attribute > &portLocs)
Parse a list of module ports.
Definition: FIRRTLOps.cpp:1153
static ParseResult parseClassInterface(OpAsmParser &parser, Type &result)
Definition: FIRRTLOps.cpp:5620
static void printElidePortAnnotations(OpAsmPrinter &p, Operation *op, DictionaryAttr attr, ArrayRef< StringRef > extraElides={})
Definition: FIRRTLOps.cpp:5542
static void printNameKind(OpAsmPrinter &p, Operation *op, firrtl::NameKindEnumAttr attr, ArrayRef< StringRef > extraElides={})
Definition: FIRRTLOps.cpp:5574
static ParseResult parseStopAttrs(OpAsmParser &p, NamedAttrList &resultAttrs)
Definition: FIRRTLOps.cpp:5665
static ParseResult parseNameKind(OpAsmParser &parser, firrtl::NameKindEnumAttr &result)
A forward declaration for NameKind attribute parser.
Definition: FIRRTLOps.cpp:5557
static void printParameterList(ArrayAttr parameters, OpAsmPrinter &p)
Print a paramter list for a module or instance.
Definition: FIRRTLOps.cpp:1255
static size_t getAddressWidth(size_t depth)
Definition: FIRRTLOps.cpp:2963
static void forceableAsmResultNames(Forceable op, StringRef name, OpAsmSetValueNameFn setNameFn)
Helper for naming forceable declarations (and their optional ref result).
Definition: FIRRTLOps.cpp:3216
static void printSubfieldLikeOp(OpTy op, ::mlir::OpAsmPrinter &printer)
Definition: FIRRTLOps.cpp:4436
SmallSet< SymbolRefAttr, 4, CompareSymbolRefAttr > LayerSet
Definition: FIRRTLOps.cpp:354
static bool checkAggConstant(Operation *op, Attribute attr, FIRRTLBaseType type)
Definition: FIRRTLOps.cpp:4060
static void printClassLike(OpAsmPrinter &p, ClassLike op)
Definition: FIRRTLOps.cpp:1860
static hw::ModulePort::Direction dirFtoH(Direction dir)
Definition: FIRRTLOps.cpp:628
static MemOp::PortKind getMemPortKindFromType(FIRRTLType type)
Return the kind of port this is given the port type from a 'mem' decl.
Definition: FIRRTLOps.cpp:117
static LogicalResult verifyInternalPaths(FModuleLike op, std::optional<::mlir::ArrayAttr > internalPaths)
Definition: FIRRTLOps.cpp:1524
static void genericAsmResultNames(Operation *op, OpAsmSetValueNameFn setNameFn)
Definition: FIRRTLOps.cpp:5687
static void printClassInterface(OpAsmPrinter &p, Operation *, ClassType type)
Definition: FIRRTLOps.cpp:5628
static void printPrintfAttrs(OpAsmPrinter &p, Operation *op, DictionaryAttr attr)
Definition: FIRRTLOps.cpp:5660
static bool isLayerSetCompatibleWith(const LayerSet &src, const LayerSet &dst, SmallVectorImpl< SymbolRefAttr > &missing)
Check that the source layers are all present in the destination layers.
Definition: FIRRTLOps.cpp:436
static bool isLayerCompatibleWith(mlir::SymbolRefAttr srcLayer, mlir::SymbolRefAttr dstLayer)
Check that the source layer is compatible with the destination layer.
Definition: FIRRTLOps.cpp:396
static LayerSet getAmbientLayersFor(Value value)
Get the ambient layer requirements at the definition site of the value.
Definition: FIRRTLOps.cpp:376
static LayerSet getAmbientLayersAt(Operation *op)
Get the ambient layers active at the given op.
Definition: FIRRTLOps.cpp:357
static void insertPorts(FModuleLike op, ArrayRef< std::pair< unsigned, PortInfo >> ports, bool supportsInternalPaths=false)
Inserts the given ports.
Definition: FIRRTLOps.cpp:710
static void printFIRRTLImplicitSSAName(OpAsmPrinter &p, Operation *op, DictionaryAttr attrs)
Definition: FIRRTLOps.cpp:5593
static ParseResult parseFIRRTLImplicitSSAName(OpAsmParser &parser, NamedAttrList &resultAttrs)
Definition: FIRRTLOps.cpp:5585
static FIRRTLBaseType inferMuxReturnType(FIRRTLBaseType high, FIRRTLBaseType low, bool isConstCondition, std::optional< Location > loc)
Infer the result type for a multiplexer given its two operand types, which may be aggregates.
Definition: FIRRTLOps.cpp:5185
static ParseResult parseCircuitOpAttrs(OpAsmParser &parser, NamedAttrList &resultAttrs)
Definition: FIRRTLOps.cpp:465
static SmallVector< PortInfo > getPortImpl(FModuleLike module)
Definition: FIRRTLOps.cpp:609
constexpr const char * toString(Flow flow)
Definition: FIRRTLOps.cpp:169
void getAsmBlockArgumentNamesImpl(Operation *op, mlir::Region &region, OpAsmSetValueNameFn setNameFn)
Get a special name to use when printing the entry block arguments of the region contained by an opera...
Definition: FIRRTLOps.cpp:300
static void printElideAnnotations(OpAsmPrinter &p, Operation *op, DictionaryAttr attr, ArrayRef< StringRef > extraElides={})
Definition: FIRRTLOps.cpp:5513
static SmallVector< hw::PortInfo > getPortListImpl(FModuleLike module)
Definition: FIRRTLOps.cpp:637
static ParseResult parseElidePortAnnotations(OpAsmParser &parser, NamedAttrList &resultAttrs)
Parse an optional attribute dictionary, adding empty 'annotations' and 'portAnnotations' attributes i...
Definition: FIRRTLOps.cpp:5528