diff --git a/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp b/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp --- a/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp +++ b/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp @@ -39,7 +39,7 @@ static constexpr const char *kGpuModuleGetFunctionName = "mgpuModuleGetFunction"; static constexpr const char *kGpuLaunchKernelName = "mgpuLaunchKernel"; -static constexpr const char *kGpuGetStreamHelperName = "mgpuGetStreamHelper"; +static constexpr const char *kGpuStreamCreateName = "mgpuStreamCreate"; static constexpr const char *kGpuStreamSynchronizeName = "mgpuStreamSynchronize"; static constexpr const char *kGpuMemHostRegisterName = "mgpuMemHostRegister"; @@ -100,12 +100,6 @@ getLLVMDialect(), module.getDataLayout().getPointerSizeInBits()); } - LLVM::LLVMType getGpuRuntimeResultType() { - // This is declared as an enum in both CUDA and ROCm (HIP), but helpers - // use i32. - return getInt32Type(); - } - // Allocate a void pointer on the stack. Value allocatePointer(OpBuilder &builder, Location loc) { auto one = builder.create(loc, getInt32Type(), @@ -168,27 +162,21 @@ if (!module.lookupSymbol(kGpuModuleLoadName)) { builder.create( loc, kGpuModuleLoadName, - LLVM::LLVMType::getFunctionTy( - getGpuRuntimeResultType(), - { - getPointerPointerType(), /* CUmodule *module */ - getPointerType() /* void *cubin */ - }, - /*isVarArg=*/false)); + LLVM::LLVMType::getFunctionTy(getPointerType(), + {getPointerType()}, /* void *cubin */ + /*isVarArg=*/false)); } if (!module.lookupSymbol(kGpuModuleGetFunctionName)) { // The helper uses void* instead of CUDA's opaque CUmodule and // CUfunction, or ROCm (HIP)'s opaque hipModule_t and hipFunction_t. builder.create( loc, kGpuModuleGetFunctionName, - LLVM::LLVMType::getFunctionTy( - getGpuRuntimeResultType(), - { - getPointerPointerType(), /* void **function */ - getPointerType(), /* void *module */ - getPointerType() /* char *name */ - }, - /*isVarArg=*/false)); + LLVM::LLVMType::getFunctionTy(getPointerType(), + { + getPointerType(), /* void *module */ + getPointerType() /* char *name */ + }, + /*isVarArg=*/false)); } if (!module.lookupSymbol(kGpuLaunchKernelName)) { // Other than the CUDA or ROCm (HIP) api, the wrappers use uintptr_t to @@ -198,7 +186,7 @@ builder.create( loc, kGpuLaunchKernelName, LLVM::LLVMType::getFunctionTy( - getGpuRuntimeResultType(), + getVoidType(), { getPointerType(), /* void* f */ getIntPtrType(), /* intptr_t gridXDim */ @@ -214,18 +202,18 @@ }, /*isVarArg=*/false)); } - if (!module.lookupSymbol(kGpuGetStreamHelperName)) { + if (!module.lookupSymbol(kGpuStreamCreateName)) { // Helper function to get the current GPU compute stream. Uses void* // instead of CUDA's opaque CUstream, or ROCm (HIP)'s opaque hipStream_t. builder.create( - loc, kGpuGetStreamHelperName, + loc, kGpuStreamCreateName, LLVM::LLVMType::getFunctionTy(getPointerType(), /*isVarArg=*/false)); } if (!module.lookupSymbol(kGpuStreamSynchronizeName)) { builder.create( loc, kGpuStreamSynchronizeName, - LLVM::LLVMType::getFunctionTy(getGpuRuntimeResultType(), - getPointerType() /* CUstream stream */, + LLVM::LLVMType::getFunctionTy(getVoidType(), + {getPointerType()}, /* void *stream */ /*isVarArg=*/false)); } if (!module.lookupSymbol(kGpuMemHostRegisterName)) { @@ -365,17 +353,13 @@ // hsaco in the 'rocdl.hsaco' attribute of the kernel function in the IR. // // %0 = call %binarygetter -// %1 = alloca sizeof(void*) -// call %moduleLoad(%2, %1) -// %2 = alloca sizeof(void*) -// %3 = load %1 -// %4 = -// call %moduleGetFunction(%2, %3, %4) -// %5 = call %getStreamHelper() -// %6 = load %2 -// %7 = -// call %launchKernel(%6, , 0, %5, %7, nullptr) -// call %streamSynchronize(%5) +// %1 = call %moduleLoad(%0) +// %2 = +// %3 = call %moduleGetFunction(%1, %2) +// %4 = call %streamCreate() +// %5 = +// call %launchKernel(%3, , 0, %4, %5, nullptr) +// call %streamSynchronize(%4) void GpuLaunchFuncToGpuRuntimeCallsPass::translateGpuLaunchCalls( mlir::gpu::LaunchFuncOp launchOp) { OpBuilder builder(launchOp); @@ -405,36 +389,30 @@ // Emit the load module call to load the module data. Error checking is done // in the called helper function. - auto gpuModule = allocatePointer(builder, loc); auto gpuModuleLoad = getOperation().lookupSymbol(kGpuModuleLoadName); - builder.create(loc, ArrayRef{getGpuRuntimeResultType()}, - builder.getSymbolRefAttr(gpuModuleLoad), - ArrayRef{gpuModule, data}); + auto module = builder.create( + loc, ArrayRef{getPointerType()}, + builder.getSymbolRefAttr(gpuModuleLoad), ArrayRef{data}); // Get the function from the module. The name corresponds to the name of // the kernel function. - auto gpuOwningModuleRef = - builder.create(loc, getPointerType(), gpuModule); auto kernelName = generateKernelNameConstant( launchOp.getKernelModuleName(), launchOp.getKernelName(), loc, builder); - auto gpuFunction = allocatePointer(builder, loc); auto gpuModuleGetFunction = getOperation().lookupSymbol(kGpuModuleGetFunctionName); - builder.create( - loc, ArrayRef{getGpuRuntimeResultType()}, + auto function = builder.create( + loc, ArrayRef{getPointerType()}, builder.getSymbolRefAttr(gpuModuleGetFunction), - ArrayRef{gpuFunction, gpuOwningModuleRef, kernelName}); + ArrayRef{module.getResult(0), kernelName}); // Grab the global stream needed for execution. - auto gpuGetStreamHelper = - getOperation().lookupSymbol(kGpuGetStreamHelperName); - auto gpuStream = builder.create( + auto gpuStreamCreate = + getOperation().lookupSymbol(kGpuStreamCreateName); + auto stream = builder.create( loc, ArrayRef{getPointerType()}, - builder.getSymbolRefAttr(gpuGetStreamHelper), ArrayRef{}); + builder.getSymbolRefAttr(gpuStreamCreate), ArrayRef{}); // Invoke the function with required arguments. auto gpuLaunchKernel = getOperation().lookupSymbol(kGpuLaunchKernelName); - auto gpuFunctionRef = - builder.create(loc, getPointerType(), gpuFunction); auto paramsArray = setupParamsArray(launchOp, builder); if (!paramsArray) { launchOp.emitOpError() << "cannot pass given parameters to the kernel"; @@ -443,21 +421,21 @@ auto nullpointer = builder.create(loc, getPointerPointerType(), zero); builder.create( - loc, ArrayRef{getGpuRuntimeResultType()}, + loc, ArrayRef{getVoidType()}, builder.getSymbolRefAttr(gpuLaunchKernel), - ArrayRef{gpuFunctionRef, launchOp.getOperand(0), + ArrayRef{function.getResult(0), launchOp.getOperand(0), launchOp.getOperand(1), launchOp.getOperand(2), launchOp.getOperand(3), launchOp.getOperand(4), launchOp.getOperand(5), zero, /* sharedMemBytes */ - gpuStream.getResult(0), /* stream */ + stream.getResult(0), /* stream */ paramsArray, /* kernel params */ nullpointer /* extra */}); // Sync on the stream to make it synchronous. auto gpuStreamSync = getOperation().lookupSymbol(kGpuStreamSynchronizeName); - builder.create(loc, ArrayRef{getGpuRuntimeResultType()}, + builder.create(loc, ArrayRef{getVoidType()}, builder.getSymbolRefAttr(gpuStreamSync), - ArrayRef(gpuStream.getResult(0))); + ArrayRef(stream.getResult(0))); launchOp.erase(); } diff --git a/mlir/test/Conversion/GPUCommon/lower-launch-func-to-gpu-runtime-calls.mlir b/mlir/test/Conversion/GPUCommon/lower-launch-func-to-gpu-runtime-calls.mlir --- a/mlir/test/Conversion/GPUCommon/lower-launch-func-to-gpu-runtime-calls.mlir +++ b/mlir/test/Conversion/GPUCommon/lower-launch-func-to-gpu-runtime-calls.mlir @@ -20,13 +20,11 @@ // CHECK: %[[addressof:.*]] = llvm.mlir.addressof @[[global]] // CHECK: %[[c0:.*]] = llvm.mlir.constant(0 : index) - // CHECK: %[[binary_ptr:.*]] = llvm.getelementptr %[[addressof]][%[[c0]], %[[c0]]] + // CHECK: %[[binary:.*]] = llvm.getelementptr %[[addressof]][%[[c0]], %[[c0]]] // CHECK-SAME: -> !llvm<"i8*"> - // CHECK: %[[module_ptr:.*]] = llvm.alloca {{.*}} x !llvm<"i8*"> : (!llvm.i32) -> !llvm<"i8**"> - // CHECK: llvm.call @mgpuModuleLoad(%[[module_ptr]], %[[binary_ptr]]) : (!llvm<"i8**">, !llvm<"i8*">) -> !llvm.i32 - // CHECK: %[[func_ptr:.*]] = llvm.alloca {{.*}} x !llvm<"i8*"> : (!llvm.i32) -> !llvm<"i8**"> - // CHECK: llvm.call @mgpuModuleGetFunction(%[[func_ptr]], {{.*}}, {{.*}}) : (!llvm<"i8**">, !llvm<"i8*">, !llvm<"i8*">) -> !llvm.i32 - // CHECK: llvm.call @mgpuGetStreamHelper + // CHECK: %[[module:.*]] = llvm.call @mgpuModuleLoad(%[[binary]]) : (!llvm<"i8*">) -> !llvm<"i8*"> + // CHECK: %[[func:.*]] = llvm.call @mgpuModuleGetFunction(%[[module]], {{.*}}) : (!llvm<"i8*">, !llvm<"i8*">) -> !llvm<"i8*"> + // CHECK: llvm.call @mgpuStreamCreate // CHECK: llvm.call @mgpuLaunchKernel // CHECK: llvm.call @mgpuStreamSynchronize "gpu.launch_func"(%cst, %cst, %cst, %cst, %cst, %cst, %0, %1) { kernel = @kernel_module::@kernel } diff --git a/mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp b/mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp --- a/mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp +++ b/mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp @@ -21,54 +21,50 @@ #include "cuda.h" -namespace { -int32_t reportErrorIfAny(CUresult result, const char *where) { - if (result != CUDA_SUCCESS) { - llvm::errs() << "CUDA failed with " << result << " in " << where << "\n"; - } - return result; +#define CUDA_REPORT_IF_ERROR(expr) \ + [](CUresult result) { \ + if (!result) \ + return; \ + const char *name = nullptr; \ + cuGetErrorName(result, &name); \ + if (!name) \ + name = ""; \ + llvm::errs() << "'" << #expr << "' failed with '" << name << "'\n"; \ + }(expr) + +extern "C" CUmodule mgpuModuleLoad(void *data) { + CUmodule module = nullptr; + CUDA_REPORT_IF_ERROR(cuModuleLoadData(&module, data)); + return module; } -} // anonymous namespace -extern "C" int32_t mgpuModuleLoad(void **module, void *data) { - int32_t err = reportErrorIfAny( - cuModuleLoadData(reinterpret_cast(module), data), - "ModuleLoad"); - return err; -} - -extern "C" int32_t mgpuModuleGetFunction(void **function, void *module, - const char *name) { - return reportErrorIfAny( - cuModuleGetFunction(reinterpret_cast(function), - reinterpret_cast(module), name), - "GetFunction"); +extern "C" CUfunction mgpuModuleGetFunction(CUmodule module, const char *name) { + CUfunction function = nullptr; + CUDA_REPORT_IF_ERROR(cuModuleGetFunction(&function, module, name)); + return function; } // The wrapper uses intptr_t instead of CUDA's unsigned int to match // the type of MLIR's index type. This avoids the need for casts in the // generated MLIR code. -extern "C" int32_t mgpuLaunchKernel(void *function, intptr_t gridX, - intptr_t gridY, intptr_t gridZ, - intptr_t blockX, intptr_t blockY, - intptr_t blockZ, int32_t smem, void *stream, - void **params, void **extra) { - return reportErrorIfAny( - cuLaunchKernel(reinterpret_cast(function), gridX, gridY, - gridZ, blockX, blockY, blockZ, smem, - reinterpret_cast(stream), params, extra), - "LaunchKernel"); +extern "C" void mgpuLaunchKernel(CUfunction function, intptr_t gridX, + intptr_t gridY, intptr_t gridZ, + intptr_t blockX, intptr_t blockY, + intptr_t blockZ, int32_t smem, CUstream stream, + void **params, void **extra) { + CUDA_REPORT_IF_ERROR(cuLaunchKernel(function, gridX, gridY, gridZ, blockX, + blockY, blockZ, smem, stream, params, + extra)); } -extern "C" void *mgpuGetStreamHelper() { - CUstream stream; - reportErrorIfAny(cuStreamCreate(&stream, CU_STREAM_DEFAULT), "StreamCreate"); +extern "C" CUstream mgpuStreamCreate() { + CUstream stream = nullptr; + CUDA_REPORT_IF_ERROR(cuStreamCreate(&stream, CU_STREAM_NON_BLOCKING)); return stream; } -extern "C" int32_t mgpuStreamSynchronize(void *stream) { - return reportErrorIfAny( - cuStreamSynchronize(reinterpret_cast(stream)), "StreamSync"); +extern "C" void mgpuStreamSynchronize(CUstream stream) { + CUDA_REPORT_IF_ERROR(cuStreamSynchronize(stream)); } /// Helper functions for writing mlir example code @@ -76,17 +72,16 @@ // Allows to register byte array with the CUDA runtime. Helpful until we have // transfer functions implemented. extern "C" void mgpuMemHostRegister(void *ptr, uint64_t sizeBytes) { - reportErrorIfAny(cuMemHostRegister(ptr, sizeBytes, /*flags=*/0), - "MemHostRegister"); + CUDA_REPORT_IF_ERROR(cuMemHostRegister(ptr, sizeBytes, /*flags=*/0)); } // Allows to register a MemRef with the CUDA runtime. Initializes array with // value. Helpful until we have transfer functions implemented. template -void mgpuMemHostRegisterMemRef(const DynamicMemRefType &mem_ref, T value) { - llvm::SmallVector denseStrides(mem_ref.rank); - llvm::ArrayRef sizes(mem_ref.sizes, mem_ref.rank); - llvm::ArrayRef strides(mem_ref.strides, mem_ref.rank); +void mgpuMemHostRegisterMemRef(const DynamicMemRefType &memRef, T value) { + llvm::SmallVector denseStrides(memRef.rank); + llvm::ArrayRef sizes(memRef.sizes, memRef.rank); + llvm::ArrayRef strides(memRef.strides, memRef.rank); std::partial_sum(sizes.rbegin(), sizes.rend(), denseStrides.rbegin(), std::multiplies()); @@ -98,17 +93,17 @@ denseStrides.back() = 1; assert(strides == llvm::makeArrayRef(denseStrides)); - auto *pointer = mem_ref.data + mem_ref.offset; + auto *pointer = memRef.data + memRef.offset; std::fill_n(pointer, count, value); mgpuMemHostRegister(pointer, count * sizeof(T)); } extern "C" void mgpuMemHostRegisterFloat(int64_t rank, void *ptr) { - UnrankedMemRefType mem_ref = {rank, ptr}; - mgpuMemHostRegisterMemRef(DynamicMemRefType(mem_ref), 1.23f); + UnrankedMemRefType memRef = {rank, ptr}; + mgpuMemHostRegisterMemRef(DynamicMemRefType(memRef), 1.23f); } extern "C" void mgpuMemHostRegisterInt32(int64_t rank, void *ptr) { - UnrankedMemRefType mem_ref = {rank, ptr}; - mgpuMemHostRegisterMemRef(DynamicMemRefType(mem_ref), 123); + UnrankedMemRefType memRef = {rank, ptr}; + mgpuMemHostRegisterMemRef(DynamicMemRefType(memRef), 123); } diff --git a/mlir/tools/mlir-rocm-runner/rocm-runtime-wrappers.cpp b/mlir/tools/mlir-rocm-runner/rocm-runtime-wrappers.cpp --- a/mlir/tools/mlir-rocm-runner/rocm-runtime-wrappers.cpp +++ b/mlir/tools/mlir-rocm-runner/rocm-runtime-wrappers.cpp @@ -21,56 +21,52 @@ #include "hip/hip_runtime.h" -namespace { -int32_t reportErrorIfAny(hipError_t result, const char *where) { - if (result != hipSuccess) { - llvm::errs() << "HIP failed with " << result << " in " << where << "\n"; - } - return result; +#define HIP_REPORT_IF_ERROR(expr) \ + [](hipError_t result) { \ + if (!result) \ + return; \ + const char *name = nullptr; \ + hipGetErrorName(result, &name); \ + if (!name) \ + name = ""; \ + llvm::errs() << "'" << #expr << "' failed with '" << name << "'\n"; \ + }(expr) + +extern "C" hipModule_t mgpuModuleLoad(void *data) { + hipModule_t module = nullptr; + HIP_REPORT_IF_ERROR(hipModuleLoadData(&module, data)); + return module; } -} // anonymous namespace -extern "C" int32_t mgpuModuleLoad(void **module, void *data) { - int32_t err = reportErrorIfAny( - hipModuleLoadData(reinterpret_cast(module), data), - "ModuleLoad"); - return err; -} - -extern "C" int32_t mgpuModuleGetFunction(void **function, void *module, - const char *name) { - return reportErrorIfAny( - hipModuleGetFunction(reinterpret_cast(function), - reinterpret_cast(module), name), - "GetFunction"); +extern "C" hipFunction_t mgpuModuleGetFunction(hipModule_t module, + const char *name) { + hipFunction_t function = nullptr; + HIP_REPORT_IF_ERROR(hipModuleGetFunction(&function, module, name)); + return function; } // The wrapper uses intptr_t instead of ROCM's unsigned int to match // the type of MLIR's index type. This avoids the need for casts in the // generated MLIR code. -extern "C" int32_t mgpuLaunchKernel(void *function, intptr_t gridX, - intptr_t gridY, intptr_t gridZ, - intptr_t blockX, intptr_t blockY, - intptr_t blockZ, int32_t smem, void *stream, - void **params, void **extra) { - return reportErrorIfAny( - hipModuleLaunchKernel(reinterpret_cast(function), gridX, - gridY, gridZ, blockX, blockY, blockZ, smem, - reinterpret_cast(stream), params, - extra), - "LaunchKernel"); +extern "C" void mgpuLaunchKernel(hipFunction_t function, intptr_t gridX, + intptr_t gridY, intptr_t gridZ, + intptr_t blockX, intptr_t blockY, + intptr_t blockZ, int32_t smem, + hipStream_t stream, void **params, + void **extra) { + HIP_REPORT_IF_ERROR(hipModuleLaunchKernel(function, gridX, gridY, gridZ, + blockX, blockY, blockZ, smem, + stream, params, extra)); } -extern "C" void *mgpuGetStreamHelper() { - hipStream_t stream; - reportErrorIfAny(hipStreamCreate(&stream), "StreamCreate"); +extern "C" void *mgpuStreamCreate() { + hipStream_t stream = nullptr; + HIP_REPORT_IF_ERROR(hipStreamCreate(&stream)); return stream; } -extern "C" int32_t mgpuStreamSynchronize(void *stream) { - return reportErrorIfAny( - hipStreamSynchronize(reinterpret_cast(stream)), - "StreamSync"); +extern "C" void mgpuStreamSynchronize(hipStream_t stream) { + return HIP_REPORT_IF_ERROR(hipStreamSynchronize(stream)); } /// Helper functions for writing mlir example code @@ -78,8 +74,8 @@ // Allows to register byte array with the ROCM runtime. Helpful until we have // transfer functions implemented. extern "C" void mgpuMemHostRegister(void *ptr, uint64_t sizeBytes) { - reportErrorIfAny(hipHostRegister(ptr, sizeBytes, /*flags=*/0), - "MemHostRegister"); + HIP_REPORT_IF_ERROR(hipHostRegister(ptr, sizeBytes, /*flags=*/0), + "MemHostRegister"); } // Allows to register a MemRef with the ROCM runtime. Initializes array with @@ -120,8 +116,8 @@ template void mgpuMemGetDevicePointer(T *hostPtr, T **devicePtr) { - reportErrorIfAny(hipSetDevice(0), "hipSetDevice"); - reportErrorIfAny( + HIP_REPORT_IF_ERROR(hipSetDevice(0), "hipSetDevice"); + HIP_REPORT_IF_ERROR( hipHostGetDevicePointer((void **)devicePtr, hostPtr, /*flags=*/0), "hipHostGetDevicePointer"); }