diff --git a/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp b/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp --- a/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp +++ b/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp @@ -39,7 +39,7 @@ static constexpr const char *kGpuModuleGetFunctionName = "mgpuModuleGetFunction"; static constexpr const char *kGpuLaunchKernelName = "mgpuLaunchKernel"; -static constexpr const char *kGpuGetStreamHelperName = "mgpuGetStreamHelper"; +static constexpr const char *kGpuStreamCreateName = "mgpuStreamCreate"; static constexpr const char *kGpuStreamSynchronizeName = "mgpuStreamSynchronize"; static constexpr const char *kGpuMemHostRegisterName = "mgpuMemHostRegister"; @@ -168,27 +168,21 @@ if (!module.lookupSymbol(kGpuModuleLoadName)) { builder.create( loc, kGpuModuleLoadName, - LLVM::LLVMType::getFunctionTy( - getGpuRuntimeResultType(), - { - getPointerPointerType(), /* CUmodule *module */ - getPointerType() /* void *cubin */ - }, - /*isVarArg=*/false)); + LLVM::LLVMType::getFunctionTy(getPointerType(), + getPointerType(), /* void *cubin */ + /*isVarArg=*/false)); } if (!module.lookupSymbol(kGpuModuleGetFunctionName)) { // The helper uses void* instead of CUDA's opaque CUmodule and // CUfunction, or ROCm (HIP)'s opaque hipModule_t and hipFunction_t. builder.create( loc, kGpuModuleGetFunctionName, - LLVM::LLVMType::getFunctionTy( - getGpuRuntimeResultType(), - { - getPointerPointerType(), /* void **function */ - getPointerType(), /* void *module */ - getPointerType() /* char *name */ - }, - /*isVarArg=*/false)); + LLVM::LLVMType::getFunctionTy(getPointerType(), + { + getPointerType(), /* void *module */ + getPointerType() /* char *name */ + }, + /*isVarArg=*/false)); } if (!module.lookupSymbol(kGpuLaunchKernelName)) { // Other than the CUDA or ROCm (HIP) api, the wrappers use uintptr_t to @@ -198,7 +192,7 @@ builder.create( loc, kGpuLaunchKernelName, LLVM::LLVMType::getFunctionTy( - getGpuRuntimeResultType(), + getVoidType(), { getPointerType(), /* void* f */ getIntPtrType(), /* intptr_t gridXDim */ @@ -214,18 +208,18 @@ }, /*isVarArg=*/false)); } - if (!module.lookupSymbol(kGpuGetStreamHelperName)) { + if (!module.lookupSymbol(kGpuStreamCreateName)) { // Helper function to get the current GPU compute stream. Uses void* // instead of CUDA's opaque CUstream, or ROCm (HIP)'s opaque hipStream_t. builder.create( - loc, kGpuGetStreamHelperName, + loc, kGpuStreamCreateName, LLVM::LLVMType::getFunctionTy(getPointerType(), /*isVarArg=*/false)); } if (!module.lookupSymbol(kGpuStreamSynchronizeName)) { builder.create( loc, kGpuStreamSynchronizeName, - LLVM::LLVMType::getFunctionTy(getGpuRuntimeResultType(), - getPointerType() /* CUstream stream */, + LLVM::LLVMType::getFunctionTy(getVoidType(), + getPointerType() /* void *stream */, /*isVarArg=*/false)); } if (!module.lookupSymbol(kGpuMemHostRegisterName)) { @@ -365,17 +359,13 @@ // hsaco in the 'rocdl.hsaco' attribute of the kernel function in the IR. // // %0 = call %binarygetter -// %1 = alloca sizeof(void*) -// call %moduleLoad(%2, %1) -// %2 = alloca sizeof(void*) -// %3 = load %1 -// %4 = -// call %moduleGetFunction(%2, %3, %4) -// %5 = call %getStreamHelper() -// %6 = load %2 -// %7 = -// call %launchKernel(%6, , 0, %5, %7, nullptr) -// call %streamSynchronize(%5) +// %1 = call %moduleLoad(%0) +// %2 = +// %3 = call %moduleGetFunction(%1, %2) +// %4 = call %streamCreate() +// %5 = +// call %launchKernel(%3, , 0, %4, %5, nullptr) +// call %streamSynchronize(%4) void GpuLaunchFuncToGpuRuntimeCallsPass::translateGpuLaunchCalls( mlir::gpu::LaunchFuncOp launchOp) { OpBuilder builder(launchOp); @@ -405,36 +395,30 @@ // Emit the load module call to load the module data. Error checking is done // in the called helper function. - auto gpuModule = allocatePointer(builder, loc); auto gpuModuleLoad = getOperation().lookupSymbol(kGpuModuleLoadName); - builder.create(loc, ArrayRef{getGpuRuntimeResultType()}, - builder.getSymbolRefAttr(gpuModuleLoad), - ArrayRef{gpuModule, data}); + auto module = builder.create( + loc, ArrayRef{getVoidType()}, + builder.getSymbolRefAttr(gpuModuleLoad), ArrayRef{data}); // Get the function from the module. The name corresponds to the name of // the kernel function. - auto gpuOwningModuleRef = - builder.create(loc, getPointerType(), gpuModule); auto kernelName = generateKernelNameConstant( launchOp.getKernelModuleName(), launchOp.getKernelName(), loc, builder); - auto gpuFunction = allocatePointer(builder, loc); auto gpuModuleGetFunction = getOperation().lookupSymbol(kGpuModuleGetFunctionName); - builder.create( - loc, ArrayRef{getGpuRuntimeResultType()}, + auto function = builder.create( + loc, ArrayRef{getVoidType()}, builder.getSymbolRefAttr(gpuModuleGetFunction), - ArrayRef{gpuFunction, gpuOwningModuleRef, kernelName}); + ArrayRef{module.getResult(0), kernelName}); // Grab the global stream needed for execution. - auto gpuGetStreamHelper = - getOperation().lookupSymbol(kGpuGetStreamHelperName); - auto gpuStream = builder.create( + auto gpuStreamCreate = + getOperation().lookupSymbol(kGpuStreamCreateName); + auto stream = builder.create( loc, ArrayRef{getPointerType()}, - builder.getSymbolRefAttr(gpuGetStreamHelper), ArrayRef{}); + builder.getSymbolRefAttr(gpuStreamCreate), ArrayRef{}); // Invoke the function with required arguments. auto gpuLaunchKernel = getOperation().lookupSymbol(kGpuLaunchKernelName); - auto gpuFunctionRef = - builder.create(loc, getPointerType(), gpuFunction); auto paramsArray = setupParamsArray(launchOp, builder); if (!paramsArray) { launchOp.emitOpError() << "cannot pass given parameters to the kernel"; @@ -445,11 +429,11 @@ builder.create( loc, ArrayRef{getGpuRuntimeResultType()}, builder.getSymbolRefAttr(gpuLaunchKernel), - ArrayRef{gpuFunctionRef, launchOp.getOperand(0), + ArrayRef{function.getResult(0), launchOp.getOperand(0), launchOp.getOperand(1), launchOp.getOperand(2), launchOp.getOperand(3), launchOp.getOperand(4), launchOp.getOperand(5), zero, /* sharedMemBytes */ - gpuStream.getResult(0), /* stream */ + stream.getResult(0), /* stream */ paramsArray, /* kernel params */ nullpointer /* extra */}); // Sync on the stream to make it synchronous. @@ -457,7 +441,7 @@ getOperation().lookupSymbol(kGpuStreamSynchronizeName); builder.create(loc, ArrayRef{getGpuRuntimeResultType()}, builder.getSymbolRefAttr(gpuStreamSync), - ArrayRef(gpuStream.getResult(0))); + ArrayRef(stream.getResult(0))); launchOp.erase(); } diff --git a/mlir/test/Conversion/GPUCommon/lower-launch-func-to-gpu-runtime-calls.mlir b/mlir/test/Conversion/GPUCommon/lower-launch-func-to-gpu-runtime-calls.mlir --- a/mlir/test/Conversion/GPUCommon/lower-launch-func-to-gpu-runtime-calls.mlir +++ b/mlir/test/Conversion/GPUCommon/lower-launch-func-to-gpu-runtime-calls.mlir @@ -26,7 +26,7 @@ // CHECK: llvm.call @mgpuModuleLoad(%[[module_ptr]], %[[binary_ptr]]) : (!llvm<"i8**">, !llvm<"i8*">) -> !llvm.i32 // CHECK: %[[func_ptr:.*]] = llvm.alloca {{.*}} x !llvm<"i8*"> : (!llvm.i32) -> !llvm<"i8**"> // CHECK: llvm.call @mgpuModuleGetFunction(%[[func_ptr]], {{.*}}, {{.*}}) : (!llvm<"i8**">, !llvm<"i8*">, !llvm<"i8*">) -> !llvm.i32 - // CHECK: llvm.call @mgpuGetStreamHelper + // CHECK: llvm.call @mgpuStreamCreate // CHECK: llvm.call @mgpuLaunchKernel // CHECK: llvm.call @mgpuStreamSynchronize "gpu.launch_func"(%cst, %cst, %cst, %cst, %cst, %cst, %0, %1) { kernel = @kernel_module::@kernel } diff --git a/mlir/test/mlir-cuda-runner/all-reduce-and.mlir b/mlir/test/mlir-cuda-runner/all-reduce-and.mlir --- a/mlir/test/mlir-cuda-runner/all-reduce-and.mlir +++ b/mlir/test/mlir-cuda-runner/all-reduce-and.mlir @@ -25,9 +25,9 @@ %c6 = constant 6 : index %cast_data = memref_cast %data : memref<2x6xi32> to memref<*xi32> - call @mcuMemHostRegisterInt32(%cast_data) : (memref<*xi32>) -> () + call @mgpuMemHostRegisterInt32(%cast_data) : (memref<*xi32>) -> () %cast_sum = memref_cast %sum : memref<2xi32> to memref<*xi32> - call @mcuMemHostRegisterInt32(%cast_sum) : (memref<*xi32>) -> () + call @mgpuMemHostRegisterInt32(%cast_sum) : (memref<*xi32>) -> () store %cst0, %data[%c0, %c0] : memref<2x6xi32> store %cst1, %data[%c0, %c1] : memref<2x6xi32> @@ -58,6 +58,6 @@ return } -func @mcuMemHostRegisterInt32(%ptr : memref<*xi32>) +func @mgpuMemHostRegisterInt32(%ptr : memref<*xi32>) func @print_memref_i32(memref<*xi32>) diff --git a/mlir/test/mlir-cuda-runner/all-reduce-max.mlir b/mlir/test/mlir-cuda-runner/all-reduce-max.mlir --- a/mlir/test/mlir-cuda-runner/all-reduce-max.mlir +++ b/mlir/test/mlir-cuda-runner/all-reduce-max.mlir @@ -25,9 +25,9 @@ %c6 = constant 6 : index %cast_data = memref_cast %data : memref<2x6xi32> to memref<*xi32> - call @mcuMemHostRegisterInt32(%cast_data) : (memref<*xi32>) -> () + call @mgpuMemHostRegisterInt32(%cast_data) : (memref<*xi32>) -> () %cast_sum = memref_cast %sum : memref<2xi32> to memref<*xi32> - call @mcuMemHostRegisterInt32(%cast_sum) : (memref<*xi32>) -> () + call @mgpuMemHostRegisterInt32(%cast_sum) : (memref<*xi32>) -> () store %cst0, %data[%c0, %c0] : memref<2x6xi32> store %cst1, %data[%c0, %c1] : memref<2x6xi32> @@ -58,6 +58,6 @@ return } -func @mcuMemHostRegisterInt32(%ptr : memref<*xi32>) +func @mgpuMemHostRegisterInt32(%ptr : memref<*xi32>) func @print_memref_i32(memref<*xi32>) diff --git a/mlir/test/mlir-cuda-runner/all-reduce-min.mlir b/mlir/test/mlir-cuda-runner/all-reduce-min.mlir --- a/mlir/test/mlir-cuda-runner/all-reduce-min.mlir +++ b/mlir/test/mlir-cuda-runner/all-reduce-min.mlir @@ -25,9 +25,9 @@ %c6 = constant 6 : index %cast_data = memref_cast %data : memref<2x6xi32> to memref<*xi32> - call @mcuMemHostRegisterInt32(%cast_data) : (memref<*xi32>) -> () + call @mgpuMemHostRegisterInt32(%cast_data) : (memref<*xi32>) -> () %cast_sum = memref_cast %sum : memref<2xi32> to memref<*xi32> - call @mcuMemHostRegisterInt32(%cast_sum) : (memref<*xi32>) -> () + call @mgpuMemHostRegisterInt32(%cast_sum) : (memref<*xi32>) -> () store %cst0, %data[%c0, %c0] : memref<2x6xi32> store %cst1, %data[%c0, %c1] : memref<2x6xi32> @@ -58,6 +58,6 @@ return } -func @mcuMemHostRegisterInt32(%ptr : memref<*xi32>) +func @mgpuMemHostRegisterInt32(%ptr : memref<*xi32>) func @print_memref_i32(memref<*xi32>) diff --git a/mlir/test/mlir-cuda-runner/all-reduce-op.mlir b/mlir/test/mlir-cuda-runner/all-reduce-op.mlir --- a/mlir/test/mlir-cuda-runner/all-reduce-op.mlir +++ b/mlir/test/mlir-cuda-runner/all-reduce-op.mlir @@ -11,7 +11,7 @@ %sy = dim %dst, %c1 : memref %sz = dim %dst, %c0 : memref %cast_dst = memref_cast %dst : memref to memref<*xf32> - call @mcuMemHostRegisterFloat(%cast_dst) : (memref<*xf32>) -> () + call @mgpuMemHostRegisterFloat(%cast_dst) : (memref<*xf32>) -> () gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c1, %grid_y = %c1, %grid_z = %c1) threads(%tx, %ty, %tz) in (%block_x = %sx, %block_y = %sy, %block_z = %sz) { %t0 = muli %tz, %block_y : index @@ -28,5 +28,5 @@ return } -func @mcuMemHostRegisterFloat(%ptr : memref<*xf32>) +func @mgpuMemHostRegisterFloat(%ptr : memref<*xf32>) func @print_memref_f32(%ptr : memref<*xf32>) diff --git a/mlir/test/mlir-cuda-runner/all-reduce-or.mlir b/mlir/test/mlir-cuda-runner/all-reduce-or.mlir --- a/mlir/test/mlir-cuda-runner/all-reduce-or.mlir +++ b/mlir/test/mlir-cuda-runner/all-reduce-or.mlir @@ -25,9 +25,9 @@ %c6 = constant 6 : index %cast_data = memref_cast %data : memref<2x6xi32> to memref<*xi32> - call @mcuMemHostRegisterInt32(%cast_data) : (memref<*xi32>) -> () + call @mgpuMemHostRegisterInt32(%cast_data) : (memref<*xi32>) -> () %cast_sum = memref_cast %sum : memref<2xi32> to memref<*xi32> - call @mcuMemHostRegisterInt32(%cast_sum) : (memref<*xi32>) -> () + call @mgpuMemHostRegisterInt32(%cast_sum) : (memref<*xi32>) -> () store %cst0, %data[%c0, %c0] : memref<2x6xi32> store %cst1, %data[%c0, %c1] : memref<2x6xi32> @@ -58,6 +58,6 @@ return } -func @mcuMemHostRegisterInt32(%ptr : memref<*xi32>) +func @mgpuMemHostRegisterInt32(%ptr : memref<*xi32>) func @print_memref_i32(memref<*xi32>) diff --git a/mlir/test/mlir-cuda-runner/all-reduce-region.mlir b/mlir/test/mlir-cuda-runner/all-reduce-region.mlir --- a/mlir/test/mlir-cuda-runner/all-reduce-region.mlir +++ b/mlir/test/mlir-cuda-runner/all-reduce-region.mlir @@ -8,7 +8,7 @@ %c0 = constant 0 : index %sx = dim %dst, %c0 : memref %cast_dst = memref_cast %dst : memref to memref<*xf32> - call @mcuMemHostRegisterFloat(%cast_dst) : (memref<*xf32>) -> () + call @mgpuMemHostRegisterFloat(%cast_dst) : (memref<*xf32>) -> () gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %one, %grid_y = %one, %grid_z = %one) threads(%tx, %ty, %tz) in (%block_x = %sx, %block_y = %one, %block_z = %one) { %val = index_cast %tx : index to i32 @@ -25,5 +25,5 @@ return } -func @mcuMemHostRegisterFloat(%ptr : memref<*xf32>) +func @mgpuMemHostRegisterFloat(%ptr : memref<*xf32>) func @print_memref_f32(memref<*xf32>) diff --git a/mlir/test/mlir-cuda-runner/all-reduce-xor.mlir b/mlir/test/mlir-cuda-runner/all-reduce-xor.mlir --- a/mlir/test/mlir-cuda-runner/all-reduce-xor.mlir +++ b/mlir/test/mlir-cuda-runner/all-reduce-xor.mlir @@ -25,9 +25,9 @@ %c6 = constant 6 : index %cast_data = memref_cast %data : memref<2x6xi32> to memref<*xi32> - call @mcuMemHostRegisterInt32(%cast_data) : (memref<*xi32>) -> () + call @mgpuMemHostRegisterInt32(%cast_data) : (memref<*xi32>) -> () %cast_sum = memref_cast %sum : memref<2xi32> to memref<*xi32> - call @mcuMemHostRegisterInt32(%cast_sum) : (memref<*xi32>) -> () + call @mgpuMemHostRegisterInt32(%cast_sum) : (memref<*xi32>) -> () store %cst0, %data[%c0, %c0] : memref<2x6xi32> store %cst1, %data[%c0, %c1] : memref<2x6xi32> @@ -58,6 +58,6 @@ return } -func @mcuMemHostRegisterInt32(%ptr : memref<*xi32>) +func @mgpuMemHostRegisterInt32(%ptr : memref<*xi32>) func @print_memref_i32(memref<*xi32>) diff --git a/mlir/test/mlir-cuda-runner/gpu-to-cubin.mlir b/mlir/test/mlir-cuda-runner/gpu-to-cubin.mlir --- a/mlir/test/mlir-cuda-runner/gpu-to-cubin.mlir +++ b/mlir/test/mlir-cuda-runner/gpu-to-cubin.mlir @@ -18,7 +18,7 @@ %21 = constant 5 : i32 %22 = memref_cast %arg0 : memref<5xf32> to memref %23 = memref_cast %22 : memref to memref<*xf32> - call @mcuMemHostRegisterFloat(%23) : (memref<*xf32>) -> () + call @mgpuMemHostRegisterFloat(%23) : (memref<*xf32>) -> () call @print_memref_f32(%23) : (memref<*xf32>) -> () %24 = constant 1.0 : f32 call @other_func(%24, %22) : (f32, memref) -> () @@ -26,5 +26,5 @@ return } -func @mcuMemHostRegisterFloat(%ptr : memref<*xf32>) +func @mgpuMemHostRegisterFloat(%ptr : memref<*xf32>) func @print_memref_f32(%ptr : memref<*xf32>) diff --git a/mlir/test/mlir-cuda-runner/multiple-all-reduce.mlir b/mlir/test/mlir-cuda-runner/multiple-all-reduce.mlir --- a/mlir/test/mlir-cuda-runner/multiple-all-reduce.mlir +++ b/mlir/test/mlir-cuda-runner/multiple-all-reduce.mlir @@ -26,11 +26,11 @@ %c6 = constant 6 : index %cast_data = memref_cast %data : memref<2x6xf32> to memref<*xf32> - call @mcuMemHostRegisterFloat(%cast_data) : (memref<*xf32>) -> () + call @mgpuMemHostRegisterFloat(%cast_data) : (memref<*xf32>) -> () %cast_sum = memref_cast %sum : memref<2xf32> to memref<*xf32> - call @mcuMemHostRegisterFloat(%cast_sum) : (memref<*xf32>) -> () + call @mgpuMemHostRegisterFloat(%cast_sum) : (memref<*xf32>) -> () %cast_mul = memref_cast %mul : memref<2xf32> to memref<*xf32> - call @mcuMemHostRegisterFloat(%cast_mul) : (memref<*xf32>) -> () + call @mgpuMemHostRegisterFloat(%cast_mul) : (memref<*xf32>) -> () store %cst0, %data[%c0, %c0] : memref<2x6xf32> store %cst1, %data[%c0, %c1] : memref<2x6xf32> @@ -66,5 +66,5 @@ return } -func @mcuMemHostRegisterFloat(%ptr : memref<*xf32>) +func @mgpuMemHostRegisterFloat(%ptr : memref<*xf32>) func @print_memref_f32(memref<*xf32>) diff --git a/mlir/test/mlir-cuda-runner/shuffle.mlir b/mlir/test/mlir-cuda-runner/shuffle.mlir --- a/mlir/test/mlir-cuda-runner/shuffle.mlir +++ b/mlir/test/mlir-cuda-runner/shuffle.mlir @@ -8,7 +8,7 @@ %c0 = constant 0 : index %sx = dim %dst, %c0 : memref %cast_dest = memref_cast %dst : memref to memref<*xf32> - call @mcuMemHostRegisterFloat(%cast_dest) : (memref<*xf32>) -> () + call @mgpuMemHostRegisterFloat(%cast_dest) : (memref<*xf32>) -> () gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %one, %grid_y = %one, %grid_z = %one) threads(%tx, %ty, %tz) in (%block_x = %sx, %block_y = %one, %block_z = %one) { %t0 = index_cast %tx : index to i32 @@ -28,5 +28,5 @@ return } -func @mcuMemHostRegisterFloat(%ptr : memref<*xf32>) +func @mgpuMemHostRegisterFloat(%ptr : memref<*xf32>) func @print_memref_f32(%ptr : memref<*xf32>) diff --git a/mlir/test/mlir-cuda-runner/two-modules.mlir b/mlir/test/mlir-cuda-runner/two-modules.mlir --- a/mlir/test/mlir-cuda-runner/two-modules.mlir +++ b/mlir/test/mlir-cuda-runner/two-modules.mlir @@ -8,7 +8,7 @@ %c0 = constant 0 : index %sx = dim %dst, %c0 : memref %cast_dst = memref_cast %dst : memref to memref<*xi32> - call @mcuMemHostRegisterInt32(%cast_dst) : (memref<*xi32>) -> () + call @mgpuMemHostRegisterInt32(%cast_dst) : (memref<*xi32>) -> () gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %one, %grid_y = %one, %grid_z = %one) threads(%tx, %ty, %tz) in (%block_x = %sx, %block_y = %one, %block_z = %one) { %t0 = index_cast %tx : index to i32 @@ -25,5 +25,5 @@ return } -func @mcuMemHostRegisterInt32(%memref : memref<*xi32>) +func @mgpuMemHostRegisterInt32(%memref : memref<*xi32>) func @print_memref_i32(%memref : memref<*xi32>) diff --git a/mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp b/mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp --- a/mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp +++ b/mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp @@ -21,54 +21,50 @@ #include "cuda.h" -namespace { -int32_t reportErrorIfAny(CUresult result, const char *where) { - if (result != CUDA_SUCCESS) { - llvm::errs() << "CUDA failed with " << result << " in " << where << "\n"; - } - return result; +#define CUDA_REPORT_IF_ERROR(expr) \ + [](CUresult result) { \ + if (!result) \ + return; \ + const char *name = nullptr; \ + cuGetErrorName(result, &name); \ + if (!name) \ + name = ""; \ + llvm::errs() << "'" << #expr << "' failed with '" << name << "'\n"; \ + }(expr) + +extern "C" CUmodule mgpuModuleLoad(void *data) { + CUmodule module = nullptr; + CUDA_REPORT_IF_ERROR(cuModuleLoadData(&module, data)); + return module; } -} // anonymous namespace -extern "C" int32_t mgpuModuleLoad(void **module, void *data) { - int32_t err = reportErrorIfAny( - cuModuleLoadData(reinterpret_cast(module), data), - "ModuleLoad"); - return err; -} - -extern "C" int32_t mgpuModuleGetFunction(void **function, void *module, - const char *name) { - return reportErrorIfAny( - cuModuleGetFunction(reinterpret_cast(function), - reinterpret_cast(module), name), - "GetFunction"); +extern "C" CUfunction mgpuModuleGetFunction(CUmodule module, const char *name) { + CUfunction function = nullptr; + CUDA_REPORT_IF_ERROR(cuModuleGetFunction(&function, module, name)); + return function; } // The wrapper uses intptr_t instead of CUDA's unsigned int to match // the type of MLIR's index type. This avoids the need for casts in the // generated MLIR code. -extern "C" int32_t mgpuLaunchKernel(void *function, intptr_t gridX, - intptr_t gridY, intptr_t gridZ, - intptr_t blockX, intptr_t blockY, - intptr_t blockZ, int32_t smem, void *stream, - void **params, void **extra) { - return reportErrorIfAny( - cuLaunchKernel(reinterpret_cast(function), gridX, gridY, - gridZ, blockX, blockY, blockZ, smem, - reinterpret_cast(stream), params, extra), - "LaunchKernel"); +extern "C" void mgpuLaunchKernel(CUfunction function, intptr_t gridX, + intptr_t gridY, intptr_t gridZ, + intptr_t blockX, intptr_t blockY, + intptr_t blockZ, int32_t smem, CUstream stream, + void **params, void **extra) { + CUDA_REPORT_IF_ERROR(cuLaunchKernel(function, gridX, gridY, gridZ, blockX, + blockY, blockZ, smem, stream, params, + extra)); } -extern "C" void *mgpuGetStreamHelper() { - CUstream stream; - reportErrorIfAny(cuStreamCreate(&stream, CU_STREAM_DEFAULT), "StreamCreate"); +extern "C" CUstream mgpuStreamCreate() { + CUstream stream = nullptr; + CUDA_REPORT_IF_ERROR(cuStreamCreate(&stream, CU_STREAM_NON_BLOCKING)); return stream; } -extern "C" int32_t mgpuStreamSynchronize(void *stream) { - return reportErrorIfAny( - cuStreamSynchronize(reinterpret_cast(stream)), "StreamSync"); +extern "C" void mgpuStreamSynchronize(CUstream stream) { + CUDA_REPORT_IF_ERROR(cuStreamSynchronize(stream)); } /// Helper functions for writing mlir example code @@ -76,14 +72,13 @@ // Allows to register byte array with the CUDA runtime. Helpful until we have // transfer functions implemented. extern "C" void mgpuMemHostRegister(void *ptr, uint64_t sizeBytes) { - reportErrorIfAny(cuMemHostRegister(ptr, sizeBytes, /*flags=*/0), - "MemHostRegister"); + CUDA_REPORT_IF_ERROR(cuMemHostRegister(ptr, sizeBytes, /*flags=*/0)); } // Allows to register a MemRef with the CUDA runtime. Initializes array with // value. Helpful until we have transfer functions implemented. template -void mcuMemHostRegisterMemRef(const DynamicMemRefType &mem_ref, T value) { +void mgpuMemHostRegisterMemRef(const DynamicMemRefType &mem_ref, T value) { llvm::SmallVector denseStrides(mem_ref.rank); llvm::ArrayRef sizes(mem_ref.sizes, mem_ref.rank); llvm::ArrayRef strides(mem_ref.strides, mem_ref.rank); @@ -103,12 +98,12 @@ mgpuMemHostRegister(pointer, count * sizeof(T)); } -extern "C" void mcuMemHostRegisterFloat(int64_t rank, void *ptr) { +extern "C" void mgpuMemHostRegisterFloat(int64_t rank, void *ptr) { UnrankedMemRefType mem_ref = {rank, ptr}; - mcuMemHostRegisterMemRef(DynamicMemRefType(mem_ref), 1.23f); + mgpuMemHostRegisterMemRef(DynamicMemRefType(mem_ref), 1.23f); } -extern "C" void mcuMemHostRegisterInt32(int64_t rank, void *ptr) { +extern "C" void mgpuMemHostRegisterInt32(int64_t rank, void *ptr) { UnrankedMemRefType mem_ref = {rank, ptr}; - mcuMemHostRegisterMemRef(DynamicMemRefType(mem_ref), 123); + mgpuMemHostRegisterMemRef(DynamicMemRefType(mem_ref), 123); } diff --git a/mlir/tools/mlir-rocm-runner/rocm-runtime-wrappers.cpp b/mlir/tools/mlir-rocm-runner/rocm-runtime-wrappers.cpp --- a/mlir/tools/mlir-rocm-runner/rocm-runtime-wrappers.cpp +++ b/mlir/tools/mlir-rocm-runner/rocm-runtime-wrappers.cpp @@ -21,56 +21,52 @@ #include "hip/hip_runtime.h" -namespace { -int32_t reportErrorIfAny(hipError_t result, const char *where) { - if (result != hipSuccess) { - llvm::errs() << "HIP failed with " << result << " in " << where << "\n"; - } - return result; +#define HIP_REPORT_IF_ERROR(expr) \ + [](hipError_t result) { \ + if (!result) \ + return; \ + const char *name = nullptr; \ + hipGetErrorName(result, &name); \ + if (!name) \ + name = ""; \ + llvm::errs() << "'" << #expr << "' failed with '" << name << "'\n"; \ + }(expr) + +extern "C" hipModule_t mgpuModuleLoad(void *data) { + hipModule_t module = nullptr; + HIP_REPORT_IF_ERROR(hipModuleLoadData(&module, data)); + return module; } -} // anonymous namespace -extern "C" int32_t mgpuModuleLoad(void **module, void *data) { - int32_t err = reportErrorIfAny( - hipModuleLoadData(reinterpret_cast(module), data), - "ModuleLoad"); - return err; -} - -extern "C" int32_t mgpuModuleGetFunction(void **function, void *module, - const char *name) { - return reportErrorIfAny( - hipModuleGetFunction(reinterpret_cast(function), - reinterpret_cast(module), name), - "GetFunction"); +extern "C" hipFunction_t mgpuModuleGetFunction(hipModule_t module, + const char *name) { + hipFunction_t function = nullptr; + HIP_REPORT_IF_ERROR(hipModuleGetFunction(&function, module, name)); + return function; } // The wrapper uses intptr_t instead of ROCM's unsigned int to match // the type of MLIR's index type. This avoids the need for casts in the // generated MLIR code. -extern "C" int32_t mgpuLaunchKernel(void *function, intptr_t gridX, - intptr_t gridY, intptr_t gridZ, - intptr_t blockX, intptr_t blockY, - intptr_t blockZ, int32_t smem, void *stream, - void **params, void **extra) { - return reportErrorIfAny( - hipModuleLaunchKernel(reinterpret_cast(function), gridX, - gridY, gridZ, blockX, blockY, blockZ, smem, - reinterpret_cast(stream), params, - extra), - "LaunchKernel"); +extern "C" void mgpuLaunchKernel(hipFunction_t function, intptr_t gridX, + intptr_t gridY, intptr_t gridZ, + intptr_t blockX, intptr_t blockY, + intptr_t blockZ, int32_t smem, + hipStream_t stream, void **params, + void **extra) { + HIP_REPORT_IF_ERROR(hipModuleLaunchKernel(function, gridX, gridY, gridZ, + blockX, blockY, blockZ, smem, + stream, params, extra)); } -extern "C" void *mgpuGetStreamHelper() { - hipStream_t stream; - reportErrorIfAny(hipStreamCreate(&stream), "StreamCreate"); +extern "C" void *mgpuStreamCreate() { + hipStream_t stream = nullptr; + HIP_REPORT_IF_ERROR(hipStreamCreate(&stream)); return stream; } -extern "C" int32_t mgpuStreamSynchronize(void *stream) { - return reportErrorIfAny( - hipStreamSynchronize(reinterpret_cast(stream)), - "StreamSync"); +extern "C" int32_t mgpuStreamSynchronize(hipStream_t stream) { + return HIP_REPORT_IF_ERROR(hipStreamSynchronize(stream)); } /// Helper functions for writing mlir example code @@ -78,8 +74,8 @@ // Allows to register byte array with the ROCM runtime. Helpful until we have // transfer functions implemented. extern "C" void mgpuMemHostRegister(void *ptr, uint64_t sizeBytes) { - reportErrorIfAny(hipHostRegister(ptr, sizeBytes, /*flags=*/0), - "MemHostRegister"); + HIP_REPORT_IF_ERROR(hipHostRegister(ptr, sizeBytes, /*flags=*/0), + "MemHostRegister"); } // Allows to register a MemRef with the ROCM runtime. Initializes array with @@ -120,8 +116,8 @@ template void mgpuMemGetDevicePointer(T *hostPtr, T **devicePtr) { - reportErrorIfAny(hipSetDevice(0), "hipSetDevice"); - reportErrorIfAny( + HIP_REPORT_IF_ERROR(hipSetDevice(0), "hipSetDevice"); + HIP_REPORT_IF_ERROR( hipHostGetDevicePointer((void **)devicePtr, hostPtr, /*flags=*/0), "hipHostGetDevicePointer"); }