diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h --- a/mlir/include/mlir/Dialect/Linalg/Passes.h +++ b/mlir/include/mlir/Dialect/Linalg/Passes.h @@ -13,6 +13,7 @@ #ifndef MLIR_DIALECT_LINALG_PASSES_H_ #define MLIR_DIALECT_LINALG_PASSES_H_ +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Pass/Pass.h" @@ -77,6 +78,41 @@ /// Create a pass to tile a LinalgOp and fuse its producers. std::unique_ptr> createLinalgTileAndFuseTensorOpsPass(); +//===----------------------------------------------------------------------===// +/// Linalg strategy passes. +//===----------------------------------------------------------------------===// +/// Create a LinalgStrategyTilePass. +std::unique_ptr> createLinalgStrategyTilePass( + linalg::LinalgTilingOptions opt = linalg::LinalgTilingOptions(), + linalg::LinalgTransformationFilter filter = + linalg::LinalgTransformationFilter()); + +/// Create a LinalgStrategyPromotePass. +std::unique_ptr> createLinalgStrategyPromotePass( + linalg::LinalgPromotionOptions opt = linalg::LinalgPromotionOptions(), + linalg::LinalgTransformationFilter filter = + linalg::LinalgTransformationFilter()); + +/// Create a LinalgStrategyVectorizePass. +std::unique_ptr> +createLinalgStrategyVectorizePass(linalg::LinalgVectorizationOptions opt = + linalg::LinalgVectorizationOptions(), + linalg::LinalgTransformationFilter filter = + linalg::LinalgTransformationFilter()); + +/// Create a LinalgStrategyEnablePass. +std::unique_ptr> createLinalgStrategyEnablePass( + linalg::LinalgEnablingOptions opt = linalg::LinalgEnablingOptions(), + linalg::LinalgTransformationFilter filter = + linalg::LinalgTransformationFilter()); + +/// Create a LinalgStrategyLowerVectorsPass. +std::unique_ptr> +createLinalgStrategyLowerVectorsPass(linalg::LinalgVectorLoweringOptions opt = + linalg::LinalgVectorLoweringOptions(), + linalg::LinalgTransformationFilter filter = + linalg::LinalgTransformationFilter()); + //===----------------------------------------------------------------------===// // Registration //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td --- a/mlir/include/mlir/Dialect/Linalg/Passes.td +++ b/mlir/include/mlir/Dialect/Linalg/Passes.td @@ -229,4 +229,66 @@ let dependentDialects = ["linalg::LinalgDialect", "scf::SCFDialect"]; } +def LinalgStrategyTilePass + : FunctionPass<"linalg-strategy-tile-pass"> { + let summary = "Configurable pass to apply pattern-based linalg tiling."; + let constructor = "mlir::createLinalgStrategyTilePass()"; + let dependentDialects = ["linalg::LinalgDialect"]; + let options = [ + Option<"anchorFuncName", "anchor-func", "std::string", /*default=*/"", + "Which func op is the anchor to latch on.">, + Option<"anchorOpName", "anchor-op", "std::string", /*default=*/"", + "Which linalg op within the func is the anchor to latch on.">, + ]; +} + +def LinalgStrategyPromotePass + : FunctionPass<"linalg-strategy-promote-pass"> { + let summary = "Configurable pass to apply pattern-based linalg promotion."; + let constructor = "mlir::createLinalgStrategyPromotePass()"; + let dependentDialects = ["linalg::LinalgDialect"]; + let options = [ + Option<"anchorFuncName", "anchor-func", "std::string", /*default=*/"", + "Which func op is the anchor to latch on.">, + Option<"anchorOpName", "anchor-op", "std::string", /*default=*/"", + "Which linalg op within the func is the anchor to latch on.">, + ]; +} + +def LinalgStrategyVectorizePass + : FunctionPass<"linalg-strategy-vectorize-pass"> { + let summary = "Configurable pass to apply pattern-based linalg vectorization."; + let constructor = "mlir::createLinalgStrategyVectorizePass()"; + let dependentDialects = ["linalg::LinalgDialect"]; + let options = [ + Option<"anchorFuncName", "anchor-func", "std::string", /*default=*/"", + "Which func op is the anchor to latch on.">, + Option<"anchorOpName", "anchor-op", "std::string", /*default=*/"", + "Which linalg op within the func is the anchor to latch on.">, + ]; +} + +def LinalgStrategyEnablePass + : FunctionPass<"linalg-strategy-enable-pass"> { + let summary = "Configurable pass to enable the application of other " + "pattern-based linalg passes."; + let constructor = "mlir::createLinalgStrategyEnablePass()"; + let dependentDialects = ["linalg::LinalgDialect"]; + let options = [ + Option<"anchorFuncName", "anchor-func", "std::string", /*default=*/"", + "Which func op is the anchor to latch on.">, + ]; +} + +def LinalgStrategyLowerVectorsPass + : FunctionPass<"linalg-strategy-lower-vectors-pass"> { + let summary = "Configurable pass to lower vector operations."; + let constructor = "mlir::createLinalgStrategyLowerVectorsPass()"; + let dependentDialects = ["linalg::LinalgDialect"]; + let options = [ + Option<"anchorFuncName", "anchor-func", "std::string", /*default=*/"", + "Which func op is the anchor to latch on.">, + ]; +} + #endif // MLIR_DIALECT_LINALG_PASSES diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h b/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h @@ -10,7 +10,8 @@ #define MLIR_DIALECT_LINALG_TRANSFORMS_CODEGENSTRATEGY_H_ #include "mlir/Conversion/VectorToSCF/VectorToSCF.h" -#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/Linalg/Passes.h" +#include "mlir/Pass/PassManager.h" namespace mlir { @@ -21,69 +22,23 @@ /// Abstract Transformation class applied in a sequence that also handles state /// through markers. struct Transformation { - explicit Transformation(linalg::LinalgTransformationFilter::FilterFunction f) + explicit Transformation(LinalgTransformationFilter::FilterFunction f) : filter(f) {} virtual ~Transformation() = default; - virtual RewritePatternSet - buildRewritePatterns(MLIRContext *context, - linalg::LinalgTransformationFilter m) = 0; - linalg::LinalgTransformationFilter::FilterFunction filter = nullptr; + virtual void addToPassPipeline(OpPassManager &pm, + LinalgTransformationFilter m) const = 0; + LinalgTransformationFilter::FilterFunction filter = nullptr; }; -/// SFINAE: Enqueue helper for ConcreteOpType that have a `getOperationName`. -template