20 #include "mlir/Transforms/DialectConversion.h"
22 using namespace circt;
33 DenseMap<SymbolRefAttr, std::pair<hw::StructType, Location>>
argTypes;
37 DenseMap<std::pair<ClassOp, StringAttr>, ClassOp>
instanceMap;
40 ClassOp
lookupNext(ClassOp scope, StringAttr instSym)
const {
41 auto entry = instanceMap.find(std::make_pair(scope, instSym));
42 if (entry == instanceMap.end())
48 SymbolRefAttr resolveInstancePath(Operation *scope, SymbolRefAttr path)
const;
52 ClassOp cls = method.getParentOp();
55 {FlatSymbolRefAttr::get(method.getContext(), *method.getInnerName())});
60 auto *ctxt = mod.getContext();
63 for (
auto cls : mod.getOps<ClassOp>())
64 classSymbols[cls.getSymNameAttr()] = cls;
66 for (
auto cls : mod.getOps<ClassOp>()) {
68 for (
auto method : cls.getOps<MethodOp>()) {
71 SmallVector<hw::StructType::FieldInfo> argFields;
72 for (
auto [argName, argType] :
73 llvm::zip(method.getArgNamesAttr().getAsRange<StringAttr>(),
74 cast<MethodLikeOpInterface>(method.getOperation())
76 argFields.push_back({argName, argType});
82 if (method->getNumRegions() > 0) {
83 SmallVector<Location> argLocs;
84 Block *body = &method.getBody().front();
85 for (
auto arg : body->getArguments())
86 argLocs.push_back(arg.getLoc());
92 std::make_pair(getSymbol(method), std::make_pair(argStruct, argLoc)));
96 for (
auto inst : cls.getOps<InstanceOp>()) {
97 auto clsEntry = classSymbols.find(inst.getTargetNameAttr().getAttr());
98 assert(clsEntry != classSymbols.end() &&
99 "class being instantiated doesn't exist");
100 instanceMap[std::make_pair(cls, inst.getInnerSym().getSymName())] =
108 SymbolRefAttr path)
const {
109 auto cls = scope->getParentOfType<ClassOp>();
110 assert(cls &&
"scope outside of ibis class");
113 cls = lookupNext(cls, path.getRootReference());
119 for (
auto instSym : path.getNestedReferences().drop_back()) {
120 cls = lookupNext(cls, instSym.getAttr());
127 {FlatSymbolRefAttr::get(path.getLeafReference())});
137 void rewrite(CallOp, OpAdaptor adaptor,
138 ConversionPatternRewriter &rewriter)
const final;
139 LogicalResult match(CallOp)
const override {
return success(); }
146 void MergeCallArgs::rewrite(CallOp call, OpAdaptor adaptor,
147 ConversionPatternRewriter &rewriter)
const {
148 auto loc = call.getLoc();
149 rewriter.setInsertionPoint(call);
150 auto method = call->getParentOfType<ibis::MethodLikeOpInterface>();
153 SymbolRefAttr calleeSym =
155 auto argStructEntry = info.argTypes.find(calleeSym);
156 assert(argStructEntry != info.argTypes.end() &&
"Method symref not found!");
157 auto [argStruct, argLoc] = argStructEntry->second;
161 adaptor.getOperands());
162 newArg->setAttr(
"sv.namehint",
163 rewriter.getStringAttr(
164 call.getCalleeAttr().getLeafReference().getValue() +
165 "_args_called_from_" +
166 method.getMethodName().getValue()));
169 rewriter.updateRootInPlace(call, [&]() {
170 call.getOperandsMutable().clear();
171 call.getOperandsMutable().append(newArg.getResult());
182 void rewrite(MethodOp, OpAdaptor adaptor,
183 ConversionPatternRewriter &rewriter)
const final;
184 LogicalResult match(MethodOp)
const override {
return success(); }
191 void MergeMethodArgs::rewrite(MethodOp func, OpAdaptor adaptor,
192 ConversionPatternRewriter &rewriter)
const {
193 auto loc = func.getLoc();
194 auto *ctxt = getContext();
197 auto argStructEntry =
199 assert(argStructEntry != info.argTypes.end() &&
"Cannot find symref!");
200 auto [argStruct, argLoc] = argStructEntry->second;
203 FunctionType funcType = func.getFunctionType();
204 FunctionType newFuncType =
208 rewriter.create<MethodOp>(loc, func.getInnerSym(), newFuncType,
209 newArgNames, ArrayAttr(), ArrayAttr());
211 if (func->getNumRegions() > 0) {
215 rewriter.createBlock(&newMethod.getRegion(), {}, {argStruct}, {argLoc});
216 rewriter.setInsertionPointToStart(b);
217 auto replacementArgs =
218 rewriter.create<hw::StructExplodeOp>(loc, b->getArgument(0));
221 Block *funcBody = &func.getBody().front();
222 rewriter.mergeBlocks(funcBody, b, replacementArgs.getResults());
225 rewriter.eraseOp(func);
230 struct CallPrepPass :
public IbisCallPrepBase<CallPrepPass> {
231 void runOnOperation()
override;
239 void CallPrepPass::runOnOperation() {
242 if (failed(merge(info))) {
250 ConversionTarget target(getContext());
251 target.markUnknownOpDynamicallyLegal([](Operation *) {
return true; });
252 target.addDynamicallyLegalOp<CallOp>([](CallOp call) {
253 auto argValues = call.getArgOperands();
254 return argValues.size() == 1 &&
255 hw::type_isa<hw::StructType>(argValues.front().getType());
257 target.addDynamicallyLegalOp<MethodOp>([](MethodOp func) {
258 ArrayRef<Type> argTypes = func.getFunctionType().getInputs();
259 return argTypes.size() == 1 &&
260 hw::type_isa<hw::StructType>(argTypes.front());
264 RewritePatternSet
patterns(&getContext());
265 patterns.insert<MergeCallArgs>(&getContext(), info);
266 patterns.insert<MergeMethodArgs>(&getContext(), info);
268 return applyPartialConversion(getOperation(), target, std::move(
patterns));
272 return std::make_unique<CallPrepPass>();
assert(baseType &&"element must be base type")
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
std::unique_ptr< mlir::Pass > createCallPrepPass()
This file defines an intermediate representation for circuits acting as an abstraction for constraint...
Build indexes to make lookups faster. Create the new argument types as well.
SymbolRefAttr resolveInstancePath(Operation *scope, SymbolRefAttr path) const
DenseMap< StringAttr, ClassOp > classSymbols
ClassOp lookupNext(ClassOp scope, StringAttr instSym) const
DenseMap< std::pair< ClassOp, StringAttr >, ClassOp > instanceMap
DenseMap< SymbolRefAttr, std::pair< hw::StructType, Location > > argTypes
CallPrepPrecomputed(ModuleOp mod)
static SymbolRefAttr getSymbol(MethodOp method)