CIRCT  19.0.0git
HWOps.cpp
Go to the documentation of this file.
1 //===- HWOps.cpp - Implement the HW operations ----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implement the HW ops.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "circt/Dialect/HW/HWOps.h"
23 #include "circt/Support/Naming.h"
24 #include "mlir/IR/Builders.h"
25 #include "mlir/IR/PatternMatch.h"
26 #include "mlir/Interfaces/FunctionImplementation.h"
27 #include "llvm/ADT/BitVector.h"
28 #include "llvm/ADT/SmallPtrSet.h"
29 #include "llvm/ADT/StringSet.h"
30 
31 using namespace circt;
32 using namespace hw;
33 using mlir::TypedAttr;
34 
35 /// Flip a port direction.
37  switch (direction) {
44  }
45  llvm_unreachable("unknown PortDirection");
46 }
47 
48 bool hw::isValidIndexBitWidth(Value index, Value array) {
49  hw::ArrayType arrayType =
50  dyn_cast<hw::ArrayType>(hw::getCanonicalType(array.getType()));
51  assert(arrayType && "expected array type");
52  unsigned indexWidth = index.getType().getIntOrFloatBitWidth();
53  auto requiredWidth = llvm::Log2_64_Ceil(arrayType.getNumElements());
54  return requiredWidth == 0 ? (indexWidth == 0 || indexWidth == 1)
55  : indexWidth == requiredWidth;
56 }
57 
58 /// Return true if the specified operation is a combinational logic op.
59 bool hw::isCombinational(Operation *op) {
60  struct IsCombClassifier : public TypeOpVisitor<IsCombClassifier, bool> {
61  bool visitInvalidTypeOp(Operation *op) { return false; }
62  bool visitUnhandledTypeOp(Operation *op) { return true; }
63  };
64 
65  return (op->getDialect() && op->getDialect()->getNamespace() == "comb") ||
66  IsCombClassifier().dispatchTypeOpVisitor(op);
67 }
68 
69 static Value foldStructExtract(Operation *inputOp, uint32_t fieldIndex) {
70  // A struct extract of a struct create -> corresponding struct create operand.
71  if (auto structCreate = dyn_cast_or_null<StructCreateOp>(inputOp)) {
72  return structCreate.getOperand(fieldIndex);
73  }
74 
75  // Extracting injected field -> corresponding field
76  if (auto structInject = dyn_cast_or_null<StructInjectOp>(inputOp)) {
77  if (structInject.getFieldIndex() != fieldIndex)
78  return {};
79  return structInject.getNewValue();
80  }
81  return {};
82 }
83 
84 static ArrayAttr arrayOrEmpty(mlir::MLIRContext *context,
85  ArrayRef<Attribute> attrs) {
86  if (attrs.empty())
87  return ArrayAttr::get(context, {});
88  bool empty = true;
89  for (auto a : attrs)
90  if (a && !cast<DictionaryAttr>(a).empty()) {
91  empty = false;
92  break;
93  }
94  if (empty)
95  return ArrayAttr::get(context, {});
96  return ArrayAttr::get(context, attrs);
97 }
98 
99 /// Get a special name to use when printing the entry block arguments of the
100 /// region contained by an operation in this dialect.
101 static void getAsmBlockArgumentNamesImpl(mlir::Region &region,
102  OpAsmSetValueNameFn setNameFn) {
103  if (region.empty())
104  return;
105  // Assign port names to the bbargs.
106  auto module = cast<HWModuleOp>(region.getParentOp());
107 
108  auto *block = &region.front();
109  for (size_t i = 0, e = block->getNumArguments(); i != e; ++i) {
110  auto name = module.getInputName(i);
111  // Let mlir deterministically convert names to valid identifiers
112  setNameFn(block->getArgument(i), name);
113  }
114 }
115 
116 enum class Delimiter {
117  None,
118  Paren, // () enclosed list
119  OptionalLessGreater, // <> enclosed list or absent
120 };
121 
122 /// Check parameter specified by `value` to see if it is valid according to the
123 /// module's parameters. If not, emit an error to the diagnostic provided as an
124 /// argument to the lambda 'instanceError' and return failure, otherwise return
125 /// success.
126 ///
127 /// If `disallowParamRefs` is true, then parameter references are not allowed.
129  Attribute value, ArrayAttr moduleParameters,
130  const instance_like_impl::EmitErrorFn &instanceError,
131  bool disallowParamRefs) {
132  // Literals are always ok. Their types are already known to match
133  // expectations.
134  if (isa<IntegerAttr>(value) || isa<FloatAttr>(value) ||
135  isa<StringAttr>(value) || isa<ParamVerbatimAttr>(value))
136  return success();
137 
138  // Check both subexpressions of an expression.
139  if (auto expr = dyn_cast<ParamExprAttr>(value)) {
140  for (auto op : expr.getOperands())
141  if (failed(checkParameterInContext(op, moduleParameters, instanceError,
142  disallowParamRefs)))
143  return failure();
144  return success();
145  }
146 
147  // Parameter references need more analysis to make sure they are valid within
148  // this module.
149  if (auto parameterRef = dyn_cast<ParamDeclRefAttr>(value)) {
150  auto nameAttr = parameterRef.getName();
151 
152  // Don't allow references to parameters from the default values of a
153  // parameter list.
154  if (disallowParamRefs) {
155  instanceError([&](auto &diag) {
156  diag << "parameter " << nameAttr
157  << " cannot be used as a default value for a parameter";
158  return false;
159  });
160  return failure();
161  }
162 
163  // Find the corresponding attribute in the module.
164  for (auto param : moduleParameters) {
165  auto paramAttr = cast<ParamDeclAttr>(param);
166  if (paramAttr.getName() != nameAttr)
167  continue;
168 
169  // If the types match then the reference is ok.
170  if (paramAttr.getType() == parameterRef.getType())
171  return success();
172 
173  instanceError([&](auto &diag) {
174  diag << "parameter " << nameAttr << " used with type "
175  << parameterRef.getType() << "; should have type "
176  << paramAttr.getType();
177  return true;
178  });
179  return failure();
180  }
181 
182  instanceError([&](auto &diag) {
183  diag << "use of unknown parameter " << nameAttr;
184  return true;
185  });
186  return failure();
187  }
188 
189  instanceError([&](auto &diag) {
190  diag << "invalid parameter value " << value;
191  return false;
192  });
193  return failure();
194 }
195 
196 /// Check parameter specified by `value` to see if it is valid within the scope
197 /// of the specified module `module`. If not, emit an error at the location of
198 /// `usingOp` and return failure, otherwise return success. If `usingOp` is
199 /// null, then no diagnostic is generated.
200 ///
201 /// If `disallowParamRefs` is true, then parameter references are not allowed.
202 LogicalResult hw::checkParameterInContext(Attribute value, Operation *module,
203  Operation *usingOp,
204  bool disallowParamRefs) {
206  [&](const std::function<bool(InFlightDiagnostic &)> &fn) {
207  if (usingOp) {
208  auto diag = usingOp->emitOpError();
209  if (fn(diag))
210  diag.attachNote(module->getLoc()) << "module declared here";
211  }
212  };
213 
214  return checkParameterInContext(value,
215  module->getAttrOfType<ArrayAttr>("parameters"),
216  emitError, disallowParamRefs);
217 }
218 
219 /// Return true if the specified attribute tree is made up of nodes that are
220 /// valid in a parameter expression.
221 bool hw::isValidParameterExpression(Attribute attr, Operation *module) {
222  return succeeded(checkParameterInContext(attr, module, nullptr, false));
223 }
224 
226  const ModulePortInfo &info,
227  Region &bodyRegion)
228  : info(info) {
229  inputArgs.resize(info.sizeInputs());
230  for (auto [i, barg] : llvm::enumerate(bodyRegion.getArguments())) {
231  inputIdx[info.at(i).name.str()] = i;
232  inputArgs[i] = barg;
233  }
234 
235  outputOperands.resize(info.sizeOutputs());
236  for (auto [i, outputInfo] : llvm::enumerate(info.getOutputs())) {
237  outputIdx[outputInfo.name.str()] = i;
238  }
239 }
240 
241 void HWModulePortAccessor::setOutput(unsigned i, Value v) {
242  assert(outputOperands.size() > i && "invalid output index");
243  assert(outputOperands[i] == Value() && "output already set");
244  outputOperands[i] = v;
245 }
246 
248  assert(inputArgs.size() > i && "invalid input index");
249  return inputArgs[i];
250 }
251 Value HWModulePortAccessor::getInput(StringRef name) {
252  return getInput(inputIdx.find(name.str())->second);
253 }
254 void HWModulePortAccessor::setOutput(StringRef name, Value v) {
255  setOutput(outputIdx.find(name.str())->second, v);
256 }
257 
258 //===----------------------------------------------------------------------===//
259 // ConstantOp
260 //===----------------------------------------------------------------------===//
261 
262 void ConstantOp::print(OpAsmPrinter &p) {
263  p << " ";
264  p.printAttribute(getValueAttr());
265  p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"value"});
266 }
267 
268 ParseResult ConstantOp::parse(OpAsmParser &parser, OperationState &result) {
269  IntegerAttr valueAttr;
270 
271  if (parser.parseAttribute(valueAttr, "value", result.attributes) ||
272  parser.parseOptionalAttrDict(result.attributes))
273  return failure();
274 
275  result.addTypes(valueAttr.getType());
276  return success();
277 }
278 
279 LogicalResult ConstantOp::verify() {
280  // If the result type has a bitwidth, then the attribute must match its width.
281  if (getValue().getBitWidth() != cast<IntegerType>(getType()).getWidth())
282  return emitError(
283  "hw.constant attribute bitwidth doesn't match return type");
284 
285  return success();
286 }
287 
288 /// Build a ConstantOp from an APInt, infering the result type from the
289 /// width of the APInt.
290 void ConstantOp::build(OpBuilder &builder, OperationState &result,
291  const APInt &value) {
292 
293  auto type = IntegerType::get(builder.getContext(), value.getBitWidth());
294  auto attr = builder.getIntegerAttr(type, value);
295  return build(builder, result, type, attr);
296 }
297 
298 /// Build a ConstantOp from an APInt, infering the result type from the
299 /// width of the APInt.
300 void ConstantOp::build(OpBuilder &builder, OperationState &result,
301  IntegerAttr value) {
302  return build(builder, result, value.getType(), value);
303 }
304 
305 /// This builder allows construction of small signed integers like 0, 1, -1
306 /// matching a specified MLIR IntegerType. This shouldn't be used for general
307 /// constant folding because it only works with values that can be expressed in
308 /// an int64_t. Use APInt's instead.
309 void ConstantOp::build(OpBuilder &builder, OperationState &result, Type type,
310  int64_t value) {
311  auto numBits = cast<IntegerType>(type).getWidth();
312  build(builder, result, APInt(numBits, (uint64_t)value, /*isSigned=*/true));
313 }
314 
316  function_ref<void(Value, StringRef)> setNameFn) {
317  auto intTy = getType();
318  auto intCst = getValue();
319 
320  // Sugar i1 constants with 'true' and 'false'.
321  if (cast<IntegerType>(intTy).getWidth() == 1)
322  return setNameFn(getResult(), intCst.isZero() ? "false" : "true");
323 
324  // Otherwise, build a complex name with the value and type.
325  SmallVector<char, 32> specialNameBuffer;
326  llvm::raw_svector_ostream specialName(specialNameBuffer);
327  specialName << 'c' << intCst << '_' << intTy;
328  setNameFn(getResult(), specialName.str());
329 }
330 
331 OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) {
332  assert(adaptor.getOperands().empty() && "constant has no operands");
333  return getValueAttr();
334 }
335 
336 //===----------------------------------------------------------------------===//
337 // WireOp
338 //===----------------------------------------------------------------------===//
339 
340 /// Check whether an operation has any additional attributes set beyond its
341 /// standard list of attributes returned by `getAttributeNames`.
342 template <class Op>
343 static bool hasAdditionalAttributes(Op op,
344  ArrayRef<StringRef> ignoredAttrs = {}) {
345  auto names = op.getAttributeNames();
346  llvm::SmallDenseSet<StringRef> nameSet;
347  nameSet.reserve(names.size() + ignoredAttrs.size());
348  nameSet.insert(names.begin(), names.end());
349  nameSet.insert(ignoredAttrs.begin(), ignoredAttrs.end());
350  return llvm::any_of(op->getAttrs(), [&](auto namedAttr) {
351  return !nameSet.contains(namedAttr.getName());
352  });
353 }
354 
356  // If the wire has an optional 'name' attribute, use it.
357  auto nameAttr = (*this)->getAttrOfType<StringAttr>("name");
358  if (nameAttr && !nameAttr.getValue().empty())
359  setNameFn(getResult(), nameAttr.getValue());
360 }
361 
362 std::optional<size_t> WireOp::getTargetResultIndex() { return 0; }
363 
364 OpFoldResult WireOp::fold(FoldAdaptor adaptor) {
365  // If the wire has no additional attributes, no name, and no symbol, just
366  // forward its input.
367  if (!hasAdditionalAttributes(*this, {"sv.namehint"}) && !getNameAttr() &&
368  !getInnerSymAttr())
369  return getInput();
370  return {};
371 }
372 
373 LogicalResult WireOp::canonicalize(WireOp wire, PatternRewriter &rewriter) {
374  // Block if the wire has any attributes.
375  if (hasAdditionalAttributes(wire, {"sv.namehint"}))
376  return failure();
377 
378  // If the wire has a symbol, then we can't delete it.
379  if (wire.getInnerSymAttr())
380  return failure();
381 
382  // If the wire has a name or an `sv.namehint` attribute, propagate it as an
383  // `sv.namehint` to the expression.
384  if (auto *inputOp = wire.getInput().getDefiningOp())
385  if (auto name = chooseName(wire, inputOp))
386  rewriter.modifyOpInPlace(inputOp,
387  [&] { inputOp->setAttr("sv.namehint", name); });
388 
389  rewriter.replaceOp(wire, wire.getInput());
390  return success();
391 }
392 
393 //===----------------------------------------------------------------------===//
394 // AggregateConstantOp
395 //===----------------------------------------------------------------------===//
396 
397 static LogicalResult checkAttributes(Operation *op, Attribute attr, Type type) {
398  // If this is a type alias, get the underlying type.
399  if (auto typeAlias = dyn_cast<TypeAliasType>(type))
400  type = typeAlias.getCanonicalType();
401 
402  if (auto structType = dyn_cast<StructType>(type)) {
403  auto arrayAttr = dyn_cast<ArrayAttr>(attr);
404  if (!arrayAttr)
405  return op->emitOpError("expected array attribute for constant of type ")
406  << type;
407  if (structType.getElements().size() != arrayAttr.size())
408  return op->emitOpError("array attribute (")
409  << arrayAttr.size() << ") has wrong size for struct constant ("
410  << structType.getElements().size() << ")";
411 
412  for (auto [attr, fieldInfo] :
413  llvm::zip(arrayAttr.getValue(), structType.getElements())) {
414  if (failed(checkAttributes(op, attr, fieldInfo.type)))
415  return failure();
416  }
417  } else if (auto arrayType = dyn_cast<ArrayType>(type)) {
418  auto arrayAttr = dyn_cast<ArrayAttr>(attr);
419  if (!arrayAttr)
420  return op->emitOpError("expected array attribute for constant of type ")
421  << type;
422  if (arrayType.getNumElements() != arrayAttr.size())
423  return op->emitOpError("array attribute (")
424  << arrayAttr.size() << ") has wrong size for array constant ("
425  << arrayType.getNumElements() << ")";
426 
427  auto elementType = arrayType.getElementType();
428  for (auto attr : arrayAttr.getValue()) {
429  if (failed(checkAttributes(op, attr, elementType)))
430  return failure();
431  }
432  } else if (auto arrayType = dyn_cast<UnpackedArrayType>(type)) {
433  auto arrayAttr = dyn_cast<ArrayAttr>(attr);
434  if (!arrayAttr)
435  return op->emitOpError("expected array attribute for constant of type ")
436  << type;
437  auto elementType = arrayType.getElementType();
438  if (arrayType.getNumElements() != arrayAttr.size())
439  return op->emitOpError("array attribute (")
440  << arrayAttr.size()
441  << ") has wrong size for unpacked array constant ("
442  << arrayType.getNumElements() << ")";
443 
444  for (auto attr : arrayAttr.getValue()) {
445  if (failed(checkAttributes(op, attr, elementType)))
446  return failure();
447  }
448  } else if (auto enumType = dyn_cast<EnumType>(type)) {
449  auto stringAttr = dyn_cast<StringAttr>(attr);
450  if (!stringAttr)
451  return op->emitOpError("expected string attribute for constant of type ")
452  << type;
453  } else if (auto intType = dyn_cast<IntegerType>(type)) {
454  // Check the attribute kind is correct.
455  auto intAttr = dyn_cast<IntegerAttr>(attr);
456  if (!intAttr)
457  return op->emitOpError("expected integer attribute for constant of type ")
458  << type;
459  // Check the bitwidth is correct.
460  if (intAttr.getValue().getBitWidth() != intType.getWidth())
461  return op->emitOpError("hw.constant attribute bitwidth "
462  "doesn't match return type");
463  } else {
464  return op->emitOpError("unknown element type") << type;
465  }
466  return success();
467 }
468 
469 LogicalResult AggregateConstantOp::verify() {
470  return checkAttributes(*this, getFieldsAttr(), getType());
471 }
472 
473 OpFoldResult AggregateConstantOp::fold(FoldAdaptor) { return getFieldsAttr(); }
474 
475 //===----------------------------------------------------------------------===//
476 // ParamValueOp
477 //===----------------------------------------------------------------------===//
478 
479 static ParseResult parseParamValue(OpAsmParser &p, Attribute &value,
480  Type &resultType) {
481  if (p.parseType(resultType) || p.parseEqual() ||
482  p.parseAttribute(value, resultType))
483  return failure();
484  return success();
485 }
486 
487 static void printParamValue(OpAsmPrinter &p, Operation *, Attribute value,
488  Type resultType) {
489  p << resultType << " = ";
490  p.printAttributeWithoutType(value);
491 }
492 
493 LogicalResult ParamValueOp::verify() {
494  // Check that the attribute expression is valid in this module.
496  getValue(), (*this)->getParentOfType<hw::HWModuleOp>(), *this);
497 }
498 
499 OpFoldResult ParamValueOp::fold(FoldAdaptor adaptor) {
500  assert(adaptor.getOperands().empty() && "hw.param.value has no operands");
501  return getValueAttr();
502 }
503 
504 //===----------------------------------------------------------------------===//
505 // HWModuleOp
506 //===----------------------------------------------------------------------===/
507 
508 /// Return true if isAnyModule or instance.
509 bool hw::isAnyModuleOrInstance(Operation *moduleOrInstance) {
510  return isa<HWModuleLike, InstanceOp>(moduleOrInstance);
511 }
512 
513 /// Return the signature for a module as a function type from the module itself
514 /// or from an hw::InstanceOp.
515 FunctionType hw::getModuleType(Operation *moduleOrInstance) {
516  return TypeSwitch<Operation *, FunctionType>(moduleOrInstance)
517  .Case<InstanceOp, InstanceChoiceOp>([](auto instance) {
518  SmallVector<Type> inputs(instance->getOperandTypes());
519  SmallVector<Type> results(instance->getResultTypes());
520  return FunctionType::get(instance->getContext(), inputs, results);
521  })
522  .Case<HWModuleLike>(
523  [](auto mod) { return mod.getHWModuleType().getFuncType(); })
524  .Default([](Operation *op) {
525  return cast<FunctionType>(
526  cast<mlir::FunctionOpInterface>(op).getFunctionType());
527  });
528 }
529 
530 /// Return the name to use for the Verilog module that we're referencing
531 /// here. This is typically the symbol, but can be overridden with the
532 /// verilogName attribute.
533 StringAttr hw::getVerilogModuleNameAttr(Operation *module) {
534  auto nameAttr = module->getAttrOfType<StringAttr>("verilogName");
535  if (nameAttr)
536  return nameAttr;
537 
538  return module->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName());
539 }
540 
541 template <typename ModuleTy>
542 static void
543 buildModule(OpBuilder &builder, OperationState &result, StringAttr name,
544  const ModulePortInfo &ports, ArrayAttr parameters,
545  ArrayRef<NamedAttribute> attributes, StringAttr comment) {
546  using namespace mlir::function_interface_impl;
547 
548  // Add an attribute for the name.
549  result.addAttribute(SymbolTable::getSymbolAttrName(), name);
550 
551  SmallVector<Attribute> perPortAttrs;
552  SmallVector<ModulePort> portTypes;
553 
554  for (auto elt : ports) {
555  portTypes.push_back(elt);
556  llvm::SmallVector<NamedAttribute> portAttrs;
557  if (elt.attrs)
558  llvm::copy(elt.attrs, std::back_inserter(portAttrs));
559  perPortAttrs.push_back(builder.getDictionaryAttr(portAttrs));
560  }
561 
562  // Allow clients to pass in null for the parameters list.
563  if (!parameters)
564  parameters = builder.getArrayAttr({});
565 
566  // Record the argument and result types as an attribute.
567  auto type = ModuleType::get(builder.getContext(), portTypes);
568  result.addAttribute(ModuleTy::getModuleTypeAttrName(result.name),
569  TypeAttr::get(type));
570  result.addAttribute("per_port_attrs",
571  arrayOrEmpty(builder.getContext(), perPortAttrs));
572  result.addAttribute("parameters", parameters);
573  if (!comment)
574  comment = builder.getStringAttr("");
575  result.addAttribute("comment", comment);
576  result.addAttributes(attributes);
577  result.addRegion();
578 }
579 
580 /// Internal implementation of argument/result insertion and removal on modules.
581 static void modifyModuleArgs(
582  MLIRContext *context, ArrayRef<std::pair<unsigned, PortInfo>> insertArgs,
583  ArrayRef<unsigned> removeArgs, ArrayRef<Attribute> oldArgNames,
584  ArrayRef<Type> oldArgTypes, ArrayRef<Attribute> oldArgAttrs,
585  ArrayRef<Location> oldArgLocs, SmallVector<Attribute> &newArgNames,
586  SmallVector<Type> &newArgTypes, SmallVector<Attribute> &newArgAttrs,
587  SmallVector<Location> &newArgLocs, Block *body = nullptr) {
588 
589 #ifndef NDEBUG
590  // Check that the `insertArgs` and `removeArgs` indices are in ascending
591  // order.
592  assert(llvm::is_sorted(insertArgs,
593  [](auto &a, auto &b) { return a.first < b.first; }) &&
594  "insertArgs must be in ascending order");
595  assert(llvm::is_sorted(removeArgs, [](auto &a, auto &b) { return a < b; }) &&
596  "removeArgs must be in ascending order");
597 #endif
598 
599  auto oldArgCount = oldArgTypes.size();
600  auto newArgCount = oldArgCount + insertArgs.size() - removeArgs.size();
601  assert((int)newArgCount >= 0);
602 
603  newArgNames.reserve(newArgCount);
604  newArgTypes.reserve(newArgCount);
605  newArgAttrs.reserve(newArgCount);
606  newArgLocs.reserve(newArgCount);
607 
608  auto exportPortAttrName = StringAttr::get(context, "hw.exportPort");
609  auto emptyDictAttr = DictionaryAttr::get(context, {});
610  auto unknownLoc = UnknownLoc::get(context);
611 
612  BitVector erasedIndices;
613  if (body)
614  erasedIndices.resize(oldArgCount + insertArgs.size());
615 
616  for (unsigned argIdx = 0, idx = 0; argIdx <= oldArgCount; ++argIdx, ++idx) {
617  // Insert new ports at this position.
618  while (!insertArgs.empty() && insertArgs[0].first == argIdx) {
619  auto port = insertArgs[0].second;
620  if (port.dir == ModulePort::Direction::InOut &&
621  !isa<InOutType>(port.type))
622  port.type = InOutType::get(port.type);
623  auto sym = port.getSym();
624  Attribute attr =
625  (sym && !sym.empty())
626  ? DictionaryAttr::get(context, {{exportPortAttrName, sym}})
627  : emptyDictAttr;
628  newArgNames.push_back(port.name);
629  newArgTypes.push_back(port.type);
630  newArgAttrs.push_back(attr);
631  insertArgs = insertArgs.drop_front();
632  LocationAttr loc = port.loc ? port.loc : unknownLoc;
633  newArgLocs.push_back(loc);
634  if (body)
635  body->insertArgument(idx++, port.type, loc);
636  }
637  if (argIdx == oldArgCount)
638  break;
639 
640  // Migrate the old port at this position.
641  bool removed = false;
642  while (!removeArgs.empty() && removeArgs[0] == argIdx) {
643  removeArgs = removeArgs.drop_front();
644  removed = true;
645  }
646 
647  if (removed) {
648  if (body)
649  erasedIndices.set(idx);
650  } else {
651  newArgNames.push_back(oldArgNames[argIdx]);
652  newArgTypes.push_back(oldArgTypes[argIdx]);
653  newArgAttrs.push_back(oldArgAttrs.empty() ? emptyDictAttr
654  : oldArgAttrs[argIdx]);
655  newArgLocs.push_back(oldArgLocs[argIdx]);
656  }
657  }
658 
659  if (body)
660  body->eraseArguments(erasedIndices);
661 
662  assert(newArgNames.size() == newArgCount);
663  assert(newArgTypes.size() == newArgCount);
664  assert(newArgAttrs.size() == newArgCount);
665  assert(newArgLocs.size() == newArgCount);
666 }
667 
668 /// Insert and remove ports of a module. The insertion and removal indices must
669 /// be in ascending order. The indices refer to the port positions before any
670 /// insertion or removal occurs. Ports inserted at the same index will appear in
671 /// the module in the same order as they were listed in the `insert*` array.
672 ///
673 /// The operation must be any of the module-like operations.
674 ///
675 /// This is marked deprecated as it's only used from HandshakeToHW and
676 /// PortConverter and is likely broken and not currently tested. Users of this
677 /// are still written dealing with input and output ports separately, which is
678 /// an old and broken style.
679 [[deprecated]] static void
680 modifyModulePorts(Operation *op,
681  ArrayRef<std::pair<unsigned, PortInfo>> insertInputs,
682  ArrayRef<std::pair<unsigned, PortInfo>> insertOutputs,
683  ArrayRef<unsigned> removeInputs,
684  ArrayRef<unsigned> removeOutputs, Block *body = nullptr) {
685  auto moduleOp = cast<HWModuleLike>(op);
686  auto *context = moduleOp.getContext();
687 
688  // Dig up the old argument and result data.
689  auto oldArgNames = moduleOp.getInputNames();
690  auto oldArgTypes = moduleOp.getInputTypes();
691  auto oldArgAttrs = moduleOp.getAllInputAttrs();
692  auto oldArgLocs = moduleOp.getInputLocs();
693 
694  auto oldResultNames = moduleOp.getOutputNames();
695  auto oldResultTypes = moduleOp.getOutputTypes();
696  auto oldResultAttrs = moduleOp.getAllOutputAttrs();
697  auto oldResultLocs = moduleOp.getOutputLocs();
698 
699  // Modify the ports.
700  SmallVector<Attribute> newArgNames, newResultNames;
701  SmallVector<Type> newArgTypes, newResultTypes;
702  SmallVector<Attribute> newArgAttrs, newResultAttrs;
703  SmallVector<Location> newArgLocs, newResultLocs;
704 
705  modifyModuleArgs(context, insertInputs, removeInputs, oldArgNames,
706  oldArgTypes, oldArgAttrs, oldArgLocs, newArgNames,
707  newArgTypes, newArgAttrs, newArgLocs, body);
708 
709  modifyModuleArgs(context, insertOutputs, removeOutputs, oldResultNames,
710  oldResultTypes, oldResultAttrs, oldResultLocs,
711  newResultNames, newResultTypes, newResultAttrs,
712  newResultLocs);
713 
714  // Update the module operation types and attributes.
715  auto fnty = FunctionType::get(context, newArgTypes, newResultTypes);
716  auto modty = detail::fnToMod(fnty, newArgNames, newResultNames);
717  moduleOp.setHWModuleType(modty);
718  moduleOp.setAllInputAttrs(newArgAttrs);
719  moduleOp.setAllOutputAttrs(newResultAttrs);
720 
721  newArgLocs.append(newResultLocs.begin(), newResultLocs.end());
722  moduleOp.setAllPortLocs(newArgLocs);
723 }
724 
725 void HWModuleOp::build(OpBuilder &builder, OperationState &result,
726  StringAttr name, const ModulePortInfo &ports,
727  ArrayAttr parameters,
728  ArrayRef<NamedAttribute> attributes, StringAttr comment,
729  bool shouldEnsureTerminator) {
730  buildModule<HWModuleOp>(builder, result, name, ports, parameters, attributes,
731  comment);
732 
733  // Create a region and a block for the body.
734  auto *bodyRegion = result.regions[0].get();
735  Block *body = new Block();
736  bodyRegion->push_back(body);
737 
738  // Add arguments to the body block.
739  auto unknownLoc = builder.getUnknownLoc();
740  for (auto port : ports.getInputs()) {
741  auto loc = port.loc ? Location(port.loc) : unknownLoc;
742  auto type = port.type;
743  if (port.isInOut() && !isa<InOutType>(type))
744  type = InOutType::get(type);
745  body->addArgument(type, loc);
746  }
747 
748  // Add result ports attribute.
749  auto unknownLocAttr = cast<LocationAttr>(unknownLoc);
750  SmallVector<Attribute> resultLocs;
751  for (auto port : ports.getOutputs())
752  resultLocs.push_back(port.loc ? port.loc : unknownLocAttr);
753  result.addAttribute("result_locs", builder.getArrayAttr(resultLocs));
754 
755  if (shouldEnsureTerminator)
756  HWModuleOp::ensureTerminator(*bodyRegion, builder, result.location);
757 }
758 
759 void HWModuleOp::build(OpBuilder &builder, OperationState &result,
760  StringAttr name, ArrayRef<PortInfo> ports,
761  ArrayAttr parameters,
762  ArrayRef<NamedAttribute> attributes,
763  StringAttr comment) {
764  build(builder, result, name, ModulePortInfo(ports), parameters, attributes,
765  comment);
766 }
767 
768 void HWModuleOp::build(OpBuilder &builder, OperationState &odsState,
769  StringAttr name, const ModulePortInfo &ports,
770  HWModuleBuilder modBuilder, ArrayAttr parameters,
771  ArrayRef<NamedAttribute> attributes,
772  StringAttr comment) {
773  build(builder, odsState, name, ports, parameters, attributes, comment,
774  /*shouldEnsureTerminator=*/false);
775  auto *bodyRegion = odsState.regions[0].get();
776  OpBuilder::InsertionGuard guard(builder);
777  auto accessor = HWModulePortAccessor(odsState.location, ports, *bodyRegion);
778  builder.setInsertionPointToEnd(&bodyRegion->front());
779  modBuilder(builder, accessor);
780  // Create output operands.
781  llvm::SmallVector<Value> outputOperands = accessor.getOutputOperands();
782  builder.create<hw::OutputOp>(odsState.location, outputOperands);
783 }
784 
785 void HWModuleOp::modifyPorts(
786  ArrayRef<std::pair<unsigned, PortInfo>> insertInputs,
787  ArrayRef<std::pair<unsigned, PortInfo>> insertOutputs,
788  ArrayRef<unsigned> eraseInputs, ArrayRef<unsigned> eraseOutputs) {
789  modifyModulePorts(*this, insertInputs, insertOutputs, eraseInputs,
790  eraseOutputs);
791 }
792 
793 /// Return the name to use for the Verilog module that we're referencing
794 /// here. This is typically the symbol, but can be overridden with the
795 /// verilogName attribute.
797  if (auto vName = getVerilogNameAttr())
798  return vName;
799 
800  return (*this)->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName());
801 }
802 
804  if (auto vName = getVerilogNameAttr()) {
805  return vName;
806  }
807  return (*this)->getAttrOfType<StringAttr>(
808  ::mlir::SymbolTable::getSymbolAttrName());
809 }
810 
811 void HWModuleExternOp::build(OpBuilder &builder, OperationState &result,
812  StringAttr name, const ModulePortInfo &ports,
813  StringRef verilogName, ArrayAttr parameters,
814  ArrayRef<NamedAttribute> attributes) {
815  buildModule<HWModuleExternOp>(builder, result, name, ports, parameters,
816  attributes, {});
817 
818  // Add the port locations.
819  LocationAttr unknownLoc = builder.getUnknownLoc();
820  SmallVector<Attribute> portLocs;
821  for (auto elt : ports)
822  portLocs.push_back(elt.loc ? elt.loc : unknownLoc);
823  result.addAttribute("port_locs", builder.getArrayAttr(portLocs));
824 
825  if (!verilogName.empty())
826  result.addAttribute("verilogName", builder.getStringAttr(verilogName));
827 }
828 
829 void HWModuleExternOp::build(OpBuilder &builder, OperationState &result,
830  StringAttr name, ArrayRef<PortInfo> ports,
831  StringRef verilogName, ArrayAttr parameters,
832  ArrayRef<NamedAttribute> attributes) {
833  build(builder, result, name, ModulePortInfo(ports), verilogName, parameters,
834  attributes);
835 }
836 
837 void HWModuleExternOp::modifyPorts(
838  ArrayRef<std::pair<unsigned, PortInfo>> insertInputs,
839  ArrayRef<std::pair<unsigned, PortInfo>> insertOutputs,
840  ArrayRef<unsigned> eraseInputs, ArrayRef<unsigned> eraseOutputs) {
841  modifyModulePorts(*this, insertInputs, insertOutputs, eraseInputs,
842  eraseOutputs);
843 }
844 
845 void HWModuleExternOp::appendOutputs(
846  ArrayRef<std::pair<StringAttr, Value>> outputs) {}
847 
848 void HWModuleGeneratedOp::build(OpBuilder &builder, OperationState &result,
849  FlatSymbolRefAttr genKind, StringAttr name,
850  const ModulePortInfo &ports,
851  StringRef verilogName, ArrayAttr parameters,
852  ArrayRef<NamedAttribute> attributes) {
853  buildModule<HWModuleGeneratedOp>(builder, result, name, ports, parameters,
854  attributes, {});
855  // Add the port locations.
856  LocationAttr unknownLoc = builder.getUnknownLoc();
857  SmallVector<Attribute> portLocs;
858  for (auto elt : ports)
859  portLocs.push_back(elt.loc ? elt.loc : unknownLoc);
860  result.addAttribute("port_locs", builder.getArrayAttr(portLocs));
861 
862  result.addAttribute("generatorKind", genKind);
863  if (!verilogName.empty())
864  result.addAttribute("verilogName", builder.getStringAttr(verilogName));
865 }
866 
867 void HWModuleGeneratedOp::build(OpBuilder &builder, OperationState &result,
868  FlatSymbolRefAttr genKind, StringAttr name,
869  ArrayRef<PortInfo> ports, StringRef verilogName,
870  ArrayAttr parameters,
871  ArrayRef<NamedAttribute> attributes) {
872  build(builder, result, genKind, name, ModulePortInfo(ports), verilogName,
873  parameters, attributes);
874 }
875 
876 void HWModuleGeneratedOp::modifyPorts(
877  ArrayRef<std::pair<unsigned, PortInfo>> insertInputs,
878  ArrayRef<std::pair<unsigned, PortInfo>> insertOutputs,
879  ArrayRef<unsigned> eraseInputs, ArrayRef<unsigned> eraseOutputs) {
880  modifyModulePorts(*this, insertInputs, insertOutputs, eraseInputs,
881  eraseOutputs);
882 }
883 
884 void HWModuleGeneratedOp::appendOutputs(
885  ArrayRef<std::pair<StringAttr, Value>> outputs) {}
886 
887 static bool hasAttribute(StringRef name, ArrayRef<NamedAttribute> attrs) {
888  for (auto &argAttr : attrs)
889  if (argAttr.getName() == name)
890  return true;
891  return false;
892 }
893 
894 template <typename ModuleTy>
895 static ParseResult parseHWModuleOp(OpAsmParser &parser,
896  OperationState &result) {
897 
898  using namespace mlir::function_interface_impl;
899  auto builder = parser.getBuilder();
900  auto loc = parser.getCurrentLocation();
901 
902  // Parse the visibility attribute.
903  (void)mlir::impl::parseOptionalVisibilityKeyword(parser, result.attributes);
904 
905  // Parse the name as a symbol.
906  StringAttr nameAttr;
907  if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
908  result.attributes))
909  return failure();
910 
911  // Parse the generator information.
912  FlatSymbolRefAttr kindAttr;
913  if constexpr (std::is_same_v<ModuleTy, HWModuleGeneratedOp>) {
914  if (parser.parseComma() ||
915  parser.parseAttribute(kindAttr, "generatorKind", result.attributes)) {
916  return failure();
917  }
918  }
919 
920  // Parse the parameters.
921  ArrayAttr parameters;
922  if (parseOptionalParameterList(parser, parameters))
923  return failure();
924 
925  SmallVector<module_like_impl::PortParse> ports;
926  TypeAttr modType;
927  if (failed(module_like_impl::parseModuleSignature(parser, ports, modType)))
928  return failure();
929 
930  // Parse the attribute dict.
931  if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
932  return failure();
933 
934  if (hasAttribute("parameters", result.attributes)) {
935  parser.emitError(loc, "explicit `parameters` attributes not allowed");
936  return failure();
937  }
938 
939  result.addAttribute("parameters", parameters);
940  result.addAttribute(ModuleTy::getModuleTypeAttrName(result.name), modType);
941 
942  // Convert the specified array of dictionary attrs (which may have null
943  // entries) to an ArrayAttr of dictionaries.
944  SmallVector<Attribute> attrs;
945  for (auto &port : ports)
946  attrs.push_back(port.attrs ? port.attrs : builder.getDictionaryAttr({}));
947  // Add the attributes to the ports.
948  auto nonEmptyAttrsFn = [](Attribute attr) {
949  return attr && !cast<DictionaryAttr>(attr).empty();
950  };
951  if (llvm::any_of(attrs, nonEmptyAttrsFn))
952  result.addAttribute(ModuleTy::getPerPortAttrsAttrName(result.name),
953  builder.getArrayAttr(attrs));
954 
955  // Add the port locations.
956  auto unknownLoc = builder.getUnknownLoc();
957  auto nonEmptyLocsFn = [unknownLoc](Attribute attr) {
958  return attr && cast<Location>(attr) != unknownLoc;
959  };
960  SmallVector<Attribute> locs;
961  StringAttr portLocsAttrName;
962  if constexpr (std::is_same_v<ModuleTy, HWModuleOp>) {
963  // Plain modules only store the output port locations, as the input port
964  // locations will be stored in the basic block arguments.
965  portLocsAttrName = ModuleTy::getResultLocsAttrName(result.name);
966  for (auto &port : ports)
967  if (port.direction == ModulePort::Direction::Output)
968  locs.push_back(port.sourceLoc ? Location(*port.sourceLoc) : unknownLoc);
969  } else {
970  // All other modules store all port locations in a single array.
971  portLocsAttrName = ModuleTy::getPortLocsAttrName(result.name);
972  for (auto &port : ports)
973  locs.push_back(port.sourceLoc ? Location(*port.sourceLoc) : unknownLoc);
974  }
975  if (llvm::any_of(locs, nonEmptyLocsFn))
976  result.addAttribute(portLocsAttrName, builder.getArrayAttr(locs));
977 
978  // Add the entry block arguments.
979  SmallVector<OpAsmParser::Argument, 4> entryArgs;
980  for (auto &port : ports)
981  if (port.direction != ModulePort::Direction::Output)
982  entryArgs.push_back(port);
983 
984  // Parse the optional function body.
985  auto *body = result.addRegion();
986  if (std::is_same_v<ModuleTy, HWModuleOp>) {
987  if (parser.parseRegion(*body, entryArgs))
988  return failure();
989 
990  HWModuleOp::ensureTerminator(*body, parser.getBuilder(), result.location);
991  }
992  return success();
993 }
994 
995 ParseResult HWModuleOp::parse(OpAsmParser &parser, OperationState &result) {
996  return parseHWModuleOp<HWModuleOp>(parser, result);
997 }
998 
999 ParseResult HWModuleExternOp::parse(OpAsmParser &parser,
1000  OperationState &result) {
1001  return parseHWModuleOp<HWModuleExternOp>(parser, result);
1002 }
1003 
1004 ParseResult HWModuleGeneratedOp::parse(OpAsmParser &parser,
1005  OperationState &result) {
1006  return parseHWModuleOp<HWModuleGeneratedOp>(parser, result);
1007 }
1008 
1009 FunctionType getHWModuleOpType(Operation *op) {
1010  if (auto mod = dyn_cast<HWModuleLike>(op))
1011  return mod.getHWModuleType().getFuncType();
1012  return cast<FunctionType>(
1013  cast<mlir::FunctionOpInterface>(op).getFunctionType());
1014 }
1015 
1016 template <typename ModuleTy>
1017 static void printModuleOp(OpAsmPrinter &p, ModuleTy mod) {
1018  p << ' ';
1019  // Print the visibility of the module.
1020  StringRef visibilityAttrName = SymbolTable::getVisibilityAttrName();
1021  if (auto visibility = mod.getOperation()->template getAttrOfType<StringAttr>(
1022  visibilityAttrName))
1023  p << visibility.getValue() << ' ';
1024 
1025  // Print the operation and the function name.
1026  p.printSymbolName(SymbolTable::getSymbolName(mod.getOperation()).getValue());
1027  if (auto gen = dyn_cast<HWModuleGeneratedOp>(mod.getOperation())) {
1028  p << ", ";
1029  p.printSymbolName(gen.getGeneratorKind());
1030  }
1031 
1032  // Print the parameter list if present.
1033  printOptionalParameterList(p, mod.getOperation(), mod.getParameters());
1034 
1036 
1037  SmallVector<StringRef, 3> omittedAttrs;
1038  if (isa<HWModuleGeneratedOp>(mod.getOperation()))
1039  omittedAttrs.push_back("generatorKind");
1040  if constexpr (std::is_same_v<ModuleTy, HWModuleOp>)
1041  omittedAttrs.push_back(mod.getResultLocsAttrName());
1042  else
1043  omittedAttrs.push_back(mod.getPortLocsAttrName());
1044  omittedAttrs.push_back(mod.getModuleTypeAttrName());
1045  omittedAttrs.push_back(mod.getPerPortAttrsAttrName());
1046  omittedAttrs.push_back(mod.getParametersAttrName());
1047  omittedAttrs.push_back(visibilityAttrName);
1048  if (auto cmt =
1049  mod.getOperation()->template getAttrOfType<StringAttr>("comment"))
1050  if (cmt.getValue().empty())
1051  omittedAttrs.push_back("comment");
1052 
1053  mlir::function_interface_impl::printFunctionAttributes(p, mod.getOperation(),
1054  omittedAttrs);
1055 }
1056 
1057 void HWModuleExternOp::print(OpAsmPrinter &p) { printModuleOp(p, *this); }
1058 void HWModuleGeneratedOp::print(OpAsmPrinter &p) { printModuleOp(p, *this); }
1059 
1060 void HWModuleOp::print(OpAsmPrinter &p) {
1061  printModuleOp(p, *this);
1062 
1063  // Print the body if this is not an external function.
1064  Region &body = getBody();
1065  if (!body.empty()) {
1066  p << " ";
1067  p.printRegion(body, /*printEntryBlockArgs=*/false,
1068  /*printBlockTerminators=*/true);
1069  }
1070 }
1071 
1072 static LogicalResult verifyModuleCommon(HWModuleLike module) {
1073  assert(isa<HWModuleLike>(module) &&
1074  "verifier hook should only be called on modules");
1075 
1076  SmallPtrSet<Attribute, 4> paramNames;
1077 
1078  // Check parameter default values are sensible.
1079  for (auto param : module->getAttrOfType<ArrayAttr>("parameters")) {
1080  auto paramAttr = cast<ParamDeclAttr>(param);
1081 
1082  // Check that we don't have any redundant parameter names. These are
1083  // resolved by string name: reuse of the same name would cause ambiguities.
1084  if (!paramNames.insert(paramAttr.getName()).second)
1085  return module->emitOpError("parameter ")
1086  << paramAttr << " has the same name as a previous parameter";
1087 
1088  // Default values are allowed to be missing, check them if present.
1089  auto value = paramAttr.getValue();
1090  if (!value)
1091  continue;
1092 
1093  auto typedValue = dyn_cast<TypedAttr>(value);
1094  if (!typedValue)
1095  return module->emitOpError("parameter ")
1096  << paramAttr << " should have a typed value; has value " << value;
1097 
1098  if (typedValue.getType() != paramAttr.getType())
1099  return module->emitOpError("parameter ")
1100  << paramAttr << " should have type " << paramAttr.getType()
1101  << "; has type " << typedValue.getType();
1102 
1103  // Verify that this is a valid parameter value, disallowing parameter
1104  // references. We could allow parameters to refer to each other in the
1105  // future with lexical ordering if there is a need.
1106  if (failed(checkParameterInContext(value, module, module,
1107  /*disallowParamRefs=*/true)))
1108  return failure();
1109  }
1110  return success();
1111 }
1112 
1113 LogicalResult HWModuleOp::verify() {
1114  if (failed(verifyModuleCommon(*this)))
1115  return failure();
1116 
1117  auto type = getModuleType();
1118  auto *body = getBodyBlock();
1119 
1120  // Verify the number of block arguments.
1121  auto numInputs = type.getNumInputs();
1122  if (body->getNumArguments() != numInputs)
1123  return emitOpError("entry block must have")
1124  << numInputs << " arguments to match module signature";
1125 
1126  return success();
1127 }
1128 
1129 LogicalResult HWModuleExternOp::verify() { return verifyModuleCommon(*this); }
1130 
1131 std::pair<StringAttr, BlockArgument>
1132 HWModuleOp::insertInput(unsigned index, StringAttr name, Type ty) {
1133  // Find a unique name for the wire.
1134  Namespace ns;
1135  auto ports = getPortList();
1136  for (auto port : ports)
1137  ns.newName(port.name.getValue());
1138  auto nameAttr = StringAttr::get(getContext(), ns.newName(name.getValue()));
1139 
1140  Block *body = getBodyBlock();
1141 
1142  // Create a new port for the host clock.
1143  PortInfo port;
1144  port.name = nameAttr;
1146  port.type = ty;
1147  modifyModulePorts(getOperation(), {std::make_pair(index, port)}, {}, {}, {},
1148  body);
1149 
1150  // Add a new argument.
1151  return {nameAttr, body->getArgument(index)};
1152 }
1153 
1154 void HWModuleOp::insertOutputs(unsigned index,
1155  ArrayRef<std::pair<StringAttr, Value>> outputs) {
1156 
1157  auto output = cast<OutputOp>(getBodyBlock()->getTerminator());
1158  assert(index <= output->getNumOperands() && "invalid output index");
1159 
1160  // Rewrite the port list of the module.
1161  SmallVector<std::pair<unsigned, PortInfo>> indexedNewPorts;
1162  for (auto &[name, value] : outputs) {
1163  PortInfo port;
1164  port.name = name;
1166  port.type = value.getType();
1167  indexedNewPorts.emplace_back(index, port);
1168  }
1169  modifyModulePorts(getOperation(), {}, indexedNewPorts, {}, {},
1170  getBodyBlock());
1171 
1172  // Rewrite the output op.
1173  for (auto &[name, value] : outputs)
1174  output->insertOperands(index++, value);
1175 }
1176 
1177 void HWModuleOp::appendOutputs(ArrayRef<std::pair<StringAttr, Value>> outputs) {
1178  return insertOutputs(getNumOutputPorts(), outputs);
1179 }
1180 
1181 void HWModuleOp::getAsmBlockArgumentNames(mlir::Region &region,
1182  mlir::OpAsmSetValueNameFn setNameFn) {
1183  getAsmBlockArgumentNamesImpl(region, setNameFn);
1184 }
1185 
1186 void HWModuleExternOp::getAsmBlockArgumentNames(
1187  mlir::Region &region, mlir::OpAsmSetValueNameFn setNameFn) {
1188  getAsmBlockArgumentNamesImpl(region, setNameFn);
1189 }
1190 
1191 template <typename ModTy>
1192 static SmallVector<Location> getAllPortLocs(ModTy module) {
1193  auto locs = module.getPortLocs();
1194  if (locs) {
1195  SmallVector<Location> retval;
1196  retval.reserve(locs->size());
1197  for (auto l : *locs)
1198  retval.push_back(cast<Location>(l));
1199  // Either we have a length of 0 or the correct length
1200  assert(!locs->size() || locs->size() == module.getNumPorts());
1201  return retval;
1202  }
1203  return SmallVector<Location>(module.getNumPorts(),
1204  UnknownLoc::get(module.getContext()));
1205 }
1206 
1207 SmallVector<Location> HWModuleOp::getAllPortLocs() {
1208  SmallVector<Location> portLocs;
1209  portLocs.reserve(getNumPorts());
1210  auto resultLocs = getResultLocsAttr();
1211  unsigned inputCount = 0;
1212  auto modType = getModuleType();
1213  auto unknownLoc = UnknownLoc::get(getContext());
1214  auto *body = getBodyBlock();
1215  for (unsigned i = 0, e = getNumPorts(); i < e; ++i) {
1216  if (modType.isOutput(i)) {
1217  auto loc = resultLocs
1218  ? cast<Location>(
1219  resultLocs.getValue()[portLocs.size() - inputCount])
1220  : unknownLoc;
1221  portLocs.push_back(loc);
1222  } else {
1223  auto loc = body ? body->getArgument(inputCount).getLoc() : unknownLoc;
1224  portLocs.push_back(loc);
1225  ++inputCount;
1226  }
1227  }
1228  return portLocs;
1229 }
1230 
1231 SmallVector<Location> HWModuleExternOp::getAllPortLocs() {
1232  return ::getAllPortLocs(*this);
1233 }
1234 
1235 SmallVector<Location> HWModuleGeneratedOp::getAllPortLocs() {
1236  return ::getAllPortLocs(*this);
1237 }
1238 
1239 void HWModuleOp::setAllPortLocsAttrs(ArrayRef<Attribute> locs) {
1240  SmallVector<Attribute> resultLocs;
1241  unsigned inputCount = 0;
1242  auto modType = getModuleType();
1243  auto *body = getBodyBlock();
1244  for (unsigned i = 0, e = getNumPorts(); i < e; ++i) {
1245  if (modType.isOutput(i))
1246  resultLocs.push_back(locs[i]);
1247  else
1248  body->getArgument(inputCount++).setLoc(cast<Location>(locs[i]));
1249  }
1250  setResultLocsAttr(ArrayAttr::get(getContext(), resultLocs));
1251 }
1252 
1253 void HWModuleExternOp::setAllPortLocsAttrs(ArrayRef<Attribute> locs) {
1254  setPortLocsAttr(ArrayAttr::get(getContext(), locs));
1255 }
1256 
1257 void HWModuleGeneratedOp::setAllPortLocsAttrs(ArrayRef<Attribute> locs) {
1258  setPortLocsAttr(ArrayAttr::get(getContext(), locs));
1259 }
1260 
1261 template <typename ModTy>
1262 static void setAllPortNames(ArrayRef<Attribute> names, ModTy module) {
1263  auto numInputs = module.getNumInputPorts();
1264  SmallVector<Attribute> argNames(names.begin(), names.begin() + numInputs);
1265  SmallVector<Attribute> resNames(names.begin() + numInputs, names.end());
1266  auto oldType = module.getModuleType();
1267  SmallVector<ModulePort> newPorts(oldType.getPorts().begin(),
1268  oldType.getPorts().end());
1269  for (size_t i = 0UL, e = newPorts.size(); i != e; ++i)
1270  newPorts[i].name = cast<StringAttr>(names[i]);
1271  auto newType = ModuleType::get(module.getContext(), newPorts);
1272  module.setModuleType(newType);
1273 }
1274 
1275 void HWModuleOp::setAllPortNames(ArrayRef<Attribute> names) {
1276  ::setAllPortNames(names, *this);
1277 }
1278 
1279 void HWModuleExternOp::setAllPortNames(ArrayRef<Attribute> names) {
1280  ::setAllPortNames(names, *this);
1281 }
1282 
1283 void HWModuleGeneratedOp::setAllPortNames(ArrayRef<Attribute> names) {
1284  ::setAllPortNames(names, *this);
1285 }
1286 
1287 ArrayRef<Attribute> HWModuleOp::getAllPortAttrs() {
1288  auto attrs = getPerPortAttrs();
1289  if (attrs && !attrs->empty())
1290  return attrs->getValue();
1291  return {};
1292 }
1293 
1294 ArrayRef<Attribute> HWModuleExternOp::getAllPortAttrs() {
1295  auto attrs = getPerPortAttrs();
1296  if (attrs && !attrs->empty())
1297  return attrs->getValue();
1298  return {};
1299 }
1300 
1301 ArrayRef<Attribute> HWModuleGeneratedOp::getAllPortAttrs() {
1302  auto attrs = getPerPortAttrs();
1303  if (attrs && !attrs->empty())
1304  return attrs->getValue();
1305  return {};
1306 }
1307 
1308 void HWModuleOp::setAllPortAttrs(ArrayRef<Attribute> attrs) {
1309  setPerPortAttrsAttr(arrayOrEmpty(getContext(), attrs));
1310 }
1311 
1312 void HWModuleExternOp::setAllPortAttrs(ArrayRef<Attribute> attrs) {
1313  setPerPortAttrsAttr(arrayOrEmpty(getContext(), attrs));
1314 }
1315 
1316 void HWModuleGeneratedOp::setAllPortAttrs(ArrayRef<Attribute> attrs) {
1317  setPerPortAttrsAttr(arrayOrEmpty(getContext(), attrs));
1318 }
1319 
1320 void HWModuleOp::removeAllPortAttrs() {
1321  setPerPortAttrsAttr(ArrayAttr::get(getContext(), {}));
1322 }
1323 
1324 void HWModuleExternOp::removeAllPortAttrs() {
1325  setPerPortAttrsAttr(ArrayAttr::get(getContext(), {}));
1326 }
1327 
1328 void HWModuleGeneratedOp::removeAllPortAttrs() {
1329  setPerPortAttrsAttr(ArrayAttr::get(getContext(), {}));
1330 }
1331 
1332 // This probably does really unexpected stuff when you change the number of
1333 
1334 template <typename ModTy>
1335 static void setHWModuleType(ModTy &mod, ModuleType type) {
1336  auto argAttrs = mod.getAllInputAttrs();
1337  auto resAttrs = mod.getAllOutputAttrs();
1338  mod.setModuleTypeAttr(TypeAttr::get(type));
1339  unsigned newNumArgs = type.getNumInputs();
1340  unsigned newNumResults = type.getNumOutputs();
1341 
1342  auto emptyDict = DictionaryAttr::get(mod.getContext());
1343  argAttrs.resize(newNumArgs, emptyDict);
1344  resAttrs.resize(newNumResults, emptyDict);
1345 
1346  SmallVector<Attribute> attrs;
1347  attrs.append(argAttrs.begin(), argAttrs.end());
1348  attrs.append(resAttrs.begin(), resAttrs.end());
1349 
1350  if (attrs.empty())
1351  return mod.removeAllPortAttrs();
1352  mod.setAllPortAttrs(attrs);
1353 }
1354 
1355 void HWModuleOp::setHWModuleType(ModuleType type) {
1356  return ::setHWModuleType(*this, type);
1357 }
1358 
1359 void HWModuleExternOp::setHWModuleType(ModuleType type) {
1360  return ::setHWModuleType(*this, type);
1361 }
1362 
1363 void HWModuleGeneratedOp::setHWModuleType(ModuleType type) {
1364  return ::setHWModuleType(*this, type);
1365 }
1366 
1367 /// Lookup the generator for the symbol. This returns null on
1368 /// invalid IR.
1369 Operation *HWModuleGeneratedOp::getGeneratorKindOp() {
1370  auto topLevelModuleOp = (*this)->getParentOfType<ModuleOp>();
1371  return topLevelModuleOp.lookupSymbol(getGeneratorKind());
1372 }
1373 
1374 LogicalResult
1375 HWModuleGeneratedOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1376  auto *referencedKind =
1377  symbolTable.lookupNearestSymbolFrom(*this, getGeneratorKindAttr());
1378 
1379  if (referencedKind == nullptr)
1380  return emitError("Cannot find generator definition '")
1381  << getGeneratorKind() << "'";
1382 
1383  if (!isa<HWGeneratorSchemaOp>(referencedKind))
1384  return emitError("Symbol resolved to '")
1385  << referencedKind->getName()
1386  << "' which is not a HWGeneratorSchemaOp";
1387 
1388  auto referencedKindOp = dyn_cast<HWGeneratorSchemaOp>(referencedKind);
1389  auto paramRef = referencedKindOp.getRequiredAttrs();
1390  auto dict = (*this)->getAttrDictionary();
1391  for (auto str : paramRef) {
1392  auto strAttr = dyn_cast<StringAttr>(str);
1393  if (!strAttr)
1394  return emitError("Unknown attribute type, expected a string");
1395  if (!dict.get(strAttr.getValue()))
1396  return emitError("Missing attribute '") << strAttr.getValue() << "'";
1397  }
1398 
1399  return success();
1400 }
1401 
1402 LogicalResult HWModuleGeneratedOp::verify() {
1403  return verifyModuleCommon(*this);
1404 }
1405 
1406 void HWModuleGeneratedOp::getAsmBlockArgumentNames(
1407  mlir::Region &region, mlir::OpAsmSetValueNameFn setNameFn) {
1408  getAsmBlockArgumentNamesImpl(region, setNameFn);
1409 }
1410 
1411 LogicalResult HWModuleOp::verifyBody() { return success(); }
1412 
1413 template <typename ModuleTy>
1414 static SmallVector<PortInfo> getPortList(ModuleTy &mod) {
1415  auto modTy = mod.getHWModuleType();
1416  auto emptyDict = DictionaryAttr::get(mod.getContext());
1417  SmallVector<PortInfo> retval;
1418  auto locs = mod.getAllPortLocs();
1419  for (unsigned i = 0, e = modTy.getNumPorts(); i < e; ++i) {
1420  LocationAttr loc = locs[i];
1421  DictionaryAttr attrs =
1422  dyn_cast_or_null<DictionaryAttr>(mod.getPortAttrs(i));
1423  if (!attrs)
1424  attrs = emptyDict;
1425  retval.push_back({modTy.getPorts()[i],
1426  modTy.isOutput(i) ? modTy.getOutputIdForPortId(i)
1427  : modTy.getInputIdForPortId(i),
1428  attrs, loc});
1429  }
1430  return retval;
1431 }
1432 
1433 template <typename ModuleTy>
1434 static PortInfo getPort(ModuleTy &mod, size_t idx) {
1435  auto modTy = mod.getHWModuleType();
1436  auto emptyDict = DictionaryAttr::get(mod.getContext());
1437  LocationAttr loc = mod.getPortLoc(idx);
1438  DictionaryAttr attrs =
1439  dyn_cast_or_null<DictionaryAttr>(mod.getPortAttrs(idx));
1440  if (!attrs)
1441  attrs = emptyDict;
1442  return {modTy.getPorts()[idx],
1443  modTy.isOutput(idx) ? modTy.getOutputIdForPortId(idx)
1444  : modTy.getInputIdForPortId(idx),
1445  attrs, loc};
1446 }
1447 
1448 //===----------------------------------------------------------------------===//
1449 // InstanceOp
1450 //===----------------------------------------------------------------------===//
1451 
1452 /// Create a instance that refers to a known module.
1453 void InstanceOp::build(OpBuilder &builder, OperationState &result,
1454  Operation *module, StringAttr name,
1455  ArrayRef<Value> inputs, ArrayAttr parameters,
1456  InnerSymAttr innerSym) {
1457  if (!parameters)
1458  parameters = builder.getArrayAttr({});
1459 
1460  auto mod = cast<hw::HWModuleLike>(module);
1461  auto argNames = builder.getArrayAttr(mod.getInputNames());
1462  auto resultNames = builder.getArrayAttr(mod.getOutputNames());
1463 
1464  // Try to resolve the parameterized module type. If failed, use the module's
1465  // parmeterized type. If the client doesn't fix this error, the verifier will
1466  // fail.
1467  ModuleType modType = mod.getHWModuleType();
1468  FailureOr<ModuleType> resolvedModType = modType.resolveParametricTypes(
1469  parameters, result.location, /*emitErrors=*/false);
1470  if (succeeded(resolvedModType))
1471  modType = *resolvedModType;
1472  FunctionType funcType = resolvedModType->getFuncType();
1473  build(builder, result, funcType.getResults(), name,
1474  FlatSymbolRefAttr::get(SymbolTable::getSymbolName(module)), inputs,
1475  argNames, resultNames, parameters, innerSym);
1476 }
1477 
1478 std::optional<size_t> InstanceOp::getTargetResultIndex() {
1479  // Inner symbols on instance operations target the op not any result.
1480  return std::nullopt;
1481 }
1482 
1483 LogicalResult InstanceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1485  *this, getModuleNameAttr(), getInputs(), getResultTypes(), getArgNames(),
1486  getResultNames(), getParameters(), symbolTable);
1487 }
1488 
1489 LogicalResult InstanceOp::verify() {
1490  auto module = (*this)->getParentOfType<HWModuleOp>();
1491  if (!module)
1492  return success();
1493 
1494  auto moduleParameters = module->getAttrOfType<ArrayAttr>("parameters");
1496  [&](const std::function<bool(InFlightDiagnostic &)> &fn) {
1497  auto diag = emitOpError();
1498  if (fn(diag))
1499  diag.attachNote(module->getLoc()) << "module declared here";
1500  };
1502  getParameters(), moduleParameters, emitError);
1503 }
1504 
1505 ParseResult InstanceOp::parse(OpAsmParser &parser, OperationState &result) {
1506  StringAttr instanceNameAttr;
1507  InnerSymAttr innerSym;
1508  FlatSymbolRefAttr moduleNameAttr;
1509  SmallVector<OpAsmParser::UnresolvedOperand, 4> inputsOperands;
1510  SmallVector<Type, 1> inputsTypes, allResultTypes;
1511  ArrayAttr argNames, resultNames, parameters;
1512  auto noneType = parser.getBuilder().getType<NoneType>();
1513 
1514  if (parser.parseAttribute(instanceNameAttr, noneType, "instanceName",
1515  result.attributes))
1516  return failure();
1517 
1518  if (succeeded(parser.parseOptionalKeyword("sym"))) {
1519  // Parsing an optional symbol name doesn't fail, so no need to check the
1520  // result.
1521  if (parser.parseCustomAttributeWithFallback(innerSym))
1522  return failure();
1523  result.addAttribute(InnerSymbolTable::getInnerSymbolAttrName(), innerSym);
1524  }
1525 
1526  llvm::SMLoc parametersLoc, inputsOperandsLoc;
1527  if (parser.parseAttribute(moduleNameAttr, noneType, "moduleName",
1528  result.attributes) ||
1529  parser.getCurrentLocation(&parametersLoc) ||
1530  parseOptionalParameterList(parser, parameters) ||
1531  parseInputPortList(parser, inputsOperands, inputsTypes, argNames) ||
1532  parser.resolveOperands(inputsOperands, inputsTypes, inputsOperandsLoc,
1533  result.operands) ||
1534  parser.parseArrow() ||
1535  parseOutputPortList(parser, allResultTypes, resultNames) ||
1536  parser.parseOptionalAttrDict(result.attributes)) {
1537  return failure();
1538  }
1539 
1540  result.addAttribute("argNames", argNames);
1541  result.addAttribute("resultNames", resultNames);
1542  result.addAttribute("parameters", parameters);
1543  result.addTypes(allResultTypes);
1544  return success();
1545 }
1546 
1547 void InstanceOp::print(OpAsmPrinter &p) {
1548  p << ' ';
1549  p.printAttributeWithoutType(getInstanceNameAttr());
1550  if (auto attr = getInnerSymAttr()) {
1551  p << " sym ";
1552  attr.print(p);
1553  }
1554  p << ' ';
1555  p.printAttributeWithoutType(getModuleNameAttr());
1556  printOptionalParameterList(p, *this, getParameters());
1557  printInputPortList(p, *this, getInputs(), getInputs().getTypes(),
1558  getArgNames());
1559  p << " -> ";
1560  printOutputPortList(p, *this, getResultTypes(), getResultNames());
1561 
1562  p.printOptionalAttrDict(
1563  (*this)->getAttrs(),
1564  /*elidedAttrs=*/{"instanceName",
1565  InnerSymbolTable::getInnerSymbolAttrName(), "moduleName",
1566  "argNames", "resultNames", "parameters"});
1567 }
1568 
1569 //===----------------------------------------------------------------------===//
1570 // InstanceChoiceOp
1571 //===----------------------------------------------------------------------===//
1572 
1573 std::optional<size_t> InstanceChoiceOp::getTargetResultIndex() {
1574  // Inner symbols on instance operations target the op not any result.
1575  return std::nullopt;
1576 }
1577 
1578 LogicalResult
1579 InstanceChoiceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1580  for (Attribute name : getModuleNamesAttr()) {
1582  *this, cast<FlatSymbolRefAttr>(name), getInputs(), getResultTypes(),
1583  getArgNames(), getResultNames(), getParameters(), symbolTable))) {
1584  return failure();
1585  }
1586  }
1587  return success();
1588 }
1589 
1590 LogicalResult InstanceChoiceOp::verify() {
1591  auto module = (*this)->getParentOfType<HWModuleOp>();
1592  if (!module)
1593  return success();
1594 
1595  auto moduleParameters = module->getAttrOfType<ArrayAttr>("parameters");
1597  [&](const std::function<bool(InFlightDiagnostic &)> &fn) {
1598  auto diag = emitOpError();
1599  if (fn(diag))
1600  diag.attachNote(module->getLoc()) << "module declared here";
1601  };
1603  getParameters(), moduleParameters, emitError);
1604 }
1605 
1606 ParseResult InstanceChoiceOp::parse(OpAsmParser &parser,
1607  OperationState &result) {
1608  StringAttr optionNameAttr;
1609  StringAttr instanceNameAttr;
1610  InnerSymAttr innerSym;
1611  SmallVector<Attribute> moduleNames;
1612  SmallVector<Attribute> caseNames;
1613  SmallVector<OpAsmParser::UnresolvedOperand, 4> inputsOperands;
1614  SmallVector<Type, 1> inputsTypes, allResultTypes;
1615  ArrayAttr argNames, resultNames, parameters;
1616  auto noneType = parser.getBuilder().getType<NoneType>();
1617 
1618  if (parser.parseAttribute(instanceNameAttr, noneType, "instanceName",
1619  result.attributes))
1620  return failure();
1621 
1622  if (succeeded(parser.parseOptionalKeyword("sym"))) {
1623  // Parsing an optional symbol name doesn't fail, so no need to check the
1624  // result.
1625  if (parser.parseCustomAttributeWithFallback(innerSym))
1626  return failure();
1627  result.addAttribute(InnerSymbolTable::getInnerSymbolAttrName(), innerSym);
1628  }
1629 
1630  if (parser.parseKeyword("option") ||
1631  parser.parseAttribute(optionNameAttr, noneType, "optionName",
1632  result.attributes))
1633  return failure();
1634 
1635  FlatSymbolRefAttr defaultModuleName;
1636  if (parser.parseAttribute(defaultModuleName))
1637  return failure();
1638  moduleNames.push_back(defaultModuleName);
1639 
1640  while (succeeded(parser.parseOptionalKeyword("or"))) {
1641  FlatSymbolRefAttr moduleName;
1642  StringAttr targetName;
1643  if (parser.parseAttribute(moduleName) ||
1644  parser.parseOptionalKeyword("if") || parser.parseAttribute(targetName))
1645  return failure();
1646  moduleNames.push_back(moduleName);
1647  caseNames.push_back(targetName);
1648  }
1649 
1650  llvm::SMLoc parametersLoc, inputsOperandsLoc;
1651  if (parser.getCurrentLocation(&parametersLoc) ||
1652  parseOptionalParameterList(parser, parameters) ||
1653  parseInputPortList(parser, inputsOperands, inputsTypes, argNames) ||
1654  parser.resolveOperands(inputsOperands, inputsTypes, inputsOperandsLoc,
1655  result.operands) ||
1656  parser.parseArrow() ||
1657  parseOutputPortList(parser, allResultTypes, resultNames) ||
1658  parser.parseOptionalAttrDict(result.attributes)) {
1659  return failure();
1660  }
1661 
1662  result.addAttribute("moduleNames",
1663  ArrayAttr::get(parser.getContext(), moduleNames));
1664  result.addAttribute("caseNames",
1665  ArrayAttr::get(parser.getContext(), caseNames));
1666  result.addAttribute("argNames", argNames);
1667  result.addAttribute("resultNames", resultNames);
1668  result.addAttribute("parameters", parameters);
1669  result.addTypes(allResultTypes);
1670  return success();
1671 }
1672 
1673 void InstanceChoiceOp::print(OpAsmPrinter &p) {
1674  p << ' ';
1675  p.printAttributeWithoutType(getInstanceNameAttr());
1676  if (auto attr = getInnerSymAttr()) {
1677  p << " sym ";
1678  attr.print(p);
1679  }
1680  p << " option " << getOptionNameAttr() << ' ';
1681 
1682  auto moduleNames = getModuleNamesAttr();
1683  auto caseNames = getCaseNamesAttr();
1684  assert(moduleNames.size() == caseNames.size() + 1);
1685 
1686  p.printAttributeWithoutType(moduleNames[0]);
1687  for (size_t i = 0, n = caseNames.size(); i < n; ++i) {
1688  p << " or ";
1689  p.printAttributeWithoutType(moduleNames[i + 1]);
1690  p << " if ";
1691  p.printAttributeWithoutType(caseNames[i]);
1692  }
1693 
1694  printOptionalParameterList(p, *this, getParameters());
1695  printInputPortList(p, *this, getInputs(), getInputs().getTypes(),
1696  getArgNames());
1697  p << " -> ";
1698  printOutputPortList(p, *this, getResultTypes(), getResultNames());
1699 
1700  p.printOptionalAttrDict(
1701  (*this)->getAttrs(),
1702  /*elidedAttrs=*/{"instanceName",
1703  InnerSymbolTable::getInnerSymbolAttrName(),
1704  "moduleNames", "caseNames", "argNames", "resultNames",
1705  "parameters", "optionName"});
1706 }
1707 
1708 ArrayAttr InstanceChoiceOp::getReferencedModuleNamesAttr() {
1709  SmallVector<Attribute> moduleNames;
1710  for (Attribute attr : getModuleNamesAttr()) {
1711  moduleNames.push_back(cast<FlatSymbolRefAttr>(attr).getAttr());
1712  }
1713  return ArrayAttr::get(getContext(), moduleNames);
1714 }
1715 
1716 //===----------------------------------------------------------------------===//
1717 // HWOutputOp
1718 //===----------------------------------------------------------------------===//
1719 
1720 /// Verify that the num of operands and types fit the declared results.
1721 LogicalResult OutputOp::verify() {
1722  // Check that the we (hw.output) have the same number of operands as our
1723  // region has results.
1724  ModuleType modType;
1725  if (auto mod = dyn_cast<HWModuleOp>((*this)->getParentOp()))
1726  modType = mod.getHWModuleType();
1727  else {
1728  emitOpError("must have a module parent");
1729  return failure();
1730  }
1731  auto modResults = modType.getOutputTypes();
1732  OperandRange outputValues = getOperands();
1733  if (modResults.size() != outputValues.size()) {
1734  emitOpError("must have same number of operands as region results.");
1735  return failure();
1736  }
1737 
1738  // Check that the types of our operands and the region's results match.
1739  for (size_t i = 0, e = modResults.size(); i < e; ++i) {
1740  if (modResults[i] != outputValues[i].getType()) {
1741  emitOpError("output types must match module. In "
1742  "operand ")
1743  << i << ", expected " << modResults[i] << ", but got "
1744  << outputValues[i].getType() << ".";
1745  return failure();
1746  }
1747  }
1748 
1749  return success();
1750 }
1751 
1752 //===----------------------------------------------------------------------===//
1753 // Other Operations
1754 //===----------------------------------------------------------------------===//
1755 
1756 static ParseResult parseSliceTypes(OpAsmParser &p, Type &srcType,
1757  Type &idxType) {
1758  Type type;
1759  if (p.parseType(type))
1760  return p.emitError(p.getCurrentLocation(), "Expected type");
1761  auto arrType = type_dyn_cast<ArrayType>(type);
1762  if (!arrType)
1763  return p.emitError(p.getCurrentLocation(), "Expected !hw.array type");
1764  srcType = type;
1765  unsigned idxWidth = llvm::Log2_64_Ceil(arrType.getNumElements());
1766  idxType = IntegerType::get(p.getBuilder().getContext(), idxWidth);
1767  return success();
1768 }
1769 
1770 static void printSliceTypes(OpAsmPrinter &p, Operation *, Type srcType,
1771  Type idxType) {
1772  p.printType(srcType);
1773 }
1774 
1775 ParseResult ArrayCreateOp::parse(OpAsmParser &parser, OperationState &result) {
1776  llvm::SMLoc inputOperandsLoc = parser.getCurrentLocation();
1777  llvm::SmallVector<OpAsmParser::UnresolvedOperand, 16> operands;
1778  Type elemType;
1779 
1780  if (parser.parseOperandList(operands) ||
1781  parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
1782  parser.parseType(elemType))
1783  return failure();
1784 
1785  if (operands.size() == 0)
1786  return parser.emitError(inputOperandsLoc,
1787  "Cannot construct an array of length 0");
1788  result.addTypes({ArrayType::get(elemType, operands.size())});
1789 
1790  for (auto operand : operands)
1791  if (parser.resolveOperand(operand, elemType, result.operands))
1792  return failure();
1793  return success();
1794 }
1795 
1796 void ArrayCreateOp::print(OpAsmPrinter &p) {
1797  p << " ";
1798  p.printOperands(getInputs());
1799  p.printOptionalAttrDict((*this)->getAttrs());
1800  p << " : " << getInputs()[0].getType();
1801 }
1802 
1803 void ArrayCreateOp::build(OpBuilder &b, OperationState &state,
1804  ValueRange values) {
1805  assert(values.size() > 0 && "Cannot build array of zero elements");
1806  Type elemType = values[0].getType();
1807  assert(llvm::all_of(
1808  values,
1809  [elemType](Value v) -> bool { return v.getType() == elemType; }) &&
1810  "All values must have same type.");
1811  build(b, state, ArrayType::get(elemType, values.size()), values);
1812 }
1813 
1814 LogicalResult ArrayCreateOp::verify() {
1815  unsigned returnSize = cast<ArrayType>(getType()).getNumElements();
1816  if (getInputs().size() != returnSize)
1817  return failure();
1818  return success();
1819 }
1820 
1821 OpFoldResult ArrayCreateOp::fold(FoldAdaptor adaptor) {
1822  if (llvm::any_of(adaptor.getInputs(), [](Attribute attr) { return !attr; }))
1823  return {};
1824  return ArrayAttr::get(getContext(), adaptor.getInputs());
1825 }
1826 
1827 // Check whether an integer value is an offset from a base.
1828 bool hw::isOffset(Value base, Value index, uint64_t offset) {
1829  if (auto constBase = base.getDefiningOp<hw::ConstantOp>()) {
1830  if (auto constIndex = index.getDefiningOp<hw::ConstantOp>()) {
1831  // If both values are a constant, check if index == base + offset.
1832  // To account for overflow, the addition is performed with an extra bit
1833  // and the offset is asserted to fit in the bit width of the base.
1834  auto baseValue = constBase.getValue();
1835  auto indexValue = constIndex.getValue();
1836 
1837  unsigned bits = baseValue.getBitWidth();
1838  assert(bits == indexValue.getBitWidth() && "mismatched widths");
1839 
1840  if (bits < 64 && offset >= (1ull << bits))
1841  return false;
1842 
1843  APInt baseExt = baseValue.zextOrTrunc(bits + 1);
1844  APInt indexExt = indexValue.zextOrTrunc(bits + 1);
1845  return baseExt + offset == indexExt;
1846  }
1847  }
1848  return false;
1849 }
1850 
1851 // Canonicalize a create of consecutive elements to a slice.
1852 static LogicalResult foldCreateToSlice(ArrayCreateOp op,
1853  PatternRewriter &rewriter) {
1854  // Do not canonicalize create of get into a slice.
1855  auto arrayTy = hw::type_cast<ArrayType>(op.getType());
1856  if (arrayTy.getNumElements() <= 1)
1857  return failure();
1858  auto elemTy = arrayTy.getElementType();
1859 
1860  // Check if create arguments are consecutive elements of the same array.
1861  // Attempt to break a create of gets into a sequence of consecutive intervals.
1862  struct Chunk {
1863  Value input;
1864  Value index;
1865  size_t size;
1866  };
1867  SmallVector<Chunk> chunks;
1868  for (Value value : llvm::reverse(op.getInputs())) {
1869  auto get = value.getDefiningOp<ArrayGetOp>();
1870  if (!get)
1871  return failure();
1872 
1873  Value input = get.getInput();
1874  Value index = get.getIndex();
1875  if (!chunks.empty()) {
1876  auto &c = *chunks.rbegin();
1877  if (c.input == get.getInput() && isOffset(c.index, index, c.size)) {
1878  c.size++;
1879  continue;
1880  }
1881  }
1882 
1883  chunks.push_back(Chunk{input, index, 1});
1884  }
1885 
1886  // If there is a single slice, eliminate the create.
1887  if (chunks.size() == 1) {
1888  auto &chunk = chunks[0];
1889  rewriter.replaceOp(op, rewriter.createOrFold<ArraySliceOp>(
1890  op.getLoc(), arrayTy, chunk.input, chunk.index));
1891  return success();
1892  }
1893 
1894  // If the number of chunks is significantly less than the number of
1895  // elements, replace the create with a concat of the identified slices.
1896  if (chunks.size() * 2 < arrayTy.getNumElements()) {
1897  SmallVector<Value> slices;
1898  for (auto &chunk : llvm::reverse(chunks)) {
1899  auto sliceTy = ArrayType::get(elemTy, chunk.size);
1900  slices.push_back(rewriter.createOrFold<ArraySliceOp>(
1901  op.getLoc(), sliceTy, chunk.input, chunk.index));
1902  }
1903  rewriter.replaceOpWithNewOp<ArrayConcatOp>(op, arrayTy, slices);
1904  return success();
1905  }
1906 
1907  return failure();
1908 }
1909 
1910 LogicalResult ArrayCreateOp::canonicalize(ArrayCreateOp op,
1911  PatternRewriter &rewriter) {
1912  if (succeeded(foldCreateToSlice(op, rewriter)))
1913  return success();
1914  return failure();
1915 }
1916 
1917 Value ArrayCreateOp::getUniformElement() {
1918  if (!getInputs().empty() && llvm::all_equal(getInputs()))
1919  return getInputs()[0];
1920  return {};
1921 }
1922 
1923 static std::optional<uint64_t> getUIntFromValue(Value value) {
1924  auto idxOp = dyn_cast_or_null<ConstantOp>(value.getDefiningOp());
1925  if (!idxOp)
1926  return std::nullopt;
1927  APInt idxAttr = idxOp.getValue();
1928  if (idxAttr.getBitWidth() > 64)
1929  return std::nullopt;
1930  return idxAttr.getLimitedValue();
1931 }
1932 
1933 LogicalResult ArraySliceOp::verify() {
1934  unsigned inputSize =
1935  type_cast<ArrayType>(getInput().getType()).getNumElements();
1936  if (llvm::Log2_64_Ceil(inputSize) !=
1937  getLowIndex().getType().getIntOrFloatBitWidth())
1938  return emitOpError(
1939  "ArraySlice: index width must match clog2 of array size");
1940  return success();
1941 }
1942 
1943 OpFoldResult ArraySliceOp::fold(FoldAdaptor adaptor) {
1944  // If we are slicing the entire input, then return it.
1945  if (getType() == getInput().getType())
1946  return getInput();
1947  return {};
1948 }
1949 
1950 LogicalResult ArraySliceOp::canonicalize(ArraySliceOp op,
1951  PatternRewriter &rewriter) {
1952  auto sliceTy = hw::type_cast<ArrayType>(op.getType());
1953  auto elemTy = sliceTy.getElementType();
1954  uint64_t sliceSize = sliceTy.getNumElements();
1955  if (sliceSize == 0)
1956  return failure();
1957 
1958  if (sliceSize == 1) {
1959  // slice(a, n) -> create(a[n])
1960  auto get = rewriter.create<ArrayGetOp>(op.getLoc(), op.getInput(),
1961  op.getLowIndex());
1962  rewriter.replaceOpWithNewOp<ArrayCreateOp>(op, op.getType(),
1963  get.getResult());
1964  return success();
1965  }
1966 
1967  auto offsetOpt = getUIntFromValue(op.getLowIndex());
1968  if (!offsetOpt)
1969  return failure();
1970 
1971  auto inputOp = op.getInput().getDefiningOp();
1972  if (auto inputSlice = dyn_cast_or_null<ArraySliceOp>(inputOp)) {
1973  // slice(slice(a, n), m) -> slice(a, n + m)
1974  if (inputSlice == op)
1975  return failure();
1976 
1977  auto inputIndex = inputSlice.getLowIndex();
1978  auto inputOffsetOpt = getUIntFromValue(inputIndex);
1979  if (!inputOffsetOpt)
1980  return failure();
1981 
1982  uint64_t offset = *offsetOpt + *inputOffsetOpt;
1983  auto lowIndex =
1984  rewriter.create<ConstantOp>(op.getLoc(), inputIndex.getType(), offset);
1985  rewriter.replaceOpWithNewOp<ArraySliceOp>(op, op.getType(),
1986  inputSlice.getInput(), lowIndex);
1987  return success();
1988  }
1989 
1990  if (auto inputCreate = dyn_cast_or_null<ArrayCreateOp>(inputOp)) {
1991  // slice(create(a0, a1, ..., an), m) -> create(am, ...)
1992  auto inputs = inputCreate.getInputs();
1993 
1994  uint64_t begin = inputs.size() - *offsetOpt - sliceSize;
1995  rewriter.replaceOpWithNewOp<ArrayCreateOp>(op, op.getType(),
1996  inputs.slice(begin, sliceSize));
1997  return success();
1998  }
1999 
2000  if (auto inputConcat = dyn_cast_or_null<ArrayConcatOp>(inputOp)) {
2001  // slice(concat(a1, a2, ...)) -> concat(a2, slice(a3, ..), ...)
2002  SmallVector<Value> chunks;
2003  uint64_t sliceStart = *offsetOpt;
2004  for (auto input : llvm::reverse(inputConcat.getInputs())) {
2005  // Check whether the input intersects with the slice.
2006  uint64_t inputSize =
2007  hw::type_cast<ArrayType>(input.getType()).getNumElements();
2008  if (inputSize == 0 || inputSize <= sliceStart) {
2009  sliceStart -= inputSize;
2010  continue;
2011  }
2012 
2013  // Find the indices to slice from this input by intersection.
2014  uint64_t cutEnd = std::min(inputSize, sliceStart + sliceSize);
2015  uint64_t cutSize = cutEnd - sliceStart;
2016  assert(cutSize != 0 && "slice cannot be empty");
2017 
2018  if (cutSize == inputSize) {
2019  // The whole input fits in the slice, add it.
2020  assert(sliceStart == 0 && "invalid cut size");
2021  chunks.push_back(input);
2022  } else {
2023  // Slice the required bits from the input.
2024  unsigned width = inputSize == 1 ? 1 : llvm::Log2_64_Ceil(inputSize);
2025  auto lowIndex = rewriter.create<ConstantOp>(
2026  op.getLoc(), rewriter.getIntegerType(width), sliceStart);
2027  chunks.push_back(rewriter.create<ArraySliceOp>(
2028  op.getLoc(), hw::ArrayType::get(elemTy, cutSize), input, lowIndex));
2029  }
2030 
2031  sliceStart = 0;
2032  sliceSize -= cutSize;
2033  if (sliceSize == 0)
2034  break;
2035  }
2036 
2037  assert(chunks.size() > 0 && "missing sliced items");
2038  if (chunks.size() == 1)
2039  rewriter.replaceOp(op, chunks[0]);
2040  else
2041  rewriter.replaceOpWithNewOp<ArrayConcatOp>(
2042  op, llvm::to_vector(llvm::reverse(chunks)));
2043  return success();
2044  }
2045  return failure();
2046 }
2047 
2048 //===----------------------------------------------------------------------===//
2049 // ArrayConcatOp
2050 //===----------------------------------------------------------------------===//
2051 
2052 static ParseResult parseArrayConcatTypes(OpAsmParser &p,
2053  SmallVectorImpl<Type> &inputTypes,
2054  Type &resultType) {
2055  Type elemType;
2056  uint64_t resultSize = 0;
2057 
2058  auto parseElement = [&]() -> ParseResult {
2059  Type ty;
2060  if (p.parseType(ty))
2061  return failure();
2062  auto arrTy = type_dyn_cast<ArrayType>(ty);
2063  if (!arrTy)
2064  return p.emitError(p.getCurrentLocation(), "Expected !hw.array type");
2065  if (elemType && elemType != arrTy.getElementType())
2066  return p.emitError(p.getCurrentLocation(), "Expected array element type ")
2067  << elemType;
2068 
2069  elemType = arrTy.getElementType();
2070  inputTypes.push_back(ty);
2071  resultSize += arrTy.getNumElements();
2072  return success();
2073  };
2074 
2075  if (p.parseCommaSeparatedList(parseElement))
2076  return failure();
2077 
2078  resultType = ArrayType::get(elemType, resultSize);
2079  return success();
2080 }
2081 
2082 static void printArrayConcatTypes(OpAsmPrinter &p, Operation *,
2083  TypeRange inputTypes, Type resultType) {
2084  llvm::interleaveComma(inputTypes, p, [&p](Type t) { p << t; });
2085 }
2086 
2087 void ArrayConcatOp::build(OpBuilder &b, OperationState &state,
2088  ValueRange values) {
2089  assert(!values.empty() && "Cannot build array of zero elements");
2090  ArrayType arrayTy = cast<ArrayType>(values[0].getType());
2091  Type elemTy = arrayTy.getElementType();
2092  assert(llvm::all_of(values,
2093  [elemTy](Value v) -> bool {
2094  return isa<ArrayType>(v.getType()) &&
2095  cast<ArrayType>(v.getType()).getElementType() ==
2096  elemTy;
2097  }) &&
2098  "All values must be of ArrayType with the same element type.");
2099 
2100  uint64_t resultSize = 0;
2101  for (Value val : values)
2102  resultSize += cast<ArrayType>(val.getType()).getNumElements();
2103  build(b, state, ArrayType::get(elemTy, resultSize), values);
2104 }
2105 
2106 OpFoldResult ArrayConcatOp::fold(FoldAdaptor adaptor) {
2107  auto inputs = adaptor.getInputs();
2108  SmallVector<Attribute> array;
2109  for (size_t i = 0, e = getNumOperands(); i < e; ++i) {
2110  if (!inputs[i])
2111  return {};
2112  llvm::copy(cast<ArrayAttr>(inputs[i]), std::back_inserter(array));
2113  }
2114  return ArrayAttr::get(getContext(), array);
2115 }
2116 
2117 // Flatten a concatenation of array creates into a single create.
2118 static bool flattenConcatOp(ArrayConcatOp op, PatternRewriter &rewriter) {
2119  for (auto input : op.getInputs())
2120  if (!input.getDefiningOp<ArrayCreateOp>())
2121  return false;
2122 
2123  SmallVector<Value> items;
2124  for (auto input : op.getInputs()) {
2125  auto create = cast<ArrayCreateOp>(input.getDefiningOp());
2126  for (auto item : create.getInputs())
2127  items.push_back(item);
2128  }
2129 
2130  rewriter.replaceOpWithNewOp<ArrayCreateOp>(op, items);
2131  return true;
2132 }
2133 
2134 // Merge consecutive slice expressions in a concatenation.
2135 static bool mergeConcatSlices(ArrayConcatOp op, PatternRewriter &rewriter) {
2136  struct Slice {
2137  Value input;
2138  Value index;
2139  size_t size;
2140  Value op;
2141  SmallVector<Location> locs;
2142  };
2143 
2144  SmallVector<Value> items;
2145  std::optional<Slice> last;
2146  bool changed = false;
2147 
2148  auto concatenate = [&] {
2149  // If there is only one op in the slice, place it to the items list.
2150  if (!last)
2151  return;
2152  if (last->op) {
2153  items.push_back(last->op);
2154  last.reset();
2155  return;
2156  }
2157 
2158  // Otherwise, create a new slice of with the given size and place it.
2159  // In this case, the concat op is replaced, using the new argument.
2160  changed = true;
2161  auto loc = FusedLoc::get(op.getContext(), last->locs);
2162  auto origTy = hw::type_cast<ArrayType>(last->input.getType());
2163  auto arrayTy = ArrayType::get(origTy.getElementType(), last->size);
2164  items.push_back(rewriter.createOrFold<ArraySliceOp>(
2165  loc, arrayTy, last->input, last->index));
2166 
2167  last.reset();
2168  };
2169 
2170  auto append = [&](Value op, Value input, Value index, size_t size) {
2171  // If this slice is an extension of the previous one, extend the size
2172  // saved. In this case, a new slice of is created and the concatenation
2173  // operator is rewritten. Otherwise, flush the last slice.
2174  if (last) {
2175  if (last->input == input && isOffset(last->index, index, last->size)) {
2176  last->size += size;
2177  last->op = {};
2178  last->locs.push_back(op.getLoc());
2179  return;
2180  }
2181  concatenate();
2182  }
2183  last.emplace(Slice{input, index, size, op, {op.getLoc()}});
2184  };
2185 
2186  for (auto item : llvm::reverse(op.getInputs())) {
2187  if (auto slice = item.getDefiningOp<ArraySliceOp>()) {
2188  auto size = hw::type_cast<ArrayType>(slice.getType()).getNumElements();
2189  append(item, slice.getInput(), slice.getLowIndex(), size);
2190  continue;
2191  }
2192 
2193  if (auto create = item.getDefiningOp<ArrayCreateOp>()) {
2194  if (create.getInputs().size() == 1) {
2195  if (auto get = create.getInputs()[0].getDefiningOp<ArrayGetOp>()) {
2196  append(item, get.getInput(), get.getIndex(), 1);
2197  continue;
2198  }
2199  }
2200  }
2201 
2202  concatenate();
2203  items.push_back(item);
2204  }
2205  concatenate();
2206 
2207  if (!changed)
2208  return false;
2209 
2210  if (items.size() == 1) {
2211  rewriter.replaceOp(op, items[0]);
2212  } else {
2213  std::reverse(items.begin(), items.end());
2214  rewriter.replaceOpWithNewOp<ArrayConcatOp>(op, items);
2215  }
2216  return true;
2217 }
2218 
2219 LogicalResult ArrayConcatOp::canonicalize(ArrayConcatOp op,
2220  PatternRewriter &rewriter) {
2221  // concat(create(a1, ...), create(a3, ...), ...) -> create(a1, ..., a3, ...)
2222  if (flattenConcatOp(op, rewriter))
2223  return success();
2224 
2225  // concat(slice(a, n, m), slice(a, n + m, p)) -> concat(slice(a, n, m + p))
2226  if (mergeConcatSlices(op, rewriter))
2227  return success();
2228 
2229  return failure();
2230 }
2231 
2232 //===----------------------------------------------------------------------===//
2233 // EnumConstantOp
2234 //===----------------------------------------------------------------------===//
2235 
2236 ParseResult EnumConstantOp::parse(OpAsmParser &parser, OperationState &result) {
2237  // Parse a Type instead of an EnumType since the type might be a type alias.
2238  // The validity of the canonical type is checked during construction of the
2239  // EnumFieldAttr.
2240  Type type;
2241  StringRef field;
2242 
2243  auto loc = parser.getEncodedSourceLoc(parser.getCurrentLocation());
2244  if (parser.parseKeyword(&field) || parser.parseColonType(type))
2245  return failure();
2246 
2247  auto fieldAttr = EnumFieldAttr::get(
2248  loc, StringAttr::get(parser.getContext(), field), type);
2249 
2250  if (!fieldAttr)
2251  return failure();
2252 
2253  result.addAttribute("field", fieldAttr);
2254  result.addTypes(type);
2255 
2256  return success();
2257 }
2258 
2259 void EnumConstantOp::print(OpAsmPrinter &p) {
2260  p << " " << getField().getField().getValue() << " : "
2261  << getField().getType().getValue();
2262 }
2263 
2265  function_ref<void(Value, StringRef)> setNameFn) {
2266  setNameFn(getResult(), getField().getField().str());
2267 }
2268 
2269 void EnumConstantOp::build(OpBuilder &builder, OperationState &odsState,
2270  EnumFieldAttr field) {
2271  return build(builder, odsState, field.getType().getValue(), field);
2272 }
2273 
2274 OpFoldResult EnumConstantOp::fold(FoldAdaptor adaptor) {
2275  assert(adaptor.getOperands().empty() && "constant has no operands");
2276  return getFieldAttr();
2277 }
2278 
2279 LogicalResult EnumConstantOp::verify() {
2280  auto fieldAttr = getFieldAttr();
2281  auto fieldType = fieldAttr.getType().getValue();
2282  // This check ensures that we are using the exact same type, without looking
2283  // through type aliases.
2284  if (fieldType != getType())
2285  emitOpError("return type ")
2286  << getType() << " does not match attribute type " << fieldAttr;
2287  return success();
2288 }
2289 
2290 //===----------------------------------------------------------------------===//
2291 // EnumCmpOp
2292 //===----------------------------------------------------------------------===//
2293 
2294 LogicalResult EnumCmpOp::verify() {
2295  // Compare the canonical types.
2296  auto lhsType = type_cast<EnumType>(getLhs().getType());
2297  auto rhsType = type_cast<EnumType>(getRhs().getType());
2298  if (rhsType != lhsType)
2299  emitOpError("types do not match");
2300  return success();
2301 }
2302 
2303 //===----------------------------------------------------------------------===//
2304 // StructCreateOp
2305 //===----------------------------------------------------------------------===//
2306 
2307 ParseResult StructCreateOp::parse(OpAsmParser &parser, OperationState &result) {
2308  llvm::SMLoc inputOperandsLoc = parser.getCurrentLocation();
2309  llvm::SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
2310  Type declOrAliasType;
2311 
2312  if (parser.parseLParen() || parser.parseOperandList(operands) ||
2313  parser.parseRParen() || parser.parseOptionalAttrDict(result.attributes) ||
2314  parser.parseColonType(declOrAliasType))
2315  return failure();
2316 
2317  auto declType = type_dyn_cast<StructType>(declOrAliasType);
2318  if (!declType)
2319  return parser.emitError(parser.getNameLoc(),
2320  "expected !hw.struct type or alias");
2321 
2322  llvm::SmallVector<Type, 4> structInnerTypes;
2323  declType.getInnerTypes(structInnerTypes);
2324  result.addTypes(declOrAliasType);
2325 
2326  if (parser.resolveOperands(operands, structInnerTypes, inputOperandsLoc,
2327  result.operands))
2328  return failure();
2329  return success();
2330 }
2331 
2332 void StructCreateOp::print(OpAsmPrinter &printer) {
2333  printer << " (";
2334  printer.printOperands(getInput());
2335  printer << ")";
2336  printer.printOptionalAttrDict((*this)->getAttrs());
2337  printer << " : " << getType();
2338 }
2339 
2340 LogicalResult StructCreateOp::verify() {
2341  auto elements = hw::type_cast<StructType>(getType()).getElements();
2342 
2343  if (elements.size() != getInput().size())
2344  return emitOpError("structure field count mismatch");
2345 
2346  for (const auto &[field, value] : llvm::zip(elements, getInput()))
2347  if (field.type != value.getType())
2348  return emitOpError("structure field `")
2349  << field.name << "` type does not match";
2350 
2351  return success();
2352 }
2353 
2354 OpFoldResult StructCreateOp::fold(FoldAdaptor adaptor) {
2355  // struct_create(struct_explode(x)) => x
2356  if (!getInput().empty())
2357  if (auto explodeOp = getInput()[0].getDefiningOp<StructExplodeOp>();
2358  explodeOp && getInput() == explodeOp.getResults() &&
2359  getResult().getType() == explodeOp.getInput().getType())
2360  return explodeOp.getInput();
2361 
2362  auto inputs = adaptor.getInput();
2363  if (llvm::any_of(inputs, [](Attribute attr) { return !attr; }))
2364  return {};
2365  return ArrayAttr::get(getContext(), inputs);
2366 }
2367 
2368 //===----------------------------------------------------------------------===//
2369 // StructExplodeOp
2370 //===----------------------------------------------------------------------===//
2371 
2372 ParseResult StructExplodeOp::parse(OpAsmParser &parser,
2373  OperationState &result) {
2374  OpAsmParser::UnresolvedOperand operand;
2375  Type declType;
2376 
2377  if (parser.parseOperand(operand) ||
2378  parser.parseOptionalAttrDict(result.attributes) ||
2379  parser.parseColonType(declType))
2380  return failure();
2381  auto structType = type_dyn_cast<StructType>(declType);
2382  if (!structType)
2383  return parser.emitError(parser.getNameLoc(),
2384  "invalid kind of type specified");
2385 
2386  llvm::SmallVector<Type, 4> structInnerTypes;
2387  structType.getInnerTypes(structInnerTypes);
2388  result.addTypes(structInnerTypes);
2389 
2390  if (parser.resolveOperand(operand, declType, result.operands))
2391  return failure();
2392  return success();
2393 }
2394 
2395 void StructExplodeOp::print(OpAsmPrinter &printer) {
2396  printer << " ";
2397  printer.printOperand(getInput());
2398  printer.printOptionalAttrDict((*this)->getAttrs());
2399  printer << " : " << getInput().getType();
2400 }
2401 
2402 LogicalResult StructExplodeOp::fold(FoldAdaptor adaptor,
2403  SmallVectorImpl<OpFoldResult> &results) {
2404  auto input = adaptor.getInput();
2405  if (!input)
2406  return failure();
2407  llvm::copy(cast<ArrayAttr>(input), std::back_inserter(results));
2408  return success();
2409 }
2410 
2411 LogicalResult StructExplodeOp::canonicalize(StructExplodeOp op,
2412  PatternRewriter &rewriter) {
2413  auto *inputOp = op.getInput().getDefiningOp();
2414  auto elements = type_cast<StructType>(op.getInput().getType()).getElements();
2415  auto result = failure();
2416  auto opResults = op.getResults();
2417  for (uint32_t index = 0; index < elements.size(); index++) {
2418  if (auto foldResult = foldStructExtract(inputOp, index)) {
2419  rewriter.replaceAllUsesWith(opResults[index], foldResult);
2420  result = success();
2421  }
2422  }
2423  return result;
2424 }
2425 
2427  function_ref<void(Value, StringRef)> setNameFn) {
2428  auto structType = type_cast<StructType>(getInput().getType());
2429  for (auto [res, field] : llvm::zip(getResults(), structType.getElements()))
2430  setNameFn(res, field.name.str());
2431 }
2432 
2433 void StructExplodeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
2434  Value input) {
2435  StructType inputType = dyn_cast<StructType>(input.getType());
2436  assert(inputType);
2437  SmallVector<Type, 16> fieldTypes;
2438  for (auto field : inputType.getElements())
2439  fieldTypes.push_back(field.type);
2440  build(odsBuilder, odsState, fieldTypes, input);
2441 }
2442 
2443 //===----------------------------------------------------------------------===//
2444 // StructExtractOp
2445 //===----------------------------------------------------------------------===//
2446 
2447 /// Ensure an aggregate op's field index is within the bounds of
2448 /// the aggregate type and the accessed field is of 'elementType'.
2449 template <typename AggregateOp, typename AggregateType>
2450 static LogicalResult verifyAggregateFieldIndexAndType(AggregateOp &op,
2451  AggregateType aggType,
2452  Type elementType) {
2453  auto index = op.getFieldIndex();
2454  if (index >= aggType.getElements().size())
2455  return op.emitOpError() << "field index " << index
2456  << " exceeds element count of aggregate type";
2457 
2459  getCanonicalType(aggType.getElements()[index].type))
2460  return op.emitOpError()
2461  << "type " << aggType.getElements()[index].type
2462  << " of accessed field in aggregate at index " << index
2463  << " does not match expected type " << elementType;
2464 
2465  return success();
2466 }
2467 
2468 LogicalResult StructExtractOp::verify() {
2469  return verifyAggregateFieldIndexAndType<StructExtractOp, StructType>(
2470  *this, getInput().getType(), getType());
2471 }
2472 
2473 /// Use the same parser for both struct_extract and union_extract since the
2474 /// syntax is identical.
2475 template <typename AggregateType>
2476 static ParseResult parseExtractOp(OpAsmParser &parser, OperationState &result) {
2477  OpAsmParser::UnresolvedOperand operand;
2478  StringAttr fieldName;
2479  Type declType;
2480 
2481  if (parser.parseOperand(operand) || parser.parseLSquare() ||
2482  parser.parseAttribute(fieldName) || parser.parseRSquare() ||
2483  parser.parseOptionalAttrDict(result.attributes) ||
2484  parser.parseColonType(declType))
2485  return failure();
2486  auto aggType = type_dyn_cast<AggregateType>(declType);
2487  if (!aggType)
2488  return parser.emitError(parser.getNameLoc(),
2489  "invalid kind of type specified");
2490 
2491  auto fieldIndex = aggType.getFieldIndex(fieldName);
2492  if (!fieldIndex) {
2493  parser.emitError(parser.getNameLoc(), "field name '" +
2494  fieldName.getValue() +
2495  "' not found in aggregate type");
2496  return failure();
2497  }
2498 
2499  auto indexAttr =
2500  IntegerAttr::get(IntegerType::get(parser.getContext(), 32), *fieldIndex);
2501  result.addAttribute("fieldIndex", indexAttr);
2502  Type resultType = aggType.getElements()[*fieldIndex].type;
2503  result.addTypes(resultType);
2504 
2505  if (parser.resolveOperand(operand, declType, result.operands))
2506  return failure();
2507  return success();
2508 }
2509 
2510 /// Use the same printer for both struct_extract and union_extract since the
2511 /// syntax is identical.
2512 template <typename AggType>
2513 static void printExtractOp(OpAsmPrinter &printer, AggType op) {
2514  printer << " ";
2515  printer.printOperand(op.getInput());
2516  printer << "[\"" << op.getFieldName() << "\"]";
2517  printer.printOptionalAttrDict(op->getAttrs(), {"fieldIndex"});
2518  printer << " : " << op.getInput().getType();
2519 }
2520 
2521 ParseResult StructExtractOp::parse(OpAsmParser &parser,
2522  OperationState &result) {
2523  return parseExtractOp<StructType>(parser, result);
2524 }
2525 
2526 void StructExtractOp::print(OpAsmPrinter &printer) {
2527  printExtractOp(printer, *this);
2528 }
2529 
2530 void StructExtractOp::build(OpBuilder &builder, OperationState &odsState,
2531  Value input, StructType::FieldInfo field) {
2532  auto fieldIndex =
2533  type_cast<StructType>(input.getType()).getFieldIndex(field.name);
2534  assert(fieldIndex.has_value() && "field name not found in aggregate type");
2535  build(builder, odsState, field.type, input, *fieldIndex);
2536 }
2537 
2538 void StructExtractOp::build(OpBuilder &builder, OperationState &odsState,
2539  Value input, StringAttr fieldName) {
2540  auto structType = type_cast<StructType>(input.getType());
2541  auto fieldIndex = structType.getFieldIndex(fieldName);
2542  assert(fieldIndex.has_value() && "field name not found in aggregate type");
2543  auto resultType = structType.getElements()[*fieldIndex].type;
2544  build(builder, odsState, resultType, input, *fieldIndex);
2545 }
2546 
2547 OpFoldResult StructExtractOp::fold(FoldAdaptor adaptor) {
2548  if (auto constOperand = adaptor.getInput()) {
2549  // Fold extract from aggregate constant
2550  auto operandAttr = llvm::cast<ArrayAttr>(constOperand);
2551  return operandAttr.getValue()[getFieldIndex()];
2552  }
2553 
2554  if (auto foldResult =
2555  foldStructExtract(getInput().getDefiningOp(), getFieldIndex()))
2556  return foldResult;
2557  return {};
2558 }
2559 
2560 LogicalResult StructExtractOp::canonicalize(StructExtractOp op,
2561  PatternRewriter &rewriter) {
2562  auto inputOp = op.getInput().getDefiningOp();
2563 
2564  // b = extract(inject(x["a"], v0)["b"]) => extract(x, "b")
2565  if (auto structInject = dyn_cast_or_null<StructInjectOp>(inputOp)) {
2566  if (structInject.getFieldIndex() != op.getFieldIndex()) {
2567  rewriter.replaceOpWithNewOp<StructExtractOp>(
2568  op, op.getType(), structInject.getInput(), op.getFieldIndexAttr());
2569  return success();
2570  }
2571  }
2572 
2573  return failure();
2574 }
2575 
2577  function_ref<void(Value, StringRef)> setNameFn) {
2578  setNameFn(getResult(), getFieldName());
2579 }
2580 
2581 //===----------------------------------------------------------------------===//
2582 // StructInjectOp
2583 //===----------------------------------------------------------------------===//
2584 
2585 void StructInjectOp::build(OpBuilder &builder, OperationState &odsState,
2586  Value input, StringAttr fieldName, Value newValue) {
2587  auto structType = type_cast<StructType>(input.getType());
2588  auto fieldIndex = structType.getFieldIndex(fieldName);
2589  assert(fieldIndex.has_value() && "field name not found in aggregate type");
2590  build(builder, odsState, input, *fieldIndex, newValue);
2591 }
2592 
2593 LogicalResult StructInjectOp::verify() {
2594  return verifyAggregateFieldIndexAndType<StructInjectOp, StructType>(
2595  *this, getInput().getType(), getNewValue().getType());
2596 }
2597 
2598 ParseResult StructInjectOp::parse(OpAsmParser &parser, OperationState &result) {
2599  llvm::SMLoc inputOperandsLoc = parser.getCurrentLocation();
2600  OpAsmParser::UnresolvedOperand operand, val;
2601  StringAttr fieldName;
2602  Type declType;
2603 
2604  if (parser.parseOperand(operand) || parser.parseLSquare() ||
2605  parser.parseAttribute(fieldName) || parser.parseRSquare() ||
2606  parser.parseComma() || parser.parseOperand(val) ||
2607  parser.parseOptionalAttrDict(result.attributes) ||
2608  parser.parseColonType(declType))
2609  return failure();
2610  auto structType = type_dyn_cast<StructType>(declType);
2611  if (!structType)
2612  return parser.emitError(inputOperandsLoc, "invalid kind of type specified");
2613 
2614  auto fieldIndex = structType.getFieldIndex(fieldName);
2615  if (!fieldIndex) {
2616  parser.emitError(parser.getNameLoc(), "field name '" +
2617  fieldName.getValue() +
2618  "' not found in aggregate type");
2619  return failure();
2620  }
2621 
2622  auto indexAttr =
2623  IntegerAttr::get(IntegerType::get(parser.getContext(), 32), *fieldIndex);
2624  result.addAttribute("fieldIndex", indexAttr);
2625  result.addTypes(declType);
2626 
2627  Type resultType = structType.getElements()[*fieldIndex].type;
2628  if (parser.resolveOperands({operand, val}, {declType, resultType},
2629  inputOperandsLoc, result.operands))
2630  return failure();
2631  return success();
2632 }
2633 
2634 void StructInjectOp::print(OpAsmPrinter &printer) {
2635  printer << " ";
2636  printer.printOperand(getInput());
2637  printer << "[\"" << getFieldName() << "\"], ";
2638  printer.printOperand(getNewValue());
2639  printer.printOptionalAttrDict((*this)->getAttrs(), {"fieldIndex"});
2640  printer << " : " << getInput().getType();
2641 }
2642 
2643 OpFoldResult StructInjectOp::fold(FoldAdaptor adaptor) {
2644  auto input = adaptor.getInput();
2645  auto newValue = adaptor.getNewValue();
2646  if (!input || !newValue)
2647  return {};
2648  SmallVector<Attribute> array;
2649  llvm::copy(cast<ArrayAttr>(input), std::back_inserter(array));
2650  array[getFieldIndex()] = newValue;
2651  return ArrayAttr::get(getContext(), array);
2652 }
2653 
2654 LogicalResult StructInjectOp::canonicalize(StructInjectOp op,
2655  PatternRewriter &rewriter) {
2656  // Canonicalize multiple injects into a create op and eliminate overwrites.
2657  SmallPtrSet<Operation *, 4> injects;
2658  DenseMap<StringAttr, Value> fields;
2659 
2660  // Chase a chain of injects. Bail out if cycles are present.
2661  StructInjectOp inject = op;
2662  Value input;
2663  do {
2664  if (!injects.insert(inject).second)
2665  return failure();
2666 
2667  fields.try_emplace(inject.getFieldNameAttr(), inject.getNewValue());
2668  input = inject.getInput();
2669  inject = dyn_cast_or_null<StructInjectOp>(input.getDefiningOp());
2670  } while (inject);
2671  assert(input && "missing input to inject chain");
2672 
2673  auto ty = hw::type_cast<StructType>(op.getType());
2674  auto elements = ty.getElements();
2675 
2676  // If the inject chain sets all fields, canonicalize to create.
2677  if (fields.size() == elements.size()) {
2678  SmallVector<Value> createFields;
2679  for (const auto &field : elements) {
2680  auto it = fields.find(field.name);
2681  assert(it != fields.end() && "missing field");
2682  createFields.push_back(it->second);
2683  }
2684  rewriter.replaceOpWithNewOp<StructCreateOp>(op, ty, createFields);
2685  return success();
2686  }
2687 
2688  // Nothing to canonicalize, only the original inject in the chain.
2689  if (injects.size() == fields.size())
2690  return failure();
2691 
2692  // Eliminate overwrites. The hash map contains the last write to each field.
2693  for (uint32_t fieldIndex = 0; fieldIndex < elements.size(); fieldIndex++) {
2694  auto it = fields.find(elements[fieldIndex].name);
2695  if (it == fields.end())
2696  continue;
2697  input = rewriter.create<StructInjectOp>(op.getLoc(), ty, input, fieldIndex,
2698  it->second);
2699  }
2700 
2701  rewriter.replaceOp(op, input);
2702  return success();
2703 }
2704 
2705 //===----------------------------------------------------------------------===//
2706 // UnionCreateOp
2707 //===----------------------------------------------------------------------===//
2708 
2709 LogicalResult UnionCreateOp::verify() {
2710  return verifyAggregateFieldIndexAndType<UnionCreateOp, UnionType>(
2711  *this, getType(), getInput().getType());
2712 }
2713 
2714 void UnionCreateOp::build(OpBuilder &builder, OperationState &odsState,
2715  Type unionType, StringAttr fieldName, Value input) {
2716  auto fieldIndex = type_cast<UnionType>(unionType).getFieldIndex(fieldName);
2717  assert(fieldIndex.has_value() && "field name not found in aggregate type");
2718  build(builder, odsState, unionType, *fieldIndex, input);
2719 }
2720 
2721 ParseResult UnionCreateOp::parse(OpAsmParser &parser, OperationState &result) {
2722  Type declOrAliasType;
2723  StringAttr fieldName;
2724  OpAsmParser::UnresolvedOperand input;
2725  llvm::SMLoc fieldLoc = parser.getCurrentLocation();
2726 
2727  if (parser.parseAttribute(fieldName) || parser.parseComma() ||
2728  parser.parseOperand(input) ||
2729  parser.parseOptionalAttrDict(result.attributes) ||
2730  parser.parseColonType(declOrAliasType))
2731  return failure();
2732 
2733  auto declType = type_dyn_cast<UnionType>(declOrAliasType);
2734  if (!declType)
2735  return parser.emitError(parser.getNameLoc(),
2736  "expected !hw.union type or alias");
2737 
2738  auto fieldIndex = declType.getFieldIndex(fieldName);
2739  if (!fieldIndex) {
2740  parser.emitError(fieldLoc, "cannot find union field '")
2741  << fieldName.getValue() << '\'';
2742  return failure();
2743  }
2744 
2745  auto indexAttr =
2746  IntegerAttr::get(IntegerType::get(parser.getContext(), 32), *fieldIndex);
2747  result.addAttribute("fieldIndex", indexAttr);
2748  Type inputType = declType.getElements()[*fieldIndex].type;
2749 
2750  if (parser.resolveOperand(input, inputType, result.operands))
2751  return failure();
2752  result.addTypes({declOrAliasType});
2753  return success();
2754 }
2755 
2756 void UnionCreateOp::print(OpAsmPrinter &printer) {
2757  printer << " \"" << getFieldName() << "\", ";
2758  printer.printOperand(getInput());
2759  printer.printOptionalAttrDict((*this)->getAttrs(), {"fieldIndex"});
2760  printer << " : " << getType();
2761 }
2762 
2763 //===----------------------------------------------------------------------===//
2764 // UnionExtractOp
2765 //===----------------------------------------------------------------------===//
2766 
2767 ParseResult UnionExtractOp::parse(OpAsmParser &parser, OperationState &result) {
2768  return parseExtractOp<UnionType>(parser, result);
2769 }
2770 
2771 void UnionExtractOp::print(OpAsmPrinter &printer) {
2772  printExtractOp(printer, *this);
2773 }
2774 
2775 LogicalResult UnionExtractOp::inferReturnTypes(
2776  MLIRContext *context, std::optional<Location> loc, ValueRange operands,
2777  DictionaryAttr attrs, mlir::OpaqueProperties properties,
2778  mlir::RegionRange regions, SmallVectorImpl<Type> &results) {
2779  auto unionElements =
2780  hw::type_cast<UnionType>((operands[0].getType())).getElements();
2781  unsigned fieldIndex =
2782  attrs.getAs<IntegerAttr>("fieldIndex").getValue().getZExtValue();
2783  if (fieldIndex >= unionElements.size()) {
2784  if (loc)
2785  mlir::emitError(*loc, "field index " + Twine(fieldIndex) +
2786  " exceeds element count of aggregate type");
2787  return failure();
2788  }
2789  results.push_back(unionElements[fieldIndex].type);
2790  return success();
2791 }
2792 
2793 void UnionExtractOp::build(OpBuilder &odsBuilder, OperationState &odsState,
2794  Value input, StringAttr fieldName) {
2795  auto unionType = type_cast<UnionType>(input.getType());
2796  auto fieldIndex = unionType.getFieldIndex(fieldName);
2797  assert(fieldIndex.has_value() && "field name not found in aggregate type");
2798  auto resultType = unionType.getElements()[*fieldIndex].type;
2799  build(odsBuilder, odsState, resultType, input, *fieldIndex);
2800 }
2801 
2802 //===----------------------------------------------------------------------===//
2803 // ArrayGetOp
2804 //===----------------------------------------------------------------------===//
2805 
2806 // An array_get of an array_create with a constant index can just be the
2807 // array_create operand at the constant index. If the array_create has a
2808 // single uniform value for each element, just return that value regardless of
2809 // the index. If the array is constructed from a constant by a bitcast
2810 // operation, we can fold into a constant.
2811 OpFoldResult ArrayGetOp::fold(FoldAdaptor adaptor) {
2812  auto inputCst = dyn_cast_or_null<ArrayAttr>(adaptor.getInput());
2813  auto indexCst = dyn_cast_or_null<IntegerAttr>(adaptor.getIndex());
2814 
2815  if (inputCst) {
2816  // Constant array index.
2817  if (indexCst) {
2818  auto indexVal = indexCst.getValue();
2819  if (indexVal.getBitWidth() < 64) {
2820  auto index = indexVal.getZExtValue();
2821  return inputCst[inputCst.size() - 1 - index];
2822  }
2823  }
2824  // If all elements of the array are the same, we can return any element of
2825  // array.
2826  if (!inputCst.empty() && llvm::all_equal(inputCst))
2827  return inputCst[0];
2828  }
2829 
2830  // array_get(bitcast(c), i) -> c[i*w+w-1:i*w]
2831  if (auto bitcast = getInput().getDefiningOp<hw::BitcastOp>()) {
2832  auto intTy = dyn_cast<IntegerType>(getType());
2833  if (!intTy)
2834  return {};
2835  auto bitcastInputOp = bitcast.getInput().getDefiningOp<hw::ConstantOp>();
2836  if (!bitcastInputOp)
2837  return {};
2838  if (!indexCst)
2839  return {};
2840  auto bitcastInputCst = bitcastInputOp.getValue();
2841  // Calculate the index. Make sure to zero-extend the index value before
2842  // multiplying the element width.
2843  auto startIdx = indexCst.getValue().zext(bitcastInputCst.getBitWidth()) *
2844  getType().getIntOrFloatBitWidth();
2845  // Extract [startIdx + width - 1: startIdx].
2846  return IntegerAttr::get(intTy, bitcastInputCst.lshr(startIdx).trunc(
2847  intTy.getIntOrFloatBitWidth()));
2848  }
2849 
2850  auto inputCreate = getInput().getDefiningOp<ArrayCreateOp>();
2851  if (!inputCreate)
2852  return {};
2853 
2854  if (auto uniformValue = inputCreate.getUniformElement())
2855  return uniformValue;
2856 
2857  if (!indexCst || indexCst.getValue().getBitWidth() > 64)
2858  return {};
2859 
2860  uint64_t index = indexCst.getValue().getLimitedValue();
2861  auto createInputs = inputCreate.getInputs();
2862  if (index >= createInputs.size())
2863  return {};
2864  return createInputs[createInputs.size() - index - 1];
2865 }
2866 
2867 LogicalResult ArrayGetOp::canonicalize(ArrayGetOp op,
2868  PatternRewriter &rewriter) {
2869  auto idxOpt = getUIntFromValue(op.getIndex());
2870  if (!idxOpt)
2871  return failure();
2872 
2873  auto *inputOp = op.getInput().getDefiningOp();
2874  if (auto inputSlice = dyn_cast_or_null<ArraySliceOp>(inputOp)) {
2875  // get(slice(a, n), m) -> get(a, n + m)
2876  auto offsetOp = inputSlice.getLowIndex();
2877  auto offsetOpt = getUIntFromValue(offsetOp);
2878  if (!offsetOpt)
2879  return failure();
2880 
2881  uint64_t offset = *offsetOpt + *idxOpt;
2882  auto newOffset =
2883  rewriter.create<ConstantOp>(op.getLoc(), offsetOp.getType(), offset);
2884  rewriter.replaceOpWithNewOp<ArrayGetOp>(op, inputSlice.getInput(),
2885  newOffset);
2886  return success();
2887  }
2888 
2889  if (auto inputConcat = dyn_cast_or_null<ArrayConcatOp>(inputOp)) {
2890  // get(concat(a0, a1, ...), m) -> get(an, m - s0 - s1 - ...)
2891  uint64_t elemIndex = *idxOpt;
2892  for (auto input : llvm::reverse(inputConcat.getInputs())) {
2893  size_t size = hw::type_cast<ArrayType>(input.getType()).getNumElements();
2894  if (elemIndex >= size) {
2895  elemIndex -= size;
2896  continue;
2897  }
2898 
2899  unsigned indexWidth = size == 1 ? 1 : llvm::Log2_64_Ceil(size);
2900  auto newIdxOp = rewriter.create<ConstantOp>(
2901  op.getLoc(), rewriter.getIntegerType(indexWidth), elemIndex);
2902 
2903  rewriter.replaceOpWithNewOp<ArrayGetOp>(op, input, newIdxOp);
2904  return success();
2905  }
2906  return failure();
2907  }
2908 
2909  // array_get const, (array_get sel, (array_create a, b, c, d)) -->
2910  // array_get sel, (array_create (array_get const a), (array_get const b),
2911  // (array_get const, c), (array_get const, d))
2912  if (auto innerGet = dyn_cast_or_null<hw::ArrayGetOp>(inputOp)) {
2913  if (!innerGet.getIndex().getDefiningOp<hw::ConstantOp>()) {
2914  if (auto create =
2915  innerGet.getInput().getDefiningOp<hw::ArrayCreateOp>()) {
2916 
2917  SmallVector<Value> newValues;
2918  for (auto operand : create.getOperands())
2919  newValues.push_back(rewriter.createOrFold<hw::ArrayGetOp>(
2920  op.getLoc(), operand, op.getIndex()));
2921 
2922  rewriter.replaceOpWithNewOp<hw::ArrayGetOp>(
2923  op,
2924  rewriter.createOrFold<hw::ArrayCreateOp>(op.getLoc(), newValues),
2925  innerGet.getIndex());
2926  return success();
2927  }
2928  }
2929  }
2930 
2931  return failure();
2932 }
2933 
2934 //===----------------------------------------------------------------------===//
2935 // TypedeclOp
2936 //===----------------------------------------------------------------------===//
2937 
2938 StringRef TypedeclOp::getPreferredName() {
2939  return getVerilogName().value_or(getName());
2940 }
2941 
2942 Type TypedeclOp::getAliasType() {
2943  auto parentScope = cast<hw::TypeScopeOp>(getOperation()->getParentOp());
2944  return hw::TypeAliasType::get(
2945  SymbolRefAttr::get(parentScope.getSymNameAttr(),
2946  {FlatSymbolRefAttr::get(*this)}),
2947  getType());
2948 }
2949 
2950 //===----------------------------------------------------------------------===//
2951 // BitcastOp
2952 //===----------------------------------------------------------------------===//
2953 
2954 OpFoldResult BitcastOp::fold(FoldAdaptor) {
2955  // Identity.
2956  // bitcast(%a) : A -> A ==> %a
2957  if (getOperand().getType() == getType())
2958  return getOperand();
2959 
2960  return {};
2961 }
2962 
2963 LogicalResult BitcastOp::canonicalize(BitcastOp op, PatternRewriter &rewriter) {
2964  // Composition.
2965  // %b = bitcast(%a) : A -> B
2966  // bitcast(%b) : B -> C
2967  // ===> bitcast(%a) : A -> C
2968  auto inputBitcast =
2969  dyn_cast_or_null<BitcastOp>(op.getInput().getDefiningOp());
2970  if (!inputBitcast)
2971  return failure();
2972  auto bitcast = rewriter.createOrFold<BitcastOp>(op.getLoc(), op.getType(),
2973  inputBitcast.getInput());
2974  rewriter.replaceOp(op, bitcast);
2975  return success();
2976 }
2977 
2978 LogicalResult BitcastOp::verify() {
2979  if (getBitWidth(getInput().getType()) != getBitWidth(getResult().getType()))
2980  return this->emitOpError("Bitwidth of input must match result");
2981  return success();
2982 }
2983 
2984 //===----------------------------------------------------------------------===//
2985 // HierPathOp helpers.
2986 //===----------------------------------------------------------------------===//
2987 
2988 bool HierPathOp::dropModule(StringAttr moduleToDrop) {
2989  SmallVector<Attribute, 4> newPath;
2990  bool updateMade = false;
2991  for (auto nameRef : getNamepath()) {
2992  // nameRef is either an InnerRefAttr or a FlatSymbolRefAttr.
2993  if (auto ref = dyn_cast<hw::InnerRefAttr>(nameRef)) {
2994  if (ref.getModule() == moduleToDrop)
2995  updateMade = true;
2996  else
2997  newPath.push_back(ref);
2998  } else {
2999  if (cast<FlatSymbolRefAttr>(nameRef).getAttr() == moduleToDrop)
3000  updateMade = true;
3001  else
3002  newPath.push_back(nameRef);
3003  }
3004  }
3005  if (updateMade)
3006  setNamepathAttr(ArrayAttr::get(getContext(), newPath));
3007  return updateMade;
3008 }
3009 
3010 bool HierPathOp::inlineModule(StringAttr moduleToDrop) {
3011  SmallVector<Attribute, 4> newPath;
3012  bool updateMade = false;
3013  StringRef inlinedInstanceName = "";
3014  for (auto nameRef : getNamepath()) {
3015  // nameRef is either an InnerRefAttr or a FlatSymbolRefAttr.
3016  if (auto ref = dyn_cast<hw::InnerRefAttr>(nameRef)) {
3017  if (ref.getModule() == moduleToDrop) {
3018  inlinedInstanceName = ref.getName().getValue();
3019  updateMade = true;
3020  } else if (!inlinedInstanceName.empty()) {
3021  newPath.push_back(hw::InnerRefAttr::get(
3022  ref.getModule(),
3023  StringAttr::get(getContext(), inlinedInstanceName + "_" +
3024  ref.getName().getValue())));
3025  inlinedInstanceName = "";
3026  } else
3027  newPath.push_back(ref);
3028  } else {
3029  if (cast<FlatSymbolRefAttr>(nameRef).getAttr() == moduleToDrop)
3030  updateMade = true;
3031  else
3032  newPath.push_back(nameRef);
3033  }
3034  }
3035  if (updateMade)
3036  setNamepathAttr(ArrayAttr::get(getContext(), newPath));
3037  return updateMade;
3038 }
3039 
3040 bool HierPathOp::updateModule(StringAttr oldMod, StringAttr newMod) {
3041  SmallVector<Attribute, 4> newPath;
3042  bool updateMade = false;
3043  for (auto nameRef : getNamepath()) {
3044  // nameRef is either an InnerRefAttr or a FlatSymbolRefAttr.
3045  if (auto ref = dyn_cast<hw::InnerRefAttr>(nameRef)) {
3046  if (ref.getModule() == oldMod) {
3047  newPath.push_back(hw::InnerRefAttr::get(newMod, ref.getName()));
3048  updateMade = true;
3049  } else
3050  newPath.push_back(ref);
3051  } else {
3052  if (cast<FlatSymbolRefAttr>(nameRef).getAttr() == oldMod) {
3053  newPath.push_back(FlatSymbolRefAttr::get(newMod));
3054  updateMade = true;
3055  } else
3056  newPath.push_back(nameRef);
3057  }
3058  }
3059  if (updateMade)
3060  setNamepathAttr(ArrayAttr::get(getContext(), newPath));
3061  return updateMade;
3062 }
3063 
3064 bool HierPathOp::updateModuleAndInnerRef(
3065  StringAttr oldMod, StringAttr newMod,
3066  const llvm::DenseMap<StringAttr, StringAttr> &innerSymRenameMap) {
3067  auto fromRef = FlatSymbolRefAttr::get(oldMod);
3068  if (oldMod == newMod)
3069  return false;
3070 
3071  auto namepathNew = getNamepath().getValue().vec();
3072  bool updateMade = false;
3073  // Break from the loop if the module is found, since it can occur only once.
3074  for (auto &element : namepathNew) {
3075  if (auto innerRef = dyn_cast<hw::InnerRefAttr>(element)) {
3076  if (innerRef.getModule() != oldMod)
3077  continue;
3078  auto symName = innerRef.getName();
3079  // Since the module got updated, the old innerRef symbol inside oldMod
3080  // should also be updated to the new symbol inside the newMod.
3081  auto to = innerSymRenameMap.find(symName);
3082  if (to != innerSymRenameMap.end())
3083  symName = to->second;
3084  updateMade = true;
3085  element = hw::InnerRefAttr::get(newMod, symName);
3086  break;
3087  }
3088  if (element != fromRef)
3089  continue;
3090 
3091  updateMade = true;
3092  element = FlatSymbolRefAttr::get(newMod);
3093  break;
3094  }
3095  if (updateMade)
3096  setNamepathAttr(ArrayAttr::get(getContext(), namepathNew));
3097  return updateMade;
3098 }
3099 
3100 bool HierPathOp::truncateAtModule(StringAttr atMod, bool includeMod) {
3101  SmallVector<Attribute, 4> newPath;
3102  bool updateMade = false;
3103  for (auto nameRef : getNamepath()) {
3104  // nameRef is either an InnerRefAttr or a FlatSymbolRefAttr.
3105  if (auto ref = dyn_cast<hw::InnerRefAttr>(nameRef)) {
3106  if (ref.getModule() == atMod) {
3107  updateMade = true;
3108  if (includeMod)
3109  newPath.push_back(ref);
3110  } else
3111  newPath.push_back(ref);
3112  } else {
3113  if (cast<FlatSymbolRefAttr>(nameRef).getAttr() == atMod && !includeMod)
3114  updateMade = true;
3115  else
3116  newPath.push_back(nameRef);
3117  }
3118  if (updateMade)
3119  break;
3120  }
3121  if (updateMade)
3122  setNamepathAttr(ArrayAttr::get(getContext(), newPath));
3123  return updateMade;
3124 }
3125 
3126 /// Return just the module part of the namepath at a specific index.
3127 StringAttr HierPathOp::modPart(unsigned i) {
3128  return TypeSwitch<Attribute, StringAttr>(getNamepath()[i])
3129  .Case<FlatSymbolRefAttr>([](auto a) { return a.getAttr(); })
3130  .Case<hw::InnerRefAttr>([](auto a) { return a.getModule(); });
3131 }
3132 
3133 /// Return the root module.
3134 StringAttr HierPathOp::root() {
3135  assert(!getNamepath().empty());
3136  return modPart(0);
3137 }
3138 
3139 /// Return true if the NLA has the module in its path.
3140 bool HierPathOp::hasModule(StringAttr modName) {
3141  for (auto nameRef : getNamepath()) {
3142  // nameRef is either an InnerRefAttr or a FlatSymbolRefAttr.
3143  if (auto ref = dyn_cast<hw::InnerRefAttr>(nameRef)) {
3144  if (ref.getModule() == modName)
3145  return true;
3146  } else {
3147  if (cast<FlatSymbolRefAttr>(nameRef).getAttr() == modName)
3148  return true;
3149  }
3150  }
3151  return false;
3152 }
3153 
3154 /// Return true if the NLA has the InnerSym .
3155 bool HierPathOp::hasInnerSym(StringAttr modName, StringAttr symName) const {
3156  for (auto nameRef : const_cast<HierPathOp *>(this)->getNamepath())
3157  if (auto ref = dyn_cast<hw::InnerRefAttr>(nameRef))
3158  if (ref.getName() == symName && ref.getModule() == modName)
3159  return true;
3160 
3161  return false;
3162 }
3163 
3164 /// Return just the reference part of the namepath at a specific index. This
3165 /// will return an empty attribute if this is the leaf and the leaf is a module.
3166 StringAttr HierPathOp::refPart(unsigned i) {
3167  return TypeSwitch<Attribute, StringAttr>(getNamepath()[i])
3168  .Case<FlatSymbolRefAttr>([](auto a) { return StringAttr({}); })
3169  .Case<hw::InnerRefAttr>([](auto a) { return a.getName(); });
3170 }
3171 
3172 /// Return the leaf reference. This returns an empty attribute if the leaf
3173 /// reference is a module.
3174 StringAttr HierPathOp::ref() {
3175  assert(!getNamepath().empty());
3176  return refPart(getNamepath().size() - 1);
3177 }
3178 
3179 /// Return the leaf module.
3180 StringAttr HierPathOp::leafMod() {
3181  assert(!getNamepath().empty());
3182  return modPart(getNamepath().size() - 1);
3183 }
3184 
3185 /// Returns true if this NLA targets an instance of a module (as opposed to
3186 /// an instance's port or something inside an instance).
3187 bool HierPathOp::isModule() { return !ref(); }
3188 
3189 /// Returns true if this NLA targets something inside a module (as opposed
3190 /// to a module or an instance of a module);
3191 bool HierPathOp::isComponent() { return (bool)ref(); }
3192 
3193 // Verify the HierPathOp.
3194 // 1. Iterate over the namepath.
3195 // 2. The namepath should be a valid instance path, specified either on a
3196 // module or a declaration inside a module.
3197 // 3. Each element in the namepath is an InnerRefAttr except possibly the
3198 // last element.
3199 // 4. Make sure that the InnerRefAttr is legal, by verifying the module name
3200 // and the corresponding inner_sym on the instance.
3201 // 5. Make sure that the instance path is legal, by verifying the sequence of
3202 // instance and the expected module occurs as the next element in the path.
3203 // 6. The last element of the namepath, can be an InnerRefAttr on either a
3204 // module port or a declaration inside the module.
3205 // 7. The last element of the namepath can also be a module symbol.
3206 LogicalResult HierPathOp::verifyInnerRefs(hw::InnerRefNamespace &ns) {
3207  ArrayAttr expectedModuleNames = {};
3208  auto checkExpectedModule = [&](Attribute name) -> LogicalResult {
3209  if (!expectedModuleNames)
3210  return success();
3211  if (llvm::any_of(expectedModuleNames,
3212  [name](Attribute attr) { return attr == name; }))
3213  return success();
3214  auto diag = emitOpError() << "instance path is incorrect. Expected ";
3215  size_t n = expectedModuleNames.size();
3216  if (n != 1) {
3217  diag << "one of ";
3218  }
3219  for (size_t i = 0; i < n; ++i) {
3220  if (i != 0)
3221  diag << ((i + 1 == n) ? " or " : ", ");
3222  diag << cast<StringAttr>(expectedModuleNames[i]);
3223  }
3224  diag << ". Instead found: " << name;
3225  return diag;
3226  };
3227 
3228  if (!getNamepath() || getNamepath().empty())
3229  return emitOpError() << "the instance path cannot be empty";
3230  for (unsigned i = 0, s = getNamepath().size() - 1; i < s; ++i) {
3231  hw::InnerRefAttr innerRef = dyn_cast<hw::InnerRefAttr>(getNamepath()[i]);
3232  if (!innerRef)
3233  return emitOpError()
3234  << "the instance path can only contain inner sym reference"
3235  << ", only the leaf can refer to a module symbol";
3236 
3237  if (failed(checkExpectedModule(innerRef.getModule())))
3238  return failure();
3239 
3240  auto instOp = ns.lookupOp<igraph::InstanceOpInterface>(innerRef);
3241  if (!instOp)
3242  return emitOpError() << " module: " << innerRef.getModule()
3243  << " does not contain any instance with symbol: "
3244  << innerRef.getName();
3245  expectedModuleNames = instOp.getReferencedModuleNamesAttr();
3246  }
3247 
3248  // The instance path has been verified. Now verify the last element.
3249  auto leafRef = getNamepath()[getNamepath().size() - 1];
3250  if (auto innerRef = dyn_cast<hw::InnerRefAttr>(leafRef)) {
3251  if (!ns.lookup(innerRef)) {
3252  return emitOpError() << " operation with symbol: " << innerRef
3253  << " was not found ";
3254  }
3255  if (failed(checkExpectedModule(innerRef.getModule())))
3256  return failure();
3257  } else if (failed(checkExpectedModule(
3258  cast<FlatSymbolRefAttr>(leafRef).getAttr()))) {
3259  return failure();
3260  }
3261  return success();
3262 }
3263 
3264 void HierPathOp::print(OpAsmPrinter &p) {
3265  p << " ";
3266 
3267  // Print visibility if present.
3268  StringRef visibilityAttrName = SymbolTable::getVisibilityAttrName();
3269  if (auto visibility =
3270  getOperation()->getAttrOfType<StringAttr>(visibilityAttrName))
3271  p << visibility.getValue() << ' ';
3272 
3273  p.printSymbolName(getSymName());
3274  p << " [";
3275  llvm::interleaveComma(getNamepath().getValue(), p, [&](Attribute attr) {
3276  if (auto ref = dyn_cast<hw::InnerRefAttr>(attr)) {
3277  p.printSymbolName(ref.getModule().getValue());
3278  p << "::";
3279  p.printSymbolName(ref.getName().getValue());
3280  } else {
3281  p.printSymbolName(cast<FlatSymbolRefAttr>(attr).getValue());
3282  }
3283  });
3284  p << "]";
3285  p.printOptionalAttrDict(
3286  (*this)->getAttrs(),
3287  {SymbolTable::getSymbolAttrName(), "namepath", visibilityAttrName});
3288 }
3289 
3290 ParseResult HierPathOp::parse(OpAsmParser &parser, OperationState &result) {
3291  // Parse the visibility attribute.
3292  (void)mlir::impl::parseOptionalVisibilityKeyword(parser, result.attributes);
3293 
3294  // Parse the symbol name.
3295  StringAttr symName;
3296  if (parser.parseSymbolName(symName, SymbolTable::getSymbolAttrName(),
3297  result.attributes))
3298  return failure();
3299 
3300  // Parse the namepath.
3301  SmallVector<Attribute> namepath;
3302  if (parser.parseCommaSeparatedList(
3303  OpAsmParser::Delimiter::Square, [&]() -> ParseResult {
3304  auto loc = parser.getCurrentLocation();
3305  SymbolRefAttr ref;
3306  if (parser.parseAttribute(ref))
3307  return failure();
3308 
3309  // "A" is a Ref, "A::b" is a InnerRef, "A::B::c" is an error.
3310  auto pathLength = ref.getNestedReferences().size();
3311  if (pathLength == 0)
3312  namepath.push_back(
3313  FlatSymbolRefAttr::get(ref.getRootReference()));
3314  else if (pathLength == 1)
3315  namepath.push_back(hw::InnerRefAttr::get(ref.getRootReference(),
3316  ref.getLeafReference()));
3317  else
3318  return parser.emitError(loc,
3319  "only one nested reference is allowed");
3320  return success();
3321  }))
3322  return failure();
3323  result.addAttribute("namepath",
3324  ArrayAttr::get(parser.getContext(), namepath));
3325 
3326  if (parser.parseOptionalAttrDict(result.attributes))
3327  return failure();
3328 
3329  return success();
3330 }
3331 
3332 //===----------------------------------------------------------------------===//
3333 // TriggeredOp
3334 //===----------------------------------------------------------------------===//
3335 
3336 void TriggeredOp::build(OpBuilder &builder, OperationState &odsState,
3337  EventControlAttr event, Value trigger,
3338  ValueRange inputs) {
3339  odsState.addOperands(trigger);
3340  odsState.addOperands(inputs);
3341  odsState.addAttribute(getEventAttrName(odsState.name), event);
3342  auto *r = odsState.addRegion();
3343  Block *b = new Block();
3344  r->push_back(b);
3345 
3346  llvm::SmallVector<Location> argLocs;
3347  llvm::transform(inputs, std::back_inserter(argLocs),
3348  [&](Value v) { return v.getLoc(); });
3349  b->addArguments(inputs.getTypes(), argLocs);
3350 }
3351 
3352 //===----------------------------------------------------------------------===//
3353 // TableGen generated logic.
3354 //===----------------------------------------------------------------------===//
3355 
3356 // Provide the autogenerated implementation guts for the Op classes.
3357 #define GET_OP_CLASSES
3358 #include "circt/Dialect/HW/HW.cpp.inc"
assert(baseType &&"element must be base type")
MlirType elementType
Definition: CHIRRTL.cpp:29
static Attribute getAttr(ArrayRef< NamedAttribute > attrs, StringRef name)
Get an attribute by name from a list of named attributes.
Definition: FIRRTLOps.cpp:3916
int32_t width
Definition: FIRRTL.cpp:36
static LogicalResult verifyModuleCommon(HWModuleLike module)
Definition: HWOps.cpp:1072
static void printParamValue(OpAsmPrinter &p, Operation *, Attribute value, Type resultType)
Definition: HWOps.cpp:487
static void printModuleOp(OpAsmPrinter &p, ModuleTy mod)
Definition: HWOps.cpp:1017
static bool flattenConcatOp(ArrayConcatOp op, PatternRewriter &rewriter)
Definition: HWOps.cpp:2118
static LogicalResult foldCreateToSlice(ArrayCreateOp op, PatternRewriter &rewriter)
Definition: HWOps.cpp:1852
static SmallVector< Location > getAllPortLocs(ModTy module)
Definition: HWOps.cpp:1192
static ArrayAttr arrayOrEmpty(mlir::MLIRContext *context, ArrayRef< Attribute > attrs)
Definition: HWOps.cpp:84
FunctionType getHWModuleOpType(Operation *op)
Definition: HWOps.cpp:1009
static void buildModule(OpBuilder &builder, OperationState &result, StringAttr name, const ModulePortInfo &ports, ArrayAttr parameters, ArrayRef< NamedAttribute > attributes, StringAttr comment)
Definition: HWOps.cpp:543
static void printExtractOp(OpAsmPrinter &printer, AggType op)
Use the same printer for both struct_extract and union_extract since the syntax is identical.
Definition: HWOps.cpp:2513
static void modifyModulePorts(Operation *op, ArrayRef< std::pair< unsigned, PortInfo >> insertInputs, ArrayRef< std::pair< unsigned, PortInfo >> insertOutputs, ArrayRef< unsigned > removeInputs, ArrayRef< unsigned > removeOutputs, Block *body=nullptr)
Insert and remove ports of a module.
Definition: HWOps.cpp:680
static void printArrayConcatTypes(OpAsmPrinter &p, Operation *, TypeRange inputTypes, Type resultType)
Definition: HWOps.cpp:2082
static void modifyModuleArgs(MLIRContext *context, ArrayRef< std::pair< unsigned, PortInfo >> insertArgs, ArrayRef< unsigned > removeArgs, ArrayRef< Attribute > oldArgNames, ArrayRef< Type > oldArgTypes, ArrayRef< Attribute > oldArgAttrs, ArrayRef< Location > oldArgLocs, SmallVector< Attribute > &newArgNames, SmallVector< Type > &newArgTypes, SmallVector< Attribute > &newArgAttrs, SmallVector< Location > &newArgLocs, Block *body=nullptr)
Internal implementation of argument/result insertion and removal on modules.
Definition: HWOps.cpp:581
static ParseResult parseSliceTypes(OpAsmParser &p, Type &srcType, Type &idxType)
Definition: HWOps.cpp:1756
static Value foldStructExtract(Operation *inputOp, uint32_t fieldIndex)
Definition: HWOps.cpp:69
static bool hasAttribute(StringRef name, ArrayRef< NamedAttribute > attrs)
Definition: HWOps.cpp:887
static bool mergeConcatSlices(ArrayConcatOp op, PatternRewriter &rewriter)
Definition: HWOps.cpp:2135
static ParseResult parseExtractOp(OpAsmParser &parser, OperationState &result)
Use the same parser for both struct_extract and union_extract since the syntax is identical.
Definition: HWOps.cpp:2476
static void setAllPortNames(ArrayRef< Attribute > names, ModTy module)
Definition: HWOps.cpp:1262
static void getAsmBlockArgumentNamesImpl(mlir::Region &region, OpAsmSetValueNameFn setNameFn)
Get a special name to use when printing the entry block arguments of the region contained by an opera...
Definition: HWOps.cpp:101
static void setHWModuleType(ModTy &mod, ModuleType type)
Definition: HWOps.cpp:1335
static SmallVector< PortInfo > getPortList(ModuleTy &mod)
Definition: HWOps.cpp:1414
static ParseResult parseParamValue(OpAsmParser &p, Attribute &value, Type &resultType)
Definition: HWOps.cpp:479
static LogicalResult checkAttributes(Operation *op, Attribute attr, Type type)
Definition: HWOps.cpp:397
static std::optional< uint64_t > getUIntFromValue(Value value)
Definition: HWOps.cpp:1923
static ParseResult parseHWModuleOp(OpAsmParser &parser, OperationState &result)
Definition: HWOps.cpp:895
static LogicalResult verifyAggregateFieldIndexAndType(AggregateOp &op, AggregateType aggType, Type elementType)
Ensure an aggregate op's field index is within the bounds of the aggregate type and the accessed fiel...
Definition: HWOps.cpp:2450
static PortInfo getPort(ModuleTy &mod, size_t idx)
Definition: HWOps.cpp:1434
static void printSliceTypes(OpAsmPrinter &p, Operation *, Type srcType, Type idxType)
Definition: HWOps.cpp:1770
static bool hasAdditionalAttributes(Op op, ArrayRef< StringRef > ignoredAttrs={})
Check whether an operation has any additional attributes set beyond its standard list of attributes r...
Definition: HWOps.cpp:343
Delimiter
Definition: HWOps.cpp:116
@ OptionalLessGreater
static ParseResult parseArrayConcatTypes(OpAsmParser &p, SmallVectorImpl< Type > &inputTypes, Type &resultType)
Definition: HWOps.cpp:2052
@ Input
Definition: HW.h:35
@ Output
Definition: HW.h:35
@ InOut
Definition: HW.h:35
static bool getFieldName(const FieldRef &fieldRef, SmallString< 32 > &string)
static InstancePath empty
llvm::SmallVector< StringAttr > inputs
llvm::SmallVector< StringAttr > outputs
Builder builder
static StringAttr append(StringAttr base, const Twine &suffix)
Return a attribute with the specified suffix appended.
A namespace that is used to store existing names and generate new names in some scope within the IR.
Definition: Namespace.h:30
StringRef newName(const Twine &name)
Return a unique name, derived from the input name, and add the new name to the internal namespace.
Definition: Namespace.h:72
void setOutput(unsigned i, Value v)
Definition: HWOps.cpp:241
Value getInput(unsigned i)
Definition: HWOps.cpp:247
llvm::SmallVector< Value > outputOperands
Definition: HWOps.h:118
llvm::SmallVector< Value > inputArgs
Definition: HWOps.h:117
llvm::StringMap< unsigned > outputIdx
Definition: HWOps.h:116
llvm::StringMap< unsigned > inputIdx
Definition: HWOps.h:116
HWModulePortAccessor(Location loc, const ModulePortInfo &info, Region &bodyRegion)
Definition: HWOps.cpp:225
static StringRef getInnerSymbolAttrName()
Return the name of the attribute used for inner symbol names.
This helps visit TypeOp nodes.
Definition: HWVisitors.h:25
ResultType dispatchTypeOpVisitor(Operation *op, ExtraArgs... args)
Definition: HWVisitors.h:27
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:54
uint64_t getWidth(Type t)
Definition: ESIPasses.cpp:32
LogicalResult inferReturnTypes(MLIRContext *context, std::optional< Location > loc, ValueRange operands, DictionaryAttr attrs, mlir::OpaqueProperties properties, mlir::RegionRange regions, SmallVectorImpl< Type > &results, llvm::function_ref< FIRRTLType(ValueRange, ArrayRef< NamedAttribute >, std::optional< Location >)> callback)
size_t getNumPorts(Operation *op)
Return the number of ports in a module-like thing (modules, memories, etc)
Definition: FIRRTLOps.cpp:297
ModuleType fnToMod(Operation *op, ArrayRef< Attribute > inputNames, ArrayRef< Attribute > outputNames)
Definition: HWTypes.cpp:1037
LogicalResult verifyParameterStructure(ArrayAttr parameters, ArrayAttr moduleParameters, const EmitErrorFn &emitError)
Check that all the parameter values specified to the instance are structurally valid.
std::function< void(std::function< bool(InFlightDiagnostic &)>)> EmitErrorFn
Whenever the nested function returns true, a note referring to the referenced module is attached to t...
LogicalResult verifyInstanceOfHWModule(Operation *instance, FlatSymbolRefAttr moduleRef, OperandRange inputs, TypeRange results, ArrayAttr argNames, ArrayAttr resultNames, ArrayAttr parameters, SymbolTableCollection &symbolTable)
Combines verifyReferencedModule, verifyInputs, verifyOutputs, and verifyParameters.
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
void getAsmResultNames(OpAsmSetValueNameFn setNameFn, StringRef instanceName, ArrayAttr resultNames, ValueRange results)
Suggest a name for each result value based on the saved result names attribute.
ParseResult parseModuleSignature(OpAsmParser &parser, SmallVectorImpl< PortParse > &args, TypeAttr &modType)
New Style parsing.
void printModuleSignatureNew(OpAsmPrinter &p, Region &body, hw::ModuleType modType, ArrayRef< Attribute > portAttrs, ArrayRef< Location > locAttrs)
bool isOffset(Value base, Value index, uint64_t offset)
Definition: HWOps.cpp:1828
llvm::function_ref< void(OpBuilder &, HWModulePortAccessor &)> HWModuleBuilder
Definition: HWOps.h:123
bool isCombinational(Operation *op)
Return true if the specified operation is a combinational logic op.
Definition: HWOps.cpp:59
ModulePort::Direction flip(ModulePort::Direction direction)
Flip a port direction.
Definition: HWOps.cpp:36
FunctionType getModuleType(Operation *module)
Return the signature for the specified module as a function type.
Definition: HWOps.cpp:515
LogicalResult checkParameterInContext(Attribute value, Operation *module, Operation *usingOp, bool disallowParamRefs=false)
Check parameter specified by value to see if it is valid within the scope of the specified module mod...
Definition: HWOps.cpp:202
bool isValidParameterExpression(Attribute attr, Operation *module)
Return true if the specified attribute tree is made up of nodes that are valid in a parameter express...
Definition: HWOps.cpp:221
bool isValidIndexBitWidth(Value index, Value array)
Definition: HWOps.cpp:48
int64_t getBitWidth(mlir::Type type)
Return the hardware bit width of a type.
Definition: HWTypes.cpp:102
bool isAnyModuleOrInstance(Operation *module)
TODO: Move all these functions to a hw::ModuleLike interface.
Definition: HWOps.cpp:509
StringAttr getVerilogModuleNameAttr(Operation *module)
Returns the verilog module name attribute or symbol name of any module-like operations.
Definition: HWOps.cpp:533
mlir::Type getCanonicalType(mlir::Type type)
Definition: HWTypes.cpp:41
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21
ParseResult parseInputPortList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &inputs, SmallVectorImpl< Type > &inputTypes, ArrayAttr &inputNames)
Parse a list of instance input ports.
void printOutputPortList(OpAsmPrinter &p, Operation *op, TypeRange resultTypes, ArrayAttr resultNames)
Print a list of instance output ports.
ParseResult parseOptionalParameterList(OpAsmParser &parser, ArrayAttr &parameters)
Parse an parameter list if present.
void printOptionalParameterList(OpAsmPrinter &p, Operation *op, ArrayAttr parameters)
Print a parameter list for a module or instance.
StringRef chooseName(StringRef a, StringRef b)
Choose a good name for an item from two options.
Definition: Naming.cpp:47
void printInputPortList(OpAsmPrinter &p, Operation *op, OperandRange inputs, TypeRange inputTypes, ArrayAttr inputNames)
Print a list of instance input ports.
ParseResult parseOutputPortList(OpAsmParser &parser, SmallVectorImpl< Type > &resultTypes, ArrayAttr &resultNames)
Parse a list of instance output ports.
Definition: hw.py:1
function_ref< void(Value, StringRef)> OpAsmSetValueNameFn
Definition: LLVM.h:186
This holds a decoded list of input/inout and output ports for a module or instance.
PortInfo & at(size_t idx)
PortDirectionRange getInputs()
PortDirectionRange getOutputs()
mlir::Type type
Definition: HWTypes.h:30
mlir::StringAttr name
Definition: HWTypes.h:29
This holds the name, type, direction of a module's ports.