Index: llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp =================================================================== --- llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp +++ llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp @@ -2298,22 +2298,30 @@ if (!Ty->isIntOrIntVectorTy() || !B->getType()->isIntOrIntVectorTy()) return nullptr; - // We need 0 or all-1's bitmasks. - if (ComputeNumSignBits(A) != Ty->getScalarSizeInBits()) - return nullptr; - - // If B is the 'not' value of A, we have our answer. + // If A is the 'not' operand of B and has enough signbits, we have our answer. if (match(B, m_Not(m_Specific(A)))) { // If these are scalars or vectors of i1, A can be used directly. if (Ty->isIntOrIntVectorTy(1)) return A; - return Builder.CreateTrunc(A, CmpInst::makeCmpResultType(Ty)); + + // If we look through a vector bitcast, the caller will bitcast the operands + // to match the condition's number of bits (N x i1). + // To make this poison-safe, disallow bitcast from wide element to narrow + // element. That could allow poison in lanes where it was not present in the + // original code. + A = peekThroughBitcast(A); + unsigned NumSignBits = ComputeNumSignBits(A); + if (NumSignBits == A->getType()->getScalarSizeInBits() && + NumSignBits <= Ty->getScalarSizeInBits()) + return Builder.CreateTrunc(A, CmpInst::makeCmpResultType(A->getType())); + return nullptr; } // If both operands are constants, see if the constants are inverse bitmasks. Constant *AConst, *BConst; if (match(A, m_Constant(AConst)) && match(B, m_Constant(BConst))) - if (AConst == ConstantExpr::getNot(BConst)) + if (AConst == ConstantExpr::getNot(BConst) && + ComputeNumSignBits(A) == Ty->getScalarSizeInBits()) return Builder.CreateZExtOrTrunc(A, CmpInst::makeCmpResultType(Ty)); // Look for more complex patterns. The 'not' op may be hidden behind various @@ -2357,10 +2365,17 @@ B = peekThroughBitcast(B, true); if (Value *Cond = getSelectCondition(A, B)) { // ((bc Cond) & C) | ((bc ~Cond) & D) --> bc (select Cond, (bc C), (bc D)) + // If this is a vector, we may need to cast to match the condition's length. // The bitcasts will either all exist or all not exist. The builder will // not create unnecessary casts if the types already match. - Value *BitcastC = Builder.CreateBitCast(C, A->getType()); - Value *BitcastD = Builder.CreateBitCast(D, A->getType()); + Type *SelTy = A->getType(); + if (auto *VecTy = dyn_cast(Cond->getType())) { + unsigned Elts = VecTy->getNumElements(); + Type *EltTy = Builder.getIntNTy(SelTy->getPrimitiveSizeInBits() / Elts); + SelTy = FixedVectorType::get(EltTy, Elts); + } + Value *BitcastC = Builder.CreateBitCast(C, SelTy); + Value *BitcastD = Builder.CreateBitCast(D, SelTy); Value *Select = Builder.CreateSelect(Cond, BitcastC, BitcastD); return Builder.CreateBitCast(Select, OrigType); } Index: llvm/test/Transforms/InstCombine/logical-select.ll =================================================================== --- llvm/test/Transforms/InstCombine/logical-select.ll +++ llvm/test/Transforms/InstCombine/logical-select.ll @@ -682,15 +682,15 @@ ret <4 x i32> %sel } +; Bitcast of condition from narrow source element type can be converted to select. + define <2 x i64> @bitcast_vec_cond(<16 x i1> %cond, <2 x i64> %c, <2 x i64> %d) { ; CHECK-LABEL: @bitcast_vec_cond( -; CHECK-NEXT: [[S:%.*]] = sext <16 x i1> [[COND:%.*]] to <16 x i8> -; CHECK-NEXT: [[T9:%.*]] = bitcast <16 x i8> [[S]] to <2 x i64> -; CHECK-NEXT: [[NOTT9:%.*]] = xor <2 x i64> [[T9]], -; CHECK-NEXT: [[T11:%.*]] = and <2 x i64> [[NOTT9]], [[C:%.*]] -; CHECK-NEXT: [[T12:%.*]] = and <2 x i64> [[T9]], [[D:%.*]] -; CHECK-NEXT: [[R:%.*]] = or <2 x i64> [[T11]], [[T12]] -; CHECK-NEXT: ret <2 x i64> [[R]] +; CHECK-NEXT: [[TMP1:%.*]] = bitcast <2 x i64> [[D:%.*]] to <16 x i8> +; CHECK-NEXT: [[TMP2:%.*]] = bitcast <2 x i64> [[C:%.*]] to <16 x i8> +; CHECK-NEXT: [[TMP3:%.*]] = select <16 x i1> [[COND:%.*]], <16 x i8> [[TMP1]], <16 x i8> [[TMP2]] +; CHECK-NEXT: [[TMP4:%.*]] = bitcast <16 x i8> [[TMP3]] to <2 x i64> +; CHECK-NEXT: ret <2 x i64> [[TMP4]] ; %s = sext <16 x i1> %cond to <16 x i8> %t9 = bitcast <16 x i8> %s to <2 x i64> @@ -701,6 +701,8 @@ ret <2 x i64> %r } +; Negative test - bitcast of condition from wide source element type cannot be converted to select. + define <8 x i3> @bitcast_vec_cond_commute1(<3 x i1> %cond, <8 x i3> %pc, <8 x i3> %d) { ; CHECK-LABEL: @bitcast_vec_cond_commute1( ; CHECK-NEXT: [[C:%.*]] = mul <8 x i3> [[PC:%.*]], [[PC]] @@ -726,13 +728,11 @@ ; CHECK-LABEL: @bitcast_vec_cond_commute2( ; CHECK-NEXT: [[C:%.*]] = mul <2 x i16> [[PC:%.*]], [[PC]] ; CHECK-NEXT: [[D:%.*]] = mul <2 x i16> [[PD:%.*]], [[PD]] -; CHECK-NEXT: [[S:%.*]] = sext <4 x i1> [[COND:%.*]] to <4 x i8> -; CHECK-NEXT: [[T9:%.*]] = bitcast <4 x i8> [[S]] to <2 x i16> -; CHECK-NEXT: [[NOTT9:%.*]] = xor <2 x i16> [[T9]], -; CHECK-NEXT: [[T11:%.*]] = and <2 x i16> [[C]], [[NOTT9]] -; CHECK-NEXT: [[T12:%.*]] = and <2 x i16> [[D]], [[T9]] -; CHECK-NEXT: [[R:%.*]] = or <2 x i16> [[T11]], [[T12]] -; CHECK-NEXT: ret <2 x i16> [[R]] +; CHECK-NEXT: [[TMP1:%.*]] = bitcast <2 x i16> [[D]] to <4 x i8> +; CHECK-NEXT: [[TMP2:%.*]] = bitcast <2 x i16> [[C]] to <4 x i8> +; CHECK-NEXT: [[TMP3:%.*]] = select <4 x i1> [[COND:%.*]], <4 x i8> [[TMP1]], <4 x i8> [[TMP2]] +; CHECK-NEXT: [[TMP4:%.*]] = bitcast <4 x i8> [[TMP3]] to <2 x i16> +; CHECK-NEXT: ret <2 x i16> [[TMP4]] ; %c = mul <2 x i16> %pc, %pc ; thwart complexity-based canonicalization %d = mul <2 x i16> %pd, %pd ; thwart complexity-based canonicalization @@ -745,17 +745,18 @@ ret <2 x i16> %r } +; Condition doesn't have to be a bool vec - just all signbits. + define <2 x i16> @bitcast_vec_cond_commute3(<4 x i8> %cond, <2 x i16> %pc, <2 x i16> %pd) { ; CHECK-LABEL: @bitcast_vec_cond_commute3( ; CHECK-NEXT: [[C:%.*]] = mul <2 x i16> [[PC:%.*]], [[PC]] ; CHECK-NEXT: [[D:%.*]] = mul <2 x i16> [[PD:%.*]], [[PD]] -; CHECK-NEXT: [[S:%.*]] = ashr <4 x i8> [[COND:%.*]], -; CHECK-NEXT: [[T9:%.*]] = bitcast <4 x i8> [[S]] to <2 x i16> -; CHECK-NEXT: [[NOTT9:%.*]] = xor <2 x i16> [[T9]], -; CHECK-NEXT: [[T11:%.*]] = and <2 x i16> [[C]], [[NOTT9]] -; CHECK-NEXT: [[T12:%.*]] = and <2 x i16> [[D]], [[T9]] -; CHECK-NEXT: [[R:%.*]] = or <2 x i16> [[T11]], [[T12]] -; CHECK-NEXT: ret <2 x i16> [[R]] +; CHECK-NEXT: [[DOTNOT:%.*]] = icmp sgt <4 x i8> [[COND:%.*]], +; CHECK-NEXT: [[TMP1:%.*]] = bitcast <2 x i16> [[D]] to <4 x i8> +; CHECK-NEXT: [[TMP2:%.*]] = bitcast <2 x i16> [[C]] to <4 x i8> +; CHECK-NEXT: [[TMP3:%.*]] = select <4 x i1> [[DOTNOT]], <4 x i8> [[TMP2]], <4 x i8> [[TMP1]] +; CHECK-NEXT: [[TMP4:%.*]] = bitcast <4 x i8> [[TMP3]] to <2 x i16> +; CHECK-NEXT: ret <2 x i16> [[TMP4]] ; %c = mul <2 x i16> %pc, %pc ; thwart complexity-based canonicalization %d = mul <2 x i16> %pd, %pd ; thwart complexity-based canonicalization