CIRCT 23.0.0git
Loading...
Searching...
No Matches
HWArithToHW.cpp
Go to the documentation of this file.
1//===- HWArithToHW.cpp - HWArith to HW Lowering pass ------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This is the main HWArith to HW Lowering Pass Implementation.
10//
11//===----------------------------------------------------------------------===//
12
21#include "mlir/Pass/Pass.h"
22
23#include "mlir/Transforms/DialectConversion.h"
24#include "llvm/ADT/TypeSwitch.h"
25
26namespace circt {
27#define GEN_PASS_DEF_HWARITHTOHW
28#include "circt/Conversion/Passes.h.inc"
29} // namespace circt
30
31using namespace mlir;
32using namespace circt;
33using namespace hwarith;
34
35//===----------------------------------------------------------------------===//
36// Utility functions
37//===----------------------------------------------------------------------===//
38
39// Function for setting the 'sv.namehint' attribute on 'newOp' based on any
40// currently existing 'sv.namehint' attached to the source operation of 'value'.
41// The user provides a callback which returns a new namehint based on the old
42// namehint.
43static void
44improveNamehint(Value oldValue, Operation *newOp,
45 llvm::function_ref<std::string(StringRef)> namehintCallback) {
46 if (auto *sourceOp = oldValue.getDefiningOp()) {
47 if (auto namehint =
48 sourceOp->getAttrOfType<mlir::StringAttr>("sv.namehint")) {
49 auto newNamehint = namehintCallback(namehint.strref());
50 newOp->setAttr("sv.namehint",
51 StringAttr::get(oldValue.getContext(), newNamehint));
52 }
53 }
54}
55
56// Extract a bit range, specified via start bit and width, from a given value.
57static Value extractBits(OpBuilder &builder, Location loc, Value value,
58 unsigned startBit, unsigned bitWidth) {
59 Value extractedValue =
60 builder.createOrFold<comb::ExtractOp>(loc, value, startBit, bitWidth);
61 Operation *definingOp = extractedValue.getDefiningOp();
62 if (extractedValue != value && definingOp) {
63 // only change namehint if a new operation was created.
64 improveNamehint(value, definingOp, [&](StringRef oldNamehint) {
65 return (oldNamehint + "_" + std::to_string(startBit) + "_to_" +
66 std::to_string(startBit + bitWidth))
67 .str();
68 });
69 }
70 return extractedValue;
71}
72
73// Perform the specified bit-extension (either sign- or zero-extension) for a
74// given value to a desired target width.
75static Value extendTypeWidth(OpBuilder &builder, Location loc, Value value,
76 unsigned targetWidth, bool signExtension) {
77 unsigned sourceWidth = value.getType().getIntOrFloatBitWidth();
78 unsigned extensionLength = targetWidth - sourceWidth;
79
80 if (extensionLength == 0)
81 return value;
82
83 Value extensionBits;
84 // https://circt.llvm.org/docs/Dialects/Comb/RationaleComb/#no-complement-negate-zext-sext-operators
85 if (signExtension) {
86 // Sign extension
87 Value highBit = extractBits(builder, loc, value,
88 /*startBit=*/sourceWidth - 1, /*bitWidth=*/1);
89 extensionBits =
90 builder.createOrFold<comb::ReplicateOp>(loc, highBit, extensionLength);
91 } else {
92 // Zero extension
93 extensionBits =
94 hw::ConstantOp::create(builder, loc,
95 builder.getIntegerType(extensionLength), 0)
96 ->getOpResult(0);
97 }
98
99 auto extOp = comb::ConcatOp::create(builder, loc, extensionBits, value);
100 improveNamehint(value, extOp, [&](StringRef oldNamehint) {
101 return (oldNamehint + "_" + (signExtension ? "sext_" : "zext_") +
102 std::to_string(targetWidth))
103 .str();
104 });
105
106 return extOp->getOpResult(0);
107}
108
109static bool isSignednessType(Type type) {
110 auto match =
111 llvm::TypeSwitch<Type, bool>(type)
112 .Case<IntegerType>([](auto type) { return !type.isSignless(); })
113 .Case<hw::ArrayType>(
114 [](auto type) { return isSignednessType(type.getElementType()); })
115 .Case<hw::UnpackedArrayType>(
116 [](auto type) { return isSignednessType(type.getElementType()); })
117 .Case<hw::StructType>([](auto type) {
118 return llvm::any_of(type.getElements(), [](auto element) {
119 return isSignednessType(element.type);
120 });
121 })
122 .Case<hw::UnionType>([](auto type) {
123 return llvm::any_of(type.getElements(), [](auto element) {
124 return isSignednessType(element.type);
125 });
126 })
127 .Case<hw::InOutType>(
128 [](auto type) { return isSignednessType(type.getElementType()); })
129 .Case<hw::TypeAliasType>(
130 [](auto type) { return isSignednessType(type.getInnerType()); })
131 .Default([](auto type) { return false; });
132
133 return match;
134}
135
136static bool isSignednessAttr(Attribute attr) {
137 if (auto typeAttr = dyn_cast<TypeAttr>(attr))
138 return isSignednessType(typeAttr.getValue());
139 return false;
140}
141
142/// Returns true if the given `op` is considered as legal for HWArith
143/// conversion.
144static bool isLegalOp(Operation *op) {
145 if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
146 return llvm::none_of(funcOp.getArgumentTypes(), isSignednessType) &&
147 llvm::none_of(funcOp.getResultTypes(), isSignednessType) &&
148 llvm::none_of(funcOp.getFunctionBody().getArgumentTypes(),
150 }
151
152 if (auto modOp = dyn_cast<hw::HWModuleLike>(op)) {
153 return llvm::none_of(modOp.getPortTypes(), isSignednessType) &&
154 llvm::none_of(modOp.getModuleBody().getArgumentTypes(),
156 }
157
158 auto attrs = llvm::map_range(op->getAttrs(), [](const NamedAttribute &attr) {
159 return attr.getValue();
160 });
161
162 bool operandsOK = llvm::none_of(op->getOperandTypes(), isSignednessType);
163 bool resultsOK = llvm::none_of(op->getResultTypes(), isSignednessType);
164 bool attrsOK = llvm::none_of(attrs, isSignednessAttr);
165 return operandsOK && resultsOK && attrsOK;
166}
167
168//===----------------------------------------------------------------------===//
169// Conversion patterns
170//===----------------------------------------------------------------------===//
171
172namespace {
173struct ConstantOpLowering : public OpConversionPattern<ConstantOp> {
175
176 LogicalResult
177 matchAndRewrite(ConstantOp constOp, OpAdaptor adaptor,
178 ConversionPatternRewriter &rewriter) const override {
179 rewriter.replaceOpWithNewOp<hw::ConstantOp>(constOp,
180 constOp.getConstantValue());
181 return success();
182 }
183};
184struct DivOpLowering : public OpConversionPattern<DivOp> {
186
187 LogicalResult
188 matchAndRewrite(DivOp op, OpAdaptor adaptor,
189 ConversionPatternRewriter &rewriter) const override {
190 auto loc = op.getLoc();
191 auto isLhsTypeSigned =
192 cast<IntegerType>(op.getOperand(0).getType()).isSigned();
193 auto rhsType = cast<IntegerType>(op.getOperand(1).getType());
194 auto targetType = cast<IntegerType>(op.getResult().getType());
195
196 // comb.div* needs identical bitwidths for its operands and its result.
197 // Hence, we need to calculate the minimal bitwidth that can be used to
198 // represent the result as well as the operands without precision or sign
199 // loss. The target size only depends on LHS and already handles the edge
200 // cases where the bitwidth needs to be increased by 1. Thus, the targetType
201 // is good enough for both the result as well as LHS.
202 // The bitwidth for RHS is bit tricky. If the RHS is unsigned and we are
203 // about to perform a signed division, then we need one additional bit to
204 // avoid misinterpretation of RHS as a signed value!
205 bool signedDivision = targetType.isSigned();
206 unsigned extendSize = std::max(
207 targetType.getWidth(),
208 rhsType.getWidth() + (signedDivision && !rhsType.isSigned() ? 1 : 0));
209
210 // Extend the operands
211 Value lhsValue = extendTypeWidth(rewriter, loc, adaptor.getInputs()[0],
212 extendSize, isLhsTypeSigned);
213 Value rhsValue = extendTypeWidth(rewriter, loc, adaptor.getInputs()[1],
214 extendSize, rhsType.isSigned());
215
216 Value divResult;
217 if (signedDivision)
218 divResult = comb::DivSOp::create(rewriter, loc, lhsValue, rhsValue, false)
219 ->getOpResult(0);
220 else
221 divResult = comb::DivUOp::create(rewriter, loc, lhsValue, rhsValue, false)
222 ->getOpResult(0);
223
224 // Carry over any attributes from the original div op.
225 auto *divOp = divResult.getDefiningOp();
226 rewriter.modifyOpInPlace(
227 divOp, [&]() { divOp->setDialectAttrs(op->getDialectAttrs()); });
228
229 // finally truncate back to the expected result size!
230 Value truncateResult = extractBits(rewriter, loc, divResult, /*startBit=*/0,
231 /*bitWidth=*/targetType.getWidth());
232 rewriter.replaceOp(op, truncateResult);
233
234 return success();
235 }
236};
237} // namespace
238
239namespace {
240struct CastOpLowering : public OpConversionPattern<CastOp> {
242
243 LogicalResult
244 matchAndRewrite(CastOp op, OpAdaptor adaptor,
245 ConversionPatternRewriter &rewriter) const override {
246 auto sourceType = cast<IntegerType>(op.getIn().getType());
247 auto sourceWidth = sourceType.getWidth();
248 bool isSourceTypeSigned = sourceType.isSigned();
249 auto targetWidth = cast<IntegerType>(op.getOut().getType()).getWidth();
250
251 Value replaceValue;
252 if (sourceWidth == targetWidth) {
253 // the width does not change, we are done here and can directly use the
254 // lowering input value
255 replaceValue = adaptor.getIn();
256 } else if (sourceWidth < targetWidth) {
257 // bit extensions needed, the type of extension required is determined by
258 // the source type only!
259 replaceValue = extendTypeWidth(rewriter, op.getLoc(), adaptor.getIn(),
260 targetWidth, isSourceTypeSigned);
261 } else {
262 // bit truncation needed
263 replaceValue = extractBits(rewriter, op.getLoc(), adaptor.getIn(),
264 /*startBit=*/0, /*bitWidth=*/targetWidth);
265 }
266 rewriter.replaceOp(op, replaceValue);
267
268 return success();
269 }
270};
271} // namespace
272
273namespace {
274
275// Utility lowering function that maps a hwarith::ICmpPredicate predicate and
276// the information whether the comparison contains signed values to the
277// corresponding comb::ICmpPredicate.
278static comb::ICmpPredicate lowerPredicate(ICmpPredicate pred, bool isSigned) {
279 switch (pred) {
280 case ICmpPredicate::eq:
281 return comb::ICmpPredicate::eq;
282 case ICmpPredicate::ne:
283 return comb::ICmpPredicate::ne;
284 case ICmpPredicate::lt:
285 return isSigned ? comb::ICmpPredicate::slt : comb::ICmpPredicate::ult;
286 case ICmpPredicate::ge:
287 return isSigned ? comb::ICmpPredicate::sge : comb::ICmpPredicate::uge;
288 case ICmpPredicate::le:
289 return isSigned ? comb::ICmpPredicate::sle : comb::ICmpPredicate::ule;
290 case ICmpPredicate::gt:
291 return isSigned ? comb::ICmpPredicate::sgt : comb::ICmpPredicate::ugt;
292 }
293
294 llvm_unreachable(
295 "Missing hwarith::ICmpPredicate to comb::ICmpPredicate lowering");
296 return comb::ICmpPredicate::eq;
297}
298
299struct ICmpOpLowering : public OpConversionPattern<ICmpOp> {
301
302 LogicalResult
303 matchAndRewrite(ICmpOp op, OpAdaptor adaptor,
304 ConversionPatternRewriter &rewriter) const override {
305 auto lhsType = cast<IntegerType>(op.getLhs().getType());
306 auto rhsType = cast<IntegerType>(op.getRhs().getType());
307 IntegerType::SignednessSemantics cmpSignedness;
308 const unsigned cmpWidth =
309 inferAddResultType(cmpSignedness, lhsType, rhsType) - 1;
310
311 ICmpPredicate pred = op.getPredicate();
312 comb::ICmpPredicate combPred = lowerPredicate(
313 pred, cmpSignedness == IntegerType::SignednessSemantics::Signed);
314
315 const auto loc = op.getLoc();
316 Value lhsValue = extendTypeWidth(rewriter, loc, adaptor.getLhs(), cmpWidth,
317 lhsType.isSigned());
318 Value rhsValue = extendTypeWidth(rewriter, loc, adaptor.getRhs(), cmpWidth,
319 rhsType.isSigned());
320
321 auto newOp = comb::ICmpOp::create(rewriter, op->getLoc(), combPred,
322 lhsValue, rhsValue, false);
323 rewriter.modifyOpInPlace(
324 newOp, [&]() { newOp->setDialectAttrs(op->getDialectAttrs()); });
325 rewriter.replaceOp(op, newOp);
326
327 return success();
328 }
329};
330
331template <class BinOp, class ReplaceOp>
332struct BinaryOpLowering : public OpConversionPattern<BinOp> {
334 using OpAdaptor = typename OpConversionPattern<BinOp>::OpAdaptor;
335
336 LogicalResult
337 matchAndRewrite(BinOp op, OpAdaptor adaptor,
338 ConversionPatternRewriter &rewriter) const override {
339 auto loc = op.getLoc();
340 auto isLhsTypeSigned =
341 cast<IntegerType>(op.getOperand(0).getType()).isSigned();
342 auto isRhsTypeSigned =
343 cast<IntegerType>(op.getOperand(1).getType()).isSigned();
344 auto targetWidth = cast<IntegerType>(op.getResult().getType()).getWidth();
345
346 Value lhsValue = extendTypeWidth(rewriter, loc, adaptor.getInputs()[0],
347 targetWidth, isLhsTypeSigned);
348 Value rhsValue = extendTypeWidth(rewriter, loc, adaptor.getInputs()[1],
349 targetWidth, isRhsTypeSigned);
350 auto newOp =
351 ReplaceOp::create(rewriter, op.getLoc(), lhsValue, rhsValue, false);
352 rewriter.modifyOpInPlace(
353 newOp, [&]() { newOp->setDialectAttrs(op->getDialectAttrs()); });
354 rewriter.replaceOp(op, newOp);
355
356 return success();
357 }
358};
359
360} // namespace
361
363 auto it = conversionCache.find(type);
364 if (it != conversionCache.end())
365 return it->second.type;
366
367 auto convertedType =
368 llvm::TypeSwitch<Type, Type>(type)
369 .Case<IntegerType>([](auto type) {
370 if (type.isSignless())
371 return type;
372 return IntegerType::get(type.getContext(), type.getWidth());
373 })
374 .Case<hw::ArrayType>([this](auto type) {
375 return hw::ArrayType::get(removeSignedness(type.getElementType()),
376 type.getNumElements());
377 })
378 .Case<hw::UnpackedArrayType>([this](auto type) {
379 return hw::UnpackedArrayType::get(
380 removeSignedness(type.getElementType()), type.getNumElements());
381 })
382 .Case<hw::StructType>([this](auto type) {
383 // Recursively convert each element.
384 llvm::SmallVector<hw::StructType::FieldInfo> convertedElements;
385 for (auto element : type.getElements()) {
386 convertedElements.push_back(
387 {element.name, removeSignedness(element.type)});
388 }
389 return hw::StructType::get(type.getContext(), convertedElements);
390 })
391 .Case<hw::UnionType>([this](auto type) {
392 llvm::SmallVector<hw::UnionType::FieldInfo> convertedElements;
393 for (auto element : type.getElements()) {
394 convertedElements.push_back({element.name,
395 removeSignedness(element.type),
396 element.offset});
397 }
398 return hw::UnionType::get(type.getContext(), convertedElements);
399 })
400 .Case<hw::InOutType>([this](auto type) {
401 return hw::InOutType::get(removeSignedness(type.getElementType()));
402 })
403 .Case<hw::TypeAliasType>([this](auto type) {
404 return hw::TypeAliasType::get(
405 type.getRef(), removeSignedness(type.getInnerType()));
406 })
407 .Default([](auto type) { return type; });
408
409 return convertedType;
410}
411
413 // Pass any type through the signedness remover.
414 addConversion([this](Type type) { return removeSignedness(type); });
415
416 addTargetMaterialization([&](mlir::OpBuilder &builder, mlir::Type resultType,
417 mlir::ValueRange inputs,
418 mlir::Location loc) -> mlir::Value {
419 if (inputs.size() != 1)
420 return Value();
421 return UnrealizedConversionCastOp::create(builder, loc, resultType,
422 inputs[0])
423 ->getResult(0);
424 });
425
426 addSourceMaterialization([&](mlir::OpBuilder &builder, mlir::Type resultType,
427 mlir::ValueRange inputs,
428 mlir::Location loc) -> mlir::Value {
429 if (inputs.size() != 1)
430 return Value();
431 return UnrealizedConversionCastOp::create(builder, loc, resultType,
432 inputs[0])
433 ->getResult(0);
434 });
435}
436
437//===----------------------------------------------------------------------===//
438// Pass driver
439//===----------------------------------------------------------------------===//
440
442 HWArithToHWTypeConverter &typeConverter, RewritePatternSet &patterns) {
443 patterns.add<ConstantOpLowering, CastOpLowering, ICmpOpLowering,
444 BinaryOpLowering<AddOp, comb::AddOp>,
445 BinaryOpLowering<SubOp, comb::SubOp>,
446 BinaryOpLowering<MulOp, comb::MulOp>, DivOpLowering>(
447 typeConverter, patterns.getContext());
448}
449
450namespace {
451
452class HWArithToHWPass : public circt::impl::HWArithToHWBase<HWArithToHWPass> {
453public:
454 void runOnOperation() override {
455 ModuleOp module = getOperation();
456
457 ConversionTarget target(getContext());
458 target.markUnknownOpDynamicallyLegal(isLegalOp);
459 RewritePatternSet patterns(&getContext());
460 HWArithToHWTypeConverter typeConverter;
461 target.addIllegalDialect<HWArithDialect>();
462
463 // Add HWArith-specific conversion patterns.
465
466 // ALL other operations are converted via the TypeConversionPattern which
467 // will replace an operation to an identical operation with replaced
468 // result types and operands.
469 patterns.add<TypeConversionPattern>(typeConverter, patterns.getContext());
470
471 // Apply a full conversion - all operations must either be legal, be caught
472 // by one of the HWArith patterns or be converted by the
473 // TypeConversionPattern.
474 if (failed(applyFullConversion(module, target, std::move(patterns))))
475 return signalPassFailure();
476 }
477};
478} // namespace
479
480//===----------------------------------------------------------------------===//
481// Pass initialization
482//===----------------------------------------------------------------------===//
483
484std::unique_ptr<Pass> circt::createHWArithToHWPass() {
485 return std::make_unique<HWArithToHWPass>();
486}
static SmallVector< Value > extractBits(OpBuilder &builder, Value val)
static bool isLegalOp(Operation *op)
Returns true if the given op is considered as legal - i.e.
Definition DCToHW.cpp:844
static bool isSignednessAttr(Attribute attr)
static bool isLegalOp(Operation *op)
Returns true if the given op is considered as legal for HWArith conversion.
static void improveNamehint(Value oldValue, Operation *newOp, llvm::function_ref< std::string(StringRef)> namehintCallback)
static Value extractBits(OpBuilder &builder, Location loc, Value value, unsigned startBit, unsigned bitWidth)
static bool isSignednessType(Type type)
static Value extendTypeWidth(OpBuilder &builder, Location loc, Value value, unsigned targetWidth, bool signExtension)
A helper type converter class that automatically populates the relevant materializations and type con...
Definition HWArithToHW.h:32
mlir::Type removeSignedness(mlir::Type type)
llvm::DenseMap< mlir::Type, ConvertedType > conversionCache
Definition HWArithToHW.h:50
create(data_type, value)
Definition hw.py:433
unsigned inferAddResultType(IntegerType::SignednessSemantics &signedness, IntegerType lhs, IntegerType rhs)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
void populateHWArithToHWConversionPatterns(HWArithToHWTypeConverter &typeConverter, RewritePatternSet &patterns)
Get the HWArith to HW conversion patterns.
std::unique_ptr< mlir::Pass > createHWArithToHWPass()
Generic pattern which replaces an operation by one of the same operation name, but with converted att...