CIRCT  19.0.0git
LinkModules.cpp
Go to the documentation of this file.
1 //===- OMLinkModules.cpp - OM Linker pass -----------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Contains the definitions of the OM Linker pass.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "circt/Dialect/OM/OMOps.h"
16 #include "mlir/IR/Builders.h"
17 #include "mlir/IR/BuiltinOps.h"
18 #include "mlir/IR/SymbolTable.h"
19 #include "mlir/IR/Threading.h"
20 #include "mlir/Pass/Pass.h"
21 
22 #include <memory>
23 
24 namespace circt {
25 namespace om {
26 #define GEN_PASS_DEF_LINKMODULES
27 #include "circt/Dialect/OM/OMPasses.h.inc"
28 } // namespace om
29 } // namespace circt
30 
31 using namespace circt;
32 using namespace om;
33 using namespace mlir;
34 using namespace llvm;
35 
36 namespace {
37 // A map from a pair of enclosing module op and old symbol to a new symbol.
38 using SymMappingTy =
39  llvm::DenseMap<std::pair<ModuleOp, StringAttr>, StringAttr>;
40 
41 struct ModuleInfo {
42  ModuleInfo(mlir::ModuleOp module) : module(module) {}
43 
44  // Populate `symbolToClasses`.
45  LogicalResult initialize();
46 
47  // Update symbols based on the mapping and erase external classes.
48  void postProcess(const SymMappingTy &symMapping);
49 
50  // A map from symbols to classes.
51  llvm::DenseMap<StringAttr, ClassLike> symbolToClasses;
52 
53  // A target module.
54  ModuleOp module;
55 };
56 
57 struct LinkModulesPass
58  : public circt::om::impl::LinkModulesBase<LinkModulesPass> {
59  void runOnOperation() override;
60 };
61 
62 } // namespace
63 
64 LogicalResult ModuleInfo::initialize() {
65  for (auto &op : llvm::make_early_inc_range(module.getOps())) {
66  if (auto classLike = dyn_cast<ClassLike>(op))
67  symbolToClasses.insert({classLike.getSymNameAttr(), classLike});
68  else
69  op.erase();
70  }
71  return success();
72 }
73 
74 void ModuleInfo::postProcess(const SymMappingTy &symMapping) {
75  AttrTypeReplacer replacer;
76  replacer.addReplacement(
77  // Update class types when their symbols were renamed.
78  [&](om::ClassType classType) -> std::pair<mlir::Type, WalkResult> {
79  auto it = symMapping.find({module, classType.getClassName().getAttr()});
80  // No change.
81  if (it == symMapping.end())
82  return {classType, WalkResult::skip()};
83  return {om::ClassType::get(classType.getContext(),
84  FlatSymbolRefAttr::get(it->second)),
85  WalkResult::skip()};
86  });
87 
88  module.walk<WalkOrder::PreOrder>([&](Operation *op) {
89  // External modules must be erased.
90  if (isa<ClassExternOp>(op)) {
91  op->erase();
92  // ClassExternFieldOp will be deleted as well.
93  return WalkResult::skip();
94  }
95 
96  if (auto classOp = dyn_cast<ClassOp>(op)) {
97  // Update its class name if changed.
98  auto it = symMapping.find({module, classOp.getNameAttr()});
99  if (it != symMapping.end())
100  classOp.setSymNameAttr(it->second);
101  } else if (auto objectOp = dyn_cast<ObjectOp>(op)) {
102  // Update its class name if changed..
103  auto it = symMapping.find({module, objectOp.getClassNameAttr()});
104  if (it != symMapping.end())
105  objectOp.setClassNameAttr(it->second);
106  }
107 
108  // Otherwise update om.class types.
109  replacer.replaceElementsIn(op,
110  /*replaceAttrs=*/false,
111  /*replaceLocs=*/false,
112  /*replaceTypes=*/true);
113  return WalkResult::advance();
114  });
115 }
116 
117 // Return a failure if classes cannot be resolved. Return true if
118 // it's necessary to rename symbols.
119 static FailureOr<bool> resolveClasses(StringAttr name,
120  ArrayRef<ClassLike> classes) {
121  bool existExternalClass = false;
122  size_t countDefinition = 0;
123  ClassOp classOp;
124 
125  for (auto op : classes) {
126  if (isa<ClassExternOp>(op))
127  existExternalClass = true;
128  else {
129  classOp = cast<ClassOp>(op);
130  ++countDefinition;
131  }
132  }
133 
134  // There must be exactly one definition if the symbol was referred by an
135  // external class.
136  if (existExternalClass && countDefinition != 1) {
137  SmallVector<Location> classExternLocs;
138  SmallVector<Location> classLocs;
139  for (auto op : classes)
140  (isa<ClassExternOp>(op) ? classExternLocs : classLocs)
141  .push_back(op.getLoc());
142 
143  auto diag = emitError(classExternLocs.front())
144  << "class " << name << " is declared as an external class but "
145  << (countDefinition == 0 ? "there is no definition"
146  : "there are multiple definitions");
147  for (auto loc : ArrayRef(classExternLocs).drop_front())
148  diag.attachNote(loc) << "class " << name << " is declared here as well";
149 
150  if (countDefinition != 0) {
151  // There are multiple definitions.
152  for (auto loc : classLocs)
153  diag.attachNote(loc) << "class " << name << " is defined here";
154  }
155  return failure();
156  }
157 
158  if (!existExternalClass)
159  return countDefinition != 1;
160 
161  assert(classOp && countDefinition == 1);
162 
163  // Raise errors if linked external modules are not compatible with the
164  // definition.
165  auto emitError = [&](Operation *op) {
166  auto diag = op->emitError()
167  << "failed to link class " << name
168  << " since declaration doesn't match the definition: ";
169  diag.attachNote(classOp.getLoc()) << "definition is here";
170  return diag;
171  };
172 
173  llvm::MapVector<StringAttr, Type> classFields;
174  for (auto fieldOp : classOp.getOps<om::ClassFieldOp>())
175  classFields.insert({fieldOp.getNameAttr(), fieldOp.getType()});
176 
177  for (auto op : classes) {
178  if (op == classOp)
179  continue;
180 
181  if (classOp.getBodyBlock()->getNumArguments() !=
182  op.getBodyBlock()->getNumArguments())
183  return emitError(op) << "the number of arguments is not equal, "
184  << classOp.getBodyBlock()->getNumArguments()
185  << " vs " << op.getBodyBlock()->getNumArguments();
186  unsigned index = 0;
187  for (auto [l, r] : llvm::zip(classOp.getBodyBlock()->getArgumentTypes(),
188  op.getBodyBlock()->getArgumentTypes())) {
189  if (l != r)
190  return emitError(op) << index << "-th argument type is not equal, " << l
191  << " vs " << r;
192  index++;
193  }
194  // Check declared fields.
195  llvm::DenseSet<StringAttr> declaredFields;
196  for (auto fieldOp : op.getBodyBlock()->getOps<om::ClassExternFieldOp>()) {
197  auto it = classFields.find(fieldOp.getNameAttr());
198 
199  // Field not found in its definition.
200  if (it == classFields.end())
201  return emitError(op)
202  << "declaration has a field " << fieldOp.getNameAttr()
203  << " but not found in its definition";
204 
205  if (it->second != fieldOp.getType())
206  return emitError(op)
207  << "declaration has a field " << fieldOp.getNameAttr()
208  << " but types don't match, " << it->second << " vs "
209  << fieldOp.getType();
210  declaredFields.insert(fieldOp.getNameAttr());
211  }
212 
213  for (auto [fieldName, _] : classFields)
214  if (!declaredFields.count(fieldName))
215  return emitError(op) << "definition has a field " << fieldName
216  << " but not found in this declaration";
217  }
218  return false;
219 }
220 
221 void LinkModulesPass::runOnOperation() {
222  auto toplevelModule = getOperation();
223  // 1. Initialize ModuleInfo.
224  SmallVector<ModuleInfo> modules;
225  size_t counter = 0;
226  for (auto module : toplevelModule.getOps<ModuleOp>()) {
227  auto name = module->getAttrOfType<StringAttr>("om.namespace");
228  // Use `counter` if the namespace is not specified beforehand.
229  if (!name) {
230  name = StringAttr::get(module.getContext(), "module_" + Twine(counter++));
231  module->setAttr("om.namespace", name);
232  }
233  modules.emplace_back(module);
234  }
235 
236  if (failed(failableParallelForEach(&getContext(), modules, [](auto &info) {
237  // Collect local information.
238  return info.initialize();
239  })))
240  return signalPassFailure();
241 
242  // 2. Symbol resolution. Check that there is exactly single definition for
243  // public symbols and rename private symbols if necessary.
244 
245  // Global namespace to get unique names to symbols.
246  Namespace nameSpace;
247  // A map from a pair of enclosing module op and old symbol to a new symbol.
248  SymMappingTy symMapping;
249 
250  // Construct a global map from symbols to class operations.
251  llvm::MapVector<StringAttr, SmallVector<ClassLike>> symbolToClasses;
252  for (const auto &info : modules)
253  for (auto &[name, op] : info.symbolToClasses) {
254  symbolToClasses[name].push_back(op);
255  // Add names to avoid collision.
256  (void)nameSpace.newName(name.getValue());
257  }
258 
259  // Resolve symbols. We consider a symbol used as an external module to be
260  // "public" thus we cannot rename such symbols when there is collision. We
261  // require a public symbol to have exactly one definition so otherwise raise
262  // an error.
263  for (auto &[name, classes] : symbolToClasses) {
264  // Check if it's legal to link classes. `resolveClasses` returns true if
265  // it's necessary to rename symbols.
266  auto result = resolveClasses(name, classes);
267  if (failed(result))
268  return signalPassFailure();
269 
270  // We can resolve symbol collision for symbols not referred by external
271  // classes. Create a new name using `om.namespace` attributes as a
272  // suffix.
273  if (*result)
274  for (auto op : classes) {
275  auto enclosingModule = cast<mlir::ModuleOp>(op->getParentOp());
276  auto nameSpaceId =
277  enclosingModule->getAttrOfType<StringAttr>("om.namespace");
278  symMapping[{enclosingModule, name}] = StringAttr::get(
279  &getContext(),
280  nameSpace.newName(name.getValue(), nameSpaceId.getValue()));
281  }
282  }
283 
284  // 3. Post-processing. Update class names and erase external classes.
285 
286  // Rename private symbols and remove external classes.
287  parallelForEach(&getContext(), modules,
288  [&](auto &info) { info.postProcess(symMapping); });
289 
290  // Finally move operations to the toplevel module.
291  auto *block = toplevelModule.getBody();
292  for (auto info : modules) {
293  block->getOperations().splice(block->end(),
294  info.module.getBody()->getOperations());
295  // Erase the module.
296  info.module.erase();
297  }
298 }
299 
300 std::unique_ptr<mlir::Pass> circt::om::createOMLinkModulesPass() {
301  return std::make_unique<LinkModulesPass>();
302 }
assert(baseType &&"element must be base type")
static Attribute getAttr(ArrayRef< NamedAttribute > attrs, StringRef name)
Get an attribute by name from a list of named attributes.
Definition: FIRRTLOps.cpp:3929
static FailureOr< bool > resolveClasses(StringAttr name, ArrayRef< ClassLike > classes)
A namespace that is used to store existing names and generate new names in some scope within the IR.
Definition: Namespace.h:30
StringRef newName(const Twine &name)
Return a unique name, derived from the input name, and add the new name to the internal namespace.
Definition: Namespace.h:72
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:54
std::unique_ptr< mlir::Pass > createOMLinkModulesPass()
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21
Definition: om.py:1