diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp --- a/llvm/lib/Analysis/InstructionSimplify.cpp +++ b/llvm/lib/Analysis/InstructionSimplify.cpp @@ -6119,6 +6119,51 @@ return nullptr; } +/// Given a min/max intrinsic, see if it can be removed based on having an +/// operand that is another min/max intrinsic with shared operand(s). The caller +/// is expected to swap the operand arguments to handle commutation. +static Value *foldMinimumMaximumSharedOp(Intrinsic::ID IID, Value *Op0, + Value *Op1) { + assert((IID == Intrinsic::maxnum || IID == Intrinsic::minnum || + IID == Intrinsic::maximum || IID == Intrinsic::minimum) && + "Unsupported intrinsic"); + + auto *M0 = dyn_cast(Op0); + // If Op0 is not the same intrinsic as IID, do not process. + // This is a difference with integer min/max handling. We do not process the + // case like max(min(X,Y),min(X,Y)) => min(X,Y). But it can be handled by GVN. + if (!M0 || M0->getIntrinsicID() != IID) + return nullptr; + Value *X0 = M0->getOperand(0); + Value *Y0 = M0->getOperand(1); + // Simple case, m(m(X,Y), X) => m(X, Y) + // m(m(X,Y), Y) => m(X, Y) + // For minimum/maximum, X is NaN => m(NaN, Y) == NaN and m(NaN, NaN) == NaN. + // For minimum/maximum, Y is NaN => m(X, NaN) == NaN and m(NaN, NaN) == NaN. + // For minnum/maxnum, X is NaN => m(NaN, Y) == Y and m(Y, Y) == Y. + // For minnum/maxnum, Y is NaN => m(X, NaN) == X and m(X, NaN) == X. + if (X0 == Op1 || Y0 == Op1) + return M0; + + auto *M1 = dyn_cast(Op1); + if (!M1) + return nullptr; + Value *X1 = M1->getOperand(0); + Value *Y1 = M1->getOperand(1); + Intrinsic::ID IID1 = M1->getIntrinsicID(); + // we have a case m(m(X,Y),m'(X,Y)) taking into account m' is commutative. + // if m' is m or inversion of m => m(m(X,Y),m'(X,Y)) == m(X,Y). + // For minimum/maximum, X is NaN => m(NaN,Y) == m'(NaN, Y) == NaN. + // For minimum/maximum, Y is NaN => m(X,NaN) == m'(X, NaN) == NaN. + // For minnum/maxnum, X is NaN => m(NaN,Y) == m'(NaN, Y) == Y. + // For minnum/maxnum, Y is NaN => m(X,NaN) == m'(X, NaN) == X. + if ((X0 == X1 && Y0 == Y1) || (X0 == Y1 && Y0 == X1)) + if (IID1 == IID || getInverseMinMaxIntrinsic(IID1) == IID) + return M0; + + return nullptr; +} + static Value *simplifyBinaryIntrinsic(Function *F, Value *Op0, Value *Op1, const SimplifyQuery &Q) { Intrinsic::ID IID = F->getIntrinsicID(); @@ -6358,14 +6403,10 @@ // Min/max of the same operation with common operand: // m(m(X, Y)), X --> m(X, Y) (4 commuted variants) - if (auto *M0 = dyn_cast(Op0)) - if (M0->getIntrinsicID() == IID && - (M0->getOperand(0) == Op1 || M0->getOperand(1) == Op1)) - return Op0; - if (auto *M1 = dyn_cast(Op1)) - if (M1->getIntrinsicID() == IID && - (M1->getOperand(0) == Op0 || M1->getOperand(1) == Op0)) - return Op1; + if (Value *V = foldMinimumMaximumSharedOp(IID, Op0, Op1)) + return V; + if (Value *V = foldMinimumMaximumSharedOp(IID, Op1, Op0)) + return V; break; } diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp --- a/llvm/lib/Analysis/ValueTracking.cpp +++ b/llvm/lib/Analysis/ValueTracking.cpp @@ -7735,6 +7735,12 @@ case Intrinsic::smin: return Intrinsic::smax; case Intrinsic::umax: return Intrinsic::umin; case Intrinsic::umin: return Intrinsic::umax; + // Please note that next four intrinsics may produce the same result for + // original and inverted case even if X != Y due to NaN is handled specially. + case Intrinsic::maximum: return Intrinsic::minimum; + case Intrinsic::minimum: return Intrinsic::maximum; + case Intrinsic::maxnum: return Intrinsic::minnum; + case Intrinsic::minnum: return Intrinsic::maxnum; default: llvm_unreachable("Unexpected intrinsic"); } } diff --git a/llvm/test/Transforms/InstSimplify/fminmax-folds.ll b/llvm/test/Transforms/InstSimplify/fminmax-folds.ll --- a/llvm/test/Transforms/InstSimplify/fminmax-folds.ll +++ b/llvm/test/Transforms/InstSimplify/fminmax-folds.ll @@ -1206,9 +1206,7 @@ define float @maximum_maximum_minimum(float %x, float %y) { ; CHECK-LABEL: @maximum_maximum_minimum( ; CHECK-NEXT: [[MAX:%.*]] = call float @llvm.maximum.f32(float [[X:%.*]], float [[Y:%.*]]) -; CHECK-NEXT: [[MIN:%.*]] = call float @llvm.minimum.f32(float [[X]], float [[Y]]) -; CHECK-NEXT: [[VAL:%.*]] = call float @llvm.maximum.f32(float [[MAX]], float [[MIN]]) -; CHECK-NEXT: ret float [[VAL]] +; CHECK-NEXT: ret float [[MAX]] ; %max = call float @llvm.maximum.f32(float %x, float %y) %min = call float @llvm.minimum.f32(float %x, float %y) @@ -1219,9 +1217,7 @@ define double @maximum_minimum_maximum(double %x, double %y) { ; CHECK-LABEL: @maximum_minimum_maximum( ; CHECK-NEXT: [[MAX:%.*]] = call double @llvm.maximum.f64(double [[X:%.*]], double [[Y:%.*]]) -; CHECK-NEXT: [[MIN:%.*]] = call double @llvm.minimum.f64(double [[X]], double [[Y]]) -; CHECK-NEXT: [[VAL:%.*]] = call double @llvm.maximum.f64(double [[MIN]], double [[MAX]]) -; CHECK-NEXT: ret double [[VAL]] +; CHECK-NEXT: ret double [[MAX]] ; %max = call double @llvm.maximum.f64(double %x, double %y) %min = call double @llvm.minimum.f64(double %x, double %y) @@ -1245,9 +1241,7 @@ define half @maximum_maximum_maximum(half %x, half %y) { ; CHECK-LABEL: @maximum_maximum_maximum( ; CHECK-NEXT: [[MAX1:%.*]] = call half @llvm.maximum.f16(half [[X:%.*]], half [[Y:%.*]]) -; CHECK-NEXT: [[MAX2:%.*]] = call half @llvm.maximum.f16(half [[X]], half [[Y]]) -; CHECK-NEXT: [[VAL:%.*]] = call half @llvm.maximum.f16(half [[MAX1]], half [[MAX2]]) -; CHECK-NEXT: ret half [[VAL]] +; CHECK-NEXT: ret half [[MAX1]] ; %max1 = call half @llvm.maximum.f16(half %x, half %y) %max2 = call half @llvm.maximum.f16(half %x, half %y) @@ -1257,10 +1251,8 @@ define <2 x float> @minimum_maximum_minimum(<2 x float> %x, <2 x float> %y) { ; CHECK-LABEL: @minimum_maximum_minimum( -; CHECK-NEXT: [[MAX:%.*]] = call <2 x float> @llvm.maximum.v2f32(<2 x float> [[X:%.*]], <2 x float> [[Y:%.*]]) -; CHECK-NEXT: [[MIN:%.*]] = call <2 x float> @llvm.minimum.v2f32(<2 x float> [[X]], <2 x float> [[Y]]) -; CHECK-NEXT: [[VAL:%.*]] = call <2 x float> @llvm.minimum.v2f32(<2 x float> [[MAX]], <2 x float> [[MIN]]) -; CHECK-NEXT: ret <2 x float> [[VAL]] +; CHECK-NEXT: [[MIN:%.*]] = call <2 x float> @llvm.minimum.v2f32(<2 x float> [[X:%.*]], <2 x float> [[Y:%.*]]) +; CHECK-NEXT: ret <2 x float> [[MIN]] ; %max = call <2 x float> @llvm.maximum.v2f32(<2 x float> %x, <2 x float> %y) %min = call <2 x float> @llvm.minimum.v2f32(<2 x float> %x, <2 x float> %y) @@ -1270,10 +1262,8 @@ define <2 x double> @minimum_minimum_maximum(<2 x double> %x, <2 x double> %y) { ; CHECK-LABEL: @minimum_minimum_maximum( -; CHECK-NEXT: [[MAX:%.*]] = call <2 x double> @llvm.maximum.v2f64(<2 x double> [[X:%.*]], <2 x double> [[Y:%.*]]) -; CHECK-NEXT: [[MIN:%.*]] = call <2 x double> @llvm.minimum.v2f64(<2 x double> [[X]], <2 x double> [[Y]]) -; CHECK-NEXT: [[VAL:%.*]] = call <2 x double> @llvm.minimum.v2f64(<2 x double> [[MIN]], <2 x double> [[MAX]]) -; CHECK-NEXT: ret <2 x double> [[VAL]] +; CHECK-NEXT: [[MIN:%.*]] = call <2 x double> @llvm.minimum.v2f64(<2 x double> [[X:%.*]], <2 x double> [[Y:%.*]]) +; CHECK-NEXT: ret <2 x double> [[MIN]] ; %max = call <2 x double> @llvm.maximum.v2f64(<2 x double> %x, <2 x double> %y) %min = call <2 x double> @llvm.minimum.v2f64(<2 x double> %x, <2 x double> %y) @@ -1297,9 +1287,7 @@ define float @minimum_minimum_minimum(float %x, float %y) { ; CHECK-LABEL: @minimum_minimum_minimum( ; CHECK-NEXT: [[MIN1:%.*]] = call float @llvm.minimum.f32(float [[X:%.*]], float [[Y:%.*]]) -; CHECK-NEXT: [[MIN2:%.*]] = call float @llvm.minimum.f32(float [[X]], float [[Y]]) -; CHECK-NEXT: [[VAL:%.*]] = call float @llvm.minimum.f32(float [[MIN1]], float [[MIN2]]) -; CHECK-NEXT: ret float [[VAL]] +; CHECK-NEXT: ret float [[MIN1]] ; %min1 = call float @llvm.minimum.f32(float %x, float %y) %min2 = call float @llvm.minimum.f32(float %x, float %y) @@ -1310,9 +1298,7 @@ define double @maxnum_maxnum_minnum(double %x, double %y) { ; CHECK-LABEL: @maxnum_maxnum_minnum( ; CHECK-NEXT: [[MAX:%.*]] = call double @llvm.maxnum.f64(double [[X:%.*]], double [[Y:%.*]]) -; CHECK-NEXT: [[MIN:%.*]] = call double @llvm.minnum.f64(double [[X]], double [[Y]]) -; CHECK-NEXT: [[VAL:%.*]] = call double @llvm.maxnum.f64(double [[MAX]], double [[MIN]]) -; CHECK-NEXT: ret double [[VAL]] +; CHECK-NEXT: ret double [[MAX]] ; %max = call double @llvm.maxnum.f64(double %x, double %y) %min = call double @llvm.minnum.f64(double %x, double %y) @@ -1323,9 +1309,7 @@ define <2 x float> @maxnum_minnum_maxnum(<2 x float> %x, <2 x float> %y) { ; CHECK-LABEL: @maxnum_minnum_maxnum( ; CHECK-NEXT: [[MAX:%.*]] = call <2 x float> @llvm.maxnum.v2f32(<2 x float> [[X:%.*]], <2 x float> [[Y:%.*]]) -; CHECK-NEXT: [[MIN:%.*]] = call <2 x float> @llvm.minnum.v2f32(<2 x float> [[X]], <2 x float> [[Y]]) -; CHECK-NEXT: [[VAL:%.*]] = call <2 x float> @llvm.maxnum.v2f32(<2 x float> [[MIN]], <2 x float> [[MAX]]) -; CHECK-NEXT: ret <2 x float> [[VAL]] +; CHECK-NEXT: ret <2 x float> [[MAX]] ; %max = call <2 x float> @llvm.maxnum.v2f32(<2 x float> %x, <2 x float> %y) %min = call <2 x float> @llvm.minnum.v2f32(<2 x float> %x, <2 x float> %y) @@ -1349,9 +1333,7 @@ define float @maxnum_maxnum_maxnum(float %x, float %y) { ; CHECK-LABEL: @maxnum_maxnum_maxnum( ; CHECK-NEXT: [[MAX1:%.*]] = call float @llvm.maxnum.f32(float [[X:%.*]], float [[Y:%.*]]) -; CHECK-NEXT: [[MAX2:%.*]] = call float @llvm.maxnum.f32(float [[X]], float [[Y]]) -; CHECK-NEXT: [[VAL:%.*]] = call float @llvm.maxnum.f32(float [[MAX1]], float [[MAX2]]) -; CHECK-NEXT: ret float [[VAL]] +; CHECK-NEXT: ret float [[MAX1]] ; %max1 = call float @llvm.maxnum.f32(float %x, float %y) %max2 = call float @llvm.maxnum.f32(float %x, float %y) @@ -1361,10 +1343,8 @@ define double @minnum_maxnum_minnum(double %x, double %y) { ; CHECK-LABEL: @minnum_maxnum_minnum( -; CHECK-NEXT: [[MAX:%.*]] = call double @llvm.maxnum.f64(double [[X:%.*]], double [[Y:%.*]]) -; CHECK-NEXT: [[MIN:%.*]] = call double @llvm.minnum.f64(double [[X]], double [[Y]]) -; CHECK-NEXT: [[VAL:%.*]] = call double @llvm.minnum.f64(double [[MAX]], double [[MIN]]) -; CHECK-NEXT: ret double [[VAL]] +; CHECK-NEXT: [[MIN:%.*]] = call double @llvm.minnum.f64(double [[X:%.*]], double [[Y:%.*]]) +; CHECK-NEXT: ret double [[MIN]] ; %max = call double @llvm.maxnum.f64(double %x, double %y) %min = call double @llvm.minnum.f64(double %x, double %y) @@ -1374,10 +1354,8 @@ define float @minnum_minnum_maxnum(float %x, float %y) { ; CHECK-LABEL: @minnum_minnum_maxnum( -; CHECK-NEXT: [[MAX:%.*]] = call float @llvm.maxnum.f32(float [[X:%.*]], float [[Y:%.*]]) -; CHECK-NEXT: [[MIN:%.*]] = call float @llvm.minnum.f32(float [[X]], float [[Y]]) -; CHECK-NEXT: [[VAL:%.*]] = call float @llvm.minnum.f32(float [[MIN]], float [[MAX]]) -; CHECK-NEXT: ret float [[VAL]] +; CHECK-NEXT: [[MIN:%.*]] = call float @llvm.minnum.f32(float [[X:%.*]], float [[Y:%.*]]) +; CHECK-NEXT: ret float [[MIN]] ; %max = call float @llvm.maxnum.f32(float %x, float %y) %min = call float @llvm.minnum.f32(float %x, float %y) @@ -1401,9 +1379,7 @@ define <2 x double> @minnum_minnum_minmum(<2 x double> %x, <2 x double> %y) { ; CHECK-LABEL: @minnum_minnum_minmum( ; CHECK-NEXT: [[MIN1:%.*]] = call <2 x double> @llvm.minnum.v2f64(<2 x double> [[X:%.*]], <2 x double> [[Y:%.*]]) -; CHECK-NEXT: [[MIN2:%.*]] = call <2 x double> @llvm.minnum.v2f64(<2 x double> [[X]], <2 x double> [[Y]]) -; CHECK-NEXT: [[VAL:%.*]] = call <2 x double> @llvm.minnum.v2f64(<2 x double> [[MIN1]], <2 x double> [[MIN2]]) -; CHECK-NEXT: ret <2 x double> [[VAL]] +; CHECK-NEXT: ret <2 x double> [[MIN1]] ; %min1 = call <2 x double> @llvm.minnum.v2f64(<2 x double> %x, <2 x double> %y) %min2 = call <2 x double> @llvm.minnum.v2f64(<2 x double> %x, <2 x double> %y)