diff --git a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h --- a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h +++ b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h @@ -3128,9 +3128,25 @@ /// Attempt to match a unary predicate against a scalar/splat constant or /// every element of a constant BUILD_VECTOR. /// If AllowUndef is true, then UNDEF elements will pass nullptr to Match. - bool matchUnaryPredicate(SDValue Op, - std::function Match, - bool AllowUndefs = false); + template + bool matchUnaryPredicateImpl(SDValue Op, + std::function Match, + bool AllowUndefs = false); + + /// Hook for matching ConstantSDNode predicate + inline bool matchUnaryPredicate(SDValue Op, + std::function Match, + bool AllowUndefs = false) { + return matchUnaryPredicateImpl(Op, Match, AllowUndefs); + } + + /// Hook for matching ConstantFPSDNode predicate + inline bool + matchUnaryFpPredicate(SDValue Op, + std::function Match, + bool AllowUndefs = false) { + return matchUnaryPredicateImpl(Op, Match, AllowUndefs); + } /// Attempt to match a binary predicate against a pair of scalar/splat /// constants or every element of a pair of constant BUILD_VECTORs. diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -16352,7 +16352,7 @@ EVT VT = N->getValueType(0); SDValue ConstOp, Pow2Op; - int Mantissa = -1; + std::optional Mantissa; auto GetConstAndPow2Ops = [&](unsigned ConstOpIdx) { if (ConstOpIdx == 1 && N->getOpcode() == ISD::FDIV) return false; @@ -16366,36 +16366,43 @@ Pow2Op = Pow2Op.getOperand(0); - // TODO(1): We may be able to include undefs. - // TODO(2): We could also handle non-splat vector types. - ConstantFPSDNode *CFP = - isConstOrConstSplatFP(ConstOp, /*AllowUndefs*/ false); - if (CFP == nullptr) - return false; - const APFloat &APF = CFP->getValueAPF(); - - // Make sure we have normal/ieee constant. - if (!APF.isNormal() || !APF.isIEEE()) - return false; - // `Log2(Pow2Op) < Pow2Op.getScalarSizeInBits()`. // TODO: We could use knownbits to make this bound more precise. int MaxExpChange = Pow2Op.getValueType().getScalarSizeInBits(); - // Make sure the floats exponent is within the bounds that this transform - // produces bitwise equals value. - int CurExp = ilogb(APF); - // FMul by pow2 will only increase exponent. - int MinExp = N->getOpcode() == ISD::FMUL ? CurExp : (CurExp - MaxExpChange); - // FDiv by pow2 will only decrease exponent. - int MaxExp = N->getOpcode() == ISD::FDIV ? CurExp : (CurExp + MaxExpChange); - if (MinExp <= APFloat::semanticsMinExponent(APF.getSemantics()) || - MaxExp >= APFloat::semanticsMaxExponent(APF.getSemantics())) - return false; + auto IsFPConstValid = [N, MaxExpChange, &Mantissa](ConstantFPSDNode *CFP) { + if (CFP == nullptr) + return false; + + const APFloat &APF = CFP->getValueAPF(); + + // Make sure we have normal/ieee constant. + if (!APF.isNormal() || !APF.isIEEE()) + return false; + + // Make sure the floats exponent is within the bounds that this transform + // produces bitwise equals value. + int CurExp = ilogb(APF); + // FMul by pow2 will only increase exponent. + int MinExp = + N->getOpcode() == ISD::FMUL ? CurExp : (CurExp - MaxExpChange); + // FDiv by pow2 will only decrease exponent. + int MaxExp = + N->getOpcode() == ISD::FDIV ? CurExp : (CurExp + MaxExpChange); + if (MinExp <= APFloat::semanticsMinExponent(APF.getSemantics()) || + MaxExp >= APFloat::semanticsMaxExponent(APF.getSemantics())) + return false; + + // Finally make sure we actually know the mantissa for the float type. + int ThisMantissa = APFloat::semanticsPrecision(APF.getSemantics()) - 1; + if (!Mantissa) + Mantissa = ThisMantissa; + + return *Mantissa == ThisMantissa && ThisMantissa > 0; + }; - // Finally make sure we actually know the mantissa for the float type. - Mantissa = APFloat::semanticsPrecision(APF.getSemantics()) - 1; - return Mantissa > 0; + // TODO: We may be able to include undefs. + return ISD::matchUnaryFpPredicate(ConstOp, IsFPConstValid); }; if (!GetConstAndPow2Ops(0) && !GetConstAndPow2Ops(1)) @@ -16420,7 +16427,7 @@ // Perform actual transform. SDValue MantissaShiftCnt = - DAG.getConstant(Mantissa, DL, getShiftAmountTy(NewIntVT)); + DAG.getConstant(*Mantissa, DL, getShiftAmountTy(NewIntVT)); // TODO: Sometimes Log2 is of form `(X + C)`. `(X + C) << C1` should fold to // `(X << C1) + (C << C1)`, but that isn't always the case because of the // cast. We could implement that by handle here to handle the casts. diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -344,12 +344,13 @@ return N->getOpcode() == ISD::FREEZE && N->getOperand(0).isUndef(); } -bool ISD::matchUnaryPredicate(SDValue Op, - std::function Match, - bool AllowUndefs) { +template +bool ISD::matchUnaryPredicateImpl(SDValue Op, + std::function Match, + bool AllowUndefs) { // FIXME: Add support for scalar UNDEF cases? - if (auto *Cst = dyn_cast(Op)) - return Match(Cst); + if (auto *C = dyn_cast(Op)) + return Match(C); // FIXME: Add support for vector UNDEF cases? if (ISD::BUILD_VECTOR != Op.getOpcode() && @@ -364,12 +365,17 @@ continue; } - auto *Cst = dyn_cast(Op.getOperand(i)); + auto *Cst = dyn_cast(Op.getOperand(i)); if (!Cst || Cst->getValueType(0) != SVT || !Match(Cst)) return false; } return true; } +// Build used template types. +template bool ISD::matchUnaryPredicateImpl( + SDValue, std::function, bool); +template bool ISD::matchUnaryPredicateImpl( + SDValue, std::function, bool); bool ISD::matchBinaryPredicate( SDValue LHS, SDValue RHS, diff --git a/llvm/test/CodeGen/X86/fold-int-pow2-with-fmul-or-fdiv.ll b/llvm/test/CodeGen/X86/fold-int-pow2-with-fmul-or-fdiv.ll --- a/llvm/test/CodeGen/X86/fold-int-pow2-with-fmul-or-fdiv.ll +++ b/llvm/test/CodeGen/X86/fold-int-pow2-with-fmul-or-fdiv.ll @@ -1104,58 +1104,15 @@ define <2 x double> @fmul_pow_shl_cnt_vec_non_splat_todo(<2 x i64> %cnt) nounwind { ; CHECK-SSE-LABEL: fmul_pow_shl_cnt_vec_non_splat_todo: ; CHECK-SSE: # %bb.0: -; CHECK-SSE-NEXT: movdqa {{.*#+}} xmm1 = [2,2] -; CHECK-SSE-NEXT: movdqa %xmm1, %xmm2 -; CHECK-SSE-NEXT: psllq %xmm0, %xmm2 -; CHECK-SSE-NEXT: pshufd {{.*#+}} xmm0 = xmm0[2,3,2,3] -; CHECK-SSE-NEXT: psllq %xmm0, %xmm1 -; CHECK-SSE-NEXT: movsd {{.*#+}} xmm1 = xmm2[0],xmm1[1] -; CHECK-SSE-NEXT: movapd {{.*#+}} xmm0 = [4294967295,4294967295] -; CHECK-SSE-NEXT: andpd %xmm1, %xmm0 -; CHECK-SSE-NEXT: orpd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0 -; CHECK-SSE-NEXT: psrlq $32, %xmm1 -; CHECK-SSE-NEXT: por {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm1 -; CHECK-SSE-NEXT: subpd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm1 -; CHECK-SSE-NEXT: addpd %xmm0, %xmm1 -; CHECK-SSE-NEXT: mulpd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm1 -; CHECK-SSE-NEXT: movapd %xmm1, %xmm0 +; CHECK-SSE-NEXT: psllq $52, %xmm0 +; CHECK-SSE-NEXT: paddq {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0 ; CHECK-SSE-NEXT: retq ; -; CHECK-AVX2-LABEL: fmul_pow_shl_cnt_vec_non_splat_todo: -; CHECK-AVX2: # %bb.0: -; CHECK-AVX2-NEXT: vpbroadcastq {{.*#+}} xmm1 = [2,2] -; CHECK-AVX2-NEXT: vpsllvq %xmm0, %xmm1, %xmm0 -; CHECK-AVX2-NEXT: vpxor %xmm1, %xmm1, %xmm1 -; CHECK-AVX2-NEXT: vpblendd {{.*#+}} xmm1 = xmm0[0],xmm1[1],xmm0[2],xmm1[3] -; CHECK-AVX2-NEXT: vpor {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm1, %xmm1 -; CHECK-AVX2-NEXT: vpsrlq $32, %xmm0, %xmm0 -; CHECK-AVX2-NEXT: vpor {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0 -; CHECK-AVX2-NEXT: vsubpd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0 -; CHECK-AVX2-NEXT: vaddpd %xmm0, %xmm1, %xmm0 -; CHECK-AVX2-NEXT: vmulpd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0 -; CHECK-AVX2-NEXT: retq -; -; CHECK-NO-FASTFMA-LABEL: fmul_pow_shl_cnt_vec_non_splat_todo: -; CHECK-NO-FASTFMA: # %bb.0: -; CHECK-NO-FASTFMA-NEXT: vpbroadcastq {{.*#+}} xmm1 = [2,2] -; CHECK-NO-FASTFMA-NEXT: vpsllvq %xmm0, %xmm1, %xmm0 -; CHECK-NO-FASTFMA-NEXT: vpxor %xmm1, %xmm1, %xmm1 -; CHECK-NO-FASTFMA-NEXT: vpblendd {{.*#+}} xmm1 = xmm0[0],xmm1[1],xmm0[2],xmm1[3] -; CHECK-NO-FASTFMA-NEXT: vpor {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm1, %xmm1 -; CHECK-NO-FASTFMA-NEXT: vpsrlq $32, %xmm0, %xmm0 -; CHECK-NO-FASTFMA-NEXT: vpor {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0 -; CHECK-NO-FASTFMA-NEXT: vsubpd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0 -; CHECK-NO-FASTFMA-NEXT: vaddpd %xmm0, %xmm1, %xmm0 -; CHECK-NO-FASTFMA-NEXT: vmulpd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0 -; CHECK-NO-FASTFMA-NEXT: retq -; -; CHECK-FMA-LABEL: fmul_pow_shl_cnt_vec_non_splat_todo: -; CHECK-FMA: # %bb.0: -; CHECK-FMA-NEXT: vpbroadcastq {{.*#+}} xmm1 = [2,2] -; CHECK-FMA-NEXT: vpsllvq %xmm0, %xmm1, %xmm0 -; CHECK-FMA-NEXT: vcvtuqq2pd %xmm0, %xmm0 -; CHECK-FMA-NEXT: vmulpd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0 -; CHECK-FMA-NEXT: retq +; CHECK-AVX-LABEL: fmul_pow_shl_cnt_vec_non_splat_todo: +; CHECK-AVX: # %bb.0: +; CHECK-AVX-NEXT: vpsllq $52, %xmm0, %xmm0 +; CHECK-AVX-NEXT: vpaddq {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0 +; CHECK-AVX-NEXT: retq %shl = shl nsw nuw <2 x i64> , %cnt %conv = uitofp <2 x i64> %shl to <2 x double> %mul = fmul <2 x double> , %conv