CIRCT 22.0.0git
Loading...
Searching...
No Matches
LowerSignatures.cpp
Go to the documentation of this file.
1//===- LowerSignatures.cpp - Lower Module Signatures ------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the LowerSignatures pass. This pass replaces aggregate
10// types with expanded values in module arguments as specified by the ABI
11// information.
12//
13//===----------------------------------------------------------------------===//
14
21#include "circt/Support/Debug.h"
22#include "mlir/IR/Threading.h"
23#include "mlir/Pass/Pass.h"
24#include "llvm/Support/Debug.h"
25
26#define DEBUG_TYPE "firrtl-lower-signatures"
27
28namespace circt {
29namespace firrtl {
30#define GEN_PASS_DEF_LOWERSIGNATURES
31#include "circt/Dialect/FIRRTL/Passes.h.inc"
32} // namespace firrtl
33} // namespace circt
34
35using namespace circt;
36using namespace firrtl;
37
38//===----------------------------------------------------------------------===//
39// Module Type Lowering
40//===----------------------------------------------------------------------===//
41namespace {
42
43struct AttrCache {
44 AttrCache(MLIRContext *context) {
45 nameAttr = StringAttr::get(context, "name");
46 sPortDirections = StringAttr::get(context, "portDirections");
47 sPortNames = StringAttr::get(context, "portNames");
48 sPortTypes = StringAttr::get(context, "portTypes");
49 sPortLocations = StringAttr::get(context, "portLocations");
50 sPortAnnotations = StringAttr::get(context, "portAnnotations");
51 sPortDomains = StringAttr::get(context, "domainInfo");
52 aEmpty = ArrayAttr::get(context, {});
53 }
54 AttrCache(const AttrCache &) = default;
55
56 StringAttr nameAttr, sPortDirections, sPortNames, sPortTypes, sPortLocations,
57 sPortAnnotations, sPortDomains;
58 ArrayAttr aEmpty;
59};
60
61struct FieldMapEntry : public PortInfo {
62 size_t portID;
63 size_t resultID;
64 size_t fieldID;
65};
66
67using PortConversion = SmallVector<FieldMapEntry>;
68
69template <typename T>
70class FieldIDSearch {
71 using E = typename T::ElementType;
72 using V = SmallVector<E>;
73
74public:
75 using const_iterator = typename V::const_iterator;
76
77 template <typename Container>
78 FieldIDSearch(const Container &src) {
79 if constexpr (std::is_convertible_v<Container, Attribute>)
80 if (!src)
81 return;
82 for (auto attr : src)
83 vals.push_back(attr);
84 std::sort(vals.begin(), vals.end(), fieldComp);
85 }
86
87 std::pair<const_iterator, const_iterator> find(uint64_t low,
88 uint64_t high) const {
89 return {std::lower_bound(vals.begin(), vals.end(), low, fieldCompInt2),
90 std::upper_bound(vals.begin(), vals.end(), high, fieldCompInt1)};
91 }
92
93 bool empty(uint64_t low, uint64_t high) const {
94 auto [b, e] = find(low, high);
95 return b == e;
96 }
97
98private:
99 static constexpr auto fieldComp = [](const E &lhs, const E &rhs) {
100 return lhs.getFieldID() < rhs.getFieldID();
101 };
102 static constexpr auto fieldCompInt2 = [](const E &lhs, uint64_t rhs) {
103 return lhs.getFieldID() < rhs;
104 };
105 static constexpr auto fieldCompInt1 = [](uint64_t lhs, const E &rhs) {
106 return lhs < rhs.getFieldID();
107 };
108
109 V vals;
110};
111
112} // namespace
113
114static hw::InnerSymAttr
115symbolsForFieldIDRange(MLIRContext *ctx,
116 const FieldIDSearch<hw::InnerSymAttr> &syms,
117 uint64_t low, uint64_t high) {
118 auto [b, e] = syms.find(low, high);
119 SmallVector<hw::InnerSymPropertiesAttr, 4> newSyms(b, e);
120 if (newSyms.empty())
121 return {};
122 for (auto &sym : newSyms)
123 sym = hw::InnerSymPropertiesAttr::get(
124 ctx, sym.getName(), sym.getFieldID() - low, sym.getSymVisibility());
125 return hw::InnerSymAttr::get(ctx, newSyms);
126}
127
128static AnnotationSet
129annosForFieldIDRange(MLIRContext *ctx,
130 const FieldIDSearch<AnnotationSet> &annos, uint64_t low,
131 uint64_t high) {
132 AnnotationSet newAnnos(ctx);
133 auto [b, e] = annos.find(low, high);
134 for (; b != e; ++b)
135 newAnnos.addAnnotations(Annotation(*b, b->getFieldID() - low));
136 return newAnnos;
137}
138
139static LogicalResult
140computeLoweringImpl(FModuleLike mod, PortConversion &newPorts, Convention conv,
141 size_t portID, const PortInfo &port, bool isFlip,
142 Twine name, FIRRTLType type, uint64_t fieldID,
143 const FieldIDSearch<hw::InnerSymAttr> &syms,
144 const FieldIDSearch<AnnotationSet> &annos) {
145 auto *ctx = type.getContext();
147 .Case<BundleType>([&](BundleType bundle) -> LogicalResult {
148 // This should be enhanced to be able to handle bundle<all flips of
149 // passive>, or this should be a canonicalizer
150 if (conv != Convention::Scalarized && bundle.isPassive()) {
151 auto lastId = fieldID + bundle.getMaxFieldID();
152 newPorts.push_back(
153 {{StringAttr::get(ctx, name), type,
154 isFlip ? Direction::Out : Direction::In,
155 symbolsForFieldIDRange(ctx, syms, fieldID, lastId), port.loc,
156 annosForFieldIDRange(ctx, annos, fieldID, lastId),
157 port.domains},
158 portID,
159 newPorts.size(),
160 fieldID});
161 } else {
162 for (auto [idx, elem] : llvm::enumerate(bundle.getElements())) {
163 if (failed(computeLoweringImpl(
164 mod, newPorts, conv, portID, port, isFlip ^ elem.isFlip,
165 name + "_" + elem.name.getValue(), elem.type,
166 fieldID + bundle.getFieldID(idx), syms, annos)))
167 return failure();
168 if (!syms.empty(fieldID, fieldID))
169 return mod.emitError("Port [")
170 << port.name
171 << "] should be subdivided, but cannot be because of "
172 "symbol ["
173 << port.sym.getSymIfExists(fieldID) << "] on a bundle";
174 if (!annos.empty(fieldID, fieldID)) {
175 auto err = mod.emitError("Port [")
176 << port.name
177 << "] should be subdivided, but cannot be because of "
178 "annotations [";
179 auto [b, e] = annos.find(fieldID, fieldID);
180 err << b->getClass() << "(" << b->getFieldID() << ")";
181 b++;
182 for (; b != e; ++b)
183 err << ", " << b->getClass() << "(" << b->getFieldID() << ")";
184 err << "] on a bundle";
185 return err;
186 }
187 }
188 }
189 return success();
190 })
191 .Case<FVectorType>([&](FVectorType vector) -> LogicalResult {
192 if (conv != Convention::Scalarized &&
193 vector.getElementType().isPassive()) {
194 auto lastId = fieldID + vector.getMaxFieldID();
195 newPorts.push_back(
196 {{StringAttr::get(ctx, name), type,
197 isFlip ? Direction::Out : Direction::In,
198 symbolsForFieldIDRange(ctx, syms, fieldID, lastId), port.loc,
199 annosForFieldIDRange(ctx, annos, fieldID, lastId),
200 port.domains},
201 portID,
202 newPorts.size(),
203 fieldID});
204 } else {
205 for (size_t i = 0, e = vector.getNumElements(); i < e; ++i) {
206 if (failed(computeLoweringImpl(
207 mod, newPorts, conv, portID, port, isFlip,
208 name + "_" + Twine(i), vector.getElementType(),
209 fieldID + vector.getFieldID(i), syms, annos)))
210 return failure();
211 if (!syms.empty(fieldID, fieldID))
212 return mod.emitError("Port [")
213 << port.name
214 << "] should be subdivided, but cannot be because of "
215 "symbol ["
216 << port.sym.getSymIfExists(fieldID) << "] on a vector";
217 if (!annos.empty(fieldID, fieldID)) {
218 auto err = mod.emitError("Port [")
219 << port.name
220 << "] should be subdivided, but cannot be because of "
221 "annotations [";
222 auto [b, e] = annos.find(fieldID, fieldID);
223 err << b->getClass();
224 ++b;
225 for (; b != e; ++b)
226 err << ", " << b->getClass();
227 err << "] on a vector";
228 return err;
229 }
230 }
231 }
232 return success();
233 })
234 .Default([&](FIRRTLType type) {
235 // Properties and other types wind up here.
236 newPorts.push_back(
237 {{StringAttr::get(ctx, name), type,
238 isFlip ? Direction::Out : Direction::In,
239 symbolsForFieldIDRange(ctx, syms, fieldID, fieldID), port.loc,
240 annosForFieldIDRange(ctx, annos, fieldID, fieldID), port.domains},
241 portID,
242 newPorts.size(),
243 fieldID});
244 return success();
245 });
246}
247
248// compute a new moduletype from an old module type and lowering convention.
249// Also compute a fieldID map from port, fieldID -> port
250static LogicalResult computeLowering(FModuleLike mod, Convention conv,
251 PortConversion &newPorts) {
252 for (auto [idx, port] : llvm::enumerate(mod.getPorts())) {
253 if (failed(computeLoweringImpl(
254 mod, newPorts, conv, idx, port, port.direction == Direction::Out,
255 port.name.getValue(), type_cast<FIRRTLType>(port.type), 0,
256 FieldIDSearch<hw::InnerSymAttr>(port.sym),
257 FieldIDSearch<AnnotationSet>(port.annotations))))
258 return failure();
259 }
260 return success();
261}
262
263static LogicalResult lowerModuleSignature(FModuleLike module, Convention conv,
264 AttrCache &cache,
265 PortConversion &newPorts) {
266 ImplicitLocOpBuilder theBuilder(module.getLoc(), module.getContext());
267 if (computeLowering(module, conv, newPorts).failed())
268 return failure();
269
270 // Update domain information now that all port expansions are fixed.
271 DenseMap<size_t, size_t> domainMap;
272 for (auto &newPort : newPorts) {
273 if (!type_isa<DomainType>(newPort.type))
274 continue;
275 domainMap[newPort.portID] = newPort.resultID;
276 }
277 for (auto &newPort : newPorts) {
278 if (type_isa<DomainType>(newPort.type))
279 continue;
280 auto oldAssociations = dyn_cast_or_null<ArrayAttr>(newPort.domains);
281 if (!oldAssociations)
282 continue;
283 SmallVector<Attribute> newAssociations;
284 for (auto oldAttr : oldAssociations)
285 newAssociations.push_back(theBuilder.getUI32IntegerAttr(
286 domainMap[cast<IntegerAttr>(oldAttr).getValue().getZExtValue()]));
287 newPort.domains = theBuilder.getArrayAttr(newAssociations);
288 }
289
290 if (auto mod = dyn_cast<FModuleOp>(module.getOperation())) {
291 Block *body = mod.getBodyBlock();
292 theBuilder.setInsertionPointToStart(body);
293 auto oldNumArgs = body->getNumArguments();
294
295 // Compute the replacement value for old arguments
296 // This creates all the new arguments and produces bounce wires when
297 // necessary
298 SmallVector<Value> bounceWires(oldNumArgs);
299 for (auto &p : newPorts) {
300 auto newArg = body->addArgument(p.type, p.loc);
301 // Get or create a bounce wire for changed ports
302 // For unmodified ports, move the uses to the replacement port
303 if (p.fieldID != 0) {
304 auto &wire = bounceWires[p.portID];
305 if (!wire)
306 wire = WireOp::create(theBuilder, module.getPortType(p.portID),
307 module.getPortNameAttr(p.portID),
308 NameKindEnum::InterestingName)
309 .getResult();
310 } else {
311 bounceWires[p.portID] = newArg;
312 }
313 }
314 // replace old arguments. Somethings get dropped completely, like
315 // zero-length vectors.
316 for (auto idx = 0U; idx < oldNumArgs; ++idx) {
317 if (!bounceWires[idx]) {
318 bounceWires[idx] = WireOp::create(theBuilder, module.getPortType(idx),
319 module.getPortNameAttr(idx))
320 .getResult();
321 }
322 body->getArgument(idx).replaceAllUsesWith(bounceWires[idx]);
323 }
324
325 // Goodbye old ports, now ResultID in the PortInfo is correct.
326 body->eraseArguments(0, oldNumArgs);
327
328 // Connect the bounce wires to the new arguments
329 for (auto &p : newPorts) {
330 if (isa<BlockArgument>(bounceWires[p.portID]))
331 continue;
332 if (p.isOutput())
334 theBuilder, body->getArgument(p.resultID),
335 getValueByFieldID(theBuilder, bounceWires[p.portID], p.fieldID));
336 else
338 theBuilder,
339 getValueByFieldID(theBuilder, bounceWires[p.portID], p.fieldID),
340 body->getArgument(p.resultID));
341 }
342 }
343
344 SmallVector<NamedAttribute, 8> newModuleAttrs;
345
346 // Copy over any attributes that weren't original argument attributes.
347 for (auto attr : module->getAttrDictionary())
348 // Drop old "portNames", directions, and argument attributes. These are
349 // handled differently below.
350 if (attr.getName() != "portNames" && attr.getName() != "portDirections" &&
351 attr.getName() != "portTypes" && attr.getName() != "portAnnotations" &&
352 attr.getName() != "portSymbols" && attr.getName() != "portLocations")
353 newModuleAttrs.push_back(attr);
354
355 SmallVector<Direction> newPortDirections;
356 SmallVector<Attribute> newPortNames;
357 SmallVector<Attribute> newPortTypes;
358 SmallVector<Attribute> newPortSyms;
359 SmallVector<Attribute> newPortLocations;
360 SmallVector<Attribute, 8> newPortAnnotations;
361 SmallVector<Attribute> newPortDomains;
362
363 for (auto p : newPorts) {
364 newPortTypes.push_back(TypeAttr::get(p.type));
365 newPortNames.push_back(p.name);
366 newPortDirections.push_back(p.direction);
367 newPortSyms.push_back(p.sym);
368 newPortLocations.push_back(p.loc);
369 newPortAnnotations.push_back(p.annotations.getArrayAttr());
370 newPortDomains.push_back(p.domains ? p.domains : cache.aEmpty);
371 }
372
373 newModuleAttrs.push_back(NamedAttribute(
374 cache.sPortDirections,
375 direction::packAttribute(module.getContext(), newPortDirections)));
376
377 newModuleAttrs.push_back(
378 NamedAttribute(cache.sPortNames, theBuilder.getArrayAttr(newPortNames)));
379
380 newModuleAttrs.push_back(
381 NamedAttribute(cache.sPortTypes, theBuilder.getArrayAttr(newPortTypes)));
382
383 newModuleAttrs.push_back(NamedAttribute(
384 cache.sPortLocations, theBuilder.getArrayAttr(newPortLocations)));
385
386 newModuleAttrs.push_back(NamedAttribute(
387 cache.sPortAnnotations, theBuilder.getArrayAttr(newPortAnnotations)));
388
389 newModuleAttrs.push_back(NamedAttribute(
390 cache.sPortDomains, theBuilder.getArrayAttr(newPortDomains)));
391
392 // Update the module's attributes.
393 module->setAttrs(newModuleAttrs);
394 FModuleLike::fixupPortSymsArray(newPortSyms, theBuilder.getContext());
395 module.setPortSymbols(newPortSyms);
396 return success();
397}
398
399static void lowerModuleBody(FModuleOp mod,
400 const DenseMap<StringAttr, PortConversion> &ports) {
401 mod->walk([&](InstanceOp inst) -> void {
402 ImplicitLocOpBuilder theBuilder(inst.getLoc(), inst);
403 const auto &modPorts = ports.at(inst.getModuleNameAttr().getAttr());
404
405 // Fix up the Instance
406 SmallVector<PortInfo> instPorts; // Oh I wish ArrayRef was polymorphic.
407 for (auto p : modPorts) {
408 p.sym = {};
409 // Might need to partially copy stuff from the old instance.
410 p.annotations = AnnotationSet{mod.getContext()};
411 instPorts.push_back(p);
412 }
413 auto annos = inst.getAnnotations();
414 auto newOp = InstanceOp::create(
415 theBuilder, instPorts, inst.getModuleName(), inst.getName(),
416 inst.getNameKind(), annos.getValue(), inst.getLayers(),
417 inst.getLowerToBind(), inst.getDoNotPrint(), inst.getInnerSymAttr());
418
419 auto oldDict = inst->getDiscardableAttrDictionary();
420 auto newDict = newOp->getDiscardableAttrDictionary();
421 auto oldNames = inst.getPortNamesAttr();
422 SmallVector<NamedAttribute> newAttrs;
423 for (auto na : oldDict)
424 if (!newDict.contains(na.getName()))
425 newOp->setDiscardableAttr(na.getName(), na.getValue());
426
427 // Connect up the old instance users to the new instance
428 SmallVector<WireOp> bounce(inst.getNumResults());
429 for (auto p : modPorts) {
430 // No change? No bounce wire.
431 if (p.fieldID == 0) {
432 inst.getResult(p.portID).replaceAllUsesWith(
433 newOp.getResult(p.resultID));
434 continue;
435 }
436 if (!bounce[p.portID]) {
437 bounce[p.portID] = WireOp::create(
438 theBuilder, inst.getResult(p.portID).getType(),
439 theBuilder.getStringAttr(
440 inst.getName() + "." +
441 cast<StringAttr>(oldNames[p.portID]).getValue()));
442 inst.getResult(p.portID).replaceAllUsesWith(
443 bounce[p.portID].getResult());
444 }
445 // Connect up the Instance to the bounce wires
446 if (p.isInput())
447 emitConnect(theBuilder, newOp.getResult(p.resultID),
448 getValueByFieldID(theBuilder, bounce[p.portID].getResult(),
449 p.fieldID));
450 else
451 emitConnect(theBuilder,
452 getValueByFieldID(theBuilder, bounce[p.portID].getResult(),
453 p.fieldID),
454 newOp.getResult(p.resultID));
455 }
456 // Zero Width ports may have dangling connects since they are not preserved
457 // and do not have bounce wires.
458 for (auto *use : llvm::make_early_inc_range(inst->getUsers())) {
459 assert(isa<MatchingConnectOp>(use) || isa<ConnectOp>(use));
460 use->erase();
461 }
462 inst->erase();
463 return;
464 });
465}
466
467//===----------------------------------------------------------------------===//
468// Pass Infrastructure
469//===----------------------------------------------------------------------===//
470
471namespace {
472struct LowerSignaturesPass
473 : public circt::firrtl::impl::LowerSignaturesBase<LowerSignaturesPass> {
474 void runOnOperation() override;
475};
476} // end anonymous namespace
477
478// This is the main entrypoint for the lowering pass.
479void LowerSignaturesPass::runOnOperation() {
481
482 // Cached attr
483 AttrCache cache(&getContext());
484
485 DenseMap<StringAttr, PortConversion> portMap;
486 auto circuit = getOperation();
487
488 for (auto mod : circuit.getOps<FModuleLike>()) {
489 if (lowerModuleSignature(mod, mod.getConvention(), cache,
490 portMap[mod.getNameAttr()])
491 .failed())
492 return signalPassFailure();
493 }
494 parallelForEach(&getContext(), circuit.getOps<FModuleOp>(),
495 [&portMap](FModuleOp mod) { lowerModuleBody(mod, portMap); });
496}
assert(baseType &&"element must be base type")
static std::unique_ptr< Context > context
static LogicalResult computeLowering(FModuleLike mod, Convention conv, PortConversion &newPorts)
static AnnotationSet annosForFieldIDRange(MLIRContext *ctx, const FieldIDSearch< AnnotationSet > &annos, uint64_t low, uint64_t high)
static LogicalResult lowerModuleSignature(FModuleLike module, Convention conv, AttrCache &cache, PortConversion &newPorts)
static LogicalResult computeLoweringImpl(FModuleLike mod, PortConversion &newPorts, Convention conv, size_t portID, const PortInfo &port, bool isFlip, Twine name, FIRRTLType type, uint64_t fieldID, const FieldIDSearch< hw::InnerSymAttr > &syms, const FieldIDSearch< AnnotationSet > &annos)
static void lowerModuleBody(FModuleOp mod, const DenseMap< StringAttr, PortConversion > &ports)
static hw::InnerSymAttr symbolsForFieldIDRange(MLIRContext *ctx, const FieldIDSearch< hw::InnerSymAttr > &syms, uint64_t low, uint64_t high)
static InstancePath empty
#define CIRCT_DEBUG_SCOPED_PASS_LOGGER(PASS)
Definition Debug.h:70
This class provides a read-only projection over the MLIR attributes that represent a set of annotatio...
MLIRContext * getContext() const
Return the MLIRContext corresponding to this AnnotationSet.
void addAnnotations(ArrayRef< Annotation > annotations)
Add more annotations to this annotation set.
This class provides a read-only projection of an annotation.
This class implements the same functionality as TypeSwitch except that it uses firrtl::type_dyn_cast ...
FIRRTLTypeSwitch< T, ResultT > & Case(CallableT &&caseFn)
Add a case on the given type.
Base class for the port conversion of a particular port.
mlir::DenseBoolArrayAttr packAttribute(MLIRContext *context, ArrayRef< Direction > directions)
Return a DenseBoolArrayAttr containing the packed representation of an array of directions.
Value getValueByFieldID(ImplicitLocOpBuilder builder, Value value, unsigned fieldID)
This gets the value targeted by a field id.
void emitConnect(OpBuilder &builder, Location loc, Value lhs, Value rhs)
Emit a connect between two values.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
This holds the name and type that describes the module's ports.