16 #include "mlir/IR/Builders.h"
17 #include "mlir/IR/BuiltinOps.h"
18 #include "mlir/IR/SymbolTable.h"
19 #include "mlir/IR/Threading.h"
20 #include "mlir/Pass/Pass.h"
26 #define GEN_PASS_DEF_LINKMODULES
27 #include "circt/Dialect/OM/OMPasses.h.inc"
31 using namespace circt;
39 llvm::DenseMap<std::pair<ModuleOp, StringAttr>, StringAttr>;
42 ModuleInfo(mlir::ModuleOp module) : module(module) {}
45 LogicalResult initialize();
48 void postProcess(
const SymMappingTy &symMapping);
51 llvm::DenseMap<StringAttr, ClassLike> symbolToClasses;
57 struct LinkModulesPass
58 :
public circt::om::impl::LinkModulesBase<LinkModulesPass> {
59 void runOnOperation()
override;
64 LogicalResult ModuleInfo::initialize() {
65 for (
auto &op : llvm::make_early_inc_range(module.getOps())) {
66 if (
auto classLike = dyn_cast<ClassLike>(op))
67 symbolToClasses.insert({classLike.getSymNameAttr(), classLike});
74 void ModuleInfo::postProcess(
const SymMappingTy &symMapping) {
75 AttrTypeReplacer replacer;
76 replacer.addReplacement(
78 [&](om::ClassType classType) -> std::pair<mlir::Type, WalkResult> {
79 auto it = symMapping.find({module, classType.getClassName().
getAttr()});
81 if (it == symMapping.end())
82 return {classType, WalkResult::skip()};
83 return {om::ClassType::get(classType.getContext(),
84 FlatSymbolRefAttr::get(it->second)),
88 module.walk<WalkOrder::PreOrder>([&](Operation *op) {
90 if (isa<ClassExternOp>(op)) {
93 return WalkResult::skip();
96 if (
auto classOp = dyn_cast<ClassOp>(op)) {
98 auto it = symMapping.find({module, classOp.getNameAttr()});
99 if (it != symMapping.end())
100 classOp.setSymNameAttr(it->second);
101 }
else if (
auto objectOp = dyn_cast<ObjectOp>(op)) {
103 auto it = symMapping.find({module, objectOp.getClassNameAttr()});
104 if (it != symMapping.end())
105 objectOp.setClassNameAttr(it->second);
109 replacer.replaceElementsIn(op,
113 return WalkResult::advance();
120 ArrayRef<ClassLike> classes) {
121 bool existExternalClass =
false;
122 size_t countDefinition = 0;
125 for (
auto op : classes) {
126 if (isa<ClassExternOp>(op))
127 existExternalClass =
true;
129 classOp = cast<ClassOp>(op);
136 if (existExternalClass && countDefinition != 1) {
137 SmallVector<Location> classExternLocs;
138 SmallVector<Location> classLocs;
139 for (
auto op : classes)
140 (isa<ClassExternOp>(op) ? classExternLocs : classLocs)
141 .push_back(op.getLoc());
143 auto diag = emitError(classExternLocs.front())
144 <<
"class " << name <<
" is declared as an external class but "
145 << (countDefinition == 0 ?
"there is no definition"
146 :
"there are multiple definitions");
147 for (
auto loc : ArrayRef(classExternLocs).drop_front())
148 diag.attachNote(loc) <<
"class " << name <<
" is declared here as well";
150 if (countDefinition != 0) {
152 for (
auto loc : classLocs)
153 diag.attachNote(loc) <<
"class " << name <<
" is defined here";
158 if (!existExternalClass)
159 return countDefinition != 1;
161 assert(classOp && countDefinition == 1);
165 auto emitError = [&](Operation *op) {
166 auto diag = op->emitError()
167 <<
"failed to link class " << name
168 <<
" since declaration doesn't match the definition: ";
169 diag.attachNote(classOp.getLoc()) <<
"definition is here";
173 llvm::MapVector<StringAttr, Type> classFields;
174 for (
auto fieldOp : classOp.getOps<om::ClassFieldOp>())
175 classFields.insert({fieldOp.getNameAttr(), fieldOp.getType()});
177 for (
auto op : classes) {
181 if (classOp.getBodyBlock()->getNumArguments() !=
182 op.getBodyBlock()->getNumArguments())
183 return emitError(op) <<
"the number of arguments is not equal, "
184 << classOp.getBodyBlock()->getNumArguments()
185 <<
" vs " << op.getBodyBlock()->getNumArguments();
187 for (
auto [l, r] : llvm::zip(classOp.getBodyBlock()->getArgumentTypes(),
188 op.getBodyBlock()->getArgumentTypes())) {
190 return emitError(op) << index <<
"-th argument type is not equal, " << l
195 llvm::DenseSet<StringAttr> declaredFields;
196 for (
auto fieldOp : op.getBodyBlock()->getOps<om::ClassExternFieldOp>()) {
197 auto it = classFields.find(fieldOp.getNameAttr());
200 if (it == classFields.end())
202 <<
"declaration has a field " << fieldOp.getNameAttr()
203 <<
" but not found in its definition";
205 if (it->second != fieldOp.getType())
207 <<
"declaration has a field " << fieldOp.getNameAttr()
208 <<
" but types don't match, " << it->second <<
" vs "
209 << fieldOp.getType();
210 declaredFields.insert(fieldOp.getNameAttr());
213 for (
auto [fieldName, _] : classFields)
214 if (!declaredFields.count(fieldName))
215 return emitError(op) <<
"definition has a field " << fieldName
216 <<
" but not found in this declaration";
221 void LinkModulesPass::runOnOperation() {
222 auto toplevelModule = getOperation();
224 SmallVector<ModuleInfo> modules;
226 for (
auto module : toplevelModule.getOps<ModuleOp>()) {
227 auto name = module->getAttrOfType<StringAttr>(
"om.namespace");
230 name =
StringAttr::get(module.getContext(),
"module_" + Twine(counter++));
231 module->setAttr(
"om.namespace", name);
233 modules.emplace_back(module);
236 if (failed(failableParallelForEach(&getContext(), modules, [](
auto &info) {
238 return info.initialize();
240 return signalPassFailure();
248 SymMappingTy symMapping;
251 llvm::MapVector<StringAttr, SmallVector<ClassLike>> symbolToClasses;
252 for (
const auto &info : modules)
253 for (
auto &[name, op] : info.symbolToClasses) {
254 symbolToClasses[name].push_back(op);
256 (void)nameSpace.
newName(name.getValue());
263 for (
auto &[name, classes] : symbolToClasses) {
268 return signalPassFailure();
274 for (
auto op : classes) {
275 auto enclosingModule = cast<mlir::ModuleOp>(op->getParentOp());
277 enclosingModule->getAttrOfType<StringAttr>(
"om.namespace");
280 nameSpace.
newName(name.getValue(), nameSpaceId.getValue()));
287 parallelForEach(&getContext(), modules,
288 [&](
auto &info) { info.postProcess(symMapping); });
291 auto *block = toplevelModule.getBody();
292 for (
auto info : modules) {
293 block->getOperations().splice(block->end(),
294 info.module.getBody()->getOperations());
301 return std::make_unique<LinkModulesPass>();
assert(baseType &&"element must be base type")
static Attribute getAttr(ArrayRef< NamedAttribute > attrs, StringRef name)
Get an attribute by name from a list of named attributes.
static FailureOr< bool > resolveClasses(StringAttr name, ArrayRef< ClassLike > classes)
A namespace that is used to store existing names and generate new names in some scope within the IR.
StringRef newName(const Twine &name)
Return a unique name, derived from the input name, and add the new name to the internal namespace.
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
std::unique_ptr< mlir::Pass > createOMLinkModulesPass()
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.