CIRCT 20.0.0git
Loading...
Searching...
No Matches
SMTOps.cpp
Go to the documentation of this file.
1//===- SMTOps.cpp ---------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10#include "mlir/IR/Builders.h"
11#include "mlir/IR/OpImplementation.h"
12#include "llvm/ADT/APSInt.h"
13
14using namespace circt;
15using namespace smt;
16using namespace mlir;
17
18//===----------------------------------------------------------------------===//
19// BVConstantOp
20//===----------------------------------------------------------------------===//
21
22LogicalResult BVConstantOp::inferReturnTypes(
23 mlir::MLIRContext *context, std::optional<mlir::Location> location,
24 ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes,
25 ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions,
26 ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) {
27 inferredReturnTypes.push_back(
28 properties.as<Properties *>()->getValue().getType());
29 return success();
30}
31
32void BVConstantOp::getAsmResultNames(
33 function_ref<void(Value, StringRef)> setNameFn) {
34 SmallVector<char, 128> specialNameBuffer;
35 llvm::raw_svector_ostream specialName(specialNameBuffer);
36 specialName << "c" << getValue().getValue() << "_bv"
37 << getValue().getValue().getBitWidth();
38 setNameFn(getResult(), specialName.str());
39}
40
41OpFoldResult BVConstantOp::fold(FoldAdaptor adaptor) {
42 assert(adaptor.getOperands().empty() && "constant has no operands");
43 return getValueAttr();
44}
45
46//===----------------------------------------------------------------------===//
47// DeclareFunOp
48//===----------------------------------------------------------------------===//
49
50void DeclareFunOp::getAsmResultNames(
51 function_ref<void(Value, StringRef)> setNameFn) {
52 setNameFn(getResult(), getNamePrefix().has_value() ? *getNamePrefix() : "");
53}
54
55//===----------------------------------------------------------------------===//
56// SolverOp
57//===----------------------------------------------------------------------===//
58
59LogicalResult SolverOp::verifyRegions() {
60 if (getBody()->getTerminator()->getOperands().getTypes() != getResultTypes())
61 return emitOpError() << "types of yielded values must match return values";
62 if (getBody()->getArgumentTypes() != getInputs().getTypes())
63 return emitOpError()
64 << "block argument types must match the types of the 'inputs'";
65
66 return success();
67}
68
69//===----------------------------------------------------------------------===//
70// CheckOp
71//===----------------------------------------------------------------------===//
72
73LogicalResult CheckOp::verifyRegions() {
74 if (getSatRegion().front().getTerminator()->getOperands().getTypes() !=
75 getResultTypes())
76 return emitOpError() << "types of yielded values in 'sat' region must "
77 "match return values";
78 if (getUnknownRegion().front().getTerminator()->getOperands().getTypes() !=
79 getResultTypes())
80 return emitOpError() << "types of yielded values in 'unknown' region must "
81 "match return values";
82 if (getUnsatRegion().front().getTerminator()->getOperands().getTypes() !=
83 getResultTypes())
84 return emitOpError() << "types of yielded values in 'unsat' region must "
85 "match return values";
86
87 return success();
88}
89
90//===----------------------------------------------------------------------===//
91// EqOp
92//===----------------------------------------------------------------------===//
93
94static LogicalResult
96 OperationState &result) {
97 SmallVector<OpAsmParser::UnresolvedOperand, 4> inputs;
98 SMLoc loc = parser.getCurrentLocation();
99 Type type;
100
101 if (parser.parseOperandList(inputs) ||
102 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
103 parser.parseType(type))
104 return failure();
105
106 result.addTypes(BoolType::get(parser.getContext()));
107 if (parser.resolveOperands(inputs, SmallVector<Type>(inputs.size(), type),
108 loc, result.operands))
109 return failure();
110
111 return success();
112}
113
114ParseResult EqOp::parse(OpAsmParser &parser, OperationState &result) {
115 return parseSameOperandTypeVariadicToBoolOp(parser, result);
116}
117
118void EqOp::print(OpAsmPrinter &printer) {
119 printer << ' ' << getInputs();
120 printer.printOptionalAttrDict(getOperation()->getAttrs());
121 printer << " : " << getInputs().front().getType();
122}
123
124LogicalResult EqOp::verify() {
125 if (getInputs().size() < 2)
126 return emitOpError() << "'inputs' must have at least size 2, but got "
127 << getInputs().size();
128
129 return success();
130}
131
132//===----------------------------------------------------------------------===//
133// DistinctOp
134//===----------------------------------------------------------------------===//
135
136ParseResult DistinctOp::parse(OpAsmParser &parser, OperationState &result) {
137 return parseSameOperandTypeVariadicToBoolOp(parser, result);
138}
139
140void DistinctOp::print(OpAsmPrinter &printer) {
141 printer << ' ' << getInputs();
142 printer.printOptionalAttrDict(getOperation()->getAttrs());
143 printer << " : " << getInputs().front().getType();
144}
145
146LogicalResult DistinctOp::verify() {
147 if (getInputs().size() < 2)
148 return emitOpError() << "'inputs' must have at least size 2, but got "
149 << getInputs().size();
150
151 return success();
152}
153
154//===----------------------------------------------------------------------===//
155// ExtractOp
156//===----------------------------------------------------------------------===//
157
158LogicalResult ExtractOp::verify() {
159 unsigned rangeWidth = getType().getWidth();
160 unsigned inputWidth = cast<BitVectorType>(getInput().getType()).getWidth();
161 if (getLowBit() + rangeWidth > inputWidth)
162 return emitOpError("range to be extracted is too big, expected range "
163 "starting at index ")
164 << getLowBit() << " of length " << rangeWidth
165 << " requires input width of at least " << (getLowBit() + rangeWidth)
166 << ", but the input width is only " << inputWidth;
167 return success();
168}
169
170//===----------------------------------------------------------------------===//
171// ConcatOp
172//===----------------------------------------------------------------------===//
173
174LogicalResult ConcatOp::inferReturnTypes(
175 MLIRContext *context, std::optional<Location> location, ValueRange operands,
176 DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
177 SmallVectorImpl<Type> &inferredReturnTypes) {
178 inferredReturnTypes.push_back(BitVectorType::get(
179 context, cast<BitVectorType>(operands[0].getType()).getWidth() +
180 cast<BitVectorType>(operands[1].getType()).getWidth()));
181 return success();
182}
183
184//===----------------------------------------------------------------------===//
185// RepeatOp
186//===----------------------------------------------------------------------===//
187
188LogicalResult RepeatOp::verify() {
189 unsigned inputWidth = cast<BitVectorType>(getInput().getType()).getWidth();
190 unsigned resultWidth = getType().getWidth();
191 if (resultWidth % inputWidth != 0)
192 return emitOpError() << "result bit-vector width must be a multiple of the "
193 "input bit-vector width";
194
195 return success();
196}
197
198unsigned RepeatOp::getCount() {
199 unsigned inputWidth = cast<BitVectorType>(getInput().getType()).getWidth();
200 unsigned resultWidth = getType().getWidth();
201 return resultWidth / inputWidth;
202}
203
204void RepeatOp::build(OpBuilder &builder, OperationState &state, unsigned count,
205 Value input) {
206 unsigned inputWidth = cast<BitVectorType>(input.getType()).getWidth();
207 Type resultTy = BitVectorType::get(builder.getContext(), inputWidth * count);
208 build(builder, state, resultTy, input);
209}
210
211ParseResult RepeatOp::parse(OpAsmParser &parser, OperationState &result) {
212 OpAsmParser::UnresolvedOperand input;
213 Type inputType;
214 llvm::SMLoc countLoc = parser.getCurrentLocation();
215
216 APInt count;
217 if (parser.parseInteger(count) || parser.parseKeyword("times"))
218 return failure();
219
220 if (count.isNonPositive())
221 return parser.emitError(countLoc) << "integer must be positive";
222
223 llvm::SMLoc inputLoc = parser.getCurrentLocation();
224 if (parser.parseOperand(input) ||
225 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
226 parser.parseType(inputType))
227 return failure();
228
229 if (parser.resolveOperand(input, inputType, result.operands))
230 return failure();
231
232 auto bvInputTy = dyn_cast<BitVectorType>(inputType);
233 if (!bvInputTy)
234 return parser.emitError(inputLoc) << "input must have bit-vector type";
235
236 // Make sure no assertions can trigger and no silent overflows can happen
237 // Bit-width is stored as 'int64_t' parameter in 'BitVectorType'
238 const unsigned maxBw = 63;
239 if (count.getActiveBits() > maxBw)
240 return parser.emitError(countLoc)
241 << "integer must fit into " << maxBw << " bits";
242
243 // Store multiplication in an APInt twice the size to not have any overflow
244 // and check if it can be truncated to 'maxBw' bits without cutting of
245 // important bits.
246 APInt resultBw = bvInputTy.getWidth() * count.zext(2 * maxBw);
247 if (resultBw.getActiveBits() > maxBw)
248 return parser.emitError(countLoc)
249 << "result bit-width (provided integer times bit-width of the input "
250 "type) must fit into "
251 << maxBw << " bits";
252
253 Type resultTy =
254 BitVectorType::get(parser.getContext(), resultBw.getZExtValue());
255 result.addTypes(resultTy);
256 return success();
257}
258
259void RepeatOp::print(OpAsmPrinter &printer) {
260 printer << " " << getCount() << " times " << getInput();
261 printer.printOptionalAttrDict((*this)->getAttrs());
262 printer << " : " << getInput().getType();
263}
264
265//===----------------------------------------------------------------------===//
266// BoolConstantOp
267//===----------------------------------------------------------------------===//
268
269void BoolConstantOp::getAsmResultNames(
270 function_ref<void(Value, StringRef)> setNameFn) {
271 setNameFn(getResult(), getValue() ? "true" : "false");
272}
273
274OpFoldResult BoolConstantOp::fold(FoldAdaptor adaptor) {
275 assert(adaptor.getOperands().empty() && "constant has no operands");
276 return getValueAttr();
277}
278
279//===----------------------------------------------------------------------===//
280// IntConstantOp
281//===----------------------------------------------------------------------===//
282
283void IntConstantOp::getAsmResultNames(
284 function_ref<void(Value, StringRef)> setNameFn) {
285 SmallVector<char, 32> specialNameBuffer;
286 llvm::raw_svector_ostream specialName(specialNameBuffer);
287 specialName << "c" << getValue();
288 setNameFn(getResult(), specialName.str());
289}
290
291OpFoldResult IntConstantOp::fold(FoldAdaptor adaptor) {
292 assert(adaptor.getOperands().empty() && "constant has no operands");
293 return getValueAttr();
294}
295
296void IntConstantOp::print(OpAsmPrinter &p) {
297 p << " " << getValue();
298 p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"value"});
299}
300
301ParseResult IntConstantOp::parse(OpAsmParser &parser, OperationState &result) {
302 APInt value;
303 if (parser.parseInteger(value))
304 return failure();
305
306 result.getOrAddProperties<Properties>().setValue(
307 IntegerAttr::get(parser.getContext(), APSInt(value)));
308
309 if (parser.parseOptionalAttrDict(result.attributes))
310 return failure();
311
312 result.addTypes(smt::IntType::get(parser.getContext()));
313 return success();
314}
315
316//===----------------------------------------------------------------------===//
317// ForallOp
318//===----------------------------------------------------------------------===//
319
320template <typename QuantifierOp>
321static LogicalResult verifyQuantifierRegions(QuantifierOp op) {
322 if (op.getBoundVarNames() &&
323 op.getBody().getNumArguments() != op.getBoundVarNames()->size())
324 return op.emitOpError(
325 "number of bound variable names must match number of block arguments");
326 if (!llvm::all_of(op.getBody().getArgumentTypes(), isAnyNonFuncSMTValueType))
327 return op.emitOpError()
328 << "bound variables must by any non-function SMT value";
329
330 if (op.getBody().front().getTerminator()->getNumOperands() != 1)
331 return op.emitOpError("must have exactly one yielded value");
332 if (!isa<BoolType>(
333 op.getBody().front().getTerminator()->getOperand(0).getType()))
334 return op.emitOpError("yielded value must be of '!smt.bool' type");
335
336 for (auto regionWithIndex : llvm::enumerate(op.getPatterns())) {
337 unsigned i = regionWithIndex.index();
338 Region &region = regionWithIndex.value();
339
340 if (op.getBody().getArgumentTypes() != region.getArgumentTypes())
341 return op.emitOpError()
342 << "block argument number and types of the 'body' "
343 "and 'patterns' region #"
344 << i << " must match";
345 if (region.front().getTerminator()->getNumOperands() < 1)
346 return op.emitOpError() << "'patterns' region #" << i
347 << " must have at least one yielded value";
348
349 // All operations in the 'patterns' region must be SMT operations.
350 auto result = region.walk([&](Operation *childOp) {
351 if (!isa<SMTDialect>(childOp->getDialect())) {
352 auto diag = op.emitOpError()
353 << "the 'patterns' region #" << i
354 << " may only contain SMT dialect operations";
355 diag.attachNote(childOp->getLoc()) << "first non-SMT operation here";
356 return WalkResult::interrupt();
357 }
358
359 // There may be no quantifier (or other variable binding) operations in
360 // the 'patterns' region.
361 if (isa<ForallOp, ExistsOp>(childOp)) {
362 auto diag = op.emitOpError() << "the 'patterns' region #" << i
363 << " must not contain "
364 "any variable binding operations";
365 diag.attachNote(childOp->getLoc()) << "first violating operation here";
366 return WalkResult::interrupt();
367 }
368
369 return WalkResult::advance();
370 });
371 if (result.wasInterrupted())
372 return failure();
373 }
374
375 return success();
376}
377
378template <typename Properties>
379static void buildQuantifier(
380 OpBuilder &odsBuilder, OperationState &odsState, TypeRange boundVarTypes,
381 function_ref<Value(OpBuilder &, Location, ValueRange)> bodyBuilder,
382 std::optional<ArrayRef<StringRef>> boundVarNames,
383 function_ref<ValueRange(OpBuilder &, Location, ValueRange)> patternBuilder,
384 uint32_t weight, bool noPattern) {
385 odsState.addTypes(BoolType::get(odsBuilder.getContext()));
386 if (weight != 0)
387 odsState.getOrAddProperties<Properties>().weight =
388 odsBuilder.getIntegerAttr(odsBuilder.getIntegerType(32), weight);
389 if (noPattern)
390 odsState.getOrAddProperties<Properties>().noPattern =
391 odsBuilder.getUnitAttr();
392 if (boundVarNames.has_value()) {
393 SmallVector<Attribute> boundVarNamesList;
394 for (StringRef str : *boundVarNames)
395 boundVarNamesList.emplace_back(odsBuilder.getStringAttr(str));
396 odsState.getOrAddProperties<Properties>().boundVarNames =
397 odsBuilder.getArrayAttr(boundVarNamesList);
398 }
399 {
400 OpBuilder::InsertionGuard guard(odsBuilder);
401 Region *region = odsState.addRegion();
402 Block *block = odsBuilder.createBlock(region);
403 block->addArguments(
404 boundVarTypes,
405 SmallVector<Location>(boundVarTypes.size(), odsState.location));
406 Value returnVal =
407 bodyBuilder(odsBuilder, odsState.location, block->getArguments());
408 odsBuilder.create<smt::YieldOp>(odsState.location, returnVal);
409 }
410 if (patternBuilder) {
411 Region *region = odsState.addRegion();
412 OpBuilder::InsertionGuard guard(odsBuilder);
413 Block *block = odsBuilder.createBlock(region);
414 block->addArguments(
415 boundVarTypes,
416 SmallVector<Location>(boundVarTypes.size(), odsState.location));
417 ValueRange returnVals =
418 patternBuilder(odsBuilder, odsState.location, block->getArguments());
419 odsBuilder.create<smt::YieldOp>(odsState.location, returnVals);
420 }
421}
422
423LogicalResult ForallOp::verify() {
424 if (!getPatterns().empty() && getNoPattern())
425 return emitOpError() << "patterns and the no_pattern attribute must not be "
426 "specified at the same time";
427
428 return success();
429}
430
431LogicalResult ForallOp::verifyRegions() {
432 return verifyQuantifierRegions(*this);
433}
434
435void ForallOp::build(
436 OpBuilder &odsBuilder, OperationState &odsState, TypeRange boundVarTypes,
437 function_ref<Value(OpBuilder &, Location, ValueRange)> bodyBuilder,
438 std::optional<ArrayRef<StringRef>> boundVarNames,
439 function_ref<ValueRange(OpBuilder &, Location, ValueRange)> patternBuilder,
440 uint32_t weight, bool noPattern) {
441 buildQuantifier<Properties>(odsBuilder, odsState, boundVarTypes, bodyBuilder,
442 boundVarNames, patternBuilder, weight, noPattern);
443}
444
445//===----------------------------------------------------------------------===//
446// ExistsOp
447//===----------------------------------------------------------------------===//
448
449LogicalResult ExistsOp::verify() {
450 if (!getPatterns().empty() && getNoPattern())
451 return emitOpError() << "patterns and the no_pattern attribute must not be "
452 "specified at the same time";
453
454 return success();
455}
456
457LogicalResult ExistsOp::verifyRegions() {
458 return verifyQuantifierRegions(*this);
459}
460
461void ExistsOp::build(
462 OpBuilder &odsBuilder, OperationState &odsState, TypeRange boundVarTypes,
463 function_ref<Value(OpBuilder &, Location, ValueRange)> bodyBuilder,
464 std::optional<ArrayRef<StringRef>> boundVarNames,
465 function_ref<ValueRange(OpBuilder &, Location, ValueRange)> patternBuilder,
466 uint32_t weight, bool noPattern) {
467 buildQuantifier<Properties>(odsBuilder, odsState, boundVarTypes, bodyBuilder,
468 boundVarNames, patternBuilder, weight, noPattern);
469}
470
471#define GET_OP_CLASSES
472#include "circt/Dialect/SMT/SMT.cpp.inc"
assert(baseType &&"element must be base type")
static InstancePath empty
static LogicalResult verifyQuantifierRegions(QuantifierOp op)
Definition SMTOps.cpp:321
static LogicalResult parseSameOperandTypeVariadicToBoolOp(OpAsmParser &parser, OperationState &result)
Definition SMTOps.cpp:95
static void buildQuantifier(OpBuilder &odsBuilder, OperationState &odsState, TypeRange boundVarTypes, function_ref< Value(OpBuilder &, Location, ValueRange)> bodyBuilder, std::optional< ArrayRef< StringRef > > boundVarNames, function_ref< ValueRange(OpBuilder &, Location, ValueRange)> patternBuilder, uint32_t weight, bool noPattern)
Definition SMTOps.cpp:379
uint64_t getWidth(Type t)
Definition ESIPasses.cpp:32
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition smt.py:1