CIRCT  19.0.0git
FIRRTLOps.cpp
Go to the documentation of this file.
1 //===- FIRRTLOps.cpp - Implement the FIRRTL operations --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implement the FIRRTL ops.
10 //
11 //===----------------------------------------------------------------------===//
12 
24 #include "mlir/IR/BuiltinTypes.h"
25 #include "mlir/IR/Diagnostics.h"
26 #include "mlir/IR/DialectImplementation.h"
27 #include "mlir/IR/PatternMatch.h"
28 #include "mlir/IR/SymbolTable.h"
29 #include "mlir/Interfaces/FunctionImplementation.h"
30 #include "llvm/ADT/BitVector.h"
31 #include "llvm/ADT/DenseMap.h"
32 #include "llvm/ADT/DenseSet.h"
33 #include "llvm/ADT/STLExtras.h"
34 #include "llvm/ADT/SmallSet.h"
35 #include "llvm/ADT/StringExtras.h"
36 #include "llvm/ADT/TypeSwitch.h"
37 #include "llvm/Support/FormatVariadic.h"
38 
39 using llvm::SmallDenseSet;
40 using mlir::RegionRange;
41 using namespace circt;
42 using namespace firrtl;
43 using namespace chirrtl;
44 
45 //===----------------------------------------------------------------------===//
46 // Utilities
47 //===----------------------------------------------------------------------===//
48 
49 /// Remove elements from the input array corresponding to set bits in
50 /// `indicesToDrop`, returning the elements not mentioned.
51 template <typename T>
52 static SmallVector<T>
53 removeElementsAtIndices(ArrayRef<T> input,
54  const llvm::BitVector &indicesToDrop) {
55 #ifndef NDEBUG
56  if (!input.empty()) {
57  int lastIndex = indicesToDrop.find_last();
58  if (lastIndex >= 0)
59  assert((size_t)lastIndex < input.size() && "index out of range");
60  }
61 #endif
62 
63  // If the input is empty (which is an optimization we do for certain array
64  // attributes), simply return an empty vector.
65  if (input.empty())
66  return {};
67 
68  // Copy over the live chunks.
69  size_t lastCopied = 0;
70  SmallVector<T> result;
71  result.reserve(input.size() - indicesToDrop.count());
72 
73  for (unsigned indexToDrop : indicesToDrop.set_bits()) {
74  // If we skipped over some valid elements, copy them over.
75  if (indexToDrop > lastCopied) {
76  result.append(input.begin() + lastCopied, input.begin() + indexToDrop);
77  lastCopied = indexToDrop;
78  }
79  // Ignore this value so we don't copy it in the next iteration.
80  ++lastCopied;
81  }
82 
83  // If there are live elements at the end, copy them over.
84  if (lastCopied < input.size())
85  result.append(input.begin() + lastCopied, input.end());
86 
87  return result;
88 }
89 
90 /// Emit an error if optional location is non-null, return null of return type.
91 template <typename RetTy = FIRRTLType, typename... Args>
92 static RetTy emitInferRetTypeError(std::optional<Location> loc,
93  const Twine &message, Args &&...args) {
94  if (loc)
95  (mlir::emitError(*loc, message) << ... << std::forward<Args>(args));
96  return {};
97 }
98 
99 bool firrtl::isDuplexValue(Value val) {
100  // Block arguments are not duplex values.
101  while (Operation *op = val.getDefiningOp()) {
102  auto isDuplex =
103  TypeSwitch<Operation *, std::optional<bool>>(op)
104  .Case<SubfieldOp, SubindexOp, SubaccessOp>([&val](auto op) {
105  val = op.getInput();
106  return std::nullopt;
107  })
108  .Case<RegOp, RegResetOp, WireOp>([](auto) { return true; })
109  .Default([](auto) { return false; });
110  if (isDuplex)
111  return *isDuplex;
112  }
113  return false;
114 }
115 
116 SmallVector<std::pair<circt::FieldRef, circt::FieldRef>>
117 MemOp::computeDataFlow() {
118  // If read result has non-zero latency, then no combinational dependency
119  // exists.
120  if (getReadLatency() > 0)
121  return {};
122 
123  SmallVector<std::pair<circt::FieldRef, circt::FieldRef>> deps;
124  // Add a dependency from the enable and address fields to the data field.
125  for (auto memPort : getResults())
126  if (auto type = type_dyn_cast<BundleType>(memPort.getType())) {
127  auto enableFieldId = type.getFieldID((unsigned)ReadPortSubfield::en);
128  auto addressFieldId = type.getFieldID((unsigned)ReadPortSubfield::addr);
129  auto dataFieldId = type.getFieldID((unsigned)ReadPortSubfield::data);
130  deps.push_back({FieldRef(memPort, static_cast<unsigned>(dataFieldId)),
131  FieldRef(memPort, static_cast<unsigned>(enableFieldId))});
132  deps.push_back({{memPort, static_cast<unsigned>(dataFieldId)},
133  {memPort, static_cast<unsigned>(addressFieldId)}});
134  }
135  return deps;
136 }
137 
138 /// Return the kind of port this is given the port type from a 'mem' decl.
139 static MemOp::PortKind getMemPortKindFromType(FIRRTLType type) {
140  constexpr unsigned int addr = 1 << 0;
141  constexpr unsigned int en = 1 << 1;
142  constexpr unsigned int clk = 1 << 2;
143  constexpr unsigned int data = 1 << 3;
144  constexpr unsigned int mask = 1 << 4;
145  constexpr unsigned int rdata = 1 << 5;
146  constexpr unsigned int wdata = 1 << 6;
147  constexpr unsigned int wmask = 1 << 7;
148  constexpr unsigned int wmode = 1 << 8;
149  constexpr unsigned int def = 1 << 9;
150  // Get the kind of port based on the fields of the Bundle.
151  auto portType = type_dyn_cast<BundleType>(type);
152  if (!portType)
153  return MemOp::PortKind::Debug;
154  unsigned fields = 0;
155  // Get the kind of port based on the fields of the Bundle.
156  for (auto elem : portType.getElements()) {
157  fields |= llvm::StringSwitch<unsigned>(elem.name.getValue())
158  .Case("addr", addr)
159  .Case("en", en)
160  .Case("clk", clk)
161  .Case("data", data)
162  .Case("mask", mask)
163  .Case("rdata", rdata)
164  .Case("wdata", wdata)
165  .Case("wmask", wmask)
166  .Case("wmode", wmode)
167  .Default(def);
168  }
169  if (fields == (addr | en | clk | data))
170  return MemOp::PortKind::Read;
171  if (fields == (addr | en | clk | data | mask))
172  return MemOp::PortKind::Write;
173  if (fields == (addr | en | clk | wdata | wmask | rdata | wmode))
174  return MemOp::PortKind::ReadWrite;
175  return MemOp::PortKind::Debug;
176 }
177 
179  switch (flow) {
180  case Flow::None:
181  return Flow::None;
182  case Flow::Source:
183  return Flow::Sink;
184  case Flow::Sink:
185  return Flow::Source;
186  case Flow::Duplex:
187  return Flow::Duplex;
188  }
189  // Unreachable but silences warning
190  llvm_unreachable("Unsupported Flow type.");
191 }
192 
193 const char *toString(Flow flow) {
194  switch (flow) {
195  case Flow::None:
196  return "no flow";
197  case Flow::Source:
198  return "source flow";
199  case Flow::Sink:
200  return "sink flow";
201  case Flow::Duplex:
202  return "duplex flow";
203  }
204  // Unreachable but silences warning
205  llvm_unreachable("Unsupported Flow type.");
206 }
207 
208 Flow firrtl::foldFlow(Value val, Flow accumulatedFlow) {
209 
210  if (auto blockArg = dyn_cast<BlockArgument>(val)) {
211  auto *op = val.getParentBlock()->getParentOp();
212  if (auto moduleLike = dyn_cast<FModuleLike>(op)) {
213  auto direction = moduleLike.getPortDirection(blockArg.getArgNumber());
214  if (direction == Direction::Out)
215  return swapFlow(accumulatedFlow);
216  }
217  return accumulatedFlow;
218  }
219 
220  Operation *op = val.getDefiningOp();
221 
222  return TypeSwitch<Operation *, Flow>(op)
223  .Case<SubfieldOp, OpenSubfieldOp>([&](auto op) {
224  return foldFlow(op.getInput(), op.isFieldFlipped()
225  ? swapFlow(accumulatedFlow)
226  : accumulatedFlow);
227  })
228  .Case<SubindexOp, SubaccessOp, OpenSubindexOp, RefSubOp>(
229  [&](auto op) { return foldFlow(op.getInput(), accumulatedFlow); })
230  // Registers, Wires, and behavioral memory ports are always Duplex.
231  .Case<RegOp, RegResetOp, WireOp, MemoryPortOp>(
232  [](auto) { return Flow::Duplex; })
233  .Case<InstanceOp, InstanceChoiceOp>([&](auto inst) {
234  auto resultNo = cast<OpResult>(val).getResultNumber();
235  if (inst.getPortDirection(resultNo) == Direction::Out)
236  return accumulatedFlow;
237  return swapFlow(accumulatedFlow);
238  })
239  .Case<MemOp>([&](auto op) {
240  // only debug ports with RefType have source flow.
241  if (type_isa<RefType>(val.getType()))
242  return Flow::Source;
243  return swapFlow(accumulatedFlow);
244  })
245  .Case<ObjectSubfieldOp>([&](ObjectSubfieldOp op) {
246  auto input = op.getInput();
247  auto *inputOp = input.getDefiningOp();
248 
249  // We are directly accessing a port on a local declaration.
250  if (auto objectOp = dyn_cast_or_null<ObjectOp>(inputOp)) {
251  auto classType = input.getType();
252  auto direction = classType.getElement(op.getIndex()).direction;
253  if (direction == Direction::In)
254  return Flow::Sink;
255  return Flow::Source;
256  }
257 
258  // We are accessing a remote object. Input ports on remote objects are
259  // inaccessible, and thus have Flow::None. Walk backwards through the
260  // chain of subindexes, to detect if we have indexed through an input
261  // port. At the end, either we did index through an input port, or the
262  // entire path was through output ports with source flow.
263  while (true) {
264  auto classType = input.getType();
265  auto direction = classType.getElement(op.getIndex()).direction;
266  if (direction == Direction::In)
267  return Flow::None;
268 
269  op = dyn_cast_or_null<ObjectSubfieldOp>(inputOp);
270  if (op) {
271  input = op.getInput();
272  inputOp = input.getDefiningOp();
273  continue;
274  }
275 
276  return accumulatedFlow;
277  };
278  })
279  // Anything else acts like a universal source.
280  .Default([&](auto) { return accumulatedFlow; });
281 }
282 
283 // TODO: This is doing the same walk as foldFlow. These two functions can be
284 // combined and return a (flow, kind) product.
286  Operation *op = val.getDefiningOp();
287  if (!op)
288  return DeclKind::Port;
289 
290  return TypeSwitch<Operation *, DeclKind>(op)
291  .Case<InstanceOp>([](auto) { return DeclKind::Instance; })
292  .Case<SubfieldOp, SubindexOp, SubaccessOp, OpenSubfieldOp, OpenSubindexOp,
293  RefSubOp>([](auto op) { return getDeclarationKind(op.getInput()); })
294  .Default([](auto) { return DeclKind::Other; });
295 }
296 
297 size_t firrtl::getNumPorts(Operation *op) {
298  if (auto module = dyn_cast<FModuleLike>(op))
299  return module.getNumPorts();
300  return op->getNumResults();
301 }
302 
303 /// Check whether an operation has a `DontTouch` annotation, or a symbol that
304 /// should prevent certain types of canonicalizations.
305 bool firrtl::hasDontTouch(Operation *op) {
306  return op->getAttr(hw::InnerSymbolTable::getInnerSymbolAttrName()) ||
308 }
309 
310 /// Check whether a block argument ("port") or the operation defining a value
311 /// has a `DontTouch` annotation, or a symbol that should prevent certain types
312 /// of canonicalizations.
313 bool firrtl::hasDontTouch(Value value) {
314  if (auto *op = value.getDefiningOp())
315  return hasDontTouch(op);
316  auto arg = dyn_cast<BlockArgument>(value);
317  auto module = cast<FModuleOp>(arg.getOwner()->getParentOp());
318  return (module.getPortSymbolAttr(arg.getArgNumber())) ||
319  AnnotationSet::forPort(module, arg.getArgNumber()).hasDontTouch();
320 }
321 
322 /// Get a special name to use when printing the entry block arguments of the
323 /// region contained by an operation in this dialect.
324 void getAsmBlockArgumentNamesImpl(Operation *op, mlir::Region &region,
325  OpAsmSetValueNameFn setNameFn) {
326  if (region.empty())
327  return;
328  auto *parentOp = op;
329  auto *block = &region.front();
330  // Check to see if the operation containing the arguments has 'firrtl.name'
331  // attributes for them. If so, use that as the name.
332  auto argAttr = parentOp->getAttrOfType<ArrayAttr>("portNames");
333  // Do not crash on invalid IR.
334  if (!argAttr || argAttr.size() != block->getNumArguments())
335  return;
336 
337  for (size_t i = 0, e = block->getNumArguments(); i != e; ++i) {
338  auto str = cast<StringAttr>(argAttr[i]).getValue();
339  if (!str.empty())
340  setNameFn(block->getArgument(i), str);
341  }
342 }
343 
344 /// A forward declaration for `NameKind` attribute parser.
345 static ParseResult parseNameKind(OpAsmParser &parser,
346  firrtl::NameKindEnumAttr &result);
347 
348 //===----------------------------------------------------------------------===//
349 // Layer Verification Utilities
350 //===----------------------------------------------------------------------===//
351 
352 namespace {
353 struct CompareSymbolRefAttr {
354  // True if lhs is lexicographically less than rhs.
355  bool operator()(SymbolRefAttr lhs, SymbolRefAttr rhs) const {
356  auto cmp = lhs.getRootReference().compare(rhs.getRootReference());
357  if (cmp == -1)
358  return true;
359  if (cmp == 1)
360  return false;
361  auto lhsNested = lhs.getNestedReferences();
362  auto rhsNested = rhs.getNestedReferences();
363  auto lhsNestedSize = lhsNested.size();
364  auto rhsNestedSize = rhsNested.size();
365  auto e = std::min(lhsNestedSize, rhsNestedSize);
366  for (unsigned i = 0; i < e; ++i) {
367  auto cmp = lhsNested[i].getAttr().compare(rhsNested[i].getAttr());
368  if (cmp == -1)
369  return true;
370  if (cmp == 1)
371  return false;
372  }
373  return lhsNestedSize < rhsNestedSize;
374  }
375 };
376 } // namespace
377 
379 
380 /// Get the ambient layers active at the given op.
381 static LayerSet getAmbientLayersAt(Operation *op) {
382  // Crawl through the parent ops, accumulating all ambient layers at the given
383  // operation.
384  LayerSet result;
385  for (; op != nullptr; op = op->getParentOp()) {
386  if (auto module = dyn_cast<FModuleLike>(op)) {
387  auto layers = module.getLayersAttr().getAsRange<SymbolRefAttr>();
388  result.insert(layers.begin(), layers.end());
389  break;
390  }
391  if (auto layerblock = dyn_cast<LayerBlockOp>(op)) {
392  result.insert(layerblock.getLayerName());
393  continue;
394  }
395  }
396  return result;
397 }
398 
399 /// Get the ambient layer requirements at the definition site of the value.
400 static LayerSet getAmbientLayersFor(Value value) {
401  return getAmbientLayersAt(getFieldRefFromValue(value).getDefiningOp());
402 }
403 
404 /// Get the effective layer requirements for the given value.
405 /// The effective layers for a value is the union of
406 /// - the ambient layers for the cannonical storage location.
407 /// - any explicit layer annotations in the value's type.
408 static LayerSet getLayersFor(Value value) {
409  auto result = getAmbientLayersFor(value);
410  if (auto type = dyn_cast<RefType>(value.getType()))
411  if (auto layer = type.getLayer())
412  result.insert(type.getLayer());
413  return result;
414 }
415 
416 /// Check that the source layer is compatible with the destination layer.
417 /// Either the source and destination are identical, or the source-layer
418 /// is a parent of the destination. For example `A` is compatible with `A.B.C`,
419 /// because any definition valid in `A` is also valid in `A.B.C`.
420 static bool isLayerCompatibleWith(mlir::SymbolRefAttr srcLayer,
421  mlir::SymbolRefAttr dstLayer) {
422  // A non-colored probe may be cast to any colored probe.
423  if (!srcLayer)
424  return true;
425 
426  // A colored probe cannot be cast to an uncolored probe.
427  if (!dstLayer)
428  return false;
429 
430  // Return true if the srcLayer is a prefix of the dstLayer.
431  if (srcLayer.getRootReference() != dstLayer.getRootReference())
432  return false;
433 
434  auto srcNames = srcLayer.getNestedReferences();
435  auto dstNames = dstLayer.getNestedReferences();
436  if (dstNames.size() < srcNames.size())
437  return false;
438 
439  return llvm::all_of(llvm::zip_first(srcNames, dstNames),
440  [](auto x) { return std::get<0>(x) == std::get<1>(x); });
441 }
442 
443 /// Check that the source layer is present in the destination layers.
444 static bool isLayerCompatibleWith(SymbolRefAttr srcLayer,
445  const LayerSet &dstLayers) {
446  // fast path: the required layer is directly listed in the provided layers.
447  if (dstLayers.contains(srcLayer))
448  return true;
449 
450  // Slow path: the required layer is not directly listed in the provided
451  // layers, but the layer may still be provided by a nested layer.
452  return any_of(dstLayers, [=](SymbolRefAttr dstLayer) {
453  return isLayerCompatibleWith(srcLayer, dstLayer);
454  });
455 }
456 
457 /// Check that the source layers are all present in the destination layers.
458 /// True if all source layers are present in the destination.
459 /// Outputs the set of source layers that are missing in the destination.
460 static bool isLayerSetCompatibleWith(const LayerSet &src, const LayerSet &dst,
461  SmallVectorImpl<SymbolRefAttr> &missing) {
462  for (auto srcLayer : src)
463  if (!isLayerCompatibleWith(srcLayer, dst))
464  missing.push_back(srcLayer);
465 
466  llvm::sort(missing, CompareSymbolRefAttr());
467  return missing.empty();
468 }
469 
470 //===----------------------------------------------------------------------===//
471 // CircuitOp
472 //===----------------------------------------------------------------------===//
473 
474 void CircuitOp::build(OpBuilder &builder, OperationState &result,
475  StringAttr name, ArrayAttr annotations) {
476  // Add an attribute for the name.
477  result.addAttribute(builder.getStringAttr("name"), name);
478 
479  if (!annotations)
480  annotations = builder.getArrayAttr({});
481  result.addAttribute("annotations", annotations);
482 
483  // Create a region and a block for the body.
484  Region *bodyRegion = result.addRegion();
485  Block *body = new Block();
486  bodyRegion->push_back(body);
487 }
488 
489 static ParseResult parseCircuitOpAttrs(OpAsmParser &parser,
490  NamedAttrList &resultAttrs) {
491  auto result = parser.parseOptionalAttrDictWithKeyword(resultAttrs);
492  if (!resultAttrs.get("annotations"))
493  resultAttrs.append("annotations", parser.getBuilder().getArrayAttr({}));
494 
495  return result;
496 }
497 
498 static void printCircuitOpAttrs(OpAsmPrinter &p, Operation *op,
499  DictionaryAttr attr) {
500  // "name" is always elided.
501  SmallVector<StringRef> elidedAttrs = {"name"};
502  // Elide "annotations" if it doesn't exist or if it is empty
503  auto annotationsAttr = op->getAttrOfType<ArrayAttr>("annotations");
504  if (annotationsAttr.empty())
505  elidedAttrs.push_back("annotations");
506 
507  p.printOptionalAttrDictWithKeyword(op->getAttrs(), elidedAttrs);
508 }
509 
510 LogicalResult CircuitOp::verifyRegions() {
511  StringRef main = getName();
512 
513  // Check that the circuit has a non-empty name.
514  if (main.empty()) {
515  emitOpError("must have a non-empty name");
516  return failure();
517  }
518 
519  mlir::SymbolTable symtbl(getOperation());
520 
521  if (!symtbl.lookup(main))
522  return emitOpError().append("Module with same name as circuit not found");
523 
524  // Store a mapping of defname to either the first external module
525  // that defines it or, preferentially, the first external module
526  // that defines it and has no parameters.
527  llvm::DenseMap<Attribute, FExtModuleOp> defnameMap;
528 
529  auto verifyExtModule = [&](FExtModuleOp extModule) -> LogicalResult {
530  if (!extModule)
531  return success();
532 
533  auto defname = extModule.getDefnameAttr();
534  if (!defname)
535  return success();
536 
537  // Check that this extmodule's defname does not conflict with
538  // the symbol name of any module.
539  if (auto collidingModule = symtbl.lookup<FModuleOp>(defname.getValue()))
540  return extModule.emitOpError()
541  .append("attribute 'defname' with value ", defname,
542  " conflicts with the name of another module in the circuit")
543  .attachNote(collidingModule.getLoc())
544  .append("previous module declared here");
545 
546  // Find an optional extmodule with a defname collision. Update
547  // the defnameMap if this is the first extmodule with that
548  // defname or if the current extmodule takes no parameters and
549  // the collision does. The latter condition improves later
550  // extmodule verification as checking against a parameterless
551  // module is stricter.
552  FExtModuleOp collidingExtModule;
553  if (auto &value = defnameMap[defname]) {
554  collidingExtModule = value;
555  if (!value.getParameters().empty() && extModule.getParameters().empty())
556  value = extModule;
557  } else {
558  value = extModule;
559  // Go to the next extmodule if no extmodule with the same
560  // defname was found.
561  return success();
562  }
563 
564  // Check that the number of ports is exactly the same.
565  SmallVector<PortInfo> ports = extModule.getPorts();
566  SmallVector<PortInfo> collidingPorts = collidingExtModule.getPorts();
567 
568  if (ports.size() != collidingPorts.size())
569  return extModule.emitOpError()
570  .append("with 'defname' attribute ", defname, " has ", ports.size(),
571  " ports which is different from a previously defined "
572  "extmodule with the same 'defname' which has ",
573  collidingPorts.size(), " ports")
574  .attachNote(collidingExtModule.getLoc())
575  .append("previous extmodule definition occurred here");
576 
577  // Check that ports match for name and type. Since parameters
578  // *might* affect widths, ignore widths if either module has
579  // parameters. Note that this allows for misdetections, but
580  // has zero false positives.
581  for (auto p : llvm::zip(ports, collidingPorts)) {
582  StringAttr aName = std::get<0>(p).name, bName = std::get<1>(p).name;
583  Type aType = std::get<0>(p).type, bType = std::get<1>(p).type;
584 
585  if (aName != bName)
586  return extModule.emitOpError()
587  .append("with 'defname' attribute ", defname,
588  " has a port with name ", aName,
589  " which does not match the name of the port in the same "
590  "position of a previously defined extmodule with the same "
591  "'defname', expected port to have name ",
592  bName)
593  .attachNote(collidingExtModule.getLoc())
594  .append("previous extmodule definition occurred here");
595 
596  if (!extModule.getParameters().empty() ||
597  !collidingExtModule.getParameters().empty()) {
598  // Compare base types as widthless, others must match.
599  if (auto base = type_dyn_cast<FIRRTLBaseType>(aType))
600  aType = base.getWidthlessType();
601  if (auto base = type_dyn_cast<FIRRTLBaseType>(bType))
602  bType = base.getWidthlessType();
603  }
604  if (aType != bType)
605  return extModule.emitOpError()
606  .append("with 'defname' attribute ", defname,
607  " has a port with name ", aName,
608  " which has a different type ", aType,
609  " which does not match the type of the port in the same "
610  "position of a previously defined extmodule with the same "
611  "'defname', expected port to have type ",
612  bType)
613  .attachNote(collidingExtModule.getLoc())
614  .append("previous extmodule definition occurred here");
615  }
616  return success();
617  };
618 
619  for (auto &op : *getBodyBlock()) {
620  // Verify external modules.
621  if (auto extModule = dyn_cast<FExtModuleOp>(op)) {
622  if (verifyExtModule(extModule).failed())
623  return failure();
624  }
625  }
626 
627  return success();
628 }
629 
630 Block *CircuitOp::getBodyBlock() { return &getBody().front(); }
631 
632 //===----------------------------------------------------------------------===//
633 // FExtModuleOp and FModuleOp
634 //===----------------------------------------------------------------------===//
635 
636 static SmallVector<PortInfo> getPortImpl(FModuleLike module) {
637  SmallVector<PortInfo> results;
638  for (unsigned i = 0, e = module.getNumPorts(); i < e; ++i) {
639  results.push_back({module.getPortNameAttr(i), module.getPortType(i),
640  module.getPortDirection(i), module.getPortSymbolAttr(i),
641  module.getPortLocation(i),
642  AnnotationSet::forPort(module, i)});
643  }
644  return results;
645 }
646 
647 SmallVector<PortInfo> FModuleOp::getPorts() { return ::getPortImpl(*this); }
648 
649 SmallVector<PortInfo> FExtModuleOp::getPorts() { return ::getPortImpl(*this); }
650 
651 SmallVector<PortInfo> FIntModuleOp::getPorts() { return ::getPortImpl(*this); }
652 
653 SmallVector<PortInfo> FMemModuleOp::getPorts() { return ::getPortImpl(*this); }
654 
656  if (dir == Direction::In)
658  if (dir == Direction::Out)
660  assert(0 && "invalid direction");
661  abort();
662 }
663 
664 static SmallVector<hw::PortInfo> getPortListImpl(FModuleLike module) {
665  SmallVector<hw::PortInfo> results;
666  auto aname = StringAttr::get(module.getContext(),
667  hw::HWModuleLike::getPortSymbolAttrName());
668  auto emptyDict = DictionaryAttr::get(module.getContext());
669  for (unsigned i = 0, e = getNumPorts(module); i < e; ++i) {
670  auto sym = module.getPortSymbolAttr(i);
671  results.push_back(
672  {{module.getPortNameAttr(i), module.getPortType(i),
673  dirFtoH(module.getPortDirection(i))},
674  i,
675  sym ? DictionaryAttr::get(
676  module.getContext(),
677  ArrayRef<mlir::NamedAttribute>{NamedAttribute{aname, sym}})
678  : emptyDict,
679  module.getPortLocation(i)});
680  }
681  return results;
682 }
683 
684 SmallVector<::circt::hw::PortInfo> FModuleOp::getPortList() {
686 }
687 
688 SmallVector<::circt::hw::PortInfo> FExtModuleOp::getPortList() {
690 }
691 
692 SmallVector<::circt::hw::PortInfo> FIntModuleOp::getPortList() {
694 }
695 
696 SmallVector<::circt::hw::PortInfo> FMemModuleOp::getPortList() {
698 }
699 
700 static hw::PortInfo getPortImpl(FModuleLike module, size_t idx) {
701  return {{module.getPortNameAttr(idx), module.getPortType(idx),
702  dirFtoH(module.getPortDirection(idx))},
703  idx,
705  module.getContext(),
706  ArrayRef<mlir::NamedAttribute>{NamedAttribute{
707  StringAttr::get(module.getContext(),
708  hw::HWModuleLike::getPortSymbolAttrName()),
709  module.getPortSymbolAttr(idx)}}),
710  module.getPortLocation(idx)};
711 }
712 
714  return ::getPortImpl(*this, idx);
715 }
716 
718  return ::getPortImpl(*this, idx);
719 }
720 
722  return ::getPortImpl(*this, idx);
723 }
724 
726  return ::getPortImpl(*this, idx);
727 }
728 
729 // Return the port with the specified name.
730 BlockArgument FModuleOp::getArgument(size_t portNumber) {
731  return getBodyBlock()->getArgument(portNumber);
732 }
733 
734 /// Inserts the given ports. The insertion indices are expected to be in order.
735 /// Insertion occurs in-order, such that ports with the same insertion index
736 /// appear in the module in the same order they appeared in the list.
737 static void insertPorts(FModuleLike op,
738  ArrayRef<std::pair<unsigned, PortInfo>> ports,
739  bool supportsInternalPaths = false) {
740  if (ports.empty())
741  return;
742  unsigned oldNumArgs = op.getNumPorts();
743  unsigned newNumArgs = oldNumArgs + ports.size();
744 
745  // Add direction markers and names for new ports.
746  auto existingDirections = op.getPortDirectionsAttr();
747  ArrayRef<Attribute> existingNames = op.getPortNames();
748  ArrayRef<Attribute> existingTypes = op.getPortTypes();
749  ArrayRef<Attribute> existingLocs = op.getPortLocations();
750  assert(existingDirections.size() == oldNumArgs);
751  assert(existingNames.size() == oldNumArgs);
752  assert(existingTypes.size() == oldNumArgs);
753  assert(existingLocs.size() == oldNumArgs);
754  SmallVector<Attribute> internalPaths;
755  auto emptyInternalPath = InternalPathAttr::get(op.getContext());
756  if (supportsInternalPaths) {
757  if (auto internalPathsAttr = op->getAttrOfType<ArrayAttr>("internalPaths"))
758  llvm::append_range(internalPaths, internalPathsAttr);
759  else
760  internalPaths.resize(oldNumArgs, emptyInternalPath);
761  assert(internalPaths.size() == oldNumArgs);
762  }
763 
764  SmallVector<bool> newDirections;
765  SmallVector<Attribute> newNames, newTypes, newAnnos, newSyms, newLocs,
766  newInternalPaths;
767  newDirections.reserve(newNumArgs);
768  newNames.reserve(newNumArgs);
769  newTypes.reserve(newNumArgs);
770  newAnnos.reserve(newNumArgs);
771  newSyms.reserve(newNumArgs);
772  newLocs.reserve(newNumArgs);
773  newInternalPaths.reserve(newNumArgs);
774 
775  auto emptyArray = ArrayAttr::get(op.getContext(), {});
776 
777  unsigned oldIdx = 0;
778  auto migrateOldPorts = [&](unsigned untilOldIdx) {
779  while (oldIdx < oldNumArgs && oldIdx < untilOldIdx) {
780  newDirections.push_back(existingDirections[oldIdx]);
781  newNames.push_back(existingNames[oldIdx]);
782  newTypes.push_back(existingTypes[oldIdx]);
783  newAnnos.push_back(op.getAnnotationsAttrForPort(oldIdx));
784  newSyms.push_back(op.getPortSymbolAttr(oldIdx));
785  newLocs.push_back(existingLocs[oldIdx]);
786  if (supportsInternalPaths)
787  newInternalPaths.push_back(internalPaths[oldIdx]);
788  ++oldIdx;
789  }
790  };
791  for (auto pair : llvm::enumerate(ports)) {
792  auto idx = pair.value().first;
793  auto &port = pair.value().second;
794  migrateOldPorts(idx);
795  newDirections.push_back(direction::unGet(port.direction));
796  newNames.push_back(port.name);
797  newTypes.push_back(TypeAttr::get(port.type));
798  auto annos = port.annotations.getArrayAttr();
799  newAnnos.push_back(annos ? annos : emptyArray);
800  newSyms.push_back(port.sym);
801  newLocs.push_back(port.loc);
802  if (supportsInternalPaths)
803  newInternalPaths.push_back(emptyInternalPath);
804  }
805  migrateOldPorts(oldNumArgs);
806 
807  // The lack of *any* port annotations is represented by an empty
808  // `portAnnotations` array as a shorthand.
809  if (llvm::all_of(newAnnos, [](Attribute attr) {
810  return cast<ArrayAttr>(attr).empty();
811  }))
812  newAnnos.clear();
813 
814  // Apply these changed markers.
815  op->setAttr("portDirections",
816  direction::packAttribute(op.getContext(), newDirections));
817  op->setAttr("portNames", ArrayAttr::get(op.getContext(), newNames));
818  op->setAttr("portTypes", ArrayAttr::get(op.getContext(), newTypes));
819  op->setAttr("portAnnotations", ArrayAttr::get(op.getContext(), newAnnos));
820  op.setPortSymbols(newSyms);
821  op->setAttr("portLocations", ArrayAttr::get(op.getContext(), newLocs));
822  if (supportsInternalPaths) {
823  // Drop if all-empty, otherwise set to new array.
824  auto empty = llvm::all_of(newInternalPaths, [](Attribute attr) {
825  return !cast<InternalPathAttr>(attr).getPath();
826  });
827  if (empty)
828  op->removeAttr("internalPaths");
829  else
830  op->setAttr("internalPaths",
831  ArrayAttr::get(op.getContext(), newInternalPaths));
832  }
833 }
834 
835 /// Erases the ports that have their corresponding bit set in `portIndices`.
836 static void erasePorts(FModuleLike op, const llvm::BitVector &portIndices) {
837  if (portIndices.none())
838  return;
839 
840  // Drop the direction markers for dead ports.
841  ArrayRef<bool> portDirections = op.getPortDirectionsAttr().asArrayRef();
842  ArrayRef<Attribute> portNames = op.getPortNames();
843  ArrayRef<Attribute> portTypes = op.getPortTypes();
844  ArrayRef<Attribute> portAnnos = op.getPortAnnotations();
845  ArrayRef<Attribute> portSyms = op.getPortSymbols();
846  ArrayRef<Attribute> portLocs = op.getPortLocations();
847  auto numPorts = op.getNumPorts();
848  (void)numPorts;
849  assert(portDirections.size() == numPorts);
850  assert(portNames.size() == numPorts);
851  assert(portAnnos.size() == numPorts || portAnnos.empty());
852  assert(portTypes.size() == numPorts);
853  assert(portSyms.size() == numPorts || portSyms.empty());
854  assert(portLocs.size() == numPorts);
855 
856  SmallVector<bool> newPortDirections =
857  removeElementsAtIndices<bool>(portDirections, portIndices);
858  SmallVector<Attribute> newPortNames, newPortTypes, newPortAnnos, newPortSyms,
859  newPortLocs;
860  newPortNames = removeElementsAtIndices(portNames, portIndices);
861  newPortTypes = removeElementsAtIndices(portTypes, portIndices);
862  newPortAnnos = removeElementsAtIndices(portAnnos, portIndices);
863  newPortSyms = removeElementsAtIndices(portSyms, portIndices);
864  newPortLocs = removeElementsAtIndices(portLocs, portIndices);
865  op->setAttr("portDirections",
866  direction::packAttribute(op.getContext(), newPortDirections));
867  op->setAttr("portNames", ArrayAttr::get(op.getContext(), newPortNames));
868  op->setAttr("portAnnotations", ArrayAttr::get(op.getContext(), newPortAnnos));
869  op->setAttr("portTypes", ArrayAttr::get(op.getContext(), newPortTypes));
870  FModuleLike::fixupPortSymsArray(newPortSyms, op.getContext());
871  op->setAttr("portSyms", ArrayAttr::get(op.getContext(), newPortSyms));
872  op->setAttr("portLocations", ArrayAttr::get(op.getContext(), newPortLocs));
873 }
874 
875 template <typename T>
876 static void eraseInternalPaths(T op, const llvm::BitVector &portIndices) {
877  // Fixup internalPaths array.
878  auto internalPaths = op.getInternalPaths();
879  if (!internalPaths)
880  return;
881 
882  auto newPaths =
883  removeElementsAtIndices(internalPaths->getValue(), portIndices);
884 
885  // Drop if all-empty, otherwise set to new array.
886  auto empty = llvm::all_of(newPaths, [](Attribute attr) {
887  return !cast<InternalPathAttr>(attr).getPath();
888  });
889  if (empty)
890  op.removeInternalPathsAttr();
891  else
892  op.setInternalPathsAttr(ArrayAttr::get(op.getContext(), newPaths));
893 }
894 
895 void FExtModuleOp::erasePorts(const llvm::BitVector &portIndices) {
896  ::erasePorts(cast<FModuleLike>((Operation *)*this), portIndices);
897  eraseInternalPaths(*this, portIndices);
898 }
899 
900 void FIntModuleOp::erasePorts(const llvm::BitVector &portIndices) {
901  ::erasePorts(cast<FModuleLike>((Operation *)*this), portIndices);
902  eraseInternalPaths(*this, portIndices);
903 }
904 
905 void FMemModuleOp::erasePorts(const llvm::BitVector &portIndices) {
906  ::erasePorts(cast<FModuleLike>((Operation *)*this), portIndices);
907 }
908 
909 void FModuleOp::erasePorts(const llvm::BitVector &portIndices) {
910  ::erasePorts(cast<FModuleLike>((Operation *)*this), portIndices);
911  getBodyBlock()->eraseArguments(portIndices);
912 }
913 
914 /// Inserts the given ports. The insertion indices are expected to be in order.
915 /// Insertion occurs in-order, such that ports with the same insertion index
916 /// appear in the module in the same order they appeared in the list.
917 void FModuleOp::insertPorts(ArrayRef<std::pair<unsigned, PortInfo>> ports) {
918  ::insertPorts(cast<FModuleLike>((Operation *)*this), ports);
919 
920  // Insert the block arguments.
921  auto *body = getBodyBlock();
922  for (size_t i = 0, e = ports.size(); i < e; ++i) {
923  // Block arguments are inserted one at a time, so for each argument we
924  // insert we have to increase the index by 1.
925  auto &[index, port] = ports[i];
926  body->insertArgument(index + i, port.type, port.loc);
927  }
928 }
929 
930 void FExtModuleOp::insertPorts(ArrayRef<std::pair<unsigned, PortInfo>> ports) {
931  ::insertPorts(cast<FModuleLike>((Operation *)*this), ports,
932  /*supportsInternalPaths=*/true);
933 }
934 
935 void FIntModuleOp::insertPorts(ArrayRef<std::pair<unsigned, PortInfo>> ports) {
936  ::insertPorts(cast<FModuleLike>((Operation *)*this), ports,
937  /*supportsInternalPaths=*/true);
938 }
939 
940 /// Inserts the given ports. The insertion indices are expected to be in order.
941 /// Insertion occurs in-order, such that ports with the same insertion index
942 /// appear in the module in the same order they appeared in the list.
943 void FMemModuleOp::insertPorts(ArrayRef<std::pair<unsigned, PortInfo>> ports) {
944  ::insertPorts(cast<FModuleLike>((Operation *)*this), ports);
945 }
946 
947 static void buildModuleLike(OpBuilder &builder, OperationState &result,
948  StringAttr name, ArrayRef<PortInfo> ports,
949  ArrayAttr annotations, ArrayAttr layers,
950  bool withAnnotations, bool withLayers) {
951  // Add an attribute for the name.
952  result.addAttribute(::mlir::SymbolTable::getSymbolAttrName(), name);
953 
954  // Record the names of the arguments if present.
955  SmallVector<Direction, 4> portDirections;
956  SmallVector<Attribute, 4> portNames;
957  SmallVector<Attribute, 4> portTypes;
958  SmallVector<Attribute, 4> portAnnotations;
959  SmallVector<Attribute, 4> portSyms;
960  SmallVector<Attribute, 4> portLocs;
961  for (const auto &port : ports) {
962  portDirections.push_back(port.direction);
963  portNames.push_back(port.name);
964  portTypes.push_back(TypeAttr::get(port.type));
965  portAnnotations.push_back(port.annotations.getArrayAttr());
966  portSyms.push_back(port.sym);
967  portLocs.push_back(port.loc);
968  }
969 
970  // The lack of *any* port annotations is represented by an empty
971  // `portAnnotations` array as a shorthand.
972  if (llvm::all_of(portAnnotations, [](Attribute attr) {
973  return cast<ArrayAttr>(attr).empty();
974  }))
975  portAnnotations.clear();
976 
977  FModuleLike::fixupPortSymsArray(portSyms, builder.getContext());
978  // Both attributes are added, even if the module has no ports.
979  result.addAttribute(
980  "portDirections",
981  direction::packAttribute(builder.getContext(), portDirections));
982  result.addAttribute("portNames", builder.getArrayAttr(portNames));
983  result.addAttribute("portTypes", builder.getArrayAttr(portTypes));
984  result.addAttribute("portSyms", builder.getArrayAttr(portSyms));
985  result.addAttribute("portLocations", builder.getArrayAttr(portLocs));
986 
987  if (withAnnotations) {
988  if (!annotations)
989  annotations = builder.getArrayAttr({});
990  result.addAttribute("annotations", annotations);
991  result.addAttribute("portAnnotations",
992  builder.getArrayAttr(portAnnotations));
993  }
994 
995  if (withLayers) {
996  if (!layers)
997  layers = builder.getArrayAttr({});
998  result.addAttribute("layers", layers);
999  }
1000 
1001  result.addRegion();
1002 }
1003 
1004 static void buildModule(OpBuilder &builder, OperationState &result,
1005  StringAttr name, ArrayRef<PortInfo> ports,
1006  ArrayAttr annotations, ArrayAttr layers) {
1007  buildModuleLike(builder, result, name, ports, annotations, layers,
1008  /*withAnnotations=*/true, /*withLayers=*/true);
1009 }
1010 
1011 static void buildClass(OpBuilder &builder, OperationState &result,
1012  StringAttr name, ArrayRef<PortInfo> ports) {
1013  return buildModuleLike(builder, result, name, ports, {}, {},
1014  /*withAnnotations=*/false,
1015  /*withLayers=*/false);
1016 }
1017 
1018 void FModuleOp::build(OpBuilder &builder, OperationState &result,
1019  StringAttr name, ConventionAttr convention,
1020  ArrayRef<PortInfo> ports, ArrayAttr annotations,
1021  ArrayAttr layers) {
1022  buildModule(builder, result, name, ports, annotations, layers);
1023  result.addAttribute("convention", convention);
1024 
1025  // Create a region and a block for the body.
1026  auto *bodyRegion = result.regions[0].get();
1027  Block *body = new Block();
1028  bodyRegion->push_back(body);
1029 
1030  // Add arguments to the body block.
1031  for (auto &elt : ports)
1032  body->addArgument(elt.type, elt.loc);
1033 }
1034 
1035 void FExtModuleOp::build(OpBuilder &builder, OperationState &result,
1036  StringAttr name, ConventionAttr convention,
1037  ArrayRef<PortInfo> ports, StringRef defnameAttr,
1038  ArrayAttr annotations, ArrayAttr parameters,
1039  ArrayAttr internalPaths, ArrayAttr layers) {
1040  buildModule(builder, result, name, ports, annotations, layers);
1041  result.addAttribute("convention", convention);
1042  if (!defnameAttr.empty())
1043  result.addAttribute("defname", builder.getStringAttr(defnameAttr));
1044  if (!parameters)
1045  parameters = builder.getArrayAttr({});
1046  result.addAttribute(getParametersAttrName(result.name), parameters);
1047  if (internalPaths && !internalPaths.empty())
1048  result.addAttribute(getInternalPathsAttrName(result.name), internalPaths);
1049 }
1050 
1051 void FIntModuleOp::build(OpBuilder &builder, OperationState &result,
1052  StringAttr name, ArrayRef<PortInfo> ports,
1053  StringRef intrinsicNameStr, ArrayAttr annotations,
1054  ArrayAttr parameters, ArrayAttr internalPaths,
1055  ArrayAttr layers) {
1056  buildModule(builder, result, name, ports, annotations, layers);
1057  result.addAttribute("intrinsic", builder.getStringAttr(intrinsicNameStr));
1058  if (!parameters)
1059  parameters = builder.getArrayAttr({});
1060  result.addAttribute(getParametersAttrName(result.name), parameters);
1061  if (internalPaths && !internalPaths.empty())
1062  result.addAttribute(getInternalPathsAttrName(result.name), internalPaths);
1063 }
1064 
1065 void FMemModuleOp::build(OpBuilder &builder, OperationState &result,
1066  StringAttr name, ArrayRef<PortInfo> ports,
1067  uint32_t numReadPorts, uint32_t numWritePorts,
1068  uint32_t numReadWritePorts, uint32_t dataWidth,
1069  uint32_t maskBits, uint32_t readLatency,
1070  uint32_t writeLatency, uint64_t depth,
1071  ArrayAttr annotations, ArrayAttr layers) {
1072  auto *context = builder.getContext();
1073  buildModule(builder, result, name, ports, annotations, layers);
1074  auto ui32Type = IntegerType::get(context, 32, IntegerType::Unsigned);
1075  auto ui64Type = IntegerType::get(context, 64, IntegerType::Unsigned);
1076  result.addAttribute("numReadPorts", IntegerAttr::get(ui32Type, numReadPorts));
1077  result.addAttribute("numWritePorts",
1078  IntegerAttr::get(ui32Type, numWritePorts));
1079  result.addAttribute("numReadWritePorts",
1080  IntegerAttr::get(ui32Type, numReadWritePorts));
1081  result.addAttribute("dataWidth", IntegerAttr::get(ui32Type, dataWidth));
1082  result.addAttribute("maskBits", IntegerAttr::get(ui32Type, maskBits));
1083  result.addAttribute("readLatency", IntegerAttr::get(ui32Type, readLatency));
1084  result.addAttribute("writeLatency", IntegerAttr::get(ui32Type, writeLatency));
1085  result.addAttribute("depth", IntegerAttr::get(ui64Type, depth));
1086  result.addAttribute("extraPorts", ArrayAttr::get(context, {}));
1087 }
1088 
1089 /// Print a list of module ports in the following form:
1090 /// in x: !firrtl.uint<1> [{class = "DontTouch}], out "_port": !firrtl.uint<2>
1091 ///
1092 /// When there is no block specified, the port names print as MLIR identifiers,
1093 /// wrapping in quotes if not legal to print as-is. When there is no block
1094 /// specified, this function always return false, indicating that there was no
1095 /// issue printing port names.
1096 ///
1097 /// If there is a block specified, then port names will be printed as SSA
1098 /// values. If there is a reason the printed SSA values can't match the true
1099 /// port name, then this function will return true. When this happens, the
1100 /// caller should print the port names as a part of the `attr-dict`.
1101 static bool
1102 printModulePorts(OpAsmPrinter &p, Block *block, ArrayRef<bool> portDirections,
1103  ArrayRef<Attribute> portNames, ArrayRef<Attribute> portTypes,
1104  ArrayRef<Attribute> portAnnotations,
1105  ArrayRef<Attribute> portSyms, ArrayRef<Attribute> portLocs) {
1106  // When printing port names as SSA values, we can fail to print them
1107  // identically.
1108  bool printedNamesDontMatch = false;
1109 
1110  mlir::OpPrintingFlags flags;
1111 
1112  // If we are printing the ports as block arguments the op must have a first
1113  // block.
1114  SmallString<32> resultNameStr;
1115  p << '(';
1116  for (unsigned i = 0, e = portTypes.size(); i < e; ++i) {
1117  if (i > 0)
1118  p << ", ";
1119 
1120  // Print the port direction.
1121  p << direction::get(portDirections[i]) << " ";
1122 
1123  // Print the port name. If there is a valid block, we print it as a block
1124  // argument.
1125  if (block) {
1126  // Get the printed format for the argument name.
1127  resultNameStr.clear();
1128  llvm::raw_svector_ostream tmpStream(resultNameStr);
1129  p.printOperand(block->getArgument(i), tmpStream);
1130  // If the name wasn't printable in a way that agreed with portName, make
1131  // sure to print out an explicit portNames attribute.
1132  auto portName = cast<StringAttr>(portNames[i]).getValue();
1133  if (!portName.empty() && tmpStream.str().drop_front() != portName)
1134  printedNamesDontMatch = true;
1135  p << tmpStream.str();
1136  } else {
1137  p.printKeywordOrString(cast<StringAttr>(portNames[i]).getValue());
1138  }
1139 
1140  // Print the port type.
1141  p << ": ";
1142  auto portType = cast<TypeAttr>(portTypes[i]).getValue();
1143  p.printType(portType);
1144 
1145  // Print the optional port symbol.
1146  if (!portSyms.empty()) {
1147  if (!cast<hw::InnerSymAttr>(portSyms[i]).empty()) {
1148  p << " sym ";
1149  cast<hw::InnerSymAttr>(portSyms[i]).print(p);
1150  }
1151  }
1152 
1153  // Print the port specific annotations. The port annotations array will be
1154  // empty if there are none.
1155  if (!portAnnotations.empty() &&
1156  !cast<ArrayAttr>(portAnnotations[i]).empty()) {
1157  p << " ";
1158  p.printAttribute(portAnnotations[i]);
1159  }
1160 
1161  // Print the port location.
1162  // TODO: `printOptionalLocationSpecifier` will emit aliases for locations,
1163  // even if they are not printed. This will have to be fixed upstream. For
1164  // now, use what was specified on the command line.
1165  if (flags.shouldPrintDebugInfo() && !portLocs.empty())
1166  p.printOptionalLocationSpecifier(cast<LocationAttr>(portLocs[i]));
1167  }
1168 
1169  p << ')';
1170  return printedNamesDontMatch;
1171 }
1172 
1173 /// Parse a list of module ports. If port names are SSA identifiers, then this
1174 /// will populate `entryArgs`.
1175 static ParseResult
1176 parseModulePorts(OpAsmParser &parser, bool hasSSAIdentifiers,
1177  bool supportsSymbols,
1178  SmallVectorImpl<OpAsmParser::Argument> &entryArgs,
1179  SmallVectorImpl<Direction> &portDirections,
1180  SmallVectorImpl<Attribute> &portNames,
1181  SmallVectorImpl<Attribute> &portTypes,
1182  SmallVectorImpl<Attribute> &portAnnotations,
1183  SmallVectorImpl<Attribute> &portSyms,
1184  SmallVectorImpl<Attribute> &portLocs) {
1185  auto *context = parser.getContext();
1186 
1187  auto parseArgument = [&]() -> ParseResult {
1188  // Parse port direction.
1189  if (succeeded(parser.parseOptionalKeyword("out")))
1190  portDirections.push_back(Direction::Out);
1191  else if (succeeded(parser.parseKeyword("in", "or 'out'")))
1192  portDirections.push_back(Direction::In);
1193  else
1194  return failure();
1195 
1196  // This is the location or the port declaration in the IR. If there is no
1197  // other location information, we use this to point to the MLIR.
1198  llvm::SMLoc irLoc;
1199 
1200  if (hasSSAIdentifiers) {
1201  OpAsmParser::Argument arg;
1202  if (parser.parseArgument(arg))
1203  return failure();
1204  entryArgs.push_back(arg);
1205 
1206  // The name of an argument is of the form "%42" or "%id", and since
1207  // parsing succeeded, we know it always has one character.
1208  assert(arg.ssaName.name.size() > 1 && arg.ssaName.name[0] == '%' &&
1209  "Unknown MLIR name");
1210  if (isdigit(arg.ssaName.name[1]))
1211  portNames.push_back(StringAttr::get(context, ""));
1212  else
1213  portNames.push_back(
1214  StringAttr::get(context, arg.ssaName.name.drop_front()));
1215 
1216  // Store the location of the SSA name.
1217  irLoc = arg.ssaName.location;
1218 
1219  } else {
1220  // Parse the port name.
1221  irLoc = parser.getCurrentLocation();
1222  std::string portName;
1223  if (parser.parseKeywordOrString(&portName))
1224  return failure();
1225  portNames.push_back(StringAttr::get(context, portName));
1226  }
1227 
1228  // Parse the port type.
1229  Type portType;
1230  if (parser.parseColonType(portType))
1231  return failure();
1232  portTypes.push_back(TypeAttr::get(portType));
1233 
1234  if (hasSSAIdentifiers)
1235  entryArgs.back().type = portType;
1236 
1237  // Parse the optional port symbol.
1238  if (supportsSymbols) {
1239  hw::InnerSymAttr innerSymAttr;
1240  if (succeeded(parser.parseOptionalKeyword("sym"))) {
1241  NamedAttrList dummyAttrs;
1242  if (parser.parseCustomAttributeWithFallback(
1243  innerSymAttr, ::mlir::Type{},
1245  return ::mlir::failure();
1246  }
1247  }
1248  portSyms.push_back(innerSymAttr);
1249  }
1250 
1251  // Parse the port annotations.
1252  ArrayAttr annos;
1253  auto parseResult = parser.parseOptionalAttribute(annos);
1254  if (!parseResult.has_value())
1255  annos = parser.getBuilder().getArrayAttr({});
1256  else if (failed(*parseResult))
1257  return failure();
1258  portAnnotations.push_back(annos);
1259 
1260  // Parse the optional port location.
1261  std::optional<Location> maybeLoc;
1262  if (failed(parser.parseOptionalLocationSpecifier(maybeLoc)))
1263  return failure();
1264  Location loc = maybeLoc ? *maybeLoc : parser.getEncodedSourceLoc(irLoc);
1265  portLocs.push_back(loc);
1266  if (hasSSAIdentifiers)
1267  entryArgs.back().sourceLoc = loc;
1268 
1269  return success();
1270  };
1271 
1272  // Parse all ports.
1273  return parser.parseCommaSeparatedList(OpAsmParser::Delimiter::Paren,
1274  parseArgument);
1275 }
1276 
1277 /// Print a paramter list for a module or instance.
1278 static void printParameterList(OpAsmPrinter &p, Operation *op,
1279  ArrayAttr parameters) {
1280  if (!parameters || parameters.empty())
1281  return;
1282 
1283  p << '<';
1284  llvm::interleaveComma(parameters, p, [&](Attribute param) {
1285  auto paramAttr = cast<ParamDeclAttr>(param);
1286  p << paramAttr.getName().getValue() << ": " << paramAttr.getType();
1287  if (auto value = paramAttr.getValue()) {
1288  p << " = ";
1289  p.printAttributeWithoutType(value);
1290  }
1291  });
1292  p << '>';
1293 }
1294 
1295 static void printFModuleLikeOp(OpAsmPrinter &p, FModuleLike op) {
1296  p << " ";
1297 
1298  // Print the visibility of the module.
1299  StringRef visibilityAttrName = SymbolTable::getVisibilityAttrName();
1300  if (auto visibility = op->getAttrOfType<StringAttr>(visibilityAttrName))
1301  p << visibility.getValue() << ' ';
1302 
1303  // Print the operation and the function name.
1304  p.printSymbolName(op.getModuleName());
1305 
1306  // Print the parameter list (if non-empty).
1307  printParameterList(p, op, op->getAttrOfType<ArrayAttr>("parameters"));
1308 
1309  // Both modules and external modules have a body, but it is always empty for
1310  // external modules.
1311  Block *body = nullptr;
1312  if (!op->getRegion(0).empty())
1313  body = &op->getRegion(0).front();
1314 
1315  auto needPortNamesAttr = printModulePorts(
1316  p, body, op.getPortDirectionsAttr(), op.getPortNames(), op.getPortTypes(),
1317  op.getPortAnnotations(), op.getPortSymbols(), op.getPortLocations());
1318 
1319  SmallVector<StringRef, 12> omittedAttrs = {
1320  "sym_name", "portDirections", "portTypes", "portAnnotations",
1321  "portSyms", "portLocations", "parameters", visibilityAttrName};
1322 
1323  if (op.getConvention() == Convention::Internal)
1324  omittedAttrs.push_back("convention");
1325 
1326  // We can omit the portNames if they were able to be printed as properly as
1327  // block arguments.
1328  if (!needPortNamesAttr)
1329  omittedAttrs.push_back("portNames");
1330 
1331  // If there are no annotations we can omit the empty array.
1332  if (op->getAttrOfType<ArrayAttr>("annotations").empty())
1333  omittedAttrs.push_back("annotations");
1334 
1335  // If there are no enabled layers, then omit the empty array.
1336  if (auto layers = op->getAttrOfType<ArrayAttr>("layers"))
1337  if (layers.empty())
1338  omittedAttrs.push_back("layers");
1339 
1340  p.printOptionalAttrDictWithKeyword(op->getAttrs(), omittedAttrs);
1341 }
1342 
1343 void FExtModuleOp::print(OpAsmPrinter &p) { printFModuleLikeOp(p, *this); }
1344 
1345 void FIntModuleOp::print(OpAsmPrinter &p) { printFModuleLikeOp(p, *this); }
1346 
1347 void FMemModuleOp::print(OpAsmPrinter &p) { printFModuleLikeOp(p, *this); }
1348 
1349 void FModuleOp::print(OpAsmPrinter &p) {
1350  printFModuleLikeOp(p, *this);
1351 
1352  // Print the body if this is not an external function. Since this block does
1353  // not have terminators, printing the terminator actually just prints the last
1354  // operation.
1355  Region &fbody = getBody();
1356  if (!fbody.empty()) {
1357  p << " ";
1358  p.printRegion(fbody, /*printEntryBlockArgs=*/false,
1359  /*printBlockTerminators=*/true);
1360  }
1361 }
1362 
1363 /// Parse an parameter list if present.
1364 /// module-parameter-list ::= `<` parameter-decl (`,` parameter-decl)* `>`
1365 /// parameter-decl ::= identifier `:` type
1366 /// parameter-decl ::= identifier `:` type `=` attribute
1367 ///
1368 static ParseResult
1369 parseOptionalParameters(OpAsmParser &parser,
1370  SmallVectorImpl<Attribute> &parameters) {
1371 
1372  return parser.parseCommaSeparatedList(
1373  OpAsmParser::Delimiter::OptionalLessGreater, [&]() {
1374  std::string name;
1375  Type type;
1376  Attribute value;
1377 
1378  if (parser.parseKeywordOrString(&name) || parser.parseColonType(type))
1379  return failure();
1380 
1381  // Parse the default value if present.
1382  if (succeeded(parser.parseOptionalEqual())) {
1383  if (parser.parseAttribute(value, type))
1384  return failure();
1385  }
1386 
1387  auto &builder = parser.getBuilder();
1388  parameters.push_back(ParamDeclAttr::get(
1389  builder.getContext(), builder.getStringAttr(name), type, value));
1390  return success();
1391  });
1392 }
1393 
1394 /// Shim to use with assemblyFormat, custom<ParameterList>.
1395 static ParseResult parseParameterList(OpAsmParser &parser,
1396  ArrayAttr &parameters) {
1397  SmallVector<Attribute> parseParameters;
1398  if (failed(parseOptionalParameters(parser, parseParameters)))
1399  return failure();
1400 
1401  parameters = ArrayAttr::get(parser.getContext(), parseParameters);
1402 
1403  return success();
1404 }
1405 
1406 static ParseResult parseFModuleLikeOp(OpAsmParser &parser,
1407  OperationState &result,
1408  bool hasSSAIdentifiers) {
1409  auto *context = result.getContext();
1410  auto &builder = parser.getBuilder();
1411 
1412  // Parse the visibility attribute.
1413  (void)mlir::impl::parseOptionalVisibilityKeyword(parser, result.attributes);
1414 
1415  // Parse the name as a symbol.
1416  StringAttr nameAttr;
1417  if (parser.parseSymbolName(nameAttr, ::mlir::SymbolTable::getSymbolAttrName(),
1418  result.attributes))
1419  return failure();
1420 
1421  // Parse optional parameters.
1422  SmallVector<Attribute, 4> parameters;
1423  if (parseOptionalParameters(parser, parameters))
1424  return failure();
1425  result.addAttribute("parameters", builder.getArrayAttr(parameters));
1426 
1427  // Parse the module ports.
1428  SmallVector<OpAsmParser::Argument> entryArgs;
1429  SmallVector<Direction, 4> portDirections;
1430  SmallVector<Attribute, 4> portNames;
1431  SmallVector<Attribute, 4> portTypes;
1432  SmallVector<Attribute, 4> portAnnotations;
1433  SmallVector<Attribute, 4> portSyms;
1434  SmallVector<Attribute, 4> portLocs;
1435  if (parseModulePorts(parser, hasSSAIdentifiers, /*supportsSymbols=*/true,
1436  entryArgs, portDirections, portNames, portTypes,
1437  portAnnotations, portSyms, portLocs))
1438  return failure();
1439 
1440  // If module attributes are present, parse them.
1441  if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
1442  return failure();
1443 
1444  assert(portNames.size() == portTypes.size());
1445 
1446  // Record the argument and result types as an attribute. This is necessary
1447  // for external modules.
1448 
1449  // Add port directions.
1450  if (!result.attributes.get("portDirections"))
1451  result.addAttribute("portDirections",
1452  direction::packAttribute(context, portDirections));
1453 
1454  // Add port names.
1455  if (!result.attributes.get("portNames"))
1456  result.addAttribute("portNames", builder.getArrayAttr(portNames));
1457 
1458  // Add the port types.
1459  if (!result.attributes.get("portTypes"))
1460  result.addAttribute("portTypes", ArrayAttr::get(context, portTypes));
1461 
1462  // Add the port annotations.
1463  if (!result.attributes.get("portAnnotations")) {
1464  // If there are no portAnnotations, don't add the attribute.
1465  if (llvm::any_of(portAnnotations, [&](Attribute anno) {
1466  return !cast<ArrayAttr>(anno).empty();
1467  }))
1468  result.addAttribute("portAnnotations",
1469  ArrayAttr::get(context, portAnnotations));
1470  }
1471 
1472  // Add port symbols.
1473  if (!result.attributes.get("portSyms")) {
1474  FModuleLike::fixupPortSymsArray(portSyms, builder.getContext());
1475  result.addAttribute("portSyms", builder.getArrayAttr(portSyms));
1476  }
1477 
1478  // Add port locations.
1479  if (!result.attributes.get("portLocations"))
1480  result.addAttribute("portLocations", ArrayAttr::get(context, portLocs));
1481 
1482  // The annotations attribute is always present, but not printed when empty.
1483  if (!result.attributes.get("annotations"))
1484  result.addAttribute("annotations", builder.getArrayAttr({}));
1485 
1486  // The portAnnotations attribute is always present, but not printed when
1487  // empty.
1488  if (!result.attributes.get("portAnnotations"))
1489  result.addAttribute("portAnnotations", builder.getArrayAttr({}));
1490 
1491  // Parse the optional function body.
1492  auto *body = result.addRegion();
1493 
1494  if (hasSSAIdentifiers) {
1495  if (parser.parseRegion(*body, entryArgs))
1496  return failure();
1497  if (body->empty())
1498  body->push_back(new Block());
1499  }
1500  return success();
1501 }
1502 
1503 ParseResult FModuleOp::parse(OpAsmParser &parser, OperationState &result) {
1504  if (parseFModuleLikeOp(parser, result, /*hasSSAIdentifiers=*/true))
1505  return failure();
1506  if (!result.attributes.get("convention"))
1507  result.addAttribute(
1508  "convention",
1509  ConventionAttr::get(result.getContext(), Convention::Internal));
1510  if (!result.attributes.get("layers"))
1511  result.addAttribute("layers", ArrayAttr::get(parser.getContext(), {}));
1512  return success();
1513 }
1514 
1515 ParseResult FExtModuleOp::parse(OpAsmParser &parser, OperationState &result) {
1516  if (parseFModuleLikeOp(parser, result, /*hasSSAIdentifiers=*/false))
1517  return failure();
1518  if (!result.attributes.get("convention"))
1519  result.addAttribute(
1520  "convention",
1521  ConventionAttr::get(result.getContext(), Convention::Internal));
1522  return success();
1523 }
1524 
1525 ParseResult FIntModuleOp::parse(OpAsmParser &parser, OperationState &result) {
1526  return parseFModuleLikeOp(parser, result, /*hasSSAIdentifiers=*/false);
1527 }
1528 
1529 ParseResult FMemModuleOp::parse(OpAsmParser &parser, OperationState &result) {
1530  return parseFModuleLikeOp(parser, result, /*hasSSAIdentifiers=*/false);
1531 }
1532 
1533 LogicalResult FModuleOp::verify() {
1534  // Verify the block arguments.
1535  auto *body = getBodyBlock();
1536  auto portTypes = getPortTypes();
1537  auto portLocs = getPortLocations();
1538  auto numPorts = portTypes.size();
1539 
1540  // Verify that we have the correct number of block arguments.
1541  if (body->getNumArguments() != numPorts)
1542  return emitOpError("entry block must have ")
1543  << numPorts << " arguments to match module signature";
1544 
1545  // Verify the block arguments' types and locations match our attributes.
1546  for (auto [arg, type, loc] : zip(body->getArguments(), portTypes, portLocs)) {
1547  if (arg.getType() != cast<TypeAttr>(type).getValue())
1548  return emitOpError("block argument types should match signature types");
1549  if (arg.getLoc() != cast<LocationAttr>(loc))
1550  return emitOpError(
1551  "block argument locations should match signature locations");
1552  }
1553 
1554  return success();
1555 }
1556 
1557 static LogicalResult
1558 verifyInternalPaths(FModuleLike op,
1559  std::optional<::mlir::ArrayAttr> internalPaths) {
1560  if (!internalPaths)
1561  return success();
1562 
1563  // If internal paths are present, should cover all ports.
1564  if (internalPaths->size() != op.getNumPorts())
1565  return op.emitError("module has inconsistent internal path array with ")
1566  << internalPaths->size() << " entries for " << op.getNumPorts()
1567  << " ports";
1568 
1569  // No internal paths for non-ref-type ports.
1570  for (auto [idx, path, typeattr] : llvm::enumerate(
1571  internalPaths->getAsRange<InternalPathAttr>(), op.getPortTypes())) {
1572  if (path.getPath() &&
1573  !type_isa<RefType>(cast<TypeAttr>(typeattr).getValue())) {
1574  auto diag =
1575  op.emitError("module has internal path for non-ref-type port ")
1576  << op.getPortNameAttr(idx);
1577  return diag.attachNote(op.getPortLocation(idx)) << "this port";
1578  }
1579  }
1580 
1581  return success();
1582 }
1583 
1584 LogicalResult FExtModuleOp::verify() {
1585  if (failed(verifyInternalPaths(*this, getInternalPaths())))
1586  return failure();
1587 
1588  auto params = getParameters();
1589  if (params.empty())
1590  return success();
1591 
1592  auto checkParmValue = [&](Attribute elt) -> bool {
1593  auto param = cast<ParamDeclAttr>(elt);
1594  auto value = param.getValue();
1595  if (isa<IntegerAttr, StringAttr, FloatAttr, hw::ParamVerbatimAttr>(value))
1596  return true;
1597  emitError() << "has unknown extmodule parameter value '"
1598  << param.getName().getValue() << "' = " << value;
1599  return false;
1600  };
1601 
1602  if (!llvm::all_of(params, checkParmValue))
1603  return failure();
1604 
1605  return success();
1606 }
1607 
1608 LogicalResult FIntModuleOp::verify() {
1609  if (failed(verifyInternalPaths(*this, getInternalPaths())))
1610  return failure();
1611 
1612  auto params = getParameters();
1613  if (params.empty())
1614  return success();
1615 
1616  auto checkParmValue = [&](Attribute elt) -> bool {
1617  auto param = cast<ParamDeclAttr>(elt);
1618  auto value = param.getValue();
1619  if (isa<IntegerAttr, StringAttr, FloatAttr>(value))
1620  return true;
1621  emitError() << "has unknown intmodule parameter value '"
1622  << param.getName().getValue() << "' = " << value;
1623  return false;
1624  };
1625 
1626  if (!llvm::all_of(params, checkParmValue))
1627  return failure();
1628 
1629  return success();
1630 }
1631 
1632 static LogicalResult verifyProbeType(RefType refType, Location loc,
1633  CircuitOp circuitOp,
1634  SymbolTableCollection &symbolTable,
1635  Twine start) {
1636  auto layer = refType.getLayer();
1637  if (!layer)
1638  return success();
1639  auto *layerOp = symbolTable.lookupSymbolIn(circuitOp, layer);
1640  if (!layerOp)
1641  return emitError(loc) << start << " associated with layer '" << layer
1642  << "', but this layer was not defined";
1643  if (!isa<LayerOp>(layerOp)) {
1644  auto diag = emitError(loc)
1645  << start << " associated with layer '" << layer
1646  << "', but symbol '" << layer << "' does not refer to a '"
1647  << LayerOp::getOperationName() << "' op";
1648  return diag.attachNote(layerOp->getLoc()) << "symbol refers to this op";
1649  }
1650  return success();
1651 }
1652 
1653 static LogicalResult verifyPortSymbolUses(FModuleLike module,
1654  SymbolTableCollection &symbolTable) {
1655  // verify types in ports.
1656  auto circuitOp = module->getParentOfType<CircuitOp>();
1657  for (size_t i = 0, e = module.getNumPorts(); i < e; ++i) {
1658  auto type = module.getPortType(i);
1659 
1660  if (auto refType = type_dyn_cast<RefType>(type)) {
1661  if (failed(verifyProbeType(
1662  refType, module.getPortLocation(i), circuitOp, symbolTable,
1663  Twine("probe port '") + module.getPortName(i) + "' is")))
1664  return failure();
1665  continue;
1666  }
1667 
1668  if (auto classType = dyn_cast<ClassType>(type)) {
1669  auto className = classType.getNameAttr();
1670  auto classOp = dyn_cast_or_null<ClassLike>(
1671  symbolTable.lookupSymbolIn(circuitOp, className));
1672  if (!classOp)
1673  return module.emitOpError() << "references unknown class " << className;
1674 
1675  // verify that the result type agrees with the class definition.
1676  if (failed(classOp.verifyType(classType,
1677  [&]() { return module.emitOpError(); })))
1678  return failure();
1679  continue;
1680  }
1681  }
1682 
1683  return success();
1684 }
1685 
1686 LogicalResult FModuleOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1687  if (failed(
1688  verifyPortSymbolUses(cast<FModuleLike>(getOperation()), symbolTable)))
1689  return failure();
1690 
1691  auto circuitOp = (*this)->getParentOfType<CircuitOp>();
1692  for (auto layer : getLayers()) {
1693  if (!symbolTable.lookupSymbolIn(circuitOp, cast<SymbolRefAttr>(layer)))
1694  return emitOpError() << "enables unknown layer '" << layer << "'";
1695  }
1696 
1697  return success();
1698 }
1699 
1700 LogicalResult
1701 FExtModuleOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1702  return verifyPortSymbolUses(cast<FModuleLike>(getOperation()), symbolTable);
1703 }
1704 
1705 LogicalResult
1706 FIntModuleOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1707  return verifyPortSymbolUses(cast<FModuleLike>(getOperation()), symbolTable);
1708 }
1709 
1710 LogicalResult
1711 FMemModuleOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1712  return verifyPortSymbolUses(cast<FModuleLike>(getOperation()), symbolTable);
1713 }
1714 
1715 void FModuleOp::getAsmBlockArgumentNames(mlir::Region &region,
1716  mlir::OpAsmSetValueNameFn setNameFn) {
1717  getAsmBlockArgumentNamesImpl(getOperation(), region, setNameFn);
1718 }
1719 
1720 void FExtModuleOp::getAsmBlockArgumentNames(
1721  mlir::Region &region, mlir::OpAsmSetValueNameFn setNameFn) {
1722  getAsmBlockArgumentNamesImpl(getOperation(), region, setNameFn);
1723 }
1724 
1725 void FIntModuleOp::getAsmBlockArgumentNames(
1726  mlir::Region &region, mlir::OpAsmSetValueNameFn setNameFn) {
1727  getAsmBlockArgumentNamesImpl(getOperation(), region, setNameFn);
1728 }
1729 
1730 void FMemModuleOp::getAsmBlockArgumentNames(
1731  mlir::Region &region, mlir::OpAsmSetValueNameFn setNameFn) {
1732  getAsmBlockArgumentNamesImpl(getOperation(), region, setNameFn);
1733 }
1734 
1735 ArrayAttr FMemModuleOp::getParameters() { return {}; }
1736 
1737 ArrayAttr FModuleOp::getParameters() { return {}; }
1738 
1739 Convention FIntModuleOp::getConvention() { return Convention::Internal; }
1740 
1741 ConventionAttr FIntModuleOp::getConventionAttr() {
1742  return ConventionAttr::get(getContext(), getConvention());
1743 }
1744 
1745 Convention FMemModuleOp::getConvention() { return Convention::Internal; }
1746 
1747 ConventionAttr FMemModuleOp::getConventionAttr() {
1748  return ConventionAttr::get(getContext(), getConvention());
1749 }
1750 
1751 //===----------------------------------------------------------------------===//
1752 // ClassLike Helpers
1753 //===----------------------------------------------------------------------===//
1754 
1756  ClassLike classOp, ClassType type,
1757  function_ref<InFlightDiagnostic()> emitError) {
1758  // This check is probably not required, but done for sanity.
1759  auto name = type.getNameAttr().getAttr();
1760  auto expectedName = classOp.getModuleNameAttr();
1761  if (name != expectedName)
1762  return emitError() << "type has wrong name, got " << name << ", expected "
1763  << expectedName;
1764 
1765  auto elements = type.getElements();
1766  auto numElements = elements.size();
1767  auto expectedNumElements = classOp.getNumPorts();
1768  if (numElements != expectedNumElements)
1769  return emitError() << "has wrong number of ports, got " << numElements
1770  << ", expected " << expectedNumElements;
1771 
1772  auto portNames = classOp.getPortNames();
1773  auto portDirections = classOp.getPortDirections();
1774  auto portTypes = classOp.getPortTypes();
1775 
1776  for (unsigned i = 0; i < numElements; ++i) {
1777  auto element = elements[i];
1778 
1779  auto name = element.name;
1780  auto expectedName = portNames[i];
1781  if (name != expectedName)
1782  return emitError() << "port #" << i << " has wrong name, got " << name
1783  << ", expected " << expectedName;
1784 
1785  auto direction = element.direction;
1786  auto expectedDirection = Direction(portDirections[i]);
1787  if (direction != expectedDirection)
1788  return emitError() << "port " << name << " has wrong direction, got "
1789  << direction::toString(direction) << ", expected "
1790  << direction::toString(expectedDirection);
1791 
1792  auto type = element.type;
1793  auto expectedType = cast<TypeAttr>(portTypes[i]).getValue();
1794  if (type != expectedType)
1795  return emitError() << "port " << name << " has wrong type, got " << type
1796  << ", expected " << expectedType;
1797  }
1798 
1799  return success();
1800 }
1801 
1802 ClassType firrtl::detail::getInstanceTypeForClassLike(ClassLike classOp) {
1803  auto n = classOp.getNumPorts();
1804  SmallVector<ClassElement> elements;
1805  elements.reserve(n);
1806  for (size_t i = 0; i < n; ++i)
1807  elements.push_back({classOp.getPortNameAttr(i), classOp.getPortType(i),
1808  classOp.getPortDirection(i)});
1809  auto name = FlatSymbolRefAttr::get(classOp.getNameAttr());
1810  return ClassType::get(name, elements);
1811 }
1812 
1813 static ParseResult parseClassLike(OpAsmParser &parser, OperationState &result,
1814  bool hasSSAIdentifiers) {
1815  auto *context = result.getContext();
1816  auto &builder = parser.getBuilder();
1817 
1818  // Parse the visibility attribute.
1819  (void)mlir::impl::parseOptionalVisibilityKeyword(parser, result.attributes);
1820 
1821  // Parse the name as a symbol.
1822  StringAttr nameAttr;
1823  if (parser.parseSymbolName(nameAttr, ::mlir::SymbolTable::getSymbolAttrName(),
1824  result.attributes))
1825  return failure();
1826 
1827  // Parse the module ports.
1828  SmallVector<OpAsmParser::Argument> entryArgs;
1829  SmallVector<Direction, 4> portDirections;
1830  SmallVector<Attribute, 4> portNames;
1831  SmallVector<Attribute, 4> portTypes;
1832  SmallVector<Attribute, 4> portAnnotations;
1833  SmallVector<Attribute, 4> portSyms;
1834  SmallVector<Attribute, 4> portLocs;
1835  if (parseModulePorts(parser, hasSSAIdentifiers,
1836  /*supportsSymbols=*/false, entryArgs, portDirections,
1837  portNames, portTypes, portAnnotations, portSyms,
1838  portLocs))
1839  return failure();
1840 
1841  // Ports on ClassLike ops cannot have annotations
1842  for (auto annos : portAnnotations)
1843  if (!cast<ArrayAttr>(annos).empty())
1844  return failure();
1845 
1846  // If attributes are present, parse them.
1847  if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
1848  return failure();
1849 
1850  assert(portNames.size() == portTypes.size());
1851 
1852  // Record the argument and result types as an attribute. This is necessary
1853  // for external modules.
1854 
1855  // Add port directions.
1856  if (!result.attributes.get("portDirections"))
1857  result.addAttribute("portDirections",
1858  direction::packAttribute(context, portDirections));
1859 
1860  // Add port names.
1861  if (!result.attributes.get("portNames"))
1862  result.addAttribute("portNames", builder.getArrayAttr(portNames));
1863 
1864  // Add the port types.
1865  if (!result.attributes.get("portTypes"))
1866  result.addAttribute("portTypes", builder.getArrayAttr(portTypes));
1867 
1868  // Add the port symbols.
1869  if (!result.attributes.get("portSyms")) {
1870  FModuleLike::fixupPortSymsArray(portSyms, builder.getContext());
1871  result.addAttribute("portSyms", builder.getArrayAttr(portSyms));
1872  }
1873 
1874  // Add port locations.
1875  if (!result.attributes.get("portLocations"))
1876  result.addAttribute("portLocations", ArrayAttr::get(context, portLocs));
1877 
1878  // Notably missing compared to other FModuleLike, we do not track port
1879  // annotations, nor port symbols, on classes.
1880 
1881  // Add the region (unused by extclass).
1882  auto *bodyRegion = result.addRegion();
1883 
1884  if (hasSSAIdentifiers) {
1885  if (parser.parseRegion(*bodyRegion, entryArgs))
1886  return failure();
1887  if (bodyRegion->empty())
1888  bodyRegion->push_back(new Block());
1889  }
1890 
1891  return success();
1892 }
1893 
1894 static void printClassLike(OpAsmPrinter &p, ClassLike op) {
1895  p << ' ';
1896 
1897  // Print the visibility of the class.
1898  StringRef visibilityAttrName = SymbolTable::getVisibilityAttrName();
1899  if (auto visibility = op->getAttrOfType<StringAttr>(visibilityAttrName))
1900  p << visibility.getValue() << ' ';
1901 
1902  // Print the class name.
1903  p.printSymbolName(op.getName());
1904 
1905  // Both classes and external classes have a body, but it is always empty for
1906  // external classes.
1907  Region &region = op->getRegion(0);
1908  Block *body = nullptr;
1909  if (!region.empty())
1910  body = &region.front();
1911 
1912  auto needPortNamesAttr = printModulePorts(
1913  p, body, op.getPortDirectionsAttr(), op.getPortNames(), op.getPortTypes(),
1914  {}, op.getPortSymbols(), op.getPortLocations());
1915 
1916  // Print the attr-dict.
1917  SmallVector<StringRef, 8> omittedAttrs = {
1918  "sym_name", "portNames", "portTypes", "portDirections",
1919  "portSyms", "portLocations", visibilityAttrName};
1920 
1921  // We can omit the portNames if they were able to be printed as properly as
1922  // block arguments.
1923  if (!needPortNamesAttr)
1924  omittedAttrs.push_back("portNames");
1925 
1926  p.printOptionalAttrDictWithKeyword(op->getAttrs(), omittedAttrs);
1927 
1928  // print the body if it exists.
1929  if (!region.empty()) {
1930  p << " ";
1931  auto printEntryBlockArgs = false;
1932  auto printBlockTerminators = false;
1933  p.printRegion(region, printEntryBlockArgs, printBlockTerminators);
1934  }
1935 }
1936 
1937 //===----------------------------------------------------------------------===//
1938 // ClassOp
1939 //===----------------------------------------------------------------------===//
1940 
1941 void ClassOp::build(OpBuilder &builder, OperationState &result, StringAttr name,
1942  ArrayRef<PortInfo> ports) {
1943  assert(
1944  llvm::all_of(ports,
1945  [](const auto &port) { return port.annotations.empty(); }) &&
1946  "class ports may not have annotations");
1947 
1948  buildClass(builder, result, name, ports);
1949 
1950  // Create a region and a block for the body.
1951  auto *bodyRegion = result.regions[0].get();
1952  Block *body = new Block();
1953  bodyRegion->push_back(body);
1954 
1955  // Add arguments to the body block.
1956  for (auto &elt : ports)
1957  body->addArgument(elt.type, elt.loc);
1958 }
1959 
1960 void ClassOp::print(OpAsmPrinter &p) {
1961  printClassLike(p, cast<ClassLike>(getOperation()));
1962 }
1963 
1964 ParseResult ClassOp::parse(OpAsmParser &parser, OperationState &result) {
1965  auto hasSSAIdentifiers = true;
1966  return parseClassLike(parser, result, hasSSAIdentifiers);
1967 }
1968 
1969 LogicalResult ClassOp::verify() {
1970  for (auto operand : getBodyBlock()->getArguments()) {
1971  auto type = operand.getType();
1972  if (!isa<PropertyType>(type)) {
1973  emitOpError("ports on a class must be properties");
1974  return failure();
1975  }
1976  }
1977 
1978  return success();
1979 }
1980 
1981 LogicalResult
1982 ClassOp::verifySymbolUses(::mlir::SymbolTableCollection &symbolTable) {
1983  return verifyPortSymbolUses(cast<FModuleLike>(getOperation()), symbolTable);
1984 }
1985 
1986 void ClassOp::getAsmBlockArgumentNames(mlir::Region &region,
1987  mlir::OpAsmSetValueNameFn setNameFn) {
1988  getAsmBlockArgumentNamesImpl(getOperation(), region, setNameFn);
1989 }
1990 
1991 SmallVector<PortInfo> ClassOp::getPorts() {
1992  return ::getPortImpl(cast<FModuleLike>((Operation *)*this));
1993 }
1994 
1995 void ClassOp::erasePorts(const llvm::BitVector &portIndices) {
1996  ::erasePorts(cast<FModuleLike>((Operation *)*this), portIndices);
1997  getBodyBlock()->eraseArguments(portIndices);
1998 }
1999 
2000 void ClassOp::insertPorts(ArrayRef<std::pair<unsigned, PortInfo>> ports) {
2001  ::insertPorts(cast<FModuleLike>((Operation *)*this), ports);
2002 }
2003 
2004 Convention ClassOp::getConvention() { return Convention::Internal; }
2005 
2006 ConventionAttr ClassOp::getConventionAttr() {
2007  return ConventionAttr::get(getContext(), getConvention());
2008 }
2009 
2010 ArrayAttr ClassOp::getParameters() { return {}; }
2011 
2012 ArrayAttr ClassOp::getPortAnnotationsAttr() {
2013  return ArrayAttr::get(getContext(), {});
2014 }
2015 
2016 ArrayAttr ClassOp::getLayersAttr() { return ArrayAttr::get(getContext(), {}); }
2017 
2018 ArrayRef<Attribute> ClassOp::getLayers() { return getLayersAttr(); }
2019 
2020 SmallVector<::circt::hw::PortInfo> ClassOp::getPortList() {
2021  return ::getPortListImpl(*this);
2022 }
2023 
2025  return ::getPortImpl(*this, idx);
2026 }
2027 
2028 BlockArgument ClassOp::getArgument(size_t portNumber) {
2029  return getBodyBlock()->getArgument(portNumber);
2030 }
2031 
2032 bool ClassOp::canDiscardOnUseEmpty() {
2033  // ClassOps are referenced by ClassTypes, and these uses are not
2034  // discoverable by the symbol infrastructure. Return false here to prevent
2035  // passes like symbolDCE from removing our classes.
2036  return false;
2037 }
2038 
2039 //===----------------------------------------------------------------------===//
2040 // ExtClassOp
2041 //===----------------------------------------------------------------------===//
2042 
2043 void ExtClassOp::build(OpBuilder &builder, OperationState &result,
2044  StringAttr name, ArrayRef<PortInfo> ports) {
2045  assert(
2046  llvm::all_of(ports,
2047  [](const auto &port) { return port.annotations.empty(); }) &&
2048  "class ports may not have annotations");
2049 
2050  buildClass(builder, result, name, ports);
2051 }
2052 
2053 void ExtClassOp::print(OpAsmPrinter &p) {
2054  printClassLike(p, cast<ClassLike>(getOperation()));
2055 }
2056 
2057 ParseResult ExtClassOp::parse(OpAsmParser &parser, OperationState &result) {
2058  auto hasSSAIdentifiers = false;
2059  return parseClassLike(parser, result, hasSSAIdentifiers);
2060 }
2061 
2062 LogicalResult
2063 ExtClassOp::verifySymbolUses(::mlir::SymbolTableCollection &symbolTable) {
2064  return verifyPortSymbolUses(cast<FModuleLike>(getOperation()), symbolTable);
2065 }
2066 
2067 void ExtClassOp::getAsmBlockArgumentNames(mlir::Region &region,
2068  mlir::OpAsmSetValueNameFn setNameFn) {
2069  getAsmBlockArgumentNamesImpl(getOperation(), region, setNameFn);
2070 }
2071 
2072 SmallVector<PortInfo> ExtClassOp::getPorts() {
2073  return ::getPortImpl(cast<FModuleLike>((Operation *)*this));
2074 }
2075 
2076 void ExtClassOp::erasePorts(const llvm::BitVector &portIndices) {
2077  ::erasePorts(cast<FModuleLike>((Operation *)*this), portIndices);
2078 }
2079 
2080 void ExtClassOp::insertPorts(ArrayRef<std::pair<unsigned, PortInfo>> ports) {
2081  ::insertPorts(cast<FModuleLike>((Operation *)*this), ports);
2082 }
2083 
2084 Convention ExtClassOp::getConvention() { return Convention::Internal; }
2085 
2086 ConventionAttr ExtClassOp::getConventionAttr() {
2087  return ConventionAttr::get(getContext(), getConvention());
2088 }
2089 
2090 ArrayAttr ExtClassOp::getLayersAttr() {
2091  return ArrayAttr::get(getContext(), {});
2092 }
2093 
2094 ArrayRef<Attribute> ExtClassOp::getLayers() { return getLayersAttr(); }
2095 
2096 ArrayAttr ExtClassOp::getParameters() { return {}; }
2097 
2098 ArrayAttr ExtClassOp::getPortAnnotationsAttr() {
2099  return ArrayAttr::get(getContext(), {});
2100 }
2101 
2102 SmallVector<::circt::hw::PortInfo> ExtClassOp::getPortList() {
2103  return ::getPortListImpl(*this);
2104 }
2105 
2107  return ::getPortImpl(*this, idx);
2108 }
2109 
2110 bool ExtClassOp::canDiscardOnUseEmpty() {
2111  // ClassOps are referenced by ClassTypes, and these uses are not
2112  // discovereable by the symbol infrastructure. Return false here to prevent
2113  // passes like symbolDCE from removing our classes.
2114  return false;
2115 }
2116 
2117 //===----------------------------------------------------------------------===//
2118 // InstanceOp
2119 //===----------------------------------------------------------------------===//
2120 
2121 void InstanceOp::build(
2122  OpBuilder &builder, OperationState &result, TypeRange resultTypes,
2123  StringRef moduleName, StringRef name, NameKindEnum nameKind,
2124  ArrayRef<Direction> portDirections, ArrayRef<Attribute> portNames,
2125  ArrayRef<Attribute> annotations, ArrayRef<Attribute> portAnnotations,
2126  ArrayRef<Attribute> layers, bool lowerToBind, StringAttr innerSym) {
2127  build(builder, result, resultTypes, moduleName, name, nameKind,
2128  portDirections, portNames, annotations, portAnnotations, layers,
2129  lowerToBind,
2130  innerSym ? hw::InnerSymAttr::get(innerSym) : hw::InnerSymAttr());
2131 }
2132 
2133 void InstanceOp::build(
2134  OpBuilder &builder, OperationState &result, TypeRange resultTypes,
2135  StringRef moduleName, StringRef name, NameKindEnum nameKind,
2136  ArrayRef<Direction> portDirections, ArrayRef<Attribute> portNames,
2137  ArrayRef<Attribute> annotations, ArrayRef<Attribute> portAnnotations,
2138  ArrayRef<Attribute> layers, bool lowerToBind, hw::InnerSymAttr innerSym) {
2139  result.addTypes(resultTypes);
2140  result.addAttribute("moduleName",
2141  SymbolRefAttr::get(builder.getContext(), moduleName));
2142  result.addAttribute("name", builder.getStringAttr(name));
2143  result.addAttribute(
2144  "portDirections",
2145  direction::packAttribute(builder.getContext(), portDirections));
2146  result.addAttribute("portNames", builder.getArrayAttr(portNames));
2147  result.addAttribute("annotations", builder.getArrayAttr(annotations));
2148  result.addAttribute("layers", builder.getArrayAttr(layers));
2149  if (lowerToBind)
2150  result.addAttribute("lowerToBind", builder.getUnitAttr());
2151  if (innerSym)
2152  result.addAttribute("inner_sym", innerSym);
2153  result.addAttribute("nameKind",
2154  NameKindEnumAttr::get(builder.getContext(), nameKind));
2155 
2156  if (portAnnotations.empty()) {
2157  SmallVector<Attribute, 16> portAnnotationsVec(resultTypes.size(),
2158  builder.getArrayAttr({}));
2159  result.addAttribute("portAnnotations",
2160  builder.getArrayAttr(portAnnotationsVec));
2161  } else {
2162  assert(portAnnotations.size() == resultTypes.size());
2163  result.addAttribute("portAnnotations",
2164  builder.getArrayAttr(portAnnotations));
2165  }
2166 }
2167 
2168 void InstanceOp::build(OpBuilder &builder, OperationState &result,
2169  FModuleLike module, StringRef name,
2170  NameKindEnum nameKind, ArrayRef<Attribute> annotations,
2171  ArrayRef<Attribute> portAnnotations, bool lowerToBind,
2172  hw::InnerSymAttr innerSym) {
2173 
2174  // Gather the result types.
2175  SmallVector<Type> resultTypes;
2176  resultTypes.reserve(module.getNumPorts());
2177  llvm::transform(
2178  module.getPortTypes(), std::back_inserter(resultTypes),
2179  [](Attribute typeAttr) { return cast<TypeAttr>(typeAttr).getValue(); });
2180 
2181  // Create the port annotations.
2182  ArrayAttr portAnnotationsAttr;
2183  if (portAnnotations.empty()) {
2184  portAnnotationsAttr = builder.getArrayAttr(SmallVector<Attribute, 16>(
2185  resultTypes.size(), builder.getArrayAttr({})));
2186  } else {
2187  portAnnotationsAttr = builder.getArrayAttr(portAnnotations);
2188  }
2189 
2190  return build(
2191  builder, result, resultTypes,
2192  SymbolRefAttr::get(builder.getContext(), module.getModuleNameAttr()),
2193  builder.getStringAttr(name),
2194  NameKindEnumAttr::get(builder.getContext(), nameKind),
2195  module.getPortDirectionsAttr(), module.getPortNamesAttr(),
2196  builder.getArrayAttr(annotations), portAnnotationsAttr,
2197  module.getLayersAttr(), lowerToBind ? builder.getUnitAttr() : UnitAttr(),
2198  innerSym);
2199 }
2200 
2201 void InstanceOp::build(OpBuilder &builder, OperationState &odsState,
2202  ArrayRef<PortInfo> ports, StringRef moduleName,
2203  StringRef name, NameKindEnum nameKind,
2204  ArrayRef<Attribute> annotations,
2205  ArrayRef<Attribute> layers, bool lowerToBind,
2206  hw::InnerSymAttr innerSym) {
2207  // Gather the result types.
2208  SmallVector<Type> newResultTypes;
2209  SmallVector<Direction> newPortDirections;
2210  SmallVector<Attribute> newPortNames;
2211  SmallVector<Attribute> newPortAnnotations;
2212  for (auto &p : ports) {
2213  newResultTypes.push_back(p.type);
2214  newPortDirections.push_back(p.direction);
2215  newPortNames.push_back(p.name);
2216  newPortAnnotations.push_back(p.annotations.getArrayAttr());
2217  }
2218 
2219  return build(builder, odsState, newResultTypes, moduleName, name, nameKind,
2220  newPortDirections, newPortNames, annotations, newPortAnnotations,
2221  layers, lowerToBind, innerSym);
2222 }
2223 
2224 LogicalResult InstanceOp::verify() {
2225  // The instance may only be instantiated under its required layers.
2226  auto ambientLayers = getAmbientLayersAt(getOperation());
2227  SmallVector<SymbolRefAttr> missingLayers;
2228  for (auto layer : getLayersAttr().getAsRange<SymbolRefAttr>())
2229  if (!isLayerCompatibleWith(layer, ambientLayers))
2230  missingLayers.push_back(layer);
2231 
2232  if (missingLayers.empty())
2233  return success();
2234 
2235  auto diag =
2236  emitOpError("ambient layers are insufficient to instantiate module");
2237  auto &note = diag.attachNote();
2238  note << "missing layer requirements: ";
2239  interleaveComma(missingLayers, note);
2240  return failure();
2241 }
2242 
2243 /// Builds a new `InstanceOp` with the ports listed in `portIndices` erased, and
2244 /// updates any users of the remaining ports to point at the new instance.
2245 InstanceOp InstanceOp::erasePorts(OpBuilder &builder,
2246  const llvm::BitVector &portIndices) {
2247  assert(portIndices.size() >= getNumResults() &&
2248  "portIndices is not at least as large as getNumResults()");
2249 
2250  if (portIndices.none())
2251  return *this;
2252 
2253  SmallVector<Type> newResultTypes = removeElementsAtIndices<Type>(
2254  SmallVector<Type>(result_type_begin(), result_type_end()), portIndices);
2255  SmallVector<Direction> newPortDirections = removeElementsAtIndices<Direction>(
2256  direction::unpackAttribute(getPortDirectionsAttr()), portIndices);
2257  SmallVector<Attribute> newPortNames =
2258  removeElementsAtIndices(getPortNames().getValue(), portIndices);
2259  SmallVector<Attribute> newPortAnnotations =
2260  removeElementsAtIndices(getPortAnnotations().getValue(), portIndices);
2261 
2262  auto newOp = builder.create<InstanceOp>(
2263  getLoc(), newResultTypes, getModuleName(), getName(), getNameKind(),
2264  newPortDirections, newPortNames, getAnnotations().getValue(),
2265  newPortAnnotations, getLayers(), getLowerToBind(), getInnerSymAttr());
2266 
2267  for (unsigned oldIdx = 0, newIdx = 0, numOldPorts = getNumResults();
2268  oldIdx != numOldPorts; ++oldIdx) {
2269  if (portIndices.test(oldIdx)) {
2270  assert(getResult(oldIdx).use_empty() && "removed instance port has uses");
2271  continue;
2272  }
2273  getResult(oldIdx).replaceAllUsesWith(newOp.getResult(newIdx));
2274  ++newIdx;
2275  }
2276 
2277  // Compy over "output_file" information so that this is not lost when ports
2278  // are erased.
2279  //
2280  // TODO: Other attributes may need to be copied over.
2281  if (auto outputFile = (*this)->getAttr("output_file"))
2282  newOp->setAttr("output_file", outputFile);
2283 
2284  return newOp;
2285 }
2286 
2287 ArrayAttr InstanceOp::getPortAnnotation(unsigned portIdx) {
2288  assert(portIdx < getNumResults() &&
2289  "index should be smaller than result number");
2290  return cast<ArrayAttr>(getPortAnnotations()[portIdx]);
2291 }
2292 
2293 void InstanceOp::setAllPortAnnotations(ArrayRef<Attribute> annotations) {
2294  assert(annotations.size() == getNumResults() &&
2295  "number of annotations is not equal to result number");
2296  (*this)->setAttr("portAnnotations",
2297  ArrayAttr::get(getContext(), annotations));
2298 }
2299 
2300 InstanceOp
2301 InstanceOp::cloneAndInsertPorts(ArrayRef<std::pair<unsigned, PortInfo>> ports) {
2302  auto portSize = ports.size();
2303  auto newPortCount = getNumResults() + portSize;
2304  SmallVector<Direction> newPortDirections;
2305  newPortDirections.reserve(newPortCount);
2306  SmallVector<Attribute> newPortNames;
2307  newPortNames.reserve(newPortCount);
2308  SmallVector<Type> newPortTypes;
2309  newPortTypes.reserve(newPortCount);
2310  SmallVector<Attribute> newPortAnnos;
2311  newPortAnnos.reserve(newPortCount);
2312 
2313  unsigned oldIndex = 0;
2314  unsigned newIndex = 0;
2315  while (oldIndex + newIndex < newPortCount) {
2316  // Check if we should insert a port here.
2317  if (newIndex < portSize && ports[newIndex].first == oldIndex) {
2318  auto &newPort = ports[newIndex].second;
2319  newPortDirections.push_back(newPort.direction);
2320  newPortNames.push_back(newPort.name);
2321  newPortTypes.push_back(newPort.type);
2322  newPortAnnos.push_back(newPort.annotations.getArrayAttr());
2323  ++newIndex;
2324  } else {
2325  // Copy the next old port.
2326  newPortDirections.push_back(getPortDirection(oldIndex));
2327  newPortNames.push_back(getPortName(oldIndex));
2328  newPortTypes.push_back(getType(oldIndex));
2329  newPortAnnos.push_back(getPortAnnotation(oldIndex));
2330  ++oldIndex;
2331  }
2332  }
2333 
2334  // Create a new instance op with the reset inserted.
2335  return OpBuilder(*this).create<InstanceOp>(
2336  getLoc(), newPortTypes, getModuleName(), getName(), getNameKind(),
2337  newPortDirections, newPortNames, getAnnotations().getValue(),
2338  newPortAnnos, getLayers(), getLowerToBind(), getInnerSymAttr());
2339 }
2340 
2341 LogicalResult InstanceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
2342  return instance_like_impl::verifyReferencedModule(*this, symbolTable,
2343  getModuleNameAttr());
2344 }
2345 
2346 StringRef InstanceOp::getInstanceName() { return getName(); }
2347 
2348 StringAttr InstanceOp::getInstanceNameAttr() { return getNameAttr(); }
2349 
2350 void InstanceOp::print(OpAsmPrinter &p) {
2351  // Print the instance name.
2352  p << " ";
2353  p.printKeywordOrString(getName());
2354  if (auto attr = getInnerSymAttr()) {
2355  p << " sym ";
2356  p.printSymbolName(attr.getSymName());
2357  }
2358  if (getNameKindAttr().getValue() != NameKindEnum::DroppableName)
2359  p << ' ' << stringifyNameKindEnum(getNameKindAttr().getValue());
2360 
2361  // Print the attr-dict.
2362  SmallVector<StringRef, 10> omittedAttrs = {
2363  "moduleName", "name", "portDirections",
2364  "portNames", "portTypes", "portAnnotations",
2365  "inner_sym", "nameKind"};
2366  if (getAnnotations().empty())
2367  omittedAttrs.push_back("annotations");
2368  if (getLayers().empty())
2369  omittedAttrs.push_back("layers");
2370  p.printOptionalAttrDict((*this)->getAttrs(), omittedAttrs);
2371 
2372  // Print the module name.
2373  p << " ";
2374  p.printSymbolName(getModuleName());
2375 
2376  // Collect all the result types as TypeAttrs for printing.
2377  SmallVector<Attribute> portTypes;
2378  portTypes.reserve(getNumResults());
2379  llvm::transform(getResultTypes(), std::back_inserter(portTypes),
2380  &TypeAttr::get);
2381  printModulePorts(p, /*block=*/nullptr, getPortDirectionsAttr(),
2382  getPortNames().getValue(), portTypes,
2383  getPortAnnotations().getValue(), {}, {});
2384 }
2385 
2386 ParseResult InstanceOp::parse(OpAsmParser &parser, OperationState &result) {
2387  auto *context = parser.getContext();
2388  auto &resultAttrs = result.attributes;
2389 
2390  std::string name;
2391  hw::InnerSymAttr innerSymAttr;
2392  FlatSymbolRefAttr moduleName;
2393  SmallVector<OpAsmParser::Argument> entryArgs;
2394  SmallVector<Direction, 4> portDirections;
2395  SmallVector<Attribute, 4> portNames;
2396  SmallVector<Attribute, 4> portTypes;
2397  SmallVector<Attribute, 4> portAnnotations;
2398  SmallVector<Attribute, 4> portSyms;
2399  SmallVector<Attribute, 4> portLocs;
2400  NameKindEnumAttr nameKind;
2401 
2402  if (parser.parseKeywordOrString(&name))
2403  return failure();
2404  if (succeeded(parser.parseOptionalKeyword("sym"))) {
2405  if (parser.parseCustomAttributeWithFallback(
2406  innerSymAttr, ::mlir::Type{},
2407  hw::InnerSymbolTable::getInnerSymbolAttrName(),
2408  result.attributes)) {
2409  return ::mlir::failure();
2410  }
2411  }
2412  if (parseNameKind(parser, nameKind) ||
2413  parser.parseOptionalAttrDict(result.attributes) ||
2414  parser.parseAttribute(moduleName, "moduleName", resultAttrs) ||
2415  parseModulePorts(parser, /*hasSSAIdentifiers=*/false,
2416  /*supportsSymbols=*/false, entryArgs, portDirections,
2417  portNames, portTypes, portAnnotations, portSyms,
2418  portLocs))
2419  return failure();
2420 
2421  // Add the attributes. We let attributes defined in the attr-dict override
2422  // attributes parsed out of the module signature.
2423  if (!resultAttrs.get("moduleName"))
2424  result.addAttribute("moduleName", moduleName);
2425  if (!resultAttrs.get("name"))
2426  result.addAttribute("name", StringAttr::get(context, name));
2427  result.addAttribute("nameKind", nameKind);
2428  if (!resultAttrs.get("portDirections"))
2429  result.addAttribute("portDirections",
2430  direction::packAttribute(context, portDirections));
2431  if (!resultAttrs.get("portNames"))
2432  result.addAttribute("portNames", ArrayAttr::get(context, portNames));
2433  if (!resultAttrs.get("portAnnotations"))
2434  result.addAttribute("portAnnotations",
2435  ArrayAttr::get(context, portAnnotations));
2436 
2437  // Annotations, layers, and LowerToBind are omitted in the printed format
2438  // if they are empty, empty, and false (respectively).
2439  if (!resultAttrs.get("annotations"))
2440  resultAttrs.append("annotations", parser.getBuilder().getArrayAttr({}));
2441  if (!resultAttrs.get("layers"))
2442  resultAttrs.append("layers", parser.getBuilder().getArrayAttr({}));
2443 
2444  // Add result types.
2445  result.types.reserve(portTypes.size());
2446  llvm::transform(
2447  portTypes, std::back_inserter(result.types),
2448  [](Attribute typeAttr) { return cast<TypeAttr>(typeAttr).getValue(); });
2449 
2450  return success();
2451 }
2452 
2454  StringRef base = getName();
2455  if (base.empty())
2456  base = "inst";
2457 
2458  for (size_t i = 0, e = (*this)->getNumResults(); i != e; ++i) {
2459  setNameFn(getResult(i), (base + "_" + getPortNameStr(i)).str());
2460  }
2461 }
2462 
2463 std::optional<size_t> InstanceOp::getTargetResultIndex() {
2464  // Inner symbols on instance operations target the op not any result.
2465  return std::nullopt;
2466 }
2467 
2468 // -----------------------------------------------------------------------------
2469 // InstanceChoiceOp
2470 // -----------------------------------------------------------------------------
2471 
2472 void InstanceChoiceOp::build(
2473  OpBuilder &builder, OperationState &result, FModuleLike defaultModule,
2474  ArrayRef<std::pair<OptionCaseOp, FModuleLike>> cases, StringRef name,
2475  NameKindEnum nameKind, ArrayRef<Attribute> annotations,
2476  ArrayRef<Attribute> portAnnotations, StringAttr innerSym) {
2477  // Gather the result types.
2478  SmallVector<Type> resultTypes;
2479  for (Attribute portType : defaultModule.getPortTypes())
2480  resultTypes.push_back(cast<TypeAttr>(portType).getValue());
2481 
2482  // Create the port annotations.
2483  ArrayAttr portAnnotationsAttr;
2484  if (portAnnotations.empty()) {
2485  portAnnotationsAttr = builder.getArrayAttr(SmallVector<Attribute, 16>(
2486  resultTypes.size(), builder.getArrayAttr({})));
2487  } else {
2488  portAnnotationsAttr = builder.getArrayAttr(portAnnotations);
2489  }
2490 
2491  // Gather the module & case names.
2492  SmallVector<Attribute> moduleNames, caseNames;
2493  moduleNames.push_back(SymbolRefAttr::get(defaultModule.getModuleNameAttr()));
2494  for (auto [caseOption, caseModule] : cases) {
2495  auto caseGroup = caseOption->getParentOfType<OptionOp>();
2496  caseNames.push_back(SymbolRefAttr::get(caseGroup.getSymNameAttr(),
2497  {SymbolRefAttr::get(caseOption)}));
2498  moduleNames.push_back(SymbolRefAttr::get(caseModule.getModuleNameAttr()));
2499  }
2500 
2501  return build(builder, result, resultTypes, builder.getArrayAttr(moduleNames),
2502  builder.getArrayAttr(caseNames), builder.getStringAttr(name),
2503  NameKindEnumAttr::get(builder.getContext(), nameKind),
2504  defaultModule.getPortDirectionsAttr(),
2505  defaultModule.getPortNamesAttr(),
2506  builder.getArrayAttr(annotations), portAnnotationsAttr,
2507  defaultModule.getLayersAttr(),
2508  innerSym ? hw::InnerSymAttr::get(innerSym) : hw::InnerSymAttr());
2509 }
2510 
2511 std::optional<size_t> InstanceChoiceOp::getTargetResultIndex() {
2512  return std::nullopt;
2513 }
2514 
2515 void InstanceChoiceOp::print(OpAsmPrinter &p) {
2516  // Print the instance name.
2517  p << " ";
2518  p.printKeywordOrString(getName());
2519  if (auto attr = getInnerSymAttr()) {
2520  p << " sym ";
2521  p.printSymbolName(attr.getSymName());
2522  }
2523  if (getNameKindAttr().getValue() != NameKindEnum::DroppableName)
2524  p << ' ' << stringifyNameKindEnum(getNameKindAttr().getValue());
2525 
2526  // Print the attr-dict.
2527  SmallVector<StringRef, 10> omittedAttrs = {
2528  "moduleNames", "caseNames", "name",
2529  "portDirections", "portNames", "portTypes",
2530  "portAnnotations", "inner_sym", "nameKind"};
2531  if (getAnnotations().empty())
2532  omittedAttrs.push_back("annotations");
2533  if (getLayers().empty())
2534  omittedAttrs.push_back("layers");
2535  p.printOptionalAttrDict((*this)->getAttrs(), omittedAttrs);
2536 
2537  // Print the module name.
2538  p << ' ';
2539 
2540  auto moduleNames = getModuleNamesAttr();
2541  auto caseNames = getCaseNamesAttr();
2542 
2543  p.printSymbolName(cast<FlatSymbolRefAttr>(moduleNames[0]).getValue());
2544 
2545  p << " alternatives ";
2546  p.printSymbolName(
2547  cast<SymbolRefAttr>(caseNames[0]).getRootReference().getValue());
2548  p << " { ";
2549  for (size_t i = 0, n = caseNames.size(); i < n; ++i) {
2550  if (i != 0)
2551  p << ", ";
2552 
2553  auto symbol = cast<SymbolRefAttr>(caseNames[i]);
2554  p.printSymbolName(symbol.getNestedReferences()[0].getValue());
2555  p << " -> ";
2556  p.printSymbolName(cast<FlatSymbolRefAttr>(moduleNames[i + 1]).getValue());
2557  }
2558 
2559  p << " } ";
2560 
2561  // Collect all the result types as TypeAttrs for printing.
2562  SmallVector<Attribute> portTypes;
2563  portTypes.reserve(getNumResults());
2564  llvm::transform(getResultTypes(), std::back_inserter(portTypes),
2565  &TypeAttr::get);
2566  printModulePorts(p, /*block=*/nullptr, getPortDirectionsAttr(),
2567  getPortNames().getValue(), portTypes,
2568  getPortAnnotations().getValue(), {}, {});
2569 }
2570 
2571 ParseResult InstanceChoiceOp::parse(OpAsmParser &parser,
2572  OperationState &result) {
2573  auto *context = parser.getContext();
2574  auto &resultAttrs = result.attributes;
2575 
2576  std::string name;
2577  hw::InnerSymAttr innerSymAttr;
2578  SmallVector<Attribute> moduleNames;
2579  SmallVector<Attribute> caseNames;
2580  SmallVector<OpAsmParser::Argument> entryArgs;
2581  SmallVector<Direction, 4> portDirections;
2582  SmallVector<Attribute, 4> portNames;
2583  SmallVector<Attribute, 4> portTypes;
2584  SmallVector<Attribute, 4> portAnnotations;
2585  SmallVector<Attribute, 4> portSyms;
2586  SmallVector<Attribute, 4> portLocs;
2587  NameKindEnumAttr nameKind;
2588 
2589  if (parser.parseKeywordOrString(&name))
2590  return failure();
2591  if (succeeded(parser.parseOptionalKeyword("sym"))) {
2592  if (parser.parseCustomAttributeWithFallback(
2593  innerSymAttr, Type{},
2594  hw::InnerSymbolTable::getInnerSymbolAttrName(),
2595  result.attributes)) {
2596  return failure();
2597  }
2598  }
2599  if (parseNameKind(parser, nameKind) ||
2600  parser.parseOptionalAttrDict(result.attributes))
2601  return failure();
2602 
2603  FlatSymbolRefAttr defaultModuleName;
2604  if (parser.parseAttribute(defaultModuleName))
2605  return failure();
2606  moduleNames.push_back(defaultModuleName);
2607 
2608  // alternatives { @opt::@case -> @target, ... }
2609  {
2610  FlatSymbolRefAttr optionName;
2611  if (parser.parseKeyword("alternatives") ||
2612  parser.parseAttribute(optionName) || parser.parseLBrace())
2613  return failure();
2614 
2615  FlatSymbolRefAttr moduleName;
2616  StringAttr caseName;
2617  while (succeeded(parser.parseOptionalSymbolName(caseName))) {
2618  if (parser.parseArrow() || parser.parseAttribute(moduleName))
2619  return failure();
2620  moduleNames.push_back(moduleName);
2621  caseNames.push_back(SymbolRefAttr::get(
2622  optionName.getAttr(), {FlatSymbolRefAttr::get(caseName)}));
2623  if (failed(parser.parseOptionalComma()))
2624  break;
2625  }
2626  if (parser.parseRBrace())
2627  return failure();
2628  }
2629 
2630  if (parseModulePorts(parser, /*hasSSAIdentifiers=*/false,
2631  /*supportsSymbols=*/false, entryArgs, portDirections,
2632  portNames, portTypes, portAnnotations, portSyms,
2633  portLocs))
2634  return failure();
2635 
2636  // Add the attributes. We let attributes defined in the attr-dict override
2637  // attributes parsed out of the module signature.
2638  if (!resultAttrs.get("moduleNames"))
2639  result.addAttribute("moduleNames", ArrayAttr::get(context, moduleNames));
2640  if (!resultAttrs.get("caseNames"))
2641  result.addAttribute("caseNames", ArrayAttr::get(context, caseNames));
2642  if (!resultAttrs.get("name"))
2643  result.addAttribute("name", StringAttr::get(context, name));
2644  result.addAttribute("nameKind", nameKind);
2645  if (!resultAttrs.get("portDirections"))
2646  result.addAttribute("portDirections",
2647  direction::packAttribute(context, portDirections));
2648  if (!resultAttrs.get("portNames"))
2649  result.addAttribute("portNames", ArrayAttr::get(context, portNames));
2650  if (!resultAttrs.get("portAnnotations"))
2651  result.addAttribute("portAnnotations",
2652  ArrayAttr::get(context, portAnnotations));
2653 
2654  // Annotations, layers, and LowerToBind are omitted in the printed format if
2655  // they are empty, empty, and false (respectively).
2656  if (!resultAttrs.get("annotations"))
2657  resultAttrs.append("annotations", parser.getBuilder().getArrayAttr({}));
2658  if (!resultAttrs.get("layers"))
2659  resultAttrs.append("layers", parser.getBuilder().getArrayAttr({}));
2660 
2661  // Add result types.
2662  result.types.reserve(portTypes.size());
2663  llvm::transform(
2664  portTypes, std::back_inserter(result.types),
2665  [](Attribute typeAttr) { return cast<TypeAttr>(typeAttr).getValue(); });
2666 
2667  return success();
2668 }
2669 
2671  StringRef base = getName().empty() ? "inst" : getName();
2672  for (auto [result, name] : llvm::zip(getResults(), getPortNames()))
2673  setNameFn(result, (base + "_" + cast<StringAttr>(name).getValue()).str());
2674 }
2675 
2676 LogicalResult InstanceChoiceOp::verify() {
2677  if (getCaseNamesAttr().empty())
2678  return emitOpError() << "must have at least one case";
2679  if (getModuleNamesAttr().size() != getCaseNamesAttr().size() + 1)
2680  return emitOpError() << "number of referenced modules does not match the "
2681  "number of options";
2682 
2683  // The modules may only be instantiated under their required layers (which
2684  // are the same for all modules).
2685  auto ambientLayers = getAmbientLayersAt(getOperation());
2686  SmallVector<SymbolRefAttr> missingLayers;
2687  for (auto layer : getLayersAttr().getAsRange<SymbolRefAttr>())
2688  if (!isLayerCompatibleWith(layer, ambientLayers))
2689  missingLayers.push_back(layer);
2690 
2691  if (missingLayers.empty())
2692  return success();
2693 
2694  auto diag =
2695  emitOpError("ambient layers are insufficient to instantiate module");
2696  auto &note = diag.attachNote();
2697  note << "missing layer requirements: ";
2698  interleaveComma(missingLayers, note);
2699  return failure();
2700 }
2701 
2702 LogicalResult
2703 InstanceChoiceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
2704  auto caseNames = getCaseNamesAttr();
2705  for (auto moduleName : getModuleNamesAttr()) {
2707  *this, symbolTable, cast<FlatSymbolRefAttr>(moduleName))))
2708  return failure();
2709  }
2710 
2711  auto root = cast<SymbolRefAttr>(caseNames[0]).getRootReference();
2712  for (size_t i = 0, n = caseNames.size(); i < n; ++i) {
2713  auto ref = cast<SymbolRefAttr>(caseNames[i]);
2714  auto refRoot = ref.getRootReference();
2715  if (ref.getRootReference() != root)
2716  return emitOpError() << "case " << ref
2717  << " is not in the same option group as "
2718  << caseNames[0];
2719 
2720  if (!symbolTable.lookupNearestSymbolFrom<OptionOp>(*this, refRoot))
2721  return emitOpError() << "option " << refRoot << " does not exist";
2722 
2723  if (!symbolTable.lookupNearestSymbolFrom<OptionCaseOp>(*this, ref))
2724  return emitOpError() << "option " << refRoot
2725  << " does not contain option case " << ref;
2726  }
2727 
2728  return success();
2729 }
2730 
2731 FlatSymbolRefAttr
2732 InstanceChoiceOp::getTargetOrDefaultAttr(OptionCaseOp option) {
2733  auto caseNames = getCaseNamesAttr();
2734  for (size_t i = 0, n = caseNames.size(); i < n; ++i) {
2735  StringAttr caseSym = cast<SymbolRefAttr>(caseNames[i]).getLeafReference();
2736  if (caseSym == option.getSymName())
2737  return cast<FlatSymbolRefAttr>(getModuleNamesAttr()[i + 1]);
2738  }
2739  return getDefaultTargetAttr();
2740 }
2741 
2742 SmallVector<std::pair<SymbolRefAttr, FlatSymbolRefAttr>, 1>
2743 InstanceChoiceOp::getTargetChoices() {
2744  auto caseNames = getCaseNamesAttr();
2745  auto moduleNames = getModuleNamesAttr();
2746  SmallVector<std::pair<SymbolRefAttr, FlatSymbolRefAttr>, 1> choices;
2747  for (size_t i = 0; i < caseNames.size(); ++i) {
2748  choices.emplace_back(cast<SymbolRefAttr>(caseNames[i]),
2749  cast<FlatSymbolRefAttr>(moduleNames[i + 1]));
2750  }
2751 
2752  return choices;
2753 }
2754 
2755 InstanceChoiceOp
2757  const llvm::BitVector &portIndices) {
2758  assert(portIndices.size() >= getNumResults() &&
2759  "portIndices is not at least as large as getNumResults()");
2760 
2761  if (portIndices.none())
2762  return *this;
2763 
2764  SmallVector<Type> newResultTypes = removeElementsAtIndices<Type>(
2765  SmallVector<Type>(result_type_begin(), result_type_end()), portIndices);
2766  SmallVector<Direction> newPortDirections = removeElementsAtIndices<Direction>(
2767  direction::unpackAttribute(getPortDirectionsAttr()), portIndices);
2768  SmallVector<Attribute> newPortNames =
2769  removeElementsAtIndices(getPortNames().getValue(), portIndices);
2770  SmallVector<Attribute> newPortAnnotations =
2771  removeElementsAtIndices(getPortAnnotations().getValue(), portIndices);
2772 
2773  auto newOp = builder.create<InstanceChoiceOp>(
2774  getLoc(), newResultTypes, getModuleNames(), getCaseNames(), getName(),
2775  getNameKind(), direction::packAttribute(getContext(), newPortDirections),
2776  ArrayAttr::get(getContext(), newPortNames), getAnnotationsAttr(),
2777  ArrayAttr::get(getContext(), newPortAnnotations), getLayers(),
2778  getInnerSymAttr());
2779 
2780  for (unsigned oldIdx = 0, newIdx = 0, numOldPorts = getNumResults();
2781  oldIdx != numOldPorts; ++oldIdx) {
2782  if (portIndices.test(oldIdx)) {
2783  assert(getResult(oldIdx).use_empty() && "removed instance port has uses");
2784  continue;
2785  }
2786  getResult(oldIdx).replaceAllUsesWith(newOp.getResult(newIdx));
2787  ++newIdx;
2788  }
2789 
2790  // Copy over "output_file" information so that this is not lost when ports
2791  // are erased.
2792  //
2793  // TODO: Other attributes may need to be copied over.
2794  if (auto outputFile = (*this)->getAttr("output_file"))
2795  newOp->setAttr("output_file", outputFile);
2796 
2797  return newOp;
2798 }
2799 
2800 //===----------------------------------------------------------------------===//
2801 // MemOp
2802 //===----------------------------------------------------------------------===//
2803 
2804 void MemOp::build(OpBuilder &builder, OperationState &result,
2805  TypeRange resultTypes, uint32_t readLatency,
2806  uint32_t writeLatency, uint64_t depth, RUWAttr ruw,
2807  ArrayRef<Attribute> portNames, StringRef name,
2808  NameKindEnum nameKind, ArrayRef<Attribute> annotations,
2809  ArrayRef<Attribute> portAnnotations,
2810  hw::InnerSymAttr innerSym) {
2811  result.addAttribute(
2812  "readLatency",
2813  builder.getIntegerAttr(builder.getIntegerType(32), readLatency));
2814  result.addAttribute(
2815  "writeLatency",
2816  builder.getIntegerAttr(builder.getIntegerType(32), writeLatency));
2817  result.addAttribute(
2818  "depth", builder.getIntegerAttr(builder.getIntegerType(64), depth));
2819  result.addAttribute("ruw", ::RUWAttrAttr::get(builder.getContext(), ruw));
2820  result.addAttribute("portNames", builder.getArrayAttr(portNames));
2821  result.addAttribute("name", builder.getStringAttr(name));
2822  result.addAttribute("nameKind",
2823  NameKindEnumAttr::get(builder.getContext(), nameKind));
2824  result.addAttribute("annotations", builder.getArrayAttr(annotations));
2825  if (innerSym)
2826  result.addAttribute("inner_sym", innerSym);
2827  result.addTypes(resultTypes);
2828 
2829  if (portAnnotations.empty()) {
2830  SmallVector<Attribute, 16> portAnnotationsVec(resultTypes.size(),
2831  builder.getArrayAttr({}));
2832  result.addAttribute("portAnnotations",
2833  builder.getArrayAttr(portAnnotationsVec));
2834  } else {
2835  assert(portAnnotations.size() == resultTypes.size());
2836  result.addAttribute("portAnnotations",
2837  builder.getArrayAttr(portAnnotations));
2838  }
2839 }
2840 
2841 void MemOp::build(OpBuilder &builder, OperationState &result,
2842  TypeRange resultTypes, uint32_t readLatency,
2843  uint32_t writeLatency, uint64_t depth, RUWAttr ruw,
2844  ArrayRef<Attribute> portNames, StringRef name,
2845  NameKindEnum nameKind, ArrayRef<Attribute> annotations,
2846  ArrayRef<Attribute> portAnnotations, StringAttr innerSym) {
2847  build(builder, result, resultTypes, readLatency, writeLatency, depth, ruw,
2848  portNames, name, nameKind, annotations, portAnnotations,
2849  innerSym ? hw::InnerSymAttr::get(innerSym) : hw::InnerSymAttr());
2850 }
2851 
2852 ArrayAttr MemOp::getPortAnnotation(unsigned portIdx) {
2853  assert(portIdx < getNumResults() &&
2854  "index should be smaller than result number");
2855  return cast<ArrayAttr>(getPortAnnotations()[portIdx]);
2856 }
2857 
2858 void MemOp::setAllPortAnnotations(ArrayRef<Attribute> annotations) {
2859  assert(annotations.size() == getNumResults() &&
2860  "number of annotations is not equal to result number");
2861  (*this)->setAttr("portAnnotations",
2862  ArrayAttr::get(getContext(), annotations));
2863 }
2864 
2865 // Get the number of read, write and read-write ports.
2866 void MemOp::getNumPorts(size_t &numReadPorts, size_t &numWritePorts,
2867  size_t &numReadWritePorts, size_t &numDbgsPorts) {
2868  numReadPorts = 0;
2869  numWritePorts = 0;
2870  numReadWritePorts = 0;
2871  numDbgsPorts = 0;
2872  for (size_t i = 0, e = getNumResults(); i != e; ++i) {
2873  auto portKind = getPortKind(i);
2874  if (portKind == MemOp::PortKind::Debug)
2875  ++numDbgsPorts;
2876  else if (portKind == MemOp::PortKind::Read)
2877  ++numReadPorts;
2878  else if (portKind == MemOp::PortKind::Write) {
2879  ++numWritePorts;
2880  } else
2881  ++numReadWritePorts;
2882  }
2883 }
2884 
2885 /// Verify the correctness of a MemOp.
2886 LogicalResult MemOp::verify() {
2887 
2888  // Store the port names as we find them. This lets us check quickly
2889  // for uniqueneess.
2890  llvm::SmallDenseSet<Attribute, 8> portNamesSet;
2891 
2892  // Store the previous data type. This lets us check that the data
2893  // type is consistent across all ports.
2894  FIRRTLType oldDataType;
2895 
2896  for (size_t i = 0, e = getNumResults(); i != e; ++i) {
2897  auto portName = getPortName(i);
2898 
2899  // Get a bundle type representing this port, stripping an outer
2900  // flip if it exists. If this is not a bundle<> or
2901  // flip<bundle<>>, then this is an error.
2902  BundleType portBundleType =
2903  type_dyn_cast<BundleType>(getResult(i).getType());
2904 
2905  // Require that all port names are unique.
2906  if (!portNamesSet.insert(portName).second) {
2907  emitOpError() << "has non-unique port name " << portName;
2908  return failure();
2909  }
2910 
2911  // Determine the kind of the memory. If the kind cannot be
2912  // determined, then it's indicative of the wrong number of fields
2913  // in the type (but we don't know any more just yet).
2914 
2915  auto elt = getPortNamed(portName);
2916  if (!elt) {
2917  emitOpError() << "could not get port with name " << portName;
2918  return failure();
2919  }
2920  auto firrtlType = type_cast<FIRRTLType>(elt.getType());
2921  MemOp::PortKind portKind = getMemPortKindFromType(firrtlType);
2922 
2923  if (portKind == MemOp::PortKind::Debug &&
2924  !type_isa<RefType>(getResult(i).getType()))
2925  return emitOpError() << "has an invalid type on port " << portName
2926  << " (expected Read/Write/ReadWrite/Debug)";
2927  if (type_isa<RefType>(firrtlType) && e == 1)
2928  return emitOpError()
2929  << "cannot have only one port of debug type. Debug port can only "
2930  "exist alongside other read/write/read-write port";
2931 
2932  // Safely search for the "data" field, erroring if it can't be
2933  // found.
2934  FIRRTLBaseType dataType;
2935  if (portKind == MemOp::PortKind::Debug) {
2936  auto resType = type_cast<RefType>(getResult(i).getType());
2937  if (!(resType && type_isa<FVectorType>(resType.getType())))
2938  return emitOpError() << "debug ports must be a RefType of FVectorType";
2939  dataType = type_cast<FVectorType>(resType.getType()).getElementType();
2940  } else {
2941  auto dataTypeOption = portBundleType.getElement("data");
2942  if (!dataTypeOption && portKind == MemOp::PortKind::ReadWrite)
2943  dataTypeOption = portBundleType.getElement("wdata");
2944  if (!dataTypeOption) {
2945  emitOpError() << "has no data field on port " << portName
2946  << " (expected to see \"data\" for a read or write "
2947  "port or \"rdata\" for a read/write port)";
2948  return failure();
2949  }
2950  dataType = dataTypeOption->type;
2951  // Read data is expected to ba a flip.
2952  if (portKind == MemOp::PortKind::Read) {
2953  // FIXME error on missing bundle flip
2954  }
2955  }
2956 
2957  // Error if the data type isn't passive.
2958  if (!dataType.isPassive()) {
2959  emitOpError() << "has non-passive data type on port " << portName
2960  << " (memory types must be passive)";
2961  return failure();
2962  }
2963 
2964  // Error if the data type contains analog types.
2965  if (dataType.containsAnalog()) {
2966  emitOpError() << "has a data type that contains an analog type on port "
2967  << portName
2968  << " (memory types cannot contain analog types)";
2969  return failure();
2970  }
2971 
2972  // Check that the port type matches the kind that we determined
2973  // for this port. This catches situations of extraneous port
2974  // fields beind included or the fields being named incorrectly.
2975  FIRRTLType expectedType =
2976  getTypeForPort(getDepth(), dataType, portKind,
2977  dataType.isGround() ? getMaskBits() : 0);
2978  // Compute the original port type as portBundleType may have
2979  // stripped outer flip information.
2980  auto originalType = getResult(i).getType();
2981  if (originalType != expectedType) {
2982  StringRef portKindName;
2983  switch (portKind) {
2984  case MemOp::PortKind::Read:
2985  portKindName = "read";
2986  break;
2987  case MemOp::PortKind::Write:
2988  portKindName = "write";
2989  break;
2990  case MemOp::PortKind::ReadWrite:
2991  portKindName = "readwrite";
2992  break;
2993  case MemOp::PortKind::Debug:
2994  portKindName = "dbg";
2995  break;
2996  }
2997  emitOpError() << "has an invalid type for port " << portName
2998  << " of determined kind \"" << portKindName
2999  << "\" (expected " << expectedType << ", but got "
3000  << originalType << ")";
3001  return failure();
3002  }
3003 
3004  // Error if the type of the current port was not the same as the
3005  // last port, but skip checking the first port.
3006  if (oldDataType && oldDataType != dataType) {
3007  emitOpError() << "port " << getPortName(i)
3008  << " has a different type than port " << getPortName(i - 1)
3009  << " (expected " << oldDataType << ", but got " << dataType
3010  << ")";
3011  return failure();
3012  }
3013 
3014  oldDataType = dataType;
3015  }
3016 
3017  auto maskWidth = getMaskBits();
3018 
3019  auto dataWidth = getDataType().getBitWidthOrSentinel();
3020  if (dataWidth > 0 && maskWidth > (size_t)dataWidth)
3021  return emitOpError("the mask width cannot be greater than "
3022  "data width");
3023 
3024  if (getPortAnnotations().size() != getNumResults())
3025  return emitOpError("the number of result annotations should be "
3026  "equal to the number of results");
3027 
3028  return success();
3029 }
3030 
3031 static size_t getAddressWidth(size_t depth) {
3032  return std::max(1U, llvm::Log2_64_Ceil(depth));
3033 }
3034 
3035 size_t MemOp::getAddrBits() { return getAddressWidth(getDepth()); }
3036 
3037 FIRRTLType MemOp::getTypeForPort(uint64_t depth, FIRRTLBaseType dataType,
3038  PortKind portKind, size_t maskBits) {
3039 
3040  auto *context = dataType.getContext();
3041  if (portKind == PortKind::Debug)
3042  return RefType::get(FVectorType::get(dataType, depth));
3043  FIRRTLBaseType maskType;
3044  // maskBits not specified (==0), then get the mask type from the dataType.
3045  if (maskBits == 0)
3046  maskType = dataType.getMaskType();
3047  else
3048  maskType = UIntType::get(context, maskBits);
3049 
3050  auto getId = [&](StringRef name) -> StringAttr {
3051  return StringAttr::get(context, name);
3052  };
3053 
3054  SmallVector<BundleType::BundleElement, 7> portFields;
3055 
3056  auto addressType = UIntType::get(context, getAddressWidth(depth));
3057 
3058  portFields.push_back({getId("addr"), false, addressType});
3059  portFields.push_back({getId("en"), false, UIntType::get(context, 1)});
3060  portFields.push_back({getId("clk"), false, ClockType::get(context)});
3061 
3062  switch (portKind) {
3063  case PortKind::Read:
3064  portFields.push_back({getId("data"), true, dataType});
3065  break;
3066 
3067  case PortKind::Write:
3068  portFields.push_back({getId("data"), false, dataType});
3069  portFields.push_back({getId("mask"), false, maskType});
3070  break;
3071 
3072  case PortKind::ReadWrite:
3073  portFields.push_back({getId("rdata"), true, dataType});
3074  portFields.push_back({getId("wmode"), false, UIntType::get(context, 1)});
3075  portFields.push_back({getId("wdata"), false, dataType});
3076  portFields.push_back({getId("wmask"), false, maskType});
3077  break;
3078  default:
3079  llvm::report_fatal_error("memory port kind not handled");
3080  break;
3081  }
3082 
3083  return BundleType::get(context, portFields);
3084 }
3085 
3086 /// Return the name and kind of ports supported by this memory.
3087 SmallVector<MemOp::NamedPort> MemOp::getPorts() {
3088  SmallVector<MemOp::NamedPort> result;
3089  // Each entry in the bundle is a port.
3090  for (size_t i = 0, e = getNumResults(); i != e; ++i) {
3091  // Each port is a bundle.
3092  auto portType = type_cast<FIRRTLType>(getResult(i).getType());
3093  result.push_back({getPortName(i), getMemPortKindFromType(portType)});
3094  }
3095  return result;
3096 }
3097 
3098 /// Return the kind of the specified port.
3099 MemOp::PortKind MemOp::getPortKind(StringRef portName) {
3100  return getMemPortKindFromType(
3101  type_cast<FIRRTLType>(getPortNamed(portName).getType()));
3102 }
3103 
3104 /// Return the kind of the specified port number.
3105 MemOp::PortKind MemOp::getPortKind(size_t resultNo) {
3106  return getMemPortKindFromType(
3107  type_cast<FIRRTLType>(getResult(resultNo).getType()));
3108 }
3109 
3110 /// Return the number of bits in the mask for the memory.
3111 size_t MemOp::getMaskBits() {
3112 
3113  for (auto res : getResults()) {
3114  if (type_isa<RefType>(res.getType()))
3115  continue;
3116  auto firstPortType = type_cast<FIRRTLBaseType>(res.getType());
3117  if (getMemPortKindFromType(firstPortType) == PortKind::Read ||
3118  getMemPortKindFromType(firstPortType) == PortKind::Debug)
3119  continue;
3120 
3121  FIRRTLBaseType mType;
3122  for (auto t : type_cast<BundleType>(firstPortType.getPassiveType())) {
3123  if (t.name.getValue().contains("mask"))
3124  mType = t.type;
3125  }
3126  if (type_isa<UIntType>(mType))
3127  return mType.getBitWidthOrSentinel();
3128  }
3129  // Mask of zero bits means, either there are no write/readwrite ports or the
3130  // mask is of aggregate type.
3131  return 0;
3132 }
3133 
3134 /// Return the data-type field of the memory, the type of each element.
3135 FIRRTLBaseType MemOp::getDataType() {
3136  assert(getNumResults() != 0 && "Mems with no read/write ports are illegal");
3137 
3138  if (auto refType = type_dyn_cast<RefType>(getResult(0).getType()))
3139  return type_cast<FVectorType>(refType.getType()).getElementType();
3140  auto firstPortType = type_cast<FIRRTLBaseType>(getResult(0).getType());
3141 
3142  StringRef dataFieldName = "data";
3143  if (getMemPortKindFromType(firstPortType) == PortKind::ReadWrite)
3144  dataFieldName = "rdata";
3145 
3146  return type_cast<BundleType>(firstPortType.getPassiveType())
3147  .getElementType(dataFieldName);
3148 }
3149 
3150 StringAttr MemOp::getPortName(size_t resultNo) {
3151  return cast<StringAttr>(getPortNames()[resultNo]);
3152 }
3153 
3154 FIRRTLBaseType MemOp::getPortType(size_t resultNo) {
3155  return type_cast<FIRRTLBaseType>(getResults()[resultNo].getType());
3156 }
3157 
3158 Value MemOp::getPortNamed(StringAttr name) {
3159  auto namesArray = getPortNames();
3160  for (size_t i = 0, e = namesArray.size(); i != e; ++i) {
3161  if (namesArray[i] == name) {
3162  assert(i < getNumResults() && " names array out of sync with results");
3163  return getResult(i);
3164  }
3165  }
3166  return Value();
3167 }
3168 
3169 // Extract all the relevant attributes from the MemOp and return the FirMemory.
3171  auto op = *this;
3172  size_t numReadPorts = 0;
3173  size_t numWritePorts = 0;
3174  size_t numReadWritePorts = 0;
3176  SmallVector<int32_t> writeClockIDs;
3177 
3178  for (size_t i = 0, e = op.getNumResults(); i != e; ++i) {
3179  auto portKind = op.getPortKind(i);
3180  if (portKind == MemOp::PortKind::Read)
3181  ++numReadPorts;
3182  else if (portKind == MemOp::PortKind::Write) {
3183  for (auto *a : op.getResult(i).getUsers()) {
3184  auto subfield = dyn_cast<SubfieldOp>(a);
3185  if (!subfield || subfield.getFieldIndex() != 2)
3186  continue;
3187  auto clockPort = a->getResult(0);
3188  for (auto *b : clockPort.getUsers()) {
3189  if (auto connect = dyn_cast<FConnectLike>(b)) {
3190  if (connect.getDest() == clockPort) {
3191  auto result =
3192  clockToLeader.insert({circt::firrtl::getModuleScopedDriver(
3193  connect.getSrc(), true, true, true),
3194  numWritePorts});
3195  if (result.second) {
3196  writeClockIDs.push_back(numWritePorts);
3197  } else {
3198  writeClockIDs.push_back(result.first->second);
3199  }
3200  }
3201  }
3202  }
3203  break;
3204  }
3205  ++numWritePorts;
3206  } else
3207  ++numReadWritePorts;
3208  }
3209 
3210  size_t width = 0;
3211  if (auto widthV = getBitWidth(op.getDataType()))
3212  width = *widthV;
3213  else
3214  op.emitError("'firrtl.mem' should have simple type and known width");
3215  MemoryInitAttr init = op->getAttrOfType<MemoryInitAttr>("init");
3216  StringAttr modName;
3217  if (op->hasAttr("modName"))
3218  modName = op->getAttrOfType<StringAttr>("modName");
3219  else {
3220  SmallString<8> clocks;
3221  for (auto a : writeClockIDs)
3222  clocks.append(Twine((char)(a + 'a')).str());
3223  SmallString<32> initStr;
3224  // If there is a file initialization, then come up with a decent
3225  // representation for this. Use the filename, but only characters
3226  // [a-zA-Z0-9] and the bool/hex and inline booleans.
3227  if (init) {
3228  for (auto c : init.getFilename().getValue())
3229  if ((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') ||
3230  (c >= '0' && c <= '9'))
3231  initStr.push_back(c);
3232  initStr.push_back('_');
3233  initStr.push_back(init.getIsBinary() ? 't' : 'f');
3234  initStr.push_back('_');
3235  initStr.push_back(init.getIsInline() ? 't' : 'f');
3236  }
3237  modName = StringAttr::get(
3238  op->getContext(),
3239  llvm::formatv(
3240  "{0}FIRRTLMem_{1}_{2}_{3}_{4}_{5}_{6}_{7}_{8}_{9}_{10}{11}{12}",
3241  op.getPrefix().value_or(""), numReadPorts, numWritePorts,
3242  numReadWritePorts, (size_t)width, op.getDepth(),
3243  op.getReadLatency(), op.getWriteLatency(), op.getMaskBits(),
3244  (unsigned)op.getRuw(), (unsigned)seq::WUW::PortOrder,
3245  clocks.empty() ? "" : "_" + clocks, init ? initStr.str() : ""));
3246  }
3247  return {numReadPorts,
3248  numWritePorts,
3249  numReadWritePorts,
3250  (size_t)width,
3251  op.getDepth(),
3252  op.getReadLatency(),
3253  op.getWriteLatency(),
3254  op.getMaskBits(),
3255  *seq::symbolizeRUW(unsigned(op.getRuw())),
3256  seq::WUW::PortOrder,
3257  writeClockIDs,
3258  modName,
3259  op.getMaskBits() > 1,
3260  init,
3261  op.getPrefixAttr(),
3262  op.getLoc()};
3263 }
3264 
3266  StringRef base = getName();
3267  if (base.empty())
3268  base = "mem";
3269 
3270  for (size_t i = 0, e = (*this)->getNumResults(); i != e; ++i) {
3271  setNameFn(getResult(i), (base + "_" + getPortNameStr(i)).str());
3272  }
3273 }
3274 
3275 std::optional<size_t> MemOp::getTargetResultIndex() {
3276  // Inner symbols on memory operations target the op not any result.
3277  return std::nullopt;
3278 }
3279 
3280 // Construct name of the module which will be used for the memory definition.
3281 StringAttr FirMemory::getFirMemoryName() const { return modName; }
3282 
3283 /// Helper for naming forceable declarations (and their optional ref result).
3284 static void forceableAsmResultNames(Forceable op, StringRef name,
3285  OpAsmSetValueNameFn setNameFn) {
3286  if (name.empty())
3287  return;
3288  setNameFn(op.getDataRaw(), name);
3289  if (op.isForceable())
3290  setNameFn(op.getDataRef(), (name + "_ref").str());
3291 }
3292 
3294  return forceableAsmResultNames(*this, getName(), setNameFn);
3295 }
3296 
3297 LogicalResult NodeOp::inferReturnTypes(
3298  mlir::MLIRContext *context, std::optional<mlir::Location> location,
3299  ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes,
3300  ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions,
3301  ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) {
3302  if (operands.empty())
3303  return failure();
3304  inferredReturnTypes.push_back(operands[0].getType());
3305  for (auto &attr : attributes)
3306  if (attr.getName() == Forceable::getForceableAttrName()) {
3307  auto forceableType =
3308  firrtl::detail::getForceableResultType(true, operands[0].getType());
3309  if (!forceableType) {
3310  if (location)
3311  ::mlir::emitError(*location, "cannot force a node of type ")
3312  << operands[0].getType();
3313  return failure();
3314  }
3315  inferredReturnTypes.push_back(forceableType);
3316  }
3317  return success();
3318 }
3319 
3320 std::optional<size_t> NodeOp::getTargetResultIndex() { return 0; }
3321 
3323  return forceableAsmResultNames(*this, getName(), setNameFn);
3324 }
3325 
3326 std::optional<size_t> RegOp::getTargetResultIndex() { return 0; }
3327 
3328 LogicalResult RegResetOp::verify() {
3329  auto reset = getResetValue();
3330 
3331  FIRRTLBaseType resetType = reset.getType();
3332  FIRRTLBaseType regType = getResult().getType();
3333 
3334  // The type of the initialiser must be equivalent to the register type.
3335  if (!areTypesEquivalent(regType, resetType))
3336  return emitError("type mismatch between register ")
3337  << regType << " and reset value " << resetType;
3338 
3339  return success();
3340 }
3341 
3342 std::optional<size_t> RegResetOp::getTargetResultIndex() { return 0; }
3343 
3345  return forceableAsmResultNames(*this, getName(), setNameFn);
3346 }
3347 
3349  return forceableAsmResultNames(*this, getName(), setNameFn);
3350 }
3351 
3352 std::optional<size_t> WireOp::getTargetResultIndex() { return 0; }
3353 
3354 LogicalResult WireOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
3355  auto refType = type_dyn_cast<RefType>(getType(0));
3356  if (!refType)
3357  return success();
3358 
3359  return verifyProbeType(
3360  refType, getLoc(), getOperation()->getParentOfType<CircuitOp>(),
3361  symbolTable, Twine("'") + getOperationName() + "' op is");
3362 }
3363 
3364 void ObjectOp::build(OpBuilder &builder, OperationState &state, ClassLike klass,
3365  StringRef name) {
3366  build(builder, state, klass.getInstanceType(),
3367  StringAttr::get(builder.getContext(), name));
3368 }
3369 
3370 LogicalResult ObjectOp::verify() { return success(); }
3371 
3372 LogicalResult ObjectOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
3373  auto circuitOp = getOperation()->getParentOfType<CircuitOp>();
3374  auto classType = getType();
3375  auto className = classType.getNameAttr();
3376 
3377  // verify that the class exists.
3378  auto classOp = dyn_cast_or_null<ClassLike>(
3379  symbolTable.lookupSymbolIn(circuitOp, className));
3380  if (!classOp)
3381  return emitOpError() << "references unknown class " << className;
3382 
3383  // verify that the result type agrees with the class definition.
3384  if (failed(classOp.verifyType(classType, [&]() { return emitOpError(); })))
3385  return failure();
3386 
3387  return success();
3388 }
3389 
3390 StringAttr ObjectOp::getClassNameAttr() {
3391  return getType().getNameAttr().getAttr();
3392 }
3393 
3394 StringRef ObjectOp::getClassName() { return getType().getName(); }
3395 
3396 ClassLike ObjectOp::getReferencedClass(const SymbolTable &symbolTable) {
3397  auto symRef = getType().getNameAttr();
3398  return symbolTable.lookup<ClassLike>(symRef.getLeafReference());
3399 }
3400 
3401 Operation *ObjectOp::getReferencedOperation(const SymbolTable &symtbl) {
3402  return getReferencedClass(symtbl);
3403 }
3404 
3405 StringRef ObjectOp::getInstanceName() { return getName(); }
3406 
3407 StringAttr ObjectOp::getInstanceNameAttr() { return getNameAttr(); }
3408 
3409 StringRef ObjectOp::getReferencedModuleName() { return getClassName(); }
3410 
3411 StringAttr ObjectOp::getReferencedModuleNameAttr() {
3412  return getClassNameAttr();
3413 }
3414 
3416  setNameFn(getResult(), getName());
3417 }
3418 
3419 //===----------------------------------------------------------------------===//
3420 // Statements
3421 //===----------------------------------------------------------------------===//
3422 
3423 LogicalResult AttachOp::verify() {
3424  // All known widths must match.
3425  std::optional<int32_t> commonWidth;
3426  for (auto operand : getOperands()) {
3427  auto thisWidth = type_cast<AnalogType>(operand.getType()).getWidth();
3428  if (!thisWidth)
3429  continue;
3430  if (!commonWidth) {
3431  commonWidth = thisWidth;
3432  continue;
3433  }
3434  if (commonWidth != thisWidth)
3435  return emitOpError("is inavlid as not all known operand widths match");
3436  }
3437  return success();
3438 }
3439 
3440 /// Check if the source and sink are of appropriate flow.
3441 static LogicalResult checkConnectFlow(Operation *connect) {
3442  Value dst = connect->getOperand(0);
3443  Value src = connect->getOperand(1);
3444 
3445  // TODO: Relax this to allow reads from output ports,
3446  // instance/memory input ports.
3447  auto srcFlow = foldFlow(src);
3448  if (!isValidSrc(srcFlow)) {
3449  // A sink that is a port output or instance input used as a source is okay.
3450  auto kind = getDeclarationKind(src);
3451  if (kind != DeclKind::Port && kind != DeclKind::Instance) {
3452  auto srcRef = getFieldRefFromValue(src, /*lookThroughCasts=*/true);
3453  auto [srcName, rootKnown] = getFieldName(srcRef);
3454  auto diag = emitError(connect->getLoc());
3455  diag << "connect has invalid flow: the source expression ";
3456  if (rootKnown)
3457  diag << "\"" << srcName << "\" ";
3458  diag << "has " << toString(srcFlow) << ", expected source or duplex flow";
3459  return diag.attachNote(srcRef.getLoc()) << "the source was defined here";
3460  }
3461  }
3462 
3463  auto dstFlow = foldFlow(dst);
3464  if (!isValidDst(dstFlow)) {
3465  auto dstRef = getFieldRefFromValue(dst, /*lookThroughCasts=*/true);
3466  auto [dstName, rootKnown] = getFieldName(dstRef);
3467  auto diag = emitError(connect->getLoc());
3468  diag << "connect has invalid flow: the destination expression ";
3469  if (rootKnown)
3470  diag << "\"" << dstName << "\" ";
3471  diag << "has " << toString(dstFlow) << ", expected sink or duplex flow";
3472  return diag.attachNote(dstRef.getLoc())
3473  << "the destination was defined here";
3474  }
3475  return success();
3476 }
3477 
3478 // NOLINTBEGIN(misc-no-recursion)
3479 /// Checks if the type has any 'const' leaf elements . If `isFlip` is `true`,
3480 /// the `const` leaf is not considered to be driven.
3481 static bool isConstFieldDriven(FIRRTLBaseType type, bool isFlip = false,
3482  bool outerTypeIsConst = false) {
3483  auto typeIsConst = outerTypeIsConst || type.isConst();
3484 
3485  if (typeIsConst && type.isPassive())
3486  return !isFlip;
3487 
3488  if (auto bundleType = type_dyn_cast<BundleType>(type))
3489  return llvm::any_of(bundleType.getElements(), [&](auto &element) {
3490  return isConstFieldDriven(element.type, isFlip ^ element.isFlip,
3491  typeIsConst);
3492  });
3493 
3494  if (auto vectorType = type_dyn_cast<FVectorType>(type))
3495  return isConstFieldDriven(vectorType.getElementType(), isFlip, typeIsConst);
3496 
3497  if (typeIsConst)
3498  return !isFlip;
3499  return false;
3500 }
3501 // NOLINTEND(misc-no-recursion)
3502 
3503 /// Checks that connections to 'const' destinations are not dependent on
3504 /// non-'const' conditions in when blocks.
3505 static LogicalResult checkConnectConditionality(FConnectLike connect) {
3506  auto dest = connect.getDest();
3507  auto destType = type_dyn_cast<FIRRTLBaseType>(dest.getType());
3508  auto src = connect.getSrc();
3509  auto srcType = type_dyn_cast<FIRRTLBaseType>(src.getType());
3510  if (!destType || !srcType)
3511  return success();
3512 
3513  auto destRefinedType = destType;
3514  auto srcRefinedType = srcType;
3515 
3516  /// Looks up the value's defining op until the defining op is null or a
3517  /// declaration of the value. If a SubAccessOp is encountered with a 'const'
3518  /// input, `originalFieldType` is made 'const'.
3519  auto findFieldDeclarationRefiningFieldType =
3520  [](Value value, FIRRTLBaseType &originalFieldType) -> Value {
3521  while (auto *definingOp = value.getDefiningOp()) {
3522  bool shouldContinue = true;
3523  TypeSwitch<Operation *>(definingOp)
3524  .Case<SubfieldOp, SubindexOp>([&](auto op) { value = op.getInput(); })
3525  .Case<SubaccessOp>([&](SubaccessOp op) {
3526  if (op.getInput()
3527  .getType()
3528  .base()
3529  .getElementTypePreservingConst()
3530  .isConst())
3531  originalFieldType = originalFieldType.getConstType(true);
3532  value = op.getInput();
3533  })
3534  .Default([&](Operation *) { shouldContinue = false; });
3535  if (!shouldContinue)
3536  break;
3537  }
3538  return value;
3539  };
3540 
3541  auto destDeclaration =
3542  findFieldDeclarationRefiningFieldType(dest, destRefinedType);
3543  auto srcDeclaration =
3544  findFieldDeclarationRefiningFieldType(src, srcRefinedType);
3545 
3546  auto checkConstConditionality = [&](Value value, FIRRTLBaseType type,
3547  Value declaration) -> LogicalResult {
3548  auto *declarationBlock = declaration.getParentBlock();
3549  auto *block = connect->getBlock();
3550  while (block && block != declarationBlock) {
3551  auto *parentOp = block->getParentOp();
3552 
3553  if (auto whenOp = dyn_cast<WhenOp>(parentOp);
3554  whenOp && !whenOp.getCondition().getType().isConst()) {
3555  if (type.isConst())
3556  return connect.emitOpError()
3557  << "assignment to 'const' type " << type
3558  << " is dependent on a non-'const' condition";
3559  return connect->emitOpError()
3560  << "assignment to nested 'const' member of type " << type
3561  << " is dependent on a non-'const' condition";
3562  }
3563 
3564  block = parentOp->getBlock();
3565  }
3566  return success();
3567  };
3568 
3569  auto emitSubaccessError = [&] {
3570  return connect.emitError(
3571  "assignment to non-'const' subaccess of 'const' type is disallowed");
3572  };
3573 
3574  // Check destination if it contains 'const' leaves
3575  if (destRefinedType.containsConst() && isConstFieldDriven(destRefinedType)) {
3576  // Disallow assignment to non-'const' subaccesses of 'const' types
3577  if (destType != destRefinedType)
3578  return emitSubaccessError();
3579 
3580  if (failed(checkConstConditionality(dest, destType, destDeclaration)))
3581  return failure();
3582  }
3583 
3584  // Check source if it contains 'const' 'flip' leaves
3585  if (srcRefinedType.containsConst() &&
3586  isConstFieldDriven(srcRefinedType, /*isFlip=*/true)) {
3587  // Disallow assignment to non-'const' subaccesses of 'const' types
3588  if (srcType != srcRefinedType)
3589  return emitSubaccessError();
3590  if (failed(checkConstConditionality(src, srcType, srcDeclaration)))
3591  return failure();
3592  }
3593 
3594  return success();
3595 }
3596 
3597 LogicalResult ConnectOp::verify() {
3598  auto dstType = getDest().getType();
3599  auto srcType = getSrc().getType();
3600  auto dstBaseType = type_dyn_cast<FIRRTLBaseType>(dstType);
3601  auto srcBaseType = type_dyn_cast<FIRRTLBaseType>(srcType);
3602  if (!dstBaseType || !srcBaseType) {
3603  if (dstType != srcType)
3604  return emitError("may not connect different non-base types");
3605  } else {
3606  // Analog types cannot be connected and must be attached.
3607  if (dstBaseType.containsAnalog() || srcBaseType.containsAnalog())
3608  return emitError("analog types may not be connected");
3609 
3610  // Destination and source types must be equivalent.
3611  if (!areTypesEquivalent(dstBaseType, srcBaseType))
3612  return emitError("type mismatch between destination ")
3613  << dstBaseType << " and source " << srcBaseType;
3614 
3615  // Truncation is banned in a connection: destination bit width must be
3616  // greater than or equal to source bit width.
3617  if (!isTypeLarger(dstBaseType, srcBaseType))
3618  return emitError("destination ")
3619  << dstBaseType << " is not as wide as the source " << srcBaseType;
3620  }
3621 
3622  // Check that the flows make sense.
3623  if (failed(checkConnectFlow(*this)))
3624  return failure();
3625 
3626  if (failed(checkConnectConditionality(*this)))
3627  return failure();
3628 
3629  return success();
3630 }
3631 
3632 LogicalResult MatchingConnectOp::verify() {
3633  if (auto type = type_dyn_cast<FIRRTLType>(getDest().getType())) {
3634  auto baseType = type_cast<FIRRTLBaseType>(type);
3635 
3636  // Analog types cannot be connected and must be attached.
3637  if (baseType && baseType.containsAnalog())
3638  return emitError("analog types may not be connected");
3639 
3640  // The anonymous types of operands must be equivalent.
3641  assert(areAnonymousTypesEquivalent(cast<FIRRTLBaseType>(getSrc().getType()),
3642  baseType) &&
3643  "`SameAnonTypeOperands` trait should have already rejected "
3644  "structurally non-equivalent types");
3645  }
3646 
3647  // Check that the flows make sense.
3648  if (failed(checkConnectFlow(*this)))
3649  return failure();
3650 
3651  if (failed(checkConnectConditionality(*this)))
3652  return failure();
3653 
3654  return success();
3655 }
3656 
3657 LogicalResult RefDefineOp::verify() {
3658  // Check that the flows make sense.
3659  if (failed(checkConnectFlow(*this)))
3660  return failure();
3661 
3662  // For now, refs can't be in bundles so this is sufficient.
3663  // In the future need to ensure no other define's to same "fieldSource".
3664  // (When aggregates can have references, we can define a reference within,
3665  // but this must be unique. Checking this here may be expensive,
3666  // consider adding something to FModuleLike's to check it there instead)
3667  for (auto *user : getDest().getUsers()) {
3668  if (auto conn = dyn_cast<FConnectLike>(user);
3669  conn && conn.getDest() == getDest() && conn != *this)
3670  return emitError("destination reference cannot be reused by multiple "
3671  "operations, it can only capture a unique dataflow");
3672  }
3673 
3674  // Check "static" source/dest
3675  if (auto *op = getDest().getDefiningOp()) {
3676  // TODO: Make ref.sub only source flow?
3677  if (isa<RefSubOp>(op))
3678  return emitError(
3679  "destination reference cannot be a sub-element of a reference");
3680  if (isa<RefCastOp>(op)) // Source flow, check anyway for now.
3681  return emitError(
3682  "destination reference cannot be a cast of another reference");
3683  }
3684 
3685  // This define is only enabled when its ambient layers are active. Check
3686  // that whenever the destination's layer requirements are met, that this
3687  // op is enabled.
3688  auto ambientLayers = getAmbientLayersAt(getOperation());
3689  auto dstLayers = getLayersFor(getDest());
3690  SmallVector<SymbolRefAttr> missingLayers;
3691  if (!isLayerSetCompatibleWith(ambientLayers, dstLayers, missingLayers)) {
3692  auto diag = emitOpError("has more layer requirements than destination");
3693  auto &note = diag.attachNote();
3694  note << "additional layers required: ";
3695  interleaveComma(missingLayers, note);
3696  return failure();
3697  }
3698 
3699  return success();
3700 }
3701 
3702 LogicalResult PropAssignOp::verify() {
3703  // Check that the flows make sense.
3704  if (failed(checkConnectFlow(*this)))
3705  return failure();
3706 
3707  // Verify that there is a single value driving the destination.
3708  for (auto *user : getDest().getUsers()) {
3709  if (auto conn = dyn_cast<FConnectLike>(user);
3710  conn && conn.getDest() == getDest() && conn != *this)
3711  return emitError("destination property cannot be reused by multiple "
3712  "operations, it can only capture a unique dataflow");
3713  }
3714 
3715  return success();
3716 }
3717 
3718 void WhenOp::createElseRegion() {
3719  assert(!hasElseRegion() && "already has an else region");
3720  getElseRegion().push_back(new Block());
3721 }
3722 
3723 void WhenOp::build(OpBuilder &builder, OperationState &result, Value condition,
3724  bool withElseRegion, std::function<void()> thenCtor,
3725  std::function<void()> elseCtor) {
3726  OpBuilder::InsertionGuard guard(builder);
3727  result.addOperands(condition);
3728 
3729  // Create "then" region.
3730  builder.createBlock(result.addRegion());
3731  if (thenCtor)
3732  thenCtor();
3733 
3734  // Create "else" region.
3735  Region *elseRegion = result.addRegion();
3736  if (withElseRegion) {
3737  builder.createBlock(elseRegion);
3738  if (elseCtor)
3739  elseCtor();
3740  }
3741 }
3742 
3743 //===----------------------------------------------------------------------===//
3744 // MatchOp
3745 //===----------------------------------------------------------------------===//
3746 
3747 LogicalResult MatchOp::verify() {
3748  FEnumType type = getInput().getType();
3749 
3750  // Make sure that the number of tags matches the number of regions.
3751  auto numCases = getTags().size();
3752  auto numRegions = getNumRegions();
3753  if (numRegions != numCases)
3754  return emitOpError("expected ")
3755  << numRegions << " tags but got " << numCases;
3756 
3757  auto numTags = type.getNumElements();
3758 
3759  SmallDenseSet<int64_t> seen;
3760  for (const auto &[tag, region] : llvm::zip(getTags(), getRegions())) {
3761  auto tagIndex = size_t(cast<IntegerAttr>(tag).getInt());
3762 
3763  // Ensure that the block has a single argument.
3764  if (region.getNumArguments() != 1)
3765  return emitOpError("region should have exactly one argument");
3766 
3767  // Make sure that it is a valid tag.
3768  if (tagIndex >= numTags)
3769  return emitOpError("the tag index ")
3770  << tagIndex << " is out of the range of valid tags in " << type;
3771 
3772  // Make sure we have not already matched this tag.
3773  auto [it, inserted] = seen.insert(tagIndex);
3774  if (!inserted)
3775  return emitOpError("the tag ") << type.getElementNameAttr(tagIndex)
3776  << " is matched more than once";
3777 
3778  // Check that the block argument type matches the tag's type.
3779  auto expectedType = type.getElementTypePreservingConst(tagIndex);
3780  auto regionType = region.getArgument(0).getType();
3781  if (regionType != expectedType)
3782  return emitOpError("region type ")
3783  << regionType << " does not match the expected type "
3784  << expectedType;
3785  }
3786 
3787  // Check that the match statement is exhaustive.
3788  for (size_t i = 0, e = type.getNumElements(); i < e; ++i)
3789  if (!seen.contains(i))
3790  return emitOpError("missing case for tag ") << type.getElementNameAttr(i);
3791 
3792  return success();
3793 }
3794 
3795 void MatchOp::print(OpAsmPrinter &p) {
3796  auto input = getInput();
3797  FEnumType type = input.getType();
3798  auto regions = getRegions();
3799  p << " " << input << " : " << type;
3800  SmallVector<StringRef> elided = {"tags"};
3801  p.printOptionalAttrDictWithKeyword((*this)->getAttrs(), elided);
3802  p << " {";
3803  p.increaseIndent();
3804  for (const auto &[tag, region] : llvm::zip(getTags(), regions)) {
3805  p.printNewline();
3806  p << "case ";
3807  p.printKeywordOrString(
3808  type.getElementName(cast<IntegerAttr>(tag).getInt()));
3809  p << "(";
3810  p.printRegionArgument(region.front().getArgument(0), /*attrs=*/{},
3811  /*omitType=*/true);
3812  p << ") ";
3813  p.printRegion(region, /*printEntryBlockArgs=*/false);
3814  }
3815  p.decreaseIndent();
3816  p.printNewline();
3817  p << "}";
3818 }
3819 
3820 ParseResult MatchOp::parse(OpAsmParser &parser, OperationState &result) {
3821  auto *context = parser.getContext();
3822  OpAsmParser::UnresolvedOperand input;
3823  if (parser.parseOperand(input) || parser.parseColon())
3824  return failure();
3825 
3826  auto loc = parser.getCurrentLocation();
3827  Type type;
3828  if (parser.parseType(type))
3829  return failure();
3830  auto enumType = type_dyn_cast<FEnumType>(type);
3831  if (!enumType)
3832  return parser.emitError(loc, "expected enumeration type but got") << type;
3833 
3834  if (parser.resolveOperand(input, type, result.operands) ||
3835  parser.parseOptionalAttrDictWithKeyword(result.attributes) ||
3836  parser.parseLBrace())
3837  return failure();
3838 
3839  auto i32Type = IntegerType::get(context, 32);
3840  SmallVector<Attribute> tags;
3841  while (true) {
3842  // Stop parsing when we don't find another "case" keyword.
3843  if (failed(parser.parseOptionalKeyword("case")))
3844  break;
3845 
3846  // Parse the tag and region argument.
3847  auto nameLoc = parser.getCurrentLocation();
3848  std::string name;
3849  OpAsmParser::Argument arg;
3850  auto *region = result.addRegion();
3851  if (parser.parseKeywordOrString(&name) || parser.parseLParen() ||
3852  parser.parseArgument(arg) || parser.parseRParen())
3853  return failure();
3854 
3855  // Figure out the enum index of the tag.
3856  auto index = enumType.getElementIndex(name);
3857  if (!index)
3858  return parser.emitError(nameLoc, "the tag \"")
3859  << name << "\" is not a member of the enumeration " << enumType;
3860  tags.push_back(IntegerAttr::get(i32Type, *index));
3861 
3862  // Parse the region.
3863  arg.type = enumType.getElementTypePreservingConst(*index);
3864  if (parser.parseRegion(*region, arg))
3865  return failure();
3866  }
3867  result.addAttribute("tags", ArrayAttr::get(context, tags));
3868 
3869  return parser.parseRBrace();
3870 }
3871 
3872 void MatchOp::build(OpBuilder &builder, OperationState &result, Value input,
3873  ArrayAttr tags,
3874  MutableArrayRef<std::unique_ptr<Region>> regions) {
3875  result.addOperands(input);
3876  result.addAttribute("tags", tags);
3877  result.addRegions(regions);
3878 }
3879 
3880 //===----------------------------------------------------------------------===//
3881 // Expressions
3882 //===----------------------------------------------------------------------===//
3883 
3884 /// Type inference adaptor that narrows from the very generic MLIR
3885 /// `InferTypeOpInterface` to what we need in the FIRRTL dialect: just operands
3886 /// and attributes, no context or regions. Also, we only ever produce a single
3887 /// result value, so the FIRRTL-specific type inference ops directly return the
3888 /// inferred type rather than pushing into the `results` vector.
3889 LogicalResult impl::inferReturnTypes(
3890  MLIRContext *context, std::optional<Location> loc, ValueRange operands,
3891  DictionaryAttr attrs, mlir::OpaqueProperties properties,
3892  RegionRange regions, SmallVectorImpl<Type> &results,
3893  llvm::function_ref<FIRRTLType(ValueRange, ArrayRef<NamedAttribute>,
3894  std::optional<Location>)>
3895  callback) {
3896  auto type = callback(
3897  operands, attrs ? attrs.getValue() : ArrayRef<NamedAttribute>{}, loc);
3898  if (type) {
3899  results.push_back(type);
3900  return success();
3901  }
3902  return failure();
3903 }
3904 
3905 /// Get an attribute by name from a list of named attributes. Return null if no
3906 /// attribute is found with that name.
3907 static Attribute maybeGetAttr(ArrayRef<NamedAttribute> attrs, StringRef name) {
3908  for (auto attr : attrs)
3909  if (attr.getName() == name)
3910  return attr.getValue();
3911  return {};
3912 }
3913 
3914 /// Get an attribute by name from a list of named attributes. Aborts if the
3915 /// attribute does not exist.
3916 static Attribute getAttr(ArrayRef<NamedAttribute> attrs, StringRef name) {
3917  if (auto attr = maybeGetAttr(attrs, name))
3918  return attr;
3919  llvm::report_fatal_error("attribute '" + name + "' not found");
3920 }
3921 
3922 /// Same as above, but casts the attribute to a specific type.
3923 template <typename AttrClass>
3924 AttrClass getAttr(ArrayRef<NamedAttribute> attrs, StringRef name) {
3925  return cast<AttrClass>(getAttr(attrs, name));
3926 }
3927 
3928 /// Return true if the specified operation is a firrtl expression.
3929 bool firrtl::isExpression(Operation *op) {
3930  struct IsExprClassifier : public ExprVisitor<IsExprClassifier, bool> {
3931  bool visitInvalidExpr(Operation *op) { return false; }
3932  bool visitUnhandledExpr(Operation *op) { return true; }
3933  };
3934 
3935  return IsExprClassifier().dispatchExprVisitor(op);
3936 }
3937 
3939  // Set invalid values to have a distinct name.
3940  std::string name;
3941  if (auto ty = type_dyn_cast<IntType>(getType())) {
3942  const char *base = ty.isSigned() ? "invalid_si" : "invalid_ui";
3943  auto width = ty.getWidthOrSentinel();
3944  if (width == -1)
3945  name = base;
3946  else
3947  name = (Twine(base) + Twine(width)).str();
3948  } else if (auto ty = type_dyn_cast<AnalogType>(getType())) {
3949  auto width = ty.getWidthOrSentinel();
3950  if (width == -1)
3951  name = "invalid_analog";
3952  else
3953  name = ("invalid_analog" + Twine(width)).str();
3954  } else if (type_isa<AsyncResetType>(getType()))
3955  name = "invalid_asyncreset";
3956  else if (type_isa<ResetType>(getType()))
3957  name = "invalid_reset";
3958  else if (type_isa<ClockType>(getType()))
3959  name = "invalid_clock";
3960  else
3961  name = "invalid";
3962 
3963  setNameFn(getResult(), name);
3964 }
3965 
3966 void ConstantOp::print(OpAsmPrinter &p) {
3967  p << " ";
3968  p.printAttributeWithoutType(getValueAttr());
3969  p << " : ";
3970  p.printType(getType());
3971  p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"value"});
3972 }
3973 
3974 ParseResult ConstantOp::parse(OpAsmParser &parser, OperationState &result) {
3975  // Parse the constant value, without knowing its width.
3976  APInt value;
3977  auto loc = parser.getCurrentLocation();
3978  auto valueResult = parser.parseOptionalInteger(value);
3979  if (!valueResult.has_value())
3980  return parser.emitError(loc, "expected integer value");
3981 
3982  // Parse the result firrtl integer type.
3983  IntType resultType;
3984  if (failed(*valueResult) || parser.parseColonType(resultType) ||
3985  parser.parseOptionalAttrDict(result.attributes))
3986  return failure();
3987  result.addTypes(resultType);
3988 
3989  // Now that we know the width and sign of the result type, we can munge the
3990  // APInt as appropriate.
3991  if (resultType.hasWidth()) {
3992  auto width = (unsigned)resultType.getWidthOrSentinel();
3993  if (width > value.getBitWidth()) {
3994  // sext is always safe here, even for unsigned values, because the
3995  // parseOptionalInteger method will return something with a zero in the
3996  // top bits if it is a positive number.
3997  value = value.sext(width);
3998  } else if (width < value.getBitWidth()) {
3999  // The parser can return an unnecessarily wide result with leading
4000  // zeros. This isn't a problem, but truncating off bits is bad.
4001  unsigned neededBits = value.isNegative() ? value.getSignificantBits()
4002  : value.getActiveBits();
4003  if (width < neededBits)
4004  return parser.emitError(loc, "constant out of range for result type ")
4005  << resultType;
4006  value = value.trunc(width);
4007  }
4008  }
4009 
4010  auto intType = parser.getBuilder().getIntegerType(value.getBitWidth(),
4011  resultType.isSigned());
4012  auto valueAttr = parser.getBuilder().getIntegerAttr(intType, value);
4013  result.addAttribute("value", valueAttr);
4014  return success();
4015 }
4016 
4017 LogicalResult ConstantOp::verify() {
4018  // If the result type has a bitwidth, then the attribute must match its width.
4019  IntType intType = getType();
4020  auto width = intType.getWidthOrSentinel();
4021  if (width != -1 && (int)getValue().getBitWidth() != width)
4022  return emitError(
4023  "firrtl.constant attribute bitwidth doesn't match return type");
4024 
4025  // The sign of the attribute's integer type must match our integer type sign.
4026  auto attrType = type_cast<IntegerType>(getValueAttr().getType());
4027  if (attrType.isSignless() || attrType.isSigned() != intType.isSigned())
4028  return emitError("firrtl.constant attribute has wrong sign");
4029 
4030  return success();
4031 }
4032 
4033 /// Build a ConstantOp from an APInt and a FIRRTL type, handling the attribute
4034 /// formation for the 'value' attribute.
4035 void ConstantOp::build(OpBuilder &builder, OperationState &result, IntType type,
4036  const APInt &value) {
4037  int32_t width = type.getWidthOrSentinel();
4038  (void)width;
4039  assert((width == -1 || (int32_t)value.getBitWidth() == width) &&
4040  "incorrect attribute bitwidth for firrtl.constant");
4041 
4042  auto attr =
4043  IntegerAttr::get(type.getContext(), APSInt(value, !type.isSigned()));
4044  return build(builder, result, type, attr);
4045 }
4046 
4047 /// Build a ConstantOp from an APSInt, handling the attribute formation for the
4048 /// 'value' attribute and inferring the FIRRTL type.
4049 void ConstantOp::build(OpBuilder &builder, OperationState &result,
4050  const APSInt &value) {
4051  auto attr = IntegerAttr::get(builder.getContext(), value);
4052  auto type =
4053  IntType::get(builder.getContext(), value.isSigned(), value.getBitWidth());
4054  return build(builder, result, type, attr);
4055 }
4056 
4058  // For constants in particular, propagate the value into the result name to
4059  // make it easier to read the IR.
4060  IntType intTy = getType();
4061  assert(intTy);
4062 
4063  // Otherwise, build a complex name with the value and type.
4064  SmallString<32> specialNameBuffer;
4065  llvm::raw_svector_ostream specialName(specialNameBuffer);
4066  specialName << 'c';
4067  getValue().print(specialName, /*isSigned:*/ intTy.isSigned());
4068 
4069  specialName << (intTy.isSigned() ? "_si" : "_ui");
4070  auto width = intTy.getWidthOrSentinel();
4071  if (width != -1)
4072  specialName << width;
4073  setNameFn(getResult(), specialName.str());
4074 }
4075 
4076 void SpecialConstantOp::print(OpAsmPrinter &p) {
4077  p << " ";
4078  // SpecialConstant uses a BoolAttr, and we want to print `true` as `1`.
4079  p << static_cast<unsigned>(getValue());
4080  p << " : ";
4081  p.printType(getType());
4082  p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"value"});
4083 }
4084 
4085 ParseResult SpecialConstantOp::parse(OpAsmParser &parser,
4086  OperationState &result) {
4087  // Parse the constant value. SpecialConstant uses bool attributes, but it
4088  // prints as an integer.
4089  APInt value;
4090  auto loc = parser.getCurrentLocation();
4091  auto valueResult = parser.parseOptionalInteger(value);
4092  if (!valueResult.has_value())
4093  return parser.emitError(loc, "expected integer value");
4094 
4095  // Clocks and resets can only be 0 or 1.
4096  if (value != 0 && value != 1)
4097  return parser.emitError(loc, "special constants can only be 0 or 1.");
4098 
4099  // Parse the result firrtl type.
4100  Type resultType;
4101  if (failed(*valueResult) || parser.parseColonType(resultType) ||
4102  parser.parseOptionalAttrDict(result.attributes))
4103  return failure();
4104  result.addTypes(resultType);
4105 
4106  // Create the attribute.
4107  auto valueAttr = parser.getBuilder().getBoolAttr(value == 1);
4108  result.addAttribute("value", valueAttr);
4109  return success();
4110 }
4111 
4113  SmallString<32> specialNameBuffer;
4114  llvm::raw_svector_ostream specialName(specialNameBuffer);
4115  specialName << 'c';
4116  specialName << static_cast<unsigned>(getValue());
4117  auto type = getType();
4118  if (type_isa<ClockType>(type)) {
4119  specialName << "_clock";
4120  } else if (type_isa<ResetType>(type)) {
4121  specialName << "_reset";
4122  } else if (type_isa<AsyncResetType>(type)) {
4123  specialName << "_asyncreset";
4124  }
4125  setNameFn(getResult(), specialName.str());
4126 }
4127 
4128 // Checks that an array attr representing an aggregate constant has the correct
4129 // shape. This recurses on the type.
4130 static bool checkAggConstant(Operation *op, Attribute attr,
4131  FIRRTLBaseType type) {
4132  if (type.isGround()) {
4133  if (!isa<IntegerAttr>(attr)) {
4134  op->emitOpError("Ground type is not an integer attribute");
4135  return false;
4136  }
4137  return true;
4138  }
4139  auto attrlist = dyn_cast<ArrayAttr>(attr);
4140  if (!attrlist) {
4141  op->emitOpError("expected array attribute for aggregate constant");
4142  return false;
4143  }
4144  if (auto array = type_dyn_cast<FVectorType>(type)) {
4145  if (array.getNumElements() != attrlist.size()) {
4146  op->emitOpError("array attribute (")
4147  << attrlist.size() << ") has wrong size for vector constant ("
4148  << array.getNumElements() << ")";
4149  return false;
4150  }
4151  return llvm::all_of(attrlist, [&array, op](Attribute attr) {
4152  return checkAggConstant(op, attr, array.getElementType());
4153  });
4154  }
4155  if (auto bundle = type_dyn_cast<BundleType>(type)) {
4156  if (bundle.getNumElements() != attrlist.size()) {
4157  op->emitOpError("array attribute (")
4158  << attrlist.size() << ") has wrong size for bundle constant ("
4159  << bundle.getNumElements() << ")";
4160  return false;
4161  }
4162  for (size_t i = 0; i < bundle.getNumElements(); ++i) {
4163  if (bundle.getElement(i).isFlip) {
4164  op->emitOpError("Cannot have constant bundle type with flip");
4165  return false;
4166  }
4167  if (!checkAggConstant(op, attrlist[i], bundle.getElement(i).type))
4168  return false;
4169  }
4170  return true;
4171  }
4172  op->emitOpError("Unknown aggregate type");
4173  return false;
4174 }
4175 
4176 LogicalResult AggregateConstantOp::verify() {
4177  if (checkAggConstant(getOperation(), getFields(), getType()))
4178  return success();
4179  return failure();
4180 }
4181 
4182 Attribute AggregateConstantOp::getAttributeFromFieldID(uint64_t fieldID) {
4183  FIRRTLBaseType type = getType();
4184  Attribute value = getFields();
4185  while (fieldID != 0) {
4186  if (auto bundle = type_dyn_cast<BundleType>(type)) {
4187  auto index = bundle.getIndexForFieldID(fieldID);
4188  fieldID -= bundle.getFieldID(index);
4189  type = bundle.getElementType(index);
4190  value = cast<ArrayAttr>(value)[index];
4191  } else {
4192  auto vector = type_cast<FVectorType>(type);
4193  auto index = vector.getIndexForFieldID(fieldID);
4194  fieldID -= vector.getFieldID(index);
4195  type = vector.getElementType();
4196  value = cast<ArrayAttr>(value)[index];
4197  }
4198  }
4199  return value;
4200 }
4201 
4202 void FIntegerConstantOp::print(OpAsmPrinter &p) {
4203  p << " ";
4204  p.printAttributeWithoutType(getValueAttr());
4205  p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"value"});
4206 }
4207 
4208 ParseResult FIntegerConstantOp::parse(OpAsmParser &parser,
4209  OperationState &result) {
4210  auto *context = parser.getContext();
4211  APInt value;
4212  if (parser.parseInteger(value) ||
4213  parser.parseOptionalAttrDict(result.attributes))
4214  return failure();
4215  result.addTypes(FIntegerType::get(context));
4216  auto intType =
4217  IntegerType::get(context, value.getBitWidth(), IntegerType::Signed);
4218  auto valueAttr = parser.getBuilder().getIntegerAttr(intType, value);
4219  result.addAttribute("value", valueAttr);
4220  return success();
4221 }
4222 
4223 ParseResult ListCreateOp::parse(OpAsmParser &parser, OperationState &result) {
4224  llvm::SmallVector<OpAsmParser::UnresolvedOperand, 16> operands;
4225  ListType type;
4226 
4227  if (parser.parseOperandList(operands) ||
4228  parser.parseOptionalAttrDict(result.attributes) ||
4229  parser.parseColonType(type))
4230  return failure();
4231  result.addTypes(type);
4232 
4233  return parser.resolveOperands(operands, type.getElementType(),
4234  result.operands);
4235 }
4236 
4237 void ListCreateOp::print(OpAsmPrinter &p) {
4238  p << " ";
4239  p.printOperands(getElements());
4240  p.printOptionalAttrDict((*this)->getAttrs());
4241  p << " : " << getType();
4242 }
4243 
4244 LogicalResult ListCreateOp::verify() {
4245  if (getElements().empty())
4246  return success();
4247 
4248  auto elementType = getElements().front().getType();
4249  auto listElementType = getType().getElementType();
4250  if (elementType != listElementType)
4251  return emitOpError("has elements of type ")
4252  << elementType << " instead of " << listElementType;
4253 
4254  return success();
4255 }
4256 
4257 LogicalResult BundleCreateOp::verify() {
4258  BundleType resultType = getType();
4259  if (resultType.getNumElements() != getFields().size())
4260  return emitOpError("number of fields doesn't match type");
4261  for (size_t i = 0; i < resultType.getNumElements(); ++i)
4262  if (!areTypesConstCastable(
4263  resultType.getElementTypePreservingConst(i),
4264  type_cast<FIRRTLBaseType>(getOperand(i).getType())))
4265  return emitOpError("type of element doesn't match bundle for field ")
4266  << resultType.getElement(i).name;
4267  // TODO: check flow
4268  return success();
4269 }
4270 
4271 LogicalResult VectorCreateOp::verify() {
4272  FVectorType resultType = getType();
4273  if (resultType.getNumElements() != getFields().size())
4274  return emitOpError("number of fields doesn't match type");
4275  auto elemTy = resultType.getElementTypePreservingConst();
4276  for (size_t i = 0; i < resultType.getNumElements(); ++i)
4277  if (!areTypesConstCastable(
4278  elemTy, type_cast<FIRRTLBaseType>(getOperand(i).getType())))
4279  return emitOpError("type of element doesn't match vector element");
4280  // TODO: check flow
4281  return success();
4282 }
4283 
4284 //===----------------------------------------------------------------------===//
4285 // FEnumCreateOp
4286 //===----------------------------------------------------------------------===//
4287 
4288 LogicalResult FEnumCreateOp::verify() {
4289  FEnumType resultType = getResult().getType();
4290  auto elementIndex = resultType.getElementIndex(getFieldName());
4291  if (!elementIndex)
4292  return emitOpError("label ")
4293  << getFieldName() << " is not a member of the enumeration type "
4294  << resultType;
4295  if (!areTypesConstCastable(
4296  resultType.getElementTypePreservingConst(*elementIndex),
4297  getInput().getType()))
4298  return emitOpError("type of element doesn't match enum element");
4299  return success();
4300 }
4301 
4302 void FEnumCreateOp::print(OpAsmPrinter &printer) {
4303  printer << ' ';
4304  printer.printKeywordOrString(getFieldName());
4305  printer << '(' << getInput() << ')';
4306  SmallVector<StringRef> elidedAttrs = {"fieldIndex"};
4307  printer.printOptionalAttrDictWithKeyword((*this)->getAttrs(), elidedAttrs);
4308  printer << " : ";
4309  printer.printFunctionalType(ArrayRef<Type>{getInput().getType()},
4310  ArrayRef<Type>{getResult().getType()});
4311 }
4312 
4313 ParseResult FEnumCreateOp::parse(OpAsmParser &parser, OperationState &result) {
4314  auto *context = parser.getContext();
4315 
4316  OpAsmParser::UnresolvedOperand input;
4317  std::string fieldName;
4318  mlir::FunctionType functionType;
4319  if (parser.parseKeywordOrString(&fieldName) || parser.parseLParen() ||
4320  parser.parseOperand(input) || parser.parseRParen() ||
4321  parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
4322  parser.parseType(functionType))
4323  return failure();
4324 
4325  if (functionType.getNumInputs() != 1)
4326  return parser.emitError(parser.getNameLoc(), "single input type required");
4327  if (functionType.getNumResults() != 1)
4328  return parser.emitError(parser.getNameLoc(), "single result type required");
4329 
4330  auto inputType = functionType.getInput(0);
4331  if (parser.resolveOperand(input, inputType, result.operands))
4332  return failure();
4333 
4334  auto outputType = functionType.getResult(0);
4335  auto enumType = type_dyn_cast<FEnumType>(outputType);
4336  if (!enumType)
4337  return parser.emitError(parser.getNameLoc(),
4338  "output must be enum type, got ")
4339  << outputType;
4340  auto fieldIndex = enumType.getElementIndex(fieldName);
4341  if (!fieldIndex)
4342  return parser.emitError(parser.getNameLoc(),
4343  "unknown field " + fieldName + " in enum type ")
4344  << enumType;
4345 
4346  result.addAttribute(
4347  "fieldIndex",
4348  IntegerAttr::get(IntegerType::get(context, 32), *fieldIndex));
4349 
4350  result.addTypes(enumType);
4351 
4352  return success();
4353 }
4354 
4355 //===----------------------------------------------------------------------===//
4356 // IsTagOp
4357 //===----------------------------------------------------------------------===//
4358 
4359 LogicalResult IsTagOp::verify() {
4360  if (getFieldIndex() >= getInput().getType().base().getNumElements())
4361  return emitOpError("element index is greater than the number of fields in "
4362  "the bundle type");
4363  return success();
4364 }
4365 
4366 void IsTagOp::print(::mlir::OpAsmPrinter &printer) {
4367  printer << ' ' << getInput() << ' ';
4368  printer.printKeywordOrString(getFieldName());
4369  SmallVector<::llvm::StringRef, 1> elidedAttrs = {"fieldIndex"};
4370  printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
4371  printer << " : " << getInput().getType();
4372 }
4373 
4374 ParseResult IsTagOp::parse(OpAsmParser &parser, OperationState &result) {
4375  auto *context = parser.getContext();
4376 
4377  OpAsmParser::UnresolvedOperand input;
4378  std::string fieldName;
4379  Type inputType;
4380  if (parser.parseOperand(input) || parser.parseKeywordOrString(&fieldName) ||
4381  parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
4382  parser.parseType(inputType))
4383  return failure();
4384 
4385  if (parser.resolveOperand(input, inputType, result.operands))
4386  return failure();
4387 
4388  auto enumType = type_dyn_cast<FEnumType>(inputType);
4389  if (!enumType)
4390  return parser.emitError(parser.getNameLoc(),
4391  "input must be enum type, got ")
4392  << inputType;
4393  auto fieldIndex = enumType.getElementIndex(fieldName);
4394  if (!fieldIndex)
4395  return parser.emitError(parser.getNameLoc(),
4396  "unknown field " + fieldName + " in enum type ")
4397  << enumType;
4398 
4399  result.addAttribute(
4400  "fieldIndex",
4401  IntegerAttr::get(IntegerType::get(context, 32), *fieldIndex));
4402 
4403  result.addTypes(UIntType::get(context, 1, /*isConst=*/false));
4404 
4405  return success();
4406 }
4407 
4408 FIRRTLType IsTagOp::inferReturnType(ValueRange operands,
4409  ArrayRef<NamedAttribute> attrs,
4410  std::optional<Location> loc) {
4411  return UIntType::get(operands[0].getContext(), 1,
4412  isConst(operands[0].getType()));
4413 }
4414 
4415 template <typename OpTy>
4416 ParseResult parseSubfieldLikeOp(OpAsmParser &parser, OperationState &result) {
4417  auto *context = parser.getContext();
4418 
4419  OpAsmParser::UnresolvedOperand input;
4420  std::string fieldName;
4421  Type inputType;
4422  if (parser.parseOperand(input) || parser.parseLSquare() ||
4423  parser.parseKeywordOrString(&fieldName) || parser.parseRSquare() ||
4424  parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
4425  parser.parseType(inputType))
4426  return failure();
4427 
4428  if (parser.resolveOperand(input, inputType, result.operands))
4429  return failure();
4430 
4431  auto bundleType = type_dyn_cast<typename OpTy::InputType>(inputType);
4432  if (!bundleType)
4433  return parser.emitError(parser.getNameLoc(),
4434  "input must be bundle type, got ")
4435  << inputType;
4436  auto fieldIndex = bundleType.getElementIndex(fieldName);
4437  if (!fieldIndex)
4438  return parser.emitError(parser.getNameLoc(),
4439  "unknown field " + fieldName + " in bundle type ")
4440  << bundleType;
4441 
4442  result.addAttribute(
4443  "fieldIndex",
4444  IntegerAttr::get(IntegerType::get(context, 32), *fieldIndex));
4445 
4446  SmallVector<Type> inferredReturnTypes;
4447  if (failed(OpTy::inferReturnTypes(context, result.location, result.operands,
4448  result.attributes.getDictionary(context),
4449  result.getRawProperties(), result.regions,
4450  inferredReturnTypes)))
4451  return failure();
4452  result.addTypes(inferredReturnTypes);
4453 
4454  return success();
4455 }
4456 
4457 ParseResult SubtagOp::parse(OpAsmParser &parser, OperationState &result) {
4458  auto *context = parser.getContext();
4459 
4460  OpAsmParser::UnresolvedOperand input;
4461  std::string fieldName;
4462  Type inputType;
4463  if (parser.parseOperand(input) || parser.parseLSquare() ||
4464  parser.parseKeywordOrString(&fieldName) || parser.parseRSquare() ||
4465  parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
4466  parser.parseType(inputType))
4467  return failure();
4468 
4469  if (parser.resolveOperand(input, inputType, result.operands))
4470  return failure();
4471 
4472  auto enumType = type_dyn_cast<FEnumType>(inputType);
4473  if (!enumType)
4474  return parser.emitError(parser.getNameLoc(),
4475  "input must be enum type, got ")
4476  << inputType;
4477  auto fieldIndex = enumType.getElementIndex(fieldName);
4478  if (!fieldIndex)
4479  return parser.emitError(parser.getNameLoc(),
4480  "unknown field " + fieldName + " in enum type ")
4481  << enumType;
4482 
4483  result.addAttribute(
4484  "fieldIndex",
4485  IntegerAttr::get(IntegerType::get(context, 32), *fieldIndex));
4486 
4487  SmallVector<Type> inferredReturnTypes;
4488  if (failed(SubtagOp::inferReturnTypes(
4489  context, result.location, result.operands,
4490  result.attributes.getDictionary(context), result.getRawProperties(),
4491  result.regions, inferredReturnTypes)))
4492  return failure();
4493  result.addTypes(inferredReturnTypes);
4494 
4495  return success();
4496 }
4497 
4498 ParseResult SubfieldOp::parse(OpAsmParser &parser, OperationState &result) {
4499  return parseSubfieldLikeOp<SubfieldOp>(parser, result);
4500 }
4501 ParseResult OpenSubfieldOp::parse(OpAsmParser &parser, OperationState &result) {
4502  return parseSubfieldLikeOp<OpenSubfieldOp>(parser, result);
4503 }
4504 
4505 template <typename OpTy>
4506 static void printSubfieldLikeOp(OpTy op, ::mlir::OpAsmPrinter &printer) {
4507  printer << ' ' << op.getInput() << '[';
4508  printer.printKeywordOrString(op.getFieldName());
4509  printer << ']';
4510  ::llvm::SmallVector<::llvm::StringRef, 2> elidedAttrs;
4511  elidedAttrs.push_back("fieldIndex");
4512  printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
4513  printer << " : " << op.getInput().getType();
4514 }
4515 void SubfieldOp::print(::mlir::OpAsmPrinter &printer) {
4516  return printSubfieldLikeOp<SubfieldOp>(*this, printer);
4517 }
4518 void OpenSubfieldOp::print(::mlir::OpAsmPrinter &printer) {
4519  return printSubfieldLikeOp<OpenSubfieldOp>(*this, printer);
4520 }
4521 
4522 void SubtagOp::print(::mlir::OpAsmPrinter &printer) {
4523  printer << ' ' << getInput() << '[';
4524  printer.printKeywordOrString(getFieldName());
4525  printer << ']';
4526  ::llvm::SmallVector<::llvm::StringRef, 2> elidedAttrs;
4527  elidedAttrs.push_back("fieldIndex");
4528  printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
4529  printer << " : " << getInput().getType();
4530 }
4531 
4532 template <typename OpTy>
4533 static LogicalResult verifySubfieldLike(OpTy op) {
4534  if (op.getFieldIndex() >=
4535  firrtl::type_cast<typename OpTy::InputType>(op.getInput().getType())
4536  .getNumElements())
4537  return op.emitOpError("subfield element index is greater than the number "
4538  "of fields in the bundle type");
4539  return success();
4540 }
4541 LogicalResult SubfieldOp::verify() {
4542  return verifySubfieldLike<SubfieldOp>(*this);
4543 }
4544 LogicalResult OpenSubfieldOp::verify() {
4545  return verifySubfieldLike<OpenSubfieldOp>(*this);
4546 }
4547 
4548 LogicalResult SubtagOp::verify() {
4549  if (getFieldIndex() >= getInput().getType().base().getNumElements())
4550  return emitOpError("subfield element index is greater than the number "
4551  "of fields in the bundle type");
4552  return success();
4553 }
4554 
4555 /// Return true if the specified operation has a constant value. This trivially
4556 /// checks for `firrtl.constant` and friends, but also looks through subaccesses
4557 /// and correctly handles wires driven with only constant values.
4558 bool firrtl::isConstant(Operation *op) {
4559  // Worklist of ops that need to be examined that should all be constant in
4560  // order for the input operation to be constant.
4561  SmallVector<Operation *, 8> worklist({op});
4562 
4563  // Mutable state indicating if this op is a constant. Assume it is a constant
4564  // and look for counterexamples.
4565  bool constant = true;
4566 
4567  // While we haven't found a counterexample and there are still ops in the
4568  // worklist, pull ops off the worklist. If it provides a counterexample, set
4569  // the `constant` to false (and exit on the next loop iteration). Otherwise,
4570  // look through the op or spawn off more ops to look at.
4571  while (constant && !(worklist.empty()))
4572  TypeSwitch<Operation *>(worklist.pop_back_val())
4573  .Case<NodeOp, AsSIntPrimOp, AsUIntPrimOp>([&](auto op) {
4574  if (auto definingOp = op.getInput().getDefiningOp())
4575  worklist.push_back(definingOp);
4576  constant = false;
4577  })
4578  .Case<WireOp, SubindexOp, SubfieldOp>([&](auto op) {
4579  for (auto &use : op.getResult().getUses())
4580  worklist.push_back(use.getOwner());
4581  })
4582  .Case<ConstantOp, SpecialConstantOp, AggregateConstantOp>([](auto) {})
4583  .Default([&](auto) { constant = false; });
4584 
4585  return constant;
4586 }
4587 
4588 /// Return true if the specified value is a constant. This trivially checks for
4589 /// `firrtl.constant` and friends, but also looks through subaccesses and
4590 /// correctly handles wires driven with only constant values.
4591 bool firrtl::isConstant(Value value) {
4592  if (auto *op = value.getDefiningOp())
4593  return isConstant(op);
4594  return false;
4595 }
4596 
4597 LogicalResult ConstCastOp::verify() {
4598  if (!areTypesConstCastable(getResult().getType(), getInput().getType()))
4599  return emitOpError() << getInput().getType()
4600  << " is not 'const'-castable to "
4601  << getResult().getType();
4602  return success();
4603 }
4604 
4605 FIRRTLType SubfieldOp::inferReturnType(ValueRange operands,
4606  ArrayRef<NamedAttribute> attrs,
4607  std::optional<Location> loc) {
4608  auto inType = type_cast<BundleType>(operands[0].getType());
4609  auto fieldIndex =
4610  getAttr<IntegerAttr>(attrs, "fieldIndex").getValue().getZExtValue();
4611 
4612  if (fieldIndex >= inType.getNumElements())
4613  return emitInferRetTypeError(loc,
4614  "subfield element index is greater than the "
4615  "number of fields in the bundle type");
4616 
4617  // SubfieldOp verifier checks that the field index is valid with number of
4618  // subelements.
4619  return inType.getElementTypePreservingConst(fieldIndex);
4620 }
4621 
4622 FIRRTLType OpenSubfieldOp::inferReturnType(ValueRange operands,
4623  ArrayRef<NamedAttribute> attrs,
4624  std::optional<Location> loc) {
4625  auto inType = type_cast<OpenBundleType>(operands[0].getType());
4626  auto fieldIndex =
4627  getAttr<IntegerAttr>(attrs, "fieldIndex").getValue().getZExtValue();
4628 
4629  if (fieldIndex >= inType.getNumElements())
4630  return emitInferRetTypeError(loc,
4631  "subfield element index is greater than the "
4632  "number of fields in the bundle type");
4633 
4634  // OpenSubfieldOp verifier checks that the field index is valid with number of
4635  // subelements.
4636  return inType.getElementTypePreservingConst(fieldIndex);
4637 }
4638 
4639 bool SubfieldOp::isFieldFlipped() {
4640  BundleType bundle = getInput().getType();
4641  return bundle.getElement(getFieldIndex()).isFlip;
4642 }
4643 bool OpenSubfieldOp::isFieldFlipped() {
4644  auto bundle = getInput().getType();
4645  return bundle.getElement(getFieldIndex()).isFlip;
4646 }
4647 
4648 FIRRTLType SubindexOp::inferReturnType(ValueRange operands,
4649  ArrayRef<NamedAttribute> attrs,
4650  std::optional<Location> loc) {
4651  Type inType = operands[0].getType();
4652  auto fieldIdx =
4653  getAttr<IntegerAttr>(attrs, "index").getValue().getZExtValue();
4654 
4655  if (auto vectorType = type_dyn_cast<FVectorType>(inType)) {
4656  if (fieldIdx < vectorType.getNumElements())
4657  return vectorType.getElementTypePreservingConst();
4658  return emitInferRetTypeError(loc, "out of range index '", fieldIdx,
4659  "' in vector type ", inType);
4660  }
4661 
4662  return emitInferRetTypeError(loc, "subindex requires vector operand");
4663 }
4664 
4665 FIRRTLType OpenSubindexOp::inferReturnType(ValueRange operands,
4666  ArrayRef<NamedAttribute> attrs,
4667  std::optional<Location> loc) {
4668  Type inType = operands[0].getType();
4669  auto fieldIdx =
4670  getAttr<IntegerAttr>(attrs, "index").getValue().getZExtValue();
4671 
4672  if (auto vectorType = type_dyn_cast<OpenVectorType>(inType)) {
4673  if (fieldIdx < vectorType.getNumElements())
4674  return vectorType.getElementTypePreservingConst();
4675  return emitInferRetTypeError(loc, "out of range index '", fieldIdx,
4676  "' in vector type ", inType);
4677  }
4678 
4679  return emitInferRetTypeError(loc, "subindex requires vector operand");
4680 }
4681 
4682 FIRRTLType SubtagOp::inferReturnType(ValueRange operands,
4683  ArrayRef<NamedAttribute> attrs,
4684  std::optional<Location> loc) {
4685  auto inType = type_cast<FEnumType>(operands[0].getType());
4686  auto fieldIndex =
4687  getAttr<IntegerAttr>(attrs, "fieldIndex").getValue().getZExtValue();
4688 
4689  if (fieldIndex >= inType.getNumElements())
4690  return emitInferRetTypeError(loc,
4691  "subtag element index is greater than the "
4692  "number of fields in the enum type");
4693 
4694  // SubtagOp verifier checks that the field index is valid with number of
4695  // subelements.
4696  auto elementType = inType.getElement(fieldIndex).type;
4697  return elementType.getConstType(elementType.isConst() || inType.isConst());
4698 }
4699 
4700 FIRRTLType SubaccessOp::inferReturnType(ValueRange operands,
4701  ArrayRef<NamedAttribute> attrs,
4702  std::optional<Location> loc) {
4703  auto inType = operands[0].getType();
4704  auto indexType = operands[1].getType();
4705 
4706  if (!type_isa<UIntType>(indexType))
4707  return emitInferRetTypeError(loc, "subaccess index must be UInt type, not ",
4708  indexType);
4709 
4710  if (auto vectorType = type_dyn_cast<FVectorType>(inType)) {
4711  if (isConst(indexType))
4712  return vectorType.getElementTypePreservingConst();
4713  return vectorType.getElementType().getAllConstDroppedType();
4714  }
4715 
4716  return emitInferRetTypeError(loc, "subaccess requires vector operand, not ",
4717  inType);
4718 }
4719 
4720 FIRRTLType TagExtractOp::inferReturnType(ValueRange operands,
4721  ArrayRef<NamedAttribute> attrs,
4722  std::optional<Location> loc) {
4723  auto inType = type_cast<FEnumType>(operands[0].getType());
4724  auto i = llvm::Log2_32_Ceil(inType.getNumElements());
4725  return UIntType::get(inType.getContext(), i);
4726 }
4727 
4728 ParseResult MultibitMuxOp::parse(OpAsmParser &parser, OperationState &result) {
4729  OpAsmParser::UnresolvedOperand index;
4730  SmallVector<OpAsmParser::UnresolvedOperand, 16> inputs;
4731  Type indexType, elemType;
4732 
4733  if (parser.parseOperand(index) || parser.parseComma() ||
4734  parser.parseOperandList(inputs) ||
4735  parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
4736  parser.parseType(indexType) || parser.parseComma() ||
4737  parser.parseType(elemType))
4738  return failure();
4739 
4740  if (parser.resolveOperand(index, indexType, result.operands))
4741  return failure();
4742 
4743  result.addTypes(elemType);
4744 
4745  return parser.resolveOperands(inputs, elemType, result.operands);
4746 }
4747 
4748 void MultibitMuxOp::print(OpAsmPrinter &p) {
4749  p << " " << getIndex() << ", ";
4750  p.printOperands(getInputs());
4751  p.printOptionalAttrDict((*this)->getAttrs());
4752  p << " : " << getIndex().getType() << ", " << getType();
4753 }
4754 
4755 FIRRTLType MultibitMuxOp::inferReturnType(ValueRange operands,
4756  ArrayRef<NamedAttribute> attrs,
4757  std::optional<Location> loc) {
4758  if (operands.size() < 2)
4759  return emitInferRetTypeError(loc, "at least one input is required");
4760 
4761  // Check all mux inputs have the same type.
4762  if (!llvm::all_of(operands.drop_front(2), [&](auto op) {
4763  return operands[1].getType() == op.getType();
4764  }))
4765  return emitInferRetTypeError(loc, "all inputs must have the same type");
4766 
4767  return type_cast<FIRRTLType>(operands[1].getType());
4768 }
4769 
4770 //===----------------------------------------------------------------------===//
4771 // ObjectSubfieldOp
4772 //===----------------------------------------------------------------------===//
4773 
4775  MLIRContext *context, std::optional<mlir::Location> location,
4776  ValueRange operands, DictionaryAttr attributes, OpaqueProperties properties,
4777  RegionRange regions, llvm::SmallVectorImpl<Type> &inferredReturnTypes) {
4778  auto type = inferReturnType(operands, attributes.getValue(), location);
4779  if (!type)
4780  return failure();
4781  inferredReturnTypes.push_back(type);
4782  return success();
4783 }
4784 
4785 Type ObjectSubfieldOp::inferReturnType(ValueRange operands,
4786  ArrayRef<NamedAttribute> attrs,
4787  std::optional<Location> loc) {
4788  auto classType = dyn_cast<ClassType>(operands[0].getType());
4789  if (!classType)
4790  return emitInferRetTypeError(loc, "base object is not a class");
4791 
4792  auto index = getAttr<IntegerAttr>(attrs, "index").getValue().getZExtValue();
4793  if (classType.getNumElements() <= index)
4794  return emitInferRetTypeError(loc, "element index is greater than the "
4795  "number of fields in the object");
4796 
4797  return classType.getElement(index).type;
4798 }
4799 
4800 void ObjectSubfieldOp::print(OpAsmPrinter &p) {
4801  auto input = getInput();
4802  auto classType = input.getType();
4803  p << ' ' << input << "[";
4804  p.printKeywordOrString(classType.getElement(getIndex()).name);
4805  p << "]";
4806  p.printOptionalAttrDict((*this)->getAttrs(), std::array{StringRef("index")});
4807  p << " : " << classType;
4808 }
4809 
4810 ParseResult ObjectSubfieldOp::parse(OpAsmParser &parser,
4811  OperationState &result) {
4812  auto *context = parser.getContext();
4813 
4814  OpAsmParser::UnresolvedOperand input;
4815  std::string fieldName;
4816  ClassType inputType;
4817  if (parser.parseOperand(input) || parser.parseLSquare() ||
4818  parser.parseKeywordOrString(&fieldName) || parser.parseRSquare() ||
4819  parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
4820  parser.parseType(inputType) ||
4821  parser.resolveOperand(input, inputType, result.operands))
4822  return failure();
4823 
4824  auto index = inputType.getElementIndex(fieldName);
4825  if (!index)
4826  return parser.emitError(parser.getNameLoc(),
4827  "unknown field " + fieldName + " in class type ")
4828  << inputType;
4829  result.addAttribute("index",
4830  IntegerAttr::get(IntegerType::get(context, 32), *index));
4831 
4832  SmallVector<Type> inferredReturnTypes;
4833  if (failed(inferReturnTypes(context, result.location, result.operands,
4834  result.attributes.getDictionary(context),
4835  result.getRawProperties(), result.regions,
4836  inferredReturnTypes)))
4837  return failure();
4838  result.addTypes(inferredReturnTypes);
4839 
4840  return success();
4841 }
4842 
4843 //===----------------------------------------------------------------------===//
4844 // Binary Primitives
4845 //===----------------------------------------------------------------------===//
4846 
4847 /// If LHS and RHS are both UInt or SInt types, the return true and fill in the
4848 /// width of them if known. If unknown, return -1 for the widths.
4849 /// The constness of the result is also returned, where if both lhs and rhs are
4850 /// const, then the result is const.
4851 ///
4852 /// On failure, this reports and error and returns false. This function should
4853 /// not be used if you don't want an error reported.
4854 static bool isSameIntTypeKind(Type lhs, Type rhs, int32_t &lhsWidth,
4855  int32_t &rhsWidth, bool &isConstResult,
4856  std::optional<Location> loc) {
4857  // Must have two integer types with the same signedness.
4858  auto lhsi = type_dyn_cast<IntType>(lhs);
4859  auto rhsi = type_dyn_cast<IntType>(rhs);
4860  if (!lhsi || !rhsi || lhsi.isSigned() != rhsi.isSigned()) {
4861  if (loc) {
4862  if (lhsi && !rhsi)
4863  mlir::emitError(*loc, "second operand must be an integer type, not ")
4864  << rhs;
4865  else if (!lhsi && rhsi)
4866  mlir::emitError(*loc, "first operand must be an integer type, not ")
4867  << lhs;
4868  else if (!lhsi && !rhsi)
4869  mlir::emitError(*loc, "operands must be integer types, not ")
4870  << lhs << " and " << rhs;
4871  else
4872  mlir::emitError(*loc, "operand signedness must match");
4873  }
4874  return false;
4875  }
4876 
4877  lhsWidth = lhsi.getWidthOrSentinel();
4878  rhsWidth = rhsi.getWidthOrSentinel();
4879  isConstResult = lhsi.isConst() && rhsi.isConst();
4880  return true;
4881 }
4882 
4883 LogicalResult impl::verifySameOperandsIntTypeKind(Operation *op) {
4884  assert(op->getNumOperands() == 2 &&
4885  "SameOperandsIntTypeKind on non-binary op");
4886  int32_t lhsWidth, rhsWidth;
4887  bool isConstResult;
4888  return success(isSameIntTypeKind(op->getOperand(0).getType(),
4889  op->getOperand(1).getType(), lhsWidth,
4890  rhsWidth, isConstResult, op->getLoc()));
4891 }
4892 
4893 LogicalResult impl::validateBinaryOpArguments(ValueRange operands,
4894  ArrayRef<NamedAttribute> attrs,
4895  Location loc) {
4896  if (operands.size() != 2 || !attrs.empty()) {
4897  mlir::emitError(loc, "operation requires two operands and no constants");
4898  return failure();
4899  }
4900  return success();
4901 }
4902 
4904  std::optional<Location> loc) {
4905  int32_t lhsWidth, rhsWidth, resultWidth = -1;
4906  bool isConstResult = false;
4907  if (!isSameIntTypeKind(lhs, rhs, lhsWidth, rhsWidth, isConstResult, loc))
4908  return {};
4909 
4910  if (lhsWidth != -1 && rhsWidth != -1)
4911  resultWidth = std::max(lhsWidth, rhsWidth) + 1;
4912  return IntType::get(lhs.getContext(), type_isa<SIntType>(lhs), resultWidth,
4913  isConstResult);
4914 }
4915 
4916 FIRRTLType MulPrimOp::inferBinaryReturnType(FIRRTLType lhs, FIRRTLType rhs,
4917  std::optional<Location> loc) {
4918  int32_t lhsWidth, rhsWidth, resultWidth = -1;
4919  bool isConstResult = false;
4920  if (!isSameIntTypeKind(lhs, rhs, lhsWidth, rhsWidth, isConstResult, loc))
4921  return {};
4922 
4923  if (lhsWidth != -1 && rhsWidth != -1)
4924  resultWidth = lhsWidth + rhsWidth;
4925 
4926  return IntType::get(lhs.getContext(), type_isa<SIntType>(lhs), resultWidth,
4927  isConstResult);
4928 }
4929 
4930 FIRRTLType DivPrimOp::inferBinaryReturnType(FIRRTLType lhs, FIRRTLType rhs,
4931  std::optional<Location> loc) {
4932  int32_t lhsWidth, rhsWidth;
4933  bool isConstResult = false;
4934  if (!isSameIntTypeKind(lhs, rhs, lhsWidth, rhsWidth, isConstResult, loc))
4935  return {};
4936 
4937  // For unsigned, the width is the width of the numerator on the LHS.
4938  if (type_isa<UIntType>(lhs))
4939  return UIntType::get(lhs.getContext(), lhsWidth, isConstResult);
4940 
4941  // For signed, the width is the width of the numerator on the LHS, plus 1.
4942  int32_t resultWidth = lhsWidth != -1 ? lhsWidth + 1 : -1;
4943  return SIntType::get(lhs.getContext(), resultWidth, isConstResult);
4944 }
4945 
4946 FIRRTLType RemPrimOp::inferBinaryReturnType(FIRRTLType lhs, FIRRTLType rhs,
4947  std::optional<Location> loc) {
4948  int32_t lhsWidth, rhsWidth, resultWidth = -1;
4949  bool isConstResult = false;
4950  if (!isSameIntTypeKind(lhs, rhs, lhsWidth, rhsWidth, isConstResult, loc))
4951  return {};
4952 
4953  if (lhsWidth != -1 && rhsWidth != -1)
4954  resultWidth = std::min(lhsWidth, rhsWidth);
4955  return IntType::get(lhs.getContext(), type_isa<SIntType>(lhs), resultWidth,
4956  isConstResult);
4957 }
4958 
4960  std::optional<Location> loc) {
4961  int32_t lhsWidth, rhsWidth, resultWidth = -1;
4962  bool isConstResult = false;
4963  if (!isSameIntTypeKind(lhs, rhs, lhsWidth, rhsWidth, isConstResult, loc))
4964  return {};
4965 
4966  if (lhsWidth != -1 && rhsWidth != -1) {
4967  resultWidth = std::max(lhsWidth, rhsWidth);
4968  if (lhsWidth == resultWidth && lhs.isConst() == isConstResult &&
4969  isa<UIntType>(lhs))
4970  return lhs;
4971  if (rhsWidth == resultWidth && rhs.isConst() == isConstResult &&
4972  isa<UIntType>(rhs))
4973  return rhs;
4974  }
4975  return UIntType::get(lhs.getContext(), resultWidth, isConstResult);
4976 }
4977 
4979  std::optional<Location> loc) {
4980  if (!type_isa<FVectorType>(lhs) || !type_isa<FVectorType>(rhs))
4981  return {};
4982 
4983  auto lhsVec = type_cast<FVectorType>(lhs);
4984  auto rhsVec = type_cast<FVectorType>(rhs);
4985 
4986  if (lhsVec.getNumElements() != rhsVec.getNumElements())
4987  return {};
4988 
4989  auto elemType =
4990  impl::inferBitwiseResult(lhsVec.getElementTypePreservingConst(),
4991  rhsVec.getElementTypePreservingConst(), loc);
4992  if (!elemType)
4993  return {};
4994  auto elemBaseType = type_cast<FIRRTLBaseType>(elemType);
4995  return FVectorType::get(elemBaseType, lhsVec.getNumElements(),
4996  lhsVec.isConst() && rhsVec.isConst() &&
4997  elemBaseType.isConst());
4998 }
4999 
5001  std::optional<Location> loc) {
5002  return UIntType::get(lhs.getContext(), 1, isConst(lhs) && isConst(rhs));
5003 }
5004 
5005 FIRRTLType CatPrimOp::inferBinaryReturnType(FIRRTLType lhs, FIRRTLType rhs,
5006  std::optional<Location> loc) {
5007  int32_t lhsWidth, rhsWidth, resultWidth = -1;
5008  bool isConstResult = false;
5009  if (!isSameIntTypeKind(lhs, rhs, lhsWidth, rhsWidth, isConstResult, loc))
5010  return {};
5011 
5012  if (lhsWidth != -1 && rhsWidth != -1)
5013  resultWidth = lhsWidth + rhsWidth;
5014  return UIntType::get(lhs.getContext(), resultWidth, isConstResult);
5015 }
5016 
5017 FIRRTLType DShlPrimOp::inferBinaryReturnType(FIRRTLType lhs, FIRRTLType rhs,
5018  std::optional<Location> loc) {
5019  auto lhsi = type_dyn_cast<IntType>(lhs);
5020  auto rhsui = type_dyn_cast<UIntType>(rhs);
5021  if (!rhsui || !lhsi)
5022  return emitInferRetTypeError(
5023  loc, "first operand should be integer, second unsigned int");
5024 
5025  // If the left or right has unknown result type, then the operation does
5026  // too.
5027  auto width = lhsi.getWidthOrSentinel();
5028  if (width == -1 || !rhsui.getWidth().has_value()) {
5029  width = -1;
5030  } else {
5031  auto amount = *rhsui.getWidth();
5032  if (amount >= 32)
5033  return emitInferRetTypeError(loc,
5034  "shift amount too large: second operand of "
5035  "dshl is wider than 31 bits");
5036  int64_t newWidth = (int64_t)width + ((int64_t)1 << amount) - 1;
5037  if (newWidth > INT32_MAX)
5038  return emitInferRetTypeError(
5039  loc, "shift amount too large: first operand shifted by maximum "
5040  "amount exceeds maximum width");
5041  width = newWidth;
5042  }
5043  return IntType::get(lhs.getContext(), lhsi.isSigned(), width,
5044  lhsi.isConst() && rhsui.isConst());
5045 }
5046 
5047 FIRRTLType DShlwPrimOp::inferBinaryReturnType(FIRRTLType lhs, FIRRTLType rhs,
5048  std::optional<Location> loc) {
5049  auto lhsi = type_dyn_cast<IntType>(lhs);
5050  auto rhsu = type_dyn_cast<UIntType>(rhs);
5051  if (!lhsi || !rhsu)
5052  return emitInferRetTypeError(
5053  loc, "first operand should be integer, second unsigned int");
5054  return lhsi.getConstType(lhsi.isConst() && rhsu.isConst());
5055 }
5056 
5057 FIRRTLType DShrPrimOp::inferBinaryReturnType(FIRRTLType lhs, FIRRTLType rhs,
5058  std::optional<Location> loc) {
5059  auto lhsi = type_dyn_cast<IntType>(lhs);
5060  auto rhsu = type_dyn_cast<UIntType>(rhs);
5061  if (!lhsi || !rhsu)
5062  return emitInferRetTypeError(
5063  loc, "first operand should be integer, second unsigned int");
5064  return lhsi.getConstType(lhsi.isConst() && rhsu.isConst());
5065 }
5066 
5067 //===----------------------------------------------------------------------===//
5068 // Unary Primitives
5069 //===----------------------------------------------------------------------===//
5070 
5071 LogicalResult impl::validateUnaryOpArguments(ValueRange operands,
5072  ArrayRef<NamedAttribute> attrs,
5073  Location loc) {
5074  if (operands.size() != 1 || !attrs.empty()) {
5075  mlir::emitError(loc, "operation requires one operand and no constants");
5076  return failure();
5077  }
5078  return success();
5079 }
5080 
5081 FIRRTLType
5082 SizeOfIntrinsicOp::inferUnaryReturnType(FIRRTLType input,
5083  std::optional<Location> loc) {
5084  return UIntType::get(input.getContext(), 32);
5085 }
5086 
5087 FIRRTLType AsSIntPrimOp::inferUnaryReturnType(FIRRTLType input,
5088  std::optional<Location> loc) {
5089  auto base = type_dyn_cast<FIRRTLBaseType>(input);
5090  if (!base)
5091  return emitInferRetTypeError(loc, "operand must be a scalar base type");
5092  int32_t width = base.getBitWidthOrSentinel();
5093  if (width == -2)
5094  return emitInferRetTypeError(loc, "operand must be a scalar type");
5095  return SIntType::get(input.getContext(), width, base.isConst());
5096 }
5097 
5098 FIRRTLType AsUIntPrimOp::inferUnaryReturnType(FIRRTLType input,
5099  std::optional<Location> loc) {
5100  auto base = type_dyn_cast<FIRRTLBaseType>(input);
5101  if (!base)
5102  return emitInferRetTypeError(loc, "operand must be a scalar base type");
5103  int32_t width = base.getBitWidthOrSentinel();
5104  if (width == -2)
5105  return emitInferRetTypeError(loc, "operand must be a scalar type");
5106  return UIntType::get(input.getContext(), width, base.isConst());
5107 }
5108 
5109 FIRRTLType
5110 AsAsyncResetPrimOp::inferUnaryReturnType(FIRRTLType input,
5111  std::optional<Location> loc) {
5112  auto base = type_dyn_cast<FIRRTLBaseType>(input);
5113  if (!base)
5114  return emitInferRetTypeError(loc,
5115  "operand must be single bit scalar base type");
5116  int32_t width = base.getBitWidthOrSentinel();
5117  if (width == -2 || width == 0 || width > 1)
5118  return emitInferRetTypeError(loc, "operand must be single bit scalar type");
5119  return AsyncResetType::get(input.getContext(), base.isConst());
5120 }
5121 
5122 FIRRTLType AsClockPrimOp::inferUnaryReturnType(FIRRTLType input,
5123  std::optional<Location> loc) {
5124  return ClockType::get(input.getContext(), isConst(input));
5125 }
5126 
5127 FIRRTLType CvtPrimOp::inferUnaryReturnType(FIRRTLType input,
5128  std::optional<Location> loc) {
5129  if (auto uiType = type_dyn_cast<UIntType>(input)) {
5130  auto width = uiType.getWidthOrSentinel();
5131  if (width != -1)
5132  ++width;
5133  return SIntType::get(input.getContext(), width, uiType.isConst());
5134  }
5135 
5136  if (type_isa<SIntType>(input))
5137  return input;
5138 
5139  return emitInferRetTypeError(loc, "operand must have integer type");
5140 }
5141 
5142 FIRRTLType NegPrimOp::inferUnaryReturnType(FIRRTLType input,
5143  std::optional<Location> loc) {
5144  auto inputi = type_dyn_cast<IntType>(input);
5145  if (!inputi)
5146  return emitInferRetTypeError(loc, "operand must have integer type");
5147  int32_t width = inputi.getWidthOrSentinel();
5148  if (width != -1)
5149  ++width;
5150  return SIntType::get(input.getContext(), width, inputi.isConst());
5151 }
5152 
5153 FIRRTLType NotPrimOp::inferUnaryReturnType(FIRRTLType input,
5154  std::optional<Location> loc) {
5155  auto inputi = type_dyn_cast<IntType>(input);
5156  if (!inputi)
5157  return emitInferRetTypeError(loc, "operand must have integer type");
5158  if (isa<UIntType>(inputi))
5159  return inputi;
5160  return UIntType::get(input.getContext(), inputi.getWidthOrSentinel(),
5161  inputi.isConst());
5162 }
5163 
5165  std::optional<Location> loc) {
5166  return UIntType::get(input.getContext(), 1, isConst(input));
5167 }
5168 
5169 //===----------------------------------------------------------------------===//
5170 // Other Operations
5171 //===----------------------------------------------------------------------===//
5172 
5173 LogicalResult BitsPrimOp::validateArguments(ValueRange operands,
5174  ArrayRef<NamedAttribute> attrs,
5175  Location loc) {
5176  if (operands.size() != 1 || attrs.size() != 2) {
5177  mlir::emitError(loc, "operation requires one operand and two constants");
5178  return failure();
5179  }
5180  return success();
5181 }
5182 
5183 FIRRTLType BitsPrimOp::inferReturnType(ValueRange operands,
5184  ArrayRef<NamedAttribute> attrs,
5185  std::optional<Location> loc) {
5186  auto input = operands[0].getType();
5187  auto high = getAttr<IntegerAttr>(attrs, "hi").getValue().getSExtValue();
5188  auto low = getAttr<IntegerAttr>(attrs, "lo").getValue().getSExtValue();
5189 
5190  auto inputi = type_dyn_cast<IntType>(input);
5191  if (!inputi)
5192  return emitInferRetTypeError(
5193  loc, "input type should be the int type but got ", input);
5194 
5195  // High must be >= low and both most be non-negative.
5196  if (high < low)
5197  return emitInferRetTypeError(
5198  loc, "high must be equal or greater than low, but got high = ", high,
5199  ", low = ", low);
5200 
5201  if (low < 0)
5202  return emitInferRetTypeError(loc, "low must be non-negative but got ", low);
5203 
5204  // If the input has staticly known width, check it. Both and low must be
5205  // strictly less than width.
5206  int32_t width = inputi.getWidthOrSentinel();
5207  if (width != -1 && high >= width)
5208  return emitInferRetTypeError(
5209  loc,
5210  "high must be smaller than the width of input, but got high = ", high,
5211  ", width = ", width);
5212 
5213  return UIntType::get(input.getContext(), high - low + 1, inputi.isConst());
5214 }
5215 
5216 LogicalResult impl::validateOneOperandOneConst(ValueRange operands,
5217  ArrayRef<NamedAttribute> attrs,
5218  Location loc) {
5219  if (operands.size() != 1 || attrs.size() != 1) {
5220  mlir::emitError(loc, "operation requires one operand and one constant");
5221  return failure();
5222  }
5223  return success();
5224 }
5225 
5226 FIRRTLType HeadPrimOp::inferReturnType(ValueRange operands,
5227  ArrayRef<NamedAttribute> attrs,
5228  std::optional<Location> loc) {
5229  auto input = operands[0].getType();
5230  auto amount = getAttr<IntegerAttr>(attrs, "amount").getValue().getSExtValue();
5231 
5232  auto inputi = type_dyn_cast<IntType>(input);
5233  if (amount < 0 || !inputi)
5234  return emitInferRetTypeError(
5235  loc, "operand must have integer type and amount must be >= 0");
5236 
5237  int32_t width = inputi.getWidthOrSentinel();
5238  if (width != -1 && amount > width)
5239  return emitInferRetTypeError(loc, "amount larger than input width");
5240 
5241  return UIntType::get(input.getContext(), amount, inputi.isConst());
5242 }
5243 
5244 LogicalResult MuxPrimOp::validateArguments(ValueRange operands,
5245  ArrayRef<NamedAttribute> attrs,
5246  Location loc) {
5247  if (operands.size() != 3 || attrs.size() != 0) {
5248  mlir::emitError(loc, "operation requires three operands and no constants");
5249  return failure();
5250  }
5251  return success();
5252 }
5253 
5254 /// Infer the result type for a multiplexer given its two operand types, which
5255 /// may be aggregates.
5256 ///
5257 /// This essentially performs a pairwise comparison of fields and elements, as
5258 /// follows:
5259 /// - Identical operands inferred to their common type
5260 /// - Integer operands inferred to the larger one if both have a known width, a
5261 /// widthless integer otherwise.
5262 /// - Vectors inferred based on the element type.
5263 /// - Bundles inferred in a pairwise fashion based on the field types.
5265  FIRRTLBaseType low,
5266  bool isConstCondition,
5267  std::optional<Location> loc) {
5268  // If the types are identical we're done.
5269  if (high == low)
5270  return isConstCondition ? low : low.getAllConstDroppedType();
5271 
5272  // The base types need to be equivalent.
5273  if (high.getTypeID() != low.getTypeID())
5274  return emitInferRetTypeError<FIRRTLBaseType>(
5275  loc, "incompatible mux operand types, true value type: ", high,
5276  ", false value type: ", low);
5277 
5278  bool outerTypeIsConst = isConstCondition && low.isConst() && high.isConst();
5279 
5280  // Two different Int types can be compatible. If either has unknown width,
5281  // then return it. If both are known but different width, then return the
5282  // larger one.
5283  if (type_isa<IntType>(low)) {
5284  int32_t highWidth = high.getBitWidthOrSentinel();
5285  int32_t lowWidth = low.getBitWidthOrSentinel();
5286  if (lowWidth == -1)
5287  return low.getConstType(outerTypeIsConst);
5288  if (highWidth == -1)
5289  return high.getConstType(outerTypeIsConst);
5290  return (lowWidth > highWidth ? low : high).getConstType(outerTypeIsConst);
5291  }
5292 
5293  // Infer vector types by comparing the element types.
5294  auto highVector = type_dyn_cast<FVectorType>(high);
5295  auto lowVector = type_dyn_cast<FVectorType>(low);
5296  if (highVector && lowVector &&
5297  highVector.getNumElements() == lowVector.getNumElements()) {
5298  auto inner = inferMuxReturnType(highVector.getElementTypePreservingConst(),
5299  lowVector.getElementTypePreservingConst(),
5300  isConstCondition, loc);
5301  if (!inner)
5302  return {};
5303  return FVectorType::get(inner, lowVector.getNumElements(),
5304  outerTypeIsConst);
5305  }
5306 
5307  // Infer bundle types by inferring names in a pairwise fashion.
5308  auto highBundle = type_dyn_cast<BundleType>(high);
5309  auto lowBundle = type_dyn_cast<BundleType>(low);
5310  if (highBundle && lowBundle) {
5311  auto highElements = highBundle.getElements();
5312  auto lowElements = lowBundle.getElements();
5313  size_t numElements = highElements.size();
5314 
5315  SmallVector<BundleType::BundleElement> newElements;
5316  if (numElements == lowElements.size()) {
5317  bool failed = false;
5318  for (size_t i = 0; i < numElements; ++i) {
5319  if (highElements[i].name != lowElements[i].name ||
5320  highElements[i].isFlip != lowElements[i].isFlip) {
5321  failed = true;
5322  break;
5323  }
5324  auto element = highElements[i];
5325  element.type = inferMuxReturnType(
5326  highBundle.getElementTypePreservingConst(i),
5327  lowBundle.getElementTypePreservingConst(i), isConstCondition, loc);
5328  if (!element.type)
5329  return {};
5330  newElements.push_back(element);
5331  }
5332  if (!failed)
5333  return BundleType::get(low.getContext(), newElements, outerTypeIsConst);
5334  }
5335  return emitInferRetTypeError<FIRRTLBaseType>(
5336  loc, "incompatible mux operand bundle fields, true value type: ", high,
5337  ", false value type: ", low);
5338  }
5339 
5340  // If we arrive here the types of the two mux arms are fundamentally
5341  // incompatible.
5342  return emitInferRetTypeError<FIRRTLBaseType>(
5343  loc, "invalid mux operand types, true value type: ", high,
5344  ", false value type: ", low);
5345 }
5346 
5347 FIRRTLType MuxPrimOp::inferReturnType(ValueRange operands,
5348  ArrayRef<NamedAttribute> attrs,
5349  std::optional<Location> loc) {
5350  auto highType = type_dyn_cast<FIRRTLBaseType>(operands[1].getType());
5351  auto lowType = type_dyn_cast<FIRRTLBaseType>(operands[2].getType());
5352  if (!highType || !lowType)
5353  return emitInferRetTypeError(loc, "operands must be base type");
5354  return inferMuxReturnType(highType, lowType, isConst(operands[0].getType()),
5355  loc);
5356 }
5357 
5358 FIRRTLType Mux2CellIntrinsicOp::inferReturnType(ValueRange operands,
5359  ArrayRef<NamedAttribute> attrs,
5360  std::optional<Location> loc) {
5361  auto highType = type_dyn_cast<FIRRTLBaseType>(operands[1].getType());
5362  auto lowType = type_dyn_cast<FIRRTLBaseType>(operands[2].getType());
5363  if (!highType || !lowType)
5364  return emitInferRetTypeError(loc, "operands must be base type");
5365  return inferMuxReturnType(highType, lowType, isConst(operands[0].getType()),
5366  loc);
5367 }
5368 
5369 FIRRTLType Mux4CellIntrinsicOp::inferReturnType(ValueRange operands,
5370  ArrayRef<NamedAttribute> attrs,
5371  std::optional<Location> loc) {
5372  SmallVector<FIRRTLBaseType> types;
5373  FIRRTLBaseType result;
5374  for (unsigned i = 1; i < 5; i++) {
5375  types.push_back(type_dyn_cast<FIRRTLBaseType>(operands[i].getType()));
5376  if (!types.back())
5377  return emitInferRetTypeError(loc, "operands must be base type");
5378  if (result) {
5379  result = inferMuxReturnType(result, types.back(),
5380  isConst(operands[0].getType()), loc);
5381  if (!result)
5382  return result;
5383  } else {
5384  result = types.back();
5385  }
5386  }
5387  return result;
5388 }
5389 
5390 FIRRTLType PadPrimOp::inferReturnType(ValueRange operands,
5391  ArrayRef<NamedAttribute> attrs,
5392  std::optional<Location> loc) {
5393  auto input = operands[0].getType();
5394  auto amount = getAttr<IntegerAttr>(attrs, "amount").getValue().getSExtValue();
5395 
5396  auto inputi = type_dyn_cast<IntType>(input);
5397  if (amount < 0 || !inputi)
5398  return emitInferRetTypeError(
5399  loc, "pad input must be integer and amount must be >= 0");
5400 
5401  int32_t width = inputi.getWidthOrSentinel();
5402  if (width == -1)
5403  return inputi;
5404 
5405  width = std::max<int32_t>(width, amount);
5406  return IntType::get(input.getContext(), inputi.isSigned(), width,
5407  inputi.isConst());
5408 }
5409 
5410 FIRRTLType ShlPrimOp::inferReturnType(ValueRange operands,
5411  ArrayRef<NamedAttribute> attrs,
5412  std::optional<Location> loc) {
5413  auto input = operands[0].getType();
5414  auto amount = getAttr<IntegerAttr>(attrs, "amount").getValue().getSExtValue();
5415 
5416  auto inputi = type_dyn_cast<IntType>(input);
5417  if (amount < 0 || !inputi)
5418  return emitInferRetTypeError(
5419  loc, "shl input must be integer and amount must be >= 0");
5420 
5421  int32_t width = inputi.getWidthOrSentinel();
5422  if (width != -1)
5423  width += amount;
5424 
5425  return IntType::get(input.getContext(), inputi.isSigned(), width,
5426  inputi.isConst());
5427 }
5428 
5429 FIRRTLType ShrPrimOp::inferReturnType(ValueRange operands,
5430  ArrayRef<NamedAttribute> attrs,
5431  std::optional<Location> loc) {
5432  auto input = operands[0].getType();
5433  auto amount = getAttr<IntegerAttr>(attrs, "amount").getValue().getSExtValue();
5434 
5435  auto inputi = type_dyn_cast<IntType>(input);
5436  if (amount < 0 || !inputi)
5437  return emitInferRetTypeError(
5438  loc, "shr input must be integer and amount must be >= 0");
5439 
5440  int32_t width = inputi.getWidthOrSentinel();
5441  if (width != -1) {
5442  // UInt saturates at 0 bits, SInt at 1 bit
5443  int32_t minWidth = inputi.isUnsigned() ? 0 : 1;
5444  width = std::max<int32_t>(minWidth, width - amount);
5445  }
5446 
5447  return IntType::get(input.getContext(), inputi.isSigned(), width,
5448  inputi.isConst());
5449 }
5450 
5451 FIRRTLType TailPrimOp::inferReturnType(ValueRange operands,
5452  ArrayRef<NamedAttribute> attrs,
5453  std::optional<Location> loc) {
5454  auto input = operands[0].getType();
5455  auto amount = getAttr<IntegerAttr>(attrs, "amount").getValue().getSExtValue();
5456 
5457  auto inputi = type_dyn_cast<IntType>(input);
5458  if (amount < 0 || !inputi)
5459  return emitInferRetTypeError(
5460  loc, "tail input must be integer and amount must be >= 0");
5461 
5462  int32_t width = inputi.getWidthOrSentinel();
5463  if (width != -1) {
5464  if (width < amount)
5465  return emitInferRetTypeError(
5466  loc, "amount must be less than or equal operand width");
5467  width -= amount;
5468  }
5469 
5470  return IntType::get(input.getContext(), false, width, inputi.isConst());
5471 }
5472 
5473 //===----------------------------------------------------------------------===//
5474 // VerbatimExprOp
5475 //===----------------------------------------------------------------------===//
5476 
5478  function_ref<void(Value, StringRef)> setNameFn) {
5479  // If the text is macro like, then use a pretty name. We only take the
5480  // text up to a weird character (like a paren) and currently ignore
5481  // parenthesized expressions.
5482  auto isOkCharacter = [](char c) { return llvm::isAlnum(c) || c == '_'; };
5483  auto name = getText();
5484  // Ignore a leading ` in macro name.
5485  if (name.starts_with("`"))
5486  name = name.drop_front();
5487  name = name.take_while(isOkCharacter);
5488  if (!name.empty())
5489  setNameFn(getResult(), name);
5490 }
5491 
5492 //===----------------------------------------------------------------------===//
5493 // VerbatimWireOp
5494 //===----------------------------------------------------------------------===//
5495 
5497  function_ref<void(Value, StringRef)> setNameFn) {
5498  // If the text is macro like, then use a pretty name. We only take the
5499  // text up to a weird character (like a paren) and currently ignore
5500  // parenthesized expressions.
5501  auto isOkCharacter = [](char c) { return llvm::isAlnum(c) || c == '_'; };
5502  auto name = getText();
5503  // Ignore a leading ` in macro name.
5504  if (name.starts_with("`"))
5505  name = name.drop_front();
5506  name = name.take_while(isOkCharacter);
5507  if (!name.empty())
5508  setNameFn(getResult(), name);
5509 }
5510 
5511 //===----------------------------------------------------------------------===//
5512 // DPICallIntrinsicOp
5513 //===----------------------------------------------------------------------===//
5514 
5515 static bool isTypeAllowedForDPI(Operation *op, Type type) {
5516  return !type.walk([&](firrtl::IntType intType) -> mlir::WalkResult {
5517  auto width = intType.getWidth();
5518  if (width < 0) {
5519  op->emitError() << "unknown width is not allowed for DPI";
5520  return WalkResult::interrupt();
5521  }
5522  if (width == 1 || width == 8 || width == 16 || width == 32 ||
5523  width >= 64)
5524  return WalkResult::advance();
5525  op->emitError()
5526  << "integer types used by DPI functions must have a "
5527  "specific bit width; "
5528  "it must be equal to 1(bit), 8(byte), 16(shortint), "
5529  "32(int), 64(longint) "
5530  "or greater than 64, but got "
5531  << intType;
5532  return WalkResult::interrupt();
5533  })
5534  .wasInterrupted();
5535 }
5536 
5537 LogicalResult DPICallIntrinsicOp::verify() {
5538  auto checkType = [this](Type type) {
5539  return isTypeAllowedForDPI(*this, type);
5540  };
5541  return success(llvm::all_of(this->getResultTypes(), checkType) &&
5542  llvm::all_of(this->getOperandTypes(), checkType));
5543 }
5544 
5545 //===----------------------------------------------------------------------===//
5546 // Conversions to/from structs in the standard dialect.
5547 //===----------------------------------------------------------------------===//
5548 
5549 LogicalResult HWStructCastOp::verify() {
5550  // We must have a bundle and a struct, with matching pairwise fields
5551  BundleType bundleType;
5552  hw::StructType structType;
5553  if ((bundleType = type_dyn_cast<BundleType>(getOperand().getType()))) {
5554  structType = dyn_cast<hw::StructType>(getType());
5555  if (!structType)
5556  return emitError("result type must be a struct");
5557  } else if ((bundleType = type_dyn_cast<BundleType>(getType()))) {
5558  structType = dyn_cast<hw::StructType>(getOperand().getType());
5559  if (!structType)
5560  return emitError("operand type must be a struct");
5561  } else {
5562  return emitError("either source or result type must be a bundle type");
5563  }
5564 
5565  auto firFields = bundleType.getElements();
5566  auto hwFields = structType.getElements();
5567  if (firFields.size() != hwFields.size())
5568  return emitError("bundle and struct have different number of fields");
5569 
5570  for (size_t findex = 0, fend = firFields.size(); findex < fend; ++findex) {
5571  if (firFields[findex].name.getValue() != hwFields[findex].name)
5572  return emitError("field names don't match '")
5573  << firFields[findex].name.getValue() << "', '"
5574  << hwFields[findex].name.getValue() << "'";
5575  int64_t firWidth =
5576  FIRRTLBaseType(firFields[findex].type).getBitWidthOrSentinel();
5577  int64_t hwWidth = hw::getBitWidth(hwFields[findex].type);
5578  if (firWidth > 0 && hwWidth > 0 && firWidth != hwWidth)
5579  return emitError("size of field '")
5580  << hwFields[findex].name.getValue() << "' don't match " << firWidth
5581  << ", " << hwWidth;
5582  }
5583 
5584  return success();
5585 }
5586 
5587 LogicalResult BitCastOp::verify() {
5588  auto inTypeBits = getBitWidth(getInput().getType(), /*ignoreFlip=*/true);
5589  auto resTypeBits = getBitWidth(getType());
5590  if (inTypeBits.has_value() && resTypeBits.has_value()) {
5591  // Bitwidths must match for valid bit
5592  if (*inTypeBits == *resTypeBits) {
5593  // non-'const' cannot be casted to 'const'
5594  if (containsConst(getType()) && !isConst(getOperand().getType()))
5595  return emitError("cannot cast non-'const' input type ")
5596  << getOperand().getType() << " to 'const' result type "
5597  << getType();
5598  return success();
5599  }
5600  return emitError("the bitwidth of input (")
5601  << *inTypeBits << ") and result (" << *resTypeBits
5602  << ") don't match";
5603  }
5604  if (!inTypeBits.has_value())
5605  return emitError("bitwidth cannot be determined for input operand type ")
5606  << getInput().getType();
5607  return emitError("bitwidth cannot be determined for result type ")
5608  << getType();
5609 }
5610 
5611 //===----------------------------------------------------------------------===//
5612 // Custom attr-dict Directive that Elides Annotations
5613 //===----------------------------------------------------------------------===//
5614 
5615 /// Parse an optional attribute dictionary, adding an empty 'annotations'
5616 /// attribute if not specified.
5617 static ParseResult parseElideAnnotations(OpAsmParser &parser,
5618  NamedAttrList &resultAttrs) {
5619  auto result = parser.parseOptionalAttrDict(resultAttrs);
5620  if (!resultAttrs.get("annotations"))
5621  resultAttrs.append("annotations", parser.getBuilder().getArrayAttr({}));
5622 
5623  return result;
5624 }
5625 
5626 static void printElideAnnotations(OpAsmPrinter &p, Operation *op,
5627  DictionaryAttr attr,
5628  ArrayRef<StringRef> extraElides = {}) {
5629  SmallVector<StringRef> elidedAttrs(extraElides.begin(), extraElides.end());
5630  // Elide "annotations" if it is empty.
5631  if (op->getAttrOfType<ArrayAttr>("annotations").empty())
5632  elidedAttrs.push_back("annotations");
5633  // Elide "nameKind".
5634  elidedAttrs.push_back("nameKind");
5635 
5636  p.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
5637 }
5638 
5639 /// Parse an optional attribute dictionary, adding empty 'annotations' and
5640 /// 'portAnnotations' attributes if not specified.
5641 static ParseResult parseElidePortAnnotations(OpAsmParser &parser,
5642  NamedAttrList &resultAttrs) {
5643  auto result = parseElideAnnotations(parser, resultAttrs);
5644 
5645  if (!resultAttrs.get("portAnnotations")) {
5646  SmallVector<Attribute, 16> portAnnotations(
5647  parser.getNumResults(), parser.getBuilder().getArrayAttr({}));
5648  resultAttrs.append("portAnnotations",
5649  parser.getBuilder().getArrayAttr(portAnnotations));
5650  }
5651  return result;
5652 }
5653 
5654 // Elide 'annotations' and 'portAnnotations' attributes if they are empty.
5655 static void printElidePortAnnotations(OpAsmPrinter &p, Operation *op,
5656  DictionaryAttr attr,
5657  ArrayRef<StringRef> extraElides = {}) {
5658  SmallVector<StringRef, 2> elidedAttrs(extraElides.begin(), extraElides.end());
5659 
5660  if (llvm::all_of(op->getAttrOfType<ArrayAttr>("portAnnotations"),
5661  [&](Attribute a) { return cast<ArrayAttr>(a).empty(); }))
5662  elidedAttrs.push_back("portAnnotations");
5663  printElideAnnotations(p, op, attr, elidedAttrs);
5664 }
5665 
5666 //===----------------------------------------------------------------------===//
5667 // NameKind Custom Directive
5668 //===----------------------------------------------------------------------===//
5669 
5670 static ParseResult parseNameKind(OpAsmParser &parser,
5671  firrtl::NameKindEnumAttr &result) {
5672  StringRef keyword;
5673 
5674  if (!parser.parseOptionalKeyword(&keyword,
5675  {"interesting_name", "droppable_name"})) {
5676  auto kind = symbolizeNameKindEnum(keyword);
5677  result = NameKindEnumAttr::get(parser.getContext(), kind.value());
5678  return success();
5679  }
5680 
5681  // Default is droppable name.
5682  result =
5683  NameKindEnumAttr::get(parser.getContext(), NameKindEnum::DroppableName);
5684  return success();
5685 }
5686 
5687 static void printNameKind(OpAsmPrinter &p, Operation *op,
5688  firrtl::NameKindEnumAttr attr,
5689  ArrayRef<StringRef> extraElides = {}) {
5690  if (attr.getValue() != NameKindEnum::DroppableName)
5691  p << " " << stringifyNameKindEnum(attr.getValue());
5692 }
5693 
5694 //===----------------------------------------------------------------------===//
5695 // ImplicitSSAName Custom Directive
5696 //===----------------------------------------------------------------------===//
5697 
5698 static ParseResult parseFIRRTLImplicitSSAName(OpAsmParser &parser,
5699  NamedAttrList &resultAttrs) {
5700  if (parseElideAnnotations(parser, resultAttrs))
5701  return failure();
5702  inferImplicitSSAName(parser, resultAttrs);
5703  return success();
5704 }
5705 
5706 static void printFIRRTLImplicitSSAName(OpAsmPrinter &p, Operation *op,
5707  DictionaryAttr attrs) {
5708  SmallVector<StringRef, 4> elides;
5710  elides.push_back(Forceable::getForceableAttrName());
5711  elideImplicitSSAName(p, op, attrs, elides);
5712  printElideAnnotations(p, op, attrs, elides);
5713 }
5714 
5715 //===----------------------------------------------------------------------===//
5716 // MemOp Custom attr-dict Directive
5717 //===----------------------------------------------------------------------===//
5718 
5719 static ParseResult parseMemOp(OpAsmParser &parser, NamedAttrList &resultAttrs) {
5720  return parseElidePortAnnotations(parser, resultAttrs);
5721 }
5722 
5723 /// Always elide "ruw" and elide "annotations" if it exists or if it is empty.
5724 static void printMemOp(OpAsmPrinter &p, Operation *op, DictionaryAttr attr) {
5725  // "ruw" and "inner_sym" is always elided.
5726  printElidePortAnnotations(p, op, attr, {"ruw", "inner_sym"});
5727 }
5728 
5729 //===----------------------------------------------------------------------===//
5730 // ClassInterface custom directive
5731 //===----------------------------------------------------------------------===//
5732 
5733 static ParseResult parseClassInterface(OpAsmParser &parser, Type &result) {
5734  ClassType type;
5735  if (ClassType::parseInterface(parser, type))
5736  return failure();
5737  result = type;
5738  return success();
5739 }
5740 
5741 static void printClassInterface(OpAsmPrinter &p, Operation *, ClassType type) {
5742  type.printInterface(p);
5743 }
5744 
5745 //===----------------------------------------------------------------------===//
5746 // Miscellaneous custom elision logic.
5747 //===----------------------------------------------------------------------===//
5748 
5749 static ParseResult parseElideEmptyName(OpAsmParser &p,
5750  NamedAttrList &resultAttrs) {
5751  auto result = p.parseOptionalAttrDict(resultAttrs);
5752  if (!resultAttrs.get("name"))
5753  resultAttrs.append("name", p.getBuilder().getStringAttr(""));
5754 
5755  return result;
5756 }
5757 
5758 static void printElideEmptyName(OpAsmPrinter &p, Operation *op,
5759  DictionaryAttr attr,
5760  ArrayRef<StringRef> extraElides = {}) {
5761  SmallVector<StringRef> elides(extraElides.begin(), extraElides.end());
5762  if (op->getAttrOfType<StringAttr>("name").getValue().empty())
5763  elides.push_back("name");
5764 
5765  p.printOptionalAttrDict(op->getAttrs(), elides);
5766 }
5767 
5768 static ParseResult parsePrintfAttrs(OpAsmParser &p,
5769  NamedAttrList &resultAttrs) {
5770  return parseElideEmptyName(p, resultAttrs);
5771 }
5772 
5773 static void printPrintfAttrs(OpAsmPrinter &p, Operation *op,
5774  DictionaryAttr attr) {
5775  printElideEmptyName(p, op, attr, {"formatString"});
5776 }
5777 
5778 static ParseResult parseStopAttrs(OpAsmParser &p, NamedAttrList &resultAttrs) {
5779  return parseElideEmptyName(p, resultAttrs);
5780 }
5781 
5782 static void printStopAttrs(OpAsmPrinter &p, Operation *op,
5783  DictionaryAttr attr) {
5784  printElideEmptyName(p, op, attr, {"exitCode"});
5785 }
5786 
5787 static ParseResult parseVerifAttrs(OpAsmParser &p, NamedAttrList &resultAttrs) {
5788  return parseElideEmptyName(p, resultAttrs);
5789 }
5790 
5791 static void printVerifAttrs(OpAsmPrinter &p, Operation *op,
5792  DictionaryAttr attr) {
5793  printElideEmptyName(p, op, attr, {"message"});
5794 }
5795 
5796 //===----------------------------------------------------------------------===//
5797 // Various namers.
5798 //===----------------------------------------------------------------------===//
5799 
5800 static void genericAsmResultNames(Operation *op,
5801  OpAsmSetValueNameFn setNameFn) {
5802  // Many firrtl dialect operations have an optional 'name' attribute. If
5803  // present, use it.
5804  if (op->getNumResults() == 1)
5805  if (auto nameAttr = op->getAttrOfType<StringAttr>("name"))
5806  setNameFn(op->getResult(0), nameAttr.getValue());
5807 }
5808 
5810  genericAsmResultNames(*this, setNameFn);
5811 }
5812 
5814  genericAsmResultNames(*this, setNameFn);
5815 }
5816 
5818  genericAsmResultNames(*this, setNameFn);
5819 }
5820 
5822  genericAsmResultNames(*this, setNameFn);
5823 }
5825  genericAsmResultNames(*this, setNameFn);
5826 }
5828  genericAsmResultNames(*this, setNameFn);
5829 }
5831  genericAsmResultNames(*this, setNameFn);
5832 }
5834  genericAsmResultNames(*this, setNameFn);
5835 }
5837  genericAsmResultNames(*this, setNameFn);
5838 }
5840  genericAsmResultNames(*this, setNameFn);
5841 }
5843  genericAsmResultNames(*this, setNameFn);
5844 }
5846  genericAsmResultNames(*this, setNameFn);
5847 }
5849  genericAsmResultNames(*this, setNameFn);
5850 }
5852  genericAsmResultNames(*this, setNameFn);
5853 }
5855  genericAsmResultNames(*this, setNameFn);
5856 }
5858  genericAsmResultNames(*this, setNameFn);
5859 }
5861  genericAsmResultNames(*this, setNameFn);
5862 }
5864  genericAsmResultNames(*this, setNameFn);
5865 }
5867  genericAsmResultNames(*this, setNameFn);
5868 }
5870  genericAsmResultNames(*this, setNameFn);
5871 }
5873  genericAsmResultNames(*this, setNameFn);
5874 }
5876  genericAsmResultNames(*this, setNameFn);
5877 }
5879  genericAsmResultNames(*this, setNameFn);
5880 }
5882  genericAsmResultNames(*this, setNameFn);
5883 }
5885  genericAsmResultNames(*this, setNameFn);
5886 }
5888  OpAsmSetValueNameFn setNameFn) {
5889  genericAsmResultNames(*this, setNameFn);
5890 }
5892  genericAsmResultNames(*this, setNameFn);
5893 }
5895  genericAsmResultNames(*this, setNameFn);
5896 }
5898  genericAsmResultNames(*this, setNameFn);
5899 }
5901  genericAsmResultNames(*this, setNameFn);
5902 }
5904  genericAsmResultNames(*this, setNameFn);
5905 }
5907  genericAsmResultNames(*this, setNameFn);
5908 }
5910  genericAsmResultNames(*this, setNameFn);
5911 }
5913  genericAsmResultNames(*this, setNameFn);
5914 }
5916  genericAsmResultNames(*this, setNameFn);
5917 }
5919  genericAsmResultNames(*this, setNameFn);
5920 }
5922  genericAsmResultNames(*this, setNameFn);
5923 }
5925  genericAsmResultNames(*this, setNameFn);
5926 }
5928  genericAsmResultNames(*this, setNameFn);
5929 }
5931  genericAsmResultNames(*this, setNameFn);
5932 }
5934  genericAsmResultNames(*this, setNameFn);
5935 }
5937  genericAsmResultNames(*this, setNameFn);
5938 }
5940  genericAsmResultNames(*this, setNameFn);
5941 }
5942 
5944  genericAsmResultNames(*this, setNameFn);
5945 }
5946 
5948  genericAsmResultNames(*this, setNameFn);
5949 }
5950 
5952  genericAsmResultNames(*this, setNameFn);
5953 }
5954 
5956  genericAsmResultNames(*this, setNameFn);
5957 }
5958 
5960  genericAsmResultNames(*this, setNameFn);
5961 }
5962 
5964  genericAsmResultNames(*this, setNameFn);
5965 }
5966 
5968  genericAsmResultNames(*this, setNameFn);
5969 }
5970 
5972  genericAsmResultNames(*this, setNameFn);
5973 }
5974 
5976  genericAsmResultNames(*this, setNameFn);
5977 }
5978 
5980  genericAsmResultNames(*this, setNameFn);
5981 }
5982 
5984  genericAsmResultNames(*this, setNameFn);
5985 }
5986 
5988  genericAsmResultNames(*this, setNameFn);
5989 }
5990 
5992  genericAsmResultNames(*this, setNameFn);
5993 }
5994 
5996  genericAsmResultNames(*this, setNameFn);
5997 }
5998 
6000  genericAsmResultNames(*this, setNameFn);
6001 }
6002 
6004  genericAsmResultNames(*this, setNameFn);
6005 }
6006 
6007 //===----------------------------------------------------------------------===//
6008 // RefOps
6009 //===----------------------------------------------------------------------===//
6010 
6012  genericAsmResultNames(*this, setNameFn);
6013 }
6014 
6016  genericAsmResultNames(*this, setNameFn);
6017 }
6018 
6020  genericAsmResultNames(*this, setNameFn);
6021 }
6022 
6024  genericAsmResultNames(*this, setNameFn);
6025 }
6026 
6028  genericAsmResultNames(*this, setNameFn);
6029 }
6030 
6031 FIRRTLType RefResolveOp::inferReturnType(ValueRange operands,
6032  ArrayRef<NamedAttribute> attrs,
6033  std::optional<Location> loc) {
6034  Type inType = operands[0].getType();
6035  auto inRefType = type_dyn_cast<RefType>(inType);
6036  if (!inRefType)
6037  return emitInferRetTypeError(
6038  loc, "ref.resolve operand must be ref type, not ", inType);
6039  return inRefType.getType();
6040 }
6041 
6042 FIRRTLType RefSendOp::inferReturnType(ValueRange operands,
6043  ArrayRef<NamedAttribute> attrs,
6044  std::optional<Location> loc) {
6045  Type inType = operands[0].getType();
6046  auto inBaseType = type_dyn_cast<FIRRTLBaseType>(inType);
6047  if (!inBaseType)
6048  return emitInferRetTypeError(
6049  loc, "ref.send operand must be base type, not ", inType);
6050  return RefType::get(inBaseType.getPassiveType());
6051 }
6052 
6053 FIRRTLType RefSubOp::inferReturnType(ValueRange operands,
6054  ArrayRef<NamedAttribute> attrs,
6055  std::optional<Location> loc) {
6056  auto refType = type_dyn_cast<RefType>(operands[0].getType());
6057  if (!refType)
6058  return emitInferRetTypeError(loc, "input must be of reference type");
6059  auto inType = refType.getType();
6060  auto fieldIdx =
6061  getAttr<IntegerAttr>(attrs, "index").getValue().getZExtValue();
6062 
6063  // TODO: Determine ref.sub + rwprobe behavior, test.
6064  // Probably best to demote to non-rw, but that has implications
6065  // for any LowerTypes behavior being relied on.
6066  // Allow for now, as need to LowerTypes things generally.
6067  if (auto vectorType = type_dyn_cast<FVectorType>(inType)) {
6068  if (fieldIdx < vectorType.getNumElements())
6069  return RefType::get(
6070  vectorType.getElementType().getConstType(
6071  vectorType.isConst() || vectorType.getElementType().isConst()),
6072  refType.getForceable(), refType.getLayer());
6073  return emitInferRetTypeError(loc, "out of range index '", fieldIdx,
6074  "' in RefType of vector type ", refType);
6075  }
6076  if (auto bundleType = type_dyn_cast<BundleType>(inType)) {
6077  if (fieldIdx >= bundleType.getNumElements()) {
6078  return emitInferRetTypeError(loc,
6079  "subfield element index is greater than "
6080  "the number of fields in the bundle type");
6081  }
6082  return RefType::get(bundleType.getElement(fieldIdx).type.getConstType(
6083  bundleType.isConst() ||
6084  bundleType.getElement(fieldIdx).type.isConst()),
6085  refType.getForceable(), refType.getLayer());
6086  }
6087 
6088  return emitInferRetTypeError(
6089  loc, "ref.sub op requires a RefType of vector or bundle base type");
6090 }
6091 
6092 LogicalResult RefCastOp::verify() {
6093  auto srcLayers = getLayersFor(getInput());
6094  auto dstLayers = getLayersFor(getResult());
6095  SmallVector<SymbolRefAttr> missingLayers;
6096  if (!isLayerSetCompatibleWith(srcLayers, dstLayers, missingLayers)) {
6097  auto diag =
6098  emitOpError("cannot discard layer requirements of input reference");
6099  auto &note = diag.attachNote();
6100  note << "discarding layer requirements: ";
6101  llvm::interleaveComma(missingLayers, note);
6102  return failure();
6103  }
6104  return success();
6105 }
6106 
6107 LogicalResult RefResolveOp::verify() {
6108  auto srcLayers = getLayersFor(getRef());
6109  auto dstLayers = getAmbientLayersAt(getOperation());
6110  SmallVector<SymbolRefAttr> missingLayers;
6111  if (!isLayerSetCompatibleWith(srcLayers, dstLayers, missingLayers)) {
6112  auto diag =
6113  emitOpError("ambient layers are insufficient to resolve reference");
6114  auto &note = diag.attachNote();
6115  note << "missing layer requirements: ";
6116  interleaveComma(missingLayers, note);
6117  return failure();
6118  }
6119  return success();
6120 }
6121 
6122 LogicalResult RWProbeOp::verifyInnerRefs(hw::InnerRefNamespace &ns) {
6123  auto targetRef = getTarget();
6124  if (targetRef.getModule() !=
6125  (*this)->getParentOfType<FModuleLike>().getModuleNameAttr())
6126  return emitOpError() << "has non-local target";
6127 
6128  auto target = ns.lookup(targetRef);
6129  if (!target)
6130  return emitOpError() << "has target that cannot be resolved: " << targetRef;
6131 
6132  auto checkFinalType = [&](auto type, Location loc) -> LogicalResult {
6133  // Determine final type.
6134  mlir::Type fType =
6135  hw::FieldIdImpl::getFinalTypeByFieldID(type, target.getField());
6136  // Check.
6137  auto baseType = type_dyn_cast<FIRRTLBaseType>(fType);
6138  if (!baseType || baseType.getPassiveType() != getType().getType()) {
6139  auto diag = emitOpError("has type mismatch: target resolves to ")
6140  << fType << " instead of expected " << getType().getType();
6141  diag.attachNote(loc) << "target resolves here";
6142  return diag;
6143  }
6144  return success();
6145  };
6146  if (target.isPort()) {
6147  auto mod = cast<FModuleLike>(target.getOp());
6148  return checkFinalType(mod.getPortType(target.getPort()),
6149  mod.getPortLocation(target.getPort()));
6150  }
6151  hw::InnerSymbolOpInterface symOp =
6152  cast<hw::InnerSymbolOpInterface>(target.getOp());
6153  if (!symOp.getTargetResult())
6154  return emitOpError("has target that cannot be probed")
6155  .attachNote(symOp.getLoc())
6156  .append("target resolves here");
6157  auto *ancestor =
6158  symOp.getTargetResult().getParentBlock()->findAncestorOpInBlock(**this);
6159  if (!ancestor || !symOp->isBeforeInBlock(ancestor))
6160  return emitOpError("is not dominated by target")
6161  .attachNote(symOp.getLoc())
6162  .append("target here");
6163  return checkFinalType(symOp.getTargetResult().getType(), symOp.getLoc());
6164 }
6165 
6166 //===----------------------------------------------------------------------===//
6167 // Layer Block Operations
6168 //===----------------------------------------------------------------------===//
6169 
6170 LogicalResult LayerBlockOp::verify() {
6171  auto layerName = getLayerName();
6172  auto *parentOp = (*this)->getParentOp();
6173 
6174  // Verify the correctness of the symbol reference. Only verify that this
6175  // layer block makes sense in its parent module or layer block.
6176  auto nestedReferences = layerName.getNestedReferences();
6177  if (nestedReferences.empty()) {
6178  if (!isa<FModuleOp>(parentOp)) {
6179  auto diag = emitOpError() << "has an un-nested layer symbol, but does "
6180  "not have a 'firrtl.module' op as a parent";
6181  return diag.attachNote(parentOp->getLoc())
6182  << "illegal parent op defined here";
6183  }
6184  } else {
6185  auto parentLayerBlock = dyn_cast<LayerBlockOp>(parentOp);
6186  if (!parentLayerBlock) {
6187  auto diag = emitOpError()
6188  << "has a nested layer symbol, but does not have a '"
6189  << getOperationName() << "' op as a parent'";
6190  return diag.attachNote(parentOp->getLoc())
6191  << "illegal parent op defined here";
6192  }
6193  auto parentLayerBlockName = parentLayerBlock.getLayerName();
6194  if (parentLayerBlockName.getRootReference() !=
6195  layerName.getRootReference() ||
6196  parentLayerBlockName.getNestedReferences() !=
6197  layerName.getNestedReferences().drop_back()) {
6198  auto diag = emitOpError() << "is nested under an illegal layer block";
6199  return diag.attachNote(parentLayerBlock->getLoc())
6200  << "illegal parent layer block defined here";
6201  }
6202  }
6203 
6204  // Verify the body of the region.
6205  auto result = getBody(0)->walk<mlir::WalkOrder::PreOrder>(
6206  [&](Operation *op) -> WalkResult {
6207  // Skip nested layer blocks. Those will be verified separately.
6208  if (isa<LayerBlockOp>(op))
6209  return WalkResult::skip();
6210 
6211  // Check all the operands of each op to make sure that only legal things
6212  // are captured.
6213  for (auto operand : op->getOperands()) {
6214  // Any value captured from the current layer block is fine.
6215  if (auto *definingOp = operand.getDefiningOp())
6216  if (getOperation()->isAncestor(definingOp))
6217  continue;
6218 
6219  auto type = operand.getType();
6220 
6221  // Capture of a non-base type, e.g., reference, is allowed.
6222  if (isa<PropertyType>(type)) {
6223  auto diag = emitOpError() << "captures a property operand";
6224  diag.attachNote(operand.getLoc()) << "operand is defined here";
6225  diag.attachNote(op->getLoc()) << "operand is used here";
6226  return WalkResult::interrupt();
6227  }
6228 
6229  // Capturing a non-passive type is illegal.
6230  if (auto baseType = type_dyn_cast<FIRRTLBaseType>(type)) {
6231  if (!baseType.isPassive()) {
6232  auto diag = emitOpError()
6233  << "captures an operand which is not a passive type";
6234  diag.attachNote(operand.getLoc()) << "operand is defined here";
6235  diag.attachNote(op->getLoc()) << "operand is used here";
6236  return WalkResult::interrupt();
6237  }
6238  }
6239  }
6240 
6241  // Ensure that the layer block does not drive any sinks outside.
6242  if (auto connect = dyn_cast<FConnectLike>(op)) {
6243  // ref.define is allowed to drive probes outside the layerblock.
6244  if (isa<RefDefineOp>(connect))
6245  return WalkResult::advance();
6246 
6247  // We can drive any destination inside the current layerblock.
6248  auto dest = getFieldRefFromValue(connect.getDest()).getValue();
6249  if (auto *destOp = dest.getDefiningOp())
6250  if (getOperation()->isAncestor(destOp))
6251  return WalkResult::advance();
6252 
6253  auto diag =
6254  connect.emitOpError()
6255  << "connects to a destination which is defined outside its "
6256  "enclosing layer block";
6257  diag.attachNote(getLoc()) << "enclosing layer block is defined here";
6258  diag.attachNote(dest.getLoc()) << "destination is defined here";
6259  return WalkResult::interrupt();
6260  }
6261 
6262  return WalkResult::advance();
6263  });
6264 
6265  return failure(result.wasInterrupted());
6266 }
6267 
6268 LogicalResult
6269 LayerBlockOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
6270  auto layerOp =
6271  symbolTable.lookupNearestSymbolFrom<LayerOp>(*this, getLayerNameAttr());
6272  if (!layerOp) {
6273  return emitOpError("invalid symbol reference");
6274  }
6275 
6276  return success();
6277 }
6278 
6279 //===----------------------------------------------------------------------===//
6280 // TblGen Generated Logic.
6281 //===----------------------------------------------------------------------===//
6282 
6283 // Provide the autogenerated implementation guts for the Op classes.
6284 #define GET_OP_CLASSES
6285 #include "circt/Dialect/FIRRTL/FIRRTL.cpp.inc"
assert(baseType &&"element must be base type")
MlirType uint64_t numElements
Definition: CHIRRTL.cpp:30
MlirType elementType
Definition: CHIRRTL.cpp:29
#define isdigit(x)
Definition: FIRLexer.cpp:26
static bool printModulePorts(OpAsmPrinter &p, Block *block, ArrayRef< bool > portDirections, ArrayRef< Attribute > portNames, ArrayRef< Attribute > portTypes, ArrayRef< Attribute > portAnnotations, ArrayRef< Attribute > portSyms, ArrayRef< Attribute > portLocs)
Print a list of module ports in the following form: in x: !firrtl.uint<1> [{class = "DontTouch}],...
Definition: FIRRTLOps.cpp:1102
static LogicalResult verifyProbeType(RefType refType, Location loc, CircuitOp circuitOp, SymbolTableCollection &symbolTable, Twine start)
Definition: FIRRTLOps.cpp:1632
static SmallVector< T > removeElementsAtIndices(ArrayRef< T > input, const llvm::BitVector &indicesToDrop)
Remove elements from the input array corresponding to set bits in indicesToDrop, returning the elemen...
Definition: FIRRTLOps.cpp:53
static void printStopAttrs(OpAsmPrinter &p, Operation *op, DictionaryAttr attr)
Definition: FIRRTLOps.cpp:5782
static LayerSet getLayersFor(Value value)
Get the effective layer requirements for the given value.
Definition: FIRRTLOps.cpp:408
ParseResult parseSubfieldLikeOp(OpAsmParser &parser, OperationState &result)
Definition: FIRRTLOps.cpp:4416
static bool isSameIntTypeKind(Type lhs, Type rhs, int32_t &lhsWidth, int32_t &rhsWidth, bool &isConstResult, std::optional< Location > loc)
If LHS and RHS are both UInt or SInt types, the return true and fill in the width of them if known.
Definition: FIRRTLOps.cpp:4854
static LogicalResult verifySubfieldLike(OpTy op)