16 #include "mlir/IR/Builders.h"
17 #include "mlir/IR/BuiltinOps.h"
18 #include "mlir/IR/SymbolTable.h"
19 #include "mlir/IR/Threading.h"
20 #include "mlir/Pass/Pass.h"
26 #define GEN_PASS_DEF_LINKMODULES
27 #include "circt/Dialect/OM/OMPasses.h.inc"
31 using namespace circt;
39 llvm::DenseMap<std::pair<ModuleOp, StringAttr>, StringAttr>;
42 ModuleInfo(mlir::ModuleOp module) : module(module) {}
45 LogicalResult initialize();
48 void postProcess(
const SymMappingTy &symMapping);
51 llvm::DenseMap<StringAttr, ClassLike> symbolToClasses;
57 struct LinkModulesPass
58 :
public circt::om::impl::LinkModulesBase<LinkModulesPass> {
59 void runOnOperation()
override;
64 LogicalResult ModuleInfo::initialize() {
65 for (
auto &op : llvm::make_early_inc_range(module.getOps())) {
66 if (
auto classLike = dyn_cast<ClassLike>(op))
67 symbolToClasses.insert({classLike.getSymNameAttr(), classLike});
74 void ModuleInfo::postProcess(
const SymMappingTy &symMapping) {
75 AttrTypeReplacer replacer;
76 replacer.addReplacement(
78 [&](om::ClassType classType) -> std::pair<mlir::Type, WalkResult> {
79 auto it = symMapping.find({module, classType.getClassName().getAttr()});
81 if (it == symMapping.end())
82 return {classType, WalkResult::skip()};
83 return {om::ClassType::get(classType.getContext(),
84 FlatSymbolRefAttr::get(it->second)),
88 module.walk<WalkOrder::PreOrder>([&](Operation *op) {
90 if (isa<ClassExternOp>(op)) {
93 return WalkResult::skip();
96 if (
auto classOp = dyn_cast<ClassOp>(op)) {
98 auto it = symMapping.find({module, classOp.getNameAttr()});
99 if (it != symMapping.end())
100 classOp.setSymNameAttr(it->second);
101 classOp.replaceFieldTypes(replacer);
102 }
else if (
auto objectOp = dyn_cast<ObjectOp>(op)) {
104 auto it = symMapping.find({module, objectOp.getClassNameAttr()});
105 if (it != symMapping.end())
106 objectOp.setClassNameAttr(it->second);
110 replacer.replaceElementsIn(op,
114 return WalkResult::advance();
121 ArrayRef<ClassLike> classes) {
122 bool existExternalClass =
false;
123 size_t countDefinition = 0;
126 for (
auto op : classes) {
127 if (isa<ClassExternOp>(op))
128 existExternalClass =
true;
130 classOp = cast<ClassOp>(op);
137 if (existExternalClass && countDefinition != 1) {
138 SmallVector<Location> classExternLocs;
139 SmallVector<Location> classLocs;
140 for (
auto op : classes)
141 (isa<ClassExternOp>(op) ? classExternLocs : classLocs)
142 .push_back(op.getLoc());
144 auto diag = emitError(classExternLocs.front())
145 <<
"class " << name <<
" is declared as an external class but "
146 << (countDefinition == 0 ?
"there is no definition"
147 :
"there are multiple definitions");
148 for (
auto loc : ArrayRef(classExternLocs).drop_front())
149 diag.attachNote(loc) <<
"class " << name <<
" is declared here as well";
151 if (countDefinition != 0) {
153 for (
auto loc : classLocs)
154 diag.attachNote(loc) <<
"class " << name <<
" is defined here";
159 if (!existExternalClass)
160 return countDefinition != 1;
162 assert(classOp && countDefinition == 1);
166 auto emitError = [&](Operation *op) {
167 auto diag = op->emitError()
168 <<
"failed to link class " << name
169 <<
" since declaration doesn't match the definition: ";
170 diag.attachNote(classOp.getLoc()) <<
"definition is here";
174 for (
auto op : classes) {
178 if (classOp.getBodyBlock()->getNumArguments() !=
179 op.getBodyBlock()->getNumArguments())
180 return emitError(op) <<
"the number of arguments is not equal, "
181 << classOp.getBodyBlock()->getNumArguments()
182 <<
" vs " << op.getBodyBlock()->getNumArguments();
184 for (
auto [l, r] : llvm::zip(classOp.getBodyBlock()->getArgumentTypes(),
185 op.getBodyBlock()->getArgumentTypes())) {
187 return emitError(op) << index <<
"-th argument type is not equal, " << l
192 llvm::DenseSet<StringAttr> declaredFields;
194 for (
auto nameAttr : op.getFieldNames()) {
195 StringAttr name = cast<StringAttr>(nameAttr);
196 std::optional<Type> opTypeOpt = op.getFieldType(name);
198 if (!opTypeOpt.has_value())
199 return emitError(op) <<
" no type for field " << name;
200 Type opType = opTypeOpt.value();
202 std::optional<Type> classTypeOpt = classOp.getFieldType(name);
205 if (!classTypeOpt.has_value())
206 return emitError(op) <<
"declaration has a field " << name
207 <<
" but not found in its definition";
208 Type classType = classTypeOpt.value();
210 if (classType != opType)
212 <<
"declaration has a field " << name
213 <<
" but types don't match, " << classType <<
" vs " << opType;
214 declaredFields.insert(name);
217 for (
auto fieldName : classOp.getFieldNames())
218 if (!declaredFields.count(cast<StringAttr>(fieldName)))
219 return emitError(op) <<
"definition has a field " << fieldName
220 <<
" but not found in this declaration";
225 void LinkModulesPass::runOnOperation() {
226 auto toplevelModule = getOperation();
228 SmallVector<ModuleInfo> modules;
230 for (
auto module : toplevelModule.getOps<ModuleOp>()) {
231 auto name = module->getAttrOfType<StringAttr>(
"om.namespace");
234 name =
StringAttr::get(module.getContext(),
"module_" + Twine(counter++));
235 module->setAttr(
"om.namespace", name);
237 modules.emplace_back(module);
240 if (failed(failableParallelForEach(&getContext(), modules, [](
auto &info) {
242 return info.initialize();
244 return signalPassFailure();
252 SymMappingTy symMapping;
255 llvm::MapVector<StringAttr, SmallVector<ClassLike>> symbolToClasses;
256 for (
const auto &info : modules)
257 for (
auto &[name, op] : info.symbolToClasses) {
258 symbolToClasses[name].push_back(op);
260 (void)nameSpace.
newName(name.getValue());
267 for (
auto &[name, classes] : symbolToClasses) {
272 return signalPassFailure();
278 for (
auto op : classes) {
279 auto enclosingModule = cast<mlir::ModuleOp>(op->getParentOp());
281 enclosingModule->getAttrOfType<StringAttr>(
"om.namespace");
284 nameSpace.
newName(name.getValue(), nameSpaceId.getValue()));
291 parallelForEach(&getContext(), modules,
292 [&](
auto &info) { info.postProcess(symMapping); });
295 auto *block = toplevelModule.getBody();
296 for (
auto info : modules) {
297 block->getOperations().splice(block->end(),
298 info.module.getBody()->getOperations());
305 return std::make_unique<LinkModulesPass>();
assert(baseType &&"element must be base type")
static FailureOr< bool > resolveClasses(StringAttr name, ArrayRef< ClassLike > classes)
A namespace that is used to store existing names and generate new names in some scope within the IR.
StringRef newName(const Twine &name)
Return a unique name, derived from the input name, and add the new name to the internal namespace.
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
std::unique_ptr< mlir::Pass > createOMLinkModulesPass()
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.