CIRCT 23.0.0git
Loading...
Searching...
No Matches
LowerVariadic.cpp
Go to the documentation of this file.
1//===- LowerVariadic.cpp - Lowering Variadic to Binary Ops ------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass lowers variadic operations to binary operations using a
10// delay-aware algorithm for commutative operations.
11//
12//===----------------------------------------------------------------------===//
13
21#include "mlir/Analysis/TopologicalSortUtils.h"
22#include "mlir/IR/Block.h"
23#include "mlir/IR/OpDefinition.h"
24#include "mlir/IR/PatternMatch.h"
25#include "mlir/IR/Value.h"
26#include "mlir/Support/LLVM.h"
27#include "llvm/ADT/SmallVector.h"
28#include "llvm/Support/Casting.h"
29#include "llvm/Support/Error.h"
30#include "llvm/Support/LogicalResult.h"
31#include "llvm/Support/raw_ostream.h"
32#include <iterator>
33#include <vector>
34
35#define DEBUG_TYPE "synth-lower-variadic"
36
37namespace circt {
38namespace synth {
39#define GEN_PASS_DEF_LOWERVARIADIC
40#include "circt/Dialect/Synth/Transforms/SynthPasses.h.inc"
41} // namespace synth
42} // namespace circt
43
44using namespace circt;
45using namespace synth;
46
47//===----------------------------------------------------------------------===//
48// Lower Variadic pass
49//===----------------------------------------------------------------------===//
50
51namespace {
52
53struct LowerVariadicPass : public impl::LowerVariadicBase<LowerVariadicPass> {
54 using LowerVariadicBase::LowerVariadicBase;
55 void runOnOperation() override;
56};
57
58} // namespace
59
60/// Construct a balanced binary tree from a variadic operation using a
61/// delay-aware algorithm. This function builds the tree by repeatedly combining
62/// the two values with the earliest arrival times, which minimizes the critical
63/// path delay.
64static LogicalResult replaceWithBalancedTree(
65 IncrementalLongestPathAnalysis *analysis, mlir::IRRewriter &rewriter,
66 Operation *op, llvm::function_ref<bool(OpOperand &)> isInverted,
67 llvm::function_ref<Value(ValueWithArrivalTime, ValueWithArrivalTime)>
68 createBinaryOp) {
69 // Collect all operands with their arrival times and inversion flags
70 SmallVector<ValueWithArrivalTime> operands;
71 size_t valueNumber = 0;
72
73 for (size_t i = 0, e = op->getNumOperands(); i < e; ++i) {
74 int64_t delay = 0;
75 // If analysis is available, use it to compute the delay.
76 // If not available, use zero delay and `valueNumber` will be used instead.
77 if (analysis) {
78 auto result = analysis->getMaxDelay(op->getOperand(i));
79 if (failed(result))
80 return failure();
81 delay = *result;
82 }
83 operands.push_back(ValueWithArrivalTime(op->getOperand(i), delay,
84 isInverted(op->getOpOperand(i)),
85 valueNumber++));
86 }
87
88 // Use shared tree building utility
89 auto result = buildBalancedTreeWithArrivalTimes<ValueWithArrivalTime>(
90 operands,
91 // Combine: create binary operation and compute new arrival time
92 [&](const ValueWithArrivalTime &lhs, const ValueWithArrivalTime &rhs) {
93 Value combined = createBinaryOp(lhs, rhs);
94 int64_t newDelay = 0;
95 if (analysis) {
96 auto delayResult = analysis->getMaxDelay(combined);
97 if (succeeded(delayResult))
98 newDelay = *delayResult;
99 }
100 return ValueWithArrivalTime(combined, newDelay, false, valueNumber++);
101 });
102
103 rewriter.replaceOp(op, result.getValue());
104 return success();
105}
106
107using OperandKey = llvm::SmallVector<std::pair<mlir::Value, bool>>;
108
109namespace llvm {
110template <>
113 // Return a vector containing the mlir::Value empty key
114 return {{DenseMapInfo<mlir::Value>::getEmptyKey(), false}};
115 }
116
118 // Return a vector containing the mlir::Value tombstone key
120 }
121
122 static unsigned getHashValue(const OperandKey &val) {
123 llvm::hash_code hash = 0;
124 // Iteratively combine the hash of each pair in the vector
125 for (const auto &pair : val) {
126 hash = llvm::hash_combine(
128 pair.second);
129 }
130 return static_cast<unsigned>(hash);
131 }
132
133 static bool isEqual(const OperandKey &lhs, const OperandKey &rhs) {
134 // std::vector and std::pair already implement operator==,
135 // which does a deep equality check of the elements.
136 return lhs == rhs;
137 }
138};
139} // namespace llvm
140
141// Struct for ordering the andInverterOp operations we have already seen
143 bool operator()(const std::pair<mlir::Value, bool> &lhs,
144 const std::pair<mlir::Value, bool> &rhs) const {
145 if (lhs.first != rhs.first) {
146 auto lhsArg = llvm::dyn_cast<mlir::BlockArgument>(lhs.first);
147 auto rhsArg = llvm::dyn_cast<mlir::BlockArgument>(rhs.first);
148 if (lhsArg && rhsArg)
149 return lhsArg.getArgNumber() < rhsArg.getArgNumber();
150 if (lhsArg)
151 return true;
152 if (rhsArg)
153 return false;
154
155 auto *lhsOp = lhs.first.getDefiningOp();
156 auto *rhsOp = rhs.first.getDefiningOp();
157 return lhsOp->isBeforeInBlock(rhsOp);
158 }
159 return lhs.second < rhs.second;
160 }
161};
162
163static OperandKey getSortedOperandKey(aig::AndInverterOp op) {
164 OperandKey key;
165 for (size_t i = 0, e = op.getNumOperands(); i < e; ++i)
166 key.emplace_back(op.getOperand(i), op.isInverted(i));
167
168 std::sort(key.begin(), key.end(), OperandPairLess());
169 return key;
170}
171
173 aig::AndInverterOp op, mlir::IRRewriter &rewriter,
174 llvm::DenseMap<OperandKey, mlir::Value> &seenExpressions) {
175
176 if (op.getNumOperands() <= 2)
177 return;
178
179 OperandKey allOperands = getSortedOperandKey(op);
180 mlir::SmallVector<Value> newValues;
181 mlir::SmallVector<bool> newInversions;
182
183 for (auto it = allOperands.begin(); it != allOperands.end(); ++it) {
184 // Look at the remaining operands from 'it' to the end
185 OperandKey remaining(it, allOperands.end());
186
187 auto match = seenExpressions.find(remaining);
188 if (match != seenExpressions.end() && match->second != op.getResult()) {
189 newValues.push_back(match->second);
190 newInversions.push_back(false);
191
192 // We found a match that covers everything from 'it' to the end,
193 // so we can stop searching.
194 break;
195 }
196
197 // No match, add it to the new list of values and inversions.
198 newValues.push_back(it->first);
199 newInversions.push_back(it->second);
200 }
201
202 if (newValues.size() < allOperands.size()) {
203 rewriter.modifyOpInPlace(op, [&]() {
204 op.getOperation()->setOperands(newValues);
205 op.setInverted(newInversions);
206 });
207 }
208}
209
210void LowerVariadicPass::runOnOperation() {
211 if (getOperation()->getNumRegions() != 1 ||
212 getOperation()->getRegion(0).getBlocks().size() != 1)
213 return;
214 // Topologically sort operations in graph regions to ensure operands are
215 // defined before uses.
216 mlir::Block &bodyBlock = getOperation()->getRegion(0).getBlocks().front();
217 auto *moduleOp = getOperation();
218
219 if (!mlir::sortTopologically(
220 &bodyBlock, [](Value val, Operation *op) -> bool {
221 if (isa_and_nonnull<hw::HWDialect>(op->getDialect()))
222 return isa<hw::InstanceOp>(op);
223 return !isa_and_nonnull<comb::CombDialect, synth::SynthDialect>(
224 op->getDialect());
225 })) {
226 mlir::emitError(moduleOp->getLoc())
227 << "Failed to topologically sort graph region blocks";
228 return signalPassFailure();
229 }
230
231 // Get longest path analysis if timing-aware lowering is enabled.
232 synth::IncrementalLongestPathAnalysis *analysis = nullptr;
233 if (timingAware.getValue()) {
234 if (!dyn_cast<hw::HWModuleOp>(moduleOp)) {
235 moduleOp->emitWarning(
236 "Longest Path Analysis failed: expected 'hw.module', but found '")
237 << moduleOp->getName().getStringRef()
238 << "'. Only HWModuleOps are currently supported.";
239 } else {
240 analysis = &getAnalysis<synth::IncrementalLongestPathAnalysis>();
241 }
242 }
243
244 // Build set of operation names to lower if specified.
245 SmallVector<OperationName> names;
246 for (const auto &name : opNames)
247 names.push_back(OperationName(name, &getContext()));
248
249 // Return true if the operation should be lowered.
250 auto shouldLower = [&](Operation *op) {
251 // If no names specified, lower all variadic ops.
252 if (names.empty())
253 return true;
254 return llvm::find(names, op->getName()) != names.end();
255 };
256
257 mlir::IRRewriter rewriter(&getContext());
258 rewriter.setListener(analysis);
259
260 // Simplify exising andInverterOps by reusing operations.
261 if (reuseSubsets) {
262 llvm::DenseMap<OperandKey, mlir::Value> seenExpressions;
263 // First collect all the andInverterOp operations in the block.
264 for (auto &op : bodyBlock.getOperations()) {
265 if (auto andInverterOp = llvm::dyn_cast<aig::AndInverterOp>(op)) {
266 OperandKey key = getSortedOperandKey(andInverterOp);
267 seenExpressions[key] = andInverterOp.getResult();
268 }
269 }
270 // Now try to replace operations with subsets.
271 for (auto &op : bodyBlock.getOperations()) {
272 if (auto andInverterOp = llvm::dyn_cast<aig::AndInverterOp>(op)) {
273 simplifyWithExistingOperations(andInverterOp, rewriter,
274 seenExpressions);
275 }
276 }
277 }
278
279 // FIXME: Currently only top-level operations are lowered due to the lack of
280 // topological sorting in across nested regions.
281 for (auto &opRef : llvm::make_early_inc_range(bodyBlock.getOperations())) {
282 auto *op = &opRef;
283 // Skip operations that don't need lowering or are already binary.
284 if (!shouldLower(op) || op->getNumOperands() <= 2)
285 continue;
286
287 rewriter.setInsertionPoint(op);
288
289 // Handle invertible Synth ops specially to preserve inversion flags.
290 auto result =
291 mlir::TypeSwitch<Operation *, LogicalResult>(op)
292 .Case<aig::AndInverterOp, XorInverterOp>([&](auto op) {
294 analysis, rewriter, op,
295 // Check if each operand is inverted.
296 [&](OpOperand &operand) {
297 return op.isInverted(operand.getOperandNumber());
298 },
299 // Create binary AndInverterOp with inversion flags.
300 [&](ValueWithArrivalTime lhs, ValueWithArrivalTime rhs) {
301 return decltype(op)::create(
302 rewriter, op->getLoc(), lhs.getValue(), rhs.getValue(),
303 lhs.isInverted(), rhs.isInverted());
304 });
305 })
306 .Default([&](Operation *op) {
307 // Handle commutative operations (and, or, xor, mul, add, etc.)
308 // using delay-aware lowering to minimize critical path.
309 if (isa_and_nonnull<comb::CombDialect>(op->getDialect()) &&
310 op->hasTrait<OpTrait::IsCommutative>())
312 analysis, rewriter, op,
313 // No inversion flags for standard commutative operations.
314 [](OpOperand &) { return false; },
315 // Create binary operation with the same operation type.
316 [&](ValueWithArrivalTime lhs, ValueWithArrivalTime rhs) {
317 OperationState state(op->getLoc(), op->getName());
318 state.addOperands(
319 ValueRange{lhs.getValue(), rhs.getValue()});
320 state.addTypes(op->getResult(0).getType());
321 auto *newOp = Operation::create(state);
322 rewriter.insert(newOp);
323 return newOp->getResult(0);
324 });
325 return success();
326 });
327 if (failed(result))
328 return signalPassFailure();
329 }
330}
static LogicalResult replaceWithBalancedTree(IncrementalLongestPathAnalysis *analysis, mlir::IRRewriter &rewriter, Operation *op, llvm::function_ref< bool(OpOperand &)> isInverted, llvm::function_ref< Value(ValueWithArrivalTime, ValueWithArrivalTime)> createBinaryOp)
Construct a balanced binary tree from a variadic operation using a delay-aware algorithm.
static void simplifyWithExistingOperations(aig::AndInverterOp op, mlir::IRRewriter &rewriter, llvm::DenseMap< OperandKey, mlir::Value > &seenExpressions)
llvm::SmallVector< std::pair< mlir::Value, bool > > OperandKey
static OperandKey getSortedOperandKey(aig::AndInverterOp op)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition synth.py:1
bool operator()(const std::pair< mlir::Value, bool > &lhs, const std::pair< mlir::Value, bool > &rhs) const
static unsigned getHashValue(const OperandKey &val)
static bool isEqual(const OperandKey &lhs, const OperandKey &rhs)