17 #include "mlir/IR/Builders.h"
18 #include "mlir/IR/BuiltinOps.h"
19 #include "mlir/IR/SymbolTable.h"
20 #include "mlir/IR/Threading.h"
24 using namespace circt;
32 llvm::DenseMap<std::pair<ModuleOp, StringAttr>, StringAttr>;
35 ModuleInfo(mlir::ModuleOp module) : module(module) {}
38 LogicalResult initialize();
41 void postProcess(
const SymMappingTy &symMapping);
44 llvm::DenseMap<StringAttr, ClassLike> symbolToClasses;
50 struct LinkModulesPass :
public LinkModulesBase<LinkModulesPass> {
51 void runOnOperation()
override;
56 LogicalResult ModuleInfo::initialize() {
57 for (
auto &op : llvm::make_early_inc_range(module.getOps())) {
58 if (
auto classLike = dyn_cast<ClassLike>(op))
59 symbolToClasses.insert({classLike.getSymNameAttr(), classLike});
66 void ModuleInfo::postProcess(
const SymMappingTy &symMapping) {
67 AttrTypeReplacer replacer;
68 replacer.addReplacement(
70 [&](om::ClassType classType) -> std::pair<mlir::Type, WalkResult> {
71 auto it = symMapping.find({module, classType.getClassName().
getAttr()});
73 if (it == symMapping.end())
74 return {classType, WalkResult::skip()};
75 return {om::ClassType::get(classType.getContext(),
76 FlatSymbolRefAttr::get(it->second)),
80 module.walk<WalkOrder::PreOrder>([&](Operation *op) {
82 if (isa<ClassExternOp>(op)) {
85 return WalkResult::skip();
88 if (
auto classOp = dyn_cast<ClassOp>(op)) {
90 auto it = symMapping.find({module, classOp.getNameAttr()});
91 if (it != symMapping.end())
92 classOp.setSymNameAttr(it->second);
93 }
else if (
auto objectOp = dyn_cast<ObjectOp>(op)) {
95 auto it = symMapping.find({module, objectOp.getClassNameAttr()});
96 if (it != symMapping.end())
97 objectOp.setClassNameAttr(it->second);
101 replacer.replaceElementsIn(op,
105 return WalkResult::advance();
112 ArrayRef<ClassLike> classes) {
113 bool existExternalClass =
false;
114 size_t countDefinition = 0;
117 for (
auto op : classes) {
118 if (isa<ClassExternOp>(op))
119 existExternalClass =
true;
121 classOp = cast<ClassOp>(op);
128 if (existExternalClass && countDefinition != 1) {
129 SmallVector<Location> classExternLocs;
130 SmallVector<Location> classLocs;
131 for (
auto op : classes)
132 (isa<ClassExternOp>(op) ? classExternLocs : classLocs)
133 .push_back(op.getLoc());
135 auto diag = emitError(classExternLocs.front())
136 <<
"class " << name <<
" is declared as an external class but "
137 << (countDefinition == 0 ?
"there is no definition"
138 :
"there are multiple definitions");
139 for (
auto loc : ArrayRef(classExternLocs).drop_front())
140 diag.attachNote(loc) <<
"class " << name <<
" is declared here as well";
142 if (countDefinition != 0) {
144 for (
auto loc : classLocs)
145 diag.attachNote(loc) <<
"class " << name <<
" is defined here";
150 if (!existExternalClass)
151 return countDefinition != 1;
153 assert(classOp && countDefinition == 1);
157 auto emitError = [&](Operation *op) {
158 auto diag = op->emitError()
159 <<
"failed to link class " << name
160 <<
" since declaration doesn't match the definition: ";
161 diag.attachNote(classOp.getLoc()) <<
"definition is here";
165 llvm::MapVector<StringAttr, Type> classFields;
166 for (
auto fieldOp : classOp.getOps<om::ClassFieldOp>())
167 classFields.insert({fieldOp.getNameAttr(), fieldOp.getType()});
169 for (
auto op : classes) {
173 if (classOp.getBodyBlock()->getNumArguments() !=
174 op.getBodyBlock()->getNumArguments())
175 return emitError(op) <<
"the number of arguments is not equal, "
176 << classOp.getBodyBlock()->getNumArguments()
177 <<
" vs " << op.getBodyBlock()->getNumArguments();
179 for (
auto [l, r] : llvm::zip(classOp.getBodyBlock()->getArgumentTypes(),
180 op.getBodyBlock()->getArgumentTypes())) {
182 return emitError(op) << index <<
"-th argument type is not equal, " << l
187 llvm::DenseSet<StringAttr> declaredFields;
188 for (
auto fieldOp : op.getBodyBlock()->getOps<om::ClassExternFieldOp>()) {
189 auto it = classFields.find(fieldOp.getNameAttr());
192 if (it == classFields.end())
194 <<
"declaration has a field " << fieldOp.getNameAttr()
195 <<
" but not found in its definition";
197 if (it->second != fieldOp.getType())
199 <<
"declaration has a field " << fieldOp.getNameAttr()
200 <<
" but types don't match, " << it->second <<
" vs "
201 << fieldOp.getType();
202 declaredFields.insert(fieldOp.getNameAttr());
205 for (
auto [fieldName, _] : classFields)
206 if (!declaredFields.count(fieldName))
207 return emitError(op) <<
"definition has a field " << fieldName
208 <<
" but not found in this declaration";
213 void LinkModulesPass::runOnOperation() {
214 auto toplevelModule = getOperation();
216 SmallVector<ModuleInfo> modules;
218 for (
auto module : toplevelModule.getOps<ModuleOp>()) {
219 auto name = module->getAttrOfType<StringAttr>(
"om.namespace");
222 name =
StringAttr::get(module.getContext(),
"module_" + Twine(counter++));
223 module->setAttr(
"om.namespace", name);
225 modules.emplace_back(module);
228 if (failed(failableParallelForEach(&getContext(), modules, [](
auto &info) {
230 return info.initialize();
232 return signalPassFailure();
240 SymMappingTy symMapping;
243 llvm::MapVector<StringAttr, SmallVector<ClassLike>> symbolToClasses;
244 for (
const auto &info : modules)
245 for (
auto &[name, op] : info.symbolToClasses) {
246 symbolToClasses[name].push_back(op);
248 (void)nameSpace.
newName(name.getValue());
255 for (
auto &[name, classes] : symbolToClasses) {
260 return signalPassFailure();
266 for (
auto op : classes) {
267 auto enclosingModule = cast<mlir::ModuleOp>(op->getParentOp());
269 enclosingModule->getAttrOfType<StringAttr>(
"om.namespace");
272 nameSpace.
newName(name.getValue(), nameSpaceId.getValue()));
279 parallelForEach(&getContext(), modules,
280 [&](
auto &info) { info.postProcess(symMapping); });
283 auto *block = toplevelModule.getBody();
284 for (
auto info : modules) {
285 block->getOperations().splice(block->end(),
286 info.module.getBody()->getOperations());
293 return std::make_unique<LinkModulesPass>();
assert(baseType &&"element must be base type")
static Attribute getAttr(ArrayRef< NamedAttribute > attrs, StringRef name)
Get an attribute by name from a list of named attributes.
static FailureOr< bool > resolveClasses(StringAttr name, ArrayRef< ClassLike > classes)
A namespace that is used to store existing names and generate new names in some scope within the IR.
StringRef newName(const Twine &name)
Return a unique name, derived from the input name, and add the new name to the internal namespace.
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
std::unique_ptr< mlir::Pass > createOMLinkModulesPass()
This file defines an intermediate representation for circuits acting as an abstraction for constraint...