12 #include "mlir/Transforms/DialectConversion.h"
13 #include "llvm/ADT/TypeSwitch.h"
16 using namespace circt;
28 return llvm::none_of(moduleLikeOp.getHWModuleType().getPortTypes(),
33 llvm::SmallVector<Type> inner;
34 t.getInnerTypes(inner);
35 for (
auto [index,
innerType] : llvm::enumerate(inner))
44 OutputOpConversion(TypeConverter &typeConverter, MLIRContext *context,
45 DenseSet<Operation *> *opVisited)
49 matchAndRewrite(hw::OutputOp op, OpAdaptor adaptor,
50 ConversionPatternRewriter &rewriter)
const override {
51 llvm::SmallVector<Value> convOperands;
54 for (
auto operand : adaptor.getOperands()) {
56 auto explodedStruct = rewriter.create<hw::StructExplodeOp>(
58 llvm::copy(explodedStruct.getResults(),
59 std::back_inserter(convOperands));
61 convOperands.push_back(operand);
66 rewriter.replaceOpWithNewOp<hw::OutputOp>(op, convOperands);
67 opVisited->insert(op->getParentOp());
70 DenseSet<Operation *> *opVisited;
74 InstanceOpConversion(TypeConverter &typeConverter, MLIRContext *context,
75 DenseSet<hw::InstanceOp> *convertedOps)
77 convertedOps(convertedOps) {}
80 matchAndRewrite(hw::InstanceOp op, OpAdaptor adaptor,
81 ConversionPatternRewriter &rewriter)
const override {
82 auto loc = op.getLoc();
84 llvm::SmallVector<Value> convOperands;
85 for (
auto operand : adaptor.getOperands()) {
87 auto explodedStruct = rewriter.create<hw::StructExplodeOp>(
89 llvm::copy(explodedStruct.getResults(),
90 std::back_inserter(convOperands));
92 convOperands.push_back(operand);
97 auto newInstance = rewriter.create<hw::InstanceOp>(
98 loc, op.getReferencedModuleSlow(), op.getInstanceName(), convOperands);
101 llvm::SmallVector<Value> convResults;
102 size_t oldResultCntr = 0;
103 for (
size_t resIndex = 0; resIndex < newInstance.getNumResults();
105 Type oldResultType = op.getResultTypes()[oldResultCntr];
107 size_t nElements = structType.getElements().size();
110 newInstance.getResults().slice(resIndex, nElements));
111 convResults.push_back(implodedStruct.getResult());
112 resIndex += nElements - 1;
114 convResults.push_back(newInstance.getResult(resIndex));
118 rewriter.replaceOp(op, convResults);
119 convertedOps->insert(newInstance);
123 DenseSet<hw::InstanceOp> *convertedOps;
126 using IOTypes = std::pair<TypeRange, TypeRange>;
130 DenseMap<unsigned, hw::StructType> argStructs, resStructs;
133 SmallVector<Type> argTypes, resTypes;
136 class FlattenIOTypeConverter :
public TypeConverter {
138 FlattenIOTypeConverter() {
139 addConversion([](Type type, SmallVectorImpl<Type> &results) {
142 results.push_back(type);
144 for (
auto field : structType.getElements())
145 results.push_back(field.type);
150 addTargetMaterialization([](OpBuilder &
builder, hw::StructType type,
151 ValueRange
inputs, Location loc) {
153 return result.getResult();
156 addTargetMaterialization([](OpBuilder &
builder, hw::TypeAliasType type,
157 ValueRange
inputs, Location loc) {
159 assert(structType &&
"expected struct type");
161 return result.getResult();
168 template <
typename... TOp>
170 ConversionTarget &target,
172 FlattenIOTypeConverter &typeConverter) {
187 target.addDynamicallyLegalOp<TOp...>([&](hw::HWModuleLike moduleLikeOp) {
192 auto ioInfoIt = ioMap.find(moduleLikeOp);
193 if (ioInfoIt == ioMap.end()) {
198 auto ioInfo = ioInfoIt->second;
200 auto compareTypes = [&](TypeRange oldTypes, TypeRange newTypes) {
201 return llvm::any_of(llvm::zip(oldTypes, newTypes), [&](
auto typePair) {
202 auto oldType = std::get<0>(typePair);
203 auto newType = std::get<1>(typePair);
204 return oldType != newType;
207 auto mtype = moduleLikeOp.getHWModuleType();
208 if (
compareTypes(mtype.getOutputTypes(), ioInfo.resTypes) ||
218 template <
typename T>
220 return llvm::any_of(module.getBody()->getOps<T>(),
221 [](T op) { return !isLegalModLikeOp(op); });
224 template <
typename T>
225 static DenseMap<Operation *, IOTypes>
populateIOMap(mlir::ModuleOp module) {
226 DenseMap<Operation *, IOTypes> ioMap;
227 for (
auto op : module.getOps<T>())
228 ioMap[op] = {op.getArgumentTypes(), op.getResultTypes()};
232 template <
typename ModTy,
typename T>
233 static llvm::SmallVector<Attribute>
235 DenseMap<unsigned, hw::StructType> &structMap, T oldNames) {
236 llvm::SmallVector<Attribute> newNames;
237 for (
auto [i, oldName] : llvm::enumerate(oldNames)) {
239 auto it = structMap.find(i);
240 if (it == structMap.end()) {
248 auto structType = it->second;
249 for (
auto field : structType.getElements())
256 static llvm::SmallVector<Attribute>
259 llvm::SmallVector<Attribute> newLocs;
262 for (
auto [i, oldLoc] : llvm::enumerate(oldLocs.getAsRange<Location>())) {
264 auto it = structMap.find(i);
265 if (it == structMap.end()) {
267 newLocs.push_back(oldLoc);
271 auto structType = it->second;
272 for (
size_t i = 0, e = structType.getElements().size(); i < e; ++i)
273 newLocs.push_back(oldLoc);
283 DenseMap<unsigned, hw::StructType> &structMap) {
284 auto locs = op.getInputLocs();
285 if (locs.empty() || op.getModuleBody().empty())
287 for (
auto [arg, loc] : llvm::zip(op.getBodyBlock()->getArguments(), locs))
291 template <
typename T>
293 DenseMap<Operation *, IOInfo> ioInfoMap;
294 for (
auto op : module.getOps<T>()) {
296 ioInfo.argTypes = op.getInputTypes();
297 ioInfo.resTypes = op.getOutputTypes();
298 for (
auto [i, arg] : llvm::enumerate(ioInfo.argTypes)) {
300 ioInfo.argStructs[i] = structType;
302 for (
auto [i, res] : llvm::enumerate(ioInfo.resTypes)) {
304 ioInfo.resStructs[i] = structType;
306 ioInfoMap[op] = ioInfo;
311 template <
typename T>
313 auto *ctx = module.getContext();
314 FlattenIOTypeConverter typeConverter;
319 while (hasUnconvertedOps<T>(module)) {
320 ConversionTarget target(*ctx);
322 target.addLegalDialect<hw::HWDialect>();
326 auto ioInfoMap = populateIOInfoMap<T>(module);
331 llvm::DenseSet<hw::InstanceOp> convertedInstances;
336 DenseSet<Operation *> opVisited;
337 patterns.add<OutputOpConversion>(typeConverter, ctx, &opVisited);
339 patterns.add<InstanceOpConversion>(typeConverter, ctx, &convertedInstances);
340 target.addDynamicallyLegalOp<hw::OutputOp>(
341 [&](
auto op) {
return opVisited.contains(op->getParentOp()); });
342 target.addDynamicallyLegalOp<hw::InstanceOp>([&](
auto op) {
343 return llvm::none_of(op->getOperands(), [](
auto operand) {
344 return isStructType(operand.getType());
348 DenseMap<Operation *, ArrayAttr> oldArgNames, oldResNames, oldArgLocs,
350 for (
auto op : module.getOps<T>()) {
351 oldArgNames[op] =
ArrayAttr::get(module.getContext(), op.getInputNames());
354 oldArgLocs[op] = op.getInputLocsAttr();
355 oldResLocs[op] = op.getOutputLocsAttr();
359 addSignatureConversion<T>(ioInfoMap, target,
patterns, typeConverter);
361 if (failed(applyPartialConversion(module, target, std::move(
patterns))))
365 for (
auto op : module.getOps<T>()) {
366 auto ioInfo = ioInfoMap[op];
368 op,
"argNames", ioInfo.argStructs,
369 oldArgNames[op].template getAsValueRange<StringAttr>());
371 op,
"resultNames", ioInfo.resStructs,
372 oldResNames[op].template getAsValueRange<StringAttr>());
373 newArgNames.append(newResNames.begin(), newResNames.end());
374 op.setAllPortNames(newArgNames);
377 newArgLocs.append(newResLocs.begin(), newResLocs.end());
383 for (
auto instanceOp : convertedInstances) {
384 Operation *targetModule = instanceOp.getReferencedModuleSlow();
385 auto ioInfo = ioInfoMap[targetModule];
387 instanceOp.getContext(),
389 oldArgNames[targetModule]
390 .template getAsValueRange<StringAttr>())));
392 instanceOp.getContext(),
394 oldResNames[targetModule]
395 .template getAsValueRange<StringAttr>())));
410 template <
typename... TOps>
411 static bool flattenIO(ModuleOp module,
bool recursive) {
412 return (failed(flattenOpsOfType<TOps>(module, recursive)) || ...);
417 class FlattenIOPass :
public circt::hw::FlattenIOBase<FlattenIOPass> {
419 void runOnOperation()
override {
420 ModuleOp module = getOperation();
422 hw::HWModuleGeneratedOp>(module, recursive))
434 return std::make_unique<FlattenIOPass>();
assert(baseType &&"element must be base type")
static LogicalResult compareTypes(Location loc, TypeRange rangeA, TypeRange rangeB)
static llvm::SmallVector< Type > getInnerTypes(hw::StructType t)
static bool isLegalModLikeOp(hw::HWModuleLike moduleLikeOp)
static void updateBlockLocations(hw::HWModuleLike op, DenseMap< unsigned, hw::StructType > &structMap)
The conversion framework seems to throw away block argument locations.
static bool hasUnconvertedOps(mlir::ModuleOp module)
static LogicalResult flattenOpsOfType(ModuleOp module, bool recursive)
static llvm::SmallVector< Attribute > updateLocAttribute(DenseMap< unsigned, hw::StructType > &structMap, ArrayAttr oldLocs)
static llvm::SmallVector< Attribute > updateNameAttribute(ModTy op, StringRef attrName, DenseMap< unsigned, hw::StructType > &structMap, T oldNames)
static DenseMap< Operation *, IOTypes > populateIOMap(mlir::ModuleOp module)
static void addSignatureConversion(DenseMap< Operation *, IOInfo > &ioMap, ConversionTarget &target, RewritePatternSet &patterns, FlattenIOTypeConverter &typeConverter)
static bool flattenIO(ModuleOp module, bool recursive)
static bool isStructType(Type type)
static hw::StructType getStructType(Type type)
static DenseMap< Operation *, IOInfo > populateIOInfoMap(mlir::ModuleOp module)
llvm::SmallVector< StringAttr > inputs
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
mlir::Type innerType(mlir::Type type)
std::unique_ptr< mlir::Pass > createFlattenIOPass()
void populateHWModuleLikeTypeConversionPattern(StringRef moduleLikeOpName, RewritePatternSet &patterns, TypeConverter &converter)
mlir::Type getCanonicalType(mlir::Type type)
This file defines an intermediate representation for circuits acting as an abstraction for constraint...