diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp --- a/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp @@ -13,8 +13,12 @@ #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Func/Transforms/FuncConversions.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/TypeUtilities.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/APFloat.h" +#include "llvm/ADT/APInt.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/MathExtras.h" #include @@ -906,6 +910,58 @@ } }; +//===----------------------------------------------------------------------===// +// ConvertSIToFP +//===----------------------------------------------------------------------===// + +struct ConvertSIToFP final : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(arith::SIToFPOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + + Type oldTy = op.getIn().getType(); + auto newTy = + dyn_cast_or_null(getTypeConverter()->convertType(oldTy)); + if (!newTy) + return rewriter.notifyMatchFailure( + loc, llvm::formatv("unsupported type: {0}", op.getType())); + unsigned newBitWidth = newTy.getElementTypeBitWidth(); + + auto [low, hi] = extractLastDimHalves(rewriter, loc, adaptor.getIn()); + Value lowInt = dropTrailingX1Dim(rewriter, loc, low); + Value hiInt = dropTrailingX1Dim(rewriter, loc, hi); + + // The final result has the following form: + // fp = sitofp(low) + (sitofp(hi) * 2^BW) + // + // where `BW` is the bitwidth of the narrowed integer type. + // + // Note that this emulation is precise only for input values that have exact + // integer representation in the result floating point type, and may lead + // loss of precision otherwise. + Value lowFp = + rewriter.createOrFold(loc, op.getType(), lowInt); + + Type resultTy = op.getType(); + Type resultElemTy = getElementTypeOrSelf(resultTy); + + Value hiFp = rewriter.create(loc, op.getType(), hiInt); + int64_t pow2Int = int64_t(1) << newBitWidth; + Attribute pow2Attr = + rewriter.getFloatAttr(resultElemTy, static_cast(pow2Int)); + if (auto vecTy = dyn_cast(resultTy)) + pow2Attr = SplatElementsAttr::get(vecTy, pow2Attr); + + Value pow2Val = rewriter.create(loc, resultTy, pow2Attr); + Value hiVal = rewriter.create(loc, hiFp, pow2Val); + rewriter.replaceOpWithNewOp(op, lowFp, hiVal); + return success(); + } +}; + //===----------------------------------------------------------------------===// // ConvertTruncI //===----------------------------------------------------------------------===// @@ -1080,6 +1136,6 @@ ConvertIndexCastIntToIndex, ConvertIndexCastIntToIndex, ConvertIndexCastIndexToInt, - ConvertIndexCastIndexToInt>( - typeConverter, patterns.getContext()); + ConvertIndexCastIndexToInt, + ConvertSIToFP>(typeConverter, patterns.getContext()); } diff --git a/mlir/test/Dialect/Arith/emulate-wide-int.mlir b/mlir/test/Dialect/Arith/emulate-wide-int.mlir --- a/mlir/test/Dialect/Arith/emulate-wide-int.mlir +++ b/mlir/test/Dialect/Arith/emulate-wide-int.mlir @@ -908,3 +908,50 @@ %x = arith.xori %a, %b : vector<3xi64> return %x : vector<3xi64> } + +// CHECK-LABEL: func @sitofp_i64_f64 +// CHECK-SAME: ([[ARG:%.+]]: vector<2xi32>) -> f64 +// CHECK-NEXT: [[LOW:%.+]] = vector.extract [[ARG]][0] : vector<2xi32> +// CHECK-NEXT: [[HI:%.+]] = vector.extract [[ARG]][1] : vector<2xi32> +// CHECK-NEXT: [[LOWFP:%.+]] = arith.sitofp [[LOW]] : i32 to f64 +// CHECK-NEXT: [[HIFP:%.+]] = arith.sitofp [[HI]] : i32 to f64 +// CHECK-NEXT: [[POW:%.+]] = arith.constant 0x41F0000000000000 : f64 +// CHECK-NEXT: [[RESHI:%.+]] = arith.mulf [[HIFP]], [[POW]] : f64 +// CHECK-NEXT: [[RES:%.+]] = arith.addf [[LOWFP]], [[RESHI]] : f64 +// CHECK-NEXT: return [[RES]] : f64 +func.func @sitofp_i64_f64(%a : i64) -> f64 { + %r = arith.sitofp %a : i64 to f64 + return %r : f64 +} + +// CHECK-LABEL: func @sitofp_i64_f64_vector +// CHECK-SAME: ([[ARG:%.+]]: vector<3x2xi32>) -> vector<3xf64> +// CHECK-NEXT: [[EXTLOW:%.+]] = vector.extract_strided_slice [[ARG]] {offsets = [0, 0], sizes = [3, 1], strides = [1, 1]} : vector<3x2xi32> to vector<3x1xi32> +// CHECK-NEXT: [[EXTHI:%.+]] = vector.extract_strided_slice [[ARG]] {offsets = [0, 1], sizes = [3, 1], strides = [1, 1]} : vector<3x2xi32> to vector<3x1xi32> +// CHECK-NEXT: [[LOW:%.+]] = vector.shape_cast [[EXTLOW]] : vector<3x1xi32> to vector<3xi32> +// CHECK-NEXT: [[HI:%.+]] = vector.shape_cast [[EXTHI]] : vector<3x1xi32> to vector<3xi32> +// CHECK-NEXT: [[LOWFP:%.+]] = arith.sitofp [[LOW]] : vector<3xi32> to vector<3xf64> +// CHECK-NEXT: [[HIFP:%.+]] = arith.sitofp [[HI]] : vector<3xi32> to vector<3xf64> +// CHECK-NEXT: [[POW:%.+]] = arith.constant dense<0x41F0000000000000> : vector<3xf64> +// CHECK-NEXT: [[RESHI:%.+]] = arith.mulf [[HIFP]], [[POW]] : vector<3xf64> +// CHECK-NEXT: [[RES:%.+]] = arith.addf [[LOWFP]], [[RESHI]] : vector<3xf64> +// CHECK-NEXT: return [[RES]] : vector<3xf64> +func.func @sitofp_i64_f64_vector(%a : vector<3xi64>) -> vector<3xf64> { + %r = arith.sitofp %a : vector<3xi64> to vector<3xf64> + return %r : vector<3xf64> +} + +// CHECK-LABEL: func @sitofp_i64_f16 +// CHECK-SAME: ([[ARG:%.+]]: vector<2xi32>) -> f16 +// CHECK-NEXT: [[LOW:%.+]] = vector.extract [[ARG]][0] : vector<2xi32> +// CHECK-NEXT: [[HI:%.+]] = vector.extract [[ARG]][1] : vector<2xi32> +// CHECK-NEXT: [[LOWFP:%.+]] = arith.sitofp [[LOW]] : i32 to f16 +// CHECK-NEXT: [[HIFP:%.+]] = arith.sitofp [[HI]] : i32 to f16 +// CHECK-NEXT: [[POW:%.+]] = arith.constant 0x7C00 : f16 +// CHECK-NEXT: [[RESHI:%.+]] = arith.mulf [[HIFP]], [[POW]] : f16 +// CHECK-NEXT: [[RES:%.+]] = arith.addf [[LOWFP]], [[RESHI]] : f16 +// CHECK-NEXT: return [[RES]] : f16 +func.func @sitofp_i64_f16(%a : i64) -> f16 { + %r = arith.sitofp %a : i64 to f16 + return %r : f16 +}