diff --git a/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp b/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp --- a/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp @@ -85,18 +85,19 @@ asyncOp.addAsyncDependency(currentToken); // Clone the op to return a token in addition to the other results. - SmallVector resultTypes = {tokenType}; + SmallVector resultTypes; resultTypes.reserve(1 + op->getNumResults()); copy(op->getResultTypes(), std::back_inserter(resultTypes)); + resultTypes.push_back(tokenType); auto *newOp = Operation::create(op->getLoc(), op->getName(), resultTypes, op->getOperands(), op->getMutableAttrDict(), op->getSuccessors()); // Replace the op with the async clone. auto results = newOp->getResults(); - currentToken = results.front(); + currentToken = results.back(); builder.insert(newOp); - op->replaceAllUsesWith(results.drop_front()); + op->replaceAllUsesWith(results.drop_back()); op->erase(); return success(); @@ -165,7 +166,14 @@ // Construct new result type list with `count` additional types. SmallVector resultTypes; resultTypes.reserve(numResults); - copy(executeOp.getResultTypes(), std::back_inserter(resultTypes)); + transform(executeOp.getResultTypes(), std::back_inserter(resultTypes), + [](Type type) { + // Extract value type from !async.value. + if (auto valueType = type.dyn_cast()) + return valueType.getValueType(); + assert(type.isa() && "expected token type"); + return type; + }); OpBuilder builder(executeOp); auto tokenType = builder.getType(); resultTypes.resize(numResults, tokenType); @@ -266,7 +274,7 @@ .wasInterrupted()) return signalPassFailure(); - // Collect gpu.wait ops that we can move out of gpu.execute regions. + // Collect gpu.wait ops that we can move out of async.execute regions. getFunction().getRegion().walk(DeferWaitCallback()); } diff --git a/mlir/test/Dialect/GPU/async-region.mlir b/mlir/test/Dialect/GPU/async-region.mlir --- a/mlir/test/Dialect/GPU/async-region.mlir +++ b/mlir/test/Dialect/GPU/async-region.mlir @@ -18,7 +18,11 @@ // CHECK: %[[t2:.*]] = gpu.launch_func async [%[[t1]]] gpu.launch_func @kernels::@kernel blocks in (%sz, %sz, %sz) threads in (%sz, %sz, %sz) - // CHECK: gpu.wait [%[[t2]]] + // CHECK: %[[m:.*]], %[[t3:.*]] = gpu.alloc async [%[[t2]]] () + %0 = gpu.alloc() : memref<7xf32> + // CHECK: %[[t4:.*]] = gpu.dealloc async [%[[t3]]] %[[m]] + gpu.dealloc %0 : memref<7xf32> + // CHECK: gpu.wait [%[[t4]]] // CHECK: call @foo call @foo() : () -> () return @@ -98,4 +102,27 @@ async.await %a1 : !async.token return } + + // CHECK-LABEL:func @async_execute_with_result(%{{.*}}: index) + func @async_execute_with_result(%sz : index) -> index { + // CHECK: %[[a0:.*]], %[[f0:.*]]:2 = async.execute + // CHECK-SAME: -> (!async.value, !async.value) + %a0, %f0 = async.execute -> !async.value { + // CHECK: %[[t:.*]] = gpu.launch_func async + gpu.launch_func @kernels::@kernel + blocks in (%sz, %sz, %sz) threads in (%sz, %sz, %sz) + // CHECK-NOT: gpu.wait + // CHECK: async.yield {{.*}}, %[[t]] : index, !gpu.async.token + async.yield %sz : index + } + + // CHECK: async.await %[[a0]] : !async.token + // CHECK: %[[t:.*]] = async.await %[[f0]]#1 : !async.value + // CHECK: gpu.wait [%[[t]]] + async.await %a0 : !async.token + // CHECK: %[[x:.*]] = async.await %[[f0]]#0 : !async.value + %x = async.await %f0 : !async.value + // CHECK: return %[[x]] : index + return %x : index + } }