CIRCT 23.0.0git
Loading...
Searching...
No Matches
SynthOps.cpp
Go to the documentation of this file.
1//===----------------------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
13#include "mlir/Analysis/TopologicalSortUtils.h"
14#include "mlir/IR/BuiltinAttributes.h"
15#include "mlir/IR/Matchers.h"
16#include "mlir/IR/OpDefinition.h"
17#include "mlir/IR/PatternMatch.h"
18#include "mlir/IR/Value.h"
19#include "llvm/ADT/APInt.h"
20#include "llvm/ADT/SmallVector.h"
21#include "llvm/Support/Casting.h"
22#include "llvm/Support/LogicalResult.h"
23
24using namespace mlir;
25using namespace circt;
26using namespace circt::synth;
27using namespace circt::synth::mig;
28using namespace circt::synth::aig;
29
30#define GET_OP_CLASSES
31#include "circt/Dialect/Synth/Synth.cpp.inc"
32
33LogicalResult ChoiceOp::verify() {
34 if (getNumOperands() < 1)
35 return emitOpError("requires at least one operand");
36 return success();
37}
38
39OpFoldResult ChoiceOp::fold(FoldAdaptor adaptor) {
40 if (adaptor.getInputs().size() == 1)
41 return getOperand(0);
42 return {};
43}
44
45// Canonicalize a network of synth.choice operations by computing their
46// transitive closure and flattening them into a single choice operation.
47// This merges nested choices and deduplicates shared operands.
48// Pattern matched:
49// %0 = synth.choice %x, %y, %z
50// %1 = synth.choice %0, %u
51// %2 = synth.choice %z, %v
52// =>
53// %merged = synth.choice %x, %y, %z, %u, %v
54LogicalResult ChoiceOp::canonicalize(ChoiceOp op, PatternRewriter &rewriter) {
55 llvm::SetVector<Value> worklist;
56 llvm::SmallSetVector<Operation *, 4> visitedChoices;
57
58 auto addToWorklist = [&](ChoiceOp choice) -> bool {
59 if (choice->getBlock() == op->getBlock() && visitedChoices.insert(choice)) {
60 worklist.insert(choice.getInputs().begin(), choice.getInputs().end());
61 return true;
62 }
63 return false;
64 };
65
66 addToWorklist(op);
67
68 bool mergedOtherChoices = false;
69
70 // Look up and down at definitions and users.
71 for (unsigned i = 0; i < worklist.size(); ++i) {
72 Value val = worklist[i];
73 if (auto defOp = val.getDefiningOp<synth::ChoiceOp>()) {
74
75 if (addToWorklist(defOp))
76 mergedOtherChoices = true;
77 }
78
79 for (Operation *user : val.getUsers()) {
80 if (auto userChoice = llvm::dyn_cast<synth::ChoiceOp>(user)) {
81 if (addToWorklist(userChoice)) {
82 mergedOtherChoices = true;
83 }
84 }
85 }
86 }
87
88 llvm::SmallVector<mlir::Value> finalOperands;
89 for (Value v : worklist) {
90 if (!visitedChoices.contains(v.getDefiningOp())) {
91 finalOperands.push_back(v);
92 }
93 }
94
95 if (!mergedOtherChoices && finalOperands.size() == op.getInputs().size())
96 return llvm::failure();
97
98 auto newChoice = synth::ChoiceOp::create(rewriter, op->getLoc(), op.getType(),
99 finalOperands);
100 for (Operation *visited : visitedChoices.takeVector())
101 rewriter.replaceOp(visited, newChoice);
102
103 for (auto value : newChoice.getInputs())
104 rewriter.replaceAllUsesExcept(value, newChoice.getResult(), newChoice);
105
106 return success();
107}
108
109LogicalResult MajorityInverterOp::verify() {
110 if (getNumOperands() % 2 != 1)
111 return emitOpError("requires an odd number of operands");
112
113 return success();
114}
115
116llvm::APInt MajorityInverterOp::evaluate(ArrayRef<APInt> inputs) {
117 assert(inputs.size() == getNumOperands() &&
118 "Number of inputs must match number of operands");
119
120 if (inputs.size() == 3) {
121 auto a = (isInverted(0) ? ~inputs[0] : inputs[0]);
122 auto b = (isInverted(1) ? ~inputs[1] : inputs[1]);
123 auto c = (isInverted(2) ? ~inputs[2] : inputs[2]);
124 return (a & b) | (a & c) | (b & c);
125 }
126
127 // General case for odd number of inputs != 3
128 auto width = inputs[0].getBitWidth();
129 APInt result(width, 0);
130
131 for (size_t bit = 0; bit < width; ++bit) {
132 size_t count = 0;
133 for (size_t i = 0; i < inputs.size(); ++i) {
134 // Count the number of 1s, considering inversion.
135 if (isInverted(i) ^ inputs[i][bit])
136 count++;
137 }
138
139 if (count > inputs.size() / 2)
140 result.setBit(bit);
141 }
142
143 return result;
144}
145
146OpFoldResult MajorityInverterOp::fold(FoldAdaptor adaptor) {
147
148 SmallVector<APInt, 3> inputValues;
149 SmallVector<size_t, 3> nonConstantValues;
150 for (auto [i, input] : llvm::enumerate(adaptor.getInputs())) {
151 auto attr = llvm::dyn_cast_or_null<IntegerAttr>(input);
152 if (attr)
153 inputValues.push_back(attr.getValue());
154 else
155 nonConstantValues.push_back(i);
156 }
157
158 if (nonConstantValues.size() == 0)
159 return IntegerAttr::get(getType(), evaluate(inputValues));
160
161 if (getNumOperands() != 3)
162 return {};
163
164 auto getConstant = [&](unsigned index) -> std::optional<llvm::APInt> {
165 APInt value;
166 if (mlir::matchPattern(getInputs()[index], mlir::m_ConstantInt(&value)))
167 return isInverted(index) ? ~value : value;
168 return std::nullopt;
169 };
170 if (nonConstantValues.size() == 1) {
171 auto k = nonConstantValues[0]; // for 3 operands
172 auto i = (k + 1) % 3;
173 auto j = (k + 2) % 3;
174 auto c1 = getConstant(i);
175 auto c2 = getConstant(j);
176 // x c c -> c
177 // x c !c -> x
178 // x ~c ~c -> ~c
179 if (c1 == c2) {
180 return IntegerAttr::get(IntegerType::get(getContext(), c1->getBitWidth()),
181 c1.value());
182 } else {
183 if (isInverted(k)) {
184 (*this)->setOperands({getOperand(i)});
185 (*this).setInverted({true});
186 return getResult();
187 } else
188 return getOperand(k);
189 }
190 }
191 return {};
192}
193
194LogicalResult MajorityInverterOp::canonicalize(MajorityInverterOp op,
195 PatternRewriter &rewriter) {
196 if (op.getNumOperands() == 1) {
197 if (op.getInverted()[0])
198 return failure();
199 rewriter.replaceOp(op, op.getOperand(0));
200 return success();
201 }
202
203 // For now, only support 3 operands.
204 if (op.getNumOperands() != 3)
205 return failure();
206
207 // Replace the op with the idx-th operand (inverted if necessary).
208 auto replaceWithIndex = [&](int index) {
209 bool inverted = op.isInverted(index);
210 if (inverted)
211 rewriter.replaceOpWithNewOp<MajorityInverterOp>(
212 op, op.getType(), op.getOperand(index), true);
213 else
214 rewriter.replaceOp(op, op.getOperand(index));
215 return success();
216 };
217
218 // Pattern match following cases:
219 // maj_inv(x, x, y) -> x
220 // maj_inv(x, y, not y) -> x
221 for (int i = 0; i < 2; ++i) {
222 for (int j = i + 1; j < 3; ++j) {
223 int k = 3 - (i + j);
224 assert(k >= 0 && k < 3);
225 // If we have two identical operands, we can fold.
226 if (op.getOperand(i) == op.getOperand(j)) {
227 // If they are inverted differently, we can fold to the third.
228 if (op.isInverted(i) != op.isInverted(j))
229 return replaceWithIndex(k);
230 return replaceWithIndex(i);
231 }
232 }
233 }
234 return failure();
235}
236
237//===----------------------------------------------------------------------===//
238// AIG Operations
239//===----------------------------------------------------------------------===//
240
241OpFoldResult AndInverterOp::fold(FoldAdaptor adaptor) {
242 if (getNumOperands() == 1 && !isInverted(0))
243 return getOperand(0);
244
245 auto inputs = adaptor.getInputs();
246 if (inputs.size() == 2)
247 if (auto intAttr = dyn_cast_or_null<IntegerAttr>(inputs[1])) {
248 auto value = intAttr.getValue();
249 if (isInverted(1))
250 value = ~value;
251 if (value.isZero())
252 return IntegerAttr::get(
253 IntegerType::get(getContext(), value.getBitWidth()), value);
254 if (value.isAllOnes()) {
255 if (isInverted(0))
256 return {};
257
258 return getOperand(0);
259 }
260 }
261 return {};
262}
263
264LogicalResult AndInverterOp::canonicalize(AndInverterOp op,
265 PatternRewriter &rewriter) {
267 SmallVector<Value> uniqueValues;
268 SmallVector<bool> uniqueInverts;
269
270 APInt constValue =
271 APInt::getAllOnes(op.getResult().getType().getIntOrFloatBitWidth());
272
273 bool invertedConstFound = false;
274 bool flippedFound = false;
275
276 for (auto [value, inverted] : llvm::zip(op.getInputs(), op.getInverted())) {
277 bool newInverted = inverted;
278 if (auto constOp = value.getDefiningOp<hw::ConstantOp>()) {
279 if (inverted) {
280 constValue &= ~constOp.getValue();
281 invertedConstFound = true;
282 } else {
283 constValue &= constOp.getValue();
284 }
285 continue;
286 }
287
288 if (auto andInverterOp = value.getDefiningOp<synth::aig::AndInverterOp>()) {
289 if (andInverterOp.getInputs().size() == 1 &&
290 andInverterOp.isInverted(0)) {
291 value = andInverterOp.getOperand(0);
292 newInverted = andInverterOp.isInverted(0) ^ inverted;
293 flippedFound = true;
294 }
295 }
296
297 auto it = seen.find(value);
298 if (it == seen.end()) {
299 seen.insert({value, newInverted});
300 uniqueValues.push_back(value);
301 uniqueInverts.push_back(newInverted);
302 } else if (it->second != newInverted) {
303 // replace with const 0
304 rewriter.replaceOpWithNewOp<hw::ConstantOp>(
305 op, APInt::getZero(value.getType().getIntOrFloatBitWidth()));
306 return success();
307 }
308 }
309
310 // If the constant is zero, we can just replace with zero.
311 if (constValue.isZero()) {
312 rewriter.replaceOpWithNewOp<hw::ConstantOp>(op, constValue);
313 return success();
314 }
315
316 // No change.
317 if ((uniqueValues.size() == op.getInputs().size() && !flippedFound) ||
318 (!constValue.isAllOnes() && !invertedConstFound &&
319 uniqueValues.size() + 1 == op.getInputs().size()))
320 return failure();
321
322 if (!constValue.isAllOnes()) {
323 auto constOp = hw::ConstantOp::create(rewriter, op.getLoc(), constValue);
324 uniqueInverts.push_back(false);
325 uniqueValues.push_back(constOp);
326 }
327
328 // It means the input is reduced to all ones.
329 if (uniqueValues.empty()) {
330 rewriter.replaceOpWithNewOp<hw::ConstantOp>(op, constValue);
331 return success();
332 }
333
334 // build new op with reduced input values
335 replaceOpWithNewOpAndCopyNamehint<synth::aig::AndInverterOp>(
336 rewriter, op, uniqueValues, uniqueInverts);
337 return success();
338}
339
340APInt AndInverterOp::evaluate(ArrayRef<APInt> inputs) {
341 assert(inputs.size() == getNumOperands() &&
342 "Expected as many inputs as operands");
343 assert(!inputs.empty() && "Expected non-empty input list");
344 APInt result = APInt::getAllOnes(inputs.front().getBitWidth());
345 for (auto [idx, input] : llvm::enumerate(inputs)) {
346 if (isInverted(idx))
347 result &= ~input;
348 else
349 result &= input;
350 }
351 return result;
352}
353
354static Value lowerVariadicAndInverterOp(AndInverterOp op, OperandRange operands,
355 ArrayRef<bool> inverts,
356 PatternRewriter &rewriter) {
357 switch (operands.size()) {
358 case 0:
359 assert(0 && "cannot be called with empty operand range");
360 break;
361 case 1:
362 if (inverts[0])
363 return AndInverterOp::create(rewriter, op.getLoc(), operands[0], true);
364 else
365 return operands[0];
366 case 2:
367 return AndInverterOp::create(rewriter, op.getLoc(), operands[0],
368 operands[1], inverts[0], inverts[1]);
369 default:
370 auto firstHalf = operands.size() / 2;
371 auto lhs =
372 lowerVariadicAndInverterOp(op, operands.take_front(firstHalf),
373 inverts.take_front(firstHalf), rewriter);
374 auto rhs =
375 lowerVariadicAndInverterOp(op, operands.drop_front(firstHalf),
376 inverts.drop_front(firstHalf), rewriter);
377 return AndInverterOp::create(rewriter, op.getLoc(), lhs, rhs);
378 }
379 return Value();
380}
381
383 AndInverterOp op, PatternRewriter &rewriter) const {
384 if (op.getInputs().size() <= 2)
385 return failure();
386 // TODO: This is a naive implementation that creates a balanced binary tree.
387 // We can improve by analyzing the dataflow and creating a tree that
388 // improves the critical path or area.
389 rewriter.replaceOp(op, lowerVariadicAndInverterOp(
390 op, op.getOperands(), op.getInverted(), rewriter));
391 return success();
392}
393
395 mlir::Operation *op,
396 llvm::function_ref<bool(mlir::Value, mlir::Operation *)> isOperandReady) {
397 // Sort the operations topologically
398 auto walkResult = op->walk([&](Region *region) {
399 auto regionKindOp =
400 dyn_cast<mlir::RegionKindInterface>(region->getParentOp());
401 if (!regionKindOp ||
402 regionKindOp.hasSSADominance(region->getRegionNumber()))
403 return WalkResult::advance();
404
405 // Graph region.
406 for (auto &block : *region) {
407 if (!mlir::sortTopologically(&block, isOperandReady))
408 return WalkResult::interrupt();
409 }
410 return WalkResult::advance();
411 });
412
413 return success(!walkResult.wasInterrupted());
414}
assert(baseType &&"element must be base type")
static std::optional< APSInt > getConstant(Attribute operand)
Determine the value of a constant operand for the sake of constant folding.
static Value lowerVariadicAndInverterOp(AndInverterOp op, OperandRange operands, ArrayRef< bool > inverts, PatternRewriter &rewriter)
Definition SynthOps.cpp:354
create(data_type, value)
Definition hw.py:433
LogicalResult topologicallySortGraphRegionBlocks(mlir::Operation *op, llvm::function_ref< bool(mlir::Value, mlir::Operation *)> isOperandReady)
This function performs a topological sort on the operations within each block of graph regions in the...
Definition SynthOps.cpp:394
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
mlir::LogicalResult matchAndRewrite(aig::AndInverterOp op, mlir::PatternRewriter &rewriter) const override
Definition SynthOps.cpp:382