diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp @@ -2031,14 +2031,17 @@ { // sub(add(X,Y),umin(Y,Z)) --> add(X,usub.sat(Y,Z)) - // sub(add(X,Z),umin(Y,Z)) --> add(X,usub.sat(Y,Z)) + // sub(add(X,Z),umin(Y,Z)) --> add(X,usub.sat(Z,Y)) Value *X, *Y, *Z; - if (match(Op1, m_OneUse(m_UMin(m_Value(Y), m_Value(Z)))) && - (match(Op0, m_OneUse(m_c_Add(m_Specific(Y), m_Value(X)))) || - match(Op0, m_OneUse(m_c_Add(m_Specific(Z), m_Value(X)))))) { - Value *USub = - Builder.CreateIntrinsic(Intrinsic::usub_sat, I.getType(), {Y, Z}); - return BinaryOperator::CreateAdd(X, USub); + if (match(Op1, m_OneUse(m_UMin(m_Value(Y), m_Value(Z))))) { + if (match(Op0, m_OneUse(m_c_Add(m_Specific(Y), m_Value(X))))) + return BinaryOperator::CreateAdd( + X, Builder.CreateIntrinsic(Intrinsic::usub_sat, I.getType(), + {Y, Z})); + if (match(Op0, m_OneUse(m_c_Add(m_Specific(Z), m_Value(X))))) + return BinaryOperator::CreateAdd( + X, Builder.CreateIntrinsic(Intrinsic::usub_sat, I.getType(), + {Z, Y})); } } } diff --git a/llvm/test/Transforms/InstCombine/sub-minmax.ll b/llvm/test/Transforms/InstCombine/sub-minmax.ll --- a/llvm/test/Transforms/InstCombine/sub-minmax.ll +++ b/llvm/test/Transforms/InstCombine/sub-minmax.ll @@ -715,9 +715,7 @@ ret i8 %s } -; -; TODO: sub(add(X,Y),umin(Y,Z)) --> add(X,usubsat(Y,Z)) -; +; sub(add(X,Y),umin(Y,Z)) --> add(X,usubsat(Y,Z)) define i8 @sub_add_umin(i8 %x, i8 %y, i8 %z) { ; CHECK-LABEL: @sub_add_umin(