CIRCT 22.0.0git
Loading...
Searching...
No Matches
LowerWordToBits.cpp
Go to the documentation of this file.
1//===----------------------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass implements bit-blasting for logic synthesis operations.
10// It converts multi-bit operations (AIG, MIG, combinatorial) into equivalent
11// single-bit operations, enabling more efficient synthesis and optimization.
12//
13//===----------------------------------------------------------------------===//
14
19#include "circt/Support/LLVM.h"
20#include "mlir/IR/Operation.h"
21#include "mlir/Pass/Pass.h"
22#include "llvm/ADT/APInt.h"
23#include "llvm/ADT/SmallVector.h"
24#include "llvm/ADT/TypeSwitch.h"
25#include "llvm/Support/KnownBits.h"
26#include "llvm/Support/LogicalResult.h"
27
28#define DEBUG_TYPE "synth-lower-word-to-bits"
29
30namespace circt {
31namespace synth {
32#define GEN_PASS_DEF_LOWERWORDTOBITS
33#include "circt/Dialect/Synth/Transforms/SynthPasses.h.inc"
34} // namespace synth
35} // namespace circt
36
37using namespace circt;
38using namespace synth;
39
40//===----------------------------------------------------------------------===//
41// Utilities
42//===----------------------------------------------------------------------===//
43
44/// Check if an operation should be lowered to bit-level operations.
45static bool shouldLowerOperation(Operation *op) {
46 return isa<aig::AndInverterOp, mig::MajorityInverterOp, comb::AndOp,
48}
49
50namespace {
51
52//===----------------------------------------------------------------------===//
53// BitBlaster - Bit-level lowering implementation
54//===----------------------------------------------------------------------===//
55
56/// The BitBlaster class implements the core bit-blasting algorithm.
57/// It manages the lowering of multi-bit operations to single-bit operations
58/// while maintaining correctness and optimizing for constant propagation.
59class BitBlaster {
60public:
61 explicit BitBlaster(hw::HWModuleOp moduleOp) : moduleOp(moduleOp) {}
62
63 /// Run the bit-blasting algorithm on the module.
64 LogicalResult run();
65
66 //===--------------------------------------------------------------------===//
67 // Statistics
68 //===--------------------------------------------------------------------===//
69
70 /// Number of bits that were lowered from multi-bit to single-bit operations
71 size_t numLoweredBits = 0;
72
73 /// Number of constant bits that were identified and optimized
74 size_t numLoweredConstants = 0;
75
76 /// Number of operations that were lowered
77 size_t numLoweredOps = 0;
78
79private:
80 //===--------------------------------------------------------------------===//
81 // Core Lowering Methods
82 //===--------------------------------------------------------------------===//
83
84 /// Lower a multi-bit value to individual bits.
85 /// This is the main entry point for bit-blasting a value.
86 ArrayRef<Value> lowerValueToBits(Value value);
87 template <typename OpTy>
88 ArrayRef<Value> lowerInvertibleOperations(OpTy op);
89 template <typename OpTy>
90 ArrayRef<Value> lowerCombOperations(OpTy op);
91 ArrayRef<Value>
92 lowerOp(Operation *op,
93 llvm::function_ref<Value(OpBuilder &builder, ValueRange)> createOp);
94
95 /// Extract a specific bit from a value.
96 /// Handles various IR constructs that can represent bit extraction.
97 Value extractBit(Value value, size_t index);
98
99 /// Compute and cache known bits for a value.
100 /// Uses operation-specific logic to determine which bits are constants.
101 const llvm::KnownBits &computeKnownBits(Value value);
102
103 /// Get or create a boolean constant (0 or 1).
104 /// Constants are cached to avoid duplication.
105 Value getBoolConstant(bool value);
106
107 //===--------------------------------------------------------------------===//
108 // Helper Methods
109 //===--------------------------------------------------------------------===//
110
111 /// Insert lowered bits into the cache.
112 ArrayRef<Value> insertBits(Value value, SmallVector<Value> bits) {
113 auto it = loweredValues.insert({value, std::move(bits)});
114 assert(it.second && "value already inserted");
115 return it.first->second;
116 }
117
118 /// Insert computed known bits into the cache.
119 const llvm::KnownBits &insertKnownBits(Value value, llvm::KnownBits bits) {
120 auto it = knownBits.insert({value, std::move(bits)});
121 return it.first->second;
122 }
123
124 /// Cache for lowered values (multi-bit -> vector of single-bit values)
125 llvm::MapVector<Value, SmallVector<Value>> loweredValues;
126
127 /// Cache for computed known bits information
128 llvm::MapVector<Value, llvm::KnownBits> knownBits;
129
130 /// Cached boolean constants (false at index 0, true at index 1)
131 std::array<Value, 2> constants;
132
133 /// Reference to the module being processed
134 hw::HWModuleOp moduleOp;
135};
136
137} // namespace
138
139//===----------------------------------------------------------------------===//
140// BitBlaster Implementation
141//===----------------------------------------------------------------------===//
142
143const llvm::KnownBits &BitBlaster::computeKnownBits(Value value) {
144 // Check cache first
145 auto *it = knownBits.find(value);
146 if (it != knownBits.end())
147 return it->second;
148
149 auto width = hw::getBitWidth(value.getType());
150 auto *op = value.getDefiningOp();
151
152 // For block arguments, return unknown bits
153 if (!op)
154 return insertKnownBits(value, llvm::KnownBits(width));
155
156 llvm::KnownBits result(width);
157 if (auto aig = dyn_cast<aig::AndInverterOp>(op)) {
158 // Initialize to all ones for AND operation
159 result.One = APInt::getAllOnes(width);
160 result.Zero = APInt::getZero(width);
161
162 for (auto [operand, inverted] :
163 llvm::zip(aig.getInputs(), aig.getInverted())) {
164 auto operandKnownBits = computeKnownBits(operand);
165 if (inverted)
166 // Complement the known bits by swapping Zero and One
167 std::swap(operandKnownBits.Zero, operandKnownBits.One);
168 result &= operandKnownBits;
169 }
170 } else if (auto mig = dyn_cast<mig::MajorityInverterOp>(op)) {
171 // Give up if it's not a 3-input majority inverter.
172 if (mig.getNumOperands() == 3) {
173 std::array<llvm::KnownBits, 3> operandsKnownBits;
174 for (auto [i, operand, inverted] :
175 llvm::enumerate(mig.getInputs(), mig.getInverted())) {
176 operandsKnownBits[i] = computeKnownBits(operand);
177 // Complement the known bits by swapping Zero and One
178 if (inverted)
179 std::swap(operandsKnownBits[i].Zero, operandsKnownBits[i].One);
180 }
181
182 result = (operandsKnownBits[0] & operandsKnownBits[1]) |
183 (operandsKnownBits[0] & operandsKnownBits[2]) |
184 (operandsKnownBits[1] & operandsKnownBits[2]);
185 }
186 } else {
187 // For other operations, use the standard known bits computation
188 // TODO: This is not optimal as it has a depth limit and does not check
189 // cached results.
190 result = comb::computeKnownBits(value);
191 }
192
193 return insertKnownBits(value, std::move(result));
194}
195
196Value BitBlaster::extractBit(Value value, size_t index) {
197 if (hw::getBitWidth(value.getType()) <= 1)
198 return value;
199
200 auto *op = value.getDefiningOp();
201
202 // If the value is a block argument, extract the bit.
203 if (!op)
204 return lowerValueToBits(value)[index];
205
206 return TypeSwitch<Operation *, Value>(op)
207 .Case<comb::ConcatOp>([&](comb::ConcatOp op) {
208 for (auto operand : llvm::reverse(op.getOperands())) {
209 auto width = hw::getBitWidth(operand.getType());
210 assert(width >= 0 && "operand has zero width");
211 if (index < static_cast<size_t>(width))
212 return extractBit(operand, index);
213 index -= static_cast<size_t>(width);
214 }
215 llvm_unreachable("index out of bounds");
216 })
217 .Case<comb::ExtractOp>([&](comb::ExtractOp ext) {
218 return extractBit(ext.getInput(),
219 static_cast<size_t>(ext.getLowBit()) + index);
220 })
221 .Case<comb::ReplicateOp>([&](comb::ReplicateOp op) {
222 return extractBit(op.getInput(),
223 index % static_cast<size_t>(hw::getBitWidth(
224 op.getOperand().getType())));
225 })
226 .Case<hw::ConstantOp>([&](hw::ConstantOp op) {
227 auto value = op.getValue();
228 return getBoolConstant(value[index]);
229 })
230 .Default([&](auto op) { return lowerValueToBits(value)[index]; });
231}
232
233ArrayRef<Value> BitBlaster::lowerValueToBits(Value value) {
234 auto *it = loweredValues.find(value);
235 if (it != loweredValues.end())
236 return it->second;
237
238 auto width = hw::getBitWidth(value.getType());
239 if (width <= 1)
240 return insertBits(value, {value});
241
242 auto *op = value.getDefiningOp();
243 if (!op) {
244 SmallVector<Value> results;
245 OpBuilder builder(value.getContext());
246 builder.setInsertionPointAfterValue(value);
247 comb::extractBits(builder, value, results);
248 return insertBits(value, std::move(results));
249 }
250
251 return TypeSwitch<Operation *, ArrayRef<Value>>(op)
252 .Case<aig::AndInverterOp, mig::MajorityInverterOp>(
253 [&](auto op) { return lowerInvertibleOperations(op); })
254 .Case<comb::AndOp, comb::OrOp, comb::XorOp>(
255 [&](auto op) { return lowerCombOperations(op); })
256 .Default([&](auto op) {
257 OpBuilder builder(value.getContext());
258 builder.setInsertionPoint(op);
259 SmallVector<Value> results;
260 comb::extractBits(builder, value, results);
261
262 return insertBits(value, std::move(results));
263 });
264}
265
266LogicalResult BitBlaster::run() {
267 // Topologically sort operations in graph regions so that walk visits them in
268 // the topological order.
270 moduleOp, [](Value value, Operation *op) -> bool {
271 // Otherthan target ops, all other ops are always ready.
272 return !(shouldLowerOperation(op) ||
273 isa<comb::ExtractOp, comb::ReplicateOp, comb::ConcatOp,
274 comb::ReplicateOp>(op));
275 }))) {
276 // If we failed to topologically sort operations we cannot proceed.
277 return mlir::emitError(moduleOp.getLoc(), "there is a combinational cycle");
278 }
279
280 // Lower target operations
281 moduleOp.walk([&](Operation *op) {
282 // If the block is in a graph region, topologically sort it first.
283 if (shouldLowerOperation(op))
284 (void)lowerValueToBits(op->getResult(0));
285 });
286
287 // Replace operations with concatenated results if needed
288 for (auto &[value, results] :
289 llvm::make_early_inc_range(llvm::reverse(loweredValues))) {
290 if (hw::getBitWidth(value.getType()) <= 1)
291 continue;
292
293 auto *op = value.getDefiningOp();
294 if (!op)
295 continue;
296
297 if (value.use_empty()) {
298 op->erase();
299 continue;
300 }
301
302 // If a target operation still has an use (e.g. connected to output or
303 // instance), replace the value with the concatenated result.
304 if (shouldLowerOperation(op)) {
305 OpBuilder builder(op);
306 std::reverse(results.begin(), results.end());
307 auto concat = builder.create<comb::ConcatOp>(value.getLoc(), results);
308 value.replaceAllUsesWith(concat);
309 op->erase();
310 }
311 }
312
313 return success();
314}
315
316Value BitBlaster::getBoolConstant(bool value) {
317 if (!constants[value]) {
318 auto builder = OpBuilder::atBlockBegin(moduleOp.getBodyBlock());
319 constants[value] = builder.create<hw::ConstantOp>(
320 builder.getUnknownLoc(), builder.getI1Type(), value);
321 }
322 return constants[value];
323}
324
325template <typename OpTy>
326ArrayRef<Value> BitBlaster::lowerInvertibleOperations(OpTy op) {
327 auto createOp = [&](OpBuilder &builder, ValueRange operands) {
328 return builder.createOrFold<OpTy>(op.getLoc(), operands, op.getInverted());
329 };
330 return lowerOp(op, createOp);
331}
332
333template <typename OpTy>
334ArrayRef<Value> BitBlaster::lowerCombOperations(OpTy op) {
335 auto createOp = [&](OpBuilder &builder, ValueRange operands) {
336 return builder.createOrFold<OpTy>(op.getLoc(), operands,
337 op.getTwoStateAttr());
338 };
339 return lowerOp(op, createOp);
340}
341
342ArrayRef<Value> BitBlaster::lowerOp(
343 Operation *op,
344 llvm::function_ref<Value(OpBuilder &builder, ValueRange)> createOp) {
345 auto value = op->getResult(0);
346 OpBuilder builder(op);
347 auto width = hw::getBitWidth(value.getType());
348 assert(width > 1 && "expected multi-bit operation");
349
350 auto known = computeKnownBits(value);
351 APInt knownMask = known.Zero | known.One;
352
353 // Update statistics
354 numLoweredConstants += knownMask.popcount();
355 numLoweredBits += width;
356 ++numLoweredOps;
357
358 SmallVector<Value> results;
359 results.reserve(width);
360
361 for (int64_t i = 0; i < width; ++i) {
362 SmallVector<Value> operands;
363 operands.reserve(op->getNumOperands());
364 if (knownMask[i]) {
365 // Use known constant value
366 results.push_back(getBoolConstant(known.One[i]));
367 continue;
368 }
369
370 // Extract the i-th bit from each operand
371 for (auto operand : op->getOperands())
372 operands.push_back(extractBit(operand, i));
373
374 // Create the single-bit operation
375 auto result = createOp(builder, operands);
376 results.push_back(result);
377
378 // Add name hint if present
379 if (auto name = op->getAttrOfType<StringAttr>("sv.namehint")) {
380 auto newName = StringAttr::get(
381 op->getContext(), name.getValue() + "[" + std::to_string(i) + "]");
382 if (auto *loweredOp = result.getDefiningOp())
383 loweredOp->setAttr("sv.namehint", newName);
384 }
385 }
386
387 assert(results.size() == static_cast<size_t>(width));
388 return insertBits(value, std::move(results));
389}
390
391//===----------------------------------------------------------------------===//
392// Pass Implementation
393//===----------------------------------------------------------------------===//
394
395namespace {
396struct LowerWordToBitsPass
397 : public impl::LowerWordToBitsBase<LowerWordToBitsPass> {
398 void runOnOperation() override;
399};
400} // namespace
401
402void LowerWordToBitsPass::runOnOperation() {
403 BitBlaster driver(getOperation());
404 if (failed(driver.run()))
405 return signalPassFailure();
406
407 // Update statistics
408 numLoweredBits += driver.numLoweredBits;
409 numLoweredConstants += driver.numLoweredConstants;
410 numLoweredOps += driver.numLoweredOps;
411}
assert(baseType &&"element must be base type")
static SmallVector< T > concat(const SmallVectorImpl< T > &a, const SmallVectorImpl< T > &b)
Returns a new vector containing the concatenation of vectors a and b.
Definition CalyxOps.cpp:540
static KnownBits computeKnownBits(Value v, unsigned depth)
Given an integer SSA value, check to see if we know anything about the result of the computation.
static bool shouldLowerOperation(Operation *op)
Check if an operation should be lowered to bit-level operations.
LogicalResult topologicallySortGraphRegionBlocks(mlir::Operation *op, llvm::function_ref< bool(mlir::Value, mlir::Operation *)> isOperandReady)
This function performs a topological sort on the operations within each block of graph regions in the...
Definition SynthOps.cpp:305
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
int run(Type[Generator] generator=CppGenerator, cmdline_args=sys.argv)
Definition codegen.py:121
Definition synth.py:1