CIRCT 20.0.0git
Loading...
Searching...
No Matches
LowerLUT.cpp
Go to the documentation of this file.
1//===- LowerLUT.cpp -------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
13#include "mlir/Pass/Pass.h"
14#include "mlir/Transforms/DialectConversion.h"
15#include "llvm/Support/Debug.h"
16
17#define DEBUG_TYPE "arc-lower-lut"
18
19namespace circt {
20namespace arc {
21#define GEN_PASS_DEF_LOWERLUT
22#include "circt/Dialect/Arc/ArcPasses.h.inc"
23} // namespace arc
24} // namespace circt
25
26using namespace circt;
27using namespace arc;
28
29//===----------------------------------------------------------------------===//
30// Data structures
31//===----------------------------------------------------------------------===//
32
33namespace {
34
35/// Allows to compute the constant lookup-table entries given the LutOp
36/// operation and caches the result. Also provides additional utility functions
37/// related to lookup-table materialization.
38class LutCalculator {
39public:
40 /// Compute all the lookup-table enties if they haven't already been computed
41 /// and cache the results. Note that calling this function is very expensive
42 /// in terms of runtime as it calls the constant folders of all operations
43 /// inside the LutOp for all possible input values.
44 LogicalResult computeTableEntries(LutOp lut);
45
46 /// Get a reference to the cached lookup-table entries. `computeTableEntries`
47 /// has to be called before calling this function.
48 ArrayRef<IntegerAttr> getRefToTableEntries();
49 /// Get a copy of the cached lookup-table entries. `computeTableEntries` has
50 /// to be called before calling this function.
51 void getCopyOfTableEntries(SmallVector<IntegerAttr> &tableEntries);
52 /// Materialize uniqued hw::ConstantOp operations for all cached lookup-table
53 /// entries. `computeTableEntries` has to be called before calling this
54 /// function.
55 void getTableEntriesAsConstValues(OpBuilder &builder,
56 SmallVector<Value> &tableEntries);
57 /// Compute and return the total size of the table in bits.
58 uint32_t getTableSize();
59 /// Compute and return the summed up bit-width of all input values.
60 uint32_t getInputBitWidth();
61
62private:
63 LutOp lut;
64 SmallVector<IntegerAttr> table;
65};
66
67} // namespace
68
69//===----------------------------------------------------------------------===//
70// Data structure implementations
71//===----------------------------------------------------------------------===//
72
73// Note that this function is very expensive in terms of runtime since it
74// computes the LUT entries by calling the operation's folders
75// O(2^inputBitWidth) times.
76LogicalResult LutCalculator::computeTableEntries(LutOp lut) {
77 // If we already have precomputed the entries for this LUT operation, we don't
78 // need to re-compute it. This is important, because the dialect conversion
79 // framework may try several lowering patterns for the same LutOp after
80 // another and recomputing it every time would be very expensive.
81 if (this->lut == lut && !table.empty())
82 return success();
83
84 // Cache this LUT to be able to apply above shortcut next time and clear the
85 // currently cached table entries from a previous LUT.
86 this->lut = lut;
87 table.clear();
88
89 // Allocate memory
90 DenseMap<Value, SmallVector<Attribute>> vals;
91 const uint32_t bw = getInputBitWidth();
92
93 for (auto arg : lut.getBodyBlock()->getArguments())
94 vals[arg] = SmallVector<Attribute>(1U << bw);
95
96 for (auto &operation : lut.getBodyBlock()->without_terminator()) {
97 for (auto operand : operation.getResults()) {
98 if (vals.count(operand))
99 continue;
100 vals[operand] = SmallVector<Attribute>(1U << bw);
101 }
102 }
103
104 // Initialize inputs
105 for (int i = 0; i < (1 << bw); ++i) {
106 const APInt input(bw, i);
107 size_t offset = bw;
108 for (auto arg : lut.getBodyBlock()->getArguments()) {
109 const unsigned argBitWidth = arg.getType().getIntOrFloatBitWidth();
110 offset -= argBitWidth;
111 vals[arg][i] = IntegerAttr::get(arg.getType(),
112 input.extractBits(argBitWidth, offset));
113 }
114 }
115
116 for (auto &operation : lut.getBodyBlock()->without_terminator()) {
117 // We need to rearange the vectors to use the operation folers. There is
118 // probably still some potential for optimization here.
119 SmallVector<SmallVector<Attribute>, 8> constants(1U << bw);
120 for (size_t j = 0, e = operation.getNumOperands(); j < e; ++j) {
121 SmallVector<Attribute> &tmp = vals[operation.getOperand(j)];
122 for (int i = (1U << bw) - 1; i >= 0; i--)
123 constants[i].push_back(tmp[i]);
124 }
125
126 // Call the operation folders
127 SmallVector<SmallVector<OpFoldResult>, 8> results(
128 1U << bw, SmallVector<OpFoldResult, 8>());
129 for (int i = (1U << bw) - 1; i >= 0; i--) {
130 if (failed(operation.fold(constants[i], results[i]))) {
131 LLVM_DEBUG(llvm::dbgs() << "Failed to fold operation '";
132 operation.print(llvm::dbgs()); llvm::dbgs() << "'\n");
133 return failure();
134 }
135 }
136
137 // Store the folder's results in the value map.
138 for (size_t i = 0, e = operation.getNumResults(); i < e; ++i) {
139 SmallVector<Attribute> &ref = vals[operation.getResult(i)];
140 for (int j = (1U << bw) - 1; j >= 0; j--) {
141 Attribute foldAttr;
142 if (!(foldAttr = dyn_cast<Attribute>(results[j][i])))
143 foldAttr = vals[llvm::cast<Value>(results[j][i])][j];
144 ref[j] = foldAttr;
145 }
146 }
147 }
148
149 // Store the LUT's output values in the correct order in the table entry
150 // cache.
151 auto outValue = lut.getBodyBlock()->getTerminator()->getOperand(0);
152 for (int j = (1U << bw) - 1; j >= 0; j--)
153 table.push_back(cast<IntegerAttr>(vals[outValue][j]));
154
155 return success();
156}
157
158ArrayRef<IntegerAttr> LutCalculator::getRefToTableEntries() { return table; }
159
160void LutCalculator::getCopyOfTableEntries(
161 SmallVector<IntegerAttr> &tableEntries) {
162 tableEntries.append(table);
163}
164
165void LutCalculator::getTableEntriesAsConstValues(
166 OpBuilder &builder, SmallVector<Value> &tableEntries) {
167 // Since LUT entries tend to have a very small bit-width (mostly 1-3 bits),
168 // there are many duplicate constants. Creating a single constant operation
169 // for each unique number saves us a lot of CSE afterwards.
170 DenseMap<IntegerAttr, Value> map;
171 for (auto entry : table) {
172 if (!map.count(entry))
173 map[entry] = builder.create<hw::ConstantOp>(lut.getLoc(), entry);
174
175 tableEntries.push_back(map[entry]);
176 }
177}
178
179uint32_t LutCalculator::getInputBitWidth() {
180 unsigned bw = 0;
181 for (auto val : lut.getInputs())
182 bw += cast<IntegerType>(val.getType()).getWidth();
183 return bw;
184}
185
186uint32_t LutCalculator::getTableSize() {
187 return (1 << getInputBitWidth()) *
188 lut.getOutput().getType().getIntOrFloatBitWidth();
189}
190
191//===----------------------------------------------------------------------===//
192// Conversion patterns
193//===----------------------------------------------------------------------===//
194
195namespace {
196
197/// Lower lookup-tables that have a total size of less than 256 bits to an
198/// integer that is shifed and truncated according to the lookup/index value.
199/// Encoding the lookup tables as intermediate values in the instruction stream
200/// should provide better performnace than loading from some global constant.
201struct LutToInteger : OpConversionPattern<LutOp> {
202 LutToInteger(LutCalculator &calculator, MLIRContext *context)
203 : OpConversionPattern<LutOp>(context), lutCalculator(calculator) {}
204
205 LogicalResult
206 matchAndRewrite(LutOp lut, LutOpAdaptor adaptor,
207 ConversionPatternRewriter &rewriter) const final {
208 if (failed(lutCalculator.computeTableEntries(lut)))
209 return failure();
210
211 const uint32_t tableSize = lutCalculator.getTableSize();
212 const uint32_t inputBw = lutCalculator.getInputBitWidth();
213
214 if (tableSize > 256)
215 return failure();
216
217 // Concatenate the lookup table entries to a single integer.
218 auto constants = lutCalculator.getRefToTableEntries();
219 APInt result(tableSize, 0);
220 unsigned nextInsertion = tableSize;
221
222 for (auto attr : constants) {
223 auto chunk = attr.getValue();
224 nextInsertion -= chunk.getBitWidth();
225 result.insertBits(chunk, nextInsertion);
226 }
227
228 Value table = rewriter.create<hw::ConstantOp>(lut.getLoc(), result);
229
230 // Zero-extend the lookup/index value to the same bit-width as the table,
231 // because the shift operation requires both operands to have the same
232 // bit-width.
233 Value zextValue = rewriter.create<hw::ConstantOp>(
234 lut->getLoc(), rewriter.getIntegerType(tableSize - inputBw), 0);
235 Value entryOffset = rewriter.create<comb::ConcatOp>(lut.getLoc(), zextValue,
236 lut.getInputs());
237 Value resultBitWidth = rewriter.create<hw::ConstantOp>(
238 lut.getLoc(), entryOffset.getType(),
239 lut.getResult().getType().getIntOrFloatBitWidth());
240 Value lookupValue =
241 rewriter.create<comb::MulOp>(lut.getLoc(), entryOffset, resultBitWidth);
242
243 // Shift the table and truncate to the bitwidth of the output value.
244 Value shiftedTable =
245 rewriter.create<comb::ShrUOp>(lut->getLoc(), table, lookupValue);
246 const Value extracted = rewriter.create<comb::ExtractOp>(
247 lut.getLoc(), shiftedTable, 0,
248 lut.getOutput().getType().getIntOrFloatBitWidth());
249
250 rewriter.replaceOp(lut, extracted);
251 return success();
252 }
253
254 LutCalculator &lutCalculator;
255};
256
257/// Lower lookup-tables with a total size bigger than 256 bits to a constant
258/// array that is stored as constant global data and thus a lookup consists of a
259/// memory load at the correct offset of that global data frame.
260struct LutToArray : OpConversionPattern<LutOp> {
261 LutToArray(LutCalculator &calculator, MLIRContext *context)
262 : OpConversionPattern<LutOp>(context), lutCalculator(calculator) {}
263
264 LogicalResult
265 matchAndRewrite(LutOp lut, LutOpAdaptor adaptor,
266 ConversionPatternRewriter &rewriter) const final {
267 if (failed(lutCalculator.computeTableEntries(lut)))
268 return failure();
269
270 auto constants = lutCalculator.getRefToTableEntries();
271 SmallVector<Attribute> constantAttrs(constants.begin(), constants.end());
272 auto tableSize = lutCalculator.getTableSize();
273 auto inputBw = lutCalculator.getInputBitWidth();
274
275 if (tableSize <= 256)
276 return failure();
277
278 Value table = rewriter.create<hw::AggregateConstantOp>(
279 lut.getLoc(), hw::ArrayType::get(lut.getType(), constantAttrs.size()),
280 rewriter.getArrayAttr(constantAttrs));
281 Value lookupValue = rewriter.create<comb::ConcatOp>(
282 lut.getLoc(), rewriter.getIntegerType(inputBw), lut.getInputs());
283 const Value extracted =
284 rewriter.create<hw::ArrayGetOp>(lut.getLoc(), table, lookupValue);
285
286 rewriter.replaceOp(lut, extracted);
287 return success();
288 }
289
290 LutCalculator &lutCalculator;
291};
292
293} // namespace
294
295//===----------------------------------------------------------------------===//
296// Lower LUT pass
297//===----------------------------------------------------------------------===//
298
299namespace {
300
301/// Lower LutOp operations to comb and hw operations.
302struct LowerLUTPass : public arc::impl::LowerLUTBase<LowerLUTPass> {
303 void runOnOperation() override;
304};
305
306} // namespace
307
308void LowerLUTPass::runOnOperation() {
309 MLIRContext &context = getContext();
310 ConversionTarget target(context);
311 RewritePatternSet patterns(&context);
312 target.addLegalDialect<comb::CombDialect, hw::HWDialect, arc::ArcDialect>();
313 target.addIllegalOp<arc::LutOp>();
314
315 // TODO: This class could be factored out into an analysis if there is a need
316 // to access precomputed lookup-tables in some other pass.
317 LutCalculator lutCalculator;
318 patterns.add<LutToInteger, LutToArray>(lutCalculator, &context);
319
320 if (failed(
321 applyPartialConversion(getOperation(), target, std::move(patterns))))
322 signalPassFailure();
323}
324
325std::unique_ptr<Pass> arc::createLowerLUTPass() {
326 return std::make_unique<LowerLUTPass>();
327}
static Block * getBodyBlock(FModuleLike mod)
create(data_type, value)
Definition hw.py:433
std::unique_ptr< mlir::Pass > createLowerLUTPass()
Definition LowerLUT.cpp:325
uint64_t getWidth(Type t)
Definition ESIPasses.cpp:32
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.