diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp @@ -1231,7 +1231,7 @@ } /// This is a specialization of a more general transform from -/// SimplifyUsingDistributiveLaws. If that code can be made to work optimally +/// foldUsingDistributiveLaws. If that code can be made to work optimally /// for multi-use cases or propagating nsw/nuw, then we would not need this. static Instruction *factorizeMathWithShlOps(BinaryOperator &I, InstCombiner::BuilderTy &Builder) { @@ -1322,7 +1322,7 @@ return Phi; // (A*B)+(A*C) -> A*(B+C) etc - if (Value *V = SimplifyUsingDistributiveLaws(I)) + if (Value *V = foldUsingDistributiveLaws(I)) return replaceInstUsesWith(I, V); if (Instruction *R = foldBoxMultiply(I)) @@ -1944,7 +1944,7 @@ return TryToNarrowDeduceFlags(); // Should have been handled in Negator! // (A*B)-(A*C) -> A*(B-C) etc - if (Value *V = SimplifyUsingDistributiveLaws(I)) + if (Value *V = foldUsingDistributiveLaws(I)) return replaceInstUsesWith(I, V); if (I.getType()->isIntOrIntVectorTy(1)) diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp @@ -1847,7 +1847,7 @@ return X; // (A|B)&(A|C) -> A|(B&C) etc - if (Value *V = SimplifyUsingDistributiveLaws(I)) + if (Value *V = foldUsingDistributiveLaws(I)) return replaceInstUsesWith(I, V); if (Value *V = SimplifyBSwap(I, Builder)) @@ -2825,7 +2825,7 @@ return X; // (A&B)|(A&C) -> A&(B|C) etc - if (Value *V = SimplifyUsingDistributiveLaws(I)) + if (Value *V = foldUsingDistributiveLaws(I)) return replaceInstUsesWith(I, V); if (Value *V = SimplifyBSwap(I, Builder)) @@ -3768,7 +3768,7 @@ return NewXor; // (A&B)^(A&C) -> A&(B^C) etc - if (Value *V = SimplifyUsingDistributiveLaws(I)) + if (Value *V = foldUsingDistributiveLaws(I)) return replaceInstUsesWith(I, V); // See if we can simplify any instructions used by the instruction whose sole diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h --- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h +++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h @@ -544,7 +544,7 @@ /// -> "A*(B+C)") or expanding out if this results in simplifications (eg: "A /// & (B | C) -> (A&B) | (A&C)" if this is a win). Returns the simplified /// value, or null if it didn't simplify. - Value *SimplifyUsingDistributiveLaws(BinaryOperator &I); + Value *foldUsingDistributiveLaws(BinaryOperator &I); /// Tries to simplify add operations using the definition of remainder. /// @@ -560,8 +560,7 @@ /// This tries to simplify binary operations by factorizing out common terms /// (e. g. "(A*B)+(A*C)" -> "A*(B+C)"). - Value *tryFactorization(BinaryOperator &, Instruction::BinaryOps, Value *, - Value *, Value *, Value *); + Value *tryFactorizationFolds(BinaryOperator &I); /// Match a select chain which produces one of three values based on whether /// the LHS is less than, equal to, or greater than RHS respectively. diff --git a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp @@ -188,7 +188,7 @@ if (Instruction *Phi = foldBinopWithPhiOperands(I)) return Phi; - if (Value *V = SimplifyUsingDistributiveLaws(I)) + if (Value *V = foldUsingDistributiveLaws(I)) return replaceInstUsesWith(I, V); Type *Ty = I.getType(); diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp --- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp +++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp @@ -633,10 +633,10 @@ /// This tries to simplify binary operations by factorizing out common terms /// (e. g. "(A*B)+(A*C)" -> "A*(B+C)"). -Value *InstCombinerImpl::tryFactorization(BinaryOperator &I, - Instruction::BinaryOps InnerOpcode, - Value *A, Value *B, Value *C, - Value *D) { +static Value *tryFactorization(BinaryOperator &I, const SimplifyQuery &SQ, + InstCombiner::BuilderTy &Builder, + Instruction::BinaryOps InnerOpcode, Value *A, + Value *B, Value *C, Value *D) { assert(A && B && C && D && "All values must be provided"); Value *V = nullptr; @@ -730,46 +730,58 @@ return RetVal; } +Value *InstCombinerImpl::tryFactorizationFolds(BinaryOperator &I) { + Value *LHS = I.getOperand(0), *RHS = I.getOperand(1); + BinaryOperator *Op0 = dyn_cast(LHS); + BinaryOperator *Op1 = dyn_cast(RHS); + Instruction::BinaryOps TopLevelOpcode = I.getOpcode(); + Value *A, *B, *C, *D; + Instruction::BinaryOps LHSOpcode, RHSOpcode; + + if (Op0) + LHSOpcode = getBinOpsForFactorization(TopLevelOpcode, Op0, A, B); + if (Op1) + RHSOpcode = getBinOpsForFactorization(TopLevelOpcode, Op1, C, D); + + // The instruction has the form "(A op' B) op (C op' D)". Try to factorize + // a common term. + if (Op0 && Op1 && LHSOpcode == RHSOpcode) + if (Value *V = tryFactorization(I, SQ, Builder, LHSOpcode, A, B, C, D)) + return V; + + // The instruction has the form "(A op' B) op (C)". Try to factorize common + // term. + if (Op0) + if (Value *Ident = getIdentityValue(LHSOpcode, RHS)) + if (Value *V = + tryFactorization(I, SQ, Builder, LHSOpcode, A, B, RHS, Ident)) + return V; + + // The instruction has the form "(B) op (C op' D)". Try to factorize common + // term. + if (Op1) + if (Value *Ident = getIdentityValue(RHSOpcode, LHS)) + if (Value *V = + tryFactorization(I, SQ, Builder, RHSOpcode, LHS, Ident, C, D)) + return V; + + return nullptr; +} + /// This tries to simplify binary operations which some other binary operation /// distributes over either by factorizing out common terms /// (eg "(A*B)+(A*C)" -> "A*(B+C)") or expanding out if this results in /// simplifications (eg: "A & (B | C) -> (A&B) | (A&C)" if this is a win). /// Returns the simplified value, or null if it didn't simplify. -Value *InstCombinerImpl::SimplifyUsingDistributiveLaws(BinaryOperator &I) { +Value *InstCombinerImpl::foldUsingDistributiveLaws(BinaryOperator &I) { Value *LHS = I.getOperand(0), *RHS = I.getOperand(1); BinaryOperator *Op0 = dyn_cast(LHS); BinaryOperator *Op1 = dyn_cast(RHS); Instruction::BinaryOps TopLevelOpcode = I.getOpcode(); - { - // Factorization. - Value *A, *B, *C, *D; - Instruction::BinaryOps LHSOpcode, RHSOpcode; - if (Op0) - LHSOpcode = getBinOpsForFactorization(TopLevelOpcode, Op0, A, B); - if (Op1) - RHSOpcode = getBinOpsForFactorization(TopLevelOpcode, Op1, C, D); - - // The instruction has the form "(A op' B) op (C op' D)". Try to factorize - // a common term. - if (Op0 && Op1 && LHSOpcode == RHSOpcode) - if (Value *V = tryFactorization(I, LHSOpcode, A, B, C, D)) - return V; - - // The instruction has the form "(A op' B) op (C)". Try to factorize common - // term. - if (Op0) - if (Value *Ident = getIdentityValue(LHSOpcode, RHS)) - if (Value *V = tryFactorization(I, LHSOpcode, A, B, RHS, Ident)) - return V; - - // The instruction has the form "(B) op (C op' D)". Try to factorize common - // term. - if (Op1) - if (Value *Ident = getIdentityValue(RHSOpcode, LHS)) - if (Value *V = tryFactorization(I, RHSOpcode, LHS, Ident, C, D)) - return V; - } + // Factorization. + if (Value *R = tryFactorizationFolds(I)) + return R; // Expansion. if (Op0 && rightDistributesOverLeft(Op0->getOpcode(), TopLevelOpcode)) {