11 #include "mlir/Pass/Pass.h"
12 #include "mlir/Transforms/DialectConversion.h"
13 #include "llvm/ADT/TypeSwitch.h"
17 #define GEN_PASS_DEF_FLATTENIO
18 #include "circt/Dialect/HW/Passes.h.inc"
23 using namespace circt;
35 return llvm::none_of(moduleLikeOp.getHWModuleType().getPortTypes(),
40 llvm::SmallVector<Type> inner;
41 t.getInnerTypes(inner);
42 for (
auto [index,
innerType] : llvm::enumerate(inner))
51 OutputOpConversion(TypeConverter &typeConverter, MLIRContext *context,
52 DenseSet<Operation *> *opVisited)
56 matchAndRewrite(hw::OutputOp op, OpAdaptor adaptor,
57 ConversionPatternRewriter &rewriter)
const override {
58 llvm::SmallVector<Value> convOperands;
61 for (
auto operand : adaptor.getOperands()) {
63 auto explodedStruct = rewriter.create<hw::StructExplodeOp>(
65 llvm::copy(explodedStruct.getResults(),
66 std::back_inserter(convOperands));
68 convOperands.push_back(operand);
73 opVisited->insert(op->getParentOp());
74 rewriter.replaceOpWithNewOp<hw::OutputOp>(op, convOperands);
77 DenseSet<Operation *> *opVisited;
81 InstanceOpConversion(TypeConverter &typeConverter, MLIRContext *context,
82 DenseSet<hw::InstanceOp> *convertedOps,
83 const StringSet<> *externModules)
85 externModules(externModules) {}
88 matchAndRewrite(hw::InstanceOp op, OpAdaptor adaptor,
89 ConversionPatternRewriter &rewriter)
const override {
90 auto referencedMod = op.getReferencedModuleNameAttr();
93 if (externModules->contains(referencedMod.getValue()))
96 auto loc = op.getLoc();
98 llvm::SmallVector<Value> convOperands;
99 for (
auto operand : adaptor.getOperands()) {
101 auto explodedStruct = rewriter.create<hw::StructExplodeOp>(
103 llvm::copy(explodedStruct.getResults(),
104 std::back_inserter(convOperands));
106 convOperands.push_back(operand);
111 llvm::SmallVector<Type> newResultTypes;
112 for (
auto oldResultType : op.getResultTypes()) {
114 for (
auto t : structType.getElements())
115 newResultTypes.push_back(t.type);
117 newResultTypes.push_back(oldResultType);
122 auto newInstance = rewriter.create<hw::InstanceOp>(
123 loc, newResultTypes, op.getInstanceNameAttr(),
125 op.getArgNamesAttr(), op.getResultNamesAttr(), op.getParametersAttr(),
126 op.getInnerSymAttr());
129 llvm::SmallVector<Value> convResults;
130 size_t oldResultCntr = 0;
131 for (
size_t resIndex = 0; resIndex < newInstance.getNumResults();
133 Type oldResultType = op.getResultTypes()[oldResultCntr];
135 size_t nElements = structType.getElements().size();
138 newInstance.getResults().slice(resIndex, nElements));
139 convResults.push_back(implodedStruct.getResult());
140 resIndex += nElements - 1;
142 convResults.push_back(newInstance.getResult(resIndex));
146 rewriter.replaceOp(op, convResults);
147 convertedOps->insert(newInstance);
151 DenseSet<hw::InstanceOp> *convertedOps;
152 const StringSet<> *externModules;
155 using IOTypes = std::pair<TypeRange, TypeRange>;
159 DenseMap<unsigned, hw::StructType> argStructs, resStructs;
162 SmallVector<Type> argTypes, resTypes;
165 class FlattenIOTypeConverter :
public TypeConverter {
167 FlattenIOTypeConverter() {
168 addConversion([](Type type, SmallVectorImpl<Type> &results) {
171 results.push_back(type);
173 for (
auto field : structType.getElements())
175 results.push_back(field.type);
180 addTargetMaterialization([](OpBuilder &builder, hw::StructType type,
181 ValueRange inputs, Location loc) {
183 return result.getResult();
186 addTargetMaterialization([](OpBuilder &builder, hw::TypeAliasType type,
187 ValueRange inputs, Location loc) {
189 assert(structType &&
"expected struct type");
191 return result.getResult();
198 template <
typename... TOp>
200 ConversionTarget &target,
202 FlattenIOTypeConverter &typeConverter) {
217 target.addDynamicallyLegalOp<TOp...>([&](hw::HWModuleLike moduleLikeOp) {
222 auto ioInfoIt = ioMap.find(moduleLikeOp);
223 if (ioInfoIt == ioMap.end()) {
228 auto ioInfo = ioInfoIt->second;
230 auto compareTypes = [&](TypeRange oldTypes, TypeRange newTypes) {
231 return llvm::any_of(llvm::zip(oldTypes, newTypes), [&](
auto typePair) {
232 auto oldType = std::get<0>(typePair);
233 auto newType = std::get<1>(typePair);
234 return oldType != newType;
237 auto mtype = moduleLikeOp.getHWModuleType();
238 if (
compareTypes(mtype.getOutputTypes(), ioInfo.resTypes) ||
248 template <
typename T>
250 return llvm::any_of(module.getBody()->getOps<T>(),
251 [](T op) { return !isLegalModLikeOp(op); });
254 template <
typename T>
255 static DenseMap<Operation *, IOTypes>
populateIOMap(mlir::ModuleOp module) {
256 DenseMap<Operation *, IOTypes> ioMap;
257 for (
auto op : module.getOps<T>())
258 ioMap[op] = {op.getArgumentTypes(), op.getResultTypes()};
262 template <
typename ModTy,
typename T>
263 static llvm::SmallVector<Attribute>
265 DenseMap<unsigned, hw::StructType> &structMap, T oldNames,
267 llvm::SmallVector<Attribute> newNames;
268 for (
auto [i, oldName] : llvm::enumerate(oldNames)) {
270 auto it = structMap.find(i);
271 if (it == structMap.end()) {
279 auto structType = it->second;
280 for (
auto field : structType.getElements())
282 op->getContext(), oldName + Twine(joinChar) + field.name.str()));
287 template <
typename ModTy>
292 SmallVector<Attribute> newNames;
293 SmallVector<hw::ModulePort> oldPorts(oldModType.getPorts().begin(),
294 oldModType.getPorts().end());
295 for (
auto oldPort : oldPorts) {
296 auto oldName = oldPort.name;
298 for (
auto field : structType.getElements()) {
301 oldName.getValue() + Twine(joinChar) + field.name.str()));
304 newNames.push_back(oldName);
306 op.setAllPortNames(newNames);
309 static llvm::SmallVector<Location>
311 SmallVectorImpl<Location> &oldLocs) {
312 llvm::SmallVector<Location> newLocs;
313 for (
auto [i, oldLoc] : llvm::enumerate(oldLocs)) {
315 auto it = structMap.find(i);
316 if (it == structMap.end()) {
318 newLocs.push_back(oldLoc);
322 auto structType = it->second;
323 for (
size_t i = 0, e = structType.getElements().size(); i < e; ++i)
324 newLocs.push_back(oldLoc);
334 DenseMap<unsigned, hw::StructType> &structMap) {
335 auto locs = op.getInputLocs();
336 if (locs.empty() || op.getModuleBody().empty())
338 for (
auto [arg, loc] : llvm::zip(op.getBodyBlock()->getArguments(), locs))
342 static void setIOInfo(hw::HWModuleLike op, IOInfo &ioInfo) {
343 ioInfo.argTypes = op.getInputTypes();
344 ioInfo.resTypes = op.getOutputTypes();
345 for (
auto [i, arg] : llvm::enumerate(ioInfo.argTypes)) {
347 ioInfo.argStructs[i] = structType;
349 for (
auto [i, res] : llvm::enumerate(ioInfo.resTypes)) {
351 ioInfo.resStructs[i] = structType;
355 template <
typename T>
357 DenseMap<Operation *, IOInfo> ioInfoMap;
358 for (
auto op : module.getOps<T>()) {
361 ioInfoMap[op] = ioInfo;
366 template <
typename T>
368 StringSet<> &externModules,
370 auto *ctx = module.getContext();
371 FlattenIOTypeConverter typeConverter;
376 while (hasUnconvertedOps<T>(module)) {
377 ConversionTarget target(*ctx);
379 target.addLegalDialect<hw::HWDialect>();
383 auto ioInfoMap = populateIOInfoMap<T>(module);
388 llvm::DenseSet<hw::InstanceOp> convertedInstances;
393 DenseSet<Operation *> opVisited;
394 patterns.add<OutputOpConversion>(typeConverter, ctx, &opVisited);
396 patterns.add<InstanceOpConversion>(typeConverter, ctx, &convertedInstances,
398 target.addDynamicallyLegalOp<hw::OutputOp>(
399 [&](
auto op) {
return opVisited.contains(op->getParentOp()); });
400 target.addDynamicallyLegalOp<hw::InstanceOp>([&](hw::InstanceOp op) {
401 auto refName = op.getReferencedModuleName();
402 return externModules.contains(refName) ||
403 llvm::none_of(op->getOperands(), [](
auto operand) {
404 return isStructType(operand.getType());
408 DenseMap<Operation *, ArrayAttr> oldArgNames, oldResNames;
409 DenseMap<Operation *, SmallVector<Location>> oldArgLocs, oldResLocs;
410 DenseMap<Operation *, hw::ModuleType> oldModTypes;
412 for (
auto op : module.getOps<T>()) {
413 oldModTypes[op] = op.getHWModuleType();
414 oldArgNames[op] =
ArrayAttr::get(module.getContext(), op.getInputNames());
417 oldArgLocs[op] = op.getInputLocs();
418 oldResLocs[op] = op.getOutputLocs();
422 addSignatureConversion<T>(ioInfoMap, target,
patterns, typeConverter);
424 if (failed(applyPartialConversion(module, target, std::move(
patterns))))
428 for (
auto op : module.getOps<T>()) {
429 auto ioInfo = ioInfoMap[op];
433 newArgLocs.append(newResLocs.begin(), newResLocs.end());
434 op.setAllPortLocs(newArgLocs);
439 for (
auto instanceOp : convertedInstances) {
441 cast<hw::HWModuleLike>(SymbolTable::lookupNearestSymbolFrom(
442 instanceOp, instanceOp.getReferencedModuleNameAttr()));
445 if (!ioInfoMap.contains(targetModule)) {
448 ioInfoMap[targetModule] = ioInfo;
449 oldArgNames[targetModule] =
450 ArrayAttr::get(module.getContext(), targetModule.getInputNames());
451 oldResNames[targetModule] =
452 ArrayAttr::get(module.getContext(), targetModule.getOutputNames());
453 oldArgLocs[targetModule] = targetModule.getInputLocs();
454 oldResLocs[targetModule] = targetModule.getOutputLocs();
456 ioInfo = ioInfoMap[targetModule];
459 instanceOp.getContext(),
461 instanceOp,
"argNames", ioInfo.argStructs,
462 oldArgNames[targetModule].template getAsValueRange<StringAttr>(),
465 instanceOp.getContext(),
467 instanceOp,
"resultNames", ioInfo.resStructs,
468 oldResNames[targetModule].template getAsValueRange<StringAttr>(),
483 template <
typename... TOps>
485 StringSet<> &externModules,
char joinChar) {
486 return (failed(flattenOpsOfType<TOps>(module, recursive, externModules,
493 class FlattenIOPass :
public circt::hw::impl::FlattenIOBase<FlattenIOPass> {
495 FlattenIOPass(
bool recursiveFlag,
bool flattenExternFlag,
char join) {
496 recursive = recursiveFlag;
497 flattenExtern = flattenExternFlag;
501 void runOnOperation()
override {
502 ModuleOp module = getOperation();
503 if (!flattenExtern) {
506 externModules.insert(m.getModuleName());
507 if (flattenIO<hw::HWModuleOp, hw::HWModuleGeneratedOp>(
508 module, recursive, externModules, joinChar))
514 hw::HWModuleGeneratedOp>(module, recursive, externModules,
520 StringSet<> externModules;
529 bool flattenExternFlag,
531 return std::make_unique<FlattenIOPass>(recursiveFlag, flattenExternFlag,
assert(baseType &&"element must be base type")
static LogicalResult compareTypes(Location loc, TypeRange rangeA, TypeRange rangeB)
static llvm::SmallVector< Type > getInnerTypes(hw::StructType t)
static llvm::SmallVector< Location > updateLocAttribute(DenseMap< unsigned, hw::StructType > &structMap, SmallVectorImpl< Location > &oldLocs)
static bool isLegalModLikeOp(hw::HWModuleLike moduleLikeOp)
static void updateBlockLocations(hw::HWModuleLike op, DenseMap< unsigned, hw::StructType > &structMap)
The conversion framework seems to throw away block argument locations.
static bool hasUnconvertedOps(mlir::ModuleOp module)
static llvm::SmallVector< Attribute > updateNameAttribute(ModTy op, StringRef attrName, DenseMap< unsigned, hw::StructType > &structMap, T oldNames, char joinChar)
static void setIOInfo(hw::HWModuleLike op, IOInfo &ioInfo)
static LogicalResult flattenOpsOfType(ModuleOp module, bool recursive, StringSet<> &externModules, char joinChar)
static DenseMap< Operation *, IOTypes > populateIOMap(mlir::ModuleOp module)
static void addSignatureConversion(DenseMap< Operation *, IOInfo > &ioMap, ConversionTarget &target, RewritePatternSet &patterns, FlattenIOTypeConverter &typeConverter)
static bool isStructType(Type type)
static void updateModulePortNames(ModTy op, hw::ModuleType oldModType, char joinChar)
static hw::StructType getStructType(Type type)
static bool flattenIO(ModuleOp module, bool recursive, StringSet<> &externModules, char joinChar)
static DenseMap< Operation *, IOInfo > populateIOInfoMap(mlir::ModuleOp module)
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
mlir::Type innerType(mlir::Type type)
void populateHWModuleLikeTypeConversionPattern(StringRef moduleLikeOpName, RewritePatternSet &patterns, TypeConverter &converter)
mlir::Type getCanonicalType(mlir::Type type)
std::unique_ptr< mlir::Pass > createFlattenIOPass(bool recursiveFlag=true, bool flattenExternFlag=false, char joinChar='.')
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.