CIRCT 20.0.0git
Loading...
Searching...
No Matches
DCOps.cpp
Go to the documentation of this file.
1//===- DCOps.cpp ----------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10#include "mlir/IR/Builders.h"
11#include "mlir/IR/Diagnostics.h"
12#include "mlir/IR/OpImplementation.h"
13#include "mlir/IR/PatternMatch.h"
14#include "mlir/Interfaces/FunctionImplementation.h"
15#include "mlir/Interfaces/SideEffectInterfaces.h"
16
17using namespace circt;
18using namespace dc;
19using namespace mlir;
20
22 auto vt = dyn_cast<ValueType>(t);
23 if (!vt)
24 return false;
25 auto innerWidth = vt.getInnerType().getIntOrFloatBitWidth();
26 return innerWidth == 1;
27}
28
29namespace circt {
30namespace dc {
31
32// =============================================================================
33// JoinOp
34// =============================================================================
35
36OpFoldResult JoinOp::fold(FoldAdaptor adaptor) {
37 // Fold simple joins (joins with 1 input).
38 if (auto tokens = getTokens(); tokens.size() == 1)
39 return tokens.front();
40
41 return {};
42}
43
44struct JoinOnBranchPattern : public OpRewritePattern<JoinOp> {
46 LogicalResult matchAndRewrite(JoinOp op,
47 PatternRewriter &rewriter) const override {
48
49 struct BranchOperandInfo {
50 // Unique operands from the branch op, in case we have the same operand
51 // from the branch op multiple times.
52 SetVector<Value> uniqueOperands;
53 // Indices which the operands are at in the join op.
54 BitVector indices;
55 };
56
57 DenseMap<BranchOp, BranchOperandInfo> branchOperands;
58 for (auto &opOperand : op->getOpOperands()) {
59 auto branch = opOperand.get().getDefiningOp<BranchOp>();
60 if (!branch)
61 continue;
62
63 BranchOperandInfo &info = branchOperands[branch];
64 info.uniqueOperands.insert(opOperand.get());
65 info.indices.resize(op->getNumOperands());
66 info.indices.set(opOperand.getOperandNumber());
67 }
68
69 if (branchOperands.empty())
70 return failure();
71
72 // Do we have both operands from any given branch op?
73 for (auto &it : branchOperands) {
74 auto branch = it.first;
75 auto &operandInfo = it.second;
76 if (operandInfo.uniqueOperands.size() != 2) {
77 // We don't have both operands from the branch op.
78 continue;
79 }
80
81 // We have both operands from the branch op. Replace the join op with the
82 // branch op's data operand.
83
84 // Unpack the !dc.value<i1> input to the branch op
85 auto unpacked =
86 rewriter.create<UnpackOp>(op.getLoc(), branch.getCondition());
87 rewriter.modifyOpInPlace(op, [&]() {
88 op->eraseOperands(operandInfo.indices);
89 op.getTokensMutable().append({unpacked.getToken()});
90 });
91
92 // Only attempt a single branch at a time - else we'd have to maintain
93 // OpOperand indices during the loop... too complicated, let recursive
94 // pattern application handle this.
95 return success();
96 }
97
98 return failure();
99 }
100};
103 LogicalResult matchAndRewrite(JoinOp op,
104 PatternRewriter &rewriter) const override {
105 for (OpOperand &operand : llvm::make_early_inc_range(op->getOpOperands())) {
106 auto otherJoin = operand.get().getDefiningOp<dc::JoinOp>();
107 if (!otherJoin) {
108 // Operand does not originate from a join so it's a valid join input.
109 continue;
110 }
111
112 // Operand originates from a join. Erase the current join operand and
113 // add all of the otherJoin op's inputs to this join.
114 // DCE will take care of otherJoin in case it's no longer used.
115 rewriter.modifyOpInPlace(op, [&]() {
116 op.getTokensMutable().erase(operand.getOperandNumber());
117 op.getTokensMutable().append(otherJoin.getTokens());
118 });
119 return success();
120 }
121 return failure();
122 }
123};
124
127 LogicalResult matchAndRewrite(JoinOp op,
128 PatternRewriter &rewriter) const override {
129 for (OpOperand &operand : llvm::make_early_inc_range(op->getOpOperands())) {
130 if (auto source = operand.get().getDefiningOp<dc::SourceOp>()) {
131 rewriter.modifyOpInPlace(
132 op, [&]() { op->eraseOperand(operand.getOperandNumber()); });
133 return success();
134 }
135 }
136 return failure();
137 }
138};
139
142 LogicalResult matchAndRewrite(JoinOp op,
143 PatternRewriter &rewriter) const override {
144 llvm::DenseSet<Value> uniqueOperands;
145 for (OpOperand &operand : llvm::make_early_inc_range(op->getOpOperands())) {
146 if (!uniqueOperands.insert(operand.get()).second) {
147 rewriter.modifyOpInPlace(
148 op, [&]() { op->eraseOperand(operand.getOperandNumber()); });
149 return success();
150 }
151 }
152 return failure();
153 }
154};
155
156void JoinOp::getCanonicalizationPatterns(RewritePatternSet &results,
157 MLIRContext *context) {
160}
161
162// =============================================================================
163// ForkOp
164// =============================================================================
165
166template <typename TInt>
167static ParseResult parseIntInSquareBrackets(OpAsmParser &parser, TInt &v) {
168 if (parser.parseLSquare() || parser.parseInteger(v) || parser.parseRSquare())
169 return failure();
170 return success();
171}
172
173ParseResult ForkOp::parse(OpAsmParser &parser, OperationState &result) {
174 OpAsmParser::UnresolvedOperand operand;
175 size_t size = 0;
176 if (parseIntInSquareBrackets(parser, size))
177 return failure();
178
179 if (size == 0)
180 return parser.emitError(parser.getNameLoc(),
181 "fork size must be greater than 0");
182
183 if (parser.parseOperand(operand) ||
184 parser.parseOptionalAttrDict(result.attributes))
185 return failure();
186
187 auto tt = dc::TokenType::get(parser.getContext());
188 llvm::SmallVector<Type> operandTypes{tt};
189 SmallVector<Type> resultTypes{size, tt};
190 result.addTypes(resultTypes);
191 if (parser.resolveOperand(operand, tt, result.operands))
192 return failure();
193 return success();
194}
195
196void ForkOp::print(OpAsmPrinter &p) {
197 p << " [" << getNumResults() << "] ";
198 p << getOperand() << " ";
199 auto attrs = (*this)->getAttrs();
200 if (!attrs.empty()) {
201 p << " ";
202 p.printOptionalAttrDict(attrs);
203 }
204}
205
207 // Canonicalization of forks where the output is fed into another fork.
208public:
210 LogicalResult matchAndRewrite(ForkOp fork,
211 PatternRewriter &rewriter) const override {
212 for (auto output : fork.getOutputs()) {
213 for (auto *user : output.getUsers()) {
214 auto userFork = dyn_cast<ForkOp>(user);
215 if (!userFork)
216 continue;
217
218 // We have a fork feeding into another fork. Replace the output fork by
219 // adding more outputs to the current fork.
220 size_t totalForks = fork.getNumResults() + userFork.getNumResults();
221
222 auto newFork = rewriter.create<dc::ForkOp>(fork.getLoc(),
223 fork.getToken(), totalForks);
224 rewriter.replaceOp(
225 fork, newFork.getResults().take_front(fork.getNumResults()));
226 rewriter.replaceOp(
227 userFork, newFork.getResults().take_back(userFork.getNumResults()));
228
229 // Just stop the pattern here instead of trying to do more - let the
230 // canonicalizer recurse if another run of the canonicalization applies.
231 return success();
232 }
233 }
234 return failure();
235 }
236};
237
239 // Canonicalizes away forks on source ops, in favor of individual source
240 // operations. Having standalone sources are a better alternative, since other
241 // operations can canonicalize on it (e.g. joins) as well as being very cheap
242 // to implement in hardware, if they do remain.
243public:
245 LogicalResult matchAndRewrite(ForkOp fork,
246 PatternRewriter &rewriter) const override {
247 auto source = fork.getToken().getDefiningOp<SourceOp>();
248 if (!source)
249 return failure();
250
251 // We have a source feeding into a fork. Replace the fork by a source for
252 // each output.
253 llvm::SmallVector<Value> sources;
254 for (size_t i = 0; i < fork.getNumResults(); ++i)
255 sources.push_back(rewriter.create<dc::SourceOp>(fork.getLoc()));
256
257 rewriter.replaceOp(fork, sources);
258 return success();
259 }
260};
261
264
265 LogicalResult matchAndRewrite(ForkOp op,
266 PatternRewriter &rewriter) const override {
267 std::set<unsigned> unusedIndexes;
268
269 for (auto res : llvm::enumerate(op.getResults()))
270 if (res.value().use_empty())
271 unusedIndexes.insert(res.index());
272
273 if (unusedIndexes.empty())
274 return failure();
275
276 // Create a new fork op, dropping the unused results.
277 rewriter.setInsertionPoint(op);
278 auto operand = op.getOperand();
279 auto newFork = rewriter.create<ForkOp>(
280 op.getLoc(), operand, op.getNumResults() - unusedIndexes.size());
281 unsigned i = 0;
282 for (auto oldRes : llvm::enumerate(op.getResults()))
283 if (unusedIndexes.count(oldRes.index()) == 0)
284 rewriter.replaceAllUsesWith(oldRes.value(), newFork.getResults()[i++]);
285 rewriter.eraseOp(op);
286 return success();
287 }
288};
289
290void ForkOp::getCanonicalizationPatterns(RewritePatternSet &results,
291 MLIRContext *context) {
294}
295
296LogicalResult ForkOp::fold(FoldAdaptor adaptor,
297 SmallVectorImpl<OpFoldResult> &results) {
298 // Fold simple forks (forks with 1 output).
299 if (getOutputs().size() == 1) {
300 results.push_back(getToken());
301 return success();
302 }
303
304 return failure();
305}
306
307// =============================================================================
308// UnpackOp
309// =============================================================================
310
312 // Eliminates unpacks where only the token is used.
314 LogicalResult matchAndRewrite(UnpackOp unpack,
315 PatternRewriter &rewriter) const override {
316 // Is the value-side of the unpack used?
317 if (!unpack.getOutput().use_empty())
318 return failure();
319
320 auto pack = unpack.getInput().getDefiningOp<PackOp>();
321 if (!pack)
322 return failure();
323
324 // Replace all uses of the unpack token with the packed token.
325 rewriter.replaceAllUsesWith(unpack.getToken(), pack.getToken());
326 rewriter.eraseOp(unpack);
327 return success();
328 }
329};
330
331void UnpackOp::getCanonicalizationPatterns(RewritePatternSet &results,
332 MLIRContext *context) {
333 results.insert<EliminateRedundantUnpackPattern>(context);
334}
335
336LogicalResult UnpackOp::fold(FoldAdaptor adaptor,
337 SmallVectorImpl<OpFoldResult> &results) {
338 // Unpack of a pack is a no-op.
339 if (auto pack = getInput().getDefiningOp<PackOp>()) {
340 results.push_back(pack.getToken());
341 results.push_back(pack.getInput());
342 return success();
343 }
344
345 return failure();
346}
347
348LogicalResult UnpackOp::inferReturnTypes(
349 MLIRContext *context, std::optional<Location> loc, ValueRange operands,
350 DictionaryAttr attrs, mlir::OpaqueProperties properties,
351 mlir::RegionRange regions, SmallVectorImpl<Type> &results) {
352 auto inputType = cast<ValueType>(operands.front().getType());
353 results.push_back(TokenType::get(context));
354 results.push_back(inputType.getInnerType());
355 return success();
356}
357
358// =============================================================================
359// PackOp
360// =============================================================================
361
362OpFoldResult PackOp::fold(FoldAdaptor adaptor) {
363 auto token = getToken();
364
365 // Pack of an unpack is a no-op.
366 if (auto unpack = token.getDefiningOp<UnpackOp>()) {
367 if (unpack.getOutput() == getInput())
368 return unpack.getInput();
369 }
370 return {};
371}
372
373LogicalResult PackOp::inferReturnTypes(
374 MLIRContext *context, std::optional<Location> loc, ValueRange operands,
375 DictionaryAttr attrs, mlir::OpaqueProperties properties,
376 mlir::RegionRange regions, SmallVectorImpl<Type> &results) {
377 llvm::SmallVector<Type> inputTypes;
378 Type inputType = operands.back().getType();
379 auto valueType = dc::ValueType::get(context, inputType);
380 results.push_back(valueType);
381 return success();
382}
383
384// =============================================================================
385// SelectOp
386// =============================================================================
387
389 // Canonicalize away a select that is fed only by a single branch
390 // example:
391 // %true, %false = dc.branch %sel1 %token
392 // %0 = dc.select %sel2, %true, %false
393 // ->
394 // %0 = dc.join %sel1, %sel2, %token
395
396public:
398 LogicalResult matchAndRewrite(SelectOp select,
399 PatternRewriter &rewriter) const override {
400 // Do all the inputs come from a branch?
401 BranchOp branchInput;
402 for (auto operand : {select.getTrueToken(), select.getFalseToken()}) {
403 auto br = operand.getDefiningOp<BranchOp>();
404 if (!br)
405 return failure();
406
407 if (!branchInput)
408 branchInput = br;
409 else if (branchInput != br)
410 return failure();
411 }
412
413 // Replace the select with a join (unpack the select conditions).
414 rewriter.replaceOpWithNewOp<JoinOp>(
415 select,
416 llvm::SmallVector<Value>{
417 rewriter.create<UnpackOp>(select.getLoc(), select.getCondition())
418 .getToken(),
419 rewriter
420 .create<UnpackOp>(branchInput.getLoc(),
421 branchInput.getCondition())
422 .getToken()});
423
424 return success();
425 }
426};
427
428void SelectOp::getCanonicalizationPatterns(RewritePatternSet &results,
429 MLIRContext *context) {
430 results.insert<EliminateBranchToSelectPattern>(context);
431}
432
433// =============================================================================
434// BufferOp
435// =============================================================================
436
437FailureOr<SmallVector<int64_t>> BufferOp::getInitValueArray() {
438 assert(getInitValues() && "initValues attribute not set");
439 SmallVector<int64_t> values;
440 for (auto value : getInitValuesAttr()) {
441 if (auto iValue = dyn_cast<IntegerAttr>(value)) {
442 values.push_back(iValue.getValue().getSExtValue());
443 } else {
444 return emitError() << "initValues attribute must be an array of integers";
445 }
446 }
447 return values;
448}
449
450LogicalResult BufferOp::verify() {
451 // Verify that exactly 'size' number of initial values have been provided, if
452 // an initializer list have been provided.
453 if (auto initVals = getInitValuesAttr()) {
454 auto nInits = initVals.size();
455 if (nInits != getSize())
456 return emitOpError() << "expected " << getSize()
457 << " init values but got " << nInits << ".";
458 }
459
460 return success();
461}
462
463// =============================================================================
464// ToESIOp
465// =============================================================================
466
467LogicalResult ToESIOp::inferReturnTypes(
468 MLIRContext *context, std::optional<Location> loc, ValueRange operands,
469 DictionaryAttr attrs, mlir::OpaqueProperties properties,
470 mlir::RegionRange regions, SmallVectorImpl<Type> &results) {
471 Type channelEltType;
472 if (auto valueType = dyn_cast<ValueType>(operands.front().getType()))
473 channelEltType = valueType.getInnerType();
474 else {
475 // dc.token => esi.channel<i0>
476 channelEltType = IntegerType::get(context, 0);
477 }
478
479 results.push_back(esi::ChannelType::get(context, channelEltType));
480 return success();
481}
482
483// =============================================================================
484// FromESIOp
485// =============================================================================
486
487LogicalResult FromESIOp::inferReturnTypes(
488 MLIRContext *context, std::optional<Location> loc, ValueRange operands,
489 DictionaryAttr attrs, mlir::OpaqueProperties properties,
490 mlir::RegionRange regions, SmallVectorImpl<Type> &results) {
491 auto innerType =
492 cast<esi::ChannelType>(operands.front().getType()).getInner();
493 if (auto intType = dyn_cast<IntegerType>(innerType); intType.getWidth() == 0)
494 results.push_back(dc::TokenType::get(context));
495 else
496 results.push_back(dc::ValueType::get(context, innerType));
497
498 return success();
499}
500
501} // namespace dc
502} // namespace circt
503
504#define GET_OP_CLASSES
505#include "circt/Dialect/DC/DC.cpp.inc"
assert(baseType &&"element must be base type")
LogicalResult matchAndRewrite(SelectOp select, PatternRewriter &rewriter) const override
Definition DCOps.cpp:398
LogicalResult matchAndRewrite(ForkOp fork, PatternRewriter &rewriter) const override
Definition DCOps.cpp:245
LogicalResult matchAndRewrite(ForkOp fork, PatternRewriter &rewriter) const override
Definition DCOps.cpp:210
static ParseResult parseIntInSquareBrackets(OpAsmParser &parser, TInt &v)
Definition DCOps.cpp:167
bool isI1ValueType(Type t)
Definition DCOps.cpp:21
mlir::Type innerType(mlir::Type type)
Definition ESITypes.cpp:227
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
LogicalResult matchAndRewrite(UnpackOp unpack, PatternRewriter &rewriter) const override
Definition DCOps.cpp:314
LogicalResult matchAndRewrite(ForkOp op, PatternRewriter &rewriter) const override
Definition DCOps.cpp:265
LogicalResult matchAndRewrite(JoinOp op, PatternRewriter &rewriter) const override
Definition DCOps.cpp:46
LogicalResult matchAndRewrite(JoinOp op, PatternRewriter &rewriter) const override
Definition DCOps.cpp:142
LogicalResult matchAndRewrite(JoinOp op, PatternRewriter &rewriter) const override
Definition DCOps.cpp:127
LogicalResult matchAndRewrite(JoinOp op, PatternRewriter &rewriter) const override
Definition DCOps.cpp:103