CIRCT  19.0.0git
SMTAttributes.cpp
Go to the documentation of this file.
1 //===- SMTAttributes.cpp - Implement SMT attributes -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
12 #include "mlir/IR/Builders.h"
13 #include "mlir/IR/DialectImplementation.h"
14 #include "llvm/ADT/TypeSwitch.h"
15 #include "llvm/Support/Format.h"
16 
17 using namespace circt;
18 using namespace circt::smt;
19 
20 //===----------------------------------------------------------------------===//
21 // BitVectorAttr
22 //===----------------------------------------------------------------------===//
23 
24 namespace circt {
25 namespace smt {
26 namespace detail {
27 struct BitVectorAttrStorage : public mlir::AttributeStorage {
28  using KeyTy = APInt;
29  BitVectorAttrStorage(APInt value) : value(std::move(value)) {}
30 
31  KeyTy getAsKey() const { return value; }
32 
33  // NOTE: the implementation of this operator is the reason we need to define
34  // the storage manually. The auto-generated version would just do the direct
35  // equality check of the APInt, but that asserts the bitwidth of both to be
36  // the same, leading to a crash. This implementation, therefore, checks for
37  // matching bit-width beforehand.
38  bool operator==(const KeyTy &key) const {
39  return (value.getBitWidth() == key.getBitWidth() && value == key);
40  }
41 
42  static llvm::hash_code hashKey(const KeyTy &key) {
43  return llvm::hash_value(key);
44  }
45 
46  static BitVectorAttrStorage *
47  construct(mlir::AttributeStorageAllocator &allocator, KeyTy &&key) {
48  return new (allocator.allocate<BitVectorAttrStorage>())
49  BitVectorAttrStorage(std::move(key));
50  }
51 
52  APInt value;
53 };
54 } // namespace detail
55 } // namespace smt
56 } // namespace circt
57 
58 APInt BitVectorAttr::getValue() const { return getImpl()->value; }
59 
60 LogicalResult BitVectorAttr::verify(
61  function_ref<InFlightDiagnostic()> emitError,
62  APInt value) { // NOLINT(performance-unnecessary-value-param)
63  if (value.getBitWidth() < 1)
64  return emitError() << "bit-width must be at least 1, but got "
65  << value.getBitWidth();
66  return success();
67 }
68 
69 std::string BitVectorAttr::getValueAsString(bool prefix) const {
70  unsigned width = getValue().getBitWidth();
71  SmallVector<char> toPrint;
72  StringRef pref = prefix ? "#" : "";
73  if (width % 4 == 0) {
74  getValue().toString(toPrint, 16, false, false, false);
75  // APInt's 'toString' omits leading zeros. However, those are critical here
76  // because they determine the bit-width of the bit-vector.
77  SmallVector<char> leadingZeros(width / 4 - toPrint.size(), '0');
78  return (pref + "x" + Twine(leadingZeros) + toPrint).str();
79  }
80 
81  getValue().toString(toPrint, 2, false, false, false);
82  // APInt's 'toString' omits leading zeros
83  SmallVector<char> leadingZeros(width - toPrint.size(), '0');
84  return (pref + "b" + Twine(leadingZeros) + toPrint).str();
85 }
86 
87 /// Parse an SMT-LIB formatted bit-vector string.
88 static FailureOr<APInt>
89 parseBitVectorString(function_ref<InFlightDiagnostic()> emitError,
90  StringRef value) {
91  if (value[0] != '#')
92  return emitError() << "expected '#'";
93 
94  if (value.size() < 3)
95  return emitError() << "expected at least one digit";
96 
97  if (value[1] == 'b')
98  return APInt(value.size() - 2, std::string(value.begin() + 2, value.end()),
99  2);
100 
101  if (value[1] == 'x')
102  return APInt((value.size() - 2) * 4,
103  std::string(value.begin() + 2, value.end()), 16);
104 
105  return emitError() << "expected either 'b' or 'x'";
106 }
107 
108 BitVectorAttr BitVectorAttr::get(MLIRContext *context, StringRef value) {
109  auto maybeValue = parseBitVectorString(nullptr, value);
110 
111  assert(succeeded(maybeValue) && "string must have SMT-LIB format");
112  return Base::get(context, *maybeValue);
113 }
114 
115 BitVectorAttr
116 BitVectorAttr::getChecked(function_ref<InFlightDiagnostic()> emitError,
117  MLIRContext *context, StringRef value) {
118  auto maybeValue = parseBitVectorString(emitError, value);
119  if (failed(maybeValue))
120  return {};
121 
122  return Base::getChecked(emitError, context, *maybeValue);
123 }
124 
125 BitVectorAttr BitVectorAttr::get(MLIRContext *context, uint64_t value,
126  unsigned width) {
127  return Base::get(context, APInt(width, value));
128 }
129 
130 BitVectorAttr
131 BitVectorAttr::getChecked(function_ref<InFlightDiagnostic()> emitError,
132  MLIRContext *context, uint64_t value,
133  unsigned width) {
134  if (width < 64 && value >= (UINT64_C(1) << width)) {
135  emitError() << "value does not fit in a bit-vector of desired width";
136  return {};
137  }
138  return Base::getChecked(emitError, context, APInt(width, value));
139 }
140 
141 Attribute BitVectorAttr::parse(AsmParser &odsParser, Type odsType) {
142  llvm::SMLoc loc = odsParser.getCurrentLocation();
143 
144  APInt val;
145  if (odsParser.parseLess() || odsParser.parseInteger(val) ||
146  odsParser.parseGreater())
147  return {};
148 
149  // Requires the use of `quantified(<attr>)` in operation assembly formats.
150  if (!odsType || !llvm::isa<BitVectorType>(odsType)) {
151  odsParser.emitError(loc) << "explicit bit-vector type required";
152  return {};
153  }
154 
155  unsigned width = llvm::cast<BitVectorType>(odsType).getWidth();
156 
157  if (width > val.getBitWidth()) {
158  // sext is always safe here, even for unsigned values, because the
159  // parseOptionalInteger method will return something with a zero in the
160  // top bits if it is a positive number.
161  val = val.sext(width);
162  } else if (width < val.getBitWidth()) {
163  // The parser can return an unnecessarily wide result.
164  // This isn't a problem, but truncating off bits is bad.
165  unsigned neededBits =
166  val.isNegative() ? val.getSignificantBits() : val.getActiveBits();
167  if (width < neededBits) {
168  odsParser.emitError(loc)
169  << "integer value out of range for given bit-vector type " << odsType;
170  return {};
171  }
172  val = val.trunc(width);
173  }
174 
175  return BitVectorAttr::get(odsParser.getContext(), val);
176 }
177 
178 void BitVectorAttr::print(AsmPrinter &odsPrinter) const {
179  // This printer only works for the extended format where the MLIR
180  // infrastructure prints the type for us. This means, the attribute should
181  // never be used without `quantified` in an assembly format.
182  odsPrinter << "<" << getValue() << ">";
183 }
184 
185 Type BitVectorAttr::getType() const {
186  return BitVectorType::get(getContext(), getValue().getBitWidth());
187 }
188 
189 //===----------------------------------------------------------------------===//
190 // ODS Boilerplate
191 //===----------------------------------------------------------------------===//
192 
193 #define GET_ATTRDEF_CLASSES
194 #include "circt/Dialect/SMT/SMTAttributes.cpp.inc"
195 
196 void SMTDialect::registerAttributes() {
197  addAttributes<
198 #define GET_ATTRDEF_LIST
199 #include "circt/Dialect/SMT/SMTAttributes.cpp.inc"
200  >();
201 }
assert(baseType &&"element must be base type")
int32_t width
Definition: FIRRTL.cpp:36
static FailureOr< APInt > parseBitVectorString(function_ref< InFlightDiagnostic()> emitError, StringRef value)
Parse an SMT-LIB formatted bit-vector string.
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:54
llvm::hash_code hash_value(const BundledChannel channel)
Definition: ESITypes.h:48
std::optional< int64_t > getBitWidth(FIRRTLBaseType type, bool ignoreFlip=false)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21
static llvm::hash_code hashKey(const KeyTy &key)
bool operator==(const KeyTy &key) const
static BitVectorAttrStorage * construct(mlir::AttributeStorageAllocator &allocator, KeyTy &&key)