diff --git a/llvm/include/llvm/IR/InstrTypes.h b/llvm/include/llvm/IR/InstrTypes.h --- a/llvm/include/llvm/IR/InstrTypes.h +++ b/llvm/include/llvm/IR/InstrTypes.h @@ -1401,13 +1401,6 @@ return nullptr; } - /// Returns the function called and removes the pointer casting around it - Function *getCalledFunctionRemovingPtrCasts() const { - if (auto *Operand = getCalledOperand()) - return dyn_cast_or_null(Operand->stripPointerCasts()); - return nullptr; - } - /// Return true if the callsite is an indirect call. bool isIndirectCall() const; diff --git a/llvm/lib/Target/NVPTX/NVVMReflect.cpp b/llvm/lib/Target/NVPTX/NVVMReflect.cpp --- a/llvm/lib/Target/NVPTX/NVVMReflect.cpp +++ b/llvm/lib/Target/NVPTX/NVVMReflect.cpp @@ -40,6 +40,7 @@ #include #include #define NVVM_REFLECT_FUNCTION "__nvvm_reflect" +#define NVVM_REFLECT_OCL_FUNCTION "__nvvm_reflect_ocl" using namespace llvm; @@ -78,7 +79,8 @@ if (!NVVMReflectEnabled) return false; - if (F.getName() == NVVM_REFLECT_FUNCTION) { + if (F.getName() == NVVM_REFLECT_FUNCTION || + F.getName() == NVVM_REFLECT_OCL_FUNCTION) { assert(F.isDeclaration() && "_reflect function should not have a body"); assert(F.getReturnType()->isIntegerTy() && "_reflect's return type should be integer"); @@ -118,9 +120,8 @@ if (!Call) continue; Function *Callee = Call->getCalledFunction(); - if (!Callee) - Callee = Call->getCalledFunctionRemovingPtrCasts(); if (!Callee || (Callee->getName() != NVVM_REFLECT_FUNCTION && + Callee->getName() != NVVM_REFLECT_OCL_FUNCTION && Callee->getIntrinsicID() != Intrinsic::nvvm_reflect)) continue; diff --git a/llvm/test/CodeGen/NVPTX/nvvm-reflect-cast.ll b/llvm/test/CodeGen/NVPTX/nvvm-reflect-ocl.ll rename from llvm/test/CodeGen/NVPTX/nvvm-reflect-cast.ll rename to llvm/test/CodeGen/NVPTX/nvvm-reflect-ocl.ll --- a/llvm/test/CodeGen/NVPTX/nvvm-reflect-cast.ll +++ b/llvm/test/CodeGen/NVPTX/nvvm-reflect-ocl.ll @@ -8,12 +8,12 @@ @"$str" = private addrspace(4) constant [12 x i8] c"__CUDA_ARCH\00" -declare i32 @__nvvm_reflect(i8*) +declare i32 @__nvvm_reflect_ocl(i8 addrspace(4)* noundef) ; COMMON-LABEL: @foo define i32 @foo(float %a, float %b) { -; COMMON-NOT: call i32 @__nvvm_reflect - %reflect = tail call i32 bitcast (i32 (i8*)* @__nvvm_reflect to i32 (i8 addrspace(4)*)*)(i8 addrspace(4)* noundef getelementptr inbounds ([12 x i8], [12 x i8] addrspace(4)* @"$str", i64 0, i64 0)) +; COMMON-NOT: call i32 @__nvvm_reflect_ocl + %reflect = tail call i32 @__nvvm_reflect_ocl(i8 addrspace(4)* noundef getelementptr inbounds ([12 x i8], [12 x i8] addrspace(4)* @"$str", i64 0, i64 0)) ; SM20: ret i32 200 ; SM35: ret i32 350 ret i32 %reflect