Index: llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp =================================================================== --- llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -267,6 +267,7 @@ SDValue PromoteIntShiftOp(SDValue Op); SDValue PromoteExtend(SDValue Op); bool PromoteLoad(SDValue Op); + SDValue combineSelectFP(SDNode *N); /// Call the node-specific routine that knows how to fold each /// particular type of node. If that doesn't do anything, try the @@ -6637,6 +6638,87 @@ return SDValue(); } +// Perform checks on select instructions, and replace it with +// fmin or fmax. +SDValue DAGCombiner::combineSelectFP(SDNode *N) { + SDValue N0 = N->getOperand(0); + SDValue N1 = N->getOperand(1); + SDValue N2 = N->getOperand(2); + + EVT VT = N->getValueType(0); + SDNode *CmpNode = N0.getNode(); + + // Check if num of operands match requirements + if (N2->getNumOperands() < 2 || CmpNode->getOpcode() != ISD::SETCC || + CmpNode->getOperand(0) != N2.getOperand(0)) + return SDValue(); + + // Is it really safe here? + ConstantFPSDNode *CmpN1FP = + dyn_cast(CmpNode->getOperand(1)); + ConstantFPSDNode *N1FP = dyn_cast(N1); + if (!CmpN1FP || !N1FP) + return SDValue(); + + const APFloat &CmpN1FPVal = CmpN1FP->getValueAPF(); + const APFloat &N1FPVal = N1FP->getValueAPF(); + + // Check if float point constant are the same + if (&CmpN1FPVal.getSemantics() != &N1FPVal.getSemantics() || + CmpN1FPVal.compare(N1FPVal) != APFloat::cmpEqual) + return SDValue(); + + // Check if the N2 operand is constant float + SDNode *N2Node = N2.getNode(); + ConstantFPSDNode *N2Operand2 = + dyn_cast(N2Node->getOperand(1).getNode()); + if (!N2Operand2) + return SDValue(); + + // Check the value of cmpN1 and N2 are equal + APFloat N2Operand2Val = N2Operand2->getValueAPF(); + if (&CmpN1FPVal.getSemantics() != &N2Operand2Val.getSemantics()) + return SDValue(); + + unsigned RetOpcode; + // Switch based on the comparison operand. + switch (cast(CmpNode->getOperand(2))->get()) { + case ISD::SETOLT: + case ISD::SETOLE: + case ISD::SETLT: + case ISD::SETLE: + case ISD::SETULT: + case ISD::SETULE: { + if (N1FPVal.compare(N2Operand2Val) != APFloat::cmpLessThan) + return SDValue(); + + RetOpcode = ISD::FMAXNUM; + unsigned N2Opcode = N2Node->getOpcode(); + if (N2Opcode != ISD::FMINNUM && N2Opcode != ISD::FMINNAN) + return SDValue(); + }; break; + case ISD::SETOGT: + case ISD::SETOGE: + case ISD::SETGT: + case ISD::SETGE: + case ISD::SETUGT: + case ISD::SETUGE: { + if (N1FPVal.compare(N2Operand2Val) != APFloat::cmpGreaterThan) + return SDValue(); + + RetOpcode = ISD::FMINNUM; + unsigned N2Opcode = N2Node->getOpcode(); + if (N2Opcode != ISD::FMAXNUM && N2Opcode != ISD::FMAXNAN) + return SDValue(); + }; break; + default: + return SDValue(); + break; + } + + return DAG.getNode(RetOpcode, SDLoc(N), VT, N2, N1); +} + SDValue DAGCombiner::visitSELECT(SDNode *N) { SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); @@ -6783,6 +6865,16 @@ if (SDValue FMinMax = combineMinNumMaxNum( DL, VT, N0.getOperand(0), N0.getOperand(1), N1, N2, CC, TLI, DAG)) return FMinMax; + + // t5: i1 = setcc t2, ConstantFP:f1, setgt:ch + // t9: f32 = fmaxnum t2, ConstantFP:f2 + // t10: f32 = select t5, ConstantFP:f1, t9 + // and f1 >= f2 + // ==> t9 = fmaxnum t2, f2 + // t10 = fminnum t9, f1 + if (isa(N1) && N0->getOpcode() == ISD::SETCC) + if (SDValue combinedFP = combineSelectFP(N)) + return combinedFP; } if ((!LegalOperations && Index: llvm/test/CodeGen/AArch64/aarch64-DAGCombine-fminmax.ll =================================================================== --- /dev/null +++ llvm/test/CodeGen/AArch64/aarch64-DAGCombine-fminmax.ll @@ -0,0 +1,88 @@ +; RUN: llc --enable-unsafe-fp-math -mtriple=aarch64-unknown-linux-gnu < %s | FileCheck %s +; RUN: llc -mtriple=aarch64-unknown-linux-gnu < %s | FileCheck -check-prefix=CHECK1 %s + +target datalayout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128" + +declare float @llvm.minnum.f32(float, float) +declare float @llvm.maxnum.f32(float, float) + +; Function Attrs: norecurse nounwind readnone +; CHECK: fmax{{(nm)?}} {{s[0-9]+}}, {{s[0-9]+}}, {{s[0-9]+}} +; CHECK: fminnm {{s[0-9]+}}, {{s[0-9]+}}, {{s[0-9]+}} +; CHECK1: fmax{{(nm)?}} {{s[0-9]+}}, {{s[0-9]+}}, {{s[0-9]+}} +; CHECK1-NOT: fminnm {{s[0-9]+}}, {{s[0-9]+}}, {{s[0-9]+}} +define float @clampNUM(float %a) local_unnamed_addr #0 { +entry: + %cmp = fcmp ogt float %a, 2.550000e+02 + %cmp3 = fcmp olt float %a, 1.280000e+02 + %.a = select i1 %cmp3, float 1.280000e+02, float %a + %retval.0 = select i1 %cmp, float 2.550000e+02, float %.a + ret float %retval.0 +} + +; Function Attrs: norecurse nounwind readnone +; CHECK: fminnm {{s[0-9]+}}, {{s[0-9]+}}, {{s[0-9]+}} +; CHECK: fmaxnm {{s[0-9]+}}, {{s[0-9]+}}, {{s[0-9]+}} +; CHECK1: fminnm {{s[0-9]+}}, {{s[0-9]+}}, {{s[0-9]+}} +; CHECK1-NOT: fmaxnm {{s[0-9]+}}, {{s[0-9]+}}, {{s[0-9]+}} +define float @clampNAN(float %a) local_unnamed_addr #0 { +entry: + %cmp = fcmp olt float %a, 1.280000e+02 + %cmp2 = fcmp fast ogt float %a, 2.550000e+02 + %.a = select i1 %cmp2, float 2.550000e+02, float %a + %retval.0 = select i1 %cmp, float 1.280000e+02, float %.a + ret float %retval.0 +} + +; Function Attrs: norecurse nounwind readnone +; CHECK: fmaxnm {{s[0-9]+}}, {{s[0-9]+}}, {{s[0-9]+}} +; CHECK: fminnm {{s[0-9]+}}, {{s[0-9]+}}, {{s[0-9]+}} +; CHECK1: fmaxnm {{s[0-9]+}}, {{s[0-9]+}}, {{s[0-9]+}} +; CHECK1-NOT: fminnm {{s[0-9]+}}, {{s[0-9]+}}, {{s[0-9]+}} +define float @clampIntrinsicFmax(float %a) local_unnamed_addr #0 { +entry: + %cmp = fcmp ogt float %a, 2.550000e+02 + %.a = call float @llvm.maxnum.f32(float %a, float 1.280000e+02) readnone + %retval.0 = select i1 %cmp, float 2.550000e+02, float %.a + ret float %retval.0 +} + +; Function Attrs: norecurse nounwind readnone +; CHECK: fminnm {{s[0-9]+}}, {{s[0-9]+}}, {{s[0-9]+}} +; CHECK: fmaxnm {{s[0-9]+}}, {{s[0-9]+}}, {{s[0-9]+}} +; CHECK1: fminnm {{s[0-9]+}}, {{s[0-9]+}}, {{s[0-9]+}} +; CHECK1-NOT: fmaxnm {{s[0-9]+}}, {{s[0-9]+}}, {{s[0-9]+}} +define float @clampIntrinsicFmin(float %a) local_unnamed_addr #0 { +entry: + %cmp = fcmp olt float %a, 1.280000e+02 + %.a = call float @llvm.minnum.f32(float %a, float 2.550000e+02) readnone + %retval.0 = select i1 %cmp, float 1.280000e+02, float %.a + ret float %retval.0 +} + +; Function Attrs: norecurse nounwind readnone +; CHECK-NOT: fmin +; CHECK1-NOT: fmin +define double @clampNoConvert(float %a) local_unnamed_addr #0 { +entry: + %cmp = fcmp ogt float %a, 2.550000e+02 + %.inv = fcmp ole float %a, 3.000000e+00 + %0 = select i1 %.inv, float 3.000000e+00, float %a + %1 = fpext float %0 to double + %retval.0 = select i1 %cmp, double 2.550000e+02, double %1 + ret double %retval.0 +} + +; Function Attrs: norecurse nounwind readnone +; CHECK-NOT: fmin +; CHECK1-NOT: fmin +define float @clampNo2Convert(float %a) local_unnamed_addr #0 { +entry: + %cmp = fcmp ogt float %a, 2.550000e+02 + %cmp2 = fcmp fast olt float %a, 0.000000e+00 + %.a = select i1 %cmp2, float 3.000000e+01, float %a + %retval.0 = select i1 %cmp, float 2.550000e+02, float %.a + ret float %retval.0 +} + +attributes #0 = { norecurse nounwind readnone "no-nans-fp-math"="true" }