CIRCT 23.0.0git
Loading...
Searching...
No Matches
HWOps.cpp
Go to the documentation of this file.
1//===- HWOps.cpp - Implement the HW operations ----------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implement the HW ops.
10//
11//===----------------------------------------------------------------------===//
12
23#include "mlir/IR/Builders.h"
24#include "mlir/IR/Matchers.h"
25#include "mlir/IR/PatternMatch.h"
26#include "mlir/Interfaces/FunctionImplementation.h"
27#include "llvm/ADT/BitVector.h"
28#include "llvm/ADT/SmallPtrSet.h"
29#include "llvm/ADT/StringSet.h"
30
31using namespace circt;
32using namespace hw;
33using mlir::TypedAttr;
34
35/// Flip a port direction.
37 switch (direction) {
38 case ModulePort::Direction::Input:
39 return ModulePort::Direction::Output;
40 case ModulePort::Direction::Output:
41 return ModulePort::Direction::Input;
42 case ModulePort::Direction::InOut:
43 return ModulePort::Direction::InOut;
44 }
45 llvm_unreachable("unknown PortDirection");
46}
47
48bool hw::isValidIndexBitWidth(Value index, Value array) {
49 hw::ArrayType arrayType =
50 dyn_cast<hw::ArrayType>(hw::getCanonicalType(array.getType()));
51 assert(arrayType && "expected array type");
52 unsigned indexWidth = index.getType().getIntOrFloatBitWidth();
53 auto requiredWidth = llvm::Log2_64_Ceil(arrayType.getNumElements());
54 return requiredWidth == 0 ? (indexWidth == 0 || indexWidth == 1)
55 : indexWidth == requiredWidth;
56}
57
58/// Return true if the specified operation is a combinational logic op.
59bool hw::isCombinational(Operation *op) {
60 struct IsCombClassifier : public TypeOpVisitor<IsCombClassifier, bool> {
61 bool visitInvalidTypeOp(Operation *op) { return false; }
62 bool visitUnhandledTypeOp(Operation *op) { return true; }
63 };
64
65 return (op->getDialect() && op->getDialect()->getNamespace() == "comb") ||
66 IsCombClassifier().dispatchTypeOpVisitor(op);
67}
68
69static Value foldStructExtract(Operation *inputOp, uint32_t fieldIndex) {
70 // A struct extract of a struct create -> corresponding struct create operand.
71 if (auto structCreate = dyn_cast_or_null<StructCreateOp>(inputOp)) {
72 return structCreate.getOperand(fieldIndex);
73 }
74
75 // Extracting injected field -> corresponding field
76 if (auto structInject = dyn_cast_or_null<StructInjectOp>(inputOp)) {
77 if (structInject.getFieldIndex() != fieldIndex)
78 return {};
79 return structInject.getNewValue();
80 }
81 return {};
82}
83
84static ArrayAttr arrayOrEmpty(mlir::MLIRContext *context,
85 ArrayRef<Attribute> attrs) {
86 if (attrs.empty())
87 return ArrayAttr::get(context, {});
88 bool empty = true;
89 for (auto a : attrs)
90 if (a && !cast<DictionaryAttr>(a).empty()) {
91 empty = false;
92 break;
93 }
94 if (empty)
95 return ArrayAttr::get(context, {});
96 return ArrayAttr::get(context, attrs);
97}
98
99/// Get a special name to use when printing the entry block arguments of the
100/// region contained by an operation in this dialect.
101static void getAsmBlockArgumentNamesImpl(mlir::Region &region,
102 OpAsmSetValueNameFn setNameFn) {
103 if (region.empty())
104 return;
105 // Assign port names to the bbargs.
106 auto module = cast<HWModuleOp>(region.getParentOp());
107
108 auto *block = &region.front();
109 for (size_t i = 0, e = block->getNumArguments(); i != e; ++i) {
110 auto name = module.getInputName(i);
111 // Let mlir deterministically convert names to valid identifiers
112 setNameFn(block->getArgument(i), name);
113 }
114}
115
116enum class Delimiter {
117 None,
118 Paren, // () enclosed list
119 OptionalLessGreater, // <> enclosed list or absent
120};
121
122/// Check parameter specified by `value` to see if it is valid according to the
123/// module's parameters. If not, emit an error to the diagnostic provided as an
124/// argument to the lambda 'instanceError' and return failure, otherwise return
125/// success.
126///
127/// If `disallowParamRefs` is true, then parameter references are not allowed.
128LogicalResult hw::checkParameterInContext(
129 Attribute value, ArrayAttr moduleParameters,
130 const instance_like_impl::EmitErrorFn &instanceError,
131 bool disallowParamRefs) {
132 // Literals are always ok. Their types are already known to match
133 // expectations.
134 if (isa<IntegerAttr>(value) || isa<FloatAttr>(value) ||
135 isa<StringAttr>(value) || isa<ParamVerbatimAttr>(value))
136 return success();
137
138 // Check both subexpressions of an expression.
139 if (auto expr = dyn_cast<ParamExprAttr>(value)) {
140 for (auto op : expr.getOperands())
141 if (failed(checkParameterInContext(op, moduleParameters, instanceError,
142 disallowParamRefs)))
143 return failure();
144 return success();
145 }
146
147 // Parameter references need more analysis to make sure they are valid within
148 // this module.
149 if (auto parameterRef = dyn_cast<ParamDeclRefAttr>(value)) {
150 auto nameAttr = parameterRef.getName();
151
152 // Don't allow references to parameters from the default values of a
153 // parameter list.
154 if (disallowParamRefs) {
155 instanceError([&](auto &diag) {
156 diag << "parameter " << nameAttr
157 << " cannot be used as a default value for a parameter";
158 return false;
159 });
160 return failure();
161 }
162
163 // Find the corresponding attribute in the module.
164 for (auto param : moduleParameters) {
165 auto paramAttr = cast<ParamDeclAttr>(param);
166 if (paramAttr.getName() != nameAttr)
167 continue;
168
169 // If the types match then the reference is ok.
170 if (paramAttr.getType() == parameterRef.getType())
171 return success();
172
173 instanceError([&](auto &diag) {
174 diag << "parameter " << nameAttr << " used with type "
175 << parameterRef.getType() << "; should have type "
176 << paramAttr.getType();
177 return true;
178 });
179 return failure();
180 }
181
182 instanceError([&](auto &diag) {
183 diag << "use of unknown parameter " << nameAttr;
184 return true;
185 });
186 return failure();
187 }
188
189 instanceError([&](auto &diag) {
190 diag << "invalid parameter value " << value;
191 return false;
192 });
193 return failure();
194}
195
196/// Check parameter specified by `value` to see if it is valid within the scope
197/// of the specified module `module`. If not, emit an error at the location of
198/// `usingOp` and return failure, otherwise return success. If `usingOp` is
199/// null, then no diagnostic is generated.
200///
201/// If `disallowParamRefs` is true, then parameter references are not allowed.
202LogicalResult hw::checkParameterInContext(Attribute value, Operation *module,
203 Operation *usingOp,
204 bool disallowParamRefs) {
206 [&](const std::function<bool(InFlightDiagnostic &)> &fn) {
207 if (usingOp) {
208 auto diag = usingOp->emitOpError();
209 if (fn(diag))
210 diag.attachNote(module->getLoc()) << "module declared here";
211 }
212 };
213
214 return checkParameterInContext(value,
215 module->getAttrOfType<ArrayAttr>("parameters"),
216 emitError, disallowParamRefs);
217}
218
219/// Return true if the specified attribute tree is made up of nodes that are
220/// valid in a parameter expression.
221bool hw::isValidParameterExpression(Attribute attr, Operation *module) {
222 return succeeded(checkParameterInContext(attr, module, nullptr, false));
223}
224
226 const ModulePortInfo &info,
227 Region &bodyRegion)
228 : info(info) {
229 inputArgs.resize(info.sizeInputs());
230 for (auto [i, barg] : llvm::enumerate(bodyRegion.getArguments())) {
231 inputIdx[info.at(i).name.str()] = i;
232 inputArgs[i] = barg;
233 }
234
236 for (auto [i, outputInfo] : llvm::enumerate(info.getOutputs())) {
237 outputIdx[outputInfo.name.str()] = i;
238 }
239}
240
241void HWModulePortAccessor::setOutput(unsigned i, Value v) {
242 assert(outputOperands.size() > i && "invalid output index");
243 assert(outputOperands[i] == Value() && "output already set");
244 outputOperands[i] = v;
245}
246
248 assert(inputArgs.size() > i && "invalid input index");
249 return inputArgs[i];
250}
251Value HWModulePortAccessor::getInput(StringRef name) {
252 return getInput(inputIdx.find(name.str())->second);
253}
254void HWModulePortAccessor::setOutput(StringRef name, Value v) {
255 setOutput(outputIdx.find(name.str())->second, v);
256}
257
258//===----------------------------------------------------------------------===//
259// Declarative Canonicalization Patterns
260//===----------------------------------------------------------------------===//
261
262namespace {
263#include "circt/Dialect/HW/HWCanonicalization.cpp.inc"
264} // namespace
265
266//===----------------------------------------------------------------------===//
267// ConstantOp
268//===----------------------------------------------------------------------===//
269
270void ConstantOp::print(OpAsmPrinter &p) {
271 p << " ";
272 p.printAttribute(getValueAttr());
273 p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"value"});
274}
275
276ParseResult ConstantOp::parse(OpAsmParser &parser, OperationState &result) {
277 IntegerAttr valueAttr;
278
279 if (parser.parseAttribute(valueAttr, "value", result.attributes) ||
280 parser.parseOptionalAttrDict(result.attributes))
281 return failure();
282
283 result.addTypes(valueAttr.getType());
284 return success();
285}
286
287LogicalResult ConstantOp::verify() {
288 // If the result type has a bitwidth, then the attribute must match its width.
289 if (getValue().getBitWidth() != cast<IntegerType>(getType()).getWidth())
290 return emitError(
291 "hw.constant attribute bitwidth doesn't match return type");
292
293 return success();
294}
295
296/// Build a ConstantOp from an APInt, infering the result type from the
297/// width of the APInt.
298void ConstantOp::build(OpBuilder &builder, OperationState &result,
299 const APInt &value) {
300
301 auto type = IntegerType::get(builder.getContext(), value.getBitWidth());
302 auto attr = builder.getIntegerAttr(type, value);
303 return build(builder, result, type, attr);
304}
305
306/// Build a ConstantOp from an APInt, infering the result type from the
307/// width of the APInt.
308void ConstantOp::build(OpBuilder &builder, OperationState &result,
309 IntegerAttr value) {
310 return build(builder, result, value.getType(), value);
311}
312
313/// This builder allows construction of small signed integers like 0, 1, -1
314/// matching a specified MLIR IntegerType. This shouldn't be used for general
315/// constant folding because it only works with values that can be expressed in
316/// an int64_t. Use APInt's instead.
317void ConstantOp::build(OpBuilder &builder, OperationState &result, Type type,
318 int64_t value) {
319 auto numBits = cast<IntegerType>(type).getWidth();
320 build(builder, result,
321 APInt(numBits, (uint64_t)value, /*isSigned=*/true,
322 /*implicitTrunc=*/true));
323}
324
325void ConstantOp::getAsmResultNames(
326 function_ref<void(Value, StringRef)> setNameFn) {
327 auto intTy = getType();
328 auto intCst = getValue();
329
330 // Sugar i1 constants with 'true' and 'false'.
331 if (cast<IntegerType>(intTy).getWidth() == 1)
332 return setNameFn(getResult(), intCst.isZero() ? "false" : "true");
333
334 // Otherwise, build a complex name with the value and type.
335 SmallVector<char, 32> specialNameBuffer;
336 llvm::raw_svector_ostream specialName(specialNameBuffer);
337 specialName << 'c' << intCst << '_' << intTy;
338 setNameFn(getResult(), specialName.str());
339}
340
341OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) {
342 assert(adaptor.getOperands().empty() && "constant has no operands");
343 return getValueAttr();
344}
345
346//===----------------------------------------------------------------------===//
347// WireOp
348//===----------------------------------------------------------------------===//
349
350/// Check whether an operation has any additional attributes set beyond its
351/// standard list of attributes returned by `getAttributeNames`.
352template <class Op>
353static bool hasAdditionalAttributes(Op op,
354 ArrayRef<StringRef> ignoredAttrs = {}) {
355 auto names = op.getAttributeNames();
356 llvm::SmallDenseSet<StringRef> nameSet;
357 nameSet.reserve(names.size() + ignoredAttrs.size());
358 nameSet.insert(names.begin(), names.end());
359 nameSet.insert(ignoredAttrs.begin(), ignoredAttrs.end());
360 return llvm::any_of(op->getAttrs(), [&](auto namedAttr) {
361 return !nameSet.contains(namedAttr.getName());
362 });
363}
364
365void WireOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
366 // If the wire has an optional 'name' attribute, use it.
367 auto nameAttr = (*this)->getAttrOfType<StringAttr>("name");
368 if (nameAttr && !nameAttr.getValue().empty())
369 setNameFn(getResult(), nameAttr.getValue());
370}
371
372std::optional<size_t> WireOp::getTargetResultIndex() { return 0; }
373
374OpFoldResult WireOp::fold(FoldAdaptor adaptor) {
375 // If the wire has no additional attributes, no name, and no symbol, just
376 // forward its input.
377 if (!hasAdditionalAttributes(*this, {"sv.namehint"}) && !getNameAttr() &&
378 !getInnerSymAttr())
379 return getInput();
380 return {};
381}
382
383LogicalResult WireOp::canonicalize(WireOp wire, PatternRewriter &rewriter) {
384 // Block if the wire has any attributes.
385 if (hasAdditionalAttributes(wire, {"sv.namehint"}))
386 return failure();
387
388 // If the wire has a symbol, then we can't delete it.
389 if (wire.getInnerSymAttr())
390 return failure();
391
392 // If the wire has a name or an `sv.namehint` attribute, propagate it as an
393 // `sv.namehint` to the expression.
394 if (auto *inputOp = wire.getInput().getDefiningOp())
395 if (auto name = chooseName(wire, inputOp))
396 rewriter.modifyOpInPlace(inputOp,
397 [&] { inputOp->setAttr("sv.namehint", name); });
398
399 rewriter.replaceOp(wire, wire.getInput());
400 return success();
401}
402
403//===----------------------------------------------------------------------===//
404// AggregateConstantOp
405//===----------------------------------------------------------------------===//
406
407static LogicalResult checkAttributes(Operation *op, Attribute attr, Type type) {
408 // If this is a type alias, get the underlying type.
409 if (auto typeAlias = dyn_cast<TypeAliasType>(type))
410 type = typeAlias.getCanonicalType();
411
412 if (auto structType = dyn_cast<StructType>(type)) {
413 auto arrayAttr = dyn_cast<ArrayAttr>(attr);
414 if (!arrayAttr)
415 return op->emitOpError("expected array attribute for constant of type ")
416 << type;
417 if (structType.getElements().size() != arrayAttr.size())
418 return op->emitOpError("array attribute (")
419 << arrayAttr.size() << ") has wrong size for struct constant ("
420 << structType.getElements().size() << ")";
421
422 for (auto [attr, fieldInfo] :
423 llvm::zip(arrayAttr.getValue(), structType.getElements())) {
424 if (failed(checkAttributes(op, attr, fieldInfo.type)))
425 return failure();
426 }
427 } else if (auto arrayType = dyn_cast<ArrayType>(type)) {
428 auto arrayAttr = dyn_cast<ArrayAttr>(attr);
429 if (!arrayAttr)
430 return op->emitOpError("expected array attribute for constant of type ")
431 << type;
432 if (arrayType.getNumElements() != arrayAttr.size())
433 return op->emitOpError("array attribute (")
434 << arrayAttr.size() << ") has wrong size for array constant ("
435 << arrayType.getNumElements() << ")";
436
437 auto elementType = arrayType.getElementType();
438 for (auto attr : arrayAttr.getValue()) {
439 if (failed(checkAttributes(op, attr, elementType)))
440 return failure();
441 }
442 } else if (auto arrayType = dyn_cast<UnpackedArrayType>(type)) {
443 auto arrayAttr = dyn_cast<ArrayAttr>(attr);
444 if (!arrayAttr)
445 return op->emitOpError("expected array attribute for constant of type ")
446 << type;
447 auto elementType = arrayType.getElementType();
448 if (arrayType.getNumElements() != arrayAttr.size())
449 return op->emitOpError("array attribute (")
450 << arrayAttr.size()
451 << ") has wrong size for unpacked array constant ("
452 << arrayType.getNumElements() << ")";
453
454 for (auto attr : arrayAttr.getValue()) {
455 if (failed(checkAttributes(op, attr, elementType)))
456 return failure();
457 }
458 } else if (auto enumType = dyn_cast<EnumType>(type)) {
459 auto stringAttr = dyn_cast<StringAttr>(attr);
460 if (!stringAttr)
461 return op->emitOpError("expected string attribute for constant of type ")
462 << type;
463 } else if (auto intType = dyn_cast<IntegerType>(type)) {
464 // Check the attribute kind is correct.
465 auto intAttr = dyn_cast<IntegerAttr>(attr);
466 if (!intAttr)
467 return op->emitOpError("expected integer attribute for constant of type ")
468 << type;
469 // Check the bitwidth is correct.
470 if (intAttr.getValue().getBitWidth() != intType.getWidth())
471 return op->emitOpError("hw.constant attribute bitwidth "
472 "doesn't match return type");
473 } else if (auto typedAttr = dyn_cast<TypedAttr>(attr)) {
474 if (typedAttr.getType() != type)
475 return op->emitOpError("typed attr doesn't match the return type ")
476 << type;
477 } else {
478 return op->emitOpError("unknown element type ") << type;
479 }
480 return success();
481}
482
483LogicalResult AggregateConstantOp::verify() {
484 return checkAttributes(*this, getFieldsAttr(), getType());
485}
486
487OpFoldResult AggregateConstantOp::fold(FoldAdaptor) { return getFieldsAttr(); }
488
489//===----------------------------------------------------------------------===//
490// ParamValueOp
491//===----------------------------------------------------------------------===//
492
493static ParseResult parseParamValue(OpAsmParser &p, Attribute &value,
494 Type &resultType) {
495 if (p.parseType(resultType) || p.parseEqual() ||
496 p.parseAttribute(value, resultType))
497 return failure();
498 return success();
499}
500
501static void printParamValue(OpAsmPrinter &p, Operation *, Attribute value,
502 Type resultType) {
503 p << resultType << " = ";
504 p.printAttributeWithoutType(value);
505}
506
507LogicalResult ParamValueOp::verify() {
508 // Check that the attribute expression is valid in this module.
510 getValue(), (*this)->getParentOfType<hw::HWModuleOp>(), *this);
511}
512
513OpFoldResult ParamValueOp::fold(FoldAdaptor adaptor) {
514 assert(adaptor.getOperands().empty() && "hw.param.value has no operands");
515 return getValueAttr();
516}
517
518//===----------------------------------------------------------------------===//
519// HWModuleOp
520//===----------------------------------------------------------------------===/
521
522/// Return true if isAnyModule or instance.
523bool hw::isAnyModuleOrInstance(Operation *moduleOrInstance) {
524 return isa<HWModuleLike, InstanceOp>(moduleOrInstance);
525}
526
527/// Return the signature for a module as a function type from the module itself
528/// or from an hw::InstanceOp.
529FunctionType hw::getModuleType(Operation *moduleOrInstance) {
530 return TypeSwitch<Operation *, FunctionType>(moduleOrInstance)
531 .Case<InstanceOp, InstanceChoiceOp>([](auto instance) {
532 SmallVector<Type> inputs(instance->getOperandTypes());
533 SmallVector<Type> results(instance->getResultTypes());
534 return FunctionType::get(instance->getContext(), inputs, results);
535 })
536 .Case<HWModuleLike>(
537 [](auto mod) { return mod.getHWModuleType().getFuncType(); })
538 .Default([](Operation *op) {
539 return cast<FunctionType>(
540 cast<mlir::FunctionOpInterface>(op).getFunctionType());
541 });
542}
543
544/// Return the name to use for the Verilog module that we're referencing
545/// here. This is typically the symbol, but can be overridden with the
546/// verilogName attribute.
547StringAttr hw::getVerilogModuleNameAttr(Operation *module) {
548 auto nameAttr = module->getAttrOfType<StringAttr>("verilogName");
549 if (nameAttr)
550 return nameAttr;
551
552 return module->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName());
553}
554
555template <typename ModuleTy>
556static void
557buildModule(OpBuilder &builder, OperationState &result, StringAttr name,
558 const ModulePortInfo &ports, ArrayAttr parameters,
559 ArrayRef<NamedAttribute> attributes, StringAttr comment) {
560 using namespace mlir::function_interface_impl;
561
562 // Add an attribute for the name.
563 result.addAttribute(SymbolTable::getSymbolAttrName(), name);
564
565 SmallVector<Attribute> perPortAttrs;
566 SmallVector<ModulePort> portTypes;
567
568 for (auto elt : ports) {
569 portTypes.push_back(elt);
570 llvm::SmallVector<NamedAttribute> portAttrs;
571 if (elt.attrs)
572 llvm::copy(elt.attrs, std::back_inserter(portAttrs));
573 perPortAttrs.push_back(builder.getDictionaryAttr(portAttrs));
574 }
575
576 // Allow clients to pass in null for the parameters list.
577 if (!parameters)
578 parameters = builder.getArrayAttr({});
579
580 // Record the argument and result types as an attribute.
581 auto type = ModuleType::get(builder.getContext(), portTypes);
582 result.addAttribute(ModuleTy::getModuleTypeAttrName(result.name),
583 TypeAttr::get(type));
584 result.addAttribute("per_port_attrs",
585 arrayOrEmpty(builder.getContext(), perPortAttrs));
586 result.addAttribute("parameters", parameters);
587 if (!comment)
588 comment = builder.getStringAttr("");
589 result.addAttribute("comment", comment);
590 result.addAttributes(attributes);
591 result.addRegion();
592}
593
594/// Internal implementation of argument/result insertion and removal on modules.
596 MLIRContext *context, ArrayRef<std::pair<unsigned, PortInfo>> insertArgs,
597 ArrayRef<unsigned> removeArgs, ArrayRef<Attribute> oldArgNames,
598 ArrayRef<Type> oldArgTypes, ArrayRef<Attribute> oldArgAttrs,
599 ArrayRef<Location> oldArgLocs, SmallVector<Attribute> &newArgNames,
600 SmallVector<Type> &newArgTypes, SmallVector<Attribute> &newArgAttrs,
601 SmallVector<Location> &newArgLocs, Block *body = nullptr) {
602
603#ifndef NDEBUG
604 // Check that the `insertArgs` and `removeArgs` indices are in ascending
605 // order.
606 assert(llvm::is_sorted(insertArgs,
607 [](auto &a, auto &b) { return a.first < b.first; }) &&
608 "insertArgs must be in ascending order");
609 assert(llvm::is_sorted(removeArgs, [](auto &a, auto &b) { return a < b; }) &&
610 "removeArgs must be in ascending order");
611#endif
612
613 auto oldArgCount = oldArgTypes.size();
614 auto newArgCount = oldArgCount + insertArgs.size() - removeArgs.size();
615 assert((int)newArgCount >= 0);
616
617 newArgNames.reserve(newArgCount);
618 newArgTypes.reserve(newArgCount);
619 newArgAttrs.reserve(newArgCount);
620 newArgLocs.reserve(newArgCount);
621
622 auto exportPortAttrName = StringAttr::get(context, "hw.exportPort");
623 auto emptyDictAttr = DictionaryAttr::get(context, {});
624 auto unknownLoc = UnknownLoc::get(context);
625
626 BitVector erasedIndices;
627 if (body)
628 erasedIndices.resize(oldArgCount + insertArgs.size());
629
630 for (unsigned argIdx = 0, idx = 0; argIdx <= oldArgCount; ++argIdx, ++idx) {
631 // Insert new ports at this position.
632 while (!insertArgs.empty() && insertArgs[0].first == argIdx) {
633 auto port = insertArgs[0].second;
634 if (port.dir == ModulePort::Direction::InOut &&
635 !isa<InOutType>(port.type))
636 port.type = InOutType::get(port.type);
637 auto sym = port.getSym();
638 Attribute attr =
639 (sym && !sym.empty())
640 ? DictionaryAttr::get(context, {{exportPortAttrName, sym}})
641 : emptyDictAttr;
642 newArgNames.push_back(port.name);
643 newArgTypes.push_back(port.type);
644 newArgAttrs.push_back(attr);
645 insertArgs = insertArgs.drop_front();
646 LocationAttr loc = port.loc ? port.loc : unknownLoc;
647 newArgLocs.push_back(loc);
648 if (body)
649 body->insertArgument(idx++, port.type, loc);
650 }
651 if (argIdx == oldArgCount)
652 break;
653
654 // Migrate the old port at this position.
655 bool removed = false;
656 while (!removeArgs.empty() && removeArgs[0] == argIdx) {
657 removeArgs = removeArgs.drop_front();
658 removed = true;
659 }
660
661 if (removed) {
662 if (body)
663 erasedIndices.set(idx);
664 } else {
665 newArgNames.push_back(oldArgNames[argIdx]);
666 newArgTypes.push_back(oldArgTypes[argIdx]);
667 newArgAttrs.push_back(oldArgAttrs.empty() ? emptyDictAttr
668 : oldArgAttrs[argIdx]);
669 newArgLocs.push_back(oldArgLocs[argIdx]);
670 }
671 }
672
673 if (body)
674 body->eraseArguments(erasedIndices);
675
676 assert(newArgNames.size() == newArgCount);
677 assert(newArgTypes.size() == newArgCount);
678 assert(newArgAttrs.size() == newArgCount);
679 assert(newArgLocs.size() == newArgCount);
680}
681
682/// Insert and remove ports of a module. The insertion and removal indices must
683/// be in ascending order. The indices refer to the port positions before any
684/// insertion or removal occurs. Ports inserted at the same index will appear in
685/// the module in the same order as they were listed in the `insert*` array.
686///
687/// The operation must be any of the module-like operations.
688///
689/// This is marked deprecated as it's only used from HandshakeToHW and
690/// PortConverter and is likely broken and not currently tested. Users of this
691/// are still written dealing with input and output ports separately, which is
692/// an old and broken style.
693[[deprecated]] static void
694modifyModulePorts(Operation *op,
695 ArrayRef<std::pair<unsigned, PortInfo>> insertInputs,
696 ArrayRef<std::pair<unsigned, PortInfo>> insertOutputs,
697 ArrayRef<unsigned> removeInputs,
698 ArrayRef<unsigned> removeOutputs, Block *body = nullptr) {
699 auto moduleOp = cast<HWModuleLike>(op);
700 auto *context = moduleOp.getContext();
701
702 // Dig up the old argument and result data.
703 auto oldArgNames = moduleOp.getInputNames();
704 auto oldArgTypes = moduleOp.getInputTypes();
705 auto oldArgAttrs = moduleOp.getAllInputAttrs();
706 auto oldArgLocs = moduleOp.getInputLocs();
707
708 auto oldResultNames = moduleOp.getOutputNames();
709 auto oldResultTypes = moduleOp.getOutputTypes();
710 auto oldResultAttrs = moduleOp.getAllOutputAttrs();
711 auto oldResultLocs = moduleOp.getOutputLocs();
712
713 // Modify the ports.
714 SmallVector<Attribute> newArgNames, newResultNames;
715 SmallVector<Type> newArgTypes, newResultTypes;
716 SmallVector<Attribute> newArgAttrs, newResultAttrs;
717 SmallVector<Location> newArgLocs, newResultLocs;
718
719 modifyModuleArgs(context, insertInputs, removeInputs, oldArgNames,
720 oldArgTypes, oldArgAttrs, oldArgLocs, newArgNames,
721 newArgTypes, newArgAttrs, newArgLocs, body);
722
723 modifyModuleArgs(context, insertOutputs, removeOutputs, oldResultNames,
724 oldResultTypes, oldResultAttrs, oldResultLocs,
725 newResultNames, newResultTypes, newResultAttrs,
726 newResultLocs);
727
728 // Update the module operation types and attributes.
729 auto fnty = FunctionType::get(context, newArgTypes, newResultTypes);
730 auto modty = detail::fnToMod(fnty, newArgNames, newResultNames);
731 moduleOp.setHWModuleType(modty);
732 moduleOp.setAllInputAttrs(newArgAttrs);
733 moduleOp.setAllOutputAttrs(newResultAttrs);
734
735 newArgLocs.append(newResultLocs.begin(), newResultLocs.end());
736 moduleOp.setAllPortLocs(newArgLocs);
737}
738
739void HWModuleOp::build(OpBuilder &builder, OperationState &result,
740 StringAttr name, const ModulePortInfo &ports,
741 ArrayAttr parameters,
742 ArrayRef<NamedAttribute> attributes, StringAttr comment,
743 bool shouldEnsureTerminator) {
744 buildModule<HWModuleOp>(builder, result, name, ports, parameters, attributes,
745 comment);
746
747 // Create a region and a block for the body.
748 auto *bodyRegion = result.regions[0].get();
749 Block *body = new Block();
750 bodyRegion->push_back(body);
751
752 // Add arguments to the body block.
753 auto unknownLoc = builder.getUnknownLoc();
754 for (auto port : ports.getInputs()) {
755 auto loc = port.loc ? Location(port.loc) : unknownLoc;
756 auto type = port.type;
757 if (port.isInOut() && !isa<InOutType>(type))
758 type = InOutType::get(type);
759 body->addArgument(type, loc);
760 }
761
762 // Add result ports attribute.
763 auto unknownLocAttr = cast<LocationAttr>(unknownLoc);
764 SmallVector<Attribute> resultLocs;
765 for (auto port : ports.getOutputs())
766 resultLocs.push_back(port.loc ? port.loc : unknownLocAttr);
767 result.addAttribute("result_locs", builder.getArrayAttr(resultLocs));
768
769 if (shouldEnsureTerminator)
770 HWModuleOp::ensureTerminator(*bodyRegion, builder, result.location);
771}
772
773void HWModuleOp::build(OpBuilder &builder, OperationState &result,
774 StringAttr name, ArrayRef<PortInfo> ports,
775 ArrayAttr parameters,
776 ArrayRef<NamedAttribute> attributes,
777 StringAttr comment) {
778 build(builder, result, name, ModulePortInfo(ports), parameters, attributes,
779 comment);
780}
781
782void HWModuleOp::build(OpBuilder &builder, OperationState &odsState,
783 StringAttr name, const ModulePortInfo &ports,
784 HWModuleBuilder modBuilder, ArrayAttr parameters,
785 ArrayRef<NamedAttribute> attributes,
786 StringAttr comment) {
787 build(builder, odsState, name, ports, parameters, attributes, comment,
788 /*shouldEnsureTerminator=*/false);
789 auto *bodyRegion = odsState.regions[0].get();
790 OpBuilder::InsertionGuard guard(builder);
791 auto accessor = HWModulePortAccessor(odsState.location, ports, *bodyRegion);
792 builder.setInsertionPointToEnd(&bodyRegion->front());
793 modBuilder(builder, accessor);
794 // Create output operands.
795 llvm::SmallVector<Value> outputOperands = accessor.getOutputOperands();
796 hw::OutputOp::create(builder, odsState.location, outputOperands);
797}
798
799void HWModuleOp::modifyPorts(
800 ArrayRef<std::pair<unsigned, PortInfo>> insertInputs,
801 ArrayRef<std::pair<unsigned, PortInfo>> insertOutputs,
802 ArrayRef<unsigned> eraseInputs, ArrayRef<unsigned> eraseOutputs) {
803 modifyModulePorts(*this, insertInputs, insertOutputs, eraseInputs,
804 eraseOutputs);
805}
806
807/// Return the name to use for the Verilog module that we're referencing
808/// here. This is typically the symbol, but can be overridden with the
809/// verilogName attribute.
810StringAttr HWModuleExternOp::getVerilogModuleNameAttr() {
811 if (auto vName = getVerilogNameAttr())
812 return vName;
813
814 return (*this)->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName());
815}
816
817StringAttr HWModuleGeneratedOp::getVerilogModuleNameAttr() {
818 if (auto vName = getVerilogNameAttr()) {
819 return vName;
820 }
821 return (*this)->getAttrOfType<StringAttr>(
822 ::mlir::SymbolTable::getSymbolAttrName());
823}
824
825void HWModuleExternOp::build(OpBuilder &builder, OperationState &result,
826 StringAttr name, const ModulePortInfo &ports,
827 StringRef verilogName, ArrayAttr parameters,
828 ArrayRef<NamedAttribute> attributes) {
829 buildModule<HWModuleExternOp>(builder, result, name, ports, parameters,
830 attributes, {});
831
832 // Add the port locations.
833 LocationAttr unknownLoc = builder.getUnknownLoc();
834 SmallVector<Attribute> portLocs;
835 for (auto elt : ports)
836 portLocs.push_back(elt.loc ? elt.loc : unknownLoc);
837 result.addAttribute("port_locs", builder.getArrayAttr(portLocs));
838
839 if (!verilogName.empty())
840 result.addAttribute("verilogName", builder.getStringAttr(verilogName));
841}
842
843void HWModuleExternOp::build(OpBuilder &builder, OperationState &result,
844 StringAttr name, ArrayRef<PortInfo> ports,
845 StringRef verilogName, ArrayAttr parameters,
846 ArrayRef<NamedAttribute> attributes) {
847 build(builder, result, name, ModulePortInfo(ports), verilogName, parameters,
848 attributes);
849}
850
851void HWModuleExternOp::modifyPorts(
852 ArrayRef<std::pair<unsigned, PortInfo>> insertInputs,
853 ArrayRef<std::pair<unsigned, PortInfo>> insertOutputs,
854 ArrayRef<unsigned> eraseInputs, ArrayRef<unsigned> eraseOutputs) {
855 modifyModulePorts(*this, insertInputs, insertOutputs, eraseInputs,
856 eraseOutputs);
857}
858
859void HWModuleExternOp::appendOutputs(
860 ArrayRef<std::pair<StringAttr, Value>> outputs) {}
861
862void HWModuleGeneratedOp::build(OpBuilder &builder, OperationState &result,
863 FlatSymbolRefAttr genKind, StringAttr name,
864 const ModulePortInfo &ports,
865 StringRef verilogName, ArrayAttr parameters,
866 ArrayRef<NamedAttribute> attributes) {
867 buildModule<HWModuleGeneratedOp>(builder, result, name, ports, parameters,
868 attributes, {});
869 // Add the port locations.
870 LocationAttr unknownLoc = builder.getUnknownLoc();
871 SmallVector<Attribute> portLocs;
872 for (auto elt : ports)
873 portLocs.push_back(elt.loc ? elt.loc : unknownLoc);
874 result.addAttribute("port_locs", builder.getArrayAttr(portLocs));
875
876 result.addAttribute("generatorKind", genKind);
877 if (!verilogName.empty())
878 result.addAttribute("verilogName", builder.getStringAttr(verilogName));
879}
880
881void HWModuleGeneratedOp::build(OpBuilder &builder, OperationState &result,
882 FlatSymbolRefAttr genKind, StringAttr name,
883 ArrayRef<PortInfo> ports, StringRef verilogName,
884 ArrayAttr parameters,
885 ArrayRef<NamedAttribute> attributes) {
886 build(builder, result, genKind, name, ModulePortInfo(ports), verilogName,
887 parameters, attributes);
888}
889
890void HWModuleGeneratedOp::modifyPorts(
891 ArrayRef<std::pair<unsigned, PortInfo>> insertInputs,
892 ArrayRef<std::pair<unsigned, PortInfo>> insertOutputs,
893 ArrayRef<unsigned> eraseInputs, ArrayRef<unsigned> eraseOutputs) {
894 modifyModulePorts(*this, insertInputs, insertOutputs, eraseInputs,
895 eraseOutputs);
896}
897
898void HWModuleGeneratedOp::appendOutputs(
899 ArrayRef<std::pair<StringAttr, Value>> outputs) {}
900
901static bool hasAttribute(StringRef name, ArrayRef<NamedAttribute> attrs) {
902 for (auto &argAttr : attrs)
903 if (argAttr.getName() == name)
904 return true;
905 return false;
906}
907
908template <typename ModuleTy>
909static ParseResult parseHWModuleOp(OpAsmParser &parser,
910 OperationState &result) {
911
912 using namespace mlir::function_interface_impl;
913 auto builder = parser.getBuilder();
914 auto loc = parser.getCurrentLocation();
915
916 // Parse the visibility attribute.
917 (void)mlir::impl::parseOptionalVisibilityKeyword(parser, result.attributes);
918
919 // Parse the name as a symbol.
920 StringAttr nameAttr;
921 if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
922 result.attributes))
923 return failure();
924
925 // Parse the generator information.
926 FlatSymbolRefAttr kindAttr;
927 if constexpr (std::is_same_v<ModuleTy, HWModuleGeneratedOp>) {
928 if (parser.parseComma() ||
929 parser.parseAttribute(kindAttr, "generatorKind", result.attributes)) {
930 return failure();
931 }
932 }
933
934 // Parse the parameters.
935 ArrayAttr parameters;
936 if (parseOptionalParameterList(parser, parameters))
937 return failure();
938
939 SmallVector<module_like_impl::PortParse> ports;
940 TypeAttr modType;
941 if (failed(module_like_impl::parseModuleSignature(parser, ports, modType)))
942 return failure();
943
944 // Parse the attribute dict.
945 if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
946 return failure();
947
948 if (hasAttribute("parameters", result.attributes)) {
949 parser.emitError(loc, "explicit `parameters` attributes not allowed");
950 return failure();
951 }
952
953 result.addAttribute("parameters", parameters);
954 result.addAttribute(ModuleTy::getModuleTypeAttrName(result.name), modType);
955
956 // Convert the specified array of dictionary attrs (which may have null
957 // entries) to an ArrayAttr of dictionaries.
958 SmallVector<Attribute> attrs;
959 for (auto &port : ports)
960 attrs.push_back(port.attrs ? port.attrs : builder.getDictionaryAttr({}));
961 // Add the attributes to the ports.
962 auto nonEmptyAttrsFn = [](Attribute attr) {
963 return attr && !cast<DictionaryAttr>(attr).empty();
964 };
965 if (llvm::any_of(attrs, nonEmptyAttrsFn))
966 result.addAttribute(ModuleTy::getPerPortAttrsAttrName(result.name),
967 builder.getArrayAttr(attrs));
968
969 // Add the port locations.
970 auto unknownLoc = builder.getUnknownLoc();
971 auto nonEmptyLocsFn = [unknownLoc](Attribute attr) {
972 return attr && cast<Location>(attr) != unknownLoc;
973 };
974 SmallVector<Attribute> locs;
975 StringAttr portLocsAttrName;
976 if constexpr (std::is_same_v<ModuleTy, HWModuleOp>) {
977 // Plain modules only store the output port locations, as the input port
978 // locations will be stored in the basic block arguments.
979 portLocsAttrName = ModuleTy::getResultLocsAttrName(result.name);
980 for (auto &port : ports)
981 if (port.direction == ModulePort::Direction::Output)
982 locs.push_back(port.sourceLoc ? Location(*port.sourceLoc) : unknownLoc);
983 } else {
984 // All other modules store all port locations in a single array.
985 portLocsAttrName = ModuleTy::getPortLocsAttrName(result.name);
986 for (auto &port : ports)
987 locs.push_back(port.sourceLoc ? Location(*port.sourceLoc) : unknownLoc);
988 }
989 if (llvm::any_of(locs, nonEmptyLocsFn))
990 result.addAttribute(portLocsAttrName, builder.getArrayAttr(locs));
991
992 // Add the entry block arguments.
993 SmallVector<OpAsmParser::Argument, 4> entryArgs;
994 for (auto &port : ports)
995 if (port.direction != ModulePort::Direction::Output)
996 entryArgs.push_back(port);
997
998 // Parse the optional function body.
999 auto *body = result.addRegion();
1000 if (std::is_same_v<ModuleTy, HWModuleOp>) {
1001 if (parser.parseRegion(*body, entryArgs))
1002 return failure();
1003
1004 HWModuleOp::ensureTerminator(*body, parser.getBuilder(), result.location);
1005 }
1006 return success();
1007}
1008
1009ParseResult HWModuleOp::parse(OpAsmParser &parser, OperationState &result) {
1010 return parseHWModuleOp<HWModuleOp>(parser, result);
1011}
1012
1013ParseResult HWModuleExternOp::parse(OpAsmParser &parser,
1014 OperationState &result) {
1015 return parseHWModuleOp<HWModuleExternOp>(parser, result);
1016}
1017
1018ParseResult HWModuleGeneratedOp::parse(OpAsmParser &parser,
1019 OperationState &result) {
1020 return parseHWModuleOp<HWModuleGeneratedOp>(parser, result);
1021}
1022
1023FunctionType getHWModuleOpType(Operation *op) {
1024 if (auto mod = dyn_cast<HWModuleLike>(op))
1025 return mod.getHWModuleType().getFuncType();
1026 return cast<FunctionType>(
1027 cast<mlir::FunctionOpInterface>(op).getFunctionType());
1028}
1029
1030template <typename ModuleTy>
1031static void printModuleOp(OpAsmPrinter &p, ModuleTy mod) {
1032 p << ' ';
1033 // Print the visibility of the module.
1034 StringRef visibilityAttrName = SymbolTable::getVisibilityAttrName();
1035 if (auto visibility = mod.getOperation()->template getAttrOfType<StringAttr>(
1036 visibilityAttrName))
1037 p << visibility.getValue() << ' ';
1038
1039 // Print the operation and the function name.
1040 p.printSymbolName(SymbolTable::getSymbolName(mod.getOperation()).getValue());
1041 if (auto gen = dyn_cast<HWModuleGeneratedOp>(mod.getOperation())) {
1042 p << ", ";
1043 p.printSymbolName(gen.getGeneratorKind());
1044 }
1045
1046 // Print the parameter list if present.
1047 printOptionalParameterList(p, mod.getOperation(), mod.getParameters());
1048
1050
1051 SmallVector<StringRef, 3> omittedAttrs;
1052 if (isa<HWModuleGeneratedOp>(mod.getOperation()))
1053 omittedAttrs.push_back("generatorKind");
1054 if constexpr (std::is_same_v<ModuleTy, HWModuleOp>)
1055 omittedAttrs.push_back(mod.getResultLocsAttrName());
1056 else
1057 omittedAttrs.push_back(mod.getPortLocsAttrName());
1058 omittedAttrs.push_back(mod.getModuleTypeAttrName());
1059 omittedAttrs.push_back(mod.getPerPortAttrsAttrName());
1060 omittedAttrs.push_back(mod.getParametersAttrName());
1061 omittedAttrs.push_back(visibilityAttrName);
1062 if (auto cmt =
1063 mod.getOperation()->template getAttrOfType<StringAttr>("comment"))
1064 if (cmt.getValue().empty())
1065 omittedAttrs.push_back("comment");
1066
1067 mlir::function_interface_impl::printFunctionAttributes(p, mod.getOperation(),
1068 omittedAttrs);
1069}
1070
1071void HWModuleExternOp::print(OpAsmPrinter &p) { printModuleOp(p, *this); }
1072void HWModuleGeneratedOp::print(OpAsmPrinter &p) { printModuleOp(p, *this); }
1073
1074void HWModuleOp::print(OpAsmPrinter &p) {
1075 printModuleOp(p, *this);
1076
1077 // Print the body if this is not an external function.
1078 Region &body = getBody();
1079 if (!body.empty()) {
1080 p << " ";
1081 p.printRegion(body, /*printEntryBlockArgs=*/false,
1082 /*printBlockTerminators=*/true);
1083 }
1084}
1085
1086static LogicalResult verifyModuleCommon(HWModuleLike module) {
1087 assert(isa<HWModuleLike>(module) &&
1088 "verifier hook should only be called on modules");
1089
1090 SmallPtrSet<Attribute, 4> paramNames;
1091
1092 // Check parameter default values are sensible.
1093 for (auto param : module->getAttrOfType<ArrayAttr>("parameters")) {
1094 auto paramAttr = cast<ParamDeclAttr>(param);
1095
1096 // Check that we don't have any redundant parameter names. These are
1097 // resolved by string name: reuse of the same name would cause ambiguities.
1098 if (!paramNames.insert(paramAttr.getName()).second)
1099 return module->emitOpError("parameter ")
1100 << paramAttr << " has the same name as a previous parameter";
1101
1102 // Default values are allowed to be missing, check them if present.
1103 auto value = paramAttr.getValue();
1104 if (!value)
1105 continue;
1106
1107 auto typedValue = dyn_cast<TypedAttr>(value);
1108 if (!typedValue)
1109 return module->emitOpError("parameter ")
1110 << paramAttr << " should have a typed value; has value " << value;
1111
1112 if (typedValue.getType() != paramAttr.getType())
1113 return module->emitOpError("parameter ")
1114 << paramAttr << " should have type " << paramAttr.getType()
1115 << "; has type " << typedValue.getType();
1116
1117 // Verify that this is a valid parameter value, disallowing parameter
1118 // references. We could allow parameters to refer to each other in the
1119 // future with lexical ordering if there is a need.
1120 if (failed(checkParameterInContext(value, module, module,
1121 /*disallowParamRefs=*/true)))
1122 return failure();
1123 }
1124 return success();
1125}
1126
1127LogicalResult HWModuleOp::verify() {
1128 if (failed(verifyModuleCommon(*this)))
1129 return failure();
1130
1131 auto type = getModuleType();
1132 auto *body = getBodyBlock();
1133
1134 // Verify the number of block arguments.
1135 auto numInputs = type.getNumInputs();
1136 if (body->getNumArguments() != numInputs)
1137 return emitOpError("entry block must have ")
1138 << numInputs << " arguments to match module signature";
1139
1140 return success();
1141}
1142
1143LogicalResult HWModuleExternOp::verify() { return verifyModuleCommon(*this); }
1144
1145std::pair<StringAttr, BlockArgument>
1146HWModuleOp::insertInput(unsigned index, StringAttr name, Type ty) {
1147 // Find a unique name for the wire.
1148 Namespace ns;
1149 auto ports = getPortList();
1150 for (auto port : ports)
1151 ns.newName(port.name.getValue());
1152 auto nameAttr = StringAttr::get(getContext(), ns.newName(name.getValue()));
1153
1154 Block *body = getBodyBlock();
1155
1156 // Create a new port for the host clock.
1157 PortInfo port;
1158 port.name = nameAttr;
1160 port.type = ty;
1161 modifyModulePorts(getOperation(), {std::make_pair(index, port)}, {}, {}, {},
1162 body);
1163
1164 // Add a new argument.
1165 return {nameAttr, body->getArgument(index)};
1166}
1167
1168void HWModuleOp::insertOutputs(unsigned index,
1169 ArrayRef<std::pair<StringAttr, Value>> outputs) {
1170
1171 auto output = cast<OutputOp>(getBodyBlock()->getTerminator());
1172 assert(index <= output->getNumOperands() && "invalid output index");
1173
1174 // Rewrite the port list of the module.
1175 SmallVector<std::pair<unsigned, PortInfo>> indexedNewPorts;
1176 for (auto &[name, value] : outputs) {
1177 PortInfo port;
1178 port.name = name;
1180 port.type = value.getType();
1181 indexedNewPorts.emplace_back(index, port);
1182 }
1183 modifyModulePorts(getOperation(), {}, indexedNewPorts, {}, {},
1184 getBodyBlock());
1185
1186 // Rewrite the output op.
1187 for (auto &[name, value] : outputs)
1188 output->insertOperands(index++, value);
1189}
1190
1191void HWModuleOp::appendOutputs(ArrayRef<std::pair<StringAttr, Value>> outputs) {
1192 return insertOutputs(getNumOutputPorts(), outputs);
1193}
1194
1195void HWModuleOp::getAsmBlockArgumentNames(mlir::Region &region,
1196 mlir::OpAsmSetValueNameFn setNameFn) {
1197 getAsmBlockArgumentNamesImpl(region, setNameFn);
1198}
1199
1200void HWModuleExternOp::getAsmBlockArgumentNames(
1201 mlir::Region &region, mlir::OpAsmSetValueNameFn setNameFn) {
1202 getAsmBlockArgumentNamesImpl(region, setNameFn);
1203}
1204
1205template <typename ModTy>
1206static SmallVector<Location> getAllPortLocs(ModTy module) {
1207 auto locs = module.getPortLocs();
1208 if (locs) {
1209 SmallVector<Location> retval;
1210 retval.reserve(locs->size());
1211 for (auto l : *locs)
1212 retval.push_back(cast<Location>(l));
1213 // Either we have a length of 0 or the correct length
1214 assert(!locs->size() || locs->size() == module.getNumPorts());
1215 return retval;
1216 }
1217 return SmallVector<Location>(module.getNumPorts(),
1218 UnknownLoc::get(module.getContext()));
1219}
1220
1221SmallVector<Location> HWModuleOp::getAllPortLocs() {
1222 SmallVector<Location> portLocs;
1223 portLocs.reserve(getNumPorts());
1224 auto resultLocs = getResultLocsAttr();
1225 unsigned inputCount = 0;
1226 auto modType = getModuleType();
1227 auto unknownLoc = UnknownLoc::get(getContext());
1228 auto *body = getBodyBlock();
1229 for (unsigned i = 0, e = getNumPorts(); i < e; ++i) {
1230 if (modType.isOutput(i)) {
1231 auto loc = resultLocs
1232 ? cast<Location>(
1233 resultLocs.getValue()[portLocs.size() - inputCount])
1234 : unknownLoc;
1235 portLocs.push_back(loc);
1236 } else {
1237 auto loc = body ? body->getArgument(inputCount).getLoc() : unknownLoc;
1238 portLocs.push_back(loc);
1239 ++inputCount;
1240 }
1241 }
1242 return portLocs;
1243}
1244
1245SmallVector<Location> HWModuleExternOp::getAllPortLocs() {
1246 return ::getAllPortLocs(*this);
1247}
1248
1249SmallVector<Location> HWModuleGeneratedOp::getAllPortLocs() {
1250 return ::getAllPortLocs(*this);
1251}
1252
1253void HWModuleOp::setAllPortLocsAttrs(ArrayRef<Attribute> locs) {
1254 SmallVector<Attribute> resultLocs;
1255 unsigned inputCount = 0;
1256 auto modType = getModuleType();
1257 auto *body = getBodyBlock();
1258 for (unsigned i = 0, e = getNumPorts(); i < e; ++i) {
1259 if (modType.isOutput(i))
1260 resultLocs.push_back(locs[i]);
1261 else
1262 body->getArgument(inputCount++).setLoc(cast<Location>(locs[i]));
1263 }
1264 setResultLocsAttr(ArrayAttr::get(getContext(), resultLocs));
1265}
1266
1267void HWModuleExternOp::setAllPortLocsAttrs(ArrayRef<Attribute> locs) {
1268 setPortLocsAttr(ArrayAttr::get(getContext(), locs));
1269}
1270
1271void HWModuleGeneratedOp::setAllPortLocsAttrs(ArrayRef<Attribute> locs) {
1272 setPortLocsAttr(ArrayAttr::get(getContext(), locs));
1273}
1274
1275template <typename ModTy>
1276static void setAllPortNames(ArrayRef<Attribute> names, ModTy module) {
1277 auto numInputs = module.getNumInputPorts();
1278 SmallVector<Attribute> argNames(names.begin(), names.begin() + numInputs);
1279 SmallVector<Attribute> resNames(names.begin() + numInputs, names.end());
1280 auto oldType = module.getModuleType();
1281 SmallVector<ModulePort> newPorts(oldType.getPorts().begin(),
1282 oldType.getPorts().end());
1283 for (size_t i = 0UL, e = newPorts.size(); i != e; ++i)
1284 newPorts[i].name = cast<StringAttr>(names[i]);
1285 auto newType = ModuleType::get(module.getContext(), newPorts);
1286 module.setModuleType(newType);
1287}
1288
1289void HWModuleOp::setAllPortNames(ArrayRef<Attribute> names) {
1290 ::setAllPortNames(names, *this);
1291}
1292
1293void HWModuleExternOp::setAllPortNames(ArrayRef<Attribute> names) {
1294 ::setAllPortNames(names, *this);
1295}
1296
1297void HWModuleGeneratedOp::setAllPortNames(ArrayRef<Attribute> names) {
1298 ::setAllPortNames(names, *this);
1299}
1300
1301ArrayRef<Attribute> HWModuleOp::getAllPortAttrs() {
1302 auto attrs = getPerPortAttrs();
1303 if (attrs && !attrs->empty())
1304 return attrs->getValue();
1305 return {};
1306}
1307
1308ArrayRef<Attribute> HWModuleExternOp::getAllPortAttrs() {
1309 auto attrs = getPerPortAttrs();
1310 if (attrs && !attrs->empty())
1311 return attrs->getValue();
1312 return {};
1313}
1314
1315ArrayRef<Attribute> HWModuleGeneratedOp::getAllPortAttrs() {
1316 auto attrs = getPerPortAttrs();
1317 if (attrs && !attrs->empty())
1318 return attrs->getValue();
1319 return {};
1320}
1321
1322void HWModuleOp::setAllPortAttrs(ArrayRef<Attribute> attrs) {
1323 setPerPortAttrsAttr(arrayOrEmpty(getContext(), attrs));
1324}
1325
1326void HWModuleExternOp::setAllPortAttrs(ArrayRef<Attribute> attrs) {
1327 setPerPortAttrsAttr(arrayOrEmpty(getContext(), attrs));
1328}
1329
1330void HWModuleGeneratedOp::setAllPortAttrs(ArrayRef<Attribute> attrs) {
1331 setPerPortAttrsAttr(arrayOrEmpty(getContext(), attrs));
1332}
1333
1334void HWModuleOp::removeAllPortAttrs() {
1335 setPerPortAttrsAttr(ArrayAttr::get(getContext(), {}));
1336}
1337
1338void HWModuleExternOp::removeAllPortAttrs() {
1339 setPerPortAttrsAttr(ArrayAttr::get(getContext(), {}));
1340}
1341
1342void HWModuleGeneratedOp::removeAllPortAttrs() {
1343 setPerPortAttrsAttr(ArrayAttr::get(getContext(), {}));
1344}
1345
1346// This probably does really unexpected stuff when you change the number of
1347
1348template <typename ModTy>
1349static void setHWModuleType(ModTy &mod, ModuleType type) {
1350 auto argAttrs = mod.getAllInputAttrs();
1351 auto resAttrs = mod.getAllOutputAttrs();
1352 mod.setModuleTypeAttr(TypeAttr::get(type));
1353 unsigned newNumArgs = type.getNumInputs();
1354 unsigned newNumResults = type.getNumOutputs();
1355
1356 auto emptyDict = DictionaryAttr::get(mod.getContext());
1357 argAttrs.resize(newNumArgs, emptyDict);
1358 resAttrs.resize(newNumResults, emptyDict);
1359
1360 SmallVector<Attribute> attrs;
1361 attrs.append(argAttrs.begin(), argAttrs.end());
1362 attrs.append(resAttrs.begin(), resAttrs.end());
1363
1364 if (attrs.empty())
1365 return mod.removeAllPortAttrs();
1366 mod.setAllPortAttrs(attrs);
1367}
1368
1369void HWModuleOp::setHWModuleType(ModuleType type) {
1370 return ::setHWModuleType(*this, type);
1371}
1372
1373void HWModuleExternOp::setHWModuleType(ModuleType type) {
1374 return ::setHWModuleType(*this, type);
1375}
1376
1377void HWModuleGeneratedOp::setHWModuleType(ModuleType type) {
1378 return ::setHWModuleType(*this, type);
1379}
1380
1381/// Lookup the generator for the symbol. This returns null on
1382/// invalid IR.
1383Operation *HWModuleGeneratedOp::getGeneratorKindOp() {
1384 auto topLevelModuleOp = (*this)->getParentOfType<ModuleOp>();
1385 return topLevelModuleOp.lookupSymbol(getGeneratorKind());
1386}
1387
1388LogicalResult
1389HWModuleGeneratedOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1390 auto *referencedKind =
1391 symbolTable.lookupNearestSymbolFrom(*this, getGeneratorKindAttr());
1392
1393 if (referencedKind == nullptr)
1394 return emitError("Cannot find generator definition '")
1395 << getGeneratorKind() << "'";
1396
1397 if (!isa<HWGeneratorSchemaOp>(referencedKind))
1398 return emitError("Symbol resolved to '")
1399 << referencedKind->getName()
1400 << "' which is not a HWGeneratorSchemaOp";
1401
1402 auto referencedKindOp = dyn_cast<HWGeneratorSchemaOp>(referencedKind);
1403 auto paramRef = referencedKindOp.getRequiredAttrs();
1404 auto dict = (*this)->getAttrDictionary();
1405 for (auto str : paramRef) {
1406 auto strAttr = dyn_cast<StringAttr>(str);
1407 if (!strAttr)
1408 return emitError("Unknown attribute type, expected a string");
1409 if (!dict.get(strAttr.getValue()))
1410 return emitError("Missing attribute '") << strAttr.getValue() << "'";
1411 }
1412
1413 return success();
1414}
1415
1416LogicalResult HWModuleGeneratedOp::verify() {
1417 return verifyModuleCommon(*this);
1418}
1419
1420void HWModuleGeneratedOp::getAsmBlockArgumentNames(
1421 mlir::Region &region, mlir::OpAsmSetValueNameFn setNameFn) {
1422 getAsmBlockArgumentNamesImpl(region, setNameFn);
1423}
1424
1425LogicalResult HWModuleOp::verifyBody() { return success(); }
1426
1427template <typename ModuleTy>
1428static SmallVector<PortInfo> getPortList(ModuleTy &mod) {
1429 auto modTy = mod.getHWModuleType();
1430 auto emptyDict = DictionaryAttr::get(mod.getContext());
1431 SmallVector<PortInfo> retval;
1432 auto locs = mod.getAllPortLocs();
1433 for (unsigned i = 0, e = modTy.getNumPorts(); i < e; ++i) {
1434 LocationAttr loc = locs[i];
1435 DictionaryAttr attrs =
1436 dyn_cast_or_null<DictionaryAttr>(mod.getPortAttrs(i));
1437 if (!attrs)
1438 attrs = emptyDict;
1439 retval.push_back({modTy.getPorts()[i],
1440 modTy.isOutput(i) ? modTy.getOutputIdForPortId(i)
1441 : modTy.getInputIdForPortId(i),
1442 attrs, loc});
1443 }
1444 return retval;
1445}
1446
1447template <typename ModuleTy>
1448static PortInfo getPort(ModuleTy &mod, size_t idx) {
1449 auto modTy = mod.getHWModuleType();
1450 auto emptyDict = DictionaryAttr::get(mod.getContext());
1451 LocationAttr loc = mod.getPortLoc(idx);
1452 DictionaryAttr attrs =
1453 dyn_cast_or_null<DictionaryAttr>(mod.getPortAttrs(idx));
1454 if (!attrs)
1455 attrs = emptyDict;
1456 return {modTy.getPorts()[idx],
1457 modTy.isOutput(idx) ? modTy.getOutputIdForPortId(idx)
1458 : modTy.getInputIdForPortId(idx),
1459 attrs, loc};
1460}
1461
1462//===----------------------------------------------------------------------===//
1463// InstanceOp
1464//===----------------------------------------------------------------------===//
1465
1466/// Create a instance that refers to a known module.
1467void InstanceOp::build(OpBuilder &builder, OperationState &result,
1468 Operation *module, StringAttr name,
1469 ArrayRef<Value> inputs, ArrayAttr parameters,
1470 InnerSymAttr innerSym) {
1471 if (!parameters)
1472 parameters = builder.getArrayAttr({});
1473
1474 auto mod = cast<hw::HWModuleLike>(module);
1475 auto argNames = builder.getArrayAttr(mod.getInputNames());
1476 auto resultNames = builder.getArrayAttr(mod.getOutputNames());
1477
1478 // Try to resolve the parameterized module type. If failed, use the module's
1479 // parmeterized type. If the client doesn't fix this error, the verifier will
1480 // fail.
1481 ModuleType modType = mod.getHWModuleType();
1482 FailureOr<ModuleType> resolvedModType = modType.resolveParametricTypes(
1483 parameters, result.location, /*emitErrors=*/false);
1484 if (succeeded(resolvedModType))
1485 modType = *resolvedModType;
1486 FunctionType funcType = resolvedModType->getFuncType();
1487 build(builder, result, funcType.getResults(), name,
1488 FlatSymbolRefAttr::get(SymbolTable::getSymbolName(module)), inputs,
1489 argNames, resultNames, parameters, innerSym, /*doNotPrint=*/{});
1490}
1491
1492std::optional<size_t> InstanceOp::getTargetResultIndex() {
1493 // Inner symbols on instance operations target the op not any result.
1494 return std::nullopt;
1495}
1496
1497LogicalResult InstanceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1499 *this, getModuleNameAttr(), getInputs(), getResultTypes(), getArgNames(),
1500 getResultNames(), getParameters(), symbolTable);
1501}
1502
1503LogicalResult InstanceOp::verify() {
1504 auto module = (*this)->getParentOfType<HWModuleOp>();
1505 if (!module)
1506 return success();
1507
1508 auto moduleParameters = module->getAttrOfType<ArrayAttr>("parameters");
1510 [&](const std::function<bool(InFlightDiagnostic &)> &fn) {
1511 auto diag = emitOpError();
1512 if (fn(diag))
1513 diag.attachNote(module->getLoc()) << "module declared here";
1514 };
1516 getParameters(), moduleParameters, emitError);
1517}
1518
1519ParseResult InstanceOp::parse(OpAsmParser &parser, OperationState &result) {
1520 StringAttr instanceNameAttr;
1521 InnerSymAttr innerSym;
1522 FlatSymbolRefAttr moduleNameAttr;
1523 SmallVector<OpAsmParser::UnresolvedOperand, 4> inputsOperands;
1524 SmallVector<Type, 1> inputsTypes, allResultTypes;
1525 ArrayAttr argNames, resultNames, parameters;
1526 auto noneType = parser.getBuilder().getType<NoneType>();
1527
1528 if (parser.parseAttribute(instanceNameAttr, noneType, "instanceName",
1529 result.attributes))
1530 return failure();
1531
1532 if (succeeded(parser.parseOptionalKeyword("sym"))) {
1533 // Parsing an optional symbol name doesn't fail, so no need to check the
1534 // result.
1535 if (parser.parseCustomAttributeWithFallback(innerSym))
1536 return failure();
1537 result.addAttribute(InnerSymbolTable::getInnerSymbolAttrName(), innerSym);
1538 }
1539
1540 llvm::SMLoc parametersLoc, inputsOperandsLoc;
1541 if (parser.parseAttribute(moduleNameAttr, noneType, "moduleName",
1542 result.attributes) ||
1543 parser.getCurrentLocation(&parametersLoc) ||
1544 parseOptionalParameterList(parser, parameters) ||
1545 parseInputPortList(parser, inputsOperands, inputsTypes, argNames) ||
1546 parser.resolveOperands(inputsOperands, inputsTypes, inputsOperandsLoc,
1547 result.operands) ||
1548 parser.parseArrow() ||
1549 parseOutputPortList(parser, allResultTypes, resultNames) ||
1550 parser.parseOptionalAttrDict(result.attributes)) {
1551 return failure();
1552 }
1553
1554 result.addAttribute("argNames", argNames);
1555 result.addAttribute("resultNames", resultNames);
1556 result.addAttribute("parameters", parameters);
1557 result.addTypes(allResultTypes);
1558 return success();
1559}
1560
1561void InstanceOp::print(OpAsmPrinter &p) {
1562 p << ' ';
1563 p.printAttributeWithoutType(getInstanceNameAttr());
1564 if (auto attr = getInnerSymAttr()) {
1565 p << " sym ";
1566 attr.print(p);
1567 }
1568 p << ' ';
1569 p.printAttributeWithoutType(getModuleNameAttr());
1570 printOptionalParameterList(p, *this, getParameters());
1571 printInputPortList(p, *this, getInputs(), getInputs().getTypes(),
1572 getArgNames());
1573 p << " -> ";
1574 printOutputPortList(p, *this, getResultTypes(), getResultNames());
1575
1576 p.printOptionalAttrDict(
1577 (*this)->getAttrs(),
1578 /*elidedAttrs=*/{"instanceName",
1579 InnerSymbolTable::getInnerSymbolAttrName(), "moduleName",
1580 "argNames", "resultNames", "parameters"});
1581}
1582
1583//===----------------------------------------------------------------------===//
1584// InstanceChoiceOp
1585//===----------------------------------------------------------------------===//
1586
1587std::optional<size_t> InstanceChoiceOp::getTargetResultIndex() {
1588 // Inner symbols on instance operations target the op not any result.
1589 return std::nullopt;
1590}
1591
1592LogicalResult
1593InstanceChoiceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1594 for (Attribute name : getModuleNamesAttr()) {
1596 *this, cast<FlatSymbolRefAttr>(name), getInputs(), getResultTypes(),
1597 getArgNames(), getResultNames(), getParameters(), symbolTable))) {
1598 return failure();
1599 }
1600 }
1601 return success();
1602}
1603
1604LogicalResult InstanceChoiceOp::verify() {
1605 auto module = (*this)->getParentOfType<HWModuleOp>();
1606 if (!module)
1607 return success();
1608
1609 auto moduleParameters = module->getAttrOfType<ArrayAttr>("parameters");
1611 [&](const std::function<bool(InFlightDiagnostic &)> &fn) {
1612 auto diag = emitOpError();
1613 if (fn(diag))
1614 diag.attachNote(module->getLoc()) << "module declared here";
1615 };
1617 getParameters(), moduleParameters, emitError);
1618}
1619
1620ParseResult InstanceChoiceOp::parse(OpAsmParser &parser,
1621 OperationState &result) {
1622 StringAttr optionNameAttr;
1623 StringAttr instanceNameAttr;
1624 InnerSymAttr innerSym;
1625 SmallVector<Attribute> moduleNames;
1626 SmallVector<Attribute> caseNames;
1627 SmallVector<OpAsmParser::UnresolvedOperand, 4> inputsOperands;
1628 SmallVector<Type, 1> inputsTypes, allResultTypes;
1629 ArrayAttr argNames, resultNames, parameters;
1630 auto noneType = parser.getBuilder().getType<NoneType>();
1631
1632 if (parser.parseAttribute(instanceNameAttr, noneType, "instanceName",
1633 result.attributes))
1634 return failure();
1635
1636 if (succeeded(parser.parseOptionalKeyword("sym"))) {
1637 // Parsing an optional symbol name doesn't fail, so no need to check the
1638 // result.
1639 if (parser.parseCustomAttributeWithFallback(innerSym))
1640 return failure();
1641 result.addAttribute(InnerSymbolTable::getInnerSymbolAttrName(), innerSym);
1642 }
1643
1644 if (parser.parseKeyword("option") ||
1645 parser.parseAttribute(optionNameAttr, noneType, "optionName",
1646 result.attributes))
1647 return failure();
1648
1649 FlatSymbolRefAttr defaultModuleName;
1650 if (parser.parseAttribute(defaultModuleName))
1651 return failure();
1652 moduleNames.push_back(defaultModuleName);
1653
1654 while (succeeded(parser.parseOptionalKeyword("or"))) {
1655 FlatSymbolRefAttr moduleName;
1656 StringAttr targetName;
1657 if (parser.parseAttribute(moduleName) ||
1658 parser.parseOptionalKeyword("if") || parser.parseAttribute(targetName))
1659 return failure();
1660 moduleNames.push_back(moduleName);
1661 caseNames.push_back(targetName);
1662 }
1663
1664 llvm::SMLoc parametersLoc, inputsOperandsLoc;
1665 if (parser.getCurrentLocation(&parametersLoc) ||
1666 parseOptionalParameterList(parser, parameters) ||
1667 parseInputPortList(parser, inputsOperands, inputsTypes, argNames) ||
1668 parser.resolveOperands(inputsOperands, inputsTypes, inputsOperandsLoc,
1669 result.operands) ||
1670 parser.parseArrow() ||
1671 parseOutputPortList(parser, allResultTypes, resultNames) ||
1672 parser.parseOptionalAttrDict(result.attributes)) {
1673 return failure();
1674 }
1675
1676 result.addAttribute("moduleNames",
1677 ArrayAttr::get(parser.getContext(), moduleNames));
1678 result.addAttribute("caseNames",
1679 ArrayAttr::get(parser.getContext(), caseNames));
1680 result.addAttribute("argNames", argNames);
1681 result.addAttribute("resultNames", resultNames);
1682 result.addAttribute("parameters", parameters);
1683 result.addTypes(allResultTypes);
1684 return success();
1685}
1686
1687void InstanceChoiceOp::print(OpAsmPrinter &p) {
1688 p << ' ';
1689 p.printAttributeWithoutType(getInstanceNameAttr());
1690 if (auto attr = getInnerSymAttr()) {
1691 p << " sym ";
1692 attr.print(p);
1693 }
1694 p << " option " << getOptionNameAttr() << ' ';
1695
1696 auto moduleNames = getModuleNamesAttr();
1697 auto caseNames = getCaseNamesAttr();
1698 assert(moduleNames.size() == caseNames.size() + 1);
1699
1700 p.printAttributeWithoutType(moduleNames[0]);
1701 for (size_t i = 0, n = caseNames.size(); i < n; ++i) {
1702 p << " or ";
1703 p.printAttributeWithoutType(moduleNames[i + 1]);
1704 p << " if ";
1705 p.printAttributeWithoutType(caseNames[i]);
1706 }
1707
1708 printOptionalParameterList(p, *this, getParameters());
1709 printInputPortList(p, *this, getInputs(), getInputs().getTypes(),
1710 getArgNames());
1711 p << " -> ";
1712 printOutputPortList(p, *this, getResultTypes(), getResultNames());
1713
1714 p.printOptionalAttrDict(
1715 (*this)->getAttrs(),
1716 /*elidedAttrs=*/{"instanceName",
1717 InnerSymbolTable::getInnerSymbolAttrName(),
1718 "moduleNames", "caseNames", "argNames", "resultNames",
1719 "parameters", "optionName"});
1720}
1721
1722ArrayAttr InstanceChoiceOp::getReferencedModuleNamesAttr() {
1723 SmallVector<Attribute> moduleNames;
1724 for (Attribute attr : getModuleNamesAttr()) {
1725 moduleNames.push_back(cast<FlatSymbolRefAttr>(attr).getAttr());
1726 }
1727 return ArrayAttr::get(getContext(), moduleNames);
1728}
1729
1730//===----------------------------------------------------------------------===//
1731// HWOutputOp
1732//===----------------------------------------------------------------------===//
1733
1734/// Verify that the num of operands and types fit the declared results.
1735LogicalResult OutputOp::verify() {
1736 // Check that the we (hw.output) have the same number of operands as our
1737 // region has results.
1738 ModuleType modType;
1739 if (auto mod = dyn_cast<HWModuleOp>((*this)->getParentOp()))
1740 modType = mod.getHWModuleType();
1741 else {
1742 emitOpError("must have a module parent");
1743 return failure();
1744 }
1745 auto modResults = modType.getOutputTypes();
1746 OperandRange outputValues = getOperands();
1747 if (modResults.size() != outputValues.size()) {
1748 emitOpError("must have same number of operands as region results.");
1749 return failure();
1750 }
1751
1752 // Check that the types of our operands and the region's results match.
1753 for (size_t i = 0, e = modResults.size(); i < e; ++i) {
1754 if (modResults[i] != outputValues[i].getType()) {
1755 emitOpError("output types must match module. In "
1756 "operand ")
1757 << i << ", expected " << modResults[i] << ", but got "
1758 << outputValues[i].getType() << ".";
1759 return failure();
1760 }
1761 }
1762
1763 return success();
1764}
1765
1766//===----------------------------------------------------------------------===//
1767// Other Operations
1768//===----------------------------------------------------------------------===//
1769
1770static ParseResult parseSliceTypes(OpAsmParser &p, Type &srcType,
1771 Type &idxType) {
1772 Type type;
1773 if (p.parseType(type))
1774 return p.emitError(p.getCurrentLocation(), "Expected type");
1775 auto arrType = type_dyn_cast<ArrayType>(type);
1776 if (!arrType)
1777 return p.emitError(p.getCurrentLocation(), "Expected !hw.array type");
1778 srcType = type;
1779 unsigned idxWidth = llvm::Log2_64_Ceil(arrType.getNumElements());
1780 idxType = IntegerType::get(p.getBuilder().getContext(), idxWidth);
1781 return success();
1782}
1783
1784static void printSliceTypes(OpAsmPrinter &p, Operation *, Type srcType,
1785 Type idxType) {
1786 p.printType(srcType);
1787}
1788
1789ParseResult ArrayCreateOp::parse(OpAsmParser &parser, OperationState &result) {
1790 llvm::SMLoc inputOperandsLoc = parser.getCurrentLocation();
1791 llvm::SmallVector<OpAsmParser::UnresolvedOperand, 16> operands;
1792 Type elemType;
1793
1794 if (parser.parseOperandList(operands) ||
1795 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
1796 parser.parseType(elemType))
1797 return failure();
1798
1799 if (operands.size() == 0)
1800 return parser.emitError(inputOperandsLoc,
1801 "Cannot construct an array of length 0");
1802
1803 // Check for optional result type: " -> <type>"
1804 Type resultType;
1805 if (parser.parseOptionalArrow().succeeded()) {
1806 if (parser.parseType(resultType))
1807 return failure();
1808 result.addTypes(resultType);
1809 } else {
1810 result.addTypes({ArrayType::get(elemType, operands.size())});
1811 }
1812
1813 for (auto operand : operands)
1814 if (parser.resolveOperand(operand, elemType, result.operands))
1815 return failure();
1816 return success();
1817}
1818
1819void ArrayCreateOp::print(OpAsmPrinter &p) {
1820 p << " ";
1821 p.printOperands(getInputs());
1822 p.printOptionalAttrDict((*this)->getAttrs());
1823 p << " : " << getInputs()[0].getType();
1824
1825 // Print optional result type if it's not the default constructed type
1826 Type expectedType =
1827 ArrayType::get(getInputs()[0].getType(), getInputs().size());
1828 if (getType() != expectedType)
1829 p << " -> " << getType();
1830}
1831
1832void ArrayCreateOp::build(OpBuilder &b, OperationState &state,
1833 ValueRange values) {
1834 assert(values.size() > 0 && "Cannot build array of zero elements");
1835 Type elemType = values[0].getType();
1836 assert(llvm::all_of(
1837 values,
1838 [elemType](Value v) -> bool { return v.getType() == elemType; }) &&
1839 "All values must have same type.");
1840 build(b, state, ArrayType::get(elemType, values.size()), values);
1841}
1842
1843LogicalResult ArrayCreateOp::verify() {
1844 unsigned returnSize = hw::type_cast<ArrayType>(getType()).getNumElements();
1845 if (getInputs().size() != returnSize)
1846 return failure();
1847 return success();
1848}
1849
1850OpFoldResult ArrayCreateOp::fold(FoldAdaptor adaptor) {
1851 if (llvm::any_of(adaptor.getInputs(), [](Attribute attr) {
1852 return !isa_and_nonnull<IntegerAttr>(attr);
1853 }))
1854 return {};
1855 return ArrayAttr::get(getContext(), adaptor.getInputs());
1856}
1857
1858// Check whether an integer value is an offset from a base.
1859bool hw::isOffset(Value base, Value index, uint64_t offset) {
1860 if (auto constBase = base.getDefiningOp<hw::ConstantOp>()) {
1861 if (auto constIndex = index.getDefiningOp<hw::ConstantOp>()) {
1862 // If both values are a constant, check if index == base + offset.
1863 // To account for overflow, the addition is performed with an extra bit
1864 // and the offset is asserted to fit in the bit width of the base.
1865 auto baseValue = constBase.getValue();
1866 auto indexValue = constIndex.getValue();
1867
1868 unsigned bits = baseValue.getBitWidth();
1869 assert(bits == indexValue.getBitWidth() && "mismatched widths");
1870
1871 if (bits < 64 && offset >= (1ull << bits))
1872 return false;
1873
1874 APInt baseExt = baseValue.zextOrTrunc(bits + 1);
1875 APInt indexExt = indexValue.zextOrTrunc(bits + 1);
1876 return baseExt + offset == indexExt;
1877 }
1878 }
1879 return false;
1880}
1881
1882// Canonicalize a create of consecutive elements to a slice.
1883static LogicalResult foldCreateToSlice(ArrayCreateOp op,
1884 PatternRewriter &rewriter) {
1885 // Do not canonicalize create of get into a slice.
1886 auto arrayTy = hw::type_cast<ArrayType>(op.getType());
1887 if (arrayTy.getNumElements() <= 1)
1888 return failure();
1889 auto elemTy = arrayTy.getElementType();
1890
1891 // Check if create arguments are consecutive elements of the same array.
1892 // Attempt to break a create of gets into a sequence of consecutive intervals.
1893 struct Chunk {
1894 Value input;
1895 Value index;
1896 size_t size;
1897 };
1898 SmallVector<Chunk> chunks;
1899 for (Value value : llvm::reverse(op.getInputs())) {
1900 auto get = value.getDefiningOp<ArrayGetOp>();
1901 if (!get)
1902 return failure();
1903
1904 Value input = get.getInput();
1905 Value index = get.getIndex();
1906 if (!chunks.empty()) {
1907 auto &c = *chunks.rbegin();
1908 if (c.input == get.getInput() && isOffset(c.index, index, c.size)) {
1909 c.size++;
1910 continue;
1911 }
1912 }
1913
1914 chunks.push_back(Chunk{input, index, 1});
1915 }
1916
1917 // If there is a single slice, eliminate the create.
1918 if (chunks.size() == 1) {
1919 auto &chunk = chunks[0];
1920 rewriter.replaceOp(op, rewriter.createOrFold<ArraySliceOp>(
1921 op.getLoc(), arrayTy, chunk.input, chunk.index));
1922 return success();
1923 }
1924
1925 // If the number of chunks is significantly less than the number of
1926 // elements, replace the create with a concat of the identified slices.
1927 if (chunks.size() * 2 < arrayTy.getNumElements()) {
1928 SmallVector<Value> slices;
1929 for (auto &chunk : llvm::reverse(chunks)) {
1930 auto sliceTy = ArrayType::get(elemTy, chunk.size);
1931 slices.push_back(rewriter.createOrFold<ArraySliceOp>(
1932 op.getLoc(), sliceTy, chunk.input, chunk.index));
1933 }
1934 rewriter.replaceOpWithNewOp<ArrayConcatOp>(op, arrayTy, slices);
1935 return success();
1936 }
1937
1938 return failure();
1939}
1940
1941LogicalResult ArrayCreateOp::canonicalize(ArrayCreateOp op,
1942 PatternRewriter &rewriter) {
1943 if (succeeded(foldCreateToSlice(op, rewriter)))
1944 return success();
1945 return failure();
1946}
1947
1948Value ArrayCreateOp::getUniformElement() {
1949 if (!getInputs().empty() && llvm::all_equal(getInputs()))
1950 return getInputs()[0];
1951 return {};
1952}
1953
1954static std::optional<uint64_t> getUIntFromValue(Value value) {
1955 auto idxOp = dyn_cast_or_null<ConstantOp>(value.getDefiningOp());
1956 if (!idxOp)
1957 return std::nullopt;
1958 APInt idxAttr = idxOp.getValue();
1959 if (idxAttr.getBitWidth() > 64)
1960 return std::nullopt;
1961 return idxAttr.getLimitedValue();
1962}
1963
1964LogicalResult ArraySliceOp::verify() {
1965 unsigned inputSize =
1966 type_cast<ArrayType>(getInput().getType()).getNumElements();
1967 if (llvm::Log2_64_Ceil(inputSize) !=
1968 getLowIndex().getType().getIntOrFloatBitWidth())
1969 return emitOpError(
1970 "ArraySlice: index width must match clog2 of array size");
1971 return success();
1972}
1973
1974OpFoldResult ArraySliceOp::fold(FoldAdaptor adaptor) {
1975 // If we are slicing the entire input, then return it.
1976 if (getType() == getInput().getType())
1977 return getInput();
1978 return {};
1979}
1980
1981LogicalResult ArraySliceOp::canonicalize(ArraySliceOp op,
1982 PatternRewriter &rewriter) {
1983 auto sliceTy = hw::type_cast<ArrayType>(op.getType());
1984 auto elemTy = sliceTy.getElementType();
1985 uint64_t sliceSize = sliceTy.getNumElements();
1986 if (sliceSize == 0)
1987 return failure();
1988
1989 if (sliceSize == 1) {
1990 // slice(a, n) -> create(a[n])
1991 auto get = ArrayGetOp::create(rewriter, op.getLoc(), op.getInput(),
1992 op.getLowIndex());
1993 rewriter.replaceOpWithNewOp<ArrayCreateOp>(op, op.getType(),
1994 get.getResult());
1995 return success();
1996 }
1997
1998 auto offsetOpt = getUIntFromValue(op.getLowIndex());
1999 if (!offsetOpt)
2000 return failure();
2001
2002 auto *inputOp = op.getInput().getDefiningOp();
2003 if (auto inputSlice = dyn_cast_or_null<ArraySliceOp>(inputOp)) {
2004 // slice(slice(a, n), m) -> slice(a, n + m)
2005 if (inputSlice == op)
2006 return failure();
2007
2008 auto inputIndex = inputSlice.getLowIndex();
2009 auto inputOffsetOpt = getUIntFromValue(inputIndex);
2010 if (!inputOffsetOpt)
2011 return failure();
2012
2013 uint64_t offset = *offsetOpt + *inputOffsetOpt;
2014 auto lowIndex =
2015 ConstantOp::create(rewriter, op.getLoc(), inputIndex.getType(), offset);
2016 rewriter.replaceOpWithNewOp<ArraySliceOp>(op, op.getType(),
2017 inputSlice.getInput(), lowIndex);
2018 return success();
2019 }
2020
2021 if (auto inputCreate = dyn_cast_or_null<ArrayCreateOp>(inputOp)) {
2022 // slice(create(a0, a1, ..., an), m) -> create(am, ...)
2023 auto inputs = inputCreate.getInputs();
2024
2025 uint64_t begin = inputs.size() - *offsetOpt - sliceSize;
2026 rewriter.replaceOpWithNewOp<ArrayCreateOp>(op, op.getType(),
2027 inputs.slice(begin, sliceSize));
2028 return success();
2029 }
2030
2031 if (auto inputConcat = dyn_cast_or_null<ArrayConcatOp>(inputOp)) {
2032 // slice(concat(a1, a2, ...)) -> concat(a2, slice(a3, ..), ...)
2033 SmallVector<Value> chunks;
2034 uint64_t sliceStart = *offsetOpt;
2035 for (auto input : llvm::reverse(inputConcat.getInputs())) {
2036 // Check whether the input intersects with the slice.
2037 uint64_t inputSize =
2038 hw::type_cast<ArrayType>(input.getType()).getNumElements();
2039 if (inputSize == 0 || inputSize <= sliceStart) {
2040 sliceStart -= inputSize;
2041 continue;
2042 }
2043
2044 // Find the indices to slice from this input by intersection.
2045 uint64_t cutEnd = std::min(inputSize, sliceStart + sliceSize);
2046 uint64_t cutSize = cutEnd - sliceStart;
2047 assert(cutSize != 0 && "slice cannot be empty");
2048
2049 if (cutSize == inputSize) {
2050 // The whole input fits in the slice, add it.
2051 assert(sliceStart == 0 && "invalid cut size");
2052 chunks.push_back(input);
2053 } else {
2054 // Slice the required bits from the input.
2055 unsigned width = inputSize == 1 ? 1 : llvm::Log2_64_Ceil(inputSize);
2056 auto lowIndex = ConstantOp::create(
2057 rewriter, op.getLoc(), rewriter.getIntegerType(width), sliceStart);
2058 chunks.push_back(ArraySliceOp::create(
2059 rewriter, op.getLoc(), hw::ArrayType::get(elemTy, cutSize), input,
2060 lowIndex));
2061 }
2062
2063 sliceStart = 0;
2064 sliceSize -= cutSize;
2065 if (sliceSize == 0)
2066 break;
2067 }
2068
2069 assert(chunks.size() > 0 && "missing sliced items");
2070 if (chunks.size() == 1)
2071 rewriter.replaceOp(op, chunks[0]);
2072 else
2073 rewriter.replaceOpWithNewOp<ArrayConcatOp>(
2074 op, llvm::to_vector(llvm::reverse(chunks)));
2075 return success();
2076 }
2077 return failure();
2078}
2079
2080//===----------------------------------------------------------------------===//
2081// ArrayConcatOp
2082//===----------------------------------------------------------------------===//
2083
2084static ParseResult parseArrayConcatTypes(OpAsmParser &p,
2085 SmallVectorImpl<Type> &inputTypes,
2086 Type &resultType) {
2087 Type elemType;
2088 uint64_t resultSize = 0;
2089
2090 auto parseElement = [&]() -> ParseResult {
2091 Type ty;
2092 if (p.parseType(ty))
2093 return failure();
2094 auto arrTy = type_dyn_cast<ArrayType>(ty);
2095 if (!arrTy)
2096 return p.emitError(p.getCurrentLocation(), "Expected !hw.array type");
2097 if (elemType && elemType != arrTy.getElementType())
2098 return p.emitError(p.getCurrentLocation(), "Expected array element type ")
2099 << elemType;
2100
2101 elemType = arrTy.getElementType();
2102 inputTypes.push_back(ty);
2103 resultSize += arrTy.getNumElements();
2104 return success();
2105 };
2106
2107 if (p.parseCommaSeparatedList(parseElement))
2108 return failure();
2109
2110 resultType = ArrayType::get(elemType, resultSize);
2111 return success();
2112}
2113
2114static void printArrayConcatTypes(OpAsmPrinter &p, Operation *,
2115 TypeRange inputTypes, Type resultType) {
2116 llvm::interleaveComma(inputTypes, p, [&p](Type t) { p << t; });
2117}
2118
2119void ArrayConcatOp::build(OpBuilder &b, OperationState &state,
2120 ValueRange values) {
2121 assert(!values.empty() && "Cannot build array of zero elements");
2122 ArrayType arrayTy = cast<ArrayType>(values[0].getType());
2123 Type elemTy = arrayTy.getElementType();
2124 assert(llvm::all_of(values,
2125 [elemTy](Value v) -> bool {
2126 return isa<ArrayType>(v.getType()) &&
2127 cast<ArrayType>(v.getType()).getElementType() ==
2128 elemTy;
2129 }) &&
2130 "All values must be of ArrayType with the same element type.");
2131
2132 uint64_t resultSize = 0;
2133 for (Value val : values)
2134 resultSize += cast<ArrayType>(val.getType()).getNumElements();
2135 build(b, state, ArrayType::get(elemTy, resultSize), values);
2136}
2137
2138OpFoldResult ArrayConcatOp::fold(FoldAdaptor adaptor) {
2139 if (getInputs().size() == 1)
2140 return getInputs()[0];
2141
2142 auto inputs = adaptor.getInputs();
2143 SmallVector<Attribute> array;
2144 for (size_t i = 0, e = getNumOperands(); i < e; ++i) {
2145 if (!inputs[i])
2146 return {};
2147 llvm::copy(cast<ArrayAttr>(inputs[i]), std::back_inserter(array));
2148 }
2149 return ArrayAttr::get(getContext(), array);
2150}
2151
2152// Flatten a concatenation of array creates into a single create.
2153static bool flattenConcatOp(ArrayConcatOp op, PatternRewriter &rewriter) {
2154 for (auto input : op.getInputs())
2155 if (!input.getDefiningOp<ArrayCreateOp>())
2156 return false;
2157
2158 SmallVector<Value> items;
2159 for (auto input : op.getInputs()) {
2160 auto create = cast<ArrayCreateOp>(input.getDefiningOp());
2161 for (auto item : create.getInputs())
2162 items.push_back(item);
2163 }
2164
2165 rewriter.replaceOpWithNewOp<ArrayCreateOp>(op, items);
2166 return true;
2167}
2168
2169// Merge consecutive slice expressions in a concatenation.
2170static bool mergeConcatSlices(ArrayConcatOp op, PatternRewriter &rewriter) {
2171 struct Slice {
2172 Value input;
2173 Value index;
2174 size_t size;
2175 Value op;
2176 SmallVector<Location> locs;
2177 };
2178
2179 SmallVector<Value> items;
2180 std::optional<Slice> last;
2181 bool changed = false;
2182
2183 auto concatenate = [&] {
2184 // If there is only one op in the slice, place it to the items list.
2185 if (!last)
2186 return;
2187 if (last->op) {
2188 items.push_back(last->op);
2189 last.reset();
2190 return;
2191 }
2192
2193 // Otherwise, create a new slice of with the given size and place it.
2194 // In this case, the concat op is replaced, using the new argument.
2195 changed = true;
2196 auto loc = FusedLoc::get(op.getContext(), last->locs);
2197 auto origTy = hw::type_cast<ArrayType>(last->input.getType());
2198 auto arrayTy = ArrayType::get(origTy.getElementType(), last->size);
2199 items.push_back(rewriter.createOrFold<ArraySliceOp>(
2200 loc, arrayTy, last->input, last->index));
2201
2202 last.reset();
2203 };
2204
2205 auto append = [&](Value op, Value input, Value index, size_t size) {
2206 // If this slice is an extension of the previous one, extend the size
2207 // saved. In this case, a new slice of is created and the concatenation
2208 // operator is rewritten. Otherwise, flush the last slice.
2209 if (last) {
2210 if (last->input == input && isOffset(last->index, index, last->size)) {
2211 last->size += size;
2212 last->op = {};
2213 last->locs.push_back(op.getLoc());
2214 return;
2215 }
2216 concatenate();
2217 }
2218 last.emplace(Slice{input, index, size, op, {op.getLoc()}});
2219 };
2220
2221 for (auto item : llvm::reverse(op.getInputs())) {
2222 if (auto slice = item.getDefiningOp<ArraySliceOp>()) {
2223 auto size = hw::type_cast<ArrayType>(slice.getType()).getNumElements();
2224 append(item, slice.getInput(), slice.getLowIndex(), size);
2225 continue;
2226 }
2227
2228 if (auto create = item.getDefiningOp<ArrayCreateOp>()) {
2229 if (create.getInputs().size() == 1) {
2230 if (auto get = create.getInputs()[0].getDefiningOp<ArrayGetOp>()) {
2231 append(item, get.getInput(), get.getIndex(), 1);
2232 continue;
2233 }
2234 }
2235 }
2236
2237 concatenate();
2238 items.push_back(item);
2239 }
2240 concatenate();
2241
2242 if (!changed)
2243 return false;
2244
2245 if (items.size() == 1) {
2246 rewriter.replaceOp(op, items[0]);
2247 } else {
2248 std::reverse(items.begin(), items.end());
2249 rewriter.replaceOpWithNewOp<ArrayConcatOp>(op, items);
2250 }
2251 return true;
2252}
2253
2254LogicalResult ArrayConcatOp::canonicalize(ArrayConcatOp op,
2255 PatternRewriter &rewriter) {
2256 // concat(create(a1, ...), create(a3, ...), ...) -> create(a1, ..., a3, ...)
2257 if (flattenConcatOp(op, rewriter))
2258 return success();
2259
2260 // concat(slice(a, n, m), slice(a, n + m, p)) -> concat(slice(a, n, m + p))
2261 if (mergeConcatSlices(op, rewriter))
2262 return success();
2263
2264 return failure();
2265}
2266
2267//===----------------------------------------------------------------------===//
2268// EnumConstantOp
2269//===----------------------------------------------------------------------===//
2270
2271ParseResult EnumConstantOp::parse(OpAsmParser &parser, OperationState &result) {
2272 // Parse a Type instead of an EnumType since the type might be a type alias.
2273 // The validity of the canonical type is checked during construction of the
2274 // EnumFieldAttr.
2275 Type type;
2276 StringRef field;
2277
2278 auto loc = parser.getEncodedSourceLoc(parser.getCurrentLocation());
2279 if (parser.parseKeyword(&field) || parser.parseColonType(type))
2280 return failure();
2281
2282 auto fieldAttr = EnumFieldAttr::get(
2283 loc, StringAttr::get(parser.getContext(), field), type);
2284
2285 if (!fieldAttr)
2286 return failure();
2287
2288 result.addAttribute("field", fieldAttr);
2289 result.addTypes(type);
2290
2291 return success();
2292}
2293
2294void EnumConstantOp::print(OpAsmPrinter &p) {
2295 p << " " << getField().getField().getValue() << " : "
2296 << getField().getType().getValue();
2297}
2298
2299void EnumConstantOp::getAsmResultNames(
2300 function_ref<void(Value, StringRef)> setNameFn) {
2301 setNameFn(getResult(), getField().getField().str());
2302}
2303
2304void EnumConstantOp::build(OpBuilder &builder, OperationState &odsState,
2305 EnumFieldAttr field) {
2306 return build(builder, odsState, field.getType().getValue(), field);
2307}
2308
2309OpFoldResult EnumConstantOp::fold(FoldAdaptor adaptor) {
2310 assert(adaptor.getOperands().empty() && "constant has no operands");
2311 return getFieldAttr();
2312}
2313
2314LogicalResult EnumConstantOp::verify() {
2315 auto fieldAttr = getFieldAttr();
2316 auto fieldType = fieldAttr.getType().getValue();
2317 // This check ensures that we are using the exact same type, without looking
2318 // through type aliases.
2319 if (fieldType != getType())
2320 emitOpError("return type ")
2321 << getType() << " does not match attribute type " << fieldAttr;
2322 return success();
2323}
2324
2325//===----------------------------------------------------------------------===//
2326// EnumCmpOp
2327//===----------------------------------------------------------------------===//
2328
2329LogicalResult EnumCmpOp::verify() {
2330 // Compare the canonical types.
2331 auto lhsType = type_cast<EnumType>(getLhs().getType());
2332 auto rhsType = type_cast<EnumType>(getRhs().getType());
2333 if (rhsType != lhsType)
2334 emitOpError("types do not match");
2335 return success();
2336}
2337
2338//===----------------------------------------------------------------------===//
2339// StructCreateOp
2340//===----------------------------------------------------------------------===//
2341
2342ParseResult StructCreateOp::parse(OpAsmParser &parser, OperationState &result) {
2343 llvm::SMLoc inputOperandsLoc = parser.getCurrentLocation();
2344 llvm::SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
2345 Type declOrAliasType;
2346
2347 if (parser.parseLParen() || parser.parseOperandList(operands) ||
2348 parser.parseRParen() || parser.parseOptionalAttrDict(result.attributes) ||
2349 parser.parseColonType(declOrAliasType))
2350 return failure();
2351
2352 auto declType = type_dyn_cast<StructType>(declOrAliasType);
2353 if (!declType)
2354 return parser.emitError(parser.getNameLoc(),
2355 "expected !hw.struct type or alias");
2356
2357 llvm::SmallVector<Type, 4> structInnerTypes;
2358 declType.getInnerTypes(structInnerTypes);
2359 result.addTypes(declOrAliasType);
2360
2361 if (parser.resolveOperands(operands, structInnerTypes, inputOperandsLoc,
2362 result.operands))
2363 return failure();
2364 return success();
2365}
2366
2367void StructCreateOp::print(OpAsmPrinter &printer) {
2368 printer << " (";
2369 printer.printOperands(getInput());
2370 printer << ")";
2371 printer.printOptionalAttrDict((*this)->getAttrs());
2372 printer << " : " << getType();
2373}
2374
2375LogicalResult StructCreateOp::verify() {
2376 auto elements = hw::type_cast<StructType>(getType()).getElements();
2377
2378 if (elements.size() != getInput().size())
2379 return emitOpError("structure field count mismatch");
2380
2381 for (const auto &[field, value] : llvm::zip(elements, getInput()))
2382 if (field.type != value.getType())
2383 return emitOpError("structure field `")
2384 << field.name << "` type does not match";
2385
2386 return success();
2387}
2388
2389OpFoldResult StructCreateOp::fold(FoldAdaptor adaptor) {
2390 // struct_create(struct_explode(x)) => x
2391 if (!getInput().empty())
2392 if (auto explodeOp = getInput()[0].getDefiningOp<StructExplodeOp>();
2393 explodeOp && getInput() == explodeOp.getResults() &&
2394 getResult().getType() == explodeOp.getInput().getType())
2395 return explodeOp.getInput();
2396
2397 auto inputs = adaptor.getInput();
2398 if (llvm::any_of(inputs, [](Attribute attr) {
2399 return !isa_and_nonnull<IntegerAttr>(attr);
2400 }))
2401 return {};
2402 return ArrayAttr::get(getContext(), inputs);
2403}
2404
2405//===----------------------------------------------------------------------===//
2406// StructExplodeOp
2407//===----------------------------------------------------------------------===//
2408
2409ParseResult StructExplodeOp::parse(OpAsmParser &parser,
2410 OperationState &result) {
2411 OpAsmParser::UnresolvedOperand operand;
2412 Type declType;
2413
2414 if (parser.parseOperand(operand) ||
2415 parser.parseOptionalAttrDict(result.attributes) ||
2416 parser.parseColonType(declType))
2417 return failure();
2418 auto structType = type_dyn_cast<StructType>(declType);
2419 if (!structType)
2420 return parser.emitError(parser.getNameLoc(),
2421 "invalid kind of type specified");
2422
2423 llvm::SmallVector<Type, 4> structInnerTypes;
2424 structType.getInnerTypes(structInnerTypes);
2425 result.addTypes(structInnerTypes);
2426
2427 if (parser.resolveOperand(operand, declType, result.operands))
2428 return failure();
2429 return success();
2430}
2431
2432void StructExplodeOp::print(OpAsmPrinter &printer) {
2433 printer << " ";
2434 printer.printOperand(getInput());
2435 printer.printOptionalAttrDict((*this)->getAttrs());
2436 printer << " : " << getInput().getType();
2437}
2438
2439LogicalResult StructExplodeOp::fold(FoldAdaptor adaptor,
2440 SmallVectorImpl<OpFoldResult> &results) {
2441 auto input = adaptor.getInput();
2442 if (!input)
2443 return failure();
2444 llvm::copy(cast<ArrayAttr>(input), std::back_inserter(results));
2445 return success();
2446}
2447
2448LogicalResult StructExplodeOp::canonicalize(StructExplodeOp op,
2449 PatternRewriter &rewriter) {
2450 auto *inputOp = op.getInput().getDefiningOp();
2451 auto elements = type_cast<StructType>(op.getInput().getType()).getElements();
2452 auto result = failure();
2453 auto opResults = op.getResults();
2454 for (uint32_t index = 0; index < elements.size(); index++) {
2455 if (auto foldResult = foldStructExtract(inputOp, index)) {
2456 rewriter.replaceAllUsesWith(opResults[index], foldResult);
2457 result = success();
2458 }
2459 }
2460 return result;
2461}
2462
2463void StructExplodeOp::getAsmResultNames(
2464 function_ref<void(Value, StringRef)> setNameFn) {
2465 auto structType = type_cast<StructType>(getInput().getType());
2466 for (auto [res, field] : llvm::zip(getResults(), structType.getElements()))
2467 setNameFn(res, field.name.str());
2468}
2469
2470void StructExplodeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
2471 Value input) {
2472 StructType inputType = dyn_cast<StructType>(input.getType());
2473 assert(inputType);
2474 SmallVector<Type, 16> fieldTypes;
2475 for (auto field : inputType.getElements())
2476 fieldTypes.push_back(field.type);
2477 build(odsBuilder, odsState, fieldTypes, input);
2478}
2479
2480//===----------------------------------------------------------------------===//
2481// StructExtractOp
2482//===----------------------------------------------------------------------===//
2483
2484/// Ensure an aggregate op's field index is within the bounds of
2485/// the aggregate type and the accessed field is of 'elementType'.
2486template <typename AggregateOp, typename AggregateType>
2487static LogicalResult verifyAggregateFieldIndexAndType(AggregateOp &op,
2488 AggregateType aggType,
2489 Type elementType) {
2490 auto index = op.getFieldIndex();
2491 if (index >= aggType.getElements().size())
2492 return op.emitOpError() << "field index " << index
2493 << " exceeds element count of aggregate type";
2494
2496 getCanonicalType(aggType.getElements()[index].type))
2497 return op.emitOpError()
2498 << "type " << aggType.getElements()[index].type
2499 << " of accessed field in aggregate at index " << index
2500 << " does not match expected type " << elementType;
2501
2502 return success();
2503}
2504
2505LogicalResult StructExtractOp::verify() {
2506 return verifyAggregateFieldIndexAndType<StructExtractOp, StructType>(
2507 *this, getInput().getType(), getType());
2508}
2509
2510/// Use the same parser for both struct_extract and union_extract since the
2511/// syntax is identical.
2512template <typename AggregateType>
2513static ParseResult parseExtractOp(OpAsmParser &parser, OperationState &result) {
2514 OpAsmParser::UnresolvedOperand operand;
2515 StringAttr fieldName;
2516 Type declType;
2517
2518 if (parser.parseOperand(operand) || parser.parseLSquare() ||
2519 parser.parseAttribute(fieldName) || parser.parseRSquare() ||
2520 parser.parseOptionalAttrDict(result.attributes) ||
2521 parser.parseColonType(declType))
2522 return failure();
2523 auto aggType = type_dyn_cast<AggregateType>(declType);
2524 if (!aggType)
2525 return parser.emitError(parser.getNameLoc(),
2526 "invalid kind of type specified");
2527
2528 auto fieldIndex = aggType.getFieldIndex(fieldName);
2529 if (!fieldIndex) {
2530 parser.emitError(parser.getNameLoc(), "field name '" +
2531 fieldName.getValue() +
2532 "' not found in aggregate type");
2533 return failure();
2534 }
2535
2536 auto indexAttr =
2537 IntegerAttr::get(IntegerType::get(parser.getContext(), 32), *fieldIndex);
2538 result.addAttribute("fieldIndex", indexAttr);
2539 Type resultType = aggType.getElements()[*fieldIndex].type;
2540 result.addTypes(resultType);
2541
2542 if (parser.resolveOperand(operand, declType, result.operands))
2543 return failure();
2544 return success();
2545}
2546
2547/// Use the same printer for both struct_extract and union_extract since the
2548/// syntax is identical.
2549template <typename AggType>
2550static void printExtractOp(OpAsmPrinter &printer, AggType op) {
2551 printer << " ";
2552 printer.printOperand(op.getInput());
2553 printer << "[\"" << op.getFieldName() << "\"]";
2554 printer.printOptionalAttrDict(op->getAttrs(), {"fieldIndex"});
2555 printer << " : " << op.getInput().getType();
2556}
2557
2558ParseResult StructExtractOp::parse(OpAsmParser &parser,
2559 OperationState &result) {
2560 return parseExtractOp<StructType>(parser, result);
2561}
2562
2563void StructExtractOp::print(OpAsmPrinter &printer) {
2564 printExtractOp(printer, *this);
2565}
2566
2567void StructExtractOp::build(OpBuilder &builder, OperationState &odsState,
2568 Value input, StructType::FieldInfo field) {
2569 auto fieldIndex =
2570 type_cast<StructType>(input.getType()).getFieldIndex(field.name);
2571 assert(fieldIndex.has_value() && "field name not found in aggregate type");
2572 build(builder, odsState, field.type, input, *fieldIndex);
2573}
2574
2575void StructExtractOp::build(OpBuilder &builder, OperationState &odsState,
2576 Value input, StringAttr fieldName) {
2577 auto structType = type_cast<StructType>(input.getType());
2578 auto fieldIndex = structType.getFieldIndex(fieldName);
2579 assert(fieldIndex.has_value() && "field name not found in aggregate type");
2580 auto resultType = structType.getElements()[*fieldIndex].type;
2581 build(builder, odsState, resultType, input, *fieldIndex);
2582}
2583
2584OpFoldResult StructExtractOp::fold(FoldAdaptor adaptor) {
2585 if (auto constOperand = adaptor.getInput()) {
2586 // Fold extract from aggregate constant
2587 auto operandAttr = llvm::cast<ArrayAttr>(constOperand);
2588 return operandAttr.getValue()[getFieldIndex()];
2589 }
2590
2591 if (auto foldResult =
2592 foldStructExtract(getInput().getDefiningOp(), getFieldIndex()))
2593 return foldResult;
2594 return {};
2595}
2596
2597LogicalResult StructExtractOp::canonicalize(StructExtractOp op,
2598 PatternRewriter &rewriter) {
2599 auto *inputOp = op.getInput().getDefiningOp();
2600
2601 // b = extract(inject(x["a"], v0)["b"]) => extract(x, "b")
2602 if (auto structInject = dyn_cast_or_null<StructInjectOp>(inputOp)) {
2603 if (structInject.getFieldIndex() != op.getFieldIndex()) {
2604 rewriter.replaceOpWithNewOp<StructExtractOp>(
2605 op, op.getType(), structInject.getInput(), op.getFieldIndexAttr());
2606 return success();
2607 }
2608 }
2609
2610 return failure();
2611}
2612
2613void StructExtractOp::getAsmResultNames(
2614 function_ref<void(Value, StringRef)> setNameFn) {
2615 setNameFn(getResult(), getFieldName());
2616}
2617
2618//===----------------------------------------------------------------------===//
2619// StructInjectOp
2620//===----------------------------------------------------------------------===//
2621
2622void StructInjectOp::build(OpBuilder &builder, OperationState &odsState,
2623 Value input, StringAttr fieldName, Value newValue) {
2624 auto structType = type_cast<StructType>(input.getType());
2625 auto fieldIndex = structType.getFieldIndex(fieldName);
2626 assert(fieldIndex.has_value() && "field name not found in aggregate type");
2627 build(builder, odsState, input, *fieldIndex, newValue);
2628}
2629
2630LogicalResult StructInjectOp::verify() {
2631 return verifyAggregateFieldIndexAndType<StructInjectOp, StructType>(
2632 *this, getInput().getType(), getNewValue().getType());
2633}
2634
2635ParseResult StructInjectOp::parse(OpAsmParser &parser, OperationState &result) {
2636 llvm::SMLoc inputOperandsLoc = parser.getCurrentLocation();
2637 OpAsmParser::UnresolvedOperand operand, val;
2638 StringAttr fieldName;
2639 Type declType;
2640
2641 if (parser.parseOperand(operand) || parser.parseLSquare() ||
2642 parser.parseAttribute(fieldName) || parser.parseRSquare() ||
2643 parser.parseComma() || parser.parseOperand(val) ||
2644 parser.parseOptionalAttrDict(result.attributes) ||
2645 parser.parseColonType(declType))
2646 return failure();
2647 auto structType = type_dyn_cast<StructType>(declType);
2648 if (!structType)
2649 return parser.emitError(inputOperandsLoc, "invalid kind of type specified");
2650
2651 auto fieldIndex = structType.getFieldIndex(fieldName);
2652 if (!fieldIndex) {
2653 parser.emitError(parser.getNameLoc(), "field name '" +
2654 fieldName.getValue() +
2655 "' not found in aggregate type");
2656 return failure();
2657 }
2658
2659 auto indexAttr =
2660 IntegerAttr::get(IntegerType::get(parser.getContext(), 32), *fieldIndex);
2661 result.addAttribute("fieldIndex", indexAttr);
2662 result.addTypes(declType);
2663
2664 Type resultType = structType.getElements()[*fieldIndex].type;
2665 if (parser.resolveOperands({operand, val}, {declType, resultType},
2666 inputOperandsLoc, result.operands))
2667 return failure();
2668 return success();
2669}
2670
2671void StructInjectOp::print(OpAsmPrinter &printer) {
2672 printer << " ";
2673 printer.printOperand(getInput());
2674 printer << "[\"" << getFieldName() << "\"], ";
2675 printer.printOperand(getNewValue());
2676 printer.printOptionalAttrDict((*this)->getAttrs(), {"fieldIndex"});
2677 printer << " : " << getInput().getType();
2678}
2679
2680OpFoldResult StructInjectOp::fold(FoldAdaptor adaptor) {
2681 auto input = adaptor.getInput();
2682 auto newValue = adaptor.getNewValue();
2683 if (!input || !newValue)
2684 return {};
2685 SmallVector<Attribute> array;
2686 llvm::copy(cast<ArrayAttr>(input), std::back_inserter(array));
2687 array[getFieldIndex()] = newValue;
2688 return ArrayAttr::get(getContext(), array);
2689}
2690
2691LogicalResult StructInjectOp::canonicalize(StructInjectOp op,
2692 PatternRewriter &rewriter) {
2693 // If this inject is only used as an input to another inject, don't try to
2694 // canonicalize it. It will be included in that other op's canonicalization
2695 // attempt. This avoids doing redundant work.
2696 if (op->hasOneUse()) {
2697 auto &use = *op->use_begin();
2698 if (isa<StructInjectOp>(use.getOwner()) && use.getOperandNumber() == 0)
2699 return failure();
2700 }
2701
2702 // Canonicalize multiple injects into a create op and eliminate overwrites.
2703 SmallPtrSet<Operation *, 4> injects;
2704 DenseMap<StringAttr, Value> fields;
2705
2706 // Chase a chain of injects. Bail out if cycles are present.
2707 StructInjectOp inject = op;
2708 Value input;
2709 do {
2710 if (!injects.insert(inject).second)
2711 break;
2712
2713 fields.try_emplace(inject.getFieldNameAttr(), inject.getNewValue());
2714 input = inject.getInput();
2715 inject = dyn_cast_or_null<StructInjectOp>(input.getDefiningOp());
2716 } while (inject);
2717 assert(input && "missing input to inject chain");
2718
2719 auto ty = hw::type_cast<StructType>(op.getType());
2720 auto elements = ty.getElements();
2721
2722 // If the inject chain sets all fields, canonicalize to create.
2723 if (fields.size() == elements.size()) {
2724 SmallVector<Value> createFields;
2725 for (const auto &field : elements) {
2726 auto it = fields.find(field.name);
2727 assert(it != fields.end() && "missing field");
2728 createFields.push_back(it->second);
2729 }
2730 rewriter.replaceOpWithNewOp<StructCreateOp>(op, ty, createFields);
2731 return success();
2732 }
2733
2734 // Nothing to canonicalize, only the original inject in the chain.
2735 if (injects.size() == fields.size())
2736 return failure();
2737
2738 // Eliminate overwrites. The hash map contains the last write to each field.
2739 for (uint32_t fieldIndex = 0; fieldIndex < elements.size(); fieldIndex++) {
2740 auto it = fields.find(elements[fieldIndex].name);
2741 if (it == fields.end())
2742 continue;
2743 input = StructInjectOp::create(rewriter, op.getLoc(), ty, input, fieldIndex,
2744 it->second);
2745 }
2746
2747 rewriter.replaceOp(op, input);
2748 return success();
2749}
2750
2751//===----------------------------------------------------------------------===//
2752// UnionCreateOp
2753//===----------------------------------------------------------------------===//
2754
2755LogicalResult UnionCreateOp::verify() {
2756 return verifyAggregateFieldIndexAndType<UnionCreateOp, UnionType>(
2757 *this, getType(), getInput().getType());
2758}
2759
2760void UnionCreateOp::build(OpBuilder &builder, OperationState &odsState,
2761 Type unionType, StringAttr fieldName, Value input) {
2762 auto fieldIndex = type_cast<UnionType>(unionType).getFieldIndex(fieldName);
2763 assert(fieldIndex.has_value() && "field name not found in aggregate type");
2764 build(builder, odsState, unionType, *fieldIndex, input);
2765}
2766
2767ParseResult UnionCreateOp::parse(OpAsmParser &parser, OperationState &result) {
2768 Type declOrAliasType;
2769 StringAttr fieldName;
2770 OpAsmParser::UnresolvedOperand input;
2771 llvm::SMLoc fieldLoc = parser.getCurrentLocation();
2772
2773 if (parser.parseAttribute(fieldName) || parser.parseComma() ||
2774 parser.parseOperand(input) ||
2775 parser.parseOptionalAttrDict(result.attributes) ||
2776 parser.parseColonType(declOrAliasType))
2777 return failure();
2778
2779 auto declType = type_dyn_cast<UnionType>(declOrAliasType);
2780 if (!declType)
2781 return parser.emitError(parser.getNameLoc(),
2782 "expected !hw.union type or alias");
2783
2784 auto fieldIndex = declType.getFieldIndex(fieldName);
2785 if (!fieldIndex) {
2786 parser.emitError(fieldLoc, "cannot find union field '")
2787 << fieldName.getValue() << '\'';
2788 return failure();
2789 }
2790
2791 auto indexAttr =
2792 IntegerAttr::get(IntegerType::get(parser.getContext(), 32), *fieldIndex);
2793 result.addAttribute("fieldIndex", indexAttr);
2794 Type inputType = declType.getElements()[*fieldIndex].type;
2795
2796 if (parser.resolveOperand(input, inputType, result.operands))
2797 return failure();
2798 result.addTypes({declOrAliasType});
2799 return success();
2800}
2801
2802void UnionCreateOp::print(OpAsmPrinter &printer) {
2803 printer << " \"" << getFieldName() << "\", ";
2804 printer.printOperand(getInput());
2805 printer.printOptionalAttrDict((*this)->getAttrs(), {"fieldIndex"});
2806 printer << " : " << getType();
2807}
2808
2809//===----------------------------------------------------------------------===//
2810// UnionExtractOp
2811//===----------------------------------------------------------------------===//
2812
2813ParseResult UnionExtractOp::parse(OpAsmParser &parser, OperationState &result) {
2814 return parseExtractOp<UnionType>(parser, result);
2815}
2816
2817void UnionExtractOp::print(OpAsmPrinter &printer) {
2818 printExtractOp(printer, *this);
2819}
2820
2821LogicalResult UnionExtractOp::inferReturnTypes(
2822 MLIRContext *context, std::optional<Location> loc, ValueRange operands,
2823 DictionaryAttr attrs, mlir::OpaqueProperties properties,
2824 mlir::RegionRange regions, SmallVectorImpl<Type> &results) {
2825 Adaptor adaptor(operands, attrs, properties, regions);
2826 auto unionElements =
2827 hw::type_cast<UnionType>((adaptor.getInput().getType())).getElements();
2828 unsigned fieldIndex = adaptor.getFieldIndexAttr().getValue().getZExtValue();
2829 if (fieldIndex >= unionElements.size()) {
2830 if (loc)
2831 mlir::emitError(*loc, "field index " + Twine(fieldIndex) +
2832 " exceeds element count of aggregate type");
2833 return failure();
2834 }
2835 results.push_back(unionElements[fieldIndex].type);
2836 return success();
2837}
2838
2839void UnionExtractOp::build(OpBuilder &odsBuilder, OperationState &odsState,
2840 Value input, StringAttr fieldName) {
2841 auto unionType = type_cast<UnionType>(input.getType());
2842 auto fieldIndex = unionType.getFieldIndex(fieldName);
2843 assert(fieldIndex.has_value() && "field name not found in aggregate type");
2844 auto resultType = unionType.getElements()[*fieldIndex].type;
2845 build(odsBuilder, odsState, resultType, input, *fieldIndex);
2846}
2847
2848//===----------------------------------------------------------------------===//
2849// ArrayGetOp
2850//===----------------------------------------------------------------------===//
2851
2852// An array_get of an array_create with a constant index can just be the
2853// array_create operand at the constant index. If the array_create has a
2854// single uniform value for each element, just return that value regardless of
2855// the index. If the array is constructed from a constant by a bitcast
2856// operation, we can fold into a constant.
2857OpFoldResult ArrayGetOp::fold(FoldAdaptor adaptor) {
2858 auto inputCst = dyn_cast_or_null<ArrayAttr>(adaptor.getInput());
2859 auto indexCst = dyn_cast_or_null<IntegerAttr>(adaptor.getIndex());
2860
2861 if (inputCst) {
2862 // Constant array index.
2863 if (indexCst) {
2864 auto indexVal = indexCst.getValue();
2865 if (indexVal.getBitWidth() < 64) {
2866 auto index = indexVal.getZExtValue();
2867 return inputCst[inputCst.size() - 1 - index];
2868 }
2869 }
2870 // If all elements of the array are the same, we can return any element of
2871 // array.
2872 if (!inputCst.empty() && llvm::all_equal(inputCst))
2873 return inputCst[0];
2874 }
2875
2876 // array_get(bitcast(c), i) -> c[i*w+w-1:i*w]
2877 if (auto bitcast = getInput().getDefiningOp<hw::BitcastOp>()) {
2878 auto intTy = dyn_cast<IntegerType>(getType());
2879 if (!intTy)
2880 return {};
2881 auto bitcastInputOp = bitcast.getInput().getDefiningOp<hw::ConstantOp>();
2882 if (!bitcastInputOp)
2883 return {};
2884 if (!indexCst)
2885 return {};
2886 auto bitcastInputCst = bitcastInputOp.getValue();
2887 // Calculate the index. Make sure to zero-extend the index value before
2888 // multiplying the element width.
2889 auto startIdx = indexCst.getValue().zext(bitcastInputCst.getBitWidth()) *
2890 getType().getIntOrFloatBitWidth();
2891 // Extract [startIdx + width - 1: startIdx].
2892 return IntegerAttr::get(intTy, bitcastInputCst.lshr(startIdx).trunc(
2893 intTy.getIntOrFloatBitWidth()));
2894 }
2895
2896 // array_get(array_inject(_, index, element), index) -> element
2897 if (auto inject = getInput().getDefiningOp<ArrayInjectOp>())
2898 if (getIndex() == inject.getIndex())
2899 return inject.getElement();
2900
2901 auto inputCreate = getInput().getDefiningOp<ArrayCreateOp>();
2902 if (!inputCreate)
2903 return {};
2904
2905 if (auto uniformValue = inputCreate.getUniformElement())
2906 return uniformValue;
2907
2908 if (!indexCst || indexCst.getValue().getBitWidth() > 64)
2909 return {};
2910
2911 uint64_t index = indexCst.getValue().getLimitedValue();
2912 auto createInputs = inputCreate.getInputs();
2913 if (index >= createInputs.size())
2914 return {};
2915 return createInputs[createInputs.size() - index - 1];
2916}
2917
2918LogicalResult ArrayGetOp::canonicalize(ArrayGetOp op,
2919 PatternRewriter &rewriter) {
2920 auto idxOpt = getUIntFromValue(op.getIndex());
2921 if (!idxOpt)
2922 return failure();
2923
2924 auto *inputOp = op.getInput().getDefiningOp();
2925 if (auto inputSlice = dyn_cast_or_null<ArraySliceOp>(inputOp)) {
2926 // get(slice(a, n), m) -> get(a, n + m)
2927 auto offsetOp = inputSlice.getLowIndex();
2928 auto offsetOpt = getUIntFromValue(offsetOp);
2929 if (!offsetOpt)
2930 return failure();
2931
2932 uint64_t offset = *offsetOpt + *idxOpt;
2933 auto newOffset =
2934 ConstantOp::create(rewriter, op.getLoc(), offsetOp.getType(), offset);
2935 rewriter.replaceOpWithNewOp<ArrayGetOp>(op, inputSlice.getInput(),
2936 newOffset);
2937 return success();
2938 }
2939
2940 if (auto inputConcat = dyn_cast_or_null<ArrayConcatOp>(inputOp)) {
2941 // get(concat(a0, a1, ...), m) -> get(an, m - s0 - s1 - ...)
2942 uint64_t elemIndex = *idxOpt;
2943 for (auto input : llvm::reverse(inputConcat.getInputs())) {
2944 size_t size = hw::type_cast<ArrayType>(input.getType()).getNumElements();
2945 if (elemIndex >= size) {
2946 elemIndex -= size;
2947 continue;
2948 }
2949
2950 unsigned indexWidth = size == 1 ? 1 : llvm::Log2_64_Ceil(size);
2951 auto newIdxOp =
2952 ConstantOp::create(rewriter, op.getLoc(),
2953 rewriter.getIntegerType(indexWidth), elemIndex);
2954
2955 rewriter.replaceOpWithNewOp<ArrayGetOp>(op, input, newIdxOp);
2956 return success();
2957 }
2958 return failure();
2959 }
2960
2961 // array_get const, (array_get sel, (array_create a, b, c, d)) -->
2962 // array_get sel, (array_create (array_get const a), (array_get const b),
2963 // (array_get const, c), (array_get const, d))
2964 if (auto innerGet = dyn_cast_or_null<hw::ArrayGetOp>(inputOp)) {
2965 if (!innerGet.getIndex().getDefiningOp<hw::ConstantOp>()) {
2966 if (auto create =
2967 innerGet.getInput().getDefiningOp<hw::ArrayCreateOp>()) {
2968
2969 SmallVector<Value> newValues;
2970 for (auto operand : create.getOperands())
2971 newValues.push_back(rewriter.createOrFold<hw::ArrayGetOp>(
2972 op.getLoc(), operand, op.getIndex()));
2973
2974 rewriter.replaceOpWithNewOp<hw::ArrayGetOp>(
2975 op,
2976 rewriter.createOrFold<hw::ArrayCreateOp>(op.getLoc(), newValues),
2977 innerGet.getIndex());
2978 return success();
2979 }
2980 }
2981 }
2982
2983 return failure();
2984}
2985
2986//===----------------------------------------------------------------------===//
2987// ArrayInjectOp
2988//===----------------------------------------------------------------------===//
2989
2990OpFoldResult ArrayInjectOp::fold(FoldAdaptor adaptor) {
2991 auto inputAttr = dyn_cast_or_null<ArrayAttr>(adaptor.getInput());
2992 auto indexAttr = dyn_cast_or_null<IntegerAttr>(adaptor.getIndex());
2993 auto elementAttr = adaptor.getElement();
2994
2995 // inject(constant[xs, y, zs], iy, a) -> constant[x, a, z]
2996 if (inputAttr && indexAttr && elementAttr) {
2997 if (auto index = indexAttr.getValue().tryZExtValue()) {
2998 if (*index < inputAttr.size()) {
2999 SmallVector<Attribute> elements(inputAttr.getValue());
3000 elements[inputAttr.size() - 1 - *index] = elementAttr;
3001 return ArrayAttr::get(getContext(), elements);
3002 }
3003 }
3004 }
3005
3006 return {};
3007}
3008
3009static LogicalResult canonicalizeArrayInjectChain(ArrayInjectOp op,
3010 PatternRewriter &rewriter) {
3011 // If this inject is only used as an input to another inject, don't try to
3012 // canonicalize it. It will be included in that other op's canonicalization
3013 // attempt. This avoids doing redundant work.
3014 if (op->hasOneUse()) {
3015 auto &use = *op->use_begin();
3016 if (isa<ArrayInjectOp>(use.getOwner()) && use.getOperandNumber() == 0)
3017 return failure();
3018 }
3019
3020 // Collect all injects to constant indices.
3021 auto arrayLength = type_cast<ArrayType>(op.getType()).getNumElements();
3022 Value input = op;
3024 while (auto inject = input.getDefiningOp<ArrayInjectOp>()) {
3025 // Determine the constant index.
3026 APInt indexAPInt;
3027 if (!matchPattern(inject.getIndex(), mlir::m_ConstantInt(&indexAPInt)))
3028 break;
3029 if (indexAPInt.getActiveBits() > 32)
3030 break;
3031 uint32_t index = indexAPInt.getZExtValue();
3032
3033 // Track the injected value. Make sure to only track indices that are in
3034 // bounds. This will allow us to later check if `elements.size()` matches
3035 // the array length.
3036 if (index < arrayLength)
3037 elements.insert({index, inject.getElement()});
3038
3039 // Step to the next inject op.
3040 input = inject.getInput();
3041 if (input == op)
3042 break; // break cycles
3043 }
3044
3045 // If we are assigning every single element, replace the op with an
3046 // `hw.array_create`.
3047 if (elements.size() == arrayLength) {
3048 SmallVector<Value, 4> operands;
3049 operands.reserve(arrayLength);
3050 for (uint32_t idx = 0; idx < arrayLength; ++idx)
3051 operands.push_back(elements.at(arrayLength - idx - 1));
3052 rewriter.replaceOpWithNewOp<ArrayCreateOp>(op, op.getType(), operands);
3053 return success();
3054 }
3055
3056 return failure();
3057}
3058
3059static LogicalResult
3060canonicalizeArrayInjectIntoCreate(ArrayInjectOp op, PatternRewriter &rewriter) {
3061 auto createOp = op.getInput().getDefiningOp<ArrayCreateOp>();
3062 if (!createOp)
3063 return failure();
3064
3065 // Make sure the access is in bounds.
3066 APInt indexAPInt;
3067 if (!matchPattern(op.getIndex(), mlir::m_ConstantInt(&indexAPInt)) ||
3068 !indexAPInt.ult(createOp.getInputs().size()))
3069 return failure();
3070
3071 // Substitute the injected value.
3072 SmallVector<Value> elements = createOp.getInputs();
3073 elements[elements.size() - indexAPInt.getZExtValue() - 1] = op.getElement();
3074 rewriter.replaceOpWithNewOp<ArrayCreateOp>(op, elements);
3075 return success();
3076}
3077
3078void ArrayInjectOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
3079 MLIRContext *context) {
3080 patterns.add<ArrayInjectToSameIndex>(context);
3083}
3084
3085//===----------------------------------------------------------------------===//
3086// TypedeclOp
3087//===----------------------------------------------------------------------===//
3088
3089StringRef TypedeclOp::getPreferredName() {
3090 return getVerilogName().value_or(getName());
3091}
3092
3093Type TypedeclOp::getAliasType() {
3094 auto parentScope = cast<hw::TypeScopeOp>(getOperation()->getParentOp());
3095 return hw::TypeAliasType::get(
3096 SymbolRefAttr::get(parentScope.getSymNameAttr(),
3097 {FlatSymbolRefAttr::get(*this)}),
3098 getType());
3099}
3100
3101//===----------------------------------------------------------------------===//
3102// BitcastOp
3103//===----------------------------------------------------------------------===//
3104
3105OpFoldResult BitcastOp::fold(FoldAdaptor) {
3106 // Identity.
3107 // bitcast(%a) : A -> A ==> %a
3108 if (getOperand().getType() == getType())
3109 return getOperand();
3110
3111 return {};
3112}
3113
3114LogicalResult BitcastOp::canonicalize(BitcastOp op, PatternRewriter &rewriter) {
3115 // Composition.
3116 // %b = bitcast(%a) : A -> B
3117 // bitcast(%b) : B -> C
3118 // ===> bitcast(%a) : A -> C
3119 auto inputBitcast =
3120 dyn_cast_or_null<BitcastOp>(op.getInput().getDefiningOp());
3121 if (!inputBitcast)
3122 return failure();
3123 auto bitcast = rewriter.createOrFold<BitcastOp>(op.getLoc(), op.getType(),
3124 inputBitcast.getInput());
3125 rewriter.replaceOp(op, bitcast);
3126 return success();
3127}
3128
3129LogicalResult BitcastOp::verify() {
3130 if (getBitWidth(getInput().getType()) != getBitWidth(getResult().getType()))
3131 return this->emitOpError("Bitwidth of input must match result");
3132 return success();
3133}
3134
3135//===----------------------------------------------------------------------===//
3136// HierPathOp helpers.
3137//===----------------------------------------------------------------------===//
3138
3139bool HierPathOp::dropModule(StringAttr moduleToDrop) {
3140 SmallVector<Attribute, 4> newPath;
3141 bool updateMade = false;
3142 for (auto nameRef : getNamepath()) {
3143 // nameRef is either an InnerRefAttr or a FlatSymbolRefAttr.
3144 if (auto ref = dyn_cast<hw::InnerRefAttr>(nameRef)) {
3145 if (ref.getModule() == moduleToDrop)
3146 updateMade = true;
3147 else
3148 newPath.push_back(ref);
3149 } else {
3150 if (cast<FlatSymbolRefAttr>(nameRef).getAttr() == moduleToDrop)
3151 updateMade = true;
3152 else
3153 newPath.push_back(nameRef);
3154 }
3155 }
3156 if (updateMade)
3157 setNamepathAttr(ArrayAttr::get(getContext(), newPath));
3158 return updateMade;
3159}
3160
3161bool HierPathOp::inlineModule(StringAttr moduleToDrop) {
3162 SmallVector<Attribute, 4> newPath;
3163 bool updateMade = false;
3164 StringRef inlinedInstanceName = "";
3165 for (auto nameRef : getNamepath()) {
3166 // nameRef is either an InnerRefAttr or a FlatSymbolRefAttr.
3167 if (auto ref = dyn_cast<hw::InnerRefAttr>(nameRef)) {
3168 if (ref.getModule() == moduleToDrop) {
3169 inlinedInstanceName = ref.getName().getValue();
3170 updateMade = true;
3171 } else if (!inlinedInstanceName.empty()) {
3172 newPath.push_back(hw::InnerRefAttr::get(
3173 ref.getModule(),
3174 StringAttr::get(getContext(), inlinedInstanceName + "_" +
3175 ref.getName().getValue())));
3176 inlinedInstanceName = "";
3177 } else
3178 newPath.push_back(ref);
3179 } else {
3180 if (cast<FlatSymbolRefAttr>(nameRef).getAttr() == moduleToDrop)
3181 updateMade = true;
3182 else
3183 newPath.push_back(nameRef);
3184 }
3185 }
3186 if (updateMade)
3187 setNamepathAttr(ArrayAttr::get(getContext(), newPath));
3188 return updateMade;
3189}
3190
3191bool HierPathOp::updateModule(StringAttr oldMod, StringAttr newMod) {
3192 SmallVector<Attribute, 4> newPath;
3193 bool updateMade = false;
3194 for (auto nameRef : getNamepath()) {
3195 // nameRef is either an InnerRefAttr or a FlatSymbolRefAttr.
3196 if (auto ref = dyn_cast<hw::InnerRefAttr>(nameRef)) {
3197 if (ref.getModule() == oldMod) {
3198 newPath.push_back(hw::InnerRefAttr::get(newMod, ref.getName()));
3199 updateMade = true;
3200 } else
3201 newPath.push_back(ref);
3202 } else {
3203 if (cast<FlatSymbolRefAttr>(nameRef).getAttr() == oldMod) {
3204 newPath.push_back(FlatSymbolRefAttr::get(newMod));
3205 updateMade = true;
3206 } else
3207 newPath.push_back(nameRef);
3208 }
3209 }
3210 if (updateMade)
3211 setNamepathAttr(ArrayAttr::get(getContext(), newPath));
3212 return updateMade;
3213}
3214
3215bool HierPathOp::updateModuleAndInnerRef(
3216 StringAttr oldMod, StringAttr newMod,
3217 const llvm::DenseMap<StringAttr, StringAttr> &innerSymRenameMap) {
3218 auto fromRef = FlatSymbolRefAttr::get(oldMod);
3219 if (oldMod == newMod)
3220 return false;
3221
3222 auto namepathNew = getNamepath().getValue().vec();
3223 bool updateMade = false;
3224 // Break from the loop if the module is found, since it can occur only once.
3225 for (auto &element : namepathNew) {
3226 if (auto innerRef = dyn_cast<hw::InnerRefAttr>(element)) {
3227 if (innerRef.getModule() != oldMod)
3228 continue;
3229 auto symName = innerRef.getName();
3230 // Since the module got updated, the old innerRef symbol inside oldMod
3231 // should also be updated to the new symbol inside the newMod.
3232 auto to = innerSymRenameMap.find(symName);
3233 if (to != innerSymRenameMap.end())
3234 symName = to->second;
3235 updateMade = true;
3236 element = hw::InnerRefAttr::get(newMod, symName);
3237 break;
3238 }
3239 if (element != fromRef)
3240 continue;
3241
3242 updateMade = true;
3243 element = FlatSymbolRefAttr::get(newMod);
3244 break;
3245 }
3246 if (updateMade)
3247 setNamepathAttr(ArrayAttr::get(getContext(), namepathNew));
3248 return updateMade;
3249}
3250
3251bool HierPathOp::truncateAtModule(StringAttr atMod, bool includeMod) {
3252 SmallVector<Attribute, 4> newPath;
3253 bool updateMade = false;
3254 for (auto nameRef : getNamepath()) {
3255 // nameRef is either an InnerRefAttr or a FlatSymbolRefAttr.
3256 if (auto ref = dyn_cast<hw::InnerRefAttr>(nameRef)) {
3257 if (ref.getModule() == atMod) {
3258 updateMade = true;
3259 if (includeMod)
3260 newPath.push_back(ref);
3261 } else
3262 newPath.push_back(ref);
3263 } else {
3264 if (cast<FlatSymbolRefAttr>(nameRef).getAttr() == atMod && !includeMod)
3265 updateMade = true;
3266 else
3267 newPath.push_back(nameRef);
3268 }
3269 if (updateMade)
3270 break;
3271 }
3272 if (updateMade)
3273 setNamepathAttr(ArrayAttr::get(getContext(), newPath));
3274 return updateMade;
3275}
3276
3277/// Return just the module part of the namepath at a specific index.
3278StringAttr HierPathOp::modPart(unsigned i) {
3279 return TypeSwitch<Attribute, StringAttr>(getNamepath()[i])
3280 .Case<FlatSymbolRefAttr>([](auto a) { return a.getAttr(); })
3281 .Case<hw::InnerRefAttr>([](auto a) { return a.getModule(); });
3282}
3283
3284/// Return the root module.
3285StringAttr HierPathOp::root() {
3286 assert(!getNamepath().empty());
3287 return modPart(0);
3288}
3289
3290/// Return true if the NLA has the module in its path.
3291bool HierPathOp::hasModule(StringAttr modName) {
3292 for (auto nameRef : getNamepath()) {
3293 // nameRef is either an InnerRefAttr or a FlatSymbolRefAttr.
3294 if (auto ref = dyn_cast<hw::InnerRefAttr>(nameRef)) {
3295 if (ref.getModule() == modName)
3296 return true;
3297 } else {
3298 if (cast<FlatSymbolRefAttr>(nameRef).getAttr() == modName)
3299 return true;
3300 }
3301 }
3302 return false;
3303}
3304
3305/// Return true if the NLA has the InnerSym .
3306bool HierPathOp::hasInnerSym(StringAttr modName, StringAttr symName) const {
3307 for (auto nameRef : const_cast<HierPathOp *>(this)->getNamepath())
3308 if (auto ref = dyn_cast<hw::InnerRefAttr>(nameRef))
3309 if (ref.getName() == symName && ref.getModule() == modName)
3310 return true;
3311
3312 return false;
3313}
3314
3315/// Return just the reference part of the namepath at a specific index. This
3316/// will return an empty attribute if this is the leaf and the leaf is a module.
3317StringAttr HierPathOp::refPart(unsigned i) {
3318 return TypeSwitch<Attribute, StringAttr>(getNamepath()[i])
3319 .Case<FlatSymbolRefAttr>([](auto a) { return StringAttr({}); })
3320 .Case<hw::InnerRefAttr>([](auto a) { return a.getName(); });
3321}
3322
3323/// Return the leaf reference. This returns an empty attribute if the leaf
3324/// reference is a module.
3325StringAttr HierPathOp::ref() {
3326 assert(!getNamepath().empty());
3327 return refPart(getNamepath().size() - 1);
3328}
3329
3330/// Return the leaf module.
3331StringAttr HierPathOp::leafMod() {
3332 assert(!getNamepath().empty());
3333 return modPart(getNamepath().size() - 1);
3334}
3335
3336/// Returns true if this NLA targets an instance of a module (as opposed to
3337/// an instance's port or something inside an instance).
3338bool HierPathOp::isModule() { return !ref(); }
3339
3340/// Returns true if this NLA targets something inside a module (as opposed
3341/// to a module or an instance of a module);
3342bool HierPathOp::isComponent() { return (bool)ref(); }
3343
3344// Verify the HierPathOp.
3345// 1. Iterate over the namepath.
3346// 2. The namepath should be a valid instance path, specified either on a
3347// module or a declaration inside a module.
3348// 3. Each element in the namepath is an InnerRefAttr except possibly the
3349// last element.
3350// 4. Make sure that the InnerRefAttr is legal, by verifying the module name
3351// and the corresponding inner_sym on the instance.
3352// 5. Make sure that the instance path is legal, by verifying the sequence of
3353// instance and the expected module occurs as the next element in the path.
3354// 6. The last element of the namepath, can be an InnerRefAttr on either a
3355// module port or a declaration inside the module.
3356// 7. The last element of the namepath can also be a module symbol.
3357LogicalResult HierPathOp::verifyInnerRefs(hw::InnerRefNamespace &ns) {
3358 ArrayAttr expectedModuleNames = {};
3359 auto checkExpectedModule = [&](Attribute name) -> LogicalResult {
3360 if (!expectedModuleNames)
3361 return success();
3362 if (llvm::any_of(expectedModuleNames,
3363 [name](Attribute attr) { return attr == name; }))
3364 return success();
3365 auto diag = emitOpError() << "instance path is incorrect. Expected ";
3366 size_t n = expectedModuleNames.size();
3367 if (n != 1) {
3368 diag << "one of ";
3369 }
3370 for (size_t i = 0; i < n; ++i) {
3371 if (i != 0)
3372 diag << ((i + 1 == n) ? " or " : ", ");
3373 diag << cast<StringAttr>(expectedModuleNames[i]);
3374 }
3375 diag << ". Instead found: " << name;
3376 return diag;
3377 };
3378
3379 if (!getNamepath() || getNamepath().empty())
3380 return emitOpError() << "the instance path cannot be empty";
3381 for (unsigned i = 0, s = getNamepath().size() - 1; i < s; ++i) {
3382 hw::InnerRefAttr innerRef = dyn_cast<hw::InnerRefAttr>(getNamepath()[i]);
3383 if (!innerRef)
3384 return emitOpError()
3385 << "the instance path can only contain inner sym reference"
3386 << ", only the leaf can refer to a module symbol";
3387
3388 if (failed(checkExpectedModule(innerRef.getModule())))
3389 return failure();
3390
3391 auto instOp = ns.lookupOp<igraph::InstanceOpInterface>(innerRef);
3392 if (!instOp)
3393 return emitOpError() << " module: " << innerRef.getModule()
3394 << " does not contain any instance with symbol: "
3395 << innerRef.getName();
3396 expectedModuleNames = instOp.getReferencedModuleNamesAttr();
3397 }
3398
3399 // The instance path has been verified. Now verify the last element.
3400 auto leafRef = getNamepath()[getNamepath().size() - 1];
3401 if (auto innerRef = dyn_cast<hw::InnerRefAttr>(leafRef)) {
3402 if (!ns.lookup(innerRef)) {
3403 return emitOpError() << " operation with symbol: " << innerRef
3404 << " was not found ";
3405 }
3406 if (failed(checkExpectedModule(innerRef.getModule())))
3407 return failure();
3408 } else if (failed(checkExpectedModule(
3409 cast<FlatSymbolRefAttr>(leafRef).getAttr()))) {
3410 return failure();
3411 }
3412 return success();
3413}
3414
3415void HierPathOp::print(OpAsmPrinter &p) {
3416 p << " ";
3417
3418 // Print visibility if present.
3419 StringRef visibilityAttrName = SymbolTable::getVisibilityAttrName();
3420 if (auto visibility =
3421 getOperation()->getAttrOfType<StringAttr>(visibilityAttrName))
3422 p << visibility.getValue() << ' ';
3423
3424 p.printSymbolName(getSymName());
3425 p << " [";
3426 llvm::interleaveComma(getNamepath().getValue(), p, [&](Attribute attr) {
3427 if (auto ref = dyn_cast<hw::InnerRefAttr>(attr)) {
3428 p.printSymbolName(ref.getModule().getValue());
3429 p << "::";
3430 p.printSymbolName(ref.getName().getValue());
3431 } else {
3432 p.printSymbolName(cast<FlatSymbolRefAttr>(attr).getValue());
3433 }
3434 });
3435 p << "]";
3436 p.printOptionalAttrDict(
3437 (*this)->getAttrs(),
3438 {SymbolTable::getSymbolAttrName(), "namepath", visibilityAttrName});
3439}
3440
3441ParseResult HierPathOp::parse(OpAsmParser &parser, OperationState &result) {
3442 // Parse the visibility attribute.
3443 (void)mlir::impl::parseOptionalVisibilityKeyword(parser, result.attributes);
3444
3445 // Parse the symbol name.
3446 StringAttr symName;
3447 if (parser.parseSymbolName(symName, SymbolTable::getSymbolAttrName(),
3448 result.attributes))
3449 return failure();
3450
3451 // Parse the namepath.
3452 SmallVector<Attribute> namepath;
3453 if (parser.parseCommaSeparatedList(
3454 OpAsmParser::Delimiter::Square, [&]() -> ParseResult {
3455 auto loc = parser.getCurrentLocation();
3456 SymbolRefAttr ref;
3457 if (parser.parseAttribute(ref))
3458 return failure();
3459
3460 // "A" is a Ref, "A::b" is a InnerRef, "A::B::c" is an error.
3461 auto pathLength = ref.getNestedReferences().size();
3462 if (pathLength == 0)
3463 namepath.push_back(
3464 FlatSymbolRefAttr::get(ref.getRootReference()));
3465 else if (pathLength == 1)
3466 namepath.push_back(hw::InnerRefAttr::get(ref.getRootReference(),
3467 ref.getLeafReference()));
3468 else
3469 return parser.emitError(loc,
3470 "only one nested reference is allowed");
3471 return success();
3472 }))
3473 return failure();
3474 result.addAttribute("namepath",
3475 ArrayAttr::get(parser.getContext(), namepath));
3476
3477 if (parser.parseOptionalAttrDict(result.attributes))
3478 return failure();
3479
3480 return success();
3481}
3482
3483//===----------------------------------------------------------------------===//
3484// TriggeredOp
3485//===----------------------------------------------------------------------===//
3486
3487void TriggeredOp::build(OpBuilder &builder, OperationState &odsState,
3488 EventControlAttr event, Value trigger,
3489 ValueRange inputs) {
3490 odsState.addOperands(trigger);
3491 odsState.addOperands(inputs);
3492 odsState.addAttribute(getEventAttrName(odsState.name), event);
3493 auto *r = odsState.addRegion();
3494 Block *b = new Block();
3495 r->push_back(b);
3496
3497 llvm::SmallVector<Location> argLocs;
3498 llvm::transform(inputs, std::back_inserter(argLocs),
3499 [&](Value v) { return v.getLoc(); });
3500 b->addArguments(inputs.getTypes(), argLocs);
3501}
3502
3503//===----------------------------------------------------------------------===//
3504// TableGen generated logic.
3505//===----------------------------------------------------------------------===//
3506
3507// Provide the autogenerated implementation guts for the Op classes.
3508#define GET_OP_CLASSES
3509#include "circt/Dialect/HW/HW.cpp.inc"
assert(baseType &&"element must be base type")
MlirType elementType
Definition CHIRRTL.cpp:29
static std::unique_ptr< Context > context
static void buildModule(OpBuilder &builder, OperationState &result, StringAttr name, ArrayRef< PortInfo > ports, ArrayAttr annotations, ArrayAttr layers)
void getAsmBlockArgumentNamesImpl(Operation *op, mlir::Region &region, OpAsmSetValueNameFn setNameFn)
Get a special name to use when printing the entry block arguments of the region contained by an opera...
static LogicalResult verifyModuleCommon(HWModuleLike module)
Definition HWOps.cpp:1086
static void printParamValue(OpAsmPrinter &p, Operation *, Attribute value, Type resultType)
Definition HWOps.cpp:501
static LogicalResult canonicalizeArrayInjectChain(ArrayInjectOp op, PatternRewriter &rewriter)
Definition HWOps.cpp:3009
static void printModuleOp(OpAsmPrinter &p, ModuleTy mod)
Definition HWOps.cpp:1031
static bool flattenConcatOp(ArrayConcatOp op, PatternRewriter &rewriter)
Definition HWOps.cpp:2153
static LogicalResult foldCreateToSlice(ArrayCreateOp op, PatternRewriter &rewriter)
Definition HWOps.cpp:1883
static SmallVector< PortInfo > getPortList(ModuleTy &mod)
Definition HWOps.cpp:1428
static ArrayAttr arrayOrEmpty(mlir::MLIRContext *context, ArrayRef< Attribute > attrs)
Definition HWOps.cpp:84
FunctionType getHWModuleOpType(Operation *op)
Definition HWOps.cpp:1023
static void printExtractOp(OpAsmPrinter &printer, AggType op)
Use the same printer for both struct_extract and union_extract since the syntax is identical.
Definition HWOps.cpp:2550
static void printArrayConcatTypes(OpAsmPrinter &p, Operation *, TypeRange inputTypes, Type resultType)
Definition HWOps.cpp:2114
static ParseResult parseSliceTypes(OpAsmParser &p, Type &srcType, Type &idxType)
Definition HWOps.cpp:1770
static void modifyModulePorts(Operation *op, ArrayRef< std::pair< unsigned, PortInfo > > insertInputs, ArrayRef< std::pair< unsigned, PortInfo > > insertOutputs, ArrayRef< unsigned > removeInputs, ArrayRef< unsigned > removeOutputs, Block *body=nullptr)
Insert and remove ports of a module.
Definition HWOps.cpp:694
static Value foldStructExtract(Operation *inputOp, uint32_t fieldIndex)
Definition HWOps.cpp:69
static bool hasAttribute(StringRef name, ArrayRef< NamedAttribute > attrs)
Definition HWOps.cpp:901
static void modifyModuleArgs(MLIRContext *context, ArrayRef< std::pair< unsigned, PortInfo > > insertArgs, ArrayRef< unsigned > removeArgs, ArrayRef< Attribute > oldArgNames, ArrayRef< Type > oldArgTypes, ArrayRef< Attribute > oldArgAttrs, ArrayRef< Location > oldArgLocs, SmallVector< Attribute > &newArgNames, SmallVector< Type > &newArgTypes, SmallVector< Attribute > &newArgAttrs, SmallVector< Location > &newArgLocs, Block *body=nullptr)
Internal implementation of argument/result insertion and removal on modules.
Definition HWOps.cpp:595
static bool mergeConcatSlices(ArrayConcatOp op, PatternRewriter &rewriter)
Definition HWOps.cpp:2170
static SmallVector< Location > getAllPortLocs(ModTy module)
Definition HWOps.cpp:1206
static ParseResult parseExtractOp(OpAsmParser &parser, OperationState &result)
Use the same parser for both struct_extract and union_extract since the syntax is identical.
Definition HWOps.cpp:2513
static void setAllPortNames(ArrayRef< Attribute > names, ModTy module)
Definition HWOps.cpp:1276
static void getAsmBlockArgumentNamesImpl(mlir::Region &region, OpAsmSetValueNameFn setNameFn)
Get a special name to use when printing the entry block arguments of the region contained by an opera...
Definition HWOps.cpp:101
static void setHWModuleType(ModTy &mod, ModuleType type)
Definition HWOps.cpp:1349
static ParseResult parseParamValue(OpAsmParser &p, Attribute &value, Type &resultType)
Definition HWOps.cpp:493
static LogicalResult checkAttributes(Operation *op, Attribute attr, Type type)
Definition HWOps.cpp:407
static LogicalResult canonicalizeArrayInjectIntoCreate(ArrayInjectOp op, PatternRewriter &rewriter)
Definition HWOps.cpp:3060
static std::optional< uint64_t > getUIntFromValue(Value value)
Definition HWOps.cpp:1954
static ParseResult parseHWModuleOp(OpAsmParser &parser, OperationState &result)
Definition HWOps.cpp:909
static LogicalResult verifyAggregateFieldIndexAndType(AggregateOp &op, AggregateType aggType, Type elementType)
Ensure an aggregate op's field index is within the bounds of the aggregate type and the accessed fiel...
Definition HWOps.cpp:2487
static PortInfo getPort(ModuleTy &mod, size_t idx)
Definition HWOps.cpp:1448
static void printSliceTypes(OpAsmPrinter &p, Operation *, Type srcType, Type idxType)
Definition HWOps.cpp:1784
static bool hasAdditionalAttributes(Op op, ArrayRef< StringRef > ignoredAttrs={})
Check whether an operation has any additional attributes set beyond its standard list of attributes r...
Definition HWOps.cpp:353
Delimiter
Definition HWOps.cpp:116
@ OptionalLessGreater
static ParseResult parseArrayConcatTypes(OpAsmParser &p, SmallVectorImpl< Type > &inputTypes, Type &resultType)
Definition HWOps.cpp:2084
static bool getFieldName(const FieldRef &fieldRef, SmallString< 32 > &string)
static Location getLoc(DefSlot slot)
Definition Mem2Reg.cpp:216
static StringAttr append(StringAttr base, const Twine &suffix)
Return a attribute with the specified suffix appended.
static Block * getBodyBlock(FModuleLike mod)
static InstancePath empty
A namespace that is used to store existing names and generate new names in some scope within the IR.
Definition Namespace.h:30
StringRef newName(const Twine &name)
Return a unique name, derived from the input name, and add the new name to the internal namespace.
Definition Namespace.h:87
void setOutput(unsigned i, Value v)
Definition HWOps.cpp:241
Value getInput(unsigned i)
Definition HWOps.cpp:247
llvm::SmallVector< Value > outputOperands
Definition HWOps.h:119
llvm::SmallVector< Value > inputArgs
Definition HWOps.h:118
llvm::StringMap< unsigned > outputIdx
Definition HWOps.h:117
llvm::StringMap< unsigned > inputIdx
Definition HWOps.h:117
HWModulePortAccessor(Location loc, const ModulePortInfo &info, Region &bodyRegion)
Definition HWOps.cpp:225
static StringRef getInnerSymbolAttrName()
Return the name of the attribute used for inner symbol names.
This helps visit TypeOp nodes.
Definition HWVisitors.h:25
ResultType dispatchTypeOpVisitor(Operation *op, ExtraArgs... args)
Definition HWVisitors.h:27
ResultType visitUnhandledTypeOp(Operation *op, ExtraArgs... args)
This callback is invoked on any combinational operations that are not handled by the concrete visitor...
Definition HWVisitors.h:57
ResultType visitInvalidTypeOp(Operation *op, ExtraArgs... args)
This callback is invoked on any non-expression operations.
Definition HWVisitors.h:50
create(array_value, idx)
Definition hw.py:450
create(array_value, low_index, ret_type)
Definition hw.py:466
create(data_type, value)
Definition hw.py:433
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition CalyxOps.cpp:55
uint64_t getWidth(Type t)
Definition ESIPasses.cpp:32
size_t getNumPorts(Operation *op)
Return the number of ports in a module-like thing (modules, memories, etc)
ModuleType fnToMod(Operation *op, ArrayRef< Attribute > inputNames, ArrayRef< Attribute > outputNames)
Definition HWTypes.cpp:1052
LogicalResult verifyParameterStructure(ArrayAttr parameters, ArrayAttr moduleParameters, const EmitErrorFn &emitError)
Check that all the parameter values specified to the instance are structurally valid.
std::function< void(std::function< bool(InFlightDiagnostic &)>)> EmitErrorFn
Whenever the nested function returns true, a note referring to the referenced module is attached to t...
LogicalResult verifyInstanceOfHWModule(Operation *instance, FlatSymbolRefAttr moduleRef, OperandRange inputs, TypeRange results, ArrayAttr argNames, ArrayAttr resultNames, ArrayAttr parameters, SymbolTableCollection &symbolTable)
Combines verifyReferencedModule, verifyInputs, verifyOutputs, and verifyParameters.
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
ParseResult parseModuleSignature(OpAsmParser &parser, SmallVectorImpl< PortParse > &args, TypeAttr &modType)
New Style parsing.
void printModuleSignatureNew(OpAsmPrinter &p, Region &body, hw::ModuleType modType, ArrayRef< Attribute > portAttrs, ArrayRef< Location > locAttrs)
bool isOffset(Value base, Value index, uint64_t offset)
Definition HWOps.cpp:1859
llvm::function_ref< void(OpBuilder &, HWModulePortAccessor &)> HWModuleBuilder
Definition HWOps.h:124
FunctionType getModuleType(Operation *module)
Return the signature for the specified module as a function type.
Definition HWOps.cpp:529
LogicalResult checkParameterInContext(Attribute value, Operation *module, Operation *usingOp, bool disallowParamRefs=false)
Check parameter specified by value to see if it is valid within the scope of the specified module mod...
Definition HWOps.cpp:202
int64_t getBitWidth(mlir::Type type)
Return the hardware bit width of a type.
Definition HWTypes.cpp:110
bool isAnyModuleOrInstance(Operation *module)
TODO: Move all these functions to a hw::ModuleLike interface.
Definition HWOps.cpp:523
StringAttr getVerilogModuleNameAttr(Operation *module)
Returns the verilog module name attribute or symbol name of any module-like operations.
Definition HWOps.cpp:547
mlir::Type getCanonicalType(mlir::Type type)
Definition HWTypes.cpp:49
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
ParseResult parseInputPortList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &inputs, SmallVectorImpl< Type > &inputTypes, ArrayAttr &inputNames)
Parse a list of instance input ports.
void printOutputPortList(OpAsmPrinter &p, Operation *op, TypeRange resultTypes, ArrayAttr resultNames)
Print a list of instance output ports.
ParseResult parseOptionalParameterList(OpAsmParser &parser, ArrayAttr &parameters)
Parse an parameter list if present.
void printOptionalParameterList(OpAsmPrinter &p, Operation *op, ArrayAttr parameters)
Print a parameter list for a module or instance.
StringRef chooseName(StringRef a, StringRef b)
Choose a good name for an item from two options.
Definition Naming.cpp:47
void printInputPortList(OpAsmPrinter &p, Operation *op, OperandRange inputs, TypeRange inputTypes, ArrayAttr inputNames)
Print a list of instance input ports.
ParseResult parseOutputPortList(OpAsmParser &parser, SmallVectorImpl< Type > &resultTypes, ArrayAttr &resultNames)
Parse a list of instance output ports.
Definition hw.py:1
function_ref< void(Value, StringRef)> OpAsmSetValueNameFn
Definition LLVM.h:183
This class represents the namespace in which InnerRef's can be resolved.
InnerSymTarget lookup(hw::InnerRefAttr inner) const
Resolve the InnerRef to its target within this namespace, returning empty target if no such name exis...
Operation * lookupOp(hw::InnerRefAttr inner) const
Resolve the InnerRef to its target within this namespace, returning empty target if no such name exis...
This holds a decoded list of input/inout and output ports for a module or instance.
PortInfo & at(size_t idx)
PortDirectionRange getOutputs()
mlir::Type type
Definition HWTypes.h:31
mlir::StringAttr name
Definition HWTypes.h:30
This holds the name, type, direction of a module's ports.