diff --git a/llvm/lib/IR/Constants.cpp b/llvm/lib/IR/Constants.cpp --- a/llvm/lib/IR/Constants.cpp +++ b/llvm/lib/IR/Constants.cpp @@ -1588,6 +1588,27 @@ return CV->getSplatValue(); if (const ConstantVector *CV = dyn_cast(this)) return CV->getSplatValue(AllowUndefs); + + // Check if this is a constant expression splat of the form returned by + // ConstantVector::getSplat() + const auto *Shuf = dyn_cast(this); + if (Shuf && Shuf->getOpcode() == Instruction::ShuffleVector && + isa(Shuf->getOperand(1))) { + + const auto *IElt = dyn_cast(Shuf->getOperand(0)); + if (IElt && IElt->getOpcode() == Instruction::InsertElement && + isa(IElt->getOperand(0))) { + + ArrayRef Mask = Shuf->getShuffleMask(); + Constant *SplatVal = IElt->getOperand(1); + ConstantInt *Index = dyn_cast(IElt->getOperand(2)); + + if (Index && Index->getValue() == 0 && + std::all_of(Mask.begin(), Mask.end(), [](int I) { return I == 0; })) + return SplatVal; + } + } + return nullptr; } diff --git a/llvm/test/Transforms/InstSimplify/vscale.ll b/llvm/test/Transforms/InstSimplify/vscale.ll --- a/llvm/test/Transforms/InstSimplify/vscale.ll +++ b/llvm/test/Transforms/InstSimplify/vscale.ll @@ -95,6 +95,15 @@ ret i32 %r } +; more complicated expressions + +define @cmp_le_smax_always_true( %x) { +; CHECK_LABEL: @cmp_le_smax_always_true( +; CHECK-NEXT: ret shufflevector ( insertelement ( undef, i1 true, i32 0), undef, zeroinitializer) + %cmp = icmp sle %x, shufflevector ( insertelement ( undef, i64 9223372036854775807, i32 0), undef, zeroinitializer) + ret %cmp +} + ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; ;; Memory Access and Addressing Operations ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; diff --git a/llvm/unittests/IR/ConstantsTest.cpp b/llvm/unittests/IR/ConstantsTest.cpp --- a/llvm/unittests/IR/ConstantsTest.cpp +++ b/llvm/unittests/IR/ConstantsTest.cpp @@ -638,5 +638,34 @@ EXPECT_FALSE(CP00U->isElementWiseEqual(CP00U0)); } +TEST(ConstantsTest, GetSplatValueRoundTrip) { + LLVMContext Context; + + Type *FloatTy = Type::getFloatTy(Context); + Type *Int32Ty = Type::getInt32Ty(Context); + Type *Int8Ty = Type::getInt8Ty(Context); + + for (unsigned Min : {1, 2, 8}) { + ElementCount SEC = {Min, true}; + ElementCount FEC = {Min, false}; + + for (auto EC : {SEC, FEC}) { + for (auto *Ty : {FloatTy, Int32Ty, Int8Ty}) { + Constant *Zero = Constant::getNullValue(Ty); + Constant *One = Constant::getAllOnesValue(Ty); + + for (auto *C : {Zero, One}) { + Constant *Splat = ConstantVector::getSplat(EC, C); + ASSERT_NE(nullptr, Splat); + + Constant *SplatVal = Splat->getSplatValue(); + EXPECT_NE(nullptr, SplatVal); + EXPECT_EQ(SplatVal, C); + } + } + } + } +} + } // end anonymous namespace } // end namespace llvm