diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp --- a/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp @@ -13,7 +13,10 @@ #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Func/Transforms/FuncConversions.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/BuiltinTypes.h" #include "mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/MathExtras.h" #include @@ -781,6 +784,69 @@ } }; +//===----------------------------------------------------------------------===// +// ConvertShRSI +//===----------------------------------------------------------------------===// + +struct ConvertShRSI final : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(arith::ShRSIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + + Type oldTy = op.getType(); + auto newTy = + getTypeConverter()->convertType(oldTy).dyn_cast_or_null(); + if (!newTy) + return rewriter.notifyMatchFailure( + loc, llvm::formatv("unsupported type: {0}", op.getType())); + + Value lhsElem1 = extractLastDimSlice(rewriter, loc, adaptor.getLhs(), 1); + Value rhsElem0 = extractLastDimSlice(rewriter, loc, adaptor.getRhs(), 0); + + Type narrowTy = rhsElem0.getType(); + int64_t origBitwidth = newTy.getElementTypeBitWidth() * 2; + + // Rewrite this as an bitwise or of `arith.shrui` and sign extension bits. + // Perform as many ops over the narrow integer type as possible and let the + // other emulation patterns convert the rest. + Value elemZero = + createScalarOrSplatConstant(rewriter, loc, narrowTy, 0); + Value signBit = rewriter.create( + loc, arith::CmpIPredicate::slt, lhsElem1, elemZero); + signBit = dropTrailingX1Dim(rewriter, loc, signBit); + + // Create a bit pattern of either all ones or all zeros. Then shift it left + // to calculate the sign extension bits created by shifting the original + // sign bit right. + Value allSign = rewriter.create(loc, oldTy, signBit); + Value maxShift = + createScalarOrSplatConstant(rewriter, loc, narrowTy, origBitwidth); + Value numNonSignExtBits = + rewriter.create(loc, maxShift, rhsElem0); + numNonSignExtBits = dropTrailingX1Dim(rewriter, loc, numNonSignExtBits); + numNonSignExtBits = + rewriter.create(loc, oldTy, numNonSignExtBits); + Value signBits = + rewriter.create(loc, allSign, numNonSignExtBits); + + // Use original arguments to create the right shift. + Value shrui = rewriter.create(loc, op.getLhs(), op.getRhs()); + Value shrsi = rewriter.create(loc, shrui, signBits); + + // Handle shifting by zero. This is necessary when the `signBits` shift is + // invalid. + Value isNoop = rewriter.create(loc, arith::CmpIPredicate::eq, + rhsElem0, elemZero); + isNoop = dropTrailingX1Dim(rewriter, loc, isNoop); + rewriter.replaceOpWithNewOp(op, isNoop, op.getLhs(), shrsi); + + return success(); + } +}; + //===----------------------------------------------------------------------===// // ConvertTruncI //===----------------------------------------------------------------------===// @@ -799,7 +865,8 @@ loc, llvm::formatv("unsupported truncation result type: {0}", op.getType())); - // Discard the high half of the input. Truncate the low half, if necessary. + // Discard the high half of the input. Truncate the low half, if + // necessary. Value extracted = extractLastDimSlice(rewriter, loc, adaptor.getIn(), 0); extracted = dropTrailingX1Dim(rewriter, loc, extracted); Value truncated = @@ -940,7 +1007,7 @@ // Misc ops. ConvertConstant, ConvertCmpI, ConvertSelect, ConvertVectorPrint, // Binary ops. - ConvertAddI, ConvertMulI, ConvertShLI, ConvertShRUI, + ConvertAddI, ConvertMulI, ConvertShLI, ConvertShRSI, ConvertShRUI, // Bitwise binary ops. ConvertBitwiseBinary, ConvertBitwiseBinary, ConvertBitwiseBinary, diff --git a/mlir/test/Dialect/Arith/emulate-wide-int.mlir b/mlir/test/Dialect/Arith/emulate-wide-int.mlir --- a/mlir/test/Dialect/Arith/emulate-wide-int.mlir +++ b/mlir/test/Dialect/Arith/emulate-wide-int.mlir @@ -587,6 +587,45 @@ return %m : vector<3xi64> } +// CHECK-LABEL: func.func @shrsi_scalar +// CHECK-SAME: ([[ARG0:%.+]]: vector<2xi32>, [[ARG1:%.+]]: vector<2xi32>) -> vector<2xi32> +// CHECK-NEXT: [[HIGH0:%.+]] = vector.extract [[ARG0]][1] : vector<2xi32> +// CHECK-NEXT: [[LOW1:%.+]] = vector.extract [[ARG1]][0] : vector<2xi32> +// CHECK-NEXT: [[CST0:%.+]] = arith.constant 0 : i32 +// CHECK-NEXT: [[NEG:%.+]] = arith.cmpi slt, [[HIGH0]], [[CST0]] : i32 +// CHECK-NEXT: [[NEGEXT:%.+]] = arith.extsi [[NEG]] : i1 to i32 +// CHECK: [[CST64:%.+]] = arith.constant 64 : i32 +// CHECK-NEXT: [[SIGNBITS:%.+]] = arith.subi [[CST64]], [[LOW1]] : i32 +// CHECK: arith.shli +// CHECK: arith.shrui +// CHECK: arith.shli +// CHECK: arith.shli +// CHECK: arith.shrui +// CHECK: arith.shrui +// CHECK: arith.shli +// CHECK: arith.shrui +// CHECK: return {{%.+}} : vector<2xi32> +func.func @shrsi_scalar(%a : i64, %b : i64) -> i64 { + %c = arith.shrsi %a, %b : i64 + return %c : i64 +} + +// CHECK-LABEL: func.func @shrsi_vector +// CHECK-SAME: ({{%.+}}: vector<3x2xi32>, {{%.+}}: vector<3x2xi32>) -> vector<3x2xi32> +// CHECK: arith.shli {{%.+}}, {{%.+}} : vector<3x1xi32> +// CHECK: arith.shrui {{%.+}}, {{%.+}} : vector<3x1xi32> +// CHECK: arith.shli {{%.+}}, {{%.+}} : vector<3x1xi32> +// CHECK: arith.shli {{%.+}}, {{%.+}} : vector<3x1xi32> +// CHECK: arith.shrui {{%.+}}, {{%.+}} : vector<3x1xi32> +// CHECK: arith.shrui {{%.+}}, {{%.+}} : vector<3x1xi32> +// CHECK: arith.shli {{%.+}}, {{%.+}} : vector<3x1xi32> +// CHECK: arith.shrui {{%.+}}, {{%.+}} : vector<3x1xi32> +// CHECK: return {{%.+}} : vector<3x2xi32> +func.func @shrsi_vector(%a : vector<3xi64>, %b : vector<3xi64>) -> vector<3xi64> { + %m = arith.shrsi %a, %b : vector<3xi64> + return %m : vector<3xi64> +} + // CHECK-LABEL: func @andi_scalar_a_b // CHECK-SAME: ([[ARG0:%.+]]: vector<2xi32>, [[ARG1:%.+]]: vector<2xi32>) -> vector<2xi32> // CHECK-NEXT: [[LOW0:%.+]] = vector.extract [[ARG0]][0] : vector<2xi32> diff --git a/mlir/test/Integration/Dialect/Arith/CPU/test-wide-int-emulation-compare-results-i16.mlir b/mlir/test/Integration/Dialect/Arith/CPU/test-wide-int-emulation-compare-results-i16.mlir --- a/mlir/test/Integration/Dialect/Arith/CPU/test-wide-int-emulation-compare-results-i16.mlir +++ b/mlir/test/Integration/Dialect/Arith/CPU/test-wide-int-emulation-compare-results-i16.mlir @@ -203,6 +203,53 @@ return } +//===----------------------------------------------------------------------===// +// Test arith.shrsi +//===----------------------------------------------------------------------===// + +// Ops in this function will be emulated using i8 ops. +func.func @emulate_shrsi(%lhs : i16, %rhs : i16) -> (i16) { + %res = arith.shrsi %lhs, %rhs : i16 + return %res : i16 +} + +// Performs both wide and emulated `arith.shrsi`, and checks that the results +// match. +func.func @check_shrsi(%lhs : i16, %rhs : i16) -> () { + %wide = arith.shrsi %lhs, %rhs : i16 + %emulated = func.call @emulate_shrsi(%lhs, %rhs) : (i16, i16) -> (i16) + func.call @check_results(%lhs, %rhs, %wide, %emulated) : (i16, i16, i16, i16) -> () + return +} + +// Checks that `arith.shrus` is emulated properly by sampling the input space. +// Checks all valid shift amounts for i16: 0 to 15. +// In total, this test function checks 100 * 16 = 1.6k input pairs. +func.func @test_shrsi() -> () { + %idx0 = arith.constant 0 : index + %idx1 = arith.constant 1 : index + %idx16 = arith.constant 16 : index + %idx100 = arith.constant 100 : index + + %cst0 = arith.constant 0 : i16 + %cst1 = arith.constant 1 : i16 + + scf.for %lhs_idx = %idx0 to %idx100 step %idx1 iter_args(%lhs = %cst0) -> (i16) { + %arg_lhs = func.call @xhash(%lhs) : (i16) -> (i16) + + scf.for %rhs_idx = %idx0 to %idx16 step %idx1 iter_args(%rhs = %cst0) -> (i16) { + func.call @check_shrsi(%arg_lhs, %rhs) : (i16, i16) -> () + %rhs_next = arith.addi %rhs, %cst1 : i16 + scf.yield %rhs_next : i16 + } + + %lhs_next = arith.addi %lhs, %cst1 : i16 + scf.yield %lhs_next : i16 + } + + return +} + //===----------------------------------------------------------------------===// // Test arith.shrui //===----------------------------------------------------------------------===// @@ -258,6 +305,7 @@ func.call @test_addi() : () -> () func.call @test_muli() : () -> () func.call @test_shli() : () -> () + func.call @test_shrsi() : () -> () func.call @test_shrui() : () -> () return } diff --git a/mlir/test/Integration/Dialect/Arith/CPU/test-wide-int-emulation-shrsi-i16.mlir b/mlir/test/Integration/Dialect/Arith/CPU/test-wide-int-emulation-shrsi-i16.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Integration/Dialect/Arith/CPU/test-wide-int-emulation-shrsi-i16.mlir @@ -0,0 +1,100 @@ +// Check that the wide integer `arith.shrsi` emulation produces the same result as wide +// `arith.shrsi`. Emulate i16 ops with i8 ops. + +// RUN: mlir-opt %s --convert-scf-to-cf --convert-cf-to-llvm --convert-vector-to-llvm \ +// RUN: --convert-func-to-llvm --convert-arith-to-llvm | \ +// RUN: mlir-cpu-runner -e entry -entry-point-result=void \ +// RUN: --shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext | \ +// RUN: FileCheck %s --match-full-lines + +// RUN: mlir-opt %s --test-arith-emulate-wide-int="widest-int-supported=8" \ +// RUN: --convert-scf-to-cf --convert-cf-to-llvm --convert-vector-to-llvm \ +// RUN: --convert-func-to-llvm --convert-arith-to-llvm | \ +// RUN: mlir-cpu-runner -e entry -entry-point-result=void \ +// RUN: --shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext | \ +// RUN: FileCheck %s --match-full-lines + +// Ops in this function *only* will be emulated using i8 types. +func.func @emulate_shrsi(%lhs : i16, %rhs : i16) -> (i16) { + %res = arith.shrsi %lhs, %rhs : i16 + return %res : i16 +} + +func.func @check_shrsi(%lhs : i16, %rhs : i16) -> () { + %res = func.call @emulate_shrsi(%lhs, %rhs) : (i16, i16) -> (i16) + vector.print %res : i16 + return +} + +func.func @entry() { + %cst0 = arith.constant 0 : i16 + %cst1 = arith.constant 1 : i16 + %cst2 = arith.constant 2 : i16 + %cst7 = arith.constant 7 : i16 + %cst8 = arith.constant 8 : i16 + %cst9 = arith.constant 9 : i16 + %cst15 = arith.constant 15 : i16 + + %cst_n1 = arith.constant -1 : i16 + + %cst1337 = arith.constant 1337 : i16 + %cst_n1337 = arith.constant -1337 : i16 + + %cst_i16_min = arith.constant -32768 : i16 + + // CHECK: -32768 + // CHECK-NEXT: -16384 + // CHECK-NEXT: -8192 + // CHECK-NEXT: -256 + // CHECK-NEXT: -128 + // CHECK-NEXT: -64 + // CHECK-NEXT: -1 + func.call @check_shrsi(%cst_i16_min, %cst0) : (i16, i16) -> () + func.call @check_shrsi(%cst_i16_min, %cst1) : (i16, i16) -> () + func.call @check_shrsi(%cst_i16_min, %cst2) : (i16, i16) -> () + func.call @check_shrsi(%cst_i16_min, %cst7) : (i16, i16) -> () + func.call @check_shrsi(%cst_i16_min, %cst8) : (i16, i16) -> () + func.call @check_shrsi(%cst_i16_min, %cst9) : (i16, i16) -> () + func.call @check_shrsi(%cst_i16_min, %cst15) : (i16, i16) -> () + + // CHECK-NEXT: 0 + // CHECK-NEXT: 0 + // CHECK-NEXT: 0 + // CHECK-NEXT: 1 + // CHECK-NEXT: -1 + // CHECK-NEXT: -1 + func.call @check_shrsi(%cst0, %cst0) : (i16, i16) -> () + func.call @check_shrsi(%cst0, %cst1) : (i16, i16) -> () + func.call @check_shrsi(%cst1, %cst1) : (i16, i16) -> () + func.call @check_shrsi(%cst1, %cst0) : (i16, i16) -> () + func.call @check_shrsi(%cst_n1, %cst1) : (i16, i16) -> () + func.call @check_shrsi(%cst_n1, %cst15) : (i16, i16) -> () + + // CHECK-NEXT: 1337 + // CHECK-NEXT: 334 + // CHECK-NEXT: 10 + // CHECK-NEXT: 5 + // CHECK-NEXT: 2 + // CHECK-NEXT: 0 + func.call @check_shrsi(%cst1337, %cst0) : (i16, i16) -> () + func.call @check_shrsi(%cst1337, %cst2) : (i16, i16) -> () + func.call @check_shrsi(%cst1337, %cst7) : (i16, i16) -> () + func.call @check_shrsi(%cst1337, %cst8) : (i16, i16) -> () + func.call @check_shrsi(%cst1337, %cst9) : (i16, i16) -> () + func.call @check_shrsi(%cst1337, %cst15) : (i16, i16) -> () + + // CHECK-NEXT: -1337 + // CHECK-NEXT: -335 + // CHECK-NEXT: -11 + // CHECK-NEXT: -6 + // CHECK-NEXT: -3 + // CHECK-NEXT: -1 + func.call @check_shrsi(%cst_n1337, %cst0) : (i16, i16) -> () + func.call @check_shrsi(%cst_n1337, %cst2) : (i16, i16) -> () + func.call @check_shrsi(%cst_n1337, %cst7) : (i16, i16) -> () + func.call @check_shrsi(%cst_n1337, %cst8) : (i16, i16) -> () + func.call @check_shrsi(%cst_n1337, %cst9) : (i16, i16) -> () + func.call @check_shrsi(%cst_n1337, %cst15) : (i16, i16) -> () + + return +} diff --git a/mlir/test/Integration/Dialect/Arith/CPU/test-wide-int-emulation-shrui-i16.mlir b/mlir/test/Integration/Dialect/Arith/CPU/test-wide-int-emulation-shrui-i16.mlir --- a/mlir/test/Integration/Dialect/Arith/CPU/test-wide-int-emulation-shrui-i16.mlir +++ b/mlir/test/Integration/Dialect/Arith/CPU/test-wide-int-emulation-shrui-i16.mlir @@ -38,6 +38,7 @@ %cst_n1 = arith.constant -1 : i16 %cst1337 = arith.constant 1337 : i16 + %cst_n1337 = arith.constant -1337 : i16 %cst_i16_min = arith.constant -32768 : i16 @@ -67,6 +68,19 @@ func.call @check_shrui(%cst1337, %cst9) : (i16, i16) -> () func.call @check_shrui(%cst1337, %cst15) : (i16, i16) -> () + // CHECK-NEXT: -1337 + // CHECK-NEXT: 16049 + // CHECK-NEXT: 501 + // CHECK-NEXT: 250 + // CHECK-NEXT: 125 + // CHECK-NEXT: 1 + func.call @check_shrui(%cst_n1337, %cst0) : (i16, i16) -> () + func.call @check_shrui(%cst_n1337, %cst2) : (i16, i16) -> () + func.call @check_shrui(%cst_n1337, %cst7) : (i16, i16) -> () + func.call @check_shrui(%cst_n1337, %cst8) : (i16, i16) -> () + func.call @check_shrui(%cst_n1337, %cst9) : (i16, i16) -> () + func.call @check_shrui(%cst_n1337, %cst15) : (i16, i16) -> () + // CHECK-NEXT: 16384 // CHECK-NEXT: 8192 // CHECK-NEXT: 256