diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h --- a/mlir/include/mlir/Dialect/Linalg/Passes.h +++ b/mlir/include/mlir/Dialect/Linalg/Passes.h @@ -127,6 +127,13 @@ const linalg::LinalgTransformationFilter &filter = linalg::LinalgTransformationFilter()); +/// Create a LinalgStrategyPeelPass. +std::unique_ptr> createLinalgStrategyPeelPass( + StringRef opName = "", + linalg::LinalgPeelOptions opt = linalg::LinalgPeelOptions(), + const linalg::LinalgTransformationFilter &filter = + linalg::LinalgTransformationFilter()); + /// Create a LinalgStrategyVectorizePass. std::unique_ptr> createLinalgStrategyVectorizePass( StringRef opName = "", diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td --- a/mlir/include/mlir/Dialect/Linalg/Passes.td +++ b/mlir/include/mlir/Dialect/Linalg/Passes.td @@ -272,6 +272,22 @@ ]; } +def LinalgStrategyPeelPass + : Pass<"linalg-strategy-peel-pass", "func::FuncOp"> { + let summary = "Configurable pass to apply pattern-based linalg peeling."; + let constructor = "mlir::createLinalgStrategyPeelPass()"; + let dependentDialects = [ + "linalg::LinalgDialect", + "scf::SCFDialect" + ]; + let options = [ + Option<"anchorFuncName", "anchor-func", "std::string", /*default=*/"", + "Which func op is the anchor to latch on.">, + Option<"anchorOpName", "anchor-op", "std::string", /*default=*/"", + "Which linalg op within the func is the anchor to latch on.">, + ]; +} + def LinalgStrategyVectorizePass : Pass<"linalg-strategy-vectorize-pass", "func::FuncOp"> { let summary = "Configurable pass to apply pattern-based linalg vectorization."; diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h b/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h @@ -141,6 +141,26 @@ } }; +/// Represent one application of createLinalgStrategyPeelPass. +struct Peel : public Transformation { + explicit Peel(linalg::LinalgPeelOptions options, + LinalgTransformationFilter::FilterFunction f = nullptr) + : Transformation(std::move(f)), opName(), options(options) {} + + Peel(StringRef name, linalg::LinalgPeelOptions options, + LinalgTransformationFilter::FilterFunction f = nullptr) + : Transformation(std::move(f)), opName(name), options(options) {} + + void addToPassPipeline(OpPassManager &pm, + LinalgTransformationFilter m) const override { + pm.addPass(createLinalgStrategyPeelPass(opName, options, m)); + } + +private: + std::string opName; + linalg::LinalgPeelOptions options; +}; + /// Represent one application of createLinalgStrategyVectorizePass. struct Vectorize : public Transformation { explicit Vectorize(linalg::LinalgVectorizationOptions options, @@ -288,6 +308,20 @@ decomposeIf(bool b, LinalgTransformationFilter::FilterFunction f = nullptr) { return b ? decompose(std::move(f)) : *this; } + /// Append a pattern to peel 'LinalgOpType'. + CodegenStrategy & + peel(StringRef opName, const LinalgPeelOptions &options, + const LinalgTransformationFilter::FilterFunction &f = nullptr) { + transformationSequence.emplace_back( + std::make_unique(opName, options, f)); + return *this; + } + /// Conditionally append a pattern to peel 'LinalgOpType'. + CodegenStrategy & + peelIf(bool b, StringRef opName, const LinalgPeelOptions &options, + LinalgTransformationFilter::FilterFunction f = nullptr) { + return b ? peel(opName, options, std::move(f)) : *this; + } /// Append a pattern to rewrite `LinalgOpType` as a vector operation. CodegenStrategy & vectorize(StringRef opName, diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -129,6 +129,9 @@ FailureOr tileLinalgOp(RewriterBase &b, LinalgOp op, const LinalgTilingOptions &options); +/// Peel and canonicalize 'loops'. +void peelLoops(RewriterBase &rewriter, ArrayRef loops); + /// Peel the loops of a TiledLinalgOp. void peelTiledLinalgOp(RewriterBase &rewriter, TiledLinalgOp &res, ArrayRef peeledLoops, @@ -965,6 +968,49 @@ : LinalgBasePromotionPattern(opName, context, options, f, benefit) {} }; +/// +/// Linalg peeling patterns. +/// + +/// Compute the loops to peel and return them in a SmallVector. Loops will be +/// peeled in order of appearance in the SmallVector. This order will impact the +/// output IR. If an inner-to-outer order is provided, the peeled iterations of +/// the outer loops will also contain the peeled inner loops. If an +/// outer-to-inner order is provided, the peeled iterations of the outer loops +/// will not contain any peeled inner loops. +using LoopsToPeelComputationFunction = std::function &)>; + +struct LinalgPeelOptions { + LoopsToPeelComputationFunction loopsToPeelComputationFunction = nullptr; +}; + +/// `filter` controls LinalgTransformMarker matching and update when specified. +struct LinalgPeelingPattern : public OpInterfaceRewritePattern { + /// Construct a generic pattern applied to all LinalgOp that verify `filter`. + LinalgPeelingPattern( + MLIRContext *context, + LinalgTransformationFilter f = LinalgTransformationFilter(), + LinalgPeelOptions options = LinalgPeelOptions(), + PatternBenefit benefit = 1); + + /// Construct a pattern specifically applied to `opName`. + LinalgPeelingPattern( + StringRef opName, MLIRContext *context, + LinalgPeelOptions options = LinalgPeelOptions(), + LinalgTransformationFilter f = LinalgTransformationFilter(), + PatternBenefit benefit = 1); + + LogicalResult matchAndRewrite(LinalgOp linalgOp, + PatternRewriter &rewriter) const override; + +private: + /// LinalgTransformMarker handles special attribute manipulations. + const LinalgTransformationFilter filter; + /// Peeling options. + const LinalgPeelOptions options; +}; + /// /// Linalg vectorization patterns. /// diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp @@ -262,6 +262,40 @@ LinalgTransformationFilter filter; }; +/// Configurable pass to apply pattern-based linalg peeling. +struct LinalgStrategyPeelPass + : public LinalgStrategyPeelPassBase { + + LinalgStrategyPeelPass() = default; + + LinalgStrategyPeelPass(StringRef opName, LinalgPeelOptions opt, + LinalgTransformationFilter filt) + : options(opt), filter(std::move(filt)) { + this->anchorOpName.setValue(opName.str()); + } + + void runOnOperation() override { + auto funcOp = getOperation(); + if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) + return; + + RewritePatternSet peelingPatterns(funcOp.getContext()); + if (!anchorOpName.empty()) { + peelingPatterns.add( + anchorOpName, funcOp.getContext(), options, filter); + } else { + peelingPatterns.add(funcOp.getContext(), filter, + options); + } + if (failed( + applyPatternsAndFoldGreedily(funcOp, std::move(peelingPatterns)))) + return signalPassFailure(); + } + + LinalgPeelOptions options; + LinalgTransformationFilter filter; +}; + /// Configurable pass to apply pattern-based linalg vectorization. struct LinalgStrategyVectorizePass : public LinalgStrategyVectorizePassBase { @@ -506,6 +540,13 @@ filter); } +/// Create a LinalgStrategyPeelPass. +std::unique_ptr> +mlir::createLinalgStrategyPeelPass(StringRef opName, LinalgPeelOptions opt, + const LinalgTransformationFilter &filter) { + return std::make_unique(opName, opt, filter); +} + /// Create a LinalgStrategyVectorizePass. std::unique_ptr> mlir::createLinalgStrategyVectorizePass( diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -323,6 +323,15 @@ .Default([&](Operation *op) { return op->getResults(); }); } +/// Peel and canonicalize 'loops'. +void mlir::linalg::peelLoops(RewriterBase &rewriter, + ArrayRef loops) { + for (auto loopOp : loops) { + SmallVector loopResults; + loopResults = peelLoop(rewriter, loopOp); + } +} + /// Peel loops after tiling. void mlir::linalg::peelTiledLinalgOp(RewriterBase &rewriter, TiledLinalgOp &res, ArrayRef peeledLoops, @@ -716,6 +725,35 @@ return success(); } +mlir::linalg::LinalgPeelingPattern::LinalgPeelingPattern( + MLIRContext *context, LinalgTransformationFilter f, + LinalgPeelOptions options, PatternBenefit benefit) + : OpInterfaceRewritePattern(context, benefit), + filter(std::move(f)), options(std::move(options)) {} + +mlir::linalg::LinalgPeelingPattern::LinalgPeelingPattern( + StringRef opName, MLIRContext *context, LinalgPeelOptions options, + LinalgTransformationFilter f, PatternBenefit benefit) + : OpInterfaceRewritePattern(context, benefit), + filter(f.addOpNameFilter(opName)), options(std::move(options)) {} + +LogicalResult mlir::linalg::LinalgPeelingPattern::matchAndRewrite( + LinalgOp linalgOp, PatternRewriter &rewriter) const { + if (failed(filter.checkAndNotify(rewriter, linalgOp))) + return failure(); + + // Increase marker counter even if peeling doesn't happen for this op. + filter.replaceLinalgTransformationFilter(rewriter, linalgOp); + + if (!options.loopsToPeelComputationFunction) + return failure(); + + SmallVector loopsToPeel; + options.loopsToPeelComputationFunction(rewriter, linalgOp, loopsToPeel); + peelLoops(rewriter, loopsToPeel); + return success(); +} + mlir::linalg::LinalgVectorizationPattern::LinalgVectorizationPattern( MLIRContext *context, LinalgTransformationFilter f, LinalgVectorizationOptions options, PatternBenefit benefit)