CIRCT 23.0.0git
Loading...
Searching...
No Matches
LowerWordToBits.cpp
Go to the documentation of this file.
1//===----------------------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass implements bit-blasting for logic synthesis operations.
10// It converts multi-bit operations (AIG, MIG, combinatorial) into equivalent
11// single-bit operations, enabling more efficient synthesis and optimization.
12//
13//===----------------------------------------------------------------------===//
14
19#include "circt/Support/LLVM.h"
20#include "mlir/IR/Operation.h"
21#include "mlir/Pass/Pass.h"
22#include "llvm/ADT/APInt.h"
23#include "llvm/ADT/SmallVector.h"
24#include "llvm/ADT/TypeSwitch.h"
25#include "llvm/Support/KnownBits.h"
26#include "llvm/Support/LogicalResult.h"
27
28#define DEBUG_TYPE "synth-lower-word-to-bits"
29
30namespace circt {
31namespace synth {
32#define GEN_PASS_DEF_LOWERWORDTOBITS
33#include "circt/Dialect/Synth/Transforms/SynthPasses.h.inc"
34} // namespace synth
35} // namespace circt
36
37using namespace circt;
38using namespace synth;
39
40//===----------------------------------------------------------------------===//
41// Utilities
42//===----------------------------------------------------------------------===//
43
44/// Check if an operation should be lowered to bit-level operations.
45static bool shouldLowerOperation(Operation *op) {
46 return isa<aig::AndInverterOp, mig::MajorityInverterOp, comb::AndOp,
48}
49
50namespace {
51
52//===----------------------------------------------------------------------===//
53// BitBlaster - Bit-level lowering implementation
54//===----------------------------------------------------------------------===//
55
56/// The BitBlaster class implements the core bit-blasting algorithm.
57/// It manages the lowering of multi-bit operations to single-bit operations
58/// while maintaining correctness and optimizing for constant propagation.
59class BitBlaster {
60public:
61 explicit BitBlaster(hw::HWModuleOp moduleOp) : moduleOp(moduleOp) {}
62
63 /// Run the bit-blasting algorithm on the module.
64 LogicalResult run();
65
66 //===--------------------------------------------------------------------===//
67 // Statistics
68 //===--------------------------------------------------------------------===//
69
70 /// Number of bits that were lowered from multi-bit to single-bit operations
71 size_t numLoweredBits = 0;
72
73 /// Number of constant bits that were identified and optimized
74 size_t numLoweredConstants = 0;
75
76 /// Number of operations that were lowered
77 size_t numLoweredOps = 0;
78
79private:
80 //===--------------------------------------------------------------------===//
81 // Core Lowering Methods
82 //===--------------------------------------------------------------------===//
83
84 /// Lower a multi-bit value to individual bits.
85 /// This is the main entry point for bit-blasting a value.
86 ArrayRef<Value> lowerValueToBits(Value value);
87 template <typename OpTy>
88 ArrayRef<Value> lowerInvertibleOperations(OpTy op);
89 template <typename OpTy>
90 ArrayRef<Value> lowerCombLogicOperations(OpTy op);
91 ArrayRef<Value> lowerCombMux(comb::MuxOp op);
92 ArrayRef<Value>
93 lowerOp(Operation *op,
94 llvm::function_ref<Value(OpBuilder &builder, ValueRange)> createOp);
95
96 /// Extract a specific bit from a value.
97 /// Handles various IR constructs that can represent bit extraction.
98 Value extractBit(Value value, size_t index);
99
100 /// Compute and cache known bits for a value.
101 /// Uses operation-specific logic to determine which bits are constants.
102 const llvm::KnownBits &computeKnownBits(Value value);
103
104 /// Get or create a boolean constant (0 or 1).
105 /// Constants are cached to avoid duplication.
106 Value getBoolConstant(bool value);
107
108 //===--------------------------------------------------------------------===//
109 // Helper Methods
110 //===--------------------------------------------------------------------===//
111
112 /// Insert lowered bits into the cache.
113 ArrayRef<Value> insertBits(Value value, SmallVector<Value> bits) {
114 auto it = loweredValues.insert({value, std::move(bits)});
115 assert(it.second && "value already inserted");
116 return it.first->second;
117 }
118
119 /// Insert computed known bits into the cache.
120 const llvm::KnownBits &insertKnownBits(Value value, llvm::KnownBits bits) {
121 auto it = knownBits.insert({value, std::move(bits)});
122 return it.first->second;
123 }
124
125 /// Cache for lowered values (multi-bit -> vector of single-bit values)
126 llvm::MapVector<Value, SmallVector<Value>> loweredValues;
127
128 /// Cache for computed known bits information
129 llvm::MapVector<Value, llvm::KnownBits> knownBits;
130
131 /// Cached boolean constants (false at index 0, true at index 1)
132 std::array<Value, 2> constants;
133
134 /// Reference to the module being processed
135 hw::HWModuleOp moduleOp;
136};
137
138} // namespace
139
140//===----------------------------------------------------------------------===//
141// BitBlaster Implementation
142//===----------------------------------------------------------------------===//
143
144const llvm::KnownBits &BitBlaster::computeKnownBits(Value value) {
145 // Check cache first
146 auto *it = knownBits.find(value);
147 if (it != knownBits.end())
148 return it->second;
149
150 auto width = hw::getBitWidth(value.getType());
151 auto *op = value.getDefiningOp();
152
153 // For block arguments, return unknown bits
154 if (!op)
155 return insertKnownBits(value, llvm::KnownBits(width));
156
157 llvm::KnownBits result(width);
158 if (auto aig = dyn_cast<aig::AndInverterOp>(op)) {
159 // Initialize to all ones for AND operation
160 result.One = APInt::getAllOnes(width);
161 result.Zero = APInt::getZero(width);
162
163 for (auto [operand, inverted] :
164 llvm::zip(aig.getInputs(), aig.getInverted())) {
165 auto operandKnownBits = computeKnownBits(operand);
166 if (inverted)
167 // Complement the known bits by swapping Zero and One
168 std::swap(operandKnownBits.Zero, operandKnownBits.One);
169 result &= operandKnownBits;
170 }
171 } else if (auto mig = dyn_cast<mig::MajorityInverterOp>(op)) {
172 // Give up if it's not a 3-input majority inverter.
173 if (mig.getNumOperands() == 3) {
174 std::array<llvm::KnownBits, 3> operandsKnownBits;
175 for (auto [i, operand, inverted] :
176 llvm::enumerate(mig.getInputs(), mig.getInverted())) {
177 operandsKnownBits[i] = computeKnownBits(operand);
178 // Complement the known bits by swapping Zero and One
179 if (inverted)
180 std::swap(operandsKnownBits[i].Zero, operandsKnownBits[i].One);
181 }
182
183 result = (operandsKnownBits[0] & operandsKnownBits[1]) |
184 (operandsKnownBits[0] & operandsKnownBits[2]) |
185 (operandsKnownBits[1] & operandsKnownBits[2]);
186 }
187 } else {
188 // For other operations, use the standard known bits computation
189 // TODO: This is not optimal as it has a depth limit and does not check
190 // cached results.
191 result = comb::computeKnownBits(value);
192 }
193
194 return insertKnownBits(value, std::move(result));
195}
196
197Value BitBlaster::extractBit(Value value, size_t index) {
198 if (hw::getBitWidth(value.getType()) <= 1)
199 return value;
200
201 auto *op = value.getDefiningOp();
202
203 // If the value is a block argument, extract the bit.
204 if (!op)
205 return lowerValueToBits(value)[index];
206
207 return TypeSwitch<Operation *, Value>(op)
208 .Case<comb::ConcatOp>([&](comb::ConcatOp op) {
209 for (auto operand : llvm::reverse(op.getOperands())) {
210 auto width = hw::getBitWidth(operand.getType());
211 assert(width >= 0 && "operand has zero width");
212 if (index < static_cast<size_t>(width))
213 return extractBit(operand, index);
214 index -= static_cast<size_t>(width);
215 }
216 llvm_unreachable("index out of bounds");
217 })
218 .Case<comb::ExtractOp>([&](comb::ExtractOp ext) {
219 return extractBit(ext.getInput(),
220 static_cast<size_t>(ext.getLowBit()) + index);
221 })
222 .Case<comb::ReplicateOp>([&](comb::ReplicateOp op) {
223 return extractBit(op.getInput(),
224 index % static_cast<size_t>(hw::getBitWidth(
225 op.getOperand().getType())));
226 })
227 .Case<hw::ConstantOp>([&](hw::ConstantOp op) {
228 auto value = op.getValue();
229 return getBoolConstant(value[index]);
230 })
231 .Default([&](auto op) { return lowerValueToBits(value)[index]; });
232}
233
234ArrayRef<Value> BitBlaster::lowerValueToBits(Value value) {
235 auto *it = loweredValues.find(value);
236 if (it != loweredValues.end())
237 return it->second;
238
239 auto width = hw::getBitWidth(value.getType());
240 if (width <= 1)
241 return insertBits(value, {value});
242
243 auto *op = value.getDefiningOp();
244 if (!op) {
245 SmallVector<Value> results;
246 OpBuilder builder(value.getContext());
247 builder.setInsertionPointAfterValue(value);
248 comb::extractBits(builder, value, results);
249 return insertBits(value, std::move(results));
250 }
251
252 return TypeSwitch<Operation *, ArrayRef<Value>>(op)
253 .Case<aig::AndInverterOp, mig::MajorityInverterOp>(
254 [&](auto op) { return lowerInvertibleOperations(op); })
255 .Case<comb::AndOp, comb::OrOp, comb::XorOp>(
256 [&](auto op) { return lowerCombLogicOperations(op); })
257 .Case<comb::MuxOp>([&](comb::MuxOp op) { return lowerCombMux(op); })
258 .Default([&](auto op) {
259 OpBuilder builder(value.getContext());
260 builder.setInsertionPoint(op);
261 SmallVector<Value> results;
262 comb::extractBits(builder, value, results);
263
264 return insertBits(value, std::move(results));
265 });
266}
267
268LogicalResult BitBlaster::run() {
269 // Topologically sort operations in graph regions so that walk visits them in
270 // the topological order.
272 moduleOp, [](Value value, Operation *op) -> bool {
273 // Otherthan target ops, all other ops are always ready.
274 return !(shouldLowerOperation(op) ||
275 isa<comb::ExtractOp, comb::ReplicateOp, comb::ConcatOp,
276 comb::ReplicateOp>(op));
277 }))) {
278 // If we failed to topologically sort operations we cannot proceed.
279 return mlir::emitError(moduleOp.getLoc(), "there is a combinational cycle");
280 }
281
282 // Lower target operations
283 moduleOp.walk([&](Operation *op) {
284 // If the block is in a graph region, topologically sort it first.
285 if (shouldLowerOperation(op))
286 (void)lowerValueToBits(op->getResult(0));
287 });
288
289 // Replace operations with concatenated results if needed
290 for (auto &[value, results] :
291 llvm::make_early_inc_range(llvm::reverse(loweredValues))) {
292 if (hw::getBitWidth(value.getType()) <= 1)
293 continue;
294
295 auto *op = value.getDefiningOp();
296 if (!op)
297 continue;
298
299 if (value.use_empty()) {
300 op->erase();
301 continue;
302 }
303
304 // If a target operation still has an use (e.g. connected to output or
305 // instance), replace the value with the concatenated result.
306 if (shouldLowerOperation(op)) {
307 OpBuilder builder(op);
308 std::reverse(results.begin(), results.end());
309 auto concat = comb::ConcatOp::create(builder, value.getLoc(), results);
310 value.replaceAllUsesWith(concat);
311 op->erase();
312 }
313 }
314
315 return success();
316}
317
318Value BitBlaster::getBoolConstant(bool value) {
319 if (!constants[value]) {
320 auto builder = OpBuilder::atBlockBegin(moduleOp.getBodyBlock());
321 constants[value] = hw::ConstantOp::create(builder, builder.getUnknownLoc(),
322 builder.getI1Type(), value);
323 }
324 return constants[value];
325}
326
327template <typename OpTy>
328ArrayRef<Value> BitBlaster::lowerInvertibleOperations(OpTy op) {
329 auto createOp = [&](OpBuilder &builder, ValueRange operands) {
330 return builder.createOrFold<OpTy>(op.getLoc(), operands, op.getInverted());
331 };
332 return lowerOp(op, createOp);
333}
334
335template <typename OpTy>
336ArrayRef<Value> BitBlaster::lowerCombLogicOperations(OpTy op) {
337 auto createOp = [&](OpBuilder &builder, ValueRange operands) {
338 return builder.createOrFold<OpTy>(op.getLoc(), operands,
339 op.getTwoStateAttr());
340 };
341 return lowerOp(op, createOp);
342}
343
344ArrayRef<Value> BitBlaster::lowerCombMux(comb::MuxOp op) {
345 auto createOp = [&](OpBuilder &builder, ValueRange operands) {
346 return builder.createOrFold<comb::MuxOp>(op.getLoc(), operands[0],
347 operands[1], operands[2],
348 op.getTwoStateAttr());
349 };
350 return lowerOp(op, createOp);
351}
352
353ArrayRef<Value> BitBlaster::lowerOp(
354 Operation *op,
355 llvm::function_ref<Value(OpBuilder &builder, ValueRange)> createOp) {
356 auto value = op->getResult(0);
357 OpBuilder builder(op);
358 auto width = hw::getBitWidth(value.getType());
359 assert(width > 1 && "expected multi-bit operation");
360
361 auto known = computeKnownBits(value);
362 APInt knownMask = known.Zero | known.One;
363
364 // Update statistics
365 numLoweredConstants += knownMask.popcount();
366 numLoweredBits += width;
367 ++numLoweredOps;
368
369 SmallVector<Value> results;
370 results.reserve(width);
371
372 for (int64_t i = 0; i < width; ++i) {
373 SmallVector<Value> operands;
374 operands.reserve(op->getNumOperands());
375 if (knownMask[i]) {
376 // Use known constant value
377 results.push_back(getBoolConstant(known.One[i]));
378 continue;
379 }
380
381 // Extract the i-th bit from each operand
382 for (auto operand : op->getOperands())
383 operands.push_back(extractBit(operand, i));
384
385 // Create the single-bit operation
386 auto result = createOp(builder, operands);
387 results.push_back(result);
388
389 // Add name hint if present
390 if (auto name = op->getAttrOfType<StringAttr>("sv.namehint")) {
391 auto newName = StringAttr::get(
392 op->getContext(), name.getValue() + "[" + std::to_string(i) + "]");
393 if (auto *loweredOp = result.getDefiningOp())
394 loweredOp->setAttr("sv.namehint", newName);
395 }
396 }
397
398 assert(results.size() == static_cast<size_t>(width));
399 return insertBits(value, std::move(results));
400}
401
402//===----------------------------------------------------------------------===//
403// Pass Implementation
404//===----------------------------------------------------------------------===//
405
406namespace {
407struct LowerWordToBitsPass
408 : public impl::LowerWordToBitsBase<LowerWordToBitsPass> {
409 void runOnOperation() override;
410};
411} // namespace
412
413void LowerWordToBitsPass::runOnOperation() {
414 BitBlaster driver(getOperation());
415 if (failed(driver.run()))
416 return signalPassFailure();
417
418 // Update statistics
419 numLoweredBits += driver.numLoweredBits;
420 numLoweredConstants += driver.numLoweredConstants;
421 numLoweredOps += driver.numLoweredOps;
422}
assert(baseType &&"element must be base type")
static KnownBits computeKnownBits(Value v, unsigned depth)
Given an integer SSA value, check to see if we know anything about the result of the computation.
static bool shouldLowerOperation(Operation *op)
Check if an operation should be lowered to bit-level operations.
create(data_type, value)
Definition hw.py:433
LogicalResult topologicallySortGraphRegionBlocks(mlir::Operation *op, llvm::function_ref< bool(mlir::Value, mlir::Operation *)> isOperandReady)
This function performs a topological sort on the operations within each block of graph regions in the...
Definition SynthOps.cpp:306
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
int run(Type[Generator] generator=CppGenerator, cmdline_args=sys.argv)
Definition codegen.py:440
Definition synth.py:1