CIRCT  20.0.0git
SMTOps.cpp
Go to the documentation of this file.
1 //===- SMTOps.cpp ---------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 #include "mlir/IR/Builders.h"
11 #include "mlir/IR/OpImplementation.h"
12 #include "llvm/ADT/APSInt.h"
13 
14 using namespace circt;
15 using namespace smt;
16 using namespace mlir;
17 
18 //===----------------------------------------------------------------------===//
19 // BVConstantOp
20 //===----------------------------------------------------------------------===//
21 
22 LogicalResult BVConstantOp::inferReturnTypes(
23  mlir::MLIRContext *context, std::optional<mlir::Location> location,
24  ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes,
25  ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions,
26  ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) {
27  inferredReturnTypes.push_back(
28  properties.as<Properties *>()->getValue().getType());
29  return success();
30 }
31 
33  function_ref<void(Value, StringRef)> setNameFn) {
34  SmallVector<char, 128> specialNameBuffer;
35  llvm::raw_svector_ostream specialName(specialNameBuffer);
36  specialName << "c" << getValue().getValue() << "_bv"
37  << getValue().getValue().getBitWidth();
38  setNameFn(getResult(), specialName.str());
39 }
40 
41 OpFoldResult BVConstantOp::fold(FoldAdaptor adaptor) {
42  assert(adaptor.getOperands().empty() && "constant has no operands");
43  return getValueAttr();
44 }
45 
46 //===----------------------------------------------------------------------===//
47 // DeclareFunOp
48 //===----------------------------------------------------------------------===//
49 
51  function_ref<void(Value, StringRef)> setNameFn) {
52  setNameFn(getResult(), getNamePrefix().has_value() ? *getNamePrefix() : "");
53 }
54 
55 //===----------------------------------------------------------------------===//
56 // SolverOp
57 //===----------------------------------------------------------------------===//
58 
59 LogicalResult SolverOp::verifyRegions() {
60  if (getBody()->getTerminator()->getOperands().getTypes() != getResultTypes())
61  return emitOpError() << "types of yielded values must match return values";
62  if (getBody()->getArgumentTypes() != getInputs().getTypes())
63  return emitOpError()
64  << "block argument types must match the types of the 'inputs'";
65 
66  return success();
67 }
68 
69 //===----------------------------------------------------------------------===//
70 // CheckOp
71 //===----------------------------------------------------------------------===//
72 
73 LogicalResult CheckOp::verifyRegions() {
74  if (getSatRegion().front().getTerminator()->getOperands().getTypes() !=
75  getResultTypes())
76  return emitOpError() << "types of yielded values in 'sat' region must "
77  "match return values";
78  if (getUnknownRegion().front().getTerminator()->getOperands().getTypes() !=
79  getResultTypes())
80  return emitOpError() << "types of yielded values in 'unknown' region must "
81  "match return values";
82  if (getUnsatRegion().front().getTerminator()->getOperands().getTypes() !=
83  getResultTypes())
84  return emitOpError() << "types of yielded values in 'unsat' region must "
85  "match return values";
86 
87  return success();
88 }
89 
90 //===----------------------------------------------------------------------===//
91 // EqOp
92 //===----------------------------------------------------------------------===//
93 
94 static LogicalResult
96  OperationState &result) {
97  SmallVector<OpAsmParser::UnresolvedOperand, 4> inputs;
98  SMLoc loc = parser.getCurrentLocation();
99  Type type;
100 
101  if (parser.parseOperandList(inputs) ||
102  parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
103  parser.parseType(type))
104  return failure();
105 
106  result.addTypes(BoolType::get(parser.getContext()));
107  if (parser.resolveOperands(inputs, SmallVector<Type>(inputs.size(), type),
108  loc, result.operands))
109  return failure();
110 
111  return success();
112 }
113 
114 ParseResult EqOp::parse(OpAsmParser &parser, OperationState &result) {
115  return parseSameOperandTypeVariadicToBoolOp(parser, result);
116 }
117 
118 void EqOp::print(OpAsmPrinter &printer) {
119  printer << ' ' << getInputs();
120  printer.printOptionalAttrDict(getOperation()->getAttrs());
121  printer << " : " << getInputs().front().getType();
122 }
123 
124 LogicalResult EqOp::verify() {
125  if (getInputs().size() < 2)
126  return emitOpError() << "'inputs' must have at least size 2, but got "
127  << getInputs().size();
128 
129  return success();
130 }
131 
132 //===----------------------------------------------------------------------===//
133 // DistinctOp
134 //===----------------------------------------------------------------------===//
135 
136 ParseResult DistinctOp::parse(OpAsmParser &parser, OperationState &result) {
137  return parseSameOperandTypeVariadicToBoolOp(parser, result);
138 }
139 
140 void DistinctOp::print(OpAsmPrinter &printer) {
141  printer << ' ' << getInputs();
142  printer.printOptionalAttrDict(getOperation()->getAttrs());
143  printer << " : " << getInputs().front().getType();
144 }
145 
146 LogicalResult DistinctOp::verify() {
147  if (getInputs().size() < 2)
148  return emitOpError() << "'inputs' must have at least size 2, but got "
149  << getInputs().size();
150 
151  return success();
152 }
153 
154 //===----------------------------------------------------------------------===//
155 // ExtractOp
156 //===----------------------------------------------------------------------===//
157 
158 LogicalResult ExtractOp::verify() {
159  unsigned rangeWidth = getType().getWidth();
160  unsigned inputWidth = cast<BitVectorType>(getInput().getType()).getWidth();
161  if (getLowBit() + rangeWidth > inputWidth)
162  return emitOpError("range to be extracted is too big, expected range "
163  "starting at index ")
164  << getLowBit() << " of length " << rangeWidth
165  << " requires input width of at least " << (getLowBit() + rangeWidth)
166  << ", but the input width is only " << inputWidth;
167  return success();
168 }
169 
170 //===----------------------------------------------------------------------===//
171 // ConcatOp
172 //===----------------------------------------------------------------------===//
173 
174 LogicalResult ConcatOp::inferReturnTypes(
175  MLIRContext *context, std::optional<Location> location, ValueRange operands,
176  DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
177  SmallVectorImpl<Type> &inferredReturnTypes) {
178  inferredReturnTypes.push_back(BitVectorType::get(
179  context, cast<BitVectorType>(operands[0].getType()).getWidth() +
180  cast<BitVectorType>(operands[1].getType()).getWidth()));
181  return success();
182 }
183 
184 //===----------------------------------------------------------------------===//
185 // RepeatOp
186 //===----------------------------------------------------------------------===//
187 
188 LogicalResult RepeatOp::verify() {
189  unsigned inputWidth = cast<BitVectorType>(getInput().getType()).getWidth();
190  unsigned resultWidth = getType().getWidth();
191  if (resultWidth % inputWidth != 0)
192  return emitOpError() << "result bit-vector width must be a multiple of the "
193  "input bit-vector width";
194 
195  return success();
196 }
197 
198 unsigned RepeatOp::getCount() {
199  unsigned inputWidth = cast<BitVectorType>(getInput().getType()).getWidth();
200  unsigned resultWidth = getType().getWidth();
201  return resultWidth / inputWidth;
202 }
203 
204 void RepeatOp::build(OpBuilder &builder, OperationState &state, unsigned count,
205  Value input) {
206  unsigned inputWidth = cast<BitVectorType>(input.getType()).getWidth();
207  Type resultTy = BitVectorType::get(builder.getContext(), inputWidth * count);
208  build(builder, state, resultTy, input);
209 }
210 
211 ParseResult RepeatOp::parse(OpAsmParser &parser, OperationState &result) {
212  OpAsmParser::UnresolvedOperand input;
213  Type inputType;
214  llvm::SMLoc countLoc = parser.getCurrentLocation();
215 
216  APInt count;
217  if (parser.parseInteger(count) || parser.parseKeyword("times"))
218  return failure();
219 
220  if (count.isNonPositive())
221  return parser.emitError(countLoc) << "integer must be positive";
222 
223  llvm::SMLoc inputLoc = parser.getCurrentLocation();
224  if (parser.parseOperand(input) ||
225  parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
226  parser.parseType(inputType))
227  return failure();
228 
229  if (parser.resolveOperand(input, inputType, result.operands))
230  return failure();
231 
232  auto bvInputTy = dyn_cast<BitVectorType>(inputType);
233  if (!bvInputTy)
234  return parser.emitError(inputLoc) << "input must have bit-vector type";
235 
236  // Make sure no assertions can trigger and no silent overflows can happen
237  // Bit-width is stored as 'uint64_t' parameter in 'BitVectorType'
238  const unsigned maxBw = 64;
239  if (count.getActiveBits() > maxBw)
240  return parser.emitError(countLoc)
241  << "integer must fit into " << maxBw << " bits";
242 
243  // Store multiplication in an APInt twice the size to not have any overflow
244  // and check if it can be truncated to 'maxBw' bits without cutting of
245  // important bits.
246  APInt resultBw = bvInputTy.getWidth() * count.zext(2 * maxBw);
247  if (resultBw.getActiveBits() > maxBw)
248  return parser.emitError(countLoc)
249  << "result bit-width (provided integer times bit-width of the input "
250  "type) must fit into "
251  << maxBw << " bits";
252 
253  Type resultTy =
254  BitVectorType::get(parser.getContext(), resultBw.getZExtValue());
255  result.addTypes(resultTy);
256  return success();
257 }
258 
259 void RepeatOp::print(OpAsmPrinter &printer) {
260  printer << " " << getCount() << " times " << getInput();
261  printer.printOptionalAttrDict((*this)->getAttrs());
262  printer << " : " << getInput().getType();
263 }
264 
265 //===----------------------------------------------------------------------===//
266 // BoolConstantOp
267 //===----------------------------------------------------------------------===//
268 
270  function_ref<void(Value, StringRef)> setNameFn) {
271  setNameFn(getResult(), getValue() ? "true" : "false");
272 }
273 
274 OpFoldResult BoolConstantOp::fold(FoldAdaptor adaptor) {
275  assert(adaptor.getOperands().empty() && "constant has no operands");
276  return getValueAttr();
277 }
278 
279 //===----------------------------------------------------------------------===//
280 // IntConstantOp
281 //===----------------------------------------------------------------------===//
282 
284  function_ref<void(Value, StringRef)> setNameFn) {
285  SmallVector<char, 32> specialNameBuffer;
286  llvm::raw_svector_ostream specialName(specialNameBuffer);
287  specialName << "c" << getValue();
288  setNameFn(getResult(), specialName.str());
289 }
290 
291 OpFoldResult IntConstantOp::fold(FoldAdaptor adaptor) {
292  assert(adaptor.getOperands().empty() && "constant has no operands");
293  return getValueAttr();
294 }
295 
296 void IntConstantOp::print(OpAsmPrinter &p) {
297  p << " " << getValue();
298  p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"value"});
299 }
300 
301 ParseResult IntConstantOp::parse(OpAsmParser &parser, OperationState &result) {
302  APInt value;
303  if (parser.parseInteger(value))
304  return failure();
305 
306  result.getOrAddProperties<Properties>().setValue(
307  IntegerAttr::get(parser.getContext(), APSInt(value)));
308 
309  if (parser.parseOptionalAttrDict(result.attributes))
310  return failure();
311 
312  result.addTypes(smt::IntType::get(parser.getContext()));
313  return success();
314 }
315 
316 //===----------------------------------------------------------------------===//
317 // ForallOp
318 //===----------------------------------------------------------------------===//
319 
320 template <typename QuantifierOp>
321 static LogicalResult verifyQuantifierRegions(QuantifierOp op) {
322  if (op.getBoundVarNames() &&
323  op.getBody().getNumArguments() != op.getBoundVarNames()->size())
324  return op.emitOpError(
325  "number of bound variable names must match number of block arguments");
326  if (!llvm::all_of(op.getBody().getArgumentTypes(), isAnyNonFuncSMTValueType))
327  return op.emitOpError()
328  << "bound variables must by any non-function SMT value";
329 
330  if (op.getBody().front().getTerminator()->getNumOperands() != 1)
331  return op.emitOpError("must have exactly one yielded value");
332  if (!isa<BoolType>(
333  op.getBody().front().getTerminator()->getOperand(0).getType()))
334  return op.emitOpError("yielded value must be of '!smt.bool' type");
335 
336  for (auto regionWithIndex : llvm::enumerate(op.getPatterns())) {
337  unsigned i = regionWithIndex.index();
338  Region &region = regionWithIndex.value();
339 
340  if (op.getBody().getArgumentTypes() != region.getArgumentTypes())
341  return op.emitOpError()
342  << "block argument number and types of the 'body' "
343  "and 'patterns' region #"
344  << i << " must match";
345  if (region.front().getTerminator()->getNumOperands() < 1)
346  return op.emitOpError() << "'patterns' region #" << i
347  << " must have at least one yielded value";
348 
349  // All operations in the 'patterns' region must be SMT operations.
350  auto result = region.walk([&](Operation *childOp) {
351  if (!isa<SMTDialect>(childOp->getDialect())) {
352  auto diag = op.emitOpError()
353  << "the 'patterns' region #" << i
354  << " may only contain SMT dialect operations";
355  diag.attachNote(childOp->getLoc()) << "first non-SMT operation here";
356  return WalkResult::interrupt();
357  }
358 
359  // There may be no quantifier (or other variable binding) operations in
360  // the 'patterns' region.
361  if (isa<ForallOp, ExistsOp>(childOp)) {
362  auto diag = op.emitOpError() << "the 'patterns' region #" << i
363  << " must not contain "
364  "any variable binding operations";
365  diag.attachNote(childOp->getLoc()) << "first violating operation here";
366  return WalkResult::interrupt();
367  }
368 
369  return WalkResult::advance();
370  });
371  if (result.wasInterrupted())
372  return failure();
373  }
374 
375  return success();
376 }
377 
378 template <typename Properties>
379 static void buildQuantifier(
380  OpBuilder &odsBuilder, OperationState &odsState, TypeRange boundVarTypes,
381  function_ref<Value(OpBuilder &, Location, ValueRange)> bodyBuilder,
382  std::optional<ArrayRef<StringRef>> boundVarNames,
383  function_ref<ValueRange(OpBuilder &, Location, ValueRange)> patternBuilder,
384  uint32_t weight, bool noPattern) {
385  odsState.addTypes(BoolType::get(odsBuilder.getContext()));
386  if (weight != 0)
387  odsState.getOrAddProperties<Properties>().weight =
388  odsBuilder.getIntegerAttr(odsBuilder.getIntegerType(32), weight);
389  if (noPattern)
390  odsState.getOrAddProperties<Properties>().noPattern =
391  odsBuilder.getUnitAttr();
392  if (boundVarNames.has_value()) {
393  SmallVector<Attribute> boundVarNamesList;
394  for (StringRef str : *boundVarNames)
395  boundVarNamesList.emplace_back(odsBuilder.getStringAttr(str));
396  odsState.getOrAddProperties<Properties>().boundVarNames =
397  odsBuilder.getArrayAttr(boundVarNamesList);
398  }
399  {
400  OpBuilder::InsertionGuard guard(odsBuilder);
401  Region *region = odsState.addRegion();
402  Block *block = odsBuilder.createBlock(region);
403  block->addArguments(
404  boundVarTypes,
405  SmallVector<Location>(boundVarTypes.size(), odsState.location));
406  Value returnVal =
407  bodyBuilder(odsBuilder, odsState.location, block->getArguments());
408  odsBuilder.create<smt::YieldOp>(odsState.location, returnVal);
409  }
410  if (patternBuilder) {
411  Region *region = odsState.addRegion();
412  OpBuilder::InsertionGuard guard(odsBuilder);
413  Block *block = odsBuilder.createBlock(region);
414  block->addArguments(
415  boundVarTypes,
416  SmallVector<Location>(boundVarTypes.size(), odsState.location));
417  ValueRange returnVals =
418  patternBuilder(odsBuilder, odsState.location, block->getArguments());
419  odsBuilder.create<smt::YieldOp>(odsState.location, returnVals);
420  }
421 }
422 
423 LogicalResult ForallOp::verify() {
424  if (!getPatterns().empty() && getNoPattern())
425  return emitOpError() << "patterns and the no_pattern attribute must not be "
426  "specified at the same time";
427 
428  return success();
429 }
430 
431 LogicalResult ForallOp::verifyRegions() {
432  return verifyQuantifierRegions(*this);
433 }
434 
435 void ForallOp::build(
436  OpBuilder &odsBuilder, OperationState &odsState, TypeRange boundVarTypes,
437  function_ref<Value(OpBuilder &, Location, ValueRange)> bodyBuilder,
438  std::optional<ArrayRef<StringRef>> boundVarNames,
439  function_ref<ValueRange(OpBuilder &, Location, ValueRange)> patternBuilder,
440  uint32_t weight, bool noPattern) {
441  buildQuantifier<Properties>(odsBuilder, odsState, boundVarTypes, bodyBuilder,
442  boundVarNames, patternBuilder, weight, noPattern);
443 }
444 
445 //===----------------------------------------------------------------------===//
446 // ExistsOp
447 //===----------------------------------------------------------------------===//
448 
449 LogicalResult ExistsOp::verify() {
450  if (!getPatterns().empty() && getNoPattern())
451  return emitOpError() << "patterns and the no_pattern attribute must not be "
452  "specified at the same time";
453 
454  return success();
455 }
456 
457 LogicalResult ExistsOp::verifyRegions() {
458  return verifyQuantifierRegions(*this);
459 }
460 
461 void ExistsOp::build(
462  OpBuilder &odsBuilder, OperationState &odsState, TypeRange boundVarTypes,
463  function_ref<Value(OpBuilder &, Location, ValueRange)> bodyBuilder,
464  std::optional<ArrayRef<StringRef>> boundVarNames,
465  function_ref<ValueRange(OpBuilder &, Location, ValueRange)> patternBuilder,
466  uint32_t weight, bool noPattern) {
467  buildQuantifier<Properties>(odsBuilder, odsState, boundVarTypes, bodyBuilder,
468  boundVarNames, patternBuilder, weight, noPattern);
469 }
470 
471 #define GET_OP_CLASSES
472 #include "circt/Dialect/SMT/SMT.cpp.inc"
assert(baseType &&"element must be base type")
static InstancePath empty
static void buildQuantifier(OpBuilder &odsBuilder, OperationState &odsState, TypeRange boundVarTypes, function_ref< Value(OpBuilder &, Location, ValueRange)> bodyBuilder, std::optional< ArrayRef< StringRef >> boundVarNames, function_ref< ValueRange(OpBuilder &, Location, ValueRange)> patternBuilder, uint32_t weight, bool noPattern)
Definition: SMTOps.cpp:379
static LogicalResult verifyQuantifierRegions(QuantifierOp op)
Definition: SMTOps.cpp:321
static LogicalResult parseSameOperandTypeVariadicToBoolOp(OpAsmParser &parser, OperationState &result)
Definition: SMTOps.cpp:95
static LogicalResult verify(Value clock, bool eventExists, mlir::Location loc)
Definition: SVOps.cpp:2467
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:55
uint64_t getWidth(Type t)
Definition: ESIPasses.cpp:32
void getAsmResultNames(OpAsmSetValueNameFn setNameFn, StringRef instanceName, ArrayAttr resultNames, ValueRange results)
Suggest a name for each result value based on the saved result names attribute.
bool isAnyNonFuncSMTValueType(mlir::Type type)
Returns whether the given type is an SMT value type (excluding functions).
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21