Index: include/llvm/IR/Constants.h =================================================================== --- include/llvm/IR/Constants.h +++ include/llvm/IR/Constants.h @@ -1018,10 +1018,15 @@ return getLShr(C1, C2, true); } - /// Return the identity for the given binary operation, - /// i.e. a constant C such that X op C = X and C op X = X for every X. It - /// returns null if the operator doesn't have an identity. - static Constant *getBinOpIdentity(unsigned Opcode, Type *Ty); + /// Return the identity constant for a binary opcode. + /// The identity constant C is defined as X op C = X and C op X = X for every + /// X when the binary operation is commutative. If the binop is not + /// commutative, callers can acquire the operand 1 identity constant by + /// setting AllowRHSConstant to true. For example, any shift has a zero + /// identity constant for operand 1: X shift 0 = X. + /// Return nullptr if the operator does not have an identity constant. + static Constant *getBinOpIdentity(unsigned Opcode, Type *Ty, + bool AllowRHSConstant = false); /// Return the absorbing element for the given binary /// operation, i.e. a constant C such that X op C = C and C op X = C for Index: lib/IR/Constants.cpp =================================================================== --- lib/IR/Constants.cpp +++ lib/IR/Constants.cpp @@ -2261,32 +2261,49 @@ isExact ? PossiblyExactOperator::IsExact : 0); } -// FIXME: Add a parameter to specify the operand number for non-commutative ops. -// For example, the operand 1 identity constant for any shift is the null value -// because shift-by-0 always returns operand 0. -Constant *ConstantExpr::getBinOpIdentity(unsigned Opcode, Type *Ty) { - switch (Opcode) { - default: - // Doesn't have an identity. - return nullptr; - - case Instruction::Add: - case Instruction::Or: - case Instruction::Xor: - return Constant::getNullValue(Ty); - - case Instruction::Mul: - return ConstantInt::get(Ty, 1); - - case Instruction::And: - return Constant::getAllOnesValue(Ty); +Constant *ConstantExpr::getBinOpIdentity(unsigned Opcode, Type *Ty, + bool AllowRHSConstant) { + assert(Instruction::isBinaryOp(Opcode) && "Only binops allowed"); + + // Commutative opcodes: it does not matter if AllowRHSConstant is set. + if (Instruction::isCommutative(Opcode)) { + switch (Opcode) { + case Instruction::Add: // X + 0 = X + case Instruction::Or: // X | 0 = X + case Instruction::Xor: // X ^ 0 = X + return Constant::getNullValue(Ty); + case Instruction::Mul: // X * 1 = X + return ConstantInt::get(Ty, 1); + case Instruction::And: // X & -1 = X + return Constant::getAllOnesValue(Ty); + case Instruction::FAdd: // X + -0.0 = X + // TODO: If the fadd has 'nsz', should we return +0.0? + return ConstantFP::getNegativeZero(Ty); + case Instruction::FMul: // X * 1.0 = X + return ConstantFP::get(Ty, 1.0); + default: + llvm_unreachable("Every commutative binop has an identity constant"); + } + } - // TODO: If the fadd has 'nsz', should we return +0.0? - case Instruction::FAdd: - return ConstantFP::getNegativeZero(Ty); + // Non-commutative opcodes: AllowRHSConstant must be set. + if (!AllowRHSConstant) + return nullptr; - case Instruction::FMul: - return ConstantFP::get(Ty, 1.0); + switch (Opcode) { + case Instruction::Sub: // X - 0 = X + case Instruction::Shl: // X << 0 = X + case Instruction::LShr: // X >>u 0 = X + case Instruction::AShr: // X >> 0 = X + case Instruction::FSub: // X - 0.0 = X + return Constant::getNullValue(Ty); + case Instruction::SDiv: // X / 1 = X + case Instruction::UDiv: // X /u 1 = X + return ConstantInt::get(Ty, 1); + case Instruction::FDiv: // X / 1.0 = X + return ConstantFP::get(Ty, 1.0); + default: + return nullptr; } } Index: lib/Transforms/InstCombine/InstCombineVectorOps.cpp =================================================================== --- lib/Transforms/InstCombine/InstCombineVectorOps.cpp +++ lib/Transforms/InstCombine/InstCombineVectorOps.cpp @@ -1207,7 +1207,7 @@ // a vector, this is a splat of something like 0, -1, or 1. // If there's no identity constant for this binop, we're done. BinaryOperator::BinaryOps BOpcode = BO->getOpcode(); - Constant *IdC = ConstantExpr::getBinOpIdentity(BOpcode, Shuf.getType()); + Constant *IdC = ConstantExpr::getBinOpIdentity(BOpcode, Shuf.getType(), true); if (!IdC) return nullptr; Index: test/Transforms/InstCombine/shuffle_select.ll =================================================================== --- test/Transforms/InstCombine/shuffle_select.ll +++ test/Transforms/InstCombine/shuffle_select.ll @@ -43,8 +43,7 @@ define <4 x i32> @shl(<4 x i32> %v) { ; CHECK-LABEL: @shl( -; CHECK-NEXT: [[B:%.*]] = shl nuw <4 x i32> [[V:%.*]], -; CHECK-NEXT: [[S:%.*]] = shufflevector <4 x i32> [[B]], <4 x i32> [[V]], <4 x i32> +; CHECK-NEXT: [[S:%.*]] = shl nuw <4 x i32> [[V:%.*]], ; CHECK-NEXT: ret <4 x i32> [[S]] ; %b = shl nuw <4 x i32> %v, @@ -54,8 +53,7 @@ define <4 x i32> @lshr_constant_op0(<4 x i32> %v) { ; CHECK-LABEL: @lshr_constant_op0( -; CHECK-NEXT: [[B:%.*]] = lshr exact <4 x i32> [[V:%.*]], -; CHECK-NEXT: [[S:%.*]] = shufflevector <4 x i32> [[V]], <4 x i32> [[B]], <4 x i32> +; CHECK-NEXT: [[S:%.*]] = lshr exact <4 x i32> [[V:%.*]], ; CHECK-NEXT: ret <4 x i32> [[S]] ; %b = lshr exact <4 x i32> %v, @@ -77,8 +75,7 @@ define <3 x i32> @ashr(<3 x i32> %v) { ; CHECK-LABEL: @ashr( -; CHECK-NEXT: [[B:%.*]] = ashr <3 x i32> [[V:%.*]], -; CHECK-NEXT: [[S:%.*]] = shufflevector <3 x i32> [[B]], <3 x i32> [[V]], <3 x i32> +; CHECK-NEXT: [[S:%.*]] = ashr <3 x i32> [[V:%.*]], ; CHECK-NEXT: ret <3 x i32> [[S]] ; %b = ashr <3 x i32> %v, @@ -217,8 +214,7 @@ define <4 x double> @fdiv_constant_op1(<4 x double> %v) { ; CHECK-LABEL: @fdiv_constant_op1( -; CHECK-NEXT: [[B:%.*]] = fdiv reassoc <4 x double> [[V:%.*]], -; CHECK-NEXT: [[S:%.*]] = shufflevector <4 x double> [[V]], <4 x double> [[B]], <4 x i32> +; CHECK-NEXT: [[S:%.*]] = fdiv reassoc <4 x double> [[V:%.*]], ; CHECK-NEXT: ret <4 x double> [[S]] ; %b = fdiv reassoc <4 x double> %v,