diff --git a/mlir/include/mlir/Dialect/GPU/Passes.h b/mlir/include/mlir/Dialect/GPU/Passes.h --- a/mlir/include/mlir/Dialect/GPU/Passes.h +++ b/mlir/include/mlir/Dialect/GPU/Passes.h @@ -18,9 +18,15 @@ namespace mlir { std::unique_ptr> createGpuKernelOutliningPass(); -/// Collect a set of patterns to rewrite ops within the GPU dialect. -void populateGpuRewritePatterns(MLIRContext *context, - OwningRewritePatternList &patterns); +/// Collect a set of patterns to rewrite all-reduce ops within the GPU dialect. +void populateGpuAllReducePatterns(MLIRContext *context, + OwningRewritePatternList &patterns); + +/// Collect all patterns to rewrite ops within the GPU dialect. +inline void populateGpuRewritePatterns(MLIRContext *context, + OwningRewritePatternList &patterns) { + populateGpuAllReducePatterns(context, patterns); +} //===----------------------------------------------------------------------===// // Registration diff --git a/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp b/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp --- a/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp @@ -397,7 +397,7 @@ }; } // namespace -void mlir::populateGpuRewritePatterns(MLIRContext *context, - OwningRewritePatternList &patterns) { +void mlir::populateGpuAllReducePatterns(MLIRContext *context, + OwningRewritePatternList &patterns) { patterns.insert(context); } diff --git a/mlir/test/Dialect/GPU/all-reduce-max.mlir b/mlir/test/Dialect/GPU/all-reduce-max.mlir --- a/mlir/test/Dialect/GPU/all-reduce-max.mlir +++ b/mlir/test/Dialect/GPU/all-reduce-max.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -test-all-reduce-lowering %s | FileCheck %s +// RUN: mlir-opt -test-gpu-rewrite %s | FileCheck %s // NOTE: Assertions have been autogenerated by utils/generate-test-checks.py // CHECK: gpu.module @kernels { diff --git a/mlir/test/Dialect/GPU/all-reduce.mlir b/mlir/test/Dialect/GPU/all-reduce.mlir --- a/mlir/test/Dialect/GPU/all-reduce.mlir +++ b/mlir/test/Dialect/GPU/all-reduce.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -test-all-reduce-lowering %s | FileCheck %s +// RUN: mlir-opt -test-gpu-rewrite %s | FileCheck %s // NOTE: Assertions have been autogenerated by utils/generate-test-checks.py // CHECK: gpu.module @kernels { diff --git a/mlir/test/lib/Transforms/CMakeLists.txt b/mlir/test/lib/Transforms/CMakeLists.txt --- a/mlir/test/lib/Transforms/CMakeLists.txt +++ b/mlir/test/lib/Transforms/CMakeLists.txt @@ -1,6 +1,5 @@ # Exclude tests from libMLIR.so add_mlir_library(MLIRTestTransforms - TestAllReduceLowering.cpp TestAffineLoopParametricTiling.cpp TestBufferPlacement.cpp TestExpandTanh.cpp @@ -15,6 +14,7 @@ TestLoopFusion.cpp TestGpuMemoryPromotion.cpp TestGpuParallelLoopMapping.cpp + TestGpuRewrite.cpp TestInlining.cpp TestLinalgFusionTransforms.cpp TestLinalgHoisting.cpp diff --git a/mlir/test/lib/Transforms/TestAllReduceLowering.cpp b/mlir/test/lib/Transforms/TestGpuRewrite.cpp rename from mlir/test/lib/Transforms/TestAllReduceLowering.cpp rename to mlir/test/lib/Transforms/TestGpuRewrite.cpp --- a/mlir/test/lib/Transforms/TestAllReduceLowering.cpp +++ b/mlir/test/lib/Transforms/TestGpuRewrite.cpp @@ -18,8 +18,8 @@ using namespace mlir; namespace { -struct TestAllReduceLoweringPass - : public PassWrapper> { +struct TestGpuRewritePass + : public PassWrapper> { void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); } @@ -33,8 +33,8 @@ namespace mlir { void registerTestAllReduceLoweringPass() { - PassRegistration pass( - "test-all-reduce-lowering", - "Lowers gpu.all-reduce ops within the GPU dialect."); + PassRegistration pass( + "test-gpu-rewrite", + "Applies all rewrite patterns within the GPU dialect."); } } // namespace mlir