Loading [MathJax]/extensions/tex2jax.js
CIRCT 22.0.0git
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
SeqOps.cpp
Go to the documentation of this file.
1//===- SeqOps.cpp - Implement the Seq operations ------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements sequential ops.
10//
11//===----------------------------------------------------------------------===//
12
18#include "mlir/Analysis/TopologicalSortUtils.h"
19#include "mlir/Dialect/Arith/IR/Arith.h"
20#include "mlir/IR/Builders.h"
21#include "mlir/IR/DialectImplementation.h"
22#include "mlir/IR/Matchers.h"
23#include "mlir/IR/PatternMatch.h"
24
26#include "llvm/ADT/SmallString.h"
27
28using namespace mlir;
29using namespace circt;
30using namespace seq;
31
32bool circt::seq::isValidIndexValues(Value hlmemHandle, ValueRange addresses) {
33 auto memType = cast<seq::HLMemType>(hlmemHandle.getType());
34 auto shape = memType.getShape();
35 if (shape.size() != addresses.size())
36 return false;
37
38 for (auto [dim, addr] : llvm::zip(shape, addresses)) {
39 auto addrType = dyn_cast<IntegerType>(addr.getType());
40 if (!addrType)
41 return false;
42 if (addrType.getIntOrFloatBitWidth() != llvm::Log2_64_Ceil(dim))
43 return false;
44 }
45 return true;
46}
47
48// If there was no name specified, check to see if there was a useful name
49// specified in the asm file.
50static void setNameFromResult(OpAsmParser &parser, OperationState &result) {
51 if (result.attributes.getNamed("name"))
52 return;
53 // If there is no explicit name attribute, get it from the SSA result name.
54 // If numeric, just use an empty name.
55 StringRef resultName = parser.getResultName(0).first;
56 if (!resultName.empty() && isdigit(resultName[0]))
57 resultName = "";
58 result.addAttribute("name", parser.getBuilder().getStringAttr(resultName));
59}
60
61static bool canElideName(OpAsmPrinter &p, Operation *op) {
62 if (!op->hasAttr("name"))
63 return true;
64
65 auto name = op->getAttrOfType<StringAttr>("name").getValue();
66 if (name.empty())
67 return true;
68
69 SmallString<32> resultNameStr;
70 llvm::raw_svector_ostream tmpStream(resultNameStr);
71 p.printOperand(op->getResult(0), tmpStream);
72 auto actualName = tmpStream.str().drop_front();
73 return actualName == name;
74}
75
76static ParseResult
77parseOptionalTypeMatch(OpAsmParser &parser, Type refType,
78 std::optional<OpAsmParser::UnresolvedOperand> operand,
79 Type &type) {
80 if (operand)
81 type = refType;
82 return success();
83}
84
85static void printOptionalTypeMatch(OpAsmPrinter &p, Operation *op, Type refType,
86 Value operand, Type type) {
87 // Nothing to do - this is strictly an implicit parsing helper.
88}
89
91 OpAsmParser &parser, Type refType,
92 std::optional<OpAsmParser::UnresolvedOperand> operand, Type &type) {
93 if (operand)
94 type = seq::ImmutableType::get(refType);
95 return success();
96}
97
98static void printOptionalImmutableTypeMatch(OpAsmPrinter &p, Operation *op,
99 Type refType, Value operand,
100 Type type) {
101 // Nothing to do - this is strictly an implicit parsing helper.
102}
103
104//===----------------------------------------------------------------------===//
105// ReadPortOp
106//===----------------------------------------------------------------------===//
107
108ParseResult ReadPortOp::parse(OpAsmParser &parser, OperationState &result) {
109 llvm::SMLoc loc = parser.getCurrentLocation();
110
111 OpAsmParser::UnresolvedOperand memOperand, rdenOperand;
112 bool hasRdEn = false;
113 llvm::SmallVector<OpAsmParser::UnresolvedOperand, 2> addressOperands;
114 seq::HLMemType memType;
115
116 if (parser.parseOperand(memOperand) ||
117 parser.parseOperandList(addressOperands, OpAsmParser::Delimiter::Square))
118 return failure();
119
120 if (succeeded(parser.parseOptionalKeyword("rden"))) {
121 if (failed(parser.parseOperand(rdenOperand)))
122 return failure();
123 hasRdEn = true;
124 }
125
126 if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
127 parser.parseType(memType))
128 return failure();
129
130 llvm::SmallVector<Type> operandTypes = memType.getAddressTypes();
131 operandTypes.insert(operandTypes.begin(), memType);
132
133 llvm::SmallVector<OpAsmParser::UnresolvedOperand> allOperands = {memOperand};
134 llvm::copy(addressOperands, std::back_inserter(allOperands));
135 if (hasRdEn) {
136 operandTypes.push_back(parser.getBuilder().getI1Type());
137 allOperands.push_back(rdenOperand);
138 }
139
140 if (parser.resolveOperands(allOperands, operandTypes, loc, result.operands))
141 return failure();
142
143 result.addTypes(memType.getElementType());
144
145 llvm::SmallVector<int32_t, 2> operandSizes;
146 operandSizes.push_back(1); // memory handle
147 operandSizes.push_back(addressOperands.size());
148 operandSizes.push_back(hasRdEn ? 1 : 0);
149 result.addAttribute("operandSegmentSizes",
150 parser.getBuilder().getDenseI32ArrayAttr(operandSizes));
151 return success();
152}
153
154void ReadPortOp::print(OpAsmPrinter &p) {
155 p << " " << getMemory() << "[" << getAddresses() << "]";
156 if (getRdEn())
157 p << " rden " << getRdEn();
158 p.printOptionalAttrDict((*this)->getAttrs(), {"operandSegmentSizes"});
159 p << " : " << getMemory().getType();
160}
161
162void ReadPortOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
163 auto memName = getMemory().getDefiningOp<seq::HLMemOp>().getName();
164 setNameFn(getReadData(), (memName + "_rdata").str());
165}
166
167void ReadPortOp::build(OpBuilder &builder, OperationState &result, Value memory,
168 ValueRange addresses, Value rdEn, unsigned latency) {
169 auto memType = cast<seq::HLMemType>(memory.getType());
170 ReadPortOp::build(builder, result, memType.getElementType(), memory,
171 addresses, rdEn, latency);
172}
173
174//===----------------------------------------------------------------------===//
175// WritePortOp
176//===----------------------------------------------------------------------===//
177
178ParseResult WritePortOp::parse(OpAsmParser &parser, OperationState &result) {
179 llvm::SMLoc loc = parser.getCurrentLocation();
180 OpAsmParser::UnresolvedOperand memOperand, dataOperand, wrenOperand;
181 llvm::SmallVector<OpAsmParser::UnresolvedOperand, 2> addressOperands;
182 seq::HLMemType memType;
183
184 if (parser.parseOperand(memOperand) ||
185 parser.parseOperandList(addressOperands,
186 OpAsmParser::Delimiter::Square) ||
187 parser.parseOperand(dataOperand) || parser.parseKeyword("wren") ||
188 parser.parseOperand(wrenOperand) ||
189 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
190 parser.parseType(memType))
191 return failure();
192
193 llvm::SmallVector<Type> operandTypes = memType.getAddressTypes();
194 operandTypes.insert(operandTypes.begin(), memType);
195 operandTypes.push_back(memType.getElementType());
196 operandTypes.push_back(parser.getBuilder().getI1Type());
197
198 llvm::SmallVector<OpAsmParser::UnresolvedOperand, 2> allOperands(
199 addressOperands);
200 allOperands.insert(allOperands.begin(), memOperand);
201 allOperands.push_back(dataOperand);
202 allOperands.push_back(wrenOperand);
203
204 if (parser.resolveOperands(allOperands, operandTypes, loc, result.operands))
205 return failure();
206
207 return success();
208}
209
210void WritePortOp::print(OpAsmPrinter &p) {
211 p << " " << getMemory() << "[" << getAddresses() << "] " << getInData()
212 << " wren " << getWrEn();
213 p.printOptionalAttrDict((*this)->getAttrs());
214 p << " : " << getMemory().getType();
215}
216
217//===----------------------------------------------------------------------===//
218// HLMemOp
219//===----------------------------------------------------------------------===//
220
221void HLMemOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
222 setNameFn(getHandle(), getName());
223}
224
225void HLMemOp::build(OpBuilder &builder, OperationState &result, Value clk,
226 Value rst, StringRef name, llvm::ArrayRef<int64_t> shape,
227 Type elementType) {
228 HLMemType t = HLMemType::get(builder.getContext(), shape, elementType);
229 HLMemOp::build(builder, result, t, clk, rst, name);
230}
231
232//===----------------------------------------------------------------------===//
233// FIFOOp
234//===----------------------------------------------------------------------===//
235
236// Flag threshold custom directive
237static ParseResult parseFIFOFlagThreshold(OpAsmParser &parser,
238 IntegerAttr &threshold,
239 Type &outputFlagType,
240 StringRef directive) {
241 // look for an optional "almost_full $threshold" group.
242 if (succeeded(parser.parseOptionalKeyword(directive))) {
243 int64_t thresholdValue;
244 if (succeeded(parser.parseInteger(thresholdValue))) {
245 threshold = parser.getBuilder().getI64IntegerAttr(thresholdValue);
246 outputFlagType = parser.getBuilder().getI1Type();
247 return success();
248 }
249 return parser.emitError(parser.getNameLoc(),
250 "expected integer value after " + directive +
251 " directive");
252 }
253 return success();
254}
255
256ParseResult parseFIFOAFThreshold(OpAsmParser &parser, IntegerAttr &threshold,
257 Type &outputFlagType) {
258 return parseFIFOFlagThreshold(parser, threshold, outputFlagType,
259 "almost_full");
260}
261
262ParseResult parseFIFOAEThreshold(OpAsmParser &parser, IntegerAttr &threshold,
263 Type &outputFlagType) {
264 return parseFIFOFlagThreshold(parser, threshold, outputFlagType,
265 "almost_empty");
266}
267
268void printFIFOAFThreshold(OpAsmPrinter &p, Operation *op, IntegerAttr threshold,
269 Type outputFlagType) {
270 if (threshold)
271 p << "almost_full"
272 << " " << threshold.getInt();
273}
274
275void printFIFOAEThreshold(OpAsmPrinter &p, Operation *op, IntegerAttr threshold,
276 Type outputFlagType) {
277 if (threshold)
278 p << "almost_empty"
279 << " " << threshold.getInt();
280}
281
282void FIFOOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
283 setNameFn(getOutput(), "out");
284 setNameFn(getEmpty(), "empty");
285 setNameFn(getFull(), "full");
286 if (auto ae = getAlmostEmpty())
287 setNameFn(ae, "almostEmpty");
288 if (auto af = getAlmostFull())
289 setNameFn(af, "almostFull");
290}
291
292LogicalResult FIFOOp::verify() {
293 auto aet = getAlmostEmptyThreshold();
294 auto aft = getAlmostFullThreshold();
295 size_t depth = getDepth();
296 if (aft.has_value() && aft.value() > depth)
297 return emitOpError("almost full threshold must be <= FIFO depth");
298
299 if (aet.has_value() && aet.value() > depth)
300 return emitOpError("almost empty threshold must be <= FIFO depth");
301
302 return success();
303}
304
305//===----------------------------------------------------------------------===//
306// CompRegOp
307//===----------------------------------------------------------------------===//
308
309/// Suggest a name for each result value based on the saved result names
310/// attribute.
311void CompRegOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
312 if (auto name = getName())
313 setNameFn(getResult(), *name);
314}
315
316template <typename TOp>
317LogicalResult verifyResets(TOp op) {
318 if ((op.getReset() == nullptr) ^ (op.getResetValue() == nullptr))
319 return op->emitOpError(
320 "either reset and resetValue or neither must be specified");
321 bool hasReset = op.getReset() != nullptr;
322 if (hasReset && op.getResetValue().getType() != op.getInput().getType())
323 return op->emitOpError("reset value must be the same type as the input");
324
325 return success();
326}
327
328std::optional<size_t> CompRegOp::getTargetResultIndex() { return 0; }
329
330LogicalResult CompRegOp::verify() { return verifyResets(*this); }
331
332//===----------------------------------------------------------------------===//
333// CompRegClockEnabledOp
334//===----------------------------------------------------------------------===//
335
336/// Suggest a name for each result value based on the saved result names
337/// attribute.
338void CompRegClockEnabledOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
339 if (auto name = getName())
340 setNameFn(getResult(), *name);
341}
342
343std::optional<size_t> CompRegClockEnabledOp::getTargetResultIndex() {
344 return 0;
345}
346
347LogicalResult CompRegClockEnabledOp::verify() { return verifyResets(*this); }
348
349LogicalResult CompRegClockEnabledOp::canonicalize(CompRegClockEnabledOp op,
350 PatternRewriter &rewriter) {
351 // reg(comb.mux(en, d, ?), en) -> reg(d, en)
352 // reg(arith.select(en, d, ?), en) -> reg(d, en)
353 auto *inputOp = op.getInput().getDefiningOp();
354 if (isa_and_nonnull<comb::MuxOp, arith::SelectOp>(inputOp) &&
355 inputOp->getOperand(0) == op.getClockEnable()) {
356 rewriter.modifyOpInPlace(
357 op, [&] { op.getInputMutable().assign(inputOp->getOperand(1)); });
358 return success();
359 }
360
361 // Match constant clock enable values.
362 APInt en;
363 if (mlir::matchPattern(op.getClockEnable(), mlir::m_ConstantInt(&en))) {
364 if (en.isAllOnes()) {
365 rewriter.replaceOpWithNewOp<seq::CompRegOp>(
366 op, op.getInput(), op.getClk(), op.getNameAttr(), op.getReset(),
367 op.getResetValue(), op.getInitialValue(), op.getInnerSymAttr());
368 return success();
369 }
370 }
371
372 return failure();
373}
374
375//===----------------------------------------------------------------------===//
376// ShiftRegOp
377//===----------------------------------------------------------------------===//
378
379void ShiftRegOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
380 // If the wire has an optional 'name' attribute, use it.
381 if (auto name = getName())
382 setNameFn(getResult(), *name);
383}
384
385std::optional<size_t> ShiftRegOp::getTargetResultIndex() { return 0; }
386
387LogicalResult ShiftRegOp::verify() {
388 if (failed(verifyResets(*this)))
389 return failure();
390 return success();
391}
392
393//===----------------------------------------------------------------------===//
394// FirRegOp
395//===----------------------------------------------------------------------===//
396
397void FirRegOp::build(OpBuilder &builder, OperationState &result, Value input,
398 Value clk, StringAttr name, hw::InnerSymAttr innerSym,
399 Attribute preset) {
400
401 OpBuilder::InsertionGuard guard(builder);
402
403 result.addOperands(input);
404 result.addOperands(clk);
405
406 result.addAttribute(getNameAttrName(result.name), name);
407
408 if (innerSym)
409 result.addAttribute(getInnerSymAttrName(result.name), innerSym);
410
411 if (preset)
412 result.addAttribute(getPresetAttrName(result.name), preset);
413
414 result.addTypes(input.getType());
415}
416
417void FirRegOp::build(OpBuilder &builder, OperationState &result, Value input,
418 Value clk, StringAttr name, Value reset, Value resetValue,
419 hw::InnerSymAttr innerSym, bool isAsync) {
420
421 OpBuilder::InsertionGuard guard(builder);
422
423 result.addOperands(input);
424 result.addOperands(clk);
425 result.addOperands(reset);
426 result.addOperands(resetValue);
427
428 result.addAttribute(getNameAttrName(result.name), name);
429 if (isAsync)
430 result.addAttribute(getIsAsyncAttrName(result.name), builder.getUnitAttr());
431
432 if (innerSym)
433 result.addAttribute(getInnerSymAttrName(result.name), innerSym);
434
435 result.addTypes(input.getType());
436}
437
438ParseResult FirRegOp::parse(OpAsmParser &parser, OperationState &result) {
439 auto &builder = parser.getBuilder();
440 llvm::SMLoc loc = parser.getCurrentLocation();
441
442 using Op = OpAsmParser::UnresolvedOperand;
443
444 Op next, clk;
445 if (parser.parseOperand(next) || parser.parseKeyword("clock") ||
446 parser.parseOperand(clk))
447 return failure();
448
449 if (succeeded(parser.parseOptionalKeyword("sym"))) {
450 hw::InnerSymAttr innerSym;
451 if (parser.parseCustomAttributeWithFallback(innerSym, /*type=*/nullptr,
452 "inner_sym", result.attributes))
453 return failure();
454 }
455
456 // Parse reset [sync|async] %reset, %value
457 std::optional<std::pair<Op, Op>> resetAndValue;
458 if (succeeded(parser.parseOptionalKeyword("reset"))) {
459 bool isAsync;
460 if (succeeded(parser.parseOptionalKeyword("async")))
461 isAsync = true;
462 else if (succeeded(parser.parseOptionalKeyword("sync")))
463 isAsync = false;
464 else
465 return parser.emitError(loc, "invalid reset, expected 'sync' or 'async'");
466 if (isAsync)
467 result.attributes.append("isAsync", builder.getUnitAttr());
468
469 resetAndValue = {{}, {}};
470 if (parser.parseOperand(resetAndValue->first) || parser.parseComma() ||
471 parser.parseOperand(resetAndValue->second))
472 return failure();
473 }
474
475 std::optional<APInt> presetValue;
476 llvm::SMLoc presetValueLoc;
477 if (succeeded(parser.parseOptionalKeyword("preset"))) {
478 presetValueLoc = parser.getCurrentLocation();
479 OptionalParseResult presetIntResult =
480 parser.parseOptionalInteger(presetValue.emplace());
481 if (!presetIntResult.has_value() || failed(*presetIntResult))
482 return parser.emitError(loc, "expected integer value");
483 }
484
485 Type ty;
486 if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
487 parser.parseType(ty))
488 return failure();
489 result.addTypes({ty});
490
491 if (presetValue) {
492 uint64_t width = 0;
493 if (hw::type_isa<seq::ClockType>(ty)) {
494 width = 1;
495 } else {
496 int64_t maybeWidth = hw::getBitWidth(ty);
497 if (maybeWidth < 0)
498 return parser.emitError(presetValueLoc,
499 "cannot preset register of unknown width");
500 width = maybeWidth;
501 }
502
503 APInt presetResult = presetValue->sextOrTrunc(width);
504 if (presetResult.zextOrTrunc(presetValue->getBitWidth()) != *presetValue)
505 return parser.emitError(loc, "preset value too large");
506
507 auto builder = parser.getBuilder();
508 auto presetTy = builder.getIntegerType(width);
509 auto resultAttr = builder.getIntegerAttr(presetTy, presetResult);
510 result.addAttribute("preset", resultAttr);
511 }
512
513 setNameFromResult(parser, result);
514
515 if (parser.resolveOperand(next, ty, result.operands))
516 return failure();
517
518 Type clkTy = ClockType::get(result.getContext());
519 if (parser.resolveOperand(clk, clkTy, result.operands))
520 return failure();
521
522 if (resetAndValue) {
523 Type i1 = IntegerType::get(result.getContext(), 1);
524 if (parser.resolveOperand(resetAndValue->first, i1, result.operands) ||
525 parser.resolveOperand(resetAndValue->second, ty, result.operands))
526 return failure();
527 }
528
529 return success();
530}
531
532void FirRegOp::print(::mlir::OpAsmPrinter &p) {
533 SmallVector<StringRef> elidedAttrs = {
534 getInnerSymAttrName(), getIsAsyncAttrName(), getPresetAttrName()};
535
536 p << ' ' << getNext() << " clock " << getClk();
537
538 if (auto sym = getInnerSymAttr()) {
539 p << " sym ";
540 sym.print(p);
541 }
542
543 if (hasReset()) {
544 p << " reset " << (getIsAsync() ? "async" : "sync") << ' ';
545 p << getReset() << ", " << getResetValue();
546 }
547
548 if (auto preset = getPresetAttr()) {
549 p << " preset " << preset.getValue();
550 }
551
552 if (canElideName(p, *this))
553 elidedAttrs.push_back("name");
554
555 p.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
556 p << " : " << getNext().getType();
557}
558
559/// Verifier for the FIR register op.
560LogicalResult FirRegOp::verify() {
561 if (getReset() || getResetValue() || getIsAsync()) {
562 if (!getReset() || !getResetValue())
563 return emitOpError("must specify reset and reset value");
564 } else {
565 if (getIsAsync())
566 return emitOpError("register with no reset cannot be async");
567 }
568 if (auto preset = getPresetAttr()) {
569 int64_t presetWidth = hw::getBitWidth(preset.getType());
570 int64_t width = hw::getBitWidth(getType());
571 if (preset.getType() != getType() && presetWidth != width)
572 return emitOpError("preset type width must match register type");
573 }
574 return success();
575}
576
577/// Suggest a name for each result value based on the saved result names
578/// attribute.
579void FirRegOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
580 // If the register has an optional 'name' attribute, use it.
581 if (!getName().empty())
582 setNameFn(getResult(), getName());
583}
584
585std::optional<size_t> FirRegOp::getTargetResultIndex() { return 0; }
586
587LogicalResult FirRegOp::canonicalize(FirRegOp op, PatternRewriter &rewriter) {
588
589 // If the register has a constant zero reset, drop the reset and reset value
590 // altogether (And preserve the PresetAttr).
591 if (auto reset = op.getReset()) {
592 if (auto constOp = reset.getDefiningOp<hw::ConstantOp>()) {
593 if (constOp.getValue().isZero()) {
594 rewriter.replaceOpWithNewOp<FirRegOp>(
595 op, op.getNext(), op.getClk(), op.getNameAttr(),
596 op.getInnerSymAttr(), op.getPresetAttr());
597 return success();
598 }
599 }
600 }
601
602 // If the register has a symbol, we can't optimize it away.
603 if (op.getInnerSymAttr())
604 return failure();
605
606 // Replace a register with a trivial feedback or constant clock with a
607 // constant zero.
608 // TODO: Once HW aggregate constant values are supported, move this
609 // canonicalization to the folder.
610 auto isConstant = [&]() -> bool {
611 if (op.getNext() == op.getResult())
612 return true;
613 if (auto clk = op.getClk().getDefiningOp<seq::ToClockOp>())
614 return clk.getInput().getDefiningOp<hw::ConstantOp>();
615 return false;
616 };
617
618 // Preset can block canonicalization only if it is non-zero.
619 bool replaceWithConstZero = true;
620 if (auto preset = op.getPresetAttr())
621 if (!preset.getValue().isZero())
622 replaceWithConstZero = false;
623
624 if (isConstant() && !op.getResetValue() && replaceWithConstZero) {
625 if (isa<seq::ClockType>(op.getType())) {
626 rewriter.replaceOpWithNewOp<seq::ConstClockOp>(
627 op, seq::ClockConstAttr::get(rewriter.getContext(), ClockConst::Low));
628 } else {
629 auto constant = rewriter.create<hw::ConstantOp>(
630 op.getLoc(), APInt::getZero(hw::getBitWidth(op.getType())));
631 rewriter.replaceOpWithNewOp<hw::BitcastOp>(op, op.getType(), constant);
632 }
633 return success();
634 }
635
636 // Canonicalize registers with mux-based constant drivers.
637 // This pattern matches registers where the next value is a mux with one
638 // branch being the register itself (creating a self-loop) and the other
639 // branch being a constant. In such cases, the register effectively holds a
640 // constant value and can be replaced with that constant.
641 if (auto nextMux = op.getNext().getDefiningOp<comb::MuxOp>()) {
642 // Reject optimization if register has preset attribute (for simplicity)
643 if (op.getPresetAttr())
644 return failure();
645
646 Attribute value;
647 Value replacedValue;
648
649 // Check if true branch is self-loop and false branch is constant
650 if (nextMux.getTrueValue() == op.getResult() &&
651 matchPattern(nextMux.getFalseValue(), m_Constant(&value))) {
652 replacedValue = nextMux.getFalseValue();
653 }
654 // Check if false branch is self-loop and true branch is constant
655 else if (nextMux.getFalseValue() == op.getResult() &&
656 matchPattern(nextMux.getTrueValue(), m_Constant(&value))) {
657 replacedValue = nextMux.getTrueValue();
658 }
659
660 if (!replacedValue)
661 return failure();
662
663 // Verify reset value compatibility: if register has reset, it must be
664 // a constant that matches the mux constant
665 if (op.getResetValue()) {
666 Attribute resetConst;
667 if (matchPattern(op.getResetValue(), m_Constant(&resetConst))) {
668 if (resetConst != value)
669 return failure();
670 } else {
671 // Non-constant reset value prevents optimization
672 return failure();
673 }
674 }
675
676 assert(replacedValue);
677 // Apply the optimization if all conditions are met
678 rewriter.replaceOp(op, replacedValue);
679 return success();
680 }
681
682 // For reset-less 1d array registers, replace an uninitialized element with
683 // constant zero. For example, let `r` be a 2xi1 register and its next value
684 // be `{foo, r[0]}`. `r[0]` is connected to itself so will never be
685 // initialized. If we don't enable aggregate preservation, `r_0` is replaced
686 // with `0`. Hence this canonicalization replaces 0th element of the next
687 // value with zero to match the behaviour.
688 if (!op.getReset() && !op.getPresetAttr()) {
689 if (auto arrayCreate = op.getNext().getDefiningOp<hw::ArrayCreateOp>()) {
690 // For now only support 1d arrays.
691 // TODO: Support nested arrays and bundles.
692 if (isa<IntegerType>(
693 hw::type_cast<hw::ArrayType>(op.getResult().getType())
694 .getElementType())) {
695 SmallVector<Value> nextOperands;
696 bool changed = false;
697 for (const auto &[i, value] :
698 llvm::enumerate(arrayCreate.getOperands())) {
699 auto index = arrayCreate.getOperands().size() - i - 1;
700 APInt elementIndex;
701 // Check that the corresponding operand is op's element.
702 if (auto arrayGet = value.getDefiningOp<hw::ArrayGetOp>())
703 if (arrayGet.getInput() == op.getResult() &&
704 matchPattern(arrayGet.getIndex(),
705 m_ConstantInt(&elementIndex)) &&
706 elementIndex == index) {
707 nextOperands.push_back(rewriter.create<hw::ConstantOp>(
708 op.getLoc(),
709 APInt::getZero(hw::getBitWidth(arrayGet.getType()))));
710 changed = true;
711 continue;
712 }
713 nextOperands.push_back(value);
714 }
715 // If one of the operands is self loop, update the next value.
716 if (changed) {
717 auto newNextVal = rewriter.create<hw::ArrayCreateOp>(
718 arrayCreate.getLoc(), nextOperands);
719 if (arrayCreate->hasOneUse())
720 // If the original next value has a single use, we can replace the
721 // value directly.
722 rewriter.replaceOp(arrayCreate, newNextVal);
723 else {
724 // Otherwise, replace the entire firreg with a new one.
725 rewriter.replaceOpWithNewOp<FirRegOp>(op, newNextVal, op.getClk(),
726 op.getNameAttr(),
727 op.getInnerSymAttr());
728 }
729
730 return success();
731 }
732 }
733 }
734 }
735
736 return failure();
737}
738
739OpFoldResult FirRegOp::fold(FoldAdaptor adaptor) {
740 // If the register has a symbol or preset value, we can't optimize it away.
741 // TODO: Handle a preset value.
742 if (getInnerSymAttr())
743 return {};
744
745 auto presetAttr = getPresetAttr();
746
747 // If the register is held in permanent reset, replace it with its reset
748 // value. This works trivially if the reset is asynchronous and therefore
749 // level-sensitive, in which case it will always immediately assume the reset
750 // value in silicon. If it is synchronous, the register value is undefined
751 // until the first clock edge at which point it becomes the reset value, in
752 // which case we simply define the initial value to already be the reset
753 // value. Works only if no preset.
754 if (!presetAttr)
755 if (auto reset = getReset())
756 if (auto constOp = reset.getDefiningOp<hw::ConstantOp>())
757 if (constOp.getValue().isOne())
758 return getResetValue();
759
760 // If the register's next value is trivially it's current value, or the
761 // register is never clocked, we can replace the register with a constant
762 // value.
763 bool isTrivialFeedback = (getNext() == getResult());
764 bool isNeverClocked =
765 adaptor.getClk() != nullptr; // clock operand is constant
766 if (!isTrivialFeedback && !isNeverClocked)
767 return {};
768
769 // If the register has a const reset value, and no preset, we can replace it
770 // with the const reset. We cannot replace it with a non-constant reset value.
771 if (auto resetValue = getResetValue()) {
772 if (auto *op = resetValue.getDefiningOp()) {
773 if (op->hasTrait<OpTrait::ConstantLike>() && !presetAttr)
774 return resetValue;
775 if (auto constOp = dyn_cast<hw::ConstantOp>(op))
776 if (presetAttr.getValue() == constOp.getValue())
777 return resetValue;
778 }
779 return {};
780 }
781
782 // Otherwise we want to replace the register with a constant 0. For now this
783 // only works with integer types.
784 auto intType = dyn_cast<IntegerType>(getType());
785 if (!intType)
786 return {};
787 // If preset present, then replace with preset.
788 if (presetAttr)
789 return presetAttr;
790 return IntegerAttr::get(intType, 0);
791}
792
793//===----------------------------------------------------------------------===//
794// ClockGateOp
795//===----------------------------------------------------------------------===//
796
797OpFoldResult ClockGateOp::fold(FoldAdaptor adaptor) {
798 // Forward the clock if one of the enables is always true.
799 if (isConstantOne(adaptor.getEnable()) ||
800 isConstantOne(adaptor.getTestEnable()))
801 return getInput();
802
803 // Fold to a constant zero clock if the enables are always false.
804 if (isConstantZero(adaptor.getEnable()) &&
805 (!getTestEnable() || isConstantZero(adaptor.getTestEnable())))
806 return ClockConstAttr::get(getContext(), ClockConst::Low);
807
808 // Forward constant zero clocks.
809 if (auto clockAttr = dyn_cast_or_null<ClockConstAttr>(adaptor.getInput()))
810 if (clockAttr.getValue() == ClockConst::Low)
811 return ClockConstAttr::get(getContext(), ClockConst::Low);
812
813 // Transitive clock gating - eliminate clock gates that are driven by an
814 // identical enable signal somewhere higher in the clock gate hierarchy.
815 auto clockGateInputOp = getInput().getDefiningOp<ClockGateOp>();
816 while (clockGateInputOp) {
817 if (clockGateInputOp.getEnable() == getEnable() &&
818 clockGateInputOp.getTestEnable() == getTestEnable())
819 return getInput();
820 clockGateInputOp = clockGateInputOp.getInput().getDefiningOp<ClockGateOp>();
821 }
822
823 return {};
824}
825
826LogicalResult ClockGateOp::canonicalize(ClockGateOp op,
827 PatternRewriter &rewriter) {
828 // Remove constant false test enable.
829 if (auto testEnable = op.getTestEnable()) {
830 if (auto constOp = testEnable.getDefiningOp<hw::ConstantOp>()) {
831 if (constOp.getValue().isZero()) {
832 rewriter.modifyOpInPlace(op,
833 [&] { op.getTestEnableMutable().clear(); });
834 return success();
835 }
836 }
837 }
838
839 return failure();
840}
841
842std::optional<size_t> ClockGateOp::getTargetResultIndex() {
843 return std::nullopt;
844}
845
846//===----------------------------------------------------------------------===//
847// ClockMuxOp
848//===----------------------------------------------------------------------===//
849
850OpFoldResult ClockMuxOp::fold(FoldAdaptor adaptor) {
851 if (isConstantOne(adaptor.getCond()))
852 return getTrueClock();
853 if (isConstantZero(adaptor.getCond()))
854 return getFalseClock();
855 return {};
856}
857
858//===----------------------------------------------------------------------===//
859// FirMemOp
860//===----------------------------------------------------------------------===//
861
862LogicalResult FirMemOp::canonicalize(FirMemOp op, PatternRewriter &rewriter) {
863 // Do not change memories if symbols point to them.
864 if (op.getInnerSymAttr())
865 return failure();
866
867 bool readOnly = true, writeOnly = true;
868
869 // If the memory has no read ports, erase it.
870 for (auto *user : op->getUsers()) {
871 if (isa<FirMemReadOp, FirMemReadWriteOp>(user)) {
872 writeOnly = false;
873 }
874 if (isa<FirMemWriteOp, FirMemReadWriteOp>(user)) {
875 readOnly = false;
876 }
877 assert((isa<FirMemReadOp, FirMemWriteOp, FirMemReadWriteOp>(user)) &&
878 "invalid seq.firmem user");
879 }
880 if (writeOnly) {
881 for (auto *user : llvm::make_early_inc_range(op->getUsers()))
882 rewriter.eraseOp(user);
883
884 rewriter.eraseOp(op);
885 return success();
886 }
887
888 if (readOnly && !op.getInit()) {
889 // Replace all read ports with a constant 0.
890 for (auto *user : llvm::make_early_inc_range(op->getUsers())) {
891 auto readOp = cast<FirMemReadOp>(user);
892 Value zero = rewriter.create<hw::ConstantOp>(
893 readOp.getLoc(), APInt::getZero(hw::getBitWidth(readOp.getType())));
894 if (readOp.getType() != zero.getType())
895 zero = rewriter.create<hw::BitcastOp>(readOp.getLoc(), readOp.getType(),
896 zero);
897 rewriter.replaceOp(readOp, zero);
898 }
899 rewriter.eraseOp(op);
900 return success();
901 }
902 return failure();
903}
904
905void FirMemOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
906 auto nameAttr = (*this)->getAttrOfType<StringAttr>("name");
907 if (!nameAttr.getValue().empty())
908 setNameFn(getResult(), nameAttr.getValue());
909}
910
911std::optional<size_t> FirMemOp::getTargetResultIndex() { return 0; }
912
913template <class Op>
914static LogicalResult verifyFirMemMask(Op op) {
915 if (auto mask = op.getMask()) {
916 auto memType = op.getMemory().getType();
917 if (!memType.getMaskWidth())
918 return op.emitOpError("has mask operand but memory type '")
919 << memType << "' has no mask";
920 auto expected = IntegerType::get(op.getContext(), *memType.getMaskWidth());
921 if (mask.getType() != expected)
922 return op.emitOpError("has mask operand of type '")
923 << mask.getType() << "', but memory type requires '" << expected
924 << "'";
925 }
926 return success();
927}
928
929LogicalResult FirMemWriteOp::verify() { return verifyFirMemMask(*this); }
930LogicalResult FirMemReadWriteOp::verify() { return verifyFirMemMask(*this); }
931
932static bool isConstClock(Value value) {
933 if (!value)
934 return false;
935 return value.getDefiningOp<seq::ConstClockOp>();
936}
937
938static bool isConstZero(Value value) {
939 if (value)
940 if (auto constOp = value.getDefiningOp<hw::ConstantOp>())
941 return constOp.getValue().isZero();
942 return false;
943}
944
945static bool isConstAllOnes(Value value) {
946 if (value)
947 if (auto constOp = value.getDefiningOp<hw::ConstantOp>())
948 return constOp.getValue().isAllOnes();
949 return false;
950}
951
952LogicalResult FirMemReadOp::canonicalize(FirMemReadOp op,
953 PatternRewriter &rewriter) {
954 // Remove the enable if it is constant true.
955 if (isConstAllOnes(op.getEnable())) {
956 rewriter.modifyOpInPlace(op, [&] { op.getEnableMutable().erase(0); });
957 return success();
958 }
959 return failure();
960}
961
962LogicalResult FirMemWriteOp::canonicalize(FirMemWriteOp op,
963 PatternRewriter &rewriter) {
964 // Remove the write port if it is trivially dead.
965 if (isConstZero(op.getEnable()) || isConstZero(op.getMask()) ||
966 isConstClock(op.getClk())) {
967 auto memOp = op.getMemory().getDefiningOp<FirMemOp>();
968 if (memOp.getInnerSymAttr())
969 return failure();
970 rewriter.eraseOp(op);
971 return success();
972 }
973 bool anyChanges = false;
974
975 // Remove the enable if it is constant true.
976 if (auto enable = op.getEnable(); isConstAllOnes(enable)) {
977 rewriter.modifyOpInPlace(op, [&] { op.getEnableMutable().erase(0); });
978 anyChanges = true;
979 }
980
981 // Remove the mask if it is all ones.
982 if (auto mask = op.getMask(); isConstAllOnes(mask)) {
983 rewriter.modifyOpInPlace(op, [&] { op.getMaskMutable().erase(0); });
984 anyChanges = true;
985 }
986
987 return success(anyChanges);
988}
989
990LogicalResult FirMemReadWriteOp::canonicalize(FirMemReadWriteOp op,
991 PatternRewriter &rewriter) {
992 // Replace the read-write port with a read port if the write behavior is
993 // trivially disabled.
994 if (isConstZero(op.getEnable()) || isConstZero(op.getMask()) ||
995 isConstClock(op.getClk()) || isConstZero(op.getMode())) {
996 auto opAttrs = op->getAttrs();
997 auto opAttrNames = op.getAttributeNames();
998 auto newOp = rewriter.replaceOpWithNewOp<FirMemReadOp>(
999 op, op.getMemory(), op.getAddress(), op.getClk(), op.getEnable());
1000 for (auto namedAttr : opAttrs)
1001 if (!llvm::is_contained(opAttrNames, namedAttr.getName()))
1002 newOp->setAttr(namedAttr.getName(), namedAttr.getValue());
1003 return success();
1004 }
1005 bool anyChanges = false;
1006
1007 // Remove the enable if it is constant true.
1008 if (auto enable = op.getEnable(); isConstAllOnes(enable)) {
1009 rewriter.modifyOpInPlace(op, [&] { op.getEnableMutable().erase(0); });
1010 anyChanges = true;
1011 }
1012
1013 // Remove the mask if it is all ones.
1014 if (auto mask = op.getMask(); isConstAllOnes(mask)) {
1015 rewriter.modifyOpInPlace(op, [&] { op.getMaskMutable().erase(0); });
1016 anyChanges = true;
1017 }
1018
1019 return success(anyChanges);
1020}
1021
1022//===----------------------------------------------------------------------===//
1023// ConstClockOp
1024//===----------------------------------------------------------------------===//
1025
1026OpFoldResult ConstClockOp::fold(FoldAdaptor adaptor) {
1027 return ClockConstAttr::get(getContext(), getValue());
1028}
1029
1030//===----------------------------------------------------------------------===//
1031// ToClockOp/FromClockOp
1032//===----------------------------------------------------------------------===//
1033
1034LogicalResult ToClockOp::canonicalize(ToClockOp op, PatternRewriter &rewriter) {
1035 if (auto fromClock = op.getInput().getDefiningOp<FromClockOp>()) {
1036 rewriter.replaceOp(op, fromClock.getInput());
1037 return success();
1038 }
1039 return failure();
1040}
1041
1042OpFoldResult ToClockOp::fold(FoldAdaptor adaptor) {
1043 if (auto fromClock = getInput().getDefiningOp<FromClockOp>())
1044 return fromClock.getInput();
1045 if (auto intAttr = dyn_cast_or_null<IntegerAttr>(adaptor.getInput())) {
1046 auto value =
1047 intAttr.getValue().isZero() ? ClockConst::Low : ClockConst::High;
1048 return ClockConstAttr::get(getContext(), value);
1049 }
1050 return {};
1051}
1052
1053LogicalResult FromClockOp::canonicalize(FromClockOp op,
1054 PatternRewriter &rewriter) {
1055 if (auto toClock = op.getInput().getDefiningOp<ToClockOp>()) {
1056 rewriter.replaceOp(op, toClock.getInput());
1057 return success();
1058 }
1059 return failure();
1060}
1061
1062OpFoldResult FromClockOp::fold(FoldAdaptor adaptor) {
1063 if (auto toClock = getInput().getDefiningOp<ToClockOp>())
1064 return toClock.getInput();
1065 if (auto clockAttr = dyn_cast_or_null<ClockConstAttr>(adaptor.getInput())) {
1066 auto ty = IntegerType::get(getContext(), 1);
1067 return IntegerAttr::get(ty, clockAttr.getValue() == ClockConst::High);
1068 }
1069 return {};
1070}
1071
1072//===----------------------------------------------------------------------===//
1073// ClockInverterOp
1074//===----------------------------------------------------------------------===//
1075
1076OpFoldResult ClockInverterOp::fold(FoldAdaptor adaptor) {
1077 if (auto chainedInv = getInput().getDefiningOp<ClockInverterOp>())
1078 return chainedInv.getInput();
1079 if (auto clockAttr = dyn_cast_or_null<ClockConstAttr>(adaptor.getInput())) {
1080 auto clockIn = clockAttr.getValue() == ClockConst::High;
1081 return ClockConstAttr::get(getContext(),
1082 clockIn ? ClockConst::Low : ClockConst::High);
1083 }
1084 return {};
1085}
1086
1087//===----------------------------------------------------------------------===//
1088// FIR memory helper
1089//===----------------------------------------------------------------------===//
1090
1091FirMemory::FirMemory(hw::HWModuleGeneratedOp op) {
1092 depth = op->getAttrOfType<IntegerAttr>("depth").getInt();
1093 numReadPorts = op->getAttrOfType<IntegerAttr>("numReadPorts").getUInt();
1094 numWritePorts = op->getAttrOfType<IntegerAttr>("numWritePorts").getUInt();
1095 numReadWritePorts =
1096 op->getAttrOfType<IntegerAttr>("numReadWritePorts").getUInt();
1097 readLatency = op->getAttrOfType<IntegerAttr>("readLatency").getUInt();
1098 writeLatency = op->getAttrOfType<IntegerAttr>("writeLatency").getUInt();
1099 dataWidth = op->getAttrOfType<IntegerAttr>("width").getUInt();
1100 if (op->hasAttrOfType<IntegerAttr>("maskGran"))
1101 maskGran = op->getAttrOfType<IntegerAttr>("maskGran").getUInt();
1102 else
1103 maskGran = dataWidth;
1104 readUnderWrite = op->getAttrOfType<seq::RUWAttr>("readUnderWrite").getValue();
1105 writeUnderWrite =
1106 op->getAttrOfType<seq::WUWAttr>("writeUnderWrite").getValue();
1107 if (auto clockIDsAttr = op->getAttrOfType<ArrayAttr>("writeClockIDs"))
1108 for (auto clockID : clockIDsAttr)
1109 writeClockIDs.push_back(
1110 cast<IntegerAttr>(clockID).getValue().getZExtValue());
1111 initFilename = op->getAttrOfType<StringAttr>("initFilename").getValue();
1112 initIsBinary = op->getAttrOfType<BoolAttr>("initIsBinary").getValue();
1113 initIsInline = op->getAttrOfType<BoolAttr>("initIsInline").getValue();
1114}
1115
1116LogicalResult InitialOp::verify() {
1117 // Check outputs.
1118 auto *terminator = this->getBody().front().getTerminator();
1119 if (terminator->getOperands().size() != getNumResults())
1120 return emitError() << "result type doesn't match with the terminator";
1121 for (auto [lhs, rhs] :
1122 llvm::zip(terminator->getOperands().getTypes(), getResultTypes())) {
1123 if (cast<seq::ImmutableType>(rhs).getInnerType() != lhs)
1124 return emitError() << cast<seq::ImmutableType>(rhs).getInnerType()
1125 << " is expected but got " << lhs;
1126 }
1127
1128 auto blockArgs = this->getBody().front().getArguments();
1129
1130 if (blockArgs.size() != getNumOperands())
1131 return emitError() << "operand type doesn't match with the block arg";
1132
1133 for (auto [blockArg, operand] : llvm::zip(blockArgs, getOperands())) {
1134 if (blockArg.getType() !=
1135 cast<ImmutableType>(operand.getType()).getInnerType())
1136 return emitError()
1137 << blockArg.getType() << " is expected but got "
1138 << cast<ImmutableType>(operand.getType()).getInnerType();
1139 }
1140 return success();
1141}
1142void InitialOp::build(OpBuilder &builder, OperationState &result,
1143 TypeRange resultTypes, std::function<void()> ctor) {
1144 OpBuilder::InsertionGuard guard(builder);
1145
1146 builder.createBlock(result.addRegion());
1147 SmallVector<Type> types;
1148 for (auto t : resultTypes)
1149 types.push_back(seq::ImmutableType::get(t));
1150
1151 result.addTypes(types);
1152
1153 if (ctor)
1154 ctor();
1155}
1156
1157TypedValue<seq::ImmutableType>
1158circt::seq::createConstantInitialValue(OpBuilder builder, Location loc,
1159 mlir::IntegerAttr attr) {
1160 auto initial = builder.create<seq::InitialOp>(loc, attr.getType(), [&]() {
1161 auto constant = builder.create<hw::ConstantOp>(loc, attr);
1162 builder.create<seq::YieldOp>(loc, ArrayRef<Value>{constant});
1163 });
1164 return cast<TypedValue<seq::ImmutableType>>(initial->getResult(0));
1165}
1166
1167mlir::TypedValue<seq::ImmutableType>
1168circt::seq::createConstantInitialValue(OpBuilder builder, Operation *op) {
1169 assert(op->getNumResults() == 1 &&
1170 op->hasTrait<mlir::OpTrait::ConstantLike>());
1171 auto initial =
1172 builder.create<seq::InitialOp>(op->getLoc(), op->getResultTypes(), [&]() {
1173 auto clonedOp = builder.clone(*op);
1174 builder.create<seq::YieldOp>(op->getLoc(), clonedOp->getResults());
1175 });
1176 return cast<mlir::TypedValue<seq::ImmutableType>>(initial.getResult(0));
1177}
1178
1179Value circt::seq::unwrapImmutableValue(TypedValue<seq::ImmutableType> value) {
1180 auto resultNum = cast<OpResult>(value).getResultNumber();
1181 auto initialOp = value.getDefiningOp<seq::InitialOp>();
1182 assert(initialOp);
1183 return initialOp.getBodyBlock()->getTerminator()->getOperand(resultNum);
1184}
1185
1186FailureOr<seq::InitialOp> circt::seq::mergeInitialOps(Block *block) {
1187 SmallVector<Operation *> initialOps;
1188 for (auto &op : *block)
1189 if (isa<seq::InitialOp>(op))
1190 initialOps.push_back(&op);
1191
1192 if (!mlir::computeTopologicalSorting(initialOps, {}))
1193 return block->getParentOp()->emitError() << "initial ops cannot be "
1194 << "topologically sorted";
1195
1196 // No need to merge if there is only one initial op.
1197 if (initialOps.size() <= 1)
1198 return initialOps.empty() ? seq::InitialOp()
1199 : cast<seq::InitialOp>(initialOps[0]);
1200
1201 auto initialOp = cast<seq::InitialOp>(initialOps.front());
1202 auto yieldOp = cast<seq::YieldOp>(initialOp.getBodyBlock()->getTerminator());
1203
1204 llvm::MapVector<Value, Value>
1205 resultToYieldOperand; // seq.immutable value to operand.
1206
1207 for (auto [result, operand] :
1208 llvm::zip(initialOp.getResults(), yieldOp->getOperands()))
1209 resultToYieldOperand.insert({result, operand});
1210
1211 for (size_t i = 1; i < initialOps.size(); ++i) {
1212 auto currentInitialOp = cast<seq::InitialOp>(initialOps[i]);
1213 auto operands = currentInitialOp->getOperands();
1214 for (auto [blockArg, operand] :
1215 llvm::zip(currentInitialOp.getBodyBlock()->getArguments(), operands)) {
1216 if (auto initOp = operand.getDefiningOp<seq::InitialOp>()) {
1217 assert(resultToYieldOperand.count(operand) &&
1218 "it must be visited already");
1219 blockArg.replaceAllUsesWith(resultToYieldOperand.lookup(operand));
1220 } else {
1221 // Otherwise add the operand to the current block.
1222 initialOp.getBodyBlock()->addArgument(
1223 cast<seq::ImmutableType>(operand.getType()).getInnerType(),
1224 operand.getLoc());
1225 initialOp.getInputsMutable().append(operand);
1226 }
1227 }
1228
1229 auto currentYieldOp =
1230 cast<seq::YieldOp>(currentInitialOp.getBodyBlock()->getTerminator());
1231
1232 for (auto [result, operand] : llvm::zip(currentInitialOp.getResults(),
1233 currentYieldOp->getOperands()))
1234 resultToYieldOperand.insert({result, operand});
1235
1236 // Append the operands of the current yield op to the original yield op.
1237 yieldOp.getOperandsMutable().append(currentYieldOp.getOperands());
1238 currentYieldOp->erase();
1239
1240 // Append the operations of the current initial op to the original initial
1241 // op.
1242 initialOp.getBodyBlock()->getOperations().splice(
1243 initialOp.end(), currentInitialOp.getBodyBlock()->getOperations());
1244 }
1245
1246 // Move the terminator to the end of the block.
1247 yieldOp->moveBefore(initialOp.getBodyBlock(),
1248 initialOp.getBodyBlock()->end());
1249
1250 auto builder = OpBuilder::atBlockBegin(block);
1251 SmallVector<Type> types;
1252 for (auto [result, operand] : resultToYieldOperand)
1253 types.push_back(operand.getType());
1254
1255 // Create a new initial op which accumulates the results of the merged initial
1256 // ops.
1257 auto newInitial = builder.create<seq::InitialOp>(initialOp.getLoc(), types);
1258 newInitial.getInputsMutable().append(initialOp.getInputs());
1259
1260 for (auto [resultAndOperand, newResult] :
1261 llvm::zip(resultToYieldOperand, newInitial.getResults()))
1262 resultAndOperand.first.replaceAllUsesWith(newResult);
1263
1264 // Update the block arguments of the new initial op.
1265 for (auto oldBlockArg : initialOp.getBodyBlock()->getArguments()) {
1266 auto blockArg = newInitial.getBodyBlock()->addArgument(
1267 oldBlockArg.getType(), oldBlockArg.getLoc());
1268 oldBlockArg.replaceAllUsesWith(blockArg);
1269 }
1270
1271 newInitial.getBodyBlock()->getOperations().splice(
1272 newInitial.end(), initialOp.getBodyBlock()->getOperations());
1273
1274 // Clean up.
1275 while (!initialOps.empty())
1276 initialOps.pop_back_val()->erase();
1277
1278 return newInitial;
1279}
1280
1281//===----------------------------------------------------------------------===//
1282// TableGen generated logic.
1283//===----------------------------------------------------------------------===//
1284
1285// Provide the autogenerated implementation guts for the Op classes.
1286#define GET_OP_CLASSES
1287#include "circt/Dialect/Seq/Seq.cpp.inc"
assert(baseType &&"element must be base type")
MlirType elementType
Definition CHIRRTL.cpp:29
#define isdigit(x)
Definition FIRLexer.cpp:26
static bool isConstZero(Value value)
static std::optional< APInt > getInt(Value value)
Helper to convert a value to a constant integer if it is one.
static Block * getBodyBlock(FModuleLike mod)
void printFIFOAFThreshold(OpAsmPrinter &p, Operation *op, IntegerAttr threshold, Type outputFlagType)
Definition SeqOps.cpp:268
static bool isConstClock(Value value)
Definition SeqOps.cpp:932
static ParseResult parseFIFOFlagThreshold(OpAsmParser &parser, IntegerAttr &threshold, Type &outputFlagType, StringRef directive)
Definition SeqOps.cpp:237
static void printOptionalTypeMatch(OpAsmPrinter &p, Operation *op, Type refType, Value operand, Type type)
Definition SeqOps.cpp:85
static bool isConstAllOnes(Value value)
Definition SeqOps.cpp:945
static ParseResult parseOptionalImmutableTypeMatch(OpAsmParser &parser, Type refType, std::optional< OpAsmParser::UnresolvedOperand > operand, Type &type)
Definition SeqOps.cpp:90
void printFIFOAEThreshold(OpAsmPrinter &p, Operation *op, IntegerAttr threshold, Type outputFlagType)
Definition SeqOps.cpp:275
LogicalResult verifyResets(TOp op)
Definition SeqOps.cpp:317
static bool canElideName(OpAsmPrinter &p, Operation *op)
Definition SeqOps.cpp:61
ParseResult parseFIFOAEThreshold(OpAsmParser &parser, IntegerAttr &threshold, Type &outputFlagType)
Definition SeqOps.cpp:262
static LogicalResult verifyFirMemMask(Op op)
Definition SeqOps.cpp:914
static void printOptionalImmutableTypeMatch(OpAsmPrinter &p, Operation *op, Type refType, Value operand, Type type)
Definition SeqOps.cpp:98
static ParseResult parseOptionalTypeMatch(OpAsmParser &parser, Type refType, std::optional< OpAsmParser::UnresolvedOperand > operand, Type &type)
Definition SeqOps.cpp:77
static void setNameFromResult(OpAsmParser &parser, OperationState &result)
Definition SeqOps.cpp:50
ParseResult parseFIFOAFThreshold(OpAsmParser &parser, IntegerAttr &threshold, Type &outputFlagType)
Definition SeqOps.cpp:256
static InstancePath empty
create(data_type, value)
Definition hw.py:433
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition CalyxOps.cpp:55
bool isConstant(Operation *op)
Return true if the specified operation has a constant value.
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
int64_t getBitWidth(mlir::Type type)
Return the hardware bit width of a type.
Definition HWTypes.cpp:110
FailureOr< seq::InitialOp > mergeInitialOps(Block *block)
Definition SeqOps.cpp:1186
bool isValidIndexValues(Value hlmemHandle, ValueRange addresses)
Definition SeqOps.cpp:32
mlir::TypedValue< seq::ImmutableType > createConstantInitialValue(OpBuilder builder, Location loc, mlir::IntegerAttr attr)
Definition SeqOps.cpp:1158
Value unwrapImmutableValue(mlir::TypedValue< seq::ImmutableType > immutableVal)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
static bool isConstantZero(Attribute operand)
Determine whether a constant operand is a zero value.
Definition FoldUtils.h:27
static bool isConstantOne(Attribute operand)
Determine whether a constant operand is a one value.
Definition FoldUtils.h:34
function_ref< void(Value, StringRef)> OpAsmSetValueNameFn
Definition LLVM.h:183
Definition seq.py:1
FirMemory(hw::HWModuleGeneratedOp op)
Definition SeqOps.cpp:1091