CIRCT  18.0.0git
HWAttributes.cpp
Go to the documentation of this file.
1 //===- HWAttributes.cpp - Implement HW attributes -------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
11 #include "circt/Dialect/HW/HWOps.h"
13 #include "circt/Support/LLVM.h"
14 #include "mlir/IR/Diagnostics.h"
15 #include "mlir/IR/DialectImplementation.h"
16 #include "llvm/ADT/SmallString.h"
17 #include "llvm/ADT/StringExtras.h"
18 #include "llvm/ADT/TypeSwitch.h"
19 #include "llvm/Support/FileSystem.h"
20 #include "llvm/Support/Path.h"
21 
22 using namespace circt;
23 using namespace circt::hw;
24 using mlir::TypedAttr;
25 
26 // Internal method used for .mlir file parsing, defined below.
27 static Attribute parseParamExprWithOpcode(StringRef opcode, DialectAsmParser &p,
28  Type type);
29 
30 //===----------------------------------------------------------------------===//
31 // ODS Boilerplate
32 //===----------------------------------------------------------------------===//
33 
34 #define GET_ATTRDEF_CLASSES
35 #include "circt/Dialect/HW/HWAttributes.cpp.inc"
36 
37 void HWDialect::registerAttributes() {
38  addAttributes<
39 #define GET_ATTRDEF_LIST
40 #include "circt/Dialect/HW/HWAttributes.cpp.inc"
41  >();
42 }
43 
44 Attribute HWDialect::parseAttribute(DialectAsmParser &p, Type type) const {
45  StringRef attrName;
46  Attribute attr;
47  auto parseResult = generatedAttributeParser(p, &attrName, type, attr);
48  if (parseResult.has_value())
49  return attr;
50 
51  // Parse "#hw.param.expr.add" as ParamExprAttr.
52  if (attrName.startswith(ParamExprAttr::getMnemonic())) {
53  auto string = attrName.drop_front(ParamExprAttr::getMnemonic().size());
54  if (string.front() == '.')
55  return parseParamExprWithOpcode(string.drop_front(), p, type);
56  }
57 
58  p.emitError(p.getNameLoc(), "Unexpected hw attribute '" + attrName + "'");
59  return {};
60 }
61 
62 void HWDialect::printAttribute(Attribute attr, DialectAsmPrinter &p) const {
63  if (succeeded(generatedAttributePrinter(attr, p)))
64  return;
65  llvm_unreachable("Unexpected attribute");
66 }
67 
68 //===----------------------------------------------------------------------===//
69 // OutputFileAttr
70 //===----------------------------------------------------------------------===//
71 
72 static std::string canonicalizeFilename(const Twine &directory,
73  const Twine &filename) {
74 
75  // Convert the filename to a native style path.
76  SmallString<128> nativeFilename;
77  llvm::sys::path::native(filename, nativeFilename);
78 
79  // If the filename is an absolute path, ignore the directory.
80  // e.g. `directory/` + `/etc/filename` -> `/etc/filename`.
81  if (llvm::sys::path::is_absolute(nativeFilename))
82  return std::string(nativeFilename);
83 
84  // Convert the directory to a native style path.
85  SmallString<128> nativeDirectory;
86  llvm::sys::path::native(directory, nativeDirectory);
87 
88  // If the filename component is empty, then ensure that the path ends in a
89  // separator and return it.
90  // e.g. `directory` + `` -> `directory/`.
91  auto separator = llvm::sys::path::get_separator();
92  if (nativeFilename.empty() && !nativeDirectory.endswith(separator)) {
93  nativeDirectory += separator;
94  return std::string(nativeDirectory);
95  }
96 
97  // Append the directory and filename together.
98  // e.g. `/tmp/` + `out/filename` -> `/tmp/out/filename`.
99  SmallString<128> fullPath;
100  llvm::sys::path::append(fullPath, nativeDirectory, nativeFilename);
101  return std::string(fullPath);
102 }
103 
104 OutputFileAttr OutputFileAttr::getFromFilename(MLIRContext *context,
105  const Twine &filename,
106  bool excludeFromFileList,
107  bool includeReplicatedOps) {
108  return OutputFileAttr::getFromDirectoryAndFilename(
109  context, "", filename, excludeFromFileList, includeReplicatedOps);
110 }
111 
112 OutputFileAttr OutputFileAttr::getFromDirectoryAndFilename(
113  MLIRContext *context, const Twine &directory, const Twine &filename,
114  bool excludeFromFileList, bool includeReplicatedOps) {
115  auto canonicalized = canonicalizeFilename(directory, filename);
116  return OutputFileAttr::get(StringAttr::get(context, canonicalized),
117  BoolAttr::get(context, excludeFromFileList),
118  BoolAttr::get(context, includeReplicatedOps));
119 }
120 
121 OutputFileAttr OutputFileAttr::getAsDirectory(MLIRContext *context,
122  const Twine &directory,
123  bool excludeFromFileList,
124  bool includeReplicatedOps) {
125  return getFromDirectoryAndFilename(context, directory, "",
126  excludeFromFileList, includeReplicatedOps);
127 }
128 
129 bool OutputFileAttr::isDirectory() {
130  return getFilename().getValue().endswith(llvm::sys::path::get_separator());
131 }
132 
133 /// Option ::= 'excludeFromFileList' | 'includeReplicatedOp'
134 /// OutputFileAttr ::= 'output_file<' directory ',' name (',' Option)* '>'
135 Attribute OutputFileAttr::parse(AsmParser &p, Type type) {
136  StringAttr filename;
137  if (p.parseLess() || p.parseAttribute<StringAttr>(filename))
138  return Attribute();
139 
140  // Parse the additional keyword attributes. Its easier to let people specify
141  // these more than once than to detect the problem and do something about it.
142  bool excludeFromFileList = false;
143  bool includeReplicatedOps = false;
144  while (true) {
145  if (p.parseOptionalComma())
146  break;
147  if (!p.parseOptionalKeyword("excludeFromFileList"))
148  excludeFromFileList = true;
149  else if (!p.parseKeyword("includeReplicatedOps",
150  "or 'excludeFromFileList'"))
151  includeReplicatedOps = true;
152  else
153  return Attribute();
154  }
155 
156  if (p.parseGreater())
157  return Attribute();
158 
159  return OutputFileAttr::getFromFilename(p.getContext(), filename.getValue(),
160  excludeFromFileList,
161  includeReplicatedOps);
162 }
163 
164 void OutputFileAttr::print(AsmPrinter &p) const {
165  p << "<" << getFilename();
166  if (getExcludeFromFilelist().getValue())
167  p << ", excludeFromFileList";
168  if (getIncludeReplicatedOps().getValue())
169  p << ", includeReplicatedOps";
170  p << ">";
171 }
172 
173 //===----------------------------------------------------------------------===//
174 // FileListAttr
175 //===----------------------------------------------------------------------===//
176 
177 FileListAttr FileListAttr::getFromFilename(MLIRContext *context,
178  const Twine &filename) {
179  auto canonicalized = canonicalizeFilename("", filename);
180  return FileListAttr::get(StringAttr::get(context, canonicalized));
181 }
182 
183 //===----------------------------------------------------------------------===//
184 // EnumFieldAttr
185 //===----------------------------------------------------------------------===//
186 
187 Attribute EnumFieldAttr::parse(AsmParser &p, Type) {
188  StringRef field;
189  Type type;
190  if (p.parseLess() || p.parseKeyword(&field) || p.parseComma() ||
191  p.parseType(type) || p.parseGreater())
192  return Attribute();
193  return EnumFieldAttr::get(p.getEncodedSourceLoc(p.getCurrentLocation()),
194  StringAttr::get(p.getContext(), field), type);
195 }
196 
197 void EnumFieldAttr::print(AsmPrinter &p) const {
198  p << "<" << getField().getValue() << ", ";
199  p.printType(getType().getValue());
200  p << ">";
201 }
202 
203 EnumFieldAttr EnumFieldAttr::get(Location loc, StringAttr value,
204  mlir::Type type) {
205  if (!hw::isHWEnumType(type))
206  emitError(loc) << "expected enum type";
207 
208  // Check whether the provided value is a member of the enum type.
209  EnumType enumType = getCanonicalType(type).cast<EnumType>();
210  if (!enumType.contains(value.getValue())) {
211  emitError(loc) << "enum value '" << value.getValue()
212  << "' is not a member of enum type " << enumType;
213  return nullptr;
214  }
215 
216  return Base::get(value.getContext(), value, TypeAttr::get(type));
217 }
218 
219 //===----------------------------------------------------------------------===//
220 // InnerRefAttr
221 //===----------------------------------------------------------------------===//
222 
223 Attribute InnerRefAttr::parse(AsmParser &p, Type type) {
224  SymbolRefAttr attr;
225  if (p.parseLess() || p.parseAttribute<SymbolRefAttr>(attr) ||
226  p.parseGreater())
227  return Attribute();
228  if (attr.getNestedReferences().size() != 1)
229  return Attribute();
230  return InnerRefAttr::get(attr.getRootReference(), attr.getLeafReference());
231 }
232 
233 static void printSymbolName(AsmPrinter &p, StringAttr sym) {
234  if (sym)
235  p.printSymbolName(sym.getValue());
236  else
237  p.printSymbolName({});
238 }
239 
240 void InnerRefAttr::print(AsmPrinter &p) const {
241  p << "<";
242  printSymbolName(p, getModule());
243  p << "::";
244  printSymbolName(p, getName());
245  p << ">";
246 }
247 
248 //===----------------------------------------------------------------------===//
249 // InnerSymAttr and InnerSymPropertiesAttr
250 //===----------------------------------------------------------------------===//
251 
252 Attribute InnerSymPropertiesAttr::parse(AsmParser &parser, Type type) {
253  StringAttr name;
254  NamedAttrList dummyList;
255  int64_t fieldId = 0;
256  if (parser.parseLess() || parser.parseSymbolName(name, "name", dummyList) ||
257  parser.parseComma() || parser.parseInteger(fieldId) ||
258  parser.parseComma())
259  return Attribute();
260 
261  StringRef visibility;
262  auto loc = parser.getCurrentLocation();
263  if (parser.parseOptionalKeyword(&visibility,
264  {"public", "private", "nested"})) {
265  parser.emitError(loc, "expected 'public', 'private', or 'nested'");
266  return Attribute();
267  }
268  auto visibilityAttr = parser.getBuilder().getStringAttr(visibility);
269 
270  if (parser.parseGreater())
271  return Attribute();
272 
273  return parser.getChecked<InnerSymPropertiesAttr>(parser.getContext(), name,
274  fieldId, visibilityAttr);
275 }
276 
277 void InnerSymPropertiesAttr::print(AsmPrinter &odsPrinter) const {
278  odsPrinter << "<@" << getName().getValue() << "," << getFieldID() << ","
279  << getSymVisibility().getValue() << ">";
280 }
281 
282 LogicalResult InnerSymPropertiesAttr::verify(
283  ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError,
284  ::mlir::StringAttr name, uint64_t fieldID,
285  ::mlir::StringAttr symVisibility) {
286  if (!name || name.getValue().empty())
287  return emitError() << "inner symbol cannot have empty name";
288  return success();
289 }
290 
291 StringAttr InnerSymAttr::getSymIfExists(uint64_t fieldId) const {
292  const auto *it =
293  llvm::find_if(getImpl()->props, [&](const InnerSymPropertiesAttr &p) {
294  return p.getFieldID() == fieldId;
295  });
296  if (it != getProps().end())
297  return it->getName();
298  return {};
299 }
300 
301 InnerSymAttr InnerSymAttr::erase(uint64_t fieldID) const {
302  SmallVector<InnerSymPropertiesAttr> syms(getProps());
303  const auto *it = llvm::find_if(syms, [fieldID](InnerSymPropertiesAttr p) {
304  return p.getFieldID() == fieldID;
305  });
306  assert(it != syms.end());
307  syms.erase(it);
308  return InnerSymAttr::get(getContext(), syms);
309 }
310 
311 LogicalResult InnerSymAttr::walkSymbols(
312  llvm::function_ref<LogicalResult(StringAttr)> callback) const {
313  for (auto p : getImpl()->props)
314  if (callback(p.getName()).failed())
315  return failure();
316  return success();
317 }
318 
319 Attribute InnerSymAttr::parse(AsmParser &parser, Type type) {
320  StringAttr sym;
321  NamedAttrList dummyList;
322  SmallVector<InnerSymPropertiesAttr, 4> names;
323  if (!parser.parseOptionalSymbolName(sym, "dummy", dummyList)) {
324  auto prop = parser.getChecked<InnerSymPropertiesAttr>(
325  parser.getContext(), sym, 0,
326  StringAttr::get(parser.getContext(), "public"));
327  if (!prop)
328  return {};
329  names.push_back(prop);
330  } else if (parser.parseCommaSeparatedList(
331  OpAsmParser::Delimiter::Square, [&]() -> ParseResult {
332  InnerSymPropertiesAttr prop;
333  if (parser.parseCustomAttributeWithFallback(
334  prop, mlir::Type{}, "dummy", dummyList))
335  return failure();
336 
337  names.push_back(prop);
338 
339  return success();
340  }))
341  return Attribute();
342 
343  std::sort(names.begin(), names.end(),
344  [&](InnerSymPropertiesAttr a, InnerSymPropertiesAttr b) {
345  return a.getFieldID() < b.getFieldID();
346  });
347 
348  return InnerSymAttr::get(parser.getContext(), names);
349 }
350 
351 void InnerSymAttr::print(AsmPrinter &odsPrinter) const {
352 
353  auto props = getProps();
354  if (props.size() == 1 &&
355  props[0].getSymVisibility().getValue().equals("public") &&
356  props[0].getFieldID() == 0) {
357  odsPrinter << "@" << props[0].getName().getValue();
358  return;
359  }
360  auto names = props.vec();
361 
362  std::sort(names.begin(), names.end(),
363  [&](InnerSymPropertiesAttr a, InnerSymPropertiesAttr b) {
364  return a.getFieldID() < b.getFieldID();
365  });
366  odsPrinter << "[";
367  llvm::interleaveComma(names, odsPrinter, [&](InnerSymPropertiesAttr attr) {
368  attr.print(odsPrinter);
369  });
370  odsPrinter << "]";
371 }
372 
373 //===----------------------------------------------------------------------===//
374 // ParamDeclAttr
375 //===----------------------------------------------------------------------===//
376 
377 Attribute ParamDeclAttr::parse(AsmParser &p, Type trailing) {
378  std::string name;
379  Type type;
380  Attribute value;
381  // < "FOO" : i32 > : i32
382  // < "FOO" : i32 = 0 > : i32
383  // < "FOO" : none >
384  if (p.parseLess() || p.parseString(&name) || p.parseColonType(type))
385  return Attribute();
386 
387  if (succeeded(p.parseOptionalEqual())) {
388  if (p.parseAttribute(value, type))
389  return Attribute();
390  }
391 
392  if (p.parseGreater())
393  return Attribute();
394 
395  if (value)
396  return ParamDeclAttr::get(p.getContext(),
397  p.getBuilder().getStringAttr(name), type, value);
398  return ParamDeclAttr::get(name, type);
399 }
400 
401 void ParamDeclAttr::print(AsmPrinter &p) const {
402  p << "<" << getName() << ": " << getType();
403  if (getValue()) {
404  p << " = ";
405  p.printAttributeWithoutType(getValue());
406  }
407  p << ">";
408 }
409 
410 //===----------------------------------------------------------------------===//
411 // ParamDeclRefAttr
412 //===----------------------------------------------------------------------===//
413 
414 Attribute ParamDeclRefAttr::parse(AsmParser &p, Type type) {
415  StringAttr name;
416  if (p.parseLess() || p.parseAttribute(name) || p.parseGreater() ||
417  (!type && (p.parseColon() || p.parseType(type))))
418  return Attribute();
419 
420  return ParamDeclRefAttr::get(name, type);
421 }
422 
423 void ParamDeclRefAttr::print(AsmPrinter &p) const {
424  p << "<" << getName() << ">";
425 }
426 
427 //===----------------------------------------------------------------------===//
428 // ParamVerbatimAttr
429 //===----------------------------------------------------------------------===//
430 
431 Attribute ParamVerbatimAttr::parse(AsmParser &p, Type type) {
432  StringAttr text;
433  if (p.parseLess() || p.parseAttribute(text) || p.parseGreater() ||
434  (!type && (p.parseColon() || p.parseType(type))))
435  return Attribute();
436 
437  return ParamVerbatimAttr::get(p.getContext(), text, type);
438 }
439 
440 void ParamVerbatimAttr::print(AsmPrinter &p) const {
441  p << "<" << getValue() << ">";
442 }
443 
444 //===----------------------------------------------------------------------===//
445 // ParamExprAttr
446 //===----------------------------------------------------------------------===//
447 
448 /// Given a binary function, if the two operands are known constant integers,
449 /// use the specified fold function to compute the result.
450 static TypedAttr foldBinaryOp(
451  ArrayRef<TypedAttr> operands,
452  llvm::function_ref<APInt(const APInt &, const APInt &)> calculate) {
453  assert(operands.size() == 2 && "binary operator always has two operands");
454  if (auto lhs = operands[0].dyn_cast<IntegerAttr>())
455  if (auto rhs = operands[1].dyn_cast<IntegerAttr>())
456  return IntegerAttr::get(lhs.getType(),
457  calculate(lhs.getValue(), rhs.getValue()));
458  return {};
459 }
460 
461 /// Given a unary function, if the operand is a known constant integer,
462 /// use the specified fold function to compute the result.
463 static TypedAttr
464 foldUnaryOp(ArrayRef<TypedAttr> operands,
465  llvm::function_ref<APInt(const APInt &)> calculate) {
466  assert(operands.size() == 1 && "unary operator always has one operand");
467  if (auto intAttr = operands[0].dyn_cast<IntegerAttr>())
468  return IntegerAttr::get(intAttr.getType(), calculate(intAttr.getValue()));
469  return {};
470 }
471 
472 /// If the specified attribute is a ParamExprAttr with the specified opcode,
473 /// return it. Otherwise return null.
474 static ParamExprAttr dyn_castPE(PEO opcode, Attribute value) {
475  if (auto expr = value.dyn_cast<ParamExprAttr>())
476  if (expr.getOpcode() == opcode)
477  return expr;
478  return {};
479 }
480 
481 /// This implements a < comparison for two operands to an associative operation
482 /// imposing an ordering upon them.
483 ///
484 /// The ordering provided puts more complex things to the start of the list,
485 /// from left to right:
486 /// expressions :: verbatims :: decl.refs :: constant
487 ///
488 static bool paramExprOperandSortPredicate(Attribute lhs, Attribute rhs) {
489  // Simplify the code below - we never have to care about exactly equal values.
490  if (lhs == rhs)
491  return false;
492 
493  // All expressions are "less than" a constant, since they appear on the right.
494  if (rhs.isa<IntegerAttr>()) {
495  // We don't bother to order constants w.r.t. each other since they will be
496  // folded - they can all compare equal.
497  return !lhs.isa<IntegerAttr>();
498  }
499  if (lhs.isa<IntegerAttr>())
500  return false;
501 
502  // Next up are named parameters.
503  if (auto rhsParam = rhs.dyn_cast<ParamDeclRefAttr>()) {
504  // Parameters are sorted lexically w.r.t. each other.
505  if (auto lhsParam = lhs.dyn_cast<ParamDeclRefAttr>())
506  return lhsParam.getName().getValue() < rhsParam.getName().getValue();
507  // They otherwise appear on the right of other things.
508  return true;
509  }
510  if (lhs.isa<ParamDeclRefAttr>())
511  return false;
512 
513  // Next up are verbatim parameters.
514  if (auto rhsParam = rhs.dyn_cast<ParamVerbatimAttr>()) {
515  // Verbatims are sorted lexically w.r.t. each other.
516  if (auto lhsParam = lhs.dyn_cast<ParamVerbatimAttr>())
517  return lhsParam.getValue().getValue() < rhsParam.getValue().getValue();
518  // They otherwise appear on the right of other things.
519  return true;
520  }
521  if (lhs.isa<ParamVerbatimAttr>())
522  return false;
523 
524  // The only thing left are nested expressions.
525  auto lhsExpr = lhs.cast<ParamExprAttr>(), rhsExpr = rhs.cast<ParamExprAttr>();
526  // Sort by the string form of the opcode, e.g. add, .. mul,... then xor.
527  if (lhsExpr.getOpcode() != rhsExpr.getOpcode())
528  return stringifyPEO(lhsExpr.getOpcode()) <
529  stringifyPEO(rhsExpr.getOpcode());
530 
531  // If they are the same opcode, then sort by arity: more complex to the left.
532  ArrayRef<TypedAttr> lhsOperands = lhsExpr.getOperands(),
533  rhsOperands = rhsExpr.getOperands();
534  if (lhsOperands.size() != rhsOperands.size())
535  return lhsOperands.size() > rhsOperands.size();
536 
537  // We know the two subexpressions are different (they'd otherwise be pointer
538  // equivalent) so just go compare all of the elements.
539  for (size_t i = 0, e = lhsOperands.size(); i != e; ++i) {
540  if (paramExprOperandSortPredicate(lhsOperands[i], rhsOperands[i]))
541  return true;
542  if (paramExprOperandSortPredicate(rhsOperands[i], lhsOperands[i]))
543  return false;
544  }
545 
546  llvm_unreachable("expressions should never be equivalent");
547  return false;
548 }
549 
550 /// Given a fully associative variadic integer operation, constant fold any
551 /// constant operands and move them to the right. If the whole expression is
552 /// constant, then return that, otherwise update the operands list.
553 static TypedAttr simplifyAssocOp(
554  PEO opcode, SmallVector<TypedAttr, 4> &operands,
555  llvm::function_ref<APInt(const APInt &, const APInt &)> calculateFn,
556  llvm::function_ref<bool(const APInt &)> identityConstantFn,
557  llvm::function_ref<bool(const APInt &)> destructiveConstantFn = {}) {
558  auto type = operands[0].getType();
559  assert(isHWIntegerType(type));
560  if (operands.size() == 1)
561  return operands[0];
562 
563  // Flatten any of the same operation into the operand list:
564  // `(add x, (add y, z))` => `(add x, y, z)`.
565  for (size_t i = 0, e = operands.size(); i != e; ++i) {
566  if (auto subexpr = dyn_castPE(opcode, operands[i])) {
567  std::swap(operands[i], operands.back());
568  operands.pop_back();
569  --e;
570  --i;
571  operands.append(subexpr.getOperands().begin(),
572  subexpr.getOperands().end());
573  }
574  }
575 
576  // Impose an ordering on the operands, pushing subexpressions to the left and
577  // constants to the right, with verbatims and parameters in the middle - but
578  // predictably ordered w.r.t. each other.
579  llvm::stable_sort(operands, paramExprOperandSortPredicate);
580 
581  // Merge any constants, they will appear at the back of the operand list now.
582  if (operands.back().isa<IntegerAttr>()) {
583  while (operands.size() >= 2 &&
584  operands[operands.size() - 2].isa<IntegerAttr>()) {
585  APInt c1 = operands.pop_back_val().cast<IntegerAttr>().getValue();
586  APInt c2 = operands.pop_back_val().cast<IntegerAttr>().getValue();
587  auto resultConstant = IntegerAttr::get(type, calculateFn(c1, c2));
588  operands.push_back(resultConstant);
589  }
590 
591  auto resultCst = operands.back().cast<IntegerAttr>();
592 
593  // If the resulting constant is the destructive constant (e.g. `x*0`), then
594  // return it.
595  if (destructiveConstantFn && destructiveConstantFn(resultCst.getValue()))
596  return resultCst;
597 
598  // Remove the constant back to our operand list if it is the identity
599  // constant for this operator (e.g. `x*1`) and there are other operands.
600  if (identityConstantFn(resultCst.getValue()) && operands.size() != 1)
601  operands.pop_back();
602  }
603 
604  return operands.size() == 1 ? operands[0] : TypedAttr();
605 }
606 
607 /// Analyze an operand to an add. If it is a multiplication by a constant (e.g.
608 /// `(a*b*42)` then split it into the non-constant and the constant portions
609 /// (e.g. `a*b` and `42`). Otherwise return the operand as the first value and
610 /// null as the second (standin for "multiplication by 1").
611 static std::pair<TypedAttr, TypedAttr> decomposeAddend(TypedAttr operand) {
612  if (auto mul = dyn_castPE(PEO::Mul, operand))
613  if (auto cst = mul.getOperands().back().dyn_cast<IntegerAttr>()) {
614  auto nonCst = ParamExprAttr::get(PEO::Mul, mul.getOperands().drop_back());
615  return {nonCst, cst};
616  }
617  return {operand, TypedAttr()};
618 }
619 
620 static TypedAttr getOneOfType(Type type) {
621  return IntegerAttr::get(type, APInt(type.getIntOrFloatBitWidth(), 1));
622 }
623 
624 static TypedAttr simplifyAdd(SmallVector<TypedAttr, 4> &operands) {
625  if (auto result = simplifyAssocOp(
626  PEO::Add, operands, [](auto a, auto b) { return a + b; },
627  /*identityCst*/ [](auto cst) { return cst.isZero(); }))
628  return result;
629 
630  // Canonicalize the add by splitting all addends into their variable and
631  // constant factors.
632  SmallVector<std::pair<TypedAttr, TypedAttr>> decomposedOperands;
633  llvm::SmallDenseSet<TypedAttr> nonConstantParts;
634  for (auto &op : operands) {
635  decomposedOperands.push_back(decomposeAddend(op));
636 
637  // Keep track of non-constant parts we've already seen. If we see multiple
638  // uses of the same value, then we can fold them together with a multiply.
639  // This handles things like `(a+b+a)` => `(a*2 + b)` and `(a*2 + b + a)` =>
640  // `(a*3 + b)`.
641  if (!nonConstantParts.insert(decomposedOperands.back().first).second) {
642  // The thing we multiply will be the common expression.
643  TypedAttr mulOperand = decomposedOperands.back().first;
644 
645  // Find the index of the first occurrence.
646  size_t i = 0;
647  while (decomposedOperands[i].first != mulOperand)
648  ++i;
649  // Remove both occurrences from the operand list.
650  operands.erase(operands.begin() + (&op - &operands[0]));
651  operands.erase(operands.begin() + i);
652 
653  auto type = mulOperand.getType();
654  auto c1 = decomposedOperands[i].second,
655  c2 = decomposedOperands.back().second;
656  // Fill in missing constant multiplicands with 1.
657  if (!c1)
658  c1 = getOneOfType(type);
659  if (!c2)
660  c2 = getOneOfType(type);
661  // Re-add the "a"*(c1+c2) expression to the operand list and
662  // re-canonicalize.
663  auto constant = ParamExprAttr::get(PEO::Add, c1, c2);
664  auto mulCst = ParamExprAttr::get(PEO::Mul, mulOperand, constant);
665  operands.push_back(mulCst);
666  return ParamExprAttr::get(PEO::Add, operands);
667  }
668  }
669 
670  return {};
671 }
672 
673 static TypedAttr simplifyMul(SmallVector<TypedAttr, 4> &operands) {
674  if (auto result = simplifyAssocOp(
675  PEO::Mul, operands, [](auto a, auto b) { return a * b; },
676  /*identityCst*/ [](auto cst) { return cst.isOne(); },
677  /*destructiveCst*/ [](auto cst) { return cst.isZero(); }))
678  return result;
679 
680  // We always build a sum-of-products representation, so if we see an addition
681  // as a subexpr, we need to pull it out: (a+b)*c*d ==> (a*c*d + b*c*d).
682  for (size_t i = 0, e = operands.size(); i != e; ++i) {
683  if (auto subexpr = dyn_castPE(PEO::Add, operands[i])) {
684  // Pull the `c*d` operands out - it is whatever operands remain after
685  // removing the `(a+b)` term.
686  SmallVector<TypedAttr> mulOperands(operands.begin(), operands.end());
687  mulOperands.erase(mulOperands.begin() + i);
688 
689  // Build each add operand.
690  SmallVector<TypedAttr> addOperands;
691  for (auto addOperand : subexpr.getOperands()) {
692  mulOperands.push_back(addOperand);
693  addOperands.push_back(ParamExprAttr::get(PEO::Mul, mulOperands));
694  mulOperands.pop_back();
695  }
696  // Canonicalize and form the add expression.
697  return ParamExprAttr::get(PEO::Add, addOperands);
698  }
699  }
700 
701  return {};
702 }
703 static TypedAttr simplifyAnd(SmallVector<TypedAttr, 4> &operands) {
704  return simplifyAssocOp(
705  PEO::And, operands, [](auto a, auto b) { return a & b; },
706  /*identityCst*/ [](auto cst) { return cst.isAllOnes(); },
707  /*destructiveCst*/ [](auto cst) { return cst.isZero(); });
708 }
709 
710 static TypedAttr simplifyOr(SmallVector<TypedAttr, 4> &operands) {
711  return simplifyAssocOp(
712  PEO::Or, operands, [](auto a, auto b) { return a | b; },
713  /*identityCst*/ [](auto cst) { return cst.isZero(); },
714  /*destructiveCst*/ [](auto cst) { return cst.isAllOnes(); });
715 }
716 
717 static TypedAttr simplifyXor(SmallVector<TypedAttr, 4> &operands) {
718  return simplifyAssocOp(
719  PEO::Xor, operands, [](auto a, auto b) { return a ^ b; },
720  /*identityCst*/ [](auto cst) { return cst.isZero(); });
721 }
722 
723 static TypedAttr simplifyShl(SmallVector<TypedAttr, 4> &operands) {
724  assert(isHWIntegerType(operands[0].getType()));
725 
726  if (auto rhs = operands[1].dyn_cast<IntegerAttr>()) {
727  // Constant fold simple integers.
728  if (auto lhs = operands[0].dyn_cast<IntegerAttr>())
729  return IntegerAttr::get(lhs.getType(),
730  lhs.getValue().shl(rhs.getValue()));
731 
732  // Canonicalize `x << cst` => `x * (1<<cst)` to compose correctly with
733  // add/mul canonicalization.
734  auto rhsCst = APInt::getOneBitSet(rhs.getValue().getBitWidth(),
735  rhs.getValue().getZExtValue());
736  return ParamExprAttr::get(PEO::Mul, operands[0],
737  IntegerAttr::get(rhs.getType(), rhsCst));
738  }
739  return {};
740 }
741 
742 static TypedAttr simplifyShrU(SmallVector<TypedAttr, 4> &operands) {
743  assert(isHWIntegerType(operands[0].getType()));
744  // Implement support for identities like `x >> 0`.
745  if (auto rhs = operands[1].dyn_cast<IntegerAttr>())
746  if (rhs.getValue().isZero())
747  return operands[0];
748 
749  return foldBinaryOp(operands, [](auto a, auto b) { return a.lshr(b); });
750 }
751 
752 static TypedAttr simplifyShrS(SmallVector<TypedAttr, 4> &operands) {
753  assert(isHWIntegerType(operands[0].getType()));
754  // Implement support for identities like `x >> 0`.
755  if (auto rhs = operands[1].dyn_cast<IntegerAttr>())
756  if (rhs.getValue().isZero())
757  return operands[0];
758 
759  return foldBinaryOp(operands, [](auto a, auto b) { return a.ashr(b); });
760 }
761 
762 static TypedAttr simplifyDivU(SmallVector<TypedAttr, 4> &operands) {
763  assert(isHWIntegerType(operands[0].getType()));
764  // Implement support for identities like `x/1`.
765  if (auto rhs = operands[1].dyn_cast<IntegerAttr>())
766  if (rhs.getValue().isOne())
767  return operands[0];
768 
769  return foldBinaryOp(operands, [](auto a, auto b) { return a.udiv(b); });
770 }
771 
772 static TypedAttr simplifyDivS(SmallVector<TypedAttr, 4> &operands) {
773  assert(isHWIntegerType(operands[0].getType()));
774  // Implement support for identities like `x/1`.
775  if (auto rhs = operands[1].dyn_cast<IntegerAttr>())
776  if (rhs.getValue().isOne())
777  return operands[0];
778 
779  return foldBinaryOp(operands, [](auto a, auto b) { return a.sdiv(b); });
780 }
781 
782 static TypedAttr simplifyModU(SmallVector<TypedAttr, 4> &operands) {
783  assert(isHWIntegerType(operands[0].getType()));
784  // Implement support for identities like `x%1`.
785  if (auto rhs = operands[1].dyn_cast<IntegerAttr>())
786  if (rhs.getValue().isOne())
787  return IntegerAttr::get(rhs.getType(), 0);
788 
789  return foldBinaryOp(operands, [](auto a, auto b) { return a.urem(b); });
790 }
791 
792 static TypedAttr simplifyModS(SmallVector<TypedAttr, 4> &operands) {
793  assert(isHWIntegerType(operands[0].getType()));
794  // Implement support for identities like `x%1`.
795  if (auto rhs = operands[1].dyn_cast<IntegerAttr>())
796  if (rhs.getValue().isOne())
797  return IntegerAttr::get(rhs.getType(), 0);
798 
799  return foldBinaryOp(operands, [](auto a, auto b) { return a.srem(b); });
800 }
801 
802 static TypedAttr simplifyCLog2(SmallVector<TypedAttr, 4> &operands) {
803  assert(isHWIntegerType(operands[0].getType()));
804  return foldUnaryOp(operands, [](auto a) {
805  // Following the Verilog spec, clog2(0) is 0
806  return APInt(a.getBitWidth(), a == 0 ? 0 : a.ceilLogBase2());
807  });
808 }
809 
810 static TypedAttr simplifyStrConcat(SmallVector<TypedAttr, 4> &operands) {
811  // Combine all adjacent strings.
812  SmallVector<TypedAttr> newOperands;
813  SmallVector<StringAttr> stringsToCombine;
814  auto combineAndPush = [&]() {
815  if (stringsToCombine.empty())
816  return;
817  // Concatenate buffered strings, push to ops.
818  SmallString<32> newString;
819  for (auto part : stringsToCombine)
820  newString.append(part.getValue());
821  newOperands.push_back(
822  StringAttr::get(stringsToCombine[0].getContext(), newString));
823  stringsToCombine.clear();
824  };
825 
826  for (TypedAttr op : operands) {
827  if (auto strOp = op.dyn_cast<StringAttr>()) {
828  // Queue up adjacent strings.
829  stringsToCombine.push_back(strOp);
830  } else {
831  combineAndPush();
832  newOperands.push_back(op);
833  }
834  }
835  combineAndPush();
836 
837  assert(!newOperands.empty());
838  if (newOperands.size() == 1)
839  return newOperands[0];
840  if (newOperands.size() < operands.size())
841  return ParamExprAttr::get(PEO::StrConcat, newOperands);
842  return {};
843 }
844 
845 /// Build a parameter expression. This automatically canonicalizes and
846 /// folds, so it may not necessarily return a ParamExprAttr.
847 TypedAttr ParamExprAttr::get(PEO opcode, ArrayRef<TypedAttr> operandsIn) {
848  assert(!operandsIn.empty() && "Cannot have expr with no operands");
849  // All operands must have the same type, which is the type of the result.
850  auto type = operandsIn.front().getType();
851  assert(llvm::all_of(operandsIn.drop_front(),
852  [&](auto op) { return op.getType() == type; }));
853 
854  SmallVector<TypedAttr, 4> operands(operandsIn.begin(), operandsIn.end());
855 
856  // Verify and canonicalize parameter expressions.
857  TypedAttr result;
858  switch (opcode) {
859  case PEO::Add:
860  result = simplifyAdd(operands);
861  break;
862  case PEO::Mul:
863  result = simplifyMul(operands);
864  break;
865  case PEO::And:
866  result = simplifyAnd(operands);
867  break;
868  case PEO::Or:
869  result = simplifyOr(operands);
870  break;
871  case PEO::Xor:
872  result = simplifyXor(operands);
873  break;
874  case PEO::Shl:
875  result = simplifyShl(operands);
876  break;
877  case PEO::ShrU:
878  result = simplifyShrU(operands);
879  break;
880  case PEO::ShrS:
881  result = simplifyShrS(operands);
882  break;
883  case PEO::DivU:
884  result = simplifyDivU(operands);
885  break;
886  case PEO::DivS:
887  result = simplifyDivS(operands);
888  break;
889  case PEO::ModU:
890  result = simplifyModU(operands);
891  break;
892  case PEO::ModS:
893  result = simplifyModS(operands);
894  break;
895  case PEO::CLog2:
896  result = simplifyCLog2(operands);
897  break;
898  case PEO::StrConcat:
899  result = simplifyStrConcat(operands);
900  break;
901  }
902 
903  // If we folded to an operand, return it.
904  if (result)
905  return result;
906 
907  return Base::get(operands[0].getContext(), opcode, operands, type);
908 }
909 
910 Attribute ParamExprAttr::parse(AsmParser &p, Type type) {
911  // We require an opcode suffix like `#hw.param.expr.add`, we don't allow
912  // parsing a plain `#hw.param.expr` on its own.
913  p.emitError(p.getNameLoc(), "#hw.param.expr should have opcode suffix");
914  return {};
915 }
916 
917 /// Internal method used for .mlir file parsing when parsing the
918 /// "#hw.param.expr.mul" form of the attribute.
919 static Attribute parseParamExprWithOpcode(StringRef opcodeStr,
920  DialectAsmParser &p, Type type) {
921  SmallVector<TypedAttr> operands;
922  if (p.parseCommaSeparatedList(
923  mlir::AsmParser::Delimiter::LessGreater, [&]() -> ParseResult {
924  operands.push_back({});
925  return p.parseAttribute(operands.back(), type);
926  }))
927  return {};
928 
929  std::optional<PEO> opcode = symbolizePEO(opcodeStr);
930  if (!opcode.has_value()) {
931  p.emitError(p.getNameLoc(), "unknown parameter expr operator name");
932  return {};
933  }
934 
935  return ParamExprAttr::get(*opcode, operands);
936 }
937 
938 void ParamExprAttr::print(AsmPrinter &p) const {
939  p << "." << stringifyPEO(getOpcode()) << '<';
940  llvm::interleaveComma(getOperands(), p.getStream(),
941  [&](Attribute op) { p.printAttributeWithoutType(op); });
942  p << '>';
943 }
944 
945 // Replaces any ParamDeclRefAttr within a parametric expression with its
946 // corresponding value from the map of provided parameters.
947 static FailureOr<Attribute>
948 replaceDeclRefInExpr(Location loc,
949  const std::map<std::string, Attribute> &parameters,
950  Attribute paramAttr, bool emitErrors) {
951  if (paramAttr.dyn_cast<IntegerAttr>()) {
952  // Nothing to do, constant value.
953  return paramAttr;
954  }
955  if (auto paramRefAttr = paramAttr.dyn_cast<hw::ParamDeclRefAttr>()) {
956  // Get the value from the provided parameters.
957  auto it = parameters.find(paramRefAttr.getName().str());
958  if (it == parameters.end()) {
959  if (emitErrors)
960  return emitError(loc)
961  << "Could not find parameter " << paramRefAttr.getName().str()
962  << " in the provided parameters for the expression!";
963  return failure();
964  }
965  return it->second;
966  }
967  if (auto paramExprAttr = paramAttr.dyn_cast<hw::ParamExprAttr>()) {
968  // Recurse into all operands of the expression.
969  llvm::SmallVector<TypedAttr, 4> replacedOperands;
970  for (auto operand : paramExprAttr.getOperands()) {
971  auto res = replaceDeclRefInExpr(loc, parameters, operand, emitErrors);
972  if (failed(res))
973  return {failure()};
974  replacedOperands.push_back(res->cast<TypedAttr>());
975  }
976  return {
977  hw::ParamExprAttr::get(paramExprAttr.getOpcode(), replacedOperands)};
978  }
979  llvm_unreachable("Unhandled parametric attribute");
980  return {};
981 }
982 
983 FailureOr<TypedAttr> hw::evaluateParametricAttr(Location loc,
984  ArrayAttr parameters,
985  Attribute paramAttr,
986  bool emitErrors) {
987  // Create a map of the provided parameters for faster lookup.
988  std::map<std::string, Attribute> parameterMap;
989  for (auto param : parameters) {
990  auto paramDecl = param.cast<ParamDeclAttr>();
991  parameterMap[paramDecl.getName().str()] = paramDecl.getValue();
992  }
993 
994  // First, replace any ParamDeclRefAttr in the expression with its
995  // corresponding value in 'parameters'.
996  auto paramAttrRes =
997  replaceDeclRefInExpr(loc, parameterMap, paramAttr, emitErrors);
998  if (failed(paramAttrRes))
999  return {failure()};
1000  paramAttr = *paramAttrRes;
1001 
1002  // Then, evaluate the parametric attribute.
1003  if (paramAttr.isa<IntegerAttr, hw::ParamDeclRefAttr>())
1004  return paramAttr.cast<TypedAttr>();
1005  if (auto paramExprAttr = paramAttr.dyn_cast<hw::ParamExprAttr>()) {
1006  // Since any ParamDeclRefAttr was replaced within the expression,
1007  // we re-evaluate the expression through the existing ParamExprAttr
1008  // canonicalizer.
1009  return ParamExprAttr::get(paramExprAttr.getOpcode(),
1010  paramExprAttr.getOperands());
1011  }
1012 
1013  llvm_unreachable("Unhandled parametric attribute");
1014  return TypedAttr();
1015 }
1016 
1017 template <typename TArray>
1018 FailureOr<Type> evaluateParametricArrayType(Location loc, ArrayAttr parameters,
1019  TArray arrayType, bool emitErrors) {
1020  auto size = evaluateParametricAttr(loc, parameters, arrayType.getSizeAttr(),
1021  emitErrors);
1022  if (failed(size))
1023  return failure();
1024  auto elementType = evaluateParametricType(
1025  loc, parameters, arrayType.getElementType(), emitErrors);
1026  if (failed(elementType))
1027  return failure();
1028 
1029  // If the size was evaluated to a constant, use a 64-bit integer
1030  // attribute version of it
1031  if (auto intAttr = size->template dyn_cast<IntegerAttr>())
1032  return TArray::get(
1033  arrayType.getContext(), *elementType,
1034  IntegerAttr::get(IntegerType::get(arrayType.getContext(), 64),
1035  intAttr.getValue().getSExtValue()));
1036 
1037  // Otherwise parameter references are still involved
1038  return TArray::get(arrayType.getContext(), *elementType, *size);
1039 }
1040 
1041 FailureOr<Type> hw::evaluateParametricType(Location loc, ArrayAttr parameters,
1042  Type type, bool emitErrors) {
1043  return llvm::TypeSwitch<Type, FailureOr<Type>>(type)
1044  .Case<hw::IntType>([&](hw::IntType t) -> FailureOr<Type> {
1045  auto evaluatedWidth =
1046  evaluateParametricAttr(loc, parameters, t.getWidth(), emitErrors);
1047  if (failed(evaluatedWidth))
1048  return {failure()};
1049 
1050  // If the width was evaluated to a constant, return an `IntegerType`
1051  if (auto intAttr = evaluatedWidth->dyn_cast<IntegerAttr>())
1052  return {IntegerType::get(type.getContext(),
1053  intAttr.getValue().getSExtValue())};
1054 
1055  // Otherwise parameter references are still involved
1056  return hw::IntType::get(evaluatedWidth->cast<TypedAttr>());
1057  })
1058  .Case<hw::ArrayType, hw::UnpackedArrayType>(
1059  [&](auto arrayType) -> FailureOr<Type> {
1060  return evaluateParametricArrayType(loc, parameters, arrayType,
1061  emitErrors);
1062  })
1063  .Default([&](auto) { return type; });
1064 }
1065 
1066 // Returns true if any part of this parametric attribute contains a reference
1067 // to a parameter declaration.
1068 static bool isParamAttrWithParamRef(Attribute expr) {
1069  return llvm::TypeSwitch<Attribute, bool>(expr)
1070  .Case([](ParamExprAttr attr) {
1071  return llvm::any_of(attr.getOperands(), isParamAttrWithParamRef);
1072  })
1073  .Case([](ParamDeclRefAttr) { return true; })
1074  .Default([](auto) { return false; });
1075 }
1076 
1077 bool hw::isParametricType(mlir::Type t) {
1078  return llvm::TypeSwitch<Type, bool>(t)
1079  .Case<hw::IntType>(
1080  [&](hw::IntType t) { return isParamAttrWithParamRef(t.getWidth()); })
1081  .Case<hw::ArrayType, hw::UnpackedArrayType>([&](auto arrayType) {
1082  return isParametricType(arrayType.getElementType()) ||
1083  isParamAttrWithParamRef(arrayType.getSizeAttr());
1084  })
1085  .Default([](auto) { return false; });
1086 }
lowerAnnotationsNoRefTypePorts FirtoolPreserveValuesMode value
Definition: Firtool.cpp:95
assert(baseType &&"element must be base type")
static TypedAttr foldBinaryOp(ArrayRef< TypedAttr > operands, llvm::function_ref< APInt(const APInt &, const APInt &)> calculate)
Given a binary function, if the two operands are known constant integers, use the specified fold func...
static TypedAttr simplifyDivS(SmallVector< TypedAttr, 4 > &operands)
static TypedAttr simplifyAssocOp(PEO opcode, SmallVector< TypedAttr, 4 > &operands, llvm::function_ref< APInt(const APInt &, const APInt &)> calculateFn, llvm::function_ref< bool(const APInt &)> identityConstantFn, llvm::function_ref< bool(const APInt &)> destructiveConstantFn={})
Given a fully associative variadic integer operation, constant fold any constant operands and move th...
static TypedAttr simplifyDivU(SmallVector< TypedAttr, 4 > &operands)
static TypedAttr simplifyCLog2(SmallVector< TypedAttr, 4 > &operands)
static TypedAttr simplifyModU(SmallVector< TypedAttr, 4 > &operands)
static TypedAttr simplifyOr(SmallVector< TypedAttr, 4 > &operands)
static TypedAttr simplifyShl(SmallVector< TypedAttr, 4 > &operands)
static void printSymbolName(AsmPrinter &p, StringAttr sym)
static TypedAttr simplifyShrS(SmallVector< TypedAttr, 4 > &operands)
static TypedAttr foldUnaryOp(ArrayRef< TypedAttr > operands, llvm::function_ref< APInt(const APInt &)> calculate)
Given a unary function, if the operand is a known constant integer, use the specified fold function t...
static TypedAttr simplifyMul(SmallVector< TypedAttr, 4 > &operands)
static TypedAttr getOneOfType(Type type)
static TypedAttr simplifyAnd(SmallVector< TypedAttr, 4 > &operands)
static std::pair< TypedAttr, TypedAttr > decomposeAddend(TypedAttr operand)
Analyze an operand to an add.
static TypedAttr simplifyModS(SmallVector< TypedAttr, 4 > &operands)
static Attribute parseParamExprWithOpcode(StringRef opcode, DialectAsmParser &p, Type type)
Internal method used for .mlir file parsing when parsing the "#hw.param.expr.mul" form of the attribu...
static bool paramExprOperandSortPredicate(Attribute lhs, Attribute rhs)
This implements a < comparison for two operands to an associative operation imposing an ordering upon...
static std::string canonicalizeFilename(const Twine &directory, const Twine &filename)
static TypedAttr simplifyShrU(SmallVector< TypedAttr, 4 > &operands)
static TypedAttr simplifyAdd(SmallVector< TypedAttr, 4 > &operands)
static TypedAttr simplifyStrConcat(SmallVector< TypedAttr, 4 > &operands)
static TypedAttr simplifyXor(SmallVector< TypedAttr, 4 > &operands)
static ParamExprAttr dyn_castPE(PEO opcode, Attribute value)
If the specified attribute is a ParamExprAttr with the specified opcode, return it.
static StringAttr append(StringAttr base, const Twine &suffix)
Return a attribute with the specified suffix appended.
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:53
uint64_t getFieldID(Type type, uint64_t index)
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
bool isHWIntegerType(mlir::Type type)
Return true if the specified type is a value HW Integer type.
Definition: HWTypes.cpp:52
bool isHWEnumType(mlir::Type type)
Return true if the specified type is a HW Enum type.
Definition: HWTypes.cpp:65
mlir::Type getCanonicalType(mlir::Type type)
Definition: HWTypes.cpp:41
This file defines an intermediate representation for circuits acting as an abstraction for constraint...
Definition: DebugAnalysis.h:21