Index: streamexecutor/include/streamexecutor/platforms/cuda/CUDAPlatformDevice.h =================================================================== --- streamexecutor/include/streamexecutor/platforms/cuda/CUDAPlatformDevice.h +++ streamexecutor/include/streamexecutor/platforms/cuda/CUDAPlatformDevice.h @@ -17,6 +17,13 @@ #include "streamexecutor/PlatformDevice.h" +#include "llvm/Support/Mutex.h" + +#include + +struct CUfunc_st; +struct CUmod_st; + namespace streamexecutor { namespace cuda { @@ -85,6 +92,8 @@ CUDAPlatformDevice(size_t DeviceIndex) : DeviceIndex(DeviceIndex) {} int DeviceIndex; + llvm::sys::Mutex Mutex; + std::map> LoadedModules; }; } // namespace cuda Index: streamexecutor/lib/platforms/cuda/CUDAPlatformDevice.cpp =================================================================== --- streamexecutor/lib/platforms/cuda/CUDAPlatformDevice.cpp +++ streamexecutor/lib/platforms/cuda/CUDAPlatformDevice.cpp @@ -90,7 +90,6 @@ Expected CUDAPlatformDevice::createKernel(const MultiKernelLoaderSpec &Spec) { - // TODO(jhen): Maybe first check loaded modules? if (!Spec.hasCUDAPTXInMemory()) return make_error("no CUDA code available to create kernel"); @@ -117,27 +116,26 @@ llvm::Twine(ComputeCapabilityMajor) + "." + llvm::Twine(ComputeCapabilityMinor)); - CUmodule Module; - if (CUresult Result = cuModuleLoadData(&Module, Code)) - return CUresultToError(Result, "cuModuleLoadData"); - - CUfunction Function; - if (CUresult Result = - cuModuleGetFunction(&Function, Module, Spec.getKernelName().c_str())) - return CUresultToError(Result, "cuModuleGetFunction"); - - // TODO(jhen): Should I save this function pointer in case someone asks for - // it again? - - // TODO(jhen): Should I save the module pointer so I can unload it when I - // destroy this device? + CUfunction Function = nullptr; + { + llvm::sys::ScopedLock Lock(Mutex); + auto Iterator = LoadedModules.find(Code); + if (Iterator == LoadedModules.end()) { + CUmodule Module = nullptr; + if (CUresult Result = cuModuleLoadData(&Module, Code)) + return CUresultToError(Result, "cuModuleLoadData"); + if (CUresult Result = cuModuleGetFunction(&Function, Module, + Spec.getKernelName().c_str())) + return CUresultToError(Result, "cuModuleGetFunction"); + LoadedModules.emplace(Code, std::make_pair(Module, Function)); + } else + Function = Iterator->second.second; + } return static_cast(Function); } Error CUDAPlatformDevice::destroyKernel(const void *Handle) { - // TODO(jhen): Maybe keep track of kernels for each module and unload the - // module after they are all destroyed. return Error::success(); }