diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -11,6 +11,7 @@ #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/Vector/VectorOps.h" +#include "mlir/IR/Identifier.h" #include "mlir/IR/PatternMatch.h" #include "llvm/ADT/SmallBitVector.h" @@ -206,15 +207,16 @@ /// Helper class to control common attribute matching and setting behavior. struct LinalgMarker { - LinalgMarker(ArrayRef matchDisjunction = {}, - Optional replacement = None); - LinalgMarker(ArrayRef matchDisjunction, StringRef replacement); + explicit LinalgMarker(ArrayRef matchDisjunction = {}, + Optional replacement = None); + LinalgMarker(LinalgMarker &&) = default; + LinalgMarker(const LinalgMarker &) = default; LogicalResult checkAndNotify(PatternRewriter &rewriter, Operation *op) const; void replaceLinalgMarker(PatternRewriter &rewriter, Operation *op) const; private: - SmallVector matchDisjunction; - Optional replacement; + SmallVector matchDisjunction; + Optional replacement; }; /// diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -459,8 +459,8 @@ public: static void insert(OwningRewritePatternList &patterns, const LinalgTilingOptions &options, MLIRContext *ctx) { - patterns.insert>(ctx, options, - LinalgMarker({}, "tiled")); + patterns.insert>( + ctx, options, LinalgMarker({}, Identifier::get("tiled", ctx))); RewritePatternList::insert(patterns, options, ctx); } }; diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -46,15 +46,11 @@ const StringLiteral mlir::linalg::LinalgTransforms::kLinalgTransformMarker = "__internal_linalg_transform__"; -mlir::linalg::LinalgMarker::LinalgMarker(ArrayRef matchDisjunction, - Optional replacement) +mlir::linalg::LinalgMarker::LinalgMarker(ArrayRef matchDisjunction, + Optional replacement) : matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()), replacement(replacement) {} -mlir::linalg::LinalgMarker::LinalgMarker(ArrayRef matchDisjunction, - StringRef replacement) - : LinalgMarker(matchDisjunction, Optional{replacement}) {} - LogicalResult mlir::linalg::LinalgMarker::checkAndNotify(PatternRewriter &rewriter, Operation *op) const { @@ -66,12 +62,7 @@ if (matchDisjunction.empty()) return success(); - // 2. Has no marker and matchDisjuntion matches the no-moarker case. - for (auto marker : matchDisjunction) - if (marker.empty()) - return success(); - - // 3. Has no marker but was expecting a marker. + // 2. Has no marker but was expecting a marker. return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { diag << " does not have any marker from list: "; interleaveComma(matchDisjunction, diag); diff --git a/mlir/test/Dialect/Linalg/transform-patterns.mlir b/mlir/test/Dialect/Linalg/transform-patterns.mlir --- a/mlir/test/Dialect/Linalg/transform-patterns.mlir +++ b/mlir/test/Dialect/Linalg/transform-patterns.mlir @@ -14,9 +14,10 @@ func @dot(%x: memref, %y: memref, %v: memref) { - linalg.dot(%x, %y, %v) : memref, - memref, - memref + linalg.dot(%x, %y, %v) { __internal_linalg_transform__ = "MEM" } : + memref, + memref, + memref return } // CHECK-LABEL: func @dot @@ -35,9 +36,10 @@ func @matvec(%A: memref, %x: memref, %y: memref) { - linalg.matvec(%A, %x, %y) : memref, - memref, - memref + linalg.matvec(%A, %x, %y) : + memref, + memref, + memref return } // CHECK-LABEL: func @matvec @@ -51,9 +53,10 @@ func @matmul(%A: memref, %B: memref, %C: memref) { - linalg.matmul(%A, %B, %C) : memref, - memref, - memref + linalg.matmul(%A, %B, %C) { __internal_linalg_transform__ = "MEM" } : + memref, + memref, + memref return } // CHECK-LABEL: func @matmul diff --git a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp --- a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp +++ b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp @@ -66,26 +66,29 @@ //===--------------------------------------------------------------------===// patterns.insert>( ctx, LinalgTilingOptions().setTileSizes({2000, 3000, 4000}), - LinalgMarker({"MEM", {}}, "L3")); + LinalgMarker(Identifier::get("MEM", ctx), Identifier::get("L3", ctx))); patterns.insert>( ctx, LinalgTilingOptions().setTileSizes({200, 300, 400}), - LinalgMarker({"L3"}, "L2")); + LinalgMarker(Identifier::get("L3", ctx), Identifier::get("L2", ctx))); patterns.insert>( ctx, LinalgTilingOptions().setTileSizes({20, 30, 40}), - LinalgMarker({"L2"}, "L1")); + LinalgMarker(Identifier::get("L2", ctx), Identifier::get("L1", ctx))); patterns.insert>( ctx, LinalgTilingOptions().setTileSizes({2, 3, 4}), - LinalgMarker({"L1"}, "REG")); + LinalgMarker(Identifier::get("L1", ctx), Identifier::get("REG", ctx))); patterns.insert>( ctx, LinalgTilingOptions().setTileSizes({5, 6}).setLoopType( LinalgTilingLoopType::ParallelLoops), - LinalgMarker({}, "L1")); + LinalgMarker({}, Identifier::get("L1", ctx))); patterns.insert>( ctx, LinalgTilingOptions().setTileSizes(8000), - LinalgMarker({"MEM", "L3", "L2", {}}, "REG")); + LinalgMarker(ArrayRef{Identifier::get("MEM", ctx), + Identifier::get("L3", ctx), + Identifier::get("L2", ctx)}, + Identifier::get("REG", ctx))); //===--------------------------------------------------------------------===// // Linalg tiling and permutation patterns. @@ -95,20 +98,24 @@ LinalgTilingOptions() .setTileSizes({2000, 3000, 4000}) .setInterchange({1, 2, 0}), - LinalgMarker({"__with_perm__"}, "L2__with_perm__")); + LinalgMarker(Identifier::get("__with_perm__", ctx), + Identifier::get("L2__with_perm__", ctx))); patterns.insert>( ctx, LinalgTilingOptions() .setTileSizes({200, 300, 400}) .setInterchange({1, 0, 2}), - LinalgMarker({"L2__with_perm__"}, "L1__with_perm__")); + LinalgMarker(Identifier::get("L2__with_perm__", ctx), + Identifier::get("L1__with_perm__", ctx))); patterns.insert>( ctx, LinalgTilingOptions().setTileSizes({20, 30, 40}), - LinalgMarker({"L1__with_perm__"}, "REG__with_perm__")); + LinalgMarker(Identifier::get("L1__with_perm__", ctx), + Identifier::get("REG__with_perm__", ctx))); patterns.insert>( ctx, LinalgTilingOptions().setTileSizes({5, 6}).setInterchange({1, 0}), - LinalgMarker({"__with_perm__"}, "L1__with_perm__")); + LinalgMarker(Identifier::get("__with_perm__", ctx), + Identifier::get("L1__with_perm__", ctx))); patterns.insert>( ctx, @@ -116,14 +123,16 @@ .setTileSizes({16, 8, 4}) .setInterchange({1, 2, 0}) .setLoopType(LinalgTilingLoopType::ParallelLoops), - LinalgMarker({"par__with_perm__"}, "after_par__with_perm__")); + LinalgMarker(Identifier::get("par__with_perm__", ctx), + Identifier::get("after_par__with_perm__", ctx))); //===--------------------------------------------------------------------===// // Linalg to loops patterns. //===--------------------------------------------------------------------===// patterns.insert>( ctx, - /*loweringType=*/LinalgLoweringType::Loops, LinalgMarker({"REG"})); + /*loweringType=*/LinalgLoweringType::Loops, + LinalgMarker(Identifier::get("REG", ctx))); //===--------------------------------------------------------------------===// // Linalg to vector contraction patterns. @@ -131,7 +140,7 @@ patterns.insert, LinalgVectorizationPattern, LinalgVectorizationPattern>( - ctx, LinalgMarker({"VECTORIZE"})); + ctx, LinalgMarker(Identifier::get("VECTORIZE", ctx))); //===--------------------------------------------------------------------===// // Linalg generic permutation patterns. @@ -139,31 +148,34 @@ patterns.insert>( ctx, /*interchangeVector=*/ArrayRef{1, 2, 0}, - LinalgMarker({}, "PERMUTED")); + LinalgMarker({}, Identifier::get("PERMUTED", ctx))); patterns.insert>( ctx, /*interchangeVector=*/ArrayRef{1, 2, 0}, - LinalgMarker({}, "PERMUTED")); + LinalgMarker({}, Identifier::get("PERMUTED", ctx))); //===--------------------------------------------------------------------===// // Linalg subview operands promotion. //===--------------------------------------------------------------------===// patterns.insert>( ctx, LinalgPromotionOptions().useFullTileBuffersByDefault(), - LinalgMarker({"_promote_views_"}, "_views_promoted_")); + LinalgMarker(Identifier::get("_promote_views_", ctx), + Identifier::get("_views_promoted_", ctx))); patterns.insert>( ctx, LinalgPromotionOptions() .setOperandsToPromote({0}) .useFullTileBuffersByDefault(), - LinalgMarker({"_promote_first_view_"}, "_first_view_promoted_")); + LinalgMarker(Identifier::get("_promote_first_view_", ctx), + Identifier::get("_first_view_promoted_", ctx))); patterns.insert>( ctx, LinalgPromotionOptions() .setOperandsToPromote({0}) .setUseFullTileBuffers({true}) .setAlignment(32), - LinalgMarker({"_promote_views_aligned_"}, "_views_aligned_promoted_")); + LinalgMarker(Identifier::get("_promote_views_aligned_", ctx), + Identifier::get("_views_aligned_promoted_", ctx))); applyPatternsAndFoldGreedily(funcOp, patterns); @@ -176,21 +188,22 @@ static void fillL1TilingAndMatmulToVectorPatterns( FuncOp funcOp, StringRef startMarker, SmallVectorImpl &patternsVector) { - MLIRContext *context = funcOp.getContext(); + MLIRContext *ctx = funcOp.getContext(); patternsVector.emplace_back(LinalgTilingPattern( - context, + ctx, LinalgTilingOptions().setTileSizes({8, 12, 16}).setInterchange({1, 0, 2}), - LinalgMarker({startMarker}, "L1"))); + LinalgMarker(Identifier::get(startMarker, ctx), + Identifier::get("L1", ctx)))); patternsVector.emplace_back(LinalgPromotionPattern( - context, LinalgPromotionOptions().useFullTileBuffersByDefault(), - LinalgMarker({"L1"}, "VEC"))); + ctx, LinalgPromotionOptions().useFullTileBuffersByDefault(), + LinalgMarker(Identifier::get("L1", ctx), Identifier::get("VEC", ctx)))); - patternsVector.emplace_back( - LinalgVectorizationPattern(context, LinalgMarker({"VEC"}))); + patternsVector.emplace_back(LinalgVectorizationPattern( + ctx, LinalgMarker(Identifier::get("VEC", ctx)))); patternsVector.back() .insert, - LinalgVectorizationPattern>(context); + LinalgVectorizationPattern>(ctx); } //===----------------------------------------------------------------------===// @@ -231,13 +244,14 @@ return success(); } -void fillPromotionCallBackPatterns(MLIRContext *context, +void fillPromotionCallBackPatterns(MLIRContext *ctx, OwningRewritePatternList &patterns) { patterns.insert>( - context, LinalgTilingOptions().setTileSizes({16, 16, 16}), - LinalgMarker({"START"}, "PROMOTE")); + ctx, LinalgTilingOptions().setTileSizes({16, 16, 16}), + LinalgMarker(Identifier::get("START", ctx), + Identifier::get("PROMOTE", ctx))); patterns.insert>( - context, + ctx, LinalgPromotionOptions() .setOperandsToPromote({0, 2}) .setUseFullTileBuffers({false, false}) @@ -251,7 +265,7 @@ copyCallBackFn(b, src, dst, true); return success(); }), - LinalgMarker({"PROMOTE"})); + LinalgMarker(Identifier::get("PROMOTE", ctx))); } static void @@ -261,15 +275,18 @@ MLIRContext *ctx = funcOp.getContext(); SmallVector stage1Patterns; if (testMatmulToVectorPatterns1dTiling) { - fillL1TilingAndMatmulToVectorPatterns(funcOp, "START", stage1Patterns); + fillL1TilingAndMatmulToVectorPatterns(funcOp, Identifier::get("START", ctx), + stage1Patterns); } else if (testMatmulToVectorPatterns2dTiling) { - stage1Patterns.emplace_back( - LinalgTilingPattern(ctx, - LinalgTilingOptions() - .setTileSizes({768, 264, 768}) - .setInterchange({1, 2, 0}), - LinalgMarker({"START"}, "L2"))); - fillL1TilingAndMatmulToVectorPatterns(funcOp, "L2", stage1Patterns); + stage1Patterns.emplace_back(LinalgTilingPattern( + ctx, + LinalgTilingOptions() + .setTileSizes({768, 264, 768}) + .setInterchange({1, 2, 0}), + LinalgMarker(Identifier::get("START", ctx), + Identifier::get("L2", ctx)))); + fillL1TilingAndMatmulToVectorPatterns(funcOp, Identifier::get("L2", ctx), + stage1Patterns); } OwningRewritePatternList stage2Patterns = getLinalgTilingCanonicalizationPatterns(ctx);