Index: streamexecutor/include/streamexecutor/KernelSpec.h =================================================================== --- streamexecutor/include/streamexecutor/KernelSpec.h +++ streamexecutor/include/streamexecutor/KernelSpec.h @@ -121,12 +121,11 @@ llvm::StringRef KernelName, const llvm::ArrayRef SpecList); - /// Returns a pointer to the PTX code for the requested compute capability. + /// Returns a pointer to the PTX code for the greatest compute capability not + /// exceeding the requested compute capability. /// - /// Returns nullptr on failed lookup (if the requested compute capability is - /// not available). Matches exactly the specified compute capability. Doesn't - /// try to do anything smart like finding the next best compute capability if - /// the specified capability cannot be found. + /// Returns nullptr on failed lookup (if the requested version is not + /// available and no lower versions are available). const char *getCode(int ComputeCapabilityMajor, int ComputeCapabilityMinor) const; Index: streamexecutor/lib/KernelSpec.cpp =================================================================== --- streamexecutor/lib/KernelSpec.cpp +++ streamexecutor/lib/KernelSpec.cpp @@ -31,12 +31,13 @@ const char *CUDAPTXInMemorySpec::getCode(int ComputeCapabilityMajor, int ComputeCapabilityMinor) const { - auto PTXIter = - PTXByComputeCapability.find(CUDAPTXInMemorySpec::ComputeCapability{ + auto Iterator = + PTXByComputeCapability.upper_bound(CUDAPTXInMemorySpec::ComputeCapability{ ComputeCapabilityMajor, ComputeCapabilityMinor}); - if (PTXIter == PTXByComputeCapability.end()) + if (Iterator == PTXByComputeCapability.begin()) return nullptr; - return PTXIter->second; + --Iterator; + return Iterator->second; } CUDAFatbinInMemorySpec::CUDAFatbinInMemorySpec(llvm::StringRef KernelName, Index: streamexecutor/unittests/CoreTests/KernelSpecTest.cpp =================================================================== --- streamexecutor/unittests/CoreTests/KernelSpecTest.cpp +++ streamexecutor/unittests/CoreTests/KernelSpecTest.cpp @@ -30,8 +30,9 @@ const char *PTXCodeString = "Dummy PTX code"; se::CUDAPTXInMemorySpec Spec("KernelName", {{{1, 0}, PTXCodeString}}); EXPECT_EQ("KernelName", Spec.getKernelName()); + EXPECT_EQ(nullptr, Spec.getCode(0, 5)); EXPECT_EQ(PTXCodeString, Spec.getCode(1, 0)); - EXPECT_EQ(nullptr, Spec.getCode(2, 0)); + EXPECT_EQ(PTXCodeString, Spec.getCode(2, 0)); } TEST(CUDAPTXInMemorySpec, TwoComputeCapabilities) { @@ -40,9 +41,10 @@ se::CUDAPTXInMemorySpec Spec( "KernelName", {{{1, 0}, PTXCodeString10}, {{3, 0}, PTXCodeString30}}); EXPECT_EQ("KernelName", Spec.getKernelName()); + EXPECT_EQ(nullptr, Spec.getCode(0, 5)); EXPECT_EQ(PTXCodeString10, Spec.getCode(1, 0)); EXPECT_EQ(PTXCodeString30, Spec.getCode(3, 0)); - EXPECT_EQ(nullptr, Spec.getCode(2, 0)); + EXPECT_EQ(PTXCodeString10, Spec.getCode(2, 0)); } TEST(CUDAFatbinInMemorySpec, BasicUsage) { @@ -89,8 +91,9 @@ EXPECT_TRUE(MultiSpec.hasOpenCLTextInMemory()); EXPECT_EQ(KernelName, MultiSpec.getCUDAPTXInMemory().getKernelName()); + EXPECT_EQ(nullptr, MultiSpec.getCUDAPTXInMemory().getCode(0, 5)); EXPECT_EQ(PTXCodeString, MultiSpec.getCUDAPTXInMemory().getCode(1, 0)); - EXPECT_EQ(nullptr, MultiSpec.getCUDAPTXInMemory().getCode(2, 0)); + EXPECT_EQ(PTXCodeString, MultiSpec.getCUDAPTXInMemory().getCode(2, 0)); EXPECT_EQ(KernelName, MultiSpec.getCUDAFatbinInMemory().getKernelName()); EXPECT_EQ(FatbinBytes, MultiSpec.getCUDAFatbinInMemory().getBytes());