16#include "mlir/IR/Block.h"
17#include "mlir/IR/Builders.h"
18#include "mlir/IR/BuiltinOps.h"
19#include "mlir/IR/SymbolTable.h"
20#include "mlir/IR/Threading.h"
21#include "mlir/Pass/Pass.h"
27#define GEN_PASS_DEF_LINKMODULES
28#include "circt/Dialect/OM/OMPasses.h.inc"
40 llvm::DenseMap<std::pair<ModuleOp, StringAttr>, StringAttr>;
43 ModuleInfo(mlir::ModuleOp module) : module(module) {
44 block = std::make_unique<Block>();
48 LogicalResult initialize();
51 void postProcess(
const SymMappingTy &symMapping);
54 llvm::DenseMap<StringAttr, ClassLike> symbolToClasses;
60 std::unique_ptr<Block> block;
64 :
public circt::om::impl::LinkModulesBase<LinkModulesPass> {
66 void runOnOperation()
override;
71LogicalResult ModuleInfo::initialize() {
72 for (
auto &op :
llvm::make_early_inc_range(module.getOps())) {
73 if (
auto classLike = dyn_cast<ClassLike>(op))
74 symbolToClasses.insert({classLike.getSymNameAttr(), classLike});
77 op.moveBefore(block.get(), block->end());
83void ModuleInfo::postProcess(
const SymMappingTy &symMapping) {
84 AttrTypeReplacer replacer;
85 replacer.addReplacement(
87 [&](om::ClassType classType) -> std::pair<mlir::Type, WalkResult> {
88 auto it = symMapping.find({module, classType.getClassName().getAttr()});
90 if (it == symMapping.end())
91 return {classType, WalkResult::skip()};
92 return {om::ClassType::get(classType.getContext(),
93 FlatSymbolRefAttr::get(it->second)),
97 module.walk<WalkOrder::PreOrder>([&](Operation *op) {
99 if (isa<ClassExternOp>(op)) {
102 return WalkResult::skip();
105 if (
auto classOp = dyn_cast<ClassOp>(op)) {
107 auto it = symMapping.find({module, classOp.getNameAttr()});
108 if (it != symMapping.end())
109 classOp.setSymNameAttr(it->second);
110 classOp.replaceFieldTypes(replacer);
111 }
else if (
auto objectOp = dyn_cast<ObjectOp>(op)) {
113 auto it = symMapping.find({module, objectOp.getClassNameAttr()});
114 if (it != symMapping.end())
115 objectOp.setClassNameAttr(it->second);
119 replacer.replaceElementsIn(op,
123 return WalkResult::advance();
130 ArrayRef<ClassLike> classes) {
131 bool existExternalClass =
false;
132 size_t countDefinition = 0;
135 for (
auto op : classes) {
136 if (isa<ClassExternOp>(op))
137 existExternalClass =
true;
139 classOp = cast<ClassOp>(op);
146 if (existExternalClass && countDefinition != 1) {
147 SmallVector<Location> classExternLocs;
148 SmallVector<Location> classLocs;
149 for (
auto op : classes)
150 (isa<ClassExternOp>(op) ? classExternLocs : classLocs)
151 .push_back(op.getLoc());
153 auto diag = emitError(classExternLocs.front())
154 <<
"class " << name <<
" is declared as an external class but "
155 << (countDefinition == 0 ?
"there is no definition"
156 :
"there are multiple definitions");
157 for (
auto loc : ArrayRef(classExternLocs).drop_front())
158 diag.attachNote(loc) <<
"class " << name <<
" is declared here as well";
160 if (countDefinition != 0) {
162 for (
auto loc : classLocs)
163 diag.attachNote(loc) <<
"class " << name <<
" is defined here";
168 if (!existExternalClass)
169 return countDefinition != 1;
171 assert(classOp && countDefinition == 1);
175 auto emitError = [&](Operation *op) {
176 auto diag = op->emitError()
177 <<
"failed to link class " << name
178 <<
" since declaration doesn't match the definition: ";
179 diag.attachNote(classOp.getLoc()) <<
"definition is here";
183 for (
auto op : classes) {
187 if (classOp.getBodyBlock()->getNumArguments() !=
188 op.getBodyBlock()->getNumArguments())
189 return emitError(op) <<
"the number of arguments is not equal, "
190 << classOp.getBodyBlock()->getNumArguments()
191 <<
" vs " << op.getBodyBlock()->getNumArguments();
193 for (
auto [l, r] : llvm::zip(classOp.getBodyBlock()->getArgumentTypes(),
194 op.getBodyBlock()->getArgumentTypes())) {
196 return emitError(op) << index <<
"-th argument type is not equal, " << l
201 llvm::DenseSet<StringAttr> declaredFields;
203 for (
auto nameAttr : op.getFieldNames()) {
204 StringAttr name = cast<StringAttr>(nameAttr);
205 std::optional<Type> opTypeOpt = op.getFieldType(name);
207 if (!opTypeOpt.has_value())
208 return emitError(op) <<
" no type for field " << name;
209 Type opType = opTypeOpt.value();
211 std::optional<Type> classTypeOpt = classOp.getFieldType(name);
214 if (!classTypeOpt.has_value())
215 return emitError(op) <<
"declaration has a field " << name
216 <<
" but not found in its definition";
217 Type classType = classTypeOpt.value();
219 if (classType != opType)
221 <<
"declaration has a field " << name
222 <<
" but types don't match, " << classType <<
" vs " << opType;
223 declaredFields.insert(name);
226 for (
auto fieldName : classOp.getFieldNames())
227 if (!declaredFields.count(cast<StringAttr>(fieldName)))
228 return emitError(op) <<
"definition has a field " << fieldName
229 <<
" but not found in this declaration";
234void LinkModulesPass::runOnOperation() {
235 auto toplevelModule = getOperation();
237 SmallVector<ModuleInfo> modules;
239 for (
auto module : toplevelModule.getOps<ModuleOp>()) {
240 auto name =
module->getAttrOfType<StringAttr>("om.namespace");
243 name = StringAttr::get(module.getContext(),
"module_" + Twine(counter++));
244 module->setAttr("om.namespace", name);
246 modules.emplace_back(module);
249 if (failed(failableParallelForEach(&getContext(), modules, [](
auto &info) {
251 return info.initialize();
253 return signalPassFailure();
261 SymMappingTy symMapping;
265 for (
const auto &info : modules)
266 for (auto &[name, op] :
info.symbolToClasses) {
267 symbolToClasses[name].push_back(op);
269 (void)nameSpace.
newName(name.getValue());
276 for (
auto &[name, classes] : symbolToClasses) {
281 return signalPassFailure();
287 for (
auto op : classes) {
288 auto enclosingModule = cast<mlir::ModuleOp>(op->getParentOp());
290 enclosingModule->getAttrOfType<StringAttr>(
"om.namespace");
291 symMapping[{enclosingModule, name}] = StringAttr::get(
293 nameSpace.
newName(name.getValue(), nameSpaceId.getValue()));
300 parallelForEach(&getContext(), modules,
301 [&](
auto &info) {
info.postProcess(symMapping); });
304 auto *block = toplevelModule.getBody();
305 for (
auto &info : modules) {
306 block->getOperations().splice(block->end(),
307 info.module.getBody()->getOperations());
310 info.module.getBody()->getOperations().splice(
311 info.module.getBody()->begin(),
info.block->getOperations());
assert(baseType &&"element must be base type")
static FailureOr< bool > resolveClasses(StringAttr name, ArrayRef< ClassLike > classes)
A namespace that is used to store existing names and generate new names in some scope within the IR.
StringRef newName(const Twine &name)
Return a unique name, derived from the input name, and add the new name to the internal namespace.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.