Loading [MathJax]/extensions/tex2jax.js
CIRCT 21.0.0git
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
HWAttributes.cpp
Go to the documentation of this file.
1//===- HWAttributes.cpp - Implement HW attributes -------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
13#include "circt/Support/LLVM.h"
14#include "mlir/IR/Diagnostics.h"
15#include "mlir/IR/DialectImplementation.h"
16#include "llvm/ADT/SmallString.h"
17#include "llvm/ADT/StringExtras.h"
18#include "llvm/ADT/TypeSwitch.h"
19#include "llvm/Support/FileSystem.h"
20#include "llvm/Support/Path.h"
21
22using namespace circt;
23using namespace circt::hw;
24using mlir::TypedAttr;
25
26// Internal method used for .mlir file parsing, defined below.
27static Attribute parseParamExprWithOpcode(StringRef opcode, DialectAsmParser &p,
28 Type type);
29
30//===----------------------------------------------------------------------===//
31// ODS Boilerplate
32//===----------------------------------------------------------------------===//
33
34#define GET_ATTRDEF_CLASSES
35#include "circt/Dialect/HW/HWAttributes.cpp.inc"
36
37void HWDialect::registerAttributes() {
38 addAttributes<
39#define GET_ATTRDEF_LIST
40#include "circt/Dialect/HW/HWAttributes.cpp.inc"
41 >();
42}
43
44Attribute HWDialect::parseAttribute(DialectAsmParser &p, Type type) const {
45 StringRef attrName;
46 Attribute attr;
47 auto parseResult = generatedAttributeParser(p, &attrName, type, attr);
48 if (parseResult.has_value())
49 return attr;
50
51 // Parse "#hw.param.expr.add" as ParamExprAttr.
52 if (attrName.starts_with(ParamExprAttr::getMnemonic())) {
53 auto string = attrName.drop_front(ParamExprAttr::getMnemonic().size());
54 if (string.front() == '.')
55 return parseParamExprWithOpcode(string.drop_front(), p, type);
56 }
57
58 p.emitError(p.getNameLoc(), "Unexpected hw attribute '" + attrName + "'");
59 return {};
60}
61
62void HWDialect::printAttribute(Attribute attr, DialectAsmPrinter &p) const {
63 if (succeeded(generatedAttributePrinter(attr, p)))
64 return;
65 llvm_unreachable("Unexpected attribute");
66}
67
68//===----------------------------------------------------------------------===//
69// OutputFileAttr
70//===----------------------------------------------------------------------===//
71
72static std::string canonicalizeFilename(const Twine &directory,
73 const Twine &filename) {
74
75 // Convert the filename to a native style path.
76 SmallString<128> nativeFilename;
77 llvm::sys::path::native(filename, nativeFilename);
78
79 // If the filename is an absolute path, ignore the directory.
80 // e.g. `directory/` + `/etc/filename` -> `/etc/filename`.
81 if (llvm::sys::path::is_absolute(nativeFilename))
82 return std::string(nativeFilename);
83
84 // Convert the directory to a native style path.
85 SmallString<128> nativeDirectory;
86 llvm::sys::path::native(directory, nativeDirectory);
87
88 // If the filename component is empty, then ensure that the path ends in a
89 // separator and return it.
90 // e.g. `directory` + `` -> `directory/`.
91 auto separator = llvm::sys::path::get_separator();
92 if (nativeFilename.empty() && !nativeDirectory.ends_with(separator)) {
93 nativeDirectory += separator;
94 return std::string(nativeDirectory);
95 }
96
97 // Append the directory and filename together.
98 // e.g. `/tmp/` + `out/filename` -> `/tmp/out/filename`.
99 SmallString<128> fullPath;
100 llvm::sys::path::append(fullPath, nativeDirectory, nativeFilename);
101 return std::string(fullPath);
102}
103
104OutputFileAttr OutputFileAttr::getFromFilename(MLIRContext *context,
105 const Twine &filename,
106 bool excludeFromFileList,
107 bool includeReplicatedOps) {
108 return OutputFileAttr::getFromDirectoryAndFilename(
109 context, "", filename, excludeFromFileList, includeReplicatedOps);
110}
111
112OutputFileAttr OutputFileAttr::getFromDirectoryAndFilename(
113 MLIRContext *context, const Twine &directory, const Twine &filename,
114 bool excludeFromFileList, bool includeReplicatedOps) {
115 auto canonicalized = canonicalizeFilename(directory, filename);
116 return OutputFileAttr::get(StringAttr::get(context, canonicalized),
117 BoolAttr::get(context, excludeFromFileList),
118 BoolAttr::get(context, includeReplicatedOps));
119}
120
121OutputFileAttr OutputFileAttr::getAsDirectory(MLIRContext *context,
122 const Twine &directory,
123 bool excludeFromFileList,
124 bool includeReplicatedOps) {
125 return getFromDirectoryAndFilename(context, directory, "",
126 excludeFromFileList, includeReplicatedOps);
127}
128
129bool OutputFileAttr::isDirectory() {
130 return getFilename().getValue().ends_with(llvm::sys::path::get_separator());
131}
132
133StringRef OutputFileAttr::getDirectory() {
134 auto dir = getFilename().getValue();
135 for (unsigned i = 0, e = dir.size(); i < e; ++i) {
136 if (dir.ends_with(llvm::sys::path::get_separator()))
137 break;
138 dir = dir.drop_back();
139 }
140 return dir;
141}
142
143/// Option ::= 'excludeFromFileList' | 'includeReplicatedOp'
144/// OutputFileAttr ::= 'output_file<' directory ',' name (',' Option)* '>'
145Attribute OutputFileAttr::parse(AsmParser &p, Type type) {
146 StringAttr filename;
147 if (p.parseLess() || p.parseAttribute<StringAttr>(filename))
148 return Attribute();
149
150 // Parse the additional keyword attributes. Its easier to let people specify
151 // these more than once than to detect the problem and do something about it.
152 bool excludeFromFileList = false;
153 bool includeReplicatedOps = false;
154 while (true) {
155 if (p.parseOptionalComma())
156 break;
157 if (!p.parseOptionalKeyword("excludeFromFileList"))
158 excludeFromFileList = true;
159 else if (!p.parseKeyword("includeReplicatedOps",
160 "or 'excludeFromFileList'"))
161 includeReplicatedOps = true;
162 else
163 return Attribute();
164 }
165
166 if (p.parseGreater())
167 return Attribute();
168
169 return OutputFileAttr::getFromFilename(p.getContext(), filename.getValue(),
170 excludeFromFileList,
171 includeReplicatedOps);
172}
173
174void OutputFileAttr::print(AsmPrinter &p) const {
175 p << "<" << getFilename();
176 if (getExcludeFromFilelist().getValue())
177 p << ", excludeFromFileList";
178 if (getIncludeReplicatedOps().getValue())
179 p << ", includeReplicatedOps";
180 p << ">";
181}
182
183//===----------------------------------------------------------------------===//
184// EnumFieldAttr
185//===----------------------------------------------------------------------===//
186
187Attribute EnumFieldAttr::parse(AsmParser &p, Type) {
188 StringRef field;
189 Type type;
190 if (p.parseLess() || p.parseKeyword(&field) || p.parseComma() ||
191 p.parseType(type) || p.parseGreater())
192 return Attribute();
193 return EnumFieldAttr::get(p.getEncodedSourceLoc(p.getCurrentLocation()),
194 StringAttr::get(p.getContext(), field), type);
195}
196
197void EnumFieldAttr::print(AsmPrinter &p) const {
198 p << "<" << getField().getValue() << ", ";
199 p.printType(getType().getValue());
200 p << ">";
201}
202
203EnumFieldAttr EnumFieldAttr::get(Location loc, StringAttr value,
204 mlir::Type type) {
205 if (!hw::isHWEnumType(type))
206 emitError(loc) << "expected enum type";
207
208 // Check whether the provided value is a member of the enum type.
209 EnumType enumType = llvm::cast<EnumType>(getCanonicalType(type));
210 if (!enumType.contains(value.getValue())) {
211 emitError(loc) << "enum value '" << value.getValue()
212 << "' is not a member of enum type " << enumType;
213 return nullptr;
214 }
215
216 return Base::get(value.getContext(), value, TypeAttr::get(type));
217}
218
219//===----------------------------------------------------------------------===//
220// InnerRefAttr
221//===----------------------------------------------------------------------===//
222
223Attribute InnerRefAttr::parse(AsmParser &p, Type type) {
224 SymbolRefAttr attr;
225 if (p.parseLess() || p.parseAttribute<SymbolRefAttr>(attr) ||
226 p.parseGreater())
227 return Attribute();
228 if (attr.getNestedReferences().size() != 1) {
229 p.emitError(p.getNameLoc(),
230 "inner reference must have exactly one nested reference");
231 return Attribute();
232 }
233 return InnerRefAttr::get(attr.getRootReference(), attr.getLeafReference());
234}
235
236static void printSymbolName(AsmPrinter &p, StringAttr sym) {
237 if (sym)
238 p.printSymbolName(sym.getValue());
239 else
240 p.printSymbolName({});
241}
242
243void InnerRefAttr::print(AsmPrinter &p) const {
244 p << "<";
245 printSymbolName(p, getModule());
246 p << "::";
248 p << ">";
249}
250
251//===----------------------------------------------------------------------===//
252// InnerSymAttr and InnerSymPropertiesAttr
253//===----------------------------------------------------------------------===//
254
255Attribute InnerSymPropertiesAttr::parse(AsmParser &parser, Type type) {
256 StringAttr name;
257 NamedAttrList dummyList;
258 int64_t fieldId = 0;
259 if (parser.parseLess() || parser.parseSymbolName(name, "name", dummyList) ||
260 parser.parseComma() || parser.parseInteger(fieldId) ||
261 parser.parseComma())
262 return Attribute();
263
264 StringRef visibility;
265 auto loc = parser.getCurrentLocation();
266 if (parser.parseOptionalKeyword(&visibility,
267 {"public", "private", "nested"})) {
268 parser.emitError(loc, "expected 'public', 'private', or 'nested'");
269 return Attribute();
270 }
271 auto visibilityAttr = parser.getBuilder().getStringAttr(visibility);
272
273 if (parser.parseGreater())
274 return Attribute();
275
276 return parser.getChecked<InnerSymPropertiesAttr>(parser.getContext(), name,
277 fieldId, visibilityAttr);
278}
279
280void InnerSymPropertiesAttr::print(AsmPrinter &odsPrinter) const {
281 odsPrinter << "<@" << getName().getValue() << "," << getFieldID() << ","
282 << getSymVisibility().getValue() << ">";
283}
284
285LogicalResult InnerSymPropertiesAttr::verify(
286 ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError,
287 ::mlir::StringAttr name, uint64_t fieldID,
288 ::mlir::StringAttr symVisibility) {
289 if (!name || name.getValue().empty())
290 return emitError() << "inner symbol cannot have empty name";
291 return success();
292}
293
294StringAttr InnerSymAttr::getSymIfExists(uint64_t fieldId) const {
295 const auto *it =
296 llvm::find_if(getImpl()->props, [&](const InnerSymPropertiesAttr &p) {
297 return p.getFieldID() == fieldId;
298 });
299 if (it != getProps().end())
300 return it->getName();
301 return {};
302}
303
304InnerSymAttr InnerSymAttr::erase(uint64_t fieldID) const {
305 SmallVector<InnerSymPropertiesAttr> syms(getProps());
306 const auto *it = llvm::find_if(syms, [fieldID](InnerSymPropertiesAttr p) {
307 return p.getFieldID() == fieldID;
308 });
309 assert(it != syms.end());
310 syms.erase(it);
311 return InnerSymAttr::get(getContext(), syms);
312}
313
314LogicalResult InnerSymAttr::walkSymbols(
315 llvm::function_ref<LogicalResult(StringAttr)> callback) const {
316 for (auto p : getImpl()->props)
317 if (callback(p.getName()).failed())
318 return failure();
319 return success();
320}
321
322Attribute InnerSymAttr::parse(AsmParser &parser, Type type) {
323 StringAttr sym;
324 NamedAttrList dummyList;
325 SmallVector<InnerSymPropertiesAttr, 4> names;
326 if (!parser.parseOptionalSymbolName(sym, "dummy", dummyList)) {
327 auto prop = parser.getChecked<InnerSymPropertiesAttr>(
328 parser.getContext(), sym, 0,
329 StringAttr::get(parser.getContext(), "public"));
330 if (!prop)
331 return {};
332 names.push_back(prop);
333 } else if (parser.parseCommaSeparatedList(
334 OpAsmParser::Delimiter::Square, [&]() -> ParseResult {
335 InnerSymPropertiesAttr prop;
336 if (parser.parseCustomAttributeWithFallback(
337 prop, mlir::Type{}, "dummy", dummyList))
338 return failure();
339
340 names.push_back(prop);
341
342 return success();
343 }))
344 return Attribute();
345
346 std::sort(names.begin(), names.end(),
347 [&](InnerSymPropertiesAttr a, InnerSymPropertiesAttr b) {
348 return a.getFieldID() < b.getFieldID();
349 });
350
351 return InnerSymAttr::get(parser.getContext(), names);
352}
353
354void InnerSymAttr::print(AsmPrinter &odsPrinter) const {
355
356 auto props = getProps();
357 if (props.size() == 1 && props[0].getSymVisibility().getValue() == "public" &&
358 props[0].getFieldID() == 0) {
359 odsPrinter << "@" << props[0].getName().getValue();
360 return;
361 }
362 auto names = props.vec();
363
364 std::sort(names.begin(), names.end(),
365 [&](InnerSymPropertiesAttr a, InnerSymPropertiesAttr b) {
366 return a.getFieldID() < b.getFieldID();
367 });
368 odsPrinter << "[";
369 llvm::interleaveComma(names, odsPrinter, [&](InnerSymPropertiesAttr attr) {
370 attr.print(odsPrinter);
371 });
372 odsPrinter << "]";
373}
374
375//===----------------------------------------------------------------------===//
376// ParamDeclAttr
377//===----------------------------------------------------------------------===//
378
379Attribute ParamDeclAttr::parse(AsmParser &p, Type trailing) {
380 std::string name;
381 Type type;
382 Attribute value;
383 // < "FOO" : i32 > : i32
384 // < "FOO" : i32 = 0 > : i32
385 // < "FOO" : none >
386 if (p.parseLess() || p.parseString(&name) || p.parseColonType(type))
387 return Attribute();
388
389 if (succeeded(p.parseOptionalEqual())) {
390 if (p.parseAttribute(value, type))
391 return Attribute();
392 }
393
394 if (p.parseGreater())
395 return Attribute();
396
397 if (value)
398 return ParamDeclAttr::get(p.getContext(),
399 p.getBuilder().getStringAttr(name), type, value);
400 return ParamDeclAttr::get(name, type);
401}
402
403void ParamDeclAttr::print(AsmPrinter &p) const {
404 p << "<" << getName() << ": " << getType();
405 if (getValue()) {
406 p << " = ";
407 p.printAttributeWithoutType(getValue());
408 }
409 p << ">";
410}
411
412//===----------------------------------------------------------------------===//
413// ParamDeclRefAttr
414//===----------------------------------------------------------------------===//
415
416Attribute ParamDeclRefAttr::parse(AsmParser &p, Type type) {
417 StringAttr name;
418 if (p.parseLess() || p.parseAttribute(name) || p.parseGreater() ||
419 (!type && (p.parseColon() || p.parseType(type))))
420 return Attribute();
421
422 return ParamDeclRefAttr::get(name, type);
423}
424
425void ParamDeclRefAttr::print(AsmPrinter &p) const {
426 p << "<" << getName() << ">";
427}
428
429//===----------------------------------------------------------------------===//
430// ParamVerbatimAttr
431//===----------------------------------------------------------------------===//
432
433Attribute ParamVerbatimAttr::parse(AsmParser &p, Type type) {
434 StringAttr text;
435 if (p.parseLess() || p.parseAttribute(text) || p.parseGreater() ||
436 (!type && (p.parseColon() || p.parseType(type))))
437 return Attribute();
438
439 return ParamVerbatimAttr::get(p.getContext(), text, type);
440}
441
442void ParamVerbatimAttr::print(AsmPrinter &p) const {
443 p << "<" << getValue() << ">";
444}
445
446//===----------------------------------------------------------------------===//
447// ParamExprAttr
448//===----------------------------------------------------------------------===//
449
450/// Given a binary function, if the two operands are known constant integers,
451/// use the specified fold function to compute the result.
452static TypedAttr foldBinaryOp(
453 ArrayRef<TypedAttr> operands,
454 llvm::function_ref<APInt(const APInt &, const APInt &)> calculate) {
455 assert(operands.size() == 2 && "binary operator always has two operands");
456 if (auto lhs = dyn_cast<IntegerAttr>(operands[0]))
457 if (auto rhs = dyn_cast<IntegerAttr>(operands[1]))
458 return IntegerAttr::get(lhs.getType(),
459 calculate(lhs.getValue(), rhs.getValue()));
460 return {};
461}
462
463/// Given a unary function, if the operand is a known constant integer,
464/// use the specified fold function to compute the result.
465static TypedAttr
466foldUnaryOp(ArrayRef<TypedAttr> operands,
467 llvm::function_ref<APInt(const APInt &)> calculate) {
468 assert(operands.size() == 1 && "unary operator always has one operand");
469 if (auto intAttr = dyn_cast<IntegerAttr>(operands[0]))
470 return IntegerAttr::get(intAttr.getType(), calculate(intAttr.getValue()));
471 return {};
472}
473
474/// If the specified attribute is a ParamExprAttr with the specified opcode,
475/// return it. Otherwise return null.
476static ParamExprAttr dyn_castPE(PEO opcode, Attribute value) {
477 if (auto expr = dyn_cast<ParamExprAttr>(value))
478 if (expr.getOpcode() == opcode)
479 return expr;
480 return {};
481}
482
483/// This implements a < comparison for two operands to an associative operation
484/// imposing an ordering upon them.
485///
486/// The ordering provided puts more complex things to the start of the list,
487/// from left to right:
488/// expressions :: verbatims :: decl.refs :: constant
489///
490static bool paramExprOperandSortPredicate(Attribute lhs, Attribute rhs) {
491 // Simplify the code below - we never have to care about exactly equal values.
492 if (lhs == rhs)
493 return false;
494
495 // All expressions are "less than" a constant, since they appear on the right.
496 if (isa<IntegerAttr>(rhs)) {
497 // We don't bother to order constants w.r.t. each other since they will be
498 // folded - they can all compare equal.
499 return !isa<IntegerAttr>(lhs);
500 }
501 if (isa<IntegerAttr>(lhs))
502 return false;
503
504 // Next up are named parameters.
505 if (auto rhsParam = dyn_cast<ParamDeclRefAttr>(rhs)) {
506 // Parameters are sorted lexically w.r.t. each other.
507 if (auto lhsParam = dyn_cast<ParamDeclRefAttr>(lhs))
508 return lhsParam.getName().getValue() < rhsParam.getName().getValue();
509 // They otherwise appear on the right of other things.
510 return true;
511 }
512 if (isa<ParamDeclRefAttr>(lhs))
513 return false;
514
515 // Next up are verbatim parameters.
516 if (auto rhsParam = dyn_cast<ParamVerbatimAttr>(rhs)) {
517 // Verbatims are sorted lexically w.r.t. each other.
518 if (auto lhsParam = dyn_cast<ParamVerbatimAttr>(lhs))
519 return lhsParam.getValue().getValue() < rhsParam.getValue().getValue();
520 // They otherwise appear on the right of other things.
521 return true;
522 }
523 if (isa<ParamVerbatimAttr>(lhs))
524 return false;
525
526 // The only thing left are nested expressions.
527 auto lhsExpr = cast<ParamExprAttr>(lhs), rhsExpr = cast<ParamExprAttr>(rhs);
528 // Sort by the string form of the opcode, e.g. add, .. mul,... then xor.
529 if (lhsExpr.getOpcode() != rhsExpr.getOpcode())
530 return stringifyPEO(lhsExpr.getOpcode()) <
531 stringifyPEO(rhsExpr.getOpcode());
532
533 // If they are the same opcode, then sort by arity: more complex to the left.
534 ArrayRef<TypedAttr> lhsOperands = lhsExpr.getOperands(),
535 rhsOperands = rhsExpr.getOperands();
536 if (lhsOperands.size() != rhsOperands.size())
537 return lhsOperands.size() > rhsOperands.size();
538
539 // We know the two subexpressions are different (they'd otherwise be pointer
540 // equivalent) so just go compare all of the elements.
541 for (size_t i = 0, e = lhsOperands.size(); i != e; ++i) {
542 if (paramExprOperandSortPredicate(lhsOperands[i], rhsOperands[i]))
543 return true;
544 if (paramExprOperandSortPredicate(rhsOperands[i], lhsOperands[i]))
545 return false;
546 }
547
548 llvm_unreachable("expressions should never be equivalent");
549 return false;
550}
551
552/// Given a fully associative variadic integer operation, constant fold any
553/// constant operands and move them to the right. If the whole expression is
554/// constant, then return that, otherwise update the operands list.
555static TypedAttr simplifyAssocOp(
556 PEO opcode, SmallVector<TypedAttr, 4> &operands,
557 llvm::function_ref<APInt(const APInt &, const APInt &)> calculateFn,
558 llvm::function_ref<bool(const APInt &)> identityConstantFn,
559 llvm::function_ref<bool(const APInt &)> destructiveConstantFn = {}) {
560 auto type = operands[0].getType();
561 assert(isHWIntegerType(type));
562 if (operands.size() == 1)
563 return operands[0];
564
565 // Flatten any of the same operation into the operand list:
566 // `(add x, (add y, z))` => `(add x, y, z)`.
567 for (size_t i = 0, e = operands.size(); i != e; ++i) {
568 if (auto subexpr = dyn_castPE(opcode, operands[i])) {
569 std::swap(operands[i], operands.back());
570 operands.pop_back();
571 --e;
572 --i;
573 operands.append(subexpr.getOperands().begin(),
574 subexpr.getOperands().end());
575 }
576 }
577
578 // Impose an ordering on the operands, pushing subexpressions to the left and
579 // constants to the right, with verbatims and parameters in the middle - but
580 // predictably ordered w.r.t. each other.
581 llvm::stable_sort(operands, paramExprOperandSortPredicate);
582
583 // Merge any constants, they will appear at the back of the operand list now.
584 if (isa<IntegerAttr>(operands.back())) {
585 while (operands.size() >= 2 &&
586 isa<IntegerAttr>(operands[operands.size() - 2])) {
587 APInt c1 = cast<IntegerAttr>(operands.pop_back_val()).getValue();
588 APInt c2 = cast<IntegerAttr>(operands.pop_back_val()).getValue();
589 auto resultConstant = IntegerAttr::get(type, calculateFn(c1, c2));
590 operands.push_back(resultConstant);
591 }
592
593 auto resultCst = cast<IntegerAttr>(operands.back());
594
595 // If the resulting constant is the destructive constant (e.g. `x*0`), then
596 // return it.
597 if (destructiveConstantFn && destructiveConstantFn(resultCst.getValue()))
598 return resultCst;
599
600 // Remove the constant back to our operand list if it is the identity
601 // constant for this operator (e.g. `x*1`) and there are other operands.
602 if (identityConstantFn(resultCst.getValue()) && operands.size() != 1)
603 operands.pop_back();
604 }
605
606 return operands.size() == 1 ? operands[0] : TypedAttr();
607}
608
609/// Analyze an operand to an add. If it is a multiplication by a constant (e.g.
610/// `(a*b*42)` then split it into the non-constant and the constant portions
611/// (e.g. `a*b` and `42`). Otherwise return the operand as the first value and
612/// null as the second (standin for "multiplication by 1").
613static std::pair<TypedAttr, TypedAttr> decomposeAddend(TypedAttr operand) {
614 if (auto mul = dyn_castPE(PEO::Mul, operand))
615 if (auto cst = dyn_cast<IntegerAttr>(mul.getOperands().back())) {
616 auto nonCst = ParamExprAttr::get(PEO::Mul, mul.getOperands().drop_back());
617 return {nonCst, cst};
618 }
619 return {operand, TypedAttr()};
620}
621
622static TypedAttr getOneOfType(Type type) {
623 return IntegerAttr::get(type, APInt(type.getIntOrFloatBitWidth(), 1));
624}
625
626static TypedAttr simplifyAdd(SmallVector<TypedAttr, 4> &operands) {
627 if (auto result = simplifyAssocOp(
628 PEO::Add, operands, [](auto a, auto b) { return a + b; },
629 /*identityCst*/ [](auto cst) { return cst.isZero(); }))
630 return result;
631
632 // Canonicalize the add by splitting all addends into their variable and
633 // constant factors.
634 SmallVector<std::pair<TypedAttr, TypedAttr>> decomposedOperands;
635 llvm::SmallDenseSet<TypedAttr> nonConstantParts;
636 for (auto &op : operands) {
637 decomposedOperands.push_back(decomposeAddend(op));
638
639 // Keep track of non-constant parts we've already seen. If we see multiple
640 // uses of the same value, then we can fold them together with a multiply.
641 // This handles things like `(a+b+a)` => `(a*2 + b)` and `(a*2 + b + a)` =>
642 // `(a*3 + b)`.
643 if (!nonConstantParts.insert(decomposedOperands.back().first).second) {
644 // The thing we multiply will be the common expression.
645 TypedAttr mulOperand = decomposedOperands.back().first;
646
647 // Find the index of the first occurrence.
648 size_t i = 0;
649 while (decomposedOperands[i].first != mulOperand)
650 ++i;
651 // Remove both occurrences from the operand list.
652 operands.erase(operands.begin() + (&op - &operands[0]));
653 operands.erase(operands.begin() + i);
654
655 auto type = mulOperand.getType();
656 auto c1 = decomposedOperands[i].second,
657 c2 = decomposedOperands.back().second;
658 // Fill in missing constant multiplicands with 1.
659 if (!c1)
660 c1 = getOneOfType(type);
661 if (!c2)
662 c2 = getOneOfType(type);
663 // Re-add the "a"*(c1+c2) expression to the operand list and
664 // re-canonicalize.
665 auto constant = ParamExprAttr::get(PEO::Add, c1, c2);
666 auto mulCst = ParamExprAttr::get(PEO::Mul, mulOperand, constant);
667 operands.push_back(mulCst);
668 return ParamExprAttr::get(PEO::Add, operands);
669 }
670 }
671
672 return {};
673}
674
675static TypedAttr simplifyMul(SmallVector<TypedAttr, 4> &operands) {
676 if (auto result = simplifyAssocOp(
677 PEO::Mul, operands, [](auto a, auto b) { return a * b; },
678 /*identityCst*/ [](auto cst) { return cst.isOne(); },
679 /*destructiveCst*/ [](auto cst) { return cst.isZero(); }))
680 return result;
681
682 // We always build a sum-of-products representation, so if we see an addition
683 // as a subexpr, we need to pull it out: (a+b)*c*d ==> (a*c*d + b*c*d).
684 for (size_t i = 0, e = operands.size(); i != e; ++i) {
685 if (auto subexpr = dyn_castPE(PEO::Add, operands[i])) {
686 // Pull the `c*d` operands out - it is whatever operands remain after
687 // removing the `(a+b)` term.
688 SmallVector<TypedAttr> mulOperands(operands.begin(), operands.end());
689 mulOperands.erase(mulOperands.begin() + i);
690
691 // Build each add operand.
692 SmallVector<TypedAttr> addOperands;
693 for (auto addOperand : subexpr.getOperands()) {
694 mulOperands.push_back(addOperand);
695 addOperands.push_back(ParamExprAttr::get(PEO::Mul, mulOperands));
696 mulOperands.pop_back();
697 }
698 // Canonicalize and form the add expression.
699 return ParamExprAttr::get(PEO::Add, addOperands);
700 }
701 }
702
703 return {};
704}
705static TypedAttr simplifyAnd(SmallVector<TypedAttr, 4> &operands) {
706 return simplifyAssocOp(
707 PEO::And, operands, [](auto a, auto b) { return a & b; },
708 /*identityCst*/ [](auto cst) { return cst.isAllOnes(); },
709 /*destructiveCst*/ [](auto cst) { return cst.isZero(); });
710}
711
712static TypedAttr simplifyOr(SmallVector<TypedAttr, 4> &operands) {
713 return simplifyAssocOp(
714 PEO::Or, operands, [](auto a, auto b) { return a | b; },
715 /*identityCst*/ [](auto cst) { return cst.isZero(); },
716 /*destructiveCst*/ [](auto cst) { return cst.isAllOnes(); });
717}
718
719static TypedAttr simplifyXor(SmallVector<TypedAttr, 4> &operands) {
720 return simplifyAssocOp(
721 PEO::Xor, operands, [](auto a, auto b) { return a ^ b; },
722 /*identityCst*/ [](auto cst) { return cst.isZero(); });
723}
724
725static TypedAttr simplifyShl(SmallVector<TypedAttr, 4> &operands) {
726 assert(isHWIntegerType(operands[0].getType()));
727
728 if (auto rhs = dyn_cast<IntegerAttr>(operands[1])) {
729 // Constant fold simple integers.
730 if (auto lhs = dyn_cast<IntegerAttr>(operands[0]))
731 return IntegerAttr::get(lhs.getType(),
732 lhs.getValue().shl(rhs.getValue()));
733
734 // Canonicalize `x << cst` => `x * (1<<cst)` to compose correctly with
735 // add/mul canonicalization.
736 auto rhsCst = APInt::getOneBitSet(rhs.getValue().getBitWidth(),
737 rhs.getValue().getZExtValue());
738 return ParamExprAttr::get(PEO::Mul, operands[0],
739 IntegerAttr::get(rhs.getType(), rhsCst));
740 }
741 return {};
742}
743
744static TypedAttr simplifyShrU(SmallVector<TypedAttr, 4> &operands) {
745 assert(isHWIntegerType(operands[0].getType()));
746 // Implement support for identities like `x >> 0`.
747 if (auto rhs = dyn_cast<IntegerAttr>(operands[1]))
748 if (rhs.getValue().isZero())
749 return operands[0];
750
751 return foldBinaryOp(operands, [](auto a, auto b) { return a.lshr(b); });
752}
753
754static TypedAttr simplifyShrS(SmallVector<TypedAttr, 4> &operands) {
755 assert(isHWIntegerType(operands[0].getType()));
756 // Implement support for identities like `x >> 0`.
757 if (auto rhs = dyn_cast<IntegerAttr>(operands[1]))
758 if (rhs.getValue().isZero())
759 return operands[0];
760
761 return foldBinaryOp(operands, [](auto a, auto b) { return a.ashr(b); });
762}
763
764static TypedAttr simplifyDivU(SmallVector<TypedAttr, 4> &operands) {
765 assert(isHWIntegerType(operands[0].getType()));
766 // Implement support for identities like `x/1`.
767 if (auto rhs = dyn_cast<IntegerAttr>(operands[1]))
768 if (rhs.getValue().isOne())
769 return operands[0];
770
771 return foldBinaryOp(operands, [](auto a, auto b) { return a.udiv(b); });
772}
773
774static TypedAttr simplifyDivS(SmallVector<TypedAttr, 4> &operands) {
775 assert(isHWIntegerType(operands[0].getType()));
776 // Implement support for identities like `x/1`.
777 if (auto rhs = dyn_cast<IntegerAttr>(operands[1]))
778 if (rhs.getValue().isOne())
779 return operands[0];
780
781 return foldBinaryOp(operands, [](auto a, auto b) { return a.sdiv(b); });
782}
783
784static TypedAttr simplifyModU(SmallVector<TypedAttr, 4> &operands) {
785 assert(isHWIntegerType(operands[0].getType()));
786 // Implement support for identities like `x%1`.
787 if (auto rhs = dyn_cast<IntegerAttr>(operands[1]))
788 if (rhs.getValue().isOne())
789 return IntegerAttr::get(rhs.getType(), 0);
790
791 return foldBinaryOp(operands, [](auto a, auto b) { return a.urem(b); });
792}
793
794static TypedAttr simplifyModS(SmallVector<TypedAttr, 4> &operands) {
795 assert(isHWIntegerType(operands[0].getType()));
796 // Implement support for identities like `x%1`.
797 if (auto rhs = dyn_cast<IntegerAttr>(operands[1]))
798 if (rhs.getValue().isOne())
799 return IntegerAttr::get(rhs.getType(), 0);
800
801 return foldBinaryOp(operands, [](auto a, auto b) { return a.srem(b); });
802}
803
804static TypedAttr simplifyCLog2(SmallVector<TypedAttr, 4> &operands) {
805 assert(isHWIntegerType(operands[0].getType()));
806 return foldUnaryOp(operands, [](auto a) {
807 // Following the Verilog spec, clog2(0) is 0
808 return APInt(a.getBitWidth(), a == 0 ? 0 : a.ceilLogBase2());
809 });
810}
811
812static TypedAttr simplifyStrConcat(SmallVector<TypedAttr, 4> &operands) {
813 // Combine all adjacent strings.
814 SmallVector<TypedAttr> newOperands;
815 SmallVector<StringAttr> stringsToCombine;
816 auto combineAndPush = [&]() {
817 if (stringsToCombine.empty())
818 return;
819 // Concatenate buffered strings, push to ops.
820 SmallString<32> newString;
821 for (auto part : stringsToCombine)
822 newString.append(part.getValue());
823 newOperands.push_back(
824 StringAttr::get(stringsToCombine[0].getContext(), newString));
825 stringsToCombine.clear();
826 };
827
828 for (TypedAttr op : operands) {
829 if (auto strOp = dyn_cast<StringAttr>(op)) {
830 // Queue up adjacent strings.
831 stringsToCombine.push_back(strOp);
832 } else {
833 combineAndPush();
834 newOperands.push_back(op);
835 }
836 }
837 combineAndPush();
838
839 assert(!newOperands.empty());
840 if (newOperands.size() == 1)
841 return newOperands[0];
842 if (newOperands.size() < operands.size())
843 return ParamExprAttr::get(PEO::StrConcat, newOperands);
844 return {};
845}
846
847/// Build a parameter expression. This automatically canonicalizes and
848/// folds, so it may not necessarily return a ParamExprAttr.
849TypedAttr ParamExprAttr::get(PEO opcode, ArrayRef<TypedAttr> operandsIn) {
850 assert(!operandsIn.empty() && "Cannot have expr with no operands");
851 // All operands must have the same type, which is the type of the result.
852 auto type = operandsIn.front().getType();
853 assert(llvm::all_of(operandsIn.drop_front(),
854 [&](auto op) { return op.getType() == type; }));
855
856 SmallVector<TypedAttr, 4> operands(operandsIn.begin(), operandsIn.end());
857
858 // Verify and canonicalize parameter expressions.
859 TypedAttr result;
860 switch (opcode) {
861 case PEO::Add:
862 result = simplifyAdd(operands);
863 break;
864 case PEO::Mul:
865 result = simplifyMul(operands);
866 break;
867 case PEO::And:
868 result = simplifyAnd(operands);
869 break;
870 case PEO::Or:
871 result = simplifyOr(operands);
872 break;
873 case PEO::Xor:
874 result = simplifyXor(operands);
875 break;
876 case PEO::Shl:
877 result = simplifyShl(operands);
878 break;
879 case PEO::ShrU:
880 result = simplifyShrU(operands);
881 break;
882 case PEO::ShrS:
883 result = simplifyShrS(operands);
884 break;
885 case PEO::DivU:
886 result = simplifyDivU(operands);
887 break;
888 case PEO::DivS:
889 result = simplifyDivS(operands);
890 break;
891 case PEO::ModU:
892 result = simplifyModU(operands);
893 break;
894 case PEO::ModS:
895 result = simplifyModS(operands);
896 break;
897 case PEO::CLog2:
898 result = simplifyCLog2(operands);
899 break;
900 case PEO::StrConcat:
901 result = simplifyStrConcat(operands);
902 break;
903 }
904
905 // If we folded to an operand, return it.
906 if (result)
907 return result;
908
909 return Base::get(operands[0].getContext(), opcode, operands, type);
910}
911
912Attribute ParamExprAttr::parse(AsmParser &p, Type type) {
913 // We require an opcode suffix like `#hw.param.expr.add`, we don't allow
914 // parsing a plain `#hw.param.expr` on its own.
915 p.emitError(p.getNameLoc(), "#hw.param.expr should have opcode suffix");
916 return {};
917}
918
919/// Internal method used for .mlir file parsing when parsing the
920/// "#hw.param.expr.mul" form of the attribute.
921static Attribute parseParamExprWithOpcode(StringRef opcodeStr,
922 DialectAsmParser &p, Type type) {
923 SmallVector<TypedAttr> operands;
924 if (p.parseCommaSeparatedList(
925 mlir::AsmParser::Delimiter::LessGreater, [&]() -> ParseResult {
926 operands.push_back({});
927 return p.parseAttribute(operands.back(), type);
928 }))
929 return {};
930
931 std::optional<PEO> opcode = symbolizePEO(opcodeStr);
932 if (!opcode.has_value()) {
933 p.emitError(p.getNameLoc(), "unknown parameter expr operator name");
934 return {};
935 }
936
937 return ParamExprAttr::get(*opcode, operands);
938}
939
940void ParamExprAttr::print(AsmPrinter &p) const {
941 p << "." << stringifyPEO(getOpcode()) << '<';
942 llvm::interleaveComma(getOperands(), p.getStream(),
943 [&](Attribute op) { p.printAttributeWithoutType(op); });
944 p << '>';
945}
946
947// Replaces any ParamDeclRefAttr within a parametric expression with its
948// corresponding value from the map of provided parameters.
949static FailureOr<Attribute>
951 const std::map<std::string, Attribute> &parameters,
952 Attribute paramAttr, bool emitErrors) {
953 if (dyn_cast<IntegerAttr>(paramAttr)) {
954 // Nothing to do, constant value.
955 return paramAttr;
956 }
957 if (auto paramRefAttr = dyn_cast<hw::ParamDeclRefAttr>(paramAttr)) {
958 // Get the value from the provided parameters.
959 auto it = parameters.find(paramRefAttr.getName().str());
960 if (it == parameters.end()) {
961 if (emitErrors)
962 return emitError(loc)
963 << "Could not find parameter " << paramRefAttr.getName().str()
964 << " in the provided parameters for the expression!";
965 return failure();
966 }
967 return it->second;
968 }
969 if (auto paramExprAttr = dyn_cast<hw::ParamExprAttr>(paramAttr)) {
970 // Recurse into all operands of the expression.
971 llvm::SmallVector<TypedAttr, 4> replacedOperands;
972 for (auto operand : paramExprAttr.getOperands()) {
973 auto res = replaceDeclRefInExpr(loc, parameters, operand, emitErrors);
974 if (failed(res))
975 return {failure()};
976 replacedOperands.push_back(cast<TypedAttr>(*res));
977 }
978 return {
979 hw::ParamExprAttr::get(paramExprAttr.getOpcode(), replacedOperands)};
980 }
981 llvm_unreachable("Unhandled parametric attribute");
982 return {};
983}
984
985FailureOr<TypedAttr> hw::evaluateParametricAttr(Location loc,
986 ArrayAttr parameters,
987 Attribute paramAttr,
988 bool emitErrors) {
989 // Create a map of the provided parameters for faster lookup.
990 std::map<std::string, Attribute> parameterMap;
991 for (auto param : parameters) {
992 auto paramDecl = cast<ParamDeclAttr>(param);
993 parameterMap[paramDecl.getName().str()] = paramDecl.getValue();
994 }
995
996 // First, replace any ParamDeclRefAttr in the expression with its
997 // corresponding value in 'parameters'.
998 auto paramAttrRes =
999 replaceDeclRefInExpr(loc, parameterMap, paramAttr, emitErrors);
1000 if (failed(paramAttrRes))
1001 return {failure()};
1002 paramAttr = *paramAttrRes;
1003
1004 // Then, evaluate the parametric attribute.
1005 if (isa<IntegerAttr, hw::ParamDeclRefAttr>(paramAttr))
1006 return cast<TypedAttr>(paramAttr);
1007 if (auto paramExprAttr = dyn_cast<hw::ParamExprAttr>(paramAttr)) {
1008 // Since any ParamDeclRefAttr was replaced within the expression,
1009 // we re-evaluate the expression through the existing ParamExprAttr
1010 // canonicalizer.
1011 return ParamExprAttr::get(paramExprAttr.getOpcode(),
1012 paramExprAttr.getOperands());
1013 }
1014
1015 llvm_unreachable("Unhandled parametric attribute");
1016 return TypedAttr();
1017}
1018
1019template <typename TArray>
1020FailureOr<Type> evaluateParametricArrayType(Location loc, ArrayAttr parameters,
1021 TArray arrayType, bool emitErrors) {
1022 auto size = evaluateParametricAttr(loc, parameters, arrayType.getSizeAttr(),
1023 emitErrors);
1024 if (failed(size))
1025 return failure();
1027 loc, parameters, arrayType.getElementType(), emitErrors);
1028 if (failed(elementType))
1029 return failure();
1030
1031 // If the size was evaluated to a constant, use a 64-bit integer
1032 // attribute version of it
1033 if (auto intAttr = dyn_cast<IntegerAttr>(*size))
1034 return TArray::get(
1035 arrayType.getContext(), *elementType,
1036 IntegerAttr::get(IntegerType::get(arrayType.getContext(), 64),
1037 intAttr.getValue().getSExtValue()));
1038
1039 // Otherwise parameter references are still involved
1040 return TArray::get(arrayType.getContext(), *elementType, *size);
1041}
1042
1043FailureOr<Type> hw::evaluateParametricType(Location loc, ArrayAttr parameters,
1044 Type type, bool emitErrors) {
1045 return llvm::TypeSwitch<Type, FailureOr<Type>>(type)
1046 .Case<hw::IntType>([&](hw::IntType t) -> FailureOr<Type> {
1047 auto evaluatedWidth =
1048 evaluateParametricAttr(loc, parameters, t.getWidth(), emitErrors);
1049 if (failed(evaluatedWidth))
1050 return {failure()};
1051
1052 // If the width was evaluated to a constant, return an `IntegerType`
1053 if (auto intAttr = dyn_cast<IntegerAttr>(*evaluatedWidth))
1054 return {IntegerType::get(type.getContext(),
1055 intAttr.getValue().getSExtValue())};
1056
1057 // Otherwise parameter references are still involved
1058 return hw::IntType::get(cast<TypedAttr>(*evaluatedWidth));
1059 })
1060 .Case<hw::ArrayType, hw::UnpackedArrayType>(
1061 [&](auto arrayType) -> FailureOr<Type> {
1062 return evaluateParametricArrayType(loc, parameters, arrayType,
1063 emitErrors);
1064 })
1065 .Default([&](auto) { return type; });
1066}
1067
1068// Returns true if any part of this parametric attribute contains a reference
1069// to a parameter declaration.
1070static bool isParamAttrWithParamRef(Attribute expr) {
1071 return llvm::TypeSwitch<Attribute, bool>(expr)
1072 .Case([](ParamExprAttr attr) {
1073 return llvm::any_of(attr.getOperands(), isParamAttrWithParamRef);
1074 })
1075 .Case([](ParamDeclRefAttr) { return true; })
1076 .Default([](auto) { return false; });
1077}
1078
1079bool hw::isParametricType(mlir::Type t) {
1080 return llvm::TypeSwitch<Type, bool>(t)
1081 .Case<hw::IntType>(
1082 [&](hw::IntType t) { return isParamAttrWithParamRef(t.getWidth()); })
1083 .Case<hw::ArrayType, hw::UnpackedArrayType>([&](auto arrayType) {
1084 return isParametricType(arrayType.getElementType()) ||
1085 isParamAttrWithParamRef(arrayType.getSizeAttr());
1086 })
1087 .Default([](auto) { return false; });
1088}
assert(baseType &&"element must be base type")
MlirType elementType
Definition CHIRRTL.cpp:29
static TypedAttr foldBinaryOp(ArrayRef< TypedAttr > operands, llvm::function_ref< APInt(const APInt &, const APInt &)> calculate)
Given a binary function, if the two operands are known constant integers, use the specified fold func...
static TypedAttr simplifyDivS(SmallVector< TypedAttr, 4 > &operands)
static std::pair< TypedAttr, TypedAttr > decomposeAddend(TypedAttr operand)
Analyze an operand to an add.
static TypedAttr simplifyAssocOp(PEO opcode, SmallVector< TypedAttr, 4 > &operands, llvm::function_ref< APInt(const APInt &, const APInt &)> calculateFn, llvm::function_ref< bool(const APInt &)> identityConstantFn, llvm::function_ref< bool(const APInt &)> destructiveConstantFn={})
Given a fully associative variadic integer operation, constant fold any constant operands and move th...
static TypedAttr simplifyDivU(SmallVector< TypedAttr, 4 > &operands)
static TypedAttr simplifyCLog2(SmallVector< TypedAttr, 4 > &operands)
static bool isParamAttrWithParamRef(Attribute expr)
static TypedAttr simplifyModU(SmallVector< TypedAttr, 4 > &operands)
static TypedAttr simplifyOr(SmallVector< TypedAttr, 4 > &operands)
static TypedAttr simplifyShl(SmallVector< TypedAttr, 4 > &operands)
static void printSymbolName(AsmPrinter &p, StringAttr sym)
static TypedAttr simplifyShrS(SmallVector< TypedAttr, 4 > &operands)
static FailureOr< Attribute > replaceDeclRefInExpr(Location loc, const std::map< std::string, Attribute > &parameters, Attribute paramAttr, bool emitErrors)
static TypedAttr foldUnaryOp(ArrayRef< TypedAttr > operands, llvm::function_ref< APInt(const APInt &)> calculate)
Given a unary function, if the operand is a known constant integer, use the specified fold function t...
static TypedAttr simplifyMul(SmallVector< TypedAttr, 4 > &operands)
static TypedAttr getOneOfType(Type type)
static TypedAttr simplifyAnd(SmallVector< TypedAttr, 4 > &operands)
static TypedAttr simplifyModS(SmallVector< TypedAttr, 4 > &operands)
static Attribute parseParamExprWithOpcode(StringRef opcode, DialectAsmParser &p, Type type)
Internal method used for .mlir file parsing when parsing the "#hw.param.expr.mul" form of the attribu...
static bool paramExprOperandSortPredicate(Attribute lhs, Attribute rhs)
This implements a < comparison for two operands to an associative operation imposing an ordering upon...
static std::string canonicalizeFilename(const Twine &directory, const Twine &filename)
static TypedAttr simplifyShrU(SmallVector< TypedAttr, 4 > &operands)
FailureOr< Type > evaluateParametricArrayType(Location loc, ArrayAttr parameters, TArray arrayType, bool emitErrors)
static TypedAttr simplifyAdd(SmallVector< TypedAttr, 4 > &operands)
static TypedAttr simplifyStrConcat(SmallVector< TypedAttr, 4 > &operands)
static TypedAttr simplifyXor(SmallVector< TypedAttr, 4 > &operands)
static ParamExprAttr dyn_castPE(PEO opcode, Attribute value)
If the specified attribute is a ParamExprAttr with the specified opcode, return it.
static unsigned getFieldID(BundleType type, unsigned index)
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
mlir::FailureOr< mlir::TypedAttr > evaluateParametricAttr(mlir::Location loc, mlir::ArrayAttr parameters, mlir::Attribute paramAttr, bool emitErrors=true)
Evaluates a parametric attribute (param.decl.ref/param.expr) based on a set of provided parameter val...
bool isHWIntegerType(mlir::Type type)
Return true if the specified type is a value HW Integer type.
Definition HWTypes.cpp:60
mlir::FailureOr< mlir::Type > evaluateParametricType(mlir::Location loc, mlir::ArrayAttr parameters, mlir::Type type, bool emitErrors=true)
Returns a resolved version of 'type' wherein any parameter reference has been evaluated based on the ...
bool isParametricType(mlir::Type t)
Returns true if any part of t is parametric.
mlir::Type getCanonicalType(mlir::Type type)
Definition HWTypes.cpp:49
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.