CIRCT  20.0.0git
CombToAIG.cpp
Go to the documentation of this file.
1 //===- CombToAIG.cpp - Comb to AIG Conversion Pass --------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This is the main Comb to AIG Conversion Pass Implementation.
10 //
11 //===----------------------------------------------------------------------===//
12 
16 #include "circt/Dialect/HW/HWOps.h"
17 #include "mlir/Pass/Pass.h"
18 #include "mlir/Transforms/DialectConversion.h"
19 
20 namespace circt {
21 #define GEN_PASS_DEF_CONVERTCOMBTOAIG
22 #include "circt/Conversion/Passes.h.inc"
23 } // namespace circt
24 
25 using namespace circt;
26 using namespace comb;
27 
28 //===----------------------------------------------------------------------===//
29 // Utility Functions
30 //===----------------------------------------------------------------------===//
31 
32 // Extract individual bits from a value
33 static SmallVector<Value> extractBits(ConversionPatternRewriter &rewriter,
34  Value val) {
35  assert(val.getType().isInteger() && "expected integer");
36  auto width = val.getType().getIntOrFloatBitWidth();
37  SmallVector<Value> bits;
38  bits.reserve(width);
39 
40  // Check if we can reuse concat operands
41  if (auto concat = val.getDefiningOp<comb::ConcatOp>()) {
42  if (concat.getNumOperands() == width &&
43  llvm::all_of(concat.getOperandTypes(), [](Type type) {
44  return type.getIntOrFloatBitWidth() == 1;
45  })) {
46  // Reverse the operands to match the bit order
47  bits.append(std::make_reverse_iterator(concat.getOperands().end()),
48  std::make_reverse_iterator(concat.getOperands().begin()));
49  return bits;
50  }
51  }
52 
53  // Extract individual bits
54  for (int64_t i = 0; i < width; ++i)
55  bits.push_back(
56  rewriter.createOrFold<comb::ExtractOp>(val.getLoc(), val, i, 1));
57 
58  return bits;
59 }
60 
61 //===----------------------------------------------------------------------===//
62 // Conversion patterns
63 //===----------------------------------------------------------------------===//
64 
65 namespace {
66 
67 /// Lower a comb::AndOp operation to aig::AndInverterOp
68 struct CombAndOpConversion : OpConversionPattern<AndOp> {
70 
71  LogicalResult
72  matchAndRewrite(AndOp op, OpAdaptor adaptor,
73  ConversionPatternRewriter &rewriter) const override {
74  SmallVector<bool> nonInverts(adaptor.getInputs().size(), false);
75  rewriter.replaceOpWithNewOp<aig::AndInverterOp>(op, adaptor.getInputs(),
76  nonInverts);
77  return success();
78  }
79 };
80 
81 /// Lower a comb::OrOp operation to aig::AndInverterOp with invert flags
82 struct CombOrOpConversion : OpConversionPattern<OrOp> {
84 
85  LogicalResult
86  matchAndRewrite(OrOp op, OpAdaptor adaptor,
87  ConversionPatternRewriter &rewriter) const override {
88  // Implement Or using And and invert flags: a | b = ~(~a & ~b)
89  SmallVector<bool> allInverts(adaptor.getInputs().size(), true);
90  auto andOp = rewriter.create<aig::AndInverterOp>(
91  op.getLoc(), adaptor.getInputs(), allInverts);
92  rewriter.replaceOpWithNewOp<aig::AndInverterOp>(op, andOp,
93  /*invert=*/true);
94  return success();
95  }
96 };
97 
98 /// Lower a comb::XorOp operation to AIG operations
99 struct CombXorOpConversion : OpConversionPattern<XorOp> {
101 
102  LogicalResult
103  matchAndRewrite(XorOp op, OpAdaptor adaptor,
104  ConversionPatternRewriter &rewriter) const override {
105  if (op.getNumOperands() != 2)
106  return failure();
107  // Xor using And with invert flags: a ^ b = (a | b) & (~a | ~b)
108 
109  // (a | b) = ~(~a & ~b)
110  // (~a | ~b) = ~(a & b)
111  auto inputs = adaptor.getInputs();
112  SmallVector<bool> allInverts(inputs.size(), true);
113  SmallVector<bool> allNotInverts(inputs.size(), false);
114 
115  auto notAAndNotB =
116  rewriter.create<aig::AndInverterOp>(op.getLoc(), inputs, allInverts);
117  auto aAndB =
118  rewriter.create<aig::AndInverterOp>(op.getLoc(), inputs, allNotInverts);
119 
120  rewriter.replaceOpWithNewOp<aig::AndInverterOp>(op, notAAndNotB, aAndB,
121  /*lhs_invert=*/true,
122  /*rhs_invert=*/true);
123  return success();
124  }
125 };
126 
127 template <typename OpTy>
128 struct CombLowerVariadicOp : OpConversionPattern<OpTy> {
130  using OpAdaptor = typename OpConversionPattern<OpTy>::OpAdaptor;
131  LogicalResult
132  matchAndRewrite(OpTy op, OpAdaptor adaptor,
133  ConversionPatternRewriter &rewriter) const override {
134  auto result = lowerFullyAssociativeOp(op, op.getOperands(), rewriter);
135  rewriter.replaceOp(op, result);
136  return success();
137  }
138 
139  static Value lowerFullyAssociativeOp(OpTy op, OperandRange operands,
140  ConversionPatternRewriter &rewriter) {
141  Value lhs, rhs;
142  switch (operands.size()) {
143  case 0:
144  assert(false && "cannot be called with empty operand range");
145  break;
146  case 1:
147  return operands[0];
148  case 2:
149  lhs = operands[0];
150  rhs = operands[1];
151  return rewriter.create<OpTy>(op.getLoc(), ValueRange{lhs, rhs}, true);
152  default:
153  auto firstHalf = operands.size() / 2;
154  lhs =
155  lowerFullyAssociativeOp(op, operands.take_front(firstHalf), rewriter);
156  rhs =
157  lowerFullyAssociativeOp(op, operands.drop_front(firstHalf), rewriter);
158  return rewriter.create<OpTy>(op.getLoc(), ValueRange{lhs, rhs}, true);
159  }
160  }
161 };
162 
163 // Lower comb::MuxOp to AIG operations.
164 struct CombMuxOpConversion : OpConversionPattern<MuxOp> {
166 
167  LogicalResult
168  matchAndRewrite(MuxOp op, OpAdaptor adaptor,
169  ConversionPatternRewriter &rewriter) const override {
170  // Implement: c ? a : b = (replicate(c) & a) | (~replicate(c) & b)
171 
172  Value cond = op.getCond();
173  auto trueVal = op.getTrueValue();
174  auto falseVal = op.getFalseValue();
175 
176  if (!op.getType().isInteger()) {
177  // If the type of the mux is not integer, bitcast the operands first.
178  auto widthType = rewriter.getIntegerType(hw::getBitWidth(op.getType()));
179  trueVal =
180  rewriter.create<hw::BitcastOp>(op->getLoc(), widthType, trueVal);
181  falseVal =
182  rewriter.create<hw::BitcastOp>(op->getLoc(), widthType, falseVal);
183  }
184 
185  // Replicate condition if needed
186  if (!trueVal.getType().isInteger(1))
187  cond = rewriter.create<comb::ReplicateOp>(op.getLoc(), trueVal.getType(),
188  cond);
189 
190  // c ? a : b => (replicate(c) & a) | (~replicate(c) & b)
191  auto lhs = rewriter.create<aig::AndInverterOp>(op.getLoc(), cond, trueVal);
192  auto rhs = rewriter.create<aig::AndInverterOp>(op.getLoc(), cond, falseVal,
193  true, false);
194 
195  Value result = rewriter.create<comb::OrOp>(op.getLoc(), lhs, rhs);
196  // Insert the bitcast if the type of the mux is not integer.
197  if (result.getType() != op.getType())
198  result =
199  rewriter.create<hw::BitcastOp>(op.getLoc(), op.getType(), result);
200  rewriter.replaceOp(op, result);
201  return success();
202  }
203 };
204 
205 struct CombAddOpConversion : OpConversionPattern<AddOp> {
207  LogicalResult
208  matchAndRewrite(AddOp op, OpAdaptor adaptor,
209  ConversionPatternRewriter &rewriter) const override {
210  auto inputs = adaptor.getInputs();
211  // Lower only when there are two inputs.
212  // Variadic operands must be lowered in a different pattern.
213  if (inputs.size() != 2)
214  return failure();
215 
216  auto width = op.getType().getIntOrFloatBitWidth();
217  // Skip a zero width value.
218  if (width == 0) {
219  rewriter.replaceOpWithNewOp<hw::ConstantOp>(op, op.getType(), 0);
220  return success();
221  }
222 
223  // Implement a naive Ripple-carry full adder.
224  Value carry;
225 
226  auto aBits = extractBits(rewriter, inputs[0]);
227  auto bBits = extractBits(rewriter, inputs[1]);
228  SmallVector<Value> results;
229  results.resize(width);
230  for (int64_t i = 0; i < width; ++i) {
231  SmallVector<Value> xorOperands = {aBits[i], bBits[i]};
232  if (carry)
233  xorOperands.push_back(carry);
234 
235  // sum[i] = xor(carry[i-1], a[i], b[i])
236  // NOTE: The result is stored in reverse order.
237  results[width - i - 1] =
238  rewriter.create<comb::XorOp>(op.getLoc(), xorOperands, true);
239 
240  // If this is the last bit, we are done.
241  if (i == width - 1) {
242  break;
243  }
244 
245  // carry[i] = (carry[i-1] & (a[i] ^ b[i])) | (a[i] & b[i])
246  Value nextCarry = rewriter.create<comb::AndOp>(
247  op.getLoc(), ValueRange{aBits[i], bBits[i]}, true);
248  if (!carry) {
249  // This is the first bit, so the carry is the next carry.
250  carry = nextCarry;
251  continue;
252  }
253 
254  auto aXnorB = rewriter.create<comb::XorOp>(
255  op.getLoc(), ValueRange{aBits[i], bBits[i]}, true);
256  auto andOp = rewriter.create<comb::AndOp>(
257  op.getLoc(), ValueRange{carry, aXnorB}, true);
258  carry = rewriter.create<comb::OrOp>(op.getLoc(),
259  ValueRange{andOp, nextCarry}, true);
260  }
261 
262  rewriter.replaceOpWithNewOp<comb::ConcatOp>(op, results);
263  return success();
264  }
265 };
266 
267 struct CombSubOpConversion : OpConversionPattern<SubOp> {
269  LogicalResult
270  matchAndRewrite(SubOp op, OpAdaptor adaptor,
271  ConversionPatternRewriter &rewriter) const override {
272  auto lhs = op.getLhs();
273  auto rhs = op.getRhs();
274  // Since `-rhs = ~rhs + 1` holds, rewrite `sub(lhs, rhs)` to:
275  // sub(lhs, rhs) => add(lhs, -rhs) => add(lhs, add(~rhs, 1))
276  // => add(lhs, ~rhs, 1)
277  auto notRhs = rewriter.create<aig::AndInverterOp>(op.getLoc(), rhs,
278  /*invert=*/true);
279  auto one = rewriter.create<hw::ConstantOp>(op.getLoc(), op.getType(), 1);
280  rewriter.replaceOpWithNewOp<comb::AddOp>(op, ValueRange{lhs, notRhs, one},
281  true);
282  return success();
283  }
284 };
285 
286 } // namespace
287 
288 //===----------------------------------------------------------------------===//
289 // Convert Comb to AIG pass
290 //===----------------------------------------------------------------------===//
291 
292 namespace {
293 struct ConvertCombToAIGPass
294  : public impl::ConvertCombToAIGBase<ConvertCombToAIGPass> {
295  void runOnOperation() override;
296  using ConvertCombToAIGBase<ConvertCombToAIGPass>::ConvertCombToAIGBase;
297  using ConvertCombToAIGBase<ConvertCombToAIGPass>::additionalLegalOps;
298 };
299 } // namespace
300 
301 static void populateCombToAIGConversionPatterns(RewritePatternSet &patterns) {
302  patterns.add<
303  // Bitwise Logical Ops
304  CombAndOpConversion, CombOrOpConversion, CombXorOpConversion,
305  CombMuxOpConversion,
306  // Arithmetic Ops
307  CombAddOpConversion, CombSubOpConversion,
308  // Variadic ops that must be lowered to binary operations
309  CombLowerVariadicOp<XorOp>, CombLowerVariadicOp<AddOp>>(
310  patterns.getContext());
311 }
312 
313 void ConvertCombToAIGPass::runOnOperation() {
314  ConversionTarget target(getContext());
315  target.addIllegalDialect<comb::CombDialect>();
316  // Keep data movement operations like Extract, Concat and Replicate.
317  target.addLegalOp<comb::ExtractOp, comb::ConcatOp, comb::ReplicateOp,
319  target.addLegalDialect<aig::AIGDialect>();
320 
321  // This is a test only option to add logical ops.
322  if (!additionalLegalOps.empty())
323  for (const auto &opName : additionalLegalOps)
324  target.addLegalOp(OperationName(opName, &getContext()));
325 
326  RewritePatternSet patterns(&getContext());
328 
329  if (failed(mlir::applyPartialConversion(getOperation(), target,
330  std::move(patterns))))
331  return signalPassFailure();
332 }
assert(baseType &&"element must be base type")
static SmallVector< T > concat(const SmallVectorImpl< T > &a, const SmallVectorImpl< T > &b)
Returns a new vector containing the concatenation of vectors a and b.
Definition: CalyxOps.cpp:540
static void populateCombToAIGConversionPatterns(RewritePatternSet &patterns)
Definition: CombToAIG.cpp:301
static SmallVector< Value > extractBits(ConversionPatternRewriter &rewriter, Value val)
Definition: CombToAIG.cpp:33
static Value lowerFullyAssociativeOp(Operation &op, OperandRange operands, SmallVector< Operation * > &newOps)
Lower a variadic fully-associative operation into an expression tree.
def create(data_type, value)
Definition: hw.py:441
int64_t getBitWidth(mlir::Type type)
Return the hardware bit width of a type.
Definition: HWTypes.cpp:110
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21
Definition: comb.py:1