CIRCT  20.0.0git
HWOps.cpp
Go to the documentation of this file.
1 //===- HWOps.cpp - Implement the HW operations ----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implement the HW ops.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "circt/Dialect/HW/HWOps.h"
23 #include "circt/Support/Naming.h"
24 #include "mlir/IR/Builders.h"
25 #include "mlir/IR/PatternMatch.h"
26 #include "mlir/Interfaces/FunctionImplementation.h"
27 #include "llvm/ADT/BitVector.h"
28 #include "llvm/ADT/SmallPtrSet.h"
29 #include "llvm/ADT/StringSet.h"
30 
31 using namespace circt;
32 using namespace hw;
33 using mlir::TypedAttr;
34 
35 /// Flip a port direction.
37  switch (direction) {
44  }
45  llvm_unreachable("unknown PortDirection");
46 }
47 
48 bool hw::isValidIndexBitWidth(Value index, Value array) {
49  hw::ArrayType arrayType =
50  dyn_cast<hw::ArrayType>(hw::getCanonicalType(array.getType()));
51  assert(arrayType && "expected array type");
52  unsigned indexWidth = index.getType().getIntOrFloatBitWidth();
53  auto requiredWidth = llvm::Log2_64_Ceil(arrayType.getNumElements());
54  return requiredWidth == 0 ? (indexWidth == 0 || indexWidth == 1)
55  : indexWidth == requiredWidth;
56 }
57 
58 /// Return true if the specified operation is a combinational logic op.
59 bool hw::isCombinational(Operation *op) {
60  struct IsCombClassifier : public TypeOpVisitor<IsCombClassifier, bool> {
61  bool visitInvalidTypeOp(Operation *op) { return false; }
62  bool visitUnhandledTypeOp(Operation *op) { return true; }
63  };
64 
65  return (op->getDialect() && op->getDialect()->getNamespace() == "comb") ||
66  IsCombClassifier().dispatchTypeOpVisitor(op);
67 }
68 
69 static Value foldStructExtract(Operation *inputOp, uint32_t fieldIndex) {
70  // A struct extract of a struct create -> corresponding struct create operand.
71  if (auto structCreate = dyn_cast_or_null<StructCreateOp>(inputOp)) {
72  return structCreate.getOperand(fieldIndex);
73  }
74 
75  // Extracting injected field -> corresponding field
76  if (auto structInject = dyn_cast_or_null<StructInjectOp>(inputOp)) {
77  if (structInject.getFieldIndex() != fieldIndex)
78  return {};
79  return structInject.getNewValue();
80  }
81  return {};
82 }
83 
84 static ArrayAttr arrayOrEmpty(mlir::MLIRContext *context,
85  ArrayRef<Attribute> attrs) {
86  if (attrs.empty())
87  return ArrayAttr::get(context, {});
88  bool empty = true;
89  for (auto a : attrs)
90  if (a && !cast<DictionaryAttr>(a).empty()) {
91  empty = false;
92  break;
93  }
94  if (empty)
95  return ArrayAttr::get(context, {});
96  return ArrayAttr::get(context, attrs);
97 }
98 
99 /// Get a special name to use when printing the entry block arguments of the
100 /// region contained by an operation in this dialect.
101 static void getAsmBlockArgumentNamesImpl(mlir::Region &region,
102  OpAsmSetValueNameFn setNameFn) {
103  if (region.empty())
104  return;
105  // Assign port names to the bbargs.
106  auto module = cast<HWModuleOp>(region.getParentOp());
107 
108  auto *block = &region.front();
109  for (size_t i = 0, e = block->getNumArguments(); i != e; ++i) {
110  auto name = module.getInputName(i);
111  // Let mlir deterministically convert names to valid identifiers
112  setNameFn(block->getArgument(i), name);
113  }
114 }
115 
116 enum class Delimiter {
117  None,
118  Paren, // () enclosed list
119  OptionalLessGreater, // <> enclosed list or absent
120 };
121 
122 /// Check parameter specified by `value` to see if it is valid according to the
123 /// module's parameters. If not, emit an error to the diagnostic provided as an
124 /// argument to the lambda 'instanceError' and return failure, otherwise return
125 /// success.
126 ///
127 /// If `disallowParamRefs` is true, then parameter references are not allowed.
129  Attribute value, ArrayAttr moduleParameters,
130  const instance_like_impl::EmitErrorFn &instanceError,
131  bool disallowParamRefs) {
132  // Literals are always ok. Their types are already known to match
133  // expectations.
134  if (isa<IntegerAttr>(value) || isa<FloatAttr>(value) ||
135  isa<StringAttr>(value) || isa<ParamVerbatimAttr>(value))
136  return success();
137 
138  // Check both subexpressions of an expression.
139  if (auto expr = dyn_cast<ParamExprAttr>(value)) {
140  for (auto op : expr.getOperands())
141  if (failed(checkParameterInContext(op, moduleParameters, instanceError,
142  disallowParamRefs)))
143  return failure();
144  return success();
145  }
146 
147  // Parameter references need more analysis to make sure they are valid within
148  // this module.
149  if (auto parameterRef = dyn_cast<ParamDeclRefAttr>(value)) {
150  auto nameAttr = parameterRef.getName();
151 
152  // Don't allow references to parameters from the default values of a
153  // parameter list.
154  if (disallowParamRefs) {
155  instanceError([&](auto &diag) {
156  diag << "parameter " << nameAttr
157  << " cannot be used as a default value for a parameter";
158  return false;
159  });
160  return failure();
161  }
162 
163  // Find the corresponding attribute in the module.
164  for (auto param : moduleParameters) {
165  auto paramAttr = cast<ParamDeclAttr>(param);
166  if (paramAttr.getName() != nameAttr)
167  continue;
168 
169  // If the types match then the reference is ok.
170  if (paramAttr.getType() == parameterRef.getType())
171  return success();
172 
173  instanceError([&](auto &diag) {
174  diag << "parameter " << nameAttr << " used with type "
175  << parameterRef.getType() << "; should have type "
176  << paramAttr.getType();
177  return true;
178  });
179  return failure();
180  }
181 
182  instanceError([&](auto &diag) {
183  diag << "use of unknown parameter " << nameAttr;
184  return true;
185  });
186  return failure();
187  }
188 
189  instanceError([&](auto &diag) {
190  diag << "invalid parameter value " << value;
191  return false;
192  });
193  return failure();
194 }
195 
196 /// Check parameter specified by `value` to see if it is valid within the scope
197 /// of the specified module `module`. If not, emit an error at the location of
198 /// `usingOp` and return failure, otherwise return success. If `usingOp` is
199 /// null, then no diagnostic is generated.
200 ///
201 /// If `disallowParamRefs` is true, then parameter references are not allowed.
202 LogicalResult hw::checkParameterInContext(Attribute value, Operation *module,
203  Operation *usingOp,
204  bool disallowParamRefs) {
206  [&](const std::function<bool(InFlightDiagnostic &)> &fn) {
207  if (usingOp) {
208  auto diag = usingOp->emitOpError();
209  if (fn(diag))
210  diag.attachNote(module->getLoc()) << "module declared here";
211  }
212  };
213 
214  return checkParameterInContext(value,
215  module->getAttrOfType<ArrayAttr>("parameters"),
216  emitError, disallowParamRefs);
217 }
218 
219 /// Return true if the specified attribute tree is made up of nodes that are
220 /// valid in a parameter expression.
221 bool hw::isValidParameterExpression(Attribute attr, Operation *module) {
222  return succeeded(checkParameterInContext(attr, module, nullptr, false));
223 }
224 
226  const ModulePortInfo &info,
227  Region &bodyRegion)
228  : info(info) {
229  inputArgs.resize(info.sizeInputs());
230  for (auto [i, barg] : llvm::enumerate(bodyRegion.getArguments())) {
231  inputIdx[info.at(i).name.str()] = i;
232  inputArgs[i] = barg;
233  }
234 
235  outputOperands.resize(info.sizeOutputs());
236  for (auto [i, outputInfo] : llvm::enumerate(info.getOutputs())) {
237  outputIdx[outputInfo.name.str()] = i;
238  }
239 }
240 
241 void HWModulePortAccessor::setOutput(unsigned i, Value v) {
242  assert(outputOperands.size() > i && "invalid output index");
243  assert(outputOperands[i] == Value() && "output already set");
244  outputOperands[i] = v;
245 }
246 
248  assert(inputArgs.size() > i && "invalid input index");
249  return inputArgs[i];
250 }
251 Value HWModulePortAccessor::getInput(StringRef name) {
252  return getInput(inputIdx.find(name.str())->second);
253 }
254 void HWModulePortAccessor::setOutput(StringRef name, Value v) {
255  setOutput(outputIdx.find(name.str())->second, v);
256 }
257 
258 //===----------------------------------------------------------------------===//
259 // ConstantOp
260 //===----------------------------------------------------------------------===//
261 
262 void ConstantOp::print(OpAsmPrinter &p) {
263  p << " ";
264  p.printAttribute(getValueAttr());
265  p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"value"});
266 }
267 
268 ParseResult ConstantOp::parse(OpAsmParser &parser, OperationState &result) {
269  IntegerAttr valueAttr;
270 
271  if (parser.parseAttribute(valueAttr, "value", result.attributes) ||
272  parser.parseOptionalAttrDict(result.attributes))
273  return failure();
274 
275  result.addTypes(valueAttr.getType());
276  return success();
277 }
278 
279 LogicalResult ConstantOp::verify() {
280  // If the result type has a bitwidth, then the attribute must match its width.
281  if (getValue().getBitWidth() != cast<IntegerType>(getType()).getWidth())
282  return emitError(
283  "hw.constant attribute bitwidth doesn't match return type");
284 
285  return success();
286 }
287 
288 /// Build a ConstantOp from an APInt, infering the result type from the
289 /// width of the APInt.
290 void ConstantOp::build(OpBuilder &builder, OperationState &result,
291  const APInt &value) {
292 
293  auto type = IntegerType::get(builder.getContext(), value.getBitWidth());
294  auto attr = builder.getIntegerAttr(type, value);
295  return build(builder, result, type, attr);
296 }
297 
298 /// Build a ConstantOp from an APInt, infering the result type from the
299 /// width of the APInt.
300 void ConstantOp::build(OpBuilder &builder, OperationState &result,
301  IntegerAttr value) {
302  return build(builder, result, value.getType(), value);
303 }
304 
305 /// This builder allows construction of small signed integers like 0, 1, -1
306 /// matching a specified MLIR IntegerType. This shouldn't be used for general
307 /// constant folding because it only works with values that can be expressed in
308 /// an int64_t. Use APInt's instead.
309 void ConstantOp::build(OpBuilder &builder, OperationState &result, Type type,
310  int64_t value) {
311  auto numBits = cast<IntegerType>(type).getWidth();
312  build(builder, result, APInt(numBits, (uint64_t)value, /*isSigned=*/true));
313 }
314 
316  function_ref<void(Value, StringRef)> setNameFn) {
317  auto intTy = getType();
318  auto intCst = getValue();
319 
320  // Sugar i1 constants with 'true' and 'false'.
321  if (cast<IntegerType>(intTy).getWidth() == 1)
322  return setNameFn(getResult(), intCst.isZero() ? "false" : "true");
323 
324  // Otherwise, build a complex name with the value and type.
325  SmallVector<char, 32> specialNameBuffer;
326  llvm::raw_svector_ostream specialName(specialNameBuffer);
327  specialName << 'c' << intCst << '_' << intTy;
328  setNameFn(getResult(), specialName.str());
329 }
330 
331 OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) {
332  assert(adaptor.getOperands().empty() && "constant has no operands");
333  return getValueAttr();
334 }
335 
336 //===----------------------------------------------------------------------===//
337 // WireOp
338 //===----------------------------------------------------------------------===//
339 
340 /// Check whether an operation has any additional attributes set beyond its
341 /// standard list of attributes returned by `getAttributeNames`.
342 template <class Op>
343 static bool hasAdditionalAttributes(Op op,
344  ArrayRef<StringRef> ignoredAttrs = {}) {
345  auto names = op.getAttributeNames();
346  llvm::SmallDenseSet<StringRef> nameSet;
347  nameSet.reserve(names.size() + ignoredAttrs.size());
348  nameSet.insert(names.begin(), names.end());
349  nameSet.insert(ignoredAttrs.begin(), ignoredAttrs.end());
350  return llvm::any_of(op->getAttrs(), [&](auto namedAttr) {
351  return !nameSet.contains(namedAttr.getName());
352  });
353 }
354 
356  // If the wire has an optional 'name' attribute, use it.
357  auto nameAttr = (*this)->getAttrOfType<StringAttr>("name");
358  if (nameAttr && !nameAttr.getValue().empty())
359  setNameFn(getResult(), nameAttr.getValue());
360 }
361 
362 std::optional<size_t> WireOp::getTargetResultIndex() { return 0; }
363 
364 OpFoldResult WireOp::fold(FoldAdaptor adaptor) {
365  // If the wire has no additional attributes, no name, and no symbol, just
366  // forward its input.
367  if (!hasAdditionalAttributes(*this, {"sv.namehint"}) && !getNameAttr() &&
368  !getInnerSymAttr())
369  return getInput();
370  return {};
371 }
372 
373 LogicalResult WireOp::canonicalize(WireOp wire, PatternRewriter &rewriter) {
374  // Block if the wire has any attributes.
375  if (hasAdditionalAttributes(wire, {"sv.namehint"}))
376  return failure();
377 
378  // If the wire has a symbol, then we can't delete it.
379  if (wire.getInnerSymAttr())
380  return failure();
381 
382  // If the wire has a name or an `sv.namehint` attribute, propagate it as an
383  // `sv.namehint` to the expression.
384  if (auto *inputOp = wire.getInput().getDefiningOp())
385  if (auto name = chooseName(wire, inputOp))
386  rewriter.modifyOpInPlace(inputOp,
387  [&] { inputOp->setAttr("sv.namehint", name); });
388 
389  rewriter.replaceOp(wire, wire.getInput());
390  return success();
391 }
392 
393 //===----------------------------------------------------------------------===//
394 // AggregateConstantOp
395 //===----------------------------------------------------------------------===//
396 
397 static LogicalResult checkAttributes(Operation *op, Attribute attr, Type type) {
398  // If this is a type alias, get the underlying type.
399  if (auto typeAlias = dyn_cast<TypeAliasType>(type))
400  type = typeAlias.getCanonicalType();
401 
402  if (auto structType = dyn_cast<StructType>(type)) {
403  auto arrayAttr = dyn_cast<ArrayAttr>(attr);
404  if (!arrayAttr)
405  return op->emitOpError("expected array attribute for constant of type ")
406  << type;
407  if (structType.getElements().size() != arrayAttr.size())
408  return op->emitOpError("array attribute (")
409  << arrayAttr.size() << ") has wrong size for struct constant ("
410  << structType.getElements().size() << ")";
411 
412  for (auto [attr, fieldInfo] :
413  llvm::zip(arrayAttr.getValue(), structType.getElements())) {
414  if (failed(checkAttributes(op, attr, fieldInfo.type)))
415  return failure();
416  }
417  } else if (auto arrayType = dyn_cast<ArrayType>(type)) {
418  auto arrayAttr = dyn_cast<ArrayAttr>(attr);
419  if (!arrayAttr)
420  return op->emitOpError("expected array attribute for constant of type ")
421  << type;
422  if (arrayType.getNumElements() != arrayAttr.size())
423  return op->emitOpError("array attribute (")
424  << arrayAttr.size() << ") has wrong size for array constant ("
425  << arrayType.getNumElements() << ")";
426 
427  auto elementType = arrayType.getElementType();
428  for (auto attr : arrayAttr.getValue()) {
429  if (failed(checkAttributes(op, attr, elementType)))
430  return failure();
431  }
432  } else if (auto arrayType = dyn_cast<UnpackedArrayType>(type)) {
433  auto arrayAttr = dyn_cast<ArrayAttr>(attr);
434  if (!arrayAttr)
435  return op->emitOpError("expected array attribute for constant of type ")
436  << type;
437  auto elementType = arrayType.getElementType();
438  if (arrayType.getNumElements() != arrayAttr.size())
439  return op->emitOpError("array attribute (")
440  << arrayAttr.size()
441  << ") has wrong size for unpacked array constant ("
442  << arrayType.getNumElements() << ")";
443 
444  for (auto attr : arrayAttr.getValue()) {
445  if (failed(checkAttributes(op, attr, elementType)))
446  return failure();
447  }
448  } else if (auto enumType = dyn_cast<EnumType>(type)) {
449  auto stringAttr = dyn_cast<StringAttr>(attr);
450  if (!stringAttr)
451  return op->emitOpError("expected string attribute for constant of type ")
452  << type;
453  } else if (auto intType = dyn_cast<IntegerType>(type)) {
454  // Check the attribute kind is correct.
455  auto intAttr = dyn_cast<IntegerAttr>(attr);
456  if (!intAttr)
457  return op->emitOpError("expected integer attribute for constant of type ")
458  << type;
459  // Check the bitwidth is correct.
460  if (intAttr.getValue().getBitWidth() != intType.getWidth())
461  return op->emitOpError("hw.constant attribute bitwidth "
462  "doesn't match return type");
463  } else if (auto typedAttr = dyn_cast<TypedAttr>(attr)) {
464  if (typedAttr.getType() != type)
465  return op->emitOpError("typed attr doesn't match the return type ")
466  << type;
467  } else {
468  return op->emitOpError("unknown element type ") << type;
469  }
470  return success();
471 }
472 
473 LogicalResult AggregateConstantOp::verify() {
474  return checkAttributes(*this, getFieldsAttr(), getType());
475 }
476 
477 OpFoldResult AggregateConstantOp::fold(FoldAdaptor) { return getFieldsAttr(); }
478 
479 //===----------------------------------------------------------------------===//
480 // ParamValueOp
481 //===----------------------------------------------------------------------===//
482 
483 static ParseResult parseParamValue(OpAsmParser &p, Attribute &value,
484  Type &resultType) {
485  if (p.parseType(resultType) || p.parseEqual() ||
486  p.parseAttribute(value, resultType))
487  return failure();
488  return success();
489 }
490 
491 static void printParamValue(OpAsmPrinter &p, Operation *, Attribute value,
492  Type resultType) {
493  p << resultType << " = ";
494  p.printAttributeWithoutType(value);
495 }
496 
497 LogicalResult ParamValueOp::verify() {
498  // Check that the attribute expression is valid in this module.
500  getValue(), (*this)->getParentOfType<hw::HWModuleOp>(), *this);
501 }
502 
503 OpFoldResult ParamValueOp::fold(FoldAdaptor adaptor) {
504  assert(adaptor.getOperands().empty() && "hw.param.value has no operands");
505  return getValueAttr();
506 }
507 
508 //===----------------------------------------------------------------------===//
509 // HWModuleOp
510 //===----------------------------------------------------------------------===/
511 
512 /// Return true if isAnyModule or instance.
513 bool hw::isAnyModuleOrInstance(Operation *moduleOrInstance) {
514  return isa<HWModuleLike, InstanceOp>(moduleOrInstance);
515 }
516 
517 /// Return the signature for a module as a function type from the module itself
518 /// or from an hw::InstanceOp.
519 FunctionType hw::getModuleType(Operation *moduleOrInstance) {
520  return TypeSwitch<Operation *, FunctionType>(moduleOrInstance)
521  .Case<InstanceOp, InstanceChoiceOp>([](auto instance) {
522  SmallVector<Type> inputs(instance->getOperandTypes());
523  SmallVector<Type> results(instance->getResultTypes());
524  return FunctionType::get(instance->getContext(), inputs, results);
525  })
526  .Case<HWModuleLike>(
527  [](auto mod) { return mod.getHWModuleType().getFuncType(); })
528  .Default([](Operation *op) {
529  return cast<FunctionType>(
530  cast<mlir::FunctionOpInterface>(op).getFunctionType());
531  });
532 }
533 
534 /// Return the name to use for the Verilog module that we're referencing
535 /// here. This is typically the symbol, but can be overridden with the
536 /// verilogName attribute.
537 StringAttr hw::getVerilogModuleNameAttr(Operation *module) {
538  auto nameAttr = module->getAttrOfType<StringAttr>("verilogName");
539  if (nameAttr)
540  return nameAttr;
541 
542  return module->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName());
543 }
544 
545 template <typename ModuleTy>
546 static void
547 buildModule(OpBuilder &builder, OperationState &result, StringAttr name,
548  const ModulePortInfo &ports, ArrayAttr parameters,
549  ArrayRef<NamedAttribute> attributes, StringAttr comment) {
550  using namespace mlir::function_interface_impl;
551 
552  // Add an attribute for the name.
553  result.addAttribute(SymbolTable::getSymbolAttrName(), name);
554 
555  SmallVector<Attribute> perPortAttrs;
556  SmallVector<ModulePort> portTypes;
557 
558  for (auto elt : ports) {
559  portTypes.push_back(elt);
560  llvm::SmallVector<NamedAttribute> portAttrs;
561  if (elt.attrs)
562  llvm::copy(elt.attrs, std::back_inserter(portAttrs));
563  perPortAttrs.push_back(builder.getDictionaryAttr(portAttrs));
564  }
565 
566  // Allow clients to pass in null for the parameters list.
567  if (!parameters)
568  parameters = builder.getArrayAttr({});
569 
570  // Record the argument and result types as an attribute.
571  auto type = ModuleType::get(builder.getContext(), portTypes);
572  result.addAttribute(ModuleTy::getModuleTypeAttrName(result.name),
573  TypeAttr::get(type));
574  result.addAttribute("per_port_attrs",
575  arrayOrEmpty(builder.getContext(), perPortAttrs));
576  result.addAttribute("parameters", parameters);
577  if (!comment)
578  comment = builder.getStringAttr("");
579  result.addAttribute("comment", comment);
580  result.addAttributes(attributes);
581  result.addRegion();
582 }
583 
584 /// Internal implementation of argument/result insertion and removal on modules.
585 static void modifyModuleArgs(
586  MLIRContext *context, ArrayRef<std::pair<unsigned, PortInfo>> insertArgs,
587  ArrayRef<unsigned> removeArgs, ArrayRef<Attribute> oldArgNames,
588  ArrayRef<Type> oldArgTypes, ArrayRef<Attribute> oldArgAttrs,
589  ArrayRef<Location> oldArgLocs, SmallVector<Attribute> &newArgNames,
590  SmallVector<Type> &newArgTypes, SmallVector<Attribute> &newArgAttrs,
591  SmallVector<Location> &newArgLocs, Block *body = nullptr) {
592 
593 #ifndef NDEBUG
594  // Check that the `insertArgs` and `removeArgs` indices are in ascending
595  // order.
596  assert(llvm::is_sorted(insertArgs,
597  [](auto &a, auto &b) { return a.first < b.first; }) &&
598  "insertArgs must be in ascending order");
599  assert(llvm::is_sorted(removeArgs, [](auto &a, auto &b) { return a < b; }) &&
600  "removeArgs must be in ascending order");
601 #endif
602 
603  auto oldArgCount = oldArgTypes.size();
604  auto newArgCount = oldArgCount + insertArgs.size() - removeArgs.size();
605  assert((int)newArgCount >= 0);
606 
607  newArgNames.reserve(newArgCount);
608  newArgTypes.reserve(newArgCount);
609  newArgAttrs.reserve(newArgCount);
610  newArgLocs.reserve(newArgCount);
611 
612  auto exportPortAttrName = StringAttr::get(context, "hw.exportPort");
613  auto emptyDictAttr = DictionaryAttr::get(context, {});
614  auto unknownLoc = UnknownLoc::get(context);
615 
616  BitVector erasedIndices;
617  if (body)
618  erasedIndices.resize(oldArgCount + insertArgs.size());
619 
620  for (unsigned argIdx = 0, idx = 0; argIdx <= oldArgCount; ++argIdx, ++idx) {
621  // Insert new ports at this position.
622  while (!insertArgs.empty() && insertArgs[0].first == argIdx) {
623  auto port = insertArgs[0].second;
624  if (port.dir == ModulePort::Direction::InOut &&
625  !isa<InOutType>(port.type))
626  port.type = InOutType::get(port.type);
627  auto sym = port.getSym();
628  Attribute attr =
629  (sym && !sym.empty())
630  ? DictionaryAttr::get(context, {{exportPortAttrName, sym}})
631  : emptyDictAttr;
632  newArgNames.push_back(port.name);
633  newArgTypes.push_back(port.type);
634  newArgAttrs.push_back(attr);
635  insertArgs = insertArgs.drop_front();
636  LocationAttr loc = port.loc ? port.loc : unknownLoc;
637  newArgLocs.push_back(loc);
638  if (body)
639  body->insertArgument(idx++, port.type, loc);
640  }
641  if (argIdx == oldArgCount)
642  break;
643 
644  // Migrate the old port at this position.
645  bool removed = false;
646  while (!removeArgs.empty() && removeArgs[0] == argIdx) {
647  removeArgs = removeArgs.drop_front();
648  removed = true;
649  }
650 
651  if (removed) {
652  if (body)
653  erasedIndices.set(idx);
654  } else {
655  newArgNames.push_back(oldArgNames[argIdx]);
656  newArgTypes.push_back(oldArgTypes[argIdx]);
657  newArgAttrs.push_back(oldArgAttrs.empty() ? emptyDictAttr
658  : oldArgAttrs[argIdx]);
659  newArgLocs.push_back(oldArgLocs[argIdx]);
660  }
661  }
662 
663  if (body)
664  body->eraseArguments(erasedIndices);
665 
666  assert(newArgNames.size() == newArgCount);
667  assert(newArgTypes.size() == newArgCount);
668  assert(newArgAttrs.size() == newArgCount);
669  assert(newArgLocs.size() == newArgCount);
670 }
671 
672 /// Insert and remove ports of a module. The insertion and removal indices must
673 /// be in ascending order. The indices refer to the port positions before any
674 /// insertion or removal occurs. Ports inserted at the same index will appear in
675 /// the module in the same order as they were listed in the `insert*` array.
676 ///
677 /// The operation must be any of the module-like operations.
678 ///
679 /// This is marked deprecated as it's only used from HandshakeToHW and
680 /// PortConverter and is likely broken and not currently tested. Users of this
681 /// are still written dealing with input and output ports separately, which is
682 /// an old and broken style.
683 [[deprecated]] static void
684 modifyModulePorts(Operation *op,
685  ArrayRef<std::pair<unsigned, PortInfo>> insertInputs,
686  ArrayRef<std::pair<unsigned, PortInfo>> insertOutputs,
687  ArrayRef<unsigned> removeInputs,
688  ArrayRef<unsigned> removeOutputs, Block *body = nullptr) {
689  auto moduleOp = cast<HWModuleLike>(op);
690  auto *context = moduleOp.getContext();
691 
692  // Dig up the old argument and result data.
693  auto oldArgNames = moduleOp.getInputNames();
694  auto oldArgTypes = moduleOp.getInputTypes();
695  auto oldArgAttrs = moduleOp.getAllInputAttrs();
696  auto oldArgLocs = moduleOp.getInputLocs();
697 
698  auto oldResultNames = moduleOp.getOutputNames();
699  auto oldResultTypes = moduleOp.getOutputTypes();
700  auto oldResultAttrs = moduleOp.getAllOutputAttrs();
701  auto oldResultLocs = moduleOp.getOutputLocs();
702 
703  // Modify the ports.
704  SmallVector<Attribute> newArgNames, newResultNames;
705  SmallVector<Type> newArgTypes, newResultTypes;
706  SmallVector<Attribute> newArgAttrs, newResultAttrs;
707  SmallVector<Location> newArgLocs, newResultLocs;
708 
709  modifyModuleArgs(context, insertInputs, removeInputs, oldArgNames,
710  oldArgTypes, oldArgAttrs, oldArgLocs, newArgNames,
711  newArgTypes, newArgAttrs, newArgLocs, body);
712 
713  modifyModuleArgs(context, insertOutputs, removeOutputs, oldResultNames,
714  oldResultTypes, oldResultAttrs, oldResultLocs,
715  newResultNames, newResultTypes, newResultAttrs,
716  newResultLocs);
717 
718  // Update the module operation types and attributes.
719  auto fnty = FunctionType::get(context, newArgTypes, newResultTypes);
720  auto modty = detail::fnToMod(fnty, newArgNames, newResultNames);
721  moduleOp.setHWModuleType(modty);
722  moduleOp.setAllInputAttrs(newArgAttrs);
723  moduleOp.setAllOutputAttrs(newResultAttrs);
724 
725  newArgLocs.append(newResultLocs.begin(), newResultLocs.end());
726  moduleOp.setAllPortLocs(newArgLocs);
727 }
728 
729 void HWModuleOp::build(OpBuilder &builder, OperationState &result,
730  StringAttr name, const ModulePortInfo &ports,
731  ArrayAttr parameters,
732  ArrayRef<NamedAttribute> attributes, StringAttr comment,
733  bool shouldEnsureTerminator) {
734  buildModule<HWModuleOp>(builder, result, name, ports, parameters, attributes,
735  comment);
736 
737  // Create a region and a block for the body.
738  auto *bodyRegion = result.regions[0].get();
739  Block *body = new Block();
740  bodyRegion->push_back(body);
741 
742  // Add arguments to the body block.
743  auto unknownLoc = builder.getUnknownLoc();
744  for (auto port : ports.getInputs()) {
745  auto loc = port.loc ? Location(port.loc) : unknownLoc;
746  auto type = port.type;
747  if (port.isInOut() && !isa<InOutType>(type))
748  type = InOutType::get(type);
749  body->addArgument(type, loc);
750  }
751 
752  // Add result ports attribute.
753  auto unknownLocAttr = cast<LocationAttr>(unknownLoc);
754  SmallVector<Attribute> resultLocs;
755  for (auto port : ports.getOutputs())
756  resultLocs.push_back(port.loc ? port.loc : unknownLocAttr);
757  result.addAttribute("result_locs", builder.getArrayAttr(resultLocs));
758 
759  if (shouldEnsureTerminator)
760  HWModuleOp::ensureTerminator(*bodyRegion, builder, result.location);
761 }
762 
763 void HWModuleOp::build(OpBuilder &builder, OperationState &result,
764  StringAttr name, ArrayRef<PortInfo> ports,
765  ArrayAttr parameters,
766  ArrayRef<NamedAttribute> attributes,
767  StringAttr comment) {
768  build(builder, result, name, ModulePortInfo(ports), parameters, attributes,
769  comment);
770 }
771 
772 void HWModuleOp::build(OpBuilder &builder, OperationState &odsState,
773  StringAttr name, const ModulePortInfo &ports,
774  HWModuleBuilder modBuilder, ArrayAttr parameters,
775  ArrayRef<NamedAttribute> attributes,
776  StringAttr comment) {
777  build(builder, odsState, name, ports, parameters, attributes, comment,
778  /*shouldEnsureTerminator=*/false);
779  auto *bodyRegion = odsState.regions[0].get();
780  OpBuilder::InsertionGuard guard(builder);
781  auto accessor = HWModulePortAccessor(odsState.location, ports, *bodyRegion);
782  builder.setInsertionPointToEnd(&bodyRegion->front());
783  modBuilder(builder, accessor);
784  // Create output operands.
785  llvm::SmallVector<Value> outputOperands = accessor.getOutputOperands();
786  builder.create<hw::OutputOp>(odsState.location, outputOperands);
787 }
788 
789 void HWModuleOp::modifyPorts(
790  ArrayRef<std::pair<unsigned, PortInfo>> insertInputs,
791  ArrayRef<std::pair<unsigned, PortInfo>> insertOutputs,
792  ArrayRef<unsigned> eraseInputs, ArrayRef<unsigned> eraseOutputs) {
793  modifyModulePorts(*this, insertInputs, insertOutputs, eraseInputs,
794  eraseOutputs);
795 }
796 
797 /// Return the name to use for the Verilog module that we're referencing
798 /// here. This is typically the symbol, but can be overridden with the
799 /// verilogName attribute.
801  if (auto vName = getVerilogNameAttr())
802  return vName;
803 
804  return (*this)->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName());
805 }
806 
808  if (auto vName = getVerilogNameAttr()) {
809  return vName;
810  }
811  return (*this)->getAttrOfType<StringAttr>(
812  ::mlir::SymbolTable::getSymbolAttrName());
813 }
814 
815 void HWModuleExternOp::build(OpBuilder &builder, OperationState &result,
816  StringAttr name, const ModulePortInfo &ports,
817  StringRef verilogName, ArrayAttr parameters,
818  ArrayRef<NamedAttribute> attributes) {
819  buildModule<HWModuleExternOp>(builder, result, name, ports, parameters,
820  attributes, {});
821 
822  // Add the port locations.
823  LocationAttr unknownLoc = builder.getUnknownLoc();
824  SmallVector<Attribute> portLocs;
825  for (auto elt : ports)
826  portLocs.push_back(elt.loc ? elt.loc : unknownLoc);
827  result.addAttribute("port_locs", builder.getArrayAttr(portLocs));
828 
829  if (!verilogName.empty())
830  result.addAttribute("verilogName", builder.getStringAttr(verilogName));
831 }
832 
833 void HWModuleExternOp::build(OpBuilder &builder, OperationState &result,
834  StringAttr name, ArrayRef<PortInfo> ports,
835  StringRef verilogName, ArrayAttr parameters,
836  ArrayRef<NamedAttribute> attributes) {
837  build(builder, result, name, ModulePortInfo(ports), verilogName, parameters,
838  attributes);
839 }
840 
841 void HWModuleExternOp::modifyPorts(
842  ArrayRef<std::pair<unsigned, PortInfo>> insertInputs,
843  ArrayRef<std::pair<unsigned, PortInfo>> insertOutputs,
844  ArrayRef<unsigned> eraseInputs, ArrayRef<unsigned> eraseOutputs) {
845  modifyModulePorts(*this, insertInputs, insertOutputs, eraseInputs,
846  eraseOutputs);
847 }
848 
849 void HWModuleExternOp::appendOutputs(
850  ArrayRef<std::pair<StringAttr, Value>> outputs) {}
851 
852 void HWModuleGeneratedOp::build(OpBuilder &builder, OperationState &result,
853  FlatSymbolRefAttr genKind, StringAttr name,
854  const ModulePortInfo &ports,
855  StringRef verilogName, ArrayAttr parameters,
856  ArrayRef<NamedAttribute> attributes) {
857  buildModule<HWModuleGeneratedOp>(builder, result, name, ports, parameters,
858  attributes, {});
859  // Add the port locations.
860  LocationAttr unknownLoc = builder.getUnknownLoc();
861  SmallVector<Attribute> portLocs;
862  for (auto elt : ports)
863  portLocs.push_back(elt.loc ? elt.loc : unknownLoc);
864  result.addAttribute("port_locs", builder.getArrayAttr(portLocs));
865 
866  result.addAttribute("generatorKind", genKind);
867  if (!verilogName.empty())
868  result.addAttribute("verilogName", builder.getStringAttr(verilogName));
869 }
870 
871 void HWModuleGeneratedOp::build(OpBuilder &builder, OperationState &result,
872  FlatSymbolRefAttr genKind, StringAttr name,
873  ArrayRef<PortInfo> ports, StringRef verilogName,
874  ArrayAttr parameters,
875  ArrayRef<NamedAttribute> attributes) {
876  build(builder, result, genKind, name, ModulePortInfo(ports), verilogName,
877  parameters, attributes);
878 }
879 
880 void HWModuleGeneratedOp::modifyPorts(
881  ArrayRef<std::pair<unsigned, PortInfo>> insertInputs,
882  ArrayRef<std::pair<unsigned, PortInfo>> insertOutputs,
883  ArrayRef<unsigned> eraseInputs, ArrayRef<unsigned> eraseOutputs) {
884  modifyModulePorts(*this, insertInputs, insertOutputs, eraseInputs,
885  eraseOutputs);
886 }
887 
888 void HWModuleGeneratedOp::appendOutputs(
889  ArrayRef<std::pair<StringAttr, Value>> outputs) {}
890 
891 static bool hasAttribute(StringRef name, ArrayRef<NamedAttribute> attrs) {
892  for (auto &argAttr : attrs)
893  if (argAttr.getName() == name)
894  return true;
895  return false;
896 }
897 
898 template <typename ModuleTy>
899 static ParseResult parseHWModuleOp(OpAsmParser &parser,
900  OperationState &result) {
901 
902  using namespace mlir::function_interface_impl;
903  auto builder = parser.getBuilder();
904  auto loc = parser.getCurrentLocation();
905 
906  // Parse the visibility attribute.
907  (void)mlir::impl::parseOptionalVisibilityKeyword(parser, result.attributes);
908 
909  // Parse the name as a symbol.
910  StringAttr nameAttr;
911  if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
912  result.attributes))
913  return failure();
914 
915  // Parse the generator information.
916  FlatSymbolRefAttr kindAttr;
917  if constexpr (std::is_same_v<ModuleTy, HWModuleGeneratedOp>) {
918  if (parser.parseComma() ||
919  parser.parseAttribute(kindAttr, "generatorKind", result.attributes)) {
920  return failure();
921  }
922  }
923 
924  // Parse the parameters.
925  ArrayAttr parameters;
926  if (parseOptionalParameterList(parser, parameters))
927  return failure();
928 
929  SmallVector<module_like_impl::PortParse> ports;
930  TypeAttr modType;
931  if (failed(module_like_impl::parseModuleSignature(parser, ports, modType)))
932  return failure();
933 
934  // Parse the attribute dict.
935  if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
936  return failure();
937 
938  if (hasAttribute("parameters", result.attributes)) {
939  parser.emitError(loc, "explicit `parameters` attributes not allowed");
940  return failure();
941  }
942 
943  result.addAttribute("parameters", parameters);
944  result.addAttribute(ModuleTy::getModuleTypeAttrName(result.name), modType);
945 
946  // Convert the specified array of dictionary attrs (which may have null
947  // entries) to an ArrayAttr of dictionaries.
948  SmallVector<Attribute> attrs;
949  for (auto &port : ports)
950  attrs.push_back(port.attrs ? port.attrs : builder.getDictionaryAttr({}));
951  // Add the attributes to the ports.
952  auto nonEmptyAttrsFn = [](Attribute attr) {
953  return attr && !cast<DictionaryAttr>(attr).empty();
954  };
955  if (llvm::any_of(attrs, nonEmptyAttrsFn))
956  result.addAttribute(ModuleTy::getPerPortAttrsAttrName(result.name),
957  builder.getArrayAttr(attrs));
958 
959  // Add the port locations.
960  auto unknownLoc = builder.getUnknownLoc();
961  auto nonEmptyLocsFn = [unknownLoc](Attribute attr) {
962  return attr && cast<Location>(attr) != unknownLoc;
963  };
964  SmallVector<Attribute> locs;
965  StringAttr portLocsAttrName;
966  if constexpr (std::is_same_v<ModuleTy, HWModuleOp>) {
967  // Plain modules only store the output port locations, as the input port
968  // locations will be stored in the basic block arguments.
969  portLocsAttrName = ModuleTy::getResultLocsAttrName(result.name);
970  for (auto &port : ports)
971  if (port.direction == ModulePort::Direction::Output)
972  locs.push_back(port.sourceLoc ? Location(*port.sourceLoc) : unknownLoc);
973  } else {
974  // All other modules store all port locations in a single array.
975  portLocsAttrName = ModuleTy::getPortLocsAttrName(result.name);
976  for (auto &port : ports)
977  locs.push_back(port.sourceLoc ? Location(*port.sourceLoc) : unknownLoc);
978  }
979  if (llvm::any_of(locs, nonEmptyLocsFn))
980  result.addAttribute(portLocsAttrName, builder.getArrayAttr(locs));
981 
982  // Add the entry block arguments.
983  SmallVector<OpAsmParser::Argument, 4> entryArgs;
984  for (auto &port : ports)
985  if (port.direction != ModulePort::Direction::Output)
986  entryArgs.push_back(port);
987 
988  // Parse the optional function body.
989  auto *body = result.addRegion();
990  if (std::is_same_v<ModuleTy, HWModuleOp>) {
991  if (parser.parseRegion(*body, entryArgs))
992  return failure();
993 
994  HWModuleOp::ensureTerminator(*body, parser.getBuilder(), result.location);
995  }
996  return success();
997 }
998 
999 ParseResult HWModuleOp::parse(OpAsmParser &parser, OperationState &result) {
1000  return parseHWModuleOp<HWModuleOp>(parser, result);
1001 }
1002 
1003 ParseResult HWModuleExternOp::parse(OpAsmParser &parser,
1004  OperationState &result) {
1005  return parseHWModuleOp<HWModuleExternOp>(parser, result);
1006 }
1007 
1008 ParseResult HWModuleGeneratedOp::parse(OpAsmParser &parser,
1009  OperationState &result) {
1010  return parseHWModuleOp<HWModuleGeneratedOp>(parser, result);
1011 }
1012 
1013 FunctionType getHWModuleOpType(Operation *op) {
1014  if (auto mod = dyn_cast<HWModuleLike>(op))
1015  return mod.getHWModuleType().getFuncType();
1016  return cast<FunctionType>(
1017  cast<mlir::FunctionOpInterface>(op).getFunctionType());
1018 }
1019 
1020 template <typename ModuleTy>
1021 static void printModuleOp(OpAsmPrinter &p, ModuleTy mod) {
1022  p << ' ';
1023  // Print the visibility of the module.
1024  StringRef visibilityAttrName = SymbolTable::getVisibilityAttrName();
1025  if (auto visibility = mod.getOperation()->template getAttrOfType<StringAttr>(
1026  visibilityAttrName))
1027  p << visibility.getValue() << ' ';
1028 
1029  // Print the operation and the function name.
1030  p.printSymbolName(SymbolTable::getSymbolName(mod.getOperation()).getValue());
1031  if (auto gen = dyn_cast<HWModuleGeneratedOp>(mod.getOperation())) {
1032  p << ", ";
1033  p.printSymbolName(gen.getGeneratorKind());
1034  }
1035 
1036  // Print the parameter list if present.
1037  printOptionalParameterList(p, mod.getOperation(), mod.getParameters());
1038 
1040 
1041  SmallVector<StringRef, 3> omittedAttrs;
1042  if (isa<HWModuleGeneratedOp>(mod.getOperation()))
1043  omittedAttrs.push_back("generatorKind");
1044  if constexpr (std::is_same_v<ModuleTy, HWModuleOp>)
1045  omittedAttrs.push_back(mod.getResultLocsAttrName());
1046  else
1047  omittedAttrs.push_back(mod.getPortLocsAttrName());
1048  omittedAttrs.push_back(mod.getModuleTypeAttrName());
1049  omittedAttrs.push_back(mod.getPerPortAttrsAttrName());
1050  omittedAttrs.push_back(mod.getParametersAttrName());
1051  omittedAttrs.push_back(visibilityAttrName);
1052  if (auto cmt =
1053  mod.getOperation()->template getAttrOfType<StringAttr>("comment"))
1054  if (cmt.getValue().empty())
1055  omittedAttrs.push_back("comment");
1056 
1057  mlir::function_interface_impl::printFunctionAttributes(p, mod.getOperation(),
1058  omittedAttrs);
1059 }
1060 
1061 void HWModuleExternOp::print(OpAsmPrinter &p) { printModuleOp(p, *this); }
1062 void HWModuleGeneratedOp::print(OpAsmPrinter &p) { printModuleOp(p, *this); }
1063 
1064 void HWModuleOp::print(OpAsmPrinter &p) {
1065  printModuleOp(p, *this);
1066 
1067  // Print the body if this is not an external function.
1068  Region &body = getBody();
1069  if (!body.empty()) {
1070  p << " ";
1071  p.printRegion(body, /*printEntryBlockArgs=*/false,
1072  /*printBlockTerminators=*/true);
1073  }
1074 }
1075 
1076 static LogicalResult verifyModuleCommon(HWModuleLike module) {
1077  assert(isa<HWModuleLike>(module) &&
1078  "verifier hook should only be called on modules");
1079 
1080  SmallPtrSet<Attribute, 4> paramNames;
1081 
1082  // Check parameter default values are sensible.
1083  for (auto param : module->getAttrOfType<ArrayAttr>("parameters")) {
1084  auto paramAttr = cast<ParamDeclAttr>(param);
1085 
1086  // Check that we don't have any redundant parameter names. These are
1087  // resolved by string name: reuse of the same name would cause ambiguities.
1088  if (!paramNames.insert(paramAttr.getName()).second)
1089  return module->emitOpError("parameter ")
1090  << paramAttr << " has the same name as a previous parameter";
1091 
1092  // Default values are allowed to be missing, check them if present.
1093  auto value = paramAttr.getValue();
1094  if (!value)
1095  continue;
1096 
1097  auto typedValue = dyn_cast<TypedAttr>(value);
1098  if (!typedValue)
1099  return module->emitOpError("parameter ")
1100  << paramAttr << " should have a typed value; has value " << value;
1101 
1102  if (typedValue.getType() != paramAttr.getType())
1103  return module->emitOpError("parameter ")
1104  << paramAttr << " should have type " << paramAttr.getType()
1105  << "; has type " << typedValue.getType();
1106 
1107  // Verify that this is a valid parameter value, disallowing parameter
1108  // references. We could allow parameters to refer to each other in the
1109  // future with lexical ordering if there is a need.
1110  if (failed(checkParameterInContext(value, module, module,
1111  /*disallowParamRefs=*/true)))
1112  return failure();
1113  }
1114  return success();
1115 }
1116 
1117 LogicalResult HWModuleOp::verify() {
1118  if (failed(verifyModuleCommon(*this)))
1119  return failure();
1120 
1121  auto type = getModuleType();
1122  auto *body = getBodyBlock();
1123 
1124  // Verify the number of block arguments.
1125  auto numInputs = type.getNumInputs();
1126  if (body->getNumArguments() != numInputs)
1127  return emitOpError("entry block must have")
1128  << numInputs << " arguments to match module signature";
1129 
1130  return success();
1131 }
1132 
1133 LogicalResult HWModuleExternOp::verify() { return verifyModuleCommon(*this); }
1134 
1135 std::pair<StringAttr, BlockArgument>
1136 HWModuleOp::insertInput(unsigned index, StringAttr name, Type ty) {
1137  // Find a unique name for the wire.
1138  Namespace ns;
1139  auto ports = getPortList();
1140  for (auto port : ports)
1141  ns.newName(port.name.getValue());
1142  auto nameAttr = StringAttr::get(getContext(), ns.newName(name.getValue()));
1143 
1144  Block *body = getBodyBlock();
1145 
1146  // Create a new port for the host clock.
1147  PortInfo port;
1148  port.name = nameAttr;
1150  port.type = ty;
1151  modifyModulePorts(getOperation(), {std::make_pair(index, port)}, {}, {}, {},
1152  body);
1153 
1154  // Add a new argument.
1155  return {nameAttr, body->getArgument(index)};
1156 }
1157 
1158 void HWModuleOp::insertOutputs(unsigned index,
1159  ArrayRef<std::pair<StringAttr, Value>> outputs) {
1160 
1161  auto output = cast<OutputOp>(getBodyBlock()->getTerminator());
1162  assert(index <= output->getNumOperands() && "invalid output index");
1163 
1164  // Rewrite the port list of the module.
1165  SmallVector<std::pair<unsigned, PortInfo>> indexedNewPorts;
1166  for (auto &[name, value] : outputs) {
1167  PortInfo port;
1168  port.name = name;
1170  port.type = value.getType();
1171  indexedNewPorts.emplace_back(index, port);
1172  }
1173  modifyModulePorts(getOperation(), {}, indexedNewPorts, {}, {},
1174  getBodyBlock());
1175 
1176  // Rewrite the output op.
1177  for (auto &[name, value] : outputs)
1178  output->insertOperands(index++, value);
1179 }
1180 
1181 void HWModuleOp::appendOutputs(ArrayRef<std::pair<StringAttr, Value>> outputs) {
1182  return insertOutputs(getNumOutputPorts(), outputs);
1183 }
1184 
1185 void HWModuleOp::getAsmBlockArgumentNames(mlir::Region &region,
1186  mlir::OpAsmSetValueNameFn setNameFn) {
1187  getAsmBlockArgumentNamesImpl(region, setNameFn);
1188 }
1189 
1190 void HWModuleExternOp::getAsmBlockArgumentNames(
1191  mlir::Region &region, mlir::OpAsmSetValueNameFn setNameFn) {
1192  getAsmBlockArgumentNamesImpl(region, setNameFn);
1193 }
1194 
1195 template <typename ModTy>
1196 static SmallVector<Location> getAllPortLocs(ModTy module) {
1197  auto locs = module.getPortLocs();
1198  if (locs) {
1199  SmallVector<Location> retval;
1200  retval.reserve(locs->size());
1201  for (auto l : *locs)
1202  retval.push_back(cast<Location>(l));
1203  // Either we have a length of 0 or the correct length
1204  assert(!locs->size() || locs->size() == module.getNumPorts());
1205  return retval;
1206  }
1207  return SmallVector<Location>(module.getNumPorts(),
1208  UnknownLoc::get(module.getContext()));
1209 }
1210 
1211 SmallVector<Location> HWModuleOp::getAllPortLocs() {
1212  SmallVector<Location> portLocs;
1213  portLocs.reserve(getNumPorts());
1214  auto resultLocs = getResultLocsAttr();
1215  unsigned inputCount = 0;
1216  auto modType = getModuleType();
1217  auto unknownLoc = UnknownLoc::get(getContext());
1218  auto *body = getBodyBlock();
1219  for (unsigned i = 0, e = getNumPorts(); i < e; ++i) {
1220  if (modType.isOutput(i)) {
1221  auto loc = resultLocs
1222  ? cast<Location>(
1223  resultLocs.getValue()[portLocs.size() - inputCount])
1224  : unknownLoc;
1225  portLocs.push_back(loc);
1226  } else {
1227  auto loc = body ? body->getArgument(inputCount).getLoc() : unknownLoc;
1228  portLocs.push_back(loc);
1229  ++inputCount;
1230  }
1231  }
1232  return portLocs;
1233 }
1234 
1235 SmallVector<Location> HWModuleExternOp::getAllPortLocs() {
1236  return ::getAllPortLocs(*this);
1237 }
1238 
1239 SmallVector<Location> HWModuleGeneratedOp::getAllPortLocs() {
1240  return ::getAllPortLocs(*this);
1241 }
1242 
1243 void HWModuleOp::setAllPortLocsAttrs(ArrayRef<Attribute> locs) {
1244  SmallVector<Attribute> resultLocs;
1245  unsigned inputCount = 0;
1246  auto modType = getModuleType();
1247  auto *body = getBodyBlock();
1248  for (unsigned i = 0, e = getNumPorts(); i < e; ++i) {
1249  if (modType.isOutput(i))
1250  resultLocs.push_back(locs[i]);
1251  else
1252  body->getArgument(inputCount++).setLoc(cast<Location>(locs[i]));
1253  }
1254  setResultLocsAttr(ArrayAttr::get(getContext(), resultLocs));
1255 }
1256 
1257 void HWModuleExternOp::setAllPortLocsAttrs(ArrayRef<Attribute> locs) {
1258  setPortLocsAttr(ArrayAttr::get(getContext(), locs));
1259 }
1260 
1261 void HWModuleGeneratedOp::setAllPortLocsAttrs(ArrayRef<Attribute> locs) {
1262  setPortLocsAttr(ArrayAttr::get(getContext(), locs));
1263 }
1264 
1265 template <typename ModTy>
1266 static void setAllPortNames(ArrayRef<Attribute> names, ModTy module) {
1267  auto numInputs = module.getNumInputPorts();
1268  SmallVector<Attribute> argNames(names.begin(), names.begin() + numInputs);
1269  SmallVector<Attribute> resNames(names.begin() + numInputs, names.end());
1270  auto oldType = module.getModuleType();
1271  SmallVector<ModulePort> newPorts(oldType.getPorts().begin(),
1272  oldType.getPorts().end());
1273  for (size_t i = 0UL, e = newPorts.size(); i != e; ++i)
1274  newPorts[i].name = cast<StringAttr>(names[i]);
1275  auto newType = ModuleType::get(module.getContext(), newPorts);
1276  module.setModuleType(newType);
1277 }
1278 
1279 void HWModuleOp::setAllPortNames(ArrayRef<Attribute> names) {
1280  ::setAllPortNames(names, *this);
1281 }
1282 
1283 void HWModuleExternOp::setAllPortNames(ArrayRef<Attribute> names) {
1284  ::setAllPortNames(names, *this);
1285 }
1286 
1287 void HWModuleGeneratedOp::setAllPortNames(ArrayRef<Attribute> names) {
1288  ::setAllPortNames(names, *this);
1289 }
1290 
1291 ArrayRef<Attribute> HWModuleOp::getAllPortAttrs() {
1292  auto attrs = getPerPortAttrs();
1293  if (attrs && !attrs->empty())
1294  return attrs->getValue();
1295  return {};
1296 }
1297 
1298 ArrayRef<Attribute> HWModuleExternOp::getAllPortAttrs() {
1299  auto attrs = getPerPortAttrs();
1300  if (attrs && !attrs->empty())
1301  return attrs->getValue();
1302  return {};
1303 }
1304 
1305 ArrayRef<Attribute> HWModuleGeneratedOp::getAllPortAttrs() {
1306  auto attrs = getPerPortAttrs();
1307  if (attrs && !attrs->empty())
1308  return attrs->getValue();
1309  return {};
1310 }
1311 
1312 void HWModuleOp::setAllPortAttrs(ArrayRef<Attribute> attrs) {
1313  setPerPortAttrsAttr(arrayOrEmpty(getContext(), attrs));
1314 }
1315 
1316 void HWModuleExternOp::setAllPortAttrs(ArrayRef<Attribute> attrs) {
1317  setPerPortAttrsAttr(arrayOrEmpty(getContext(), attrs));
1318 }
1319 
1320 void HWModuleGeneratedOp::setAllPortAttrs(ArrayRef<Attribute> attrs) {
1321  setPerPortAttrsAttr(arrayOrEmpty(getContext(), attrs));
1322 }
1323 
1324 void HWModuleOp::removeAllPortAttrs() {
1325  setPerPortAttrsAttr(ArrayAttr::get(getContext(), {}));
1326 }
1327 
1328 void HWModuleExternOp::removeAllPortAttrs() {
1329  setPerPortAttrsAttr(ArrayAttr::get(getContext(), {}));
1330 }
1331 
1332 void HWModuleGeneratedOp::removeAllPortAttrs() {
1333  setPerPortAttrsAttr(ArrayAttr::get(getContext(), {}));
1334 }
1335 
1336 // This probably does really unexpected stuff when you change the number of
1337 
1338 template <typename ModTy>
1339 static void setHWModuleType(ModTy &mod, ModuleType type) {
1340  auto argAttrs = mod.getAllInputAttrs();
1341  auto resAttrs = mod.getAllOutputAttrs();
1342  mod.setModuleTypeAttr(TypeAttr::get(type));
1343  unsigned newNumArgs = type.getNumInputs();
1344  unsigned newNumResults = type.getNumOutputs();
1345 
1346  auto emptyDict = DictionaryAttr::get(mod.getContext());
1347  argAttrs.resize(newNumArgs, emptyDict);
1348  resAttrs.resize(newNumResults, emptyDict);
1349 
1350  SmallVector<Attribute> attrs;
1351  attrs.append(argAttrs.begin(), argAttrs.end());
1352  attrs.append(resAttrs.begin(), resAttrs.end());
1353 
1354  if (attrs.empty())
1355  return mod.removeAllPortAttrs();
1356  mod.setAllPortAttrs(attrs);
1357 }
1358 
1359 void HWModuleOp::setHWModuleType(ModuleType type) {
1360  return ::setHWModuleType(*this, type);
1361 }
1362 
1363 void HWModuleExternOp::setHWModuleType(ModuleType type) {
1364  return ::setHWModuleType(*this, type);
1365 }
1366 
1367 void HWModuleGeneratedOp::setHWModuleType(ModuleType type) {
1368  return ::setHWModuleType(*this, type);
1369 }
1370 
1371 /// Lookup the generator for the symbol. This returns null on
1372 /// invalid IR.
1373 Operation *HWModuleGeneratedOp::getGeneratorKindOp() {
1374  auto topLevelModuleOp = (*this)->getParentOfType<ModuleOp>();
1375  return topLevelModuleOp.lookupSymbol(getGeneratorKind());
1376 }
1377 
1378 LogicalResult
1379 HWModuleGeneratedOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1380  auto *referencedKind =
1381  symbolTable.lookupNearestSymbolFrom(*this, getGeneratorKindAttr());
1382 
1383  if (referencedKind == nullptr)
1384  return emitError("Cannot find generator definition '")
1385  << getGeneratorKind() << "'";
1386 
1387  if (!isa<HWGeneratorSchemaOp>(referencedKind))
1388  return emitError("Symbol resolved to '")
1389  << referencedKind->getName()
1390  << "' which is not a HWGeneratorSchemaOp";
1391 
1392  auto referencedKindOp = dyn_cast<HWGeneratorSchemaOp>(referencedKind);
1393  auto paramRef = referencedKindOp.getRequiredAttrs();
1394  auto dict = (*this)->getAttrDictionary();
1395  for (auto str : paramRef) {
1396  auto strAttr = dyn_cast<StringAttr>(str);
1397  if (!strAttr)
1398  return emitError("Unknown attribute type, expected a string");
1399  if (!dict.get(strAttr.getValue()))
1400  return emitError("Missing attribute '") << strAttr.getValue() << "'";
1401  }
1402 
1403  return success();
1404 }
1405 
1406 LogicalResult HWModuleGeneratedOp::verify() {
1407  return verifyModuleCommon(*this);
1408 }
1409 
1410 void HWModuleGeneratedOp::getAsmBlockArgumentNames(
1411  mlir::Region &region, mlir::OpAsmSetValueNameFn setNameFn) {
1412  getAsmBlockArgumentNamesImpl(region, setNameFn);
1413 }
1414 
1415 LogicalResult HWModuleOp::verifyBody() { return success(); }
1416 
1417 template <typename ModuleTy>
1418 static SmallVector<PortInfo> getPortList(ModuleTy &mod) {
1419  auto modTy = mod.getHWModuleType();
1420  auto emptyDict = DictionaryAttr::get(mod.getContext());
1421  SmallVector<PortInfo> retval;
1422  auto locs = mod.getAllPortLocs();
1423  for (unsigned i = 0, e = modTy.getNumPorts(); i < e; ++i) {
1424  LocationAttr loc = locs[i];
1425  DictionaryAttr attrs =
1426  dyn_cast_or_null<DictionaryAttr>(mod.getPortAttrs(i));
1427  if (!attrs)
1428  attrs = emptyDict;
1429  retval.push_back({modTy.getPorts()[i],
1430  modTy.isOutput(i) ? modTy.getOutputIdForPortId(i)
1431  : modTy.getInputIdForPortId(i),
1432  attrs, loc});
1433  }
1434  return retval;
1435 }
1436 
1437 template <typename ModuleTy>
1438 static PortInfo getPort(ModuleTy &mod, size_t idx) {
1439  auto modTy = mod.getHWModuleType();
1440  auto emptyDict = DictionaryAttr::get(mod.getContext());
1441  LocationAttr loc = mod.getPortLoc(idx);
1442  DictionaryAttr attrs =
1443  dyn_cast_or_null<DictionaryAttr>(mod.getPortAttrs(idx));
1444  if (!attrs)
1445  attrs = emptyDict;
1446  return {modTy.getPorts()[idx],
1447  modTy.isOutput(idx) ? modTy.getOutputIdForPortId(idx)
1448  : modTy.getInputIdForPortId(idx),
1449  attrs, loc};
1450 }
1451 
1452 //===----------------------------------------------------------------------===//
1453 // InstanceOp
1454 //===----------------------------------------------------------------------===//
1455 
1456 /// Create a instance that refers to a known module.
1457 void InstanceOp::build(OpBuilder &builder, OperationState &result,
1458  Operation *module, StringAttr name,
1459  ArrayRef<Value> inputs, ArrayAttr parameters,
1460  InnerSymAttr innerSym) {
1461  if (!parameters)
1462  parameters = builder.getArrayAttr({});
1463 
1464  auto mod = cast<hw::HWModuleLike>(module);
1465  auto argNames = builder.getArrayAttr(mod.getInputNames());
1466  auto resultNames = builder.getArrayAttr(mod.getOutputNames());
1467 
1468  // Try to resolve the parameterized module type. If failed, use the module's
1469  // parmeterized type. If the client doesn't fix this error, the verifier will
1470  // fail.
1471  ModuleType modType = mod.getHWModuleType();
1472  FailureOr<ModuleType> resolvedModType = modType.resolveParametricTypes(
1473  parameters, result.location, /*emitErrors=*/false);
1474  if (succeeded(resolvedModType))
1475  modType = *resolvedModType;
1476  FunctionType funcType = resolvedModType->getFuncType();
1477  build(builder, result, funcType.getResults(), name,
1478  FlatSymbolRefAttr::get(SymbolTable::getSymbolName(module)), inputs,
1479  argNames, resultNames, parameters, innerSym);
1480 }
1481 
1482 std::optional<size_t> InstanceOp::getTargetResultIndex() {
1483  // Inner symbols on instance operations target the op not any result.
1484  return std::nullopt;
1485 }
1486 
1487 LogicalResult InstanceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1489  *this, getModuleNameAttr(), getInputs(), getResultTypes(), getArgNames(),
1490  getResultNames(), getParameters(), symbolTable);
1491 }
1492 
1493 LogicalResult InstanceOp::verify() {
1494  auto module = (*this)->getParentOfType<HWModuleOp>();
1495  if (!module)
1496  return success();
1497 
1498  auto moduleParameters = module->getAttrOfType<ArrayAttr>("parameters");
1500  [&](const std::function<bool(InFlightDiagnostic &)> &fn) {
1501  auto diag = emitOpError();
1502  if (fn(diag))
1503  diag.attachNote(module->getLoc()) << "module declared here";
1504  };
1506  getParameters(), moduleParameters, emitError);
1507 }
1508 
1509 ParseResult InstanceOp::parse(OpAsmParser &parser, OperationState &result) {
1510  StringAttr instanceNameAttr;
1511  InnerSymAttr innerSym;
1512  FlatSymbolRefAttr moduleNameAttr;
1513  SmallVector<OpAsmParser::UnresolvedOperand, 4> inputsOperands;
1514  SmallVector<Type, 1> inputsTypes, allResultTypes;
1515  ArrayAttr argNames, resultNames, parameters;
1516  auto noneType = parser.getBuilder().getType<NoneType>();
1517 
1518  if (parser.parseAttribute(instanceNameAttr, noneType, "instanceName",
1519  result.attributes))
1520  return failure();
1521 
1522  if (succeeded(parser.parseOptionalKeyword("sym"))) {
1523  // Parsing an optional symbol name doesn't fail, so no need to check the
1524  // result.
1525  if (parser.parseCustomAttributeWithFallback(innerSym))
1526  return failure();
1527  result.addAttribute(InnerSymbolTable::getInnerSymbolAttrName(), innerSym);
1528  }
1529 
1530  llvm::SMLoc parametersLoc, inputsOperandsLoc;
1531  if (parser.parseAttribute(moduleNameAttr, noneType, "moduleName",
1532  result.attributes) ||
1533  parser.getCurrentLocation(&parametersLoc) ||
1534  parseOptionalParameterList(parser, parameters) ||
1535  parseInputPortList(parser, inputsOperands, inputsTypes, argNames) ||
1536  parser.resolveOperands(inputsOperands, inputsTypes, inputsOperandsLoc,
1537  result.operands) ||
1538  parser.parseArrow() ||
1539  parseOutputPortList(parser, allResultTypes, resultNames) ||
1540  parser.parseOptionalAttrDict(result.attributes)) {
1541  return failure();
1542  }
1543 
1544  result.addAttribute("argNames", argNames);
1545  result.addAttribute("resultNames", resultNames);
1546  result.addAttribute("parameters", parameters);
1547  result.addTypes(allResultTypes);
1548  return success();
1549 }
1550 
1551 void InstanceOp::print(OpAsmPrinter &p) {
1552  p << ' ';
1553  p.printAttributeWithoutType(getInstanceNameAttr());
1554  if (auto attr = getInnerSymAttr()) {
1555  p << " sym ";
1556  attr.print(p);
1557  }
1558  p << ' ';
1559  p.printAttributeWithoutType(getModuleNameAttr());
1560  printOptionalParameterList(p, *this, getParameters());
1561  printInputPortList(p, *this, getInputs(), getInputs().getTypes(),
1562  getArgNames());
1563  p << " -> ";
1564  printOutputPortList(p, *this, getResultTypes(), getResultNames());
1565 
1566  p.printOptionalAttrDict(
1567  (*this)->getAttrs(),
1568  /*elidedAttrs=*/{"instanceName",
1569  InnerSymbolTable::getInnerSymbolAttrName(), "moduleName",
1570  "argNames", "resultNames", "parameters"});
1571 }
1572 
1573 //===----------------------------------------------------------------------===//
1574 // InstanceChoiceOp
1575 //===----------------------------------------------------------------------===//
1576 
1577 std::optional<size_t> InstanceChoiceOp::getTargetResultIndex() {
1578  // Inner symbols on instance operations target the op not any result.
1579  return std::nullopt;
1580 }
1581 
1582 LogicalResult
1583 InstanceChoiceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1584  for (Attribute name : getModuleNamesAttr()) {
1586  *this, cast<FlatSymbolRefAttr>(name), getInputs(), getResultTypes(),
1587  getArgNames(), getResultNames(), getParameters(), symbolTable))) {
1588  return failure();
1589  }
1590  }
1591  return success();
1592 }
1593 
1594 LogicalResult InstanceChoiceOp::verify() {
1595  auto module = (*this)->getParentOfType<HWModuleOp>();
1596  if (!module)
1597  return success();
1598 
1599  auto moduleParameters = module->getAttrOfType<ArrayAttr>("parameters");
1601  [&](const std::function<bool(InFlightDiagnostic &)> &fn) {
1602  auto diag = emitOpError();
1603  if (fn(diag))
1604  diag.attachNote(module->getLoc()) << "module declared here";
1605  };
1607  getParameters(), moduleParameters, emitError);
1608 }
1609 
1610 ParseResult InstanceChoiceOp::parse(OpAsmParser &parser,
1611  OperationState &result) {
1612  StringAttr optionNameAttr;
1613  StringAttr instanceNameAttr;
1614  InnerSymAttr innerSym;
1615  SmallVector<Attribute> moduleNames;
1616  SmallVector<Attribute> caseNames;
1617  SmallVector<OpAsmParser::UnresolvedOperand, 4> inputsOperands;
1618  SmallVector<Type, 1> inputsTypes, allResultTypes;
1619  ArrayAttr argNames, resultNames, parameters;
1620  auto noneType = parser.getBuilder().getType<NoneType>();
1621 
1622  if (parser.parseAttribute(instanceNameAttr, noneType, "instanceName",
1623  result.attributes))
1624  return failure();
1625 
1626  if (succeeded(parser.parseOptionalKeyword("sym"))) {
1627  // Parsing an optional symbol name doesn't fail, so no need to check the
1628  // result.
1629  if (parser.parseCustomAttributeWithFallback(innerSym))
1630  return failure();
1631  result.addAttribute(InnerSymbolTable::getInnerSymbolAttrName(), innerSym);
1632  }
1633 
1634  if (parser.parseKeyword("option") ||
1635  parser.parseAttribute(optionNameAttr, noneType, "optionName",
1636  result.attributes))
1637  return failure();
1638 
1639  FlatSymbolRefAttr defaultModuleName;
1640  if (parser.parseAttribute(defaultModuleName))
1641  return failure();
1642  moduleNames.push_back(defaultModuleName);
1643 
1644  while (succeeded(parser.parseOptionalKeyword("or"))) {
1645  FlatSymbolRefAttr moduleName;
1646  StringAttr targetName;
1647  if (parser.parseAttribute(moduleName) ||
1648  parser.parseOptionalKeyword("if") || parser.parseAttribute(targetName))
1649  return failure();
1650  moduleNames.push_back(moduleName);
1651  caseNames.push_back(targetName);
1652  }
1653 
1654  llvm::SMLoc parametersLoc, inputsOperandsLoc;
1655  if (parser.getCurrentLocation(&parametersLoc) ||
1656  parseOptionalParameterList(parser, parameters) ||
1657  parseInputPortList(parser, inputsOperands, inputsTypes, argNames) ||
1658  parser.resolveOperands(inputsOperands, inputsTypes, inputsOperandsLoc,
1659  result.operands) ||
1660  parser.parseArrow() ||
1661  parseOutputPortList(parser, allResultTypes, resultNames) ||
1662  parser.parseOptionalAttrDict(result.attributes)) {
1663  return failure();
1664  }
1665 
1666  result.addAttribute("moduleNames",
1667  ArrayAttr::get(parser.getContext(), moduleNames));
1668  result.addAttribute("caseNames",
1669  ArrayAttr::get(parser.getContext(), caseNames));
1670  result.addAttribute("argNames", argNames);
1671  result.addAttribute("resultNames", resultNames);
1672  result.addAttribute("parameters", parameters);
1673  result.addTypes(allResultTypes);
1674  return success();
1675 }
1676 
1677 void InstanceChoiceOp::print(OpAsmPrinter &p) {
1678  p << ' ';
1679  p.printAttributeWithoutType(getInstanceNameAttr());
1680  if (auto attr = getInnerSymAttr()) {
1681  p << " sym ";
1682  attr.print(p);
1683  }
1684  p << " option " << getOptionNameAttr() << ' ';
1685 
1686  auto moduleNames = getModuleNamesAttr();
1687  auto caseNames = getCaseNamesAttr();
1688  assert(moduleNames.size() == caseNames.size() + 1);
1689 
1690  p.printAttributeWithoutType(moduleNames[0]);
1691  for (size_t i = 0, n = caseNames.size(); i < n; ++i) {
1692  p << " or ";
1693  p.printAttributeWithoutType(moduleNames[i + 1]);
1694  p << " if ";
1695  p.printAttributeWithoutType(caseNames[i]);
1696  }
1697 
1698  printOptionalParameterList(p, *this, getParameters());
1699  printInputPortList(p, *this, getInputs(), getInputs().getTypes(),
1700  getArgNames());
1701  p << " -> ";
1702  printOutputPortList(p, *this, getResultTypes(), getResultNames());
1703 
1704  p.printOptionalAttrDict(
1705  (*this)->getAttrs(),
1706  /*elidedAttrs=*/{"instanceName",
1707  InnerSymbolTable::getInnerSymbolAttrName(),
1708  "moduleNames", "caseNames", "argNames", "resultNames",
1709  "parameters", "optionName"});
1710 }
1711 
1712 ArrayAttr InstanceChoiceOp::getReferencedModuleNamesAttr() {
1713  SmallVector<Attribute> moduleNames;
1714  for (Attribute attr : getModuleNamesAttr()) {
1715  moduleNames.push_back(cast<FlatSymbolRefAttr>(attr).getAttr());
1716  }
1717  return ArrayAttr::get(getContext(), moduleNames);
1718 }
1719 
1720 //===----------------------------------------------------------------------===//
1721 // HWOutputOp
1722 //===----------------------------------------------------------------------===//
1723 
1724 /// Verify that the num of operands and types fit the declared results.
1725 LogicalResult OutputOp::verify() {
1726  // Check that the we (hw.output) have the same number of operands as our
1727  // region has results.
1728  ModuleType modType;
1729  if (auto mod = dyn_cast<HWModuleOp>((*this)->getParentOp()))
1730  modType = mod.getHWModuleType();
1731  else {
1732  emitOpError("must have a module parent");
1733  return failure();
1734  }
1735  auto modResults = modType.getOutputTypes();
1736  OperandRange outputValues = getOperands();
1737  if (modResults.size() != outputValues.size()) {
1738  emitOpError("must have same number of operands as region results.");
1739  return failure();
1740  }
1741 
1742  // Check that the types of our operands and the region's results match.
1743  for (size_t i = 0, e = modResults.size(); i < e; ++i) {
1744  if (modResults[i] != outputValues[i].getType()) {
1745  emitOpError("output types must match module. In "
1746  "operand ")
1747  << i << ", expected " << modResults[i] << ", but got "
1748  << outputValues[i].getType() << ".";
1749  return failure();
1750  }
1751  }
1752 
1753  return success();
1754 }
1755 
1756 //===----------------------------------------------------------------------===//
1757 // Other Operations
1758 //===----------------------------------------------------------------------===//
1759 
1760 static ParseResult parseSliceTypes(OpAsmParser &p, Type &srcType,
1761  Type &idxType) {
1762  Type type;
1763  if (p.parseType(type))
1764  return p.emitError(p.getCurrentLocation(), "Expected type");
1765  auto arrType = type_dyn_cast<ArrayType>(type);
1766  if (!arrType)
1767  return p.emitError(p.getCurrentLocation(), "Expected !hw.array type");
1768  srcType = type;
1769  unsigned idxWidth = llvm::Log2_64_Ceil(arrType.getNumElements());
1770  idxType = IntegerType::get(p.getBuilder().getContext(), idxWidth);
1771  return success();
1772 }
1773 
1774 static void printSliceTypes(OpAsmPrinter &p, Operation *, Type srcType,
1775  Type idxType) {
1776  p.printType(srcType);
1777 }
1778 
1779 ParseResult ArrayCreateOp::parse(OpAsmParser &parser, OperationState &result) {
1780  llvm::SMLoc inputOperandsLoc = parser.getCurrentLocation();
1781  llvm::SmallVector<OpAsmParser::UnresolvedOperand, 16> operands;
1782  Type elemType;
1783 
1784  if (parser.parseOperandList(operands) ||
1785  parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
1786  parser.parseType(elemType))
1787  return failure();
1788 
1789  if (operands.size() == 0)
1790  return parser.emitError(inputOperandsLoc,
1791  "Cannot construct an array of length 0");
1792  result.addTypes({ArrayType::get(elemType, operands.size())});
1793 
1794  for (auto operand : operands)
1795  if (parser.resolveOperand(operand, elemType, result.operands))
1796  return failure();
1797  return success();
1798 }
1799 
1800 void ArrayCreateOp::print(OpAsmPrinter &p) {
1801  p << " ";
1802  p.printOperands(getInputs());
1803  p.printOptionalAttrDict((*this)->getAttrs());
1804  p << " : " << getInputs()[0].getType();
1805 }
1806 
1807 void ArrayCreateOp::build(OpBuilder &b, OperationState &state,
1808  ValueRange values) {
1809  assert(values.size() > 0 && "Cannot build array of zero elements");
1810  Type elemType = values[0].getType();
1811  assert(llvm::all_of(
1812  values,
1813  [elemType](Value v) -> bool { return v.getType() == elemType; }) &&
1814  "All values must have same type.");
1815  build(b, state, ArrayType::get(elemType, values.size()), values);
1816 }
1817 
1818 LogicalResult ArrayCreateOp::verify() {
1819  unsigned returnSize = cast<ArrayType>(getType()).getNumElements();
1820  if (getInputs().size() != returnSize)
1821  return failure();
1822  return success();
1823 }
1824 
1825 OpFoldResult ArrayCreateOp::fold(FoldAdaptor adaptor) {
1826  if (llvm::any_of(adaptor.getInputs(), [](Attribute attr) { return !attr; }))
1827  return {};
1828  return ArrayAttr::get(getContext(), adaptor.getInputs());
1829 }
1830 
1831 // Check whether an integer value is an offset from a base.
1832 bool hw::isOffset(Value base, Value index, uint64_t offset) {
1833  if (auto constBase = base.getDefiningOp<hw::ConstantOp>()) {
1834  if (auto constIndex = index.getDefiningOp<hw::ConstantOp>()) {
1835  // If both values are a constant, check if index == base + offset.
1836  // To account for overflow, the addition is performed with an extra bit
1837  // and the offset is asserted to fit in the bit width of the base.
1838  auto baseValue = constBase.getValue();
1839  auto indexValue = constIndex.getValue();
1840 
1841  unsigned bits = baseValue.getBitWidth();
1842  assert(bits == indexValue.getBitWidth() && "mismatched widths");
1843 
1844  if (bits < 64 && offset >= (1ull << bits))
1845  return false;
1846 
1847  APInt baseExt = baseValue.zextOrTrunc(bits + 1);
1848  APInt indexExt = indexValue.zextOrTrunc(bits + 1);
1849  return baseExt + offset == indexExt;
1850  }
1851  }
1852  return false;
1853 }
1854 
1855 // Canonicalize a create of consecutive elements to a slice.
1856 static LogicalResult foldCreateToSlice(ArrayCreateOp op,
1857  PatternRewriter &rewriter) {
1858  // Do not canonicalize create of get into a slice.
1859  auto arrayTy = hw::type_cast<ArrayType>(op.getType());
1860  if (arrayTy.getNumElements() <= 1)
1861  return failure();
1862  auto elemTy = arrayTy.getElementType();
1863 
1864  // Check if create arguments are consecutive elements of the same array.
1865  // Attempt to break a create of gets into a sequence of consecutive intervals.
1866  struct Chunk {
1867  Value input;
1868  Value index;
1869  size_t size;
1870  };
1871  SmallVector<Chunk> chunks;
1872  for (Value value : llvm::reverse(op.getInputs())) {
1873  auto get = value.getDefiningOp<ArrayGetOp>();
1874  if (!get)
1875  return failure();
1876 
1877  Value input = get.getInput();
1878  Value index = get.getIndex();
1879  if (!chunks.empty()) {
1880  auto &c = *chunks.rbegin();
1881  if (c.input == get.getInput() && isOffset(c.index, index, c.size)) {
1882  c.size++;
1883  continue;
1884  }
1885  }
1886 
1887  chunks.push_back(Chunk{input, index, 1});
1888  }
1889 
1890  // If there is a single slice, eliminate the create.
1891  if (chunks.size() == 1) {
1892  auto &chunk = chunks[0];
1893  rewriter.replaceOp(op, rewriter.createOrFold<ArraySliceOp>(
1894  op.getLoc(), arrayTy, chunk.input, chunk.index));
1895  return success();
1896  }
1897 
1898  // If the number of chunks is significantly less than the number of
1899  // elements, replace the create with a concat of the identified slices.
1900  if (chunks.size() * 2 < arrayTy.getNumElements()) {
1901  SmallVector<Value> slices;
1902  for (auto &chunk : llvm::reverse(chunks)) {
1903  auto sliceTy = ArrayType::get(elemTy, chunk.size);
1904  slices.push_back(rewriter.createOrFold<ArraySliceOp>(
1905  op.getLoc(), sliceTy, chunk.input, chunk.index));
1906  }
1907  rewriter.replaceOpWithNewOp<ArrayConcatOp>(op, arrayTy, slices);
1908  return success();
1909  }
1910 
1911  return failure();
1912 }
1913 
1915  PatternRewriter &rewriter) {
1916  if (succeeded(foldCreateToSlice(op, rewriter)))
1917  return success();
1918  return failure();
1919 }
1920 
1921 Value ArrayCreateOp::getUniformElement() {
1922  if (!getInputs().empty() && llvm::all_equal(getInputs()))
1923  return getInputs()[0];
1924  return {};
1925 }
1926 
1927 static std::optional<uint64_t> getUIntFromValue(Value value) {
1928  auto idxOp = dyn_cast_or_null<ConstantOp>(value.getDefiningOp());
1929  if (!idxOp)
1930  return std::nullopt;
1931  APInt idxAttr = idxOp.getValue();
1932  if (idxAttr.getBitWidth() > 64)
1933  return std::nullopt;
1934  return idxAttr.getLimitedValue();
1935 }
1936 
1937 LogicalResult ArraySliceOp::verify() {
1938  unsigned inputSize =
1939  type_cast<ArrayType>(getInput().getType()).getNumElements();
1940  if (llvm::Log2_64_Ceil(inputSize) !=
1941  getLowIndex().getType().getIntOrFloatBitWidth())
1942  return emitOpError(
1943  "ArraySlice: index width must match clog2 of array size");
1944  return success();
1945 }
1946 
1947 OpFoldResult ArraySliceOp::fold(FoldAdaptor adaptor) {
1948  // If we are slicing the entire input, then return it.
1949  if (getType() == getInput().getType())
1950  return getInput();
1951  return {};
1952 }
1953 
1954 LogicalResult ArraySliceOp::canonicalize(ArraySliceOp op,
1955  PatternRewriter &rewriter) {
1956  auto sliceTy = hw::type_cast<ArrayType>(op.getType());
1957  auto elemTy = sliceTy.getElementType();
1958  uint64_t sliceSize = sliceTy.getNumElements();
1959  if (sliceSize == 0)
1960  return failure();
1961 
1962  if (sliceSize == 1) {
1963  // slice(a, n) -> create(a[n])
1964  auto get = rewriter.create<ArrayGetOp>(op.getLoc(), op.getInput(),
1965  op.getLowIndex());
1966  rewriter.replaceOpWithNewOp<ArrayCreateOp>(op, op.getType(),
1967  get.getResult());
1968  return success();
1969  }
1970 
1971  auto offsetOpt = getUIntFromValue(op.getLowIndex());
1972  if (!offsetOpt)
1973  return failure();
1974 
1975  auto inputOp = op.getInput().getDefiningOp();
1976  if (auto inputSlice = dyn_cast_or_null<ArraySliceOp>(inputOp)) {
1977  // slice(slice(a, n), m) -> slice(a, n + m)
1978  if (inputSlice == op)
1979  return failure();
1980 
1981  auto inputIndex = inputSlice.getLowIndex();
1982  auto inputOffsetOpt = getUIntFromValue(inputIndex);
1983  if (!inputOffsetOpt)
1984  return failure();
1985 
1986  uint64_t offset = *offsetOpt + *inputOffsetOpt;
1987  auto lowIndex =
1988  rewriter.create<ConstantOp>(op.getLoc(), inputIndex.getType(), offset);
1989  rewriter.replaceOpWithNewOp<ArraySliceOp>(op, op.getType(),
1990  inputSlice.getInput(), lowIndex);
1991  return success();
1992  }
1993 
1994  if (auto inputCreate = dyn_cast_or_null<ArrayCreateOp>(inputOp)) {
1995  // slice(create(a0, a1, ..., an), m) -> create(am, ...)
1996  auto inputs = inputCreate.getInputs();
1997 
1998  uint64_t begin = inputs.size() - *offsetOpt - sliceSize;
1999  rewriter.replaceOpWithNewOp<ArrayCreateOp>(op, op.getType(),
2000  inputs.slice(begin, sliceSize));
2001  return success();
2002  }
2003 
2004  if (auto inputConcat = dyn_cast_or_null<ArrayConcatOp>(inputOp)) {
2005  // slice(concat(a1, a2, ...)) -> concat(a2, slice(a3, ..), ...)
2006  SmallVector<Value> chunks;
2007  uint64_t sliceStart = *offsetOpt;
2008  for (auto input : llvm::reverse(inputConcat.getInputs())) {
2009  // Check whether the input intersects with the slice.
2010  uint64_t inputSize =
2011  hw::type_cast<ArrayType>(input.getType()).getNumElements();
2012  if (inputSize == 0 || inputSize <= sliceStart) {
2013  sliceStart -= inputSize;
2014  continue;
2015  }
2016 
2017  // Find the indices to slice from this input by intersection.
2018  uint64_t cutEnd = std::min(inputSize, sliceStart + sliceSize);
2019  uint64_t cutSize = cutEnd - sliceStart;
2020  assert(cutSize != 0 && "slice cannot be empty");
2021 
2022  if (cutSize == inputSize) {
2023  // The whole input fits in the slice, add it.
2024  assert(sliceStart == 0 && "invalid cut size");
2025  chunks.push_back(input);
2026  } else {
2027  // Slice the required bits from the input.
2028  unsigned width = inputSize == 1 ? 1 : llvm::Log2_64_Ceil(inputSize);
2029  auto lowIndex = rewriter.create<ConstantOp>(
2030  op.getLoc(), rewriter.getIntegerType(width), sliceStart);
2031  chunks.push_back(rewriter.create<ArraySliceOp>(
2032  op.getLoc(), hw::ArrayType::get(elemTy, cutSize), input, lowIndex));
2033  }
2034 
2035  sliceStart = 0;
2036  sliceSize -= cutSize;
2037  if (sliceSize == 0)
2038  break;
2039  }
2040 
2041  assert(chunks.size() > 0 && "missing sliced items");
2042  if (chunks.size() == 1)
2043  rewriter.replaceOp(op, chunks[0]);
2044  else
2045  rewriter.replaceOpWithNewOp<ArrayConcatOp>(
2046  op, llvm::to_vector(llvm::reverse(chunks)));
2047  return success();
2048  }
2049  return failure();
2050 }
2051 
2052 //===----------------------------------------------------------------------===//
2053 // ArrayConcatOp
2054 //===----------------------------------------------------------------------===//
2055 
2056 static ParseResult parseArrayConcatTypes(OpAsmParser &p,
2057  SmallVectorImpl<Type> &inputTypes,
2058  Type &resultType) {
2059  Type elemType;
2060  uint64_t resultSize = 0;
2061 
2062  auto parseElement = [&]() -> ParseResult {
2063  Type ty;
2064  if (p.parseType(ty))
2065  return failure();
2066  auto arrTy = type_dyn_cast<ArrayType>(ty);
2067  if (!arrTy)
2068  return p.emitError(p.getCurrentLocation(), "Expected !hw.array type");
2069  if (elemType && elemType != arrTy.getElementType())
2070  return p.emitError(p.getCurrentLocation(), "Expected array element type ")
2071  << elemType;
2072 
2073  elemType = arrTy.getElementType();
2074  inputTypes.push_back(ty);
2075  resultSize += arrTy.getNumElements();
2076  return success();
2077  };
2078 
2079  if (p.parseCommaSeparatedList(parseElement))
2080  return failure();
2081 
2082  resultType = ArrayType::get(elemType, resultSize);
2083  return success();
2084 }
2085 
2086 static void printArrayConcatTypes(OpAsmPrinter &p, Operation *,
2087  TypeRange inputTypes, Type resultType) {
2088  llvm::interleaveComma(inputTypes, p, [&p](Type t) { p << t; });
2089 }
2090 
2091 void ArrayConcatOp::build(OpBuilder &b, OperationState &state,
2092  ValueRange values) {
2093  assert(!values.empty() && "Cannot build array of zero elements");
2094  ArrayType arrayTy = cast<ArrayType>(values[0].getType());
2095  Type elemTy = arrayTy.getElementType();
2096  assert(llvm::all_of(values,
2097  [elemTy](Value v) -> bool {
2098  return isa<ArrayType>(v.getType()) &&
2099  cast<ArrayType>(v.getType()).getElementType() ==
2100  elemTy;
2101  }) &&
2102  "All values must be of ArrayType with the same element type.");
2103 
2104  uint64_t resultSize = 0;
2105  for (Value val : values)
2106  resultSize += cast<ArrayType>(val.getType()).getNumElements();
2107  build(b, state, ArrayType::get(elemTy, resultSize), values);
2108 }
2109 
2110 OpFoldResult ArrayConcatOp::fold(FoldAdaptor adaptor) {
2111  auto inputs = adaptor.getInputs();
2112  SmallVector<Attribute> array;
2113  for (size_t i = 0, e = getNumOperands(); i < e; ++i) {
2114  if (!inputs[i])
2115  return {};
2116  llvm::copy(cast<ArrayAttr>(inputs[i]), std::back_inserter(array));
2117  }
2118  return ArrayAttr::get(getContext(), array);
2119 }
2120 
2121 // Flatten a concatenation of array creates into a single create.
2122 static bool flattenConcatOp(ArrayConcatOp op, PatternRewriter &rewriter) {
2123  for (auto input : op.getInputs())
2124  if (!input.getDefiningOp<ArrayCreateOp>())
2125  return false;
2126 
2127  SmallVector<Value> items;
2128  for (auto input : op.getInputs()) {
2129  auto create = cast<ArrayCreateOp>(input.getDefiningOp());
2130  for (auto item : create.getInputs())
2131  items.push_back(item);
2132  }
2133 
2134  rewriter.replaceOpWithNewOp<ArrayCreateOp>(op, items);
2135  return true;
2136 }
2137 
2138 // Merge consecutive slice expressions in a concatenation.
2139 static bool mergeConcatSlices(ArrayConcatOp op, PatternRewriter &rewriter) {
2140  struct Slice {
2141  Value input;
2142  Value index;
2143  size_t size;
2144  Value op;
2145  SmallVector<Location> locs;
2146  };
2147 
2148  SmallVector<Value> items;
2149  std::optional<Slice> last;
2150  bool changed = false;
2151 
2152  auto concatenate = [&] {
2153  // If there is only one op in the slice, place it to the items list.
2154  if (!last)
2155  return;
2156  if (last->op) {
2157  items.push_back(last->op);
2158  last.reset();
2159  return;
2160  }
2161 
2162  // Otherwise, create a new slice of with the given size and place it.
2163  // In this case, the concat op is replaced, using the new argument.
2164  changed = true;
2165  auto loc = FusedLoc::get(op.getContext(), last->locs);
2166  auto origTy = hw::type_cast<ArrayType>(last->input.getType());
2167  auto arrayTy = ArrayType::get(origTy.getElementType(), last->size);
2168  items.push_back(rewriter.createOrFold<ArraySliceOp>(
2169  loc, arrayTy, last->input, last->index));
2170 
2171  last.reset();
2172  };
2173 
2174  auto append = [&](Value op, Value input, Value index, size_t size) {
2175  // If this slice is an extension of the previous one, extend the size
2176  // saved. In this case, a new slice of is created and the concatenation
2177  // operator is rewritten. Otherwise, flush the last slice.
2178  if (last) {
2179  if (last->input == input && isOffset(last->index, index, last->size)) {
2180  last->size += size;
2181  last->op = {};
2182  last->locs.push_back(op.getLoc());
2183  return;
2184  }
2185  concatenate();
2186  }
2187  last.emplace(Slice{input, index, size, op, {op.getLoc()}});
2188  };
2189 
2190  for (auto item : llvm::reverse(op.getInputs())) {
2191  if (auto slice = item.getDefiningOp<ArraySliceOp>()) {
2192  auto size = hw::type_cast<ArrayType>(slice.getType()).getNumElements();
2193  append(item, slice.getInput(), slice.getLowIndex(), size);
2194  continue;
2195  }
2196 
2197  if (auto create = item.getDefiningOp<ArrayCreateOp>()) {
2198  if (create.getInputs().size() == 1) {
2199  if (auto get = create.getInputs()[0].getDefiningOp<ArrayGetOp>()) {
2200  append(item, get.getInput(), get.getIndex(), 1);
2201  continue;
2202  }
2203  }
2204  }
2205 
2206  concatenate();
2207  items.push_back(item);
2208  }
2209  concatenate();
2210 
2211  if (!changed)
2212  return false;
2213 
2214  if (items.size() == 1) {
2215  rewriter.replaceOp(op, items[0]);
2216  } else {
2217  std::reverse(items.begin(), items.end());
2218  rewriter.replaceOpWithNewOp<ArrayConcatOp>(op, items);
2219  }
2220  return true;
2221 }
2222 
2224  PatternRewriter &rewriter) {
2225  // concat(create(a1, ...), create(a3, ...), ...) -> create(a1, ..., a3, ...)
2226  if (flattenConcatOp(op, rewriter))
2227  return success();
2228 
2229  // concat(slice(a, n, m), slice(a, n + m, p)) -> concat(slice(a, n, m + p))
2230  if (mergeConcatSlices(op, rewriter))
2231  return success();
2232 
2233  return failure();
2234 }
2235 
2236 //===----------------------------------------------------------------------===//
2237 // EnumConstantOp
2238 //===----------------------------------------------------------------------===//
2239 
2240 ParseResult EnumConstantOp::parse(OpAsmParser &parser, OperationState &result) {
2241  // Parse a Type instead of an EnumType since the type might be a type alias.
2242  // The validity of the canonical type is checked during construction of the
2243  // EnumFieldAttr.
2244  Type type;
2245  StringRef field;
2246 
2247  auto loc = parser.getEncodedSourceLoc(parser.getCurrentLocation());
2248  if (parser.parseKeyword(&field) || parser.parseColonType(type))
2249  return failure();
2250 
2251  auto fieldAttr = EnumFieldAttr::get(
2252  loc, StringAttr::get(parser.getContext(), field), type);
2253 
2254  if (!fieldAttr)
2255  return failure();
2256 
2257  result.addAttribute("field", fieldAttr);
2258  result.addTypes(type);
2259 
2260  return success();
2261 }
2262 
2263 void EnumConstantOp::print(OpAsmPrinter &p) {
2264  p << " " << getField().getField().getValue() << " : "
2265  << getField().getType().getValue();
2266 }
2267 
2269  function_ref<void(Value, StringRef)> setNameFn) {
2270  setNameFn(getResult(), getField().getField().str());
2271 }
2272 
2273 void EnumConstantOp::build(OpBuilder &builder, OperationState &odsState,
2274  EnumFieldAttr field) {
2275  return build(builder, odsState, field.getType().getValue(), field);
2276 }
2277 
2278 OpFoldResult EnumConstantOp::fold(FoldAdaptor adaptor) {
2279  assert(adaptor.getOperands().empty() && "constant has no operands");
2280  return getFieldAttr();
2281 }
2282 
2283 LogicalResult EnumConstantOp::verify() {
2284  auto fieldAttr = getFieldAttr();
2285  auto fieldType = fieldAttr.getType().getValue();
2286  // This check ensures that we are using the exact same type, without looking
2287  // through type aliases.
2288  if (fieldType != getType())
2289  emitOpError("return type ")
2290  << getType() << " does not match attribute type " << fieldAttr;
2291  return success();
2292 }
2293 
2294 //===----------------------------------------------------------------------===//
2295 // EnumCmpOp
2296 //===----------------------------------------------------------------------===//
2297 
2298 LogicalResult EnumCmpOp::verify() {
2299  // Compare the canonical types.
2300  auto lhsType = type_cast<EnumType>(getLhs().getType());
2301  auto rhsType = type_cast<EnumType>(getRhs().getType());
2302  if (rhsType != lhsType)
2303  emitOpError("types do not match");
2304  return success();
2305 }
2306 
2307 //===----------------------------------------------------------------------===//
2308 // StructCreateOp
2309 //===----------------------------------------------------------------------===//
2310 
2311 ParseResult StructCreateOp::parse(OpAsmParser &parser, OperationState &result) {
2312  llvm::SMLoc inputOperandsLoc = parser.getCurrentLocation();
2313  llvm::SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
2314  Type declOrAliasType;
2315 
2316  if (parser.parseLParen() || parser.parseOperandList(operands) ||
2317  parser.parseRParen() || parser.parseOptionalAttrDict(result.attributes) ||
2318  parser.parseColonType(declOrAliasType))
2319  return failure();
2320 
2321  auto declType = type_dyn_cast<StructType>(declOrAliasType);
2322  if (!declType)
2323  return parser.emitError(parser.getNameLoc(),
2324  "expected !hw.struct type or alias");
2325 
2326  llvm::SmallVector<Type, 4> structInnerTypes;
2327  declType.getInnerTypes(structInnerTypes);
2328  result.addTypes(declOrAliasType);
2329 
2330  if (parser.resolveOperands(operands, structInnerTypes, inputOperandsLoc,
2331  result.operands))
2332  return failure();
2333  return success();
2334 }
2335 
2336 void StructCreateOp::print(OpAsmPrinter &printer) {
2337  printer << " (";
2338  printer.printOperands(getInput());
2339  printer << ")";
2340  printer.printOptionalAttrDict((*this)->getAttrs());
2341  printer << " : " << getType();
2342 }
2343 
2344 LogicalResult StructCreateOp::verify() {
2345  auto elements = hw::type_cast<StructType>(getType()).getElements();
2346 
2347  if (elements.size() != getInput().size())
2348  return emitOpError("structure field count mismatch");
2349 
2350  for (const auto &[field, value] : llvm::zip(elements, getInput()))
2351  if (field.type != value.getType())
2352  return emitOpError("structure field `")
2353  << field.name << "` type does not match";
2354 
2355  return success();
2356 }
2357 
2358 OpFoldResult StructCreateOp::fold(FoldAdaptor adaptor) {
2359  // struct_create(struct_explode(x)) => x
2360  if (!getInput().empty())
2361  if (auto explodeOp = getInput()[0].getDefiningOp<StructExplodeOp>();
2362  explodeOp && getInput() == explodeOp.getResults() &&
2363  getResult().getType() == explodeOp.getInput().getType())
2364  return explodeOp.getInput();
2365 
2366  auto inputs = adaptor.getInput();
2367  if (llvm::any_of(inputs, [](Attribute attr) { return !attr; }))
2368  return {};
2369  return ArrayAttr::get(getContext(), inputs);
2370 }
2371 
2372 //===----------------------------------------------------------------------===//
2373 // StructExplodeOp
2374 //===----------------------------------------------------------------------===//
2375 
2376 ParseResult StructExplodeOp::parse(OpAsmParser &parser,
2377  OperationState &result) {
2378  OpAsmParser::UnresolvedOperand operand;
2379  Type declType;
2380 
2381  if (parser.parseOperand(operand) ||
2382  parser.parseOptionalAttrDict(result.attributes) ||
2383  parser.parseColonType(declType))
2384  return failure();
2385  auto structType = type_dyn_cast<StructType>(declType);
2386  if (!structType)
2387  return parser.emitError(parser.getNameLoc(),
2388  "invalid kind of type specified");
2389 
2390  llvm::SmallVector<Type, 4> structInnerTypes;
2391  structType.getInnerTypes(structInnerTypes);
2392  result.addTypes(structInnerTypes);
2393 
2394  if (parser.resolveOperand(operand, declType, result.operands))
2395  return failure();
2396  return success();
2397 }
2398 
2399 void StructExplodeOp::print(OpAsmPrinter &printer) {
2400  printer << " ";
2401  printer.printOperand(getInput());
2402  printer.printOptionalAttrDict((*this)->getAttrs());
2403  printer << " : " << getInput().getType();
2404 }
2405 
2406 LogicalResult StructExplodeOp::fold(FoldAdaptor adaptor,
2407  SmallVectorImpl<OpFoldResult> &results) {
2408  auto input = adaptor.getInput();
2409  if (!input)
2410  return failure();
2411  llvm::copy(cast<ArrayAttr>(input), std::back_inserter(results));
2412  return success();
2413 }
2414 
2415 LogicalResult StructExplodeOp::canonicalize(StructExplodeOp op,
2416  PatternRewriter &rewriter) {
2417  auto *inputOp = op.getInput().getDefiningOp();
2418  auto elements = type_cast<StructType>(op.getInput().getType()).getElements();
2419  auto result = failure();
2420  auto opResults = op.getResults();
2421  for (uint32_t index = 0; index < elements.size(); index++) {
2422  if (auto foldResult = foldStructExtract(inputOp, index)) {
2423  rewriter.replaceAllUsesWith(opResults[index], foldResult);
2424  result = success();
2425  }
2426  }
2427  return result;
2428 }
2429 
2431  function_ref<void(Value, StringRef)> setNameFn) {
2432  auto structType = type_cast<StructType>(getInput().getType());
2433  for (auto [res, field] : llvm::zip(getResults(), structType.getElements()))
2434  setNameFn(res, field.name.str());
2435 }
2436 
2437 void StructExplodeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
2438  Value input) {
2439  StructType inputType = dyn_cast<StructType>(input.getType());
2440  assert(inputType);
2441  SmallVector<Type, 16> fieldTypes;
2442  for (auto field : inputType.getElements())
2443  fieldTypes.push_back(field.type);
2444  build(odsBuilder, odsState, fieldTypes, input);
2445 }
2446 
2447 //===----------------------------------------------------------------------===//
2448 // StructExtractOp
2449 //===----------------------------------------------------------------------===//
2450 
2451 /// Ensure an aggregate op's field index is within the bounds of
2452 /// the aggregate type and the accessed field is of 'elementType'.
2453 template <typename AggregateOp, typename AggregateType>
2454 static LogicalResult verifyAggregateFieldIndexAndType(AggregateOp &op,
2455  AggregateType aggType,
2456  Type elementType) {
2457  auto index = op.getFieldIndex();
2458  if (index >= aggType.getElements().size())
2459  return op.emitOpError() << "field index " << index
2460  << " exceeds element count of aggregate type";
2461 
2463  getCanonicalType(aggType.getElements()[index].type))
2464  return op.emitOpError()
2465  << "type " << aggType.getElements()[index].type
2466  << " of accessed field in aggregate at index " << index
2467  << " does not match expected type " << elementType;
2468 
2469  return success();
2470 }
2471 
2472 LogicalResult StructExtractOp::verify() {
2473  return verifyAggregateFieldIndexAndType<StructExtractOp, StructType>(
2474  *this, getInput().getType(), getType());
2475 }
2476 
2477 /// Use the same parser for both struct_extract and union_extract since the
2478 /// syntax is identical.
2479 template <typename AggregateType>
2480 static ParseResult parseExtractOp(OpAsmParser &parser, OperationState &result) {
2481  OpAsmParser::UnresolvedOperand operand;
2482  StringAttr fieldName;
2483  Type declType;
2484 
2485  if (parser.parseOperand(operand) || parser.parseLSquare() ||
2486  parser.parseAttribute(fieldName) || parser.parseRSquare() ||
2487  parser.parseOptionalAttrDict(result.attributes) ||
2488  parser.parseColonType(declType))
2489  return failure();
2490  auto aggType = type_dyn_cast<AggregateType>(declType);
2491  if (!aggType)
2492  return parser.emitError(parser.getNameLoc(),
2493  "invalid kind of type specified");
2494 
2495  auto fieldIndex = aggType.getFieldIndex(fieldName);
2496  if (!fieldIndex) {
2497  parser.emitError(parser.getNameLoc(), "field name '" +
2498  fieldName.getValue() +
2499  "' not found in aggregate type");
2500  return failure();
2501  }
2502 
2503  auto indexAttr =
2504  IntegerAttr::get(IntegerType::get(parser.getContext(), 32), *fieldIndex);
2505  result.addAttribute("fieldIndex", indexAttr);
2506  Type resultType = aggType.getElements()[*fieldIndex].type;
2507  result.addTypes(resultType);
2508 
2509  if (parser.resolveOperand(operand, declType, result.operands))
2510  return failure();
2511  return success();
2512 }
2513 
2514 /// Use the same printer for both struct_extract and union_extract since the
2515 /// syntax is identical.
2516 template <typename AggType>
2517 static void printExtractOp(OpAsmPrinter &printer, AggType op) {
2518  printer << " ";
2519  printer.printOperand(op.getInput());
2520  printer << "[\"" << op.getFieldName() << "\"]";
2521  printer.printOptionalAttrDict(op->getAttrs(), {"fieldIndex"});
2522  printer << " : " << op.getInput().getType();
2523 }
2524 
2525 ParseResult StructExtractOp::parse(OpAsmParser &parser,
2526  OperationState &result) {
2527  return parseExtractOp<StructType>(parser, result);
2528 }
2529 
2530 void StructExtractOp::print(OpAsmPrinter &printer) {
2531  printExtractOp(printer, *this);
2532 }
2533 
2534 void StructExtractOp::build(OpBuilder &builder, OperationState &odsState,
2535  Value input, StructType::FieldInfo field) {
2536  auto fieldIndex =
2537  type_cast<StructType>(input.getType()).getFieldIndex(field.name);
2538  assert(fieldIndex.has_value() && "field name not found in aggregate type");
2539  build(builder, odsState, field.type, input, *fieldIndex);
2540 }
2541 
2542 void StructExtractOp::build(OpBuilder &builder, OperationState &odsState,
2543  Value input, StringAttr fieldName) {
2544  auto structType = type_cast<StructType>(input.getType());
2545  auto fieldIndex = structType.getFieldIndex(fieldName);
2546  assert(fieldIndex.has_value() && "field name not found in aggregate type");
2547  auto resultType = structType.getElements()[*fieldIndex].type;
2548  build(builder, odsState, resultType, input, *fieldIndex);
2549 }
2550 
2551 OpFoldResult StructExtractOp::fold(FoldAdaptor adaptor) {
2552  if (auto constOperand = adaptor.getInput()) {
2553  // Fold extract from aggregate constant
2554  auto operandAttr = llvm::cast<ArrayAttr>(constOperand);
2555  return operandAttr.getValue()[getFieldIndex()];
2556  }
2557 
2558  if (auto foldResult =
2559  foldStructExtract(getInput().getDefiningOp(), getFieldIndex()))
2560  return foldResult;
2561  return {};
2562 }
2563 
2565  PatternRewriter &rewriter) {
2566  auto inputOp = op.getInput().getDefiningOp();
2567 
2568  // b = extract(inject(x["a"], v0)["b"]) => extract(x, "b")
2569  if (auto structInject = dyn_cast_or_null<StructInjectOp>(inputOp)) {
2570  if (structInject.getFieldIndex() != op.getFieldIndex()) {
2571  rewriter.replaceOpWithNewOp<StructExtractOp>(
2572  op, op.getType(), structInject.getInput(), op.getFieldIndexAttr());
2573  return success();
2574  }
2575  }
2576 
2577  return failure();
2578 }
2579 
2581  function_ref<void(Value, StringRef)> setNameFn) {
2582  setNameFn(getResult(), getFieldName());
2583 }
2584 
2585 //===----------------------------------------------------------------------===//
2586 // StructInjectOp
2587 //===----------------------------------------------------------------------===//
2588 
2589 void StructInjectOp::build(OpBuilder &builder, OperationState &odsState,
2590  Value input, StringAttr fieldName, Value newValue) {
2591  auto structType = type_cast<StructType>(input.getType());
2592  auto fieldIndex = structType.getFieldIndex(fieldName);
2593  assert(fieldIndex.has_value() && "field name not found in aggregate type");
2594  build(builder, odsState, input, *fieldIndex, newValue);
2595 }
2596 
2597 LogicalResult StructInjectOp::verify() {
2598  return verifyAggregateFieldIndexAndType<StructInjectOp, StructType>(
2599  *this, getInput().getType(), getNewValue().getType());
2600 }
2601 
2602 ParseResult StructInjectOp::parse(OpAsmParser &parser, OperationState &result) {
2603  llvm::SMLoc inputOperandsLoc = parser.getCurrentLocation();
2604  OpAsmParser::UnresolvedOperand operand, val;
2605  StringAttr fieldName;
2606  Type declType;
2607 
2608  if (parser.parseOperand(operand) || parser.parseLSquare() ||
2609  parser.parseAttribute(fieldName) || parser.parseRSquare() ||
2610  parser.parseComma() || parser.parseOperand(val) ||
2611  parser.parseOptionalAttrDict(result.attributes) ||
2612  parser.parseColonType(declType))
2613  return failure();
2614  auto structType = type_dyn_cast<StructType>(declType);
2615  if (!structType)
2616  return parser.emitError(inputOperandsLoc, "invalid kind of type specified");
2617 
2618  auto fieldIndex = structType.getFieldIndex(fieldName);
2619  if (!fieldIndex) {
2620  parser.emitError(parser.getNameLoc(), "field name '" +
2621  fieldName.getValue() +
2622  "' not found in aggregate type");
2623  return failure();
2624  }
2625 
2626  auto indexAttr =
2627  IntegerAttr::get(IntegerType::get(parser.getContext(), 32), *fieldIndex);
2628  result.addAttribute("fieldIndex", indexAttr);
2629  result.addTypes(declType);
2630 
2631  Type resultType = structType.getElements()[*fieldIndex].type;
2632  if (parser.resolveOperands({operand, val}, {declType, resultType},
2633  inputOperandsLoc, result.operands))
2634  return failure();
2635  return success();
2636 }
2637 
2638 void StructInjectOp::print(OpAsmPrinter &printer) {
2639  printer << " ";
2640  printer.printOperand(getInput());
2641  printer << "[\"" << getFieldName() << "\"], ";
2642  printer.printOperand(getNewValue());
2643  printer.printOptionalAttrDict((*this)->getAttrs(), {"fieldIndex"});
2644  printer << " : " << getInput().getType();
2645 }
2646 
2647 OpFoldResult StructInjectOp::fold(FoldAdaptor adaptor) {
2648  auto input = adaptor.getInput();
2649  auto newValue = adaptor.getNewValue();
2650  if (!input || !newValue)
2651  return {};
2652  SmallVector<Attribute> array;
2653  llvm::copy(cast<ArrayAttr>(input), std::back_inserter(array));
2654  array[getFieldIndex()] = newValue;
2655  return ArrayAttr::get(getContext(), array);
2656 }
2657 
2658 LogicalResult StructInjectOp::canonicalize(StructInjectOp op,
2659  PatternRewriter &rewriter) {
2660  // Canonicalize multiple injects into a create op and eliminate overwrites.
2661  SmallPtrSet<Operation *, 4> injects;
2662  DenseMap<StringAttr, Value> fields;
2663 
2664  // Chase a chain of injects. Bail out if cycles are present.
2665  StructInjectOp inject = op;
2666  Value input;
2667  do {
2668  if (!injects.insert(inject).second)
2669  return failure();
2670 
2671  fields.try_emplace(inject.getFieldNameAttr(), inject.getNewValue());
2672  input = inject.getInput();
2673  inject = dyn_cast_or_null<StructInjectOp>(input.getDefiningOp());
2674  } while (inject);
2675  assert(input && "missing input to inject chain");
2676 
2677  auto ty = hw::type_cast<StructType>(op.getType());
2678  auto elements = ty.getElements();
2679 
2680  // If the inject chain sets all fields, canonicalize to create.
2681  if (fields.size() == elements.size()) {
2682  SmallVector<Value> createFields;
2683  for (const auto &field : elements) {
2684  auto it = fields.find(field.name);
2685  assert(it != fields.end() && "missing field");
2686  createFields.push_back(it->second);
2687  }
2688  rewriter.replaceOpWithNewOp<StructCreateOp>(op, ty, createFields);
2689  return success();
2690  }
2691 
2692  // Nothing to canonicalize, only the original inject in the chain.
2693  if (injects.size() == fields.size())
2694  return failure();
2695 
2696  // Eliminate overwrites. The hash map contains the last write to each field.
2697  for (uint32_t fieldIndex = 0; fieldIndex < elements.size(); fieldIndex++) {
2698  auto it = fields.find(elements[fieldIndex].name);
2699  if (it == fields.end())
2700  continue;
2701  input = rewriter.create<StructInjectOp>(op.getLoc(), ty, input, fieldIndex,
2702  it->second);
2703  }
2704 
2705  rewriter.replaceOp(op, input);
2706  return success();
2707 }
2708 
2709 //===----------------------------------------------------------------------===//
2710 // UnionCreateOp
2711 //===----------------------------------------------------------------------===//
2712 
2713 LogicalResult UnionCreateOp::verify() {
2714  return verifyAggregateFieldIndexAndType<UnionCreateOp, UnionType>(
2715  *this, getType(), getInput().getType());
2716 }
2717 
2718 void UnionCreateOp::build(OpBuilder &builder, OperationState &odsState,
2719  Type unionType, StringAttr fieldName, Value input) {
2720  auto fieldIndex = type_cast<UnionType>(unionType).getFieldIndex(fieldName);
2721  assert(fieldIndex.has_value() && "field name not found in aggregate type");
2722  build(builder, odsState, unionType, *fieldIndex, input);
2723 }
2724 
2725 ParseResult UnionCreateOp::parse(OpAsmParser &parser, OperationState &result) {
2726  Type declOrAliasType;
2727  StringAttr fieldName;
2728  OpAsmParser::UnresolvedOperand input;
2729  llvm::SMLoc fieldLoc = parser.getCurrentLocation();
2730 
2731  if (parser.parseAttribute(fieldName) || parser.parseComma() ||
2732  parser.parseOperand(input) ||
2733  parser.parseOptionalAttrDict(result.attributes) ||
2734  parser.parseColonType(declOrAliasType))
2735  return failure();
2736 
2737  auto declType = type_dyn_cast<UnionType>(declOrAliasType);
2738  if (!declType)
2739  return parser.emitError(parser.getNameLoc(),
2740  "expected !hw.union type or alias");
2741 
2742  auto fieldIndex = declType.getFieldIndex(fieldName);
2743  if (!fieldIndex) {
2744  parser.emitError(fieldLoc, "cannot find union field '")
2745  << fieldName.getValue() << '\'';
2746  return failure();
2747  }
2748 
2749  auto indexAttr =
2750  IntegerAttr::get(IntegerType::get(parser.getContext(), 32), *fieldIndex);
2751  result.addAttribute("fieldIndex", indexAttr);
2752  Type inputType = declType.getElements()[*fieldIndex].type;
2753 
2754  if (parser.resolveOperand(input, inputType, result.operands))
2755  return failure();
2756  result.addTypes({declOrAliasType});
2757  return success();
2758 }
2759 
2760 void UnionCreateOp::print(OpAsmPrinter &printer) {
2761  printer << " \"" << getFieldName() << "\", ";
2762  printer.printOperand(getInput());
2763  printer.printOptionalAttrDict((*this)->getAttrs(), {"fieldIndex"});
2764  printer << " : " << getType();
2765 }
2766 
2767 //===----------------------------------------------------------------------===//
2768 // UnionExtractOp
2769 //===----------------------------------------------------------------------===//
2770 
2771 ParseResult UnionExtractOp::parse(OpAsmParser &parser, OperationState &result) {
2772  return parseExtractOp<UnionType>(parser, result);
2773 }
2774 
2775 void UnionExtractOp::print(OpAsmPrinter &printer) {
2776  printExtractOp(printer, *this);
2777 }
2778 
2779 LogicalResult UnionExtractOp::inferReturnTypes(
2780  MLIRContext *context, std::optional<Location> loc, ValueRange operands,
2781  DictionaryAttr attrs, mlir::OpaqueProperties properties,
2782  mlir::RegionRange regions, SmallVectorImpl<Type> &results) {
2783  Adaptor adaptor(operands, attrs, properties, regions);
2784  auto unionElements =
2785  hw::type_cast<UnionType>((adaptor.getInput().getType())).getElements();
2786  unsigned fieldIndex = adaptor.getFieldIndexAttr().getValue().getZExtValue();
2787  if (fieldIndex >= unionElements.size()) {
2788  if (loc)
2789  mlir::emitError(*loc, "field index " + Twine(fieldIndex) +
2790  " exceeds element count of aggregate type");
2791  return failure();
2792  }
2793  results.push_back(unionElements[fieldIndex].type);
2794  return success();
2795 }
2796 
2797 void UnionExtractOp::build(OpBuilder &odsBuilder, OperationState &odsState,
2798  Value input, StringAttr fieldName) {
2799  auto unionType = type_cast<UnionType>(input.getType());
2800  auto fieldIndex = unionType.getFieldIndex(fieldName);
2801  assert(fieldIndex.has_value() && "field name not found in aggregate type");
2802  auto resultType = unionType.getElements()[*fieldIndex].type;
2803  build(odsBuilder, odsState, resultType, input, *fieldIndex);
2804 }
2805 
2806 //===----------------------------------------------------------------------===//
2807 // ArrayGetOp
2808 //===----------------------------------------------------------------------===//
2809 
2810 // An array_get of an array_create with a constant index can just be the
2811 // array_create operand at the constant index. If the array_create has a
2812 // single uniform value for each element, just return that value regardless of
2813 // the index. If the array is constructed from a constant by a bitcast
2814 // operation, we can fold into a constant.
2815 OpFoldResult ArrayGetOp::fold(FoldAdaptor adaptor) {
2816  auto inputCst = dyn_cast_or_null<ArrayAttr>(adaptor.getInput());
2817  auto indexCst = dyn_cast_or_null<IntegerAttr>(adaptor.getIndex());
2818 
2819  if (inputCst) {
2820  // Constant array index.
2821  if (indexCst) {
2822  auto indexVal = indexCst.getValue();
2823  if (indexVal.getBitWidth() < 64) {
2824  auto index = indexVal.getZExtValue();
2825  return inputCst[inputCst.size() - 1 - index];
2826  }
2827  }
2828  // If all elements of the array are the same, we can return any element of
2829  // array.
2830  if (!inputCst.empty() && llvm::all_equal(inputCst))
2831  return inputCst[0];
2832  }
2833 
2834  // array_get(bitcast(c), i) -> c[i*w+w-1:i*w]
2835  if (auto bitcast = getInput().getDefiningOp<hw::BitcastOp>()) {
2836  auto intTy = dyn_cast<IntegerType>(getType());
2837  if (!intTy)
2838  return {};
2839  auto bitcastInputOp = bitcast.getInput().getDefiningOp<hw::ConstantOp>();
2840  if (!bitcastInputOp)
2841  return {};
2842  if (!indexCst)
2843  return {};
2844  auto bitcastInputCst = bitcastInputOp.getValue();
2845  // Calculate the index. Make sure to zero-extend the index value before
2846  // multiplying the element width.
2847  auto startIdx = indexCst.getValue().zext(bitcastInputCst.getBitWidth()) *
2848  getType().getIntOrFloatBitWidth();
2849  // Extract [startIdx + width - 1: startIdx].
2850  return IntegerAttr::get(intTy, bitcastInputCst.lshr(startIdx).trunc(
2851  intTy.getIntOrFloatBitWidth()));
2852  }
2853 
2854  auto inputCreate = getInput().getDefiningOp<ArrayCreateOp>();
2855  if (!inputCreate)
2856  return {};
2857 
2858  if (auto uniformValue = inputCreate.getUniformElement())
2859  return uniformValue;
2860 
2861  if (!indexCst || indexCst.getValue().getBitWidth() > 64)
2862  return {};
2863 
2864  uint64_t index = indexCst.getValue().getLimitedValue();
2865  auto createInputs = inputCreate.getInputs();
2866  if (index >= createInputs.size())
2867  return {};
2868  return createInputs[createInputs.size() - index - 1];
2869 }
2870 
2871 LogicalResult ArrayGetOp::canonicalize(ArrayGetOp op,
2872  PatternRewriter &rewriter) {
2873  auto idxOpt = getUIntFromValue(op.getIndex());
2874  if (!idxOpt)
2875  return failure();
2876 
2877  auto *inputOp = op.getInput().getDefiningOp();
2878  if (auto inputSlice = dyn_cast_or_null<ArraySliceOp>(inputOp)) {
2879  // get(slice(a, n), m) -> get(a, n + m)
2880  auto offsetOp = inputSlice.getLowIndex();
2881  auto offsetOpt = getUIntFromValue(offsetOp);
2882  if (!offsetOpt)
2883  return failure();
2884 
2885  uint64_t offset = *offsetOpt + *idxOpt;
2886  auto newOffset =
2887  rewriter.create<ConstantOp>(op.getLoc(), offsetOp.getType(), offset);
2888  rewriter.replaceOpWithNewOp<ArrayGetOp>(op, inputSlice.getInput(),
2889  newOffset);
2890  return success();
2891  }
2892 
2893  if (auto inputConcat = dyn_cast_or_null<ArrayConcatOp>(inputOp)) {
2894  // get(concat(a0, a1, ...), m) -> get(an, m - s0 - s1 - ...)
2895  uint64_t elemIndex = *idxOpt;
2896  for (auto input : llvm::reverse(inputConcat.getInputs())) {
2897  size_t size = hw::type_cast<ArrayType>(input.getType()).getNumElements();
2898  if (elemIndex >= size) {
2899  elemIndex -= size;
2900  continue;
2901  }
2902 
2903  unsigned indexWidth = size == 1 ? 1 : llvm::Log2_64_Ceil(size);
2904  auto newIdxOp = rewriter.create<ConstantOp>(
2905  op.getLoc(), rewriter.getIntegerType(indexWidth), elemIndex);
2906 
2907  rewriter.replaceOpWithNewOp<ArrayGetOp>(op, input, newIdxOp);
2908  return success();
2909  }
2910  return failure();
2911  }
2912 
2913  // array_get const, (array_get sel, (array_create a, b, c, d)) -->
2914  // array_get sel, (array_create (array_get const a), (array_get const b),
2915  // (array_get const, c), (array_get const, d))
2916  if (auto innerGet = dyn_cast_or_null<hw::ArrayGetOp>(inputOp)) {
2917  if (!innerGet.getIndex().getDefiningOp<hw::ConstantOp>()) {
2918  if (auto create =
2919  innerGet.getInput().getDefiningOp<hw::ArrayCreateOp>()) {
2920 
2921  SmallVector<Value> newValues;
2922  for (auto operand : create.getOperands())
2923  newValues.push_back(rewriter.createOrFold<hw::ArrayGetOp>(
2924  op.getLoc(), operand, op.getIndex()));
2925 
2926  rewriter.replaceOpWithNewOp<hw::ArrayGetOp>(
2927  op,
2928  rewriter.createOrFold<hw::ArrayCreateOp>(op.getLoc(), newValues),
2929  innerGet.getIndex());
2930  return success();
2931  }
2932  }
2933  }
2934 
2935  return failure();
2936 }
2937 
2938 //===----------------------------------------------------------------------===//
2939 // TypedeclOp
2940 //===----------------------------------------------------------------------===//
2941 
2942 StringRef TypedeclOp::getPreferredName() {
2943  return getVerilogName().value_or(getName());
2944 }
2945 
2946 Type TypedeclOp::getAliasType() {
2947  auto parentScope = cast<hw::TypeScopeOp>(getOperation()->getParentOp());
2948  return hw::TypeAliasType::get(
2949  SymbolRefAttr::get(parentScope.getSymNameAttr(),
2950  {FlatSymbolRefAttr::get(*this)}),
2951  getType());
2952 }
2953 
2954 //===----------------------------------------------------------------------===//
2955 // BitcastOp
2956 //===----------------------------------------------------------------------===//
2957 
2958 OpFoldResult BitcastOp::fold(FoldAdaptor) {
2959  // Identity.
2960  // bitcast(%a) : A -> A ==> %a
2961  if (getOperand().getType() == getType())
2962  return getOperand();
2963 
2964  return {};
2965 }
2966 
2967 LogicalResult BitcastOp::canonicalize(BitcastOp op, PatternRewriter &rewriter) {
2968  // Composition.
2969  // %b = bitcast(%a) : A -> B
2970  // bitcast(%b) : B -> C
2971  // ===> bitcast(%a) : A -> C
2972  auto inputBitcast =
2973  dyn_cast_or_null<BitcastOp>(op.getInput().getDefiningOp());
2974  if (!inputBitcast)
2975  return failure();
2976  auto bitcast = rewriter.createOrFold<BitcastOp>(op.getLoc(), op.getType(),
2977  inputBitcast.getInput());
2978  rewriter.replaceOp(op, bitcast);
2979  return success();
2980 }
2981 
2982 LogicalResult BitcastOp::verify() {
2983  if (getBitWidth(getInput().getType()) != getBitWidth(getResult().getType()))
2984  return this->emitOpError("Bitwidth of input must match result");
2985  return success();
2986 }
2987 
2988 //===----------------------------------------------------------------------===//
2989 // HierPathOp helpers.
2990 //===----------------------------------------------------------------------===//
2991 
2992 bool HierPathOp::dropModule(StringAttr moduleToDrop) {
2993  SmallVector<Attribute, 4> newPath;
2994  bool updateMade = false;
2995  for (auto nameRef : getNamepath()) {
2996  // nameRef is either an InnerRefAttr or a FlatSymbolRefAttr.
2997  if (auto ref = dyn_cast<hw::InnerRefAttr>(nameRef)) {
2998  if (ref.getModule() == moduleToDrop)
2999  updateMade = true;
3000  else
3001  newPath.push_back(ref);
3002  } else {
3003  if (cast<FlatSymbolRefAttr>(nameRef).getAttr() == moduleToDrop)
3004  updateMade = true;
3005  else
3006  newPath.push_back(nameRef);
3007  }
3008  }
3009  if (updateMade)
3010  setNamepathAttr(ArrayAttr::get(getContext(), newPath));
3011  return updateMade;
3012 }
3013 
3014 bool HierPathOp::inlineModule(StringAttr moduleToDrop) {
3015  SmallVector<Attribute, 4> newPath;
3016  bool updateMade = false;
3017  StringRef inlinedInstanceName = "";
3018  for (auto nameRef : getNamepath()) {
3019  // nameRef is either an InnerRefAttr or a FlatSymbolRefAttr.
3020  if (auto ref = dyn_cast<hw::InnerRefAttr>(nameRef)) {
3021  if (ref.getModule() == moduleToDrop) {
3022  inlinedInstanceName = ref.getName().getValue();
3023  updateMade = true;
3024  } else if (!inlinedInstanceName.empty()) {
3025  newPath.push_back(hw::InnerRefAttr::get(
3026  ref.getModule(),
3027  StringAttr::get(getContext(), inlinedInstanceName + "_" +
3028  ref.getName().getValue())));
3029  inlinedInstanceName = "";
3030  } else
3031  newPath.push_back(ref);
3032  } else {
3033  if (cast<FlatSymbolRefAttr>(nameRef).getAttr() == moduleToDrop)
3034  updateMade = true;
3035  else
3036  newPath.push_back(nameRef);
3037  }
3038  }
3039  if (updateMade)
3040  setNamepathAttr(ArrayAttr::get(getContext(), newPath));
3041  return updateMade;
3042 }
3043 
3044 bool HierPathOp::updateModule(StringAttr oldMod, StringAttr newMod) {
3045  SmallVector<Attribute, 4> newPath;
3046  bool updateMade = false;
3047  for (auto nameRef : getNamepath()) {
3048  // nameRef is either an InnerRefAttr or a FlatSymbolRefAttr.
3049  if (auto ref = dyn_cast<hw::InnerRefAttr>(nameRef)) {
3050  if (ref.getModule() == oldMod) {
3051  newPath.push_back(hw::InnerRefAttr::get(newMod, ref.getName()));
3052  updateMade = true;
3053  } else
3054  newPath.push_back(ref);
3055  } else {
3056  if (cast<FlatSymbolRefAttr>(nameRef).getAttr() == oldMod) {
3057  newPath.push_back(FlatSymbolRefAttr::get(newMod));
3058  updateMade = true;
3059  } else
3060  newPath.push_back(nameRef);
3061  }
3062  }
3063  if (updateMade)
3064  setNamepathAttr(ArrayAttr::get(getContext(), newPath));
3065  return updateMade;
3066 }
3067 
3068 bool HierPathOp::updateModuleAndInnerRef(
3069  StringAttr oldMod, StringAttr newMod,
3070  const llvm::DenseMap<StringAttr, StringAttr> &innerSymRenameMap) {
3071  auto fromRef = FlatSymbolRefAttr::get(oldMod);
3072  if (oldMod == newMod)
3073  return false;
3074 
3075  auto namepathNew = getNamepath().getValue().vec();
3076  bool updateMade = false;
3077  // Break from the loop if the module is found, since it can occur only once.
3078  for (auto &element : namepathNew) {
3079  if (auto innerRef = dyn_cast<hw::InnerRefAttr>(element)) {
3080  if (innerRef.getModule() != oldMod)
3081  continue;
3082  auto symName = innerRef.getName();
3083  // Since the module got updated, the old innerRef symbol inside oldMod
3084  // should also be updated to the new symbol inside the newMod.
3085  auto to = innerSymRenameMap.find(symName);
3086  if (to != innerSymRenameMap.end())
3087  symName = to->second;
3088  updateMade = true;
3089  element = hw::InnerRefAttr::get(newMod, symName);
3090  break;
3091  }
3092  if (element != fromRef)
3093  continue;
3094 
3095  updateMade = true;
3096  element = FlatSymbolRefAttr::get(newMod);
3097  break;
3098  }
3099  if (updateMade)
3100  setNamepathAttr(ArrayAttr::get(getContext(), namepathNew));
3101  return updateMade;
3102 }
3103 
3104 bool HierPathOp::truncateAtModule(StringAttr atMod, bool includeMod) {
3105  SmallVector<Attribute, 4> newPath;
3106  bool updateMade = false;
3107  for (auto nameRef : getNamepath()) {
3108  // nameRef is either an InnerRefAttr or a FlatSymbolRefAttr.
3109  if (auto ref = dyn_cast<hw::InnerRefAttr>(nameRef)) {
3110  if (ref.getModule() == atMod) {
3111  updateMade = true;
3112  if (includeMod)
3113  newPath.push_back(ref);
3114  } else
3115  newPath.push_back(ref);
3116  } else {
3117  if (cast<FlatSymbolRefAttr>(nameRef).getAttr() == atMod && !includeMod)
3118  updateMade = true;
3119  else
3120  newPath.push_back(nameRef);
3121  }
3122  if (updateMade)
3123  break;
3124  }
3125  if (updateMade)
3126  setNamepathAttr(ArrayAttr::get(getContext(), newPath));
3127  return updateMade;
3128 }
3129 
3130 /// Return just the module part of the namepath at a specific index.
3131 StringAttr HierPathOp::modPart(unsigned i) {
3132  return TypeSwitch<Attribute, StringAttr>(getNamepath()[i])
3133  .Case<FlatSymbolRefAttr>([](auto a) { return a.getAttr(); })
3134  .Case<hw::InnerRefAttr>([](auto a) { return a.getModule(); });
3135 }
3136 
3137 /// Return the root module.
3138 StringAttr HierPathOp::root() {
3139  assert(!getNamepath().empty());
3140  return modPart(0);
3141 }
3142 
3143 /// Return true if the NLA has the module in its path.
3144 bool HierPathOp::hasModule(StringAttr modName) {
3145  for (auto nameRef : getNamepath()) {
3146  // nameRef is either an InnerRefAttr or a FlatSymbolRefAttr.
3147  if (auto ref = dyn_cast<hw::InnerRefAttr>(nameRef)) {
3148  if (ref.getModule() == modName)
3149  return true;
3150  } else {
3151  if (cast<FlatSymbolRefAttr>(nameRef).getAttr() == modName)
3152  return true;
3153  }
3154  }
3155  return false;
3156 }
3157 
3158 /// Return true if the NLA has the InnerSym .
3159 bool HierPathOp::hasInnerSym(StringAttr modName, StringAttr symName) const {
3160  for (auto nameRef : const_cast<HierPathOp *>(this)->getNamepath())
3161  if (auto ref = dyn_cast<hw::InnerRefAttr>(nameRef))
3162  if (ref.getName() == symName && ref.getModule() == modName)
3163  return true;
3164 
3165  return false;
3166 }
3167 
3168 /// Return just the reference part of the namepath at a specific index. This
3169 /// will return an empty attribute if this is the leaf and the leaf is a module.
3170 StringAttr HierPathOp::refPart(unsigned i) {
3171  return TypeSwitch<Attribute, StringAttr>(getNamepath()[i])
3172  .Case<FlatSymbolRefAttr>([](auto a) { return StringAttr({}); })
3173  .Case<hw::InnerRefAttr>([](auto a) { return a.getName(); });
3174 }
3175 
3176 /// Return the leaf reference. This returns an empty attribute if the leaf
3177 /// reference is a module.
3178 StringAttr HierPathOp::ref() {
3179  assert(!getNamepath().empty());
3180  return refPart(getNamepath().size() - 1);
3181 }
3182 
3183 /// Return the leaf module.
3184 StringAttr HierPathOp::leafMod() {
3185  assert(!getNamepath().empty());
3186  return modPart(getNamepath().size() - 1);
3187 }
3188 
3189 /// Returns true if this NLA targets an instance of a module (as opposed to
3190 /// an instance's port or something inside an instance).
3191 bool HierPathOp::isModule() { return !ref(); }
3192 
3193 /// Returns true if this NLA targets something inside a module (as opposed
3194 /// to a module or an instance of a module);
3195 bool HierPathOp::isComponent() { return (bool)ref(); }
3196 
3197 // Verify the HierPathOp.
3198 // 1. Iterate over the namepath.
3199 // 2. The namepath should be a valid instance path, specified either on a
3200 // module or a declaration inside a module.
3201 // 3. Each element in the namepath is an InnerRefAttr except possibly the
3202 // last element.
3203 // 4. Make sure that the InnerRefAttr is legal, by verifying the module name
3204 // and the corresponding inner_sym on the instance.
3205 // 5. Make sure that the instance path is legal, by verifying the sequence of
3206 // instance and the expected module occurs as the next element in the path.
3207 // 6. The last element of the namepath, can be an InnerRefAttr on either a
3208 // module port or a declaration inside the module.
3209 // 7. The last element of the namepath can also be a module symbol.
3210 LogicalResult HierPathOp::verifyInnerRefs(hw::InnerRefNamespace &ns) {
3211  ArrayAttr expectedModuleNames = {};
3212  auto checkExpectedModule = [&](Attribute name) -> LogicalResult {
3213  if (!expectedModuleNames)
3214  return success();
3215  if (llvm::any_of(expectedModuleNames,
3216  [name](Attribute attr) { return attr == name; }))
3217  return success();
3218  auto diag = emitOpError() << "instance path is incorrect. Expected ";
3219  size_t n = expectedModuleNames.size();
3220  if (n != 1) {
3221  diag << "one of ";
3222  }
3223  for (size_t i = 0; i < n; ++i) {
3224  if (i != 0)
3225  diag << ((i + 1 == n) ? " or " : ", ");
3226  diag << cast<StringAttr>(expectedModuleNames[i]);
3227  }
3228  diag << ". Instead found: " << name;
3229  return diag;
3230  };
3231 
3232  if (!getNamepath() || getNamepath().empty())
3233  return emitOpError() << "the instance path cannot be empty";
3234  for (unsigned i = 0, s = getNamepath().size() - 1; i < s; ++i) {
3235  hw::InnerRefAttr innerRef = dyn_cast<hw::InnerRefAttr>(getNamepath()[i]);
3236  if (!innerRef)
3237  return emitOpError()
3238  << "the instance path can only contain inner sym reference"
3239  << ", only the leaf can refer to a module symbol";
3240 
3241  if (failed(checkExpectedModule(innerRef.getModule())))
3242  return failure();
3243 
3244  auto instOp = ns.lookupOp<igraph::InstanceOpInterface>(innerRef);
3245  if (!instOp)
3246  return emitOpError() << " module: " << innerRef.getModule()
3247  << " does not contain any instance with symbol: "
3248  << innerRef.getName();
3249  expectedModuleNames = instOp.getReferencedModuleNamesAttr();
3250  }
3251 
3252  // The instance path has been verified. Now verify the last element.
3253  auto leafRef = getNamepath()[getNamepath().size() - 1];
3254  if (auto innerRef = dyn_cast<hw::InnerRefAttr>(leafRef)) {
3255  if (!ns.lookup(innerRef)) {
3256  return emitOpError() << " operation with symbol: " << innerRef
3257  << " was not found ";
3258  }
3259  if (failed(checkExpectedModule(innerRef.getModule())))
3260  return failure();
3261  } else if (failed(checkExpectedModule(
3262  cast<FlatSymbolRefAttr>(leafRef).getAttr()))) {
3263  return failure();
3264  }
3265  return success();
3266 }
3267 
3268 void HierPathOp::print(OpAsmPrinter &p) {
3269  p << " ";
3270 
3271  // Print visibility if present.
3272  StringRef visibilityAttrName = SymbolTable::getVisibilityAttrName();
3273  if (auto visibility =
3274  getOperation()->getAttrOfType<StringAttr>(visibilityAttrName))
3275  p << visibility.getValue() << ' ';
3276 
3277  p.printSymbolName(getSymName());
3278  p << " [";
3279  llvm::interleaveComma(getNamepath().getValue(), p, [&](Attribute attr) {
3280  if (auto ref = dyn_cast<hw::InnerRefAttr>(attr)) {
3281  p.printSymbolName(ref.getModule().getValue());
3282  p << "::";
3283  p.printSymbolName(ref.getName().getValue());
3284  } else {
3285  p.printSymbolName(cast<FlatSymbolRefAttr>(attr).getValue());
3286  }
3287  });
3288  p << "]";
3289  p.printOptionalAttrDict(
3290  (*this)->getAttrs(),
3291  {SymbolTable::getSymbolAttrName(), "namepath", visibilityAttrName});
3292 }
3293 
3294 ParseResult HierPathOp::parse(OpAsmParser &parser, OperationState &result) {
3295  // Parse the visibility attribute.
3296  (void)mlir::impl::parseOptionalVisibilityKeyword(parser, result.attributes);
3297 
3298  // Parse the symbol name.
3299  StringAttr symName;
3300  if (parser.parseSymbolName(symName, SymbolTable::getSymbolAttrName(),
3301  result.attributes))
3302  return failure();
3303 
3304  // Parse the namepath.
3305  SmallVector<Attribute> namepath;
3306  if (parser.parseCommaSeparatedList(
3307  OpAsmParser::Delimiter::Square, [&]() -> ParseResult {
3308  auto loc = parser.getCurrentLocation();
3309  SymbolRefAttr ref;
3310  if (parser.parseAttribute(ref))
3311  return failure();
3312 
3313  // "A" is a Ref, "A::b" is a InnerRef, "A::B::c" is an error.
3314  auto pathLength = ref.getNestedReferences().size();
3315  if (pathLength == 0)
3316  namepath.push_back(
3317  FlatSymbolRefAttr::get(ref.getRootReference()));
3318  else if (pathLength == 1)
3319  namepath.push_back(hw::InnerRefAttr::get(ref.getRootReference(),
3320  ref.getLeafReference()));
3321  else
3322  return parser.emitError(loc,
3323  "only one nested reference is allowed");
3324  return success();
3325  }))
3326  return failure();
3327  result.addAttribute("namepath",
3328  ArrayAttr::get(parser.getContext(), namepath));
3329 
3330  if (parser.parseOptionalAttrDict(result.attributes))
3331  return failure();
3332 
3333  return success();
3334 }
3335 
3336 //===----------------------------------------------------------------------===//
3337 // TriggeredOp
3338 //===----------------------------------------------------------------------===//
3339 
3340 void TriggeredOp::build(OpBuilder &builder, OperationState &odsState,
3341  EventControlAttr event, Value trigger,
3342  ValueRange inputs) {
3343  odsState.addOperands(trigger);
3344  odsState.addOperands(inputs);
3345  odsState.addAttribute(getEventAttrName(odsState.name), event);
3346  auto *r = odsState.addRegion();
3347  Block *b = new Block();
3348  r->push_back(b);
3349 
3350  llvm::SmallVector<Location> argLocs;
3351  llvm::transform(inputs, std::back_inserter(argLocs),
3352  [&](Value v) { return v.getLoc(); });
3353  b->addArguments(inputs.getTypes(), argLocs);
3354 }
3355 
3356 //===----------------------------------------------------------------------===//
3357 // TableGen generated logic.
3358 //===----------------------------------------------------------------------===//
3359 
3360 // Provide the autogenerated implementation guts for the Op classes.
3361 #define GET_OP_CLASSES
3362 #include "circt/Dialect/HW/HW.cpp.inc"
assert(baseType &&"element must be base type")
MlirType elementType
Definition: CHIRRTL.cpp:29
int32_t width
Definition: FIRRTL.cpp:36
static LogicalResult verifyModuleCommon(HWModuleLike module)
Definition: HWOps.cpp:1076
static void printParamValue(OpAsmPrinter &p, Operation *, Attribute value, Type resultType)
Definition: HWOps.cpp:491
static void printModuleOp(OpAsmPrinter &p, ModuleTy mod)
Definition: HWOps.cpp:1021
static bool flattenConcatOp(ArrayConcatOp op, PatternRewriter &rewriter)
Definition: HWOps.cpp:2122
static LogicalResult foldCreateToSlice(ArrayCreateOp op, PatternRewriter &rewriter)
Definition: HWOps.cpp:1856
static SmallVector< Location > getAllPortLocs(ModTy module)
Definition: HWOps.cpp:1196
static ArrayAttr arrayOrEmpty(mlir::MLIRContext *context, ArrayRef< Attribute > attrs)
Definition: HWOps.cpp:84
FunctionType getHWModuleOpType(Operation *op)
Definition: HWOps.cpp:1013
static void buildModule(OpBuilder &builder, OperationState &result, StringAttr name, const ModulePortInfo &ports, ArrayAttr parameters, ArrayRef< NamedAttribute > attributes, StringAttr comment)
Definition: HWOps.cpp:547
static void printExtractOp(OpAsmPrinter &printer, AggType op)
Use the same printer for both struct_extract and union_extract since the syntax is identical.
Definition: HWOps.cpp:2517
static void modifyModulePorts(Operation *op, ArrayRef< std::pair< unsigned, PortInfo >> insertInputs, ArrayRef< std::pair< unsigned, PortInfo >> insertOutputs, ArrayRef< unsigned > removeInputs, ArrayRef< unsigned > removeOutputs, Block *body=nullptr)
Insert and remove ports of a module.
Definition: HWOps.cpp:684
static void printArrayConcatTypes(OpAsmPrinter &p, Operation *, TypeRange inputTypes, Type resultType)
Definition: HWOps.cpp:2086
static void modifyModuleArgs(MLIRContext *context, ArrayRef< std::pair< unsigned, PortInfo >> insertArgs, ArrayRef< unsigned > removeArgs, ArrayRef< Attribute > oldArgNames, ArrayRef< Type > oldArgTypes, ArrayRef< Attribute > oldArgAttrs, ArrayRef< Location > oldArgLocs, SmallVector< Attribute > &newArgNames, SmallVector< Type > &newArgTypes, SmallVector< Attribute > &newArgAttrs, SmallVector< Location > &newArgLocs, Block *body=nullptr)
Internal implementation of argument/result insertion and removal on modules.
Definition: HWOps.cpp:585
static ParseResult parseSliceTypes(OpAsmParser &p, Type &srcType, Type &idxType)
Definition: HWOps.cpp:1760
static Value foldStructExtract(Operation *inputOp, uint32_t fieldIndex)
Definition: HWOps.cpp:69
static bool hasAttribute(StringRef name, ArrayRef< NamedAttribute > attrs)
Definition: HWOps.cpp:891
static bool mergeConcatSlices(ArrayConcatOp op, PatternRewriter &rewriter)
Definition: HWOps.cpp:2139
static ParseResult parseExtractOp(OpAsmParser &parser, OperationState &result)
Use the same parser for both struct_extract and union_extract since the syntax is identical.
Definition: HWOps.cpp:2480
static void setAllPortNames(ArrayRef< Attribute > names, ModTy module)
Definition: HWOps.cpp:1266
static void getAsmBlockArgumentNamesImpl(mlir::Region &region, OpAsmSetValueNameFn setNameFn)
Get a special name to use when printing the entry block arguments of the region contained by an opera...
Definition: HWOps.cpp:101
static void setHWModuleType(ModTy &mod, ModuleType type)
Definition: HWOps.cpp:1339
static SmallVector< PortInfo > getPortList(ModuleTy &mod)
Definition: HWOps.cpp:1418
static ParseResult parseParamValue(OpAsmParser &p, Attribute &value, Type &resultType)
Definition: HWOps.cpp:483
static LogicalResult checkAttributes(Operation *op, Attribute attr, Type type)
Definition: HWOps.cpp:397
static std::optional< uint64_t > getUIntFromValue(Value value)
Definition: HWOps.cpp:1927
static ParseResult parseHWModuleOp(OpAsmParser &parser, OperationState &result)
Definition: HWOps.cpp:899
static LogicalResult verifyAggregateFieldIndexAndType(AggregateOp &op, AggregateType aggType, Type elementType)
Ensure an aggregate op's field index is within the bounds of the aggregate type and the accessed fiel...
Definition: HWOps.cpp:2454
static PortInfo getPort(ModuleTy &mod, size_t idx)
Definition: HWOps.cpp:1438
static void printSliceTypes(OpAsmPrinter &p, Operation *, Type srcType, Type idxType)
Definition: HWOps.cpp:1774
static bool hasAdditionalAttributes(Op op, ArrayRef< StringRef > ignoredAttrs={})
Check whether an operation has any additional attributes set beyond its standard list of attributes r...
Definition: HWOps.cpp:343
Delimiter
Definition: HWOps.cpp:116
@ OptionalLessGreater
static ParseResult parseArrayConcatTypes(OpAsmParser &p, SmallVectorImpl< Type > &inputTypes, Type &resultType)
Definition: HWOps.cpp:2056
@ Input
Definition: HW.h:35
@ Output
Definition: HW.h:35
@ InOut
Definition: HW.h:35
static bool getFieldName(const FieldRef &fieldRef, SmallString< 32 > &string)
static InstancePath empty
static StringAttr append(StringAttr base, const Twine &suffix)
Return a attribute with the specified suffix appended.
static Block * getBodyBlock(FModuleLike mod)
A namespace that is used to store existing names and generate new names in some scope within the IR.
Definition: Namespace.h:30
StringRef newName(const Twine &name)
Return a unique name, derived from the input name, and add the new name to the internal namespace.
Definition: Namespace.h:85
void setOutput(unsigned i, Value v)
Definition: HWOps.cpp:241
Value getInput(unsigned i)
Definition: HWOps.cpp:247
llvm::SmallVector< Value > outputOperands
Definition: HWOps.h:118
llvm::SmallVector< Value > inputArgs
Definition: HWOps.h:117
llvm::StringMap< unsigned > outputIdx
Definition: HWOps.h:116
llvm::StringMap< unsigned > inputIdx
Definition: HWOps.h:116
HWModulePortAccessor(Location loc, const ModulePortInfo &info, Region &bodyRegion)
Definition: HWOps.cpp:225
static StringRef getInnerSymbolAttrName()
Return the name of the attribute used for inner symbol names.
This helps visit TypeOp nodes.
Definition: HWVisitors.h:25
ResultType dispatchTypeOpVisitor(Operation *op, ExtraArgs... args)
Definition: HWVisitors.h:27
static LogicalResult canonicalize(Op op, PatternRewriter &rewriter)
Definition: VerifOps.cpp:66
static LogicalResult verify(Value clock, bool eventExists, mlir::Location loc)
Definition: SVOps.cpp:2459
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:55
uint64_t getWidth(Type t)
Definition: ESIPasses.cpp:32
size_t getNumPorts(Operation *op)
Return the number of ports in a module-like thing (modules, memories, etc)
Definition: FIRRTLOps.cpp:301
ModuleType fnToMod(Operation *op, ArrayRef< Attribute > inputNames, ArrayRef< Attribute > outputNames)
Definition: HWTypes.cpp:1023
LogicalResult verifyParameterStructure(ArrayAttr parameters, ArrayAttr moduleParameters, const EmitErrorFn &emitError)
Check that all the parameter values specified to the instance are structurally valid.
std::function< void(std::function< bool(InFlightDiagnostic &)>)> EmitErrorFn
Whenever the nested function returns true, a note referring to the referenced module is attached to t...
LogicalResult verifyInstanceOfHWModule(Operation *instance, FlatSymbolRefAttr moduleRef, OperandRange inputs, TypeRange results, ArrayAttr argNames, ArrayAttr resultNames, ArrayAttr parameters, SymbolTableCollection &symbolTable)
Combines verifyReferencedModule, verifyInputs, verifyOutputs, and verifyParameters.
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
void getAsmResultNames(OpAsmSetValueNameFn setNameFn, StringRef instanceName, ArrayAttr resultNames, ValueRange results)
Suggest a name for each result value based on the saved result names attribute.
ParseResult parseModuleSignature(OpAsmParser &parser, SmallVectorImpl< PortParse > &args, TypeAttr &modType)
New Style parsing.
void printModuleSignatureNew(OpAsmPrinter &p, Region &body, hw::ModuleType modType, ArrayRef< Attribute > portAttrs, ArrayRef< Location > locAttrs)
bool isOffset(Value base, Value index, uint64_t offset)
Definition: HWOps.cpp:1832
llvm::function_ref< void(OpBuilder &, HWModulePortAccessor &)> HWModuleBuilder
Definition: HWOps.h:123
bool isCombinational(Operation *op)
Return true if the specified operation is a combinational logic op.
Definition: HWOps.cpp:59
ModulePort::Direction flip(ModulePort::Direction direction)
Flip a port direction.
Definition: HWOps.cpp:36
FunctionType getModuleType(Operation *module)
Return the signature for the specified module as a function type.
Definition: HWOps.cpp:519
LogicalResult checkParameterInContext(Attribute value, Operation *module, Operation *usingOp, bool disallowParamRefs=false)
Check parameter specified by value to see if it is valid within the scope of the specified module mod...
Definition: HWOps.cpp:202
bool isValidParameterExpression(Attribute attr, Operation *module)
Return true if the specified attribute tree is made up of nodes that are valid in a parameter express...
Definition: HWOps.cpp:221
bool isValidIndexBitWidth(Value index, Value array)
Definition: HWOps.cpp:48
int64_t getBitWidth(mlir::Type type)
Return the hardware bit width of a type.
Definition: HWTypes.cpp:110
bool isAnyModuleOrInstance(Operation *module)
TODO: Move all these functions to a hw::ModuleLike interface.
Definition: HWOps.cpp:513
StringAttr getVerilogModuleNameAttr(Operation *module)
Returns the verilog module name attribute or symbol name of any module-like operations.
Definition: HWOps.cpp:537
mlir::Type getCanonicalType(mlir::Type type)
Definition: HWTypes.cpp:49
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21
ParseResult parseInputPortList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &inputs, SmallVectorImpl< Type > &inputTypes, ArrayAttr &inputNames)
Parse a list of instance input ports.
void printOutputPortList(OpAsmPrinter &p, Operation *op, TypeRange resultTypes, ArrayAttr resultNames)
Print a list of instance output ports.
ParseResult parseOptionalParameterList(OpAsmParser &parser, ArrayAttr &parameters)
Parse an parameter list if present.
void printOptionalParameterList(OpAsmPrinter &p, Operation *op, ArrayAttr parameters)
Print a parameter list for a module or instance.
StringRef chooseName(StringRef a, StringRef b)
Choose a good name for an item from two options.
Definition: Naming.cpp:47
void printInputPortList(OpAsmPrinter &p, Operation *op, OperandRange inputs, TypeRange inputTypes, ArrayAttr inputNames)
Print a list of instance input ports.
ParseResult parseOutputPortList(OpAsmParser &parser, SmallVectorImpl< Type > &resultTypes, ArrayAttr &resultNames)
Parse a list of instance output ports.
Definition: hw.py:1
function_ref< void(Value, StringRef)> OpAsmSetValueNameFn
Definition: LLVM.h:182
This holds a decoded list of input/inout and output ports for a module or instance.
PortInfo & at(size_t idx)
PortDirectionRange getInputs()
PortDirectionRange getOutputs()
mlir::Type type
Definition: HWTypes.h:31
mlir::StringAttr name
Definition: HWTypes.h:30
This holds the name, type, direction of a module's ports.