20 #include "mlir/Dialect/SCF/IR/SCF.h"
21 #include "mlir/Pass/Pass.h"
22 #include "llvm/Support/Debug.h"
24 #define DEBUG_TYPE "arc-mux-to-control-flow"
28 #define GEN_PASS_DEF_MUXTOCONTROLFLOW
29 #include "circt/Dialect/Arc/ArcPasses.h.inc"
33 using namespace circt;
43 struct MuxToControlFlowPass
44 :
public arc::impl::MuxToControlFlowBase<MuxToControlFlowPass> {
45 MuxToControlFlowPass() =
default;
46 MuxToControlFlowPass(
const MuxToControlFlowPass &pass)
47 : MuxToControlFlowPass() {}
49 void runOnOperation()
override;
51 Statistic numMuxesConverted{
52 this,
"num-muxes-converted",
53 "Number of muxes that were converted to if-statements"};
54 Statistic numMuxesRetained{
this,
"num-muxes-retained",
55 "Number of muxes that were not converted"};
60 BranchInfo() =
default;
61 BranchInfo(Value condition, Value trueValue, Value falseValue)
62 : condition(condition), trueValue(trueValue), falseValue(falseValue) {}
68 operator bool() {
return condition && trueValue && falseValue; }
81 SmallPtrSetImpl<Operation *> &visited) {
82 for (
auto res : curr->getResults()) {
83 for (
auto *user : res.getUsers()) {
87 if (!visited.contains(user) && user != mux)
92 if (user == mux && res != useValue)
95 if (user->getBlock() != curr->getBlock())
106 SmallPtrSetImpl<Operation *> &visited) {
107 auto *op = useValue.getDefiningOp();
111 SmallVector<Operation *> worklist{op};
113 while (!worklist.empty()) {
114 auto *curr = worklist.front();
115 worklist.erase(worklist.begin());
117 if (visited.contains(curr))
123 visited.insert(curr);
125 for (
auto val : curr->getOperands()) {
126 if (
auto *defOp = val.getDefiningOp())
127 worklist.push_back(defOp);
140 [[maybe_unused]]
static void
144 for (
auto &op : llvm::reverse(*ifOp->getBlock())) {
145 if (op.getNumResults() == 0)
149 SmallVector<Operation *> users;
150 for (
auto result : op.getResults())
151 users.append(llvm::to_vector(result.getUsers()));
153 auto parentsOfUsers =
154 llvm::map_range(users, [](
auto user) {
return user->getParentOp(); });
156 auto allUsersNestedInIf = llvm::any_of(parentsOfUsers, [&](
auto *parent) {
157 return !(isa<mlir::scf::IfOp>(parent) &&
158 parent->getBlock() == op.getBlock());
163 if (allUsersNestedInIf || !llvm::all_equal(parentsOfUsers))
166 DenseMap<Region *, Value> cloneMap;
167 for (
auto &use : llvm::make_early_inc_range(op.getUses())) {
168 auto *parentRegion = use.getOwner()->getParentRegion();
169 if (!cloneMap.count(parentRegion)) {
170 OpBuilder builder(&parentRegion->front().front());
171 cloneMap[parentRegion] = builder.clone(op)->getResult(0);
173 use.set(cloneMap[parentRegion]);
181 const SmallPtrSetImpl<Operation *> &thenOps,
182 const SmallPtrSetImpl<Operation *> &elseOps) {
183 if (op->getNumResults() != 1)
187 ImplicitLocOpBuilder builder(op->getLoc(), op);
188 mlir::scf::IfOp ifOp = builder.create<mlir::scf::IfOp>(
190 [&](OpBuilder &builder, Location loc) {
191 builder.create<mlir::scf::YieldOp>(loc, info.trueValue);
193 [&](OpBuilder &builder, Location loc) {
194 builder.create<mlir::scf::YieldOp>(loc, info.falseValue);
197 op->getResult(0).replaceAllUsesWith(ifOp.getResult(0));
200 llvm::make_early_inc_range(op->getParentRegion()->getOps())) {
203 if (thenOps.contains(&ops))
204 ops.moveBefore(ifOp.thenBlock()->getTerminator());
208 if (elseOps.contains(&ops))
209 ops.moveBefore(ifOp.elseBlock()->getTerminator());
223 for (
auto *op : ops) {
224 if (
auto *runtimeCostIF =
225 dyn_cast<RuntimeCostEstimateDialectInterface>(op->getDialect())) {
226 cost += runtimeCostIF->getCostEstimate(op);
228 LLVM_DEBUG(llvm::dbgs() <<
"No runtime cost estimate was provided for '"
229 << op->getName() <<
"', using default of 10\n");
246 if (
auto mux = dyn_cast<comb::MuxOp>(op))
247 return BranchInfo{mux.getCond(), mux.getTrueValue(), mux.getFalseValue()};
258 const SmallPtrSetImpl<Operation *> &thenOps,
259 const SmallPtrSetImpl<Operation *> &elseOps) {
267 if (
auto parent = op->getParentOfType<mlir::scf::IfOp>()) {
268 SmallPtrSet<Operation *, 32> ifBranchOps;
270 for (
auto &nestedOp : *op->getBlock()) {
271 if (!thenOps.contains(&nestedOp) && !elseOps.contains(&nestedOp))
272 ifBranchOps.insert(&nestedOp);
280 return (thenCost >= 100 || thenCost == 0) &&
281 (elseCost >= 100 || elseCost == 0) &&
282 std::abs((
int)thenCost - (int)elseCost) >= 100;
292 void MuxToControlFlowPass::runOnOperation() {
298 SmallVector<Operation *> supportedOps;
299 getOperation()->walk<mlir::WalkOrder::PreOrder>([&](Operation *op) {
303 if (isa<hw::HWModuleOp, arc::ModelOp>(op))
304 return WalkResult::skip();
307 supportedOps.push_back(op);
309 return WalkResult::advance();
315 for (
auto *op : llvm::reverse(supportedOps)) {
321 llvm::SmallPtrSet<Operation *, 32> thenOps, elseOps;
337 return std::make_unique<MuxToControlFlowPass>();
static void doConversion(Operation *op, BranchInfo info, const SmallPtrSetImpl< Operation * > &thenOps, const SmallPtrSetImpl< Operation * > &elseOps)
Perform the actual conversion.
static uint32_t getCostEstimate(const SmallPtrSetImpl< Operation * > &ops)
Simple helper to invoke the runtime cost interface for every operation in a set and sum up the costs.
static bool isBeneficialToConvert(Operation *op, const SmallPtrSetImpl< Operation * > &thenOps, const SmallPtrSetImpl< Operation * > &elseOps)
Use the cost measure of each branch to heuristically decide whether to actually perform the conversio...
static void computeFanIn(Operation *mux, Value useValue, SmallPtrSetImpl< Operation * > &visited)
Compute the set of operations that would only be used in the branch represented by.
static bool isValidToProceedTraversal(Operation *mux, Operation *curr, Value useValue, SmallPtrSetImpl< Operation * > &visited)
Check whether.
static BranchInfo getConversionInfo(Operation *op)
Convert concrete operations that should be converted to if-statements to a more abstract representati...
static void cloneOpsIntoBranchesWhenUsedInBoth(mlir::scf::IfOp ifOp)
Clone ops that are used in both branches of an if-statement but not outside of it.
std::unique_ptr< mlir::Pass > createMuxToControlFlowPass()
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.