Index: lib/Transforms/InstCombine/InstCombineAddSub.cpp =================================================================== --- lib/Transforms/InstCombine/InstCombineAddSub.cpp +++ lib/Transforms/InstCombine/InstCombineAddSub.cpp @@ -1387,30 +1387,9 @@ } } - // select C, 0, B + select C, A, 0 -> select C, A, B - { - Value *A1, *B1, *C1, *A2, *B2, *C2; - if (match(LHS, m_Select(m_Value(C1), m_Value(A1), m_Value(B1))) && - match(RHS, m_Select(m_Value(C2), m_Value(A2), m_Value(B2)))) { - if (C1 == C2) { - Constant *Z1=nullptr, *Z2=nullptr; - Value *A, *B, *C=C1; - if (match(A1, m_AnyZero()) && match(B2, m_AnyZero())) { - Z1 = dyn_cast(A1); A = A2; - Z2 = dyn_cast(B2); B = B1; - } else if (match(B1, m_AnyZero()) && match(A2, m_AnyZero())) { - Z1 = dyn_cast(B1); B = B2; - Z2 = dyn_cast(A2); A = A1; - } - - if (Z1 && Z2 && - (I.hasNoSignedZeros() || - (Z1->isNegativeZeroValue() && Z2->isNegativeZeroValue()))) { - return SelectInst::Create(C, A, B); - } - } - } - } + // Handle specials cases for FAdd with selects feeding the operation + if (Value *V = SimplifySelectHelper(I, LHS, RHS, true)) + return replaceInstUsesWith(I, V); if (I.hasUnsafeAlgebra()) { if (Value *V = FAddCombine(Builder).simplify(&I)) @@ -1760,6 +1739,10 @@ } } + // Handle specials cases for FSub with selects feeding the operation + if (Value *V = SimplifySelectHelper(I, Op0, Op1, true)) + return replaceInstUsesWith(I, V); + if (I.hasUnsafeAlgebra()) { if (Value *V = FAddCombine(Builder).simplify(&I)) return replaceInstUsesWith(I, V); Index: lib/Transforms/InstCombine/InstCombineInternal.h =================================================================== --- lib/Transforms/InstCombine/InstCombineInternal.h +++ lib/Transforms/InstCombine/InstCombineInternal.h @@ -597,6 +597,11 @@ /// value, or null if it didn't simplify. Value *SimplifyUsingDistributiveLaws(BinaryOperator &I); + // Binary Op helper for select operations where the expression can be + // efficiently reorganized. + Value *SimplifySelectHelper(BinaryOperator &I, Value *LHS, + Value *RHS, bool CarryFastMathFlags); + /// This tries to simplify binary operations by factorizing out common terms /// (e. g. "(A*B)+(A*C)" -> "A*(B+C)"). Value *tryFactorization(BinaryOperator &, Instruction::BinaryOps, Value *, Index: lib/Transforms/InstCombine/InstCombineMulDivRem.cpp =================================================================== --- lib/Transforms/InstCombine/InstCombineMulDivRem.cpp +++ lib/Transforms/InstCombine/InstCombineMulDivRem.cpp @@ -736,6 +736,10 @@ } } + // Handle specials cases for FMul with selects feeding the operation + if (Value *V = SimplifySelectHelper(I, Op0, Op1, true)) + return replaceInstUsesWith(I, V); + // (X*Y) * X => (X*X) * Y where Y != X // The purpose is two-fold: // 1) to form a power expression (of X). Index: lib/Transforms/InstCombine/InstructionCombining.cpp =================================================================== --- lib/Transforms/InstCombine/InstructionCombining.cpp +++ lib/Transforms/InstCombine/InstructionCombining.cpp @@ -708,31 +708,34 @@ } } + if (Value *V = SimplifySelectHelper(I, LHS, RHS, false)) + return V; + + return nullptr; +} + +Value *InstCombiner::SimplifySelectHelper(BinaryOperator &I, Value *LHS, + Value *RHS, bool CarryFastMathFlags) { + Instruction::BinaryOps Opcode = I.getOpcode(); // (op (select (a, c, b)), (select (a, d, b))) -> (select (a, (op c, d), 0)) // (op (select (a, b, c)), (select (a, b, d))) -> (select (a, 0, (op c, d))) - if (auto *SI0 = dyn_cast(LHS)) { - if (auto *SI1 = dyn_cast(RHS)) { - if (SI0->getCondition() == SI1->getCondition()) { - Value *SI = nullptr; - if (Value *V = - SimplifyBinOp(TopLevelOpcode, SI0->getFalseValue(), - SI1->getFalseValue(), SQ.getWithInstruction(&I))) - SI = Builder.CreateSelect(SI0->getCondition(), - Builder.CreateBinOp(TopLevelOpcode, - SI0->getTrueValue(), - SI1->getTrueValue()), - V); - if (Value *V = - SimplifyBinOp(TopLevelOpcode, SI0->getTrueValue(), - SI1->getTrueValue(), SQ.getWithInstruction(&I))) - SI = Builder.CreateSelect( - SI0->getCondition(), V, - Builder.CreateBinOp(TopLevelOpcode, SI0->getFalseValue(), - SI1->getFalseValue())); - if (SI) { - SI->takeName(&I); - return SI; - } + Value *A1, *B1, *C1, *A2, *B2, *C2; + if (match(LHS, m_Select(m_Value(C1), m_Value(A1), m_Value(B1))) && + match(RHS, m_Select(m_Value(C2), m_Value(A2), m_Value(B2)))) { + if (C1 == C2) { + if (CarryFastMathFlags) + Builder.setFastMathFlags(I.getFastMathFlags()); + + Value *SI = nullptr; + if (Value *V = SimplifyBinOp(Opcode, B1, B2, SQ.getWithInstruction(&I))) + SI = Builder.CreateSelect(C1, Builder.CreateBinOp(Opcode, A1, A2), V); + + if (Value *V = SimplifyBinOp(Opcode, A1, A2, SQ.getWithInstruction(&I))) + SI = Builder.CreateSelect(C1, V, Builder.CreateBinOp(Opcode, B1, B2)); + + if (SI) { + SI->takeName(&I); + return SI; } } } Index: test/Transforms/InstCombine/select_arithmetic.ll =================================================================== --- /dev/null +++ test/Transforms/InstCombine/select_arithmetic.ll @@ -0,0 +1,44 @@ +; RUN: opt < %s -instcombine -S | FileCheck %s + +target datalayout = "e-m:o-i64:64-f80:128-n8:16:32:64-S128" + +define float @test1(i1 zeroext) #0 { + %2 = select i1 %0, float 5.000000e+00, float 6.000000e+00 + %3 = select i1 %0, float 1.000000e+00, float 9.000000e+00 + %4 = fadd float %2, %3 + ret float %4 +; CHECK-LABEL: @test1( +; CHECK: %2 = select i1 %0, float 6.000000e+00, float 1.500000e+01 +; CHECK-NOT: fadd +; CHECK: ret float %2 +} + +define float @test2(i1 zeroext) #0 { + %2 = select i1 %0, float 5.000000e+00, float 6.000000e+00 + %3 = select i1 %0, float 1.000000e+00, float 9.000000e+00 + %4 = fsub float %2, %3 + ret float %4 +; CHECK-LABEL: @test2( +; CHECK: %2 = select i1 %0, float 4.000000e+00, float -3.000000e+00 +; CHECK-NOT: fsub +; CHECK: ret float %2 +} + +define float @test3(i1 zeroext) #0 { + %2 = select i1 %0, float 5.000000e+00, float 6.000000e+00 + %3 = select i1 %0, float 1.000000e+00, float 9.000000e+00 + %4 = fmul float %2, %3 + ret float %4 +; CHECK-LABEL: @test3( +; CHECK: %2 = select i1 %0, float 5.000000e+00, float 5.400000e+01 +; CHECK-NOT: fmul +; CHECK: ret float %2 +} + +attributes #0 = { noinline nounwind ssp uwtable "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "less-precise-fpmad"="false" "no-frame-pointer-elim"="true" "no-frame-pointer-elim-non-leaf" "no-infs-fp-math"="false" "no-jump-tables"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="false" "stack-protector-buffer-size"="8" "target-cpu"="penryn" "target-features"="+cx16,+fxsr,+mmx,+sse,+sse2,+sse3,+sse4.1,+ssse3,+x87" "unsafe-fp-math"="false" "use-soft-float"="false" } + +!llvm.module.flags = !{!0} +!llvm.ident = !{!1} + +!0 = !{i32 1, !"PIC Level", i32 2} +!1 = !{!"Apple LLVM version 9.0.0 (clang-900.0.32)"}