CIRCT 22.0.0git
Loading...
Searching...
No Matches
LowerVariadic.cpp
Go to the documentation of this file.
1//===- LowerVariadic.cpp - Lowering Variadic to Binary Ops ------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass lowers variadic operations to binary operations using a
10// delay-aware algorithm for commutative operations.
11//
12//===----------------------------------------------------------------------===//
13
21#include "mlir/Analysis/TopologicalSortUtils.h"
22#include "mlir/IR/OpDefinition.h"
23
24#define DEBUG_TYPE "synth-lower-variadic"
25
26namespace circt {
27namespace synth {
28#define GEN_PASS_DEF_LOWERVARIADIC
29#include "circt/Dialect/Synth/Transforms/SynthPasses.h.inc"
30} // namespace synth
31} // namespace circt
32
33using namespace circt;
34using namespace synth;
35
36//===----------------------------------------------------------------------===//
37// Lower Variadic pass
38//===----------------------------------------------------------------------===//
39
40namespace {
41
42struct LowerVariadicPass : public impl::LowerVariadicBase<LowerVariadicPass> {
43 using LowerVariadicBase::LowerVariadicBase;
44 void runOnOperation() override;
45};
46
47} // namespace
48
49/// Construct a balanced binary tree from a variadic operation using a
50/// delay-aware algorithm. This function builds the tree by repeatedly combining
51/// the two values with the earliest arrival times, which minimizes the critical
52/// path delay.
53static LogicalResult replaceWithBalancedTree(
54 IncrementalLongestPathAnalysis *analysis, mlir::IRRewriter &rewriter,
55 Operation *op, llvm::function_ref<bool(OpOperand &)> isInverted,
56 llvm::function_ref<Value(ValueWithArrivalTime, ValueWithArrivalTime)>
57 createBinaryOp) {
58 // Collect all operands with their arrival times and inversion flags
59 SmallVector<ValueWithArrivalTime> operands;
60 size_t valueNumber = 0;
61
62 for (size_t i = 0, e = op->getNumOperands(); i < e; ++i) {
63 int64_t delay = 0;
64 // If analysis is available, use it to compute the delay.
65 // If not available, use zero delay and `valueNumber` will be used instead.
66 if (analysis) {
67 auto result = analysis->getMaxDelay(op->getOperand(i));
68 if (failed(result))
69 return failure();
70 delay = *result;
71 }
72 operands.push_back(ValueWithArrivalTime(op->getOperand(i), delay,
73 isInverted(op->getOpOperand(i)),
74 valueNumber++));
75 }
76
77 // Use shared tree building utility
78 auto result = buildBalancedTreeWithArrivalTimes<ValueWithArrivalTime>(
79 operands,
80 // Combine: create binary operation and compute new arrival time
81 [&](const ValueWithArrivalTime &lhs, const ValueWithArrivalTime &rhs) {
82 Value combined = createBinaryOp(lhs, rhs);
83 int64_t newDelay = 0;
84 if (analysis) {
85 auto delayResult = analysis->getMaxDelay(combined);
86 if (succeeded(delayResult))
87 newDelay = *delayResult;
88 }
89 return ValueWithArrivalTime(combined, newDelay, false, valueNumber++);
90 });
91
92 rewriter.replaceOp(op, result.getValue());
93 return success();
94}
95
96void LowerVariadicPass::runOnOperation() {
97 // Topologically sort operations in graph regions to ensure operands are
98 // defined before uses.
99 if (!mlir::sortTopologically(
100 getOperation().getBodyBlock(), [](Value val, Operation *op) -> bool {
101 if (isa_and_nonnull<hw::HWDialect>(op->getDialect()))
102 return isa<hw::InstanceOp>(op);
103 return !isa_and_nonnull<comb::CombDialect, synth::SynthDialect>(
104 op->getDialect());
105 })) {
106 mlir::emitError(getOperation().getLoc())
107 << "Failed to topologically sort graph region blocks";
108 return signalPassFailure();
109 }
110
111 // Get longest path analysis if timing-aware lowering is enabled.
112 synth::IncrementalLongestPathAnalysis *analysis = nullptr;
113 if (timingAware.getValue())
114 analysis = &getAnalysis<synth::IncrementalLongestPathAnalysis>();
115
116 auto moduleOp = getOperation();
117
118 // Build set of operation names to lower if specified.
119 SmallVector<OperationName> names;
120 for (const auto &name : opNames)
121 names.push_back(OperationName(name, &getContext()));
122
123 // Return true if the operation should be lowered.
124 auto shouldLower = [&](Operation *op) {
125 // If no names specified, lower all variadic ops.
126 if (names.empty())
127 return true;
128 return llvm::find(names, op->getName()) != names.end();
129 };
130
131 mlir::IRRewriter rewriter(&getContext());
132 rewriter.setListener(analysis);
133
134 // FIXME: Currently only top-level operations are lowered due to the lack of
135 // topological sorting in across nested regions.
136 for (auto &opRef :
137 llvm::make_early_inc_range(moduleOp.getBodyBlock()->getOperations())) {
138 auto *op = &opRef;
139 // Skip operations that don't need lowering or are already binary.
140 if (!shouldLower(op) || op->getNumOperands() <= 2)
141 continue;
142
143 rewriter.setInsertionPoint(op);
144
145 // Handle AndInverterOp specially to preserve inversion flags.
146 if (auto andInverterOp = dyn_cast<aig::AndInverterOp>(op)) {
147 auto result = replaceWithBalancedTree(
148 analysis, rewriter, op,
149 // Check if each operand is inverted.
150 [&](OpOperand &operand) {
151 return andInverterOp.isInverted(operand.getOperandNumber());
152 },
153 // Create binary AndInverterOp with inversion flags.
154 [&](ValueWithArrivalTime lhs, ValueWithArrivalTime rhs) {
155 return aig::AndInverterOp::create(
156 rewriter, op->getLoc(), lhs.getValue(), rhs.getValue(),
157 lhs.isInverted(), rhs.isInverted());
158 });
159 if (failed(result))
160 return signalPassFailure();
161 continue;
162 }
163
164 // Handle commutative operations (and, or, xor, mul, add, etc.) using
165 // delay-aware lowering to minimize critical path.
166 if (isa_and_nonnull<comb::CombDialect>(op->getDialect()) &&
167 op->hasTrait<OpTrait::IsCommutative>()) {
168 auto result = replaceWithBalancedTree(
169 analysis, rewriter, op,
170 // No inversion flags for standard commutative operations.
171 [](OpOperand &) { return false; },
172 // Create binary operation with the same operation type.
173 [&](ValueWithArrivalTime lhs, ValueWithArrivalTime rhs) {
174 OperationState state(op->getLoc(), op->getName());
175 state.addOperands(ValueRange{lhs.getValue(), rhs.getValue()});
176 state.addTypes(op->getResult(0).getType());
177 auto *newOp = Operation::create(state);
178 rewriter.insert(newOp);
179 return newOp->getResult(0);
180 });
181 if (failed(result))
182 return signalPassFailure();
183 }
184 }
185}
static LogicalResult replaceWithBalancedTree(IncrementalLongestPathAnalysis *analysis, mlir::IRRewriter &rewriter, Operation *op, llvm::function_ref< bool(OpOperand &)> isInverted, llvm::function_ref< Value(ValueWithArrivalTime, ValueWithArrivalTime)> createBinaryOp)
Construct a balanced binary tree from a variadic operation using a delay-aware algorithm.
static Location getLoc(DefSlot slot)
Definition Mem2Reg.cpp:216
static Block * getBodyBlock(FModuleLike mod)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition synth.py:1