diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp @@ -1180,6 +1180,21 @@ return nullptr; } +// Transform: +// (add A, (shl (neg B), Y)) +// -> (sub A, (shl B, Y)) +static Instruction *combineAddSubWithShlAddSub(InstCombiner::BuilderTy &Builder, + const BinaryOperator &I) { + Value *A, *B, *Cnt; + if (match(&I, + m_c_Add(m_OneUse(m_Shl(m_OneUse(m_Neg(m_Value(B))), m_Value(Cnt))), + m_Value(A)))) { + Value *NewShl = Builder.CreateShl(B, Cnt); + return BinaryOperator::CreateSub(A, NewShl); + } + return nullptr; +} + /// Try to reduce signed division by power-of-2 to an arithmetic shift right. static Instruction *foldAddToAshr(BinaryOperator &Add) { // Division must be by power-of-2, but not the minimum signed value. @@ -1420,6 +1435,9 @@ if (Instruction *R = foldBinOpShiftWithShift(I)) return R; + if (Instruction *R = combineAddSubWithShlAddSub(Builder, I)) + return R; + Value *LHS = I.getOperand(0), *RHS = I.getOperand(1); Type *Ty = I.getType(); if (Ty->isIntOrIntVectorTy(1)) diff --git a/llvm/test/Transforms/InstCombine/add-shift.ll b/llvm/test/Transforms/InstCombine/add-shift.ll --- a/llvm/test/Transforms/InstCombine/add-shift.ll +++ b/llvm/test/Transforms/InstCombine/add-shift.ll @@ -3,9 +3,8 @@ define i8 @flip_add_of_shift_neg(i8 %v, i8 %sh, i8 %x) { ; CHECK-LABEL: define i8 @flip_add_of_shift_neg ; CHECK-SAME: (i8 [[V:%.*]], i8 [[SH:%.*]], i8 [[X:%.*]]) { -; CHECK-NEXT: [[NV:%.*]] = sub i8 0, [[V]] -; CHECK-NEXT: [[SV:%.*]] = shl nuw nsw i8 [[NV]], [[SH]] -; CHECK-NEXT: [[R:%.*]] = add i8 [[SV]], [[X]] +; CHECK-NEXT: [[TMP1:%.*]] = shl i8 [[V]], [[SH]] +; CHECK-NEXT: [[R:%.*]] = sub i8 [[X]], [[TMP1]] ; CHECK-NEXT: ret i8 [[R]] ; %nv = sub i8 0, %v @@ -18,9 +17,8 @@ ; CHECK-LABEL: define <2 x i8> @flip_add_of_shift_neg_vec ; CHECK-SAME: (<2 x i8> [[V:%.*]], <2 x i8> [[SH:%.*]], <2 x i8> [[XX:%.*]]) { ; CHECK-NEXT: [[X:%.*]] = mul <2 x i8> [[XX]], [[XX]] -; CHECK-NEXT: [[NV:%.*]] = sub <2 x i8> zeroinitializer, [[V]] -; CHECK-NEXT: [[SV:%.*]] = shl <2 x i8> [[NV]], [[SH]] -; CHECK-NEXT: [[R:%.*]] = add <2 x i8> [[X]], [[SV]] +; CHECK-NEXT: [[TMP1:%.*]] = shl <2 x i8> [[V]], [[SH]] +; CHECK-NEXT: [[R:%.*]] = sub <2 x i8> [[X]], [[TMP1]] ; CHECK-NEXT: ret <2 x i8> [[R]] ; %x = mul <2 x i8> %xx, %xx