CIRCT 23.0.0git
Loading...
Searching...
No Matches
LinkModules.cpp
Go to the documentation of this file.
1//===- OMLinkModules.cpp - OM Linker pass -----------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Contains the definitions of the OM Linker pass.
10//
11//===----------------------------------------------------------------------===//
12
16#include "mlir/IR/Block.h"
17#include "mlir/IR/Builders.h"
18#include "mlir/IR/BuiltinOps.h"
19#include "mlir/IR/SymbolTable.h"
20#include "mlir/IR/Threading.h"
21#include "mlir/Pass/Pass.h"
22
23#include <memory>
24
25namespace circt {
26namespace om {
27#define GEN_PASS_DEF_LINKMODULES
28#include "circt/Dialect/OM/OMPasses.h.inc"
29} // namespace om
30} // namespace circt
31
32using namespace circt;
33using namespace om;
34using namespace mlir;
35using namespace llvm;
36
37namespace {
38// A map from a pair of enclosing module op and old symbol to a new symbol.
39using SymMappingTy =
40 llvm::DenseMap<std::pair<ModuleOp, StringAttr>, StringAttr>;
41
42struct ModuleInfo {
43 ModuleInfo(mlir::ModuleOp module) : module(module) {
44 block = std::make_unique<Block>();
45 }
46
47 // Populate `symbolToClasses`.
48 LogicalResult initialize();
49
50 // Update symbols based on the mapping and erase external classes.
51 void postProcess(const SymMappingTy &symMapping);
52
53 // A map from symbols to classes.
54 llvm::DenseMap<StringAttr, ClassLike> symbolToClasses;
55
56 // A target module.
57 ModuleOp module;
58
59 // A block to store original operations.
60 std::unique_ptr<Block> block;
61};
62
63struct LinkModulesPass
64 : public circt::om::impl::LinkModulesBase<LinkModulesPass> {
65 using Base::Base;
66 void runOnOperation() override;
67};
68
69} // namespace
70
71LogicalResult ModuleInfo::initialize() {
72 for (auto &op : llvm::make_early_inc_range(module.getOps())) {
73 if (auto classLike = dyn_cast<ClassLike>(op))
74 symbolToClasses.insert({classLike.getSymNameAttr(), classLike});
75 else {
76 // Keep the op.
77 op.moveBefore(block.get(), block->end());
78 }
79 }
80 return success();
81}
82
83void ModuleInfo::postProcess(const SymMappingTy &symMapping) {
84 AttrTypeReplacer replacer;
85 replacer.addReplacement(
86 // Update class types when their symbols were renamed.
87 [&](om::ClassType classType) -> std::pair<mlir::Type, WalkResult> {
88 auto it = symMapping.find({module, classType.getClassName().getAttr()});
89 // No change.
90 if (it == symMapping.end())
91 return {classType, WalkResult::skip()};
92 return {om::ClassType::get(classType.getContext(),
93 FlatSymbolRefAttr::get(it->second)),
94 WalkResult::skip()};
95 });
96
97 module.walk<WalkOrder::PreOrder>([&](Operation *op) {
98 // External modules must be erased.
99 if (isa<ClassExternOp>(op)) {
100 op->erase();
101 // ClassExternFieldOp will be deleted as well.
102 return WalkResult::skip();
103 }
104
105 if (auto classOp = dyn_cast<ClassOp>(op)) {
106 // Update its class name if changed.
107 auto it = symMapping.find({module, classOp.getNameAttr()});
108 if (it != symMapping.end())
109 classOp.setSymNameAttr(it->second);
110 classOp.replaceFieldTypes(replacer);
111 } else if (auto objectOp = dyn_cast<ObjectOp>(op)) {
112 // Update its class name if changed..
113 auto it = symMapping.find({module, objectOp.getClassNameAttr()});
114 if (it != symMapping.end())
115 objectOp.setClassNameAttr(it->second);
116 }
117
118 // Otherwise update om.class types.
119 replacer.replaceElementsIn(op,
120 /*replaceAttrs=*/false,
121 /*replaceLocs=*/false,
122 /*replaceTypes=*/true);
123 return WalkResult::advance();
124 });
125}
126
127// Return a failure if classes cannot be resolved. Return true if
128// it's necessary to rename symbols.
129static FailureOr<bool> resolveClasses(StringAttr name,
130 ArrayRef<ClassLike> classes) {
131 bool existExternalClass = false;
132 size_t countDefinition = 0;
133 ClassOp classOp;
134
135 for (auto op : classes) {
136 if (isa<ClassExternOp>(op))
137 existExternalClass = true;
138 else {
139 classOp = cast<ClassOp>(op);
140 ++countDefinition;
141 }
142 }
143
144 // There must be exactly one definition if the symbol was referred by an
145 // external class.
146 if (existExternalClass && countDefinition != 1) {
147 SmallVector<Location> classExternLocs;
148 SmallVector<Location> classLocs;
149 for (auto op : classes)
150 (isa<ClassExternOp>(op) ? classExternLocs : classLocs)
151 .push_back(op.getLoc());
152
153 auto diag = emitError(classExternLocs.front())
154 << "class " << name << " is declared as an external class but "
155 << (countDefinition == 0 ? "there is no definition"
156 : "there are multiple definitions");
157 for (auto loc : ArrayRef(classExternLocs).drop_front())
158 diag.attachNote(loc) << "class " << name << " is declared here as well";
159
160 if (countDefinition != 0) {
161 // There are multiple definitions.
162 for (auto loc : classLocs)
163 diag.attachNote(loc) << "class " << name << " is defined here";
164 }
165 return failure();
166 }
167
168 if (!existExternalClass)
169 return countDefinition != 1;
170
171 assert(classOp && countDefinition == 1);
172
173 // Raise errors if linked external modules are not compatible with the
174 // definition.
175 auto emitError = [&](Operation *op) {
176 auto diag = op->emitError()
177 << "failed to link class " << name
178 << " since declaration doesn't match the definition: ";
179 diag.attachNote(classOp.getLoc()) << "definition is here";
180 return diag;
181 };
182
183 for (auto op : classes) {
184 if (op == classOp)
185 continue;
186
187 if (classOp.getBodyBlock()->getNumArguments() !=
188 op.getBodyBlock()->getNumArguments())
189 return emitError(op) << "the number of arguments is not equal, "
190 << classOp.getBodyBlock()->getNumArguments()
191 << " vs " << op.getBodyBlock()->getNumArguments();
192 unsigned index = 0;
193 for (auto [l, r] : llvm::zip(classOp.getBodyBlock()->getArgumentTypes(),
194 op.getBodyBlock()->getArgumentTypes())) {
195 if (l != r)
196 return emitError(op) << index << "-th argument type is not equal, " << l
197 << " vs " << r;
198 index++;
199 }
200 // Check declared fields.
201 llvm::DenseSet<StringAttr> declaredFields;
202
203 for (auto nameAttr : op.getFieldNames()) {
204 StringAttr name = cast<StringAttr>(nameAttr);
205 std::optional<Type> opTypeOpt = op.getFieldType(name);
206
207 if (!opTypeOpt.has_value())
208 return emitError(op) << " no type for field " << name;
209 Type opType = opTypeOpt.value();
210
211 std::optional<Type> classTypeOpt = classOp.getFieldType(name);
212
213 // Field not found in its definition.
214 if (!classTypeOpt.has_value())
215 return emitError(op) << "declaration has a field " << name
216 << " but not found in its definition";
217 Type classType = classTypeOpt.value();
218
219 if (classType != opType)
220 return emitError(op)
221 << "declaration has a field " << name
222 << " but types don't match, " << classType << " vs " << opType;
223 declaredFields.insert(name);
224 }
225
226 for (auto fieldName : classOp.getFieldNames())
227 if (!declaredFields.count(cast<StringAttr>(fieldName)))
228 return emitError(op) << "definition has a field " << fieldName
229 << " but not found in this declaration";
230 }
231 return false;
232}
233
234void LinkModulesPass::runOnOperation() {
235 auto toplevelModule = getOperation();
236 // 1. Initialize ModuleInfo.
237 SmallVector<ModuleInfo> modules;
238 size_t counter = 0;
239 for (auto module : toplevelModule.getOps<ModuleOp>()) {
240 auto name = module->getAttrOfType<StringAttr>("om.namespace");
241 // Use `counter` if the namespace is not specified beforehand.
242 if (!name) {
243 name = StringAttr::get(module.getContext(), "module_" + Twine(counter++));
244 module->setAttr("om.namespace", name);
245 }
246 modules.emplace_back(module);
247 }
248
249 if (failed(failableParallelForEach(&getContext(), modules, [](auto &info) {
250 // Collect local information.
251 return info.initialize();
252 })))
253 return signalPassFailure();
254
255 // 2. Symbol resolution. Check that there is exactly single definition for
256 // public symbols and rename private symbols if necessary.
257
258 // Global namespace to get unique names to symbols.
259 Namespace nameSpace;
260 // A map from a pair of enclosing module op and old symbol to a new symbol.
261 SymMappingTy symMapping;
262
263 // Construct a global map from symbols to class operations.
265 for (const auto &info : modules)
266 for (auto &[name, op] : info.symbolToClasses) {
267 symbolToClasses[name].push_back(op);
268 // Add names to avoid collision.
269 (void)nameSpace.newName(name.getValue());
270 }
271
272 // Resolve symbols. We consider a symbol used as an external module to be
273 // "public" thus we cannot rename such symbols when there is collision. We
274 // require a public symbol to have exactly one definition so otherwise raise
275 // an error.
276 for (auto &[name, classes] : symbolToClasses) {
277 // Check if it's legal to link classes. `resolveClasses` returns true if
278 // it's necessary to rename symbols.
279 auto result = resolveClasses(name, classes);
280 if (failed(result))
281 return signalPassFailure();
282
283 // We can resolve symbol collision for symbols not referred by external
284 // classes. Create a new name using `om.namespace` attributes as a
285 // suffix.
286 if (*result)
287 for (auto op : classes) {
288 auto enclosingModule = cast<mlir::ModuleOp>(op->getParentOp());
289 auto nameSpaceId =
290 enclosingModule->getAttrOfType<StringAttr>("om.namespace");
291 symMapping[{enclosingModule, name}] = StringAttr::get(
292 &getContext(),
293 nameSpace.newName(name.getValue(), nameSpaceId.getValue()));
294 }
295 }
296
297 // 3. Post-processing. Update class names and erase external classes.
298
299 // Rename private symbols and remove external classes.
300 parallelForEach(&getContext(), modules,
301 [&](auto &info) { info.postProcess(symMapping); });
302
303 // Finally move operations to the toplevel module.
304 auto *block = toplevelModule.getBody();
305 for (auto &info : modules) {
306 block->getOperations().splice(block->end(),
307 info.module.getBody()->getOperations());
308 // Restore non-OM operations.
309 assert(info.module.getBody()->empty());
310 info.module.getBody()->getOperations().splice(
311 info.module.getBody()->begin(), info.block->getOperations());
312 }
313}
assert(baseType &&"element must be base type")
static FailureOr< bool > resolveClasses(StringAttr name, ArrayRef< ClassLike > classes)
A namespace that is used to store existing names and generate new names in some scope within the IR.
Definition Namespace.h:30
StringRef newName(const Twine &name)
Return a unique name, derived from the input name, and add the new name to the internal namespace.
Definition Namespace.h:87
void info(Twine message)
Definition LSPUtils.cpp:20
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition om.py:1