CIRCT 20.0.0git
Loading...
Searching...
No Matches
CombOps.cpp
Go to the documentation of this file.
1//===- CombOps.cpp - Implement the Comb operations ------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements combinational ops.
10//
11//===----------------------------------------------------------------------===//
12
15#include "mlir/IR/Builders.h"
16#include "mlir/IR/ImplicitLocOpBuilder.h"
17#include "mlir/IR/PatternMatch.h"
18#include "llvm/Support/FormatVariadic.h"
19
20using namespace circt;
21using namespace comb;
22
23/// Create a sign extension operation from a value of integer type to an equal
24/// or larger integer type.
25Value comb::createOrFoldSExt(Location loc, Value value, Type destTy,
26 OpBuilder &builder) {
27 IntegerType valueType = dyn_cast<IntegerType>(value.getType());
28 assert(valueType && isa<IntegerType>(destTy) &&
29 valueType.getWidth() <= destTy.getIntOrFloatBitWidth() &&
30 valueType.getWidth() != 0 && "invalid sext operands");
31 // If already the right size, we are done.
32 if (valueType == destTy)
33 return value;
34
35 // sext is concat with a replicate of the sign bits and the bottom part.
36 auto signBit =
37 builder.createOrFold<ExtractOp>(loc, value, valueType.getWidth() - 1, 1);
38 auto signBits = builder.createOrFold<ReplicateOp>(
39 loc, signBit, destTy.getIntOrFloatBitWidth() - valueType.getWidth());
40 return builder.createOrFold<ConcatOp>(loc, signBits, value);
41}
42
43Value comb::createOrFoldSExt(Value value, Type destTy,
44 ImplicitLocOpBuilder &builder) {
45 return createOrFoldSExt(builder.getLoc(), value, destTy, builder);
46}
47
48Value comb::createOrFoldNot(Location loc, Value value, OpBuilder &builder,
49 bool twoState) {
50 auto allOnes = builder.create<hw::ConstantOp>(loc, value.getType(), -1);
51 return builder.createOrFold<XorOp>(loc, value, allOnes, twoState);
52}
53
54Value comb::createOrFoldNot(Value value, ImplicitLocOpBuilder &builder,
55 bool twoState) {
56 return createOrFoldNot(builder.getLoc(), value, builder, twoState);
57}
58
59// Extract individual bits from a value
60void comb::extractBits(OpBuilder &builder, Value val,
61 SmallVectorImpl<Value> &bits) {
62 assert(val.getType().isInteger() && "expected integer");
63 auto width = val.getType().getIntOrFloatBitWidth();
64 bits.reserve(width);
65
66 // Check if we can reuse concat operands
67 if (auto concat = val.getDefiningOp<comb::ConcatOp>()) {
68 if (concat.getNumOperands() == width &&
69 llvm::all_of(concat.getOperandTypes(), [](Type type) {
70 return type.getIntOrFloatBitWidth() == 1;
71 })) {
72 // Reverse the operands to match the bit order
73 bits.append(std::make_reverse_iterator(concat.getOperands().end()),
74 std::make_reverse_iterator(concat.getOperands().begin()));
75 return;
76 }
77 }
78
79 // Extract individual bits
80 for (int64_t i = 0; i < width; ++i)
81 bits.push_back(
82 builder.createOrFold<comb::ExtractOp>(val.getLoc(), val, i, 1));
83}
84
85// Construct a mux tree for given leaf nodes. `selectors` is the selector for
86// each level of the tree. Currently the selector is tested from MSB to LSB.
87Value comb::constructMuxTree(OpBuilder &builder, Location loc,
88 ArrayRef<Value> selectors,
89 ArrayRef<Value> leafNodes,
90 Value outOfBoundsValue) {
91 // Recursive helper function to construct the mux tree
92 std::function<Value(size_t, size_t)> constructTreeHelper =
93 [&](size_t id, size_t level) -> Value {
94 // Base case: at the lowest level, return the result
95 if (level == 0) {
96 // Return the result for the given index. If the index is out of bounds,
97 // return the out-of-bound value.
98 return id < leafNodes.size() ? leafNodes[id] : outOfBoundsValue;
99 }
100
101 auto selector = selectors[level - 1];
102
103 // Recursive case: create muxes for true and false branches
104 auto trueVal = constructTreeHelper(2 * id + 1, level - 1);
105 auto falseVal = constructTreeHelper(2 * id, level - 1);
106
107 // Combine the results with a mux
108 return builder.createOrFold<comb::MuxOp>(loc, selector, trueVal, falseVal);
109 };
110
111 return constructTreeHelper(0, llvm::Log2_64_Ceil(leafNodes.size()));
112}
113
114//===----------------------------------------------------------------------===//
115// ICmpOp
116//===----------------------------------------------------------------------===//
117
118ICmpPredicate ICmpOp::getFlippedPredicate(ICmpPredicate predicate) {
119 switch (predicate) {
120 case ICmpPredicate::eq:
121 return ICmpPredicate::eq;
122 case ICmpPredicate::ne:
123 return ICmpPredicate::ne;
124 case ICmpPredicate::slt:
125 return ICmpPredicate::sgt;
126 case ICmpPredicate::sle:
127 return ICmpPredicate::sge;
128 case ICmpPredicate::sgt:
129 return ICmpPredicate::slt;
130 case ICmpPredicate::sge:
131 return ICmpPredicate::sle;
132 case ICmpPredicate::ult:
133 return ICmpPredicate::ugt;
134 case ICmpPredicate::ule:
135 return ICmpPredicate::uge;
136 case ICmpPredicate::ugt:
137 return ICmpPredicate::ult;
138 case ICmpPredicate::uge:
139 return ICmpPredicate::ule;
140 case ICmpPredicate::ceq:
141 return ICmpPredicate::ceq;
142 case ICmpPredicate::cne:
143 return ICmpPredicate::cne;
144 case ICmpPredicate::weq:
145 return ICmpPredicate::weq;
146 case ICmpPredicate::wne:
147 return ICmpPredicate::wne;
148 }
149 llvm_unreachable("unknown comparison predicate");
150}
151
152bool ICmpOp::isPredicateSigned(ICmpPredicate predicate) {
153 switch (predicate) {
154 case ICmpPredicate::ult:
155 case ICmpPredicate::ugt:
156 case ICmpPredicate::ule:
157 case ICmpPredicate::uge:
158 case ICmpPredicate::ne:
159 case ICmpPredicate::eq:
160 case ICmpPredicate::cne:
161 case ICmpPredicate::ceq:
162 case ICmpPredicate::wne:
163 case ICmpPredicate::weq:
164 return false;
165 case ICmpPredicate::slt:
166 case ICmpPredicate::sgt:
167 case ICmpPredicate::sle:
168 case ICmpPredicate::sge:
169 return true;
170 }
171 llvm_unreachable("unknown comparison predicate");
172}
173
174/// Returns the predicate for a logically negated comparison, e.g. mapping
175/// EQ => NE and SLE => SGT.
176ICmpPredicate ICmpOp::getNegatedPredicate(ICmpPredicate predicate) {
177 switch (predicate) {
178 case ICmpPredicate::eq:
179 return ICmpPredicate::ne;
180 case ICmpPredicate::ne:
181 return ICmpPredicate::eq;
182 case ICmpPredicate::slt:
183 return ICmpPredicate::sge;
184 case ICmpPredicate::sle:
185 return ICmpPredicate::sgt;
186 case ICmpPredicate::sgt:
187 return ICmpPredicate::sle;
188 case ICmpPredicate::sge:
189 return ICmpPredicate::slt;
190 case ICmpPredicate::ult:
191 return ICmpPredicate::uge;
192 case ICmpPredicate::ule:
193 return ICmpPredicate::ugt;
194 case ICmpPredicate::ugt:
195 return ICmpPredicate::ule;
196 case ICmpPredicate::uge:
197 return ICmpPredicate::ult;
198 case ICmpPredicate::ceq:
199 return ICmpPredicate::cne;
200 case ICmpPredicate::cne:
201 return ICmpPredicate::ceq;
202 case ICmpPredicate::weq:
203 return ICmpPredicate::wne;
204 case ICmpPredicate::wne:
205 return ICmpPredicate::weq;
206 }
207 llvm_unreachable("unknown comparison predicate");
208}
209
210/// Return true if this is an equality test with -1, which is a "reduction
211/// and" operation in Verilog.
212bool ICmpOp::isEqualAllOnes() {
213 if (getPredicate() != ICmpPredicate::eq)
214 return false;
215
216 if (auto op1 =
217 dyn_cast_or_null<hw::ConstantOp>(getOperand(1).getDefiningOp()))
218 return op1.getValue().isAllOnes();
219 return false;
220}
221
222/// Return true if this is a not equal test with 0, which is a "reduction
223/// or" operation in Verilog.
224bool ICmpOp::isNotEqualZero() {
225 if (getPredicate() != ICmpPredicate::ne)
226 return false;
227
228 if (auto op1 =
229 dyn_cast_or_null<hw::ConstantOp>(getOperand(1).getDefiningOp()))
230 return op1.getValue().isZero();
231 return false;
232}
233
234//===----------------------------------------------------------------------===//
235// Unary Operations
236//===----------------------------------------------------------------------===//
237
238LogicalResult ReplicateOp::verify() {
239 // The source must be equal or smaller than the dest type, and an even
240 // multiple of it. Both are already known to be signless integers.
241 auto srcWidth = cast<IntegerType>(getOperand().getType()).getWidth();
242 auto dstWidth = cast<IntegerType>(getType()).getWidth();
243 if (srcWidth == 0)
244 return emitOpError("replicate does not take zero bit integer");
245
246 if (srcWidth > dstWidth)
247 return emitOpError("replicate cannot shrink bitwidth of operand"),
248 failure();
249
250 if (dstWidth % srcWidth)
251 return emitOpError("replicate must produce integer multiple of operand"),
252 failure();
253
254 return success();
255}
256
257//===----------------------------------------------------------------------===//
258// Variadic operations
259//===----------------------------------------------------------------------===//
260
261static LogicalResult verifyUTBinOp(Operation *op) {
262 if (op->getOperands().empty())
263 return op->emitOpError("requires 1 or more args");
264 return success();
265}
266
267LogicalResult AddOp::verify() { return verifyUTBinOp(*this); }
268
269LogicalResult MulOp::verify() { return verifyUTBinOp(*this); }
270
271LogicalResult AndOp::verify() { return verifyUTBinOp(*this); }
272
273LogicalResult OrOp::verify() { return verifyUTBinOp(*this); }
274
275LogicalResult XorOp::verify() { return verifyUTBinOp(*this); }
276
277/// Return true if this is a two operand xor with an all ones constant as its
278/// RHS operand.
279bool XorOp::isBinaryNot() {
280 if (getNumOperands() != 2)
281 return false;
282 if (auto cst = getOperand(1).getDefiningOp<hw::ConstantOp>())
283 if (cst.getValue().isAllOnes())
284 return true;
285 return false;
286}
287
288//===----------------------------------------------------------------------===//
289// ConcatOp
290//===----------------------------------------------------------------------===//
291
292static unsigned getTotalWidth(ValueRange inputs) {
293 unsigned resultWidth = 0;
294 for (auto input : inputs) {
295 resultWidth += hw::type_cast<IntegerType>(input.getType()).getWidth();
296 }
297 return resultWidth;
298}
299
300LogicalResult ConcatOp::verify() {
301 unsigned tyWidth = cast<IntegerType>(getType()).getWidth();
302 unsigned operandsTotalWidth = getTotalWidth(getInputs());
303 if (tyWidth != operandsTotalWidth)
304 return emitOpError("ConcatOp requires operands total width to "
305 "match type width. operands "
306 "totalWidth is")
307 << operandsTotalWidth << ", but concatOp type width is " << tyWidth;
308
309 return success();
310}
311
312void ConcatOp::build(OpBuilder &builder, OperationState &result, Value hd,
313 ValueRange tl) {
314 result.addOperands(ValueRange{hd});
315 result.addOperands(tl);
316 unsigned hdWidth = cast<IntegerType>(hd.getType()).getWidth();
317 result.addTypes(builder.getIntegerType(getTotalWidth(tl) + hdWidth));
318}
319
320LogicalResult ConcatOp::inferReturnTypes(
321 MLIRContext *context, std::optional<Location> loc, ValueRange operands,
322 DictionaryAttr attrs, mlir::OpaqueProperties properties,
323 mlir::RegionRange regions, SmallVectorImpl<Type> &results) {
324 unsigned resultWidth = getTotalWidth(operands);
325 results.push_back(IntegerType::get(context, resultWidth));
326 return success();
327}
328
329//===----------------------------------------------------------------------===//
330// Other Operations
331//===----------------------------------------------------------------------===//
332
333LogicalResult ExtractOp::verify() {
334 unsigned srcWidth = cast<IntegerType>(getInput().getType()).getWidth();
335 unsigned dstWidth = cast<IntegerType>(getType()).getWidth();
336 if (getLowBit() >= srcWidth || srcWidth - getLowBit() < dstWidth)
337 return emitOpError("from bit too large for input"), failure();
338
339 return success();
340}
341
342LogicalResult TruthTableOp::verify() {
343 size_t numInputs = getInputs().size();
344 if (numInputs >= sizeof(size_t) * 8)
345 return emitOpError("Truth tables support a maximum of ")
346 << sizeof(size_t) * 8 - 1 << " inputs on your platform";
347
348 ArrayAttr table = getLookupTable();
349 if (table.size() != (1ull << numInputs))
350 return emitOpError("Expected lookup table of 2^n length");
351 return success();
352}
353
354//===----------------------------------------------------------------------===//
355// TableGen generated logic.
356//===----------------------------------------------------------------------===//
357
358// Provide the autogenerated implementation guts for the Op classes.
359#define GET_OP_CLASSES
360#include "circt/Dialect/Comb/Comb.cpp.inc"
assert(baseType &&"element must be base type")
static SmallVector< T > concat(const SmallVectorImpl< T > &a, const SmallVectorImpl< T > &b)
Returns a new vector containing the concatenation of vectors a and b.
Definition CalyxOps.cpp:540
static size_t getTotalWidth(ArrayRef< Value > operands)
static LogicalResult verifyUTBinOp(Operation *op)
Definition CombOps.cpp:261
Value createOrFoldNot(Location loc, Value value, OpBuilder &builder, bool twoState=false)
Create a `‘Not’' gate on a value.
Definition CombOps.cpp:48
Value createOrFoldSExt(Location loc, Value value, Type destTy, OpBuilder &builder)
Create a sign extension operation from a value of integer type to an equal or larger integer type.
Definition CombOps.cpp:25
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition comb.py:1