CIRCT  19.0.0git
DCOps.cpp
Go to the documentation of this file.
1 //===- DCOps.cpp ----------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 #include "mlir/IR/Builders.h"
11 #include "mlir/IR/Diagnostics.h"
12 #include "mlir/IR/OpImplementation.h"
13 #include "mlir/IR/PatternMatch.h"
14 #include "mlir/Interfaces/FunctionImplementation.h"
15 #include "mlir/Interfaces/SideEffectInterfaces.h"
16 
17 using namespace circt;
18 using namespace dc;
19 using namespace mlir;
20 
22  auto vt = dyn_cast<ValueType>(t);
23  if (!vt)
24  return false;
25  auto innerWidth = vt.getInnerType().getIntOrFloatBitWidth();
26  return innerWidth == 1;
27 }
28 
29 namespace circt {
30 namespace dc {
31 
32 // =============================================================================
33 // JoinOp
34 // =============================================================================
35 
36 OpFoldResult JoinOp::fold(FoldAdaptor adaptor) {
37  // Fold simple joins (joins with 1 input).
38  if (auto tokens = getTokens(); tokens.size() == 1)
39  return tokens.front();
40 
41  // These folders are disabled to work around MLIR bugs when changing
42  // the number of operands. https://github.com/llvm/llvm-project/issues/64280
43  return {};
44 
45  // Remove operands which originate from a dc.source op (redundant).
46  auto *op = getOperation();
47  for (OpOperand &operand : llvm::make_early_inc_range(op->getOpOperands())) {
48  if (auto source = operand.get().getDefiningOp<dc::SourceOp>()) {
49  op->eraseOperand(operand.getOperandNumber());
50  return getOutput();
51  }
52  }
53 
54  // Remove duplicate operands.
55  llvm::DenseSet<Value> uniqueOperands;
56  for (OpOperand &operand : llvm::make_early_inc_range(op->getOpOperands())) {
57  if (!uniqueOperands.insert(operand.get()).second) {
58  op->eraseOperand(operand.getOperandNumber());
59  return getOutput();
60  }
61  }
62 
63  // Canonicalization staggered joins where the sink join contains inputs also
64  // found in the source join.
65  for (OpOperand &operand : llvm::make_early_inc_range(op->getOpOperands())) {
66  auto otherJoin = operand.get().getDefiningOp<dc::JoinOp>();
67  if (!otherJoin) {
68  // Operand does not originate from a join so it's a valid join input.
69  continue;
70  }
71 
72  // Operand originates from a join. Erase the current join operand and add
73  // all of the otherJoin op's inputs to this join.
74  // DCE will take care of otherJoin in case it's no longer used.
75  op->eraseOperand(operand.getOperandNumber());
76  op->insertOperands(getNumOperands(), otherJoin.getTokens());
77  return getOutput();
78  }
79 
80  return {};
81 }
82 
83 // =============================================================================
84 // ForkOp
85 // =============================================================================
86 
87 template <typename TInt>
88 static ParseResult parseIntInSquareBrackets(OpAsmParser &parser, TInt &v) {
89  if (parser.parseLSquare() || parser.parseInteger(v) || parser.parseRSquare())
90  return failure();
91  return success();
92 }
93 
94 ParseResult ForkOp::parse(OpAsmParser &parser, OperationState &result) {
95  OpAsmParser::UnresolvedOperand operand;
96  size_t size = 0;
97  if (parseIntInSquareBrackets(parser, size))
98  return failure();
99 
100  if (size == 0)
101  return parser.emitError(parser.getNameLoc(),
102  "fork size must be greater than 0");
103 
104  if (parser.parseOperand(operand) ||
105  parser.parseOptionalAttrDict(result.attributes))
106  return failure();
107 
108  auto tt = dc::TokenType::get(parser.getContext());
109  llvm::SmallVector<Type> operandTypes{tt};
110  SmallVector<Type> resultTypes{size, tt};
111  result.addTypes(resultTypes);
112  if (parser.resolveOperand(operand, tt, result.operands))
113  return failure();
114  return success();
115 }
116 
117 void ForkOp::print(OpAsmPrinter &p) {
118  p << " [" << getNumResults() << "] ";
119  p << getOperand() << " ";
120  auto attrs = (*this)->getAttrs();
121  if (!attrs.empty()) {
122  p << " ";
123  p.printOptionalAttrDict(attrs);
124  }
125 }
126 
128  // Canonicalization of forks where the output is fed into another fork.
129 public:
131  LogicalResult matchAndRewrite(ForkOp fork,
132  PatternRewriter &rewriter) const override {
133  for (auto output : fork.getOutputs()) {
134  for (auto *user : output.getUsers()) {
135  auto userFork = dyn_cast<ForkOp>(user);
136  if (!userFork)
137  continue;
138 
139  // We have a fork feeding into another fork. Replace the output fork by
140  // adding more outputs to the current fork.
141  size_t totalForks = fork.getNumResults() + userFork.getNumResults() - 1;
142 
143  auto newFork = rewriter.create<dc::ForkOp>(fork.getLoc(),
144  fork.getToken(), totalForks);
145  rewriter.replaceOp(
146  fork, newFork.getResults().take_front(fork.getNumResults()));
147  rewriter.replaceOp(
148  userFork, newFork.getResults().take_back(userFork.getNumResults()));
149 
150  // Just stop the pattern here instead of trying to do more - let the
151  // canonicalizer recurse if another run of the canonicalization applies.
152  return success();
153  }
154  }
155  return failure();
156  }
157 };
158 
160  // Canonicalizes away forks on source ops, in favor of individual source
161  // operations. Having standalone sources are a better alternative, since other
162  // operations can canonicalize on it (e.g. joins) as well as being very cheap
163  // to implement in hardware, if they do remain.
164 public:
166  LogicalResult matchAndRewrite(ForkOp fork,
167  PatternRewriter &rewriter) const override {
168  auto source = fork.getToken().getDefiningOp<SourceOp>();
169  if (!source)
170  return failure();
171 
172  // We have a source feeding into a fork. Replace the fork by a source for
173  // each output.
174  llvm::SmallVector<Value> sources;
175  for (size_t i = 0; i < fork.getNumResults(); ++i)
176  sources.push_back(rewriter.create<dc::SourceOp>(fork.getLoc()));
177 
178  rewriter.replaceOp(fork, sources);
179  return success();
180  }
181 };
182 
183 void ForkOp::getCanonicalizationPatterns(RewritePatternSet &results,
184  MLIRContext *context) {
186  context);
187 }
188 
189 LogicalResult ForkOp::fold(FoldAdaptor adaptor,
190  SmallVectorImpl<OpFoldResult> &results) {
191  // Fold simple forks (forks with 1 output).
192  if (getOutputs().size() == 1) {
193  results.push_back(getToken());
194  return success();
195  }
196 
197  return failure();
198 }
199 
200 // =============================================================================
201 // UnpackOp
202 // =============================================================================
203 
205  // Eliminates unpacks where only the token is used.
207  LogicalResult matchAndRewrite(UnpackOp unpack,
208  PatternRewriter &rewriter) const override {
209  // Is the value-side of the unpack used?
210  if (!unpack.getOutput().use_empty())
211  return failure();
212 
213  auto pack = unpack.getInput().getDefiningOp<PackOp>();
214  if (!pack)
215  return failure();
216 
217  // Replace all uses of the unpack token with the packed token.
218  rewriter.replaceAllUsesWith(unpack.getToken(), pack.getToken());
219  rewriter.eraseOp(unpack);
220  return success();
221  }
222 };
223 
224 void UnpackOp::getCanonicalizationPatterns(RewritePatternSet &results,
225  MLIRContext *context) {
226  results.insert<EliminateRedundantUnpackPattern>(context);
227 }
228 
229 LogicalResult UnpackOp::fold(FoldAdaptor adaptor,
230  SmallVectorImpl<OpFoldResult> &results) {
231  // Unpack of a pack is a no-op.
232  if (auto pack = getInput().getDefiningOp<PackOp>()) {
233  results.push_back(pack.getToken());
234  results.push_back(pack.getInput());
235  return success();
236  }
237 
238  return failure();
239 }
240 
241 LogicalResult UnpackOp::inferReturnTypes(
242  MLIRContext *context, std::optional<Location> loc, ValueRange operands,
243  DictionaryAttr attrs, mlir::OpaqueProperties properties,
244  mlir::RegionRange regions, SmallVectorImpl<Type> &results) {
245  auto inputType = cast<ValueType>(operands.front().getType());
246  results.push_back(TokenType::get(context));
247  results.push_back(inputType.getInnerType());
248  return success();
249 }
250 
251 // =============================================================================
252 // PackOp
253 // =============================================================================
254 
255 OpFoldResult PackOp::fold(FoldAdaptor adaptor) {
256  auto token = getToken();
257 
258  // Pack of an unpack is a no-op.
259  if (auto unpack = token.getDefiningOp<UnpackOp>()) {
260  if (unpack.getOutput() == getInput())
261  return unpack.getInput();
262  }
263  return {};
264 }
265 
266 LogicalResult PackOp::inferReturnTypes(
267  MLIRContext *context, std::optional<Location> loc, ValueRange operands,
268  DictionaryAttr attrs, mlir::OpaqueProperties properties,
269  mlir::RegionRange regions, SmallVectorImpl<Type> &results) {
270  llvm::SmallVector<Type> inputTypes;
271  Type inputType = operands.back().getType();
272  auto valueType = dc::ValueType::get(context, inputType);
273  results.push_back(valueType);
274  return success();
275 }
276 
277 // =============================================================================
278 // SelectOp
279 // =============================================================================
280 
282  // Canonicalize away a select that is fed only by a single branch
283  // example:
284  // %true, %false = dc.branch %sel1 %token
285  // %0 = dc.select %sel2, %true, %false
286  // ->
287  // %0 = dc.join %sel1, %sel2, %token
288 
289 public:
291  LogicalResult matchAndRewrite(SelectOp select,
292  PatternRewriter &rewriter) const override {
293  // Do all the inputs come from a branch?
294  BranchOp branchInput;
295  for (auto operand : {select.getTrueToken(), select.getFalseToken()}) {
296  auto br = operand.getDefiningOp<BranchOp>();
297  if (!br)
298  return failure();
299 
300  if (!branchInput)
301  branchInput = br;
302  else if (branchInput != br)
303  return failure();
304  }
305 
306  // Replace the select with a join (unpack the select conditions).
307  rewriter.replaceOpWithNewOp<JoinOp>(
308  select,
309  llvm::SmallVector<Value>{
310  rewriter.create<UnpackOp>(select.getLoc(), select.getCondition())
311  .getToken(),
312  rewriter
313  .create<UnpackOp>(branchInput.getLoc(),
314  branchInput.getCondition())
315  .getToken()});
316 
317  return success();
318  }
319 };
320 
321 void SelectOp::getCanonicalizationPatterns(RewritePatternSet &results,
322  MLIRContext *context) {
323  results.insert<EliminateBranchToSelectPattern>(context);
324 }
325 
326 // =============================================================================
327 // BufferOp
328 // =============================================================================
329 
330 FailureOr<SmallVector<int64_t>> BufferOp::getInitValueArray() {
331  assert(getInitValues() && "initValues attribute not set");
332  SmallVector<int64_t> values;
333  for (auto value : getInitValuesAttr()) {
334  if (auto iValue = dyn_cast<IntegerAttr>(value)) {
335  values.push_back(iValue.getValue().getSExtValue());
336  } else {
337  return emitError() << "initValues attribute must be an array of integers";
338  }
339  }
340  return values;
341 }
342 
343 LogicalResult BufferOp::verify() {
344  // Verify that exactly 'size' number of initial values have been provided, if
345  // an initializer list have been provided.
346  if (auto initVals = getInitValuesAttr()) {
347  auto nInits = initVals.size();
348  if (nInits != getSize())
349  return emitOpError() << "expected " << getSize()
350  << " init values but got " << nInits << ".";
351  }
352 
353  return success();
354 }
355 
356 // =============================================================================
357 // ToESIOp
358 // =============================================================================
359 
360 LogicalResult ToESIOp::inferReturnTypes(
361  MLIRContext *context, std::optional<Location> loc, ValueRange operands,
362  DictionaryAttr attrs, mlir::OpaqueProperties properties,
363  mlir::RegionRange regions, SmallVectorImpl<Type> &results) {
364  Type channelEltType;
365  if (auto valueType = dyn_cast<ValueType>(operands.front().getType()))
366  channelEltType = valueType.getInnerType();
367  else {
368  // dc.token => esi.channel<i0>
369  channelEltType = IntegerType::get(context, 0);
370  }
371 
372  results.push_back(esi::ChannelType::get(context, channelEltType));
373  return success();
374 }
375 
376 // =============================================================================
377 // FromESIOp
378 // =============================================================================
379 
380 LogicalResult FromESIOp::inferReturnTypes(
381  MLIRContext *context, std::optional<Location> loc, ValueRange operands,
382  DictionaryAttr attrs, mlir::OpaqueProperties properties,
383  mlir::RegionRange regions, SmallVectorImpl<Type> &results) {
384  auto innerType =
385  cast<esi::ChannelType>(operands.front().getType()).getInner();
386  if (auto intType = dyn_cast<IntegerType>(innerType); intType.getWidth() == 0)
387  results.push_back(dc::TokenType::get(context));
388  else
389  results.push_back(dc::ValueType::get(context, innerType));
390 
391  return success();
392 }
393 
394 } // namespace dc
395 } // namespace circt
396 
397 #define GET_OP_CLASSES
398 #include "circt/Dialect/DC/DC.cpp.inc"
assert(baseType &&"element must be base type")
LogicalResult matchAndRewrite(SelectOp select, PatternRewriter &rewriter) const override
Definition: DCOps.cpp:291
LogicalResult matchAndRewrite(ForkOp fork, PatternRewriter &rewriter) const override
Definition: DCOps.cpp:166
LogicalResult matchAndRewrite(ForkOp fork, PatternRewriter &rewriter) const override
Definition: DCOps.cpp:131
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:54
static ParseResult parseIntInSquareBrackets(OpAsmParser &parser, TInt &v)
Definition: DCOps.cpp:88
bool isI1ValueType(Type t)
Definition: DCOps.cpp:21
mlir::Type innerType(mlir::Type type)
Definition: ESITypes.cpp:184
LogicalResult inferReturnTypes(MLIRContext *context, std::optional< Location > loc, ValueRange operands, DictionaryAttr attrs, mlir::OpaqueProperties properties, mlir::RegionRange regions, SmallVectorImpl< Type > &results, llvm::function_ref< FIRRTLType(ValueRange, ArrayRef< NamedAttribute >, std::optional< Location >)> callback)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21
LogicalResult matchAndRewrite(UnpackOp unpack, PatternRewriter &rewriter) const override
Definition: DCOps.cpp:207