CIRCT 20.0.0git
Loading...
Searching...
No Matches
LinkModules.cpp
Go to the documentation of this file.
1//===- OMLinkModules.cpp - OM Linker pass -----------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Contains the definitions of the OM Linker pass.
10//
11//===----------------------------------------------------------------------===//
12
16#include "mlir/IR/Block.h"
17#include "mlir/IR/Builders.h"
18#include "mlir/IR/BuiltinOps.h"
19#include "mlir/IR/SymbolTable.h"
20#include "mlir/IR/Threading.h"
21#include "mlir/Pass/Pass.h"
22
23#include <memory>
24
25namespace circt {
26namespace om {
27#define GEN_PASS_DEF_LINKMODULES
28#include "circt/Dialect/OM/OMPasses.h.inc"
29} // namespace om
30} // namespace circt
31
32using namespace circt;
33using namespace om;
34using namespace mlir;
35using namespace llvm;
36
37namespace {
38// A map from a pair of enclosing module op and old symbol to a new symbol.
39using SymMappingTy =
40 llvm::DenseMap<std::pair<ModuleOp, StringAttr>, StringAttr>;
41
42struct ModuleInfo {
43 ModuleInfo(mlir::ModuleOp module) : module(module) {
44 block = std::make_unique<Block>();
45 }
46
47 // Populate `symbolToClasses`.
48 LogicalResult initialize();
49
50 // Update symbols based on the mapping and erase external classes.
51 void postProcess(const SymMappingTy &symMapping);
52
53 // A map from symbols to classes.
54 llvm::DenseMap<StringAttr, ClassLike> symbolToClasses;
55
56 // A target module.
57 ModuleOp module;
58
59 // A block to store original operations.
60 std::unique_ptr<Block> block;
61};
62
63struct LinkModulesPass
64 : public circt::om::impl::LinkModulesBase<LinkModulesPass> {
65 void runOnOperation() override;
66};
67
68} // namespace
69
70LogicalResult ModuleInfo::initialize() {
71 for (auto &op : llvm::make_early_inc_range(module.getOps())) {
72 if (auto classLike = dyn_cast<ClassLike>(op))
73 symbolToClasses.insert({classLike.getSymNameAttr(), classLike});
74 else {
75 // Keep the op.
76 op.moveBefore(block.get(), block->end());
77 }
78 }
79 return success();
80}
81
82void ModuleInfo::postProcess(const SymMappingTy &symMapping) {
83 AttrTypeReplacer replacer;
84 replacer.addReplacement(
85 // Update class types when their symbols were renamed.
86 [&](om::ClassType classType) -> std::pair<mlir::Type, WalkResult> {
87 auto it = symMapping.find({module, classType.getClassName().getAttr()});
88 // No change.
89 if (it == symMapping.end())
90 return {classType, WalkResult::skip()};
91 return {om::ClassType::get(classType.getContext(),
92 FlatSymbolRefAttr::get(it->second)),
93 WalkResult::skip()};
94 });
95
96 module.walk<WalkOrder::PreOrder>([&](Operation *op) {
97 // External modules must be erased.
98 if (isa<ClassExternOp>(op)) {
99 op->erase();
100 // ClassExternFieldOp will be deleted as well.
101 return WalkResult::skip();
102 }
103
104 if (auto classOp = dyn_cast<ClassOp>(op)) {
105 // Update its class name if changed.
106 auto it = symMapping.find({module, classOp.getNameAttr()});
107 if (it != symMapping.end())
108 classOp.setSymNameAttr(it->second);
109 classOp.replaceFieldTypes(replacer);
110 } else if (auto objectOp = dyn_cast<ObjectOp>(op)) {
111 // Update its class name if changed..
112 auto it = symMapping.find({module, objectOp.getClassNameAttr()});
113 if (it != symMapping.end())
114 objectOp.setClassNameAttr(it->second);
115 }
116
117 // Otherwise update om.class types.
118 replacer.replaceElementsIn(op,
119 /*replaceAttrs=*/false,
120 /*replaceLocs=*/false,
121 /*replaceTypes=*/true);
122 return WalkResult::advance();
123 });
124}
125
126// Return a failure if classes cannot be resolved. Return true if
127// it's necessary to rename symbols.
128static FailureOr<bool> resolveClasses(StringAttr name,
129 ArrayRef<ClassLike> classes) {
130 bool existExternalClass = false;
131 size_t countDefinition = 0;
132 ClassOp classOp;
133
134 for (auto op : classes) {
135 if (isa<ClassExternOp>(op))
136 existExternalClass = true;
137 else {
138 classOp = cast<ClassOp>(op);
139 ++countDefinition;
140 }
141 }
142
143 // There must be exactly one definition if the symbol was referred by an
144 // external class.
145 if (existExternalClass && countDefinition != 1) {
146 SmallVector<Location> classExternLocs;
147 SmallVector<Location> classLocs;
148 for (auto op : classes)
149 (isa<ClassExternOp>(op) ? classExternLocs : classLocs)
150 .push_back(op.getLoc());
151
152 auto diag = emitError(classExternLocs.front())
153 << "class " << name << " is declared as an external class but "
154 << (countDefinition == 0 ? "there is no definition"
155 : "there are multiple definitions");
156 for (auto loc : ArrayRef(classExternLocs).drop_front())
157 diag.attachNote(loc) << "class " << name << " is declared here as well";
158
159 if (countDefinition != 0) {
160 // There are multiple definitions.
161 for (auto loc : classLocs)
162 diag.attachNote(loc) << "class " << name << " is defined here";
163 }
164 return failure();
165 }
166
167 if (!existExternalClass)
168 return countDefinition != 1;
169
170 assert(classOp && countDefinition == 1);
171
172 // Raise errors if linked external modules are not compatible with the
173 // definition.
174 auto emitError = [&](Operation *op) {
175 auto diag = op->emitError()
176 << "failed to link class " << name
177 << " since declaration doesn't match the definition: ";
178 diag.attachNote(classOp.getLoc()) << "definition is here";
179 return diag;
180 };
181
182 for (auto op : classes) {
183 if (op == classOp)
184 continue;
185
186 if (classOp.getBodyBlock()->getNumArguments() !=
187 op.getBodyBlock()->getNumArguments())
188 return emitError(op) << "the number of arguments is not equal, "
189 << classOp.getBodyBlock()->getNumArguments()
190 << " vs " << op.getBodyBlock()->getNumArguments();
191 unsigned index = 0;
192 for (auto [l, r] : llvm::zip(classOp.getBodyBlock()->getArgumentTypes(),
193 op.getBodyBlock()->getArgumentTypes())) {
194 if (l != r)
195 return emitError(op) << index << "-th argument type is not equal, " << l
196 << " vs " << r;
197 index++;
198 }
199 // Check declared fields.
200 llvm::DenseSet<StringAttr> declaredFields;
201
202 for (auto nameAttr : op.getFieldNames()) {
203 StringAttr name = cast<StringAttr>(nameAttr);
204 std::optional<Type> opTypeOpt = op.getFieldType(name);
205
206 if (!opTypeOpt.has_value())
207 return emitError(op) << " no type for field " << name;
208 Type opType = opTypeOpt.value();
209
210 std::optional<Type> classTypeOpt = classOp.getFieldType(name);
211
212 // Field not found in its definition.
213 if (!classTypeOpt.has_value())
214 return emitError(op) << "declaration has a field " << name
215 << " but not found in its definition";
216 Type classType = classTypeOpt.value();
217
218 if (classType != opType)
219 return emitError(op)
220 << "declaration has a field " << name
221 << " but types don't match, " << classType << " vs " << opType;
222 declaredFields.insert(name);
223 }
224
225 for (auto fieldName : classOp.getFieldNames())
226 if (!declaredFields.count(cast<StringAttr>(fieldName)))
227 return emitError(op) << "definition has a field " << fieldName
228 << " but not found in this declaration";
229 }
230 return false;
231}
232
233void LinkModulesPass::runOnOperation() {
234 auto toplevelModule = getOperation();
235 // 1. Initialize ModuleInfo.
236 SmallVector<ModuleInfo> modules;
237 size_t counter = 0;
238 for (auto module : toplevelModule.getOps<ModuleOp>()) {
239 auto name = module->getAttrOfType<StringAttr>("om.namespace");
240 // Use `counter` if the namespace is not specified beforehand.
241 if (!name) {
242 name = StringAttr::get(module.getContext(), "module_" + Twine(counter++));
243 module->setAttr("om.namespace", name);
244 }
245 modules.emplace_back(module);
246 }
247
248 if (failed(failableParallelForEach(&getContext(), modules, [](auto &info) {
249 // Collect local information.
250 return info.initialize();
251 })))
252 return signalPassFailure();
253
254 // 2. Symbol resolution. Check that there is exactly single definition for
255 // public symbols and rename private symbols if necessary.
256
257 // Global namespace to get unique names to symbols.
258 Namespace nameSpace;
259 // A map from a pair of enclosing module op and old symbol to a new symbol.
260 SymMappingTy symMapping;
261
262 // Construct a global map from symbols to class operations.
263 llvm::MapVector<StringAttr, SmallVector<ClassLike>> symbolToClasses;
264 for (const auto &info : modules)
265 for (auto &[name, op] : info.symbolToClasses) {
266 symbolToClasses[name].push_back(op);
267 // Add names to avoid collision.
268 (void)nameSpace.newName(name.getValue());
269 }
270
271 // Resolve symbols. We consider a symbol used as an external module to be
272 // "public" thus we cannot rename such symbols when there is collision. We
273 // require a public symbol to have exactly one definition so otherwise raise
274 // an error.
275 for (auto &[name, classes] : symbolToClasses) {
276 // Check if it's legal to link classes. `resolveClasses` returns true if
277 // it's necessary to rename symbols.
278 auto result = resolveClasses(name, classes);
279 if (failed(result))
280 return signalPassFailure();
281
282 // We can resolve symbol collision for symbols not referred by external
283 // classes. Create a new name using `om.namespace` attributes as a
284 // suffix.
285 if (*result)
286 for (auto op : classes) {
287 auto enclosingModule = cast<mlir::ModuleOp>(op->getParentOp());
288 auto nameSpaceId =
289 enclosingModule->getAttrOfType<StringAttr>("om.namespace");
290 symMapping[{enclosingModule, name}] = StringAttr::get(
291 &getContext(),
292 nameSpace.newName(name.getValue(), nameSpaceId.getValue()));
293 }
294 }
295
296 // 3. Post-processing. Update class names and erase external classes.
297
298 // Rename private symbols and remove external classes.
299 parallelForEach(&getContext(), modules,
300 [&](auto &info) { info.postProcess(symMapping); });
301
302 // Finally move operations to the toplevel module.
303 auto *block = toplevelModule.getBody();
304 for (auto &info : modules) {
305 block->getOperations().splice(block->end(),
306 info.module.getBody()->getOperations());
307 // Restore non-OM operations.
308 assert(info.module.getBody()->empty());
309 info.module.getBody()->getOperations().splice(
310 info.module.getBody()->begin(), info.block->getOperations());
311 }
312}
313
314std::unique_ptr<mlir::Pass> circt::om::createOMLinkModulesPass() {
315 return std::make_unique<LinkModulesPass>();
316}
assert(baseType &&"element must be base type")
static FailureOr< bool > resolveClasses(StringAttr name, ArrayRef< ClassLike > classes)
A namespace that is used to store existing names and generate new names in some scope within the IR.
Definition Namespace.h:30
StringRef newName(const Twine &name)
Return a unique name, derived from the input name, and add the new name to the internal namespace.
Definition Namespace.h:85
std::unique_ptr< mlir::Pass > createOMLinkModulesPass()
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition om.py:1