Index: llvm/trunk/lib/Analysis/ValueTracking.cpp =================================================================== --- llvm/trunk/lib/Analysis/ValueTracking.cpp +++ llvm/trunk/lib/Analysis/ValueTracking.cpp @@ -2512,26 +2512,41 @@ return ComputeNumSignBits(U->getOperand(0), Depth + 1, Q); case Instruction::ShuffleVector: { - // If the shuffle mask contains any undefined elements, that element of the - // result is undefined. Propagating information from a source operand may - // not be correct in that case, so just bail out. - if (cast(U)->getMask()->containsUndefElement()) - break; - - // If everything is undef, we can't say anything. This should be simplified. - Value *Op0 = U->getOperand(0), *Op1 = U->getOperand(1); - if (isa(Op0) && isa(Op1)) + // TODO: This is copied almost directly from the SelectionDAG version of + // ComputeNumSignBits. It would be better if we could share common + // code. If not, make sure that changes are translated to the DAG. + + // Collect the minimum number of sign bits that are shared by every vector + // element referenced by the shuffle. + auto *Shuf = cast(U); + int NumElts = Shuf->getOperand(0)->getType()->getVectorNumElements(); + int NumMaskElts = Shuf->getMask()->getType()->getVectorNumElements(); + APInt DemandedLHS(NumElts, 0), DemandedRHS(NumElts, 0); + for (int i = 0; i != NumMaskElts; ++i) { + int M = Shuf->getMaskValue(i); + assert(M < NumElts * 2 && "Invalid shuffle mask constant"); + // For undef elements, we don't know anything about the common state of + // the shuffle result. + if (M == -1) + return 1; + if (M < NumElts) + DemandedLHS.setBit(M % NumElts); + else + DemandedRHS.setBit(M % NumElts); + } + Tmp = std::numeric_limits::max(); + if (!!DemandedLHS) + Tmp = ComputeNumSignBits(Shuf->getOperand(0), Depth + 1, Q); + if (!!DemandedRHS) { + Tmp2 = ComputeNumSignBits(Shuf->getOperand(1), Depth + 1, Q); + Tmp = std::min(Tmp, Tmp2); + } + // If we don't know anything, early out and try computeKnownBits fall-back. + if (Tmp == 1) break; - - // Look through shuffle of 1 source vector. - if (isa(Op0)) - return ComputeNumSignBits(Op1, Depth + 1, Q); - if (isa(Op1)) - return ComputeNumSignBits(Op0, Depth + 1, Q); - - // TODO: We can look through shuffles of 2 sources by computing the minimum - // sign bits for each operand (similar to what we do for binops). - break; + assert(Tmp <= V->getType()->getScalarSizeInBits() && + "Failed to determine minimum sign bits"); + return Tmp; } } Index: llvm/trunk/test/Transforms/InstCombine/logical-select.ll =================================================================== --- llvm/trunk/test/Transforms/InstCombine/logical-select.ll +++ llvm/trunk/test/Transforms/InstCombine/logical-select.ll @@ -621,11 +621,9 @@ ; CHECK-NEXT: [[SEXT1:%.*]] = sext <4 x i1> [[COND1:%.*]] to <4 x i32> ; CHECK-NEXT: [[SEXT2:%.*]] = sext <4 x i1> [[COND2:%.*]] to <4 x i32> ; CHECK-NEXT: [[COND:%.*]] = shufflevector <4 x i32> [[SEXT1]], <4 x i32> [[SEXT2]], <4 x i32> -; CHECK-NEXT: [[NOTCOND:%.*]] = xor <4 x i32> [[COND]], -; CHECK-NEXT: [[AND1:%.*]] = and <4 x i32> [[NOTCOND]], [[X:%.*]] -; CHECK-NEXT: [[AND2:%.*]] = and <4 x i32> [[COND]], [[Y:%.*]] -; CHECK-NEXT: [[SEL:%.*]] = or <4 x i32> [[AND1]], [[AND2]] -; CHECK-NEXT: ret <4 x i32> [[SEL]] +; CHECK-NEXT: [[TMP1:%.*]] = trunc <4 x i32> [[COND]] to <4 x i1> +; CHECK-NEXT: [[TMP2:%.*]] = select <4 x i1> [[TMP1]], <4 x i32> [[Y:%.*]], <4 x i32> [[X:%.*]] +; CHECK-NEXT: ret <4 x i32> [[TMP2]] ; %sext1 = sext <4 x i1> %cond1 to <4 x i32> %sext2 = sext <4 x i1> %cond2 to <4 x i32> Index: llvm/trunk/unittests/Analysis/ValueTrackingTest.cpp =================================================================== --- llvm/trunk/unittests/Analysis/ValueTrackingTest.cpp +++ llvm/trunk/unittests/Analysis/ValueTrackingTest.cpp @@ -514,7 +514,6 @@ EXPECT_EQ(ComputeNumSignBits(RVal, M->getDataLayout()), 1u); } -// FIXME: // No guarantees for canonical IR in this analysis, so a shuffle element that // references an undef value means this can't return any extra information. TEST(ValueTracking, ComputeNumSignBits_Shuffle2) { @@ -534,7 +533,7 @@ auto *RVal = cast(F->getEntryBlock().getTerminator())->getOperand(0); - EXPECT_EQ(ComputeNumSignBits(RVal, M->getDataLayout()), 32u); + EXPECT_EQ(ComputeNumSignBits(RVal, M->getDataLayout()), 1u); } TEST(ValueTracking, ComputeKnownBits) {