CIRCT 22.0.0git
Loading...
Searching...
No Matches
LowerWordToBits.cpp
Go to the documentation of this file.
1//===----------------------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass implements bit-blasting for logic synthesis operations.
10// It converts multi-bit operations (AIG, MIG, combinatorial) into equivalent
11// single-bit operations, enabling more efficient synthesis and optimization.
12//
13//===----------------------------------------------------------------------===//
14
19#include "circt/Support/LLVM.h"
20#include "mlir/IR/Operation.h"
21#include "mlir/Pass/Pass.h"
22#include "llvm/ADT/APInt.h"
23#include "llvm/ADT/SmallVector.h"
24#include "llvm/ADT/TypeSwitch.h"
25#include "llvm/Support/KnownBits.h"
26#include "llvm/Support/LogicalResult.h"
27
28#define DEBUG_TYPE "synth-lower-word-to-bits"
29
30namespace circt {
31namespace synth {
32#define GEN_PASS_DEF_LOWERWORDTOBITS
33#include "circt/Dialect/Synth/Transforms/SynthPasses.h.inc"
34} // namespace synth
35} // namespace circt
36
37using namespace circt;
38using namespace synth;
39
40//===----------------------------------------------------------------------===//
41// Utilities
42//===----------------------------------------------------------------------===//
43
44/// Check if an operation should be lowered to bit-level operations.
45static bool shouldLowerOperation(Operation *op) {
46 return isa<aig::AndInverterOp, mig::MajorityInverterOp, comb::AndOp,
48}
49
50namespace {
51
52//===----------------------------------------------------------------------===//
53// BitBlaster - Bit-level lowering implementation
54//===----------------------------------------------------------------------===//
55
56/// The BitBlaster class implements the core bit-blasting algorithm.
57/// It manages the lowering of multi-bit operations to single-bit operations
58/// while maintaining correctness and optimizing for constant propagation.
59class BitBlaster {
60public:
61 explicit BitBlaster(hw::HWModuleOp moduleOp) : moduleOp(moduleOp) {}
62
63 /// Run the bit-blasting algorithm on the module.
64 LogicalResult run();
65
66 //===--------------------------------------------------------------------===//
67 // Statistics
68 //===--------------------------------------------------------------------===//
69
70 /// Number of bits that were lowered from multi-bit to single-bit operations
71 size_t numLoweredBits = 0;
72
73 /// Number of constant bits that were identified and optimized
74 size_t numLoweredConstants = 0;
75
76 /// Number of operations that were lowered
77 size_t numLoweredOps = 0;
78
79private:
80 //===--------------------------------------------------------------------===//
81 // Core Lowering Methods
82 //===--------------------------------------------------------------------===//
83
84 /// Lower a multi-bit value to individual bits.
85 /// This is the main entry point for bit-blasting a value.
86 ArrayRef<Value> lowerValueToBits(Value value);
87 template <typename OpTy>
88 ArrayRef<Value> lowerInvertibleOperations(OpTy op);
89 template <typename OpTy>
90 ArrayRef<Value> lowerCombLogicOperations(OpTy op);
91 ArrayRef<Value> lowerCombMux(comb::MuxOp op);
92 ArrayRef<Value>
93 lowerOp(Operation *op,
94 llvm::function_ref<Value(OpBuilder &builder, ValueRange)> createOp);
95
96 /// Extract a specific bit from a value.
97 /// Handles various IR constructs that can represent bit extraction.
98 Value extractBit(Value value, size_t index);
99
100 /// Compute and cache known bits for a value.
101 /// Uses operation-specific logic to determine which bits are constants.
102 const llvm::KnownBits &computeKnownBits(Value value);
103
104 /// Get or create a boolean constant (0 or 1).
105 /// Constants are cached to avoid duplication.
106 Value getBoolConstant(bool value);
107
108 //===--------------------------------------------------------------------===//
109 // Helper Methods
110 //===--------------------------------------------------------------------===//
111
112 /// Insert lowered bits into the cache.
113 ArrayRef<Value> insertBits(Value value, SmallVector<Value> bits) {
114 auto it = loweredValues.insert({value, std::move(bits)});
115 assert(it.second && "value already inserted");
116 return it.first->second;
117 }
118
119 /// Insert computed known bits into the cache.
120 const llvm::KnownBits &insertKnownBits(Value value, llvm::KnownBits bits) {
121 auto it = knownBits.insert({value, std::move(bits)});
122 return it.first->second;
123 }
124
125 /// Cache for lowered values (multi-bit -> vector of single-bit values)
126 llvm::MapVector<Value, SmallVector<Value>> loweredValues;
127
128 /// Cache for computed known bits information
129 llvm::MapVector<Value, llvm::KnownBits> knownBits;
130
131 /// Cached boolean constants (false at index 0, true at index 1)
132 std::array<Value, 2> constants;
133
134 /// Reference to the module being processed
135 hw::HWModuleOp moduleOp;
136};
137
138} // namespace
139
140//===----------------------------------------------------------------------===//
141// BitBlaster Implementation
142//===----------------------------------------------------------------------===//
143
144const llvm::KnownBits &BitBlaster::computeKnownBits(Value value) {
145 // Check cache first
146 auto *it = knownBits.find(value);
147 if (it != knownBits.end())
148 return it->second;
149
150 auto width = hw::getBitWidth(value.getType());
151 auto *op = value.getDefiningOp();
152
153 // For block arguments, return unknown bits
154 if (!op)
155 return insertKnownBits(value, llvm::KnownBits(width));
156
157 llvm::KnownBits result(width);
158 if (auto aig = dyn_cast<aig::AndInverterOp>(op)) {
159 // Initialize to all ones for AND operation
160 result.One = APInt::getAllOnes(width);
161 result.Zero = APInt::getZero(width);
162
163 for (auto [operand, inverted] :
164 llvm::zip(aig.getInputs(), aig.getInverted())) {
165 auto operandKnownBits = computeKnownBits(operand);
166 if (inverted)
167 // Complement the known bits by swapping Zero and One
168 std::swap(operandKnownBits.Zero, operandKnownBits.One);
169 result &= operandKnownBits;
170 }
171 } else if (auto mig = dyn_cast<mig::MajorityInverterOp>(op)) {
172 // Give up if it's not a 3-input majority inverter.
173 if (mig.getNumOperands() == 3) {
174 std::array<llvm::KnownBits, 3> operandsKnownBits;
175 for (auto [i, operand, inverted] :
176 llvm::enumerate(mig.getInputs(), mig.getInverted())) {
177 operandsKnownBits[i] = computeKnownBits(operand);
178 // Complement the known bits by swapping Zero and One
179 if (inverted)
180 std::swap(operandsKnownBits[i].Zero, operandsKnownBits[i].One);
181 }
182
183 result = (operandsKnownBits[0] & operandsKnownBits[1]) |
184 (operandsKnownBits[0] & operandsKnownBits[2]) |
185 (operandsKnownBits[1] & operandsKnownBits[2]);
186 }
187 } else {
188 // For other operations, use the standard known bits computation
189 // TODO: This is not optimal as it has a depth limit and does not check
190 // cached results.
191 result = comb::computeKnownBits(value);
192 }
193
194 return insertKnownBits(value, std::move(result));
195}
196
197Value BitBlaster::extractBit(Value value, size_t index) {
198 if (hw::getBitWidth(value.getType()) <= 1)
199 return value;
200
201 auto *op = value.getDefiningOp();
202
203 // If the value is a block argument, extract the bit.
204 if (!op)
205 return lowerValueToBits(value)[index];
206
207 return TypeSwitch<Operation *, Value>(op)
208 .Case<comb::ConcatOp>([&](comb::ConcatOp op) {
209 for (auto operand : llvm::reverse(op.getOperands())) {
210 auto width = hw::getBitWidth(operand.getType());
211 assert(width >= 0 && "operand has zero width");
212 if (index < static_cast<size_t>(width))
213 return extractBit(operand, index);
214 index -= static_cast<size_t>(width);
215 }
216 llvm_unreachable("index out of bounds");
217 })
218 .Case<comb::ExtractOp>([&](comb::ExtractOp ext) {
219 return extractBit(ext.getInput(),
220 static_cast<size_t>(ext.getLowBit()) + index);
221 })
222 .Case<comb::ReplicateOp>([&](comb::ReplicateOp op) {
223 return extractBit(op.getInput(),
224 index % static_cast<size_t>(hw::getBitWidth(
225 op.getOperand().getType())));
226 })
227 .Case<hw::ConstantOp>([&](hw::ConstantOp op) {
228 auto value = op.getValue();
229 return getBoolConstant(value[index]);
230 })
231 .Default([&](auto op) { return lowerValueToBits(value)[index]; });
232}
233
234ArrayRef<Value> BitBlaster::lowerValueToBits(Value value) {
235 auto *it = loweredValues.find(value);
236 if (it != loweredValues.end())
237 return it->second;
238
239 auto width = hw::getBitWidth(value.getType());
240 if (width <= 1)
241 return insertBits(value, {value});
242
243 auto *op = value.getDefiningOp();
244 if (!op) {
245 SmallVector<Value> results;
246 OpBuilder builder(value.getContext());
247 builder.setInsertionPointAfterValue(value);
248 comb::extractBits(builder, value, results);
249 return insertBits(value, std::move(results));
250 }
251
252 return TypeSwitch<Operation *, ArrayRef<Value>>(op)
253 .Case<aig::AndInverterOp, mig::MajorityInverterOp>(
254 [&](auto op) { return lowerInvertibleOperations(op); })
255 .Case<comb::AndOp, comb::OrOp, comb::XorOp>(
256 [&](auto op) { return lowerCombLogicOperations(op); })
257 .Case<comb::MuxOp>([&](comb::MuxOp op) { return lowerCombMux(op); })
258 .Default([&](auto op) {
259 OpBuilder builder(value.getContext());
260 builder.setInsertionPoint(op);
261 SmallVector<Value> results;
262 comb::extractBits(builder, value, results);
263
264 return insertBits(value, std::move(results));
265 });
266}
267
268LogicalResult BitBlaster::run() {
269 // Topologically sort operations in graph regions so that walk visits them in
270 // the topological order.
272 moduleOp, [](Value value, Operation *op) -> bool {
273 // Otherthan target ops, all other ops are always ready.
274 return !(shouldLowerOperation(op) ||
275 isa<comb::ExtractOp, comb::ReplicateOp, comb::ConcatOp,
276 comb::ReplicateOp>(op));
277 }))) {
278 // If we failed to topologically sort operations we cannot proceed.
279 return mlir::emitError(moduleOp.getLoc(), "there is a combinational cycle");
280 }
281
282 // Lower target operations
283 moduleOp.walk([&](Operation *op) {
284 // If the block is in a graph region, topologically sort it first.
285 if (shouldLowerOperation(op))
286 (void)lowerValueToBits(op->getResult(0));
287 });
288
289 // Replace operations with concatenated results if needed
290 for (auto &[value, results] :
291 llvm::make_early_inc_range(llvm::reverse(loweredValues))) {
292 if (hw::getBitWidth(value.getType()) <= 1)
293 continue;
294
295 auto *op = value.getDefiningOp();
296 if (!op)
297 continue;
298
299 if (value.use_empty()) {
300 op->erase();
301 continue;
302 }
303
304 // If a target operation still has an use (e.g. connected to output or
305 // instance), replace the value with the concatenated result.
306 if (shouldLowerOperation(op)) {
307 OpBuilder builder(op);
308 std::reverse(results.begin(), results.end());
309 auto concat = comb::ConcatOp::create(builder, value.getLoc(), results);
310 value.replaceAllUsesWith(concat);
311 op->erase();
312 }
313 }
314
315 return success();
316}
317
318Value BitBlaster::getBoolConstant(bool value) {
319 if (!constants[value]) {
320 auto builder = OpBuilder::atBlockBegin(moduleOp.getBodyBlock());
321 constants[value] = hw::ConstantOp::create(builder, builder.getUnknownLoc(),
322 builder.getI1Type(), value);
323 }
324 return constants[value];
325}
326
327template <typename OpTy>
328ArrayRef<Value> BitBlaster::lowerInvertibleOperations(OpTy op) {
329 auto createOp = [&](OpBuilder &builder, ValueRange operands) {
330 return builder.createOrFold<OpTy>(op.getLoc(), operands, op.getInverted());
331 };
332 return lowerOp(op, createOp);
333}
334
335template <typename OpTy>
336ArrayRef<Value> BitBlaster::lowerCombLogicOperations(OpTy op) {
337 auto createOp = [&](OpBuilder &builder, ValueRange operands) {
338 return builder.createOrFold<OpTy>(op.getLoc(), operands,
339 op.getTwoStateAttr());
340 };
341 return lowerOp(op, createOp);
342}
343
344ArrayRef<Value> BitBlaster::lowerCombMux(comb::MuxOp op) {
345 auto createOp = [&](OpBuilder &builder, ValueRange operands) {
346 return builder.createOrFold<comb::MuxOp>(op.getLoc(), operands[0],
347 operands[1], operands[2],
348 op.getTwoStateAttr());
349 };
350 return lowerOp(op, createOp);
351}
352
353ArrayRef<Value> BitBlaster::lowerOp(
354 Operation *op,
355 llvm::function_ref<Value(OpBuilder &builder, ValueRange)> createOp) {
356 auto value = op->getResult(0);
357 OpBuilder builder(op);
358 auto width = hw::getBitWidth(value.getType());
359 assert(width > 1 && "expected multi-bit operation");
360
361 auto known = computeKnownBits(value);
362 APInt knownMask = known.Zero | known.One;
363
364 // Update statistics
365 numLoweredConstants += knownMask.popcount();
366 numLoweredBits += width;
367 ++numLoweredOps;
368
369 SmallVector<Value> results;
370 results.reserve(width);
371
372 for (int64_t i = 0; i < width; ++i) {
373 SmallVector<Value> operands;
374 operands.reserve(op->getNumOperands());
375 if (knownMask[i]) {
376 // Use known constant value
377 results.push_back(getBoolConstant(known.One[i]));
378 continue;
379 }
380
381 // Extract the i-th bit from each operand
382 for (auto operand : op->getOperands())
383 operands.push_back(extractBit(operand, i));
384
385 // Create the single-bit operation
386 auto result = createOp(builder, operands);
387 results.push_back(result);
388
389 // Add name hint if present
390 if (auto name = op->getAttrOfType<StringAttr>("sv.namehint")) {
391 auto newName = StringAttr::get(
392 op->getContext(), name.getValue() + "[" + std::to_string(i) + "]");
393 if (auto *loweredOp = result.getDefiningOp())
394 loweredOp->setAttr("sv.namehint", newName);
395 }
396 }
397
398 assert(results.size() == static_cast<size_t>(width));
399 return insertBits(value, std::move(results));
400}
401
402//===----------------------------------------------------------------------===//
403// Pass Implementation
404//===----------------------------------------------------------------------===//
405
406namespace {
407struct LowerWordToBitsPass
408 : public impl::LowerWordToBitsBase<LowerWordToBitsPass> {
409 void runOnOperation() override;
410};
411} // namespace
412
413void LowerWordToBitsPass::runOnOperation() {
414 BitBlaster driver(getOperation());
415 if (failed(driver.run()))
416 return signalPassFailure();
417
418 // Update statistics
419 numLoweredBits += driver.numLoweredBits;
420 numLoweredConstants += driver.numLoweredConstants;
421 numLoweredOps += driver.numLoweredOps;
422}
assert(baseType &&"element must be base type")
static SmallVector< T > concat(const SmallVectorImpl< T > &a, const SmallVectorImpl< T > &b)
Returns a new vector containing the concatenation of vectors a and b.
Definition CalyxOps.cpp:540
static KnownBits computeKnownBits(Value v, unsigned depth)
Given an integer SSA value, check to see if we know anything about the result of the computation.
static bool shouldLowerOperation(Operation *op)
Check if an operation should be lowered to bit-level operations.
create(data_type, value)
Definition hw.py:433
LogicalResult topologicallySortGraphRegionBlocks(mlir::Operation *op, llvm::function_ref< bool(mlir::Value, mlir::Operation *)> isOperandReady)
This function performs a topological sort on the operations within each block of graph regions in the...
Definition SynthOps.cpp:306
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
int run(Type[Generator] generator=CppGenerator, cmdline_args=sys.argv)
Definition codegen.py:121
Definition synth.py:1