diff --git a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp @@ -552,6 +552,103 @@ return nullptr; } +// Transform IEEE Floats: +// (fmul C, (uitofp Pow2)) +// -> (bitcast_to_FP (add (bitcast_to_INT C), Log2(Pow2) << mantissa)) +// (fdiv C, (uitofp Pow2)) +// -> (bitcast_to_FP (sub (bitcast_to_INT C), Log2(Pow2) << mantissa)) +// +// The rationale is fmul/fdiv by a power of 2 is just change the +// exponent, so there is no need for more than an add/sub. +// +// This is valid under the following circumstances: +// 1) We are dealing with IEEE floats +// 2) C is normal +// 3) The fmul/fdiv add/sub will not go outside of min/max exponent bounds. +static Instruction *foldFMulOrFDivOfConstantAndIntPow2(InstCombinerImpl &IC, + const SimplifyQuery &Q, + BinaryOperator &I) { + Type *FPTy = I.getType(); + // Only do IEEE floats. Also skip scalable vecs. + if (!FPTy->getScalarType()->isIEEELikeFPTy() || FPTy->isScalableTy()) + return nullptr; + + // Make sure we can get a valid mantissa. + int Mantissa = FPTy->getFPMantissaWidth() - 1; + if (Mantissa <= 0) + return nullptr; + + // We are explicitly only matching FDiv where the denominator is a Pow2 + // Integer. This means the constant must be the first operand. For FMul, + // however, the constant will canonicalize to the second argument. + unsigned ConstOpIdx = I.getOpcode() == Instruction::FMul; + unsigned UIToFPOpIdx = 1 - ConstOpIdx; + + const APFloat *APF; + if (!match(I.getOperand(ConstOpIdx), m_APFloat(APF))) + return nullptr; + + // Make sure we have normal float (there might be some cases where this works + // for subnormal, but thats probably rare and code in not worth the risk of a + // bug). + if (!APF->isNormal() || !APF->isIEEE()) + return nullptr; + + Value *V; + if (!match(I.getOperand(UIToFPOpIdx), m_UIToFP(m_Value(V)))) { + // We can match (sitofp Pow2) if we know Pow2 is non-negative. Doing this in + // InstCombineCasts causes regressions. + // TODO: Remove this if InstCombineCasts is ever updated. + if (!match(I.getOperand(UIToFPOpIdx), m_SIToFP(m_Value(V))) || + !isKnownNonNegative(V, Q.DL, 0, Q.AC, Q.CxtI, Q.DT)) + return nullptr; + } + + // Make sure add/sub will never breach the bounds of min/max exponent. We + // could be a little more clever here if we used knownbits to get an + // upperbound of V, but most floats are well within the bounds so its probably + // not worth the risk of a bug to handle such a rare case. + int MaxExpChange = V->getType()->getScalarSizeInBits(); + int CurExp = ilogb(*APF); + int NewExpBound = I.getOpcode() == Instruction::FMul + ? (CurExp + MaxExpChange) + : (CurExp - MaxExpChange); + if (NewExpBound <= APFloat::semanticsMinExponent(APF->getSemantics()) || + NewExpBound >= APFloat::semanticsMaxExponent(APF->getSemantics())) + return nullptr; + + // Finally check if the integer is a Pow2 we can get the log of. All other + // checks need to be before this, as takeLog2 will create our log instruction. + Value *Log2 = + takeLog2(IC.Builder, V, /*Depth*/ 0, /*AssumeNonZero*/ + isKnownNonZero(V, Q.DL, /*Depth*/ 0, Q.AC, Q.CxtI, Q.DT), + /*DoFold*/ true); + if (Log2 == nullptr) + return nullptr; + + // uintofp can change bitwidth, so need to cast our Log2 value to proper + // integer width. + unsigned FPWidth = FPTy->getScalarSizeInBits(); + Type *IntTy = Type::getIntNTy(I.getContext(), FPWidth); + if (FPTy->isVectorTy()) { + auto *FVTy = dyn_cast(FPTy); + assert(FVTy && + "We should have already made sure this is not a scalable type!"); + IntTy = + VectorType::get(IntTy, ElementCount::getFixed(FVTy->getNumElements())); + } + Value *CastedLog2 = IC.Builder.CreateZExtOrTrunc(Log2, IntTy); + + // Do actual fold. + Value *ShiftedLog2 = IC.Builder.CreateBinOp( + Instruction::Shl, CastedLog2, ConstantInt::get(IntTy, Mantissa)); + Value *BitwiseMulOrDiv = IC.Builder.CreateBinOp( + I.getOpcode() == Instruction::FMul ? Instruction::Add : Instruction::Sub, + IC.Builder.CreateBitCast(I.getOperand(ConstOpIdx), IntTy), ShiftedLog2); + Value *R = IC.Builder.CreateBitCast(BitwiseMulOrDiv, FPTy); + return IC.replaceInstUsesWith(I, R); +} + Instruction *InstCombinerImpl::visitFMul(BinaryOperator &I) { if (Value *V = simplifyFMulInst(I.getOperand(0), I.getOperand(1), I.getFastMathFlags(), @@ -813,6 +910,10 @@ return Result; } + const SimplifyQuery Q = SQ.getWithInstruction(&I); + if (Instruction *R = foldFMulOrFDivOfConstantAndIntPow2(*this, Q, I)) + return R; + return nullptr; } @@ -1756,6 +1857,10 @@ return replaceInstUsesWith(I, Pow); } + const SimplifyQuery Q = SQ.getWithInstruction(&I); + if (Instruction *R = foldFMulOrFDivOfConstantAndIntPow2(*this, Q, I)) + return R; + return nullptr; } diff --git a/llvm/test/Transforms/InstCombine/fmul-and-fdiv-by-int-pow2.ll b/llvm/test/Transforms/InstCombine/fmul-and-fdiv-by-int-pow2.ll --- a/llvm/test/Transforms/InstCombine/fmul-and-fdiv-by-int-pow2.ll +++ b/llvm/test/Transforms/InstCombine/fmul-and-fdiv-by-int-pow2.ll @@ -4,9 +4,9 @@ define double @fmul_dbl_pow_1_shl_cnt(i64 %cnt) { ; CHECK-LABEL: define double @fmul_dbl_pow_1_shl_cnt ; CHECK-SAME: (i64 [[CNT:%.*]]) { -; CHECK-NEXT: [[SHL:%.*]] = shl nuw i64 1, [[CNT]] -; CHECK-NEXT: [[CONV:%.*]] = uitofp i64 [[SHL]] to double -; CHECK-NEXT: [[MUL:%.*]] = fmul double [[CONV]], 9.000000e+00 +; CHECK-NEXT: [[TMP1:%.*]] = shl i64 [[CNT]], 52 +; CHECK-NEXT: [[TMP2:%.*]] = add i64 [[TMP1]], 4621256167635550208 +; CHECK-NEXT: [[MUL:%.*]] = bitcast i64 [[TMP2]] to double ; CHECK-NEXT: ret double [[MUL]] ; %shl = shl nuw i64 1, %cnt @@ -32,9 +32,10 @@ define float @fmul_pow_1_shl_cnt(i64 %cnt) { ; CHECK-LABEL: define float @fmul_pow_1_shl_cnt ; CHECK-SAME: (i64 [[CNT:%.*]]) { -; CHECK-NEXT: [[SHL:%.*]] = shl nuw nsw i64 8, [[CNT]] -; CHECK-NEXT: [[CONV:%.*]] = uitofp i64 [[SHL]] to float -; CHECK-NEXT: [[MUL:%.*]] = fmul float [[CONV]], -9.000000e+00 +; CHECK-NEXT: [[TMP1:%.*]] = trunc i64 [[CNT]] to i32 +; CHECK-NEXT: [[TMP2:%.*]] = shl i32 [[TMP1]], 23 +; CHECK-NEXT: [[TMP3:%.*]] = add i32 [[TMP2]], -1030750208 +; CHECK-NEXT: [[MUL:%.*]] = bitcast i32 [[TMP3]] to float ; CHECK-NEXT: ret float [[MUL]] ; %shl = shl nsw nuw i64 8, %cnt @@ -46,9 +47,10 @@ define <2 x float> @fmul_pow_2_shl_cnt_vec(<2 x i64> %cnt) { ; CHECK-LABEL: define <2 x float> @fmul_pow_2_shl_cnt_vec ; CHECK-SAME: (<2 x i64> [[CNT:%.*]]) { -; CHECK-NEXT: [[SHL:%.*]] = shl nuw nsw <2 x i64> , [[CNT]] -; CHECK-NEXT: [[CONV:%.*]] = uitofp <2 x i64> [[SHL]] to <2 x float> -; CHECK-NEXT: [[MUL:%.*]] = fmul <2 x float> [[CONV]], +; CHECK-NEXT: [[TMP1:%.*]] = trunc <2 x i64> [[CNT]] to <2 x i32> +; CHECK-NEXT: [[TMP2:%.*]] = shl <2 x i32> [[TMP1]], +; CHECK-NEXT: [[TMP3:%.*]] = add <2 x i32> [[TMP2]], +; CHECK-NEXT: [[MUL:%.*]] = bitcast <2 x i32> [[TMP3]] to <2 x float> ; CHECK-NEXT: ret <2 x float> [[MUL]] ; %shl = shl nsw nuw <2 x i64> , %cnt @@ -88,9 +90,10 @@ define double @fmul_pow_1_shl_cnt_safe(i16 %cnt) { ; CHECK-LABEL: define double @fmul_pow_1_shl_cnt_safe ; CHECK-SAME: (i16 [[CNT:%.*]]) { -; CHECK-NEXT: [[SHL:%.*]] = shl nuw i16 1, [[CNT]] -; CHECK-NEXT: [[CONV:%.*]] = uitofp i16 [[SHL]] to double -; CHECK-NEXT: [[MUL:%.*]] = fmul double [[CONV]], 0x7BEFFFFFFF5F3992 +; CHECK-NEXT: [[TMP1:%.*]] = zext i16 [[CNT]] to i64 +; CHECK-NEXT: [[TMP2:%.*]] = shl i64 [[TMP1]], 52 +; CHECK-NEXT: [[TMP3:%.*]] = add i64 [[TMP2]], 8930638061065157010 +; CHECK-NEXT: [[MUL:%.*]] = bitcast i64 [[TMP3]] to double ; CHECK-NEXT: ret double [[MUL]] ; %shl = shl nuw i16 1, %cnt @@ -102,9 +105,9 @@ define <2 x double> @fdiv_pow_1_shl_cnt_vec(<2 x i64> %cnt) { ; CHECK-LABEL: define <2 x double> @fdiv_pow_1_shl_cnt_vec ; CHECK-SAME: (<2 x i64> [[CNT:%.*]]) { -; CHECK-NEXT: [[SHL:%.*]] = shl nuw <2 x i64> , [[CNT]] -; CHECK-NEXT: [[CONV:%.*]] = uitofp <2 x i64> [[SHL]] to <2 x double> -; CHECK-NEXT: [[MUL:%.*]] = fdiv <2 x double> , [[CONV]] +; CHECK-NEXT: [[TMP1:%.*]] = shl <2 x i64> [[CNT]], +; CHECK-NEXT: [[TMP2:%.*]] = sub <2 x i64> , [[TMP1]] +; CHECK-NEXT: [[MUL:%.*]] = bitcast <2 x i64> [[TMP2]] to <2 x double> ; CHECK-NEXT: ret <2 x double> [[MUL]] ; %shl = shl nuw <2 x i64> , %cnt @@ -144,10 +147,11 @@ define float @fdiv_pow_1_shl_cnt(i64 %cnt_in) { ; CHECK-LABEL: define float @fdiv_pow_1_shl_cnt ; CHECK-SAME: (i64 [[CNT_IN:%.*]]) { -; CHECK-NEXT: [[CNT:%.*]] = and i64 [[CNT_IN]], 31 -; CHECK-NEXT: [[SHL:%.*]] = shl i64 8, [[CNT]] -; CHECK-NEXT: [[CONV:%.*]] = sitofp i64 [[SHL]] to float -; CHECK-NEXT: [[MUL:%.*]] = fdiv float -5.000000e-01, [[CONV]] +; CHECK-NEXT: [[TMP1:%.*]] = trunc i64 [[CNT_IN]] to i32 +; CHECK-NEXT: [[TMP2:%.*]] = shl i32 [[TMP1]], 23 +; CHECK-NEXT: [[TMP3:%.*]] = and i32 [[TMP2]], 260046848 +; CHECK-NEXT: [[TMP4:%.*]] = sub nuw nsw i32 -1115684864, [[TMP3]] +; CHECK-NEXT: [[MUL:%.*]] = bitcast i32 [[TMP4]] to float ; CHECK-NEXT: ret float [[MUL]] ; %cnt = and i64 %cnt_in, 31 @@ -174,9 +178,9 @@ define half @fdiv_pow_1_shl_cnt_in_bounds(i16 %cnt) { ; CHECK-LABEL: define half @fdiv_pow_1_shl_cnt_in_bounds ; CHECK-SAME: (i16 [[CNT:%.*]]) { -; CHECK-NEXT: [[SHL:%.*]] = shl nuw i16 1, [[CNT]] -; CHECK-NEXT: [[CONV:%.*]] = uitofp i16 [[SHL]] to half -; CHECK-NEXT: [[MUL:%.*]] = fdiv half 0xH7000, [[CONV]] +; CHECK-NEXT: [[TMP1:%.*]] = shl i16 [[CNT]], 10 +; CHECK-NEXT: [[TMP2:%.*]] = sub i16 28672, [[TMP1]] +; CHECK-NEXT: [[MUL:%.*]] = bitcast i16 [[TMP2]] to half ; CHECK-NEXT: ret half [[MUL]] ; %shl = shl nuw i16 1, %cnt @@ -188,9 +192,9 @@ define half @fdiv_pow_1_shl_cnt_in_bounds2(i16 %cnt) { ; CHECK-LABEL: define half @fdiv_pow_1_shl_cnt_in_bounds2 ; CHECK-SAME: (i16 [[CNT:%.*]]) { -; CHECK-NEXT: [[SHL:%.*]] = shl nuw i16 1, [[CNT]] -; CHECK-NEXT: [[CONV:%.*]] = uitofp i16 [[SHL]] to half -; CHECK-NEXT: [[MUL:%.*]] = fdiv half 0xH4800, [[CONV]] +; CHECK-NEXT: [[TMP1:%.*]] = shl i16 [[CNT]], 10 +; CHECK-NEXT: [[TMP2:%.*]] = sub i16 18432, [[TMP1]] +; CHECK-NEXT: [[MUL:%.*]] = bitcast i16 [[TMP2]] to half ; CHECK-NEXT: ret half [[MUL]] ; %shl = shl nuw i16 1, %cnt