11#include "mlir/IR/PatternMatch.h"
12#include "mlir/Support/LogicalResult.h"
22static bool isAlways(Attribute attr,
bool expected) {
23 if (
auto enable = dyn_cast_or_null<IntegerAttr>(attr))
24 return enable.getValue().getBoolValue() == expected;
28static bool isAlways(Value value,
bool expected) {
33 return constOp.getValue().getBoolValue() == expected;
42LogicalResult StateOp::fold(FoldAdaptor adaptor,
43 SmallVectorImpl<OpFoldResult> &results) {
45 if (getNumResults() > 0 && !getOperation()->hasAttr(
"name") &&
46 !getOperation()->hasAttr(
"names")) {
47 bool hasExplicitInitials = !getInitials().empty();
48 bool allInitialsConstant =
49 !hasExplicitInitials ||
50 llvm::all_of(adaptor.getInitials(),
51 [&](Attribute attr) { return !!attr; });
52 if (
isAlways(adaptor.getEnable(),
false) && allInitialsConstant) {
56 if (hasExplicitInitials)
57 results.append(adaptor.getInitials().begin(),
58 adaptor.getInitials().end());
60 for (
auto resTy : getResultTypes())
61 results.push_back(IntegerAttr::
get(resTy, 0));
64 if (!hasExplicitInitials &&
isAlways(adaptor.getReset(),
true)) {
67 for (
auto resTy : getResultTypes())
68 results.push_back(IntegerAttr::
get(resTy, 0));
74 if (
isAlways(adaptor.getReset(),
false))
75 return getResetMutable().clear(), success();
78 if (
isAlways(adaptor.getEnable(),
true))
79 return getEnableMutable().clear(), success();
84LogicalResult StateOp::canonicalize(StateOp op, PatternRewriter &rewriter) {
87 if (op->use_empty() && !op->hasAttr(
"name") && !op->hasAttr(
"names")) {
99LogicalResult MemoryWriteOp::fold(FoldAdaptor adaptor,
100 SmallVectorImpl<OpFoldResult> &results) {
101 if (
isAlways(adaptor.getEnable(),
true))
102 return getEnableMutable().clear(), success();
106LogicalResult MemoryWriteOp::canonicalize(MemoryWriteOp op,
107 PatternRewriter &rewriter) {
108 if (
isAlways(op.getEnable(),
false))
109 return rewriter.eraseOp(op), success();
117LogicalResult StorageGetOp::canonicalize(StorageGetOp op,
118 PatternRewriter &rewriter) {
119 if (
auto pred = op.getStorage().getDefiningOp<StorageGetOp>()) {
120 rewriter.modifyOpInPlace(op, [&] {
121 op.getStorageMutable().assign(pred.getStorage());
122 op.setOffset(op.getOffset() + pred.getOffset());
134 PatternRewriter &rewriter) {
135 BitVector toDelete(op.getBodyBlock().getNumArguments());
136 for (
auto arg : llvm::reverse(op.getBodyBlock().getArguments())) {
137 if (arg.use_empty()) {
138 auto i = arg.getArgNumber();
140 rewriter.modifyOpInPlace(op, [&] { op.getInputsMutable().erase(i); });
143 if (toDelete.any()) {
144 rewriter.modifyOpInPlace(
145 op, [&] { op.getBodyBlock().eraseArguments(toDelete); });
152 PatternRewriter &rewriter) {
153 SmallVector<Type> resultTypes;
154 for (
auto res : llvm::reverse(op->getResults())) {
155 if (res.use_empty()) {
156 auto *terminator = op.getBodyBlock().getTerminator();
157 rewriter.modifyOpInPlace(
158 terminator, [&] { terminator->eraseOperand(res.getResultNumber()); });
160 resultTypes.push_back(res.getType());
165 if (resultTypes.size() == op->getNumResults())
168 rewriter.setInsertionPoint(op);
170 auto newDomain = ClockDomainOp::create(rewriter, op.getLoc(), resultTypes,
171 op.getInputs(), op.getClock());
172 rewriter.inlineRegionBefore(op.getBody(), newDomain.getBody(),
173 newDomain->getRegion(0).begin());
175 unsigned currIdx = 0;
176 for (
auto result : op.getOutputs()) {
177 if (!result.use_empty())
178 rewriter.replaceAllUsesWith(result, newDomain->getResult(currIdx++));
181 rewriter.eraseOp(op);
185LogicalResult ClockDomainOp::canonicalize(ClockDomainOp op,
186 PatternRewriter &rewriter) {
187 rewriter.setInsertionPointToStart(&op.getBodyBlock());
190 DenseMap<Value, unsigned> seenArgs;
193 auto i = arg.getArgNumber();
194 auto inputVal = op.getInputs()[i];
200 if (seenArgs.count(inputVal)) {
201 rewriter.replaceAllUsesWith(
202 arg, op.getBodyBlock().getArgument(seenArgs[inputVal]));
208 if (
auto *inputOp = inputVal.getDefiningOp()) {
209 bool isConstant = inputOp->hasTrait<OpTrait::ConstantLike>();
210 bool hasOneUse = inputVal.hasOneUse();
211 if (isConstant || (isa<MemoryOp>(inputOp) && hasOneUse)) {
212 auto resultNumber = cast<OpResult>(inputVal).getResultNumber();
213 auto *clone = rewriter.clone(*inputOp);
214 rewriter.replaceAllUsesWith(arg, clone->getResult(resultNumber));
215 if (hasOneUse && inputOp->getNumResults() == 1) {
216 inputVal.dropAllUses();
217 rewriter.eraseOp(inputOp);
223 seenArgs[op.getInputs()[i]] = i;
229 for (
auto [result, terminatorOperand] :
llvm::zip(
230 op.getOutputs(), op.
getBodyBlock().getTerminator()->getOperands())) {
233 if (isa<BlockArgument>(terminatorOperand))
234 rewriter.replaceAllUsesWith(
235 result, op.getInputs()[cast<BlockArgument>(terminatorOperand)
246 if (
auto *defOp = terminatorOperand.getDefiningOp();
247 defOp && defOp->hasTrait<OpTrait::ConstantLike>() &&
248 !result.use_empty()) {
249 rewriter.setInsertionPointAfter(op);
250 unsigned resultIdx = cast<OpResult>(terminatorOperand).getResultNumber();
251 auto *clone = rewriter.clone(*defOp);
252 if (defOp->hasOneUse()) {
253 defOp->dropAllUses();
254 rewriter.eraseOp(defOp);
256 rewriter.replaceAllUsesWith(result, clone->getResult(resultIdx));
262 return success(didCanonicalizeInput || didCanoncalizeOutput);
static bool isAlways(Attribute attr, bool expected)
static bool removeUnusedClockDomainInputs(ClockDomainOp op, PatternRewriter &rewriter)
static bool removeUnusedClockDomainOutputs(ClockDomainOp op, PatternRewriter &rewriter)
static Block * getBodyBlock(FModuleLike mod)
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
bool isConstant(Operation *op)
Return true if the specified operation has a constant value.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.