CIRCT  18.0.0git
LowerArcToLLVM.cpp
Go to the documentation of this file.
1 //===- LowerArcToLLVM.cpp -------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "../PassDetail.h"
16 #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
17 #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
18 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
19 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
20 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
21 #include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
22 #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
23 #include "mlir/Dialect/Func/IR/FuncOps.h"
24 #include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
25 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
26 #include "mlir/Dialect/SCF/IR/SCF.h"
27 #include "mlir/IR/BuiltinDialect.h"
28 #include "mlir/Transforms/DialectConversion.h"
29 #include "llvm/Support/Debug.h"
30 
31 #define DEBUG_TYPE "arc-lower-to-llvm"
32 
33 using namespace mlir;
34 using namespace circt;
35 using namespace arc;
36 using namespace hw;
37 
38 //===----------------------------------------------------------------------===//
39 // Lowering Patterns
40 //===----------------------------------------------------------------------===//
41 
42 namespace {
43 
44 struct DefineOpLowering : public OpConversionPattern<arc::DefineOp> {
45  using OpConversionPattern::OpConversionPattern;
46  LogicalResult
47  matchAndRewrite(arc::DefineOp op, OpAdaptor adaptor,
48  ConversionPatternRewriter &rewriter) const final {
49  auto func = rewriter.create<mlir::func::FuncOp>(op.getLoc(), op.getName(),
50  op.getFunctionType());
51  func->setAttr(
52  "llvm.linkage",
53  LLVM::LinkageAttr::get(getContext(), LLVM::linkage::Linkage::Internal));
54  rewriter.inlineRegionBefore(op.getRegion(), func.getBody(), func.end());
55  rewriter.eraseOp(op);
56  return success();
57  }
58 };
59 
60 struct OutputOpLowering : public OpConversionPattern<arc::OutputOp> {
61  using OpConversionPattern::OpConversionPattern;
62  LogicalResult
63  matchAndRewrite(arc::OutputOp op, OpAdaptor adaptor,
64  ConversionPatternRewriter &rewriter) const final {
65  rewriter.replaceOpWithNewOp<func::ReturnOp>(op, adaptor.getOutputs());
66  return success();
67  }
68 };
69 
70 struct CallOpLowering : public OpConversionPattern<arc::CallOp> {
71  using OpConversionPattern::OpConversionPattern;
72  LogicalResult
73  matchAndRewrite(arc::CallOp op, OpAdaptor adaptor,
74  ConversionPatternRewriter &rewriter) const final {
75  SmallVector<Type> newResultTypes;
76  if (failed(
77  typeConverter->convertTypes(op.getResultTypes(), newResultTypes)))
78  return failure();
79  rewriter.replaceOpWithNewOp<func::CallOp>(
80  op, newResultTypes, op.getArcAttr(), adaptor.getInputs());
81  return success();
82  }
83 };
84 
85 struct StateOpLowering : public OpConversionPattern<arc::StateOp> {
86  using OpConversionPattern::OpConversionPattern;
87  LogicalResult
88  matchAndRewrite(arc::StateOp op, OpAdaptor adaptor,
89  ConversionPatternRewriter &rewriter) const final {
90  SmallVector<Type> newResultTypes;
91  if (failed(
92  typeConverter->convertTypes(op.getResultTypes(), newResultTypes)))
93  return failure();
94  rewriter.replaceOpWithNewOp<func::CallOp>(
95  op, newResultTypes, op.getArcAttr(), adaptor.getInputs());
96  return success();
97  }
98 };
99 
100 struct AllocStorageOpLowering
101  : public OpConversionPattern<arc::AllocStorageOp> {
102  using OpConversionPattern::OpConversionPattern;
103  LogicalResult
104  matchAndRewrite(arc::AllocStorageOp op, OpAdaptor adaptor,
105  ConversionPatternRewriter &rewriter) const final {
106  auto type = typeConverter->convertType(op.getType());
107  if (!op.getOffset().has_value())
108  return failure();
109  rewriter.replaceOpWithNewOp<LLVM::GEPOp>(op, type, adaptor.getInput(),
110  LLVM::GEPArg(*op.getOffset()));
111  return success();
112  }
113 };
114 
115 template <class ConcreteOp>
116 struct AllocStateLikeOpLowering : public OpConversionPattern<ConcreteOp> {
119  using OpAdaptor = typename ConcreteOp::Adaptor;
120 
121  LogicalResult
122  matchAndRewrite(ConcreteOp op, OpAdaptor adaptor,
123  ConversionPatternRewriter &rewriter) const final {
124  // Get a pointer to the correct offset in the storage.
125  auto offsetAttr = op->template getAttrOfType<IntegerAttr>("offset");
126  if (!offsetAttr)
127  return failure();
128  Value ptr = rewriter.create<LLVM::GEPOp>(
129  op->getLoc(), adaptor.getStorage().getType(), adaptor.getStorage(),
130  LLVM::GEPArg(offsetAttr.getValue().getZExtValue()));
131 
132  // Cast the raw storage pointer to a pointer of the state's actual type.
133  auto type = typeConverter->convertType(op.getType());
134  if (type != ptr.getType())
135  ptr = rewriter.create<LLVM::BitcastOp>(op->getLoc(), type, ptr);
136 
137  rewriter.replaceOp(op, ptr);
138  return success();
139  }
140 };
141 
142 struct StateReadOpLowering : public OpConversionPattern<arc::StateReadOp> {
143  using OpConversionPattern::OpConversionPattern;
144  LogicalResult
145  matchAndRewrite(arc::StateReadOp op, OpAdaptor adaptor,
146  ConversionPatternRewriter &rewriter) const final {
147  rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, adaptor.getState());
148  return success();
149  }
150 };
151 
152 struct StateWriteOpLowering : public OpConversionPattern<arc::StateWriteOp> {
153  using OpConversionPattern::OpConversionPattern;
154  LogicalResult
155  matchAndRewrite(arc::StateWriteOp op, OpAdaptor adaptor,
156  ConversionPatternRewriter &rewriter) const final {
157  if (adaptor.getCondition()) {
158  rewriter.replaceOpWithNewOp<scf::IfOp>(
159  op, adaptor.getCondition(), [&](auto &builder, auto loc) {
160  builder.template create<LLVM::StoreOp>(loc, adaptor.getValue(),
161  adaptor.getState());
162  builder.template create<scf::YieldOp>(loc);
163  });
164  } else {
165  rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, adaptor.getValue(),
166  adaptor.getState());
167  }
168  return success();
169  }
170 };
171 
172 struct AllocMemoryOpLowering : public OpConversionPattern<arc::AllocMemoryOp> {
173  using OpConversionPattern::OpConversionPattern;
174  LogicalResult
175  matchAndRewrite(arc::AllocMemoryOp op, OpAdaptor adaptor,
176  ConversionPatternRewriter &rewriter) const final {
177  auto offsetAttr = op->getAttrOfType<IntegerAttr>("offset");
178  if (!offsetAttr)
179  return failure();
180  Value ptr = rewriter.create<LLVM::GEPOp>(
181  op.getLoc(), adaptor.getStorage().getType(), adaptor.getStorage(),
182  LLVM::GEPArg(offsetAttr.getValue().getZExtValue()));
183 
184  auto type = typeConverter->convertType(op.getType());
185  if (type != ptr.getType())
186  ptr = rewriter.create<LLVM::BitcastOp>(op.getLoc(), type, ptr);
187 
188  rewriter.replaceOp(op, ptr);
189  return success();
190  }
191 };
192 
193 struct StorageGetOpLowering : public OpConversionPattern<arc::StorageGetOp> {
194  using OpConversionPattern::OpConversionPattern;
195  LogicalResult
196  matchAndRewrite(arc::StorageGetOp op, OpAdaptor adaptor,
197  ConversionPatternRewriter &rewriter) const final {
198  Value offset = rewriter.create<LLVM::ConstantOp>(
199  op.getLoc(), rewriter.getI32Type(), op.getOffsetAttr());
200  Value ptr = rewriter.create<LLVM::GEPOp>(op.getLoc(),
201  adaptor.getStorage().getType(),
202  adaptor.getStorage(), offset);
203  auto type = typeConverter->convertType(op.getType());
204  if (type != ptr.getType())
205  ptr = rewriter.create<LLVM::BitcastOp>(op.getLoc(), type, ptr);
206  rewriter.replaceOp(op, ptr);
207  return success();
208  }
209 };
210 
211 struct MemoryAccess {
212  Value ptr;
213  Value withinBounds;
214 };
215 
216 static MemoryAccess prepareMemoryAccess(Location loc, Value memory,
217  Value address, MemoryType type,
218  ConversionPatternRewriter &rewriter) {
219  auto zextAddrType = rewriter.getIntegerType(
220  address.getType().cast<IntegerType>().getWidth() + 1);
221  Value addr = rewriter.create<LLVM::ZExtOp>(loc, zextAddrType, address);
222  Value addrLimit = rewriter.create<LLVM::ConstantOp>(
223  loc, zextAddrType, rewriter.getI32IntegerAttr(type.getNumWords()));
224  Value withinBounds = rewriter.create<LLVM::ICmpOp>(
225  loc, LLVM::ICmpPredicate::ult, addr, addrLimit);
226  auto ptrType = LLVM::LLVMPointerType::get(type.getWordType());
227  Value ptr =
228  rewriter.create<LLVM::GEPOp>(loc, ptrType, memory, ValueRange{addr});
229  return {ptr, withinBounds};
230 }
231 
232 struct MemoryReadOpLowering : public OpConversionPattern<arc::MemoryReadOp> {
233  using OpConversionPattern::OpConversionPattern;
234  LogicalResult
235  matchAndRewrite(arc::MemoryReadOp op, OpAdaptor adaptor,
236  ConversionPatternRewriter &rewriter) const final {
237  auto type = typeConverter->convertType(op.getType());
238  auto access = prepareMemoryAccess(
239  op.getLoc(), adaptor.getMemory(), adaptor.getAddress(),
240  op.getMemory().getType().cast<MemoryType>(), rewriter);
241 
242  // Only attempt to read the memory if the address is within bounds,
243  // otherwise produce a zero value.
244  rewriter.replaceOpWithNewOp<scf::IfOp>(
245  op, access.withinBounds,
246  [&](auto &builder, auto loc) {
247  Value loadOp = builder.template create<LLVM::LoadOp>(loc, access.ptr);
248  builder.template create<scf::YieldOp>(loc, loadOp);
249  },
250  [&](auto &builder, auto loc) {
251  Value zeroValue = builder.template create<LLVM::ConstantOp>(
252  loc, type, builder.getI64IntegerAttr(0));
253  builder.template create<scf::YieldOp>(loc, zeroValue);
254  });
255  return success();
256  }
257 };
258 
259 struct MemoryWriteOpLowering : public OpConversionPattern<arc::MemoryWriteOp> {
260  using OpConversionPattern::OpConversionPattern;
261  LogicalResult
262  matchAndRewrite(arc::MemoryWriteOp op, OpAdaptor adaptor,
263  ConversionPatternRewriter &rewriter) const final {
264  auto access = prepareMemoryAccess(
265  op.getLoc(), adaptor.getMemory(), adaptor.getAddress(),
266  op.getMemory().getType().cast<MemoryType>(), rewriter);
267  auto enable = access.withinBounds;
268  if (adaptor.getEnable())
269  enable = rewriter.create<LLVM::AndOp>(op.getLoc(), adaptor.getEnable(),
270  enable);
271 
272  // Only attempt to write the memory if the address is within bounds.
273  rewriter.replaceOpWithNewOp<scf::IfOp>(
274  op, enable, [&](auto &builder, auto loc) {
275  builder.template create<LLVM::StoreOp>(loc, adaptor.getData(),
276  access.ptr);
277  builder.template create<scf::YieldOp>(loc);
278  });
279  return success();
280  }
281 };
282 
283 /// A dummy lowering for clock gates to an AND gate.
284 struct ClockGateOpLowering : public OpConversionPattern<arc::ClockGateOp> {
285  using OpConversionPattern::OpConversionPattern;
286  LogicalResult
287  matchAndRewrite(arc::ClockGateOp op, OpAdaptor adaptor,
288  ConversionPatternRewriter &rewriter) const final {
289  rewriter.replaceOpWithNewOp<comb::AndOp>(op, adaptor.getInput(),
290  adaptor.getEnable(), true);
291  return success();
292  }
293 };
294 
295 struct ReturnOpLowering : public OpConversionPattern<func::ReturnOp> {
296  using OpConversionPattern::OpConversionPattern;
297  LogicalResult
298  matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor,
299  ConversionPatternRewriter &rewriter) const override {
300  rewriter.replaceOpWithNewOp<func::ReturnOp>(op, adaptor.getOperands());
301  return success();
302  }
303 };
304 
305 struct FuncCallOpLowering : public OpConversionPattern<func::CallOp> {
306  using OpConversionPattern::OpConversionPattern;
307  LogicalResult
308  matchAndRewrite(func::CallOp op, OpAdaptor adaptor,
309  ConversionPatternRewriter &rewriter) const override {
310  SmallVector<Type> newResultTypes;
311  if (failed(
312  typeConverter->convertTypes(op->getResultTypes(), newResultTypes)))
313  return failure();
314  rewriter.replaceOpWithNewOp<func::CallOp>(
315  op, op.getCalleeAttr(), newResultTypes, adaptor.getOperands());
316  return success();
317  }
318 };
319 
320 struct ZeroCountOpLowering : public OpConversionPattern<arc::ZeroCountOp> {
321  using OpConversionPattern::OpConversionPattern;
322  LogicalResult
323  matchAndRewrite(arc::ZeroCountOp op, OpAdaptor adaptor,
324  ConversionPatternRewriter &rewriter) const override {
325  // Use poison when input is zero.
326  IntegerAttr isZeroPoison = rewriter.getBoolAttr(true);
327 
328  if (op.getPredicate() == arc::ZeroCountPredicate::leading) {
329  rewriter.replaceOpWithNewOp<LLVM::CountLeadingZerosOp>(
330  op, adaptor.getInput().getType(), adaptor.getInput(), isZeroPoison);
331  return success();
332  }
333 
334  rewriter.replaceOpWithNewOp<LLVM::CountTrailingZerosOp>(
335  op, adaptor.getInput().getType(), adaptor.getInput(), isZeroPoison);
336  return success();
337  }
338 };
339 
340 } // namespace
341 
342 static bool isArcType(Type type) {
343  return type.isa<StorageType>() || type.isa<MemoryType>() ||
344  type.isa<StateType>();
345 }
346 
347 static bool hasArcType(TypeRange types) {
348  return llvm::any_of(types, isArcType);
349 }
350 
351 static bool hasArcType(ValueRange values) {
352  return hasArcType(values.getTypes());
353 }
354 
355 template <typename Op>
356 static void addGenericLegality(ConversionTarget &target) {
357  target.addDynamicallyLegalOp<Op>([](Op op) {
358  return !hasArcType(op->getOperands()) && !hasArcType(op->getResults());
359  });
360 }
361 
362 static void populateLegality(ConversionTarget &target) {
363  target.addLegalDialect<mlir::BuiltinDialect>();
364  target.addLegalDialect<hw::HWDialect>();
365  target.addLegalDialect<comb::CombDialect>();
366  target.addLegalDialect<func::FuncDialect>();
367  target.addLegalDialect<scf::SCFDialect>();
368  target.addLegalDialect<LLVM::LLVMDialect>();
369 
370  target.addIllegalOp<arc::DefineOp>();
371  target.addIllegalOp<arc::OutputOp>();
372  target.addIllegalOp<arc::StateOp>();
373  target.addIllegalOp<arc::ClockTreeOp>();
374  target.addIllegalOp<arc::PassThroughOp>();
375 
376  target.addDynamicallyLegalOp<func::FuncOp>([](func::FuncOp op) {
377  auto argsConverted = llvm::none_of(op.getBlocks(), [](auto &block) {
378  return hasArcType(block.getArguments());
379  });
380  auto resultsConverted = !hasArcType(op.getResultTypes());
381  return argsConverted && resultsConverted;
382  });
383  addGenericLegality<func::ReturnOp>(target);
384  addGenericLegality<func::CallOp>(target);
385 }
386 
387 static void populateTypeConversion(TypeConverter &typeConverter) {
388  typeConverter.addConversion([&](StorageType type) {
389  return LLVM::LLVMPointerType::get(IntegerType::get(type.getContext(), 8));
390  });
391  typeConverter.addConversion([&](MemoryType type) {
393  IntegerType::get(type.getContext(), type.getStride() * 8));
394  });
395  typeConverter.addConversion([&](StateType type) {
397  typeConverter.convertType(type.getType()));
398  });
399  typeConverter.addConversion([](hw::ArrayType type) { return type; });
400  typeConverter.addConversion([](mlir::IntegerType type) { return type; });
401 }
402 
403 static void populateOpConversion(RewritePatternSet &patterns,
404  TypeConverter &typeConverter) {
405  auto *context = patterns.getContext();
406  // clang-format off
407  patterns.add<
408  AllocMemoryOpLowering,
409  AllocStateLikeOpLowering<arc::AllocStateOp>,
410  AllocStateLikeOpLowering<arc::RootInputOp>,
411  AllocStateLikeOpLowering<arc::RootOutputOp>,
412  AllocStorageOpLowering,
413  CallOpLowering,
414  ClockGateOpLowering,
415  DefineOpLowering,
416  MemoryReadOpLowering,
417  MemoryWriteOpLowering,
418  OutputOpLowering,
419  FuncCallOpLowering,
420  ReturnOpLowering,
421  StateOpLowering,
422  StateReadOpLowering,
423  StateWriteOpLowering,
424  StorageGetOpLowering,
425  ZeroCountOpLowering
426  >(typeConverter, context);
427  // clang-format on
428 
429  mlir::populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
430  patterns, typeConverter);
431 }
432 
433 //===----------------------------------------------------------------------===//
434 // Pass Implementation
435 //===----------------------------------------------------------------------===//
436 
437 namespace {
438 struct LowerArcToLLVMPass : public LowerArcToLLVMBase<LowerArcToLLVMPass> {
439  void runOnOperation() override;
440  LogicalResult lowerToMLIR();
441  LogicalResult lowerArcToLLVM();
442 };
443 } // namespace
444 
445 void LowerArcToLLVMPass::runOnOperation() {
446  // Remove the models since we only care about the clock functions at this
447  // point.
448  // NOTE: In the future we may want to have an earlier pass lower the model
449  // into a separate `*_eval` function that checks for rising edges on clocks
450  // and then calls the appropriate function. At that point we won't have to
451  // delete models here anymore.
452  for (auto op : llvm::make_early_inc_range(getOperation().getOps<ModelOp>()))
453  op.erase();
454 
455  if (failed(lowerToMLIR()))
456  return signalPassFailure();
457 
458  if (failed(lowerArcToLLVM()))
459  return signalPassFailure();
460 }
461 
462 /// Perform the lowering to Func and SCF.
463 LogicalResult LowerArcToLLVMPass::lowerToMLIR() {
464  LLVM_DEBUG(llvm::dbgs() << "Lowering arcs to Func/SCF dialects\n");
465  ConversionTarget target(getContext());
466  TypeConverter converter;
467  RewritePatternSet patterns(&getContext());
468  populateLegality(target);
469  populateTypeConversion(converter);
470  populateOpConversion(patterns, converter);
471  return applyPartialConversion(getOperation(), target, std::move(patterns));
472 }
473 
474 /// Perform lowering to LLVM.
475 LogicalResult LowerArcToLLVMPass::lowerArcToLLVM() {
476  LLVM_DEBUG(llvm::dbgs() << "Lowering to LLVM dialect\n");
477 
478  Namespace globals;
479  SymbolCache cache;
480  cache.addDefinitions(getOperation());
481  globals.add(cache);
482 
483  LLVMConversionTarget target(getContext());
484  LLVMTypeConverter converter(&getContext());
485  RewritePatternSet patterns(&getContext());
486  target.addLegalOp<mlir::ModuleOp>();
487  target.addIllegalOp<arc::ModelOp>();
488  populateSCFToControlFlowConversionPatterns(patterns);
489  populateFuncToLLVMConversionPatterns(converter, patterns);
490  cf::populateControlFlowToLLVMConversionPatterns(converter, patterns);
491 
492  DenseMap<std::pair<Type, ArrayAttr>, LLVM::GlobalOp> constAggregateGlobalsMap;
493  populateHWToLLVMConversionPatterns(converter, patterns, globals,
494  constAggregateGlobalsMap);
497  arith::populateArithToLLVMConversionPatterns(converter, patterns);
498 
499  return applyFullConversion(getOperation(), target, std::move(patterns));
500 }
501 
502 std::unique_ptr<OperationPass<ModuleOp>> circt::createLowerArcToLLVMPass() {
503  return std::make_unique<LowerArcToLLVMPass>();
504 }
static void populateLegality(ConversionTarget &target)
static bool hasArcType(TypeRange types)
static void populateOpConversion(RewritePatternSet &patterns, TypeConverter &typeConverter)
static void addGenericLegality(ConversionTarget &target)
static void populateTypeConversion(TypeConverter &typeConverter)
static bool isArcType(Type type)
Builder builder
A namespace that is used to store existing names and generate new names in some scope within the IR.
Definition: Namespace.h:29
void add(SymbolCache &symCache)
SymbolCache initializer; initialize from every key that is convertible to a StringAttr in the SymbolC...
Definition: Namespace.h:43
void addDefinitions(mlir::Operation *top)
Populate the symbol cache with all symbol-defining operations within the 'top' operation.
Definition: SymCache.cpp:23
Default symbol cache implementation; stores associations between names (StringAttr's) to mlir::Operat...
Definition: SymCache.h:85
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:53
This file defines an intermediate representation for circuits acting as an abstraction for constraint...
void populateHWToLLVMConversionPatterns(mlir::LLVMTypeConverter &converter, RewritePatternSet &patterns, Namespace &globals, DenseMap< std::pair< Type, ArrayAttr >, mlir::LLVM::GlobalOp > &constAggregateGlobalsMap)
Get the HW to LLVM conversion patterns.
void populateCombToLLVMConversionPatterns(mlir::LLVMTypeConverter &converter, RewritePatternSet &patterns)
Get the Comb to LLVM conversion patterns.
void populateHWToLLVMTypeConversions(mlir::LLVMTypeConverter &converter)
Get the HW to LLVM type conversions.
std::unique_ptr< OperationPass< ModuleOp > > createLowerArcToLLVMPass()
Definition: hw.py:1
mlir::raw_indented_ostream & dbgs()
Definition: Utility.h:28