CIRCT 22.0.0git
Loading...
Searching...
No Matches
LowerSignatures.cpp
Go to the documentation of this file.
1//===- LowerSignatures.cpp - Lower Module Signatures ------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the LowerSignatures pass. This pass replaces aggregate
10// types with expanded values in module arguments as specified by the ABI
11// information.
12//
13//===----------------------------------------------------------------------===//
14
21#include "circt/Support/Debug.h"
22#include "mlir/IR/Threading.h"
23#include "mlir/Pass/Pass.h"
24#include "llvm/Support/Debug.h"
25
26#define DEBUG_TYPE "firrtl-lower-signatures"
27
28namespace circt {
29namespace firrtl {
30#define GEN_PASS_DEF_LOWERSIGNATURES
31#include "circt/Dialect/FIRRTL/Passes.h.inc"
32} // namespace firrtl
33} // namespace circt
34
35using namespace circt;
36using namespace firrtl;
37
38//===----------------------------------------------------------------------===//
39// Module Type Lowering
40//===----------------------------------------------------------------------===//
41namespace {
42
43struct AttrCache {
44 AttrCache(MLIRContext *context) {
45 nameAttr = StringAttr::get(context, "name");
46 sPortDirections = StringAttr::get(context, "portDirections");
47 sPortNames = StringAttr::get(context, "portNames");
48 sPortTypes = StringAttr::get(context, "portTypes");
49 sPortLocations = StringAttr::get(context, "portLocations");
50 sPortAnnotations = StringAttr::get(context, "portAnnotations");
51 sPortDomains = StringAttr::get(context, "domainInfo");
52 sInternalPaths = StringAttr::get(context, "internalPaths");
53 aEmpty = ArrayAttr::get(context, {});
54 }
55 AttrCache(const AttrCache &) = default;
56
57 StringAttr nameAttr, sPortDirections, sPortNames, sPortTypes, sPortLocations,
58 sPortAnnotations, sPortDomains, sInternalPaths;
59 ArrayAttr aEmpty;
60};
61
62struct FieldMapEntry : public PortInfo {
63 size_t portID;
64 size_t resultID;
65 size_t fieldID;
66};
67
68using PortConversion = SmallVector<FieldMapEntry>;
69
70template <typename T>
71class FieldIDSearch {
72 using E = typename T::ElementType;
73 using V = SmallVector<E>;
74
75public:
76 using const_iterator = typename V::const_iterator;
77
78 template <typename Container>
79 FieldIDSearch(const Container &src) {
80 if constexpr (std::is_convertible_v<Container, Attribute>)
81 if (!src)
82 return;
83 for (auto attr : src)
84 vals.push_back(attr);
85 std::sort(vals.begin(), vals.end(), fieldComp);
86 }
87
88 std::pair<const_iterator, const_iterator> find(uint64_t low,
89 uint64_t high) const {
90 return {std::lower_bound(vals.begin(), vals.end(), low, fieldCompInt2),
91 std::upper_bound(vals.begin(), vals.end(), high, fieldCompInt1)};
92 }
93
94 bool empty(uint64_t low, uint64_t high) const {
95 auto [b, e] = find(low, high);
96 return b == e;
97 }
98
99private:
100 static constexpr auto fieldComp = [](const E &lhs, const E &rhs) {
101 return lhs.getFieldID() < rhs.getFieldID();
102 };
103 static constexpr auto fieldCompInt2 = [](const E &lhs, uint64_t rhs) {
104 return lhs.getFieldID() < rhs;
105 };
106 static constexpr auto fieldCompInt1 = [](uint64_t lhs, const E &rhs) {
107 return lhs < rhs.getFieldID();
108 };
109
110 V vals;
111};
112
113} // namespace
114
115static hw::InnerSymAttr
116symbolsForFieldIDRange(MLIRContext *ctx,
117 const FieldIDSearch<hw::InnerSymAttr> &syms,
118 uint64_t low, uint64_t high) {
119 auto [b, e] = syms.find(low, high);
120 SmallVector<hw::InnerSymPropertiesAttr, 4> newSyms(b, e);
121 if (newSyms.empty())
122 return {};
123 for (auto &sym : newSyms)
124 sym = hw::InnerSymPropertiesAttr::get(
125 ctx, sym.getName(), sym.getFieldID() - low, sym.getSymVisibility());
126 return hw::InnerSymAttr::get(ctx, newSyms);
127}
128
129static AnnotationSet
130annosForFieldIDRange(MLIRContext *ctx,
131 const FieldIDSearch<AnnotationSet> &annos, uint64_t low,
132 uint64_t high) {
133 AnnotationSet newAnnos(ctx);
134 auto [b, e] = annos.find(low, high);
135 for (; b != e; ++b)
136 newAnnos.addAnnotations(Annotation(*b, b->getFieldID() - low));
137 return newAnnos;
138}
139
140static LogicalResult
141computeLoweringImpl(FModuleLike mod, PortConversion &newPorts, Convention conv,
142 size_t portID, const PortInfo &port, bool isFlip,
143 Twine name, FIRRTLType type, uint64_t fieldID,
144 const FieldIDSearch<hw::InnerSymAttr> &syms,
145 const FieldIDSearch<AnnotationSet> &annos) {
146 auto *ctx = type.getContext();
148 .Case<BundleType>([&](BundleType bundle) -> LogicalResult {
149 // This should be enhanced to be able to handle bundle<all flips of
150 // passive>, or this should be a canonicalizer
151 if (conv != Convention::Scalarized && bundle.isPassive()) {
152 auto lastId = fieldID + bundle.getMaxFieldID();
153 newPorts.push_back(
154 {{StringAttr::get(ctx, name), type,
155 isFlip ? Direction::Out : Direction::In,
156 symbolsForFieldIDRange(ctx, syms, fieldID, lastId), port.loc,
157 annosForFieldIDRange(ctx, annos, fieldID, lastId),
158 port.domains},
159 portID,
160 newPorts.size(),
161 fieldID});
162 } else {
163 for (auto [idx, elem] : llvm::enumerate(bundle.getElements())) {
164 if (failed(computeLoweringImpl(
165 mod, newPorts, conv, portID, port, isFlip ^ elem.isFlip,
166 name + "_" + elem.name.getValue(), elem.type,
167 fieldID + bundle.getFieldID(idx), syms, annos)))
168 return failure();
169 if (!syms.empty(fieldID, fieldID))
170 return mod.emitError("Port [")
171 << port.name
172 << "] should be subdivided, but cannot be because of "
173 "symbol ["
174 << port.sym.getSymIfExists(fieldID) << "] on a bundle";
175 if (!annos.empty(fieldID, fieldID)) {
176 auto err = mod.emitError("Port [")
177 << port.name
178 << "] should be subdivided, but cannot be because of "
179 "annotations [";
180 auto [b, e] = annos.find(fieldID, fieldID);
181 err << b->getClass() << "(" << b->getFieldID() << ")";
182 b++;
183 for (; b != e; ++b)
184 err << ", " << b->getClass() << "(" << b->getFieldID() << ")";
185 err << "] on a bundle";
186 return err;
187 }
188 }
189 }
190 return success();
191 })
192 .Case<FVectorType>([&](FVectorType vector) -> LogicalResult {
193 if (conv != Convention::Scalarized &&
194 vector.getElementType().isPassive()) {
195 auto lastId = fieldID + vector.getMaxFieldID();
196 newPorts.push_back(
197 {{StringAttr::get(ctx, name), type,
198 isFlip ? Direction::Out : Direction::In,
199 symbolsForFieldIDRange(ctx, syms, fieldID, lastId), port.loc,
200 annosForFieldIDRange(ctx, annos, fieldID, lastId),
201 port.domains},
202 portID,
203 newPorts.size(),
204 fieldID});
205 } else {
206 for (size_t i = 0, e = vector.getNumElements(); i < e; ++i) {
207 if (failed(computeLoweringImpl(
208 mod, newPorts, conv, portID, port, isFlip,
209 name + "_" + Twine(i), vector.getElementType(),
210 fieldID + vector.getFieldID(i), syms, annos)))
211 return failure();
212 if (!syms.empty(fieldID, fieldID))
213 return mod.emitError("Port [")
214 << port.name
215 << "] should be subdivided, but cannot be because of "
216 "symbol ["
217 << port.sym.getSymIfExists(fieldID) << "] on a vector";
218 if (!annos.empty(fieldID, fieldID)) {
219 auto err = mod.emitError("Port [")
220 << port.name
221 << "] should be subdivided, but cannot be because of "
222 "annotations [";
223 auto [b, e] = annos.find(fieldID, fieldID);
224 err << b->getClass();
225 ++b;
226 for (; b != e; ++b)
227 err << ", " << b->getClass();
228 err << "] on a vector";
229 return err;
230 }
231 }
232 }
233 return success();
234 })
235 .Default([&](FIRRTLType type) {
236 // Properties and other types wind up here.
237 newPorts.push_back(
238 {{StringAttr::get(ctx, name), type,
239 isFlip ? Direction::Out : Direction::In,
240 symbolsForFieldIDRange(ctx, syms, fieldID, fieldID), port.loc,
241 annosForFieldIDRange(ctx, annos, fieldID, fieldID), port.domains},
242 portID,
243 newPorts.size(),
244 fieldID});
245 return success();
246 });
247}
248
249// compute a new moduletype from an old module type and lowering convention.
250// Also compute a fieldID map from port, fieldID -> port
251static LogicalResult computeLowering(FModuleLike mod, Convention conv,
252 PortConversion &newPorts) {
253 for (auto [idx, port] : llvm::enumerate(mod.getPorts())) {
254 if (failed(computeLoweringImpl(
255 mod, newPorts, conv, idx, port, port.direction == Direction::Out,
256 port.name.getValue(), type_cast<FIRRTLType>(port.type), 0,
257 FieldIDSearch<hw::InnerSymAttr>(port.sym),
258 FieldIDSearch<AnnotationSet>(port.annotations))))
259 return failure();
260 }
261 return success();
262}
263
264static LogicalResult lowerModuleSignature(FModuleLike module, Convention conv,
265 AttrCache &cache,
266 PortConversion &newPorts) {
267 ImplicitLocOpBuilder theBuilder(module.getLoc(), module.getContext());
268 if (computeLowering(module, conv, newPorts).failed())
269 return failure();
270 if (auto mod = dyn_cast<FModuleOp>(module.getOperation())) {
271 Block *body = mod.getBodyBlock();
272 theBuilder.setInsertionPointToStart(body);
273 auto oldNumArgs = body->getNumArguments();
274
275 // Compute the replacement value for old arguments
276 // This creates all the new arguments and produces bounce wires when
277 // necessary
278 SmallVector<Value> bounceWires(oldNumArgs);
279 for (auto &p : newPorts) {
280 auto newArg = body->addArgument(p.type, p.loc);
281 // Get or create a bounce wire for changed ports
282 // For unmodified ports, move the uses to the replacement port
283 if (p.fieldID != 0) {
284 auto &wire = bounceWires[p.portID];
285 if (!wire)
286 wire = WireOp::create(theBuilder, module.getPortType(p.portID),
287 module.getPortNameAttr(p.portID),
288 NameKindEnum::InterestingName)
289 .getResult();
290 } else {
291 bounceWires[p.portID] = newArg;
292 }
293 }
294 // replace old arguments. Somethings get dropped completely, like
295 // zero-length vectors.
296 for (auto idx = 0U; idx < oldNumArgs; ++idx) {
297 if (!bounceWires[idx]) {
298 bounceWires[idx] = WireOp::create(theBuilder, module.getPortType(idx),
299 module.getPortNameAttr(idx))
300 .getResult();
301 }
302 body->getArgument(idx).replaceAllUsesWith(bounceWires[idx]);
303 }
304
305 // Goodbye old ports, now ResultID in the PortInfo is correct.
306 body->eraseArguments(0, oldNumArgs);
307
308 // Connect the bounce wires to the new arguments
309 for (auto &p : newPorts) {
310 if (isa<BlockArgument>(bounceWires[p.portID]))
311 continue;
312 if (p.isOutput())
314 theBuilder, body->getArgument(p.resultID),
315 getValueByFieldID(theBuilder, bounceWires[p.portID], p.fieldID));
316 else
318 theBuilder,
319 getValueByFieldID(theBuilder, bounceWires[p.portID], p.fieldID),
320 body->getArgument(p.resultID));
321 }
322 }
323
324 SmallVector<NamedAttribute, 8> newModuleAttrs;
325
326 // Copy over any attributes that weren't original argument attributes.
327 for (auto attr : module->getAttrDictionary())
328 // Drop old "portNames", directions, and argument attributes. These are
329 // handled differently below.
330 if (attr.getName() != "portNames" && attr.getName() != "portDirections" &&
331 attr.getName() != "portTypes" && attr.getName() != "portAnnotations" &&
332 attr.getName() != "portSymbols" && attr.getName() != "portLocations" &&
333 attr.getName() != "internalPaths")
334 newModuleAttrs.push_back(attr);
335
336 SmallVector<Direction> newPortDirections;
337 SmallVector<Attribute> newPortNames;
338 SmallVector<Attribute> newPortTypes;
339 SmallVector<Attribute> newPortSyms;
340 SmallVector<Attribute> newPortLocations;
341 SmallVector<Attribute, 8> newPortAnnotations;
342 SmallVector<Attribute> newPortDomains;
343 SmallVector<Attribute> newInternalPaths;
344
345 bool hasInternalPaths = false;
346 auto internalPaths = module->getAttrOfType<ArrayAttr>("internalPaths");
347 for (auto p : newPorts) {
348 newPortTypes.push_back(TypeAttr::get(p.type));
349 newPortNames.push_back(p.name);
350 newPortDirections.push_back(p.direction);
351 newPortSyms.push_back(p.sym);
352 newPortLocations.push_back(p.loc);
353 newPortAnnotations.push_back(p.annotations.getArrayAttr());
354 newPortDomains.push_back(p.domains ? p.domains : cache.aEmpty);
355 if (internalPaths) {
356 auto internalPath = cast<InternalPathAttr>(internalPaths[p.portID]);
357 newInternalPaths.push_back(internalPath);
358 if (internalPath.getPath())
359 hasInternalPaths = true;
360 }
361 }
362
363 newModuleAttrs.push_back(NamedAttribute(
364 cache.sPortDirections,
365 direction::packAttribute(module.getContext(), newPortDirections)));
366
367 newModuleAttrs.push_back(
368 NamedAttribute(cache.sPortNames, theBuilder.getArrayAttr(newPortNames)));
369
370 newModuleAttrs.push_back(
371 NamedAttribute(cache.sPortTypes, theBuilder.getArrayAttr(newPortTypes)));
372
373 newModuleAttrs.push_back(NamedAttribute(
374 cache.sPortLocations, theBuilder.getArrayAttr(newPortLocations)));
375
376 newModuleAttrs.push_back(NamedAttribute(
377 cache.sPortAnnotations, theBuilder.getArrayAttr(newPortAnnotations)));
378
379 newModuleAttrs.push_back(NamedAttribute(
380 cache.sPortDomains, theBuilder.getArrayAttr(newPortDomains)));
381
382 assert(newInternalPaths.empty() ||
383 newInternalPaths.size() == newPorts.size());
384 if (hasInternalPaths) {
385 newModuleAttrs.emplace_back(cache.sInternalPaths,
386 theBuilder.getArrayAttr(newInternalPaths));
387 }
388
389 // Update the module's attributes.
390 module->setAttrs(newModuleAttrs);
391 FModuleLike::fixupPortSymsArray(newPortSyms, theBuilder.getContext());
392 module.setPortSymbols(newPortSyms);
393 return success();
394}
395
396static void lowerModuleBody(FModuleOp mod,
397 const DenseMap<StringAttr, PortConversion> &ports) {
398 mod->walk([&](InstanceOp inst) -> void {
399 ImplicitLocOpBuilder theBuilder(inst.getLoc(), inst);
400 const auto &modPorts = ports.at(inst.getModuleNameAttr().getAttr());
401
402 // Fix up the Instance
403 SmallVector<PortInfo> instPorts; // Oh I wish ArrayRef was polymorphic.
404 for (auto p : modPorts) {
405 p.sym = {};
406 // Might need to partially copy stuff from the old instance.
407 p.annotations = AnnotationSet{mod.getContext()};
408 instPorts.push_back(p);
409 }
410 auto annos = inst.getAnnotations();
411 auto newOp = InstanceOp::create(
412 theBuilder, instPorts, inst.getModuleName(), inst.getName(),
413 inst.getNameKind(), annos.getValue(), inst.getLayers(),
414 inst.getLowerToBind(), inst.getDoNotPrint(), inst.getInnerSymAttr());
415
416 auto oldDict = inst->getDiscardableAttrDictionary();
417 auto newDict = newOp->getDiscardableAttrDictionary();
418 auto oldNames = inst.getPortNamesAttr();
419 SmallVector<NamedAttribute> newAttrs;
420 for (auto na : oldDict)
421 if (!newDict.contains(na.getName()))
422 newOp->setDiscardableAttr(na.getName(), na.getValue());
423
424 // Connect up the old instance users to the new instance
425 SmallVector<WireOp> bounce(inst.getNumResults());
426 for (auto p : modPorts) {
427 // No change? No bounce wire.
428 if (p.fieldID == 0) {
429 inst.getResult(p.portID).replaceAllUsesWith(
430 newOp.getResult(p.resultID));
431 continue;
432 }
433 if (!bounce[p.portID]) {
434 bounce[p.portID] = WireOp::create(
435 theBuilder, inst.getResult(p.portID).getType(),
436 theBuilder.getStringAttr(
437 inst.getName() + "." +
438 cast<StringAttr>(oldNames[p.portID]).getValue()));
439 inst.getResult(p.portID).replaceAllUsesWith(
440 bounce[p.portID].getResult());
441 }
442 // Connect up the Instance to the bounce wires
443 if (p.isInput())
444 emitConnect(theBuilder, newOp.getResult(p.resultID),
445 getValueByFieldID(theBuilder, bounce[p.portID].getResult(),
446 p.fieldID));
447 else
448 emitConnect(theBuilder,
449 getValueByFieldID(theBuilder, bounce[p.portID].getResult(),
450 p.fieldID),
451 newOp.getResult(p.resultID));
452 }
453 // Zero Width ports may have dangling connects since they are not preserved
454 // and do not have bounce wires.
455 for (auto *use : llvm::make_early_inc_range(inst->getUsers())) {
456 assert(isa<MatchingConnectOp>(use) || isa<ConnectOp>(use));
457 use->erase();
458 }
459 inst->erase();
460 return;
461 });
462}
463
464//===----------------------------------------------------------------------===//
465// Pass Infrastructure
466//===----------------------------------------------------------------------===//
467
468namespace {
469struct LowerSignaturesPass
470 : public circt::firrtl::impl::LowerSignaturesBase<LowerSignaturesPass> {
471 void runOnOperation() override;
472};
473} // end anonymous namespace
474
475// This is the main entrypoint for the lowering pass.
476void LowerSignaturesPass::runOnOperation() {
477 LLVM_DEBUG(debugPassHeader(this) << "\n");
478 // Cached attr
479 AttrCache cache(&getContext());
480
481 DenseMap<StringAttr, PortConversion> portMap;
482 auto circuit = getOperation();
483
484 for (auto mod : circuit.getOps<FModuleLike>()) {
485 if (lowerModuleSignature(mod, mod.getConvention(), cache,
486 portMap[mod.getNameAttr()])
487 .failed())
488 return signalPassFailure();
489 }
490 parallelForEach(&getContext(), circuit.getOps<FModuleOp>(),
491 [&portMap](FModuleOp mod) { lowerModuleBody(mod, portMap); });
492}
assert(baseType &&"element must be base type")
static LogicalResult computeLowering(FModuleLike mod, Convention conv, PortConversion &newPorts)
static AnnotationSet annosForFieldIDRange(MLIRContext *ctx, const FieldIDSearch< AnnotationSet > &annos, uint64_t low, uint64_t high)
static LogicalResult lowerModuleSignature(FModuleLike module, Convention conv, AttrCache &cache, PortConversion &newPorts)
static LogicalResult computeLoweringImpl(FModuleLike mod, PortConversion &newPorts, Convention conv, size_t portID, const PortInfo &port, bool isFlip, Twine name, FIRRTLType type, uint64_t fieldID, const FieldIDSearch< hw::InnerSymAttr > &syms, const FieldIDSearch< AnnotationSet > &annos)
static void lowerModuleBody(FModuleOp mod, const DenseMap< StringAttr, PortConversion > &ports)
static hw::InnerSymAttr symbolsForFieldIDRange(MLIRContext *ctx, const FieldIDSearch< hw::InnerSymAttr > &syms, uint64_t low, uint64_t high)
static InstancePath empty
This class provides a read-only projection over the MLIR attributes that represent a set of annotatio...
MLIRContext * getContext() const
Return the MLIRContext corresponding to this AnnotationSet.
void addAnnotations(ArrayRef< Annotation > annotations)
Add more annotations to this annotation set.
This class provides a read-only projection of an annotation.
This class implements the same functionality as TypeSwitch except that it uses firrtl::type_dyn_cast ...
FIRRTLTypeSwitch< T, ResultT > & Case(CallableT &&caseFn)
Add a case on the given type.
Base class for the port conversion of a particular port.
mlir::DenseBoolArrayAttr packAttribute(MLIRContext *context, ArrayRef< Direction > directions)
Return a DenseBoolArrayAttr containing the packed representation of an array of directions.
Value getValueByFieldID(ImplicitLocOpBuilder builder, Value value, unsigned fieldID)
This gets the value targeted by a field id.
void emitConnect(OpBuilder &builder, Location loc, Value lhs, Value rhs)
Emit a connect between two values.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
llvm::raw_ostream & debugPassHeader(const mlir::Pass *pass, int width=80)
Write a boilerplate header for a pass to the debug stream.
Definition Debug.cpp:31
This holds the name and type that describes the module's ports.