CIRCT 22.0.0git
Loading...
Searching...
No Matches
SynthOps.cpp
Go to the documentation of this file.
1//===----------------------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
13#include "mlir/Analysis/TopologicalSortUtils.h"
14#include "mlir/IR/BuiltinAttributes.h"
15#include "mlir/IR/Matchers.h"
16#include "mlir/IR/OpDefinition.h"
17#include "mlir/IR/PatternMatch.h"
18#include "llvm/ADT/APInt.h"
19#include "llvm/Support/Casting.h"
20#include "llvm/Support/LogicalResult.h"
21
22using namespace mlir;
23using namespace circt;
24using namespace circt::synth::mig;
25using namespace circt::synth::aig;
26
27#define GET_OP_CLASSES
28#include "circt/Dialect/Synth/Synth.cpp.inc"
29
30LogicalResult MajorityInverterOp::verify() {
31 if (getNumOperands() % 2 != 1)
32 return emitOpError("requires an odd number of operands");
33
34 return success();
35}
36
37llvm::APInt MajorityInverterOp::evaluate(ArrayRef<APInt> inputs) {
38 assert(inputs.size() == getNumOperands() &&
39 "Number of inputs must match number of operands");
40
41 if (inputs.size() == 3) {
42 auto a = (isInverted(0) ? ~inputs[0] : inputs[0]);
43 auto b = (isInverted(1) ? ~inputs[1] : inputs[1]);
44 auto c = (isInverted(2) ? ~inputs[2] : inputs[2]);
45 return (a & b) | (a & c) | (b & c);
46 }
47
48 // General case for odd number of inputs != 3
49 auto width = inputs[0].getBitWidth();
50 APInt result(width, 0);
51
52 for (size_t bit = 0; bit < width; ++bit) {
53 size_t count = 0;
54 for (size_t i = 0; i < inputs.size(); ++i) {
55 // Count the number of 1s, considering inversion.
56 if (isInverted(i) ^ inputs[i][bit])
57 count++;
58 }
59
60 if (count > inputs.size() / 2)
61 result.setBit(bit);
62 }
63
64 return result;
65}
66
67OpFoldResult MajorityInverterOp::fold(FoldAdaptor adaptor) {
68 // TODO: Implement maj(x, 1, 1) = 1, maj(x, 0, 0) = 0
69
70 SmallVector<APInt, 3> inputValues;
71 for (auto input : adaptor.getInputs()) {
72 auto attr = llvm::dyn_cast_or_null<IntegerAttr>(input);
73 if (!attr)
74 return {};
75 inputValues.push_back(attr.getValue());
76 }
77
78 auto result = evaluate(inputValues);
79 return IntegerAttr::get(getType(), result);
80}
81
82LogicalResult MajorityInverterOp::canonicalize(MajorityInverterOp op,
83 PatternRewriter &rewriter) {
84 if (op.getNumOperands() == 1) {
85 if (op.getInverted()[0])
86 return failure();
87 rewriter.replaceOp(op, op.getOperand(0));
88 return success();
89 }
90
91 // For now, only support 3 operands.
92 if (op.getNumOperands() != 3)
93 return failure();
94
95 // Return if the idx-th operand is a constant (inverted if necessary),
96 // otherwise return std::nullopt.
97 auto getConstant = [&](unsigned index) -> std::optional<llvm::APInt> {
98 APInt value;
99 if (mlir::matchPattern(op.getInputs()[index], mlir::m_ConstantInt(&value)))
100 return op.isInverted(index) ? ~value : value;
101 return std::nullopt;
102 };
103
104 // Replace the op with the idx-th operand (inverted if necessary).
105 auto replaceWithIndex = [&](int index) {
106 bool inverted = op.isInverted(index);
107 if (inverted)
108 rewriter.replaceOpWithNewOp<MajorityInverterOp>(
109 op, op.getType(), op.getOperand(index), true);
110 else
111 rewriter.replaceOp(op, op.getOperand(index));
112 return success();
113 };
114
115 // Pattern match following cases:
116 // maj_inv(x, x, y) -> x
117 // maj_inv(x, y, not y) -> x
118 for (int i = 0; i < 2; ++i) {
119 for (int j = i + 1; j < 3; ++j) {
120 int k = 3 - (i + j);
121 assert(k >= 0 && k < 3);
122 // If we have two identical operands, we can fold.
123 if (op.getOperand(i) == op.getOperand(j)) {
124 // If they are inverted differently, we can fold to the third.
125 if (op.isInverted(i) != op.isInverted(j)) {
126 return replaceWithIndex(k);
127 }
128 rewriter.replaceOp(op, op.getOperand(i));
129 return success();
130 }
131
132 // If i and j are constant.
133 if (auto c1 = getConstant(i)) {
134 if (auto c2 = getConstant(j)) {
135 // If both constants are equal, we can fold.
136 if (*c1 == *c2) {
137 rewriter.replaceOpWithNewOp<hw::ConstantOp>(
138 op, op.getType(), mlir::IntegerAttr::get(op.getType(), *c1));
139 return success();
140 }
141 // If constants are complementary, we can fold.
142 if (*c1 == ~*c2)
143 return replaceWithIndex(k);
144 }
145 }
146 }
147 }
148 return failure();
149}
150
151//===----------------------------------------------------------------------===//
152// AIG Operations
153//===----------------------------------------------------------------------===//
154
155OpFoldResult AndInverterOp::fold(FoldAdaptor adaptor) {
156 if (getNumOperands() == 1 && !isInverted(0))
157 return getOperand(0);
158
159 auto inputs = adaptor.getInputs();
160 if (inputs.size() == 2 && inputs[1]) {
161 auto value = cast<IntegerAttr>(inputs[1]).getValue();
162 if (isInverted(1))
163 value = ~value;
164 if (value.isZero())
165 return IntegerAttr::get(
166 IntegerType::get(getContext(), value.getBitWidth()), value);
167 if (value.isAllOnes()) {
168 if (isInverted(0))
169 return {};
170
171 return getOperand(0);
172 }
173 }
174 return {};
175}
176
177LogicalResult AndInverterOp::canonicalize(AndInverterOp op,
178 PatternRewriter &rewriter) {
180 SmallVector<Value> uniqueValues;
181 SmallVector<bool> uniqueInverts;
182
183 APInt constValue =
184 APInt::getAllOnes(op.getResult().getType().getIntOrFloatBitWidth());
185
186 bool invertedConstFound = false;
187 bool flippedFound = false;
188
189 for (auto [value, inverted] : llvm::zip(op.getInputs(), op.getInverted())) {
190 bool newInverted = inverted;
191 if (auto constOp = value.getDefiningOp<hw::ConstantOp>()) {
192 if (inverted) {
193 constValue &= ~constOp.getValue();
194 invertedConstFound = true;
195 } else {
196 constValue &= constOp.getValue();
197 }
198 continue;
199 }
200
201 if (auto andInverterOp = value.getDefiningOp<synth::aig::AndInverterOp>()) {
202 if (andInverterOp.getInputs().size() == 1 &&
203 andInverterOp.isInverted(0)) {
204 value = andInverterOp.getOperand(0);
205 newInverted = andInverterOp.isInverted(0) ^ inverted;
206 flippedFound = true;
207 }
208 }
209
210 auto it = seen.find(value);
211 if (it == seen.end()) {
212 seen.insert({value, newInverted});
213 uniqueValues.push_back(value);
214 uniqueInverts.push_back(newInverted);
215 } else if (it->second != newInverted) {
216 // replace with const 0
217 rewriter.replaceOpWithNewOp<hw::ConstantOp>(
218 op, APInt::getZero(value.getType().getIntOrFloatBitWidth()));
219 return success();
220 }
221 }
222
223 // If the constant is zero, we can just replace with zero.
224 if (constValue.isZero()) {
225 rewriter.replaceOpWithNewOp<hw::ConstantOp>(op, constValue);
226 return success();
227 }
228
229 // No change.
230 if ((uniqueValues.size() == op.getInputs().size() && !flippedFound) ||
231 (!constValue.isAllOnes() && !invertedConstFound &&
232 uniqueValues.size() + 1 == op.getInputs().size()))
233 return failure();
234
235 if (!constValue.isAllOnes()) {
236 auto constOp = hw::ConstantOp::create(rewriter, op.getLoc(), constValue);
237 uniqueInverts.push_back(false);
238 uniqueValues.push_back(constOp);
239 }
240
241 // It means the input is reduced to all ones.
242 if (uniqueValues.empty()) {
243 rewriter.replaceOpWithNewOp<hw::ConstantOp>(op, constValue);
244 return success();
245 }
246
247 // build new op with reduced input values
248 replaceOpWithNewOpAndCopyNamehint<synth::aig::AndInverterOp>(
249 rewriter, op, uniqueValues, uniqueInverts);
250 return success();
251}
252
253APInt AndInverterOp::evaluate(ArrayRef<APInt> inputs) {
254 assert(inputs.size() == getNumOperands() &&
255 "Expected as many inputs as operands");
256 assert(!inputs.empty() && "Expected non-empty input list");
257 APInt result = APInt::getAllOnes(inputs.front().getBitWidth());
258 for (auto [idx, input] : llvm::enumerate(inputs)) {
259 if (isInverted(idx))
260 result &= ~input;
261 else
262 result &= input;
263 }
264 return result;
265}
266
267static Value lowerVariadicAndInverterOp(AndInverterOp op, OperandRange operands,
268 ArrayRef<bool> inverts,
269 PatternRewriter &rewriter) {
270 switch (operands.size()) {
271 case 0:
272 assert(0 && "cannot be called with empty operand range");
273 break;
274 case 1:
275 if (inverts[0])
276 return AndInverterOp::create(rewriter, op.getLoc(), operands[0], true);
277 else
278 return operands[0];
279 case 2:
280 return AndInverterOp::create(rewriter, op.getLoc(), operands[0],
281 operands[1], inverts[0], inverts[1]);
282 default:
283 auto firstHalf = operands.size() / 2;
284 auto lhs =
285 lowerVariadicAndInverterOp(op, operands.take_front(firstHalf),
286 inverts.take_front(firstHalf), rewriter);
287 auto rhs =
288 lowerVariadicAndInverterOp(op, operands.drop_front(firstHalf),
289 inverts.drop_front(firstHalf), rewriter);
290 return AndInverterOp::create(rewriter, op.getLoc(), lhs, rhs);
291 }
292 return Value();
293}
294
296 AndInverterOp op, PatternRewriter &rewriter) const {
297 if (op.getInputs().size() <= 2)
298 return failure();
299 // TODO: This is a naive implementation that creates a balanced binary tree.
300 // We can improve by analyzing the dataflow and creating a tree that
301 // improves the critical path or area.
302 rewriter.replaceOp(op, lowerVariadicAndInverterOp(
303 op, op.getOperands(), op.getInverted(), rewriter));
304 return success();
305}
306
308 mlir::Operation *op,
309 llvm::function_ref<bool(mlir::Value, mlir::Operation *)> isOperandReady) {
310 // Sort the operations topologically
311 auto walkResult = op->walk([&](Region *region) {
312 auto regionKindOp =
313 dyn_cast<mlir::RegionKindInterface>(region->getParentOp());
314 if (!regionKindOp ||
315 regionKindOp.hasSSADominance(region->getRegionNumber()))
316 return WalkResult::advance();
317
318 // Graph region.
319 for (auto &block : *region) {
320 if (!mlir::sortTopologically(&block, isOperandReady))
321 return WalkResult::interrupt();
322 }
323 return WalkResult::advance();
324 });
325
326 return success(!walkResult.wasInterrupted());
327}
assert(baseType &&"element must be base type")
static std::optional< APSInt > getConstant(Attribute operand)
Determine the value of a constant operand for the sake of constant folding.
static Value lowerVariadicAndInverterOp(AndInverterOp op, OperandRange operands, ArrayRef< bool > inverts, PatternRewriter &rewriter)
Definition SynthOps.cpp:267
create(data_type, value)
Definition hw.py:433
LogicalResult topologicallySortGraphRegionBlocks(mlir::Operation *op, llvm::function_ref< bool(mlir::Value, mlir::Operation *)> isOperandReady)
This function performs a topological sort on the operations within each block of graph regions in the...
Definition SynthOps.cpp:307
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
mlir::LogicalResult matchAndRewrite(aig::AndInverterOp op, mlir::PatternRewriter &rewriter) const override
Definition SynthOps.cpp:295