16#include "mlir/Analysis/TopologicalSortUtils.h"
17#include "mlir/IR/BuiltinAttributes.h"
18#include "mlir/IR/Matchers.h"
19#include "mlir/IR/OpDefinition.h"
20#include "mlir/IR/PatternMatch.h"
21#include "mlir/IR/Value.h"
22#include "llvm/ADT/APInt.h"
23#include "llvm/ADT/SmallVector.h"
24#include "llvm/Support/Casting.h"
25#include "llvm/Support/LogicalResult.h"
30using namespace circt::synth::aig;
32using namespace matchers;
35#include "circt/Dialect/Synth/Synth.cpp.inc"
39inline llvm::KnownBits applyInversion(llvm::KnownBits value,
bool inverted) {
41 std::swap(value.Zero, value.One);
45template <
typename SubType>
46struct ComplementMatcher {
48 ComplementMatcher(SubType lhs) : lhs(std::move(lhs)) {}
49 bool match(Operation *op) {
50 auto boolOp = dyn_cast<BooleanLogicOpInterface>(op);
51 return boolOp && boolOp.getInputs().size() == 1 && boolOp.isInverted(0) &&
52 lhs.match(op->getOperand(0));
56template <
typename SubType>
57static inline ComplementMatcher<SubType>
m_Complement(
const SubType &subExpr) {
58 return ComplementMatcher<SubType>(subExpr);
63LogicalResult ChoiceOp::verify() {
64 if (getNumOperands() < 1)
65 return emitOpError(
"requires at least one operand");
69OpFoldResult ChoiceOp::fold(FoldAdaptor adaptor) {
70 if (adaptor.getInputs().size() == 1)
84LogicalResult ChoiceOp::canonicalize(ChoiceOp op, PatternRewriter &rewriter) {
85 llvm::SetVector<Value> worklist;
88 auto addToWorklist = [&](ChoiceOp choice) ->
bool {
89 if (choice->getBlock() == op->getBlock() && visitedChoices.insert(choice)) {
90 worklist.insert(choice.getInputs().begin(), choice.getInputs().end());
98 bool mergedOtherChoices =
false;
101 for (
unsigned i = 0; i < worklist.size(); ++i) {
102 Value val = worklist[i];
103 if (
auto defOp = val.getDefiningOp<synth::ChoiceOp>()) {
105 if (addToWorklist(defOp))
106 mergedOtherChoices =
true;
109 for (Operation *user : val.getUsers()) {
110 if (
auto userChoice = llvm::dyn_cast<synth::ChoiceOp>(user)) {
111 if (addToWorklist(userChoice)) {
112 mergedOtherChoices =
true;
118 llvm::SmallVector<mlir::Value> finalOperands;
119 for (Value v : worklist) {
120 if (!visitedChoices.contains(v.getDefiningOp())) {
121 finalOperands.push_back(v);
125 if (!mergedOtherChoices && finalOperands.size() == op.getInputs().size())
126 return llvm::failure();
128 auto newChoice = synth::ChoiceOp::create(rewriter, op->getLoc(), op.getType(),
130 for (Operation *visited : visitedChoices.takeVector())
131 rewriter.replaceOp(visited, newChoice);
133 for (
auto value : newChoice.getInputs())
134 rewriter.replaceAllUsesExcept(value, newChoice.getResult(), newChoice);
143bool AndInverterOp::areInputsPermutationInvariant() {
return true; }
145OpFoldResult AndInverterOp::fold(FoldAdaptor adaptor) {
146 if (getNumOperands() == 1 && !isInverted(0))
147 return getOperand(0);
149 auto inputs = adaptor.getInputs();
150 if (inputs.size() == 2)
151 if (
auto intAttr = dyn_cast_or_null<IntegerAttr>(inputs[1])) {
152 auto value = intAttr.getValue();
156 return IntegerAttr::get(
157 IntegerType::get(getContext(), value.getBitWidth()), value);
158 if (value.isAllOnes()) {
162 return getOperand(0);
168LogicalResult AndInverterOp::canonicalize(AndInverterOp op,
169 PatternRewriter &rewriter) {
171 SmallVector<Value> uniqueValues;
172 SmallVector<bool> uniqueInverts;
175 APInt::getAllOnes(op.getResult().getType().getIntOrFloatBitWidth());
177 bool invertedConstFound =
false;
178 bool flippedFound =
false;
180 for (
auto [value, inverted] :
llvm::zip(op.getInputs(), op.getInverted())) {
181 bool newInverted = inverted;
184 constValue &= ~constOp.getValue();
185 invertedConstFound =
true;
187 constValue &= constOp.getValue();
192 if (
auto andInverterOp = value.getDefiningOp<synth::aig::AndInverterOp>()) {
193 if (andInverterOp.getInputs().size() == 1 &&
194 andInverterOp.isInverted(0)) {
195 value = andInverterOp.getOperand(0);
196 newInverted = andInverterOp.isInverted(0) ^ inverted;
201 auto it = seen.find(value);
202 if (it == seen.end()) {
203 seen.insert({value, newInverted});
204 uniqueValues.push_back(value);
205 uniqueInverts.push_back(newInverted);
206 }
else if (it->second != newInverted) {
209 op, APInt::getZero(value.getType().getIntOrFloatBitWidth()));
215 if (constValue.isZero()) {
221 if ((uniqueValues.size() == op.getInputs().size() && !flippedFound) ||
222 (!constValue.isAllOnes() && !invertedConstFound &&
223 uniqueValues.size() + 1 == op.getInputs().size()))
226 if (!constValue.isAllOnes()) {
228 uniqueInverts.push_back(
false);
229 uniqueValues.push_back(constOp);
233 if (uniqueValues.empty()) {
239 replaceOpWithNewOpAndCopyNamehint<synth::aig::AndInverterOp>(
240 rewriter, op, uniqueValues, uniqueInverts);
244APInt AndInverterOp::evaluateBooleanLogicWithoutInversion(
245 llvm::ArrayRef<APInt> inputs) {
246 assert(!inputs.empty() &&
"expected non-empty input list");
247 APInt result = APInt::getAllOnes(inputs.front().getBitWidth());
248 for (
const APInt &input : inputs)
253bool AndInverterOp::supportsNumInputs(
unsigned numInputs) {
254 return numInputs >= 1;
257llvm::KnownBits AndInverterOp::computeKnownBits(
258 llvm::function_ref<
const llvm::KnownBits &(
unsigned)> getInputKnownBits) {
259 assert(getNumOperands() > 0 &&
"Expected non-empty input list");
261 auto width = getInputKnownBits(0).getBitWidth();
262 llvm::KnownBits result(width);
263 result.One = APInt::getAllOnes(width);
264 result.Zero = APInt::getZero(width);
266 for (
auto [i, inverted] :
llvm::enumerate(getInverted()))
267 result &= applyInversion(getInputKnownBits(i), inverted);
272int64_t AndInverterOp::getLogicDepthCost() {
273 return llvm::Log2_64_Ceil(getNumOperands());
276std::optional<uint64_t> AndInverterOp::getLogicAreaCost() {
277 int64_t bitWidth = hw::getBitWidth(getType());
280 return static_cast<uint64_t
>(getNumOperands() - 1) * bitWidth;
283void AndInverterOp::emitCNFWithoutInversion(
284 int outVar, llvm::ArrayRef<int> inputVars,
285 llvm::function_ref<
void(llvm::ArrayRef<int>)> addClause,
286 llvm::function_ref<
int()> newVar) {
295bool XorInverterOp::areInputsPermutationInvariant() {
return true; }
297OpFoldResult XorInverterOp::fold(FoldAdaptor adaptor) {
299 if (getNumOperands() == 1 && !isInverted(0))
300 return getOperand(0);
302 auto inputs = adaptor.getInputs();
303 if (inputs.size() == 2)
304 if (
auto intAttr = dyn_cast_or_null<IntegerAttr>(inputs[1])) {
305 auto value = intAttr.getValue();
310 return getOperand(0);
315LogicalResult XorInverterOp::canonicalize(XorInverterOp op,
316 PatternRewriter &rewriter) {
323 APInt::getZero(op.getResult().getType().getIntOrFloatBitWidth());
325 bool constFound =
false;
326 bool changed =
false;
328 for (
auto [value, inverted] :
llvm::zip(op.getInputs(), op.getInverted())) {
329 Value currentValue = value;
330 bool newInverted = inverted;
334 if (
auto constOp = currentValue.getDefiningOp<
hw::ConstantOp>()) {
335 APInt val = constOp.getValue();
346 matchPattern(currentValue,
m_Complement(m_Any(&matchedVal)))) {
347 currentValue = matchedVal;
354 if (activeOperands.count(currentValue)) {
357 if (activeOperands[currentValue] != newInverted)
358 constValue.flipAllBits();
359 activeOperands.erase(currentValue);
362 activeOperands[currentValue] = newInverted;
368 if (!changed && !constFound && activeOperands.size() == op.getInputs().size())
373 if (!constValue.isZero()) {
374 if (constValue.isAllOnes() && !activeOperands.empty()) {
376 activeOperands.back().second = !activeOperands.back().second;
378 if (op.getInputs().size() == 2 && !op.getInverted()[1] &&
379 activeOperands.size() == 1)
382 activeOperands.insert({constOp,
false});
386 if (activeOperands.empty()) {
388 op, APInt::getZero(op.getResult().getType().getIntOrFloatBitWidth()));
393 XorInverterOp::create(rewriter, op.getLoc(),
394 activeOperands.getArrayRef()));
398APInt XorInverterOp::evaluateBooleanLogicWithoutInversion(
399 llvm::ArrayRef<APInt> inputs) {
400 assert(!inputs.empty() &&
"expected non-empty input list");
401 APInt result = APInt::getZero(inputs.front().getBitWidth());
402 for (
const APInt &input : inputs)
407bool XorInverterOp::supportsNumInputs(
unsigned numInputs) {
408 return numInputs >= 1;
411llvm::KnownBits XorInverterOp::computeKnownBits(
412 llvm::function_ref<
const llvm::KnownBits &(
unsigned)> getInputKnownBits) {
413 assert(getNumOperands() > 0 &&
"Expected non-empty input list");
415 llvm::KnownBits result(getInputKnownBits(0).
getBitWidth());
416 for (
auto [i, inverted] :
llvm::enumerate(getInverted()))
417 result ^= applyInversion(getInputKnownBits(i), inverted);
421int64_t XorInverterOp::getLogicDepthCost() {
422 return llvm::Log2_64_Ceil(getNumOperands());
425std::optional<uint64_t> XorInverterOp::getLogicAreaCost() {
426 int64_t bitWidth = hw::getBitWidth(getType());
429 return static_cast<uint64_t
>(getNumOperands() - 1) * bitWidth;
432void XorInverterOp::emitCNFWithoutInversion(
433 int outVar, llvm::ArrayRef<int> inputVars,
434 llvm::function_ref<
void(llvm::ArrayRef<int>)> addClause,
435 llvm::function_ref<
int()> newVar) {
443ParseResult DotOp::parse(OpAsmParser &parser, OperationState &result) {
444 SmallVector<OpAsmParser::UnresolvedOperand> operands;
446 DenseBoolArrayAttr inverted;
452 if (operands.size() != 3)
453 return parser.emitError(parser.getCurrentLocation())
454 <<
"expected exactly three operands";
455 if (parser.resolveOperands(operands, resultType, result.operands))
458 result.addTypes(resultType);
459 result.addAttributes(attrs);
460 result.addAttribute(
"inverted", inverted);
464void DotOp::print(OpAsmPrinter &printer) {
467 getType(), getInvertedAttr(),
468 (*this)->getAttrDictionary());
471LogicalResult DotOp::verify() {
472 if (getInverted().size() != 3)
473 return emitOpError(
"requires exactly three inversion flags");
477APInt DotOp::evaluateBooleanLogicWithoutInversion(
478 llvm::ArrayRef<APInt> inputs) {
479 assert(supportsNumInputs(inputs.size()) &&
480 "dot expects exactly three operands");
484bool DotOp::areInputsPermutationInvariant() {
return false; }
486bool DotOp::supportsNumInputs(
unsigned numInputs) {
return numInputs == 3; }
488llvm::KnownBits DotOp::computeKnownBits(
489 llvm::function_ref<
const llvm::KnownBits &(
unsigned)> getInputKnownBits) {
490 auto x = applyInversion(getInputKnownBits(0), isInverted(0));
491 auto y = applyInversion(getInputKnownBits(1), isInverted(1));
492 auto z = applyInversion(getInputKnownBits(2), isInverted(2));
496std::optional<uint64_t> DotOp::getLogicAreaCost() {
497 int64_t bitWidth = hw::getBitWidth(getType());
500 return static_cast<uint64_t
>(bitWidth);
503void DotOp::emitCNFWithoutInversion(
504 int outVar, llvm::ArrayRef<int> inputVars,
505 llvm::function_ref<
void(llvm::ArrayRef<int>)> addClause,
506 llvm::function_ref<
int()> newVar) {
507 assert(inputVars.size() == 3 &&
"expected one SAT variable per operand");
508 int andVar = newVar();
509 int orVar = newVar();
519 Location loc, ValueRange operands, ArrayRef<bool> inverts,
520 PatternRewriter &rewriter,
521 llvm::function_ref<Value(Value,
bool)> createUnary,
522 llvm::function_ref<Value(Value, Value,
bool,
bool)> createBinary) {
523 switch (operands.size()) {
525 assert(0 &&
"cannot be called with empty operand range");
528 return inverts[0] ? createUnary(operands[0],
true) : operands[0];
530 return createBinary(operands[0], operands[1], inverts[0], inverts[1]);
532 auto firstHalf = operands.size() / 2;
534 inverts.take_front(firstHalf),
535 rewriter, createUnary, createBinary);
537 inverts.drop_front(firstHalf),
538 rewriter, createUnary, createBinary);
539 return createBinary(lhs, rhs,
false,
false);
544template <
typename OpTy>
546 PatternRewriter &rewriter) {
547 if (op.getInputs().size() <= 2)
550 op.getLoc(), op.getOperands(), op.getInverted(), rewriter,
551 [&](Value input,
bool invert) {
552 return OpTy::create(rewriter, op.getLoc(), input, invert);
554 [&](Value lhs, Value rhs,
bool invertLhs,
bool invertRhs) {
555 return OpTy::create(rewriter, op.getLoc(), lhs, rhs, invertLhs,
564 patterns.add(lowerVariadicAndInverterOpConversion<aig::AndInverterOp>);
569 patterns.add(lowerVariadicAndInverterOpConversion<XorInverterOp>);
573 return isa<synth::BooleanLogicOpInterface, synth::ChoiceOp,
comb::ExtractOp,
579 llvm::function_ref<
bool(mlir::Value, mlir::Operation *)> isOperandReady) {
581 auto walkResult = op->walk([&](Region *region) {
583 dyn_cast<mlir::RegionKindInterface>(region->getParentOp());
585 regionKindOp.hasSSADominance(region->getRegionNumber()))
586 return WalkResult::advance();
589 for (
auto &block : *region) {
590 if (!mlir::sortTopologically(&block, isOperandReady))
591 return WalkResult::interrupt();
593 return WalkResult::advance();
596 return success(!walkResult.wasInterrupted());
assert(baseType &&"element must be base type")
static ComplementMatcher< SubType > m_Complement(const SubType &subExpr)
LogicalResult lowerVariadicAndInverterOpConversion(OpTy op, PatternRewriter &rewriter)
static Value lowerVariadicInvertibleOp(Location loc, ValueRange operands, ArrayRef< bool > inverts, PatternRewriter &rewriter, llvm::function_ref< Value(Value, bool)> createUnary, llvm::function_ref< Value(Value, Value, bool, bool)> createBinary)
int64_t getBitWidth(mlir::Type type)
Return the hardware bit width of a type.
void populateVariadicXorInverterLoweringPatterns(mlir::RewritePatternSet &patterns)
LogicalResult topologicallySortGraphRegionBlocks(mlir::Operation *op, llvm::function_ref< bool(mlir::Value, mlir::Operation *)> isOperandReady)
This function performs a topological sort on the operations within each block of graph regions in the...
T evaluateDotLogic(const T &x, const T &y, const T &z)
Evaluate the Boolean function x ^ (z | (x & y)).
bool isLogicNetworkOp(mlir::Operation *op)
void populateVariadicAndInverterLoweringPatterns(mlir::RewritePatternSet &patterns)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
ParseResult parseVariadicInvertibleOperands(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &operands, Type &resultType, mlir::DenseBoolArrayAttr &inverted, NamedAttrList &attrDict)
Parse a variadic list of operands that may be prefixed with an optional not keyword.
void addAndClauses(int outVar, llvm::ArrayRef< int > inputLits, llvm::function_ref< void(llvm::ArrayRef< int >)> addClause)
Emit clauses encoding outVar <=> and(inputLits).
void addXorClauses(int outVar, int lhsLit, int rhsLit, llvm::function_ref< void(llvm::ArrayRef< int >)> addClause)
Emit clauses encoding outVar <=> (lhsLit xor rhsLit).
void printVariadicInvertibleOperands(OpAsmPrinter &printer, Operation *op, OperandRange operands, Type resultType, mlir::DenseBoolArrayAttr inverted, DictionaryAttr attrDict)
Print a variadic list of operands that may be prefixed with an optional not keyword.
void replaceOpAndCopyNamehint(PatternRewriter &rewriter, Operation *op, Value newValue)
A wrapper of PatternRewriter::replaceOp to propagate "sv.namehint" attribute.
void addOrClauses(int outVar, llvm::ArrayRef< int > inputLits, llvm::function_ref< void(llvm::ArrayRef< int >)> addClause)
Emit clauses encoding outVar <=> or(inputLits).
void addParityClauses(int outVar, llvm::ArrayRef< int > inputLits, llvm::function_ref< void(llvm::ArrayRef< int >)> addClause, llvm::function_ref< int()> newVar)
Emit clauses encoding outVar <=> parity(inputLits).