Loading [MathJax]/extensions/tex2jax.js
CIRCT 21.0.0git
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
IntRangeOptimizations.cpp
Go to the documentation of this file.
1//===- IntRangeOptimizations.cpp - Narrow ops in comb ------------*- C++-*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
12#include "mlir/Transforms/DialectConversion.h"
13#include "llvm/ADT/TypeSwitch.h"
14
15#include "mlir/Analysis/DataFlowFramework.h"
16
17#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
18#include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h"
19#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
20
21using namespace circt;
22using namespace circt::comb;
23using namespace mlir;
24using namespace mlir::dataflow;
25
26namespace circt {
27namespace comb {
28#define GEN_PASS_DEF_COMBINTRANGENARROWING
29#include "circt/Dialect/Comb/Passes.h.inc"
30} // namespace comb
31} // namespace circt
32
33/// Gather ranges for all the values in `values`. Appends to the existing
34/// vector.
35static LogicalResult collectRanges(DataFlowSolver &solver, ValueRange values,
36 SmallVectorImpl<ConstantIntRanges> &ranges) {
37 for (Value val : values) {
38 auto *maybeInferredRange =
39 solver.lookupState<IntegerValueRangeLattice>(val);
40 if (!maybeInferredRange || maybeInferredRange->getValue().isUninitialized())
41 return failure();
42
43 const ConstantIntRanges &inferredRange =
44 maybeInferredRange->getValue().getValue();
45 ranges.push_back(inferredRange);
46 }
47 return success();
48}
49
50namespace {
51template <typename CombOpTy>
52struct CombOpNarrow : public OpRewritePattern<CombOpTy> {
53 CombOpNarrow(MLIRContext *context, DataFlowSolver &s)
54 : OpRewritePattern<CombOpTy>(context), solver(s) {}
55
56 LogicalResult matchAndRewrite(CombOpTy op,
57 PatternRewriter &rewriter) const override {
58
59 auto opWidth = op.getType().getIntOrFloatBitWidth();
60 if (op->getNumOperands() != 2 || op->getNumResults() != 1)
61 return rewriter.notifyMatchFailure(
62 op, "Only support binary operations with one result");
63
64 SmallVector<ConstantIntRanges> ranges;
65 if (failed(collectRanges(solver, op->getOperands(), ranges)))
66 return rewriter.notifyMatchFailure(op, "input without specified range");
67 if (failed(collectRanges(solver, op->getResults(), ranges)))
68 return rewriter.notifyMatchFailure(op, "output without specified range");
69
70 auto removeWidth = ranges[0].umax().countLeadingZeros();
71 for (const ConstantIntRanges &range : ranges) {
72 auto rangeCanRemove = range.umax().countLeadingZeros();
73 removeWidth = std::min(removeWidth, rangeCanRemove);
74 }
75 if (removeWidth == 0)
76 return rewriter.notifyMatchFailure(op, "no bits to remove");
77 if (removeWidth == opWidth)
78 return rewriter.notifyMatchFailure(
79 op, "all bits to remove - replace by zero");
80
81 // Replace operator by narrower version of itself
82 Value lhs = op.getOperand(0);
83 Value rhs = op.getOperand(1);
84
85 Location loc = op.getLoc();
86 auto newWidth = opWidth - removeWidth;
87 // Create a replacement type for the extracted bits
88 auto replaceType = rewriter.getIntegerType(newWidth);
89
90 // Extract the lsbs from each operand
91 auto extractLhsOp =
92 rewriter.create<comb::ExtractOp>(loc, replaceType, lhs, 0);
93 auto extractRhsOp =
94 rewriter.create<comb::ExtractOp>(loc, replaceType, rhs, 0);
95 auto narrowOp = rewriter.create<CombOpTy>(loc, extractLhsOp, extractRhsOp);
96
97 // Concatenate zeros to match the original operator width
98 auto zero =
99 rewriter.create<hw::ConstantOp>(loc, APInt::getZero(removeWidth));
100 auto replaceOp = rewriter.create<comb::ConcatOp>(
101 loc, op.getType(), ValueRange{zero, narrowOp});
102
103 rewriter.replaceOp(op, replaceOp);
104 return success();
105 }
106
107private:
108 DataFlowSolver &solver;
109};
110
111struct CombIntRangeNarrowingPass
112 : comb::impl::CombIntRangeNarrowingBase<CombIntRangeNarrowingPass> {
113
114 using CombIntRangeNarrowingBase::CombIntRangeNarrowingBase;
115 void runOnOperation() override;
116};
117} // namespace
118
119void CombIntRangeNarrowingPass::runOnOperation() {
120 Operation *op = getOperation();
121 MLIRContext *ctx = op->getContext();
122 DataFlowSolver solver;
123 solver.load<DeadCodeAnalysis>();
124 solver.load<IntegerRangeAnalysis>();
125 if (failed(solver.initializeAndRun(op)))
126 return signalPassFailure();
127
128 RewritePatternSet patterns(ctx);
130
131 if (failed(applyPatternsGreedily(op, std::move(patterns))))
132 signalPassFailure();
133}
134
135void comb::populateCombNarrowingPatterns(RewritePatternSet &patterns,
136 DataFlowSolver &solver) {
137 patterns.add<CombOpNarrow<comb::AddOp>, CombOpNarrow<comb::MulOp>,
138 CombOpNarrow<comb::SubOp>>(patterns.getContext(), solver);
139}
static LogicalResult collectRanges(DataFlowSolver &solver, ValueRange values, SmallVectorImpl< ConstantIntRanges > &ranges)
Gather ranges for all the values in values.
create(low_bit, result_type, input=None)
Definition comb.py:187
create(data_type, value)
Definition hw.py:433
void populateCombNarrowingPatterns(mlir::RewritePatternSet &patterns, mlir::DataFlowSolver &solver)
Add patterns for int range based narrowing.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition comb.py:1