CIRCT  19.0.0git
LLHDOps.cpp
Go to the documentation of this file.
1 //===- LLHDOps.cpp - Implement the LLHD operations ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implement the LLHD ops.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 #include "circt/Dialect/HW/HWOps.h"
16 #include "mlir/Dialect/CommonFolders.h"
17 #include "mlir/IR/Attributes.h"
18 #include "mlir/IR/BuiltinOps.h"
19 #include "mlir/IR/BuiltinTypes.h"
20 #include "mlir/IR/Matchers.h"
21 #include "mlir/IR/OpImplementation.h"
22 #include "mlir/IR/PatternMatch.h"
23 #include "mlir/IR/Region.h"
24 #include "mlir/IR/Types.h"
25 #include "mlir/IR/Value.h"
26 #include "mlir/Support/LogicalResult.h"
27 #include "llvm/ADT/ArrayRef.h"
28 #include "llvm/ADT/SmallVector.h"
29 #include "llvm/ADT/StringSet.h"
30 #include "llvm/ADT/TypeSwitch.h"
31 
32 using namespace circt;
33 using namespace mlir;
34 
35 template <class AttrElementT,
36  class ElementValueT = typename AttrElementT::ValueType,
37  class CalculationT = function_ref<ElementValueT(ElementValueT)>>
38 static Attribute constFoldUnaryOp(ArrayRef<Attribute> operands,
39  const CalculationT &calculate) {
40  assert(operands.size() == 1 && "unary op takes one operand");
41  if (!operands[0])
42  return {};
43 
44  if (auto val = dyn_cast<AttrElementT>(operands[0])) {
45  return AttrElementT::get(val.getType(), calculate(val.getValue()));
46  } else if (auto val = dyn_cast<SplatElementsAttr>(operands[0])) {
47  // Operand is a splat so we can avoid expanding the value out and
48  // just fold based on the splat value.
49  auto elementResult = calculate(val.getSplatValue<ElementValueT>());
50  return DenseElementsAttr::get(val.getType(), elementResult);
51  }
52  if (auto val = dyn_cast<ElementsAttr>(operands[0])) {
53  // Operand is ElementsAttr-derived; perform an element-wise fold by
54  // expanding the values.
55  auto valIt = val.getValues<ElementValueT>().begin();
56  SmallVector<ElementValueT, 4> elementResults;
57  elementResults.reserve(val.getNumElements());
58  for (size_t i = 0, e = val.getNumElements(); i < e; ++i, ++valIt)
59  elementResults.push_back(calculate(*valIt));
60  return DenseElementsAttr::get(val.getType(), elementResults);
61  }
62  return {};
63 }
64 
65 template <class AttrElementT,
66  class ElementValueT = typename AttrElementT::ValueType,
67  class CalculationT = function_ref<
68  ElementValueT(ElementValueT, ElementValueT, ElementValueT)>>
69 static Attribute constFoldTernaryOp(ArrayRef<Attribute> operands,
70  const CalculationT &calculate) {
71  assert(operands.size() == 3 && "ternary op takes three operands");
72  if (!operands[0] || !operands[1] || !operands[2])
73  return {};
74 
75  if (isa<AttrElementT>(operands[0]) && isa<AttrElementT>(operands[1]) &&
76  isa<AttrElementT>(operands[2])) {
77  auto fst = cast<AttrElementT>(operands[0]);
78  auto snd = cast<AttrElementT>(operands[1]);
79  auto trd = cast<AttrElementT>(operands[2]);
80 
81  return AttrElementT::get(
82  fst.getType(),
83  calculate(fst.getValue(), snd.getValue(), trd.getValue()));
84  }
85  if (isa<SplatElementsAttr>(operands[0]) &&
86  isa<SplatElementsAttr>(operands[1]) &&
87  isa<SplatElementsAttr>(operands[2])) {
88  // Operands are splats so we can avoid expanding the values out and
89  // just fold based on the splat value.
90  auto fst = cast<SplatElementsAttr>(operands[0]);
91  auto snd = cast<SplatElementsAttr>(operands[1]);
92  auto trd = cast<SplatElementsAttr>(operands[2]);
93 
94  auto elementResult = calculate(fst.getSplatValue<ElementValueT>(),
95  snd.getSplatValue<ElementValueT>(),
96  trd.getSplatValue<ElementValueT>());
97  return DenseElementsAttr::get(fst.getType(), elementResult);
98  }
99  if (isa<ElementsAttr>(operands[0]) && isa<ElementsAttr>(operands[1]) &&
100  isa<ElementsAttr>(operands[2])) {
101  // Operands are ElementsAttr-derived; perform an element-wise fold by
102  // expanding the values.
103  auto fst = cast<ElementsAttr>(operands[0]);
104  auto snd = cast<ElementsAttr>(operands[1]);
105  auto trd = cast<ElementsAttr>(operands[2]);
106 
107  auto fstIt = fst.getValues<ElementValueT>().begin();
108  auto sndIt = snd.getValues<ElementValueT>().begin();
109  auto trdIt = trd.getValues<ElementValueT>().begin();
110  SmallVector<ElementValueT, 4> elementResults;
111  elementResults.reserve(fst.getNumElements());
112  for (size_t i = 0, e = fst.getNumElements(); i < e;
113  ++i, ++fstIt, ++sndIt, ++trdIt)
114  elementResults.push_back(calculate(*fstIt, *sndIt, *trdIt));
115  return DenseElementsAttr::get(fst.getType(), elementResults);
116  }
117  return {};
118 }
119 
120 namespace {
121 
122 struct constant_int_all_ones_matcher {
123  bool match(Operation *op) {
124  APInt value;
125  return mlir::detail::constant_int_value_binder(&value).match(op) &&
126  value.isAllOnes();
127  }
128 };
129 
130 } // anonymous namespace
131 
132 unsigned circt::llhd::getLLHDTypeWidth(Type type) {
133  if (auto sig = dyn_cast<hw::InOutType>(type))
134  type = sig.getElementType();
135  else if (auto ptr = dyn_cast<llhd::PtrType>(type))
136  type = ptr.getElementType();
137  if (auto array = dyn_cast<hw::ArrayType>(type))
138  return array.getNumElements();
139  if (auto tup = dyn_cast<hw::StructType>(type))
140  return tup.getElements().size();
141  return type.getIntOrFloatBitWidth();
142 }
143 
145  if (auto sig = dyn_cast<hw::InOutType>(type))
146  type = sig.getElementType();
147  else if (auto ptr = dyn_cast<llhd::PtrType>(type))
148  type = ptr.getElementType();
149  if (auto array = dyn_cast<hw::ArrayType>(type))
150  return array.getElementType();
151  return type;
152 }
153 
154 //===---------------------------------------------------------------------===//
155 // LLHD Operations
156 //===---------------------------------------------------------------------===//
157 
158 //===----------------------------------------------------------------------===//
159 // ConstantTimeOp
160 //===----------------------------------------------------------------------===//
161 
162 OpFoldResult llhd::ConstantTimeOp::fold(FoldAdaptor adaptor) {
163  assert(adaptor.getOperands().empty() && "const has no operands");
164  return getValueAttr();
165 }
166 
167 void llhd::ConstantTimeOp::build(OpBuilder &builder, OperationState &result,
168  unsigned time, const StringRef &timeUnit,
169  unsigned delta, unsigned epsilon) {
170  auto *ctx = builder.getContext();
171  auto attr = TimeAttr::get(ctx, time, timeUnit, delta, epsilon);
172  return build(builder, result, TimeType::get(ctx), attr);
173 }
174 
175 //===----------------------------------------------------------------------===//
176 // SigExtractOp and PtrExtractOp
177 //===----------------------------------------------------------------------===//
178 
179 template <class Op>
180 static OpFoldResult foldSigPtrExtractOp(Op op, ArrayRef<Attribute> operands) {
181 
182  if (!operands[1])
183  return nullptr;
184 
185  // llhd.sig.extract(input, 0) with inputWidth == resultWidth => input
186  if (op.getResultWidth() == op.getInputWidth() &&
187  cast<IntegerAttr>(operands[1]).getValue().isZero())
188  return op.getInput();
189 
190  return nullptr;
191 }
192 
193 OpFoldResult llhd::SigExtractOp::fold(FoldAdaptor adaptor) {
194  return foldSigPtrExtractOp(*this, adaptor.getOperands());
195 }
196 
197 OpFoldResult llhd::PtrExtractOp::fold(FoldAdaptor adaptor) {
198  return foldSigPtrExtractOp(*this, adaptor.getOperands());
199 }
200 
201 //===----------------------------------------------------------------------===//
202 // SigArraySliceOp and PtrArraySliceOp
203 //===----------------------------------------------------------------------===//
204 
205 template <class Op>
206 static OpFoldResult foldSigPtrArraySliceOp(Op op,
207  ArrayRef<Attribute> operands) {
208  if (!operands[1])
209  return nullptr;
210 
211  // llhd.sig.array_slice(input, 0) with inputWidth == resultWidth => input
212  if (op.getResultWidth() == op.getInputWidth() &&
213  cast<IntegerAttr>(operands[1]).getValue().isZero())
214  return op.getInput();
215 
216  return nullptr;
217 }
218 
219 OpFoldResult llhd::SigArraySliceOp::fold(FoldAdaptor adaptor) {
220  return foldSigPtrArraySliceOp(*this, adaptor.getOperands());
221 }
222 
223 OpFoldResult llhd::PtrArraySliceOp::fold(FoldAdaptor adaptor) {
224  return foldSigPtrArraySliceOp(*this, adaptor.getOperands());
225 }
226 
227 template <class Op>
228 static LogicalResult canonicalizeSigPtrArraySliceOp(Op op,
229  PatternRewriter &rewriter) {
230  IntegerAttr indexAttr;
231  if (!matchPattern(op.getLowIndex(), m_Constant(&indexAttr)))
232  return failure();
233 
234  // llhd.sig.array_slice(llhd.sig.array_slice(target, a), b)
235  // => llhd.sig.array_slice(target, a+b)
236  IntegerAttr a;
237  if (matchPattern(op.getInput(),
238  m_Op<Op>(matchers::m_Any(), m_Constant(&a)))) {
239  auto sliceOp = op.getInput().template getDefiningOp<Op>();
240  rewriter.modifyOpInPlace(op, [&]() {
241  op.getInputMutable().assign(sliceOp.getInput());
242  Value newIndex = rewriter.create<hw::ConstantOp>(
243  op->getLoc(), a.getValue() + indexAttr.getValue());
244  op.getLowIndexMutable().assign(newIndex);
245  });
246 
247  return success();
248  }
249 
250  return failure();
251 }
252 
253 LogicalResult llhd::SigArraySliceOp::canonicalize(llhd::SigArraySliceOp op,
254  PatternRewriter &rewriter) {
255  return canonicalizeSigPtrArraySliceOp(op, rewriter);
256 }
257 
258 LogicalResult llhd::PtrArraySliceOp::canonicalize(llhd::PtrArraySliceOp op,
259  PatternRewriter &rewriter) {
260  return canonicalizeSigPtrArraySliceOp(op, rewriter);
261 }
262 
263 //===----------------------------------------------------------------------===//
264 // SigStructExtractOp and PtrStructExtractOp
265 //===----------------------------------------------------------------------===//
266 
267 template <class SigPtrType>
269  MLIRContext *context, std::optional<Location> loc, ValueRange operands,
270  DictionaryAttr attrs, mlir::OpaqueProperties properties,
271  mlir::RegionRange regions, SmallVectorImpl<Type> &results) {
272  Type type =
273  cast<hw::StructType>(
274  cast<SigPtrType>(operands[0].getType()).getElementType())
275  .getFieldType(
276  cast<StringAttr>(attrs.getNamed("field")->getValue()).getValue());
277  if (!type) {
278  context->getDiagEngine().emit(loc.value_or(UnknownLoc()),
279  DiagnosticSeverity::Error)
280  << "invalid field name specified";
281  return failure();
282  }
283  results.push_back(SigPtrType::get(type));
284  return success();
285 }
286 
288  MLIRContext *context, std::optional<Location> loc, ValueRange operands,
289  DictionaryAttr attrs, mlir::OpaqueProperties properties,
290  mlir::RegionRange regions, SmallVectorImpl<Type> &results) {
291  return inferReturnTypesOfStructExtractOp<hw::InOutType>(
292  context, loc, operands, attrs, properties, regions, results);
293 }
294 
296  MLIRContext *context, std::optional<Location> loc, ValueRange operands,
297  DictionaryAttr attrs, mlir::OpaqueProperties properties,
298  mlir::RegionRange regions, SmallVectorImpl<Type> &results) {
299  return inferReturnTypesOfStructExtractOp<llhd::PtrType>(
300  context, loc, operands, attrs, properties, regions, results);
301 }
302 
303 //===----------------------------------------------------------------------===//
304 // DrvOp
305 //===----------------------------------------------------------------------===//
306 
307 LogicalResult llhd::DrvOp::fold(FoldAdaptor adaptor,
308  SmallVectorImpl<OpFoldResult> &result) {
309  if (!getEnable())
310  return failure();
311 
312  if (matchPattern(getEnable(), m_One())) {
313  getEnableMutable().clear();
314  return success();
315  }
316 
317  return failure();
318 }
319 
320 LogicalResult llhd::DrvOp::canonicalize(llhd::DrvOp op,
321  PatternRewriter &rewriter) {
322  if (!op.getEnable())
323  return failure();
324 
325  if (matchPattern(op.getEnable(), m_Zero())) {
326  rewriter.eraseOp(op);
327  return success();
328  }
329 
330  return failure();
331 }
332 
333 //===----------------------------------------------------------------------===//
334 // WaitOp
335 //===----------------------------------------------------------------------===//
336 
337 // Implement this operation for the BranchOpInterface
338 SuccessorOperands llhd::WaitOp::getSuccessorOperands(unsigned index) {
339  assert(index == 0 && "invalid successor index");
340  return SuccessorOperands(getDestOpsMutable());
341 }
342 
343 //===----------------------------------------------------------------------===//
344 // ConnectOp
345 //===----------------------------------------------------------------------===//
346 
347 LogicalResult llhd::ConnectOp::canonicalize(llhd::ConnectOp op,
348  PatternRewriter &rewriter) {
349  if (op.getLhs() == op.getRhs())
350  rewriter.eraseOp(op);
351  return success();
352 }
353 
354 //===----------------------------------------------------------------------===//
355 // RegOp
356 //===----------------------------------------------------------------------===//
357 
358 ParseResult llhd::RegOp::parse(OpAsmParser &parser, OperationState &result) {
359  OpAsmParser::UnresolvedOperand signal;
360  Type signalType;
361  SmallVector<OpAsmParser::UnresolvedOperand, 8> valueOperands;
362  SmallVector<OpAsmParser::UnresolvedOperand, 8> triggerOperands;
363  SmallVector<OpAsmParser::UnresolvedOperand, 8> delayOperands;
364  SmallVector<OpAsmParser::UnresolvedOperand, 8> gateOperands;
365  SmallVector<Type, 8> valueTypes;
366  llvm::SmallVector<int64_t, 8> modesArray;
367  llvm::SmallVector<int64_t, 8> gateMask;
368  int64_t gateCount = 0;
369 
370  if (parser.parseOperand(signal))
371  return failure();
372  while (succeeded(parser.parseOptionalComma())) {
373  OpAsmParser::UnresolvedOperand value;
374  OpAsmParser::UnresolvedOperand trigger;
375  OpAsmParser::UnresolvedOperand delay;
376  OpAsmParser::UnresolvedOperand gate;
377  Type valueType;
378  StringAttr modeAttr;
379  NamedAttrList attrStorage;
380 
381  if (parser.parseLParen())
382  return failure();
383  if (parser.parseOperand(value) || parser.parseComma())
384  return failure();
385  if (parser.parseAttribute(modeAttr, parser.getBuilder().getNoneType(),
386  "modes", attrStorage))
387  return failure();
388  auto attrOptional = llhd::symbolizeRegMode(modeAttr.getValue());
389  if (!attrOptional)
390  return parser.emitError(parser.getCurrentLocation(),
391  "invalid string attribute");
392  modesArray.push_back(static_cast<int64_t>(*attrOptional));
393  if (parser.parseOperand(trigger))
394  return failure();
395  if (parser.parseKeyword("after") || parser.parseOperand(delay))
396  return failure();
397  if (succeeded(parser.parseOptionalKeyword("if"))) {
398  gateMask.push_back(++gateCount);
399  if (parser.parseOperand(gate))
400  return failure();
401  gateOperands.push_back(gate);
402  } else {
403  gateMask.push_back(0);
404  }
405  if (parser.parseColon() || parser.parseType(valueType) ||
406  parser.parseRParen())
407  return failure();
408  valueOperands.push_back(value);
409  triggerOperands.push_back(trigger);
410  delayOperands.push_back(delay);
411  valueTypes.push_back(valueType);
412  }
413  if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
414  parser.parseType(signalType))
415  return failure();
416  if (parser.resolveOperand(signal, signalType, result.operands))
417  return failure();
418  if (parser.resolveOperands(valueOperands, valueTypes,
419  parser.getCurrentLocation(), result.operands))
420  return failure();
421  for (auto operand : triggerOperands)
422  if (parser.resolveOperand(operand, parser.getBuilder().getI1Type(),
423  result.operands))
424  return failure();
425  for (auto operand : delayOperands)
426  if (parser.resolveOperand(
427  operand, llhd::TimeType::get(parser.getBuilder().getContext()),
428  result.operands))
429  return failure();
430  for (auto operand : gateOperands)
431  if (parser.resolveOperand(operand, parser.getBuilder().getI1Type(),
432  result.operands))
433  return failure();
434  result.addAttribute("gateMask",
435  parser.getBuilder().getI64ArrayAttr(gateMask));
436  result.addAttribute("modes", parser.getBuilder().getI64ArrayAttr(modesArray));
437  llvm::SmallVector<int32_t, 5> operandSizes;
438  operandSizes.push_back(1);
439  operandSizes.push_back(valueOperands.size());
440  operandSizes.push_back(triggerOperands.size());
441  operandSizes.push_back(delayOperands.size());
442  operandSizes.push_back(gateOperands.size());
443  result.addAttribute("operandSegmentSizes",
444  parser.getBuilder().getDenseI32ArrayAttr(operandSizes));
445 
446  return success();
447 }
448 
449 void llhd::RegOp::print(OpAsmPrinter &printer) {
450  printer << " " << getSignal();
451  for (size_t i = 0, e = getValues().size(); i < e; ++i) {
452  std::optional<llhd::RegMode> mode = llhd::symbolizeRegMode(
453  cast<IntegerAttr>(getModes().getValue()[i]).getInt());
454  if (!mode) {
455  emitError("invalid RegMode");
456  return;
457  }
458  printer << ", (" << getValues()[i] << ", \""
459  << llhd::stringifyRegMode(*mode) << "\" " << getTriggers()[i]
460  << " after " << getDelays()[i];
461  if (hasGate(i))
462  printer << " if " << getGateAt(i);
463  printer << " : " << getValues()[i].getType() << ")";
464  }
465  printer.printOptionalAttrDict((*this)->getAttrs(),
466  {"modes", "gateMask", "operandSegmentSizes"});
467  printer << " : " << getSignal().getType();
468 }
469 
470 LogicalResult llhd::RegOp::verify() {
471  // At least one trigger has to be present
472  if (getTriggers().size() < 1)
473  return emitError("At least one trigger quadruple has to be present.");
474 
475  // Values variadic operand must have the same size as the triggers variadic
476  if (getValues().size() != getTriggers().size())
477  return emitOpError("Number of 'values' is not equal to the number of "
478  "'triggers', got ")
479  << getValues().size() << " modes, but " << getTriggers().size()
480  << " triggers!";
481 
482  // Delay variadic operand must have the same size as the triggers variadic
483  if (getDelays().size() != getTriggers().size())
484  return emitOpError("Number of 'delays' is not equal to the number of "
485  "'triggers', got ")
486  << getDelays().size() << " modes, but " << getTriggers().size()
487  << " triggers!";
488 
489  // Array Attribute of RegModes must have the same number of elements as the
490  // variadics
491  if (getModes().size() != getTriggers().size())
492  return emitOpError("Number of 'modes' is not equal to the number of "
493  "'triggers', got ")
494  << getModes().size() << " modes, but " << getTriggers().size()
495  << " triggers!";
496 
497  // Array Attribute 'gateMask' must have the same number of elements as the
498  // triggers and values variadics
499  if (getGateMask().size() != getTriggers().size())
500  return emitOpError("Size of 'gateMask' is not equal to the size of "
501  "'triggers', got ")
502  << getGateMask().size() << " modes, but " << getTriggers().size()
503  << " triggers!";
504 
505  // Number of non-zero elements in 'gateMask' has to be the same as the size
506  // of the gates variadic, also each number from 1 to size-1 has to occur
507  // only once and in increasing order
508  unsigned counter = 0;
509  unsigned prevElement = 0;
510  for (Attribute maskElem : getGateMask().getValue()) {
511  int64_t val = cast<IntegerAttr>(maskElem).getInt();
512  if (val < 0)
513  return emitError("Element in 'gateMask' must not be negative!");
514  if (val == 0)
515  continue;
516  if (val != ++prevElement)
517  return emitError(
518  "'gateMask' has to contain every number from 1 to the "
519  "number of gates minus one exactly once in increasing order "
520  "(may have zeros in-between).");
521  counter++;
522  }
523  if (getGates().size() != counter)
524  return emitError("The number of non-zero elements in 'gateMask' and the "
525  "size of the 'gates' variadic have to match.");
526 
527  // Each value must be either the same type as the 'signal' or the underlying
528  // type of the 'signal'
529  for (auto val : getValues()) {
530  if (val.getType() != getSignal().getType() &&
531  val.getType() !=
532  cast<hw::InOutType>(getSignal().getType()).getElementType()) {
533  return emitOpError(
534  "type of each 'value' has to be either the same as the "
535  "type of 'signal' or the underlying type of 'signal'");
536  }
537  }
538  return success();
539 }
540 
541 #include "circt/Dialect/LLHD/IR/LLHDEnums.cpp.inc"
542 
543 #define GET_OP_CLASSES
544 #include "circt/Dialect/LLHD/IR/LLHD.cpp.inc"
assert(baseType &&"element must be base type")
static Attribute constFoldTernaryOp(ArrayRef< Attribute > operands, const CalculationT &calculate)
Definition: LLHDOps.cpp:69
static LogicalResult inferReturnTypesOfStructExtractOp(MLIRContext *context, std::optional< Location > loc, ValueRange operands, DictionaryAttr attrs, mlir::OpaqueProperties properties, mlir::RegionRange regions, SmallVectorImpl< Type > &results)
Definition: LLHDOps.cpp:268
static OpFoldResult foldSigPtrArraySliceOp(Op op, ArrayRef< Attribute > operands)
Definition: LLHDOps.cpp:206
static LogicalResult canonicalizeSigPtrArraySliceOp(Op op, PatternRewriter &rewriter)
Definition: LLHDOps.cpp:228
static Attribute constFoldUnaryOp(ArrayRef< Attribute > operands, const CalculationT &calculate)
Definition: LLHDOps.cpp:38
static OpFoldResult foldSigPtrExtractOp(Op op, ArrayRef< Attribute > operands)
Definition: LLHDOps.cpp:180
static std::optional< APInt > getInt(Value value)
Helper to convert a value to a constant integer if it is one.
static LogicalResult canonicalize(Op op, PatternRewriter &rewriter)
Definition: VerifOps.cpp:66
static LogicalResult verify(Value clock, bool eventExists, mlir::Location loc)
Definition: SVOps.cpp:2443
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:54
LogicalResult inferReturnTypes(MLIRContext *context, std::optional< Location > loc, ValueRange operands, DictionaryAttr attrs, mlir::OpaqueProperties properties, mlir::RegionRange regions, SmallVectorImpl< Type > &results, llvm::function_ref< FIRRTLType(ValueRange, ArrayRef< NamedAttribute >, std::optional< Location >)> callback)
unsigned getLLHDTypeWidth(Type type)
Definition: LLHDOps.cpp:132
Type getLLHDElementType(Type type)
Definition: LLHDOps.cpp:144
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21