CIRCT 20.0.0git
Loading...
Searching...
No Matches
HWOps.cpp
Go to the documentation of this file.
1//===- HWOps.cpp - Implement the HW operations ----------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implement the HW ops.
10//
11//===----------------------------------------------------------------------===//
12
24#include "mlir/IR/Builders.h"
25#include "mlir/IR/PatternMatch.h"
26#include "mlir/Interfaces/FunctionImplementation.h"
27#include "llvm/ADT/BitVector.h"
28#include "llvm/ADT/SmallPtrSet.h"
29#include "llvm/ADT/StringSet.h"
30
31using namespace circt;
32using namespace hw;
33using mlir::TypedAttr;
34
35/// Flip a port direction.
37 switch (direction) {
38 case ModulePort::Direction::Input:
39 return ModulePort::Direction::Output;
40 case ModulePort::Direction::Output:
41 return ModulePort::Direction::Input;
42 case ModulePort::Direction::InOut:
43 return ModulePort::Direction::InOut;
44 }
45 llvm_unreachable("unknown PortDirection");
46}
47
48bool hw::isValidIndexBitWidth(Value index, Value array) {
49 hw::ArrayType arrayType =
50 dyn_cast<hw::ArrayType>(hw::getCanonicalType(array.getType()));
51 assert(arrayType && "expected array type");
52 unsigned indexWidth = index.getType().getIntOrFloatBitWidth();
53 auto requiredWidth = llvm::Log2_64_Ceil(arrayType.getNumElements());
54 return requiredWidth == 0 ? (indexWidth == 0 || indexWidth == 1)
55 : indexWidth == requiredWidth;
56}
57
58/// Return true if the specified operation is a combinational logic op.
59bool hw::isCombinational(Operation *op) {
60 struct IsCombClassifier : public TypeOpVisitor<IsCombClassifier, bool> {
61 bool visitInvalidTypeOp(Operation *op) { return false; }
62 bool visitUnhandledTypeOp(Operation *op) { return true; }
63 };
64
65 return (op->getDialect() && op->getDialect()->getNamespace() == "comb") ||
66 IsCombClassifier().dispatchTypeOpVisitor(op);
67}
68
69static Value foldStructExtract(Operation *inputOp, uint32_t fieldIndex) {
70 // A struct extract of a struct create -> corresponding struct create operand.
71 if (auto structCreate = dyn_cast_or_null<StructCreateOp>(inputOp)) {
72 return structCreate.getOperand(fieldIndex);
73 }
74
75 // Extracting injected field -> corresponding field
76 if (auto structInject = dyn_cast_or_null<StructInjectOp>(inputOp)) {
77 if (structInject.getFieldIndex() != fieldIndex)
78 return {};
79 return structInject.getNewValue();
80 }
81 return {};
82}
83
84static ArrayAttr arrayOrEmpty(mlir::MLIRContext *context,
85 ArrayRef<Attribute> attrs) {
86 if (attrs.empty())
87 return ArrayAttr::get(context, {});
88 bool empty = true;
89 for (auto a : attrs)
90 if (a && !cast<DictionaryAttr>(a).empty()) {
91 empty = false;
92 break;
93 }
94 if (empty)
95 return ArrayAttr::get(context, {});
96 return ArrayAttr::get(context, attrs);
97}
98
99/// Get a special name to use when printing the entry block arguments of the
100/// region contained by an operation in this dialect.
101static void getAsmBlockArgumentNamesImpl(mlir::Region &region,
102 OpAsmSetValueNameFn setNameFn) {
103 if (region.empty())
104 return;
105 // Assign port names to the bbargs.
106 auto module = cast<HWModuleOp>(region.getParentOp());
107
108 auto *block = &region.front();
109 for (size_t i = 0, e = block->getNumArguments(); i != e; ++i) {
110 auto name = module.getInputName(i);
111 // Let mlir deterministically convert names to valid identifiers
112 setNameFn(block->getArgument(i), name);
113 }
114}
115
116enum class Delimiter {
117 None,
118 Paren, // () enclosed list
119 OptionalLessGreater, // <> enclosed list or absent
120};
121
122/// Check parameter specified by `value` to see if it is valid according to the
123/// module's parameters. If not, emit an error to the diagnostic provided as an
124/// argument to the lambda 'instanceError' and return failure, otherwise return
125/// success.
126///
127/// If `disallowParamRefs` is true, then parameter references are not allowed.
128LogicalResult hw::checkParameterInContext(
129 Attribute value, ArrayAttr moduleParameters,
130 const instance_like_impl::EmitErrorFn &instanceError,
131 bool disallowParamRefs) {
132 // Literals are always ok. Their types are already known to match
133 // expectations.
134 if (isa<IntegerAttr>(value) || isa<FloatAttr>(value) ||
135 isa<StringAttr>(value) || isa<ParamVerbatimAttr>(value))
136 return success();
137
138 // Check both subexpressions of an expression.
139 if (auto expr = dyn_cast<ParamExprAttr>(value)) {
140 for (auto op : expr.getOperands())
141 if (failed(checkParameterInContext(op, moduleParameters, instanceError,
142 disallowParamRefs)))
143 return failure();
144 return success();
145 }
146
147 // Parameter references need more analysis to make sure they are valid within
148 // this module.
149 if (auto parameterRef = dyn_cast<ParamDeclRefAttr>(value)) {
150 auto nameAttr = parameterRef.getName();
151
152 // Don't allow references to parameters from the default values of a
153 // parameter list.
154 if (disallowParamRefs) {
155 instanceError([&](auto &diag) {
156 diag << "parameter " << nameAttr
157 << " cannot be used as a default value for a parameter";
158 return false;
159 });
160 return failure();
161 }
162
163 // Find the corresponding attribute in the module.
164 for (auto param : moduleParameters) {
165 auto paramAttr = cast<ParamDeclAttr>(param);
166 if (paramAttr.getName() != nameAttr)
167 continue;
168
169 // If the types match then the reference is ok.
170 if (paramAttr.getType() == parameterRef.getType())
171 return success();
172
173 instanceError([&](auto &diag) {
174 diag << "parameter " << nameAttr << " used with type "
175 << parameterRef.getType() << "; should have type "
176 << paramAttr.getType();
177 return true;
178 });
179 return failure();
180 }
181
182 instanceError([&](auto &diag) {
183 diag << "use of unknown parameter " << nameAttr;
184 return true;
185 });
186 return failure();
187 }
188
189 instanceError([&](auto &diag) {
190 diag << "invalid parameter value " << value;
191 return false;
192 });
193 return failure();
194}
195
196/// Check parameter specified by `value` to see if it is valid within the scope
197/// of the specified module `module`. If not, emit an error at the location of
198/// `usingOp` and return failure, otherwise return success. If `usingOp` is
199/// null, then no diagnostic is generated.
200///
201/// If `disallowParamRefs` is true, then parameter references are not allowed.
202LogicalResult hw::checkParameterInContext(Attribute value, Operation *module,
203 Operation *usingOp,
204 bool disallowParamRefs) {
206 [&](const std::function<bool(InFlightDiagnostic &)> &fn) {
207 if (usingOp) {
208 auto diag = usingOp->emitOpError();
209 if (fn(diag))
210 diag.attachNote(module->getLoc()) << "module declared here";
211 }
212 };
213
214 return checkParameterInContext(value,
215 module->getAttrOfType<ArrayAttr>("parameters"),
216 emitError, disallowParamRefs);
217}
218
219/// Return true if the specified attribute tree is made up of nodes that are
220/// valid in a parameter expression.
221bool hw::isValidParameterExpression(Attribute attr, Operation *module) {
222 return succeeded(checkParameterInContext(attr, module, nullptr, false));
223}
224
226 const ModulePortInfo &info,
227 Region &bodyRegion)
228 : info(info) {
229 inputArgs.resize(info.sizeInputs());
230 for (auto [i, barg] : llvm::enumerate(bodyRegion.getArguments())) {
231 inputIdx[info.at(i).name.str()] = i;
232 inputArgs[i] = barg;
233 }
234
236 for (auto [i, outputInfo] : llvm::enumerate(info.getOutputs())) {
237 outputIdx[outputInfo.name.str()] = i;
238 }
239}
240
241void HWModulePortAccessor::setOutput(unsigned i, Value v) {
242 assert(outputOperands.size() > i && "invalid output index");
243 assert(outputOperands[i] == Value() && "output already set");
244 outputOperands[i] = v;
245}
246
248 assert(inputArgs.size() > i && "invalid input index");
249 return inputArgs[i];
250}
251Value HWModulePortAccessor::getInput(StringRef name) {
252 return getInput(inputIdx.find(name.str())->second);
253}
254void HWModulePortAccessor::setOutput(StringRef name, Value v) {
255 setOutput(outputIdx.find(name.str())->second, v);
256}
257
258//===----------------------------------------------------------------------===//
259// ConstantOp
260//===----------------------------------------------------------------------===//
261
262void ConstantOp::print(OpAsmPrinter &p) {
263 p << " ";
264 p.printAttribute(getValueAttr());
265 p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"value"});
266}
267
268ParseResult ConstantOp::parse(OpAsmParser &parser, OperationState &result) {
269 IntegerAttr valueAttr;
270
271 if (parser.parseAttribute(valueAttr, "value", result.attributes) ||
272 parser.parseOptionalAttrDict(result.attributes))
273 return failure();
274
275 result.addTypes(valueAttr.getType());
276 return success();
277}
278
279LogicalResult ConstantOp::verify() {
280 // If the result type has a bitwidth, then the attribute must match its width.
281 if (getValue().getBitWidth() != cast<IntegerType>(getType()).getWidth())
282 return emitError(
283 "hw.constant attribute bitwidth doesn't match return type");
284
285 return success();
286}
287
288/// Build a ConstantOp from an APInt, infering the result type from the
289/// width of the APInt.
290void ConstantOp::build(OpBuilder &builder, OperationState &result,
291 const APInt &value) {
292
293 auto type = IntegerType::get(builder.getContext(), value.getBitWidth());
294 auto attr = builder.getIntegerAttr(type, value);
295 return build(builder, result, type, attr);
296}
297
298/// Build a ConstantOp from an APInt, infering the result type from the
299/// width of the APInt.
300void ConstantOp::build(OpBuilder &builder, OperationState &result,
301 IntegerAttr value) {
302 return build(builder, result, value.getType(), value);
303}
304
305/// This builder allows construction of small signed integers like 0, 1, -1
306/// matching a specified MLIR IntegerType. This shouldn't be used for general
307/// constant folding because it only works with values that can be expressed in
308/// an int64_t. Use APInt's instead.
309void ConstantOp::build(OpBuilder &builder, OperationState &result, Type type,
310 int64_t value) {
311 auto numBits = cast<IntegerType>(type).getWidth();
312 build(builder, result,
313 APInt(numBits, (uint64_t)value, /*isSigned=*/true,
314 /*implicitTrunc=*/true));
315}
316
317void ConstantOp::getAsmResultNames(
318 function_ref<void(Value, StringRef)> setNameFn) {
319 auto intTy = getType();
320 auto intCst = getValue();
321
322 // Sugar i1 constants with 'true' and 'false'.
323 if (cast<IntegerType>(intTy).getWidth() == 1)
324 return setNameFn(getResult(), intCst.isZero() ? "false" : "true");
325
326 // Otherwise, build a complex name with the value and type.
327 SmallVector<char, 32> specialNameBuffer;
328 llvm::raw_svector_ostream specialName(specialNameBuffer);
329 specialName << 'c' << intCst << '_' << intTy;
330 setNameFn(getResult(), specialName.str());
331}
332
333OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) {
334 assert(adaptor.getOperands().empty() && "constant has no operands");
335 return getValueAttr();
336}
337
338//===----------------------------------------------------------------------===//
339// WireOp
340//===----------------------------------------------------------------------===//
341
342/// Check whether an operation has any additional attributes set beyond its
343/// standard list of attributes returned by `getAttributeNames`.
344template <class Op>
345static bool hasAdditionalAttributes(Op op,
346 ArrayRef<StringRef> ignoredAttrs = {}) {
347 auto names = op.getAttributeNames();
348 llvm::SmallDenseSet<StringRef> nameSet;
349 nameSet.reserve(names.size() + ignoredAttrs.size());
350 nameSet.insert(names.begin(), names.end());
351 nameSet.insert(ignoredAttrs.begin(), ignoredAttrs.end());
352 return llvm::any_of(op->getAttrs(), [&](auto namedAttr) {
353 return !nameSet.contains(namedAttr.getName());
354 });
355}
356
357void WireOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
358 // If the wire has an optional 'name' attribute, use it.
359 auto nameAttr = (*this)->getAttrOfType<StringAttr>("name");
360 if (nameAttr && !nameAttr.getValue().empty())
361 setNameFn(getResult(), nameAttr.getValue());
362}
363
364std::optional<size_t> WireOp::getTargetResultIndex() { return 0; }
365
366OpFoldResult WireOp::fold(FoldAdaptor adaptor) {
367 // If the wire has no additional attributes, no name, and no symbol, just
368 // forward its input.
369 if (!hasAdditionalAttributes(*this, {"sv.namehint"}) && !getNameAttr() &&
370 !getInnerSymAttr())
371 return getInput();
372 return {};
373}
374
375LogicalResult WireOp::canonicalize(WireOp wire, PatternRewriter &rewriter) {
376 // Block if the wire has any attributes.
377 if (hasAdditionalAttributes(wire, {"sv.namehint"}))
378 return failure();
379
380 // If the wire has a symbol, then we can't delete it.
381 if (wire.getInnerSymAttr())
382 return failure();
383
384 // If the wire has a name or an `sv.namehint` attribute, propagate it as an
385 // `sv.namehint` to the expression.
386 if (auto *inputOp = wire.getInput().getDefiningOp())
387 if (auto name = chooseName(wire, inputOp))
388 rewriter.modifyOpInPlace(inputOp,
389 [&] { inputOp->setAttr("sv.namehint", name); });
390
391 rewriter.replaceOp(wire, wire.getInput());
392 return success();
393}
394
395//===----------------------------------------------------------------------===//
396// AggregateConstantOp
397//===----------------------------------------------------------------------===//
398
399static LogicalResult checkAttributes(Operation *op, Attribute attr, Type type) {
400 // If this is a type alias, get the underlying type.
401 if (auto typeAlias = dyn_cast<TypeAliasType>(type))
402 type = typeAlias.getCanonicalType();
403
404 if (auto structType = dyn_cast<StructType>(type)) {
405 auto arrayAttr = dyn_cast<ArrayAttr>(attr);
406 if (!arrayAttr)
407 return op->emitOpError("expected array attribute for constant of type ")
408 << type;
409 if (structType.getElements().size() != arrayAttr.size())
410 return op->emitOpError("array attribute (")
411 << arrayAttr.size() << ") has wrong size for struct constant ("
412 << structType.getElements().size() << ")";
413
414 for (auto [attr, fieldInfo] :
415 llvm::zip(arrayAttr.getValue(), structType.getElements())) {
416 if (failed(checkAttributes(op, attr, fieldInfo.type)))
417 return failure();
418 }
419 } else if (auto arrayType = dyn_cast<ArrayType>(type)) {
420 auto arrayAttr = dyn_cast<ArrayAttr>(attr);
421 if (!arrayAttr)
422 return op->emitOpError("expected array attribute for constant of type ")
423 << type;
424 if (arrayType.getNumElements() != arrayAttr.size())
425 return op->emitOpError("array attribute (")
426 << arrayAttr.size() << ") has wrong size for array constant ("
427 << arrayType.getNumElements() << ")";
428
429 auto elementType = arrayType.getElementType();
430 for (auto attr : arrayAttr.getValue()) {
431 if (failed(checkAttributes(op, attr, elementType)))
432 return failure();
433 }
434 } else if (auto arrayType = dyn_cast<UnpackedArrayType>(type)) {
435 auto arrayAttr = dyn_cast<ArrayAttr>(attr);
436 if (!arrayAttr)
437 return op->emitOpError("expected array attribute for constant of type ")
438 << type;
439 auto elementType = arrayType.getElementType();
440 if (arrayType.getNumElements() != arrayAttr.size())
441 return op->emitOpError("array attribute (")
442 << arrayAttr.size()
443 << ") has wrong size for unpacked array constant ("
444 << arrayType.getNumElements() << ")";
445
446 for (auto attr : arrayAttr.getValue()) {
447 if (failed(checkAttributes(op, attr, elementType)))
448 return failure();
449 }
450 } else if (auto enumType = dyn_cast<EnumType>(type)) {
451 auto stringAttr = dyn_cast<StringAttr>(attr);
452 if (!stringAttr)
453 return op->emitOpError("expected string attribute for constant of type ")
454 << type;
455 } else if (auto intType = dyn_cast<IntegerType>(type)) {
456 // Check the attribute kind is correct.
457 auto intAttr = dyn_cast<IntegerAttr>(attr);
458 if (!intAttr)
459 return op->emitOpError("expected integer attribute for constant of type ")
460 << type;
461 // Check the bitwidth is correct.
462 if (intAttr.getValue().getBitWidth() != intType.getWidth())
463 return op->emitOpError("hw.constant attribute bitwidth "
464 "doesn't match return type");
465 } else if (auto typedAttr = dyn_cast<TypedAttr>(attr)) {
466 if (typedAttr.getType() != type)
467 return op->emitOpError("typed attr doesn't match the return type ")
468 << type;
469 } else {
470 return op->emitOpError("unknown element type ") << type;
471 }
472 return success();
473}
474
475LogicalResult AggregateConstantOp::verify() {
476 return checkAttributes(*this, getFieldsAttr(), getType());
477}
478
479OpFoldResult AggregateConstantOp::fold(FoldAdaptor) { return getFieldsAttr(); }
480
481//===----------------------------------------------------------------------===//
482// ParamValueOp
483//===----------------------------------------------------------------------===//
484
485static ParseResult parseParamValue(OpAsmParser &p, Attribute &value,
486 Type &resultType) {
487 if (p.parseType(resultType) || p.parseEqual() ||
488 p.parseAttribute(value, resultType))
489 return failure();
490 return success();
491}
492
493static void printParamValue(OpAsmPrinter &p, Operation *, Attribute value,
494 Type resultType) {
495 p << resultType << " = ";
496 p.printAttributeWithoutType(value);
497}
498
499LogicalResult ParamValueOp::verify() {
500 // Check that the attribute expression is valid in this module.
502 getValue(), (*this)->getParentOfType<hw::HWModuleOp>(), *this);
503}
504
505OpFoldResult ParamValueOp::fold(FoldAdaptor adaptor) {
506 assert(adaptor.getOperands().empty() && "hw.param.value has no operands");
507 return getValueAttr();
508}
509
510//===----------------------------------------------------------------------===//
511// HWModuleOp
512//===----------------------------------------------------------------------===/
513
514/// Return true if isAnyModule or instance.
515bool hw::isAnyModuleOrInstance(Operation *moduleOrInstance) {
516 return isa<HWModuleLike, InstanceOp>(moduleOrInstance);
517}
518
519/// Return the signature for a module as a function type from the module itself
520/// or from an hw::InstanceOp.
521FunctionType hw::getModuleType(Operation *moduleOrInstance) {
522 return TypeSwitch<Operation *, FunctionType>(moduleOrInstance)
523 .Case<InstanceOp, InstanceChoiceOp>([](auto instance) {
524 SmallVector<Type> inputs(instance->getOperandTypes());
525 SmallVector<Type> results(instance->getResultTypes());
526 return FunctionType::get(instance->getContext(), inputs, results);
527 })
528 .Case<HWModuleLike>(
529 [](auto mod) { return mod.getHWModuleType().getFuncType(); })
530 .Default([](Operation *op) {
531 return cast<FunctionType>(
532 cast<mlir::FunctionOpInterface>(op).getFunctionType());
533 });
534}
535
536/// Return the name to use for the Verilog module that we're referencing
537/// here. This is typically the symbol, but can be overridden with the
538/// verilogName attribute.
539StringAttr hw::getVerilogModuleNameAttr(Operation *module) {
540 auto nameAttr = module->getAttrOfType<StringAttr>("verilogName");
541 if (nameAttr)
542 return nameAttr;
543
544 return module->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName());
545}
546
547template <typename ModuleTy>
548static void
549buildModule(OpBuilder &builder, OperationState &result, StringAttr name,
550 const ModulePortInfo &ports, ArrayAttr parameters,
551 ArrayRef<NamedAttribute> attributes, StringAttr comment) {
552 using namespace mlir::function_interface_impl;
553
554 // Add an attribute for the name.
555 result.addAttribute(SymbolTable::getSymbolAttrName(), name);
556
557 SmallVector<Attribute> perPortAttrs;
558 SmallVector<ModulePort> portTypes;
559
560 for (auto elt : ports) {
561 portTypes.push_back(elt);
562 llvm::SmallVector<NamedAttribute> portAttrs;
563 if (elt.attrs)
564 llvm::copy(elt.attrs, std::back_inserter(portAttrs));
565 perPortAttrs.push_back(builder.getDictionaryAttr(portAttrs));
566 }
567
568 // Allow clients to pass in null for the parameters list.
569 if (!parameters)
570 parameters = builder.getArrayAttr({});
571
572 // Record the argument and result types as an attribute.
573 auto type = ModuleType::get(builder.getContext(), portTypes);
574 result.addAttribute(ModuleTy::getModuleTypeAttrName(result.name),
575 TypeAttr::get(type));
576 result.addAttribute("per_port_attrs",
577 arrayOrEmpty(builder.getContext(), perPortAttrs));
578 result.addAttribute("parameters", parameters);
579 if (!comment)
580 comment = builder.getStringAttr("");
581 result.addAttribute("comment", comment);
582 result.addAttributes(attributes);
583 result.addRegion();
584}
585
586/// Internal implementation of argument/result insertion and removal on modules.
588 MLIRContext *context, ArrayRef<std::pair<unsigned, PortInfo>> insertArgs,
589 ArrayRef<unsigned> removeArgs, ArrayRef<Attribute> oldArgNames,
590 ArrayRef<Type> oldArgTypes, ArrayRef<Attribute> oldArgAttrs,
591 ArrayRef<Location> oldArgLocs, SmallVector<Attribute> &newArgNames,
592 SmallVector<Type> &newArgTypes, SmallVector<Attribute> &newArgAttrs,
593 SmallVector<Location> &newArgLocs, Block *body = nullptr) {
594
595#ifndef NDEBUG
596 // Check that the `insertArgs` and `removeArgs` indices are in ascending
597 // order.
598 assert(llvm::is_sorted(insertArgs,
599 [](auto &a, auto &b) { return a.first < b.first; }) &&
600 "insertArgs must be in ascending order");
601 assert(llvm::is_sorted(removeArgs, [](auto &a, auto &b) { return a < b; }) &&
602 "removeArgs must be in ascending order");
603#endif
604
605 auto oldArgCount = oldArgTypes.size();
606 auto newArgCount = oldArgCount + insertArgs.size() - removeArgs.size();
607 assert((int)newArgCount >= 0);
608
609 newArgNames.reserve(newArgCount);
610 newArgTypes.reserve(newArgCount);
611 newArgAttrs.reserve(newArgCount);
612 newArgLocs.reserve(newArgCount);
613
614 auto exportPortAttrName = StringAttr::get(context, "hw.exportPort");
615 auto emptyDictAttr = DictionaryAttr::get(context, {});
616 auto unknownLoc = UnknownLoc::get(context);
617
618 BitVector erasedIndices;
619 if (body)
620 erasedIndices.resize(oldArgCount + insertArgs.size());
621
622 for (unsigned argIdx = 0, idx = 0; argIdx <= oldArgCount; ++argIdx, ++idx) {
623 // Insert new ports at this position.
624 while (!insertArgs.empty() && insertArgs[0].first == argIdx) {
625 auto port = insertArgs[0].second;
626 if (port.dir == ModulePort::Direction::InOut &&
627 !isa<InOutType>(port.type))
628 port.type = InOutType::get(port.type);
629 auto sym = port.getSym();
630 Attribute attr =
631 (sym && !sym.empty())
632 ? DictionaryAttr::get(context, {{exportPortAttrName, sym}})
633 : emptyDictAttr;
634 newArgNames.push_back(port.name);
635 newArgTypes.push_back(port.type);
636 newArgAttrs.push_back(attr);
637 insertArgs = insertArgs.drop_front();
638 LocationAttr loc = port.loc ? port.loc : unknownLoc;
639 newArgLocs.push_back(loc);
640 if (body)
641 body->insertArgument(idx++, port.type, loc);
642 }
643 if (argIdx == oldArgCount)
644 break;
645
646 // Migrate the old port at this position.
647 bool removed = false;
648 while (!removeArgs.empty() && removeArgs[0] == argIdx) {
649 removeArgs = removeArgs.drop_front();
650 removed = true;
651 }
652
653 if (removed) {
654 if (body)
655 erasedIndices.set(idx);
656 } else {
657 newArgNames.push_back(oldArgNames[argIdx]);
658 newArgTypes.push_back(oldArgTypes[argIdx]);
659 newArgAttrs.push_back(oldArgAttrs.empty() ? emptyDictAttr
660 : oldArgAttrs[argIdx]);
661 newArgLocs.push_back(oldArgLocs[argIdx]);
662 }
663 }
664
665 if (body)
666 body->eraseArguments(erasedIndices);
667
668 assert(newArgNames.size() == newArgCount);
669 assert(newArgTypes.size() == newArgCount);
670 assert(newArgAttrs.size() == newArgCount);
671 assert(newArgLocs.size() == newArgCount);
672}
673
674/// Insert and remove ports of a module. The insertion and removal indices must
675/// be in ascending order. The indices refer to the port positions before any
676/// insertion or removal occurs. Ports inserted at the same index will appear in
677/// the module in the same order as they were listed in the `insert*` array.
678///
679/// The operation must be any of the module-like operations.
680///
681/// This is marked deprecated as it's only used from HandshakeToHW and
682/// PortConverter and is likely broken and not currently tested. Users of this
683/// are still written dealing with input and output ports separately, which is
684/// an old and broken style.
685[[deprecated]] static void
686modifyModulePorts(Operation *op,
687 ArrayRef<std::pair<unsigned, PortInfo>> insertInputs,
688 ArrayRef<std::pair<unsigned, PortInfo>> insertOutputs,
689 ArrayRef<unsigned> removeInputs,
690 ArrayRef<unsigned> removeOutputs, Block *body = nullptr) {
691 auto moduleOp = cast<HWModuleLike>(op);
692 auto *context = moduleOp.getContext();
693
694 // Dig up the old argument and result data.
695 auto oldArgNames = moduleOp.getInputNames();
696 auto oldArgTypes = moduleOp.getInputTypes();
697 auto oldArgAttrs = moduleOp.getAllInputAttrs();
698 auto oldArgLocs = moduleOp.getInputLocs();
699
700 auto oldResultNames = moduleOp.getOutputNames();
701 auto oldResultTypes = moduleOp.getOutputTypes();
702 auto oldResultAttrs = moduleOp.getAllOutputAttrs();
703 auto oldResultLocs = moduleOp.getOutputLocs();
704
705 // Modify the ports.
706 SmallVector<Attribute> newArgNames, newResultNames;
707 SmallVector<Type> newArgTypes, newResultTypes;
708 SmallVector<Attribute> newArgAttrs, newResultAttrs;
709 SmallVector<Location> newArgLocs, newResultLocs;
710
711 modifyModuleArgs(context, insertInputs, removeInputs, oldArgNames,
712 oldArgTypes, oldArgAttrs, oldArgLocs, newArgNames,
713 newArgTypes, newArgAttrs, newArgLocs, body);
714
715 modifyModuleArgs(context, insertOutputs, removeOutputs, oldResultNames,
716 oldResultTypes, oldResultAttrs, oldResultLocs,
717 newResultNames, newResultTypes, newResultAttrs,
718 newResultLocs);
719
720 // Update the module operation types and attributes.
721 auto fnty = FunctionType::get(context, newArgTypes, newResultTypes);
722 auto modty = detail::fnToMod(fnty, newArgNames, newResultNames);
723 moduleOp.setHWModuleType(modty);
724 moduleOp.setAllInputAttrs(newArgAttrs);
725 moduleOp.setAllOutputAttrs(newResultAttrs);
726
727 newArgLocs.append(newResultLocs.begin(), newResultLocs.end());
728 moduleOp.setAllPortLocs(newArgLocs);
729}
730
731void HWModuleOp::build(OpBuilder &builder, OperationState &result,
732 StringAttr name, const ModulePortInfo &ports,
733 ArrayAttr parameters,
734 ArrayRef<NamedAttribute> attributes, StringAttr comment,
735 bool shouldEnsureTerminator) {
736 buildModule<HWModuleOp>(builder, result, name, ports, parameters, attributes,
737 comment);
738
739 // Create a region and a block for the body.
740 auto *bodyRegion = result.regions[0].get();
741 Block *body = new Block();
742 bodyRegion->push_back(body);
743
744 // Add arguments to the body block.
745 auto unknownLoc = builder.getUnknownLoc();
746 for (auto port : ports.getInputs()) {
747 auto loc = port.loc ? Location(port.loc) : unknownLoc;
748 auto type = port.type;
749 if (port.isInOut() && !isa<InOutType>(type))
750 type = InOutType::get(type);
751 body->addArgument(type, loc);
752 }
753
754 // Add result ports attribute.
755 auto unknownLocAttr = cast<LocationAttr>(unknownLoc);
756 SmallVector<Attribute> resultLocs;
757 for (auto port : ports.getOutputs())
758 resultLocs.push_back(port.loc ? port.loc : unknownLocAttr);
759 result.addAttribute("result_locs", builder.getArrayAttr(resultLocs));
760
761 if (shouldEnsureTerminator)
762 HWModuleOp::ensureTerminator(*bodyRegion, builder, result.location);
763}
764
765void HWModuleOp::build(OpBuilder &builder, OperationState &result,
766 StringAttr name, ArrayRef<PortInfo> ports,
767 ArrayAttr parameters,
768 ArrayRef<NamedAttribute> attributes,
769 StringAttr comment) {
770 build(builder, result, name, ModulePortInfo(ports), parameters, attributes,
771 comment);
772}
773
774void HWModuleOp::build(OpBuilder &builder, OperationState &odsState,
775 StringAttr name, const ModulePortInfo &ports,
776 HWModuleBuilder modBuilder, ArrayAttr parameters,
777 ArrayRef<NamedAttribute> attributes,
778 StringAttr comment) {
779 build(builder, odsState, name, ports, parameters, attributes, comment,
780 /*shouldEnsureTerminator=*/false);
781 auto *bodyRegion = odsState.regions[0].get();
782 OpBuilder::InsertionGuard guard(builder);
783 auto accessor = HWModulePortAccessor(odsState.location, ports, *bodyRegion);
784 builder.setInsertionPointToEnd(&bodyRegion->front());
785 modBuilder(builder, accessor);
786 // Create output operands.
787 llvm::SmallVector<Value> outputOperands = accessor.getOutputOperands();
788 builder.create<hw::OutputOp>(odsState.location, outputOperands);
789}
790
791void HWModuleOp::modifyPorts(
792 ArrayRef<std::pair<unsigned, PortInfo>> insertInputs,
793 ArrayRef<std::pair<unsigned, PortInfo>> insertOutputs,
794 ArrayRef<unsigned> eraseInputs, ArrayRef<unsigned> eraseOutputs) {
795 modifyModulePorts(*this, insertInputs, insertOutputs, eraseInputs,
796 eraseOutputs);
797}
798
799/// Return the name to use for the Verilog module that we're referencing
800/// here. This is typically the symbol, but can be overridden with the
801/// verilogName attribute.
802StringAttr HWModuleExternOp::getVerilogModuleNameAttr() {
803 if (auto vName = getVerilogNameAttr())
804 return vName;
805
806 return (*this)->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName());
807}
808
809StringAttr HWModuleGeneratedOp::getVerilogModuleNameAttr() {
810 if (auto vName = getVerilogNameAttr()) {
811 return vName;
812 }
813 return (*this)->getAttrOfType<StringAttr>(
814 ::mlir::SymbolTable::getSymbolAttrName());
815}
816
817void HWModuleExternOp::build(OpBuilder &builder, OperationState &result,
818 StringAttr name, const ModulePortInfo &ports,
819 StringRef verilogName, ArrayAttr parameters,
820 ArrayRef<NamedAttribute> attributes) {
821 buildModule<HWModuleExternOp>(builder, result, name, ports, parameters,
822 attributes, {});
823
824 // Add the port locations.
825 LocationAttr unknownLoc = builder.getUnknownLoc();
826 SmallVector<Attribute> portLocs;
827 for (auto elt : ports)
828 portLocs.push_back(elt.loc ? elt.loc : unknownLoc);
829 result.addAttribute("port_locs", builder.getArrayAttr(portLocs));
830
831 if (!verilogName.empty())
832 result.addAttribute("verilogName", builder.getStringAttr(verilogName));
833}
834
835void HWModuleExternOp::build(OpBuilder &builder, OperationState &result,
836 StringAttr name, ArrayRef<PortInfo> ports,
837 StringRef verilogName, ArrayAttr parameters,
838 ArrayRef<NamedAttribute> attributes) {
839 build(builder, result, name, ModulePortInfo(ports), verilogName, parameters,
840 attributes);
841}
842
843void HWModuleExternOp::modifyPorts(
844 ArrayRef<std::pair<unsigned, PortInfo>> insertInputs,
845 ArrayRef<std::pair<unsigned, PortInfo>> insertOutputs,
846 ArrayRef<unsigned> eraseInputs, ArrayRef<unsigned> eraseOutputs) {
847 modifyModulePorts(*this, insertInputs, insertOutputs, eraseInputs,
848 eraseOutputs);
849}
850
851void HWModuleExternOp::appendOutputs(
852 ArrayRef<std::pair<StringAttr, Value>> outputs) {}
853
854void HWModuleGeneratedOp::build(OpBuilder &builder, OperationState &result,
855 FlatSymbolRefAttr genKind, StringAttr name,
856 const ModulePortInfo &ports,
857 StringRef verilogName, ArrayAttr parameters,
858 ArrayRef<NamedAttribute> attributes) {
859 buildModule<HWModuleGeneratedOp>(builder, result, name, ports, parameters,
860 attributes, {});
861 // Add the port locations.
862 LocationAttr unknownLoc = builder.getUnknownLoc();
863 SmallVector<Attribute> portLocs;
864 for (auto elt : ports)
865 portLocs.push_back(elt.loc ? elt.loc : unknownLoc);
866 result.addAttribute("port_locs", builder.getArrayAttr(portLocs));
867
868 result.addAttribute("generatorKind", genKind);
869 if (!verilogName.empty())
870 result.addAttribute("verilogName", builder.getStringAttr(verilogName));
871}
872
873void HWModuleGeneratedOp::build(OpBuilder &builder, OperationState &result,
874 FlatSymbolRefAttr genKind, StringAttr name,
875 ArrayRef<PortInfo> ports, StringRef verilogName,
876 ArrayAttr parameters,
877 ArrayRef<NamedAttribute> attributes) {
878 build(builder, result, genKind, name, ModulePortInfo(ports), verilogName,
879 parameters, attributes);
880}
881
882void HWModuleGeneratedOp::modifyPorts(
883 ArrayRef<std::pair<unsigned, PortInfo>> insertInputs,
884 ArrayRef<std::pair<unsigned, PortInfo>> insertOutputs,
885 ArrayRef<unsigned> eraseInputs, ArrayRef<unsigned> eraseOutputs) {
886 modifyModulePorts(*this, insertInputs, insertOutputs, eraseInputs,
887 eraseOutputs);
888}
889
890void HWModuleGeneratedOp::appendOutputs(
891 ArrayRef<std::pair<StringAttr, Value>> outputs) {}
892
893static bool hasAttribute(StringRef name, ArrayRef<NamedAttribute> attrs) {
894 for (auto &argAttr : attrs)
895 if (argAttr.getName() == name)
896 return true;
897 return false;
898}
899
900template <typename ModuleTy>
901static ParseResult parseHWModuleOp(OpAsmParser &parser,
902 OperationState &result) {
903
904 using namespace mlir::function_interface_impl;
905 auto builder = parser.getBuilder();
906 auto loc = parser.getCurrentLocation();
907
908 // Parse the visibility attribute.
909 (void)mlir::impl::parseOptionalVisibilityKeyword(parser, result.attributes);
910
911 // Parse the name as a symbol.
912 StringAttr nameAttr;
913 if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
914 result.attributes))
915 return failure();
916
917 // Parse the generator information.
918 FlatSymbolRefAttr kindAttr;
919 if constexpr (std::is_same_v<ModuleTy, HWModuleGeneratedOp>) {
920 if (parser.parseComma() ||
921 parser.parseAttribute(kindAttr, "generatorKind", result.attributes)) {
922 return failure();
923 }
924 }
925
926 // Parse the parameters.
927 ArrayAttr parameters;
928 if (parseOptionalParameterList(parser, parameters))
929 return failure();
930
931 SmallVector<module_like_impl::PortParse> ports;
932 TypeAttr modType;
933 if (failed(module_like_impl::parseModuleSignature(parser, ports, modType)))
934 return failure();
935
936 // Parse the attribute dict.
937 if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
938 return failure();
939
940 if (hasAttribute("parameters", result.attributes)) {
941 parser.emitError(loc, "explicit `parameters` attributes not allowed");
942 return failure();
943 }
944
945 result.addAttribute("parameters", parameters);
946 result.addAttribute(ModuleTy::getModuleTypeAttrName(result.name), modType);
947
948 // Convert the specified array of dictionary attrs (which may have null
949 // entries) to an ArrayAttr of dictionaries.
950 SmallVector<Attribute> attrs;
951 for (auto &port : ports)
952 attrs.push_back(port.attrs ? port.attrs : builder.getDictionaryAttr({}));
953 // Add the attributes to the ports.
954 auto nonEmptyAttrsFn = [](Attribute attr) {
955 return attr && !cast<DictionaryAttr>(attr).empty();
956 };
957 if (llvm::any_of(attrs, nonEmptyAttrsFn))
958 result.addAttribute(ModuleTy::getPerPortAttrsAttrName(result.name),
959 builder.getArrayAttr(attrs));
960
961 // Add the port locations.
962 auto unknownLoc = builder.getUnknownLoc();
963 auto nonEmptyLocsFn = [unknownLoc](Attribute attr) {
964 return attr && cast<Location>(attr) != unknownLoc;
965 };
966 SmallVector<Attribute> locs;
967 StringAttr portLocsAttrName;
968 if constexpr (std::is_same_v<ModuleTy, HWModuleOp>) {
969 // Plain modules only store the output port locations, as the input port
970 // locations will be stored in the basic block arguments.
971 portLocsAttrName = ModuleTy::getResultLocsAttrName(result.name);
972 for (auto &port : ports)
973 if (port.direction == ModulePort::Direction::Output)
974 locs.push_back(port.sourceLoc ? Location(*port.sourceLoc) : unknownLoc);
975 } else {
976 // All other modules store all port locations in a single array.
977 portLocsAttrName = ModuleTy::getPortLocsAttrName(result.name);
978 for (auto &port : ports)
979 locs.push_back(port.sourceLoc ? Location(*port.sourceLoc) : unknownLoc);
980 }
981 if (llvm::any_of(locs, nonEmptyLocsFn))
982 result.addAttribute(portLocsAttrName, builder.getArrayAttr(locs));
983
984 // Add the entry block arguments.
985 SmallVector<OpAsmParser::Argument, 4> entryArgs;
986 for (auto &port : ports)
987 if (port.direction != ModulePort::Direction::Output)
988 entryArgs.push_back(port);
989
990 // Parse the optional function body.
991 auto *body = result.addRegion();
992 if (std::is_same_v<ModuleTy, HWModuleOp>) {
993 if (parser.parseRegion(*body, entryArgs))
994 return failure();
995
996 HWModuleOp::ensureTerminator(*body, parser.getBuilder(), result.location);
997 }
998 return success();
999}
1000
1001ParseResult HWModuleOp::parse(OpAsmParser &parser, OperationState &result) {
1002 return parseHWModuleOp<HWModuleOp>(parser, result);
1003}
1004
1005ParseResult HWModuleExternOp::parse(OpAsmParser &parser,
1006 OperationState &result) {
1007 return parseHWModuleOp<HWModuleExternOp>(parser, result);
1008}
1009
1010ParseResult HWModuleGeneratedOp::parse(OpAsmParser &parser,
1011 OperationState &result) {
1012 return parseHWModuleOp<HWModuleGeneratedOp>(parser, result);
1013}
1014
1015FunctionType getHWModuleOpType(Operation *op) {
1016 if (auto mod = dyn_cast<HWModuleLike>(op))
1017 return mod.getHWModuleType().getFuncType();
1018 return cast<FunctionType>(
1019 cast<mlir::FunctionOpInterface>(op).getFunctionType());
1020}
1021
1022template <typename ModuleTy>
1023static void printModuleOp(OpAsmPrinter &p, ModuleTy mod) {
1024 p << ' ';
1025 // Print the visibility of the module.
1026 StringRef visibilityAttrName = SymbolTable::getVisibilityAttrName();
1027 if (auto visibility = mod.getOperation()->template getAttrOfType<StringAttr>(
1028 visibilityAttrName))
1029 p << visibility.getValue() << ' ';
1030
1031 // Print the operation and the function name.
1032 p.printSymbolName(SymbolTable::getSymbolName(mod.getOperation()).getValue());
1033 if (auto gen = dyn_cast<HWModuleGeneratedOp>(mod.getOperation())) {
1034 p << ", ";
1035 p.printSymbolName(gen.getGeneratorKind());
1036 }
1037
1038 // Print the parameter list if present.
1039 printOptionalParameterList(p, mod.getOperation(), mod.getParameters());
1040
1042
1043 SmallVector<StringRef, 3> omittedAttrs;
1044 if (isa<HWModuleGeneratedOp>(mod.getOperation()))
1045 omittedAttrs.push_back("generatorKind");
1046 if constexpr (std::is_same_v<ModuleTy, HWModuleOp>)
1047 omittedAttrs.push_back(mod.getResultLocsAttrName());
1048 else
1049 omittedAttrs.push_back(mod.getPortLocsAttrName());
1050 omittedAttrs.push_back(mod.getModuleTypeAttrName());
1051 omittedAttrs.push_back(mod.getPerPortAttrsAttrName());
1052 omittedAttrs.push_back(mod.getParametersAttrName());
1053 omittedAttrs.push_back(visibilityAttrName);
1054 if (auto cmt =
1055 mod.getOperation()->template getAttrOfType<StringAttr>("comment"))
1056 if (cmt.getValue().empty())
1057 omittedAttrs.push_back("comment");
1058
1059 mlir::function_interface_impl::printFunctionAttributes(p, mod.getOperation(),
1060 omittedAttrs);
1061}
1062
1063void HWModuleExternOp::print(OpAsmPrinter &p) { printModuleOp(p, *this); }
1064void HWModuleGeneratedOp::print(OpAsmPrinter &p) { printModuleOp(p, *this); }
1065
1066void HWModuleOp::print(OpAsmPrinter &p) {
1067 printModuleOp(p, *this);
1068
1069 // Print the body if this is not an external function.
1070 Region &body = getBody();
1071 if (!body.empty()) {
1072 p << " ";
1073 p.printRegion(body, /*printEntryBlockArgs=*/false,
1074 /*printBlockTerminators=*/true);
1075 }
1076}
1077
1078static LogicalResult verifyModuleCommon(HWModuleLike module) {
1079 assert(isa<HWModuleLike>(module) &&
1080 "verifier hook should only be called on modules");
1081
1082 SmallPtrSet<Attribute, 4> paramNames;
1083
1084 // Check parameter default values are sensible.
1085 for (auto param : module->getAttrOfType<ArrayAttr>("parameters")) {
1086 auto paramAttr = cast<ParamDeclAttr>(param);
1087
1088 // Check that we don't have any redundant parameter names. These are
1089 // resolved by string name: reuse of the same name would cause ambiguities.
1090 if (!paramNames.insert(paramAttr.getName()).second)
1091 return module->emitOpError("parameter ")
1092 << paramAttr << " has the same name as a previous parameter";
1093
1094 // Default values are allowed to be missing, check them if present.
1095 auto value = paramAttr.getValue();
1096 if (!value)
1097 continue;
1098
1099 auto typedValue = dyn_cast<TypedAttr>(value);
1100 if (!typedValue)
1101 return module->emitOpError("parameter ")
1102 << paramAttr << " should have a typed value; has value " << value;
1103
1104 if (typedValue.getType() != paramAttr.getType())
1105 return module->emitOpError("parameter ")
1106 << paramAttr << " should have type " << paramAttr.getType()
1107 << "; has type " << typedValue.getType();
1108
1109 // Verify that this is a valid parameter value, disallowing parameter
1110 // references. We could allow parameters to refer to each other in the
1111 // future with lexical ordering if there is a need.
1112 if (failed(checkParameterInContext(value, module, module,
1113 /*disallowParamRefs=*/true)))
1114 return failure();
1115 }
1116 return success();
1117}
1118
1119LogicalResult HWModuleOp::verify() {
1120 if (failed(verifyModuleCommon(*this)))
1121 return failure();
1122
1123 auto type = getModuleType();
1124 auto *body = getBodyBlock();
1125
1126 // Verify the number of block arguments.
1127 auto numInputs = type.getNumInputs();
1128 if (body->getNumArguments() != numInputs)
1129 return emitOpError("entry block must have")
1130 << numInputs << " arguments to match module signature";
1131
1132 return success();
1133}
1134
1135LogicalResult HWModuleExternOp::verify() { return verifyModuleCommon(*this); }
1136
1137std::pair<StringAttr, BlockArgument>
1138HWModuleOp::insertInput(unsigned index, StringAttr name, Type ty) {
1139 // Find a unique name for the wire.
1140 Namespace ns;
1141 auto ports = getPortList();
1142 for (auto port : ports)
1143 ns.newName(port.name.getValue());
1144 auto nameAttr = StringAttr::get(getContext(), ns.newName(name.getValue()));
1145
1146 Block *body = getBodyBlock();
1147
1148 // Create a new port for the host clock.
1149 PortInfo port;
1150 port.name = nameAttr;
1152 port.type = ty;
1153 modifyModulePorts(getOperation(), {std::make_pair(index, port)}, {}, {}, {},
1154 body);
1155
1156 // Add a new argument.
1157 return {nameAttr, body->getArgument(index)};
1158}
1159
1160void HWModuleOp::insertOutputs(unsigned index,
1161 ArrayRef<std::pair<StringAttr, Value>> outputs) {
1162
1163 auto output = cast<OutputOp>(getBodyBlock()->getTerminator());
1164 assert(index <= output->getNumOperands() && "invalid output index");
1165
1166 // Rewrite the port list of the module.
1167 SmallVector<std::pair<unsigned, PortInfo>> indexedNewPorts;
1168 for (auto &[name, value] : outputs) {
1169 PortInfo port;
1170 port.name = name;
1172 port.type = value.getType();
1173 indexedNewPorts.emplace_back(index, port);
1174 }
1175 modifyModulePorts(getOperation(), {}, indexedNewPorts, {}, {},
1176 getBodyBlock());
1177
1178 // Rewrite the output op.
1179 for (auto &[name, value] : outputs)
1180 output->insertOperands(index++, value);
1181}
1182
1183void HWModuleOp::appendOutputs(ArrayRef<std::pair<StringAttr, Value>> outputs) {
1184 return insertOutputs(getNumOutputPorts(), outputs);
1185}
1186
1187void HWModuleOp::getAsmBlockArgumentNames(mlir::Region &region,
1188 mlir::OpAsmSetValueNameFn setNameFn) {
1189 getAsmBlockArgumentNamesImpl(region, setNameFn);
1190}
1191
1192void HWModuleExternOp::getAsmBlockArgumentNames(
1193 mlir::Region &region, mlir::OpAsmSetValueNameFn setNameFn) {
1194 getAsmBlockArgumentNamesImpl(region, setNameFn);
1195}
1196
1197template <typename ModTy>
1198static SmallVector<Location> getAllPortLocs(ModTy module) {
1199 auto locs = module.getPortLocs();
1200 if (locs) {
1201 SmallVector<Location> retval;
1202 retval.reserve(locs->size());
1203 for (auto l : *locs)
1204 retval.push_back(cast<Location>(l));
1205 // Either we have a length of 0 or the correct length
1206 assert(!locs->size() || locs->size() == module.getNumPorts());
1207 return retval;
1208 }
1209 return SmallVector<Location>(module.getNumPorts(),
1210 UnknownLoc::get(module.getContext()));
1211}
1212
1213SmallVector<Location> HWModuleOp::getAllPortLocs() {
1214 SmallVector<Location> portLocs;
1215 portLocs.reserve(getNumPorts());
1216 auto resultLocs = getResultLocsAttr();
1217 unsigned inputCount = 0;
1218 auto modType = getModuleType();
1219 auto unknownLoc = UnknownLoc::get(getContext());
1220 auto *body = getBodyBlock();
1221 for (unsigned i = 0, e = getNumPorts(); i < e; ++i) {
1222 if (modType.isOutput(i)) {
1223 auto loc = resultLocs
1224 ? cast<Location>(
1225 resultLocs.getValue()[portLocs.size() - inputCount])
1226 : unknownLoc;
1227 portLocs.push_back(loc);
1228 } else {
1229 auto loc = body ? body->getArgument(inputCount).getLoc() : unknownLoc;
1230 portLocs.push_back(loc);
1231 ++inputCount;
1232 }
1233 }
1234 return portLocs;
1235}
1236
1237SmallVector<Location> HWModuleExternOp::getAllPortLocs() {
1238 return ::getAllPortLocs(*this);
1239}
1240
1241SmallVector<Location> HWModuleGeneratedOp::getAllPortLocs() {
1242 return ::getAllPortLocs(*this);
1243}
1244
1245void HWModuleOp::setAllPortLocsAttrs(ArrayRef<Attribute> locs) {
1246 SmallVector<Attribute> resultLocs;
1247 unsigned inputCount = 0;
1248 auto modType = getModuleType();
1249 auto *body = getBodyBlock();
1250 for (unsigned i = 0, e = getNumPorts(); i < e; ++i) {
1251 if (modType.isOutput(i))
1252 resultLocs.push_back(locs[i]);
1253 else
1254 body->getArgument(inputCount++).setLoc(cast<Location>(locs[i]));
1255 }
1256 setResultLocsAttr(ArrayAttr::get(getContext(), resultLocs));
1257}
1258
1259void HWModuleExternOp::setAllPortLocsAttrs(ArrayRef<Attribute> locs) {
1260 setPortLocsAttr(ArrayAttr::get(getContext(), locs));
1261}
1262
1263void HWModuleGeneratedOp::setAllPortLocsAttrs(ArrayRef<Attribute> locs) {
1264 setPortLocsAttr(ArrayAttr::get(getContext(), locs));
1265}
1266
1267template <typename ModTy>
1268static void setAllPortNames(ArrayRef<Attribute> names, ModTy module) {
1269 auto numInputs = module.getNumInputPorts();
1270 SmallVector<Attribute> argNames(names.begin(), names.begin() + numInputs);
1271 SmallVector<Attribute> resNames(names.begin() + numInputs, names.end());
1272 auto oldType = module.getModuleType();
1273 SmallVector<ModulePort> newPorts(oldType.getPorts().begin(),
1274 oldType.getPorts().end());
1275 for (size_t i = 0UL, e = newPorts.size(); i != e; ++i)
1276 newPorts[i].name = cast<StringAttr>(names[i]);
1277 auto newType = ModuleType::get(module.getContext(), newPorts);
1278 module.setModuleType(newType);
1279}
1280
1281void HWModuleOp::setAllPortNames(ArrayRef<Attribute> names) {
1282 ::setAllPortNames(names, *this);
1283}
1284
1285void HWModuleExternOp::setAllPortNames(ArrayRef<Attribute> names) {
1286 ::setAllPortNames(names, *this);
1287}
1288
1289void HWModuleGeneratedOp::setAllPortNames(ArrayRef<Attribute> names) {
1290 ::setAllPortNames(names, *this);
1291}
1292
1293ArrayRef<Attribute> HWModuleOp::getAllPortAttrs() {
1294 auto attrs = getPerPortAttrs();
1295 if (attrs && !attrs->empty())
1296 return attrs->getValue();
1297 return {};
1298}
1299
1300ArrayRef<Attribute> HWModuleExternOp::getAllPortAttrs() {
1301 auto attrs = getPerPortAttrs();
1302 if (attrs && !attrs->empty())
1303 return attrs->getValue();
1304 return {};
1305}
1306
1307ArrayRef<Attribute> HWModuleGeneratedOp::getAllPortAttrs() {
1308 auto attrs = getPerPortAttrs();
1309 if (attrs && !attrs->empty())
1310 return attrs->getValue();
1311 return {};
1312}
1313
1314void HWModuleOp::setAllPortAttrs(ArrayRef<Attribute> attrs) {
1315 setPerPortAttrsAttr(arrayOrEmpty(getContext(), attrs));
1316}
1317
1318void HWModuleExternOp::setAllPortAttrs(ArrayRef<Attribute> attrs) {
1319 setPerPortAttrsAttr(arrayOrEmpty(getContext(), attrs));
1320}
1321
1322void HWModuleGeneratedOp::setAllPortAttrs(ArrayRef<Attribute> attrs) {
1323 setPerPortAttrsAttr(arrayOrEmpty(getContext(), attrs));
1324}
1325
1326void HWModuleOp::removeAllPortAttrs() {
1327 setPerPortAttrsAttr(ArrayAttr::get(getContext(), {}));
1328}
1329
1330void HWModuleExternOp::removeAllPortAttrs() {
1331 setPerPortAttrsAttr(ArrayAttr::get(getContext(), {}));
1332}
1333
1334void HWModuleGeneratedOp::removeAllPortAttrs() {
1335 setPerPortAttrsAttr(ArrayAttr::get(getContext(), {}));
1336}
1337
1338// This probably does really unexpected stuff when you change the number of
1339
1340template <typename ModTy>
1341static void setHWModuleType(ModTy &mod, ModuleType type) {
1342 auto argAttrs = mod.getAllInputAttrs();
1343 auto resAttrs = mod.getAllOutputAttrs();
1344 mod.setModuleTypeAttr(TypeAttr::get(type));
1345 unsigned newNumArgs = type.getNumInputs();
1346 unsigned newNumResults = type.getNumOutputs();
1347
1348 auto emptyDict = DictionaryAttr::get(mod.getContext());
1349 argAttrs.resize(newNumArgs, emptyDict);
1350 resAttrs.resize(newNumResults, emptyDict);
1351
1352 SmallVector<Attribute> attrs;
1353 attrs.append(argAttrs.begin(), argAttrs.end());
1354 attrs.append(resAttrs.begin(), resAttrs.end());
1355
1356 if (attrs.empty())
1357 return mod.removeAllPortAttrs();
1358 mod.setAllPortAttrs(attrs);
1359}
1360
1361void HWModuleOp::setHWModuleType(ModuleType type) {
1362 return ::setHWModuleType(*this, type);
1363}
1364
1365void HWModuleExternOp::setHWModuleType(ModuleType type) {
1366 return ::setHWModuleType(*this, type);
1367}
1368
1369void HWModuleGeneratedOp::setHWModuleType(ModuleType type) {
1370 return ::setHWModuleType(*this, type);
1371}
1372
1373/// Lookup the generator for the symbol. This returns null on
1374/// invalid IR.
1375Operation *HWModuleGeneratedOp::getGeneratorKindOp() {
1376 auto topLevelModuleOp = (*this)->getParentOfType<ModuleOp>();
1377 return topLevelModuleOp.lookupSymbol(getGeneratorKind());
1378}
1379
1380LogicalResult
1381HWModuleGeneratedOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1382 auto *referencedKind =
1383 symbolTable.lookupNearestSymbolFrom(*this, getGeneratorKindAttr());
1384
1385 if (referencedKind == nullptr)
1386 return emitError("Cannot find generator definition '")
1387 << getGeneratorKind() << "'";
1388
1389 if (!isa<HWGeneratorSchemaOp>(referencedKind))
1390 return emitError("Symbol resolved to '")
1391 << referencedKind->getName()
1392 << "' which is not a HWGeneratorSchemaOp";
1393
1394 auto referencedKindOp = dyn_cast<HWGeneratorSchemaOp>(referencedKind);
1395 auto paramRef = referencedKindOp.getRequiredAttrs();
1396 auto dict = (*this)->getAttrDictionary();
1397 for (auto str : paramRef) {
1398 auto strAttr = dyn_cast<StringAttr>(str);
1399 if (!strAttr)
1400 return emitError("Unknown attribute type, expected a string");
1401 if (!dict.get(strAttr.getValue()))
1402 return emitError("Missing attribute '") << strAttr.getValue() << "'";
1403 }
1404
1405 return success();
1406}
1407
1408LogicalResult HWModuleGeneratedOp::verify() {
1409 return verifyModuleCommon(*this);
1410}
1411
1412void HWModuleGeneratedOp::getAsmBlockArgumentNames(
1413 mlir::Region &region, mlir::OpAsmSetValueNameFn setNameFn) {
1414 getAsmBlockArgumentNamesImpl(region, setNameFn);
1415}
1416
1417LogicalResult HWModuleOp::verifyBody() { return success(); }
1418
1419template <typename ModuleTy>
1420static SmallVector<PortInfo> getPortList(ModuleTy &mod) {
1421 auto modTy = mod.getHWModuleType();
1422 auto emptyDict = DictionaryAttr::get(mod.getContext());
1423 SmallVector<PortInfo> retval;
1424 auto locs = mod.getAllPortLocs();
1425 for (unsigned i = 0, e = modTy.getNumPorts(); i < e; ++i) {
1426 LocationAttr loc = locs[i];
1427 DictionaryAttr attrs =
1428 dyn_cast_or_null<DictionaryAttr>(mod.getPortAttrs(i));
1429 if (!attrs)
1430 attrs = emptyDict;
1431 retval.push_back({modTy.getPorts()[i],
1432 modTy.isOutput(i) ? modTy.getOutputIdForPortId(i)
1433 : modTy.getInputIdForPortId(i),
1434 attrs, loc});
1435 }
1436 return retval;
1437}
1438
1439template <typename ModuleTy>
1440static PortInfo getPort(ModuleTy &mod, size_t idx) {
1441 auto modTy = mod.getHWModuleType();
1442 auto emptyDict = DictionaryAttr::get(mod.getContext());
1443 LocationAttr loc = mod.getPortLoc(idx);
1444 DictionaryAttr attrs =
1445 dyn_cast_or_null<DictionaryAttr>(mod.getPortAttrs(idx));
1446 if (!attrs)
1447 attrs = emptyDict;
1448 return {modTy.getPorts()[idx],
1449 modTy.isOutput(idx) ? modTy.getOutputIdForPortId(idx)
1450 : modTy.getInputIdForPortId(idx),
1451 attrs, loc};
1452}
1453
1454//===----------------------------------------------------------------------===//
1455// InstanceOp
1456//===----------------------------------------------------------------------===//
1457
1458/// Create a instance that refers to a known module.
1459void InstanceOp::build(OpBuilder &builder, OperationState &result,
1460 Operation *module, StringAttr name,
1461 ArrayRef<Value> inputs, ArrayAttr parameters,
1462 InnerSymAttr innerSym) {
1463 if (!parameters)
1464 parameters = builder.getArrayAttr({});
1465
1466 auto mod = cast<hw::HWModuleLike>(module);
1467 auto argNames = builder.getArrayAttr(mod.getInputNames());
1468 auto resultNames = builder.getArrayAttr(mod.getOutputNames());
1469
1470 // Try to resolve the parameterized module type. If failed, use the module's
1471 // parmeterized type. If the client doesn't fix this error, the verifier will
1472 // fail.
1473 ModuleType modType = mod.getHWModuleType();
1474 FailureOr<ModuleType> resolvedModType = modType.resolveParametricTypes(
1475 parameters, result.location, /*emitErrors=*/false);
1476 if (succeeded(resolvedModType))
1477 modType = *resolvedModType;
1478 FunctionType funcType = resolvedModType->getFuncType();
1479 build(builder, result, funcType.getResults(), name,
1480 FlatSymbolRefAttr::get(SymbolTable::getSymbolName(module)), inputs,
1481 argNames, resultNames, parameters, innerSym, /*doNotPrint=*/{});
1482}
1483
1484std::optional<size_t> InstanceOp::getTargetResultIndex() {
1485 // Inner symbols on instance operations target the op not any result.
1486 return std::nullopt;
1487}
1488
1489LogicalResult InstanceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1491 *this, getModuleNameAttr(), getInputs(), getResultTypes(), getArgNames(),
1492 getResultNames(), getParameters(), symbolTable);
1493}
1494
1495LogicalResult InstanceOp::verify() {
1496 auto module = (*this)->getParentOfType<HWModuleOp>();
1497 if (!module)
1498 return success();
1499
1500 auto moduleParameters = module->getAttrOfType<ArrayAttr>("parameters");
1502 [&](const std::function<bool(InFlightDiagnostic &)> &fn) {
1503 auto diag = emitOpError();
1504 if (fn(diag))
1505 diag.attachNote(module->getLoc()) << "module declared here";
1506 };
1508 getParameters(), moduleParameters, emitError);
1509}
1510
1511ParseResult InstanceOp::parse(OpAsmParser &parser, OperationState &result) {
1512 StringAttr instanceNameAttr;
1513 InnerSymAttr innerSym;
1514 FlatSymbolRefAttr moduleNameAttr;
1515 SmallVector<OpAsmParser::UnresolvedOperand, 4> inputsOperands;
1516 SmallVector<Type, 1> inputsTypes, allResultTypes;
1517 ArrayAttr argNames, resultNames, parameters;
1518 auto noneType = parser.getBuilder().getType<NoneType>();
1519
1520 if (parser.parseAttribute(instanceNameAttr, noneType, "instanceName",
1521 result.attributes))
1522 return failure();
1523
1524 if (succeeded(parser.parseOptionalKeyword("sym"))) {
1525 // Parsing an optional symbol name doesn't fail, so no need to check the
1526 // result.
1527 if (parser.parseCustomAttributeWithFallback(innerSym))
1528 return failure();
1529 result.addAttribute(InnerSymbolTable::getInnerSymbolAttrName(), innerSym);
1530 }
1531
1532 llvm::SMLoc parametersLoc, inputsOperandsLoc;
1533 if (parser.parseAttribute(moduleNameAttr, noneType, "moduleName",
1534 result.attributes) ||
1535 parser.getCurrentLocation(&parametersLoc) ||
1536 parseOptionalParameterList(parser, parameters) ||
1537 parseInputPortList(parser, inputsOperands, inputsTypes, argNames) ||
1538 parser.resolveOperands(inputsOperands, inputsTypes, inputsOperandsLoc,
1539 result.operands) ||
1540 parser.parseArrow() ||
1541 parseOutputPortList(parser, allResultTypes, resultNames) ||
1542 parser.parseOptionalAttrDict(result.attributes)) {
1543 return failure();
1544 }
1545
1546 result.addAttribute("argNames", argNames);
1547 result.addAttribute("resultNames", resultNames);
1548 result.addAttribute("parameters", parameters);
1549 result.addTypes(allResultTypes);
1550 return success();
1551}
1552
1553void InstanceOp::print(OpAsmPrinter &p) {
1554 p << ' ';
1555 p.printAttributeWithoutType(getInstanceNameAttr());
1556 if (auto attr = getInnerSymAttr()) {
1557 p << " sym ";
1558 attr.print(p);
1559 }
1560 p << ' ';
1561 p.printAttributeWithoutType(getModuleNameAttr());
1562 printOptionalParameterList(p, *this, getParameters());
1563 printInputPortList(p, *this, getInputs(), getInputs().getTypes(),
1564 getArgNames());
1565 p << " -> ";
1566 printOutputPortList(p, *this, getResultTypes(), getResultNames());
1567
1568 p.printOptionalAttrDict(
1569 (*this)->getAttrs(),
1570 /*elidedAttrs=*/{"instanceName",
1571 InnerSymbolTable::getInnerSymbolAttrName(), "moduleName",
1572 "argNames", "resultNames", "parameters"});
1573}
1574
1575//===----------------------------------------------------------------------===//
1576// InstanceChoiceOp
1577//===----------------------------------------------------------------------===//
1578
1579std::optional<size_t> InstanceChoiceOp::getTargetResultIndex() {
1580 // Inner symbols on instance operations target the op not any result.
1581 return std::nullopt;
1582}
1583
1584LogicalResult
1585InstanceChoiceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1586 for (Attribute name : getModuleNamesAttr()) {
1588 *this, cast<FlatSymbolRefAttr>(name), getInputs(), getResultTypes(),
1589 getArgNames(), getResultNames(), getParameters(), symbolTable))) {
1590 return failure();
1591 }
1592 }
1593 return success();
1594}
1595
1596LogicalResult InstanceChoiceOp::verify() {
1597 auto module = (*this)->getParentOfType<HWModuleOp>();
1598 if (!module)
1599 return success();
1600
1601 auto moduleParameters = module->getAttrOfType<ArrayAttr>("parameters");
1603 [&](const std::function<bool(InFlightDiagnostic &)> &fn) {
1604 auto diag = emitOpError();
1605 if (fn(diag))
1606 diag.attachNote(module->getLoc()) << "module declared here";
1607 };
1609 getParameters(), moduleParameters, emitError);
1610}
1611
1612ParseResult InstanceChoiceOp::parse(OpAsmParser &parser,
1613 OperationState &result) {
1614 StringAttr optionNameAttr;
1615 StringAttr instanceNameAttr;
1616 InnerSymAttr innerSym;
1617 SmallVector<Attribute> moduleNames;
1618 SmallVector<Attribute> caseNames;
1619 SmallVector<OpAsmParser::UnresolvedOperand, 4> inputsOperands;
1620 SmallVector<Type, 1> inputsTypes, allResultTypes;
1621 ArrayAttr argNames, resultNames, parameters;
1622 auto noneType = parser.getBuilder().getType<NoneType>();
1623
1624 if (parser.parseAttribute(instanceNameAttr, noneType, "instanceName",
1625 result.attributes))
1626 return failure();
1627
1628 if (succeeded(parser.parseOptionalKeyword("sym"))) {
1629 // Parsing an optional symbol name doesn't fail, so no need to check the
1630 // result.
1631 if (parser.parseCustomAttributeWithFallback(innerSym))
1632 return failure();
1633 result.addAttribute(InnerSymbolTable::getInnerSymbolAttrName(), innerSym);
1634 }
1635
1636 if (parser.parseKeyword("option") ||
1637 parser.parseAttribute(optionNameAttr, noneType, "optionName",
1638 result.attributes))
1639 return failure();
1640
1641 FlatSymbolRefAttr defaultModuleName;
1642 if (parser.parseAttribute(defaultModuleName))
1643 return failure();
1644 moduleNames.push_back(defaultModuleName);
1645
1646 while (succeeded(parser.parseOptionalKeyword("or"))) {
1647 FlatSymbolRefAttr moduleName;
1648 StringAttr targetName;
1649 if (parser.parseAttribute(moduleName) ||
1650 parser.parseOptionalKeyword("if") || parser.parseAttribute(targetName))
1651 return failure();
1652 moduleNames.push_back(moduleName);
1653 caseNames.push_back(targetName);
1654 }
1655
1656 llvm::SMLoc parametersLoc, inputsOperandsLoc;
1657 if (parser.getCurrentLocation(&parametersLoc) ||
1658 parseOptionalParameterList(parser, parameters) ||
1659 parseInputPortList(parser, inputsOperands, inputsTypes, argNames) ||
1660 parser.resolveOperands(inputsOperands, inputsTypes, inputsOperandsLoc,
1661 result.operands) ||
1662 parser.parseArrow() ||
1663 parseOutputPortList(parser, allResultTypes, resultNames) ||
1664 parser.parseOptionalAttrDict(result.attributes)) {
1665 return failure();
1666 }
1667
1668 result.addAttribute("moduleNames",
1669 ArrayAttr::get(parser.getContext(), moduleNames));
1670 result.addAttribute("caseNames",
1671 ArrayAttr::get(parser.getContext(), caseNames));
1672 result.addAttribute("argNames", argNames);
1673 result.addAttribute("resultNames", resultNames);
1674 result.addAttribute("parameters", parameters);
1675 result.addTypes(allResultTypes);
1676 return success();
1677}
1678
1679void InstanceChoiceOp::print(OpAsmPrinter &p) {
1680 p << ' ';
1681 p.printAttributeWithoutType(getInstanceNameAttr());
1682 if (auto attr = getInnerSymAttr()) {
1683 p << " sym ";
1684 attr.print(p);
1685 }
1686 p << " option " << getOptionNameAttr() << ' ';
1687
1688 auto moduleNames = getModuleNamesAttr();
1689 auto caseNames = getCaseNamesAttr();
1690 assert(moduleNames.size() == caseNames.size() + 1);
1691
1692 p.printAttributeWithoutType(moduleNames[0]);
1693 for (size_t i = 0, n = caseNames.size(); i < n; ++i) {
1694 p << " or ";
1695 p.printAttributeWithoutType(moduleNames[i + 1]);
1696 p << " if ";
1697 p.printAttributeWithoutType(caseNames[i]);
1698 }
1699
1700 printOptionalParameterList(p, *this, getParameters());
1701 printInputPortList(p, *this, getInputs(), getInputs().getTypes(),
1702 getArgNames());
1703 p << " -> ";
1704 printOutputPortList(p, *this, getResultTypes(), getResultNames());
1705
1706 p.printOptionalAttrDict(
1707 (*this)->getAttrs(),
1708 /*elidedAttrs=*/{"instanceName",
1709 InnerSymbolTable::getInnerSymbolAttrName(),
1710 "moduleNames", "caseNames", "argNames", "resultNames",
1711 "parameters", "optionName"});
1712}
1713
1714ArrayAttr InstanceChoiceOp::getReferencedModuleNamesAttr() {
1715 SmallVector<Attribute> moduleNames;
1716 for (Attribute attr : getModuleNamesAttr()) {
1717 moduleNames.push_back(cast<FlatSymbolRefAttr>(attr).getAttr());
1718 }
1719 return ArrayAttr::get(getContext(), moduleNames);
1720}
1721
1722//===----------------------------------------------------------------------===//
1723// HWOutputOp
1724//===----------------------------------------------------------------------===//
1725
1726/// Verify that the num of operands and types fit the declared results.
1727LogicalResult OutputOp::verify() {
1728 // Check that the we (hw.output) have the same number of operands as our
1729 // region has results.
1730 ModuleType modType;
1731 if (auto mod = dyn_cast<HWModuleOp>((*this)->getParentOp()))
1732 modType = mod.getHWModuleType();
1733 else {
1734 emitOpError("must have a module parent");
1735 return failure();
1736 }
1737 auto modResults = modType.getOutputTypes();
1738 OperandRange outputValues = getOperands();
1739 if (modResults.size() != outputValues.size()) {
1740 emitOpError("must have same number of operands as region results.");
1741 return failure();
1742 }
1743
1744 // Check that the types of our operands and the region's results match.
1745 for (size_t i = 0, e = modResults.size(); i < e; ++i) {
1746 if (modResults[i] != outputValues[i].getType()) {
1747 emitOpError("output types must match module. In "
1748 "operand ")
1749 << i << ", expected " << modResults[i] << ", but got "
1750 << outputValues[i].getType() << ".";
1751 return failure();
1752 }
1753 }
1754
1755 return success();
1756}
1757
1758//===----------------------------------------------------------------------===//
1759// Other Operations
1760//===----------------------------------------------------------------------===//
1761
1762static ParseResult parseSliceTypes(OpAsmParser &p, Type &srcType,
1763 Type &idxType) {
1764 Type type;
1765 if (p.parseType(type))
1766 return p.emitError(p.getCurrentLocation(), "Expected type");
1767 auto arrType = type_dyn_cast<ArrayType>(type);
1768 if (!arrType)
1769 return p.emitError(p.getCurrentLocation(), "Expected !hw.array type");
1770 srcType = type;
1771 unsigned idxWidth = llvm::Log2_64_Ceil(arrType.getNumElements());
1772 idxType = IntegerType::get(p.getBuilder().getContext(), idxWidth);
1773 return success();
1774}
1775
1776static void printSliceTypes(OpAsmPrinter &p, Operation *, Type srcType,
1777 Type idxType) {
1778 p.printType(srcType);
1779}
1780
1781ParseResult ArrayCreateOp::parse(OpAsmParser &parser, OperationState &result) {
1782 llvm::SMLoc inputOperandsLoc = parser.getCurrentLocation();
1783 llvm::SmallVector<OpAsmParser::UnresolvedOperand, 16> operands;
1784 Type elemType;
1785
1786 if (parser.parseOperandList(operands) ||
1787 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
1788 parser.parseType(elemType))
1789 return failure();
1790
1791 if (operands.size() == 0)
1792 return parser.emitError(inputOperandsLoc,
1793 "Cannot construct an array of length 0");
1794 result.addTypes({ArrayType::get(elemType, operands.size())});
1795
1796 for (auto operand : operands)
1797 if (parser.resolveOperand(operand, elemType, result.operands))
1798 return failure();
1799 return success();
1800}
1801
1802void ArrayCreateOp::print(OpAsmPrinter &p) {
1803 p << " ";
1804 p.printOperands(getInputs());
1805 p.printOptionalAttrDict((*this)->getAttrs());
1806 p << " : " << getInputs()[0].getType();
1807}
1808
1809void ArrayCreateOp::build(OpBuilder &b, OperationState &state,
1810 ValueRange values) {
1811 assert(values.size() > 0 && "Cannot build array of zero elements");
1812 Type elemType = values[0].getType();
1813 assert(llvm::all_of(
1814 values,
1815 [elemType](Value v) -> bool { return v.getType() == elemType; }) &&
1816 "All values must have same type.");
1817 build(b, state, ArrayType::get(elemType, values.size()), values);
1818}
1819
1820LogicalResult ArrayCreateOp::verify() {
1821 unsigned returnSize = cast<ArrayType>(getType()).getNumElements();
1822 if (getInputs().size() != returnSize)
1823 return failure();
1824 return success();
1825}
1826
1827OpFoldResult ArrayCreateOp::fold(FoldAdaptor adaptor) {
1828 if (llvm::any_of(adaptor.getInputs(), [](Attribute attr) { return !attr; }))
1829 return {};
1830 return ArrayAttr::get(getContext(), adaptor.getInputs());
1831}
1832
1833// Check whether an integer value is an offset from a base.
1834bool hw::isOffset(Value base, Value index, uint64_t offset) {
1835 if (auto constBase = base.getDefiningOp<hw::ConstantOp>()) {
1836 if (auto constIndex = index.getDefiningOp<hw::ConstantOp>()) {
1837 // If both values are a constant, check if index == base + offset.
1838 // To account for overflow, the addition is performed with an extra bit
1839 // and the offset is asserted to fit in the bit width of the base.
1840 auto baseValue = constBase.getValue();
1841 auto indexValue = constIndex.getValue();
1842
1843 unsigned bits = baseValue.getBitWidth();
1844 assert(bits == indexValue.getBitWidth() && "mismatched widths");
1845
1846 if (bits < 64 && offset >= (1ull << bits))
1847 return false;
1848
1849 APInt baseExt = baseValue.zextOrTrunc(bits + 1);
1850 APInt indexExt = indexValue.zextOrTrunc(bits + 1);
1851 return baseExt + offset == indexExt;
1852 }
1853 }
1854 return false;
1855}
1856
1857// Canonicalize a create of consecutive elements to a slice.
1858static LogicalResult foldCreateToSlice(ArrayCreateOp op,
1859 PatternRewriter &rewriter) {
1860 // Do not canonicalize create of get into a slice.
1861 auto arrayTy = hw::type_cast<ArrayType>(op.getType());
1862 if (arrayTy.getNumElements() <= 1)
1863 return failure();
1864 auto elemTy = arrayTy.getElementType();
1865
1866 // Check if create arguments are consecutive elements of the same array.
1867 // Attempt to break a create of gets into a sequence of consecutive intervals.
1868 struct Chunk {
1869 Value input;
1870 Value index;
1871 size_t size;
1872 };
1873 SmallVector<Chunk> chunks;
1874 for (Value value : llvm::reverse(op.getInputs())) {
1875 auto get = value.getDefiningOp<ArrayGetOp>();
1876 if (!get)
1877 return failure();
1878
1879 Value input = get.getInput();
1880 Value index = get.getIndex();
1881 if (!chunks.empty()) {
1882 auto &c = *chunks.rbegin();
1883 if (c.input == get.getInput() && isOffset(c.index, index, c.size)) {
1884 c.size++;
1885 continue;
1886 }
1887 }
1888
1889 chunks.push_back(Chunk{input, index, 1});
1890 }
1891
1892 // If there is a single slice, eliminate the create.
1893 if (chunks.size() == 1) {
1894 auto &chunk = chunks[0];
1895 rewriter.replaceOp(op, rewriter.createOrFold<ArraySliceOp>(
1896 op.getLoc(), arrayTy, chunk.input, chunk.index));
1897 return success();
1898 }
1899
1900 // If the number of chunks is significantly less than the number of
1901 // elements, replace the create with a concat of the identified slices.
1902 if (chunks.size() * 2 < arrayTy.getNumElements()) {
1903 SmallVector<Value> slices;
1904 for (auto &chunk : llvm::reverse(chunks)) {
1905 auto sliceTy = ArrayType::get(elemTy, chunk.size);
1906 slices.push_back(rewriter.createOrFold<ArraySliceOp>(
1907 op.getLoc(), sliceTy, chunk.input, chunk.index));
1908 }
1909 rewriter.replaceOpWithNewOp<ArrayConcatOp>(op, arrayTy, slices);
1910 return success();
1911 }
1912
1913 return failure();
1914}
1915
1916LogicalResult ArrayCreateOp::canonicalize(ArrayCreateOp op,
1917 PatternRewriter &rewriter) {
1918 if (succeeded(foldCreateToSlice(op, rewriter)))
1919 return success();
1920 return failure();
1921}
1922
1923Value ArrayCreateOp::getUniformElement() {
1924 if (!getInputs().empty() && llvm::all_equal(getInputs()))
1925 return getInputs()[0];
1926 return {};
1927}
1928
1929static std::optional<uint64_t> getUIntFromValue(Value value) {
1930 auto idxOp = dyn_cast_or_null<ConstantOp>(value.getDefiningOp());
1931 if (!idxOp)
1932 return std::nullopt;
1933 APInt idxAttr = idxOp.getValue();
1934 if (idxAttr.getBitWidth() > 64)
1935 return std::nullopt;
1936 return idxAttr.getLimitedValue();
1937}
1938
1939LogicalResult ArraySliceOp::verify() {
1940 unsigned inputSize =
1941 type_cast<ArrayType>(getInput().getType()).getNumElements();
1942 if (llvm::Log2_64_Ceil(inputSize) !=
1943 getLowIndex().getType().getIntOrFloatBitWidth())
1944 return emitOpError(
1945 "ArraySlice: index width must match clog2 of array size");
1946 return success();
1947}
1948
1949OpFoldResult ArraySliceOp::fold(FoldAdaptor adaptor) {
1950 // If we are slicing the entire input, then return it.
1951 if (getType() == getInput().getType())
1952 return getInput();
1953 return {};
1954}
1955
1956LogicalResult ArraySliceOp::canonicalize(ArraySliceOp op,
1957 PatternRewriter &rewriter) {
1958 auto sliceTy = hw::type_cast<ArrayType>(op.getType());
1959 auto elemTy = sliceTy.getElementType();
1960 uint64_t sliceSize = sliceTy.getNumElements();
1961 if (sliceSize == 0)
1962 return failure();
1963
1964 if (sliceSize == 1) {
1965 // slice(a, n) -> create(a[n])
1966 auto get = rewriter.create<ArrayGetOp>(op.getLoc(), op.getInput(),
1967 op.getLowIndex());
1968 rewriter.replaceOpWithNewOp<ArrayCreateOp>(op, op.getType(),
1969 get.getResult());
1970 return success();
1971 }
1972
1973 auto offsetOpt = getUIntFromValue(op.getLowIndex());
1974 if (!offsetOpt)
1975 return failure();
1976
1977 auto inputOp = op.getInput().getDefiningOp();
1978 if (auto inputSlice = dyn_cast_or_null<ArraySliceOp>(inputOp)) {
1979 // slice(slice(a, n), m) -> slice(a, n + m)
1980 if (inputSlice == op)
1981 return failure();
1982
1983 auto inputIndex = inputSlice.getLowIndex();
1984 auto inputOffsetOpt = getUIntFromValue(inputIndex);
1985 if (!inputOffsetOpt)
1986 return failure();
1987
1988 uint64_t offset = *offsetOpt + *inputOffsetOpt;
1989 auto lowIndex =
1990 rewriter.create<ConstantOp>(op.getLoc(), inputIndex.getType(), offset);
1991 rewriter.replaceOpWithNewOp<ArraySliceOp>(op, op.getType(),
1992 inputSlice.getInput(), lowIndex);
1993 return success();
1994 }
1995
1996 if (auto inputCreate = dyn_cast_or_null<ArrayCreateOp>(inputOp)) {
1997 // slice(create(a0, a1, ..., an), m) -> create(am, ...)
1998 auto inputs = inputCreate.getInputs();
1999
2000 uint64_t begin = inputs.size() - *offsetOpt - sliceSize;
2001 rewriter.replaceOpWithNewOp<ArrayCreateOp>(op, op.getType(),
2002 inputs.slice(begin, sliceSize));
2003 return success();
2004 }
2005
2006 if (auto inputConcat = dyn_cast_or_null<ArrayConcatOp>(inputOp)) {
2007 // slice(concat(a1, a2, ...)) -> concat(a2, slice(a3, ..), ...)
2008 SmallVector<Value> chunks;
2009 uint64_t sliceStart = *offsetOpt;
2010 for (auto input : llvm::reverse(inputConcat.getInputs())) {
2011 // Check whether the input intersects with the slice.
2012 uint64_t inputSize =
2013 hw::type_cast<ArrayType>(input.getType()).getNumElements();
2014 if (inputSize == 0 || inputSize <= sliceStart) {
2015 sliceStart -= inputSize;
2016 continue;
2017 }
2018
2019 // Find the indices to slice from this input by intersection.
2020 uint64_t cutEnd = std::min(inputSize, sliceStart + sliceSize);
2021 uint64_t cutSize = cutEnd - sliceStart;
2022 assert(cutSize != 0 && "slice cannot be empty");
2023
2024 if (cutSize == inputSize) {
2025 // The whole input fits in the slice, add it.
2026 assert(sliceStart == 0 && "invalid cut size");
2027 chunks.push_back(input);
2028 } else {
2029 // Slice the required bits from the input.
2030 unsigned width = inputSize == 1 ? 1 : llvm::Log2_64_Ceil(inputSize);
2031 auto lowIndex = rewriter.create<ConstantOp>(
2032 op.getLoc(), rewriter.getIntegerType(width), sliceStart);
2033 chunks.push_back(rewriter.create<ArraySliceOp>(
2034 op.getLoc(), hw::ArrayType::get(elemTy, cutSize), input, lowIndex));
2035 }
2036
2037 sliceStart = 0;
2038 sliceSize -= cutSize;
2039 if (sliceSize == 0)
2040 break;
2041 }
2042
2043 assert(chunks.size() > 0 && "missing sliced items");
2044 if (chunks.size() == 1)
2045 rewriter.replaceOp(op, chunks[0]);
2046 else
2047 rewriter.replaceOpWithNewOp<ArrayConcatOp>(
2048 op, llvm::to_vector(llvm::reverse(chunks)));
2049 return success();
2050 }
2051 return failure();
2052}
2053
2054//===----------------------------------------------------------------------===//
2055// ArrayConcatOp
2056//===----------------------------------------------------------------------===//
2057
2058static ParseResult parseArrayConcatTypes(OpAsmParser &p,
2059 SmallVectorImpl<Type> &inputTypes,
2060 Type &resultType) {
2061 Type elemType;
2062 uint64_t resultSize = 0;
2063
2064 auto parseElement = [&]() -> ParseResult {
2065 Type ty;
2066 if (p.parseType(ty))
2067 return failure();
2068 auto arrTy = type_dyn_cast<ArrayType>(ty);
2069 if (!arrTy)
2070 return p.emitError(p.getCurrentLocation(), "Expected !hw.array type");
2071 if (elemType && elemType != arrTy.getElementType())
2072 return p.emitError(p.getCurrentLocation(), "Expected array element type ")
2073 << elemType;
2074
2075 elemType = arrTy.getElementType();
2076 inputTypes.push_back(ty);
2077 resultSize += arrTy.getNumElements();
2078 return success();
2079 };
2080
2081 if (p.parseCommaSeparatedList(parseElement))
2082 return failure();
2083
2084 resultType = ArrayType::get(elemType, resultSize);
2085 return success();
2086}
2087
2088static void printArrayConcatTypes(OpAsmPrinter &p, Operation *,
2089 TypeRange inputTypes, Type resultType) {
2090 llvm::interleaveComma(inputTypes, p, [&p](Type t) { p << t; });
2091}
2092
2093void ArrayConcatOp::build(OpBuilder &b, OperationState &state,
2094 ValueRange values) {
2095 assert(!values.empty() && "Cannot build array of zero elements");
2096 ArrayType arrayTy = cast<ArrayType>(values[0].getType());
2097 Type elemTy = arrayTy.getElementType();
2098 assert(llvm::all_of(values,
2099 [elemTy](Value v) -> bool {
2100 return isa<ArrayType>(v.getType()) &&
2101 cast<ArrayType>(v.getType()).getElementType() ==
2102 elemTy;
2103 }) &&
2104 "All values must be of ArrayType with the same element type.");
2105
2106 uint64_t resultSize = 0;
2107 for (Value val : values)
2108 resultSize += cast<ArrayType>(val.getType()).getNumElements();
2109 build(b, state, ArrayType::get(elemTy, resultSize), values);
2110}
2111
2112OpFoldResult ArrayConcatOp::fold(FoldAdaptor adaptor) {
2113 auto inputs = adaptor.getInputs();
2114 SmallVector<Attribute> array;
2115 for (size_t i = 0, e = getNumOperands(); i < e; ++i) {
2116 if (!inputs[i])
2117 return {};
2118 llvm::copy(cast<ArrayAttr>(inputs[i]), std::back_inserter(array));
2119 }
2120 return ArrayAttr::get(getContext(), array);
2121}
2122
2123// Flatten a concatenation of array creates into a single create.
2124static bool flattenConcatOp(ArrayConcatOp op, PatternRewriter &rewriter) {
2125 for (auto input : op.getInputs())
2126 if (!input.getDefiningOp<ArrayCreateOp>())
2127 return false;
2128
2129 SmallVector<Value> items;
2130 for (auto input : op.getInputs()) {
2131 auto create = cast<ArrayCreateOp>(input.getDefiningOp());
2132 for (auto item : create.getInputs())
2133 items.push_back(item);
2134 }
2135
2136 rewriter.replaceOpWithNewOp<ArrayCreateOp>(op, items);
2137 return true;
2138}
2139
2140// Merge consecutive slice expressions in a concatenation.
2141static bool mergeConcatSlices(ArrayConcatOp op, PatternRewriter &rewriter) {
2142 struct Slice {
2143 Value input;
2144 Value index;
2145 size_t size;
2146 Value op;
2147 SmallVector<Location> locs;
2148 };
2149
2150 SmallVector<Value> items;
2151 std::optional<Slice> last;
2152 bool changed = false;
2153
2154 auto concatenate = [&] {
2155 // If there is only one op in the slice, place it to the items list.
2156 if (!last)
2157 return;
2158 if (last->op) {
2159 items.push_back(last->op);
2160 last.reset();
2161 return;
2162 }
2163
2164 // Otherwise, create a new slice of with the given size and place it.
2165 // In this case, the concat op is replaced, using the new argument.
2166 changed = true;
2167 auto loc = FusedLoc::get(op.getContext(), last->locs);
2168 auto origTy = hw::type_cast<ArrayType>(last->input.getType());
2169 auto arrayTy = ArrayType::get(origTy.getElementType(), last->size);
2170 items.push_back(rewriter.createOrFold<ArraySliceOp>(
2171 loc, arrayTy, last->input, last->index));
2172
2173 last.reset();
2174 };
2175
2176 auto append = [&](Value op, Value input, Value index, size_t size) {
2177 // If this slice is an extension of the previous one, extend the size
2178 // saved. In this case, a new slice of is created and the concatenation
2179 // operator is rewritten. Otherwise, flush the last slice.
2180 if (last) {
2181 if (last->input == input && isOffset(last->index, index, last->size)) {
2182 last->size += size;
2183 last->op = {};
2184 last->locs.push_back(op.getLoc());
2185 return;
2186 }
2187 concatenate();
2188 }
2189 last.emplace(Slice{input, index, size, op, {op.getLoc()}});
2190 };
2191
2192 for (auto item : llvm::reverse(op.getInputs())) {
2193 if (auto slice = item.getDefiningOp<ArraySliceOp>()) {
2194 auto size = hw::type_cast<ArrayType>(slice.getType()).getNumElements();
2195 append(item, slice.getInput(), slice.getLowIndex(), size);
2196 continue;
2197 }
2198
2199 if (auto create = item.getDefiningOp<ArrayCreateOp>()) {
2200 if (create.getInputs().size() == 1) {
2201 if (auto get = create.getInputs()[0].getDefiningOp<ArrayGetOp>()) {
2202 append(item, get.getInput(), get.getIndex(), 1);
2203 continue;
2204 }
2205 }
2206 }
2207
2208 concatenate();
2209 items.push_back(item);
2210 }
2211 concatenate();
2212
2213 if (!changed)
2214 return false;
2215
2216 if (items.size() == 1) {
2217 rewriter.replaceOp(op, items[0]);
2218 } else {
2219 std::reverse(items.begin(), items.end());
2220 rewriter.replaceOpWithNewOp<ArrayConcatOp>(op, items);
2221 }
2222 return true;
2223}
2224
2225LogicalResult ArrayConcatOp::canonicalize(ArrayConcatOp op,
2226 PatternRewriter &rewriter) {
2227 // concat(create(a1, ...), create(a3, ...), ...) -> create(a1, ..., a3, ...)
2228 if (flattenConcatOp(op, rewriter))
2229 return success();
2230
2231 // concat(slice(a, n, m), slice(a, n + m, p)) -> concat(slice(a, n, m + p))
2232 if (mergeConcatSlices(op, rewriter))
2233 return success();
2234
2235 return failure();
2236}
2237
2238//===----------------------------------------------------------------------===//
2239// EnumConstantOp
2240//===----------------------------------------------------------------------===//
2241
2242ParseResult EnumConstantOp::parse(OpAsmParser &parser, OperationState &result) {
2243 // Parse a Type instead of an EnumType since the type might be a type alias.
2244 // The validity of the canonical type is checked during construction of the
2245 // EnumFieldAttr.
2246 Type type;
2247 StringRef field;
2248
2249 auto loc = parser.getEncodedSourceLoc(parser.getCurrentLocation());
2250 if (parser.parseKeyword(&field) || parser.parseColonType(type))
2251 return failure();
2252
2253 auto fieldAttr = EnumFieldAttr::get(
2254 loc, StringAttr::get(parser.getContext(), field), type);
2255
2256 if (!fieldAttr)
2257 return failure();
2258
2259 result.addAttribute("field", fieldAttr);
2260 result.addTypes(type);
2261
2262 return success();
2263}
2264
2265void EnumConstantOp::print(OpAsmPrinter &p) {
2266 p << " " << getField().getField().getValue() << " : "
2267 << getField().getType().getValue();
2268}
2269
2270void EnumConstantOp::getAsmResultNames(
2271 function_ref<void(Value, StringRef)> setNameFn) {
2272 setNameFn(getResult(), getField().getField().str());
2273}
2274
2275void EnumConstantOp::build(OpBuilder &builder, OperationState &odsState,
2276 EnumFieldAttr field) {
2277 return build(builder, odsState, field.getType().getValue(), field);
2278}
2279
2280OpFoldResult EnumConstantOp::fold(FoldAdaptor adaptor) {
2281 assert(adaptor.getOperands().empty() && "constant has no operands");
2282 return getFieldAttr();
2283}
2284
2285LogicalResult EnumConstantOp::verify() {
2286 auto fieldAttr = getFieldAttr();
2287 auto fieldType = fieldAttr.getType().getValue();
2288 // This check ensures that we are using the exact same type, without looking
2289 // through type aliases.
2290 if (fieldType != getType())
2291 emitOpError("return type ")
2292 << getType() << " does not match attribute type " << fieldAttr;
2293 return success();
2294}
2295
2296//===----------------------------------------------------------------------===//
2297// EnumCmpOp
2298//===----------------------------------------------------------------------===//
2299
2300LogicalResult EnumCmpOp::verify() {
2301 // Compare the canonical types.
2302 auto lhsType = type_cast<EnumType>(getLhs().getType());
2303 auto rhsType = type_cast<EnumType>(getRhs().getType());
2304 if (rhsType != lhsType)
2305 emitOpError("types do not match");
2306 return success();
2307}
2308
2309//===----------------------------------------------------------------------===//
2310// StructCreateOp
2311//===----------------------------------------------------------------------===//
2312
2313ParseResult StructCreateOp::parse(OpAsmParser &parser, OperationState &result) {
2314 llvm::SMLoc inputOperandsLoc = parser.getCurrentLocation();
2315 llvm::SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
2316 Type declOrAliasType;
2317
2318 if (parser.parseLParen() || parser.parseOperandList(operands) ||
2319 parser.parseRParen() || parser.parseOptionalAttrDict(result.attributes) ||
2320 parser.parseColonType(declOrAliasType))
2321 return failure();
2322
2323 auto declType = type_dyn_cast<StructType>(declOrAliasType);
2324 if (!declType)
2325 return parser.emitError(parser.getNameLoc(),
2326 "expected !hw.struct type or alias");
2327
2328 llvm::SmallVector<Type, 4> structInnerTypes;
2329 declType.getInnerTypes(structInnerTypes);
2330 result.addTypes(declOrAliasType);
2331
2332 if (parser.resolveOperands(operands, structInnerTypes, inputOperandsLoc,
2333 result.operands))
2334 return failure();
2335 return success();
2336}
2337
2338void StructCreateOp::print(OpAsmPrinter &printer) {
2339 printer << " (";
2340 printer.printOperands(getInput());
2341 printer << ")";
2342 printer.printOptionalAttrDict((*this)->getAttrs());
2343 printer << " : " << getType();
2344}
2345
2346LogicalResult StructCreateOp::verify() {
2347 auto elements = hw::type_cast<StructType>(getType()).getElements();
2348
2349 if (elements.size() != getInput().size())
2350 return emitOpError("structure field count mismatch");
2351
2352 for (const auto &[field, value] : llvm::zip(elements, getInput()))
2353 if (field.type != value.getType())
2354 return emitOpError("structure field `")
2355 << field.name << "` type does not match";
2356
2357 return success();
2358}
2359
2360OpFoldResult StructCreateOp::fold(FoldAdaptor adaptor) {
2361 // struct_create(struct_explode(x)) => x
2362 if (!getInput().empty())
2363 if (auto explodeOp = getInput()[0].getDefiningOp<StructExplodeOp>();
2364 explodeOp && getInput() == explodeOp.getResults() &&
2365 getResult().getType() == explodeOp.getInput().getType())
2366 return explodeOp.getInput();
2367
2368 auto inputs = adaptor.getInput();
2369 if (llvm::any_of(inputs, [](Attribute attr) { return !attr; }))
2370 return {};
2371 return ArrayAttr::get(getContext(), inputs);
2372}
2373
2374//===----------------------------------------------------------------------===//
2375// StructExplodeOp
2376//===----------------------------------------------------------------------===//
2377
2378ParseResult StructExplodeOp::parse(OpAsmParser &parser,
2379 OperationState &result) {
2380 OpAsmParser::UnresolvedOperand operand;
2381 Type declType;
2382
2383 if (parser.parseOperand(operand) ||
2384 parser.parseOptionalAttrDict(result.attributes) ||
2385 parser.parseColonType(declType))
2386 return failure();
2387 auto structType = type_dyn_cast<StructType>(declType);
2388 if (!structType)
2389 return parser.emitError(parser.getNameLoc(),
2390 "invalid kind of type specified");
2391
2392 llvm::SmallVector<Type, 4> structInnerTypes;
2393 structType.getInnerTypes(structInnerTypes);
2394 result.addTypes(structInnerTypes);
2395
2396 if (parser.resolveOperand(operand, declType, result.operands))
2397 return failure();
2398 return success();
2399}
2400
2401void StructExplodeOp::print(OpAsmPrinter &printer) {
2402 printer << " ";
2403 printer.printOperand(getInput());
2404 printer.printOptionalAttrDict((*this)->getAttrs());
2405 printer << " : " << getInput().getType();
2406}
2407
2408LogicalResult StructExplodeOp::fold(FoldAdaptor adaptor,
2409 SmallVectorImpl<OpFoldResult> &results) {
2410 auto input = adaptor.getInput();
2411 if (!input)
2412 return failure();
2413 llvm::copy(cast<ArrayAttr>(input), std::back_inserter(results));
2414 return success();
2415}
2416
2417LogicalResult StructExplodeOp::canonicalize(StructExplodeOp op,
2418 PatternRewriter &rewriter) {
2419 auto *inputOp = op.getInput().getDefiningOp();
2420 auto elements = type_cast<StructType>(op.getInput().getType()).getElements();
2421 auto result = failure();
2422 auto opResults = op.getResults();
2423 for (uint32_t index = 0; index < elements.size(); index++) {
2424 if (auto foldResult = foldStructExtract(inputOp, index)) {
2425 rewriter.replaceAllUsesWith(opResults[index], foldResult);
2426 result = success();
2427 }
2428 }
2429 return result;
2430}
2431
2432void StructExplodeOp::getAsmResultNames(
2433 function_ref<void(Value, StringRef)> setNameFn) {
2434 auto structType = type_cast<StructType>(getInput().getType());
2435 for (auto [res, field] : llvm::zip(getResults(), structType.getElements()))
2436 setNameFn(res, field.name.str());
2437}
2438
2439void StructExplodeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
2440 Value input) {
2441 StructType inputType = dyn_cast<StructType>(input.getType());
2442 assert(inputType);
2443 SmallVector<Type, 16> fieldTypes;
2444 for (auto field : inputType.getElements())
2445 fieldTypes.push_back(field.type);
2446 build(odsBuilder, odsState, fieldTypes, input);
2447}
2448
2449//===----------------------------------------------------------------------===//
2450// StructExtractOp
2451//===----------------------------------------------------------------------===//
2452
2453/// Ensure an aggregate op's field index is within the bounds of
2454/// the aggregate type and the accessed field is of 'elementType'.
2455template <typename AggregateOp, typename AggregateType>
2456static LogicalResult verifyAggregateFieldIndexAndType(AggregateOp &op,
2457 AggregateType aggType,
2458 Type elementType) {
2459 auto index = op.getFieldIndex();
2460 if (index >= aggType.getElements().size())
2461 return op.emitOpError() << "field index " << index
2462 << " exceeds element count of aggregate type";
2463
2465 getCanonicalType(aggType.getElements()[index].type))
2466 return op.emitOpError()
2467 << "type " << aggType.getElements()[index].type
2468 << " of accessed field in aggregate at index " << index
2469 << " does not match expected type " << elementType;
2470
2471 return success();
2472}
2473
2474LogicalResult StructExtractOp::verify() {
2475 return verifyAggregateFieldIndexAndType<StructExtractOp, StructType>(
2476 *this, getInput().getType(), getType());
2477}
2478
2479/// Use the same parser for both struct_extract and union_extract since the
2480/// syntax is identical.
2481template <typename AggregateType>
2482static ParseResult parseExtractOp(OpAsmParser &parser, OperationState &result) {
2483 OpAsmParser::UnresolvedOperand operand;
2484 StringAttr fieldName;
2485 Type declType;
2486
2487 if (parser.parseOperand(operand) || parser.parseLSquare() ||
2488 parser.parseAttribute(fieldName) || parser.parseRSquare() ||
2489 parser.parseOptionalAttrDict(result.attributes) ||
2490 parser.parseColonType(declType))
2491 return failure();
2492 auto aggType = type_dyn_cast<AggregateType>(declType);
2493 if (!aggType)
2494 return parser.emitError(parser.getNameLoc(),
2495 "invalid kind of type specified");
2496
2497 auto fieldIndex = aggType.getFieldIndex(fieldName);
2498 if (!fieldIndex) {
2499 parser.emitError(parser.getNameLoc(), "field name '" +
2500 fieldName.getValue() +
2501 "' not found in aggregate type");
2502 return failure();
2503 }
2504
2505 auto indexAttr =
2506 IntegerAttr::get(IntegerType::get(parser.getContext(), 32), *fieldIndex);
2507 result.addAttribute("fieldIndex", indexAttr);
2508 Type resultType = aggType.getElements()[*fieldIndex].type;
2509 result.addTypes(resultType);
2510
2511 if (parser.resolveOperand(operand, declType, result.operands))
2512 return failure();
2513 return success();
2514}
2515
2516/// Use the same printer for both struct_extract and union_extract since the
2517/// syntax is identical.
2518template <typename AggType>
2519static void printExtractOp(OpAsmPrinter &printer, AggType op) {
2520 printer << " ";
2521 printer.printOperand(op.getInput());
2522 printer << "[\"" << op.getFieldName() << "\"]";
2523 printer.printOptionalAttrDict(op->getAttrs(), {"fieldIndex"});
2524 printer << " : " << op.getInput().getType();
2525}
2526
2527ParseResult StructExtractOp::parse(OpAsmParser &parser,
2528 OperationState &result) {
2529 return parseExtractOp<StructType>(parser, result);
2530}
2531
2532void StructExtractOp::print(OpAsmPrinter &printer) {
2533 printExtractOp(printer, *this);
2534}
2535
2536void StructExtractOp::build(OpBuilder &builder, OperationState &odsState,
2537 Value input, StructType::FieldInfo field) {
2538 auto fieldIndex =
2539 type_cast<StructType>(input.getType()).getFieldIndex(field.name);
2540 assert(fieldIndex.has_value() && "field name not found in aggregate type");
2541 build(builder, odsState, field.type, input, *fieldIndex);
2542}
2543
2544void StructExtractOp::build(OpBuilder &builder, OperationState &odsState,
2545 Value input, StringAttr fieldName) {
2546 auto structType = type_cast<StructType>(input.getType());
2547 auto fieldIndex = structType.getFieldIndex(fieldName);
2548 assert(fieldIndex.has_value() && "field name not found in aggregate type");
2549 auto resultType = structType.getElements()[*fieldIndex].type;
2550 build(builder, odsState, resultType, input, *fieldIndex);
2551}
2552
2553OpFoldResult StructExtractOp::fold(FoldAdaptor adaptor) {
2554 if (auto constOperand = adaptor.getInput()) {
2555 // Fold extract from aggregate constant
2556 auto operandAttr = llvm::cast<ArrayAttr>(constOperand);
2557 return operandAttr.getValue()[getFieldIndex()];
2558 }
2559
2560 if (auto foldResult =
2561 foldStructExtract(getInput().getDefiningOp(), getFieldIndex()))
2562 return foldResult;
2563 return {};
2564}
2565
2566LogicalResult StructExtractOp::canonicalize(StructExtractOp op,
2567 PatternRewriter &rewriter) {
2568 auto inputOp = op.getInput().getDefiningOp();
2569
2570 // b = extract(inject(x["a"], v0)["b"]) => extract(x, "b")
2571 if (auto structInject = dyn_cast_or_null<StructInjectOp>(inputOp)) {
2572 if (structInject.getFieldIndex() != op.getFieldIndex()) {
2573 rewriter.replaceOpWithNewOp<StructExtractOp>(
2574 op, op.getType(), structInject.getInput(), op.getFieldIndexAttr());
2575 return success();
2576 }
2577 }
2578
2579 return failure();
2580}
2581
2582void StructExtractOp::getAsmResultNames(
2583 function_ref<void(Value, StringRef)> setNameFn) {
2584 setNameFn(getResult(), getFieldName());
2585}
2586
2587//===----------------------------------------------------------------------===//
2588// StructInjectOp
2589//===----------------------------------------------------------------------===//
2590
2591void StructInjectOp::build(OpBuilder &builder, OperationState &odsState,
2592 Value input, StringAttr fieldName, Value newValue) {
2593 auto structType = type_cast<StructType>(input.getType());
2594 auto fieldIndex = structType.getFieldIndex(fieldName);
2595 assert(fieldIndex.has_value() && "field name not found in aggregate type");
2596 build(builder, odsState, input, *fieldIndex, newValue);
2597}
2598
2599LogicalResult StructInjectOp::verify() {
2600 return verifyAggregateFieldIndexAndType<StructInjectOp, StructType>(
2601 *this, getInput().getType(), getNewValue().getType());
2602}
2603
2604ParseResult StructInjectOp::parse(OpAsmParser &parser, OperationState &result) {
2605 llvm::SMLoc inputOperandsLoc = parser.getCurrentLocation();
2606 OpAsmParser::UnresolvedOperand operand, val;
2607 StringAttr fieldName;
2608 Type declType;
2609
2610 if (parser.parseOperand(operand) || parser.parseLSquare() ||
2611 parser.parseAttribute(fieldName) || parser.parseRSquare() ||
2612 parser.parseComma() || parser.parseOperand(val) ||
2613 parser.parseOptionalAttrDict(result.attributes) ||
2614 parser.parseColonType(declType))
2615 return failure();
2616 auto structType = type_dyn_cast<StructType>(declType);
2617 if (!structType)
2618 return parser.emitError(inputOperandsLoc, "invalid kind of type specified");
2619
2620 auto fieldIndex = structType.getFieldIndex(fieldName);
2621 if (!fieldIndex) {
2622 parser.emitError(parser.getNameLoc(), "field name '" +
2623 fieldName.getValue() +
2624 "' not found in aggregate type");
2625 return failure();
2626 }
2627
2628 auto indexAttr =
2629 IntegerAttr::get(IntegerType::get(parser.getContext(), 32), *fieldIndex);
2630 result.addAttribute("fieldIndex", indexAttr);
2631 result.addTypes(declType);
2632
2633 Type resultType = structType.getElements()[*fieldIndex].type;
2634 if (parser.resolveOperands({operand, val}, {declType, resultType},
2635 inputOperandsLoc, result.operands))
2636 return failure();
2637 return success();
2638}
2639
2640void StructInjectOp::print(OpAsmPrinter &printer) {
2641 printer << " ";
2642 printer.printOperand(getInput());
2643 printer << "[\"" << getFieldName() << "\"], ";
2644 printer.printOperand(getNewValue());
2645 printer.printOptionalAttrDict((*this)->getAttrs(), {"fieldIndex"});
2646 printer << " : " << getInput().getType();
2647}
2648
2649OpFoldResult StructInjectOp::fold(FoldAdaptor adaptor) {
2650 auto input = adaptor.getInput();
2651 auto newValue = adaptor.getNewValue();
2652 if (!input || !newValue)
2653 return {};
2654 SmallVector<Attribute> array;
2655 llvm::copy(cast<ArrayAttr>(input), std::back_inserter(array));
2656 array[getFieldIndex()] = newValue;
2657 return ArrayAttr::get(getContext(), array);
2658}
2659
2660LogicalResult StructInjectOp::canonicalize(StructInjectOp op,
2661 PatternRewriter &rewriter) {
2662 // Canonicalize multiple injects into a create op and eliminate overwrites.
2663 SmallPtrSet<Operation *, 4> injects;
2664 DenseMap<StringAttr, Value> fields;
2665
2666 // Chase a chain of injects. Bail out if cycles are present.
2667 StructInjectOp inject = op;
2668 Value input;
2669 do {
2670 if (!injects.insert(inject).second)
2671 return failure();
2672
2673 fields.try_emplace(inject.getFieldNameAttr(), inject.getNewValue());
2674 input = inject.getInput();
2675 inject = dyn_cast_or_null<StructInjectOp>(input.getDefiningOp());
2676 } while (inject);
2677 assert(input && "missing input to inject chain");
2678
2679 auto ty = hw::type_cast<StructType>(op.getType());
2680 auto elements = ty.getElements();
2681
2682 // If the inject chain sets all fields, canonicalize to create.
2683 if (fields.size() == elements.size()) {
2684 SmallVector<Value> createFields;
2685 for (const auto &field : elements) {
2686 auto it = fields.find(field.name);
2687 assert(it != fields.end() && "missing field");
2688 createFields.push_back(it->second);
2689 }
2690 rewriter.replaceOpWithNewOp<StructCreateOp>(op, ty, createFields);
2691 return success();
2692 }
2693
2694 // Nothing to canonicalize, only the original inject in the chain.
2695 if (injects.size() == fields.size())
2696 return failure();
2697
2698 // Eliminate overwrites. The hash map contains the last write to each field.
2699 for (uint32_t fieldIndex = 0; fieldIndex < elements.size(); fieldIndex++) {
2700 auto it = fields.find(elements[fieldIndex].name);
2701 if (it == fields.end())
2702 continue;
2703 input = rewriter.create<StructInjectOp>(op.getLoc(), ty, input, fieldIndex,
2704 it->second);
2705 }
2706
2707 rewriter.replaceOp(op, input);
2708 return success();
2709}
2710
2711//===----------------------------------------------------------------------===//
2712// UnionCreateOp
2713//===----------------------------------------------------------------------===//
2714
2715LogicalResult UnionCreateOp::verify() {
2716 return verifyAggregateFieldIndexAndType<UnionCreateOp, UnionType>(
2717 *this, getType(), getInput().getType());
2718}
2719
2720void UnionCreateOp::build(OpBuilder &builder, OperationState &odsState,
2721 Type unionType, StringAttr fieldName, Value input) {
2722 auto fieldIndex = type_cast<UnionType>(unionType).getFieldIndex(fieldName);
2723 assert(fieldIndex.has_value() && "field name not found in aggregate type");
2724 build(builder, odsState, unionType, *fieldIndex, input);
2725}
2726
2727ParseResult UnionCreateOp::parse(OpAsmParser &parser, OperationState &result) {
2728 Type declOrAliasType;
2729 StringAttr fieldName;
2730 OpAsmParser::UnresolvedOperand input;
2731 llvm::SMLoc fieldLoc = parser.getCurrentLocation();
2732
2733 if (parser.parseAttribute(fieldName) || parser.parseComma() ||
2734 parser.parseOperand(input) ||
2735 parser.parseOptionalAttrDict(result.attributes) ||
2736 parser.parseColonType(declOrAliasType))
2737 return failure();
2738
2739 auto declType = type_dyn_cast<UnionType>(declOrAliasType);
2740 if (!declType)
2741 return parser.emitError(parser.getNameLoc(),
2742 "expected !hw.union type or alias");
2743
2744 auto fieldIndex = declType.getFieldIndex(fieldName);
2745 if (!fieldIndex) {
2746 parser.emitError(fieldLoc, "cannot find union field '")
2747 << fieldName.getValue() << '\'';
2748 return failure();
2749 }
2750
2751 auto indexAttr =
2752 IntegerAttr::get(IntegerType::get(parser.getContext(), 32), *fieldIndex);
2753 result.addAttribute("fieldIndex", indexAttr);
2754 Type inputType = declType.getElements()[*fieldIndex].type;
2755
2756 if (parser.resolveOperand(input, inputType, result.operands))
2757 return failure();
2758 result.addTypes({declOrAliasType});
2759 return success();
2760}
2761
2762void UnionCreateOp::print(OpAsmPrinter &printer) {
2763 printer << " \"" << getFieldName() << "\", ";
2764 printer.printOperand(getInput());
2765 printer.printOptionalAttrDict((*this)->getAttrs(), {"fieldIndex"});
2766 printer << " : " << getType();
2767}
2768
2769//===----------------------------------------------------------------------===//
2770// UnionExtractOp
2771//===----------------------------------------------------------------------===//
2772
2773ParseResult UnionExtractOp::parse(OpAsmParser &parser, OperationState &result) {
2774 return parseExtractOp<UnionType>(parser, result);
2775}
2776
2777void UnionExtractOp::print(OpAsmPrinter &printer) {
2778 printExtractOp(printer, *this);
2779}
2780
2781LogicalResult UnionExtractOp::inferReturnTypes(
2782 MLIRContext *context, std::optional<Location> loc, ValueRange operands,
2783 DictionaryAttr attrs, mlir::OpaqueProperties properties,
2784 mlir::RegionRange regions, SmallVectorImpl<Type> &results) {
2785 Adaptor adaptor(operands, attrs, properties, regions);
2786 auto unionElements =
2787 hw::type_cast<UnionType>((adaptor.getInput().getType())).getElements();
2788 unsigned fieldIndex = adaptor.getFieldIndexAttr().getValue().getZExtValue();
2789 if (fieldIndex >= unionElements.size()) {
2790 if (loc)
2791 mlir::emitError(*loc, "field index " + Twine(fieldIndex) +
2792 " exceeds element count of aggregate type");
2793 return failure();
2794 }
2795 results.push_back(unionElements[fieldIndex].type);
2796 return success();
2797}
2798
2799void UnionExtractOp::build(OpBuilder &odsBuilder, OperationState &odsState,
2800 Value input, StringAttr fieldName) {
2801 auto unionType = type_cast<UnionType>(input.getType());
2802 auto fieldIndex = unionType.getFieldIndex(fieldName);
2803 assert(fieldIndex.has_value() && "field name not found in aggregate type");
2804 auto resultType = unionType.getElements()[*fieldIndex].type;
2805 build(odsBuilder, odsState, resultType, input, *fieldIndex);
2806}
2807
2808//===----------------------------------------------------------------------===//
2809// ArrayGetOp
2810//===----------------------------------------------------------------------===//
2811
2812// An array_get of an array_create with a constant index can just be the
2813// array_create operand at the constant index. If the array_create has a
2814// single uniform value for each element, just return that value regardless of
2815// the index. If the array is constructed from a constant by a bitcast
2816// operation, we can fold into a constant.
2817OpFoldResult ArrayGetOp::fold(FoldAdaptor adaptor) {
2818 auto inputCst = dyn_cast_or_null<ArrayAttr>(adaptor.getInput());
2819 auto indexCst = dyn_cast_or_null<IntegerAttr>(adaptor.getIndex());
2820
2821 if (inputCst) {
2822 // Constant array index.
2823 if (indexCst) {
2824 auto indexVal = indexCst.getValue();
2825 if (indexVal.getBitWidth() < 64) {
2826 auto index = indexVal.getZExtValue();
2827 return inputCst[inputCst.size() - 1 - index];
2828 }
2829 }
2830 // If all elements of the array are the same, we can return any element of
2831 // array.
2832 if (!inputCst.empty() && llvm::all_equal(inputCst))
2833 return inputCst[0];
2834 }
2835
2836 // array_get(bitcast(c), i) -> c[i*w+w-1:i*w]
2837 if (auto bitcast = getInput().getDefiningOp<hw::BitcastOp>()) {
2838 auto intTy = dyn_cast<IntegerType>(getType());
2839 if (!intTy)
2840 return {};
2841 auto bitcastInputOp = bitcast.getInput().getDefiningOp<hw::ConstantOp>();
2842 if (!bitcastInputOp)
2843 return {};
2844 if (!indexCst)
2845 return {};
2846 auto bitcastInputCst = bitcastInputOp.getValue();
2847 // Calculate the index. Make sure to zero-extend the index value before
2848 // multiplying the element width.
2849 auto startIdx = indexCst.getValue().zext(bitcastInputCst.getBitWidth()) *
2850 getType().getIntOrFloatBitWidth();
2851 // Extract [startIdx + width - 1: startIdx].
2852 return IntegerAttr::get(intTy, bitcastInputCst.lshr(startIdx).trunc(
2853 intTy.getIntOrFloatBitWidth()));
2854 }
2855
2856 auto inputCreate = getInput().getDefiningOp<ArrayCreateOp>();
2857 if (!inputCreate)
2858 return {};
2859
2860 if (auto uniformValue = inputCreate.getUniformElement())
2861 return uniformValue;
2862
2863 if (!indexCst || indexCst.getValue().getBitWidth() > 64)
2864 return {};
2865
2866 uint64_t index = indexCst.getValue().getLimitedValue();
2867 auto createInputs = inputCreate.getInputs();
2868 if (index >= createInputs.size())
2869 return {};
2870 return createInputs[createInputs.size() - index - 1];
2871}
2872
2873LogicalResult ArrayGetOp::canonicalize(ArrayGetOp op,
2874 PatternRewriter &rewriter) {
2875 auto idxOpt = getUIntFromValue(op.getIndex());
2876 if (!idxOpt)
2877 return failure();
2878
2879 auto *inputOp = op.getInput().getDefiningOp();
2880 if (auto inputSlice = dyn_cast_or_null<ArraySliceOp>(inputOp)) {
2881 // get(slice(a, n), m) -> get(a, n + m)
2882 auto offsetOp = inputSlice.getLowIndex();
2883 auto offsetOpt = getUIntFromValue(offsetOp);
2884 if (!offsetOpt)
2885 return failure();
2886
2887 uint64_t offset = *offsetOpt + *idxOpt;
2888 auto newOffset =
2889 rewriter.create<ConstantOp>(op.getLoc(), offsetOp.getType(), offset);
2890 rewriter.replaceOpWithNewOp<ArrayGetOp>(op, inputSlice.getInput(),
2891 newOffset);
2892 return success();
2893 }
2894
2895 if (auto inputConcat = dyn_cast_or_null<ArrayConcatOp>(inputOp)) {
2896 // get(concat(a0, a1, ...), m) -> get(an, m - s0 - s1 - ...)
2897 uint64_t elemIndex = *idxOpt;
2898 for (auto input : llvm::reverse(inputConcat.getInputs())) {
2899 size_t size = hw::type_cast<ArrayType>(input.getType()).getNumElements();
2900 if (elemIndex >= size) {
2901 elemIndex -= size;
2902 continue;
2903 }
2904
2905 unsigned indexWidth = size == 1 ? 1 : llvm::Log2_64_Ceil(size);
2906 auto newIdxOp = rewriter.create<ConstantOp>(
2907 op.getLoc(), rewriter.getIntegerType(indexWidth), elemIndex);
2908
2909 rewriter.replaceOpWithNewOp<ArrayGetOp>(op, input, newIdxOp);
2910 return success();
2911 }
2912 return failure();
2913 }
2914
2915 // array_get const, (array_get sel, (array_create a, b, c, d)) -->
2916 // array_get sel, (array_create (array_get const a), (array_get const b),
2917 // (array_get const, c), (array_get const, d))
2918 if (auto innerGet = dyn_cast_or_null<hw::ArrayGetOp>(inputOp)) {
2919 if (!innerGet.getIndex().getDefiningOp<hw::ConstantOp>()) {
2920 if (auto create =
2921 innerGet.getInput().getDefiningOp<hw::ArrayCreateOp>()) {
2922
2923 SmallVector<Value> newValues;
2924 for (auto operand : create.getOperands())
2925 newValues.push_back(rewriter.createOrFold<hw::ArrayGetOp>(
2926 op.getLoc(), operand, op.getIndex()));
2927
2928 rewriter.replaceOpWithNewOp<hw::ArrayGetOp>(
2929 op,
2930 rewriter.createOrFold<hw::ArrayCreateOp>(op.getLoc(), newValues),
2931 innerGet.getIndex());
2932 return success();
2933 }
2934 }
2935 }
2936
2937 return failure();
2938}
2939
2940//===----------------------------------------------------------------------===//
2941// TypedeclOp
2942//===----------------------------------------------------------------------===//
2943
2944StringRef TypedeclOp::getPreferredName() {
2945 return getVerilogName().value_or(getName());
2946}
2947
2948Type TypedeclOp::getAliasType() {
2949 auto parentScope = cast<hw::TypeScopeOp>(getOperation()->getParentOp());
2950 return hw::TypeAliasType::get(
2951 SymbolRefAttr::get(parentScope.getSymNameAttr(),
2952 {FlatSymbolRefAttr::get(*this)}),
2953 getType());
2954}
2955
2956//===----------------------------------------------------------------------===//
2957// BitcastOp
2958//===----------------------------------------------------------------------===//
2959
2960OpFoldResult BitcastOp::fold(FoldAdaptor) {
2961 // Identity.
2962 // bitcast(%a) : A -> A ==> %a
2963 if (getOperand().getType() == getType())
2964 return getOperand();
2965
2966 return {};
2967}
2968
2969LogicalResult BitcastOp::canonicalize(BitcastOp op, PatternRewriter &rewriter) {
2970 // Composition.
2971 // %b = bitcast(%a) : A -> B
2972 // bitcast(%b) : B -> C
2973 // ===> bitcast(%a) : A -> C
2974 auto inputBitcast =
2975 dyn_cast_or_null<BitcastOp>(op.getInput().getDefiningOp());
2976 if (!inputBitcast)
2977 return failure();
2978 auto bitcast = rewriter.createOrFold<BitcastOp>(op.getLoc(), op.getType(),
2979 inputBitcast.getInput());
2980 rewriter.replaceOp(op, bitcast);
2981 return success();
2982}
2983
2984LogicalResult BitcastOp::verify() {
2985 if (getBitWidth(getInput().getType()) != getBitWidth(getResult().getType()))
2986 return this->emitOpError("Bitwidth of input must match result");
2987 return success();
2988}
2989
2990//===----------------------------------------------------------------------===//
2991// HierPathOp helpers.
2992//===----------------------------------------------------------------------===//
2993
2994bool HierPathOp::dropModule(StringAttr moduleToDrop) {
2995 SmallVector<Attribute, 4> newPath;
2996 bool updateMade = false;
2997 for (auto nameRef : getNamepath()) {
2998 // nameRef is either an InnerRefAttr or a FlatSymbolRefAttr.
2999 if (auto ref = dyn_cast<hw::InnerRefAttr>(nameRef)) {
3000 if (ref.getModule() == moduleToDrop)
3001 updateMade = true;
3002 else
3003 newPath.push_back(ref);
3004 } else {
3005 if (cast<FlatSymbolRefAttr>(nameRef).getAttr() == moduleToDrop)
3006 updateMade = true;
3007 else
3008 newPath.push_back(nameRef);
3009 }
3010 }
3011 if (updateMade)
3012 setNamepathAttr(ArrayAttr::get(getContext(), newPath));
3013 return updateMade;
3014}
3015
3016bool HierPathOp::inlineModule(StringAttr moduleToDrop) {
3017 SmallVector<Attribute, 4> newPath;
3018 bool updateMade = false;
3019 StringRef inlinedInstanceName = "";
3020 for (auto nameRef : getNamepath()) {
3021 // nameRef is either an InnerRefAttr or a FlatSymbolRefAttr.
3022 if (auto ref = dyn_cast<hw::InnerRefAttr>(nameRef)) {
3023 if (ref.getModule() == moduleToDrop) {
3024 inlinedInstanceName = ref.getName().getValue();
3025 updateMade = true;
3026 } else if (!inlinedInstanceName.empty()) {
3027 newPath.push_back(hw::InnerRefAttr::get(
3028 ref.getModule(),
3029 StringAttr::get(getContext(), inlinedInstanceName + "_" +
3030 ref.getName().getValue())));
3031 inlinedInstanceName = "";
3032 } else
3033 newPath.push_back(ref);
3034 } else {
3035 if (cast<FlatSymbolRefAttr>(nameRef).getAttr() == moduleToDrop)
3036 updateMade = true;
3037 else
3038 newPath.push_back(nameRef);
3039 }
3040 }
3041 if (updateMade)
3042 setNamepathAttr(ArrayAttr::get(getContext(), newPath));
3043 return updateMade;
3044}
3045
3046bool HierPathOp::updateModule(StringAttr oldMod, StringAttr newMod) {
3047 SmallVector<Attribute, 4> newPath;
3048 bool updateMade = false;
3049 for (auto nameRef : getNamepath()) {
3050 // nameRef is either an InnerRefAttr or a FlatSymbolRefAttr.
3051 if (auto ref = dyn_cast<hw::InnerRefAttr>(nameRef)) {
3052 if (ref.getModule() == oldMod) {
3053 newPath.push_back(hw::InnerRefAttr::get(newMod, ref.getName()));
3054 updateMade = true;
3055 } else
3056 newPath.push_back(ref);
3057 } else {
3058 if (cast<FlatSymbolRefAttr>(nameRef).getAttr() == oldMod) {
3059 newPath.push_back(FlatSymbolRefAttr::get(newMod));
3060 updateMade = true;
3061 } else
3062 newPath.push_back(nameRef);
3063 }
3064 }
3065 if (updateMade)
3066 setNamepathAttr(ArrayAttr::get(getContext(), newPath));
3067 return updateMade;
3068}
3069
3070bool HierPathOp::updateModuleAndInnerRef(
3071 StringAttr oldMod, StringAttr newMod,
3072 const llvm::DenseMap<StringAttr, StringAttr> &innerSymRenameMap) {
3073 auto fromRef = FlatSymbolRefAttr::get(oldMod);
3074 if (oldMod == newMod)
3075 return false;
3076
3077 auto namepathNew = getNamepath().getValue().vec();
3078 bool updateMade = false;
3079 // Break from the loop if the module is found, since it can occur only once.
3080 for (auto &element : namepathNew) {
3081 if (auto innerRef = dyn_cast<hw::InnerRefAttr>(element)) {
3082 if (innerRef.getModule() != oldMod)
3083 continue;
3084 auto symName = innerRef.getName();
3085 // Since the module got updated, the old innerRef symbol inside oldMod
3086 // should also be updated to the new symbol inside the newMod.
3087 auto to = innerSymRenameMap.find(symName);
3088 if (to != innerSymRenameMap.end())
3089 symName = to->second;
3090 updateMade = true;
3091 element = hw::InnerRefAttr::get(newMod, symName);
3092 break;
3093 }
3094 if (element != fromRef)
3095 continue;
3096
3097 updateMade = true;
3098 element = FlatSymbolRefAttr::get(newMod);
3099 break;
3100 }
3101 if (updateMade)
3102 setNamepathAttr(ArrayAttr::get(getContext(), namepathNew));
3103 return updateMade;
3104}
3105
3106bool HierPathOp::truncateAtModule(StringAttr atMod, bool includeMod) {
3107 SmallVector<Attribute, 4> newPath;
3108 bool updateMade = false;
3109 for (auto nameRef : getNamepath()) {
3110 // nameRef is either an InnerRefAttr or a FlatSymbolRefAttr.
3111 if (auto ref = dyn_cast<hw::InnerRefAttr>(nameRef)) {
3112 if (ref.getModule() == atMod) {
3113 updateMade = true;
3114 if (includeMod)
3115 newPath.push_back(ref);
3116 } else
3117 newPath.push_back(ref);
3118 } else {
3119 if (cast<FlatSymbolRefAttr>(nameRef).getAttr() == atMod && !includeMod)
3120 updateMade = true;
3121 else
3122 newPath.push_back(nameRef);
3123 }
3124 if (updateMade)
3125 break;
3126 }
3127 if (updateMade)
3128 setNamepathAttr(ArrayAttr::get(getContext(), newPath));
3129 return updateMade;
3130}
3131
3132/// Return just the module part of the namepath at a specific index.
3133StringAttr HierPathOp::modPart(unsigned i) {
3134 return TypeSwitch<Attribute, StringAttr>(getNamepath()[i])
3135 .Case<FlatSymbolRefAttr>([](auto a) { return a.getAttr(); })
3136 .Case<hw::InnerRefAttr>([](auto a) { return a.getModule(); });
3137}
3138
3139/// Return the root module.
3140StringAttr HierPathOp::root() {
3141 assert(!getNamepath().empty());
3142 return modPart(0);
3143}
3144
3145/// Return true if the NLA has the module in its path.
3146bool HierPathOp::hasModule(StringAttr modName) {
3147 for (auto nameRef : getNamepath()) {
3148 // nameRef is either an InnerRefAttr or a FlatSymbolRefAttr.
3149 if (auto ref = dyn_cast<hw::InnerRefAttr>(nameRef)) {
3150 if (ref.getModule() == modName)
3151 return true;
3152 } else {
3153 if (cast<FlatSymbolRefAttr>(nameRef).getAttr() == modName)
3154 return true;
3155 }
3156 }
3157 return false;
3158}
3159
3160/// Return true if the NLA has the InnerSym .
3161bool HierPathOp::hasInnerSym(StringAttr modName, StringAttr symName) const {
3162 for (auto nameRef : const_cast<HierPathOp *>(this)->getNamepath())
3163 if (auto ref = dyn_cast<hw::InnerRefAttr>(nameRef))
3164 if (ref.getName() == symName && ref.getModule() == modName)
3165 return true;
3166
3167 return false;
3168}
3169
3170/// Return just the reference part of the namepath at a specific index. This
3171/// will return an empty attribute if this is the leaf and the leaf is a module.
3172StringAttr HierPathOp::refPart(unsigned i) {
3173 return TypeSwitch<Attribute, StringAttr>(getNamepath()[i])
3174 .Case<FlatSymbolRefAttr>([](auto a) { return StringAttr({}); })
3175 .Case<hw::InnerRefAttr>([](auto a) { return a.getName(); });
3176}
3177
3178/// Return the leaf reference. This returns an empty attribute if the leaf
3179/// reference is a module.
3180StringAttr HierPathOp::ref() {
3181 assert(!getNamepath().empty());
3182 return refPart(getNamepath().size() - 1);
3183}
3184
3185/// Return the leaf module.
3186StringAttr HierPathOp::leafMod() {
3187 assert(!getNamepath().empty());
3188 return modPart(getNamepath().size() - 1);
3189}
3190
3191/// Returns true if this NLA targets an instance of a module (as opposed to
3192/// an instance's port or something inside an instance).
3193bool HierPathOp::isModule() { return !ref(); }
3194
3195/// Returns true if this NLA targets something inside a module (as opposed
3196/// to a module or an instance of a module);
3197bool HierPathOp::isComponent() { return (bool)ref(); }
3198
3199// Verify the HierPathOp.
3200// 1. Iterate over the namepath.
3201// 2. The namepath should be a valid instance path, specified either on a
3202// module or a declaration inside a module.
3203// 3. Each element in the namepath is an InnerRefAttr except possibly the
3204// last element.
3205// 4. Make sure that the InnerRefAttr is legal, by verifying the module name
3206// and the corresponding inner_sym on the instance.
3207// 5. Make sure that the instance path is legal, by verifying the sequence of
3208// instance and the expected module occurs as the next element in the path.
3209// 6. The last element of the namepath, can be an InnerRefAttr on either a
3210// module port or a declaration inside the module.
3211// 7. The last element of the namepath can also be a module symbol.
3212LogicalResult HierPathOp::verifyInnerRefs(hw::InnerRefNamespace &ns) {
3213 ArrayAttr expectedModuleNames = {};
3214 auto checkExpectedModule = [&](Attribute name) -> LogicalResult {
3215 if (!expectedModuleNames)
3216 return success();
3217 if (llvm::any_of(expectedModuleNames,
3218 [name](Attribute attr) { return attr == name; }))
3219 return success();
3220 auto diag = emitOpError() << "instance path is incorrect. Expected ";
3221 size_t n = expectedModuleNames.size();
3222 if (n != 1) {
3223 diag << "one of ";
3224 }
3225 for (size_t i = 0; i < n; ++i) {
3226 if (i != 0)
3227 diag << ((i + 1 == n) ? " or " : ", ");
3228 diag << cast<StringAttr>(expectedModuleNames[i]);
3229 }
3230 diag << ". Instead found: " << name;
3231 return diag;
3232 };
3233
3234 if (!getNamepath() || getNamepath().empty())
3235 return emitOpError() << "the instance path cannot be empty";
3236 for (unsigned i = 0, s = getNamepath().size() - 1; i < s; ++i) {
3237 hw::InnerRefAttr innerRef = dyn_cast<hw::InnerRefAttr>(getNamepath()[i]);
3238 if (!innerRef)
3239 return emitOpError()
3240 << "the instance path can only contain inner sym reference"
3241 << ", only the leaf can refer to a module symbol";
3242
3243 if (failed(checkExpectedModule(innerRef.getModule())))
3244 return failure();
3245
3246 auto instOp = ns.lookupOp<igraph::InstanceOpInterface>(innerRef);
3247 if (!instOp)
3248 return emitOpError() << " module: " << innerRef.getModule()
3249 << " does not contain any instance with symbol: "
3250 << innerRef.getName();
3251 expectedModuleNames = instOp.getReferencedModuleNamesAttr();
3252 }
3253
3254 // The instance path has been verified. Now verify the last element.
3255 auto leafRef = getNamepath()[getNamepath().size() - 1];
3256 if (auto innerRef = dyn_cast<hw::InnerRefAttr>(leafRef)) {
3257 if (!ns.lookup(innerRef)) {
3258 return emitOpError() << " operation with symbol: " << innerRef
3259 << " was not found ";
3260 }
3261 if (failed(checkExpectedModule(innerRef.getModule())))
3262 return failure();
3263 } else if (failed(checkExpectedModule(
3264 cast<FlatSymbolRefAttr>(leafRef).getAttr()))) {
3265 return failure();
3266 }
3267 return success();
3268}
3269
3270void HierPathOp::print(OpAsmPrinter &p) {
3271 p << " ";
3272
3273 // Print visibility if present.
3274 StringRef visibilityAttrName = SymbolTable::getVisibilityAttrName();
3275 if (auto visibility =
3276 getOperation()->getAttrOfType<StringAttr>(visibilityAttrName))
3277 p << visibility.getValue() << ' ';
3278
3279 p.printSymbolName(getSymName());
3280 p << " [";
3281 llvm::interleaveComma(getNamepath().getValue(), p, [&](Attribute attr) {
3282 if (auto ref = dyn_cast<hw::InnerRefAttr>(attr)) {
3283 p.printSymbolName(ref.getModule().getValue());
3284 p << "::";
3285 p.printSymbolName(ref.getName().getValue());
3286 } else {
3287 p.printSymbolName(cast<FlatSymbolRefAttr>(attr).getValue());
3288 }
3289 });
3290 p << "]";
3291 p.printOptionalAttrDict(
3292 (*this)->getAttrs(),
3293 {SymbolTable::getSymbolAttrName(), "namepath", visibilityAttrName});
3294}
3295
3296ParseResult HierPathOp::parse(OpAsmParser &parser, OperationState &result) {
3297 // Parse the visibility attribute.
3298 (void)mlir::impl::parseOptionalVisibilityKeyword(parser, result.attributes);
3299
3300 // Parse the symbol name.
3301 StringAttr symName;
3302 if (parser.parseSymbolName(symName, SymbolTable::getSymbolAttrName(),
3303 result.attributes))
3304 return failure();
3305
3306 // Parse the namepath.
3307 SmallVector<Attribute> namepath;
3308 if (parser.parseCommaSeparatedList(
3309 OpAsmParser::Delimiter::Square, [&]() -> ParseResult {
3310 auto loc = parser.getCurrentLocation();
3311 SymbolRefAttr ref;
3312 if (parser.parseAttribute(ref))
3313 return failure();
3314
3315 // "A" is a Ref, "A::b" is a InnerRef, "A::B::c" is an error.
3316 auto pathLength = ref.getNestedReferences().size();
3317 if (pathLength == 0)
3318 namepath.push_back(
3319 FlatSymbolRefAttr::get(ref.getRootReference()));
3320 else if (pathLength == 1)
3321 namepath.push_back(hw::InnerRefAttr::get(ref.getRootReference(),
3322 ref.getLeafReference()));
3323 else
3324 return parser.emitError(loc,
3325 "only one nested reference is allowed");
3326 return success();
3327 }))
3328 return failure();
3329 result.addAttribute("namepath",
3330 ArrayAttr::get(parser.getContext(), namepath));
3331
3332 if (parser.parseOptionalAttrDict(result.attributes))
3333 return failure();
3334
3335 return success();
3336}
3337
3338//===----------------------------------------------------------------------===//
3339// TriggeredOp
3340//===----------------------------------------------------------------------===//
3341
3342void TriggeredOp::build(OpBuilder &builder, OperationState &odsState,
3343 EventControlAttr event, Value trigger,
3344 ValueRange inputs) {
3345 odsState.addOperands(trigger);
3346 odsState.addOperands(inputs);
3347 odsState.addAttribute(getEventAttrName(odsState.name), event);
3348 auto *r = odsState.addRegion();
3349 Block *b = new Block();
3350 r->push_back(b);
3351
3352 llvm::SmallVector<Location> argLocs;
3353 llvm::transform(inputs, std::back_inserter(argLocs),
3354 [&](Value v) { return v.getLoc(); });
3355 b->addArguments(inputs.getTypes(), argLocs);
3356}
3357
3358//===----------------------------------------------------------------------===//
3359// TableGen generated logic.
3360//===----------------------------------------------------------------------===//
3361
3362// Provide the autogenerated implementation guts for the Op classes.
3363#define GET_OP_CLASSES
3364#include "circt/Dialect/HW/HW.cpp.inc"
assert(baseType &&"element must be base type")
MlirType elementType
Definition CHIRRTL.cpp:29
static void buildModule(OpBuilder &builder, OperationState &result, StringAttr name, ArrayRef< PortInfo > ports, ArrayAttr annotations, ArrayAttr layers)
void getAsmBlockArgumentNamesImpl(Operation *op, mlir::Region &region, OpAsmSetValueNameFn setNameFn)
Get a special name to use when printing the entry block arguments of the region contained by an opera...
static LogicalResult verifyModuleCommon(HWModuleLike module)
Definition HWOps.cpp:1078
static void printParamValue(OpAsmPrinter &p, Operation *, Attribute value, Type resultType)
Definition HWOps.cpp:493
static void printModuleOp(OpAsmPrinter &p, ModuleTy mod)
Definition HWOps.cpp:1023
static bool flattenConcatOp(ArrayConcatOp op, PatternRewriter &rewriter)
Definition HWOps.cpp:2124
static LogicalResult foldCreateToSlice(ArrayCreateOp op, PatternRewriter &rewriter)
Definition HWOps.cpp:1858
static SmallVector< PortInfo > getPortList(ModuleTy &mod)
Definition HWOps.cpp:1420
static ArrayAttr arrayOrEmpty(mlir::MLIRContext *context, ArrayRef< Attribute > attrs)
Definition HWOps.cpp:84
FunctionType getHWModuleOpType(Operation *op)
Definition HWOps.cpp:1015
static void printExtractOp(OpAsmPrinter &printer, AggType op)
Use the same printer for both struct_extract and union_extract since the syntax is identical.
Definition HWOps.cpp:2519
static void printArrayConcatTypes(OpAsmPrinter &p, Operation *, TypeRange inputTypes, Type resultType)
Definition HWOps.cpp:2088
static ParseResult parseSliceTypes(OpAsmParser &p, Type &srcType, Type &idxType)
Definition HWOps.cpp:1762
static void modifyModulePorts(Operation *op, ArrayRef< std::pair< unsigned, PortInfo > > insertInputs, ArrayRef< std::pair< unsigned, PortInfo > > insertOutputs, ArrayRef< unsigned > removeInputs, ArrayRef< unsigned > removeOutputs, Block *body=nullptr)
Insert and remove ports of a module.
Definition HWOps.cpp:686
static Value foldStructExtract(Operation *inputOp, uint32_t fieldIndex)
Definition HWOps.cpp:69
static bool hasAttribute(StringRef name, ArrayRef< NamedAttribute > attrs)
Definition HWOps.cpp:893
static void modifyModuleArgs(MLIRContext *context, ArrayRef< std::pair< unsigned, PortInfo > > insertArgs, ArrayRef< unsigned > removeArgs, ArrayRef< Attribute > oldArgNames, ArrayRef< Type > oldArgTypes, ArrayRef< Attribute > oldArgAttrs, ArrayRef< Location > oldArgLocs, SmallVector< Attribute > &newArgNames, SmallVector< Type > &newArgTypes, SmallVector< Attribute > &newArgAttrs, SmallVector< Location > &newArgLocs, Block *body=nullptr)
Internal implementation of argument/result insertion and removal on modules.
Definition HWOps.cpp:587
static bool mergeConcatSlices(ArrayConcatOp op, PatternRewriter &rewriter)
Definition HWOps.cpp:2141
static SmallVector< Location > getAllPortLocs(ModTy module)
Definition HWOps.cpp:1198
static ParseResult parseExtractOp(OpAsmParser &parser, OperationState &result)
Use the same parser for both struct_extract and union_extract since the syntax is identical.
Definition HWOps.cpp:2482
static void setAllPortNames(ArrayRef< Attribute > names, ModTy module)
Definition HWOps.cpp:1268
static void getAsmBlockArgumentNamesImpl(mlir::Region &region, OpAsmSetValueNameFn setNameFn)
Get a special name to use when printing the entry block arguments of the region contained by an opera...
Definition HWOps.cpp:101
static void setHWModuleType(ModTy &mod, ModuleType type)
Definition HWOps.cpp:1341
static ParseResult parseParamValue(OpAsmParser &p, Attribute &value, Type &resultType)
Definition HWOps.cpp:485
static LogicalResult checkAttributes(Operation *op, Attribute attr, Type type)
Definition HWOps.cpp:399
static std::optional< uint64_t > getUIntFromValue(Value value)
Definition HWOps.cpp:1929
static ParseResult parseHWModuleOp(OpAsmParser &parser, OperationState &result)
Definition HWOps.cpp:901
static LogicalResult verifyAggregateFieldIndexAndType(AggregateOp &op, AggregateType aggType, Type elementType)
Ensure an aggregate op's field index is within the bounds of the aggregate type and the accessed fiel...
Definition HWOps.cpp:2456
static PortInfo getPort(ModuleTy &mod, size_t idx)
Definition HWOps.cpp:1440
static void printSliceTypes(OpAsmPrinter &p, Operation *, Type srcType, Type idxType)
Definition HWOps.cpp:1776
static bool hasAdditionalAttributes(Op op, ArrayRef< StringRef > ignoredAttrs={})
Check whether an operation has any additional attributes set beyond its standard list of attributes r...
Definition HWOps.cpp:345
Delimiter
Definition HWOps.cpp:116
@ OptionalLessGreater
static ParseResult parseArrayConcatTypes(OpAsmParser &p, SmallVectorImpl< Type > &inputTypes, Type &resultType)
Definition HWOps.cpp:2058
static bool getFieldName(const FieldRef &fieldRef, SmallString< 32 > &string)
static InstancePath empty
static StringAttr append(StringAttr base, const Twine &suffix)
Return a attribute with the specified suffix appended.
static Block * getBodyBlock(FModuleLike mod)
A namespace that is used to store existing names and generate new names in some scope within the IR.
Definition Namespace.h:30
StringRef newName(const Twine &name)
Return a unique name, derived from the input name, and add the new name to the internal namespace.
Definition Namespace.h:85
void setOutput(unsigned i, Value v)
Definition HWOps.cpp:241
Value getInput(unsigned i)
Definition HWOps.cpp:247
llvm::SmallVector< Value > outputOperands
Definition HWOps.h:118
llvm::SmallVector< Value > inputArgs
Definition HWOps.h:117
llvm::StringMap< unsigned > outputIdx
Definition HWOps.h:116
llvm::StringMap< unsigned > inputIdx
Definition HWOps.h:116
HWModulePortAccessor(Location loc, const ModulePortInfo &info, Region &bodyRegion)
Definition HWOps.cpp:225
static StringRef getInnerSymbolAttrName()
Return the name of the attribute used for inner symbol names.
This helps visit TypeOp nodes.
Definition HWVisitors.h:25
ResultType dispatchTypeOpVisitor(Operation *op, ExtraArgs... args)
Definition HWVisitors.h:27
ResultType visitUnhandledTypeOp(Operation *op, ExtraArgs... args)
This callback is invoked on any combinational operations that are not handled by the concrete visitor...
Definition HWVisitors.h:56
ResultType visitInvalidTypeOp(Operation *op, ExtraArgs... args)
This callback is invoked on any non-expression operations.
Definition HWVisitors.h:49
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition CalyxOps.cpp:55
uint64_t getWidth(Type t)
Definition ESIPasses.cpp:32
size_t getNumPorts(Operation *op)
Return the number of ports in a module-like thing (modules, memories, etc)
ModuleType fnToMod(Operation *op, ArrayRef< Attribute > inputNames, ArrayRef< Attribute > outputNames)
Definition HWTypes.cpp:1023
LogicalResult verifyParameterStructure(ArrayAttr parameters, ArrayAttr moduleParameters, const EmitErrorFn &emitError)
Check that all the parameter values specified to the instance are structurally valid.
std::function< void(std::function< bool(InFlightDiagnostic &)>)> EmitErrorFn
Whenever the nested function returns true, a note referring to the referenced module is attached to t...
LogicalResult verifyInstanceOfHWModule(Operation *instance, FlatSymbolRefAttr moduleRef, OperandRange inputs, TypeRange results, ArrayAttr argNames, ArrayAttr resultNames, ArrayAttr parameters, SymbolTableCollection &symbolTable)
Combines verifyReferencedModule, verifyInputs, verifyOutputs, and verifyParameters.
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
ParseResult parseModuleSignature(OpAsmParser &parser, SmallVectorImpl< PortParse > &args, TypeAttr &modType)
New Style parsing.
void printModuleSignatureNew(OpAsmPrinter &p, Region &body, hw::ModuleType modType, ArrayRef< Attribute > portAttrs, ArrayRef< Location > locAttrs)
bool isOffset(Value base, Value index, uint64_t offset)
Definition HWOps.cpp:1834
llvm::function_ref< void(OpBuilder &, HWModulePortAccessor &)> HWModuleBuilder
Definition HWOps.h:123
FunctionType getModuleType(Operation *module)
Return the signature for the specified module as a function type.
Definition HWOps.cpp:521
LogicalResult checkParameterInContext(Attribute value, Operation *module, Operation *usingOp, bool disallowParamRefs=false)
Check parameter specified by value to see if it is valid within the scope of the specified module mod...
Definition HWOps.cpp:202
int64_t getBitWidth(mlir::Type type)
Return the hardware bit width of a type.
Definition HWTypes.cpp:110
bool isAnyModuleOrInstance(Operation *module)
TODO: Move all these functions to a hw::ModuleLike interface.
Definition HWOps.cpp:515
StringAttr getVerilogModuleNameAttr(Operation *module)
Returns the verilog module name attribute or symbol name of any module-like operations.
Definition HWOps.cpp:539
mlir::Type getCanonicalType(mlir::Type type)
Definition HWTypes.cpp:49
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
ParseResult parseInputPortList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &inputs, SmallVectorImpl< Type > &inputTypes, ArrayAttr &inputNames)
Parse a list of instance input ports.
void printOutputPortList(OpAsmPrinter &p, Operation *op, TypeRange resultTypes, ArrayAttr resultNames)
Print a list of instance output ports.
ParseResult parseOptionalParameterList(OpAsmParser &parser, ArrayAttr &parameters)
Parse an parameter list if present.
void printOptionalParameterList(OpAsmPrinter &p, Operation *op, ArrayAttr parameters)
Print a parameter list for a module or instance.
StringRef chooseName(StringRef a, StringRef b)
Choose a good name for an item from two options.
Definition Naming.cpp:47
void printInputPortList(OpAsmPrinter &p, Operation *op, OperandRange inputs, TypeRange inputTypes, ArrayAttr inputNames)
Print a list of instance input ports.
ParseResult parseOutputPortList(OpAsmParser &parser, SmallVectorImpl< Type > &resultTypes, ArrayAttr &resultNames)
Parse a list of instance output ports.
Definition hw.py:1
function_ref< void(Value, StringRef)> OpAsmSetValueNameFn
Definition LLVM.h:182
This class represents the namespace in which InnerRef's can be resolved.
InnerSymTarget lookup(hw::InnerRefAttr inner) const
Resolve the InnerRef to its target within this namespace, returning empty target if no such name exis...
Operation * lookupOp(hw::InnerRefAttr inner) const
Resolve the InnerRef to its target within this namespace, returning empty target if no such name exis...
This holds a decoded list of input/inout and output ports for a module or instance.
PortInfo & at(size_t idx)
PortDirectionRange getOutputs()
mlir::Type type
Definition HWTypes.h:31
mlir::StringAttr name
Definition HWTypes.h:30
This holds the name, type, direction of a module's ports.