diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp --- a/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp @@ -485,6 +485,95 @@ } }; +//===----------------------------------------------------------------------===// +// ConvertShLI +//===----------------------------------------------------------------------===// + +struct ConvertShLI final : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(arith::ShLIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + + Type oldTy = op.getType(); + auto newTy = + getTypeConverter()->convertType(oldTy).dyn_cast_or_null(); + if (!newTy) + return rewriter.notifyMatchFailure(loc, "unsupported type"); + + Type newOperandTy = reduceInnermostDim(newTy); + // `oldBitWidth` == `2 * newBitWidth` + unsigned newBitWidth = newTy.getElementTypeBitWidth(); + + auto [lhsElem0, lhsElem1] = + extractLastDimHalves(rewriter, loc, adaptor.getLhs()); + Value rhsElem0 = extractLastDimSlice(rewriter, loc, adaptor.getRhs(), 0); + + // Assume that the shift amount is < 2 * newBitWidth. Calculate the low and + // high halves of the results separately: + // 1. low := LHS.low shli RHS + // + // 2. high := a or b or c, where: + // a) Bits from LHS.high, shifted by the RHS. + // b) Bits from LHS.low, shifted right. These come into play when + // RHS < newBitWidth, e.g.: + // [0000][llll] shli 3 --> [0lll][l000] + // ^ + // | + // [llll] shrui (4 - 3) + // c) Bits from LHS.low, shifted left. These matter when + // RHS > newBitWidth, e.g.: + // [0000][llll] shli 7 --> [l000][0000] + // ^ + // | + // [llll] shli (7 - 4) + // + // Because shifts by values >= newBitWidth are undefined, we ignore the high + // half of RHS, and introduce 'bounds checks' to account for + // RHS.low > newBitWidth. + // + // TODO: Explore possible optimizations. + Value zeroCst = createScalarOrSplatConstant(rewriter, loc, newOperandTy, 0); + Value elemBitWidth = + createScalarOrSplatConstant(rewriter, loc, newOperandTy, newBitWidth); + + Value illegalElemShift = rewriter.create( + loc, arith::CmpIPredicate::uge, rhsElem0, elemBitWidth); + + Value shiftedElem0 = + rewriter.create(loc, lhsElem0, rhsElem0); + Value resElem0 = rewriter.create(loc, illegalElemShift, + zeroCst, shiftedElem0); + + Value cappedShiftAmount = rewriter.create( + loc, illegalElemShift, elemBitWidth, rhsElem0); + Value rightShiftAmount = + rewriter.create(loc, elemBitWidth, cappedShiftAmount); + Value shiftedRight = + rewriter.create(loc, lhsElem0, rightShiftAmount); + Value overshotShiftAmount = + rewriter.create(loc, rhsElem0, elemBitWidth); + Value shiftedLeft = + rewriter.create(loc, lhsElem0, overshotShiftAmount); + + Value shiftedElem1 = + rewriter.create(loc, lhsElem1, rhsElem0); + Value resElem1High = rewriter.create( + loc, illegalElemShift, zeroCst, shiftedElem1); + Value resElem1Low = rewriter.create( + loc, illegalElemShift, shiftedLeft, shiftedRight); + Value resElem1 = + rewriter.create(loc, resElem1Low, resElem1High); + + Value resultVec = + constructResultVector(rewriter, loc, newTy, {resElem0, resElem1}); + rewriter.replaceOp(op, resultVec); + return success(); + } +}; + //===----------------------------------------------------------------------===// // ConvertShRUI //===----------------------------------------------------------------------===// @@ -498,8 +587,13 @@ Location loc = op->getLoc(); Type oldTy = op.getType(); - auto newTy = getTypeConverter()->convertType(oldTy).cast(); + auto newTy = + getTypeConverter()->convertType(oldTy).dyn_cast_or_null(); + if (!newTy) + return rewriter.notifyMatchFailure(loc, "unsupported type"); + Type newOperandTy = reduceInnermostDim(newTy); + // `oldBitWidth` == `2 * newBitWidth` unsigned newBitWidth = newTy.getElementTypeBitWidth(); auto [lhsElem0, lhsElem1] = @@ -727,7 +821,7 @@ // Misc ops. ConvertConstant, ConvertVectorPrint, // Binary ops. - ConvertAddI, ConvertMulI, ConvertShRUI, + ConvertAddI, ConvertMulI, ConvertShLI, ConvertShRUI, // Bitwise binary ops. ConvertBitwiseBinary, ConvertBitwiseBinary, ConvertBitwiseBinary, diff --git a/mlir/test/Dialect/Arith/emulate-wide-int.mlir b/mlir/test/Dialect/Arith/emulate-wide-int.mlir --- a/mlir/test/Dialect/Arith/emulate-wide-int.mlir +++ b/mlir/test/Dialect/Arith/emulate-wide-int.mlir @@ -278,6 +278,46 @@ return %m : vector<3xi64> } +// CHECK-LABEL: func.func @shli_scalar +// CHECK-SAME: ([[ARG0:%.+]]: vector<2xi32>, [[ARG1:%.+]]: vector<2xi32>) -> vector<2xi32> +// CHECK-NEXT: [[LOW0:%.+]] = vector.extract [[ARG0]][0] : vector<2xi32> +// CHECK-NEXT: [[HIGH0:%.+]] = vector.extract [[ARG0]][1] : vector<2xi32> +// CHECK-NEXT: [[LOW1:%.+]] = vector.extract [[ARG1]][0] : vector<2xi32> +// CHECK-NEXT: [[CST0:%.+]] = arith.constant 0 : i32 +// CHECK-NEXT: [[CST32:%.+]] = arith.constant 32 : i32 +// CHECK-NEXT: [[OOB:%.+]] = arith.cmpi uge, [[LOW1]], [[CST32]] : i32 +// CHECK-NEXT: [[SHLOW0:%.+]] = arith.shli [[LOW0]], [[LOW1]] : i32 +// CHECK-NEXT: [[RES0:%.+]] = arith.select [[OOB]], [[CST0]], [[SHLOW0]] : i32 +// CHECK-NEXT: [[SHAMT:%.+]] = arith.select [[OOB]], [[CST32]], [[LOW1]] : i32 +// CHECK-NEXT: [[RSHAMT:%.+]] = arith.subi [[CST32]], [[SHAMT]] : i32 +// CHECK-NEXT: [[SHRHIGH0:%.+]] = arith.shrui [[LOW0]], [[RSHAMT]] : i32 +// CHECK-NEXT: [[LSHAMT:%.+]] = arith.subi [[LOW1]], [[CST32]] : i32 +// CHECK-NEXT: [[SHLHIGH0:%.+]] = arith.shli [[LOW0]], [[LSHAMT]] : i32 +// CHECK-NEXT: [[SHLHIGH1:%.+]] = arith.shli [[HIGH0]], [[LOW1]] : i32 +// CHECK-NEXT: [[RES1HIGH:%.+]] = arith.select [[OOB]], [[CST0]], [[SHLHIGH1]] : i32 +// CHECK-NEXT: [[RES1LOW:%.+]] = arith.select [[OOB]], [[SHLHIGH0]], [[SHRHIGH0]] : i32 +// CHECK-NEXT: [[RES1:%.+]] = arith.ori [[RES1LOW]], [[RES1HIGH]] : i32 +// CHECK-NEXT: [[VZ:%.+]] = arith.constant dense<0> : vector<2xi32> +// CHECK-NEXT: [[INS0:%.+]] = vector.insert [[RES0]], [[VZ]] [0] : i32 into vector<2xi32> +// CHECK-NEXT: [[INS1:%.+]] = vector.insert [[RES1]], [[INS0]] [1] : i32 into vector<2xi32> +// CHECK-NEXT: return [[INS1]] : vector<2xi32> +func.func @shli_scalar(%a : i64, %b : i64) -> i64 { + %c = arith.shli %a, %b : i64 + return %c : i64 +} + +// CHECK-LABEL: func.func @shli_vector +// CHECK-SAME: ({{%.+}}: vector<3x2xi32>, {{%.+}}: vector<3x2xi32>) -> vector<3x2xi32> +// CHECK: {{%.+}} = arith.shli {{%.+}}, {{%.+}} : vector<3x1xi32> +// CHECK: {{%.+}} = arith.shrui {{%.+}}, {{%.+}} : vector<3x1xi32> +// CHECK: {{%.+}} = arith.shli {{%.+}}, {{%.+}} : vector<3x1xi32> +// CHECK: {{%.+}} = arith.shli {{%.+}}, {{%.+}} : vector<3x1xi32> +// CHECK: return {{%.+}} : vector<3x2xi32> +func.func @shli_vector(%a : vector<3xi64>, %b : vector<3xi64>) -> vector<3xi64> { + %m = arith.shli %a, %b : vector<3xi64> + return %m : vector<3xi64> +} + // CHECK-LABEL: func.func @shrui_scalar // CHECK-SAME: ([[ARG0:%.+]]: vector<2xi32>, [[ARG1:%.+]]: vector<2xi32>) -> vector<2xi32> // CHECK-NEXT: [[LOW0:%.+]] = vector.extract [[ARG0]][0] : vector<2xi32> @@ -326,6 +366,10 @@ // CHECK-LABEL: func.func @shrui_vector // CHECK-SAME: ({{%.+}}: vector<3x2xi32>, {{%.+}}: vector<3x2xi32>) -> vector<3x2xi32> +// CHECK: {{%.+}} = arith.shrui {{%.+}}, {{%.+}} : vector<3x1xi32> +// CHECK: {{%.+}} = arith.shrui {{%.+}}, {{%.+}} : vector<3x1xi32> +// CHECK: {{%.+}} = arith.shli {{%.+}}, {{%.+}} : vector<3x1xi32> +// CHECK: {{%.+}} = arith.shrui {{%.+}}, {{%.+}} : vector<3x1xi32> // CHECK: return {{%.+}} : vector<3x2xi32> func.func @shrui_vector(%a : vector<3xi64>, %b : vector<3xi64>) -> vector<3xi64> { %m = arith.shrui %a, %b : vector<3xi64> diff --git a/mlir/test/Integration/Dialect/Arith/CPU/test-wide-int-emulation-compare-results-i16.mlir b/mlir/test/Integration/Dialect/Arith/CPU/test-wide-int-emulation-compare-results-i16.mlir --- a/mlir/test/Integration/Dialect/Arith/CPU/test-wide-int-emulation-compare-results-i16.mlir +++ b/mlir/test/Integration/Dialect/Arith/CPU/test-wide-int-emulation-compare-results-i16.mlir @@ -156,6 +156,53 @@ return } +//===----------------------------------------------------------------------===// +// Test arith.shli +//===----------------------------------------------------------------------===// + +// Ops in this function will be emulated using i8 ops. +func.func @emulate_shli(%lhs : i16, %rhs : i16) -> (i16) { + %res = arith.shli %lhs, %rhs : i16 + return %res : i16 +} + +// Performs both wide and emulated `arith.shli`, and checks that the results +// match. +func.func @check_shli(%lhs : i16, %rhs : i16) -> () { + %wide = arith.shli %lhs, %rhs : i16 + %emulated = func.call @emulate_shli(%lhs, %rhs) : (i16, i16) -> (i16) + func.call @check_results(%lhs, %rhs, %wide, %emulated) : (i16, i16, i16, i16) -> () + return +} + +// Checks that `arith.shli` is emulated properly by sampling the input space. +// Checks all valid shift amounts for i16: 0 to 15. +// In total, this test function checks 100 * 16 = 1.6k input pairs. +func.func @test_shli() -> () { + %idx0 = arith.constant 0 : index + %idx1 = arith.constant 1 : index + %idx16 = arith.constant 16 : index + %idx100 = arith.constant 100 : index + + %cst0 = arith.constant 0 : i16 + %cst1 = arith.constant 1 : i16 + + scf.for %lhs_idx = %idx0 to %idx100 step %idx1 iter_args(%lhs = %cst0) -> (i16) { + %arg_lhs = func.call @xhash(%lhs) : (i16) -> (i16) + + scf.for %rhs_idx = %idx0 to %idx16 step %idx1 iter_args(%rhs = %cst0) -> (i16) { + func.call @check_shli(%arg_lhs, %rhs) : (i16, i16) -> () + %rhs_next = arith.addi %rhs, %cst1 : i16 + scf.yield %rhs_next : i16 + } + + %lhs_next = arith.addi %lhs, %cst1 : i16 + scf.yield %lhs_next : i16 + } + + return +} + //===----------------------------------------------------------------------===// // Test arith.shrui //===----------------------------------------------------------------------===// @@ -210,6 +257,7 @@ func.func @entry() { func.call @test_addi() : () -> () func.call @test_muli() : () -> () + func.call @test_shli() : () -> () func.call @test_shrui() : () -> () return } diff --git a/mlir/test/Integration/Dialect/Arith/CPU/test-wide-int-emulation-shli-i16.mlir b/mlir/test/Integration/Dialect/Arith/CPU/test-wide-int-emulation-shli-i16.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Integration/Dialect/Arith/CPU/test-wide-int-emulation-shli-i16.mlir @@ -0,0 +1,73 @@ +// Check that the wide integer `arith.shli` emulation produces the same result as wide +// `arith.shli`. Emulate i16 ops with i8 ops. + +// RUN: mlir-opt %s --convert-scf-to-cf --convert-cf-to-llvm --convert-vector-to-llvm \ +// RUN: --convert-func-to-llvm --convert-arith-to-llvm | \ +// RUN: mlir-cpu-runner -e entry -entry-point-result=void \ +// RUN: --shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext | \ +// RUN: FileCheck %s --match-full-lines + +// RUN: mlir-opt %s --test-arith-emulate-wide-int="widest-int-supported=8" \ +// RUN: --convert-scf-to-cf --convert-cf-to-llvm --convert-vector-to-llvm \ +// RUN: --convert-func-to-llvm --convert-arith-to-llvm | \ +// RUN: mlir-cpu-runner -e entry -entry-point-result=void \ +// RUN: --shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext | \ +// RUN: FileCheck %s --match-full-lines + +// Ops in this function *only* will be emulated using i8 types. +func.func @emulate_shli(%lhs : i16, %rhs : i16) -> (i16) { + %res = arith.shli %lhs, %rhs : i16 + return %res : i16 +} + +func.func @check_shli(%lhs : i16, %rhs : i16) -> () { + %res = func.call @emulate_shli(%lhs, %rhs) : (i16, i16) -> (i16) + vector.print %res : i16 + return +} + +func.func @entry() { + %cst0 = arith.constant 0 : i16 + %cst1 = arith.constant 1 : i16 + %cst2 = arith.constant 2 : i16 + %cst7 = arith.constant 7 : i16 + %cst8 = arith.constant 8 : i16 + %cst9 = arith.constant 9 : i16 + %cst15 = arith.constant 15 : i16 + + %cst_n1 = arith.constant -1 : i16 + + %cst1337 = arith.constant 1337 : i16 + + %cst_i16_min = arith.constant -32768 : i16 + + // CHECK: 0 + // CHECK-NEXT: 0 + // CHECK-NEXT: 1 + // CHECK-NEXT: 2 + // CHECK-NEXT: -2 + // CHECK-NEXT: -32768 + func.call @check_shli(%cst0, %cst0) : (i16, i16) -> () + func.call @check_shli(%cst0, %cst1) : (i16, i16) -> () + func.call @check_shli(%cst1, %cst0) : (i16, i16) -> () + func.call @check_shli(%cst1, %cst1) : (i16, i16) -> () + func.call @check_shli(%cst_n1, %cst1) : (i16, i16) -> () + func.call @check_shli(%cst_n1, %cst15) : (i16, i16) -> () + + // CHECK-NEXT: 1337 + // CHECK-NEXT: 5348 + // CHECK-NEXT: -25472 + // CHECK-NEXT: 14592 + // CHECK-NEXT: 29184 + // CHECK-NEXT: -32768 + // CHECK-NEXT: 0 + func.call @check_shli(%cst1337, %cst0) : (i16, i16) -> () + func.call @check_shli(%cst1337, %cst2) : (i16, i16) -> () + func.call @check_shli(%cst1337, %cst7) : (i16, i16) -> () + func.call @check_shli(%cst1337, %cst8) : (i16, i16) -> () + func.call @check_shli(%cst1337, %cst9) : (i16, i16) -> () + func.call @check_shli(%cst1337, %cst15) : (i16, i16) -> () + func.call @check_shli(%cst_i16_min, %cst1) : (i16, i16) -> () + + return +}