diff --git a/llvm/include/llvm/IR/PatternMatch.h b/llvm/include/llvm/IR/PatternMatch.h --- a/llvm/include/llvm/IR/PatternMatch.h +++ b/llvm/include/llvm/IR/PatternMatch.h @@ -2200,6 +2200,11 @@ return m_Intrinsic(Op0, Op1); } +template +inline typename m_Intrinsic_Ty::Ty m_Sqrt(const Opnd0 &Op0) { + return m_Intrinsic(Op0); +} + template inline typename m_Intrinsic_Ty::Ty m_FShl(const Opnd0 &Op0, const Opnd1 &Op1, const Opnd2 &Op2) { diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -6754,6 +6754,17 @@ } } + if (match(Op0, m_Sqrt(m_Value(X)))) { + // fcmp sqrt(x),C -> fcmp x,C*C + const APFloat *CF; + if (match(Op1, m_APFloat(CF)) && !CF->isNegative() && I.isFast()) { + Constant *C = ConstantFP::get(X->getType(), *CF); + Instruction *FCmp = new FCmpInst(Pred, X, ConstantExpr::getFMul(C, C)); + FCmp->setFastMathFlags(I.getFastMathFlags()); + return FCmp; + } + } + if (match(Op0, m_FPExt(m_Value(X)))) { // fcmp (fpext X), (fpext Y) -> fcmp X, Y if (match(Op1, m_FPExt(m_Value(Y))) && X->getType() == Y->getType()) diff --git a/llvm/test/Transforms/InstCombine/fcmp.ll b/llvm/test/Transforms/InstCombine/fcmp.ll --- a/llvm/test/Transforms/InstCombine/fcmp.ll +++ b/llvm/test/Transforms/InstCombine/fcmp.ll @@ -3,6 +3,7 @@ declare half @llvm.fabs.f16(half) declare double @llvm.fabs.f64(double) +declare double @llvm.sqrt.f64(double) declare <2 x float> @llvm.fabs.v2f32(<2 x float>) declare double @llvm.copysign.f64(double, double) declare <2 x double> @llvm.copysign.v2f64(<2 x double>, <2 x double>) @@ -1210,3 +1211,38 @@ %cmp = fcmp ninf une float %a, %fneg ret i1 %cmp } + +; fcmp sqrt(X),C --> fcmp X,C*C +define i1 @fcmp_fsqrt_test1(double %v) { +; CHECK-LABEL: @fcmp_fsqrt_test1( +; CHECK-NEXT: [[CMP:%.*]] = fcmp fast ogt double [[V:%.*]], 4.000000e+00 +; CHECK-NEXT: ret i1 [[CMP]] +; + %sqrt = call double @llvm.sqrt.f64(double %v) + %cmp = fcmp fast ogt double %sqrt, 2.000000e+00 + ret i1 %cmp +} + +; ensure we preserve sqrts when compared against negative numbers. +define i1 @fcmp_fsqrt_test2(double %v) { +; CHECK-LABEL: @fcmp_fsqrt_test2( +; CHECK-NEXT: [[SQRT:%.*]] = call double @llvm.sqrt.f64(double [[V:%.*]]) +; CHECK-NEXT: [[CMP:%.*]] = fcmp fast ogt double [[SQRT]], -2.000000e+00 +; CHECK-NEXT: ret i1 [[CMP]] +; + %sqrt = call double @llvm.sqrt.f64(double %v) + %cmp = fcmp fast ogt double %sqrt, -2.000000e+00 + ret i1 %cmp +} + +; ensure we maintain sqrts when preserving NaNs. +define i1 @fcmp_fsqrt_test3(double %v) { +; CHECK-LABEL: @fcmp_fsqrt_test3( +; CHECK-NEXT: [[SQRT:%.*]] = call double @llvm.sqrt.f64(double [[V:%.*]]) +; CHECK-NEXT: [[CMP:%.*]] = fcmp ogt double [[SQRT]], 2.000000e+00 +; CHECK-NEXT: ret i1 [[CMP]] +; + %sqrt = call double @llvm.sqrt.f64(double %v) + %cmp = fcmp ogt double %sqrt, 2.000000e+00 + ret i1 %cmp +}