diff --git a/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp b/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp --- a/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp @@ -58,8 +58,12 @@ WalkResult operator()(Operation *op) { if (isa(op)) return op->emitOpError("replace with gpu.launch_func first"); - if (isa(op)) - return op->emitOpError("unexpected pre-existing gpu.wait"); + if (auto waitOp = llvm::dyn_cast(op)) { + if (currentToken) + waitOp.addAsyncDependency(currentToken); + currentToken = waitOp.asyncToken(); + return success(); + } builder.setInsertionPoint(op); if (auto asyncOp = dyn_cast(op)) return rewriteAsyncOp(asyncOp); // Replace GPU op with async version. @@ -75,10 +79,6 @@ // Replaces asyncOp with a clone that returns a token. LogicalResult rewriteAsyncOp(gpu::AsyncOpInterface asyncOp) { auto *op = asyncOp.getOperation(); - if (asyncOp.getAsyncToken()) - // TODO: Support ops that are already async. - return op->emitOpError("is already async"); - auto tokenType = builder.getType(); // If there is no current token, insert a `gpu.wait async` without @@ -87,6 +87,11 @@ currentToken = createWaitOp(op->getLoc(), tokenType, {}); asyncOp.addAsyncDependency(currentToken); + // Return early if op returns token already. + currentToken = asyncOp.getAsyncToken(); + if (currentToken) + return success(); + // Clone the op to return a token in addition to the other results. SmallVector resultTypes; resultTypes.reserve(1 + op->getNumResults()); diff --git a/mlir/test/Dialect/GPU/async-region.mlir b/mlir/test/Dialect/GPU/async-region.mlir --- a/mlir/test/Dialect/GPU/async-region.mlir +++ b/mlir/test/Dialect/GPU/async-region.mlir @@ -169,4 +169,24 @@ } return } + + // CHECK-LABEL:func @existing_tokens() + func @existing_tokens() { + // CHECK: %[[t0:.*]] = gpu.wait async + // CHECK-NOT: [{{.*}}] + %t0 = gpu.wait async + // CHECK: %[[t1:.*]] = gpu.wait async [%[[t0]], %[[t0]]] + %t1 = gpu.wait async [%t0] + // CHECK: %[[m:.*]], %[[t2:.*]] = gpu.alloc async [%[[t1]], %[[t0]]] () + %0 = gpu.alloc [%t0] () : memref<7xf32> + // CHECK: %[[t3:.*]] = gpu.dealloc async [%[[t2]]] %[[m]] + %t2 = gpu.dealloc async %0 : memref<7xf32> + // CHECK: gpu.wait [%[[t3]]] + gpu.wait + // CHECK: gpu.wait + // CHECK-NOT: async + // CHECK-NOT: [{{.*}}] + gpu.wait + return + } }