16 #include "mlir/Dialect/CommonFolders.h"
17 #include "mlir/IR/Attributes.h"
18 #include "mlir/IR/BuiltinOps.h"
19 #include "mlir/IR/BuiltinTypes.h"
20 #include "mlir/IR/Matchers.h"
21 #include "mlir/IR/OpImplementation.h"
22 #include "mlir/IR/PatternMatch.h"
23 #include "mlir/IR/Region.h"
24 #include "mlir/IR/Types.h"
25 #include "mlir/IR/Value.h"
26 #include "mlir/Support/LogicalResult.h"
27 #include "llvm/ADT/ArrayRef.h"
28 #include "llvm/ADT/SmallVector.h"
29 #include "llvm/ADT/StringSet.h"
30 #include "llvm/ADT/TypeSwitch.h"
32 using namespace circt;
35 template <
class AttrElementT,
36 class ElementValueT =
typename AttrElementT::ValueType,
37 class CalculationT = function_ref<ElementValueT(ElementValueT)>>
39 const CalculationT &calculate) {
40 assert(operands.size() == 1 &&
"unary op takes one operand");
44 if (
auto val = dyn_cast<AttrElementT>(operands[0])) {
46 }
else if (
auto val = dyn_cast<SplatElementsAttr>(operands[0])) {
49 auto elementResult = calculate(val.getSplatValue<ElementValueT>());
52 if (
auto val = dyn_cast<ElementsAttr>(operands[0])) {
55 auto valIt = val.getValues<ElementValueT>().begin();
56 SmallVector<ElementValueT, 4> elementResults;
57 elementResults.reserve(val.getNumElements());
58 for (
size_t i = 0, e = val.getNumElements(); i < e; ++i, ++valIt)
59 elementResults.push_back(calculate(*valIt));
65 template <
class AttrElementT,
66 class ElementValueT =
typename AttrElementT::ValueType,
67 class CalculationT = function_ref<
68 ElementValueT(ElementValueT, ElementValueT, ElementValueT)>>
70 const CalculationT &calculate) {
71 assert(operands.size() == 3 &&
"ternary op takes three operands");
72 if (!operands[0] || !operands[1] || !operands[2])
75 if (isa<AttrElementT>(operands[0]) && isa<AttrElementT>(operands[1]) &&
76 isa<AttrElementT>(operands[2])) {
77 auto fst = cast<AttrElementT>(operands[0]);
78 auto snd = cast<AttrElementT>(operands[1]);
79 auto trd = cast<AttrElementT>(operands[2]);
83 calculate(fst.getValue(), snd.getValue(), trd.getValue()));
85 if (isa<SplatElementsAttr>(operands[0]) &&
86 isa<SplatElementsAttr>(operands[1]) &&
87 isa<SplatElementsAttr>(operands[2])) {
90 auto fst = cast<SplatElementsAttr>(operands[0]);
91 auto snd = cast<SplatElementsAttr>(operands[1]);
92 auto trd = cast<SplatElementsAttr>(operands[2]);
94 auto elementResult = calculate(fst.getSplatValue<ElementValueT>(),
95 snd.getSplatValue<ElementValueT>(),
96 trd.getSplatValue<ElementValueT>());
99 if (isa<ElementsAttr>(operands[0]) && isa<ElementsAttr>(operands[1]) &&
100 isa<ElementsAttr>(operands[2])) {
103 auto fst = cast<ElementsAttr>(operands[0]);
104 auto snd = cast<ElementsAttr>(operands[1]);
105 auto trd = cast<ElementsAttr>(operands[2]);
107 auto fstIt = fst.getValues<ElementValueT>().begin();
108 auto sndIt = snd.getValues<ElementValueT>().begin();
109 auto trdIt = trd.getValues<ElementValueT>().begin();
110 SmallVector<ElementValueT, 4> elementResults;
111 elementResults.reserve(fst.getNumElements());
112 for (
size_t i = 0, e = fst.getNumElements(); i < e;
113 ++i, ++fstIt, ++sndIt, ++trdIt)
114 elementResults.push_back(calculate(*fstIt, *sndIt, *trdIt));
122 struct constant_int_all_ones_matcher {
123 bool match(Operation *op) {
125 return mlir::detail::constant_int_value_binder(&value).match(op) &&
133 if (
auto sig = dyn_cast<hw::InOutType>(type))
134 type = sig.getElementType();
135 else if (
auto ptr = dyn_cast<llhd::PtrType>(type))
136 type = ptr.getElementType();
137 if (
auto array = dyn_cast<hw::ArrayType>(type))
138 return array.getNumElements();
139 if (
auto tup = dyn_cast<hw::StructType>(type))
140 return tup.getElements().size();
141 return type.getIntOrFloatBitWidth();
145 if (
auto sig = dyn_cast<hw::InOutType>(type))
146 type = sig.getElementType();
147 else if (
auto ptr = dyn_cast<llhd::PtrType>(type))
148 type = ptr.getElementType();
149 if (
auto array = dyn_cast<hw::ArrayType>(type))
150 return array.getElementType();
162 OpFoldResult llhd::ConstantTimeOp::fold(FoldAdaptor adaptor) {
163 assert(adaptor.getOperands().empty() &&
"const has no operands");
164 return getValueAttr();
167 void llhd::ConstantTimeOp::build(OpBuilder &builder, OperationState &result,
168 unsigned time,
const StringRef &timeUnit,
169 unsigned delta,
unsigned epsilon) {
170 auto *ctx = builder.getContext();
171 auto attr =
TimeAttr::get(ctx, time, timeUnit, delta, epsilon);
186 if (op.getResultWidth() == op.getInputWidth() &&
187 cast<IntegerAttr>(operands[1]).getValue().isZero())
188 return op.getInput();
193 OpFoldResult llhd::SigExtractOp::fold(FoldAdaptor adaptor) {
197 OpFoldResult llhd::PtrExtractOp::fold(FoldAdaptor adaptor) {
207 ArrayRef<Attribute> operands) {
212 if (op.getResultWidth() == op.getInputWidth() &&
213 cast<IntegerAttr>(operands[1]).getValue().isZero())
214 return op.getInput();
219 OpFoldResult llhd::SigArraySliceOp::fold(FoldAdaptor adaptor) {
223 OpFoldResult llhd::PtrArraySliceOp::fold(FoldAdaptor adaptor) {
229 PatternRewriter &rewriter) {
230 IntegerAttr indexAttr;
231 if (!matchPattern(op.getLowIndex(), m_Constant(&indexAttr)))
237 if (matchPattern(op.getInput(),
238 m_Op<Op>(matchers::m_Any(), m_Constant(&a)))) {
239 auto sliceOp = op.getInput().template getDefiningOp<Op>();
240 rewriter.modifyOpInPlace(op, [&]() {
241 op.getInputMutable().assign(sliceOp.getInput());
243 op->getLoc(), a.getValue() + indexAttr.getValue());
244 op.getLowIndexMutable().assign(newIndex);
254 PatternRewriter &rewriter) {
259 PatternRewriter &rewriter) {
267 template <
class SigPtrType>
269 MLIRContext *context, std::optional<Location> loc, ValueRange operands,
270 DictionaryAttr attrs, mlir::OpaqueProperties properties,
271 mlir::RegionRange regions, SmallVectorImpl<Type> &results) {
273 cast<hw::StructType>(
274 cast<SigPtrType>(operands[0].getType()).getElementType())
276 cast<StringAttr>(attrs.getNamed(
"field")->getValue()).getValue());
278 context->getDiagEngine().emit(loc.value_or(UnknownLoc()),
279 DiagnosticSeverity::Error)
280 <<
"invalid field name specified";
288 MLIRContext *context, std::optional<Location> loc, ValueRange operands,
289 DictionaryAttr attrs, mlir::OpaqueProperties properties,
290 mlir::RegionRange regions, SmallVectorImpl<Type> &results) {
291 return inferReturnTypesOfStructExtractOp<hw::InOutType>(
292 context, loc, operands, attrs, properties, regions, results);
296 MLIRContext *context, std::optional<Location> loc, ValueRange operands,
297 DictionaryAttr attrs, mlir::OpaqueProperties properties,
298 mlir::RegionRange regions, SmallVectorImpl<Type> &results) {
299 return inferReturnTypesOfStructExtractOp<llhd::PtrType>(
300 context, loc, operands, attrs, properties, regions, results);
307 LogicalResult llhd::DrvOp::fold(FoldAdaptor adaptor,
308 SmallVectorImpl<OpFoldResult> &result) {
312 if (matchPattern(getEnable(), m_One())) {
313 getEnableMutable().clear();
321 PatternRewriter &rewriter) {
325 if (matchPattern(op.getEnable(), m_Zero())) {
326 rewriter.eraseOp(op);
338 SuccessorOperands llhd::WaitOp::getSuccessorOperands(
unsigned index) {
339 assert(index == 0 &&
"invalid successor index");
340 return SuccessorOperands(getDestOpsMutable());
348 PatternRewriter &rewriter) {
349 if (op.getLhs() == op.getRhs())
350 rewriter.eraseOp(op);
358 ParseResult llhd::RegOp::parse(OpAsmParser &parser, OperationState &result) {
359 OpAsmParser::UnresolvedOperand signal;
361 SmallVector<OpAsmParser::UnresolvedOperand, 8> valueOperands;
362 SmallVector<OpAsmParser::UnresolvedOperand, 8> triggerOperands;
363 SmallVector<OpAsmParser::UnresolvedOperand, 8> delayOperands;
364 SmallVector<OpAsmParser::UnresolvedOperand, 8> gateOperands;
365 SmallVector<Type, 8> valueTypes;
366 llvm::SmallVector<int64_t, 8> modesArray;
367 llvm::SmallVector<int64_t, 8> gateMask;
368 int64_t gateCount = 0;
370 if (parser.parseOperand(signal))
372 while (succeeded(parser.parseOptionalComma())) {
373 OpAsmParser::UnresolvedOperand value;
374 OpAsmParser::UnresolvedOperand trigger;
375 OpAsmParser::UnresolvedOperand delay;
376 OpAsmParser::UnresolvedOperand gate;
379 NamedAttrList attrStorage;
381 if (parser.parseLParen())
383 if (parser.parseOperand(value) || parser.parseComma())
385 if (parser.parseAttribute(modeAttr, parser.getBuilder().getNoneType(),
386 "modes", attrStorage))
388 auto attrOptional = llhd::symbolizeRegMode(modeAttr.getValue());
390 return parser.emitError(parser.getCurrentLocation(),
391 "invalid string attribute");
392 modesArray.push_back(
static_cast<int64_t
>(*attrOptional));
393 if (parser.parseOperand(trigger))
395 if (parser.parseKeyword(
"after") || parser.parseOperand(delay))
397 if (succeeded(parser.parseOptionalKeyword(
"if"))) {
398 gateMask.push_back(++gateCount);
399 if (parser.parseOperand(gate))
401 gateOperands.push_back(gate);
403 gateMask.push_back(0);
405 if (parser.parseColon() || parser.parseType(valueType) ||
406 parser.parseRParen())
408 valueOperands.push_back(value);
409 triggerOperands.push_back(trigger);
410 delayOperands.push_back(delay);
411 valueTypes.push_back(valueType);
413 if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
414 parser.parseType(signalType))
416 if (parser.resolveOperand(signal, signalType, result.operands))
418 if (parser.resolveOperands(valueOperands, valueTypes,
419 parser.getCurrentLocation(), result.operands))
421 for (
auto operand : triggerOperands)
422 if (parser.resolveOperand(operand, parser.getBuilder().getI1Type(),
425 for (
auto operand : delayOperands)
426 if (parser.resolveOperand(
430 for (
auto operand : gateOperands)
431 if (parser.resolveOperand(operand, parser.getBuilder().getI1Type(),
434 result.addAttribute(
"gateMask",
435 parser.getBuilder().getI64ArrayAttr(gateMask));
436 result.addAttribute(
"modes", parser.getBuilder().getI64ArrayAttr(modesArray));
437 llvm::SmallVector<int32_t, 5> operandSizes;
438 operandSizes.push_back(1);
439 operandSizes.push_back(valueOperands.size());
440 operandSizes.push_back(triggerOperands.size());
441 operandSizes.push_back(delayOperands.size());
442 operandSizes.push_back(gateOperands.size());
443 result.addAttribute(
"operandSegmentSizes",
444 parser.getBuilder().getDenseI32ArrayAttr(operandSizes));
449 void llhd::RegOp::print(OpAsmPrinter &printer) {
450 printer <<
" " << getSignal();
451 for (
size_t i = 0, e = getValues().size(); i < e; ++i) {
452 std::optional<llhd::RegMode> mode = llhd::symbolizeRegMode(
453 cast<IntegerAttr>(getModes().getValue()[i]).
getInt());
455 emitError(
"invalid RegMode");
458 printer <<
", (" << getValues()[i] <<
", \""
459 << llhd::stringifyRegMode(*mode) <<
"\" " << getTriggers()[i]
460 <<
" after " << getDelays()[i];
462 printer <<
" if " << getGateAt(i);
463 printer <<
" : " << getValues()[i].getType() <<
")";
465 printer.printOptionalAttrDict((*this)->getAttrs(),
466 {
"modes",
"gateMask",
"operandSegmentSizes"});
467 printer <<
" : " << getSignal().getType();
472 if (getTriggers().size() < 1)
473 return emitError(
"At least one trigger quadruple has to be present.");
476 if (getValues().size() != getTriggers().size())
477 return emitOpError(
"Number of 'values' is not equal to the number of "
479 << getValues().size() <<
" modes, but " << getTriggers().size()
483 if (getDelays().size() != getTriggers().size())
484 return emitOpError(
"Number of 'delays' is not equal to the number of "
486 << getDelays().size() <<
" modes, but " << getTriggers().size()
491 if (getModes().size() != getTriggers().size())
492 return emitOpError(
"Number of 'modes' is not equal to the number of "
494 << getModes().size() <<
" modes, but " << getTriggers().size()
499 if (getGateMask().size() != getTriggers().size())
500 return emitOpError(
"Size of 'gateMask' is not equal to the size of "
502 << getGateMask().size() <<
" modes, but " << getTriggers().size()
508 unsigned counter = 0;
509 unsigned prevElement = 0;
510 for (Attribute maskElem : getGateMask().getValue()) {
511 int64_t val = cast<IntegerAttr>(maskElem).getInt();
513 return emitError(
"Element in 'gateMask' must not be negative!");
516 if (val != ++prevElement)
518 "'gateMask' has to contain every number from 1 to the "
519 "number of gates minus one exactly once in increasing order "
520 "(may have zeros in-between).");
523 if (getGates().size() != counter)
524 return emitError(
"The number of non-zero elements in 'gateMask' and the "
525 "size of the 'gates' variadic have to match.");
529 for (
auto val : getValues()) {
530 if (val.getType() != getSignal().getType() &&
532 cast<hw::InOutType>(getSignal().getType()).getElementType()) {
534 "type of each 'value' has to be either the same as the "
535 "type of 'signal' or the underlying type of 'signal'");
541 #include "circt/Dialect/LLHD/IR/LLHDEnums.cpp.inc"
543 #define GET_OP_CLASSES
544 #include "circt/Dialect/LLHD/IR/LLHD.cpp.inc"
assert(baseType &&"element must be base type")
static Attribute constFoldTernaryOp(ArrayRef< Attribute > operands, const CalculationT &calculate)
static LogicalResult inferReturnTypesOfStructExtractOp(MLIRContext *context, std::optional< Location > loc, ValueRange operands, DictionaryAttr attrs, mlir::OpaqueProperties properties, mlir::RegionRange regions, SmallVectorImpl< Type > &results)
static OpFoldResult foldSigPtrArraySliceOp(Op op, ArrayRef< Attribute > operands)
static LogicalResult canonicalizeSigPtrArraySliceOp(Op op, PatternRewriter &rewriter)
static Attribute constFoldUnaryOp(ArrayRef< Attribute > operands, const CalculationT &calculate)
static OpFoldResult foldSigPtrExtractOp(Op op, ArrayRef< Attribute > operands)
static std::optional< APInt > getInt(Value value)
Helper to convert a value to a constant integer if it is one.
static LogicalResult canonicalize(Op op, PatternRewriter &rewriter)
static LogicalResult verify(Value clock, bool eventExists, mlir::Location loc)
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
LogicalResult inferReturnTypes(MLIRContext *context, std::optional< Location > loc, ValueRange operands, DictionaryAttr attrs, mlir::OpaqueProperties properties, mlir::RegionRange regions, SmallVectorImpl< Type > &results, llvm::function_ref< FIRRTLType(ValueRange, ArrayRef< NamedAttribute >, std::optional< Location >)> callback)
unsigned getLLHDTypeWidth(Type type)
Type getLLHDElementType(Type type)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.