diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h --- a/mlir/include/mlir/Dialect/Linalg/Passes.h +++ b/mlir/include/mlir/Dialect/Linalg/Passes.h @@ -134,15 +134,6 @@ const linalg::LinalgTransformationFilter &filter = linalg::LinalgTransformationFilter()); -/// Create a LinalgStrategyVectorizePass. -std::unique_ptr> createLinalgStrategyVectorizePass( - StringRef opName = "", - linalg::LinalgVectorizationOptions opt = - linalg::LinalgVectorizationOptions(), - const linalg::LinalgTransformationFilter &filter = - linalg::LinalgTransformationFilter(), - bool padVectorize = false); - /// Create a LinalgStrategyLowerVectorsPass. std::unique_ptr> createLinalgStrategyLowerVectorsPass( diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td --- a/mlir/include/mlir/Dialect/Linalg/Passes.td +++ b/mlir/include/mlir/Dialect/Linalg/Passes.td @@ -238,21 +238,6 @@ ]; } -def LinalgStrategyVectorizePass - : Pass<"linalg-strategy-vectorize-pass", "func::FuncOp"> { - let summary = "Configurable pass to apply pattern-based linalg vectorization."; - let constructor = "mlir::createLinalgStrategyVectorizePass()"; - let dependentDialects = ["linalg::LinalgDialect"]; - let options = [ - Option<"anchorFuncName", "anchor-func", "std::string", /*default=*/"", - "Which func op is the anchor to latch on.">, - Option<"anchorOpName", "anchor-op", "std::string", /*default=*/"", - "Which linalg op within the func is the anchor to latch on.">, - Option<"vectorizePadding", "vectorize-padding", "bool", "false", - "Enable vectorization of padding ops.">, - ]; -} - def LinalgStrategyLowerVectorsPass : Pass<"linalg-strategy-lower-vectors-pass", "func::FuncOp"> { let summary = "Configurable pass to lower vector operations."; diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -764,6 +764,10 @@ Note that this transformation is invalidating the handles to any payload IR operation that is contained inside the vectorization target. + `disable_multi_reduction_to_contract_patterns` and + `disable_transfer_permutation_map_lowering_patterns` limits the power of + vectorization. They are currently intended for testing purposes. + #### Return modes: This operation produces `definiteFailure` if vectorization fails for any @@ -773,7 +777,9 @@ }]; let arguments = (ins PDL_Operation:$target, - DefaultValuedAttr:$vectorize_padding); + DefaultValuedAttr:$vectorize_padding, + DefaultValuedAttr:$disable_multi_reduction_to_contract_patterns, + DefaultValuedAttr:$disable_transfer_permutation_map_lowering_patterns); let results = (outs PDL_Operation:$transformed); let assemblyFormat = "$target attr-dict"; diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h b/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h @@ -112,32 +112,6 @@ linalg::LinalgPeelOptions options; }; -/// Represent one application of createLinalgStrategyVectorizePass. -struct Vectorize : public Transformation { - explicit Vectorize(linalg::LinalgVectorizationOptions options, - LinalgTransformationFilter::FilterFunction f = nullptr, - bool padVectorize = false) - : Transformation(std::move(f)), options(options), - vectorizePadding(padVectorize) {} - - Vectorize(StringRef name, linalg::LinalgVectorizationOptions options, - LinalgTransformationFilter::FilterFunction f = nullptr, - bool padVectorize = false) - : Transformation(std::move(f)), opName(name), options(options), - vectorizePadding(padVectorize) {} - - void addToPassPipeline(OpPassManager &pm, - LinalgTransformationFilter m) const override { - pm.addPass(createLinalgStrategyVectorizePass(opName, options, m, - vectorizePadding)); - } - -private: - std::string opName; - linalg::LinalgVectorizationOptions options; - bool vectorizePadding; -}; - /// Represent one application of createLinalgStrategyLowerVectorsPass. struct VectorLowering : public Transformation { explicit VectorLowering( @@ -203,7 +177,7 @@ padIf(bool b, StringRef opName, linalg::LinalgPaddingOptions options, LinalgTransformationFilter::FilterFunction f = nullptr) { return b ? pad(opName, std::move(options), std::move(f)) : *this; - } + } /// Append patterns to decompose convolutions. CodegenStrategy & decompose(const LinalgTransformationFilter::FilterFunction &f = nullptr) { @@ -229,23 +203,6 @@ LinalgTransformationFilter::FilterFunction f = nullptr) { return b ? peel(opName, options, std::move(f)) : *this; } - /// Append a pattern to rewrite `LinalgOpType` as a vector operation. - CodegenStrategy & - vectorize(StringRef opName, - const LinalgTransformationFilter::FilterFunction &f = nullptr, - bool vectorizePadding = false) { - transformationSequence.emplace_back(std::make_unique( - opName, linalg::LinalgVectorizationOptions(), f, vectorizePadding)); - return *this; - } - /// Conditionally append a pattern to rewrite `LinalgOpType` as a vector - /// operation. - CodegenStrategy & - vectorizeIf(bool b, StringRef opName, - LinalgTransformationFilter::FilterFunction f = nullptr, - bool vectorizePadding = false) { - return b ? vectorize(opName, std::move(f), vectorizePadding) : *this; - } /// Append a pattern to lower all vector operations. CodegenStrategy &vectorLowering(LinalgVectorLoweringOptions options) { transformationSequence.emplace_back( diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -921,31 +921,6 @@ /// Empty for now, used for SFINAE purposes only. struct LinalgVectorizationOptions {}; -/// `filter` controls LinalgTransformMarker matching and update when specified. -/// See `vectorizeLinalgOp` for more details. -struct LinalgVectorizationPattern : public OpInterfaceRewritePattern { - /// Construct a generic pattern applied to all LinalgOp that verify `filter`. - LinalgVectorizationPattern( - MLIRContext *context, - LinalgTransformationFilter f = LinalgTransformationFilter(), - LinalgVectorizationOptions options = LinalgVectorizationOptions(), - PatternBenefit benefit = 1); - - /// Construct a pattern specifically applied to `opName`. - LinalgVectorizationPattern( - StringRef opName, MLIRContext *context, - LinalgVectorizationOptions options = LinalgVectorizationOptions(), - LinalgTransformationFilter f = LinalgTransformationFilter(), - PatternBenefit benefit = 1); - - LogicalResult matchAndRewrite(LinalgOp linalgOp, - PatternRewriter &rewriter) const override; - -private: - /// LinalgTransformMarker handles special attribute manipulations. - LinalgTransformationFilter filter; -}; - /// `filter` controls LinalgTransformMarker matching and update when specified. /// See `vectorizeLinalgOp` for more details. struct CopyVectorizationPattern : public OpRewritePattern { @@ -1330,18 +1305,6 @@ const LinalgTransformationFilter &f) {} }; -template -class VectorizationPatterns { -public: - static void insert(RewritePatternSet &patterns, - const LinalgVectorizationOptions &options, - const LinalgTransformationFilter &f) { - patterns.add(OpTy::getOperationName(), - patterns.getContext(), options, f); - VectorizationPatterns::insert(patterns, options, f); - } -}; - template class TilingPatterns; diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -1160,6 +1160,22 @@ // VectorizeOp //===----------------------------------------------------------------------===// +namespace { +/// This is an helper only to call vectorize via a pattern inside of +/// VectorizeOp::applyToOne. +struct VectorizationPattern : public RewritePattern { + explicit VectorizationPattern(MLIRContext *context) + : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {} + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { + LinalgOp linalgOp = dyn_cast(op); + if (!linalgOp) + return failure(); + return vectorize(rewriter, linalgOp); + } +}; +} // namespace + DiagnosedSilenceableFailure transform::VectorizeOp::applyToOne(Operation *target, SmallVectorImpl &results, @@ -1172,15 +1188,22 @@ MLIRContext *ctx = getContext(); RewritePatternSet patterns(ctx); - patterns.add(ctx); + patterns.add(ctx); + + if (!getDisableTransferPermutationMapLoweringPatterns()) + vector::populateVectorTransferPermutationMapLoweringPatterns(patterns); + + if (!getDisableMultiReductionToContractPatterns()) + vector::populateVectorReductionToContractPatterns(patterns); - vector::populateVectorTransferPermutationMapLoweringPatterns(patterns); - vector::populateVectorReductionToContractPatterns(patterns); patterns.add(ctx, /*benefit=*/2); vector::TransferReadOp::getCanonicalizationPatterns(patterns, ctx); vector::TransferWriteOp::getCanonicalizationPatterns(patterns, ctx); + + patterns.add(ctx); + if (getVectorizePadding()) linalg::populatePadOpVectorizationPatterns(patterns); @@ -1206,7 +1229,7 @@ void init() { declareDependentDialect(); - + declareDependentDialect(); declareGeneratedDialect(); declareGeneratedDialect(); declareGeneratedDialect(); diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp @@ -40,7 +40,6 @@ #define GEN_PASS_DEF_LINALGSTRATEGYPADPASS #define GEN_PASS_DEF_LINALGSTRATEGYDECOMPOSEPASS #define GEN_PASS_DEF_LINALGSTRATEGYPEELPASS -#define GEN_PASS_DEF_LINALGSTRATEGYVECTORIZEPASS #define GEN_PASS_DEF_LINALGSTRATEGYLOWERVECTORSPASS #define GEN_PASS_DEF_LINALGSTRATEGYREMOVEMARKERSPASS #include "mlir/Dialect/Linalg/Passes.h.inc" @@ -215,62 +214,6 @@ LinalgTransformationFilter filter; }; -/// Configurable pass to apply pattern-based linalg vectorization. -struct LinalgStrategyVectorizePass - : public impl::LinalgStrategyVectorizePassBase< - LinalgStrategyVectorizePass> { - - LinalgStrategyVectorizePass() = default; - - LinalgStrategyVectorizePass(StringRef opName, LinalgVectorizationOptions opt, - LinalgTransformationFilter filt, - bool padVectorize = false) - : options(opt), filter(std::move(filt)) { - this->anchorOpName.setValue(opName.str()); - this->vectorizePadding.setValue(padVectorize); - } - - void runOnOperation() override { - auto funcOp = getOperation(); - if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) - return; - - RewritePatternSet vectorizationPatterns(funcOp.getContext()); - if (!anchorOpName.empty()) { - vectorizationPatterns.add( - anchorOpName, funcOp.getContext(), options, filter); - } else { - vectorizationPatterns.add(funcOp.getContext(), - filter, options); - } - vector::populateVectorTransferPermutationMapLoweringPatterns( - vectorizationPatterns); - vector::populateVectorReductionToContractPatterns(vectorizationPatterns); - vectorizationPatterns.add( - funcOp.getContext(), /*benefit=*/2); - TransferReadOp::getCanonicalizationPatterns(vectorizationPatterns, - funcOp.getContext()); - TransferWriteOp::getCanonicalizationPatterns(vectorizationPatterns, - funcOp.getContext()); - (void)applyPatternsAndFoldGreedily(funcOp, - std::move(vectorizationPatterns)); - - // Apply the pad tensor op vectorization separately to avoid running the - // GenericPadOpVectorizationPattern too early. - // TODO: Improve once we have better infrastructure to control pattern - // application. - if (vectorizePadding) { - RewritePatternSet patterns(funcOp.getContext()); - linalg::populatePadOpVectorizationPatterns(patterns); - (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); - } - } - - LinalgVectorizationOptions options; - LinalgTransformationFilter filter; -}; - /// Configurable pass to lower vector operations. struct LinalgStrategyLowerVectorsPass : public impl::LinalgStrategyLowerVectorsPassBase< @@ -393,15 +336,6 @@ return std::make_unique(opName, opt, filter); } -/// Create a LinalgStrategyVectorizePass. -std::unique_ptr> -mlir::createLinalgStrategyVectorizePass( - StringRef opName, LinalgVectorizationOptions opt, - const LinalgTransformationFilter &filter, bool padVectorize) { - return std::make_unique(opName, opt, filter, - padVectorize); -} - /// Create a LinalgStrategyLowerVectorsPass. std::unique_ptr> mlir::createLinalgStrategyLowerVectorsPass( diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -590,25 +590,6 @@ return success(); } -mlir::linalg::LinalgVectorizationPattern::LinalgVectorizationPattern( - MLIRContext *context, LinalgTransformationFilter f, - LinalgVectorizationOptions options, PatternBenefit benefit) - : OpInterfaceRewritePattern(context, benefit), - filter(std::move(f)) {} - -mlir::linalg::LinalgVectorizationPattern::LinalgVectorizationPattern( - StringRef opName, MLIRContext *context, LinalgVectorizationOptions options, - LinalgTransformationFilter f, PatternBenefit benefit) - : OpInterfaceRewritePattern(context, benefit), - filter(f.addOpNameFilter(opName)) {} - -LogicalResult mlir::linalg::LinalgVectorizationPattern::matchAndRewrite( - LinalgOp linalgOp, PatternRewriter &rewriter) const { - if (failed(filter.checkAndNotify(rewriter, linalgOp))) - return failure(); - return vectorize(rewriter, linalgOp); -} - LogicalResult mlir::linalg::CopyVectorizationPattern::matchAndRewrite( memref::CopyOp copyOp, PatternRewriter &rewriter) const { return vectorizeCopy(rewriter, copyOp); diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir --- a/mlir/test/Dialect/Linalg/vectorization.mlir +++ b/mlir/test/Dialect/Linalg/vectorization.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -test-linalg-transform-patterns=test-linalg-to-vector-patterns -split-input-file | FileCheck %s +// RUN: mlir-opt %s -test-transform-dialect-interpreter -split-input-file | FileCheck %s // ----- @@ -12,6 +12,16 @@ return } +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.dot"]} in %arg1 + %1 = get_closest_isolated_parent %0 + %2 = transform.structured.vectorize %1 { disable_multi_reduction_to_contract_patterns = true } + } +} + // ----- // CHECK-LABEL: contraction_matvec @@ -24,6 +34,16 @@ return } +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.matvec"]} in %arg1 + %1 = get_closest_isolated_parent %0 + %2 = transform.structured.vectorize %1 { disable_multi_reduction_to_contract_patterns = true } + } +} + // ----- // CHECK-LABEL: contraction_matmul @@ -35,6 +55,16 @@ return } +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 + %1 = get_closest_isolated_parent %0 + %2 = transform.structured.vectorize %1 { disable_multi_reduction_to_contract_patterns = true } + } +} + // ----- // CHECK-LABEL: contraction_batch_matmul @@ -47,6 +77,16 @@ return } +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.batch_matmul"]} in %arg1 + %1 = get_closest_isolated_parent %0 + %2 = transform.structured.vectorize %1 { disable_multi_reduction_to_contract_patterns = true } + } +} + // ----- #matmul_trait = { @@ -80,6 +120,16 @@ return } +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 + %1 = get_closest_isolated_parent %0 + %2 = transform.structured.vectorize %1 { disable_multi_reduction_to_contract_patterns = true, disable_transfer_permutation_map_lowering_patterns = true } + } +} + // ----- #matmul_transpose_out_trait = { @@ -113,6 +163,16 @@ return } +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 + %1 = get_closest_isolated_parent %0 + %2 = transform.structured.vectorize %1 { disable_multi_reduction_to_contract_patterns = true, disable_transfer_permutation_map_lowering_patterns = true } + } +} + // ----- #map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> @@ -133,6 +193,16 @@ return %1 : tensor<128x12x32xf32> } +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 + %1 = get_closest_isolated_parent %0 + %2 = transform.structured.vectorize %1 { disable_multi_reduction_to_contract_patterns = true, disable_transfer_permutation_map_lowering_patterns = true } + } +} + // ----- #matmul_trait = { @@ -166,6 +236,16 @@ return } +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 + %1 = get_closest_isolated_parent %0 + %2 = transform.structured.vectorize %1 { disable_multi_reduction_to_contract_patterns = true, disable_transfer_permutation_map_lowering_patterns = true } + } +} + // ----- // CHECK-LABEL: func @vectorization_test_2 @@ -179,6 +259,16 @@ return } +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 + %1 = get_closest_isolated_parent %0 + %2 = transform.structured.vectorize %1 { disable_multi_reduction_to_contract_patterns = true } + } +} + // ----- // CHECK-LABEL: func @test_vectorize_scalar_input @@ -196,6 +286,16 @@ return } +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 + %1 = get_closest_isolated_parent %0 + %2 = transform.structured.vectorize %1 + } +} + // ----- // CHECK-LABEL: func @test_do_not_vectorize_unsupported_element_types @@ -213,6 +313,16 @@ return } +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 + %1 = get_closest_isolated_parent %0 + %2 = transform.structured.vectorize %1 + } +} + // ----- // CHECK-LABEL: func @test_vectorize_fill @@ -223,6 +333,16 @@ return } +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.fill"]} in %arg1 + %1 = get_closest_isolated_parent %0 + %2 = transform.structured.vectorize %1 + } +} + // ----- // CHECK-LABEL: func @test_vectorize_fill @@ -234,6 +354,16 @@ return } +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.fill"]} in %arg1 + %1 = get_closest_isolated_parent %0 + %2 = transform.structured.vectorize %1 + } +} + // ----- // CHECK-LABEL: func @test_vectorize_copy @@ -244,6 +374,16 @@ return } +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["memref.copy"]} in %arg1 + %1 = get_closest_isolated_parent %0 + %2 = transform.structured.vectorize %1 + } +} + // ----- // CHECK-LABEL: func @test_vectorize_copy_scalar @@ -257,6 +397,15 @@ return } +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["memref.copy"]} in %arg1 + %1 = get_closest_isolated_parent %0 + %2 = transform.structured.vectorize %1 + } +} // ----- // CHECK-LABEL: func @test_vectorize_trailing_index @@ -278,6 +427,16 @@ return } +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 + %1 = get_closest_isolated_parent %0 + %2 = transform.structured.vectorize %1 + } +} + // ----- // CHECK-LABEL: func @test_vectorize_inner_index @@ -300,6 +459,16 @@ return } +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 + %1 = get_closest_isolated_parent %0 + %2 = transform.structured.vectorize %1 + } +} + // ----- // CHECK-LABEL: func @generic_vectorize @@ -378,6 +547,16 @@ return } +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 + %1 = get_closest_isolated_parent %0 + %2 = transform.structured.vectorize %1 {disable_transfer_permutation_map_lowering_patterns = true } + } +} + // ----- // CHECK-LABEL: func @generic_vectorize_tensor @@ -462,6 +641,16 @@ tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32> } +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 + %1 = get_closest_isolated_parent %0 + %2 = transform.structured.vectorize %1 { disable_transfer_permutation_map_lowering_patterns = true } + } +} + // ----- // CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, 0, 0, d1)> @@ -499,6 +688,16 @@ return } +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 + %1 = get_closest_isolated_parent %0 + %2 = transform.structured.vectorize %1 {disable_transfer_permutation_map_lowering_patterns = true } + } +} + // ----- // Test different input maps. @@ -535,6 +734,16 @@ return } +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 + %1 = get_closest_isolated_parent %0 + %2 = transform.structured.vectorize %1 {disable_transfer_permutation_map_lowering_patterns = true } + } +} + // ----- // CHECK-LABEL: func @matmul_tensors @@ -560,6 +769,16 @@ return %0 : tensor<8x12xf32> } +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 + %1 = get_closest_isolated_parent %0 + %2 = transform.structured.vectorize %1 { disable_multi_reduction_to_contract_patterns = true, disable_transfer_permutation_map_lowering_patterns = true } + } +} + // ----- // CHECK-LABEL: func @pad_static( @@ -581,6 +800,17 @@ return %0 : tensor<2x3x4xf32> } + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["tensor.pad"]} in %arg1 + %1 = get_closest_isolated_parent %0 + %2 = transform.structured.vectorize %1 { vectorize_padding = true } + } +} + // ----- // CHECK-LABEL: func @pad_static_source( @@ -602,6 +832,18 @@ return %0 : tensor<2x6x4xf32> } + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["tensor.pad"]} in %arg1 + %1 = get_closest_isolated_parent %0 + %2 = transform.structured.vectorize %1 { vectorize_padding = true } + } +} + + // ----- // CHECK-LABEL: func @pad_static_dynamic( @@ -630,6 +872,18 @@ return %0 : tensor<6x?x?x?xf32> } + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["tensor.pad"]} in %arg1 + %1 = get_closest_isolated_parent %0 + %2 = transform.structured.vectorize %1 { vectorize_padding = true } + } +} + + // ----- // CHECK-LABEL: func @pad_and_transfer_read @@ -652,6 +906,17 @@ return %1 : vector<7x9xf32> } + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["tensor.pad"]} in %arg1 + %1 = get_closest_isolated_parent %0 + %2 = transform.structured.vectorize %1 { vectorize_padding = true } + } +} + // ----- func.func private @make_vector() -> vector<7x9xf32> @@ -678,6 +943,17 @@ return %3 : tensor<5x6xf32> } +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %3 = transform.structured.match ops{["tensor.pad"]} in %arg1 + %4 = get_closest_isolated_parent %3 + %5 = transform.structured.vectorize %4 { vectorize_padding = true } + } +} + + // ----- func.func private @make_vector() -> vector<7x9xf32> @@ -707,6 +983,17 @@ return %3 : tensor } +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %3 = transform.structured.match ops{["tensor.pad"]} in %arg1 + %4 = get_closest_isolated_parent %3 + %5 = transform.structured.vectorize %4 { vectorize_padding = true } + } +} + + // ----- func.func private @make_vector() -> tensor<12x13xf32> @@ -733,6 +1020,17 @@ return %r : tensor<12x13xf32> } +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %3 = transform.structured.match ops{["tensor.pad"]} in %arg1 + %4 = get_closest_isolated_parent %3 + %5 = transform.structured.vectorize %4 { vectorize_padding = true } + } +} + + // ----- func.func private @make_vector() -> tensor<12x13xf32> @@ -753,6 +1051,16 @@ return %r : tensor<1x12x13xf32> } +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %3 = transform.structured.match ops{["tensor.pad"]} in %arg1 + %4 = get_closest_isolated_parent %3 + %5 = transform.structured.vectorize %4 + } +} + // ----- // CHECK-LABEL: func @pad_tensor_non_const_pad_value @@ -782,6 +1090,17 @@ return %0 : tensor<12x13xf32> } + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %3 = transform.structured.match ops{["tensor.pad"]} in %arg1 + %4 = get_closest_isolated_parent %3 + %5 = transform.structured.vectorize %4 { vectorize_padding = true } + } +} + // ----- // CHECK-LABEL: func @sum_exp @@ -809,6 +1128,17 @@ return %0 : tensor<4x16xf32> } + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %3 = transform.structured.match ops{["linalg.generic"]} in %arg1 + %4 = get_closest_isolated_parent %3 + %5 = transform.structured.vectorize %4 + } +} + // ----- // CHECK-DAG: #[[$M1:.*]] = affine_map<(d0, d1) -> (d1, d0, 0, 0)> @@ -846,13 +1176,23 @@ return %0 : tensor<5x2xf32> } + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %3 = transform.structured.match ops{["linalg.generic"]} in %arg1 + %4 = get_closest_isolated_parent %3 + %5 = transform.structured.vectorize %4 { disable_multi_reduction_to_contract_patterns = true, disable_transfer_permutation_map_lowering_patterns = true } + } +} + // ----- // CHECK-LABEL: func @red_max_2d( func.func @red_max_2d(%arg0: tensor<4x4xf32>) -> tensor<4xf32> { // CHECK: %[[CMINF:.+]] = arith.constant dense<-3.402820e+38> : vector<4xf32> // CHECK: linalg.init_tensor [4] : tensor<4xf32> - // CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32> // CHECK: vector.multi_reduction , {{.*}}, %[[CMINF]] [1] : vector<4x4xf32> to vector<4xf32> // CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32> %ident = arith.constant -3.40282e+38 : f32 @@ -869,13 +1209,23 @@ return %red : tensor<4xf32> } + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %3 = transform.structured.match ops{["linalg.generic"]} in %arg1 + %4 = get_closest_isolated_parent %3 + %5 = transform.structured.vectorize %4 { vectorize_padding = true } + } +} + // ----- // CHECK-LABEL: func @red_min_2d( func.func @red_min_2d(%arg0: tensor<4x4xf32>) -> tensor<4xf32> { // CHECK: %[[CMAXF:.+]] = arith.constant dense<3.402820e+38> : vector<4xf32> // CHECK: linalg.init_tensor [4] : tensor<4xf32> - // CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32> // CHECK: vector.transfer_read {{.*}} : tensor<4x4xf32>, vector<4x4xf32> // CHECK: vector.multi_reduction , {{.*}}, %[[CMAXF]] [1] : vector<4x4xf32> to vector<4xf32> // CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32> @@ -893,12 +1243,22 @@ return %red : tensor<4xf32> } + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %3 = transform.structured.match ops{["linalg.generic"]} in %arg1 + %4 = get_closest_isolated_parent %3 + %5 = transform.structured.vectorize %4 + } +} + // ----- // CHECK-LABEL: func @red_mul_2d( func.func @red_mul_2d(%arg0: tensor<4x4xf32>) -> tensor<4xf32> { // CHECK: linalg.init_tensor [4] : tensor<4xf32> - // CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32> // CHECK: vector.transfer_read {{.*}} : tensor<4x4xf32>, vector<4x4xf32> // CHECK: vector.multi_reduction , {{.*}}, {{.*}} [1] : vector<4x4xf32> to vector<4xf32> // CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32> @@ -916,12 +1276,22 @@ return %red : tensor<4xf32> } + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %3 = transform.structured.match ops{["linalg.generic"]} in %arg1 + %4 = get_closest_isolated_parent %3 + %5 = transform.structured.vectorize %4 + } +} + // ----- // CHECK-LABEL: func @red_or_2d( func.func @red_or_2d(%arg0: tensor<4x4xi1>) -> tensor<4xi1> { // CHECK: linalg.init_tensor [4] : tensor<4xi1> - // CHECK: vector.transfer_write {{.*}} : vector<4xi1>, tensor<4xi1> // CHECK: vector.transfer_read {{.*}} : tensor<4x4xi1>, vector<4x4xi1> // CHECK: vector.multi_reduction , {{.*}}, {{.*}} [1] : vector<4x4xi1> to vector<4xi1> // CHECK: vector.transfer_write {{.*}} : vector<4xi1>, tensor<4xi1> @@ -939,12 +1309,22 @@ return %red : tensor<4xi1> } + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %3 = transform.structured.match ops{["linalg.generic"]} in %arg1 + %4 = get_closest_isolated_parent %3 + %5 = transform.structured.vectorize %4 + } +} + // ----- // CHECK-LABEL: func @red_and_2d( func.func @red_and_2d(%arg0: tensor<4x4xi1>) -> tensor<4xi1> { // CHECK: linalg.init_tensor [4] : tensor<4xi1> - // CHECK: vector.transfer_write {{.*}} : vector<4xi1>, tensor<4xi1> // CHECK: vector.transfer_read {{.*}} : tensor<4x4xi1>, vector<4x4xi1> // CHECK: vector.multi_reduction , {{.*}}, {{.*}} [1] : vector<4x4xi1> to vector<4xi1> // CHECK: vector.transfer_write {{.*}} : vector<4xi1>, tensor<4xi1> @@ -962,12 +1342,22 @@ return %red : tensor<4xi1> } + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %3 = transform.structured.match ops{["linalg.generic"]} in %arg1 + %4 = get_closest_isolated_parent %3 + %5 = transform.structured.vectorize %4 + } +} + // ----- // CHECK-LABEL: func @red_xor_2d( func.func @red_xor_2d(%arg0: tensor<4x4xi1>) -> tensor<4xi1> { // CHECK: linalg.init_tensor [4] : tensor<4xi1> - // CHECK: vector.transfer_write {{.*}} : vector<4xi1>, tensor<4xi1> // CHECK: vector.transfer_read {{.*}} : tensor<4x4xi1>, vector<4x4xi1> // CHECK: vector.multi_reduction , {{.*}}, {{.*}} [1] : vector<4x4xi1> to vector<4xi1> // CHECK: vector.transfer_write {{.*}} : vector<4xi1>, tensor<4xi1> @@ -985,6 +1375,17 @@ return %red : tensor<4xi1> } + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %3 = transform.structured.match ops{["linalg.generic"]} in %arg1 + %4 = get_closest_isolated_parent %3 + %5 = transform.structured.vectorize %4 + } +} + // ----- // CHECK-DAG: #[[$M5:.*]] = affine_map<(d0, d1) -> (d0, 0)> @@ -1011,6 +1412,17 @@ return %red : tensor<4x4xf32> } + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %3 = transform.structured.match ops{["linalg.generic"]} in %arg1 + %4 = get_closest_isolated_parent %3 + %5 = transform.structured.vectorize %4 + } +} + // ----- // CHECK-DAG: #[[$M6:.*]] = affine_map<(d0, d1) -> (d0, 0)> @@ -1041,6 +1453,21 @@ return %red : tensor<4xf32> } + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.fill"]} in %arg1 + %1 = get_closest_isolated_parent %0 + %2 = transform.structured.vectorize %1 + + %3 = transform.structured.match ops{["linalg.generic"]} in %arg1 + %4 = get_closest_isolated_parent %3 + %5 = transform.structured.vectorize %4 + } +} + // ----- // CHECK-LABEL: func @reduce_1d( @@ -1054,8 +1481,6 @@ // CHECK: %[[init:.*]] = linalg.init_tensor [] : tensor %0 = linalg.init_tensor [] : tensor - // CHECK: %[[f:.*]] = vector.transfer_write %[[vF0]], %[[init]][] - // CHECK-SAME: : vector, tensor %1 = linalg.fill ins(%f0 : f32) outs(%0 : tensor) -> tensor // CHECK: %[[r:.*]] = vector.transfer_read %[[A]][%[[C0]]] // CHECK-SAME: : tensor<32xf32>, vector<32xf32> @@ -1063,7 +1488,7 @@ // CHECK: %[[red:.*]] = vector.multi_reduction , %[[r]], %[[f0]] [0] // CHECK-SAME: : vector<32xf32> to f32 // CHECK: %[[red_v1:.*]] = vector.broadcast %[[red]] : f32 to vector - // CHECK: %[[res:.*]] = vector.transfer_write %[[red_v1]], %[[f]][] + // CHECK: %[[res:.*]] = vector.transfer_write %[[red_v1]], %[[init]][] // CHECK-SAME: : vector, tensor %2 = linalg.generic { indexing_maps = [affine_map<(d0) -> (d0)>, @@ -1079,6 +1504,16 @@ return %2 : tensor } +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 + %1 = get_closest_isolated_parent %0 + %2 = transform.structured.vectorize %1 + } +} + // ----- @@ -1103,6 +1538,16 @@ return %result : tensor<6x6x3x3xf32> } +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 + %1 = get_closest_isolated_parent %0 + %2 = transform.structured.vectorize %1 + } +} + // ----- // Check vectorization can handle cases where outputs are a mix of reduced and non-reduced values. @@ -1134,3 +1579,13 @@ // CHECK-DAG: %[[ADD:.+]] = vector.multi_reduction , %[[MUL]], %[[V2]] // CHECK-DAG: vector.transfer_write %[[MUL]], %[[ARG2]] // CHECK-DAG: vector.transfer_write %[[ADD]], %[[ARG3]] + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 + %1 = get_closest_isolated_parent %0 + %2 = transform.structured.vectorize %1 { disable_multi_reduction_to_contract_patterns = true, disable_transfer_permutation_map_lowering_patterns = true } + } +} \ No newline at end of file diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp --- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp @@ -225,9 +225,6 @@ //===--------------------------------------------------------------------===// // Linalg to vector contraction patterns. //===--------------------------------------------------------------------===// - patterns.add( - ctx, LinalgTransformationFilter(StringAttr::get(ctx, "VECTORIZE")) - .addOpFilter()); patterns.add(ctx); (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); @@ -441,9 +438,6 @@ static void applyLinalgToVectorPatterns(func::FuncOp funcOp) { RewritePatternSet patterns(funcOp.getContext()); auto *ctx = funcOp.getContext(); - patterns.add( - ctx, LinalgTransformationFilter() - .addOpFilter()); patterns.add(ctx); populatePadOpVectorizationPatterns(patterns); populateConvolutionVectorizationPatterns(patterns);