14#include "mlir/IR/Matchers.h"
15#include "llvm/Support/Debug.h"
17#define DEBUG_TYPE "llhd-combine-drives"
21#define GEN_PASS_DEF_COMBINEDRIVESPASS
22#include "circt/Dialect/LLHD/Transforms/LLHDPasses.h.inc"
30using llvm::SmallMapVector;
31using llvm::SmallSetVector;
32using llvm::SpecificBumpPtrAllocator;
38 return TypeSwitch<Type, unsigned>(cast<hw::InOutType>(type).getElementType())
39 .Case<IntegerType>([](
auto type) {
return type.getWidth(); })
40 .Case<hw::ArrayType>([](
auto type) {
return type.getNumElements(); })
41 .Case<hw::StructType>([](
auto type) {
return type.getElements().size(); })
42 .Default([](
auto) {
return 0; });
87 Signal *signal =
nullptr;
91 explicit operator bool()
const {
return signal !=
nullptr; }
110 Signal *parent =
nullptr;
112 unsigned indexInParent = 0;
115 SmallVector<Signal *> subsignals;
120 SmallVector<ValueSlice> slices;
123 SmallVector<DrvOp, 2> completeDrives;
126 explicit Signal(Value root) : value(root) {}
128 Signal(Value value, Signal *parent,
unsigned indexInParent)
129 : value(value), parent(parent), indexInParent(indexInParent) {}
134struct ModuleContext {
135 ModuleContext(
HWModuleOp moduleOp) : moduleOp(moduleOp) {}
139 SignalSlice traceProjection(Value value);
140 SignalSlice traceProjectionImpl(Value value);
141 Signal *internSignal(Value root);
142 Signal *internSignal(Value value, Signal *parent,
unsigned index);
145 void aggregateDrives(Signal &signal);
146 void addDefaultDriveSlices(Signal &signal,
147 SmallVectorImpl<DriveSlice> &slices);
148 void aggregateDriveSlices(Signal &signal, Value driveDelay, Value driveEnable,
149 ArrayRef<DriveSlice> slices);
154 DenseMap<Value, SignalSlice> projections;
156 SmallVector<Signal *> rootSignals;
161 using SignalKey = std::pair<PointerUnion<Value, Signal *>,
unsigned>;
162 SpecificBumpPtrAllocator<Signal> signalAlloc;
163 DenseMap<SignalKey, Signal *> internedSignals;
169 const Signal &signal) {
171 return os << *signal.parent <<
"[" << signal.indexInParent <<
"]";
172 signal.value.printAsOperand(os, OpPrintingFlags().useLocalScope());
177static llvm::raw_ostream &
operator<<(llvm::raw_ostream &os, SignalSlice slice) {
179 return os <<
"<null-slice>";
180 return os << *slice.signal <<
"[" << slice.offset <<
".."
181 << (slice.offset + slice.length) <<
"]";
177static llvm::raw_ostream &
operator<<(llvm::raw_ostream &os, SignalSlice slice) {
…}
191SignalSlice ModuleContext::traceProjection(Value value) {
193 if (
auto it = projections.find(value); it != projections.end())
197 auto projection = traceProjectionImpl(value);
199 projection.signal->slices.push_back(
200 ValueSlice{value, projection.offset, projection.length});
201 projections.insert({value, projection});
202 LLVM_DEBUG(llvm::dbgs() <<
"- Traced " << value <<
" to " << projection
208SignalSlice ModuleContext::traceProjectionImpl(Value value) {
213 if (
auto op = value.getDefiningOp<SigExtractOp>()) {
214 auto slice = traceProjection(op.getInput());
217 IntegerAttr offsetAttr;
218 if (!matchPattern(op.getLowBit(), m_Constant(&offsetAttr)))
220 slice.offset += offsetAttr.getValue().getZExtValue();
221 slice.length =
getLength(value.getType());
225 if (
auto op = value.getDefiningOp<SigArraySliceOp>()) {
226 auto slice = traceProjection(op.getInput());
229 IntegerAttr offsetAttr;
230 if (!matchPattern(op.getLowIndex(), m_Constant(&offsetAttr)))
232 slice.offset += offsetAttr.getValue().getZExtValue();
233 slice.length =
getLength(value.getType());
240 if (
auto op = value.getDefiningOp<SigArrayGetOp>()) {
241 auto input = traceProjection(op.getInput());
244 IntegerAttr indexAttr;
245 if (!matchPattern(op.getIndex(), m_Constant(&indexAttr)))
247 unsigned offset = input.offset + indexAttr.getValue().getZExtValue();
249 slice.signal = internSignal(value, input.signal, offset);
250 slice.length =
getLength(value.getType());
254 if (
auto op = value.getDefiningOp<SigStructExtractOp>()) {
255 auto structType = cast<hw::StructType>(
256 cast<hw::InOutType>(op.getInput().getType()).getElementType());
257 auto input = traceProjection(op.getInput());
260 assert(input.offset == 0);
261 assert(input.length == structType.getElements().size());
262 unsigned index = *structType.getFieldIndex(op.getFieldAttr());
264 slice.signal = internSignal(value, input.signal, index);
265 slice.length =
getLength(value.getType());
271 slice.signal = internSignal(value);
272 slice.length =
getLength(value.getType());
279Signal *ModuleContext::internSignal(Value root) {
280 auto &slot = internedSignals[{root, 0}];
282 slot =
new (signalAlloc.Allocate()) Signal(root);
283 rootSignals.push_back(slot);
291Signal *ModuleContext::internSignal(Value value, Signal *parent,
293 auto &slot = internedSignals[{parent, index}];
295 slot =
new (signalAlloc.Allocate()) Signal(value, parent, index);
296 parent->subsignals.push_back(slot);
310void ModuleContext::aggregateDrives(Signal &signal) {
316 SmallMapVector<std::pair<Value, Value>, SmallVector<DriveSlice>, 2> drives;
317 SmallPtrSet<Operation *, 8> knownDrives;
318 auto addDrive = [&](DrvOp op,
unsigned offset,
unsigned length) {
319 knownDrives.insert(op);
320 drives[{op.getTime(), op.getEnable()}].push_back(
321 DriveSlice{op, op.getValue(), offset, length});
323 for (
auto *subsignal : signal.subsignals) {
324 aggregateDrives(*subsignal);
332 for (
auto driveOp : subsignal->completeDrives)
333 addDrive(driveOp, subsignal->indexInParent, 0);
337 for (
auto slice : signal.slices) {
338 for (
auto &use : slice.value.getUses()) {
339 auto driveOp = dyn_cast<DrvOp>(use.getOwner());
340 if (driveOp && use.getOperandNumber() == 0 &&
341 driveOp->getBlock() == slice.value.getParentBlock())
342 addDrive(driveOp, slice.offset, slice.length);
349 SmallSetVector<Value, 8> worklist;
350 worklist.insert(signal.value);
351 bool hasUnknownUses =
false;
352 while (!worklist.empty() && !hasUnknownUses) {
353 auto value = worklist.pop_back_val();
354 for (
auto *user : value.getUsers()) {
355 if (isa<PrbOp>(user))
357 if (isa<DrvOp>(user) && knownDrives.contains(user))
359 if (isa<SigExtractOp, SigStructExtractOp, SigArrayGetOp, SigArraySliceOp>(
361 worklist.insert(user->getResult(0));
364 hasUnknownUses =
true;
372 if (!hasUnknownUses && drives.size() == 1) {
373 auto &slices = drives.begin()->second;
374 addDefaultDriveSlices(signal, slices);
379 for (
auto &[key, slices] : drives) {
380 llvm::sort(slices, [](
auto &a,
auto &b) {
return a.offset < b.offset; });
381 aggregateDriveSlices(signal, key.first, key.second, slices);
388void ModuleContext::addDefaultDriveSlices(Signal &signal,
389 SmallVectorImpl<DriveSlice> &slices) {
390 auto type = cast<hw::InOutType>(signal.value.getType()).getElementType();
393 llvm::sort(slices, [](
auto &a,
auto &b) {
return a.offset < b.offset; });
398 bool anyOverlaps =
false;
399 bool needSeparateFields = isa<hw::StructType>(type);
400 SmallVector<DriveSlice> gapSlices;
401 auto fillGap = [&](
unsigned from,
unsigned to) {
408 if (needSeparateFields) {
409 for (
auto idx = from; idx < to; ++idx)
410 gapSlices.push_back(DriveSlice{DrvOp{}, Value{}, idx, 0});
412 gapSlices.push_back(DriveSlice{DrvOp{}, Value{}, from, to - from});
419 unsigned expectedOffset = 0;
420 for (
auto slice : slices) {
421 fillGap(expectedOffset, slice.offset);
422 expectedOffset = slice.offset + std::max<unsigned>(1, slice.length);
426 fillGap(expectedOffset,
getLength(signal.value.getType()));
430 if (anyOverlaps || gapSlices.empty())
445 auto signalOp = signal.value.getDefiningOp<SignalOp>();
448 auto defaultValue = signalOp.getInit();
451 ImplicitLocOpBuilder builder(signal.value.getLoc(),
452 signal.value.getContext());
453 builder.setInsertionPointAfterValue(signal.value);
455 for (
auto &slice : gapSlices) {
456 LLVM_DEBUG(llvm::dbgs()
457 <<
"- Filling gap " << signal <<
"[" << slice.offset <<
".."
458 << (slice.offset + slice.length) <<
"] with initial value\n");
461 if (
auto intType = dyn_cast<IntegerType>(type)) {
464 builder.getIntegerType(slice.length), defaultValue, slice.offset);
469 if (
auto structType = dyn_cast<hw::StructType>(type)) {
470 assert(slice.length == 0);
472 defaultValue, structType.getElements()[slice.offset]);
477 if (
auto arrayType = dyn_cast<hw::ArrayType>(type)) {
480 APInt(llvm::Log2_64_Ceil(arrayType.getNumElements()), slice.offset));
482 hw::ArrayType::get(arrayType.getElementType(), slice.length),
483 defaultValue, offset);
491 slices.append(gapSlices.begin(), gapSlices.end());
496void ModuleContext::aggregateDriveSlices(Signal &signal, Value driveDelay,
498 ArrayRef<DriveSlice> slices) {
500 unsigned expectedOffset = 0;
501 for (
auto slice : slices) {
502 assert(slice.value &&
"all slices must have an assigned value");
503 if (slice.offset != expectedOffset) {
510 expectedOffset += std::max<unsigned>(1, slice.length);
512 if (expectedOffset !=
getLength(signal.value.getType())) {
513 LLVM_DEBUG(llvm::dbgs()
514 <<
"- Signal " << signal <<
" not completely driven\n");
521 if (slices.size() == 1 && slices[0].length != 0 && slices[0].op) {
522 signal.completeDrives.push_back(slices[0].op);
526 llvm::dbgs() <<
"- Aggregating " << signal <<
" drives (delay ";
527 driveDelay.printAsOperand(llvm::dbgs(), OpPrintingFlags().useLocalScope());
529 llvm::dbgs() <<
" if ";
530 driveEnable.printAsOperand(llvm::dbgs(),
531 OpPrintingFlags().useLocalScope());
533 llvm::dbgs() <<
")\n";
537 auto type = cast<hw::InOutType>(signal.value.getType()).getElementType();
538 ImplicitLocOpBuilder builder(signal.value.getLoc(),
539 signal.value.getContext());
540 builder.setInsertionPointAfterValue(signal.value);
543 if (
auto intType = dyn_cast<IntegerType>(type)) {
547 SmallVector<Value> operands;
548 for (
auto slice : slices)
549 operands.push_back(slice.value);
550 std::reverse(operands.begin(), operands.end());
552 LLVM_DEBUG(llvm::dbgs() <<
" - Created " << result <<
"\n");
556 if (
auto structType = dyn_cast<hw::StructType>(type)) {
559 SmallVector<Value> operands;
560 for (
auto slice : slices)
561 operands.push_back(slice.value);
563 LLVM_DEBUG(llvm::dbgs() <<
" - Created " << result <<
"\n");
567 if (
auto arrayType = dyn_cast<hw::ArrayType>(type)) {
571 SmallVector<Value> scalars;
572 SmallVector<Value> aggregates;
573 auto flushScalars = [&] {
576 std::reverse(scalars.begin(), scalars.end());
578 aggregates.push_back(aggregate);
580 LLVM_DEBUG(llvm::dbgs() <<
" - Created " << aggregate <<
"\n");
582 for (
auto slice : slices) {
583 if (slice.length == 0) {
584 scalars.push_back(slice.value);
587 aggregates.push_back(slice.value);
594 result = aggregates.back();
595 if (aggregates.size() != 1) {
596 std::reverse(aggregates.begin(), aggregates.end());
598 LLVM_DEBUG(llvm::dbgs() <<
" - Created " << result <<
"\n");
605 builder.create<DrvOp>(signal.value, result, driveDelay, driveEnable);
606 signal.completeDrives.push_back(driveOp);
607 LLVM_DEBUG(llvm::dbgs() <<
" - Created " << driveOp <<
"\n");
610 for (
auto slice : slices) {
613 LLVM_DEBUG(llvm::dbgs() <<
" - Removed " << slice.op <<
"\n");
614 pruner.eraseNow(slice.op);
623struct CombineDrivesPass
624 :
public llhd::impl::CombineDrivesPassBase<CombineDrivesPass> {
625 void runOnOperation()
override;
629void CombineDrivesPass::runOnOperation() {
630 LLVM_DEBUG(llvm::dbgs() <<
"Combining drives in "
631 << getOperation().getModuleNameAttr() <<
"\n");
632 ModuleContext context(getOperation());
636 if (isa<SigExtractOp, SigArraySliceOp, SigArrayGetOp, SigStructExtractOp>(
638 context.traceProjection(op.getResult(0));
641 for (
auto *signal : context.rootSignals)
642 context.aggregateDrives(*signal);
645 context.pruner.eraseNow();
assert(baseType &&"element must be base type")
static unsigned getLength(Type type)
Determine the number of elements in a type.
static Block * getBodyBlock(FModuleLike mod)
OS & operator<<(OS &os, const InnerSymTarget &target)
Printing InnerSymTarget's.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Utility that tracks operations that have potentially become unused and allows them to be cleaned up a...