diff --git a/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp b/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp --- a/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp @@ -47,6 +47,15 @@ struct GpuAsyncRegionPass::ThreadTokenCallback { ThreadTokenCallback(MLIRContext &context) : builder(&context) {} + WalkResult operator()(Block *block) { + for (Operation &op : make_early_inc_range(*block)) { + if (failed(visit(&op))) + return WalkResult::interrupt(); + } + return WalkResult::advance(); + } + +private: // If `op` implements the AsyncOpInterface, insert a `gpu.wait async` to // create a current token (unless it already exists), and 'thread' that token // through the `op` so that it executes asynchronously. @@ -55,11 +64,15 @@ // host-synchronize execution. A `!gpu.async.token` will therefore only be // used inside of its block and GPU execution will always synchronize with // the host at block boundaries. - WalkResult operator()(Operation *op) { + LogicalResult visit(Operation *op) { if (isa(op)) return op->emitOpError("replace with gpu.launch_func first"); - if (isa(op)) - return op->emitOpError("unexpected pre-existing gpu.wait"); + if (auto waitOp = llvm::dyn_cast(op)) { + if (currentToken) + waitOp.addAsyncDependency(currentToken); + currentToken = waitOp.asyncToken(); + return success(); + } builder.setInsertionPoint(op); if (auto asyncOp = dyn_cast(op)) return rewriteAsyncOp(asyncOp); // Replace GPU op with async version. @@ -71,14 +84,9 @@ return success(); } -private: // Replaces asyncOp with a clone that returns a token. LogicalResult rewriteAsyncOp(gpu::AsyncOpInterface asyncOp) { auto *op = asyncOp.getOperation(); - if (asyncOp.getAsyncToken()) - // TODO: Support ops that are already async. - return op->emitOpError("is already async"); - auto tokenType = builder.getType(); // If there is no current token, insert a `gpu.wait async` without @@ -87,6 +95,11 @@ currentToken = createWaitOp(op->getLoc(), tokenType, {}); asyncOp.addAsyncDependency(currentToken); + // Return early if op returns a token already. + currentToken = asyncOp.getAsyncToken(); + if (currentToken) + return success(); + // Clone the op to return a token in addition to the other results. SmallVector resultTypes; resultTypes.reserve(1 + op->getNumResults()); @@ -315,11 +328,10 @@ // inserts the necessary synchronization (as gpu.wait ops). Assumes sequential // execution semantics and that no GPU ops are asynchronous yet. void GpuAsyncRegionPass::runOnFunction() { - if (getFunction() - .getRegion() - .walk(ThreadTokenCallback(getContext())) - .wasInterrupted()) + getContext().disableMultithreading(); + if (getFunction()->walk(ThreadTokenCallback(getContext())).wasInterrupted()) return signalPassFailure(); + getContext().enableMultithreading(); // Collect gpu.wait ops that we can move out of async.execute regions. getFunction().getRegion().walk(DeferWaitCallback()); diff --git a/mlir/test/Dialect/GPU/async-region.mlir b/mlir/test/Dialect/GPU/async-region.mlir --- a/mlir/test/Dialect/GPU/async-region.mlir +++ b/mlir/test/Dialect/GPU/async-region.mlir @@ -169,4 +169,24 @@ } return } + + // CHECK-LABEL:func @existing_tokens() + func @existing_tokens() { + // CHECK: %[[t0:.*]] = gpu.wait async + // CHECK-NOT: [{{.*}}] + %t0 = gpu.wait async + // CHECK: %[[t1:.*]] = gpu.wait async [%[[t0]], %[[t0]]] + %t1 = gpu.wait async [%t0] + // CHECK: %[[m:.*]], %[[t2:.*]] = gpu.alloc async [%[[t1]], %[[t0]]] () + %0 = gpu.alloc [%t0] () : memref<7xf32> + // CHECK: %[[t3:.*]] = gpu.dealloc async [%[[t2]]] %[[m]] + %t2 = gpu.dealloc async %0 : memref<7xf32> + // CHECK: gpu.wait [%[[t3]]] + gpu.wait + // CHECK: gpu.wait + // CHECK-NOT: async + // CHECK-NOT: [{{.*}}] + gpu.wait + return + } }