diff --git a/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp b/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp --- a/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp +++ b/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp @@ -297,7 +297,7 @@ Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const { if (cast(op).asyncToken()) - return failure(); // The gpu.wait is async. + return rewriter.notifyMatchFailure(op, "Cannot convert async op."); Location loc = op->getLoc(); @@ -320,7 +320,7 @@ Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const { if (!cast(op).asyncToken()) - return failure(); // The gpu.wait is not async. + return rewriter.notifyMatchFailure(op, "Can only convert async op."); Location loc = op->getLoc(); @@ -440,6 +440,11 @@ // %5 = // call %launchKernel(%3, , 0, %4, %5, nullptr) // call %streamSynchronize(%4) +// call %streamDestroy(%4) +// call %moduleUnload(%1) +// +// If the op is async, the stream corresponds to the (single) async dependency +// as well as the async token the op produces. LogicalResult ConvertLaunchFuncOpToGpuRuntimeCallPattern::matchAndRewrite( Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const { @@ -448,6 +453,18 @@ op, "Cannot convert if operands aren't of LLVM type."); auto launchOp = cast(op); + + if (launchOp.asyncDependencies().size() > 1) + return rewriter.notifyMatchFailure( + op, "Cannot convert with more than one async dependency."); + + // Fail when the synchronous version of the op has async dependencies. The + // lowering destroys the stream, and we do not want to check that there is no + // use of the stream after this op. + if (!launchOp.asyncToken() && !launchOp.asyncDependencies().empty()) + return rewriter.notifyMatchFailure( + op, "Cannot convert non-async op with async dependencies."); + Location loc = launchOp.getLoc(); // Create an LLVM global with CUBIN extracted from the kernel annotation and @@ -478,8 +495,11 @@ loc, rewriter, {module.getResult(0), kernelName}); auto zero = rewriter.create(loc, llvmInt32Type, rewriter.getI32IntegerAttr(0)); - // Grab the global stream needed for execution. - auto stream = streamCreateCallBuilder.create(loc, rewriter, {}); + auto adaptor = gpu::LaunchFuncOpAdaptor(operands, op->getAttrDictionary()); + Value stream = + adaptor.asyncDependencies().empty() + ? streamCreateCallBuilder.create(loc, rewriter, {}).getResult(0) + : adaptor.asyncDependencies().front(); // Create array of pointers to kernel arguments. auto kernelParams = generateParamsArray(launchOp, operands, rewriter); auto nullpointer = rewriter.create(loc, llvmPointerPointerType); @@ -487,15 +507,22 @@ loc, rewriter, {function.getResult(0), launchOp.gridSizeX(), launchOp.gridSizeY(), launchOp.gridSizeZ(), launchOp.blockSizeX(), launchOp.blockSizeY(), - launchOp.blockSizeZ(), zero, /* sharedMemBytes */ - stream.getResult(0), /* stream */ - kernelParams, /* kernel params */ - nullpointer /* extra */}); - streamSynchronizeCallBuilder.create(loc, rewriter, stream.getResult(0)); - streamDestroyCallBuilder.create(loc, rewriter, stream.getResult(0)); + launchOp.blockSizeZ(), /*sharedMemBytes=*/zero, stream, kernelParams, + /*extra=*/nullpointer}); + + if (launchOp.asyncToken()) { + // Async launch: make dependent ops use the same stream. + rewriter.replaceOp(op, {stream}); + } else { + // Synchronize with host and destroy stream. This must be the stream created + // above (with no other uses) because we check that the synchronous version + // does not have any async dependencies. + streamSynchronizeCallBuilder.create(loc, rewriter, stream); + streamDestroyCallBuilder.create(loc, rewriter, stream); + rewriter.eraseOp(op); + } moduleUnloadCallBuilder.create(loc, rewriter, module.getResult(0)); - rewriter.eraseOp(op); return success(); }