diff --git a/mlir/include/mlir/Dialect/GPU/GPUDialect.h b/mlir/include/mlir/Dialect/GPU/GPUDialect.h --- a/mlir/include/mlir/Dialect/GPU/GPUDialect.h +++ b/mlir/include/mlir/Dialect/GPU/GPUDialect.h @@ -19,6 +19,7 @@ #include "mlir/IR/FunctionSupport.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpImplementation.h" +#include "mlir/IR/StandardTypes.h" #include "mlir/IR/SymbolTable.h" #include "mlir/Interfaces/SideEffectInterfaces.h" diff --git a/mlir/include/mlir/Dialect/GPU/GPUOps.td b/mlir/include/mlir/Dialect/GPU/GPUOps.td --- a/mlir/include/mlir/Dialect/GPU/GPUOps.td +++ b/mlir/include/mlir/Dialect/GPU/GPUOps.td @@ -804,4 +804,79 @@ }]; } +def GPU_AllocOp : GPU_Op<"alloc", [ + GPU_AsyncOpInterface, + AttrSizedOperandSegments, + MemoryEffects<[MemAlloc]> + ]> { + + let summary = "GPU memory allocation operation."; + let description = [{ + The `gpu.alloc` operation allocates a region of memory on the GPU. It is + similar to the `std.alloc` op, but supports asynchronous GPU execution. + + The op does not execute before all async dependencies have finished + executing. + + If the `async` keyword is present, the op is executed asynchronously (i.e. + it does not block until the execution has finished on the device). In + that case, it also returns a !gpu.async.token. + + Example: + + ```mlir + %memref, %token = gpu.alloc async [%dep] (%width) : memref<64x?xf32, 1> + ``` + }]; + + let arguments = (ins Variadic:$asyncDependencies, + Variadic:$dynamicSizes, Variadic:$symbolOperands); + let results = (outs Res]>:$memref, + Optional:$asyncToken); + + let extraClassDeclaration = [{ + MemRefType getType() { return memref().getType().cast(); } + }]; + + let assemblyFormat = [{ + custom(type($asyncToken), $asyncDependencies) ` ` + `(` $dynamicSizes `)` (`` `[` $symbolOperands^ `]`)? attr-dict `:` type($memref) + }]; +} + +def GPU_DeallocOp : GPU_Op<"dealloc", [ + GPU_AsyncOpInterface, MemoryEffects<[MemFree]> + ]> { + + let summary = "GPU memory deallocation operation"; + + let description = [{ + The `gpu.dealloc` operation frees the region of memory referenced by a + memref which was originally created by the `gpu.alloc` operation. It is + similar to the `std.dealloc` op, but supports asynchronous GPU execution. + + The op does not execute before all async dependencies have finished + executing. + + If the `async` keyword is present, the op is executed asynchronously (i.e. + it does not block until the execution has finished on the device). In + that case, it returns a !gpu.async.token. + + Example: + + ```mlir + %token = gpu.dealloc async [%dep] %memref : memref<8x64xf32, 1> + ``` + }]; + + let arguments = (ins Variadic:$asyncDependencies, + Arg:$memref); + let results = (outs Optional:$asyncToken); + + let assemblyFormat = [{ + custom(type($asyncToken), $asyncDependencies) + $memref attr-dict `:` type($memref) + }]; +} + #endif // GPU_OPS diff --git a/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp b/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp --- a/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp +++ b/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp @@ -142,6 +142,15 @@ {llvmIntPtrType /* intptr_t rank */, llvmPointerType /* void *memrefDesc */, llvmIntPtrType /* intptr_t elementSizeBytes */}}; + FunctionCallBuilder allocCallBuilder = { + "mgpuMemAlloc", + llvmPointerType /* void * */, + {llvmIntPtrType /* intptr_t sizeBytes */, + llvmPointerType /* void *stream */}}; + FunctionCallBuilder deallocCallBuilder = { + "mgpuMemFree", + llvmVoidType, + {llvmPointerType /* void *ptr */, llvmPointerType /* void *stream */}}; }; /// A rewrite pattern to convert gpu.host_register operations into a GPU runtime @@ -158,6 +167,34 @@ ConversionPatternRewriter &rewriter) const override; }; +/// A rewrite pattern to convert gpu.alloc operations into a GPU runtime +/// call. Currently it supports CUDA and ROCm (HIP). +class ConvertAllocOpToGpuRuntimeCallPattern + : public ConvertOpToGpuRuntimeCallPattern { +public: + ConvertAllocOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter) + : ConvertOpToGpuRuntimeCallPattern(typeConverter) {} + +private: + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override; +}; + +/// A rewrite pattern to convert gpu.dealloc operations into a GPU runtime +/// call. Currently it supports CUDA and ROCm (HIP). +class ConvertDeallocOpToGpuRuntimeCallPattern + : public ConvertOpToGpuRuntimeCallPattern { +public: + ConvertDeallocOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter) + : ConvertOpToGpuRuntimeCallPattern(typeConverter) {} + +private: + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override; +}; + /// A rewrite pattern to convert gpu.wait operations into a GPU runtime /// call. Currently it supports CUDA and ROCm (HIP). class ConvertWaitOpToGpuRuntimeCallPattern @@ -231,7 +268,6 @@ return success(); } }; - } // namespace void GpuToLLVMConversionPass::runOnOperation() { @@ -260,17 +296,35 @@ builder.getSymbolRefAttr(function), arguments); } -// Returns whether value is of LLVM type. -static bool isLLVMType(Value value) { - return value.getType().isa(); +// Returns whether all operands are of LLVM type. +static LogicalResult areAllLLVMTypes(Operation *op, ValueRange operands, + ConversionPatternRewriter &rewriter) { + if (!llvm::all_of(operands, [](Value value) { + return value.getType().isa(); + })) + return rewriter.notifyMatchFailure( + op, "Cannot convert if operands aren't of LLVM type."); + return success(); +} + +static LogicalResult +isAsyncWithOneDependency(ConversionPatternRewriter &rewriter, + gpu::AsyncOpInterface op) { + if (op.getAsyncDependencies().size() != 1) + return rewriter.notifyMatchFailure( + op, "Can only convert with exactly one async dependency."); + + if (!op.getAsyncToken()) + return rewriter.notifyMatchFailure(op, "Can convert only async version."); + + return success(); } LogicalResult ConvertHostRegisterOpToGpuRuntimeCallPattern::matchAndRewrite( Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const { - if (!llvm::all_of(operands, isLLVMType)) - return rewriter.notifyMatchFailure( - op, "Cannot convert if operands aren't of LLVM type."); + if (failed(areAllLLVMTypes(op, operands, rewriter))) + return failure(); Location loc = op->getLoc(); @@ -287,6 +341,71 @@ return success(); } +LogicalResult ConvertAllocOpToGpuRuntimeCallPattern::matchAndRewrite( + Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const { + auto allocOp = cast(op); + MemRefType memRefType = allocOp.getType(); + + if (failed(areAllLLVMTypes(op, operands, rewriter)) || + !isSupportedMemRefType(memRefType) || + failed( + isAsyncWithOneDependency(rewriter, cast(op)))) + return failure(); + + auto loc = op->getLoc(); + + // Get shape of the memref as values: static sizes are constant + // values and dynamic sizes are passed to 'alloc' as operands. + SmallVector shape; + SmallVector strides; + Value sizeBytes; + getMemRefDescriptorSizes(loc, memRefType, operands, rewriter, shape, strides, + sizeBytes); + + // Allocate the underlying buffer and store a pointer to it in the MemRef + // descriptor. + Type elementPtrType = this->getElementPtrType(memRefType); + auto adaptor = gpu::AllocOpAdaptor(operands, op->getAttrDictionary()); + auto stream = adaptor.asyncDependencies().front(); + Value allocatedPtr = + allocCallBuilder.create(loc, rewriter, {sizeBytes, stream}).getResult(0); + allocatedPtr = + rewriter.create(loc, elementPtrType, allocatedPtr); + + // No alignment. + Value alignedPtr = allocatedPtr; + + // Create the MemRef descriptor. + auto memRefDescriptor = this->createMemRefDescriptor( + loc, memRefType, allocatedPtr, alignedPtr, shape, strides, rewriter); + + rewriter.replaceOp(op, {memRefDescriptor, stream}); + + return success(); +} + +LogicalResult ConvertDeallocOpToGpuRuntimeCallPattern::matchAndRewrite( + Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const { + if (failed(areAllLLVMTypes(op, operands, rewriter)) || + failed( + isAsyncWithOneDependency(rewriter, cast(op)))) + return failure(); + + Location loc = op->getLoc(); + + auto adaptor = gpu::DeallocOpAdaptor(operands, op->getAttrDictionary()); + Value pointer = + MemRefDescriptor(adaptor.memref()).allocatedPtr(rewriter, loc); + auto casted = rewriter.create(loc, llvmPointerType, pointer); + Value stream = adaptor.asyncDependencies().front(); + deallocCallBuilder.create(loc, rewriter, {casted, stream}); + + rewriter.replaceOp(op, {stream}); + return success(); +} + // Converts `gpu.wait` to runtime calls. The operands are all CUDA or ROCm // streams (i.e. void*). The converted op synchronizes the host with every // stream and then destroys it. That is, it assumes that the stream is not used @@ -447,9 +566,8 @@ LogicalResult ConvertLaunchFuncOpToGpuRuntimeCallPattern::matchAndRewrite( Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const { - if (!llvm::all_of(operands, isLLVMType)) - return rewriter.notifyMatchFailure( - op, "Cannot convert if operands aren't of LLVM type."); + if (failed(areAllLLVMTypes(op, operands, rewriter))) + return failure(); auto launchOp = cast(op); @@ -537,9 +655,11 @@ [context = &converter.getContext()](gpu::AsyncTokenType type) -> Type { return LLVM::LLVMType::getInt8PtrTy(context); }); - patterns.insert(converter); + patterns.insert(converter); patterns.insert( converter, gpuBinaryAnnotation); patterns.insert(&converter.getContext()); diff --git a/mlir/test/Conversion/GPUCommon/lower-alloc-to-gpu-runtime-calls.mlir b/mlir/test/Conversion/GPUCommon/lower-alloc-to-gpu-runtime-calls.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/GPUCommon/lower-alloc-to-gpu-runtime-calls.mlir @@ -0,0 +1,19 @@ +// RUN: mlir-opt -allow-unregistered-dialect %s --gpu-to-llvm | FileCheck %s + +module attributes {gpu.container_module} { + func @main() { + // CHECK: %[[stream:.*]] = llvm.call @mgpuStreamCreate() + %0 = gpu.wait async + // CHECK: %[[size_bytes:.*]] = llvm.ptrtoint + // CHECK: llvm.call @mgpuMemAlloc(%[[size_bytes]], %[[stream]]) + %1, %2 = gpu.alloc async [%0] () : memref<13xf32> + // CHECK: %[[float_ptr:.*]] = llvm.extractvalue {{.*}}[0] + // CHECK: %[[void_ptr:.*]] = llvm.bitcast %[[float_ptr]] + // CHECK: llvm.call @mgpuMemFree(%[[void_ptr]], %[[stream]]) + %3 = gpu.dealloc async [%2] %1 : memref<13xf32> + // CHECK: llvm.call @mgpuStreamSynchronize(%[[stream]]) + // CHECK: llvm.call @mgpuStreamDestroy(%[[stream]]) + gpu.wait [%3] + return + } +} diff --git a/mlir/test/Dialect/GPU/ops.mlir b/mlir/test/Dialect/GPU/ops.mlir --- a/mlir/test/Dialect/GPU/ops.mlir +++ b/mlir/test/Dialect/GPU/ops.mlir @@ -144,6 +144,23 @@ } ) {gpu.kernel, sym_name = "kernel_1", type = (f32, memref) -> (), workgroup_attributions = 1: i64} : () -> () } + func @alloc() { + // CHECK-LABEL: func @alloc() + + // CHECK: %[[m0:.*]] = gpu.alloc () : memref<13xf32, 1> + %m0 = gpu.alloc () : memref<13xf32, 1> + // CHECK: gpu.dealloc %[[m0]] : memref<13xf32, 1> + gpu.dealloc %m0 : memref<13xf32, 1> + + %t0 = gpu.wait async + // CHECK: %[[m1:.*]], %[[t1:.*]] = gpu.alloc async [{{.*}}] () : memref<13xf32, 1> + %m1, %t1 = gpu.alloc async [%t0] () : memref<13xf32, 1> + // CHECK: gpu.dealloc async [%[[t1]]] %[[m1]] : memref<13xf32, 1> + %t2 = gpu.dealloc async [%t1] %m1 : memref<13xf32, 1> + + return + } + func @async_token(%arg0 : !gpu.async.token) -> !gpu.async.token { // CHECK-LABEL: func @async_token({{.*}}: !gpu.async.token) // CHECK: return {{.*}} : !gpu.async.token diff --git a/mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp b/mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp --- a/mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp +++ b/mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp @@ -107,6 +107,16 @@ CUDA_REPORT_IF_ERROR(cuEventRecord(event, stream)); } +extern "C" void *mgpuMemAlloc(uint64_t sizeBytes, CUstream /*stream*/) { + CUdeviceptr ptr; + CUDA_REPORT_IF_ERROR(cuMemAlloc(&ptr, sizeBytes)); + return reinterpret_cast(ptr); +} + +extern "C" void mgpuMemFree(void *ptr, CUstream /*stream*/) { + CUDA_REPORT_IF_ERROR(cuMemFree(reinterpret_cast(ptr))); +} + /// Helper functions for writing mlir example code // Allows to register byte array with the CUDA runtime. Helpful until we have diff --git a/mlir/tools/mlir-rocm-runner/rocm-runtime-wrappers.cpp b/mlir/tools/mlir-rocm-runner/rocm-runtime-wrappers.cpp --- a/mlir/tools/mlir-rocm-runner/rocm-runtime-wrappers.cpp +++ b/mlir/tools/mlir-rocm-runner/rocm-runtime-wrappers.cpp @@ -108,6 +108,16 @@ HIP_REPORT_IF_ERROR(hipEventRecord(event, stream)); } +extern "C" void *mgpuMemAlloc(uint64_t sizeBytes, hipStream_t /*stream*/) { + void *ptr; + HIP_REPORT_IF_ERROR(hipMemAlloc(&ptr, sizeBytes)); + return ptr; +} + +extern "C" void mgpuMemFree(void *ptr, hipStream_t /*stream*/) { + HIP_REPORT_IF_ERROR(hipMemFree(ptr)); +} + /// Helper functions for writing mlir example code // Allows to register byte array with the ROCM runtime. Helpful until we have