CIRCT 23.0.0git
Loading...
Searching...
No Matches
SeqOps.cpp
Go to the documentation of this file.
1//===- SeqOps.cpp - Implement the Seq operations ------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements sequential ops.
10//
11//===----------------------------------------------------------------------===//
12
18#include "mlir/Analysis/TopologicalSortUtils.h"
19#include "mlir/Dialect/Arith/IR/Arith.h"
20#include "mlir/IR/Builders.h"
21#include "mlir/IR/DialectImplementation.h"
22#include "mlir/IR/Matchers.h"
23#include "mlir/IR/PatternMatch.h"
24
26#include "llvm/ADT/SmallString.h"
27
28using namespace mlir;
29using namespace circt;
30using namespace seq;
31
32bool circt::seq::isValidIndexValues(Value hlmemHandle, ValueRange addresses) {
33 auto memType = cast<seq::HLMemType>(hlmemHandle.getType());
34 auto shape = memType.getShape();
35 if (shape.size() != addresses.size())
36 return false;
37
38 for (auto [dim, addr] : llvm::zip(shape, addresses)) {
39 auto addrType = dyn_cast<IntegerType>(addr.getType());
40 if (!addrType)
41 return false;
42 if (addrType.getIntOrFloatBitWidth() != llvm::Log2_64_Ceil(dim))
43 return false;
44 }
45 return true;
46}
47
48// If there was no name specified, check to see if there was a useful name
49// specified in the asm file.
50static void setNameFromResult(OpAsmParser &parser, OperationState &result) {
51 if (result.attributes.getNamed("name"))
52 return;
53 // If there is no explicit name attribute, get it from the SSA result name.
54 // If numeric, just use an empty name.
55 StringRef resultName = parser.getResultName(0).first;
56 if (!resultName.empty() && isdigit(resultName[0]))
57 resultName = "";
58 result.addAttribute("name", parser.getBuilder().getStringAttr(resultName));
59}
60
61static bool canElideName(OpAsmPrinter &p, Operation *op) {
62 if (!op->hasAttr("name"))
63 return true;
64
65 auto name = op->getAttrOfType<StringAttr>("name").getValue();
66 if (name.empty())
67 return true;
68
69 SmallString<32> resultNameStr;
70 llvm::raw_svector_ostream tmpStream(resultNameStr);
71 p.printOperand(op->getResult(0), tmpStream);
72 auto actualName = tmpStream.str().drop_front();
73 return actualName == name;
74}
75
76static ParseResult
77parseOptionalTypeMatch(OpAsmParser &parser, Type refType,
78 std::optional<OpAsmParser::UnresolvedOperand> operand,
79 Type &type) {
80 if (operand)
81 type = refType;
82 return success();
83}
84
85static void printOptionalTypeMatch(OpAsmPrinter &p, Operation *op, Type refType,
86 Value operand, Type type) {
87 // Nothing to do - this is strictly an implicit parsing helper.
88}
89
91 OpAsmParser &parser, Type refType,
92 std::optional<OpAsmParser::UnresolvedOperand> operand, Type &type) {
93 if (operand)
94 type = seq::ImmutableType::get(refType);
95 return success();
96}
97
98static void printOptionalImmutableTypeMatch(OpAsmPrinter &p, Operation *op,
99 Type refType, Value operand,
100 Type type) {
101 // Nothing to do - this is strictly an implicit parsing helper.
102}
103
104//===----------------------------------------------------------------------===//
105// ReadPortOp
106//===----------------------------------------------------------------------===//
107
108ParseResult ReadPortOp::parse(OpAsmParser &parser, OperationState &result) {
109 llvm::SMLoc loc = parser.getCurrentLocation();
110
111 OpAsmParser::UnresolvedOperand memOperand, rdenOperand;
112 bool hasRdEn = false;
113 llvm::SmallVector<OpAsmParser::UnresolvedOperand, 2> addressOperands;
114 seq::HLMemType memType;
115
116 if (parser.parseOperand(memOperand) ||
117 parser.parseOperandList(addressOperands, OpAsmParser::Delimiter::Square))
118 return failure();
119
120 if (succeeded(parser.parseOptionalKeyword("rden"))) {
121 if (failed(parser.parseOperand(rdenOperand)))
122 return failure();
123 hasRdEn = true;
124 }
125
126 if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
127 parser.parseType(memType))
128 return failure();
129
130 llvm::SmallVector<Type> operandTypes = memType.getAddressTypes();
131 operandTypes.insert(operandTypes.begin(), memType);
132
133 llvm::SmallVector<OpAsmParser::UnresolvedOperand> allOperands = {memOperand};
134 llvm::copy(addressOperands, std::back_inserter(allOperands));
135 if (hasRdEn) {
136 operandTypes.push_back(parser.getBuilder().getI1Type());
137 allOperands.push_back(rdenOperand);
138 }
139
140 if (parser.resolveOperands(allOperands, operandTypes, loc, result.operands))
141 return failure();
142
143 result.addTypes(memType.getElementType());
144
145 llvm::SmallVector<int32_t, 2> operandSizes;
146 operandSizes.push_back(1); // memory handle
147 operandSizes.push_back(addressOperands.size());
148 operandSizes.push_back(hasRdEn ? 1 : 0);
149 result.addAttribute("operandSegmentSizes",
150 parser.getBuilder().getDenseI32ArrayAttr(operandSizes));
151 return success();
152}
153
154void ReadPortOp::print(OpAsmPrinter &p) {
155 p << " " << getMemory() << "[" << getAddresses() << "]";
156 if (getRdEn())
157 p << " rden " << getRdEn();
158 p.printOptionalAttrDict((*this)->getAttrs(), {"operandSegmentSizes"});
159 p << " : " << getMemory().getType();
160}
161
162void ReadPortOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
163 auto memName = getMemory().getDefiningOp<seq::HLMemOp>().getName();
164 setNameFn(getReadData(), (memName + "_rdata").str());
165}
166
167void ReadPortOp::build(OpBuilder &builder, OperationState &result, Value memory,
168 ValueRange addresses, Value rdEn, unsigned latency) {
169 auto memType = cast<seq::HLMemType>(memory.getType());
170 ReadPortOp::build(builder, result, memType.getElementType(), memory,
171 addresses, rdEn, latency);
172}
173
174//===----------------------------------------------------------------------===//
175// WritePortOp
176//===----------------------------------------------------------------------===//
177
178ParseResult WritePortOp::parse(OpAsmParser &parser, OperationState &result) {
179 llvm::SMLoc loc = parser.getCurrentLocation();
180 OpAsmParser::UnresolvedOperand memOperand, dataOperand, wrenOperand;
181 llvm::SmallVector<OpAsmParser::UnresolvedOperand, 2> addressOperands;
182 seq::HLMemType memType;
183
184 if (parser.parseOperand(memOperand) ||
185 parser.parseOperandList(addressOperands,
186 OpAsmParser::Delimiter::Square) ||
187 parser.parseOperand(dataOperand) || parser.parseKeyword("wren") ||
188 parser.parseOperand(wrenOperand) ||
189 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
190 parser.parseType(memType))
191 return failure();
192
193 llvm::SmallVector<Type> operandTypes = memType.getAddressTypes();
194 operandTypes.insert(operandTypes.begin(), memType);
195 operandTypes.push_back(memType.getElementType());
196 operandTypes.push_back(parser.getBuilder().getI1Type());
197
198 llvm::SmallVector<OpAsmParser::UnresolvedOperand, 2> allOperands(
199 addressOperands);
200 allOperands.insert(allOperands.begin(), memOperand);
201 allOperands.push_back(dataOperand);
202 allOperands.push_back(wrenOperand);
203
204 if (parser.resolveOperands(allOperands, operandTypes, loc, result.operands))
205 return failure();
206
207 return success();
208}
209
210void WritePortOp::print(OpAsmPrinter &p) {
211 p << " " << getMemory() << "[" << getAddresses() << "] " << getInData()
212 << " wren " << getWrEn();
213 p.printOptionalAttrDict((*this)->getAttrs());
214 p << " : " << getMemory().getType();
215}
216
217//===----------------------------------------------------------------------===//
218// HLMemOp
219//===----------------------------------------------------------------------===//
220
221void HLMemOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
222 setNameFn(getHandle(), getName());
223}
224
225void HLMemOp::build(OpBuilder &builder, OperationState &result, Value clk,
226 Value rst, StringRef name, llvm::ArrayRef<int64_t> shape,
227 Type elementType) {
228 HLMemType t = HLMemType::get(builder.getContext(), shape, elementType);
229 HLMemOp::build(builder, result, t, clk, rst, name);
230}
231
232//===----------------------------------------------------------------------===//
233// FIFOOp
234//===----------------------------------------------------------------------===//
235
236// Flag threshold custom directive
237static ParseResult parseFIFOFlagThreshold(OpAsmParser &parser,
238 IntegerAttr &threshold,
239 Type &outputFlagType,
240 StringRef directive) {
241 // look for an optional "almost_full $threshold" group.
242 if (succeeded(parser.parseOptionalKeyword(directive))) {
243 int64_t thresholdValue;
244 if (succeeded(parser.parseInteger(thresholdValue))) {
245 threshold = parser.getBuilder().getI64IntegerAttr(thresholdValue);
246 outputFlagType = parser.getBuilder().getI1Type();
247 return success();
248 }
249 return parser.emitError(parser.getNameLoc(),
250 "expected integer value after " + directive +
251 " directive");
252 }
253 return success();
254}
255
256ParseResult parseFIFOAFThreshold(OpAsmParser &parser, IntegerAttr &threshold,
257 Type &outputFlagType) {
258 return parseFIFOFlagThreshold(parser, threshold, outputFlagType,
259 "almost_full");
260}
261
262ParseResult parseFIFOAEThreshold(OpAsmParser &parser, IntegerAttr &threshold,
263 Type &outputFlagType) {
264 return parseFIFOFlagThreshold(parser, threshold, outputFlagType,
265 "almost_empty");
266}
267
268void printFIFOAFThreshold(OpAsmPrinter &p, Operation *op, IntegerAttr threshold,
269 Type outputFlagType) {
270 if (threshold)
271 p << "almost_full"
272 << " " << threshold.getInt();
273}
274
275void printFIFOAEThreshold(OpAsmPrinter &p, Operation *op, IntegerAttr threshold,
276 Type outputFlagType) {
277 if (threshold)
278 p << "almost_empty"
279 << " " << threshold.getInt();
280}
281
282void FIFOOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
283 setNameFn(getOutput(), "out");
284 setNameFn(getEmpty(), "empty");
285 setNameFn(getFull(), "full");
286 if (auto ae = getAlmostEmpty())
287 setNameFn(ae, "almostEmpty");
288 if (auto af = getAlmostFull())
289 setNameFn(af, "almostFull");
290}
291
292LogicalResult FIFOOp::verify() {
293 auto aet = getAlmostEmptyThreshold();
294 auto aft = getAlmostFullThreshold();
295 size_t depth = getDepth();
296 if (aft.has_value() && aft.value() > depth)
297 return emitOpError("almost full threshold must be <= FIFO depth");
298
299 if (aet.has_value() && aet.value() > depth)
300 return emitOpError("almost empty threshold must be <= FIFO depth");
301
302 return success();
303}
304
305//===----------------------------------------------------------------------===//
306// CompRegOp
307//===----------------------------------------------------------------------===//
308
309/// Suggest a name for each result value based on the saved result names
310/// attribute.
311void CompRegOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
312 if (auto name = getName())
313 setNameFn(getResult(), *name);
314}
315
316template <typename TOp>
317LogicalResult verifyResets(TOp op) {
318 if ((op.getReset() == nullptr) ^ (op.getResetValue() == nullptr))
319 return op->emitOpError(
320 "either reset and resetValue or neither must be specified");
321 bool hasReset = op.getReset() != nullptr;
322 if (hasReset && op.getResetValue().getType() != op.getInput().getType())
323 return op->emitOpError("reset value must be the same type as the input");
324
325 return success();
326}
327
328std::optional<size_t> CompRegOp::getTargetResultIndex() { return 0; }
329
330LogicalResult CompRegOp::verify() { return verifyResets(*this); }
331
332//===----------------------------------------------------------------------===//
333// CompRegClockEnabledOp
334//===----------------------------------------------------------------------===//
335
336/// Suggest a name for each result value based on the saved result names
337/// attribute.
338void CompRegClockEnabledOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
339 if (auto name = getName())
340 setNameFn(getResult(), *name);
341}
342
343std::optional<size_t> CompRegClockEnabledOp::getTargetResultIndex() {
344 return 0;
345}
346
347LogicalResult CompRegClockEnabledOp::verify() { return verifyResets(*this); }
348
349LogicalResult CompRegClockEnabledOp::canonicalize(CompRegClockEnabledOp op,
350 PatternRewriter &rewriter) {
351 // reg(comb.mux(en, d, ?), en) -> reg(d, en)
352 // reg(arith.select(en, d, ?), en) -> reg(d, en)
353 auto *inputOp = op.getInput().getDefiningOp();
354 if (isa_and_nonnull<comb::MuxOp, arith::SelectOp>(inputOp) &&
355 inputOp->getOperand(0) == op.getClockEnable()) {
356 rewriter.modifyOpInPlace(
357 op, [&] { op.getInputMutable().assign(inputOp->getOperand(1)); });
358 return success();
359 }
360
361 // Match constant clock enable values.
362 APInt en;
363 if (mlir::matchPattern(op.getClockEnable(), mlir::m_ConstantInt(&en))) {
364 if (en.isAllOnes()) {
365 rewriter.replaceOpWithNewOp<seq::CompRegOp>(
366 op, op.getInput(), op.getClk(), op.getNameAttr(), op.getReset(),
367 op.getResetValue(), op.getInitialValue(), op.getInnerSymAttr());
368 return success();
369 }
370 }
371
372 return failure();
373}
374
375//===----------------------------------------------------------------------===//
376// ShiftRegOp
377//===----------------------------------------------------------------------===//
378
379void ShiftRegOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
380 // If the wire has an optional 'name' attribute, use it.
381 if (auto name = getName())
382 setNameFn(getResult(), *name);
383}
384
385std::optional<size_t> ShiftRegOp::getTargetResultIndex() { return 0; }
386
387LogicalResult ShiftRegOp::verify() {
388 if (failed(verifyResets(*this)))
389 return failure();
390 return success();
391}
392
393//===----------------------------------------------------------------------===//
394// FirRegOp
395//===----------------------------------------------------------------------===//
396
397void FirRegOp::build(OpBuilder &builder, OperationState &result, Value input,
398 Value clk, StringAttr name, hw::InnerSymAttr innerSym,
399 Attribute preset) {
400
401 OpBuilder::InsertionGuard guard(builder);
402
403 result.addOperands(input);
404 result.addOperands(clk);
405
406 result.addAttribute(getNameAttrName(result.name), name);
407
408 if (innerSym)
409 result.addAttribute(getInnerSymAttrName(result.name), innerSym);
410
411 if (preset)
412 result.addAttribute(getPresetAttrName(result.name), preset);
413
414 result.addTypes(input.getType());
415}
416
417void FirRegOp::build(OpBuilder &builder, OperationState &result, Value input,
418 Value clk, StringAttr name, Value reset, Value resetValue,
419 hw::InnerSymAttr innerSym, bool isAsync) {
420
421 OpBuilder::InsertionGuard guard(builder);
422
423 result.addOperands(input);
424 result.addOperands(clk);
425 result.addOperands(reset);
426 result.addOperands(resetValue);
427
428 result.addAttribute(getNameAttrName(result.name), name);
429 if (isAsync)
430 result.addAttribute(getIsAsyncAttrName(result.name), builder.getUnitAttr());
431
432 if (innerSym)
433 result.addAttribute(getInnerSymAttrName(result.name), innerSym);
434
435 result.addTypes(input.getType());
436}
437
438ParseResult FirRegOp::parse(OpAsmParser &parser, OperationState &result) {
439 auto &builder = parser.getBuilder();
440 llvm::SMLoc loc = parser.getCurrentLocation();
441
442 using Op = OpAsmParser::UnresolvedOperand;
443
444 Op next, clk;
445 if (parser.parseOperand(next) || parser.parseKeyword("clock") ||
446 parser.parseOperand(clk))
447 return failure();
448
449 if (succeeded(parser.parseOptionalKeyword("sym"))) {
450 hw::InnerSymAttr innerSym;
451 if (parser.parseCustomAttributeWithFallback(innerSym, /*type=*/nullptr,
452 "inner_sym", result.attributes))
453 return failure();
454 }
455
456 // Parse reset [sync|async] %reset, %value
457 std::optional<std::pair<Op, Op>> resetAndValue;
458 if (succeeded(parser.parseOptionalKeyword("reset"))) {
459 bool isAsync;
460 if (succeeded(parser.parseOptionalKeyword("async")))
461 isAsync = true;
462 else if (succeeded(parser.parseOptionalKeyword("sync")))
463 isAsync = false;
464 else
465 return parser.emitError(loc, "invalid reset, expected 'sync' or 'async'");
466 if (isAsync)
467 result.attributes.append("isAsync", builder.getUnitAttr());
468
469 resetAndValue = {{}, {}};
470 if (parser.parseOperand(resetAndValue->first) || parser.parseComma() ||
471 parser.parseOperand(resetAndValue->second))
472 return failure();
473 }
474
475 std::optional<APInt> presetValue;
476 llvm::SMLoc presetValueLoc;
477 if (succeeded(parser.parseOptionalKeyword("preset"))) {
478 presetValueLoc = parser.getCurrentLocation();
479 OptionalParseResult presetIntResult =
480 parser.parseOptionalInteger(presetValue.emplace());
481 if (!presetIntResult.has_value() || failed(*presetIntResult))
482 return parser.emitError(presetValueLoc, "expected integer value");
483 if (presetValue->isNegative())
484 return parser.emitError(presetValueLoc,
485 "preset value must not be negative");
486 }
487
488 Type ty;
489 if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
490 parser.parseType(ty))
491 return failure();
492 result.addTypes({ty});
493
494 if (presetValue) {
495 uint64_t width = 0;
496 if (hw::type_isa<seq::ClockType>(ty)) {
497 width = 1;
498 } else {
499 int64_t maybeWidth = hw::getBitWidth(ty);
500 if (maybeWidth < 0)
501 return parser.emitError(presetValueLoc,
502 "cannot preset register of unknown width");
503 width = maybeWidth;
504 }
505
506 APInt presetResult = presetValue->sextOrTrunc(width);
507 if (presetResult.zextOrTrunc(presetValue->getBitWidth()) != *presetValue)
508 return parser.emitError(presetValueLoc, "preset value too large");
509
510 auto builder = parser.getBuilder();
511 auto presetTy = builder.getIntegerType(width);
512 auto resultAttr = builder.getIntegerAttr(presetTy, presetResult);
513 result.addAttribute("preset", resultAttr);
514 }
515
516 setNameFromResult(parser, result);
517
518 if (parser.resolveOperand(next, ty, result.operands))
519 return failure();
520
521 Type clkTy = ClockType::get(result.getContext());
522 if (parser.resolveOperand(clk, clkTy, result.operands))
523 return failure();
524
525 if (resetAndValue) {
526 Type i1 = IntegerType::get(result.getContext(), 1);
527 if (parser.resolveOperand(resetAndValue->first, i1, result.operands) ||
528 parser.resolveOperand(resetAndValue->second, ty, result.operands))
529 return failure();
530 }
531
532 return success();
533}
534
535void FirRegOp::print(::mlir::OpAsmPrinter &p) {
536 SmallVector<StringRef> elidedAttrs = {
537 getInnerSymAttrName(), getIsAsyncAttrName(), getPresetAttrName()};
538
539 p << ' ' << getNext() << " clock " << getClk();
540
541 if (auto sym = getInnerSymAttr()) {
542 p << " sym ";
543 sym.print(p);
544 }
545
546 if (hasReset()) {
547 p << " reset " << (getIsAsync() ? "async" : "sync") << ' ';
548 p << getReset() << ", " << getResetValue();
549 }
550
551 if (auto preset = getPresetAttr()) {
552 p << " preset ";
553
554 // Don't emit negative integers to match the parsing logic.
555 const auto &presetVal = preset.getValue();
556 if (presetVal.isNonNegative())
557 p << presetVal;
558 else
559 p << presetVal.zext(presetVal.getBitWidth() + 1);
560 }
561
562 if (canElideName(p, *this))
563 elidedAttrs.push_back("name");
564
565 p.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
566 p << " : " << getNext().getType();
567}
568
569/// Verifier for the FIR register op.
570LogicalResult FirRegOp::verify() {
571 if (getReset() || getResetValue() || getIsAsync()) {
572 if (!getReset() || !getResetValue())
573 return emitOpError("must specify reset and reset value");
574 } else {
575 if (getIsAsync())
576 return emitOpError("register with no reset cannot be async");
577 }
578 if (auto preset = getPresetAttr()) {
579 int64_t presetWidth = hw::getBitWidth(preset.getType());
580 int64_t width = hw::getBitWidth(getType());
581 if (preset.getType() != getType() && presetWidth != width)
582 return emitOpError("preset type width must match register type");
583 }
584 return success();
585}
586
587/// Suggest a name for each result value based on the saved result names
588/// attribute.
589void FirRegOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
590 // If the register has an optional 'name' attribute, use it.
591 if (!getName().empty())
592 setNameFn(getResult(), getName());
593}
594
595std::optional<size_t> FirRegOp::getTargetResultIndex() { return 0; }
596
597LogicalResult FirRegOp::canonicalize(FirRegOp op, PatternRewriter &rewriter) {
598
599 // If the register has a constant zero reset, drop the reset and reset value
600 // altogether (And preserve the PresetAttr).
601 if (auto reset = op.getReset()) {
602 if (auto constOp = reset.getDefiningOp<hw::ConstantOp>()) {
603 if (constOp.getValue().isZero()) {
604 rewriter.replaceOpWithNewOp<FirRegOp>(
605 op, op.getNext(), op.getClk(), op.getNameAttr(),
606 op.getInnerSymAttr(), op.getPresetAttr());
607 return success();
608 }
609 }
610 }
611
612 // If the register has a symbol, we can't optimize it away.
613 if (op.getInnerSymAttr())
614 return failure();
615
616 // Replace a register with a trivial feedback or constant clock with a
617 // constant zero.
618 // TODO: Once HW aggregate constant values are supported, move this
619 // canonicalization to the folder.
620 auto isConstant = [&]() -> bool {
621 if (op.getNext() == op.getResult())
622 return true;
623 if (auto clk = op.getClk().getDefiningOp<seq::ToClockOp>())
624 return clk.getInput().getDefiningOp<hw::ConstantOp>();
625 return false;
626 };
627
628 // Preset can block canonicalization only if it is non-zero.
629 bool replaceWithConstZero = true;
630 if (auto preset = op.getPresetAttr())
631 if (!preset.getValue().isZero())
632 replaceWithConstZero = false;
633
634 if (isConstant() && !op.getResetValue() && replaceWithConstZero) {
635 if (isa<seq::ClockType>(op.getType())) {
636 rewriter.replaceOpWithNewOp<seq::ConstClockOp>(
637 op, seq::ClockConstAttr::get(rewriter.getContext(), ClockConst::Low));
638 } else {
639 auto constant = hw::ConstantOp::create(
640 rewriter, op.getLoc(), APInt::getZero(hw::getBitWidth(op.getType())));
641 rewriter.replaceOpWithNewOp<hw::BitcastOp>(op, op.getType(), constant);
642 }
643 return success();
644 }
645
646 // Canonicalize registers with mux-based constant drivers.
647 // This pattern matches registers where the next value is a mux with one
648 // branch being the register itself (creating a self-loop) and the other
649 // branch being a constant. In such cases, the register effectively holds a
650 // constant value and can be replaced with that constant.
651 if (auto nextMux = op.getNext().getDefiningOp<comb::MuxOp>()) {
652 // Reject optimization if register has preset attribute (for simplicity)
653 if (op.getPresetAttr())
654 return failure();
655
656 Attribute value;
657 Value replacedValue;
658
659 // Check if true branch is self-loop and false branch is constant
660 if (nextMux.getTrueValue() == op.getResult() &&
661 matchPattern(nextMux.getFalseValue(), m_Constant(&value))) {
662 replacedValue = nextMux.getFalseValue();
663 }
664 // Check if false branch is self-loop and true branch is constant
665 else if (nextMux.getFalseValue() == op.getResult() &&
666 matchPattern(nextMux.getTrueValue(), m_Constant(&value))) {
667 replacedValue = nextMux.getTrueValue();
668 }
669
670 if (!replacedValue)
671 return failure();
672
673 // Verify reset value compatibility: if register has reset, it must be
674 // a constant that matches the mux constant
675 if (op.getResetValue()) {
676 Attribute resetConst;
677 if (matchPattern(op.getResetValue(), m_Constant(&resetConst))) {
678 if (resetConst != value)
679 return failure();
680 } else {
681 // Non-constant reset value prevents optimization
682 return failure();
683 }
684 }
685
686 assert(replacedValue);
687 // Apply the optimization if all conditions are met
688 rewriter.replaceOp(op, replacedValue);
689 return success();
690 }
691
692 // For reset-less 1d array registers, replace an uninitialized element with
693 // constant zero. For example, let `r` be a 2xi1 register and its next value
694 // be `{foo, r[0]}`. `r[0]` is connected to itself so will never be
695 // initialized. If we don't enable aggregate preservation, `r_0` is replaced
696 // with `0`. Hence this canonicalization replaces 0th element of the next
697 // value with zero to match the behaviour.
698 if (!op.getReset() && !op.getPresetAttr()) {
699 if (auto arrayCreate = op.getNext().getDefiningOp<hw::ArrayCreateOp>()) {
700 // For now only support 1d arrays.
701 // TODO: Support nested arrays and bundles.
702 if (isa<IntegerType>(
703 hw::type_cast<hw::ArrayType>(op.getResult().getType())
704 .getElementType())) {
705 SmallVector<Value> nextOperands;
706 bool changed = false;
707 for (const auto &[i, value] :
708 llvm::enumerate(arrayCreate.getOperands())) {
709 auto index = arrayCreate.getOperands().size() - i - 1;
710 APInt elementIndex;
711 // Check that the corresponding operand is op's element.
712 if (auto arrayGet = value.getDefiningOp<hw::ArrayGetOp>())
713 if (arrayGet.getInput() == op.getResult() &&
714 matchPattern(arrayGet.getIndex(),
715 m_ConstantInt(&elementIndex)) &&
716 elementIndex == index) {
717 nextOperands.push_back(hw::ConstantOp::create(
718 rewriter, op.getLoc(),
719 APInt::getZero(hw::getBitWidth(arrayGet.getType()))));
720 changed = true;
721 continue;
722 }
723 nextOperands.push_back(value);
724 }
725 // If one of the operands is self loop, update the next value.
726 if (changed) {
727 auto newNextVal = hw::ArrayCreateOp::create(
728 rewriter, arrayCreate.getLoc(), nextOperands);
729 if (arrayCreate->hasOneUse())
730 // If the original next value has a single use, we can replace the
731 // value directly.
732 rewriter.replaceOp(arrayCreate, newNextVal);
733 else {
734 // Otherwise, replace the entire firreg with a new one.
735 rewriter.replaceOpWithNewOp<FirRegOp>(op, newNextVal, op.getClk(),
736 op.getNameAttr(),
737 op.getInnerSymAttr());
738 }
739
740 return success();
741 }
742 }
743 }
744 }
745
746 return failure();
747}
748
749OpFoldResult FirRegOp::fold(FoldAdaptor adaptor) {
750 // If the register has a symbol or preset value, we can't optimize it away.
751 // TODO: Handle a preset value.
752 if (getInnerSymAttr())
753 return {};
754
755 auto presetAttr = getPresetAttr();
756
757 // If the register is held in permanent reset, replace it with its reset
758 // value. This works trivially if the reset is asynchronous and therefore
759 // level-sensitive, in which case it will always immediately assume the reset
760 // value in silicon. If it is synchronous, the register value is undefined
761 // until the first clock edge at which point it becomes the reset value, in
762 // which case we simply define the initial value to already be the reset
763 // value. Works only if no preset.
764 if (!presetAttr)
765 if (auto reset = getReset())
766 if (auto constOp = reset.getDefiningOp<hw::ConstantOp>())
767 if (constOp.getValue().isOne())
768 return getResetValue();
769
770 // If the register's next value is trivially it's current value, or the
771 // register is never clocked, we can replace the register with a constant
772 // value.
773 bool isTrivialFeedback = (getNext() == getResult());
774 bool isNeverClocked =
775 adaptor.getClk() != nullptr; // clock operand is constant
776 if (!isTrivialFeedback && !isNeverClocked)
777 return {};
778
779 // If the register has a const reset value, and no preset, we can replace it
780 // with the const reset. We cannot replace it with a non-constant reset value.
781 if (auto resetValue = getResetValue()) {
782 if (auto *op = resetValue.getDefiningOp()) {
783 if (op->hasTrait<OpTrait::ConstantLike>() && !presetAttr)
784 return resetValue;
785 if (auto constOp = dyn_cast<hw::ConstantOp>(op))
786 if (presetAttr.getValue() == constOp.getValue())
787 return resetValue;
788 }
789 return {};
790 }
791
792 // Otherwise we want to replace the register with a constant 0. For now this
793 // only works with integer types.
794 auto intType = dyn_cast<IntegerType>(getType());
795 if (!intType)
796 return {};
797 // If preset present, then replace with preset.
798 if (presetAttr)
799 return presetAttr;
800 return IntegerAttr::get(intType, 0);
801}
802
803//===----------------------------------------------------------------------===//
804// ClockGateOp
805//===----------------------------------------------------------------------===//
806
807OpFoldResult ClockGateOp::fold(FoldAdaptor adaptor) {
808 // Forward the clock if one of the enables is always true.
809 if (isConstantOne(adaptor.getEnable()) ||
810 isConstantOne(adaptor.getTestEnable()))
811 return getInput();
812
813 // Fold to a constant zero clock if the enables are always false.
814 if (isConstantZero(adaptor.getEnable()) &&
815 (!getTestEnable() || isConstantZero(adaptor.getTestEnable())))
816 return ClockConstAttr::get(getContext(), ClockConst::Low);
817
818 // Forward constant zero clocks.
819 if (auto clockAttr = dyn_cast_or_null<ClockConstAttr>(adaptor.getInput()))
820 if (clockAttr.getValue() == ClockConst::Low)
821 return ClockConstAttr::get(getContext(), ClockConst::Low);
822
823 // Transitive clock gating - eliminate clock gates that are driven by an
824 // identical enable signal somewhere higher in the clock gate hierarchy.
825 auto clockGateInputOp = getInput().getDefiningOp<ClockGateOp>();
826 while (clockGateInputOp) {
827 if (clockGateInputOp.getEnable() == getEnable() &&
828 clockGateInputOp.getTestEnable() == getTestEnable())
829 return getInput();
830 clockGateInputOp = clockGateInputOp.getInput().getDefiningOp<ClockGateOp>();
831 }
832
833 return {};
834}
835
836LogicalResult ClockGateOp::canonicalize(ClockGateOp op,
837 PatternRewriter &rewriter) {
838 // Remove constant false test enable.
839 if (auto testEnable = op.getTestEnable()) {
840 if (auto constOp = testEnable.getDefiningOp<hw::ConstantOp>()) {
841 if (constOp.getValue().isZero()) {
842 rewriter.modifyOpInPlace(op,
843 [&] { op.getTestEnableMutable().clear(); });
844 return success();
845 }
846 }
847 }
848
849 return failure();
850}
851
852std::optional<size_t> ClockGateOp::getTargetResultIndex() {
853 return std::nullopt;
854}
855
856//===----------------------------------------------------------------------===//
857// ClockMuxOp
858//===----------------------------------------------------------------------===//
859
860OpFoldResult ClockMuxOp::fold(FoldAdaptor adaptor) {
861 if (isConstantOne(adaptor.getCond()))
862 return getTrueClock();
863 if (isConstantZero(adaptor.getCond()))
864 return getFalseClock();
865 return {};
866}
867
868//===----------------------------------------------------------------------===//
869// ClockDividerOp
870//===----------------------------------------------------------------------===//
871
872LogicalResult ClockDividerOp::canonicalize(ClockDividerOp op,
873 PatternRewriter &rewriter) {
874 // clock_div(clock_div(clock, a), b) -> clock_div(clock, a + b)
875 if (auto innerDiv = op.getInput().getDefiningOp<ClockDividerOp>()) {
876 auto outerPow2 = op.getPow2();
877 auto innerPow2 = innerDiv.getPow2();
878 auto combinedPow2 = outerPow2 + innerPow2;
879
880 rewriter.replaceOpWithNewOp<ClockDividerOp>(op, innerDiv.getInput(),
881 combinedPow2);
882 return success();
883 }
884 return failure();
885}
886
887//===----------------------------------------------------------------------===//
888// FirMemOp
889//===----------------------------------------------------------------------===//
890
891LogicalResult FirMemOp::canonicalize(FirMemOp op, PatternRewriter &rewriter) {
892 // Do not change memories if symbols point to them.
893 if (op.getInnerSymAttr())
894 return failure();
895
896 bool readOnly = true, writeOnly = true;
897
898 // If the memory has no read ports, erase it.
899 for (auto *user : op->getUsers()) {
900 if (isa<FirMemReadOp, FirMemReadWriteOp>(user)) {
901 writeOnly = false;
902 }
903 if (isa<FirMemWriteOp, FirMemReadWriteOp>(user)) {
904 readOnly = false;
905 }
906 assert((isa<FirMemReadOp, FirMemWriteOp, FirMemReadWriteOp>(user)) &&
907 "invalid seq.firmem user");
908 }
909 if (writeOnly) {
910 for (auto *user : llvm::make_early_inc_range(op->getUsers()))
911 rewriter.eraseOp(user);
912
913 rewriter.eraseOp(op);
914 return success();
915 }
916
917 if (readOnly && !op.getInit()) {
918 // Replace all read ports with a constant 0.
919 for (auto *user : llvm::make_early_inc_range(op->getUsers())) {
920 auto readOp = cast<FirMemReadOp>(user);
921 Value zero = hw::ConstantOp::create(
922 rewriter, readOp.getLoc(),
923 APInt::getZero(hw::getBitWidth(readOp.getType())));
924 if (readOp.getType() != zero.getType())
925 zero = hw::BitcastOp::create(rewriter, readOp.getLoc(),
926 readOp.getType(), zero);
927 rewriter.replaceOp(readOp, zero);
928 }
929 rewriter.eraseOp(op);
930 return success();
931 }
932 return failure();
933}
934
935void FirMemOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
936 auto nameAttr = (*this)->getAttrOfType<StringAttr>("name");
937 if (!nameAttr.getValue().empty())
938 setNameFn(getResult(), nameAttr.getValue());
939}
940
941std::optional<size_t> FirMemOp::getTargetResultIndex() { return 0; }
942
943template <class Op>
944static LogicalResult verifyFirMemMask(Op op) {
945 if (auto mask = op.getMask()) {
946 auto memType = op.getMemory().getType();
947 if (!memType.getMaskWidth())
948 return op.emitOpError("has mask operand but memory type '")
949 << memType << "' has no mask";
950 auto expected = IntegerType::get(op.getContext(), *memType.getMaskWidth());
951 if (mask.getType() != expected)
952 return op.emitOpError("has mask operand of type '")
953 << mask.getType() << "', but memory type requires '" << expected
954 << "'";
955 }
956 return success();
957}
958
959LogicalResult FirMemWriteOp::verify() { return verifyFirMemMask(*this); }
960LogicalResult FirMemReadWriteOp::verify() { return verifyFirMemMask(*this); }
961
962static bool isConstClock(Value value) {
963 if (!value)
964 return false;
965 return value.getDefiningOp<seq::ConstClockOp>();
966}
967
968static bool isConstZero(Value value) {
969 if (value)
970 if (auto constOp = value.getDefiningOp<hw::ConstantOp>())
971 return constOp.getValue().isZero();
972 return false;
973}
974
975static bool isConstAllOnes(Value value) {
976 if (value)
977 if (auto constOp = value.getDefiningOp<hw::ConstantOp>())
978 return constOp.getValue().isAllOnes();
979 return false;
980}
981
982LogicalResult FirMemReadOp::canonicalize(FirMemReadOp op,
983 PatternRewriter &rewriter) {
984 // Remove the enable if it is constant true.
985 if (isConstAllOnes(op.getEnable())) {
986 rewriter.modifyOpInPlace(op, [&] { op.getEnableMutable().erase(0); });
987 return success();
988 }
989 return failure();
990}
991
992LogicalResult FirMemWriteOp::canonicalize(FirMemWriteOp op,
993 PatternRewriter &rewriter) {
994 // Remove the write port if it is trivially dead.
995 if (isConstZero(op.getEnable()) || isConstZero(op.getMask()) ||
996 isConstClock(op.getClk())) {
997 auto memOp = op.getMemory().getDefiningOp<FirMemOp>();
998 if (memOp.getInnerSymAttr())
999 return failure();
1000 rewriter.eraseOp(op);
1001 return success();
1002 }
1003 bool anyChanges = false;
1004
1005 // Remove the enable if it is constant true.
1006 if (auto enable = op.getEnable(); isConstAllOnes(enable)) {
1007 rewriter.modifyOpInPlace(op, [&] { op.getEnableMutable().erase(0); });
1008 anyChanges = true;
1009 }
1010
1011 // Remove the mask if it is all ones.
1012 if (auto mask = op.getMask(); isConstAllOnes(mask)) {
1013 rewriter.modifyOpInPlace(op, [&] { op.getMaskMutable().erase(0); });
1014 anyChanges = true;
1015 }
1016
1017 return success(anyChanges);
1018}
1019
1020LogicalResult FirMemReadWriteOp::canonicalize(FirMemReadWriteOp op,
1021 PatternRewriter &rewriter) {
1022 // Replace the read-write port with a read port if the write behavior is
1023 // trivially disabled.
1024 if (isConstZero(op.getEnable()) || isConstZero(op.getMask()) ||
1025 isConstClock(op.getClk()) || isConstZero(op.getMode())) {
1026 auto opAttrs = op->getAttrs();
1027 auto opAttrNames = op.getAttributeNames();
1028 auto newOp = rewriter.replaceOpWithNewOp<FirMemReadOp>(
1029 op, op.getMemory(), op.getAddress(), op.getClk(), op.getEnable());
1030 for (auto namedAttr : opAttrs)
1031 if (!llvm::is_contained(opAttrNames, namedAttr.getName()))
1032 newOp->setAttr(namedAttr.getName(), namedAttr.getValue());
1033 return success();
1034 }
1035 bool anyChanges = false;
1036
1037 // Remove the enable if it is constant true.
1038 if (auto enable = op.getEnable(); isConstAllOnes(enable)) {
1039 rewriter.modifyOpInPlace(op, [&] { op.getEnableMutable().erase(0); });
1040 anyChanges = true;
1041 }
1042
1043 // Remove the mask if it is all ones.
1044 if (auto mask = op.getMask(); isConstAllOnes(mask)) {
1045 rewriter.modifyOpInPlace(op, [&] { op.getMaskMutable().erase(0); });
1046 anyChanges = true;
1047 }
1048
1049 return success(anyChanges);
1050}
1051
1052//===----------------------------------------------------------------------===//
1053// ConstClockOp
1054//===----------------------------------------------------------------------===//
1055
1056OpFoldResult ConstClockOp::fold(FoldAdaptor adaptor) {
1057 return ClockConstAttr::get(getContext(), getValue());
1058}
1059
1060//===----------------------------------------------------------------------===//
1061// ToClockOp/FromClockOp
1062//===----------------------------------------------------------------------===//
1063
1064LogicalResult ToClockOp::canonicalize(ToClockOp op, PatternRewriter &rewriter) {
1065 if (auto fromClock = op.getInput().getDefiningOp<FromClockOp>()) {
1066 rewriter.replaceOp(op, fromClock.getInput());
1067 return success();
1068 }
1069 return failure();
1070}
1071
1072OpFoldResult ToClockOp::fold(FoldAdaptor adaptor) {
1073 if (auto fromClock = getInput().getDefiningOp<FromClockOp>())
1074 return fromClock.getInput();
1075 if (auto intAttr = dyn_cast_or_null<IntegerAttr>(adaptor.getInput())) {
1076 auto value =
1077 intAttr.getValue().isZero() ? ClockConst::Low : ClockConst::High;
1078 return ClockConstAttr::get(getContext(), value);
1079 }
1080 return {};
1081}
1082
1083LogicalResult FromClockOp::canonicalize(FromClockOp op,
1084 PatternRewriter &rewriter) {
1085 if (auto toClock = op.getInput().getDefiningOp<ToClockOp>()) {
1086 rewriter.replaceOp(op, toClock.getInput());
1087 return success();
1088 }
1089 return failure();
1090}
1091
1092OpFoldResult FromClockOp::fold(FoldAdaptor adaptor) {
1093 if (auto toClock = getInput().getDefiningOp<ToClockOp>())
1094 return toClock.getInput();
1095 if (auto clockAttr = dyn_cast_or_null<ClockConstAttr>(adaptor.getInput())) {
1096 auto ty = IntegerType::get(getContext(), 1);
1097 return IntegerAttr::get(ty, clockAttr.getValue() == ClockConst::High);
1098 }
1099 return {};
1100}
1101
1102//===----------------------------------------------------------------------===//
1103// ClockInverterOp
1104//===----------------------------------------------------------------------===//
1105
1106OpFoldResult ClockInverterOp::fold(FoldAdaptor adaptor) {
1107 if (auto chainedInv = getInput().getDefiningOp<ClockInverterOp>())
1108 return chainedInv.getInput();
1109 if (auto clockAttr = dyn_cast_or_null<ClockConstAttr>(adaptor.getInput())) {
1110 auto clockIn = clockAttr.getValue() == ClockConst::High;
1111 return ClockConstAttr::get(getContext(),
1112 clockIn ? ClockConst::Low : ClockConst::High);
1113 }
1114 return {};
1115}
1116
1117//===----------------------------------------------------------------------===//
1118// FIR memory helper
1119//===----------------------------------------------------------------------===//
1120
1121FirMemory::FirMemory(hw::HWModuleGeneratedOp op) {
1122 depth = op->getAttrOfType<IntegerAttr>("depth").getInt();
1123 numReadPorts = op->getAttrOfType<IntegerAttr>("numReadPorts").getUInt();
1124 numWritePorts = op->getAttrOfType<IntegerAttr>("numWritePorts").getUInt();
1125 numReadWritePorts =
1126 op->getAttrOfType<IntegerAttr>("numReadWritePorts").getUInt();
1127 readLatency = op->getAttrOfType<IntegerAttr>("readLatency").getUInt();
1128 writeLatency = op->getAttrOfType<IntegerAttr>("writeLatency").getUInt();
1129 dataWidth = op->getAttrOfType<IntegerAttr>("width").getUInt();
1130 if (op->hasAttrOfType<IntegerAttr>("maskGran"))
1131 maskGran = op->getAttrOfType<IntegerAttr>("maskGran").getUInt();
1132 else
1133 maskGran = dataWidth;
1134 readUnderWrite = op->getAttrOfType<seq::RUWAttr>("readUnderWrite").getValue();
1135 writeUnderWrite =
1136 op->getAttrOfType<seq::WUWAttr>("writeUnderWrite").getValue();
1137 if (auto clockIDsAttr = op->getAttrOfType<ArrayAttr>("writeClockIDs"))
1138 for (auto clockID : clockIDsAttr)
1139 writeClockIDs.push_back(
1140 cast<IntegerAttr>(clockID).getValue().getZExtValue());
1141 initFilename = op->getAttrOfType<StringAttr>("initFilename").getValue();
1142 initIsBinary = op->getAttrOfType<BoolAttr>("initIsBinary").getValue();
1143 initIsInline = op->getAttrOfType<BoolAttr>("initIsInline").getValue();
1144}
1145
1146LogicalResult InitialOp::verify() {
1147 // Check outputs.
1148 auto *terminator = this->getBody().front().getTerminator();
1149 if (terminator->getOperands().size() != getNumResults())
1150 return emitError() << "result type doesn't match with the terminator";
1151 for (auto [lhs, rhs] :
1152 llvm::zip(terminator->getOperands().getTypes(), getResultTypes())) {
1153 if (cast<seq::ImmutableType>(rhs).getInnerType() != lhs)
1154 return emitError() << cast<seq::ImmutableType>(rhs).getInnerType()
1155 << " is expected but got " << lhs;
1156 }
1157
1158 auto blockArgs = this->getBody().front().getArguments();
1159
1160 if (blockArgs.size() != getNumOperands())
1161 return emitError() << "operand type doesn't match with the block arg";
1162
1163 for (auto [blockArg, operand] : llvm::zip(blockArgs, getOperands())) {
1164 if (blockArg.getType() !=
1165 cast<ImmutableType>(operand.getType()).getInnerType())
1166 return emitError()
1167 << blockArg.getType() << " is expected but got "
1168 << cast<ImmutableType>(operand.getType()).getInnerType();
1169 }
1170 return success();
1171}
1172void InitialOp::build(OpBuilder &builder, OperationState &result,
1173 TypeRange resultTypes, std::function<void()> ctor) {
1174 OpBuilder::InsertionGuard guard(builder);
1175
1176 builder.createBlock(result.addRegion());
1177 SmallVector<Type> types;
1178 for (auto t : resultTypes)
1179 types.push_back(seq::ImmutableType::get(t));
1180
1181 result.addTypes(types);
1182
1183 if (ctor)
1184 ctor();
1185}
1186
1187TypedValue<seq::ImmutableType>
1188circt::seq::createConstantInitialValue(OpBuilder builder, Location loc,
1189 mlir::IntegerAttr attr) {
1190 auto initial = seq::InitialOp::create(builder, loc, attr.getType(), [&]() {
1191 auto constant = hw::ConstantOp::create(builder, loc, attr);
1192 seq::YieldOp::create(builder, loc, ArrayRef<Value>{constant});
1193 });
1194 return cast<TypedValue<seq::ImmutableType>>(initial->getResult(0));
1195}
1196
1197mlir::TypedValue<seq::ImmutableType>
1198circt::seq::createConstantInitialValue(OpBuilder builder, Operation *op) {
1199 assert(op->getNumResults() == 1 &&
1200 op->hasTrait<mlir::OpTrait::ConstantLike>());
1201 auto initial = seq::InitialOp::create(
1202 builder, op->getLoc(), op->getResultTypes(), [&]() {
1203 auto clonedOp = builder.clone(*op);
1204 seq::YieldOp::create(builder, op->getLoc(), clonedOp->getResults());
1205 });
1206 return cast<mlir::TypedValue<seq::ImmutableType>>(initial.getResult(0));
1207}
1208
1209Value circt::seq::unwrapImmutableValue(TypedValue<seq::ImmutableType> value) {
1210 auto resultNum = cast<OpResult>(value).getResultNumber();
1211 auto initialOp = value.getDefiningOp<seq::InitialOp>();
1212 assert(initialOp);
1213 return initialOp.getBodyBlock()->getTerminator()->getOperand(resultNum);
1214}
1215
1216FailureOr<seq::InitialOp> circt::seq::mergeInitialOps(Block *block) {
1217 SmallVector<Operation *> initialOps;
1218 for (auto &op : *block)
1219 if (isa<seq::InitialOp>(op))
1220 initialOps.push_back(&op);
1221
1222 if (!mlir::computeTopologicalSorting(initialOps, {}))
1223 return block->getParentOp()->emitError() << "initial ops cannot be "
1224 << "topologically sorted";
1225
1226 // No need to merge if there is only one initial op.
1227 if (initialOps.size() <= 1)
1228 return initialOps.empty() ? seq::InitialOp()
1229 : cast<seq::InitialOp>(initialOps[0]);
1230
1231 auto initialOp = cast<seq::InitialOp>(initialOps.front());
1232 auto yieldOp = cast<seq::YieldOp>(initialOp.getBodyBlock()->getTerminator());
1233
1234 llvm::MapVector<Value, Value>
1235 resultToYieldOperand; // seq.immutable value to operand.
1236
1237 for (auto [result, operand] :
1238 llvm::zip(initialOp.getResults(), yieldOp->getOperands()))
1239 resultToYieldOperand.insert({result, operand});
1240
1241 for (size_t i = 1; i < initialOps.size(); ++i) {
1242 auto currentInitialOp = cast<seq::InitialOp>(initialOps[i]);
1243 auto operands = currentInitialOp->getOperands();
1244 for (auto [blockArg, operand] :
1245 llvm::zip(currentInitialOp.getBodyBlock()->getArguments(), operands)) {
1246 if (auto initOp = operand.getDefiningOp<seq::InitialOp>()) {
1247 assert(resultToYieldOperand.count(operand) &&
1248 "it must be visited already");
1249 blockArg.replaceAllUsesWith(resultToYieldOperand.lookup(operand));
1250 } else {
1251 // Otherwise add the operand to the current block.
1252 initialOp.getBodyBlock()->addArgument(
1253 cast<seq::ImmutableType>(operand.getType()).getInnerType(),
1254 operand.getLoc());
1255 initialOp.getInputsMutable().append(operand);
1256 }
1257 }
1258
1259 auto currentYieldOp =
1260 cast<seq::YieldOp>(currentInitialOp.getBodyBlock()->getTerminator());
1261
1262 for (auto [result, operand] : llvm::zip(currentInitialOp.getResults(),
1263 currentYieldOp->getOperands()))
1264 resultToYieldOperand.insert({result, operand});
1265
1266 // Append the operands of the current yield op to the original yield op.
1267 yieldOp.getOperandsMutable().append(currentYieldOp.getOperands());
1268 currentYieldOp->erase();
1269
1270 // Append the operations of the current initial op to the original initial
1271 // op.
1272 initialOp.getBodyBlock()->getOperations().splice(
1273 initialOp.end(), currentInitialOp.getBodyBlock()->getOperations());
1274 }
1275
1276 // Move the terminator to the end of the block.
1277 yieldOp->moveBefore(initialOp.getBodyBlock(),
1278 initialOp.getBodyBlock()->end());
1279
1280 auto builder = OpBuilder::atBlockBegin(block);
1281 SmallVector<Type> types;
1282 for (auto [result, operand] : resultToYieldOperand)
1283 types.push_back(operand.getType());
1284
1285 // Create a new initial op which accumulates the results of the merged initial
1286 // ops.
1287 auto newInitial = seq::InitialOp::create(builder, initialOp.getLoc(), types);
1288 newInitial.getInputsMutable().append(initialOp.getInputs());
1289
1290 for (auto [resultAndOperand, newResult] :
1291 llvm::zip(resultToYieldOperand, newInitial.getResults()))
1292 resultAndOperand.first.replaceAllUsesWith(newResult);
1293
1294 // Update the block arguments of the new initial op.
1295 for (auto oldBlockArg : initialOp.getBodyBlock()->getArguments()) {
1296 auto blockArg = newInitial.getBodyBlock()->addArgument(
1297 oldBlockArg.getType(), oldBlockArg.getLoc());
1298 oldBlockArg.replaceAllUsesWith(blockArg);
1299 }
1300
1301 newInitial.getBodyBlock()->getOperations().splice(
1302 newInitial.end(), initialOp.getBodyBlock()->getOperations());
1303
1304 // Clean up.
1305 while (!initialOps.empty())
1306 initialOps.pop_back_val()->erase();
1307
1308 return newInitial;
1309}
1310
1311//===----------------------------------------------------------------------===//
1312// TableGen generated logic.
1313//===----------------------------------------------------------------------===//
1314
1315// Provide the autogenerated implementation guts for the Op classes.
1316#define GET_OP_CLASSES
1317#include "circt/Dialect/Seq/Seq.cpp.inc"
assert(baseType &&"element must be base type")
MlirType elementType
Definition CHIRRTL.cpp:29
#define isdigit(x)
Definition FIRLexer.cpp:26
static bool isConstZero(Value value)
static std::optional< APInt > getInt(Value value)
Helper to convert a value to a constant integer if it is one.
static Block * getBodyBlock(FModuleLike mod)
void printFIFOAFThreshold(OpAsmPrinter &p, Operation *op, IntegerAttr threshold, Type outputFlagType)
Definition SeqOps.cpp:268
static bool isConstClock(Value value)
Definition SeqOps.cpp:962
static ParseResult parseFIFOFlagThreshold(OpAsmParser &parser, IntegerAttr &threshold, Type &outputFlagType, StringRef directive)
Definition SeqOps.cpp:237
static void printOptionalTypeMatch(OpAsmPrinter &p, Operation *op, Type refType, Value operand, Type type)
Definition SeqOps.cpp:85
static bool isConstAllOnes(Value value)
Definition SeqOps.cpp:975
static ParseResult parseOptionalImmutableTypeMatch(OpAsmParser &parser, Type refType, std::optional< OpAsmParser::UnresolvedOperand > operand, Type &type)
Definition SeqOps.cpp:90
void printFIFOAEThreshold(OpAsmPrinter &p, Operation *op, IntegerAttr threshold, Type outputFlagType)
Definition SeqOps.cpp:275
LogicalResult verifyResets(TOp op)
Definition SeqOps.cpp:317
static bool canElideName(OpAsmPrinter &p, Operation *op)
Definition SeqOps.cpp:61
ParseResult parseFIFOAEThreshold(OpAsmParser &parser, IntegerAttr &threshold, Type &outputFlagType)
Definition SeqOps.cpp:262
static LogicalResult verifyFirMemMask(Op op)
Definition SeqOps.cpp:944
static void printOptionalImmutableTypeMatch(OpAsmPrinter &p, Operation *op, Type refType, Value operand, Type type)
Definition SeqOps.cpp:98
static ParseResult parseOptionalTypeMatch(OpAsmParser &parser, Type refType, std::optional< OpAsmParser::UnresolvedOperand > operand, Type &type)
Definition SeqOps.cpp:77
static void setNameFromResult(OpAsmParser &parser, OperationState &result)
Definition SeqOps.cpp:50
ParseResult parseFIFOAFThreshold(OpAsmParser &parser, IntegerAttr &threshold, Type &outputFlagType)
Definition SeqOps.cpp:256
static InstancePath empty
create(elements, Type result_type=None)
Definition hw.py:483
create(data_type, value)
Definition hw.py:441
create(data_type, value)
Definition hw.py:433
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition CalyxOps.cpp:55
bool isConstant(Operation *op)
Return true if the specified operation has a constant value.
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
int64_t getBitWidth(mlir::Type type)
Return the hardware bit width of a type.
Definition HWTypes.cpp:110
FailureOr< seq::InitialOp > mergeInitialOps(Block *block)
Definition SeqOps.cpp:1216
bool isValidIndexValues(Value hlmemHandle, ValueRange addresses)
Definition SeqOps.cpp:32
mlir::TypedValue< seq::ImmutableType > createConstantInitialValue(OpBuilder builder, Location loc, mlir::IntegerAttr attr)
Definition SeqOps.cpp:1188
Value unwrapImmutableValue(mlir::TypedValue< seq::ImmutableType > immutableVal)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
static bool isConstantZero(Attribute operand)
Determine whether a constant operand is a zero value.
Definition FoldUtils.h:28
static bool isConstantOne(Attribute operand)
Determine whether a constant operand is a one value.
Definition FoldUtils.h:35
function_ref< void(Value, StringRef)> OpAsmSetValueNameFn
Definition LLVM.h:183
Definition seq.py:1
FirMemory(hw::HWModuleGeneratedOp op)
Definition SeqOps.cpp:1121