13#include "mlir/Analysis/TopologicalSortUtils.h"
14#include "mlir/IR/BuiltinAttributes.h"
15#include "mlir/IR/Matchers.h"
16#include "mlir/IR/OpDefinition.h"
17#include "mlir/IR/PatternMatch.h"
18#include "mlir/IR/Value.h"
19#include "llvm/ADT/APInt.h"
20#include "llvm/ADT/SmallVector.h"
21#include "llvm/Support/Casting.h"
22#include "llvm/Support/LogicalResult.h"
27using namespace circt::synth::mig;
28using namespace circt::synth::aig;
31#include "circt/Dialect/Synth/Synth.cpp.inc"
33LogicalResult ChoiceOp::verify() {
34 if (getNumOperands() < 1)
35 return emitOpError(
"requires at least one operand");
39OpFoldResult ChoiceOp::fold(FoldAdaptor adaptor) {
40 if (adaptor.getInputs().size() == 1)
54LogicalResult ChoiceOp::canonicalize(ChoiceOp op, PatternRewriter &rewriter) {
55 llvm::SetVector<Value> worklist;
56 llvm::SmallSetVector<Operation *, 4> visitedChoices;
58 auto addToWorklist = [&](ChoiceOp choice) ->
bool {
59 if (choice->getBlock() == op->getBlock() && visitedChoices.insert(choice)) {
60 worklist.insert(choice.getInputs().begin(), choice.getInputs().end());
68 bool mergedOtherChoices =
false;
71 for (
unsigned i = 0; i < worklist.size(); ++i) {
72 Value val = worklist[i];
73 if (
auto defOp = val.getDefiningOp<synth::ChoiceOp>()) {
75 if (addToWorklist(defOp))
76 mergedOtherChoices =
true;
79 for (Operation *user : val.getUsers()) {
80 if (
auto userChoice = llvm::dyn_cast<synth::ChoiceOp>(user)) {
81 if (addToWorklist(userChoice)) {
82 mergedOtherChoices =
true;
88 llvm::SmallVector<mlir::Value> finalOperands;
89 for (Value v : worklist) {
90 if (!visitedChoices.contains(v.getDefiningOp())) {
91 finalOperands.push_back(v);
95 if (!mergedOtherChoices && finalOperands.size() == op.getInputs().size())
96 return llvm::failure();
98 auto newChoice = synth::ChoiceOp::create(rewriter, op->getLoc(), op.getType(),
100 for (Operation *visited : visitedChoices.takeVector())
101 rewriter.replaceOp(visited, newChoice);
103 for (
auto value : newChoice.getInputs())
104 rewriter.replaceAllUsesExcept(value, newChoice.getResult(), newChoice);
109LogicalResult MajorityInverterOp::verify() {
110 if (getNumOperands() % 2 != 1)
111 return emitOpError(
"requires an odd number of operands");
116llvm::APInt MajorityInverterOp::evaluate(ArrayRef<APInt> inputs) {
117 assert(inputs.size() == getNumOperands() &&
118 "Number of inputs must match number of operands");
120 if (inputs.size() == 3) {
121 auto a = (isInverted(0) ? ~inputs[0] : inputs[0]);
122 auto b = (isInverted(1) ? ~inputs[1] : inputs[1]);
123 auto c = (isInverted(2) ? ~inputs[2] : inputs[2]);
124 return (a & b) | (
a & c) | (b & c);
128 auto width = inputs[0].getBitWidth();
129 APInt result(width, 0);
131 for (
size_t bit = 0; bit < width; ++bit) {
133 for (
size_t i = 0; i < inputs.size(); ++i) {
135 if (isInverted(i) ^ inputs[i][bit])
139 if (count > inputs.size() / 2)
146OpFoldResult MajorityInverterOp::fold(FoldAdaptor adaptor) {
148 SmallVector<APInt, 3> inputValues;
149 SmallVector<size_t, 3> nonConstantValues;
150 for (
auto [i, input] :
llvm::enumerate(adaptor.getInputs())) {
151 auto attr = llvm::dyn_cast_or_null<IntegerAttr>(input);
153 inputValues.push_back(attr.getValue());
155 nonConstantValues.push_back(i);
158 if (nonConstantValues.size() == 0)
159 return IntegerAttr::get(getType(), evaluate(inputValues));
161 if (getNumOperands() != 3)
164 auto getConstant = [&](
unsigned index) -> std::optional<llvm::APInt> {
166 if (mlir::matchPattern(getInputs()[index], mlir::m_ConstantInt(&value)))
167 return isInverted(index) ? ~value : value;
170 if (nonConstantValues.size() == 1) {
171 auto k = nonConstantValues[0];
172 auto i = (k + 1) % 3;
173 auto j = (k + 2) % 3;
180 return IntegerAttr::get(IntegerType::get(getContext(), c1->getBitWidth()),
184 (*this)->setOperands({getOperand(i)});
185 (*this).setInverted({
true});
188 return getOperand(k);
194LogicalResult MajorityInverterOp::canonicalize(MajorityInverterOp op,
195 PatternRewriter &rewriter) {
196 if (op.getNumOperands() == 1) {
197 if (op.getInverted()[0])
199 rewriter.replaceOp(op, op.getOperand(0));
204 if (op.getNumOperands() != 3)
208 auto replaceWithIndex = [&](
int index) {
209 bool inverted = op.isInverted(index);
211 rewriter.replaceOpWithNewOp<MajorityInverterOp>(
212 op, op.getType(), op.getOperand(index),
true);
214 rewriter.replaceOp(op, op.getOperand(index));
221 for (
int i = 0; i < 2; ++i) {
222 for (
int j = i + 1; j < 3; ++j) {
226 if (op.getOperand(i) == op.getOperand(j)) {
228 if (op.isInverted(i) != op.isInverted(j))
229 return replaceWithIndex(k);
230 return replaceWithIndex(i);
241OpFoldResult AndInverterOp::fold(FoldAdaptor adaptor) {
242 if (getNumOperands() == 1 && !isInverted(0))
243 return getOperand(0);
245 auto inputs = adaptor.getInputs();
246 if (inputs.size() == 2)
247 if (
auto intAttr = dyn_cast_or_null<IntegerAttr>(inputs[1])) {
248 auto value = intAttr.getValue();
252 return IntegerAttr::get(
253 IntegerType::get(getContext(), value.getBitWidth()), value);
254 if (value.isAllOnes()) {
258 return getOperand(0);
264LogicalResult AndInverterOp::canonicalize(AndInverterOp op,
265 PatternRewriter &rewriter) {
267 SmallVector<Value> uniqueValues;
268 SmallVector<bool> uniqueInverts;
271 APInt::getAllOnes(op.getResult().getType().getIntOrFloatBitWidth());
273 bool invertedConstFound =
false;
274 bool flippedFound =
false;
276 for (
auto [value, inverted] :
llvm::zip(op.getInputs(), op.getInverted())) {
277 bool newInverted = inverted;
280 constValue &= ~constOp.getValue();
281 invertedConstFound =
true;
283 constValue &= constOp.getValue();
288 if (
auto andInverterOp = value.getDefiningOp<synth::aig::AndInverterOp>()) {
289 if (andInverterOp.getInputs().size() == 1 &&
290 andInverterOp.isInverted(0)) {
291 value = andInverterOp.getOperand(0);
292 newInverted = andInverterOp.isInverted(0) ^ inverted;
297 auto it = seen.find(value);
298 if (it == seen.end()) {
299 seen.insert({value, newInverted});
300 uniqueValues.push_back(value);
301 uniqueInverts.push_back(newInverted);
302 }
else if (it->second != newInverted) {
305 op, APInt::getZero(value.getType().getIntOrFloatBitWidth()));
311 if (constValue.isZero()) {
317 if ((uniqueValues.size() == op.getInputs().size() && !flippedFound) ||
318 (!constValue.isAllOnes() && !invertedConstFound &&
319 uniqueValues.size() + 1 == op.getInputs().size()))
322 if (!constValue.isAllOnes()) {
324 uniqueInverts.push_back(
false);
325 uniqueValues.push_back(constOp);
329 if (uniqueValues.empty()) {
335 replaceOpWithNewOpAndCopyNamehint<synth::aig::AndInverterOp>(
336 rewriter, op, uniqueValues, uniqueInverts);
340APInt AndInverterOp::evaluate(ArrayRef<APInt> inputs) {
341 assert(inputs.size() == getNumOperands() &&
342 "Expected as many inputs as operands");
343 assert(!inputs.empty() &&
"Expected non-empty input list");
344 APInt result = APInt::getAllOnes(inputs.front().getBitWidth());
345 for (
auto [idx, input] :
llvm::enumerate(inputs)) {
355 ArrayRef<bool> inverts,
356 PatternRewriter &rewriter) {
357 switch (operands.size()) {
359 assert(0 &&
"cannot be called with empty operand range");
363 return AndInverterOp::create(rewriter, op.getLoc(), operands[0],
true);
367 return AndInverterOp::create(rewriter, op.getLoc(), operands[0],
368 operands[1], inverts[0], inverts[1]);
370 auto firstHalf = operands.size() / 2;
373 inverts.take_front(firstHalf), rewriter);
376 inverts.drop_front(firstHalf), rewriter);
377 return AndInverterOp::create(rewriter, op.getLoc(), lhs, rhs);
383 AndInverterOp op, PatternRewriter &rewriter)
const {
384 if (op.getInputs().size() <= 2)
390 op, op.getOperands(), op.getInverted(), rewriter));
396 llvm::function_ref<
bool(mlir::Value, mlir::Operation *)> isOperandReady) {
398 auto walkResult = op->walk([&](Region *region) {
400 dyn_cast<mlir::RegionKindInterface>(region->getParentOp());
402 regionKindOp.hasSSADominance(region->getRegionNumber()))
403 return WalkResult::advance();
406 for (
auto &block : *region) {
407 if (!mlir::sortTopologically(&block, isOperandReady))
408 return WalkResult::interrupt();
410 return WalkResult::advance();
413 return success(!walkResult.wasInterrupted());
assert(baseType &&"element must be base type")
static std::optional< APSInt > getConstant(Attribute operand)
Determine the value of a constant operand for the sake of constant folding.
static Value lowerVariadicAndInverterOp(AndInverterOp op, OperandRange operands, ArrayRef< bool > inverts, PatternRewriter &rewriter)
LogicalResult topologicallySortGraphRegionBlocks(mlir::Operation *op, llvm::function_ref< bool(mlir::Value, mlir::Operation *)> isOperandReady)
This function performs a topological sort on the operations within each block of graph regions in the...
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
mlir::LogicalResult matchAndRewrite(aig::AndInverterOp op, mlir::PatternRewriter &rewriter) const override