diff --git a/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp b/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp --- a/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp @@ -85,18 +85,19 @@ asyncOp.addAsyncDependency(currentToken); // Clone the op to return a token in addition to the other results. - SmallVector resultTypes = {tokenType}; + SmallVector resultTypes; resultTypes.reserve(1 + op->getNumResults()); copy(op->getResultTypes(), std::back_inserter(resultTypes)); + resultTypes.push_back(tokenType); auto *newOp = Operation::create(op->getLoc(), op->getName(), resultTypes, op->getOperands(), op->getMutableAttrDict(), op->getSuccessors()); // Replace the op with the async clone. auto results = newOp->getResults(); - currentToken = results.front(); + currentToken = results.back(); builder.insert(newOp); - op->replaceAllUsesWith(results.drop_front()); + op->replaceAllUsesWith(results.drop_back()); op->erase(); return success(); @@ -165,7 +166,14 @@ // Construct new result type list with `count` additional types. SmallVector resultTypes; resultTypes.reserve(numResults); - copy(executeOp.getResultTypes(), std::back_inserter(resultTypes)); + transform(executeOp.getResultTypes(), std::back_inserter(resultTypes), + [](Type type) { + // Extract value type from !async.value. + if (auto valueType = type.dyn_cast()) + return valueType.getValueType(); + assert(type.isa() && "expected token type"); + return type; + }); OpBuilder builder(executeOp); auto tokenType = builder.getType(); resultTypes.resize(numResults, tokenType); @@ -266,7 +274,7 @@ .wasInterrupted()) return signalPassFailure(); - // Collect gpu.wait ops that we can move out of gpu.execute regions. + // Collect gpu.wait ops that we can move out of async.execute regions. getFunction().getRegion().walk(DeferWaitCallback()); } diff --git a/mlir/test/Dialect/GPU/async-region.mlir b/mlir/test/Dialect/GPU/async-region.mlir --- a/mlir/test/Dialect/GPU/async-region.mlir +++ b/mlir/test/Dialect/GPU/async-region.mlir @@ -18,7 +18,11 @@ // CHECK: %[[t2:.*]] = gpu.launch_func async [%[[t1]]] gpu.launch_func @kernels::@kernel blocks in (%sz, %sz, %sz) threads in (%sz, %sz, %sz) - // CHECK: gpu.wait [%[[t2]]] + // CHECK: %[[m:.*]], %[[t3:.*]] = gpu.alloc async [%[[t2]]] () + %0 = gpu.alloc() : memref<7xf32> + // CHECK: %[[t4:.*]] = gpu.dealloc async [%[[t3]]] %[[m]] + gpu.dealloc %0 : memref<7xf32> + // CHECK: gpu.wait [%[[t4]]] // CHECK: call @foo call @foo() : () -> () return