diff --git a/mlir/docs/PassManagement.md b/mlir/docs/PassManagement.md --- a/mlir/docs/PassManagement.md +++ b/mlir/docs/PassManagement.md @@ -622,18 +622,34 @@ } ``` -We can include the generated registration calls via: +Using the `gen-pass-decls` generator, we can generate the much of the +boilerplater above automatically. This generator takes as an input a `-name` +parameter, that provides a tag for the group of passes that are being generated. +This generator produces two chunks of output: + +The first is the code for registering the declarative passes with the global +registry. For each pass, the generator produces a `registerFooPass` where `Foo` +is the name of the definition specified in tablegen. It also generates a +`registerGroupPasses`, where `Group` is the tag provided via the `-name` input +parameter, that registers all of the passes present. ```c++ -void registerMyPasses() { - // The generated registration is not static, so we need to include this in - // a location that we can call into. #define GEN_PASS_REGISTRATION #include "Passes.h.inc" + +void registerMyPasses() { + // Register all of our passes. + registerMyPasses(); + + // Register `MyPass` specifically. + registerMyPassPass(); } ``` -We can then update the original C++ pass definition: +The second is a base class for each of the passes, with each containing most of +the boiler plate related to pass definition. These classes are named in the form +of `MyPassBase`, where `MyPass` is the name of the definition in tablegen. We +can update the original C++ pass definition as so: ```c++ /// Include the generated base pass class definitions. @@ -651,6 +667,10 @@ } ``` +Using the `gen-pass-doc` generator, we can generate markdown documentation for +each of our passes. See [Passes.md](Passes.md) for example output of real MLIR +passes. + ### Tablegen Specification The `Pass` class is used to begin a new pass definition. This class takes as an diff --git a/mlir/include/mlir/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.h b/mlir/include/mlir/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.h --- a/mlir/include/mlir/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.h +++ b/mlir/include/mlir/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.h @@ -6,8 +6,8 @@ // //===----------------------------------------------------------------------===// -#ifndef MLIR_EDGE_CONVERSION_AVX512TOLLVM_CONVERTAVX512TOLLVM_H_ -#define MLIR_EDGE_CONVERSION_AVX512TOLLVM_CONVERTAVX512TOLLVM_H_ +#ifndef MLIR_CONVERSION_AVX512TOLLVM_CONVERTAVX512TOLLVM_H_ +#define MLIR_CONVERSION_AVX512TOLLVM_CONVERTAVX512TOLLVM_H_ #include @@ -26,4 +26,4 @@ } // namespace mlir -#endif // MLIR_EDGE_CONVERSION_AVX512TOLLVM_CONVERTAVX512TOLLVM_H_ +#endif // MLIR_CONVERSION_AVX512TOLLVM_CONVERTAVX512TOLLVM_H_ diff --git a/mlir/include/mlir/Conversion/AffineToStandard/AffineToStandard.h b/mlir/include/mlir/Conversion/AffineToStandard/AffineToStandard.h --- a/mlir/include/mlir/Conversion/AffineToStandard/AffineToStandard.h +++ b/mlir/include/mlir/Conversion/AffineToStandard/AffineToStandard.h @@ -20,6 +20,7 @@ struct LogicalResult; class MLIRContext; class OpBuilder; +class Pass; class RewritePattern; class Value; class ValueRange; @@ -57,6 +58,12 @@ /// Emit code that computes the upper bound of the given affine loop using /// standard arithmetic operations. Value lowerAffineUpperBound(AffineForOp op, OpBuilder &builder); + +/// Lowers affine control flow operations (ForStmt, IfStmt and AffineApplyOp) +/// to equivalent lower-level constructs (flow of basic blocks and arithmetic +/// primitives). +std::unique_ptr createLowerAffinePass(); + } // namespace mlir #endif // MLIR_CONVERSION_AFFINETOSTANDARD_AFFINETOSTANDARD_H diff --git a/mlir/include/mlir/Conversion/CMakeLists.txt b/mlir/include/mlir/Conversion/CMakeLists.txt --- a/mlir/include/mlir/Conversion/CMakeLists.txt +++ b/mlir/include/mlir/Conversion/CMakeLists.txt @@ -1,6 +1,6 @@ set(LLVM_TARGET_DEFINITIONS Passes.td) -mlir_tablegen(Passes.h.inc -gen-pass-decls) +mlir_tablegen(Passes.h.inc -gen-pass-decls -name Conversion) add_public_tablegen_target(MLIRConversionPassIncGen) add_mlir_doc(Passes -gen-pass-doc ConversionPasses ./) diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Conversion/Passes.h @@ -0,0 +1,41 @@ +//===- Passes.h - Conversion Pass Construction and Registration -----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_PASSES_H +#define MLIR_CONVERSION_PASSES_H + +#include "mlir/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.h" +#include "mlir/Conversion/AffineToStandard/AffineToStandard.h" +#include "mlir/Conversion/GPUCommon/GPUCommonPass.h" +#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h" +#include "mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h" +#include "mlir/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.h" +#include "mlir/Conversion/GPUToVulkan/ConvertGPUToVulkanPass.h" +#include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h" +#include "mlir/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.h" +#include "mlir/Conversion/LinalgToStandard/LinalgToStandard.h" +#include "mlir/Conversion/SCFToGPU/SCFToGPUPass.h" +#include "mlir/Conversion/SCFToStandard/SCFToStandard.h" +#include "mlir/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVMPass.h" +#include "mlir/Conversion/ShapeToSCF/ShapeToSCF.h" +#include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h" +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" +#include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.h" +#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" +#include "mlir/Conversion/VectorToROCDL/VectorToROCDL.h" +#include "mlir/Conversion/VectorToSCF/VectorToSCF.h" + +namespace mlir { + +/// Generate the code for registering conversion passes. +#define GEN_PASS_REGISTRATION +#include "mlir/Conversion/Passes.h.inc" + +} // namespace mlir + +#endif // MLIR_CONVERSION_PASSES_H diff --git a/mlir/include/mlir/Dialect/Affine/CMakeLists.txt b/mlir/include/mlir/Dialect/Affine/CMakeLists.txt --- a/mlir/include/mlir/Dialect/Affine/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/Affine/CMakeLists.txt @@ -1,7 +1,7 @@ add_subdirectory(IR) set(LLVM_TARGET_DEFINITIONS Passes.td) -mlir_tablegen(Passes.h.inc -gen-pass-decls) +mlir_tablegen(Passes.h.inc -gen-pass-decls -name Affine) add_public_tablegen_target(MLIRAffinePassIncGen) add_mlir_doc(Passes -gen-pass-doc AffinePasses ./) diff --git a/mlir/include/mlir/Dialect/Affine/Passes.h b/mlir/include/mlir/Dialect/Affine/Passes.h --- a/mlir/include/mlir/Dialect/Affine/Passes.h +++ b/mlir/include/mlir/Dialect/Affine/Passes.h @@ -14,17 +14,12 @@ #ifndef MLIR_DIALECT_AFFINE_TRANSFORMS_PASSES_H #define MLIR_DIALECT_AFFINE_TRANSFORMS_PASSES_H -#include "mlir/Support/LLVM.h" -#include +#include "mlir/Pass/Pass.h" #include namespace mlir { class AffineForOp; -class FuncOp; -class ModuleOp; -class Pass; -template class OperationPass; /// Creates a simplification pass for affine structures (maps and sets). In /// addition, this pass also normalizes memrefs to have the trivial (identity) @@ -79,6 +74,14 @@ /// Overload relying on pass options for initialization. std::unique_ptr> createSuperVectorizePass(); +//===----------------------------------------------------------------------===// +// Registration +//===----------------------------------------------------------------------===// + +/// Generate the code for registering passes. +#define GEN_PASS_REGISTRATION +#include "mlir/Dialect/Affine/Passes.h.inc" + } // end namespace mlir #endif // MLIR_DIALECT_AFFINE_RANSFORMS_PASSES_H diff --git a/mlir/include/mlir/Dialect/GPU/CMakeLists.txt b/mlir/include/mlir/Dialect/GPU/CMakeLists.txt --- a/mlir/include/mlir/Dialect/GPU/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/GPU/CMakeLists.txt @@ -12,7 +12,7 @@ add_public_tablegen_target(MLIRParallelLoopMapperEnumsGen) set(LLVM_TARGET_DEFINITIONS Passes.td) -mlir_tablegen(Passes.h.inc -gen-pass-decls) +mlir_tablegen(Passes.h.inc -gen-pass-decls -name GPU) add_public_tablegen_target(MLIRGPUPassIncGen) add_mlir_doc(Passes -gen-pass-doc GPUPasses ./) diff --git a/mlir/include/mlir/Dialect/GPU/Passes.h b/mlir/include/mlir/Dialect/GPU/Passes.h --- a/mlir/include/mlir/Dialect/GPU/Passes.h +++ b/mlir/include/mlir/Dialect/GPU/Passes.h @@ -13,21 +13,23 @@ #ifndef MLIR_DIALECT_GPU_PASSES_H_ #define MLIR_DIALECT_GPU_PASSES_H_ -#include +#include "mlir/Pass/Pass.h" namespace mlir { - -class MLIRContext; -class ModuleOp; -template class OperationPass; -class OwningRewritePatternList; - std::unique_ptr> createGpuKernelOutliningPass(); /// Collect a set of patterns to rewrite ops within the GPU dialect. void populateGpuRewritePatterns(MLIRContext *context, OwningRewritePatternList &patterns); +//===----------------------------------------------------------------------===// +// Registration +//===----------------------------------------------------------------------===// + +/// Generate the code for registering passes. +#define GEN_PASS_REGISTRATION +#include "mlir/Dialect/GPU/Passes.h.inc" + } // namespace mlir #endif // MLIR_DIALECT_GPU_PASSES_H_ diff --git a/mlir/include/mlir/Dialect/LLVMIR/Transforms/CMakeLists.txt b/mlir/include/mlir/Dialect/LLVMIR/Transforms/CMakeLists.txt --- a/mlir/include/mlir/Dialect/LLVMIR/Transforms/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/LLVMIR/Transforms/CMakeLists.txt @@ -1,5 +1,5 @@ set(LLVM_TARGET_DEFINITIONS Passes.td) -mlir_tablegen(Passes.h.inc -gen-pass-decls) +mlir_tablegen(Passes.h.inc -gen-pass-decls -name LLVM) add_public_tablegen_target(MLIRLLVMPassIncGen) add_mlir_doc(Passes -gen-pass-doc LLVMPasses ./) diff --git a/mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.h b/mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.h @@ -0,0 +1,26 @@ +//===- Passes.h - LLVM Pass Construction and Registration -----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_LLVMIR_TRANSFORMS_PASSES_H +#define MLIR_DIALECT_LLVMIR_TRANSFORMS_PASSES_H + +#include "mlir/Dialect/LLVMIR/Transforms/LegalizeForExport.h" +#include "mlir/Pass/Pass.h" + +namespace mlir { + +namespace LLVM { + +/// Generate the code for registering conversion passes. +#define GEN_PASS_REGISTRATION +#include "mlir/Dialect/LLVMIR/Transforms/Passes.h.inc" + +} // namespace LLVM +} // namespace mlir + +#endif // MLIR_DIALECT_LLVMIR_TRANSFORMS_PASSES_H diff --git a/mlir/include/mlir/Dialect/Linalg/CMakeLists.txt b/mlir/include/mlir/Dialect/Linalg/CMakeLists.txt --- a/mlir/include/mlir/Dialect/Linalg/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/Linalg/CMakeLists.txt @@ -1,7 +1,7 @@ add_subdirectory(IR) set(LLVM_TARGET_DEFINITIONS Passes.td) -mlir_tablegen(Passes.h.inc -gen-pass-decls) +mlir_tablegen(Passes.h.inc -gen-pass-decls -name Linalg) add_public_tablegen_target(MLIRLinalgPassIncGen) add_mlir_doc(Passes -gen-pass-doc LinalgPasses ./) diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h --- a/mlir/include/mlir/Dialect/Linalg/Passes.h +++ b/mlir/include/mlir/Dialect/Linalg/Passes.h @@ -13,17 +13,9 @@ #ifndef MLIR_DIALECT_LINALG_PASSES_H_ #define MLIR_DIALECT_LINALG_PASSES_H_ -#include "mlir/Support/LLVM.h" -#include "llvm/ADT/ArrayRef.h" +#include "mlir/Pass/Pass.h" namespace mlir { -class FuncOp; -class MLIRContext; -class ModuleOp; -template class OperationPass; -class OwningRewritePatternList; -class Pass; - std::unique_ptr> createLinalgFoldUnitExtentDimsPass(); std::unique_ptr> createLinalgFusionPass(); @@ -66,6 +58,14 @@ void populateLinalgFoldUnitExtentDimsPatterns( MLIRContext *context, OwningRewritePatternList &patterns); +//===----------------------------------------------------------------------===// +// Registration +//===----------------------------------------------------------------------===// + +/// Generate the code for registering passes. +#define GEN_PASS_REGISTRATION +#include "mlir/Dialect/Linalg/Passes.h.inc" + } // namespace mlir #endif // MLIR_DIALECT_LINALG_PASSES_H_ diff --git a/mlir/include/mlir/Dialect/Quant/CMakeLists.txt b/mlir/include/mlir/Dialect/Quant/CMakeLists.txt --- a/mlir/include/mlir/Dialect/Quant/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/Quant/CMakeLists.txt @@ -2,7 +2,7 @@ add_mlir_doc(QuantOps -gen-dialect-doc QuantDialect Dialects/) set(LLVM_TARGET_DEFINITIONS Passes.td) -mlir_tablegen(Passes.h.inc -gen-pass-decls) +mlir_tablegen(Passes.h.inc -gen-pass-decls -name Quant) add_public_tablegen_target(MLIRQuantPassIncGen) add_mlir_doc(Passes -gen-pass-doc QuantPasses ./) diff --git a/mlir/include/mlir/Dialect/Quant/Passes.h b/mlir/include/mlir/Dialect/Quant/Passes.h --- a/mlir/include/mlir/Dialect/Quant/Passes.h +++ b/mlir/include/mlir/Dialect/Quant/Passes.h @@ -16,12 +16,9 @@ #ifndef MLIR_DIALECT_QUANT_PASSES_H #define MLIR_DIALECT_QUANT_PASSES_H -#include +#include "mlir/Pass/Pass.h" namespace mlir { -class FuncOp; -template class OperationPass; - namespace quant { /// Creates a pass that converts quantization simulation operations (i.e. @@ -35,6 +32,14 @@ /// destructive and cannot be undone. std::unique_ptr> createConvertConstPass(); +//===----------------------------------------------------------------------===// +// Registration +//===----------------------------------------------------------------------===// + +/// Generate the code for registering passes. +#define GEN_PASS_REGISTRATION +#include "mlir/Dialect/Quant/Passes.h.inc" + } // namespace quant } // namespace mlir diff --git a/mlir/include/mlir/Dialect/SCF/CMakeLists.txt b/mlir/include/mlir/Dialect/SCF/CMakeLists.txt --- a/mlir/include/mlir/Dialect/SCF/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/SCF/CMakeLists.txt @@ -2,7 +2,7 @@ add_mlir_doc(SCFOps -gen-dialect-doc SCFDialect Dialects/) set(LLVM_TARGET_DEFINITIONS Passes.td) -mlir_tablegen(Passes.h.inc -gen-pass-decls) +mlir_tablegen(Passes.h.inc -gen-pass-decls -name SCF) add_public_tablegen_target(MLIRSCFPassIncGen) add_dependencies(mlir-headers MLIRSCFPassIncGen) diff --git a/mlir/include/mlir/Dialect/SCF/Passes.h b/mlir/include/mlir/Dialect/SCF/Passes.h --- a/mlir/include/mlir/Dialect/SCF/Passes.h +++ b/mlir/include/mlir/Dialect/SCF/Passes.h @@ -13,13 +13,10 @@ #ifndef MLIR_DIALECT_SCF_PASSES_H_ #define MLIR_DIALECT_SCF_PASSES_H_ -#include "llvm/ADT/ArrayRef.h" -#include +#include "mlir/Pass/Pass.h" namespace mlir { -class Pass; - /// Creates a pass that specializes for loop for unrolling and /// vectorization. std::unique_ptr createForLoopSpecializationPass(); @@ -35,6 +32,14 @@ std::unique_ptr createParallelLoopTilingPass(llvm::ArrayRef tileSize = {}); +//===----------------------------------------------------------------------===// +// Registration +//===----------------------------------------------------------------------===// + +/// Generate the code for registering passes. +#define GEN_PASS_REGISTRATION +#include "mlir/Dialect/SCF/Passes.h.inc" + } // namespace mlir #endif // MLIR_DIALECT_SCF_PASSES_H_ diff --git a/mlir/include/mlir/Dialect/SPIRV/CMakeLists.txt b/mlir/include/mlir/Dialect/SPIRV/CMakeLists.txt --- a/mlir/include/mlir/Dialect/SPIRV/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/SPIRV/CMakeLists.txt @@ -38,7 +38,7 @@ add_dependencies(mlir-headers MLIRSPIRVTargetAndABIIncGen) set(LLVM_TARGET_DEFINITIONS Passes.td) -mlir_tablegen(Passes.h.inc -gen-pass-decls) +mlir_tablegen(Passes.h.inc -gen-pass-decls -name SPIRV) add_public_tablegen_target(MLIRSPIRVPassIncGen) add_dependencies(mlir-headers MLIRSPIRVPassIncGen) diff --git a/mlir/include/mlir/Dialect/SPIRV/Passes.h b/mlir/include/mlir/Dialect/SPIRV/Passes.h --- a/mlir/include/mlir/Dialect/SPIRV/Passes.h +++ b/mlir/include/mlir/Dialect/SPIRV/Passes.h @@ -50,6 +50,14 @@ /// spv.CompositeInsert into spv.CompositeConstruct. std::unique_ptr> createRewriteInsertsPass(); +//===----------------------------------------------------------------------===// +// Registration +//===----------------------------------------------------------------------===// + +/// Generate the code for registering passes. +#define GEN_PASS_REGISTRATION +#include "mlir/Dialect/SPIRV/Passes.h.inc" + } // namespace spirv } // namespace mlir diff --git a/mlir/include/mlir/Dialect/Shape/Transforms/CMakeLists.txt b/mlir/include/mlir/Dialect/Shape/Transforms/CMakeLists.txt --- a/mlir/include/mlir/Dialect/Shape/Transforms/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/Shape/Transforms/CMakeLists.txt @@ -1,5 +1,5 @@ set(LLVM_TARGET_DEFINITIONS Passes.td) -mlir_tablegen(Passes.h.inc -gen-pass-decls) +mlir_tablegen(Passes.h.inc -gen-pass-decls -name Shape) add_public_tablegen_target(MLIRShapeTransformsIncGen) add_mlir_doc(Passes -gen-pass-doc ShapePasses ./) diff --git a/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h b/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h @@ -14,15 +14,9 @@ #ifndef MLIR_DIALECT_SHAPE_TRANSFORMS_PASSES_H_ #define MLIR_DIALECT_SHAPE_TRANSFORMS_PASSES_H_ -#include +#include "mlir/Pass/Pass.h" namespace mlir { - -class FunctionPass; -class MLIRContext; -class OwningRewritePatternList; -class Pass; - /// Creates an instance of the ShapeToShapeLowering pass that legalizes Shape /// dialect to be convertible to Standard. For example, `shape.num_elements` get /// transformed to `shape.reduce`, which can be lowered to SCF and Standard. @@ -42,6 +36,14 @@ MLIRContext *ctx); std::unique_ptr createRemoveShapeConstraintsPass(); +//===----------------------------------------------------------------------===// +// Registration +//===----------------------------------------------------------------------===// + +/// Generate the code for registering passes. +#define GEN_PASS_REGISTRATION +#include "mlir/Dialect/Shape/Transforms/Passes.h.inc" + } // end namespace mlir #endif // MLIR_DIALECT_SHAPE_TRANSFORMS_PASSES_H_ diff --git a/mlir/include/mlir/Dialect/StandardOps/Transforms/CMakeLists.txt b/mlir/include/mlir/Dialect/StandardOps/Transforms/CMakeLists.txt --- a/mlir/include/mlir/Dialect/StandardOps/Transforms/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/StandardOps/Transforms/CMakeLists.txt @@ -1,5 +1,5 @@ set(LLVM_TARGET_DEFINITIONS Passes.td) -mlir_tablegen(Passes.h.inc -gen-pass-decls) +mlir_tablegen(Passes.h.inc -gen-pass-decls -name Standard) add_public_tablegen_target(MLIRStandardTransformsIncGen) add_mlir_doc(Passes -gen-pass-doc StandardPasses ./) diff --git a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h @@ -15,12 +15,10 @@ #ifndef MLIR_DIALECT_STANDARD_TRANSFORMS_PASSES_H_ #define MLIR_DIALECT_STANDARD_TRANSFORMS_PASSES_H_ -#include +#include "mlir/Pass/Pass.h" namespace mlir { -class Pass; -class MLIRContext; class OwningRewritePatternList; /// Creates an instance of the ExpandAtomic pass. @@ -29,6 +27,14 @@ void populateExpandTanhPattern(OwningRewritePatternList &patterns, MLIRContext *ctx); +//===----------------------------------------------------------------------===// +// Registration +//===----------------------------------------------------------------------===// + +/// Generate the code for registering passes. +#define GEN_PASS_REGISTRATION +#include "mlir/Dialect/StandardOps/Transforms/Passes.h.inc" + } // end namespace mlir #endif // MLIR_DIALECT_STANDARD_TRANSFORMS_PASSES_H_ diff --git a/mlir/include/mlir/InitAllPasses.h b/mlir/include/mlir/InitAllPasses.h --- a/mlir/include/mlir/InitAllPasses.h +++ b/mlir/include/mlir/InitAllPasses.h @@ -14,38 +14,17 @@ #ifndef MLIR_INITALLPASSES_H_ #define MLIR_INITALLPASSES_H_ -#include "mlir/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.h" -#include "mlir/Conversion/GPUCommon/GPUCommonPass.h" -#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h" -#include "mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h" -#include "mlir/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.h" -#include "mlir/Conversion/GPUToVulkan/ConvertGPUToVulkanPass.h" -#include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h" -#include "mlir/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.h" -#include "mlir/Conversion/LinalgToStandard/LinalgToStandard.h" -#include "mlir/Conversion/SCFToGPU/SCFToGPUPass.h" -#include "mlir/Conversion/SCFToStandard/SCFToStandard.h" -#include "mlir/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVMPass.h" -#include "mlir/Conversion/ShapeToSCF/ShapeToSCF.h" -#include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h" -#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" -#include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.h" -#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" -#include "mlir/Conversion/VectorToROCDL/VectorToROCDL.h" -#include "mlir/Conversion/VectorToSCF/VectorToSCF.h" +#include "mlir/Conversion/Passes.h" #include "mlir/Dialect/Affine/Passes.h" #include "mlir/Dialect/GPU/Passes.h" -#include "mlir/Dialect/LLVMIR/Transforms/LegalizeForExport.h" +#include "mlir/Dialect/LLVMIR/Transforms/Passes.h" #include "mlir/Dialect/Linalg/Passes.h" #include "mlir/Dialect/Quant/Passes.h" #include "mlir/Dialect/SCF/Passes.h" #include "mlir/Dialect/SPIRV/Passes.h" #include "mlir/Dialect/Shape/Transforms/Passes.h" #include "mlir/Dialect/StandardOps/Transforms/Passes.h" -#include "mlir/Transforms/LocationSnapshot.h" #include "mlir/Transforms/Passes.h" -#include "mlir/Transforms/ViewOpGraph.h" -#include "mlir/Transforms/ViewRegionGraph.h" #include @@ -60,48 +39,37 @@ // The global registry is interesting to interact with the command-line tools. inline void registerAllPasses() { // Init general passes -#define GEN_PASS_REGISTRATION -#include "mlir/Transforms/Passes.h.inc" + registerTransformsPasses(); // Conversion passes -#define GEN_PASS_REGISTRATION -#include "mlir/Conversion/Passes.h.inc" + registerConversionPasses(); // Affine -#define GEN_PASS_REGISTRATION -#include "mlir/Dialect/Affine/Passes.h.inc" + registerAffinePasses(); // GPU -#define GEN_PASS_REGISTRATION -#include "mlir/Dialect/GPU/Passes.h.inc" + registerGPUPasses(); // Linalg -#define GEN_PASS_REGISTRATION -#include "mlir/Dialect/Linalg/Passes.h.inc" + registerLinalgPasses(); // LLVM -#define GEN_PASS_REGISTRATION -#include "mlir/Dialect/LLVMIR/Transforms/Passes.h.inc" - - // Loop -#define GEN_PASS_REGISTRATION -#include "mlir/Dialect/SCF/Passes.h.inc" + LLVM::registerLLVMPasses(); // Quant -#define GEN_PASS_REGISTRATION -#include "mlir/Dialect/Quant/Passes.h.inc" + quant::registerQuantPasses(); + + // SCF + registerSCFPasses(); // SPIR-V -#define GEN_PASS_REGISTRATION -#include "mlir/Dialect/SPIRV/Passes.h.inc" + spirv::registerSPIRVPasses(); // Standard -#define GEN_PASS_REGISTRATION -#include "mlir/Dialect/StandardOps/Transforms/Passes.h.inc" + registerStandardPasses(); // Shape -#define GEN_PASS_REGISTRATION -#include "mlir/Dialect/Shape/Transforms/Passes.h.inc" + registerShapePasses(); } } // namespace mlir diff --git a/mlir/include/mlir/Transforms/CMakeLists.txt b/mlir/include/mlir/Transforms/CMakeLists.txt --- a/mlir/include/mlir/Transforms/CMakeLists.txt +++ b/mlir/include/mlir/Transforms/CMakeLists.txt @@ -1,6 +1,6 @@ set(LLVM_TARGET_DEFINITIONS Passes.td) -mlir_tablegen(Passes.h.inc -gen-pass-decls) +mlir_tablegen(Passes.h.inc -gen-pass-decls -name Transforms) add_public_tablegen_target(MLIRTransformsPassIncGen) add_mlir_doc(Passes -gen-pass-doc GeneralPasses ./) diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h --- a/mlir/include/mlir/Transforms/Passes.h +++ b/mlir/include/mlir/Transforms/Passes.h @@ -14,17 +14,19 @@ #ifndef MLIR_TRANSFORMS_PASSES_H #define MLIR_TRANSFORMS_PASSES_H -#include "mlir/Support/LLVM.h" -#include +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/LocationSnapshot.h" +#include "mlir/Transforms/ViewOpGraph.h" +#include "mlir/Transforms/ViewRegionGraph.h" #include namespace mlir { class AffineForOp; -class FuncOp; -class ModuleOp; -class Pass; -template class OperationPass; + +//===----------------------------------------------------------------------===// +// Passes +//===----------------------------------------------------------------------===// /// Creates an instance of the BufferPlacement pass. std::unique_ptr createBufferPlacementPass(); @@ -89,6 +91,15 @@ /// Creates a pass which delete symbol operations that are unreachable. This /// pass may *only* be scheduled on an operation that defines a SymbolTable. std::unique_ptr createSymbolDCEPass(); + +//===----------------------------------------------------------------------===// +// Registration +//===----------------------------------------------------------------------===// + +/// Generate the code for registering passes. +#define GEN_PASS_REGISTRATION +#include "mlir/Transforms/Passes.h.inc" + } // end namespace mlir #endif // MLIR_TRANSFORMS_PASSES_H diff --git a/mlir/tools/mlir-tblgen/PassGen.cpp b/mlir/tools/mlir-tblgen/PassGen.cpp --- a/mlir/tools/mlir-tblgen/PassGen.cpp +++ b/mlir/tools/mlir-tblgen/PassGen.cpp @@ -14,6 +14,7 @@ #include "mlir/TableGen/GenInfo.h" #include "mlir/TableGen/Pass.h" #include "llvm/ADT/StringExtras.h" +#include "llvm/Support/CommandLine.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/TableGen/Error.h" #include "llvm/TableGen/Record.h" @@ -21,6 +22,11 @@ using namespace mlir; using namespace mlir::tblgen; +static llvm::cl::OptionCategory passGenCat("Options for -gen-pass-decls"); +static llvm::cl::opt + groupName("name", llvm::cl::desc("The name of this group of passes"), + llvm::cl::cat(passGenCat)); + //===----------------------------------------------------------------------===// // GEN: Pass base class generation //===----------------------------------------------------------------------===// @@ -109,36 +115,49 @@ // GEN: Pass registration generation //===----------------------------------------------------------------------===// +/// The code snippet used to generate the start of a pass base class. +/// +/// {0}: The def name of the pass record. +/// {1}: The argument of the pass. +/// {2): The summary of the pass. +/// {3}: The code for constructing the pass. +const char *const passRegistrationCode = R"( +//===----------------------------------------------------------------------===// +// {0} Registration +//===----------------------------------------------------------------------===// + +inline void register{0}Pass() {{ + ::mlir::registerPass("{1}", "{2}", []() -> std::unique_ptr<::mlir::Pass> {{ + return {3}; + }); +} +)"; + +/// {0}: The name of the pass group. +const char *const passGroupRegistrationCode = R"( +//===----------------------------------------------------------------------===// +// {0} Registration +//===----------------------------------------------------------------------===// + +inline void register{0}Passes() {{ +)"; + /// Emit the code for registering each of the given passes with the global /// PassRegistry. static void emitRegistration(ArrayRef passes, raw_ostream &os) { os << "#ifdef GEN_PASS_REGISTRATION\n"; for (const Pass &pass : passes) { - os << llvm::formatv("#define GEN_PASS_REGISTRATION_{0}\n", - pass.getDef()->getName()); - } - os << "#endif // GEN_PASS_REGISTRATION\n"; - - for (const Pass &pass : passes) { - os << llvm::formatv("#ifdef GEN_PASS_REGISTRATION_{0}\n", - pass.getDef()->getName()); - os << llvm::formatv("::mlir::registerPass(\"{0}\", \"{1}\", []() -> " - "std::unique_ptr<::mlir::Pass> {{ return {2}; });\n", + os << llvm::formatv(passRegistrationCode, pass.getDef()->getName(), pass.getArgument(), pass.getSummary(), pass.getConstructor()); - os << llvm::formatv("#endif // GEN_PASS_REGISTRATION_{0}\n", - pass.getDef()->getName()); - os << llvm::formatv("#undef GEN_PASS_REGISTRATION_{0}\n", - pass.getDef()->getName()); } - os << "#ifdef GEN_PASS_REGISTRATION\n"; - for (const Pass &pass : passes) { - os << llvm::formatv("#undef GEN_PASS_REGISTRATION_{0}\n", - pass.getDef()->getName()); - } - os << "#endif // GEN_PASS_REGISTRATION\n"; + os << llvm::formatv(passGroupRegistrationCode, groupName); + for (const Pass &pass : passes) + os << " register" << pass.getDef()->getName() << "Pass();\n"; + os << "}\n"; os << "#undef GEN_PASS_REGISTRATION\n"; + os << "#endif // GEN_PASS_REGISTRATION\n"; } //===----------------------------------------------------------------------===//