diff --git a/mlir/include/mlir/Dialect/Async/Passes.h b/mlir/include/mlir/Dialect/Async/Passes.h --- a/mlir/include/mlir/Dialect/Async/Passes.h +++ b/mlir/include/mlir/Dialect/Async/Passes.h @@ -17,16 +17,15 @@ namespace mlir { -std::unique_ptr> createAsyncParallelForPass(); +std::unique_ptr createAsyncParallelForPass(); -std::unique_ptr> -createAsyncParallelForPass(int numWorkerThreads); +std::unique_ptr createAsyncParallelForPass(int numWorkerThreads); std::unique_ptr> createAsyncToAsyncRuntimePass(); -std::unique_ptr> createAsyncRuntimeRefCountingPass(); +std::unique_ptr createAsyncRuntimeRefCountingPass(); -std::unique_ptr> createAsyncRuntimeRefCountingOptPass(); +std::unique_ptr createAsyncRuntimeRefCountingOptPass(); //===----------------------------------------------------------------------===// // Registration diff --git a/mlir/include/mlir/Dialect/Async/Passes.td b/mlir/include/mlir/Dialect/Async/Passes.td --- a/mlir/include/mlir/Dialect/Async/Passes.td +++ b/mlir/include/mlir/Dialect/Async/Passes.td @@ -11,7 +11,7 @@ include "mlir/Pass/PassBase.td" -def AsyncParallelFor : FunctionPass<"async-parallel-for"> { +def AsyncParallelFor : Pass<"async-parallel-for"> { let summary = "Convert scf.parallel operations to multiple async regions " "executed concurrently for non-overlapping iteration ranges"; let constructor = "mlir::createAsyncParallelForPass()"; @@ -31,7 +31,7 @@ let dependentDialects = ["async::AsyncDialect"]; } -def AsyncRuntimeRefCounting : FunctionPass<"async-runtime-ref-counting"> { +def AsyncRuntimeRefCounting : Pass<"async-runtime-ref-counting"> { let summary = "Automatic reference counting for Async runtime operations"; let description = [{ This pass works at the async runtime abtraction level, after all @@ -48,8 +48,7 @@ let dependentDialects = ["async::AsyncDialect"]; } -def AsyncRuntimeRefCountingOpt : - FunctionPass<"async-runtime-ref-counting-opt"> { +def AsyncRuntimeRefCountingOpt : Pass<"async-runtime-ref-counting-opt"> { let summary = "Optimize automatic reference counting operations for the" "Async runtime by removing redundant operations"; let constructor = "mlir::createAsyncRuntimeRefCountingOptPass()"; diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp --- a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp +++ b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp @@ -100,7 +100,7 @@ assert(numWorkerThreads >= 1); numConcurrentAsyncExecute = numWorkerThreads; } - void runOnFunction() override; + void runOnOperation() override; }; } // namespace @@ -267,21 +267,20 @@ return success(); } -void AsyncParallelForPass::runOnFunction() { +void AsyncParallelForPass::runOnOperation() { MLIRContext *ctx = &getContext(); RewritePatternSet patterns(ctx); patterns.add(ctx, numConcurrentAsyncExecute); - if (failed(applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)))) + if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) signalPassFailure(); } -std::unique_ptr> mlir::createAsyncParallelForPass() { +std::unique_ptr mlir::createAsyncParallelForPass() { return std::make_unique(); } -std::unique_ptr> -mlir::createAsyncParallelForPass(int numWorkerThreads) { +std::unique_ptr mlir::createAsyncParallelForPass(int numWorkerThreads) { return std::make_unique(numWorkerThreads); } diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp --- a/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp +++ b/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp @@ -32,7 +32,7 @@ : public AsyncRuntimeRefCountingBase { public: AsyncRuntimeRefCountingPass() = default; - void runOnFunction() override; + void runOnOperation() override; private: /// Adds an automatic reference counting to the `value`. @@ -323,13 +323,13 @@ return success(); } -void AsyncRuntimeRefCountingPass::runOnFunction() { - FuncOp func = getFunction(); +void AsyncRuntimeRefCountingPass::runOnOperation() { + Operation *op = getOperation(); // Check that we do not have high level async operations in the IR because // otherwise automatic reference counting will produce incorrect results after // execute operations will be lowered to `async.runtime` - WalkResult executeOpWalk = func.walk([&](Operation *op) -> WalkResult { + WalkResult executeOpWalk = op->walk([&](Operation *op) -> WalkResult { if (!isa(op)) return WalkResult::advance(); @@ -343,7 +343,7 @@ } // Add reference counting to block arguments. - WalkResult blockWalk = func.walk([&](Block *block) -> WalkResult { + WalkResult blockWalk = op->walk([&](Block *block) -> WalkResult { for (BlockArgument arg : block->getArguments()) if (isRefCounted(arg.getType())) if (failed(addAutomaticRefCounting(arg))) @@ -358,7 +358,7 @@ } // Add reference counting to operation results. - WalkResult opWalk = func.walk([&](Operation *op) -> WalkResult { + WalkResult opWalk = op->walk([&](Operation *op) -> WalkResult { for (unsigned i = 0; i < op->getNumResults(); ++i) if (isRefCounted(op->getResultTypes()[i])) if (failed(addAutomaticRefCounting(op->getResult(i)))) @@ -371,7 +371,6 @@ signalPassFailure(); } -std::unique_ptr> -mlir::createAsyncRuntimeRefCountingPass() { +std::unique_ptr mlir::createAsyncRuntimeRefCountingPass() { return std::make_unique(); } diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCountingOpt.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCountingOpt.cpp --- a/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCountingOpt.cpp +++ b/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCountingOpt.cpp @@ -26,7 +26,7 @@ : public AsyncRuntimeRefCountingOptBase { public: AsyncRuntimeRefCountingOptPass() = default; - void runOnFunction() override; + void runOnOperation() override; private: LogicalResult optimizeReferenceCounting( @@ -124,8 +124,8 @@ return success(); } -void AsyncRuntimeRefCountingOptPass::runOnFunction() { - FuncOp func = getFunction(); +void AsyncRuntimeRefCountingOptPass::runOnOperation() { + Operation *op = getOperation(); // Mapping from `dropRef.getOperation()` to `addRef.getOperation()`. // @@ -134,7 +134,7 @@ llvm::SmallDenseMap cancellable; // Optimize reference counting for values defined by block arguments. - WalkResult blockWalk = func.walk([&](Block *block) -> WalkResult { + WalkResult blockWalk = op->walk([&](Block *block) -> WalkResult { for (BlockArgument arg : block->getArguments()) if (isRefCounted(arg.getType())) if (failed(optimizeReferenceCounting(arg, cancellable))) @@ -147,7 +147,7 @@ signalPassFailure(); // Optimize reference counting for values defined by operation results. - WalkResult opWalk = func.walk([&](Operation *op) -> WalkResult { + WalkResult opWalk = op->walk([&](Operation *op) -> WalkResult { for (unsigned i = 0; i < op->getNumResults(); ++i) if (isRefCounted(op->getResultTypes()[i])) if (failed(optimizeReferenceCounting(op->getResult(i), cancellable))) @@ -171,7 +171,6 @@ } } -std::unique_ptr> -mlir::createAsyncRuntimeRefCountingOptPass() { +std::unique_ptr mlir::createAsyncRuntimeRefCountingOptPass() { return std::make_unique(); } diff --git a/mlir/test/Integration/GPU/CUDA/async.mlir b/mlir/test/Integration/GPU/CUDA/async.mlir --- a/mlir/test/Integration/GPU/CUDA/async.mlir +++ b/mlir/test/Integration/GPU/CUDA/async.mlir @@ -1,8 +1,9 @@ // RUN: mlir-opt %s \ // RUN: -gpu-kernel-outlining \ // RUN: -pass-pipeline='gpu.module(strip-debuginfo,convert-gpu-to-nvvm,gpu-to-cubin)' \ -// RUN: -gpu-async-region -async-ref-counting -gpu-to-llvm \ -// RUN: -async-to-async-runtime -convert-async-to-llvm -convert-std-to-llvm \ +// RUN: -gpu-async-region -gpu-to-llvm \ +// RUN: -async-to-async-runtime -async-runtime-ref-counting \ +// RUN: -convert-async-to-llvm -convert-std-to-llvm \ // RUN: | mlir-cpu-runner \ // RUN: --shared-libs=%linalg_test_lib_dir/libmlir_cuda_runtime%shlibext \ // RUN: --shared-libs=%linalg_test_lib_dir/libmlir_async_runtime%shlibext \