CIRCT 23.0.0git
Loading...
Searching...
No Matches
SynthOps.cpp
Go to the documentation of this file.
1//===----------------------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
13#include "mlir/Analysis/TopologicalSortUtils.h"
14#include "mlir/IR/BuiltinAttributes.h"
15#include "mlir/IR/Matchers.h"
16#include "mlir/IR/OpDefinition.h"
17#include "mlir/IR/PatternMatch.h"
18#include "llvm/ADT/APInt.h"
19#include "llvm/Support/Casting.h"
20#include "llvm/Support/LogicalResult.h"
21
22using namespace mlir;
23using namespace circt;
24using namespace circt::synth;
25using namespace circt::synth::mig;
26using namespace circt::synth::aig;
27
28#define GET_OP_CLASSES
29#include "circt/Dialect/Synth/Synth.cpp.inc"
30
31LogicalResult ChoiceOp::verify() {
32 if (getNumOperands() < 1)
33 return emitOpError("requires at least one operand");
34 return success();
35}
36
37LogicalResult MajorityInverterOp::verify() {
38 if (getNumOperands() % 2 != 1)
39 return emitOpError("requires an odd number of operands");
40
41 return success();
42}
43
44llvm::APInt MajorityInverterOp::evaluate(ArrayRef<APInt> inputs) {
45 assert(inputs.size() == getNumOperands() &&
46 "Number of inputs must match number of operands");
47
48 if (inputs.size() == 3) {
49 auto a = (isInverted(0) ? ~inputs[0] : inputs[0]);
50 auto b = (isInverted(1) ? ~inputs[1] : inputs[1]);
51 auto c = (isInverted(2) ? ~inputs[2] : inputs[2]);
52 return (a & b) | (a & c) | (b & c);
53 }
54
55 // General case for odd number of inputs != 3
56 auto width = inputs[0].getBitWidth();
57 APInt result(width, 0);
58
59 for (size_t bit = 0; bit < width; ++bit) {
60 size_t count = 0;
61 for (size_t i = 0; i < inputs.size(); ++i) {
62 // Count the number of 1s, considering inversion.
63 if (isInverted(i) ^ inputs[i][bit])
64 count++;
65 }
66
67 if (count > inputs.size() / 2)
68 result.setBit(bit);
69 }
70
71 return result;
72}
73
74OpFoldResult MajorityInverterOp::fold(FoldAdaptor adaptor) {
75 // TODO: Implement maj(x, 1, 1) = 1, maj(x, 0, 0) = 0
76
77 SmallVector<APInt, 3> inputValues;
78 for (auto input : adaptor.getInputs()) {
79 auto attr = llvm::dyn_cast_or_null<IntegerAttr>(input);
80 if (!attr)
81 return {};
82 inputValues.push_back(attr.getValue());
83 }
84
85 auto result = evaluate(inputValues);
86 return IntegerAttr::get(getType(), result);
87}
88
89LogicalResult MajorityInverterOp::canonicalize(MajorityInverterOp op,
90 PatternRewriter &rewriter) {
91 if (op.getNumOperands() == 1) {
92 if (op.getInverted()[0])
93 return failure();
94 rewriter.replaceOp(op, op.getOperand(0));
95 return success();
96 }
97
98 // For now, only support 3 operands.
99 if (op.getNumOperands() != 3)
100 return failure();
101
102 // Return if the idx-th operand is a constant (inverted if necessary),
103 // otherwise return std::nullopt.
104 auto getConstant = [&](unsigned index) -> std::optional<llvm::APInt> {
105 APInt value;
106 if (mlir::matchPattern(op.getInputs()[index], mlir::m_ConstantInt(&value)))
107 return op.isInverted(index) ? ~value : value;
108 return std::nullopt;
109 };
110
111 // Replace the op with the idx-th operand (inverted if necessary).
112 auto replaceWithIndex = [&](int index) {
113 bool inverted = op.isInverted(index);
114 if (inverted)
115 rewriter.replaceOpWithNewOp<MajorityInverterOp>(
116 op, op.getType(), op.getOperand(index), true);
117 else
118 rewriter.replaceOp(op, op.getOperand(index));
119 return success();
120 };
121
122 // Pattern match following cases:
123 // maj_inv(x, x, y) -> x
124 // maj_inv(x, y, not y) -> x
125 for (int i = 0; i < 2; ++i) {
126 for (int j = i + 1; j < 3; ++j) {
127 int k = 3 - (i + j);
128 assert(k >= 0 && k < 3);
129 // If we have two identical operands, we can fold.
130 if (op.getOperand(i) == op.getOperand(j)) {
131 // If they are inverted differently, we can fold to the third.
132 if (op.isInverted(i) != op.isInverted(j))
133 return replaceWithIndex(k);
134 return replaceWithIndex(i);
135 }
136
137 // If i and j are constant.
138 if (auto c1 = getConstant(i)) {
139 if (auto c2 = getConstant(j)) {
140 // If both constants are equal, we can fold.
141 if (*c1 == *c2) {
142 rewriter.replaceOpWithNewOp<hw::ConstantOp>(
143 op, op.getType(), mlir::IntegerAttr::get(op.getType(), *c1));
144 return success();
145 }
146 // If constants are complementary, we can fold.
147 if (*c1 == ~*c2)
148 return replaceWithIndex(k);
149 }
150 }
151 }
152 }
153 return failure();
154}
155
156//===----------------------------------------------------------------------===//
157// AIG Operations
158//===----------------------------------------------------------------------===//
159
160OpFoldResult AndInverterOp::fold(FoldAdaptor adaptor) {
161 if (getNumOperands() == 1 && !isInverted(0))
162 return getOperand(0);
163
164 auto inputs = adaptor.getInputs();
165 if (inputs.size() == 2)
166 if (auto intAttr = dyn_cast_or_null<IntegerAttr>(inputs[1])) {
167 auto value = intAttr.getValue();
168 if (isInverted(1))
169 value = ~value;
170 if (value.isZero())
171 return IntegerAttr::get(
172 IntegerType::get(getContext(), value.getBitWidth()), value);
173 if (value.isAllOnes()) {
174 if (isInverted(0))
175 return {};
176
177 return getOperand(0);
178 }
179 }
180 return {};
181}
182
183LogicalResult AndInverterOp::canonicalize(AndInverterOp op,
184 PatternRewriter &rewriter) {
186 SmallVector<Value> uniqueValues;
187 SmallVector<bool> uniqueInverts;
188
189 APInt constValue =
190 APInt::getAllOnes(op.getResult().getType().getIntOrFloatBitWidth());
191
192 bool invertedConstFound = false;
193 bool flippedFound = false;
194
195 for (auto [value, inverted] : llvm::zip(op.getInputs(), op.getInverted())) {
196 bool newInverted = inverted;
197 if (auto constOp = value.getDefiningOp<hw::ConstantOp>()) {
198 if (inverted) {
199 constValue &= ~constOp.getValue();
200 invertedConstFound = true;
201 } else {
202 constValue &= constOp.getValue();
203 }
204 continue;
205 }
206
207 if (auto andInverterOp = value.getDefiningOp<synth::aig::AndInverterOp>()) {
208 if (andInverterOp.getInputs().size() == 1 &&
209 andInverterOp.isInverted(0)) {
210 value = andInverterOp.getOperand(0);
211 newInverted = andInverterOp.isInverted(0) ^ inverted;
212 flippedFound = true;
213 }
214 }
215
216 auto it = seen.find(value);
217 if (it == seen.end()) {
218 seen.insert({value, newInverted});
219 uniqueValues.push_back(value);
220 uniqueInverts.push_back(newInverted);
221 } else if (it->second != newInverted) {
222 // replace with const 0
223 rewriter.replaceOpWithNewOp<hw::ConstantOp>(
224 op, APInt::getZero(value.getType().getIntOrFloatBitWidth()));
225 return success();
226 }
227 }
228
229 // If the constant is zero, we can just replace with zero.
230 if (constValue.isZero()) {
231 rewriter.replaceOpWithNewOp<hw::ConstantOp>(op, constValue);
232 return success();
233 }
234
235 // No change.
236 if ((uniqueValues.size() == op.getInputs().size() && !flippedFound) ||
237 (!constValue.isAllOnes() && !invertedConstFound &&
238 uniqueValues.size() + 1 == op.getInputs().size()))
239 return failure();
240
241 if (!constValue.isAllOnes()) {
242 auto constOp = hw::ConstantOp::create(rewriter, op.getLoc(), constValue);
243 uniqueInverts.push_back(false);
244 uniqueValues.push_back(constOp);
245 }
246
247 // It means the input is reduced to all ones.
248 if (uniqueValues.empty()) {
249 rewriter.replaceOpWithNewOp<hw::ConstantOp>(op, constValue);
250 return success();
251 }
252
253 // build new op with reduced input values
254 replaceOpWithNewOpAndCopyNamehint<synth::aig::AndInverterOp>(
255 rewriter, op, uniqueValues, uniqueInverts);
256 return success();
257}
258
259APInt AndInverterOp::evaluate(ArrayRef<APInt> inputs) {
260 assert(inputs.size() == getNumOperands() &&
261 "Expected as many inputs as operands");
262 assert(!inputs.empty() && "Expected non-empty input list");
263 APInt result = APInt::getAllOnes(inputs.front().getBitWidth());
264 for (auto [idx, input] : llvm::enumerate(inputs)) {
265 if (isInverted(idx))
266 result &= ~input;
267 else
268 result &= input;
269 }
270 return result;
271}
272
273static Value lowerVariadicAndInverterOp(AndInverterOp op, OperandRange operands,
274 ArrayRef<bool> inverts,
275 PatternRewriter &rewriter) {
276 switch (operands.size()) {
277 case 0:
278 assert(0 && "cannot be called with empty operand range");
279 break;
280 case 1:
281 if (inverts[0])
282 return AndInverterOp::create(rewriter, op.getLoc(), operands[0], true);
283 else
284 return operands[0];
285 case 2:
286 return AndInverterOp::create(rewriter, op.getLoc(), operands[0],
287 operands[1], inverts[0], inverts[1]);
288 default:
289 auto firstHalf = operands.size() / 2;
290 auto lhs =
291 lowerVariadicAndInverterOp(op, operands.take_front(firstHalf),
292 inverts.take_front(firstHalf), rewriter);
293 auto rhs =
294 lowerVariadicAndInverterOp(op, operands.drop_front(firstHalf),
295 inverts.drop_front(firstHalf), rewriter);
296 return AndInverterOp::create(rewriter, op.getLoc(), lhs, rhs);
297 }
298 return Value();
299}
300
302 AndInverterOp op, PatternRewriter &rewriter) const {
303 if (op.getInputs().size() <= 2)
304 return failure();
305 // TODO: This is a naive implementation that creates a balanced binary tree.
306 // We can improve by analyzing the dataflow and creating a tree that
307 // improves the critical path or area.
308 rewriter.replaceOp(op, lowerVariadicAndInverterOp(
309 op, op.getOperands(), op.getInverted(), rewriter));
310 return success();
311}
312
314 mlir::Operation *op,
315 llvm::function_ref<bool(mlir::Value, mlir::Operation *)> isOperandReady) {
316 // Sort the operations topologically
317 auto walkResult = op->walk([&](Region *region) {
318 auto regionKindOp =
319 dyn_cast<mlir::RegionKindInterface>(region->getParentOp());
320 if (!regionKindOp ||
321 regionKindOp.hasSSADominance(region->getRegionNumber()))
322 return WalkResult::advance();
323
324 // Graph region.
325 for (auto &block : *region) {
326 if (!mlir::sortTopologically(&block, isOperandReady))
327 return WalkResult::interrupt();
328 }
329 return WalkResult::advance();
330 });
331
332 return success(!walkResult.wasInterrupted());
333}
assert(baseType &&"element must be base type")
static std::optional< APSInt > getConstant(Attribute operand)
Determine the value of a constant operand for the sake of constant folding.
static Value lowerVariadicAndInverterOp(AndInverterOp op, OperandRange operands, ArrayRef< bool > inverts, PatternRewriter &rewriter)
Definition SynthOps.cpp:273
create(data_type, value)
Definition hw.py:433
LogicalResult topologicallySortGraphRegionBlocks(mlir::Operation *op, llvm::function_ref< bool(mlir::Value, mlir::Operation *)> isOperandReady)
This function performs a topological sort on the operations within each block of graph regions in the...
Definition SynthOps.cpp:313
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
mlir::LogicalResult matchAndRewrite(aig::AndInverterOp op, mlir::PatternRewriter &rewriter) const override
Definition SynthOps.cpp:301