CIRCT 23.0.0git
Loading...
Searching...
No Matches
HWOps.cpp
Go to the documentation of this file.
1//===- HWOps.cpp - Implement the HW operations ----------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implement the HW ops.
10//
11//===----------------------------------------------------------------------===//
12
23#include "mlir/IR/Builders.h"
24#include "mlir/IR/Matchers.h"
25#include "mlir/IR/PatternMatch.h"
26#include "mlir/Interfaces/FunctionImplementation.h"
27#include "llvm/ADT/BitVector.h"
28#include "llvm/ADT/SmallPtrSet.h"
29#include "llvm/ADT/StringSet.h"
30
31using namespace circt;
32using namespace hw;
33using mlir::TypedAttr;
34
35/// Flip a port direction.
37 switch (direction) {
38 case ModulePort::Direction::Input:
39 return ModulePort::Direction::Output;
40 case ModulePort::Direction::Output:
41 return ModulePort::Direction::Input;
42 case ModulePort::Direction::InOut:
43 return ModulePort::Direction::InOut;
44 }
45 llvm_unreachable("unknown PortDirection");
46}
47
48bool hw::isValidIndexBitWidth(Value index, Value array) {
49 hw::ArrayType arrayType =
50 dyn_cast<hw::ArrayType>(hw::getCanonicalType(array.getType()));
51 assert(arrayType && "expected array type");
52 unsigned indexWidth = index.getType().getIntOrFloatBitWidth();
53 auto requiredWidth = llvm::Log2_64_Ceil(arrayType.getNumElements());
54 return requiredWidth == 0 ? (indexWidth == 0 || indexWidth == 1)
55 : indexWidth == requiredWidth;
56}
57
58/// Return true if the specified operation is a combinational logic op.
59bool hw::isCombinational(Operation *op) {
60 struct IsCombClassifier : public TypeOpVisitor<IsCombClassifier, bool> {
61 bool visitInvalidTypeOp(Operation *op) { return false; }
62 bool visitUnhandledTypeOp(Operation *op) { return true; }
63 };
64
65 return (op->getDialect() && op->getDialect()->getNamespace() == "comb") ||
66 IsCombClassifier().dispatchTypeOpVisitor(op);
67}
68
69static Value foldStructExtract(Operation *inputOp, uint32_t fieldIndex) {
70 // A struct extract of a struct create -> corresponding struct create operand.
71 if (auto structCreate = dyn_cast_or_null<StructCreateOp>(inputOp)) {
72 return structCreate.getOperand(fieldIndex);
73 }
74
75 // Extracting injected field -> corresponding field
76 if (auto structInject = dyn_cast_or_null<StructInjectOp>(inputOp)) {
77 if (structInject.getFieldIndex() != fieldIndex)
78 return {};
79 return structInject.getNewValue();
80 }
81 return {};
82}
83
84static ArrayAttr arrayOrEmpty(mlir::MLIRContext *context,
85 ArrayRef<Attribute> attrs) {
86 if (attrs.empty())
87 return ArrayAttr::get(context, {});
88 bool empty = true;
89 for (auto a : attrs)
90 if (a && !cast<DictionaryAttr>(a).empty()) {
91 empty = false;
92 break;
93 }
94 if (empty)
95 return ArrayAttr::get(context, {});
96 return ArrayAttr::get(context, attrs);
97}
98
99/// Get a special name to use when printing the entry block arguments of the
100/// region contained by an operation in this dialect.
101static void getAsmBlockArgumentNamesImpl(mlir::Region &region,
102 OpAsmSetValueNameFn setNameFn) {
103 if (region.empty())
104 return;
105 // Assign port names to the bbargs.
106 auto module = cast<HWModuleOp>(region.getParentOp());
107
108 auto *block = &region.front();
109 for (size_t i = 0, e = block->getNumArguments(); i != e; ++i) {
110 auto name = module.getInputName(i);
111 // Let mlir deterministically convert names to valid identifiers
112 setNameFn(block->getArgument(i), name);
113 }
114}
115
116enum class Delimiter {
117 None,
118 Paren, // () enclosed list
119 OptionalLessGreater, // <> enclosed list or absent
120};
121
122/// Check parameter specified by `value` to see if it is valid according to the
123/// module's parameters. If not, emit an error to the diagnostic provided as an
124/// argument to the lambda 'instanceError' and return failure, otherwise return
125/// success.
126///
127/// If `disallowParamRefs` is true, then parameter references are not allowed.
128LogicalResult hw::checkParameterInContext(
129 Attribute value, ArrayAttr moduleParameters,
130 const instance_like_impl::EmitErrorFn &instanceError,
131 bool disallowParamRefs) {
132 // Literals are always ok. Their types are already known to match
133 // expectations.
134 if (isa<IntegerAttr>(value) || isa<FloatAttr>(value) ||
135 isa<StringAttr>(value) || isa<ParamVerbatimAttr>(value))
136 return success();
137
138 // Check both subexpressions of an expression.
139 if (auto expr = dyn_cast<ParamExprAttr>(value)) {
140 for (auto op : expr.getOperands())
141 if (failed(checkParameterInContext(op, moduleParameters, instanceError,
142 disallowParamRefs)))
143 return failure();
144 return success();
145 }
146
147 // Parameter references need more analysis to make sure they are valid within
148 // this module.
149 if (auto parameterRef = dyn_cast<ParamDeclRefAttr>(value)) {
150 auto nameAttr = parameterRef.getName();
151
152 // Don't allow references to parameters from the default values of a
153 // parameter list.
154 if (disallowParamRefs) {
155 instanceError([&](auto &diag) {
156 diag << "parameter " << nameAttr
157 << " cannot be used as a default value for a parameter";
158 return false;
159 });
160 return failure();
161 }
162
163 // Find the corresponding attribute in the module.
164 for (auto param : moduleParameters) {
165 auto paramAttr = cast<ParamDeclAttr>(param);
166 if (paramAttr.getName() != nameAttr)
167 continue;
168
169 // If the types match then the reference is ok.
170 if (paramAttr.getType() == parameterRef.getType())
171 return success();
172
173 instanceError([&](auto &diag) {
174 diag << "parameter " << nameAttr << " used with type "
175 << parameterRef.getType() << "; should have type "
176 << paramAttr.getType();
177 return true;
178 });
179 return failure();
180 }
181
182 instanceError([&](auto &diag) {
183 diag << "use of unknown parameter " << nameAttr;
184 return true;
185 });
186 return failure();
187 }
188
189 instanceError([&](auto &diag) {
190 diag << "invalid parameter value " << value;
191 return false;
192 });
193 return failure();
194}
195
196/// Check parameter specified by `value` to see if it is valid within the scope
197/// of the specified module `module`. If not, emit an error at the location of
198/// `usingOp` and return failure, otherwise return success. If `usingOp` is
199/// null, then no diagnostic is generated.
200///
201/// If `disallowParamRefs` is true, then parameter references are not allowed.
202LogicalResult hw::checkParameterInContext(Attribute value, Operation *module,
203 Operation *usingOp,
204 bool disallowParamRefs) {
206 [&](const std::function<bool(InFlightDiagnostic &)> &fn) {
207 if (usingOp) {
208 auto diag = usingOp->emitOpError();
209 if (fn(diag))
210 diag.attachNote(module->getLoc()) << "module declared here";
211 }
212 };
213
214 return checkParameterInContext(value,
215 module->getAttrOfType<ArrayAttr>("parameters"),
216 emitError, disallowParamRefs);
217}
218
219/// Return true if the specified attribute tree is made up of nodes that are
220/// valid in a parameter expression.
221bool hw::isValidParameterExpression(Attribute attr, Operation *module) {
222 return succeeded(checkParameterInContext(attr, module, nullptr, false));
223}
224
226 const ModulePortInfo &info,
227 Region &bodyRegion)
228 : info(info) {
229 inputArgs.resize(info.sizeInputs());
230 for (auto [i, barg] : llvm::enumerate(bodyRegion.getArguments())) {
231 inputIdx[info.at(i).name.str()] = i;
232 inputArgs[i] = barg;
233 }
234
236 for (auto [i, outputInfo] : llvm::enumerate(info.getOutputs())) {
237 outputIdx[outputInfo.name.str()] = i;
238 }
239}
240
241void HWModulePortAccessor::setOutput(unsigned i, Value v) {
242 assert(outputOperands.size() > i && "invalid output index");
243 assert(outputOperands[i] == Value() && "output already set");
244 outputOperands[i] = v;
245}
246
248 assert(inputArgs.size() > i && "invalid input index");
249 return inputArgs[i];
250}
251Value HWModulePortAccessor::getInput(StringRef name) {
252 return getInput(inputIdx.find(name.str())->second);
253}
254void HWModulePortAccessor::setOutput(StringRef name, Value v) {
255 setOutput(outputIdx.find(name.str())->second, v);
256}
257
258//===----------------------------------------------------------------------===//
259// Declarative Canonicalization Patterns
260//===----------------------------------------------------------------------===//
261
262namespace {
263#include "circt/Dialect/HW/HWCanonicalization.cpp.inc"
264} // namespace
265
266//===----------------------------------------------------------------------===//
267// ConstantOp
268//===----------------------------------------------------------------------===//
269
270void ConstantOp::print(OpAsmPrinter &p) {
271 p << " ";
272 p.printAttribute(getValueAttr());
273 p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"value"});
274}
275
276ParseResult ConstantOp::parse(OpAsmParser &parser, OperationState &result) {
277 IntegerAttr valueAttr;
278
279 if (parser.parseAttribute(valueAttr, "value", result.attributes) ||
280 parser.parseOptionalAttrDict(result.attributes))
281 return failure();
282
283 result.addTypes(valueAttr.getType());
284 return success();
285}
286
287LogicalResult ConstantOp::verify() {
288 // If the result type has a bitwidth, then the attribute must match its width.
289 if (getValue().getBitWidth() != cast<IntegerType>(getType()).getWidth())
290 return emitError(
291 "hw.constant attribute bitwidth doesn't match return type");
292
293 return success();
294}
295
296/// Build a ConstantOp from an APInt, infering the result type from the
297/// width of the APInt.
298void ConstantOp::build(OpBuilder &builder, OperationState &result,
299 const APInt &value) {
300
301 auto type = IntegerType::get(builder.getContext(), value.getBitWidth());
302 auto attr = builder.getIntegerAttr(type, value);
303 return build(builder, result, type, attr);
304}
305
306/// Build a ConstantOp from an APInt, infering the result type from the
307/// width of the APInt.
308void ConstantOp::build(OpBuilder &builder, OperationState &result,
309 IntegerAttr value) {
310 return build(builder, result, value.getType(), value);
311}
312
313/// This builder allows construction of small signed integers like 0, 1, -1
314/// matching a specified MLIR IntegerType. This shouldn't be used for general
315/// constant folding because it only works with values that can be expressed in
316/// an int64_t. Use APInt's instead.
317void ConstantOp::build(OpBuilder &builder, OperationState &result, Type type,
318 int64_t value) {
319 auto numBits = cast<IntegerType>(type).getWidth();
320 build(builder, result,
321 APInt(numBits, (uint64_t)value, /*isSigned=*/true,
322 /*implicitTrunc=*/true));
323}
324
325void ConstantOp::getAsmResultNames(
326 function_ref<void(Value, StringRef)> setNameFn) {
327 auto intTy = getType();
328 auto intCst = getValue();
329
330 // Sugar i1 constants with 'true' and 'false'.
331 if (cast<IntegerType>(intTy).getWidth() == 1)
332 return setNameFn(getResult(), intCst.isZero() ? "false" : "true");
333
334 // Otherwise, build a complex name with the value and type.
335 SmallVector<char, 32> specialNameBuffer;
336 llvm::raw_svector_ostream specialName(specialNameBuffer);
337 specialName << 'c' << intCst << '_' << intTy;
338 setNameFn(getResult(), specialName.str());
339}
340
341OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) {
342 assert(adaptor.getOperands().empty() && "constant has no operands");
343 return getValueAttr();
344}
345
346//===----------------------------------------------------------------------===//
347// WireOp
348//===----------------------------------------------------------------------===//
349
350/// Check whether an operation has any additional attributes set beyond its
351/// standard list of attributes returned by `getAttributeNames`.
352template <class Op>
353static bool hasAdditionalAttributes(Op op,
354 ArrayRef<StringRef> ignoredAttrs = {}) {
355 auto names = op.getAttributeNames();
356 llvm::SmallDenseSet<StringRef> nameSet;
357 nameSet.reserve(names.size() + ignoredAttrs.size());
358 nameSet.insert(names.begin(), names.end());
359 nameSet.insert(ignoredAttrs.begin(), ignoredAttrs.end());
360 return llvm::any_of(op->getAttrs(), [&](auto namedAttr) {
361 return !nameSet.contains(namedAttr.getName());
362 });
363}
364
365void WireOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
366 // If the wire has an optional 'name' attribute, use it.
367 auto nameAttr = (*this)->getAttrOfType<StringAttr>("name");
368 if (nameAttr && !nameAttr.getValue().empty())
369 setNameFn(getResult(), nameAttr.getValue());
370}
371
372std::optional<size_t> WireOp::getTargetResultIndex() { return 0; }
373
374OpFoldResult WireOp::fold(FoldAdaptor adaptor) {
375 // If the wire has no additional attributes, no name, and no symbol, just
376 // forward its input.
377 if (!hasAdditionalAttributes(*this, {"sv.namehint"}) && !getNameAttr() &&
378 !getInnerSymAttr())
379 return getInput();
380 return {};
381}
382
383LogicalResult WireOp::canonicalize(WireOp wire, PatternRewriter &rewriter) {
384 // Block if the wire has any attributes.
385 if (hasAdditionalAttributes(wire, {"sv.namehint"}))
386 return failure();
387
388 // If the wire has a symbol, then we can't delete it.
389 if (wire.getInnerSymAttr())
390 return failure();
391
392 // If the wire is self-referential (its input is itself), we can't remove it.
393 if (wire.getInput() == wire.getResult())
394 return failure();
395
396 // If the wire has a name or an `sv.namehint` attribute, propagate it as an
397 // `sv.namehint` to the expression.
398 if (auto *inputOp = wire.getInput().getDefiningOp())
399 if (auto name = chooseName(wire, inputOp))
400 rewriter.modifyOpInPlace(inputOp,
401 [&] { inputOp->setAttr("sv.namehint", name); });
402
403 rewriter.replaceOp(wire, wire.getInput());
404 return success();
405}
406
407//===----------------------------------------------------------------------===//
408// AggregateConstantOp
409//===----------------------------------------------------------------------===//
410
411static LogicalResult checkAttributes(Operation *op, Attribute attr, Type type) {
412 // If this is a type alias, get the underlying type.
413 if (auto typeAlias = dyn_cast<TypeAliasType>(type))
414 type = typeAlias.getCanonicalType();
415
416 if (auto structType = dyn_cast<StructType>(type)) {
417 auto arrayAttr = dyn_cast<ArrayAttr>(attr);
418 if (!arrayAttr)
419 return op->emitOpError("expected array attribute for constant of type ")
420 << type;
421 if (structType.getElements().size() != arrayAttr.size())
422 return op->emitOpError("array attribute (")
423 << arrayAttr.size() << ") has wrong size for struct constant ("
424 << structType.getElements().size() << ")";
425
426 for (auto [attr, fieldInfo] :
427 llvm::zip(arrayAttr.getValue(), structType.getElements())) {
428 if (failed(checkAttributes(op, attr, fieldInfo.type)))
429 return failure();
430 }
431 } else if (auto arrayType = dyn_cast<ArrayType>(type)) {
432 auto arrayAttr = dyn_cast<ArrayAttr>(attr);
433 if (!arrayAttr)
434 return op->emitOpError("expected array attribute for constant of type ")
435 << type;
436 if (arrayType.getNumElements() != arrayAttr.size())
437 return op->emitOpError("array attribute (")
438 << arrayAttr.size() << ") has wrong size for array constant ("
439 << arrayType.getNumElements() << ")";
440
441 auto elementType = arrayType.getElementType();
442 for (auto attr : arrayAttr.getValue()) {
443 if (failed(checkAttributes(op, attr, elementType)))
444 return failure();
445 }
446 } else if (auto arrayType = dyn_cast<UnpackedArrayType>(type)) {
447 auto arrayAttr = dyn_cast<ArrayAttr>(attr);
448 if (!arrayAttr)
449 return op->emitOpError("expected array attribute for constant of type ")
450 << type;
451 auto elementType = arrayType.getElementType();
452 if (arrayType.getNumElements() != arrayAttr.size())
453 return op->emitOpError("array attribute (")
454 << arrayAttr.size()
455 << ") has wrong size for unpacked array constant ("
456 << arrayType.getNumElements() << ")";
457
458 for (auto attr : arrayAttr.getValue()) {
459 if (failed(checkAttributes(op, attr, elementType)))
460 return failure();
461 }
462 } else if (auto enumType = dyn_cast<EnumType>(type)) {
463 auto stringAttr = dyn_cast<StringAttr>(attr);
464 if (!stringAttr)
465 return op->emitOpError("expected string attribute for constant of type ")
466 << type;
467 } else if (auto intType = dyn_cast<IntegerType>(type)) {
468 // Check the attribute kind is correct.
469 auto intAttr = dyn_cast<IntegerAttr>(attr);
470 if (!intAttr)
471 return op->emitOpError("expected integer attribute for constant of type ")
472 << type;
473 // Check the bitwidth is correct.
474 if (intAttr.getValue().getBitWidth() != intType.getWidth())
475 return op->emitOpError("hw.constant attribute bitwidth "
476 "doesn't match return type");
477 } else if (auto typedAttr = dyn_cast<TypedAttr>(attr)) {
478 if (typedAttr.getType() != type)
479 return op->emitOpError("typed attr doesn't match the return type ")
480 << type;
481 } else {
482 return op->emitOpError("unknown element type ") << type;
483 }
484 return success();
485}
486
487LogicalResult AggregateConstantOp::verify() {
488 return checkAttributes(*this, getFieldsAttr(), getType());
489}
490
491OpFoldResult AggregateConstantOp::fold(FoldAdaptor) { return getFieldsAttr(); }
492
493//===----------------------------------------------------------------------===//
494// ParamValueOp
495//===----------------------------------------------------------------------===//
496
497static ParseResult parseParamValue(OpAsmParser &p, Attribute &value,
498 Type &resultType) {
499 if (p.parseType(resultType) || p.parseEqual() ||
500 p.parseAttribute(value, resultType))
501 return failure();
502 return success();
503}
504
505static void printParamValue(OpAsmPrinter &p, Operation *, Attribute value,
506 Type resultType) {
507 p << resultType << " = ";
508 p.printAttributeWithoutType(value);
509}
510
511LogicalResult ParamValueOp::verify() {
512 // Check that the attribute expression is valid in this module.
514 getValue(), (*this)->getParentOfType<hw::HWModuleOp>(), *this);
515}
516
517OpFoldResult ParamValueOp::fold(FoldAdaptor adaptor) {
518 assert(adaptor.getOperands().empty() && "hw.param.value has no operands");
519 return getValueAttr();
520}
521
522//===----------------------------------------------------------------------===//
523// HWModuleOp
524//===----------------------------------------------------------------------===/
525
526/// Return true if isAnyModule or instance.
527bool hw::isAnyModuleOrInstance(Operation *moduleOrInstance) {
528 return isa<HWModuleLike, InstanceOp>(moduleOrInstance);
529}
530
531/// Return the signature for a module as a function type from the module itself
532/// or from an hw::InstanceOp.
533FunctionType hw::getModuleType(Operation *moduleOrInstance) {
534 return TypeSwitch<Operation *, FunctionType>(moduleOrInstance)
535 .Case<InstanceOp>([](auto instance) {
536 SmallVector<Type> inputs(instance->getOperandTypes());
537 SmallVector<Type> results(instance->getResultTypes());
538 return FunctionType::get(instance->getContext(), inputs, results);
539 })
540 .Case<HWModuleLike>(
541 [](auto mod) { return mod.getHWModuleType().getFuncType(); })
542 .Default([](Operation *op) {
543 return cast<FunctionType>(
544 cast<mlir::FunctionOpInterface>(op).getFunctionType());
545 });
546}
547
548/// Return the name to use for the Verilog module that we're referencing
549/// here. This is typically the symbol, but can be overridden with the
550/// verilogName attribute.
551StringAttr hw::getVerilogModuleNameAttr(Operation *module) {
552 auto nameAttr = module->getAttrOfType<StringAttr>("verilogName");
553 if (nameAttr)
554 return nameAttr;
555
556 return module->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName());
557}
558
559template <typename ModuleTy>
560static void
561buildModule(OpBuilder &builder, OperationState &result, StringAttr name,
562 const ModulePortInfo &ports, ArrayAttr parameters,
563 ArrayRef<NamedAttribute> attributes, StringAttr comment) {
564 using namespace mlir::function_interface_impl;
565
566 // Add an attribute for the name.
567 result.addAttribute(SymbolTable::getSymbolAttrName(), name);
568
569 SmallVector<Attribute> perPortAttrs;
570 SmallVector<ModulePort> portTypes;
571
572 for (auto elt : ports) {
573 portTypes.push_back(elt);
574 llvm::SmallVector<NamedAttribute> portAttrs;
575 if (elt.attrs)
576 llvm::copy(elt.attrs, std::back_inserter(portAttrs));
577 perPortAttrs.push_back(builder.getDictionaryAttr(portAttrs));
578 }
579
580 // Allow clients to pass in null for the parameters list.
581 if (!parameters)
582 parameters = builder.getArrayAttr({});
583
584 // Record the argument and result types as an attribute.
585 auto type = ModuleType::get(builder.getContext(), portTypes);
586 result.addAttribute(ModuleTy::getModuleTypeAttrName(result.name),
587 TypeAttr::get(type));
588 result.addAttribute("per_port_attrs",
589 arrayOrEmpty(builder.getContext(), perPortAttrs));
590 result.addAttribute("parameters", parameters);
591 if (!comment)
592 comment = builder.getStringAttr("");
593 result.addAttribute("comment", comment);
594 result.addAttributes(attributes);
595 result.addRegion();
596}
597
598/// Internal implementation of argument/result insertion and removal on modules.
600 MLIRContext *context, ArrayRef<std::pair<unsigned, PortInfo>> insertArgs,
601 ArrayRef<unsigned> removeArgs, ArrayRef<Attribute> oldArgNames,
602 ArrayRef<Type> oldArgTypes, ArrayRef<Attribute> oldArgAttrs,
603 ArrayRef<Location> oldArgLocs, SmallVector<Attribute> &newArgNames,
604 SmallVector<Type> &newArgTypes, SmallVector<Attribute> &newArgAttrs,
605 SmallVector<Location> &newArgLocs, Block *body = nullptr) {
606
607#ifndef NDEBUG
608 // Check that the `insertArgs` and `removeArgs` indices are in ascending
609 // order.
610 assert(llvm::is_sorted(insertArgs,
611 [](auto &a, auto &b) { return a.first < b.first; }) &&
612 "insertArgs must be in ascending order");
613 assert(llvm::is_sorted(removeArgs, [](auto &a, auto &b) { return a < b; }) &&
614 "removeArgs must be in ascending order");
615#endif
616
617 auto oldArgCount = oldArgTypes.size();
618 auto newArgCount = oldArgCount + insertArgs.size() - removeArgs.size();
619 assert((int)newArgCount >= 0);
620
621 newArgNames.reserve(newArgCount);
622 newArgTypes.reserve(newArgCount);
623 newArgAttrs.reserve(newArgCount);
624 newArgLocs.reserve(newArgCount);
625
626 auto exportPortAttrName = StringAttr::get(context, "hw.exportPort");
627 auto emptyDictAttr = DictionaryAttr::get(context, {});
628 auto unknownLoc = UnknownLoc::get(context);
629
630 BitVector erasedIndices;
631 if (body)
632 erasedIndices.resize(oldArgCount + insertArgs.size());
633
634 for (unsigned argIdx = 0, idx = 0; argIdx <= oldArgCount; ++argIdx, ++idx) {
635 // Insert new ports at this position.
636 while (!insertArgs.empty() && insertArgs[0].first == argIdx) {
637 auto port = insertArgs[0].second;
638 if (port.dir == ModulePort::Direction::InOut &&
639 !isa<InOutType>(port.type))
640 port.type = InOutType::get(port.type);
641 auto sym = port.getSym();
642 Attribute attr =
643 (sym && !sym.empty())
644 ? DictionaryAttr::get(context, {{exportPortAttrName, sym}})
645 : emptyDictAttr;
646 newArgNames.push_back(port.name);
647 newArgTypes.push_back(port.type);
648 newArgAttrs.push_back(attr);
649 insertArgs = insertArgs.drop_front();
650 LocationAttr loc = port.loc ? port.loc : unknownLoc;
651 newArgLocs.push_back(loc);
652 if (body)
653 body->insertArgument(idx++, port.type, loc);
654 }
655 if (argIdx == oldArgCount)
656 break;
657
658 // Migrate the old port at this position.
659 bool removed = false;
660 while (!removeArgs.empty() && removeArgs[0] == argIdx) {
661 removeArgs = removeArgs.drop_front();
662 removed = true;
663 }
664
665 if (removed) {
666 if (body)
667 erasedIndices.set(idx);
668 } else {
669 newArgNames.push_back(oldArgNames[argIdx]);
670 newArgTypes.push_back(oldArgTypes[argIdx]);
671 newArgAttrs.push_back(oldArgAttrs.empty() ? emptyDictAttr
672 : oldArgAttrs[argIdx]);
673 newArgLocs.push_back(oldArgLocs[argIdx]);
674 }
675 }
676
677 if (body)
678 body->eraseArguments(erasedIndices);
679
680 assert(newArgNames.size() == newArgCount);
681 assert(newArgTypes.size() == newArgCount);
682 assert(newArgAttrs.size() == newArgCount);
683 assert(newArgLocs.size() == newArgCount);
684}
685
686/// Insert and remove ports of a module. The insertion and removal indices must
687/// be in ascending order. The indices refer to the port positions before any
688/// insertion or removal occurs. Ports inserted at the same index will appear in
689/// the module in the same order as they were listed in the `insert*` array.
690///
691/// The operation must be any of the module-like operations.
692///
693/// This is marked deprecated as it's only used from HandshakeToHW and
694/// PortConverter and is likely broken and not currently tested. Users of this
695/// are still written dealing with input and output ports separately, which is
696/// an old and broken style.
697[[deprecated]] static void
698modifyModulePorts(Operation *op,
699 ArrayRef<std::pair<unsigned, PortInfo>> insertInputs,
700 ArrayRef<std::pair<unsigned, PortInfo>> insertOutputs,
701 ArrayRef<unsigned> removeInputs,
702 ArrayRef<unsigned> removeOutputs, Block *body = nullptr) {
703 auto moduleOp = cast<HWModuleLike>(op);
704 auto *context = moduleOp.getContext();
705
706 // Dig up the old argument and result data.
707 auto oldArgNames = moduleOp.getInputNames();
708 auto oldArgTypes = moduleOp.getInputTypes();
709 auto oldArgAttrs = moduleOp.getAllInputAttrs();
710 auto oldArgLocs = moduleOp.getInputLocs();
711
712 auto oldResultNames = moduleOp.getOutputNames();
713 auto oldResultTypes = moduleOp.getOutputTypes();
714 auto oldResultAttrs = moduleOp.getAllOutputAttrs();
715 auto oldResultLocs = moduleOp.getOutputLocs();
716
717 // Modify the ports.
718 SmallVector<Attribute> newArgNames, newResultNames;
719 SmallVector<Type> newArgTypes, newResultTypes;
720 SmallVector<Attribute> newArgAttrs, newResultAttrs;
721 SmallVector<Location> newArgLocs, newResultLocs;
722
723 modifyModuleArgs(context, insertInputs, removeInputs, oldArgNames,
724 oldArgTypes, oldArgAttrs, oldArgLocs, newArgNames,
725 newArgTypes, newArgAttrs, newArgLocs, body);
726
727 modifyModuleArgs(context, insertOutputs, removeOutputs, oldResultNames,
728 oldResultTypes, oldResultAttrs, oldResultLocs,
729 newResultNames, newResultTypes, newResultAttrs,
730 newResultLocs);
731
732 // Update the module operation types and attributes.
733 auto fnty = FunctionType::get(context, newArgTypes, newResultTypes);
734 auto modty = detail::fnToMod(fnty, newArgNames, newResultNames);
735 moduleOp.setHWModuleType(modty);
736 moduleOp.setAllInputAttrs(newArgAttrs);
737 moduleOp.setAllOutputAttrs(newResultAttrs);
738
739 newArgLocs.append(newResultLocs.begin(), newResultLocs.end());
740 moduleOp.setAllPortLocs(newArgLocs);
741}
742
743void HWModuleOp::build(OpBuilder &builder, OperationState &result,
744 StringAttr name, const ModulePortInfo &ports,
745 ArrayAttr parameters,
746 ArrayRef<NamedAttribute> attributes, StringAttr comment,
747 bool shouldEnsureTerminator) {
748 buildModule<HWModuleOp>(builder, result, name, ports, parameters, attributes,
749 comment);
750
751 // Create a region and a block for the body.
752 auto *bodyRegion = result.regions[0].get();
753 Block *body = new Block();
754 bodyRegion->push_back(body);
755
756 // Add arguments to the body block.
757 auto unknownLoc = builder.getUnknownLoc();
758 for (auto port : ports.getInputs()) {
759 auto loc = port.loc ? Location(port.loc) : unknownLoc;
760 auto type = port.type;
761 if (port.isInOut() && !isa<InOutType>(type))
762 type = InOutType::get(type);
763 body->addArgument(type, loc);
764 }
765
766 // Add result ports attribute.
767 auto unknownLocAttr = cast<LocationAttr>(unknownLoc);
768 SmallVector<Attribute> resultLocs;
769 for (auto port : ports.getOutputs())
770 resultLocs.push_back(port.loc ? port.loc : unknownLocAttr);
771 result.addAttribute("result_locs", builder.getArrayAttr(resultLocs));
772
773 if (shouldEnsureTerminator)
774 HWModuleOp::ensureTerminator(*bodyRegion, builder, result.location);
775}
776
777void HWModuleOp::build(OpBuilder &builder, OperationState &result,
778 StringAttr name, ArrayRef<PortInfo> ports,
779 ArrayAttr parameters,
780 ArrayRef<NamedAttribute> attributes,
781 StringAttr comment) {
782 build(builder, result, name, ModulePortInfo(ports), parameters, attributes,
783 comment);
784}
785
786void HWModuleOp::build(OpBuilder &builder, OperationState &odsState,
787 StringAttr name, const ModulePortInfo &ports,
788 HWModuleBuilder modBuilder, ArrayAttr parameters,
789 ArrayRef<NamedAttribute> attributes,
790 StringAttr comment) {
791 build(builder, odsState, name, ports, parameters, attributes, comment,
792 /*shouldEnsureTerminator=*/false);
793 auto *bodyRegion = odsState.regions[0].get();
794 OpBuilder::InsertionGuard guard(builder);
795 auto accessor = HWModulePortAccessor(odsState.location, ports, *bodyRegion);
796 builder.setInsertionPointToEnd(&bodyRegion->front());
797 modBuilder(builder, accessor);
798 // Create output operands.
799 llvm::SmallVector<Value> outputOperands = accessor.getOutputOperands();
800 hw::OutputOp::create(builder, odsState.location, outputOperands);
801}
802
803void HWModuleOp::modifyPorts(
804 ArrayRef<std::pair<unsigned, PortInfo>> insertInputs,
805 ArrayRef<std::pair<unsigned, PortInfo>> insertOutputs,
806 ArrayRef<unsigned> eraseInputs, ArrayRef<unsigned> eraseOutputs) {
807 modifyModulePorts(*this, insertInputs, insertOutputs, eraseInputs,
808 eraseOutputs);
809}
810
811/// Return the name to use for the Verilog module that we're referencing
812/// here. This is typically the symbol, but can be overridden with the
813/// verilogName attribute.
814StringAttr HWModuleExternOp::getVerilogModuleNameAttr() {
815 if (auto vName = getVerilogNameAttr())
816 return vName;
817
818 return (*this)->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName());
819}
820
821StringAttr HWModuleGeneratedOp::getVerilogModuleNameAttr() {
822 if (auto vName = getVerilogNameAttr()) {
823 return vName;
824 }
825 return (*this)->getAttrOfType<StringAttr>(
826 ::mlir::SymbolTable::getSymbolAttrName());
827}
828
829void HWModuleExternOp::build(OpBuilder &builder, OperationState &result,
830 StringAttr name, const ModulePortInfo &ports,
831 StringRef verilogName, ArrayAttr parameters,
832 ArrayRef<NamedAttribute> attributes) {
833 buildModule<HWModuleExternOp>(builder, result, name, ports, parameters,
834 attributes, {});
835
836 // Add the port locations.
837 LocationAttr unknownLoc = builder.getUnknownLoc();
838 SmallVector<Attribute> portLocs;
839 for (auto elt : ports)
840 portLocs.push_back(elt.loc ? elt.loc : unknownLoc);
841 result.addAttribute("port_locs", builder.getArrayAttr(portLocs));
842
843 if (!verilogName.empty())
844 result.addAttribute("verilogName", builder.getStringAttr(verilogName));
845}
846
847void HWModuleExternOp::build(OpBuilder &builder, OperationState &result,
848 StringAttr name, ArrayRef<PortInfo> ports,
849 StringRef verilogName, ArrayAttr parameters,
850 ArrayRef<NamedAttribute> attributes) {
851 build(builder, result, name, ModulePortInfo(ports), verilogName, parameters,
852 attributes);
853}
854
855void HWModuleExternOp::modifyPorts(
856 ArrayRef<std::pair<unsigned, PortInfo>> insertInputs,
857 ArrayRef<std::pair<unsigned, PortInfo>> insertOutputs,
858 ArrayRef<unsigned> eraseInputs, ArrayRef<unsigned> eraseOutputs) {
859 modifyModulePorts(*this, insertInputs, insertOutputs, eraseInputs,
860 eraseOutputs);
861}
862
863void HWModuleExternOp::appendOutputs(
864 ArrayRef<std::pair<StringAttr, Value>> outputs) {}
865
866void HWModuleGeneratedOp::build(OpBuilder &builder, OperationState &result,
867 FlatSymbolRefAttr genKind, StringAttr name,
868 const ModulePortInfo &ports,
869 StringRef verilogName, ArrayAttr parameters,
870 ArrayRef<NamedAttribute> attributes) {
871 buildModule<HWModuleGeneratedOp>(builder, result, name, ports, parameters,
872 attributes, {});
873 // Add the port locations.
874 LocationAttr unknownLoc = builder.getUnknownLoc();
875 SmallVector<Attribute> portLocs;
876 for (auto elt : ports)
877 portLocs.push_back(elt.loc ? elt.loc : unknownLoc);
878 result.addAttribute("port_locs", builder.getArrayAttr(portLocs));
879
880 result.addAttribute("generatorKind", genKind);
881 if (!verilogName.empty())
882 result.addAttribute("verilogName", builder.getStringAttr(verilogName));
883}
884
885void HWModuleGeneratedOp::build(OpBuilder &builder, OperationState &result,
886 FlatSymbolRefAttr genKind, StringAttr name,
887 ArrayRef<PortInfo> ports, StringRef verilogName,
888 ArrayAttr parameters,
889 ArrayRef<NamedAttribute> attributes) {
890 build(builder, result, genKind, name, ModulePortInfo(ports), verilogName,
891 parameters, attributes);
892}
893
894void HWModuleGeneratedOp::modifyPorts(
895 ArrayRef<std::pair<unsigned, PortInfo>> insertInputs,
896 ArrayRef<std::pair<unsigned, PortInfo>> insertOutputs,
897 ArrayRef<unsigned> eraseInputs, ArrayRef<unsigned> eraseOutputs) {
898 modifyModulePorts(*this, insertInputs, insertOutputs, eraseInputs,
899 eraseOutputs);
900}
901
902void HWModuleGeneratedOp::appendOutputs(
903 ArrayRef<std::pair<StringAttr, Value>> outputs) {}
904
905static bool hasAttribute(StringRef name, ArrayRef<NamedAttribute> attrs) {
906 for (auto &argAttr : attrs)
907 if (argAttr.getName() == name)
908 return true;
909 return false;
910}
911
912template <typename ModuleTy>
913static ParseResult parseHWModuleOp(OpAsmParser &parser,
914 OperationState &result) {
915
916 using namespace mlir::function_interface_impl;
917 auto builder = parser.getBuilder();
918 auto loc = parser.getCurrentLocation();
919
920 // Parse the visibility attribute.
921 (void)mlir::impl::parseOptionalVisibilityKeyword(parser, result.attributes);
922
923 // Parse the name as a symbol.
924 StringAttr nameAttr;
925 if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
926 result.attributes))
927 return failure();
928
929 // Parse the generator information.
930 FlatSymbolRefAttr kindAttr;
931 if constexpr (std::is_same_v<ModuleTy, HWModuleGeneratedOp>) {
932 if (parser.parseComma() ||
933 parser.parseAttribute(kindAttr, "generatorKind", result.attributes)) {
934 return failure();
935 }
936 }
937
938 // Parse the parameters.
939 ArrayAttr parameters;
940 if (parseOptionalParameterList(parser, parameters))
941 return failure();
942
943 SmallVector<module_like_impl::PortParse> ports;
944 TypeAttr modType;
945 if (failed(module_like_impl::parseModuleSignature(parser, ports, modType)))
946 return failure();
947
948 // Parse the attribute dict.
949 if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
950 return failure();
951
952 if (hasAttribute("parameters", result.attributes)) {
953 parser.emitError(loc, "explicit `parameters` attributes not allowed");
954 return failure();
955 }
956
957 result.addAttribute("parameters", parameters);
958 result.addAttribute(ModuleTy::getModuleTypeAttrName(result.name), modType);
959
960 // Convert the specified array of dictionary attrs (which may have null
961 // entries) to an ArrayAttr of dictionaries.
962 SmallVector<Attribute> attrs;
963 for (auto &port : ports)
964 attrs.push_back(port.attrs ? port.attrs : builder.getDictionaryAttr({}));
965 // Add the attributes to the ports.
966 auto nonEmptyAttrsFn = [](Attribute attr) {
967 return attr && !cast<DictionaryAttr>(attr).empty();
968 };
969 if (llvm::any_of(attrs, nonEmptyAttrsFn))
970 result.addAttribute(ModuleTy::getPerPortAttrsAttrName(result.name),
971 builder.getArrayAttr(attrs));
972
973 // Add the port locations.
974 auto unknownLoc = builder.getUnknownLoc();
975 auto nonEmptyLocsFn = [unknownLoc](Attribute attr) {
976 return attr && cast<Location>(attr) != unknownLoc;
977 };
978 SmallVector<Attribute> locs;
979 StringAttr portLocsAttrName;
980 if constexpr (std::is_same_v<ModuleTy, HWModuleOp>) {
981 // Plain modules only store the output port locations, as the input port
982 // locations will be stored in the basic block arguments.
983 portLocsAttrName = ModuleTy::getResultLocsAttrName(result.name);
984 for (auto &port : ports)
985 if (port.direction == ModulePort::Direction::Output)
986 locs.push_back(port.sourceLoc ? Location(*port.sourceLoc) : unknownLoc);
987 } else {
988 // All other modules store all port locations in a single array.
989 portLocsAttrName = ModuleTy::getPortLocsAttrName(result.name);
990 for (auto &port : ports)
991 locs.push_back(port.sourceLoc ? Location(*port.sourceLoc) : unknownLoc);
992 }
993 if (llvm::any_of(locs, nonEmptyLocsFn))
994 result.addAttribute(portLocsAttrName, builder.getArrayAttr(locs));
995
996 // Add the entry block arguments.
997 SmallVector<OpAsmParser::Argument, 4> entryArgs;
998 for (auto &port : ports)
999 if (port.direction != ModulePort::Direction::Output)
1000 entryArgs.push_back(port);
1001
1002 // Parse the optional function body.
1003 auto *body = result.addRegion();
1004 if (std::is_same_v<ModuleTy, HWModuleOp>) {
1005 if (parser.parseRegion(*body, entryArgs))
1006 return failure();
1007
1008 HWModuleOp::ensureTerminator(*body, parser.getBuilder(), result.location);
1009 }
1010 return success();
1011}
1012
1013ParseResult HWModuleOp::parse(OpAsmParser &parser, OperationState &result) {
1014 return parseHWModuleOp<HWModuleOp>(parser, result);
1015}
1016
1017ParseResult HWModuleExternOp::parse(OpAsmParser &parser,
1018 OperationState &result) {
1019 return parseHWModuleOp<HWModuleExternOp>(parser, result);
1020}
1021
1022ParseResult HWModuleGeneratedOp::parse(OpAsmParser &parser,
1023 OperationState &result) {
1024 return parseHWModuleOp<HWModuleGeneratedOp>(parser, result);
1025}
1026
1027FunctionType getHWModuleOpType(Operation *op) {
1028 if (auto mod = dyn_cast<HWModuleLike>(op))
1029 return mod.getHWModuleType().getFuncType();
1030 return cast<FunctionType>(
1031 cast<mlir::FunctionOpInterface>(op).getFunctionType());
1032}
1033
1034template <typename ModuleTy>
1035static void printModuleOp(OpAsmPrinter &p, ModuleTy mod) {
1036 p << ' ';
1037 // Print the visibility of the module.
1038 StringRef visibilityAttrName = SymbolTable::getVisibilityAttrName();
1039 if (auto visibility = mod.getOperation()->template getAttrOfType<StringAttr>(
1040 visibilityAttrName))
1041 p << visibility.getValue() << ' ';
1042
1043 // Print the operation and the function name.
1044 p.printSymbolName(SymbolTable::getSymbolName(mod.getOperation()).getValue());
1045 if (auto gen = dyn_cast<HWModuleGeneratedOp>(mod.getOperation())) {
1046 p << ", ";
1047 p.printSymbolName(gen.getGeneratorKind());
1048 }
1049
1050 // Print the parameter list if present.
1051 printOptionalParameterList(p, mod.getOperation(), mod.getParameters());
1052
1054
1055 SmallVector<StringRef, 3> omittedAttrs;
1056 if (isa<HWModuleGeneratedOp>(mod.getOperation()))
1057 omittedAttrs.push_back("generatorKind");
1058 if constexpr (std::is_same_v<ModuleTy, HWModuleOp>)
1059 omittedAttrs.push_back(mod.getResultLocsAttrName());
1060 else
1061 omittedAttrs.push_back(mod.getPortLocsAttrName());
1062 omittedAttrs.push_back(mod.getModuleTypeAttrName());
1063 omittedAttrs.push_back(mod.getPerPortAttrsAttrName());
1064 omittedAttrs.push_back(mod.getParametersAttrName());
1065 omittedAttrs.push_back(visibilityAttrName);
1066 if (auto cmt =
1067 mod.getOperation()->template getAttrOfType<StringAttr>("comment"))
1068 if (cmt.getValue().empty())
1069 omittedAttrs.push_back("comment");
1070
1071 mlir::function_interface_impl::printFunctionAttributes(p, mod.getOperation(),
1072 omittedAttrs);
1073}
1074
1075void HWModuleExternOp::print(OpAsmPrinter &p) { printModuleOp(p, *this); }
1076void HWModuleGeneratedOp::print(OpAsmPrinter &p) { printModuleOp(p, *this); }
1077
1078void HWModuleOp::print(OpAsmPrinter &p) {
1079 printModuleOp(p, *this);
1080
1081 // Print the body if this is not an external function.
1082 Region &body = getBody();
1083 if (!body.empty()) {
1084 p << " ";
1085 p.printRegion(body, /*printEntryBlockArgs=*/false,
1086 /*printBlockTerminators=*/true);
1087 }
1088}
1089
1090static LogicalResult verifyModuleCommon(HWModuleLike module) {
1091 assert(isa<HWModuleLike>(module) &&
1092 "verifier hook should only be called on modules");
1093
1094 if (auto portLocs = module->getAttrOfType<ArrayAttr>("port_locs"))
1095 if (!portLocs.empty() && portLocs.size() != module.getNumPorts())
1096 return module->emitOpError("requires ")
1097 << module.getNumPorts() << " port locations but got "
1098 << portLocs.size();
1099
1100 SmallPtrSet<Attribute, 4> paramNames;
1101
1102 // Check parameter default values are sensible.
1103 for (auto param : module->getAttrOfType<ArrayAttr>("parameters")) {
1104 auto paramAttr = cast<ParamDeclAttr>(param);
1105
1106 // Check that we don't have any redundant parameter names. These are
1107 // resolved by string name: reuse of the same name would cause ambiguities.
1108 if (!paramNames.insert(paramAttr.getName()).second)
1109 return module->emitOpError("parameter ")
1110 << paramAttr << " has the same name as a previous parameter";
1111
1112 // Default values are allowed to be missing, check them if present.
1113 auto value = paramAttr.getValue();
1114 if (!value)
1115 continue;
1116
1117 auto typedValue = dyn_cast<TypedAttr>(value);
1118 if (!typedValue)
1119 return module->emitOpError("parameter ")
1120 << paramAttr << " should have a typed value; has value " << value;
1121
1122 if (typedValue.getType() != paramAttr.getType())
1123 return module->emitOpError("parameter ")
1124 << paramAttr << " should have type " << paramAttr.getType()
1125 << "; has type " << typedValue.getType();
1126
1127 // Verify that this is a valid parameter value, disallowing parameter
1128 // references. We could allow parameters to refer to each other in the
1129 // future with lexical ordering if there is a need.
1130 if (failed(checkParameterInContext(value, module, module,
1131 /*disallowParamRefs=*/true)))
1132 return failure();
1133 }
1134 return success();
1135}
1136
1137LogicalResult HWModuleOp::verify() {
1138 if (failed(verifyModuleCommon(*this)))
1139 return failure();
1140
1141 auto type = getModuleType();
1142 auto *body = getBodyBlock();
1143
1144 // Verify the number of block arguments.
1145 auto numInputs = type.getNumInputs();
1146 if (body->getNumArguments() != numInputs)
1147 return emitOpError("entry block must have ")
1148 << numInputs << " arguments to match module signature";
1149
1150 return success();
1151}
1152
1153LogicalResult HWModuleExternOp::verify() { return verifyModuleCommon(*this); }
1154
1155std::pair<StringAttr, BlockArgument>
1156HWModuleOp::insertInput(unsigned index, StringAttr name, Type ty) {
1157 // Find a unique name for the wire.
1158 Namespace ns;
1159 auto ports = getPortList();
1160 for (auto port : ports)
1161 ns.newName(port.name.getValue());
1162 auto nameAttr = StringAttr::get(getContext(), ns.newName(name.getValue()));
1163
1164 Block *body = getBodyBlock();
1165
1166 // Create a new port for the host clock.
1167 PortInfo port;
1168 port.name = nameAttr;
1170 port.type = ty;
1171 modifyModulePorts(getOperation(), {std::make_pair(index, port)}, {}, {}, {},
1172 body);
1173
1174 // Add a new argument.
1175 return {nameAttr, body->getArgument(index)};
1176}
1177
1178void HWModuleOp::insertOutputs(unsigned index,
1179 ArrayRef<std::pair<StringAttr, Value>> outputs) {
1180
1181 auto output = cast<OutputOp>(getBodyBlock()->getTerminator());
1182 assert(index <= output->getNumOperands() && "invalid output index");
1183
1184 // Rewrite the port list of the module.
1185 SmallVector<std::pair<unsigned, PortInfo>> indexedNewPorts;
1186 for (auto &[name, value] : outputs) {
1187 PortInfo port;
1188 port.name = name;
1190 port.type = value.getType();
1191 indexedNewPorts.emplace_back(index, port);
1192 }
1193 modifyModulePorts(getOperation(), {}, indexedNewPorts, {}, {},
1194 getBodyBlock());
1195
1196 // Rewrite the output op.
1197 for (auto &[name, value] : outputs)
1198 output->insertOperands(index++, value);
1199}
1200
1201void HWModuleOp::appendOutputs(ArrayRef<std::pair<StringAttr, Value>> outputs) {
1202 return insertOutputs(getNumOutputPorts(), outputs);
1203}
1204
1205void HWModuleOp::getAsmBlockArgumentNames(mlir::Region &region,
1206 mlir::OpAsmSetValueNameFn setNameFn) {
1207 getAsmBlockArgumentNamesImpl(region, setNameFn);
1208}
1209
1210void HWModuleExternOp::getAsmBlockArgumentNames(
1211 mlir::Region &region, mlir::OpAsmSetValueNameFn setNameFn) {
1212 getAsmBlockArgumentNamesImpl(region, setNameFn);
1213}
1214
1215template <typename ModTy>
1216static SmallVector<Location> getAllPortLocs(ModTy module) {
1217 auto locs = module.getPortLocs();
1218 if (locs) {
1219 SmallVector<Location> retval;
1220 retval.reserve(locs->size());
1221 for (auto l : *locs)
1222 retval.push_back(cast<Location>(l));
1223 // Either we have a length of 0 or the correct length
1224 assert(!locs->size() || locs->size() == module.getNumPorts());
1225 return retval;
1226 }
1227 return SmallVector<Location>(module.getNumPorts(),
1228 UnknownLoc::get(module.getContext()));
1229}
1230
1231SmallVector<Location> HWModuleOp::getAllPortLocs() {
1232 SmallVector<Location> portLocs;
1233 portLocs.reserve(getNumPorts());
1234 auto resultLocs = getResultLocsAttr();
1235 unsigned inputCount = 0;
1236 auto modType = getModuleType();
1237 auto unknownLoc = UnknownLoc::get(getContext());
1238 auto *body = getBodyBlock();
1239 for (unsigned i = 0, e = getNumPorts(); i < e; ++i) {
1240 if (modType.isOutput(i)) {
1241 auto loc = resultLocs
1242 ? cast<Location>(
1243 resultLocs.getValue()[portLocs.size() - inputCount])
1244 : unknownLoc;
1245 portLocs.push_back(loc);
1246 } else {
1247 auto loc = body ? body->getArgument(inputCount).getLoc() : unknownLoc;
1248 portLocs.push_back(loc);
1249 ++inputCount;
1250 }
1251 }
1252 return portLocs;
1253}
1254
1255SmallVector<Location> HWModuleExternOp::getAllPortLocs() {
1256 return ::getAllPortLocs(*this);
1257}
1258
1259SmallVector<Location> HWModuleGeneratedOp::getAllPortLocs() {
1260 return ::getAllPortLocs(*this);
1261}
1262
1263void HWModuleOp::setAllPortLocsAttrs(ArrayRef<Attribute> locs) {
1264 SmallVector<Attribute> resultLocs;
1265 unsigned inputCount = 0;
1266 auto modType = getModuleType();
1267 auto *body = getBodyBlock();
1268 for (unsigned i = 0, e = getNumPorts(); i < e; ++i) {
1269 if (modType.isOutput(i))
1270 resultLocs.push_back(locs[i]);
1271 else
1272 body->getArgument(inputCount++).setLoc(cast<Location>(locs[i]));
1273 }
1274 setResultLocsAttr(ArrayAttr::get(getContext(), resultLocs));
1275}
1276
1277void HWModuleExternOp::setAllPortLocsAttrs(ArrayRef<Attribute> locs) {
1278 setPortLocsAttr(ArrayAttr::get(getContext(), locs));
1279}
1280
1281void HWModuleGeneratedOp::setAllPortLocsAttrs(ArrayRef<Attribute> locs) {
1282 setPortLocsAttr(ArrayAttr::get(getContext(), locs));
1283}
1284
1285template <typename ModTy>
1286static void setAllPortNames(ArrayRef<Attribute> names, ModTy module) {
1287 auto numInputs = module.getNumInputPorts();
1288 SmallVector<Attribute> argNames(names.begin(), names.begin() + numInputs);
1289 SmallVector<Attribute> resNames(names.begin() + numInputs, names.end());
1290 auto oldType = module.getModuleType();
1291 SmallVector<ModulePort> newPorts(oldType.getPorts().begin(),
1292 oldType.getPorts().end());
1293 for (size_t i = 0UL, e = newPorts.size(); i != e; ++i)
1294 newPorts[i].name = cast<StringAttr>(names[i]);
1295 auto newType = ModuleType::get(module.getContext(), newPorts);
1296 module.setModuleType(newType);
1297}
1298
1299void HWModuleOp::setAllPortNames(ArrayRef<Attribute> names) {
1300 ::setAllPortNames(names, *this);
1301}
1302
1303void HWModuleExternOp::setAllPortNames(ArrayRef<Attribute> names) {
1304 ::setAllPortNames(names, *this);
1305}
1306
1307void HWModuleGeneratedOp::setAllPortNames(ArrayRef<Attribute> names) {
1308 ::setAllPortNames(names, *this);
1309}
1310
1311ArrayRef<Attribute> HWModuleOp::getAllPortAttrs() {
1312 auto attrs = getPerPortAttrs();
1313 if (attrs && !attrs->empty())
1314 return attrs->getValue();
1315 return {};
1316}
1317
1318ArrayRef<Attribute> HWModuleExternOp::getAllPortAttrs() {
1319 auto attrs = getPerPortAttrs();
1320 if (attrs && !attrs->empty())
1321 return attrs->getValue();
1322 return {};
1323}
1324
1325ArrayRef<Attribute> HWModuleGeneratedOp::getAllPortAttrs() {
1326 auto attrs = getPerPortAttrs();
1327 if (attrs && !attrs->empty())
1328 return attrs->getValue();
1329 return {};
1330}
1331
1332void HWModuleOp::setAllPortAttrs(ArrayRef<Attribute> attrs) {
1333 setPerPortAttrsAttr(arrayOrEmpty(getContext(), attrs));
1334}
1335
1336void HWModuleExternOp::setAllPortAttrs(ArrayRef<Attribute> attrs) {
1337 setPerPortAttrsAttr(arrayOrEmpty(getContext(), attrs));
1338}
1339
1340void HWModuleGeneratedOp::setAllPortAttrs(ArrayRef<Attribute> attrs) {
1341 setPerPortAttrsAttr(arrayOrEmpty(getContext(), attrs));
1342}
1343
1344void HWModuleOp::removeAllPortAttrs() {
1345 setPerPortAttrsAttr(ArrayAttr::get(getContext(), {}));
1346}
1347
1348void HWModuleExternOp::removeAllPortAttrs() {
1349 setPerPortAttrsAttr(ArrayAttr::get(getContext(), {}));
1350}
1351
1352void HWModuleGeneratedOp::removeAllPortAttrs() {
1353 setPerPortAttrsAttr(ArrayAttr::get(getContext(), {}));
1354}
1355
1356// This probably does really unexpected stuff when you change the number of
1357
1358template <typename ModTy>
1359static void setHWModuleType(ModTy &mod, ModuleType type) {
1360 auto argAttrs = mod.getAllInputAttrs();
1361 auto resAttrs = mod.getAllOutputAttrs();
1362 mod.setModuleTypeAttr(TypeAttr::get(type));
1363 unsigned newNumArgs = type.getNumInputs();
1364 unsigned newNumResults = type.getNumOutputs();
1365
1366 auto emptyDict = DictionaryAttr::get(mod.getContext());
1367 argAttrs.resize(newNumArgs, emptyDict);
1368 resAttrs.resize(newNumResults, emptyDict);
1369
1370 SmallVector<Attribute> attrs;
1371 attrs.append(argAttrs.begin(), argAttrs.end());
1372 attrs.append(resAttrs.begin(), resAttrs.end());
1373
1374 if (attrs.empty())
1375 return mod.removeAllPortAttrs();
1376 mod.setAllPortAttrs(attrs);
1377}
1378
1379void HWModuleOp::setHWModuleType(ModuleType type) {
1380 return ::setHWModuleType(*this, type);
1381}
1382
1383void HWModuleExternOp::setHWModuleType(ModuleType type) {
1384 return ::setHWModuleType(*this, type);
1385}
1386
1387void HWModuleGeneratedOp::setHWModuleType(ModuleType type) {
1388 return ::setHWModuleType(*this, type);
1389}
1390
1391/// Lookup the generator for the symbol. This returns null on
1392/// invalid IR.
1393Operation *HWModuleGeneratedOp::getGeneratorKindOp() {
1394 auto topLevelModuleOp = (*this)->getParentOfType<ModuleOp>();
1395 return topLevelModuleOp.lookupSymbol(getGeneratorKind());
1396}
1397
1398LogicalResult
1399HWModuleGeneratedOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1400 auto *referencedKind =
1401 symbolTable.lookupNearestSymbolFrom(*this, getGeneratorKindAttr());
1402
1403 if (referencedKind == nullptr)
1404 return emitError("Cannot find generator definition '")
1405 << getGeneratorKind() << "'";
1406
1407 if (!isa<HWGeneratorSchemaOp>(referencedKind))
1408 return emitError("Symbol resolved to '")
1409 << referencedKind->getName()
1410 << "' which is not a HWGeneratorSchemaOp";
1411
1412 auto referencedKindOp = dyn_cast<HWGeneratorSchemaOp>(referencedKind);
1413 auto paramRef = referencedKindOp.getRequiredAttrs();
1414 auto dict = (*this)->getAttrDictionary();
1415 for (auto str : paramRef) {
1416 auto strAttr = dyn_cast<StringAttr>(str);
1417 if (!strAttr)
1418 return emitError("Unknown attribute type, expected a string");
1419 if (!dict.get(strAttr.getValue()))
1420 return emitError("Missing attribute '") << strAttr.getValue() << "'";
1421 }
1422
1423 return success();
1424}
1425
1426LogicalResult HWModuleGeneratedOp::verify() {
1427 return verifyModuleCommon(*this);
1428}
1429
1430void HWModuleGeneratedOp::getAsmBlockArgumentNames(
1431 mlir::Region &region, mlir::OpAsmSetValueNameFn setNameFn) {
1432 getAsmBlockArgumentNamesImpl(region, setNameFn);
1433}
1434
1435LogicalResult HWModuleOp::verifyBody() { return success(); }
1436
1437template <typename ModuleTy>
1438static SmallVector<PortInfo> getPortList(ModuleTy &mod) {
1439 auto modTy = mod.getHWModuleType();
1440 auto emptyDict = DictionaryAttr::get(mod.getContext());
1441 SmallVector<PortInfo> retval;
1442 auto locs = mod.getAllPortLocs();
1443 for (unsigned i = 0, e = modTy.getNumPorts(); i < e; ++i) {
1444 LocationAttr loc = locs[i];
1445 DictionaryAttr attrs =
1446 dyn_cast_or_null<DictionaryAttr>(mod.getPortAttrs(i));
1447 if (!attrs)
1448 attrs = emptyDict;
1449 retval.push_back({modTy.getPorts()[i],
1450 modTy.isOutput(i) ? modTy.getOutputIdForPortId(i)
1451 : modTy.getInputIdForPortId(i),
1452 attrs, loc});
1453 }
1454 return retval;
1455}
1456
1457template <typename ModuleTy>
1458static PortInfo getPort(ModuleTy &mod, size_t idx) {
1459 auto modTy = mod.getHWModuleType();
1460 auto emptyDict = DictionaryAttr::get(mod.getContext());
1461 LocationAttr loc = mod.getPortLoc(idx);
1462 DictionaryAttr attrs =
1463 dyn_cast_or_null<DictionaryAttr>(mod.getPortAttrs(idx));
1464 if (!attrs)
1465 attrs = emptyDict;
1466 return {modTy.getPorts()[idx],
1467 modTy.isOutput(idx) ? modTy.getOutputIdForPortId(idx)
1468 : modTy.getInputIdForPortId(idx),
1469 attrs, loc};
1470}
1471
1472//===----------------------------------------------------------------------===//
1473// InstanceOp
1474//===----------------------------------------------------------------------===//
1475
1476/// Create a instance that refers to a known module.
1477void InstanceOp::build(OpBuilder &builder, OperationState &result,
1478 Operation *module, StringAttr name,
1479 ArrayRef<Value> inputs, ArrayAttr parameters,
1480 InnerSymAttr innerSym) {
1481 if (!parameters)
1482 parameters = builder.getArrayAttr({});
1483
1484 auto mod = cast<hw::HWModuleLike>(module);
1485 auto argNames = builder.getArrayAttr(mod.getInputNames());
1486 auto resultNames = builder.getArrayAttr(mod.getOutputNames());
1487
1488 // Try to resolve the parameterized module type. If failed, use the module's
1489 // parmeterized type. If the client doesn't fix this error, the verifier will
1490 // fail.
1491 ModuleType modType = mod.getHWModuleType();
1492 FailureOr<ModuleType> resolvedModType = modType.resolveParametricTypes(
1493 parameters, result.location, /*emitErrors=*/false);
1494 if (succeeded(resolvedModType))
1495 modType = *resolvedModType;
1496 FunctionType funcType = resolvedModType->getFuncType();
1497 build(builder, result, funcType.getResults(), name,
1498 FlatSymbolRefAttr::get(SymbolTable::getSymbolName(module)), inputs,
1499 argNames, resultNames, parameters, innerSym, /*doNotPrint=*/{});
1500}
1501
1502std::optional<size_t> InstanceOp::getTargetResultIndex() {
1503 // Inner symbols on instance operations target the op not any result.
1504 return std::nullopt;
1505}
1506
1507LogicalResult InstanceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1509 *this, getModuleNameAttr(), getInputs(), getResultTypes(), getArgNames(),
1510 getResultNames(), getParameters(), symbolTable);
1511}
1512
1513LogicalResult InstanceOp::verify() {
1514 auto module = (*this)->getParentOfType<HWModuleOp>();
1515 if (!module)
1516 return success();
1517
1518 auto moduleParameters = module->getAttrOfType<ArrayAttr>("parameters");
1520 [&](const std::function<bool(InFlightDiagnostic &)> &fn) {
1521 auto diag = emitOpError();
1522 if (fn(diag))
1523 diag.attachNote(module->getLoc()) << "module declared here";
1524 };
1526 getParameters(), moduleParameters, emitError);
1527}
1528
1529ParseResult InstanceOp::parse(OpAsmParser &parser, OperationState &result) {
1530 StringAttr instanceNameAttr;
1531 InnerSymAttr innerSym;
1532 FlatSymbolRefAttr moduleNameAttr;
1533 SmallVector<OpAsmParser::UnresolvedOperand, 4> inputsOperands;
1534 SmallVector<Type, 1> inputsTypes, allResultTypes;
1535 ArrayAttr argNames, resultNames, parameters;
1536 auto noneType = parser.getBuilder().getType<NoneType>();
1537
1538 if (parser.parseAttribute(instanceNameAttr, noneType, "instanceName",
1539 result.attributes))
1540 return failure();
1541
1542 if (succeeded(parser.parseOptionalKeyword("sym"))) {
1543 // Parsing an optional symbol name doesn't fail, so no need to check the
1544 // result.
1545 if (parser.parseCustomAttributeWithFallback(innerSym))
1546 return failure();
1547 result.addAttribute(InnerSymbolTable::getInnerSymbolAttrName(), innerSym);
1548 }
1549
1550 llvm::SMLoc parametersLoc, inputsOperandsLoc;
1551 if (parser.parseAttribute(moduleNameAttr, noneType, "moduleName",
1552 result.attributes) ||
1553 parser.getCurrentLocation(&parametersLoc) ||
1554 parseOptionalParameterList(parser, parameters) ||
1555 parseInputPortList(parser, inputsOperands, inputsTypes, argNames) ||
1556 parser.resolveOperands(inputsOperands, inputsTypes, inputsOperandsLoc,
1557 result.operands) ||
1558 parser.parseArrow() ||
1559 parseOutputPortList(parser, allResultTypes, resultNames) ||
1560 parser.parseOptionalAttrDict(result.attributes)) {
1561 return failure();
1562 }
1563
1564 result.addAttribute("argNames", argNames);
1565 result.addAttribute("resultNames", resultNames);
1566 result.addAttribute("parameters", parameters);
1567 result.addTypes(allResultTypes);
1568 return success();
1569}
1570
1571void InstanceOp::print(OpAsmPrinter &p) {
1572 p << ' ';
1573 p.printAttributeWithoutType(getInstanceNameAttr());
1574 if (auto attr = getInnerSymAttr()) {
1575 p << " sym ";
1576 attr.print(p);
1577 }
1578 p << ' ';
1579 p.printAttributeWithoutType(getModuleNameAttr());
1580 printOptionalParameterList(p, *this, getParameters());
1581 printInputPortList(p, *this, getInputs(), getInputs().getTypes(),
1582 getArgNames());
1583 p << " -> ";
1584 printOutputPortList(p, *this, getResultTypes(), getResultNames());
1585
1586 p.printOptionalAttrDict(
1587 (*this)->getAttrs(),
1588 /*elidedAttrs=*/{"instanceName",
1589 InnerSymbolTable::getInnerSymbolAttrName(), "moduleName",
1590 "argNames", "resultNames", "parameters"});
1591}
1592
1593//===----------------------------------------------------------------------===//
1594// HWOutputOp
1595//===----------------------------------------------------------------------===//
1596
1597/// Verify that the num of operands and types fit the declared results.
1598LogicalResult OutputOp::verify() {
1599 // Check that the we (hw.output) have the same number of operands as our
1600 // region has results.
1601 ModuleType modType;
1602 if (auto mod = dyn_cast<HWModuleOp>((*this)->getParentOp()))
1603 modType = mod.getHWModuleType();
1604 else {
1605 emitOpError("must have a module parent");
1606 return failure();
1607 }
1608 auto modResults = modType.getOutputTypes();
1609 OperandRange outputValues = getOperands();
1610 if (modResults.size() != outputValues.size()) {
1611 emitOpError("must have same number of operands as region results.");
1612 return failure();
1613 }
1614
1615 // Check that the types of our operands and the region's results match.
1616 for (size_t i = 0, e = modResults.size(); i < e; ++i) {
1617 if (modResults[i] != outputValues[i].getType()) {
1618 emitOpError("output types must match module. In "
1619 "operand ")
1620 << i << ", expected " << modResults[i] << ", but got "
1621 << outputValues[i].getType() << ".";
1622 return failure();
1623 }
1624 }
1625
1626 return success();
1627}
1628
1629//===----------------------------------------------------------------------===//
1630// Other Operations
1631//===----------------------------------------------------------------------===//
1632
1633static ParseResult parseSliceTypes(OpAsmParser &p, Type &srcType,
1634 Type &idxType) {
1635 Type type;
1636 if (p.parseType(type))
1637 return p.emitError(p.getCurrentLocation(), "Expected type");
1638 auto arrType = type_dyn_cast<ArrayType>(type);
1639 if (!arrType)
1640 return p.emitError(p.getCurrentLocation(), "Expected !hw.array type");
1641 srcType = type;
1642 unsigned idxWidth = llvm::Log2_64_Ceil(arrType.getNumElements());
1643 idxType = IntegerType::get(p.getBuilder().getContext(), idxWidth);
1644 return success();
1645}
1646
1647static void printSliceTypes(OpAsmPrinter &p, Operation *, Type srcType,
1648 Type idxType) {
1649 p.printType(srcType);
1650}
1651
1652ParseResult ArrayCreateOp::parse(OpAsmParser &parser, OperationState &result) {
1653 llvm::SMLoc inputOperandsLoc = parser.getCurrentLocation();
1654 llvm::SmallVector<OpAsmParser::UnresolvedOperand, 16> operands;
1655 Type elemType;
1656
1657 if (parser.parseOperandList(operands) ||
1658 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
1659 parser.parseType(elemType))
1660 return failure();
1661
1662 if (operands.size() == 0)
1663 return parser.emitError(inputOperandsLoc,
1664 "Cannot construct an array of length 0");
1665
1666 // Check for optional result type: " -> <type>"
1667 Type resultType;
1668 if (parser.parseOptionalArrow().succeeded()) {
1669 if (parser.parseType(resultType))
1670 return failure();
1671 result.addTypes(resultType);
1672 } else {
1673 result.addTypes({ArrayType::get(elemType, operands.size())});
1674 }
1675
1676 for (auto operand : operands)
1677 if (parser.resolveOperand(operand, elemType, result.operands))
1678 return failure();
1679 return success();
1680}
1681
1682void ArrayCreateOp::print(OpAsmPrinter &p) {
1683 p << " ";
1684 p.printOperands(getInputs());
1685 p.printOptionalAttrDict((*this)->getAttrs());
1686 p << " : " << getInputs()[0].getType();
1687
1688 // Print optional result type if it's not the default constructed type
1689 Type expectedType =
1690 ArrayType::get(getInputs()[0].getType(), getInputs().size());
1691 if (getType() != expectedType)
1692 p << " -> " << getType();
1693}
1694
1695void ArrayCreateOp::build(OpBuilder &b, OperationState &state,
1696 ValueRange values) {
1697 assert(values.size() > 0 && "Cannot build array of zero elements");
1698 Type elemType = values[0].getType();
1699 assert(llvm::all_of(
1700 values,
1701 [elemType](Value v) -> bool { return v.getType() == elemType; }) &&
1702 "All values must have same type.");
1703 build(b, state, ArrayType::get(elemType, values.size()), values);
1704}
1705
1706LogicalResult ArrayCreateOp::verify() {
1707 unsigned returnSize = hw::type_cast<ArrayType>(getType()).getNumElements();
1708 if (getInputs().size() != returnSize)
1709 return failure();
1710 return success();
1711}
1712
1713OpFoldResult ArrayCreateOp::fold(FoldAdaptor adaptor) {
1714 if (llvm::any_of(adaptor.getInputs(), [](Attribute attr) {
1715 return !isa_and_nonnull<IntegerAttr>(attr);
1716 }))
1717 return {};
1718 return ArrayAttr::get(getContext(), adaptor.getInputs());
1719}
1720
1721// Check whether an integer value is an offset from a base.
1722bool hw::isOffset(Value base, Value index, uint64_t offset) {
1723 if (auto constBase = base.getDefiningOp<hw::ConstantOp>()) {
1724 if (auto constIndex = index.getDefiningOp<hw::ConstantOp>()) {
1725 // If both values are a constant, check if index == base + offset.
1726 // To account for overflow, the addition is performed with an extra bit
1727 // and the offset is asserted to fit in the bit width of the base.
1728 auto baseValue = constBase.getValue();
1729 auto indexValue = constIndex.getValue();
1730
1731 unsigned bits = baseValue.getBitWidth();
1732 assert(bits == indexValue.getBitWidth() && "mismatched widths");
1733
1734 if (bits < 64 && offset >= (1ull << bits))
1735 return false;
1736
1737 APInt baseExt = baseValue.zextOrTrunc(bits + 1);
1738 APInt indexExt = indexValue.zextOrTrunc(bits + 1);
1739 return baseExt + offset == indexExt;
1740 }
1741 }
1742 return false;
1743}
1744
1745// Canonicalize a create of consecutive elements to a slice.
1746static LogicalResult foldCreateToSlice(ArrayCreateOp op,
1747 PatternRewriter &rewriter) {
1748 // Do not canonicalize create of get into a slice.
1749 auto arrayTy = hw::type_cast<ArrayType>(op.getType());
1750 if (arrayTy.getNumElements() <= 1)
1751 return failure();
1752 auto elemTy = arrayTy.getElementType();
1753
1754 // Check if create arguments are consecutive elements of the same array.
1755 // Attempt to break a create of gets into a sequence of consecutive intervals.
1756 struct Chunk {
1757 Value input;
1758 Value index;
1759 size_t size;
1760 };
1761 SmallVector<Chunk> chunks;
1762 for (Value value : llvm::reverse(op.getInputs())) {
1763 auto get = value.getDefiningOp<ArrayGetOp>();
1764 if (!get)
1765 return failure();
1766
1767 Value input = get.getInput();
1768 Value index = get.getIndex();
1769 if (!chunks.empty()) {
1770 auto &c = *chunks.rbegin();
1771 if (c.input == get.getInput() && isOffset(c.index, index, c.size)) {
1772 c.size++;
1773 continue;
1774 }
1775 }
1776
1777 chunks.push_back(Chunk{input, index, 1});
1778 }
1779
1780 // If there is a single slice, eliminate the create.
1781 if (chunks.size() == 1) {
1782 auto &chunk = chunks[0];
1783 rewriter.replaceOp(op, rewriter.createOrFold<ArraySliceOp>(
1784 op.getLoc(), arrayTy, chunk.input, chunk.index));
1785 return success();
1786 }
1787
1788 // If the number of chunks is significantly less than the number of
1789 // elements, replace the create with a concat of the identified slices.
1790 if (chunks.size() * 2 < arrayTy.getNumElements()) {
1791 SmallVector<Value> slices;
1792 for (auto &chunk : llvm::reverse(chunks)) {
1793 auto sliceTy = ArrayType::get(elemTy, chunk.size);
1794 slices.push_back(rewriter.createOrFold<ArraySliceOp>(
1795 op.getLoc(), sliceTy, chunk.input, chunk.index));
1796 }
1797 rewriter.replaceOpWithNewOp<ArrayConcatOp>(op, arrayTy, slices);
1798 return success();
1799 }
1800
1801 return failure();
1802}
1803
1804LogicalResult ArrayCreateOp::canonicalize(ArrayCreateOp op,
1805 PatternRewriter &rewriter) {
1806 if (succeeded(foldCreateToSlice(op, rewriter)))
1807 return success();
1808 return failure();
1809}
1810
1811Value ArrayCreateOp::getUniformElement() {
1812 if (!getInputs().empty() && llvm::all_equal(getInputs()))
1813 return getInputs()[0];
1814 return {};
1815}
1816
1817static std::optional<uint64_t> getUIntFromValue(Value value) {
1818 auto idxOp = dyn_cast_or_null<ConstantOp>(value.getDefiningOp());
1819 if (!idxOp)
1820 return std::nullopt;
1821 APInt idxAttr = idxOp.getValue();
1822 if (idxAttr.getBitWidth() > 64)
1823 return std::nullopt;
1824 return idxAttr.getLimitedValue();
1825}
1826
1827LogicalResult ArraySliceOp::verify() {
1828 unsigned inputSize =
1829 type_cast<ArrayType>(getInput().getType()).getNumElements();
1830 if (llvm::Log2_64_Ceil(inputSize) !=
1831 getLowIndex().getType().getIntOrFloatBitWidth())
1832 return emitOpError(
1833 "ArraySlice: index width must match clog2 of array size");
1834 return success();
1835}
1836
1837OpFoldResult ArraySliceOp::fold(FoldAdaptor adaptor) {
1838 // If we are slicing the entire input, then return it.
1839 if (getType() == getInput().getType())
1840 return getInput();
1841 return {};
1842}
1843
1844LogicalResult ArraySliceOp::canonicalize(ArraySliceOp op,
1845 PatternRewriter &rewriter) {
1846 auto sliceTy = hw::type_cast<ArrayType>(op.getType());
1847 auto elemTy = sliceTy.getElementType();
1848 uint64_t sliceSize = sliceTy.getNumElements();
1849 if (sliceSize == 0)
1850 return failure();
1851
1852 if (sliceSize == 1) {
1853 // slice(a, n) -> create(a[n])
1854 auto get = ArrayGetOp::create(rewriter, op.getLoc(), op.getInput(),
1855 op.getLowIndex());
1856 rewriter.replaceOpWithNewOp<ArrayCreateOp>(op, op.getType(),
1857 get.getResult());
1858 return success();
1859 }
1860
1861 auto offsetOpt = getUIntFromValue(op.getLowIndex());
1862 if (!offsetOpt)
1863 return failure();
1864
1865 auto *inputOp = op.getInput().getDefiningOp();
1866 if (auto inputSlice = dyn_cast_or_null<ArraySliceOp>(inputOp)) {
1867 // slice(slice(a, n), m) -> slice(a, n + m)
1868 if (inputSlice == op)
1869 return failure();
1870
1871 auto inputIndex = inputSlice.getLowIndex();
1872 auto inputOffsetOpt = getUIntFromValue(inputIndex);
1873 if (!inputOffsetOpt)
1874 return failure();
1875
1876 uint64_t offset = *offsetOpt + *inputOffsetOpt;
1877 auto lowIndex =
1878 ConstantOp::create(rewriter, op.getLoc(), inputIndex.getType(), offset);
1879 rewriter.replaceOpWithNewOp<ArraySliceOp>(op, op.getType(),
1880 inputSlice.getInput(), lowIndex);
1881 return success();
1882 }
1883
1884 if (auto inputCreate = dyn_cast_or_null<ArrayCreateOp>(inputOp)) {
1885 // slice(create(a0, a1, ..., an), m) -> create(am, ...)
1886 auto inputs = inputCreate.getInputs();
1887
1888 uint64_t begin = inputs.size() - *offsetOpt - sliceSize;
1889 rewriter.replaceOpWithNewOp<ArrayCreateOp>(op, op.getType(),
1890 inputs.slice(begin, sliceSize));
1891 return success();
1892 }
1893
1894 if (auto inputConcat = dyn_cast_or_null<ArrayConcatOp>(inputOp)) {
1895 // slice(concat(a1, a2, ...)) -> concat(a2, slice(a3, ..), ...)
1896 SmallVector<Value> chunks;
1897 uint64_t sliceStart = *offsetOpt;
1898 for (auto input : llvm::reverse(inputConcat.getInputs())) {
1899 // Check whether the input intersects with the slice.
1900 uint64_t inputSize =
1901 hw::type_cast<ArrayType>(input.getType()).getNumElements();
1902 if (inputSize == 0 || inputSize <= sliceStart) {
1903 sliceStart -= inputSize;
1904 continue;
1905 }
1906
1907 // Find the indices to slice from this input by intersection.
1908 uint64_t cutEnd = std::min(inputSize, sliceStart + sliceSize);
1909 uint64_t cutSize = cutEnd - sliceStart;
1910 assert(cutSize != 0 && "slice cannot be empty");
1911
1912 if (cutSize == inputSize) {
1913 // The whole input fits in the slice, add it.
1914 assert(sliceStart == 0 && "invalid cut size");
1915 chunks.push_back(input);
1916 } else {
1917 // Slice the required bits from the input.
1918 unsigned width = inputSize == 1 ? 1 : llvm::Log2_64_Ceil(inputSize);
1919 auto lowIndex = ConstantOp::create(
1920 rewriter, op.getLoc(), rewriter.getIntegerType(width), sliceStart);
1921 chunks.push_back(ArraySliceOp::create(
1922 rewriter, op.getLoc(), hw::ArrayType::get(elemTy, cutSize), input,
1923 lowIndex));
1924 }
1925
1926 sliceStart = 0;
1927 sliceSize -= cutSize;
1928 if (sliceSize == 0)
1929 break;
1930 }
1931
1932 assert(chunks.size() > 0 && "missing sliced items");
1933 if (chunks.size() == 1)
1934 rewriter.replaceOp(op, chunks[0]);
1935 else
1936 rewriter.replaceOpWithNewOp<ArrayConcatOp>(
1937 op, llvm::to_vector(llvm::reverse(chunks)));
1938 return success();
1939 }
1940 return failure();
1941}
1942
1943//===----------------------------------------------------------------------===//
1944// ArrayConcatOp
1945//===----------------------------------------------------------------------===//
1946
1947static ParseResult parseArrayConcatTypes(OpAsmParser &p,
1948 SmallVectorImpl<Type> &inputTypes,
1949 Type &resultType) {
1950 Type elemType;
1951 uint64_t resultSize = 0;
1952
1953 auto parseElement = [&]() -> ParseResult {
1954 Type ty;
1955 if (p.parseType(ty))
1956 return failure();
1957 auto arrTy = type_dyn_cast<ArrayType>(ty);
1958 if (!arrTy)
1959 return p.emitError(p.getCurrentLocation(), "Expected !hw.array type");
1960 if (elemType && elemType != arrTy.getElementType())
1961 return p.emitError(p.getCurrentLocation(), "Expected array element type ")
1962 << elemType;
1963
1964 elemType = arrTy.getElementType();
1965 inputTypes.push_back(ty);
1966 resultSize += arrTy.getNumElements();
1967 return success();
1968 };
1969
1970 if (p.parseCommaSeparatedList(parseElement))
1971 return failure();
1972
1973 resultType = ArrayType::get(elemType, resultSize);
1974 return success();
1975}
1976
1977static void printArrayConcatTypes(OpAsmPrinter &p, Operation *,
1978 TypeRange inputTypes, Type resultType) {
1979 llvm::interleaveComma(inputTypes, p, [&p](Type t) { p << t; });
1980}
1981
1982void ArrayConcatOp::build(OpBuilder &b, OperationState &state,
1983 ValueRange values) {
1984 assert(!values.empty() && "Cannot build array of zero elements");
1985 ArrayType arrayTy = cast<ArrayType>(values[0].getType());
1986 Type elemTy = arrayTy.getElementType();
1987 assert(llvm::all_of(values,
1988 [elemTy](Value v) -> bool {
1989 return isa<ArrayType>(v.getType()) &&
1990 cast<ArrayType>(v.getType()).getElementType() ==
1991 elemTy;
1992 }) &&
1993 "All values must be of ArrayType with the same element type.");
1994
1995 uint64_t resultSize = 0;
1996 for (Value val : values)
1997 resultSize += cast<ArrayType>(val.getType()).getNumElements();
1998 build(b, state, ArrayType::get(elemTy, resultSize), values);
1999}
2000
2001OpFoldResult ArrayConcatOp::fold(FoldAdaptor adaptor) {
2002 if (getInputs().size() == 1)
2003 return getInputs()[0];
2004
2005 auto inputs = adaptor.getInputs();
2006 SmallVector<Attribute> array;
2007 for (size_t i = 0, e = getNumOperands(); i < e; ++i) {
2008 if (!inputs[i])
2009 return {};
2010 llvm::copy(cast<ArrayAttr>(inputs[i]), std::back_inserter(array));
2011 }
2012 return ArrayAttr::get(getContext(), array);
2013}
2014
2015// Flatten a concatenation of array creates into a single create.
2016static bool flattenConcatOp(ArrayConcatOp op, PatternRewriter &rewriter) {
2017 for (auto input : op.getInputs())
2018 if (!input.getDefiningOp<ArrayCreateOp>())
2019 return false;
2020
2021 SmallVector<Value> items;
2022 for (auto input : op.getInputs()) {
2023 auto create = cast<ArrayCreateOp>(input.getDefiningOp());
2024 for (auto item : create.getInputs())
2025 items.push_back(item);
2026 }
2027
2028 rewriter.replaceOpWithNewOp<ArrayCreateOp>(op, items);
2029 return true;
2030}
2031
2032// Merge consecutive slice expressions in a concatenation.
2033static bool mergeConcatSlices(ArrayConcatOp op, PatternRewriter &rewriter) {
2034 struct Slice {
2035 Value input;
2036 Value index;
2037 size_t size;
2038 Value op;
2039 SmallVector<Location> locs;
2040 };
2041
2042 SmallVector<Value> items;
2043 std::optional<Slice> last;
2044 bool changed = false;
2045
2046 auto concatenate = [&] {
2047 // If there is only one op in the slice, place it to the items list.
2048 if (!last)
2049 return;
2050 if (last->op) {
2051 items.push_back(last->op);
2052 last.reset();
2053 return;
2054 }
2055
2056 // Otherwise, create a new slice of with the given size and place it.
2057 // In this case, the concat op is replaced, using the new argument.
2058 changed = true;
2059 auto loc = FusedLoc::get(op.getContext(), last->locs);
2060 auto origTy = hw::type_cast<ArrayType>(last->input.getType());
2061 auto arrayTy = ArrayType::get(origTy.getElementType(), last->size);
2062 items.push_back(rewriter.createOrFold<ArraySliceOp>(
2063 loc, arrayTy, last->input, last->index));
2064
2065 last.reset();
2066 };
2067
2068 auto append = [&](Value op, Value input, Value index, size_t size) {
2069 // If this slice is an extension of the previous one, extend the size
2070 // saved. In this case, a new slice of is created and the concatenation
2071 // operator is rewritten. Otherwise, flush the last slice.
2072 if (last) {
2073 if (last->input == input && isOffset(last->index, index, last->size)) {
2074 last->size += size;
2075 last->op = {};
2076 last->locs.push_back(op.getLoc());
2077 return;
2078 }
2079 concatenate();
2080 }
2081 last.emplace(Slice{input, index, size, op, {op.getLoc()}});
2082 };
2083
2084 for (auto item : llvm::reverse(op.getInputs())) {
2085 if (auto slice = item.getDefiningOp<ArraySliceOp>()) {
2086 auto size = hw::type_cast<ArrayType>(slice.getType()).getNumElements();
2087 append(item, slice.getInput(), slice.getLowIndex(), size);
2088 continue;
2089 }
2090
2091 if (auto create = item.getDefiningOp<ArrayCreateOp>()) {
2092 if (create.getInputs().size() == 1) {
2093 if (auto get = create.getInputs()[0].getDefiningOp<ArrayGetOp>()) {
2094 append(item, get.getInput(), get.getIndex(), 1);
2095 continue;
2096 }
2097 }
2098 }
2099
2100 concatenate();
2101 items.push_back(item);
2102 }
2103 concatenate();
2104
2105 if (!changed)
2106 return false;
2107
2108 if (items.size() == 1) {
2109 rewriter.replaceOp(op, items[0]);
2110 } else {
2111 std::reverse(items.begin(), items.end());
2112 rewriter.replaceOpWithNewOp<ArrayConcatOp>(op, items);
2113 }
2114 return true;
2115}
2116
2117LogicalResult ArrayConcatOp::canonicalize(ArrayConcatOp op,
2118 PatternRewriter &rewriter) {
2119 // concat(create(a1, ...), create(a3, ...), ...) -> create(a1, ..., a3, ...)
2120 if (flattenConcatOp(op, rewriter))
2121 return success();
2122
2123 // concat(slice(a, n, m), slice(a, n + m, p)) -> concat(slice(a, n, m + p))
2124 if (mergeConcatSlices(op, rewriter))
2125 return success();
2126
2127 return failure();
2128}
2129
2130//===----------------------------------------------------------------------===//
2131// EnumConstantOp
2132//===----------------------------------------------------------------------===//
2133
2134ParseResult EnumConstantOp::parse(OpAsmParser &parser, OperationState &result) {
2135 // Parse a Type instead of an EnumType since the type might be a type alias.
2136 // The validity of the canonical type is checked during construction of the
2137 // EnumFieldAttr.
2138 Type type;
2139 StringRef field;
2140
2141 auto loc = parser.getEncodedSourceLoc(parser.getCurrentLocation());
2142 if (parser.parseKeyword(&field) || parser.parseColonType(type))
2143 return failure();
2144
2145 auto fieldAttr = EnumFieldAttr::get(
2146 loc, StringAttr::get(parser.getContext(), field), type);
2147
2148 if (!fieldAttr)
2149 return failure();
2150
2151 result.addAttribute("field", fieldAttr);
2152 result.addTypes(type);
2153
2154 return success();
2155}
2156
2157void EnumConstantOp::print(OpAsmPrinter &p) {
2158 p << " " << getField().getField().getValue() << " : "
2159 << getField().getType().getValue();
2160}
2161
2162void EnumConstantOp::getAsmResultNames(
2163 function_ref<void(Value, StringRef)> setNameFn) {
2164 setNameFn(getResult(), getField().getField().str());
2165}
2166
2167void EnumConstantOp::build(OpBuilder &builder, OperationState &odsState,
2168 EnumFieldAttr field) {
2169 return build(builder, odsState, field.getType().getValue(), field);
2170}
2171
2172OpFoldResult EnumConstantOp::fold(FoldAdaptor adaptor) {
2173 assert(adaptor.getOperands().empty() && "constant has no operands");
2174 return getFieldAttr();
2175}
2176
2177LogicalResult EnumConstantOp::verify() {
2178 auto fieldAttr = getFieldAttr();
2179 auto fieldType = fieldAttr.getType().getValue();
2180 // This check ensures that we are using the exact same type, without looking
2181 // through type aliases.
2182 if (fieldType != getType())
2183 emitOpError("return type ")
2184 << getType() << " does not match attribute type " << fieldAttr;
2185 return success();
2186}
2187
2188//===----------------------------------------------------------------------===//
2189// EnumCmpOp
2190//===----------------------------------------------------------------------===//
2191
2192LogicalResult EnumCmpOp::verify() {
2193 // Compare the canonical types.
2194 auto lhsType = type_cast<EnumType>(getLhs().getType());
2195 auto rhsType = type_cast<EnumType>(getRhs().getType());
2196 if (rhsType != lhsType)
2197 emitOpError("types do not match");
2198 return success();
2199}
2200
2201//===----------------------------------------------------------------------===//
2202// StructCreateOp
2203//===----------------------------------------------------------------------===//
2204
2205ParseResult StructCreateOp::parse(OpAsmParser &parser, OperationState &result) {
2206 llvm::SMLoc inputOperandsLoc = parser.getCurrentLocation();
2207 llvm::SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
2208 Type declOrAliasType;
2209
2210 if (parser.parseLParen() || parser.parseOperandList(operands) ||
2211 parser.parseRParen() || parser.parseOptionalAttrDict(result.attributes) ||
2212 parser.parseColonType(declOrAliasType))
2213 return failure();
2214
2215 auto declType = type_dyn_cast<StructType>(declOrAliasType);
2216 if (!declType)
2217 return parser.emitError(parser.getNameLoc(),
2218 "expected !hw.struct type or alias");
2219
2220 llvm::SmallVector<Type, 4> structInnerTypes;
2221 declType.getInnerTypes(structInnerTypes);
2222 result.addTypes(declOrAliasType);
2223
2224 if (parser.resolveOperands(operands, structInnerTypes, inputOperandsLoc,
2225 result.operands))
2226 return failure();
2227 return success();
2228}
2229
2230void StructCreateOp::print(OpAsmPrinter &printer) {
2231 printer << " (";
2232 printer.printOperands(getInput());
2233 printer << ")";
2234 printer.printOptionalAttrDict((*this)->getAttrs());
2235 printer << " : " << getType();
2236}
2237
2238LogicalResult StructCreateOp::verify() {
2239 auto elements = hw::type_cast<StructType>(getType()).getElements();
2240
2241 if (elements.size() != getInput().size())
2242 return emitOpError("structure field count mismatch");
2243
2244 for (const auto &[field, value] : llvm::zip(elements, getInput()))
2245 if (field.type != value.getType())
2246 return emitOpError("structure field `")
2247 << field.name << "` type does not match";
2248
2249 return success();
2250}
2251
2252OpFoldResult StructCreateOp::fold(FoldAdaptor adaptor) {
2253 // struct_create(struct_explode(x)) => x
2254 if (!getInput().empty())
2255 if (auto explodeOp = getInput()[0].getDefiningOp<StructExplodeOp>();
2256 explodeOp && getInput() == explodeOp.getResults() &&
2257 getResult().getType() == explodeOp.getInput().getType())
2258 return explodeOp.getInput();
2259
2260 auto inputs = adaptor.getInput();
2261 if (llvm::any_of(inputs, [](Attribute attr) {
2262 return !isa_and_nonnull<IntegerAttr>(attr);
2263 }))
2264 return {};
2265 return ArrayAttr::get(getContext(), inputs);
2266}
2267
2268//===----------------------------------------------------------------------===//
2269// StructExplodeOp
2270//===----------------------------------------------------------------------===//
2271
2272ParseResult StructExplodeOp::parse(OpAsmParser &parser,
2273 OperationState &result) {
2274 OpAsmParser::UnresolvedOperand operand;
2275 Type declType;
2276
2277 if (parser.parseOperand(operand) ||
2278 parser.parseOptionalAttrDict(result.attributes) ||
2279 parser.parseColonType(declType))
2280 return failure();
2281 auto structType = type_dyn_cast<StructType>(declType);
2282 if (!structType)
2283 return parser.emitError(parser.getNameLoc(),
2284 "invalid kind of type specified");
2285
2286 llvm::SmallVector<Type, 4> structInnerTypes;
2287 structType.getInnerTypes(structInnerTypes);
2288 result.addTypes(structInnerTypes);
2289
2290 if (parser.resolveOperand(operand, declType, result.operands))
2291 return failure();
2292 return success();
2293}
2294
2295void StructExplodeOp::print(OpAsmPrinter &printer) {
2296 printer << " ";
2297 printer.printOperand(getInput());
2298 printer.printOptionalAttrDict((*this)->getAttrs());
2299 printer << " : " << getInput().getType();
2300}
2301
2302LogicalResult StructExplodeOp::fold(FoldAdaptor adaptor,
2303 SmallVectorImpl<OpFoldResult> &results) {
2304 auto input = adaptor.getInput();
2305 if (!input)
2306 return failure();
2307 llvm::copy(cast<ArrayAttr>(input), std::back_inserter(results));
2308 return success();
2309}
2310
2311LogicalResult StructExplodeOp::canonicalize(StructExplodeOp op,
2312 PatternRewriter &rewriter) {
2313 auto *inputOp = op.getInput().getDefiningOp();
2314 auto elements = type_cast<StructType>(op.getInput().getType()).getElements();
2315 auto result = failure();
2316 auto opResults = op.getResults();
2317 for (uint32_t index = 0; index < elements.size(); index++) {
2318 if (auto foldResult = foldStructExtract(inputOp, index)) {
2319 rewriter.replaceAllUsesWith(opResults[index], foldResult);
2320 result = success();
2321 }
2322 }
2323 return result;
2324}
2325
2326void StructExplodeOp::getAsmResultNames(
2327 function_ref<void(Value, StringRef)> setNameFn) {
2328 auto structType = type_cast<StructType>(getInput().getType());
2329 for (auto [res, field] : llvm::zip(getResults(), structType.getElements()))
2330 setNameFn(res, field.name.str());
2331}
2332
2333void StructExplodeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
2334 Value input) {
2335 StructType inputType = dyn_cast<StructType>(input.getType());
2336 assert(inputType);
2337 SmallVector<Type, 16> fieldTypes;
2338 for (auto field : inputType.getElements())
2339 fieldTypes.push_back(field.type);
2340 build(odsBuilder, odsState, fieldTypes, input);
2341}
2342
2343//===----------------------------------------------------------------------===//
2344// StructExtractOp
2345//===----------------------------------------------------------------------===//
2346
2347/// Ensure an aggregate op's field index is within the bounds of
2348/// the aggregate type and the accessed field is of 'elementType'.
2349template <typename AggregateOp, typename AggregateType>
2350static LogicalResult verifyAggregateFieldIndexAndType(AggregateOp &op,
2351 AggregateType aggType,
2352 Type elementType) {
2353 auto index = op.getFieldIndex();
2354 if (index >= aggType.getElements().size())
2355 return op.emitOpError() << "field index " << index
2356 << " exceeds element count of aggregate type";
2357
2359 getCanonicalType(aggType.getElements()[index].type))
2360 return op.emitOpError()
2361 << "type " << aggType.getElements()[index].type
2362 << " of accessed field in aggregate at index " << index
2363 << " does not match expected type " << elementType;
2364
2365 return success();
2366}
2367
2368LogicalResult StructExtractOp::verify() {
2369 return verifyAggregateFieldIndexAndType<StructExtractOp, StructType>(
2370 *this, getInput().getType(), getType());
2371}
2372
2373/// Use the same parser for both struct_extract and union_extract since the
2374/// syntax is identical.
2375template <typename AggregateType>
2376static ParseResult parseExtractOp(OpAsmParser &parser, OperationState &result) {
2377 OpAsmParser::UnresolvedOperand operand;
2378 StringAttr fieldName;
2379 Type declType;
2380
2381 if (parser.parseOperand(operand) || parser.parseLSquare() ||
2382 parser.parseAttribute(fieldName) || parser.parseRSquare() ||
2383 parser.parseOptionalAttrDict(result.attributes) ||
2384 parser.parseColonType(declType))
2385 return failure();
2386 auto aggType = type_dyn_cast<AggregateType>(declType);
2387 if (!aggType)
2388 return parser.emitError(parser.getNameLoc(),
2389 "invalid kind of type specified");
2390
2391 auto fieldIndex = aggType.getFieldIndex(fieldName);
2392 if (!fieldIndex) {
2393 parser.emitError(parser.getNameLoc(), "field name '" +
2394 fieldName.getValue() +
2395 "' not found in aggregate type");
2396 return failure();
2397 }
2398
2399 auto indexAttr =
2400 IntegerAttr::get(IntegerType::get(parser.getContext(), 32), *fieldIndex);
2401 result.addAttribute("fieldIndex", indexAttr);
2402 Type resultType = aggType.getElements()[*fieldIndex].type;
2403 result.addTypes(resultType);
2404
2405 if (parser.resolveOperand(operand, declType, result.operands))
2406 return failure();
2407 return success();
2408}
2409
2410/// Use the same printer for both struct_extract and union_extract since the
2411/// syntax is identical.
2412template <typename AggType>
2413static void printExtractOp(OpAsmPrinter &printer, AggType op) {
2414 printer << " ";
2415 printer.printOperand(op.getInput());
2416 printer << "[\"" << op.getFieldName() << "\"]";
2417 printer.printOptionalAttrDict(op->getAttrs(), {"fieldIndex"});
2418 printer << " : " << op.getInput().getType();
2419}
2420
2421ParseResult StructExtractOp::parse(OpAsmParser &parser,
2422 OperationState &result) {
2423 return parseExtractOp<StructType>(parser, result);
2424}
2425
2426void StructExtractOp::print(OpAsmPrinter &printer) {
2427 printExtractOp(printer, *this);
2428}
2429
2430void StructExtractOp::build(OpBuilder &builder, OperationState &odsState,
2431 Value input, StructType::FieldInfo field) {
2432 auto fieldIndex =
2433 type_cast<StructType>(input.getType()).getFieldIndex(field.name);
2434 assert(fieldIndex.has_value() && "field name not found in aggregate type");
2435 build(builder, odsState, field.type, input, *fieldIndex);
2436}
2437
2438void StructExtractOp::build(OpBuilder &builder, OperationState &odsState,
2439 Value input, StringAttr fieldName) {
2440 auto structType = type_cast<StructType>(input.getType());
2441 auto fieldIndex = structType.getFieldIndex(fieldName);
2442 assert(fieldIndex.has_value() && "field name not found in aggregate type");
2443 auto resultType = structType.getElements()[*fieldIndex].type;
2444 build(builder, odsState, resultType, input, *fieldIndex);
2445}
2446
2447OpFoldResult StructExtractOp::fold(FoldAdaptor adaptor) {
2448 if (auto constOperand = adaptor.getInput()) {
2449 // Fold extract from aggregate constant
2450 auto operandAttr = llvm::cast<ArrayAttr>(constOperand);
2451 return operandAttr.getValue()[getFieldIndex()];
2452 }
2453
2454 if (auto foldResult =
2455 foldStructExtract(getInput().getDefiningOp(), getFieldIndex()))
2456 return foldResult;
2457 return {};
2458}
2459
2460LogicalResult StructExtractOp::canonicalize(StructExtractOp op,
2461 PatternRewriter &rewriter) {
2462 auto *inputOp = op.getInput().getDefiningOp();
2463
2464 // b = extract(inject(x["a"], v0)["b"]) => extract(x, "b")
2465 if (auto structInject = dyn_cast_or_null<StructInjectOp>(inputOp)) {
2466 if (structInject.getFieldIndex() != op.getFieldIndex()) {
2467 rewriter.replaceOpWithNewOp<StructExtractOp>(
2468 op, op.getType(), structInject.getInput(), op.getFieldIndexAttr());
2469 return success();
2470 }
2471 }
2472
2473 return failure();
2474}
2475
2476void StructExtractOp::getAsmResultNames(
2477 function_ref<void(Value, StringRef)> setNameFn) {
2478 setNameFn(getResult(), getFieldName());
2479}
2480
2481//===----------------------------------------------------------------------===//
2482// StructInjectOp
2483//===----------------------------------------------------------------------===//
2484
2485void StructInjectOp::build(OpBuilder &builder, OperationState &odsState,
2486 Value input, StringAttr fieldName, Value newValue) {
2487 auto structType = type_cast<StructType>(input.getType());
2488 auto fieldIndex = structType.getFieldIndex(fieldName);
2489 assert(fieldIndex.has_value() && "field name not found in aggregate type");
2490 build(builder, odsState, input, *fieldIndex, newValue);
2491}
2492
2493LogicalResult StructInjectOp::verify() {
2494 return verifyAggregateFieldIndexAndType<StructInjectOp, StructType>(
2495 *this, getInput().getType(), getNewValue().getType());
2496}
2497
2498ParseResult StructInjectOp::parse(OpAsmParser &parser, OperationState &result) {
2499 llvm::SMLoc inputOperandsLoc = parser.getCurrentLocation();
2500 OpAsmParser::UnresolvedOperand operand, val;
2501 StringAttr fieldName;
2502 Type declType;
2503
2504 if (parser.parseOperand(operand) || parser.parseLSquare() ||
2505 parser.parseAttribute(fieldName) || parser.parseRSquare() ||
2506 parser.parseComma() || parser.parseOperand(val) ||
2507 parser.parseOptionalAttrDict(result.attributes) ||
2508 parser.parseColonType(declType))
2509 return failure();
2510 auto structType = type_dyn_cast<StructType>(declType);
2511 if (!structType)
2512 return parser.emitError(inputOperandsLoc, "invalid kind of type specified");
2513
2514 auto fieldIndex = structType.getFieldIndex(fieldName);
2515 if (!fieldIndex) {
2516 parser.emitError(parser.getNameLoc(), "field name '" +
2517 fieldName.getValue() +
2518 "' not found in aggregate type");
2519 return failure();
2520 }
2521
2522 auto indexAttr =
2523 IntegerAttr::get(IntegerType::get(parser.getContext(), 32), *fieldIndex);
2524 result.addAttribute("fieldIndex", indexAttr);
2525 result.addTypes(declType);
2526
2527 Type resultType = structType.getElements()[*fieldIndex].type;
2528 if (parser.resolveOperands({operand, val}, {declType, resultType},
2529 inputOperandsLoc, result.operands))
2530 return failure();
2531 return success();
2532}
2533
2534void StructInjectOp::print(OpAsmPrinter &printer) {
2535 printer << " ";
2536 printer.printOperand(getInput());
2537 printer << "[\"" << getFieldName() << "\"], ";
2538 printer.printOperand(getNewValue());
2539 printer.printOptionalAttrDict((*this)->getAttrs(), {"fieldIndex"});
2540 printer << " : " << getInput().getType();
2541}
2542
2543OpFoldResult StructInjectOp::fold(FoldAdaptor adaptor) {
2544 auto input = adaptor.getInput();
2545 auto newValue = adaptor.getNewValue();
2546 if (!input || !newValue)
2547 return {};
2548 SmallVector<Attribute> array;
2549 llvm::copy(cast<ArrayAttr>(input), std::back_inserter(array));
2550 array[getFieldIndex()] = newValue;
2551 return ArrayAttr::get(getContext(), array);
2552}
2553
2554LogicalResult StructInjectOp::canonicalize(StructInjectOp op,
2555 PatternRewriter &rewriter) {
2556 // If this inject is only used as an input to another inject, don't try to
2557 // canonicalize it. It will be included in that other op's canonicalization
2558 // attempt. This avoids doing redundant work.
2559 if (op->hasOneUse()) {
2560 auto &use = *op->use_begin();
2561 if (isa<StructInjectOp>(use.getOwner()) && use.getOperandNumber() == 0)
2562 return failure();
2563 }
2564
2565 // Canonicalize multiple injects into a create op and eliminate overwrites.
2566 SmallPtrSet<Operation *, 4> injects;
2567 DenseMap<StringAttr, Value> fields;
2568
2569 // Chase a chain of injects. Bail out if cycles are present.
2570 StructInjectOp inject = op;
2571 Value input;
2572 do {
2573 if (!injects.insert(inject).second)
2574 break;
2575
2576 fields.try_emplace(inject.getFieldNameAttr(), inject.getNewValue());
2577 input = inject.getInput();
2578 inject = dyn_cast_or_null<StructInjectOp>(input.getDefiningOp());
2579 } while (inject);
2580 assert(input && "missing input to inject chain");
2581
2582 auto ty = hw::type_cast<StructType>(op.getType());
2583 auto elements = ty.getElements();
2584
2585 // If the inject chain sets all fields, canonicalize to create.
2586 if (fields.size() == elements.size()) {
2587 SmallVector<Value> createFields;
2588 for (const auto &field : elements) {
2589 auto it = fields.find(field.name);
2590 assert(it != fields.end() && "missing field");
2591 createFields.push_back(it->second);
2592 }
2593 rewriter.replaceOpWithNewOp<StructCreateOp>(op, ty, createFields);
2594 return success();
2595 }
2596
2597 // Nothing to canonicalize, only the original inject in the chain.
2598 if (injects.size() == fields.size())
2599 return failure();
2600
2601 // Eliminate overwrites. The hash map contains the last write to each field.
2602 for (uint32_t fieldIndex = 0; fieldIndex < elements.size(); fieldIndex++) {
2603 auto it = fields.find(elements[fieldIndex].name);
2604 if (it == fields.end())
2605 continue;
2606 input = StructInjectOp::create(rewriter, op.getLoc(), ty, input, fieldIndex,
2607 it->second);
2608 }
2609
2610 rewriter.replaceOp(op, input);
2611 return success();
2612}
2613
2614//===----------------------------------------------------------------------===//
2615// UnionCreateOp
2616//===----------------------------------------------------------------------===//
2617
2618LogicalResult UnionCreateOp::verify() {
2619 return verifyAggregateFieldIndexAndType<UnionCreateOp, UnionType>(
2620 *this, getType(), getInput().getType());
2621}
2622
2623void UnionCreateOp::build(OpBuilder &builder, OperationState &odsState,
2624 Type unionType, StringAttr fieldName, Value input) {
2625 auto fieldIndex = type_cast<UnionType>(unionType).getFieldIndex(fieldName);
2626 assert(fieldIndex.has_value() && "field name not found in aggregate type");
2627 build(builder, odsState, unionType, *fieldIndex, input);
2628}
2629
2630ParseResult UnionCreateOp::parse(OpAsmParser &parser, OperationState &result) {
2631 Type declOrAliasType;
2632 StringAttr fieldName;
2633 OpAsmParser::UnresolvedOperand input;
2634 llvm::SMLoc fieldLoc = parser.getCurrentLocation();
2635
2636 if (parser.parseAttribute(fieldName) || parser.parseComma() ||
2637 parser.parseOperand(input) ||
2638 parser.parseOptionalAttrDict(result.attributes) ||
2639 parser.parseColonType(declOrAliasType))
2640 return failure();
2641
2642 auto declType = type_dyn_cast<UnionType>(declOrAliasType);
2643 if (!declType)
2644 return parser.emitError(parser.getNameLoc(),
2645 "expected !hw.union type or alias");
2646
2647 auto fieldIndex = declType.getFieldIndex(fieldName);
2648 if (!fieldIndex) {
2649 parser.emitError(fieldLoc, "cannot find union field '")
2650 << fieldName.getValue() << '\'';
2651 return failure();
2652 }
2653
2654 auto indexAttr =
2655 IntegerAttr::get(IntegerType::get(parser.getContext(), 32), *fieldIndex);
2656 result.addAttribute("fieldIndex", indexAttr);
2657 Type inputType = declType.getElements()[*fieldIndex].type;
2658
2659 if (parser.resolveOperand(input, inputType, result.operands))
2660 return failure();
2661 result.addTypes({declOrAliasType});
2662 return success();
2663}
2664
2665void UnionCreateOp::print(OpAsmPrinter &printer) {
2666 printer << " \"" << getFieldName() << "\", ";
2667 printer.printOperand(getInput());
2668 printer.printOptionalAttrDict((*this)->getAttrs(), {"fieldIndex"});
2669 printer << " : " << getType();
2670}
2671
2672//===----------------------------------------------------------------------===//
2673// UnionExtractOp
2674//===----------------------------------------------------------------------===//
2675
2676ParseResult UnionExtractOp::parse(OpAsmParser &parser, OperationState &result) {
2677 return parseExtractOp<UnionType>(parser, result);
2678}
2679
2680void UnionExtractOp::print(OpAsmPrinter &printer) {
2681 printExtractOp(printer, *this);
2682}
2683
2684LogicalResult UnionExtractOp::inferReturnTypes(
2685 MLIRContext *context, std::optional<Location> loc, ValueRange operands,
2686 DictionaryAttr attrs, mlir::OpaqueProperties properties,
2687 mlir::RegionRange regions, SmallVectorImpl<Type> &results) {
2688 Adaptor adaptor(operands, attrs, properties, regions);
2689 auto unionElements =
2690 hw::type_cast<UnionType>((adaptor.getInput().getType())).getElements();
2691 unsigned fieldIndex = adaptor.getFieldIndexAttr().getValue().getZExtValue();
2692 if (fieldIndex >= unionElements.size()) {
2693 if (loc)
2694 mlir::emitError(*loc, "field index " + Twine(fieldIndex) +
2695 " exceeds element count of aggregate type");
2696 return failure();
2697 }
2698 results.push_back(unionElements[fieldIndex].type);
2699 return success();
2700}
2701
2702void UnionExtractOp::build(OpBuilder &odsBuilder, OperationState &odsState,
2703 Value input, StringAttr fieldName) {
2704 auto unionType = type_cast<UnionType>(input.getType());
2705 auto fieldIndex = unionType.getFieldIndex(fieldName);
2706 assert(fieldIndex.has_value() && "field name not found in aggregate type");
2707 auto resultType = unionType.getElements()[*fieldIndex].type;
2708 build(odsBuilder, odsState, resultType, input, *fieldIndex);
2709}
2710
2711//===----------------------------------------------------------------------===//
2712// ArrayGetOp
2713//===----------------------------------------------------------------------===//
2714
2715// An array_get of an array_create with a constant index can just be the
2716// array_create operand at the constant index. If the array_create has a
2717// single uniform value for each element, just return that value regardless of
2718// the index. If the array is constructed from a constant by a bitcast
2719// operation, we can fold into a constant.
2720OpFoldResult ArrayGetOp::fold(FoldAdaptor adaptor) {
2721 auto inputCst = dyn_cast_or_null<ArrayAttr>(adaptor.getInput());
2722 auto indexCst = dyn_cast_or_null<IntegerAttr>(adaptor.getIndex());
2723
2724 if (inputCst) {
2725 // Constant array index.
2726 if (indexCst) {
2727 auto indexVal = indexCst.getValue();
2728 if (indexVal.getBitWidth() < 64) {
2729 auto index = indexVal.getZExtValue();
2730 return inputCst[inputCst.size() - 1 - index];
2731 }
2732 }
2733 // If all elements of the array are the same, we can return any element of
2734 // array.
2735 if (!inputCst.empty() && llvm::all_equal(inputCst))
2736 return inputCst[0];
2737 }
2738
2739 // array_get(bitcast(c), i) -> c[i*w+w-1:i*w]
2740 if (auto bitcast = getInput().getDefiningOp<hw::BitcastOp>()) {
2741 auto intTy = dyn_cast<IntegerType>(getType());
2742 if (!intTy)
2743 return {};
2744 auto bitcastInputOp = bitcast.getInput().getDefiningOp<hw::ConstantOp>();
2745 if (!bitcastInputOp)
2746 return {};
2747 if (!indexCst)
2748 return {};
2749 auto bitcastInputCst = bitcastInputOp.getValue();
2750 // Calculate the index. Make sure to zero-extend the index value before
2751 // multiplying the element width.
2752 auto startIdx = indexCst.getValue().zext(bitcastInputCst.getBitWidth()) *
2753 getType().getIntOrFloatBitWidth();
2754 // Extract [startIdx + width - 1: startIdx].
2755 return IntegerAttr::get(intTy, bitcastInputCst.lshr(startIdx).trunc(
2756 intTy.getIntOrFloatBitWidth()));
2757 }
2758
2759 // array_get(array_inject(_, index, element), index) -> element
2760 if (auto inject = getInput().getDefiningOp<ArrayInjectOp>())
2761 if (getIndex() == inject.getIndex())
2762 return inject.getElement();
2763
2764 auto inputCreate = getInput().getDefiningOp<ArrayCreateOp>();
2765 if (!inputCreate)
2766 return {};
2767
2768 if (auto uniformValue = inputCreate.getUniformElement())
2769 return uniformValue;
2770
2771 if (!indexCst || indexCst.getValue().getBitWidth() > 64)
2772 return {};
2773
2774 uint64_t index = indexCst.getValue().getLimitedValue();
2775 auto createInputs = inputCreate.getInputs();
2776 if (index >= createInputs.size())
2777 return {};
2778 return createInputs[createInputs.size() - index - 1];
2779}
2780
2781LogicalResult ArrayGetOp::canonicalize(ArrayGetOp op,
2782 PatternRewriter &rewriter) {
2783 auto idxOpt = getUIntFromValue(op.getIndex());
2784 if (!idxOpt)
2785 return failure();
2786
2787 auto *inputOp = op.getInput().getDefiningOp();
2788 if (auto inputSlice = dyn_cast_or_null<ArraySliceOp>(inputOp)) {
2789 // get(slice(a, n), m) -> get(a, n + m)
2790 auto offsetOp = inputSlice.getLowIndex();
2791 auto offsetOpt = getUIntFromValue(offsetOp);
2792 if (!offsetOpt)
2793 return failure();
2794
2795 uint64_t offset = *offsetOpt + *idxOpt;
2796 auto newOffset =
2797 ConstantOp::create(rewriter, op.getLoc(), offsetOp.getType(), offset);
2798 rewriter.replaceOpWithNewOp<ArrayGetOp>(op, inputSlice.getInput(),
2799 newOffset);
2800 return success();
2801 }
2802
2803 if (auto inputConcat = dyn_cast_or_null<ArrayConcatOp>(inputOp)) {
2804 // get(concat(a0, a1, ...), m) -> get(an, m - s0 - s1 - ...)
2805 uint64_t elemIndex = *idxOpt;
2806 for (auto input : llvm::reverse(inputConcat.getInputs())) {
2807 size_t size = hw::type_cast<ArrayType>(input.getType()).getNumElements();
2808 if (elemIndex >= size) {
2809 elemIndex -= size;
2810 continue;
2811 }
2812
2813 unsigned indexWidth = size == 1 ? 1 : llvm::Log2_64_Ceil(size);
2814 auto newIdxOp =
2815 ConstantOp::create(rewriter, op.getLoc(),
2816 rewriter.getIntegerType(indexWidth), elemIndex);
2817
2818 rewriter.replaceOpWithNewOp<ArrayGetOp>(op, input, newIdxOp);
2819 return success();
2820 }
2821 return failure();
2822 }
2823
2824 // array_get const, (array_get sel, (array_create a, b, c, d)) -->
2825 // array_get sel, (array_create (array_get const a), (array_get const b),
2826 // (array_get const, c), (array_get const, d))
2827 if (auto innerGet = dyn_cast_or_null<hw::ArrayGetOp>(inputOp)) {
2828 if (!innerGet.getIndex().getDefiningOp<hw::ConstantOp>()) {
2829 if (auto create =
2830 innerGet.getInput().getDefiningOp<hw::ArrayCreateOp>()) {
2831
2832 SmallVector<Value> newValues;
2833 for (auto operand : create.getOperands())
2834 newValues.push_back(rewriter.createOrFold<hw::ArrayGetOp>(
2835 op.getLoc(), operand, op.getIndex()));
2836
2837 rewriter.replaceOpWithNewOp<hw::ArrayGetOp>(
2838 op,
2839 rewriter.createOrFold<hw::ArrayCreateOp>(op.getLoc(), newValues),
2840 innerGet.getIndex());
2841 return success();
2842 }
2843 }
2844 }
2845
2846 return failure();
2847}
2848
2849//===----------------------------------------------------------------------===//
2850// ArrayInjectOp
2851//===----------------------------------------------------------------------===//
2852
2853OpFoldResult ArrayInjectOp::fold(FoldAdaptor adaptor) {
2854 auto inputAttr = dyn_cast_or_null<ArrayAttr>(adaptor.getInput());
2855 auto indexAttr = dyn_cast_or_null<IntegerAttr>(adaptor.getIndex());
2856 auto elementAttr = adaptor.getElement();
2857
2858 // inject(constant[xs, y, zs], iy, a) -> constant[x, a, z]
2859 if (inputAttr && indexAttr && elementAttr) {
2860 if (auto index = indexAttr.getValue().tryZExtValue()) {
2861 if (*index < inputAttr.size()) {
2862 SmallVector<Attribute> elements(inputAttr.getValue());
2863 elements[inputAttr.size() - 1 - *index] = elementAttr;
2864 return ArrayAttr::get(getContext(), elements);
2865 }
2866 }
2867 }
2868
2869 return {};
2870}
2871
2872static LogicalResult canonicalizeArrayInjectChain(ArrayInjectOp op,
2873 PatternRewriter &rewriter) {
2874 // If this inject is only used as an input to another inject, don't try to
2875 // canonicalize it. It will be included in that other op's canonicalization
2876 // attempt. This avoids doing redundant work.
2877 if (op->hasOneUse()) {
2878 auto &use = *op->use_begin();
2879 if (isa<ArrayInjectOp>(use.getOwner()) && use.getOperandNumber() == 0)
2880 return failure();
2881 }
2882
2883 // Collect all injects to constant indices.
2884 auto arrayLength = type_cast<ArrayType>(op.getType()).getNumElements();
2885 Value input = op;
2887 while (auto inject = input.getDefiningOp<ArrayInjectOp>()) {
2888 // Determine the constant index.
2889 APInt indexAPInt;
2890 if (!matchPattern(inject.getIndex(), mlir::m_ConstantInt(&indexAPInt)))
2891 break;
2892 if (indexAPInt.getActiveBits() > 32)
2893 break;
2894 uint32_t index = indexAPInt.getZExtValue();
2895
2896 // Track the injected value. Make sure to only track indices that are in
2897 // bounds. This will allow us to later check if `elements.size()` matches
2898 // the array length.
2899 if (index < arrayLength)
2900 elements.insert({index, inject.getElement()});
2901
2902 // Step to the next inject op.
2903 input = inject.getInput();
2904 if (input == op)
2905 break; // break cycles
2906 }
2907
2908 // If we are assigning every single element, replace the op with an
2909 // `hw.array_create`.
2910 if (elements.size() == arrayLength) {
2911 SmallVector<Value, 4> operands;
2912 operands.reserve(arrayLength);
2913 for (uint32_t idx = 0; idx < arrayLength; ++idx)
2914 operands.push_back(elements.at(arrayLength - idx - 1));
2915 rewriter.replaceOpWithNewOp<ArrayCreateOp>(op, op.getType(), operands);
2916 return success();
2917 }
2918
2919 return failure();
2920}
2921
2922static LogicalResult
2923canonicalizeArrayInjectIntoCreate(ArrayInjectOp op, PatternRewriter &rewriter) {
2924 auto createOp = op.getInput().getDefiningOp<ArrayCreateOp>();
2925 if (!createOp)
2926 return failure();
2927
2928 // Make sure the access is in bounds.
2929 APInt indexAPInt;
2930 if (!matchPattern(op.getIndex(), mlir::m_ConstantInt(&indexAPInt)) ||
2931 !indexAPInt.ult(createOp.getInputs().size()))
2932 return failure();
2933
2934 // Substitute the injected value.
2935 SmallVector<Value> elements = createOp.getInputs();
2936 elements[elements.size() - indexAPInt.getZExtValue() - 1] = op.getElement();
2937 rewriter.replaceOpWithNewOp<ArrayCreateOp>(op, elements);
2938 return success();
2939}
2940
2941void ArrayInjectOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
2942 MLIRContext *context) {
2943 patterns.add<ArrayInjectToSameIndex>(context);
2946}
2947
2948//===----------------------------------------------------------------------===//
2949// TypedeclOp
2950//===----------------------------------------------------------------------===//
2951
2952StringRef TypedeclOp::getPreferredName() {
2953 return getVerilogName().value_or(getName());
2954}
2955
2956Type TypedeclOp::getAliasType() {
2957 auto parentScope = cast<hw::TypeScopeOp>(getOperation()->getParentOp());
2958 return hw::TypeAliasType::get(
2959 SymbolRefAttr::get(parentScope.getSymNameAttr(),
2960 {FlatSymbolRefAttr::get(*this)}),
2961 getType());
2962}
2963
2964//===----------------------------------------------------------------------===//
2965// BitcastOp
2966//===----------------------------------------------------------------------===//
2967
2968OpFoldResult BitcastOp::fold(FoldAdaptor) {
2969 // Identity.
2970 // bitcast(%a) : A -> A ==> %a
2971 if (getOperand().getType() == getType())
2972 return getOperand();
2973
2974 return {};
2975}
2976
2977LogicalResult BitcastOp::canonicalize(BitcastOp op, PatternRewriter &rewriter) {
2978 // Composition.
2979 // %b = bitcast(%a) : A -> B
2980 // bitcast(%b) : B -> C
2981 // ===> bitcast(%a) : A -> C
2982 auto inputBitcast =
2983 dyn_cast_or_null<BitcastOp>(op.getInput().getDefiningOp());
2984 if (!inputBitcast)
2985 return failure();
2986 auto bitcast = rewriter.createOrFold<BitcastOp>(op.getLoc(), op.getType(),
2987 inputBitcast.getInput());
2988 rewriter.replaceOp(op, bitcast);
2989 return success();
2990}
2991
2992LogicalResult BitcastOp::verify() {
2993 if (getBitWidth(getInput().getType()) != getBitWidth(getResult().getType()))
2994 return this->emitOpError("Bitwidth of input must match result");
2995 return success();
2996}
2997
2998//===----------------------------------------------------------------------===//
2999// HierPathOp helpers.
3000//===----------------------------------------------------------------------===//
3001
3002bool HierPathOp::dropModule(StringAttr moduleToDrop) {
3003 SmallVector<Attribute, 4> newPath;
3004 bool updateMade = false;
3005 for (auto nameRef : getNamepath()) {
3006 // nameRef is either an InnerRefAttr or a FlatSymbolRefAttr.
3007 if (auto ref = dyn_cast<hw::InnerRefAttr>(nameRef)) {
3008 if (ref.getModule() == moduleToDrop)
3009 updateMade = true;
3010 else
3011 newPath.push_back(ref);
3012 } else {
3013 if (cast<FlatSymbolRefAttr>(nameRef).getAttr() == moduleToDrop)
3014 updateMade = true;
3015 else
3016 newPath.push_back(nameRef);
3017 }
3018 }
3019 if (updateMade)
3020 setNamepathAttr(ArrayAttr::get(getContext(), newPath));
3021 return updateMade;
3022}
3023
3024bool HierPathOp::inlineModule(StringAttr moduleToDrop) {
3025 SmallVector<Attribute, 4> newPath;
3026 bool updateMade = false;
3027 StringRef inlinedInstanceName = "";
3028 for (auto nameRef : getNamepath()) {
3029 // nameRef is either an InnerRefAttr or a FlatSymbolRefAttr.
3030 if (auto ref = dyn_cast<hw::InnerRefAttr>(nameRef)) {
3031 if (ref.getModule() == moduleToDrop) {
3032 inlinedInstanceName = ref.getName().getValue();
3033 updateMade = true;
3034 } else if (!inlinedInstanceName.empty()) {
3035 newPath.push_back(hw::InnerRefAttr::get(
3036 ref.getModule(),
3037 StringAttr::get(getContext(), inlinedInstanceName + "_" +
3038 ref.getName().getValue())));
3039 inlinedInstanceName = "";
3040 } else
3041 newPath.push_back(ref);
3042 } else {
3043 if (cast<FlatSymbolRefAttr>(nameRef).getAttr() == moduleToDrop)
3044 updateMade = true;
3045 else
3046 newPath.push_back(nameRef);
3047 }
3048 }
3049 if (updateMade)
3050 setNamepathAttr(ArrayAttr::get(getContext(), newPath));
3051 return updateMade;
3052}
3053
3054bool HierPathOp::updateModule(StringAttr oldMod, StringAttr newMod) {
3055 SmallVector<Attribute, 4> newPath;
3056 bool updateMade = false;
3057 for (auto nameRef : getNamepath()) {
3058 // nameRef is either an InnerRefAttr or a FlatSymbolRefAttr.
3059 if (auto ref = dyn_cast<hw::InnerRefAttr>(nameRef)) {
3060 if (ref.getModule() == oldMod) {
3061 newPath.push_back(hw::InnerRefAttr::get(newMod, ref.getName()));
3062 updateMade = true;
3063 } else
3064 newPath.push_back(ref);
3065 } else {
3066 if (cast<FlatSymbolRefAttr>(nameRef).getAttr() == oldMod) {
3067 newPath.push_back(FlatSymbolRefAttr::get(newMod));
3068 updateMade = true;
3069 } else
3070 newPath.push_back(nameRef);
3071 }
3072 }
3073 if (updateMade)
3074 setNamepathAttr(ArrayAttr::get(getContext(), newPath));
3075 return updateMade;
3076}
3077
3078bool HierPathOp::updateModuleAndInnerRef(
3079 StringAttr oldMod, StringAttr newMod,
3080 const llvm::DenseMap<StringAttr, StringAttr> &innerSymRenameMap) {
3081 auto fromRef = FlatSymbolRefAttr::get(oldMod);
3082 if (oldMod == newMod)
3083 return false;
3084
3085 auto namepathNew = getNamepath().getValue().vec();
3086 bool updateMade = false;
3087 // Break from the loop if the module is found, since it can occur only once.
3088 for (auto &element : namepathNew) {
3089 if (auto innerRef = dyn_cast<hw::InnerRefAttr>(element)) {
3090 if (innerRef.getModule() != oldMod)
3091 continue;
3092 auto symName = innerRef.getName();
3093 // Since the module got updated, the old innerRef symbol inside oldMod
3094 // should also be updated to the new symbol inside the newMod.
3095 auto to = innerSymRenameMap.find(symName);
3096 if (to != innerSymRenameMap.end())
3097 symName = to->second;
3098 updateMade = true;
3099 element = hw::InnerRefAttr::get(newMod, symName);
3100 break;
3101 }
3102 if (element != fromRef)
3103 continue;
3104
3105 updateMade = true;
3106 element = FlatSymbolRefAttr::get(newMod);
3107 break;
3108 }
3109 if (updateMade)
3110 setNamepathAttr(ArrayAttr::get(getContext(), namepathNew));
3111 return updateMade;
3112}
3113
3114bool HierPathOp::truncateAtModule(StringAttr atMod, bool includeMod) {
3115 SmallVector<Attribute, 4> newPath;
3116 bool updateMade = false;
3117 for (auto nameRef : getNamepath()) {
3118 // nameRef is either an InnerRefAttr or a FlatSymbolRefAttr.
3119 if (auto ref = dyn_cast<hw::InnerRefAttr>(nameRef)) {
3120 if (ref.getModule() == atMod) {
3121 updateMade = true;
3122 if (includeMod)
3123 newPath.push_back(ref);
3124 } else
3125 newPath.push_back(ref);
3126 } else {
3127 if (cast<FlatSymbolRefAttr>(nameRef).getAttr() == atMod && !includeMod)
3128 updateMade = true;
3129 else
3130 newPath.push_back(nameRef);
3131 }
3132 if (updateMade)
3133 break;
3134 }
3135 if (updateMade)
3136 setNamepathAttr(ArrayAttr::get(getContext(), newPath));
3137 return updateMade;
3138}
3139
3140/// Return just the module part of the namepath at a specific index.
3141StringAttr HierPathOp::modPart(unsigned i) {
3142 return TypeSwitch<Attribute, StringAttr>(getNamepath()[i])
3143 .Case<FlatSymbolRefAttr>([](auto a) { return a.getAttr(); })
3144 .Case<hw::InnerRefAttr>([](auto a) { return a.getModule(); });
3145}
3146
3147/// Return the root module.
3148StringAttr HierPathOp::root() {
3149 assert(!getNamepath().empty());
3150 return modPart(0);
3151}
3152
3153/// Return true if the NLA has the module in its path.
3154bool HierPathOp::hasModule(StringAttr modName) {
3155 for (auto nameRef : getNamepath()) {
3156 // nameRef is either an InnerRefAttr or a FlatSymbolRefAttr.
3157 if (auto ref = dyn_cast<hw::InnerRefAttr>(nameRef)) {
3158 if (ref.getModule() == modName)
3159 return true;
3160 } else {
3161 if (cast<FlatSymbolRefAttr>(nameRef).getAttr() == modName)
3162 return true;
3163 }
3164 }
3165 return false;
3166}
3167
3168/// Return true if the NLA has the InnerSym .
3169bool HierPathOp::hasInnerSym(StringAttr modName, StringAttr symName) const {
3170 for (auto nameRef : const_cast<HierPathOp *>(this)->getNamepath())
3171 if (auto ref = dyn_cast<hw::InnerRefAttr>(nameRef))
3172 if (ref.getName() == symName && ref.getModule() == modName)
3173 return true;
3174
3175 return false;
3176}
3177
3178/// Return just the reference part of the namepath at a specific index. This
3179/// will return an empty attribute if this is the leaf and the leaf is a module.
3180StringAttr HierPathOp::refPart(unsigned i) {
3181 return TypeSwitch<Attribute, StringAttr>(getNamepath()[i])
3182 .Case<FlatSymbolRefAttr>([](auto a) { return StringAttr({}); })
3183 .Case<hw::InnerRefAttr>([](auto a) { return a.getName(); });
3184}
3185
3186/// Return the leaf reference. This returns an empty attribute if the leaf
3187/// reference is a module.
3188StringAttr HierPathOp::ref() {
3189 assert(!getNamepath().empty());
3190 return refPart(getNamepath().size() - 1);
3191}
3192
3193/// Return the leaf module.
3194StringAttr HierPathOp::leafMod() {
3195 assert(!getNamepath().empty());
3196 return modPart(getNamepath().size() - 1);
3197}
3198
3199/// Returns true if this NLA targets an instance of a module (as opposed to
3200/// an instance's port or something inside an instance).
3201bool HierPathOp::isModule() { return !ref(); }
3202
3203/// Returns true if this NLA targets something inside a module (as opposed
3204/// to a module or an instance of a module);
3205bool HierPathOp::isComponent() { return (bool)ref(); }
3206
3207// Verify the HierPathOp.
3208// 1. Iterate over the namepath.
3209// 2. The namepath should be a valid instance path, specified either on a
3210// module or a declaration inside a module.
3211// 3. Each element in the namepath is an InnerRefAttr except possibly the
3212// last element.
3213// 4. Make sure that the InnerRefAttr is legal, by verifying the module name
3214// and the corresponding inner_sym on the instance.
3215// 5. Make sure that the instance path is legal, by verifying the sequence of
3216// instance and the expected module occurs as the next element in the path.
3217// 6. The last element of the namepath, can be an InnerRefAttr on either a
3218// module port or a declaration inside the module.
3219// 7. The last element of the namepath can also be a module symbol.
3220LogicalResult HierPathOp::verifyInnerRefs(hw::InnerRefNamespace &ns) {
3221 ArrayAttr expectedModuleNames = {};
3222 auto checkExpectedModule = [&](Attribute name) -> LogicalResult {
3223 if (!expectedModuleNames)
3224 return success();
3225 if (llvm::any_of(expectedModuleNames,
3226 [name](Attribute attr) { return attr == name; }))
3227 return success();
3228 auto diag = emitOpError() << "instance path is incorrect. Expected ";
3229 size_t n = expectedModuleNames.size();
3230 if (n != 1) {
3231 diag << "one of ";
3232 }
3233 for (size_t i = 0; i < n; ++i) {
3234 if (i != 0)
3235 diag << ((i + 1 == n) ? " or " : ", ");
3236 diag << cast<StringAttr>(expectedModuleNames[i]);
3237 }
3238 diag << ". Instead found: " << name;
3239 return diag;
3240 };
3241
3242 if (!getNamepath() || getNamepath().empty())
3243 return emitOpError() << "the instance path cannot be empty";
3244 for (unsigned i = 0, s = getNamepath().size() - 1; i < s; ++i) {
3245 hw::InnerRefAttr innerRef = dyn_cast<hw::InnerRefAttr>(getNamepath()[i]);
3246 if (!innerRef)
3247 return emitOpError()
3248 << "the instance path can only contain inner sym reference"
3249 << ", only the leaf can refer to a module symbol";
3250
3251 if (failed(checkExpectedModule(innerRef.getModule())))
3252 return failure();
3253
3254 auto instOp = ns.lookupOp<igraph::InstanceOpInterface>(innerRef);
3255 if (!instOp)
3256 return emitOpError() << " module: " << innerRef.getModule()
3257 << " does not contain any instance with symbol: "
3258 << innerRef.getName();
3259 expectedModuleNames = instOp.getReferencedModuleNamesAttr();
3260 }
3261
3262 // The instance path has been verified. Now verify the last element.
3263 auto leafRef = getNamepath()[getNamepath().size() - 1];
3264 if (auto innerRef = dyn_cast<hw::InnerRefAttr>(leafRef)) {
3265 if (!ns.lookup(innerRef)) {
3266 return emitOpError() << " operation with symbol: " << innerRef
3267 << " was not found ";
3268 }
3269 if (failed(checkExpectedModule(innerRef.getModule())))
3270 return failure();
3271 } else if (failed(checkExpectedModule(
3272 cast<FlatSymbolRefAttr>(leafRef).getAttr()))) {
3273 return failure();
3274 }
3275 return success();
3276}
3277
3278void HierPathOp::print(OpAsmPrinter &p) {
3279 p << " ";
3280
3281 // Print visibility if present.
3282 StringRef visibilityAttrName = SymbolTable::getVisibilityAttrName();
3283 if (auto visibility =
3284 getOperation()->getAttrOfType<StringAttr>(visibilityAttrName))
3285 p << visibility.getValue() << ' ';
3286
3287 p.printSymbolName(getSymName());
3288 p << " [";
3289 llvm::interleaveComma(getNamepath().getValue(), p, [&](Attribute attr) {
3290 if (auto ref = dyn_cast<hw::InnerRefAttr>(attr)) {
3291 p.printSymbolName(ref.getModule().getValue());
3292 p << "::";
3293 p.printSymbolName(ref.getName().getValue());
3294 } else {
3295 p.printSymbolName(cast<FlatSymbolRefAttr>(attr).getValue());
3296 }
3297 });
3298 p << "]";
3299 p.printOptionalAttrDict(
3300 (*this)->getAttrs(),
3301 {SymbolTable::getSymbolAttrName(), "namepath", visibilityAttrName});
3302}
3303
3304ParseResult HierPathOp::parse(OpAsmParser &parser, OperationState &result) {
3305 // Parse the visibility attribute.
3306 (void)mlir::impl::parseOptionalVisibilityKeyword(parser, result.attributes);
3307
3308 // Parse the symbol name.
3309 StringAttr symName;
3310 if (parser.parseSymbolName(symName, SymbolTable::getSymbolAttrName(),
3311 result.attributes))
3312 return failure();
3313
3314 // Parse the namepath.
3315 SmallVector<Attribute> namepath;
3316 if (parser.parseCommaSeparatedList(
3317 OpAsmParser::Delimiter::Square, [&]() -> ParseResult {
3318 auto loc = parser.getCurrentLocation();
3319 SymbolRefAttr ref;
3320 if (parser.parseAttribute(ref))
3321 return failure();
3322
3323 // "A" is a Ref, "A::b" is a InnerRef, "A::B::c" is an error.
3324 auto pathLength = ref.getNestedReferences().size();
3325 if (pathLength == 0)
3326 namepath.push_back(
3327 FlatSymbolRefAttr::get(ref.getRootReference()));
3328 else if (pathLength == 1)
3329 namepath.push_back(hw::InnerRefAttr::get(ref.getRootReference(),
3330 ref.getLeafReference()));
3331 else
3332 return parser.emitError(loc,
3333 "only one nested reference is allowed");
3334 return success();
3335 }))
3336 return failure();
3337 result.addAttribute("namepath",
3338 ArrayAttr::get(parser.getContext(), namepath));
3339
3340 if (parser.parseOptionalAttrDict(result.attributes))
3341 return failure();
3342
3343 return success();
3344}
3345
3346//===----------------------------------------------------------------------===//
3347// TriggeredOp
3348//===----------------------------------------------------------------------===//
3349
3350void TriggeredOp::build(OpBuilder &builder, OperationState &odsState,
3351 EventControlAttr event, Value trigger,
3352 ValueRange inputs) {
3353 odsState.addOperands(trigger);
3354 odsState.addOperands(inputs);
3355 odsState.addAttribute(getEventAttrName(odsState.name), event);
3356 auto *r = odsState.addRegion();
3357 Block *b = new Block();
3358 r->push_back(b);
3359
3360 llvm::SmallVector<Location> argLocs;
3361 llvm::transform(inputs, std::back_inserter(argLocs),
3362 [&](Value v) { return v.getLoc(); });
3363 b->addArguments(inputs.getTypes(), argLocs);
3364}
3365
3366//===----------------------------------------------------------------------===//
3367// TableGen generated logic.
3368//===----------------------------------------------------------------------===//
3369
3370// Provide the autogenerated implementation guts for the Op classes.
3371#define GET_OP_CLASSES
3372#include "circt/Dialect/HW/HW.cpp.inc"
assert(baseType &&"element must be base type")
MlirType elementType
Definition CHIRRTL.cpp:29
static std::unique_ptr< Context > context
static void buildModule(OpBuilder &builder, OperationState &result, StringAttr name, ArrayRef< PortInfo > ports, ArrayAttr annotations, ArrayAttr layers)
void getAsmBlockArgumentNamesImpl(Operation *op, mlir::Region &region, OpAsmSetValueNameFn setNameFn)
Get a special name to use when printing the entry block arguments of the region contained by an opera...
static LogicalResult verifyModuleCommon(HWModuleLike module)
Definition HWOps.cpp:1090
static void printParamValue(OpAsmPrinter &p, Operation *, Attribute value, Type resultType)
Definition HWOps.cpp:505
static LogicalResult canonicalizeArrayInjectChain(ArrayInjectOp op, PatternRewriter &rewriter)
Definition HWOps.cpp:2872
static void printModuleOp(OpAsmPrinter &p, ModuleTy mod)
Definition HWOps.cpp:1035
static bool flattenConcatOp(ArrayConcatOp op, PatternRewriter &rewriter)
Definition HWOps.cpp:2016
static LogicalResult foldCreateToSlice(ArrayCreateOp op, PatternRewriter &rewriter)
Definition HWOps.cpp:1746
static SmallVector< PortInfo > getPortList(ModuleTy &mod)
Definition HWOps.cpp:1438
static ArrayAttr arrayOrEmpty(mlir::MLIRContext *context, ArrayRef< Attribute > attrs)
Definition HWOps.cpp:84
FunctionType getHWModuleOpType(Operation *op)
Definition HWOps.cpp:1027
static void printExtractOp(OpAsmPrinter &printer, AggType op)
Use the same printer for both struct_extract and union_extract since the syntax is identical.
Definition HWOps.cpp:2413
static void printArrayConcatTypes(OpAsmPrinter &p, Operation *, TypeRange inputTypes, Type resultType)
Definition HWOps.cpp:1977
static ParseResult parseSliceTypes(OpAsmParser &p, Type &srcType, Type &idxType)
Definition HWOps.cpp:1633
static void modifyModulePorts(Operation *op, ArrayRef< std::pair< unsigned, PortInfo > > insertInputs, ArrayRef< std::pair< unsigned, PortInfo > > insertOutputs, ArrayRef< unsigned > removeInputs, ArrayRef< unsigned > removeOutputs, Block *body=nullptr)
Insert and remove ports of a module.
Definition HWOps.cpp:698
static Value foldStructExtract(Operation *inputOp, uint32_t fieldIndex)
Definition HWOps.cpp:69
static bool hasAttribute(StringRef name, ArrayRef< NamedAttribute > attrs)
Definition HWOps.cpp:905
static void modifyModuleArgs(MLIRContext *context, ArrayRef< std::pair< unsigned, PortInfo > > insertArgs, ArrayRef< unsigned > removeArgs, ArrayRef< Attribute > oldArgNames, ArrayRef< Type > oldArgTypes, ArrayRef< Attribute > oldArgAttrs, ArrayRef< Location > oldArgLocs, SmallVector< Attribute > &newArgNames, SmallVector< Type > &newArgTypes, SmallVector< Attribute > &newArgAttrs, SmallVector< Location > &newArgLocs, Block *body=nullptr)
Internal implementation of argument/result insertion and removal on modules.
Definition HWOps.cpp:599
static bool mergeConcatSlices(ArrayConcatOp op, PatternRewriter &rewriter)
Definition HWOps.cpp:2033
static SmallVector< Location > getAllPortLocs(ModTy module)
Definition HWOps.cpp:1216
static ParseResult parseExtractOp(OpAsmParser &parser, OperationState &result)
Use the same parser for both struct_extract and union_extract since the syntax is identical.
Definition HWOps.cpp:2376
static void setAllPortNames(ArrayRef< Attribute > names, ModTy module)
Definition HWOps.cpp:1286
static void getAsmBlockArgumentNamesImpl(mlir::Region &region, OpAsmSetValueNameFn setNameFn)
Get a special name to use when printing the entry block arguments of the region contained by an opera...
Definition HWOps.cpp:101
static void setHWModuleType(ModTy &mod, ModuleType type)
Definition HWOps.cpp:1359
static ParseResult parseParamValue(OpAsmParser &p, Attribute &value, Type &resultType)
Definition HWOps.cpp:497
static LogicalResult checkAttributes(Operation *op, Attribute attr, Type type)
Definition HWOps.cpp:411
static LogicalResult canonicalizeArrayInjectIntoCreate(ArrayInjectOp op, PatternRewriter &rewriter)
Definition HWOps.cpp:2923
static std::optional< uint64_t > getUIntFromValue(Value value)
Definition HWOps.cpp:1817
static ParseResult parseHWModuleOp(OpAsmParser &parser, OperationState &result)
Definition HWOps.cpp:913
static LogicalResult verifyAggregateFieldIndexAndType(AggregateOp &op, AggregateType aggType, Type elementType)
Ensure an aggregate op's field index is within the bounds of the aggregate type and the accessed fiel...
Definition HWOps.cpp:2350
static PortInfo getPort(ModuleTy &mod, size_t idx)
Definition HWOps.cpp:1458
static void printSliceTypes(OpAsmPrinter &p, Operation *, Type srcType, Type idxType)
Definition HWOps.cpp:1647
static bool hasAdditionalAttributes(Op op, ArrayRef< StringRef > ignoredAttrs={})
Check whether an operation has any additional attributes set beyond its standard list of attributes r...
Definition HWOps.cpp:353
Delimiter
Definition HWOps.cpp:116
@ OptionalLessGreater
static ParseResult parseArrayConcatTypes(OpAsmParser &p, SmallVectorImpl< Type > &inputTypes, Type &resultType)
Definition HWOps.cpp:1947
static bool getFieldName(const FieldRef &fieldRef, SmallString< 32 > &string)
static Location getLoc(DefSlot slot)
Definition Mem2Reg.cpp:216
static StringAttr append(StringAttr base, const Twine &suffix)
Return a attribute with the specified suffix appended.
static Block * getBodyBlock(FModuleLike mod)
static InstancePath empty
A namespace that is used to store existing names and generate new names in some scope within the IR.
Definition Namespace.h:30
StringRef newName(const Twine &name)
Return a unique name, derived from the input name, and add the new name to the internal namespace.
Definition Namespace.h:87
void setOutput(unsigned i, Value v)
Definition HWOps.cpp:241
Value getInput(unsigned i)
Definition HWOps.cpp:247
llvm::SmallVector< Value > outputOperands
Definition HWOps.h:119
llvm::SmallVector< Value > inputArgs
Definition HWOps.h:118
llvm::StringMap< unsigned > outputIdx
Definition HWOps.h:117
llvm::StringMap< unsigned > inputIdx
Definition HWOps.h:117
HWModulePortAccessor(Location loc, const ModulePortInfo &info, Region &bodyRegion)
Definition HWOps.cpp:225
static StringRef getInnerSymbolAttrName()
Return the name of the attribute used for inner symbol names.
This helps visit TypeOp nodes.
Definition HWVisitors.h:25
ResultType dispatchTypeOpVisitor(Operation *op, ExtraArgs... args)
Definition HWVisitors.h:27
ResultType visitUnhandledTypeOp(Operation *op, ExtraArgs... args)
This callback is invoked on any combinational operations that are not handled by the concrete visitor...
Definition HWVisitors.h:57
ResultType visitInvalidTypeOp(Operation *op, ExtraArgs... args)
This callback is invoked on any non-expression operations.
Definition HWVisitors.h:50
create(array_value, idx)
Definition hw.py:450
create(array_value, low_index, ret_type)
Definition hw.py:466
create(data_type, value)
Definition hw.py:433
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition CalyxOps.cpp:55
uint64_t getWidth(Type t)
Definition ESIPasses.cpp:32
size_t getNumPorts(Operation *op)
Return the number of ports in a module-like thing (modules, memories, etc)
ModuleType fnToMod(Operation *op, ArrayRef< Attribute > inputNames, ArrayRef< Attribute > outputNames)
Definition HWTypes.cpp:1052
LogicalResult verifyParameterStructure(ArrayAttr parameters, ArrayAttr moduleParameters, const EmitErrorFn &emitError)
Check that all the parameter values specified to the instance are structurally valid.
std::function< void(std::function< bool(InFlightDiagnostic &)>)> EmitErrorFn
Whenever the nested function returns true, a note referring to the referenced module is attached to t...
LogicalResult verifyInstanceOfHWModule(Operation *instance, FlatSymbolRefAttr moduleRef, OperandRange inputs, TypeRange results, ArrayAttr argNames, ArrayAttr resultNames, ArrayAttr parameters, SymbolTableCollection &symbolTable)
Combines verifyReferencedModule, verifyInputs, verifyOutputs, and verifyParameters.
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
ParseResult parseModuleSignature(OpAsmParser &parser, SmallVectorImpl< PortParse > &args, TypeAttr &modType)
New Style parsing.
void printModuleSignatureNew(OpAsmPrinter &p, Region &body, hw::ModuleType modType, ArrayRef< Attribute > portAttrs, ArrayRef< Location > locAttrs)
bool isOffset(Value base, Value index, uint64_t offset)
Definition HWOps.cpp:1722
llvm::function_ref< void(OpBuilder &, HWModulePortAccessor &)> HWModuleBuilder
Definition HWOps.h:124
FunctionType getModuleType(Operation *module)
Return the signature for the specified module as a function type.
Definition HWOps.cpp:533
LogicalResult checkParameterInContext(Attribute value, Operation *module, Operation *usingOp, bool disallowParamRefs=false)
Check parameter specified by value to see if it is valid within the scope of the specified module mod...
Definition HWOps.cpp:202
int64_t getBitWidth(mlir::Type type)
Return the hardware bit width of a type.
Definition HWTypes.cpp:110
bool isAnyModuleOrInstance(Operation *module)
TODO: Move all these functions to a hw::ModuleLike interface.
Definition HWOps.cpp:527
StringAttr getVerilogModuleNameAttr(Operation *module)
Returns the verilog module name attribute or symbol name of any module-like operations.
Definition HWOps.cpp:551
mlir::Type getCanonicalType(mlir::Type type)
Definition HWTypes.cpp:49
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
ParseResult parseInputPortList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &inputs, SmallVectorImpl< Type > &inputTypes, ArrayAttr &inputNames)
Parse a list of instance input ports.
void printOutputPortList(OpAsmPrinter &p, Operation *op, TypeRange resultTypes, ArrayAttr resultNames)
Print a list of instance output ports.
ParseResult parseOptionalParameterList(OpAsmParser &parser, ArrayAttr &parameters)
Parse an parameter list if present.
void printOptionalParameterList(OpAsmPrinter &p, Operation *op, ArrayAttr parameters)
Print a parameter list for a module or instance.
StringRef chooseName(StringRef a, StringRef b)
Choose a good name for an item from two options.
Definition Naming.cpp:47
void printInputPortList(OpAsmPrinter &p, Operation *op, OperandRange inputs, TypeRange inputTypes, ArrayAttr inputNames)
Print a list of instance input ports.
ParseResult parseOutputPortList(OpAsmParser &parser, SmallVectorImpl< Type > &resultTypes, ArrayAttr &resultNames)
Parse a list of instance output ports.
Definition hw.py:1
function_ref< void(Value, StringRef)> OpAsmSetValueNameFn
Definition LLVM.h:183
This class represents the namespace in which InnerRef's can be resolved.
InnerSymTarget lookup(hw::InnerRefAttr inner) const
Resolve the InnerRef to its target within this namespace, returning empty target if no such name exis...
Operation * lookupOp(hw::InnerRefAttr inner) const
Resolve the InnerRef to its target within this namespace, returning empty target if no such name exis...
This holds a decoded list of input/inout and output ports for a module or instance.
PortInfo & at(size_t idx)
PortDirectionRange getOutputs()
mlir::Type type
Definition HWTypes.h:31
mlir::StringAttr name
Definition HWTypes.h:30
This holds the name, type, direction of a module's ports.