Index: lib/Target/NVPTX/NVPTXLowerKernelArgs.cpp =================================================================== --- lib/Target/NVPTX/NVPTXLowerKernelArgs.cpp +++ lib/Target/NVPTX/NVPTXLowerKernelArgs.cpp @@ -47,6 +47,36 @@ // ... // } // +// 3. Convert pointers in a byval kernel parameter to pointers in the global +// address space. As #2, it allows NVPTX to emit more ld/st.global. E.g., +// +// struct S { +// int *x; +// int *y; +// }; +// __global__ void foo(S s) { +// int *b = s.y; +// // use b +// } +// +// "b" points to the global address space. In the IR level, +// +// define void @foo({i32*, i32*}* byval %input) { +// %b_ptr = getelementptr {i32*, i32*}, {i32*, i32*}* %input, i64 0, i32 1 +// %b = load i32*, i32** %b_ptr +// ; use %b +// } +// +// becomes +// +// define void @foo({i32*, i32*}* byval %input) { +// %b_ptr = getelementptr {i32*, i32*}, {i32*, i32*}* %input, i64 0, i32 1 +// %b = load i32*, i32** %b_ptr +// %b_global = addrspacecast i32* %b to i32 addrspace(1)* +// %b_generic = addrspacecast i32 addrspace(1)* %b_global to i32* +// ; use %b_generic +// } +// // TODO: merge this pass with NVPTXFavorNonGenericAddrSpace so that other passes // don't cancel the addrspacecast pair this pass emits. //===----------------------------------------------------------------------===// @@ -54,6 +84,7 @@ #include "NVPTX.h" #include "NVPTXUtilities.h" #include "NVPTXTargetMachine.h" +#include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/Function.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/Module.h" @@ -71,9 +102,12 @@ bool runOnFunction(Function &F) override; // handle byval parameters - void handleByValParam(Argument *); - // handle non-byval pointer parameters - void handlePointerParam(Argument *); + void handleByValParam(Argument *Arg); + // Knowing Ptr must point to the global address space, this function + // addrspacecasts Ptr to global and then back to generic. This allows + // NVPTXFavorNonGenericAddrSpace to fold the global-to-generic cast into + // loads/stores that appear later. + void markPointerAsGlobal(Value *Ptr); public: static char ID; // Pass identification, replacement for typeid @@ -128,27 +162,34 @@ new StoreInst(LI, AllocA, FirstInst); } -void NVPTXLowerKernelArgs::handlePointerParam(Argument *Arg) { - assert(!Arg->hasByValAttr() && - "byval params should be handled by handleByValParam"); - - // Do nothing if the argument already points to the global address space. - if (Arg->getType()->getPointerAddressSpace() == ADDRESS_SPACE_GLOBAL) +void NVPTXLowerKernelArgs::markPointerAsGlobal(Value *Ptr) { + if (Ptr->getType()->getPointerAddressSpace() == ADDRESS_SPACE_GLOBAL) return; - Instruction *FirstInst = Arg->getParent()->getEntryBlock().begin(); - Instruction *ArgInGlobal = new AddrSpaceCastInst( - Arg, PointerType::get(Arg->getType()->getPointerElementType(), + // Deciding where to emit the addrspacecast pair. + BasicBlock::iterator InsertPt; + if (Argument *Arg = dyn_cast(Ptr)) { + // Insert at the functon entry if Ptr is an argument. + InsertPt = Arg->getParent()->getEntryBlock().begin(); + } else { + // Insert right after Ptr if Ptr is an instruction. + InsertPt = cast(Ptr); + ++InsertPt; + assert(InsertPt != InsertPt->getParent()->end() && + "We don't call this function with Ptr being a terminator."); + } + + Instruction *PtrInGlobal = new AddrSpaceCastInst( + Ptr, PointerType::get(Ptr->getType()->getPointerElementType(), ADDRESS_SPACE_GLOBAL), - Arg->getName(), FirstInst); - Value *ArgInGeneric = new AddrSpaceCastInst(ArgInGlobal, Arg->getType(), - Arg->getName(), FirstInst); - // Replace with ArgInGeneric all uses of Args except ArgInGlobal. - Arg->replaceAllUsesWith(ArgInGeneric); - ArgInGlobal->setOperand(0, Arg); + Ptr->getName(), InsertPt); + Value *PtrInGeneric = new AddrSpaceCastInst(PtrInGlobal, Ptr->getType(), + Ptr->getName(), InsertPt); + // Replace with PtrInGeneric all uses of Ptr except PtrInGlobal. + Ptr->replaceAllUsesWith(PtrInGeneric); + PtrInGlobal->setOperand(0, Ptr); } - // ============================================================================= // Main function for this pass. // ============================================================================= @@ -157,12 +198,32 @@ if (!isKernelFunction(F)) return false; + if (TM && TM->getDrvInterface() == NVPTX::CUDA) { + // Mark pointers in byval structs as global. + for (auto &B : F) { + for (auto &I : B) { + if (LoadInst *LI = dyn_cast(&I)) { + if (LI->getType()->isPointerTy()) { + Value *UO = GetUnderlyingObject(LI->getPointerOperand(), + F.getParent()->getDataLayout()); + if (Argument *Arg = dyn_cast(UO)) { + if (Arg->hasByValAttr()) { + // LI is a load from a pointer within a byval kernel parameter. + markPointerAsGlobal(LI); + } + } + } + } + } + } + } + for (Argument &Arg : F.args()) { if (Arg.getType()->isPointerTy()) { if (Arg.hasByValAttr()) handleByValParam(&Arg); else if (TM && TM->getDrvInterface() == NVPTX::CUDA) - handlePointerParam(&Arg); + markPointerAsGlobal(&Arg); } } return true; Index: test/CodeGen/NVPTX/lower-kernel-ptr-arg.ll =================================================================== --- test/CodeGen/NVPTX/lower-kernel-ptr-arg.ll +++ test/CodeGen/NVPTX/lower-kernel-ptr-arg.ll @@ -1,7 +1,7 @@ ; RUN: llc < %s -march=nvptx64 -mcpu=sm_20 | FileCheck %s target datalayout = "e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-f64:64:64-v16:16:16-v32:32:32-v64:64:64-v128:128:128-n16:32:64" -target triple = "nvptx64-unknown-unknown" +target triple = "nvptx64-nvidia-cuda" ; Verify that both %input and %output are converted to global pointers and then ; addrspacecast'ed back to the original type. @@ -26,6 +26,22 @@ ret void } -!nvvm.annotations = !{!0, !1} +%struct.S = type { i32*, i32* } + +define void @ptr_in_byval(%struct.S* byval %input, i32* %output) { +; CHECK-LABEL: .visible .entry ptr_in_byval( +; CHECK: cvta.to.global.u64 +; CHECK: cvta.to.global.u64 + %b_ptr = getelementptr inbounds %struct.S, %struct.S* %input, i64 0, i32 1 + %b = load i32*, i32** %b_ptr, align 4 + %v = load i32, i32* %b, align 4 +; CHECK: ld.global.u32 + store i32 %v, i32* %output, align 4 +; CHECK: st.global.u32 + ret void +} + +!nvvm.annotations = !{!0, !1, !2} !0 = !{void (float*, float*)* @kernel, !"kernel", i32 1} !1 = !{void (float addrspace(1)*, float addrspace(1)*)* @kernel2, !"kernel", i32 1} +!2 = !{void (%struct.S*, i32*)* @ptr_in_byval, !"kernel", i32 1}