Index: llvm/trunk/lib/Transforms/InstCombine/InstCombineAddSub.cpp =================================================================== --- llvm/trunk/lib/Transforms/InstCombine/InstCombineAddSub.cpp +++ llvm/trunk/lib/Transforms/InstCombine/InstCombineAddSub.cpp @@ -1387,30 +1387,9 @@ } } - // select C, 0, B + select C, A, 0 -> select C, A, B - { - Value *A1, *B1, *C1, *A2, *B2, *C2; - if (match(LHS, m_Select(m_Value(C1), m_Value(A1), m_Value(B1))) && - match(RHS, m_Select(m_Value(C2), m_Value(A2), m_Value(B2)))) { - if (C1 == C2) { - Constant *Z1=nullptr, *Z2=nullptr; - Value *A, *B, *C=C1; - if (match(A1, m_AnyZero()) && match(B2, m_AnyZero())) { - Z1 = dyn_cast(A1); A = A2; - Z2 = dyn_cast(B2); B = B1; - } else if (match(B1, m_AnyZero()) && match(A2, m_AnyZero())) { - Z1 = dyn_cast(B1); B = B2; - Z2 = dyn_cast(A2); A = A1; - } - - if (Z1 && Z2 && - (I.hasNoSignedZeros() || - (Z1->isNegativeZeroValue() && Z2->isNegativeZeroValue()))) { - return SelectInst::Create(C, A, B); - } - } - } - } + // Handle specials cases for FAdd with selects feeding the operation + if (Value *V = SimplifySelectsFeedingBinaryOp(I, LHS, RHS)) + return replaceInstUsesWith(I, V); if (I.hasUnsafeAlgebra()) { if (Value *V = FAddCombine(Builder).simplify(&I)) @@ -1760,6 +1739,10 @@ } } + // Handle specials cases for FSub with selects feeding the operation + if (Value *V = SimplifySelectsFeedingBinaryOp(I, Op0, Op1)) + return replaceInstUsesWith(I, V); + if (I.hasUnsafeAlgebra()) { if (Value *V = FAddCombine(Builder).simplify(&I)) return replaceInstUsesWith(I, V); Index: llvm/trunk/lib/Transforms/InstCombine/InstCombineInternal.h =================================================================== --- llvm/trunk/lib/Transforms/InstCombine/InstCombineInternal.h +++ llvm/trunk/lib/Transforms/InstCombine/InstCombineInternal.h @@ -600,6 +600,11 @@ /// value, or null if it didn't simplify. Value *SimplifyUsingDistributiveLaws(BinaryOperator &I); + // Binary Op helper for select operations where the expression can be + // efficiently reorganized. + Value *SimplifySelectsFeedingBinaryOp(BinaryOperator &I, Value *LHS, + Value *RHS); + /// This tries to simplify binary operations by factorizing out common terms /// (e. g. "(A*B)+(A*C)" -> "A*(B+C)"). Value *tryFactorization(BinaryOperator &, Instruction::BinaryOps, Value *, Index: llvm/trunk/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp =================================================================== --- llvm/trunk/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp +++ llvm/trunk/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp @@ -736,6 +736,10 @@ } } + // Handle specials cases for FMul with selects feeding the operation + if (Value *V = SimplifySelectsFeedingBinaryOp(I, Op0, Op1)) + return replaceInstUsesWith(I, V); + // (X*Y) * X => (X*X) * Y where Y != X // The purpose is two-fold: // 1) to form a power expression (of X). Index: llvm/trunk/lib/Transforms/InstCombine/InstructionCombining.cpp =================================================================== --- llvm/trunk/lib/Transforms/InstCombine/InstructionCombining.cpp +++ llvm/trunk/lib/Transforms/InstCombine/InstructionCombining.cpp @@ -719,36 +719,36 @@ } } - // (op (select (a, c, b)), (select (a, d, b))) -> (select (a, (op c, d), 0)) - // (op (select (a, b, c)), (select (a, b, d))) -> (select (a, 0, (op c, d))) - if (auto *SI0 = dyn_cast(LHS)) { - if (auto *SI1 = dyn_cast(RHS)) { - if (SI0->getCondition() == SI1->getCondition()) { - Value *SI = nullptr; - if (Value *V = - SimplifyBinOp(TopLevelOpcode, SI0->getFalseValue(), - SI1->getFalseValue(), SQ.getWithInstruction(&I))) - SI = Builder.CreateSelect(SI0->getCondition(), - Builder.CreateBinOp(TopLevelOpcode, - SI0->getTrueValue(), - SI1->getTrueValue()), - V); - if (Value *V = - SimplifyBinOp(TopLevelOpcode, SI0->getTrueValue(), - SI1->getTrueValue(), SQ.getWithInstruction(&I))) - SI = Builder.CreateSelect( - SI0->getCondition(), V, - Builder.CreateBinOp(TopLevelOpcode, SI0->getFalseValue(), - SI1->getFalseValue())); - if (SI) { - SI->takeName(&I); - return SI; - } - } - } + return SimplifySelectsFeedingBinaryOp(I, LHS, RHS); +} + +Value *InstCombiner::SimplifySelectsFeedingBinaryOp(BinaryOperator &I, + Value *LHS, Value *RHS) { + Instruction::BinaryOps Opcode = I.getOpcode(); + // (op (select (a, b, c)), (select (a, d, e))) -> (select (a, (op b, d), (op + // c, e))) + Value *A, *B, *C, *D, *E; + Value *SI = nullptr; + if (match(LHS, m_Select(m_Value(A), m_Value(B), m_Value(C))) && + match(RHS, m_Select(m_Specific(A), m_Value(D), m_Value(E)))) { + BuilderTy::FastMathFlagGuard Guard(Builder); + if (isa(&I)) + Builder.setFastMathFlags(I.getFastMathFlags()); + + Value *V1 = SimplifyBinOp(Opcode, C, E, SQ.getWithInstruction(&I)); + Value *V2 = SimplifyBinOp(Opcode, B, D, SQ.getWithInstruction(&I)); + if (V1 && V2) + SI = Builder.CreateSelect(A, V2, V1); + else if (V2) + SI = Builder.CreateSelect(A, V2, Builder.CreateBinOp(Opcode, C, E)); + else if (V1) + SI = Builder.CreateSelect(A, Builder.CreateBinOp(Opcode, B, D), V1); + + if (SI) + SI->takeName(&I); } - return nullptr; + return SI; } /// Given a 'sub' instruction, return the RHS of the instruction if the LHS is a Index: llvm/trunk/test/Transforms/InstCombine/select_arithmetic.ll =================================================================== --- llvm/trunk/test/Transforms/InstCombine/select_arithmetic.ll +++ llvm/trunk/test/Transforms/InstCombine/select_arithmetic.ll @@ -0,0 +1,40 @@ +; RUN: opt < %s -instcombine -S | FileCheck %s + +target datalayout = "e-m:o-i64:64-f80:128-n8:16:32:64-S128" + +; Tests folding constants from two similar selects that feed a add +define float @test1(i1 zeroext %arg) #0 { + %tmp = select i1 %arg, float 5.000000e+00, float 6.000000e+00 + %tmp1 = select i1 %arg, float 1.000000e+00, float 9.000000e+00 + %tmp2 = fadd float %tmp, %tmp1 + ret float %tmp2 +; CHECK-LABEL: @test1( +; CHECK: %tmp2 = select i1 %arg, float 6.000000e+00, float 1.500000e+01 +; CHECK-NOT: fadd +; CHECK: ret float %tmp2 +} + +; Tests folding constants from two similar selects that feed a sub +define float @test2(i1 zeroext %arg) #0 { + %tmp = select i1 %arg, float 5.000000e+00, float 6.000000e+00 + %tmp1 = select i1 %arg, float 1.000000e+00, float 9.000000e+00 + %tmp2 = fsub float %tmp, %tmp1 + ret float %tmp2 +; CHECK-LABEL: @test2( +; CHECK: %tmp2 = select i1 %arg, float 4.000000e+00, float -3.000000e+00 +; CHECK-NOT: fsub +; CHECK: ret float %tmp2 +} + +; Tests folding constants from two similar selects that feed a mul +define float @test3(i1 zeroext %arg) #0 { + %tmp = select i1 %arg, float 5.000000e+00, float 6.000000e+00 + %tmp1 = select i1 %arg, float 1.000000e+00, float 9.000000e+00 + %tmp2 = fmul float %tmp, %tmp1 + ret float %tmp2 +; CHECK-LABEL: @test3( +; CHECK: %tmp2 = select i1 %arg, float 5.000000e+00, float 5.400000e+01 +; CHECK-NOT: fmul +; CHECK: ret float %tmp2 +} +