10#include "mlir/IR/Builders.h"
11#include "mlir/IR/Diagnostics.h"
12#include "mlir/IR/OpImplementation.h"
13#include "mlir/IR/PatternMatch.h"
14#include "mlir/Interfaces/FunctionImplementation.h"
15#include "mlir/Interfaces/SideEffectInterfaces.h"
22 auto vt = dyn_cast<ValueType>(t);
25 auto innerWidth = vt.getInnerType().getIntOrFloatBitWidth();
26 return innerWidth == 1;
36OpFoldResult JoinOp::fold(FoldAdaptor adaptor) {
38 if (
auto tokens = getTokens(); tokens.size() == 1)
39 return tokens.front();
47 PatternRewriter &rewriter)
const override {
49 struct BranchOperandInfo {
52 SetVector<Value> uniqueOperands;
57 DenseMap<BranchOp, BranchOperandInfo> branchOperands;
58 for (
auto &opOperand : op->getOpOperands()) {
59 auto branch = opOperand.get().getDefiningOp<BranchOp>();
63 BranchOperandInfo &info = branchOperands[branch];
64 info.uniqueOperands.insert(opOperand.get());
65 info.indices.resize(op->getNumOperands());
66 info.indices.set(opOperand.getOperandNumber());
69 if (branchOperands.empty())
73 for (
auto &it : branchOperands) {
74 auto branch = it.first;
75 auto &operandInfo = it.second;
76 if (operandInfo.uniqueOperands.size() != 2) {
86 rewriter.create<UnpackOp>(op.getLoc(), branch.getCondition());
87 rewriter.modifyOpInPlace(op, [&]() {
88 op->eraseOperands(operandInfo.indices);
89 op.getTokensMutable().append({unpacked.getToken()});
104 PatternRewriter &rewriter)
const override {
105 for (OpOperand &operand : llvm::make_early_inc_range(op->getOpOperands())) {
106 auto otherJoin = operand.get().getDefiningOp<dc::JoinOp>();
115 rewriter.modifyOpInPlace(op, [&]() {
116 op.getTokensMutable().erase(operand.getOperandNumber());
117 op.getTokensMutable().append(otherJoin.getTokens());
128 PatternRewriter &rewriter)
const override {
129 for (OpOperand &operand : llvm::make_early_inc_range(op->getOpOperands())) {
130 if (
auto source = operand.get().getDefiningOp<dc::SourceOp>()) {
131 rewriter.modifyOpInPlace(
132 op, [&]() { op->eraseOperand(operand.getOperandNumber()); });
143 PatternRewriter &rewriter)
const override {
144 llvm::DenseSet<Value> uniqueOperands;
145 for (OpOperand &operand : llvm::make_early_inc_range(op->getOpOperands())) {
146 if (!uniqueOperands.insert(operand.get()).second) {
147 rewriter.modifyOpInPlace(
148 op, [&]() { op->eraseOperand(operand.getOperandNumber()); });
156void JoinOp::getCanonicalizationPatterns(RewritePatternSet &results,
157 MLIRContext *context) {
166template <
typename TInt>
168 if (parser.parseLSquare() || parser.parseInteger(v) || parser.parseRSquare())
173ParseResult ForkOp::parse(OpAsmParser &parser, OperationState &result) {
174 OpAsmParser::UnresolvedOperand operand;
180 return parser.emitError(parser.getNameLoc(),
181 "fork size must be greater than 0");
183 if (parser.parseOperand(operand) ||
184 parser.parseOptionalAttrDict(result.attributes))
187 auto tt = dc::TokenType::get(parser.getContext());
188 llvm::SmallVector<Type> operandTypes{tt};
189 SmallVector<Type> resultTypes{size, tt};
190 result.addTypes(resultTypes);
191 if (parser.resolveOperand(operand, tt, result.operands))
196void ForkOp::print(OpAsmPrinter &p) {
197 p <<
" [" << getNumResults() <<
"] ";
198 p << getOperand() <<
" ";
199 auto attrs = (*this)->getAttrs();
200 if (!attrs.empty()) {
202 p.printOptionalAttrDict(attrs);
211 PatternRewriter &rewriter)
const override {
212 for (
auto output : fork.getOutputs()) {
213 for (
auto *user : output.getUsers()) {
214 auto userFork = dyn_cast<ForkOp>(user);
220 size_t totalForks = fork.getNumResults() + userFork.getNumResults();
222 auto newFork = rewriter.create<dc::ForkOp>(fork.getLoc(),
223 fork.getToken(), totalForks);
225 fork, newFork.getResults().take_front(fork.getNumResults()));
227 userFork, newFork.getResults().take_back(userFork.getNumResults()));
246 PatternRewriter &rewriter)
const override {
247 auto source = fork.getToken().getDefiningOp<SourceOp>();
253 llvm::SmallVector<Value> sources;
254 for (
size_t i = 0; i < fork.getNumResults(); ++i)
255 sources.push_back(rewriter.create<dc::SourceOp>(fork.getLoc()));
257 rewriter.replaceOp(fork, sources);
266 PatternRewriter &rewriter)
const override {
267 std::set<unsigned> unusedIndexes;
269 for (
auto res : llvm::enumerate(op.getResults()))
270 if (res.value().use_empty())
271 unusedIndexes.insert(res.index());
273 if (unusedIndexes.empty())
277 rewriter.setInsertionPoint(op);
278 auto operand = op.getOperand();
279 auto newFork = rewriter.create<ForkOp>(
280 op.getLoc(), operand, op.getNumResults() - unusedIndexes.size());
282 for (
auto oldRes : llvm::enumerate(op.getResults()))
283 if (unusedIndexes.count(oldRes.index()) == 0)
284 rewriter.replaceAllUsesWith(oldRes.value(), newFork.getResults()[i++]);
285 rewriter.eraseOp(op);
290void ForkOp::getCanonicalizationPatterns(RewritePatternSet &results,
291 MLIRContext *context) {
296LogicalResult ForkOp::fold(FoldAdaptor adaptor,
297 SmallVectorImpl<OpFoldResult> &results) {
299 if (getOutputs().size() == 1) {
300 results.push_back(getToken());
315 PatternRewriter &rewriter)
const override {
317 if (!unpack.getOutput().use_empty())
320 auto pack = unpack.getInput().getDefiningOp<PackOp>();
325 rewriter.replaceAllUsesWith(unpack.getToken(), pack.getToken());
326 rewriter.eraseOp(unpack);
331void UnpackOp::getCanonicalizationPatterns(RewritePatternSet &results,
332 MLIRContext *context) {
336LogicalResult UnpackOp::fold(FoldAdaptor adaptor,
337 SmallVectorImpl<OpFoldResult> &results) {
339 if (
auto pack = getInput().getDefiningOp<PackOp>()) {
340 results.push_back(pack.getToken());
341 results.push_back(pack.getInput());
348LogicalResult UnpackOp::inferReturnTypes(
349 MLIRContext *context, std::optional<Location> loc, ValueRange operands,
350 DictionaryAttr attrs, mlir::OpaqueProperties properties,
351 mlir::RegionRange regions, SmallVectorImpl<Type> &results) {
352 auto inputType = cast<ValueType>(operands.front().getType());
353 results.push_back(TokenType::get(context));
354 results.push_back(inputType.getInnerType());
362OpFoldResult PackOp::fold(FoldAdaptor adaptor) {
363 auto token = getToken();
366 if (
auto unpack = token.getDefiningOp<UnpackOp>()) {
367 if (unpack.getOutput() == getInput())
368 return unpack.getInput();
373LogicalResult PackOp::inferReturnTypes(
374 MLIRContext *context, std::optional<Location> loc, ValueRange operands,
375 DictionaryAttr attrs, mlir::OpaqueProperties properties,
376 mlir::RegionRange regions, SmallVectorImpl<Type> &results) {
377 llvm::SmallVector<Type> inputTypes;
378 Type inputType = operands.back().getType();
379 auto valueType = dc::ValueType::get(context, inputType);
380 results.push_back(valueType);
399 PatternRewriter &rewriter)
const override {
401 BranchOp branchInput;
402 for (
auto operand : {select.getTrueToken(), select.getFalseToken()}) {
403 auto br = operand.getDefiningOp<BranchOp>();
409 else if (branchInput != br)
414 rewriter.replaceOpWithNewOp<JoinOp>(
416 llvm::SmallVector<Value>{
417 rewriter.create<UnpackOp>(select.getLoc(), select.getCondition())
420 .create<UnpackOp>(branchInput.getLoc(),
421 branchInput.getCondition())
428void SelectOp::getCanonicalizationPatterns(RewritePatternSet &results,
429 MLIRContext *context) {
437FailureOr<SmallVector<int64_t>> BufferOp::getInitValueArray() {
438 assert(getInitValues() &&
"initValues attribute not set");
439 SmallVector<int64_t> values;
440 for (
auto value : getInitValuesAttr()) {
441 if (
auto iValue = dyn_cast<IntegerAttr>(value)) {
442 values.push_back(iValue.getValue().getSExtValue());
444 return emitError() <<
"initValues attribute must be an array of integers";
450LogicalResult BufferOp::verify() {
453 if (
auto initVals = getInitValuesAttr()) {
454 auto nInits = initVals.size();
455 if (nInits != getSize())
456 return emitOpError() <<
"expected " << getSize()
457 <<
" init values but got " << nInits <<
".";
467LogicalResult ToESIOp::inferReturnTypes(
468 MLIRContext *context, std::optional<Location> loc, ValueRange operands,
469 DictionaryAttr attrs, mlir::OpaqueProperties properties,
470 mlir::RegionRange regions, SmallVectorImpl<Type> &results) {
472 if (
auto valueType = dyn_cast<ValueType>(operands.front().getType()))
473 channelEltType = valueType.getInnerType();
476 channelEltType = IntegerType::get(context, 0);
479 results.push_back(esi::ChannelType::get(context, channelEltType));
487LogicalResult FromESIOp::inferReturnTypes(
488 MLIRContext *context, std::optional<Location> loc, ValueRange operands,
489 DictionaryAttr attrs, mlir::OpaqueProperties properties,
490 mlir::RegionRange regions, SmallVectorImpl<Type> &results) {
492 cast<esi::ChannelType>(operands.front().getType()).getInner();
493 if (
auto intType = dyn_cast<IntegerType>(innerType); intType.getWidth() == 0)
494 results.push_back(dc::TokenType::get(context));
496 results.push_back(dc::ValueType::get(context, innerType));
504#define GET_OP_CLASSES
505#include "circt/Dialect/DC/DC.cpp.inc"
assert(baseType &&"element must be base type")
LogicalResult matchAndRewrite(SelectOp select, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(ForkOp fork, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(ForkOp fork, PatternRewriter &rewriter) const override
static ParseResult parseIntInSquareBrackets(OpAsmParser &parser, TInt &v)
bool isI1ValueType(Type t)
mlir::Type innerType(mlir::Type type)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
LogicalResult matchAndRewrite(UnpackOp unpack, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(ForkOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(JoinOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(JoinOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(JoinOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(JoinOp op, PatternRewriter &rewriter) const override