CIRCT 20.0.0git
Loading...
Searching...
No Matches
HWArithToHW.cpp
Go to the documentation of this file.
1//===- HWArithToHW.cpp - HWArith to HW Lowering pass ------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This is the main HWArith to HW Lowering Pass Implementation.
10//
11//===----------------------------------------------------------------------===//
12
21#include "mlir/Pass/Pass.h"
22
23#include "mlir/Transforms/DialectConversion.h"
24#include "llvm/ADT/TypeSwitch.h"
25
26namespace circt {
27#define GEN_PASS_DEF_HWARITHTOHW
28#include "circt/Conversion/Passes.h.inc"
29} // namespace circt
30
31using namespace mlir;
32using namespace circt;
33using namespace hwarith;
34
35//===----------------------------------------------------------------------===//
36// Utility functions
37//===----------------------------------------------------------------------===//
38
39// Function for setting the 'sv.namehint' attribute on 'newOp' based on any
40// currently existing 'sv.namehint' attached to the source operation of 'value'.
41// The user provides a callback which returns a new namehint based on the old
42// namehint.
43static void
44improveNamehint(Value oldValue, Operation *newOp,
45 llvm::function_ref<std::string(StringRef)> namehintCallback) {
46 if (auto *sourceOp = oldValue.getDefiningOp()) {
47 if (auto namehint =
48 sourceOp->getAttrOfType<mlir::StringAttr>("sv.namehint")) {
49 auto newNamehint = namehintCallback(namehint.strref());
50 newOp->setAttr("sv.namehint",
51 StringAttr::get(oldValue.getContext(), newNamehint));
52 }
53 }
54}
55
56// Extract a bit range, specified via start bit and width, from a given value.
57static Value extractBits(OpBuilder &builder, Location loc, Value value,
58 unsigned startBit, unsigned bitWidth) {
59 Value extractedValue =
60 builder.createOrFold<comb::ExtractOp>(loc, value, startBit, bitWidth);
61 Operation *definingOp = extractedValue.getDefiningOp();
62 if (extractedValue != value && definingOp) {
63 // only change namehint if a new operation was created.
64 improveNamehint(value, definingOp, [&](StringRef oldNamehint) {
65 return (oldNamehint + "_" + std::to_string(startBit) + "_to_" +
66 std::to_string(startBit + bitWidth))
67 .str();
68 });
69 }
70 return extractedValue;
71}
72
73// Perform the specified bit-extension (either sign- or zero-extension) for a
74// given value to a desired target width.
75static Value extendTypeWidth(OpBuilder &builder, Location loc, Value value,
76 unsigned targetWidth, bool signExtension) {
77 unsigned sourceWidth = value.getType().getIntOrFloatBitWidth();
78 unsigned extensionLength = targetWidth - sourceWidth;
79
80 if (extensionLength == 0)
81 return value;
82
83 Value extensionBits;
84 // https://circt.llvm.org/docs/Dialects/Comb/RationaleComb/#no-complement-negate-zext-sext-operators
85 if (signExtension) {
86 // Sign extension
87 Value highBit = extractBits(builder, loc, value,
88 /*startBit=*/sourceWidth - 1, /*bitWidth=*/1);
89 extensionBits =
90 builder.createOrFold<comb::ReplicateOp>(loc, highBit, extensionLength);
91 } else {
92 // Zero extension
93 extensionBits = builder
94 .create<hw::ConstantOp>(
95 loc, builder.getIntegerType(extensionLength), 0)
96 ->getOpResult(0);
97 }
98
99 auto extOp = builder.create<comb::ConcatOp>(loc, extensionBits, value);
100 improveNamehint(value, extOp, [&](StringRef oldNamehint) {
101 return (oldNamehint + "_" + (signExtension ? "sext_" : "zext_") +
102 std::to_string(targetWidth))
103 .str();
104 });
105
106 return extOp->getOpResult(0);
107}
108
109static bool isSignednessType(Type type) {
110 auto match =
111 llvm::TypeSwitch<Type, bool>(type)
112 .Case<IntegerType>([](auto type) { return !type.isSignless(); })
113 .Case<hw::ArrayType>(
114 [](auto type) { return isSignednessType(type.getElementType()); })
115 .Case<hw::StructType>([](auto type) {
116 return llvm::any_of(type.getElements(), [](auto element) {
117 return isSignednessType(element.type);
118 });
119 })
120 .Case<hw::InOutType>(
121 [](auto type) { return isSignednessType(type.getElementType()); })
122 .Case<hw::TypeAliasType>(
123 [](auto type) { return isSignednessType(type.getInnerType()); })
124 .Default([](auto type) { return false; });
125
126 return match;
127}
128
129static bool isSignednessAttr(Attribute attr) {
130 if (auto typeAttr = dyn_cast<TypeAttr>(attr))
131 return isSignednessType(typeAttr.getValue());
132 return false;
133}
134
135/// Returns true if the given `op` is considered as legal for HWArith
136/// conversion.
137static bool isLegalOp(Operation *op) {
138 if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
139 return llvm::none_of(funcOp.getArgumentTypes(), isSignednessType) &&
140 llvm::none_of(funcOp.getResultTypes(), isSignednessType) &&
141 llvm::none_of(funcOp.getFunctionBody().getArgumentTypes(),
143 }
144
145 if (auto modOp = dyn_cast<hw::HWModuleLike>(op)) {
146 return llvm::none_of(modOp.getPortTypes(), isSignednessType) &&
147 llvm::none_of(modOp.getModuleBody().getArgumentTypes(),
149 }
150
151 auto attrs = llvm::map_range(op->getAttrs(), [](const NamedAttribute &attr) {
152 return attr.getValue();
153 });
154
155 bool operandsOK = llvm::none_of(op->getOperandTypes(), isSignednessType);
156 bool resultsOK = llvm::none_of(op->getResultTypes(), isSignednessType);
157 bool attrsOK = llvm::none_of(attrs, isSignednessAttr);
158 return operandsOK && resultsOK && attrsOK;
159}
160
161//===----------------------------------------------------------------------===//
162// Conversion patterns
163//===----------------------------------------------------------------------===//
164
165namespace {
166struct ConstantOpLowering : public OpConversionPattern<ConstantOp> {
168
169 LogicalResult
170 matchAndRewrite(ConstantOp constOp, OpAdaptor adaptor,
171 ConversionPatternRewriter &rewriter) const override {
172 rewriter.replaceOpWithNewOp<hw::ConstantOp>(constOp,
173 constOp.getConstantValue());
174 return success();
175 }
176};
177struct DivOpLowering : public OpConversionPattern<DivOp> {
179
180 LogicalResult
181 matchAndRewrite(DivOp op, OpAdaptor adaptor,
182 ConversionPatternRewriter &rewriter) const override {
183 auto loc = op.getLoc();
184 auto isLhsTypeSigned =
185 cast<IntegerType>(op.getOperand(0).getType()).isSigned();
186 auto rhsType = cast<IntegerType>(op.getOperand(1).getType());
187 auto targetType = cast<IntegerType>(op.getResult().getType());
188
189 // comb.div* needs identical bitwidths for its operands and its result.
190 // Hence, we need to calculate the minimal bitwidth that can be used to
191 // represent the result as well as the operands without precision or sign
192 // loss. The target size only depends on LHS and already handles the edge
193 // cases where the bitwidth needs to be increased by 1. Thus, the targetType
194 // is good enough for both the result as well as LHS.
195 // The bitwidth for RHS is bit tricky. If the RHS is unsigned and we are
196 // about to perform a signed division, then we need one additional bit to
197 // avoid misinterpretation of RHS as a signed value!
198 bool signedDivision = targetType.isSigned();
199 unsigned extendSize = std::max(
200 targetType.getWidth(),
201 rhsType.getWidth() + (signedDivision && !rhsType.isSigned() ? 1 : 0));
202
203 // Extend the operands
204 Value lhsValue = extendTypeWidth(rewriter, loc, adaptor.getInputs()[0],
205 extendSize, isLhsTypeSigned);
206 Value rhsValue = extendTypeWidth(rewriter, loc, adaptor.getInputs()[1],
207 extendSize, rhsType.isSigned());
208
209 Value divResult;
210 if (signedDivision)
211 divResult = rewriter.create<comb::DivSOp>(loc, lhsValue, rhsValue, false)
212 ->getOpResult(0);
213 else
214 divResult = rewriter.create<comb::DivUOp>(loc, lhsValue, rhsValue, false)
215 ->getOpResult(0);
216
217 // Carry over any attributes from the original div op.
218 auto *divOp = divResult.getDefiningOp();
219 rewriter.modifyOpInPlace(
220 divOp, [&]() { divOp->setDialectAttrs(op->getDialectAttrs()); });
221
222 // finally truncate back to the expected result size!
223 Value truncateResult = extractBits(rewriter, loc, divResult, /*startBit=*/0,
224 /*bitWidth=*/targetType.getWidth());
225 rewriter.replaceOp(op, truncateResult);
226
227 return success();
228 }
229};
230} // namespace
231
232namespace {
233struct CastOpLowering : public OpConversionPattern<CastOp> {
235
236 LogicalResult
237 matchAndRewrite(CastOp op, OpAdaptor adaptor,
238 ConversionPatternRewriter &rewriter) const override {
239 auto sourceType = cast<IntegerType>(op.getIn().getType());
240 auto sourceWidth = sourceType.getWidth();
241 bool isSourceTypeSigned = sourceType.isSigned();
242 auto targetWidth = cast<IntegerType>(op.getOut().getType()).getWidth();
243
244 Value replaceValue;
245 if (sourceWidth == targetWidth) {
246 // the width does not change, we are done here and can directly use the
247 // lowering input value
248 replaceValue = adaptor.getIn();
249 } else if (sourceWidth < targetWidth) {
250 // bit extensions needed, the type of extension required is determined by
251 // the source type only!
252 replaceValue = extendTypeWidth(rewriter, op.getLoc(), adaptor.getIn(),
253 targetWidth, isSourceTypeSigned);
254 } else {
255 // bit truncation needed
256 replaceValue = extractBits(rewriter, op.getLoc(), adaptor.getIn(),
257 /*startBit=*/0, /*bitWidth=*/targetWidth);
258 }
259 rewriter.replaceOp(op, replaceValue);
260
261 return success();
262 }
263};
264} // namespace
265
266namespace {
267
268// Utility lowering function that maps a hwarith::ICmpPredicate predicate and
269// the information whether the comparison contains signed values to the
270// corresponding comb::ICmpPredicate.
271static comb::ICmpPredicate lowerPredicate(ICmpPredicate pred, bool isSigned) {
272 switch (pred) {
273 case ICmpPredicate::eq:
274 return comb::ICmpPredicate::eq;
275 case ICmpPredicate::ne:
276 return comb::ICmpPredicate::ne;
277 case ICmpPredicate::lt:
278 return isSigned ? comb::ICmpPredicate::slt : comb::ICmpPredicate::ult;
279 case ICmpPredicate::ge:
280 return isSigned ? comb::ICmpPredicate::sge : comb::ICmpPredicate::uge;
281 case ICmpPredicate::le:
282 return isSigned ? comb::ICmpPredicate::sle : comb::ICmpPredicate::ule;
283 case ICmpPredicate::gt:
284 return isSigned ? comb::ICmpPredicate::sgt : comb::ICmpPredicate::ugt;
285 }
286
287 llvm_unreachable(
288 "Missing hwarith::ICmpPredicate to comb::ICmpPredicate lowering");
289 return comb::ICmpPredicate::eq;
290}
291
292struct ICmpOpLowering : public OpConversionPattern<ICmpOp> {
294
295 LogicalResult
296 matchAndRewrite(ICmpOp op, OpAdaptor adaptor,
297 ConversionPatternRewriter &rewriter) const override {
298 auto lhsType = cast<IntegerType>(op.getLhs().getType());
299 auto rhsType = cast<IntegerType>(op.getRhs().getType());
300 IntegerType::SignednessSemantics cmpSignedness;
301 const unsigned cmpWidth =
302 inferAddResultType(cmpSignedness, lhsType, rhsType) - 1;
303
304 ICmpPredicate pred = op.getPredicate();
305 comb::ICmpPredicate combPred = lowerPredicate(
306 pred, cmpSignedness == IntegerType::SignednessSemantics::Signed);
307
308 const auto loc = op.getLoc();
309 Value lhsValue = extendTypeWidth(rewriter, loc, adaptor.getLhs(), cmpWidth,
310 lhsType.isSigned());
311 Value rhsValue = extendTypeWidth(rewriter, loc, adaptor.getRhs(), cmpWidth,
312 rhsType.isSigned());
313
314 auto newOp = rewriter.create<comb::ICmpOp>(op->getLoc(), combPred, lhsValue,
315 rhsValue, false);
316 rewriter.modifyOpInPlace(
317 newOp, [&]() { newOp->setDialectAttrs(op->getDialectAttrs()); });
318 rewriter.replaceOp(op, newOp);
319
320 return success();
321 }
322};
323
324template <class BinOp, class ReplaceOp>
325struct BinaryOpLowering : public OpConversionPattern<BinOp> {
327 using OpAdaptor = typename OpConversionPattern<BinOp>::OpAdaptor;
328
329 LogicalResult
330 matchAndRewrite(BinOp op, OpAdaptor adaptor,
331 ConversionPatternRewriter &rewriter) const override {
332 auto loc = op.getLoc();
333 auto isLhsTypeSigned =
334 cast<IntegerType>(op.getOperand(0).getType()).isSigned();
335 auto isRhsTypeSigned =
336 cast<IntegerType>(op.getOperand(1).getType()).isSigned();
337 auto targetWidth = cast<IntegerType>(op.getResult().getType()).getWidth();
338
339 Value lhsValue = extendTypeWidth(rewriter, loc, adaptor.getInputs()[0],
340 targetWidth, isLhsTypeSigned);
341 Value rhsValue = extendTypeWidth(rewriter, loc, adaptor.getInputs()[1],
342 targetWidth, isRhsTypeSigned);
343 auto newOp =
344 rewriter.create<ReplaceOp>(op.getLoc(), lhsValue, rhsValue, false);
345 rewriter.modifyOpInPlace(
346 newOp, [&]() { newOp->setDialectAttrs(op->getDialectAttrs()); });
347 rewriter.replaceOp(op, newOp);
348
349 return success();
350 }
351};
352
353} // namespace
354
356 auto it = conversionCache.find(type);
357 if (it != conversionCache.end())
358 return it->second.type;
359
360 auto convertedType =
361 llvm::TypeSwitch<Type, Type>(type)
362 .Case<IntegerType>([](auto type) {
363 if (type.isSignless())
364 return type;
365 return IntegerType::get(type.getContext(), type.getWidth());
366 })
367 .Case<hw::ArrayType>([this](auto type) {
368 return hw::ArrayType::get(removeSignedness(type.getElementType()),
369 type.getNumElements());
370 })
371 .Case<hw::StructType>([this](auto type) {
372 // Recursively convert each element.
373 llvm::SmallVector<hw::StructType::FieldInfo> convertedElements;
374 for (auto element : type.getElements()) {
375 convertedElements.push_back(
376 {element.name, removeSignedness(element.type)});
377 }
378 return hw::StructType::get(type.getContext(), convertedElements);
379 })
380 .Case<hw::InOutType>([this](auto type) {
381 return hw::InOutType::get(removeSignedness(type.getElementType()));
382 })
383 .Case<hw::TypeAliasType>([this](auto type) {
384 return hw::TypeAliasType::get(
385 type.getRef(), removeSignedness(type.getInnerType()));
386 })
387 .Default([](auto type) { return type; });
388
389 return convertedType;
390}
391
393 // Pass any type through the signedness remover.
394 addConversion([this](Type type) { return removeSignedness(type); });
395
396 addTargetMaterialization([&](mlir::OpBuilder &builder, mlir::Type resultType,
397 mlir::ValueRange inputs,
398 mlir::Location loc) -> mlir::Value {
399 if (inputs.size() != 1)
400 return Value();
401 return builder
402 .create<UnrealizedConversionCastOp>(loc, resultType, inputs[0])
403 ->getResult(0);
404 });
405
406 addSourceMaterialization([&](mlir::OpBuilder &builder, mlir::Type resultType,
407 mlir::ValueRange inputs,
408 mlir::Location loc) -> mlir::Value {
409 if (inputs.size() != 1)
410 return Value();
411 return builder
412 .create<UnrealizedConversionCastOp>(loc, resultType, inputs[0])
413 ->getResult(0);
414 });
415}
416
417//===----------------------------------------------------------------------===//
418// Pass driver
419//===----------------------------------------------------------------------===//
420
422 HWArithToHWTypeConverter &typeConverter, RewritePatternSet &patterns) {
423 patterns.add<ConstantOpLowering, CastOpLowering, ICmpOpLowering,
424 BinaryOpLowering<AddOp, comb::AddOp>,
425 BinaryOpLowering<SubOp, comb::SubOp>,
426 BinaryOpLowering<MulOp, comb::MulOp>, DivOpLowering>(
427 typeConverter, patterns.getContext());
428}
429
430namespace {
431
432class HWArithToHWPass : public circt::impl::HWArithToHWBase<HWArithToHWPass> {
433public:
434 void runOnOperation() override {
435 ModuleOp module = getOperation();
436
437 ConversionTarget target(getContext());
438 target.markUnknownOpDynamicallyLegal(isLegalOp);
439 RewritePatternSet patterns(&getContext());
440 HWArithToHWTypeConverter typeConverter;
441 target.addIllegalDialect<HWArithDialect>();
442
443 // Add HWArith-specific conversion patterns.
445
446 // ALL other operations are converted via the TypeConversionPattern which
447 // will replace an operation to an identical operation with replaced
448 // result types and operands.
449 patterns.add<TypeConversionPattern>(typeConverter, patterns.getContext());
450
451 // Apply a full conversion - all operations must either be legal, be caught
452 // by one of the HWArith patterns or be converted by the
453 // TypeConversionPattern.
454 if (failed(applyFullConversion(module, target, std::move(patterns))))
455 return signalPassFailure();
456 }
457};
458} // namespace
459
460//===----------------------------------------------------------------------===//
461// Pass initialization
462//===----------------------------------------------------------------------===//
463
464std::unique_ptr<Pass> circt::createHWArithToHWPass() {
465 return std::make_unique<HWArithToHWPass>();
466}
static SmallVector< Value > extractBits(ConversionPatternRewriter &rewriter, Value val)
Definition CombToAIG.cpp:33
static bool isLegalOp(Operation *op)
Returns true if the given op is considered as legal - i.e.
Definition DCToHW.cpp:844
static bool isSignednessAttr(Attribute attr)
static bool isLegalOp(Operation *op)
Returns true if the given op is considered as legal for HWArith conversion.
static void improveNamehint(Value oldValue, Operation *newOp, llvm::function_ref< std::string(StringRef)> namehintCallback)
static Value extractBits(OpBuilder &builder, Location loc, Value value, unsigned startBit, unsigned bitWidth)
static bool isSignednessType(Type type)
static Value extendTypeWidth(OpBuilder &builder, Location loc, Value value, unsigned targetWidth, bool signExtension)
A helper type converter class that automatically populates the relevant materializations and type con...
Definition HWArithToHW.h:32
mlir::Type removeSignedness(mlir::Type type)
llvm::DenseMap< mlir::Type, ConvertedType > conversionCache
Definition HWArithToHW.h:50
unsigned inferAddResultType(IntegerType::SignednessSemantics &signedness, IntegerType lhs, IntegerType rhs)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
void populateHWArithToHWConversionPatterns(HWArithToHWTypeConverter &typeConverter, RewritePatternSet &patterns)
Get the HWArith to HW conversion patterns.
std::unique_ptr< mlir::Pass > createHWArithToHWPass()
Generic pattern which replaces an operation by one of the same operation name, but with converted att...