CIRCT  19.0.0git
HWOps.cpp
Go to the documentation of this file.
1 //===- HWOps.cpp - Implement the HW operations ----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implement the HW ops.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "circt/Dialect/HW/HWOps.h"
23 #include "circt/Support/Naming.h"
24 #include "mlir/IR/Builders.h"
25 #include "mlir/IR/PatternMatch.h"
26 #include "mlir/Interfaces/FunctionImplementation.h"
27 #include "llvm/ADT/BitVector.h"
28 #include "llvm/ADT/SmallPtrSet.h"
29 #include "llvm/ADT/StringSet.h"
30 
31 using namespace circt;
32 using namespace hw;
33 using mlir::TypedAttr;
34 
35 /// Flip a port direction.
37  switch (direction) {
44  }
45  llvm_unreachable("unknown PortDirection");
46 }
47 
48 bool hw::isValidIndexBitWidth(Value index, Value array) {
49  hw::ArrayType arrayType =
50  hw::getCanonicalType(array.getType()).dyn_cast<hw::ArrayType>();
51  assert(arrayType && "expected array type");
52  unsigned indexWidth = index.getType().getIntOrFloatBitWidth();
53  auto requiredWidth = llvm::Log2_64_Ceil(arrayType.getNumElements());
54  return requiredWidth == 0 ? (indexWidth == 0 || indexWidth == 1)
55  : indexWidth == requiredWidth;
56 }
57 
58 /// Return true if the specified operation is a combinational logic op.
59 bool hw::isCombinational(Operation *op) {
60  struct IsCombClassifier : public TypeOpVisitor<IsCombClassifier, bool> {
61  bool visitInvalidTypeOp(Operation *op) { return false; }
62  bool visitUnhandledTypeOp(Operation *op) { return true; }
63  };
64 
65  return (op->getDialect() && op->getDialect()->getNamespace() == "comb") ||
66  IsCombClassifier().dispatchTypeOpVisitor(op);
67 }
68 
69 static Value foldStructExtract(Operation *inputOp, uint32_t fieldIndex) {
70  // A struct extract of a struct create -> corresponding struct create operand.
71  if (auto structCreate = dyn_cast_or_null<StructCreateOp>(inputOp)) {
72  return structCreate.getOperand(fieldIndex);
73  }
74 
75  // Extracting injected field -> corresponding field
76  if (auto structInject = dyn_cast_or_null<StructInjectOp>(inputOp)) {
77  if (structInject.getFieldIndex() != fieldIndex)
78  return {};
79  return structInject.getNewValue();
80  }
81  return {};
82 }
83 
84 static ArrayAttr arrayOrEmpty(mlir::MLIRContext *context,
85  ArrayRef<Attribute> attrs) {
86  if (attrs.empty())
87  return ArrayAttr::get(context, {});
88  bool empty = true;
89  for (auto a : attrs)
90  if (a && !cast<DictionaryAttr>(a).empty()) {
91  empty = false;
92  break;
93  }
94  if (empty)
95  return ArrayAttr::get(context, {});
96  return ArrayAttr::get(context, attrs);
97 }
98 
99 /// Get a special name to use when printing the entry block arguments of the
100 /// region contained by an operation in this dialect.
101 static void getAsmBlockArgumentNamesImpl(mlir::Region &region,
102  OpAsmSetValueNameFn setNameFn) {
103  if (region.empty())
104  return;
105  // Assign port names to the bbargs.
106  auto module = cast<HWModuleOp>(region.getParentOp());
107 
108  auto *block = &region.front();
109  for (size_t i = 0, e = block->getNumArguments(); i != e; ++i) {
110  auto name = module.getInputName(i);
111  // Let mlir deterministically convert names to valid identifiers
112  setNameFn(block->getArgument(i), name);
113  }
114 }
115 
116 enum class Delimiter {
117  None,
118  Paren, // () enclosed list
119  OptionalLessGreater, // <> enclosed list or absent
120 };
121 
122 /// Check parameter specified by `value` to see if it is valid according to the
123 /// module's parameters. If not, emit an error to the diagnostic provided as an
124 /// argument to the lambda 'instanceError' and return failure, otherwise return
125 /// success.
126 ///
127 /// If `disallowParamRefs` is true, then parameter references are not allowed.
129  Attribute value, ArrayAttr moduleParameters,
130  const instance_like_impl::EmitErrorFn &instanceError,
131  bool disallowParamRefs) {
132  // Literals are always ok. Their types are already known to match
133  // expectations.
134  if (value.isa<IntegerAttr>() || value.isa<FloatAttr>() ||
135  value.isa<StringAttr>() || value.isa<ParamVerbatimAttr>())
136  return success();
137 
138  // Check both subexpressions of an expression.
139  if (auto expr = value.dyn_cast<ParamExprAttr>()) {
140  for (auto op : expr.getOperands())
141  if (failed(checkParameterInContext(op, moduleParameters, instanceError,
142  disallowParamRefs)))
143  return failure();
144  return success();
145  }
146 
147  // Parameter references need more analysis to make sure they are valid within
148  // this module.
149  if (auto parameterRef = value.dyn_cast<ParamDeclRefAttr>()) {
150  auto nameAttr = parameterRef.getName();
151 
152  // Don't allow references to parameters from the default values of a
153  // parameter list.
154  if (disallowParamRefs) {
155  instanceError([&](auto &diag) {
156  diag << "parameter " << nameAttr
157  << " cannot be used as a default value for a parameter";
158  return false;
159  });
160  return failure();
161  }
162 
163  // Find the corresponding attribute in the module.
164  for (auto param : moduleParameters) {
165  auto paramAttr = param.cast<ParamDeclAttr>();
166  if (paramAttr.getName() != nameAttr)
167  continue;
168 
169  // If the types match then the reference is ok.
170  if (paramAttr.getType() == parameterRef.getType())
171  return success();
172 
173  instanceError([&](auto &diag) {
174  diag << "parameter " << nameAttr << " used with type "
175  << parameterRef.getType() << "; should have type "
176  << paramAttr.getType();
177  return true;
178  });
179  return failure();
180  }
181 
182  instanceError([&](auto &diag) {
183  diag << "use of unknown parameter " << nameAttr;
184  return true;
185  });
186  return failure();
187  }
188 
189  instanceError([&](auto &diag) {
190  diag << "invalid parameter value " << value;
191  return false;
192  });
193  return failure();
194 }
195 
196 /// Check parameter specified by `value` to see if it is valid within the scope
197 /// of the specified module `module`. If not, emit an error at the location of
198 /// `usingOp` and return failure, otherwise return success. If `usingOp` is
199 /// null, then no diagnostic is generated.
200 ///
201 /// If `disallowParamRefs` is true, then parameter references are not allowed.
202 LogicalResult hw::checkParameterInContext(Attribute value, Operation *module,
203  Operation *usingOp,
204  bool disallowParamRefs) {
206  [&](const std::function<bool(InFlightDiagnostic &)> &fn) {
207  if (usingOp) {
208  auto diag = usingOp->emitOpError();
209  if (fn(diag))
210  diag.attachNote(module->getLoc()) << "module declared here";
211  }
212  };
213 
214  return checkParameterInContext(value,
215  module->getAttrOfType<ArrayAttr>("parameters"),
216  emitError, disallowParamRefs);
217 }
218 
219 /// Return true if the specified attribute tree is made up of nodes that are
220 /// valid in a parameter expression.
221 bool hw::isValidParameterExpression(Attribute attr, Operation *module) {
222  return succeeded(checkParameterInContext(attr, module, nullptr, false));
223 }
224 
226  const ModulePortInfo &info,
227  Region &bodyRegion)
228  : info(info) {
229  inputArgs.resize(info.sizeInputs());
230  for (auto [i, barg] : llvm::enumerate(bodyRegion.getArguments())) {
231  inputIdx[info.at(i).name.str()] = i;
232  inputArgs[i] = barg;
233  }
234 
235  outputOperands.resize(info.sizeOutputs());
236  for (auto [i, outputInfo] : llvm::enumerate(info.getOutputs())) {
237  outputIdx[outputInfo.name.str()] = i;
238  }
239 }
240 
241 void HWModulePortAccessor::setOutput(unsigned i, Value v) {
242  assert(outputOperands.size() > i && "invalid output index");
243  assert(outputOperands[i] == Value() && "output already set");
244  outputOperands[i] = v;
245 }
246 
248  assert(inputArgs.size() > i && "invalid input index");
249  return inputArgs[i];
250 }
251 Value HWModulePortAccessor::getInput(StringRef name) {
252  return getInput(inputIdx.find(name.str())->second);
253 }
254 void HWModulePortAccessor::setOutput(StringRef name, Value v) {
255  setOutput(outputIdx.find(name.str())->second, v);
256 }
257 
258 //===----------------------------------------------------------------------===//
259 // ConstantOp
260 //===----------------------------------------------------------------------===//
261 
262 void ConstantOp::print(OpAsmPrinter &p) {
263  p << " ";
264  p.printAttribute(getValueAttr());
265  p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"value"});
266 }
267 
268 ParseResult ConstantOp::parse(OpAsmParser &parser, OperationState &result) {
269  IntegerAttr valueAttr;
270 
271  if (parser.parseAttribute(valueAttr, "value", result.attributes) ||
272  parser.parseOptionalAttrDict(result.attributes))
273  return failure();
274 
275  result.addTypes(valueAttr.getType());
276  return success();
277 }
278 
279 LogicalResult ConstantOp::verify() {
280  // If the result type has a bitwidth, then the attribute must match its width.
281  if (getValue().getBitWidth() != getType().cast<IntegerType>().getWidth())
282  return emitError(
283  "hw.constant attribute bitwidth doesn't match return type");
284 
285  return success();
286 }
287 
288 /// Build a ConstantOp from an APInt, infering the result type from the
289 /// width of the APInt.
290 void ConstantOp::build(OpBuilder &builder, OperationState &result,
291  const APInt &value) {
292 
293  auto type = IntegerType::get(builder.getContext(), value.getBitWidth());
294  auto attr = builder.getIntegerAttr(type, value);
295  return build(builder, result, type, attr);
296 }
297 
298 /// Build a ConstantOp from an APInt, infering the result type from the
299 /// width of the APInt.
300 void ConstantOp::build(OpBuilder &builder, OperationState &result,
301  IntegerAttr value) {
302  return build(builder, result, value.getType(), value);
303 }
304 
305 /// This builder allows construction of small signed integers like 0, 1, -1
306 /// matching a specified MLIR IntegerType. This shouldn't be used for general
307 /// constant folding because it only works with values that can be expressed in
308 /// an int64_t. Use APInt's instead.
309 void ConstantOp::build(OpBuilder &builder, OperationState &result, Type type,
310  int64_t value) {
311  auto numBits = type.cast<IntegerType>().getWidth();
312  build(builder, result, APInt(numBits, (uint64_t)value, /*isSigned=*/true));
313 }
314 
316  function_ref<void(Value, StringRef)> setNameFn) {
317  auto intTy = getType();
318  auto intCst = getValue();
319 
320  // Sugar i1 constants with 'true' and 'false'.
321  if (intTy.cast<IntegerType>().getWidth() == 1)
322  return setNameFn(getResult(), intCst.isZero() ? "false" : "true");
323 
324  // Otherwise, build a complex name with the value and type.
325  SmallVector<char, 32> specialNameBuffer;
326  llvm::raw_svector_ostream specialName(specialNameBuffer);
327  specialName << 'c' << intCst << '_' << intTy;
328  setNameFn(getResult(), specialName.str());
329 }
330 
331 OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) {
332  assert(adaptor.getOperands().empty() && "constant has no operands");
333  return getValueAttr();
334 }
335 
336 //===----------------------------------------------------------------------===//
337 // WireOp
338 //===----------------------------------------------------------------------===//
339 
340 /// Check whether an operation has any additional attributes set beyond its
341 /// standard list of attributes returned by `getAttributeNames`.
342 template <class Op>
343 static bool hasAdditionalAttributes(Op op,
344  ArrayRef<StringRef> ignoredAttrs = {}) {
345  auto names = op.getAttributeNames();
346  llvm::SmallDenseSet<StringRef> nameSet;
347  nameSet.reserve(names.size() + ignoredAttrs.size());
348  nameSet.insert(names.begin(), names.end());
349  nameSet.insert(ignoredAttrs.begin(), ignoredAttrs.end());
350  return llvm::any_of(op->getAttrs(), [&](auto namedAttr) {
351  return !nameSet.contains(namedAttr.getName());
352  });
353 }
354 
356  // If the wire has an optional 'name' attribute, use it.
357  auto nameAttr = (*this)->getAttrOfType<StringAttr>("name");
358  if (nameAttr && !nameAttr.getValue().empty())
359  setNameFn(getResult(), nameAttr.getValue());
360 }
361 
362 std::optional<size_t> WireOp::getTargetResultIndex() { return 0; }
363 
364 OpFoldResult WireOp::fold(FoldAdaptor adaptor) {
365  // If the wire has no additional attributes, no name, and no symbol, just
366  // forward its input.
367  if (!hasAdditionalAttributes(*this, {"sv.namehint"}) && !getNameAttr() &&
368  !getInnerSymAttr())
369  return getInput();
370  return {};
371 }
372 
373 LogicalResult WireOp::canonicalize(WireOp wire, PatternRewriter &rewriter) {
374  // Block if the wire has any attributes.
375  if (hasAdditionalAttributes(wire, {"sv.namehint"}))
376  return failure();
377 
378  // If the wire has a symbol, then we can't delete it.
379  if (wire.getInnerSymAttr())
380  return failure();
381 
382  // If the wire has a name or an `sv.namehint` attribute, propagate it as an
383  // `sv.namehint` to the expression.
384  if (auto *inputOp = wire.getInput().getDefiningOp())
385  if (auto name = chooseName(wire, inputOp))
386  rewriter.modifyOpInPlace(inputOp,
387  [&] { inputOp->setAttr("sv.namehint", name); });
388 
389  rewriter.replaceOp(wire, wire.getInput());
390  return success();
391 }
392 
393 //===----------------------------------------------------------------------===//
394 // AggregateConstantOp
395 //===----------------------------------------------------------------------===//
396 
397 static LogicalResult checkAttributes(Operation *op, Attribute attr, Type type) {
398  // If this is a type alias, get the underlying type.
399  if (auto typeAlias = type.dyn_cast<TypeAliasType>())
400  type = typeAlias.getCanonicalType();
401 
402  if (auto structType = type.dyn_cast<StructType>()) {
403  auto arrayAttr = attr.dyn_cast<ArrayAttr>();
404  if (!arrayAttr)
405  return op->emitOpError("expected array attribute for constant of type ")
406  << type;
407  if (structType.getElements().size() != arrayAttr.size())
408  return op->emitOpError("array attribute (")
409  << arrayAttr.size() << ") has wrong size for struct constant ("
410  << structType.getElements().size() << ")";
411 
412  for (auto [attr, fieldInfo] :
413  llvm::zip(arrayAttr.getValue(), structType.getElements())) {
414  if (failed(checkAttributes(op, attr, fieldInfo.type)))
415  return failure();
416  }
417  } else if (auto arrayType = type.dyn_cast<ArrayType>()) {
418  auto arrayAttr = attr.dyn_cast<ArrayAttr>();
419  if (!arrayAttr)
420  return op->emitOpError("expected array attribute for constant of type ")
421  << type;
422  if (arrayType.getNumElements() != arrayAttr.size())
423  return op->emitOpError("array attribute (")
424  << arrayAttr.size() << ") has wrong size for array constant ("
425  << arrayType.getNumElements() << ")";
426 
427  auto elementType = arrayType.getElementType();
428  for (auto attr : arrayAttr.getValue()) {
429  if (failed(checkAttributes(op, attr, elementType)))
430  return failure();
431  }
432  } else if (auto arrayType = type.dyn_cast<UnpackedArrayType>()) {
433  auto arrayAttr = attr.dyn_cast<ArrayAttr>();
434  if (!arrayAttr)
435  return op->emitOpError("expected array attribute for constant of type ")
436  << type;
437  auto elementType = arrayType.getElementType();
438  if (arrayType.getNumElements() != arrayAttr.size())
439  return op->emitOpError("array attribute (")
440  << arrayAttr.size()
441  << ") has wrong size for unpacked array constant ("
442  << arrayType.getNumElements() << ")";
443 
444  for (auto attr : arrayAttr.getValue()) {
445  if (failed(checkAttributes(op, attr, elementType)))
446  return failure();
447  }
448  } else if (auto enumType = type.dyn_cast<EnumType>()) {
449  auto stringAttr = attr.dyn_cast<StringAttr>();
450  if (!stringAttr)
451  return op->emitOpError("expected string attribute for constant of type ")
452  << type;
453  } else if (auto intType = type.dyn_cast<IntegerType>()) {
454  // Check the attribute kind is correct.
455  auto intAttr = attr.dyn_cast<IntegerAttr>();
456  if (!intAttr)
457  return op->emitOpError("expected integer attribute for constant of type ")
458  << type;
459  // Check the bitwidth is correct.
460  if (intAttr.getValue().getBitWidth() != intType.getWidth())
461  return op->emitOpError("hw.constant attribute bitwidth "
462  "doesn't match return type");
463  } else {
464  return op->emitOpError("unknown element type") << type;
465  }
466  return success();
467 }
468 
469 LogicalResult AggregateConstantOp::verify() {
470  return checkAttributes(*this, getFieldsAttr(), getType());
471 }
472 
473 OpFoldResult AggregateConstantOp::fold(FoldAdaptor) { return getFieldsAttr(); }
474 
475 //===----------------------------------------------------------------------===//
476 // ParamValueOp
477 //===----------------------------------------------------------------------===//
478 
479 static ParseResult parseParamValue(OpAsmParser &p, Attribute &value,
480  Type &resultType) {
481  if (p.parseType(resultType) || p.parseEqual() ||
482  p.parseAttribute(value, resultType))
483  return failure();
484  return success();
485 }
486 
487 static void printParamValue(OpAsmPrinter &p, Operation *, Attribute value,
488  Type resultType) {
489  p << resultType << " = ";
490  p.printAttributeWithoutType(value);
491 }
492 
493 LogicalResult ParamValueOp::verify() {
494  // Check that the attribute expression is valid in this module.
496  getValue(), (*this)->getParentOfType<hw::HWModuleOp>(), *this);
497 }
498 
499 OpFoldResult ParamValueOp::fold(FoldAdaptor adaptor) {
500  assert(adaptor.getOperands().empty() && "hw.param.value has no operands");
501  return getValueAttr();
502 }
503 
504 //===----------------------------------------------------------------------===//
505 // HWModuleOp
506 //===----------------------------------------------------------------------===/
507 
508 /// Return true if isAnyModule or instance.
509 bool hw::isAnyModuleOrInstance(Operation *moduleOrInstance) {
510  return isa<HWModuleLike, InstanceOp>(moduleOrInstance);
511 }
512 
513 /// Return the signature for a module as a function type from the module itself
514 /// or from an hw::InstanceOp.
515 FunctionType hw::getModuleType(Operation *moduleOrInstance) {
516  return TypeSwitch<Operation *, FunctionType>(moduleOrInstance)
517  .Case<InstanceOp, InstanceChoiceOp>([](auto instance) {
518  SmallVector<Type> inputs(instance->getOperandTypes());
519  SmallVector<Type> results(instance->getResultTypes());
520  return FunctionType::get(instance->getContext(), inputs, results);
521  })
522  .Case<HWModuleLike>(
523  [](auto mod) { return mod.getHWModuleType().getFuncType(); })
524  .Default([](Operation *op) {
525  return cast<mlir::FunctionOpInterface>(op)
526  .getFunctionType()
527  .cast<FunctionType>();
528  });
529 }
530 
531 /// Return the name to use for the Verilog module that we're referencing
532 /// here. This is typically the symbol, but can be overridden with the
533 /// verilogName attribute.
534 StringAttr hw::getVerilogModuleNameAttr(Operation *module) {
535  auto nameAttr = module->getAttrOfType<StringAttr>("verilogName");
536  if (nameAttr)
537  return nameAttr;
538 
539  return module->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName());
540 }
541 
542 template <typename ModuleTy>
543 static void
544 buildModule(OpBuilder &builder, OperationState &result, StringAttr name,
545  const ModulePortInfo &ports, ArrayAttr parameters,
546  ArrayRef<NamedAttribute> attributes, StringAttr comment) {
547  using namespace mlir::function_interface_impl;
548 
549  // Add an attribute for the name.
550  result.addAttribute(SymbolTable::getSymbolAttrName(), name);
551 
552  SmallVector<Attribute> perPortAttrs;
553  SmallVector<ModulePort> portTypes;
554 
555  for (auto elt : ports) {
556  portTypes.push_back(elt);
557  llvm::SmallVector<NamedAttribute> portAttrs;
558  if (elt.attrs)
559  llvm::copy(elt.attrs, std::back_inserter(portAttrs));
560  perPortAttrs.push_back(builder.getDictionaryAttr(portAttrs));
561  }
562 
563  // Allow clients to pass in null for the parameters list.
564  if (!parameters)
565  parameters = builder.getArrayAttr({});
566 
567  // Record the argument and result types as an attribute.
568  auto type = ModuleType::get(builder.getContext(), portTypes);
569  result.addAttribute(ModuleTy::getModuleTypeAttrName(result.name),
570  TypeAttr::get(type));
571  result.addAttribute("per_port_attrs",
572  arrayOrEmpty(builder.getContext(), perPortAttrs));
573  result.addAttribute("parameters", parameters);
574  if (!comment)
575  comment = builder.getStringAttr("");
576  result.addAttribute("comment", comment);
577  result.addAttributes(attributes);
578  result.addRegion();
579 }
580 
581 /// Internal implementation of argument/result insertion and removal on modules.
582 static void modifyModuleArgs(
583  MLIRContext *context, ArrayRef<std::pair<unsigned, PortInfo>> insertArgs,
584  ArrayRef<unsigned> removeArgs, ArrayRef<Attribute> oldArgNames,
585  ArrayRef<Type> oldArgTypes, ArrayRef<Attribute> oldArgAttrs,
586  ArrayRef<Location> oldArgLocs, SmallVector<Attribute> &newArgNames,
587  SmallVector<Type> &newArgTypes, SmallVector<Attribute> &newArgAttrs,
588  SmallVector<Location> &newArgLocs, Block *body = nullptr) {
589 
590 #ifndef NDEBUG
591  // Check that the `insertArgs` and `removeArgs` indices are in ascending
592  // order.
593  assert(llvm::is_sorted(insertArgs,
594  [](auto &a, auto &b) { return a.first < b.first; }) &&
595  "insertArgs must be in ascending order");
596  assert(llvm::is_sorted(removeArgs, [](auto &a, auto &b) { return a < b; }) &&
597  "removeArgs must be in ascending order");
598 #endif
599 
600  auto oldArgCount = oldArgTypes.size();
601  auto newArgCount = oldArgCount + insertArgs.size() - removeArgs.size();
602  assert((int)newArgCount >= 0);
603 
604  newArgNames.reserve(newArgCount);
605  newArgTypes.reserve(newArgCount);
606  newArgAttrs.reserve(newArgCount);
607  newArgLocs.reserve(newArgCount);
608 
609  auto exportPortAttrName = StringAttr::get(context, "hw.exportPort");
610  auto emptyDictAttr = DictionaryAttr::get(context, {});
611  auto unknownLoc = UnknownLoc::get(context);
612 
613  BitVector erasedIndices;
614  if (body)
615  erasedIndices.resize(oldArgCount + insertArgs.size());
616 
617  for (unsigned argIdx = 0, idx = 0; argIdx <= oldArgCount; ++argIdx, ++idx) {
618  // Insert new ports at this position.
619  while (!insertArgs.empty() && insertArgs[0].first == argIdx) {
620  auto port = insertArgs[0].second;
621  if (port.dir == ModulePort::Direction::InOut &&
622  !port.type.isa<InOutType>())
623  port.type = InOutType::get(port.type);
624  auto sym = port.getSym();
625  Attribute attr =
626  (sym && !sym.empty())
627  ? DictionaryAttr::get(context, {{exportPortAttrName, sym}})
628  : emptyDictAttr;
629  newArgNames.push_back(port.name);
630  newArgTypes.push_back(port.type);
631  newArgAttrs.push_back(attr);
632  insertArgs = insertArgs.drop_front();
633  LocationAttr loc = port.loc ? port.loc : unknownLoc;
634  newArgLocs.push_back(loc);
635  if (body)
636  body->insertArgument(idx++, port.type, loc);
637  }
638  if (argIdx == oldArgCount)
639  break;
640 
641  // Migrate the old port at this position.
642  bool removed = false;
643  while (!removeArgs.empty() && removeArgs[0] == argIdx) {
644  removeArgs = removeArgs.drop_front();
645  removed = true;
646  }
647 
648  if (removed) {
649  if (body)
650  erasedIndices.set(idx);
651  } else {
652  newArgNames.push_back(oldArgNames[argIdx]);
653  newArgTypes.push_back(oldArgTypes[argIdx]);
654  newArgAttrs.push_back(oldArgAttrs.empty() ? emptyDictAttr
655  : oldArgAttrs[argIdx]);
656  newArgLocs.push_back(oldArgLocs[argIdx]);
657  }
658  }
659 
660  if (body)
661  body->eraseArguments(erasedIndices);
662 
663  assert(newArgNames.size() == newArgCount);
664  assert(newArgTypes.size() == newArgCount);
665  assert(newArgAttrs.size() == newArgCount);
666  assert(newArgLocs.size() == newArgCount);
667 }
668 
669 /// Insert and remove ports of a module. The insertion and removal indices must
670 /// be in ascending order. The indices refer to the port positions before any
671 /// insertion or removal occurs. Ports inserted at the same index will appear in
672 /// the module in the same order as they were listed in the `insert*` array.
673 ///
674 /// The operation must be any of the module-like operations.
675 ///
676 /// This is marked deprecated as it's only used from HandshakeToHW and
677 /// PortConverter and is likely broken and not currently tested. Users of this
678 /// are still written dealing with input and output ports separately, which is
679 /// an old and broken style.
680 [[deprecated]] static void
681 modifyModulePorts(Operation *op,
682  ArrayRef<std::pair<unsigned, PortInfo>> insertInputs,
683  ArrayRef<std::pair<unsigned, PortInfo>> insertOutputs,
684  ArrayRef<unsigned> removeInputs,
685  ArrayRef<unsigned> removeOutputs, Block *body = nullptr) {
686  auto moduleOp = cast<HWModuleLike>(op);
687  auto *context = moduleOp.getContext();
688 
689  // Dig up the old argument and result data.
690  auto oldArgNames = moduleOp.getInputNames();
691  auto oldArgTypes = moduleOp.getInputTypes();
692  auto oldArgAttrs = moduleOp.getAllInputAttrs();
693  auto oldArgLocs = moduleOp.getInputLocs();
694 
695  auto oldResultNames = moduleOp.getOutputNames();
696  auto oldResultTypes = moduleOp.getOutputTypes();
697  auto oldResultAttrs = moduleOp.getAllOutputAttrs();
698  auto oldResultLocs = moduleOp.getOutputLocs();
699 
700  // Modify the ports.
701  SmallVector<Attribute> newArgNames, newResultNames;
702  SmallVector<Type> newArgTypes, newResultTypes;
703  SmallVector<Attribute> newArgAttrs, newResultAttrs;
704  SmallVector<Location> newArgLocs, newResultLocs;
705 
706  modifyModuleArgs(context, insertInputs, removeInputs, oldArgNames,
707  oldArgTypes, oldArgAttrs, oldArgLocs, newArgNames,
708  newArgTypes, newArgAttrs, newArgLocs, body);
709 
710  modifyModuleArgs(context, insertOutputs, removeOutputs, oldResultNames,
711  oldResultTypes, oldResultAttrs, oldResultLocs,
712  newResultNames, newResultTypes, newResultAttrs,
713  newResultLocs);
714 
715  // Update the module operation types and attributes.
716  auto fnty = FunctionType::get(context, newArgTypes, newResultTypes);
717  auto modty = detail::fnToMod(fnty, newArgNames, newResultNames);
718  moduleOp.setHWModuleType(modty);
719  moduleOp.setAllInputAttrs(newArgAttrs);
720  moduleOp.setAllOutputAttrs(newResultAttrs);
721 
722  newArgLocs.append(newResultLocs.begin(), newResultLocs.end());
723  moduleOp.setAllPortLocs(newArgLocs);
724 }
725 
726 void HWModuleOp::build(OpBuilder &builder, OperationState &result,
727  StringAttr name, const ModulePortInfo &ports,
728  ArrayAttr parameters,
729  ArrayRef<NamedAttribute> attributes, StringAttr comment,
730  bool shouldEnsureTerminator) {
731  buildModule<HWModuleOp>(builder, result, name, ports, parameters, attributes,
732  comment);
733 
734  // Create a region and a block for the body.
735  auto *bodyRegion = result.regions[0].get();
736  Block *body = new Block();
737  bodyRegion->push_back(body);
738 
739  // Add arguments to the body block.
740  auto unknownLoc = builder.getUnknownLoc();
741  for (auto port : ports.getInputs()) {
742  auto loc = port.loc ? Location(port.loc) : unknownLoc;
743  auto type = port.type;
744  if (port.isInOut() && !type.isa<InOutType>())
745  type = InOutType::get(type);
746  body->addArgument(type, loc);
747  }
748 
749  // Add result ports attribute.
750  auto unknownLocAttr = cast<LocationAttr>(unknownLoc);
751  SmallVector<Attribute> resultLocs;
752  for (auto port : ports.getOutputs())
753  resultLocs.push_back(port.loc ? port.loc : unknownLocAttr);
754  result.addAttribute("result_locs", builder.getArrayAttr(resultLocs));
755 
756  if (shouldEnsureTerminator)
757  HWModuleOp::ensureTerminator(*bodyRegion, builder, result.location);
758 }
759 
760 void HWModuleOp::build(OpBuilder &builder, OperationState &result,
761  StringAttr name, ArrayRef<PortInfo> ports,
762  ArrayAttr parameters,
763  ArrayRef<NamedAttribute> attributes,
764  StringAttr comment) {
765  build(builder, result, name, ModulePortInfo(ports), parameters, attributes,
766  comment);
767 }
768 
769 void HWModuleOp::build(OpBuilder &builder, OperationState &odsState,
770  StringAttr name, const ModulePortInfo &ports,
771  HWModuleBuilder modBuilder, ArrayAttr parameters,
772  ArrayRef<NamedAttribute> attributes,
773  StringAttr comment) {
774  build(builder, odsState, name, ports, parameters, attributes, comment,
775  /*shouldEnsureTerminator=*/false);
776  auto *bodyRegion = odsState.regions[0].get();
777  OpBuilder::InsertionGuard guard(builder);
778  auto accessor = HWModulePortAccessor(odsState.location, ports, *bodyRegion);
779  builder.setInsertionPointToEnd(&bodyRegion->front());
780  modBuilder(builder, accessor);
781  // Create output operands.
782  llvm::SmallVector<Value> outputOperands = accessor.getOutputOperands();
783  builder.create<hw::OutputOp>(odsState.location, outputOperands);
784 }
785 
786 void HWModuleOp::modifyPorts(
787  ArrayRef<std::pair<unsigned, PortInfo>> insertInputs,
788  ArrayRef<std::pair<unsigned, PortInfo>> insertOutputs,
789  ArrayRef<unsigned> eraseInputs, ArrayRef<unsigned> eraseOutputs) {
790  modifyModulePorts(*this, insertInputs, insertOutputs, eraseInputs,
791  eraseOutputs);
792 }
793 
794 /// Return the name to use for the Verilog module that we're referencing
795 /// here. This is typically the symbol, but can be overridden with the
796 /// verilogName attribute.
798  if (auto vName = getVerilogNameAttr())
799  return vName;
800 
801  return (*this)->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName());
802 }
803 
805  if (auto vName = getVerilogNameAttr()) {
806  return vName;
807  }
808  return (*this)->getAttrOfType<StringAttr>(
809  ::mlir::SymbolTable::getSymbolAttrName());
810 }
811 
812 void HWModuleExternOp::build(OpBuilder &builder, OperationState &result,
813  StringAttr name, const ModulePortInfo &ports,
814  StringRef verilogName, ArrayAttr parameters,
815  ArrayRef<NamedAttribute> attributes) {
816  buildModule<HWModuleExternOp>(builder, result, name, ports, parameters,
817  attributes, {});
818 
819  // Add the port locations.
820  LocationAttr unknownLoc = builder.getUnknownLoc();
821  SmallVector<Attribute> portLocs;
822  for (auto elt : ports)
823  portLocs.push_back(elt.loc ? elt.loc : unknownLoc);
824  result.addAttribute("port_locs", builder.getArrayAttr(portLocs));
825 
826  if (!verilogName.empty())
827  result.addAttribute("verilogName", builder.getStringAttr(verilogName));
828 }
829 
830 void HWModuleExternOp::build(OpBuilder &builder, OperationState &result,
831  StringAttr name, ArrayRef<PortInfo> ports,
832  StringRef verilogName, ArrayAttr parameters,
833  ArrayRef<NamedAttribute> attributes) {
834  build(builder, result, name, ModulePortInfo(ports), verilogName, parameters,
835  attributes);
836 }
837 
838 void HWModuleExternOp::modifyPorts(
839  ArrayRef<std::pair<unsigned, PortInfo>> insertInputs,
840  ArrayRef<std::pair<unsigned, PortInfo>> insertOutputs,
841  ArrayRef<unsigned> eraseInputs, ArrayRef<unsigned> eraseOutputs) {
842  modifyModulePorts(*this, insertInputs, insertOutputs, eraseInputs,
843  eraseOutputs);
844 }
845 
846 void HWModuleExternOp::appendOutputs(
847  ArrayRef<std::pair<StringAttr, Value>> outputs) {}
848 
849 void HWModuleGeneratedOp::build(OpBuilder &builder, OperationState &result,
850  FlatSymbolRefAttr genKind, StringAttr name,
851  const ModulePortInfo &ports,
852  StringRef verilogName, ArrayAttr parameters,
853  ArrayRef<NamedAttribute> attributes) {
854  buildModule<HWModuleGeneratedOp>(builder, result, name, ports, parameters,
855  attributes, {});
856  // Add the port locations.
857  LocationAttr unknownLoc = builder.getUnknownLoc();
858  SmallVector<Attribute> portLocs;
859  for (auto elt : ports)
860  portLocs.push_back(elt.loc ? elt.loc : unknownLoc);
861  result.addAttribute("port_locs", builder.getArrayAttr(portLocs));
862 
863  result.addAttribute("generatorKind", genKind);
864  if (!verilogName.empty())
865  result.addAttribute("verilogName", builder.getStringAttr(verilogName));
866 }
867 
868 void HWModuleGeneratedOp::build(OpBuilder &builder, OperationState &result,
869  FlatSymbolRefAttr genKind, StringAttr name,
870  ArrayRef<PortInfo> ports, StringRef verilogName,
871  ArrayAttr parameters,
872  ArrayRef<NamedAttribute> attributes) {
873  build(builder, result, genKind, name, ModulePortInfo(ports), verilogName,
874  parameters, attributes);
875 }
876 
877 void HWModuleGeneratedOp::modifyPorts(
878  ArrayRef<std::pair<unsigned, PortInfo>> insertInputs,
879  ArrayRef<std::pair<unsigned, PortInfo>> insertOutputs,
880  ArrayRef<unsigned> eraseInputs, ArrayRef<unsigned> eraseOutputs) {
881  modifyModulePorts(*this, insertInputs, insertOutputs, eraseInputs,
882  eraseOutputs);
883 }
884 
885 void HWModuleGeneratedOp::appendOutputs(
886  ArrayRef<std::pair<StringAttr, Value>> outputs) {}
887 
888 static bool hasAttribute(StringRef name, ArrayRef<NamedAttribute> attrs) {
889  for (auto &argAttr : attrs)
890  if (argAttr.getName() == name)
891  return true;
892  return false;
893 }
894 
895 template <typename ModuleTy>
896 static ParseResult parseHWModuleOp(OpAsmParser &parser,
897  OperationState &result) {
898 
899  using namespace mlir::function_interface_impl;
900  auto builder = parser.getBuilder();
901  auto loc = parser.getCurrentLocation();
902 
903  // Parse the visibility attribute.
904  (void)mlir::impl::parseOptionalVisibilityKeyword(parser, result.attributes);
905 
906  // Parse the name as a symbol.
907  StringAttr nameAttr;
908  if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
909  result.attributes))
910  return failure();
911 
912  // Parse the generator information.
913  FlatSymbolRefAttr kindAttr;
914  if constexpr (std::is_same_v<ModuleTy, HWModuleGeneratedOp>) {
915  if (parser.parseComma() ||
916  parser.parseAttribute(kindAttr, "generatorKind", result.attributes)) {
917  return failure();
918  }
919  }
920 
921  // Parse the parameters.
922  ArrayAttr parameters;
923  if (parseOptionalParameterList(parser, parameters))
924  return failure();
925 
926  SmallVector<module_like_impl::PortParse> ports;
927  TypeAttr modType;
928  if (failed(module_like_impl::parseModuleSignature(parser, ports, modType)))
929  return failure();
930 
931  // Parse the attribute dict.
932  if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
933  return failure();
934 
935  if (hasAttribute("parameters", result.attributes)) {
936  parser.emitError(loc, "explicit `parameters` attributes not allowed");
937  return failure();
938  }
939 
940  result.addAttribute("parameters", parameters);
941  result.addAttribute(ModuleTy::getModuleTypeAttrName(result.name), modType);
942 
943  // Convert the specified array of dictionary attrs (which may have null
944  // entries) to an ArrayAttr of dictionaries.
945  SmallVector<Attribute> attrs;
946  for (auto &port : ports)
947  attrs.push_back(port.attrs ? port.attrs : builder.getDictionaryAttr({}));
948  // Add the attributes to the ports.
949  auto nonEmptyAttrsFn = [](Attribute attr) {
950  return attr && !cast<DictionaryAttr>(attr).empty();
951  };
952  if (llvm::any_of(attrs, nonEmptyAttrsFn))
953  result.addAttribute(ModuleTy::getPerPortAttrsAttrName(result.name),
954  builder.getArrayAttr(attrs));
955 
956  // Add the port locations.
957  auto unknownLoc = builder.getUnknownLoc();
958  auto nonEmptyLocsFn = [unknownLoc](Attribute attr) {
959  return attr && cast<Location>(attr) != unknownLoc;
960  };
961  SmallVector<Attribute> locs;
962  StringAttr portLocsAttrName;
963  if constexpr (std::is_same_v<ModuleTy, HWModuleOp>) {
964  // Plain modules only store the output port locations, as the input port
965  // locations will be stored in the basic block arguments.
966  portLocsAttrName = ModuleTy::getResultLocsAttrName(result.name);
967  for (auto &port : ports)
968  if (port.direction == ModulePort::Direction::Output)
969  locs.push_back(port.sourceLoc ? Location(*port.sourceLoc) : unknownLoc);
970  } else {
971  // All other modules store all port locations in a single array.
972  portLocsAttrName = ModuleTy::getPortLocsAttrName(result.name);
973  for (auto &port : ports)
974  locs.push_back(port.sourceLoc ? Location(*port.sourceLoc) : unknownLoc);
975  }
976  if (llvm::any_of(locs, nonEmptyLocsFn))
977  result.addAttribute(portLocsAttrName, builder.getArrayAttr(locs));
978 
979  // Add the entry block arguments.
980  SmallVector<OpAsmParser::Argument, 4> entryArgs;
981  for (auto &port : ports)
982  if (port.direction != ModulePort::Direction::Output)
983  entryArgs.push_back(port);
984 
985  // Parse the optional function body.
986  auto *body = result.addRegion();
987  if (std::is_same_v<ModuleTy, HWModuleOp>) {
988  if (parser.parseRegion(*body, entryArgs))
989  return failure();
990 
991  HWModuleOp::ensureTerminator(*body, parser.getBuilder(), result.location);
992  }
993  return success();
994 }
995 
996 ParseResult HWModuleOp::parse(OpAsmParser &parser, OperationState &result) {
997  return parseHWModuleOp<HWModuleOp>(parser, result);
998 }
999 
1000 ParseResult HWModuleExternOp::parse(OpAsmParser &parser,
1001  OperationState &result) {
1002  return parseHWModuleOp<HWModuleExternOp>(parser, result);
1003 }
1004 
1005 ParseResult HWModuleGeneratedOp::parse(OpAsmParser &parser,
1006  OperationState &result) {
1007  return parseHWModuleOp<HWModuleGeneratedOp>(parser, result);
1008 }
1009 
1010 FunctionType getHWModuleOpType(Operation *op) {
1011  if (auto mod = dyn_cast<HWModuleLike>(op))
1012  return mod.getHWModuleType().getFuncType();
1013  return cast<mlir::FunctionOpInterface>(op)
1014  .getFunctionType()
1015  .cast<FunctionType>();
1016 }
1017 
1018 template <typename ModuleTy>
1019 static void printModuleOp(OpAsmPrinter &p, ModuleTy mod) {
1020  p << ' ';
1021  // Print the visibility of the module.
1022  StringRef visibilityAttrName = SymbolTable::getVisibilityAttrName();
1023  if (auto visibility = mod.getOperation()->template getAttrOfType<StringAttr>(
1024  visibilityAttrName))
1025  p << visibility.getValue() << ' ';
1026 
1027  // Print the operation and the function name.
1028  p.printSymbolName(SymbolTable::getSymbolName(mod.getOperation()).getValue());
1029  if (auto gen = dyn_cast<HWModuleGeneratedOp>(mod.getOperation())) {
1030  p << ", ";
1031  p.printSymbolName(gen.getGeneratorKind());
1032  }
1033 
1034  // Print the parameter list if present.
1035  printOptionalParameterList(p, mod.getOperation(), mod.getParameters());
1036 
1038 
1039  SmallVector<StringRef, 3> omittedAttrs;
1040  if (isa<HWModuleGeneratedOp>(mod.getOperation()))
1041  omittedAttrs.push_back("generatorKind");
1042  if constexpr (std::is_same_v<ModuleTy, HWModuleOp>)
1043  omittedAttrs.push_back(mod.getResultLocsAttrName());
1044  else
1045  omittedAttrs.push_back(mod.getPortLocsAttrName());
1046  omittedAttrs.push_back(mod.getModuleTypeAttrName());
1047  omittedAttrs.push_back(mod.getPerPortAttrsAttrName());
1048  omittedAttrs.push_back(mod.getParametersAttrName());
1049  omittedAttrs.push_back(visibilityAttrName);
1050  if (auto cmt =
1051  mod.getOperation()->template getAttrOfType<StringAttr>("comment"))
1052  if (cmt.getValue().empty())
1053  omittedAttrs.push_back("comment");
1054 
1055  mlir::function_interface_impl::printFunctionAttributes(p, mod.getOperation(),
1056  omittedAttrs);
1057 }
1058 
1059 void HWModuleExternOp::print(OpAsmPrinter &p) { printModuleOp(p, *this); }
1060 void HWModuleGeneratedOp::print(OpAsmPrinter &p) { printModuleOp(p, *this); }
1061 
1062 void HWModuleOp::print(OpAsmPrinter &p) {
1063  printModuleOp(p, *this);
1064 
1065  // Print the body if this is not an external function.
1066  Region &body = getBody();
1067  if (!body.empty()) {
1068  p << " ";
1069  p.printRegion(body, /*printEntryBlockArgs=*/false,
1070  /*printBlockTerminators=*/true);
1071  }
1072 }
1073 
1074 static LogicalResult verifyModuleCommon(HWModuleLike module) {
1075  assert(isa<HWModuleLike>(module) &&
1076  "verifier hook should only be called on modules");
1077 
1078  auto moduleType = module.getHWModuleType();
1079 
1080  auto argLocs = module.getInputLocs();
1081  if (argLocs.size() != moduleType.getNumInputs())
1082  return module->emitOpError("incorrect number of argument locations");
1083 
1084  auto resultLocs = module.getOutputLocs();
1085  if (resultLocs.size() != moduleType.getNumOutputs())
1086  return module->emitOpError("incorrect number of result locations");
1087 
1088  SmallPtrSet<Attribute, 4> paramNames;
1089 
1090  // Check parameter default values are sensible.
1091  for (auto param : module->getAttrOfType<ArrayAttr>("parameters")) {
1092  auto paramAttr = param.cast<ParamDeclAttr>();
1093 
1094  // Check that we don't have any redundant parameter names. These are
1095  // resolved by string name: reuse of the same name would cause ambiguities.
1096  if (!paramNames.insert(paramAttr.getName()).second)
1097  return module->emitOpError("parameter ")
1098  << paramAttr << " has the same name as a previous parameter";
1099 
1100  // Default values are allowed to be missing, check them if present.
1101  auto value = paramAttr.getValue();
1102  if (!value)
1103  continue;
1104 
1105  auto typedValue = value.dyn_cast<TypedAttr>();
1106  if (!typedValue)
1107  return module->emitOpError("parameter ")
1108  << paramAttr << " should have a typed value; has value " << value;
1109 
1110  if (typedValue.getType() != paramAttr.getType())
1111  return module->emitOpError("parameter ")
1112  << paramAttr << " should have type " << paramAttr.getType()
1113  << "; has type " << typedValue.getType();
1114 
1115  // Verify that this is a valid parameter value, disallowing parameter
1116  // references. We could allow parameters to refer to each other in the
1117  // future with lexical ordering if there is a need.
1118  if (failed(checkParameterInContext(value, module, module,
1119  /*disallowParamRefs=*/true)))
1120  return failure();
1121  }
1122  return success();
1123 }
1124 
1125 LogicalResult HWModuleOp::verify() {
1126  if (failed(verifyModuleCommon(*this)))
1127  return failure();
1128 
1129  auto type = getModuleType();
1130  auto *body = getBodyBlock();
1131 
1132  // Verify the number of block arguments.
1133  auto numInputs = type.getNumInputs();
1134  if (body->getNumArguments() != numInputs)
1135  return emitOpError("entry block must have")
1136  << numInputs << " arguments to match module signature";
1137 
1138  return success();
1139 }
1140 
1141 LogicalResult HWModuleExternOp::verify() { return verifyModuleCommon(*this); }
1142 
1143 std::pair<StringAttr, BlockArgument>
1144 HWModuleOp::insertInput(unsigned index, StringAttr name, Type ty) {
1145  // Find a unique name for the wire.
1146  Namespace ns;
1147  auto ports = getPortList();
1148  for (auto port : ports)
1149  ns.newName(port.name.getValue());
1150  auto nameAttr = StringAttr::get(getContext(), ns.newName(name.getValue()));
1151 
1152  Block *body = getBodyBlock();
1153 
1154  // Create a new port for the host clock.
1155  PortInfo port;
1156  port.name = nameAttr;
1157  port.dir = ModulePort::Direction::Input;
1158  port.type = ty;
1159  modifyModulePorts(getOperation(), {std::make_pair(index, port)}, {}, {}, {},
1160  body);
1161 
1162  // Add a new argument.
1163  return {nameAttr, body->getArgument(index)};
1164 }
1165 
1166 void HWModuleOp::insertOutputs(unsigned index,
1167  ArrayRef<std::pair<StringAttr, Value>> outputs) {
1168 
1169  auto output = cast<OutputOp>(getBodyBlock()->getTerminator());
1170  assert(index <= output->getNumOperands() && "invalid output index");
1171 
1172  // Rewrite the port list of the module.
1173  SmallVector<std::pair<unsigned, PortInfo>> indexedNewPorts;
1174  for (auto &[name, value] : outputs) {
1175  PortInfo port;
1176  port.name = name;
1177  port.dir = ModulePort::Direction::Output;
1178  port.type = value.getType();
1179  indexedNewPorts.emplace_back(index, port);
1180  }
1181  modifyModulePorts(getOperation(), {}, indexedNewPorts, {}, {},
1182  getBodyBlock());
1183 
1184  // Rewrite the output op.
1185  for (auto &[name, value] : outputs)
1186  output->insertOperands(index++, value);
1187 }
1188 
1189 void HWModuleOp::appendOutputs(ArrayRef<std::pair<StringAttr, Value>> outputs) {
1190  return insertOutputs(getNumOutputPorts(), outputs);
1191 }
1192 
1193 void HWModuleOp::getAsmBlockArgumentNames(mlir::Region &region,
1194  mlir::OpAsmSetValueNameFn setNameFn) {
1195  getAsmBlockArgumentNamesImpl(region, setNameFn);
1196 }
1197 
1198 void HWModuleExternOp::getAsmBlockArgumentNames(
1199  mlir::Region &region, mlir::OpAsmSetValueNameFn setNameFn) {
1200  getAsmBlockArgumentNamesImpl(region, setNameFn);
1201 }
1202 
1203 template <typename ModTy>
1204 static SmallVector<Location> getAllPortLocs(ModTy module) {
1205  auto locs = module.getPortLocs();
1206  if (locs) {
1207  SmallVector<Location> retval;
1208  for (auto l : *locs)
1209  retval.push_back(cast<Location>(l));
1210  // Either we have a length of 0 or the correct length
1211  assert(!locs->size() || locs->size() == module.getNumPorts());
1212  return retval;
1213  }
1214  return SmallVector<Location>(module.getNumPorts(),
1215  UnknownLoc::get(module.getContext()));
1216 }
1217 
1218 SmallVector<Location> HWModuleOp::getAllPortLocs() {
1219  SmallVector<Location> portLocs;
1220  auto resultLocs = getResultLocsAttr();
1221  unsigned inputCount = 0;
1222  auto modType = getModuleType();
1223  auto unknownLoc = UnknownLoc::get(getContext());
1224  auto *body = getBodyBlock();
1225  for (unsigned i = 0, e = getNumPorts(); i < e; ++i) {
1226  if (modType.isOutput(i)) {
1227  auto loc = resultLocs
1228  ? cast<Location>(
1229  resultLocs.getValue()[portLocs.size() - inputCount])
1230  : unknownLoc;
1231  portLocs.push_back(loc);
1232  } else {
1233  auto loc = body ? body->getArgument(inputCount).getLoc() : unknownLoc;
1234  portLocs.push_back(loc);
1235  ++inputCount;
1236  }
1237  }
1238  return portLocs;
1239 }
1240 
1241 SmallVector<Location> HWModuleExternOp::getAllPortLocs() {
1242  return ::getAllPortLocs(*this);
1243 }
1244 
1245 SmallVector<Location> HWModuleGeneratedOp::getAllPortLocs() {
1246  return ::getAllPortLocs(*this);
1247 }
1248 
1249 void HWModuleOp::setAllPortLocsAttrs(ArrayRef<Attribute> locs) {
1250  SmallVector<Attribute> resultLocs;
1251  unsigned inputCount = 0;
1252  auto modType = getModuleType();
1253  auto *body = getBodyBlock();
1254  for (unsigned i = 0, e = getNumPorts(); i < e; ++i) {
1255  if (modType.isOutput(i))
1256  resultLocs.push_back(locs[i]);
1257  else
1258  body->getArgument(inputCount++).setLoc(cast<Location>(locs[i]));
1259  }
1260  setResultLocsAttr(ArrayAttr::get(getContext(), resultLocs));
1261 }
1262 
1263 void HWModuleExternOp::setAllPortLocsAttrs(ArrayRef<Attribute> locs) {
1264  setPortLocsAttr(ArrayAttr::get(getContext(), locs));
1265 }
1266 
1267 void HWModuleGeneratedOp::setAllPortLocsAttrs(ArrayRef<Attribute> locs) {
1268  setPortLocsAttr(ArrayAttr::get(getContext(), locs));
1269 }
1270 
1271 template <typename ModTy>
1272 static void setAllPortNames(ArrayRef<Attribute> names, ModTy module) {
1273  auto numInputs = module.getNumInputPorts();
1274  SmallVector<Attribute> argNames(names.begin(), names.begin() + numInputs);
1275  SmallVector<Attribute> resNames(names.begin() + numInputs, names.end());
1276  auto oldType = module.getModuleType();
1277  SmallVector<ModulePort> newPorts(oldType.getPorts().begin(),
1278  oldType.getPorts().end());
1279  for (size_t i = 0UL, e = newPorts.size(); i != e; ++i)
1280  newPorts[i].name = cast<StringAttr>(names[i]);
1281  auto newType = ModuleType::get(module.getContext(), newPorts);
1282  module.setModuleType(newType);
1283 }
1284 
1285 void HWModuleOp::setAllPortNames(ArrayRef<Attribute> names) {
1286  ::setAllPortNames(names, *this);
1287 }
1288 
1289 void HWModuleExternOp::setAllPortNames(ArrayRef<Attribute> names) {
1290  ::setAllPortNames(names, *this);
1291 }
1292 
1293 void HWModuleGeneratedOp::setAllPortNames(ArrayRef<Attribute> names) {
1294  ::setAllPortNames(names, *this);
1295 }
1296 
1297 ArrayRef<Attribute> HWModuleOp::getAllPortAttrs() {
1298  auto attrs = getPerPortAttrs();
1299  if (attrs && !attrs->empty())
1300  return attrs->getValue();
1301  return {};
1302 }
1303 
1304 ArrayRef<Attribute> HWModuleExternOp::getAllPortAttrs() {
1305  auto attrs = getPerPortAttrs();
1306  if (attrs && !attrs->empty())
1307  return attrs->getValue();
1308  return {};
1309 }
1310 
1311 ArrayRef<Attribute> HWModuleGeneratedOp::getAllPortAttrs() {
1312  auto attrs = getPerPortAttrs();
1313  if (attrs && !attrs->empty())
1314  return attrs->getValue();
1315  return {};
1316 }
1317 
1318 void HWModuleOp::setAllPortAttrs(ArrayRef<Attribute> attrs) {
1319  setPerPortAttrsAttr(arrayOrEmpty(getContext(), attrs));
1320 }
1321 
1322 void HWModuleExternOp::setAllPortAttrs(ArrayRef<Attribute> attrs) {
1323  setPerPortAttrsAttr(arrayOrEmpty(getContext(), attrs));
1324 }
1325 
1326 void HWModuleGeneratedOp::setAllPortAttrs(ArrayRef<Attribute> attrs) {
1327  setPerPortAttrsAttr(arrayOrEmpty(getContext(), attrs));
1328 }
1329 
1330 void HWModuleOp::removeAllPortAttrs() {
1331  setPerPortAttrsAttr(ArrayAttr::get(getContext(), {}));
1332 }
1333 
1334 void HWModuleExternOp::removeAllPortAttrs() {
1335  setPerPortAttrsAttr(ArrayAttr::get(getContext(), {}));
1336 }
1337 
1338 void HWModuleGeneratedOp::removeAllPortAttrs() {
1339  setPerPortAttrsAttr(ArrayAttr::get(getContext(), {}));
1340 }
1341 
1342 // This probably does really unexpected stuff when you change the number of
1343 
1344 template <typename ModTy>
1345 static void setHWModuleType(ModTy &mod, ModuleType type) {
1346  auto argAttrs = mod.getAllInputAttrs();
1347  auto resAttrs = mod.getAllOutputAttrs();
1348  mod.setModuleTypeAttr(TypeAttr::get(type));
1349  unsigned newNumArgs = type.getNumInputs();
1350  unsigned newNumResults = type.getNumOutputs();
1351 
1352  auto emptyDict = DictionaryAttr::get(mod.getContext());
1353  argAttrs.resize(newNumArgs, emptyDict);
1354  resAttrs.resize(newNumResults, emptyDict);
1355 
1356  SmallVector<Attribute> attrs;
1357  attrs.append(argAttrs.begin(), argAttrs.end());
1358  attrs.append(resAttrs.begin(), resAttrs.end());
1359 
1360  if (attrs.empty())
1361  return mod.removeAllPortAttrs();
1362  mod.setAllPortAttrs(attrs);
1363 }
1364 
1365 void HWModuleOp::setHWModuleType(ModuleType type) {
1366  return ::setHWModuleType(*this, type);
1367 }
1368 
1369 void HWModuleExternOp::setHWModuleType(ModuleType type) {
1370  return ::setHWModuleType(*this, type);
1371 }
1372 
1373 void HWModuleGeneratedOp::setHWModuleType(ModuleType type) {
1374  return ::setHWModuleType(*this, type);
1375 }
1376 
1377 /// Lookup the generator for the symbol. This returns null on
1378 /// invalid IR.
1379 Operation *HWModuleGeneratedOp::getGeneratorKindOp() {
1380  auto topLevelModuleOp = (*this)->getParentOfType<ModuleOp>();
1381  return topLevelModuleOp.lookupSymbol(getGeneratorKind());
1382 }
1383 
1384 LogicalResult
1385 HWModuleGeneratedOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1386  auto *referencedKind =
1387  symbolTable.lookupNearestSymbolFrom(*this, getGeneratorKindAttr());
1388 
1389  if (referencedKind == nullptr)
1390  return emitError("Cannot find generator definition '")
1391  << getGeneratorKind() << "'";
1392 
1393  if (!isa<HWGeneratorSchemaOp>(referencedKind))
1394  return emitError("Symbol resolved to '")
1395  << referencedKind->getName()
1396  << "' which is not a HWGeneratorSchemaOp";
1397 
1398  auto referencedKindOp = dyn_cast<HWGeneratorSchemaOp>(referencedKind);
1399  auto paramRef = referencedKindOp.getRequiredAttrs();
1400  auto dict = (*this)->getAttrDictionary();
1401  for (auto str : paramRef) {
1402  auto strAttr = str.dyn_cast<StringAttr>();
1403  if (!strAttr)
1404  return emitError("Unknown attribute type, expected a string");
1405  if (!dict.get(strAttr.getValue()))
1406  return emitError("Missing attribute '") << strAttr.getValue() << "'";
1407  }
1408 
1409  return success();
1410 }
1411 
1412 LogicalResult HWModuleGeneratedOp::verify() {
1413  return verifyModuleCommon(*this);
1414 }
1415 
1416 void HWModuleGeneratedOp::getAsmBlockArgumentNames(
1417  mlir::Region &region, mlir::OpAsmSetValueNameFn setNameFn) {
1418  getAsmBlockArgumentNamesImpl(region, setNameFn);
1419 }
1420 
1421 LogicalResult HWModuleOp::verifyBody() { return success(); }
1422 
1423 template <typename ModuleTy>
1424 static SmallVector<PortInfo> getPortList(ModuleTy &mod) {
1425  auto modTy = mod.getHWModuleType();
1426  auto emptyDict = DictionaryAttr::get(mod.getContext());
1427  SmallVector<PortInfo> retval;
1428  auto locs = mod.getAllPortLocs();
1429  for (unsigned i = 0, e = modTy.getNumPorts(); i < e; ++i) {
1430  LocationAttr loc = locs[i];
1431  DictionaryAttr attrs =
1432  dyn_cast_or_null<DictionaryAttr>(mod.getPortAttrs(i));
1433  if (!attrs)
1434  attrs = emptyDict;
1435  retval.push_back({modTy.getPorts()[i],
1436  modTy.isOutput(i) ? modTy.getOutputIdForPortId(i)
1437  : modTy.getInputIdForPortId(i),
1438  attrs, loc});
1439  }
1440  return retval;
1441 }
1442 
1443 template <typename ModuleTy>
1444 static PortInfo getPort(ModuleTy &mod, size_t idx) {
1445  auto modTy = mod.getHWModuleType();
1446  auto emptyDict = DictionaryAttr::get(mod.getContext());
1447  LocationAttr loc = mod.getPortLoc(idx);
1448  DictionaryAttr attrs =
1449  dyn_cast_or_null<DictionaryAttr>(mod.getPortAttrs(idx));
1450  if (!attrs)
1451  attrs = emptyDict;
1452  return {modTy.getPorts()[idx],
1453  modTy.isOutput(idx) ? modTy.getOutputIdForPortId(idx)
1454  : modTy.getInputIdForPortId(idx),
1455  attrs, loc};
1456 }
1457 
1458 //===----------------------------------------------------------------------===//
1459 // InstanceOp
1460 //===----------------------------------------------------------------------===//
1461 
1462 /// Create a instance that refers to a known module.
1463 void InstanceOp::build(OpBuilder &builder, OperationState &result,
1464  Operation *module, StringAttr name,
1465  ArrayRef<Value> inputs, ArrayAttr parameters,
1466  InnerSymAttr innerSym) {
1467  if (!parameters)
1468  parameters = builder.getArrayAttr({});
1469 
1470  auto mod = cast<hw::HWModuleLike>(module);
1471  auto argNames = builder.getArrayAttr(mod.getInputNames());
1472  auto resultNames = builder.getArrayAttr(mod.getOutputNames());
1473 
1474  // Try to resolve the parameterized module type. If failed, use the module's
1475  // parmeterized type. If the client doesn't fix this error, the verifier will
1476  // fail.
1477  ModuleType modType = mod.getHWModuleType();
1478  FailureOr<ModuleType> resolvedModType = modType.resolveParametricTypes(
1479  parameters, result.location, /*emitErrors=*/false);
1480  if (succeeded(resolvedModType))
1481  modType = *resolvedModType;
1482  FunctionType funcType = resolvedModType->getFuncType();
1483  build(builder, result, funcType.getResults(), name,
1484  FlatSymbolRefAttr::get(SymbolTable::getSymbolName(module)), inputs,
1485  argNames, resultNames, parameters, innerSym);
1486 }
1487 
1488 std::optional<size_t> InstanceOp::getTargetResultIndex() {
1489  // Inner symbols on instance operations target the op not any result.
1490  return std::nullopt;
1491 }
1492 
1493 LogicalResult InstanceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1495  *this, getModuleNameAttr(), getInputs(), getResultTypes(), getArgNames(),
1496  getResultNames(), getParameters(), symbolTable);
1497 }
1498 
1499 LogicalResult InstanceOp::verify() {
1500  auto module = (*this)->getParentOfType<HWModuleOp>();
1501  if (!module)
1502  return success();
1503 
1504  auto moduleParameters = module->getAttrOfType<ArrayAttr>("parameters");
1506  [&](const std::function<bool(InFlightDiagnostic &)> &fn) {
1507  auto diag = emitOpError();
1508  if (fn(diag))
1509  diag.attachNote(module->getLoc()) << "module declared here";
1510  };
1512  getParameters(), moduleParameters, emitError);
1513 }
1514 
1515 ParseResult InstanceOp::parse(OpAsmParser &parser, OperationState &result) {
1516  StringAttr instanceNameAttr;
1517  InnerSymAttr innerSym;
1518  FlatSymbolRefAttr moduleNameAttr;
1519  SmallVector<OpAsmParser::UnresolvedOperand, 4> inputsOperands;
1520  SmallVector<Type, 1> inputsTypes, allResultTypes;
1521  ArrayAttr argNames, resultNames, parameters;
1522  auto noneType = parser.getBuilder().getType<NoneType>();
1523 
1524  if (parser.parseAttribute(instanceNameAttr, noneType, "instanceName",
1525  result.attributes))
1526  return failure();
1527 
1528  if (succeeded(parser.parseOptionalKeyword("sym"))) {
1529  // Parsing an optional symbol name doesn't fail, so no need to check the
1530  // result.
1531  if (parser.parseCustomAttributeWithFallback(innerSym))
1532  return failure();
1533  result.addAttribute(InnerSymbolTable::getInnerSymbolAttrName(), innerSym);
1534  }
1535 
1536  llvm::SMLoc parametersLoc, inputsOperandsLoc;
1537  if (parser.parseAttribute(moduleNameAttr, noneType, "moduleName",
1538  result.attributes) ||
1539  parser.getCurrentLocation(&parametersLoc) ||
1540  parseOptionalParameterList(parser, parameters) ||
1541  parseInputPortList(parser, inputsOperands, inputsTypes, argNames) ||
1542  parser.resolveOperands(inputsOperands, inputsTypes, inputsOperandsLoc,
1543  result.operands) ||
1544  parser.parseArrow() ||
1545  parseOutputPortList(parser, allResultTypes, resultNames) ||
1546  parser.parseOptionalAttrDict(result.attributes)) {
1547  return failure();
1548  }
1549 
1550  result.addAttribute("argNames", argNames);
1551  result.addAttribute("resultNames", resultNames);
1552  result.addAttribute("parameters", parameters);
1553  result.addTypes(allResultTypes);
1554  return success();
1555 }
1556 
1557 void InstanceOp::print(OpAsmPrinter &p) {
1558  p << ' ';
1559  p.printAttributeWithoutType(getInstanceNameAttr());
1560  if (auto attr = getInnerSymAttr()) {
1561  p << " sym ";
1562  attr.print(p);
1563  }
1564  p << ' ';
1565  p.printAttributeWithoutType(getModuleNameAttr());
1566  printOptionalParameterList(p, *this, getParameters());
1567  printInputPortList(p, *this, getInputs(), getInputs().getTypes(),
1568  getArgNames());
1569  p << " -> ";
1570  printOutputPortList(p, *this, getResultTypes(), getResultNames());
1571 
1572  p.printOptionalAttrDict(
1573  (*this)->getAttrs(),
1574  /*elidedAttrs=*/{"instanceName",
1575  InnerSymbolTable::getInnerSymbolAttrName(), "moduleName",
1576  "argNames", "resultNames", "parameters"});
1577 }
1578 
1579 //===----------------------------------------------------------------------===//
1580 // InstanceChoiceOp
1581 //===----------------------------------------------------------------------===//
1582 
1583 std::optional<size_t> InstanceChoiceOp::getTargetResultIndex() {
1584  // Inner symbols on instance operations target the op not any result.
1585  return std::nullopt;
1586 }
1587 
1588 LogicalResult
1589 InstanceChoiceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1590  for (Attribute name : getModuleNamesAttr()) {
1592  *this, name.cast<FlatSymbolRefAttr>(), getInputs(),
1593  getResultTypes(), getArgNames(), getResultNames(), getParameters(),
1594  symbolTable))) {
1595  return failure();
1596  }
1597  }
1598  return success();
1599 }
1600 
1601 LogicalResult InstanceChoiceOp::verify() {
1602  auto module = (*this)->getParentOfType<HWModuleOp>();
1603  if (!module)
1604  return success();
1605 
1606  auto moduleParameters = module->getAttrOfType<ArrayAttr>("parameters");
1608  [&](const std::function<bool(InFlightDiagnostic &)> &fn) {
1609  auto diag = emitOpError();
1610  if (fn(diag))
1611  diag.attachNote(module->getLoc()) << "module declared here";
1612  };
1614  getParameters(), moduleParameters, emitError);
1615 }
1616 
1617 ParseResult InstanceChoiceOp::parse(OpAsmParser &parser,
1618  OperationState &result) {
1619  StringAttr optionNameAttr;
1620  StringAttr instanceNameAttr;
1621  InnerSymAttr innerSym;
1622  SmallVector<Attribute> moduleNames;
1623  SmallVector<Attribute> caseNames;
1624  SmallVector<OpAsmParser::UnresolvedOperand, 4> inputsOperands;
1625  SmallVector<Type, 1> inputsTypes, allResultTypes;
1626  ArrayAttr argNames, resultNames, parameters;
1627  auto noneType = parser.getBuilder().getType<NoneType>();
1628 
1629  if (parser.parseAttribute(instanceNameAttr, noneType, "instanceName",
1630  result.attributes))
1631  return failure();
1632 
1633  if (succeeded(parser.parseOptionalKeyword("sym"))) {
1634  // Parsing an optional symbol name doesn't fail, so no need to check the
1635  // result.
1636  if (parser.parseCustomAttributeWithFallback(innerSym))
1637  return failure();
1638  result.addAttribute(InnerSymbolTable::getInnerSymbolAttrName(), innerSym);
1639  }
1640 
1641  if (parser.parseKeyword("option") ||
1642  parser.parseAttribute(optionNameAttr, noneType, "optionName",
1643  result.attributes))
1644  return failure();
1645 
1646  FlatSymbolRefAttr defaultModuleName;
1647  if (parser.parseAttribute(defaultModuleName))
1648  return failure();
1649  moduleNames.push_back(defaultModuleName);
1650 
1651  while (succeeded(parser.parseOptionalKeyword("or"))) {
1652  FlatSymbolRefAttr moduleName;
1653  StringAttr targetName;
1654  if (parser.parseAttribute(moduleName) ||
1655  parser.parseOptionalKeyword("if") || parser.parseAttribute(targetName))
1656  return failure();
1657  moduleNames.push_back(moduleName);
1658  caseNames.push_back(targetName);
1659  }
1660 
1661  llvm::SMLoc parametersLoc, inputsOperandsLoc;
1662  if (parser.getCurrentLocation(&parametersLoc) ||
1663  parseOptionalParameterList(parser, parameters) ||
1664  parseInputPortList(parser, inputsOperands, inputsTypes, argNames) ||
1665  parser.resolveOperands(inputsOperands, inputsTypes, inputsOperandsLoc,
1666  result.operands) ||
1667  parser.parseArrow() ||
1668  parseOutputPortList(parser, allResultTypes, resultNames) ||
1669  parser.parseOptionalAttrDict(result.attributes)) {
1670  return failure();
1671  }
1672 
1673  result.addAttribute("moduleNames",
1674  ArrayAttr::get(parser.getContext(), moduleNames));
1675  result.addAttribute("caseNames",
1676  ArrayAttr::get(parser.getContext(), caseNames));
1677  result.addAttribute("argNames", argNames);
1678  result.addAttribute("resultNames", resultNames);
1679  result.addAttribute("parameters", parameters);
1680  result.addTypes(allResultTypes);
1681  return success();
1682 }
1683 
1684 void InstanceChoiceOp::print(OpAsmPrinter &p) {
1685  p << ' ';
1686  p.printAttributeWithoutType(getInstanceNameAttr());
1687  if (auto attr = getInnerSymAttr()) {
1688  p << " sym ";
1689  attr.print(p);
1690  }
1691  p << " option " << getOptionNameAttr() << ' ';
1692 
1693  auto moduleNames = getModuleNamesAttr();
1694  auto caseNames = getCaseNamesAttr();
1695  assert(moduleNames.size() == caseNames.size() + 1);
1696 
1697  p.printAttributeWithoutType(moduleNames[0]);
1698  for (size_t i = 0, n = caseNames.size(); i < n; ++i) {
1699  p << " or ";
1700  p.printAttributeWithoutType(moduleNames[i + 1]);
1701  p << " if ";
1702  p.printAttributeWithoutType(caseNames[i]);
1703  }
1704 
1705  printOptionalParameterList(p, *this, getParameters());
1706  printInputPortList(p, *this, getInputs(), getInputs().getTypes(),
1707  getArgNames());
1708  p << " -> ";
1709  printOutputPortList(p, *this, getResultTypes(), getResultNames());
1710 
1711  p.printOptionalAttrDict(
1712  (*this)->getAttrs(),
1713  /*elidedAttrs=*/{"instanceName",
1714  InnerSymbolTable::getInnerSymbolAttrName(),
1715  "moduleNames", "caseNames", "argNames", "resultNames",
1716  "parameters", "optionName"});
1717 }
1718 
1719 ArrayAttr InstanceChoiceOp::getReferencedModuleNamesAttr() {
1720  SmallVector<Attribute> moduleNames;
1721  for (Attribute attr : getModuleNamesAttr()) {
1722  moduleNames.push_back(attr.cast<FlatSymbolRefAttr>().getAttr());
1723  }
1724  return ArrayAttr::get(getContext(), moduleNames);
1725 }
1726 
1727 //===----------------------------------------------------------------------===//
1728 // HWOutputOp
1729 //===----------------------------------------------------------------------===//
1730 
1731 /// Verify that the num of operands and types fit the declared results.
1732 LogicalResult OutputOp::verify() {
1733  // Check that the we (hw.output) have the same number of operands as our
1734  // region has results.
1735  ModuleType modType;
1736  if (auto mod = dyn_cast<HWModuleOp>((*this)->getParentOp()))
1737  modType = mod.getHWModuleType();
1738  else {
1739  emitOpError("must have a module parent");
1740  return failure();
1741  }
1742  auto modResults = modType.getOutputTypes();
1743  OperandRange outputValues = getOperands();
1744  if (modResults.size() != outputValues.size()) {
1745  emitOpError("must have same number of operands as region results.");
1746  return failure();
1747  }
1748 
1749  // Check that the types of our operands and the region's results match.
1750  for (size_t i = 0, e = modResults.size(); i < e; ++i) {
1751  if (modResults[i] != outputValues[i].getType()) {
1752  emitOpError("output types must match module. In "
1753  "operand ")
1754  << i << ", expected " << modResults[i] << ", but got "
1755  << outputValues[i].getType() << ".";
1756  return failure();
1757  }
1758  }
1759 
1760  return success();
1761 }
1762 
1763 //===----------------------------------------------------------------------===//
1764 // Other Operations
1765 //===----------------------------------------------------------------------===//
1766 
1767 static ParseResult parseSliceTypes(OpAsmParser &p, Type &srcType,
1768  Type &idxType) {
1769  Type type;
1770  if (p.parseType(type))
1771  return p.emitError(p.getCurrentLocation(), "Expected type");
1772  auto arrType = type_dyn_cast<ArrayType>(type);
1773  if (!arrType)
1774  return p.emitError(p.getCurrentLocation(), "Expected !hw.array type");
1775  srcType = type;
1776  unsigned idxWidth = llvm::Log2_64_Ceil(arrType.getNumElements());
1777  idxType = IntegerType::get(p.getBuilder().getContext(), idxWidth);
1778  return success();
1779 }
1780 
1781 static void printSliceTypes(OpAsmPrinter &p, Operation *, Type srcType,
1782  Type idxType) {
1783  p.printType(srcType);
1784 }
1785 
1786 ParseResult ArrayCreateOp::parse(OpAsmParser &parser, OperationState &result) {
1787  llvm::SMLoc inputOperandsLoc = parser.getCurrentLocation();
1788  llvm::SmallVector<OpAsmParser::UnresolvedOperand, 16> operands;
1789  Type elemType;
1790 
1791  if (parser.parseOperandList(operands) ||
1792  parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
1793  parser.parseType(elemType))
1794  return failure();
1795 
1796  if (operands.size() == 0)
1797  return parser.emitError(inputOperandsLoc,
1798  "Cannot construct an array of length 0");
1799  result.addTypes({ArrayType::get(elemType, operands.size())});
1800 
1801  for (auto operand : operands)
1802  if (parser.resolveOperand(operand, elemType, result.operands))
1803  return failure();
1804  return success();
1805 }
1806 
1807 void ArrayCreateOp::print(OpAsmPrinter &p) {
1808  p << " ";
1809  p.printOperands(getInputs());
1810  p.printOptionalAttrDict((*this)->getAttrs());
1811  p << " : " << getInputs()[0].getType();
1812 }
1813 
1814 void ArrayCreateOp::build(OpBuilder &b, OperationState &state,
1815  ValueRange values) {
1816  assert(values.size() > 0 && "Cannot build array of zero elements");
1817  Type elemType = values[0].getType();
1818  assert(llvm::all_of(
1819  values,
1820  [elemType](Value v) -> bool { return v.getType() == elemType; }) &&
1821  "All values must have same type.");
1822  build(b, state, ArrayType::get(elemType, values.size()), values);
1823 }
1824 
1825 LogicalResult ArrayCreateOp::verify() {
1826  unsigned returnSize = getType().cast<ArrayType>().getNumElements();
1827  if (getInputs().size() != returnSize)
1828  return failure();
1829  return success();
1830 }
1831 
1832 OpFoldResult ArrayCreateOp::fold(FoldAdaptor adaptor) {
1833  if (llvm::any_of(adaptor.getInputs(), [](Attribute attr) { return !attr; }))
1834  return {};
1835  return ArrayAttr::get(getContext(), adaptor.getInputs());
1836 }
1837 
1838 // Check whether an integer value is an offset from a base.
1839 bool hw::isOffset(Value base, Value index, uint64_t offset) {
1840  if (auto constBase = base.getDefiningOp<hw::ConstantOp>()) {
1841  if (auto constIndex = index.getDefiningOp<hw::ConstantOp>()) {
1842  // If both values are a constant, check if index == base + offset.
1843  // To account for overflow, the addition is performed with an extra bit
1844  // and the offset is asserted to fit in the bit width of the base.
1845  auto baseValue = constBase.getValue();
1846  auto indexValue = constIndex.getValue();
1847 
1848  unsigned bits = baseValue.getBitWidth();
1849  assert(bits == indexValue.getBitWidth() && "mismatched widths");
1850 
1851  if (bits < 64 && offset >= (1ull << bits))
1852  return false;
1853 
1854  APInt baseExt = baseValue.zextOrTrunc(bits + 1);
1855  APInt indexExt = indexValue.zextOrTrunc(bits + 1);
1856  return baseExt + offset == indexExt;
1857  }
1858  }
1859  return false;
1860 }
1861 
1862 // Canonicalize a create of consecutive elements to a slice.
1863 static LogicalResult foldCreateToSlice(ArrayCreateOp op,
1864  PatternRewriter &rewriter) {
1865  // Do not canonicalize create of get into a slice.
1866  auto arrayTy = hw::type_cast<ArrayType>(op.getType());
1867  if (arrayTy.getNumElements() <= 1)
1868  return failure();
1869  auto elemTy = arrayTy.getElementType();
1870 
1871  // Check if create arguments are consecutive elements of the same array.
1872  // Attempt to break a create of gets into a sequence of consecutive intervals.
1873  struct Chunk {
1874  Value input;
1875  Value index;
1876  size_t size;
1877  };
1878  SmallVector<Chunk> chunks;
1879  for (Value value : llvm::reverse(op.getInputs())) {
1880  auto get = value.getDefiningOp<ArrayGetOp>();
1881  if (!get)
1882  return failure();
1883 
1884  Value input = get.getInput();
1885  Value index = get.getIndex();
1886  if (!chunks.empty()) {
1887  auto &c = *chunks.rbegin();
1888  if (c.input == get.getInput() && isOffset(c.index, index, c.size)) {
1889  c.size++;
1890  continue;
1891  }
1892  }
1893 
1894  chunks.push_back(Chunk{input, index, 1});
1895  }
1896 
1897  // If there is a single slice, eliminate the create.
1898  if (chunks.size() == 1) {
1899  auto &chunk = chunks[0];
1900  rewriter.replaceOp(op, rewriter.createOrFold<ArraySliceOp>(
1901  op.getLoc(), arrayTy, chunk.input, chunk.index));
1902  return success();
1903  }
1904 
1905  // If the number of chunks is significantly less than the number of
1906  // elements, replace the create with a concat of the identified slices.
1907  if (chunks.size() * 2 < arrayTy.getNumElements()) {
1908  SmallVector<Value> slices;
1909  for (auto &chunk : llvm::reverse(chunks)) {
1910  auto sliceTy = ArrayType::get(elemTy, chunk.size);
1911  slices.push_back(rewriter.createOrFold<ArraySliceOp>(
1912  op.getLoc(), sliceTy, chunk.input, chunk.index));
1913  }
1914  rewriter.replaceOpWithNewOp<ArrayConcatOp>(op, arrayTy, slices);
1915  return success();
1916  }
1917 
1918  return failure();
1919 }
1920 
1921 LogicalResult ArrayCreateOp::canonicalize(ArrayCreateOp op,
1922  PatternRewriter &rewriter) {
1923  if (succeeded(foldCreateToSlice(op, rewriter)))
1924  return success();
1925  return failure();
1926 }
1927 
1928 Value ArrayCreateOp::getUniformElement() {
1929  if (!getInputs().empty() && llvm::all_equal(getInputs()))
1930  return getInputs()[0];
1931  return {};
1932 }
1933 
1934 static std::optional<uint64_t> getUIntFromValue(Value value) {
1935  auto idxOp = dyn_cast_or_null<ConstantOp>(value.getDefiningOp());
1936  if (!idxOp)
1937  return std::nullopt;
1938  APInt idxAttr = idxOp.getValue();
1939  if (idxAttr.getBitWidth() > 64)
1940  return std::nullopt;
1941  return idxAttr.getLimitedValue();
1942 }
1943 
1944 LogicalResult ArraySliceOp::verify() {
1945  unsigned inputSize =
1946  type_cast<ArrayType>(getInput().getType()).getNumElements();
1947  if (llvm::Log2_64_Ceil(inputSize) !=
1948  getLowIndex().getType().getIntOrFloatBitWidth())
1949  return emitOpError(
1950  "ArraySlice: index width must match clog2 of array size");
1951  return success();
1952 }
1953 
1954 OpFoldResult ArraySliceOp::fold(FoldAdaptor adaptor) {
1955  // If we are slicing the entire input, then return it.
1956  if (getType() == getInput().getType())
1957  return getInput();
1958  return {};
1959 }
1960 
1961 LogicalResult ArraySliceOp::canonicalize(ArraySliceOp op,
1962  PatternRewriter &rewriter) {
1963  auto sliceTy = hw::type_cast<ArrayType>(op.getType());
1964  auto elemTy = sliceTy.getElementType();
1965  uint64_t sliceSize = sliceTy.getNumElements();
1966  if (sliceSize == 0)
1967  return failure();
1968 
1969  if (sliceSize == 1) {
1970  // slice(a, n) -> create(a[n])
1971  auto get = rewriter.create<ArrayGetOp>(op.getLoc(), op.getInput(),
1972  op.getLowIndex());
1973  rewriter.replaceOpWithNewOp<ArrayCreateOp>(op, op.getType(),
1974  get.getResult());
1975  return success();
1976  }
1977 
1978  auto offsetOpt = getUIntFromValue(op.getLowIndex());
1979  if (!offsetOpt)
1980  return failure();
1981 
1982  auto inputOp = op.getInput().getDefiningOp();
1983  if (auto inputSlice = dyn_cast_or_null<ArraySliceOp>(inputOp)) {
1984  // slice(slice(a, n), m) -> slice(a, n + m)
1985  if (inputSlice == op)
1986  return failure();
1987 
1988  auto inputIndex = inputSlice.getLowIndex();
1989  auto inputOffsetOpt = getUIntFromValue(inputIndex);
1990  if (!inputOffsetOpt)
1991  return failure();
1992 
1993  uint64_t offset = *offsetOpt + *inputOffsetOpt;
1994  auto lowIndex =
1995  rewriter.create<ConstantOp>(op.getLoc(), inputIndex.getType(), offset);
1996  rewriter.replaceOpWithNewOp<ArraySliceOp>(op, op.getType(),
1997  inputSlice.getInput(), lowIndex);
1998  return success();
1999  }
2000 
2001  if (auto inputCreate = dyn_cast_or_null<ArrayCreateOp>(inputOp)) {
2002  // slice(create(a0, a1, ..., an), m) -> create(am, ...)
2003  auto inputs = inputCreate.getInputs();
2004 
2005  uint64_t begin = inputs.size() - *offsetOpt - sliceSize;
2006  rewriter.replaceOpWithNewOp<ArrayCreateOp>(op, op.getType(),
2007  inputs.slice(begin, sliceSize));
2008  return success();
2009  }
2010 
2011  if (auto inputConcat = dyn_cast_or_null<ArrayConcatOp>(inputOp)) {
2012  // slice(concat(a1, a2, ...)) -> concat(a2, slice(a3, ..), ...)
2013  SmallVector<Value> chunks;
2014  uint64_t sliceStart = *offsetOpt;
2015  for (auto input : llvm::reverse(inputConcat.getInputs())) {
2016  // Check whether the input intersects with the slice.
2017  uint64_t inputSize =
2018  hw::type_cast<ArrayType>(input.getType()).getNumElements();
2019  if (inputSize == 0 || inputSize <= sliceStart) {
2020  sliceStart -= inputSize;
2021  continue;
2022  }
2023 
2024  // Find the indices to slice from this input by intersection.
2025  uint64_t cutEnd = std::min(inputSize, sliceStart + sliceSize);
2026  uint64_t cutSize = cutEnd - sliceStart;
2027  assert(cutSize != 0 && "slice cannot be empty");
2028 
2029  if (cutSize == inputSize) {
2030  // The whole input fits in the slice, add it.
2031  assert(sliceStart == 0 && "invalid cut size");
2032  chunks.push_back(input);
2033  } else {
2034  // Slice the required bits from the input.
2035  unsigned width = inputSize == 1 ? 1 : llvm::Log2_64_Ceil(inputSize);
2036  auto lowIndex = rewriter.create<ConstantOp>(
2037  op.getLoc(), rewriter.getIntegerType(width), sliceStart);
2038  chunks.push_back(rewriter.create<ArraySliceOp>(
2039  op.getLoc(), hw::ArrayType::get(elemTy, cutSize), input, lowIndex));
2040  }
2041 
2042  sliceStart = 0;
2043  sliceSize -= cutSize;
2044  if (sliceSize == 0)
2045  break;
2046  }
2047 
2048  assert(chunks.size() > 0 && "missing sliced items");
2049  if (chunks.size() == 1)
2050  rewriter.replaceOp(op, chunks[0]);
2051  else
2052  rewriter.replaceOpWithNewOp<ArrayConcatOp>(
2053  op, llvm::to_vector(llvm::reverse(chunks)));
2054  return success();
2055  }
2056  return failure();
2057 }
2058 
2059 //===----------------------------------------------------------------------===//
2060 // ArrayConcatOp
2061 //===----------------------------------------------------------------------===//
2062 
2063 static ParseResult parseArrayConcatTypes(OpAsmParser &p,
2064  SmallVectorImpl<Type> &inputTypes,
2065  Type &resultType) {
2066  Type elemType;
2067  uint64_t resultSize = 0;
2068 
2069  auto parseElement = [&]() -> ParseResult {
2070  Type ty;
2071  if (p.parseType(ty))
2072  return failure();
2073  auto arrTy = type_dyn_cast<ArrayType>(ty);
2074  if (!arrTy)
2075  return p.emitError(p.getCurrentLocation(), "Expected !hw.array type");
2076  if (elemType && elemType != arrTy.getElementType())
2077  return p.emitError(p.getCurrentLocation(), "Expected array element type ")
2078  << elemType;
2079 
2080  elemType = arrTy.getElementType();
2081  inputTypes.push_back(ty);
2082  resultSize += arrTy.getNumElements();
2083  return success();
2084  };
2085 
2086  if (p.parseCommaSeparatedList(parseElement))
2087  return failure();
2088 
2089  resultType = ArrayType::get(elemType, resultSize);
2090  return success();
2091 }
2092 
2093 static void printArrayConcatTypes(OpAsmPrinter &p, Operation *,
2094  TypeRange inputTypes, Type resultType) {
2095  llvm::interleaveComma(inputTypes, p, [&p](Type t) { p << t; });
2096 }
2097 
2098 void ArrayConcatOp::build(OpBuilder &b, OperationState &state,
2099  ValueRange values) {
2100  assert(!values.empty() && "Cannot build array of zero elements");
2101  ArrayType arrayTy = values[0].getType().cast<ArrayType>();
2102  Type elemTy = arrayTy.getElementType();
2103  assert(llvm::all_of(values,
2104  [elemTy](Value v) -> bool {
2105  return v.getType().isa<ArrayType>() &&
2106  v.getType().cast<ArrayType>().getElementType() ==
2107  elemTy;
2108  }) &&
2109  "All values must be of ArrayType with the same element type.");
2110 
2111  uint64_t resultSize = 0;
2112  for (Value val : values)
2113  resultSize += val.getType().cast<ArrayType>().getNumElements();
2114  build(b, state, ArrayType::get(elemTy, resultSize), values);
2115 }
2116 
2117 OpFoldResult ArrayConcatOp::fold(FoldAdaptor adaptor) {
2118  auto inputs = adaptor.getInputs();
2119  SmallVector<Attribute> array;
2120  for (size_t i = 0, e = getNumOperands(); i < e; ++i) {
2121  if (!inputs[i])
2122  return {};
2123  llvm::copy(inputs[i].cast<ArrayAttr>(), std::back_inserter(array));
2124  }
2125  return ArrayAttr::get(getContext(), array);
2126 }
2127 
2128 // Flatten a concatenation of array creates into a single create.
2129 static bool flattenConcatOp(ArrayConcatOp op, PatternRewriter &rewriter) {
2130  for (auto input : op.getInputs())
2131  if (!input.getDefiningOp<ArrayCreateOp>())
2132  return false;
2133 
2134  SmallVector<Value> items;
2135  for (auto input : op.getInputs()) {
2136  auto create = cast<ArrayCreateOp>(input.getDefiningOp());
2137  for (auto item : create.getInputs())
2138  items.push_back(item);
2139  }
2140 
2141  rewriter.replaceOpWithNewOp<ArrayCreateOp>(op, items);
2142  return true;
2143 }
2144 
2145 // Merge consecutive slice expressions in a concatenation.
2146 static bool mergeConcatSlices(ArrayConcatOp op, PatternRewriter &rewriter) {
2147  struct Slice {
2148  Value input;
2149  Value index;
2150  size_t size;
2151  Value op;
2152  SmallVector<Location> locs;
2153  };
2154 
2155  SmallVector<Value> items;
2156  std::optional<Slice> last;
2157  bool changed = false;
2158 
2159  auto concatenate = [&] {
2160  // If there is only one op in the slice, place it to the items list.
2161  if (!last)
2162  return;
2163  if (last->op) {
2164  items.push_back(last->op);
2165  last.reset();
2166  return;
2167  }
2168 
2169  // Otherwise, create a new slice of with the given size and place it.
2170  // In this case, the concat op is replaced, using the new argument.
2171  changed = true;
2172  auto loc = FusedLoc::get(op.getContext(), last->locs);
2173  auto origTy = hw::type_cast<ArrayType>(last->input.getType());
2174  auto arrayTy = ArrayType::get(origTy.getElementType(), last->size);
2175  items.push_back(rewriter.createOrFold<ArraySliceOp>(
2176  loc, arrayTy, last->input, last->index));
2177 
2178  last.reset();
2179  };
2180 
2181  auto append = [&](Value op, Value input, Value index, size_t size) {
2182  // If this slice is an extension of the previous one, extend the size
2183  // saved. In this case, a new slice of is created and the concatenation
2184  // operator is rewritten. Otherwise, flush the last slice.
2185  if (last) {
2186  if (last->input == input && isOffset(last->index, index, last->size)) {
2187  last->size += size;
2188  last->op = {};
2189  last->locs.push_back(op.getLoc());
2190  return;
2191  }
2192  concatenate();
2193  }
2194  last.emplace(Slice{input, index, size, op, {op.getLoc()}});
2195  };
2196 
2197  for (auto item : llvm::reverse(op.getInputs())) {
2198  if (auto slice = item.getDefiningOp<ArraySliceOp>()) {
2199  auto size = hw::type_cast<ArrayType>(slice.getType()).getNumElements();
2200  append(item, slice.getInput(), slice.getLowIndex(), size);
2201  continue;
2202  }
2203 
2204  if (auto create = item.getDefiningOp<ArrayCreateOp>()) {
2205  if (create.getInputs().size() == 1) {
2206  if (auto get = create.getInputs()[0].getDefiningOp<ArrayGetOp>()) {
2207  append(item, get.getInput(), get.getIndex(), 1);
2208  continue;
2209  }
2210  }
2211  }
2212 
2213  concatenate();
2214  items.push_back(item);
2215  }
2216  concatenate();
2217 
2218  if (!changed)
2219  return false;
2220 
2221  if (items.size() == 1) {
2222  rewriter.replaceOp(op, items[0]);
2223  } else {
2224  std::reverse(items.begin(), items.end());
2225  rewriter.replaceOpWithNewOp<ArrayConcatOp>(op, items);
2226  }
2227  return true;
2228 }
2229 
2230 LogicalResult ArrayConcatOp::canonicalize(ArrayConcatOp op,
2231  PatternRewriter &rewriter) {
2232  // concat(create(a1, ...), create(a3, ...), ...) -> create(a1, ..., a3, ...)
2233  if (flattenConcatOp(op, rewriter))
2234  return success();
2235 
2236  // concat(slice(a, n, m), slice(a, n + m, p)) -> concat(slice(a, n, m + p))
2237  if (mergeConcatSlices(op, rewriter))
2238  return success();
2239 
2240  return success();
2241 }
2242 
2243 //===----------------------------------------------------------------------===//
2244 // EnumConstantOp
2245 //===----------------------------------------------------------------------===//
2246 
2247 ParseResult EnumConstantOp::parse(OpAsmParser &parser, OperationState &result) {
2248  // Parse a Type instead of an EnumType since the type might be a type alias.
2249  // The validity of the canonical type is checked during construction of the
2250  // EnumFieldAttr.
2251  Type type;
2252  StringRef field;
2253 
2254  auto loc = parser.getEncodedSourceLoc(parser.getCurrentLocation());
2255  if (parser.parseKeyword(&field) || parser.parseColonType(type))
2256  return failure();
2257 
2258  auto fieldAttr = EnumFieldAttr::get(
2259  loc, StringAttr::get(parser.getContext(), field), type);
2260 
2261  if (!fieldAttr)
2262  return failure();
2263 
2264  result.addAttribute("field", fieldAttr);
2265  result.addTypes(type);
2266 
2267  return success();
2268 }
2269 
2270 void EnumConstantOp::print(OpAsmPrinter &p) {
2271  p << " " << getField().getField().getValue() << " : "
2272  << getField().getType().getValue();
2273 }
2274 
2276  function_ref<void(Value, StringRef)> setNameFn) {
2277  setNameFn(getResult(), getField().getField().str());
2278 }
2279 
2280 void EnumConstantOp::build(OpBuilder &builder, OperationState &odsState,
2281  EnumFieldAttr field) {
2282  return build(builder, odsState, field.getType().getValue(), field);
2283 }
2284 
2285 OpFoldResult EnumConstantOp::fold(FoldAdaptor adaptor) {
2286  assert(adaptor.getOperands().empty() && "constant has no operands");
2287  return getFieldAttr();
2288 }
2289 
2290 LogicalResult EnumConstantOp::verify() {
2291  auto fieldAttr = getFieldAttr();
2292  auto fieldType = fieldAttr.getType().getValue();
2293  // This check ensures that we are using the exact same type, without looking
2294  // through type aliases.
2295  if (fieldType != getType())
2296  emitOpError("return type ")
2297  << getType() << " does not match attribute type " << fieldAttr;
2298  return success();
2299 }
2300 
2301 //===----------------------------------------------------------------------===//
2302 // EnumCmpOp
2303 //===----------------------------------------------------------------------===//
2304 
2305 LogicalResult EnumCmpOp::verify() {
2306  // Compare the canonical types.
2307  auto lhsType = type_cast<EnumType>(getLhs().getType());
2308  auto rhsType = type_cast<EnumType>(getRhs().getType());
2309  if (rhsType != lhsType)
2310  emitOpError("types do not match");
2311  return success();
2312 }
2313 
2314 //===----------------------------------------------------------------------===//
2315 // StructCreateOp
2316 //===----------------------------------------------------------------------===//
2317 
2318 ParseResult StructCreateOp::parse(OpAsmParser &parser, OperationState &result) {
2319  llvm::SMLoc inputOperandsLoc = parser.getCurrentLocation();
2320  llvm::SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
2321  Type declOrAliasType;
2322 
2323  if (parser.parseLParen() || parser.parseOperandList(operands) ||
2324  parser.parseRParen() || parser.parseOptionalAttrDict(result.attributes) ||
2325  parser.parseColonType(declOrAliasType))
2326  return failure();
2327 
2328  auto declType = type_dyn_cast<StructType>(declOrAliasType);
2329  if (!declType)
2330  return parser.emitError(parser.getNameLoc(),
2331  "expected !hw.struct type or alias");
2332 
2333  llvm::SmallVector<Type, 4> structInnerTypes;
2334  declType.getInnerTypes(structInnerTypes);
2335  result.addTypes(declOrAliasType);
2336 
2337  if (parser.resolveOperands(operands, structInnerTypes, inputOperandsLoc,
2338  result.operands))
2339  return failure();
2340  return success();
2341 }
2342 
2343 void StructCreateOp::print(OpAsmPrinter &printer) {
2344  printer << " (";
2345  printer.printOperands(getInput());
2346  printer << ")";
2347  printer.printOptionalAttrDict((*this)->getAttrs());
2348  printer << " : " << getType();
2349 }
2350 
2351 LogicalResult StructCreateOp::verify() {
2352  auto elements = hw::type_cast<StructType>(getType()).getElements();
2353 
2354  if (elements.size() != getInput().size())
2355  return emitOpError("structure field count mismatch");
2356 
2357  for (const auto &[field, value] : llvm::zip(elements, getInput()))
2358  if (field.type != value.getType())
2359  return emitOpError("structure field `")
2360  << field.name << "` type does not match";
2361 
2362  return success();
2363 }
2364 
2365 OpFoldResult StructCreateOp::fold(FoldAdaptor adaptor) {
2366  // struct_create(struct_explode(x)) => x
2367  if (!getInput().empty())
2368  if (auto explodeOp = getInput()[0].getDefiningOp<StructExplodeOp>();
2369  explodeOp && getInput() == explodeOp.getResults() &&
2370  getResult().getType() == explodeOp.getInput().getType())
2371  return explodeOp.getInput();
2372 
2373  auto inputs = adaptor.getInput();
2374  if (llvm::any_of(inputs, [](Attribute attr) { return !attr; }))
2375  return {};
2376  return ArrayAttr::get(getContext(), inputs);
2377 }
2378 
2379 //===----------------------------------------------------------------------===//
2380 // StructExplodeOp
2381 //===----------------------------------------------------------------------===//
2382 
2383 ParseResult StructExplodeOp::parse(OpAsmParser &parser,
2384  OperationState &result) {
2385  OpAsmParser::UnresolvedOperand operand;
2386  Type declType;
2387 
2388  if (parser.parseOperand(operand) ||
2389  parser.parseOptionalAttrDict(result.attributes) ||
2390  parser.parseColonType(declType))
2391  return failure();
2392  auto structType = type_dyn_cast<StructType>(declType);
2393  if (!structType)
2394  return parser.emitError(parser.getNameLoc(),
2395  "invalid kind of type specified");
2396 
2397  llvm::SmallVector<Type, 4> structInnerTypes;
2398  structType.getInnerTypes(structInnerTypes);
2399  result.addTypes(structInnerTypes);
2400 
2401  if (parser.resolveOperand(operand, declType, result.operands))
2402  return failure();
2403  return success();
2404 }
2405 
2406 void StructExplodeOp::print(OpAsmPrinter &printer) {
2407  printer << " ";
2408  printer.printOperand(getInput());
2409  printer.printOptionalAttrDict((*this)->getAttrs());
2410  printer << " : " << getInput().getType();
2411 }
2412 
2413 LogicalResult StructExplodeOp::fold(FoldAdaptor adaptor,
2414  SmallVectorImpl<OpFoldResult> &results) {
2415  auto input = adaptor.getInput();
2416  if (!input)
2417  return failure();
2418  llvm::copy(input.cast<ArrayAttr>(), std::back_inserter(results));
2419  return success();
2420 }
2421 
2422 LogicalResult StructExplodeOp::canonicalize(StructExplodeOp op,
2423  PatternRewriter &rewriter) {
2424  auto *inputOp = op.getInput().getDefiningOp();
2425  auto elements = type_cast<StructType>(op.getInput().getType()).getElements();
2426  auto result = failure();
2427  auto opResults = op.getResults();
2428  for (uint32_t index = 0; index < elements.size(); index++) {
2429  if (auto foldResult = foldStructExtract(inputOp, index)) {
2430  rewriter.replaceAllUsesWith(opResults[index], foldResult);
2431  result = success();
2432  }
2433  }
2434  return result;
2435 }
2436 
2438  function_ref<void(Value, StringRef)> setNameFn) {
2439  auto structType = type_cast<StructType>(getInput().getType());
2440  for (auto [res, field] : llvm::zip(getResults(), structType.getElements()))
2441  setNameFn(res, field.name.str());
2442 }
2443 
2444 void StructExplodeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
2445  Value input) {
2446  StructType inputType = input.getType().dyn_cast<StructType>();
2447  assert(inputType);
2448  SmallVector<Type, 16> fieldTypes;
2449  for (auto field : inputType.getElements())
2450  fieldTypes.push_back(field.type);
2451  build(odsBuilder, odsState, fieldTypes, input);
2452 }
2453 
2454 //===----------------------------------------------------------------------===//
2455 // StructExtractOp
2456 //===----------------------------------------------------------------------===//
2457 
2458 /// Ensure an aggregate op's field index is within the bounds of
2459 /// the aggregate type and the accessed field is of 'elementType'.
2460 template <typename AggregateOp, typename AggregateType>
2461 static LogicalResult verifyAggregateFieldIndexAndType(AggregateOp &op,
2462  AggregateType aggType,
2463  Type elementType) {
2464  auto index = op.getFieldIndex();
2465  if (index >= aggType.getElements().size())
2466  return op.emitOpError() << "field index " << index
2467  << " exceeds element count of aggregate type";
2468 
2470  getCanonicalType(aggType.getElements()[index].type))
2471  return op.emitOpError()
2472  << "type " << aggType.getElements()[index].type
2473  << " of accessed field in aggregate at index " << index
2474  << " does not match expected type " << elementType;
2475 
2476  return success();
2477 }
2478 
2479 LogicalResult StructExtractOp::verify() {
2480  return verifyAggregateFieldIndexAndType<StructExtractOp, StructType>(
2481  *this, getInput().getType(), getType());
2482 }
2483 
2484 /// Use the same parser for both struct_extract and union_extract since the
2485 /// syntax is identical.
2486 template <typename AggregateType>
2487 static ParseResult parseExtractOp(OpAsmParser &parser, OperationState &result) {
2488  OpAsmParser::UnresolvedOperand operand;
2489  StringAttr fieldName;
2490  Type declType;
2491 
2492  if (parser.parseOperand(operand) || parser.parseLSquare() ||
2493  parser.parseAttribute(fieldName) || parser.parseRSquare() ||
2494  parser.parseOptionalAttrDict(result.attributes) ||
2495  parser.parseColonType(declType))
2496  return failure();
2497  auto aggType = type_dyn_cast<AggregateType>(declType);
2498  if (!aggType)
2499  return parser.emitError(parser.getNameLoc(),
2500  "invalid kind of type specified");
2501 
2502  auto fieldIndex = aggType.getFieldIndex(fieldName);
2503  if (!fieldIndex) {
2504  parser.emitError(parser.getNameLoc(), "field name '" +
2505  fieldName.getValue() +
2506  "' not found in aggregate type");
2507  return failure();
2508  }
2509 
2510  auto indexAttr =
2511  IntegerAttr::get(IntegerType::get(parser.getContext(), 32), *fieldIndex);
2512  result.addAttribute("fieldIndex", indexAttr);
2513  Type resultType = aggType.getElements()[*fieldIndex].type;
2514  result.addTypes(resultType);
2515 
2516  if (parser.resolveOperand(operand, declType, result.operands))
2517  return failure();
2518  return success();
2519 }
2520 
2521 /// Use the same printer for both struct_extract and union_extract since the
2522 /// syntax is identical.
2523 template <typename AggType>
2524 static void printExtractOp(OpAsmPrinter &printer, AggType op) {
2525  printer << " ";
2526  printer.printOperand(op.getInput());
2527  printer << "[\"" << op.getFieldName() << "\"]";
2528  printer.printOptionalAttrDict(op->getAttrs(), {"fieldIndex"});
2529  printer << " : " << op.getInput().getType();
2530 }
2531 
2532 ParseResult StructExtractOp::parse(OpAsmParser &parser,
2533  OperationState &result) {
2534  return parseExtractOp<StructType>(parser, result);
2535 }
2536 
2537 void StructExtractOp::print(OpAsmPrinter &printer) {
2538  printExtractOp(printer, *this);
2539 }
2540 
2541 void StructExtractOp::build(OpBuilder &builder, OperationState &odsState,
2542  Value input, StructType::FieldInfo field) {
2543  auto fieldIndex =
2544  type_cast<StructType>(input.getType()).getFieldIndex(field.name);
2545  assert(fieldIndex.has_value() && "field name not found in aggregate type");
2546  build(builder, odsState, field.type, input, *fieldIndex);
2547 }
2548 
2549 void StructExtractOp::build(OpBuilder &builder, OperationState &odsState,
2550  Value input, StringAttr fieldName) {
2551  auto structType = type_cast<StructType>(input.getType());
2552  auto fieldIndex = structType.getFieldIndex(fieldName);
2553  assert(fieldIndex.has_value() && "field name not found in aggregate type");
2554  auto resultType = structType.getElements()[*fieldIndex].type;
2555  build(builder, odsState, resultType, input, *fieldIndex);
2556 }
2557 
2558 OpFoldResult StructExtractOp::fold(FoldAdaptor adaptor) {
2559  if (auto constOperand = adaptor.getInput()) {
2560  // Fold extract from aggregate constant
2561  auto operandAttr = llvm::cast<ArrayAttr>(constOperand);
2562  return operandAttr.getValue()[getFieldIndex()];
2563  }
2564 
2565  if (auto foldResult =
2566  foldStructExtract(getInput().getDefiningOp(), getFieldIndex()))
2567  return foldResult;
2568  return {};
2569 }
2570 
2571 LogicalResult StructExtractOp::canonicalize(StructExtractOp op,
2572  PatternRewriter &rewriter) {
2573  auto inputOp = op.getInput().getDefiningOp();
2574 
2575  // b = extract(inject(x["a"], v0)["b"]) => extract(x, "b")
2576  if (auto structInject = dyn_cast_or_null<StructInjectOp>(inputOp)) {
2577  if (structInject.getFieldIndex() != op.getFieldIndex()) {
2578  rewriter.replaceOpWithNewOp<StructExtractOp>(
2579  op, op.getType(), structInject.getInput(), op.getFieldIndexAttr());
2580  return success();
2581  }
2582  }
2583 
2584  return failure();
2585 }
2586 
2588  function_ref<void(Value, StringRef)> setNameFn) {
2589  setNameFn(getResult(), getFieldName());
2590 }
2591 
2592 //===----------------------------------------------------------------------===//
2593 // StructInjectOp
2594 //===----------------------------------------------------------------------===//
2595 
2596 void StructInjectOp::build(OpBuilder &builder, OperationState &odsState,
2597  Value input, StringAttr fieldName, Value newValue) {
2598  auto structType = type_cast<StructType>(input.getType());
2599  auto fieldIndex = structType.getFieldIndex(fieldName);
2600  assert(fieldIndex.has_value() && "field name not found in aggregate type");
2601  build(builder, odsState, input, *fieldIndex, newValue);
2602 }
2603 
2604 LogicalResult StructInjectOp::verify() {
2605  return verifyAggregateFieldIndexAndType<StructInjectOp, StructType>(
2606  *this, getInput().getType(), getNewValue().getType());
2607 }
2608 
2609 ParseResult StructInjectOp::parse(OpAsmParser &parser, OperationState &result) {
2610  llvm::SMLoc inputOperandsLoc = parser.getCurrentLocation();
2611  OpAsmParser::UnresolvedOperand operand, val;
2612  StringAttr fieldName;
2613  Type declType;
2614 
2615  if (parser.parseOperand(operand) || parser.parseLSquare() ||
2616  parser.parseAttribute(fieldName) || parser.parseRSquare() ||
2617  parser.parseComma() || parser.parseOperand(val) ||
2618  parser.parseOptionalAttrDict(result.attributes) ||
2619  parser.parseColonType(declType))
2620  return failure();
2621  auto structType = type_dyn_cast<StructType>(declType);
2622  if (!structType)
2623  return parser.emitError(inputOperandsLoc, "invalid kind of type specified");
2624 
2625  auto fieldIndex = structType.getFieldIndex(fieldName);
2626  if (!fieldIndex) {
2627  parser.emitError(parser.getNameLoc(), "field name '" +
2628  fieldName.getValue() +
2629  "' not found in aggregate type");
2630  return failure();
2631  }
2632 
2633  auto indexAttr =
2634  IntegerAttr::get(IntegerType::get(parser.getContext(), 32), *fieldIndex);
2635  result.addAttribute("fieldIndex", indexAttr);
2636  result.addTypes(declType);
2637 
2638  Type resultType = structType.getElements()[*fieldIndex].type;
2639  if (parser.resolveOperands({operand, val}, {declType, resultType},
2640  inputOperandsLoc, result.operands))
2641  return failure();
2642  return success();
2643 }
2644 
2645 void StructInjectOp::print(OpAsmPrinter &printer) {
2646  printer << " ";
2647  printer.printOperand(getInput());
2648  printer << "[\"" << getFieldName() << "\"], ";
2649  printer.printOperand(getNewValue());
2650  printer.printOptionalAttrDict((*this)->getAttrs(), {"fieldIndex"});
2651  printer << " : " << getInput().getType();
2652 }
2653 
2654 OpFoldResult StructInjectOp::fold(FoldAdaptor adaptor) {
2655  auto input = adaptor.getInput();
2656  auto newValue = adaptor.getNewValue();
2657  if (!input || !newValue)
2658  return {};
2659  SmallVector<Attribute> array;
2660  llvm::copy(input.cast<ArrayAttr>(), std::back_inserter(array));
2661  array[getFieldIndex()] = newValue;
2662  return ArrayAttr::get(getContext(), array);
2663 }
2664 
2665 LogicalResult StructInjectOp::canonicalize(StructInjectOp op,
2666  PatternRewriter &rewriter) {
2667  // Canonicalize multiple injects into a create op and eliminate overwrites.
2668  SmallPtrSet<Operation *, 4> injects;
2669  DenseMap<StringAttr, Value> fields;
2670 
2671  // Chase a chain of injects. Bail out if cycles are present.
2672  StructInjectOp inject = op;
2673  Value input;
2674  do {
2675  if (!injects.insert(inject).second)
2676  return failure();
2677 
2678  fields.try_emplace(inject.getFieldNameAttr(), inject.getNewValue());
2679  input = inject.getInput();
2680  inject = dyn_cast_or_null<StructInjectOp>(input.getDefiningOp());
2681  } while (inject);
2682  assert(input && "missing input to inject chain");
2683 
2684  auto ty = hw::type_cast<StructType>(op.getType());
2685  auto elements = ty.getElements();
2686 
2687  // If the inject chain sets all fields, canonicalize to create.
2688  if (fields.size() == elements.size()) {
2689  SmallVector<Value> createFields;
2690  for (const auto &field : elements) {
2691  auto it = fields.find(field.name);
2692  assert(it != fields.end() && "missing field");
2693  createFields.push_back(it->second);
2694  }
2695  rewriter.replaceOpWithNewOp<StructCreateOp>(op, ty, createFields);
2696  return success();
2697  }
2698 
2699  // Nothing to canonicalize, only the original inject in the chain.
2700  if (injects.size() == fields.size())
2701  return failure();
2702 
2703  // Eliminate overwrites. The hash map contains the last write to each field.
2704  for (uint32_t fieldIndex = 0; fieldIndex < elements.size(); fieldIndex++) {
2705  auto it = fields.find(elements[fieldIndex].name);
2706  if (it == fields.end())
2707  continue;
2708  input = rewriter.create<StructInjectOp>(op.getLoc(), ty, input, fieldIndex,
2709  it->second);
2710  }
2711 
2712  rewriter.replaceOp(op, input);
2713  return success();
2714 }
2715 
2716 //===----------------------------------------------------------------------===//
2717 // UnionCreateOp
2718 //===----------------------------------------------------------------------===//
2719 
2720 LogicalResult UnionCreateOp::verify() {
2721  return verifyAggregateFieldIndexAndType<UnionCreateOp, UnionType>(
2722  *this, getType(), getInput().getType());
2723 }
2724 
2725 void UnionCreateOp::build(OpBuilder &builder, OperationState &odsState,
2726  Type unionType, StringAttr fieldName, Value input) {
2727  auto fieldIndex = type_cast<UnionType>(unionType).getFieldIndex(fieldName);
2728  assert(fieldIndex.has_value() && "field name not found in aggregate type");
2729  build(builder, odsState, unionType, *fieldIndex, input);
2730 }
2731 
2732 ParseResult UnionCreateOp::parse(OpAsmParser &parser, OperationState &result) {
2733  Type declOrAliasType;
2734  StringAttr fieldName;
2735  OpAsmParser::UnresolvedOperand input;
2736  llvm::SMLoc fieldLoc = parser.getCurrentLocation();
2737 
2738  if (parser.parseAttribute(fieldName) || parser.parseComma() ||
2739  parser.parseOperand(input) ||
2740  parser.parseOptionalAttrDict(result.attributes) ||
2741  parser.parseColonType(declOrAliasType))
2742  return failure();
2743 
2744  auto declType = type_dyn_cast<UnionType>(declOrAliasType);
2745  if (!declType)
2746  return parser.emitError(parser.getNameLoc(),
2747  "expected !hw.union type or alias");
2748 
2749  auto fieldIndex = declType.getFieldIndex(fieldName);
2750  if (!fieldIndex) {
2751  parser.emitError(fieldLoc, "cannot find union field '")
2752  << fieldName.getValue() << '\'';
2753  return failure();
2754  }
2755 
2756  auto indexAttr =
2757  IntegerAttr::get(IntegerType::get(parser.getContext(), 32), *fieldIndex);
2758  result.addAttribute("fieldIndex", indexAttr);
2759  Type inputType = declType.getElements()[*fieldIndex].type;
2760 
2761  if (parser.resolveOperand(input, inputType, result.operands))
2762  return failure();
2763  result.addTypes({declOrAliasType});
2764  return success();
2765 }
2766 
2767 void UnionCreateOp::print(OpAsmPrinter &printer) {
2768  printer << " \"" << getFieldName() << "\", ";
2769  printer.printOperand(getInput());
2770  printer.printOptionalAttrDict((*this)->getAttrs(), {"fieldIndex"});
2771  printer << " : " << getType();
2772 }
2773 
2774 //===----------------------------------------------------------------------===//
2775 // UnionExtractOp
2776 //===----------------------------------------------------------------------===//
2777 
2778 ParseResult UnionExtractOp::parse(OpAsmParser &parser, OperationState &result) {
2779  return parseExtractOp<UnionType>(parser, result);
2780 }
2781 
2782 void UnionExtractOp::print(OpAsmPrinter &printer) {
2783  printExtractOp(printer, *this);
2784 }
2785 
2786 LogicalResult UnionExtractOp::inferReturnTypes(
2787  MLIRContext *context, std::optional<Location> loc, ValueRange operands,
2788  DictionaryAttr attrs, mlir::OpaqueProperties properties,
2789  mlir::RegionRange regions, SmallVectorImpl<Type> &results) {
2790  auto unionElements =
2791  hw::type_cast<UnionType>((operands[0].getType())).getElements();
2792  unsigned fieldIndex =
2793  attrs.getAs<IntegerAttr>("fieldIndex").getValue().getZExtValue();
2794  if (fieldIndex >= unionElements.size()) {
2795  if (loc)
2796  mlir::emitError(*loc, "field index " + Twine(fieldIndex) +
2797  " exceeds element count of aggregate type");
2798  return failure();
2799  }
2800  results.push_back(unionElements[fieldIndex].type);
2801  return success();
2802 }
2803 
2804 void UnionExtractOp::build(OpBuilder &odsBuilder, OperationState &odsState,
2805  Value input, StringAttr fieldName) {
2806  auto unionType = type_cast<UnionType>(input.getType());
2807  auto fieldIndex = unionType.getFieldIndex(fieldName);
2808  assert(fieldIndex.has_value() && "field name not found in aggregate type");
2809  auto resultType = unionType.getElements()[*fieldIndex].type;
2810  build(odsBuilder, odsState, resultType, input, *fieldIndex);
2811 }
2812 
2813 //===----------------------------------------------------------------------===//
2814 // ArrayGetOp
2815 //===----------------------------------------------------------------------===//
2816 
2817 // An array_get of an array_create with a constant index can just be the
2818 // array_create operand at the constant index. If the array_create has a
2819 // single uniform value for each element, just return that value regardless of
2820 // the index. If the array is constructed from a constant by a bitcast
2821 // operation, we can fold into a constant.
2822 OpFoldResult ArrayGetOp::fold(FoldAdaptor adaptor) {
2823  auto inputCst = adaptor.getInput().dyn_cast_or_null<ArrayAttr>();
2824  auto indexCst = adaptor.getIndex().dyn_cast_or_null<IntegerAttr>();
2825 
2826  if (inputCst) {
2827  // Constant array index.
2828  if (indexCst) {
2829  auto indexVal = indexCst.getValue();
2830  if (indexVal.getBitWidth() < 64) {
2831  auto index = indexVal.getZExtValue();
2832  return inputCst[inputCst.size() - 1 - index];
2833  }
2834  }
2835  // If all elements of the array are the same, we can return any element of
2836  // array.
2837  if (!inputCst.empty() && llvm::all_equal(inputCst))
2838  return inputCst[0];
2839  }
2840 
2841  // array_get(bitcast(c), i) -> c[i*w+w-1:i*w]
2842  if (auto bitcast = getInput().getDefiningOp<hw::BitcastOp>()) {
2843  auto intTy = getType().dyn_cast<IntegerType>();
2844  if (!intTy)
2845  return {};
2846  auto bitcastInputOp = bitcast.getInput().getDefiningOp<hw::ConstantOp>();
2847  if (!bitcastInputOp)
2848  return {};
2849  if (!indexCst)
2850  return {};
2851  auto bitcastInputCst = bitcastInputOp.getValue();
2852  // Calculate the index. Make sure to zero-extend the index value before
2853  // multiplying the element width.
2854  auto startIdx = indexCst.getValue().zext(bitcastInputCst.getBitWidth()) *
2855  getType().getIntOrFloatBitWidth();
2856  // Extract [startIdx + width - 1: startIdx].
2857  return IntegerAttr::get(intTy, bitcastInputCst.lshr(startIdx).trunc(
2858  intTy.getIntOrFloatBitWidth()));
2859  }
2860 
2861  auto inputCreate = getInput().getDefiningOp<ArrayCreateOp>();
2862  if (!inputCreate)
2863  return {};
2864 
2865  if (auto uniformValue = inputCreate.getUniformElement())
2866  return uniformValue;
2867 
2868  if (!indexCst || indexCst.getValue().getBitWidth() > 64)
2869  return {};
2870 
2871  uint64_t index = indexCst.getValue().getLimitedValue();
2872  auto createInputs = inputCreate.getInputs();
2873  if (index >= createInputs.size())
2874  return {};
2875  return createInputs[createInputs.size() - index - 1];
2876 }
2877 
2878 LogicalResult ArrayGetOp::canonicalize(ArrayGetOp op,
2879  PatternRewriter &rewriter) {
2880  auto idxOpt = getUIntFromValue(op.getIndex());
2881  if (!idxOpt)
2882  return failure();
2883 
2884  auto *inputOp = op.getInput().getDefiningOp();
2885  if (auto inputSlice = dyn_cast_or_null<ArraySliceOp>(inputOp)) {
2886  // get(slice(a, n), m) -> get(a, n + m)
2887  auto offsetOp = inputSlice.getLowIndex();
2888  auto offsetOpt = getUIntFromValue(offsetOp);
2889  if (!offsetOpt)
2890  return failure();
2891 
2892  uint64_t offset = *offsetOpt + *idxOpt;
2893  auto newOffset =
2894  rewriter.create<ConstantOp>(op.getLoc(), offsetOp.getType(), offset);
2895  rewriter.replaceOpWithNewOp<ArrayGetOp>(op, inputSlice.getInput(),
2896  newOffset);
2897  return success();
2898  }
2899 
2900  if (auto inputConcat = dyn_cast_or_null<ArrayConcatOp>(inputOp)) {
2901  // get(concat(a0, a1, ...), m) -> get(an, m - s0 - s1 - ...)
2902  uint64_t elemIndex = *idxOpt;
2903  for (auto input : llvm::reverse(inputConcat.getInputs())) {
2904  size_t size = hw::type_cast<ArrayType>(input.getType()).getNumElements();
2905  if (elemIndex >= size) {
2906  elemIndex -= size;
2907  continue;
2908  }
2909 
2910  unsigned indexWidth = size == 1 ? 1 : llvm::Log2_64_Ceil(size);
2911  auto newIdxOp = rewriter.create<ConstantOp>(
2912  op.getLoc(), rewriter.getIntegerType(indexWidth), elemIndex);
2913 
2914  rewriter.replaceOpWithNewOp<ArrayGetOp>(op, input, newIdxOp);
2915  return success();
2916  }
2917  return failure();
2918  }
2919 
2920  // array_get const, (array_get sel, (array_create a, b, c, d)) -->
2921  // array_get sel, (array_create (array_get const a), (array_get const b),
2922  // (array_get const, c), (array_get const, d))
2923  if (auto innerGet = dyn_cast_or_null<hw::ArrayGetOp>(inputOp)) {
2924  if (!innerGet.getIndex().getDefiningOp<hw::ConstantOp>()) {
2925  if (auto create =
2926  innerGet.getInput().getDefiningOp<hw::ArrayCreateOp>()) {
2927 
2928  SmallVector<Value> newValues;
2929  for (auto operand : create.getOperands())
2930  newValues.push_back(rewriter.createOrFold<hw::ArrayGetOp>(
2931  op.getLoc(), operand, op.getIndex()));
2932 
2933  rewriter.replaceOpWithNewOp<hw::ArrayGetOp>(
2934  op,
2935  rewriter.createOrFold<hw::ArrayCreateOp>(op.getLoc(), newValues),
2936  innerGet.getIndex());
2937  return success();
2938  }
2939  }
2940  }
2941 
2942  return failure();
2943 }
2944 
2945 //===----------------------------------------------------------------------===//
2946 // TypedeclOp
2947 //===----------------------------------------------------------------------===//
2948 
2949 StringRef TypedeclOp::getPreferredName() {
2950  return getVerilogName().value_or(getName());
2951 }
2952 
2953 Type TypedeclOp::getAliasType() {
2954  auto parentScope = cast<hw::TypeScopeOp>(getOperation()->getParentOp());
2955  return hw::TypeAliasType::get(
2956  SymbolRefAttr::get(parentScope.getSymNameAttr(),
2957  {FlatSymbolRefAttr::get(*this)}),
2958  getType());
2959 }
2960 
2961 //===----------------------------------------------------------------------===//
2962 // BitcastOp
2963 //===----------------------------------------------------------------------===//
2964 
2965 OpFoldResult BitcastOp::fold(FoldAdaptor) {
2966  // Identity.
2967  // bitcast(%a) : A -> A ==> %a
2968  if (getOperand().getType() == getType())
2969  return getOperand();
2970 
2971  return {};
2972 }
2973 
2974 LogicalResult BitcastOp::canonicalize(BitcastOp op, PatternRewriter &rewriter) {
2975  // Composition.
2976  // %b = bitcast(%a) : A -> B
2977  // bitcast(%b) : B -> C
2978  // ===> bitcast(%a) : A -> C
2979  auto inputBitcast =
2980  dyn_cast_or_null<BitcastOp>(op.getInput().getDefiningOp());
2981  if (!inputBitcast)
2982  return failure();
2983  auto bitcast = rewriter.createOrFold<BitcastOp>(op.getLoc(), op.getType(),
2984  inputBitcast.getInput());
2985  rewriter.replaceOp(op, bitcast);
2986  return success();
2987 }
2988 
2989 LogicalResult BitcastOp::verify() {
2990  if (getBitWidth(getInput().getType()) != getBitWidth(getResult().getType()))
2991  return this->emitOpError("Bitwidth of input must match result");
2992  return success();
2993 }
2994 
2995 //===----------------------------------------------------------------------===//
2996 // HierPathOp helpers.
2997 //===----------------------------------------------------------------------===//
2998 
2999 bool HierPathOp::dropModule(StringAttr moduleToDrop) {
3000  SmallVector<Attribute, 4> newPath;
3001  bool updateMade = false;
3002  for (auto nameRef : getNamepath()) {
3003  // nameRef is either an InnerRefAttr or a FlatSymbolRefAttr.
3004  if (auto ref = nameRef.dyn_cast<hw::InnerRefAttr>()) {
3005  if (ref.getModule() == moduleToDrop)
3006  updateMade = true;
3007  else
3008  newPath.push_back(ref);
3009  } else {
3010  if (nameRef.cast<FlatSymbolRefAttr>().getAttr() == moduleToDrop)
3011  updateMade = true;
3012  else
3013  newPath.push_back(nameRef);
3014  }
3015  }
3016  if (updateMade)
3017  setNamepathAttr(ArrayAttr::get(getContext(), newPath));
3018  return updateMade;
3019 }
3020 
3021 bool HierPathOp::inlineModule(StringAttr moduleToDrop) {
3022  SmallVector<Attribute, 4> newPath;
3023  bool updateMade = false;
3024  StringRef inlinedInstanceName = "";
3025  for (auto nameRef : getNamepath()) {
3026  // nameRef is either an InnerRefAttr or a FlatSymbolRefAttr.
3027  if (auto ref = nameRef.dyn_cast<hw::InnerRefAttr>()) {
3028  if (ref.getModule() == moduleToDrop) {
3029  inlinedInstanceName = ref.getName().getValue();
3030  updateMade = true;
3031  } else if (!inlinedInstanceName.empty()) {
3032  newPath.push_back(hw::InnerRefAttr::get(
3033  ref.getModule(),
3034  StringAttr::get(getContext(), inlinedInstanceName + "_" +
3035  ref.getName().getValue())));
3036  inlinedInstanceName = "";
3037  } else
3038  newPath.push_back(ref);
3039  } else {
3040  if (nameRef.cast<FlatSymbolRefAttr>().getAttr() == moduleToDrop)
3041  updateMade = true;
3042  else
3043  newPath.push_back(nameRef);
3044  }
3045  }
3046  if (updateMade)
3047  setNamepathAttr(ArrayAttr::get(getContext(), newPath));
3048  return updateMade;
3049 }
3050 
3051 bool HierPathOp::updateModule(StringAttr oldMod, StringAttr newMod) {
3052  SmallVector<Attribute, 4> newPath;
3053  bool updateMade = false;
3054  for (auto nameRef : getNamepath()) {
3055  // nameRef is either an InnerRefAttr or a FlatSymbolRefAttr.
3056  if (auto ref = nameRef.dyn_cast<hw::InnerRefAttr>()) {
3057  if (ref.getModule() == oldMod) {
3058  newPath.push_back(hw::InnerRefAttr::get(newMod, ref.getName()));
3059  updateMade = true;
3060  } else
3061  newPath.push_back(ref);
3062  } else {
3063  if (nameRef.cast<FlatSymbolRefAttr>().getAttr() == oldMod) {
3064  newPath.push_back(FlatSymbolRefAttr::get(newMod));
3065  updateMade = true;
3066  } else
3067  newPath.push_back(nameRef);
3068  }
3069  }
3070  if (updateMade)
3071  setNamepathAttr(ArrayAttr::get(getContext(), newPath));
3072  return updateMade;
3073 }
3074 
3075 bool HierPathOp::updateModuleAndInnerRef(
3076  StringAttr oldMod, StringAttr newMod,
3077  const llvm::DenseMap<StringAttr, StringAttr> &innerSymRenameMap) {
3078  auto fromRef = FlatSymbolRefAttr::get(oldMod);
3079  if (oldMod == newMod)
3080  return false;
3081 
3082  auto namepathNew = getNamepath().getValue().vec();
3083  bool updateMade = false;
3084  // Break from the loop if the module is found, since it can occur only once.
3085  for (auto &element : namepathNew) {
3086  if (auto innerRef = element.dyn_cast<hw::InnerRefAttr>()) {
3087  if (innerRef.getModule() != oldMod)
3088  continue;
3089  auto symName = innerRef.getName();
3090  // Since the module got updated, the old innerRef symbol inside oldMod
3091  // should also be updated to the new symbol inside the newMod.
3092  auto to = innerSymRenameMap.find(symName);
3093  if (to != innerSymRenameMap.end())
3094  symName = to->second;
3095  updateMade = true;
3096  element = hw::InnerRefAttr::get(newMod, symName);
3097  break;
3098  }
3099  if (element != fromRef)
3100  continue;
3101 
3102  updateMade = true;
3103  element = FlatSymbolRefAttr::get(newMod);
3104  break;
3105  }
3106  if (updateMade)
3107  setNamepathAttr(ArrayAttr::get(getContext(), namepathNew));
3108  return updateMade;
3109 }
3110 
3111 bool HierPathOp::truncateAtModule(StringAttr atMod, bool includeMod) {
3112  SmallVector<Attribute, 4> newPath;
3113  bool updateMade = false;
3114  for (auto nameRef : getNamepath()) {
3115  // nameRef is either an InnerRefAttr or a FlatSymbolRefAttr.
3116  if (auto ref = nameRef.dyn_cast<hw::InnerRefAttr>()) {
3117  if (ref.getModule() == atMod) {
3118  updateMade = true;
3119  if (includeMod)
3120  newPath.push_back(ref);
3121  } else
3122  newPath.push_back(ref);
3123  } else {
3124  if (nameRef.cast<FlatSymbolRefAttr>().getAttr() == atMod && !includeMod)
3125  updateMade = true;
3126  else
3127  newPath.push_back(nameRef);
3128  }
3129  if (updateMade)
3130  break;
3131  }
3132  if (updateMade)
3133  setNamepathAttr(ArrayAttr::get(getContext(), newPath));
3134  return updateMade;
3135 }
3136 
3137 /// Return just the module part of the namepath at a specific index.
3138 StringAttr HierPathOp::modPart(unsigned i) {
3139  return TypeSwitch<Attribute, StringAttr>(getNamepath()[i])
3140  .Case<FlatSymbolRefAttr>([](auto a) { return a.getAttr(); })
3141  .Case<hw::InnerRefAttr>([](auto a) { return a.getModule(); });
3142 }
3143 
3144 /// Return the root module.
3145 StringAttr HierPathOp::root() {
3146  assert(!getNamepath().empty());
3147  return modPart(0);
3148 }
3149 
3150 /// Return true if the NLA has the module in its path.
3151 bool HierPathOp::hasModule(StringAttr modName) {
3152  for (auto nameRef : getNamepath()) {
3153  // nameRef is either an InnerRefAttr or a FlatSymbolRefAttr.
3154  if (auto ref = nameRef.dyn_cast<hw::InnerRefAttr>()) {
3155  if (ref.getModule() == modName)
3156  return true;
3157  } else {
3158  if (nameRef.cast<FlatSymbolRefAttr>().getAttr() == modName)
3159  return true;
3160  }
3161  }
3162  return false;
3163 }
3164 
3165 /// Return true if the NLA has the InnerSym .
3166 bool HierPathOp::hasInnerSym(StringAttr modName, StringAttr symName) const {
3167  for (auto nameRef : const_cast<HierPathOp *>(this)->getNamepath())
3168  if (auto ref = nameRef.dyn_cast<hw::InnerRefAttr>())
3169  if (ref.getName() == symName && ref.getModule() == modName)
3170  return true;
3171 
3172  return false;
3173 }
3174 
3175 /// Return just the reference part of the namepath at a specific index. This
3176 /// will return an empty attribute if this is the leaf and the leaf is a module.
3177 StringAttr HierPathOp::refPart(unsigned i) {
3178  return TypeSwitch<Attribute, StringAttr>(getNamepath()[i])
3179  .Case<FlatSymbolRefAttr>([](auto a) { return StringAttr({}); })
3180  .Case<hw::InnerRefAttr>([](auto a) { return a.getName(); });
3181 }
3182 
3183 /// Return the leaf reference. This returns an empty attribute if the leaf
3184 /// reference is a module.
3185 StringAttr HierPathOp::ref() {
3186  assert(!getNamepath().empty());
3187  return refPart(getNamepath().size() - 1);
3188 }
3189 
3190 /// Return the leaf module.
3191 StringAttr HierPathOp::leafMod() {
3192  assert(!getNamepath().empty());
3193  return modPart(getNamepath().size() - 1);
3194 }
3195 
3196 /// Returns true if this NLA targets an instance of a module (as opposed to
3197 /// an instance's port or something inside an instance).
3198 bool HierPathOp::isModule() { return !ref(); }
3199 
3200 /// Returns true if this NLA targets something inside a module (as opposed
3201 /// to a module or an instance of a module);
3202 bool HierPathOp::isComponent() { return (bool)ref(); }
3203 
3204 // Verify the HierPathOp.
3205 // 1. Iterate over the namepath.
3206 // 2. The namepath should be a valid instance path, specified either on a
3207 // module or a declaration inside a module.
3208 // 3. Each element in the namepath is an InnerRefAttr except possibly the
3209 // last element.
3210 // 4. Make sure that the InnerRefAttr is legal, by verifying the module name
3211 // and the corresponding inner_sym on the instance.
3212 // 5. Make sure that the instance path is legal, by verifying the sequence of
3213 // instance and the expected module occurs as the next element in the path.
3214 // 6. The last element of the namepath, can be an InnerRefAttr on either a
3215 // module port or a declaration inside the module.
3216 // 7. The last element of the namepath can also be a module symbol.
3217 LogicalResult HierPathOp::verifyInnerRefs(hw::InnerRefNamespace &ns) {
3218  ArrayAttr expectedModuleNames = {};
3219  auto checkExpectedModule = [&](Attribute name) -> LogicalResult {
3220  if (!expectedModuleNames)
3221  return success();
3222  if (llvm::any_of(expectedModuleNames,
3223  [name](Attribute attr) { return attr == name; }))
3224  return success();
3225  auto diag = emitOpError() << "instance path is incorrect. Expected ";
3226  size_t n = expectedModuleNames.size();
3227  if (n != 1) {
3228  diag << "one of ";
3229  }
3230  for (size_t i = 0; i < n; ++i) {
3231  if (i != 0)
3232  diag << ((i + 1 == n) ? " or " : ", ");
3233  diag << expectedModuleNames[i].cast<StringAttr>();
3234  }
3235  diag << ". Instead found: " << name;
3236  return diag;
3237  };
3238 
3239  if (!getNamepath() || getNamepath().empty())
3240  return emitOpError() << "the instance path cannot be empty";
3241  for (unsigned i = 0, s = getNamepath().size() - 1; i < s; ++i) {
3242  hw::InnerRefAttr innerRef = getNamepath()[i].dyn_cast<hw::InnerRefAttr>();
3243  if (!innerRef)
3244  return emitOpError()
3245  << "the instance path can only contain inner sym reference"
3246  << ", only the leaf can refer to a module symbol";
3247 
3248  if (failed(checkExpectedModule(innerRef.getModule())))
3249  return failure();
3250 
3251  auto instOp = ns.lookupOp<igraph::InstanceOpInterface>(innerRef);
3252  if (!instOp)
3253  return emitOpError() << " module: " << innerRef.getModule()
3254  << " does not contain any instance with symbol: "
3255  << innerRef.getName();
3256  expectedModuleNames = instOp.getReferencedModuleNamesAttr();
3257  }
3258 
3259  // The instance path has been verified. Now verify the last element.
3260  auto leafRef = getNamepath()[getNamepath().size() - 1];
3261  if (auto innerRef = leafRef.dyn_cast<hw::InnerRefAttr>()) {
3262  if (!ns.lookup(innerRef)) {
3263  return emitOpError() << " operation with symbol: " << innerRef
3264  << " was not found ";
3265  }
3266  if (failed(checkExpectedModule(innerRef.getModule())))
3267  return failure();
3268  } else if (failed(checkExpectedModule(
3269  leafRef.cast<FlatSymbolRefAttr>().getAttr()))) {
3270  return failure();
3271  }
3272  return success();
3273 }
3274 
3275 void HierPathOp::print(OpAsmPrinter &p) {
3276  p << " ";
3277 
3278  // Print visibility if present.
3279  StringRef visibilityAttrName = SymbolTable::getVisibilityAttrName();
3280  if (auto visibility =
3281  getOperation()->getAttrOfType<StringAttr>(visibilityAttrName))
3282  p << visibility.getValue() << ' ';
3283 
3284  p.printSymbolName(getSymName());
3285  p << " [";
3286  llvm::interleaveComma(getNamepath().getValue(), p, [&](Attribute attr) {
3287  if (auto ref = attr.dyn_cast<hw::InnerRefAttr>()) {
3288  p.printSymbolName(ref.getModule().getValue());
3289  p << "::";
3290  p.printSymbolName(ref.getName().getValue());
3291  } else {
3292  p.printSymbolName(attr.cast<FlatSymbolRefAttr>().getValue());
3293  }
3294  });
3295  p << "]";
3296  p.printOptionalAttrDict(
3297  (*this)->getAttrs(),
3298  {SymbolTable::getSymbolAttrName(), "namepath", visibilityAttrName});
3299 }
3300 
3301 ParseResult HierPathOp::parse(OpAsmParser &parser, OperationState &result) {
3302  // Parse the visibility attribute.
3303  (void)mlir::impl::parseOptionalVisibilityKeyword(parser, result.attributes);
3304 
3305  // Parse the symbol name.
3306  StringAttr symName;
3307  if (parser.parseSymbolName(symName, SymbolTable::getSymbolAttrName(),
3308  result.attributes))
3309  return failure();
3310 
3311  // Parse the namepath.
3312  SmallVector<Attribute> namepath;
3313  if (parser.parseCommaSeparatedList(
3314  OpAsmParser::Delimiter::Square, [&]() -> ParseResult {
3315  auto loc = parser.getCurrentLocation();
3316  SymbolRefAttr ref;
3317  if (parser.parseAttribute(ref))
3318  return failure();
3319 
3320  // "A" is a Ref, "A::b" is a InnerRef, "A::B::c" is an error.
3321  auto pathLength = ref.getNestedReferences().size();
3322  if (pathLength == 0)
3323  namepath.push_back(
3324  FlatSymbolRefAttr::get(ref.getRootReference()));
3325  else if (pathLength == 1)
3326  namepath.push_back(hw::InnerRefAttr::get(ref.getRootReference(),
3327  ref.getLeafReference()));
3328  else
3329  return parser.emitError(loc,
3330  "only one nested reference is allowed");
3331  return success();
3332  }))
3333  return failure();
3334  result.addAttribute("namepath",
3335  ArrayAttr::get(parser.getContext(), namepath));
3336 
3337  if (parser.parseOptionalAttrDict(result.attributes))
3338  return failure();
3339 
3340  return success();
3341 }
3342 
3343 //===----------------------------------------------------------------------===//
3344 // TriggeredOp
3345 //===----------------------------------------------------------------------===//
3346 
3347 void TriggeredOp::build(OpBuilder &builder, OperationState &odsState,
3348  EventControlAttr event, Value trigger,
3349  ValueRange inputs) {
3350  odsState.addOperands(trigger);
3351  odsState.addOperands(inputs);
3352  odsState.addAttribute(getEventAttrName(odsState.name), event);
3353  auto *r = odsState.addRegion();
3354  Block *b = new Block();
3355  r->push_back(b);
3356 
3357  llvm::SmallVector<Location> argLocs;
3358  llvm::transform(inputs, std::back_inserter(argLocs),
3359  [&](Value v) { return v.getLoc(); });
3360  b->addArguments(inputs.getTypes(), argLocs);
3361 }
3362 
3363 //===----------------------------------------------------------------------===//
3364 // TableGen generated logic.
3365 //===----------------------------------------------------------------------===//
3366 
3367 // Provide the autogenerated implementation guts for the Op classes.
3368 #define GET_OP_CLASSES
3369 #include "circt/Dialect/HW/HW.cpp.inc"
assert(baseType &&"element must be base type")
MlirType elementType
Definition: CHIRRTL.cpp:29
int32_t width
Definition: FIRRTL.cpp:36
static LogicalResult verifyModuleCommon(HWModuleLike module)
Definition: HWOps.cpp:1074
static void printParamValue(OpAsmPrinter &p, Operation *, Attribute value, Type resultType)
Definition: HWOps.cpp:487
static void printModuleOp(OpAsmPrinter &p, ModuleTy mod)
Definition: HWOps.cpp:1019
static bool flattenConcatOp(ArrayConcatOp op, PatternRewriter &rewriter)
Definition: HWOps.cpp:2129
static LogicalResult foldCreateToSlice(ArrayCreateOp op, PatternRewriter &rewriter)
Definition: HWOps.cpp:1863
static SmallVector< Location > getAllPortLocs(ModTy module)
Definition: HWOps.cpp:1204
static ArrayAttr arrayOrEmpty(mlir::MLIRContext *context, ArrayRef< Attribute > attrs)
Definition: HWOps.cpp:84
FunctionType getHWModuleOpType(Operation *op)
Definition: HWOps.cpp:1010
static void buildModule(OpBuilder &builder, OperationState &result, StringAttr name, const ModulePortInfo &ports, ArrayAttr parameters, ArrayRef< NamedAttribute > attributes, StringAttr comment)
Definition: HWOps.cpp:544
static void printExtractOp(OpAsmPrinter &printer, AggType op)
Use the same printer for both struct_extract and union_extract since the syntax is identical.
Definition: HWOps.cpp:2524
static void modifyModulePorts(Operation *op, ArrayRef< std::pair< unsigned, PortInfo >> insertInputs, ArrayRef< std::pair< unsigned, PortInfo >> insertOutputs, ArrayRef< unsigned > removeInputs, ArrayRef< unsigned > removeOutputs, Block *body=nullptr)
Insert and remove ports of a module.
Definition: HWOps.cpp:681
static void printArrayConcatTypes(OpAsmPrinter &p, Operation *, TypeRange inputTypes, Type resultType)
Definition: HWOps.cpp:2093
static void modifyModuleArgs(MLIRContext *context, ArrayRef< std::pair< unsigned, PortInfo >> insertArgs, ArrayRef< unsigned > removeArgs, ArrayRef< Attribute > oldArgNames, ArrayRef< Type > oldArgTypes, ArrayRef< Attribute > oldArgAttrs, ArrayRef< Location > oldArgLocs, SmallVector< Attribute > &newArgNames, SmallVector< Type > &newArgTypes, SmallVector< Attribute > &newArgAttrs, SmallVector< Location > &newArgLocs, Block *body=nullptr)
Internal implementation of argument/result insertion and removal on modules.
Definition: HWOps.cpp:582
static ParseResult parseSliceTypes(OpAsmParser &p, Type &srcType, Type &idxType)
Definition: HWOps.cpp:1767
static Value foldStructExtract(Operation *inputOp, uint32_t fieldIndex)
Definition: HWOps.cpp:69
static bool hasAttribute(StringRef name, ArrayRef< NamedAttribute > attrs)
Definition: HWOps.cpp:888
static bool mergeConcatSlices(ArrayConcatOp op, PatternRewriter &rewriter)
Definition: HWOps.cpp:2146
static ParseResult parseExtractOp(OpAsmParser &parser, OperationState &result)
Use the same parser for both struct_extract and union_extract since the syntax is identical.
Definition: HWOps.cpp:2487
static void setAllPortNames(ArrayRef< Attribute > names, ModTy module)
Definition: HWOps.cpp:1272
static void getAsmBlockArgumentNamesImpl(mlir::Region &region, OpAsmSetValueNameFn setNameFn)
Get a special name to use when printing the entry block arguments of the region contained by an opera...
Definition: HWOps.cpp:101
static void setHWModuleType(ModTy &mod, ModuleType type)
Definition: HWOps.cpp:1345
static SmallVector< PortInfo > getPortList(ModuleTy &mod)
Definition: HWOps.cpp:1424
static ParseResult parseParamValue(OpAsmParser &p, Attribute &value, Type &resultType)
Definition: HWOps.cpp:479
static LogicalResult checkAttributes(Operation *op, Attribute attr, Type type)
Definition: HWOps.cpp:397
static std::optional< uint64_t > getUIntFromValue(Value value)
Definition: HWOps.cpp:1934
static ParseResult parseHWModuleOp(OpAsmParser &parser, OperationState &result)
Definition: HWOps.cpp:896
static LogicalResult verifyAggregateFieldIndexAndType(AggregateOp &op, AggregateType aggType, Type elementType)
Ensure an aggregate op's field index is within the bounds of the aggregate type and the accessed fiel...
Definition: HWOps.cpp:2461
static PortInfo getPort(ModuleTy &mod, size_t idx)
Definition: HWOps.cpp:1444
static void printSliceTypes(OpAsmPrinter &p, Operation *, Type srcType, Type idxType)
Definition: HWOps.cpp:1781
static bool hasAdditionalAttributes(Op op, ArrayRef< StringRef > ignoredAttrs={})
Check whether an operation has any additional attributes set beyond its standard list of attributes r...
Definition: HWOps.cpp:343
Delimiter
Definition: HWOps.cpp:116
@ OptionalLessGreater
static ParseResult parseArrayConcatTypes(OpAsmParser &p, SmallVectorImpl< Type > &inputTypes, Type &resultType)
Definition: HWOps.cpp:2063
@ Input
Definition: HW.h:35
@ Output
Definition: HW.h:35
@ InOut
Definition: HW.h:35
static bool getFieldName(const FieldRef &fieldRef, SmallString< 32 > &string)
static InstancePath empty
llvm::SmallVector< StringAttr > inputs
llvm::SmallVector< StringAttr > outputs
Builder builder
static StringAttr append(StringAttr base, const Twine &suffix)
Return a attribute with the specified suffix appended.
A namespace that is used to store existing names and generate new names in some scope within the IR.
Definition: Namespace.h:29
StringRef newName(const Twine &name)
Return a unique name, derived from the input name, and add the new name to the internal namespace.
Definition: Namespace.h:63
void setOutput(unsigned i, Value v)
Definition: HWOps.cpp:241
Value getInput(unsigned i)
Definition: HWOps.cpp:247
llvm::SmallVector< Value > outputOperands
Definition: HWOps.h:117
llvm::SmallVector< Value > inputArgs
Definition: HWOps.h:116
llvm::StringMap< unsigned > outputIdx
Definition: HWOps.h:115
llvm::StringMap< unsigned > inputIdx
Definition: HWOps.h:115
HWModulePortAccessor(Location loc, const ModulePortInfo &info, Region &bodyRegion)
Definition: HWOps.cpp:225
static StringRef getInnerSymbolAttrName()
Return the name of the attribute used for inner symbol names.
This helps visit TypeOp nodes.
Definition: HWVisitors.h:25
ResultType dispatchTypeOpVisitor(Operation *op, ExtraArgs... args)
Definition: HWVisitors.h:27
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:53
uint64_t getWidth(Type t)
Definition: ESIPasses.cpp:34
LogicalResult inferReturnTypes(MLIRContext *context, std::optional< Location > loc, ValueRange operands, DictionaryAttr attrs, mlir::OpaqueProperties properties, mlir::RegionRange regions, SmallVectorImpl< Type > &results, llvm::function_ref< FIRRTLType(ValueRange, ArrayRef< NamedAttribute >, std::optional< Location >)> callback)
size_t getNumPorts(Operation *op)
Return the number of ports in a module-like thing (modules, memories, etc)
Definition: FIRRTLOps.cpp:273
ModuleType fnToMod(Operation *op, ArrayRef< Attribute > inputNames, ArrayRef< Attribute > outputNames)
Definition: HWTypes.cpp:1075
LogicalResult verifyParameterStructure(ArrayAttr parameters, ArrayAttr moduleParameters, const EmitErrorFn &emitError)
Check that all the parameter values specified to the instance are structurally valid.
std::function< void(std::function< bool(InFlightDiagnostic &)>)> EmitErrorFn
Whenever the nested function returns true, a note referring to the referenced module is attached to t...
LogicalResult verifyInstanceOfHWModule(Operation *instance, FlatSymbolRefAttr moduleRef, OperandRange inputs, TypeRange results, ArrayAttr argNames, ArrayAttr resultNames, ArrayAttr parameters, SymbolTableCollection &symbolTable)
Combines verifyReferencedModule, verifyInputs, verifyOutputs, and verifyParameters.
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
void getAsmResultNames(OpAsmSetValueNameFn setNameFn, StringRef instanceName, ArrayAttr resultNames, ValueRange results)
Suggest a name for each result value based on the saved result names attribute.
ParseResult parseModuleSignature(OpAsmParser &parser, SmallVectorImpl< PortParse > &args, TypeAttr &modType)
New Style parsing.
void printModuleSignatureNew(OpAsmPrinter &p, HWModuleLike op)
bool isOffset(Value base, Value index, uint64_t offset)
Definition: HWOps.cpp:1839
llvm::function_ref< void(OpBuilder &, HWModulePortAccessor &)> HWModuleBuilder
Definition: HWOps.h:122
bool isCombinational(Operation *op)
Return true if the specified operation is a combinational logic op.
Definition: HWOps.cpp:59
ModulePort::Direction flip(ModulePort::Direction direction)
Flip a port direction.
Definition: HWOps.cpp:36
FunctionType getModuleType(Operation *module)
Return the signature for the specified module as a function type.
Definition: HWOps.cpp:515
LogicalResult checkParameterInContext(Attribute value, Operation *module, Operation *usingOp, bool disallowParamRefs=false)
Check parameter specified by value to see if it is valid within the scope of the specified module mod...
Definition: HWOps.cpp:202
bool isValidParameterExpression(Attribute attr, Operation *module)
Return true if the specified attribute tree is made up of nodes that are valid in a parameter express...
Definition: HWOps.cpp:221
bool isValidIndexBitWidth(Value index, Value array)
Definition: HWOps.cpp:48
int64_t getBitWidth(mlir::Type type)
Return the hardware bit width of a type.
Definition: HWTypes.cpp:102
bool isAnyModuleOrInstance(Operation *module)
TODO: Move all these functions to a hw::ModuleLike interface.
Definition: HWOps.cpp:509
StringAttr getVerilogModuleNameAttr(Operation *module)
Returns the verilog module name attribute or symbol name of any module-like operations.
Definition: HWOps.cpp:534
mlir::Type getCanonicalType(mlir::Type type)
Definition: HWTypes.cpp:41
circt::hw::InOutType InOutType
Definition: SVTypes.h:25
This file defines an intermediate representation for circuits acting as an abstraction for constraint...
Definition: DebugAnalysis.h:21
ParseResult parseInputPortList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &inputs, SmallVectorImpl< Type > &inputTypes, ArrayAttr &inputNames)
Parse a list of instance input ports.
void printOutputPortList(OpAsmPrinter &p, Operation *op, TypeRange resultTypes, ArrayAttr resultNames)
Print a list of instance output ports.
ParseResult parseOptionalParameterList(OpAsmParser &parser, ArrayAttr &parameters)
Parse an parameter list if present.
void printOptionalParameterList(OpAsmPrinter &p, Operation *op, ArrayAttr parameters)
Print a parameter list for a module or instance.
StringRef chooseName(StringRef a, StringRef b)
Choose a good name for an item from two options.
Definition: Naming.cpp:47
void printInputPortList(OpAsmPrinter &p, Operation *op, OperandRange inputs, TypeRange inputTypes, ArrayAttr inputNames)
Print a list of instance input ports.
ParseResult parseOutputPortList(OpAsmParser &parser, SmallVectorImpl< Type > &resultTypes, ArrayAttr &resultNames)
Parse a list of instance output ports.
Definition: hw.py:1
function_ref< void(Value, StringRef)> OpAsmSetValueNameFn
Definition: LLVM.h:186
This holds a decoded list of input/inout and output ports for a module or instance.
PortInfo & at(size_t idx)
PortDirectionRange getOutputs()
mlir::StringAttr name
Definition: HWTypes.h:29