diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -91,6 +91,7 @@ def GpuToLLVMConversionPass : Pass<"gpu-to-llvm", "ModuleOp"> { let summary = "Convert GPU dialect to LLVM dialect with GPU runtime calls"; let constructor = "mlir::createGpuToLLVMConversionPass()"; + let dependentDialects = ["LLVM::LLVMDialect"]; let options = [ Option<"gpuBinaryAnnotation", "gpu-binary-annotation", "std::string", "", "Annotation attribute string for GPU binary">, diff --git a/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp b/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp --- a/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp +++ b/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp @@ -157,6 +157,34 @@ ConversionPatternRewriter &rewriter) const override; }; +/// A rewrite pattern to convert gpu.wait operations into a GPU runtime +/// call. Currently it supports CUDA and ROCm (HIP). +class ConvertWaitOpToGpuRuntimeCallPattern + : public ConvertOpToGpuRuntimeCallPattern { +public: + ConvertWaitOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter) + : ConvertOpToGpuRuntimeCallPattern(typeConverter) {} + +private: + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override; +}; + +/// A rewrite pattern to convert gpu.wait async operations into a GPU runtime +/// call. Currently it supports CUDA and ROCm (HIP). +class ConvertWaitAsyncOpToGpuRuntimeCallPattern + : public ConvertOpToGpuRuntimeCallPattern { +public: + ConvertWaitAsyncOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter) + : ConvertOpToGpuRuntimeCallPattern(typeConverter) {} + +private: + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override; +}; + /// A rewrite patter to convert gpu.launch_func operations into a sequence of /// GPU runtime calls. Currently it supports CUDA and ROCm (HIP). /// @@ -257,6 +285,62 @@ return success(); } +// Converts `gpu.wait` to runtime calls. The operands are all CUDA or ROCm +// streams (i.e. void*). The converted op synchronizes the host with every +// stream and then destroys it. That is, it assumes that the stream is not used +// afterwards. In case this isn't correct, we will get a runtime error. +// Eventually, we will have a pass that guarantees this property. +LogicalResult ConvertWaitOpToGpuRuntimeCallPattern::matchAndRewrite( + Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const { + if (cast(op).asyncToken()) + return failure(); // The gpu.wait is async. + + Location loc = op->getLoc(); + + for (auto asyncDependency : operands) + streamSynchronizeCallBuilder.create(loc, rewriter, {asyncDependency}); + for (auto asyncDependency : operands) + streamDestroyCallBuilder.create(loc, rewriter, {asyncDependency}); + + rewriter.eraseOp(op); + return success(); +} + +// Converts `gpu.wait async` to runtime calls. The result is a new stream that +// is synchronized with all operands, which are CUDA or ROCm streams (i.e. +// void*). We create and record an event after the definition of the stream +// and make the new stream wait on that event before destroying it again. This +// assumes that there is no other use between the definition and this op, and +// the plan is to have a pass that guarantees this property. +LogicalResult ConvertWaitAsyncOpToGpuRuntimeCallPattern::matchAndRewrite( + Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const { + if (!cast(op).asyncToken()) + return failure(); // The gpu.wait is not async. + + Location loc = op->getLoc(); + + auto insertionPoint = rewriter.saveInsertionPoint(); + SmallVector events; + for (auto asyncDependency : operands) { + auto *defOp = asyncDependency.getDefiningOp(); + rewriter.setInsertionPointAfter(defOp); + auto event = eventCreateCallBuilder.create(loc, rewriter, {}).getResult(0); + eventRecordCallBuilder.create(loc, rewriter, {event, defOp->getResult(0)}); + events.push_back(event); + } + rewriter.restoreInsertionPoint(insertionPoint); + auto stream = streamCreateCallBuilder.create(loc, rewriter, {}).getResult(0); + for (auto event : events) + streamWaitEventCallBuilder.create(loc, rewriter, {stream, event}); + for (auto event : events) + eventDestroyCallBuilder.create(loc, rewriter, {event}); + rewriter.replaceOp(op, {stream}); + + return success(); +} + // Creates a struct containing all kernel parameters on the stack and returns // an array of type-erased pointers to the fields of the struct. The array can // then be passed to the CUDA / ROCm (HIP) kernel launch calls. @@ -411,7 +495,13 @@ void mlir::populateGpuToLLVMConversionPatterns( LLVMTypeConverter &converter, OwningRewritePatternList &patterns, StringRef gpuBinaryAnnotation) { - patterns.insert(converter); + converter.addConversion( + [context = &converter.getContext()](gpu::AsyncTokenType type) -> Type { + return LLVM::LLVMType::getInt8PtrTy(context); + }); + patterns.insert(converter); patterns.insert( converter, gpuBinaryAnnotation); patterns.insert(&converter.getContext()); diff --git a/mlir/test/Conversion/GPUCommon/lower-wait-to-gpu-runtime-calls.mlir b/mlir/test/Conversion/GPUCommon/lower-wait-to-gpu-runtime-calls.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/GPUCommon/lower-wait-to-gpu-runtime-calls.mlir @@ -0,0 +1,21 @@ +// RUN: mlir-opt -allow-unregistered-dialect %s --gpu-to-llvm | FileCheck %s + +module attributes {gpu.container_module} { + + func @foo() { + // CHECK: %[[t0:.*]] = llvm.call @mgpuStreamCreate + // CHECK: %[[e0:.*]] = llvm.call @mgpuEventCreate + // CHECK: llvm.call @mgpuEventRecord(%[[e0]], %[[t0]]) + %t0 = gpu.wait async + // CHECK: %[[t1:.*]] = llvm.call @mgpuStreamCreate + // CHECK: llvm.call @mgpuStreamWaitEvent(%[[t1]], %[[e0]]) + // CHECK: llvm.call @mgpuEventDestroy(%[[e0]]) + %t1 = gpu.wait async [%t0] + // CHECK: llvm.call @mgpuStreamSynchronize(%[[t0]]) + // CHECK: llvm.call @mgpuStreamSynchronize(%[[t1]]) + // CHECK: llvm.call @mgpuStreamDestroy(%[[t0]]) + // CHECK: llvm.call @mgpuStreamDestroy(%[[t1]]) + gpu.wait [%t0, %t1] + return + } +}