14#include "mlir/IR/Matchers.h"
15#include "llvm/Support/Debug.h"
17#define DEBUG_TYPE "llhd-combine-drives"
21#define GEN_PASS_DEF_COMBINEDRIVESPASS
22#include "circt/Dialect/LLHD/Transforms/LLHDPasses.h.inc"
30using llvm::SmallMapVector;
31using llvm::SmallSetVector;
32using llvm::SpecificBumpPtrAllocator;
38 return TypeSwitch<Type, unsigned>(cast<hw::InOutType>(type).getElementType())
39 .Case<IntegerType>([](
auto type) {
return type.getWidth(); })
40 .Case<hw::ArrayType>([](
auto type) {
return type.getNumElements(); })
41 .Case<hw::StructType>([](
auto type) {
return type.getElements().size(); })
42 .Default([](
auto) {
return 0; });
87 Signal *signal =
nullptr;
91 explicit operator bool()
const {
return signal !=
nullptr; }
110 Signal *parent =
nullptr;
112 unsigned indexInParent = 0;
115 SmallVector<Signal *> subsignals;
120 SmallVector<ValueSlice> slices;
123 SmallVector<DrvOp, 2> completeDrives;
126 explicit Signal(Value root) : value(root) {}
128 Signal(Value value, Signal *parent,
unsigned indexInParent)
129 : value(value), parent(parent), indexInParent(indexInParent) {}
134struct ModuleContext {
135 ModuleContext(
HWModuleOp moduleOp) : moduleOp(moduleOp) {}
139 SignalSlice traceProjection(Value value);
140 SignalSlice traceProjectionImpl(Value value);
141 Signal *internSignal(Value root);
142 Signal *internSignal(Value value, Signal *parent,
unsigned index);
145 void aggregateDrives(Signal &signal);
146 void addDefaultDriveSlices(Signal &signal,
147 SmallVectorImpl<DriveSlice> &slices);
148 void aggregateDriveSlices(Signal &signal, Value driveDelay, Value driveEnable,
149 ArrayRef<DriveSlice> slices);
154 DenseMap<Value, SignalSlice> projections;
156 SmallVector<Signal *> rootSignals;
161 using SignalKey = std::pair<PointerUnion<Value, Signal *>,
unsigned>;
162 SpecificBumpPtrAllocator<Signal> signalAlloc;
163 DenseMap<SignalKey, Signal *> internedSignals;
169 const Signal &signal) {
171 return os << *signal.parent <<
"[" << signal.indexInParent <<
"]";
172 signal.value.printAsOperand(os, OpPrintingFlags().useLocalScope());
177static llvm::raw_ostream &
operator<<(llvm::raw_ostream &os, SignalSlice slice) {
179 return os <<
"<null-slice>";
180 return os << *slice.signal <<
"[" << slice.offset <<
".."
181 << (slice.offset + slice.length) <<
"]";
191SignalSlice ModuleContext::traceProjection(Value value) {
193 if (
auto it = projections.find(value); it != projections.end())
197 auto projection = traceProjectionImpl(value);
199 projection.signal->slices.push_back(
200 ValueSlice{value, projection.offset, projection.length});
201 projections.insert({value, projection});
202 LLVM_DEBUG(llvm::dbgs() <<
"- Traced " << value <<
" to " << projection
208SignalSlice ModuleContext::traceProjectionImpl(Value value) {
213 if (
auto op = value.getDefiningOp<SigExtractOp>()) {
214 auto slice = traceProjection(op.getInput());
217 IntegerAttr offsetAttr;
218 if (!matchPattern(op.getLowBit(), m_Constant(&offsetAttr)))
220 slice.offset += offsetAttr.getValue().getZExtValue();
221 slice.length =
getLength(value.getType());
225 if (
auto op = value.getDefiningOp<SigArraySliceOp>()) {
226 auto slice = traceProjection(op.getInput());
229 IntegerAttr offsetAttr;
230 if (!matchPattern(op.getLowIndex(), m_Constant(&offsetAttr)))
232 slice.offset += offsetAttr.getValue().getZExtValue();
233 slice.length =
getLength(value.getType());
240 if (
auto op = value.getDefiningOp<SigArrayGetOp>()) {
241 auto input = traceProjection(op.getInput());
244 IntegerAttr indexAttr;
245 if (!matchPattern(op.getIndex(), m_Constant(&indexAttr)))
247 unsigned offset = input.offset + indexAttr.getValue().getZExtValue();
249 slice.signal = internSignal(value, input.signal, offset);
250 slice.length =
getLength(value.getType());
254 if (
auto op = value.getDefiningOp<SigStructExtractOp>()) {
255 auto structType = cast<hw::StructType>(
256 cast<hw::InOutType>(op.getInput().getType()).getElementType());
257 auto input = traceProjection(op.getInput());
260 assert(input.offset == 0);
261 assert(input.length == structType.getElements().size());
262 unsigned index = *structType.getFieldIndex(op.getFieldAttr());
264 slice.signal = internSignal(value, input.signal, index);
265 slice.length =
getLength(value.getType());
271 slice.signal = internSignal(value);
272 slice.length =
getLength(value.getType());
279Signal *ModuleContext::internSignal(Value root) {
280 auto &slot = internedSignals[{root, 0}];
282 slot =
new (signalAlloc.Allocate()) Signal(root);
283 rootSignals.push_back(slot);
291Signal *ModuleContext::internSignal(Value value, Signal *parent,
293 auto &slot = internedSignals[{parent, index}];
295 slot =
new (signalAlloc.Allocate()) Signal(value, parent, index);
296 parent->subsignals.push_back(slot);
310void ModuleContext::aggregateDrives(Signal &signal) {
316 SmallMapVector<std::pair<Value, Value>, SmallVector<DriveSlice>, 2> drives;
317 SmallPtrSet<Operation *, 8> knownDrives;
318 auto addDrive = [&](DrvOp op,
unsigned offset,
unsigned length) {
319 knownDrives.insert(op);
320 drives[{op.getTime(), op.getEnable()}].push_back(
321 DriveSlice{op, op.getValue(), offset, length});
323 for (
auto *subsignal : signal.subsignals) {
324 aggregateDrives(*subsignal);
332 for (
auto driveOp : subsignal->completeDrives)
333 addDrive(driveOp, subsignal->indexInParent, 0);
337 for (
auto slice : signal.slices) {
338 for (
auto &use : slice.value.getUses()) {
339 auto driveOp = dyn_cast<DrvOp>(use.getOwner());
340 if (driveOp && use.getOperandNumber() == 0 &&
341 driveOp->getBlock() == slice.value.getParentBlock())
342 addDrive(driveOp, slice.offset, slice.length);
349 SmallSetVector<Value, 8> worklist;
350 worklist.insert(signal.value);
351 bool hasUnknownUses =
false;
352 while (!worklist.empty() && !hasUnknownUses) {
353 auto value = worklist.pop_back_val();
354 for (
auto *user : value.getUsers()) {
355 if (isa<PrbOp>(user))
357 if (isa<DrvOp>(user) && knownDrives.contains(user))
359 if (isa<SigExtractOp, SigStructExtractOp, SigArrayGetOp, SigArraySliceOp>(
361 worklist.insert(user->getResult(0));
364 hasUnknownUses =
true;
372 if (!hasUnknownUses && drives.size() == 1) {
373 auto &slices = drives.begin()->second;
374 addDefaultDriveSlices(signal, slices);
379 for (
auto &[key, slices] : drives) {
380 llvm::sort(slices, [](
auto &a,
auto &b) {
return a.offset < b.offset; });
381 aggregateDriveSlices(signal, key.first, key.second, slices);
388void ModuleContext::addDefaultDriveSlices(Signal &signal,
389 SmallVectorImpl<DriveSlice> &slices) {
390 auto type = cast<hw::InOutType>(signal.value.getType()).getElementType();
393 llvm::sort(slices, [](
auto &a,
auto &b) {
return a.offset < b.offset; });
398 bool anyOverlaps =
false;
399 bool needSeparateFields = isa<hw::StructType>(type);
400 SmallVector<DriveSlice> gapSlices;
401 auto fillGap = [&](
unsigned from,
unsigned to) {
408 if (needSeparateFields) {
409 for (
auto idx = from; idx < to; ++idx)
410 gapSlices.push_back(DriveSlice{DrvOp{}, Value{}, idx, 0});
412 gapSlices.push_back(DriveSlice{DrvOp{}, Value{}, from, to - from});
419 unsigned expectedOffset = 0;
420 for (
auto slice : slices) {
421 fillGap(expectedOffset, slice.offset);
422 expectedOffset = slice.offset + std::max<unsigned>(1, slice.length);
426 fillGap(expectedOffset,
getLength(signal.value.getType()));
430 if (anyOverlaps || gapSlices.empty())
445 auto signalOp = signal.value.getDefiningOp<SignalOp>();
448 auto defaultValue = signalOp.getInit();
451 ImplicitLocOpBuilder builder(signal.value.getLoc(),
452 signal.value.getContext());
453 builder.setInsertionPointAfterValue(signal.value);
455 for (
auto &slice : gapSlices) {
456 LLVM_DEBUG(llvm::dbgs()
457 <<
"- Filling gap " << signal <<
"[" << slice.offset <<
".."
458 << (slice.offset + slice.length) <<
"] with initial value\n");
461 if (
auto intType = dyn_cast<IntegerType>(type)) {
465 defaultValue, slice.offset);
470 if (
auto structType = dyn_cast<hw::StructType>(type)) {
471 assert(slice.length == 0);
473 builder, defaultValue, structType.getElements()[slice.offset]);
478 if (
auto arrayType = dyn_cast<hw::ArrayType>(type)) {
482 APInt(llvm::Log2_64_Ceil(arrayType.getNumElements()), slice.offset));
484 builder, hw::ArrayType::get(arrayType.getElementType(), slice.length),
485 defaultValue, offset);
493 slices.append(gapSlices.begin(), gapSlices.end());
498void ModuleContext::aggregateDriveSlices(Signal &signal, Value driveDelay,
500 ArrayRef<DriveSlice> slices) {
502 unsigned expectedOffset = 0;
503 for (
auto slice : slices) {
504 assert(slice.value &&
"all slices must have an assigned value");
505 if (slice.offset != expectedOffset) {
512 expectedOffset += std::max<unsigned>(1, slice.length);
514 if (expectedOffset !=
getLength(signal.value.getType())) {
515 LLVM_DEBUG(llvm::dbgs()
516 <<
"- Signal " << signal <<
" not completely driven\n");
523 if (slices.size() == 1 && slices[0].length != 0 && slices[0].op) {
524 signal.completeDrives.push_back(slices[0].op);
528 llvm::dbgs() <<
"- Aggregating " << signal <<
" drives (delay ";
529 driveDelay.printAsOperand(llvm::dbgs(), OpPrintingFlags().useLocalScope());
531 llvm::dbgs() <<
" if ";
532 driveEnable.printAsOperand(llvm::dbgs(),
533 OpPrintingFlags().useLocalScope());
535 llvm::dbgs() <<
")\n";
539 auto type = cast<hw::InOutType>(signal.value.getType()).getElementType();
540 ImplicitLocOpBuilder builder(signal.value.getLoc(),
541 signal.value.getContext());
542 builder.setInsertionPointAfterValue(signal.value);
545 if (
auto intType = dyn_cast<IntegerType>(type)) {
549 SmallVector<Value> operands;
550 for (
auto slice : slices)
551 operands.push_back(slice.value);
552 std::reverse(operands.begin(), operands.end());
553 result = comb::ConcatOp::create(builder, operands);
554 LLVM_DEBUG(llvm::dbgs() <<
" - Created " << result <<
"\n");
558 if (
auto structType = dyn_cast<hw::StructType>(type)) {
561 SmallVector<Value> operands;
562 for (
auto slice : slices)
563 operands.push_back(slice.value);
565 LLVM_DEBUG(llvm::dbgs() <<
" - Created " << result <<
"\n");
569 if (
auto arrayType = dyn_cast<hw::ArrayType>(type)) {
573 SmallVector<Value> scalars;
574 SmallVector<Value> aggregates;
575 auto flushScalars = [&] {
578 std::reverse(scalars.begin(), scalars.end());
580 aggregates.push_back(aggregate);
582 LLVM_DEBUG(llvm::dbgs() <<
" - Created " << aggregate <<
"\n");
584 for (
auto slice : slices) {
585 if (slice.length == 0) {
586 scalars.push_back(slice.value);
589 aggregates.push_back(slice.value);
596 result = aggregates.back();
597 if (aggregates.size() != 1) {
598 std::reverse(aggregates.begin(), aggregates.end());
600 LLVM_DEBUG(llvm::dbgs() <<
" - Created " << result <<
"\n");
607 DrvOp::create(builder, signal.value, result, driveDelay, driveEnable);
608 signal.completeDrives.push_back(driveOp);
609 LLVM_DEBUG(llvm::dbgs() <<
" - Created " << driveOp <<
"\n");
612 for (
auto slice : slices) {
615 LLVM_DEBUG(llvm::dbgs() <<
" - Removed " << slice.op <<
"\n");
616 pruner.eraseNow(slice.op);
625struct CombineDrivesPass
626 :
public llhd::impl::CombineDrivesPassBase<CombineDrivesPass> {
627 void runOnOperation()
override;
631void CombineDrivesPass::runOnOperation() {
632 LLVM_DEBUG(llvm::dbgs() <<
"Combining drives in "
633 << getOperation().getModuleNameAttr() <<
"\n");
634 ModuleContext context(getOperation());
638 if (isa<SigExtractOp, SigArraySliceOp, SigArrayGetOp, SigStructExtractOp>(
640 context.traceProjection(op.getResult(0));
643 for (
auto *signal : context.rootSignals)
644 context.aggregateDrives(*signal);
647 context.pruner.eraseNow();
assert(baseType &&"element must be base type")
static unsigned getLength(Type type)
Determine the number of elements in a type.
static Block * getBodyBlock(FModuleLike mod)
create(array_value, low_index, ret_type)
create(elements, Type result_type=None)
OS & operator<<(OS &os, const InnerSymTarget &target)
Printing InnerSymTarget's.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Utility that tracks operations that have potentially become unused and allows them to be cleaned up a...