CIRCT 23.0.0git
Loading...
Searching...
No Matches
SynthOps.cpp
Go to the documentation of this file.
1//===----------------------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
15#include "mlir/Analysis/TopologicalSortUtils.h"
16#include "mlir/IR/BuiltinAttributes.h"
17#include "mlir/IR/Matchers.h"
18#include "mlir/IR/OpDefinition.h"
19#include "mlir/IR/PatternMatch.h"
20#include "mlir/IR/Value.h"
21#include "llvm/ADT/APInt.h"
22#include "llvm/ADT/SmallVector.h"
23#include "llvm/Support/Casting.h"
24#include "llvm/Support/LogicalResult.h"
25
26using namespace mlir;
27using namespace circt;
28using namespace circt::synth;
29using namespace circt::synth::aig;
30
31#define GET_OP_CLASSES
32#include "circt/Dialect/Synth/Synth.cpp.inc"
33
34namespace {
35
36// Keep inversion semantics identical across folding, analysis, and CNF
37// lowering so new invertible Synth ops can reuse the same helpers.
38inline APInt applyInversion(APInt value, bool inverted) {
39 if (inverted)
40 value.flipAllBits();
41 return value;
42}
43
44inline llvm::KnownBits applyInversion(llvm::KnownBits value, bool inverted) {
45 if (inverted)
46 std::swap(value.Zero, value.One);
47 return value;
48}
49
50} // namespace
51
52LogicalResult ChoiceOp::verify() {
53 if (getNumOperands() < 1)
54 return emitOpError("requires at least one operand");
55 return success();
56}
57
58OpFoldResult ChoiceOp::fold(FoldAdaptor adaptor) {
59 if (adaptor.getInputs().size() == 1)
60 return getOperand(0);
61 return {};
62}
63
64// Canonicalize a network of synth.choice operations by computing their
65// transitive closure and flattening them into a single choice operation.
66// This merges nested choices and deduplicates shared operands.
67// Pattern matched:
68// %0 = synth.choice %x, %y, %z
69// %1 = synth.choice %0, %u
70// %2 = synth.choice %z, %v
71// =>
72// %merged = synth.choice %x, %y, %z, %u, %v
73LogicalResult ChoiceOp::canonicalize(ChoiceOp op, PatternRewriter &rewriter) {
74 llvm::SetVector<Value> worklist;
76
77 auto addToWorklist = [&](ChoiceOp choice) -> bool {
78 if (choice->getBlock() == op->getBlock() && visitedChoices.insert(choice)) {
79 worklist.insert(choice.getInputs().begin(), choice.getInputs().end());
80 return true;
81 }
82 return false;
83 };
84
85 addToWorklist(op);
86
87 bool mergedOtherChoices = false;
88
89 // Look up and down at definitions and users.
90 for (unsigned i = 0; i < worklist.size(); ++i) {
91 Value val = worklist[i];
92 if (auto defOp = val.getDefiningOp<synth::ChoiceOp>()) {
93
94 if (addToWorklist(defOp))
95 mergedOtherChoices = true;
96 }
97
98 for (Operation *user : val.getUsers()) {
99 if (auto userChoice = llvm::dyn_cast<synth::ChoiceOp>(user)) {
100 if (addToWorklist(userChoice)) {
101 mergedOtherChoices = true;
102 }
103 }
104 }
105 }
106
107 llvm::SmallVector<mlir::Value> finalOperands;
108 for (Value v : worklist) {
109 if (!visitedChoices.contains(v.getDefiningOp())) {
110 finalOperands.push_back(v);
111 }
112 }
113
114 if (!mergedOtherChoices && finalOperands.size() == op.getInputs().size())
115 return llvm::failure();
116
117 auto newChoice = synth::ChoiceOp::create(rewriter, op->getLoc(), op.getType(),
118 finalOperands);
119 for (Operation *visited : visitedChoices.takeVector())
120 rewriter.replaceOp(visited, newChoice);
121
122 for (auto value : newChoice.getInputs())
123 rewriter.replaceAllUsesExcept(value, newChoice.getResult(), newChoice);
124
125 return success();
126}
127
128//===----------------------------------------------------------------------===//
129// AndInverterOp
130//===----------------------------------------------------------------------===//
131
132bool AndInverterOp::areInputsPermutationInvariant() { return true; }
133
134OpFoldResult AndInverterOp::fold(FoldAdaptor adaptor) {
135 if (getNumOperands() == 1 && !isInverted(0))
136 return getOperand(0);
137
138 auto inputs = adaptor.getInputs();
139 if (inputs.size() == 2)
140 if (auto intAttr = dyn_cast_or_null<IntegerAttr>(inputs[1])) {
141 auto value = intAttr.getValue();
142 if (isInverted(1))
143 value = ~value;
144 if (value.isZero())
145 return IntegerAttr::get(
146 IntegerType::get(getContext(), value.getBitWidth()), value);
147 if (value.isAllOnes()) {
148 if (isInverted(0))
149 return {};
150
151 return getOperand(0);
152 }
153 }
154 return {};
155}
156
157LogicalResult AndInverterOp::canonicalize(AndInverterOp op,
158 PatternRewriter &rewriter) {
160 SmallVector<Value> uniqueValues;
161 SmallVector<bool> uniqueInverts;
162
163 APInt constValue =
164 APInt::getAllOnes(op.getResult().getType().getIntOrFloatBitWidth());
165
166 bool invertedConstFound = false;
167 bool flippedFound = false;
168
169 for (auto [value, inverted] : llvm::zip(op.getInputs(), op.getInverted())) {
170 bool newInverted = inverted;
171 if (auto constOp = value.getDefiningOp<hw::ConstantOp>()) {
172 if (inverted) {
173 constValue &= ~constOp.getValue();
174 invertedConstFound = true;
175 } else {
176 constValue &= constOp.getValue();
177 }
178 continue;
179 }
180
181 if (auto andInverterOp = value.getDefiningOp<synth::aig::AndInverterOp>()) {
182 if (andInverterOp.getInputs().size() == 1 &&
183 andInverterOp.isInverted(0)) {
184 value = andInverterOp.getOperand(0);
185 newInverted = andInverterOp.isInverted(0) ^ inverted;
186 flippedFound = true;
187 }
188 }
189
190 auto it = seen.find(value);
191 if (it == seen.end()) {
192 seen.insert({value, newInverted});
193 uniqueValues.push_back(value);
194 uniqueInverts.push_back(newInverted);
195 } else if (it->second != newInverted) {
196 // replace with const 0
197 rewriter.replaceOpWithNewOp<hw::ConstantOp>(
198 op, APInt::getZero(value.getType().getIntOrFloatBitWidth()));
199 return success();
200 }
201 }
202
203 // If the constant is zero, we can just replace with zero.
204 if (constValue.isZero()) {
205 rewriter.replaceOpWithNewOp<hw::ConstantOp>(op, constValue);
206 return success();
207 }
208
209 // No change.
210 if ((uniqueValues.size() == op.getInputs().size() && !flippedFound) ||
211 (!constValue.isAllOnes() && !invertedConstFound &&
212 uniqueValues.size() + 1 == op.getInputs().size()))
213 return failure();
214
215 if (!constValue.isAllOnes()) {
216 auto constOp = hw::ConstantOp::create(rewriter, op.getLoc(), constValue);
217 uniqueInverts.push_back(false);
218 uniqueValues.push_back(constOp);
219 }
220
221 // It means the input is reduced to all ones.
222 if (uniqueValues.empty()) {
223 rewriter.replaceOpWithNewOp<hw::ConstantOp>(op, constValue);
224 return success();
225 }
226
227 // build new op with reduced input values
228 replaceOpWithNewOpAndCopyNamehint<synth::aig::AndInverterOp>(
229 rewriter, op, uniqueValues, uniqueInverts);
230 return success();
231}
232
233APInt AndInverterOp::evaluateBooleanLogic(
234 llvm::function_ref<const APInt &(unsigned)> getInputValue) {
235 assert(getNumOperands() > 0 && "Expected non-empty input list");
236 APInt result = APInt::getAllOnes(getInputValue(0).getBitWidth());
237 for (auto [idx, inverted] : llvm::enumerate(getInverted())) {
238 const APInt &input = getInputValue(idx);
239 // Model each operand inversion before intersecting with the running AND.
240 result &= applyInversion(input, inverted);
241 }
242 return result;
243}
244
245llvm::KnownBits AndInverterOp::computeKnownBits(
246 llvm::function_ref<const llvm::KnownBits &(unsigned)> getInputKnownBits) {
247 assert(getNumOperands() > 0 && "Expected non-empty input list");
248
249 auto width = getInputKnownBits(0).getBitWidth();
250 llvm::KnownBits result(width);
251 result.One = APInt::getAllOnes(width);
252 result.Zero = APInt::getZero(width);
253
254 for (auto [i, inverted] : llvm::enumerate(getInverted()))
255 result &= applyInversion(getInputKnownBits(i), inverted);
256
257 return result;
258}
259
260int64_t AndInverterOp::getLogicDepthCost() {
261 return llvm::Log2_64_Ceil(getNumOperands());
262}
263
264std::optional<uint64_t> AndInverterOp::getLogicAreaCost() {
265 int64_t bitWidth = hw::getBitWidth(getType());
266 if (bitWidth < 0)
267 return std::nullopt;
268 return static_cast<uint64_t>(getNumOperands() - 1) * bitWidth;
269}
270
271void AndInverterOp::emitCNFWithoutInversion(
272 int outVar, llvm::ArrayRef<int> inputVars,
273 llvm::function_ref<void(llvm::ArrayRef<int>)> addClause,
274 llvm::function_ref<int()> newVar) {
275 (void)newVar;
276 circt::addAndClauses(outVar, inputVars, addClause);
277}
278
279//===----------------------------------------------------------------------===//
280// XorInverterOp
281//===----------------------------------------------------------------------===//
282
283bool XorInverterOp::areInputsPermutationInvariant() { return true; }
284
285APInt XorInverterOp::evaluateBooleanLogic(
286 llvm::function_ref<const APInt &(unsigned)> getInputValue) {
287 assert(getNumOperands() > 0 && "Expected non-empty input list");
288 APInt result = APInt::getZero(getInputValue(0).getBitWidth());
289 for (auto [idx, inverted] : llvm::enumerate(getInverted()))
290 result ^= applyInversion(getInputValue(idx), inverted);
291 return result;
292}
293
294llvm::KnownBits XorInverterOp::computeKnownBits(
295 llvm::function_ref<const llvm::KnownBits &(unsigned)> getInputKnownBits) {
296 assert(getNumOperands() > 0 && "Expected non-empty input list");
297
298 llvm::KnownBits result(getInputKnownBits(0).getBitWidth());
299 for (auto [i, inverted] : llvm::enumerate(getInverted()))
300 result ^= applyInversion(getInputKnownBits(i), inverted);
301 return result;
302}
303
304int64_t XorInverterOp::getLogicDepthCost() {
305 return llvm::Log2_64_Ceil(getNumOperands());
306}
307
308std::optional<uint64_t> XorInverterOp::getLogicAreaCost() {
309 int64_t bitWidth = hw::getBitWidth(getType());
310 if (bitWidth < 0)
311 return std::nullopt;
312 return static_cast<uint64_t>(getNumOperands() - 1) * bitWidth;
313}
314
315void XorInverterOp::emitCNFWithoutInversion(
316 int outVar, llvm::ArrayRef<int> inputVars,
317 llvm::function_ref<void(llvm::ArrayRef<int>)> addClause,
318 llvm::function_ref<int()> newVar) {
319 circt::addParityClauses(outVar, inputVars, addClause, newVar);
320}
321
323 Location loc, ValueRange operands, ArrayRef<bool> inverts,
324 PatternRewriter &rewriter,
325 llvm::function_ref<Value(Value, bool)> createUnary,
326 llvm::function_ref<Value(Value, Value, bool, bool)> createBinary) {
327 switch (operands.size()) {
328 case 0:
329 assert(0 && "cannot be called with empty operand range");
330 break;
331 case 1:
332 return inverts[0] ? createUnary(operands[0], true) : operands[0];
333 case 2:
334 return createBinary(operands[0], operands[1], inverts[0], inverts[1]);
335 default:
336 auto firstHalf = operands.size() / 2;
337 auto lhs = lowerVariadicInvertibleOp(loc, operands.take_front(firstHalf),
338 inverts.take_front(firstHalf),
339 rewriter, createUnary, createBinary);
340 auto rhs = lowerVariadicInvertibleOp(loc, operands.drop_front(firstHalf),
341 inverts.drop_front(firstHalf),
342 rewriter, createUnary, createBinary);
343 return createBinary(lhs, rhs, false, false);
344 }
345 return Value();
346}
347
348template <typename OpTy>
350 PatternRewriter &rewriter) {
351 if (op.getInputs().size() <= 2)
352 return failure();
353 auto result = lowerVariadicInvertibleOp(
354 op.getLoc(), op.getOperands(), op.getInverted(), rewriter,
355 [&](Value input, bool invert) {
356 return OpTy::create(rewriter, op.getLoc(), input, invert);
357 },
358 [&](Value lhs, Value rhs, bool invertLhs, bool invertRhs) {
359 return OpTy::create(rewriter, op.getLoc(), lhs, rhs, invertLhs,
360 invertRhs);
361 });
362 replaceOpAndCopyNamehint(rewriter, op, result);
363 return success();
364}
365
367 RewritePatternSet &patterns) {
368 patterns.add(lowerVariadicAndInverterOpConversion<aig::AndInverterOp>);
369}
370
372 RewritePatternSet &patterns) {
373 patterns.add(lowerVariadicAndInverterOpConversion<XorInverterOp>);
374}
375
377 mlir::Operation *op,
378 llvm::function_ref<bool(mlir::Value, mlir::Operation *)> isOperandReady) {
379 // Sort the operations topologically
380 auto walkResult = op->walk([&](Region *region) {
381 auto regionKindOp =
382 dyn_cast<mlir::RegionKindInterface>(region->getParentOp());
383 if (!regionKindOp ||
384 regionKindOp.hasSSADominance(region->getRegionNumber()))
385 return WalkResult::advance();
386
387 // Graph region.
388 for (auto &block : *region) {
389 if (!mlir::sortTopologically(&block, isOperandReady))
390 return WalkResult::interrupt();
391 }
392 return WalkResult::advance();
393 });
394
395 return success(!walkResult.wasInterrupted());
396}
assert(baseType &&"element must be base type")
LogicalResult lowerVariadicAndInverterOpConversion(OpTy op, PatternRewriter &rewriter)
Definition SynthOps.cpp:349
static Value lowerVariadicInvertibleOp(Location loc, ValueRange operands, ArrayRef< bool > inverts, PatternRewriter &rewriter, llvm::function_ref< Value(Value, bool)> createUnary, llvm::function_ref< Value(Value, Value, bool, bool)> createBinary)
Definition SynthOps.cpp:322
create(data_type, value)
Definition hw.py:433
int64_t getBitWidth(mlir::Type type)
Return the hardware bit width of a type.
Definition HWTypes.cpp:110
void populateVariadicXorInverterLoweringPatterns(mlir::RewritePatternSet &patterns)
LogicalResult topologicallySortGraphRegionBlocks(mlir::Operation *op, llvm::function_ref< bool(mlir::Value, mlir::Operation *)> isOperandReady)
This function performs a topological sort on the operations within each block of graph regions in the...
Definition SynthOps.cpp:376
void populateVariadicAndInverterLoweringPatterns(mlir::RewritePatternSet &patterns)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
void addAndClauses(int outVar, llvm::ArrayRef< int > inputLits, llvm::function_ref< void(llvm::ArrayRef< int >)> addClause)
Emit clauses encoding outVar <=> and(inputLits).
void replaceOpAndCopyNamehint(PatternRewriter &rewriter, Operation *op, Value newValue)
A wrapper of PatternRewriter::replaceOp to propagate "sv.namehint" attribute.
Definition Naming.cpp:73
void addParityClauses(int outVar, llvm::ArrayRef< int > inputLits, llvm::function_ref< void(llvm::ArrayRef< int >)> addClause, llvm::function_ref< int()> newVar)
Emit clauses encoding outVar <=> parity(inputLits).