CIRCT 23.0.0git
Loading...
Searching...
No Matches
SOPBalancing.cpp
Go to the documentation of this file.
1//===----------------------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements SOP (Sum-of-Products) balancing for delay optimization
10// based on "Delay Optimization Using SOP Balancing" by Mishchenko et al.
11// (ICCAD 2011).
12//
13// NOTE: Currently supports AIG but should be extended to other logic forms
14// (MIG, comb or/and) in the future.
15//
16//===----------------------------------------------------------------------===//
17
24#include "mlir/Support/LLVM.h"
25#include "llvm/ADT/APInt.h"
26#include "llvm/ADT/DenseMap.h"
27#include "llvm/ADT/SmallVector.h"
28#include "llvm/Support/Debug.h"
29
30#define DEBUG_TYPE "synth-sop-balancing"
31
32using namespace circt;
33using namespace circt::synth;
34using namespace mlir;
35
36namespace circt {
37namespace synth {
38#define GEN_PASS_DEF_SOPBALANCING
39#include "circt/Dialect/Synth/Transforms/SynthPasses.h.inc"
40} // namespace synth
41} // namespace circt
42
43using namespace circt::synth;
44
45//===----------------------------------------------------------------------===//
46// SOP Cache
47//===----------------------------------------------------------------------===//
48
49namespace {
50
51/// Expected maximum number of inputs for ISOP extraction, used as a hint for
52/// SmallVector capacity to avoid reallocations in common cases.
53constexpr unsigned expectedISOPInputs = 8;
54
55/// Cache for SOP extraction results, keyed by truth table.
56class SOPCache {
57public:
58 const SOPForm &getOrCompute(const APInt &truthTable, unsigned numVars) {
59 auto it = cache.find(truthTable);
60 if (it != cache.end())
61 return it->second;
62 return cache
63 .try_emplace(truthTable, circt::extractISOP(truthTable, numVars))
64 .first->second;
65 }
66
67private:
68 DenseMap<APInt, SOPForm> cache;
69};
70
71//===----------------------------------------------------------------------===//
72// Tree Building Helpers
73//===----------------------------------------------------------------------===//
74
75/// Simulate balanced tree and return output arrival time.
76static DelayType simulateBalancedTree(ArrayRef<DelayType> arrivalTimes) {
77 if (arrivalTimes.empty())
78 return 0;
79 return buildBalancedTreeWithArrivalTimes<DelayType>(
80 arrivalTimes, [](auto a, auto b) { return std::max(a, b) + 1; });
81}
82
83/// Build balanced AND tree.
85buildBalancedAndTree(OpBuilder &builder, Location loc,
86 SmallVectorImpl<ValueWithArrivalTime> &nodes,
87 size_t &valueNumbering) {
88 assert(!nodes.empty());
89
90 if (nodes.size() == 1)
91 return nodes[0];
92
93 auto result = buildBalancedTreeWithArrivalTimes<ValueWithArrivalTime>(
94 nodes, [&](const auto &n1, const auto &n2) {
95 Value v = aig::AndInverterOp::create(builder, loc, n1.getValue(),
96 n2.getValue(), n1.isInverted(),
97 n2.isInverted());
99 v, std::max(n1.getArrivalTime(), n2.getArrivalTime()) + 1, false,
100 valueNumbering++);
101 });
102 return result;
103}
104
105/// Build balanced SOP structure.
106Value buildBalancedSOP(OpBuilder &builder, Location loc, const SOPForm &sop,
107 ArrayRef<Value> inputs,
108 ArrayRef<DelayType> inputArrivalTimes) {
109 SmallVector<ValueWithArrivalTime, expectedISOPInputs> productTerms, literals;
110 size_t valueNumbering = 0;
111
112 for (const auto &cube : sop.cubes) {
113 for (unsigned i = 0; i < sop.numVars; ++i)
114 if (cube.hasLiteral(i))
115 literals.push_back(ValueWithArrivalTime(
116 inputs[i], inputArrivalTimes[i], cube.isLiteralInverted(i),
117 /*valueNumbering=*/valueNumbering++));
118
119 if (literals.empty())
120 continue;
121
122 // Get product term, and flip the inversion to construct OR afterwards.
123 productTerms.push_back(
124 buildBalancedAndTree(builder, loc, literals, valueNumbering)
125 .flipInversion());
126
127 literals.clear();
128 }
129
130 assert(!productTerms.empty() && "No product terms");
131
132 auto andOfInverted =
133 buildBalancedAndTree(builder, loc, productTerms, valueNumbering)
134 .flipInversion();
135 // Let's invert the output.
136 if (andOfInverted.isInverted())
137 return aig::AndInverterOp::create(builder, loc, andOfInverted.getValue(),
138 true);
139 return andOfInverted.getValue();
140}
141
142/// Compute SOP delays for cost estimation.
143void computeSOPDelays(const SOPForm &sop, ArrayRef<DelayType> inputArrivalTimes,
144 SmallVectorImpl<DelayType> &delays) {
145 SmallVector<DelayType, expectedISOPInputs> productArrivalTimes, literalTimes;
146 for (const auto &cube : sop.cubes) {
147 for (unsigned i = 0; i < sop.numVars; ++i)
148 // No need to consider inverted literals separately for delay.
149 if (cube.hasLiteral(i))
150 literalTimes.push_back(inputArrivalTimes[i]);
151 if (!literalTimes.empty()) {
152 productArrivalTimes.push_back(simulateBalancedTree(literalTimes));
153 literalTimes.clear();
154 }
155 }
156
157 DelayType outputTime = simulateBalancedTree(productArrivalTimes);
158
159 delays.resize(sop.numVars, 0);
160 // Compute the delay contribution of each input to the output for cost
161 // estimation. The CutRewriter framework requires per-input delays, even
162 // though this is somewhat artificial for SOP balancing. This may be
163 // improved in future framework improvements.
164 //
165 // First, determine which variables are actually used in the SOP by
166 // collecting a bitmask from all cubes.
167 uint64_t mask = 0;
168 for (auto &cube : sop.cubes)
169 mask |= cube.mask;
170
171 // Compute delay for each used input variable.
172 for (unsigned i = 0; i < sop.numVars; ++i)
173 if (mask & (1u << i))
174 delays[i] = outputTime - inputArrivalTimes[i];
175}
176
177//===----------------------------------------------------------------------===//
178// SOP Balancing Pattern
179//===----------------------------------------------------------------------===//
180
181/// Pattern that performs SOP balancing on cuts.
182struct SOPBalancingPattern : public CutRewritePattern {
183 SOPBalancingPattern(MLIRContext *context) : CutRewritePattern(context) {}
184
185 std::optional<MatchResult> match(CutEnumerator &enumerator,
186 const Cut &cut) const override {
187 const auto &network = enumerator.getLogicNetwork();
188 if (cut.isTrivialCut() || cut.getOutputSize(network) != 1)
189 return std::nullopt;
190
191 const auto &tt = *cut.getTruthTable();
192 const SOPForm &sop = sopCache.getOrCompute(tt.table, tt.numInputs);
193 if (sop.cubes.empty())
194 return std::nullopt;
195
196 SmallVector<DelayType, expectedISOPInputs> arrivalTimes;
197 if (failed(cut.getInputArrivalTimes(enumerator, arrivalTimes)))
198 return std::nullopt;
199
200 // Compute area estimate
201 unsigned totalGates = 0;
202 for (const auto &cube : sop.cubes)
203 if (cube.size() > 1)
204 totalGates += cube.size() - 1;
205 if (sop.cubes.size() > 1)
206 totalGates += sop.cubes.size() - 1;
207
208 SmallVector<DelayType, expectedISOPInputs> delays;
209 computeSOPDelays(sop, arrivalTimes, delays);
210
211 MatchResult result;
212 result.area = static_cast<double>(totalGates);
213 result.setOwnedDelays(std::move(delays));
214 return result;
215 }
216
217 FailureOr<Operation *> rewrite(OpBuilder &builder, CutEnumerator &enumerator,
218 const Cut &cut) const override {
219 const auto &network = enumerator.getLogicNetwork();
220 const auto &tt = *cut.getTruthTable();
221 const SOPForm &sop = sopCache.getOrCompute(tt.table, tt.numInputs);
222 LLVM_DEBUG({
223 llvm::dbgs() << "Rewriting SOP:\n";
224 sop.dump(llvm::dbgs());
225 });
226
227 SmallVector<DelayType, expectedISOPInputs> arrivalTimes;
228 if (failed(cut.getInputArrivalTimes(enumerator, arrivalTimes)))
229 return failure();
230
231 // Construct the fused location.
232 SetVector<Location> inputLocs;
233 auto *rootOp = network.getGate(cut.getRootIndex()).getOperation();
234 assert(rootOp && "cut root must be a valid operation");
235 inputLocs.insert(rootOp->getLoc());
236
237 SmallVector<Value> inputValues;
238 network.getValues(cut.inputs, inputValues);
239 for (auto input : inputValues)
240 inputLocs.insert(input.getLoc());
241
242 auto loc = builder.getFusedLoc(inputLocs.getArrayRef());
243
244 Value result =
245 buildBalancedSOP(builder, loc, sop, inputValues, arrivalTimes);
246
247 auto *op = result.getDefiningOp();
248 if (!op)
249 op = aig::AndInverterOp::create(builder, loc, result, false);
250 return op;
251 }
252
253 unsigned getNumOutputs() const override { return 1; }
254 StringRef getPatternName() const override { return "sop-balancing"; }
255
256private:
257 // Cache for SOP extraction results. Hence the pattern is stateful and must
258 // not be used in parallelly.
259 mutable SOPCache sopCache;
260};
261
262} // namespace
263
264//===----------------------------------------------------------------------===//
265// SOP Balancing Pass
266//===----------------------------------------------------------------------===//
267
269 : public circt::synth::impl::SOPBalancingBase<SOPBalancingPass> {
270 using SOPBalancingBase::SOPBalancingBase;
271
272 void runOnOperation() override {
273 auto module = getOperation();
274
275 CutRewriterOptions options;
276 options.strategy = strategy;
277 options.maxCutInputSize = maxCutInputSize;
278 options.maxCutSizePerRoot = maxCutsPerRoot;
279 options.allowNoMatch = true;
280
281 SmallVector<std::unique_ptr<CutRewritePattern>, 1> patterns;
282 patterns.push_back(
283 std::make_unique<SOPBalancingPattern>(module->getContext()));
284
285 CutRewritePatternSet patternSet(std::move(patterns));
286 CutRewriter rewriter(options, patternSet);
287 if (failed(rewriter.run(module)))
288 return signalPassFailure();
289 }
290};
assert(baseType &&"element must be base type")
static std::unique_ptr< Context > context
static Location getLoc(DefSlot slot)
Definition Mem2Reg.cpp:216
Strategy strategy
Cut enumeration engine for combinational logic networks.
const LogicNetwork & getLogicNetwork() const
Get the logic network (read-only).
Manages a collection of rewriting patterns for combinational logic optimization.
Main cut-based rewriting algorithm for combinational logic optimization.
LogicalResult run(Operation *topOp)
Execute the complete cut-based rewriting algorithm.
Represents a cut in the combinational logic network.
unsigned getOutputSize(const LogicNetwork &network) const
Get the number of outputs from root operation.
const std::optional< BinaryTruthTable > & getTruthTable() const
Get the truth table for this cut.
uint32_t getRootIndex() const
Get the root index in the LogicNetwork.
llvm::SmallVector< uint32_t, 6 > inputs
External inputs to this cut (cut boundary).
LogicalResult getInputArrivalTimes(CutEnumerator &enumerator, SmallVectorImpl< DelayType > &results) const
Get arrival times for each input of this cut.
bool isTrivialCut() const
Check if this cut represents a trivial cut.
Helper class for delay-aware tree building.
Definition SynthOps.h:58
int64_t DelayType
Definition CutRewriter.h:39
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
SOPForm extractISOP(const llvm::APInt &truthTable, unsigned numVars)
Extract ISOP (Irredundant Sum-of-Products) from a truth table.
Definition synth.py:1
void runOnOperation() override
Represents a sum-of-products expression.
Definition TruthTable.h:250
unsigned numVars
Definition TruthTable.h:252
void dump(llvm::raw_ostream &os=llvm::errs()) const
Debug dump method for SOP forms.
llvm::SmallVector< Cube > cubes
Definition TruthTable.h:251
Base class for cut rewriting patterns used in combinational logic optimization.
virtual StringRef getPatternName() const
Get the name of this pattern. Used for debugging.
virtual FailureOr< Operation * > rewrite(mlir::OpBuilder &builder, CutEnumerator &enumerator, const Cut &cut) const =0
Return a new operation that replaces the matched cut.
virtual std::optional< MatchResult > match(CutEnumerator &enumerator, const Cut &cut) const =0
Check if a cut matches this pattern and compute area/delay metrics.
virtual unsigned getNumOutputs() const =0
Get the number of outputs this pattern produces.
Configuration options for the cut-based rewriting algorithm.
unsigned maxCutInputSize
Maximum number of inputs allowed for any cut.
unsigned maxCutSizePerRoot
Maximum number of cuts to maintain per logic node.
bool allowNoMatch
Fail if there is a root operation that has no matching pattern.
OptimizationStrategy strategy
Optimization strategy (area vs. timing).
Result of matching a cut against a pattern.
void setOwnedDelays(SmallVector< DelayType, 6 > delays)
Set delays by transferring ownership (for dynamically computed delays).
double area
Area cost of implementing this cut with the pattern.