19 #include "mlir/Pass/Pass.h"
23 #define GEN_PASS_DEF_FREEZEPATHS
24 #include "circt/Dialect/OM/OMPasses.h.inc"
28 using namespace circt;
34 std::function<StringAttr(Operation *)> &getOpName)
35 : instanceGraph(instanceGraph), irn(irn), getOpNameFallback(getOpName) {}
38 LogicalResult processPath(Location loc, hw::HierPathOp hierPathOp,
39 PathAttr &targetPath, StringAttr &bottomModule,
40 StringAttr &component, StringAttr &field);
41 LogicalResult process(BasePathCreateOp pathOp);
42 LogicalResult process(PathCreateOp pathOp);
43 LogicalResult process(EmptyPathOp pathOp);
44 LogicalResult processListCreator(Operation *listCreateOp);
45 LogicalResult process(ObjectFieldOp objectFieldOp);
46 LogicalResult
run(ModuleOp module);
48 hw::InnerRefNamespace &irn;
49 std::function<StringAttr(Operation *)> getOpNameFallback;
53 static LogicalResult
getAccessPath(Location loc, Type type,
size_t fieldId,
55 SmallString<64> field;
57 if (
auto aliasType = dyn_cast<hw::TypeAliasType>(type))
58 type = aliasType.getCanonicalType();
59 if (
auto structType = dyn_cast<hw::StructType>(type)) {
60 auto index = structType.getIndexForFieldID(fieldId);
61 auto &element = structType.getElements()[index];
63 llvm::append_range(field, element.name.getValue());
65 fieldId -= structType.getFieldID(index);
66 }
else if (
auto arrayType = dyn_cast<hw::ArrayType>(type)) {
67 auto index = arrayType.getIndexForFieldID(fieldId);
69 Twine(index).toVector(field);
71 type = arrayType.getElementType();
72 fieldId -= arrayType.getFieldID(index);
74 return emitError(loc) <<
"can't create access path with fieldID "
75 << fieldId <<
" in type " << type;
84 bool isPathType =
false;
86 if (isa<BasePathType, PathType>(
innerType))
94 mlir::AttrTypeReplacer replacer;
95 replacer.addReplacement([](BasePathType
innerType) {
98 replacer.addReplacement([](PathType
innerType) {
107 return replacer.replace(type);
110 LogicalResult PathVisitor::processPath(Location loc, hw::HierPathOp hierPathOp,
111 PathAttr &targetPath,
112 StringAttr &bottomModule,
113 StringAttr &component,
115 auto *context = hierPathOp->getContext();
117 auto namepath = hierPathOp.getNamepathAttr().getValue();
118 SmallVector<PathElement> modules;
121 for (
auto attr : namepath.drop_back()) {
122 auto innerRef = cast<hw::InnerRefAttr>(attr);
123 auto target = irn.lookup(innerRef);
124 assert(target.isOpOnly() &&
"can't target a port the middle of a namepath");
125 auto *op = target.getOp();
127 auto verilogName = op->getAttrOfType<StringAttr>(
"hw.verilogName");
128 if (!verilogName && getOpNameFallback)
129 verilogName = getOpNameFallback(op);
131 auto diag = emitError(loc,
"component does not have verilog name");
132 diag.attachNote(op->getLoc()) <<
"component here";
135 modules.emplace_back(innerRef.getModule(), verilogName);
139 auto &
end = namepath.back();
140 if (
auto innerRef = dyn_cast<hw::InnerRefAttr>(end)) {
141 auto target = irn.lookup(innerRef);
142 if (target.isPort()) {
144 auto module = cast<hw::HWModuleLike>(target.getOp());
145 auto index = target.getPort();
146 bottomModule = module.getModuleNameAttr();
148 auto loc = module.getPortLoc(index);
149 auto type = module.getPortTypes()[index];
150 if (failed(
getAccessPath(loc, type, target.getField(), field)))
153 auto *op = target.getOp();
154 assert(op &&
"innerRef should be targeting something");
156 auto currentModule = innerRef.getModule();
158 auto verilogName = op->getAttrOfType<StringAttr>(
"hw.verilogName");
159 if (!verilogName && getOpNameFallback)
160 verilogName = getOpNameFallback(op);
162 auto diag = emitError(loc,
"component does not have verilog name");
163 diag.attachNote(op->getLoc()) <<
"component here";
169 if (
auto inst = dyn_cast<hw::HWInstanceLike>(op)) {
171 auto mods = inst.getReferencedModuleNamesAttr();
173 return op->emitError(
"unsupported instance operation");
175 modules.emplace_back(currentModule, verilogName);
176 bottomModule = cast<StringAttr>(mods[0]);
181 bottomModule = currentModule;
182 component = verilogName;
183 auto innerSym = cast<hw::InnerSymbolOpInterface>(op);
184 auto value = innerSym.getTargetResult();
186 target.getField(), field)))
192 auto symbolRef = cast<FlatSymbolRefAttr>(end);
193 bottomModule = symbolRef.getAttr();
203 LogicalResult PathVisitor::process(PathCreateOp path) {
205 irn.symTable.lookup<hw::HierPathOp>(path.getTargetAttr().getAttr());
207 StringAttr bottomModule;
210 if (failed(processPath(path.getLoc(), hierPathOp, targetPath, bottomModule,
215 OpBuilder builder(path);
216 auto frozenPath = builder.create<FrozenPathCreateOp>(
217 path.getLoc(), path.getTargetKindAttr(), path->getOperand(0), targetPath,
218 bottomModule, ref, field);
219 path.replaceAllUsesWith(frozenPath.getResult());
225 LogicalResult PathVisitor::process(BasePathCreateOp path) {
227 irn.symTable.lookup<hw::HierPathOp>(path.getTargetAttr().getAttr());
229 StringAttr bottomModule;
232 if (failed(processPath(path.getLoc(), hierPathOp, targetPath, bottomModule,
237 return path->emitError(
"basepath must target an instance");
239 return path->emitError(
"basepath must not target a field");
242 OpBuilder builder(path);
243 auto frozenPath = builder.create<FrozenBasePathCreateOp>(
244 path.getLoc(), path->getOperand(0), targetPath);
245 path.replaceAllUsesWith(frozenPath.getResult());
251 LogicalResult PathVisitor::process(EmptyPathOp path) {
252 OpBuilder builder(path);
253 auto frozenPath = builder.create<FrozenEmptyPathOp>(path.getLoc());
254 path.replaceAllUsesWith(frozenPath.getResult());
260 LogicalResult PathVisitor::processListCreator(Operation *listCreateOp) {
261 ListType listType = cast<ListType>(listCreateOp->getResult(0).getType());
271 OpBuilder builder(listCreateOp);
272 auto *newListCreateOp = builder.create(
273 listCreateOp->getLoc(), listCreateOp->getName().getIdentifier(),
274 listCreateOp->getOperands(), {newListType});
275 listCreateOp->getResult(0).replaceAllUsesWith(newListCreateOp->getResult(0));
276 listCreateOp->erase();
281 LogicalResult PathVisitor::process(ObjectFieldOp objectFieldOp) {
282 Type resultType = objectFieldOp.getResult().getType();
292 OpBuilder builder(objectFieldOp);
293 auto newObjectFieldOp = builder.create<ObjectFieldOp>(
294 objectFieldOp.getLoc(), newResultType, objectFieldOp.getObject(),
295 objectFieldOp.getFieldPath());
296 objectFieldOp.replaceAllUsesWith(newObjectFieldOp.getResult());
297 objectFieldOp->erase();
302 auto updatePathType = [&](Value value) {
306 for (
auto classLike : module.getOps<ClassLike>()) {
308 for (
auto arg : classLike.getBodyBlock()->getArguments())
310 auto result = classLike->walk([&](Operation *op) -> WalkResult {
311 if (
auto basePath = dyn_cast<BasePathCreateOp>(op)) {
312 if (failed(process(basePath)))
313 return WalkResult::interrupt();
314 }
else if (
auto path = dyn_cast<PathCreateOp>(op)) {
315 if (failed(process(path)))
316 return WalkResult::interrupt();
317 }
else if (
auto path = dyn_cast<EmptyPathOp>(op)) {
318 if (failed(process(path)))
319 return WalkResult::interrupt();
320 }
else if (isa<ListCreateOp, ListConcatOp>(op)) {
321 if (failed(processListCreator(op)))
322 return WalkResult::interrupt();
323 }
else if (
auto objectField = dyn_cast<ObjectFieldOp>(op)) {
324 if (failed(process(objectField)))
325 return WalkResult::interrupt();
327 return WalkResult::advance();
329 if (result.wasInterrupted())
339 struct FreezePathsPass
340 :
public circt::om::impl::FreezePathsBase<FreezePathsPass> {
341 FreezePathsPass(std::function<StringAttr(Operation *)> getOpName)
342 : getOpName(std::move(getOpName)) {}
343 void runOnOperation()
override;
345 std::function<StringAttr(Operation *)> getOpName;
349 void FreezePathsPass::runOnOperation() {
350 auto module = getOperation();
351 auto &instanceGraph = getAnalysis<hw::InstanceGraph>();
352 mlir::SymbolTableCollection symbolTableCollection;
353 auto &symbolTable = getAnalysis<SymbolTable>();
354 hw::InnerSymbolTableCollection collection(module);
355 hw::InnerRefNamespace irn{symbolTable, collection};
356 if (failed(PathVisitor(instanceGraph, irn, getOpName).
run(module)))
361 std::function<StringAttr(Operation *)> getOpName) {
362 return std::make_unique<FreezePathsPass>(getOpName);
assert(baseType &&"element must be base type")
static bool hasPathType(Type type)
static LogicalResult getAccessPath(Location loc, Type type, size_t fieldId, StringAttr &result)
mlir::AttrTypeReplacer makeReplacer()
static Type processType(Type type)
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
mlir::Type innerType(mlir::Type type)
std::map< std::string, std::set< std::string > > InstanceGraph
Iterates over the handshake::FuncOp's in the program to build an instance graph.
std::unique_ptr< mlir::Pass > createFreezePathsPass(std::function< StringAttr(Operation *)> getOpNameFallback={})
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
int run(Type[Generator] generator=CppGenerator, cmdline_args=sys.argv)