CIRCT 20.0.0git
Loading...
Searching...
No Matches
MakeTables.cpp
Go to the documentation of this file.
1//===- MakeTables.cpp -----------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
13#include "mlir/IR/ImplicitLocOpBuilder.h"
14#include "mlir/Pass/Pass.h"
15#include "llvm/Support/Debug.h"
16
17#define DEBUG_TYPE "arc-lookup-tables"
18
19namespace circt {
20namespace arc {
21#define GEN_PASS_DEF_MAKETABLES
22#include "circt/Dialect/Arc/ArcPasses.h.inc"
23} // namespace arc
24} // namespace circt
25
26using namespace circt;
27using namespace arc;
28using namespace hw;
29
30namespace {
31
32static constexpr int tableMinOpCount = 20;
33static constexpr int tableMaxSize = 32768; // bits
34
35struct MakeTablesPass : public arc::impl::MakeTablesBase<MakeTablesPass> {
36 void runOnOperation() override;
37 void runOnArc(DefineOp defineOp);
38};
39} // namespace
40
41static inline uint32_t bitsMask(uint32_t nbits) {
42 if (nbits == 32)
43 return ~0;
44 return (1 << nbits) - 1;
45}
46
47static inline uint32_t bitsGet(uint32_t x, uint32_t lb, uint32_t ub) {
48 return (x >> lb) & bitsMask(ub - lb + 1);
49}
50
51void MakeTablesPass::runOnOperation() {
52 auto module = getOperation();
53 for (auto op : module.getOps<DefineOp>())
54 runOnArc(op);
55}
56
57void MakeTablesPass::runOnArc(DefineOp defineOp) {
58 // Determine the number of input bits.
59 unsigned numInputBits = 0;
60 for (auto &type : defineOp.getArgumentTypes()) {
61 auto intType = dyn_cast<IntegerType>(type);
62 if (!intType)
63 return;
64 numInputBits += intType.getWidth();
65 }
66 if (numInputBits == 0)
67 return;
68
69 // Count the number of non-constant operations in the block.
70 unsigned numOps = 0;
71 for (auto &op : defineOp.getBodyBlock().without_terminator())
72 if (!op.hasTrait<OpTrait::ConstantLike>())
73 ++numOps;
74
75 // Determine the number of output bits.
76 unsigned numOutputBits = 0;
77 auto outputOp = cast<arc::OutputOp>(defineOp.getBodyBlock().getTerminator());
78 for (auto type : outputOp.getOperandTypes()) {
79 auto intType = dyn_cast<IntegerType>(type);
80 if (!intType)
81 return;
82 numOutputBits += intType.getWidth();
83 }
84 if (numOutputBits == 0)
85 return;
86
87 LLVM_DEBUG(llvm::dbgs() << "Making lookup tables in `" << defineOp.getName()
88 << "`\n");
89 LLVM_DEBUG(llvm::dbgs() << "- " << numInputBits << " input bits, "
90 << numOutputBits << " output bits, " << numOps
91 << " ops\n");
92
93 // Check whether the table dimensions are within bounds.
94 if (numInputBits >= 31) {
95 LLVM_DEBUG(llvm::dbgs() << "- Skip; too many input bits\n");
96 return;
97 }
98 if (numOps < tableMinOpCount) {
99 LLVM_DEBUG(llvm::dbgs() << "- Skip; not enough ops\n");
100 return;
101 }
102
103 unsigned numTableEntries = 1U << numInputBits;
104 if (numTableEntries > tableMaxSize / numOutputBits) {
105 LLVM_DEBUG(llvm::dbgs() << "- Skip; table too large\n");
106 return;
107 }
108 LLVM_DEBUG(llvm::dbgs() << "- Creating table of "
109 << numTableEntries * numOutputBits << " bits\n");
110
111 // Actually build the table.
112 SmallVector<Operation *, 64> tabularizedOps;
113 for (auto &op : defineOp.getBodyBlock().without_terminator())
114 tabularizedOps.push_back(&op);
115
116 // Concatenate the inputs into a single index value.
117 auto builder = ImplicitLocOpBuilder::atBlockBegin(defineOp.getLoc(),
118 &defineOp.getBodyBlock());
119 SmallVector<Value> inputsToConcat(defineOp.getArguments());
120 std::reverse(inputsToConcat.begin(), inputsToConcat.end());
121 auto concatInputs = inputsToConcat.size() > 1
122 ? builder.create<comb::ConcatOp>(inputsToConcat)
123 : inputsToConcat[0];
124
125 // Compute a lookup table for every output.
126 SmallVector<SmallVector<Attribute, 0>> tables;
127 DenseMap<Value, Attribute> values;
128 tables.resize(outputOp->getNumOperands());
129
130 for (int input = (1U << numInputBits) - 1; input >= 0; input--) {
131 // Assign the input values.
132 values.clear();
133 unsigned bits = 0;
134 for (auto arg : defineOp.getArguments()) {
135 auto w = dyn_cast<IntegerType>(arg.getType()).getWidth();
136 values[arg] = builder.getIntegerAttr(arg.getType(),
137 bitsGet(input, bits, bits + w - 1));
138 bits += w;
139 }
140
141 // Evaluate the operations.
142 SmallVector<Attribute> constants;
143 for (auto *operation : tabularizedOps) {
144 constants.clear();
145 for (auto operand : operation->getOperands())
146 constants.push_back(values[operand]);
147
148 SmallVector<OpFoldResult, 8> resultValues;
149 if (failed(operation->fold(constants, resultValues))) {
150 LLVM_DEBUG(llvm::dbgs() << "- Skip; operation folder failed\n");
151 return;
152 }
153
154 for (auto [result, resultValue] :
155 llvm::zip(operation->getResults(), resultValues)) {
156 auto attr = dyn_cast<Attribute>(resultValue);
157 if (!attr)
158 attr = values[dyn_cast<Value>(resultValue)];
159 values[result] = attr;
160 }
161 }
162
163 // Add the evaluated values to the output tables.
164 for (auto [table, outputOperand] :
165 llvm::zip(tables, outputOp->getOpOperands())) {
166 table.push_back(dyn_cast<Attribute>(values[outputOperand.get()]));
167 }
168 }
169
170 // Create the table lookup ops.
171 for (auto [table, outputOperand] :
172 llvm::zip(tables, outputOp->getOpOperands())) {
173 auto array = builder.create<hw::AggregateConstantOp>(
174 ArrayType::get(outputOperand.get().getType(), numTableEntries),
175 builder.getArrayAttr(table));
176 outputOperand.set(builder.create<hw::ArrayGetOp>(array, concatInputs));
177 }
178
179 for (auto *op : tabularizedOps) {
180 op->dropAllUses();
181 op->erase();
182 }
183}
184
185std::unique_ptr<Pass> arc::createMakeTablesPass() {
186 return std::make_unique<MakeTablesPass>();
187}
static uint32_t bitsGet(uint32_t x, uint32_t lb, uint32_t ub)
static uint32_t bitsMask(uint32_t nbits)
static Block * getBodyBlock(FModuleLike mod)
std::unique_ptr< mlir::Pass > createMakeTablesPass()
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition hw.py:1