diff --git a/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp b/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp --- a/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp +++ b/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp @@ -73,6 +73,13 @@ __func__, __VA_ARGS__); \ } while (0) +// Returns default CUdevice +CUdevice getDefaultCuDevice() { + CUdevice device; + CUDA_REPORT_IF_ERROR(cuDeviceGet(&device, /*ordinal=*/defaultDevice)); + return device; +} + // Make the primary context of the current default device current for the // duration // of the instance and restore the previous context on destruction. @@ -83,11 +90,10 @@ // defaultDevice. static CUcontext context = [] { CUDA_REPORT_IF_ERROR(cuInit(/*flags=*/0)); - CUdevice device; - CUDA_REPORT_IF_ERROR(cuDeviceGet(&device, /*ordinal=*/defaultDevice)); CUcontext ctx; // Note: this does not affect the current context. - CUDA_REPORT_IF_ERROR(cuDevicePrimaryCtxRetain(&ctx, device)); + CUDA_REPORT_IF_ERROR( + cuDevicePrimaryCtxRetain(&ctx, getDefaultCuDevice())); return ctx; }(); @@ -140,6 +146,24 @@ intptr_t blockZ, int32_t smem, CUstream stream, void **params, void **extra) { ScopedContext scopedContext; + int32_t maxShmem = 0; + CUdevice device = getDefaultCuDevice(); + CUDA_REPORT_IF_ERROR(cuDeviceGet(&device, /*ordinal=*/defaultDevice)); + CUDA_REPORT_IF_ERROR(cuDeviceGetAttribute( + &maxShmem, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, + device)); + if (maxShmem < smem) { + fprintf(stderr, + "Requested shared memory (%dkb) is larger than maximum allowed " + "shared memory (%dkb) for this device\n", + smem, maxShmem); + } + CUDA_REPORT_IF_ERROR(cuFuncSetAttribute( + function, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, smem)); + debug_print("Launching kernel, grid=%ld,%ld,%ld, " + "threads: %ld, %ld, %ld, " + "smem: %dkb\n", + gridX, gridY, gridZ, blockX, blockY, blockZ, smem); CUDA_REPORT_IF_ERROR(cuLaunchKernel(function, gridX, gridY, gridZ, blockX, blockY, blockZ, smem, stream, params, extra));