diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp @@ -1468,6 +1468,46 @@ } } + // (umax X, (xor X, Pow2)) + // -> (or X, Pow2) + // (umin X, (xor X, Pow2)) + // -> (and X, ~Pow2) + // (smax X, (xor X, Pos_Pow2)) + // -> (or X, Pos_Pow2) + // (smin X, (xor X, Pos_Pow2)) + // -> (and X, ~Pos_Pow2) + // (smax X, (xor X, Neg_Pow2)) + // -> (and X, ~Neg_Pow2) + // (smin X, (xor X, Neg_Pow2)) + // -> (or X, Neg_Pow2) + if ((match(I0, m_c_Xor(m_Specific(I1), m_Value(X))) || + match(I1, m_c_Xor(m_Specific(I0), m_Value(X)))) && + isKnownToBeAPowerOfTwo(X, /* OrZero */ true)) { + bool UseOr = IID == Intrinsic::smax || IID == Intrinsic::umax; + bool UseAndN = IID == Intrinsic::smin || IID == Intrinsic::umin; + + if (IID == Intrinsic::smax || IID == Intrinsic::smin) { + auto KnownSign = getKnownSign(X, II, DL, &AC, &DT); + if (KnownSign == std::nullopt) { + UseOr = false; + UseAndN = false; + } else if (*KnownSign /* true is Signed. */) { + UseOr ^= true; + UseAndN ^= true; + Type *Ty = I0->getType(); + // Negative power of 2 must be IntMin. It's possible to be able to + // prove negative / power of 2 without actually having known bits, so + // just get the value by hand. + X = Constant::getIntegerValue( + Ty, APInt::getSignedMinValue(Ty->getScalarSizeInBits())); + } + } + if (UseOr) + return BinaryOperator::CreateOr(I0, X); + else if (UseAndN) + return BinaryOperator::CreateAnd(I0, Builder.CreateNot(X)); + } + // If we can eliminate ~A and Y is free to invert: // max ~A, Y --> ~(min A, ~Y) // diff --git a/llvm/test/Transforms/InstCombine/minmax-of-xor-x.ll b/llvm/test/Transforms/InstCombine/minmax-of-xor-x.ll --- a/llvm/test/Transforms/InstCombine/minmax-of-xor-x.ll +++ b/llvm/test/Transforms/InstCombine/minmax-of-xor-x.ll @@ -15,8 +15,7 @@ define <2 x i8> @umax_xor_Cpow2(<2 x i8> %x) { ; CHECK-LABEL: @umax_xor_Cpow2( -; CHECK-NEXT: [[X_XOR:%.*]] = xor <2 x i8> [[X:%.*]], -; CHECK-NEXT: [[R:%.*]] = call <2 x i8> @llvm.umax.v2i8(<2 x i8> [[X]], <2 x i8> [[X_XOR]]) +; CHECK-NEXT: [[R:%.*]] = or <2 x i8> [[X:%.*]], ; CHECK-NEXT: ret <2 x i8> [[R]] ; %x_xor = xor <2 x i8> %x, @@ -26,8 +25,7 @@ define i8 @umin_xor_Cpow2(i8 %x) { ; CHECK-LABEL: @umin_xor_Cpow2( -; CHECK-NEXT: [[X_XOR:%.*]] = xor i8 [[X:%.*]], 64 -; CHECK-NEXT: [[R:%.*]] = call i8 @llvm.umin.i8(i8 [[X]], i8 [[X_XOR]]) +; CHECK-NEXT: [[R:%.*]] = and i8 [[X:%.*]], -65 ; CHECK-NEXT: ret i8 [[R]] ; %x_xor = xor i8 %x, 64 @@ -37,8 +35,7 @@ define i8 @smax_xor_Cpow2_pos(i8 %x) { ; CHECK-LABEL: @smax_xor_Cpow2_pos( -; CHECK-NEXT: [[X_XOR:%.*]] = xor i8 [[X:%.*]], 32 -; CHECK-NEXT: [[R:%.*]] = call i8 @llvm.smax.i8(i8 [[X]], i8 [[X_XOR]]) +; CHECK-NEXT: [[R:%.*]] = or i8 [[X:%.*]], 32 ; CHECK-NEXT: ret i8 [[R]] ; %x_xor = xor i8 %x, 32 @@ -48,8 +45,7 @@ define <2 x i8> @smin_xor_Cpow2_pos(<2 x i8> %x) { ; CHECK-LABEL: @smin_xor_Cpow2_pos( -; CHECK-NEXT: [[X_XOR:%.*]] = xor <2 x i8> [[X:%.*]], -; CHECK-NEXT: [[R:%.*]] = call <2 x i8> @llvm.smin.v2i8(<2 x i8> [[X]], <2 x i8> [[X_XOR]]) +; CHECK-NEXT: [[R:%.*]] = and <2 x i8> [[X:%.*]], ; CHECK-NEXT: ret <2 x i8> [[R]] ; %x_xor = xor <2 x i8> %x, @@ -59,8 +55,7 @@ define <2 x i8> @smax_xor_Cpow2_neg(<2 x i8> %x) { ; CHECK-LABEL: @smax_xor_Cpow2_neg( -; CHECK-NEXT: [[X_XOR:%.*]] = xor <2 x i8> [[X:%.*]], -; CHECK-NEXT: [[R:%.*]] = call <2 x i8> @llvm.smax.v2i8(<2 x i8> [[X]], <2 x i8> [[X_XOR]]) +; CHECK-NEXT: [[R:%.*]] = and <2 x i8> [[X:%.*]], ; CHECK-NEXT: ret <2 x i8> [[R]] ; %x_xor = xor <2 x i8> %x, @@ -70,8 +65,7 @@ define i8 @smin_xor_Cpow2_neg(i8 %x) { ; CHECK-LABEL: @smin_xor_Cpow2_neg( -; CHECK-NEXT: [[X_XOR:%.*]] = xor i8 [[X:%.*]], -128 -; CHECK-NEXT: [[R:%.*]] = call i8 @llvm.smin.i8(i8 [[X]], i8 [[X_XOR]]) +; CHECK-NEXT: [[R:%.*]] = or i8 [[X:%.*]], -128 ; CHECK-NEXT: ret i8 [[R]] ; %x_xor = xor i8 %x, 128 @@ -83,8 +77,7 @@ ; CHECK-LABEL: @umax_xor_pow2( ; CHECK-NEXT: [[NY:%.*]] = sub i8 0, [[Y:%.*]] ; CHECK-NEXT: [[YP2:%.*]] = and i8 [[NY]], [[Y]] -; CHECK-NEXT: [[X_XOR:%.*]] = xor i8 [[YP2]], [[X:%.*]] -; CHECK-NEXT: [[R:%.*]] = call i8 @llvm.umax.i8(i8 [[X]], i8 [[X_XOR]]) +; CHECK-NEXT: [[R:%.*]] = or i8 [[YP2]], [[X:%.*]] ; CHECK-NEXT: ret i8 [[R]] ; %ny = sub i8 0, %y @@ -98,8 +91,8 @@ ; CHECK-LABEL: @umin_xor_pow2( ; CHECK-NEXT: [[NY:%.*]] = sub <2 x i8> zeroinitializer, [[Y:%.*]] ; CHECK-NEXT: [[YP2:%.*]] = and <2 x i8> [[NY]], [[Y]] -; CHECK-NEXT: [[X_XOR:%.*]] = xor <2 x i8> [[YP2]], [[X:%.*]] -; CHECK-NEXT: [[R:%.*]] = call <2 x i8> @llvm.umin.v2i8(<2 x i8> [[X]], <2 x i8> [[X_XOR]]) +; CHECK-NEXT: [[TMP1:%.*]] = xor <2 x i8> [[YP2]], +; CHECK-NEXT: [[R:%.*]] = and <2 x i8> [[TMP1]], [[X:%.*]] ; CHECK-NEXT: ret <2 x i8> [[R]] ; %ny = sub <2 x i8> , %y @@ -146,8 +139,7 @@ ; CHECK-NEXT: [[CMP:%.*]] = icmp slt i8 [[YP2]], 0 ; CHECK-NEXT: br i1 [[CMP]], label [[NEG:%.*]], label [[POS:%.*]] ; CHECK: neg: -; CHECK-NEXT: [[X_XOR:%.*]] = xor i8 [[YP2]], [[X:%.*]] -; CHECK-NEXT: [[R:%.*]] = call i8 @llvm.smax.i8(i8 [[X]], i8 [[X_XOR]]) +; CHECK-NEXT: [[R:%.*]] = and i8 [[X:%.*]], 127 ; CHECK-NEXT: ret i8 [[R]] ; CHECK: pos: ; CHECK-NEXT: call void @barrier() @@ -173,8 +165,8 @@ ; CHECK-NEXT: [[CMP:%.*]] = icmp sgt i8 [[YP2]], 0 ; CHECK-NEXT: br i1 [[CMP]], label [[NEG:%.*]], label [[POS:%.*]] ; CHECK: neg: -; CHECK-NEXT: [[X_XOR:%.*]] = xor i8 [[YP2]], [[X:%.*]] -; CHECK-NEXT: [[R:%.*]] = call i8 @llvm.smin.i8(i8 [[X]], i8 [[X_XOR]]) +; CHECK-NEXT: [[TMP1:%.*]] = xor i8 [[YP2]], -1 +; CHECK-NEXT: [[R:%.*]] = and i8 [[TMP1]], [[X:%.*]] ; CHECK-NEXT: ret i8 [[R]] ; CHECK: pos: ; CHECK-NEXT: call void @barrier()