diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -767,6 +767,10 @@ Note that this transformation is invalidating the handles to any payload IR operation that is contained inside the vectorization target. + `disable_multi_reduction_to_contract_patterns` and + `disable_transfer_permutation_map_lowering_patterns` limits the power of + vectorization. They are currently intended for testing purposes. + #### Return modes: This operation produces `definiteFailure` if vectorization fails for any @@ -776,7 +780,9 @@ }]; let arguments = (ins PDL_Operation:$target, - DefaultValuedAttr:$vectorize_padding); + DefaultValuedAttr:$vectorize_padding, + DefaultValuedAttr:$disable_multi_reduction_to_contract_patterns, + DefaultValuedAttr:$disable_transfer_permutation_map_lowering_patterns); let results = (outs PDL_Operation:$transformed); let assemblyFormat = "$target attr-dict"; diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -926,31 +926,6 @@ /// Empty for now, used for SFINAE purposes only. struct LinalgVectorizationOptions {}; -/// `filter` controls LinalgTransformMarker matching and update when specified. -/// See `vectorizeLinalgOp` for more details. -struct LinalgVectorizationPattern : public OpInterfaceRewritePattern { - /// Construct a generic pattern applied to all LinalgOp that verify `filter`. - LinalgVectorizationPattern( - MLIRContext *context, - LinalgTransformationFilter f = LinalgTransformationFilter(), - LinalgVectorizationOptions options = LinalgVectorizationOptions(), - PatternBenefit benefit = 1); - - /// Construct a pattern specifically applied to `opName`. - LinalgVectorizationPattern( - StringRef opName, MLIRContext *context, - LinalgVectorizationOptions options = LinalgVectorizationOptions(), - LinalgTransformationFilter f = LinalgTransformationFilter(), - PatternBenefit benefit = 1); - - LogicalResult matchAndRewrite(LinalgOp linalgOp, - PatternRewriter &rewriter) const override; - -private: - /// LinalgTransformMarker handles special attribute manipulations. - LinalgTransformationFilter filter; -}; - /// `filter` controls LinalgTransformMarker matching and update when specified. /// See `vectorizeLinalgOp` for more details. struct CopyVectorizationPattern : public OpRewritePattern { @@ -1335,18 +1310,6 @@ const LinalgTransformationFilter &f) {} }; -template -class VectorizationPatterns { -public: - static void insert(RewritePatternSet &patterns, - const LinalgVectorizationOptions &options, - const LinalgTransformationFilter &f) { - patterns.add(OpTy::getOperationName(), - patterns.getContext(), options, f); - VectorizationPatterns::insert(patterns, options, f); - } -}; - template class TilingPatterns; diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -1166,6 +1166,22 @@ // VectorizeOp //===----------------------------------------------------------------------===// +namespace { +/// This is an helper only to call vectorize via a pattern inside of +/// VectorizeOp::applyToOne. +struct VectorizationPattern : public RewritePattern { + explicit VectorizationPattern(MLIRContext *context) + : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {} + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { + LinalgOp linalgOp = dyn_cast(op); + if (!linalgOp) + return failure(); + return vectorize(rewriter, linalgOp); + } +}; +} // namespace + DiagnosedSilenceableFailure transform::VectorizeOp::applyToOne(Operation *target, SmallVectorImpl &results, @@ -1178,15 +1194,22 @@ MLIRContext *ctx = getContext(); RewritePatternSet patterns(ctx); - patterns.add(ctx); + patterns.add(ctx); + + if (!getDisableTransferPermutationMapLoweringPatterns()) + vector::populateVectorTransferPermutationMapLoweringPatterns(patterns); + + if (!getDisableMultiReductionToContractPatterns()) + vector::populateVectorReductionToContractPatterns(patterns); - vector::populateVectorTransferPermutationMapLoweringPatterns(patterns); - vector::populateVectorReductionToContractPatterns(patterns); patterns.add(ctx, /*benefit=*/2); vector::TransferReadOp::getCanonicalizationPatterns(patterns, ctx); vector::TransferWriteOp::getCanonicalizationPatterns(patterns, ctx); + + patterns.add(ctx); + if (getVectorizePadding()) linalg::populatePadOpVectorizationPatterns(patterns); @@ -1212,7 +1235,7 @@ void init() { declareDependentDialect(); - + declareDependentDialect(); declareGeneratedDialect(); declareGeneratedDialect(); declareGeneratedDialect(); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -590,25 +590,6 @@ return success(); } -mlir::linalg::LinalgVectorizationPattern::LinalgVectorizationPattern( - MLIRContext *context, LinalgTransformationFilter f, - LinalgVectorizationOptions options, PatternBenefit benefit) - : OpInterfaceRewritePattern(context, benefit), - filter(std::move(f)) {} - -mlir::linalg::LinalgVectorizationPattern::LinalgVectorizationPattern( - StringRef opName, MLIRContext *context, LinalgVectorizationOptions options, - LinalgTransformationFilter f, PatternBenefit benefit) - : OpInterfaceRewritePattern(context, benefit), - filter(f.addOpNameFilter(opName)) {} - -LogicalResult mlir::linalg::LinalgVectorizationPattern::matchAndRewrite( - LinalgOp linalgOp, PatternRewriter &rewriter) const { - if (failed(filter.checkAndNotify(rewriter, linalgOp))) - return failure(); - return vectorize(rewriter, linalgOp); -} - LogicalResult mlir::linalg::CopyVectorizationPattern::matchAndRewrite( memref::CopyOp copyOp, PatternRewriter &rewriter) const { return vectorizeCopy(rewriter, copyOp); diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir --- a/mlir/test/Dialect/Linalg/vectorization.mlir +++ b/mlir/test/Dialect/Linalg/vectorization.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -test-linalg-transform-patterns=test-linalg-to-vector-patterns -split-input-file | FileCheck %s +// RUN: mlir-opt %s -test-transform-dialect-interpreter -split-input-file | FileCheck %s // ----- @@ -12,6 +12,16 @@ return } +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.dot"]} in %arg1 + %1 = get_closest_isolated_parent %0 + %2 = transform.structured.vectorize %1 { disable_multi_reduction_to_contract_patterns = true } + } +} + // ----- // CHECK-LABEL: contraction_matvec @@ -24,6 +34,16 @@ return } +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.matvec"]} in %arg1 + %1 = get_closest_isolated_parent %0 + %2 = transform.structured.vectorize %1 { disable_multi_reduction_to_contract_patterns = true } + } +} + // ----- // CHECK-LABEL: contraction_matmul @@ -35,6 +55,16 @@ return } +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 + %1 = get_closest_isolated_parent %0 + %2 = transform.structured.vectorize %1 { disable_multi_reduction_to_contract_patterns = true } + } +} + // ----- // CHECK-LABEL: contraction_batch_matmul @@ -47,6 +77,16 @@ return } +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.batch_matmul"]} in %arg1 + %1 = get_closest_isolated_parent %0 + %2 = transform.structured.vectorize %1 { disable_multi_reduction_to_contract_patterns = true } + } +} + // ----- #matmul_trait = { @@ -80,6 +120,16 @@ return } +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 + %1 = get_closest_isolated_parent %0 + %2 = transform.structured.vectorize %1 { disable_multi_reduction_to_contract_patterns = true, disable_transfer_permutation_map_lowering_patterns = true } + } +} + // ----- #matmul_transpose_out_trait = { @@ -113,6 +163,16 @@ return } +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 + %1 = get_closest_isolated_parent %0 + %2 = transform.structured.vectorize %1 { disable_multi_reduction_to_contract_patterns = true, disable_transfer_permutation_map_lowering_patterns = true } + } +} + // ----- #map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> @@ -133,6 +193,16 @@ return %1 : tensor<128x12x32xf32> } +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 + %1 = get_closest_isolated_parent %0 + %2 = transform.structured.vectorize %1 { disable_multi_reduction_to_contract_patterns = true, disable_transfer_permutation_map_lowering_patterns = true } + } +} + // ----- #matmul_trait = { @@ -166,6 +236,16 @@ return } +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 + %1 = get_closest_isolated_parent %0 + %2 = transform.structured.vectorize %1 { disable_multi_reduction_to_contract_patterns = true, disable_transfer_permutation_map_lowering_patterns = true } + } +} + // ----- // CHECK-LABEL: func @vectorization_test_2 @@ -179,6 +259,16 @@ return } +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 + %1 = get_closest_isolated_parent %0 + %2 = transform.structured.vectorize %1 { disable_multi_reduction_to_contract_patterns = true } + } +} + // ----- // CHECK-LABEL: func @test_vectorize_scalar_input @@ -196,6 +286,16 @@ return } +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 + %1 = get_closest_isolated_parent %0 + %2 = transform.structured.vectorize %1 + } +} + // ----- // CHECK-LABEL: func @test_do_not_vectorize_unsupported_element_types @@ -213,6 +313,16 @@ return } +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 + %1 = get_closest_isolated_parent %0 + %2 = transform.structured.vectorize %1 + } +} + // ----- // CHECK-LABEL: func @test_vectorize_fill @@ -223,6 +333,16 @@ return } +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.fill"]} in %arg1 + %1 = get_closest_isolated_parent %0 + %2 = transform.structured.vectorize %1 + } +} + // ----- // CHECK-LABEL: func @test_vectorize_fill @@ -234,6 +354,16 @@ return } +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.fill"]} in %arg1 + %1 = get_closest_isolated_parent %0 + %2 = transform.structured.vectorize %1 + } +} + // ----- // CHECK-LABEL: func @test_vectorize_copy @@ -244,6 +374,16 @@ return } +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["memref.copy"]} in %arg1 + %1 = get_closest_isolated_parent %0 + %2 = transform.structured.vectorize %1 + } +} + // ----- // CHECK-LABEL: func @test_vectorize_copy_scalar @@ -257,6 +397,15 @@ return } +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["memref.copy"]} in %arg1 + %1 = get_closest_isolated_parent %0 + %2 = transform.structured.vectorize %1 + } +} // ----- // CHECK-LABEL: func @test_vectorize_trailing_index @@ -278,6 +427,16 @@ return } +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 + %1 = get_closest_isolated_parent %0 + %2 = transform.structured.vectorize %1 + } +} + // ----- // CHECK-LABEL: func @test_vectorize_inner_index @@ -300,6 +459,16 @@ return } +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 + %1 = get_closest_isolated_parent %0 + %2 = transform.structured.vectorize %1 + } +} + // ----- // CHECK-LABEL: func @generic_vectorize @@ -378,6 +547,16 @@ return } +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 + %1 = get_closest_isolated_parent %0 + %2 = transform.structured.vectorize %1 {disable_transfer_permutation_map_lowering_patterns = true } + } +} + // ----- // CHECK-LABEL: func @generic_vectorize_tensor @@ -462,6 +641,16 @@ tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32> } +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 + %1 = get_closest_isolated_parent %0 + %2 = transform.structured.vectorize %1 { disable_transfer_permutation_map_lowering_patterns = true } + } +} + // ----- // CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, 0, 0, d1)> @@ -499,6 +688,16 @@ return } +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 + %1 = get_closest_isolated_parent %0 + %2 = transform.structured.vectorize %1 {disable_transfer_permutation_map_lowering_patterns = true } + } +} + // ----- // Test different input maps. @@ -535,6 +734,16 @@ return } +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 + %1 = get_closest_isolated_parent %0 + %2 = transform.structured.vectorize %1 {disable_transfer_permutation_map_lowering_patterns = true } + } +} + // ----- // CHECK-LABEL: func @matmul_tensors @@ -560,6 +769,16 @@ return %0 : tensor<8x12xf32> } +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 + %1 = get_closest_isolated_parent %0 + %2 = transform.structured.vectorize %1 { disable_multi_reduction_to_contract_patterns = true, disable_transfer_permutation_map_lowering_patterns = true } + } +} + // ----- // CHECK-LABEL: func @pad_static( @@ -581,6 +800,17 @@ return %0 : tensor<2x3x4xf32> } + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["tensor.pad"]} in %arg1 + %1 = get_closest_isolated_parent %0 + %2 = transform.structured.vectorize %1 { vectorize_padding = true } + } +} + // ----- // CHECK-LABEL: func @pad_static_source( @@ -602,6 +832,18 @@ return %0 : tensor<2x6x4xf32> } + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["tensor.pad"]} in %arg1 + %1 = get_closest_isolated_parent %0 + %2 = transform.structured.vectorize %1 { vectorize_padding = true } + } +} + + // ----- // CHECK-LABEL: func @pad_static_dynamic( @@ -630,6 +872,18 @@ return %0 : tensor<6x?x?x?xf32> } + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["tensor.pad"]} in %arg1 + %1 = get_closest_isolated_parent %0 + %2 = transform.structured.vectorize %1 { vectorize_padding = true } + } +} + + // ----- // CHECK-LABEL: func @pad_and_transfer_read @@ -652,6 +906,17 @@ return %1 : vector<7x9xf32> } + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["tensor.pad"]} in %arg1 + %1 = get_closest_isolated_parent %0 + %2 = transform.structured.vectorize %1 { vectorize_padding = true } + } +} + // ----- func.func private @make_vector() -> vector<7x9xf32> @@ -678,6 +943,17 @@ return %3 : tensor<5x6xf32> } +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %3 = transform.structured.match ops{["tensor.pad"]} in %arg1 + %4 = get_closest_isolated_parent %3 + %5 = transform.structured.vectorize %4 { vectorize_padding = true } + } +} + + // ----- func.func private @make_vector() -> vector<7x9xf32> @@ -707,6 +983,17 @@ return %3 : tensor } +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %3 = transform.structured.match ops{["tensor.pad"]} in %arg1 + %4 = get_closest_isolated_parent %3 + %5 = transform.structured.vectorize %4 { vectorize_padding = true } + } +} + + // ----- func.func private @make_vector() -> tensor<12x13xf32> @@ -733,6 +1020,17 @@ return %r : tensor<12x13xf32> } +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %3 = transform.structured.match ops{["tensor.pad"]} in %arg1 + %4 = get_closest_isolated_parent %3 + %5 = transform.structured.vectorize %4 { vectorize_padding = true } + } +} + + // ----- func.func private @make_vector() -> tensor<12x13xf32> @@ -753,6 +1051,16 @@ return %r : tensor<1x12x13xf32> } +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %3 = transform.structured.match ops{["tensor.pad"]} in %arg1 + %4 = get_closest_isolated_parent %3 + %5 = transform.structured.vectorize %4 + } +} + // ----- // CHECK-LABEL: func @pad_tensor_non_const_pad_value @@ -782,6 +1090,17 @@ return %0 : tensor<12x13xf32> } + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %3 = transform.structured.match ops{["tensor.pad"]} in %arg1 + %4 = get_closest_isolated_parent %3 + %5 = transform.structured.vectorize %4 { vectorize_padding = true } + } +} + // ----- // CHECK-LABEL: func @sum_exp @@ -809,6 +1128,17 @@ return %0 : tensor<4x16xf32> } + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %3 = transform.structured.match ops{["linalg.generic"]} in %arg1 + %4 = get_closest_isolated_parent %3 + %5 = transform.structured.vectorize %4 + } +} + // ----- // CHECK-DAG: #[[$M1:.*]] = affine_map<(d0, d1) -> (d1, d0, 0, 0)> @@ -846,13 +1176,23 @@ return %0 : tensor<5x2xf32> } + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %3 = transform.structured.match ops{["linalg.generic"]} in %arg1 + %4 = get_closest_isolated_parent %3 + %5 = transform.structured.vectorize %4 { disable_multi_reduction_to_contract_patterns = true, disable_transfer_permutation_map_lowering_patterns = true } + } +} + // ----- // CHECK-LABEL: func @red_max_2d( func.func @red_max_2d(%arg0: tensor<4x4xf32>) -> tensor<4xf32> { // CHECK: %[[CMINF:.+]] = arith.constant dense<-3.402820e+38> : vector<4xf32> // CHECK: linalg.init_tensor [4] : tensor<4xf32> - // CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32> // CHECK: vector.multi_reduction , {{.*}}, %[[CMINF]] [1] : vector<4x4xf32> to vector<4xf32> // CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32> %ident = arith.constant -3.40282e+38 : f32 @@ -869,13 +1209,23 @@ return %red : tensor<4xf32> } + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %3 = transform.structured.match ops{["linalg.generic"]} in %arg1 + %4 = get_closest_isolated_parent %3 + %5 = transform.structured.vectorize %4 { vectorize_padding = true } + } +} + // ----- // CHECK-LABEL: func @red_min_2d( func.func @red_min_2d(%arg0: tensor<4x4xf32>) -> tensor<4xf32> { // CHECK: %[[CMAXF:.+]] = arith.constant dense<3.402820e+38> : vector<4xf32> // CHECK: linalg.init_tensor [4] : tensor<4xf32> - // CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32> // CHECK: vector.transfer_read {{.*}} : tensor<4x4xf32>, vector<4x4xf32> // CHECK: vector.multi_reduction , {{.*}}, %[[CMAXF]] [1] : vector<4x4xf32> to vector<4xf32> // CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32> @@ -893,12 +1243,22 @@ return %red : tensor<4xf32> } + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %3 = transform.structured.match ops{["linalg.generic"]} in %arg1 + %4 = get_closest_isolated_parent %3 + %5 = transform.structured.vectorize %4 + } +} + // ----- // CHECK-LABEL: func @red_mul_2d( func.func @red_mul_2d(%arg0: tensor<4x4xf32>) -> tensor<4xf32> { // CHECK: linalg.init_tensor [4] : tensor<4xf32> - // CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32> // CHECK: vector.transfer_read {{.*}} : tensor<4x4xf32>, vector<4x4xf32> // CHECK: vector.multi_reduction , {{.*}}, {{.*}} [1] : vector<4x4xf32> to vector<4xf32> // CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32> @@ -916,12 +1276,22 @@ return %red : tensor<4xf32> } + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %3 = transform.structured.match ops{["linalg.generic"]} in %arg1 + %4 = get_closest_isolated_parent %3 + %5 = transform.structured.vectorize %4 + } +} + // ----- // CHECK-LABEL: func @red_or_2d( func.func @red_or_2d(%arg0: tensor<4x4xi1>) -> tensor<4xi1> { // CHECK: linalg.init_tensor [4] : tensor<4xi1> - // CHECK: vector.transfer_write {{.*}} : vector<4xi1>, tensor<4xi1> // CHECK: vector.transfer_read {{.*}} : tensor<4x4xi1>, vector<4x4xi1> // CHECK: vector.multi_reduction , {{.*}}, {{.*}} [1] : vector<4x4xi1> to vector<4xi1> // CHECK: vector.transfer_write {{.*}} : vector<4xi1>, tensor<4xi1> @@ -939,12 +1309,22 @@ return %red : tensor<4xi1> } + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %3 = transform.structured.match ops{["linalg.generic"]} in %arg1 + %4 = get_closest_isolated_parent %3 + %5 = transform.structured.vectorize %4 + } +} + // ----- // CHECK-LABEL: func @red_and_2d( func.func @red_and_2d(%arg0: tensor<4x4xi1>) -> tensor<4xi1> { // CHECK: linalg.init_tensor [4] : tensor<4xi1> - // CHECK: vector.transfer_write {{.*}} : vector<4xi1>, tensor<4xi1> // CHECK: vector.transfer_read {{.*}} : tensor<4x4xi1>, vector<4x4xi1> // CHECK: vector.multi_reduction , {{.*}}, {{.*}} [1] : vector<4x4xi1> to vector<4xi1> // CHECK: vector.transfer_write {{.*}} : vector<4xi1>, tensor<4xi1> @@ -962,12 +1342,22 @@ return %red : tensor<4xi1> } + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %3 = transform.structured.match ops{["linalg.generic"]} in %arg1 + %4 = get_closest_isolated_parent %3 + %5 = transform.structured.vectorize %4 + } +} + // ----- // CHECK-LABEL: func @red_xor_2d( func.func @red_xor_2d(%arg0: tensor<4x4xi1>) -> tensor<4xi1> { // CHECK: linalg.init_tensor [4] : tensor<4xi1> - // CHECK: vector.transfer_write {{.*}} : vector<4xi1>, tensor<4xi1> // CHECK: vector.transfer_read {{.*}} : tensor<4x4xi1>, vector<4x4xi1> // CHECK: vector.multi_reduction , {{.*}}, {{.*}} [1] : vector<4x4xi1> to vector<4xi1> // CHECK: vector.transfer_write {{.*}} : vector<4xi1>, tensor<4xi1> @@ -985,6 +1375,17 @@ return %red : tensor<4xi1> } + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %3 = transform.structured.match ops{["linalg.generic"]} in %arg1 + %4 = get_closest_isolated_parent %3 + %5 = transform.structured.vectorize %4 + } +} + // ----- // CHECK-DAG: #[[$M5:.*]] = affine_map<(d0, d1) -> (d0, 0)> @@ -1011,6 +1412,17 @@ return %red : tensor<4x4xf32> } + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %3 = transform.structured.match ops{["linalg.generic"]} in %arg1 + %4 = get_closest_isolated_parent %3 + %5 = transform.structured.vectorize %4 + } +} + // ----- // CHECK-DAG: #[[$M6:.*]] = affine_map<(d0, d1) -> (d0, 0)> @@ -1041,6 +1453,21 @@ return %red : tensor<4xf32> } + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.fill"]} in %arg1 + %1 = get_closest_isolated_parent %0 + %2 = transform.structured.vectorize %1 + + %3 = transform.structured.match ops{["linalg.generic"]} in %arg1 + %4 = get_closest_isolated_parent %3 + %5 = transform.structured.vectorize %4 + } +} + // ----- // CHECK-LABEL: func @reduce_1d( @@ -1054,8 +1481,6 @@ // CHECK: %[[init:.*]] = linalg.init_tensor [] : tensor %0 = linalg.init_tensor [] : tensor - // CHECK: %[[f:.*]] = vector.transfer_write %[[vF0]], %[[init]][] - // CHECK-SAME: : vector, tensor %1 = linalg.fill ins(%f0 : f32) outs(%0 : tensor) -> tensor // CHECK: %[[r:.*]] = vector.transfer_read %[[A]][%[[C0]]] // CHECK-SAME: : tensor<32xf32>, vector<32xf32> @@ -1063,7 +1488,7 @@ // CHECK: %[[red:.*]] = vector.multi_reduction , %[[r]], %[[f0]] [0] // CHECK-SAME: : vector<32xf32> to f32 // CHECK: %[[red_v1:.*]] = vector.broadcast %[[red]] : f32 to vector - // CHECK: %[[res:.*]] = vector.transfer_write %[[red_v1]], %[[f]][] + // CHECK: %[[res:.*]] = vector.transfer_write %[[red_v1]], %[[init]][] // CHECK-SAME: : vector, tensor %2 = linalg.generic { indexing_maps = [affine_map<(d0) -> (d0)>, @@ -1079,6 +1504,16 @@ return %2 : tensor } +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 + %1 = get_closest_isolated_parent %0 + %2 = transform.structured.vectorize %1 + } +} + // ----- @@ -1103,6 +1538,16 @@ return %result : tensor<6x6x3x3xf32> } +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 + %1 = get_closest_isolated_parent %0 + %2 = transform.structured.vectorize %1 + } +} + // ----- // Check vectorization can handle cases where outputs are a mix of reduced and non-reduced values. @@ -1134,3 +1579,13 @@ // CHECK-DAG: %[[ADD:.+]] = vector.multi_reduction , %[[MUL]], %[[V2]] // CHECK-DAG: vector.transfer_write %[[MUL]], %[[ARG2]] // CHECK-DAG: vector.transfer_write %[[ADD]], %[[ARG3]] + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 + %1 = get_closest_isolated_parent %0 + %2 = transform.structured.vectorize %1 { disable_multi_reduction_to_contract_patterns = true, disable_transfer_permutation_map_lowering_patterns = true } + } +} \ No newline at end of file diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp --- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp @@ -225,9 +225,6 @@ //===--------------------------------------------------------------------===// // Linalg to vector contraction patterns. //===--------------------------------------------------------------------===// - patterns.add( - ctx, LinalgTransformationFilter(StringAttr::get(ctx, "VECTORIZE")) - .addOpFilter()); patterns.add(ctx); (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); @@ -441,9 +438,6 @@ static void applyLinalgToVectorPatterns(func::FuncOp funcOp) { RewritePatternSet patterns(funcOp.getContext()); auto *ctx = funcOp.getContext(); - patterns.add( - ctx, LinalgTransformationFilter() - .addOpFilter()); patterns.add(ctx); populatePadOpVectorizationPatterns(patterns); populateConvolutionVectorizationPatterns(patterns);