CIRCT 23.0.0git
Loading...
Searching...
No Matches
LowerWordToBits.cpp
Go to the documentation of this file.
1//===----------------------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass implements bit-blasting for logic synthesis operations.
10// It converts multi-bit operations (AIG, combinatorial) into equivalent
11// single-bit operations, enabling more efficient synthesis and optimization.
12//
13//===----------------------------------------------------------------------===//
14
19#include "circt/Support/LLVM.h"
20#include "mlir/IR/Operation.h"
21#include "mlir/Pass/Pass.h"
22#include "llvm/ADT/APInt.h"
23#include "llvm/ADT/SmallVector.h"
24#include "llvm/ADT/TypeSwitch.h"
25#include "llvm/Support/KnownBits.h"
26#include "llvm/Support/LogicalResult.h"
27
28#define DEBUG_TYPE "synth-lower-word-to-bits"
29
30namespace circt {
31namespace synth {
32#define GEN_PASS_DEF_LOWERWORDTOBITS
33#include "circt/Dialect/Synth/Transforms/SynthPasses.h.inc"
34} // namespace synth
35} // namespace circt
36
37using namespace circt;
38using namespace synth;
39
40//===----------------------------------------------------------------------===//
41// Utilities
42//===----------------------------------------------------------------------===//
43
44/// Check if an operation should be lowered to bit-level operations.
45static bool shouldLowerOperation(Operation *op) {
46 return isa<ChoiceOp, aig::AndInverterOp, comb::AndOp, comb::OrOp, comb::XorOp,
47 comb::MuxOp>(op);
48}
49
50namespace {
51
52//===----------------------------------------------------------------------===//
53// BitBlaster - Bit-level lowering implementation
54//===----------------------------------------------------------------------===//
55
56/// The BitBlaster class implements the core bit-blasting algorithm.
57/// It manages the lowering of multi-bit operations to single-bit operations
58/// while maintaining correctness and optimizing for constant propagation.
59class BitBlaster {
60public:
61 explicit BitBlaster(hw::HWModuleOp moduleOp) : moduleOp(moduleOp) {}
62
63 /// Run the bit-blasting algorithm on the module.
64 LogicalResult run();
65
66 //===--------------------------------------------------------------------===//
67 // Statistics
68 //===--------------------------------------------------------------------===//
69
70 /// Number of bits that were lowered from multi-bit to single-bit operations
71 size_t numLoweredBits = 0;
72
73 /// Number of constant bits that were identified and optimized
74 size_t numLoweredConstants = 0;
75
76 /// Number of operations that were lowered
77 size_t numLoweredOps = 0;
78
79private:
80 //===--------------------------------------------------------------------===//
81 // Core Lowering Methods
82 //===--------------------------------------------------------------------===//
83
84 /// Lower a multi-bit value to individual bits.
85 /// This is the main entry point for bit-blasting a value.
86 ArrayRef<Value> lowerValueToBits(Value value);
87 template <typename OpTy>
88 ArrayRef<Value> lowerInvertibleOperations(OpTy op);
89 template <typename OpTy>
90 ArrayRef<Value> lowerCombLogicOperations(OpTy op);
91 ArrayRef<Value> lowerCombMux(comb::MuxOp op);
92 ArrayRef<Value>
93 lowerOp(Operation *op,
94 llvm::function_ref<Value(OpBuilder &builder, ValueRange)> createOp);
95
96 /// Extract a specific bit from a value.
97 /// Handles various IR constructs that can represent bit extraction.
98 Value extractBit(Value value, size_t index);
99
100 /// Compute and cache known bits for a value.
101 /// Uses operation-specific logic to determine which bits are constants.
102 const llvm::KnownBits &computeKnownBits(Value value);
103
104 /// Get or create a boolean constant (0 or 1).
105 /// Constants are cached to avoid duplication.
106 Value getBoolConstant(bool value);
107
108 //===--------------------------------------------------------------------===//
109 // Helper Methods
110 //===--------------------------------------------------------------------===//
111
112 /// Insert lowered bits into the cache.
113 ArrayRef<Value> insertBits(Value value, SmallVector<Value> bits) {
114 auto it = loweredValues.insert({value, std::move(bits)});
115 assert(it.second && "value already inserted");
116 return it.first->second;
117 }
118
119 /// Insert computed known bits into the cache.
120 const llvm::KnownBits &insertKnownBits(Value value, llvm::KnownBits bits) {
121 auto it = knownBits.insert({value, std::move(bits)});
122 return it.first->second;
123 }
124
125 /// Cache for lowered values (multi-bit -> vector of single-bit values)
127
128 /// Cache for computed known bits information
130
131 /// Cached boolean constants (false at index 0, true at index 1)
132 std::array<Value, 2> constants;
133
134 /// Reference to the module being processed
135 hw::HWModuleOp moduleOp;
136};
137
138} // namespace
139
140//===----------------------------------------------------------------------===//
141// BitBlaster Implementation
142//===----------------------------------------------------------------------===//
143
144const llvm::KnownBits &BitBlaster::computeKnownBits(Value value) {
145 // Check cache first
146 auto *it = knownBits.find(value);
147 if (it != knownBits.end())
148 return it->second;
149
150 auto width = hw::getBitWidth(value.getType());
151 auto *op = value.getDefiningOp();
152
153 // For block arguments, return unknown bits
154 if (!op)
155 return insertKnownBits(value, llvm::KnownBits(width));
156
157 llvm::KnownBits result(width);
158 if (auto aig = dyn_cast<aig::AndInverterOp>(op)) {
159 // Initialize to all ones for AND operation
160 result.One = APInt::getAllOnes(width);
161 result.Zero = APInt::getZero(width);
162
163 for (auto [operand, inverted] :
164 llvm::zip(aig.getInputs(), aig.getInverted())) {
165 auto operandKnownBits = computeKnownBits(operand);
166 if (inverted)
167 // Complement the known bits by swapping Zero and One
168 std::swap(operandKnownBits.Zero, operandKnownBits.One);
169 result &= operandKnownBits;
170 }
171 } else if (auto choice = dyn_cast<ChoiceOp>(op)) {
172 result = computeKnownBits(choice.getInputs().front());
173 for (auto input : choice.getInputs().drop_front()) {
174 auto known = computeKnownBits(input);
175 result.One |= known.One;
176 result.Zero |= known.Zero;
177 }
178 } else {
179 // For other operations, use the standard known bits computation
180 // TODO: This is not optimal as it has a depth limit and does not check
181 // cached results.
182 result = comb::computeKnownBits(value);
183 }
184
185 return insertKnownBits(value, std::move(result));
186}
187
188Value BitBlaster::extractBit(Value value, size_t index) {
189 if (hw::getBitWidth(value.getType()) <= 1)
190 return value;
191
192 auto *op = value.getDefiningOp();
193
194 // If the value is a block argument, extract the bit.
195 if (!op)
196 return lowerValueToBits(value)[index];
197
198 return TypeSwitch<Operation *, Value>(op)
199 .Case<comb::ConcatOp>([&](comb::ConcatOp op) {
200 for (auto operand : llvm::reverse(op.getOperands())) {
201 auto width = hw::getBitWidth(operand.getType());
202 assert(width >= 0 && "operand has zero width");
203 if (index < static_cast<size_t>(width))
204 return extractBit(operand, index);
205 index -= static_cast<size_t>(width);
206 }
207 llvm_unreachable("index out of bounds");
208 })
209 .Case<comb::ExtractOp>([&](comb::ExtractOp ext) {
210 return extractBit(ext.getInput(),
211 static_cast<size_t>(ext.getLowBit()) + index);
212 })
213 .Case<comb::ReplicateOp>([&](comb::ReplicateOp op) {
214 return extractBit(op.getInput(),
215 index % static_cast<size_t>(hw::getBitWidth(
216 op.getOperand().getType())));
217 })
218 .Case<hw::ConstantOp>([&](hw::ConstantOp op) {
219 auto value = op.getValue();
220 return getBoolConstant(value[index]);
221 })
222 .Default([&](auto op) { return lowerValueToBits(value)[index]; });
223}
224
225ArrayRef<Value> BitBlaster::lowerValueToBits(Value value) {
226 auto *it = loweredValues.find(value);
227 if (it != loweredValues.end())
228 return it->second;
229
230 auto width = hw::getBitWidth(value.getType());
231 if (width <= 1)
232 return insertBits(value, {value});
233
234 auto *op = value.getDefiningOp();
235 if (!op) {
236 SmallVector<Value> results;
237 OpBuilder builder(value.getContext());
238 builder.setInsertionPointAfterValue(value);
239 comb::extractBits(builder, value, results);
240 return insertBits(value, std::move(results));
241 }
242
243 return TypeSwitch<Operation *, ArrayRef<Value>>(op)
244 .Case<ChoiceOp>([&](ChoiceOp op) {
245 auto createOp = [&](OpBuilder &builder, ValueRange operands) {
246 return builder.createOrFold<ChoiceOp>(
247 op.getLoc(), operands[0].getType(), operands);
248 };
249 return lowerOp(op, createOp);
250 })
251 .Case<aig::AndInverterOp>(
252 [&](auto op) { return lowerInvertibleOperations(op); })
253 .Case<comb::AndOp, comb::OrOp, comb::XorOp>(
254 [&](auto op) { return lowerCombLogicOperations(op); })
255 .Case<comb::MuxOp>([&](comb::MuxOp op) { return lowerCombMux(op); })
256 .Default([&](auto op) {
257 OpBuilder builder(value.getContext());
258 builder.setInsertionPoint(op);
259 SmallVector<Value> results;
260 comb::extractBits(builder, value, results);
261
262 return insertBits(value, std::move(results));
263 });
264}
265
266LogicalResult BitBlaster::run() {
267 // Topologically sort operations in graph regions so that walk visits them in
268 // the topological order.
270 moduleOp, [](Value value, Operation *op) -> bool {
271 // Otherthan target ops, all other ops are always ready.
272 return !(shouldLowerOperation(op) ||
273 isa<comb::ExtractOp, comb::ReplicateOp, comb::ConcatOp,
274 comb::ReplicateOp>(op));
275 }))) {
276 // If we failed to topologically sort operations we cannot proceed.
277 return mlir::emitError(moduleOp.getLoc(), "there is a combinational cycle");
278 }
279
280 // Lower target operations
281 moduleOp.walk([&](Operation *op) {
282 // If the block is in a graph region, topologically sort it first.
283 if (shouldLowerOperation(op))
284 (void)lowerValueToBits(op->getResult(0));
285 });
286
287 // Replace operations with concatenated results if needed
288 for (auto &[value, results] :
289 llvm::make_early_inc_range(llvm::reverse(loweredValues))) {
290 if (hw::getBitWidth(value.getType()) <= 1)
291 continue;
292
293 auto *op = value.getDefiningOp();
294 if (!op)
295 continue;
296
297 if (value.use_empty()) {
298 op->erase();
299 continue;
300 }
301
302 // If a target operation still has an use (e.g. connected to output or
303 // instance), replace the value with the concatenated result.
304 if (shouldLowerOperation(op)) {
305 OpBuilder builder(op);
306 std::reverse(results.begin(), results.end());
307 auto concat = comb::ConcatOp::create(builder, value.getLoc(), results);
308 value.replaceAllUsesWith(concat);
309 op->erase();
310 }
311 }
312
313 return success();
314}
315
316Value BitBlaster::getBoolConstant(bool value) {
317 if (!constants[value]) {
318 auto builder = OpBuilder::atBlockBegin(moduleOp.getBodyBlock());
319 constants[value] = hw::ConstantOp::create(builder, builder.getUnknownLoc(),
320 builder.getI1Type(), value);
321 }
322 return constants[value];
323}
324
325template <typename OpTy>
326ArrayRef<Value> BitBlaster::lowerInvertibleOperations(OpTy op) {
327 auto createOp = [&](OpBuilder &builder, ValueRange operands) {
328 return builder.createOrFold<OpTy>(op.getLoc(), operands, op.getInverted());
329 };
330 return lowerOp(op, createOp);
331}
332
333template <typename OpTy>
334ArrayRef<Value> BitBlaster::lowerCombLogicOperations(OpTy op) {
335 auto createOp = [&](OpBuilder &builder, ValueRange operands) {
336 return builder.createOrFold<OpTy>(op.getLoc(), operands,
337 op.getTwoStateAttr());
338 };
339 return lowerOp(op, createOp);
340}
341
342ArrayRef<Value> BitBlaster::lowerCombMux(comb::MuxOp op) {
343 auto createOp = [&](OpBuilder &builder, ValueRange operands) {
344 return builder.createOrFold<comb::MuxOp>(op.getLoc(), operands[0],
345 operands[1], operands[2],
346 op.getTwoStateAttr());
347 };
348 return lowerOp(op, createOp);
349}
350
351ArrayRef<Value> BitBlaster::lowerOp(
352 Operation *op,
353 llvm::function_ref<Value(OpBuilder &builder, ValueRange)> createOp) {
354 auto value = op->getResult(0);
355 OpBuilder builder(op);
356 auto width = hw::getBitWidth(value.getType());
357 assert(width > 1 && "expected multi-bit operation");
358
359 auto known = computeKnownBits(value);
360 APInt knownMask = known.Zero | known.One;
361
362 // Update statistics
363 numLoweredConstants += knownMask.popcount();
364 numLoweredBits += width;
365 ++numLoweredOps;
366
367 SmallVector<Value> results;
368 results.reserve(width);
369
370 for (int64_t i = 0; i < width; ++i) {
371 SmallVector<Value> operands;
372 operands.reserve(op->getNumOperands());
373 if (knownMask[i]) {
374 // Use known constant value
375 results.push_back(getBoolConstant(known.One[i]));
376 continue;
377 }
378
379 // Extract the i-th bit from each operand
380 for (auto operand : op->getOperands())
381 operands.push_back(extractBit(operand, i));
382
383 // Create the single-bit operation
384 auto result = createOp(builder, operands);
385 results.push_back(result);
386
387 // Add name hint if present
388 if (auto name = op->getAttrOfType<StringAttr>("sv.namehint")) {
389 auto newName = StringAttr::get(
390 op->getContext(), name.getValue() + "[" + std::to_string(i) + "]");
391 if (auto *loweredOp = result.getDefiningOp())
392 loweredOp->setAttr("sv.namehint", newName);
393 }
394 }
395
396 assert(results.size() == static_cast<size_t>(width));
397 return insertBits(value, std::move(results));
398}
399
400//===----------------------------------------------------------------------===//
401// Pass Implementation
402//===----------------------------------------------------------------------===//
403
404namespace {
405struct LowerWordToBitsPass
406 : public impl::LowerWordToBitsBase<LowerWordToBitsPass> {
407 void runOnOperation() override;
408};
409} // namespace
410
411void LowerWordToBitsPass::runOnOperation() {
412 BitBlaster driver(getOperation());
413 if (failed(driver.run()))
414 return signalPassFailure();
415
416 // Update statistics
417 numLoweredBits += driver.numLoweredBits;
418 numLoweredConstants += driver.numLoweredConstants;
419 numLoweredOps += driver.numLoweredOps;
420}
assert(baseType &&"element must be base type")
static KnownBits computeKnownBits(Value v, unsigned depth)
Given an integer SSA value, check to see if we know anything about the result of the computation.
static bool shouldLowerOperation(Operation *op)
Check if an operation should be lowered to bit-level operations.
create(data_type, value)
Definition hw.py:433
LogicalResult topologicallySortGraphRegionBlocks(mlir::Operation *op, llvm::function_ref< bool(mlir::Value, mlir::Operation *)> isOperandReady)
This function performs a topological sort on the operations within each block of graph regions in the...
Definition SynthOps.cpp:265
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
int run(Type[Generator] generator=CppGenerator, cmdline_args=sys.argv)
Definition codegen.py:445
Definition synth.py:1