CIRCT 20.0.0git
Loading...
Searching...
No Matches
SMTAttributes.cpp
Go to the documentation of this file.
1//===- SMTAttributes.cpp - Implement SMT attributes -----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
12#include "mlir/IR/Builders.h"
13#include "mlir/IR/DialectImplementation.h"
14#include "llvm/ADT/TypeSwitch.h"
15#include "llvm/Support/Format.h"
16
17using namespace circt;
18using namespace circt::smt;
19
20//===----------------------------------------------------------------------===//
21// BitVectorAttr
22//===----------------------------------------------------------------------===//
23
24namespace circt {
25namespace smt {
26namespace detail {
27struct BitVectorAttrStorage : public mlir::AttributeStorage {
28 using KeyTy = APInt;
29 BitVectorAttrStorage(APInt value) : value(std::move(value)) {}
30
31 KeyTy getAsKey() const { return value; }
32
33 // NOTE: the implementation of this operator is the reason we need to define
34 // the storage manually. The auto-generated version would just do the direct
35 // equality check of the APInt, but that asserts the bitwidth of both to be
36 // the same, leading to a crash. This implementation, therefore, checks for
37 // matching bit-width beforehand.
38 bool operator==(const KeyTy &key) const {
39 return (value.getBitWidth() == key.getBitWidth() && value == key);
40 }
41
42 static llvm::hash_code hashKey(const KeyTy &key) {
43 return llvm::hash_value(key);
44 }
45
47 construct(mlir::AttributeStorageAllocator &allocator, KeyTy &&key) {
48 return new (allocator.allocate<BitVectorAttrStorage>())
49 BitVectorAttrStorage(std::move(key));
50 }
51
52 APInt value;
53};
54} // namespace detail
55} // namespace smt
56} // namespace circt
57
58APInt BitVectorAttr::getValue() const { return getImpl()->value; }
59
60LogicalResult BitVectorAttr::verify(
61 function_ref<InFlightDiagnostic()> emitError,
62 APInt value) { // NOLINT(performance-unnecessary-value-param)
63 if (value.getBitWidth() < 1)
64 return emitError() << "bit-width must be at least 1, but got "
65 << value.getBitWidth();
66 return success();
67}
68
69std::string BitVectorAttr::getValueAsString(bool prefix) const {
70 unsigned width = getValue().getBitWidth();
71 SmallVector<char> toPrint;
72 StringRef pref = prefix ? "#" : "";
73 if (width % 4 == 0) {
74 getValue().toString(toPrint, 16, false, false, false);
75 // APInt's 'toString' omits leading zeros. However, those are critical here
76 // because they determine the bit-width of the bit-vector.
77 SmallVector<char> leadingZeros(width / 4 - toPrint.size(), '0');
78 return (pref + "x" + Twine(leadingZeros) + toPrint).str();
79 }
80
81 getValue().toString(toPrint, 2, false, false, false);
82 // APInt's 'toString' omits leading zeros
83 SmallVector<char> leadingZeros(width - toPrint.size(), '0');
84 return (pref + "b" + Twine(leadingZeros) + toPrint).str();
85}
86
87/// Parse an SMT-LIB formatted bit-vector string.
88static FailureOr<APInt>
89parseBitVectorString(function_ref<InFlightDiagnostic()> emitError,
90 StringRef value) {
91 if (value[0] != '#')
92 return emitError() << "expected '#'";
93
94 if (value.size() < 3)
95 return emitError() << "expected at least one digit";
96
97 if (value[1] == 'b')
98 return APInt(value.size() - 2, std::string(value.begin() + 2, value.end()),
99 2);
100
101 if (value[1] == 'x')
102 return APInt((value.size() - 2) * 4,
103 std::string(value.begin() + 2, value.end()), 16);
104
105 return emitError() << "expected either 'b' or 'x'";
106}
107
108BitVectorAttr BitVectorAttr::get(MLIRContext *context, StringRef value) {
109 auto maybeValue = parseBitVectorString(nullptr, value);
110
111 assert(succeeded(maybeValue) && "string must have SMT-LIB format");
112 return Base::get(context, *maybeValue);
113}
114
115BitVectorAttr
116BitVectorAttr::getChecked(function_ref<InFlightDiagnostic()> emitError,
117 MLIRContext *context, StringRef value) {
118 auto maybeValue = parseBitVectorString(emitError, value);
119 if (failed(maybeValue))
120 return {};
121
122 return Base::getChecked(emitError, context, *maybeValue);
123}
124
125BitVectorAttr BitVectorAttr::get(MLIRContext *context, uint64_t value,
126 unsigned width) {
127 return Base::get(context, APInt(width, value));
128}
129
130BitVectorAttr
131BitVectorAttr::getChecked(function_ref<InFlightDiagnostic()> emitError,
132 MLIRContext *context, uint64_t value,
133 unsigned width) {
134 if (width < 64 && value >= (UINT64_C(1) << width)) {
135 emitError() << "value does not fit in a bit-vector of desired width";
136 return {};
137 }
138 return Base::getChecked(emitError, context, APInt(width, value));
139}
140
141Attribute BitVectorAttr::parse(AsmParser &odsParser, Type odsType) {
142 llvm::SMLoc loc = odsParser.getCurrentLocation();
143
144 APInt val;
145 if (odsParser.parseLess() || odsParser.parseInteger(val) ||
146 odsParser.parseGreater())
147 return {};
148
149 // Requires the use of `quantified(<attr>)` in operation assembly formats.
150 if (!odsType || !llvm::isa<BitVectorType>(odsType)) {
151 odsParser.emitError(loc) << "explicit bit-vector type required";
152 return {};
153 }
154
155 unsigned width = llvm::cast<BitVectorType>(odsType).getWidth();
156
157 if (width > val.getBitWidth()) {
158 // sext is always safe here, even for unsigned values, because the
159 // parseOptionalInteger method will return something with a zero in the
160 // top bits if it is a positive number.
161 val = val.sext(width);
162 } else if (width < val.getBitWidth()) {
163 // The parser can return an unnecessarily wide result.
164 // This isn't a problem, but truncating off bits is bad.
165 unsigned neededBits =
166 val.isNegative() ? val.getSignificantBits() : val.getActiveBits();
167 if (width < neededBits) {
168 odsParser.emitError(loc)
169 << "integer value out of range for given bit-vector type " << odsType;
170 return {};
171 }
172 val = val.trunc(width);
173 }
174
175 return BitVectorAttr::get(odsParser.getContext(), val);
176}
177
178void BitVectorAttr::print(AsmPrinter &odsPrinter) const {
179 // This printer only works for the extended format where the MLIR
180 // infrastructure prints the type for us. This means, the attribute should
181 // never be used without `quantified` in an assembly format.
182 odsPrinter << "<" << getValue() << ">";
183}
184
185Type BitVectorAttr::getType() const {
186 return BitVectorType::get(getContext(), getValue().getBitWidth());
187}
188
189//===----------------------------------------------------------------------===//
190// ODS Boilerplate
191//===----------------------------------------------------------------------===//
192
193#define GET_ATTRDEF_CLASSES
194#include "circt/Dialect/SMT/SMTAttributes.cpp.inc"
195
196void SMTDialect::registerAttributes() {
197 addAttributes<
198#define GET_ATTRDEF_LIST
199#include "circt/Dialect/SMT/SMTAttributes.cpp.inc"
200 >();
201}
assert(baseType &&"element must be base type")
static FailureOr< APInt > parseBitVectorString(function_ref< InFlightDiagnostic()> emitError, StringRef value)
Parse an SMT-LIB formatted bit-vector string.
int64_t getBitWidth(mlir::Type type)
Return the hardware bit width of a type.
Definition HWTypes.cpp:110
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition smt.py:1
static llvm::hash_code hashKey(const KeyTy &key)
bool operator==(const KeyTy &key) const
static BitVectorAttrStorage * construct(mlir::AttributeStorageAllocator &allocator, KeyTy &&key)