CIRCT  20.0.0git
LinkModules.cpp
Go to the documentation of this file.
1 //===- OMLinkModules.cpp - OM Linker pass -----------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Contains the definitions of the OM Linker pass.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "circt/Dialect/OM/OMOps.h"
16 #include "mlir/IR/Builders.h"
17 #include "mlir/IR/BuiltinOps.h"
18 #include "mlir/IR/SymbolTable.h"
19 #include "mlir/IR/Threading.h"
20 #include "mlir/Pass/Pass.h"
21 
22 #include <memory>
23 
24 namespace circt {
25 namespace om {
26 #define GEN_PASS_DEF_LINKMODULES
27 #include "circt/Dialect/OM/OMPasses.h.inc"
28 } // namespace om
29 } // namespace circt
30 
31 using namespace circt;
32 using namespace om;
33 using namespace mlir;
34 using namespace llvm;
35 
36 namespace {
37 // A map from a pair of enclosing module op and old symbol to a new symbol.
38 using SymMappingTy =
39  llvm::DenseMap<std::pair<ModuleOp, StringAttr>, StringAttr>;
40 
41 struct ModuleInfo {
42  ModuleInfo(mlir::ModuleOp module) : module(module) {}
43 
44  // Populate `symbolToClasses`.
45  LogicalResult initialize();
46 
47  // Update symbols based on the mapping and erase external classes.
48  void postProcess(const SymMappingTy &symMapping);
49 
50  // A map from symbols to classes.
51  llvm::DenseMap<StringAttr, ClassLike> symbolToClasses;
52 
53  // A target module.
54  ModuleOp module;
55 };
56 
57 struct LinkModulesPass
58  : public circt::om::impl::LinkModulesBase<LinkModulesPass> {
59  void runOnOperation() override;
60 };
61 
62 } // namespace
63 
64 LogicalResult ModuleInfo::initialize() {
65  for (auto &op : llvm::make_early_inc_range(module.getOps())) {
66  if (auto classLike = dyn_cast<ClassLike>(op))
67  symbolToClasses.insert({classLike.getSymNameAttr(), classLike});
68  else
69  op.erase();
70  }
71  return success();
72 }
73 
74 void ModuleInfo::postProcess(const SymMappingTy &symMapping) {
75  AttrTypeReplacer replacer;
76  replacer.addReplacement(
77  // Update class types when their symbols were renamed.
78  [&](om::ClassType classType) -> std::pair<mlir::Type, WalkResult> {
79  auto it = symMapping.find({module, classType.getClassName().getAttr()});
80  // No change.
81  if (it == symMapping.end())
82  return {classType, WalkResult::skip()};
83  return {om::ClassType::get(classType.getContext(),
84  FlatSymbolRefAttr::get(it->second)),
85  WalkResult::skip()};
86  });
87 
88  module.walk<WalkOrder::PreOrder>([&](Operation *op) {
89  // External modules must be erased.
90  if (isa<ClassExternOp>(op)) {
91  op->erase();
92  // ClassExternFieldOp will be deleted as well.
93  return WalkResult::skip();
94  }
95 
96  if (auto classOp = dyn_cast<ClassOp>(op)) {
97  // Update its class name if changed.
98  auto it = symMapping.find({module, classOp.getNameAttr()});
99  if (it != symMapping.end())
100  classOp.setSymNameAttr(it->second);
101  classOp.replaceFieldTypes(replacer);
102  } else if (auto objectOp = dyn_cast<ObjectOp>(op)) {
103  // Update its class name if changed..
104  auto it = symMapping.find({module, objectOp.getClassNameAttr()});
105  if (it != symMapping.end())
106  objectOp.setClassNameAttr(it->second);
107  }
108 
109  // Otherwise update om.class types.
110  replacer.replaceElementsIn(op,
111  /*replaceAttrs=*/false,
112  /*replaceLocs=*/false,
113  /*replaceTypes=*/true);
114  return WalkResult::advance();
115  });
116 }
117 
118 // Return a failure if classes cannot be resolved. Return true if
119 // it's necessary to rename symbols.
120 static FailureOr<bool> resolveClasses(StringAttr name,
121  ArrayRef<ClassLike> classes) {
122  bool existExternalClass = false;
123  size_t countDefinition = 0;
124  ClassOp classOp;
125 
126  for (auto op : classes) {
127  if (isa<ClassExternOp>(op))
128  existExternalClass = true;
129  else {
130  classOp = cast<ClassOp>(op);
131  ++countDefinition;
132  }
133  }
134 
135  // There must be exactly one definition if the symbol was referred by an
136  // external class.
137  if (existExternalClass && countDefinition != 1) {
138  SmallVector<Location> classExternLocs;
139  SmallVector<Location> classLocs;
140  for (auto op : classes)
141  (isa<ClassExternOp>(op) ? classExternLocs : classLocs)
142  .push_back(op.getLoc());
143 
144  auto diag = emitError(classExternLocs.front())
145  << "class " << name << " is declared as an external class but "
146  << (countDefinition == 0 ? "there is no definition"
147  : "there are multiple definitions");
148  for (auto loc : ArrayRef(classExternLocs).drop_front())
149  diag.attachNote(loc) << "class " << name << " is declared here as well";
150 
151  if (countDefinition != 0) {
152  // There are multiple definitions.
153  for (auto loc : classLocs)
154  diag.attachNote(loc) << "class " << name << " is defined here";
155  }
156  return failure();
157  }
158 
159  if (!existExternalClass)
160  return countDefinition != 1;
161 
162  assert(classOp && countDefinition == 1);
163 
164  // Raise errors if linked external modules are not compatible with the
165  // definition.
166  auto emitError = [&](Operation *op) {
167  auto diag = op->emitError()
168  << "failed to link class " << name
169  << " since declaration doesn't match the definition: ";
170  diag.attachNote(classOp.getLoc()) << "definition is here";
171  return diag;
172  };
173 
174  for (auto op : classes) {
175  if (op == classOp)
176  continue;
177 
178  if (classOp.getBodyBlock()->getNumArguments() !=
179  op.getBodyBlock()->getNumArguments())
180  return emitError(op) << "the number of arguments is not equal, "
181  << classOp.getBodyBlock()->getNumArguments()
182  << " vs " << op.getBodyBlock()->getNumArguments();
183  unsigned index = 0;
184  for (auto [l, r] : llvm::zip(classOp.getBodyBlock()->getArgumentTypes(),
185  op.getBodyBlock()->getArgumentTypes())) {
186  if (l != r)
187  return emitError(op) << index << "-th argument type is not equal, " << l
188  << " vs " << r;
189  index++;
190  }
191  // Check declared fields.
192  llvm::DenseSet<StringAttr> declaredFields;
193 
194  for (auto nameAttr : op.getFieldNames()) {
195  StringAttr name = cast<StringAttr>(nameAttr);
196  std::optional<Type> opTypeOpt = op.getFieldType(name);
197 
198  if (!opTypeOpt.has_value())
199  return emitError(op) << " no type for field " << name;
200  Type opType = opTypeOpt.value();
201 
202  std::optional<Type> classTypeOpt = classOp.getFieldType(name);
203 
204  // Field not found in its definition.
205  if (!classTypeOpt.has_value())
206  return emitError(op) << "declaration has a field " << name
207  << " but not found in its definition";
208  Type classType = classTypeOpt.value();
209 
210  if (classType != opType)
211  return emitError(op)
212  << "declaration has a field " << name
213  << " but types don't match, " << classType << " vs " << opType;
214  declaredFields.insert(name);
215  }
216 
217  for (auto fieldName : classOp.getFieldNames())
218  if (!declaredFields.count(cast<StringAttr>(fieldName)))
219  return emitError(op) << "definition has a field " << fieldName
220  << " but not found in this declaration";
221  }
222  return false;
223 }
224 
225 void LinkModulesPass::runOnOperation() {
226  auto toplevelModule = getOperation();
227  // 1. Initialize ModuleInfo.
228  SmallVector<ModuleInfo> modules;
229  size_t counter = 0;
230  for (auto module : toplevelModule.getOps<ModuleOp>()) {
231  auto name = module->getAttrOfType<StringAttr>("om.namespace");
232  // Use `counter` if the namespace is not specified beforehand.
233  if (!name) {
234  name = StringAttr::get(module.getContext(), "module_" + Twine(counter++));
235  module->setAttr("om.namespace", name);
236  }
237  modules.emplace_back(module);
238  }
239 
240  if (failed(failableParallelForEach(&getContext(), modules, [](auto &info) {
241  // Collect local information.
242  return info.initialize();
243  })))
244  return signalPassFailure();
245 
246  // 2. Symbol resolution. Check that there is exactly single definition for
247  // public symbols and rename private symbols if necessary.
248 
249  // Global namespace to get unique names to symbols.
250  Namespace nameSpace;
251  // A map from a pair of enclosing module op and old symbol to a new symbol.
252  SymMappingTy symMapping;
253 
254  // Construct a global map from symbols to class operations.
255  llvm::MapVector<StringAttr, SmallVector<ClassLike>> symbolToClasses;
256  for (const auto &info : modules)
257  for (auto &[name, op] : info.symbolToClasses) {
258  symbolToClasses[name].push_back(op);
259  // Add names to avoid collision.
260  (void)nameSpace.newName(name.getValue());
261  }
262 
263  // Resolve symbols. We consider a symbol used as an external module to be
264  // "public" thus we cannot rename such symbols when there is collision. We
265  // require a public symbol to have exactly one definition so otherwise raise
266  // an error.
267  for (auto &[name, classes] : symbolToClasses) {
268  // Check if it's legal to link classes. `resolveClasses` returns true if
269  // it's necessary to rename symbols.
270  auto result = resolveClasses(name, classes);
271  if (failed(result))
272  return signalPassFailure();
273 
274  // We can resolve symbol collision for symbols not referred by external
275  // classes. Create a new name using `om.namespace` attributes as a
276  // suffix.
277  if (*result)
278  for (auto op : classes) {
279  auto enclosingModule = cast<mlir::ModuleOp>(op->getParentOp());
280  auto nameSpaceId =
281  enclosingModule->getAttrOfType<StringAttr>("om.namespace");
282  symMapping[{enclosingModule, name}] = StringAttr::get(
283  &getContext(),
284  nameSpace.newName(name.getValue(), nameSpaceId.getValue()));
285  }
286  }
287 
288  // 3. Post-processing. Update class names and erase external classes.
289 
290  // Rename private symbols and remove external classes.
291  parallelForEach(&getContext(), modules,
292  [&](auto &info) { info.postProcess(symMapping); });
293 
294  // Finally move operations to the toplevel module.
295  auto *block = toplevelModule.getBody();
296  for (auto info : modules) {
297  block->getOperations().splice(block->end(),
298  info.module.getBody()->getOperations());
299  // Erase the module.
300  info.module.erase();
301  }
302 }
303 
304 std::unique_ptr<mlir::Pass> circt::om::createOMLinkModulesPass() {
305  return std::make_unique<LinkModulesPass>();
306 }
assert(baseType &&"element must be base type")
static FailureOr< bool > resolveClasses(StringAttr name, ArrayRef< ClassLike > classes)
A namespace that is used to store existing names and generate new names in some scope within the IR.
Definition: Namespace.h:30
StringRef newName(const Twine &name)
Return a unique name, derived from the input name, and add the new name to the internal namespace.
Definition: Namespace.h:85
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:55
std::unique_ptr< mlir::Pass > createOMLinkModulesPass()
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21
Definition: om.py:1