11#include "mlir/Pass/Pass.h"
12#include "mlir/Transforms/DialectConversion.h"
13#include "llvm/ADT/TypeSwitch.h"
17#define GEN_PASS_DEF_FLATTENIO
18#include "circt/Dialect/HW/Passes.h.inc"
26 return isa<hw::StructType>(hw::getCanonicalType(type));
30 return dyn_cast<hw::StructType>(hw::getCanonicalType(type));
35 return llvm::none_of(moduleLikeOp.getHWModuleType().getPortTypes(),
40 llvm::SmallVector<Type> inner;
41 t.getInnerTypes(inner);
42 for (
auto [index, innerType] : llvm::enumerate(inner))
43 inner[index] = hw::getCanonicalType(innerType);
50static SmallVector<Value>
flattenValues(ArrayRef<ValueRange> values) {
51 SmallVector<Value> result;
52 for (
const auto &vals : values)
53 llvm::append_range(result, vals);
59 OutputOpConversion(TypeConverter &typeConverter, MLIRContext *context,
60 DenseSet<Operation *> *opVisited)
64 matchAndRewrite(hw::OutputOp op, OpAdaptor adaptor,
65 ConversionPatternRewriter &rewriter)
const override {
66 llvm::SmallVector<Value> convOperands;
69 for (
auto operand : adaptor.getOperands()) {
71 auto explodedStruct = rewriter.create<hw::StructExplodeOp>(
73 llvm::copy(explodedStruct.getResults(),
74 std::back_inserter(convOperands));
76 convOperands.push_back(operand);
81 opVisited->insert(op->getParentOp());
82 rewriter.replaceOpWithNewOp<hw::OutputOp>(op, convOperands);
87 matchAndRewrite(hw::OutputOp op, OneToNOpAdaptor adaptor,
88 ConversionPatternRewriter &rewriter)
const override {
89 llvm::SmallVector<Value> convOperands;
94 auto explodedStruct = rewriter.create<hw::StructExplodeOp>(
96 llvm::copy(explodedStruct.getResults(),
97 std::back_inserter(convOperands));
99 convOperands.push_back(operand);
104 opVisited->insert(op->getParentOp());
105 rewriter.replaceOpWithNewOp<hw::OutputOp>(op, convOperands);
108 DenseSet<Operation *> *opVisited;
112 InstanceOpConversion(TypeConverter &typeConverter, MLIRContext *context,
113 DenseSet<hw::InstanceOp> *convertedOps,
114 const StringSet<> *externModules)
116 externModules(externModules) {}
119 matchAndRewrite(hw::InstanceOp op, OneToNOpAdaptor adaptor,
120 ConversionPatternRewriter &rewriter)
const override {
121 auto referencedMod = op.getReferencedModuleNameAttr();
124 if (externModules->contains(referencedMod.getValue()))
127 auto loc = op.getLoc();
129 llvm::SmallVector<Value> convOperands;
132 auto explodedStruct = rewriter.create<hw::StructExplodeOp>(
134 llvm::copy(explodedStruct.getResults(),
135 std::back_inserter(convOperands));
137 convOperands.push_back(operand);
142 llvm::SmallVector<Type> newResultTypes;
143 for (
auto oldResultType : op.getResultTypes()) {
145 for (
auto t : structType.getElements())
146 newResultTypes.push_back(t.type);
148 newResultTypes.push_back(oldResultType);
153 auto newInstance = rewriter.create<hw::InstanceOp>(
154 loc, newResultTypes, op.getInstanceNameAttr(),
155 FlatSymbolRefAttr::get(referencedMod), convOperands,
156 op.getArgNamesAttr(), op.getResultNamesAttr(), op.getParametersAttr(),
157 op.getInnerSymAttr(), op.getDoNotPrintAttr());
160 llvm::SmallVector<Value> convResults;
161 size_t oldResultCntr = 0;
162 for (
size_t resIndex = 0; resIndex < newInstance.getNumResults();
164 Type oldResultType = op.getResultTypes()[oldResultCntr];
166 size_t nElements = structType.getElements().size();
169 newInstance.getResults().slice(resIndex, nElements));
170 convResults.push_back(implodedStruct.getResult());
171 resIndex += nElements - 1;
173 convResults.push_back(newInstance.getResult(resIndex));
177 rewriter.replaceOp(op, convResults);
178 convertedOps->insert(newInstance);
182 DenseSet<hw::InstanceOp> *convertedOps;
183 const StringSet<> *externModules;
186using IOTypes = std::pair<TypeRange, TypeRange>;
190 DenseMap<unsigned, hw::StructType> argStructs, resStructs;
193 SmallVector<Type> argTypes, resTypes;
196class FlattenIOTypeConverter :
public TypeConverter {
198 FlattenIOTypeConverter() {
199 addConversion([](Type type, SmallVectorImpl<Type> &results) {
202 results.push_back(type);
204 for (
auto field : structType.getElements())
205 results.push_back(field.type);
212 addTargetMaterialization([](OpBuilder &builder, TypeRange resultTypes,
213 ValueRange inputs, Location loc) {
214 if (inputs.size() != 1 && !
isStructType(inputs[0].getType()))
217 auto explodeOp = builder.create<hw::StructExplodeOp>(loc, inputs[0]);
218 return ValueRange(explodeOp.getResults());
220 addTargetMaterialization([](OpBuilder &builder, hw::StructType type,
221 ValueRange inputs, Location loc) {
223 return result.getResult();
226 addTargetMaterialization([](OpBuilder &builder, hw::TypeAliasType type,
227 ValueRange inputs, Location loc) {
229 return result.getResult();
238 addSourceMaterialization([](OpBuilder &builder, hw::StructType type,
239 ValueRange inputs, Location loc) {
241 return result.getResult();
248template <
typename... TOp>
250 ConversionTarget &target,
252 FlattenIOTypeConverter &typeConverter) {
253 (hw::populateHWModuleLikeTypeConversionPattern(TOp::getOperationName(),
267 target.addDynamicallyLegalOp<TOp...>([&](hw::HWModuleLike moduleLikeOp) {
272 auto ioInfoIt = ioMap.find(moduleLikeOp);
273 if (ioInfoIt == ioMap.end()) {
278 auto ioInfo = ioInfoIt->second;
280 auto compareTypes = [&](TypeRange oldTypes, TypeRange newTypes) {
281 return llvm::any_of(llvm::zip(oldTypes, newTypes), [&](
auto typePair) {
282 auto oldType = std::get<0>(typePair);
283 auto newType = std::get<1>(typePair);
284 return oldType != newType;
287 auto mtype = moduleLikeOp.getHWModuleType();
288 if (
compareTypes(mtype.getOutputTypes(), ioInfo.resTypes) ||
300 return llvm::any_of(module.getBody()->getOps<T>(),
301 [](T op) { return !isLegalModLikeOp(op); });
306 DenseMap<Operation *, IOTypes> ioMap;
307 for (
auto op :
module.getOps<T>())
308 ioMap[op] = {op.getArgumentTypes(), op.getResultTypes()};
312template <
typename ModTy,
typename T>
313static llvm::SmallVector<Attribute>
315 DenseMap<unsigned, hw::StructType> &structMap, T oldNames,
317 llvm::SmallVector<Attribute> newNames;
318 for (
auto [i, oldName] : llvm::enumerate(oldNames)) {
320 auto it = structMap.find(i);
321 if (it == structMap.end()) {
323 newNames.push_back(StringAttr::get(op->getContext(), oldName));
329 auto structType = it->second;
330 for (
auto field : structType.getElements())
331 newNames.push_back(StringAttr::get(
332 op->getContext(), oldName + Twine(joinChar) + field.name.str()));
337template <
typename ModTy>
342 SmallVector<Attribute> newNames;
343 SmallVector<hw::ModulePort> oldPorts(oldModType.getPorts().begin(),
344 oldModType.getPorts().end());
345 for (
auto oldPort : oldPorts) {
346 auto oldName = oldPort.name;
348 for (
auto field : structType.getElements()) {
349 newNames.push_back(StringAttr::get(
351 oldName.getValue() + Twine(joinChar) + field.name.str()));
354 newNames.push_back(oldName);
356 op.setAllPortNames(newNames);
359static llvm::SmallVector<Location>
361 SmallVectorImpl<Location> &oldLocs) {
362 llvm::SmallVector<Location> newLocs;
363 for (
auto [i, oldLoc] : llvm::enumerate(oldLocs)) {
365 auto it = structMap.find(i);
366 if (it == structMap.end()) {
368 newLocs.push_back(oldLoc);
372 auto structType = it->second;
373 for (
size_t i = 0, e = structType.getElements().size(); i < e; ++i)
374 newLocs.push_back(oldLoc);
384 DenseMap<unsigned, hw::StructType> &structMap) {
385 auto locs = op.getInputLocs();
386 if (locs.empty() || op.getModuleBody().empty())
388 for (
auto [arg, loc] : llvm::zip(op.getBodyBlock()->getArguments(), locs))
392static void setIOInfo(hw::HWModuleLike op, IOInfo &ioInfo) {
393 ioInfo.argTypes = op.getInputTypes();
394 ioInfo.resTypes = op.getOutputTypes();
395 for (
auto [i, arg] : llvm::enumerate(ioInfo.argTypes)) {
397 ioInfo.argStructs[i] = structType;
399 for (
auto [i, res] : llvm::enumerate(ioInfo.resTypes)) {
401 ioInfo.resStructs[i] = structType;
407 DenseMap<Operation *, IOInfo> ioInfoMap;
408 for (
auto op :
module.getOps<T>()) {
411 ioInfoMap[op] = ioInfo;
418 StringSet<> &externModules,
420 auto *ctx =
module.getContext();
421 FlattenIOTypeConverter typeConverter;
426 while (hasUnconvertedOps<T>(module)) {
427 ConversionTarget target(*ctx);
429 target.addLegalDialect<hw::HWDialect>();
433 auto ioInfoMap = populateIOInfoMap<T>(module);
438 llvm::DenseSet<hw::InstanceOp> convertedInstances;
443 DenseSet<Operation *> opVisited;
444 patterns.add<OutputOpConversion>(typeConverter, ctx, &opVisited);
446 patterns.add<InstanceOpConversion>(typeConverter, ctx, &convertedInstances,
448 target.addDynamicallyLegalOp<hw::OutputOp>(
449 [&](
auto op) {
return opVisited.contains(op->getParentOp()); });
450 target.addDynamicallyLegalOp<hw::InstanceOp>([&](hw::InstanceOp op) {
451 auto refName = op.getReferencedModuleName();
452 return externModules.contains(refName) ||
453 (llvm::none_of(op->getOperands(),
455 return isStructType(operand.getType());
457 llvm::none_of(op->getResultTypes(),
458 [](
auto result) { return isStructType(result); }));
461 DenseMap<Operation *, ArrayAttr> oldArgNames, oldResNames;
462 DenseMap<Operation *, SmallVector<Location>> oldArgLocs, oldResLocs;
463 DenseMap<Operation *, hw::ModuleType> oldModTypes;
465 for (
auto op :
module.getOps<T>()) {
466 oldModTypes[op] = op.getHWModuleType();
467 oldArgNames[op] = ArrayAttr::get(module.getContext(), op.getInputNames());
469 ArrayAttr::get(module.getContext(), op.getOutputNames());
470 oldArgLocs[op] = op.getInputLocs();
471 oldResLocs[op] = op.getOutputLocs();
475 addSignatureConversion<T>(ioInfoMap, target,
patterns, typeConverter);
477 if (failed(applyPartialConversion(module, target, std::move(
patterns))))
481 for (
auto op :
module.getOps<T>()) {
482 auto ioInfo = ioInfoMap[op];
486 newArgLocs.append(newResLocs.begin(), newResLocs.end());
487 op.setAllPortLocs(newArgLocs);
492 for (
auto instanceOp : convertedInstances) {
494 cast<hw::HWModuleLike>(SymbolTable::lookupNearestSymbolFrom(
495 instanceOp, instanceOp.getReferencedModuleNameAttr()));
498 if (!ioInfoMap.contains(targetModule)) {
501 ioInfoMap[targetModule] = ioInfo;
502 oldArgNames[targetModule] =
503 ArrayAttr::get(module.getContext(), targetModule.getInputNames());
504 oldResNames[targetModule] =
505 ArrayAttr::get(module.getContext(), targetModule.getOutputNames());
506 oldArgLocs[targetModule] = targetModule.getInputLocs();
507 oldResLocs[targetModule] = targetModule.getOutputLocs();
509 ioInfo = ioInfoMap[targetModule];
511 instanceOp.setInputNames(ArrayAttr::get(
512 instanceOp.getContext(),
514 instanceOp,
"argNames", ioInfo.argStructs,
515 oldArgNames[targetModule].template getAsValueRange<StringAttr>(),
517 instanceOp.setOutputNames(ArrayAttr::get(
518 instanceOp.getContext(),
520 instanceOp,
"resultNames", ioInfo.resStructs,
521 oldResNames[targetModule].template getAsValueRange<StringAttr>(),
536template <
typename... TOps>
538 StringSet<> &externModules,
char joinChar) {
539 return (failed(flattenOpsOfType<TOps>(module, recursive, externModules,
546class FlattenIOPass :
public circt::hw::impl::FlattenIOBase<FlattenIOPass> {
548 FlattenIOPass(
bool recursiveFlag,
bool flattenExternFlag,
char join) {
549 recursive = recursiveFlag;
550 flattenExtern = flattenExternFlag;
554 void runOnOperation()
override {
555 ModuleOp
module = getOperation();
556 if (!flattenExtern) {
558 for (
auto m : module.getOps<
hw::HWModuleExternOp>())
559 externModules.insert(m.getModuleName());
560 if (flattenIO<hw::HWModuleOp, hw::HWModuleGeneratedOp>(
561 module, recursive, externModules, joinChar))
567 hw::HWModuleGeneratedOp>(module, recursive, externModules,
573 StringSet<> externModules;
582 bool flattenExternFlag,
584 return std::make_unique<FlattenIOPass>(recursiveFlag, flattenExternFlag,
static LogicalResult compareTypes(Location loc, TypeRange rangeA, TypeRange rangeB)
static DenseMap< Operation *, IOInfo > populateIOInfoMap(mlir::ModuleOp module)
static DenseMap< Operation *, IOTypes > populateIOMap(mlir::ModuleOp module)
static llvm::SmallVector< Location > updateLocAttribute(DenseMap< unsigned, hw::StructType > &structMap, SmallVectorImpl< Location > &oldLocs)
static bool isLegalModLikeOp(hw::HWModuleLike moduleLikeOp)
static llvm::SmallVector< Attribute > updateNameAttribute(ModTy op, StringRef attrName, DenseMap< unsigned, hw::StructType > &structMap, T oldNames, char joinChar)
static void updateBlockLocations(hw::HWModuleLike op, DenseMap< unsigned, hw::StructType > &structMap)
The conversion framework seems to throw away block argument locations.
static bool hasUnconvertedOps(mlir::ModuleOp module)
static llvm::SmallVector< Type > getInnerTypes(hw::StructType t)
static void setIOInfo(hw::HWModuleLike op, IOInfo &ioInfo)
static LogicalResult flattenOpsOfType(ModuleOp module, bool recursive, StringSet<> &externModules, char joinChar)
static void addSignatureConversion(DenseMap< Operation *, IOInfo > &ioMap, ConversionTarget &target, RewritePatternSet &patterns, FlattenIOTypeConverter &typeConverter)
static bool isStructType(Type type)
static void updateModulePortNames(ModTy op, hw::ModuleType oldModType, char joinChar)
static hw::StructType getStructType(Type type)
static bool flattenIO(ModuleOp module, bool recursive, StringSet<> &externModules, char joinChar)
static SmallVector< Value > flattenValues(ArrayRef< ValueRange > values)
Flatten the given value ranges into a single vector of values.
std::unique_ptr< mlir::Pass > createFlattenIOPass(bool recursiveFlag=true, bool flattenExternFlag=false, char joinChar='.')
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.