16#include "mlir/IR/Block.h"
17#include "mlir/IR/Builders.h"
18#include "mlir/IR/BuiltinOps.h"
19#include "mlir/IR/SymbolTable.h"
20#include "mlir/IR/Threading.h"
21#include "mlir/Pass/Pass.h"
27#define GEN_PASS_DEF_LINKMODULES
28#include "circt/Dialect/OM/OMPasses.h.inc"
40 llvm::DenseMap<std::pair<ModuleOp, StringAttr>, StringAttr>;
43 ModuleInfo(mlir::ModuleOp module) : module(module) {
44 block = std::make_unique<Block>();
48 LogicalResult initialize();
51 void postProcess(
const SymMappingTy &symMapping);
54 llvm::DenseMap<StringAttr, ClassLike> symbolToClasses;
60 std::unique_ptr<Block> block;
64 :
public circt::om::impl::LinkModulesBase<LinkModulesPass> {
65 void runOnOperation()
override;
70LogicalResult ModuleInfo::initialize() {
71 for (
auto &op :
llvm::make_early_inc_range(module.getOps())) {
72 if (
auto classLike = dyn_cast<ClassLike>(op))
73 symbolToClasses.insert({classLike.getSymNameAttr(), classLike});
76 op.moveBefore(block.get(), block->end());
82void ModuleInfo::postProcess(
const SymMappingTy &symMapping) {
83 AttrTypeReplacer replacer;
84 replacer.addReplacement(
86 [&](om::ClassType classType) -> std::pair<mlir::Type, WalkResult> {
87 auto it = symMapping.find({module, classType.getClassName().getAttr()});
89 if (it == symMapping.end())
90 return {classType, WalkResult::skip()};
91 return {om::ClassType::get(classType.getContext(),
92 FlatSymbolRefAttr::get(it->second)),
96 module.walk<WalkOrder::PreOrder>([&](Operation *op) {
98 if (isa<ClassExternOp>(op)) {
101 return WalkResult::skip();
104 if (
auto classOp = dyn_cast<ClassOp>(op)) {
106 auto it = symMapping.find({module, classOp.getNameAttr()});
107 if (it != symMapping.end())
108 classOp.setSymNameAttr(it->second);
109 classOp.replaceFieldTypes(replacer);
110 }
else if (
auto objectOp = dyn_cast<ObjectOp>(op)) {
112 auto it = symMapping.find({module, objectOp.getClassNameAttr()});
113 if (it != symMapping.end())
114 objectOp.setClassNameAttr(it->second);
118 replacer.replaceElementsIn(op,
122 return WalkResult::advance();
129 ArrayRef<ClassLike> classes) {
130 bool existExternalClass =
false;
131 size_t countDefinition = 0;
134 for (
auto op : classes) {
135 if (isa<ClassExternOp>(op))
136 existExternalClass =
true;
138 classOp = cast<ClassOp>(op);
145 if (existExternalClass && countDefinition != 1) {
146 SmallVector<Location> classExternLocs;
147 SmallVector<Location> classLocs;
148 for (
auto op : classes)
149 (isa<ClassExternOp>(op) ? classExternLocs : classLocs)
150 .push_back(op.getLoc());
152 auto diag = emitError(classExternLocs.front())
153 <<
"class " << name <<
" is declared as an external class but "
154 << (countDefinition == 0 ?
"there is no definition"
155 :
"there are multiple definitions");
156 for (
auto loc : ArrayRef(classExternLocs).drop_front())
157 diag.attachNote(loc) <<
"class " << name <<
" is declared here as well";
159 if (countDefinition != 0) {
161 for (
auto loc : classLocs)
162 diag.attachNote(loc) <<
"class " << name <<
" is defined here";
167 if (!existExternalClass)
168 return countDefinition != 1;
170 assert(classOp && countDefinition == 1);
174 auto emitError = [&](Operation *op) {
175 auto diag = op->emitError()
176 <<
"failed to link class " << name
177 <<
" since declaration doesn't match the definition: ";
178 diag.attachNote(classOp.getLoc()) <<
"definition is here";
182 for (
auto op : classes) {
186 if (classOp.getBodyBlock()->getNumArguments() !=
187 op.getBodyBlock()->getNumArguments())
188 return emitError(op) <<
"the number of arguments is not equal, "
189 << classOp.getBodyBlock()->getNumArguments()
190 <<
" vs " << op.getBodyBlock()->getNumArguments();
192 for (
auto [l, r] : llvm::zip(classOp.getBodyBlock()->getArgumentTypes(),
193 op.getBodyBlock()->getArgumentTypes())) {
195 return emitError(op) << index <<
"-th argument type is not equal, " << l
200 llvm::DenseSet<StringAttr> declaredFields;
202 for (
auto nameAttr : op.getFieldNames()) {
203 StringAttr name = cast<StringAttr>(nameAttr);
204 std::optional<Type> opTypeOpt = op.getFieldType(name);
206 if (!opTypeOpt.has_value())
207 return emitError(op) <<
" no type for field " << name;
208 Type opType = opTypeOpt.value();
210 std::optional<Type> classTypeOpt = classOp.getFieldType(name);
213 if (!classTypeOpt.has_value())
214 return emitError(op) <<
"declaration has a field " << name
215 <<
" but not found in its definition";
216 Type classType = classTypeOpt.value();
218 if (classType != opType)
220 <<
"declaration has a field " << name
221 <<
" but types don't match, " << classType <<
" vs " << opType;
222 declaredFields.insert(name);
225 for (
auto fieldName : classOp.getFieldNames())
226 if (!declaredFields.count(cast<StringAttr>(fieldName)))
227 return emitError(op) <<
"definition has a field " << fieldName
228 <<
" but not found in this declaration";
233void LinkModulesPass::runOnOperation() {
234 auto toplevelModule = getOperation();
236 SmallVector<ModuleInfo> modules;
238 for (
auto module : toplevelModule.getOps<ModuleOp>()) {
239 auto name =
module->getAttrOfType<StringAttr>("om.namespace");
242 name = StringAttr::get(module.getContext(),
"module_" + Twine(counter++));
243 module->setAttr("om.namespace", name);
245 modules.emplace_back(module);
248 if (failed(failableParallelForEach(&getContext(), modules, [](
auto &info) {
250 return info.initialize();
252 return signalPassFailure();
260 SymMappingTy symMapping;
263 llvm::MapVector<StringAttr, SmallVector<ClassLike>> symbolToClasses;
264 for (
const auto &info : modules)
265 for (auto &[name, op] : info.symbolToClasses) {
266 symbolToClasses[name].push_back(op);
268 (void)nameSpace.
newName(name.getValue());
275 for (
auto &[name, classes] : symbolToClasses) {
280 return signalPassFailure();
286 for (
auto op : classes) {
287 auto enclosingModule = cast<mlir::ModuleOp>(op->getParentOp());
289 enclosingModule->getAttrOfType<StringAttr>(
"om.namespace");
290 symMapping[{enclosingModule, name}] = StringAttr::get(
292 nameSpace.
newName(name.getValue(), nameSpaceId.getValue()));
299 parallelForEach(&getContext(), modules,
300 [&](
auto &info) { info.postProcess(symMapping); });
303 auto *block = toplevelModule.getBody();
304 for (
auto &info : modules) {
305 block->getOperations().splice(block->end(),
306 info.module.getBody()->getOperations());
308 assert(info.module.getBody()->empty());
309 info.module.getBody()->getOperations().splice(
310 info.module.getBody()->begin(), info.block->getOperations());
315 return std::make_unique<LinkModulesPass>();
assert(baseType &&"element must be base type")
static FailureOr< bool > resolveClasses(StringAttr name, ArrayRef< ClassLike > classes)
A namespace that is used to store existing names and generate new names in some scope within the IR.
StringRef newName(const Twine &name)
Return a unique name, derived from the input name, and add the new name to the internal namespace.
std::unique_ptr< mlir::Pass > createOMLinkModulesPass()
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.