CIRCT 22.0.0git
Loading...
Searching...
No Matches
CombOps.cpp
Go to the documentation of this file.
1//===- CombOps.cpp - Implement the Comb operations ------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements combinational ops.
10//
11//===----------------------------------------------------------------------===//
12
15#include "mlir/IR/Builders.h"
16#include "mlir/IR/ImplicitLocOpBuilder.h"
17#include "mlir/IR/Matchers.h"
18#include "mlir/IR/PatternMatch.h"
19#include "llvm/Support/FormatVariadic.h"
20
21using namespace circt;
22using namespace comb;
23
24Value comb::createZExt(OpBuilder &builder, Location loc, Value value,
25 unsigned targetWidth) {
26 assert(value.getType().isSignlessInteger());
27 auto inputWidth = value.getType().getIntOrFloatBitWidth();
28 assert(inputWidth <= targetWidth);
29
30 // Nothing to do if the width already matches.
31 if (inputWidth == targetWidth)
32 return value;
33
34 // Create a zero constant for the upper bits.
35 auto zeros = hw::ConstantOp::create(
36 builder, loc, builder.getIntegerType(targetWidth - inputWidth), 0);
37 return builder.createOrFold<ConcatOp>(loc, zeros, value);
38}
39
40/// Create a sign extension operation from a value of integer type to an equal
41/// or larger integer type.
42Value comb::createOrFoldSExt(Location loc, Value value, Type destTy,
43 OpBuilder &builder) {
44 IntegerType valueType = dyn_cast<IntegerType>(value.getType());
45 assert(valueType && isa<IntegerType>(destTy) &&
46 valueType.getWidth() <= destTy.getIntOrFloatBitWidth() &&
47 valueType.getWidth() != 0 && "invalid sext operands");
48 // If already the right size, we are done.
49 if (valueType == destTy)
50 return value;
51
52 // sext is concat with a replicate of the sign bits and the bottom part.
53 auto signBit =
54 builder.createOrFold<ExtractOp>(loc, value, valueType.getWidth() - 1, 1);
55 auto signBits = builder.createOrFold<ReplicateOp>(
56 loc, signBit, destTy.getIntOrFloatBitWidth() - valueType.getWidth());
57 return builder.createOrFold<ConcatOp>(loc, signBits, value);
58}
59
60Value comb::createOrFoldSExt(Value value, Type destTy,
61 ImplicitLocOpBuilder &builder) {
62 return createOrFoldSExt(builder.getLoc(), value, destTy, builder);
63}
64
65Value comb::createOrFoldNot(Location loc, Value value, OpBuilder &builder,
66 bool twoState) {
67 auto allOnes = hw::ConstantOp::create(builder, loc, value.getType(), -1);
68 return builder.createOrFold<XorOp>(loc, value, allOnes, twoState);
69}
70
71Value comb::createOrFoldNot(Value value, ImplicitLocOpBuilder &builder,
72 bool twoState) {
73 return createOrFoldNot(builder.getLoc(), value, builder, twoState);
74}
75
76// Extract individual bits from a value
77void comb::extractBits(OpBuilder &builder, Value val,
78 SmallVectorImpl<Value> &bits) {
79 assert(val.getType().isInteger() && "expected integer");
80 auto width = val.getType().getIntOrFloatBitWidth();
81 bits.reserve(width);
82
83 // Check if we can reuse concat operands
84 if (auto concat = val.getDefiningOp<comb::ConcatOp>()) {
85 if (concat.getNumOperands() == width &&
86 llvm::all_of(concat.getOperandTypes(), [](Type type) {
87 return type.getIntOrFloatBitWidth() == 1;
88 })) {
89 // Reverse the operands to match the bit order
90 bits.append(std::make_reverse_iterator(concat.getOperands().end()),
91 std::make_reverse_iterator(concat.getOperands().begin()));
92 return;
93 }
94 }
95
96 // Extract individual bits
97 for (int64_t i = 0; i < width; ++i)
98 bits.push_back(
99 builder.createOrFold<comb::ExtractOp>(val.getLoc(), val, i, 1));
100}
101
102// Construct a mux tree for given leaf nodes. `selectors` is the selector for
103// each level of the tree. Currently the selector is tested from MSB to LSB.
104Value comb::constructMuxTree(OpBuilder &builder, Location loc,
105 ArrayRef<Value> selectors,
106 ArrayRef<Value> leafNodes,
107 Value outOfBoundsValue) {
108 // Recursive helper function to construct the mux tree
109 std::function<Value(size_t, size_t)> constructTreeHelper =
110 [&](size_t id, size_t level) -> Value {
111 // Base case: at the lowest level, return the result
112 if (level == 0) {
113 // Return the result for the given index. If the index is out of bounds,
114 // return the out-of-bound value.
115 return id < leafNodes.size() ? leafNodes[id] : outOfBoundsValue;
116 }
117
118 auto selector = selectors[level - 1];
119
120 // Recursive case: create muxes for true and false branches
121 auto trueVal = constructTreeHelper(2 * id + 1, level - 1);
122 auto falseVal = constructTreeHelper(2 * id, level - 1);
123
124 // Combine the results with a mux
125 return builder.createOrFold<comb::MuxOp>(loc, selector, trueVal, falseVal);
126 };
127
128 return constructTreeHelper(0, llvm::Log2_64_Ceil(leafNodes.size()));
129}
130
131Value comb::createDynamicExtract(OpBuilder &builder, Location loc, Value value,
132 Value offset, unsigned width) {
133 assert(value.getType().isSignlessInteger());
134 auto valueWidth = value.getType().getIntOrFloatBitWidth();
135 assert(width <= valueWidth);
136
137 // Handle the special case where the offset is constant.
138 APInt constOffset;
139 if (matchPattern(offset, mlir::m_ConstantInt(&constOffset)))
140 if (constOffset.getActiveBits() < 32)
141 return builder.createOrFold<comb::ExtractOp>(
142 loc, value, constOffset.getZExtValue(), width);
143
144 // Zero-extend the offset, shift the value down, and extract the requested
145 // number of bits.
146 offset = createZExt(builder, loc, offset, valueWidth);
147 value = builder.createOrFold<comb::ShrUOp>(loc, value, offset);
148 return builder.createOrFold<comb::ExtractOp>(loc, value, 0, width);
149}
150
151Value comb::createDynamicInject(OpBuilder &builder, Location loc, Value value,
152 Value offset, Value replacement,
153 bool twoState) {
154 assert(value.getType().isSignlessInteger());
155 assert(replacement.getType().isSignlessInteger());
156 auto largeWidth = value.getType().getIntOrFloatBitWidth();
157 auto smallWidth = replacement.getType().getIntOrFloatBitWidth();
158 assert(smallWidth <= largeWidth);
159
160 // If we're inserting a zero-width value there's nothing to do.
161 if (smallWidth == 0)
162 return value;
163
164 // Handle the special case where the offset is constant.
165 APInt constOffset;
166 if (matchPattern(offset, mlir::m_ConstantInt(&constOffset)))
167 if (constOffset.getActiveBits() < 32)
168 return createInject(builder, loc, value, constOffset.getZExtValue(),
169 replacement);
170
171 // Zero-extend the offset and clear the value bits we are replacing.
172 offset = createZExt(builder, loc, offset, largeWidth);
173 Value mask = hw::ConstantOp::create(
174 builder, loc, APInt::getLowBitsSet(largeWidth, smallWidth));
175 mask = builder.createOrFold<comb::ShlOp>(loc, mask, offset);
176 mask = createOrFoldNot(loc, mask, builder, true);
177 value = builder.createOrFold<comb::AndOp>(loc, value, mask, twoState);
178
179 // Zero-extend the replacement value, shift it up to the offset, and merge it
180 // with the value that has the corresponding bits cleared.
181 replacement = createZExt(builder, loc, replacement, largeWidth);
182 replacement = builder.createOrFold<comb::ShlOp>(loc, replacement, offset);
183 return builder.createOrFold<comb::OrOp>(loc, value, replacement, twoState);
184}
185
186Value comb::createInject(OpBuilder &builder, Location loc, Value value,
187 unsigned offset, Value replacement) {
188 assert(value.getType().isSignlessInteger());
189 assert(replacement.getType().isSignlessInteger());
190 auto largeWidth = value.getType().getIntOrFloatBitWidth();
191 auto smallWidth = replacement.getType().getIntOrFloatBitWidth();
192 assert(smallWidth <= largeWidth);
193
194 // If the offset is outside the value there's nothing to do.
195 if (offset >= largeWidth)
196 return value;
197
198 // If we're inserting a zero-width value there's nothing to do.
199 if (smallWidth == 0)
200 return value;
201
202 // Assemble the pieces of the injection as everything below the offset, the
203 // replacement value, and everything above the replacement value.
204 SmallVector<Value, 3> fragments;
205 auto end = offset + smallWidth;
206 if (end < largeWidth)
207 fragments.push_back(
208 comb::ExtractOp::create(builder, loc, value, end, largeWidth - end));
209 if (end <= largeWidth)
210 fragments.push_back(replacement);
211 else
212 fragments.push_back(comb::ExtractOp::create(builder, loc, replacement, 0,
213 largeWidth - offset));
214 if (offset > 0)
215 fragments.push_back(
216 comb::ExtractOp::create(builder, loc, value, 0, offset));
217 return builder.createOrFold<comb::ConcatOp>(loc, fragments);
218}
219
220//===----------------------------------------------------------------------===//
221// ICmpOp
222//===----------------------------------------------------------------------===//
223
224ICmpPredicate ICmpOp::getFlippedPredicate(ICmpPredicate predicate) {
225 switch (predicate) {
226 case ICmpPredicate::eq:
227 return ICmpPredicate::eq;
228 case ICmpPredicate::ne:
229 return ICmpPredicate::ne;
230 case ICmpPredicate::slt:
231 return ICmpPredicate::sgt;
232 case ICmpPredicate::sle:
233 return ICmpPredicate::sge;
234 case ICmpPredicate::sgt:
235 return ICmpPredicate::slt;
236 case ICmpPredicate::sge:
237 return ICmpPredicate::sle;
238 case ICmpPredicate::ult:
239 return ICmpPredicate::ugt;
240 case ICmpPredicate::ule:
241 return ICmpPredicate::uge;
242 case ICmpPredicate::ugt:
243 return ICmpPredicate::ult;
244 case ICmpPredicate::uge:
245 return ICmpPredicate::ule;
246 case ICmpPredicate::ceq:
247 return ICmpPredicate::ceq;
248 case ICmpPredicate::cne:
249 return ICmpPredicate::cne;
250 case ICmpPredicate::weq:
251 return ICmpPredicate::weq;
252 case ICmpPredicate::wne:
253 return ICmpPredicate::wne;
254 }
255 llvm_unreachable("unknown comparison predicate");
256}
257
258bool ICmpOp::isPredicateSigned(ICmpPredicate predicate) {
259 switch (predicate) {
260 case ICmpPredicate::ult:
261 case ICmpPredicate::ugt:
262 case ICmpPredicate::ule:
263 case ICmpPredicate::uge:
264 case ICmpPredicate::ne:
265 case ICmpPredicate::eq:
266 case ICmpPredicate::cne:
267 case ICmpPredicate::ceq:
268 case ICmpPredicate::wne:
269 case ICmpPredicate::weq:
270 return false;
271 case ICmpPredicate::slt:
272 case ICmpPredicate::sgt:
273 case ICmpPredicate::sle:
274 case ICmpPredicate::sge:
275 return true;
276 }
277 llvm_unreachable("unknown comparison predicate");
278}
279
280/// Returns the predicate for a logically negated comparison, e.g. mapping
281/// EQ => NE and SLE => SGT.
282ICmpPredicate ICmpOp::getNegatedPredicate(ICmpPredicate predicate) {
283 switch (predicate) {
284 case ICmpPredicate::eq:
285 return ICmpPredicate::ne;
286 case ICmpPredicate::ne:
287 return ICmpPredicate::eq;
288 case ICmpPredicate::slt:
289 return ICmpPredicate::sge;
290 case ICmpPredicate::sle:
291 return ICmpPredicate::sgt;
292 case ICmpPredicate::sgt:
293 return ICmpPredicate::sle;
294 case ICmpPredicate::sge:
295 return ICmpPredicate::slt;
296 case ICmpPredicate::ult:
297 return ICmpPredicate::uge;
298 case ICmpPredicate::ule:
299 return ICmpPredicate::ugt;
300 case ICmpPredicate::ugt:
301 return ICmpPredicate::ule;
302 case ICmpPredicate::uge:
303 return ICmpPredicate::ult;
304 case ICmpPredicate::ceq:
305 return ICmpPredicate::cne;
306 case ICmpPredicate::cne:
307 return ICmpPredicate::ceq;
308 case ICmpPredicate::weq:
309 return ICmpPredicate::wne;
310 case ICmpPredicate::wne:
311 return ICmpPredicate::weq;
312 }
313 llvm_unreachable("unknown comparison predicate");
314}
315
316/// Return true if this is an equality test with -1, which is a "reduction
317/// and" operation in Verilog.
318bool ICmpOp::isEqualAllOnes() {
319 if (getPredicate() != ICmpPredicate::eq)
320 return false;
321
322 if (auto op1 =
323 dyn_cast_or_null<hw::ConstantOp>(getOperand(1).getDefiningOp()))
324 return op1.getValue().isAllOnes();
325 return false;
326}
327
328/// Return true if this is a not equal test with 0, which is a "reduction
329/// or" operation in Verilog.
330bool ICmpOp::isNotEqualZero() {
331 if (getPredicate() != ICmpPredicate::ne)
332 return false;
333
334 if (auto op1 =
335 dyn_cast_or_null<hw::ConstantOp>(getOperand(1).getDefiningOp()))
336 return op1.getValue().isZero();
337 return false;
338}
339
340//===----------------------------------------------------------------------===//
341// Unary Operations
342//===----------------------------------------------------------------------===//
343
344LogicalResult ReplicateOp::verify() {
345 // The source must be equal or smaller than the dest type, and an even
346 // multiple of it. Both are already known to be signless integers.
347 auto srcWidth = cast<IntegerType>(getOperand().getType()).getWidth();
348 auto dstWidth = cast<IntegerType>(getType()).getWidth();
349 if (srcWidth == 0)
350 return emitOpError("replicate does not take zero bit integer");
351
352 if (srcWidth > dstWidth)
353 return emitOpError("replicate cannot shrink bitwidth of operand"),
354 failure();
355
356 if (dstWidth % srcWidth)
357 return emitOpError("replicate must produce integer multiple of operand"),
358 failure();
359
360 return success();
361}
362
363//===----------------------------------------------------------------------===//
364// Variadic operations
365//===----------------------------------------------------------------------===//
366
367static LogicalResult verifyUTBinOp(Operation *op) {
368 if (op->getOperands().empty())
369 return op->emitOpError("requires 1 or more args");
370 return success();
371}
372
373LogicalResult AddOp::verify() { return verifyUTBinOp(*this); }
374
375LogicalResult MulOp::verify() { return verifyUTBinOp(*this); }
376
377LogicalResult AndOp::verify() { return verifyUTBinOp(*this); }
378
379LogicalResult OrOp::verify() { return verifyUTBinOp(*this); }
380
381LogicalResult XorOp::verify() { return verifyUTBinOp(*this); }
382
383/// Return true if this is a two operand xor with an all ones constant as
384/// its RHS operand.
385bool XorOp::isBinaryNot() {
386 if (getNumOperands() != 2)
387 return false;
388 if (auto cst = getOperand(1).getDefiningOp<hw::ConstantOp>())
389 if (cst.getValue().isAllOnes())
390 return true;
391 return false;
392}
393
394//===----------------------------------------------------------------------===//
395// ConcatOp
396//===----------------------------------------------------------------------===//
397
398static unsigned getTotalWidth(ValueRange inputs) {
399 unsigned resultWidth = 0;
400 for (auto input : inputs) {
401 resultWidth += hw::type_cast<IntegerType>(input.getType()).getWidth();
402 }
403 return resultWidth;
404}
405
406void ConcatOp::build(OpBuilder &builder, OperationState &result, Value hd,
407 ValueRange tl) {
408 result.addOperands(ValueRange{hd});
409 result.addOperands(tl);
410 unsigned hdWidth = cast<IntegerType>(hd.getType()).getWidth();
411 result.addTypes(builder.getIntegerType(getTotalWidth(tl) + hdWidth));
412}
413
414LogicalResult ConcatOp::inferReturnTypes(
415 MLIRContext *context, std::optional<Location> loc, ValueRange operands,
416 DictionaryAttr attrs, mlir::OpaqueProperties properties,
417 mlir::RegionRange regions, SmallVectorImpl<Type> &results) {
418 unsigned resultWidth = getTotalWidth(operands);
419 results.push_back(IntegerType::get(context, resultWidth));
420 return success();
421}
422
423//===----------------------------------------------------------------------===//
424// ReverseOp
425//===----------------------------------------------------------------------===//
426
427// Folding of ReverseOp: if the input is constant, compute the reverse at
428// compile time.
429OpFoldResult comb::ReverseOp::fold(FoldAdaptor adaptor) {
430 // Try to cast the input attribute to an IntegerAttr.
431 auto cstInput = llvm::dyn_cast_or_null<mlir::IntegerAttr>(adaptor.getInput());
432 if (!cstInput)
433 return {};
434
435 APInt val = cstInput.getValue();
436 APInt reversedVal = val.reverseBits();
437
438 return mlir::IntegerAttr::get(getType(), reversedVal);
439}
440
441namespace {
442struct ReverseOfReverse : public OpRewritePattern<comb::ReverseOp> {
443 using OpRewritePattern<comb::ReverseOp>::OpRewritePattern;
444
445 LogicalResult matchAndRewrite(comb::ReverseOp op,
446 PatternRewriter &rewriter) const override {
447 auto inputOp = op.getInput().getDefiningOp<comb::ReverseOp>();
448 if (!inputOp)
449 return failure();
450
451 rewriter.replaceOp(op, inputOp.getInput());
452 return success();
453 }
454};
455} // namespace
456
457void comb::ReverseOp::getCanonicalizationPatterns(RewritePatternSet &results,
458 MLIRContext *context) {
459 results.add<ReverseOfReverse>(context);
460}
461
462//===----------------------------------------------------------------------===//
463// Other Operations
464//===----------------------------------------------------------------------===//
465
466LogicalResult ExtractOp::verify() {
467 unsigned srcWidth = cast<IntegerType>(getInput().getType()).getWidth();
468 unsigned dstWidth = cast<IntegerType>(getType()).getWidth();
469 if (getLowBit() >= srcWidth || srcWidth - getLowBit() < dstWidth)
470 return emitOpError("from bit too large for input"), failure();
471
472 return success();
473}
474
475LogicalResult TruthTableOp::verify() {
476 size_t numInputs = getInputs().size();
477 if (numInputs >= sizeof(size_t) * 8)
478 return emitOpError("Truth tables support a maximum of ")
479 << sizeof(size_t) * 8 - 1 << " inputs on your platform";
480
481 ArrayAttr table = getLookupTable();
482 if (table.size() != (1ull << numInputs))
483 return emitOpError("Expected lookup table of 2^n length");
484 return success();
485}
486
487//===----------------------------------------------------------------------===//
488// TableGen generated logic.
489//===----------------------------------------------------------------------===//
490
491// Provide the autogenerated implementation guts for the Op classes.
492#define GET_OP_CLASSES
493#include "circt/Dialect/Comb/Comb.cpp.inc"
assert(baseType &&"element must be base type")
static SmallVector< T > concat(const SmallVectorImpl< T > &a, const SmallVectorImpl< T > &b)
Returns a new vector containing the concatenation of vectors a and b.
Definition CalyxOps.cpp:540
static size_t getTotalWidth(ArrayRef< Value > operands)
static LogicalResult verifyUTBinOp(Operation *op)
Definition CombOps.cpp:367
create(low_bit, result_type, input=None)
Definition comb.py:187
create(data_type, value)
Definition hw.py:433
Value createOrFoldNot(Location loc, Value value, OpBuilder &builder, bool twoState=false)
Create a `‘Not’' gate on a value.
Definition CombOps.cpp:65
Value createInject(OpBuilder &builder, Location loc, Value value, unsigned offset, Value replacement)
Replace a range of bits in an integer and return the updated integer value.
Definition CombOps.cpp:186
Value createOrFoldSExt(Location loc, Value value, Type destTy, OpBuilder &builder)
Create a sign extension operation from a value of integer type to an equal or larger integer type.
Definition CombOps.cpp:42
Value createZExt(OpBuilder &builder, Location loc, Value value, unsigned targetWidth)
Create the ops to zero-extend a value to an integer of equal or larger type.
Definition CombOps.cpp:24
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition comb.py:1