diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h --- a/mlir/include/mlir/Dialect/Linalg/Passes.h +++ b/mlir/include/mlir/Dialect/Linalg/Passes.h @@ -97,12 +97,6 @@ const linalg::LinalgTransformationFilter &filter = linalg::LinalgTransformationFilter()); -/// Create a LinalgStrategyDecomposePass. -// TODO: if/when we need finer control add an `opName` parameter. -std::unique_ptr> createLinalgStrategyDecomposePass( - const linalg::LinalgTransformationFilter &filter = - linalg::LinalgTransformationFilter()); - /// Create a LinalgStrategyRemoveMarkersPass. std::unique_ptr> createLinalgStrategyRemoveMarkersPass(); diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td --- a/mlir/include/mlir/Dialect/Linalg/Passes.td +++ b/mlir/include/mlir/Dialect/Linalg/Passes.td @@ -200,18 +200,6 @@ ]; } -// TODO: if/when we need finer control add an anchorOp option. -def LinalgStrategyDecomposePass - : Pass<"linalg-strategy-decompose-pass", "func::FuncOp"> { - let summary = "Configurable pass to apply pattern-based generalization."; - let constructor = "mlir::createLinalgStrategyDecomposePass()"; - let dependentDialects = ["linalg::LinalgDialect"]; - let options = [ - Option<"anchorFuncName", "anchor-func", "std::string", /*default=*/"", - "Which func op is the anchor to latch on.">, - ]; -} - def LinalgStrategyRemoveMarkersPass : Pass<"linalg-strategy-remove-markers-pass", "func::FuncOp"> { let summary = "Cleanup pass that drops markers."; diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h b/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h @@ -81,17 +81,6 @@ linalg::LinalgPaddingOptions options; }; -/// Represent one application of createLinalgStrategyDecomposePass. -struct Decompose : public Transformation { - explicit Decompose(LinalgTransformationFilter::FilterFunction f = nullptr) - : Transformation(std::move(f)) {} - - void addToPassPipeline(OpPassManager &pm, - LinalgTransformationFilter m) const override { - pm.addPass(createLinalgStrategyDecomposePass(m)); - } -}; - /// Codegen strategy controls how a Linalg op is progressively lowered. struct CodegenStrategy { /// Append a pattern to tile the Op `opName` and fuse its producers with @@ -142,17 +131,6 @@ LinalgTransformationFilter::FilterFunction f = nullptr) { return b ? pad(opName, std::move(options), std::move(f)) : *this; } - /// Append patterns to decompose convolutions. - CodegenStrategy & - decompose(const LinalgTransformationFilter::FilterFunction &f = nullptr) { - transformationSequence.emplace_back(std::make_unique(f)); - return *this; - } - /// Conditionally append patterns to decompose convolutions. - CodegenStrategy & - decomposeIf(bool b, LinalgTransformationFilter::FilterFunction f = nullptr) { - return b ? decompose(std::move(f)) : *this; - } /// Configure the post staged-patterns global enabling passes options. CodegenStrategy & setVectorTransferToSCFOptions(LinalgEnablingOptions options) { diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -764,11 +764,7 @@ template struct DownscaleSizeOneWindowed2DConvolution final : public OpRewritePattern { - DownscaleSizeOneWindowed2DConvolution( - MLIRContext *context, - LinalgTransformationFilter f = LinalgTransformationFilter(), - PatternBenefit benefit = 1) - : OpRewritePattern(context, benefit), filter(std::move(f)) {} + using OpRewritePattern::OpRewritePattern; FailureOr returningMatchAndRewrite(Conv2DOp convOp, PatternRewriter &rewriter) const; @@ -777,10 +773,6 @@ PatternRewriter &rewriter) const override { return returningMatchAndRewrite(convOp, rewriter); } - -private: - /// LinalgTransformMarker handles special attribute manipulations. - LinalgTransformationFilter filter; }; extern template struct DownscaleSizeOneWindowed2DConvolution { - DownscaleDepthwiseConv2DNhwcHwcOp( - MLIRContext *context, - LinalgTransformationFilter f = LinalgTransformationFilter(), - PatternBenefit benefit = 1) - : OpRewritePattern(context, benefit), - filter(std::move(f)) {} + DownscaleDepthwiseConv2DNhwcHwcOp(MLIRContext *context, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit) {} FailureOr returningMatchAndRewrite(DepthwiseConv2DNhwcHwcOp convOp, @@ -807,10 +796,6 @@ PatternRewriter &rewriter) const override { return returningMatchAndRewrite(convOp, rewriter); } - -private: - /// LinalgTransformMarker handles special attribute manipulations. - LinalgTransformationFilter filter; }; /// @@ -1007,10 +992,8 @@ /// Populates patterns to decompose high-D convolution ops into low-D ones. This /// is a step in progressive lowering for convolution ops, afterwards we can /// vectorize the low-D convolution ops. -void populateDecomposeConvolutionPatterns( - RewritePatternSet &patterns, - const LinalgTransformationFilter &filter = LinalgTransformationFilter(), - PatternBenefit benefit = 1); +void populateDecomposeConvolutionPatterns(RewritePatternSet &patterns, + PatternBenefit benefit = 1); //===----------------------------------------------------------------------===// // Op-specific patterns. diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp @@ -155,31 +155,6 @@ LinalgTransformationFilter filter; }; -/// Configurable pass to apply lowering of coarser-grained named linalg ops into -/// finer-grained named versions. -struct LinalgStrategyDecomposePass - : public impl::LinalgStrategyDecomposePassBase< - LinalgStrategyDecomposePass> { - - LinalgStrategyDecomposePass() = default; - - LinalgStrategyDecomposePass(LinalgTransformationFilter filter) - : filter(std::move(filter)) {} - - void runOnOperation() override { - auto funcOp = getOperation(); - if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) - return; - RewritePatternSet decompositionPattern(funcOp.getContext()); - populateDecomposeConvolutionPatterns(decompositionPattern, filter); - if (failed(applyPatternsAndFoldGreedily(funcOp, - std::move(decompositionPattern)))) - signalPassFailure(); - } - - LinalgTransformationFilter filter; -}; - /// Configurable pass to lower vector operations. struct LinalgStrategyRemoveMarkersPass : public impl::LinalgStrategyRemoveMarkersPassBase< @@ -221,14 +196,6 @@ return std::make_unique(opName, opt, filter); } -/// Create a LinalgStrategyDecomposePass. -// TODO: if/when we need finer control add an `opName` parameter. -std::unique_ptr> -mlir::createLinalgStrategyDecomposePass( - const LinalgTransformationFilter &filter) { - return std::make_unique(filter); -} - /// Create a LinalgStrategyRemoveMarkersPass. std::unique_ptr> mlir::createLinalgStrategyRemoveMarkersPass() { diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -782,8 +782,6 @@ template FailureOr DownscaleSizeOneWindowed2DConvolution:: returningMatchAndRewrite(Conv2DOp convOp, PatternRewriter &rewriter) const { - if (failed(filter.checkAndNotify(rewriter, convOp))) - return failure(); if (convOp.hasBufferSemantics()) return failure(); // To be implemented. @@ -867,7 +865,6 @@ rewriter, loc, conv1DOp.getResult(0), output); rewriter.replaceOp(convOp, inserted); - filter.replaceLinalgTransformationFilter(rewriter, conv1DOp); return conv1DOp; } @@ -879,8 +876,6 @@ FailureOr DownscaleDepthwiseConv2DNhwcHwcOp::returningMatchAndRewrite( DepthwiseConv2DNhwcHwcOp convOp, PatternRewriter &rewriter) const { - if (failed(filter.checkAndNotify(rewriter, convOp))) - return failure(); if (convOp.hasBufferSemantics()) return failure(); // To be implemented. @@ -943,17 +938,15 @@ rewriter, loc, conv1DOp.getResult(0), output); rewriter.replaceOp(convOp, inserted); - filter.replaceLinalgTransformationFilter(rewriter, conv1DOp); return conv1DOp; } -void linalg::populateDecomposeConvolutionPatterns( - RewritePatternSet &patterns, const LinalgTransformationFilter &filter, - PatternBenefit benefit) { +void linalg::populateDecomposeConvolutionPatterns(RewritePatternSet &patterns, + PatternBenefit benefit) { patterns.add, DownscaleSizeOneWindowed2DConvolution, - DownscaleDepthwiseConv2DNhwcHwcOp>(patterns.getContext(), filter, + DownscaleDepthwiseConv2DNhwcHwcOp>(patterns.getContext(), benefit); }