diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h --- a/mlir/include/mlir/Dialect/Linalg/Passes.h +++ b/mlir/include/mlir/Dialect/Linalg/Passes.h @@ -67,7 +67,7 @@ /// buffers instead. std::unique_ptr> createLinalgBufferizePass(); -/// Create a pass to conver named Linalg operations to Linalg generic +/// Create a pass to convert named Linalg operations to Linalg generic /// operations. std::unique_ptr> createLinalgGeneralizationPass(); @@ -108,6 +108,11 @@ linalg::LinalgTransformationFilter filter = linalg::LinalgTransformationFilter()); +/// Create a LinalgStrategyDecomposePass. +// TODO: atm this is applied to all supported ops. If/when we need finer control +// this should be exposed with an opName + filter and a proper pattern. +std::unique_ptr> createLinalgStrategyDecomposePass(); + /// Create a LinalgStrategyInterchangePass. std::unique_ptr> createLinalgStrategyInterchangePass(ArrayRef iteratorInterchange = {}, diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td --- a/mlir/include/mlir/Dialect/Linalg/Passes.td +++ b/mlir/include/mlir/Dialect/Linalg/Passes.td @@ -287,6 +287,19 @@ ]; } +// TODO: atm this is applied to all supported ops. If/when we need finer control +// this should be exposed with an opName + filter and a proper pattern. +def LinalgStrategyDecomposePass + : FunctionPass<"linalg-strategy-decompose-pass"> { + let summary = "Configurable pass to apply pattern-based generalization."; + let constructor = "mlir::createLinalgStrategyDecomposePass()"; + let dependentDialects = ["linalg::LinalgDialect"]; + let options = [ + Option<"anchorFuncName", "anchor-func", "std::string", /*default=*/"", + "Which func op is the anchor to latch on.">, + ]; +} + def LinalgStrategyInterchangePass : FunctionPass<"linalg-strategy-interchange-pass"> { let summary = "Configurable pass to apply pattern-based iterator interchange."; diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp @@ -134,6 +134,25 @@ LinalgTransformationFilter filter; }; +/// Configurable pass to apply lowering of coarser-grained named linalg ops into +/// finer-grained named versions. +struct LinalgStrategyDecomposePass + : public LinalgStrategyDecomposePassBase { + + LinalgStrategyDecomposePass() = default; + + void runOnFunction() override { + auto funcOp = getFunction(); + if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) + return; + RewritePatternSet decompositionPattern(funcOp.getContext()); + populateDecomposeConvolutionPatterns(decompositionPattern); + if (failed(applyPatternsAndFoldGreedily(funcOp, + std::move(decompositionPattern)))) + signalPassFailure(); + } +}; + /// Configurable pass to apply pattern-based linalg generalization. struct LinalgStrategyInterchangePass : public LinalgStrategyInterchangePassBase { @@ -386,6 +405,13 @@ LinalgTransformationFilter filter) { return std::make_unique(opName, filter); } +/// Create a LinalgStrategyDecomposePass. +// TODO: atm this is applied to all supported ops. If/when we need finer control +// this should be exposed with an opName + filter and a proper pattern. +std::unique_ptr> +mlir::createLinalgStrategyDecomposePass() { + return std::make_unique(); +} /// Create a LinalgStrategyInterchangePass. std::unique_ptr> diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp --- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp @@ -14,13 +14,14 @@ #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/GPU/GPUDialect.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" +#include "mlir/Dialect/Linalg/Passes.h" #include "mlir/Dialect/Linalg/Transforms/HoistPadding.h" #include "mlir/Dialect/Linalg/Transforms/Hoisting.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/Vector/VectorOps.h" -#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/SetVector.h" @@ -581,12 +582,6 @@ (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); } -static void applyDecomposeConvolutionPatterns(FuncOp funcOp) { - RewritePatternSet patterns(funcOp.getContext()); - populateDecomposeConvolutionPatterns(patterns); - (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); -} - static void applyPadTensorToGenericPatterns(FuncOp funcOp) { RewritePatternSet patterns(funcOp.getContext()); patterns.add(funcOp.getContext()); @@ -830,8 +825,14 @@ return applyPadPattern(getFunction(), packPaddings, hoistPaddings); if (testInterchangePattern.hasValue()) return applyInterchangePattern(getFunction(), testInterchangePattern); - if (testDecomposeConvolutionPattern) - return applyDecomposeConvolutionPatterns(getFunction()); + + if (testDecomposeConvolutionPattern) { + // TODO: thread all tests through LinalgStrategy passes. + OpPassManager dynamicPM("builtin.func"); + dynamicPM.addPass(createLinalgStrategyDecomposePass()); + if (failed(runPipeline(dynamicPM, getFunction()))) + return signalPassFailure(); + } } namespace mlir {