diff --git a/flang/lib/Lower/IntrinsicCall.cpp b/flang/lib/Lower/IntrinsicCall.cpp --- a/flang/lib/Lower/IntrinsicCall.cpp +++ b/flang/lib/Lower/IntrinsicCall.cpp @@ -557,6 +557,7 @@ llvm::ArrayRef args); template mlir::Value genShift(mlir::Type resultType, llvm::ArrayRef); + mlir::Value genShiftA(mlir::Type resultType, llvm::ArrayRef); mlir::Value genSign(mlir::Type, llvm::ArrayRef); fir::ExtendedValue genSize(mlir::Type, llvm::ArrayRef); mlir::Value genSpacing(mlir::Type resultType, @@ -958,7 +959,7 @@ {"radix", asAddr, handleDynamicOptional}}}, /*isElemental=*/false}, {"set_exponent", &I::genSetExponent}, - {"shifta", &I::genShift}, + {"shifta", &I::genShiftA}, {"shiftl", &I::genShift}, {"shiftr", &I::genShift}, {"sign", &I::genSign}, @@ -4015,7 +4016,7 @@ fir::getBase(args[1]))); } -// SHIFTA, SHIFTL, SHIFTR +// SHIFTL, SHIFTR template mlir::Value IntrinsicLibrary::genShift(mlir::Type resultType, llvm::ArrayRef args) { @@ -4041,6 +4042,31 @@ return builder.create(loc, outOfBounds, zero, shifted); } +// SHIFTA +mlir::Value IntrinsicLibrary::genShiftA(mlir::Type resultType, + llvm::ArrayRef args) { + unsigned bits = resultType.getIntOrFloatBitWidth(); + mlir::Value bitSize = builder.createIntegerConstant(loc, resultType, bits); + mlir::Value shift = builder.createConvert(loc, resultType, args[1]); + mlir::Value shiftEqBitSize = builder.create( + loc, mlir::arith::CmpIPredicate::eq, shift, bitSize); + + // Lowering of mlir::arith::ShRSIOp is using `ashr`. `ashr` is undefined when + // the shift amount is equal to the element size. + // So if SHIFT is equal to the bit width then it is handled as a special case. + mlir::Value zero = builder.createIntegerConstant(loc, resultType, 0); + mlir::Value minusOne = builder.createIntegerConstant(loc, resultType, -1); + mlir::Value valueIsNeg = builder.create( + loc, mlir::arith::CmpIPredicate::slt, args[0], zero); + mlir::Value specialRes = + builder.create(loc, valueIsNeg, minusOne, zero); + + mlir::Value shifted = + builder.create(loc, args[0], shift); + return builder.create(loc, shiftEqBitSize, specialRes, + shifted); +} + // SIGN mlir::Value IntrinsicLibrary::genSign(mlir::Type resultType, llvm::ArrayRef args) { diff --git a/flang/test/Lower/Intrinsics/shifta.f90 b/flang/test/Lower/Intrinsics/shifta.f90 --- a/flang/test/Lower/Intrinsics/shifta.f90 +++ b/flang/test/Lower/Intrinsics/shifta.f90 @@ -12,13 +12,14 @@ ! CHECK: %[[B_VAL:.*]] = fir.load %[[B]] : !fir.ref c = shifta(a, b) ! CHECK: %[[C_BITS:.*]] = arith.constant 8 : i8 - ! CHECK: %[[C_0:.*]] = arith.constant 0 : i8 ! CHECK: %[[B_CONV:.*]] = fir.convert %[[B_VAL]] : (i32) -> i8 - ! CHECK: %[[UNDER:.*]] = arith.cmpi slt, %[[B_CONV]], %[[C_0]] : i8 - ! CHECK: %[[OVER:.*]] = arith.cmpi sge, %[[B_CONV]], %[[C_BITS]] : i8 - ! CHECK: %[[INVALID:.*]] = arith.ori %[[UNDER]], %[[OVER]] : i1 - ! CHECK: %[[SHIFT:.*]] = arith.shrsi %[[A_VAL]], %[[B_CONV]] : i8 - ! CHECK: %[[RES:.*]] = arith.select %[[INVALID]], %[[C_0]], %[[SHIFT]] : i8 + ! CHECK: %[[SHIFT_IS_BITWIDTH:.*]] = arith.cmpi eq, %[[B_CONV]], %[[C_BITS]] : i8 + ! CHECK: %[[C0:.*]] = arith.constant 0 : i8 + ! CHECK: %[[CM1:.*]] = arith.constant -1 : i8 + ! CHECK: %[[IS_NEG:.*]] = arith.cmpi slt, %[[A_VAL]], %[[C0]] : i8 + ! CHECK: %[[RES:.*]] = arith.select %[[IS_NEG]], %[[CM1]], %[[C0]] : i8 + ! CHECK: %[[SHIFTED:.*]] = arith.shrsi %[[A_VAL]], %[[B_CONV]] : i8 + ! CHECK: %{{.*}} = arith.select %[[SHIFT_IS_BITWIDTH]], %[[RES]], %[[SHIFTED]] : i8 end subroutine shifta1_test ! CHECK-LABEL: shifta2_test @@ -32,13 +33,14 @@ ! CHECK: %[[B_VAL:.*]] = fir.load %[[B]] : !fir.ref c = shifta(a, b) ! CHECK: %[[C_BITS:.*]] = arith.constant 16 : i16 - ! CHECK: %[[C_0:.*]] = arith.constant 0 : i16 ! CHECK: %[[B_CONV:.*]] = fir.convert %[[B_VAL]] : (i32) -> i16 - ! CHECK: %[[UNDER:.*]] = arith.cmpi slt, %[[B_CONV]], %[[C_0]] : i16 - ! CHECK: %[[OVER:.*]] = arith.cmpi sge, %[[B_CONV]], %[[C_BITS]] : i16 - ! CHECK: %[[INVALID:.*]] = arith.ori %[[UNDER]], %[[OVER]] : i1 - ! CHECK: %[[SHIFT:.*]] = arith.shrsi %[[A_VAL]], %[[B_CONV]] : i16 - ! CHECK: %[[RES:.*]] = arith.select %[[INVALID]], %[[C_0]], %[[SHIFT]] : i16 + ! CHECK: %[[SHIFT_IS_BITWIDTH:.*]] = arith.cmpi eq, %[[B_CONV]], %[[C_BITS]] : i16 + ! CHECK: %[[C0:.*]] = arith.constant 0 : i16 + ! CHECK: %[[CM1:.*]] = arith.constant -1 : i16 + ! CHECK: %[[IS_NEG:.*]] = arith.cmpi slt, %[[A_VAL]], %[[C0]] : i16 + ! CHECK: %[[RES:.*]] = arith.select %[[IS_NEG]], %[[CM1]], %[[C0]] : i16 + ! CHECK: %[[SHIFTED:.*]] = arith.shrsi %[[A_VAL]], %[[B_CONV]] : i16 + ! CHECK: %{{.*}} = arith.select %[[SHIFT_IS_BITWIDTH]], %[[RES]], %[[SHIFTED]] : i16 end subroutine shifta2_test ! CHECK-LABEL: shifta4_test @@ -52,12 +54,13 @@ ! CHECK: %[[B_VAL:.*]] = fir.load %[[B]] : !fir.ref c = shifta(a, b) ! CHECK: %[[C_BITS:.*]] = arith.constant 32 : i32 - ! CHECK: %[[C_0:.*]] = arith.constant 0 : i32 - ! CHECK: %[[UNDER:.*]] = arith.cmpi slt, %[[B_VAL]], %[[C_0]] : i32 - ! CHECK: %[[OVER:.*]] = arith.cmpi sge, %[[B_VAL]], %[[C_BITS]] : i32 - ! CHECK: %[[INVALID:.*]] = arith.ori %[[UNDER]], %[[OVER]] : i1 - ! CHECK: %[[SHIFT:.*]] = arith.shrsi %[[A_VAL]], %[[B_VAL]] : i32 - ! CHECK: %[[RES:.*]] = arith.select %[[INVALID]], %[[C_0]], %[[SHIFT]] : i32 + ! CHECK: %[[SHIFT_IS_BITWIDTH:.*]] = arith.cmpi eq, %[[B_VAL]], %[[C_BITS]] : i32 + ! CHECK: %[[C0:.*]] = arith.constant 0 : i32 + ! CHECK: %[[CM1:.*]] = arith.constant -1 : i32 + ! CHECK: %[[IS_NEG:.*]] = arith.cmpi slt, %[[A_VAL]], %[[C0]] : i32 + ! CHECK: %[[RES:.*]] = arith.select %[[IS_NEG]], %[[CM1]], %[[C0]] : i32 + ! CHECK: %[[SHIFTED:.*]] = arith.shrsi %[[A_VAL]], %[[B_VAL]] : i32 + ! CHECK: %{{.*}} = arith.select %[[SHIFT_IS_BITWIDTH]], %[[RES]], %[[SHIFTED]] : i32 end subroutine shifta4_test ! CHECK-LABEL: shifta8_test @@ -71,13 +74,14 @@ ! CHECK: %[[B_VAL:.*]] = fir.load %[[B]] : !fir.ref c = shifta(a, b) ! CHECK: %[[C_BITS:.*]] = arith.constant 64 : i64 - ! CHECK: %[[C_0:.*]] = arith.constant 0 : i64 ! CHECK: %[[B_CONV:.*]] = fir.convert %[[B_VAL]] : (i32) -> i64 - ! CHECK: %[[UNDER:.*]] = arith.cmpi slt, %[[B_CONV]], %[[C_0]] : i64 - ! CHECK: %[[OVER:.*]] = arith.cmpi sge, %[[B_CONV]], %[[C_BITS]] : i64 - ! CHECK: %[[INVALID:.*]] = arith.ori %[[UNDER]], %[[OVER]] : i1 - ! CHECK: %[[SHIFT:.*]] = arith.shrsi %[[A_VAL]], %[[B_CONV]] : i64 - ! CHECK: %[[RES:.*]] = arith.select %[[INVALID]], %[[C_0]], %[[SHIFT]] : i64 + ! CHECK: %[[SHIFT_IS_BITWIDTH:.*]] = arith.cmpi eq, %[[B_CONV]], %[[C_BITS]] : i64 + ! CHECK: %[[C0:.*]] = arith.constant 0 : i64 + ! CHECK: %[[CM1:.*]] = arith.constant -1 : i64 + ! CHECK: %[[IS_NEG:.*]] = arith.cmpi slt, %[[A_VAL]], %[[C0]] : i64 + ! CHECK: %[[RES:.*]] = arith.select %[[IS_NEG]], %[[CM1]], %[[C0]] : i64 + ! CHECK: %[[SHIFTED:.*]] = arith.shrsi %[[A_VAL]], %[[B_CONV]] : i64 + ! CHECK: %{{.*}} = arith.select %[[SHIFT_IS_BITWIDTH]], %[[RES]], %[[SHIFTED]] : i64 end subroutine shifta8_test ! CHECK-LABEL: shifta16_test @@ -91,11 +95,12 @@ ! CHECK: %[[B_VAL:.*]] = fir.load %[[B]] : !fir.ref c = shifta(a, b) ! CHECK: %[[C_BITS:.*]] = arith.constant 128 : i128 - ! CHECK: %[[C_0:.*]] = arith.constant 0 : i128 ! CHECK: %[[B_CONV:.*]] = fir.convert %[[B_VAL]] : (i32) -> i128 - ! CHECK: %[[UNDER:.*]] = arith.cmpi slt, %[[B_CONV]], %[[C_0]] : i128 - ! CHECK: %[[OVER:.*]] = arith.cmpi sge, %[[B_CONV]], %[[C_BITS]] : i128 - ! CHECK: %[[INVALID:.*]] = arith.ori %[[UNDER]], %[[OVER]] : i1 - ! CHECK: %[[SHIFT:.*]] = arith.shrsi %[[A_VAL]], %[[B_CONV]] : i128 - ! CHECK: %[[RES:.*]] = arith.select %[[INVALID]], %[[C_0]], %[[SHIFT]] : i128 + ! CHECK: %[[SHIFT_IS_BITWIDTH:.*]] = arith.cmpi eq, %[[B_CONV]], %[[C_BITS]] : i128 + ! CHECK: %[[C0:.*]] = arith.constant 0 : i128 + ! CHECK: %[[CM1:.*]] = arith.constant {{.*}} : i128 + ! CHECK: %[[IS_NEG:.*]] = arith.cmpi slt, %[[A_VAL]], %[[C0]] : i128 + ! CHECK: %[[RES:.*]] = arith.select %[[IS_NEG]], %[[CM1]], %[[C0]] : i128 + ! CHECK: %[[SHIFTED:.*]] = arith.shrsi %[[A_VAL]], %[[B_CONV]] : i128 + ! CHECK: %{{.*}} = arith.select %[[SHIFT_IS_BITWIDTH]], %[[RES]], %[[SHIFTED]] : i128 end subroutine shifta16_test