CIRCT  19.0.0git
LinkModules.cpp
Go to the documentation of this file.
1 //===- OMLinkModules.cpp - OM Linker pass -----------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Contains the definitions of the OM Linker pass.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "PassDetails.h"
14 #include "circt/Dialect/OM/OMOps.h"
17 #include "mlir/IR/Builders.h"
18 #include "mlir/IR/BuiltinOps.h"
19 #include "mlir/IR/SymbolTable.h"
20 #include "mlir/IR/Threading.h"
21 
22 #include <memory>
23 
24 using namespace circt;
25 using namespace om;
26 using namespace mlir;
27 using namespace llvm;
28 
29 namespace {
30 // A map from a pair of enclosing module op and old symbol to a new symbol.
31 using SymMappingTy =
32  llvm::DenseMap<std::pair<ModuleOp, StringAttr>, StringAttr>;
33 
34 struct ModuleInfo {
35  ModuleInfo(mlir::ModuleOp module) : module(module) {}
36 
37  // Populate `symbolToClasses`.
38  LogicalResult initialize();
39 
40  // Update symbols based on the mapping and erase external classes.
41  void postProcess(const SymMappingTy &symMapping);
42 
43  // A map from symbols to classes.
44  llvm::DenseMap<StringAttr, ClassLike> symbolToClasses;
45 
46  // A target module.
47  ModuleOp module;
48 };
49 
50 struct LinkModulesPass : public LinkModulesBase<LinkModulesPass> {
51  void runOnOperation() override;
52 };
53 
54 } // namespace
55 
56 LogicalResult ModuleInfo::initialize() {
57  for (auto &op : llvm::make_early_inc_range(module.getOps())) {
58  if (auto classLike = dyn_cast<ClassLike>(op))
59  symbolToClasses.insert({classLike.getSymNameAttr(), classLike});
60  else
61  op.erase();
62  }
63  return success();
64 }
65 
66 void ModuleInfo::postProcess(const SymMappingTy &symMapping) {
67  AttrTypeReplacer replacer;
68  replacer.addReplacement(
69  // Update class types when their symbols were renamed.
70  [&](om::ClassType classType) -> std::pair<mlir::Type, WalkResult> {
71  auto it = symMapping.find({module, classType.getClassName().getAttr()});
72  // No change.
73  if (it == symMapping.end())
74  return {classType, WalkResult::skip()};
75  return {om::ClassType::get(classType.getContext(),
76  FlatSymbolRefAttr::get(it->second)),
77  WalkResult::skip()};
78  });
79 
80  module.walk<WalkOrder::PreOrder>([&](Operation *op) {
81  // External modules must be erased.
82  if (isa<ClassExternOp>(op)) {
83  op->erase();
84  // ClassExternFieldOp will be deleted as well.
85  return WalkResult::skip();
86  }
87 
88  if (auto classOp = dyn_cast<ClassOp>(op)) {
89  // Update its class name if changed.
90  auto it = symMapping.find({module, classOp.getNameAttr()});
91  if (it != symMapping.end())
92  classOp.setSymNameAttr(it->second);
93  } else if (auto objectOp = dyn_cast<ObjectOp>(op)) {
94  // Update its class name if changed..
95  auto it = symMapping.find({module, objectOp.getClassNameAttr()});
96  if (it != symMapping.end())
97  objectOp.setClassNameAttr(it->second);
98  }
99 
100  // Otherwise update om.class types.
101  replacer.replaceElementsIn(op,
102  /*replaceAttrs=*/false,
103  /*replaceLocs=*/false,
104  /*replaceTypes=*/true);
105  return WalkResult::advance();
106  });
107 }
108 
109 // Return a failure if classes cannot be resolved. Return true if
110 // it's necessary to rename symbols.
111 static FailureOr<bool> resolveClasses(StringAttr name,
112  ArrayRef<ClassLike> classes) {
113  bool existExternalClass = false;
114  size_t countDefinition = 0;
115  ClassOp classOp;
116 
117  for (auto op : classes) {
118  if (isa<ClassExternOp>(op))
119  existExternalClass = true;
120  else {
121  classOp = cast<ClassOp>(op);
122  ++countDefinition;
123  }
124  }
125 
126  // There must be exactly one definition if the symbol was referred by an
127  // external class.
128  if (existExternalClass && countDefinition != 1) {
129  SmallVector<Location> classExternLocs;
130  SmallVector<Location> classLocs;
131  for (auto op : classes)
132  (isa<ClassExternOp>(op) ? classExternLocs : classLocs)
133  .push_back(op.getLoc());
134 
135  auto diag = emitError(classExternLocs.front())
136  << "class " << name << " is declared as an external class but "
137  << (countDefinition == 0 ? "there is no definition"
138  : "there are multiple definitions");
139  for (auto loc : ArrayRef(classExternLocs).drop_front())
140  diag.attachNote(loc) << "class " << name << " is declared here as well";
141 
142  if (countDefinition != 0) {
143  // There are multiple definitions.
144  for (auto loc : classLocs)
145  diag.attachNote(loc) << "class " << name << " is defined here";
146  }
147  return failure();
148  }
149 
150  if (!existExternalClass)
151  return countDefinition != 1;
152 
153  assert(classOp && countDefinition == 1);
154 
155  // Raise errors if linked external modules are not compatible with the
156  // definition.
157  auto emitError = [&](Operation *op) {
158  auto diag = op->emitError()
159  << "failed to link class " << name
160  << " since declaration doesn't match the definition: ";
161  diag.attachNote(classOp.getLoc()) << "definition is here";
162  return diag;
163  };
164 
165  llvm::MapVector<StringAttr, Type> classFields;
166  for (auto fieldOp : classOp.getOps<om::ClassFieldOp>())
167  classFields.insert({fieldOp.getNameAttr(), fieldOp.getType()});
168 
169  for (auto op : classes) {
170  if (op == classOp)
171  continue;
172 
173  if (classOp.getBodyBlock()->getNumArguments() !=
174  op.getBodyBlock()->getNumArguments())
175  return emitError(op) << "the number of arguments is not equal, "
176  << classOp.getBodyBlock()->getNumArguments()
177  << " vs " << op.getBodyBlock()->getNumArguments();
178  unsigned index = 0;
179  for (auto [l, r] : llvm::zip(classOp.getBodyBlock()->getArgumentTypes(),
180  op.getBodyBlock()->getArgumentTypes())) {
181  if (l != r)
182  return emitError(op) << index << "-th argument type is not equal, " << l
183  << " vs " << r;
184  index++;
185  }
186  // Check declared fields.
187  llvm::DenseSet<StringAttr> declaredFields;
188  for (auto fieldOp : op.getBodyBlock()->getOps<om::ClassExternFieldOp>()) {
189  auto it = classFields.find(fieldOp.getNameAttr());
190 
191  // Field not found in its definition.
192  if (it == classFields.end())
193  return emitError(op)
194  << "declaration has a field " << fieldOp.getNameAttr()
195  << " but not found in its definition";
196 
197  if (it->second != fieldOp.getType())
198  return emitError(op)
199  << "declaration has a field " << fieldOp.getNameAttr()
200  << " but types don't match, " << it->second << " vs "
201  << fieldOp.getType();
202  declaredFields.insert(fieldOp.getNameAttr());
203  }
204 
205  for (auto [fieldName, _] : classFields)
206  if (!declaredFields.count(fieldName))
207  return emitError(op) << "definition has a field " << fieldName
208  << " but not found in this declaration";
209  }
210  return false;
211 }
212 
213 void LinkModulesPass::runOnOperation() {
214  auto toplevelModule = getOperation();
215  // 1. Initialize ModuleInfo.
216  SmallVector<ModuleInfo> modules;
217  size_t counter = 0;
218  for (auto module : toplevelModule.getOps<ModuleOp>()) {
219  auto name = module->getAttrOfType<StringAttr>("om.namespace");
220  // Use `counter` if the namespace is not specified beforehand.
221  if (!name) {
222  name = StringAttr::get(module.getContext(), "module_" + Twine(counter++));
223  module->setAttr("om.namespace", name);
224  }
225  modules.emplace_back(module);
226  }
227 
228  if (failed(failableParallelForEach(&getContext(), modules, [](auto &info) {
229  // Collect local information.
230  return info.initialize();
231  })))
232  return signalPassFailure();
233 
234  // 2. Symbol resolution. Check that there is exactly single definition for
235  // public symbols and rename private symbols if necessary.
236 
237  // Global namespace to get unique names to symbols.
238  Namespace nameSpace;
239  // A map from a pair of enclosing module op and old symbol to a new symbol.
240  SymMappingTy symMapping;
241 
242  // Construct a global map from symbols to class operations.
243  llvm::MapVector<StringAttr, SmallVector<ClassLike>> symbolToClasses;
244  for (const auto &info : modules)
245  for (auto &[name, op] : info.symbolToClasses) {
246  symbolToClasses[name].push_back(op);
247  // Add names to avoid collision.
248  (void)nameSpace.newName(name.getValue());
249  }
250 
251  // Resolve symbols. We consider a symbol used as an external module to be
252  // "public" thus we cannot rename such symbols when there is collision. We
253  // require a public symbol to have exactly one definition so otherwise raise
254  // an error.
255  for (auto &[name, classes] : symbolToClasses) {
256  // Check if it's legal to link classes. `resolveClasses` returns true if
257  // it's necessary to rename symbols.
258  auto result = resolveClasses(name, classes);
259  if (failed(result))
260  return signalPassFailure();
261 
262  // We can resolve symbol collision for symbols not referred by external
263  // classes. Create a new name using `om.namespace` attributes as a
264  // suffix.
265  if (*result)
266  for (auto op : classes) {
267  auto enclosingModule = cast<mlir::ModuleOp>(op->getParentOp());
268  auto nameSpaceId =
269  enclosingModule->getAttrOfType<StringAttr>("om.namespace");
270  symMapping[{enclosingModule, name}] = StringAttr::get(
271  &getContext(),
272  nameSpace.newName(name.getValue(), nameSpaceId.getValue()));
273  }
274  }
275 
276  // 3. Post-processing. Update class names and erase external classes.
277 
278  // Rename private symbols and remove external classes.
279  parallelForEach(&getContext(), modules,
280  [&](auto &info) { info.postProcess(symMapping); });
281 
282  // Finally move operations to the toplevel module.
283  auto *block = toplevelModule.getBody();
284  for (auto info : modules) {
285  block->getOperations().splice(block->end(),
286  info.module.getBody()->getOperations());
287  // Erase the module.
288  info.module.erase();
289  }
290 }
291 
292 std::unique_ptr<mlir::Pass> circt::om::createOMLinkModulesPass() {
293  return std::make_unique<LinkModulesPass>();
294 }
assert(baseType &&"element must be base type")
static Attribute getAttr(ArrayRef< NamedAttribute > attrs, StringRef name)
Get an attribute by name from a list of named attributes.
Definition: FIRRTLOps.cpp:3852
static FailureOr< bool > resolveClasses(StringAttr name, ArrayRef< ClassLike > classes)
A namespace that is used to store existing names and generate new names in some scope within the IR.
Definition: Namespace.h:29
StringRef newName(const Twine &name)
Return a unique name, derived from the input name, and add the new name to the internal namespace.
Definition: Namespace.h:63
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:54
std::unique_ptr< mlir::Pass > createOMLinkModulesPass()
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21
Definition: om.py:1