17 #include "mlir/IR/Builders.h"
18 #include "mlir/IR/DialectImplementation.h"
19 #include "mlir/IR/Matchers.h"
20 #include "mlir/IR/PatternMatch.h"
23 #include "llvm/ADT/SmallString.h"
26 using namespace circt;
30 auto memType = hlmemHandle.getType().cast<seq::HLMemType>();
31 auto shape = memType.getShape();
32 if (shape.size() != addresses.size())
35 for (
auto [dim, addr] : llvm::zip(shape, addresses)) {
36 auto addrType = addr.getType().dyn_cast<IntegerType>();
39 if (addrType.getIntOrFloatBitWidth() != llvm::Log2_64_Ceil(dim))
48 if (result.attributes.getNamed(
"name"))
52 StringRef resultName = parser.getResultName(0).first;
53 if (!resultName.empty() &&
isdigit(resultName[0]))
55 result.addAttribute(
"name", parser.getBuilder().getStringAttr(resultName));
59 if (!op->hasAttr(
"name"))
62 auto name = op->getAttrOfType<StringAttr>(
"name").getValue();
66 SmallString<32> resultNameStr;
67 llvm::raw_svector_ostream tmpStream(resultNameStr);
68 p.printOperand(op->getResult(0), tmpStream);
69 auto actualName = tmpStream.str().drop_front();
70 return actualName == name;
75 std::optional<OpAsmParser::UnresolvedOperand> operand,
83 Value operand, Type type) {
91 ParseResult ReadPortOp::parse(OpAsmParser &parser, OperationState &result) {
92 llvm::SMLoc loc = parser.getCurrentLocation();
94 OpAsmParser::UnresolvedOperand memOperand, rdenOperand;
96 llvm::SmallVector<OpAsmParser::UnresolvedOperand, 2> addressOperands;
97 seq::HLMemType memType;
99 if (parser.parseOperand(memOperand) ||
100 parser.parseOperandList(addressOperands, OpAsmParser::Delimiter::Square))
103 if (succeeded(parser.parseOptionalKeyword(
"rden"))) {
104 if (failed(parser.parseOperand(rdenOperand)))
109 if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
110 parser.parseType(memType))
113 llvm::SmallVector<Type> operandTypes = memType.getAddressTypes();
114 operandTypes.insert(operandTypes.begin(), memType);
116 llvm::SmallVector<OpAsmParser::UnresolvedOperand> allOperands = {memOperand};
117 llvm::copy(addressOperands, std::back_inserter(allOperands));
119 operandTypes.push_back(parser.getBuilder().getI1Type());
120 allOperands.push_back(rdenOperand);
123 if (parser.resolveOperands(allOperands, operandTypes, loc, result.operands))
126 result.addTypes(memType.getElementType());
128 llvm::SmallVector<int32_t, 2> operandSizes;
129 operandSizes.push_back(1);
130 operandSizes.push_back(addressOperands.size());
131 operandSizes.push_back(hasRdEn ? 1 : 0);
132 result.addAttribute(
"operandSegmentSizes",
133 parser.getBuilder().getDenseI32ArrayAttr(operandSizes));
137 void ReadPortOp::print(OpAsmPrinter &p) {
138 p <<
" " << getMemory() <<
"[" << getAddresses() <<
"]";
140 p <<
" rden " << getRdEn();
141 p.printOptionalAttrDict((*this)->getAttrs(), {
"operandSegmentSizes"});
142 p <<
" : " << getMemory().getType();
146 auto memName = getMemory().getDefiningOp<seq::HLMemOp>().
getName();
147 setNameFn(getReadData(), (memName +
"_rdata").str());
150 void ReadPortOp::build(OpBuilder &
builder, OperationState &result, Value memory,
151 ValueRange addresses, Value rdEn,
unsigned latency) {
152 auto memType = memory.getType().cast<seq::HLMemType>();
153 ReadPortOp::build(
builder, result, memType.getElementType(), memory,
154 addresses, rdEn, latency);
161 ParseResult WritePortOp::parse(OpAsmParser &parser, OperationState &result) {
162 llvm::SMLoc loc = parser.getCurrentLocation();
163 OpAsmParser::UnresolvedOperand memOperand, dataOperand, wrenOperand;
164 llvm::SmallVector<OpAsmParser::UnresolvedOperand, 2> addressOperands;
165 seq::HLMemType memType;
167 if (parser.parseOperand(memOperand) ||
168 parser.parseOperandList(addressOperands,
169 OpAsmParser::Delimiter::Square) ||
170 parser.parseOperand(dataOperand) || parser.parseKeyword(
"wren") ||
171 parser.parseOperand(wrenOperand) ||
172 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
173 parser.parseType(memType))
176 llvm::SmallVector<Type> operandTypes = memType.getAddressTypes();
177 operandTypes.insert(operandTypes.begin(), memType);
178 operandTypes.push_back(memType.getElementType());
179 operandTypes.push_back(parser.getBuilder().getI1Type());
181 llvm::SmallVector<OpAsmParser::UnresolvedOperand, 2> allOperands(
183 allOperands.insert(allOperands.begin(), memOperand);
184 allOperands.push_back(dataOperand);
185 allOperands.push_back(wrenOperand);
187 if (parser.resolveOperands(allOperands, operandTypes, loc, result.operands))
193 void WritePortOp::print(OpAsmPrinter &p) {
194 p <<
" " << getMemory() <<
"[" << getAddresses() <<
"] " << getInData()
195 <<
" wren " << getWrEn();
196 p.printOptionalAttrDict((*this)->getAttrs());
197 p <<
" : " << getMemory().getType();
205 setNameFn(getHandle(),
getName());
208 void HLMemOp::build(OpBuilder &
builder, OperationState &result, Value clk,
209 Value rst, StringRef symName, llvm::ArrayRef<int64_t> shape,
212 HLMemOp::build(
builder, result, t, clk, rst, symName);
221 IntegerAttr &threshold,
222 Type &outputFlagType,
223 StringRef directive) {
225 if (succeeded(parser.parseOptionalKeyword(directive))) {
226 int64_t thresholdValue;
227 if (succeeded(parser.parseInteger(thresholdValue))) {
228 threshold = parser.getBuilder().getI64IntegerAttr(thresholdValue);
229 outputFlagType = parser.getBuilder().getI1Type();
232 return parser.emitError(parser.getNameLoc(),
233 "expected integer value after " + directive +
240 Type &outputFlagType) {
246 Type &outputFlagType) {
252 Type outputFlagType) {
255 <<
" " << threshold.getInt();
260 Type outputFlagType) {
263 <<
" " << threshold.getInt();
268 setNameFn(getOutput(),
"out");
269 setNameFn(getEmpty(),
"empty");
270 setNameFn(getFull(),
"full");
271 if (
auto ae = getAlmostEmpty())
272 setNameFn(ae,
"almostEmpty");
273 if (
auto af = getAlmostFull())
274 setNameFn(af,
"almostFull");
277 LogicalResult FIFOOp::verify() {
278 auto aet = getAlmostEmptyThreshold();
279 auto aft = getAlmostFullThreshold();
280 size_t depth = getDepth();
281 if (aft.has_value() && aft.value() > depth)
282 return emitOpError(
"almost full threshold must be <= FIFO depth");
284 if (aet.has_value() && aet.value() > depth)
285 return emitOpError(
"almost empty threshold must be <= FIFO depth");
299 setNameFn(getResult(), *name);
302 LogicalResult CompRegOp::verify() {
303 if ((getReset() ==
nullptr) ^ (getResetValue() ==
nullptr))
305 "either reset and resetValue or neither must be specified");
309 std::optional<size_t> CompRegOp::getTargetResultIndex() {
return 0; }
311 template <
typename TOp>
313 if ((op.getReset() ==
nullptr) ^ (op.getResetValue() ==
nullptr))
314 return op->emitOpError(
315 "either reset and resetValue or neither must be specified");
316 bool hasReset = op.getReset() !=
nullptr;
317 if (hasReset && op.getResetValue().getType() != op.getInput().getType())
318 return op->emitOpError(
"reset value must be the same type as the input");
328 setNameFn(getResult(), *name);
331 std::optional<size_t> CompRegClockEnabledOp::getTargetResultIndex() {
335 LogicalResult CompRegClockEnabledOp::verify() {
348 setNameFn(getResult(), *name);
351 std::optional<size_t> ShiftRegOp::getTargetResultIndex() {
return 0; }
353 LogicalResult ShiftRegOp::verify() {
363 void FirRegOp::build(OpBuilder &
builder, OperationState &result, Value input,
364 Value clk, StringAttr name, hw::InnerSymAttr innerSym) {
366 OpBuilder::InsertionGuard guard(
builder);
368 result.addOperands(input);
369 result.addOperands(clk);
371 result.addAttribute(getNameAttrName(result.name), name);
374 result.addAttribute(getInnerSymAttrName(result.name), innerSym);
376 result.addTypes(input.getType());
379 void FirRegOp::build(OpBuilder &
builder, OperationState &result, Value input,
380 Value clk, StringAttr name, Value reset, Value resetValue,
381 hw::InnerSymAttr innerSym,
bool isAsync) {
383 OpBuilder::InsertionGuard guard(
builder);
385 result.addOperands(input);
386 result.addOperands(clk);
387 result.addOperands(reset);
388 result.addOperands(resetValue);
390 result.addAttribute(getNameAttrName(result.name), name);
392 result.addAttribute(getIsAsyncAttrName(result.name),
builder.getUnitAttr());
395 result.addAttribute(getInnerSymAttrName(result.name), innerSym);
397 result.addTypes(input.getType());
400 ParseResult FirRegOp::parse(OpAsmParser &parser, OperationState &result) {
401 auto &
builder = parser.getBuilder();
402 llvm::SMLoc loc = parser.getCurrentLocation();
404 using Op = OpAsmParser::UnresolvedOperand;
407 if (parser.parseOperand(next) || parser.parseKeyword(
"clock") ||
408 parser.parseOperand(clk))
411 if (succeeded(parser.parseOptionalKeyword(
"sym"))) {
412 hw::InnerSymAttr innerSym;
413 if (parser.parseCustomAttributeWithFallback(innerSym,
nullptr,
414 "inner_sym", result.attributes))
419 std::optional<std::pair<Op, Op>> resetAndValue;
420 if (succeeded(parser.parseOptionalKeyword(
"reset"))) {
422 if (succeeded(parser.parseOptionalKeyword(
"async")))
424 else if (succeeded(parser.parseOptionalKeyword(
"sync")))
427 return parser.emitError(loc,
"invalid reset, expected 'sync' or 'async'");
429 result.attributes.append(
"isAsync",
builder.getUnitAttr());
431 resetAndValue = {{}, {}};
432 if (parser.parseOperand(resetAndValue->first) || parser.parseComma() ||
433 parser.parseOperand(resetAndValue->second))
438 if (succeeded(parser.parseOptionalKeyword(
"preset"))) {
440 if (parser.parseAttribute(preset,
"preset", result.attributes) ||
441 parser.parseOptionalAttrDict(result.attributes))
443 ty = preset.getType();
445 if (parser.parseOptionalAttrDict(result.attributes) ||
446 parser.parseColon() || parser.parseType(ty))
449 result.addTypes({ty});
453 if (parser.resolveOperand(next, ty, result.operands))
457 if (parser.resolveOperand(clk, clkTy, result.operands))
462 if (parser.resolveOperand(resetAndValue->first, i1, result.operands) ||
463 parser.resolveOperand(resetAndValue->second, ty, result.operands))
470 void FirRegOp::print(::mlir::OpAsmPrinter &p) {
471 SmallVector<StringRef> elidedAttrs = {
472 getInnerSymAttrName(), getIsAsyncAttrName(), getPresetAttrName()};
474 p <<
' ' << getNext() <<
" clock " << getClk();
476 if (
auto sym = getInnerSymAttr()) {
482 p <<
" reset " << (getIsAsync() ?
"async" :
"sync") <<
' ';
483 p << getReset() <<
", " << getResetValue();
486 if (
auto preset = getPresetAttr()) {
487 p <<
" preset " << preset.getValue();
491 elidedAttrs.push_back(
"name");
493 p.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
494 p <<
" : " << getNext().getType();
498 LogicalResult FirRegOp::verify() {
499 if (getReset() || getResetValue() || getIsAsync()) {
500 if (!getReset() || !getResetValue())
501 return emitOpError(
"must specify reset and reset value");
504 return emitOpError(
"register with no reset cannot be async");
506 if (
auto preset = getPresetAttr()) {
507 if (preset.getType() != getType())
508 return emitOpError(
"preset type must match register type");
518 setNameFn(getResult(),
getName());
521 std::optional<size_t> FirRegOp::getTargetResultIndex() {
return 0; }
523 LogicalResult FirRegOp::canonicalize(FirRegOp op, PatternRewriter &rewriter) {
526 if (
auto reset = op.getReset()) {
528 if (constOp.getValue().isZero()) {
529 rewriter.replaceOpWithNewOp<FirRegOp>(op, op.getNext(), op.getClk(),
531 op.getInnerSymAttr());
538 if (op.getInnerSymAttr())
546 if (op.getNext() == op.getResult())
548 if (
auto clk = op.getClk().getDefiningOp<seq::ToClockOp>())
554 if (
auto resetValue = op.getResetValue()) {
556 rewriter.replaceOp(op, resetValue);
558 if (op.getType().isa<seq::ClockType>()) {
559 rewriter.replaceOpWithNewOp<seq::ConstClockOp>(
565 rewriter.replaceOpWithNewOp<
hw::BitcastOp>(op, op.getType(), constant);
577 if (!op.getReset()) {
581 if (hw::type_cast<hw::ArrayType>(op.getResult().getType())
583 .isa<IntegerType>()) {
584 SmallVector<Value> nextOperands;
585 bool changed =
false;
586 for (
const auto &[i,
value] :
587 llvm::enumerate(arrayCreate.getOperands())) {
588 auto index = arrayCreate.getOperands().size() - i - 1;
592 if (arrayGet.getInput() == op.getResult() &&
593 matchPattern(arrayGet.getIndex(),
594 m_ConstantInt(&elementIndex)) &&
595 elementIndex == index) {
602 nextOperands.push_back(
value);
607 arrayCreate.getLoc(), nextOperands);
608 if (arrayCreate->hasOneUse())
611 rewriter.replaceOp(arrayCreate, newNextVal);
614 rewriter.replaceOpWithNewOp<FirRegOp>(op, newNextVal, op.getClk(),
616 op.getInnerSymAttr());
628 OpFoldResult FirRegOp::fold(FoldAdaptor adaptor) {
630 if (getInnerSymAttr())
640 if (
auto reset = getReset())
642 if (constOp.getValue().isOne())
643 return getResetValue();
648 bool isTrivialFeedback = (getNext() == getResult());
649 bool isNeverClocked =
650 adaptor.getClk() !=
nullptr;
651 if (!isTrivialFeedback && !isNeverClocked)
655 if (
auto resetValue = getResetValue())
660 auto intType = getType().dyn_cast<IntegerType>();
670 OpFoldResult ClockGateOp::fold(FoldAdaptor adaptor) {
682 if (
auto clockAttr = dyn_cast_or_null<ClockConstAttr>(adaptor.getInput()))
683 if (clockAttr.getValue() == ClockConst::Low)
688 auto clockGateInputOp = getInput().getDefiningOp<ClockGateOp>();
689 while (clockGateInputOp) {
690 if (clockGateInputOp.getEnable() == getEnable() &&
691 clockGateInputOp.getTestEnable() == getTestEnable())
693 clockGateInputOp = clockGateInputOp.getInput().getDefiningOp<ClockGateOp>();
699 LogicalResult ClockGateOp::canonicalize(ClockGateOp op,
700 PatternRewriter &rewriter) {
702 if (
auto testEnable = op.getTestEnable()) {
704 if (constOp.getValue().isZero()) {
705 rewriter.updateRootInPlace(op,
706 [&] { op.getTestEnableMutable().clear(); });
715 std::optional<size_t> ClockGateOp::getTargetResultIndex() {
723 OpFoldResult ClockMuxOp::fold(FoldAdaptor adaptor) {
725 return getTrueClock();
727 return getFalseClock();
736 auto nameAttr = (*this)->getAttrOfType<StringAttr>(
"name");
737 if (!nameAttr.getValue().empty())
738 setNameFn(getResult(), nameAttr.getValue());
741 std::optional<size_t> FirMemOp::getTargetResultIndex() {
return 0; }
745 if (
auto mask = op.getMask()) {
746 auto memType = op.getMemory().getType();
747 if (!memType.getMaskWidth())
748 return op.emitOpError(
"has mask operand but memory type '")
749 << memType <<
"' has no mask";
751 if (mask.getType() != expected)
752 return op.emitOpError(
"has mask operand of type '")
753 << mask.getType() <<
"', but memory type requires '" << expected
760 LogicalResult FirMemReadWriteOp::verify() {
return verifyFirMemMask(*
this); }
765 return value.getDefiningOp<seq::ConstClockOp>();
771 return constOp.getValue().isZero();
778 return constOp.getValue().isAllOnes();
782 LogicalResult FirMemReadOp::canonicalize(FirMemReadOp op,
783 PatternRewriter &rewriter) {
786 rewriter.updateRootInPlace(op, [&] { op.getEnableMutable().erase(0); });
792 LogicalResult FirMemWriteOp::canonicalize(FirMemWriteOp op,
793 PatternRewriter &rewriter) {
797 rewriter.eraseOp(op);
800 bool anyChanges =
false;
804 rewriter.updateRootInPlace(op, [&] { op.getEnableMutable().erase(0); });
810 rewriter.updateRootInPlace(op, [&] { op.getMaskMutable().erase(0); });
814 return success(anyChanges);
817 LogicalResult FirMemReadWriteOp::canonicalize(FirMemReadWriteOp op,
818 PatternRewriter &rewriter) {
823 auto opAttrs = op->getAttrs();
824 auto opAttrNames = op.getAttributeNames();
825 auto newOp = rewriter.replaceOpWithNewOp<FirMemReadOp>(
826 op, op.getMemory(), op.getAddress(), op.getClk(), op.getEnable());
827 for (
auto namedAttr : opAttrs)
828 if (!llvm::is_contained(opAttrNames, namedAttr.getName()))
829 newOp->setAttr(namedAttr.getName(), namedAttr.getValue());
832 bool anyChanges =
false;
836 rewriter.updateRootInPlace(op, [&] { op.getEnableMutable().erase(0); });
842 rewriter.updateRootInPlace(op, [&] { op.getMaskMutable().erase(0); });
846 return success(anyChanges);
853 OpFoldResult ConstClockOp::fold(FoldAdaptor adaptor) {
861 LogicalResult ToClockOp::canonicalize(ToClockOp op, PatternRewriter &rewriter) {
862 if (
auto fromClock = op.getInput().getDefiningOp<FromClockOp>()) {
863 rewriter.replaceOp(op, fromClock.getInput());
869 OpFoldResult ToClockOp::fold(FoldAdaptor adaptor) {
870 if (
auto fromClock = getInput().getDefiningOp<FromClockOp>())
871 return fromClock.getInput();
872 if (
auto intAttr = dyn_cast_or_null<IntegerAttr>(adaptor.getInput())) {
874 intAttr.getValue().isZero() ? ClockConst::Low : ClockConst::High;
880 LogicalResult FromClockOp::canonicalize(FromClockOp op,
881 PatternRewriter &rewriter) {
882 if (
auto toClock = op.getInput().getDefiningOp<ToClockOp>()) {
883 rewriter.replaceOp(op, toClock.getInput());
889 OpFoldResult FromClockOp::fold(FoldAdaptor adaptor) {
890 if (
auto toClock = getInput().getDefiningOp<ToClockOp>())
891 return toClock.getInput();
892 if (
auto clockAttr = dyn_cast_or_null<ClockConstAttr>(adaptor.getInput())) {
903 FirMemory::FirMemory(hw::HWModuleGeneratedOp op) {
904 depth = op->getAttrOfType<IntegerAttr>(
"depth").
getInt();
905 numReadPorts = op->getAttrOfType<IntegerAttr>(
"numReadPorts").getUInt();
906 numWritePorts = op->getAttrOfType<IntegerAttr>(
"numWritePorts").getUInt();
908 op->getAttrOfType<IntegerAttr>(
"numReadWritePorts").getUInt();
909 readLatency = op->getAttrOfType<IntegerAttr>(
"readLatency").getUInt();
910 writeLatency = op->getAttrOfType<IntegerAttr>(
"writeLatency").getUInt();
911 dataWidth = op->getAttrOfType<IntegerAttr>(
"width").getUInt();
912 if (op->hasAttrOfType<IntegerAttr>(
"maskGran"))
913 maskGran = op->getAttrOfType<IntegerAttr>(
"maskGran").getUInt();
915 maskGran = dataWidth;
916 readUnderWrite = op->getAttrOfType<seq::RUWAttr>(
"readUnderWrite").getValue();
918 op->getAttrOfType<seq::WUWAttr>(
"writeUnderWrite").getValue();
919 if (
auto clockIDsAttr = op->getAttrOfType<ArrayAttr>(
"writeClockIDs"))
920 for (
auto clockID : clockIDsAttr)
921 writeClockIDs.push_back(
922 clockID.cast<IntegerAttr>().getValue().getZExtValue());
923 initFilename = op->getAttrOfType<StringAttr>(
"initFilename").getValue();
924 initIsBinary = op->getAttrOfType<BoolAttr>(
"initIsBinary").getValue();
925 initIsInline = op->getAttrOfType<BoolAttr>(
"initIsInline").getValue();
933 #define GET_OP_CLASSES
934 #include "circt/Dialect/Seq/Seq.cpp.inc"
static bool isConstantOne(Attribute operand)
Determine whether a constant operand is a one value for the sake of constant folding.
static InstancePath empty
static std::optional< APInt > getInt(Value value)
Helper to convert a value to a constant integer if it is one.
void printFIFOAFThreshold(OpAsmPrinter &p, Operation *op, IntegerAttr threshold, Type outputFlagType)
static bool isConstClock(Value value)
static ParseResult parseFIFOFlagThreshold(OpAsmParser &parser, IntegerAttr &threshold, Type &outputFlagType, StringRef directive)
static void printOptionalTypeMatch(OpAsmPrinter &p, Operation *op, Type refType, Value operand, Type type)
static bool isConstAllOnes(Value value)
void printFIFOAEThreshold(OpAsmPrinter &p, Operation *op, IntegerAttr threshold, Type outputFlagType)
LogicalResult verifyResets(TOp op)
static bool canElideName(OpAsmPrinter &p, Operation *op)
ParseResult parseFIFOAEThreshold(OpAsmParser &parser, IntegerAttr &threshold, Type &outputFlagType)
static bool isConstZero(Value value)
static LogicalResult verifyFirMemMask(Op op)
static ParseResult parseOptionalTypeMatch(OpAsmParser &parser, Type refType, std::optional< OpAsmParser::UnresolvedOperand > operand, Type &type)
static void setNameFromResult(OpAsmParser &parser, OperationState &result)
ParseResult parseFIFOAFThreshold(OpAsmParser &parser, IntegerAttr &threshold, Type &outputFlagType)
def create(data_type, value)
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
bool isConstant(Operation *op)
Return true if the specified operation has a constant value.
std::optional< int64_t > getBitWidth(FIRRTLBaseType type, bool ignoreFlip=false)
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
void getAsmResultNames(OpAsmSetValueNameFn setNameFn, StringRef instanceName, ArrayAttr resultNames, ValueRange results)
Suggest a name for each result value based on the saved result names attribute.
bool isValidIndexValues(Value hlmemHandle, ValueRange addresses)
This file defines an intermediate representation for circuits acting as an abstraction for constraint...
static bool isConstantZero(Attribute operand)
Determine whether a constant operand is a zero value.
function_ref< void(Value, StringRef)> OpAsmSetValueNameFn