diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -362,67 +362,6 @@ //===----------------------------------------------------------------------===// // Transformations exposed as rewrite patterns. //===----------------------------------------------------------------------===// -// Marker used as attribute name in generated Linalg rewriting transformations. -struct LinalgTransforms { - static const StringLiteral kLinalgTransformMarker; -}; - -/// Helper class to control application of linalg transformation patterns. -/// Control comes in 2 forms: -/// 1. attribute matching and setting behavior using the attribute named -/// `kLinalgTransformMarker`. This can be used to build a state machine -/// using attributes and incrementally applying patterns to advance states. -/// 2. filter function, which is a simple lambda on the Operation* that -/// returns a LogicalResult. -struct LinalgTransformationFilter { - using FilterFunction = std::function; - - explicit LinalgTransformationFilter( - ArrayRef matchDisjunction = {}, - Optional replacement = None); - - explicit LinalgTransformationFilter( - const FilterFunction &f, ArrayRef matchDisjunction = {}, - Optional replacement = None); - - LinalgTransformationFilter(LinalgTransformationFilter &&) = default; - LinalgTransformationFilter(const LinalgTransformationFilter &) = default; - LogicalResult checkAndNotify(PatternRewriter &rewriter, Operation *op) const; - void replaceLinalgTransformationFilter(PatternRewriter &rewriter, - Operation *op) const; - bool hasReplacementFilter(Operation *op) const; - - LinalgTransformationFilter &addFilter(const FilterFunction &f) { - if (f) - filters.push_back(f); - return *this; - } - - template - LinalgTransformationFilter &addOpFilter() { - return addFilter( - [](Operation *op) { return success(isa(op)); }); - } - - LinalgTransformationFilter &addOpNameFilter(StringRef opName) { - return addFilter([opName](Operation *op) { - return success(op->getName().getStringRef() == opName); - }); - } - - LinalgTransformationFilter &setMatchByDefault() { - matchByDefault = true; - return *this; - } - -private: - SmallVector filters; - SmallVector matchDisjunction; - Optional replacement; - /// When set to true, if the attribute is not set, it will be treated as - /// a match. Default is false. - bool matchByDefault; -}; using TileSizeComputationFunction = std::function(OpBuilder &, Operation *)>; @@ -793,14 +732,7 @@ } }; -/// -/// Linalg vectorization patterns. -/// -/// Empty for now, used for SFINAE purposes only. -struct LinalgVectorizationOptions {}; - -/// `filter` controls LinalgTransformMarker matching and update when specified. -/// See `vectorizeLinalgOp` for more details. +/// Vectorization pattern for memref::CopyOp. struct CopyVectorizationPattern : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -811,34 +743,6 @@ /// Return vector::CombiningKind for the given op. llvm::Optional getCombinerOpKind(Operation *combinerOp); -//===----------------------------------------------------------------------===// -// Transformation and lowering options exposed as auxiliary structs. -//===----------------------------------------------------------------------===// -/// Options to control the application of enabling transformations. -/// Hoisting transformations are always deemed beneficial and must be disabled -/// explicitly. -struct LinalgEnablingOptions { - /// Enable LICM. - bool licm = true; - LinalgEnablingOptions &enableLICM(bool val = true) { - licm = val; - return *this; - } - /// Enable hoisting of redundant vector transfer ops. - bool hoistRedundantVectorTransfers = true; - LinalgEnablingOptions &enableHoistRedundantVectorTransfers(bool val = true) { - hoistRedundantVectorTransfers = val; - return *this; - } - /// Enable hoisting of redundant vector transfer ops on tensor. - bool hoistRedundantVectorTransfersOnTensor = true; - LinalgEnablingOptions & - enableHoistRedundantVectorTransfersOnTensor(bool val = true) { - hoistRedundantVectorTransfersOnTensor = val; - return *this; - } -}; - //===----------------------------------------------------------------------===// // Transformations exposed as rewrite patterns. //===----------------------------------------------------------------------===// @@ -971,24 +875,6 @@ PatternRewriter &rewriter) const override; }; -//===----------------------------------------------------------------------===// -// Support for staged pattern application. -//===----------------------------------------------------------------------===// -/// Helper function to allow applying rewrite patterns, interleaved with more -/// global transformations, in a staged fashion: -/// 1. the first stage consists of a list of FrozenRewritePatternSet. Each -/// FrozenRewritePatternSet in this list is applied once, in order. -/// 2. the second stage consists of a single RewritePattern that is applied -/// greedily until convergence. -/// 3. the third stage consists of applying a lambda, generally used for -/// non-local transformation effects. This allows creating custom fused -/// transformations where patterns can be ordered and applied at a finer -/// granularity than a sequence of traditional compiler passes. -LogicalResult applyStagedPatterns( - Operation *op, ArrayRef stage1Patterns, - const FrozenRewritePatternSet &stage2Patterns, - function_ref stage3Lambda = nullptr); - /// Rewrite extract_slice(tensor.pad(x)) into tensor.pad(extract_slice(x)). struct ExtractSliceOfPadTensorSwapPattern : public OpRewritePattern { @@ -1015,20 +901,6 @@ ControlFn controlFn; }; -//===----------------------------------------------------------------------===// -// Helper classes for type list expansion. -//===----------------------------------------------------------------------===// -template -class VectorizationPatterns; - -template <> -class VectorizationPatterns<> { -public: - static void insert(RewritePatternSet &patterns, - const LinalgVectorizationOptions &options, - const LinalgTransformationFilter &f) {} -}; - /// Split Reduction options. struct SplitReductionOptions { // Ratio used to split the reduction dimension. If the ratio is <= 1, nothing diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -659,14 +659,10 @@ LogicalResult matchAndRewrite(tensor::PadOp op, PatternRewriter &rewriter) const override { - if (op->hasAttr(LinalgTransforms::kLinalgTransformMarker)) - return failure(); tensor::PadOp newPadOp; LoopNest loopNest; if (failed(tilePadOp(rewriter, op, newPadOp, loopNest, options))) return failure(); - newPadOp->setAttr(LinalgTransforms::kLinalgTransformMarker, - rewriter.getUnitAttr()); // Replace all uses of the original tensor::PadOp. rewriter.replaceOp(op, loopNest.getResults()[0]); return success(); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -47,75 +47,6 @@ //===----------------------------------------------------------------------===// // Transformations exposed as rewrite patterns. //===----------------------------------------------------------------------===// -// Marker used as attribute name in generated Linalg rewriting transformations. -const StringLiteral mlir::linalg::LinalgTransforms::kLinalgTransformMarker = - "__internal_linalg_transform__"; - -mlir::linalg::LinalgTransformationFilter::LinalgTransformationFilter( - ArrayRef matchDisjunction, Optional replacement) - : matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()), - replacement(replacement), matchByDefault(false) {} - -mlir::linalg::LinalgTransformationFilter::LinalgTransformationFilter( - const FilterFunction &f, ArrayRef matchDisjunction, - Optional replacement) - : matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()), - replacement(replacement), matchByDefault(false) { - if (f) - filters.push_back(f); -} - -LogicalResult mlir::linalg::LinalgTransformationFilter::checkAndNotify( - PatternRewriter &rewriter, Operation *op) const { - if (llvm::any_of(filters, - [&](const FilterFunction &f) { return failed(f(op)); })) - return failure(); - - auto attr = op->template getAttrOfType( - LinalgTransforms::kLinalgTransformMarker); - - if (!attr) { - // 1. Has no filter case and matchDisjunction is empty. - if (matchDisjunction.empty() || matchByDefault) - return success(); - - // 2. Has no filter but was expecting a filter. - return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { - diag << " does not have any filter from list: "; - interleaveComma(matchDisjunction, diag); - }); - } - - // 4. Match explicit filter. - for (auto filter : matchDisjunction) - if (attr.getValue() == filter) - return success(); - - // 5. Fail to match. - return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { - diag << " does not have any filter from list: "; - interleaveComma(matchDisjunction, diag); - }); -} - -void mlir::linalg::LinalgTransformationFilter:: - replaceLinalgTransformationFilter(PatternRewriter &rewriter, - Operation *op) const { - if (replacement.has_value()) - op->setAttr(LinalgTransforms::kLinalgTransformMarker, replacement.value()); - else - op->removeAttr( - rewriter.getStringAttr(LinalgTransforms::kLinalgTransformMarker)); -} - -bool mlir::linalg::LinalgTransformationFilter::hasReplacementFilter( - Operation *op) const { - if (!replacement) - return false; - auto attr = op->getAttr(LinalgTransforms::kLinalgTransformMarker) - .dyn_cast(); - return attr && attr == *replacement; -} LinalgTilingOptions & mlir::linalg::LinalgTilingOptions::setTileSizes(ArrayRef ts) { @@ -432,37 +363,6 @@ return vectorizeCopy(rewriter, copyOp); } -LogicalResult mlir::linalg::applyStagedPatterns( - Operation *op, ArrayRef stage1Patterns, - const FrozenRewritePatternSet &stage2Patterns, - function_ref stage3Lambda) { - unsigned iteration = 0; - (void)iteration; - for (const auto &patterns : stage1Patterns) { - LLVM_DEBUG(DBGS() << "Before 1st stage, iter: " << ++iteration << "\n" - << *op); - if (failed(applyPatternsAndFoldGreedily(op, patterns))) { - LLVM_DEBUG(DBGS() << "Underlying first stage rewrite did not converge"); - return failure(); - } - LLVM_DEBUG(DBGS() << "After 1st stage, iter: " << ++iteration << "\n" - << *op); - if (failed(applyPatternsAndFoldGreedily(op, stage2Patterns))) { - LLVM_DEBUG(DBGS() << "Underlying 2nd stage rewrite did not converge"); - return failure(); - } - LLVM_DEBUG(DBGS() << "After 2nd stage, iter : " << iteration << "\n" - << *op); - if (stage3Lambda) { - if (failed(stage3Lambda(op))) - return failure(); - LLVM_DEBUG(DBGS() << "After 3rd stage, iter : " << iteration << "\n" - << *op); - } - } - return success(); -} - static SmallVector getNParallelLoopsAttrs(unsigned nParallelLoops) { return SmallVector(nParallelLoops, getParallelIteratorTypeName()); } diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp --- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp @@ -127,11 +127,6 @@ patterns.add(ctx); (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); - - // Drop the marker. - funcOp.walk([](LinalgOp op) { - op->removeAttr(LinalgTransforms::kLinalgTransformMarker); - }); } static void applyVectorTransferForwardingPatterns(func::FuncOp funcOp) { @@ -182,13 +177,6 @@ /// Apply transformations specified as patterns. void TestLinalgTransforms::runOnOperation() { - auto lambda = [&](void *) { - getOperation().walk([](LinalgOp op) { - op->removeAttr(LinalgTransforms::kLinalgTransformMarker); - }); - }; - std::unique_ptr cleanupGuard{(void *)1, lambda}; - if (testPatterns) return applyPatterns(getOperation()); if (testVectorTransferForwardingPatterns) diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp --- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp +++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp @@ -31,26 +31,149 @@ using namespace mlir; +// TODO: this file should disappear and instead tests should make use of the +// transform dialect. namespace { +/// Marker used as attribute name in generated Linalg rewriting transformations. +const StringLiteral kLinalgTransformMarker = "__internal_linalg_transform__"; + +/// Helper class to control application of linalg transformation patterns. +/// Control comes in 2 forms: +/// 1. attribute matching and setting behavior using the attribute named +/// `kLinalgTransformMarker`. This can be used to build a state machine +/// using attributes and incrementally applying patterns to advance states. +/// 2. filter function, which is a simple lambda on the Operation* that +/// returns a LogicalResult. +struct LinalgTransformationFilter { + using FilterFunction = std::function; + + explicit LinalgTransformationFilter( + ArrayRef matchDisjunction = {}, + Optional replacement = None); + + explicit LinalgTransformationFilter( + const FilterFunction &f, ArrayRef matchDisjunction = {}, + Optional replacement = None); + + LinalgTransformationFilter(LinalgTransformationFilter &&) = default; + LinalgTransformationFilter(const LinalgTransformationFilter &) = default; + LogicalResult checkAndNotify(PatternRewriter &rewriter, Operation *op) const; + void replaceLinalgTransformationFilter(PatternRewriter &rewriter, + Operation *op) const; + bool hasReplacementFilter(Operation *op) const; + + LinalgTransformationFilter &addFilter(const FilterFunction &f) { + if (f) + filters.push_back(f); + return *this; + } + + template + LinalgTransformationFilter &addOpFilter() { + return addFilter( + [](Operation *op) { return success(isa(op)); }); + } + + LinalgTransformationFilter &addOpNameFilter(StringRef opName) { + return addFilter([opName](Operation *op) { + return success(op->getName().getStringRef() == opName); + }); + } + + LinalgTransformationFilter &setMatchByDefault() { + matchByDefault = true; + return *this; + } + +private: + SmallVector filters; + SmallVector matchDisjunction; + Optional replacement; + /// When set to true, if the attribute is not set, it will be treated as + /// a match. Default is false. + bool matchByDefault; +}; + +LinalgTransformationFilter::LinalgTransformationFilter( + ArrayRef matchDisjunction, Optional replacement) + : matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()), + replacement(replacement), matchByDefault(false) {} + +LinalgTransformationFilter::LinalgTransformationFilter( + const FilterFunction &f, ArrayRef matchDisjunction, + Optional replacement) + : matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()), + replacement(replacement), matchByDefault(false) { + if (f) + filters.push_back(f); +} + +LogicalResult +LinalgTransformationFilter::checkAndNotify(PatternRewriter &rewriter, + Operation *op) const { + if (llvm::any_of(filters, + [&](const FilterFunction &f) { return failed(f(op)); })) + return failure(); + + auto attr = op->template getAttrOfType(kLinalgTransformMarker); + + if (!attr) { + // 1. Has no filter case and matchDisjunction is empty. + if (matchDisjunction.empty() || matchByDefault) + return success(); + + // 2. Has no filter but was expecting a filter. + return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { + diag << " does not have any filter from list: "; + interleaveComma(matchDisjunction, diag); + }); + } + + // 4. Match explicit filter. + for (auto filter : matchDisjunction) + if (attr.getValue() == filter) + return success(); + + // 5. Fail to match. + return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { + diag << " does not have any filter from list: "; + interleaveComma(matchDisjunction, diag); + }); +} + +void LinalgTransformationFilter::replaceLinalgTransformationFilter( + PatternRewriter &rewriter, Operation *op) const { + if (replacement.has_value()) + op->setAttr(kLinalgTransformMarker, replacement.value()); + else + op->removeAttr(rewriter.getStringAttr(kLinalgTransformMarker)); +} + +bool LinalgTransformationFilter::hasReplacementFilter(Operation *op) const { + if (!replacement) + return false; + auto attr = op->getAttr(kLinalgTransformMarker).dyn_cast(); + return attr && attr == *replacement; +} + /// Pattern for testing `TileUsingSCFForOp` pattern (that tiles operations using /// the `TilingInterface` with `scf.for` ops for iterating over the tiles) while /// using a `filter` to avoid recursive application. struct TestTileUsingSCFForOp : public OpInterfaceRewritePattern { - TestTileUsingSCFForOp(MLIRContext *context, scf::SCFTilingOptions options, - linalg::LinalgTransformationFilter filter = - linalg::LinalgTransformationFilter(), - PatternBenefit benefit = 1) + TestTileUsingSCFForOp( + MLIRContext *context, scf::SCFTilingOptions options, + LinalgTransformationFilter filter = LinalgTransformationFilter(), + PatternBenefit benefit = 1) : OpInterfaceRewritePattern(context, benefit), options(std::move(options)), filter(std::move(filter)) {} /// Construct a generic pattern applied to `opName`. - TestTileUsingSCFForOp(StringRef opName, MLIRContext *context, - scf::SCFTilingOptions options, - linalg::LinalgTransformationFilter filter = - linalg::LinalgTransformationFilter(), - PatternBenefit benefit = 1) + TestTileUsingSCFForOp( + StringRef opName, MLIRContext *context, scf::SCFTilingOptions options, + LinalgTransformationFilter filter = LinalgTransformationFilter(), + PatternBenefit benefit = 1) : OpInterfaceRewritePattern(context, benefit), options(std::move(options)), filter(std::move(filter)) {} @@ -76,7 +199,7 @@ private: scf::SCFTilingOptions options; - linalg::LinalgTransformationFilter filter; + LinalgTransformationFilter filter; }; /// Pattern for testing `TileConsumerAndFuseProducersUsingSCFForOp` pattern @@ -87,8 +210,7 @@ : public OpInterfaceRewritePattern { TestTileConsumerAndFuseProducersGreedilyUsingSCFForOp( MLIRContext *context, scf::SCFTileAndFuseOptions options, - linalg::LinalgTransformationFilter filter = - linalg::LinalgTransformationFilter(), + LinalgTransformationFilter filter = LinalgTransformationFilter(), PatternBenefit benefit = 1) : OpInterfaceRewritePattern(context, benefit), options(std::move(options)), filter(std::move(filter)) {} @@ -97,8 +219,7 @@ TestTileConsumerAndFuseProducersGreedilyUsingSCFForOp( StringRef opName, MLIRContext *context, scf::SCFTileAndFuseOptions options, - linalg::LinalgTransformationFilter filter = - linalg::LinalgTransformationFilter(), + LinalgTransformationFilter filter = LinalgTransformationFilter(), PatternBenefit benefit = 1) : OpInterfaceRewritePattern(context, benefit), options(std::move(options)), filter(std::move(filter)) {} @@ -129,7 +250,7 @@ private: scf::SCFTileAndFuseOptions options; - linalg::LinalgTransformationFilter filter; + LinalgTransformationFilter filter; }; /// Pattern to lower operations that implement the `TilingInterface` to @@ -202,8 +323,8 @@ ArrayRef interchange = {}) { scf::SCFTilingOptions tilingOptions; tilingOptions.setTileSizes(tileSizes).setInterchange(interchange); - linalg::LinalgTransformationFilter filter( - StringAttr::get(context, filterName), StringAttr::get(context, "tiled")); + LinalgTransformationFilter filter(StringAttr::get(context, filterName), + StringAttr::get(context, "tiled")); patterns.add(context, tilingOptions, filter); } @@ -215,8 +336,8 @@ scf::SCFTileAndFuseOptions tileAndFuseOptions; tileAndFuseOptions.tilingOptions.setTileSizes(tileSizes).setInterchange( interchange); - linalg::LinalgTransformationFilter filter( - StringAttr::get(context, filterName), StringAttr::get(context, "tiled")); + LinalgTransformationFilter filter(StringAttr::get(context, filterName), + StringAttr::get(context, "tiled")); patterns.add( context, tileAndFuseOptions, filter); }