Index: include/llvm/Analysis/ValueTracking.h =================================================================== --- include/llvm/Analysis/ValueTracking.h +++ include/llvm/Analysis/ValueTracking.h @@ -292,12 +292,14 @@ /// \brief Specific patterns of select instructions we can match. enum SelectPatternFlavor { SPF_UNKNOWN = 0, - SPF_SMIN, // Signed minimum - SPF_UMIN, // Unsigned minimum - SPF_SMAX, // Signed maximum - SPF_UMAX, // Unsigned maximum - SPF_ABS, // Absolute value - SPF_NABS // Negated absolute value + SPF_SMIN, /// Signed minimum + SPF_UMIN, /// Unsigned minimum + SPF_SMAX, /// Signed maximum + SPF_UMAX, /// Unsigned maximum + SPF_FMINNUM, /// Floating point minnum + SPF_FMAXNUM, /// Floating point maxnum + SPF_ABS, /// Absolute value + SPF_NABS /// Negated absolute value }; /// Pattern match integer [SU]MIN, [SU]MAX and ABS idioms, returning the kind /// and providing the out parameter results if we successfully match. Index: lib/Analysis/ValueTracking.cpp =================================================================== --- lib/Analysis/ValueTracking.cpp +++ lib/Analysis/ValueTracking.cpp @@ -3316,13 +3316,58 @@ return OverflowResult::MayOverflow; } -static SelectPatternFlavor matchSelectPattern(ICmpInst::Predicate Pred, +static bool isKnownNonNaN(Value *V, FastMathFlags FMF) { + if (FMF.noNaNs()) + return true; + + if (ConstantFP *C = dyn_cast(V)) + return !C->isNaN(); + else + return false; +} + +static SelectPatternFlavor matchSelectPattern(CmpInst::Predicate Pred, + FastMathFlags FMF, Value *CmpLHS, Value *CmpRHS, Value *TrueVal, Value *FalseVal, Value *&LHS, Value *&RHS) { LHS = CmpLHS; RHS = CmpRHS; + // Floating point fcmp+select may not return the same value as a minnum/maxnum + // operation in the presence of NaNs. + bool Swap = false; + if (CmpInst::isFPPredicate(Pred)) { + bool LHSSafe = isKnownNonNaN(CmpLHS, FMF); + bool RHSSafe = isKnownNonNaN(CmpRHS, FMF); + + if (CmpInst::isOrdered(Pred)) { + // An ordered comparison is safe as long as the RHS is non-NaN. + if (RHSSafe) + ; // All good, do nothing. + else if (LHSSafe) + Swap = true; + else + // Completely unsafe. + return SPF_UNKNOWN; + } else { + // An unordered comparison is safe as long as the LHS is non-NaN. + if (LHSSafe) + ; // All good, do nothing. + else if (RHSSafe) + Swap = true; + else + // Completely unsafe. + return SPF_UNKNOWN; + } + } + + Swap |= TrueVal == CmpRHS && FalseVal == CmpLHS; + if (Swap) { + std::swap(CmpLHS, CmpRHS); + Pred = CmpInst::getSwappedPredicate(Pred); + } + // (icmp X, Y) ? X : Y if (TrueVal == CmpLHS && FalseVal == CmpRHS) { switch (Pred) { @@ -3335,21 +3380,14 @@ case ICmpInst::ICMP_ULE: return SPF_UMIN; case ICmpInst::ICMP_SLT: case ICmpInst::ICMP_SLE: return SPF_SMIN; - } - } - - // (icmp X, Y) ? Y : X - if (TrueVal == CmpRHS && FalseVal == CmpLHS) { - switch (Pred) { - default: return SPF_UNKNOWN; // Equality. - case ICmpInst::ICMP_UGT: - case ICmpInst::ICMP_UGE: return SPF_UMIN; - case ICmpInst::ICMP_SGT: - case ICmpInst::ICMP_SGE: return SPF_SMIN; - case ICmpInst::ICMP_ULT: - case ICmpInst::ICMP_ULE: return SPF_UMAX; - case ICmpInst::ICMP_SLT: - case ICmpInst::ICMP_SLE: return SPF_SMAX; + case FCmpInst::FCMP_UGT: + case FCmpInst::FCMP_UGE: + case FCmpInst::FCMP_OGT: + case FCmpInst::FCMP_OGE: return SPF_FMAXNUM; + case FCmpInst::FCMP_ULT: + case FCmpInst::FCMP_ULE: + case FCmpInst::FCMP_OLT: + case FCmpInst::FCMP_OLE: return SPF_FMINNUM; } } @@ -3387,7 +3425,7 @@ return SPF_UNKNOWN; } -static Constant *lookThroughCast(ICmpInst *CmpI, Value *V1, Value *V2, +static Constant *lookThroughCast(CmpInst *CmpI, Value *V1, Value *V2, Instruction::CastOps *CastOp) { CastInst *CI = dyn_cast(V1); Constant *C = dyn_cast(V2); @@ -3409,6 +3447,24 @@ if (isa(CI)) return ConstantExpr::getIntegerCast(C, CI->getSrcTy(), CmpI->isSigned()); + if (isa(CI)) + return ConstantExpr::getUIToFP(C, CI->getSrcTy(), true); + + if (isa(CI)) + return ConstantExpr::getSIToFP(C, CI->getSrcTy(), true); + + if (isa(CI)) + return ConstantExpr::getFPToUI(C, CI->getSrcTy(), true); + + if (isa(CI)) + return ConstantExpr::getFPToSI(C, CI->getSrcTy(), true); + + if (isa(CI)) + return ConstantExpr::getFPExtend(C, CI->getSrcTy(), true); + + if (isa(CI)) + return ConstantExpr::getFPTrunc(C, CI->getSrcTy(), true); + return nullptr; } @@ -3418,14 +3474,17 @@ SelectInst *SI = dyn_cast(V); if (!SI) return SPF_UNKNOWN; - ICmpInst *CmpI = dyn_cast(SI->getCondition()); + CmpInst *CmpI = dyn_cast(SI->getCondition()); if (!CmpI) return SPF_UNKNOWN; - ICmpInst::Predicate Pred = CmpI->getPredicate(); + CmpInst::Predicate Pred = CmpI->getPredicate(); Value *CmpLHS = CmpI->getOperand(0); Value *CmpRHS = CmpI->getOperand(1); Value *TrueVal = SI->getTrueValue(); Value *FalseVal = SI->getFalseValue(); + FastMathFlags FMF; + if (isa(CmpI)) + FMF = CmpI->getFastMathFlags(); // Bail out early. if (CmpI->isEquality()) @@ -3434,14 +3493,14 @@ // Deal with type mismatches. if (CastOp && CmpLHS->getType() != TrueVal->getType()) { if (Constant *C = lookThroughCast(CmpI, TrueVal, FalseVal, CastOp)) - return ::matchSelectPattern(Pred, CmpLHS, CmpRHS, + return ::matchSelectPattern(Pred, FMF, CmpLHS, CmpRHS, cast(TrueVal)->getOperand(0), C, LHS, RHS); if (Constant *C = lookThroughCast(CmpI, FalseVal, TrueVal, CastOp)) - return ::matchSelectPattern(Pred, CmpLHS, CmpRHS, + return ::matchSelectPattern(Pred, FMF, CmpLHS, CmpRHS, C, cast(FalseVal)->getOperand(0), LHS, RHS); } - return ::matchSelectPattern(Pred, CmpLHS, CmpRHS, TrueVal, FalseVal, + return ::matchSelectPattern(Pred, FMF, CmpLHS, CmpRHS, TrueVal, FalseVal, LHS, RHS); } Index: lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp =================================================================== --- lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp +++ lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp @@ -2292,6 +2292,8 @@ case SPF_UMIN: Opc = ISD::UMIN; break; case SPF_SMAX: Opc = ISD::SMAX; break; case SPF_SMIN: Opc = ISD::SMIN; break; + case SPF_FMINNUM: Opc = ISD::FMINNUM; break; + case SPF_FMAXNUM: Opc = ISD::FMAXNUM; break; default: break; } Index: lib/Transforms/InstCombine/InstCombineCasts.cpp =================================================================== --- lib/Transforms/InstCombine/InstCombineCasts.cpp +++ lib/Transforms/InstCombine/InstCombineCasts.cpp @@ -1307,10 +1307,15 @@ // (fptrunc (select cond, R1, Cst)) --> // (select cond, (fptrunc R1), (fptrunc Cst)) + // + // - but only if this isn't part of a min/max operation, else we'll + // ruin min/max canonical form. + Value *LHS, *RHS; SelectInst *SI = dyn_cast(CI.getOperand(0)); if (SI && (isa(SI->getOperand(1)) || - isa(SI->getOperand(2)))) { + isa(SI->getOperand(2))) && + matchSelectPattern(SI, LHS, RHS) == SPF_UNKNOWN) { Value *LHSTrunc = Builder->CreateFPTrunc(SI->getOperand(1), CI.getType()); Value *RHSTrunc = Builder->CreateFPTrunc(SI->getOperand(2), Index: lib/Transforms/InstCombine/InstCombineSelect.cpp =================================================================== --- lib/Transforms/InstCombine/InstCombineSelect.cpp +++ lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -38,7 +38,8 @@ } } -static CmpInst::Predicate getICmpPredicateForMinMax(SelectPatternFlavor SPF) { +static CmpInst::Predicate getCmpPredicateForMinMax(SelectPatternFlavor SPF, + bool Ordered=false) { switch (SPF) { default: llvm_unreachable("unhandled!"); @@ -51,13 +52,18 @@ return ICmpInst::ICMP_SGT; case SPF_UMAX: return ICmpInst::ICMP_UGT; + case SPF_FMINNUM: + return Ordered ? FCmpInst::FCMP_OLT : FCmpInst::FCMP_ULT; + case SPF_FMAXNUM: + return Ordered ? FCmpInst::FCMP_OGT : FCmpInst::FCMP_UGT; } } static Value *generateMinMaxSelectPattern(InstCombiner::BuilderTy *Builder, SelectPatternFlavor SPF, Value *A, Value *B) { - CmpInst::Predicate Pred = getICmpPredicateForMinMax(SPF); + CmpInst::Predicate Pred = getCmpPredicateForMinMax(SPF); + assert(CmpInst::isIntPredicate(Pred)); return Builder->CreateSelect(Builder->CreateICmp(Pred, A, B), A, B); } @@ -1054,7 +1060,7 @@ } // See if we can fold the select into one of our operands. - if (SI.getType()->isIntOrIntVectorTy()) { + if (SI.getType()->isIntOrIntVectorTy() || SI.getType()->isFPOrFPVectorTy()) { if (Instruction *FoldI = FoldSelectIntoOp(SI, TrueVal, FalseVal)) return FoldI; @@ -1063,11 +1069,23 @@ SelectPatternFlavor SPF = matchSelectPattern(&SI, LHS, RHS, &CastOp); if (SPF) { + CmpInst *C = cast(SI.getCondition()); // Canonicalize so that type casts are outside select patterns. if (LHS->getType()->getPrimitiveSizeInBits() != SI.getType()->getPrimitiveSizeInBits()) { - CmpInst::Predicate Pred = getICmpPredicateForMinMax(SPF); - Value *Cmp = Builder->CreateICmp(Pred, LHS, RHS); + bool Ordered = CmpInst::isOrdered(C->getPredicate()); + CmpInst::Predicate Pred = getCmpPredicateForMinMax(SPF, Ordered); + + Value *Cmp; + if (CmpInst::isIntPredicate(Pred)) { + Cmp = Builder->CreateICmp(Pred, LHS, RHS); + } else { + IRBuilder<>::FastMathFlagGuard FMFG(*Builder); + auto FMF = cast(SI.getCondition())->getFastMathFlags(); + Builder->SetFastMathFlags(FMF); + Cmp = Builder->CreateFCmp(Pred, LHS, RHS); + } + Value *NewSI = Builder->CreateCast(CastOp, Builder->CreateSelect(Cmp, LHS, RHS), SI.getType()); Index: test/Transforms/InstCombine/minmax-fp.ll =================================================================== --- /dev/null +++ test/Transforms/InstCombine/minmax-fp.ll @@ -0,0 +1,120 @@ +; RUN: opt -S -instcombine < %s | FileCheck %s + +; CHECK-LABEL: @t1 +; CHECK-NEXT: fcmp +; CHECK-NEXT: select +; CHECK-NEXT: fpext +define double @t1(float %a) { + ; This is the canonical form for a type-changing min/max. + %1 = fcmp olt float %a, 5.0 + %2 = select i1 %1, float %a, float 5.0 + %3 = fpext float %2 to double + ret double %3 +} + +; CHECK-LABEL: @t2 +; CHECK-NEXT: fcmp +; CHECK-NEXT: select +; CHECK-NEXT: fpext +define double @t2(float %a) { + ; Check this is converted into canonical form, as above. + %1 = fcmp olt float %a, 5.0 + %2 = fpext float %a to double + %3 = select i1 %1, double %2, double 5.0 + ret double %3 +} + +; CHECK-LABEL: @t4 +; CHECK-NEXT: fcmp +; CHECK-NEXT: select +; CHECK-NEXT: fptrunc +define float @t4(double %a) { + ; Same again, with trunc. + %1 = fcmp olt double %a, 5.0 + %2 = fptrunc double %a to float + %3 = select i1 %1, float %2, float 5.0 + ret float %3 +} + +; CHECK-LABEL: @t5 +; CHECK-NEXT: fcmp +; CHECK-NEXT: fpext +; CHECK-NEXT: select +define double @t5(float %a) { + ; different values, should not be converted. + %1 = fcmp olt float %a, 5.0 + %2 = fpext float %a to double + %3 = select i1 %1, double %2, double 5.001 + ret double %3 +} + +; CHECK-LABEL: @t6 +; CHECK-NEXT: fcmp +; CHECK-NEXT: fpext +; CHECK-NEXT: select +define double @t6(float %a) { + ; Signed zero, should not be converted + %1 = fcmp olt float %a, -0.0 + %2 = fpext float %a to double + %3 = select i1 %1, double %2, double 0.0 + ret double %3 +} + +; CHECK-LABEL: @t7 +; CHECK-NEXT: fcmp +; CHECK-NEXT: fpext +; CHECK-NEXT: select +define double @t7(float %a) { + ; Signed zero, should not be converted + %1 = fcmp olt float %a, 0.0 + %2 = fpext float %a to double + %3 = select i1 %1, double %2, double -0.0 + ret double %3 +} + +; CHECK-LABEL: @t8 +; CHECK-NEXT: fcmp +; CHECK-NEXT: select +; CHECK-NEXT: fptoui +define i64 @t8(float %a) { + %1 = fcmp olt float %a, 5.0 + %2 = fptoui float %a to i64 + %3 = select i1 %1, i64 %2, i64 5 + ret i64 %3 +} + +; CHECK-LABEL: @t9 +; CHECK-NEXT: fcmp +; CHECK-NEXT: select +; CHECK-NEXT: fptosi +define i8 @t9(float %a) { + %1 = fcmp olt float %a, 0.0 + %2 = fptosi float %a to i8 + %3 = select i1 %1, i8 %2, i8 0 + ret i8 %3 +} + +; CHECK-LABEL: @t11 +; CHECK-NEXT: fcmp +; CHECK-NEXT: select +; CHECK-NEXT: fptosi +define i8 @t11(float %a, float %b) { + ; Either operand could be NaN, but fast modifier applied. + %1 = fcmp fast olt float %b, %a + %2 = fptosi float %a to i8 + %3 = fptosi float %b to i8 + %4 = select i1 %1, i8 %3, i8 %2 + ret i8 %4 +} + +; CHECK-LABEL: @t13 +; CHECK-NEXT: fcmp +; CHECK-NEXT: fptosi +; CHECK-NEXT: select +define i8 @t13(float %a) { + ; Float and int values do not match. + %1 = fcmp olt float %a, 1.5 + %2 = fptosi float %a to i8 + %3 = select i1 %1, i8 %2, i8 1 + ret i8 %3 +}