CIRCT  18.0.0git
HWArithToHW.cpp
Go to the documentation of this file.
1 //===- HWArithToHW.cpp - HWArith to HW Lowering pass ------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This is the main HWArith to HW Lowering Pass Implementation.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 #include "../PassDetail.h"
17 #include "circt/Dialect/HW/HWOps.h"
20 #include "circt/Dialect/SV/SVOps.h"
22 
23 #include "mlir/Transforms/DialectConversion.h"
24 #include "llvm/ADT/TypeSwitch.h"
25 
26 using namespace mlir;
27 using namespace circt;
28 using namespace hwarith;
29 
30 //===----------------------------------------------------------------------===//
31 // Utility functions
32 //===----------------------------------------------------------------------===//
33 
34 // Function for setting the 'sv.namehint' attribute on 'newOp' based on any
35 // currently existing 'sv.namehint' attached to the source operation of 'value'.
36 // The user provides a callback which returns a new namehint based on the old
37 // namehint.
38 static void
39 improveNamehint(Value oldValue, Operation *newOp,
40  llvm::function_ref<std::string(StringRef)> namehintCallback) {
41  if (auto *sourceOp = oldValue.getDefiningOp()) {
42  if (auto namehint =
43  sourceOp->getAttrOfType<mlir::StringAttr>("sv.namehint")) {
44  auto newNamehint = namehintCallback(namehint.strref());
45  newOp->setAttr("sv.namehint",
46  StringAttr::get(oldValue.getContext(), newNamehint));
47  }
48  }
49 }
50 
51 // Extract a bit range, specified via start bit and width, from a given value.
52 static Value extractBits(OpBuilder &builder, Location loc, Value value,
53  unsigned startBit, unsigned bitWidth) {
54  SmallVector<Value, 1> result;
55  builder.createOrFold<comb::ExtractOp>(result, loc, value, startBit, bitWidth);
56  Value extractedValue = result[0];
57  if (extractedValue != value) {
58  // only change namehint if a new operation was created.
59  auto *newOp = extractedValue.getDefiningOp();
60  improveNamehint(value, newOp, [&](StringRef oldNamehint) {
61  return (oldNamehint + "_" + std::to_string(startBit) + "_to_" +
62  std::to_string(startBit + bitWidth))
63  .str();
64  });
65  }
66  return extractedValue;
67 }
68 
69 // Perform the specified bit-extension (either sign- or zero-extension) for a
70 // given value to a desired target width.
71 static Value extendTypeWidth(OpBuilder &builder, Location loc, Value value,
72  unsigned targetWidth, bool signExtension) {
73  unsigned sourceWidth = value.getType().getIntOrFloatBitWidth();
74  unsigned extensionLength = targetWidth - sourceWidth;
75 
76  if (extensionLength == 0)
77  return value;
78 
79  Value extensionBits;
80  // https://circt.llvm.org/docs/Dialects/Comb/RationaleComb/#no-complement-negate-zext-sext-operators
81  if (signExtension) {
82  // Sign extension
83  Value highBit = extractBits(builder, loc, value,
84  /*startBit=*/sourceWidth - 1, /*bitWidth=*/1);
85  SmallVector<Value, 1> result;
86  builder.createOrFold<comb::ReplicateOp>(result, loc, highBit,
87  extensionLength);
88  extensionBits = result[0];
89  } else {
90  // Zero extension
91  extensionBits = builder
92  .create<hw::ConstantOp>(
93  loc, builder.getIntegerType(extensionLength), 0)
94  ->getOpResult(0);
95  }
96 
97  auto extOp = builder.create<comb::ConcatOp>(loc, extensionBits, value);
98  improveNamehint(value, extOp, [&](StringRef oldNamehint) {
99  return (oldNamehint + "_" + (signExtension ? "sext_" : "zext_") +
100  std::to_string(targetWidth))
101  .str();
102  });
103 
104  return extOp->getOpResult(0);
105 }
106 
107 static bool isSignednessType(Type type) {
108  auto match =
109  llvm::TypeSwitch<Type, bool>(type)
110  .Case<IntegerType>([](auto type) { return !type.isSignless(); })
111  .Case<hw::ArrayType>(
112  [](auto type) { return isSignednessType(type.getElementType()); })
113  .Case<hw::StructType>([](auto type) {
114  return llvm::any_of(type.getElements(), [](auto element) {
115  return isSignednessType(element.type);
116  });
117  })
118  .Case<hw::InOutType>(
119  [](auto type) { return isSignednessType(type.getElementType()); })
120  .Default([](auto type) { return false; });
121 
122  return match;
123 }
124 
125 static bool isSignednessAttr(Attribute attr) {
126  if (auto typeAttr = attr.dyn_cast<TypeAttr>())
127  return isSignednessType(typeAttr.getValue());
128  return false;
129 }
130 
131 /// Returns true if the given `op` is considered as legal for HWArith
132 /// conversion.
133 static bool isLegalOp(Operation *op) {
134  if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
135  return llvm::none_of(funcOp.getArgumentTypes(), isSignednessType) &&
136  llvm::none_of(funcOp.getResultTypes(), isSignednessType) &&
137  llvm::none_of(funcOp.getFunctionBody().getArgumentTypes(),
139  }
140 
141  if (auto modOp = dyn_cast<hw::HWModuleLike>(op)) {
142  return llvm::none_of(modOp.getPortTypes(), isSignednessType) &&
143  llvm::none_of(modOp.getModuleBody().getArgumentTypes(),
145  }
146 
147  auto attrs = llvm::map_range(op->getAttrs(), [](const NamedAttribute &attr) {
148  return attr.getValue();
149  });
150 
151  bool operandsOK = llvm::none_of(op->getOperandTypes(), isSignednessType);
152  bool resultsOK = llvm::none_of(op->getResultTypes(), isSignednessType);
153  bool attrsOK = llvm::none_of(attrs, isSignednessAttr);
154  return operandsOK && resultsOK && attrsOK;
155 }
156 
157 //===----------------------------------------------------------------------===//
158 // Conversion patterns
159 //===----------------------------------------------------------------------===//
160 
161 namespace {
162 struct ConstantOpLowering : public OpConversionPattern<ConstantOp> {
164 
165  LogicalResult
166  matchAndRewrite(ConstantOp constOp, OpAdaptor adaptor,
167  ConversionPatternRewriter &rewriter) const override {
168  rewriter.replaceOpWithNewOp<hw::ConstantOp>(constOp,
169  constOp.getConstantValue());
170  return success();
171  }
172 };
173 struct DivOpLowering : public OpConversionPattern<DivOp> {
175 
176  LogicalResult
177  matchAndRewrite(DivOp op, OpAdaptor adaptor,
178  ConversionPatternRewriter &rewriter) const override {
179  auto loc = op.getLoc();
180  auto isLhsTypeSigned =
181  op.getOperand(0).getType().template cast<IntegerType>().isSigned();
182  auto rhsType = op.getOperand(1).getType().template cast<IntegerType>();
183  auto targetType = op.getResult().getType().template cast<IntegerType>();
184 
185  // comb.div* needs identical bitwidths for its operands and its result.
186  // Hence, we need to calculate the minimal bitwidth that can be used to
187  // represent the result as well as the operands without precision or sign
188  // loss. The target size only depends on LHS and already handles the edge
189  // cases where the bitwidth needs to be increased by 1. Thus, the targetType
190  // is good enough for both the result as well as LHS.
191  // The bitwidth for RHS is bit tricky. If the RHS is unsigned and we are
192  // about to perform a signed division, then we need one additional bit to
193  // avoid misinterpretation of RHS as a signed value!
194  bool signedDivision = targetType.isSigned();
195  unsigned extendSize = std::max(
196  targetType.getWidth(),
197  rhsType.getWidth() + (signedDivision && !rhsType.isSigned() ? 1 : 0));
198 
199  // Extend the operands
200  Value lhsValue = extendTypeWidth(rewriter, loc, adaptor.getInputs()[0],
201  extendSize, isLhsTypeSigned);
202  Value rhsValue = extendTypeWidth(rewriter, loc, adaptor.getInputs()[1],
203  extendSize, rhsType.isSigned());
204 
205  Value divResult;
206  if (signedDivision)
207  divResult = rewriter.create<comb::DivSOp>(loc, lhsValue, rhsValue, false)
208  ->getOpResult(0);
209  else
210  divResult = rewriter.create<comb::DivUOp>(loc, lhsValue, rhsValue, false)
211  ->getOpResult(0);
212 
213  // Carry over any attributes from the original div op.
214  divResult.getDefiningOp()->setDialectAttrs(op->getDialectAttrs());
215 
216  // finally truncate back to the expected result size!
217  Value truncateResult = extractBits(rewriter, loc, divResult, /*startBit=*/0,
218  /*bitWidth=*/targetType.getWidth());
219  rewriter.replaceOp(op, truncateResult);
220 
221  return success();
222  }
223 };
224 } // namespace
225 
226 namespace {
227 struct CastOpLowering : public OpConversionPattern<CastOp> {
229 
230  LogicalResult
231  matchAndRewrite(CastOp op, OpAdaptor adaptor,
232  ConversionPatternRewriter &rewriter) const override {
233  auto sourceType = op.getIn().getType().cast<IntegerType>();
234  auto sourceWidth = sourceType.getWidth();
235  bool isSourceTypeSigned = sourceType.isSigned();
236  auto targetWidth = op.getOut().getType().cast<IntegerType>().getWidth();
237 
238  Value replaceValue;
239  if (sourceWidth == targetWidth) {
240  // the width does not change, we are done here and can directly use the
241  // lowering input value
242  replaceValue = adaptor.getIn();
243  } else if (sourceWidth < targetWidth) {
244  // bit extensions needed, the type of extension required is determined by
245  // the source type only!
246  replaceValue = extendTypeWidth(rewriter, op.getLoc(), adaptor.getIn(),
247  targetWidth, isSourceTypeSigned);
248  } else {
249  // bit truncation needed
250  replaceValue = extractBits(rewriter, op.getLoc(), adaptor.getIn(),
251  /*startBit=*/0, /*bitWidth=*/targetWidth);
252  }
253  rewriter.replaceOp(op, replaceValue);
254 
255  return success();
256  }
257 };
258 } // namespace
259 
260 namespace {
261 
262 // Utility lowering function that maps a hwarith::ICmpPredicate predicate and
263 // the information whether the comparison contains signed values to the
264 // corresponding comb::ICmpPredicate.
265 static comb::ICmpPredicate lowerPredicate(ICmpPredicate pred, bool isSigned) {
266 #define _CREATE_HWARITH_ICMP_CASE(x) \
267  case ICmpPredicate::x: \
268  return isSigned ? comb::ICmpPredicate::s##x : comb::ICmpPredicate::u##x
269 
270  switch (pred) {
271  case ICmpPredicate::eq:
272  return comb::ICmpPredicate::eq;
273 
274  case ICmpPredicate::ne:
275  return comb::ICmpPredicate::ne;
276 
281  }
282 
283 #undef _CREATE_HWARITH_ICMP_CASE
284 
285  llvm_unreachable(
286  "Missing hwarith::ICmpPredicate to comb::ICmpPredicate lowering");
287  return comb::ICmpPredicate::eq;
288 }
289 
290 struct ICmpOpLowering : public OpConversionPattern<ICmpOp> {
292 
293  LogicalResult
294  matchAndRewrite(ICmpOp op, OpAdaptor adaptor,
295  ConversionPatternRewriter &rewriter) const override {
296  auto lhsType = op.getLhs().getType().cast<IntegerType>();
297  auto rhsType = op.getRhs().getType().cast<IntegerType>();
298  IntegerType::SignednessSemantics cmpSignedness;
299  const unsigned cmpWidth =
300  inferAddResultType(cmpSignedness, lhsType, rhsType) - 1;
301 
302  ICmpPredicate pred = op.getPredicate();
303  comb::ICmpPredicate combPred = lowerPredicate(
304  pred, cmpSignedness == IntegerType::SignednessSemantics::Signed);
305 
306  const auto loc = op.getLoc();
307  Value lhsValue = extendTypeWidth(rewriter, loc, adaptor.getLhs(), cmpWidth,
308  lhsType.isSigned());
309  Value rhsValue = extendTypeWidth(rewriter, loc, adaptor.getRhs(), cmpWidth,
310  rhsType.isSigned());
311 
312  auto newOp = rewriter.replaceOpWithNewOp<comb::ICmpOp>(
313  op, combPred, lhsValue, rhsValue, false);
314  newOp->setDialectAttrs(op->getDialectAttrs());
315 
316  return success();
317  }
318 };
319 
320 template <class BinOp, class ReplaceOp>
321 struct BinaryOpLowering : public OpConversionPattern<BinOp> {
323  using OpAdaptor = typename OpConversionPattern<BinOp>::OpAdaptor;
324 
325  LogicalResult
326  matchAndRewrite(BinOp op, OpAdaptor adaptor,
327  ConversionPatternRewriter &rewriter) const override {
328  auto loc = op.getLoc();
329  auto isLhsTypeSigned =
330  op.getOperand(0).getType().template cast<IntegerType>().isSigned();
331  auto isRhsTypeSigned =
332  op.getOperand(1).getType().template cast<IntegerType>().isSigned();
333  auto targetWidth =
334  op.getResult().getType().template cast<IntegerType>().getWidth();
335 
336  Value lhsValue = extendTypeWidth(rewriter, loc, adaptor.getInputs()[0],
337  targetWidth, isLhsTypeSigned);
338  Value rhsValue = extendTypeWidth(rewriter, loc, adaptor.getInputs()[1],
339  targetWidth, isRhsTypeSigned);
340  auto newOp =
341  rewriter.replaceOpWithNewOp<ReplaceOp>(op, lhsValue, rhsValue, false);
342  newOp->setDialectAttrs(op->getDialectAttrs());
343  return success();
344  }
345 };
346 
347 } // namespace
348 
349 Type HWArithToHWTypeConverter::removeSignedness(Type type) {
350  auto it = conversionCache.find(type);
351  if (it != conversionCache.end())
352  return it->second.type;
353 
354  auto convertedType =
355  llvm::TypeSwitch<Type, Type>(type)
356  .Case<IntegerType>([](auto type) {
357  if (type.isSignless())
358  return type;
359  return IntegerType::get(type.getContext(), type.getWidth());
360  })
361  .Case<hw::ArrayType>([this](auto type) {
362  return hw::ArrayType::get(removeSignedness(type.getElementType()),
363  type.getNumElements());
364  })
365  .Case<hw::StructType>([this](auto type) {
366  // Recursively convert each element.
367  llvm::SmallVector<hw::StructType::FieldInfo> convertedElements;
368  for (auto element : type.getElements()) {
369  convertedElements.push_back(
370  {element.name, removeSignedness(element.type)});
371  }
372  return hw::StructType::get(type.getContext(), convertedElements);
373  })
374  .Case<hw::InOutType>([this](auto type) {
375  return hw::InOutType::get(removeSignedness(type.getElementType()));
376  })
377  .Default([](auto type) { return type; });
378 
379  return convertedType;
380 }
381 
382 HWArithToHWTypeConverter::HWArithToHWTypeConverter() {
383  // Pass any type through the signedness remover.
384  addConversion([this](Type type) { return removeSignedness(type); });
385 
386  addTargetMaterialization(
387  [&](mlir::OpBuilder &builder, mlir::Type resultType,
388  mlir::ValueRange inputs,
389  mlir::Location loc) -> std::optional<mlir::Value> {
390  if (inputs.size() != 1)
391  return std::nullopt;
392  return inputs[0];
393  });
394 
395  addSourceMaterialization(
396  [&](mlir::OpBuilder &builder, mlir::Type resultType,
397  mlir::ValueRange inputs,
398  mlir::Location loc) -> std::optional<mlir::Value> {
399  if (inputs.size() != 1)
400  return std::nullopt;
401  return inputs[0];
402  });
403 }
404 
405 //===----------------------------------------------------------------------===//
406 // Pass driver
407 //===----------------------------------------------------------------------===//
408 
410  HWArithToHWTypeConverter &typeConverter, RewritePatternSet &patterns) {
411  patterns.add<ConstantOpLowering, CastOpLowering, ICmpOpLowering,
412  BinaryOpLowering<AddOp, comb::AddOp>,
413  BinaryOpLowering<SubOp, comb::SubOp>,
414  BinaryOpLowering<MulOp, comb::MulOp>, DivOpLowering>(
415  typeConverter, patterns.getContext());
416 }
417 
418 namespace {
419 
420 class HWArithToHWPass : public HWArithToHWBase<HWArithToHWPass> {
421 public:
422  void runOnOperation() override {
423  ModuleOp module = getOperation();
424 
425  ConversionTarget target(getContext());
426  target.markUnknownOpDynamicallyLegal(isLegalOp);
427  RewritePatternSet patterns(&getContext());
428  HWArithToHWTypeConverter typeConverter;
429  target.addIllegalDialect<HWArithDialect>();
430 
431  // Add HWArith-specific conversion patterns.
433 
434  // ALL other operations are converted via the TypeConversionPattern which
435  // will replace an operation to an identical operation with replaced
436  // result types and operands.
437  patterns.add<TypeConversionPattern>(typeConverter, patterns.getContext());
438 
439  // Apply a full conversion - all operations must either be legal, be caught
440  // by one of the HWArith patterns or be converted by the
441  // TypeConversionPattern.
442  if (failed(applyFullConversion(module, target, std::move(patterns))))
443  return signalPassFailure();
444  }
445 };
446 } // namespace
447 
448 //===----------------------------------------------------------------------===//
449 // Pass initialization
450 //===----------------------------------------------------------------------===//
451 
452 std::unique_ptr<Pass> circt::createHWArithToHWPass() {
453  return std::make_unique<HWArithToHWPass>();
454 }
static bool isSignednessAttr(Attribute attr)
static bool isLegalOp(Operation *op)
Returns true if the given op is considered as legal for HWArith conversion.
#define _CREATE_HWARITH_ICMP_CASE(x)
static void improveNamehint(Value oldValue, Operation *newOp, llvm::function_ref< std::string(StringRef)> namehintCallback)
Definition: HWArithToHW.cpp:39
static Value extractBits(OpBuilder &builder, Location loc, Value value, unsigned startBit, unsigned bitWidth)
Definition: HWArithToHW.cpp:52
static bool isSignednessType(Type type)
static Value extendTypeWidth(OpBuilder &builder, Location loc, Value value, unsigned targetWidth, bool signExtension)
Definition: HWArithToHW.cpp:71
llvm::SmallVector< StringAttr > inputs
Builder builder
A helper type converter class that automatically populates the relevant materializations and type con...
Definition: HWArithToHW.h:29
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:53
unsigned inferAddResultType(IntegerType::SignednessSemantics &signedness, IntegerType lhs, IntegerType rhs)
Definition: HWArithOps.cpp:194
This file defines an intermediate representation for circuits acting as an abstraction for constraint...
void populateHWArithToHWConversionPatterns(HWArithToHWTypeConverter &typeConverter, RewritePatternSet &patterns)
Get the HWArith to HW conversion patterns.
std::unique_ptr< mlir::Pass > createHWArithToHWPass()
Generic pattern which replaces an operation by one of the same operation name, but with converted att...