12#include "mlir/Pass/Pass.h" 
   13#include "mlir/Transforms/DialectConversion.h" 
   14#include "llvm/ADT/APInt.h" 
   18#define GEN_PASS_DEF_HWAGGREGATETOCOMB 
   19#include "circt/Dialect/HW/Passes.h.inc" 
   29template <
typename OpTy>
 
   34  matchAndRewrite(OpTy op, OpAdaptor adaptor,
 
   35                  ConversionPatternRewriter &rewriter)
 const override {
 
   36    rewriter.replaceOpWithNewOp<
comb::ConcatOp>(op, adaptor.getInputs());
 
   41struct HWAggregateConstantOpConversion
 
   45  static LogicalResult peelAttribute(Location loc, Attribute attr,
 
   46                                     ConversionPatternRewriter &rewriter,
 
   48    SmallVector<Attribute> worklist;
 
   49    worklist.push_back(attr);
 
   50    unsigned nextInsertion = intVal.getBitWidth();
 
   52    while (!worklist.empty()) {
 
   53      auto current = worklist.pop_back_val();
 
   54      if (
auto innerArray = dyn_cast<ArrayAttr>(current)) {
 
   55        for (
auto elem : 
llvm::reverse(innerArray))
 
   56          worklist.push_back(elem);
 
   60      if (
auto intAttr = dyn_cast<IntegerAttr>(current)) {
 
   61        auto chunk = intAttr.getValue();
 
   62        nextInsertion -= chunk.getBitWidth();
 
   63        intVal.insertBits(chunk, nextInsertion);
 
   74  matchAndRewrite(hw::AggregateConstantOp op, OpAdaptor adaptor,
 
   75                  ConversionPatternRewriter &rewriter)
 const override {
 
   77    SmallVector<Value> results;
 
   78    auto bitWidth = hw::getBitWidth(op.getType());
 
   79    assert(bitWidth >= 0 && 
"bit width must be known for constant");
 
   80    APInt intVal(bitWidth, 0);
 
   81    if (failed(peelAttribute(op.getLoc(), adaptor.getFieldsAttr(), rewriter,
 
   94                  ConversionPatternRewriter &rewriter)
 const override {
 
   95    SmallVector<Value> results;
 
   96    auto arrayType = cast<hw::ArrayType>(op.getInput().getType());
 
   97    auto elemType = arrayType.getElementType();
 
   99    auto elemWidth = hw::getBitWidth(elemType);
 
  101      return rewriter.notifyMatchFailure(op.getLoc(), 
"unknown element width");
 
  103    auto lowered = adaptor.getInput();
 
  106          op.getLoc(), lowered, i * elemWidth, elemWidth));
 
  108    SmallVector<Value> bits;
 
  109    comb::extractBits(rewriter, op.getIndex(), bits);
 
  110    auto result = comb::constructMuxTree(rewriter, op.getLoc(), bits, results,
 
  113    rewriter.replaceOp(op, result);
 
  122  matchAndRewrite(hw::ArrayInjectOp op, OpAdaptor adaptor,
 
  123                  ConversionPatternRewriter &rewriter)
 const override {
 
  124    auto arrayType = cast<hw::ArrayType>(op.getInput().getType());
 
  125    auto elemType = arrayType.getElementType();
 
  127    auto elemWidth = hw::getBitWidth(elemType);
 
  129      return rewriter.notifyMatchFailure(op.getLoc(), 
"unknown element width");
 
  131    Location loc = op.getLoc();
 
  134    SmallVector<Value> originalElements;
 
  135    auto inputArray = adaptor.getInput();
 
  138          loc, inputArray, i * elemWidth, elemWidth));
 
  143    SmallVector<Value> arrayRows;
 
  145    for (
int injectIdx = 
numElements - 1; injectIdx >= 0; --injectIdx) {
 
  146      SmallVector<Value> rowElements;
 
  151      for (
int originalIdx = 
numElements - 1; originalIdx >= 0; --originalIdx) {
 
  152        if (originalIdx == injectIdx) {
 
  153          rowElements.push_back(adaptor.getElement());
 
  155          rowElements.push_back(originalElements[originalIdx]);
 
  161      arrayRows.push_back(row);
 
  173    rewriter.replaceOp(op, arrayGetOp);
 
  180class AggregateTypeConverter : 
public TypeConverter {
 
  182  AggregateTypeConverter() {
 
  183    addConversion([](Type type) -> Type { 
return type; });
 
  184    addConversion([](hw::ArrayType t) -> Type {
 
  185      return IntegerType::get(t.getContext(), hw::getBitWidth(t));
 
  187    addConversion([](hw::StructType t) -> Type {
 
  188      return IntegerType::get(t.getContext(), hw::getBitWidth(t));
 
  190    addTargetMaterialization([](mlir::OpBuilder &builder, mlir::Type resultType,
 
  191                                mlir::ValueRange inputs,
 
  192                                mlir::Location loc) -> mlir::Value {
 
  193      if (inputs.size() != 1)
 
  200    addSourceMaterialization([](mlir::OpBuilder &builder, mlir::Type resultType,
 
  201                                mlir::ValueRange inputs,
 
  202                                mlir::Location loc) -> mlir::Value {
 
  203      if (inputs.size() != 1)
 
  214    RewritePatternSet &
patterns, AggregateTypeConverter &typeConverter) {
 
  215  patterns.add<HWArrayGetOpConversion,
 
  216               HWArrayCreateLikeOpConversion<hw::ArrayCreateOp>,
 
  217               HWArrayCreateLikeOpConversion<hw::ArrayConcatOp>,
 
  218               HWAggregateConstantOpConversion, HWArrayInjectOpConversion>(
 
  219      typeConverter, 
patterns.getContext());
 
 
  223struct HWAggregateToCombPass
 
  224    : 
public hw::impl::HWAggregateToCombBase<HWAggregateToCombPass> {
 
  225  void runOnOperation() 
override;
 
  226  using HWAggregateToCombBase<HWAggregateToCombPass>::HWAggregateToCombBase;
 
  230void HWAggregateToCombPass::runOnOperation() {
 
  231  ConversionTarget target(getContext());
 
  235                      hw::AggregateConstantOp, hw::ArrayInjectOp>();
 
  237  target.addLegalDialect<hw::HWDialect, comb::CombDialect>();
 
  239  RewritePatternSet 
patterns(&getContext());
 
  240  AggregateTypeConverter typeConverter;
 
  243  if (failed(mlir::applyPartialConversion(getOperation(), target,
 
  245    return signalPassFailure();
 
assert(baseType &&"element must be base type")
 
MlirType uint64_t numElements
 
static void populateHWAggregateToCombOpConversionPatterns(RewritePatternSet &patterns, AggregateTypeConverter &typeConverter)
 
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.