16#include "mlir/IR/Block.h"
17#include "mlir/IR/Builders.h"
18#include "mlir/IR/BuiltinOps.h"
19#include "mlir/IR/SymbolTable.h"
20#include "mlir/IR/Threading.h"
21#include "mlir/Pass/Pass.h"
22#include "llvm/ADT/TypeSwitch.h"
28#define GEN_PASS_DEF_LINKMODULES
29#include "circt/Dialect/OM/OMPasses.h.inc"
41 llvm::DenseMap<std::pair<ModuleOp, StringAttr>, StringAttr>;
44 ModuleInfo(mlir::ModuleOp module) : module(module) {
45 block = std::make_unique<Block>();
49 LogicalResult initialize();
52 void postProcess(
const SymMappingTy &symMapping);
55 llvm::DenseMap<StringAttr, ClassLike> symbolToClasses;
61 std::unique_ptr<Block> block;
65 :
public circt::om::impl::LinkModulesBase<LinkModulesPass> {
67 void runOnOperation()
override;
72LogicalResult ModuleInfo::initialize() {
73 for (
auto &op :
llvm::make_early_inc_range(module.getOps())) {
74 if (
auto classLike = dyn_cast<ClassLike>(op))
75 symbolToClasses.insert({classLike.getSymNameAttr(), classLike});
78 op.moveBefore(block.get(), block->end());
84void ModuleInfo::postProcess(
const SymMappingTy &symMapping) {
85 AttrTypeReplacer replacer;
86 replacer.addReplacement(
88 [&](om::ClassType classType) -> std::pair<mlir::Type, WalkResult> {
89 auto it = symMapping.find({module, classType.getClassName().getAttr()});
91 if (it == symMapping.end())
92 return {classType, WalkResult::skip()};
93 return {om::ClassType::get(classType.getContext(),
94 FlatSymbolRefAttr::get(it->second)),
98 module.walk<WalkOrder::PreOrder>([&](Operation *op) {
100 if (isa<ClassExternOp>(op)) {
103 return WalkResult::skip();
106 llvm::TypeSwitch<Operation *>(op)
107 .Case<ClassOp>([&](ClassOp classOp) {
108 auto it = symMapping.find({module, classOp.getNameAttr()});
110 if (it != symMapping.end())
111 classOp.setSymNameAttr(it->second);
112 classOp.replaceFieldTypes(replacer);
114 .Case<ObjectOp, ElaboratedObjectOp>([&](
auto objectLike) {
115 auto it = symMapping.find(
116 {module, objectLike.getClassNameAttr().getAttr()});
118 if (it != symMapping.end())
119 objectLike.setClassNameAttr(FlatSymbolRefAttr::get(it->second));
120 replacer.replaceElementsIn(objectLike,
false,
false,
true);
124 replacer.replaceElementsIn(op,
128 return WalkResult::advance();
135 ArrayRef<ClassLike> classes) {
136 bool existExternalClass =
false;
137 size_t countDefinition = 0;
140 for (
auto op : classes) {
141 if (isa<ClassExternOp>(op))
142 existExternalClass =
true;
144 classOp = cast<ClassOp>(op);
151 if (existExternalClass && countDefinition != 1) {
152 SmallVector<Location> classExternLocs;
153 SmallVector<Location> classLocs;
154 for (
auto op : classes)
155 (isa<ClassExternOp>(op) ? classExternLocs : classLocs)
156 .push_back(op.getLoc());
158 auto diag = emitError(classExternLocs.front())
159 <<
"class " << name <<
" is declared as an external class but "
160 << (countDefinition == 0 ?
"there is no definition"
161 :
"there are multiple definitions");
162 for (
auto loc : ArrayRef(classExternLocs).drop_front())
163 diag.attachNote(loc) <<
"class " << name <<
" is declared here as well";
165 if (countDefinition != 0) {
167 for (
auto loc : classLocs)
168 diag.attachNote(loc) <<
"class " << name <<
" is defined here";
173 if (!existExternalClass)
174 return countDefinition != 1;
176 assert(classOp && countDefinition == 1);
180 auto emitError = [&](Operation *op) {
181 auto diag = op->emitError()
182 <<
"failed to link class " << name
183 <<
" since declaration doesn't match the definition: ";
184 diag.attachNote(classOp.getLoc()) <<
"definition is here";
188 for (
auto op : classes) {
192 if (classOp.getBodyBlock()->getNumArguments() !=
193 op.getBodyBlock()->getNumArguments())
194 return emitError(op) <<
"the number of arguments is not equal, "
195 << classOp.getBodyBlock()->getNumArguments()
196 <<
" vs " << op.getBodyBlock()->getNumArguments();
198 for (
auto [l, r] : llvm::zip(classOp.getBodyBlock()->getArgumentTypes(),
199 op.getBodyBlock()->getArgumentTypes())) {
201 return emitError(op) << index <<
"-th argument type is not equal, " << l
206 llvm::DenseSet<StringAttr> declaredFields;
208 for (
auto nameAttr : op.getFieldNames()) {
209 StringAttr name = cast<StringAttr>(nameAttr);
210 std::optional<Type> opTypeOpt = op.getFieldType(name);
212 if (!opTypeOpt.has_value())
213 return emitError(op) <<
" no type for field " << name;
214 Type opType = opTypeOpt.value();
216 std::optional<Type> classTypeOpt = classOp.getFieldType(name);
219 if (!classTypeOpt.has_value())
220 return emitError(op) <<
"declaration has a field " << name
221 <<
" but not found in its definition";
222 Type classType = classTypeOpt.value();
224 if (classType != opType)
226 <<
"declaration has a field " << name
227 <<
" but types don't match, " << classType <<
" vs " << opType;
228 declaredFields.insert(name);
231 for (
auto fieldName : classOp.getFieldNames())
232 if (!declaredFields.count(cast<StringAttr>(fieldName)))
233 return emitError(op) <<
"definition has a field " << fieldName
234 <<
" but not found in this declaration";
239void LinkModulesPass::runOnOperation() {
240 auto toplevelModule = getOperation();
242 SmallVector<ModuleInfo> modules;
244 for (
auto module : toplevelModule.getOps<ModuleOp>()) {
245 auto name =
module->getAttrOfType<StringAttr>("om.namespace");
248 name = StringAttr::get(module.getContext(),
"module_" + Twine(counter++));
249 module->setAttr("om.namespace", name);
251 modules.emplace_back(module);
254 if (failed(failableParallelForEach(&getContext(), modules, [](
auto &info) {
256 return info.initialize();
258 return signalPassFailure();
266 SymMappingTy symMapping;
270 for (
const auto &info : modules)
271 for (auto &[name, op] :
info.symbolToClasses) {
272 symbolToClasses[name].push_back(op);
274 (void)nameSpace.
newName(name.getValue());
281 for (
auto &[name, classes] : symbolToClasses) {
286 return signalPassFailure();
292 for (
auto op : classes) {
293 auto enclosingModule = cast<mlir::ModuleOp>(op->getParentOp());
295 enclosingModule->getAttrOfType<StringAttr>(
"om.namespace");
296 symMapping[{enclosingModule, name}] = StringAttr::get(
298 nameSpace.
newName(name.getValue(), nameSpaceId.getValue()));
305 parallelForEach(&getContext(), modules,
306 [&](
auto &info) {
info.postProcess(symMapping); });
309 auto *block = toplevelModule.getBody();
310 for (
auto &info : modules) {
311 block->getOperations().splice(block->end(),
312 info.module.getBody()->getOperations());
315 info.module.getBody()->getOperations().splice(
316 info.module.getBody()->begin(),
info.block->getOperations());
assert(baseType &&"element must be base type")
static FailureOr< bool > resolveClasses(StringAttr name, ArrayRef< ClassLike > classes)
A namespace that is used to store existing names and generate new names in some scope within the IR.
StringRef newName(const Twine &name)
Return a unique name, derived from the input name, and add the new name to the internal namespace.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.