diff --git a/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp b/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp --- a/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp @@ -85,18 +85,19 @@ asyncOp.addAsyncDependency(currentToken); // Clone the op to return a token in addition to the other results. - SmallVector resultTypes = {tokenType}; + SmallVector resultTypes; resultTypes.reserve(1 + op->getNumResults()); copy(op->getResultTypes(), std::back_inserter(resultTypes)); + resultTypes.push_back(tokenType); auto *newOp = Operation::create(op->getLoc(), op->getName(), resultTypes, op->getOperands(), op->getMutableAttrDict(), op->getSuccessors()); // Replace the op with the async clone. auto results = newOp->getResults(); - currentToken = results.front(); + currentToken = results.back(); builder.insert(newOp); - op->replaceAllUsesWith(results.drop_front()); + op->replaceAllUsesWith(results.drop_back()); op->erase(); return success(); @@ -165,7 +166,13 @@ // Construct new result type list with `count` additional types. SmallVector resultTypes; resultTypes.reserve(numResults); - copy(executeOp.getResultTypes(), std::back_inserter(resultTypes)); + transform(executeOp.getResultTypes(), std::back_inserter(resultTypes), + [](Type type) { + // Extract value type from !async.value. + if (auto value_type = type.dyn_cast()) + return value_type.getValueType(); + return type; // This is an !async.token type. + }); OpBuilder builder(executeOp); auto tokenType = builder.getType(); resultTypes.resize(numResults, tokenType); @@ -266,7 +273,7 @@ .wasInterrupted()) return signalPassFailure(); - // Collect gpu.wait ops that we can move out of gpu.execute regions. + // Collect gpu.wait ops that we can move out of async.execute regions. getFunction().getRegion().walk(DeferWaitCallback()); }