diff --git a/llvm/include/llvm/IR/IntrinsicInst.h b/llvm/include/llvm/IR/IntrinsicInst.h --- a/llvm/include/llvm/IR/IntrinsicInst.h +++ b/llvm/include/llvm/IR/IntrinsicInst.h @@ -390,8 +390,10 @@ class VPIntrinsic : public IntrinsicInst { public: /// \brief Declares a llvm.vp.* intrinsic in \p M that matches the parameters - /// \p Params. + /// \p Params. Additionally, the load and gather intrinsics require + /// \p ReturnType to be specified. static Function *getDeclarationForParams(Module *M, Intrinsic::ID, + Type *ReturnType, ArrayRef Params); static Optional getMaskParamPos(Intrinsic::ID IntrinsicID); diff --git a/llvm/lib/IR/IntrinsicInst.cpp b/llvm/lib/IR/IntrinsicInst.cpp --- a/llvm/lib/IR/IntrinsicInst.cpp +++ b/llvm/lib/IR/IntrinsicInst.cpp @@ -468,6 +468,7 @@ } Function *VPIntrinsic::getDeclarationForParams(Module *M, Intrinsic::ID VPID, + Type *ReturnType, ArrayRef Params) { assert(isVPIntrinsic(VPID) && "not a VP intrinsic"); Function *VPFunc; @@ -486,22 +487,15 @@ break; case Intrinsic::vp_load: VPFunc = Intrinsic::getDeclaration( - M, VPID, - {Params[0]->getType()->getPointerElementType(), Params[0]->getType()}); + M, VPID, {ReturnType, Params[0]->getType()}); break; case Intrinsic::vp_gather: VPFunc = Intrinsic::getDeclaration( - M, VPID, - {VectorType::get(cast(Params[0]->getType()) - ->getElementType() - ->getPointerElementType(), - cast(Params[0]->getType())), - Params[0]->getType()}); + M, VPID, {ReturnType, Params[0]->getType()}); break; case Intrinsic::vp_store: VPFunc = Intrinsic::getDeclaration( - M, VPID, - {Params[1]->getType()->getPointerElementType(), Params[1]->getType()}); + M, VPID, {Params[0]->getType(), Params[1]->getType()}); break; case Intrinsic::vp_scatter: VPFunc = Intrinsic::getDeclaration( diff --git a/llvm/unittests/IR/VPIntrinsicTest.cpp b/llvm/unittests/IR/VPIntrinsicTest.cpp --- a/llvm/unittests/IR/VPIntrinsicTest.cpp +++ b/llvm/unittests/IR/VPIntrinsicTest.cpp @@ -272,7 +272,7 @@ ASSERT_NE(F.getIntrinsicID(), Intrinsic::not_intrinsic); auto *NewDecl = VPIntrinsic::getDeclarationForParams( - OutM.get(), F.getIntrinsicID(), Values); + OutM.get(), F.getIntrinsicID(), FuncTy->getReturnType(), Values); ASSERT_TRUE(NewDecl); // Check that 'old decl' == 'new decl'.