diff --git a/mlir/CMakeLists.txt b/mlir/CMakeLists.txt --- a/mlir/CMakeLists.txt +++ b/mlir/CMakeLists.txt @@ -44,6 +44,8 @@ else() set(MLIR_CUDA_CONVERSIONS_ENABLED 0) endif() +# TODO: we should use a config.h file like LLVM does +add_definitions(-DMLIR_CUDA_CONVERSIONS_ENABLED=${MLIR_CUDA_CONVERSIONS_ENABLED}) set(MLIR_CUDA_RUNNER_ENABLED 0 CACHE BOOL "Enable building the mlir CUDA runner") diff --git a/mlir/include/mlir/InitAllPasses.h b/mlir/include/mlir/InitAllPasses.h --- a/mlir/include/mlir/InitAllPasses.h +++ b/mlir/include/mlir/InitAllPasses.h @@ -90,8 +90,10 @@ // CUDA createConvertGpuLaunchFuncToCudaCallsPass(); +#if MLIR_CUDA_CONVERSIONS_ENABLED createConvertGPUKernelToCubinPass( [](const std::string &, Location, StringRef) { return nullptr; }); +#endif createLowerGpuOpsToNVVMOpsPass(); // Linalg diff --git a/mlir/lib/Conversion/GPUToCUDA/CMakeLists.txt b/mlir/lib/Conversion/GPUToCUDA/CMakeLists.txt --- a/mlir/lib/Conversion/GPUToCUDA/CMakeLists.txt +++ b/mlir/lib/Conversion/GPUToCUDA/CMakeLists.txt @@ -1,7 +1,16 @@ -add_llvm_library(MLIRGPUtoCUDATransforms +set(LLVM_OPTIONAL_SOURCES ConvertKernelFuncToCubin.cpp +) + +set(SOURCES ConvertLaunchFuncToCudaCalls.cpp ) + +if (MLIR_CUDA_CONVERSIONS_ENABLED) + append(SOURCES ConvertKernelFuncToCubin.cpp) +endif() + +add_llvm_library(MLIRGPUtoCUDATransforms ${SOURCES}) target_link_libraries(MLIRGPUtoCUDATransforms MLIRGPU MLIRLLVMIR diff --git a/mlir/lib/Conversion/GPUToCUDA/ConvertKernelFuncToCubin.cpp b/mlir/lib/Conversion/GPUToCUDA/ConvertKernelFuncToCubin.cpp --- a/mlir/lib/Conversion/GPUToCUDA/ConvertKernelFuncToCubin.cpp +++ b/mlir/lib/Conversion/GPUToCUDA/ConvertKernelFuncToCubin.cpp @@ -57,9 +57,10 @@ gpu::GPUModuleOp module = getOperation(); // Make sure the NVPTX target is initialized. - llvm::InitializeAllTargets(); - llvm::InitializeAllTargetMCs(); - llvm::InitializeAllAsmPrinters(); + LLVMInitializeNVPTXTarget(); + LLVMInitializeNVPTXTargetInfo(); + LLVMInitializeNVPTXTargetMC(); + LLVMInitializeNVPTXAsmPrinter(); auto llvmModule = translateModuleToNVVMIR(module); if (!llvmModule)