10 #include "mlir/IR/Builders.h"
11 #include "mlir/IR/Diagnostics.h"
12 #include "mlir/IR/OpImplementation.h"
13 #include "mlir/IR/PatternMatch.h"
14 #include "mlir/Interfaces/FunctionImplementation.h"
15 #include "mlir/Interfaces/SideEffectInterfaces.h"
17 using namespace circt;
22 auto vt = dyn_cast<ValueType>(t);
25 auto innerWidth = vt.getInnerType().getIntOrFloatBitWidth();
26 return innerWidth == 1;
36 OpFoldResult JoinOp::fold(FoldAdaptor adaptor) {
38 if (
auto tokens = getTokens(); tokens.size() == 1)
39 return tokens.front();
46 auto *op = getOperation();
47 for (OpOperand &operand : llvm::make_early_inc_range(op->getOpOperands())) {
48 if (
auto source = operand.get().getDefiningOp<dc::SourceOp>()) {
49 op->eraseOperand(operand.getOperandNumber());
55 llvm::DenseSet<Value> uniqueOperands;
56 for (OpOperand &operand : llvm::make_early_inc_range(op->getOpOperands())) {
57 if (!uniqueOperands.insert(operand.get()).second) {
58 op->eraseOperand(operand.getOperandNumber());
65 for (OpOperand &operand : llvm::make_early_inc_range(op->getOpOperands())) {
66 auto otherJoin = operand.get().getDefiningOp<dc::JoinOp>();
75 op->eraseOperand(operand.getOperandNumber());
76 op->insertOperands(getNumOperands(), otherJoin.getTokens());
87 template <
typename TInt>
89 if (parser.parseLSquare() || parser.parseInteger(v) || parser.parseRSquare())
94 ParseResult ForkOp::parse(OpAsmParser &parser, OperationState &result) {
95 OpAsmParser::UnresolvedOperand operand;
101 return parser.emitError(parser.getNameLoc(),
102 "fork size must be greater than 0");
104 if (parser.parseOperand(operand) ||
105 parser.parseOptionalAttrDict(result.attributes))
109 llvm::SmallVector<Type> operandTypes{tt};
110 SmallVector<Type> resultTypes{size, tt};
111 result.addTypes(resultTypes);
112 if (parser.resolveOperand(operand, tt, result.operands))
117 void ForkOp::print(OpAsmPrinter &p) {
118 p <<
" [" << getNumResults() <<
"] ";
119 p << getOperand() <<
" ";
120 auto attrs = (*this)->getAttrs();
121 if (!attrs.empty()) {
123 p.printOptionalAttrDict(attrs);
132 PatternRewriter &rewriter)
const override {
133 for (
auto output : fork.getOutputs()) {
134 for (
auto *user : output.getUsers()) {
135 auto userFork = dyn_cast<ForkOp>(user);
141 size_t totalForks = fork.getNumResults() + userFork.getNumResults() - 1;
143 auto newFork = rewriter.create<dc::ForkOp>(fork.getLoc(),
144 fork.getToken(), totalForks);
146 fork, newFork.getResults().take_front(fork.getNumResults()));
148 userFork, newFork.getResults().take_back(userFork.getNumResults()));
167 PatternRewriter &rewriter)
const override {
168 auto source = fork.getToken().getDefiningOp<SourceOp>();
174 llvm::SmallVector<Value> sources;
175 for (
size_t i = 0; i < fork.getNumResults(); ++i)
176 sources.push_back(rewriter.create<dc::SourceOp>(fork.getLoc()));
178 rewriter.replaceOp(fork, sources);
183 void ForkOp::getCanonicalizationPatterns(RewritePatternSet &results,
184 MLIRContext *context) {
189 LogicalResult ForkOp::fold(FoldAdaptor adaptor,
190 SmallVectorImpl<OpFoldResult> &results) {
192 if (getOutputs().size() == 1) {
193 results.push_back(getToken());
208 PatternRewriter &rewriter)
const override {
210 if (!unpack.getOutput().use_empty())
213 auto pack = unpack.getInput().getDefiningOp<PackOp>();
218 rewriter.replaceAllUsesWith(unpack.getToken(), pack.getToken());
219 rewriter.eraseOp(unpack);
224 void UnpackOp::getCanonicalizationPatterns(RewritePatternSet &results,
225 MLIRContext *context) {
229 LogicalResult UnpackOp::fold(FoldAdaptor adaptor,
230 SmallVectorImpl<OpFoldResult> &results) {
232 if (
auto pack = getInput().getDefiningOp<PackOp>()) {
233 results.push_back(pack.getToken());
234 results.push_back(pack.getInput());
241 LogicalResult UnpackOp::inferReturnTypes(
242 MLIRContext *context, std::optional<Location> loc, ValueRange operands,
243 DictionaryAttr attrs, mlir::OpaqueProperties properties,
244 mlir::RegionRange regions, SmallVectorImpl<Type> &results) {
245 auto inputType = cast<ValueType>(operands.front().getType());
247 results.push_back(inputType.getInnerType());
255 OpFoldResult PackOp::fold(FoldAdaptor adaptor) {
256 auto token = getToken();
259 if (
auto unpack = token.getDefiningOp<UnpackOp>()) {
260 if (unpack.getOutput() == getInput())
261 return unpack.getInput();
266 LogicalResult PackOp::inferReturnTypes(
267 MLIRContext *context, std::optional<Location> loc, ValueRange operands,
268 DictionaryAttr attrs, mlir::OpaqueProperties properties,
269 mlir::RegionRange regions, SmallVectorImpl<Type> &results) {
270 llvm::SmallVector<Type> inputTypes;
271 Type inputType = operands.back().getType();
273 results.push_back(valueType);
292 PatternRewriter &rewriter)
const override {
294 BranchOp branchInput;
295 for (
auto operand : {select.getTrueToken(), select.getFalseToken()}) {
296 auto br = operand.getDefiningOp<BranchOp>();
302 else if (branchInput != br)
307 rewriter.replaceOpWithNewOp<JoinOp>(
309 llvm::SmallVector<Value>{
310 rewriter.create<UnpackOp>(select.getLoc(), select.getCondition())
313 .create<UnpackOp>(branchInput.getLoc(),
314 branchInput.getCondition())
321 void SelectOp::getCanonicalizationPatterns(RewritePatternSet &results,
322 MLIRContext *context) {
330 FailureOr<SmallVector<int64_t>> BufferOp::getInitValueArray() {
331 assert(getInitValues() &&
"initValues attribute not set");
332 SmallVector<int64_t> values;
333 for (
auto value : getInitValuesAttr()) {
334 if (
auto iValue = dyn_cast<IntegerAttr>(value)) {
335 values.push_back(iValue.getValue().getSExtValue());
337 return emitError() <<
"initValues attribute must be an array of integers";
346 if (
auto initVals = getInitValuesAttr()) {
347 auto nInits = initVals.size();
348 if (nInits != getSize())
349 return emitOpError() <<
"expected " << getSize()
350 <<
" init values but got " << nInits <<
".";
360 LogicalResult ToESIOp::inferReturnTypes(
361 MLIRContext *context, std::optional<Location> loc, ValueRange operands,
362 DictionaryAttr attrs, mlir::OpaqueProperties properties,
363 mlir::RegionRange regions, SmallVectorImpl<Type> &results) {
365 if (
auto valueType = dyn_cast<ValueType>(operands.front().getType()))
366 channelEltType = valueType.getInnerType();
380 LogicalResult FromESIOp::inferReturnTypes(
381 MLIRContext *context, std::optional<Location> loc, ValueRange operands,
382 DictionaryAttr attrs, mlir::OpaqueProperties properties,
383 mlir::RegionRange regions, SmallVectorImpl<Type> &results) {
385 cast<esi::ChannelType>(operands.front().getType()).getInner();
386 if (
auto intType = dyn_cast<IntegerType>(
innerType); intType.getWidth() == 0)
397 #define GET_OP_CLASSES
398 #include "circt/Dialect/DC/DC.cpp.inc"
assert(baseType &&"element must be base type")
LogicalResult matchAndRewrite(SelectOp select, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(ForkOp fork, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(ForkOp fork, PatternRewriter &rewriter) const override
static LogicalResult verify(Value clock, bool eventExists, mlir::Location loc)
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
static ParseResult parseIntInSquareBrackets(OpAsmParser &parser, TInt &v)
bool isI1ValueType(Type t)
mlir::Type innerType(mlir::Type type)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
LogicalResult matchAndRewrite(UnpackOp unpack, PatternRewriter &rewriter) const override