CIRCT 22.0.0git
Loading...
Searching...
No Matches
LowerSignatures.cpp
Go to the documentation of this file.
1//===- LowerSignatures.cpp - Lower Module Signatures ------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the LowerSignatures pass. This pass replaces aggregate
10// types with expanded values in module arguments as specified by the ABI
11// information.
12//
13//===----------------------------------------------------------------------===//
14
21#include "circt/Support/Debug.h"
22#include "mlir/IR/Threading.h"
23#include "mlir/Pass/Pass.h"
24#include "llvm/Support/Debug.h"
25
26#define DEBUG_TYPE "firrtl-lower-signatures"
27
28namespace circt {
29namespace firrtl {
30#define GEN_PASS_DEF_LOWERSIGNATURES
31#include "circt/Dialect/FIRRTL/Passes.h.inc"
32} // namespace firrtl
33} // namespace circt
34
35using namespace circt;
36using namespace firrtl;
37
38//===----------------------------------------------------------------------===//
39// Module Type Lowering
40//===----------------------------------------------------------------------===//
41namespace {
42
43struct AttrCache {
44 AttrCache(MLIRContext *context) {
45 nameAttr = StringAttr::get(context, "name");
46 sPortDirections = StringAttr::get(context, "portDirections");
47 sPortNames = StringAttr::get(context, "portNames");
48 sPortTypes = StringAttr::get(context, "portTypes");
49 sPortLocations = StringAttr::get(context, "portLocations");
50 sPortAnnotations = StringAttr::get(context, "portAnnotations");
51 sPortDomains = StringAttr::get(context, "domainInfo");
52 aEmpty = ArrayAttr::get(context, {});
53 }
54 AttrCache(const AttrCache &) = default;
55
56 StringAttr nameAttr, sPortDirections, sPortNames, sPortTypes, sPortLocations,
57 sPortAnnotations, sPortDomains;
58 ArrayAttr aEmpty;
59};
60
61struct FieldMapEntry : public PortInfo {
62 size_t portID;
63 size_t resultID;
64 size_t fieldID;
65};
66
67using PortConversion = SmallVector<FieldMapEntry>;
68
69template <typename T>
70class FieldIDSearch {
71 using E = typename T::ElementType;
72 using V = SmallVector<E>;
73
74public:
75 using const_iterator = typename V::const_iterator;
76
77 template <typename Container>
78 FieldIDSearch(const Container &src) {
79 if constexpr (std::is_convertible_v<Container, Attribute>)
80 if (!src)
81 return;
82 for (auto attr : src)
83 vals.push_back(attr);
84 std::sort(vals.begin(), vals.end(), fieldComp);
85 }
86
87 std::pair<const_iterator, const_iterator> find(uint64_t low,
88 uint64_t high) const {
89 return {std::lower_bound(vals.begin(), vals.end(), low, fieldCompInt2),
90 std::upper_bound(vals.begin(), vals.end(), high, fieldCompInt1)};
91 }
92
93 bool empty(uint64_t low, uint64_t high) const {
94 auto [b, e] = find(low, high);
95 return b == e;
96 }
97
98private:
99 static constexpr auto fieldComp = [](const E &lhs, const E &rhs) {
100 return lhs.getFieldID() < rhs.getFieldID();
101 };
102 static constexpr auto fieldCompInt2 = [](const E &lhs, uint64_t rhs) {
103 return lhs.getFieldID() < rhs;
104 };
105 static constexpr auto fieldCompInt1 = [](uint64_t lhs, const E &rhs) {
106 return lhs < rhs.getFieldID();
107 };
108
109 V vals;
110};
111
112} // namespace
113
114static hw::InnerSymAttr
115symbolsForFieldIDRange(MLIRContext *ctx,
116 const FieldIDSearch<hw::InnerSymAttr> &syms,
117 uint64_t low, uint64_t high) {
118 auto [b, e] = syms.find(low, high);
119 SmallVector<hw::InnerSymPropertiesAttr, 4> newSyms(b, e);
120 if (newSyms.empty())
121 return {};
122 for (auto &sym : newSyms)
123 sym = hw::InnerSymPropertiesAttr::get(
124 ctx, sym.getName(), sym.getFieldID() - low, sym.getSymVisibility());
125 return hw::InnerSymAttr::get(ctx, newSyms);
126}
127
128static AnnotationSet
129annosForFieldIDRange(MLIRContext *ctx,
130 const FieldIDSearch<AnnotationSet> &annos, uint64_t low,
131 uint64_t high) {
132 AnnotationSet newAnnos(ctx);
133 auto [b, e] = annos.find(low, high);
134 for (; b != e; ++b)
135 newAnnos.addAnnotations(Annotation(*b, b->getFieldID() - low));
136 return newAnnos;
137}
138
139static LogicalResult
140computeLoweringImpl(FModuleLike mod, PortConversion &newPorts, Convention conv,
141 size_t portID, const PortInfo &port, bool isFlip,
142 Twine name, FIRRTLType type, uint64_t fieldID,
143 const FieldIDSearch<hw::InnerSymAttr> &syms,
144 const FieldIDSearch<AnnotationSet> &annos) {
145 auto *ctx = type.getContext();
147 .Case<BundleType>([&](BundleType bundle) -> LogicalResult {
148 // This should be enhanced to be able to handle bundle<all flips of
149 // passive>, or this should be a canonicalizer
150 if (conv != Convention::Scalarized && bundle.isPassive()) {
151 auto lastId = fieldID + bundle.getMaxFieldID();
152 newPorts.push_back(
153 {{StringAttr::get(ctx, name), type,
154 isFlip ? Direction::Out : Direction::In,
155 symbolsForFieldIDRange(ctx, syms, fieldID, lastId), port.loc,
156 annosForFieldIDRange(ctx, annos, fieldID, lastId),
157 port.domains},
158 portID,
159 newPorts.size(),
160 fieldID});
161 } else {
162 for (auto [idx, elem] : llvm::enumerate(bundle.getElements())) {
163 if (failed(computeLoweringImpl(
164 mod, newPorts, conv, portID, port, isFlip ^ elem.isFlip,
165 name + "_" + elem.name.getValue(), elem.type,
166 fieldID + bundle.getFieldID(idx), syms, annos)))
167 return failure();
168 if (!syms.empty(fieldID, fieldID))
169 return mod.emitError("Port [")
170 << port.name
171 << "] should be subdivided, but cannot be because of "
172 "symbol ["
173 << port.sym.getSymIfExists(fieldID) << "] on a bundle";
174 if (!annos.empty(fieldID, fieldID)) {
175 auto err = mod.emitError("Port [")
176 << port.name
177 << "] should be subdivided, but cannot be because of "
178 "annotations [";
179 auto [b, e] = annos.find(fieldID, fieldID);
180 err << b->getClass() << "(" << b->getFieldID() << ")";
181 b++;
182 for (; b != e; ++b)
183 err << ", " << b->getClass() << "(" << b->getFieldID() << ")";
184 err << "] on a bundle";
185 return err;
186 }
187 }
188 }
189 return success();
190 })
191 .Case<FVectorType>([&](FVectorType vector) -> LogicalResult {
192 if (conv != Convention::Scalarized &&
193 vector.getElementType().isPassive()) {
194 auto lastId = fieldID + vector.getMaxFieldID();
195 newPorts.push_back(
196 {{StringAttr::get(ctx, name), type,
197 isFlip ? Direction::Out : Direction::In,
198 symbolsForFieldIDRange(ctx, syms, fieldID, lastId), port.loc,
199 annosForFieldIDRange(ctx, annos, fieldID, lastId),
200 port.domains},
201 portID,
202 newPorts.size(),
203 fieldID});
204 } else {
205 for (size_t i = 0, e = vector.getNumElements(); i < e; ++i) {
206 if (failed(computeLoweringImpl(
207 mod, newPorts, conv, portID, port, isFlip,
208 name + "_" + Twine(i), vector.getElementType(),
209 fieldID + vector.getFieldID(i), syms, annos)))
210 return failure();
211 if (!syms.empty(fieldID, fieldID))
212 return mod.emitError("Port [")
213 << port.name
214 << "] should be subdivided, but cannot be because of "
215 "symbol ["
216 << port.sym.getSymIfExists(fieldID) << "] on a vector";
217 if (!annos.empty(fieldID, fieldID)) {
218 auto err = mod.emitError("Port [")
219 << port.name
220 << "] should be subdivided, but cannot be because of "
221 "annotations [";
222 auto [b, e] = annos.find(fieldID, fieldID);
223 err << b->getClass();
224 ++b;
225 for (; b != e; ++b)
226 err << ", " << b->getClass();
227 err << "] on a vector";
228 return err;
229 }
230 }
231 }
232 return success();
233 })
234 .Default([&](FIRRTLType type) {
235 // Properties and other types wind up here.
236 newPorts.push_back(
237 {{StringAttr::get(ctx, name), type,
238 isFlip ? Direction::Out : Direction::In,
239 symbolsForFieldIDRange(ctx, syms, fieldID, fieldID), port.loc,
240 annosForFieldIDRange(ctx, annos, fieldID, fieldID), port.domains},
241 portID,
242 newPorts.size(),
243 fieldID});
244 return success();
245 });
246}
247
248// compute a new moduletype from an old module type and lowering convention.
249// Also compute a fieldID map from port, fieldID -> port
250static LogicalResult computeLowering(FModuleLike mod, Convention conv,
251 PortConversion &newPorts) {
252 for (auto [idx, port] : llvm::enumerate(mod.getPorts())) {
253 if (failed(computeLoweringImpl(
254 mod, newPorts, conv, idx, port, port.direction == Direction::Out,
255 port.name.getValue(), type_cast<FIRRTLType>(port.type), 0,
256 FieldIDSearch<hw::InnerSymAttr>(port.sym),
257 FieldIDSearch<AnnotationSet>(port.annotations))))
258 return failure();
259 }
260 return success();
261}
262
263static LogicalResult lowerModuleSignature(FModuleLike module, Convention conv,
264 AttrCache &cache,
265 PortConversion &newPorts) {
266 ImplicitLocOpBuilder theBuilder(module.getLoc(), module.getContext());
267 if (computeLowering(module, conv, newPorts).failed())
268 return failure();
269 if (auto mod = dyn_cast<FModuleOp>(module.getOperation())) {
270 Block *body = mod.getBodyBlock();
271 theBuilder.setInsertionPointToStart(body);
272 auto oldNumArgs = body->getNumArguments();
273
274 // Compute the replacement value for old arguments
275 // This creates all the new arguments and produces bounce wires when
276 // necessary
277 SmallVector<Value> bounceWires(oldNumArgs);
278 for (auto &p : newPorts) {
279 auto newArg = body->addArgument(p.type, p.loc);
280 // Get or create a bounce wire for changed ports
281 // For unmodified ports, move the uses to the replacement port
282 if (p.fieldID != 0) {
283 auto &wire = bounceWires[p.portID];
284 if (!wire)
285 wire = WireOp::create(theBuilder, module.getPortType(p.portID),
286 module.getPortNameAttr(p.portID),
287 NameKindEnum::InterestingName)
288 .getResult();
289 } else {
290 bounceWires[p.portID] = newArg;
291 }
292 }
293 // replace old arguments. Somethings get dropped completely, like
294 // zero-length vectors.
295 for (auto idx = 0U; idx < oldNumArgs; ++idx) {
296 if (!bounceWires[idx]) {
297 bounceWires[idx] = WireOp::create(theBuilder, module.getPortType(idx),
298 module.getPortNameAttr(idx))
299 .getResult();
300 }
301 body->getArgument(idx).replaceAllUsesWith(bounceWires[idx]);
302 }
303
304 // Goodbye old ports, now ResultID in the PortInfo is correct.
305 body->eraseArguments(0, oldNumArgs);
306
307 // Connect the bounce wires to the new arguments
308 for (auto &p : newPorts) {
309 if (isa<BlockArgument>(bounceWires[p.portID]))
310 continue;
311 if (p.isOutput())
313 theBuilder, body->getArgument(p.resultID),
314 getValueByFieldID(theBuilder, bounceWires[p.portID], p.fieldID));
315 else
317 theBuilder,
318 getValueByFieldID(theBuilder, bounceWires[p.portID], p.fieldID),
319 body->getArgument(p.resultID));
320 }
321 }
322
323 SmallVector<NamedAttribute, 8> newModuleAttrs;
324
325 // Copy over any attributes that weren't original argument attributes.
326 for (auto attr : module->getAttrDictionary())
327 // Drop old "portNames", directions, and argument attributes. These are
328 // handled differently below.
329 if (attr.getName() != "portNames" && attr.getName() != "portDirections" &&
330 attr.getName() != "portTypes" && attr.getName() != "portAnnotations" &&
331 attr.getName() != "portSymbols" && attr.getName() != "portLocations")
332 newModuleAttrs.push_back(attr);
333
334 SmallVector<Direction> newPortDirections;
335 SmallVector<Attribute> newPortNames;
336 SmallVector<Attribute> newPortTypes;
337 SmallVector<Attribute> newPortSyms;
338 SmallVector<Attribute> newPortLocations;
339 SmallVector<Attribute, 8> newPortAnnotations;
340 SmallVector<Attribute> newPortDomains;
341
342 for (auto p : newPorts) {
343 newPortTypes.push_back(TypeAttr::get(p.type));
344 newPortNames.push_back(p.name);
345 newPortDirections.push_back(p.direction);
346 newPortSyms.push_back(p.sym);
347 newPortLocations.push_back(p.loc);
348 newPortAnnotations.push_back(p.annotations.getArrayAttr());
349 newPortDomains.push_back(p.domains ? p.domains : cache.aEmpty);
350 }
351
352 newModuleAttrs.push_back(NamedAttribute(
353 cache.sPortDirections,
354 direction::packAttribute(module.getContext(), newPortDirections)));
355
356 newModuleAttrs.push_back(
357 NamedAttribute(cache.sPortNames, theBuilder.getArrayAttr(newPortNames)));
358
359 newModuleAttrs.push_back(
360 NamedAttribute(cache.sPortTypes, theBuilder.getArrayAttr(newPortTypes)));
361
362 newModuleAttrs.push_back(NamedAttribute(
363 cache.sPortLocations, theBuilder.getArrayAttr(newPortLocations)));
364
365 newModuleAttrs.push_back(NamedAttribute(
366 cache.sPortAnnotations, theBuilder.getArrayAttr(newPortAnnotations)));
367
368 newModuleAttrs.push_back(NamedAttribute(
369 cache.sPortDomains, theBuilder.getArrayAttr(newPortDomains)));
370
371 // Update the module's attributes.
372 module->setAttrs(newModuleAttrs);
373 FModuleLike::fixupPortSymsArray(newPortSyms, theBuilder.getContext());
374 module.setPortSymbols(newPortSyms);
375 return success();
376}
377
378static void lowerModuleBody(FModuleOp mod,
379 const DenseMap<StringAttr, PortConversion> &ports) {
380 mod->walk([&](InstanceOp inst) -> void {
381 ImplicitLocOpBuilder theBuilder(inst.getLoc(), inst);
382 const auto &modPorts = ports.at(inst.getModuleNameAttr().getAttr());
383
384 // Fix up the Instance
385 SmallVector<PortInfo> instPorts; // Oh I wish ArrayRef was polymorphic.
386 for (auto p : modPorts) {
387 p.sym = {};
388 // Might need to partially copy stuff from the old instance.
389 p.annotations = AnnotationSet{mod.getContext()};
390 instPorts.push_back(p);
391 }
392 auto annos = inst.getAnnotations();
393 auto newOp = InstanceOp::create(
394 theBuilder, instPorts, inst.getModuleName(), inst.getName(),
395 inst.getNameKind(), annos.getValue(), inst.getLayers(),
396 inst.getLowerToBind(), inst.getDoNotPrint(), inst.getInnerSymAttr());
397
398 auto oldDict = inst->getDiscardableAttrDictionary();
399 auto newDict = newOp->getDiscardableAttrDictionary();
400 auto oldNames = inst.getPortNamesAttr();
401 SmallVector<NamedAttribute> newAttrs;
402 for (auto na : oldDict)
403 if (!newDict.contains(na.getName()))
404 newOp->setDiscardableAttr(na.getName(), na.getValue());
405
406 // Connect up the old instance users to the new instance
407 SmallVector<WireOp> bounce(inst.getNumResults());
408 for (auto p : modPorts) {
409 // No change? No bounce wire.
410 if (p.fieldID == 0) {
411 inst.getResult(p.portID).replaceAllUsesWith(
412 newOp.getResult(p.resultID));
413 continue;
414 }
415 if (!bounce[p.portID]) {
416 bounce[p.portID] = WireOp::create(
417 theBuilder, inst.getResult(p.portID).getType(),
418 theBuilder.getStringAttr(
419 inst.getName() + "." +
420 cast<StringAttr>(oldNames[p.portID]).getValue()));
421 inst.getResult(p.portID).replaceAllUsesWith(
422 bounce[p.portID].getResult());
423 }
424 // Connect up the Instance to the bounce wires
425 if (p.isInput())
426 emitConnect(theBuilder, newOp.getResult(p.resultID),
427 getValueByFieldID(theBuilder, bounce[p.portID].getResult(),
428 p.fieldID));
429 else
430 emitConnect(theBuilder,
431 getValueByFieldID(theBuilder, bounce[p.portID].getResult(),
432 p.fieldID),
433 newOp.getResult(p.resultID));
434 }
435 // Zero Width ports may have dangling connects since they are not preserved
436 // and do not have bounce wires.
437 for (auto *use : llvm::make_early_inc_range(inst->getUsers())) {
438 assert(isa<MatchingConnectOp>(use) || isa<ConnectOp>(use));
439 use->erase();
440 }
441 inst->erase();
442 return;
443 });
444}
445
446//===----------------------------------------------------------------------===//
447// Pass Infrastructure
448//===----------------------------------------------------------------------===//
449
450namespace {
451struct LowerSignaturesPass
452 : public circt::firrtl::impl::LowerSignaturesBase<LowerSignaturesPass> {
453 void runOnOperation() override;
454};
455} // end anonymous namespace
456
457// This is the main entrypoint for the lowering pass.
458void LowerSignaturesPass::runOnOperation() {
460
461 // Cached attr
462 AttrCache cache(&getContext());
463
464 DenseMap<StringAttr, PortConversion> portMap;
465 auto circuit = getOperation();
466
467 for (auto mod : circuit.getOps<FModuleLike>()) {
468 if (lowerModuleSignature(mod, mod.getConvention(), cache,
469 portMap[mod.getNameAttr()])
470 .failed())
471 return signalPassFailure();
472 }
473 parallelForEach(&getContext(), circuit.getOps<FModuleOp>(),
474 [&portMap](FModuleOp mod) { lowerModuleBody(mod, portMap); });
475}
assert(baseType &&"element must be base type")
static LogicalResult computeLowering(FModuleLike mod, Convention conv, PortConversion &newPorts)
static AnnotationSet annosForFieldIDRange(MLIRContext *ctx, const FieldIDSearch< AnnotationSet > &annos, uint64_t low, uint64_t high)
static LogicalResult lowerModuleSignature(FModuleLike module, Convention conv, AttrCache &cache, PortConversion &newPorts)
static LogicalResult computeLoweringImpl(FModuleLike mod, PortConversion &newPorts, Convention conv, size_t portID, const PortInfo &port, bool isFlip, Twine name, FIRRTLType type, uint64_t fieldID, const FieldIDSearch< hw::InnerSymAttr > &syms, const FieldIDSearch< AnnotationSet > &annos)
static void lowerModuleBody(FModuleOp mod, const DenseMap< StringAttr, PortConversion > &ports)
static hw::InnerSymAttr symbolsForFieldIDRange(MLIRContext *ctx, const FieldIDSearch< hw::InnerSymAttr > &syms, uint64_t low, uint64_t high)
static InstancePath empty
#define CIRCT_DEBUG_SCOPED_PASS_LOGGER(PASS)
Definition Debug.h:70
This class provides a read-only projection over the MLIR attributes that represent a set of annotatio...
MLIRContext * getContext() const
Return the MLIRContext corresponding to this AnnotationSet.
void addAnnotations(ArrayRef< Annotation > annotations)
Add more annotations to this annotation set.
This class provides a read-only projection of an annotation.
This class implements the same functionality as TypeSwitch except that it uses firrtl::type_dyn_cast ...
FIRRTLTypeSwitch< T, ResultT > & Case(CallableT &&caseFn)
Add a case on the given type.
Base class for the port conversion of a particular port.
mlir::DenseBoolArrayAttr packAttribute(MLIRContext *context, ArrayRef< Direction > directions)
Return a DenseBoolArrayAttr containing the packed representation of an array of directions.
Value getValueByFieldID(ImplicitLocOpBuilder builder, Value value, unsigned fieldID)
This gets the value targeted by a field id.
void emitConnect(OpBuilder &builder, Location loc, Value lhs, Value rhs)
Emit a connect between two values.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
This holds the name and type that describes the module's ports.