diff --git a/llvm/lib/IR/Constants.cpp b/llvm/lib/IR/Constants.cpp --- a/llvm/lib/IR/Constants.cpp +++ b/llvm/lib/IR/Constants.cpp @@ -286,12 +286,17 @@ if (!isa(Y) || !Ty->isVectorTy() || Ty != Y->getType()) return false; + // TODO: Compare pointer constants? + if (!(Ty->getVectorElementType()->isIntegerTy() || + Ty->getVectorElementType()->isFloatingPointTy())) + return false; + // They may still be identical element-wise (if they have `undef`s). - // FIXME: This crashes on FP vector constants. - return match(ConstantExpr::getICmp(ICmpInst::Predicate::ICMP_EQ, - const_cast(this), - cast(Y)), - m_One()); + // Bitcast to integer to allow exact bitwise comparison for all types. + Type *IntTy = VectorType::getInteger(cast(Ty)); + Constant *C0 = ConstantExpr::getBitCast(const_cast(this), IntTy); + Constant *C1 = ConstantExpr::getBitCast(cast(Y), IntTy); + return match(ConstantExpr::getICmp(ICmpInst::ICMP_EQ, C0, C1), m_One()); } bool Constant::containsUndefElement() const { diff --git a/llvm/unittests/IR/ConstantsTest.cpp b/llvm/unittests/IR/ConstantsTest.cpp --- a/llvm/unittests/IR/ConstantsTest.cpp +++ b/llvm/unittests/IR/ConstantsTest.cpp @@ -586,7 +586,8 @@ } // Check that undefined elements in vector constants are matched -// correctly for both integer and floating-point types. +// correctly for both integer and floating-point types. Just don't +// crash on vectors of pointers (could be handled?). TEST(ConstantsTest, isElementWiseEqual) { LLVMContext Context; @@ -607,7 +608,6 @@ EXPECT_FALSE(C12U1->isElementWiseEqual(C12U2)); EXPECT_FALSE(C12U21->isElementWiseEqual(C12U2)); -/* FIXME: This will crash. Type *FltTy = Type::getFloatTy(Context); Constant *CFU = UndefValue::get(FltTy); Constant *CF1 = ConstantFP::get(FltTy, 1.0); @@ -621,7 +621,19 @@ EXPECT_TRUE(CF12U1->isElementWiseEqual(CF1211)); EXPECT_FALSE(CF12U2->isElementWiseEqual(CF12U1)); EXPECT_FALSE(CF12U1->isElementWiseEqual(CF12U2)); -*/ + + PointerType *PtrTy = Type::getInt8PtrTy(Context); + Constant *CPU = UndefValue::get(PtrTy); + Constant *CP0 = ConstantPointerNull::get(PtrTy); + + Constant *CP0000 = ConstantVector::get({CP0, CP0, CP0, CP0}); + Constant *CP00U0 = ConstantVector::get({CP0, CP0, CPU, CP0}); + Constant *CP00U = ConstantVector::get({CP0, CP0, CPU}); + + EXPECT_FALSE(CP0000->isElementWiseEqual(CP00U0)); + EXPECT_FALSE(CP00U0->isElementWiseEqual(CP0000)); + EXPECT_FALSE(CP0000->isElementWiseEqual(CP00U)); + EXPECT_FALSE(CP00U->isElementWiseEqual(CP00U0)); } } // end anonymous namespace