diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp --- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp +++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp @@ -756,30 +756,59 @@ Value *InstCombiner::SimplifySelectsFeedingBinaryOp(BinaryOperator &I, Value *LHS, Value *RHS) { Instruction::BinaryOps Opcode = I.getOpcode(); - // (op (select (a, b, c)), (select (a, d, e))) -> (select (a, (op b, d), (op - // c, e))) - Value *A, *B, *C, *D, *E; + Value *A, *B, *C, *D, *E, *F; Value *SI = nullptr; - if (match(LHS, m_Select(m_Value(A), m_Value(B), m_Value(C))) && - match(RHS, m_Select(m_Specific(A), m_Value(D), m_Value(E)))) { + bool LHSIsSelect = match(LHS, m_Select(m_Value(A), m_Value(B), m_Value(C))); + bool RHSIsSelect = match(RHS, m_Select(m_Value(D), m_Value(E), m_Value(F))); + + if (LHSIsSelect && RHSIsSelect && A == D) { + // (A ? B : C) op (A ? E : F) -> A ? (B op E) : (C op F) bool SelectsHaveOneUse = LHS->hasOneUse() && RHS->hasOneUse(); + FastMathFlags FMF; BuilderTy::FastMathFlagGuard Guard(Builder); - if (isa(&I)) - Builder.setFastMathFlags(I.getFastMathFlags()); + if (isa(&I)) { + FMF = I.getFastMathFlags(); + Builder.setFastMathFlags(FMF); + } - Value *V1 = SimplifyBinOp(Opcode, C, E, SQ.getWithInstruction(&I)); - Value *V2 = SimplifyBinOp(Opcode, B, D, SQ.getWithInstruction(&I)); + Value *V1 = SimplifyFPBinOp(Opcode, C, F, FMF, SQ.getWithInstruction(&I)); + Value *V2 = SimplifyFPBinOp(Opcode, B, E, FMF, SQ.getWithInstruction(&I)); if (V1 && V2) SI = Builder.CreateSelect(A, V2, V1); else if (V2 && SelectsHaveOneUse) - SI = Builder.CreateSelect(A, V2, Builder.CreateBinOp(Opcode, C, E)); + SI = Builder.CreateSelect(A, V2, Builder.CreateBinOp(Opcode, C, F)); else if (V1 && SelectsHaveOneUse) - SI = Builder.CreateSelect(A, Builder.CreateBinOp(Opcode, B, D), V1); + SI = Builder.CreateSelect(A, Builder.CreateBinOp(Opcode, B, E), V1); + } else if (LHSIsSelect) { + // (A ? B : C) op Y -> A ? (B op Y) : (C op Y) + FastMathFlags FMF; + BuilderTy::FastMathFlagGuard Guard(Builder); + if (isa(&I)) { + FMF = I.getFastMathFlags(); + Builder.setFastMathFlags(FMF); + } + + Value *V1 = SimplifyFPBinOp(Opcode, C, RHS, FMF, SQ.getWithInstruction(&I)); + Value *V2 = SimplifyFPBinOp(Opcode, B, RHS, FMF, SQ.getWithInstruction(&I)); + if (V1 && V2) + SI = Builder.CreateSelect(A, V2, V1); + } else if (RHSIsSelect) { + // X op (D ? E : F) -> D ? (X op E) : (X op F) + FastMathFlags FMF; + BuilderTy::FastMathFlagGuard Guard(Builder); + if (isa(&I)) { + FMF = I.getFastMathFlags(); + Builder.setFastMathFlags(FMF); + } - if (SI) - SI->takeName(&I); + Value *V1 = SimplifyFPBinOp(Opcode, LHS, F, FMF, SQ.getWithInstruction(&I)); + Value *V2 = SimplifyFPBinOp(Opcode, LHS, E, FMF, SQ.getWithInstruction(&I)); + if (V1 && V2) + SI = Builder.CreateSelect(D, V2, V1); } + if (SI) + SI->takeName(&I); return SI; } diff --git a/llvm/test/Transforms/InstCombine/fmul.ll b/llvm/test/Transforms/InstCombine/fmul.ll --- a/llvm/test/Transforms/InstCombine/fmul.ll +++ b/llvm/test/Transforms/InstCombine/fmul.ll @@ -994,3 +994,38 @@ %r = fmul double %x, fsub (double -0.000000e+00, double bitcast (i64 ptrtoint (i8** getelementptr inbounds ({ [2 x i8*] }, { [2 x i8*] }* @g, i64 0, inrange i32 0, i64 2) to i64) to double)) ret double %r } + +; X *fast (C ? 1.0 : 0.0) -> C ? X : 0.0 +define float @fmul_select(float %x, i1 %c) { +; CHECK-LABEL: @fmul_select( +; CHECK-NEXT: [[MUL:%.*]] = select fast i1 [[C:%.*]], float [[X:%.*]], float 0.000000e+00 +; CHECK-NEXT: ret float [[MUL]] +; + %sel = select i1 %c, float 1.0, float 0.0 + %mul = fmul fast float %x, %sel + ret float %mul +} + +; X *fast (C ? 1.0 : 0.0) -> C ? X : 0.0 +define <2 x float> @fmul_select_vec(<2 x float> %x, i1 %c) { +; CHECK-LABEL: @fmul_select_vec( +; CHECK-NEXT: [[MUL:%.*]] = select fast i1 [[C:%.*]], <2 x float> [[X:%.*]], <2 x float> zeroinitializer +; CHECK-NEXT: ret <2 x float> [[MUL]] +; + %sel = select i1 %c, <2 x float> , <2 x float> zeroinitializer + %mul = fmul fast <2 x float> %x, %sel + ret <2 x float> %mul +} + +; sqrt(X) *fast (C ? sqrt(X) : 1.0) -> C ? X : sqrt(X) +define double @fmul_sqrt_select(double %x, i1 %c) { +; CHECK-LABEL: @fmul_sqrt_select( +; CHECK-NEXT: [[SQR:%.*]] = call double @llvm.sqrt.f64(double [[X:%.*]]) +; CHECK-NEXT: [[MUL:%.*]] = select fast i1 [[C:%.*]], double [[X]], double [[SQR]] +; CHECK-NEXT: ret double [[MUL]] +; + %sqr = call double @llvm.sqrt.f64(double %x) + %sel = select i1 %c, double %sqr, double 1.0 + %mul = fmul fast double %sqr, %sel + ret double %mul +} diff --git a/llvm/test/Transforms/InstCombine/mul.ll b/llvm/test/Transforms/InstCombine/mul.ll --- a/llvm/test/Transforms/InstCombine/mul.ll +++ b/llvm/test/Transforms/InstCombine/mul.ll @@ -517,3 +517,15 @@ %B4 = mul i64 %B8, %L1 ret i64 %B4 } + +; (C ? (X /exact Y) : 1) * Y -> C ? X : Y +define i32 @mul_div_select(i32 %x, i32 %y, i1 %c) { +; CHECK-LABEL: @mul_div_select( +; CHECK-NEXT: [[MUL:%.*]] = select i1 [[C:%.*]], i32 [[X:%.*]], i32 [[Y:%.*]] +; CHECK-NEXT: ret i32 [[MUL]] +; + %div = udiv exact i32 %x, %y + %sel = select i1 %c, i32 %div, i32 1 + %mul = mul i32 %sel, %y + ret i32 %mul +}