diff --git a/llvm/include/llvm/IR/InstrTypes.h b/llvm/include/llvm/IR/InstrTypes.h --- a/llvm/include/llvm/IR/InstrTypes.h +++ b/llvm/include/llvm/IR/InstrTypes.h @@ -1401,6 +1401,13 @@ return nullptr; } + /// Returns the function called and removes the pointer casting around it + Function *getCalledFunctionRemovingPtrCasts() const { + if (auto *Operand = getCalledOperand()) + return dyn_cast_or_null(Operand->stripPointerCasts()); + return nullptr; + } + /// Return true if the callsite is an indirect call. bool isIndirectCall() const; diff --git a/llvm/lib/Target/NVPTX/NVVMReflect.cpp b/llvm/lib/Target/NVPTX/NVVMReflect.cpp --- a/llvm/lib/Target/NVPTX/NVVMReflect.cpp +++ b/llvm/lib/Target/NVPTX/NVVMReflect.cpp @@ -118,6 +118,8 @@ if (!Call) continue; Function *Callee = Call->getCalledFunction(); + if (!Callee) + Callee = Call->getCalledFunctionRemovingPtrCasts(); if (!Callee || (Callee->getName() != NVVM_REFLECT_FUNCTION && Callee->getIntrinsicID() != Intrinsic::nvvm_reflect)) continue; diff --git a/llvm/test/CodeGen/NVPTX/nvvm-reflect-cast.ll b/llvm/test/CodeGen/NVPTX/nvvm-reflect-cast.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/NVPTX/nvvm-reflect-cast.ll @@ -0,0 +1,21 @@ +; Verify that __nvvm_reflect() is replaced with an appropriate value when +; wrapped in bitcast. +; +; RUN: opt %s -S -passes='default' -mtriple=nvptx64 \ +; RUN: | FileCheck %s --check-prefixes=COMMON,SM20 +; RUN: opt %s -S -passes='default' -mtriple=nvptx64 -mcpu=sm_35 \ +; RUN: | FileCheck %s --check-prefixes=COMMON,SM35 + +@"$str" = private addrspace(4) constant [12 x i8] c"__CUDA_ARCH\00" + +declare i32 @__nvvm_reflect(i8*) + +; COMMON-LABEL: @foo +define i32 @foo(float %a, float %b) { +; COMMON-NOT: call i32 @__nvvm_reflect + %reflect = tail call i32 bitcast (i32 (i8*)* @__nvvm_reflect to i32 (i8 addrspace(4)*)*)(i8 addrspace(4)* noundef getelementptr inbounds ([12 x i8], [12 x i8] addrspace(4)* @"$str", i64 0, i64 0)) +; SM20: ret i32 200 +; SM35: ret i32 350 + ret i32 %reflect +} +