diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -842,6 +842,24 @@ LinalgTransformationFilter filter; }; +/// +/// Linalg skip pattern. +/// +/// Update the LinalgTransformMarker if `filter` matches. The pattern has +/// default benefit zero and is meant to update the operations that are not +/// matched by any other pattern of the current transformation stage. +struct LinalgSkipPattern : public OpInterfaceRewritePattern { + LinalgSkipPattern(MLIRContext *context, LinalgTransformationFilter filter, + PatternBenefit benefit = 0); + + LogicalResult matchAndRewrite(LinalgOp linalgOp, + PatternRewriter &rewriter) const override; + +private: + /// LinalgTransformMarker handles special attribute manipulations. + LinalgTransformationFilter filter; +}; + /// /// Linalg promotion patterns. /// diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp @@ -199,6 +199,8 @@ return; RewritePatternSet decompositionPattern(funcOp.getContext()); populateDecomposeConvolutionPatterns(decompositionPattern, filter); + // Update the transformation marker if no decompose pattern matches. + decompositionPattern.add(funcOp.getContext(), filter); if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(decompositionPattern)))) signalPassFailure(); @@ -229,9 +231,9 @@ RewritePatternSet interchangePattern(funcOp.getContext()); interchangePattern.add( funcOp.getContext(), interchangeVector, filter); - if (failed(applyPatternsAndFoldGreedily(funcOp, - std::move(interchangePattern)))) - signalPassFailure(); + // Update the transformation marker for all non generic operations. + interchangePattern.add(funcOp.getContext(), filter); + (void)applyPatternsAndFoldGreedily(funcOp, std::move(interchangePattern)); } SmallVector iteratorInterchange; diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -656,6 +656,21 @@ return success(); } +// LinalgSkipPattern +mlir::linalg::LinalgSkipPattern::LinalgSkipPattern( + MLIRContext *context, LinalgTransformationFilter filter, + PatternBenefit benefit) + : mlir::OpInterfaceRewritePattern(context, benefit), + filter(filter) {} + +LogicalResult mlir::linalg::LinalgSkipPattern::matchAndRewrite( + LinalgOp linalgOp, PatternRewriter &rewriter) const { + if (failed(filter.checkAndNotify(rewriter, linalgOp))) + return failure(); + filter.replaceLinalgTransformationFilter(rewriter, linalgOp); + return success(); +} + mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern( MLIRContext *context, LinalgTransformationFilter filter, LinalgPromotionOptions options, PatternBenefit benefit) diff --git a/mlir/test/Dialect/Linalg/codegen-strategy.mlir b/mlir/test/Dialect/Linalg/codegen-strategy.mlir --- a/mlir/test/Dialect/Linalg/codegen-strategy.mlir +++ b/mlir/test/Dialect/Linalg/codegen-strategy.mlir @@ -4,6 +4,8 @@ // RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-func=matmul anchor-op=linalg.matmul tile-sizes=16,32,64 pad pack-paddings=1,1,0 hoist-paddings=3,3,0" -split-input-file | FileCheck %s --check-prefix=CHECK-PAD // RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-func=matmul anchor-op=linalg.matmul tile-sizes=16,32,64 fuse pad vectorize" -split-input-file | FileCheck %s --check-prefix=CHECK-FUSE // RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-func=matvec generalize iterator-interchange=1,0" -split-input-file | FileCheck %s --check-prefix=CHECK-GENERALIZE +// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-func=matvec iterator-interchange=1,0 vectorize" -split-input-file | FileCheck %s --check-prefix=CHECK-INTERCHANGE-AND-VECTORIZE +// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-func=conv decompose generalize" -split-input-file | FileCheck %s --check-prefix=CHECK-DECOMPOSE-AND-GENERALIZE // CHECK-INTRINSIC: func @matmul( // CHECK-OUTER: func @matmul( @@ -83,7 +85,10 @@ // CHECK-GENERALIZE: func @matvec( func @matvec(%arg0: tensor<72x72xf32>, %arg1: tensor<72xf32>, %arg2: tensor<72xf32>) -> tensor<72xf32> { - // Check the generic op iterators are interchanged although generalization was not needed. + + // Check both operations are interchanged despite the missing generic generalization pattern. + // CHECK-GENERALIZE: linalg.generic + // CHECK-GENERALIZE-SAME: iterator_types = ["reduction", "parallel"] // CHECK-GENERALIZE: linalg.generic // CHECK-GENERALIZE-SAME: iterator_types = ["reduction", "parallel"] %0 = linalg.generic {indexing_maps = [#map0, #map1], @@ -94,10 +99,50 @@ %2 = arith.addf %arg3, %arg5 : f32 linalg.yield %2 : f32 } -> tensor<72xf32> + %1 = linalg.matvec ins(%arg0, %0: tensor<72x72xf32>, tensor<72xf32>) outs(%arg2: tensor<72xf32>) -> tensor<72xf32> + return %1 : tensor<72xf32> +} - // Check matvec is generalized and its iterators are interchanged. - // CHECK-GENERALIZE: linalg.generic - // CHECK-GENERALIZE-SAME: iterator_types = ["reduction", "parallel"] +// ----- + +#map0 = affine_map<(d0, d1) -> (d0, d1)> +#map1 = affine_map<(d0, d1) -> (d0)> + +// CHECK-INTERCHANGE-AND-VECTORIZE: func @matvec( +func @matvec(%arg0: tensor<72x72xf32>, %arg1: tensor<72xf32>, %arg2: tensor<72xf32>) -> tensor<72xf32> { + + // Check both operations are vectorized despite the missing matvec interchange pattern. + // CHECK-INTERCHANGE-AND-VECTORIZE-NOT: linalg.generic + // CHECK-INTERCHANGE-AND-VECTORIZE-NOT: linalg.matvec + %0 = linalg.generic {indexing_maps = [#map0, #map1], + iterator_types = ["parallel", "reduction"]} + ins(%arg0 : tensor<72x72xf32>) + outs(%arg1 : tensor<72xf32>) { + ^bb0(%arg3: f32, %arg5: f32): // no predecessors + %2 = arith.addf %arg3, %arg5 : f32 + linalg.yield %2 : f32 + } -> tensor<72xf32> %1 = linalg.matvec ins(%arg0, %0: tensor<72x72xf32>, tensor<72xf32>) outs(%arg2: tensor<72xf32>) -> tensor<72xf32> return %1 : tensor<72xf32> } + + +// ----- + +#map0 = affine_map<(d0, d1) -> (d0, d1)> +#map1 = affine_map<(d0, d1) -> (d0)> + +// CHECK-DECOMPOSE-AND-GENERALIZE: func @conv( +func @conv(%input: tensor<4x1x6x3xf32>, %filter: tensor<1x2x3x8xf32>, %init: tensor<4x1x2x8xf32>) -> tensor<4x1x2x8xf32> { + %cst = arith.constant 0.0 : f32 + + // Check both operations are generalized despite the missing fill decompose pattern. + // CHECK-DECOMPOSE-AND-GENERALIZE-NOT: linalg.fill + // CHECK-DECOMPOSE-AND-GENERALIZE-NOT: linalg.conv_2d_nhwc_hwcf + %0 = linalg.fill(%cst, %init) : f32, tensor<4x1x2x8xf32> -> tensor<4x1x2x8xf32> + %1 = linalg.conv_2d_nhwc_hwcf + {dilations = dense<[2, 3]> : tensor<2xi64>, strides = dense<[3, 2]> : tensor<2xi64>} + ins(%input, %filter : tensor<4x1x6x3xf32>, tensor<1x2x3x8xf32>) + outs(%0 : tensor<4x1x2x8xf32>) -> tensor<4x1x2x8xf32> + return %1 : tensor<4x1x2x8xf32> +}