CIRCT 23.0.0git
Loading...
Searching...
No Matches
LowerWordToBits.cpp
Go to the documentation of this file.
1//===----------------------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass implements bit-blasting for logic synthesis operations.
10// It converts multi-bit operations (AIG, MIG, combinatorial) into equivalent
11// single-bit operations, enabling more efficient synthesis and optimization.
12//
13//===----------------------------------------------------------------------===//
14
19#include "circt/Support/LLVM.h"
20#include "mlir/IR/Operation.h"
21#include "mlir/Pass/Pass.h"
22#include "llvm/ADT/APInt.h"
23#include "llvm/ADT/SmallVector.h"
24#include "llvm/ADT/TypeSwitch.h"
25#include "llvm/Support/KnownBits.h"
26#include "llvm/Support/LogicalResult.h"
27
28#define DEBUG_TYPE "synth-lower-word-to-bits"
29
30namespace circt {
31namespace synth {
32#define GEN_PASS_DEF_LOWERWORDTOBITS
33#include "circt/Dialect/Synth/Transforms/SynthPasses.h.inc"
34} // namespace synth
35} // namespace circt
36
37using namespace circt;
38using namespace synth;
39
40//===----------------------------------------------------------------------===//
41// Utilities
42//===----------------------------------------------------------------------===//
43
44/// Check if an operation should be lowered to bit-level operations.
45static bool shouldLowerOperation(Operation *op) {
46 return isa<ChoiceOp, aig::AndInverterOp, mig::MajorityInverterOp, comb::AndOp,
48}
49
50namespace {
51
52//===----------------------------------------------------------------------===//
53// BitBlaster - Bit-level lowering implementation
54//===----------------------------------------------------------------------===//
55
56/// The BitBlaster class implements the core bit-blasting algorithm.
57/// It manages the lowering of multi-bit operations to single-bit operations
58/// while maintaining correctness and optimizing for constant propagation.
59class BitBlaster {
60public:
61 explicit BitBlaster(hw::HWModuleOp moduleOp) : moduleOp(moduleOp) {}
62
63 /// Run the bit-blasting algorithm on the module.
64 LogicalResult run();
65
66 //===--------------------------------------------------------------------===//
67 // Statistics
68 //===--------------------------------------------------------------------===//
69
70 /// Number of bits that were lowered from multi-bit to single-bit operations
71 size_t numLoweredBits = 0;
72
73 /// Number of constant bits that were identified and optimized
74 size_t numLoweredConstants = 0;
75
76 /// Number of operations that were lowered
77 size_t numLoweredOps = 0;
78
79private:
80 //===--------------------------------------------------------------------===//
81 // Core Lowering Methods
82 //===--------------------------------------------------------------------===//
83
84 /// Lower a multi-bit value to individual bits.
85 /// This is the main entry point for bit-blasting a value.
86 ArrayRef<Value> lowerValueToBits(Value value);
87 template <typename OpTy>
88 ArrayRef<Value> lowerInvertibleOperations(OpTy op);
89 template <typename OpTy>
90 ArrayRef<Value> lowerCombLogicOperations(OpTy op);
91 ArrayRef<Value> lowerCombMux(comb::MuxOp op);
92 ArrayRef<Value>
93 lowerOp(Operation *op,
94 llvm::function_ref<Value(OpBuilder &builder, ValueRange)> createOp);
95
96 /// Extract a specific bit from a value.
97 /// Handles various IR constructs that can represent bit extraction.
98 Value extractBit(Value value, size_t index);
99
100 /// Compute and cache known bits for a value.
101 /// Uses operation-specific logic to determine which bits are constants.
102 const llvm::KnownBits &computeKnownBits(Value value);
103
104 /// Get or create a boolean constant (0 or 1).
105 /// Constants are cached to avoid duplication.
106 Value getBoolConstant(bool value);
107
108 //===--------------------------------------------------------------------===//
109 // Helper Methods
110 //===--------------------------------------------------------------------===//
111
112 /// Insert lowered bits into the cache.
113 ArrayRef<Value> insertBits(Value value, SmallVector<Value> bits) {
114 auto it = loweredValues.insert({value, std::move(bits)});
115 assert(it.second && "value already inserted");
116 return it.first->second;
117 }
118
119 /// Insert computed known bits into the cache.
120 const llvm::KnownBits &insertKnownBits(Value value, llvm::KnownBits bits) {
121 auto it = knownBits.insert({value, std::move(bits)});
122 return it.first->second;
123 }
124
125 /// Cache for lowered values (multi-bit -> vector of single-bit values)
126 llvm::MapVector<Value, SmallVector<Value>> loweredValues;
127
128 /// Cache for computed known bits information
129 llvm::MapVector<Value, llvm::KnownBits> knownBits;
130
131 /// Cached boolean constants (false at index 0, true at index 1)
132 std::array<Value, 2> constants;
133
134 /// Reference to the module being processed
135 hw::HWModuleOp moduleOp;
136};
137
138} // namespace
139
140//===----------------------------------------------------------------------===//
141// BitBlaster Implementation
142//===----------------------------------------------------------------------===//
143
144const llvm::KnownBits &BitBlaster::computeKnownBits(Value value) {
145 // Check cache first
146 auto *it = knownBits.find(value);
147 if (it != knownBits.end())
148 return it->second;
149
150 auto width = hw::getBitWidth(value.getType());
151 auto *op = value.getDefiningOp();
152
153 // For block arguments, return unknown bits
154 if (!op)
155 return insertKnownBits(value, llvm::KnownBits(width));
156
157 llvm::KnownBits result(width);
158 if (auto aig = dyn_cast<aig::AndInverterOp>(op)) {
159 // Initialize to all ones for AND operation
160 result.One = APInt::getAllOnes(width);
161 result.Zero = APInt::getZero(width);
162
163 for (auto [operand, inverted] :
164 llvm::zip(aig.getInputs(), aig.getInverted())) {
165 auto operandKnownBits = computeKnownBits(operand);
166 if (inverted)
167 // Complement the known bits by swapping Zero and One
168 std::swap(operandKnownBits.Zero, operandKnownBits.One);
169 result &= operandKnownBits;
170 }
171 } else if (auto mig = dyn_cast<mig::MajorityInverterOp>(op)) {
172 // Give up if it's not a 3-input majority inverter.
173 if (mig.getNumOperands() == 3) {
174 std::array<llvm::KnownBits, 3> operandsKnownBits;
175 for (auto [i, operand, inverted] :
176 llvm::enumerate(mig.getInputs(), mig.getInverted())) {
177 operandsKnownBits[i] = computeKnownBits(operand);
178 // Complement the known bits by swapping Zero and One
179 if (inverted)
180 std::swap(operandsKnownBits[i].Zero, operandsKnownBits[i].One);
181 }
182
183 result = (operandsKnownBits[0] & operandsKnownBits[1]) |
184 (operandsKnownBits[0] & operandsKnownBits[2]) |
185 (operandsKnownBits[1] & operandsKnownBits[2]);
186 }
187 } else if (auto choice = dyn_cast<ChoiceOp>(op)) {
188 result = computeKnownBits(choice.getInputs().front());
189 for (auto input : choice.getInputs().drop_front()) {
190 auto known = computeKnownBits(input);
191 result.One |= known.One;
192 result.Zero |= known.Zero;
193 }
194 } else {
195 // For other operations, use the standard known bits computation
196 // TODO: This is not optimal as it has a depth limit and does not check
197 // cached results.
198 result = comb::computeKnownBits(value);
199 }
200
201 return insertKnownBits(value, std::move(result));
202}
203
204Value BitBlaster::extractBit(Value value, size_t index) {
205 if (hw::getBitWidth(value.getType()) <= 1)
206 return value;
207
208 auto *op = value.getDefiningOp();
209
210 // If the value is a block argument, extract the bit.
211 if (!op)
212 return lowerValueToBits(value)[index];
213
214 return TypeSwitch<Operation *, Value>(op)
215 .Case<comb::ConcatOp>([&](comb::ConcatOp op) {
216 for (auto operand : llvm::reverse(op.getOperands())) {
217 auto width = hw::getBitWidth(operand.getType());
218 assert(width >= 0 && "operand has zero width");
219 if (index < static_cast<size_t>(width))
220 return extractBit(operand, index);
221 index -= static_cast<size_t>(width);
222 }
223 llvm_unreachable("index out of bounds");
224 })
225 .Case<comb::ExtractOp>([&](comb::ExtractOp ext) {
226 return extractBit(ext.getInput(),
227 static_cast<size_t>(ext.getLowBit()) + index);
228 })
229 .Case<comb::ReplicateOp>([&](comb::ReplicateOp op) {
230 return extractBit(op.getInput(),
231 index % static_cast<size_t>(hw::getBitWidth(
232 op.getOperand().getType())));
233 })
234 .Case<hw::ConstantOp>([&](hw::ConstantOp op) {
235 auto value = op.getValue();
236 return getBoolConstant(value[index]);
237 })
238 .Default([&](auto op) { return lowerValueToBits(value)[index]; });
239}
240
241ArrayRef<Value> BitBlaster::lowerValueToBits(Value value) {
242 auto *it = loweredValues.find(value);
243 if (it != loweredValues.end())
244 return it->second;
245
246 auto width = hw::getBitWidth(value.getType());
247 if (width <= 1)
248 return insertBits(value, {value});
249
250 auto *op = value.getDefiningOp();
251 if (!op) {
252 SmallVector<Value> results;
253 OpBuilder builder(value.getContext());
254 builder.setInsertionPointAfterValue(value);
255 comb::extractBits(builder, value, results);
256 return insertBits(value, std::move(results));
257 }
258
259 return TypeSwitch<Operation *, ArrayRef<Value>>(op)
260 .Case<ChoiceOp>([&](ChoiceOp op) {
261 auto createOp = [&](OpBuilder &builder, ValueRange operands) {
262 return builder.createOrFold<ChoiceOp>(
263 op.getLoc(), operands[0].getType(), operands);
264 };
265 return lowerOp(op, createOp);
266 })
267 .Case<aig::AndInverterOp, mig::MajorityInverterOp>(
268 [&](auto op) { return lowerInvertibleOperations(op); })
269 .Case<comb::AndOp, comb::OrOp, comb::XorOp>(
270 [&](auto op) { return lowerCombLogicOperations(op); })
271 .Case<comb::MuxOp>([&](comb::MuxOp op) { return lowerCombMux(op); })
272 .Default([&](auto op) {
273 OpBuilder builder(value.getContext());
274 builder.setInsertionPoint(op);
275 SmallVector<Value> results;
276 comb::extractBits(builder, value, results);
277
278 return insertBits(value, std::move(results));
279 });
280}
281
282LogicalResult BitBlaster::run() {
283 // Topologically sort operations in graph regions so that walk visits them in
284 // the topological order.
286 moduleOp, [](Value value, Operation *op) -> bool {
287 // Otherthan target ops, all other ops are always ready.
288 return !(shouldLowerOperation(op) ||
289 isa<comb::ExtractOp, comb::ReplicateOp, comb::ConcatOp,
290 comb::ReplicateOp>(op));
291 }))) {
292 // If we failed to topologically sort operations we cannot proceed.
293 return mlir::emitError(moduleOp.getLoc(), "there is a combinational cycle");
294 }
295
296 // Lower target operations
297 moduleOp.walk([&](Operation *op) {
298 // If the block is in a graph region, topologically sort it first.
299 if (shouldLowerOperation(op))
300 (void)lowerValueToBits(op->getResult(0));
301 });
302
303 // Replace operations with concatenated results if needed
304 for (auto &[value, results] :
305 llvm::make_early_inc_range(llvm::reverse(loweredValues))) {
306 if (hw::getBitWidth(value.getType()) <= 1)
307 continue;
308
309 auto *op = value.getDefiningOp();
310 if (!op)
311 continue;
312
313 if (value.use_empty()) {
314 op->erase();
315 continue;
316 }
317
318 // If a target operation still has an use (e.g. connected to output or
319 // instance), replace the value with the concatenated result.
320 if (shouldLowerOperation(op)) {
321 OpBuilder builder(op);
322 std::reverse(results.begin(), results.end());
323 auto concat = comb::ConcatOp::create(builder, value.getLoc(), results);
324 value.replaceAllUsesWith(concat);
325 op->erase();
326 }
327 }
328
329 return success();
330}
331
332Value BitBlaster::getBoolConstant(bool value) {
333 if (!constants[value]) {
334 auto builder = OpBuilder::atBlockBegin(moduleOp.getBodyBlock());
335 constants[value] = hw::ConstantOp::create(builder, builder.getUnknownLoc(),
336 builder.getI1Type(), value);
337 }
338 return constants[value];
339}
340
341template <typename OpTy>
342ArrayRef<Value> BitBlaster::lowerInvertibleOperations(OpTy op) {
343 auto createOp = [&](OpBuilder &builder, ValueRange operands) {
344 return builder.createOrFold<OpTy>(op.getLoc(), operands, op.getInverted());
345 };
346 return lowerOp(op, createOp);
347}
348
349template <typename OpTy>
350ArrayRef<Value> BitBlaster::lowerCombLogicOperations(OpTy op) {
351 auto createOp = [&](OpBuilder &builder, ValueRange operands) {
352 return builder.createOrFold<OpTy>(op.getLoc(), operands,
353 op.getTwoStateAttr());
354 };
355 return lowerOp(op, createOp);
356}
357
358ArrayRef<Value> BitBlaster::lowerCombMux(comb::MuxOp op) {
359 auto createOp = [&](OpBuilder &builder, ValueRange operands) {
360 return builder.createOrFold<comb::MuxOp>(op.getLoc(), operands[0],
361 operands[1], operands[2],
362 op.getTwoStateAttr());
363 };
364 return lowerOp(op, createOp);
365}
366
367ArrayRef<Value> BitBlaster::lowerOp(
368 Operation *op,
369 llvm::function_ref<Value(OpBuilder &builder, ValueRange)> createOp) {
370 auto value = op->getResult(0);
371 OpBuilder builder(op);
372 auto width = hw::getBitWidth(value.getType());
373 assert(width > 1 && "expected multi-bit operation");
374
375 auto known = computeKnownBits(value);
376 APInt knownMask = known.Zero | known.One;
377
378 // Update statistics
379 numLoweredConstants += knownMask.popcount();
380 numLoweredBits += width;
381 ++numLoweredOps;
382
383 SmallVector<Value> results;
384 results.reserve(width);
385
386 for (int64_t i = 0; i < width; ++i) {
387 SmallVector<Value> operands;
388 operands.reserve(op->getNumOperands());
389 if (knownMask[i]) {
390 // Use known constant value
391 results.push_back(getBoolConstant(known.One[i]));
392 continue;
393 }
394
395 // Extract the i-th bit from each operand
396 for (auto operand : op->getOperands())
397 operands.push_back(extractBit(operand, i));
398
399 // Create the single-bit operation
400 auto result = createOp(builder, operands);
401 results.push_back(result);
402
403 // Add name hint if present
404 if (auto name = op->getAttrOfType<StringAttr>("sv.namehint")) {
405 auto newName = StringAttr::get(
406 op->getContext(), name.getValue() + "[" + std::to_string(i) + "]");
407 if (auto *loweredOp = result.getDefiningOp())
408 loweredOp->setAttr("sv.namehint", newName);
409 }
410 }
411
412 assert(results.size() == static_cast<size_t>(width));
413 return insertBits(value, std::move(results));
414}
415
416//===----------------------------------------------------------------------===//
417// Pass Implementation
418//===----------------------------------------------------------------------===//
419
420namespace {
421struct LowerWordToBitsPass
422 : public impl::LowerWordToBitsBase<LowerWordToBitsPass> {
423 void runOnOperation() override;
424};
425} // namespace
426
427void LowerWordToBitsPass::runOnOperation() {
428 BitBlaster driver(getOperation());
429 if (failed(driver.run()))
430 return signalPassFailure();
431
432 // Update statistics
433 numLoweredBits += driver.numLoweredBits;
434 numLoweredConstants += driver.numLoweredConstants;
435 numLoweredOps += driver.numLoweredOps;
436}
assert(baseType &&"element must be base type")
static KnownBits computeKnownBits(Value v, unsigned depth)
Given an integer SSA value, check to see if we know anything about the result of the computation.
static bool shouldLowerOperation(Operation *op)
Check if an operation should be lowered to bit-level operations.
create(data_type, value)
Definition hw.py:433
LogicalResult topologicallySortGraphRegionBlocks(mlir::Operation *op, llvm::function_ref< bool(mlir::Value, mlir::Operation *)> isOperandReady)
This function performs a topological sort on the operations within each block of graph regions in the...
Definition SynthOps.cpp:313
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
int run(Type[Generator] generator=CppGenerator, cmdline_args=sys.argv)
Definition codegen.py:445
Definition synth.py:1