diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp --- a/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp @@ -16,6 +16,7 @@ #include "mlir/IR/TypeUtilities.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/APInt.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/MathExtras.h" #include @@ -907,6 +908,52 @@ } }; +//===----------------------------------------------------------------------===// +// ConvertSIToFP +//===----------------------------------------------------------------------===// + +struct ConvertSIToFP final : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(arith::SIToFPOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + + Value in = op.getIn(); + Type oldTy = in.getType(); + auto newTy = + dyn_cast_or_null(getTypeConverter()->convertType(oldTy)); + if (!newTy) + return rewriter.notifyMatchFailure( + loc, llvm::formatv("unsupported type: {0}", oldTy)); + + unsigned oldBitWidth = getElementTypeOrSelf(oldTy).getIntOrFloatBitWidth(); + Value zeroCst = createScalarOrSplatConstant(rewriter, loc, oldTy, 0); + Value oneCst = createScalarOrSplatConstant(rewriter, loc, oldTy, 1); + Value allOnesCst = createScalarOrSplatConstant( + rewriter, loc, oldTy, APInt::getAllOnes(oldBitWidth)); + + // To avoid operating on very large unsigned numbers, perform the + // conversion on the absolute value. Then, decide whether to negate the + // result or not based on that sign bit. We assume two's complement and + // implement negation by flipping all bits and adding 1. + // Note that this relies on the the other conversion patterns to legalize + // created ops and narrow the bit widths. + Value isNeg = rewriter.create(loc, arith::CmpIPredicate::slt, + in, zeroCst); + Value bitwiseNeg = rewriter.create(loc, in, allOnesCst); + Value neg = rewriter.create(loc, bitwiseNeg, oneCst); + Value abs = rewriter.create(loc, isNeg, neg, in); + + Value absResult = rewriter.create(loc, op.getType(), abs); + Value negResult = rewriter.create(loc, absResult); + rewriter.replaceOpWithNewOp(op, isNeg, negResult, + absResult); + return success(); + } +}; + //===----------------------------------------------------------------------===// // ConvertUIToFP //===----------------------------------------------------------------------===// @@ -1146,5 +1193,5 @@ ConvertIndexCastIntToIndex, ConvertIndexCastIndexToInt, ConvertIndexCastIndexToInt, - ConvertUIToFP>(typeConverter, patterns.getContext()); + ConvertSIToFP, ConvertUIToFP>(typeConverter, patterns.getContext()); } diff --git a/mlir/test/Dialect/Arith/emulate-wide-int.mlir b/mlir/test/Dialect/Arith/emulate-wide-int.mlir --- a/mlir/test/Dialect/Arith/emulate-wide-int.mlir +++ b/mlir/test/Dialect/Arith/emulate-wide-int.mlir @@ -964,3 +964,46 @@ %r = arith.uitofp %a : i64 to f16 return %r : f16 } + +// CHECK-LABEL: func @sitofp_i64_f64 +// CHECK-SAME: ([[ARG:%.+]]: vector<2xi32>) -> f64 +// CHECK: [[VONES:%.+]] = arith.constant dense<-1> : vector<2xi32> +// CHECK: [[ONES1:%.+]] = vector.extract [[VONES]][0] : vector<2xi32> +// CHECK-NEXT: [[ONES2:%.+]] = vector.extract [[VONES]][1] : vector<2xi32> +// CHECK: arith.xori {{%.+}}, [[ONES1]] : i32 +// CHECK-NEXT: arith.xori {{%.+}}, [[ONES2]] : i32 +// CHECK: [[CST0:%.+]] = arith.constant 0 : i32 +// CHECK: [[HIEQ0:%.+]] = arith.cmpi eq, [[HI:%.+]], [[CST0]] : i32 +// CHECK-NEXT: [[LOWFP:%.+]] = arith.uitofp [[LOW:%.+]] : i32 to f64 +// CHECK-NEXT: [[HIFP:%.+]] = arith.uitofp [[HI]] : i32 to f64 +// CHECK-NEXT: [[POW:%.+]] = arith.constant 0x41F0000000000000 : f64 +// CHECK-NEXT: [[RESHI:%.+]] = arith.mulf [[HIFP]], [[POW]] : f64 +// CHECK-NEXT: [[RES:%.+]] = arith.addf [[LOWFP]], [[RESHI]] : f64 +// CHECK-NEXT: [[SEL:%.+]] = arith.select [[HIEQ0]], [[LOWFP]], [[RES]] : f64 +// CHECK-NEXT: [[NEG:%.+]] = arith.negf [[SEL]] : f64 +// CHECK-NEXT: [[FINAL:%.+]] = arith.select %{{.+}}, [[NEG]], [[SEL]] : f64 +// CHECK-NEXT: return [[FINAL]] : f64 +func.func @sitofp_i64_f64(%a : i64) -> f64 { + %r = arith.sitofp %a : i64 to f64 + return %r : f64 +} + +// CHECK-LABEL: func @sitofp_i64_f64_vector +// CHECK-SAME: ([[ARG:%.+]]: vector<3x2xi32>) -> vector<3xf64> +// CHECK: [[VONES:%.+]] = arith.constant dense<-1> : vector<3x2xi32> +// CHECK: arith.xori +// CHECK-NEXT: arith.xori +// CHECK: [[HIEQ0:%.+]] = arith.cmpi eq, [[HI:%.+]], [[CST0:%.+]] : vector<3xi32> +// CHECK-NEXT: [[LOWFP:%.+]] = arith.uitofp [[LOW:%.+]] : vector<3xi32> to vector<3xf64> +// CHECK-NEXT: [[HIFP:%.+]] = arith.uitofp [[HI:%.+]] : vector<3xi32> to vector<3xf64> +// CHECK-NEXT: [[POW:%.+]] = arith.constant dense<0x41F0000000000000> : vector<3xf64> +// CHECK-NEXT: [[RESHI:%.+]] = arith.mulf [[HIFP]], [[POW]] : vector<3xf64> +// CHECK-NEXT: [[RES:%.+]] = arith.addf [[LOWFP]], [[RESHI]] : vector<3xf64> +// CHECK-NEXT: [[SEL:%.+]] = arith.select [[HIEQ0]], [[LOWFP]], [[RES]] : vector<3xi1>, vector<3xf64> +// CHECK-NEXT: [[NEG:%.+]] = arith.negf [[SEL]] : vector<3xf64> +// CHECK-NEXT: [[FINAL:%.+]] = arith.select %{{.+}}, [[NEG]], [[SEL]] : vector<3xi1>, vector<3xf64> +// CHECK-NEXT: return [[FINAL]] : vector<3xf64> +func.func @sitofp_i64_f64_vector(%a : vector<3xi64>) -> vector<3xf64> { + %r = arith.sitofp %a : vector<3xi64> to vector<3xf64> + return %r : vector<3xf64> +} diff --git a/mlir/test/Integration/Dialect/Arith/CPU/test-wide-int-emulation-sitofp-i32.mlir b/mlir/test/Integration/Dialect/Arith/CPU/test-wide-int-emulation-sitofp-i32.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Integration/Dialect/Arith/CPU/test-wide-int-emulation-sitofp-i32.mlir @@ -0,0 +1,68 @@ +// Check that the wide integer `arith.sitofp` emulation produces the same result as wide +// `arith.sitofp`. Emulate i32 ops with i16 ops. + +// RUN: mlir-opt %s --convert-scf-to-cf --convert-cf-to-llvm --convert-vector-to-llvm \ +// RUN: --convert-func-to-llvm --convert-arith-to-llvm | \ +// RUN: mlir-cpu-runner -e entry -entry-point-result=void \ +// RUN: --shared-libs=%mlir_c_runner_utils | \ +// RUN: FileCheck %s --match-full-lines + +// RUN: mlir-opt %s --test-arith-emulate-wide-int="widest-int-supported=16" \ +// RUN: --convert-scf-to-cf --convert-cf-to-llvm --convert-vector-to-llvm \ +// RUN: --convert-func-to-llvm --convert-arith-to-llvm | \ +// RUN: mlir-cpu-runner -e entry -entry-point-result=void \ +// RUN: --shared-libs=%mlir_c_runner_utils | \ +// RUN: FileCheck %s --match-full-lines + +// Ops in this function *only* will be emulated using i16 types. +func.func @emulate_sitofp(%arg: i32) -> f32 { + %res = arith.sitofp %arg : i32 to f32 + return %res : f32 +} + +func.func @check_sitofp(%arg : i32) -> () { + %res = func.call @emulate_sitofp(%arg) : (i32) -> (f32) + vector.print %res : f32 + return +} + +func.func @entry() { + %cst0 = arith.constant 0 : i32 + %cst1 = arith.constant 1 : i32 + %cst2 = arith.constant 2 : i32 + %cst7 = arith.constant 7 : i32 + %cst1337 = arith.constant 1337 : i32 + + %cst_n1 = arith.constant -1 : i32 + %cst_n13 = arith.constant -13 : i32 + %cst_n1337 = arith.constant -1337 : i32 + + %cst_i16_min = arith.constant -32768 : i32 + + %cst_f32_int_max = arith.constant 16777217 : i32 + %cst_f32_int_min = arith.constant -16777217 : i32 + + // CHECK: 0 + func.call @check_sitofp(%cst0) : (i32) -> () + // CHECK-NEXT: 1 + func.call @check_sitofp(%cst1) : (i32) -> () + // CHECK-NEXT: 2 + func.call @check_sitofp(%cst2) : (i32) -> () + // CHECK-NEXT: 7 + func.call @check_sitofp(%cst7) : (i32) -> () + // CHECK-NEXT: 1337 + func.call @check_sitofp(%cst1337) : (i32) -> () + // CHECK-NEXT: -1 + func.call @check_sitofp(%cst_n1) : (i32) -> () + // CHECK-NEXT: -1337 + func.call @check_sitofp(%cst_n1337) : (i32) -> () + + // CHECK-NEXT: -32768 + func.call @check_sitofp(%cst_i16_min) : (i32) -> () + // CHECK-NEXT: 1.6{{.+}}e+07 + func.call @check_sitofp(%cst_f32_int_max) : (i32) -> () + // CHECK-NEXT: -1.6{{.+}}e+07 + func.call @check_sitofp(%cst_f32_int_min) : (i32) -> () + + return +}