Loading [MathJax]/extensions/tex2jax.js
CIRCT 21.0.0git
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
HWOps.cpp
Go to the documentation of this file.
1//===- HWOps.cpp - Implement the HW operations ----------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implement the HW ops.
10//
11//===----------------------------------------------------------------------===//
12
23#include "mlir/IR/Builders.h"
24#include "mlir/IR/PatternMatch.h"
25#include "mlir/Interfaces/FunctionImplementation.h"
26#include "llvm/ADT/BitVector.h"
27#include "llvm/ADT/SmallPtrSet.h"
28#include "llvm/ADT/StringSet.h"
29
30using namespace circt;
31using namespace hw;
32using mlir::TypedAttr;
33
34/// Flip a port direction.
36 switch (direction) {
37 case ModulePort::Direction::Input:
38 return ModulePort::Direction::Output;
39 case ModulePort::Direction::Output:
40 return ModulePort::Direction::Input;
41 case ModulePort::Direction::InOut:
42 return ModulePort::Direction::InOut;
43 }
44 llvm_unreachable("unknown PortDirection");
45}
46
47bool hw::isValidIndexBitWidth(Value index, Value array) {
48 hw::ArrayType arrayType =
49 dyn_cast<hw::ArrayType>(hw::getCanonicalType(array.getType()));
50 assert(arrayType && "expected array type");
51 unsigned indexWidth = index.getType().getIntOrFloatBitWidth();
52 auto requiredWidth = llvm::Log2_64_Ceil(arrayType.getNumElements());
53 return requiredWidth == 0 ? (indexWidth == 0 || indexWidth == 1)
54 : indexWidth == requiredWidth;
55}
56
57/// Return true if the specified operation is a combinational logic op.
58bool hw::isCombinational(Operation *op) {
59 struct IsCombClassifier : public TypeOpVisitor<IsCombClassifier, bool> {
60 bool visitInvalidTypeOp(Operation *op) { return false; }
61 bool visitUnhandledTypeOp(Operation *op) { return true; }
62 };
63
64 return (op->getDialect() && op->getDialect()->getNamespace() == "comb") ||
65 IsCombClassifier().dispatchTypeOpVisitor(op);
66}
67
68static Value foldStructExtract(Operation *inputOp, uint32_t fieldIndex) {
69 // A struct extract of a struct create -> corresponding struct create operand.
70 if (auto structCreate = dyn_cast_or_null<StructCreateOp>(inputOp)) {
71 return structCreate.getOperand(fieldIndex);
72 }
73
74 // Extracting injected field -> corresponding field
75 if (auto structInject = dyn_cast_or_null<StructInjectOp>(inputOp)) {
76 if (structInject.getFieldIndex() != fieldIndex)
77 return {};
78 return structInject.getNewValue();
79 }
80 return {};
81}
82
83static ArrayAttr arrayOrEmpty(mlir::MLIRContext *context,
84 ArrayRef<Attribute> attrs) {
85 if (attrs.empty())
86 return ArrayAttr::get(context, {});
87 bool empty = true;
88 for (auto a : attrs)
89 if (a && !cast<DictionaryAttr>(a).empty()) {
90 empty = false;
91 break;
92 }
93 if (empty)
94 return ArrayAttr::get(context, {});
95 return ArrayAttr::get(context, attrs);
96}
97
98/// Get a special name to use when printing the entry block arguments of the
99/// region contained by an operation in this dialect.
100static void getAsmBlockArgumentNamesImpl(mlir::Region &region,
101 OpAsmSetValueNameFn setNameFn) {
102 if (region.empty())
103 return;
104 // Assign port names to the bbargs.
105 auto module = cast<HWModuleOp>(region.getParentOp());
106
107 auto *block = &region.front();
108 for (size_t i = 0, e = block->getNumArguments(); i != e; ++i) {
109 auto name = module.getInputName(i);
110 // Let mlir deterministically convert names to valid identifiers
111 setNameFn(block->getArgument(i), name);
112 }
113}
114
115enum class Delimiter {
116 None,
117 Paren, // () enclosed list
118 OptionalLessGreater, // <> enclosed list or absent
119};
120
121/// Check parameter specified by `value` to see if it is valid according to the
122/// module's parameters. If not, emit an error to the diagnostic provided as an
123/// argument to the lambda 'instanceError' and return failure, otherwise return
124/// success.
125///
126/// If `disallowParamRefs` is true, then parameter references are not allowed.
127LogicalResult hw::checkParameterInContext(
128 Attribute value, ArrayAttr moduleParameters,
129 const instance_like_impl::EmitErrorFn &instanceError,
130 bool disallowParamRefs) {
131 // Literals are always ok. Their types are already known to match
132 // expectations.
133 if (isa<IntegerAttr>(value) || isa<FloatAttr>(value) ||
134 isa<StringAttr>(value) || isa<ParamVerbatimAttr>(value))
135 return success();
136
137 // Check both subexpressions of an expression.
138 if (auto expr = dyn_cast<ParamExprAttr>(value)) {
139 for (auto op : expr.getOperands())
140 if (failed(checkParameterInContext(op, moduleParameters, instanceError,
141 disallowParamRefs)))
142 return failure();
143 return success();
144 }
145
146 // Parameter references need more analysis to make sure they are valid within
147 // this module.
148 if (auto parameterRef = dyn_cast<ParamDeclRefAttr>(value)) {
149 auto nameAttr = parameterRef.getName();
150
151 // Don't allow references to parameters from the default values of a
152 // parameter list.
153 if (disallowParamRefs) {
154 instanceError([&](auto &diag) {
155 diag << "parameter " << nameAttr
156 << " cannot be used as a default value for a parameter";
157 return false;
158 });
159 return failure();
160 }
161
162 // Find the corresponding attribute in the module.
163 for (auto param : moduleParameters) {
164 auto paramAttr = cast<ParamDeclAttr>(param);
165 if (paramAttr.getName() != nameAttr)
166 continue;
167
168 // If the types match then the reference is ok.
169 if (paramAttr.getType() == parameterRef.getType())
170 return success();
171
172 instanceError([&](auto &diag) {
173 diag << "parameter " << nameAttr << " used with type "
174 << parameterRef.getType() << "; should have type "
175 << paramAttr.getType();
176 return true;
177 });
178 return failure();
179 }
180
181 instanceError([&](auto &diag) {
182 diag << "use of unknown parameter " << nameAttr;
183 return true;
184 });
185 return failure();
186 }
187
188 instanceError([&](auto &diag) {
189 diag << "invalid parameter value " << value;
190 return false;
191 });
192 return failure();
193}
194
195/// Check parameter specified by `value` to see if it is valid within the scope
196/// of the specified module `module`. If not, emit an error at the location of
197/// `usingOp` and return failure, otherwise return success. If `usingOp` is
198/// null, then no diagnostic is generated.
199///
200/// If `disallowParamRefs` is true, then parameter references are not allowed.
201LogicalResult hw::checkParameterInContext(Attribute value, Operation *module,
202 Operation *usingOp,
203 bool disallowParamRefs) {
205 [&](const std::function<bool(InFlightDiagnostic &)> &fn) {
206 if (usingOp) {
207 auto diag = usingOp->emitOpError();
208 if (fn(diag))
209 diag.attachNote(module->getLoc()) << "module declared here";
210 }
211 };
212
213 return checkParameterInContext(value,
214 module->getAttrOfType<ArrayAttr>("parameters"),
215 emitError, disallowParamRefs);
216}
217
218/// Return true if the specified attribute tree is made up of nodes that are
219/// valid in a parameter expression.
220bool hw::isValidParameterExpression(Attribute attr, Operation *module) {
221 return succeeded(checkParameterInContext(attr, module, nullptr, false));
222}
223
225 const ModulePortInfo &info,
226 Region &bodyRegion)
227 : info(info) {
228 inputArgs.resize(info.sizeInputs());
229 for (auto [i, barg] : llvm::enumerate(bodyRegion.getArguments())) {
230 inputIdx[info.at(i).name.str()] = i;
231 inputArgs[i] = barg;
232 }
233
235 for (auto [i, outputInfo] : llvm::enumerate(info.getOutputs())) {
236 outputIdx[outputInfo.name.str()] = i;
237 }
238}
239
240void HWModulePortAccessor::setOutput(unsigned i, Value v) {
241 assert(outputOperands.size() > i && "invalid output index");
242 assert(outputOperands[i] == Value() && "output already set");
243 outputOperands[i] = v;
244}
245
247 assert(inputArgs.size() > i && "invalid input index");
248 return inputArgs[i];
249}
250Value HWModulePortAccessor::getInput(StringRef name) {
251 return getInput(inputIdx.find(name.str())->second);
252}
253void HWModulePortAccessor::setOutput(StringRef name, Value v) {
254 setOutput(outputIdx.find(name.str())->second, v);
255}
256
257//===----------------------------------------------------------------------===//
258// Declarative Canonicalization Patterns
259//===----------------------------------------------------------------------===//
260
261namespace {
262#include "circt/Dialect/HW/HWCanonicalization.cpp.inc"
263} // namespace
264
265//===----------------------------------------------------------------------===//
266// ConstantOp
267//===----------------------------------------------------------------------===//
268
269void ConstantOp::print(OpAsmPrinter &p) {
270 p << " ";
271 p.printAttribute(getValueAttr());
272 p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"value"});
273}
274
275ParseResult ConstantOp::parse(OpAsmParser &parser, OperationState &result) {
276 IntegerAttr valueAttr;
277
278 if (parser.parseAttribute(valueAttr, "value", result.attributes) ||
279 parser.parseOptionalAttrDict(result.attributes))
280 return failure();
281
282 result.addTypes(valueAttr.getType());
283 return success();
284}
285
286LogicalResult ConstantOp::verify() {
287 // If the result type has a bitwidth, then the attribute must match its width.
288 if (getValue().getBitWidth() != cast<IntegerType>(getType()).getWidth())
289 return emitError(
290 "hw.constant attribute bitwidth doesn't match return type");
291
292 return success();
293}
294
295/// Build a ConstantOp from an APInt, infering the result type from the
296/// width of the APInt.
297void ConstantOp::build(OpBuilder &builder, OperationState &result,
298 const APInt &value) {
299
300 auto type = IntegerType::get(builder.getContext(), value.getBitWidth());
301 auto attr = builder.getIntegerAttr(type, value);
302 return build(builder, result, type, attr);
303}
304
305/// Build a ConstantOp from an APInt, infering the result type from the
306/// width of the APInt.
307void ConstantOp::build(OpBuilder &builder, OperationState &result,
308 IntegerAttr value) {
309 return build(builder, result, value.getType(), value);
310}
311
312/// This builder allows construction of small signed integers like 0, 1, -1
313/// matching a specified MLIR IntegerType. This shouldn't be used for general
314/// constant folding because it only works with values that can be expressed in
315/// an int64_t. Use APInt's instead.
316void ConstantOp::build(OpBuilder &builder, OperationState &result, Type type,
317 int64_t value) {
318 auto numBits = cast<IntegerType>(type).getWidth();
319 build(builder, result,
320 APInt(numBits, (uint64_t)value, /*isSigned=*/true,
321 /*implicitTrunc=*/true));
322}
323
324void ConstantOp::getAsmResultNames(
325 function_ref<void(Value, StringRef)> setNameFn) {
326 auto intTy = getType();
327 auto intCst = getValue();
328
329 // Sugar i1 constants with 'true' and 'false'.
330 if (cast<IntegerType>(intTy).getWidth() == 1)
331 return setNameFn(getResult(), intCst.isZero() ? "false" : "true");
332
333 // Otherwise, build a complex name with the value and type.
334 SmallVector<char, 32> specialNameBuffer;
335 llvm::raw_svector_ostream specialName(specialNameBuffer);
336 specialName << 'c' << intCst << '_' << intTy;
337 setNameFn(getResult(), specialName.str());
338}
339
340OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) {
341 assert(adaptor.getOperands().empty() && "constant has no operands");
342 return getValueAttr();
343}
344
345//===----------------------------------------------------------------------===//
346// WireOp
347//===----------------------------------------------------------------------===//
348
349/// Check whether an operation has any additional attributes set beyond its
350/// standard list of attributes returned by `getAttributeNames`.
351template <class Op>
352static bool hasAdditionalAttributes(Op op,
353 ArrayRef<StringRef> ignoredAttrs = {}) {
354 auto names = op.getAttributeNames();
355 llvm::SmallDenseSet<StringRef> nameSet;
356 nameSet.reserve(names.size() + ignoredAttrs.size());
357 nameSet.insert(names.begin(), names.end());
358 nameSet.insert(ignoredAttrs.begin(), ignoredAttrs.end());
359 return llvm::any_of(op->getAttrs(), [&](auto namedAttr) {
360 return !nameSet.contains(namedAttr.getName());
361 });
362}
363
364void WireOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
365 // If the wire has an optional 'name' attribute, use it.
366 auto nameAttr = (*this)->getAttrOfType<StringAttr>("name");
367 if (nameAttr && !nameAttr.getValue().empty())
368 setNameFn(getResult(), nameAttr.getValue());
369}
370
371std::optional<size_t> WireOp::getTargetResultIndex() { return 0; }
372
373OpFoldResult WireOp::fold(FoldAdaptor adaptor) {
374 // If the wire has no additional attributes, no name, and no symbol, just
375 // forward its input.
376 if (!hasAdditionalAttributes(*this, {"sv.namehint"}) && !getNameAttr() &&
377 !getInnerSymAttr())
378 return getInput();
379 return {};
380}
381
382LogicalResult WireOp::canonicalize(WireOp wire, PatternRewriter &rewriter) {
383 // Block if the wire has any attributes.
384 if (hasAdditionalAttributes(wire, {"sv.namehint"}))
385 return failure();
386
387 // If the wire has a symbol, then we can't delete it.
388 if (wire.getInnerSymAttr())
389 return failure();
390
391 // If the wire has a name or an `sv.namehint` attribute, propagate it as an
392 // `sv.namehint` to the expression.
393 if (auto *inputOp = wire.getInput().getDefiningOp())
394 if (auto name = chooseName(wire, inputOp))
395 rewriter.modifyOpInPlace(inputOp,
396 [&] { inputOp->setAttr("sv.namehint", name); });
397
398 rewriter.replaceOp(wire, wire.getInput());
399 return success();
400}
401
402//===----------------------------------------------------------------------===//
403// AggregateConstantOp
404//===----------------------------------------------------------------------===//
405
406static LogicalResult checkAttributes(Operation *op, Attribute attr, Type type) {
407 // If this is a type alias, get the underlying type.
408 if (auto typeAlias = dyn_cast<TypeAliasType>(type))
409 type = typeAlias.getCanonicalType();
410
411 if (auto structType = dyn_cast<StructType>(type)) {
412 auto arrayAttr = dyn_cast<ArrayAttr>(attr);
413 if (!arrayAttr)
414 return op->emitOpError("expected array attribute for constant of type ")
415 << type;
416 if (structType.getElements().size() != arrayAttr.size())
417 return op->emitOpError("array attribute (")
418 << arrayAttr.size() << ") has wrong size for struct constant ("
419 << structType.getElements().size() << ")";
420
421 for (auto [attr, fieldInfo] :
422 llvm::zip(arrayAttr.getValue(), structType.getElements())) {
423 if (failed(checkAttributes(op, attr, fieldInfo.type)))
424 return failure();
425 }
426 } else if (auto arrayType = dyn_cast<ArrayType>(type)) {
427 auto arrayAttr = dyn_cast<ArrayAttr>(attr);
428 if (!arrayAttr)
429 return op->emitOpError("expected array attribute for constant of type ")
430 << type;
431 if (arrayType.getNumElements() != arrayAttr.size())
432 return op->emitOpError("array attribute (")
433 << arrayAttr.size() << ") has wrong size for array constant ("
434 << arrayType.getNumElements() << ")";
435
436 auto elementType = arrayType.getElementType();
437 for (auto attr : arrayAttr.getValue()) {
438 if (failed(checkAttributes(op, attr, elementType)))
439 return failure();
440 }
441 } else if (auto arrayType = dyn_cast<UnpackedArrayType>(type)) {
442 auto arrayAttr = dyn_cast<ArrayAttr>(attr);
443 if (!arrayAttr)
444 return op->emitOpError("expected array attribute for constant of type ")
445 << type;
446 auto elementType = arrayType.getElementType();
447 if (arrayType.getNumElements() != arrayAttr.size())
448 return op->emitOpError("array attribute (")
449 << arrayAttr.size()
450 << ") has wrong size for unpacked array constant ("
451 << arrayType.getNumElements() << ")";
452
453 for (auto attr : arrayAttr.getValue()) {
454 if (failed(checkAttributes(op, attr, elementType)))
455 return failure();
456 }
457 } else if (auto enumType = dyn_cast<EnumType>(type)) {
458 auto stringAttr = dyn_cast<StringAttr>(attr);
459 if (!stringAttr)
460 return op->emitOpError("expected string attribute for constant of type ")
461 << type;
462 } else if (auto intType = dyn_cast<IntegerType>(type)) {
463 // Check the attribute kind is correct.
464 auto intAttr = dyn_cast<IntegerAttr>(attr);
465 if (!intAttr)
466 return op->emitOpError("expected integer attribute for constant of type ")
467 << type;
468 // Check the bitwidth is correct.
469 if (intAttr.getValue().getBitWidth() != intType.getWidth())
470 return op->emitOpError("hw.constant attribute bitwidth "
471 "doesn't match return type");
472 } else if (auto typedAttr = dyn_cast<TypedAttr>(attr)) {
473 if (typedAttr.getType() != type)
474 return op->emitOpError("typed attr doesn't match the return type ")
475 << type;
476 } else {
477 return op->emitOpError("unknown element type ") << type;
478 }
479 return success();
480}
481
482LogicalResult AggregateConstantOp::verify() {
483 return checkAttributes(*this, getFieldsAttr(), getType());
484}
485
486OpFoldResult AggregateConstantOp::fold(FoldAdaptor) { return getFieldsAttr(); }
487
488//===----------------------------------------------------------------------===//
489// ParamValueOp
490//===----------------------------------------------------------------------===//
491
492static ParseResult parseParamValue(OpAsmParser &p, Attribute &value,
493 Type &resultType) {
494 if (p.parseType(resultType) || p.parseEqual() ||
495 p.parseAttribute(value, resultType))
496 return failure();
497 return success();
498}
499
500static void printParamValue(OpAsmPrinter &p, Operation *, Attribute value,
501 Type resultType) {
502 p << resultType << " = ";
503 p.printAttributeWithoutType(value);
504}
505
506LogicalResult ParamValueOp::verify() {
507 // Check that the attribute expression is valid in this module.
509 getValue(), (*this)->getParentOfType<hw::HWModuleOp>(), *this);
510}
511
512OpFoldResult ParamValueOp::fold(FoldAdaptor adaptor) {
513 assert(adaptor.getOperands().empty() && "hw.param.value has no operands");
514 return getValueAttr();
515}
516
517//===----------------------------------------------------------------------===//
518// HWModuleOp
519//===----------------------------------------------------------------------===/
520
521/// Return true if isAnyModule or instance.
522bool hw::isAnyModuleOrInstance(Operation *moduleOrInstance) {
523 return isa<HWModuleLike, InstanceOp>(moduleOrInstance);
524}
525
526/// Return the signature for a module as a function type from the module itself
527/// or from an hw::InstanceOp.
528FunctionType hw::getModuleType(Operation *moduleOrInstance) {
529 return TypeSwitch<Operation *, FunctionType>(moduleOrInstance)
530 .Case<InstanceOp, InstanceChoiceOp>([](auto instance) {
531 SmallVector<Type> inputs(instance->getOperandTypes());
532 SmallVector<Type> results(instance->getResultTypes());
533 return FunctionType::get(instance->getContext(), inputs, results);
534 })
535 .Case<HWModuleLike>(
536 [](auto mod) { return mod.getHWModuleType().getFuncType(); })
537 .Default([](Operation *op) {
538 return cast<FunctionType>(
539 cast<mlir::FunctionOpInterface>(op).getFunctionType());
540 });
541}
542
543/// Return the name to use for the Verilog module that we're referencing
544/// here. This is typically the symbol, but can be overridden with the
545/// verilogName attribute.
546StringAttr hw::getVerilogModuleNameAttr(Operation *module) {
547 auto nameAttr = module->getAttrOfType<StringAttr>("verilogName");
548 if (nameAttr)
549 return nameAttr;
550
551 return module->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName());
552}
553
554template <typename ModuleTy>
555static void
556buildModule(OpBuilder &builder, OperationState &result, StringAttr name,
557 const ModulePortInfo &ports, ArrayAttr parameters,
558 ArrayRef<NamedAttribute> attributes, StringAttr comment) {
559 using namespace mlir::function_interface_impl;
560
561 // Add an attribute for the name.
562 result.addAttribute(SymbolTable::getSymbolAttrName(), name);
563
564 SmallVector<Attribute> perPortAttrs;
565 SmallVector<ModulePort> portTypes;
566
567 for (auto elt : ports) {
568 portTypes.push_back(elt);
569 llvm::SmallVector<NamedAttribute> portAttrs;
570 if (elt.attrs)
571 llvm::copy(elt.attrs, std::back_inserter(portAttrs));
572 perPortAttrs.push_back(builder.getDictionaryAttr(portAttrs));
573 }
574
575 // Allow clients to pass in null for the parameters list.
576 if (!parameters)
577 parameters = builder.getArrayAttr({});
578
579 // Record the argument and result types as an attribute.
580 auto type = ModuleType::get(builder.getContext(), portTypes);
581 result.addAttribute(ModuleTy::getModuleTypeAttrName(result.name),
582 TypeAttr::get(type));
583 result.addAttribute("per_port_attrs",
584 arrayOrEmpty(builder.getContext(), perPortAttrs));
585 result.addAttribute("parameters", parameters);
586 if (!comment)
587 comment = builder.getStringAttr("");
588 result.addAttribute("comment", comment);
589 result.addAttributes(attributes);
590 result.addRegion();
591}
592
593/// Internal implementation of argument/result insertion and removal on modules.
595 MLIRContext *context, ArrayRef<std::pair<unsigned, PortInfo>> insertArgs,
596 ArrayRef<unsigned> removeArgs, ArrayRef<Attribute> oldArgNames,
597 ArrayRef<Type> oldArgTypes, ArrayRef<Attribute> oldArgAttrs,
598 ArrayRef<Location> oldArgLocs, SmallVector<Attribute> &newArgNames,
599 SmallVector<Type> &newArgTypes, SmallVector<Attribute> &newArgAttrs,
600 SmallVector<Location> &newArgLocs, Block *body = nullptr) {
601
602#ifndef NDEBUG
603 // Check that the `insertArgs` and `removeArgs` indices are in ascending
604 // order.
605 assert(llvm::is_sorted(insertArgs,
606 [](auto &a, auto &b) { return a.first < b.first; }) &&
607 "insertArgs must be in ascending order");
608 assert(llvm::is_sorted(removeArgs, [](auto &a, auto &b) { return a < b; }) &&
609 "removeArgs must be in ascending order");
610#endif
611
612 auto oldArgCount = oldArgTypes.size();
613 auto newArgCount = oldArgCount + insertArgs.size() - removeArgs.size();
614 assert((int)newArgCount >= 0);
615
616 newArgNames.reserve(newArgCount);
617 newArgTypes.reserve(newArgCount);
618 newArgAttrs.reserve(newArgCount);
619 newArgLocs.reserve(newArgCount);
620
621 auto exportPortAttrName = StringAttr::get(context, "hw.exportPort");
622 auto emptyDictAttr = DictionaryAttr::get(context, {});
623 auto unknownLoc = UnknownLoc::get(context);
624
625 BitVector erasedIndices;
626 if (body)
627 erasedIndices.resize(oldArgCount + insertArgs.size());
628
629 for (unsigned argIdx = 0, idx = 0; argIdx <= oldArgCount; ++argIdx, ++idx) {
630 // Insert new ports at this position.
631 while (!insertArgs.empty() && insertArgs[0].first == argIdx) {
632 auto port = insertArgs[0].second;
633 if (port.dir == ModulePort::Direction::InOut &&
634 !isa<InOutType>(port.type))
635 port.type = InOutType::get(port.type);
636 auto sym = port.getSym();
637 Attribute attr =
638 (sym && !sym.empty())
639 ? DictionaryAttr::get(context, {{exportPortAttrName, sym}})
640 : emptyDictAttr;
641 newArgNames.push_back(port.name);
642 newArgTypes.push_back(port.type);
643 newArgAttrs.push_back(attr);
644 insertArgs = insertArgs.drop_front();
645 LocationAttr loc = port.loc ? port.loc : unknownLoc;
646 newArgLocs.push_back(loc);
647 if (body)
648 body->insertArgument(idx++, port.type, loc);
649 }
650 if (argIdx == oldArgCount)
651 break;
652
653 // Migrate the old port at this position.
654 bool removed = false;
655 while (!removeArgs.empty() && removeArgs[0] == argIdx) {
656 removeArgs = removeArgs.drop_front();
657 removed = true;
658 }
659
660 if (removed) {
661 if (body)
662 erasedIndices.set(idx);
663 } else {
664 newArgNames.push_back(oldArgNames[argIdx]);
665 newArgTypes.push_back(oldArgTypes[argIdx]);
666 newArgAttrs.push_back(oldArgAttrs.empty() ? emptyDictAttr
667 : oldArgAttrs[argIdx]);
668 newArgLocs.push_back(oldArgLocs[argIdx]);
669 }
670 }
671
672 if (body)
673 body->eraseArguments(erasedIndices);
674
675 assert(newArgNames.size() == newArgCount);
676 assert(newArgTypes.size() == newArgCount);
677 assert(newArgAttrs.size() == newArgCount);
678 assert(newArgLocs.size() == newArgCount);
679}
680
681/// Insert and remove ports of a module. The insertion and removal indices must
682/// be in ascending order. The indices refer to the port positions before any
683/// insertion or removal occurs. Ports inserted at the same index will appear in
684/// the module in the same order as they were listed in the `insert*` array.
685///
686/// The operation must be any of the module-like operations.
687///
688/// This is marked deprecated as it's only used from HandshakeToHW and
689/// PortConverter and is likely broken and not currently tested. Users of this
690/// are still written dealing with input and output ports separately, which is
691/// an old and broken style.
692[[deprecated]] static void
693modifyModulePorts(Operation *op,
694 ArrayRef<std::pair<unsigned, PortInfo>> insertInputs,
695 ArrayRef<std::pair<unsigned, PortInfo>> insertOutputs,
696 ArrayRef<unsigned> removeInputs,
697 ArrayRef<unsigned> removeOutputs, Block *body = nullptr) {
698 auto moduleOp = cast<HWModuleLike>(op);
699 auto *context = moduleOp.getContext();
700
701 // Dig up the old argument and result data.
702 auto oldArgNames = moduleOp.getInputNames();
703 auto oldArgTypes = moduleOp.getInputTypes();
704 auto oldArgAttrs = moduleOp.getAllInputAttrs();
705 auto oldArgLocs = moduleOp.getInputLocs();
706
707 auto oldResultNames = moduleOp.getOutputNames();
708 auto oldResultTypes = moduleOp.getOutputTypes();
709 auto oldResultAttrs = moduleOp.getAllOutputAttrs();
710 auto oldResultLocs = moduleOp.getOutputLocs();
711
712 // Modify the ports.
713 SmallVector<Attribute> newArgNames, newResultNames;
714 SmallVector<Type> newArgTypes, newResultTypes;
715 SmallVector<Attribute> newArgAttrs, newResultAttrs;
716 SmallVector<Location> newArgLocs, newResultLocs;
717
718 modifyModuleArgs(context, insertInputs, removeInputs, oldArgNames,
719 oldArgTypes, oldArgAttrs, oldArgLocs, newArgNames,
720 newArgTypes, newArgAttrs, newArgLocs, body);
721
722 modifyModuleArgs(context, insertOutputs, removeOutputs, oldResultNames,
723 oldResultTypes, oldResultAttrs, oldResultLocs,
724 newResultNames, newResultTypes, newResultAttrs,
725 newResultLocs);
726
727 // Update the module operation types and attributes.
728 auto fnty = FunctionType::get(context, newArgTypes, newResultTypes);
729 auto modty = detail::fnToMod(fnty, newArgNames, newResultNames);
730 moduleOp.setHWModuleType(modty);
731 moduleOp.setAllInputAttrs(newArgAttrs);
732 moduleOp.setAllOutputAttrs(newResultAttrs);
733
734 newArgLocs.append(newResultLocs.begin(), newResultLocs.end());
735 moduleOp.setAllPortLocs(newArgLocs);
736}
737
738void HWModuleOp::build(OpBuilder &builder, OperationState &result,
739 StringAttr name, const ModulePortInfo &ports,
740 ArrayAttr parameters,
741 ArrayRef<NamedAttribute> attributes, StringAttr comment,
742 bool shouldEnsureTerminator) {
743 buildModule<HWModuleOp>(builder, result, name, ports, parameters, attributes,
744 comment);
745
746 // Create a region and a block for the body.
747 auto *bodyRegion = result.regions[0].get();
748 Block *body = new Block();
749 bodyRegion->push_back(body);
750
751 // Add arguments to the body block.
752 auto unknownLoc = builder.getUnknownLoc();
753 for (auto port : ports.getInputs()) {
754 auto loc = port.loc ? Location(port.loc) : unknownLoc;
755 auto type = port.type;
756 if (port.isInOut() && !isa<InOutType>(type))
757 type = InOutType::get(type);
758 body->addArgument(type, loc);
759 }
760
761 // Add result ports attribute.
762 auto unknownLocAttr = cast<LocationAttr>(unknownLoc);
763 SmallVector<Attribute> resultLocs;
764 for (auto port : ports.getOutputs())
765 resultLocs.push_back(port.loc ? port.loc : unknownLocAttr);
766 result.addAttribute("result_locs", builder.getArrayAttr(resultLocs));
767
768 if (shouldEnsureTerminator)
769 HWModuleOp::ensureTerminator(*bodyRegion, builder, result.location);
770}
771
772void HWModuleOp::build(OpBuilder &builder, OperationState &result,
773 StringAttr name, ArrayRef<PortInfo> ports,
774 ArrayAttr parameters,
775 ArrayRef<NamedAttribute> attributes,
776 StringAttr comment) {
777 build(builder, result, name, ModulePortInfo(ports), parameters, attributes,
778 comment);
779}
780
781void HWModuleOp::build(OpBuilder &builder, OperationState &odsState,
782 StringAttr name, const ModulePortInfo &ports,
783 HWModuleBuilder modBuilder, ArrayAttr parameters,
784 ArrayRef<NamedAttribute> attributes,
785 StringAttr comment) {
786 build(builder, odsState, name, ports, parameters, attributes, comment,
787 /*shouldEnsureTerminator=*/false);
788 auto *bodyRegion = odsState.regions[0].get();
789 OpBuilder::InsertionGuard guard(builder);
790 auto accessor = HWModulePortAccessor(odsState.location, ports, *bodyRegion);
791 builder.setInsertionPointToEnd(&bodyRegion->front());
792 modBuilder(builder, accessor);
793 // Create output operands.
794 llvm::SmallVector<Value> outputOperands = accessor.getOutputOperands();
795 builder.create<hw::OutputOp>(odsState.location, outputOperands);
796}
797
798void HWModuleOp::modifyPorts(
799 ArrayRef<std::pair<unsigned, PortInfo>> insertInputs,
800 ArrayRef<std::pair<unsigned, PortInfo>> insertOutputs,
801 ArrayRef<unsigned> eraseInputs, ArrayRef<unsigned> eraseOutputs) {
802 modifyModulePorts(*this, insertInputs, insertOutputs, eraseInputs,
803 eraseOutputs);
804}
805
806/// Return the name to use for the Verilog module that we're referencing
807/// here. This is typically the symbol, but can be overridden with the
808/// verilogName attribute.
809StringAttr HWModuleExternOp::getVerilogModuleNameAttr() {
810 if (auto vName = getVerilogNameAttr())
811 return vName;
812
813 return (*this)->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName());
814}
815
816StringAttr HWModuleGeneratedOp::getVerilogModuleNameAttr() {
817 if (auto vName = getVerilogNameAttr()) {
818 return vName;
819 }
820 return (*this)->getAttrOfType<StringAttr>(
821 ::mlir::SymbolTable::getSymbolAttrName());
822}
823
824void HWModuleExternOp::build(OpBuilder &builder, OperationState &result,
825 StringAttr name, const ModulePortInfo &ports,
826 StringRef verilogName, ArrayAttr parameters,
827 ArrayRef<NamedAttribute> attributes) {
828 buildModule<HWModuleExternOp>(builder, result, name, ports, parameters,
829 attributes, {});
830
831 // Add the port locations.
832 LocationAttr unknownLoc = builder.getUnknownLoc();
833 SmallVector<Attribute> portLocs;
834 for (auto elt : ports)
835 portLocs.push_back(elt.loc ? elt.loc : unknownLoc);
836 result.addAttribute("port_locs", builder.getArrayAttr(portLocs));
837
838 if (!verilogName.empty())
839 result.addAttribute("verilogName", builder.getStringAttr(verilogName));
840}
841
842void HWModuleExternOp::build(OpBuilder &builder, OperationState &result,
843 StringAttr name, ArrayRef<PortInfo> ports,
844 StringRef verilogName, ArrayAttr parameters,
845 ArrayRef<NamedAttribute> attributes) {
846 build(builder, result, name, ModulePortInfo(ports), verilogName, parameters,
847 attributes);
848}
849
850void HWModuleExternOp::modifyPorts(
851 ArrayRef<std::pair<unsigned, PortInfo>> insertInputs,
852 ArrayRef<std::pair<unsigned, PortInfo>> insertOutputs,
853 ArrayRef<unsigned> eraseInputs, ArrayRef<unsigned> eraseOutputs) {
854 modifyModulePorts(*this, insertInputs, insertOutputs, eraseInputs,
855 eraseOutputs);
856}
857
858void HWModuleExternOp::appendOutputs(
859 ArrayRef<std::pair<StringAttr, Value>> outputs) {}
860
861void HWModuleGeneratedOp::build(OpBuilder &builder, OperationState &result,
862 FlatSymbolRefAttr genKind, StringAttr name,
863 const ModulePortInfo &ports,
864 StringRef verilogName, ArrayAttr parameters,
865 ArrayRef<NamedAttribute> attributes) {
866 buildModule<HWModuleGeneratedOp>(builder, result, name, ports, parameters,
867 attributes, {});
868 // Add the port locations.
869 LocationAttr unknownLoc = builder.getUnknownLoc();
870 SmallVector<Attribute> portLocs;
871 for (auto elt : ports)
872 portLocs.push_back(elt.loc ? elt.loc : unknownLoc);
873 result.addAttribute("port_locs", builder.getArrayAttr(portLocs));
874
875 result.addAttribute("generatorKind", genKind);
876 if (!verilogName.empty())
877 result.addAttribute("verilogName", builder.getStringAttr(verilogName));
878}
879
880void HWModuleGeneratedOp::build(OpBuilder &builder, OperationState &result,
881 FlatSymbolRefAttr genKind, StringAttr name,
882 ArrayRef<PortInfo> ports, StringRef verilogName,
883 ArrayAttr parameters,
884 ArrayRef<NamedAttribute> attributes) {
885 build(builder, result, genKind, name, ModulePortInfo(ports), verilogName,
886 parameters, attributes);
887}
888
889void HWModuleGeneratedOp::modifyPorts(
890 ArrayRef<std::pair<unsigned, PortInfo>> insertInputs,
891 ArrayRef<std::pair<unsigned, PortInfo>> insertOutputs,
892 ArrayRef<unsigned> eraseInputs, ArrayRef<unsigned> eraseOutputs) {
893 modifyModulePorts(*this, insertInputs, insertOutputs, eraseInputs,
894 eraseOutputs);
895}
896
897void HWModuleGeneratedOp::appendOutputs(
898 ArrayRef<std::pair<StringAttr, Value>> outputs) {}
899
900static bool hasAttribute(StringRef name, ArrayRef<NamedAttribute> attrs) {
901 for (auto &argAttr : attrs)
902 if (argAttr.getName() == name)
903 return true;
904 return false;
905}
906
907template <typename ModuleTy>
908static ParseResult parseHWModuleOp(OpAsmParser &parser,
909 OperationState &result) {
910
911 using namespace mlir::function_interface_impl;
912 auto builder = parser.getBuilder();
913 auto loc = parser.getCurrentLocation();
914
915 // Parse the visibility attribute.
916 (void)mlir::impl::parseOptionalVisibilityKeyword(parser, result.attributes);
917
918 // Parse the name as a symbol.
919 StringAttr nameAttr;
920 if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
921 result.attributes))
922 return failure();
923
924 // Parse the generator information.
925 FlatSymbolRefAttr kindAttr;
926 if constexpr (std::is_same_v<ModuleTy, HWModuleGeneratedOp>) {
927 if (parser.parseComma() ||
928 parser.parseAttribute(kindAttr, "generatorKind", result.attributes)) {
929 return failure();
930 }
931 }
932
933 // Parse the parameters.
934 ArrayAttr parameters;
935 if (parseOptionalParameterList(parser, parameters))
936 return failure();
937
938 SmallVector<module_like_impl::PortParse> ports;
939 TypeAttr modType;
940 if (failed(module_like_impl::parseModuleSignature(parser, ports, modType)))
941 return failure();
942
943 // Parse the attribute dict.
944 if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
945 return failure();
946
947 if (hasAttribute("parameters", result.attributes)) {
948 parser.emitError(loc, "explicit `parameters` attributes not allowed");
949 return failure();
950 }
951
952 result.addAttribute("parameters", parameters);
953 result.addAttribute(ModuleTy::getModuleTypeAttrName(result.name), modType);
954
955 // Convert the specified array of dictionary attrs (which may have null
956 // entries) to an ArrayAttr of dictionaries.
957 SmallVector<Attribute> attrs;
958 for (auto &port : ports)
959 attrs.push_back(port.attrs ? port.attrs : builder.getDictionaryAttr({}));
960 // Add the attributes to the ports.
961 auto nonEmptyAttrsFn = [](Attribute attr) {
962 return attr && !cast<DictionaryAttr>(attr).empty();
963 };
964 if (llvm::any_of(attrs, nonEmptyAttrsFn))
965 result.addAttribute(ModuleTy::getPerPortAttrsAttrName(result.name),
966 builder.getArrayAttr(attrs));
967
968 // Add the port locations.
969 auto unknownLoc = builder.getUnknownLoc();
970 auto nonEmptyLocsFn = [unknownLoc](Attribute attr) {
971 return attr && cast<Location>(attr) != unknownLoc;
972 };
973 SmallVector<Attribute> locs;
974 StringAttr portLocsAttrName;
975 if constexpr (std::is_same_v<ModuleTy, HWModuleOp>) {
976 // Plain modules only store the output port locations, as the input port
977 // locations will be stored in the basic block arguments.
978 portLocsAttrName = ModuleTy::getResultLocsAttrName(result.name);
979 for (auto &port : ports)
980 if (port.direction == ModulePort::Direction::Output)
981 locs.push_back(port.sourceLoc ? Location(*port.sourceLoc) : unknownLoc);
982 } else {
983 // All other modules store all port locations in a single array.
984 portLocsAttrName = ModuleTy::getPortLocsAttrName(result.name);
985 for (auto &port : ports)
986 locs.push_back(port.sourceLoc ? Location(*port.sourceLoc) : unknownLoc);
987 }
988 if (llvm::any_of(locs, nonEmptyLocsFn))
989 result.addAttribute(portLocsAttrName, builder.getArrayAttr(locs));
990
991 // Add the entry block arguments.
992 SmallVector<OpAsmParser::Argument, 4> entryArgs;
993 for (auto &port : ports)
994 if (port.direction != ModulePort::Direction::Output)
995 entryArgs.push_back(port);
996
997 // Parse the optional function body.
998 auto *body = result.addRegion();
999 if (std::is_same_v<ModuleTy, HWModuleOp>) {
1000 if (parser.parseRegion(*body, entryArgs))
1001 return failure();
1002
1003 HWModuleOp::ensureTerminator(*body, parser.getBuilder(), result.location);
1004 }
1005 return success();
1006}
1007
1008ParseResult HWModuleOp::parse(OpAsmParser &parser, OperationState &result) {
1009 return parseHWModuleOp<HWModuleOp>(parser, result);
1010}
1011
1012ParseResult HWModuleExternOp::parse(OpAsmParser &parser,
1013 OperationState &result) {
1014 return parseHWModuleOp<HWModuleExternOp>(parser, result);
1015}
1016
1017ParseResult HWModuleGeneratedOp::parse(OpAsmParser &parser,
1018 OperationState &result) {
1019 return parseHWModuleOp<HWModuleGeneratedOp>(parser, result);
1020}
1021
1022FunctionType getHWModuleOpType(Operation *op) {
1023 if (auto mod = dyn_cast<HWModuleLike>(op))
1024 return mod.getHWModuleType().getFuncType();
1025 return cast<FunctionType>(
1026 cast<mlir::FunctionOpInterface>(op).getFunctionType());
1027}
1028
1029template <typename ModuleTy>
1030static void printModuleOp(OpAsmPrinter &p, ModuleTy mod) {
1031 p << ' ';
1032 // Print the visibility of the module.
1033 StringRef visibilityAttrName = SymbolTable::getVisibilityAttrName();
1034 if (auto visibility = mod.getOperation()->template getAttrOfType<StringAttr>(
1035 visibilityAttrName))
1036 p << visibility.getValue() << ' ';
1037
1038 // Print the operation and the function name.
1039 p.printSymbolName(SymbolTable::getSymbolName(mod.getOperation()).getValue());
1040 if (auto gen = dyn_cast<HWModuleGeneratedOp>(mod.getOperation())) {
1041 p << ", ";
1042 p.printSymbolName(gen.getGeneratorKind());
1043 }
1044
1045 // Print the parameter list if present.
1046 printOptionalParameterList(p, mod.getOperation(), mod.getParameters());
1047
1049
1050 SmallVector<StringRef, 3> omittedAttrs;
1051 if (isa<HWModuleGeneratedOp>(mod.getOperation()))
1052 omittedAttrs.push_back("generatorKind");
1053 if constexpr (std::is_same_v<ModuleTy, HWModuleOp>)
1054 omittedAttrs.push_back(mod.getResultLocsAttrName());
1055 else
1056 omittedAttrs.push_back(mod.getPortLocsAttrName());
1057 omittedAttrs.push_back(mod.getModuleTypeAttrName());
1058 omittedAttrs.push_back(mod.getPerPortAttrsAttrName());
1059 omittedAttrs.push_back(mod.getParametersAttrName());
1060 omittedAttrs.push_back(visibilityAttrName);
1061 if (auto cmt =
1062 mod.getOperation()->template getAttrOfType<StringAttr>("comment"))
1063 if (cmt.getValue().empty())
1064 omittedAttrs.push_back("comment");
1065
1066 mlir::function_interface_impl::printFunctionAttributes(p, mod.getOperation(),
1067 omittedAttrs);
1068}
1069
1070void HWModuleExternOp::print(OpAsmPrinter &p) { printModuleOp(p, *this); }
1071void HWModuleGeneratedOp::print(OpAsmPrinter &p) { printModuleOp(p, *this); }
1072
1073void HWModuleOp::print(OpAsmPrinter &p) {
1074 printModuleOp(p, *this);
1075
1076 // Print the body if this is not an external function.
1077 Region &body = getBody();
1078 if (!body.empty()) {
1079 p << " ";
1080 p.printRegion(body, /*printEntryBlockArgs=*/false,
1081 /*printBlockTerminators=*/true);
1082 }
1083}
1084
1085static LogicalResult verifyModuleCommon(HWModuleLike module) {
1086 assert(isa<HWModuleLike>(module) &&
1087 "verifier hook should only be called on modules");
1088
1089 SmallPtrSet<Attribute, 4> paramNames;
1090
1091 // Check parameter default values are sensible.
1092 for (auto param : module->getAttrOfType<ArrayAttr>("parameters")) {
1093 auto paramAttr = cast<ParamDeclAttr>(param);
1094
1095 // Check that we don't have any redundant parameter names. These are
1096 // resolved by string name: reuse of the same name would cause ambiguities.
1097 if (!paramNames.insert(paramAttr.getName()).second)
1098 return module->emitOpError("parameter ")
1099 << paramAttr << " has the same name as a previous parameter";
1100
1101 // Default values are allowed to be missing, check them if present.
1102 auto value = paramAttr.getValue();
1103 if (!value)
1104 continue;
1105
1106 auto typedValue = dyn_cast<TypedAttr>(value);
1107 if (!typedValue)
1108 return module->emitOpError("parameter ")
1109 << paramAttr << " should have a typed value; has value " << value;
1110
1111 if (typedValue.getType() != paramAttr.getType())
1112 return module->emitOpError("parameter ")
1113 << paramAttr << " should have type " << paramAttr.getType()
1114 << "; has type " << typedValue.getType();
1115
1116 // Verify that this is a valid parameter value, disallowing parameter
1117 // references. We could allow parameters to refer to each other in the
1118 // future with lexical ordering if there is a need.
1119 if (failed(checkParameterInContext(value, module, module,
1120 /*disallowParamRefs=*/true)))
1121 return failure();
1122 }
1123 return success();
1124}
1125
1126LogicalResult HWModuleOp::verify() {
1127 if (failed(verifyModuleCommon(*this)))
1128 return failure();
1129
1130 auto type = getModuleType();
1131 auto *body = getBodyBlock();
1132
1133 // Verify the number of block arguments.
1134 auto numInputs = type.getNumInputs();
1135 if (body->getNumArguments() != numInputs)
1136 return emitOpError("entry block must have")
1137 << numInputs << " arguments to match module signature";
1138
1139 return success();
1140}
1141
1142LogicalResult HWModuleExternOp::verify() { return verifyModuleCommon(*this); }
1143
1144std::pair<StringAttr, BlockArgument>
1145HWModuleOp::insertInput(unsigned index, StringAttr name, Type ty) {
1146 // Find a unique name for the wire.
1147 Namespace ns;
1148 auto ports = getPortList();
1149 for (auto port : ports)
1150 ns.newName(port.name.getValue());
1151 auto nameAttr = StringAttr::get(getContext(), ns.newName(name.getValue()));
1152
1153 Block *body = getBodyBlock();
1154
1155 // Create a new port for the host clock.
1156 PortInfo port;
1157 port.name = nameAttr;
1159 port.type = ty;
1160 modifyModulePorts(getOperation(), {std::make_pair(index, port)}, {}, {}, {},
1161 body);
1162
1163 // Add a new argument.
1164 return {nameAttr, body->getArgument(index)};
1165}
1166
1167void HWModuleOp::insertOutputs(unsigned index,
1168 ArrayRef<std::pair<StringAttr, Value>> outputs) {
1169
1170 auto output = cast<OutputOp>(getBodyBlock()->getTerminator());
1171 assert(index <= output->getNumOperands() && "invalid output index");
1172
1173 // Rewrite the port list of the module.
1174 SmallVector<std::pair<unsigned, PortInfo>> indexedNewPorts;
1175 for (auto &[name, value] : outputs) {
1176 PortInfo port;
1177 port.name = name;
1179 port.type = value.getType();
1180 indexedNewPorts.emplace_back(index, port);
1181 }
1182 modifyModulePorts(getOperation(), {}, indexedNewPorts, {}, {},
1183 getBodyBlock());
1184
1185 // Rewrite the output op.
1186 for (auto &[name, value] : outputs)
1187 output->insertOperands(index++, value);
1188}
1189
1190void HWModuleOp::appendOutputs(ArrayRef<std::pair<StringAttr, Value>> outputs) {
1191 return insertOutputs(getNumOutputPorts(), outputs);
1192}
1193
1194void HWModuleOp::getAsmBlockArgumentNames(mlir::Region &region,
1195 mlir::OpAsmSetValueNameFn setNameFn) {
1196 getAsmBlockArgumentNamesImpl(region, setNameFn);
1197}
1198
1199void HWModuleExternOp::getAsmBlockArgumentNames(
1200 mlir::Region &region, mlir::OpAsmSetValueNameFn setNameFn) {
1201 getAsmBlockArgumentNamesImpl(region, setNameFn);
1202}
1203
1204template <typename ModTy>
1205static SmallVector<Location> getAllPortLocs(ModTy module) {
1206 auto locs = module.getPortLocs();
1207 if (locs) {
1208 SmallVector<Location> retval;
1209 retval.reserve(locs->size());
1210 for (auto l : *locs)
1211 retval.push_back(cast<Location>(l));
1212 // Either we have a length of 0 or the correct length
1213 assert(!locs->size() || locs->size() == module.getNumPorts());
1214 return retval;
1215 }
1216 return SmallVector<Location>(module.getNumPorts(),
1217 UnknownLoc::get(module.getContext()));
1218}
1219
1220SmallVector<Location> HWModuleOp::getAllPortLocs() {
1221 SmallVector<Location> portLocs;
1222 portLocs.reserve(getNumPorts());
1223 auto resultLocs = getResultLocsAttr();
1224 unsigned inputCount = 0;
1225 auto modType = getModuleType();
1226 auto unknownLoc = UnknownLoc::get(getContext());
1227 auto *body = getBodyBlock();
1228 for (unsigned i = 0, e = getNumPorts(); i < e; ++i) {
1229 if (modType.isOutput(i)) {
1230 auto loc = resultLocs
1231 ? cast<Location>(
1232 resultLocs.getValue()[portLocs.size() - inputCount])
1233 : unknownLoc;
1234 portLocs.push_back(loc);
1235 } else {
1236 auto loc = body ? body->getArgument(inputCount).getLoc() : unknownLoc;
1237 portLocs.push_back(loc);
1238 ++inputCount;
1239 }
1240 }
1241 return portLocs;
1242}
1243
1244SmallVector<Location> HWModuleExternOp::getAllPortLocs() {
1245 return ::getAllPortLocs(*this);
1246}
1247
1248SmallVector<Location> HWModuleGeneratedOp::getAllPortLocs() {
1249 return ::getAllPortLocs(*this);
1250}
1251
1252void HWModuleOp::setAllPortLocsAttrs(ArrayRef<Attribute> locs) {
1253 SmallVector<Attribute> resultLocs;
1254 unsigned inputCount = 0;
1255 auto modType = getModuleType();
1256 auto *body = getBodyBlock();
1257 for (unsigned i = 0, e = getNumPorts(); i < e; ++i) {
1258 if (modType.isOutput(i))
1259 resultLocs.push_back(locs[i]);
1260 else
1261 body->getArgument(inputCount++).setLoc(cast<Location>(locs[i]));
1262 }
1263 setResultLocsAttr(ArrayAttr::get(getContext(), resultLocs));
1264}
1265
1266void HWModuleExternOp::setAllPortLocsAttrs(ArrayRef<Attribute> locs) {
1267 setPortLocsAttr(ArrayAttr::get(getContext(), locs));
1268}
1269
1270void HWModuleGeneratedOp::setAllPortLocsAttrs(ArrayRef<Attribute> locs) {
1271 setPortLocsAttr(ArrayAttr::get(getContext(), locs));
1272}
1273
1274template <typename ModTy>
1275static void setAllPortNames(ArrayRef<Attribute> names, ModTy module) {
1276 auto numInputs = module.getNumInputPorts();
1277 SmallVector<Attribute> argNames(names.begin(), names.begin() + numInputs);
1278 SmallVector<Attribute> resNames(names.begin() + numInputs, names.end());
1279 auto oldType = module.getModuleType();
1280 SmallVector<ModulePort> newPorts(oldType.getPorts().begin(),
1281 oldType.getPorts().end());
1282 for (size_t i = 0UL, e = newPorts.size(); i != e; ++i)
1283 newPorts[i].name = cast<StringAttr>(names[i]);
1284 auto newType = ModuleType::get(module.getContext(), newPorts);
1285 module.setModuleType(newType);
1286}
1287
1288void HWModuleOp::setAllPortNames(ArrayRef<Attribute> names) {
1289 ::setAllPortNames(names, *this);
1290}
1291
1292void HWModuleExternOp::setAllPortNames(ArrayRef<Attribute> names) {
1293 ::setAllPortNames(names, *this);
1294}
1295
1296void HWModuleGeneratedOp::setAllPortNames(ArrayRef<Attribute> names) {
1297 ::setAllPortNames(names, *this);
1298}
1299
1300ArrayRef<Attribute> HWModuleOp::getAllPortAttrs() {
1301 auto attrs = getPerPortAttrs();
1302 if (attrs && !attrs->empty())
1303 return attrs->getValue();
1304 return {};
1305}
1306
1307ArrayRef<Attribute> HWModuleExternOp::getAllPortAttrs() {
1308 auto attrs = getPerPortAttrs();
1309 if (attrs && !attrs->empty())
1310 return attrs->getValue();
1311 return {};
1312}
1313
1314ArrayRef<Attribute> HWModuleGeneratedOp::getAllPortAttrs() {
1315 auto attrs = getPerPortAttrs();
1316 if (attrs && !attrs->empty())
1317 return attrs->getValue();
1318 return {};
1319}
1320
1321void HWModuleOp::setAllPortAttrs(ArrayRef<Attribute> attrs) {
1322 setPerPortAttrsAttr(arrayOrEmpty(getContext(), attrs));
1323}
1324
1325void HWModuleExternOp::setAllPortAttrs(ArrayRef<Attribute> attrs) {
1326 setPerPortAttrsAttr(arrayOrEmpty(getContext(), attrs));
1327}
1328
1329void HWModuleGeneratedOp::setAllPortAttrs(ArrayRef<Attribute> attrs) {
1330 setPerPortAttrsAttr(arrayOrEmpty(getContext(), attrs));
1331}
1332
1333void HWModuleOp::removeAllPortAttrs() {
1334 setPerPortAttrsAttr(ArrayAttr::get(getContext(), {}));
1335}
1336
1337void HWModuleExternOp::removeAllPortAttrs() {
1338 setPerPortAttrsAttr(ArrayAttr::get(getContext(), {}));
1339}
1340
1341void HWModuleGeneratedOp::removeAllPortAttrs() {
1342 setPerPortAttrsAttr(ArrayAttr::get(getContext(), {}));
1343}
1344
1345// This probably does really unexpected stuff when you change the number of
1346
1347template <typename ModTy>
1348static void setHWModuleType(ModTy &mod, ModuleType type) {
1349 auto argAttrs = mod.getAllInputAttrs();
1350 auto resAttrs = mod.getAllOutputAttrs();
1351 mod.setModuleTypeAttr(TypeAttr::get(type));
1352 unsigned newNumArgs = type.getNumInputs();
1353 unsigned newNumResults = type.getNumOutputs();
1354
1355 auto emptyDict = DictionaryAttr::get(mod.getContext());
1356 argAttrs.resize(newNumArgs, emptyDict);
1357 resAttrs.resize(newNumResults, emptyDict);
1358
1359 SmallVector<Attribute> attrs;
1360 attrs.append(argAttrs.begin(), argAttrs.end());
1361 attrs.append(resAttrs.begin(), resAttrs.end());
1362
1363 if (attrs.empty())
1364 return mod.removeAllPortAttrs();
1365 mod.setAllPortAttrs(attrs);
1366}
1367
1368void HWModuleOp::setHWModuleType(ModuleType type) {
1369 return ::setHWModuleType(*this, type);
1370}
1371
1372void HWModuleExternOp::setHWModuleType(ModuleType type) {
1373 return ::setHWModuleType(*this, type);
1374}
1375
1376void HWModuleGeneratedOp::setHWModuleType(ModuleType type) {
1377 return ::setHWModuleType(*this, type);
1378}
1379
1380/// Lookup the generator for the symbol. This returns null on
1381/// invalid IR.
1382Operation *HWModuleGeneratedOp::getGeneratorKindOp() {
1383 auto topLevelModuleOp = (*this)->getParentOfType<ModuleOp>();
1384 return topLevelModuleOp.lookupSymbol(getGeneratorKind());
1385}
1386
1387LogicalResult
1388HWModuleGeneratedOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1389 auto *referencedKind =
1390 symbolTable.lookupNearestSymbolFrom(*this, getGeneratorKindAttr());
1391
1392 if (referencedKind == nullptr)
1393 return emitError("Cannot find generator definition '")
1394 << getGeneratorKind() << "'";
1395
1396 if (!isa<HWGeneratorSchemaOp>(referencedKind))
1397 return emitError("Symbol resolved to '")
1398 << referencedKind->getName()
1399 << "' which is not a HWGeneratorSchemaOp";
1400
1401 auto referencedKindOp = dyn_cast<HWGeneratorSchemaOp>(referencedKind);
1402 auto paramRef = referencedKindOp.getRequiredAttrs();
1403 auto dict = (*this)->getAttrDictionary();
1404 for (auto str : paramRef) {
1405 auto strAttr = dyn_cast<StringAttr>(str);
1406 if (!strAttr)
1407 return emitError("Unknown attribute type, expected a string");
1408 if (!dict.get(strAttr.getValue()))
1409 return emitError("Missing attribute '") << strAttr.getValue() << "'";
1410 }
1411
1412 return success();
1413}
1414
1415LogicalResult HWModuleGeneratedOp::verify() {
1416 return verifyModuleCommon(*this);
1417}
1418
1419void HWModuleGeneratedOp::getAsmBlockArgumentNames(
1420 mlir::Region &region, mlir::OpAsmSetValueNameFn setNameFn) {
1421 getAsmBlockArgumentNamesImpl(region, setNameFn);
1422}
1423
1424LogicalResult HWModuleOp::verifyBody() { return success(); }
1425
1426template <typename ModuleTy>
1427static SmallVector<PortInfo> getPortList(ModuleTy &mod) {
1428 auto modTy = mod.getHWModuleType();
1429 auto emptyDict = DictionaryAttr::get(mod.getContext());
1430 SmallVector<PortInfo> retval;
1431 auto locs = mod.getAllPortLocs();
1432 for (unsigned i = 0, e = modTy.getNumPorts(); i < e; ++i) {
1433 LocationAttr loc = locs[i];
1434 DictionaryAttr attrs =
1435 dyn_cast_or_null<DictionaryAttr>(mod.getPortAttrs(i));
1436 if (!attrs)
1437 attrs = emptyDict;
1438 retval.push_back({modTy.getPorts()[i],
1439 modTy.isOutput(i) ? modTy.getOutputIdForPortId(i)
1440 : modTy.getInputIdForPortId(i),
1441 attrs, loc});
1442 }
1443 return retval;
1444}
1445
1446template <typename ModuleTy>
1447static PortInfo getPort(ModuleTy &mod, size_t idx) {
1448 auto modTy = mod.getHWModuleType();
1449 auto emptyDict = DictionaryAttr::get(mod.getContext());
1450 LocationAttr loc = mod.getPortLoc(idx);
1451 DictionaryAttr attrs =
1452 dyn_cast_or_null<DictionaryAttr>(mod.getPortAttrs(idx));
1453 if (!attrs)
1454 attrs = emptyDict;
1455 return {modTy.getPorts()[idx],
1456 modTy.isOutput(idx) ? modTy.getOutputIdForPortId(idx)
1457 : modTy.getInputIdForPortId(idx),
1458 attrs, loc};
1459}
1460
1461//===----------------------------------------------------------------------===//
1462// InstanceOp
1463//===----------------------------------------------------------------------===//
1464
1465/// Create a instance that refers to a known module.
1466void InstanceOp::build(OpBuilder &builder, OperationState &result,
1467 Operation *module, StringAttr name,
1468 ArrayRef<Value> inputs, ArrayAttr parameters,
1469 InnerSymAttr innerSym) {
1470 if (!parameters)
1471 parameters = builder.getArrayAttr({});
1472
1473 auto mod = cast<hw::HWModuleLike>(module);
1474 auto argNames = builder.getArrayAttr(mod.getInputNames());
1475 auto resultNames = builder.getArrayAttr(mod.getOutputNames());
1476
1477 // Try to resolve the parameterized module type. If failed, use the module's
1478 // parmeterized type. If the client doesn't fix this error, the verifier will
1479 // fail.
1480 ModuleType modType = mod.getHWModuleType();
1481 FailureOr<ModuleType> resolvedModType = modType.resolveParametricTypes(
1482 parameters, result.location, /*emitErrors=*/false);
1483 if (succeeded(resolvedModType))
1484 modType = *resolvedModType;
1485 FunctionType funcType = resolvedModType->getFuncType();
1486 build(builder, result, funcType.getResults(), name,
1487 FlatSymbolRefAttr::get(SymbolTable::getSymbolName(module)), inputs,
1488 argNames, resultNames, parameters, innerSym, /*doNotPrint=*/{});
1489}
1490
1491std::optional<size_t> InstanceOp::getTargetResultIndex() {
1492 // Inner symbols on instance operations target the op not any result.
1493 return std::nullopt;
1494}
1495
1496LogicalResult InstanceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1498 *this, getModuleNameAttr(), getInputs(), getResultTypes(), getArgNames(),
1499 getResultNames(), getParameters(), symbolTable);
1500}
1501
1502LogicalResult InstanceOp::verify() {
1503 auto module = (*this)->getParentOfType<HWModuleOp>();
1504 if (!module)
1505 return success();
1506
1507 auto moduleParameters = module->getAttrOfType<ArrayAttr>("parameters");
1509 [&](const std::function<bool(InFlightDiagnostic &)> &fn) {
1510 auto diag = emitOpError();
1511 if (fn(diag))
1512 diag.attachNote(module->getLoc()) << "module declared here";
1513 };
1515 getParameters(), moduleParameters, emitError);
1516}
1517
1518ParseResult InstanceOp::parse(OpAsmParser &parser, OperationState &result) {
1519 StringAttr instanceNameAttr;
1520 InnerSymAttr innerSym;
1521 FlatSymbolRefAttr moduleNameAttr;
1522 SmallVector<OpAsmParser::UnresolvedOperand, 4> inputsOperands;
1523 SmallVector<Type, 1> inputsTypes, allResultTypes;
1524 ArrayAttr argNames, resultNames, parameters;
1525 auto noneType = parser.getBuilder().getType<NoneType>();
1526
1527 if (parser.parseAttribute(instanceNameAttr, noneType, "instanceName",
1528 result.attributes))
1529 return failure();
1530
1531 if (succeeded(parser.parseOptionalKeyword("sym"))) {
1532 // Parsing an optional symbol name doesn't fail, so no need to check the
1533 // result.
1534 if (parser.parseCustomAttributeWithFallback(innerSym))
1535 return failure();
1536 result.addAttribute(InnerSymbolTable::getInnerSymbolAttrName(), innerSym);
1537 }
1538
1539 llvm::SMLoc parametersLoc, inputsOperandsLoc;
1540 if (parser.parseAttribute(moduleNameAttr, noneType, "moduleName",
1541 result.attributes) ||
1542 parser.getCurrentLocation(&parametersLoc) ||
1543 parseOptionalParameterList(parser, parameters) ||
1544 parseInputPortList(parser, inputsOperands, inputsTypes, argNames) ||
1545 parser.resolveOperands(inputsOperands, inputsTypes, inputsOperandsLoc,
1546 result.operands) ||
1547 parser.parseArrow() ||
1548 parseOutputPortList(parser, allResultTypes, resultNames) ||
1549 parser.parseOptionalAttrDict(result.attributes)) {
1550 return failure();
1551 }
1552
1553 result.addAttribute("argNames", argNames);
1554 result.addAttribute("resultNames", resultNames);
1555 result.addAttribute("parameters", parameters);
1556 result.addTypes(allResultTypes);
1557 return success();
1558}
1559
1560void InstanceOp::print(OpAsmPrinter &p) {
1561 p << ' ';
1562 p.printAttributeWithoutType(getInstanceNameAttr());
1563 if (auto attr = getInnerSymAttr()) {
1564 p << " sym ";
1565 attr.print(p);
1566 }
1567 p << ' ';
1568 p.printAttributeWithoutType(getModuleNameAttr());
1569 printOptionalParameterList(p, *this, getParameters());
1570 printInputPortList(p, *this, getInputs(), getInputs().getTypes(),
1571 getArgNames());
1572 p << " -> ";
1573 printOutputPortList(p, *this, getResultTypes(), getResultNames());
1574
1575 p.printOptionalAttrDict(
1576 (*this)->getAttrs(),
1577 /*elidedAttrs=*/{"instanceName",
1578 InnerSymbolTable::getInnerSymbolAttrName(), "moduleName",
1579 "argNames", "resultNames", "parameters"});
1580}
1581
1582//===----------------------------------------------------------------------===//
1583// InstanceChoiceOp
1584//===----------------------------------------------------------------------===//
1585
1586std::optional<size_t> InstanceChoiceOp::getTargetResultIndex() {
1587 // Inner symbols on instance operations target the op not any result.
1588 return std::nullopt;
1589}
1590
1591LogicalResult
1592InstanceChoiceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1593 for (Attribute name : getModuleNamesAttr()) {
1595 *this, cast<FlatSymbolRefAttr>(name), getInputs(), getResultTypes(),
1596 getArgNames(), getResultNames(), getParameters(), symbolTable))) {
1597 return failure();
1598 }
1599 }
1600 return success();
1601}
1602
1603LogicalResult InstanceChoiceOp::verify() {
1604 auto module = (*this)->getParentOfType<HWModuleOp>();
1605 if (!module)
1606 return success();
1607
1608 auto moduleParameters = module->getAttrOfType<ArrayAttr>("parameters");
1610 [&](const std::function<bool(InFlightDiagnostic &)> &fn) {
1611 auto diag = emitOpError();
1612 if (fn(diag))
1613 diag.attachNote(module->getLoc()) << "module declared here";
1614 };
1616 getParameters(), moduleParameters, emitError);
1617}
1618
1619ParseResult InstanceChoiceOp::parse(OpAsmParser &parser,
1620 OperationState &result) {
1621 StringAttr optionNameAttr;
1622 StringAttr instanceNameAttr;
1623 InnerSymAttr innerSym;
1624 SmallVector<Attribute> moduleNames;
1625 SmallVector<Attribute> caseNames;
1626 SmallVector<OpAsmParser::UnresolvedOperand, 4> inputsOperands;
1627 SmallVector<Type, 1> inputsTypes, allResultTypes;
1628 ArrayAttr argNames, resultNames, parameters;
1629 auto noneType = parser.getBuilder().getType<NoneType>();
1630
1631 if (parser.parseAttribute(instanceNameAttr, noneType, "instanceName",
1632 result.attributes))
1633 return failure();
1634
1635 if (succeeded(parser.parseOptionalKeyword("sym"))) {
1636 // Parsing an optional symbol name doesn't fail, so no need to check the
1637 // result.
1638 if (parser.parseCustomAttributeWithFallback(innerSym))
1639 return failure();
1640 result.addAttribute(InnerSymbolTable::getInnerSymbolAttrName(), innerSym);
1641 }
1642
1643 if (parser.parseKeyword("option") ||
1644 parser.parseAttribute(optionNameAttr, noneType, "optionName",
1645 result.attributes))
1646 return failure();
1647
1648 FlatSymbolRefAttr defaultModuleName;
1649 if (parser.parseAttribute(defaultModuleName))
1650 return failure();
1651 moduleNames.push_back(defaultModuleName);
1652
1653 while (succeeded(parser.parseOptionalKeyword("or"))) {
1654 FlatSymbolRefAttr moduleName;
1655 StringAttr targetName;
1656 if (parser.parseAttribute(moduleName) ||
1657 parser.parseOptionalKeyword("if") || parser.parseAttribute(targetName))
1658 return failure();
1659 moduleNames.push_back(moduleName);
1660 caseNames.push_back(targetName);
1661 }
1662
1663 llvm::SMLoc parametersLoc, inputsOperandsLoc;
1664 if (parser.getCurrentLocation(&parametersLoc) ||
1665 parseOptionalParameterList(parser, parameters) ||
1666 parseInputPortList(parser, inputsOperands, inputsTypes, argNames) ||
1667 parser.resolveOperands(inputsOperands, inputsTypes, inputsOperandsLoc,
1668 result.operands) ||
1669 parser.parseArrow() ||
1670 parseOutputPortList(parser, allResultTypes, resultNames) ||
1671 parser.parseOptionalAttrDict(result.attributes)) {
1672 return failure();
1673 }
1674
1675 result.addAttribute("moduleNames",
1676 ArrayAttr::get(parser.getContext(), moduleNames));
1677 result.addAttribute("caseNames",
1678 ArrayAttr::get(parser.getContext(), caseNames));
1679 result.addAttribute("argNames", argNames);
1680 result.addAttribute("resultNames", resultNames);
1681 result.addAttribute("parameters", parameters);
1682 result.addTypes(allResultTypes);
1683 return success();
1684}
1685
1686void InstanceChoiceOp::print(OpAsmPrinter &p) {
1687 p << ' ';
1688 p.printAttributeWithoutType(getInstanceNameAttr());
1689 if (auto attr = getInnerSymAttr()) {
1690 p << " sym ";
1691 attr.print(p);
1692 }
1693 p << " option " << getOptionNameAttr() << ' ';
1694
1695 auto moduleNames = getModuleNamesAttr();
1696 auto caseNames = getCaseNamesAttr();
1697 assert(moduleNames.size() == caseNames.size() + 1);
1698
1699 p.printAttributeWithoutType(moduleNames[0]);
1700 for (size_t i = 0, n = caseNames.size(); i < n; ++i) {
1701 p << " or ";
1702 p.printAttributeWithoutType(moduleNames[i + 1]);
1703 p << " if ";
1704 p.printAttributeWithoutType(caseNames[i]);
1705 }
1706
1707 printOptionalParameterList(p, *this, getParameters());
1708 printInputPortList(p, *this, getInputs(), getInputs().getTypes(),
1709 getArgNames());
1710 p << " -> ";
1711 printOutputPortList(p, *this, getResultTypes(), getResultNames());
1712
1713 p.printOptionalAttrDict(
1714 (*this)->getAttrs(),
1715 /*elidedAttrs=*/{"instanceName",
1716 InnerSymbolTable::getInnerSymbolAttrName(),
1717 "moduleNames", "caseNames", "argNames", "resultNames",
1718 "parameters", "optionName"});
1719}
1720
1721ArrayAttr InstanceChoiceOp::getReferencedModuleNamesAttr() {
1722 SmallVector<Attribute> moduleNames;
1723 for (Attribute attr : getModuleNamesAttr()) {
1724 moduleNames.push_back(cast<FlatSymbolRefAttr>(attr).getAttr());
1725 }
1726 return ArrayAttr::get(getContext(), moduleNames);
1727}
1728
1729//===----------------------------------------------------------------------===//
1730// HWOutputOp
1731//===----------------------------------------------------------------------===//
1732
1733/// Verify that the num of operands and types fit the declared results.
1734LogicalResult OutputOp::verify() {
1735 // Check that the we (hw.output) have the same number of operands as our
1736 // region has results.
1737 ModuleType modType;
1738 if (auto mod = dyn_cast<HWModuleOp>((*this)->getParentOp()))
1739 modType = mod.getHWModuleType();
1740 else {
1741 emitOpError("must have a module parent");
1742 return failure();
1743 }
1744 auto modResults = modType.getOutputTypes();
1745 OperandRange outputValues = getOperands();
1746 if (modResults.size() != outputValues.size()) {
1747 emitOpError("must have same number of operands as region results.");
1748 return failure();
1749 }
1750
1751 // Check that the types of our operands and the region's results match.
1752 for (size_t i = 0, e = modResults.size(); i < e; ++i) {
1753 if (modResults[i] != outputValues[i].getType()) {
1754 emitOpError("output types must match module. In "
1755 "operand ")
1756 << i << ", expected " << modResults[i] << ", but got "
1757 << outputValues[i].getType() << ".";
1758 return failure();
1759 }
1760 }
1761
1762 return success();
1763}
1764
1765//===----------------------------------------------------------------------===//
1766// Other Operations
1767//===----------------------------------------------------------------------===//
1768
1769static ParseResult parseSliceTypes(OpAsmParser &p, Type &srcType,
1770 Type &idxType) {
1771 Type type;
1772 if (p.parseType(type))
1773 return p.emitError(p.getCurrentLocation(), "Expected type");
1774 auto arrType = type_dyn_cast<ArrayType>(type);
1775 if (!arrType)
1776 return p.emitError(p.getCurrentLocation(), "Expected !hw.array type");
1777 srcType = type;
1778 unsigned idxWidth = llvm::Log2_64_Ceil(arrType.getNumElements());
1779 idxType = IntegerType::get(p.getBuilder().getContext(), idxWidth);
1780 return success();
1781}
1782
1783static void printSliceTypes(OpAsmPrinter &p, Operation *, Type srcType,
1784 Type idxType) {
1785 p.printType(srcType);
1786}
1787
1788ParseResult ArrayCreateOp::parse(OpAsmParser &parser, OperationState &result) {
1789 llvm::SMLoc inputOperandsLoc = parser.getCurrentLocation();
1790 llvm::SmallVector<OpAsmParser::UnresolvedOperand, 16> operands;
1791 Type elemType;
1792
1793 if (parser.parseOperandList(operands) ||
1794 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
1795 parser.parseType(elemType))
1796 return failure();
1797
1798 if (operands.size() == 0)
1799 return parser.emitError(inputOperandsLoc,
1800 "Cannot construct an array of length 0");
1801 result.addTypes({ArrayType::get(elemType, operands.size())});
1802
1803 for (auto operand : operands)
1804 if (parser.resolveOperand(operand, elemType, result.operands))
1805 return failure();
1806 return success();
1807}
1808
1809void ArrayCreateOp::print(OpAsmPrinter &p) {
1810 p << " ";
1811 p.printOperands(getInputs());
1812 p.printOptionalAttrDict((*this)->getAttrs());
1813 p << " : " << getInputs()[0].getType();
1814}
1815
1816void ArrayCreateOp::build(OpBuilder &b, OperationState &state,
1817 ValueRange values) {
1818 assert(values.size() > 0 && "Cannot build array of zero elements");
1819 Type elemType = values[0].getType();
1820 assert(llvm::all_of(
1821 values,
1822 [elemType](Value v) -> bool { return v.getType() == elemType; }) &&
1823 "All values must have same type.");
1824 build(b, state, ArrayType::get(elemType, values.size()), values);
1825}
1826
1827LogicalResult ArrayCreateOp::verify() {
1828 unsigned returnSize = cast<ArrayType>(getType()).getNumElements();
1829 if (getInputs().size() != returnSize)
1830 return failure();
1831 return success();
1832}
1833
1834OpFoldResult ArrayCreateOp::fold(FoldAdaptor adaptor) {
1835 if (llvm::any_of(adaptor.getInputs(), [](Attribute attr) { return !attr; }))
1836 return {};
1837 return ArrayAttr::get(getContext(), adaptor.getInputs());
1838}
1839
1840// Check whether an integer value is an offset from a base.
1841bool hw::isOffset(Value base, Value index, uint64_t offset) {
1842 if (auto constBase = base.getDefiningOp<hw::ConstantOp>()) {
1843 if (auto constIndex = index.getDefiningOp<hw::ConstantOp>()) {
1844 // If both values are a constant, check if index == base + offset.
1845 // To account for overflow, the addition is performed with an extra bit
1846 // and the offset is asserted to fit in the bit width of the base.
1847 auto baseValue = constBase.getValue();
1848 auto indexValue = constIndex.getValue();
1849
1850 unsigned bits = baseValue.getBitWidth();
1851 assert(bits == indexValue.getBitWidth() && "mismatched widths");
1852
1853 if (bits < 64 && offset >= (1ull << bits))
1854 return false;
1855
1856 APInt baseExt = baseValue.zextOrTrunc(bits + 1);
1857 APInt indexExt = indexValue.zextOrTrunc(bits + 1);
1858 return baseExt + offset == indexExt;
1859 }
1860 }
1861 return false;
1862}
1863
1864// Canonicalize a create of consecutive elements to a slice.
1865static LogicalResult foldCreateToSlice(ArrayCreateOp op,
1866 PatternRewriter &rewriter) {
1867 // Do not canonicalize create of get into a slice.
1868 auto arrayTy = hw::type_cast<ArrayType>(op.getType());
1869 if (arrayTy.getNumElements() <= 1)
1870 return failure();
1871 auto elemTy = arrayTy.getElementType();
1872
1873 // Check if create arguments are consecutive elements of the same array.
1874 // Attempt to break a create of gets into a sequence of consecutive intervals.
1875 struct Chunk {
1876 Value input;
1877 Value index;
1878 size_t size;
1879 };
1880 SmallVector<Chunk> chunks;
1881 for (Value value : llvm::reverse(op.getInputs())) {
1882 auto get = value.getDefiningOp<ArrayGetOp>();
1883 if (!get)
1884 return failure();
1885
1886 Value input = get.getInput();
1887 Value index = get.getIndex();
1888 if (!chunks.empty()) {
1889 auto &c = *chunks.rbegin();
1890 if (c.input == get.getInput() && isOffset(c.index, index, c.size)) {
1891 c.size++;
1892 continue;
1893 }
1894 }
1895
1896 chunks.push_back(Chunk{input, index, 1});
1897 }
1898
1899 // If there is a single slice, eliminate the create.
1900 if (chunks.size() == 1) {
1901 auto &chunk = chunks[0];
1902 rewriter.replaceOp(op, rewriter.createOrFold<ArraySliceOp>(
1903 op.getLoc(), arrayTy, chunk.input, chunk.index));
1904 return success();
1905 }
1906
1907 // If the number of chunks is significantly less than the number of
1908 // elements, replace the create with a concat of the identified slices.
1909 if (chunks.size() * 2 < arrayTy.getNumElements()) {
1910 SmallVector<Value> slices;
1911 for (auto &chunk : llvm::reverse(chunks)) {
1912 auto sliceTy = ArrayType::get(elemTy, chunk.size);
1913 slices.push_back(rewriter.createOrFold<ArraySliceOp>(
1914 op.getLoc(), sliceTy, chunk.input, chunk.index));
1915 }
1916 rewriter.replaceOpWithNewOp<ArrayConcatOp>(op, arrayTy, slices);
1917 return success();
1918 }
1919
1920 return failure();
1921}
1922
1923LogicalResult ArrayCreateOp::canonicalize(ArrayCreateOp op,
1924 PatternRewriter &rewriter) {
1925 if (succeeded(foldCreateToSlice(op, rewriter)))
1926 return success();
1927 return failure();
1928}
1929
1930Value ArrayCreateOp::getUniformElement() {
1931 if (!getInputs().empty() && llvm::all_equal(getInputs()))
1932 return getInputs()[0];
1933 return {};
1934}
1935
1936static std::optional<uint64_t> getUIntFromValue(Value value) {
1937 auto idxOp = dyn_cast_or_null<ConstantOp>(value.getDefiningOp());
1938 if (!idxOp)
1939 return std::nullopt;
1940 APInt idxAttr = idxOp.getValue();
1941 if (idxAttr.getBitWidth() > 64)
1942 return std::nullopt;
1943 return idxAttr.getLimitedValue();
1944}
1945
1946LogicalResult ArraySliceOp::verify() {
1947 unsigned inputSize =
1948 type_cast<ArrayType>(getInput().getType()).getNumElements();
1949 if (llvm::Log2_64_Ceil(inputSize) !=
1950 getLowIndex().getType().getIntOrFloatBitWidth())
1951 return emitOpError(
1952 "ArraySlice: index width must match clog2 of array size");
1953 return success();
1954}
1955
1956OpFoldResult ArraySliceOp::fold(FoldAdaptor adaptor) {
1957 // If we are slicing the entire input, then return it.
1958 if (getType() == getInput().getType())
1959 return getInput();
1960 return {};
1961}
1962
1963LogicalResult ArraySliceOp::canonicalize(ArraySliceOp op,
1964 PatternRewriter &rewriter) {
1965 auto sliceTy = hw::type_cast<ArrayType>(op.getType());
1966 auto elemTy = sliceTy.getElementType();
1967 uint64_t sliceSize = sliceTy.getNumElements();
1968 if (sliceSize == 0)
1969 return failure();
1970
1971 if (sliceSize == 1) {
1972 // slice(a, n) -> create(a[n])
1973 auto get = rewriter.create<ArrayGetOp>(op.getLoc(), op.getInput(),
1974 op.getLowIndex());
1975 rewriter.replaceOpWithNewOp<ArrayCreateOp>(op, op.getType(),
1976 get.getResult());
1977 return success();
1978 }
1979
1980 auto offsetOpt = getUIntFromValue(op.getLowIndex());
1981 if (!offsetOpt)
1982 return failure();
1983
1984 auto inputOp = op.getInput().getDefiningOp();
1985 if (auto inputSlice = dyn_cast_or_null<ArraySliceOp>(inputOp)) {
1986 // slice(slice(a, n), m) -> slice(a, n + m)
1987 if (inputSlice == op)
1988 return failure();
1989
1990 auto inputIndex = inputSlice.getLowIndex();
1991 auto inputOffsetOpt = getUIntFromValue(inputIndex);
1992 if (!inputOffsetOpt)
1993 return failure();
1994
1995 uint64_t offset = *offsetOpt + *inputOffsetOpt;
1996 auto lowIndex =
1997 rewriter.create<ConstantOp>(op.getLoc(), inputIndex.getType(), offset);
1998 rewriter.replaceOpWithNewOp<ArraySliceOp>(op, op.getType(),
1999 inputSlice.getInput(), lowIndex);
2000 return success();
2001 }
2002
2003 if (auto inputCreate = dyn_cast_or_null<ArrayCreateOp>(inputOp)) {
2004 // slice(create(a0, a1, ..., an), m) -> create(am, ...)
2005 auto inputs = inputCreate.getInputs();
2006
2007 uint64_t begin = inputs.size() - *offsetOpt - sliceSize;
2008 rewriter.replaceOpWithNewOp<ArrayCreateOp>(op, op.getType(),
2009 inputs.slice(begin, sliceSize));
2010 return success();
2011 }
2012
2013 if (auto inputConcat = dyn_cast_or_null<ArrayConcatOp>(inputOp)) {
2014 // slice(concat(a1, a2, ...)) -> concat(a2, slice(a3, ..), ...)
2015 SmallVector<Value> chunks;
2016 uint64_t sliceStart = *offsetOpt;
2017 for (auto input : llvm::reverse(inputConcat.getInputs())) {
2018 // Check whether the input intersects with the slice.
2019 uint64_t inputSize =
2020 hw::type_cast<ArrayType>(input.getType()).getNumElements();
2021 if (inputSize == 0 || inputSize <= sliceStart) {
2022 sliceStart -= inputSize;
2023 continue;
2024 }
2025
2026 // Find the indices to slice from this input by intersection.
2027 uint64_t cutEnd = std::min(inputSize, sliceStart + sliceSize);
2028 uint64_t cutSize = cutEnd - sliceStart;
2029 assert(cutSize != 0 && "slice cannot be empty");
2030
2031 if (cutSize == inputSize) {
2032 // The whole input fits in the slice, add it.
2033 assert(sliceStart == 0 && "invalid cut size");
2034 chunks.push_back(input);
2035 } else {
2036 // Slice the required bits from the input.
2037 unsigned width = inputSize == 1 ? 1 : llvm::Log2_64_Ceil(inputSize);
2038 auto lowIndex = rewriter.create<ConstantOp>(
2039 op.getLoc(), rewriter.getIntegerType(width), sliceStart);
2040 chunks.push_back(rewriter.create<ArraySliceOp>(
2041 op.getLoc(), hw::ArrayType::get(elemTy, cutSize), input, lowIndex));
2042 }
2043
2044 sliceStart = 0;
2045 sliceSize -= cutSize;
2046 if (sliceSize == 0)
2047 break;
2048 }
2049
2050 assert(chunks.size() > 0 && "missing sliced items");
2051 if (chunks.size() == 1)
2052 rewriter.replaceOp(op, chunks[0]);
2053 else
2054 rewriter.replaceOpWithNewOp<ArrayConcatOp>(
2055 op, llvm::to_vector(llvm::reverse(chunks)));
2056 return success();
2057 }
2058 return failure();
2059}
2060
2061//===----------------------------------------------------------------------===//
2062// ArrayConcatOp
2063//===----------------------------------------------------------------------===//
2064
2065static ParseResult parseArrayConcatTypes(OpAsmParser &p,
2066 SmallVectorImpl<Type> &inputTypes,
2067 Type &resultType) {
2068 Type elemType;
2069 uint64_t resultSize = 0;
2070
2071 auto parseElement = [&]() -> ParseResult {
2072 Type ty;
2073 if (p.parseType(ty))
2074 return failure();
2075 auto arrTy = type_dyn_cast<ArrayType>(ty);
2076 if (!arrTy)
2077 return p.emitError(p.getCurrentLocation(), "Expected !hw.array type");
2078 if (elemType && elemType != arrTy.getElementType())
2079 return p.emitError(p.getCurrentLocation(), "Expected array element type ")
2080 << elemType;
2081
2082 elemType = arrTy.getElementType();
2083 inputTypes.push_back(ty);
2084 resultSize += arrTy.getNumElements();
2085 return success();
2086 };
2087
2088 if (p.parseCommaSeparatedList(parseElement))
2089 return failure();
2090
2091 resultType = ArrayType::get(elemType, resultSize);
2092 return success();
2093}
2094
2095static void printArrayConcatTypes(OpAsmPrinter &p, Operation *,
2096 TypeRange inputTypes, Type resultType) {
2097 llvm::interleaveComma(inputTypes, p, [&p](Type t) { p << t; });
2098}
2099
2100void ArrayConcatOp::build(OpBuilder &b, OperationState &state,
2101 ValueRange values) {
2102 assert(!values.empty() && "Cannot build array of zero elements");
2103 ArrayType arrayTy = cast<ArrayType>(values[0].getType());
2104 Type elemTy = arrayTy.getElementType();
2105 assert(llvm::all_of(values,
2106 [elemTy](Value v) -> bool {
2107 return isa<ArrayType>(v.getType()) &&
2108 cast<ArrayType>(v.getType()).getElementType() ==
2109 elemTy;
2110 }) &&
2111 "All values must be of ArrayType with the same element type.");
2112
2113 uint64_t resultSize = 0;
2114 for (Value val : values)
2115 resultSize += cast<ArrayType>(val.getType()).getNumElements();
2116 build(b, state, ArrayType::get(elemTy, resultSize), values);
2117}
2118
2119OpFoldResult ArrayConcatOp::fold(FoldAdaptor adaptor) {
2120 if (getInputs().size() == 1)
2121 return getInputs()[0];
2122
2123 auto inputs = adaptor.getInputs();
2124 SmallVector<Attribute> array;
2125 for (size_t i = 0, e = getNumOperands(); i < e; ++i) {
2126 if (!inputs[i])
2127 return {};
2128 llvm::copy(cast<ArrayAttr>(inputs[i]), std::back_inserter(array));
2129 }
2130 return ArrayAttr::get(getContext(), array);
2131}
2132
2133// Flatten a concatenation of array creates into a single create.
2134static bool flattenConcatOp(ArrayConcatOp op, PatternRewriter &rewriter) {
2135 for (auto input : op.getInputs())
2136 if (!input.getDefiningOp<ArrayCreateOp>())
2137 return false;
2138
2139 SmallVector<Value> items;
2140 for (auto input : op.getInputs()) {
2141 auto create = cast<ArrayCreateOp>(input.getDefiningOp());
2142 for (auto item : create.getInputs())
2143 items.push_back(item);
2144 }
2145
2146 rewriter.replaceOpWithNewOp<ArrayCreateOp>(op, items);
2147 return true;
2148}
2149
2150// Merge consecutive slice expressions in a concatenation.
2151static bool mergeConcatSlices(ArrayConcatOp op, PatternRewriter &rewriter) {
2152 struct Slice {
2153 Value input;
2154 Value index;
2155 size_t size;
2156 Value op;
2157 SmallVector<Location> locs;
2158 };
2159
2160 SmallVector<Value> items;
2161 std::optional<Slice> last;
2162 bool changed = false;
2163
2164 auto concatenate = [&] {
2165 // If there is only one op in the slice, place it to the items list.
2166 if (!last)
2167 return;
2168 if (last->op) {
2169 items.push_back(last->op);
2170 last.reset();
2171 return;
2172 }
2173
2174 // Otherwise, create a new slice of with the given size and place it.
2175 // In this case, the concat op is replaced, using the new argument.
2176 changed = true;
2177 auto loc = FusedLoc::get(op.getContext(), last->locs);
2178 auto origTy = hw::type_cast<ArrayType>(last->input.getType());
2179 auto arrayTy = ArrayType::get(origTy.getElementType(), last->size);
2180 items.push_back(rewriter.createOrFold<ArraySliceOp>(
2181 loc, arrayTy, last->input, last->index));
2182
2183 last.reset();
2184 };
2185
2186 auto append = [&](Value op, Value input, Value index, size_t size) {
2187 // If this slice is an extension of the previous one, extend the size
2188 // saved. In this case, a new slice of is created and the concatenation
2189 // operator is rewritten. Otherwise, flush the last slice.
2190 if (last) {
2191 if (last->input == input && isOffset(last->index, index, last->size)) {
2192 last->size += size;
2193 last->op = {};
2194 last->locs.push_back(op.getLoc());
2195 return;
2196 }
2197 concatenate();
2198 }
2199 last.emplace(Slice{input, index, size, op, {op.getLoc()}});
2200 };
2201
2202 for (auto item : llvm::reverse(op.getInputs())) {
2203 if (auto slice = item.getDefiningOp<ArraySliceOp>()) {
2204 auto size = hw::type_cast<ArrayType>(slice.getType()).getNumElements();
2205 append(item, slice.getInput(), slice.getLowIndex(), size);
2206 continue;
2207 }
2208
2209 if (auto create = item.getDefiningOp<ArrayCreateOp>()) {
2210 if (create.getInputs().size() == 1) {
2211 if (auto get = create.getInputs()[0].getDefiningOp<ArrayGetOp>()) {
2212 append(item, get.getInput(), get.getIndex(), 1);
2213 continue;
2214 }
2215 }
2216 }
2217
2218 concatenate();
2219 items.push_back(item);
2220 }
2221 concatenate();
2222
2223 if (!changed)
2224 return false;
2225
2226 if (items.size() == 1) {
2227 rewriter.replaceOp(op, items[0]);
2228 } else {
2229 std::reverse(items.begin(), items.end());
2230 rewriter.replaceOpWithNewOp<ArrayConcatOp>(op, items);
2231 }
2232 return true;
2233}
2234
2235LogicalResult ArrayConcatOp::canonicalize(ArrayConcatOp op,
2236 PatternRewriter &rewriter) {
2237 // concat(create(a1, ...), create(a3, ...), ...) -> create(a1, ..., a3, ...)
2238 if (flattenConcatOp(op, rewriter))
2239 return success();
2240
2241 // concat(slice(a, n, m), slice(a, n + m, p)) -> concat(slice(a, n, m + p))
2242 if (mergeConcatSlices(op, rewriter))
2243 return success();
2244
2245 return failure();
2246}
2247
2248//===----------------------------------------------------------------------===//
2249// EnumConstantOp
2250//===----------------------------------------------------------------------===//
2251
2252ParseResult EnumConstantOp::parse(OpAsmParser &parser, OperationState &result) {
2253 // Parse a Type instead of an EnumType since the type might be a type alias.
2254 // The validity of the canonical type is checked during construction of the
2255 // EnumFieldAttr.
2256 Type type;
2257 StringRef field;
2258
2259 auto loc = parser.getEncodedSourceLoc(parser.getCurrentLocation());
2260 if (parser.parseKeyword(&field) || parser.parseColonType(type))
2261 return failure();
2262
2263 auto fieldAttr = EnumFieldAttr::get(
2264 loc, StringAttr::get(parser.getContext(), field), type);
2265
2266 if (!fieldAttr)
2267 return failure();
2268
2269 result.addAttribute("field", fieldAttr);
2270 result.addTypes(type);
2271
2272 return success();
2273}
2274
2275void EnumConstantOp::print(OpAsmPrinter &p) {
2276 p << " " << getField().getField().getValue() << " : "
2277 << getField().getType().getValue();
2278}
2279
2280void EnumConstantOp::getAsmResultNames(
2281 function_ref<void(Value, StringRef)> setNameFn) {
2282 setNameFn(getResult(), getField().getField().str());
2283}
2284
2285void EnumConstantOp::build(OpBuilder &builder, OperationState &odsState,
2286 EnumFieldAttr field) {
2287 return build(builder, odsState, field.getType().getValue(), field);
2288}
2289
2290OpFoldResult EnumConstantOp::fold(FoldAdaptor adaptor) {
2291 assert(adaptor.getOperands().empty() && "constant has no operands");
2292 return getFieldAttr();
2293}
2294
2295LogicalResult EnumConstantOp::verify() {
2296 auto fieldAttr = getFieldAttr();
2297 auto fieldType = fieldAttr.getType().getValue();
2298 // This check ensures that we are using the exact same type, without looking
2299 // through type aliases.
2300 if (fieldType != getType())
2301 emitOpError("return type ")
2302 << getType() << " does not match attribute type " << fieldAttr;
2303 return success();
2304}
2305
2306//===----------------------------------------------------------------------===//
2307// EnumCmpOp
2308//===----------------------------------------------------------------------===//
2309
2310LogicalResult EnumCmpOp::verify() {
2311 // Compare the canonical types.
2312 auto lhsType = type_cast<EnumType>(getLhs().getType());
2313 auto rhsType = type_cast<EnumType>(getRhs().getType());
2314 if (rhsType != lhsType)
2315 emitOpError("types do not match");
2316 return success();
2317}
2318
2319//===----------------------------------------------------------------------===//
2320// StructCreateOp
2321//===----------------------------------------------------------------------===//
2322
2323ParseResult StructCreateOp::parse(OpAsmParser &parser, OperationState &result) {
2324 llvm::SMLoc inputOperandsLoc = parser.getCurrentLocation();
2325 llvm::SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
2326 Type declOrAliasType;
2327
2328 if (parser.parseLParen() || parser.parseOperandList(operands) ||
2329 parser.parseRParen() || parser.parseOptionalAttrDict(result.attributes) ||
2330 parser.parseColonType(declOrAliasType))
2331 return failure();
2332
2333 auto declType = type_dyn_cast<StructType>(declOrAliasType);
2334 if (!declType)
2335 return parser.emitError(parser.getNameLoc(),
2336 "expected !hw.struct type or alias");
2337
2338 llvm::SmallVector<Type, 4> structInnerTypes;
2339 declType.getInnerTypes(structInnerTypes);
2340 result.addTypes(declOrAliasType);
2341
2342 if (parser.resolveOperands(operands, structInnerTypes, inputOperandsLoc,
2343 result.operands))
2344 return failure();
2345 return success();
2346}
2347
2348void StructCreateOp::print(OpAsmPrinter &printer) {
2349 printer << " (";
2350 printer.printOperands(getInput());
2351 printer << ")";
2352 printer.printOptionalAttrDict((*this)->getAttrs());
2353 printer << " : " << getType();
2354}
2355
2356LogicalResult StructCreateOp::verify() {
2357 auto elements = hw::type_cast<StructType>(getType()).getElements();
2358
2359 if (elements.size() != getInput().size())
2360 return emitOpError("structure field count mismatch");
2361
2362 for (const auto &[field, value] : llvm::zip(elements, getInput()))
2363 if (field.type != value.getType())
2364 return emitOpError("structure field `")
2365 << field.name << "` type does not match";
2366
2367 return success();
2368}
2369
2370OpFoldResult StructCreateOp::fold(FoldAdaptor adaptor) {
2371 // struct_create(struct_explode(x)) => x
2372 if (!getInput().empty())
2373 if (auto explodeOp = getInput()[0].getDefiningOp<StructExplodeOp>();
2374 explodeOp && getInput() == explodeOp.getResults() &&
2375 getResult().getType() == explodeOp.getInput().getType())
2376 return explodeOp.getInput();
2377
2378 auto inputs = adaptor.getInput();
2379 if (llvm::any_of(inputs, [](Attribute attr) { return !attr; }))
2380 return {};
2381 return ArrayAttr::get(getContext(), inputs);
2382}
2383
2384//===----------------------------------------------------------------------===//
2385// StructExplodeOp
2386//===----------------------------------------------------------------------===//
2387
2388ParseResult StructExplodeOp::parse(OpAsmParser &parser,
2389 OperationState &result) {
2390 OpAsmParser::UnresolvedOperand operand;
2391 Type declType;
2392
2393 if (parser.parseOperand(operand) ||
2394 parser.parseOptionalAttrDict(result.attributes) ||
2395 parser.parseColonType(declType))
2396 return failure();
2397 auto structType = type_dyn_cast<StructType>(declType);
2398 if (!structType)
2399 return parser.emitError(parser.getNameLoc(),
2400 "invalid kind of type specified");
2401
2402 llvm::SmallVector<Type, 4> structInnerTypes;
2403 structType.getInnerTypes(structInnerTypes);
2404 result.addTypes(structInnerTypes);
2405
2406 if (parser.resolveOperand(operand, declType, result.operands))
2407 return failure();
2408 return success();
2409}
2410
2411void StructExplodeOp::print(OpAsmPrinter &printer) {
2412 printer << " ";
2413 printer.printOperand(getInput());
2414 printer.printOptionalAttrDict((*this)->getAttrs());
2415 printer << " : " << getInput().getType();
2416}
2417
2418LogicalResult StructExplodeOp::fold(FoldAdaptor adaptor,
2419 SmallVectorImpl<OpFoldResult> &results) {
2420 auto input = adaptor.getInput();
2421 if (!input)
2422 return failure();
2423 llvm::copy(cast<ArrayAttr>(input), std::back_inserter(results));
2424 return success();
2425}
2426
2427LogicalResult StructExplodeOp::canonicalize(StructExplodeOp op,
2428 PatternRewriter &rewriter) {
2429 auto *inputOp = op.getInput().getDefiningOp();
2430 auto elements = type_cast<StructType>(op.getInput().getType()).getElements();
2431 auto result = failure();
2432 auto opResults = op.getResults();
2433 for (uint32_t index = 0; index < elements.size(); index++) {
2434 if (auto foldResult = foldStructExtract(inputOp, index)) {
2435 rewriter.replaceAllUsesWith(opResults[index], foldResult);
2436 result = success();
2437 }
2438 }
2439 return result;
2440}
2441
2442void StructExplodeOp::getAsmResultNames(
2443 function_ref<void(Value, StringRef)> setNameFn) {
2444 auto structType = type_cast<StructType>(getInput().getType());
2445 for (auto [res, field] : llvm::zip(getResults(), structType.getElements()))
2446 setNameFn(res, field.name.str());
2447}
2448
2449void StructExplodeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
2450 Value input) {
2451 StructType inputType = dyn_cast<StructType>(input.getType());
2452 assert(inputType);
2453 SmallVector<Type, 16> fieldTypes;
2454 for (auto field : inputType.getElements())
2455 fieldTypes.push_back(field.type);
2456 build(odsBuilder, odsState, fieldTypes, input);
2457}
2458
2459//===----------------------------------------------------------------------===//
2460// StructExtractOp
2461//===----------------------------------------------------------------------===//
2462
2463/// Ensure an aggregate op's field index is within the bounds of
2464/// the aggregate type and the accessed field is of 'elementType'.
2465template <typename AggregateOp, typename AggregateType>
2466static LogicalResult verifyAggregateFieldIndexAndType(AggregateOp &op,
2467 AggregateType aggType,
2468 Type elementType) {
2469 auto index = op.getFieldIndex();
2470 if (index >= aggType.getElements().size())
2471 return op.emitOpError() << "field index " << index
2472 << " exceeds element count of aggregate type";
2473
2475 getCanonicalType(aggType.getElements()[index].type))
2476 return op.emitOpError()
2477 << "type " << aggType.getElements()[index].type
2478 << " of accessed field in aggregate at index " << index
2479 << " does not match expected type " << elementType;
2480
2481 return success();
2482}
2483
2484LogicalResult StructExtractOp::verify() {
2485 return verifyAggregateFieldIndexAndType<StructExtractOp, StructType>(
2486 *this, getInput().getType(), getType());
2487}
2488
2489/// Use the same parser for both struct_extract and union_extract since the
2490/// syntax is identical.
2491template <typename AggregateType>
2492static ParseResult parseExtractOp(OpAsmParser &parser, OperationState &result) {
2493 OpAsmParser::UnresolvedOperand operand;
2494 StringAttr fieldName;
2495 Type declType;
2496
2497 if (parser.parseOperand(operand) || parser.parseLSquare() ||
2498 parser.parseAttribute(fieldName) || parser.parseRSquare() ||
2499 parser.parseOptionalAttrDict(result.attributes) ||
2500 parser.parseColonType(declType))
2501 return failure();
2502 auto aggType = type_dyn_cast<AggregateType>(declType);
2503 if (!aggType)
2504 return parser.emitError(parser.getNameLoc(),
2505 "invalid kind of type specified");
2506
2507 auto fieldIndex = aggType.getFieldIndex(fieldName);
2508 if (!fieldIndex) {
2509 parser.emitError(parser.getNameLoc(), "field name '" +
2510 fieldName.getValue() +
2511 "' not found in aggregate type");
2512 return failure();
2513 }
2514
2515 auto indexAttr =
2516 IntegerAttr::get(IntegerType::get(parser.getContext(), 32), *fieldIndex);
2517 result.addAttribute("fieldIndex", indexAttr);
2518 Type resultType = aggType.getElements()[*fieldIndex].type;
2519 result.addTypes(resultType);
2520
2521 if (parser.resolveOperand(operand, declType, result.operands))
2522 return failure();
2523 return success();
2524}
2525
2526/// Use the same printer for both struct_extract and union_extract since the
2527/// syntax is identical.
2528template <typename AggType>
2529static void printExtractOp(OpAsmPrinter &printer, AggType op) {
2530 printer << " ";
2531 printer.printOperand(op.getInput());
2532 printer << "[\"" << op.getFieldName() << "\"]";
2533 printer.printOptionalAttrDict(op->getAttrs(), {"fieldIndex"});
2534 printer << " : " << op.getInput().getType();
2535}
2536
2537ParseResult StructExtractOp::parse(OpAsmParser &parser,
2538 OperationState &result) {
2539 return parseExtractOp<StructType>(parser, result);
2540}
2541
2542void StructExtractOp::print(OpAsmPrinter &printer) {
2543 printExtractOp(printer, *this);
2544}
2545
2546void StructExtractOp::build(OpBuilder &builder, OperationState &odsState,
2547 Value input, StructType::FieldInfo field) {
2548 auto fieldIndex =
2549 type_cast<StructType>(input.getType()).getFieldIndex(field.name);
2550 assert(fieldIndex.has_value() && "field name not found in aggregate type");
2551 build(builder, odsState, field.type, input, *fieldIndex);
2552}
2553
2554void StructExtractOp::build(OpBuilder &builder, OperationState &odsState,
2555 Value input, StringAttr fieldName) {
2556 auto structType = type_cast<StructType>(input.getType());
2557 auto fieldIndex = structType.getFieldIndex(fieldName);
2558 assert(fieldIndex.has_value() && "field name not found in aggregate type");
2559 auto resultType = structType.getElements()[*fieldIndex].type;
2560 build(builder, odsState, resultType, input, *fieldIndex);
2561}
2562
2563OpFoldResult StructExtractOp::fold(FoldAdaptor adaptor) {
2564 if (auto constOperand = adaptor.getInput()) {
2565 // Fold extract from aggregate constant
2566 auto operandAttr = llvm::cast<ArrayAttr>(constOperand);
2567 return operandAttr.getValue()[getFieldIndex()];
2568 }
2569
2570 if (auto foldResult =
2571 foldStructExtract(getInput().getDefiningOp(), getFieldIndex()))
2572 return foldResult;
2573 return {};
2574}
2575
2576LogicalResult StructExtractOp::canonicalize(StructExtractOp op,
2577 PatternRewriter &rewriter) {
2578 auto inputOp = op.getInput().getDefiningOp();
2579
2580 // b = extract(inject(x["a"], v0)["b"]) => extract(x, "b")
2581 if (auto structInject = dyn_cast_or_null<StructInjectOp>(inputOp)) {
2582 if (structInject.getFieldIndex() != op.getFieldIndex()) {
2583 rewriter.replaceOpWithNewOp<StructExtractOp>(
2584 op, op.getType(), structInject.getInput(), op.getFieldIndexAttr());
2585 return success();
2586 }
2587 }
2588
2589 return failure();
2590}
2591
2592void StructExtractOp::getAsmResultNames(
2593 function_ref<void(Value, StringRef)> setNameFn) {
2594 setNameFn(getResult(), getFieldName());
2595}
2596
2597//===----------------------------------------------------------------------===//
2598// StructInjectOp
2599//===----------------------------------------------------------------------===//
2600
2601void StructInjectOp::build(OpBuilder &builder, OperationState &odsState,
2602 Value input, StringAttr fieldName, Value newValue) {
2603 auto structType = type_cast<StructType>(input.getType());
2604 auto fieldIndex = structType.getFieldIndex(fieldName);
2605 assert(fieldIndex.has_value() && "field name not found in aggregate type");
2606 build(builder, odsState, input, *fieldIndex, newValue);
2607}
2608
2609LogicalResult StructInjectOp::verify() {
2610 return verifyAggregateFieldIndexAndType<StructInjectOp, StructType>(
2611 *this, getInput().getType(), getNewValue().getType());
2612}
2613
2614ParseResult StructInjectOp::parse(OpAsmParser &parser, OperationState &result) {
2615 llvm::SMLoc inputOperandsLoc = parser.getCurrentLocation();
2616 OpAsmParser::UnresolvedOperand operand, val;
2617 StringAttr fieldName;
2618 Type declType;
2619
2620 if (parser.parseOperand(operand) || parser.parseLSquare() ||
2621 parser.parseAttribute(fieldName) || parser.parseRSquare() ||
2622 parser.parseComma() || parser.parseOperand(val) ||
2623 parser.parseOptionalAttrDict(result.attributes) ||
2624 parser.parseColonType(declType))
2625 return failure();
2626 auto structType = type_dyn_cast<StructType>(declType);
2627 if (!structType)
2628 return parser.emitError(inputOperandsLoc, "invalid kind of type specified");
2629
2630 auto fieldIndex = structType.getFieldIndex(fieldName);
2631 if (!fieldIndex) {
2632 parser.emitError(parser.getNameLoc(), "field name '" +
2633 fieldName.getValue() +
2634 "' not found in aggregate type");
2635 return failure();
2636 }
2637
2638 auto indexAttr =
2639 IntegerAttr::get(IntegerType::get(parser.getContext(), 32), *fieldIndex);
2640 result.addAttribute("fieldIndex", indexAttr);
2641 result.addTypes(declType);
2642
2643 Type resultType = structType.getElements()[*fieldIndex].type;
2644 if (parser.resolveOperands({operand, val}, {declType, resultType},
2645 inputOperandsLoc, result.operands))
2646 return failure();
2647 return success();
2648}
2649
2650void StructInjectOp::print(OpAsmPrinter &printer) {
2651 printer << " ";
2652 printer.printOperand(getInput());
2653 printer << "[\"" << getFieldName() << "\"], ";
2654 printer.printOperand(getNewValue());
2655 printer.printOptionalAttrDict((*this)->getAttrs(), {"fieldIndex"});
2656 printer << " : " << getInput().getType();
2657}
2658
2659OpFoldResult StructInjectOp::fold(FoldAdaptor adaptor) {
2660 auto input = adaptor.getInput();
2661 auto newValue = adaptor.getNewValue();
2662 if (!input || !newValue)
2663 return {};
2664 SmallVector<Attribute> array;
2665 llvm::copy(cast<ArrayAttr>(input), std::back_inserter(array));
2666 array[getFieldIndex()] = newValue;
2667 return ArrayAttr::get(getContext(), array);
2668}
2669
2670LogicalResult StructInjectOp::canonicalize(StructInjectOp op,
2671 PatternRewriter &rewriter) {
2672 // Canonicalize multiple injects into a create op and eliminate overwrites.
2673 SmallPtrSet<Operation *, 4> injects;
2674 DenseMap<StringAttr, Value> fields;
2675
2676 // Chase a chain of injects. Bail out if cycles are present.
2677 StructInjectOp inject = op;
2678 Value input;
2679 do {
2680 if (!injects.insert(inject).second)
2681 return failure();
2682
2683 fields.try_emplace(inject.getFieldNameAttr(), inject.getNewValue());
2684 input = inject.getInput();
2685 inject = dyn_cast_or_null<StructInjectOp>(input.getDefiningOp());
2686 } while (inject);
2687 assert(input && "missing input to inject chain");
2688
2689 auto ty = hw::type_cast<StructType>(op.getType());
2690 auto elements = ty.getElements();
2691
2692 // If the inject chain sets all fields, canonicalize to create.
2693 if (fields.size() == elements.size()) {
2694 SmallVector<Value> createFields;
2695 for (const auto &field : elements) {
2696 auto it = fields.find(field.name);
2697 assert(it != fields.end() && "missing field");
2698 createFields.push_back(it->second);
2699 }
2700 rewriter.replaceOpWithNewOp<StructCreateOp>(op, ty, createFields);
2701 return success();
2702 }
2703
2704 // Nothing to canonicalize, only the original inject in the chain.
2705 if (injects.size() == fields.size())
2706 return failure();
2707
2708 // Eliminate overwrites. The hash map contains the last write to each field.
2709 for (uint32_t fieldIndex = 0; fieldIndex < elements.size(); fieldIndex++) {
2710 auto it = fields.find(elements[fieldIndex].name);
2711 if (it == fields.end())
2712 continue;
2713 input = rewriter.create<StructInjectOp>(op.getLoc(), ty, input, fieldIndex,
2714 it->second);
2715 }
2716
2717 rewriter.replaceOp(op, input);
2718 return success();
2719}
2720
2721//===----------------------------------------------------------------------===//
2722// UnionCreateOp
2723//===----------------------------------------------------------------------===//
2724
2725LogicalResult UnionCreateOp::verify() {
2726 return verifyAggregateFieldIndexAndType<UnionCreateOp, UnionType>(
2727 *this, getType(), getInput().getType());
2728}
2729
2730void UnionCreateOp::build(OpBuilder &builder, OperationState &odsState,
2731 Type unionType, StringAttr fieldName, Value input) {
2732 auto fieldIndex = type_cast<UnionType>(unionType).getFieldIndex(fieldName);
2733 assert(fieldIndex.has_value() && "field name not found in aggregate type");
2734 build(builder, odsState, unionType, *fieldIndex, input);
2735}
2736
2737ParseResult UnionCreateOp::parse(OpAsmParser &parser, OperationState &result) {
2738 Type declOrAliasType;
2739 StringAttr fieldName;
2740 OpAsmParser::UnresolvedOperand input;
2741 llvm::SMLoc fieldLoc = parser.getCurrentLocation();
2742
2743 if (parser.parseAttribute(fieldName) || parser.parseComma() ||
2744 parser.parseOperand(input) ||
2745 parser.parseOptionalAttrDict(result.attributes) ||
2746 parser.parseColonType(declOrAliasType))
2747 return failure();
2748
2749 auto declType = type_dyn_cast<UnionType>(declOrAliasType);
2750 if (!declType)
2751 return parser.emitError(parser.getNameLoc(),
2752 "expected !hw.union type or alias");
2753
2754 auto fieldIndex = declType.getFieldIndex(fieldName);
2755 if (!fieldIndex) {
2756 parser.emitError(fieldLoc, "cannot find union field '")
2757 << fieldName.getValue() << '\'';
2758 return failure();
2759 }
2760
2761 auto indexAttr =
2762 IntegerAttr::get(IntegerType::get(parser.getContext(), 32), *fieldIndex);
2763 result.addAttribute("fieldIndex", indexAttr);
2764 Type inputType = declType.getElements()[*fieldIndex].type;
2765
2766 if (parser.resolveOperand(input, inputType, result.operands))
2767 return failure();
2768 result.addTypes({declOrAliasType});
2769 return success();
2770}
2771
2772void UnionCreateOp::print(OpAsmPrinter &printer) {
2773 printer << " \"" << getFieldName() << "\", ";
2774 printer.printOperand(getInput());
2775 printer.printOptionalAttrDict((*this)->getAttrs(), {"fieldIndex"});
2776 printer << " : " << getType();
2777}
2778
2779//===----------------------------------------------------------------------===//
2780// UnionExtractOp
2781//===----------------------------------------------------------------------===//
2782
2783ParseResult UnionExtractOp::parse(OpAsmParser &parser, OperationState &result) {
2784 return parseExtractOp<UnionType>(parser, result);
2785}
2786
2787void UnionExtractOp::print(OpAsmPrinter &printer) {
2788 printExtractOp(printer, *this);
2789}
2790
2791LogicalResult UnionExtractOp::inferReturnTypes(
2792 MLIRContext *context, std::optional<Location> loc, ValueRange operands,
2793 DictionaryAttr attrs, mlir::OpaqueProperties properties,
2794 mlir::RegionRange regions, SmallVectorImpl<Type> &results) {
2795 Adaptor adaptor(operands, attrs, properties, regions);
2796 auto unionElements =
2797 hw::type_cast<UnionType>((adaptor.getInput().getType())).getElements();
2798 unsigned fieldIndex = adaptor.getFieldIndexAttr().getValue().getZExtValue();
2799 if (fieldIndex >= unionElements.size()) {
2800 if (loc)
2801 mlir::emitError(*loc, "field index " + Twine(fieldIndex) +
2802 " exceeds element count of aggregate type");
2803 return failure();
2804 }
2805 results.push_back(unionElements[fieldIndex].type);
2806 return success();
2807}
2808
2809void UnionExtractOp::build(OpBuilder &odsBuilder, OperationState &odsState,
2810 Value input, StringAttr fieldName) {
2811 auto unionType = type_cast<UnionType>(input.getType());
2812 auto fieldIndex = unionType.getFieldIndex(fieldName);
2813 assert(fieldIndex.has_value() && "field name not found in aggregate type");
2814 auto resultType = unionType.getElements()[*fieldIndex].type;
2815 build(odsBuilder, odsState, resultType, input, *fieldIndex);
2816}
2817
2818//===----------------------------------------------------------------------===//
2819// ArrayGetOp
2820//===----------------------------------------------------------------------===//
2821
2822// An array_get of an array_create with a constant index can just be the
2823// array_create operand at the constant index. If the array_create has a
2824// single uniform value for each element, just return that value regardless of
2825// the index. If the array is constructed from a constant by a bitcast
2826// operation, we can fold into a constant.
2827OpFoldResult ArrayGetOp::fold(FoldAdaptor adaptor) {
2828 auto inputCst = dyn_cast_or_null<ArrayAttr>(adaptor.getInput());
2829 auto indexCst = dyn_cast_or_null<IntegerAttr>(adaptor.getIndex());
2830
2831 if (inputCst) {
2832 // Constant array index.
2833 if (indexCst) {
2834 auto indexVal = indexCst.getValue();
2835 if (indexVal.getBitWidth() < 64) {
2836 auto index = indexVal.getZExtValue();
2837 return inputCst[inputCst.size() - 1 - index];
2838 }
2839 }
2840 // If all elements of the array are the same, we can return any element of
2841 // array.
2842 if (!inputCst.empty() && llvm::all_equal(inputCst))
2843 return inputCst[0];
2844 }
2845
2846 // array_get(bitcast(c), i) -> c[i*w+w-1:i*w]
2847 if (auto bitcast = getInput().getDefiningOp<hw::BitcastOp>()) {
2848 auto intTy = dyn_cast<IntegerType>(getType());
2849 if (!intTy)
2850 return {};
2851 auto bitcastInputOp = bitcast.getInput().getDefiningOp<hw::ConstantOp>();
2852 if (!bitcastInputOp)
2853 return {};
2854 if (!indexCst)
2855 return {};
2856 auto bitcastInputCst = bitcastInputOp.getValue();
2857 // Calculate the index. Make sure to zero-extend the index value before
2858 // multiplying the element width.
2859 auto startIdx = indexCst.getValue().zext(bitcastInputCst.getBitWidth()) *
2860 getType().getIntOrFloatBitWidth();
2861 // Extract [startIdx + width - 1: startIdx].
2862 return IntegerAttr::get(intTy, bitcastInputCst.lshr(startIdx).trunc(
2863 intTy.getIntOrFloatBitWidth()));
2864 }
2865
2866 // array_get(array_inject(_, index, element), index) -> element
2867 if (auto inject = getInput().getDefiningOp<ArrayInjectOp>())
2868 if (getIndex() == inject.getIndex())
2869 return inject.getElement();
2870
2871 auto inputCreate = getInput().getDefiningOp<ArrayCreateOp>();
2872 if (!inputCreate)
2873 return {};
2874
2875 if (auto uniformValue = inputCreate.getUniformElement())
2876 return uniformValue;
2877
2878 if (!indexCst || indexCst.getValue().getBitWidth() > 64)
2879 return {};
2880
2881 uint64_t index = indexCst.getValue().getLimitedValue();
2882 auto createInputs = inputCreate.getInputs();
2883 if (index >= createInputs.size())
2884 return {};
2885 return createInputs[createInputs.size() - index - 1];
2886}
2887
2888LogicalResult ArrayGetOp::canonicalize(ArrayGetOp op,
2889 PatternRewriter &rewriter) {
2890 auto idxOpt = getUIntFromValue(op.getIndex());
2891 if (!idxOpt)
2892 return failure();
2893
2894 auto *inputOp = op.getInput().getDefiningOp();
2895 if (auto inputSlice = dyn_cast_or_null<ArraySliceOp>(inputOp)) {
2896 // get(slice(a, n), m) -> get(a, n + m)
2897 auto offsetOp = inputSlice.getLowIndex();
2898 auto offsetOpt = getUIntFromValue(offsetOp);
2899 if (!offsetOpt)
2900 return failure();
2901
2902 uint64_t offset = *offsetOpt + *idxOpt;
2903 auto newOffset =
2904 rewriter.create<ConstantOp>(op.getLoc(), offsetOp.getType(), offset);
2905 rewriter.replaceOpWithNewOp<ArrayGetOp>(op, inputSlice.getInput(),
2906 newOffset);
2907 return success();
2908 }
2909
2910 if (auto inputConcat = dyn_cast_or_null<ArrayConcatOp>(inputOp)) {
2911 // get(concat(a0, a1, ...), m) -> get(an, m - s0 - s1 - ...)
2912 uint64_t elemIndex = *idxOpt;
2913 for (auto input : llvm::reverse(inputConcat.getInputs())) {
2914 size_t size = hw::type_cast<ArrayType>(input.getType()).getNumElements();
2915 if (elemIndex >= size) {
2916 elemIndex -= size;
2917 continue;
2918 }
2919
2920 unsigned indexWidth = size == 1 ? 1 : llvm::Log2_64_Ceil(size);
2921 auto newIdxOp = rewriter.create<ConstantOp>(
2922 op.getLoc(), rewriter.getIntegerType(indexWidth), elemIndex);
2923
2924 rewriter.replaceOpWithNewOp<ArrayGetOp>(op, input, newIdxOp);
2925 return success();
2926 }
2927 return failure();
2928 }
2929
2930 // array_get const, (array_get sel, (array_create a, b, c, d)) -->
2931 // array_get sel, (array_create (array_get const a), (array_get const b),
2932 // (array_get const, c), (array_get const, d))
2933 if (auto innerGet = dyn_cast_or_null<hw::ArrayGetOp>(inputOp)) {
2934 if (!innerGet.getIndex().getDefiningOp<hw::ConstantOp>()) {
2935 if (auto create =
2936 innerGet.getInput().getDefiningOp<hw::ArrayCreateOp>()) {
2937
2938 SmallVector<Value> newValues;
2939 for (auto operand : create.getOperands())
2940 newValues.push_back(rewriter.createOrFold<hw::ArrayGetOp>(
2941 op.getLoc(), operand, op.getIndex()));
2942
2943 rewriter.replaceOpWithNewOp<hw::ArrayGetOp>(
2944 op,
2945 rewriter.createOrFold<hw::ArrayCreateOp>(op.getLoc(), newValues),
2946 innerGet.getIndex());
2947 return success();
2948 }
2949 }
2950 }
2951
2952 return failure();
2953}
2954
2955//===----------------------------------------------------------------------===//
2956// ArrayInjectOp
2957//===----------------------------------------------------------------------===//
2958
2959OpFoldResult ArrayInjectOp::fold(FoldAdaptor adaptor) {
2960 auto inputAttr = dyn_cast_or_null<ArrayAttr>(adaptor.getInput());
2961 auto indexAttr = dyn_cast_or_null<IntegerAttr>(adaptor.getIndex());
2962 auto elementAttr = adaptor.getElement();
2963
2964 // inject(constant[xs, y, zs], iy, a) -> constant[x, a, z]
2965 if (inputAttr && indexAttr && elementAttr) {
2966 if (auto index = indexAttr.getValue().tryZExtValue()) {
2967 if (*index < inputAttr.size()) {
2968 SmallVector<Attribute> elements(inputAttr.getValue());
2969 elements[inputAttr.size() - 1 - *index] = elementAttr;
2970 return ArrayAttr::get(getContext(), elements);
2971 }
2972 }
2973 }
2974
2975 return {};
2976}
2977
2978void ArrayInjectOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
2979 MLIRContext *context) {
2980 patterns.add<ArrayInjectToSameIndex>(context);
2981}
2982
2983//===----------------------------------------------------------------------===//
2984// TypedeclOp
2985//===----------------------------------------------------------------------===//
2986
2987StringRef TypedeclOp::getPreferredName() {
2988 return getVerilogName().value_or(getName());
2989}
2990
2991Type TypedeclOp::getAliasType() {
2992 auto parentScope = cast<hw::TypeScopeOp>(getOperation()->getParentOp());
2993 return hw::TypeAliasType::get(
2994 SymbolRefAttr::get(parentScope.getSymNameAttr(),
2995 {FlatSymbolRefAttr::get(*this)}),
2996 getType());
2997}
2998
2999//===----------------------------------------------------------------------===//
3000// BitcastOp
3001//===----------------------------------------------------------------------===//
3002
3003OpFoldResult BitcastOp::fold(FoldAdaptor) {
3004 // Identity.
3005 // bitcast(%a) : A -> A ==> %a
3006 if (getOperand().getType() == getType())
3007 return getOperand();
3008
3009 return {};
3010}
3011
3012LogicalResult BitcastOp::canonicalize(BitcastOp op, PatternRewriter &rewriter) {
3013 // Composition.
3014 // %b = bitcast(%a) : A -> B
3015 // bitcast(%b) : B -> C
3016 // ===> bitcast(%a) : A -> C
3017 auto inputBitcast =
3018 dyn_cast_or_null<BitcastOp>(op.getInput().getDefiningOp());
3019 if (!inputBitcast)
3020 return failure();
3021 auto bitcast = rewriter.createOrFold<BitcastOp>(op.getLoc(), op.getType(),
3022 inputBitcast.getInput());
3023 rewriter.replaceOp(op, bitcast);
3024 return success();
3025}
3026
3027LogicalResult BitcastOp::verify() {
3028 if (getBitWidth(getInput().getType()) != getBitWidth(getResult().getType()))
3029 return this->emitOpError("Bitwidth of input must match result");
3030 return success();
3031}
3032
3033//===----------------------------------------------------------------------===//
3034// HierPathOp helpers.
3035//===----------------------------------------------------------------------===//
3036
3037bool HierPathOp::dropModule(StringAttr moduleToDrop) {
3038 SmallVector<Attribute, 4> newPath;
3039 bool updateMade = false;
3040 for (auto nameRef : getNamepath()) {
3041 // nameRef is either an InnerRefAttr or a FlatSymbolRefAttr.
3042 if (auto ref = dyn_cast<hw::InnerRefAttr>(nameRef)) {
3043 if (ref.getModule() == moduleToDrop)
3044 updateMade = true;
3045 else
3046 newPath.push_back(ref);
3047 } else {
3048 if (cast<FlatSymbolRefAttr>(nameRef).getAttr() == moduleToDrop)
3049 updateMade = true;
3050 else
3051 newPath.push_back(nameRef);
3052 }
3053 }
3054 if (updateMade)
3055 setNamepathAttr(ArrayAttr::get(getContext(), newPath));
3056 return updateMade;
3057}
3058
3059bool HierPathOp::inlineModule(StringAttr moduleToDrop) {
3060 SmallVector<Attribute, 4> newPath;
3061 bool updateMade = false;
3062 StringRef inlinedInstanceName = "";
3063 for (auto nameRef : getNamepath()) {
3064 // nameRef is either an InnerRefAttr or a FlatSymbolRefAttr.
3065 if (auto ref = dyn_cast<hw::InnerRefAttr>(nameRef)) {
3066 if (ref.getModule() == moduleToDrop) {
3067 inlinedInstanceName = ref.getName().getValue();
3068 updateMade = true;
3069 } else if (!inlinedInstanceName.empty()) {
3070 newPath.push_back(hw::InnerRefAttr::get(
3071 ref.getModule(),
3072 StringAttr::get(getContext(), inlinedInstanceName + "_" +
3073 ref.getName().getValue())));
3074 inlinedInstanceName = "";
3075 } else
3076 newPath.push_back(ref);
3077 } else {
3078 if (cast<FlatSymbolRefAttr>(nameRef).getAttr() == moduleToDrop)
3079 updateMade = true;
3080 else
3081 newPath.push_back(nameRef);
3082 }
3083 }
3084 if (updateMade)
3085 setNamepathAttr(ArrayAttr::get(getContext(), newPath));
3086 return updateMade;
3087}
3088
3089bool HierPathOp::updateModule(StringAttr oldMod, StringAttr newMod) {
3090 SmallVector<Attribute, 4> newPath;
3091 bool updateMade = false;
3092 for (auto nameRef : getNamepath()) {
3093 // nameRef is either an InnerRefAttr or a FlatSymbolRefAttr.
3094 if (auto ref = dyn_cast<hw::InnerRefAttr>(nameRef)) {
3095 if (ref.getModule() == oldMod) {
3096 newPath.push_back(hw::InnerRefAttr::get(newMod, ref.getName()));
3097 updateMade = true;
3098 } else
3099 newPath.push_back(ref);
3100 } else {
3101 if (cast<FlatSymbolRefAttr>(nameRef).getAttr() == oldMod) {
3102 newPath.push_back(FlatSymbolRefAttr::get(newMod));
3103 updateMade = true;
3104 } else
3105 newPath.push_back(nameRef);
3106 }
3107 }
3108 if (updateMade)
3109 setNamepathAttr(ArrayAttr::get(getContext(), newPath));
3110 return updateMade;
3111}
3112
3113bool HierPathOp::updateModuleAndInnerRef(
3114 StringAttr oldMod, StringAttr newMod,
3115 const llvm::DenseMap<StringAttr, StringAttr> &innerSymRenameMap) {
3116 auto fromRef = FlatSymbolRefAttr::get(oldMod);
3117 if (oldMod == newMod)
3118 return false;
3119
3120 auto namepathNew = getNamepath().getValue().vec();
3121 bool updateMade = false;
3122 // Break from the loop if the module is found, since it can occur only once.
3123 for (auto &element : namepathNew) {
3124 if (auto innerRef = dyn_cast<hw::InnerRefAttr>(element)) {
3125 if (innerRef.getModule() != oldMod)
3126 continue;
3127 auto symName = innerRef.getName();
3128 // Since the module got updated, the old innerRef symbol inside oldMod
3129 // should also be updated to the new symbol inside the newMod.
3130 auto to = innerSymRenameMap.find(symName);
3131 if (to != innerSymRenameMap.end())
3132 symName = to->second;
3133 updateMade = true;
3134 element = hw::InnerRefAttr::get(newMod, symName);
3135 break;
3136 }
3137 if (element != fromRef)
3138 continue;
3139
3140 updateMade = true;
3141 element = FlatSymbolRefAttr::get(newMod);
3142 break;
3143 }
3144 if (updateMade)
3145 setNamepathAttr(ArrayAttr::get(getContext(), namepathNew));
3146 return updateMade;
3147}
3148
3149bool HierPathOp::truncateAtModule(StringAttr atMod, bool includeMod) {
3150 SmallVector<Attribute, 4> newPath;
3151 bool updateMade = false;
3152 for (auto nameRef : getNamepath()) {
3153 // nameRef is either an InnerRefAttr or a FlatSymbolRefAttr.
3154 if (auto ref = dyn_cast<hw::InnerRefAttr>(nameRef)) {
3155 if (ref.getModule() == atMod) {
3156 updateMade = true;
3157 if (includeMod)
3158 newPath.push_back(ref);
3159 } else
3160 newPath.push_back(ref);
3161 } else {
3162 if (cast<FlatSymbolRefAttr>(nameRef).getAttr() == atMod && !includeMod)
3163 updateMade = true;
3164 else
3165 newPath.push_back(nameRef);
3166 }
3167 if (updateMade)
3168 break;
3169 }
3170 if (updateMade)
3171 setNamepathAttr(ArrayAttr::get(getContext(), newPath));
3172 return updateMade;
3173}
3174
3175/// Return just the module part of the namepath at a specific index.
3176StringAttr HierPathOp::modPart(unsigned i) {
3177 return TypeSwitch<Attribute, StringAttr>(getNamepath()[i])
3178 .Case<FlatSymbolRefAttr>([](auto a) { return a.getAttr(); })
3179 .Case<hw::InnerRefAttr>([](auto a) { return a.getModule(); });
3180}
3181
3182/// Return the root module.
3183StringAttr HierPathOp::root() {
3184 assert(!getNamepath().empty());
3185 return modPart(0);
3186}
3187
3188/// Return true if the NLA has the module in its path.
3189bool HierPathOp::hasModule(StringAttr modName) {
3190 for (auto nameRef : getNamepath()) {
3191 // nameRef is either an InnerRefAttr or a FlatSymbolRefAttr.
3192 if (auto ref = dyn_cast<hw::InnerRefAttr>(nameRef)) {
3193 if (ref.getModule() == modName)
3194 return true;
3195 } else {
3196 if (cast<FlatSymbolRefAttr>(nameRef).getAttr() == modName)
3197 return true;
3198 }
3199 }
3200 return false;
3201}
3202
3203/// Return true if the NLA has the InnerSym .
3204bool HierPathOp::hasInnerSym(StringAttr modName, StringAttr symName) const {
3205 for (auto nameRef : const_cast<HierPathOp *>(this)->getNamepath())
3206 if (auto ref = dyn_cast<hw::InnerRefAttr>(nameRef))
3207 if (ref.getName() == symName && ref.getModule() == modName)
3208 return true;
3209
3210 return false;
3211}
3212
3213/// Return just the reference part of the namepath at a specific index. This
3214/// will return an empty attribute if this is the leaf and the leaf is a module.
3215StringAttr HierPathOp::refPart(unsigned i) {
3216 return TypeSwitch<Attribute, StringAttr>(getNamepath()[i])
3217 .Case<FlatSymbolRefAttr>([](auto a) { return StringAttr({}); })
3218 .Case<hw::InnerRefAttr>([](auto a) { return a.getName(); });
3219}
3220
3221/// Return the leaf reference. This returns an empty attribute if the leaf
3222/// reference is a module.
3223StringAttr HierPathOp::ref() {
3224 assert(!getNamepath().empty());
3225 return refPart(getNamepath().size() - 1);
3226}
3227
3228/// Return the leaf module.
3229StringAttr HierPathOp::leafMod() {
3230 assert(!getNamepath().empty());
3231 return modPart(getNamepath().size() - 1);
3232}
3233
3234/// Returns true if this NLA targets an instance of a module (as opposed to
3235/// an instance's port or something inside an instance).
3236bool HierPathOp::isModule() { return !ref(); }
3237
3238/// Returns true if this NLA targets something inside a module (as opposed
3239/// to a module or an instance of a module);
3240bool HierPathOp::isComponent() { return (bool)ref(); }
3241
3242// Verify the HierPathOp.
3243// 1. Iterate over the namepath.
3244// 2. The namepath should be a valid instance path, specified either on a
3245// module or a declaration inside a module.
3246// 3. Each element in the namepath is an InnerRefAttr except possibly the
3247// last element.
3248// 4. Make sure that the InnerRefAttr is legal, by verifying the module name
3249// and the corresponding inner_sym on the instance.
3250// 5. Make sure that the instance path is legal, by verifying the sequence of
3251// instance and the expected module occurs as the next element in the path.
3252// 6. The last element of the namepath, can be an InnerRefAttr on either a
3253// module port or a declaration inside the module.
3254// 7. The last element of the namepath can also be a module symbol.
3255LogicalResult HierPathOp::verifyInnerRefs(hw::InnerRefNamespace &ns) {
3256 ArrayAttr expectedModuleNames = {};
3257 auto checkExpectedModule = [&](Attribute name) -> LogicalResult {
3258 if (!expectedModuleNames)
3259 return success();
3260 if (llvm::any_of(expectedModuleNames,
3261 [name](Attribute attr) { return attr == name; }))
3262 return success();
3263 auto diag = emitOpError() << "instance path is incorrect. Expected ";
3264 size_t n = expectedModuleNames.size();
3265 if (n != 1) {
3266 diag << "one of ";
3267 }
3268 for (size_t i = 0; i < n; ++i) {
3269 if (i != 0)
3270 diag << ((i + 1 == n) ? " or " : ", ");
3271 diag << cast<StringAttr>(expectedModuleNames[i]);
3272 }
3273 diag << ". Instead found: " << name;
3274 return diag;
3275 };
3276
3277 if (!getNamepath() || getNamepath().empty())
3278 return emitOpError() << "the instance path cannot be empty";
3279 for (unsigned i = 0, s = getNamepath().size() - 1; i < s; ++i) {
3280 hw::InnerRefAttr innerRef = dyn_cast<hw::InnerRefAttr>(getNamepath()[i]);
3281 if (!innerRef)
3282 return emitOpError()
3283 << "the instance path can only contain inner sym reference"
3284 << ", only the leaf can refer to a module symbol";
3285
3286 if (failed(checkExpectedModule(innerRef.getModule())))
3287 return failure();
3288
3289 auto instOp = ns.lookupOp<igraph::InstanceOpInterface>(innerRef);
3290 if (!instOp)
3291 return emitOpError() << " module: " << innerRef.getModule()
3292 << " does not contain any instance with symbol: "
3293 << innerRef.getName();
3294 expectedModuleNames = instOp.getReferencedModuleNamesAttr();
3295 }
3296
3297 // The instance path has been verified. Now verify the last element.
3298 auto leafRef = getNamepath()[getNamepath().size() - 1];
3299 if (auto innerRef = dyn_cast<hw::InnerRefAttr>(leafRef)) {
3300 if (!ns.lookup(innerRef)) {
3301 return emitOpError() << " operation with symbol: " << innerRef
3302 << " was not found ";
3303 }
3304 if (failed(checkExpectedModule(innerRef.getModule())))
3305 return failure();
3306 } else if (failed(checkExpectedModule(
3307 cast<FlatSymbolRefAttr>(leafRef).getAttr()))) {
3308 return failure();
3309 }
3310 return success();
3311}
3312
3313void HierPathOp::print(OpAsmPrinter &p) {
3314 p << " ";
3315
3316 // Print visibility if present.
3317 StringRef visibilityAttrName = SymbolTable::getVisibilityAttrName();
3318 if (auto visibility =
3319 getOperation()->getAttrOfType<StringAttr>(visibilityAttrName))
3320 p << visibility.getValue() << ' ';
3321
3322 p.printSymbolName(getSymName());
3323 p << " [";
3324 llvm::interleaveComma(getNamepath().getValue(), p, [&](Attribute attr) {
3325 if (auto ref = dyn_cast<hw::InnerRefAttr>(attr)) {
3326 p.printSymbolName(ref.getModule().getValue());
3327 p << "::";
3328 p.printSymbolName(ref.getName().getValue());
3329 } else {
3330 p.printSymbolName(cast<FlatSymbolRefAttr>(attr).getValue());
3331 }
3332 });
3333 p << "]";
3334 p.printOptionalAttrDict(
3335 (*this)->getAttrs(),
3336 {SymbolTable::getSymbolAttrName(), "namepath", visibilityAttrName});
3337}
3338
3339ParseResult HierPathOp::parse(OpAsmParser &parser, OperationState &result) {
3340 // Parse the visibility attribute.
3341 (void)mlir::impl::parseOptionalVisibilityKeyword(parser, result.attributes);
3342
3343 // Parse the symbol name.
3344 StringAttr symName;
3345 if (parser.parseSymbolName(symName, SymbolTable::getSymbolAttrName(),
3346 result.attributes))
3347 return failure();
3348
3349 // Parse the namepath.
3350 SmallVector<Attribute> namepath;
3351 if (parser.parseCommaSeparatedList(
3352 OpAsmParser::Delimiter::Square, [&]() -> ParseResult {
3353 auto loc = parser.getCurrentLocation();
3354 SymbolRefAttr ref;
3355 if (parser.parseAttribute(ref))
3356 return failure();
3357
3358 // "A" is a Ref, "A::b" is a InnerRef, "A::B::c" is an error.
3359 auto pathLength = ref.getNestedReferences().size();
3360 if (pathLength == 0)
3361 namepath.push_back(
3362 FlatSymbolRefAttr::get(ref.getRootReference()));
3363 else if (pathLength == 1)
3364 namepath.push_back(hw::InnerRefAttr::get(ref.getRootReference(),
3365 ref.getLeafReference()));
3366 else
3367 return parser.emitError(loc,
3368 "only one nested reference is allowed");
3369 return success();
3370 }))
3371 return failure();
3372 result.addAttribute("namepath",
3373 ArrayAttr::get(parser.getContext(), namepath));
3374
3375 if (parser.parseOptionalAttrDict(result.attributes))
3376 return failure();
3377
3378 return success();
3379}
3380
3381//===----------------------------------------------------------------------===//
3382// TriggeredOp
3383//===----------------------------------------------------------------------===//
3384
3385void TriggeredOp::build(OpBuilder &builder, OperationState &odsState,
3386 EventControlAttr event, Value trigger,
3387 ValueRange inputs) {
3388 odsState.addOperands(trigger);
3389 odsState.addOperands(inputs);
3390 odsState.addAttribute(getEventAttrName(odsState.name), event);
3391 auto *r = odsState.addRegion();
3392 Block *b = new Block();
3393 r->push_back(b);
3394
3395 llvm::SmallVector<Location> argLocs;
3396 llvm::transform(inputs, std::back_inserter(argLocs),
3397 [&](Value v) { return v.getLoc(); });
3398 b->addArguments(inputs.getTypes(), argLocs);
3399}
3400
3401//===----------------------------------------------------------------------===//
3402// TableGen generated logic.
3403//===----------------------------------------------------------------------===//
3404
3405// Provide the autogenerated implementation guts for the Op classes.
3406#define GET_OP_CLASSES
3407#include "circt/Dialect/HW/HW.cpp.inc"
assert(baseType &&"element must be base type")
MlirType elementType
Definition CHIRRTL.cpp:29
static void buildModule(OpBuilder &builder, OperationState &result, StringAttr name, ArrayRef< PortInfo > ports, ArrayAttr annotations, ArrayAttr layers)
void getAsmBlockArgumentNamesImpl(Operation *op, mlir::Region &region, OpAsmSetValueNameFn setNameFn)
Get a special name to use when printing the entry block arguments of the region contained by an opera...
static LogicalResult verifyModuleCommon(HWModuleLike module)
Definition HWOps.cpp:1085
static void printParamValue(OpAsmPrinter &p, Operation *, Attribute value, Type resultType)
Definition HWOps.cpp:500
static void printModuleOp(OpAsmPrinter &p, ModuleTy mod)
Definition HWOps.cpp:1030
static bool flattenConcatOp(ArrayConcatOp op, PatternRewriter &rewriter)
Definition HWOps.cpp:2134
static LogicalResult foldCreateToSlice(ArrayCreateOp op, PatternRewriter &rewriter)
Definition HWOps.cpp:1865
static SmallVector< PortInfo > getPortList(ModuleTy &mod)
Definition HWOps.cpp:1427
static ArrayAttr arrayOrEmpty(mlir::MLIRContext *context, ArrayRef< Attribute > attrs)
Definition HWOps.cpp:83
FunctionType getHWModuleOpType(Operation *op)
Definition HWOps.cpp:1022
static void printExtractOp(OpAsmPrinter &printer, AggType op)
Use the same printer for both struct_extract and union_extract since the syntax is identical.
Definition HWOps.cpp:2529
static void printArrayConcatTypes(OpAsmPrinter &p, Operation *, TypeRange inputTypes, Type resultType)
Definition HWOps.cpp:2095
static ParseResult parseSliceTypes(OpAsmParser &p, Type &srcType, Type &idxType)
Definition HWOps.cpp:1769
static void modifyModulePorts(Operation *op, ArrayRef< std::pair< unsigned, PortInfo > > insertInputs, ArrayRef< std::pair< unsigned, PortInfo > > insertOutputs, ArrayRef< unsigned > removeInputs, ArrayRef< unsigned > removeOutputs, Block *body=nullptr)
Insert and remove ports of a module.
Definition HWOps.cpp:693
static Value foldStructExtract(Operation *inputOp, uint32_t fieldIndex)
Definition HWOps.cpp:68
static bool hasAttribute(StringRef name, ArrayRef< NamedAttribute > attrs)
Definition HWOps.cpp:900
static void modifyModuleArgs(MLIRContext *context, ArrayRef< std::pair< unsigned, PortInfo > > insertArgs, ArrayRef< unsigned > removeArgs, ArrayRef< Attribute > oldArgNames, ArrayRef< Type > oldArgTypes, ArrayRef< Attribute > oldArgAttrs, ArrayRef< Location > oldArgLocs, SmallVector< Attribute > &newArgNames, SmallVector< Type > &newArgTypes, SmallVector< Attribute > &newArgAttrs, SmallVector< Location > &newArgLocs, Block *body=nullptr)
Internal implementation of argument/result insertion and removal on modules.
Definition HWOps.cpp:594
static bool mergeConcatSlices(ArrayConcatOp op, PatternRewriter &rewriter)
Definition HWOps.cpp:2151
static SmallVector< Location > getAllPortLocs(ModTy module)
Definition HWOps.cpp:1205
static ParseResult parseExtractOp(OpAsmParser &parser, OperationState &result)
Use the same parser for both struct_extract and union_extract since the syntax is identical.
Definition HWOps.cpp:2492
static void setAllPortNames(ArrayRef< Attribute > names, ModTy module)
Definition HWOps.cpp:1275
static void getAsmBlockArgumentNamesImpl(mlir::Region &region, OpAsmSetValueNameFn setNameFn)
Get a special name to use when printing the entry block arguments of the region contained by an opera...
Definition HWOps.cpp:100
static void setHWModuleType(ModTy &mod, ModuleType type)
Definition HWOps.cpp:1348
static ParseResult parseParamValue(OpAsmParser &p, Attribute &value, Type &resultType)
Definition HWOps.cpp:492
static LogicalResult checkAttributes(Operation *op, Attribute attr, Type type)
Definition HWOps.cpp:406
static std::optional< uint64_t > getUIntFromValue(Value value)
Definition HWOps.cpp:1936
static ParseResult parseHWModuleOp(OpAsmParser &parser, OperationState &result)
Definition HWOps.cpp:908
static LogicalResult verifyAggregateFieldIndexAndType(AggregateOp &op, AggregateType aggType, Type elementType)
Ensure an aggregate op's field index is within the bounds of the aggregate type and the accessed fiel...
Definition HWOps.cpp:2466
static PortInfo getPort(ModuleTy &mod, size_t idx)
Definition HWOps.cpp:1447
static void printSliceTypes(OpAsmPrinter &p, Operation *, Type srcType, Type idxType)
Definition HWOps.cpp:1783
static bool hasAdditionalAttributes(Op op, ArrayRef< StringRef > ignoredAttrs={})
Check whether an operation has any additional attributes set beyond its standard list of attributes r...
Definition HWOps.cpp:352
Delimiter
Definition HWOps.cpp:115
@ OptionalLessGreater
static ParseResult parseArrayConcatTypes(OpAsmParser &p, SmallVectorImpl< Type > &inputTypes, Type &resultType)
Definition HWOps.cpp:2065
static bool getFieldName(const FieldRef &fieldRef, SmallString< 32 > &string)
static InstancePath empty
static Location getLoc(DefSlot slot)
Definition Mem2Reg.cpp:217
static StringAttr append(StringAttr base, const Twine &suffix)
Return a attribute with the specified suffix appended.
static Block * getBodyBlock(FModuleLike mod)
A namespace that is used to store existing names and generate new names in some scope within the IR.
Definition Namespace.h:30
StringRef newName(const Twine &name)
Return a unique name, derived from the input name, and add the new name to the internal namespace.
Definition Namespace.h:87
void setOutput(unsigned i, Value v)
Definition HWOps.cpp:240
Value getInput(unsigned i)
Definition HWOps.cpp:246
llvm::SmallVector< Value > outputOperands
Definition HWOps.h:119
llvm::SmallVector< Value > inputArgs
Definition HWOps.h:118
llvm::StringMap< unsigned > outputIdx
Definition HWOps.h:117
llvm::StringMap< unsigned > inputIdx
Definition HWOps.h:117
HWModulePortAccessor(Location loc, const ModulePortInfo &info, Region &bodyRegion)
Definition HWOps.cpp:224
static StringRef getInnerSymbolAttrName()
Return the name of the attribute used for inner symbol names.
This helps visit TypeOp nodes.
Definition HWVisitors.h:25
ResultType dispatchTypeOpVisitor(Operation *op, ExtraArgs... args)
Definition HWVisitors.h:27
ResultType visitUnhandledTypeOp(Operation *op, ExtraArgs... args)
This callback is invoked on any combinational operations that are not handled by the concrete visitor...
Definition HWVisitors.h:57
ResultType visitInvalidTypeOp(Operation *op, ExtraArgs... args)
This callback is invoked on any non-expression operations.
Definition HWVisitors.h:50
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition CalyxOps.cpp:55
uint64_t getWidth(Type t)
Definition ESIPasses.cpp:32
size_t getNumPorts(Operation *op)
Return the number of ports in a module-like thing (modules, memories, etc)
ModuleType fnToMod(Operation *op, ArrayRef< Attribute > inputNames, ArrayRef< Attribute > outputNames)
Definition HWTypes.cpp:1023
LogicalResult verifyParameterStructure(ArrayAttr parameters, ArrayAttr moduleParameters, const EmitErrorFn &emitError)
Check that all the parameter values specified to the instance are structurally valid.
std::function< void(std::function< bool(InFlightDiagnostic &)>)> EmitErrorFn
Whenever the nested function returns true, a note referring to the referenced module is attached to t...
LogicalResult verifyInstanceOfHWModule(Operation *instance, FlatSymbolRefAttr moduleRef, OperandRange inputs, TypeRange results, ArrayAttr argNames, ArrayAttr resultNames, ArrayAttr parameters, SymbolTableCollection &symbolTable)
Combines verifyReferencedModule, verifyInputs, verifyOutputs, and verifyParameters.
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
ParseResult parseModuleSignature(OpAsmParser &parser, SmallVectorImpl< PortParse > &args, TypeAttr &modType)
New Style parsing.
void printModuleSignatureNew(OpAsmPrinter &p, Region &body, hw::ModuleType modType, ArrayRef< Attribute > portAttrs, ArrayRef< Location > locAttrs)
bool isOffset(Value base, Value index, uint64_t offset)
Definition HWOps.cpp:1841
llvm::function_ref< void(OpBuilder &, HWModulePortAccessor &)> HWModuleBuilder
Definition HWOps.h:124
FunctionType getModuleType(Operation *module)
Return the signature for the specified module as a function type.
Definition HWOps.cpp:528
LogicalResult checkParameterInContext(Attribute value, Operation *module, Operation *usingOp, bool disallowParamRefs=false)
Check parameter specified by value to see if it is valid within the scope of the specified module mod...
Definition HWOps.cpp:201
int64_t getBitWidth(mlir::Type type)
Return the hardware bit width of a type.
Definition HWTypes.cpp:110
bool isAnyModuleOrInstance(Operation *module)
TODO: Move all these functions to a hw::ModuleLike interface.
Definition HWOps.cpp:522
StringAttr getVerilogModuleNameAttr(Operation *module)
Returns the verilog module name attribute or symbol name of any module-like operations.
Definition HWOps.cpp:546
mlir::Type getCanonicalType(mlir::Type type)
Definition HWTypes.cpp:49
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
ParseResult parseInputPortList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &inputs, SmallVectorImpl< Type > &inputTypes, ArrayAttr &inputNames)
Parse a list of instance input ports.
void printOutputPortList(OpAsmPrinter &p, Operation *op, TypeRange resultTypes, ArrayAttr resultNames)
Print a list of instance output ports.
ParseResult parseOptionalParameterList(OpAsmParser &parser, ArrayAttr &parameters)
Parse an parameter list if present.
void printOptionalParameterList(OpAsmPrinter &p, Operation *op, ArrayAttr parameters)
Print a parameter list for a module or instance.
StringRef chooseName(StringRef a, StringRef b)
Choose a good name for an item from two options.
Definition Naming.cpp:47
void printInputPortList(OpAsmPrinter &p, Operation *op, OperandRange inputs, TypeRange inputTypes, ArrayAttr inputNames)
Print a list of instance input ports.
ParseResult parseOutputPortList(OpAsmParser &parser, SmallVectorImpl< Type > &resultTypes, ArrayAttr &resultNames)
Parse a list of instance output ports.
Definition hw.py:1
function_ref< void(Value, StringRef)> OpAsmSetValueNameFn
Definition LLVM.h:183
This class represents the namespace in which InnerRef's can be resolved.
InnerSymTarget lookup(hw::InnerRefAttr inner) const
Resolve the InnerRef to its target within this namespace, returning empty target if no such name exis...
Operation * lookupOp(hw::InnerRefAttr inner) const
Resolve the InnerRef to its target within this namespace, returning empty target if no such name exis...
This holds a decoded list of input/inout and output ports for a module or instance.
PortInfo & at(size_t idx)
PortDirectionRange getOutputs()
mlir::Type type
Definition HWTypes.h:31
mlir::StringAttr name
Definition HWTypes.h:30
This holds the name, type, direction of a module's ports.