Loading [MathJax]/extensions/tex2jax.js
CIRCT 22.0.0git
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
CombOps.cpp
Go to the documentation of this file.
1//===- CombOps.cpp - Implement the Comb operations ------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements combinational ops.
10//
11//===----------------------------------------------------------------------===//
12
15#include "mlir/IR/Builders.h"
16#include "mlir/IR/ImplicitLocOpBuilder.h"
17#include "mlir/IR/Matchers.h"
18#include "mlir/IR/PatternMatch.h"
19#include "llvm/Support/FormatVariadic.h"
20
21using namespace circt;
22using namespace comb;
23
24Value comb::createZExt(OpBuilder &builder, Location loc, Value value,
25 unsigned targetWidth) {
26 assert(value.getType().isSignlessInteger());
27 auto inputWidth = value.getType().getIntOrFloatBitWidth();
28 assert(inputWidth <= targetWidth);
29
30 // Nothing to do if the width already matches.
31 if (inputWidth == targetWidth)
32 return value;
33
34 // Create a zero constant for the upper bits.
35 auto zeros = hw::ConstantOp::create(
36 builder, loc, builder.getIntegerType(targetWidth - inputWidth), 0);
37 return builder.createOrFold<ConcatOp>(loc, zeros, value);
38}
39
40/// Create a sign extension operation from a value of integer type to an equal
41/// or larger integer type.
42Value comb::createOrFoldSExt(Location loc, Value value, Type destTy,
43 OpBuilder &builder) {
44 IntegerType valueType = dyn_cast<IntegerType>(value.getType());
45 assert(valueType && isa<IntegerType>(destTy) &&
46 valueType.getWidth() <= destTy.getIntOrFloatBitWidth() &&
47 valueType.getWidth() != 0 && "invalid sext operands");
48 // If already the right size, we are done.
49 if (valueType == destTy)
50 return value;
51
52 // sext is concat with a replicate of the sign bits and the bottom part.
53 auto signBit =
54 builder.createOrFold<ExtractOp>(loc, value, valueType.getWidth() - 1, 1);
55 auto signBits = builder.createOrFold<ReplicateOp>(
56 loc, signBit, destTy.getIntOrFloatBitWidth() - valueType.getWidth());
57 return builder.createOrFold<ConcatOp>(loc, signBits, value);
58}
59
60Value comb::createOrFoldSExt(Value value, Type destTy,
61 ImplicitLocOpBuilder &builder) {
62 return createOrFoldSExt(builder.getLoc(), value, destTy, builder);
63}
64
65Value comb::createOrFoldNot(Location loc, Value value, OpBuilder &builder,
66 bool twoState) {
67 auto allOnes = hw::ConstantOp::create(builder, loc, value.getType(), -1);
68 return builder.createOrFold<XorOp>(loc, value, allOnes, twoState);
69}
70
71Value comb::createOrFoldNot(Value value, ImplicitLocOpBuilder &builder,
72 bool twoState) {
73 return createOrFoldNot(builder.getLoc(), value, builder, twoState);
74}
75
76// Extract individual bits from a value
77void comb::extractBits(OpBuilder &builder, Value val,
78 SmallVectorImpl<Value> &bits) {
79 assert(val.getType().isInteger() && "expected integer");
80 auto width = val.getType().getIntOrFloatBitWidth();
81 bits.reserve(width);
82
83 // Check if we can reuse concat operands
84 if (auto concat = val.getDefiningOp<comb::ConcatOp>()) {
85 if (concat.getNumOperands() == width &&
86 llvm::all_of(concat.getOperandTypes(), [](Type type) {
87 return type.getIntOrFloatBitWidth() == 1;
88 })) {
89 // Reverse the operands to match the bit order
90 bits.append(std::make_reverse_iterator(concat.getOperands().end()),
91 std::make_reverse_iterator(concat.getOperands().begin()));
92 return;
93 }
94 }
95
96 // Extract individual bits
97 for (int64_t i = 0; i < width; ++i)
98 bits.push_back(
99 builder.createOrFold<comb::ExtractOp>(val.getLoc(), val, i, 1));
100}
101
102// Construct a mux tree for given leaf nodes. `selectors` is the selector for
103// each level of the tree. Currently the selector is tested from MSB to LSB.
104Value comb::constructMuxTree(OpBuilder &builder, Location loc,
105 ArrayRef<Value> selectors,
106 ArrayRef<Value> leafNodes,
107 Value outOfBoundsValue) {
108 // Recursive helper function to construct the mux tree
109 std::function<Value(size_t, size_t)> constructTreeHelper =
110 [&](size_t id, size_t level) -> Value {
111 // Base case: at the lowest level, return the result
112 if (level == 0) {
113 // Return the result for the given index. If the index is out of bounds,
114 // return the out-of-bound value.
115 return id < leafNodes.size() ? leafNodes[id] : outOfBoundsValue;
116 }
117
118 auto selector = selectors[level - 1];
119
120 // Recursive case: create muxes for true and false branches
121 auto trueVal = constructTreeHelper(2 * id + 1, level - 1);
122 auto falseVal = constructTreeHelper(2 * id, level - 1);
123
124 // Combine the results with a mux
125 return builder.createOrFold<comb::MuxOp>(loc, selector, trueVal, falseVal);
126 };
127
128 return constructTreeHelper(0, llvm::Log2_64_Ceil(leafNodes.size()));
129}
130
131Value comb::createDynamicExtract(OpBuilder &builder, Location loc, Value value,
132 Value offset, unsigned width) {
133 assert(value.getType().isSignlessInteger());
134 auto valueWidth = value.getType().getIntOrFloatBitWidth();
135 assert(width <= valueWidth);
136
137 // Handle the special case where the offset is constant.
138 APInt constOffset;
139 if (matchPattern(offset, mlir::m_ConstantInt(&constOffset)))
140 if (constOffset.getActiveBits() < 32)
141 return builder.createOrFold<comb::ExtractOp>(
142 loc, value, constOffset.getZExtValue(), width);
143
144 // Zero-extend the offset, shift the value down, and extract the requested
145 // number of bits.
146 offset = createZExt(builder, loc, offset, valueWidth);
147 value = builder.createOrFold<comb::ShrUOp>(loc, value, offset);
148 return builder.createOrFold<comb::ExtractOp>(loc, value, 0, width);
149}
150
151Value comb::createDynamicInject(OpBuilder &builder, Location loc, Value value,
152 Value offset, Value replacement,
153 bool twoState) {
154 assert(value.getType().isSignlessInteger());
155 assert(replacement.getType().isSignlessInteger());
156 auto largeWidth = value.getType().getIntOrFloatBitWidth();
157 auto smallWidth = replacement.getType().getIntOrFloatBitWidth();
158 assert(smallWidth <= largeWidth);
159
160 // If we're inserting a zero-width value there's nothing to do.
161 if (smallWidth == 0)
162 return value;
163
164 // Handle the special case where the offset is constant.
165 APInt constOffset;
166 if (matchPattern(offset, mlir::m_ConstantInt(&constOffset)))
167 if (constOffset.getActiveBits() < 32)
168 return createInject(builder, loc, value, constOffset.getZExtValue(),
169 replacement);
170
171 // Zero-extend the offset and clear the value bits we are replacing.
172 offset = createZExt(builder, loc, offset, largeWidth);
173 Value mask = hw::ConstantOp::create(
174 builder, loc, APInt::getLowBitsSet(largeWidth, smallWidth));
175 mask = builder.createOrFold<comb::ShlOp>(loc, mask, offset);
176 mask = createOrFoldNot(loc, mask, builder, true);
177 value = builder.createOrFold<comb::AndOp>(loc, value, mask, twoState);
178
179 // Zero-extend the replacement value, shift it up to the offset, and merge it
180 // with the value that has the corresponding bits cleared.
181 replacement = createZExt(builder, loc, replacement, largeWidth);
182 replacement = builder.createOrFold<comb::ShlOp>(loc, replacement, offset);
183 return builder.createOrFold<comb::OrOp>(loc, value, replacement, twoState);
184}
185
186Value comb::createInject(OpBuilder &builder, Location loc, Value value,
187 unsigned offset, Value replacement) {
188 assert(value.getType().isSignlessInteger());
189 assert(replacement.getType().isSignlessInteger());
190 auto largeWidth = value.getType().getIntOrFloatBitWidth();
191 auto smallWidth = replacement.getType().getIntOrFloatBitWidth();
192 assert(smallWidth <= largeWidth);
193
194 // If the offset is outside the value there's nothing to do.
195 if (offset >= largeWidth)
196 return value;
197
198 // If we're inserting a zero-width value there's nothing to do.
199 if (smallWidth == 0)
200 return value;
201
202 // Assemble the pieces of the injection as everything below the offset, the
203 // replacement value, and everything above the replacement value.
204 SmallVector<Value, 3> fragments;
205 auto end = offset + smallWidth;
206 if (end < largeWidth)
207 fragments.push_back(
208 comb::ExtractOp::create(builder, loc, value, end, largeWidth - end));
209 if (end <= largeWidth)
210 fragments.push_back(replacement);
211 else
212 fragments.push_back(comb::ExtractOp::create(builder, loc, replacement, 0,
213 largeWidth - offset));
214 if (offset > 0)
215 fragments.push_back(
216 comb::ExtractOp::create(builder, loc, value, 0, offset));
217 return builder.createOrFold<comb::ConcatOp>(loc, fragments);
218}
219
220// Construct a full adder for three 1-bit inputs.
221std::pair<Value, Value> comb::fullAdder(OpBuilder &builder, Location loc,
222 Value a, Value b, Value c) {
223 auto aXorB = builder.createOrFold<comb::XorOp>(loc, a, b, true);
224 Value sum = builder.createOrFold<comb::XorOp>(loc, aXorB, c, true);
225
226 auto carry = builder.createOrFold<comb::OrOp>(
227 loc,
228 ArrayRef<Value>{builder.createOrFold<comb::AndOp>(loc, a, b, true),
229 builder.createOrFold<comb::AndOp>(loc, aXorB, c, true)},
230 true);
231
232 return {sum, carry};
233}
234
235// Perform Wallace tree reduction on partial products.
236// See https://en.wikipedia.org/wiki/Wallace_tree
237SmallVector<Value>
238comb::wallaceReduction(OpBuilder &builder, Location loc, size_t width,
239 size_t targetAddends,
240 SmallVector<SmallVector<Value>> &addends) {
241 auto falseValue = hw::ConstantOp::create(builder, loc, APInt(1, 0));
242 SmallVector<SmallVector<Value>> newAddends;
243 newAddends.reserve(addends.size());
244 // Continue reduction until we have only two rows. The length of
245 // `addends` is reduced by 1/3 in each iteration.
246 while (addends.size() > targetAddends) {
247 newAddends.clear();
248 // Take three rows at a time and reduce to two rows(sum and carry).
249 for (unsigned i = 0; i < addends.size(); i += 3) {
250 if (i + 2 < addends.size()) {
251 // We have three rows to reduce
252 auto &row1 = addends[i];
253 auto &row2 = addends[i + 1];
254 auto &row3 = addends[i + 2];
255
256 assert(row1.size() == width && row2.size() == width &&
257 row3.size() == width);
258
259 SmallVector<Value> sumRow, carryRow;
260 sumRow.reserve(width);
261 carryRow.reserve(width);
262 carryRow.push_back(falseValue);
263
264 // Process each bit position
265 for (unsigned j = 0; j < width; ++j) {
266 // Full adder logic
267 auto [sum, carry] =
268 comb::fullAdder(builder, loc, row1[j], row2[j], row3[j]);
269 sumRow.push_back(sum);
270 if (j + 1 < width)
271 carryRow.push_back(carry);
272 }
273
274 newAddends.push_back(std::move(sumRow));
275 newAddends.push_back(std::move(carryRow));
276 } else {
277 // Add remaining rows as is
278 newAddends.append(addends.begin() + i, addends.end());
279 }
280 }
281 std::swap(newAddends, addends);
282 }
283
284 assert(addends.size() <= targetAddends);
285 SmallVector<Value> carrySave;
286 for (auto &addend : addends) {
287 // Reverse the order of the bits
288 std::reverse(addend.begin(), addend.end());
289 carrySave.push_back(comb::ConcatOp::create(builder, loc, addend));
290 }
291
292 // Pad with zeros
293 auto zero = hw::ConstantOp::create(builder, loc, APInt(width, 0));
294 while (carrySave.size() < targetAddends)
295 carrySave.push_back(zero);
296
297 return carrySave;
298}
299
300//===----------------------------------------------------------------------===//
301// ICmpOp
302//===----------------------------------------------------------------------===//
303
304ICmpPredicate ICmpOp::getFlippedPredicate(ICmpPredicate predicate) {
305 switch (predicate) {
306 case ICmpPredicate::eq:
307 return ICmpPredicate::eq;
308 case ICmpPredicate::ne:
309 return ICmpPredicate::ne;
310 case ICmpPredicate::slt:
311 return ICmpPredicate::sgt;
312 case ICmpPredicate::sle:
313 return ICmpPredicate::sge;
314 case ICmpPredicate::sgt:
315 return ICmpPredicate::slt;
316 case ICmpPredicate::sge:
317 return ICmpPredicate::sle;
318 case ICmpPredicate::ult:
319 return ICmpPredicate::ugt;
320 case ICmpPredicate::ule:
321 return ICmpPredicate::uge;
322 case ICmpPredicate::ugt:
323 return ICmpPredicate::ult;
324 case ICmpPredicate::uge:
325 return ICmpPredicate::ule;
326 case ICmpPredicate::ceq:
327 return ICmpPredicate::ceq;
328 case ICmpPredicate::cne:
329 return ICmpPredicate::cne;
330 case ICmpPredicate::weq:
331 return ICmpPredicate::weq;
332 case ICmpPredicate::wne:
333 return ICmpPredicate::wne;
334 }
335 llvm_unreachable("unknown comparison predicate");
336}
337
338bool ICmpOp::isPredicateSigned(ICmpPredicate predicate) {
339 switch (predicate) {
340 case ICmpPredicate::ult:
341 case ICmpPredicate::ugt:
342 case ICmpPredicate::ule:
343 case ICmpPredicate::uge:
344 case ICmpPredicate::ne:
345 case ICmpPredicate::eq:
346 case ICmpPredicate::cne:
347 case ICmpPredicate::ceq:
348 case ICmpPredicate::wne:
349 case ICmpPredicate::weq:
350 return false;
351 case ICmpPredicate::slt:
352 case ICmpPredicate::sgt:
353 case ICmpPredicate::sle:
354 case ICmpPredicate::sge:
355 return true;
356 }
357 llvm_unreachable("unknown comparison predicate");
358}
359
360/// Returns the predicate for a logically negated comparison, e.g. mapping
361/// EQ => NE and SLE => SGT.
362ICmpPredicate ICmpOp::getNegatedPredicate(ICmpPredicate predicate) {
363 switch (predicate) {
364 case ICmpPredicate::eq:
365 return ICmpPredicate::ne;
366 case ICmpPredicate::ne:
367 return ICmpPredicate::eq;
368 case ICmpPredicate::slt:
369 return ICmpPredicate::sge;
370 case ICmpPredicate::sle:
371 return ICmpPredicate::sgt;
372 case ICmpPredicate::sgt:
373 return ICmpPredicate::sle;
374 case ICmpPredicate::sge:
375 return ICmpPredicate::slt;
376 case ICmpPredicate::ult:
377 return ICmpPredicate::uge;
378 case ICmpPredicate::ule:
379 return ICmpPredicate::ugt;
380 case ICmpPredicate::ugt:
381 return ICmpPredicate::ule;
382 case ICmpPredicate::uge:
383 return ICmpPredicate::ult;
384 case ICmpPredicate::ceq:
385 return ICmpPredicate::cne;
386 case ICmpPredicate::cne:
387 return ICmpPredicate::ceq;
388 case ICmpPredicate::weq:
389 return ICmpPredicate::wne;
390 case ICmpPredicate::wne:
391 return ICmpPredicate::weq;
392 }
393 llvm_unreachable("unknown comparison predicate");
394}
395
396/// Return true if this is an equality test with -1, which is a "reduction
397/// and" operation in Verilog.
398bool ICmpOp::isEqualAllOnes() {
399 if (getPredicate() != ICmpPredicate::eq)
400 return false;
401
402 if (auto op1 =
403 dyn_cast_or_null<hw::ConstantOp>(getOperand(1).getDefiningOp()))
404 return op1.getValue().isAllOnes();
405 return false;
406}
407
408/// Return true if this is a not equal test with 0, which is a "reduction
409/// or" operation in Verilog.
410bool ICmpOp::isNotEqualZero() {
411 if (getPredicate() != ICmpPredicate::ne)
412 return false;
413
414 if (auto op1 =
415 dyn_cast_or_null<hw::ConstantOp>(getOperand(1).getDefiningOp()))
416 return op1.getValue().isZero();
417 return false;
418}
419
420//===----------------------------------------------------------------------===//
421// Unary Operations
422//===----------------------------------------------------------------------===//
423
424LogicalResult ReplicateOp::verify() {
425 // The source must be equal or smaller than the dest type, and an even
426 // multiple of it. Both are already known to be signless integers.
427 auto srcWidth = cast<IntegerType>(getOperand().getType()).getWidth();
428 auto dstWidth = cast<IntegerType>(getType()).getWidth();
429 if (srcWidth == 0)
430 return emitOpError("replicate does not take zero bit integer");
431
432 if (srcWidth > dstWidth)
433 return emitOpError("replicate cannot shrink bitwidth of operand"),
434 failure();
435
436 if (dstWidth % srcWidth)
437 return emitOpError("replicate must produce integer multiple of operand"),
438 failure();
439
440 return success();
441}
442
443//===----------------------------------------------------------------------===//
444// Variadic operations
445//===----------------------------------------------------------------------===//
446
447static LogicalResult verifyUTBinOp(Operation *op) {
448 if (op->getOperands().empty())
449 return op->emitOpError("requires 1 or more args");
450 return success();
451}
452
453LogicalResult AddOp::verify() { return verifyUTBinOp(*this); }
454
455LogicalResult MulOp::verify() { return verifyUTBinOp(*this); }
456
457LogicalResult AndOp::verify() { return verifyUTBinOp(*this); }
458
459LogicalResult OrOp::verify() { return verifyUTBinOp(*this); }
460
461LogicalResult XorOp::verify() { return verifyUTBinOp(*this); }
462
463/// Return true if this is a two operand xor with an all ones constant as its
464/// RHS operand.
465bool XorOp::isBinaryNot() {
466 if (getNumOperands() != 2)
467 return false;
468 if (auto cst = getOperand(1).getDefiningOp<hw::ConstantOp>())
469 if (cst.getValue().isAllOnes())
470 return true;
471 return false;
472}
473
474//===----------------------------------------------------------------------===//
475// ConcatOp
476//===----------------------------------------------------------------------===//
477
478static unsigned getTotalWidth(ValueRange inputs) {
479 unsigned resultWidth = 0;
480 for (auto input : inputs) {
481 resultWidth += hw::type_cast<IntegerType>(input.getType()).getWidth();
482 }
483 return resultWidth;
484}
485
486void ConcatOp::build(OpBuilder &builder, OperationState &result, Value hd,
487 ValueRange tl) {
488 result.addOperands(ValueRange{hd});
489 result.addOperands(tl);
490 unsigned hdWidth = cast<IntegerType>(hd.getType()).getWidth();
491 result.addTypes(builder.getIntegerType(getTotalWidth(tl) + hdWidth));
492}
493
494LogicalResult ConcatOp::inferReturnTypes(
495 MLIRContext *context, std::optional<Location> loc, ValueRange operands,
496 DictionaryAttr attrs, mlir::OpaqueProperties properties,
497 mlir::RegionRange regions, SmallVectorImpl<Type> &results) {
498 unsigned resultWidth = getTotalWidth(operands);
499 results.push_back(IntegerType::get(context, resultWidth));
500 return success();
501}
502
503//===----------------------------------------------------------------------===//
504// ReverseOp
505//===----------------------------------------------------------------------===//
506
507// Folding of ReverseOp: if the input is constant, compute the reverse at
508// compile time.
509OpFoldResult comb::ReverseOp::fold(FoldAdaptor adaptor) {
510 // Try to cast the input attribute to an IntegerAttr.
511 auto cstInput = llvm::dyn_cast_or_null<mlir::IntegerAttr>(adaptor.getInput());
512 if (!cstInput)
513 return {};
514
515 APInt val = cstInput.getValue();
516 APInt reversedVal = val.reverseBits();
517
518 return mlir::IntegerAttr::get(getType(), reversedVal);
519}
520
521namespace {
522struct ReverseOfReverse : public OpRewritePattern<comb::ReverseOp> {
523 using OpRewritePattern<comb::ReverseOp>::OpRewritePattern;
524
525 LogicalResult matchAndRewrite(comb::ReverseOp op,
526 PatternRewriter &rewriter) const override {
527 auto inputOp = op.getInput().getDefiningOp<comb::ReverseOp>();
528 if (!inputOp)
529 return failure();
530
531 rewriter.replaceOp(op, inputOp.getInput());
532 return success();
533 }
534};
535} // namespace
536
537void comb::ReverseOp::getCanonicalizationPatterns(RewritePatternSet &results,
538 MLIRContext *context) {
539 results.add<ReverseOfReverse>(context);
540}
541
542//===----------------------------------------------------------------------===//
543// Other Operations
544//===----------------------------------------------------------------------===//
545
546LogicalResult ExtractOp::verify() {
547 unsigned srcWidth = cast<IntegerType>(getInput().getType()).getWidth();
548 unsigned dstWidth = cast<IntegerType>(getType()).getWidth();
549 if (getLowBit() >= srcWidth || srcWidth - getLowBit() < dstWidth)
550 return emitOpError("from bit too large for input"), failure();
551
552 return success();
553}
554
555LogicalResult TruthTableOp::verify() {
556 size_t numInputs = getInputs().size();
557 if (numInputs >= sizeof(size_t) * 8)
558 return emitOpError("Truth tables support a maximum of ")
559 << sizeof(size_t) * 8 - 1 << " inputs on your platform";
560
561 ArrayAttr table = getLookupTable();
562 if (table.size() != (1ull << numInputs))
563 return emitOpError("Expected lookup table of 2^n length");
564 return success();
565}
566
567//===----------------------------------------------------------------------===//
568// TableGen generated logic.
569//===----------------------------------------------------------------------===//
570
571// Provide the autogenerated implementation guts for the Op classes.
572#define GET_OP_CLASSES
573#include "circt/Dialect/Comb/Comb.cpp.inc"
assert(baseType &&"element must be base type")
static SmallVector< T > concat(const SmallVectorImpl< T > &a, const SmallVectorImpl< T > &b)
Returns a new vector containing the concatenation of vectors a and b.
Definition CalyxOps.cpp:540
static size_t getTotalWidth(ArrayRef< Value > operands)
static LogicalResult verifyUTBinOp(Operation *op)
Definition CombOps.cpp:447
create(low_bit, result_type, input=None)
Definition comb.py:187
create(data_type, value)
Definition hw.py:433
Value createOrFoldNot(Location loc, Value value, OpBuilder &builder, bool twoState=false)
Create a `‘Not’' gate on a value.
Definition CombOps.cpp:65
Value createInject(OpBuilder &builder, Location loc, Value value, unsigned offset, Value replacement)
Replace a range of bits in an integer and return the updated integer value.
Definition CombOps.cpp:186
Value createOrFoldSExt(Location loc, Value value, Type destTy, OpBuilder &builder)
Create a sign extension operation from a value of integer type to an equal or larger integer type.
Definition CombOps.cpp:42
Value createZExt(OpBuilder &builder, Location loc, Value value, unsigned targetWidth)
Create the ops to zero-extend a value to an integer of equal or larger type.
Definition CombOps.cpp:24
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition comb.py:1