diff --git a/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp b/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp --- a/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp @@ -12,6 +12,7 @@ //===----------------------------------------------------------------------===// #include "PassDetail.h" +#include "mlir/Dialect/Async/IR/Async.h" #include "mlir/Dialect/GPU/GPUDialect.h" #include "mlir/Dialect/GPU/Passes.h" #include "mlir/Dialect/GPU/Utils.h" @@ -22,24 +23,37 @@ #include "mlir/IR/SymbolTable.h" #include "mlir/Support/LLVM.h" #include "mlir/Transforms/RegionUtils.h" +#include "llvm/ADT/TypeSwitch.h" using namespace mlir; namespace { class GpuAsyncRegionPass : public GpuAsyncRegionPassBase { - struct Callback; + struct ThreadTokenCallback; + struct DeferWaitCallback; void runOnFunction() override; }; } // namespace +static bool isTerminator(Operation &op) { + return op.hasTrait(); +} +static bool isTerminatorOrHasSideEffects(Operation &op) { + return isTerminator(op) || !MemoryEffectOpInterface::hasNoEffect(&op); +} + // Region walk callback which makes GPU ops implementing the AsyncOpInterface // execute asynchronously. -struct GpuAsyncRegionPass::Callback { +struct GpuAsyncRegionPass::ThreadTokenCallback { + ThreadTokenCallback(MLIRContext &context) : builder(&context) {} + // If `op` implements the AsyncOpInterface, insert a `gpu.wait async` to // create a current token (unless it already exists), and 'thread' that token // through the `op` so that it executes asynchronously. // // If `op` is a terminator or an op with side-effects, insert a `gpu.wait` to - // host-synchronize execution. + // host-synchronize execution. A `!gpu.async.token` will therefore only be + // used inside of its block and GPU execution will always synchronize with + // the host at block boundaries. WalkResult operator()(Operation *op) { if (isa(op)) return op->emitOpError("replace with gpu.launch_func first"); @@ -50,14 +64,13 @@ return rewriteAsyncOp(asyncOp); // Replace GPU op with async version. if (!currentToken) return success(); - if (!op->hasTrait() && - MemoryEffectOpInterface::hasNoEffect(op)) - return success(); // Insert host synchronization before terminator or op with side effects. - currentToken = createWaitOp(op->getLoc(), Type(), {currentToken}); + if (isTerminatorOrHasSideEffects(*op)) + currentToken = createWaitOp(op->getLoc(), Type(), {currentToken}); return success(); } +private: // Replaces asyncOp with a clone that returns a token. LogicalResult rewriteAsyncOp(gpu::AsyncOpInterface asyncOp) { auto *op = asyncOp.getOperation(); @@ -104,15 +117,149 @@ Value currentToken = {}; }; +// Callback for `async.execute` ops which tries to push the contained +// synchronous `gpu.wait` op to the dependencies of the `async.execute`. +struct GpuAsyncRegionPass::DeferWaitCallback { + // If the `executeOp`s token is used only in `async.execute` or `async.await` + // ops, add the region's last `gpu.wait` op to the worklist if it is + // synchronous and is the last op with side effects. + void operator()(async::ExecuteOp executeOp) { + if (!areUsersSupported(executeOp)) + return; + for (auto &op : llvm::reverse(executeOp.getBody()->without_terminator())) { + if (auto waitOp = dyn_cast(op)) { + if (!waitOp.asyncToken()) + worklist.push_back(waitOp); + return; + } + if (isTerminatorOrHasSideEffects(op)) + return; + } + } + + // The destructor performs the actual rewrite work. + ~DeferWaitCallback() { + for (size_t i = 0; i < worklist.size(); ++i) { + auto waitOp = worklist[i]; + auto executeOp = waitOp.getParentOfType(); + auto numDependencies = waitOp.asyncDependencies().size(); + + // Erase `gpu.wait` and return async dependencies from region instead. + auto &yieldOp = executeOp.getBody()->getOperations().back(); + yieldOp.insertOperands(yieldOp.getNumOperands(), + waitOp.asyncDependencies()); + waitOp.erase(); + auto asyncTokens = addAsyncTokenResults(executeOp, numDependencies); + + // Add the async dependency to each user of the `async.execute` token. + for (Operation *user : executeOp.token().getUsers()) + addAsyncDependencyAfter(asyncTokens, user); + } + } + +private: + // Adds `count` `!async.value` results to `executeOp`. + static ValueRange addAsyncTokenResults(async::ExecuteOp &executeOp, + unsigned count) { + auto numResults = executeOp.getNumResults() + count; + + // Construct new result type list with `count` additional types. + SmallVector resultTypes; + resultTypes.reserve(numResults); + copy(executeOp.getResultTypes(), std::back_inserter(resultTypes)); + OpBuilder builder(executeOp.getOperation()); + auto tokenType = builder.getType(); + resultTypes.resize(numResults, async::ValueType::get(tokenType)); + + // Clone executeOp with an extra `!gpu.async.token` results. + auto newOp = builder.create( + executeOp.getLoc(), TypeRange{resultTypes}, executeOp.dependencies(), + executeOp.operands()); + BlockAndValueMapping mapper; + executeOp.getRegion().cloneInto(&newOp.getRegion(), mapper); + + // Replace executeOp with cloned one. + executeOp.getOperation()->replaceAllUsesWith( + newOp.getResults().drop_back()); + executeOp.erase(); + executeOp = newOp; + + // Return the new result values. + return executeOp.getResults().take_back(count); + } + + // Returns whether we support adding async dependencies to the users of + // `executeOp`. Specifically, we do not support terminator users which could + // mean that the `executeOp` is inside control flow code. + static bool areUsersSupported(async::ExecuteOp executeOp) { + return llvm::all_of(executeOp.token().getUsers(), [](Operation *user) { + return isa(user) || isa(user); + }); + } + + // Add the `asyncToken` as dependency as needed after `op`. + void addAsyncDependencyAfter(ValueRange asyncTokens, Operation *op) { + OpBuilder builder(op->getContext()); + auto loc = op->getLoc(); + + SmallVector tokens; + tokens.reserve(asyncTokens.size()); + TypeSwitch(op) + .Case([&](auto awaitOp) { + builder.setInsertionPointAfter(op); + for (auto asyncToken : asyncTokens) + tokens.push_back( + builder.create(loc, asyncToken).getResult(0)); + }) + .Case([&](auto executeOp) { + // Set `op` to the beginning of the region and add asyncTokens to the + // async.execute operands. + op = &executeOp.getBody()->front(); + executeOp.operandsMutable().append(asyncTokens); + SmallVector tokenTypes( + asyncTokens.size(), builder.getType()); + copy(executeOp.getBody()->addArguments(tokenTypes), + std::back_inserter(tokens)); + }); + + // Find terminator or op with side-effects after `op`. + op = &*std::find_if(Block::iterator(op), Block::iterator(), + isTerminatorOrHasSideEffects); + + // If `op` implements the AsyncOpInterface, add `token` to the list of async + // dependencies. + if (auto asyncOp = dyn_cast(op)) { + for (auto token : tokens) + asyncOp.addAsyncDependency(token); + return; + } + + // Otherwise, insert a gpu.wait before 'op'. + builder.setInsertionPoint(op); + auto waitOp = builder.create(loc, Type{}, tokens); + + // If the new waitOp is at the end of an async.execute region, add it to the + // worklist. + auto executeOp = dyn_cast(op->getParentOp()); + if (executeOp && areUsersSupported(executeOp) && isTerminator(*op)) + worklist.push_back(waitOp); + } + + SmallVector worklist; +}; + // Replaces synchronous GPU ops in the op's region with asynchronous ones and // inserts the necessary synchronization (as gpu.wait ops). Assumes sequential // execution semantics and that no GPU ops are asynchronous yet. void GpuAsyncRegionPass::runOnFunction() { if (getFunction() .getRegion() - .walk(Callback{OpBuilder(&getContext())}) + .walk(ThreadTokenCallback(getContext())) .wasInterrupted()) return signalPassFailure(); + + // Collect gpu.wait ops that we can move out of gpu.execute regions. + getFunction().getRegion().walk(DeferWaitCallback()); } std::unique_ptr> mlir::createGpuAsyncRegionPass() { diff --git a/mlir/test/Dialect/GPU/async-region.mlir b/mlir/test/Dialect/GPU/async-region.mlir --- a/mlir/test/Dialect/GPU/async-region.mlir +++ b/mlir/test/Dialect/GPU/async-region.mlir @@ -9,8 +9,8 @@ func @foo() -> () - // CHECK-LABEL:func @async(%{{.*}}: index) - func @async(%sz : index) { + // CHECK-LABEL:func @thread_token(%{{.*}}: index) + func @thread_token(%sz : index) { // CHECK: %[[t0:.*]] = gpu.wait async // CHECK: %[[t1:.*]] = gpu.launch_func async [%[[t0]]] gpu.launch_func @kernels::@kernel @@ -24,4 +24,78 @@ return } + // CHECK-LABEL:func @defer_wait(%{{.*}}: index) + func @defer_wait(%sz : index) { + // CHECK: %[[a0:.*]], %[[f0:.*]] = async.execute + %a0 = async.execute { + // CHECK: %[[t:.*]] = gpu.launch_func async + gpu.launch_func @kernels::@kernel + blocks in (%sz, %sz, %sz) threads in (%sz, %sz, %sz) + // CHECK-NOT: gpu.wait + // CHECK: async.yield %[[t]] + async.yield + } + + // CHECK: %[[a1:.*]], %[[f1:.*]] = async.execute + // CHECK-SAME: %[[f0]] + %a1 = async.execute [%a0] { + // CHECK: %[[t:.*]] = gpu.launch_func async + gpu.launch_func @kernels::@kernel + blocks in (%sz, %sz, %sz) threads in (%sz, %sz, %sz) + // CHECK-NOT: gpu.wait + // CHECK: async.yield %[[t]] + async.yield + } + + // CHECK: async.await %[[a1]] + // CHECK: %[[t:.*]] = async.await %[[f1]] + // CHECK: gpu.wait [%[[t]]] + async.await %a1 : !async.token + return + } + + // CHECK-LABEL:func @defer_wait_blocked_by_side_effect(%{{.*}}: index) + func @defer_wait_blocked_by_side_effect(%sz : index) { + // CHECK: %[[a:.*]] = async.execute + %a = async.execute { + // CHECK: %[[t:.*]] = gpu.launch_func async + gpu.launch_func @kernels::@kernel + blocks in (%sz, %sz, %sz) threads in (%sz, %sz, %sz) + // CHECK: gpu.wait [%[[t]]] + call @foo() : () -> () + async.yield + } + + // CHECK: async.await %[[a]] + // CHECK-NOT: gpu.wait + async.await %a : !async.token + return + } + + // CHECK-LABEL:func @defer_wait_pass_through(%{{.*}}: index) + func @defer_wait_pass_through(%sz : index) { + // CHECK: %[[a0:.*]], %[[f0:.*]] = async.execute + %a0 = async.execute { + // CHECK: %[[t:.*]] = gpu.launch_func async + gpu.launch_func @kernels::@kernel + blocks in (%sz, %sz, %sz) threads in (%sz, %sz, %sz) + // CHECK-NOT: gpu.wait + // CHECK: async.yield %[[t]] + async.yield + } + + // CHECK: %[[a1:.*]], %[[f1:.*]] = async.execute + // CHECK-SAME: %[[f0]] + %a1 = async.execute [%a0] { + // CHECK-NOT: gpu.wait + // CHECK: async.yield %{{.*}} + async.yield + } + + // CHECK: async.await %[[a1]] + // CHECK: %[[t:.*]] = async.await %[[f1]] + // CHECK: gpu.wait [%[[t]]] + async.await %a1 : !async.token + return + } }