CIRCT  19.0.0git
MakeTables.cpp
Go to the documentation of this file.
1 //===- MakeTables.cpp -----------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
12 #include "circt/Dialect/HW/HWOps.h"
13 #include "mlir/IR/ImplicitLocOpBuilder.h"
14 #include "mlir/Pass/Pass.h"
15 #include "llvm/Support/Debug.h"
16 
17 #define DEBUG_TYPE "arc-lookup-tables"
18 
19 namespace circt {
20 namespace arc {
21 #define GEN_PASS_DEF_MAKETABLES
22 #include "circt/Dialect/Arc/ArcPasses.h.inc"
23 } // namespace arc
24 } // namespace circt
25 
26 using namespace circt;
27 using namespace arc;
28 using namespace hw;
29 
30 namespace {
31 
32 static constexpr int tableMinOpCount = 20;
33 static constexpr int tableMaxSize = 32768; // bits
34 
35 struct MakeTablesPass : public arc::impl::MakeTablesBase<MakeTablesPass> {
36  void runOnOperation() override;
37  void runOnArc(DefineOp defineOp);
38 };
39 } // namespace
40 
41 static inline uint32_t bitsMask(uint32_t nbits) {
42  if (nbits == 32)
43  return ~0;
44  return (1 << nbits) - 1;
45 }
46 
47 static inline uint32_t bitsGet(uint32_t x, uint32_t lb, uint32_t ub) {
48  return (x >> lb) & bitsMask(ub - lb + 1);
49 }
50 
51 void MakeTablesPass::runOnOperation() {
52  auto module = getOperation();
53  for (auto op : module.getOps<DefineOp>())
54  runOnArc(op);
55 }
56 
57 void MakeTablesPass::runOnArc(DefineOp defineOp) {
58  // Determine the number of input bits.
59  unsigned numInputBits = 0;
60  for (auto &type : defineOp.getArgumentTypes()) {
61  auto intType = dyn_cast<IntegerType>(type);
62  if (!intType)
63  return;
64  numInputBits += intType.getWidth();
65  }
66  if (numInputBits == 0)
67  return;
68 
69  // Count the number of non-constant operations in the block.
70  unsigned numOps = 0;
71  for (auto &op : defineOp.getBodyBlock().without_terminator())
72  if (!op.hasTrait<OpTrait::ConstantLike>())
73  ++numOps;
74 
75  // Determine the number of output bits.
76  unsigned numOutputBits = 0;
77  auto outputOp = cast<arc::OutputOp>(defineOp.getBodyBlock().getTerminator());
78  for (auto type : outputOp.getOperandTypes()) {
79  auto intType = dyn_cast<IntegerType>(type);
80  if (!intType)
81  return;
82  numOutputBits += intType.getWidth();
83  }
84  if (numOutputBits == 0)
85  return;
86 
87  LLVM_DEBUG(llvm::dbgs() << "Making lookup tables in `" << defineOp.getName()
88  << "`\n");
89  LLVM_DEBUG(llvm::dbgs() << "- " << numInputBits << " input bits, "
90  << numOutputBits << " output bits, " << numOps
91  << " ops\n");
92 
93  // Check whether the table dimensions are within bounds.
94  if (numInputBits >= 31) {
95  LLVM_DEBUG(llvm::dbgs() << "- Skip; too many input bits\n");
96  return;
97  }
98  if (numOps < tableMinOpCount) {
99  LLVM_DEBUG(llvm::dbgs() << "- Skip; not enough ops\n");
100  return;
101  }
102 
103  unsigned numTableEntries = 1U << numInputBits;
104  if (numTableEntries > tableMaxSize / numOutputBits) {
105  LLVM_DEBUG(llvm::dbgs() << "- Skip; table too large\n");
106  return;
107  }
108  LLVM_DEBUG(llvm::dbgs() << "- Creating table of "
109  << numTableEntries * numOutputBits << " bits\n");
110 
111  // Actually build the table.
112  SmallVector<Operation *, 64> tabularizedOps;
113  for (auto &op : defineOp.getBodyBlock().without_terminator())
114  tabularizedOps.push_back(&op);
115 
116  // Concatenate the inputs into a single index value.
117  auto builder = ImplicitLocOpBuilder::atBlockBegin(defineOp.getLoc(),
118  &defineOp.getBodyBlock());
119  SmallVector<Value> inputsToConcat(defineOp.getArguments());
120  std::reverse(inputsToConcat.begin(), inputsToConcat.end());
121  auto concatInputs = inputsToConcat.size() > 1
122  ? builder.create<comb::ConcatOp>(inputsToConcat)
123  : inputsToConcat[0];
124 
125  // Compute a lookup table for every output.
126  SmallVector<SmallVector<Attribute, 0>> tables;
127  DenseMap<Value, Attribute> values;
128  tables.resize(outputOp->getNumOperands());
129 
130  for (int input = (1U << numInputBits) - 1; input >= 0; input--) {
131  // Assign the input values.
132  values.clear();
133  unsigned bits = 0;
134  for (auto arg : defineOp.getArguments()) {
135  auto w = dyn_cast<IntegerType>(arg.getType()).getWidth();
136  values[arg] = builder.getIntegerAttr(arg.getType(),
137  bitsGet(input, bits, bits + w - 1));
138  bits += w;
139  }
140 
141  // Evaluate the operations.
142  SmallVector<Attribute> constants;
143  for (auto *operation : tabularizedOps) {
144  constants.clear();
145  for (auto operand : operation->getOperands())
146  constants.push_back(values[operand]);
147 
148  SmallVector<OpFoldResult, 8> resultValues;
149  if (failed(operation->fold(constants, resultValues))) {
150  LLVM_DEBUG(llvm::dbgs() << "- Skip; operation folder failed\n");
151  return;
152  }
153 
154  for (auto [result, resultValue] :
155  llvm::zip(operation->getResults(), resultValues)) {
156  auto attr = dyn_cast<Attribute>(resultValue);
157  if (!attr)
158  attr = values[dyn_cast<Value>(resultValue)];
159  values[result] = attr;
160  }
161  }
162 
163  // Add the evaluated values to the output tables.
164  for (auto [table, outputOperand] :
165  llvm::zip(tables, outputOp->getOpOperands())) {
166  table.push_back(dyn_cast<Attribute>(values[outputOperand.get()]));
167  }
168  }
169 
170  // Create the table lookup ops.
171  for (auto [table, outputOperand] :
172  llvm::zip(tables, outputOp->getOpOperands())) {
173  auto array = builder.create<hw::AggregateConstantOp>(
174  ArrayType::get(outputOperand.get().getType(), numTableEntries),
175  builder.getArrayAttr(table));
176  outputOperand.set(builder.create<hw::ArrayGetOp>(array, concatInputs));
177  }
178 
179  for (auto *op : tabularizedOps) {
180  op->dropAllUses();
181  op->erase();
182  }
183 }
184 
185 std::unique_ptr<Pass> arc::createMakeTablesPass() {
186  return std::make_unique<MakeTablesPass>();
187 }
static uint32_t bitsGet(uint32_t x, uint32_t lb, uint32_t ub)
Definition: MakeTables.cpp:47
static uint32_t bitsMask(uint32_t nbits)
Definition: MakeTables.cpp:41
Builder builder
std::unique_ptr< mlir::Pass > createMakeTablesPass()
Definition: MakeTables.cpp:185
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:54
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21
Definition: hw.py:1