CIRCT  19.0.0git
LowerLUT.cpp
Go to the documentation of this file.
1 //===- LowerLUT.cpp -------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
12 #include "circt/Dialect/HW/HWOps.h"
13 #include "mlir/Pass/Pass.h"
14 #include "mlir/Transforms/DialectConversion.h"
15 #include "llvm/Support/Debug.h"
16 
17 #define DEBUG_TYPE "arc-lower-lut"
18 
19 namespace circt {
20 namespace arc {
21 #define GEN_PASS_DEF_LOWERLUT
22 #include "circt/Dialect/Arc/ArcPasses.h.inc"
23 } // namespace arc
24 } // namespace circt
25 
26 using namespace circt;
27 using namespace arc;
28 
29 //===----------------------------------------------------------------------===//
30 // Data structures
31 //===----------------------------------------------------------------------===//
32 
33 namespace {
34 
35 /// Allows to compute the constant lookup-table entries given the LutOp
36 /// operation and caches the result. Also provides additional utility functions
37 /// related to lookup-table materialization.
38 class LutCalculator {
39 public:
40  /// Compute all the lookup-table enties if they haven't already been computed
41  /// and cache the results. Note that calling this function is very expensive
42  /// in terms of runtime as it calls the constant folders of all operations
43  /// inside the LutOp for all possible input values.
44  LogicalResult computeTableEntries(LutOp lut);
45 
46  /// Get a reference to the cached lookup-table entries. `computeTableEntries`
47  /// has to be called before calling this function.
48  ArrayRef<IntegerAttr> getRefToTableEntries();
49  /// Get a copy of the cached lookup-table entries. `computeTableEntries` has
50  /// to be called before calling this function.
51  void getCopyOfTableEntries(SmallVector<IntegerAttr> &tableEntries);
52  /// Materialize uniqued hw::ConstantOp operations for all cached lookup-table
53  /// entries. `computeTableEntries` has to be called before calling this
54  /// function.
55  void getTableEntriesAsConstValues(OpBuilder &builder,
56  SmallVector<Value> &tableEntries);
57  /// Compute and return the total size of the table in bits.
58  uint32_t getTableSize();
59  /// Compute and return the summed up bit-width of all input values.
60  uint32_t getInputBitWidth();
61 
62 private:
63  LutOp lut;
64  SmallVector<IntegerAttr> table;
65 };
66 
67 /// A wrapper around ConversionPattern that matches specifically on LutOp
68 /// operations and hold a LutCalculator member variable that allows to compute
69 /// the lookup-table entries and cache the result.
70 class LutLoweringPattern : public ConversionPattern {
71 public:
72  LutLoweringPattern(LutCalculator &lutCalculator, MLIRContext *context,
73  mlir::PatternBenefit benefit = 1)
74  : ConversionPattern(LutOp::getOperationName(), benefit, context),
75  lutCalculator(lutCalculator) {}
76  LutLoweringPattern(LutCalculator &lutCalculator, TypeConverter &typeConverter,
77  MLIRContext *context, mlir::PatternBenefit benefit = 1)
78  : ConversionPattern(typeConverter, LutOp::getOperationName(), benefit,
79  context),
80  lutCalculator(lutCalculator) {}
81 
82  /// Wrappers around the ConversionPattern methods that pass the LutOp
83  /// type and guarantee that the LutCalculator is up-to-date.
84  LogicalResult match(Operation *op) const final {
85  auto lut = cast<LutOp>(op);
86  if (failed(lutCalculator.computeTableEntries(lut)))
87  return failure();
88  return match(lut);
89  }
90  void rewrite(Operation *op, ArrayRef<Value> operands,
91  ConversionPatternRewriter &rewriter) const final {
92  rewrite(cast<LutOp>(op), LutOpAdaptor(operands, op->getAttrDictionary()),
93  rewriter);
94  }
95  LogicalResult
96  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
97  ConversionPatternRewriter &rewriter) const final {
98  auto lut = cast<LutOp>(op);
99  if (failed(lutCalculator.computeTableEntries(lut)))
100  return failure();
101  return matchAndRewrite(lut, LutOpAdaptor(operands, op->getAttrDictionary()),
102  rewriter);
103  }
104 
105  /// Rewrite and Match methods that operate on the LutOp type. These must be
106  /// overridden by the derived pattern class.
107  virtual LogicalResult match(LutOp op) const {
108  llvm_unreachable("must override match or matchAndRewrite");
109  }
110  virtual void rewrite(LutOp op, LutOpAdaptor adaptor,
111  ConversionPatternRewriter &rewriter) const {
112  llvm_unreachable("must override matchAndRewrite or a rewrite method");
113  }
114  virtual LogicalResult
115  matchAndRewrite(LutOp op, LutOpAdaptor adaptor,
116  ConversionPatternRewriter &rewriter) const {
117  if (failed(match(op)))
118  return failure();
119  rewrite(op, adaptor, rewriter);
120  return success();
121  }
122 
123 protected:
124  LutCalculator &lutCalculator;
125 
126 private:
127  using ConversionPattern::matchAndRewrite;
128 };
129 
130 } // namespace
131 
132 //===----------------------------------------------------------------------===//
133 // Data structure implementations
134 //===----------------------------------------------------------------------===//
135 
136 // Note that this function is very expensive in terms of runtime since it
137 // computes the LUT entries by calling the operation's folders
138 // O(2^inputBitWidth) times.
139 LogicalResult LutCalculator::computeTableEntries(LutOp lut) {
140  // If we already have precomputed the entries for this LUT operation, we don't
141  // need to re-compute it. This is important, because the dialect conversion
142  // framework may try several lowering patterns for the same LutOp after
143  // another and recomputing it every time would be very expensive.
144  if (this->lut == lut && !table.empty())
145  return success();
146 
147  // Cache this LUT to be able to apply above shortcut next time and clear the
148  // currently cached table entries from a previous LUT.
149  this->lut = lut;
150  table.clear();
151 
152  // Allocate memory
153  DenseMap<Value, SmallVector<Attribute>> vals;
154  const uint32_t bw = getInputBitWidth();
155 
156  for (auto arg : lut.getBodyBlock()->getArguments())
157  vals[arg] = SmallVector<Attribute>(1U << bw);
158 
159  for (auto &operation : lut.getBodyBlock()->without_terminator()) {
160  for (auto operand : operation.getResults()) {
161  if (vals.count(operand))
162  continue;
163  vals[operand] = SmallVector<Attribute>(1U << bw);
164  }
165  }
166 
167  // Initialize inputs
168  for (int i = 0; i < (1 << bw); ++i) {
169  const APInt input(bw, i);
170  size_t offset = bw;
171  for (auto arg : lut.getBodyBlock()->getArguments()) {
172  const unsigned argBitWidth = arg.getType().getIntOrFloatBitWidth();
173  offset -= argBitWidth;
174  vals[arg][i] = IntegerAttr::get(arg.getType(),
175  input.extractBits(argBitWidth, offset));
176  }
177  }
178 
179  for (auto &operation : lut.getBodyBlock()->without_terminator()) {
180  // We need to rearange the vectors to use the operation folers. There is
181  // probably still some potential for optimization here.
182  SmallVector<SmallVector<Attribute>, 8> constants(1U << bw);
183  for (size_t j = 0, e = operation.getNumOperands(); j < e; ++j) {
184  SmallVector<Attribute> &tmp = vals[operation.getOperand(j)];
185  for (int i = (1U << bw) - 1; i >= 0; i--)
186  constants[i].push_back(tmp[i]);
187  }
188 
189  // Call the operation folders
190  SmallVector<SmallVector<OpFoldResult>, 8> results(
191  1U << bw, SmallVector<OpFoldResult, 8>());
192  for (int i = (1U << bw) - 1; i >= 0; i--) {
193  if (failed(operation.fold(constants[i], results[i]))) {
194  LLVM_DEBUG(llvm::dbgs() << "Failed to fold operation '";
195  operation.print(llvm::dbgs()); llvm::dbgs() << "'\n");
196  return failure();
197  }
198  }
199 
200  // Store the folder's results in the value map.
201  for (size_t i = 0, e = operation.getNumResults(); i < e; ++i) {
202  SmallVector<Attribute> &ref = vals[operation.getResult(i)];
203  for (int j = (1U << bw) - 1; j >= 0; j--) {
204  Attribute foldAttr;
205  if (!(foldAttr = results[j][i].dyn_cast<Attribute>()))
206  foldAttr = vals[results[j][i].get<Value>()][j];
207  ref[j] = foldAttr;
208  }
209  }
210  }
211 
212  // Store the LUT's output values in the correct order in the table entry
213  // cache.
214  auto outValue = lut.getBodyBlock()->getTerminator()->getOperand(0);
215  for (int j = (1U << bw) - 1; j >= 0; j--)
216  table.push_back(vals[outValue][j].cast<IntegerAttr>());
217 
218  return success();
219 }
220 
221 ArrayRef<IntegerAttr> LutCalculator::getRefToTableEntries() { return table; }
222 
223 void LutCalculator::getCopyOfTableEntries(
224  SmallVector<IntegerAttr> &tableEntries) {
225  tableEntries.append(table);
226 }
227 
228 void LutCalculator::getTableEntriesAsConstValues(
229  OpBuilder &builder, SmallVector<Value> &tableEntries) {
230  // Since LUT entries tend to have a very small bit-width (mostly 1-3 bits),
231  // there are many duplicate constants. Creating a single constant operation
232  // for each unique number saves us a lot of CSE afterwards.
233  DenseMap<IntegerAttr, Value> map;
234  for (auto entry : table) {
235  if (!map.count(entry))
236  map[entry] = builder.create<hw::ConstantOp>(lut.getLoc(), entry);
237 
238  tableEntries.push_back(map[entry]);
239  }
240 }
241 
242 uint32_t LutCalculator::getInputBitWidth() {
243  unsigned bw = 0;
244  for (auto val : lut.getInputs())
245  bw += val.getType().cast<IntegerType>().getWidth();
246  return bw;
247 }
248 
249 uint32_t LutCalculator::getTableSize() {
250  return (1 << getInputBitWidth()) *
251  lut.getOutput().getType().getIntOrFloatBitWidth();
252 }
253 
254 //===----------------------------------------------------------------------===//
255 // Conversion patterns
256 //===----------------------------------------------------------------------===//
257 
258 namespace {
259 
260 /// Lower lookup-tables that have a total size of less than 256 bits to an
261 /// integer that is shifed and truncated according to the lookup/index value.
262 /// Encoding the lookup tables as intermediate values in the instruction stream
263 /// should provide better performnace than loading from some global constant.
264 struct LutToInteger : LutLoweringPattern {
265  using LutLoweringPattern::LutLoweringPattern;
266 
267  LogicalResult
268  matchAndRewrite(LutOp lut, LutOpAdaptor adaptor,
269  ConversionPatternRewriter &rewriter) const final {
270 
271  const uint32_t tableSize = lutCalculator.getTableSize();
272  const uint32_t inputBw = lutCalculator.getInputBitWidth();
273 
274  if (tableSize > 256)
275  return failure();
276 
277  // Concatenate the lookup table entries to a single integer.
278  auto constants = lutCalculator.getRefToTableEntries();
279  APInt result(tableSize, 0);
280  unsigned nextInsertion = tableSize;
281 
282  for (auto attr : constants) {
283  auto chunk = attr.getValue();
284  nextInsertion -= chunk.getBitWidth();
285  result.insertBits(chunk, nextInsertion);
286  }
287 
288  Value table = rewriter.create<hw::ConstantOp>(lut.getLoc(), result);
289 
290  // Zero-extend the lookup/index value to the same bit-width as the table,
291  // because the shift operation requires both operands to have the same
292  // bit-width.
293  Value zextValue = rewriter.create<hw::ConstantOp>(
294  lut->getLoc(), rewriter.getIntegerType(tableSize - inputBw), 0);
295  Value entryOffset = rewriter.create<comb::ConcatOp>(lut.getLoc(), zextValue,
296  lut.getInputs());
297  Value resultBitWidth = rewriter.create<hw::ConstantOp>(
298  lut.getLoc(), entryOffset.getType(),
299  lut.getResult().getType().getIntOrFloatBitWidth());
300  Value lookupValue =
301  rewriter.create<comb::MulOp>(lut.getLoc(), entryOffset, resultBitWidth);
302 
303  // Shift the table and truncate to the bitwidth of the output value.
304  Value shiftedTable =
305  rewriter.create<comb::ShrUOp>(lut->getLoc(), table, lookupValue);
306  const Value extracted = rewriter.create<comb::ExtractOp>(
307  lut.getLoc(), shiftedTable, 0,
308  lut.getOutput().getType().getIntOrFloatBitWidth());
309 
310  rewriter.replaceOp(lut, extracted);
311  return success();
312  }
313 };
314 
315 /// Lower lookup-tables with a total size bigger than 256 bits to a constant
316 /// array that is stored as constant global data and thus a lookup consists of a
317 /// memory load at the correct offset of that global data frame.
318 struct LutToArray : LutLoweringPattern {
319  using LutLoweringPattern::LutLoweringPattern;
320 
321  LogicalResult
322  matchAndRewrite(LutOp lut, LutOpAdaptor adaptor,
323  ConversionPatternRewriter &rewriter) const final {
324  auto constants = lutCalculator.getRefToTableEntries();
325  SmallVector<Attribute> constantAttrs(constants.begin(), constants.end());
326  auto tableSize = lutCalculator.getTableSize();
327  auto inputBw = lutCalculator.getInputBitWidth();
328 
329  if (tableSize <= 256)
330  return failure();
331 
332  Value table = rewriter.create<hw::AggregateConstantOp>(
333  lut.getLoc(), hw::ArrayType::get(lut.getType(), constantAttrs.size()),
334  rewriter.getArrayAttr(constantAttrs));
335  Value lookupValue = rewriter.create<comb::ConcatOp>(
336  lut.getLoc(), rewriter.getIntegerType(inputBw), lut.getInputs());
337  const Value extracted =
338  rewriter.create<hw::ArrayGetOp>(lut.getLoc(), table, lookupValue);
339 
340  rewriter.replaceOp(lut, extracted);
341  return success();
342  }
343 };
344 
345 } // namespace
346 
347 //===----------------------------------------------------------------------===//
348 // Lower LUT pass
349 //===----------------------------------------------------------------------===//
350 
351 namespace {
352 
353 /// Lower LutOp operations to comb and hw operations.
354 struct LowerLUTPass : public arc::impl::LowerLUTBase<LowerLUTPass> {
355  void runOnOperation() override;
356 };
357 
358 } // namespace
359 
360 void LowerLUTPass::runOnOperation() {
361  MLIRContext &context = getContext();
362  ConversionTarget target(context);
363  RewritePatternSet patterns(&context);
364  target.addLegalDialect<comb::CombDialect, hw::HWDialect, arc::ArcDialect>();
365  target.addIllegalOp<arc::LutOp>();
366 
367  // TODO: This class could be factored out into an analysis if there is a need
368  // to access precomputed lookup-tables in some other pass.
369  LutCalculator lutCalculator;
370  patterns.add<LutToInteger, LutToArray>(lutCalculator, &context);
371 
372  if (failed(
373  applyPartialConversion(getOperation(), target, std::move(patterns))))
374  signalPassFailure();
375 }
376 
377 std::unique_ptr<Pass> arc::createLowerLUTPass() {
378  return std::make_unique<LowerLUTPass>();
379 }
Builder builder
def create(data_type, value)
Definition: hw.py:393
std::unique_ptr< mlir::Pass > createLowerLUTPass()
Definition: LowerLUT.cpp:377
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:54
uint64_t getWidth(Type t)
Definition: ESIPasses.cpp:34
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21