15#include "mlir/IR/Builders.h"
16#include "mlir/IR/ImplicitLocOpBuilder.h"
17#include "mlir/IR/Matchers.h"
18#include "mlir/IR/PatternMatch.h"
19#include "llvm/Support/FormatVariadic.h"
24Value comb::createZExt(OpBuilder &builder, Location loc, Value value,
25 unsigned targetWidth) {
26 assert(value.getType().isSignlessInteger());
27 auto inputWidth = value.getType().getIntOrFloatBitWidth();
28 assert(inputWidth <= targetWidth);
31 if (inputWidth == targetWidth)
36 builder, loc, builder.getIntegerType(targetWidth - inputWidth), 0);
37 return builder.createOrFold<
ConcatOp>(loc, zeros, value);
24Value comb::createZExt(OpBuilder &builder, Location loc, Value value, {
…}
42Value comb::createOrFoldSExt(Location loc, Value value, Type destTy,
44 IntegerType valueType = dyn_cast<IntegerType>(value.getType());
45 assert(valueType && isa<IntegerType>(destTy) &&
46 valueType.getWidth() <= destTy.getIntOrFloatBitWidth() &&
47 valueType.getWidth() != 0 &&
"invalid sext operands");
49 if (valueType == destTy)
54 builder.createOrFold<
ExtractOp>(loc, value, valueType.getWidth() - 1, 1);
55 auto signBits = builder.createOrFold<ReplicateOp>(
56 loc, signBit, destTy.getIntOrFloatBitWidth() - valueType.getWidth());
57 return builder.createOrFold<
ConcatOp>(loc, signBits, value);
42Value comb::createOrFoldSExt(Location loc, Value value, Type destTy, {
…}
60Value comb::createOrFoldSExt(Value value, Type destTy,
61 ImplicitLocOpBuilder &builder) {
60Value comb::createOrFoldSExt(Value value, Type destTy, {
…}
65Value comb::createOrFoldNot(Location loc, Value value, OpBuilder &builder,
68 return builder.createOrFold<
XorOp>(loc, value, allOnes, twoState);
65Value comb::createOrFoldNot(Location loc, Value value, OpBuilder &builder, {
…}
71Value comb::createOrFoldNot(Value value, ImplicitLocOpBuilder &builder,
71Value comb::createOrFoldNot(Value value, ImplicitLocOpBuilder &builder, {
…}
77void comb::extractBits(OpBuilder &builder, Value val,
78 SmallVectorImpl<Value> &bits) {
79 assert(val.getType().isInteger() &&
"expected integer");
80 auto width = val.getType().getIntOrFloatBitWidth();
85 if (
concat.getNumOperands() == width &&
86 llvm::all_of(
concat.getOperandTypes(), [](Type type) {
87 return type.getIntOrFloatBitWidth() == 1;
90 bits.append(std::make_reverse_iterator(
concat.getOperands().end()),
91 std::make_reverse_iterator(
concat.getOperands().begin()));
97 for (int64_t i = 0; i < width; ++i)
77void comb::extractBits(OpBuilder &builder, Value val, {
…}
104Value comb::constructMuxTree(OpBuilder &builder, Location loc,
105 ArrayRef<Value> selectors,
106 ArrayRef<Value> leafNodes,
107 Value outOfBoundsValue) {
109 std::function<Value(
size_t,
size_t)> constructTreeHelper =
110 [&](
size_t id,
size_t level) -> Value {
115 return id < leafNodes.size() ? leafNodes[id] : outOfBoundsValue;
118 auto selector = selectors[level - 1];
121 auto trueVal = constructTreeHelper(2 *
id + 1, level - 1);
122 auto falseVal = constructTreeHelper(2 *
id, level - 1);
125 return builder.createOrFold<
comb::MuxOp>(loc, selector, trueVal, falseVal);
128 return constructTreeHelper(0, llvm::Log2_64_Ceil(leafNodes.size()));
104Value comb::constructMuxTree(OpBuilder &builder, Location loc, {
…}
131Value comb::createDynamicExtract(OpBuilder &builder, Location loc, Value value,
132 Value offset,
unsigned width) {
133 assert(value.getType().isSignlessInteger());
134 auto valueWidth = value.getType().getIntOrFloatBitWidth();
135 assert(width <= valueWidth);
139 if (matchPattern(offset, mlir::m_ConstantInt(&constOffset)))
140 if (constOffset.getActiveBits() < 32)
142 loc, value, constOffset.getZExtValue(), width);
146 offset =
createZExt(builder, loc, offset, valueWidth);
147 value = builder.createOrFold<
comb::ShrUOp>(loc, value, offset);
131Value comb::createDynamicExtract(OpBuilder &builder, Location loc, Value value, {
…}
151Value comb::createDynamicInject(OpBuilder &builder, Location loc, Value value,
152 Value offset, Value replacement,
154 assert(value.getType().isSignlessInteger());
155 assert(replacement.getType().isSignlessInteger());
156 auto largeWidth = value.getType().getIntOrFloatBitWidth();
157 auto smallWidth = replacement.getType().getIntOrFloatBitWidth();
158 assert(smallWidth <= largeWidth);
166 if (matchPattern(offset, mlir::m_ConstantInt(&constOffset)))
167 if (constOffset.getActiveBits() < 32)
168 return createInject(builder, loc, value, constOffset.getZExtValue(),
172 offset =
createZExt(builder, loc, offset, largeWidth);
174 builder, loc, APInt::getLowBitsSet(largeWidth, smallWidth));
175 mask = builder.createOrFold<
comb::ShlOp>(loc, mask, offset);
177 value = builder.createOrFold<
comb::AndOp>(loc, value, mask, twoState);
181 replacement =
createZExt(builder, loc, replacement, largeWidth);
182 replacement = builder.createOrFold<
comb::ShlOp>(loc, replacement, offset);
183 return builder.createOrFold<
comb::OrOp>(loc, value, replacement, twoState);
151Value comb::createDynamicInject(OpBuilder &builder, Location loc, Value value, {
…}
186Value comb::createInject(OpBuilder &builder, Location loc, Value value,
187 unsigned offset, Value replacement) {
188 assert(value.getType().isSignlessInteger());
189 assert(replacement.getType().isSignlessInteger());
190 auto largeWidth = value.getType().getIntOrFloatBitWidth();
191 auto smallWidth = replacement.getType().getIntOrFloatBitWidth();
192 assert(smallWidth <= largeWidth);
195 if (offset >= largeWidth)
204 SmallVector<Value, 3> fragments;
205 auto end = offset + smallWidth;
206 if (end < largeWidth)
209 if (end <= largeWidth)
210 fragments.push_back(replacement);
213 largeWidth - offset));
186Value comb::createInject(OpBuilder &builder, Location loc, Value value, {
…}
221std::pair<Value, Value> comb::fullAdder(OpBuilder &builder, Location loc,
222 Value a, Value b, Value c) {
223 auto aXorB = builder.createOrFold<
comb::XorOp>(loc, a, b,
true);
224 Value sum = builder.createOrFold<
comb::XorOp>(loc, aXorB, c,
true);
226 auto carry = builder.createOrFold<
comb::OrOp>(
228 ArrayRef<Value>{builder.createOrFold<
comb::AndOp>(loc, a, b,
true),
229 builder.createOrFold<
comb::AndOp>(loc, aXorB, c,
true)},
221std::pair<Value, Value> comb::fullAdder(OpBuilder &builder, Location loc, {
…}
238comb::wallaceReduction(OpBuilder &builder, Location loc,
size_t width,
239 size_t targetAddends,
240 SmallVector<SmallVector<Value>> &addends) {
242 SmallVector<SmallVector<Value>> newAddends;
243 newAddends.reserve(addends.size());
246 while (addends.size() > targetAddends) {
249 for (
unsigned i = 0; i < addends.size(); i += 3) {
250 if (i + 2 < addends.size()) {
252 auto &row1 = addends[i];
253 auto &row2 = addends[i + 1];
254 auto &row3 = addends[i + 2];
256 assert(row1.size() == width && row2.size() == width &&
257 row3.size() == width);
259 SmallVector<Value> sumRow, carryRow;
260 sumRow.reserve(width);
261 carryRow.reserve(width);
262 carryRow.push_back(falseValue);
265 for (
unsigned j = 0; j < width; ++j) {
268 comb::fullAdder(builder, loc, row1[j], row2[j], row3[j]);
269 sumRow.push_back(sum);
271 carryRow.push_back(carry);
274 newAddends.push_back(std::move(sumRow));
275 newAddends.push_back(std::move(carryRow));
278 newAddends.append(addends.begin() + i, addends.end());
281 std::swap(newAddends, addends);
284 assert(addends.size() <= targetAddends);
285 SmallVector<Value> carrySave;
286 for (
auto &addend : addends) {
288 std::reverse(addend.begin(), addend.end());
289 carrySave.push_back(comb::ConcatOp::create(builder, loc, addend));
294 while (carrySave.size() < targetAddends)
295 carrySave.push_back(zero);
238comb::wallaceReduction(OpBuilder &builder, Location loc,
size_t width, {
…}
304ICmpPredicate ICmpOp::getFlippedPredicate(ICmpPredicate predicate) {
306 case ICmpPredicate::eq:
307 return ICmpPredicate::eq;
308 case ICmpPredicate::ne:
309 return ICmpPredicate::ne;
310 case ICmpPredicate::slt:
311 return ICmpPredicate::sgt;
312 case ICmpPredicate::sle:
313 return ICmpPredicate::sge;
314 case ICmpPredicate::sgt:
315 return ICmpPredicate::slt;
316 case ICmpPredicate::sge:
317 return ICmpPredicate::sle;
318 case ICmpPredicate::ult:
319 return ICmpPredicate::ugt;
320 case ICmpPredicate::ule:
321 return ICmpPredicate::uge;
322 case ICmpPredicate::ugt:
323 return ICmpPredicate::ult;
324 case ICmpPredicate::uge:
325 return ICmpPredicate::ule;
326 case ICmpPredicate::ceq:
327 return ICmpPredicate::ceq;
328 case ICmpPredicate::cne:
329 return ICmpPredicate::cne;
330 case ICmpPredicate::weq:
331 return ICmpPredicate::weq;
332 case ICmpPredicate::wne:
333 return ICmpPredicate::wne;
335 llvm_unreachable(
"unknown comparison predicate");
338bool ICmpOp::isPredicateSigned(ICmpPredicate predicate) {
340 case ICmpPredicate::ult:
341 case ICmpPredicate::ugt:
342 case ICmpPredicate::ule:
343 case ICmpPredicate::uge:
344 case ICmpPredicate::ne:
345 case ICmpPredicate::eq:
346 case ICmpPredicate::cne:
347 case ICmpPredicate::ceq:
348 case ICmpPredicate::wne:
349 case ICmpPredicate::weq:
351 case ICmpPredicate::slt:
352 case ICmpPredicate::sgt:
353 case ICmpPredicate::sle:
354 case ICmpPredicate::sge:
357 llvm_unreachable(
"unknown comparison predicate");
362ICmpPredicate ICmpOp::getNegatedPredicate(ICmpPredicate predicate) {
364 case ICmpPredicate::eq:
365 return ICmpPredicate::ne;
366 case ICmpPredicate::ne:
367 return ICmpPredicate::eq;
368 case ICmpPredicate::slt:
369 return ICmpPredicate::sge;
370 case ICmpPredicate::sle:
371 return ICmpPredicate::sgt;
372 case ICmpPredicate::sgt:
373 return ICmpPredicate::sle;
374 case ICmpPredicate::sge:
375 return ICmpPredicate::slt;
376 case ICmpPredicate::ult:
377 return ICmpPredicate::uge;
378 case ICmpPredicate::ule:
379 return ICmpPredicate::ugt;
380 case ICmpPredicate::ugt:
381 return ICmpPredicate::ule;
382 case ICmpPredicate::uge:
383 return ICmpPredicate::ult;
384 case ICmpPredicate::ceq:
385 return ICmpPredicate::cne;
386 case ICmpPredicate::cne:
387 return ICmpPredicate::ceq;
388 case ICmpPredicate::weq:
389 return ICmpPredicate::wne;
390 case ICmpPredicate::wne:
391 return ICmpPredicate::weq;
393 llvm_unreachable(
"unknown comparison predicate");
398bool ICmpOp::isEqualAllOnes() {
399 if (getPredicate() != ICmpPredicate::eq)
403 dyn_cast_or_null<hw::ConstantOp>(getOperand(1).getDefiningOp()))
404 return op1.getValue().isAllOnes();
410bool ICmpOp::isNotEqualZero() {
411 if (getPredicate() != ICmpPredicate::ne)
415 dyn_cast_or_null<hw::ConstantOp>(getOperand(1).getDefiningOp()))
416 return op1.getValue().isZero();
424LogicalResult ReplicateOp::verify() {
427 auto srcWidth = cast<IntegerType>(getOperand().getType()).getWidth();
428 auto dstWidth = cast<IntegerType>(getType()).getWidth();
430 return emitOpError(
"replicate does not take zero bit integer");
432 if (srcWidth > dstWidth)
433 return emitOpError(
"replicate cannot shrink bitwidth of operand"),
436 if (dstWidth % srcWidth)
437 return emitOpError(
"replicate must produce integer multiple of operand"),
448 if (op->getOperands().empty())
449 return op->emitOpError(
"requires 1 or more args");
453LogicalResult AddOp::verify() {
return verifyUTBinOp(*
this); }
455LogicalResult MulOp::verify() {
return verifyUTBinOp(*
this); }
457LogicalResult AndOp::verify() {
return verifyUTBinOp(*
this); }
461LogicalResult XorOp::verify() {
return verifyUTBinOp(*
this); }
465bool XorOp::isBinaryNot() {
466 if (getNumOperands() != 2)
468 if (
auto cst = getOperand(1).getDefiningOp<hw::ConstantOp>())
469 if (cst.getValue().isAllOnes())
479 unsigned resultWidth = 0;
480 for (
auto input : inputs) {
481 resultWidth += hw::type_cast<IntegerType>(input.getType()).getWidth();
486void ConcatOp::build(OpBuilder &builder, OperationState &result, Value hd,
488 result.addOperands(ValueRange{hd});
489 result.addOperands(tl);
490 unsigned hdWidth = cast<IntegerType>(hd.getType()).getWidth();
491 result.addTypes(builder.getIntegerType(
getTotalWidth(tl) + hdWidth));
494LogicalResult ConcatOp::inferReturnTypes(
495 MLIRContext *context, std::optional<Location> loc, ValueRange operands,
496 DictionaryAttr attrs, mlir::OpaqueProperties properties,
497 mlir::RegionRange regions, SmallVectorImpl<Type> &results) {
499 results.push_back(IntegerType::get(context, resultWidth));
509OpFoldResult comb::ReverseOp::fold(FoldAdaptor adaptor) {
511 auto cstInput = llvm::dyn_cast_or_null<mlir::IntegerAttr>(adaptor.getInput());
515 APInt val = cstInput.getValue();
516 APInt reversedVal = val.reverseBits();
518 return mlir::IntegerAttr::get(getType(), reversedVal);
525 LogicalResult matchAndRewrite(comb::ReverseOp op,
526 PatternRewriter &rewriter)
const override {
527 auto inputOp = op.getInput().getDefiningOp<comb::ReverseOp>();
531 rewriter.replaceOp(op, inputOp.getInput());
537void comb::ReverseOp::getCanonicalizationPatterns(RewritePatternSet &results,
538 MLIRContext *context) {
539 results.add<ReverseOfReverse>(context);
546LogicalResult ExtractOp::verify() {
547 unsigned srcWidth = cast<IntegerType>(getInput().getType()).getWidth();
548 unsigned dstWidth = cast<IntegerType>(getType()).getWidth();
549 if (getLowBit() >= srcWidth || srcWidth - getLowBit() < dstWidth)
550 return emitOpError(
"from bit too large for input"), failure();
555LogicalResult TruthTableOp::verify() {
556 size_t numInputs = getInputs().size();
557 if (numInputs >=
sizeof(
size_t) * 8)
558 return emitOpError(
"Truth tables support a maximum of ")
559 <<
sizeof(size_t) * 8 - 1 <<
" inputs on your platform";
561 ArrayAttr table = getLookupTable();
562 if (table.size() != (1ull << numInputs))
563 return emitOpError(
"Expected lookup table of 2^n length");
572#define GET_OP_CLASSES
573#include "circt/Dialect/Comb/Comb.cpp.inc"
assert(baseType &&"element must be base type")
static SmallVector< T > concat(const SmallVectorImpl< T > &a, const SmallVectorImpl< T > &b)
Returns a new vector containing the concatenation of vectors a and b.
static size_t getTotalWidth(ArrayRef< Value > operands)
static LogicalResult verifyUTBinOp(Operation *op)
Value createOrFoldNot(Location loc, Value value, OpBuilder &builder, bool twoState=false)
Create a `‘Not’' gate on a value.
Value createInject(OpBuilder &builder, Location loc, Value value, unsigned offset, Value replacement)
Replace a range of bits in an integer and return the updated integer value.
Value createOrFoldSExt(Location loc, Value value, Type destTy, OpBuilder &builder)
Create a sign extension operation from a value of integer type to an equal or larger integer type.
Value createZExt(OpBuilder &builder, Location loc, Value value, unsigned targetWidth)
Create the ops to zero-extend a value to an integer of equal or larger type.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.