13#include "../PassDetails.h"
21#include "mlir/IR/BuiltinTypes.h"
22#include "mlir/Interfaces/ControlFlowInterfaces.h"
23#include "mlir/Pass/Pass.h"
24#include "mlir/Transforms/DialectConversion.h"
28#define GEN_PASS_DEF_LOWERESITYPES
29#include "circt/Dialect/ESI/ESIPasses.h.inc"
38struct ESILowerTypesPass
39 :
public circt::esi::impl::LowerESITypesBase<ESILowerTypesPass> {
40 void runOnOperation()
override;
46class LowerTypesConverter :
public TypeConverter {
48 LowerTypesConverter() {
49 addConversion([](Type t) {
return t; });
50 addConversion([](WindowType window) {
return window.getLoweredType(); });
51 addConversion([&](hw::ArrayType array) -> Type {
55 if (element == array.getElementType())
57 return hw::ArrayType::get(element, array.getNumElements());
59 addConversion([&](hw::StructType structType) -> Type {
60 SmallVector<hw::StructType::FieldInfo> fields;
61 fields.reserve(structType.getElements().size());
63 for (
auto field : structType.getElements()) {
67 changed |= lowered != field.type;
68 fields.push_back({field.name, lowered});
72 return hw::StructType::get(structType.getContext(), fields);
74 addConversion([&](hw::UnionType unionType) -> Type {
75 SmallVector<hw::UnionType::FieldInfo> fields;
76 fields.reserve(unionType.getElements().size());
78 for (
auto field : unionType.getElements()) {
82 changed |= lowered != field.type;
83 fields.push_back({field.name, lowered, field.offset});
87 return hw::UnionType::get(unionType.getContext(), fields);
95 return esi::ListType::get(listType.getContext(), element);
97 addConversion([&](hw::TypeAliasType alias) -> Type {
101 if (lowered == alias.getInnerType())
103 return hw::TypeAliasType::get(alias.getRef(), lowered);
105 addSourceMaterialization(wrapMaterialization);
106 addTargetMaterialization(unwrapMaterialization);
110 static mlir::Value wrapMaterialization(OpBuilder &b, WindowType resultType,
111 ValueRange inputs, Location loc) {
112 if (inputs.size() != 1)
113 return mlir::Value();
114 return b.createOrFold<WrapWindow>(loc, resultType, inputs[0]);
117 static mlir::Value unwrapMaterialization(OpBuilder &b, Type resultType,
118 ValueRange inputs, Location loc) {
119 if (inputs.size() != 1 || !isa<WindowType>(inputs[0].getType()))
120 return mlir::Value();
121 return b.createOrFold<UnwrapWindow>(loc, resultType, inputs[0]);
127 return TypeSwitch<Type, bool>(type)
128 .Case([](WindowType) {
return true; })
129 .Case<hw::ArrayType>([](hw::ArrayType array) {
132 .Case<hw::StructType>([](hw::StructType structType) {
133 for (
auto field : structType.getElements())
138 .Case<hw::UnionType>([](hw::UnionType unionType) {
139 for (
auto field : unionType.getElements())
147 .Case<hw::TypeAliasType>([](hw::TypeAliasType aliasType) {
150 .Default([](Type) {
return false; });
153void ESILowerTypesPass::runOnOperation() {
154 ConversionTarget target(getContext());
156 target.addLegalOp<WrapWindow, UnwrapWindow>();
159 target.markUnknownOpDynamicallyLegal([](Operation *op) {
160 return TypeSwitch<Operation *, bool>(op)
161 .Case([](igraph::InstanceOpInterface inst) {
163 return !(llvm::any_of(inst->getOperandTypes(), hasWindow) ||
164 llvm::any_of(inst->getResultTypes(), hasWindow));
166 .Case([](hw::HWMutableModuleLike mod) {
170 return !(llvm::any_of(mod.getPortList(), isWindowPort));
172 .Default([](Operation *op) {
180 LowerTypesConverter types;
181 RewritePatternSet
patterns(&getContext());
184 applyPartialConversion(getOperation(), target, std::move(
patterns))))
189 mlir::ConversionConfig config;
190 config.foldingMode = mlir::DialectConversionFoldingMode::BeforePatterns;
191 ConversionTarget partialCanonicalizedTarget(getContext());
192 RewritePatternSet partialPatterns(&getContext());
193 partialCanonicalizedTarget.addIllegalOp<WrapWindow, UnwrapWindow>();
194 WrapWindow::getCanonicalizationPatterns(partialPatterns, &getContext());
195 UnwrapWindow::getCanonicalizationPatterns(partialPatterns, &getContext());
196 if (failed(mlir::applyPartialConversion(getOperation(),
197 partialCanonicalizedTarget,
198 std::move(partialPatterns), config)))
202std::unique_ptr<OperationPass<ModuleOp>>
204 return std::make_unique<ESILowerTypesPass>();
static FIRRTLBaseType convertType(FIRRTLBaseType type)
Returns null type if no conversion is needed.
static bool containsWindowType(Type type)
Lists represent variable-length sequences of elements of a single type.
const Type * getElementType() const
std::unique_ptr< OperationPass< ModuleOp > > createESITypeLoweringPass()
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Generic pattern which replaces an operation by one of the same operation name, but with converted att...
This holds the name, type, direction of a module's ports.