CIRCT 22.0.0git
Loading...
Searching...
No Matches
DatapathOps.cpp
Go to the documentation of this file.
1//===----------------------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements datapath ops.
10//
11//===----------------------------------------------------------------------===//
12
16#include "llvm/Support/Debug.h"
17#include "llvm/Support/Format.h"
18#include "llvm/Support/KnownBits.h"
19
20#define DEBUG_TYPE "datapath-ops"
21
22using namespace circt;
23using namespace datapath;
24
25LogicalResult CompressOp::verify() {
26 // The compressor must reduce the number of operands by at least 1 otherwise
27 // it fails to perform any reduction.
28 if (getNumOperands() < 3)
29 return emitOpError("requires 3 or more arguments - otherwise use add");
30
31 if (getNumResults() >= getNumOperands())
32 return emitOpError("must reduce the number of operands by at least 1");
33
34 if (getNumResults() < 2)
35 return emitOpError("must produce at least 2 results");
36
37 return success();
38}
39
40// Parser for the custom type format
41// Parser for "<input-type> [<num-inputs> -> <num-outputs>]"
42static ParseResult parseCompressFormat(OpAsmParser &parser,
43 SmallVectorImpl<Type> &inputTypes,
44 SmallVectorImpl<Type> &resultTypes) {
45
46 int64_t inputCount, resultCount;
47 Type inputElementType;
48
49 if (parser.parseType(inputElementType) || parser.parseLSquare() ||
50 parser.parseInteger(inputCount) || parser.parseArrow() ||
51 parser.parseInteger(resultCount) || parser.parseRSquare())
52 return failure();
53
54 // Inputs and results have same type
55 inputTypes.assign(inputCount, inputElementType);
56 resultTypes.assign(resultCount, inputElementType);
57
58 return success();
59}
60
61// Printer for "<input-type> [<num-inputs> -> <num-outputs>]"
62static void printCompressFormat(OpAsmPrinter &printer, Operation *op,
63 TypeRange inputTypes, TypeRange resultTypes) {
64
65 printer << inputTypes[0] << " [" << inputTypes.size() << " -> "
66 << resultTypes.size() << "]";
67}
68
69//===----------------------------------------------------------------------===//
70// Compressor Tree Logic.
71//===----------------------------------------------------------------------===//
72
73// Construct a full adder for three 1-bit inputs.
74std::pair<CompressorBit, CompressorBit>
77
78 auto aXorB = builder.createOrFold<comb::XorOp>(loc, a.val, b.val, true);
79 Value sumVal = builder.createOrFold<comb::XorOp>(loc, aXorB, c.val, true);
80
81 auto carryVal = builder.createOrFold<comb::OrOp>(
82 loc,
83 ArrayRef<Value>{
84 builder.createOrFold<comb::AndOp>(loc, a.val, b.val, true),
85 builder.createOrFold<comb::AndOp>(loc, aXorB, c.val, true)},
86 true);
87
88 auto sumDelay = std::max(std::max(a.delay, b.delay) + 1, c.delay) + 1;
89 auto carryDelay = sumDelay + 1;
90
91 CompressorBit sum = {sumVal, sumDelay};
92 CompressorBit carry = {carryVal, carryDelay};
93 std::pair<CompressorBit, CompressorBit> fa{sum, carry};
95 return fa;
96}
97
98// Construct a half adder for two 1-bit inputs.
99std::pair<CompressorBit, CompressorBit>
101 CompressorBit b) {
102 auto sumVal = builder.createOrFold<comb::XorOp>(loc, a.val, b.val, true);
103 auto carryVal = builder.createOrFold<comb::AndOp>(loc, a.val, b.val, true);
104
105 auto sumDelay = std::max(a.delay, b.delay) + 1;
106 auto carryDelay = sumDelay;
107
108 CompressorBit sum = {sumVal, sumDelay};
109 CompressorBit carry = {carryVal, carryDelay};
110 std::pair<CompressorBit, CompressorBit> ha{sum, carry};
111 return ha;
112}
113
114// Map input rows to column representation
116 const SmallVector<SmallVector<Value>> &addends,
117 Location loc)
118 : columns(width), width(width), numStages(0), numFullAdders(0), loc(loc) {
119 assert(addends.size() > 2);
120
121 // Convert addends rows to columns
122 // Known bits analysis constructs a minimal array - skipping zeros
123 for (auto row : addends) {
124 // Number of bits in a row == bitwidth of input addends
125 // Compressors will be formed of uniform bitwidth addends
126 assert(row.size() == width);
127 for (size_t i = 0; i < width; ++i) {
128 CompressorBit bit = {row[i], 0};
129 // TODO: Fold Constant 1s
130 auto knownBits = comb::computeKnownBits(bit.val);
131 if (knownBits.isZero())
132 continue;
133 // Add non-zero bit to the column
134 columns[i].push_back(bit);
135 }
136 }
137}
138
139// Update the input delays based on longest path analysis
141 llvm::function_ref<FailureOr<int64_t>(Value)> getDelay) {
142 for (auto &column : columns) {
143 for (auto &[value, result] : column) {
144 auto delay = getDelay(value);
145 if (failed(delay))
146 return failure();
147 result = *delay;
148 }
149 }
150 return success();
151}
152
154 size_t maxSize = 0;
155 for (const auto &column : columns)
156 maxSize = std::max(maxSize, column.size());
157
158 return maxSize;
159}
160
161// Use Dadda's ALAP alogrithm to determine the target height of the next stage
162// https://en.wikipedia.org/wiki/Dadda_multiplier
164 auto maxHeight = getMaxHeight();
165 size_t mPrev = 2;
166 while (true) {
167 size_t m = static_cast<size_t>(std::floor(1.5 * mPrev));
168 if (m >= maxHeight)
169 return mPrev;
170 mPrev = m;
171 }
172}
173
174// Convert back to a concatenated addend representation
175SmallVector<Value> CompressorTree::columnsToAddends(OpBuilder &builder,
176 size_t targetHeight) {
177 SmallVector<Value> addend;
178 SmallVector<Value> addends;
179 auto falseValue = hw::ConstantOp::create(builder, loc, APInt(1, 0));
180 for (size_t i = 0; i < targetHeight; ++i) {
181 // Pad with zeros
182 if (i >= getMaxHeight()) {
183 addends.push_back(hw::ConstantOp::create(builder, loc, APInt(width, 0)));
184 continue;
185 }
186 // Otherwise populate a addend formed from a concatenation
187 for (size_t j = 0; j < width; ++j) {
188 if (i < columns[j].size())
189 addend.push_back(columns[j][i].val);
190 else {
191 addend.push_back(falseValue);
192 }
193 }
194 std::reverse(addend.begin(), addend.end());
195 addends.push_back(comb::ConcatOp::create(builder, loc, addend));
196 addend.clear();
197 }
198 return addends;
199}
200
201// Perform recursive compression until reduced to the target height
202SmallVector<Value> CompressorTree::compressToHeight(OpBuilder &builder,
203 size_t targetHeight) {
204
205 auto maxHeight = getMaxHeight();
206
207 if (maxHeight <= targetHeight)
208 return columnsToAddends(builder, targetHeight);
209
210 return compressUsingTiming(builder, targetHeight);
211}
212
213// Perform recursive compression using timing information until reduced to the
214// target height - this currently uses Dadda's algorithm and timing driven
215// signal selection
216// TODO: Dadda's algorithm is redundant here since it assumes uniform arrival so
217// need to implement a more timing driven approach
218SmallVector<Value> CompressorTree::compressUsingTiming(OpBuilder &builder,
219 size_t targetHeight) {
220 while (getMaxHeight() > targetHeight) {
221 LLVM_DEBUG(dump(););
222 // Increment the number of reduction stages for debugging/reporting
223 ++numStages;
224
225 // Use Dadda's algorithm to compute next stage height
226 auto targetStageHeight = getNextStageTargetHeight();
227 // Initialize empty newColumns
228 SmallVector<SmallVector<CompressorBit>> newColumns(width);
229
230 for (size_t i = 0; i < width; ++i) {
231 auto col = columns[i];
232
233 // Sort the column by arrival time - fastest at the end
234 std::stable_sort(
235 col.begin(), col.end(),
236 [](const auto &a, const auto &b) { return a.delay > b.delay; });
237 // Only compress to reach the target stage height - Dadda's Algorithm
238 while (col.size() + newColumns[i].size() > targetStageHeight) {
239 if (col.size() < 2) {
240 llvm::errs() << "CompressorTree: Not enough bits in column " << i
241 << " to compress further.\n New Columns size: "
242 << newColumns[i].size()
243 << ", Current Column size: " << col.size() << "\n";
244 llvm::report_fatal_error(
245 "Expected at least two bits in compressor column");
246 }
247
248 auto bit0 = col.pop_back_val();
249 auto bit1 = col.pop_back_val();
250
251 // If we have an additional bit we can apply a full adder
252 if (col.size() >= 1) {
253 // bit2 can arrive 1 delay unit after bit0 and bit1 without delaying
254 // the full-adder
255 auto targetDelay = std::max(bit0.delay, bit1.delay) + 1;
256 CompressorBit bit2;
257
258 // Find the third bit of the full-adder that satisfies the delay
259 // constraint
260 auto it = std::find_if(col.begin(), col.end(),
261 [targetDelay](const auto &pair) {
262 return pair.delay <= targetDelay;
263 });
264
265 if (it != col.end()) {
266 bit2 = *it;
267 col.erase(it);
268 } else {
269 // If no bit satisfies the delay constraint pick the fastest one
270 bit2 = col.pop_back_val();
271 }
272 auto [sum, carry] = fullAdderWithDelay(builder, bit0, bit1, bit2);
273
274 newColumns[i].push_back(sum);
275 if (i + 1 < newColumns.size())
276 newColumns[i + 1].push_back(carry);
277 } else {
278 // Apply a half adder to bit0 and bit1
279 auto [sum, carry] = halfAdderWithDelay(builder, bit0, bit1);
280
281 newColumns[i].push_back(sum);
282 if (i + 1 < newColumns.size())
283 newColumns[i + 1].push_back(carry);
284 }
285 }
286
287 // Pass through remaining bits
288 newColumns[i].append(col);
289 }
290
291 // Compute another stage of reduction
292 columns = std::move(newColumns);
293 }
294 LLVM_DEBUG(dump(););
295 return columnsToAddends(builder, targetHeight);
296}
297
299 llvm::dbgs() << "Compressor Tree: Height = " << getMaxHeight()
300 << ", Number of FA = " << numFullAdders
301 << ", Number of Stages = " << numStages
302 << ", Next Stage Target = " << getNextStageTargetHeight()
303 << "\n";
304 // Print column headers
305 llvm::dbgs() << std::string(9, ' ');
306 for (size_t j = width; j > 0; --j) {
307 if (j < width)
308 llvm::dbgs() << " ";
309 llvm::dbgs() << llvm::format("%02d", j - 1);
310 }
311 llvm::dbgs() << "\n"
312 << std::string(9, ' ') << std::string(width * 3, '-') << "\n";
313
314 for (size_t i = 0; i < getMaxHeight(); ++i) {
315 llvm::dbgs() << " [" << llvm::format("%02d", i) << "]: [";
316 for (size_t j = width; j > 0; --j) {
317 if (j < width)
318 llvm::dbgs() << " ";
319 if (i < columns[j - 1].size())
320 llvm::dbgs() << llvm::format(
321 "%02d",
322 columns[j - 1][i].delay); // Assumes CompressorBit has operator
323 else
324 llvm::dbgs() << " ";
325 }
326 llvm::dbgs() << "]\n";
327 }
328}
329
330//===----------------------------------------------------------------------===//
331// TableGen generated logic.
332//===----------------------------------------------------------------------===//
333
334// Provide the autogenerated implementation guts for the Op classes.
335#define GET_OP_CLASSES
336#include "circt/Dialect/Datapath/Datapath.cpp.inc"
assert(baseType &&"element must be base type")
static void printCompressFormat(OpAsmPrinter &printer, Operation *op, TypeRange inputTypes, TypeRange resultTypes)
static ParseResult parseCompressFormat(OpAsmParser &parser, SmallVectorImpl< Type > &inputTypes, SmallVectorImpl< Type > &resultTypes)
SmallVector< SmallVector< CompressorBit > > columns
Definition DatapathOps.h:58
SmallVector< Value > compressToHeight(OpBuilder &builder, size_t targetHeight)
SmallVector< Value > columnsToAddends(OpBuilder &builder, size_t targetHeight)
LogicalResult withInputDelays(llvm::function_ref< FailureOr< int64_t >(Value)> getDelay)
std::pair< CompressorBit, CompressorBit > halfAdderWithDelay(OpBuilder &builder, CompressorBit a, CompressorBit b)
std::pair< CompressorBit, CompressorBit > fullAdderWithDelay(OpBuilder &builder, CompressorBit a, CompressorBit b, CompressorBit c)
SmallVector< Value > compressUsingTiming(OpBuilder &builder, size_t targetHeight)
CompressorTree(size_t width, const SmallVector< SmallVector< Value > > &addends, Location loc)
create(data_type, value)
Definition hw.py:433
KnownBits computeKnownBits(Value value)
Compute "known bits" information about the specified value - the set of bits that are guaranteed to a...
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.