Index: mlir/include/mlir/IR/PatternMatch.h =================================================================== --- mlir/include/mlir/IR/PatternMatch.h +++ mlir/include/mlir/IR/PatternMatch.h @@ -1416,7 +1416,10 @@ // that a parameter pack can be expanded in c++11. // FIXME: In c++17 this can be simplified by using 'fold expressions'. (void)std::initializer_list{ - 0, (addImpl(/*debugLabels=*/llvm::None, arg, args...), 0)...}; + 0, (addImpl(/*debugLabels=*/llvm::None, + std::forward(arg), + std::forward(args)...), + 0)...}; return *this; } /// An overload of the above `add` method that allows for attaching a set Index: mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp =================================================================== --- mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp +++ mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp @@ -23,91 +23,60 @@ using namespace mlir; using namespace mlir::linalg; -template -static void fillFusionPatterns(MLIRContext *context, - const LinalgDependenceGraph &dependenceGraph, - RewritePatternSet &patterns) { - patterns.add, - LinalgTileAndFusePattern>( +/// Use this to safely fill patterns for this test, since RewritePatternSet::add +/// forwards Rvalues only to the first pattern. +template +static void fillFusionPattern(MLIRContext *context, + const LinalgDependenceGraph &dependenceGraph, + RewritePatternSet &patterns, + const Twine &testCase, + ArrayRef tileSizes, + ArrayRef indicesToFuse) { + patterns.add>( context, dependenceGraph, - LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType), - LinalgFusionOptions().setIndicesToFuse({2}), - LinalgTransformationFilter( - StringAttr::get(context, "basic_fusion"), - StringAttr::get(context, "after_basic_fusion")), + LinalgTilingOptions().setTileSizes(tileSizes).setLoopType(LoopType), + LinalgFusionOptions().setIndicesToFuse(indicesToFuse), LinalgTransformationFilter( - ArrayRef(), - StringAttr::get(context, "after_basic_fusion_producer")), - LinalgTransformationFilter( - ArrayRef(), - StringAttr::get(context, "after_basic_fusion_original"))); - - patterns.add>( - context, dependenceGraph, - LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType), - LinalgFusionOptions().setIndicesToFuse({0}), - LinalgTransformationFilter(StringAttr::get(context, "lhs_fusion"), - StringAttr::get(context, "after_lhs_fusion")), + StringAttr::get(context, testCase + "_fusion"), + StringAttr::get(context, "after_" + testCase + "_fusion")), LinalgTransformationFilter( ArrayRef(), - StringAttr::get(context, "after_lhs_fusion_producer")), + StringAttr::get(context, "after_" + testCase + "_fusion_producer")), LinalgTransformationFilter( ArrayRef(), - StringAttr::get(context, "after_lhs_fusion_original"))); - - patterns.add>( - context, dependenceGraph, - LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType), - LinalgFusionOptions().setIndicesToFuse({2}), - LinalgTransformationFilter(StringAttr::get(context, "out_fusion"), - StringAttr::get(context, "after_out_fusion")), - LinalgTransformationFilter( - ArrayRef(), - StringAttr::get(context, "after_out_fusion_producer")), - LinalgTransformationFilter( - ArrayRef(), - StringAttr::get(context, "after_out_fusion_original"))); + StringAttr::get(context, "after_" + testCase + "_fusion_original"))); +} - patterns.add>( - context, dependenceGraph, - LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType), - LinalgFusionOptions().setIndicesToFuse({1}), - LinalgTransformationFilter(StringAttr::get(context, "rhs_fusion"), - StringAttr::get(context, "after_rhs_fusion")), - LinalgTransformationFilter( - ArrayRef(), - StringAttr::get(context, "after_rhs_fusion_producer")), - LinalgTransformationFilter( - ArrayRef(), - StringAttr::get(context, "after_rhs_fusion_original"))); +template +static void fillFusionPatterns(MLIRContext *context, + const LinalgDependenceGraph &dependenceGraph, + RewritePatternSet &patterns) { + fillFusionPattern(context, dependenceGraph, patterns, + /*testCase=*/"basic", + /*tileSizes=*/{32, 64, 16}, + /*indicesToFuse=*/{2}); - patterns.add>( - context, dependenceGraph, - LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType), - LinalgFusionOptions().setIndicesToFuse({0, 2}), - LinalgTransformationFilter( - StringAttr::get(context, "two_operand_fusion"), - StringAttr::get(context, "after_two_operand_fusion")), - LinalgTransformationFilter( - ArrayRef(), - StringAttr::get(context, "after_two_operand_fusion_producer")), - LinalgTransformationFilter( - ArrayRef(), - StringAttr::get(context, "after_two_operand_fusion_original"))); + auto fillMatmulPattern = [&](const Twine &testCase, + ArrayRef indicesToFuse) { + fillFusionPattern(context, dependenceGraph, patterns, + testCase, /*tileSizes=*/{32, 64, 16}, + indicesToFuse); + }; + fillMatmulPattern(/*testCase=*/"basic", + /*indicesToFuse=*/{2}); + fillMatmulPattern(/*testCase=*/"lhs", + /*indicesToFuse=*/{0}); + fillMatmulPattern(/*testCase=*/"out", + /*indicesToFuse=*/{2}); + fillMatmulPattern(/*testCase=*/"rhs", + /*indicesToFuse=*/{1}); + fillMatmulPattern(/*testCase=*/"two_operand", + /*indicesToFuse=*/{0, 2}); - patterns.add>( - context, dependenceGraph, - LinalgTilingOptions().setTileSizes({32, 64}).setLoopType(LoopType), - LinalgFusionOptions().setIndicesToFuse({0, 1}), - LinalgTransformationFilter( - StringAttr::get(context, "transpose_fusion"), - StringAttr::get(context, "after_transpose_fusion")), - LinalgTransformationFilter( - ArrayRef(), - StringAttr::get(context, "after_transpose_fusion_producer")), - LinalgTransformationFilter( - ArrayRef(), - StringAttr::get(context, "after_transpose_fusion_original"))); + fillFusionPattern(context, dependenceGraph, patterns, + /*testCase=*/"transpose", + /*tileSizes=*/{32, 64}, + /*indicesToFuse=*/{0, 1}); } namespace {