Index: lib/Transforms/InstCombine/InstCombineAddSub.cpp =================================================================== --- lib/Transforms/InstCombine/InstCombineAddSub.cpp +++ lib/Transforms/InstCombine/InstCombineAddSub.cpp @@ -1387,30 +1387,9 @@ } } - // select C, 0, B + select C, A, 0 -> select C, A, B - { - Value *A1, *B1, *C1, *A2, *B2, *C2; - if (match(LHS, m_Select(m_Value(C1), m_Value(A1), m_Value(B1))) && - match(RHS, m_Select(m_Value(C2), m_Value(A2), m_Value(B2)))) { - if (C1 == C2) { - Constant *Z1=nullptr, *Z2=nullptr; - Value *A, *B, *C=C1; - if (match(A1, m_AnyZero()) && match(B2, m_AnyZero())) { - Z1 = dyn_cast(A1); A = A2; - Z2 = dyn_cast(B2); B = B1; - } else if (match(B1, m_AnyZero()) && match(A2, m_AnyZero())) { - Z1 = dyn_cast(B1); B = B2; - Z2 = dyn_cast(A2); A = A1; - } - - if (Z1 && Z2 && - (I.hasNoSignedZeros() || - (Z1->isNegativeZeroValue() && Z2->isNegativeZeroValue()))) { - return SelectInst::Create(C, A, B); - } - } - } - } + // Handle specials cases for FAdd with selects feeding the operation + if (Value *V = SimplifySelectsFeedingBinaryOp(I, LHS, RHS)) + return replaceInstUsesWith(I, V); if (I.hasUnsafeAlgebra()) { if (Value *V = FAddCombine(Builder).simplify(&I)) @@ -1760,6 +1739,10 @@ } } + // Handle specials cases for FSub with selects feeding the operation + if (Value *V = SimplifySelectsFeedingBinaryOp(I, Op0, Op1)) + return replaceInstUsesWith(I, V); + if (I.hasUnsafeAlgebra()) { if (Value *V = FAddCombine(Builder).simplify(&I)) return replaceInstUsesWith(I, V); Index: lib/Transforms/InstCombine/InstCombineInternal.h =================================================================== --- lib/Transforms/InstCombine/InstCombineInternal.h +++ lib/Transforms/InstCombine/InstCombineInternal.h @@ -600,6 +600,12 @@ /// value, or null if it didn't simplify. Value *SimplifyUsingDistributiveLaws(BinaryOperator &I); + // Binary Op helper for select operations where the expression can be + // efficiently reorganized. + Value *SimplifySelectsFeedingBinaryOp(BinaryOperator &I, + Value *LHS, + Value *RHS); + /// This tries to simplify binary operations by factorizing out common terms /// (e. g. "(A*B)+(A*C)" -> "A*(B+C)"). Value *tryFactorization(BinaryOperator &, Instruction::BinaryOps, Value *, Index: lib/Transforms/InstCombine/InstCombineMulDivRem.cpp =================================================================== --- lib/Transforms/InstCombine/InstCombineMulDivRem.cpp +++ lib/Transforms/InstCombine/InstCombineMulDivRem.cpp @@ -736,6 +736,10 @@ } } + // Handle specials cases for FMul with selects feeding the operation + if (Value *V = SimplifySelectsFeedingBinaryOp(I, Op0, Op1)) + return replaceInstUsesWith(I, V); + // (X*Y) * X => (X*X) * Y where Y != X // The purpose is two-fold: // 1) to form a power expression (of X). Index: lib/Transforms/InstCombine/InstructionCombining.cpp =================================================================== --- lib/Transforms/InstCombine/InstructionCombining.cpp +++ lib/Transforms/InstCombine/InstructionCombining.cpp @@ -719,31 +719,38 @@ } } - // (op (select (a, c, b)), (select (a, d, b))) -> (select (a, (op c, d), 0)) - // (op (select (a, b, c)), (select (a, b, d))) -> (select (a, 0, (op c, d))) - if (auto *SI0 = dyn_cast(LHS)) { - if (auto *SI1 = dyn_cast(RHS)) { - if (SI0->getCondition() == SI1->getCondition()) { - Value *SI = nullptr; - if (Value *V = - SimplifyBinOp(TopLevelOpcode, SI0->getFalseValue(), - SI1->getFalseValue(), SQ.getWithInstruction(&I))) - SI = Builder.CreateSelect(SI0->getCondition(), - Builder.CreateBinOp(TopLevelOpcode, - SI0->getTrueValue(), - SI1->getTrueValue()), - V); - if (Value *V = - SimplifyBinOp(TopLevelOpcode, SI0->getTrueValue(), - SI1->getTrueValue(), SQ.getWithInstruction(&I))) - SI = Builder.CreateSelect( - SI0->getCondition(), V, - Builder.CreateBinOp(TopLevelOpcode, SI0->getFalseValue(), - SI1->getFalseValue())); - if (SI) { - SI->takeName(&I); - return SI; - } + return SimplifySelectsFeedingBinaryOp(I, LHS, RHS); +} + +Value *InstCombiner::SimplifySelectsFeedingBinaryOp(BinaryOperator &I, + Value *LHS, + Value *RHS) { + Instruction::BinaryOps Opcode = I.getOpcode(); + // (op (select (a, b, c)), (select (a, d, e))) -> (select (a, (op b, d), (op c, e))) + Value *A1, *B, *C, *A2, *D, *E; + if (match(LHS, m_Select(m_Value(A1), m_Value(B), m_Value(C))) && + match(RHS, m_Select(m_Value(A2), m_Value(D), m_Value(E)))) { + if (A1 == A2) { + Value *SI = nullptr; + if (isa(&I)) { + BuilderTy::FastMathFlagGuard Guard(Builder); + Builder.setFastMathFlags(I.getFastMathFlags()); + Value *V1 = SimplifyBinOp(Opcode, C, E, SQ.getWithInstruction(&I)); + SI = (V1) ? Builder.CreateSelect(A1, Builder.CreateBinOp(Opcode, B, D), V1) : SI; + if (Value *V2 = SimplifyBinOp(Opcode, B, D, SQ.getWithInstruction(&I))) + SI = (V1) ? Builder.CreateSelect(A1, V2, V1) : + Builder.CreateSelect(A1, V2, Builder.CreateBinOp(Opcode, C, E)); + } else { + Value *V1 = SimplifyBinOp(Opcode, C, E, SQ.getWithInstruction(&I)); + SI = (V1) ? Builder.CreateSelect(A1, Builder.CreateBinOp(Opcode, B, D), V1) : SI; + if (Value *V2 = SimplifyBinOp(Opcode, B, D, SQ.getWithInstruction(&I))) + SI = (V1) ? Builder.CreateSelect(A1, V2, V1) : + Builder.CreateSelect(A1, V2, Builder.CreateBinOp(Opcode, C, E)); + } + + if (SI) { + SI->takeName(&I); + return SI; } } } Index: test/Transforms/InstCombine/select_arithmetic.ll =================================================================== --- /dev/null +++ test/Transforms/InstCombine/select_arithmetic.ll @@ -0,0 +1,40 @@ +; RUN: opt < %s -instcombine -S | FileCheck %s + +target datalayout = "e-m:o-i64:64-f80:128-n8:16:32:64-S128" + +; Tests folding constants from two similar selects that feed a add +define float @test1(i1 zeroext %arg) #0 { + %tmp = select i1 %arg, float 5.000000e+00, float 6.000000e+00 + %tmp1 = select i1 %arg, float 1.000000e+00, float 9.000000e+00 + %tmp2 = fadd float %tmp, %tmp1 + ret float %tmp2 +; CHECK-LABEL: @test1( +; CHECK: %tmp2 = select i1 %arg, float 6.000000e+00, float 1.500000e+01 +; CHECK-NOT: fadd +; CHECK: ret float %tmp2 +} + +; Tests folding constants from two similar selects that feed a sub +define float @test2(i1 zeroext %arg) #0 { + %tmp = select i1 %arg, float 5.000000e+00, float 6.000000e+00 + %tmp1 = select i1 %arg, float 1.000000e+00, float 9.000000e+00 + %tmp2 = fsub float %tmp, %tmp1 + ret float %tmp2 +; CHECK-LABEL: @test2( +; CHECK: %tmp2 = select i1 %arg, float 4.000000e+00, float -3.000000e+00 +; CHECK-NOT: fsub +; CHECK: ret float %tmp2 +} + +; Tests folding constants from two similar selects that feed a mul +define float @test3(i1 zeroext %arg) #0 { + %tmp = select i1 %arg, float 5.000000e+00, float 6.000000e+00 + %tmp1 = select i1 %arg, float 1.000000e+00, float 9.000000e+00 + %tmp2 = fmul float %tmp, %tmp1 + ret float %tmp2 +; CHECK-LABEL: @test3( +; CHECK: %tmp2 = select i1 %arg, float 5.000000e+00, float 5.400000e+01 +; CHECK-NOT: fmul +; CHECK: ret float %tmp2 +} +