21 #include "mlir/IR/ImplicitLocOpBuilder.h"
22 #include "mlir/IR/OperationSupport.h"
23 #include "llvm/ADT/DenseMap.h"
24 #include "llvm/ADT/Hashing.h"
25 #include "llvm/ADT/TypeSwitch.h"
27 using namespace circt;
28 using namespace firrtl;
29 using namespace chirrtl;
32 struct LowerCHIRRTLPass :
public LowerCHIRRTLPassBase<LowerCHIRRTLPass>,
40 void visitCHIRRTL(CombMemOp op);
41 void visitCHIRRTL(SeqMemOp op);
42 void visitCHIRRTL(MemoryPortOp op);
43 void visitCHIRRTL(MemoryDebugPortOp op);
44 void visitCHIRRTL(MemoryPortAccessOp op);
45 void visitExpr(SubaccessOp op);
46 void visitExpr(SubfieldOp op);
47 void visitExpr(SubindexOp op);
48 void visitStmt(ConnectOp op);
49 void visitStmt(StrictConnectOp op);
50 void visitUnhandledOp(Operation *op);
53 void visitInvalidCHIRRTL(Operation *op) { dispatchVisitor(op); }
54 void visitUnhandledCHIRRTL(Operation *op) { visitUnhandledOp(op); }
58 Value getConst(
unsigned c) {
59 auto &
value = constCache[c];
61 auto module = getOperation();
62 auto builder = OpBuilder::atBlockBegin(module.getBodyBlock());
64 value =
builder.create<ConstantOp>(module.getLoc(), u1Type, APInt(1, c));
79 void emitInvalid(ImplicitLocOpBuilder &
builder, Value
value);
81 MemDirAttr inferMemoryPortKind(MemoryPortOp memPort);
83 void replaceMem(Operation *op, StringRef name,
bool isSequential, RUWAttr ruw,
84 ArrayAttr annotations);
86 template <
typename OpType,
typename... T>
87 void cloneSubindexOpForMemory(OpType op, Value input, T... operands);
89 void runOnOperation()
override;
92 DenseMap<unsigned, Value> constCache;
93 DenseMap<Type, Value> invalidCache;
96 SmallVector<Operation *> opsToDelete;
101 DenseMap<Operation *, MemDirAttr> subfieldDirs;
106 DenseMap<Value, Value> rdataValues;
118 DenseMap<Value, WDataInfo> wdataValues;
126 llvm::function_ref<
void(Value)> func) {
127 auto type =
value.getType();
128 if (
auto bundleType = type_dyn_cast<BundleType>(type)) {
129 for (
size_t i = 0, e = bundleType.getNumElements(); i < e; ++i)
131 }
else if (
auto vectorType = type_dyn_cast<FVectorType>(type)) {
132 for (
size_t i = 0, e = vectorType.getNumElements(); i != e; ++i)
150 void LowerCHIRRTLPass::emitInvalid(ImplicitLocOpBuilder &
builder, Value
value) {
151 auto type =
value.getType();
152 auto &invalid = invalidCache[type];
154 auto builder = OpBuilder::atBlockBegin(getOperation().getBodyBlock());
155 invalid =
builder.create<InvalidValueOp>(getOperation().getLoc(), type);
164 case MemDirAttr::Read:
165 return MemOp::PortKind::Read;
166 case MemDirAttr::Write:
167 return MemOp::PortKind::Write;
168 case MemDirAttr::ReadWrite:
169 return MemOp::PortKind::ReadWrite;
172 "Unhandled MemDirAttr, was the port direction not inferred?");
187 MemDirAttr LowerCHIRRTLPass::inferMemoryPortKind(MemoryPortOp memPort) {
194 struct StackElement {
195 StackElement(Value
value, Value::use_iterator iterator, MemDirAttr mode)
196 :
value(
value), iterator(iterator), mode(mode) {}
198 Value::use_iterator iterator;
202 SmallVector<StackElement> stack;
203 stack.emplace_back(memPort.getData(), memPort.getData().use_begin(),
204 memPort.getDirection());
205 MemDirAttr mode = MemDirAttr::Infer;
207 while (!stack.empty()) {
208 auto *iter = &stack.back().iterator;
209 auto end = stack.back().value.use_end();
210 stack.back().mode |= mode;
212 while (*iter != end) {
213 auto &element = stack.back();
214 auto &use = *(*iter);
215 auto *user = use.getOwner();
217 if (isa<SubindexOp, SubfieldOp>(user)) {
219 auto output = user->getResult(0);
220 stack.emplace_back(output, output.use_begin(), MemDirAttr::Infer);
221 mode = MemDirAttr::Infer;
222 iter = &stack.back().iterator;
223 end = output.use_end();
226 if (
auto subaccessOp = dyn_cast<SubaccessOp>(user)) {
230 auto input = subaccessOp.getInput();
231 if (use.get() == input) {
232 auto output = subaccessOp.getResult();
233 stack.emplace_back(output, output.use_begin(), MemDirAttr::Infer);
234 mode = MemDirAttr::Infer;
235 iter = &stack.back().iterator;
236 end = output.use_end();
240 element.mode |= MemDirAttr::Read;
241 }
else if (
auto connectOp = dyn_cast<ConnectOp>(user)) {
242 if (use.get() == connectOp.getDest()) {
243 element.mode |= MemDirAttr::Write;
245 element.mode |= MemDirAttr::Read;
247 }
else if (
auto connectOp = dyn_cast<StrictConnectOp>(user)) {
248 if (use.get() == connectOp.getDest()) {
249 element.mode |= MemDirAttr::Write;
251 element.mode |= MemDirAttr::Read;
255 element.mode |= MemDirAttr::Read;
258 mode = stack.back().mode;
263 subfieldDirs[stack.back().value.getDefiningOp()] = mode;
270 void LowerCHIRRTLPass::replaceMem(Operation *cmem, StringRef name,
271 bool isSequential, RUWAttr ruw,
272 ArrayAttr annotations) {
273 assert(isa<CombMemOp>(cmem) || isa<SeqMemOp>(cmem));
277 opsToDelete.push_back(cmem);
280 auto cmemType = type_cast<CMemoryType>(cmem->getResult(0).getType());
281 auto depth = cmemType.getNumElements();
282 auto type = cmemType.getElementType();
288 Attribute annotations;
289 MemOp::PortKind portKind;
292 SmallVector<PortInfo, 4> ports;
293 for (
auto *user : cmem->getUsers()) {
294 MemOp::PortKind portKind;
297 if (
auto cmemoryPort = dyn_cast<MemoryPortOp>(user)) {
299 auto portDirection = inferMemoryPortKind(cmemoryPort);
304 if (portDirection == MemDirAttr::Infer)
307 portName = cmemoryPort.getNameAttr();
308 portAnnos = cmemoryPort.getAnnotationsAttr();
309 }
else if (
auto dPort = dyn_cast<MemoryDebugPortOp>(user)) {
310 portKind = MemOp::PortKind::Debug;
311 portName = dPort.getNameAttr();
312 portAnnos = dPort.getAnnotationsAttr();
314 user->emitOpError(
"unhandled user of chirrtl memory");
319 ports.push_back({portName, MemOp::getTypeForPort(depth, type, portKind),
320 portAnnos, portKind, user});
330 llvm::array_pod_sort(ports.begin(), ports.end(),
332 return lhs->name.getValue().compare(
333 rhs->name.getValue());
336 SmallVector<Attribute, 4> resultNames;
337 SmallVector<Type, 4> resultTypes;
338 SmallVector<Attribute, 4> portAnnotations;
339 for (
auto port : ports) {
340 resultNames.push_back(port.name);
341 resultTypes.push_back(port.type);
342 portAnnotations.push_back(port.annotations);
347 auto readLatency = isSequential ? 1 : 0;
348 auto writeLatency = 1;
351 ImplicitLocOpBuilder memBuilder(cmem->getLoc(), cmem);
352 auto symOp = cast<hw::InnerSymbolOpInterface>(cmem);
353 auto memory = memBuilder.create<MemOp>(
354 resultTypes, readLatency, writeLatency, depth, ruw,
355 memBuilder.getArrayAttr(resultNames), name,
356 cmem->getAttrOfType<firrtl::NameKindEnumAttr>(
"nameKind").getValue(),
357 annotations, memBuilder.getArrayAttr(portAnnotations),
358 symOp.getInnerSymAttr(),
359 cmem->getAttrOfType<firrtl::MemoryInitAttr>(
"init"), StringAttr());
364 for (
unsigned i = 0, e = memory.getNumResults(); i < e; ++i) {
365 auto memoryPort = memory.getResult(i);
366 auto portKind = ports[i].portKind;
367 if (portKind == MemOp::PortKind::Debug) {
368 rdataValues[ports[i].cmemPort->getResult(0)] = memoryPort;
371 auto cmemoryPort = cast<MemoryPortOp>(ports[i].cmemPort);
372 auto cmemoryPortAccess = cmemoryPort.getAccess();
379 ImplicitLocOpBuilder portBuilder(cmemoryPortAccess.getLoc(),
381 auto address = memBuilder.create<SubfieldOp>(memoryPort,
"addr");
382 emitInvalid(memBuilder, address);
383 auto enable = memBuilder.create<SubfieldOp>(memoryPort,
"en");
385 auto clock = memBuilder.create<SubfieldOp>(memoryPort,
"clk");
386 emitInvalid(memBuilder, clock);
389 emitConnect(portBuilder, address, cmemoryPortAccess.getIndex());
391 auto useEnableInference = isSequential && portKind == MemOp::PortKind::Read;
392 auto *addressOp = cmemoryPortAccess.getIndex().getDefiningOp();
395 useEnableInference &=
396 !addressOp || isa<WireOp, NodeOp, RegOp, RegResetOp>(addressOp);
399 if (!useEnableInference)
402 emitConnect(portBuilder, clock, cmemoryPortAccess.getClock());
404 if (portKind == MemOp::PortKind::Read) {
406 auto data = memBuilder.create<SubfieldOp>(memoryPort,
"data");
407 rdataValues[cmemoryPort.getData()] =
data;
408 }
else if (portKind == MemOp::PortKind::Write) {
410 auto data = memBuilder.create<SubfieldOp>(memoryPort,
"data");
411 emitInvalid(memBuilder, data);
412 auto mask = memBuilder.create<SubfieldOp>(memoryPort,
"mask");
413 emitInvalid(memBuilder, mask);
419 wdataValues[cmemoryPort.getData()] = {
data,
mask,
nullptr};
420 }
else if (portKind == MemOp::PortKind::ReadWrite) {
422 auto rdata = memBuilder.create<SubfieldOp>(memoryPort,
"rdata");
423 auto wmode = memBuilder.create<SubfieldOp>(memoryPort,
"wmode");
425 auto wdata = memBuilder.create<SubfieldOp>(memoryPort,
"wdata");
426 emitInvalid(memBuilder, wdata);
427 auto wmask = memBuilder.create<SubfieldOp>(memoryPort,
"wmask");
428 emitInvalid(memBuilder, wmask);
435 rdataValues[cmemoryPort.getData()] =
rdata;
443 if (useEnableInference) {
444 auto *indexOp = cmemoryPortAccess.getIndex().getDefiningOp();
445 bool success =
false;
450 }
else if (isa<WireOp, RegResetOp, RegOp>(indexOp)) {
457 make_filter_range(indexOp->getUsers(), [&](Operation *op) {
458 if (auto connectOp = dyn_cast<ConnectOp>(op)) {
459 if (cmemoryPortAccess.getIndex() == connectOp.getDest())
460 return !dyn_cast_or_null<InvalidValueOp>(
461 connectOp.getSrc().getDefiningOp());
462 }
else if (
auto connectOp = dyn_cast<StrictConnectOp>(op)) {
463 if (cmemoryPortAccess.getIndex() == connectOp.getDest())
464 return !dyn_cast_or_null<InvalidValueOp>(
465 connectOp.getSrc().getDefiningOp());
471 for (
auto *driver : drivers) {
472 OpBuilder(driver).create<StrictConnectOp>(driver->getLoc(), enable,
476 }
else if (isa<NodeOp>(indexOp)) {
479 OpBuilder(indexOp).create<StrictConnectOp>(indexOp->getLoc(), enable,
486 cmemoryPort.emitWarning(
"memory port is never enabled");
491 void LowerCHIRRTLPass::visitCHIRRTL(CombMemOp combmem) {
492 replaceMem(combmem, combmem.getName(),
false,
493 RUWAttr::Undefined, combmem.getAnnotations());
496 void LowerCHIRRTLPass::visitCHIRRTL(SeqMemOp seqmem) {
497 replaceMem(seqmem, seqmem.getName(),
true, seqmem.getRuw(),
498 seqmem.getAnnotations());
501 void LowerCHIRRTLPass::visitCHIRRTL(MemoryPortOp memPort) {
503 opsToDelete.push_back(memPort);
506 void LowerCHIRRTLPass::visitCHIRRTL(MemoryDebugPortOp memPort) {
508 opsToDelete.push_back(memPort);
511 void LowerCHIRRTLPass::visitCHIRRTL(MemoryPortAccessOp memPortAccess) {
513 opsToDelete.push_back(memPortAccess);
516 void LowerCHIRRTLPass::visitStmt(ConnectOp
connect) {
519 auto writeIt = wdataValues.find(
connect.getDest());
520 if (writeIt != wdataValues.end()) {
521 auto writeData = writeIt->second;
522 connect.getDestMutable().assign(writeData.data);
532 auto readIt = rdataValues.find(
connect.getSrc());
533 if (readIt != rdataValues.end()) {
534 auto newSource = readIt->second;
535 connect.getSrcMutable().assign(newSource);
539 void LowerCHIRRTLPass::visitStmt(StrictConnectOp
connect) {
542 auto writeIt = wdataValues.find(
connect.getDest());
543 if (writeIt != wdataValues.end()) {
544 auto writeData = writeIt->second;
545 connect.getDestMutable().assign(writeData.data);
555 auto readIt = rdataValues.find(
connect.getSrc());
556 if (readIt != rdataValues.end()) {
557 auto newSource = readIt->second;
558 connect.getSrcMutable().assign(newSource);
569 template <
typename OpType,
typename... T>
570 void LowerCHIRRTLPass::cloneSubindexOpForMemory(OpType op, Value input,
574 auto it = subfieldDirs.find(op);
575 if (it == subfieldDirs.end()) {
578 auto iter = rdataValues.find(input);
579 if (iter != rdataValues.end()) {
580 opsToDelete.push_back(op);
581 ImplicitLocOpBuilder
builder(op->getLoc(), op);
582 rdataValues[op] =
builder.create<OpType>(rdataValues[input], operands...);
591 opsToDelete.push_back(op);
593 auto direction = it->second;
594 ImplicitLocOpBuilder
builder(op->getLoc(), op);
598 if (direction == MemDirAttr::Read || direction == MemDirAttr::ReadWrite) {
599 rdataValues[op] =
builder.create<OpType>(rdataValues[input], operands...);
604 if (direction == MemDirAttr::Write || direction == MemDirAttr::ReadWrite) {
605 auto writeData = wdataValues[input];
606 auto write =
builder.create<OpType>(writeData.data, operands...);
607 auto mask =
builder.create<OpType>(writeData.mask, operands...);
608 wdataValues[op] = {write,
mask, writeData.mode};
612 void LowerCHIRRTLPass::visitExpr(SubaccessOp subaccess) {
615 auto readIt = rdataValues.find(subaccess.getIndex());
616 if (readIt != rdataValues.end()) {
617 subaccess.getIndexMutable().assign(readIt->second);
620 cloneSubindexOpForMemory(subaccess, subaccess.getInput(),
621 subaccess.getIndex());
624 void LowerCHIRRTLPass::visitExpr(SubfieldOp subfield) {
625 cloneSubindexOpForMemory<SubfieldOp>(subfield, subfield.getInput(),
626 subfield.getFieldIndex());
629 void LowerCHIRRTLPass::visitExpr(SubindexOp subindex) {
630 cloneSubindexOpForMemory<SubindexOp>(subindex, subindex.getInput(),
631 subindex.getIndex());
634 void LowerCHIRRTLPass::visitUnhandledOp(Operation *op) {
637 for (
auto &operand : op->getOpOperands()) {
638 auto it = rdataValues.find(operand.get());
639 if (it != rdataValues.end()) {
640 operand.set(it->second);
645 void LowerCHIRRTLPass::runOnOperation() {
649 getOperation().getBodyBlock()->walk(
650 [&](Operation *op) { dispatchCHIRRTLVisitor(op); });
654 if (opsToDelete.empty())
655 markAllAnalysesPreserved();
658 while (!opsToDelete.empty())
659 opsToDelete.pop_back_val()->erase();
666 return std::make_unique<LowerCHIRRTLPass>();
assert(baseType &&"element must be base type")
static MemOp::PortKind memDirAttrToPortKind(MemDirAttr direction)
Converts a CHIRRTL memory port direction to a MemoryOp port type.
static void connectLeafsTo(ImplicitLocOpBuilder &builder, Value bundle, Value value)
Drive a value to all leafs of the input aggregate value.
static void forEachLeaf(ImplicitLocOpBuilder &builder, Value value, llvm::function_ref< void(Value)> func)
Performs the callback for each leaf element of a value.
CHIRRTLVisitor is a visitor for CHIRRTL operations.
FIRRTLVisitor allows you to visit all of the expr/stmt/decls with one class declaration.
def connect(destination, source)
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
std::unique_ptr< mlir::Pass > createLowerCHIRRTLPass()
void emitConnect(OpBuilder &builder, Location loc, Value lhs, Value rhs)
Emit a connect between two values.
This file defines an intermediate representation for circuits acting as an abstraction for constraint...
This holds the name and type that describes the module's ports.