diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp @@ -1146,6 +1146,30 @@ return nullptr; } +// Transform: +// (add A, (shl (neg X), Y)) +// -> (sub A, (shl X, Y)) +static Instruction *combineAddSubWithShlAddSub(InstCombiner::BuilderTy &Builder, + const BinaryOperator &I) { + Value *Cnt, *X, *Other; + auto MatchShlOfNeg = [&I, &Cnt, &X, &Other](unsigned OpIdx) { + // TODO: We could also match multiuse shl and replace: + // (shl (neg X), Y) -> (neg (shl X, Y)) + // and continue with the transform. + if (!match(I.getOperand(OpIdx), + m_OneUse(m_Shl(m_OneUse(m_Neg(m_Value(X))), m_Value(Cnt))))) + return false; + Other = I.getOperand(1 - OpIdx); + return true; + }; + + if (!MatchShlOfNeg(0) && !MatchShlOfNeg(1)) + return nullptr; + + Value *NewShl = Builder.CreateShl(X, Cnt); + return BinaryOperator::CreateSub(Other, NewShl); +} + /// Try to reduce signed division by power-of-2 to an arithmetic shift right. static Instruction *foldAddToAshr(BinaryOperator &Add) { // Division must be by power-of-2, but not the minimum signed value. @@ -1386,6 +1410,9 @@ if (Instruction *R = foldBinOpShiftWithShift(I)) return R; + if (Instruction *R = combineAddSubWithShlAddSub(Builder, I)) + return R; + Value *LHS = I.getOperand(0), *RHS = I.getOperand(1); Type *Ty = I.getType(); if (Ty->isIntOrIntVectorTy(1)) diff --git a/llvm/test/Transforms/InstCombine/add-shift.ll b/llvm/test/Transforms/InstCombine/add-shift.ll --- a/llvm/test/Transforms/InstCombine/add-shift.ll +++ b/llvm/test/Transforms/InstCombine/add-shift.ll @@ -3,9 +3,8 @@ define i8 @flip_add_of_shift_neg(i8 %v, i8 %sh, i8 %x) { ; CHECK-LABEL: define i8 @flip_add_of_shift_neg ; CHECK-SAME: (i8 [[V:%.*]], i8 [[SH:%.*]], i8 [[X:%.*]]) { -; CHECK-NEXT: [[NV:%.*]] = sub i8 0, [[V]] -; CHECK-NEXT: [[SV:%.*]] = shl nuw nsw i8 [[NV]], [[SH]] -; CHECK-NEXT: [[R:%.*]] = add i8 [[SV]], [[X]] +; CHECK-NEXT: [[TMP1:%.*]] = shl i8 [[V]], [[SH]] +; CHECK-NEXT: [[R:%.*]] = sub i8 [[X]], [[TMP1]] ; CHECK-NEXT: ret i8 [[R]] ; %nv = sub i8 0, %v @@ -18,9 +17,8 @@ ; CHECK-LABEL: define <2 x i8> @flip_add_of_shift_neg_vec ; CHECK-SAME: (<2 x i8> [[V:%.*]], <2 x i8> [[SH:%.*]], <2 x i8> [[XX:%.*]]) { ; CHECK-NEXT: [[X:%.*]] = mul <2 x i8> [[XX]], [[XX]] -; CHECK-NEXT: [[NV:%.*]] = sub <2 x i8> zeroinitializer, [[V]] -; CHECK-NEXT: [[SV:%.*]] = shl <2 x i8> [[NV]], [[SH]] -; CHECK-NEXT: [[R:%.*]] = add <2 x i8> [[X]], [[SV]] +; CHECK-NEXT: [[TMP1:%.*]] = shl <2 x i8> [[V]], [[SH]] +; CHECK-NEXT: [[R:%.*]] = sub <2 x i8> [[X]], [[TMP1]] ; CHECK-NEXT: ret <2 x i8> [[R]] ; %x = mul <2 x i8> %xx, %xx