CIRCT 23.0.0git
Loading...
Searching...
No Matches
LowerVariadic.cpp
Go to the documentation of this file.
1//===- LowerVariadic.cpp - Lowering Variadic to Binary Ops ------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass lowers variadic operations to binary operations using a
10// delay-aware algorithm for commutative operations.
11//
12//===----------------------------------------------------------------------===//
13
21#include "mlir/Analysis/TopologicalSortUtils.h"
22#include "mlir/IR/Block.h"
23#include "mlir/IR/OpDefinition.h"
24#include "mlir/IR/PatternMatch.h"
25#include "mlir/IR/Value.h"
26#include "mlir/Support/LLVM.h"
27#include "llvm/ADT/SmallVector.h"
28#include "llvm/Support/Casting.h"
29#include "llvm/Support/Error.h"
30#include "llvm/Support/raw_ostream.h"
31#include <iterator>
32#include <vector>
33
34#define DEBUG_TYPE "synth-lower-variadic"
35
36namespace circt {
37namespace synth {
38#define GEN_PASS_DEF_LOWERVARIADIC
39#include "circt/Dialect/Synth/Transforms/SynthPasses.h.inc"
40} // namespace synth
41} // namespace circt
42
43using namespace circt;
44using namespace synth;
45
46//===----------------------------------------------------------------------===//
47// Lower Variadic pass
48//===----------------------------------------------------------------------===//
49
50namespace {
51
52struct LowerVariadicPass : public impl::LowerVariadicBase<LowerVariadicPass> {
53 using LowerVariadicBase::LowerVariadicBase;
54 void runOnOperation() override;
55};
56
57} // namespace
58
59/// Construct a balanced binary tree from a variadic operation using a
60/// delay-aware algorithm. This function builds the tree by repeatedly combining
61/// the two values with the earliest arrival times, which minimizes the critical
62/// path delay.
63static LogicalResult replaceWithBalancedTree(
64 IncrementalLongestPathAnalysis *analysis, mlir::IRRewriter &rewriter,
65 Operation *op, llvm::function_ref<bool(OpOperand &)> isInverted,
66 llvm::function_ref<Value(ValueWithArrivalTime, ValueWithArrivalTime)>
67 createBinaryOp) {
68 // Collect all operands with their arrival times and inversion flags
69 SmallVector<ValueWithArrivalTime> operands;
70 size_t valueNumber = 0;
71
72 for (size_t i = 0, e = op->getNumOperands(); i < e; ++i) {
73 int64_t delay = 0;
74 // If analysis is available, use it to compute the delay.
75 // If not available, use zero delay and `valueNumber` will be used instead.
76 if (analysis) {
77 auto result = analysis->getMaxDelay(op->getOperand(i));
78 if (failed(result))
79 return failure();
80 delay = *result;
81 }
82 operands.push_back(ValueWithArrivalTime(op->getOperand(i), delay,
83 isInverted(op->getOpOperand(i)),
84 valueNumber++));
85 }
86
87 // Use shared tree building utility
88 auto result = buildBalancedTreeWithArrivalTimes<ValueWithArrivalTime>(
89 operands,
90 // Combine: create binary operation and compute new arrival time
91 [&](const ValueWithArrivalTime &lhs, const ValueWithArrivalTime &rhs) {
92 Value combined = createBinaryOp(lhs, rhs);
93 int64_t newDelay = 0;
94 if (analysis) {
95 auto delayResult = analysis->getMaxDelay(combined);
96 if (succeeded(delayResult))
97 newDelay = *delayResult;
98 }
99 return ValueWithArrivalTime(combined, newDelay, false, valueNumber++);
100 });
101
102 rewriter.replaceOp(op, result.getValue());
103 return success();
104}
105
106using OperandKey = llvm::SmallVector<std::pair<mlir::Value, bool>>;
107
108namespace llvm {
109template <>
112 // Return a vector containing the mlir::Value empty key
113 return {{DenseMapInfo<mlir::Value>::getEmptyKey(), false}};
114 }
115
117 // Return a vector containing the mlir::Value tombstone key
119 }
120
121 static unsigned getHashValue(const OperandKey &val) {
122 llvm::hash_code hash = 0;
123 // Iteratively combine the hash of each pair in the vector
124 for (const auto &pair : val) {
125 hash = llvm::hash_combine(
127 pair.second);
128 }
129 return static_cast<unsigned>(hash);
130 }
131
132 static bool isEqual(const OperandKey &lhs, const OperandKey &rhs) {
133 // std::vector and std::pair already implement operator==,
134 // which does a deep equality check of the elements.
135 return lhs == rhs;
136 }
137};
138} // namespace llvm
139
140// Struct for ordering the andInverterOp operations we have already seen
142 bool operator()(const std::pair<mlir::Value, bool> &lhs,
143 const std::pair<mlir::Value, bool> &rhs) const {
144 if (lhs.first != rhs.first) {
145 auto lhsArg = llvm::dyn_cast<mlir::BlockArgument>(lhs.first);
146 auto rhsArg = llvm::dyn_cast<mlir::BlockArgument>(rhs.first);
147 if (lhsArg && rhsArg)
148 return lhsArg.getArgNumber() < rhsArg.getArgNumber();
149 if (lhsArg)
150 return true;
151 if (rhsArg)
152 return false;
153
154 auto *lhsOp = lhs.first.getDefiningOp();
155 auto *rhsOp = rhs.first.getDefiningOp();
156 return lhsOp->isBeforeInBlock(rhsOp);
157 }
158 return lhs.second < rhs.second;
159 }
160};
161
162static OperandKey getSortedOperandKey(aig::AndInverterOp op) {
163 OperandKey key;
164 for (size_t i = 0, e = op.getNumOperands(); i < e; ++i)
165 key.emplace_back(op.getOperand(i), op.isInverted(i));
166
167 std::sort(key.begin(), key.end(), OperandPairLess());
168 return key;
169}
170
172 aig::AndInverterOp op, mlir::IRRewriter &rewriter,
173 llvm::DenseMap<OperandKey, mlir::Value> &seenExpressions) {
174
175 if (op.getNumOperands() <= 2)
176 return;
177
178 OperandKey allOperands = getSortedOperandKey(op);
179 mlir::SmallVector<Value> newValues;
180 mlir::SmallVector<bool> newInversions;
181
182 for (auto it = allOperands.begin(); it != allOperands.end(); ++it) {
183 // Look at the remaining operands from 'it' to the end
184 OperandKey remaining(it, allOperands.end());
185
186 auto match = seenExpressions.find(remaining);
187 if (match != seenExpressions.end() && match->second != op.getResult()) {
188 newValues.push_back(match->second);
189 newInversions.push_back(false);
190
191 // We found a match that covers everything from 'it' to the end,
192 // so we can stop searching.
193 break;
194 }
195
196 // No match, add it to the new list of values and inversions.
197 newValues.push_back(it->first);
198 newInversions.push_back(it->second);
199 }
200
201 if (newValues.size() < allOperands.size()) {
202 rewriter.modifyOpInPlace(op, [&]() {
203 op.getOperation()->setOperands(newValues);
204 op.setInverted(newInversions);
205 });
206 }
207}
208
209void LowerVariadicPass::runOnOperation() {
210 // Topologically sort operations in graph regions to ensure operands are
211 // defined before uses.
212 if (!mlir::sortTopologically(
213 getOperation().getBodyBlock(), [](Value val, Operation *op) -> bool {
214 if (isa_and_nonnull<hw::HWDialect>(op->getDialect()))
215 return isa<hw::InstanceOp>(op);
216 return !isa_and_nonnull<comb::CombDialect, synth::SynthDialect>(
217 op->getDialect());
218 })) {
219 mlir::emitError(getOperation().getLoc())
220 << "Failed to topologically sort graph region blocks";
221 return signalPassFailure();
222 }
223
224 // Get longest path analysis if timing-aware lowering is enabled.
225 synth::IncrementalLongestPathAnalysis *analysis = nullptr;
226 if (timingAware.getValue())
227 analysis = &getAnalysis<synth::IncrementalLongestPathAnalysis>();
228
229 auto moduleOp = getOperation();
230
231 // Build set of operation names to lower if specified.
232 SmallVector<OperationName> names;
233 for (const auto &name : opNames)
234 names.push_back(OperationName(name, &getContext()));
235
236 // Return true if the operation should be lowered.
237 auto shouldLower = [&](Operation *op) {
238 // If no names specified, lower all variadic ops.
239 if (names.empty())
240 return true;
241 return llvm::find(names, op->getName()) != names.end();
242 };
243
244 mlir::IRRewriter rewriter(&getContext());
245 rewriter.setListener(analysis);
246
247 // Simplify exising andInverterOps by reusing operations.
248 if (reuseSubsets) {
249 llvm::DenseMap<OperandKey, mlir::Value> seenExpressions;
250 // First collect all the andInverterOp operations in the block.
251 for (auto &op : moduleOp.getBodyBlock()->getOperations()) {
252 if (auto andInverterOp = llvm::dyn_cast<aig::AndInverterOp>(op)) {
253 OperandKey key = getSortedOperandKey(andInverterOp);
254 seenExpressions[key] = andInverterOp.getResult();
255 }
256 }
257 // Now try to replace operations with subsets.
258 for (auto &op : moduleOp.getBodyBlock()->getOperations()) {
259 if (auto andInverterOp = llvm::dyn_cast<aig::AndInverterOp>(op)) {
260 simplifyWithExistingOperations(andInverterOp, rewriter,
261 seenExpressions);
262 }
263 }
264 }
265
266 // FIXME: Currently only top-level operations are lowered due to the lack of
267 // topological sorting in across nested regions.
268 for (auto &opRef :
269 llvm::make_early_inc_range(moduleOp.getBodyBlock()->getOperations())) {
270 auto *op = &opRef;
271 // Skip operations that don't need lowering or are already binary.
272 if (!shouldLower(op) || op->getNumOperands() <= 2)
273 continue;
274
275 rewriter.setInsertionPoint(op);
276
277 // Handle AndInverterOp specially to preserve inversion flags.
278 if (auto andInverterOp = dyn_cast<aig::AndInverterOp>(op)) {
279 auto result = replaceWithBalancedTree(
280 analysis, rewriter, op,
281 // Check if each operand is inverted.
282 [&](OpOperand &operand) {
283 return andInverterOp.isInverted(operand.getOperandNumber());
284 },
285 // Create binary AndInverterOp with inversion flags.
286 [&](ValueWithArrivalTime lhs, ValueWithArrivalTime rhs) {
287 return aig::AndInverterOp::create(
288 rewriter, op->getLoc(), lhs.getValue(), rhs.getValue(),
289 lhs.isInverted(), rhs.isInverted());
290 });
291 if (failed(result))
292 return signalPassFailure();
293 continue;
294 }
295
296 // Handle commutative operations (and, or, xor, mul, add, etc.) using
297 // delay-aware lowering to minimize critical path.
298 if (isa_and_nonnull<comb::CombDialect>(op->getDialect()) &&
299 op->hasTrait<OpTrait::IsCommutative>()) {
300 auto result = replaceWithBalancedTree(
301 analysis, rewriter, op,
302 // No inversion flags for standard commutative operations.
303 [](OpOperand &) { return false; },
304 // Create binary operation with the same operation type.
305 [&](ValueWithArrivalTime lhs, ValueWithArrivalTime rhs) {
306 OperationState state(op->getLoc(), op->getName());
307 state.addOperands(ValueRange{lhs.getValue(), rhs.getValue()});
308 state.addTypes(op->getResult(0).getType());
309 auto *newOp = Operation::create(state);
310 rewriter.insert(newOp);
311 return newOp->getResult(0);
312 });
313 if (failed(result))
314 return signalPassFailure();
315 }
316 }
317}
static LogicalResult replaceWithBalancedTree(IncrementalLongestPathAnalysis *analysis, mlir::IRRewriter &rewriter, Operation *op, llvm::function_ref< bool(OpOperand &)> isInverted, llvm::function_ref< Value(ValueWithArrivalTime, ValueWithArrivalTime)> createBinaryOp)
Construct a balanced binary tree from a variadic operation using a delay-aware algorithm.
static void simplifyWithExistingOperations(aig::AndInverterOp op, mlir::IRRewriter &rewriter, llvm::DenseMap< OperandKey, mlir::Value > &seenExpressions)
llvm::SmallVector< std::pair< mlir::Value, bool > > OperandKey
static OperandKey getSortedOperandKey(aig::AndInverterOp op)
static Location getLoc(DefSlot slot)
Definition Mem2Reg.cpp:216
static Block * getBodyBlock(FModuleLike mod)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition synth.py:1
bool operator()(const std::pair< mlir::Value, bool > &lhs, const std::pair< mlir::Value, bool > &rhs) const
static unsigned getHashValue(const OperandKey &val)
static bool isEqual(const OperandKey &lhs, const OperandKey &rhs)