CIRCT  19.0.0git
SeqOps.cpp
Go to the documentation of this file.
1 //===- SeqOps.cpp - Implement the Seq operations ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements sequential ops.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 #include "circt/Dialect/HW/HWOps.h"
17 #include "mlir/IR/Builders.h"
18 #include "mlir/IR/DialectImplementation.h"
19 #include "mlir/IR/Matchers.h"
20 #include "mlir/IR/PatternMatch.h"
21 
23 #include "llvm/ADT/SmallString.h"
24 
25 using namespace mlir;
26 using namespace circt;
27 using namespace seq;
28 
29 bool circt::seq::isValidIndexValues(Value hlmemHandle, ValueRange addresses) {
30  auto memType = hlmemHandle.getType().cast<seq::HLMemType>();
31  auto shape = memType.getShape();
32  if (shape.size() != addresses.size())
33  return false;
34 
35  for (auto [dim, addr] : llvm::zip(shape, addresses)) {
36  auto addrType = addr.getType().dyn_cast<IntegerType>();
37  if (!addrType)
38  return false;
39  if (addrType.getIntOrFloatBitWidth() != llvm::Log2_64_Ceil(dim))
40  return false;
41  }
42  return true;
43 }
44 
45 // If there was no name specified, check to see if there was a useful name
46 // specified in the asm file.
47 static void setNameFromResult(OpAsmParser &parser, OperationState &result) {
48  if (result.attributes.getNamed("name"))
49  return;
50  // If there is no explicit name attribute, get it from the SSA result name.
51  // If numeric, just use an empty name.
52  StringRef resultName = parser.getResultName(0).first;
53  if (!resultName.empty() && isdigit(resultName[0]))
54  resultName = "";
55  result.addAttribute("name", parser.getBuilder().getStringAttr(resultName));
56 }
57 
58 static bool canElideName(OpAsmPrinter &p, Operation *op) {
59  if (!op->hasAttr("name"))
60  return true;
61 
62  auto name = op->getAttrOfType<StringAttr>("name").getValue();
63  if (name.empty())
64  return true;
65 
66  SmallString<32> resultNameStr;
67  llvm::raw_svector_ostream tmpStream(resultNameStr);
68  p.printOperand(op->getResult(0), tmpStream);
69  auto actualName = tmpStream.str().drop_front();
70  return actualName == name;
71 }
72 
73 static ParseResult
74 parseOptionalTypeMatch(OpAsmParser &parser, Type refType,
75  std::optional<OpAsmParser::UnresolvedOperand> operand,
76  Type &type) {
77  if (operand)
78  type = refType;
79  return success();
80 }
81 
82 static void printOptionalTypeMatch(OpAsmPrinter &p, Operation *op, Type refType,
83  Value operand, Type type) {
84  // Nothing to do - this is strictly an implicit parsing helper.
85 }
86 
87 //===----------------------------------------------------------------------===//
88 // ReadPortOp
89 //===----------------------------------------------------------------------===//
90 
91 ParseResult ReadPortOp::parse(OpAsmParser &parser, OperationState &result) {
92  llvm::SMLoc loc = parser.getCurrentLocation();
93 
94  OpAsmParser::UnresolvedOperand memOperand, rdenOperand;
95  bool hasRdEn = false;
96  llvm::SmallVector<OpAsmParser::UnresolvedOperand, 2> addressOperands;
97  seq::HLMemType memType;
98 
99  if (parser.parseOperand(memOperand) ||
100  parser.parseOperandList(addressOperands, OpAsmParser::Delimiter::Square))
101  return failure();
102 
103  if (succeeded(parser.parseOptionalKeyword("rden"))) {
104  if (failed(parser.parseOperand(rdenOperand)))
105  return failure();
106  hasRdEn = true;
107  }
108 
109  if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
110  parser.parseType(memType))
111  return failure();
112 
113  llvm::SmallVector<Type> operandTypes = memType.getAddressTypes();
114  operandTypes.insert(operandTypes.begin(), memType);
115 
116  llvm::SmallVector<OpAsmParser::UnresolvedOperand> allOperands = {memOperand};
117  llvm::copy(addressOperands, std::back_inserter(allOperands));
118  if (hasRdEn) {
119  operandTypes.push_back(parser.getBuilder().getI1Type());
120  allOperands.push_back(rdenOperand);
121  }
122 
123  if (parser.resolveOperands(allOperands, operandTypes, loc, result.operands))
124  return failure();
125 
126  result.addTypes(memType.getElementType());
127 
128  llvm::SmallVector<int32_t, 2> operandSizes;
129  operandSizes.push_back(1); // memory handle
130  operandSizes.push_back(addressOperands.size());
131  operandSizes.push_back(hasRdEn ? 1 : 0);
132  result.addAttribute("operandSegmentSizes",
133  parser.getBuilder().getDenseI32ArrayAttr(operandSizes));
134  return success();
135 }
136 
137 void ReadPortOp::print(OpAsmPrinter &p) {
138  p << " " << getMemory() << "[" << getAddresses() << "]";
139  if (getRdEn())
140  p << " rden " << getRdEn();
141  p.printOptionalAttrDict((*this)->getAttrs(), {"operandSegmentSizes"});
142  p << " : " << getMemory().getType();
143 }
144 
146  auto memName = getMemory().getDefiningOp<seq::HLMemOp>().getName();
147  setNameFn(getReadData(), (memName + "_rdata").str());
148 }
149 
150 void ReadPortOp::build(OpBuilder &builder, OperationState &result, Value memory,
151  ValueRange addresses, Value rdEn, unsigned latency) {
152  auto memType = memory.getType().cast<seq::HLMemType>();
153  ReadPortOp::build(builder, result, memType.getElementType(), memory,
154  addresses, rdEn, latency);
155 }
156 
157 //===----------------------------------------------------------------------===//
158 // WritePortOp
159 //===----------------------------------------------------------------------===//
160 
161 ParseResult WritePortOp::parse(OpAsmParser &parser, OperationState &result) {
162  llvm::SMLoc loc = parser.getCurrentLocation();
163  OpAsmParser::UnresolvedOperand memOperand, dataOperand, wrenOperand;
164  llvm::SmallVector<OpAsmParser::UnresolvedOperand, 2> addressOperands;
165  seq::HLMemType memType;
166 
167  if (parser.parseOperand(memOperand) ||
168  parser.parseOperandList(addressOperands,
169  OpAsmParser::Delimiter::Square) ||
170  parser.parseOperand(dataOperand) || parser.parseKeyword("wren") ||
171  parser.parseOperand(wrenOperand) ||
172  parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
173  parser.parseType(memType))
174  return failure();
175 
176  llvm::SmallVector<Type> operandTypes = memType.getAddressTypes();
177  operandTypes.insert(operandTypes.begin(), memType);
178  operandTypes.push_back(memType.getElementType());
179  operandTypes.push_back(parser.getBuilder().getI1Type());
180 
181  llvm::SmallVector<OpAsmParser::UnresolvedOperand, 2> allOperands(
182  addressOperands);
183  allOperands.insert(allOperands.begin(), memOperand);
184  allOperands.push_back(dataOperand);
185  allOperands.push_back(wrenOperand);
186 
187  if (parser.resolveOperands(allOperands, operandTypes, loc, result.operands))
188  return failure();
189 
190  return success();
191 }
192 
193 void WritePortOp::print(OpAsmPrinter &p) {
194  p << " " << getMemory() << "[" << getAddresses() << "] " << getInData()
195  << " wren " << getWrEn();
196  p.printOptionalAttrDict((*this)->getAttrs());
197  p << " : " << getMemory().getType();
198 }
199 
200 //===----------------------------------------------------------------------===//
201 // HLMemOp
202 //===----------------------------------------------------------------------===//
203 
205  setNameFn(getHandle(), getName());
206 }
207 
208 void HLMemOp::build(OpBuilder &builder, OperationState &result, Value clk,
209  Value rst, StringRef name, llvm::ArrayRef<int64_t> shape,
210  Type elementType) {
211  HLMemType t = HLMemType::get(builder.getContext(), shape, elementType);
212  HLMemOp::build(builder, result, t, clk, rst, name);
213 }
214 
215 //===----------------------------------------------------------------------===//
216 // FIFOOp
217 //===----------------------------------------------------------------------===//
218 
219 // Flag threshold custom directive
220 static ParseResult parseFIFOFlagThreshold(OpAsmParser &parser,
221  IntegerAttr &threshold,
222  Type &outputFlagType,
223  StringRef directive) {
224  // look for an optional "almost_full $threshold" group.
225  if (succeeded(parser.parseOptionalKeyword(directive))) {
226  int64_t thresholdValue;
227  if (succeeded(parser.parseInteger(thresholdValue))) {
228  threshold = parser.getBuilder().getI64IntegerAttr(thresholdValue);
229  outputFlagType = parser.getBuilder().getI1Type();
230  return success();
231  }
232  return parser.emitError(parser.getNameLoc(),
233  "expected integer value after " + directive +
234  " directive");
235  }
236  return success();
237 }
238 
239 ParseResult parseFIFOAFThreshold(OpAsmParser &parser, IntegerAttr &threshold,
240  Type &outputFlagType) {
241  return parseFIFOFlagThreshold(parser, threshold, outputFlagType,
242  "almost_full");
243 }
244 
245 ParseResult parseFIFOAEThreshold(OpAsmParser &parser, IntegerAttr &threshold,
246  Type &outputFlagType) {
247  return parseFIFOFlagThreshold(parser, threshold, outputFlagType,
248  "almost_empty");
249 }
250 
251 void printFIFOAFThreshold(OpAsmPrinter &p, Operation *op, IntegerAttr threshold,
252  Type outputFlagType) {
253  if (threshold) {
254  p << "almost_full"
255  << " " << threshold.getInt();
256  }
257 }
258 
259 void printFIFOAEThreshold(OpAsmPrinter &p, Operation *op, IntegerAttr threshold,
260  Type outputFlagType) {
261  if (threshold) {
262  p << "almost_empty"
263  << " " << threshold.getInt();
264  }
265 }
266 
268  setNameFn(getOutput(), "out");
269  setNameFn(getEmpty(), "empty");
270  setNameFn(getFull(), "full");
271  if (auto ae = getAlmostEmpty())
272  setNameFn(ae, "almostEmpty");
273  if (auto af = getAlmostFull())
274  setNameFn(af, "almostFull");
275 }
276 
277 LogicalResult FIFOOp::verify() {
278  auto aet = getAlmostEmptyThreshold();
279  auto aft = getAlmostFullThreshold();
280  size_t depth = getDepth();
281  if (aft.has_value() && aft.value() > depth)
282  return emitOpError("almost full threshold must be <= FIFO depth");
283 
284  if (aet.has_value() && aet.value() > depth)
285  return emitOpError("almost empty threshold must be <= FIFO depth");
286 
287  return success();
288 }
289 
290 //===----------------------------------------------------------------------===//
291 // CompRegOp
292 //===----------------------------------------------------------------------===//
293 
294 /// Suggest a name for each result value based on the saved result names
295 /// attribute.
297  // If the wire has an optional 'name' attribute, use it.
298  if (auto name = getName())
299  setNameFn(getResult(), *name);
300 }
301 
302 LogicalResult CompRegOp::verify() {
303  if ((getReset() == nullptr) ^ (getResetValue() == nullptr))
304  return emitOpError(
305  "either reset and resetValue or neither must be specified");
306  return success();
307 }
308 
309 std::optional<size_t> CompRegOp::getTargetResultIndex() { return 0; }
310 
311 template <typename TOp>
312 LogicalResult verifyResets(TOp op) {
313  if ((op.getReset() == nullptr) ^ (op.getResetValue() == nullptr))
314  return op->emitOpError(
315  "either reset and resetValue or neither must be specified");
316  bool hasReset = op.getReset() != nullptr;
317  if (hasReset && op.getResetValue().getType() != op.getInput().getType())
318  return op->emitOpError("reset value must be the same type as the input");
319 
320  return success();
321 }
322 
323 /// Suggest a name for each result value based on the saved result names
324 /// attribute.
326  // If the wire has an optional 'name' attribute, use it.
327  if (auto name = getName())
328  setNameFn(getResult(), *name);
329 }
330 
331 std::optional<size_t> CompRegClockEnabledOp::getTargetResultIndex() {
332  return 0;
333 }
334 
335 LogicalResult CompRegClockEnabledOp::verify() {
336  if (failed(verifyResets(*this)))
337  return failure();
338  return success();
339 }
340 
341 //===----------------------------------------------------------------------===//
342 // ShiftRegOp
343 //===----------------------------------------------------------------------===//
344 
346  // If the wire has an optional 'name' attribute, use it.
347  if (auto name = getName())
348  setNameFn(getResult(), *name);
349 }
350 
351 std::optional<size_t> ShiftRegOp::getTargetResultIndex() { return 0; }
352 
353 LogicalResult ShiftRegOp::verify() {
354  if (failed(verifyResets(*this)))
355  return failure();
356  return success();
357 }
358 
359 //===----------------------------------------------------------------------===//
360 // FirRegOp
361 //===----------------------------------------------------------------------===//
362 
363 void FirRegOp::build(OpBuilder &builder, OperationState &result, Value input,
364  Value clk, StringAttr name, hw::InnerSymAttr innerSym) {
365 
366  OpBuilder::InsertionGuard guard(builder);
367 
368  result.addOperands(input);
369  result.addOperands(clk);
370 
371  result.addAttribute(getNameAttrName(result.name), name);
372 
373  if (innerSym)
374  result.addAttribute(getInnerSymAttrName(result.name), innerSym);
375 
376  result.addTypes(input.getType());
377 }
378 
379 void FirRegOp::build(OpBuilder &builder, OperationState &result, Value input,
380  Value clk, StringAttr name, Value reset, Value resetValue,
381  hw::InnerSymAttr innerSym, bool isAsync) {
382 
383  OpBuilder::InsertionGuard guard(builder);
384 
385  result.addOperands(input);
386  result.addOperands(clk);
387  result.addOperands(reset);
388  result.addOperands(resetValue);
389 
390  result.addAttribute(getNameAttrName(result.name), name);
391  if (isAsync)
392  result.addAttribute(getIsAsyncAttrName(result.name), builder.getUnitAttr());
393 
394  if (innerSym)
395  result.addAttribute(getInnerSymAttrName(result.name), innerSym);
396 
397  result.addTypes(input.getType());
398 }
399 
400 ParseResult FirRegOp::parse(OpAsmParser &parser, OperationState &result) {
401  auto &builder = parser.getBuilder();
402  llvm::SMLoc loc = parser.getCurrentLocation();
403 
404  using Op = OpAsmParser::UnresolvedOperand;
405 
406  Op next, clk;
407  if (parser.parseOperand(next) || parser.parseKeyword("clock") ||
408  parser.parseOperand(clk))
409  return failure();
410 
411  if (succeeded(parser.parseOptionalKeyword("sym"))) {
412  hw::InnerSymAttr innerSym;
413  if (parser.parseCustomAttributeWithFallback(innerSym, /*type=*/nullptr,
414  "inner_sym", result.attributes))
415  return failure();
416  }
417 
418  // Parse reset [sync|async] %reset, %value
419  std::optional<std::pair<Op, Op>> resetAndValue;
420  if (succeeded(parser.parseOptionalKeyword("reset"))) {
421  bool isAsync;
422  if (succeeded(parser.parseOptionalKeyword("async")))
423  isAsync = true;
424  else if (succeeded(parser.parseOptionalKeyword("sync")))
425  isAsync = false;
426  else
427  return parser.emitError(loc, "invalid reset, expected 'sync' or 'async'");
428  if (isAsync)
429  result.attributes.append("isAsync", builder.getUnitAttr());
430 
431  resetAndValue = {{}, {}};
432  if (parser.parseOperand(resetAndValue->first) || parser.parseComma() ||
433  parser.parseOperand(resetAndValue->second))
434  return failure();
435  }
436 
437  std::optional<APInt> presetValue;
438  llvm::SMLoc presetValueLoc;
439  if (succeeded(parser.parseOptionalKeyword("preset"))) {
440  presetValueLoc = parser.getCurrentLocation();
441  OptionalParseResult presetIntResult =
442  parser.parseOptionalInteger(presetValue.emplace());
443  if (!presetIntResult.has_value() || failed(*presetIntResult))
444  return parser.emitError(loc, "expected integer value");
445  }
446 
447  Type ty;
448  if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
449  parser.parseType(ty))
450  return failure();
451  result.addTypes({ty});
452 
453  if (presetValue) {
454  uint64_t width = 0;
455  if (hw::type_isa<seq::ClockType>(ty)) {
456  width = 1;
457  } else {
458  int64_t maybeWidth = hw::getBitWidth(ty);
459  if (maybeWidth < 0)
460  return parser.emitError(presetValueLoc,
461  "cannot preset register of unknown width");
462  width = maybeWidth;
463  }
464 
465  APInt presetResult = presetValue->sextOrTrunc(width);
466  if (presetResult.zextOrTrunc(presetValue->getBitWidth()) != *presetValue)
467  return parser.emitError(loc, "preset value too large");
468 
469  auto builder = parser.getBuilder();
470  auto presetTy = builder.getIntegerType(width);
471  auto resultAttr = builder.getIntegerAttr(presetTy, presetResult);
472  result.addAttribute("preset", resultAttr);
473  }
474 
475  setNameFromResult(parser, result);
476 
477  if (parser.resolveOperand(next, ty, result.operands))
478  return failure();
479 
480  Type clkTy = ClockType::get(result.getContext());
481  if (parser.resolveOperand(clk, clkTy, result.operands))
482  return failure();
483 
484  if (resetAndValue) {
485  Type i1 = IntegerType::get(result.getContext(), 1);
486  if (parser.resolveOperand(resetAndValue->first, i1, result.operands) ||
487  parser.resolveOperand(resetAndValue->second, ty, result.operands))
488  return failure();
489  }
490 
491  return success();
492 }
493 
494 void FirRegOp::print(::mlir::OpAsmPrinter &p) {
495  SmallVector<StringRef> elidedAttrs = {
496  getInnerSymAttrName(), getIsAsyncAttrName(), getPresetAttrName()};
497 
498  p << ' ' << getNext() << " clock " << getClk();
499 
500  if (auto sym = getInnerSymAttr()) {
501  p << " sym ";
502  sym.print(p);
503  }
504 
505  if (hasReset()) {
506  p << " reset " << (getIsAsync() ? "async" : "sync") << ' ';
507  p << getReset() << ", " << getResetValue();
508  }
509 
510  if (auto preset = getPresetAttr()) {
511  p << " preset " << preset.getValue();
512  }
513 
514  if (canElideName(p, *this))
515  elidedAttrs.push_back("name");
516 
517  p.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
518  p << " : " << getNext().getType();
519 }
520 
521 /// Verifier for the FIR register op.
522 LogicalResult FirRegOp::verify() {
523  if (getReset() || getResetValue() || getIsAsync()) {
524  if (!getReset() || !getResetValue())
525  return emitOpError("must specify reset and reset value");
526  } else {
527  if (getIsAsync())
528  return emitOpError("register with no reset cannot be async");
529  }
530  if (auto preset = getPresetAttr()) {
531  int64_t presetWidth = hw::getBitWidth(preset.getType());
532  int64_t width = hw::getBitWidth(getType());
533  if (preset.getType() != getType() && presetWidth != width)
534  return emitOpError("preset type width must match register type");
535  }
536  return success();
537 }
538 
539 /// Suggest a name for each result value based on the saved result names
540 /// attribute.
542  // If the register has an optional 'name' attribute, use it.
543  if (!getName().empty())
544  setNameFn(getResult(), getName());
545 }
546 
547 std::optional<size_t> FirRegOp::getTargetResultIndex() { return 0; }
548 
549 LogicalResult FirRegOp::canonicalize(FirRegOp op, PatternRewriter &rewriter) {
550  // If the register has a constant zero reset, drop the reset and reset value
551  // altogether.
552  if (auto reset = op.getReset()) {
553  if (auto constOp = reset.getDefiningOp<hw::ConstantOp>()) {
554  if (constOp.getValue().isZero()) {
555  rewriter.replaceOpWithNewOp<FirRegOp>(op, op.getNext(), op.getClk(),
556  op.getNameAttr(),
557  op.getInnerSymAttr());
558  return success();
559  }
560  }
561  }
562 
563  // If the register has a symbol, we can't optimize it away.
564  if (op.getInnerSymAttr())
565  return failure();
566 
567  // Replace a register with a trivial feedback or constant clock with a
568  // constant zero.
569  // TODO: Once HW aggregate constant values are supported, move this
570  // canonicalization to the folder.
571  auto isConstant = [&]() -> bool {
572  if (op.getNext() == op.getResult())
573  return true;
574  if (auto clk = op.getClk().getDefiningOp<seq::ToClockOp>())
575  return clk.getInput().getDefiningOp<hw::ConstantOp>();
576  return false;
577  };
578 
579  if (isConstant()) {
580  if (auto resetValue = op.getResetValue()) {
581  // If the register has a reset value, we can replace it with that.
582  rewriter.replaceOp(op, resetValue);
583  } else {
584  if (op.getType().isa<seq::ClockType>()) {
585  rewriter.replaceOpWithNewOp<seq::ConstClockOp>(
586  op,
587  seq::ClockConstAttr::get(rewriter.getContext(), ClockConst::Low));
588  } else {
589  auto constant = rewriter.create<hw::ConstantOp>(
590  op.getLoc(), APInt::getZero(hw::getBitWidth(op.getType())));
591  rewriter.replaceOpWithNewOp<hw::BitcastOp>(op, op.getType(), constant);
592  }
593  }
594  return success();
595  }
596 
597  // For reset-less 1d array registers, replace an uninitialized element with
598  // constant zero. For example, let `r` be a 2xi1 register and its next value
599  // be `{foo, r[0]}`. `r[0]` is connected to itself so will never be
600  // initialized. If we don't enable aggregate preservation, `r_0` is replaced
601  // with `0`. Hence this canonicalization replaces 0th element of the next
602  // value with zero to match the behaviour.
603  if (!op.getReset()) {
604  if (auto arrayCreate = op.getNext().getDefiningOp<hw::ArrayCreateOp>()) {
605  // For now only support 1d arrays.
606  // TODO: Support nested arrays and bundles.
607  if (hw::type_cast<hw::ArrayType>(op.getResult().getType())
608  .getElementType()
609  .isa<IntegerType>()) {
610  SmallVector<Value> nextOperands;
611  bool changed = false;
612  for (const auto &[i, value] :
613  llvm::enumerate(arrayCreate.getOperands())) {
614  auto index = arrayCreate.getOperands().size() - i - 1;
615  APInt elementIndex;
616  // Check that the corresponding operand is op's element.
617  if (auto arrayGet = value.getDefiningOp<hw::ArrayGetOp>())
618  if (arrayGet.getInput() == op.getResult() &&
619  matchPattern(arrayGet.getIndex(),
620  m_ConstantInt(&elementIndex)) &&
621  elementIndex == index) {
622  nextOperands.push_back(rewriter.create<hw::ConstantOp>(
623  op.getLoc(),
624  APInt::getZero(hw::getBitWidth(arrayGet.getType()))));
625  changed = true;
626  continue;
627  }
628  nextOperands.push_back(value);
629  }
630  // If one of the operands is self loop, update the next value.
631  if (changed) {
632  auto newNextVal = rewriter.create<hw::ArrayCreateOp>(
633  arrayCreate.getLoc(), nextOperands);
634  if (arrayCreate->hasOneUse())
635  // If the original next value has a single use, we can replace the
636  // value directly.
637  rewriter.replaceOp(arrayCreate, newNextVal);
638  else {
639  // Otherwise, replace the entire firreg with a new one.
640  rewriter.replaceOpWithNewOp<FirRegOp>(op, newNextVal, op.getClk(),
641  op.getNameAttr(),
642  op.getInnerSymAttr());
643  }
644 
645  return success();
646  }
647  }
648  }
649  }
650 
651  return failure();
652 }
653 
654 OpFoldResult FirRegOp::fold(FoldAdaptor adaptor) {
655  // If the register has a symbol, we can't optimize it away.
656  if (getInnerSymAttr())
657  return {};
658 
659  // If the register is held in permanent reset, replace it with its reset
660  // value. This works trivially if the reset is asynchronous and therefore
661  // level-sensitive, in which case it will always immediately assume the reset
662  // value in silicon. If it is synchronous, the register value is undefined
663  // until the first clock edge at which point it becomes the reset value, in
664  // which case we simply define the initial value to already be the reset
665  // value.
666  if (auto reset = getReset())
667  if (auto constOp = reset.getDefiningOp<hw::ConstantOp>())
668  if (constOp.getValue().isOne())
669  return getResetValue();
670 
671  // If the register's next value is trivially it's current value, or the
672  // register is never clocked, we can replace the register with a constant
673  // value.
674  bool isTrivialFeedback = (getNext() == getResult());
675  bool isNeverClocked =
676  adaptor.getClk() != nullptr; // clock operand is constant
677  if (!isTrivialFeedback && !isNeverClocked)
678  return {};
679 
680  // If the register has a reset value, we can replace it with that.
681  if (auto resetValue = getResetValue())
682  return resetValue;
683 
684  // Otherwise we want to replace the register with a constant 0. For now this
685  // only works with integer types.
686  auto intType = getType().dyn_cast<IntegerType>();
687  if (!intType)
688  return {};
689  return IntegerAttr::get(intType, 0);
690 }
691 
692 //===----------------------------------------------------------------------===//
693 // ClockGateOp
694 //===----------------------------------------------------------------------===//
695 
696 OpFoldResult ClockGateOp::fold(FoldAdaptor adaptor) {
697  // Forward the clock if one of the enables is always true.
698  if (isConstantOne(adaptor.getEnable()) ||
699  isConstantOne(adaptor.getTestEnable()))
700  return getInput();
701 
702  // Fold to a constant zero clock if the enables are always false.
703  if (isConstantZero(adaptor.getEnable()) &&
704  (!getTestEnable() || isConstantZero(adaptor.getTestEnable())))
705  return ClockConstAttr::get(getContext(), ClockConst::Low);
706 
707  // Forward constant zero clocks.
708  if (auto clockAttr = dyn_cast_or_null<ClockConstAttr>(adaptor.getInput()))
709  if (clockAttr.getValue() == ClockConst::Low)
710  return ClockConstAttr::get(getContext(), ClockConst::Low);
711 
712  // Transitive clock gating - eliminate clock gates that are driven by an
713  // identical enable signal somewhere higher in the clock gate hierarchy.
714  auto clockGateInputOp = getInput().getDefiningOp<ClockGateOp>();
715  while (clockGateInputOp) {
716  if (clockGateInputOp.getEnable() == getEnable() &&
717  clockGateInputOp.getTestEnable() == getTestEnable())
718  return getInput();
719  clockGateInputOp = clockGateInputOp.getInput().getDefiningOp<ClockGateOp>();
720  }
721 
722  return {};
723 }
724 
725 LogicalResult ClockGateOp::canonicalize(ClockGateOp op,
726  PatternRewriter &rewriter) {
727  // Remove constant false test enable.
728  if (auto testEnable = op.getTestEnable()) {
729  if (auto constOp = testEnable.getDefiningOp<hw::ConstantOp>()) {
730  if (constOp.getValue().isZero()) {
731  rewriter.modifyOpInPlace(op,
732  [&] { op.getTestEnableMutable().clear(); });
733  return success();
734  }
735  }
736  }
737 
738  return failure();
739 }
740 
741 std::optional<size_t> ClockGateOp::getTargetResultIndex() {
742  return std::nullopt;
743 }
744 
745 //===----------------------------------------------------------------------===//
746 // ClockMuxOp
747 //===----------------------------------------------------------------------===//
748 
749 OpFoldResult ClockMuxOp::fold(FoldAdaptor adaptor) {
750  if (isConstantOne(adaptor.getCond()))
751  return getTrueClock();
752  if (isConstantZero(adaptor.getCond()))
753  return getFalseClock();
754  return {};
755 }
756 
757 //===----------------------------------------------------------------------===//
758 // FirMemOp
759 //===----------------------------------------------------------------------===//
760 
761 LogicalResult FirMemOp::canonicalize(FirMemOp op, PatternRewriter &rewriter) {
762  // Do not change memories if symbols point to them.
763  if (op.getInnerSymAttr())
764  return failure();
765 
766  // If the memory has no read ports, erase it.
767  for (auto *user : op->getUsers()) {
768  if (isa<FirMemReadOp, FirMemReadWriteOp>(user))
769  return failure();
770  assert(isa<FirMemWriteOp>(user) && "invalid seq.firmem user");
771  }
772 
773  for (auto *user : llvm::make_early_inc_range(op->getUsers()))
774  rewriter.eraseOp(user);
775 
776  rewriter.eraseOp(op);
777  return success();
778 }
779 
781  auto nameAttr = (*this)->getAttrOfType<StringAttr>("name");
782  if (!nameAttr.getValue().empty())
783  setNameFn(getResult(), nameAttr.getValue());
784 }
785 
786 std::optional<size_t> FirMemOp::getTargetResultIndex() { return 0; }
787 
788 template <class Op>
789 static LogicalResult verifyFirMemMask(Op op) {
790  if (auto mask = op.getMask()) {
791  auto memType = op.getMemory().getType();
792  if (!memType.getMaskWidth())
793  return op.emitOpError("has mask operand but memory type '")
794  << memType << "' has no mask";
795  auto expected = IntegerType::get(op.getContext(), *memType.getMaskWidth());
796  if (mask.getType() != expected)
797  return op.emitOpError("has mask operand of type '")
798  << mask.getType() << "', but memory type requires '" << expected
799  << "'";
800  }
801  return success();
802 }
803 
804 LogicalResult FirMemWriteOp::verify() { return verifyFirMemMask(*this); }
805 LogicalResult FirMemReadWriteOp::verify() { return verifyFirMemMask(*this); }
806 
807 static bool isConstClock(Value value) {
808  if (!value)
809  return false;
810  return value.getDefiningOp<seq::ConstClockOp>();
811 }
812 
813 static bool isConstZero(Value value) {
814  if (value)
815  if (auto constOp = value.getDefiningOp<hw::ConstantOp>())
816  return constOp.getValue().isZero();
817  return false;
818 }
819 
820 static bool isConstAllOnes(Value value) {
821  if (value)
822  if (auto constOp = value.getDefiningOp<hw::ConstantOp>())
823  return constOp.getValue().isAllOnes();
824  return false;
825 }
826 
827 LogicalResult FirMemReadOp::canonicalize(FirMemReadOp op,
828  PatternRewriter &rewriter) {
829  // Remove the enable if it is constant true.
830  if (isConstAllOnes(op.getEnable())) {
831  rewriter.modifyOpInPlace(op, [&] { op.getEnableMutable().erase(0); });
832  return success();
833  }
834  return failure();
835 }
836 
837 LogicalResult FirMemWriteOp::canonicalize(FirMemWriteOp op,
838  PatternRewriter &rewriter) {
839  // Remove the write port if it is trivially dead.
840  if (isConstZero(op.getEnable()) || isConstZero(op.getMask()) ||
841  isConstClock(op.getClk())) {
842  rewriter.eraseOp(op);
843  return success();
844  }
845  bool anyChanges = false;
846 
847  // Remove the enable if it is constant true.
848  if (auto enable = op.getEnable(); isConstAllOnes(enable)) {
849  rewriter.modifyOpInPlace(op, [&] { op.getEnableMutable().erase(0); });
850  anyChanges = true;
851  }
852 
853  // Remove the mask if it is all ones.
854  if (auto mask = op.getMask(); isConstAllOnes(mask)) {
855  rewriter.modifyOpInPlace(op, [&] { op.getMaskMutable().erase(0); });
856  anyChanges = true;
857  }
858 
859  return success(anyChanges);
860 }
861 
862 LogicalResult FirMemReadWriteOp::canonicalize(FirMemReadWriteOp op,
863  PatternRewriter &rewriter) {
864  // Replace the read-write port with a read port if the write behavior is
865  // trivially disabled.
866  if (isConstZero(op.getEnable()) || isConstZero(op.getMask()) ||
867  isConstClock(op.getClk()) || isConstZero(op.getMode())) {
868  auto opAttrs = op->getAttrs();
869  auto opAttrNames = op.getAttributeNames();
870  auto newOp = rewriter.replaceOpWithNewOp<FirMemReadOp>(
871  op, op.getMemory(), op.getAddress(), op.getClk(), op.getEnable());
872  for (auto namedAttr : opAttrs)
873  if (!llvm::is_contained(opAttrNames, namedAttr.getName()))
874  newOp->setAttr(namedAttr.getName(), namedAttr.getValue());
875  return success();
876  }
877  bool anyChanges = false;
878 
879  // Remove the enable if it is constant true.
880  if (auto enable = op.getEnable(); isConstAllOnes(enable)) {
881  rewriter.modifyOpInPlace(op, [&] { op.getEnableMutable().erase(0); });
882  anyChanges = true;
883  }
884 
885  // Remove the mask if it is all ones.
886  if (auto mask = op.getMask(); isConstAllOnes(mask)) {
887  rewriter.modifyOpInPlace(op, [&] { op.getMaskMutable().erase(0); });
888  anyChanges = true;
889  }
890 
891  return success(anyChanges);
892 }
893 
894 //===----------------------------------------------------------------------===//
895 // ConstClockOp
896 //===----------------------------------------------------------------------===//
897 
898 OpFoldResult ConstClockOp::fold(FoldAdaptor adaptor) {
899  return ClockConstAttr::get(getContext(), getValue());
900 }
901 
902 //===----------------------------------------------------------------------===//
903 // ToClockOp/FromClockOp
904 //===----------------------------------------------------------------------===//
905 
906 LogicalResult ToClockOp::canonicalize(ToClockOp op, PatternRewriter &rewriter) {
907  if (auto fromClock = op.getInput().getDefiningOp<FromClockOp>()) {
908  rewriter.replaceOp(op, fromClock.getInput());
909  return success();
910  }
911  return failure();
912 }
913 
914 OpFoldResult ToClockOp::fold(FoldAdaptor adaptor) {
915  if (auto fromClock = getInput().getDefiningOp<FromClockOp>())
916  return fromClock.getInput();
917  if (auto intAttr = dyn_cast_or_null<IntegerAttr>(adaptor.getInput())) {
918  auto value =
919  intAttr.getValue().isZero() ? ClockConst::Low : ClockConst::High;
920  return ClockConstAttr::get(getContext(), value);
921  }
922  return {};
923 }
924 
925 LogicalResult FromClockOp::canonicalize(FromClockOp op,
926  PatternRewriter &rewriter) {
927  if (auto toClock = op.getInput().getDefiningOp<ToClockOp>()) {
928  rewriter.replaceOp(op, toClock.getInput());
929  return success();
930  }
931  return failure();
932 }
933 
934 OpFoldResult FromClockOp::fold(FoldAdaptor adaptor) {
935  if (auto toClock = getInput().getDefiningOp<ToClockOp>())
936  return toClock.getInput();
937  if (auto clockAttr = dyn_cast_or_null<ClockConstAttr>(adaptor.getInput())) {
938  auto ty = IntegerType::get(getContext(), 1);
939  return IntegerAttr::get(ty, clockAttr.getValue() == ClockConst::High);
940  }
941  return {};
942 }
943 
944 //===----------------------------------------------------------------------===//
945 // ClockInverterOp
946 //===----------------------------------------------------------------------===//
947 
948 OpFoldResult ClockInverterOp::fold(FoldAdaptor adaptor) {
949  if (auto chainedInv = getInput().getDefiningOp<ClockInverterOp>())
950  return chainedInv.getInput();
951  if (auto clockAttr = dyn_cast_or_null<ClockConstAttr>(adaptor.getInput())) {
952  auto clockIn = clockAttr.getValue() == ClockConst::High;
953  return ClockConstAttr::get(getContext(),
954  clockIn ? ClockConst::Low : ClockConst::High);
955  }
956  return {};
957 }
958 
959 //===----------------------------------------------------------------------===//
960 // FIR memory helper
961 //===----------------------------------------------------------------------===//
962 
963 FirMemory::FirMemory(hw::HWModuleGeneratedOp op) {
964  depth = op->getAttrOfType<IntegerAttr>("depth").getInt();
965  numReadPorts = op->getAttrOfType<IntegerAttr>("numReadPorts").getUInt();
966  numWritePorts = op->getAttrOfType<IntegerAttr>("numWritePorts").getUInt();
967  numReadWritePorts =
968  op->getAttrOfType<IntegerAttr>("numReadWritePorts").getUInt();
969  readLatency = op->getAttrOfType<IntegerAttr>("readLatency").getUInt();
970  writeLatency = op->getAttrOfType<IntegerAttr>("writeLatency").getUInt();
971  dataWidth = op->getAttrOfType<IntegerAttr>("width").getUInt();
972  if (op->hasAttrOfType<IntegerAttr>("maskGran"))
973  maskGran = op->getAttrOfType<IntegerAttr>("maskGran").getUInt();
974  else
975  maskGran = dataWidth;
976  readUnderWrite = op->getAttrOfType<seq::RUWAttr>("readUnderWrite").getValue();
977  writeUnderWrite =
978  op->getAttrOfType<seq::WUWAttr>("writeUnderWrite").getValue();
979  if (auto clockIDsAttr = op->getAttrOfType<ArrayAttr>("writeClockIDs"))
980  for (auto clockID : clockIDsAttr)
981  writeClockIDs.push_back(
982  clockID.cast<IntegerAttr>().getValue().getZExtValue());
983  initFilename = op->getAttrOfType<StringAttr>("initFilename").getValue();
984  initIsBinary = op->getAttrOfType<BoolAttr>("initIsBinary").getValue();
985  initIsInline = op->getAttrOfType<BoolAttr>("initIsInline").getValue();
986 }
987 
988 //===----------------------------------------------------------------------===//
989 // TableGen generated logic.
990 //===----------------------------------------------------------------------===//
991 
992 // Provide the autogenerated implementation guts for the Op classes.
993 #define GET_OP_CLASSES
994 #include "circt/Dialect/Seq/Seq.cpp.inc"
assert(baseType &&"element must be base type")
MlirType elementType
Definition: CHIRRTL.cpp:29
#define isdigit(x)
Definition: FIRLexer.cpp:26
static bool isConstantOne(Attribute operand)
Determine whether a constant operand is a one value for the sake of constant folding.
int32_t width
Definition: FIRRTL.cpp:36
static InstancePath empty
Builder builder
static std::optional< APInt > getInt(Value value)
Helper to convert a value to a constant integer if it is one.
void printFIFOAFThreshold(OpAsmPrinter &p, Operation *op, IntegerAttr threshold, Type outputFlagType)
Definition: SeqOps.cpp:251
static bool isConstClock(Value value)
Definition: SeqOps.cpp:807
static ParseResult parseFIFOFlagThreshold(OpAsmParser &parser, IntegerAttr &threshold, Type &outputFlagType, StringRef directive)
Definition: SeqOps.cpp:220
static void printOptionalTypeMatch(OpAsmPrinter &p, Operation *op, Type refType, Value operand, Type type)
Definition: SeqOps.cpp:82
static bool isConstAllOnes(Value value)
Definition: SeqOps.cpp:820
void printFIFOAEThreshold(OpAsmPrinter &p, Operation *op, IntegerAttr threshold, Type outputFlagType)
Definition: SeqOps.cpp:259
LogicalResult verifyResets(TOp op)
Definition: SeqOps.cpp:312
static bool canElideName(OpAsmPrinter &p, Operation *op)
Definition: SeqOps.cpp:58
ParseResult parseFIFOAEThreshold(OpAsmParser &parser, IntegerAttr &threshold, Type &outputFlagType)
Definition: SeqOps.cpp:245
static bool isConstZero(Value value)
Definition: SeqOps.cpp:813
static LogicalResult verifyFirMemMask(Op op)
Definition: SeqOps.cpp:789
static ParseResult parseOptionalTypeMatch(OpAsmParser &parser, Type refType, std::optional< OpAsmParser::UnresolvedOperand > operand, Type &type)
Definition: SeqOps.cpp:74
static void setNameFromResult(OpAsmParser &parser, OperationState &result)
Definition: SeqOps.cpp:47
ParseResult parseFIFOAFThreshold(OpAsmParser &parser, IntegerAttr &threshold, Type &outputFlagType)
Definition: SeqOps.cpp:239
def create(data_type, value)
Definition: hw.py:393
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:54
bool isConstant(Operation *op)
Return true if the specified operation has a constant value.
Definition: FIRRTLOps.cpp:4494
std::optional< int64_t > getBitWidth(FIRRTLBaseType type, bool ignoreFlip=false)
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
void getAsmResultNames(OpAsmSetValueNameFn setNameFn, StringRef instanceName, ArrayAttr resultNames, ValueRange results)
Suggest a name for each result value based on the saved result names attribute.
bool isValidIndexValues(Value hlmemHandle, ValueRange addresses)
Definition: SeqOps.cpp:29
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21
static bool isConstantZero(Attribute operand)
Determine whether a constant operand is a zero value.
Definition: FoldUtils.h:27
function_ref< void(Value, StringRef)> OpAsmSetValueNameFn
Definition: LLVM.h:186
Definition: seq.py:1