diff --git a/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp b/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp --- a/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp +++ b/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp @@ -294,7 +294,7 @@ Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const { if (cast(op).asyncToken()) - return failure(); // The gpu.wait is async. + return rewriter.notifyMatchFailure(op, "Cannot convert async op."); Location loc = op->getLoc(); @@ -317,7 +317,7 @@ Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const { if (!cast(op).asyncToken()) - return failure(); // The gpu.wait is not async. + return rewriter.notifyMatchFailure(op, "Can only convert async op."); Location loc = op->getLoc(); @@ -445,6 +445,11 @@ op, "Cannot convert if operands aren't of LLVM type."); auto launchOp = cast(op); + + if (launchOp.asyncDependencies().size() > 1) + return rewriter.notifyMatchFailure( + op, "Cannot convert with more than one async dependency."); + Location loc = launchOp.getLoc(); // Create an LLVM global with CUBIN extracted from the kernel annotation and @@ -475,8 +480,11 @@ loc, rewriter, {module.getResult(0), kernelName}); auto zero = rewriter.create(loc, llvmInt32Type, rewriter.getI32IntegerAttr(0)); - // Grab the global stream needed for execution. - auto stream = streamCreateCallBuilder.create(loc, rewriter, {}); + auto adaptor = gpu::LaunchFuncOpAdaptor(operands, op->getAttrDictionary()); + Value stream = + adaptor.asyncDependencies().empty() + ? streamCreateCallBuilder.create(loc, rewriter, {}).getResult(0) + : adaptor.asyncDependencies().front(); // Create array of pointers to kernel arguments. auto kernelParams = generateParamsArray(launchOp, operands, rewriter); auto nullpointer = rewriter.create(loc, llvmPointerPointerType); @@ -484,13 +492,19 @@ loc, rewriter, {function.getResult(0), launchOp.gridSizeX(), launchOp.gridSizeY(), launchOp.gridSizeZ(), launchOp.blockSizeX(), launchOp.blockSizeY(), - launchOp.blockSizeZ(), zero, /* sharedMemBytes */ - stream.getResult(0), /* stream */ - kernelParams, /* kernel params */ - nullpointer /* extra */}); - streamSynchronizeCallBuilder.create(loc, rewriter, stream.getResult(0)); + launchOp.blockSizeZ(), /*sharedMemBytes=*/zero, stream, kernelParams, + /*extra=*/nullpointer}); + + if (launchOp.asyncToken()) { + // Async launch: reuse stream. + rewriter.replaceOp(op, {stream}); + } else { + // Synchronize with host and destroy stream. + streamSynchronizeCallBuilder.create(loc, rewriter, stream); + streamDestroyCallBuilder.create(loc, rewriter, stream); + rewriter.eraseOp(op); + } - rewriter.eraseOp(op); return success(); }