16#include "llvm/Support/Debug.h"
17#include "llvm/Support/Format.h"
18#include "llvm/Support/KnownBits.h"
20#define DEBUG_TYPE "datapath-ops"
23using namespace datapath;
25LogicalResult CompressOp::verify() {
28 if (getNumOperands() < 3)
29 return emitOpError(
"requires 3 or more arguments - otherwise use add");
31 if (getNumResults() >= getNumOperands())
32 return emitOpError(
"must reduce the number of operands by at least 1");
34 if (getNumResults() < 2)
35 return emitOpError(
"must produce at least 2 results");
43 SmallVectorImpl<Type> &inputTypes,
44 SmallVectorImpl<Type> &resultTypes) {
46 int64_t inputCount, resultCount;
47 Type inputElementType;
49 if (parser.parseType(inputElementType) || parser.parseLSquare() ||
50 parser.parseInteger(inputCount) || parser.parseArrow() ||
51 parser.parseInteger(resultCount) || parser.parseRSquare())
55 inputTypes.assign(inputCount, inputElementType);
56 resultTypes.assign(resultCount, inputElementType);
63 TypeRange inputTypes, TypeRange resultTypes) {
65 printer << inputTypes[0] <<
" [" << inputTypes.size() <<
" -> "
66 << resultTypes.size() <<
"]";
74std::pair<CompressorBit, CompressorBit>
81 auto carryVal = builder.createOrFold<
comb::OrOp>(
89 auto carryDelay = sumDelay + 1;
93 std::pair<CompressorBit, CompressorBit> fa{sum, carry};
99std::pair<CompressorBit, CompressorBit>
105 auto sumDelay = std::max(a.
delay, b.
delay) + 1;
106 auto carryDelay = sumDelay;
110 std::pair<CompressorBit, CompressorBit> ha{sum, carry};
116 const SmallVector<SmallVector<Value>> &addends,
118 : columns(width), width(width), numStages(0), numFullAdders(0), loc(loc) {
119 assert(addends.size() > 2);
123 for (
auto row : addends) {
127 for (
size_t i = 0; i <
width; ++i) {
131 if (knownBits.isZero())
141 llvm::function_ref<FailureOr<int64_t>(Value)> getDelay) {
143 for (
auto &[value, result] : column) {
144 auto delay = getDelay(value);
155 for (
const auto &column :
columns)
156 maxSize = std::max(maxSize, column.size());
167 size_t m =
static_cast<size_t>(std::floor(1.5 * mPrev));
176 size_t targetHeight) {
177 SmallVector<Value> addend;
178 SmallVector<Value> addends;
180 for (
size_t i = 0; i < targetHeight; ++i) {
187 for (
size_t j = 0; j <
width; ++j) {
189 addend.push_back(
columns[j][i].val);
191 addend.push_back(falseValue);
194 std::reverse(addend.begin(), addend.end());
195 addends.push_back(comb::ConcatOp::create(builder,
loc, addend));
203 size_t targetHeight) {
207 if (maxHeight <= targetHeight)
219 size_t targetHeight) {
228 SmallVector<SmallVector<CompressorBit>> newColumns(
width);
230 for (
size_t i = 0; i <
width; ++i) {
235 col.begin(), col.end(),
236 [](
const auto &a,
const auto &b) { return a.delay > b.delay; });
238 while (col.size() + newColumns[i].size() > targetStageHeight) {
239 if (col.size() < 2) {
240 llvm::errs() <<
"CompressorTree: Not enough bits in column " << i
241 <<
" to compress further.\n New Columns size: "
242 << newColumns[i].size()
243 <<
", Current Column size: " << col.size() <<
"\n";
244 llvm::report_fatal_error(
245 "Expected at least two bits in compressor column");
248 auto bit0 = col.pop_back_val();
249 auto bit1 = col.pop_back_val();
252 if (col.size() >= 1) {
255 auto targetDelay = std::max(bit0.delay, bit1.delay) + 1;
260 auto it = std::find_if(col.begin(), col.end(),
261 [targetDelay](
const auto &pair) {
262 return pair.delay <= targetDelay;
265 if (it != col.end()) {
270 bit2 = col.pop_back_val();
274 newColumns[i].push_back(sum);
275 if (i + 1 < newColumns.size())
276 newColumns[i + 1].push_back(carry);
281 newColumns[i].push_back(sum);
282 if (i + 1 < newColumns.size())
283 newColumns[i + 1].push_back(carry);
288 newColumns[i].append(col);
292 columns = std::move(newColumns);
299 llvm::dbgs() <<
"Compressor Tree: Height = " <<
getMaxHeight()
305 llvm::dbgs() << std::string(9,
' ');
306 for (
size_t j =
width; j > 0; --j) {
309 llvm::dbgs() << llvm::format(
"%02d", j - 1);
312 << std::string(9,
' ') << std::string(
width * 3,
'-') <<
"\n";
315 llvm::dbgs() <<
" [" << llvm::format(
"%02d", i) <<
"]: [";
316 for (
size_t j =
width; j > 0; --j) {
320 llvm::dbgs() << llvm::format(
326 llvm::dbgs() <<
"]\n";
335#define GET_OP_CLASSES
336#include "circt/Dialect/Datapath/Datapath.cpp.inc"
assert(baseType &&"element must be base type")
static void printCompressFormat(OpAsmPrinter &printer, Operation *op, TypeRange inputTypes, TypeRange resultTypes)
static ParseResult parseCompressFormat(OpAsmParser &parser, SmallVectorImpl< Type > &inputTypes, SmallVectorImpl< Type > &resultTypes)
size_t getNextStageTargetHeight() const
SmallVector< SmallVector< CompressorBit > > columns
SmallVector< Value > compressToHeight(OpBuilder &builder, size_t targetHeight)
SmallVector< Value > columnsToAddends(OpBuilder &builder, size_t targetHeight)
LogicalResult withInputDelays(llvm::function_ref< FailureOr< int64_t >(Value)> getDelay)
std::pair< CompressorBit, CompressorBit > halfAdderWithDelay(OpBuilder &builder, CompressorBit a, CompressorBit b)
std::pair< CompressorBit, CompressorBit > fullAdderWithDelay(OpBuilder &builder, CompressorBit a, CompressorBit b, CompressorBit c)
SmallVector< Value > compressUsingTiming(OpBuilder &builder, size_t targetHeight)
size_t getMaxHeight() const
CompressorTree(size_t width, const SmallVector< SmallVector< Value > > &addends, Location loc)
KnownBits computeKnownBits(Value value)
Compute "known bits" information about the specified value - the set of bits that are guaranteed to a...
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.