CIRCT  19.0.0git
HWArithToHW.cpp
Go to the documentation of this file.
1 //===- HWArithToHW.cpp - HWArith to HW Lowering pass ------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This is the main HWArith to HW Lowering Pass Implementation.
10 //
11 //===----------------------------------------------------------------------===//
12 
16 #include "circt/Dialect/HW/HWOps.h"
19 #include "circt/Dialect/SV/SVOps.h"
21 #include "mlir/Pass/Pass.h"
22 
23 #include "mlir/Transforms/DialectConversion.h"
24 #include "llvm/ADT/TypeSwitch.h"
25 
26 namespace circt {
27 #define GEN_PASS_DEF_HWARITHTOHW
28 #include "circt/Conversion/Passes.h.inc"
29 } // namespace circt
30 
31 using namespace mlir;
32 using namespace circt;
33 using namespace hwarith;
34 
35 //===----------------------------------------------------------------------===//
36 // Utility functions
37 //===----------------------------------------------------------------------===//
38 
39 // Function for setting the 'sv.namehint' attribute on 'newOp' based on any
40 // currently existing 'sv.namehint' attached to the source operation of 'value'.
41 // The user provides a callback which returns a new namehint based on the old
42 // namehint.
43 static void
44 improveNamehint(Value oldValue, Operation *newOp,
45  llvm::function_ref<std::string(StringRef)> namehintCallback) {
46  if (auto *sourceOp = oldValue.getDefiningOp()) {
47  if (auto namehint =
48  sourceOp->getAttrOfType<mlir::StringAttr>("sv.namehint")) {
49  auto newNamehint = namehintCallback(namehint.strref());
50  newOp->setAttr("sv.namehint",
51  StringAttr::get(oldValue.getContext(), newNamehint));
52  }
53  }
54 }
55 
56 // Extract a bit range, specified via start bit and width, from a given value.
57 static Value extractBits(OpBuilder &builder, Location loc, Value value,
58  unsigned startBit, unsigned bitWidth) {
59  Value extractedValue =
60  builder.createOrFold<comb::ExtractOp>(loc, value, startBit, bitWidth);
61  Operation *definingOp = extractedValue.getDefiningOp();
62  if (extractedValue != value && definingOp) {
63  // only change namehint if a new operation was created.
64  improveNamehint(value, definingOp, [&](StringRef oldNamehint) {
65  return (oldNamehint + "_" + std::to_string(startBit) + "_to_" +
66  std::to_string(startBit + bitWidth))
67  .str();
68  });
69  }
70  return extractedValue;
71 }
72 
73 // Perform the specified bit-extension (either sign- or zero-extension) for a
74 // given value to a desired target width.
75 static Value extendTypeWidth(OpBuilder &builder, Location loc, Value value,
76  unsigned targetWidth, bool signExtension) {
77  unsigned sourceWidth = value.getType().getIntOrFloatBitWidth();
78  unsigned extensionLength = targetWidth - sourceWidth;
79 
80  if (extensionLength == 0)
81  return value;
82 
83  Value extensionBits;
84  // https://circt.llvm.org/docs/Dialects/Comb/RationaleComb/#no-complement-negate-zext-sext-operators
85  if (signExtension) {
86  // Sign extension
87  Value highBit = extractBits(builder, loc, value,
88  /*startBit=*/sourceWidth - 1, /*bitWidth=*/1);
89  extensionBits =
90  builder.createOrFold<comb::ReplicateOp>(loc, highBit, extensionLength);
91  } else {
92  // Zero extension
93  extensionBits = builder
94  .create<hw::ConstantOp>(
95  loc, builder.getIntegerType(extensionLength), 0)
96  ->getOpResult(0);
97  }
98 
99  auto extOp = builder.create<comb::ConcatOp>(loc, extensionBits, value);
100  improveNamehint(value, extOp, [&](StringRef oldNamehint) {
101  return (oldNamehint + "_" + (signExtension ? "sext_" : "zext_") +
102  std::to_string(targetWidth))
103  .str();
104  });
105 
106  return extOp->getOpResult(0);
107 }
108 
109 static bool isSignednessType(Type type) {
110  auto match =
111  llvm::TypeSwitch<Type, bool>(type)
112  .Case<IntegerType>([](auto type) { return !type.isSignless(); })
113  .Case<hw::ArrayType>(
114  [](auto type) { return isSignednessType(type.getElementType()); })
115  .Case<hw::StructType>([](auto type) {
116  return llvm::any_of(type.getElements(), [](auto element) {
117  return isSignednessType(element.type);
118  });
119  })
120  .Case<hw::InOutType>(
121  [](auto type) { return isSignednessType(type.getElementType()); })
122  .Default([](auto type) { return false; });
123 
124  return match;
125 }
126 
127 static bool isSignednessAttr(Attribute attr) {
128  if (auto typeAttr = dyn_cast<TypeAttr>(attr))
129  return isSignednessType(typeAttr.getValue());
130  return false;
131 }
132 
133 /// Returns true if the given `op` is considered as legal for HWArith
134 /// conversion.
135 static bool isLegalOp(Operation *op) {
136  if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
137  return llvm::none_of(funcOp.getArgumentTypes(), isSignednessType) &&
138  llvm::none_of(funcOp.getResultTypes(), isSignednessType) &&
139  llvm::none_of(funcOp.getFunctionBody().getArgumentTypes(),
141  }
142 
143  if (auto modOp = dyn_cast<hw::HWModuleLike>(op)) {
144  return llvm::none_of(modOp.getPortTypes(), isSignednessType) &&
145  llvm::none_of(modOp.getModuleBody().getArgumentTypes(),
147  }
148 
149  auto attrs = llvm::map_range(op->getAttrs(), [](const NamedAttribute &attr) {
150  return attr.getValue();
151  });
152 
153  bool operandsOK = llvm::none_of(op->getOperandTypes(), isSignednessType);
154  bool resultsOK = llvm::none_of(op->getResultTypes(), isSignednessType);
155  bool attrsOK = llvm::none_of(attrs, isSignednessAttr);
156  return operandsOK && resultsOK && attrsOK;
157 }
158 
159 //===----------------------------------------------------------------------===//
160 // Conversion patterns
161 //===----------------------------------------------------------------------===//
162 
163 namespace {
164 struct ConstantOpLowering : public OpConversionPattern<ConstantOp> {
166 
167  LogicalResult
168  matchAndRewrite(ConstantOp constOp, OpAdaptor adaptor,
169  ConversionPatternRewriter &rewriter) const override {
170  rewriter.replaceOpWithNewOp<hw::ConstantOp>(constOp,
171  constOp.getConstantValue());
172  return success();
173  }
174 };
175 struct DivOpLowering : public OpConversionPattern<DivOp> {
177 
178  LogicalResult
179  matchAndRewrite(DivOp op, OpAdaptor adaptor,
180  ConversionPatternRewriter &rewriter) const override {
181  auto loc = op.getLoc();
182  auto isLhsTypeSigned =
183  cast<IntegerType>(op.getOperand(0).getType()).isSigned();
184  auto rhsType = cast<IntegerType>(op.getOperand(1).getType());
185  auto targetType = cast<IntegerType>(op.getResult().getType());
186 
187  // comb.div* needs identical bitwidths for its operands and its result.
188  // Hence, we need to calculate the minimal bitwidth that can be used to
189  // represent the result as well as the operands without precision or sign
190  // loss. The target size only depends on LHS and already handles the edge
191  // cases where the bitwidth needs to be increased by 1. Thus, the targetType
192  // is good enough for both the result as well as LHS.
193  // The bitwidth for RHS is bit tricky. If the RHS is unsigned and we are
194  // about to perform a signed division, then we need one additional bit to
195  // avoid misinterpretation of RHS as a signed value!
196  bool signedDivision = targetType.isSigned();
197  unsigned extendSize = std::max(
198  targetType.getWidth(),
199  rhsType.getWidth() + (signedDivision && !rhsType.isSigned() ? 1 : 0));
200 
201  // Extend the operands
202  Value lhsValue = extendTypeWidth(rewriter, loc, adaptor.getInputs()[0],
203  extendSize, isLhsTypeSigned);
204  Value rhsValue = extendTypeWidth(rewriter, loc, adaptor.getInputs()[1],
205  extendSize, rhsType.isSigned());
206 
207  Value divResult;
208  if (signedDivision)
209  divResult = rewriter.create<comb::DivSOp>(loc, lhsValue, rhsValue, false)
210  ->getOpResult(0);
211  else
212  divResult = rewriter.create<comb::DivUOp>(loc, lhsValue, rhsValue, false)
213  ->getOpResult(0);
214 
215  // Carry over any attributes from the original div op.
216  auto *divOp = divResult.getDefiningOp();
217  rewriter.modifyOpInPlace(
218  divOp, [&]() { divOp->setDialectAttrs(op->getDialectAttrs()); });
219 
220  // finally truncate back to the expected result size!
221  Value truncateResult = extractBits(rewriter, loc, divResult, /*startBit=*/0,
222  /*bitWidth=*/targetType.getWidth());
223  rewriter.replaceOp(op, truncateResult);
224 
225  return success();
226  }
227 };
228 } // namespace
229 
230 namespace {
231 struct CastOpLowering : public OpConversionPattern<CastOp> {
233 
234  LogicalResult
235  matchAndRewrite(CastOp op, OpAdaptor adaptor,
236  ConversionPatternRewriter &rewriter) const override {
237  auto sourceType = cast<IntegerType>(op.getIn().getType());
238  auto sourceWidth = sourceType.getWidth();
239  bool isSourceTypeSigned = sourceType.isSigned();
240  auto targetWidth = cast<IntegerType>(op.getOut().getType()).getWidth();
241 
242  Value replaceValue;
243  if (sourceWidth == targetWidth) {
244  // the width does not change, we are done here and can directly use the
245  // lowering input value
246  replaceValue = adaptor.getIn();
247  } else if (sourceWidth < targetWidth) {
248  // bit extensions needed, the type of extension required is determined by
249  // the source type only!
250  replaceValue = extendTypeWidth(rewriter, op.getLoc(), adaptor.getIn(),
251  targetWidth, isSourceTypeSigned);
252  } else {
253  // bit truncation needed
254  replaceValue = extractBits(rewriter, op.getLoc(), adaptor.getIn(),
255  /*startBit=*/0, /*bitWidth=*/targetWidth);
256  }
257  rewriter.replaceOp(op, replaceValue);
258 
259  return success();
260  }
261 };
262 } // namespace
263 
264 namespace {
265 
266 // Utility lowering function that maps a hwarith::ICmpPredicate predicate and
267 // the information whether the comparison contains signed values to the
268 // corresponding comb::ICmpPredicate.
269 static comb::ICmpPredicate lowerPredicate(ICmpPredicate pred, bool isSigned) {
270  switch (pred) {
271  case ICmpPredicate::eq:
272  return comb::ICmpPredicate::eq;
273  case ICmpPredicate::ne:
274  return comb::ICmpPredicate::ne;
275  case ICmpPredicate::lt:
276  return isSigned ? comb::ICmpPredicate::slt : comb::ICmpPredicate::ult;
277  case ICmpPredicate::ge:
278  return isSigned ? comb::ICmpPredicate::sge : comb::ICmpPredicate::uge;
279  case ICmpPredicate::le:
280  return isSigned ? comb::ICmpPredicate::sle : comb::ICmpPredicate::ule;
281  case ICmpPredicate::gt:
282  return isSigned ? comb::ICmpPredicate::sgt : comb::ICmpPredicate::ugt;
283  }
284 
285  llvm_unreachable(
286  "Missing hwarith::ICmpPredicate to comb::ICmpPredicate lowering");
287  return comb::ICmpPredicate::eq;
288 }
289 
290 struct ICmpOpLowering : public OpConversionPattern<ICmpOp> {
292 
293  LogicalResult
294  matchAndRewrite(ICmpOp op, OpAdaptor adaptor,
295  ConversionPatternRewriter &rewriter) const override {
296  auto lhsType = cast<IntegerType>(op.getLhs().getType());
297  auto rhsType = cast<IntegerType>(op.getRhs().getType());
298  IntegerType::SignednessSemantics cmpSignedness;
299  const unsigned cmpWidth =
300  inferAddResultType(cmpSignedness, lhsType, rhsType) - 1;
301 
302  ICmpPredicate pred = op.getPredicate();
303  comb::ICmpPredicate combPred = lowerPredicate(
304  pred, cmpSignedness == IntegerType::SignednessSemantics::Signed);
305 
306  const auto loc = op.getLoc();
307  Value lhsValue = extendTypeWidth(rewriter, loc, adaptor.getLhs(), cmpWidth,
308  lhsType.isSigned());
309  Value rhsValue = extendTypeWidth(rewriter, loc, adaptor.getRhs(), cmpWidth,
310  rhsType.isSigned());
311 
312  auto newOp = rewriter.create<comb::ICmpOp>(op->getLoc(), combPred, lhsValue,
313  rhsValue, false);
314  rewriter.modifyOpInPlace(
315  newOp, [&]() { newOp->setDialectAttrs(op->getDialectAttrs()); });
316  rewriter.replaceOp(op, newOp);
317 
318  return success();
319  }
320 };
321 
322 template <class BinOp, class ReplaceOp>
323 struct BinaryOpLowering : public OpConversionPattern<BinOp> {
325  using OpAdaptor = typename OpConversionPattern<BinOp>::OpAdaptor;
326 
327  LogicalResult
328  matchAndRewrite(BinOp op, OpAdaptor adaptor,
329  ConversionPatternRewriter &rewriter) const override {
330  auto loc = op.getLoc();
331  auto isLhsTypeSigned =
332  cast<IntegerType>(op.getOperand(0).getType()).isSigned();
333  auto isRhsTypeSigned =
334  cast<IntegerType>(op.getOperand(1).getType()).isSigned();
335  auto targetWidth = cast<IntegerType>(op.getResult().getType()).getWidth();
336 
337  Value lhsValue = extendTypeWidth(rewriter, loc, adaptor.getInputs()[0],
338  targetWidth, isLhsTypeSigned);
339  Value rhsValue = extendTypeWidth(rewriter, loc, adaptor.getInputs()[1],
340  targetWidth, isRhsTypeSigned);
341  auto newOp =
342  rewriter.create<ReplaceOp>(op.getLoc(), lhsValue, rhsValue, false);
343  rewriter.modifyOpInPlace(
344  newOp, [&]() { newOp->setDialectAttrs(op->getDialectAttrs()); });
345  rewriter.replaceOp(op, newOp);
346 
347  return success();
348  }
349 };
350 
351 } // namespace
352 
353 Type HWArithToHWTypeConverter::removeSignedness(Type type) {
354  auto it = conversionCache.find(type);
355  if (it != conversionCache.end())
356  return it->second.type;
357 
358  auto convertedType =
359  llvm::TypeSwitch<Type, Type>(type)
360  .Case<IntegerType>([](auto type) {
361  if (type.isSignless())
362  return type;
363  return IntegerType::get(type.getContext(), type.getWidth());
364  })
365  .Case<hw::ArrayType>([this](auto type) {
366  return hw::ArrayType::get(removeSignedness(type.getElementType()),
367  type.getNumElements());
368  })
369  .Case<hw::StructType>([this](auto type) {
370  // Recursively convert each element.
371  llvm::SmallVector<hw::StructType::FieldInfo> convertedElements;
372  for (auto element : type.getElements()) {
373  convertedElements.push_back(
374  {element.name, removeSignedness(element.type)});
375  }
376  return hw::StructType::get(type.getContext(), convertedElements);
377  })
378  .Case<hw::InOutType>([this](auto type) {
379  return hw::InOutType::get(removeSignedness(type.getElementType()));
380  })
381  .Default([](auto type) { return type; });
382 
383  return convertedType;
384 }
385 
386 HWArithToHWTypeConverter::HWArithToHWTypeConverter() {
387  // Pass any type through the signedness remover.
388  addConversion([this](Type type) { return removeSignedness(type); });
389 
390  addTargetMaterialization(
391  [&](mlir::OpBuilder &builder, mlir::Type resultType,
392  mlir::ValueRange inputs,
393  mlir::Location loc) -> std::optional<mlir::Value> {
394  if (inputs.size() != 1)
395  return std::nullopt;
396  return inputs[0];
397  });
398 
399  addSourceMaterialization(
400  [&](mlir::OpBuilder &builder, mlir::Type resultType,
401  mlir::ValueRange inputs,
402  mlir::Location loc) -> std::optional<mlir::Value> {
403  if (inputs.size() != 1)
404  return std::nullopt;
405  return inputs[0];
406  });
407 }
408 
409 //===----------------------------------------------------------------------===//
410 // Pass driver
411 //===----------------------------------------------------------------------===//
412 
414  HWArithToHWTypeConverter &typeConverter, RewritePatternSet &patterns) {
415  patterns.add<ConstantOpLowering, CastOpLowering, ICmpOpLowering,
416  BinaryOpLowering<AddOp, comb::AddOp>,
417  BinaryOpLowering<SubOp, comb::SubOp>,
418  BinaryOpLowering<MulOp, comb::MulOp>, DivOpLowering>(
419  typeConverter, patterns.getContext());
420 }
421 
422 namespace {
423 
424 class HWArithToHWPass : public circt::impl::HWArithToHWBase<HWArithToHWPass> {
425 public:
426  void runOnOperation() override {
427  ModuleOp module = getOperation();
428 
429  ConversionTarget target(getContext());
430  target.markUnknownOpDynamicallyLegal(isLegalOp);
431  RewritePatternSet patterns(&getContext());
432  HWArithToHWTypeConverter typeConverter;
433  target.addIllegalDialect<HWArithDialect>();
434 
435  // Add HWArith-specific conversion patterns.
437 
438  // ALL other operations are converted via the TypeConversionPattern which
439  // will replace an operation to an identical operation with replaced
440  // result types and operands.
441  patterns.add<TypeConversionPattern>(typeConverter, patterns.getContext());
442 
443  // Apply a full conversion - all operations must either be legal, be caught
444  // by one of the HWArith patterns or be converted by the
445  // TypeConversionPattern.
446  if (failed(applyFullConversion(module, target, std::move(patterns))))
447  return signalPassFailure();
448  }
449 };
450 } // namespace
451 
452 //===----------------------------------------------------------------------===//
453 // Pass initialization
454 //===----------------------------------------------------------------------===//
455 
456 std::unique_ptr<Pass> circt::createHWArithToHWPass() {
457  return std::make_unique<HWArithToHWPass>();
458 }
static bool isSignednessAttr(Attribute attr)
static bool isLegalOp(Operation *op)
Returns true if the given op is considered as legal for HWArith conversion.
static void improveNamehint(Value oldValue, Operation *newOp, llvm::function_ref< std::string(StringRef)> namehintCallback)
Definition: HWArithToHW.cpp:44
static Value extractBits(OpBuilder &builder, Location loc, Value value, unsigned startBit, unsigned bitWidth)
Definition: HWArithToHW.cpp:57
static bool isSignednessType(Type type)
static Value extendTypeWidth(OpBuilder &builder, Location loc, Value value, unsigned targetWidth, bool signExtension)
Definition: HWArithToHW.cpp:75
llvm::SmallVector< StringAttr > inputs
Builder builder
A helper type converter class that automatically populates the relevant materializations and type con...
Definition: HWArithToHW.h:32
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:54
unsigned inferAddResultType(IntegerType::SignednessSemantics &signedness, IntegerType lhs, IntegerType rhs)
Definition: HWArithOps.cpp:194
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21
void populateHWArithToHWConversionPatterns(HWArithToHWTypeConverter &typeConverter, RewritePatternSet &patterns)
Get the HWArith to HW conversion patterns.
std::unique_ptr< mlir::Pass > createHWArithToHWPass()
Generic pattern which replaces an operation by one of the same operation name, but with converted att...