CIRCT  19.0.0git
HWAttributes.cpp
Go to the documentation of this file.
1 //===- HWAttributes.cpp - Implement HW attributes -------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
11 #include "circt/Dialect/HW/HWOps.h"
13 #include "circt/Support/LLVM.h"
14 #include "mlir/IR/Diagnostics.h"
15 #include "mlir/IR/DialectImplementation.h"
16 #include "llvm/ADT/SmallString.h"
17 #include "llvm/ADT/StringExtras.h"
18 #include "llvm/ADT/TypeSwitch.h"
19 #include "llvm/Support/FileSystem.h"
20 #include "llvm/Support/Path.h"
21 
22 using namespace circt;
23 using namespace circt::hw;
24 using mlir::TypedAttr;
25 
26 // Internal method used for .mlir file parsing, defined below.
27 static Attribute parseParamExprWithOpcode(StringRef opcode, DialectAsmParser &p,
28  Type type);
29 
30 //===----------------------------------------------------------------------===//
31 // ODS Boilerplate
32 //===----------------------------------------------------------------------===//
33 
34 #define GET_ATTRDEF_CLASSES
35 #include "circt/Dialect/HW/HWAttributes.cpp.inc"
36 
37 void HWDialect::registerAttributes() {
38  addAttributes<
39 #define GET_ATTRDEF_LIST
40 #include "circt/Dialect/HW/HWAttributes.cpp.inc"
41  >();
42 }
43 
44 Attribute HWDialect::parseAttribute(DialectAsmParser &p, Type type) const {
45  StringRef attrName;
46  Attribute attr;
47  auto parseResult = generatedAttributeParser(p, &attrName, type, attr);
48  if (parseResult.has_value())
49  return attr;
50 
51  // Parse "#hw.param.expr.add" as ParamExprAttr.
52  if (attrName.starts_with(ParamExprAttr::getMnemonic())) {
53  auto string = attrName.drop_front(ParamExprAttr::getMnemonic().size());
54  if (string.front() == '.')
55  return parseParamExprWithOpcode(string.drop_front(), p, type);
56  }
57 
58  p.emitError(p.getNameLoc(), "Unexpected hw attribute '" + attrName + "'");
59  return {};
60 }
61 
62 void HWDialect::printAttribute(Attribute attr, DialectAsmPrinter &p) const {
63  if (succeeded(generatedAttributePrinter(attr, p)))
64  return;
65  llvm_unreachable("Unexpected attribute");
66 }
67 
68 //===----------------------------------------------------------------------===//
69 // OutputFileAttr
70 //===----------------------------------------------------------------------===//
71 
72 static std::string canonicalizeFilename(const Twine &directory,
73  const Twine &filename) {
74 
75  // Convert the filename to a native style path.
76  SmallString<128> nativeFilename;
77  llvm::sys::path::native(filename, nativeFilename);
78 
79  // If the filename is an absolute path, ignore the directory.
80  // e.g. `directory/` + `/etc/filename` -> `/etc/filename`.
81  if (llvm::sys::path::is_absolute(nativeFilename))
82  return std::string(nativeFilename);
83 
84  // Convert the directory to a native style path.
85  SmallString<128> nativeDirectory;
86  llvm::sys::path::native(directory, nativeDirectory);
87 
88  // If the filename component is empty, then ensure that the path ends in a
89  // separator and return it.
90  // e.g. `directory` + `` -> `directory/`.
91  auto separator = llvm::sys::path::get_separator();
92  if (nativeFilename.empty() && !nativeDirectory.ends_with(separator)) {
93  nativeDirectory += separator;
94  return std::string(nativeDirectory);
95  }
96 
97  // Append the directory and filename together.
98  // e.g. `/tmp/` + `out/filename` -> `/tmp/out/filename`.
99  SmallString<128> fullPath;
100  llvm::sys::path::append(fullPath, nativeDirectory, nativeFilename);
101  return std::string(fullPath);
102 }
103 
104 OutputFileAttr OutputFileAttr::getFromFilename(MLIRContext *context,
105  const Twine &filename,
106  bool excludeFromFileList,
107  bool includeReplicatedOps) {
108  return OutputFileAttr::getFromDirectoryAndFilename(
109  context, "", filename, excludeFromFileList, includeReplicatedOps);
110 }
111 
112 OutputFileAttr OutputFileAttr::getFromDirectoryAndFilename(
113  MLIRContext *context, const Twine &directory, const Twine &filename,
114  bool excludeFromFileList, bool includeReplicatedOps) {
115  auto canonicalized = canonicalizeFilename(directory, filename);
116  return OutputFileAttr::get(StringAttr::get(context, canonicalized),
117  BoolAttr::get(context, excludeFromFileList),
118  BoolAttr::get(context, includeReplicatedOps));
119 }
120 
121 OutputFileAttr OutputFileAttr::getAsDirectory(MLIRContext *context,
122  const Twine &directory,
123  bool excludeFromFileList,
124  bool includeReplicatedOps) {
125  return getFromDirectoryAndFilename(context, directory, "",
126  excludeFromFileList, includeReplicatedOps);
127 }
128 
129 bool OutputFileAttr::isDirectory() {
130  return getFilename().getValue().ends_with(llvm::sys::path::get_separator());
131 }
132 
133 /// Option ::= 'excludeFromFileList' | 'includeReplicatedOp'
134 /// OutputFileAttr ::= 'output_file<' directory ',' name (',' Option)* '>'
135 Attribute OutputFileAttr::parse(AsmParser &p, Type type) {
136  StringAttr filename;
137  if (p.parseLess() || p.parseAttribute<StringAttr>(filename))
138  return Attribute();
139 
140  // Parse the additional keyword attributes. Its easier to let people specify
141  // these more than once than to detect the problem and do something about it.
142  bool excludeFromFileList = false;
143  bool includeReplicatedOps = false;
144  while (true) {
145  if (p.parseOptionalComma())
146  break;
147  if (!p.parseOptionalKeyword("excludeFromFileList"))
148  excludeFromFileList = true;
149  else if (!p.parseKeyword("includeReplicatedOps",
150  "or 'excludeFromFileList'"))
151  includeReplicatedOps = true;
152  else
153  return Attribute();
154  }
155 
156  if (p.parseGreater())
157  return Attribute();
158 
159  return OutputFileAttr::getFromFilename(p.getContext(), filename.getValue(),
160  excludeFromFileList,
161  includeReplicatedOps);
162 }
163 
164 void OutputFileAttr::print(AsmPrinter &p) const {
165  p << "<" << getFilename();
166  if (getExcludeFromFilelist().getValue())
167  p << ", excludeFromFileList";
168  if (getIncludeReplicatedOps().getValue())
169  p << ", includeReplicatedOps";
170  p << ">";
171 }
172 
173 //===----------------------------------------------------------------------===//
174 // EnumFieldAttr
175 //===----------------------------------------------------------------------===//
176 
177 Attribute EnumFieldAttr::parse(AsmParser &p, Type) {
178  StringRef field;
179  Type type;
180  if (p.parseLess() || p.parseKeyword(&field) || p.parseComma() ||
181  p.parseType(type) || p.parseGreater())
182  return Attribute();
183  return EnumFieldAttr::get(p.getEncodedSourceLoc(p.getCurrentLocation()),
184  StringAttr::get(p.getContext(), field), type);
185 }
186 
187 void EnumFieldAttr::print(AsmPrinter &p) const {
188  p << "<" << getField().getValue() << ", ";
189  p.printType(getType().getValue());
190  p << ">";
191 }
192 
193 EnumFieldAttr EnumFieldAttr::get(Location loc, StringAttr value,
194  mlir::Type type) {
195  if (!hw::isHWEnumType(type))
196  emitError(loc) << "expected enum type";
197 
198  // Check whether the provided value is a member of the enum type.
199  EnumType enumType = llvm::cast<EnumType>(getCanonicalType(type));
200  if (!enumType.contains(value.getValue())) {
201  emitError(loc) << "enum value '" << value.getValue()
202  << "' is not a member of enum type " << enumType;
203  return nullptr;
204  }
205 
206  return Base::get(value.getContext(), value, TypeAttr::get(type));
207 }
208 
209 //===----------------------------------------------------------------------===//
210 // InnerRefAttr
211 //===----------------------------------------------------------------------===//
212 
213 Attribute InnerRefAttr::parse(AsmParser &p, Type type) {
214  SymbolRefAttr attr;
215  if (p.parseLess() || p.parseAttribute<SymbolRefAttr>(attr) ||
216  p.parseGreater())
217  return Attribute();
218  if (attr.getNestedReferences().size() != 1)
219  return Attribute();
220  return InnerRefAttr::get(attr.getRootReference(), attr.getLeafReference());
221 }
222 
223 static void printSymbolName(AsmPrinter &p, StringAttr sym) {
224  if (sym)
225  p.printSymbolName(sym.getValue());
226  else
227  p.printSymbolName({});
228 }
229 
230 void InnerRefAttr::print(AsmPrinter &p) const {
231  p << "<";
232  printSymbolName(p, getModule());
233  p << "::";
234  printSymbolName(p, getName());
235  p << ">";
236 }
237 
238 //===----------------------------------------------------------------------===//
239 // InnerSymAttr and InnerSymPropertiesAttr
240 //===----------------------------------------------------------------------===//
241 
242 Attribute InnerSymPropertiesAttr::parse(AsmParser &parser, Type type) {
243  StringAttr name;
244  NamedAttrList dummyList;
245  int64_t fieldId = 0;
246  if (parser.parseLess() || parser.parseSymbolName(name, "name", dummyList) ||
247  parser.parseComma() || parser.parseInteger(fieldId) ||
248  parser.parseComma())
249  return Attribute();
250 
251  StringRef visibility;
252  auto loc = parser.getCurrentLocation();
253  if (parser.parseOptionalKeyword(&visibility,
254  {"public", "private", "nested"})) {
255  parser.emitError(loc, "expected 'public', 'private', or 'nested'");
256  return Attribute();
257  }
258  auto visibilityAttr = parser.getBuilder().getStringAttr(visibility);
259 
260  if (parser.parseGreater())
261  return Attribute();
262 
263  return parser.getChecked<InnerSymPropertiesAttr>(parser.getContext(), name,
264  fieldId, visibilityAttr);
265 }
266 
267 void InnerSymPropertiesAttr::print(AsmPrinter &odsPrinter) const {
268  odsPrinter << "<@" << getName().getValue() << "," << getFieldID() << ","
269  << getSymVisibility().getValue() << ">";
270 }
271 
272 LogicalResult InnerSymPropertiesAttr::verify(
273  ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError,
274  ::mlir::StringAttr name, uint64_t fieldID,
275  ::mlir::StringAttr symVisibility) {
276  if (!name || name.getValue().empty())
277  return emitError() << "inner symbol cannot have empty name";
278  return success();
279 }
280 
281 StringAttr InnerSymAttr::getSymIfExists(uint64_t fieldId) const {
282  const auto *it =
283  llvm::find_if(getImpl()->props, [&](const InnerSymPropertiesAttr &p) {
284  return p.getFieldID() == fieldId;
285  });
286  if (it != getProps().end())
287  return it->getName();
288  return {};
289 }
290 
291 InnerSymAttr InnerSymAttr::erase(uint64_t fieldID) const {
292  SmallVector<InnerSymPropertiesAttr> syms(getProps());
293  const auto *it = llvm::find_if(syms, [fieldID](InnerSymPropertiesAttr p) {
294  return p.getFieldID() == fieldID;
295  });
296  assert(it != syms.end());
297  syms.erase(it);
298  return InnerSymAttr::get(getContext(), syms);
299 }
300 
301 LogicalResult InnerSymAttr::walkSymbols(
302  llvm::function_ref<LogicalResult(StringAttr)> callback) const {
303  for (auto p : getImpl()->props)
304  if (callback(p.getName()).failed())
305  return failure();
306  return success();
307 }
308 
309 Attribute InnerSymAttr::parse(AsmParser &parser, Type type) {
310  StringAttr sym;
311  NamedAttrList dummyList;
312  SmallVector<InnerSymPropertiesAttr, 4> names;
313  if (!parser.parseOptionalSymbolName(sym, "dummy", dummyList)) {
314  auto prop = parser.getChecked<InnerSymPropertiesAttr>(
315  parser.getContext(), sym, 0,
316  StringAttr::get(parser.getContext(), "public"));
317  if (!prop)
318  return {};
319  names.push_back(prop);
320  } else if (parser.parseCommaSeparatedList(
321  OpAsmParser::Delimiter::Square, [&]() -> ParseResult {
322  InnerSymPropertiesAttr prop;
323  if (parser.parseCustomAttributeWithFallback(
324  prop, mlir::Type{}, "dummy", dummyList))
325  return failure();
326 
327  names.push_back(prop);
328 
329  return success();
330  }))
331  return Attribute();
332 
333  std::sort(names.begin(), names.end(),
334  [&](InnerSymPropertiesAttr a, InnerSymPropertiesAttr b) {
335  return a.getFieldID() < b.getFieldID();
336  });
337 
338  return InnerSymAttr::get(parser.getContext(), names);
339 }
340 
341 void InnerSymAttr::print(AsmPrinter &odsPrinter) const {
342 
343  auto props = getProps();
344  if (props.size() == 1 && props[0].getSymVisibility().getValue() == "public" &&
345  props[0].getFieldID() == 0) {
346  odsPrinter << "@" << props[0].getName().getValue();
347  return;
348  }
349  auto names = props.vec();
350 
351  std::sort(names.begin(), names.end(),
352  [&](InnerSymPropertiesAttr a, InnerSymPropertiesAttr b) {
353  return a.getFieldID() < b.getFieldID();
354  });
355  odsPrinter << "[";
356  llvm::interleaveComma(names, odsPrinter, [&](InnerSymPropertiesAttr attr) {
357  attr.print(odsPrinter);
358  });
359  odsPrinter << "]";
360 }
361 
362 //===----------------------------------------------------------------------===//
363 // ParamDeclAttr
364 //===----------------------------------------------------------------------===//
365 
366 Attribute ParamDeclAttr::parse(AsmParser &p, Type trailing) {
367  std::string name;
368  Type type;
369  Attribute value;
370  // < "FOO" : i32 > : i32
371  // < "FOO" : i32 = 0 > : i32
372  // < "FOO" : none >
373  if (p.parseLess() || p.parseString(&name) || p.parseColonType(type))
374  return Attribute();
375 
376  if (succeeded(p.parseOptionalEqual())) {
377  if (p.parseAttribute(value, type))
378  return Attribute();
379  }
380 
381  if (p.parseGreater())
382  return Attribute();
383 
384  if (value)
385  return ParamDeclAttr::get(p.getContext(),
386  p.getBuilder().getStringAttr(name), type, value);
387  return ParamDeclAttr::get(name, type);
388 }
389 
390 void ParamDeclAttr::print(AsmPrinter &p) const {
391  p << "<" << getName() << ": " << getType();
392  if (getValue()) {
393  p << " = ";
394  p.printAttributeWithoutType(getValue());
395  }
396  p << ">";
397 }
398 
399 //===----------------------------------------------------------------------===//
400 // ParamDeclRefAttr
401 //===----------------------------------------------------------------------===//
402 
403 Attribute ParamDeclRefAttr::parse(AsmParser &p, Type type) {
404  StringAttr name;
405  if (p.parseLess() || p.parseAttribute(name) || p.parseGreater() ||
406  (!type && (p.parseColon() || p.parseType(type))))
407  return Attribute();
408 
409  return ParamDeclRefAttr::get(name, type);
410 }
411 
412 void ParamDeclRefAttr::print(AsmPrinter &p) const {
413  p << "<" << getName() << ">";
414 }
415 
416 //===----------------------------------------------------------------------===//
417 // ParamVerbatimAttr
418 //===----------------------------------------------------------------------===//
419 
420 Attribute ParamVerbatimAttr::parse(AsmParser &p, Type type) {
421  StringAttr text;
422  if (p.parseLess() || p.parseAttribute(text) || p.parseGreater() ||
423  (!type && (p.parseColon() || p.parseType(type))))
424  return Attribute();
425 
426  return ParamVerbatimAttr::get(p.getContext(), text, type);
427 }
428 
429 void ParamVerbatimAttr::print(AsmPrinter &p) const {
430  p << "<" << getValue() << ">";
431 }
432 
433 //===----------------------------------------------------------------------===//
434 // ParamExprAttr
435 //===----------------------------------------------------------------------===//
436 
437 /// Given a binary function, if the two operands are known constant integers,
438 /// use the specified fold function to compute the result.
439 static TypedAttr foldBinaryOp(
440  ArrayRef<TypedAttr> operands,
441  llvm::function_ref<APInt(const APInt &, const APInt &)> calculate) {
442  assert(operands.size() == 2 && "binary operator always has two operands");
443  if (auto lhs = dyn_cast<IntegerAttr>(operands[0]))
444  if (auto rhs = dyn_cast<IntegerAttr>(operands[1]))
445  return IntegerAttr::get(lhs.getType(),
446  calculate(lhs.getValue(), rhs.getValue()));
447  return {};
448 }
449 
450 /// Given a unary function, if the operand is a known constant integer,
451 /// use the specified fold function to compute the result.
452 static TypedAttr
453 foldUnaryOp(ArrayRef<TypedAttr> operands,
454  llvm::function_ref<APInt(const APInt &)> calculate) {
455  assert(operands.size() == 1 && "unary operator always has one operand");
456  if (auto intAttr = dyn_cast<IntegerAttr>(operands[0]))
457  return IntegerAttr::get(intAttr.getType(), calculate(intAttr.getValue()));
458  return {};
459 }
460 
461 /// If the specified attribute is a ParamExprAttr with the specified opcode,
462 /// return it. Otherwise return null.
463 static ParamExprAttr dyn_castPE(PEO opcode, Attribute value) {
464  if (auto expr = dyn_cast<ParamExprAttr>(value))
465  if (expr.getOpcode() == opcode)
466  return expr;
467  return {};
468 }
469 
470 /// This implements a < comparison for two operands to an associative operation
471 /// imposing an ordering upon them.
472 ///
473 /// The ordering provided puts more complex things to the start of the list,
474 /// from left to right:
475 /// expressions :: verbatims :: decl.refs :: constant
476 ///
477 static bool paramExprOperandSortPredicate(Attribute lhs, Attribute rhs) {
478  // Simplify the code below - we never have to care about exactly equal values.
479  if (lhs == rhs)
480  return false;
481 
482  // All expressions are "less than" a constant, since they appear on the right.
483  if (isa<IntegerAttr>(rhs)) {
484  // We don't bother to order constants w.r.t. each other since they will be
485  // folded - they can all compare equal.
486  return !isa<IntegerAttr>(lhs);
487  }
488  if (isa<IntegerAttr>(lhs))
489  return false;
490 
491  // Next up are named parameters.
492  if (auto rhsParam = dyn_cast<ParamDeclRefAttr>(rhs)) {
493  // Parameters are sorted lexically w.r.t. each other.
494  if (auto lhsParam = dyn_cast<ParamDeclRefAttr>(lhs))
495  return lhsParam.getName().getValue() < rhsParam.getName().getValue();
496  // They otherwise appear on the right of other things.
497  return true;
498  }
499  if (isa<ParamDeclRefAttr>(lhs))
500  return false;
501 
502  // Next up are verbatim parameters.
503  if (auto rhsParam = dyn_cast<ParamVerbatimAttr>(rhs)) {
504  // Verbatims are sorted lexically w.r.t. each other.
505  if (auto lhsParam = dyn_cast<ParamVerbatimAttr>(lhs))
506  return lhsParam.getValue().getValue() < rhsParam.getValue().getValue();
507  // They otherwise appear on the right of other things.
508  return true;
509  }
510  if (isa<ParamVerbatimAttr>(lhs))
511  return false;
512 
513  // The only thing left are nested expressions.
514  auto lhsExpr = cast<ParamExprAttr>(lhs), rhsExpr = cast<ParamExprAttr>(rhs);
515  // Sort by the string form of the opcode, e.g. add, .. mul,... then xor.
516  if (lhsExpr.getOpcode() != rhsExpr.getOpcode())
517  return stringifyPEO(lhsExpr.getOpcode()) <
518  stringifyPEO(rhsExpr.getOpcode());
519 
520  // If they are the same opcode, then sort by arity: more complex to the left.
521  ArrayRef<TypedAttr> lhsOperands = lhsExpr.getOperands(),
522  rhsOperands = rhsExpr.getOperands();
523  if (lhsOperands.size() != rhsOperands.size())
524  return lhsOperands.size() > rhsOperands.size();
525 
526  // We know the two subexpressions are different (they'd otherwise be pointer
527  // equivalent) so just go compare all of the elements.
528  for (size_t i = 0, e = lhsOperands.size(); i != e; ++i) {
529  if (paramExprOperandSortPredicate(lhsOperands[i], rhsOperands[i]))
530  return true;
531  if (paramExprOperandSortPredicate(rhsOperands[i], lhsOperands[i]))
532  return false;
533  }
534 
535  llvm_unreachable("expressions should never be equivalent");
536  return false;
537 }
538 
539 /// Given a fully associative variadic integer operation, constant fold any
540 /// constant operands and move them to the right. If the whole expression is
541 /// constant, then return that, otherwise update the operands list.
542 static TypedAttr simplifyAssocOp(
543  PEO opcode, SmallVector<TypedAttr, 4> &operands,
544  llvm::function_ref<APInt(const APInt &, const APInt &)> calculateFn,
545  llvm::function_ref<bool(const APInt &)> identityConstantFn,
546  llvm::function_ref<bool(const APInt &)> destructiveConstantFn = {}) {
547  auto type = operands[0].getType();
548  assert(isHWIntegerType(type));
549  if (operands.size() == 1)
550  return operands[0];
551 
552  // Flatten any of the same operation into the operand list:
553  // `(add x, (add y, z))` => `(add x, y, z)`.
554  for (size_t i = 0, e = operands.size(); i != e; ++i) {
555  if (auto subexpr = dyn_castPE(opcode, operands[i])) {
556  std::swap(operands[i], operands.back());
557  operands.pop_back();
558  --e;
559  --i;
560  operands.append(subexpr.getOperands().begin(),
561  subexpr.getOperands().end());
562  }
563  }
564 
565  // Impose an ordering on the operands, pushing subexpressions to the left and
566  // constants to the right, with verbatims and parameters in the middle - but
567  // predictably ordered w.r.t. each other.
568  llvm::stable_sort(operands, paramExprOperandSortPredicate);
569 
570  // Merge any constants, they will appear at the back of the operand list now.
571  if (isa<IntegerAttr>(operands.back())) {
572  while (operands.size() >= 2 &&
573  isa<IntegerAttr>(operands[operands.size() - 2])) {
574  APInt c1 = cast<IntegerAttr>(operands.pop_back_val()).getValue();
575  APInt c2 = cast<IntegerAttr>(operands.pop_back_val()).getValue();
576  auto resultConstant = IntegerAttr::get(type, calculateFn(c1, c2));
577  operands.push_back(resultConstant);
578  }
579 
580  auto resultCst = cast<IntegerAttr>(operands.back());
581 
582  // If the resulting constant is the destructive constant (e.g. `x*0`), then
583  // return it.
584  if (destructiveConstantFn && destructiveConstantFn(resultCst.getValue()))
585  return resultCst;
586 
587  // Remove the constant back to our operand list if it is the identity
588  // constant for this operator (e.g. `x*1`) and there are other operands.
589  if (identityConstantFn(resultCst.getValue()) && operands.size() != 1)
590  operands.pop_back();
591  }
592 
593  return operands.size() == 1 ? operands[0] : TypedAttr();
594 }
595 
596 /// Analyze an operand to an add. If it is a multiplication by a constant (e.g.
597 /// `(a*b*42)` then split it into the non-constant and the constant portions
598 /// (e.g. `a*b` and `42`). Otherwise return the operand as the first value and
599 /// null as the second (standin for "multiplication by 1").
600 static std::pair<TypedAttr, TypedAttr> decomposeAddend(TypedAttr operand) {
601  if (auto mul = dyn_castPE(PEO::Mul, operand))
602  if (auto cst = dyn_cast<IntegerAttr>(mul.getOperands().back())) {
603  auto nonCst = ParamExprAttr::get(PEO::Mul, mul.getOperands().drop_back());
604  return {nonCst, cst};
605  }
606  return {operand, TypedAttr()};
607 }
608 
609 static TypedAttr getOneOfType(Type type) {
610  return IntegerAttr::get(type, APInt(type.getIntOrFloatBitWidth(), 1));
611 }
612 
613 static TypedAttr simplifyAdd(SmallVector<TypedAttr, 4> &operands) {
614  if (auto result = simplifyAssocOp(
615  PEO::Add, operands, [](auto a, auto b) { return a + b; },
616  /*identityCst*/ [](auto cst) { return cst.isZero(); }))
617  return result;
618 
619  // Canonicalize the add by splitting all addends into their variable and
620  // constant factors.
621  SmallVector<std::pair<TypedAttr, TypedAttr>> decomposedOperands;
622  llvm::SmallDenseSet<TypedAttr> nonConstantParts;
623  for (auto &op : operands) {
624  decomposedOperands.push_back(decomposeAddend(op));
625 
626  // Keep track of non-constant parts we've already seen. If we see multiple
627  // uses of the same value, then we can fold them together with a multiply.
628  // This handles things like `(a+b+a)` => `(a*2 + b)` and `(a*2 + b + a)` =>
629  // `(a*3 + b)`.
630  if (!nonConstantParts.insert(decomposedOperands.back().first).second) {
631  // The thing we multiply will be the common expression.
632  TypedAttr mulOperand = decomposedOperands.back().first;
633 
634  // Find the index of the first occurrence.
635  size_t i = 0;
636  while (decomposedOperands[i].first != mulOperand)
637  ++i;
638  // Remove both occurrences from the operand list.
639  operands.erase(operands.begin() + (&op - &operands[0]));
640  operands.erase(operands.begin() + i);
641 
642  auto type = mulOperand.getType();
643  auto c1 = decomposedOperands[i].second,
644  c2 = decomposedOperands.back().second;
645  // Fill in missing constant multiplicands with 1.
646  if (!c1)
647  c1 = getOneOfType(type);
648  if (!c2)
649  c2 = getOneOfType(type);
650  // Re-add the "a"*(c1+c2) expression to the operand list and
651  // re-canonicalize.
652  auto constant = ParamExprAttr::get(PEO::Add, c1, c2);
653  auto mulCst = ParamExprAttr::get(PEO::Mul, mulOperand, constant);
654  operands.push_back(mulCst);
655  return ParamExprAttr::get(PEO::Add, operands);
656  }
657  }
658 
659  return {};
660 }
661 
662 static TypedAttr simplifyMul(SmallVector<TypedAttr, 4> &operands) {
663  if (auto result = simplifyAssocOp(
664  PEO::Mul, operands, [](auto a, auto b) { return a * b; },
665  /*identityCst*/ [](auto cst) { return cst.isOne(); },
666  /*destructiveCst*/ [](auto cst) { return cst.isZero(); }))
667  return result;
668 
669  // We always build a sum-of-products representation, so if we see an addition
670  // as a subexpr, we need to pull it out: (a+b)*c*d ==> (a*c*d + b*c*d).
671  for (size_t i = 0, e = operands.size(); i != e; ++i) {
672  if (auto subexpr = dyn_castPE(PEO::Add, operands[i])) {
673  // Pull the `c*d` operands out - it is whatever operands remain after
674  // removing the `(a+b)` term.
675  SmallVector<TypedAttr> mulOperands(operands.begin(), operands.end());
676  mulOperands.erase(mulOperands.begin() + i);
677 
678  // Build each add operand.
679  SmallVector<TypedAttr> addOperands;
680  for (auto addOperand : subexpr.getOperands()) {
681  mulOperands.push_back(addOperand);
682  addOperands.push_back(ParamExprAttr::get(PEO::Mul, mulOperands));
683  mulOperands.pop_back();
684  }
685  // Canonicalize and form the add expression.
686  return ParamExprAttr::get(PEO::Add, addOperands);
687  }
688  }
689 
690  return {};
691 }
692 static TypedAttr simplifyAnd(SmallVector<TypedAttr, 4> &operands) {
693  return simplifyAssocOp(
694  PEO::And, operands, [](auto a, auto b) { return a & b; },
695  /*identityCst*/ [](auto cst) { return cst.isAllOnes(); },
696  /*destructiveCst*/ [](auto cst) { return cst.isZero(); });
697 }
698 
699 static TypedAttr simplifyOr(SmallVector<TypedAttr, 4> &operands) {
700  return simplifyAssocOp(
701  PEO::Or, operands, [](auto a, auto b) { return a | b; },
702  /*identityCst*/ [](auto cst) { return cst.isZero(); },
703  /*destructiveCst*/ [](auto cst) { return cst.isAllOnes(); });
704 }
705 
706 static TypedAttr simplifyXor(SmallVector<TypedAttr, 4> &operands) {
707  return simplifyAssocOp(
708  PEO::Xor, operands, [](auto a, auto b) { return a ^ b; },
709  /*identityCst*/ [](auto cst) { return cst.isZero(); });
710 }
711 
712 static TypedAttr simplifyShl(SmallVector<TypedAttr, 4> &operands) {
713  assert(isHWIntegerType(operands[0].getType()));
714 
715  if (auto rhs = dyn_cast<IntegerAttr>(operands[1])) {
716  // Constant fold simple integers.
717  if (auto lhs = dyn_cast<IntegerAttr>(operands[0]))
718  return IntegerAttr::get(lhs.getType(),
719  lhs.getValue().shl(rhs.getValue()));
720 
721  // Canonicalize `x << cst` => `x * (1<<cst)` to compose correctly with
722  // add/mul canonicalization.
723  auto rhsCst = APInt::getOneBitSet(rhs.getValue().getBitWidth(),
724  rhs.getValue().getZExtValue());
725  return ParamExprAttr::get(PEO::Mul, operands[0],
726  IntegerAttr::get(rhs.getType(), rhsCst));
727  }
728  return {};
729 }
730 
731 static TypedAttr simplifyShrU(SmallVector<TypedAttr, 4> &operands) {
732  assert(isHWIntegerType(operands[0].getType()));
733  // Implement support for identities like `x >> 0`.
734  if (auto rhs = dyn_cast<IntegerAttr>(operands[1]))
735  if (rhs.getValue().isZero())
736  return operands[0];
737 
738  return foldBinaryOp(operands, [](auto a, auto b) { return a.lshr(b); });
739 }
740 
741 static TypedAttr simplifyShrS(SmallVector<TypedAttr, 4> &operands) {
742  assert(isHWIntegerType(operands[0].getType()));
743  // Implement support for identities like `x >> 0`.
744  if (auto rhs = dyn_cast<IntegerAttr>(operands[1]))
745  if (rhs.getValue().isZero())
746  return operands[0];
747 
748  return foldBinaryOp(operands, [](auto a, auto b) { return a.ashr(b); });
749 }
750 
751 static TypedAttr simplifyDivU(SmallVector<TypedAttr, 4> &operands) {
752  assert(isHWIntegerType(operands[0].getType()));
753  // Implement support for identities like `x/1`.
754  if (auto rhs = dyn_cast<IntegerAttr>(operands[1]))
755  if (rhs.getValue().isOne())
756  return operands[0];
757 
758  return foldBinaryOp(operands, [](auto a, auto b) { return a.udiv(b); });
759 }
760 
761 static TypedAttr simplifyDivS(SmallVector<TypedAttr, 4> &operands) {
762  assert(isHWIntegerType(operands[0].getType()));
763  // Implement support for identities like `x/1`.
764  if (auto rhs = dyn_cast<IntegerAttr>(operands[1]))
765  if (rhs.getValue().isOne())
766  return operands[0];
767 
768  return foldBinaryOp(operands, [](auto a, auto b) { return a.sdiv(b); });
769 }
770 
771 static TypedAttr simplifyModU(SmallVector<TypedAttr, 4> &operands) {
772  assert(isHWIntegerType(operands[0].getType()));
773  // Implement support for identities like `x%1`.
774  if (auto rhs = dyn_cast<IntegerAttr>(operands[1]))
775  if (rhs.getValue().isOne())
776  return IntegerAttr::get(rhs.getType(), 0);
777 
778  return foldBinaryOp(operands, [](auto a, auto b) { return a.urem(b); });
779 }
780 
781 static TypedAttr simplifyModS(SmallVector<TypedAttr, 4> &operands) {
782  assert(isHWIntegerType(operands[0].getType()));
783  // Implement support for identities like `x%1`.
784  if (auto rhs = dyn_cast<IntegerAttr>(operands[1]))
785  if (rhs.getValue().isOne())
786  return IntegerAttr::get(rhs.getType(), 0);
787 
788  return foldBinaryOp(operands, [](auto a, auto b) { return a.srem(b); });
789 }
790 
791 static TypedAttr simplifyCLog2(SmallVector<TypedAttr, 4> &operands) {
792  assert(isHWIntegerType(operands[0].getType()));
793  return foldUnaryOp(operands, [](auto a) {
794  // Following the Verilog spec, clog2(0) is 0
795  return APInt(a.getBitWidth(), a == 0 ? 0 : a.ceilLogBase2());
796  });
797 }
798 
799 static TypedAttr simplifyStrConcat(SmallVector<TypedAttr, 4> &operands) {
800  // Combine all adjacent strings.
801  SmallVector<TypedAttr> newOperands;
802  SmallVector<StringAttr> stringsToCombine;
803  auto combineAndPush = [&]() {
804  if (stringsToCombine.empty())
805  return;
806  // Concatenate buffered strings, push to ops.
807  SmallString<32> newString;
808  for (auto part : stringsToCombine)
809  newString.append(part.getValue());
810  newOperands.push_back(
811  StringAttr::get(stringsToCombine[0].getContext(), newString));
812  stringsToCombine.clear();
813  };
814 
815  for (TypedAttr op : operands) {
816  if (auto strOp = dyn_cast<StringAttr>(op)) {
817  // Queue up adjacent strings.
818  stringsToCombine.push_back(strOp);
819  } else {
820  combineAndPush();
821  newOperands.push_back(op);
822  }
823  }
824  combineAndPush();
825 
826  assert(!newOperands.empty());
827  if (newOperands.size() == 1)
828  return newOperands[0];
829  if (newOperands.size() < operands.size())
830  return ParamExprAttr::get(PEO::StrConcat, newOperands);
831  return {};
832 }
833 
834 /// Build a parameter expression. This automatically canonicalizes and
835 /// folds, so it may not necessarily return a ParamExprAttr.
836 TypedAttr ParamExprAttr::get(PEO opcode, ArrayRef<TypedAttr> operandsIn) {
837  assert(!operandsIn.empty() && "Cannot have expr with no operands");
838  // All operands must have the same type, which is the type of the result.
839  auto type = operandsIn.front().getType();
840  assert(llvm::all_of(operandsIn.drop_front(),
841  [&](auto op) { return op.getType() == type; }));
842 
843  SmallVector<TypedAttr, 4> operands(operandsIn.begin(), operandsIn.end());
844 
845  // Verify and canonicalize parameter expressions.
846  TypedAttr result;
847  switch (opcode) {
848  case PEO::Add:
849  result = simplifyAdd(operands);
850  break;
851  case PEO::Mul:
852  result = simplifyMul(operands);
853  break;
854  case PEO::And:
855  result = simplifyAnd(operands);
856  break;
857  case PEO::Or:
858  result = simplifyOr(operands);
859  break;
860  case PEO::Xor:
861  result = simplifyXor(operands);
862  break;
863  case PEO::Shl:
864  result = simplifyShl(operands);
865  break;
866  case PEO::ShrU:
867  result = simplifyShrU(operands);
868  break;
869  case PEO::ShrS:
870  result = simplifyShrS(operands);
871  break;
872  case PEO::DivU:
873  result = simplifyDivU(operands);
874  break;
875  case PEO::DivS:
876  result = simplifyDivS(operands);
877  break;
878  case PEO::ModU:
879  result = simplifyModU(operands);
880  break;
881  case PEO::ModS:
882  result = simplifyModS(operands);
883  break;
884  case PEO::CLog2:
885  result = simplifyCLog2(operands);
886  break;
887  case PEO::StrConcat:
888  result = simplifyStrConcat(operands);
889  break;
890  }
891 
892  // If we folded to an operand, return it.
893  if (result)
894  return result;
895 
896  return Base::get(operands[0].getContext(), opcode, operands, type);
897 }
898 
899 Attribute ParamExprAttr::parse(AsmParser &p, Type type) {
900  // We require an opcode suffix like `#hw.param.expr.add`, we don't allow
901  // parsing a plain `#hw.param.expr` on its own.
902  p.emitError(p.getNameLoc(), "#hw.param.expr should have opcode suffix");
903  return {};
904 }
905 
906 /// Internal method used for .mlir file parsing when parsing the
907 /// "#hw.param.expr.mul" form of the attribute.
908 static Attribute parseParamExprWithOpcode(StringRef opcodeStr,
909  DialectAsmParser &p, Type type) {
910  SmallVector<TypedAttr> operands;
911  if (p.parseCommaSeparatedList(
912  mlir::AsmParser::Delimiter::LessGreater, [&]() -> ParseResult {
913  operands.push_back({});
914  return p.parseAttribute(operands.back(), type);
915  }))
916  return {};
917 
918  std::optional<PEO> opcode = symbolizePEO(opcodeStr);
919  if (!opcode.has_value()) {
920  p.emitError(p.getNameLoc(), "unknown parameter expr operator name");
921  return {};
922  }
923 
924  return ParamExprAttr::get(*opcode, operands);
925 }
926 
927 void ParamExprAttr::print(AsmPrinter &p) const {
928  p << "." << stringifyPEO(getOpcode()) << '<';
929  llvm::interleaveComma(getOperands(), p.getStream(),
930  [&](Attribute op) { p.printAttributeWithoutType(op); });
931  p << '>';
932 }
933 
934 // Replaces any ParamDeclRefAttr within a parametric expression with its
935 // corresponding value from the map of provided parameters.
936 static FailureOr<Attribute>
937 replaceDeclRefInExpr(Location loc,
938  const std::map<std::string, Attribute> &parameters,
939  Attribute paramAttr, bool emitErrors) {
940  if (dyn_cast<IntegerAttr>(paramAttr)) {
941  // Nothing to do, constant value.
942  return paramAttr;
943  }
944  if (auto paramRefAttr = dyn_cast<hw::ParamDeclRefAttr>(paramAttr)) {
945  // Get the value from the provided parameters.
946  auto it = parameters.find(paramRefAttr.getName().str());
947  if (it == parameters.end()) {
948  if (emitErrors)
949  return emitError(loc)
950  << "Could not find parameter " << paramRefAttr.getName().str()
951  << " in the provided parameters for the expression!";
952  return failure();
953  }
954  return it->second;
955  }
956  if (auto paramExprAttr = dyn_cast<hw::ParamExprAttr>(paramAttr)) {
957  // Recurse into all operands of the expression.
958  llvm::SmallVector<TypedAttr, 4> replacedOperands;
959  for (auto operand : paramExprAttr.getOperands()) {
960  auto res = replaceDeclRefInExpr(loc, parameters, operand, emitErrors);
961  if (failed(res))
962  return {failure()};
963  replacedOperands.push_back(cast<TypedAttr>(*res));
964  }
965  return {
966  hw::ParamExprAttr::get(paramExprAttr.getOpcode(), replacedOperands)};
967  }
968  llvm_unreachable("Unhandled parametric attribute");
969  return {};
970 }
971 
972 FailureOr<TypedAttr> hw::evaluateParametricAttr(Location loc,
973  ArrayAttr parameters,
974  Attribute paramAttr,
975  bool emitErrors) {
976  // Create a map of the provided parameters for faster lookup.
977  std::map<std::string, Attribute> parameterMap;
978  for (auto param : parameters) {
979  auto paramDecl = cast<ParamDeclAttr>(param);
980  parameterMap[paramDecl.getName().str()] = paramDecl.getValue();
981  }
982 
983  // First, replace any ParamDeclRefAttr in the expression with its
984  // corresponding value in 'parameters'.
985  auto paramAttrRes =
986  replaceDeclRefInExpr(loc, parameterMap, paramAttr, emitErrors);
987  if (failed(paramAttrRes))
988  return {failure()};
989  paramAttr = *paramAttrRes;
990 
991  // Then, evaluate the parametric attribute.
992  if (isa<IntegerAttr, hw::ParamDeclRefAttr>(paramAttr))
993  return cast<TypedAttr>(paramAttr);
994  if (auto paramExprAttr = dyn_cast<hw::ParamExprAttr>(paramAttr)) {
995  // Since any ParamDeclRefAttr was replaced within the expression,
996  // we re-evaluate the expression through the existing ParamExprAttr
997  // canonicalizer.
998  return ParamExprAttr::get(paramExprAttr.getOpcode(),
999  paramExprAttr.getOperands());
1000  }
1001 
1002  llvm_unreachable("Unhandled parametric attribute");
1003  return TypedAttr();
1004 }
1005 
1006 template <typename TArray>
1007 FailureOr<Type> evaluateParametricArrayType(Location loc, ArrayAttr parameters,
1008  TArray arrayType, bool emitErrors) {
1009  auto size = evaluateParametricAttr(loc, parameters, arrayType.getSizeAttr(),
1010  emitErrors);
1011  if (failed(size))
1012  return failure();
1013  auto elementType = evaluateParametricType(
1014  loc, parameters, arrayType.getElementType(), emitErrors);
1015  if (failed(elementType))
1016  return failure();
1017 
1018  // If the size was evaluated to a constant, use a 64-bit integer
1019  // attribute version of it
1020  if (auto intAttr = dyn_cast<IntegerAttr>(*size))
1021  return TArray::get(
1022  arrayType.getContext(), *elementType,
1023  IntegerAttr::get(IntegerType::get(arrayType.getContext(), 64),
1024  intAttr.getValue().getSExtValue()));
1025 
1026  // Otherwise parameter references are still involved
1027  return TArray::get(arrayType.getContext(), *elementType, *size);
1028 }
1029 
1030 FailureOr<Type> hw::evaluateParametricType(Location loc, ArrayAttr parameters,
1031  Type type, bool emitErrors) {
1032  return llvm::TypeSwitch<Type, FailureOr<Type>>(type)
1033  .Case<hw::IntType>([&](hw::IntType t) -> FailureOr<Type> {
1034  auto evaluatedWidth =
1035  evaluateParametricAttr(loc, parameters, t.getWidth(), emitErrors);
1036  if (failed(evaluatedWidth))
1037  return {failure()};
1038 
1039  // If the width was evaluated to a constant, return an `IntegerType`
1040  if (auto intAttr = dyn_cast<IntegerAttr>(*evaluatedWidth))
1041  return {IntegerType::get(type.getContext(),
1042  intAttr.getValue().getSExtValue())};
1043 
1044  // Otherwise parameter references are still involved
1045  return hw::IntType::get(cast<TypedAttr>(*evaluatedWidth));
1046  })
1047  .Case<hw::ArrayType, hw::UnpackedArrayType>(
1048  [&](auto arrayType) -> FailureOr<Type> {
1049  return evaluateParametricArrayType(loc, parameters, arrayType,
1050  emitErrors);
1051  })
1052  .Default([&](auto) { return type; });
1053 }
1054 
1055 // Returns true if any part of this parametric attribute contains a reference
1056 // to a parameter declaration.
1057 static bool isParamAttrWithParamRef(Attribute expr) {
1058  return llvm::TypeSwitch<Attribute, bool>(expr)
1059  .Case([](ParamExprAttr attr) {
1060  return llvm::any_of(attr.getOperands(), isParamAttrWithParamRef);
1061  })
1062  .Case([](ParamDeclRefAttr) { return true; })
1063  .Default([](auto) { return false; });
1064 }
1065 
1066 bool hw::isParametricType(mlir::Type t) {
1067  return llvm::TypeSwitch<Type, bool>(t)
1068  .Case<hw::IntType>(
1069  [&](hw::IntType t) { return isParamAttrWithParamRef(t.getWidth()); })
1070  .Case<hw::ArrayType, hw::UnpackedArrayType>([&](auto arrayType) {
1071  return isParametricType(arrayType.getElementType()) ||
1072  isParamAttrWithParamRef(arrayType.getSizeAttr());
1073  })
1074  .Default([](auto) { return false; });
1075 }
assert(baseType &&"element must be base type")
static TypedAttr foldBinaryOp(ArrayRef< TypedAttr > operands, llvm::function_ref< APInt(const APInt &, const APInt &)> calculate)
Given a binary function, if the two operands are known constant integers, use the specified fold func...
static TypedAttr simplifyDivS(SmallVector< TypedAttr, 4 > &operands)
static TypedAttr simplifyAssocOp(PEO opcode, SmallVector< TypedAttr, 4 > &operands, llvm::function_ref< APInt(const APInt &, const APInt &)> calculateFn, llvm::function_ref< bool(const APInt &)> identityConstantFn, llvm::function_ref< bool(const APInt &)> destructiveConstantFn={})
Given a fully associative variadic integer operation, constant fold any constant operands and move th...
static TypedAttr simplifyDivU(SmallVector< TypedAttr, 4 > &operands)
static TypedAttr simplifyCLog2(SmallVector< TypedAttr, 4 > &operands)
static TypedAttr simplifyModU(SmallVector< TypedAttr, 4 > &operands)
static TypedAttr simplifyOr(SmallVector< TypedAttr, 4 > &operands)
static TypedAttr simplifyShl(SmallVector< TypedAttr, 4 > &operands)
static void printSymbolName(AsmPrinter &p, StringAttr sym)
static TypedAttr simplifyShrS(SmallVector< TypedAttr, 4 > &operands)
static TypedAttr foldUnaryOp(ArrayRef< TypedAttr > operands, llvm::function_ref< APInt(const APInt &)> calculate)
Given a unary function, if the operand is a known constant integer, use the specified fold function t...
static TypedAttr simplifyMul(SmallVector< TypedAttr, 4 > &operands)
static TypedAttr getOneOfType(Type type)
static TypedAttr simplifyAnd(SmallVector< TypedAttr, 4 > &operands)
static std::pair< TypedAttr, TypedAttr > decomposeAddend(TypedAttr operand)
Analyze an operand to an add.
static TypedAttr simplifyModS(SmallVector< TypedAttr, 4 > &operands)
static Attribute parseParamExprWithOpcode(StringRef opcode, DialectAsmParser &p, Type type)
Internal method used for .mlir file parsing when parsing the "#hw.param.expr.mul" form of the attribu...
static bool paramExprOperandSortPredicate(Attribute lhs, Attribute rhs)
This implements a < comparison for two operands to an associative operation imposing an ordering upon...
static std::string canonicalizeFilename(const Twine &directory, const Twine &filename)
static TypedAttr simplifyShrU(SmallVector< TypedAttr, 4 > &operands)
static TypedAttr simplifyAdd(SmallVector< TypedAttr, 4 > &operands)
static TypedAttr simplifyStrConcat(SmallVector< TypedAttr, 4 > &operands)
static TypedAttr simplifyXor(SmallVector< TypedAttr, 4 > &operands)
static ParamExprAttr dyn_castPE(PEO opcode, Attribute value)
If the specified attribute is a ParamExprAttr with the specified opcode, return it.
static StringAttr append(StringAttr base, const Twine &suffix)
Return a attribute with the specified suffix appended.
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:54
uint64_t getFieldID(Type type, uint64_t index)
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
bool isHWIntegerType(mlir::Type type)
Return true if the specified type is a value HW Integer type.
Definition: HWTypes.cpp:52
bool isHWEnumType(mlir::Type type)
Return true if the specified type is a HW Enum type.
Definition: HWTypes.cpp:65
mlir::Type getCanonicalType(mlir::Type type)
Definition: HWTypes.cpp:41
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21