CIRCT  18.0.0git
HWOps.cpp
Go to the documentation of this file.
1 //===- HWOps.cpp - Implement the HW operations ----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implement the HW ops.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "circt/Dialect/HW/HWOps.h"
23 #include "mlir/IR/Builders.h"
24 #include "mlir/IR/PatternMatch.h"
25 #include "mlir/Interfaces/FunctionImplementation.h"
26 #include "llvm/ADT/BitVector.h"
27 #include "llvm/ADT/SmallPtrSet.h"
28 #include "llvm/ADT/StringSet.h"
29 
30 using namespace circt;
31 using namespace hw;
32 using mlir::TypedAttr;
33 
34 /// Flip a port direction.
36  switch (direction) {
43  }
44  llvm_unreachable("unknown PortDirection");
45 }
46 
47 bool hw::isValidIndexBitWidth(Value index, Value array) {
48  hw::ArrayType arrayType =
49  hw::getCanonicalType(array.getType()).dyn_cast<hw::ArrayType>();
50  assert(arrayType && "expected array type");
51  unsigned indexWidth = index.getType().getIntOrFloatBitWidth();
52  auto requiredWidth = llvm::Log2_64_Ceil(arrayType.getNumElements());
53  return requiredWidth == 0 ? (indexWidth == 0 || indexWidth == 1)
54  : indexWidth == requiredWidth;
55 }
56 
57 /// Return true if the specified operation is a combinational logic op.
58 bool hw::isCombinational(Operation *op) {
59  struct IsCombClassifier : public TypeOpVisitor<IsCombClassifier, bool> {
60  bool visitInvalidTypeOp(Operation *op) { return false; }
61  bool visitUnhandledTypeOp(Operation *op) { return true; }
62  };
63 
64  return (op->getDialect() && op->getDialect()->getNamespace() == "comb") ||
65  IsCombClassifier().dispatchTypeOpVisitor(op);
66 }
67 
68 static Value foldStructExtract(Operation *inputOp, StringRef field) {
69  // A struct extract of a struct create -> corresponding struct create operand.
70  if (auto structCreate = dyn_cast_or_null<StructCreateOp>(inputOp)) {
71  auto ty = type_cast<StructType>(structCreate.getResult().getType());
72  if (auto idx = ty.getFieldIndex(field))
73  return structCreate.getOperand(*idx);
74  return {};
75  }
76  // Extracting injected field -> corresponding field
77  if (auto structInject = dyn_cast_or_null<StructInjectOp>(inputOp)) {
78  if (structInject.getField() != field)
79  return {};
80  return structInject.getNewValue();
81  }
82  return {};
83 }
84 
85 static ArrayAttr arrayOrEmpty(mlir::MLIRContext *context,
86  ArrayRef<Attribute> attrs) {
87  if (attrs.empty())
88  return ArrayAttr::get(context, {});
89  bool empty = true;
90  for (auto a : attrs)
91  if (a && !cast<DictionaryAttr>(a).empty()) {
92  empty = false;
93  break;
94  }
95  if (empty)
96  return ArrayAttr::get(context, {});
97  return ArrayAttr::get(context, attrs);
98 }
99 
100 /// Get a special name to use when printing the entry block arguments of the
101 /// region contained by an operation in this dialect.
102 static void getAsmBlockArgumentNamesImpl(mlir::Region &region,
103  OpAsmSetValueNameFn setNameFn) {
104  if (region.empty())
105  return;
106  // Assign port names to the bbargs.
107  // auto *module = region.getParentOp();
108  auto module = cast<HWModuleOp>(region.getParentOp());
109 
110  auto *block = &region.front();
111  for (size_t i = 0, e = block->getNumArguments(); i != e; ++i) {
112  auto name = module.getInputName(i);
113  if (!name.empty())
114  setNameFn(block->getArgument(i), name);
115  }
116 }
117 
118 enum class Delimiter {
119  None,
120  Paren, // () enclosed list
121  OptionalLessGreater, // <> enclosed list or absent
122 };
123 
124 /// Check parameter specified by `value` to see if it is valid according to the
125 /// module's parameters. If not, emit an error to the diagnostic provided as an
126 /// argument to the lambda 'instanceError' and return failure, otherwise return
127 /// success.
128 ///
129 /// If `disallowParamRefs` is true, then parameter references are not allowed.
131  Attribute value, ArrayAttr moduleParameters,
132  const instance_like_impl::EmitErrorFn &instanceError,
133  bool disallowParamRefs) {
134  // Literals are always ok. Their types are already known to match
135  // expectations.
136  if (value.isa<IntegerAttr>() || value.isa<FloatAttr>() ||
137  value.isa<StringAttr>() || value.isa<ParamVerbatimAttr>())
138  return success();
139 
140  // Check both subexpressions of an expression.
141  if (auto expr = value.dyn_cast<ParamExprAttr>()) {
142  for (auto op : expr.getOperands())
143  if (failed(checkParameterInContext(op, moduleParameters, instanceError,
144  disallowParamRefs)))
145  return failure();
146  return success();
147  }
148 
149  // Parameter references need more analysis to make sure they are valid within
150  // this module.
151  if (auto parameterRef = value.dyn_cast<ParamDeclRefAttr>()) {
152  auto nameAttr = parameterRef.getName();
153 
154  // Don't allow references to parameters from the default values of a
155  // parameter list.
156  if (disallowParamRefs) {
157  instanceError([&](auto &diag) {
158  diag << "parameter " << nameAttr
159  << " cannot be used as a default value for a parameter";
160  return false;
161  });
162  return failure();
163  }
164 
165  // Find the corresponding attribute in the module.
166  for (auto param : moduleParameters) {
167  auto paramAttr = param.cast<ParamDeclAttr>();
168  if (paramAttr.getName() != nameAttr)
169  continue;
170 
171  // If the types match then the reference is ok.
172  if (paramAttr.getType() == parameterRef.getType())
173  return success();
174 
175  instanceError([&](auto &diag) {
176  diag << "parameter " << nameAttr << " used with type "
177  << parameterRef.getType() << "; should have type "
178  << paramAttr.getType();
179  return true;
180  });
181  return failure();
182  }
183 
184  instanceError([&](auto &diag) {
185  diag << "use of unknown parameter " << nameAttr;
186  return true;
187  });
188  return failure();
189  }
190 
191  instanceError([&](auto &diag) {
192  diag << "invalid parameter value " << value;
193  return false;
194  });
195  return failure();
196 }
197 
198 /// Check parameter specified by `value` to see if it is valid within the scope
199 /// of the specified module `module`. If not, emit an error at the location of
200 /// `usingOp` and return failure, otherwise return success. If `usingOp` is
201 /// null, then no diagnostic is generated.
202 ///
203 /// If `disallowParamRefs` is true, then parameter references are not allowed.
204 LogicalResult hw::checkParameterInContext(Attribute value, Operation *module,
205  Operation *usingOp,
206  bool disallowParamRefs) {
208  [&](const std::function<bool(InFlightDiagnostic &)> &fn) {
209  if (usingOp) {
210  auto diag = usingOp->emitOpError();
211  if (fn(diag))
212  diag.attachNote(module->getLoc()) << "module declared here";
213  }
214  };
215 
216  return checkParameterInContext(value,
217  module->getAttrOfType<ArrayAttr>("parameters"),
218  emitError, disallowParamRefs);
219 }
220 
221 /// Return true if the specified attribute tree is made up of nodes that are
222 /// valid in a parameter expression.
223 bool hw::isValidParameterExpression(Attribute attr, Operation *module) {
224  return succeeded(checkParameterInContext(attr, module, nullptr, false));
225 }
226 
228  const ModulePortInfo &info,
229  Region &bodyRegion)
230  : info(info) {
231  inputArgs.resize(info.sizeInputs());
232  for (auto [i, barg] : llvm::enumerate(bodyRegion.getArguments())) {
233  inputIdx[info.at(i).name.str()] = i;
234  inputArgs[i] = barg;
235  }
236 
237  outputOperands.resize(info.sizeOutputs());
238  for (auto [i, outputInfo] : llvm::enumerate(info.getOutputs())) {
239  outputIdx[outputInfo.name.str()] = i;
240  }
241 }
242 
243 void HWModulePortAccessor::setOutput(unsigned i, Value v) {
244  assert(outputOperands.size() > i && "invalid output index");
245  assert(outputOperands[i] == Value() && "output already set");
246  outputOperands[i] = v;
247 }
248 
250  assert(inputArgs.size() > i && "invalid input index");
251  return inputArgs[i];
252 }
253 Value HWModulePortAccessor::getInput(StringRef name) {
254  return getInput(inputIdx.find(name.str())->second);
255 }
256 void HWModulePortAccessor::setOutput(StringRef name, Value v) {
257  setOutput(outputIdx.find(name.str())->second, v);
258 }
259 
260 //===----------------------------------------------------------------------===//
261 // ConstantOp
262 //===----------------------------------------------------------------------===//
263 
264 void ConstantOp::print(OpAsmPrinter &p) {
265  p << " ";
266  p.printAttribute(getValueAttr());
267  p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"value"});
268 }
269 
270 ParseResult ConstantOp::parse(OpAsmParser &parser, OperationState &result) {
271  IntegerAttr valueAttr;
272 
273  if (parser.parseAttribute(valueAttr, "value", result.attributes) ||
274  parser.parseOptionalAttrDict(result.attributes))
275  return failure();
276 
277  result.addTypes(valueAttr.getType());
278  return success();
279 }
280 
281 LogicalResult ConstantOp::verify() {
282  // If the result type has a bitwidth, then the attribute must match its width.
283  if (getValue().getBitWidth() != getType().cast<IntegerType>().getWidth())
284  return emitError(
285  "hw.constant attribute bitwidth doesn't match return type");
286 
287  return success();
288 }
289 
290 /// Build a ConstantOp from an APInt, infering the result type from the
291 /// width of the APInt.
292 void ConstantOp::build(OpBuilder &builder, OperationState &result,
293  const APInt &value) {
294 
295  auto type = IntegerType::get(builder.getContext(), value.getBitWidth());
296  auto attr = builder.getIntegerAttr(type, value);
297  return build(builder, result, type, attr);
298 }
299 
300 /// Build a ConstantOp from an APInt, infering the result type from the
301 /// width of the APInt.
302 void ConstantOp::build(OpBuilder &builder, OperationState &result,
303  IntegerAttr value) {
304  return build(builder, result, value.getType(), value);
305 }
306 
307 /// This builder allows construction of small signed integers like 0, 1, -1
308 /// matching a specified MLIR IntegerType. This shouldn't be used for general
309 /// constant folding because it only works with values that can be expressed in
310 /// an int64_t. Use APInt's instead.
311 void ConstantOp::build(OpBuilder &builder, OperationState &result, Type type,
312  int64_t value) {
313  auto numBits = type.cast<IntegerType>().getWidth();
314  build(builder, result, APInt(numBits, (uint64_t)value, /*isSigned=*/true));
315 }
316 
318  function_ref<void(Value, StringRef)> setNameFn) {
319  auto intTy = getType();
320  auto intCst = getValue();
321 
322  // Sugar i1 constants with 'true' and 'false'.
323  if (intTy.cast<IntegerType>().getWidth() == 1)
324  return setNameFn(getResult(), intCst.isZero() ? "false" : "true");
325 
326  // Otherwise, build a complex name with the value and type.
327  SmallVector<char, 32> specialNameBuffer;
328  llvm::raw_svector_ostream specialName(specialNameBuffer);
329  specialName << 'c' << intCst << '_' << intTy;
330  setNameFn(getResult(), specialName.str());
331 }
332 
333 OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) {
334  assert(adaptor.getOperands().empty() && "constant has no operands");
335  return getValueAttr();
336 }
337 
338 //===----------------------------------------------------------------------===//
339 // WireOp
340 //===----------------------------------------------------------------------===//
341 
342 /// Check whether an operation has any additional attributes set beyond its
343 /// standard list of attributes returned by `getAttributeNames`.
344 template <class Op>
345 static bool hasAdditionalAttributes(Op op,
346  ArrayRef<StringRef> ignoredAttrs = {}) {
347  auto names = op.getAttributeNames();
348  llvm::SmallDenseSet<StringRef> nameSet;
349  nameSet.reserve(names.size() + ignoredAttrs.size());
350  nameSet.insert(names.begin(), names.end());
351  nameSet.insert(ignoredAttrs.begin(), ignoredAttrs.end());
352  return llvm::any_of(op->getAttrs(), [&](auto namedAttr) {
353  return !nameSet.contains(namedAttr.getName());
354  });
355 }
356 
358  // If the wire has an optional 'name' attribute, use it.
359  auto nameAttr = (*this)->getAttrOfType<StringAttr>("name");
360  if (!nameAttr.getValue().empty())
361  setNameFn(getResult(), nameAttr.getValue());
362 }
363 
364 std::optional<size_t> WireOp::getTargetResultIndex() { return 0; }
365 
366 OpFoldResult WireOp::fold(FoldAdaptor adaptor) {
367  // If the wire has no additional attributes, no name, and no symbol, just
368  // forward its input.
369  if (!hasAdditionalAttributes(*this, {"sv.namehint"}) && !getNameAttr() &&
370  !getInnerSymAttr())
371  return getInput();
372  return {};
373 }
374 
375 LogicalResult WireOp::canonicalize(WireOp wire, PatternRewriter &rewriter) {
376  // Block if the wire has any attributes.
377  if (hasAdditionalAttributes(wire, {"sv.namehint"}))
378  return failure();
379 
380  // If the wire has a symbol, then we can't delete it.
381  if (wire.getInnerSymAttr())
382  return failure();
383 
384  // If the wire has a name or an `sv.namehint` attribute, propagate it as an
385  // `sv.namehint` to the expression.
386  if (auto *inputOp = wire.getInput().getDefiningOp()) {
387  auto name = wire.getNameAttr();
388  if (!name || name.getValue().empty())
389  name = wire->getAttrOfType<StringAttr>("sv.namehint");
390  if (name)
391  rewriter.updateRootInPlace(
392  inputOp, [&] { inputOp->setAttr("sv.namehint", name); });
393  }
394 
395  rewriter.replaceOp(wire, wire.getInput());
396  return success();
397 }
398 
399 //===----------------------------------------------------------------------===//
400 // AggregateConstantOp
401 //===----------------------------------------------------------------------===//
402 
403 static LogicalResult checkAttributes(Operation *op, Attribute attr, Type type) {
404  // If this is a type alias, get the underlying type.
405  if (auto typeAlias = type.dyn_cast<TypeAliasType>())
406  type = typeAlias.getCanonicalType();
407 
408  if (auto structType = type.dyn_cast<StructType>()) {
409  auto arrayAttr = attr.dyn_cast<ArrayAttr>();
410  if (!arrayAttr)
411  return op->emitOpError("expected array attribute for constant of type ")
412  << type;
413  for (auto [attr, fieldInfo] :
414  llvm::zip(arrayAttr.getValue(), structType.getElements())) {
415  if (failed(checkAttributes(op, attr, fieldInfo.type)))
416  return failure();
417  }
418  } else if (auto arrayType = type.dyn_cast<ArrayType>()) {
419  auto arrayAttr = attr.dyn_cast<ArrayAttr>();
420  if (!arrayAttr)
421  return op->emitOpError("expected array attribute for constant of type ")
422  << type;
423  auto elementType = arrayType.getElementType();
424  for (auto attr : arrayAttr.getValue()) {
425  if (failed(checkAttributes(op, attr, elementType)))
426  return failure();
427  }
428  } else if (auto arrayType = type.dyn_cast<UnpackedArrayType>()) {
429  auto arrayAttr = attr.dyn_cast<ArrayAttr>();
430  if (!arrayAttr)
431  return op->emitOpError("expected array attribute for constant of type ")
432  << type;
433  auto elementType = arrayType.getElementType();
434  for (auto attr : arrayAttr.getValue()) {
435  if (failed(checkAttributes(op, attr, elementType)))
436  return failure();
437  }
438  } else if (auto enumType = type.dyn_cast<EnumType>()) {
439  auto stringAttr = attr.dyn_cast<StringAttr>();
440  if (!stringAttr)
441  return op->emitOpError("expected string attribute for constant of type ")
442  << type;
443  } else if (auto intType = type.dyn_cast<IntegerType>()) {
444  // Check the attribute kind is correct.
445  auto intAttr = attr.dyn_cast<IntegerAttr>();
446  if (!intAttr)
447  return op->emitOpError("expected integer attribute for constant of type ")
448  << type;
449  // Check the bitwidth is correct.
450  if (intAttr.getValue().getBitWidth() != intType.getWidth())
451  return op->emitOpError("hw.constant attribute bitwidth "
452  "doesn't match return type");
453  } else {
454  return op->emitOpError("unknown element type") << type;
455  }
456  return success();
457 }
458 
459 LogicalResult AggregateConstantOp::verify() {
460  return checkAttributes(*this, getFieldsAttr(), getType());
461 }
462 
463 OpFoldResult AggregateConstantOp::fold(FoldAdaptor) { return getFieldsAttr(); }
464 
465 //===----------------------------------------------------------------------===//
466 // ParamValueOp
467 //===----------------------------------------------------------------------===//
468 
469 static ParseResult parseParamValue(OpAsmParser &p, Attribute &value,
470  Type &resultType) {
471  if (p.parseType(resultType) || p.parseEqual() ||
472  p.parseAttribute(value, resultType))
473  return failure();
474  return success();
475 }
476 
477 static void printParamValue(OpAsmPrinter &p, Operation *, Attribute value,
478  Type resultType) {
479  p << resultType << " = ";
480  p.printAttributeWithoutType(value);
481 }
482 
483 LogicalResult ParamValueOp::verify() {
484  // Check that the attribute expression is valid in this module.
486  getValue(), (*this)->getParentOfType<hw::HWModuleOp>(), *this);
487 }
488 
489 OpFoldResult ParamValueOp::fold(FoldAdaptor adaptor) {
490  assert(adaptor.getOperands().empty() && "hw.param.value has no operands");
491  return getValueAttr();
492 }
493 
494 //===----------------------------------------------------------------------===//
495 // HWModuleOp
496 //===----------------------------------------------------------------------===/
497 
498 /// Return true if isAnyModule or instance.
499 bool hw::isAnyModuleOrInstance(Operation *moduleOrInstance) {
500  return isa<HWModuleLike, InstanceOp>(moduleOrInstance);
501 }
502 
503 /// Return the signature for a module as a function type from the module itself
504 /// or from an hw::InstanceOp.
505 FunctionType hw::getModuleType(Operation *moduleOrInstance) {
506  if (auto instance = dyn_cast<InstanceOp>(moduleOrInstance)) {
507  SmallVector<Type> inputs(instance->getOperandTypes());
508  SmallVector<Type> results(instance->getResultTypes());
509  return FunctionType::get(instance->getContext(), inputs, results);
510  }
511 
512  if (auto mod = dyn_cast<HWTestModuleOp>(moduleOrInstance))
513  return mod.getModuleType().getFuncType();
514 
515  if (auto mod = dyn_cast<HWModuleLike>(moduleOrInstance))
516  return mod.getHWModuleType().getFuncType();
517 
518  return cast<mlir::FunctionOpInterface>(moduleOrInstance)
519  .getFunctionType()
520  .cast<FunctionType>();
521 }
522 
523 /// Return the name to use for the Verilog module that we're referencing
524 /// here. This is typically the symbol, but can be overridden with the
525 /// verilogName attribute.
526 StringAttr hw::getVerilogModuleNameAttr(Operation *module) {
527  auto nameAttr = module->getAttrOfType<StringAttr>("verilogName");
528  if (nameAttr)
529  return nameAttr;
530 
531  return module->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName());
532 }
533 
534 // Flag for parsing different module types
536 
537 template <typename ModuleTy>
538 static void
539 buildModule(OpBuilder &builder, OperationState &result, StringAttr name,
540  const ModulePortInfo &ports, ArrayAttr parameters,
541  ArrayRef<NamedAttribute> attributes, StringAttr comment) {
542  using namespace mlir::function_interface_impl;
543  LocationAttr unknownLoc = builder.getUnknownLoc();
544 
545  // Add an attribute for the name.
546  result.addAttribute(SymbolTable::getSymbolAttrName(), name);
547 
548  SmallVector<Attribute> argNames, resultNames;
549  SmallVector<Type, 4> argTypes, resultTypes;
550  SmallVector<Attribute> portAttrs;
551  SmallVector<Attribute> portLocs;
552  SmallVector<ModulePort> portTypes;
553  auto exportPortIdent = StringAttr::get(builder.getContext(), "hw.exportPort");
554 
555  for (auto elt : ports.getInputs()) {
556  portTypes.push_back(elt);
557  if (elt.dir == ModulePort::Direction::InOut &&
558  !elt.type.isa<hw::InOutType>())
559  elt.type = hw::InOutType::get(elt.type);
560  argTypes.push_back(elt.type);
561  argNames.push_back(elt.name);
562  portLocs.push_back(elt.loc ? elt.loc : unknownLoc);
563  Attribute attr;
564  if (elt.sym && !elt.sym.empty())
565  attr = builder.getDictionaryAttr({{exportPortIdent, elt.sym}});
566  else
567  attr = builder.getDictionaryAttr({});
568  portAttrs.push_back(attr);
569  }
570 
571  for (auto elt : ports.getOutputs()) {
572  portTypes.push_back(elt);
573  resultTypes.push_back(elt.type);
574  resultNames.push_back(elt.name);
575  portLocs.push_back(elt.loc ? elt.loc : unknownLoc);
576  Attribute attr;
577  if (elt.sym && !elt.sym.empty())
578  attr = builder.getDictionaryAttr({{exportPortIdent, elt.sym}});
579  else
580  attr = builder.getDictionaryAttr({});
581  portAttrs.push_back(attr);
582  }
583 
584  // Allow clients to pass in null for the parameters list.
585  if (!parameters)
586  parameters = builder.getArrayAttr({});
587 
588  // Record the argument and result types as an attribute.
589  auto type = ModuleType::get(builder.getContext(), portTypes);
590  result.addAttribute(ModuleTy::getModuleTypeAttrName(result.name),
591  TypeAttr::get(type));
592  result.addAttribute("portLocs", builder.getArrayAttr(portLocs));
593  result.addAttribute("per_port_attrs",
594  arrayOrEmpty(builder.getContext(), portAttrs));
595  result.addAttribute("parameters", parameters);
596  if (!comment)
597  comment = builder.getStringAttr("");
598  result.addAttribute("comment", comment);
599  result.addAttributes(attributes);
600  result.addRegion();
601 }
602 
603 /// Internal implementation of argument/result insertion and removal on modules.
604 static void modifyModuleArgs(
605  MLIRContext *context, ArrayRef<std::pair<unsigned, PortInfo>> insertArgs,
606  ArrayRef<unsigned> removeArgs, ArrayRef<Attribute> oldArgNames,
607  ArrayRef<Type> oldArgTypes, ArrayRef<Attribute> oldArgAttrs,
608  ArrayRef<Location> oldArgLocs, SmallVector<Attribute> &newArgNames,
609  SmallVector<Type> &newArgTypes, SmallVector<Attribute> &newArgAttrs,
610  SmallVector<Location> &newArgLocs, Block *body = nullptr) {
611 
612 #ifndef NDEBUG
613  // Check that the `insertArgs` and `removeArgs` indices are in ascending
614  // order.
615  assert(llvm::is_sorted(insertArgs,
616  [](auto &a, auto &b) { return a.first < b.first; }) &&
617  "insertArgs must be in ascending order");
618  assert(llvm::is_sorted(removeArgs, [](auto &a, auto &b) { return a < b; }) &&
619  "removeArgs must be in ascending order");
620 #endif
621 
622  auto oldArgCount = oldArgTypes.size();
623  auto newArgCount = oldArgCount + insertArgs.size() - removeArgs.size();
624  assert((int)newArgCount >= 0);
625 
626  newArgNames.reserve(newArgCount);
627  newArgTypes.reserve(newArgCount);
628  newArgAttrs.reserve(newArgCount);
629  newArgLocs.reserve(newArgCount);
630 
631  auto exportPortAttrName = StringAttr::get(context, "hw.exportPort");
632  auto emptyDictAttr = DictionaryAttr::get(context, {});
633  auto unknownLoc = UnknownLoc::get(context);
634 
635  BitVector erasedIndices;
636  if (body)
637  erasedIndices.resize(oldArgCount + insertArgs.size());
638 
639  for (unsigned argIdx = 0, idx = 0; argIdx <= oldArgCount; ++argIdx, ++idx) {
640  // Insert new ports at this position.
641  while (!insertArgs.empty() && insertArgs[0].first == argIdx) {
642  auto port = insertArgs[0].second;
643  if (port.dir == ModulePort::Direction::InOut &&
644  !port.type.isa<InOutType>())
645  port.type = InOutType::get(port.type);
646  Attribute attr =
647  (port.sym && !port.sym.empty())
648  ? DictionaryAttr::get(context, {{exportPortAttrName, port.sym}})
649  : emptyDictAttr;
650  newArgNames.push_back(port.name);
651  newArgTypes.push_back(port.type);
652  newArgAttrs.push_back(attr);
653  insertArgs = insertArgs.drop_front();
654  LocationAttr loc = port.loc ? port.loc : unknownLoc;
655  newArgLocs.push_back(loc);
656  if (body)
657  body->insertArgument(idx++, port.type, loc);
658  }
659  if (argIdx == oldArgCount)
660  break;
661 
662  // Migrate the old port at this position.
663  bool removed = false;
664  while (!removeArgs.empty() && removeArgs[0] == argIdx) {
665  removeArgs = removeArgs.drop_front();
666  removed = true;
667  }
668 
669  if (removed) {
670  if (body)
671  erasedIndices.set(idx);
672  } else {
673  newArgNames.push_back(oldArgNames[argIdx]);
674  newArgTypes.push_back(oldArgTypes[argIdx]);
675  newArgAttrs.push_back(oldArgAttrs.empty() ? emptyDictAttr
676  : oldArgAttrs[argIdx]);
677  newArgLocs.push_back(oldArgLocs[argIdx]);
678  }
679  }
680 
681  if (body)
682  body->eraseArguments(erasedIndices);
683 
684  assert(newArgNames.size() == newArgCount);
685  assert(newArgTypes.size() == newArgCount);
686  assert(newArgAttrs.size() == newArgCount);
687  assert(newArgLocs.size() == newArgCount);
688 }
689 
690 /// Insert and remove ports of a module. The insertion and removal indices must
691 /// be in ascending order. The indices refer to the port positions before any
692 /// insertion or removal occurs. Ports inserted at the same index will appear in
693 /// the module in the same order as they were listed in the `insert*` array.
694 ///
695 /// The operation must be any of the module-like operations.
697  Operation *op, ArrayRef<std::pair<unsigned, PortInfo>> insertInputs,
698  ArrayRef<std::pair<unsigned, PortInfo>> insertOutputs,
699  ArrayRef<unsigned> removeInputs, ArrayRef<unsigned> removeOutputs,
700  Block *body) {
701  auto moduleOp = cast<HWModuleLike>(op);
702  auto *context = moduleOp.getContext();
703 
704  // Dig up the old argument and result data.
705  auto oldArgNames = moduleOp.getInputNames();
706  auto oldArgTypes = moduleOp.getInputTypes();
707  auto oldArgAttrs = moduleOp.getAllInputAttrs();
708  auto oldArgLocs = moduleOp.getInputLocs();
709 
710  auto oldResultNames = moduleOp.getOutputNames();
711  auto oldResultTypes = moduleOp.getOutputTypes();
712  auto oldResultAttrs = moduleOp.getAllOutputAttrs();
713  auto oldResultLocs = moduleOp.getOutputLocs();
714 
715  // Modify the ports.
716  SmallVector<Attribute> newArgNames, newResultNames;
717  SmallVector<Type> newArgTypes, newResultTypes;
718  SmallVector<Attribute> newArgAttrs, newResultAttrs;
719  SmallVector<Location> newArgLocs, newResultLocs;
720 
721  modifyModuleArgs(context, insertInputs, removeInputs, oldArgNames,
722  oldArgTypes, oldArgAttrs, oldArgLocs, newArgNames,
723  newArgTypes, newArgAttrs, newArgLocs, body);
724 
725  modifyModuleArgs(context, insertOutputs, removeOutputs, oldResultNames,
726  oldResultTypes, oldResultAttrs, oldResultLocs,
727  newResultNames, newResultTypes, newResultAttrs,
728  newResultLocs);
729 
730  // Update the module operation types and attributes.
731  auto fnty = FunctionType::get(context, newArgTypes, newResultTypes);
732  auto modty = detail::fnToMod(fnty, newArgNames, newResultNames);
733  moduleOp.setHWModuleType(modty);
734  moduleOp.setAllInputAttrs(newArgAttrs);
735  moduleOp.setAllOutputAttrs(newResultAttrs);
736 
737  newArgLocs.append(newResultLocs.begin(), newResultLocs.end());
738  moduleOp.setAllPortLocs(newArgLocs);
739 }
740 
741 void HWModuleOp::build(OpBuilder &builder, OperationState &result,
742  StringAttr name, const ModulePortInfo &ports,
743  ArrayAttr parameters,
744  ArrayRef<NamedAttribute> attributes, StringAttr comment,
745  bool shouldEnsureTerminator) {
746  buildModule<HWModuleOp>(builder, result, name, ports, parameters, attributes,
747  comment);
748 
749  // Create a region and a block for the body.
750  auto *bodyRegion = result.regions[0].get();
751  Block *body = new Block();
752  bodyRegion->push_back(body);
753 
754  // Add arguments to the body block.
755  auto unknownLoc = builder.getUnknownLoc();
756  for (auto port : ports.getInputs()) {
757  auto loc = port.loc ? Location(port.loc) : unknownLoc;
758  auto type = port.type;
759  if (port.isInOut() && !type.isa<InOutType>())
760  type = InOutType::get(type);
761  body->addArgument(type, loc);
762  }
763 
764  if (shouldEnsureTerminator)
765  HWModuleOp::ensureTerminator(*bodyRegion, builder, result.location);
766 }
767 
768 void HWModuleOp::build(OpBuilder &builder, OperationState &result,
769  StringAttr name, ArrayRef<PortInfo> ports,
770  ArrayAttr parameters,
771  ArrayRef<NamedAttribute> attributes,
772  StringAttr comment) {
773  build(builder, result, name, ModulePortInfo(ports), parameters, attributes,
774  comment);
775 }
776 
777 void HWModuleOp::build(OpBuilder &builder, OperationState &odsState,
778  StringAttr name, const ModulePortInfo &ports,
779  HWModuleBuilder modBuilder, ArrayAttr parameters,
780  ArrayRef<NamedAttribute> attributes,
781  StringAttr comment) {
782  build(builder, odsState, name, ports, parameters, attributes, comment,
783  /*shouldEnsureTerminator=*/false);
784  auto *bodyRegion = odsState.regions[0].get();
785  OpBuilder::InsertionGuard guard(builder);
786  auto accessor = HWModulePortAccessor(odsState.location, ports, *bodyRegion);
787  builder.setInsertionPointToEnd(&bodyRegion->front());
788  modBuilder(builder, accessor);
789  // Create output operands.
790  llvm::SmallVector<Value> outputOperands = accessor.getOutputOperands();
791  builder.create<hw::OutputOp>(odsState.location, outputOperands);
792 }
793 
794 void HWModuleOp::modifyPorts(
795  ArrayRef<std::pair<unsigned, PortInfo>> insertInputs,
796  ArrayRef<std::pair<unsigned, PortInfo>> insertOutputs,
797  ArrayRef<unsigned> eraseInputs, ArrayRef<unsigned> eraseOutputs) {
798  hw::modifyModulePorts(*this, insertInputs, insertOutputs, eraseInputs,
799  eraseOutputs);
800 }
801 
802 /// Return the name to use for the Verilog module that we're referencing
803 /// here. This is typically the symbol, but can be overridden with the
804 /// verilogName attribute.
806  if (auto vName = getVerilogNameAttr())
807  return vName;
808 
809  return (*this)->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName());
810 }
811 
813  if (auto vName = getVerilogNameAttr()) {
814  return vName;
815  }
816  return (*this)->getAttrOfType<StringAttr>(
817  ::mlir::SymbolTable::getSymbolAttrName());
818 }
819 
820 void HWModuleExternOp::build(OpBuilder &builder, OperationState &result,
821  StringAttr name, const ModulePortInfo &ports,
822  StringRef verilogName, ArrayAttr parameters,
823  ArrayRef<NamedAttribute> attributes) {
824  buildModule<HWModuleExternOp>(builder, result, name, ports, parameters,
825  attributes, {});
826 
827  if (!verilogName.empty())
828  result.addAttribute("verilogName", builder.getStringAttr(verilogName));
829 }
830 
831 void HWModuleExternOp::build(OpBuilder &builder, OperationState &result,
832  StringAttr name, ArrayRef<PortInfo> ports,
833  StringRef verilogName, ArrayAttr parameters,
834  ArrayRef<NamedAttribute> attributes) {
835  build(builder, result, name, ModulePortInfo(ports), verilogName, parameters,
836  attributes);
837 }
838 
839 void HWModuleExternOp::modifyPorts(
840  ArrayRef<std::pair<unsigned, PortInfo>> insertInputs,
841  ArrayRef<std::pair<unsigned, PortInfo>> insertOutputs,
842  ArrayRef<unsigned> eraseInputs, ArrayRef<unsigned> eraseOutputs) {
843  hw::modifyModulePorts(*this, insertInputs, insertOutputs, eraseInputs,
844  eraseOutputs);
845 }
846 
847 void HWModuleExternOp::appendOutputs(
848  ArrayRef<std::pair<StringAttr, Value>> outputs) {}
849 
850 void HWModuleGeneratedOp::build(OpBuilder &builder, OperationState &result,
851  FlatSymbolRefAttr genKind, StringAttr name,
852  const ModulePortInfo &ports,
853  StringRef verilogName, ArrayAttr parameters,
854  ArrayRef<NamedAttribute> attributes) {
855  buildModule<HWModuleGeneratedOp>(builder, result, name, ports, parameters,
856  attributes, {});
857  result.addAttribute("generatorKind", genKind);
858  if (!verilogName.empty())
859  result.addAttribute("verilogName", builder.getStringAttr(verilogName));
860 }
861 
862 void HWModuleGeneratedOp::build(OpBuilder &builder, OperationState &result,
863  FlatSymbolRefAttr genKind, StringAttr name,
864  ArrayRef<PortInfo> ports, StringRef verilogName,
865  ArrayAttr parameters,
866  ArrayRef<NamedAttribute> attributes) {
867  build(builder, result, genKind, name, ModulePortInfo(ports), verilogName,
868  parameters, attributes);
869 }
870 
871 void HWModuleGeneratedOp::modifyPorts(
872  ArrayRef<std::pair<unsigned, PortInfo>> insertInputs,
873  ArrayRef<std::pair<unsigned, PortInfo>> insertOutputs,
874  ArrayRef<unsigned> eraseInputs, ArrayRef<unsigned> eraseOutputs) {
875  hw::modifyModulePorts(*this, insertInputs, insertOutputs, eraseInputs,
876  eraseOutputs);
877 }
878 
879 void HWModuleGeneratedOp::appendOutputs(
880  ArrayRef<std::pair<StringAttr, Value>> outputs) {}
881 
882 static InnerSymAttr extractSym(DictionaryAttr attrs) {
883  if (attrs)
884  if (auto symRef = attrs.get("hw.exportPort"))
885  return symRef.cast<InnerSymAttr>();
886  return {};
887 }
888 
889 static bool hasAttribute(StringRef name, ArrayRef<NamedAttribute> attrs) {
890  for (auto &argAttr : attrs)
891  if (argAttr.getName() == name)
892  return true;
893  return false;
894 }
895 
896 static Attribute getAttribute(StringRef name, ArrayRef<NamedAttribute> attrs) {
897  for (auto &argAttr : attrs)
898  if (argAttr.getName() == name)
899  return argAttr.getValue();
900  return {};
901 }
902 
903 static void addArgAndResultAttrsHW(Builder &builder, OperationState &result,
904  ArrayRef<DictionaryAttr> argAttrs,
905  ArrayRef<DictionaryAttr> resultAttrs) {
906  // Add the attributes to the function arguments.
907  result.addAttribute(
908  "per_port_attrs",
909  arrayOrEmpty(builder.getContext(),
910  llvm::to_vector(
911  llvm::concat<const Attribute>(argAttrs, resultAttrs))));
912 }
913 
914 static void addArgAndResultAttrsHW(Builder &builder, OperationState &result,
915  ArrayRef<OpAsmParser::Argument> args,
916  ArrayRef<DictionaryAttr> resultAttrs) {
917  SmallVector<DictionaryAttr> argAttrs;
918  for (const auto &arg : args)
919  argAttrs.push_back(arg.attrs);
920  addArgAndResultAttrsHW(builder, result, argAttrs, resultAttrs);
921 }
922 
923 template <typename ModuleTy>
924 static ParseResult parseHWModuleOp(OpAsmParser &parser, OperationState &result,
925  ExternModKind modKind = PlainMod) {
926 
927  using namespace mlir::function_interface_impl;
928  auto loc = parser.getCurrentLocation();
929 
930  // Parse the visibility attribute.
931  (void)mlir::impl::parseOptionalVisibilityKeyword(parser, result.attributes);
932 
933  // Parse the name as a symbol.
934  StringAttr nameAttr;
935  if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
936  result.attributes))
937  return failure();
938 
939  // Parse the generator information.
940  FlatSymbolRefAttr kindAttr;
941  if (modKind == GenMod) {
942  if (parser.parseComma() ||
943  parser.parseAttribute(kindAttr, "generatorKind", result.attributes)) {
944  return failure();
945  }
946  }
947 
948  // Parse the parameters.
949  ArrayAttr parameters;
950  if (parseOptionalParameterList(parser, parameters))
951  return failure();
952 
953  // Parse the function signature.
954  bool isVariadic = false;
955  SmallVector<OpAsmParser::Argument, 4> entryArgs;
956  SmallVector<Attribute> argNames;
957  SmallVector<Attribute> argLocs;
958  SmallVector<Attribute> resultNames;
959  SmallVector<DictionaryAttr> resultAttrs;
960  SmallVector<Attribute> resultLocs;
961  TypeAttr functionType;
963  parser, isVariadic, entryArgs, argNames, argLocs, resultNames,
964  resultAttrs, resultLocs, functionType)))
965  return failure();
966 
967  // Parse the attribute dict.
968  if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
969  return failure();
970 
971  if (hasAttribute("resultNames", result.attributes) ||
972  hasAttribute("parameters", result.attributes)) {
973  parser.emitError(
974  loc, "explicit `resultNames` / `parameters` attributes not allowed");
975  return failure();
976  }
977 
978  SmallVector<Attribute> portLocs;
979  portLocs.append(argLocs.begin(), argLocs.end());
980  portLocs.append(resultLocs.begin(), resultLocs.end());
981 
982  auto *context = result.getContext();
983  // prefer the attribute over the ssa values
984  auto attr = getAttribute("argNames", result.attributes);
985  auto modType = detail::fnToMod(
986  cast<FunctionType>(functionType.getValue()),
987  attr ? cast<ArrayAttr>(attr).getValue() : argNames, resultNames);
988  result.attributes.erase("argNames");
989  result.addAttribute("portLocs", ArrayAttr::get(context, portLocs));
990  result.addAttribute("parameters", parameters);
991  if (!hasAttribute("comment", result.attributes))
992  result.addAttribute("comment", StringAttr::get(context, ""));
993  result.addAttribute(ModuleTy::getModuleTypeAttrName(result.name),
994  TypeAttr::get(modType));
995 
996  // Add the attributes to the function arguments.
997  addArgAndResultAttrsHW(parser.getBuilder(), result, entryArgs, resultAttrs);
998 
999  // Parse the optional function body.
1000  auto *body = result.addRegion();
1001  if (modKind == PlainMod) {
1002  if (parser.parseRegion(*body, entryArgs))
1003  return failure();
1004 
1005  HWModuleOp::ensureTerminator(*body, parser.getBuilder(), result.location);
1006  }
1007  return success();
1008 }
1009 
1010 ParseResult HWModuleOp::parse(OpAsmParser &parser, OperationState &result) {
1011  return parseHWModuleOp<HWModuleOp>(parser, result);
1012 }
1013 
1014 ParseResult HWModuleExternOp::parse(OpAsmParser &parser,
1015  OperationState &result) {
1016  return parseHWModuleOp<HWModuleExternOp>(parser, result, ExternMod);
1017 }
1018 
1019 ParseResult HWModuleGeneratedOp::parse(OpAsmParser &parser,
1020  OperationState &result) {
1021  return parseHWModuleOp<HWModuleGeneratedOp>(parser, result, GenMod);
1022 }
1023 
1024 FunctionType getHWModuleOpType(Operation *op) {
1025  if (auto mod = dyn_cast<HWModuleLike>(op))
1026  return mod.getHWModuleType().getFuncType();
1027  return cast<mlir::FunctionOpInterface>(op)
1028  .getFunctionType()
1029  .cast<FunctionType>();
1030 }
1031 
1032 template <typename ModuleTy>
1033 static void printModuleOp(OpAsmPrinter &p, ModuleTy mod) {
1034  using namespace mlir::function_interface_impl;
1035 
1036  FunctionType fnType = mod.getHWModuleType().getFuncType();
1037  auto argTypes = fnType.getInputs();
1038  auto resultTypes = fnType.getResults();
1039 
1040  p << ' ';
1041 
1042  // Print the visibility of the module.
1043  StringRef visibilityAttrName = SymbolTable::getVisibilityAttrName();
1044  if (auto visibility = mod.getOperation()->template getAttrOfType<StringAttr>(
1045  visibilityAttrName))
1046  p << visibility.getValue() << ' ';
1047 
1048  // Print the operation and the function name.
1049  p.printSymbolName(SymbolTable::getSymbolName(mod.getOperation()).getValue());
1050  if (auto gen = dyn_cast<HWModuleGeneratedOp>(mod.getOperation())) {
1051  p << ", ";
1052  p.printSymbolName(gen.getGeneratorKind());
1053  }
1054 
1055  // Print the parameter list if present.
1057  p, mod.getOperation(),
1058  mod.getOperation()->template getAttrOfType<ArrayAttr>("parameters"));
1059 
1060  bool needArgNamesAttr = false;
1061  module_like_impl::printModuleSignature(p, mod.getOperation(), argTypes,
1062  /*isVariadic=*/false, resultTypes,
1063  needArgNamesAttr);
1064 
1065  SmallVector<StringRef, 3> omittedAttrs;
1066  if (isa<HWModuleGeneratedOp>(mod.getOperation()))
1067  omittedAttrs.push_back("generatorKind");
1068  if (!needArgNamesAttr)
1069  omittedAttrs.push_back("argNames");
1070  omittedAttrs.push_back("portLocs");
1071  omittedAttrs.push_back(
1072  ModuleTy::getModuleTypeAttrName(mod.getOperation()->getName()));
1073  omittedAttrs.push_back("per_port_attrs");
1074  omittedAttrs.push_back("resultNames");
1075  omittedAttrs.push_back("parameters");
1076  omittedAttrs.push_back(visibilityAttrName);
1077  omittedAttrs.push_back(SymbolTable::getSymbolAttrName());
1078  if (mod.getOperation()
1079  ->template getAttrOfType<StringAttr>("comment")
1080  .getValue()
1081  .empty())
1082  omittedAttrs.push_back("comment");
1083  // inject argNames
1084  auto attrs = mod->getAttrs();
1085  SmallVector<NamedAttribute> realAttrs(attrs.begin(), attrs.end());
1086  realAttrs.push_back(
1087  NamedAttribute(StringAttr::get(mod.getContext(), "argNames"),
1088  ArrayAttr::get(mod.getContext(), mod.getInputNames())));
1089  p.printOptionalAttrDictWithKeyword(realAttrs, omittedAttrs);
1090 }
1091 
1092 void HWModuleExternOp::print(OpAsmPrinter &p) { printModuleOp(p, *this); }
1093 void HWModuleGeneratedOp::print(OpAsmPrinter &p) { printModuleOp(p, *this); }
1094 
1095 void HWModuleOp::print(OpAsmPrinter &p) {
1096  printModuleOp(p, *this);
1097 
1098  // Print the body if this is not an external function.
1099  Region &body = getBody();
1100  if (!body.empty()) {
1101  p << " ";
1102  p.printRegion(body, /*printEntryBlockArgs=*/false,
1103  /*printBlockTerminators=*/true);
1104  }
1105 }
1106 
1107 static LogicalResult verifyModuleCommon(HWModuleLike module) {
1108  assert(isa<HWModuleLike>(module) &&
1109  "verifier hook should only be called on modules");
1110 
1111  auto moduleType = module.getHWModuleType();
1112 
1113  auto argLocs = module.getInputLocs();
1114  if (argLocs.size() != moduleType.getNumInputs())
1115  return module->emitOpError("incorrect number of argument locations");
1116 
1117  auto resultLocs = module.getOutputLocs();
1118  if (resultLocs.size() != moduleType.getNumOutputs())
1119  return module->emitOpError("incorrect number of result locations");
1120 
1121  SmallPtrSet<Attribute, 4> paramNames;
1122 
1123  // Check parameter default values are sensible.
1124  for (auto param : module->getAttrOfType<ArrayAttr>("parameters")) {
1125  auto paramAttr = param.cast<ParamDeclAttr>();
1126 
1127  // Check that we don't have any redundant parameter names. These are
1128  // resolved by string name: reuse of the same name would cause ambiguities.
1129  if (!paramNames.insert(paramAttr.getName()).second)
1130  return module->emitOpError("parameter ")
1131  << paramAttr << " has the same name as a previous parameter";
1132 
1133  // Default values are allowed to be missing, check them if present.
1134  auto value = paramAttr.getValue();
1135  if (!value)
1136  continue;
1137 
1138  auto typedValue = value.dyn_cast<TypedAttr>();
1139  if (!typedValue)
1140  return module->emitOpError("parameter ")
1141  << paramAttr << " should have a typed value; has value " << value;
1142 
1143  if (typedValue.getType() != paramAttr.getType())
1144  return module->emitOpError("parameter ")
1145  << paramAttr << " should have type " << paramAttr.getType()
1146  << "; has type " << typedValue.getType();
1147 
1148  // Verify that this is a valid parameter value, disallowing parameter
1149  // references. We could allow parameters to refer to each other in the
1150  // future with lexical ordering if there is a need.
1151  if (failed(checkParameterInContext(value, module, module,
1152  /*disallowParamRefs=*/true)))
1153  return failure();
1154  }
1155  return success();
1156 }
1157 
1158 LogicalResult HWModuleOp::verify() {
1159  if (failed(verifyModuleCommon(*this)))
1160  return failure();
1161 
1162  auto type = getModuleType();
1163  auto *body = getBodyBlock();
1164 
1165  // Verify the number of block arguments.
1166  auto numInputs = type.getNumInputs();
1167  if (body->getNumArguments() != numInputs)
1168  return emitOpError("entry block must have")
1169  << numInputs << " arguments to match module signature";
1170 
1171  // Verify that the block arguments match the op's attributes.
1172  for (auto [arg, type, loc] : llvm::zip(getBodyBlock()->getArguments(),
1173  getInputTypes(), getInputLocs())) {
1174  if (arg.getType() != type)
1175  return emitOpError("block argument types should match signature types");
1176  if (arg.getLoc() != loc.cast<LocationAttr>())
1177  return emitOpError(
1178  "block argument locations should match signature locations");
1179  }
1180 
1181  return success();
1182 }
1183 
1184 LogicalResult HWModuleExternOp::verify() { return verifyModuleCommon(*this); }
1185 
1186 std::pair<StringAttr, BlockArgument>
1187 HWModuleOp::insertInput(unsigned index, StringAttr name, Type ty) {
1188  // Find a unique name for the wire.
1189  Namespace ns;
1190  auto ports = getPortList();
1191  for (auto port : ports)
1192  ns.newName(port.name.getValue());
1193  auto nameAttr = StringAttr::get(getContext(), ns.newName(name.getValue()));
1194 
1195  Block *body = getBodyBlock();
1196 
1197  // Create a new port for the host clock.
1198  PortInfo port;
1199  port.name = nameAttr;
1200  port.dir = ModulePort::Direction::Input;
1201  port.type = ty;
1202  hw::modifyModulePorts(getOperation(), {std::make_pair(index, port)}, {}, {},
1203  {}, body);
1204 
1205  // Add a new argument.
1206  return {nameAttr, body->getArgument(index)};
1207 }
1208 
1209 void HWModuleOp::insertOutputs(unsigned index,
1210  ArrayRef<std::pair<StringAttr, Value>> outputs) {
1211 
1212  auto output = cast<OutputOp>(getBodyBlock()->getTerminator());
1213  assert(index <= output->getNumOperands() && "invalid output index");
1214 
1215  // Rewrite the port list of the module.
1216  SmallVector<std::pair<unsigned, PortInfo>> indexedNewPorts;
1217  for (auto &[name, value] : outputs) {
1218  PortInfo port;
1219  port.name = name;
1220  port.dir = ModulePort::Direction::Output;
1221  port.type = value.getType();
1222  indexedNewPorts.emplace_back(index, port);
1223  }
1224  hw::modifyModulePorts(getOperation(), {}, indexedNewPorts, {}, {},
1225  getBodyBlock());
1226 
1227  // Rewrite the output op.
1228  for (auto &[name, value] : outputs)
1229  output->insertOperands(index++, value);
1230 }
1231 
1232 void HWModuleOp::appendOutputs(ArrayRef<std::pair<StringAttr, Value>> outputs) {
1233  return insertOutputs(getNumOutputPorts(), outputs);
1234 }
1235 
1236 void HWModuleOp::getAsmBlockArgumentNames(mlir::Region &region,
1237  mlir::OpAsmSetValueNameFn setNameFn) {
1238  getAsmBlockArgumentNamesImpl(region, setNameFn);
1239 }
1240 
1241 void HWModuleExternOp::getAsmBlockArgumentNames(
1242  mlir::Region &region, mlir::OpAsmSetValueNameFn setNameFn) {
1243  getAsmBlockArgumentNamesImpl(region, setNameFn);
1244 }
1245 
1246 template <typename ModTy>
1247 static SmallVector<Location> getAllPortLocs(ModTy module) {
1248  SmallVector<Location> retval;
1249  auto locs = module.getPortLocs();
1250  for (auto l : locs)
1251  retval.push_back(cast<Location>(l));
1252  assert(!locs.size() || locs.size() == module.getNumPorts());
1253  return retval;
1254 }
1255 
1256 SmallVector<Location> HWModuleOp::getAllPortLocs() {
1257  return ::getAllPortLocs(*this);
1258 }
1259 
1260 SmallVector<Location> HWModuleExternOp::getAllPortLocs() {
1261  return ::getAllPortLocs(*this);
1262 }
1263 
1264 SmallVector<Location> HWModuleGeneratedOp::getAllPortLocs() {
1265  return ::getAllPortLocs(*this);
1266 }
1267 
1268 template <typename ModTy>
1269 static void setAllPortLocs(ArrayRef<Location> locs, ModTy module) {
1270  std::vector<Attribute> nLocs(locs.begin(), locs.end());
1271  module.setPortLocsAttr(ArrayAttr::get(module.getContext(), nLocs));
1272 }
1273 
1274 void HWModuleOp::setAllPortLocs(ArrayRef<Location> locs) {
1275  ::setAllPortLocs(locs, *this);
1276 }
1277 
1278 void HWModuleExternOp::setAllPortLocs(ArrayRef<Location> locs) {
1279  ::setAllPortLocs(locs, *this);
1280 }
1281 
1282 void HWModuleGeneratedOp::setAllPortLocs(ArrayRef<Location> locs) {
1283  ::setAllPortLocs(locs, *this);
1284 }
1285 
1286 template <typename ModTy>
1287 static void setAllPortNames(ArrayRef<Attribute> names, ModTy module) {
1288  auto numInputs = module.getNumInputPorts();
1289  SmallVector<Attribute> argNames(names.begin(), names.begin() + numInputs);
1290  SmallVector<Attribute> resNames(names.begin() + numInputs, names.end());
1291  auto oldType = module.getModuleType();
1292  SmallVector<ModulePort> newPorts(oldType.getPorts().begin(),
1293  oldType.getPorts().end());
1294  for (size_t i = 0UL, e = newPorts.size(); i != e; ++i)
1295  newPorts[i].name = cast<StringAttr>(names[i]);
1296  auto newType = ModuleType::get(module.getContext(), newPorts);
1297  module.setModuleType(newType);
1298 }
1299 
1300 void HWModuleOp::setAllPortNames(ArrayRef<Attribute> names) {
1301  ::setAllPortNames(names, *this);
1302 }
1303 
1304 void HWModuleExternOp::setAllPortNames(ArrayRef<Attribute> names) {
1305  ::setAllPortNames(names, *this);
1306 }
1307 
1308 void HWModuleGeneratedOp::setAllPortNames(ArrayRef<Attribute> names) {
1309  ::setAllPortNames(names, *this);
1310 }
1311 
1312 SmallVector<Attribute> HWModuleOp::getAllPortAttrs() {
1313  auto attrs = getPerPortAttrs();
1314  if (attrs && !attrs.empty())
1315  return {attrs.getValue().begin(), attrs.getValue().end()};
1316  return SmallVector<Attribute>(getNumPorts());
1317 }
1318 
1319 SmallVector<Attribute> HWModuleExternOp::getAllPortAttrs() {
1320  auto attrs = getPerPortAttrs();
1321  if (attrs && !attrs.empty())
1322  return {attrs.getValue().begin(), attrs.getValue().end()};
1323  return SmallVector<Attribute>(getNumPorts());
1324 }
1325 
1326 SmallVector<Attribute> HWModuleGeneratedOp::getAllPortAttrs() {
1327  auto attrs = getPerPortAttrs();
1328  if (attrs && !attrs.empty())
1329  return {attrs.getValue().begin(), attrs.getValue().end()};
1330  return SmallVector<Attribute>(getNumPorts());
1331 }
1332 
1333 void HWModuleOp::setAllPortAttrs(ArrayRef<Attribute> attrs) {
1334  setPerPortAttrsAttr(arrayOrEmpty(getContext(), attrs));
1335 }
1336 
1337 void HWModuleExternOp::setAllPortAttrs(ArrayRef<Attribute> attrs) {
1338  setPerPortAttrsAttr(arrayOrEmpty(getContext(), attrs));
1339 }
1340 
1341 void HWModuleGeneratedOp::setAllPortAttrs(ArrayRef<Attribute> attrs) {
1342  setPerPortAttrsAttr(arrayOrEmpty(getContext(), attrs));
1343 }
1344 
1345 void HWModuleOp::removeAllPortAttrs() {
1346  setPerPortAttrsAttr(ArrayAttr::get(getContext(), {}));
1347 }
1348 
1349 void HWModuleExternOp::removeAllPortAttrs() {
1350  setPerPortAttrsAttr(ArrayAttr::get(getContext(), {}));
1351 }
1352 
1353 void HWModuleGeneratedOp::removeAllPortAttrs() {
1354  setPerPortAttrsAttr(ArrayAttr::get(getContext(), {}));
1355 }
1356 
1357 // This probably does really unexpected stuff when you change the number of
1358 
1359 template <typename ModTy>
1360 static void setHWModuleType(ModTy &mod, ModuleType type) {
1361  auto argAttrs = mod.getAllInputAttrs();
1362  auto resAttrs = mod.getAllOutputAttrs();
1363  mod.setModuleTypeAttr(TypeAttr::get(type));
1364  unsigned newNumArgs = type.getNumInputs();
1365  unsigned newNumResults = type.getNumOutputs();
1366 
1367  auto emptyDict = DictionaryAttr::get(mod.getContext());
1368  argAttrs.resize(newNumArgs, emptyDict);
1369  resAttrs.resize(newNumResults, emptyDict);
1370 
1371  SmallVector<Attribute> attrs;
1372  attrs.append(argAttrs.begin(), argAttrs.end());
1373  attrs.append(resAttrs.begin(), resAttrs.end());
1374 
1375  if (attrs.empty())
1376  return mod.removeAllPortAttrs();
1377  mod.setAllPortAttrs(attrs);
1378 }
1379 
1380 void HWModuleOp::setHWModuleType(ModuleType type) {
1381  return ::setHWModuleType(*this, type);
1382 }
1383 
1384 void HWModuleExternOp::setHWModuleType(ModuleType type) {
1385  return ::setHWModuleType(*this, type);
1386 }
1387 
1388 void HWModuleGeneratedOp::setHWModuleType(ModuleType type) {
1389  return ::setHWModuleType(*this, type);
1390 }
1391 
1392 /// Lookup the generator for the symbol. This returns null on
1393 /// invalid IR.
1394 Operation *HWModuleGeneratedOp::getGeneratorKindOp() {
1395  auto topLevelModuleOp = (*this)->getParentOfType<ModuleOp>();
1396  return topLevelModuleOp.lookupSymbol(getGeneratorKind());
1397 }
1398 
1399 LogicalResult
1400 HWModuleGeneratedOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1401  auto *referencedKind =
1402  symbolTable.lookupNearestSymbolFrom(*this, getGeneratorKindAttr());
1403 
1404  if (referencedKind == nullptr)
1405  return emitError("Cannot find generator definition '")
1406  << getGeneratorKind() << "'";
1407 
1408  if (!isa<HWGeneratorSchemaOp>(referencedKind))
1409  return emitError("Symbol resolved to '")
1410  << referencedKind->getName()
1411  << "' which is not a HWGeneratorSchemaOp";
1412 
1413  auto referencedKindOp = dyn_cast<HWGeneratorSchemaOp>(referencedKind);
1414  auto paramRef = referencedKindOp.getRequiredAttrs();
1415  auto dict = (*this)->getAttrDictionary();
1416  for (auto str : paramRef) {
1417  auto strAttr = str.dyn_cast<StringAttr>();
1418  if (!strAttr)
1419  return emitError("Unknown attribute type, expected a string");
1420  if (!dict.get(strAttr.getValue()))
1421  return emitError("Missing attribute '") << strAttr.getValue() << "'";
1422  }
1423 
1424  return success();
1425 }
1426 
1427 LogicalResult HWModuleGeneratedOp::verify() {
1428  return verifyModuleCommon(*this);
1429 }
1430 
1431 void HWModuleGeneratedOp::getAsmBlockArgumentNames(
1432  mlir::Region &region, mlir::OpAsmSetValueNameFn setNameFn) {
1433  getAsmBlockArgumentNamesImpl(region, setNameFn);
1434 }
1435 
1436 LogicalResult HWModuleOp::verifyBody() { return success(); }
1437 
1438 template <typename ModuleTy>
1439 static ModulePortInfo getPortList(ModuleTy &mod) {
1440  auto modTy = mod.getHWModuleType();
1441  auto emptyDict = DictionaryAttr::get(mod.getContext());
1442  SmallVector<PortInfo> retval;
1443  for (unsigned i = 0, e = modTy.getNumPorts(); i < e; ++i) {
1444  LocationAttr loc = mod.getPortLoc(i);
1445  DictionaryAttr attrs =
1446  dyn_cast_or_null<DictionaryAttr>(mod.getPortAttrs(i));
1447  if (!attrs)
1448  attrs = emptyDict;
1449  retval.push_back({modTy.getPorts()[i],
1450  modTy.isOutput(i) ? modTy.getOutputIdForPortId(i)
1451  : modTy.getInputIdForPortId(i),
1452  extractSym(attrs), attrs, loc});
1453  }
1454  return ModulePortInfo(retval);
1455 }
1456 
1457 //===----------------------------------------------------------------------===//
1458 // InstanceOp
1459 //===----------------------------------------------------------------------===//
1460 
1461 /// Create a instance that refers to a known module.
1462 void InstanceOp::build(OpBuilder &builder, OperationState &result,
1463  Operation *module, StringAttr name,
1464  ArrayRef<Value> inputs, ArrayAttr parameters,
1465  InnerSymAttr innerSym) {
1466  if (!parameters)
1467  parameters = builder.getArrayAttr({});
1468 
1469  auto mod = cast<hw::HWModuleLike>(module);
1470  auto argNames = builder.getArrayAttr(mod.getInputNames());
1471  auto resultNames = builder.getArrayAttr(mod.getOutputNames());
1472  FunctionType modType = mod.getHWModuleType().getFuncType();
1473  build(builder, result, modType.getResults(), name,
1474  FlatSymbolRefAttr::get(SymbolTable::getSymbolName(module)), inputs,
1475  argNames, resultNames, parameters, innerSym);
1476 }
1477 
1478 std::optional<size_t> InstanceOp::getTargetResultIndex() {
1479  // Inner symbols on instance operations target the op not any result.
1480  return std::nullopt;
1481 }
1482 
1483 Operation *InstanceOp::getReferencedModuleSlow() {
1484  return getReferencedModuleCached(/*cache=*/nullptr);
1485 }
1486 
1487 LogicalResult InstanceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1489  *this, getModuleNameAttr(), getInputs(), getResultTypes(), getArgNames(),
1490  getResultNames(), getParameters(), symbolTable);
1491 }
1492 
1493 LogicalResult InstanceOp::verify() {
1494  auto module = (*this)->getParentOfType<HWModuleOp>();
1495  if (!module)
1496  return success();
1497 
1498  auto moduleParameters = module->getAttrOfType<ArrayAttr>("parameters");
1500  [&](const std::function<bool(InFlightDiagnostic &)> &fn) {
1501  auto diag = emitOpError();
1502  if (fn(diag))
1503  diag.attachNote(module->getLoc()) << "module declared here";
1504  };
1506  getParameters(), moduleParameters, emitError);
1507 }
1508 
1509 ParseResult InstanceOp::parse(OpAsmParser &parser, OperationState &result) {
1510  StringAttr instanceNameAttr;
1511  InnerSymAttr innerSym;
1512  FlatSymbolRefAttr moduleNameAttr;
1513  SmallVector<OpAsmParser::UnresolvedOperand, 4> inputsOperands;
1514  SmallVector<Type, 1> inputsTypes, allResultTypes;
1515  ArrayAttr argNames, resultNames, parameters;
1516  auto noneType = parser.getBuilder().getType<NoneType>();
1517 
1518  if (parser.parseAttribute(instanceNameAttr, noneType, "instanceName",
1519  result.attributes))
1520  return failure();
1521 
1522  if (succeeded(parser.parseOptionalKeyword("sym"))) {
1523  // Parsing an optional symbol name doesn't fail, so no need to check the
1524  // result.
1525  if (parser.parseCustomAttributeWithFallback(innerSym))
1526  return failure();
1527  result.addAttribute(InnerSymbolTable::getInnerSymbolAttrName(), innerSym);
1528  }
1529 
1530  llvm::SMLoc parametersLoc, inputsOperandsLoc;
1531  if (parser.parseAttribute(moduleNameAttr, noneType, "moduleName",
1532  result.attributes) ||
1533  parser.getCurrentLocation(&parametersLoc) ||
1534  parseOptionalParameterList(parser, parameters) ||
1535  parseInputPortList(parser, inputsOperands, inputsTypes, argNames) ||
1536  parser.resolveOperands(inputsOperands, inputsTypes, inputsOperandsLoc,
1537  result.operands) ||
1538  parser.parseArrow() ||
1539  parseOutputPortList(parser, allResultTypes, resultNames) ||
1540  parser.parseOptionalAttrDict(result.attributes)) {
1541  return failure();
1542  }
1543 
1544  result.addAttribute("argNames", argNames);
1545  result.addAttribute("resultNames", resultNames);
1546  result.addAttribute("parameters", parameters);
1547  result.addTypes(allResultTypes);
1548  return success();
1549 }
1550 
1551 void InstanceOp::print(OpAsmPrinter &p) {
1552  p << ' ';
1553  p.printAttributeWithoutType(getInstanceNameAttr());
1554  if (auto attr = getInnerSymAttr()) {
1555  p << " sym ";
1556  attr.print(p);
1557  }
1558  p << ' ';
1559  p.printAttributeWithoutType(getModuleNameAttr());
1560  printOptionalParameterList(p, *this, getParameters());
1561  printInputPortList(p, *this, getInputs(), getInputs().getTypes(),
1562  getArgNames());
1563  p << " -> ";
1564  printOutputPortList(p, *this, getResultTypes(), getResultNames());
1565 
1566  p.printOptionalAttrDict(
1567  (*this)->getAttrs(),
1568  /*elidedAttrs=*/{"instanceName",
1569  InnerSymbolTable::getInnerSymbolAttrName(), "moduleName",
1570  "argNames", "resultNames", "parameters"});
1571 }
1572 
1573 /// Return the name of the specified input port or null if it cannot be
1574 /// determined.
1575 StringAttr InstanceOp::getArgumentName(size_t idx) {
1576  return instance_like_impl::getName(getArgNames(), idx);
1577 }
1578 
1579 /// Return the name of the specified result or null if it cannot be
1580 /// determined.
1581 StringAttr InstanceOp::getResultName(size_t idx) {
1582  return instance_like_impl::getName(getResultNames(), idx);
1583 }
1584 
1585 /// Change the name of the specified input port.
1586 void InstanceOp::setArgumentName(size_t i, StringAttr name) {
1587  setInputNames(instance_like_impl::updateName(getArgNames(), i, name));
1588 }
1589 
1590 /// Change the name of the specified output port.
1591 void InstanceOp::setResultName(size_t i, StringAttr name) {
1592  setOutputNames(instance_like_impl::updateName(getResultNames(), i, name));
1593 }
1594 
1595 /// Suggest a name for each result value based on the saved result names
1596 /// attribute.
1599  getResultNames(), getResults());
1600 }
1601 
1602 ModulePortInfo InstanceOp::getPortList() {
1603  SmallVector<PortInfo> inputs, outputs;
1604  auto emptyDict = DictionaryAttr::get(getContext());
1605  auto argNames = (*this)->getAttrOfType<ArrayAttr>("argNames");
1606  auto argTypes = getModuleType(*this).getInputs();
1607  auto argLocs = (*this)->getAttrOfType<ArrayAttr>("argLocs");
1608  for (unsigned i = 0, e = argTypes.size(); i < e; ++i) {
1609  auto type = argTypes[i];
1610  auto direction = ModulePort::Direction::Input;
1611 
1612  if (auto inout = type.dyn_cast<InOutType>()) {
1613  type = inout.getElementType();
1614  direction = ModulePort::Direction::InOut;
1615  }
1616 
1617  LocationAttr loc;
1618  if (argLocs)
1619  loc = argLocs[i].cast<LocationAttr>();
1620  inputs.push_back({{argNames[i].cast<StringAttr>(), type, direction},
1621  i,
1622  {},
1623  emptyDict,
1624  loc});
1625  }
1626 
1627  auto resultNames = (*this)->getAttrOfType<ArrayAttr>("resultNames");
1628  auto resultTypes = getModuleType(*this).getResults();
1629  auto resultLocs = (*this)->getAttrOfType<ArrayAttr>("resultLocs");
1630  for (unsigned i = 0, e = resultTypes.size(); i < e; ++i) {
1631  LocationAttr loc;
1632  if (resultLocs)
1633  loc = resultLocs[i].cast<LocationAttr>();
1634  outputs.push_back({{resultNames[i].cast<StringAttr>(), resultTypes[i],
1636  i,
1637  {},
1638  emptyDict,
1639  loc});
1640  }
1641  return ModulePortInfo(inputs, outputs);
1642 }
1643 
1644 size_t InstanceOp::getNumPorts() {
1645  return getNumInputPorts() + getNumOutputPorts();
1646 }
1647 
1648 size_t InstanceOp::getNumInputPorts() { return getNumOperands(); }
1649 
1650 size_t InstanceOp::getNumOutputPorts() { return getNumResults(); }
1651 
1652 size_t InstanceOp::getPortIdForInputId(size_t idx) { return idx; }
1653 
1654 size_t InstanceOp::getPortIdForOutputId(size_t idx) {
1655  return idx + getNumInputPorts();
1656 }
1657 
1658 void InstanceOp::getValues(SmallVectorImpl<Value> &values,
1659  const ModulePortInfo &mpi) {
1660  size_t inputPort = 0, resultPort = 0;
1661  values.resize(mpi.size());
1662  auto results = getResults();
1663  auto inputs = getInputs();
1664  for (auto [idx, port] : llvm::enumerate(mpi))
1665  if (mpi.at(idx).isOutput())
1666  values[idx] = results[resultPort++];
1667  else
1668  values[idx] = inputs[inputPort++];
1669 }
1670 
1671 //===----------------------------------------------------------------------===//
1672 // HWOutputOp
1673 //===----------------------------------------------------------------------===//
1674 
1675 /// Verify that the num of operands and types fit the declared results.
1676 LogicalResult OutputOp::verify() {
1677  // Check that the we (hw.output) have the same number of operands as our
1678  // region has results.
1679  ModuleType modType;
1680  if (auto mod = dyn_cast<HWModuleOp>((*this)->getParentOp()))
1681  modType = mod.getHWModuleType();
1682  else if (auto mod = dyn_cast<HWTestModuleOp>((*this)->getParentOp()))
1683  modType = mod.getModuleType();
1684  else {
1685  emitOpError("must have a module parent");
1686  return failure();
1687  }
1688  auto modResults = modType.getOutputTypes();
1689  OperandRange outputValues = getOperands();
1690  if (modResults.size() != outputValues.size()) {
1691  emitOpError("must have same number of operands as region results.");
1692  return failure();
1693  }
1694 
1695  // Check that the types of our operands and the region's results match.
1696  for (size_t i = 0, e = modResults.size(); i < e; ++i) {
1697  if (modResults[i] != outputValues[i].getType()) {
1698  emitOpError("output types must match module. In "
1699  "operand ")
1700  << i << ", expected " << modResults[i] << ", but got "
1701  << outputValues[i].getType() << ".";
1702  return failure();
1703  }
1704  }
1705 
1706  return success();
1707 }
1708 
1709 //===----------------------------------------------------------------------===//
1710 // Other Operations
1711 //===----------------------------------------------------------------------===//
1712 
1713 static ParseResult parseSliceTypes(OpAsmParser &p, Type &srcType,
1714  Type &idxType) {
1715  Type type;
1716  if (p.parseType(type))
1717  return p.emitError(p.getCurrentLocation(), "Expected type");
1718  auto arrType = type_dyn_cast<ArrayType>(type);
1719  if (!arrType)
1720  return p.emitError(p.getCurrentLocation(), "Expected !hw.array type");
1721  srcType = type;
1722  unsigned idxWidth = llvm::Log2_64_Ceil(arrType.getNumElements());
1723  idxType = IntegerType::get(p.getBuilder().getContext(), idxWidth);
1724  return success();
1725 }
1726 
1727 static void printSliceTypes(OpAsmPrinter &p, Operation *, Type srcType,
1728  Type idxType) {
1729  p.printType(srcType);
1730 }
1731 
1732 ParseResult ArrayCreateOp::parse(OpAsmParser &parser, OperationState &result) {
1733  llvm::SMLoc inputOperandsLoc = parser.getCurrentLocation();
1734  llvm::SmallVector<OpAsmParser::UnresolvedOperand, 16> operands;
1735  Type elemType;
1736 
1737  if (parser.parseOperandList(operands) ||
1738  parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
1739  parser.parseType(elemType))
1740  return failure();
1741 
1742  if (operands.size() == 0)
1743  return parser.emitError(inputOperandsLoc,
1744  "Cannot construct an array of length 0");
1745  result.addTypes({ArrayType::get(elemType, operands.size())});
1746 
1747  for (auto operand : operands)
1748  if (parser.resolveOperand(operand, elemType, result.operands))
1749  return failure();
1750  return success();
1751 }
1752 
1753 void ArrayCreateOp::print(OpAsmPrinter &p) {
1754  p << " ";
1755  p.printOperands(getInputs());
1756  p.printOptionalAttrDict((*this)->getAttrs());
1757  p << " : " << getInputs()[0].getType();
1758 }
1759 
1760 void ArrayCreateOp::build(OpBuilder &b, OperationState &state,
1761  ValueRange values) {
1762  assert(values.size() > 0 && "Cannot build array of zero elements");
1763  Type elemType = values[0].getType();
1764  assert(llvm::all_of(
1765  values,
1766  [elemType](Value v) -> bool { return v.getType() == elemType; }) &&
1767  "All values must have same type.");
1768  build(b, state, ArrayType::get(elemType, values.size()), values);
1769 }
1770 
1771 LogicalResult ArrayCreateOp::verify() {
1772  unsigned returnSize = getType().cast<ArrayType>().getNumElements();
1773  if (getInputs().size() != returnSize)
1774  return failure();
1775  return success();
1776 }
1777 
1778 OpFoldResult ArrayCreateOp::fold(FoldAdaptor adaptor) {
1779  if (llvm::any_of(adaptor.getInputs(), [](Attribute attr) { return !attr; }))
1780  return {};
1781  return ArrayAttr::get(getContext(), adaptor.getInputs());
1782 }
1783 
1784 // Check whether an integer value is an offset from a base.
1785 bool hw::isOffset(Value base, Value index, uint64_t offset) {
1786  if (auto constBase = base.getDefiningOp<hw::ConstantOp>()) {
1787  if (auto constIndex = index.getDefiningOp<hw::ConstantOp>()) {
1788  // If both values are a constant, check if index == base + offset.
1789  // To account for overflow, the addition is performed with an extra bit
1790  // and the offset is asserted to fit in the bit width of the base.
1791  auto baseValue = constBase.getValue();
1792  auto indexValue = constIndex.getValue();
1793 
1794  unsigned bits = baseValue.getBitWidth();
1795  assert(bits == indexValue.getBitWidth() && "mismatched widths");
1796 
1797  if (bits < 64 && offset >= (1ull << bits))
1798  return false;
1799 
1800  APInt baseExt = baseValue.zextOrTrunc(bits + 1);
1801  APInt indexExt = indexValue.zextOrTrunc(bits + 1);
1802  return baseExt + offset == indexExt;
1803  }
1804  }
1805  return false;
1806 }
1807 
1808 // Canonicalize a create of consecutive elements to a slice.
1809 static LogicalResult foldCreateToSlice(ArrayCreateOp op,
1810  PatternRewriter &rewriter) {
1811  // Do not canonicalize create of get into a slice.
1812  auto arrayTy = hw::type_cast<ArrayType>(op.getType());
1813  if (arrayTy.getNumElements() <= 1)
1814  return failure();
1815  auto elemTy = arrayTy.getElementType();
1816 
1817  // Check if create arguments are consecutive elements of the same array.
1818  // Attempt to break a create of gets into a sequence of consecutive intervals.
1819  struct Chunk {
1820  Value input;
1821  Value index;
1822  size_t size;
1823  };
1824  SmallVector<Chunk> chunks;
1825  for (Value value : llvm::reverse(op.getInputs())) {
1826  auto get = value.getDefiningOp<ArrayGetOp>();
1827  if (!get)
1828  return failure();
1829 
1830  Value input = get.getInput();
1831  Value index = get.getIndex();
1832  if (!chunks.empty()) {
1833  auto &c = *chunks.rbegin();
1834  if (c.input == get.getInput() && isOffset(c.index, index, c.size)) {
1835  c.size++;
1836  continue;
1837  }
1838  }
1839 
1840  chunks.push_back(Chunk{input, index, 1});
1841  }
1842 
1843  // If there is a single slice, eliminate the create.
1844  if (chunks.size() == 1) {
1845  auto &chunk = chunks[0];
1846  rewriter.replaceOp(op, rewriter.createOrFold<ArraySliceOp>(
1847  op.getLoc(), arrayTy, chunk.input, chunk.index));
1848  return success();
1849  }
1850 
1851  // If the number of chunks is significantly less than the number of
1852  // elements, replace the create with a concat of the identified slices.
1853  if (chunks.size() * 2 < arrayTy.getNumElements()) {
1854  SmallVector<Value> slices;
1855  for (auto &chunk : llvm::reverse(chunks)) {
1856  auto sliceTy = ArrayType::get(elemTy, chunk.size);
1857  slices.push_back(rewriter.createOrFold<ArraySliceOp>(
1858  op.getLoc(), sliceTy, chunk.input, chunk.index));
1859  }
1860  rewriter.replaceOpWithNewOp<ArrayConcatOp>(op, arrayTy, slices);
1861  return success();
1862  }
1863 
1864  return failure();
1865 }
1866 
1867 LogicalResult ArrayCreateOp::canonicalize(ArrayCreateOp op,
1868  PatternRewriter &rewriter) {
1869  if (succeeded(foldCreateToSlice(op, rewriter)))
1870  return success();
1871  return failure();
1872 }
1873 
1874 Value ArrayCreateOp::getUniformElement() {
1875  if (!getInputs().empty() && llvm::all_equal(getInputs()))
1876  return getInputs()[0];
1877  return {};
1878 }
1879 
1880 static std::optional<uint64_t> getUIntFromValue(Value value) {
1881  auto idxOp = dyn_cast_or_null<ConstantOp>(value.getDefiningOp());
1882  if (!idxOp)
1883  return std::nullopt;
1884  APInt idxAttr = idxOp.getValue();
1885  if (idxAttr.getBitWidth() > 64)
1886  return std::nullopt;
1887  return idxAttr.getLimitedValue();
1888 }
1889 
1890 LogicalResult ArraySliceOp::verify() {
1891  unsigned inputSize =
1892  type_cast<ArrayType>(getInput().getType()).getNumElements();
1893  if (llvm::Log2_64_Ceil(inputSize) !=
1894  getLowIndex().getType().getIntOrFloatBitWidth())
1895  return emitOpError(
1896  "ArraySlice: index width must match clog2 of array size");
1897  return success();
1898 }
1899 
1900 OpFoldResult ArraySliceOp::fold(FoldAdaptor adaptor) {
1901  // If we are slicing the entire input, then return it.
1902  if (getType() == getInput().getType())
1903  return getInput();
1904  return {};
1905 }
1906 
1907 LogicalResult ArraySliceOp::canonicalize(ArraySliceOp op,
1908  PatternRewriter &rewriter) {
1909  auto sliceTy = hw::type_cast<ArrayType>(op.getType());
1910  auto elemTy = sliceTy.getElementType();
1911  uint64_t sliceSize = sliceTy.getNumElements();
1912  if (sliceSize == 0)
1913  return failure();
1914 
1915  if (sliceSize == 1) {
1916  // slice(a, n) -> create(a[n])
1917  auto get = rewriter.create<ArrayGetOp>(op.getLoc(), op.getInput(),
1918  op.getLowIndex());
1919  rewriter.replaceOpWithNewOp<ArrayCreateOp>(op, op.getType(),
1920  get.getResult());
1921  return success();
1922  }
1923 
1924  auto offsetOpt = getUIntFromValue(op.getLowIndex());
1925  if (!offsetOpt)
1926  return failure();
1927 
1928  auto inputOp = op.getInput().getDefiningOp();
1929  if (auto inputSlice = dyn_cast_or_null<ArraySliceOp>(inputOp)) {
1930  // slice(slice(a, n), m) -> slice(a, n + m)
1931  if (inputSlice == op)
1932  return failure();
1933 
1934  auto inputIndex = inputSlice.getLowIndex();
1935  auto inputOffsetOpt = getUIntFromValue(inputIndex);
1936  if (!inputOffsetOpt)
1937  return failure();
1938 
1939  uint64_t offset = *offsetOpt + *inputOffsetOpt;
1940  auto lowIndex =
1941  rewriter.create<ConstantOp>(op.getLoc(), inputIndex.getType(), offset);
1942  rewriter.replaceOpWithNewOp<ArraySliceOp>(op, op.getType(),
1943  inputSlice.getInput(), lowIndex);
1944  return success();
1945  }
1946 
1947  if (auto inputCreate = dyn_cast_or_null<ArrayCreateOp>(inputOp)) {
1948  // slice(create(a0, a1, ..., an), m) -> create(am, ...)
1949  auto inputs = inputCreate.getInputs();
1950 
1951  uint64_t begin = inputs.size() - *offsetOpt - sliceSize;
1952  rewriter.replaceOpWithNewOp<ArrayCreateOp>(op, op.getType(),
1953  inputs.slice(begin, sliceSize));
1954  return success();
1955  }
1956 
1957  if (auto inputConcat = dyn_cast_or_null<ArrayConcatOp>(inputOp)) {
1958  // slice(concat(a1, a2, ...)) -> concat(a2, slice(a3, ..), ...)
1959  SmallVector<Value> chunks;
1960  uint64_t sliceStart = *offsetOpt;
1961  for (auto input : llvm::reverse(inputConcat.getInputs())) {
1962  // Check whether the input intersects with the slice.
1963  uint64_t inputSize =
1964  hw::type_cast<ArrayType>(input.getType()).getNumElements();
1965  if (inputSize == 0 || inputSize <= sliceStart) {
1966  sliceStart -= inputSize;
1967  continue;
1968  }
1969 
1970  // Find the indices to slice from this input by intersection.
1971  uint64_t cutEnd = std::min(inputSize, sliceStart + sliceSize);
1972  uint64_t cutSize = cutEnd - sliceStart;
1973  assert(cutSize != 0 && "slice cannot be empty");
1974 
1975  if (cutSize == inputSize) {
1976  // The whole input fits in the slice, add it.
1977  assert(sliceStart == 0 && "invalid cut size");
1978  chunks.push_back(input);
1979  } else {
1980  // Slice the required bits from the input.
1981  unsigned width = inputSize == 1 ? 1 : llvm::Log2_64_Ceil(inputSize);
1982  auto lowIndex = rewriter.create<ConstantOp>(
1983  op.getLoc(), rewriter.getIntegerType(width), sliceStart);
1984  chunks.push_back(rewriter.create<ArraySliceOp>(
1985  op.getLoc(), hw::ArrayType::get(elemTy, cutSize), input, lowIndex));
1986  }
1987 
1988  sliceStart = 0;
1989  sliceSize -= cutSize;
1990  if (sliceSize == 0)
1991  break;
1992  }
1993 
1994  assert(chunks.size() > 0 && "missing sliced items");
1995  if (chunks.size() == 1)
1996  rewriter.replaceOp(op, chunks[0]);
1997  else
1998  rewriter.replaceOpWithNewOp<ArrayConcatOp>(
1999  op, llvm::to_vector(llvm::reverse(chunks)));
2000  return success();
2001  }
2002  return failure();
2003 }
2004 
2005 //===----------------------------------------------------------------------===//
2006 // ArrayConcatOp
2007 //===----------------------------------------------------------------------===//
2008 
2009 static ParseResult parseArrayConcatTypes(OpAsmParser &p,
2010  SmallVectorImpl<Type> &inputTypes,
2011  Type &resultType) {
2012  Type elemType;
2013  uint64_t resultSize = 0;
2014 
2015  auto parseElement = [&]() -> ParseResult {
2016  Type ty;
2017  if (p.parseType(ty))
2018  return failure();
2019  auto arrTy = type_dyn_cast<ArrayType>(ty);
2020  if (!arrTy)
2021  return p.emitError(p.getCurrentLocation(), "Expected !hw.array type");
2022  if (elemType && elemType != arrTy.getElementType())
2023  return p.emitError(p.getCurrentLocation(), "Expected array element type ")
2024  << elemType;
2025 
2026  elemType = arrTy.getElementType();
2027  inputTypes.push_back(ty);
2028  resultSize += arrTy.getNumElements();
2029  return success();
2030  };
2031 
2032  if (p.parseCommaSeparatedList(parseElement))
2033  return failure();
2034 
2035  resultType = ArrayType::get(elemType, resultSize);
2036  return success();
2037 }
2038 
2039 static void printArrayConcatTypes(OpAsmPrinter &p, Operation *,
2040  TypeRange inputTypes, Type resultType) {
2041  llvm::interleaveComma(inputTypes, p, [&p](Type t) { p << t; });
2042 }
2043 
2044 void ArrayConcatOp::build(OpBuilder &b, OperationState &state,
2045  ValueRange values) {
2046  assert(!values.empty() && "Cannot build array of zero elements");
2047  ArrayType arrayTy = values[0].getType().cast<ArrayType>();
2048  Type elemTy = arrayTy.getElementType();
2049  assert(llvm::all_of(values,
2050  [elemTy](Value v) -> bool {
2051  return v.getType().isa<ArrayType>() &&
2052  v.getType().cast<ArrayType>().getElementType() ==
2053  elemTy;
2054  }) &&
2055  "All values must be of ArrayType with the same element type.");
2056 
2057  uint64_t resultSize = 0;
2058  for (Value val : values)
2059  resultSize += val.getType().cast<ArrayType>().getNumElements();
2060  build(b, state, ArrayType::get(elemTy, resultSize), values);
2061 }
2062 
2063 OpFoldResult ArrayConcatOp::fold(FoldAdaptor adaptor) {
2064  auto inputs = adaptor.getInputs();
2065  SmallVector<Attribute> array;
2066  for (size_t i = 0, e = getNumOperands(); i < e; ++i) {
2067  if (!inputs[i])
2068  return {};
2069  llvm::copy(inputs[i].cast<ArrayAttr>(), std::back_inserter(array));
2070  }
2071  return ArrayAttr::get(getContext(), array);
2072 }
2073 
2074 // Flatten a concatenation of array creates into a single create.
2075 static bool flattenConcatOp(ArrayConcatOp op, PatternRewriter &rewriter) {
2076  for (auto input : op.getInputs())
2077  if (!input.getDefiningOp<ArrayCreateOp>())
2078  return false;
2079 
2080  SmallVector<Value> items;
2081  for (auto input : op.getInputs()) {
2082  auto create = cast<ArrayCreateOp>(input.getDefiningOp());
2083  for (auto item : create.getInputs())
2084  items.push_back(item);
2085  }
2086 
2087  rewriter.replaceOpWithNewOp<ArrayCreateOp>(op, items);
2088  return true;
2089 }
2090 
2091 // Merge consecutive slice expressions in a concatenation.
2092 static bool mergeConcatSlices(ArrayConcatOp op, PatternRewriter &rewriter) {
2093  struct Slice {
2094  Value input;
2095  Value index;
2096  size_t size;
2097  Value op;
2098  SmallVector<Location> locs;
2099  };
2100 
2101  SmallVector<Value> items;
2102  std::optional<Slice> last;
2103  bool changed = false;
2104 
2105  auto concatenate = [&] {
2106  // If there is only one op in the slice, place it to the items list.
2107  if (!last)
2108  return;
2109  if (last->op) {
2110  items.push_back(last->op);
2111  last.reset();
2112  return;
2113  }
2114 
2115  // Otherwise, create a new slice of with the given size and place it.
2116  // In this case, the concat op is replaced, using the new argument.
2117  changed = true;
2118  auto loc = FusedLoc::get(op.getContext(), last->locs);
2119  auto origTy = hw::type_cast<ArrayType>(last->input.getType());
2120  auto arrayTy = ArrayType::get(origTy.getElementType(), last->size);
2121  items.push_back(rewriter.createOrFold<ArraySliceOp>(
2122  loc, arrayTy, last->input, last->index));
2123 
2124  last.reset();
2125  };
2126 
2127  auto append = [&](Value op, Value input, Value index, size_t size) {
2128  // If this slice is an extension of the previous one, extend the size
2129  // saved. In this case, a new slice of is created and the concatenation
2130  // operator is rewritten. Otherwise, flush the last slice.
2131  if (last) {
2132  if (last->input == input && isOffset(last->index, index, last->size)) {
2133  last->size += size;
2134  last->op = {};
2135  last->locs.push_back(op.getLoc());
2136  return;
2137  }
2138  concatenate();
2139  }
2140  last.emplace(Slice{input, index, size, op, {op.getLoc()}});
2141  };
2142 
2143  for (auto item : llvm::reverse(op.getInputs())) {
2144  if (auto slice = item.getDefiningOp<ArraySliceOp>()) {
2145  auto size = hw::type_cast<ArrayType>(slice.getType()).getNumElements();
2146  append(item, slice.getInput(), slice.getLowIndex(), size);
2147  continue;
2148  }
2149 
2150  if (auto create = item.getDefiningOp<ArrayCreateOp>()) {
2151  if (create.getInputs().size() == 1) {
2152  if (auto get = create.getInputs()[0].getDefiningOp<ArrayGetOp>()) {
2153  append(item, get.getInput(), get.getIndex(), 1);
2154  continue;
2155  }
2156  }
2157  }
2158 
2159  concatenate();
2160  items.push_back(item);
2161  }
2162  concatenate();
2163 
2164  if (!changed)
2165  return false;
2166 
2167  if (items.size() == 1) {
2168  rewriter.replaceOp(op, items[0]);
2169  } else {
2170  std::reverse(items.begin(), items.end());
2171  rewriter.replaceOpWithNewOp<ArrayConcatOp>(op, items);
2172  }
2173  return true;
2174 }
2175 
2176 LogicalResult ArrayConcatOp::canonicalize(ArrayConcatOp op,
2177  PatternRewriter &rewriter) {
2178  // concat(create(a1, ...), create(a3, ...), ...) -> create(a1, ..., a3, ...)
2179  if (flattenConcatOp(op, rewriter))
2180  return success();
2181 
2182  // concat(slice(a, n, m), slice(a, n + m, p)) -> concat(slice(a, n, m + p))
2183  if (mergeConcatSlices(op, rewriter))
2184  return success();
2185 
2186  return success();
2187 }
2188 
2189 //===----------------------------------------------------------------------===//
2190 // EnumConstantOp
2191 //===----------------------------------------------------------------------===//
2192 
2193 ParseResult EnumConstantOp::parse(OpAsmParser &parser, OperationState &result) {
2194  // Parse a Type instead of an EnumType since the type might be a type alias.
2195  // The validity of the canonical type is checked during construction of the
2196  // EnumFieldAttr.
2197  Type type;
2198  StringRef field;
2199 
2200  auto loc = parser.getEncodedSourceLoc(parser.getCurrentLocation());
2201  if (parser.parseKeyword(&field) || parser.parseColonType(type))
2202  return failure();
2203 
2204  auto fieldAttr = EnumFieldAttr::get(
2205  loc, StringAttr::get(parser.getContext(), field), type);
2206 
2207  if (!fieldAttr)
2208  return failure();
2209 
2210  result.addAttribute("field", fieldAttr);
2211  result.addTypes(type);
2212 
2213  return success();
2214 }
2215 
2216 void EnumConstantOp::print(OpAsmPrinter &p) {
2217  p << " " << getField().getField().getValue() << " : "
2218  << getField().getType().getValue();
2219 }
2220 
2222  function_ref<void(Value, StringRef)> setNameFn) {
2223  setNameFn(getResult(), getField().getField().str());
2224 }
2225 
2226 void EnumConstantOp::build(OpBuilder &builder, OperationState &odsState,
2227  EnumFieldAttr field) {
2228  return build(builder, odsState, field.getType().getValue(), field);
2229 }
2230 
2231 OpFoldResult EnumConstantOp::fold(FoldAdaptor adaptor) {
2232  assert(adaptor.getOperands().empty() && "constant has no operands");
2233  return getFieldAttr();
2234 }
2235 
2236 LogicalResult EnumConstantOp::verify() {
2237  auto fieldAttr = getFieldAttr();
2238  auto fieldType = fieldAttr.getType().getValue();
2239  // This check ensures that we are using the exact same type, without looking
2240  // through type aliases.
2241  if (fieldType != getType())
2242  emitOpError("return type ")
2243  << getType() << " does not match attribute type " << fieldAttr;
2244  return success();
2245 }
2246 
2247 //===----------------------------------------------------------------------===//
2248 // EnumCmpOp
2249 //===----------------------------------------------------------------------===//
2250 
2251 LogicalResult EnumCmpOp::verify() {
2252  // Compare the canonical types.
2253  auto lhsType = type_cast<EnumType>(getLhs().getType());
2254  auto rhsType = type_cast<EnumType>(getRhs().getType());
2255  if (rhsType != lhsType)
2256  emitOpError("types do not match");
2257  return success();
2258 }
2259 
2260 //===----------------------------------------------------------------------===//
2261 // StructCreateOp
2262 //===----------------------------------------------------------------------===//
2263 
2264 ParseResult StructCreateOp::parse(OpAsmParser &parser, OperationState &result) {
2265  llvm::SMLoc inputOperandsLoc = parser.getCurrentLocation();
2266  llvm::SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
2267  Type declOrAliasType;
2268 
2269  if (parser.parseLParen() || parser.parseOperandList(operands) ||
2270  parser.parseRParen() || parser.parseOptionalAttrDict(result.attributes) ||
2271  parser.parseColonType(declOrAliasType))
2272  return failure();
2273 
2274  auto declType = type_dyn_cast<StructType>(declOrAliasType);
2275  if (!declType)
2276  return parser.emitError(parser.getNameLoc(),
2277  "expected !hw.struct type or alias");
2278 
2279  llvm::SmallVector<Type, 4> structInnerTypes;
2280  declType.getInnerTypes(structInnerTypes);
2281  result.addTypes(declOrAliasType);
2282 
2283  if (parser.resolveOperands(operands, structInnerTypes, inputOperandsLoc,
2284  result.operands))
2285  return failure();
2286  return success();
2287 }
2288 
2289 void StructCreateOp::print(OpAsmPrinter &printer) {
2290  printer << " (";
2291  printer.printOperands(getInput());
2292  printer << ")";
2293  printer.printOptionalAttrDict((*this)->getAttrs());
2294  printer << " : " << getType();
2295 }
2296 
2297 LogicalResult StructCreateOp::verify() {
2298  auto elements = hw::type_cast<StructType>(getType()).getElements();
2299 
2300  if (elements.size() != getInput().size())
2301  return emitOpError("structure field count mismatch");
2302 
2303  for (const auto &[field, value] : llvm::zip(elements, getInput()))
2304  if (field.type != value.getType())
2305  return emitOpError("structure field `")
2306  << field.name << "` type does not match";
2307 
2308  return success();
2309 }
2310 
2311 OpFoldResult StructCreateOp::fold(FoldAdaptor adaptor) {
2312  // struct_create(struct_explode(x)) => x
2313  if (!getInput().empty())
2314  if (auto explodeOp = getInput()[0].getDefiningOp<StructExplodeOp>();
2315  explodeOp && getInput() == explodeOp.getResults() &&
2316  getResult().getType() == explodeOp.getInput().getType())
2317  return explodeOp.getInput();
2318 
2319  auto inputs = adaptor.getInput();
2320  if (llvm::any_of(inputs, [](Attribute attr) { return !attr; }))
2321  return {};
2322  return ArrayAttr::get(getContext(), inputs);
2323 }
2324 
2325 //===----------------------------------------------------------------------===//
2326 // StructExplodeOp
2327 //===----------------------------------------------------------------------===//
2328 
2329 ParseResult StructExplodeOp::parse(OpAsmParser &parser,
2330  OperationState &result) {
2331  OpAsmParser::UnresolvedOperand operand;
2332  Type declType;
2333 
2334  if (parser.parseOperand(operand) ||
2335  parser.parseOptionalAttrDict(result.attributes) ||
2336  parser.parseColonType(declType))
2337  return failure();
2338  auto structType = type_dyn_cast<StructType>(declType);
2339  if (!structType)
2340  return parser.emitError(parser.getNameLoc(),
2341  "invalid kind of type specified");
2342 
2343  llvm::SmallVector<Type, 4> structInnerTypes;
2344  structType.getInnerTypes(structInnerTypes);
2345  result.addTypes(structInnerTypes);
2346 
2347  if (parser.resolveOperand(operand, declType, result.operands))
2348  return failure();
2349  return success();
2350 }
2351 
2352 void StructExplodeOp::print(OpAsmPrinter &printer) {
2353  printer << " ";
2354  printer.printOperand(getInput());
2355  printer.printOptionalAttrDict((*this)->getAttrs());
2356  printer << " : " << getInput().getType();
2357 }
2358 
2359 LogicalResult StructExplodeOp::fold(FoldAdaptor adaptor,
2360  SmallVectorImpl<OpFoldResult> &results) {
2361  auto input = adaptor.getInput();
2362  if (!input)
2363  return failure();
2364  llvm::copy(input.cast<ArrayAttr>(), std::back_inserter(results));
2365  return success();
2366 }
2367 
2368 LogicalResult StructExplodeOp::canonicalize(StructExplodeOp op,
2369  PatternRewriter &rewriter) {
2370  auto *inputOp = op.getInput().getDefiningOp();
2371  auto elements = type_cast<StructType>(op.getInput().getType()).getElements();
2372  auto result = failure();
2373  for (auto [element, res] : llvm::zip(elements, op.getResults())) {
2374  if (auto foldResult = foldStructExtract(inputOp, element.name.str())) {
2375  rewriter.replaceAllUsesWith(res, foldResult);
2376  result = success();
2377  }
2378  }
2379  return result;
2380 }
2381 
2383  function_ref<void(Value, StringRef)> setNameFn) {
2384  auto structType = type_cast<StructType>(getInput().getType());
2385  for (auto [res, field] : llvm::zip(getResults(), structType.getElements()))
2386  setNameFn(res, field.name.str());
2387 }
2388 
2389 void StructExplodeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
2390  Value input) {
2391  StructType inputType = input.getType().dyn_cast<StructType>();
2392  assert(inputType);
2393  SmallVector<Type, 16> fieldTypes;
2394  for (auto field : inputType.getElements())
2395  fieldTypes.push_back(field.type);
2396  build(odsBuilder, odsState, fieldTypes, input);
2397 }
2398 
2399 //===----------------------------------------------------------------------===//
2400 // StructExtractOp
2401 //===----------------------------------------------------------------------===//
2402 
2403 /// Use the same parser for both struct_extract and union_extract since the
2404 /// syntax is identical.
2405 template <typename AggregateType>
2406 static ParseResult parseExtractOp(OpAsmParser &parser, OperationState &result) {
2407  OpAsmParser::UnresolvedOperand operand;
2408  StringAttr fieldName;
2409  Type declType;
2410 
2411  if (parser.parseOperand(operand) || parser.parseLSquare() ||
2412  parser.parseAttribute(fieldName, "field", result.attributes) ||
2413  parser.parseRSquare() ||
2414  parser.parseOptionalAttrDict(result.attributes) ||
2415  parser.parseColonType(declType))
2416  return failure();
2417  auto aggType = type_dyn_cast<AggregateType>(declType);
2418  if (!aggType)
2419  return parser.emitError(parser.getNameLoc(),
2420  "invalid kind of type specified");
2421 
2422  Type resultType = aggType.getFieldType(fieldName.getValue());
2423  if (!resultType) {
2424  parser.emitError(parser.getNameLoc(), "invalid field name specified");
2425  return failure();
2426  }
2427  result.addTypes(resultType);
2428 
2429  if (parser.resolveOperand(operand, declType, result.operands))
2430  return failure();
2431  return success();
2432 }
2433 
2434 /// Use the same printer for both struct_extract and union_extract since the
2435 /// syntax is identical.
2436 template <typename AggType>
2437 static void printExtractOp(OpAsmPrinter &printer, AggType op) {
2438  printer << " ";
2439  printer.printOperand(op.getInput());
2440  printer << "[\"" << op.getField() << "\"]";
2441  printer.printOptionalAttrDict(op->getAttrs(), {"field"});
2442  printer << " : " << op.getInput().getType();
2443 }
2444 
2445 ParseResult StructExtractOp::parse(OpAsmParser &parser,
2446  OperationState &result) {
2447  return parseExtractOp<StructType>(parser, result);
2448 }
2449 
2450 void StructExtractOp::print(OpAsmPrinter &printer) {
2451  printExtractOp(printer, *this);
2452 }
2453 
2454 void StructExtractOp::build(OpBuilder &builder, OperationState &odsState,
2455  Value input, StructType::FieldInfo field) {
2456  build(builder, odsState, field.type, input, field.name);
2457 }
2458 
2459 void StructExtractOp::build(OpBuilder &builder, OperationState &odsState,
2460  Value input, StringAttr fieldAttr) {
2461  auto structType = type_cast<StructType>(input.getType());
2462  auto resultType = structType.getFieldType(fieldAttr);
2463  build(builder, odsState, resultType, input, fieldAttr);
2464 }
2465 
2466 OpFoldResult StructExtractOp::fold(FoldAdaptor adaptor) {
2467  if (auto constOperand = adaptor.getInput()) {
2468  // Fold extract from aggregate constant
2469  auto operandType = type_cast<StructType>(getOperand().getType());
2470  auto fieldIdx = operandType.getFieldIndex(getField());
2471  auto operandAttr = llvm::cast<ArrayAttr>(constOperand);
2472  return operandAttr.getValue()[*fieldIdx];
2473  }
2474 
2475  if (auto foldResult =
2476  foldStructExtract(getInput().getDefiningOp(), getField()))
2477  return foldResult;
2478  return {};
2479 }
2480 
2481 LogicalResult StructExtractOp::canonicalize(StructExtractOp op,
2482  PatternRewriter &rewriter) {
2483  auto inputOp = op.getInput().getDefiningOp();
2484 
2485  // b = extract(inject(x["a"], v0)["b"]) => extract(x, "b")
2486  if (auto structInject = dyn_cast_or_null<StructInjectOp>(inputOp)) {
2487  if (structInject.getField() != op.getField()) {
2488  rewriter.replaceOpWithNewOp<StructExtractOp>(
2489  op, op.getType(), structInject.getInput(), op.getField());
2490  return success();
2491  }
2492  }
2493 
2494  return failure();
2495 }
2496 
2498  function_ref<void(Value, StringRef)> setNameFn) {
2499  auto structType = type_cast<StructType>(getInput().getType());
2500  for (auto field : structType.getElements()) {
2501  if (field.name == getField()) {
2502  setNameFn(getResult(), field.name.str());
2503  return;
2504  }
2505  }
2506 }
2507 
2508 //===----------------------------------------------------------------------===//
2509 // StructInjectOp
2510 //===----------------------------------------------------------------------===//
2511 
2512 ParseResult StructInjectOp::parse(OpAsmParser &parser, OperationState &result) {
2513  llvm::SMLoc inputOperandsLoc = parser.getCurrentLocation();
2514  OpAsmParser::UnresolvedOperand operand, val;
2515  StringAttr fieldName;
2516  Type declType;
2517 
2518  if (parser.parseOperand(operand) || parser.parseLSquare() ||
2519  parser.parseAttribute(fieldName, "field", result.attributes) ||
2520  parser.parseRSquare() || parser.parseComma() ||
2521  parser.parseOperand(val) ||
2522  parser.parseOptionalAttrDict(result.attributes) ||
2523  parser.parseColonType(declType))
2524  return failure();
2525  auto structType = type_dyn_cast<StructType>(declType);
2526  if (!structType)
2527  return parser.emitError(inputOperandsLoc, "invalid kind of type specified");
2528 
2529  Type resultType = structType.getFieldType(fieldName.getValue());
2530  if (!resultType) {
2531  parser.emitError(inputOperandsLoc, "invalid field name specified");
2532  return failure();
2533  }
2534  result.addTypes(declType);
2535 
2536  if (parser.resolveOperands({operand, val}, {declType, resultType},
2537  inputOperandsLoc, result.operands))
2538  return failure();
2539  return success();
2540 }
2541 
2542 void StructInjectOp::print(OpAsmPrinter &printer) {
2543  printer << " ";
2544  printer.printOperand(getInput());
2545  printer << "[\"" << getField() << "\"], ";
2546  printer.printOperand(getNewValue());
2547  printer.printOptionalAttrDict((*this)->getAttrs(), {"field"});
2548  printer << " : " << getInput().getType();
2549 }
2550 
2551 OpFoldResult StructInjectOp::fold(FoldAdaptor adaptor) {
2552  auto input = adaptor.getInput();
2553  auto newValue = adaptor.getNewValue();
2554  if (!input || !newValue)
2555  return {};
2556  SmallVector<Attribute> array;
2557  llvm::copy(input.cast<ArrayAttr>(), std::back_inserter(array));
2558  StructType structType = getInput().getType();
2559  auto index = *structType.getFieldIndex(getField());
2560  array[index] = newValue;
2561  return ArrayAttr::get(getContext(), array);
2562 }
2563 
2564 LogicalResult StructInjectOp::canonicalize(StructInjectOp op,
2565  PatternRewriter &rewriter) {
2566  // Canonicalize multiple injects into a create op and eliminate overwrites.
2567  SmallPtrSet<Operation *, 4> injects;
2568  DenseMap<StringAttr, Value> fields;
2569 
2570  // Chase a chain of injects. Bail out if cycles are present.
2571  StructInjectOp inject = op;
2572  Value input;
2573  do {
2574  if (!injects.insert(inject).second)
2575  return failure();
2576 
2577  fields.try_emplace(inject.getFieldAttr(), inject.getNewValue());
2578  input = inject.getInput();
2579  inject = dyn_cast_or_null<StructInjectOp>(input.getDefiningOp());
2580  } while (inject);
2581  assert(input && "missing input to inject chain");
2582 
2583  auto ty = hw::type_cast<StructType>(op.getType());
2584  auto elements = ty.getElements();
2585 
2586  // If the inject chain sets all fields, canonicalize to create.
2587  if (fields.size() == elements.size()) {
2588  SmallVector<Value> createFields;
2589  for (const auto &field : elements) {
2590  auto it = fields.find(field.name);
2591  assert(it != fields.end() && "missing field");
2592  createFields.push_back(it->second);
2593  }
2594  rewriter.replaceOpWithNewOp<StructCreateOp>(op, ty, createFields);
2595  return success();
2596  }
2597 
2598  // Nothing to canonicalize, only the original inject in the chain.
2599  if (injects.size() == fields.size())
2600  return failure();
2601 
2602  // Eliminate overwrites. The hash map contains the last write to each field.
2603  for (const auto &field : elements) {
2604  auto it = fields.find(field.name);
2605  if (it == fields.end())
2606  continue;
2607  input = rewriter.create<StructInjectOp>(op.getLoc(), ty, input, field.name,
2608  it->second);
2609  }
2610 
2611  rewriter.replaceOp(op, input);
2612  return success();
2613 }
2614 
2615 //===----------------------------------------------------------------------===//
2616 // UnionCreateOp
2617 //===----------------------------------------------------------------------===//
2618 
2619 ParseResult UnionCreateOp::parse(OpAsmParser &parser, OperationState &result) {
2620  Type declOrAliasType;
2621  StringAttr field;
2622  OpAsmParser::UnresolvedOperand input;
2623  llvm::SMLoc fieldLoc = parser.getCurrentLocation();
2624 
2625  if (parser.parseAttribute(field, "field", result.attributes) ||
2626  parser.parseComma() || parser.parseOperand(input) ||
2627  parser.parseOptionalAttrDict(result.attributes) ||
2628  parser.parseColonType(declOrAliasType))
2629  return failure();
2630 
2631  auto declType = type_dyn_cast<UnionType>(declOrAliasType);
2632  if (!declType)
2633  return parser.emitError(parser.getNameLoc(),
2634  "expected !hw.union type or alias");
2635 
2636  Type inputType = declType.getFieldType(field.getValue());
2637  if (!inputType) {
2638  parser.emitError(fieldLoc, "cannot find union field '")
2639  << field.getValue() << '\'';
2640  return failure();
2641  }
2642 
2643  if (parser.resolveOperand(input, inputType, result.operands))
2644  return failure();
2645  result.addTypes({declOrAliasType});
2646  return success();
2647 }
2648 
2649 void UnionCreateOp::print(OpAsmPrinter &printer) {
2650  printer << " \"" << getField() << "\", ";
2651  printer.printOperand(getInput());
2652  printer.printOptionalAttrDict((*this)->getAttrs(), {"field"});
2653  printer << " : " << getType();
2654 }
2655 
2656 //===----------------------------------------------------------------------===//
2657 // UnionExtractOp
2658 //===----------------------------------------------------------------------===//
2659 
2660 ParseResult UnionExtractOp::parse(OpAsmParser &parser, OperationState &result) {
2661  return parseExtractOp<UnionType>(parser, result);
2662 }
2663 
2664 void UnionExtractOp::print(OpAsmPrinter &printer) {
2665  printExtractOp(printer, *this);
2666 }
2667 
2668 LogicalResult UnionExtractOp::inferReturnTypes(
2669  MLIRContext *context, std::optional<Location> loc, ValueRange operands,
2670  DictionaryAttr attrs, mlir::OpaqueProperties properties,
2671  mlir::RegionRange regions, SmallVectorImpl<Type> &results) {
2672  results.push_back(cast<UnionType>(getCanonicalType(operands[0].getType()))
2673  .getFieldType(attrs.getAs<StringAttr>("field")));
2674  return success();
2675 }
2676 
2677 //===----------------------------------------------------------------------===//
2678 // ArrayGetOp
2679 //===----------------------------------------------------------------------===//
2680 
2681 // An array_get of an array_create with a constant index can just be the
2682 // array_create operand at the constant index. If the array_create has a
2683 // single uniform value for each element, just return that value regardless of
2684 // the index. If the array is constructed from a constant by a bitcast
2685 // operation, we can fold into a constant.
2686 OpFoldResult ArrayGetOp::fold(FoldAdaptor adaptor) {
2687  auto inputCst = adaptor.getInput().dyn_cast_or_null<ArrayAttr>();
2688  auto indexCst = adaptor.getIndex().dyn_cast_or_null<IntegerAttr>();
2689 
2690  if (inputCst) {
2691  // Constant array index.
2692  if (indexCst) {
2693  auto indexVal = indexCst.getValue();
2694  if (indexVal.getBitWidth() < 64) {
2695  auto index = indexVal.getZExtValue();
2696  return inputCst[inputCst.size() - 1 - index];
2697  }
2698  }
2699  // If all elements of the array are the same, we can return any element of
2700  // array.
2701  if (!inputCst.empty() && llvm::all_equal(inputCst))
2702  return inputCst[0];
2703  }
2704 
2705  // array_get(bitcast(c), i) -> c[i*w+w-1:i*w]
2706  if (auto bitcast = getInput().getDefiningOp<hw::BitcastOp>()) {
2707  auto intTy = getType().dyn_cast<IntegerType>();
2708  if (!intTy)
2709  return {};
2710  auto bitcastInputOp = bitcast.getInput().getDefiningOp<hw::ConstantOp>();
2711  if (!bitcastInputOp)
2712  return {};
2713  if (!indexCst)
2714  return {};
2715  auto bitcastInputCst = bitcastInputOp.getValue();
2716  // Calculate the index. Make sure to zero-extend the index value before
2717  // multiplying the element width.
2718  auto startIdx = indexCst.getValue().zext(bitcastInputCst.getBitWidth()) *
2719  getType().getIntOrFloatBitWidth();
2720  // Extract [startIdx + width - 1: startIdx].
2721  return IntegerAttr::get(intTy, bitcastInputCst.lshr(startIdx).trunc(
2722  intTy.getIntOrFloatBitWidth()));
2723  }
2724 
2725  auto inputCreate = getInput().getDefiningOp<ArrayCreateOp>();
2726  if (!inputCreate)
2727  return {};
2728 
2729  if (auto uniformValue = inputCreate.getUniformElement())
2730  return uniformValue;
2731 
2732  if (!indexCst || indexCst.getValue().getBitWidth() > 64)
2733  return {};
2734 
2735  uint64_t index = indexCst.getValue().getLimitedValue();
2736  auto createInputs = inputCreate.getInputs();
2737  if (index >= createInputs.size())
2738  return {};
2739  return createInputs[createInputs.size() - index - 1];
2740 }
2741 
2742 LogicalResult ArrayGetOp::canonicalize(ArrayGetOp op,
2743  PatternRewriter &rewriter) {
2744  auto idxOpt = getUIntFromValue(op.getIndex());
2745  if (!idxOpt)
2746  return failure();
2747 
2748  auto *inputOp = op.getInput().getDefiningOp();
2749  if (auto inputSlice = dyn_cast_or_null<ArraySliceOp>(inputOp)) {
2750  // get(slice(a, n), m) -> get(a, n + m)
2751  auto offsetOp = inputSlice.getLowIndex();
2752  auto offsetOpt = getUIntFromValue(offsetOp);
2753  if (!offsetOpt)
2754  return failure();
2755 
2756  uint64_t offset = *offsetOpt + *idxOpt;
2757  auto newOffset =
2758  rewriter.create<ConstantOp>(op.getLoc(), offsetOp.getType(), offset);
2759  rewriter.replaceOpWithNewOp<ArrayGetOp>(op, inputSlice.getInput(),
2760  newOffset);
2761  return success();
2762  }
2763 
2764  if (auto inputConcat = dyn_cast_or_null<ArrayConcatOp>(inputOp)) {
2765  // get(concat(a0, a1, ...), m) -> get(an, m - s0 - s1 - ...)
2766  uint64_t elemIndex = *idxOpt;
2767  for (auto input : llvm::reverse(inputConcat.getInputs())) {
2768  size_t size = hw::type_cast<ArrayType>(input.getType()).getNumElements();
2769  if (elemIndex >= size) {
2770  elemIndex -= size;
2771  continue;
2772  }
2773 
2774  unsigned indexWidth = size == 1 ? 1 : llvm::Log2_64_Ceil(size);
2775  auto newIdxOp = rewriter.create<ConstantOp>(
2776  op.getLoc(), rewriter.getIntegerType(indexWidth), elemIndex);
2777 
2778  rewriter.replaceOpWithNewOp<ArrayGetOp>(op, input, newIdxOp);
2779  return success();
2780  }
2781  return failure();
2782  }
2783 
2784  // array_get const, (array_get sel, (array_create a, b, c, d)) -->
2785  // array_get sel, (array_create (array_get const a), (array_get const b),
2786  // (array_get const, c), (array_get const, d))
2787  if (auto innerGet = dyn_cast_or_null<hw::ArrayGetOp>(inputOp)) {
2788  if (!innerGet.getIndex().getDefiningOp<hw::ConstantOp>()) {
2789  if (auto create =
2790  innerGet.getInput().getDefiningOp<hw::ArrayCreateOp>()) {
2791 
2792  SmallVector<Value> newValues;
2793  for (auto operand : create.getOperands())
2794  newValues.push_back(rewriter.createOrFold<hw::ArrayGetOp>(
2795  op.getLoc(), operand, op.getIndex()));
2796 
2797  rewriter.replaceOpWithNewOp<hw::ArrayGetOp>(
2798  op,
2799  rewriter.createOrFold<hw::ArrayCreateOp>(op.getLoc(), newValues),
2800  innerGet.getIndex());
2801  return success();
2802  }
2803  }
2804  }
2805 
2806  return failure();
2807 }
2808 
2809 //===----------------------------------------------------------------------===//
2810 // TypedeclOp
2811 //===----------------------------------------------------------------------===//
2812 
2813 StringRef TypedeclOp::getPreferredName() {
2814  return getVerilogName().value_or(getName());
2815 }
2816 
2817 Type TypedeclOp::getAliasType() {
2818  auto parentScope = cast<hw::TypeScopeOp>(getOperation()->getParentOp());
2819  return hw::TypeAliasType::get(
2820  SymbolRefAttr::get(parentScope.getSymNameAttr(),
2821  {FlatSymbolRefAttr::get(*this)}),
2822  getType());
2823 }
2824 
2825 //===----------------------------------------------------------------------===//
2826 // BitcastOp
2827 //===----------------------------------------------------------------------===//
2828 
2829 OpFoldResult BitcastOp::fold(FoldAdaptor) {
2830  // Identity.
2831  // bitcast(%a) : A -> A ==> %a
2832  if (getOperand().getType() == getType())
2833  return getOperand();
2834 
2835  return {};
2836 }
2837 
2838 LogicalResult BitcastOp::canonicalize(BitcastOp op, PatternRewriter &rewriter) {
2839  // Composition.
2840  // %b = bitcast(%a) : A -> B
2841  // bitcast(%b) : B -> C
2842  // ===> bitcast(%a) : A -> C
2843  auto inputBitcast =
2844  dyn_cast_or_null<BitcastOp>(op.getInput().getDefiningOp());
2845  if (!inputBitcast)
2846  return failure();
2847  auto bitcast = rewriter.createOrFold<BitcastOp>(op.getLoc(), op.getType(),
2848  inputBitcast.getInput());
2849  rewriter.replaceOp(op, bitcast);
2850  return success();
2851 }
2852 
2853 LogicalResult BitcastOp::verify() {
2854  if (getBitWidth(getInput().getType()) != getBitWidth(getResult().getType()))
2855  return this->emitOpError("Bitwidth of input must match result");
2856  return success();
2857 }
2858 
2859 //===----------------------------------------------------------------------===//
2860 // HierPathOp helpers.
2861 //===----------------------------------------------------------------------===//
2862 
2863 bool HierPathOp::dropModule(StringAttr moduleToDrop) {
2864  SmallVector<Attribute, 4> newPath;
2865  bool updateMade = false;
2866  for (auto nameRef : getNamepath()) {
2867  // nameRef is either an InnerRefAttr or a FlatSymbolRefAttr.
2868  if (auto ref = nameRef.dyn_cast<hw::InnerRefAttr>()) {
2869  if (ref.getModule() == moduleToDrop)
2870  updateMade = true;
2871  else
2872  newPath.push_back(ref);
2873  } else {
2874  if (nameRef.cast<FlatSymbolRefAttr>().getAttr() == moduleToDrop)
2875  updateMade = true;
2876  else
2877  newPath.push_back(nameRef);
2878  }
2879  }
2880  if (updateMade)
2881  setNamepathAttr(ArrayAttr::get(getContext(), newPath));
2882  return updateMade;
2883 }
2884 
2885 bool HierPathOp::inlineModule(StringAttr moduleToDrop) {
2886  SmallVector<Attribute, 4> newPath;
2887  bool updateMade = false;
2888  StringRef inlinedInstanceName = "";
2889  for (auto nameRef : getNamepath()) {
2890  // nameRef is either an InnerRefAttr or a FlatSymbolRefAttr.
2891  if (auto ref = nameRef.dyn_cast<hw::InnerRefAttr>()) {
2892  if (ref.getModule() == moduleToDrop) {
2893  inlinedInstanceName = ref.getName().getValue();
2894  updateMade = true;
2895  } else if (!inlinedInstanceName.empty()) {
2896  newPath.push_back(hw::InnerRefAttr::get(
2897  ref.getModule(),
2898  StringAttr::get(getContext(), inlinedInstanceName + "_" +
2899  ref.getName().getValue())));
2900  inlinedInstanceName = "";
2901  } else
2902  newPath.push_back(ref);
2903  } else {
2904  if (nameRef.cast<FlatSymbolRefAttr>().getAttr() == moduleToDrop)
2905  updateMade = true;
2906  else
2907  newPath.push_back(nameRef);
2908  }
2909  }
2910  if (updateMade)
2911  setNamepathAttr(ArrayAttr::get(getContext(), newPath));
2912  return updateMade;
2913 }
2914 
2915 bool HierPathOp::updateModule(StringAttr oldMod, StringAttr newMod) {
2916  SmallVector<Attribute, 4> newPath;
2917  bool updateMade = false;
2918  for (auto nameRef : getNamepath()) {
2919  // nameRef is either an InnerRefAttr or a FlatSymbolRefAttr.
2920  if (auto ref = nameRef.dyn_cast<hw::InnerRefAttr>()) {
2921  if (ref.getModule() == oldMod) {
2922  newPath.push_back(hw::InnerRefAttr::get(newMod, ref.getName()));
2923  updateMade = true;
2924  } else
2925  newPath.push_back(ref);
2926  } else {
2927  if (nameRef.cast<FlatSymbolRefAttr>().getAttr() == oldMod) {
2928  newPath.push_back(FlatSymbolRefAttr::get(newMod));
2929  updateMade = true;
2930  } else
2931  newPath.push_back(nameRef);
2932  }
2933  }
2934  if (updateMade)
2935  setNamepathAttr(ArrayAttr::get(getContext(), newPath));
2936  return updateMade;
2937 }
2938 
2939 bool HierPathOp::updateModuleAndInnerRef(
2940  StringAttr oldMod, StringAttr newMod,
2941  const llvm::DenseMap<StringAttr, StringAttr> &innerSymRenameMap) {
2942  auto fromRef = FlatSymbolRefAttr::get(oldMod);
2943  if (oldMod == newMod)
2944  return false;
2945 
2946  auto namepathNew = getNamepath().getValue().vec();
2947  bool updateMade = false;
2948  // Break from the loop if the module is found, since it can occur only once.
2949  for (auto &element : namepathNew) {
2950  if (auto innerRef = element.dyn_cast<hw::InnerRefAttr>()) {
2951  if (innerRef.getModule() != oldMod)
2952  continue;
2953  auto symName = innerRef.getName();
2954  // Since the module got updated, the old innerRef symbol inside oldMod
2955  // should also be updated to the new symbol inside the newMod.
2956  auto to = innerSymRenameMap.find(symName);
2957  if (to != innerSymRenameMap.end())
2958  symName = to->second;
2959  updateMade = true;
2960  element = hw::InnerRefAttr::get(newMod, symName);
2961  break;
2962  }
2963  if (element != fromRef)
2964  continue;
2965 
2966  updateMade = true;
2967  element = FlatSymbolRefAttr::get(newMod);
2968  break;
2969  }
2970  if (updateMade)
2971  setNamepathAttr(ArrayAttr::get(getContext(), namepathNew));
2972  return updateMade;
2973 }
2974 
2975 bool HierPathOp::truncateAtModule(StringAttr atMod, bool includeMod) {
2976  SmallVector<Attribute, 4> newPath;
2977  bool updateMade = false;
2978  for (auto nameRef : getNamepath()) {
2979  // nameRef is either an InnerRefAttr or a FlatSymbolRefAttr.
2980  if (auto ref = nameRef.dyn_cast<hw::InnerRefAttr>()) {
2981  if (ref.getModule() == atMod) {
2982  updateMade = true;
2983  if (includeMod)
2984  newPath.push_back(ref);
2985  } else
2986  newPath.push_back(ref);
2987  } else {
2988  if (nameRef.cast<FlatSymbolRefAttr>().getAttr() == atMod && !includeMod)
2989  updateMade = true;
2990  else
2991  newPath.push_back(nameRef);
2992  }
2993  if (updateMade)
2994  break;
2995  }
2996  if (updateMade)
2997  setNamepathAttr(ArrayAttr::get(getContext(), newPath));
2998  return updateMade;
2999 }
3000 
3001 /// Return just the module part of the namepath at a specific index.
3002 StringAttr HierPathOp::modPart(unsigned i) {
3003  return TypeSwitch<Attribute, StringAttr>(getNamepath()[i])
3004  .Case<FlatSymbolRefAttr>([](auto a) { return a.getAttr(); })
3005  .Case<hw::InnerRefAttr>([](auto a) { return a.getModule(); });
3006 }
3007 
3008 /// Return the root module.
3009 StringAttr HierPathOp::root() {
3010  assert(!getNamepath().empty());
3011  return modPart(0);
3012 }
3013 
3014 /// Return true if the NLA has the module in its path.
3015 bool HierPathOp::hasModule(StringAttr modName) {
3016  for (auto nameRef : getNamepath()) {
3017  // nameRef is either an InnerRefAttr or a FlatSymbolRefAttr.
3018  if (auto ref = nameRef.dyn_cast<hw::InnerRefAttr>()) {
3019  if (ref.getModule() == modName)
3020  return true;
3021  } else {
3022  if (nameRef.cast<FlatSymbolRefAttr>().getAttr() == modName)
3023  return true;
3024  }
3025  }
3026  return false;
3027 }
3028 
3029 /// Return true if the NLA has the InnerSym .
3030 bool HierPathOp::hasInnerSym(StringAttr modName, StringAttr symName) const {
3031  for (auto nameRef : const_cast<HierPathOp *>(this)->getNamepath())
3032  if (auto ref = nameRef.dyn_cast<hw::InnerRefAttr>())
3033  if (ref.getName() == symName && ref.getModule() == modName)
3034  return true;
3035 
3036  return false;
3037 }
3038 
3039 /// Return just the reference part of the namepath at a specific index. This
3040 /// will return an empty attribute if this is the leaf and the leaf is a module.
3041 StringAttr HierPathOp::refPart(unsigned i) {
3042  return TypeSwitch<Attribute, StringAttr>(getNamepath()[i])
3043  .Case<FlatSymbolRefAttr>([](auto a) { return StringAttr({}); })
3044  .Case<hw::InnerRefAttr>([](auto a) { return a.getName(); });
3045 }
3046 
3047 /// Return the leaf reference. This returns an empty attribute if the leaf
3048 /// reference is a module.
3049 StringAttr HierPathOp::ref() {
3050  assert(!getNamepath().empty());
3051  return refPart(getNamepath().size() - 1);
3052 }
3053 
3054 /// Return the leaf module.
3055 StringAttr HierPathOp::leafMod() {
3056  assert(!getNamepath().empty());
3057  return modPart(getNamepath().size() - 1);
3058 }
3059 
3060 /// Returns true if this NLA targets an instance of a module (as opposed to
3061 /// an instance's port or something inside an instance).
3062 bool HierPathOp::isModule() { return !ref(); }
3063 
3064 /// Returns true if this NLA targets something inside a module (as opposed
3065 /// to a module or an instance of a module);
3066 bool HierPathOp::isComponent() { return (bool)ref(); }
3067 
3068 // Verify the HierPathOp.
3069 // 1. Iterate over the namepath.
3070 // 2. The namepath should be a valid instance path, specified either on a
3071 // module or a declaration inside a module.
3072 // 3. Each element in the namepath is an InnerRefAttr except possibly the
3073 // last element.
3074 // 4. Make sure that the InnerRefAttr is legal, by verifying the module name
3075 // and the corresponding inner_sym on the instance.
3076 // 5. Make sure that the instance path is legal, by verifying the sequence of
3077 // instance and the expected module occurs as the next element in the path.
3078 // 6. The last element of the namepath, can be an InnerRefAttr on either a
3079 // module port or a declaration inside the module.
3080 // 7. The last element of the namepath can also be a module symbol.
3081 LogicalResult HierPathOp::verifyInnerRefs(hw::InnerRefNamespace &ns) {
3082  StringAttr expectedModuleName = {};
3083  if (!getNamepath() || getNamepath().empty())
3084  return emitOpError() << "the instance path cannot be empty";
3085  for (unsigned i = 0, s = getNamepath().size() - 1; i < s; ++i) {
3086  hw::InnerRefAttr innerRef = getNamepath()[i].dyn_cast<hw::InnerRefAttr>();
3087  if (!innerRef)
3088  return emitOpError()
3089  << "the instance path can only contain inner sym reference"
3090  << ", only the leaf can refer to a module symbol";
3091 
3092  if (expectedModuleName && expectedModuleName != innerRef.getModule())
3093  return emitOpError() << "instance path is incorrect. Expected module: "
3094  << expectedModuleName
3095  << " instead found: " << innerRef.getModule();
3096  auto instOp = ns.lookupOp<igraph::InstanceOpInterface>(innerRef);
3097  if (!instOp)
3098  return emitOpError() << " module: " << innerRef.getModule()
3099  << " does not contain any instance with symbol: "
3100  << innerRef.getName();
3101  expectedModuleName = instOp.getReferencedModuleNameAttr();
3102  }
3103  // The instance path has been verified. Now verify the last element.
3104  auto leafRef = getNamepath()[getNamepath().size() - 1];
3105  if (auto innerRef = leafRef.dyn_cast<hw::InnerRefAttr>()) {
3106  if (!ns.lookup(innerRef)) {
3107  return emitOpError() << " operation with symbol: " << innerRef
3108  << " was not found ";
3109  }
3110  if (expectedModuleName && expectedModuleName != innerRef.getModule())
3111  return emitOpError() << "instance path is incorrect. Expected module: "
3112  << expectedModuleName
3113  << " instead found: " << innerRef.getModule();
3114  } else if (expectedModuleName &&
3115  expectedModuleName !=
3116  leafRef.cast<FlatSymbolRefAttr>().getAttr()) {
3117  // This is the case when the nla is applied to a module.
3118  return emitOpError() << "instance path is incorrect. Expected module: "
3119  << expectedModuleName << " instead found: "
3120  << leafRef.cast<FlatSymbolRefAttr>().getAttr();
3121  }
3122  return success();
3123 }
3124 
3125 void HierPathOp::print(OpAsmPrinter &p) {
3126  p << " ";
3127 
3128  // Print visibility if present.
3129  StringRef visibilityAttrName = SymbolTable::getVisibilityAttrName();
3130  if (auto visibility =
3131  getOperation()->getAttrOfType<StringAttr>(visibilityAttrName))
3132  p << visibility.getValue() << ' ';
3133 
3134  p.printSymbolName(getSymName());
3135  p << " [";
3136  llvm::interleaveComma(getNamepath().getValue(), p, [&](Attribute attr) {
3137  if (auto ref = attr.dyn_cast<hw::InnerRefAttr>()) {
3138  p.printSymbolName(ref.getModule().getValue());
3139  p << "::";
3140  p.printSymbolName(ref.getName().getValue());
3141  } else {
3142  p.printSymbolName(attr.cast<FlatSymbolRefAttr>().getValue());
3143  }
3144  });
3145  p << "]";
3146  p.printOptionalAttrDict(
3147  (*this)->getAttrs(),
3148  {SymbolTable::getSymbolAttrName(), "namepath", visibilityAttrName});
3149 }
3150 
3151 ParseResult HierPathOp::parse(OpAsmParser &parser, OperationState &result) {
3152  // Parse the visibility attribute.
3153  (void)mlir::impl::parseOptionalVisibilityKeyword(parser, result.attributes);
3154 
3155  // Parse the symbol name.
3156  StringAttr symName;
3157  if (parser.parseSymbolName(symName, SymbolTable::getSymbolAttrName(),
3158  result.attributes))
3159  return failure();
3160 
3161  // Parse the namepath.
3162  SmallVector<Attribute> namepath;
3163  if (parser.parseCommaSeparatedList(
3164  OpAsmParser::Delimiter::Square, [&]() -> ParseResult {
3165  auto loc = parser.getCurrentLocation();
3166  SymbolRefAttr ref;
3167  if (parser.parseAttribute(ref))
3168  return failure();
3169 
3170  // "A" is a Ref, "A::b" is a InnerRef, "A::B::c" is an error.
3171  auto pathLength = ref.getNestedReferences().size();
3172  if (pathLength == 0)
3173  namepath.push_back(
3174  FlatSymbolRefAttr::get(ref.getRootReference()));
3175  else if (pathLength == 1)
3176  namepath.push_back(hw::InnerRefAttr::get(ref.getRootReference(),
3177  ref.getLeafReference()));
3178  else
3179  return parser.emitError(loc,
3180  "only one nested reference is allowed");
3181  return success();
3182  }))
3183  return failure();
3184  result.addAttribute("namepath",
3185  ArrayAttr::get(parser.getContext(), namepath));
3186 
3187  if (parser.parseOptionalAttrDict(result.attributes))
3188  return failure();
3189 
3190  return success();
3191 }
3192 
3193 //===----------------------------------------------------------------------===//
3194 // TriggeredOp
3195 //===----------------------------------------------------------------------===//
3196 
3197 void TriggeredOp::build(OpBuilder &builder, OperationState &odsState,
3198  EventControlAttr event, Value trigger,
3199  ValueRange inputs) {
3200  odsState.addOperands(trigger);
3201  odsState.addOperands(inputs);
3202  odsState.addAttribute(getEventAttrName(odsState.name), event);
3203  auto *r = odsState.addRegion();
3204  Block *b = new Block();
3205  r->push_back(b);
3206 
3207  llvm::SmallVector<Location> argLocs;
3208  llvm::transform(inputs, std::back_inserter(argLocs),
3209  [&](Value v) { return v.getLoc(); });
3210  b->addArguments(inputs.getTypes(), argLocs);
3211 }
3212 
3213 //===----------------------------------------------------------------------===//
3214 // Temporary test module
3215 //===----------------------------------------------------------------------===//
3216 
3217 static void
3218 addPortAttrsAndLocs(Builder &builder, OperationState &result,
3219  SmallVectorImpl<module_like_impl::PortParse> &ports,
3220  StringAttr portAttrsName, StringAttr portLocsName) {
3221  auto unknownLoc = builder.getUnknownLoc();
3222  auto nonEmptyAttrsFn = [](Attribute attr) {
3223  return attr && !cast<DictionaryAttr>(attr).empty();
3224  };
3225  auto nonEmptyLocsFn = [unknownLoc](Attribute attr) {
3226  return attr && cast<Location>(attr) != unknownLoc;
3227  };
3228 
3229  // Convert the specified array of dictionary attrs (which may have null
3230  // entries) to an ArrayAttr of dictionaries.
3231  SmallVector<Attribute> attrs;
3232  SmallVector<Attribute> locs;
3233  for (auto &port : ports) {
3234  attrs.push_back(port.attrs ? port.attrs : builder.getDictionaryAttr({}));
3235  locs.push_back(port.sourceLoc ? Location(*port.sourceLoc) : unknownLoc);
3236  }
3237 
3238  // Add the attributes to the ports.
3239  if (llvm::any_of(attrs, nonEmptyAttrsFn))
3240  result.addAttribute(portAttrsName, builder.getArrayAttr(attrs));
3241 
3242  if (llvm::any_of(locs, nonEmptyLocsFn))
3243  result.addAttribute(portLocsName, builder.getArrayAttr(locs));
3244 }
3245 
3246 void HWTestModuleOp::print(OpAsmPrinter &p) {
3247  p << ' ';
3248  // Print the visibility of the module.
3249  StringRef visibilityAttrName = SymbolTable::getVisibilityAttrName();
3250  if (auto visibility = (*this)->getAttrOfType<StringAttr>(visibilityAttrName))
3251  p << visibility.getValue() << ' ';
3252 
3253  // Print the operation and the function name.
3254  p.printSymbolName(SymbolTable::getSymbolName(*this).getValue());
3255 
3256  // Print the parameter list if present.
3257  printOptionalParameterList(p, *this, getParameters());
3258 
3260  SmallVector<StringRef, 3> omittedAttrs;
3261  omittedAttrs.push_back(getPortLocsAttrName());
3262  omittedAttrs.push_back(getModuleTypeAttrName());
3263  omittedAttrs.push_back(getPortAttrsAttrName());
3264  omittedAttrs.push_back(getParametersAttrName());
3265  omittedAttrs.push_back(visibilityAttrName);
3266  if (auto cmt = (*this)->getAttrOfType<StringAttr>("comment"))
3267  if (cmt.getValue().empty())
3268  omittedAttrs.push_back("comment");
3269 
3270  mlir::function_interface_impl::printFunctionAttributes(p, *this,
3271  omittedAttrs);
3272 
3273  // Print the body if this is not an external function.
3274  Region &body = getBody();
3275  if (!body.empty()) {
3276  p << " ";
3277  p.printRegion(body, /*printEntryBlockArgs=*/false,
3278  /*printBlockTerminators=*/true);
3279  }
3280 }
3281 
3282 ParseResult HWTestModuleOp::parse(OpAsmParser &parser, OperationState &result) {
3283  auto loc = parser.getCurrentLocation();
3284 
3285  // Parse the visibility attribute.
3286  (void)mlir::impl::parseOptionalVisibilityKeyword(parser, result.attributes);
3287 
3288  // Parse the name as a symbol.
3289  StringAttr nameAttr;
3290  if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
3291  result.attributes))
3292  return failure();
3293 
3294  // Parse the parameters.
3295  ArrayAttr parameters;
3296  if (parseOptionalParameterList(parser, parameters))
3297  return failure();
3298 
3299  SmallVector<module_like_impl::PortParse> ports;
3300  TypeAttr modType;
3301  if (failed(module_like_impl::parseModuleSignature(parser, ports, modType)))
3302  return failure();
3303 
3304  // Parse the attribute dict.
3305  if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
3306  return failure();
3307 
3308  if (hasAttribute("parameters", result.attributes)) {
3309  parser.emitError(loc, "explicit `parameters` attributes not allowed");
3310  return failure();
3311  }
3312 
3313  result.addAttribute("parameters", parameters);
3314  result.addAttribute(getModuleTypeAttrName(result.name), modType);
3315  addPortAttrsAndLocs(parser.getBuilder(), result, ports,
3316  getPortAttrsAttrName(result.name),
3317  getPortLocsAttrName(result.name));
3318 
3319  SmallVector<OpAsmParser::Argument, 4> entryArgs;
3320  for (auto &port : ports)
3321  if (port.direction != ModulePort::Direction::Output)
3322  entryArgs.push_back(port);
3323 
3324  // Parse the optional function body.
3325  auto *body = result.addRegion();
3326  if (parser.parseRegion(*body, entryArgs))
3327  return failure();
3328 
3329  HWModuleOp::ensureTerminator(*body, parser.getBuilder(), result.location);
3330 
3331  return success();
3332 }
3333 
3334 void HWTestModuleOp::getAsmBlockArgumentNames(
3335  mlir::Region &region, mlir::OpAsmSetValueNameFn setNameFn) {
3336  if (region.empty())
3337  return;
3338  // Assign port names to the bbargs.
3339  auto *block = &region.front();
3340  auto mt = getModuleType();
3341  for (size_t i = 0, e = block->getNumArguments(); i != e; ++i) {
3342  auto name = mt.getInputName(i);
3343  if (!name.empty())
3344  setNameFn(block->getArgument(i), name);
3345  }
3346 }
3347 
3348 ModulePortInfo HWTestModuleOp::getPortList() {
3349  SmallVector<PortInfo> ports;
3350  auto refPorts = getModuleType().getPorts();
3351  for (auto [i, port] : enumerate(refPorts)) {
3352  auto loc = getPortLocs() ? cast<LocationAttr>((*getPortLocs())[i])
3353  : LocationAttr();
3354  auto attr = getPortAttrs() ? cast<DictionaryAttr>((*getPortAttrs())[i])
3355  : DictionaryAttr();
3356  InnerSymAttr sym = {};
3357  ports.push_back({{port}, i, sym, attr, loc});
3358  }
3359  return ModulePortInfo(ports);
3360 }
3361 
3362 size_t HWTestModuleOp::getNumPorts() { return getModuleType().getNumPorts(); }
3363 size_t HWTestModuleOp::getNumInputPorts() {
3364  return getModuleType().getNumInputs();
3365 }
3366 size_t HWTestModuleOp::getNumOutputPorts() {
3367  return getModuleType().getNumOutputs();
3368 }
3369 
3370 size_t HWTestModuleOp::getPortIdForInputId(size_t idx) {
3371  return getModuleType().getPortIdForInputId(idx);
3372 }
3373 
3374 size_t HWTestModuleOp::getPortIdForOutputId(size_t idx) {
3375  return getModuleType().getPortIdForOutputId(idx);
3376 }
3377 
3378 hw::InnerSymAttr HWTestModuleOp::getPortSymbolAttr(size_t portIndex) {
3379  auto pa = getPortAttrs();
3380  if (pa)
3381  return cast<hw::InnerSymAttr>((*pa)[portIndex]);
3382  return nullptr;
3383 }
3384 
3385 //===----------------------------------------------------------------------===//
3386 // TableGen generated logic.
3387 //===----------------------------------------------------------------------===//
3388 
3389 // Provide the autogenerated implementation guts for the Op classes.
3390 #define GET_OP_CLASSES
3391 #include "circt/Dialect/HW/HW.cpp.inc"
assert(baseType &&"element must be base type")
MlirType elementType
Definition: CHIRRTL.cpp:25
static Attribute getAttr(ArrayRef< NamedAttribute > attrs, StringRef name)
Get an attribute by name from a list of named attributes.
Definition: FIRRTLOps.cpp:3381
int32_t width
Definition: FIRRTL.cpp:27
static LogicalResult verifyModuleCommon(HWModuleLike module)
Definition: HWOps.cpp:1107
static void printParamValue(OpAsmPrinter &p, Operation *, Attribute value, Type resultType)
Definition: HWOps.cpp:477
static void addArgAndResultAttrsHW(Builder &builder, OperationState &result, ArrayRef< DictionaryAttr > argAttrs, ArrayRef< DictionaryAttr > resultAttrs)
Definition: HWOps.cpp:903
ExternModKind
Definition: HWOps.cpp:535
@ PlainMod
Definition: HWOps.cpp:535
@ ExternMod
Definition: HWOps.cpp:535
@ GenMod
Definition: HWOps.cpp:535
static void printModuleOp(OpAsmPrinter &p, ModuleTy mod)
Definition: HWOps.cpp:1033
static bool flattenConcatOp(ArrayConcatOp op, PatternRewriter &rewriter)
Definition: HWOps.cpp:2075
static LogicalResult foldCreateToSlice(ArrayCreateOp op, PatternRewriter &rewriter)
Definition: HWOps.cpp:1809
static SmallVector< Location > getAllPortLocs(ModTy module)
Definition: HWOps.cpp:1247
static void addPortAttrsAndLocs(Builder &builder, OperationState &result, SmallVectorImpl< module_like_impl::PortParse > &ports, StringAttr portAttrsName, StringAttr portLocsName)
Definition: HWOps.cpp:3218
static ArrayAttr arrayOrEmpty(mlir::MLIRContext *context, ArrayRef< Attribute > attrs)
Definition: HWOps.cpp:85
FunctionType getHWModuleOpType(Operation *op)
Definition: HWOps.cpp:1024
static Attribute getAttribute(StringRef name, ArrayRef< NamedAttribute > attrs)
Definition: HWOps.cpp:896
static void buildModule(OpBuilder &builder, OperationState &result, StringAttr name, const ModulePortInfo &ports, ArrayAttr parameters, ArrayRef< NamedAttribute > attributes, StringAttr comment)
Definition: HWOps.cpp:539
static void printExtractOp(OpAsmPrinter &printer, AggType op)
Use the same printer for both struct_extract and union_extract since the syntax is identical.
Definition: HWOps.cpp:2437
static void printArrayConcatTypes(OpAsmPrinter &p, Operation *, TypeRange inputTypes, Type resultType)
Definition: HWOps.cpp:2039
static void modifyModuleArgs(MLIRContext *context, ArrayRef< std::pair< unsigned, PortInfo >> insertArgs, ArrayRef< unsigned > removeArgs, ArrayRef< Attribute > oldArgNames, ArrayRef< Type > oldArgTypes, ArrayRef< Attribute > oldArgAttrs, ArrayRef< Location > oldArgLocs, SmallVector< Attribute > &newArgNames, SmallVector< Type > &newArgTypes, SmallVector< Attribute > &newArgAttrs, SmallVector< Location > &newArgLocs, Block *body=nullptr)
Internal implementation of argument/result insertion and removal on modules.
Definition: HWOps.cpp:604
static ParseResult parseSliceTypes(OpAsmParser &p, Type &srcType, Type &idxType)
Definition: HWOps.cpp:1713
static bool hasAttribute(StringRef name, ArrayRef< NamedAttribute > attrs)
Definition: HWOps.cpp:889
static bool mergeConcatSlices(ArrayConcatOp op, PatternRewriter &rewriter)
Definition: HWOps.cpp:2092
static ParseResult parseExtractOp(OpAsmParser &parser, OperationState &result)
Use the same parser for both struct_extract and union_extract since the syntax is identical.
Definition: HWOps.cpp:2406
static void setAllPortNames(ArrayRef< Attribute > names, ModTy module)
Definition: HWOps.cpp:1287
static void getAsmBlockArgumentNamesImpl(mlir::Region &region, OpAsmSetValueNameFn setNameFn)
Get a special name to use when printing the entry block arguments of the region contained by an opera...
Definition: HWOps.cpp:102
static ModulePortInfo getPortList(ModuleTy &mod)
Definition: HWOps.cpp:1439
static void setHWModuleType(ModTy &mod, ModuleType type)
Definition: HWOps.cpp:1360
static ParseResult parseParamValue(OpAsmParser &p, Attribute &value, Type &resultType)
Definition: HWOps.cpp:469
static LogicalResult checkAttributes(Operation *op, Attribute attr, Type type)
Definition: HWOps.cpp:403
static std::optional< uint64_t > getUIntFromValue(Value value)
Definition: HWOps.cpp:1880
static Value foldStructExtract(Operation *inputOp, StringRef field)
Definition: HWOps.cpp:68
static InnerSymAttr extractSym(DictionaryAttr attrs)
Definition: HWOps.cpp:882
static ParseResult parseHWModuleOp(OpAsmParser &parser, OperationState &result, ExternModKind modKind=PlainMod)
Definition: HWOps.cpp:924
static void printSliceTypes(OpAsmPrinter &p, Operation *, Type srcType, Type idxType)
Definition: HWOps.cpp:1727
static bool hasAdditionalAttributes(Op op, ArrayRef< StringRef > ignoredAttrs={})
Check whether an operation has any additional attributes set beyond its standard list of attributes r...
Definition: HWOps.cpp:345
Delimiter
Definition: HWOps.cpp:118
@ OptionalLessGreater
static void setAllPortLocs(ArrayRef< Location > locs, ModTy module)
Definition: HWOps.cpp:1269
static ParseResult parseArrayConcatTypes(OpAsmParser &p, SmallVectorImpl< Type > &inputTypes, Type &resultType)
Definition: HWOps.cpp:2009
@ Input
Definition: HW.h:32
@ Output
Definition: HW.h:32
@ InOut
Definition: HW.h:32
static InstancePath empty
llvm::SmallVector< StringAttr > inputs
llvm::SmallVector< StringAttr > outputs
Builder builder
static StringAttr append(StringAttr base, const Twine &suffix)
Return a attribute with the specified suffix appended.
static size_t bits(::capnp::schema::Type::Reader type)
Return the number of bits used by a Capnp type.
Definition: Schema.cpp:121
static int64_t size(hw::ArrayType mType, capnp::schema::Field::Reader cField)
Returns the expected size of an array (capnp list) in 64-bit words.
Definition: Schema.cpp:193
A namespace that is used to store existing names and generate new names in some scope within the IR.
Definition: Namespace.h:29
StringRef newName(const Twine &name)
Return a unique name, derived from the input name, and add the new name to the internal namespace.
Definition: Namespace.h:59
void setOutput(unsigned i, Value v)
Definition: HWOps.cpp:243
Value getInput(unsigned i)
Definition: HWOps.cpp:249
llvm::SmallVector< Value > outputOperands
Definition: HWOps.h:128
llvm::SmallVector< Value > inputArgs
Definition: HWOps.h:127
llvm::StringMap< unsigned > outputIdx
Definition: HWOps.h:126
llvm::StringMap< unsigned > inputIdx
Definition: HWOps.h:126
HWModulePortAccessor(Location loc, const ModulePortInfo &info, Region &bodyRegion)
Definition: HWOps.cpp:227
static StringRef getInnerSymbolAttrName()
Return the name of the attribute used for inner symbol names.
This helps visit TypeOp nodes.
Definition: HWVisitors.h:25
ResultType dispatchTypeOpVisitor(Operation *op, ExtraArgs... args)
Definition: HWVisitors.h:27
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:53
std::string getInstanceName(mlir::func::CallOp callOp)
A helper function to get the instance name.
LogicalResult inferReturnTypes(MLIRContext *context, std::optional< Location > loc, ValueRange operands, DictionaryAttr attrs, mlir::OpaqueProperties properties, mlir::RegionRange regions, SmallVectorImpl< Type > &results, llvm::function_ref< FIRRTLType(ValueRange, ArrayRef< NamedAttribute >, std::optional< Location >)> callback)
size_t getNumPorts(Operation *op)
Return the number of ports in a module-like thing (modules, memories, etc)
Definition: FIRRTLOps.cpp:271
ModuleType fnToMod(Operation *op, ArrayRef< Attribute > inputNames, ArrayRef< Attribute > outputNames)
Definition: HWTypes.cpp:999
LogicalResult verifyParameterStructure(ArrayAttr parameters, ArrayAttr moduleParameters, const EmitErrorFn &emitError)
Check that all the parameter values specified to the instance are structurally valid.
std::function< void(std::function< bool(InFlightDiagnostic &)>)> EmitErrorFn
Whenever the nested function returns true, a note referring to the referenced module is attached to t...
LogicalResult verifyInstanceOfHWModule(Operation *instance, FlatSymbolRefAttr moduleRef, OperandRange inputs, TypeRange results, ArrayAttr argNames, ArrayAttr resultNames, ArrayAttr parameters, SymbolTableCollection &symbolTable)
Combines verifyReferencedModule, verifyInputs, verifyOutputs, and verifyParameters.
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
ArrayAttr updateName(ArrayAttr oldNames, size_t i, StringAttr name)
Change the name at the specified index of the.
void getAsmResultNames(OpAsmSetValueNameFn setNameFn, StringRef instanceName, ArrayAttr resultNames, ValueRange results)
Suggest a name for each result value based on the saved result names attribute.
void printModuleSignature(OpAsmPrinter &p, Operation *op, ArrayRef< Type > argTypes, bool isVariadic, ArrayRef< Type > resultTypes, bool &needArgNamesAttr)
Print a module signature with named results.
ParseResult parseModuleSignature(OpAsmParser &parser, SmallVectorImpl< PortParse > &args, TypeAttr &modType)
New Style parsing.
ParseResult parseModuleFunctionSignature(OpAsmParser &parser, bool &isVariadic, SmallVectorImpl< OpAsmParser::Argument > &args, SmallVectorImpl< Attribute > &argNames, SmallVectorImpl< Attribute > &argLocs, SmallVectorImpl< Attribute > &resultNames, SmallVectorImpl< DictionaryAttr > &resultAttrs, SmallVectorImpl< Attribute > &resultLocs, TypeAttr &type)
This is a variant of mlir::parseFunctionSignature that allows names on result arguments.
void printModuleSignatureNew(OpAsmPrinter &p, Operation *op)
bool isOffset(Value base, Value index, uint64_t offset)
Definition: HWOps.cpp:1785
llvm::function_ref< void(OpBuilder &, HWModulePortAccessor &)> HWModuleBuilder
Definition: HWOps.h:133
bool isCombinational(Operation *op)
Return true if the specified operation is a combinational logic op.
Definition: HWOps.cpp:58
void modifyModulePorts(Operation *op, ArrayRef< std::pair< unsigned, PortInfo >> insertInputs, ArrayRef< std::pair< unsigned, PortInfo >> insertOutputs, ArrayRef< unsigned > removeInputs, ArrayRef< unsigned > removeOutputs, Block *body=nullptr)
TODO: Move all these functions to a hw::ModuleLike interface.
Definition: HWOps.cpp:696
ModulePort::Direction flip(ModulePort::Direction direction)
Flip a port direction.
Definition: HWOps.cpp:35
FunctionType getModuleType(Operation *module)
Return the signature for the specified module as a function type.
Definition: HWOps.cpp:505
LogicalResult checkParameterInContext(Attribute value, Operation *module, Operation *usingOp, bool disallowParamRefs=false)
Check parameter specified by value to see if it is valid within the scope of the specified module mod...
Definition: HWOps.cpp:204
bool isValidParameterExpression(Attribute attr, Operation *module)
Return true if the specified attribute tree is made up of nodes that are valid in a parameter express...
Definition: HWOps.cpp:223
bool isValidIndexBitWidth(Value index, Value array)
Definition: HWOps.cpp:47
int64_t getBitWidth(mlir::Type type)
Return the hardware bit width of a type.
Definition: HWTypes.cpp:100
bool isAnyModuleOrInstance(Operation *module)
Return true if isAnyModule or instance.
Definition: HWOps.cpp:499
StringAttr getVerilogModuleNameAttr(Operation *module)
Returns the verilog module name attribute or symbol name of any module-like operations.
Definition: HWOps.cpp:526
mlir::Type getCanonicalType(mlir::Type type)
Definition: HWTypes.cpp:39
circt::hw::InOutType InOutType
Definition: SVTypes.h:25
This file defines an intermediate representation for circuits acting as an abstraction for constraint...
ParseResult parseInputPortList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &inputs, SmallVectorImpl< Type > &inputTypes, ArrayAttr &inputNames)
Parse a list of instance input ports.
void printOutputPortList(OpAsmPrinter &p, Operation *op, TypeRange resultTypes, ArrayAttr resultNames)
Print a list of instance output ports.
ParseResult parseOptionalParameterList(OpAsmParser &parser, ArrayAttr &parameters)
Parse an parameter list if present.
void printOptionalParameterList(OpAsmPrinter &p, Operation *op, ArrayAttr parameters)
Print a parameter list for a module or instance.
void printInputPortList(OpAsmPrinter &p, Operation *op, OperandRange inputs, TypeRange inputTypes, ArrayAttr inputNames)
Print a list of instance input ports.
ParseResult parseOutputPortList(OpAsmParser &parser, SmallVectorImpl< Type > &resultTypes, ArrayAttr &resultNames)
Parse a list of instance output ports.
Definition: hw.py:1
function_ref< void(Value, StringRef)> OpAsmSetValueNameFn
Definition: LLVM.h:186
This holds a decoded list of input/inout and output ports for a module or instance.
PortInfo & at(size_t idx)
PortDirectionRange getOutputs()
mlir::StringAttr name
Definition: HWTypes.h:29