CIRCT  20.0.0git
DCOps.cpp
Go to the documentation of this file.
1 //===- DCOps.cpp ----------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 #include "mlir/IR/Builders.h"
11 #include "mlir/IR/Diagnostics.h"
12 #include "mlir/IR/OpImplementation.h"
13 #include "mlir/IR/PatternMatch.h"
14 #include "mlir/Interfaces/FunctionImplementation.h"
15 #include "mlir/Interfaces/SideEffectInterfaces.h"
16 
17 using namespace circt;
18 using namespace dc;
19 using namespace mlir;
20 
22  auto vt = dyn_cast<ValueType>(t);
23  if (!vt)
24  return false;
25  auto innerWidth = vt.getInnerType().getIntOrFloatBitWidth();
26  return innerWidth == 1;
27 }
28 
29 namespace circt {
30 namespace dc {
31 
32 // =============================================================================
33 // JoinOp
34 // =============================================================================
35 
36 OpFoldResult JoinOp::fold(FoldAdaptor adaptor) {
37  // Fold simple joins (joins with 1 input).
38  if (auto tokens = getTokens(); tokens.size() == 1)
39  return tokens.front();
40 
41  return {};
42 }
43 
44 struct JoinOnBranchPattern : public OpRewritePattern<JoinOp> {
46  LogicalResult matchAndRewrite(JoinOp op,
47  PatternRewriter &rewriter) const override {
48 
49  struct BranchOperandInfo {
50  // Unique operands from the branch op, in case we have the same operand
51  // from the branch op multiple times.
52  SetVector<Value> uniqueOperands;
53  // Indices which the operands are at in the join op.
54  BitVector indices;
55  };
56 
57  DenseMap<BranchOp, BranchOperandInfo> branchOperands;
58  for (auto &opOperand : op->getOpOperands()) {
59  auto branch = opOperand.get().getDefiningOp<BranchOp>();
60  if (!branch)
61  continue;
62 
63  BranchOperandInfo &info = branchOperands[branch];
64  info.uniqueOperands.insert(opOperand.get());
65  info.indices.resize(op->getNumOperands());
66  info.indices.set(opOperand.getOperandNumber());
67  }
68 
69  if (branchOperands.empty())
70  return failure();
71 
72  // Do we have both operands from any given branch op?
73  for (auto &it : branchOperands) {
74  auto branch = it.first;
75  auto &operandInfo = it.second;
76  if (operandInfo.uniqueOperands.size() != 2) {
77  // We don't have both operands from the branch op.
78  continue;
79  }
80 
81  // We have both operands from the branch op. Replace the join op with the
82  // branch op's data operand.
83 
84  // Unpack the !dc.value<i1> input to the branch op
85  auto unpacked =
86  rewriter.create<UnpackOp>(op.getLoc(), branch.getCondition());
87  rewriter.modifyOpInPlace(op, [&]() {
88  op->eraseOperands(operandInfo.indices);
89  op.getTokensMutable().append({unpacked.getToken()});
90  });
91 
92  // Only attempt a single branch at a time - else we'd have to maintain
93  // OpOperand indices during the loop... too complicated, let recursive
94  // pattern application handle this.
95  return success();
96  }
97 
98  return failure();
99  }
100 };
103  LogicalResult matchAndRewrite(JoinOp op,
104  PatternRewriter &rewriter) const override {
105  for (OpOperand &operand : llvm::make_early_inc_range(op->getOpOperands())) {
106  auto otherJoin = operand.get().getDefiningOp<dc::JoinOp>();
107  if (!otherJoin) {
108  // Operand does not originate from a join so it's a valid join input.
109  continue;
110  }
111 
112  // Operand originates from a join. Erase the current join operand and
113  // add all of the otherJoin op's inputs to this join.
114  // DCE will take care of otherJoin in case it's no longer used.
115  rewriter.modifyOpInPlace(op, [&]() {
116  op.getTokensMutable().erase(operand.getOperandNumber());
117  op.getTokensMutable().append(otherJoin.getTokens());
118  });
119  return success();
120  }
121  return failure();
122  }
123 };
124 
127  LogicalResult matchAndRewrite(JoinOp op,
128  PatternRewriter &rewriter) const override {
129  for (OpOperand &operand : llvm::make_early_inc_range(op->getOpOperands())) {
130  if (auto source = operand.get().getDefiningOp<dc::SourceOp>()) {
131  rewriter.modifyOpInPlace(
132  op, [&]() { op->eraseOperand(operand.getOperandNumber()); });
133  return success();
134  }
135  }
136  return failure();
137  }
138 };
139 
142  LogicalResult matchAndRewrite(JoinOp op,
143  PatternRewriter &rewriter) const override {
144  llvm::DenseSet<Value> uniqueOperands;
145  for (OpOperand &operand : llvm::make_early_inc_range(op->getOpOperands())) {
146  if (!uniqueOperands.insert(operand.get()).second) {
147  rewriter.modifyOpInPlace(
148  op, [&]() { op->eraseOperand(operand.getOperandNumber()); });
149  return success();
150  }
151  }
152  return failure();
153  }
154 };
155 
156 void JoinOp::getCanonicalizationPatterns(RewritePatternSet &results,
157  MLIRContext *context) {
160 }
161 
162 // =============================================================================
163 // ForkOp
164 // =============================================================================
165 
166 template <typename TInt>
167 static ParseResult parseIntInSquareBrackets(OpAsmParser &parser, TInt &v) {
168  if (parser.parseLSquare() || parser.parseInteger(v) || parser.parseRSquare())
169  return failure();
170  return success();
171 }
172 
173 ParseResult ForkOp::parse(OpAsmParser &parser, OperationState &result) {
174  OpAsmParser::UnresolvedOperand operand;
175  size_t size = 0;
176  if (parseIntInSquareBrackets(parser, size))
177  return failure();
178 
179  if (size == 0)
180  return parser.emitError(parser.getNameLoc(),
181  "fork size must be greater than 0");
182 
183  if (parser.parseOperand(operand) ||
184  parser.parseOptionalAttrDict(result.attributes))
185  return failure();
186 
187  auto tt = dc::TokenType::get(parser.getContext());
188  llvm::SmallVector<Type> operandTypes{tt};
189  SmallVector<Type> resultTypes{size, tt};
190  result.addTypes(resultTypes);
191  if (parser.resolveOperand(operand, tt, result.operands))
192  return failure();
193  return success();
194 }
195 
196 void ForkOp::print(OpAsmPrinter &p) {
197  p << " [" << getNumResults() << "] ";
198  p << getOperand() << " ";
199  auto attrs = (*this)->getAttrs();
200  if (!attrs.empty()) {
201  p << " ";
202  p.printOptionalAttrDict(attrs);
203  }
204 }
205 
207  // Canonicalization of forks where the output is fed into another fork.
208 public:
210  LogicalResult matchAndRewrite(ForkOp fork,
211  PatternRewriter &rewriter) const override {
212  for (auto output : fork.getOutputs()) {
213  for (auto *user : output.getUsers()) {
214  auto userFork = dyn_cast<ForkOp>(user);
215  if (!userFork)
216  continue;
217 
218  // We have a fork feeding into another fork. Replace the output fork by
219  // adding more outputs to the current fork.
220  size_t totalForks = fork.getNumResults() + userFork.getNumResults();
221 
222  auto newFork = rewriter.create<dc::ForkOp>(fork.getLoc(),
223  fork.getToken(), totalForks);
224  rewriter.replaceOp(
225  fork, newFork.getResults().take_front(fork.getNumResults()));
226  rewriter.replaceOp(
227  userFork, newFork.getResults().take_back(userFork.getNumResults()));
228 
229  // Just stop the pattern here instead of trying to do more - let the
230  // canonicalizer recurse if another run of the canonicalization applies.
231  return success();
232  }
233  }
234  return failure();
235  }
236 };
237 
239  // Canonicalizes away forks on source ops, in favor of individual source
240  // operations. Having standalone sources are a better alternative, since other
241  // operations can canonicalize on it (e.g. joins) as well as being very cheap
242  // to implement in hardware, if they do remain.
243 public:
245  LogicalResult matchAndRewrite(ForkOp fork,
246  PatternRewriter &rewriter) const override {
247  auto source = fork.getToken().getDefiningOp<SourceOp>();
248  if (!source)
249  return failure();
250 
251  // We have a source feeding into a fork. Replace the fork by a source for
252  // each output.
253  llvm::SmallVector<Value> sources;
254  for (size_t i = 0; i < fork.getNumResults(); ++i)
255  sources.push_back(rewriter.create<dc::SourceOp>(fork.getLoc()));
256 
257  rewriter.replaceOp(fork, sources);
258  return success();
259  }
260 };
261 
264 
265  LogicalResult matchAndRewrite(ForkOp op,
266  PatternRewriter &rewriter) const override {
267  std::set<unsigned> unusedIndexes;
268 
269  for (auto res : llvm::enumerate(op.getResults()))
270  if (res.value().use_empty())
271  unusedIndexes.insert(res.index());
272 
273  if (unusedIndexes.empty())
274  return failure();
275 
276  // Create a new fork op, dropping the unused results.
277  rewriter.setInsertionPoint(op);
278  auto operand = op.getOperand();
279  auto newFork = rewriter.create<ForkOp>(
280  op.getLoc(), operand, op.getNumResults() - unusedIndexes.size());
281  unsigned i = 0;
282  for (auto oldRes : llvm::enumerate(op.getResults()))
283  if (unusedIndexes.count(oldRes.index()) == 0)
284  rewriter.replaceAllUsesWith(oldRes.value(), newFork.getResults()[i++]);
285  rewriter.eraseOp(op);
286  return success();
287  }
288 };
289 
290 void ForkOp::getCanonicalizationPatterns(RewritePatternSet &results,
291  MLIRContext *context) {
294 }
295 
296 LogicalResult ForkOp::fold(FoldAdaptor adaptor,
297  SmallVectorImpl<OpFoldResult> &results) {
298  // Fold simple forks (forks with 1 output).
299  if (getOutputs().size() == 1) {
300  results.push_back(getToken());
301  return success();
302  }
303 
304  return failure();
305 }
306 
307 // =============================================================================
308 // UnpackOp
309 // =============================================================================
310 
312  // Eliminates unpacks where only the token is used.
314  LogicalResult matchAndRewrite(UnpackOp unpack,
315  PatternRewriter &rewriter) const override {
316  // Is the value-side of the unpack used?
317  if (!unpack.getOutput().use_empty())
318  return failure();
319 
320  auto pack = unpack.getInput().getDefiningOp<PackOp>();
321  if (!pack)
322  return failure();
323 
324  // Replace all uses of the unpack token with the packed token.
325  rewriter.replaceAllUsesWith(unpack.getToken(), pack.getToken());
326  rewriter.eraseOp(unpack);
327  return success();
328  }
329 };
330 
331 void UnpackOp::getCanonicalizationPatterns(RewritePatternSet &results,
332  MLIRContext *context) {
333  results.insert<EliminateRedundantUnpackPattern>(context);
334 }
335 
336 LogicalResult UnpackOp::fold(FoldAdaptor adaptor,
337  SmallVectorImpl<OpFoldResult> &results) {
338  // Unpack of a pack is a no-op.
339  if (auto pack = getInput().getDefiningOp<PackOp>()) {
340  results.push_back(pack.getToken());
341  results.push_back(pack.getInput());
342  return success();
343  }
344 
345  return failure();
346 }
347 
348 LogicalResult UnpackOp::inferReturnTypes(
349  MLIRContext *context, std::optional<Location> loc, ValueRange operands,
350  DictionaryAttr attrs, mlir::OpaqueProperties properties,
351  mlir::RegionRange regions, SmallVectorImpl<Type> &results) {
352  auto inputType = cast<ValueType>(operands.front().getType());
353  results.push_back(TokenType::get(context));
354  results.push_back(inputType.getInnerType());
355  return success();
356 }
357 
358 // =============================================================================
359 // PackOp
360 // =============================================================================
361 
362 OpFoldResult PackOp::fold(FoldAdaptor adaptor) {
363  auto token = getToken();
364 
365  // Pack of an unpack is a no-op.
366  if (auto unpack = token.getDefiningOp<UnpackOp>()) {
367  if (unpack.getOutput() == getInput())
368  return unpack.getInput();
369  }
370  return {};
371 }
372 
373 LogicalResult PackOp::inferReturnTypes(
374  MLIRContext *context, std::optional<Location> loc, ValueRange operands,
375  DictionaryAttr attrs, mlir::OpaqueProperties properties,
376  mlir::RegionRange regions, SmallVectorImpl<Type> &results) {
377  llvm::SmallVector<Type> inputTypes;
378  Type inputType = operands.back().getType();
379  auto valueType = dc::ValueType::get(context, inputType);
380  results.push_back(valueType);
381  return success();
382 }
383 
384 // =============================================================================
385 // SelectOp
386 // =============================================================================
387 
389  // Canonicalize away a select that is fed only by a single branch
390  // example:
391  // %true, %false = dc.branch %sel1 %token
392  // %0 = dc.select %sel2, %true, %false
393  // ->
394  // %0 = dc.join %sel1, %sel2, %token
395 
396 public:
398  LogicalResult matchAndRewrite(SelectOp select,
399  PatternRewriter &rewriter) const override {
400  // Do all the inputs come from a branch?
401  BranchOp branchInput;
402  for (auto operand : {select.getTrueToken(), select.getFalseToken()}) {
403  auto br = operand.getDefiningOp<BranchOp>();
404  if (!br)
405  return failure();
406 
407  if (!branchInput)
408  branchInput = br;
409  else if (branchInput != br)
410  return failure();
411  }
412 
413  // Replace the select with a join (unpack the select conditions).
414  rewriter.replaceOpWithNewOp<JoinOp>(
415  select,
416  llvm::SmallVector<Value>{
417  rewriter.create<UnpackOp>(select.getLoc(), select.getCondition())
418  .getToken(),
419  rewriter
420  .create<UnpackOp>(branchInput.getLoc(),
421  branchInput.getCondition())
422  .getToken()});
423 
424  return success();
425  }
426 };
427 
428 void SelectOp::getCanonicalizationPatterns(RewritePatternSet &results,
429  MLIRContext *context) {
430  results.insert<EliminateBranchToSelectPattern>(context);
431 }
432 
433 // =============================================================================
434 // BufferOp
435 // =============================================================================
436 
437 FailureOr<SmallVector<int64_t>> BufferOp::getInitValueArray() {
438  assert(getInitValues() && "initValues attribute not set");
439  SmallVector<int64_t> values;
440  for (auto value : getInitValuesAttr()) {
441  if (auto iValue = dyn_cast<IntegerAttr>(value)) {
442  values.push_back(iValue.getValue().getSExtValue());
443  } else {
444  return emitError() << "initValues attribute must be an array of integers";
445  }
446  }
447  return values;
448 }
449 
450 LogicalResult BufferOp::verify() {
451  // Verify that exactly 'size' number of initial values have been provided, if
452  // an initializer list have been provided.
453  if (auto initVals = getInitValuesAttr()) {
454  auto nInits = initVals.size();
455  if (nInits != getSize())
456  return emitOpError() << "expected " << getSize()
457  << " init values but got " << nInits << ".";
458  }
459 
460  return success();
461 }
462 
463 // =============================================================================
464 // ToESIOp
465 // =============================================================================
466 
467 LogicalResult ToESIOp::inferReturnTypes(
468  MLIRContext *context, std::optional<Location> loc, ValueRange operands,
469  DictionaryAttr attrs, mlir::OpaqueProperties properties,
470  mlir::RegionRange regions, SmallVectorImpl<Type> &results) {
471  Type channelEltType;
472  if (auto valueType = dyn_cast<ValueType>(operands.front().getType()))
473  channelEltType = valueType.getInnerType();
474  else {
475  // dc.token => esi.channel<i0>
476  channelEltType = IntegerType::get(context, 0);
477  }
478 
479  results.push_back(esi::ChannelType::get(context, channelEltType));
480  return success();
481 }
482 
483 // =============================================================================
484 // FromESIOp
485 // =============================================================================
486 
487 LogicalResult FromESIOp::inferReturnTypes(
488  MLIRContext *context, std::optional<Location> loc, ValueRange operands,
489  DictionaryAttr attrs, mlir::OpaqueProperties properties,
490  mlir::RegionRange regions, SmallVectorImpl<Type> &results) {
491  auto innerType =
492  cast<esi::ChannelType>(operands.front().getType()).getInner();
493  if (auto intType = dyn_cast<IntegerType>(innerType); intType.getWidth() == 0)
494  results.push_back(dc::TokenType::get(context));
495  else
496  results.push_back(dc::ValueType::get(context, innerType));
497 
498  return success();
499 }
500 
501 } // namespace dc
502 } // namespace circt
503 
504 #define GET_OP_CLASSES
505 #include "circt/Dialect/DC/DC.cpp.inc"
assert(baseType &&"element must be base type")
LogicalResult matchAndRewrite(SelectOp select, PatternRewriter &rewriter) const override
Definition: DCOps.cpp:398
LogicalResult matchAndRewrite(ForkOp fork, PatternRewriter &rewriter) const override
Definition: DCOps.cpp:245
LogicalResult matchAndRewrite(ForkOp fork, PatternRewriter &rewriter) const override
Definition: DCOps.cpp:210
static LogicalResult verify(Value clock, bool eventExists, mlir::Location loc)
Definition: SVOps.cpp:2467
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:55
static ParseResult parseIntInSquareBrackets(OpAsmParser &parser, TInt &v)
Definition: DCOps.cpp:167
bool isI1ValueType(Type t)
Definition: DCOps.cpp:21
mlir::Type innerType(mlir::Type type)
Definition: ESITypes.cpp:184
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21
LogicalResult matchAndRewrite(UnpackOp unpack, PatternRewriter &rewriter) const override
Definition: DCOps.cpp:314
LogicalResult matchAndRewrite(ForkOp op, PatternRewriter &rewriter) const override
Definition: DCOps.cpp:265
LogicalResult matchAndRewrite(JoinOp op, PatternRewriter &rewriter) const override
Definition: DCOps.cpp:46
LogicalResult matchAndRewrite(JoinOp op, PatternRewriter &rewriter) const override
Definition: DCOps.cpp:142
LogicalResult matchAndRewrite(JoinOp op, PatternRewriter &rewriter) const override
Definition: DCOps.cpp:127
LogicalResult matchAndRewrite(JoinOp op, PatternRewriter &rewriter) const override
Definition: DCOps.cpp:103