Loading [MathJax]/extensions/tex2jax.js
CIRCT 22.0.0git
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
LowerSignatures.cpp
Go to the documentation of this file.
1//===- LowerSignatures.cpp - Lower Module Signatures ------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the LowerSignatures pass. This pass replaces aggregate
10// types with expanded values in module arguments as specified by the ABI
11// information.
12//
13//===----------------------------------------------------------------------===//
14
21#include "circt/Support/Debug.h"
22#include "mlir/IR/Threading.h"
23#include "mlir/Pass/Pass.h"
24#include "llvm/Support/Debug.h"
25
26#define DEBUG_TYPE "firrtl-lower-signatures"
27
28namespace circt {
29namespace firrtl {
30#define GEN_PASS_DEF_LOWERSIGNATURES
31#include "circt/Dialect/FIRRTL/Passes.h.inc"
32} // namespace firrtl
33} // namespace circt
34
35using namespace circt;
36using namespace firrtl;
37
38//===----------------------------------------------------------------------===//
39// Module Type Lowering
40//===----------------------------------------------------------------------===//
41namespace {
42
43struct AttrCache {
44 AttrCache(MLIRContext *context) {
45 nameAttr = StringAttr::get(context, "name");
46 sPortDirections = StringAttr::get(context, "portDirections");
47 sPortNames = StringAttr::get(context, "portNames");
48 sPortTypes = StringAttr::get(context, "portTypes");
49 sPortLocations = StringAttr::get(context, "portLocations");
50 sPortAnnotations = StringAttr::get(context, "portAnnotations");
51 sInternalPaths = StringAttr::get(context, "internalPaths");
52 }
53 AttrCache(const AttrCache &) = default;
54
55 StringAttr nameAttr, sPortDirections, sPortNames, sPortTypes, sPortLocations,
56 sPortAnnotations, sInternalPaths;
57};
58
59struct FieldMapEntry : public PortInfo {
60 size_t portID;
61 size_t resultID;
62 size_t fieldID;
63};
64
65using PortConversion = SmallVector<FieldMapEntry>;
66
67template <typename T>
68class FieldIDSearch {
69 using E = typename T::ElementType;
70 using V = SmallVector<E>;
71
72public:
73 using const_iterator = typename V::const_iterator;
74
75 template <typename Container>
76 FieldIDSearch(const Container &src) {
77 if constexpr (std::is_convertible_v<Container, Attribute>)
78 if (!src)
79 return;
80 for (auto attr : src)
81 vals.push_back(attr);
82 std::sort(vals.begin(), vals.end(), fieldComp);
83 }
84
85 std::pair<const_iterator, const_iterator> find(uint64_t low,
86 uint64_t high) const {
87 return {std::lower_bound(vals.begin(), vals.end(), low, fieldCompInt2),
88 std::upper_bound(vals.begin(), vals.end(), high, fieldCompInt1)};
89 }
90
91 bool empty(uint64_t low, uint64_t high) const {
92 auto [b, e] = find(low, high);
93 return b == e;
94 }
95
96private:
97 static constexpr auto fieldComp = [](const E &lhs, const E &rhs) {
98 return lhs.getFieldID() < rhs.getFieldID();
99 };
100 static constexpr auto fieldCompInt2 = [](const E &lhs, uint64_t rhs) {
101 return lhs.getFieldID() < rhs;
102 };
103 static constexpr auto fieldCompInt1 = [](uint64_t lhs, const E &rhs) {
104 return lhs < rhs.getFieldID();
105 };
106
107 V vals;
108};
109
110} // namespace
111
112static hw::InnerSymAttr
113symbolsForFieldIDRange(MLIRContext *ctx,
114 const FieldIDSearch<hw::InnerSymAttr> &syms,
115 uint64_t low, uint64_t high) {
116 auto [b, e] = syms.find(low, high);
117 SmallVector<hw::InnerSymPropertiesAttr, 4> newSyms(b, e);
118 if (newSyms.empty())
119 return {};
120 for (auto &sym : newSyms)
121 sym = hw::InnerSymPropertiesAttr::get(
122 ctx, sym.getName(), sym.getFieldID() - low, sym.getSymVisibility());
123 return hw::InnerSymAttr::get(ctx, newSyms);
124}
125
126static AnnotationSet
127annosForFieldIDRange(MLIRContext *ctx,
128 const FieldIDSearch<AnnotationSet> &annos, uint64_t low,
129 uint64_t high) {
130 AnnotationSet newAnnos(ctx);
131 auto [b, e] = annos.find(low, high);
132 for (; b != e; ++b)
133 newAnnos.addAnnotations(Annotation(*b, b->getFieldID() - low));
134 return newAnnos;
135}
136
137static LogicalResult
138computeLoweringImpl(FModuleLike mod, PortConversion &newPorts, Convention conv,
139 size_t portID, const PortInfo &port, bool isFlip,
140 Twine name, FIRRTLType type, uint64_t fieldID,
141 const FieldIDSearch<hw::InnerSymAttr> &syms,
142 const FieldIDSearch<AnnotationSet> &annos) {
143 auto *ctx = type.getContext();
145 .Case<BundleType>([&](BundleType bundle) -> LogicalResult {
146 // This should be enhanced to be able to handle bundle<all flips of
147 // passive>, or this should be a canonicalizer
148 if (conv != Convention::Scalarized && bundle.isPassive()) {
149 auto lastId = fieldID + bundle.getMaxFieldID();
150 newPorts.push_back(
151 {{StringAttr::get(ctx, name), type,
152 isFlip ? Direction::Out : Direction::In,
153 symbolsForFieldIDRange(ctx, syms, fieldID, lastId), port.loc,
154 annosForFieldIDRange(ctx, annos, fieldID, lastId)},
155 portID,
156 newPorts.size(),
157 fieldID});
158 } else {
159 for (auto [idx, elem] : llvm::enumerate(bundle.getElements())) {
160 if (failed(computeLoweringImpl(
161 mod, newPorts, conv, portID, port, isFlip ^ elem.isFlip,
162 name + "_" + elem.name.getValue(), elem.type,
163 fieldID + bundle.getFieldID(idx), syms, annos)))
164 return failure();
165 if (!syms.empty(fieldID, fieldID))
166 return mod.emitError("Port [")
167 << port.name
168 << "] should be subdivided, but cannot be because of "
169 "symbol ["
170 << port.sym.getSymIfExists(fieldID) << "] on a bundle";
171 if (!annos.empty(fieldID, fieldID)) {
172 auto err = mod.emitError("Port [")
173 << port.name
174 << "] should be subdivided, but cannot be because of "
175 "annotations [";
176 auto [b, e] = annos.find(fieldID, fieldID);
177 err << b->getClass() << "(" << b->getFieldID() << ")";
178 b++;
179 for (; b != e; ++b)
180 err << ", " << b->getClass() << "(" << b->getFieldID() << ")";
181 err << "] on a bundle";
182 return err;
183 }
184 }
185 }
186 return success();
187 })
188 .Case<FVectorType>([&](FVectorType vector) -> LogicalResult {
189 if (conv != Convention::Scalarized &&
190 vector.getElementType().isPassive()) {
191 auto lastId = fieldID + vector.getMaxFieldID();
192 newPorts.push_back(
193 {{StringAttr::get(ctx, name), type,
194 isFlip ? Direction::Out : Direction::In,
195 symbolsForFieldIDRange(ctx, syms, fieldID, lastId), port.loc,
196 annosForFieldIDRange(ctx, annos, fieldID, lastId)},
197 portID,
198 newPorts.size(),
199 fieldID});
200 } else {
201 for (size_t i = 0, e = vector.getNumElements(); i < e; ++i) {
202 if (failed(computeLoweringImpl(
203 mod, newPorts, conv, portID, port, isFlip,
204 name + "_" + Twine(i), vector.getElementType(),
205 fieldID + vector.getFieldID(i), syms, annos)))
206 return failure();
207 if (!syms.empty(fieldID, fieldID))
208 return mod.emitError("Port [")
209 << port.name
210 << "] should be subdivided, but cannot be because of "
211 "symbol ["
212 << port.sym.getSymIfExists(fieldID) << "] on a vector";
213 if (!annos.empty(fieldID, fieldID)) {
214 auto err = mod.emitError("Port [")
215 << port.name
216 << "] should be subdivided, but cannot be because of "
217 "annotations [";
218 auto [b, e] = annos.find(fieldID, fieldID);
219 err << b->getClass();
220 ++b;
221 for (; b != e; ++b)
222 err << ", " << b->getClass();
223 err << "] on a vector";
224 return err;
225 }
226 }
227 }
228 return success();
229 })
230 .Default([&](FIRRTLType type) {
231 // Properties and other types wind up here.
232 newPorts.push_back(
233 {{StringAttr::get(ctx, name), type,
234 isFlip ? Direction::Out : Direction::In,
235 symbolsForFieldIDRange(ctx, syms, fieldID, fieldID), port.loc,
236 annosForFieldIDRange(ctx, annos, fieldID, fieldID)},
237 portID,
238 newPorts.size(),
239 fieldID});
240 return success();
241 });
242}
243
244// compute a new moduletype from an old module type and lowering convention.
245// Also compute a fieldID map from port, fieldID -> port
246static LogicalResult computeLowering(FModuleLike mod, Convention conv,
247 PortConversion &newPorts) {
248 for (auto [idx, port] : llvm::enumerate(mod.getPorts())) {
249 if (failed(computeLoweringImpl(
250 mod, newPorts, conv, idx, port, port.direction == Direction::Out,
251 port.name.getValue(), type_cast<FIRRTLType>(port.type), 0,
252 FieldIDSearch<hw::InnerSymAttr>(port.sym),
253 FieldIDSearch<AnnotationSet>(port.annotations))))
254 return failure();
255 }
256 return success();
257}
258
259static LogicalResult lowerModuleSignature(FModuleLike module, Convention conv,
260 AttrCache &cache,
261 PortConversion &newPorts) {
262 ImplicitLocOpBuilder theBuilder(module.getLoc(), module.getContext());
263 if (computeLowering(module, conv, newPorts).failed())
264 return failure();
265 if (auto mod = dyn_cast<FModuleOp>(module.getOperation())) {
266 Block *body = mod.getBodyBlock();
267 theBuilder.setInsertionPointToStart(body);
268 auto oldNumArgs = body->getNumArguments();
269
270 // Compute the replacement value for old arguments
271 // This creates all the new arguments and produces bounce wires when
272 // necessary
273 SmallVector<Value> bounceWires(oldNumArgs);
274 for (auto &p : newPorts) {
275 auto newArg = body->addArgument(p.type, p.loc);
276 // Get or create a bounce wire for changed ports
277 // For unmodified ports, move the uses to the replacement port
278 if (p.fieldID != 0) {
279 auto &wire = bounceWires[p.portID];
280 if (!wire)
281 wire = WireOp::create(theBuilder, module.getPortType(p.portID),
282 module.getPortNameAttr(p.portID),
283 NameKindEnum::InterestingName)
284 .getResult();
285 } else {
286 bounceWires[p.portID] = newArg;
287 }
288 }
289 // replace old arguments. Somethings get dropped completely, like
290 // zero-length vectors.
291 for (auto idx = 0U; idx < oldNumArgs; ++idx) {
292 if (!bounceWires[idx]) {
293 bounceWires[idx] = WireOp::create(theBuilder, module.getPortType(idx),
294 module.getPortNameAttr(idx))
295 .getResult();
296 }
297 body->getArgument(idx).replaceAllUsesWith(bounceWires[idx]);
298 }
299
300 // Goodbye old ports, now ResultID in the PortInfo is correct.
301 body->eraseArguments(0, oldNumArgs);
302
303 // Connect the bounce wires to the new arguments
304 for (auto &p : newPorts) {
305 if (isa<BlockArgument>(bounceWires[p.portID]))
306 continue;
307 if (p.isOutput())
309 theBuilder, body->getArgument(p.resultID),
310 getValueByFieldID(theBuilder, bounceWires[p.portID], p.fieldID));
311 else
313 theBuilder,
314 getValueByFieldID(theBuilder, bounceWires[p.portID], p.fieldID),
315 body->getArgument(p.resultID));
316 }
317 }
318
319 SmallVector<NamedAttribute, 8> newModuleAttrs;
320
321 // Copy over any attributes that weren't original argument attributes.
322 for (auto attr : module->getAttrDictionary())
323 // Drop old "portNames", directions, and argument attributes. These are
324 // handled differently below.
325 if (attr.getName() != "portNames" && attr.getName() != "portDirections" &&
326 attr.getName() != "portTypes" && attr.getName() != "portAnnotations" &&
327 attr.getName() != "portSymbols" && attr.getName() != "portLocations" &&
328 attr.getName() != "internalPaths")
329 newModuleAttrs.push_back(attr);
330
331 SmallVector<Direction> newPortDirections;
332 SmallVector<Attribute> newPortNames;
333 SmallVector<Attribute> newPortTypes;
334 SmallVector<Attribute> newPortSyms;
335 SmallVector<Attribute> newPortLocations;
336 SmallVector<Attribute, 8> newPortAnnotations;
337 SmallVector<Attribute> newInternalPaths;
338
339 bool hasInternalPaths = false;
340 auto internalPaths = module->getAttrOfType<ArrayAttr>("internalPaths");
341 for (auto p : newPorts) {
342 newPortTypes.push_back(TypeAttr::get(p.type));
343 newPortNames.push_back(p.name);
344 newPortDirections.push_back(p.direction);
345 newPortSyms.push_back(p.sym);
346 newPortLocations.push_back(p.loc);
347 newPortAnnotations.push_back(p.annotations.getArrayAttr());
348 if (internalPaths) {
349 auto internalPath = cast<InternalPathAttr>(internalPaths[p.portID]);
350 newInternalPaths.push_back(internalPath);
351 if (internalPath.getPath())
352 hasInternalPaths = true;
353 }
354 }
355
356 newModuleAttrs.push_back(NamedAttribute(
357 cache.sPortDirections,
358 direction::packAttribute(module.getContext(), newPortDirections)));
359
360 newModuleAttrs.push_back(
361 NamedAttribute(cache.sPortNames, theBuilder.getArrayAttr(newPortNames)));
362
363 newModuleAttrs.push_back(
364 NamedAttribute(cache.sPortTypes, theBuilder.getArrayAttr(newPortTypes)));
365
366 newModuleAttrs.push_back(NamedAttribute(
367 cache.sPortLocations, theBuilder.getArrayAttr(newPortLocations)));
368
369 newModuleAttrs.push_back(NamedAttribute(
370 cache.sPortAnnotations, theBuilder.getArrayAttr(newPortAnnotations)));
371
372 assert(newInternalPaths.empty() ||
373 newInternalPaths.size() == newPorts.size());
374 if (hasInternalPaths) {
375 newModuleAttrs.emplace_back(cache.sInternalPaths,
376 theBuilder.getArrayAttr(newInternalPaths));
377 }
378
379 // Update the module's attributes.
380 module->setAttrs(newModuleAttrs);
381 FModuleLike::fixupPortSymsArray(newPortSyms, theBuilder.getContext());
382 module.setPortSymbols(newPortSyms);
383 return success();
384}
385
386static void lowerModuleBody(FModuleOp mod,
387 const DenseMap<StringAttr, PortConversion> &ports) {
388 mod->walk([&](InstanceOp inst) -> void {
389 ImplicitLocOpBuilder theBuilder(inst.getLoc(), inst);
390 const auto &modPorts = ports.at(inst.getModuleNameAttr().getAttr());
391
392 // Fix up the Instance
393 SmallVector<PortInfo> instPorts; // Oh I wish ArrayRef was polymorphic.
394 for (auto p : modPorts) {
395 p.sym = {};
396 // Might need to partially copy stuff from the old instance.
397 p.annotations = AnnotationSet{mod.getContext()};
398 instPorts.push_back(p);
399 }
400 auto annos = inst.getAnnotations();
401 auto newOp = InstanceOp::create(
402 theBuilder, instPorts, inst.getModuleName(), inst.getName(),
403 inst.getNameKind(), annos.getValue(), inst.getLayers(),
404 inst.getLowerToBind(), inst.getDoNotPrint(), inst.getInnerSymAttr());
405
406 auto oldDict = inst->getDiscardableAttrDictionary();
407 auto newDict = newOp->getDiscardableAttrDictionary();
408 auto oldNames = inst.getPortNamesAttr();
409 SmallVector<NamedAttribute> newAttrs;
410 for (auto na : oldDict)
411 if (!newDict.contains(na.getName()))
412 newOp->setDiscardableAttr(na.getName(), na.getValue());
413
414 // Connect up the old instance users to the new instance
415 SmallVector<WireOp> bounce(inst.getNumResults());
416 for (auto p : modPorts) {
417 // No change? No bounce wire.
418 if (p.fieldID == 0) {
419 inst.getResult(p.portID).replaceAllUsesWith(
420 newOp.getResult(p.resultID));
421 continue;
422 }
423 if (!bounce[p.portID]) {
424 bounce[p.portID] = WireOp::create(
425 theBuilder, inst.getResult(p.portID).getType(),
426 theBuilder.getStringAttr(
427 inst.getName() + "." +
428 cast<StringAttr>(oldNames[p.portID]).getValue()));
429 inst.getResult(p.portID).replaceAllUsesWith(
430 bounce[p.portID].getResult());
431 }
432 // Connect up the Instance to the bounce wires
433 if (p.isInput())
434 emitConnect(theBuilder, newOp.getResult(p.resultID),
435 getValueByFieldID(theBuilder, bounce[p.portID].getResult(),
436 p.fieldID));
437 else
438 emitConnect(theBuilder,
439 getValueByFieldID(theBuilder, bounce[p.portID].getResult(),
440 p.fieldID),
441 newOp.getResult(p.resultID));
442 }
443 // Zero Width ports may have dangling connects since they are not preserved
444 // and do not have bounce wires.
445 for (auto *use : llvm::make_early_inc_range(inst->getUsers())) {
446 assert(isa<MatchingConnectOp>(use) || isa<ConnectOp>(use));
447 use->erase();
448 }
449 inst->erase();
450 return;
451 });
452}
453
454//===----------------------------------------------------------------------===//
455// Pass Infrastructure
456//===----------------------------------------------------------------------===//
457
458namespace {
459struct LowerSignaturesPass
460 : public circt::firrtl::impl::LowerSignaturesBase<LowerSignaturesPass> {
461 void runOnOperation() override;
462};
463} // end anonymous namespace
464
465// This is the main entrypoint for the lowering pass.
466void LowerSignaturesPass::runOnOperation() {
467 LLVM_DEBUG(debugPassHeader(this) << "\n");
468 // Cached attr
469 AttrCache cache(&getContext());
470
471 DenseMap<StringAttr, PortConversion> portMap;
472 auto circuit = getOperation();
473
474 for (auto mod : circuit.getOps<FModuleLike>()) {
475 if (lowerModuleSignature(mod, mod.getConvention(), cache,
476 portMap[mod.getNameAttr()])
477 .failed())
478 return signalPassFailure();
479 }
480 parallelForEach(&getContext(), circuit.getOps<FModuleOp>(),
481 [&portMap](FModuleOp mod) { lowerModuleBody(mod, portMap); });
482}
assert(baseType &&"element must be base type")
static LogicalResult computeLowering(FModuleLike mod, Convention conv, PortConversion &newPorts)
static AnnotationSet annosForFieldIDRange(MLIRContext *ctx, const FieldIDSearch< AnnotationSet > &annos, uint64_t low, uint64_t high)
static LogicalResult lowerModuleSignature(FModuleLike module, Convention conv, AttrCache &cache, PortConversion &newPorts)
static LogicalResult computeLoweringImpl(FModuleLike mod, PortConversion &newPorts, Convention conv, size_t portID, const PortInfo &port, bool isFlip, Twine name, FIRRTLType type, uint64_t fieldID, const FieldIDSearch< hw::InnerSymAttr > &syms, const FieldIDSearch< AnnotationSet > &annos)
static void lowerModuleBody(FModuleOp mod, const DenseMap< StringAttr, PortConversion > &ports)
static hw::InnerSymAttr symbolsForFieldIDRange(MLIRContext *ctx, const FieldIDSearch< hw::InnerSymAttr > &syms, uint64_t low, uint64_t high)
static InstancePath empty
This class provides a read-only projection over the MLIR attributes that represent a set of annotatio...
MLIRContext * getContext() const
Return the MLIRContext corresponding to this AnnotationSet.
void addAnnotations(ArrayRef< Annotation > annotations)
Add more annotations to this annotation set.
This class provides a read-only projection of an annotation.
This class implements the same functionality as TypeSwitch except that it uses firrtl::type_dyn_cast ...
FIRRTLTypeSwitch< T, ResultT > & Case(CallableT &&caseFn)
Add a case on the given type.
Base class for the port conversion of a particular port.
mlir::DenseBoolArrayAttr packAttribute(MLIRContext *context, ArrayRef< Direction > directions)
Return a DenseBoolArrayAttr containing the packed representation of an array of directions.
Value getValueByFieldID(ImplicitLocOpBuilder builder, Value value, unsigned fieldID)
This gets the value targeted by a field id.
void emitConnect(OpBuilder &builder, Location loc, Value lhs, Value rhs)
Emit a connect between two values.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
llvm::raw_ostream & debugPassHeader(const mlir::Pass *pass, int width=80)
Write a boilerplate header for a pass to the debug stream.
Definition Debug.cpp:31
This holds the name and type that describes the module's ports.