Index: lib/Transforms/InstCombine/InstCombineAndOrXor.cpp =================================================================== --- lib/Transforms/InstCombine/InstCombineAndOrXor.cpp +++ lib/Transforms/InstCombine/InstCombineAndOrXor.cpp @@ -1351,20 +1351,6 @@ } } - // (X >> Z) & (Y >> Z) -> (X&Y) >> Z for all shifts. - if (BinaryOperator *SI1 = dyn_cast(Op1)) { - if (BinaryOperator *SI0 = dyn_cast(Op0)) - if (SI0->isShift() && SI0->getOpcode() == SI1->getOpcode() && - SI0->getOperand(1) == SI1->getOperand(1) && - (SI0->hasOneUse() || SI1->hasOneUse())) { - Value *NewOp = - Builder->CreateAnd(SI0->getOperand(0), SI1->getOperand(0), - SI0->getName()); - return BinaryOperator::Create(SI1->getOpcode(), NewOp, - SI1->getOperand(1)); - } - } - { Value *X = nullptr; bool OpsSwapped = false; @@ -2128,19 +2114,6 @@ if (match(Op0, m_And(m_Or(m_Specific(Op1), m_Value(C)), m_Value(A)))) return BinaryOperator::CreateOr(Op1, Builder->CreateAnd(A, C)); - // (X >> Z) | (Y >> Z) -> (X|Y) >> Z for all shifts. - if (BinaryOperator *SI1 = dyn_cast(Op1)) { - if (BinaryOperator *SI0 = dyn_cast(Op0)) - if (SI0->isShift() && SI0->getOpcode() == SI1->getOpcode() && - SI0->getOperand(1) == SI1->getOperand(1) && - (SI0->hasOneUse() || SI1->hasOneUse())) { - Value *NewOp = Builder->CreateOr(SI0->getOperand(0), SI1->getOperand(0), - SI0->getName()); - return BinaryOperator::Create(SI1->getOpcode(), NewOp, - SI1->getOperand(1)); - } - } - // (~A | ~B) == (~(A & B)) - De Morgan's Law if (Value *Op0NotVal = dyn_castNotVal(Op0)) if (Value *Op1NotVal = dyn_castNotVal(Op1)) @@ -2486,18 +2459,6 @@ } } - // (X >> Z) ^ (Y >> Z) -> (X^Y) >> Z for all shifts. - if (Op0I && Op1I && Op0I->isShift() && - Op0I->getOpcode() == Op1I->getOpcode() && - Op0I->getOperand(1) == Op1I->getOperand(1) && - (Op0I->hasOneUse() || Op1I->hasOneUse())) { - Value *NewOp = - Builder->CreateXor(Op0I->getOperand(0), Op1I->getOperand(0), - Op0I->getName()); - return BinaryOperator::Create(Op1I->getOpcode(), NewOp, - Op1I->getOperand(1)); - } - if (Op0I && Op1I) { Value *A, *B, *C, *D; // (A & B)^(A | B) -> A ^ B Index: lib/Transforms/InstCombine/InstructionCombining.cpp =================================================================== --- lib/Transforms/InstCombine/InstructionCombining.cpp +++ lib/Transforms/InstCombine/InstructionCombining.cpp @@ -390,6 +390,25 @@ Instruction::BinaryOps ROp) { if (Instruction::isCommutative(ROp)) return LeftDistributesOverRight(ROp, LOp); + + switch (LOp) { + default: + return false; + // (X >> Z) & (Y >> Z) -> (X&Y) >> Z for all shifts. + // (X >> Z) | (Y >> Z) -> (X|Y) >> Z for all shifts. + // (X >> Z) ^ (Y >> Z) -> (X^Y) >> Z for all shifts. + case Instruction::And: + case Instruction::Or: + case Instruction::Xor: + switch (ROp) { + default: + return false; + case Instruction::Shl: + case Instruction::LShr: + case Instruction::AShr: + return true; + } + } // TODO: It would be nice to handle division, aka "(X + Y)/Z = X/Z + Y/Z", // but this requires knowing that the addition does not overflow and other // such subtleties. @@ -413,24 +432,31 @@ /// This function factors binary ops which can be combined using distributive /// laws. This also factor SHL as MUL e.g. SHL(X, 2) ==> MUL(X, 4). static Instruction::BinaryOps -getBinOpsForFactorization(BinaryOperator *Op, Value *&LHS, Value *&RHS) { +getBinOpsForFactorization(Instruction::BinaryOps TopLevelOpcode, + BinaryOperator *Op, Value *&LHS, Value *&RHS) { if (!Op) return Instruction::BinaryOpsEnd; - if (Op->getOpcode() == Instruction::Shl) { - if (Constant *CST = dyn_cast(Op->getOperand(1))) { - // The multiplier is really 1 << CST. - RHS = ConstantExpr::getShl(ConstantInt::get(Op->getType(), 1), CST); - LHS = Op->getOperand(0); - return Instruction::Mul; + LHS = Op->getOperand(0); + RHS = Op->getOperand(1); + + switch (TopLevelOpcode) { + default: + return Op->getOpcode(); + + case Instruction::Add: + case Instruction::Sub: + if (Op->getOpcode() == Instruction::Shl) { + if (Constant *CST = dyn_cast(Op->getOperand(1))) { + // The multiplier is really 1 << CST. + RHS = ConstantExpr::getShl(ConstantInt::get(Op->getType(), 1), CST); + return Instruction::Mul; + } } + return Op->getOpcode(); } // TODO: We can add other conversions e.g. shr => div etc. - - LHS = Op->getOperand(0); - RHS = Op->getOperand(1); - return Op->getOpcode(); } /// This tries to simplify binary operations by factorizing out common terms @@ -529,8 +555,10 @@ // Factorization. Value *A = nullptr, *B = nullptr, *C = nullptr, *D = nullptr; - Instruction::BinaryOps LHSOpcode = getBinOpsForFactorization(Op0, A, B); - Instruction::BinaryOps RHSOpcode = getBinOpsForFactorization(Op1, C, D); + Instruction::BinaryOps LHSOpcode, RHSOpcode, TopLevelOpcode = I.getOpcode(); + + LHSOpcode = getBinOpsForFactorization(TopLevelOpcode, Op0, A, B); + RHSOpcode = getBinOpsForFactorization(TopLevelOpcode, Op1, C, D); // The instruction has the form "(A op' B) op (C op' D)". Try to factorize // a common term. @@ -552,7 +580,6 @@ return V; // Expansion. - Instruction::BinaryOps TopLevelOpcode = I.getOpcode(); if (Op0 && RightDistributesOverLeft(Op0->getOpcode(), TopLevelOpcode)) { // The instruction has the form "(A op' B) op C". See if expanding it out // to "(A op C) op' (B op C)" results in simplifications.