10 #include "mlir/IR/Builders.h"
11 #include "mlir/IR/OpImplementation.h"
12 #include "llvm/ADT/APSInt.h"
14 using namespace circt;
22 LogicalResult BVConstantOp::inferReturnTypes(
23 mlir::MLIRContext *context, std::optional<mlir::Location> location,
24 ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes,
25 ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions,
26 ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) {
27 inferredReturnTypes.push_back(
28 properties.as<Properties *>()->getValue().getType());
33 function_ref<
void(Value, StringRef)> setNameFn) {
34 SmallVector<char, 128> specialNameBuffer;
35 llvm::raw_svector_ostream specialName(specialNameBuffer);
36 specialName <<
"c" << getValue().getValue() <<
"_bv"
37 << getValue().getValue().getBitWidth();
38 setNameFn(getResult(), specialName.str());
41 OpFoldResult BVConstantOp::fold(FoldAdaptor adaptor) {
42 assert(adaptor.getOperands().empty() &&
"constant has no operands");
43 return getValueAttr();
51 function_ref<
void(Value, StringRef)> setNameFn) {
52 setNameFn(getResult(), getNamePrefix().has_value() ? *getNamePrefix() :
"");
59 LogicalResult SolverOp::verifyRegions() {
60 if (getBody()->getTerminator()->getOperands().getTypes() != getResultTypes())
61 return emitOpError() <<
"types of yielded values must match return values";
62 if (getBody()->getArgumentTypes() != getInputs().getTypes())
64 <<
"block argument types must match the types of the 'inputs'";
73 LogicalResult CheckOp::verifyRegions() {
74 if (getSatRegion().front().getTerminator()->getOperands().getTypes() !=
76 return emitOpError() <<
"types of yielded values in 'sat' region must "
77 "match return values";
78 if (getUnknownRegion().front().getTerminator()->getOperands().getTypes() !=
80 return emitOpError() <<
"types of yielded values in 'unknown' region must "
81 "match return values";
82 if (getUnsatRegion().front().getTerminator()->getOperands().getTypes() !=
84 return emitOpError() <<
"types of yielded values in 'unsat' region must "
85 "match return values";
96 OperationState &result) {
97 SmallVector<OpAsmParser::UnresolvedOperand, 4> inputs;
98 SMLoc loc = parser.getCurrentLocation();
101 if (parser.parseOperandList(inputs) ||
102 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
103 parser.parseType(type))
107 if (parser.resolveOperands(inputs, SmallVector<Type>(inputs.size(), type),
108 loc, result.operands))
114 ParseResult EqOp::parse(OpAsmParser &parser, OperationState &result) {
118 void EqOp::print(OpAsmPrinter &printer) {
119 printer <<
' ' << getInputs();
120 printer.printOptionalAttrDict(getOperation()->getAttrs());
121 printer <<
" : " << getInputs().front().getType();
125 if (getInputs().size() < 2)
126 return emitOpError() <<
"'inputs' must have at least size 2, but got "
127 << getInputs().size();
136 ParseResult DistinctOp::parse(OpAsmParser &parser, OperationState &result) {
140 void DistinctOp::print(OpAsmPrinter &printer) {
141 printer <<
' ' << getInputs();
142 printer.printOptionalAttrDict(getOperation()->getAttrs());
143 printer <<
" : " << getInputs().front().getType();
147 if (getInputs().size() < 2)
148 return emitOpError() <<
"'inputs' must have at least size 2, but got "
149 << getInputs().size();
159 unsigned rangeWidth = getType().getWidth();
160 unsigned inputWidth = cast<BitVectorType>(getInput().getType()).getWidth();
161 if (getLowBit() + rangeWidth > inputWidth)
162 return emitOpError(
"range to be extracted is too big, expected range "
163 "starting at index ")
164 << getLowBit() <<
" of length " << rangeWidth
165 <<
" requires input width of at least " << (getLowBit() + rangeWidth)
166 <<
", but the input width is only " << inputWidth;
174 LogicalResult ConcatOp::inferReturnTypes(
175 MLIRContext *context, std::optional<Location> location, ValueRange operands,
176 DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
177 SmallVectorImpl<Type> &inferredReturnTypes) {
179 context, cast<BitVectorType>(operands[0].getType()).
getWidth() +
180 cast<BitVectorType>(operands[1].getType()).
getWidth()));
189 unsigned inputWidth = cast<BitVectorType>(getInput().getType()).getWidth();
190 unsigned resultWidth = getType().getWidth();
191 if (resultWidth % inputWidth != 0)
192 return emitOpError() <<
"result bit-vector width must be a multiple of the "
193 "input bit-vector width";
198 unsigned RepeatOp::getCount() {
199 unsigned inputWidth = cast<BitVectorType>(getInput().getType()).getWidth();
200 unsigned resultWidth = getType().getWidth();
201 return resultWidth / inputWidth;
204 void RepeatOp::build(OpBuilder &builder, OperationState &state,
unsigned count,
206 unsigned inputWidth = cast<BitVectorType>(input.getType()).getWidth();
208 build(builder, state, resultTy, input);
211 ParseResult RepeatOp::parse(OpAsmParser &parser, OperationState &result) {
212 OpAsmParser::UnresolvedOperand input;
214 llvm::SMLoc countLoc = parser.getCurrentLocation();
217 if (parser.parseInteger(count) || parser.parseKeyword(
"times"))
220 if (count.isNonPositive())
221 return parser.emitError(countLoc) <<
"integer must be positive";
223 llvm::SMLoc inputLoc = parser.getCurrentLocation();
224 if (parser.parseOperand(input) ||
225 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
226 parser.parseType(inputType))
229 if (parser.resolveOperand(input, inputType, result.operands))
232 auto bvInputTy = dyn_cast<BitVectorType>(inputType);
234 return parser.emitError(inputLoc) <<
"input must have bit-vector type";
238 const unsigned maxBw = 64;
239 if (count.getActiveBits() > maxBw)
240 return parser.emitError(countLoc)
241 <<
"integer must fit into " << maxBw <<
" bits";
246 APInt resultBw = bvInputTy.getWidth() * count.zext(2 * maxBw);
247 if (resultBw.getActiveBits() > maxBw)
248 return parser.emitError(countLoc)
249 <<
"result bit-width (provided integer times bit-width of the input "
250 "type) must fit into "
255 result.addTypes(resultTy);
259 void RepeatOp::print(OpAsmPrinter &printer) {
260 printer <<
" " << getCount() <<
" times " << getInput();
261 printer.printOptionalAttrDict((*this)->getAttrs());
262 printer <<
" : " << getInput().getType();
270 function_ref<
void(Value, StringRef)> setNameFn) {
271 setNameFn(getResult(), getValue() ?
"true" :
"false");
274 OpFoldResult BoolConstantOp::fold(FoldAdaptor adaptor) {
275 assert(adaptor.getOperands().empty() &&
"constant has no operands");
276 return getValueAttr();
284 function_ref<
void(Value, StringRef)> setNameFn) {
285 SmallVector<char, 32> specialNameBuffer;
286 llvm::raw_svector_ostream specialName(specialNameBuffer);
287 specialName <<
"c" << getValue();
288 setNameFn(getResult(), specialName.str());
291 OpFoldResult IntConstantOp::fold(FoldAdaptor adaptor) {
292 assert(adaptor.getOperands().empty() &&
"constant has no operands");
293 return getValueAttr();
296 void IntConstantOp::print(OpAsmPrinter &p) {
297 p <<
" " << getValue();
298 p.printOptionalAttrDict((*this)->getAttrs(), {
"value"});
301 ParseResult IntConstantOp::parse(OpAsmParser &parser, OperationState &result) {
303 if (parser.parseInteger(value))
306 result.getOrAddProperties<Properties>().setValue(
309 if (parser.parseOptionalAttrDict(result.attributes))
320 template <
typename QuantifierOp>
322 if (op.getBoundVarNames() &&
323 op.getBody().getNumArguments() != op.getBoundVarNames()->size())
324 return op.emitOpError(
325 "number of bound variable names must match number of block arguments");
327 return op.emitOpError()
328 <<
"bound variables must by any non-function SMT value";
330 if (op.getBody().front().getTerminator()->getNumOperands() != 1)
331 return op.emitOpError(
"must have exactly one yielded value");
333 op.getBody().front().getTerminator()->getOperand(0).getType()))
334 return op.emitOpError(
"yielded value must be of '!smt.bool' type");
336 for (
auto regionWithIndex : llvm::enumerate(op.getPatterns())) {
337 unsigned i = regionWithIndex.index();
338 Region ®ion = regionWithIndex.value();
340 if (op.getBody().getArgumentTypes() != region.getArgumentTypes())
341 return op.emitOpError()
342 <<
"block argument number and types of the 'body' "
343 "and 'patterns' region #"
344 << i <<
" must match";
345 if (region.front().getTerminator()->getNumOperands() < 1)
346 return op.emitOpError() <<
"'patterns' region #" << i
347 <<
" must have at least one yielded value";
350 auto result = region.walk([&](Operation *childOp) {
351 if (!isa<SMTDialect>(childOp->getDialect())) {
352 auto diag = op.emitOpError()
353 <<
"the 'patterns' region #" << i
354 <<
" may only contain SMT dialect operations";
355 diag.attachNote(childOp->getLoc()) <<
"first non-SMT operation here";
356 return WalkResult::interrupt();
361 if (isa<ForallOp, ExistsOp>(childOp)) {
362 auto diag = op.emitOpError() <<
"the 'patterns' region #" << i
363 <<
" must not contain "
364 "any variable binding operations";
365 diag.attachNote(childOp->getLoc()) <<
"first violating operation here";
366 return WalkResult::interrupt();
369 return WalkResult::advance();
371 if (result.wasInterrupted())
378 template <
typename Properties>
380 OpBuilder &odsBuilder, OperationState &odsState, TypeRange boundVarTypes,
381 function_ref<Value(OpBuilder &, Location, ValueRange)> bodyBuilder,
382 std::optional<ArrayRef<StringRef>> boundVarNames,
383 function_ref<ValueRange(OpBuilder &, Location, ValueRange)> patternBuilder,
384 uint32_t weight,
bool noPattern) {
387 odsState.getOrAddProperties<Properties>().weight =
388 odsBuilder.getIntegerAttr(odsBuilder.getIntegerType(32), weight);
390 odsState.getOrAddProperties<Properties>().noPattern =
391 odsBuilder.getUnitAttr();
392 if (boundVarNames.has_value()) {
393 SmallVector<Attribute> boundVarNamesList;
394 for (StringRef str : *boundVarNames)
395 boundVarNamesList.emplace_back(odsBuilder.getStringAttr(str));
396 odsState.getOrAddProperties<Properties>().boundVarNames =
397 odsBuilder.getArrayAttr(boundVarNamesList);
400 OpBuilder::InsertionGuard guard(odsBuilder);
401 Region *region = odsState.addRegion();
402 Block *block = odsBuilder.createBlock(region);
405 SmallVector<Location>(boundVarTypes.size(), odsState.location));
407 bodyBuilder(odsBuilder, odsState.location, block->getArguments());
408 odsBuilder.create<smt::YieldOp>(odsState.location, returnVal);
410 if (patternBuilder) {
411 Region *region = odsState.addRegion();
412 OpBuilder::InsertionGuard guard(odsBuilder);
413 Block *block = odsBuilder.createBlock(region);
416 SmallVector<Location>(boundVarTypes.size(), odsState.location));
417 ValueRange returnVals =
418 patternBuilder(odsBuilder, odsState.location, block->getArguments());
419 odsBuilder.create<smt::YieldOp>(odsState.location, returnVals);
424 if (!getPatterns().
empty() && getNoPattern())
425 return emitOpError() <<
"patterns and the no_pattern attribute must not be "
426 "specified at the same time";
431 LogicalResult ForallOp::verifyRegions() {
435 void ForallOp::build(
436 OpBuilder &odsBuilder, OperationState &odsState, TypeRange boundVarTypes,
437 function_ref<Value(OpBuilder &, Location, ValueRange)> bodyBuilder,
438 std::optional<ArrayRef<StringRef>> boundVarNames,
439 function_ref<ValueRange(OpBuilder &, Location, ValueRange)> patternBuilder,
440 uint32_t weight,
bool noPattern) {
441 buildQuantifier<Properties>(odsBuilder, odsState, boundVarTypes, bodyBuilder,
442 boundVarNames, patternBuilder, weight, noPattern);
450 if (!getPatterns().
empty() && getNoPattern())
451 return emitOpError() <<
"patterns and the no_pattern attribute must not be "
452 "specified at the same time";
457 LogicalResult ExistsOp::verifyRegions() {
461 void ExistsOp::build(
462 OpBuilder &odsBuilder, OperationState &odsState, TypeRange boundVarTypes,
463 function_ref<Value(OpBuilder &, Location, ValueRange)> bodyBuilder,
464 std::optional<ArrayRef<StringRef>> boundVarNames,
465 function_ref<ValueRange(OpBuilder &, Location, ValueRange)> patternBuilder,
466 uint32_t weight,
bool noPattern) {
467 buildQuantifier<Properties>(odsBuilder, odsState, boundVarTypes, bodyBuilder,
468 boundVarNames, patternBuilder, weight, noPattern);
471 #define GET_OP_CLASSES
472 #include "circt/Dialect/SMT/SMT.cpp.inc"
assert(baseType &&"element must be base type")
static InstancePath empty
static void buildQuantifier(OpBuilder &odsBuilder, OperationState &odsState, TypeRange boundVarTypes, function_ref< Value(OpBuilder &, Location, ValueRange)> bodyBuilder, std::optional< ArrayRef< StringRef >> boundVarNames, function_ref< ValueRange(OpBuilder &, Location, ValueRange)> patternBuilder, uint32_t weight, bool noPattern)
static LogicalResult verifyQuantifierRegions(QuantifierOp op)
static LogicalResult parseSameOperandTypeVariadicToBoolOp(OpAsmParser &parser, OperationState &result)
static LogicalResult verify(Value clock, bool eventExists, mlir::Location loc)
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
uint64_t getWidth(Type t)
void getAsmResultNames(OpAsmSetValueNameFn setNameFn, StringRef instanceName, ArrayAttr resultNames, ValueRange results)
Suggest a name for each result value based on the saved result names attribute.
bool isAnyNonFuncSMTValueType(mlir::Type type)
Returns whether the given type is an SMT value type (excluding functions).
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.