CIRCT 23.0.0git
Loading...
Searching...
No Matches
LinkModules.cpp
Go to the documentation of this file.
1//===- OMLinkModules.cpp - OM Linker pass -----------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Contains the definitions of the OM Linker pass.
10//
11//===----------------------------------------------------------------------===//
12
16#include "mlir/IR/Block.h"
17#include "mlir/IR/Builders.h"
18#include "mlir/IR/BuiltinOps.h"
19#include "mlir/IR/SymbolTable.h"
20#include "mlir/IR/Threading.h"
21#include "mlir/Pass/Pass.h"
22#include "llvm/ADT/TypeSwitch.h"
23
24#include <memory>
25
26namespace circt {
27namespace om {
28#define GEN_PASS_DEF_LINKMODULES
29#include "circt/Dialect/OM/OMPasses.h.inc"
30} // namespace om
31} // namespace circt
32
33using namespace circt;
34using namespace om;
35using namespace mlir;
36using namespace llvm;
37
38namespace {
39// A map from a pair of enclosing module op and old symbol to a new symbol.
40using SymMappingTy =
41 llvm::DenseMap<std::pair<ModuleOp, StringAttr>, StringAttr>;
42
43struct ModuleInfo {
44 ModuleInfo(mlir::ModuleOp module) : module(module) {
45 block = std::make_unique<Block>();
46 }
47
48 // Populate `symbolToClasses`.
49 LogicalResult initialize();
50
51 // Update symbols based on the mapping and erase external classes.
52 void postProcess(const SymMappingTy &symMapping);
53
54 // A map from symbols to classes.
55 llvm::DenseMap<StringAttr, ClassLike> symbolToClasses;
56
57 // A target module.
58 ModuleOp module;
59
60 // A block to store original operations.
61 std::unique_ptr<Block> block;
62};
63
64struct LinkModulesPass
65 : public circt::om::impl::LinkModulesBase<LinkModulesPass> {
66 using Base::Base;
67 void runOnOperation() override;
68};
69
70} // namespace
71
72LogicalResult ModuleInfo::initialize() {
73 for (auto &op : llvm::make_early_inc_range(module.getOps())) {
74 if (auto classLike = dyn_cast<ClassLike>(op))
75 symbolToClasses.insert({classLike.getSymNameAttr(), classLike});
76 else {
77 // Keep the op.
78 op.moveBefore(block.get(), block->end());
79 }
80 }
81 return success();
82}
83
84void ModuleInfo::postProcess(const SymMappingTy &symMapping) {
85 AttrTypeReplacer replacer;
86 replacer.addReplacement(
87 // Update class types when their symbols were renamed.
88 [&](om::ClassType classType) -> std::pair<mlir::Type, WalkResult> {
89 auto it = symMapping.find({module, classType.getClassName().getAttr()});
90 // No change.
91 if (it == symMapping.end())
92 return {classType, WalkResult::skip()};
93 return {om::ClassType::get(classType.getContext(),
94 FlatSymbolRefAttr::get(it->second)),
95 WalkResult::skip()};
96 });
97
98 module.walk<WalkOrder::PreOrder>([&](Operation *op) {
99 // External modules must be erased.
100 if (isa<ClassExternOp>(op)) {
101 op->erase();
102 // ClassExternFieldOp will be deleted as well.
103 return WalkResult::skip();
104 }
105
106 llvm::TypeSwitch<Operation *>(op)
107 .Case<ClassOp>([&](ClassOp classOp) {
108 auto it = symMapping.find({module, classOp.getNameAttr()});
109 // Update its class name if changed.
110 if (it != symMapping.end())
111 classOp.setSymNameAttr(it->second);
112 classOp.replaceFieldTypes(replacer);
113 })
114 .Case<ObjectOp, ElaboratedObjectOp>([&](auto objectLike) {
115 auto it = symMapping.find(
116 {module, objectLike.getClassNameAttr().getAttr()});
117 // Update its class name if changed.
118 if (it != symMapping.end())
119 objectLike.setClassNameAttr(FlatSymbolRefAttr::get(it->second));
120 replacer.replaceElementsIn(objectLike, false, false, true);
121 });
122
123 // Otherwise update om.class types.
124 replacer.replaceElementsIn(op,
125 /*replaceAttrs=*/false,
126 /*replaceLocs=*/false,
127 /*replaceTypes=*/true);
128 return WalkResult::advance();
129 });
130}
131
132// Return a failure if classes cannot be resolved. Return true if
133// it's necessary to rename symbols.
134static FailureOr<bool> resolveClasses(StringAttr name,
135 ArrayRef<ClassLike> classes) {
136 bool existExternalClass = false;
137 size_t countDefinition = 0;
138 ClassOp classOp;
139
140 for (auto op : classes) {
141 if (isa<ClassExternOp>(op))
142 existExternalClass = true;
143 else {
144 classOp = cast<ClassOp>(op);
145 ++countDefinition;
146 }
147 }
148
149 // There must be exactly one definition if the symbol was referred by an
150 // external class.
151 if (existExternalClass && countDefinition != 1) {
152 SmallVector<Location> classExternLocs;
153 SmallVector<Location> classLocs;
154 for (auto op : classes)
155 (isa<ClassExternOp>(op) ? classExternLocs : classLocs)
156 .push_back(op.getLoc());
157
158 auto diag = emitError(classExternLocs.front())
159 << "class " << name << " is declared as an external class but "
160 << (countDefinition == 0 ? "there is no definition"
161 : "there are multiple definitions");
162 for (auto loc : ArrayRef(classExternLocs).drop_front())
163 diag.attachNote(loc) << "class " << name << " is declared here as well";
164
165 if (countDefinition != 0) {
166 // There are multiple definitions.
167 for (auto loc : classLocs)
168 diag.attachNote(loc) << "class " << name << " is defined here";
169 }
170 return failure();
171 }
172
173 if (!existExternalClass)
174 return countDefinition != 1;
175
176 assert(classOp && countDefinition == 1);
177
178 // Raise errors if linked external modules are not compatible with the
179 // definition.
180 auto emitError = [&](Operation *op) {
181 auto diag = op->emitError()
182 << "failed to link class " << name
183 << " since declaration doesn't match the definition: ";
184 diag.attachNote(classOp.getLoc()) << "definition is here";
185 return diag;
186 };
187
188 for (auto op : classes) {
189 if (op == classOp)
190 continue;
191
192 if (classOp.getBodyBlock()->getNumArguments() !=
193 op.getBodyBlock()->getNumArguments())
194 return emitError(op) << "the number of arguments is not equal, "
195 << classOp.getBodyBlock()->getNumArguments()
196 << " vs " << op.getBodyBlock()->getNumArguments();
197 unsigned index = 0;
198 for (auto [l, r] : llvm::zip(classOp.getBodyBlock()->getArgumentTypes(),
199 op.getBodyBlock()->getArgumentTypes())) {
200 if (l != r)
201 return emitError(op) << index << "-th argument type is not equal, " << l
202 << " vs " << r;
203 index++;
204 }
205 // Check declared fields.
206 llvm::DenseSet<StringAttr> declaredFields;
207
208 for (auto nameAttr : op.getFieldNames()) {
209 StringAttr name = cast<StringAttr>(nameAttr);
210 std::optional<Type> opTypeOpt = op.getFieldType(name);
211
212 if (!opTypeOpt.has_value())
213 return emitError(op) << " no type for field " << name;
214 Type opType = opTypeOpt.value();
215
216 std::optional<Type> classTypeOpt = classOp.getFieldType(name);
217
218 // Field not found in its definition.
219 if (!classTypeOpt.has_value())
220 return emitError(op) << "declaration has a field " << name
221 << " but not found in its definition";
222 Type classType = classTypeOpt.value();
223
224 if (classType != opType)
225 return emitError(op)
226 << "declaration has a field " << name
227 << " but types don't match, " << classType << " vs " << opType;
228 declaredFields.insert(name);
229 }
230
231 for (auto fieldName : classOp.getFieldNames())
232 if (!declaredFields.count(cast<StringAttr>(fieldName)))
233 return emitError(op) << "definition has a field " << fieldName
234 << " but not found in this declaration";
235 }
236 return false;
237}
238
239void LinkModulesPass::runOnOperation() {
240 auto toplevelModule = getOperation();
241 // 1. Initialize ModuleInfo.
242 SmallVector<ModuleInfo> modules;
243 size_t counter = 0;
244 for (auto module : toplevelModule.getOps<ModuleOp>()) {
245 auto name = module->getAttrOfType<StringAttr>("om.namespace");
246 // Use `counter` if the namespace is not specified beforehand.
247 if (!name) {
248 name = StringAttr::get(module.getContext(), "module_" + Twine(counter++));
249 module->setAttr("om.namespace", name);
250 }
251 modules.emplace_back(module);
252 }
253
254 if (failed(failableParallelForEach(&getContext(), modules, [](auto &info) {
255 // Collect local information.
256 return info.initialize();
257 })))
258 return signalPassFailure();
259
260 // 2. Symbol resolution. Check that there is exactly single definition for
261 // public symbols and rename private symbols if necessary.
262
263 // Global namespace to get unique names to symbols.
264 Namespace nameSpace;
265 // A map from a pair of enclosing module op and old symbol to a new symbol.
266 SymMappingTy symMapping;
267
268 // Construct a global map from symbols to class operations.
270 for (const auto &info : modules)
271 for (auto &[name, op] : info.symbolToClasses) {
272 symbolToClasses[name].push_back(op);
273 // Add names to avoid collision.
274 (void)nameSpace.newName(name.getValue());
275 }
276
277 // Resolve symbols. We consider a symbol used as an external module to be
278 // "public" thus we cannot rename such symbols when there is collision. We
279 // require a public symbol to have exactly one definition so otherwise raise
280 // an error.
281 for (auto &[name, classes] : symbolToClasses) {
282 // Check if it's legal to link classes. `resolveClasses` returns true if
283 // it's necessary to rename symbols.
284 auto result = resolveClasses(name, classes);
285 if (failed(result))
286 return signalPassFailure();
287
288 // We can resolve symbol collision for symbols not referred by external
289 // classes. Create a new name using `om.namespace` attributes as a
290 // suffix.
291 if (*result)
292 for (auto op : classes) {
293 auto enclosingModule = cast<mlir::ModuleOp>(op->getParentOp());
294 auto nameSpaceId =
295 enclosingModule->getAttrOfType<StringAttr>("om.namespace");
296 symMapping[{enclosingModule, name}] = StringAttr::get(
297 &getContext(),
298 nameSpace.newName(name.getValue(), nameSpaceId.getValue()));
299 }
300 }
301
302 // 3. Post-processing. Update class names and erase external classes.
303
304 // Rename private symbols and remove external classes.
305 parallelForEach(&getContext(), modules,
306 [&](auto &info) { info.postProcess(symMapping); });
307
308 // Finally move operations to the toplevel module.
309 auto *block = toplevelModule.getBody();
310 for (auto &info : modules) {
311 block->getOperations().splice(block->end(),
312 info.module.getBody()->getOperations());
313 // Restore non-OM operations.
314 assert(info.module.getBody()->empty());
315 info.module.getBody()->getOperations().splice(
316 info.module.getBody()->begin(), info.block->getOperations());
317 }
318}
assert(baseType &&"element must be base type")
static FailureOr< bool > resolveClasses(StringAttr name, ArrayRef< ClassLike > classes)
A namespace that is used to store existing names and generate new names in some scope within the IR.
Definition Namespace.h:30
StringRef newName(const Twine &name)
Return a unique name, derived from the input name, and add the new name to the internal namespace.
Definition Namespace.h:87
void info(Twine message)
Definition LSPUtils.cpp:20
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition om.py:1