diff --git a/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp b/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp --- a/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp +++ b/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp @@ -151,6 +151,12 @@ "mgpuMemFree", llvmVoidType, {llvmPointerType /* void *ptr */, llvmPointerType /* void *stream */}}; + FunctionCallBuilder memcpyCallBuilder = { + "mgpuMemcpy", + llvmVoidType, + {llvmPointerType /* void *dst */, llvmPointerType /* void *src */, + llvmIntPtrType /* intptr_t sizeBytes */, + llvmPointerType /* void *stream */}}; }; /// A rewrite pattern to convert gpu.host_register operations into a GPU runtime @@ -268,6 +274,20 @@ return success(); } }; + +/// A rewrite pattern to convert gpu.memcpy operations into a GPU runtime +/// call. Currently it supports CUDA and ROCm (HIP). +class ConvertMemcpyOpToGpuRuntimeCallPattern + : public ConvertOpToGpuRuntimeCallPattern { +public: + ConvertMemcpyOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter) + : ConvertOpToGpuRuntimeCallPattern(typeConverter) {} + +private: + LogicalResult + matchAndRewrite(gpu::MemcpyOp memcpyOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override; +}; } // namespace void GpuToLLVMConversionPass::runOnOperation() { @@ -643,6 +663,50 @@ return success(); } +LogicalResult ConvertMemcpyOpToGpuRuntimeCallPattern::matchAndRewrite( + gpu::MemcpyOp memcpyOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const { + auto memRefType = memcpyOp.src().getType().cast(); + + if (failed(areAllLLVMTypes(memcpyOp, operands, rewriter)) || + !isSupportedMemRefType(memRefType) || + failed(isAsyncWithOneDependency(rewriter, memcpyOp))) + return failure(); + + auto loc = memcpyOp.getLoc(); + auto adaptor = gpu::MemcpyOpAdaptor(operands, memcpyOp->getAttrDictionary()); + + MemRefDescriptor srcDesc(adaptor.src()); + + Value numElements = + memRefType.hasStaticShape() + ? createIndexConstant(rewriter, loc, memRefType.getNumElements()) + // For identity layouts (verified above), the number of elements is + // stride[0] * size[0]. + : rewriter.create(loc, srcDesc.stride(rewriter, loc, 0), + srcDesc.size(rewriter, loc, 0)); + + Type elementPtrType = getElementPtrType(memRefType); + Value nullPtr = rewriter.create(loc, elementPtrType); + Value gepPtr = rewriter.create( + loc, elementPtrType, ArrayRef{nullPtr, numElements}); + auto sizeBytes = + rewriter.create(loc, getIndexType(), gepPtr); + + auto src = rewriter.create( + loc, llvmPointerType, srcDesc.alignedPtr(rewriter, loc)); + auto dst = rewriter.create( + loc, llvmPointerType, + MemRefDescriptor(adaptor.dst()).alignedPtr(rewriter, loc)); + + auto stream = adaptor.asyncDependencies().front(); + memcpyCallBuilder.create(loc, rewriter, {dst, src, sizeBytes, stream}); + + rewriter.replaceOp(memcpyOp, {stream}); + + return success(); +} + std::unique_ptr> mlir::createGpuToLLVMConversionPass(StringRef gpuBinaryAnnotation) { return std::make_unique(gpuBinaryAnnotation); @@ -658,6 +722,7 @@ patterns.insert(converter); patterns.insert( diff --git a/mlir/test/Conversion/GPUCommon/lower-memcpy-to-gpu-runtime-calls.mlir b/mlir/test/Conversion/GPUCommon/lower-memcpy-to-gpu-runtime-calls.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/GPUCommon/lower-memcpy-to-gpu-runtime-calls.mlir @@ -0,0 +1,19 @@ +// RUN: mlir-opt -allow-unregistered-dialect %s --gpu-to-llvm | FileCheck %s + +module attributes {gpu.container_module} { + + // CHECK: func @foo + func @foo(%dst : memref<7xf32, 1>, %src : memref<7xf32>) { + // CHECK: %[[t0:.*]] = llvm.call @mgpuStreamCreate + %t0 = gpu.wait async + // CHECK: %[[size_bytes:.*]] = llvm.ptrtoint + // CHECK: %[[src:.*]] = llvm.bitcast + // CHECK: %[[dst:.*]] = llvm.bitcast + // CHECK: llvm.call @mgpuMemcpy(%[[dst]], %[[src]], %[[size_bytes]], %[[t0]]) + %t1 = gpu.memcpy async [%t0] %dst, %src : memref<7xf32, 1>, memref<7xf32> + // CHECK: llvm.call @mgpuStreamSynchronize(%[[t0]]) + // CHECK: llvm.call @mgpuStreamDestroy(%[[t0]]) + gpu.wait [%t1] + return + } +} diff --git a/mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp b/mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp --- a/mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp +++ b/mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp @@ -117,6 +117,13 @@ CUDA_REPORT_IF_ERROR(cuMemFree(reinterpret_cast(ptr))); } +extern "C" void mgpuMemcpy(void *dst, void *src, uint64_t sizeBytes, + CUstream stream) { + CUDA_REPORT_IF_ERROR(cuMemcpyAsync(reinterpret_cast(dst), + reinterpret_cast(src), + sizeBytes, stream)); +} + /// Helper functions for writing mlir example code // Allows to register byte array with the CUDA runtime. Helpful until we have diff --git a/mlir/tools/mlir-rocm-runner/rocm-runtime-wrappers.cpp b/mlir/tools/mlir-rocm-runner/rocm-runtime-wrappers.cpp --- a/mlir/tools/mlir-rocm-runner/rocm-runtime-wrappers.cpp +++ b/mlir/tools/mlir-rocm-runner/rocm-runtime-wrappers.cpp @@ -118,6 +118,11 @@ HIP_REPORT_IF_ERROR(hipMemFree(ptr)); } +extern "C" void mgpuMemcpy(void *dst, void *src, uint64_t sizeBytes, + hipStream_t stream) { + HIP_REPORT_IF_ERROR(hipMemcpyAsync(dst, src, sizeBytes, stream)); +} + /// Helper functions for writing mlir example code // Allows to register byte array with the ROCM runtime. Helpful until we have