16#include "mlir/IR/Builders.h"
17#include "mlir/IR/ImplicitLocOpBuilder.h"
18#include "mlir/IR/Matchers.h"
19#include "mlir/IR/PatternMatch.h"
20#include "llvm/Support/FormatVariadic.h"
25Value comb::createZExt(OpBuilder &builder, Location loc, Value value,
26 unsigned targetWidth) {
27 assert(value.getType().isSignlessInteger());
28 auto inputWidth = value.getType().getIntOrFloatBitWidth();
29 assert(inputWidth <= targetWidth);
32 if (inputWidth == targetWidth)
37 builder, loc, builder.getIntegerType(targetWidth - inputWidth), 0);
38 return builder.createOrFold<
ConcatOp>(loc, zeros, value);
43Value comb::createOrFoldSExt(Location loc, Value value, Type destTy,
45 IntegerType valueType = dyn_cast<IntegerType>(value.getType());
46 assert(valueType && isa<IntegerType>(destTy) &&
47 valueType.getWidth() <= destTy.getIntOrFloatBitWidth() &&
48 valueType.getWidth() != 0 &&
"invalid sext operands");
50 if (valueType == destTy)
55 builder.createOrFold<
ExtractOp>(loc, value, valueType.getWidth() - 1, 1);
56 auto signBits = builder.createOrFold<ReplicateOp>(
57 loc, signBit, destTy.getIntOrFloatBitWidth() - valueType.getWidth());
58 return builder.createOrFold<
ConcatOp>(loc, signBits, value);
61Value comb::createOrFoldSExt(Value value, Type destTy,
62 ImplicitLocOpBuilder &builder) {
66Value comb::createOrFoldNot(Location loc, Value value, OpBuilder &builder,
69 return builder.createOrFold<
XorOp>(loc, value, allOnes, twoState);
72Value comb::createOrFoldNot(Value value, ImplicitLocOpBuilder &builder,
78void comb::extractBits(OpBuilder &builder, Value val,
79 SmallVectorImpl<Value> &bits) {
80 assert(val.getType().isInteger() &&
"expected integer");
81 auto width = val.getType().getIntOrFloatBitWidth();
86 if (
concat.getNumOperands() == width &&
87 llvm::all_of(
concat.getOperandTypes(), [](Type type) {
88 return type.getIntOrFloatBitWidth() == 1;
91 bits.append(std::make_reverse_iterator(
concat.getOperands().end()),
92 std::make_reverse_iterator(
concat.getOperands().begin()));
98 for (int64_t i = 0; i < width; ++i)
105Value comb::constructMuxTree(OpBuilder &builder, Location loc,
106 ArrayRef<Value> selectors,
107 ArrayRef<Value> leafNodes,
108 Value outOfBoundsValue) {
110 std::function<Value(
size_t,
size_t)> constructTreeHelper =
111 [&](
size_t id,
size_t level) -> Value {
116 return id < leafNodes.size() ? leafNodes[id] : outOfBoundsValue;
119 auto selector = selectors[level - 1];
122 auto trueVal = constructTreeHelper(2 *
id + 1, level - 1);
123 auto falseVal = constructTreeHelper(2 *
id, level - 1);
126 return builder.createOrFold<
comb::MuxOp>(loc, selector, trueVal, falseVal);
129 return constructTreeHelper(0, llvm::Log2_64_Ceil(leafNodes.size()));
132Value comb::createDynamicExtract(OpBuilder &builder, Location loc, Value value,
133 Value offset,
unsigned width) {
134 assert(value.getType().isSignlessInteger());
135 auto valueWidth = value.getType().getIntOrFloatBitWidth();
136 assert(width <= valueWidth);
140 if (matchPattern(offset, mlir::m_ConstantInt(&constOffset)))
141 if (constOffset.getActiveBits() < 32)
143 loc, value, constOffset.getZExtValue(), width);
147 offset =
createZExt(builder, loc, offset, valueWidth);
148 value = builder.createOrFold<
comb::ShrUOp>(loc, value, offset);
152Value comb::createDynamicInject(OpBuilder &builder, Location loc, Value value,
153 Value offset, Value replacement,
155 assert(value.getType().isSignlessInteger());
156 assert(replacement.getType().isSignlessInteger());
157 auto largeWidth = value.getType().getIntOrFloatBitWidth();
158 auto smallWidth = replacement.getType().getIntOrFloatBitWidth();
159 assert(smallWidth <= largeWidth);
167 if (matchPattern(offset, mlir::m_ConstantInt(&constOffset)))
168 if (constOffset.getActiveBits() < 32)
169 return createInject(builder, loc, value, constOffset.getZExtValue(),
173 offset =
createZExt(builder, loc, offset, largeWidth);
175 builder, loc, APInt::getLowBitsSet(largeWidth, smallWidth));
176 mask = builder.createOrFold<
comb::ShlOp>(loc, mask, offset);
178 value = builder.createOrFold<
comb::AndOp>(loc, value, mask, twoState);
182 replacement =
createZExt(builder, loc, replacement, largeWidth);
183 replacement = builder.createOrFold<
comb::ShlOp>(loc, replacement, offset);
184 return builder.createOrFold<
comb::OrOp>(loc, value, replacement, twoState);
187Value comb::createInject(OpBuilder &builder, Location loc, Value value,
188 unsigned offset, Value replacement) {
189 assert(value.getType().isSignlessInteger());
190 assert(replacement.getType().isSignlessInteger());
191 auto largeWidth = value.getType().getIntOrFloatBitWidth();
192 auto smallWidth = replacement.getType().getIntOrFloatBitWidth();
193 assert(smallWidth <= largeWidth);
196 if (offset >= largeWidth)
205 SmallVector<Value, 3> fragments;
206 auto end = offset + smallWidth;
207 if (end < largeWidth)
210 if (end <= largeWidth)
211 fragments.push_back(replacement);
214 largeWidth - offset));
222 mlir::PatternRewriter &rewriter) {
223 auto lhs = subOp.getLhs();
224 auto rhs = subOp.getRhs();
229 comb::createOrFoldNot(subOp.getLoc(), rhs, rewriter, subOp.getTwoState());
232 replaceOpWithNewOpAndCopyNamehint<comb::AddOp>(
233 rewriter, subOp, ValueRange{lhs, notRhs, one}, subOp.getTwoState());
241ICmpPredicate ICmpOp::getFlippedPredicate(ICmpPredicate predicate) {
243 case ICmpPredicate::eq:
244 return ICmpPredicate::eq;
245 case ICmpPredicate::ne:
246 return ICmpPredicate::ne;
247 case ICmpPredicate::slt:
248 return ICmpPredicate::sgt;
249 case ICmpPredicate::sle:
250 return ICmpPredicate::sge;
251 case ICmpPredicate::sgt:
252 return ICmpPredicate::slt;
253 case ICmpPredicate::sge:
254 return ICmpPredicate::sle;
255 case ICmpPredicate::ult:
256 return ICmpPredicate::ugt;
257 case ICmpPredicate::ule:
258 return ICmpPredicate::uge;
259 case ICmpPredicate::ugt:
260 return ICmpPredicate::ult;
261 case ICmpPredicate::uge:
262 return ICmpPredicate::ule;
263 case ICmpPredicate::ceq:
264 return ICmpPredicate::ceq;
265 case ICmpPredicate::cne:
266 return ICmpPredicate::cne;
267 case ICmpPredicate::weq:
268 return ICmpPredicate::weq;
269 case ICmpPredicate::wne:
270 return ICmpPredicate::wne;
272 llvm_unreachable(
"unknown comparison predicate");
275bool ICmpOp::isPredicateSigned(ICmpPredicate predicate) {
277 case ICmpPredicate::ult:
278 case ICmpPredicate::ugt:
279 case ICmpPredicate::ule:
280 case ICmpPredicate::uge:
281 case ICmpPredicate::ne:
282 case ICmpPredicate::eq:
283 case ICmpPredicate::cne:
284 case ICmpPredicate::ceq:
285 case ICmpPredicate::wne:
286 case ICmpPredicate::weq:
288 case ICmpPredicate::slt:
289 case ICmpPredicate::sgt:
290 case ICmpPredicate::sle:
291 case ICmpPredicate::sge:
294 llvm_unreachable(
"unknown comparison predicate");
299ICmpPredicate ICmpOp::getNegatedPredicate(ICmpPredicate predicate) {
301 case ICmpPredicate::eq:
302 return ICmpPredicate::ne;
303 case ICmpPredicate::ne:
304 return ICmpPredicate::eq;
305 case ICmpPredicate::slt:
306 return ICmpPredicate::sge;
307 case ICmpPredicate::sle:
308 return ICmpPredicate::sgt;
309 case ICmpPredicate::sgt:
310 return ICmpPredicate::sle;
311 case ICmpPredicate::sge:
312 return ICmpPredicate::slt;
313 case ICmpPredicate::ult:
314 return ICmpPredicate::uge;
315 case ICmpPredicate::ule:
316 return ICmpPredicate::ugt;
317 case ICmpPredicate::ugt:
318 return ICmpPredicate::ule;
319 case ICmpPredicate::uge:
320 return ICmpPredicate::ult;
321 case ICmpPredicate::ceq:
322 return ICmpPredicate::cne;
323 case ICmpPredicate::cne:
324 return ICmpPredicate::ceq;
325 case ICmpPredicate::weq:
326 return ICmpPredicate::wne;
327 case ICmpPredicate::wne:
328 return ICmpPredicate::weq;
330 llvm_unreachable(
"unknown comparison predicate");
335bool ICmpOp::isEqualAllOnes() {
336 if (getPredicate() != ICmpPredicate::eq)
340 dyn_cast_or_null<hw::ConstantOp>(getOperand(1).getDefiningOp()))
341 return op1.getValue().isAllOnes();
347bool ICmpOp::isNotEqualZero() {
348 if (getPredicate() != ICmpPredicate::ne)
352 dyn_cast_or_null<hw::ConstantOp>(getOperand(1).getDefiningOp()))
353 return op1.getValue().isZero();
361LogicalResult ReplicateOp::verify() {
364 auto srcWidth = cast<IntegerType>(getOperand().getType()).getWidth();
365 auto dstWidth = cast<IntegerType>(getType()).getWidth();
367 return emitOpError(
"replicate does not take zero bit integer");
369 if (srcWidth > dstWidth)
370 return emitOpError(
"replicate cannot shrink bitwidth of operand"),
373 if (dstWidth % srcWidth)
374 return emitOpError(
"replicate must produce integer multiple of operand"),
385 if (op->getOperands().empty())
386 return op->emitOpError(
"requires 1 or more args");
390LogicalResult AddOp::verify() {
return verifyUTBinOp(*
this); }
392LogicalResult MulOp::verify() {
return verifyUTBinOp(*
this); }
394LogicalResult AndOp::verify() {
return verifyUTBinOp(*
this); }
398LogicalResult XorOp::verify() {
return verifyUTBinOp(*
this); }
402bool XorOp::isBinaryNot() {
403 if (getNumOperands() != 2)
405 if (
auto cst = getOperand(1).getDefiningOp<hw::ConstantOp>())
406 if (cst.getValue().isAllOnes())
416 unsigned resultWidth = 0;
417 for (
auto input : inputs) {
418 resultWidth += hw::type_cast<IntegerType>(input.getType()).getWidth();
423void ConcatOp::build(OpBuilder &builder, OperationState &result, Value hd,
425 result.addOperands(ValueRange{hd});
426 result.addOperands(tl);
427 unsigned hdWidth = cast<IntegerType>(hd.getType()).getWidth();
428 result.addTypes(builder.getIntegerType(
getTotalWidth(tl) + hdWidth));
431LogicalResult ConcatOp::inferReturnTypes(
432 MLIRContext *context, std::optional<Location> loc, ValueRange operands,
433 DictionaryAttr attrs, mlir::OpaqueProperties properties,
434 mlir::RegionRange regions, SmallVectorImpl<Type> &results) {
436 results.push_back(IntegerType::get(context, resultWidth));
446OpFoldResult comb::ReverseOp::fold(FoldAdaptor adaptor) {
448 auto cstInput = llvm::dyn_cast_or_null<mlir::IntegerAttr>(adaptor.getInput());
452 APInt val = cstInput.getValue();
453 APInt reversedVal = val.reverseBits();
455 return mlir::IntegerAttr::get(getType(), reversedVal);
462 LogicalResult matchAndRewrite(comb::ReverseOp op,
463 PatternRewriter &rewriter)
const override {
464 auto inputOp = op.getInput().getDefiningOp<comb::ReverseOp>();
468 rewriter.replaceOp(op, inputOp.getInput());
474void comb::ReverseOp::getCanonicalizationPatterns(RewritePatternSet &results,
475 MLIRContext *context) {
476 results.add<ReverseOfReverse>(context);
483LogicalResult ExtractOp::verify() {
484 unsigned srcWidth = cast<IntegerType>(getInput().getType()).getWidth();
485 unsigned dstWidth = cast<IntegerType>(getType()).getWidth();
486 if (getLowBit() >= srcWidth || srcWidth - getLowBit() < dstWidth)
487 return emitOpError(
"from bit too large for input"), failure();
492LogicalResult TruthTableOp::verify() {
493 size_t numInputs = getInputs().size();
494 if (numInputs >=
sizeof(
size_t) * 8)
495 return emitOpError(
"Truth tables support a maximum of ")
496 <<
sizeof(size_t) * 8 - 1 <<
" inputs on your platform";
498 ArrayAttr table = getLookupTable();
499 if (table.size() != (1ull << numInputs))
500 return emitOpError(
"Expected lookup table of 2^n length");
509#define GET_OP_CLASSES
510#include "circt/Dialect/Comb/Comb.cpp.inc"
assert(baseType &&"element must be base type")
static SmallVector< T > concat(const SmallVectorImpl< T > &a, const SmallVectorImpl< T > &b)
Returns a new vector containing the concatenation of vectors a and b.
static size_t getTotalWidth(ArrayRef< Value > operands)
static LogicalResult verifyUTBinOp(Operation *op)
Value createOrFoldNot(Location loc, Value value, OpBuilder &builder, bool twoState=false)
Create a `‘Not’' gate on a value.
Value createInject(OpBuilder &builder, Location loc, Value value, unsigned offset, Value replacement)
Replace a range of bits in an integer and return the updated integer value.
Value createOrFoldSExt(Location loc, Value value, Type destTy, OpBuilder &builder)
Create a sign extension operation from a value of integer type to an equal or larger integer type.
Value createZExt(OpBuilder &builder, Location loc, Value value, unsigned targetWidth)
Create the ops to zero-extend a value to an integer of equal or larger type.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.