18#include "mlir/Analysis/TopologicalSortUtils.h"
19#include "mlir/IR/Builders.h"
20#include "mlir/IR/DialectImplementation.h"
21#include "mlir/IR/Matchers.h"
22#include "mlir/IR/PatternMatch.h"
25#include "llvm/ADT/SmallString.h"
32 auto memType = cast<seq::HLMemType>(hlmemHandle.getType());
33 auto shape = memType.getShape();
34 if (shape.size() != addresses.size())
37 for (
auto [dim, addr] : llvm::zip(shape, addresses)) {
38 auto addrType = dyn_cast<IntegerType>(addr.getType());
41 if (addrType.getIntOrFloatBitWidth() != llvm::Log2_64_Ceil(dim))
50 if (result.attributes.getNamed(
"name"))
54 StringRef resultName = parser.getResultName(0).first;
55 if (!resultName.empty() &&
isdigit(resultName[0]))
57 result.addAttribute(
"name", parser.getBuilder().getStringAttr(resultName));
61 if (!op->hasAttr(
"name"))
64 auto name = op->getAttrOfType<StringAttr>(
"name").getValue();
68 SmallString<32> resultNameStr;
69 llvm::raw_svector_ostream tmpStream(resultNameStr);
70 p.printOperand(op->getResult(0), tmpStream);
71 auto actualName = tmpStream.str().drop_front();
72 return actualName == name;
77 std::optional<OpAsmParser::UnresolvedOperand> operand,
85 Value operand, Type type) {
90 OpAsmParser &parser, Type refType,
91 std::optional<OpAsmParser::UnresolvedOperand> operand, Type &type) {
93 type = seq::ImmutableType::get(refType);
98 Type refType, Value operand,
107ParseResult ReadPortOp::parse(OpAsmParser &parser, OperationState &result) {
108 llvm::SMLoc loc = parser.getCurrentLocation();
110 OpAsmParser::UnresolvedOperand memOperand, rdenOperand;
111 bool hasRdEn =
false;
112 llvm::SmallVector<OpAsmParser::UnresolvedOperand, 2> addressOperands;
113 seq::HLMemType memType;
115 if (parser.parseOperand(memOperand) ||
116 parser.parseOperandList(addressOperands, OpAsmParser::Delimiter::Square))
119 if (succeeded(parser.parseOptionalKeyword(
"rden"))) {
120 if (failed(parser.parseOperand(rdenOperand)))
125 if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
126 parser.parseType(memType))
129 llvm::SmallVector<Type> operandTypes = memType.getAddressTypes();
130 operandTypes.insert(operandTypes.begin(), memType);
132 llvm::SmallVector<OpAsmParser::UnresolvedOperand> allOperands = {memOperand};
133 llvm::copy(addressOperands, std::back_inserter(allOperands));
135 operandTypes.push_back(parser.getBuilder().getI1Type());
136 allOperands.push_back(rdenOperand);
139 if (parser.resolveOperands(allOperands, operandTypes, loc, result.operands))
142 result.addTypes(memType.getElementType());
144 llvm::SmallVector<int32_t, 2> operandSizes;
145 operandSizes.push_back(1);
146 operandSizes.push_back(addressOperands.size());
147 operandSizes.push_back(hasRdEn ? 1 : 0);
148 result.addAttribute(
"operandSegmentSizes",
149 parser.getBuilder().getDenseI32ArrayAttr(operandSizes));
153void ReadPortOp::print(OpAsmPrinter &p) {
154 p <<
" " << getMemory() <<
"[" << getAddresses() <<
"]";
156 p <<
" rden " << getRdEn();
157 p.printOptionalAttrDict((*this)->getAttrs(), {
"operandSegmentSizes"});
158 p <<
" : " << getMemory().getType();
162 auto memName = getMemory().getDefiningOp<seq::HLMemOp>().
getName();
163 setNameFn(getReadData(), (memName +
"_rdata").str());
166void ReadPortOp::build(OpBuilder &builder, OperationState &result, Value memory,
167 ValueRange addresses, Value rdEn,
unsigned latency) {
168 auto memType = cast<seq::HLMemType>(memory.getType());
169 ReadPortOp::build(builder, result, memType.getElementType(), memory,
170 addresses, rdEn, latency);
177ParseResult WritePortOp::parse(OpAsmParser &parser, OperationState &result) {
178 llvm::SMLoc loc = parser.getCurrentLocation();
179 OpAsmParser::UnresolvedOperand memOperand, dataOperand, wrenOperand;
180 llvm::SmallVector<OpAsmParser::UnresolvedOperand, 2> addressOperands;
181 seq::HLMemType memType;
183 if (parser.parseOperand(memOperand) ||
184 parser.parseOperandList(addressOperands,
185 OpAsmParser::Delimiter::Square) ||
186 parser.parseOperand(dataOperand) || parser.parseKeyword(
"wren") ||
187 parser.parseOperand(wrenOperand) ||
188 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
189 parser.parseType(memType))
192 llvm::SmallVector<Type> operandTypes = memType.getAddressTypes();
193 operandTypes.insert(operandTypes.begin(), memType);
194 operandTypes.push_back(memType.getElementType());
195 operandTypes.push_back(parser.getBuilder().getI1Type());
197 llvm::SmallVector<OpAsmParser::UnresolvedOperand, 2> allOperands(
199 allOperands.insert(allOperands.begin(), memOperand);
200 allOperands.push_back(dataOperand);
201 allOperands.push_back(wrenOperand);
203 if (parser.resolveOperands(allOperands, operandTypes, loc, result.operands))
209void WritePortOp::print(OpAsmPrinter &p) {
210 p <<
" " << getMemory() <<
"[" << getAddresses() <<
"] " << getInData()
211 <<
" wren " << getWrEn();
212 p.printOptionalAttrDict((*this)->getAttrs());
213 p <<
" : " << getMemory().getType();
221 setNameFn(getHandle(),
getName());
224void HLMemOp::build(OpBuilder &builder, OperationState &result, Value clk,
225 Value rst, StringRef name, llvm::ArrayRef<int64_t> shape,
227 HLMemType t = HLMemType::get(builder.getContext(), shape,
elementType);
228 HLMemOp::build(builder, result, t, clk, rst, name);
237 IntegerAttr &threshold,
238 Type &outputFlagType,
239 StringRef directive) {
241 if (succeeded(parser.parseOptionalKeyword(directive))) {
242 int64_t thresholdValue;
243 if (succeeded(parser.parseInteger(thresholdValue))) {
244 threshold = parser.getBuilder().getI64IntegerAttr(thresholdValue);
245 outputFlagType = parser.getBuilder().getI1Type();
248 return parser.emitError(parser.getNameLoc(),
249 "expected integer value after " + directive +
256 Type &outputFlagType) {
262 Type &outputFlagType) {
268 Type outputFlagType) {
271 <<
" " << threshold.getInt();
275 Type outputFlagType) {
278 <<
" " << threshold.getInt();
282 setNameFn(getOutput(),
"out");
283 setNameFn(getEmpty(),
"empty");
284 setNameFn(getFull(),
"full");
285 if (
auto ae = getAlmostEmpty())
286 setNameFn(ae,
"almostEmpty");
287 if (
auto af = getAlmostFull())
288 setNameFn(af,
"almostFull");
291LogicalResult FIFOOp::verify() {
292 auto aet = getAlmostEmptyThreshold();
293 auto aft = getAlmostFullThreshold();
294 size_t depth = getDepth();
295 if (aft.has_value() && aft.value() > depth)
296 return emitOpError(
"almost full threshold must be <= FIFO depth");
298 if (aet.has_value() && aet.value() > depth)
299 return emitOpError(
"almost empty threshold must be <= FIFO depth");
313 setNameFn(getResult(), *name);
316LogicalResult CompRegOp::verify() {
317 if ((getReset() ==
nullptr) ^ (getResetValue() ==
nullptr))
319 "either reset and resetValue or neither must be specified");
323std::optional<size_t> CompRegOp::getTargetResultIndex() {
return 0; }
325template <
typename TOp>
327 if ((op.getReset() ==
nullptr) ^ (op.getResetValue() ==
nullptr))
328 return op->emitOpError(
329 "either reset and resetValue or neither must be specified");
330 bool hasReset = op.getReset() !=
nullptr;
331 if (hasReset && op.getResetValue().getType() != op.getInput().getType())
332 return op->emitOpError(
"reset value must be the same type as the input");
342 setNameFn(getResult(), *name);
345std::optional<size_t> CompRegClockEnabledOp::getTargetResultIndex() {
349LogicalResult CompRegClockEnabledOp::verify() {
362 setNameFn(getResult(), *name);
365std::optional<size_t> ShiftRegOp::getTargetResultIndex() {
return 0; }
367LogicalResult ShiftRegOp::verify() {
377void FirRegOp::build(OpBuilder &builder, OperationState &result, Value input,
378 Value clk, StringAttr name, hw::InnerSymAttr innerSym,
381 OpBuilder::InsertionGuard guard(builder);
383 result.addOperands(input);
384 result.addOperands(clk);
386 result.addAttribute(getNameAttrName(result.name), name);
389 result.addAttribute(getInnerSymAttrName(result.name), innerSym);
392 result.addAttribute(getPresetAttrName(result.name), preset);
394 result.addTypes(input.getType());
397void FirRegOp::build(OpBuilder &builder, OperationState &result, Value input,
398 Value clk, StringAttr name, Value reset, Value resetValue,
399 hw::InnerSymAttr innerSym,
bool isAsync) {
401 OpBuilder::InsertionGuard guard(builder);
403 result.addOperands(input);
404 result.addOperands(clk);
405 result.addOperands(reset);
406 result.addOperands(resetValue);
408 result.addAttribute(getNameAttrName(result.name), name);
410 result.addAttribute(getIsAsyncAttrName(result.name), builder.getUnitAttr());
413 result.addAttribute(getInnerSymAttrName(result.name), innerSym);
415 result.addTypes(input.getType());
418ParseResult FirRegOp::parse(OpAsmParser &parser, OperationState &result) {
419 auto &builder = parser.getBuilder();
420 llvm::SMLoc loc = parser.getCurrentLocation();
422 using Op = OpAsmParser::UnresolvedOperand;
425 if (parser.parseOperand(next) || parser.parseKeyword(
"clock") ||
426 parser.parseOperand(clk))
429 if (succeeded(parser.parseOptionalKeyword(
"sym"))) {
430 hw::InnerSymAttr innerSym;
431 if (parser.parseCustomAttributeWithFallback(innerSym,
nullptr,
432 "inner_sym", result.attributes))
437 std::optional<std::pair<Op, Op>> resetAndValue;
438 if (succeeded(parser.parseOptionalKeyword(
"reset"))) {
440 if (succeeded(parser.parseOptionalKeyword(
"async")))
442 else if (succeeded(parser.parseOptionalKeyword(
"sync")))
445 return parser.emitError(loc,
"invalid reset, expected 'sync' or 'async'");
447 result.attributes.append(
"isAsync", builder.getUnitAttr());
449 resetAndValue = {{}, {}};
450 if (parser.parseOperand(resetAndValue->first) || parser.parseComma() ||
451 parser.parseOperand(resetAndValue->second))
455 std::optional<APInt> presetValue;
456 llvm::SMLoc presetValueLoc;
457 if (succeeded(parser.parseOptionalKeyword(
"preset"))) {
458 presetValueLoc = parser.getCurrentLocation();
459 OptionalParseResult presetIntResult =
460 parser.parseOptionalInteger(presetValue.emplace());
461 if (!presetIntResult.has_value() || failed(*presetIntResult))
462 return parser.emitError(loc,
"expected integer value");
466 if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
467 parser.parseType(ty))
469 result.addTypes({ty});
473 if (hw::type_isa<seq::ClockType>(ty)) {
478 return parser.emitError(presetValueLoc,
479 "cannot preset register of unknown width");
483 APInt presetResult = presetValue->sextOrTrunc(width);
484 if (presetResult.zextOrTrunc(presetValue->getBitWidth()) != *presetValue)
485 return parser.emitError(loc,
"preset value too large");
487 auto builder = parser.getBuilder();
488 auto presetTy = builder.getIntegerType(width);
489 auto resultAttr = builder.getIntegerAttr(presetTy, presetResult);
490 result.addAttribute(
"preset", resultAttr);
495 if (parser.resolveOperand(next, ty, result.operands))
498 Type clkTy = ClockType::get(result.getContext());
499 if (parser.resolveOperand(clk, clkTy, result.operands))
503 Type i1 = IntegerType::get(result.getContext(), 1);
504 if (parser.resolveOperand(resetAndValue->first, i1, result.operands) ||
505 parser.resolveOperand(resetAndValue->second, ty, result.operands))
512void FirRegOp::print(::mlir::OpAsmPrinter &p) {
513 SmallVector<StringRef> elidedAttrs = {
514 getInnerSymAttrName(), getIsAsyncAttrName(), getPresetAttrName()};
516 p <<
' ' << getNext() <<
" clock " << getClk();
518 if (
auto sym = getInnerSymAttr()) {
524 p <<
" reset " << (getIsAsync() ?
"async" :
"sync") <<
' ';
525 p << getReset() <<
", " << getResetValue();
528 if (
auto preset = getPresetAttr()) {
529 p <<
" preset " << preset.getValue();
533 elidedAttrs.push_back(
"name");
535 p.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
536 p <<
" : " << getNext().getType();
540LogicalResult FirRegOp::verify() {
541 if (getReset() || getResetValue() || getIsAsync()) {
542 if (!getReset() || !getResetValue())
543 return emitOpError(
"must specify reset and reset value");
546 return emitOpError(
"register with no reset cannot be async");
548 if (
auto preset = getPresetAttr()) {
551 if (preset.getType() != getType() && presetWidth != width)
552 return emitOpError(
"preset type width must match register type");
562 setNameFn(getResult(),
getName());
565std::optional<size_t> FirRegOp::getTargetResultIndex() {
return 0; }
567LogicalResult FirRegOp::canonicalize(FirRegOp op, PatternRewriter &rewriter) {
571 if (
auto reset = op.getReset()) {
573 if (constOp.getValue().isZero()) {
574 rewriter.replaceOpWithNewOp<FirRegOp>(
575 op, op.getNext(), op.getClk(), op.getNameAttr(),
576 op.getInnerSymAttr(), op.getPresetAttr());
583 if (op.getInnerSymAttr())
591 if (op.getNext() == op.getResult())
593 if (
auto clk = op.getClk().getDefiningOp<seq::ToClockOp>())
599 bool replaceWithConstZero =
true;
600 if (
auto preset = op.getPresetAttr())
601 if (!preset.getValue().isZero())
602 replaceWithConstZero =
false;
604 if (
isConstant() && !op.getResetValue() && replaceWithConstZero) {
605 if (isa<seq::ClockType>(op.getType())) {
606 rewriter.replaceOpWithNewOp<seq::ConstClockOp>(
607 op, seq::ClockConstAttr::get(rewriter.getContext(), ClockConst::Low));
611 rewriter.replaceOpWithNewOp<
hw::BitcastOp>(op, op.getType(), constant);
622 if (!op.getReset() && !op.getPresetAttr()) {
626 if (isa<IntegerType>(
627 hw::type_cast<hw::ArrayType>(op.getResult().getType())
628 .getElementType())) {
629 SmallVector<Value> nextOperands;
630 bool changed =
false;
631 for (
const auto &[i, value] :
632 llvm::enumerate(arrayCreate.getOperands())) {
633 auto index = arrayCreate.getOperands().size() - i - 1;
637 if (arrayGet.getInput() == op.getResult() &&
638 matchPattern(arrayGet.getIndex(),
639 m_ConstantInt(&elementIndex)) &&
640 elementIndex == index) {
647 nextOperands.push_back(value);
652 arrayCreate.getLoc(), nextOperands);
653 if (arrayCreate->hasOneUse())
656 rewriter.replaceOp(arrayCreate, newNextVal);
659 rewriter.replaceOpWithNewOp<FirRegOp>(op, newNextVal, op.getClk(),
661 op.getInnerSymAttr());
673OpFoldResult FirRegOp::fold(FoldAdaptor adaptor) {
676 if (getInnerSymAttr())
679 auto presetAttr = getPresetAttr();
689 if (
auto reset = getReset())
691 if (constOp.getValue().isOne())
692 return getResetValue();
697 bool isTrivialFeedback = (getNext() == getResult());
698 bool isNeverClocked =
699 adaptor.getClk() !=
nullptr;
700 if (!isTrivialFeedback && !isNeverClocked)
705 if (
auto resetValue = getResetValue()) {
706 if (
auto *op = resetValue.getDefiningOp()) {
707 if (op->hasTrait<OpTrait::ConstantLike>() && !presetAttr)
709 if (
auto constOp = dyn_cast<hw::ConstantOp>(op))
710 if (presetAttr.getValue() == constOp.getValue())
718 auto intType = dyn_cast<IntegerType>(getType());
724 return IntegerAttr::get(intType, 0);
731OpFoldResult ClockGateOp::fold(FoldAdaptor adaptor) {
740 return ClockConstAttr::get(getContext(), ClockConst::Low);
743 if (
auto clockAttr = dyn_cast_or_null<ClockConstAttr>(adaptor.getInput()))
744 if (clockAttr.getValue() == ClockConst::Low)
745 return ClockConstAttr::get(getContext(), ClockConst::Low);
749 auto clockGateInputOp = getInput().getDefiningOp<ClockGateOp>();
750 while (clockGateInputOp) {
751 if (clockGateInputOp.getEnable() == getEnable() &&
752 clockGateInputOp.getTestEnable() == getTestEnable())
754 clockGateInputOp = clockGateInputOp.getInput().getDefiningOp<ClockGateOp>();
760LogicalResult ClockGateOp::canonicalize(ClockGateOp op,
761 PatternRewriter &rewriter) {
763 if (
auto testEnable = op.getTestEnable()) {
765 if (constOp.getValue().isZero()) {
766 rewriter.modifyOpInPlace(op,
767 [&] { op.getTestEnableMutable().clear(); });
776std::optional<size_t> ClockGateOp::getTargetResultIndex() {
784OpFoldResult ClockMuxOp::fold(FoldAdaptor adaptor) {
786 return getTrueClock();
788 return getFalseClock();
796LogicalResult FirMemOp::canonicalize(FirMemOp op, PatternRewriter &rewriter) {
798 if (op.getInnerSymAttr())
802 for (
auto *user : op->getUsers()) {
803 if (isa<FirMemReadOp, FirMemReadWriteOp>(user))
805 assert(isa<FirMemWriteOp>(user) &&
"invalid seq.firmem user");
808 for (
auto *user :
llvm::make_early_inc_range(op->getUsers()))
809 rewriter.eraseOp(user);
811 rewriter.eraseOp(op);
816 auto nameAttr = (*this)->getAttrOfType<StringAttr>(
"name");
817 if (!nameAttr.getValue().empty())
818 setNameFn(getResult(), nameAttr.getValue());
821std::optional<size_t> FirMemOp::getTargetResultIndex() {
return 0; }
825 if (
auto mask = op.getMask()) {
826 auto memType = op.getMemory().getType();
827 if (!memType.getMaskWidth())
828 return op.emitOpError(
"has mask operand but memory type '")
829 << memType <<
"' has no mask";
830 auto expected = IntegerType::get(op.getContext(), *memType.getMaskWidth());
831 if (mask.getType() != expected)
832 return op.emitOpError(
"has mask operand of type '")
833 << mask.getType() <<
"', but memory type requires '" << expected
840LogicalResult FirMemReadWriteOp::verify() {
return verifyFirMemMask(*
this); }
845 return value.getDefiningOp<seq::ConstClockOp>();
851 return constOp.getValue().isZero();
858 return constOp.getValue().isAllOnes();
862LogicalResult FirMemReadOp::canonicalize(FirMemReadOp op,
863 PatternRewriter &rewriter) {
866 rewriter.modifyOpInPlace(op, [&] { op.getEnableMutable().erase(0); });
872LogicalResult FirMemWriteOp::canonicalize(FirMemWriteOp op,
873 PatternRewriter &rewriter) {
877 rewriter.eraseOp(op);
880 bool anyChanges =
false;
884 rewriter.modifyOpInPlace(op, [&] { op.getEnableMutable().erase(0); });
890 rewriter.modifyOpInPlace(op, [&] { op.getMaskMutable().erase(0); });
894 return success(anyChanges);
897LogicalResult FirMemReadWriteOp::canonicalize(FirMemReadWriteOp op,
898 PatternRewriter &rewriter) {
903 auto opAttrs = op->getAttrs();
904 auto opAttrNames = op.getAttributeNames();
905 auto newOp = rewriter.replaceOpWithNewOp<FirMemReadOp>(
906 op, op.getMemory(), op.getAddress(), op.getClk(), op.getEnable());
907 for (
auto namedAttr : opAttrs)
908 if (!
llvm::is_contained(opAttrNames, namedAttr.
getName()))
909 newOp->setAttr(namedAttr.
getName(), namedAttr.getValue());
912 bool anyChanges =
false;
916 rewriter.modifyOpInPlace(op, [&] { op.getEnableMutable().erase(0); });
922 rewriter.modifyOpInPlace(op, [&] { op.getMaskMutable().erase(0); });
926 return success(anyChanges);
933OpFoldResult ConstClockOp::fold(FoldAdaptor adaptor) {
934 return ClockConstAttr::get(getContext(), getValue());
941LogicalResult ToClockOp::canonicalize(ToClockOp op, PatternRewriter &rewriter) {
942 if (
auto fromClock = op.getInput().getDefiningOp<FromClockOp>()) {
943 rewriter.replaceOp(op, fromClock.getInput());
949OpFoldResult ToClockOp::fold(FoldAdaptor adaptor) {
950 if (
auto fromClock = getInput().getDefiningOp<FromClockOp>())
951 return fromClock.getInput();
952 if (
auto intAttr = dyn_cast_or_null<IntegerAttr>(adaptor.getInput())) {
954 intAttr.getValue().isZero() ? ClockConst::Low : ClockConst::High;
955 return ClockConstAttr::get(getContext(), value);
960LogicalResult FromClockOp::canonicalize(FromClockOp op,
961 PatternRewriter &rewriter) {
962 if (
auto toClock = op.getInput().getDefiningOp<ToClockOp>()) {
963 rewriter.replaceOp(op, toClock.getInput());
969OpFoldResult FromClockOp::fold(FoldAdaptor adaptor) {
970 if (
auto toClock = getInput().getDefiningOp<ToClockOp>())
971 return toClock.getInput();
972 if (
auto clockAttr = dyn_cast_or_null<ClockConstAttr>(adaptor.getInput())) {
973 auto ty = IntegerType::get(getContext(), 1);
974 return IntegerAttr::get(ty, clockAttr.getValue() == ClockConst::High);
983OpFoldResult ClockInverterOp::fold(FoldAdaptor adaptor) {
984 if (
auto chainedInv = getInput().getDefiningOp<ClockInverterOp>())
985 return chainedInv.getInput();
986 if (
auto clockAttr = dyn_cast_or_null<ClockConstAttr>(adaptor.getInput())) {
987 auto clockIn = clockAttr.getValue() == ClockConst::High;
988 return ClockConstAttr::get(getContext(),
989 clockIn ? ClockConst::Low : ClockConst::High);
999 depth = op->getAttrOfType<IntegerAttr>(
"depth").
getInt();
1000 numReadPorts = op->getAttrOfType<IntegerAttr>(
"numReadPorts").getUInt();
1001 numWritePorts = op->getAttrOfType<IntegerAttr>(
"numWritePorts").getUInt();
1003 op->getAttrOfType<IntegerAttr>(
"numReadWritePorts").getUInt();
1004 readLatency = op->getAttrOfType<IntegerAttr>(
"readLatency").getUInt();
1005 writeLatency = op->getAttrOfType<IntegerAttr>(
"writeLatency").getUInt();
1006 dataWidth = op->getAttrOfType<IntegerAttr>(
"width").getUInt();
1007 if (op->hasAttrOfType<IntegerAttr>(
"maskGran"))
1008 maskGran = op->getAttrOfType<IntegerAttr>(
"maskGran").getUInt();
1010 maskGran = dataWidth;
1011 readUnderWrite = op->getAttrOfType<seq::RUWAttr>(
"readUnderWrite").getValue();
1013 op->getAttrOfType<seq::WUWAttr>(
"writeUnderWrite").getValue();
1014 if (
auto clockIDsAttr = op->getAttrOfType<ArrayAttr>(
"writeClockIDs"))
1015 for (
auto clockID : clockIDsAttr)
1016 writeClockIDs.push_back(
1017 cast<IntegerAttr>(clockID).getValue().getZExtValue());
1018 initFilename = op->getAttrOfType<StringAttr>(
"initFilename").getValue();
1019 initIsBinary = op->getAttrOfType<BoolAttr>(
"initIsBinary").getValue();
1020 initIsInline = op->getAttrOfType<BoolAttr>(
"initIsInline").getValue();
1023LogicalResult InitialOp::verify() {
1025 auto *terminator = this->getBody().front().getTerminator();
1026 if (terminator->getOperands().size() != getNumResults())
1027 return emitError() <<
"result type doesn't match with the terminator";
1028 for (
auto [lhs, rhs] :
1029 llvm::zip(terminator->getOperands().getTypes(), getResultTypes())) {
1030 if (cast<seq::ImmutableType>(rhs).getInnerType() != lhs)
1031 return emitError() << cast<seq::ImmutableType>(rhs).getInnerType()
1032 <<
" is expected but got " << lhs;
1035 auto blockArgs = this->getBody().front().getArguments();
1037 if (blockArgs.size() != getNumOperands())
1038 return emitError() <<
"operand type doesn't match with the block arg";
1040 for (
auto [blockArg, operand] :
llvm::zip(blockArgs, getOperands())) {
1041 if (blockArg.getType() !=
1042 cast<ImmutableType>(operand.getType()).getInnerType())
1044 << blockArg.getType() <<
" is expected but got "
1045 << cast<ImmutableType>(operand.getType()).getInnerType();
1049void InitialOp::build(OpBuilder &builder, OperationState &result,
1050 TypeRange resultTypes, std::function<
void()> ctor) {
1051 OpBuilder::InsertionGuard guard(builder);
1053 builder.createBlock(result.addRegion());
1054 SmallVector<Type> types;
1055 for (
auto t : resultTypes)
1056 types.push_back(
seq::ImmutableType::
get(t));
1058 result.addTypes(types);
1064TypedValue<seq::ImmutableType>
1066 mlir::IntegerAttr attr) {
1067 auto initial = builder.create<seq::InitialOp>(loc, attr.getType(), [&]() {
1069 builder.
create<seq::YieldOp>(loc, ArrayRef<Value>{constant});
1071 return cast<TypedValue<seq::ImmutableType>>(initial->getResult(0));
1074mlir::TypedValue<seq::ImmutableType>
1076 assert(op->getNumResults() == 1 &&
1077 op->hasTrait<mlir::OpTrait::ConstantLike>());
1079 builder.create<seq::InitialOp>(op->getLoc(), op->getResultTypes(), [&]() {
1080 auto clonedOp = builder.clone(*op);
1081 builder.create<seq::YieldOp>(op->getLoc(), clonedOp->getResults());
1083 return cast<mlir::TypedValue<seq::ImmutableType>>(initial.getResult(0));
1087 auto resultNum = cast<OpResult>(value).getResultNumber();
1088 auto initialOp = value.getDefiningOp<seq::InitialOp>();
1090 return initialOp.getBodyBlock()->getTerminator()->getOperand(resultNum);
1094 SmallVector<Operation *> initialOps;
1095 for (
auto &op : *block)
1096 if (isa<seq::InitialOp>(op))
1097 initialOps.push_back(&op);
1099 if (!mlir::computeTopologicalSorting(initialOps, {}))
1100 return block->getParentOp()->emitError() <<
"initial ops cannot be "
1101 <<
"topologically sorted";
1104 if (initialOps.size() <= 1)
1105 return initialOps.empty() ? seq::InitialOp()
1106 : cast<seq::InitialOp>(initialOps[0]);
1108 auto initialOp = cast<seq::InitialOp>(initialOps.front());
1109 auto yieldOp = cast<seq::YieldOp>(initialOp.getBodyBlock()->getTerminator());
1111 llvm::MapVector<Value, Value>
1112 resultToYieldOperand;
1114 for (
auto [result, operand] :
1115 llvm::zip(initialOp.getResults(), yieldOp->getOperands()))
1116 resultToYieldOperand.insert({result, operand});
1118 for (
size_t i = 1; i < initialOps.size(); ++i) {
1119 auto currentInitialOp = cast<seq::InitialOp>(initialOps[i]);
1120 auto operands = currentInitialOp->getOperands();
1121 for (
auto [blockArg, operand] :
1123 if (
auto initOp = operand.getDefiningOp<seq::InitialOp>()) {
1124 assert(resultToYieldOperand.count(operand) &&
1125 "it must be visited already");
1126 blockArg.replaceAllUsesWith(resultToYieldOperand.lookup(operand));
1129 initialOp.getBodyBlock()->addArgument(
1130 cast<seq::ImmutableType>(operand.getType()).getInnerType(),
1132 initialOp.getInputsMutable().append(operand);
1136 auto currentYieldOp =
1137 cast<seq::YieldOp>(currentInitialOp.getBodyBlock()->getTerminator());
1139 for (
auto [result, operand] :
llvm::zip(currentInitialOp.getResults(),
1140 currentYieldOp->getOperands()))
1141 resultToYieldOperand.insert({result, operand});
1144 yieldOp.getOperandsMutable().append(currentYieldOp.getOperands());
1145 currentYieldOp->erase();
1149 initialOp.getBodyBlock()->getOperations().splice(
1150 initialOp.end(), currentInitialOp.getBodyBlock()->getOperations());
1154 yieldOp->moveBefore(initialOp.getBodyBlock(),
1155 initialOp.getBodyBlock()->end());
1157 auto builder = OpBuilder::atBlockBegin(block);
1158 SmallVector<Type> types;
1159 for (
auto [result, operand] : resultToYieldOperand)
1160 types.push_back(operand.getType());
1164 auto newInitial = builder.create<seq::InitialOp>(initialOp.getLoc(), types);
1165 newInitial.getInputsMutable().append(initialOp.getInputs());
1167 for (
auto [resultAndOperand, newResult] :
1168 llvm::zip(resultToYieldOperand, newInitial.getResults()))
1169 resultAndOperand.first.replaceAllUsesWith(newResult);
1172 for (
auto oldBlockArg : initialOp.
getBodyBlock()->getArguments()) {
1173 auto blockArg = newInitial.getBodyBlock()->addArgument(
1174 oldBlockArg.getType(), oldBlockArg.getLoc());
1175 oldBlockArg.replaceAllUsesWith(blockArg);
1178 newInitial.getBodyBlock()->getOperations().splice(
1179 newInitial.end(), initialOp.getBodyBlock()->getOperations());
1182 while (!initialOps.empty())
1183 initialOps.pop_back_val()->erase();
1193#define GET_OP_CLASSES
1194#include "circt/Dialect/Seq/Seq.cpp.inc"
assert(baseType &&"element must be base type")
static bool isConstZero(Value value)
static InstancePath empty
static std::optional< APInt > getInt(Value value)
Helper to convert a value to a constant integer if it is one.
static Block * getBodyBlock(FModuleLike mod)
void printFIFOAFThreshold(OpAsmPrinter &p, Operation *op, IntegerAttr threshold, Type outputFlagType)
static bool isConstClock(Value value)
static ParseResult parseFIFOFlagThreshold(OpAsmParser &parser, IntegerAttr &threshold, Type &outputFlagType, StringRef directive)
static void printOptionalTypeMatch(OpAsmPrinter &p, Operation *op, Type refType, Value operand, Type type)
static bool isConstAllOnes(Value value)
static ParseResult parseOptionalImmutableTypeMatch(OpAsmParser &parser, Type refType, std::optional< OpAsmParser::UnresolvedOperand > operand, Type &type)
void printFIFOAEThreshold(OpAsmPrinter &p, Operation *op, IntegerAttr threshold, Type outputFlagType)
LogicalResult verifyResets(TOp op)
static bool canElideName(OpAsmPrinter &p, Operation *op)
ParseResult parseFIFOAEThreshold(OpAsmParser &parser, IntegerAttr &threshold, Type &outputFlagType)
static LogicalResult verifyFirMemMask(Op op)
static void printOptionalImmutableTypeMatch(OpAsmPrinter &p, Operation *op, Type refType, Value operand, Type type)
static ParseResult parseOptionalTypeMatch(OpAsmParser &parser, Type refType, std::optional< OpAsmParser::UnresolvedOperand > operand, Type &type)
static void setNameFromResult(OpAsmParser &parser, OperationState &result)
ParseResult parseFIFOAFThreshold(OpAsmParser &parser, IntegerAttr &threshold, Type &outputFlagType)
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
bool isConstant(Operation *op)
Return true if the specified operation has a constant value.
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
int64_t getBitWidth(mlir::Type type)
Return the hardware bit width of a type.
FailureOr< seq::InitialOp > mergeInitialOps(Block *block)
bool isValidIndexValues(Value hlmemHandle, ValueRange addresses)
mlir::TypedValue< seq::ImmutableType > createConstantInitialValue(OpBuilder builder, Location loc, mlir::IntegerAttr attr)
Value unwrapImmutableValue(mlir::TypedValue< seq::ImmutableType > immutableVal)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
static bool isConstantZero(Attribute operand)
Determine whether a constant operand is a zero value.
static bool isConstantOne(Attribute operand)
Determine whether a constant operand is a one value.
function_ref< void(Value, StringRef)> OpAsmSetValueNameFn
FirMemory(hw::HWModuleGeneratedOp op)