diff --git a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h --- a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h +++ b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h @@ -54,6 +54,7 @@ #include #include #include +#include namespace llvm { @@ -3120,12 +3121,44 @@ cast(N)->getAddressingMode() == ISD::UNINDEXED; } - /// Attempt to match a unary predicate against a scalar/splat constant or + /// Attempt to match a unary predicate against an scalar/splat constant or /// every element of a constant BUILD_VECTOR. /// If AllowUndef is true, then UNDEF elements will pass nullptr to Match. - bool matchUnaryPredicate(SDValue Op, - std::function Match, - bool AllowUndefs = false); + bool matchUnaryPredicateAnyNode(SDValue Op, + std::function Match, + bool AllowUndefs = false); + + /// Wrapper for matchUnaryPredicateAnyNode to accept a function that only + /// matches a specific NodeType (ConstantSDNode or ConstantFPSDNode). + template + inline bool + matchUnaryPredicateWrapper(SDValue Op, + std::function Match, + bool AllowUndefs) { + auto ConstNodeWrapper = [Match](SDNode *N) { + if (N == nullptr) + return Match(nullptr); + else if (auto *C = dyn_cast(N)) + return Match(C); + return false; + }; + return matchUnaryPredicateAnyNode(Op, ConstNodeWrapper, AllowUndefs); + } + + /// Hook for matching ConstantSDNode predicate + inline bool matchUnaryPredicate(SDValue Op, + std::function Match, + bool AllowUndefs = false) { + return matchUnaryPredicateWrapper(Op, Match, AllowUndefs); + } + + /// Hook for matching ConstantFPSDNode predicate + inline bool + matchUnaryFpPredicate(SDValue Op, + std::function Match, + bool AllowUndefs = false) { + return matchUnaryPredicateWrapper(Op, Match, AllowUndefs); + } /// Attempt to match a binary predicate against a pair of scalar/splat /// constants or every element of a pair of constant BUILD_VECTORs. diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -16214,7 +16214,7 @@ EVT VT = N->getValueType(0); SDValue ConstOp, Pow2Op; - int Mantissa = -1; + std::optional Mantissa; auto GetConstAndPow2Ops = [&](unsigned ConstOpIdx) { if (ConstOpIdx == 1 && N->getOpcode() == ISD::FDIV) return false; @@ -16228,36 +16228,43 @@ Pow2Op = Pow2Op.getOperand(0); - // TODO(1): We may be able to include undefs. - // TODO(2): We could also handle non-splat vector types. - ConstantFPSDNode *CFP = - isConstOrConstSplatFP(ConstOp, /*AllowUndefs*/ false); - if (CFP == nullptr) - return false; - const APFloat &APF = CFP->getValueAPF(); - - // Make sure we have normal/ieee constant. - if (!APF.isNormal() || !APF.isIEEE()) - return false; - // `Log2(Pow2Op) < Pow2Op.getScalarSizeInBits()`. // TODO: We could use knownbits to make this bound more precise. int MaxExpChange = Pow2Op.getValueType().getScalarSizeInBits(); - // Make sure the floats exponent is within the bounds that this transform - // produces bitwise equals value. - int CurExp = ilogb(APF); - // FMul by pow2 will only increase exponent. - int MinExp = N->getOpcode() == ISD::FMUL ? CurExp : (CurExp - MaxExpChange); - // FDiv by pow2 will only decrease exponent. - int MaxExp = N->getOpcode() == ISD::FDIV ? CurExp : (CurExp + MaxExpChange); - if (MinExp <= APFloat::semanticsMinExponent(APF.getSemantics()) || - MaxExp >= APFloat::semanticsMaxExponent(APF.getSemantics())) - return false; + auto IsFPConstValid = [N, MaxExpChange, &Mantissa](ConstantFPSDNode *CFP) { + if (CFP == nullptr) + return false; + + const APFloat &APF = CFP->getValueAPF(); + + // Make sure we have normal/ieee constant. + if (!APF.isNormal() || !APF.isIEEE()) + return false; + + // Make sure the floats exponent is within the bounds that this transform + // produces bitwise equals value. + int CurExp = ilogb(APF); + // FMul by pow2 will only increase exponent. + int MinExp = + N->getOpcode() == ISD::FMUL ? CurExp : (CurExp - MaxExpChange); + // FDiv by pow2 will only decrease exponent. + int MaxExp = + N->getOpcode() == ISD::FDIV ? CurExp : (CurExp + MaxExpChange); + if (MinExp <= APFloat::semanticsMinExponent(APF.getSemantics()) || + MaxExp >= APFloat::semanticsMaxExponent(APF.getSemantics())) + return false; + + // Finally make sure we actually know the mantissa for the float type. + int ThisMantissa = APFloat::semanticsPrecision(APF.getSemantics()) - 1; + if (!Mantissa) + Mantissa = ThisMantissa; + + return *Mantissa == ThisMantissa && ThisMantissa > 0; + }; - // Finally make sure we actually know the mantissa for the float type. - Mantissa = APFloat::semanticsPrecision(APF.getSemantics()) - 1; - return Mantissa > 0; + // TODO: We may be able to include undefs. + return ISD::matchUnaryFpPredicate(ConstOp, IsFPConstValid); }; if (!GetConstAndPow2Ops(0) && !GetConstAndPow2Ops(1)) @@ -16282,7 +16289,7 @@ // Perform actual transform. SDValue MantissaShiftCnt = - DAG.getConstant(Mantissa, DL, DC->getShiftAmountTy(NewIntVT)); + DAG.getConstant(*Mantissa, DL, DC->getShiftAmountTy(NewIntVT)); // TODO: Sometimes Log2 is of form `(X + C)`. `(X + C) << C1` should fold to // `(X << C1) + (C << C1)`, but that isn't always the case because of the // cast. We could implement that by handle here to handle the casts. diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -344,14 +344,17 @@ return N->getOpcode() == ISD::FREEZE && N->getOperand(0).isUndef(); } -bool ISD::matchUnaryPredicate(SDValue Op, - std::function Match, - bool AllowUndefs) { - // FIXME: Add support for scalar UNDEF cases? - if (auto *Cst = dyn_cast(Op)) - return Match(Cst); +bool ISD::matchUnaryPredicateAnyNode(SDValue Op, + std::function Match, + bool AllowUndefs) { + auto IsRecognizedConstNode = [](SDValue V) { + return isa(V) || isa(V); + }; - // FIXME: Add support for vector UNDEF cases? + if (IsRecognizedConstNode(Op)) + return Match(Op.getNode()); + + // FIXME: Add support for scalar UNDEF cases? if (ISD::BUILD_VECTOR != Op.getOpcode() && ISD::SPLAT_VECTOR != Op.getOpcode()) return false; @@ -364,8 +367,9 @@ continue; } - auto *Cst = dyn_cast(Op.getOperand(i)); - if (!Cst || Cst->getValueType(0) != SVT || !Match(Cst)) + SDValue OpI = Op.getOperand(i); + if (!IsRecognizedConstNode(OpI) || OpI.getValueType() != SVT || + !Match(OpI.getNode())) return false; } return true; diff --git a/llvm/test/CodeGen/X86/fold-int-pow2-with-fmul-or-fdiv.ll b/llvm/test/CodeGen/X86/fold-int-pow2-with-fmul-or-fdiv.ll --- a/llvm/test/CodeGen/X86/fold-int-pow2-with-fmul-or-fdiv.ll +++ b/llvm/test/CodeGen/X86/fold-int-pow2-with-fmul-or-fdiv.ll @@ -320,44 +320,15 @@ define <2 x double> @fmul_pow_shl_cnt_vec_non_splat_todo(<2 x i64> %cnt) { ; CHECK-SSE-LABEL: fmul_pow_shl_cnt_vec_non_splat_todo: ; CHECK-SSE: # %bb.0: -; CHECK-SSE-NEXT: movdqa {{.*#+}} xmm1 = [2,2] -; CHECK-SSE-NEXT: movdqa %xmm1, %xmm2 -; CHECK-SSE-NEXT: psllq %xmm0, %xmm2 -; CHECK-SSE-NEXT: pshufd {{.*#+}} xmm0 = xmm0[2,3,2,3] -; CHECK-SSE-NEXT: psllq %xmm0, %xmm1 -; CHECK-SSE-NEXT: movsd {{.*#+}} xmm1 = xmm2[0],xmm1[1] -; CHECK-SSE-NEXT: movapd {{.*#+}} xmm0 = [4294967295,4294967295] -; CHECK-SSE-NEXT: andpd %xmm1, %xmm0 -; CHECK-SSE-NEXT: orpd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0 -; CHECK-SSE-NEXT: psrlq $32, %xmm1 -; CHECK-SSE-NEXT: por {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm1 -; CHECK-SSE-NEXT: subpd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm1 -; CHECK-SSE-NEXT: addpd %xmm0, %xmm1 -; CHECK-SSE-NEXT: mulpd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm1 -; CHECK-SSE-NEXT: movapd %xmm1, %xmm0 +; CHECK-SSE-NEXT: psllq $52, %xmm0 +; CHECK-SSE-NEXT: paddq {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0 ; CHECK-SSE-NEXT: retq ; -; CHECK-AVX2-LABEL: fmul_pow_shl_cnt_vec_non_splat_todo: -; CHECK-AVX2: # %bb.0: -; CHECK-AVX2-NEXT: vpbroadcastq {{.*#+}} xmm1 = [2,2] -; CHECK-AVX2-NEXT: vpsllvq %xmm0, %xmm1, %xmm0 -; CHECK-AVX2-NEXT: vpxor %xmm1, %xmm1, %xmm1 -; CHECK-AVX2-NEXT: vpblendd {{.*#+}} xmm1 = xmm0[0],xmm1[1],xmm0[2],xmm1[3] -; CHECK-AVX2-NEXT: vpor {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm1, %xmm1 -; CHECK-AVX2-NEXT: vpsrlq $32, %xmm0, %xmm0 -; CHECK-AVX2-NEXT: vpor {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0 -; CHECK-AVX2-NEXT: vsubpd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0 -; CHECK-AVX2-NEXT: vaddpd %xmm0, %xmm1, %xmm0 -; CHECK-AVX2-NEXT: vmulpd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0 -; CHECK-AVX2-NEXT: retq -; -; CHECK-FMA-LABEL: fmul_pow_shl_cnt_vec_non_splat_todo: -; CHECK-FMA: # %bb.0: -; CHECK-FMA-NEXT: vpbroadcastq {{.*#+}} xmm1 = [2,2] -; CHECK-FMA-NEXT: vpsllvq %xmm0, %xmm1, %xmm0 -; CHECK-FMA-NEXT: vcvtuqq2pd %xmm0, %xmm0 -; CHECK-FMA-NEXT: vmulpd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0 -; CHECK-FMA-NEXT: retq +; CHECK-AVX-LABEL: fmul_pow_shl_cnt_vec_non_splat_todo: +; CHECK-AVX: # %bb.0: +; CHECK-AVX-NEXT: vpsllq $52, %xmm0, %xmm0 +; CHECK-AVX-NEXT: vpaddq {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0 +; CHECK-AVX-NEXT: retq %shl = shl nsw nuw <2 x i64> , %cnt %conv = uitofp <2 x i64> %shl to <2 x double> %mul = fmul <2 x double> , %conv