Loading [MathJax]/extensions/tex2jax.js
CIRCT 22.0.0git
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
HWArithToHW.cpp
Go to the documentation of this file.
1//===- HWArithToHW.cpp - HWArith to HW Lowering pass ------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This is the main HWArith to HW Lowering Pass Implementation.
10//
11//===----------------------------------------------------------------------===//
12
21#include "mlir/Pass/Pass.h"
22
23#include "mlir/Transforms/DialectConversion.h"
24#include "llvm/ADT/TypeSwitch.h"
25
26namespace circt {
27#define GEN_PASS_DEF_HWARITHTOHW
28#include "circt/Conversion/Passes.h.inc"
29} // namespace circt
30
31using namespace mlir;
32using namespace circt;
33using namespace hwarith;
34
35//===----------------------------------------------------------------------===//
36// Utility functions
37//===----------------------------------------------------------------------===//
38
39// Function for setting the 'sv.namehint' attribute on 'newOp' based on any
40// currently existing 'sv.namehint' attached to the source operation of 'value'.
41// The user provides a callback which returns a new namehint based on the old
42// namehint.
43static void
44improveNamehint(Value oldValue, Operation *newOp,
45 llvm::function_ref<std::string(StringRef)> namehintCallback) {
46 if (auto *sourceOp = oldValue.getDefiningOp()) {
47 if (auto namehint =
48 sourceOp->getAttrOfType<mlir::StringAttr>("sv.namehint")) {
49 auto newNamehint = namehintCallback(namehint.strref());
50 newOp->setAttr("sv.namehint",
51 StringAttr::get(oldValue.getContext(), newNamehint));
52 }
53 }
54}
55
56// Extract a bit range, specified via start bit and width, from a given value.
57static Value extractBits(OpBuilder &builder, Location loc, Value value,
58 unsigned startBit, unsigned bitWidth) {
59 Value extractedValue =
60 builder.createOrFold<comb::ExtractOp>(loc, value, startBit, bitWidth);
61 Operation *definingOp = extractedValue.getDefiningOp();
62 if (extractedValue != value && definingOp) {
63 // only change namehint if a new operation was created.
64 improveNamehint(value, definingOp, [&](StringRef oldNamehint) {
65 return (oldNamehint + "_" + std::to_string(startBit) + "_to_" +
66 std::to_string(startBit + bitWidth))
67 .str();
68 });
69 }
70 return extractedValue;
71}
72
73// Perform the specified bit-extension (either sign- or zero-extension) for a
74// given value to a desired target width.
75static Value extendTypeWidth(OpBuilder &builder, Location loc, Value value,
76 unsigned targetWidth, bool signExtension) {
77 unsigned sourceWidth = value.getType().getIntOrFloatBitWidth();
78 unsigned extensionLength = targetWidth - sourceWidth;
79
80 if (extensionLength == 0)
81 return value;
82
83 Value extensionBits;
84 // https://circt.llvm.org/docs/Dialects/Comb/RationaleComb/#no-complement-negate-zext-sext-operators
85 if (signExtension) {
86 // Sign extension
87 Value highBit = extractBits(builder, loc, value,
88 /*startBit=*/sourceWidth - 1, /*bitWidth=*/1);
89 extensionBits =
90 builder.createOrFold<comb::ReplicateOp>(loc, highBit, extensionLength);
91 } else {
92 // Zero extension
93 extensionBits = builder
94 .create<hw::ConstantOp>(
95 loc, builder.getIntegerType(extensionLength), 0)
96 ->getOpResult(0);
97 }
98
99 auto extOp = builder.create<comb::ConcatOp>(loc, extensionBits, value);
100 improveNamehint(value, extOp, [&](StringRef oldNamehint) {
101 return (oldNamehint + "_" + (signExtension ? "sext_" : "zext_") +
102 std::to_string(targetWidth))
103 .str();
104 });
105
106 return extOp->getOpResult(0);
107}
108
109static bool isSignednessType(Type type) {
110 auto match =
111 llvm::TypeSwitch<Type, bool>(type)
112 .Case<IntegerType>([](auto type) { return !type.isSignless(); })
113 .Case<hw::ArrayType>(
114 [](auto type) { return isSignednessType(type.getElementType()); })
115 .Case<hw::UnpackedArrayType>(
116 [](auto type) { return isSignednessType(type.getElementType()); })
117 .Case<hw::StructType>([](auto type) {
118 return llvm::any_of(type.getElements(), [](auto element) {
119 return isSignednessType(element.type);
120 });
121 })
122 .Case<hw::InOutType>(
123 [](auto type) { return isSignednessType(type.getElementType()); })
124 .Case<hw::TypeAliasType>(
125 [](auto type) { return isSignednessType(type.getInnerType()); })
126 .Default([](auto type) { return false; });
127
128 return match;
129}
130
131static bool isSignednessAttr(Attribute attr) {
132 if (auto typeAttr = dyn_cast<TypeAttr>(attr))
133 return isSignednessType(typeAttr.getValue());
134 return false;
135}
136
137/// Returns true if the given `op` is considered as legal for HWArith
138/// conversion.
139static bool isLegalOp(Operation *op) {
140 if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
141 return llvm::none_of(funcOp.getArgumentTypes(), isSignednessType) &&
142 llvm::none_of(funcOp.getResultTypes(), isSignednessType) &&
143 llvm::none_of(funcOp.getFunctionBody().getArgumentTypes(),
145 }
146
147 if (auto modOp = dyn_cast<hw::HWModuleLike>(op)) {
148 return llvm::none_of(modOp.getPortTypes(), isSignednessType) &&
149 llvm::none_of(modOp.getModuleBody().getArgumentTypes(),
151 }
152
153 auto attrs = llvm::map_range(op->getAttrs(), [](const NamedAttribute &attr) {
154 return attr.getValue();
155 });
156
157 bool operandsOK = llvm::none_of(op->getOperandTypes(), isSignednessType);
158 bool resultsOK = llvm::none_of(op->getResultTypes(), isSignednessType);
159 bool attrsOK = llvm::none_of(attrs, isSignednessAttr);
160 return operandsOK && resultsOK && attrsOK;
161}
162
163//===----------------------------------------------------------------------===//
164// Conversion patterns
165//===----------------------------------------------------------------------===//
166
167namespace {
168struct ConstantOpLowering : public OpConversionPattern<ConstantOp> {
170
171 LogicalResult
172 matchAndRewrite(ConstantOp constOp, OpAdaptor adaptor,
173 ConversionPatternRewriter &rewriter) const override {
174 rewriter.replaceOpWithNewOp<hw::ConstantOp>(constOp,
175 constOp.getConstantValue());
176 return success();
177 }
178};
179struct DivOpLowering : public OpConversionPattern<DivOp> {
181
182 LogicalResult
183 matchAndRewrite(DivOp op, OpAdaptor adaptor,
184 ConversionPatternRewriter &rewriter) const override {
185 auto loc = op.getLoc();
186 auto isLhsTypeSigned =
187 cast<IntegerType>(op.getOperand(0).getType()).isSigned();
188 auto rhsType = cast<IntegerType>(op.getOperand(1).getType());
189 auto targetType = cast<IntegerType>(op.getResult().getType());
190
191 // comb.div* needs identical bitwidths for its operands and its result.
192 // Hence, we need to calculate the minimal bitwidth that can be used to
193 // represent the result as well as the operands without precision or sign
194 // loss. The target size only depends on LHS and already handles the edge
195 // cases where the bitwidth needs to be increased by 1. Thus, the targetType
196 // is good enough for both the result as well as LHS.
197 // The bitwidth for RHS is bit tricky. If the RHS is unsigned and we are
198 // about to perform a signed division, then we need one additional bit to
199 // avoid misinterpretation of RHS as a signed value!
200 bool signedDivision = targetType.isSigned();
201 unsigned extendSize = std::max(
202 targetType.getWidth(),
203 rhsType.getWidth() + (signedDivision && !rhsType.isSigned() ? 1 : 0));
204
205 // Extend the operands
206 Value lhsValue = extendTypeWidth(rewriter, loc, adaptor.getInputs()[0],
207 extendSize, isLhsTypeSigned);
208 Value rhsValue = extendTypeWidth(rewriter, loc, adaptor.getInputs()[1],
209 extendSize, rhsType.isSigned());
210
211 Value divResult;
212 if (signedDivision)
213 divResult = rewriter.create<comb::DivSOp>(loc, lhsValue, rhsValue, false)
214 ->getOpResult(0);
215 else
216 divResult = rewriter.create<comb::DivUOp>(loc, lhsValue, rhsValue, false)
217 ->getOpResult(0);
218
219 // Carry over any attributes from the original div op.
220 auto *divOp = divResult.getDefiningOp();
221 rewriter.modifyOpInPlace(
222 divOp, [&]() { divOp->setDialectAttrs(op->getDialectAttrs()); });
223
224 // finally truncate back to the expected result size!
225 Value truncateResult = extractBits(rewriter, loc, divResult, /*startBit=*/0,
226 /*bitWidth=*/targetType.getWidth());
227 rewriter.replaceOp(op, truncateResult);
228
229 return success();
230 }
231};
232} // namespace
233
234namespace {
235struct CastOpLowering : public OpConversionPattern<CastOp> {
237
238 LogicalResult
239 matchAndRewrite(CastOp op, OpAdaptor adaptor,
240 ConversionPatternRewriter &rewriter) const override {
241 auto sourceType = cast<IntegerType>(op.getIn().getType());
242 auto sourceWidth = sourceType.getWidth();
243 bool isSourceTypeSigned = sourceType.isSigned();
244 auto targetWidth = cast<IntegerType>(op.getOut().getType()).getWidth();
245
246 Value replaceValue;
247 if (sourceWidth == targetWidth) {
248 // the width does not change, we are done here and can directly use the
249 // lowering input value
250 replaceValue = adaptor.getIn();
251 } else if (sourceWidth < targetWidth) {
252 // bit extensions needed, the type of extension required is determined by
253 // the source type only!
254 replaceValue = extendTypeWidth(rewriter, op.getLoc(), adaptor.getIn(),
255 targetWidth, isSourceTypeSigned);
256 } else {
257 // bit truncation needed
258 replaceValue = extractBits(rewriter, op.getLoc(), adaptor.getIn(),
259 /*startBit=*/0, /*bitWidth=*/targetWidth);
260 }
261 rewriter.replaceOp(op, replaceValue);
262
263 return success();
264 }
265};
266} // namespace
267
268namespace {
269
270// Utility lowering function that maps a hwarith::ICmpPredicate predicate and
271// the information whether the comparison contains signed values to the
272// corresponding comb::ICmpPredicate.
273static comb::ICmpPredicate lowerPredicate(ICmpPredicate pred, bool isSigned) {
274 switch (pred) {
275 case ICmpPredicate::eq:
276 return comb::ICmpPredicate::eq;
277 case ICmpPredicate::ne:
278 return comb::ICmpPredicate::ne;
279 case ICmpPredicate::lt:
280 return isSigned ? comb::ICmpPredicate::slt : comb::ICmpPredicate::ult;
281 case ICmpPredicate::ge:
282 return isSigned ? comb::ICmpPredicate::sge : comb::ICmpPredicate::uge;
283 case ICmpPredicate::le:
284 return isSigned ? comb::ICmpPredicate::sle : comb::ICmpPredicate::ule;
285 case ICmpPredicate::gt:
286 return isSigned ? comb::ICmpPredicate::sgt : comb::ICmpPredicate::ugt;
287 }
288
289 llvm_unreachable(
290 "Missing hwarith::ICmpPredicate to comb::ICmpPredicate lowering");
291 return comb::ICmpPredicate::eq;
292}
293
294struct ICmpOpLowering : public OpConversionPattern<ICmpOp> {
296
297 LogicalResult
298 matchAndRewrite(ICmpOp op, OpAdaptor adaptor,
299 ConversionPatternRewriter &rewriter) const override {
300 auto lhsType = cast<IntegerType>(op.getLhs().getType());
301 auto rhsType = cast<IntegerType>(op.getRhs().getType());
302 IntegerType::SignednessSemantics cmpSignedness;
303 const unsigned cmpWidth =
304 inferAddResultType(cmpSignedness, lhsType, rhsType) - 1;
305
306 ICmpPredicate pred = op.getPredicate();
307 comb::ICmpPredicate combPred = lowerPredicate(
308 pred, cmpSignedness == IntegerType::SignednessSemantics::Signed);
309
310 const auto loc = op.getLoc();
311 Value lhsValue = extendTypeWidth(rewriter, loc, adaptor.getLhs(), cmpWidth,
312 lhsType.isSigned());
313 Value rhsValue = extendTypeWidth(rewriter, loc, adaptor.getRhs(), cmpWidth,
314 rhsType.isSigned());
315
316 auto newOp = rewriter.create<comb::ICmpOp>(op->getLoc(), combPred, lhsValue,
317 rhsValue, false);
318 rewriter.modifyOpInPlace(
319 newOp, [&]() { newOp->setDialectAttrs(op->getDialectAttrs()); });
320 rewriter.replaceOp(op, newOp);
321
322 return success();
323 }
324};
325
326template <class BinOp, class ReplaceOp>
327struct BinaryOpLowering : public OpConversionPattern<BinOp> {
329 using OpAdaptor = typename OpConversionPattern<BinOp>::OpAdaptor;
330
331 LogicalResult
332 matchAndRewrite(BinOp op, OpAdaptor adaptor,
333 ConversionPatternRewriter &rewriter) const override {
334 auto loc = op.getLoc();
335 auto isLhsTypeSigned =
336 cast<IntegerType>(op.getOperand(0).getType()).isSigned();
337 auto isRhsTypeSigned =
338 cast<IntegerType>(op.getOperand(1).getType()).isSigned();
339 auto targetWidth = cast<IntegerType>(op.getResult().getType()).getWidth();
340
341 Value lhsValue = extendTypeWidth(rewriter, loc, adaptor.getInputs()[0],
342 targetWidth, isLhsTypeSigned);
343 Value rhsValue = extendTypeWidth(rewriter, loc, adaptor.getInputs()[1],
344 targetWidth, isRhsTypeSigned);
345 auto newOp =
346 rewriter.create<ReplaceOp>(op.getLoc(), lhsValue, rhsValue, false);
347 rewriter.modifyOpInPlace(
348 newOp, [&]() { newOp->setDialectAttrs(op->getDialectAttrs()); });
349 rewriter.replaceOp(op, newOp);
350
351 return success();
352 }
353};
354
355} // namespace
356
358 auto it = conversionCache.find(type);
359 if (it != conversionCache.end())
360 return it->second.type;
361
362 auto convertedType =
363 llvm::TypeSwitch<Type, Type>(type)
364 .Case<IntegerType>([](auto type) {
365 if (type.isSignless())
366 return type;
367 return IntegerType::get(type.getContext(), type.getWidth());
368 })
369 .Case<hw::ArrayType>([this](auto type) {
370 return hw::ArrayType::get(removeSignedness(type.getElementType()),
371 type.getNumElements());
372 })
373 .Case<hw::UnpackedArrayType>([this](auto type) {
374 return hw::UnpackedArrayType::get(
375 removeSignedness(type.getElementType()), type.getNumElements());
376 })
377 .Case<hw::StructType>([this](auto type) {
378 // Recursively convert each element.
379 llvm::SmallVector<hw::StructType::FieldInfo> convertedElements;
380 for (auto element : type.getElements()) {
381 convertedElements.push_back(
382 {element.name, removeSignedness(element.type)});
383 }
384 return hw::StructType::get(type.getContext(), convertedElements);
385 })
386 .Case<hw::InOutType>([this](auto type) {
387 return hw::InOutType::get(removeSignedness(type.getElementType()));
388 })
389 .Case<hw::TypeAliasType>([this](auto type) {
390 return hw::TypeAliasType::get(
391 type.getRef(), removeSignedness(type.getInnerType()));
392 })
393 .Default([](auto type) { return type; });
394
395 return convertedType;
396}
397
399 // Pass any type through the signedness remover.
400 addConversion([this](Type type) { return removeSignedness(type); });
401
402 addTargetMaterialization([&](mlir::OpBuilder &builder, mlir::Type resultType,
403 mlir::ValueRange inputs,
404 mlir::Location loc) -> mlir::Value {
405 if (inputs.size() != 1)
406 return Value();
407 return builder
408 .create<UnrealizedConversionCastOp>(loc, resultType, inputs[0])
409 ->getResult(0);
410 });
411
412 addSourceMaterialization([&](mlir::OpBuilder &builder, mlir::Type resultType,
413 mlir::ValueRange inputs,
414 mlir::Location loc) -> mlir::Value {
415 if (inputs.size() != 1)
416 return Value();
417 return builder
418 .create<UnrealizedConversionCastOp>(loc, resultType, inputs[0])
419 ->getResult(0);
420 });
421}
422
423//===----------------------------------------------------------------------===//
424// Pass driver
425//===----------------------------------------------------------------------===//
426
428 HWArithToHWTypeConverter &typeConverter, RewritePatternSet &patterns) {
429 patterns.add<ConstantOpLowering, CastOpLowering, ICmpOpLowering,
430 BinaryOpLowering<AddOp, comb::AddOp>,
431 BinaryOpLowering<SubOp, comb::SubOp>,
432 BinaryOpLowering<MulOp, comb::MulOp>, DivOpLowering>(
433 typeConverter, patterns.getContext());
434}
435
436namespace {
437
438class HWArithToHWPass : public circt::impl::HWArithToHWBase<HWArithToHWPass> {
439public:
440 void runOnOperation() override {
441 ModuleOp module = getOperation();
442
443 ConversionTarget target(getContext());
444 target.markUnknownOpDynamicallyLegal(isLegalOp);
445 RewritePatternSet patterns(&getContext());
446 HWArithToHWTypeConverter typeConverter;
447 target.addIllegalDialect<HWArithDialect>();
448
449 // Add HWArith-specific conversion patterns.
451
452 // ALL other operations are converted via the TypeConversionPattern which
453 // will replace an operation to an identical operation with replaced
454 // result types and operands.
455 patterns.add<TypeConversionPattern>(typeConverter, patterns.getContext());
456
457 // Apply a full conversion - all operations must either be legal, be caught
458 // by one of the HWArith patterns or be converted by the
459 // TypeConversionPattern.
460 if (failed(applyFullConversion(module, target, std::move(patterns))))
461 return signalPassFailure();
462 }
463};
464} // namespace
465
466//===----------------------------------------------------------------------===//
467// Pass initialization
468//===----------------------------------------------------------------------===//
469
470std::unique_ptr<Pass> circt::createHWArithToHWPass() {
471 return std::make_unique<HWArithToHWPass>();
472}
static SmallVector< Value > extractBits(OpBuilder &builder, Value val)
Definition CombToAIG.cpp:38
static bool isLegalOp(Operation *op)
Returns true if the given op is considered as legal - i.e.
Definition DCToHW.cpp:844
static bool isSignednessAttr(Attribute attr)
static bool isLegalOp(Operation *op)
Returns true if the given op is considered as legal for HWArith conversion.
static void improveNamehint(Value oldValue, Operation *newOp, llvm::function_ref< std::string(StringRef)> namehintCallback)
static Value extractBits(OpBuilder &builder, Location loc, Value value, unsigned startBit, unsigned bitWidth)
static bool isSignednessType(Type type)
static Value extendTypeWidth(OpBuilder &builder, Location loc, Value value, unsigned targetWidth, bool signExtension)
A helper type converter class that automatically populates the relevant materializations and type con...
Definition HWArithToHW.h:32
mlir::Type removeSignedness(mlir::Type type)
llvm::DenseMap< mlir::Type, ConvertedType > conversionCache
Definition HWArithToHW.h:50
unsigned inferAddResultType(IntegerType::SignednessSemantics &signedness, IntegerType lhs, IntegerType rhs)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
void populateHWArithToHWConversionPatterns(HWArithToHWTypeConverter &typeConverter, RewritePatternSet &patterns)
Get the HWArith to HW conversion patterns.
std::unique_ptr< mlir::Pass > createHWArithToHWPass()
Generic pattern which replaces an operation by one of the same operation name, but with converted att...