Index: mlir/include/mlir/Dialect/GPU/IR/GPUOps.td =================================================================== --- mlir/include/mlir/Dialect/GPU/IR/GPUOps.td +++ mlir/include/mlir/Dialect/GPU/IR/GPUOps.td @@ -929,6 +929,19 @@ let assemblyFormat = "$value attr-dict `:` type($value)"; } +def GPU_HostUnregisterOp : GPU_Op<"host_unregister">, + Arguments<(ins AnyUnrankedMemRef:$value)> { + let summary = "Unregisters a memref for access from device."; + let description = [{ + This op unmaps the provided host buffer from the device address space. + + This operation may not be supported in every environment, there is not yet a + way to check at runtime whether this feature is supported. + }]; + + let assemblyFormat = "$value attr-dict `:` type($value)"; +} + def GPU_WaitOp : GPU_Op<"wait", [GPU_AsyncOpInterface]> { let summary = "Wait for async gpu ops to complete."; let description = [{ Index: mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp =================================================================== --- mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp +++ mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp @@ -161,6 +161,12 @@ {llvmIntPtrType /* intptr_t rank */, llvmPointerType /* void *memrefDesc */, llvmIntPtrType /* intptr_t elementSizeBytes */}}; + FunctionCallBuilder hostUnregisterCallBuilder = { + "mgpuMemHostUnregisterMemRef", + llvmVoidType, + {llvmIntPtrType /* intptr_t rank */, + llvmPointerType /* void *memrefDesc */, + llvmIntPtrType /* intptr_t elementSizeBytes */}}; FunctionCallBuilder allocCallBuilder = { "mgpuMemAlloc", llvmPointerType /* void * */, @@ -202,6 +208,20 @@ ConversionPatternRewriter &rewriter) const override; }; +class ConvertHostUnregisterOpToGpuRuntimeCallPattern + : public ConvertOpToGpuRuntimeCallPattern { +public: + ConvertHostUnregisterOpToGpuRuntimeCallPattern( + LLVMTypeConverter &typeConverter) + : ConvertOpToGpuRuntimeCallPattern(typeConverter) { + } + +private: + LogicalResult + matchAndRewrite(gpu::HostUnregisterOp hostUnregisterOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + /// A rewrite pattern to convert gpu.alloc operations into a GPU runtime /// call. Currently it supports CUDA and ROCm (HIP). class ConvertAllocOpToGpuRuntimeCallPattern @@ -446,6 +466,28 @@ return success(); } +LogicalResult ConvertHostUnregisterOpToGpuRuntimeCallPattern::matchAndRewrite( + gpu::HostUnregisterOp hostUnregisterOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Operation *op = hostUnregisterOp.getOperation(); + if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter))) + return failure(); + + Location loc = op->getLoc(); + + auto memRefType = hostUnregisterOp.getValue().getType(); + auto elementType = memRefType.cast().getElementType(); + auto elementSize = getSizeInBytes(loc, elementType, rewriter); + + auto arguments = getTypeConverter()->promoteOperands( + loc, op->getOperands(), adaptor.getOperands(), rewriter); + arguments.push_back(elementSize); + hostUnregisterCallBuilder.create(loc, rewriter, arguments); + + rewriter.eraseOp(op); + return success(); +} + LogicalResult ConvertAllocOpToGpuRuntimeCallPattern::matchAndRewrite( gpu::AllocOp allocOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { @@ -928,6 +970,7 @@ patterns.add *descriptor, + int64_t elementSizeBytes) { + auto *ptr = descriptor->data + descriptor->offset * elementSizeBytes; + mgpuMemHostUnregister(ptr); +} + extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuSetDefaultDevice(int32_t device) { defaultDevice = device; } Index: mlir/lib/ExecutionEngine/RocmRuntimeWrappers.cpp =================================================================== --- mlir/lib/ExecutionEngine/RocmRuntimeWrappers.cpp +++ mlir/lib/ExecutionEngine/RocmRuntimeWrappers.cpp @@ -152,6 +152,22 @@ mgpuMemHostRegister(ptr, sizeBytes); } +// Allows to unregister byte array with the ROCM runtime. Helpful until we have +// transfer functions implemented. +extern "C" void mgpuMemHostUnregister(void *ptr) { + HIP_REPORT_IF_ERROR(hipHostUnregister(ptr)); +} + +// Allows to unregister a MemRef with the ROCm runtime. Helpful until we have +// transfer functions implemented. +extern "C" void +mgpuMemHostUnregisterMemRef(int64_t rank, + StridedMemRefType *descriptor, + int64_t elementSizeBytes) { + auto ptr = descriptor->data + descriptor->offset * elementSizeBytes; + mgpuMemHostUnregister(ptr); +} + template void mgpuMemGetDevicePointer(T *hostPtr, T **devicePtr) { HIP_REPORT_IF_ERROR(hipSetDevice(0));