diff --git a/mlir/include/mlir/Dialect/GPU/Passes.h b/mlir/include/mlir/Dialect/GPU/Passes.h --- a/mlir/include/mlir/Dialect/GPU/Passes.h +++ b/mlir/include/mlir/Dialect/GPU/Passes.h @@ -18,7 +18,11 @@ namespace mlir { std::unique_ptr> createGpuKernelOutliningPass(); -/// Collect a set of patterns to rewrite ops within the GPU dialect. +/// Collect a set of patterns to rewrite all-reduce ops within the GPU dialect. +void populateGpuAllReducePatterns(MLIRContext *context, + OwningRewritePatternList &patterns); + +/// Collect all patterns to rewrite ops within the GPU dialect. void populateGpuRewritePatterns(MLIRContext *context, OwningRewritePatternList &patterns); diff --git a/mlir/lib/Dialect/GPU/CMakeLists.txt b/mlir/lib/Dialect/GPU/CMakeLists.txt --- a/mlir/lib/Dialect/GPU/CMakeLists.txt +++ b/mlir/lib/Dialect/GPU/CMakeLists.txt @@ -4,6 +4,7 @@ Transforms/KernelOutlining.cpp Transforms/MemoryPromotion.cpp Transforms/ParallelLoopMapper.cpp + Transforms/Passes.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/GPU diff --git a/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp b/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp --- a/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp @@ -397,7 +397,7 @@ }; } // namespace -void mlir::populateGpuRewritePatterns(MLIRContext *context, - OwningRewritePatternList &patterns) { +void mlir::populateGpuAllReducePatterns(MLIRContext *context, + OwningRewritePatternList &patterns) { patterns.insert(context); } diff --git a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp --- a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp @@ -10,7 +10,6 @@ // //===----------------------------------------------------------------------===// -#include "PassDetail.h" #include "mlir/Dialect/GPU/GPUDialect.h" #include "mlir/Dialect/GPU/Passes.h" #include "mlir/Dialect/GPU/Utils.h" diff --git a/mlir/lib/Dialect/GPU/Transforms/PassDetail.h b/mlir/lib/Dialect/GPU/Transforms/Passes.cpp rename from mlir/lib/Dialect/GPU/Transforms/PassDetail.h rename to mlir/lib/Dialect/GPU/Transforms/Passes.cpp --- a/mlir/lib/Dialect/GPU/Transforms/PassDetail.h +++ b/mlir/lib/Dialect/GPU/Transforms/Passes.cpp @@ -1,4 +1,4 @@ -//===- PassDetail.h - GPU Pass class details --------------------*- C++ -*-===// +//===- Passes.cpp - GPU passes registration ---------------------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -9,6 +9,7 @@ #ifndef DIALECT_GPU_TRANSFORMS_PASSDETAIL_H_ #define DIALECT_GPU_TRANSFORMS_PASSDETAIL_H_ +#include "mlir/Dialect/GPU/Passes.h" #include "mlir/Pass/Pass.h" namespace mlir { @@ -16,6 +17,11 @@ #define GEN_PASS_CLASSES #include "mlir/Dialect/GPU/Passes.h.inc" +void populateGpuRewritePatterns(MLIRContext *context, + OwningRewritePatternList &patterns) { + populateGpuAllReducePatterns(context, patterns); +} + } // end namespace mlir #endif // DIALECT_GPU_TRANSFORMS_PASSDETAIL_H_ diff --git a/mlir/test/Dialect/GPU/all-reduce-max.mlir b/mlir/test/Dialect/GPU/all-reduce-max.mlir --- a/mlir/test/Dialect/GPU/all-reduce-max.mlir +++ b/mlir/test/Dialect/GPU/all-reduce-max.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -test-all-reduce-lowering %s | FileCheck %s +// RUN: mlir-opt -test-gpu-rewrite %s | FileCheck %s // NOTE: Assertions have been autogenerated by utils/generate-test-checks.py // CHECK: gpu.module @kernels { diff --git a/mlir/test/Dialect/GPU/all-reduce.mlir b/mlir/test/Dialect/GPU/all-reduce.mlir --- a/mlir/test/Dialect/GPU/all-reduce.mlir +++ b/mlir/test/Dialect/GPU/all-reduce.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -test-all-reduce-lowering %s | FileCheck %s +// RUN: mlir-opt -test-gpu-rewrite %s | FileCheck %s // NOTE: Assertions have been autogenerated by utils/generate-test-checks.py // CHECK: gpu.module @kernels { diff --git a/mlir/test/lib/Transforms/CMakeLists.txt b/mlir/test/lib/Transforms/CMakeLists.txt --- a/mlir/test/lib/Transforms/CMakeLists.txt +++ b/mlir/test/lib/Transforms/CMakeLists.txt @@ -1,6 +1,5 @@ # Exclude tests from libMLIR.so add_mlir_library(MLIRTestTransforms - TestAllReduceLowering.cpp TestAffineLoopParametricTiling.cpp TestBufferPlacement.cpp TestExpandTanh.cpp @@ -15,6 +14,7 @@ TestLoopFusion.cpp TestGpuMemoryPromotion.cpp TestGpuParallelLoopMapping.cpp + TestGpuRewrite.cpp TestInlining.cpp TestLinalgFusionTransforms.cpp TestLinalgHoisting.cpp diff --git a/mlir/test/lib/Transforms/TestAllReduceLowering.cpp b/mlir/test/lib/Transforms/TestGpuRewrite.cpp rename from mlir/test/lib/Transforms/TestAllReduceLowering.cpp rename to mlir/test/lib/Transforms/TestGpuRewrite.cpp --- a/mlir/test/lib/Transforms/TestAllReduceLowering.cpp +++ b/mlir/test/lib/Transforms/TestGpuRewrite.cpp @@ -18,8 +18,8 @@ using namespace mlir; namespace { -struct TestAllReduceLoweringPass - : public PassWrapper> { +struct TestGpuRewritePass + : public PassWrapper> { void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); } @@ -33,8 +33,8 @@ namespace mlir { void registerTestAllReduceLoweringPass() { - PassRegistration pass( - "test-all-reduce-lowering", - "Lowers gpu.all-reduce ops within the GPU dialect."); + PassRegistration pass( + "test-gpu-rewrite", + "Applies all rewrite patterns within the GPU dialect."); } } // namespace mlir