diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp @@ -2015,11 +2015,10 @@ } } - { - // sub(add(X,Y), s/umin(X,Y)) --> s/umax(X,Y) - // sub(add(X,Y), s/umax(X,Y)) --> s/umin(X,Y) - // TODO: generalize to sub(add(Z,Y),umin(X,Y)) --> add(Z,usub.sat(Y,X))? - if (auto *II = dyn_cast(Op1)) { + if (auto *II = dyn_cast(Op1)) { + { + // sub(add(X,Y), s/umin(X,Y)) --> s/umax(X,Y) + // sub(add(X,Y), s/umax(X,Y)) --> s/umin(X,Y) Value *X = II->getLHS(); Value *Y = II->getRHS(); if (match(Op0, m_c_Add(m_Specific(X), m_Specific(Y))) && @@ -2029,6 +2028,19 @@ return replaceInstUsesWith(I, InvMaxMin); } } + + { + // sub(add(X,Y),umin(Y,Z)) --> add(X,usub.sat(Y,Z)) + // sub(add(X,Z),umin(Y,Z)) --> add(X,usub.sat(Y,Z)) + Value *X, *Y, *Z; + if (match(Op1, m_OneUse(m_UMin(m_Value(Y), m_Value(Z)))) && + (match(Op0, m_OneUse(m_c_Add(m_Specific(Y), m_Value(X)))) || + match(Op0, m_OneUse(m_c_Add(m_Specific(Z), m_Value(X)))))) { + Value *USub = + Builder.CreateIntrinsic(Intrinsic::usub_sat, I.getType(), {Y, Z}); + return BinaryOperator::CreateAdd(X, USub); + } + } } { diff --git a/llvm/test/Transforms/InstCombine/sub-minmax.ll b/llvm/test/Transforms/InstCombine/sub-minmax.ll --- a/llvm/test/Transforms/InstCombine/sub-minmax.ll +++ b/llvm/test/Transforms/InstCombine/sub-minmax.ll @@ -721,9 +721,8 @@ define i8 @sub_add_umin(i8 %x, i8 %y, i8 %z) { ; CHECK-LABEL: @sub_add_umin( -; CHECK-NEXT: [[A:%.*]] = add i8 [[X:%.*]], [[Y:%.*]] -; CHECK-NEXT: [[M:%.*]] = call i8 @llvm.umin.i8(i8 [[Y]], i8 [[Z:%.*]]) -; CHECK-NEXT: [[S:%.*]] = sub i8 [[A]], [[M]] +; CHECK-NEXT: [[TMP1:%.*]] = call i8 @llvm.usub.sat.i8(i8 [[Y:%.*]], i8 [[Z:%.*]]) +; CHECK-NEXT: [[S:%.*]] = add i8 [[TMP1]], [[X:%.*]] ; CHECK-NEXT: ret i8 [[S]] ; %a = add i8 %x, %y @@ -734,9 +733,8 @@ define i8 @sub_add_umin_commute_umin(i8 %x, i8 %y, i8 %z) { ; CHECK-LABEL: @sub_add_umin_commute_umin( -; CHECK-NEXT: [[A:%.*]] = add i8 [[X:%.*]], [[Y:%.*]] -; CHECK-NEXT: [[M:%.*]] = call i8 @llvm.umin.i8(i8 [[Z:%.*]], i8 [[Y]]) -; CHECK-NEXT: [[S:%.*]] = sub i8 [[A]], [[M]] +; CHECK-NEXT: [[TMP1:%.*]] = call i8 @llvm.usub.sat.i8(i8 [[Y:%.*]], i8 [[Z:%.*]]) +; CHECK-NEXT: [[S:%.*]] = add i8 [[TMP1]], [[X:%.*]] ; CHECK-NEXT: ret i8 [[S]] ; %a = add i8 %x, %y @@ -747,9 +745,8 @@ define i8 @sub_add_umin_commute_add(i8 %x, i8 %y, i8 %z) { ; CHECK-LABEL: @sub_add_umin_commute_add( -; CHECK-NEXT: [[A:%.*]] = add i8 [[Y:%.*]], [[X:%.*]] -; CHECK-NEXT: [[M:%.*]] = call i8 @llvm.umin.i8(i8 [[Y]], i8 [[Z:%.*]]) -; CHECK-NEXT: [[S:%.*]] = sub i8 [[A]], [[M]] +; CHECK-NEXT: [[TMP1:%.*]] = call i8 @llvm.usub.sat.i8(i8 [[Y:%.*]], i8 [[Z:%.*]]) +; CHECK-NEXT: [[S:%.*]] = add i8 [[TMP1]], [[X:%.*]] ; CHECK-NEXT: ret i8 [[S]] ; %a = add i8 %y, %x @@ -760,9 +757,8 @@ define i8 @sub_add_umin_commute_add_umin(i8 %x, i8 %y, i8 %z) { ; CHECK-LABEL: @sub_add_umin_commute_add_umin( -; CHECK-NEXT: [[A:%.*]] = add i8 [[Y:%.*]], [[X:%.*]] -; CHECK-NEXT: [[M:%.*]] = call i8 @llvm.umin.i8(i8 [[Z:%.*]], i8 [[Y]]) -; CHECK-NEXT: [[S:%.*]] = sub i8 [[A]], [[M]] +; CHECK-NEXT: [[TMP1:%.*]] = call i8 @llvm.usub.sat.i8(i8 [[Y:%.*]], i8 [[Z:%.*]]) +; CHECK-NEXT: [[S:%.*]] = add i8 [[TMP1]], [[X:%.*]] ; CHECK-NEXT: ret i8 [[S]] ; %a = add i8 %y, %x @@ -773,9 +769,8 @@ define <2 x i8> @sub_add_umin_vec(<2 x i8> %x, <2 x i8> %y, <2 x i8> %z) { ; CHECK-LABEL: @sub_add_umin_vec( -; CHECK-NEXT: [[A:%.*]] = add <2 x i8> [[X:%.*]], [[Y:%.*]] -; CHECK-NEXT: [[M:%.*]] = call <2 x i8> @llvm.umin.v2i8(<2 x i8> [[Y]], <2 x i8> [[Z:%.*]]) -; CHECK-NEXT: [[S:%.*]] = sub <2 x i8> [[A]], [[M]] +; CHECK-NEXT: [[TMP1:%.*]] = call <2 x i8> @llvm.usub.sat.v2i8(<2 x i8> [[Y:%.*]], <2 x i8> [[Z:%.*]]) +; CHECK-NEXT: [[S:%.*]] = add <2 x i8> [[TMP1]], [[X:%.*]] ; CHECK-NEXT: ret <2 x i8> [[S]] ; %a = add <2 x i8> %x, %y