diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp @@ -2298,22 +2298,30 @@ if (!Ty->isIntOrIntVectorTy() || !B->getType()->isIntOrIntVectorTy()) return nullptr; - // We need 0 or all-1's bitmasks. - if (ComputeNumSignBits(A) != Ty->getScalarSizeInBits()) - return nullptr; - - // If B is the 'not' value of A, we have our answer. + // If A is the 'not' operand of B and has enough signbits, we have our answer. if (match(B, m_Not(m_Specific(A)))) { // If these are scalars or vectors of i1, A can be used directly. if (Ty->isIntOrIntVectorTy(1)) return A; - return Builder.CreateTrunc(A, CmpInst::makeCmpResultType(Ty)); + + // If we look through a vector bitcast, the caller will bitcast the operands + // to match the condition's number of bits (N x i1). + // To make this poison-safe, disallow bitcast from wide element to narrow + // element. That could allow poison in lanes where it was not present in the + // original code. + A = peekThroughBitcast(A); + unsigned NumSignBits = ComputeNumSignBits(A); + if (NumSignBits == A->getType()->getScalarSizeInBits() && + NumSignBits <= Ty->getScalarSizeInBits()) + return Builder.CreateTrunc(A, CmpInst::makeCmpResultType(A->getType())); + return nullptr; } // If both operands are constants, see if the constants are inverse bitmasks. Constant *AConst, *BConst; if (match(A, m_Constant(AConst)) && match(B, m_Constant(BConst))) - if (AConst == ConstantExpr::getNot(BConst)) + if (AConst == ConstantExpr::getNot(BConst) && + ComputeNumSignBits(A) == Ty->getScalarSizeInBits()) return Builder.CreateZExtOrTrunc(A, CmpInst::makeCmpResultType(Ty)); // Look for more complex patterns. The 'not' op may be hidden behind various @@ -2357,10 +2365,17 @@ B = peekThroughBitcast(B, true); if (Value *Cond = getSelectCondition(A, B)) { // ((bc Cond) & C) | ((bc ~Cond) & D) --> bc (select Cond, (bc C), (bc D)) + // If this is a vector, we may need to cast to match the condition's length. // The bitcasts will either all exist or all not exist. The builder will // not create unnecessary casts if the types already match. - Value *BitcastC = Builder.CreateBitCast(C, A->getType()); - Value *BitcastD = Builder.CreateBitCast(D, A->getType()); + Type *SelTy = A->getType(); + if (auto *VecTy = dyn_cast(Cond->getType())) { + unsigned Elts = VecTy->getElementCount().getKnownMinValue(); + Type *EltTy = Builder.getIntNTy(SelTy->getPrimitiveSizeInBits() / Elts); + SelTy = VectorType::get(EltTy, VecTy->getElementCount()); + } + Value *BitcastC = Builder.CreateBitCast(C, SelTy); + Value *BitcastD = Builder.CreateBitCast(D, SelTy); Value *Select = Builder.CreateSelect(Cond, BitcastC, BitcastD); return Builder.CreateBitCast(Select, OrigType); } diff --git a/llvm/test/Transforms/InstCombine/logical-select.ll b/llvm/test/Transforms/InstCombine/logical-select.ll --- a/llvm/test/Transforms/InstCombine/logical-select.ll +++ b/llvm/test/Transforms/InstCombine/logical-select.ll @@ -682,15 +682,15 @@ ret <4 x i32> %sel } +; Bitcast of condition from narrow source element type can be converted to select. + define <2 x i64> @bitcast_vec_cond(<16 x i1> %cond, <2 x i64> %c, <2 x i64> %d) { ; CHECK-LABEL: @bitcast_vec_cond( -; CHECK-NEXT: [[S:%.*]] = sext <16 x i1> [[COND:%.*]] to <16 x i8> -; CHECK-NEXT: [[T9:%.*]] = bitcast <16 x i8> [[S]] to <2 x i64> -; CHECK-NEXT: [[NOTT9:%.*]] = xor <2 x i64> [[T9]], -; CHECK-NEXT: [[T11:%.*]] = and <2 x i64> [[NOTT9]], [[C:%.*]] -; CHECK-NEXT: [[T12:%.*]] = and <2 x i64> [[T9]], [[D:%.*]] -; CHECK-NEXT: [[R:%.*]] = or <2 x i64> [[T11]], [[T12]] -; CHECK-NEXT: ret <2 x i64> [[R]] +; CHECK-NEXT: [[TMP1:%.*]] = bitcast <2 x i64> [[D:%.*]] to <16 x i8> +; CHECK-NEXT: [[TMP2:%.*]] = bitcast <2 x i64> [[C:%.*]] to <16 x i8> +; CHECK-NEXT: [[TMP3:%.*]] = select <16 x i1> [[COND:%.*]], <16 x i8> [[TMP1]], <16 x i8> [[TMP2]] +; CHECK-NEXT: [[TMP4:%.*]] = bitcast <16 x i8> [[TMP3]] to <2 x i64> +; CHECK-NEXT: ret <2 x i64> [[TMP4]] ; %s = sext <16 x i1> %cond to <16 x i8> %t9 = bitcast <16 x i8> %s to <2 x i64> @@ -701,6 +701,8 @@ ret <2 x i64> %r } +; Negative test - bitcast of condition from wide source element type cannot be converted to select. + define <8 x i3> @bitcast_vec_cond_commute1(<3 x i1> %cond, <8 x i3> %pc, <8 x i3> %d) { ; CHECK-LABEL: @bitcast_vec_cond_commute1( ; CHECK-NEXT: [[C:%.*]] = mul <8 x i3> [[PC:%.*]], [[PC]] @@ -726,13 +728,11 @@ ; CHECK-LABEL: @bitcast_vec_cond_commute2( ; CHECK-NEXT: [[C:%.*]] = mul <2 x i16> [[PC:%.*]], [[PC]] ; CHECK-NEXT: [[D:%.*]] = mul <2 x i16> [[PD:%.*]], [[PD]] -; CHECK-NEXT: [[S:%.*]] = sext <4 x i1> [[COND:%.*]] to <4 x i8> -; CHECK-NEXT: [[T9:%.*]] = bitcast <4 x i8> [[S]] to <2 x i16> -; CHECK-NEXT: [[NOTT9:%.*]] = xor <2 x i16> [[T9]], -; CHECK-NEXT: [[T11:%.*]] = and <2 x i16> [[C]], [[NOTT9]] -; CHECK-NEXT: [[T12:%.*]] = and <2 x i16> [[D]], [[T9]] -; CHECK-NEXT: [[R:%.*]] = or <2 x i16> [[T11]], [[T12]] -; CHECK-NEXT: ret <2 x i16> [[R]] +; CHECK-NEXT: [[TMP1:%.*]] = bitcast <2 x i16> [[D]] to <4 x i8> +; CHECK-NEXT: [[TMP2:%.*]] = bitcast <2 x i16> [[C]] to <4 x i8> +; CHECK-NEXT: [[TMP3:%.*]] = select <4 x i1> [[COND:%.*]], <4 x i8> [[TMP1]], <4 x i8> [[TMP2]] +; CHECK-NEXT: [[TMP4:%.*]] = bitcast <4 x i8> [[TMP3]] to <2 x i16> +; CHECK-NEXT: ret <2 x i16> [[TMP4]] ; %c = mul <2 x i16> %pc, %pc ; thwart complexity-based canonicalization %d = mul <2 x i16> %pd, %pd ; thwart complexity-based canonicalization @@ -745,17 +745,18 @@ ret <2 x i16> %r } +; Condition doesn't have to be a bool vec - just all signbits. + define <2 x i16> @bitcast_vec_cond_commute3(<4 x i8> %cond, <2 x i16> %pc, <2 x i16> %pd) { ; CHECK-LABEL: @bitcast_vec_cond_commute3( ; CHECK-NEXT: [[C:%.*]] = mul <2 x i16> [[PC:%.*]], [[PC]] ; CHECK-NEXT: [[D:%.*]] = mul <2 x i16> [[PD:%.*]], [[PD]] -; CHECK-NEXT: [[S:%.*]] = ashr <4 x i8> [[COND:%.*]], -; CHECK-NEXT: [[T9:%.*]] = bitcast <4 x i8> [[S]] to <2 x i16> -; CHECK-NEXT: [[NOTT9:%.*]] = xor <2 x i16> [[T9]], -; CHECK-NEXT: [[T11:%.*]] = and <2 x i16> [[C]], [[NOTT9]] -; CHECK-NEXT: [[T12:%.*]] = and <2 x i16> [[D]], [[T9]] -; CHECK-NEXT: [[R:%.*]] = or <2 x i16> [[T11]], [[T12]] -; CHECK-NEXT: ret <2 x i16> [[R]] +; CHECK-NEXT: [[DOTNOT:%.*]] = icmp sgt <4 x i8> [[COND:%.*]], +; CHECK-NEXT: [[TMP1:%.*]] = bitcast <2 x i16> [[D]] to <4 x i8> +; CHECK-NEXT: [[TMP2:%.*]] = bitcast <2 x i16> [[C]] to <4 x i8> +; CHECK-NEXT: [[TMP3:%.*]] = select <4 x i1> [[DOTNOT]], <4 x i8> [[TMP2]], <4 x i8> [[TMP1]] +; CHECK-NEXT: [[TMP4:%.*]] = bitcast <4 x i8> [[TMP3]] to <2 x i16> +; CHECK-NEXT: ret <2 x i16> [[TMP4]] ; %c = mul <2 x i16> %pc, %pc ; thwart complexity-based canonicalization %d = mul <2 x i16> %pd, %pd ; thwart complexity-based canonicalization diff --git a/llvm/test/Transforms/PhaseOrdering/X86/vector-math.ll b/llvm/test/Transforms/PhaseOrdering/X86/vector-math.ll --- a/llvm/test/Transforms/PhaseOrdering/X86/vector-math.ll +++ b/llvm/test/Transforms/PhaseOrdering/X86/vector-math.ll @@ -68,15 +68,9 @@ define <2 x i64> @abs_v4i32(<2 x i64> %x) { ; CHECK-LABEL: @abs_v4i32( ; CHECK-NEXT: [[T1_I:%.*]] = bitcast <2 x i64> [[X:%.*]] to <4 x i32> -; CHECK-NEXT: [[SUB_I:%.*]] = sub <4 x i32> zeroinitializer, [[T1_I]] -; CHECK-NEXT: [[T1_I_LOBIT:%.*]] = ashr <4 x i32> [[T1_I]], -; CHECK-NEXT: [[TMP1:%.*]] = bitcast <4 x i32> [[T1_I_LOBIT]] to <2 x i64> -; CHECK-NEXT: [[T2_I_I:%.*]] = xor <2 x i64> [[TMP1]], -; CHECK-NEXT: [[AND_I_I1:%.*]] = and <4 x i32> [[T1_I_LOBIT]], [[SUB_I]] -; CHECK-NEXT: [[AND_I_I:%.*]] = bitcast <4 x i32> [[AND_I_I1]] to <2 x i64> -; CHECK-NEXT: [[AND_I1_I:%.*]] = and <2 x i64> [[T2_I_I]], [[X]] -; CHECK-NEXT: [[OR_I_I:%.*]] = or <2 x i64> [[AND_I1_I]], [[AND_I_I]] -; CHECK-NEXT: ret <2 x i64> [[OR_I_I]] +; CHECK-NEXT: [[TMP1:%.*]] = tail call <4 x i32> @llvm.abs.v4i32(<4 x i32> [[T1_I]], i1 false) +; CHECK-NEXT: [[TMP2:%.*]] = bitcast <4 x i32> [[TMP1]] to <2 x i64> +; CHECK-NEXT: ret <2 x i64> [[TMP2]] ; %call = call <2 x i64> @_mm_set1_epi32(i32 -1) %call1 = call <2 x i64> @_mm_setzero_si128() @@ -90,13 +84,9 @@ ; CHECK-NEXT: [[T0_I_I:%.*]] = bitcast <2 x i64> [[X:%.*]] to <4 x i32> ; CHECK-NEXT: [[T1_I_I:%.*]] = bitcast <2 x i64> [[Y:%.*]] to <4 x i32> ; CHECK-NEXT: [[CMP_I_I:%.*]] = icmp sgt <4 x i32> [[T0_I_I]], [[T1_I_I]] -; CHECK-NEXT: [[SEXT_I_I:%.*]] = sext <4 x i1> [[CMP_I_I]] to <4 x i32> -; CHECK-NEXT: [[T2_I_I:%.*]] = bitcast <4 x i32> [[SEXT_I_I]] to <2 x i64> -; CHECK-NEXT: [[NEG_I_I:%.*]] = xor <2 x i64> [[T2_I_I]], -; CHECK-NEXT: [[AND_I_I:%.*]] = and <2 x i64> [[NEG_I_I]], [[Y]] -; CHECK-NEXT: [[AND_I1_I:%.*]] = and <2 x i64> [[T2_I_I]], [[X]] -; CHECK-NEXT: [[OR_I_I:%.*]] = or <2 x i64> [[AND_I1_I]], [[AND_I_I]] -; CHECK-NEXT: ret <2 x i64> [[OR_I_I]] +; CHECK-NEXT: [[TMP1:%.*]] = select <4 x i1> [[CMP_I_I]], <4 x i32> [[T0_I_I]], <4 x i32> [[T1_I_I]] +; CHECK-NEXT: [[TMP2:%.*]] = bitcast <4 x i32> [[TMP1]] to <2 x i64> +; CHECK-NEXT: ret <2 x i64> [[TMP2]] ; %call = call <2 x i64> @cmpgt_i32_sel_m128i(<2 x i64> %x, <2 x i64> %y, <2 x i64> %y, <2 x i64> %x) ret <2 x i64> %call