CIRCT  19.0.0git
HWArithToHW.cpp
Go to the documentation of this file.
1 //===- HWArithToHW.cpp - HWArith to HW Lowering pass ------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This is the main HWArith to HW Lowering Pass Implementation.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 #include "../PassDetail.h"
17 #include "circt/Dialect/HW/HWOps.h"
20 #include "circt/Dialect/SV/SVOps.h"
22 
23 #include "mlir/Transforms/DialectConversion.h"
24 #include "llvm/ADT/TypeSwitch.h"
25 
26 using namespace mlir;
27 using namespace circt;
28 using namespace hwarith;
29 
30 //===----------------------------------------------------------------------===//
31 // Utility functions
32 //===----------------------------------------------------------------------===//
33 
34 // Function for setting the 'sv.namehint' attribute on 'newOp' based on any
35 // currently existing 'sv.namehint' attached to the source operation of 'value'.
36 // The user provides a callback which returns a new namehint based on the old
37 // namehint.
38 static void
39 improveNamehint(Value oldValue, Operation *newOp,
40  llvm::function_ref<std::string(StringRef)> namehintCallback) {
41  if (auto *sourceOp = oldValue.getDefiningOp()) {
42  if (auto namehint =
43  sourceOp->getAttrOfType<mlir::StringAttr>("sv.namehint")) {
44  auto newNamehint = namehintCallback(namehint.strref());
45  newOp->setAttr("sv.namehint",
46  StringAttr::get(oldValue.getContext(), newNamehint));
47  }
48  }
49 }
50 
51 // Extract a bit range, specified via start bit and width, from a given value.
52 static Value extractBits(OpBuilder &builder, Location loc, Value value,
53  unsigned startBit, unsigned bitWidth) {
54  Value extractedValue =
55  builder.createOrFold<comb::ExtractOp>(loc, value, startBit, bitWidth);
56  Operation *definingOp = extractedValue.getDefiningOp();
57  if (extractedValue != value && definingOp) {
58  // only change namehint if a new operation was created.
59  improveNamehint(value, definingOp, [&](StringRef oldNamehint) {
60  return (oldNamehint + "_" + std::to_string(startBit) + "_to_" +
61  std::to_string(startBit + bitWidth))
62  .str();
63  });
64  }
65  return extractedValue;
66 }
67 
68 // Perform the specified bit-extension (either sign- or zero-extension) for a
69 // given value to a desired target width.
70 static Value extendTypeWidth(OpBuilder &builder, Location loc, Value value,
71  unsigned targetWidth, bool signExtension) {
72  unsigned sourceWidth = value.getType().getIntOrFloatBitWidth();
73  unsigned extensionLength = targetWidth - sourceWidth;
74 
75  if (extensionLength == 0)
76  return value;
77 
78  Value extensionBits;
79  // https://circt.llvm.org/docs/Dialects/Comb/RationaleComb/#no-complement-negate-zext-sext-operators
80  if (signExtension) {
81  // Sign extension
82  Value highBit = extractBits(builder, loc, value,
83  /*startBit=*/sourceWidth - 1, /*bitWidth=*/1);
84  extensionBits =
85  builder.createOrFold<comb::ReplicateOp>(loc, highBit, extensionLength);
86  } else {
87  // Zero extension
88  extensionBits = builder
89  .create<hw::ConstantOp>(
90  loc, builder.getIntegerType(extensionLength), 0)
91  ->getOpResult(0);
92  }
93 
94  auto extOp = builder.create<comb::ConcatOp>(loc, extensionBits, value);
95  improveNamehint(value, extOp, [&](StringRef oldNamehint) {
96  return (oldNamehint + "_" + (signExtension ? "sext_" : "zext_") +
97  std::to_string(targetWidth))
98  .str();
99  });
100 
101  return extOp->getOpResult(0);
102 }
103 
104 static bool isSignednessType(Type type) {
105  auto match =
106  llvm::TypeSwitch<Type, bool>(type)
107  .Case<IntegerType>([](auto type) { return !type.isSignless(); })
108  .Case<hw::ArrayType>(
109  [](auto type) { return isSignednessType(type.getElementType()); })
110  .Case<hw::StructType>([](auto type) {
111  return llvm::any_of(type.getElements(), [](auto element) {
112  return isSignednessType(element.type);
113  });
114  })
115  .Case<hw::InOutType>(
116  [](auto type) { return isSignednessType(type.getElementType()); })
117  .Default([](auto type) { return false; });
118 
119  return match;
120 }
121 
122 static bool isSignednessAttr(Attribute attr) {
123  if (auto typeAttr = attr.dyn_cast<TypeAttr>())
124  return isSignednessType(typeAttr.getValue());
125  return false;
126 }
127 
128 /// Returns true if the given `op` is considered as legal for HWArith
129 /// conversion.
130 static bool isLegalOp(Operation *op) {
131  if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
132  return llvm::none_of(funcOp.getArgumentTypes(), isSignednessType) &&
133  llvm::none_of(funcOp.getResultTypes(), isSignednessType) &&
134  llvm::none_of(funcOp.getFunctionBody().getArgumentTypes(),
136  }
137 
138  if (auto modOp = dyn_cast<hw::HWModuleLike>(op)) {
139  return llvm::none_of(modOp.getPortTypes(), isSignednessType) &&
140  llvm::none_of(modOp.getModuleBody().getArgumentTypes(),
142  }
143 
144  auto attrs = llvm::map_range(op->getAttrs(), [](const NamedAttribute &attr) {
145  return attr.getValue();
146  });
147 
148  bool operandsOK = llvm::none_of(op->getOperandTypes(), isSignednessType);
149  bool resultsOK = llvm::none_of(op->getResultTypes(), isSignednessType);
150  bool attrsOK = llvm::none_of(attrs, isSignednessAttr);
151  return operandsOK && resultsOK && attrsOK;
152 }
153 
154 //===----------------------------------------------------------------------===//
155 // Conversion patterns
156 //===----------------------------------------------------------------------===//
157 
158 namespace {
159 struct ConstantOpLowering : public OpConversionPattern<ConstantOp> {
161 
162  LogicalResult
163  matchAndRewrite(ConstantOp constOp, OpAdaptor adaptor,
164  ConversionPatternRewriter &rewriter) const override {
165  rewriter.replaceOpWithNewOp<hw::ConstantOp>(constOp,
166  constOp.getConstantValue());
167  return success();
168  }
169 };
170 struct DivOpLowering : public OpConversionPattern<DivOp> {
172 
173  LogicalResult
174  matchAndRewrite(DivOp op, OpAdaptor adaptor,
175  ConversionPatternRewriter &rewriter) const override {
176  auto loc = op.getLoc();
177  auto isLhsTypeSigned =
178  op.getOperand(0).getType().template cast<IntegerType>().isSigned();
179  auto rhsType = op.getOperand(1).getType().template cast<IntegerType>();
180  auto targetType = op.getResult().getType().template cast<IntegerType>();
181 
182  // comb.div* needs identical bitwidths for its operands and its result.
183  // Hence, we need to calculate the minimal bitwidth that can be used to
184  // represent the result as well as the operands without precision or sign
185  // loss. The target size only depends on LHS and already handles the edge
186  // cases where the bitwidth needs to be increased by 1. Thus, the targetType
187  // is good enough for both the result as well as LHS.
188  // The bitwidth for RHS is bit tricky. If the RHS is unsigned and we are
189  // about to perform a signed division, then we need one additional bit to
190  // avoid misinterpretation of RHS as a signed value!
191  bool signedDivision = targetType.isSigned();
192  unsigned extendSize = std::max(
193  targetType.getWidth(),
194  rhsType.getWidth() + (signedDivision && !rhsType.isSigned() ? 1 : 0));
195 
196  // Extend the operands
197  Value lhsValue = extendTypeWidth(rewriter, loc, adaptor.getInputs()[0],
198  extendSize, isLhsTypeSigned);
199  Value rhsValue = extendTypeWidth(rewriter, loc, adaptor.getInputs()[1],
200  extendSize, rhsType.isSigned());
201 
202  Value divResult;
203  if (signedDivision)
204  divResult = rewriter.create<comb::DivSOp>(loc, lhsValue, rhsValue, false)
205  ->getOpResult(0);
206  else
207  divResult = rewriter.create<comb::DivUOp>(loc, lhsValue, rhsValue, false)
208  ->getOpResult(0);
209 
210  // Carry over any attributes from the original div op.
211  divResult.getDefiningOp()->setDialectAttrs(op->getDialectAttrs());
212 
213  // finally truncate back to the expected result size!
214  Value truncateResult = extractBits(rewriter, loc, divResult, /*startBit=*/0,
215  /*bitWidth=*/targetType.getWidth());
216  rewriter.replaceOp(op, truncateResult);
217 
218  return success();
219  }
220 };
221 } // namespace
222 
223 namespace {
224 struct CastOpLowering : public OpConversionPattern<CastOp> {
226 
227  LogicalResult
228  matchAndRewrite(CastOp op, OpAdaptor adaptor,
229  ConversionPatternRewriter &rewriter) const override {
230  auto sourceType = op.getIn().getType().cast<IntegerType>();
231  auto sourceWidth = sourceType.getWidth();
232  bool isSourceTypeSigned = sourceType.isSigned();
233  auto targetWidth = op.getOut().getType().cast<IntegerType>().getWidth();
234 
235  Value replaceValue;
236  if (sourceWidth == targetWidth) {
237  // the width does not change, we are done here and can directly use the
238  // lowering input value
239  replaceValue = adaptor.getIn();
240  } else if (sourceWidth < targetWidth) {
241  // bit extensions needed, the type of extension required is determined by
242  // the source type only!
243  replaceValue = extendTypeWidth(rewriter, op.getLoc(), adaptor.getIn(),
244  targetWidth, isSourceTypeSigned);
245  } else {
246  // bit truncation needed
247  replaceValue = extractBits(rewriter, op.getLoc(), adaptor.getIn(),
248  /*startBit=*/0, /*bitWidth=*/targetWidth);
249  }
250  rewriter.replaceOp(op, replaceValue);
251 
252  return success();
253  }
254 };
255 } // namespace
256 
257 namespace {
258 
259 // Utility lowering function that maps a hwarith::ICmpPredicate predicate and
260 // the information whether the comparison contains signed values to the
261 // corresponding comb::ICmpPredicate.
262 static comb::ICmpPredicate lowerPredicate(ICmpPredicate pred, bool isSigned) {
263  switch (pred) {
264  case ICmpPredicate::eq:
265  return comb::ICmpPredicate::eq;
266  case ICmpPredicate::ne:
267  return comb::ICmpPredicate::ne;
268  case ICmpPredicate::lt:
269  return isSigned ? comb::ICmpPredicate::slt : comb::ICmpPredicate::ult;
270  case ICmpPredicate::ge:
271  return isSigned ? comb::ICmpPredicate::sge : comb::ICmpPredicate::uge;
272  case ICmpPredicate::le:
273  return isSigned ? comb::ICmpPredicate::sle : comb::ICmpPredicate::ule;
274  case ICmpPredicate::gt:
275  return isSigned ? comb::ICmpPredicate::sgt : comb::ICmpPredicate::ugt;
276  }
277 
278  llvm_unreachable(
279  "Missing hwarith::ICmpPredicate to comb::ICmpPredicate lowering");
280  return comb::ICmpPredicate::eq;
281 }
282 
283 struct ICmpOpLowering : public OpConversionPattern<ICmpOp> {
285 
286  LogicalResult
287  matchAndRewrite(ICmpOp op, OpAdaptor adaptor,
288  ConversionPatternRewriter &rewriter) const override {
289  auto lhsType = op.getLhs().getType().cast<IntegerType>();
290  auto rhsType = op.getRhs().getType().cast<IntegerType>();
291  IntegerType::SignednessSemantics cmpSignedness;
292  const unsigned cmpWidth =
293  inferAddResultType(cmpSignedness, lhsType, rhsType) - 1;
294 
295  ICmpPredicate pred = op.getPredicate();
296  comb::ICmpPredicate combPred = lowerPredicate(
297  pred, cmpSignedness == IntegerType::SignednessSemantics::Signed);
298 
299  const auto loc = op.getLoc();
300  Value lhsValue = extendTypeWidth(rewriter, loc, adaptor.getLhs(), cmpWidth,
301  lhsType.isSigned());
302  Value rhsValue = extendTypeWidth(rewriter, loc, adaptor.getRhs(), cmpWidth,
303  rhsType.isSigned());
304 
305  auto newOp = rewriter.replaceOpWithNewOp<comb::ICmpOp>(
306  op, combPred, lhsValue, rhsValue, false);
307  newOp->setDialectAttrs(op->getDialectAttrs());
308 
309  return success();
310  }
311 };
312 
313 template <class BinOp, class ReplaceOp>
314 struct BinaryOpLowering : public OpConversionPattern<BinOp> {
316  using OpAdaptor = typename OpConversionPattern<BinOp>::OpAdaptor;
317 
318  LogicalResult
319  matchAndRewrite(BinOp op, OpAdaptor adaptor,
320  ConversionPatternRewriter &rewriter) const override {
321  auto loc = op.getLoc();
322  auto isLhsTypeSigned =
323  op.getOperand(0).getType().template cast<IntegerType>().isSigned();
324  auto isRhsTypeSigned =
325  op.getOperand(1).getType().template cast<IntegerType>().isSigned();
326  auto targetWidth =
327  op.getResult().getType().template cast<IntegerType>().getWidth();
328 
329  Value lhsValue = extendTypeWidth(rewriter, loc, adaptor.getInputs()[0],
330  targetWidth, isLhsTypeSigned);
331  Value rhsValue = extendTypeWidth(rewriter, loc, adaptor.getInputs()[1],
332  targetWidth, isRhsTypeSigned);
333  auto newOp =
334  rewriter.replaceOpWithNewOp<ReplaceOp>(op, lhsValue, rhsValue, false);
335  newOp->setDialectAttrs(op->getDialectAttrs());
336  return success();
337  }
338 };
339 
340 } // namespace
341 
342 Type HWArithToHWTypeConverter::removeSignedness(Type type) {
343  auto it = conversionCache.find(type);
344  if (it != conversionCache.end())
345  return it->second.type;
346 
347  auto convertedType =
348  llvm::TypeSwitch<Type, Type>(type)
349  .Case<IntegerType>([](auto type) {
350  if (type.isSignless())
351  return type;
352  return IntegerType::get(type.getContext(), type.getWidth());
353  })
354  .Case<hw::ArrayType>([this](auto type) {
355  return hw::ArrayType::get(removeSignedness(type.getElementType()),
356  type.getNumElements());
357  })
358  .Case<hw::StructType>([this](auto type) {
359  // Recursively convert each element.
360  llvm::SmallVector<hw::StructType::FieldInfo> convertedElements;
361  for (auto element : type.getElements()) {
362  convertedElements.push_back(
363  {element.name, removeSignedness(element.type)});
364  }
365  return hw::StructType::get(type.getContext(), convertedElements);
366  })
367  .Case<hw::InOutType>([this](auto type) {
368  return hw::InOutType::get(removeSignedness(type.getElementType()));
369  })
370  .Default([](auto type) { return type; });
371 
372  return convertedType;
373 }
374 
375 HWArithToHWTypeConverter::HWArithToHWTypeConverter() {
376  // Pass any type through the signedness remover.
377  addConversion([this](Type type) { return removeSignedness(type); });
378 
379  addTargetMaterialization(
380  [&](mlir::OpBuilder &builder, mlir::Type resultType,
381  mlir::ValueRange inputs,
382  mlir::Location loc) -> std::optional<mlir::Value> {
383  if (inputs.size() != 1)
384  return std::nullopt;
385  return inputs[0];
386  });
387 
388  addSourceMaterialization(
389  [&](mlir::OpBuilder &builder, mlir::Type resultType,
390  mlir::ValueRange inputs,
391  mlir::Location loc) -> std::optional<mlir::Value> {
392  if (inputs.size() != 1)
393  return std::nullopt;
394  return inputs[0];
395  });
396 }
397 
398 //===----------------------------------------------------------------------===//
399 // Pass driver
400 //===----------------------------------------------------------------------===//
401 
403  HWArithToHWTypeConverter &typeConverter, RewritePatternSet &patterns) {
404  patterns.add<ConstantOpLowering, CastOpLowering, ICmpOpLowering,
405  BinaryOpLowering<AddOp, comb::AddOp>,
406  BinaryOpLowering<SubOp, comb::SubOp>,
407  BinaryOpLowering<MulOp, comb::MulOp>, DivOpLowering>(
408  typeConverter, patterns.getContext());
409 }
410 
411 namespace {
412 
413 class HWArithToHWPass : public HWArithToHWBase<HWArithToHWPass> {
414 public:
415  void runOnOperation() override {
416  ModuleOp module = getOperation();
417 
418  ConversionTarget target(getContext());
419  target.markUnknownOpDynamicallyLegal(isLegalOp);
420  RewritePatternSet patterns(&getContext());
421  HWArithToHWTypeConverter typeConverter;
422  target.addIllegalDialect<HWArithDialect>();
423 
424  // Add HWArith-specific conversion patterns.
426 
427  // ALL other operations are converted via the TypeConversionPattern which
428  // will replace an operation to an identical operation with replaced
429  // result types and operands.
430  patterns.add<TypeConversionPattern>(typeConverter, patterns.getContext());
431 
432  // Apply a full conversion - all operations must either be legal, be caught
433  // by one of the HWArith patterns or be converted by the
434  // TypeConversionPattern.
435  if (failed(applyFullConversion(module, target, std::move(patterns))))
436  return signalPassFailure();
437  }
438 };
439 } // namespace
440 
441 //===----------------------------------------------------------------------===//
442 // Pass initialization
443 //===----------------------------------------------------------------------===//
444 
445 std::unique_ptr<Pass> circt::createHWArithToHWPass() {
446  return std::make_unique<HWArithToHWPass>();
447 }
static bool isSignednessAttr(Attribute attr)
static bool isLegalOp(Operation *op)
Returns true if the given op is considered as legal for HWArith conversion.
static void improveNamehint(Value oldValue, Operation *newOp, llvm::function_ref< std::string(StringRef)> namehintCallback)
Definition: HWArithToHW.cpp:39
static Value extractBits(OpBuilder &builder, Location loc, Value value, unsigned startBit, unsigned bitWidth)
Definition: HWArithToHW.cpp:52
static bool isSignednessType(Type type)
static Value extendTypeWidth(OpBuilder &builder, Location loc, Value value, unsigned targetWidth, bool signExtension)
Definition: HWArithToHW.cpp:70
llvm::SmallVector< StringAttr > inputs
Builder builder
A helper type converter class that automatically populates the relevant materializations and type con...
Definition: HWArithToHW.h:29
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:53
uint64_t getWidth(Type t)
Definition: ESIPasses.cpp:34
unsigned inferAddResultType(IntegerType::SignednessSemantics &signedness, IntegerType lhs, IntegerType rhs)
Definition: HWArithOps.cpp:194
This file defines an intermediate representation for circuits acting as an abstraction for constraint...
Definition: DebugAnalysis.h:21
void populateHWArithToHWConversionPatterns(HWArithToHWTypeConverter &typeConverter, RewritePatternSet &patterns)
Get the HWArith to HW conversion patterns.
std::unique_ptr< mlir::Pass > createHWArithToHWPass()
Generic pattern which replaces an operation by one of the same operation name, but with converted att...