diff --git a/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp b/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp --- a/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp @@ -12,6 +12,7 @@ //===----------------------------------------------------------------------===// #include "PassDetail.h" +#include "mlir/Dialect/Async/IR/Async.h" #include "mlir/Dialect/GPU/GPUDialect.h" #include "mlir/Dialect/GPU/Passes.h" #include "mlir/Dialect/GPU/Utils.h" @@ -22,24 +23,35 @@ #include "mlir/IR/SymbolTable.h" #include "mlir/Support/LLVM.h" #include "mlir/Transforms/RegionUtils.h" +#include "llvm/ADT/TypeSwitch.h" using namespace mlir; namespace { class GpuAsyncRegionPass : public GpuAsyncRegionPassBase { - struct Callback; + struct ThreadTokenCallback; + struct DeferWaitCallback; void runOnFunction() override; }; } // namespace +static bool isTerminator(Operation *op) { return !op->isKnownNonTerminator(); } +static bool hasSideEffects(Operation *op) { + return !MemoryEffectOpInterface::hasNoEffect(op); +} + // Region walk callback which makes GPU ops implementing the AsyncOpInterface // execute asynchronously. -struct GpuAsyncRegionPass::Callback { +struct GpuAsyncRegionPass::ThreadTokenCallback { + ThreadTokenCallback(MLIRContext &context) : builder(&context) {} + // If `op` implements the AsyncOpInterface, insert a `gpu.wait async` to // create a current token (unless it already exists), and 'thread' that token // through the `op` so that it executes asynchronously. // // If `op` is a terminator or an op with side-effects, insert a `gpu.wait` to - // host-synchronize execution. + // host-synchronize execution. A `!gpu.async.token` will therefore only be + // used inside of its block and GPU execution will always synchronize with + // the host at block boundaries. WalkResult operator()(Operation *op) { if (isa(op)) return op->emitOpError("replace with gpu.launch_func first"); @@ -50,14 +62,13 @@ return rewriteAsyncOp(asyncOp); // Replace GPU op with async version. if (!currentToken) return success(); - if (!op->hasTrait() && - MemoryEffectOpInterface::hasNoEffect(op)) - return success(); // Insert host synchronization before terminator or op with side effects. - currentToken = createWaitOp(op->getLoc(), Type(), {currentToken}); + if (isTerminator(op) || hasSideEffects(op)) + currentToken = createWaitOp(op->getLoc(), Type(), {currentToken}); return success(); } +private: // Replaces asyncOp with a clone that returns a token. LogicalResult rewriteAsyncOp(gpu::AsyncOpInterface asyncOp) { auto *op = asyncOp.getOperation(); @@ -104,13 +115,159 @@ Value currentToken = {}; }; +// Callback for `async.execute` ops which tries to push the contained +// synchronous `gpu.wait` op to the dependencies of the `async.execute`. +struct GpuAsyncRegionPass::DeferWaitCallback { + // If the `executeOp`s token is used only in `async.execute` or `async.await` + // ops, add the region's last `gpu.wait` op to the worklist if it is + // synchronous and is the last op with side effects. + void operator()(async::ExecuteOp executeOp) { + if (!areAllUsersExecuteOrAwait(executeOp.token())) + return; + // async.execute's region is currently restricted to one block. + for (auto &op : llvm::reverse(executeOp.getBody()->without_terminator())) { + if (auto waitOp = dyn_cast(op)) { + if (!waitOp.asyncToken()) + worklist.push_back(waitOp); + return; + } + if (hasSideEffects(&op)) + return; + } + } + + // The destructor performs the actual rewrite work. + ~DeferWaitCallback() { + for (size_t i = 0; i < worklist.size(); ++i) { + auto waitOp = worklist[i]; + auto executeOp = waitOp.getParentOfType(); + auto numDependencies = waitOp.asyncDependencies().size(); + + // Erase `gpu.wait` and return async dependencies from region instead. + auto &yieldOp = executeOp.getBody()->getOperations().back(); + yieldOp.insertOperands(yieldOp.getNumOperands(), + waitOp.asyncDependencies()); + waitOp.erase(); + auto asyncTokens = addAsyncTokenResults(executeOp, numDependencies); + + // Add the async dependency to each user of the `async.execute` token. + for (Operation *user : executeOp.token().getUsers()) + addAsyncDependencyAfter(asyncTokens, user); + } + } + +private: + // Append `count` `!async.value` results to `executeOp`. + static ValueRange addAsyncTokenResults(async::ExecuteOp &executeOp, + unsigned count) { + auto numResults = executeOp.getNumResults() + count; + + // Construct new result type list with `count` additional types. + SmallVector resultTypes; + resultTypes.reserve(numResults); + copy(executeOp.getResultTypes(), std::back_inserter(resultTypes)); + OpBuilder builder(executeOp); + auto tokenType = builder.getType(); + resultTypes.resize(numResults, tokenType); + + // Clone executeOp with the extra `!gpu.async.token` results. + auto newOp = builder.create( + executeOp.getLoc(), TypeRange{resultTypes}.drop_front() /*drop token*/, + executeOp.dependencies(), executeOp.operands()); + BlockAndValueMapping mapper; + newOp.getRegion().getBlocks().clear(); + executeOp.getRegion().cloneInto(&newOp.getRegion(), mapper); + + // Replace executeOp with cloned one. + executeOp.getOperation()->replaceAllUsesWith( + newOp.getResults().drop_back(count)); + executeOp.erase(); + executeOp = newOp; + + // Return the new result values. + return executeOp.getResults().take_back(count); + } + + // Returns whether all token users are either 'async.execute' or 'async.await' + // ops. This is used as a requirement for pushing 'gpu.wait' ops from a + // 'async.execute' body to it's users. Specifically, we do not allow + // terminator users, because it could mean that the `async.execute` is inside + // control flow code. + static bool areAllUsersExecuteOrAwait(Value token) { + return llvm::all_of(token.getUsers(), [](Operation *user) { + return isa(user); + }); + } + + // Add the `asyncToken` as dependency as needed after `op`. + void addAsyncDependencyAfter(ValueRange asyncTokens, Operation *op) { + OpBuilder builder(op->getContext()); + auto loc = op->getLoc(); + + Block::iterator it; + SmallVector tokens; + tokens.reserve(asyncTokens.size()); + TypeSwitch(op) + .Case([&](auto awaitOp) { + // Add async.await ops to wait for the !gpu.async.tokens. + builder.setInsertionPointAfter(op); + for (auto asyncToken : asyncTokens) + tokens.push_back( + builder.create(loc, asyncToken).result()); + // Set `it` after the inserted async.await ops. + it = builder.getInsertionPoint(); + }) + .Case([&](auto executeOp) { + // Set `it` to the beginning of the region and add asyncTokens to the + // async.execute operands. + it = executeOp.getBody()->begin(); + executeOp.operandsMutable().append(asyncTokens); + SmallVector tokenTypes( + asyncTokens.size(), builder.getType()); + copy(executeOp.getBody()->addArguments(tokenTypes), + std::back_inserter(tokens)); + }); + + // Advance `it` to terminator or op with side-effects. + it = std::find_if(it, Block::iterator(), [](Operation &op) { + return isTerminator(&op) || hasSideEffects(&op); + }); + + // If `op` implements the AsyncOpInterface, add `token` to the list of async + // dependencies. + if (auto asyncOp = dyn_cast(*it)) { + for (auto token : tokens) + asyncOp.addAsyncDependency(token); + return; + } + + // Otherwise, insert a gpu.wait before 'it'. + builder.setInsertionPoint(it->getBlock(), it); + auto waitOp = builder.create(loc, Type{}, tokens); + + // If the new waitOp is at the end of an async.execute region, add it to the + // worklist. 'operator()(executeOp)' would do the same, but this is faster. + auto executeOp = dyn_cast(it->getParentOp()); + if (executeOp && areAllUsersExecuteOrAwait(executeOp.token()) && + !it->getNextNode()) + worklist.push_back(waitOp); + } + + SmallVector worklist; +}; + // Replaces synchronous GPU ops in the op's region with asynchronous ones and // inserts the necessary synchronization (as gpu.wait ops). Assumes sequential // execution semantics and that no GPU ops are asynchronous yet. void GpuAsyncRegionPass::runOnFunction() { - Callback callback{OpBuilder(&getContext())}; - if (getFunction().getRegion().walk(callback).wasInterrupted()) + if (getFunction() + .getRegion() + .walk(ThreadTokenCallback(getContext())) + .wasInterrupted()) return signalPassFailure(); + + // Collect gpu.wait ops that we can move out of gpu.execute regions. + getFunction().getRegion().walk(DeferWaitCallback()); } std::unique_ptr> mlir::createGpuAsyncRegionPass() { diff --git a/mlir/test/Dialect/GPU/async-region.mlir b/mlir/test/Dialect/GPU/async-region.mlir --- a/mlir/test/Dialect/GPU/async-region.mlir +++ b/mlir/test/Dialect/GPU/async-region.mlir @@ -24,4 +24,78 @@ return } + // CHECK-LABEL:func @defer_wait(%{{.*}}: index) + func @defer_wait(%sz : index) { + // CHECK: %[[a0:.*]], %[[f0:.*]] = async.execute + %a0 = async.execute { + // CHECK: %[[t:.*]] = gpu.launch_func async + gpu.launch_func @kernels::@kernel + blocks in (%sz, %sz, %sz) threads in (%sz, %sz, %sz) + // CHECK-NOT: gpu.wait + // CHECK: async.yield %[[t]] + async.yield + } + + // CHECK: %[[a1:.*]], %[[f1:.*]] = async.execute + // CHECK-SAME: %[[f0]] + %a1 = async.execute [%a0] { + // CHECK: %[[t:.*]] = gpu.launch_func async + gpu.launch_func @kernels::@kernel + blocks in (%sz, %sz, %sz) threads in (%sz, %sz, %sz) + // CHECK-NOT: gpu.wait + // CHECK: async.yield %[[t]] + async.yield + } + + // CHECK: async.await %[[a1]] + // CHECK: %[[t:.*]] = async.await %[[f1]] + // CHECK: gpu.wait [%[[t]]] + async.await %a1 : !async.token + return + } + + // CHECK-LABEL:func @defer_wait_blocked_by_side_effect(%{{.*}}: index) + func @defer_wait_blocked_by_side_effect(%sz : index) { + // CHECK: %[[a:.*]] = async.execute + %a = async.execute { + // CHECK: %[[t:.*]] = gpu.launch_func async + gpu.launch_func @kernels::@kernel + blocks in (%sz, %sz, %sz) threads in (%sz, %sz, %sz) + // CHECK: gpu.wait [%[[t]]] + call @foo() : () -> () + async.yield + } + + // CHECK: async.await %[[a]] + // CHECK-NOT: gpu.wait + async.await %a : !async.token + return + } + + // CHECK-LABEL:func @defer_wait_pass_through(%{{.*}}: index) + func @defer_wait_pass_through(%sz : index) { + // CHECK: %[[a0:.*]], %[[f0:.*]] = async.execute + %a0 = async.execute { + // CHECK: %[[t:.*]] = gpu.launch_func async + gpu.launch_func @kernels::@kernel + blocks in (%sz, %sz, %sz) threads in (%sz, %sz, %sz) + // CHECK-NOT: gpu.wait + // CHECK: async.yield %[[t]] + async.yield + } + + // CHECK: %[[a1:.*]], %[[f1:.*]] = async.execute + // CHECK-SAME: %[[f0]] + %a1 = async.execute [%a0] { + // CHECK-NOT: gpu.wait + // CHECK: async.yield %{{.*}} + async.yield + } + + // CHECK: async.await %[[a1]] + // CHECK: %[[t:.*]] = async.await %[[f1]] + // CHECK: gpu.wait [%[[t]]] + async.await %a1 : !async.token + return + } }