diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h --- a/mlir/include/mlir/Dialect/Linalg/Passes.h +++ b/mlir/include/mlir/Dialect/Linalg/Passes.h @@ -134,15 +134,6 @@ const linalg::LinalgTransformationFilter &filter = linalg::LinalgTransformationFilter()); -/// Create a LinalgStrategyVectorizePass. -std::unique_ptr> createLinalgStrategyVectorizePass( - StringRef opName = "", - linalg::LinalgVectorizationOptions opt = - linalg::LinalgVectorizationOptions(), - const linalg::LinalgTransformationFilter &filter = - linalg::LinalgTransformationFilter(), - bool padVectorize = false); - /// Create a LinalgStrategyLowerVectorsPass. std::unique_ptr> createLinalgStrategyLowerVectorsPass( diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td --- a/mlir/include/mlir/Dialect/Linalg/Passes.td +++ b/mlir/include/mlir/Dialect/Linalg/Passes.td @@ -238,21 +238,6 @@ ]; } -def LinalgStrategyVectorizePass - : Pass<"linalg-strategy-vectorize-pass", "func::FuncOp"> { - let summary = "Configurable pass to apply pattern-based linalg vectorization."; - let constructor = "mlir::createLinalgStrategyVectorizePass()"; - let dependentDialects = ["linalg::LinalgDialect"]; - let options = [ - Option<"anchorFuncName", "anchor-func", "std::string", /*default=*/"", - "Which func op is the anchor to latch on.">, - Option<"anchorOpName", "anchor-op", "std::string", /*default=*/"", - "Which linalg op within the func is the anchor to latch on.">, - Option<"vectorizePadding", "vectorize-padding", "bool", "false", - "Enable vectorization of padding ops.">, - ]; -} - def LinalgStrategyLowerVectorsPass : Pass<"linalg-strategy-lower-vectors-pass", "func::FuncOp"> { let summary = "Configurable pass to lower vector operations."; diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -773,7 +773,9 @@ }]; let arguments = (ins PDL_Operation:$target, - DefaultValuedAttr:$vectorize_padding); + DefaultValuedAttr:$vectorize_padding, + DefaultValuedAttr:$reduction_to_contract, + DefaultValuedAttr:$permutation_map); let results = (outs PDL_Operation:$transformed); let assemblyFormat = "$target attr-dict"; diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h b/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h @@ -112,32 +112,6 @@ linalg::LinalgPeelOptions options; }; -/// Represent one application of createLinalgStrategyVectorizePass. -struct Vectorize : public Transformation { - explicit Vectorize(linalg::LinalgVectorizationOptions options, - LinalgTransformationFilter::FilterFunction f = nullptr, - bool padVectorize = false) - : Transformation(std::move(f)), options(options), - vectorizePadding(padVectorize) {} - - Vectorize(StringRef name, linalg::LinalgVectorizationOptions options, - LinalgTransformationFilter::FilterFunction f = nullptr, - bool padVectorize = false) - : Transformation(std::move(f)), opName(name), options(options), - vectorizePadding(padVectorize) {} - - void addToPassPipeline(OpPassManager &pm, - LinalgTransformationFilter m) const override { - pm.addPass(createLinalgStrategyVectorizePass(opName, options, m, - vectorizePadding)); - } - -private: - std::string opName; - linalg::LinalgVectorizationOptions options; - bool vectorizePadding; -}; - /// Represent one application of createLinalgStrategyLowerVectorsPass. struct VectorLowering : public Transformation { explicit VectorLowering( @@ -203,7 +177,7 @@ padIf(bool b, StringRef opName, linalg::LinalgPaddingOptions options, LinalgTransformationFilter::FilterFunction f = nullptr) { return b ? pad(opName, std::move(options), std::move(f)) : *this; - } + } /// Append patterns to decompose convolutions. CodegenStrategy & decompose(const LinalgTransformationFilter::FilterFunction &f = nullptr) { @@ -229,23 +203,6 @@ LinalgTransformationFilter::FilterFunction f = nullptr) { return b ? peel(opName, options, std::move(f)) : *this; } - /// Append a pattern to rewrite `LinalgOpType` as a vector operation. - CodegenStrategy & - vectorize(StringRef opName, - const LinalgTransformationFilter::FilterFunction &f = nullptr, - bool vectorizePadding = false) { - transformationSequence.emplace_back(std::make_unique( - opName, linalg::LinalgVectorizationOptions(), f, vectorizePadding)); - return *this; - } - /// Conditionally append a pattern to rewrite `LinalgOpType` as a vector - /// operation. - CodegenStrategy & - vectorizeIf(bool b, StringRef opName, - LinalgTransformationFilter::FilterFunction f = nullptr, - bool vectorizePadding = false) { - return b ? vectorize(opName, std::move(f), vectorizePadding) : *this; - } /// Append a pattern to lower all vector operations. CodegenStrategy &vectorLowering(LinalgVectorLoweringOptions options) { transformationSequence.emplace_back( diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -921,31 +921,6 @@ /// Empty for now, used for SFINAE purposes only. struct LinalgVectorizationOptions {}; -/// `filter` controls LinalgTransformMarker matching and update when specified. -/// See `vectorizeLinalgOp` for more details. -struct LinalgVectorizationPattern : public OpInterfaceRewritePattern { - /// Construct a generic pattern applied to all LinalgOp that verify `filter`. - LinalgVectorizationPattern( - MLIRContext *context, - LinalgTransformationFilter f = LinalgTransformationFilter(), - LinalgVectorizationOptions options = LinalgVectorizationOptions(), - PatternBenefit benefit = 1); - - /// Construct a pattern specifically applied to `opName`. - LinalgVectorizationPattern( - StringRef opName, MLIRContext *context, - LinalgVectorizationOptions options = LinalgVectorizationOptions(), - LinalgTransformationFilter f = LinalgTransformationFilter(), - PatternBenefit benefit = 1); - - LogicalResult matchAndRewrite(LinalgOp linalgOp, - PatternRewriter &rewriter) const override; - -private: - /// LinalgTransformMarker handles special attribute manipulations. - LinalgTransformationFilter filter; -}; - /// `filter` controls LinalgTransformMarker matching and update when specified. /// See `vectorizeLinalgOp` for more details. struct CopyVectorizationPattern : public OpRewritePattern { @@ -1330,18 +1305,6 @@ const LinalgTransformationFilter &f) {} }; -template -class VectorizationPatterns { -public: - static void insert(RewritePatternSet &patterns, - const LinalgVectorizationOptions &options, - const LinalgTransformationFilter &f) { - patterns.add(OpTy::getOperationName(), - patterns.getContext(), options, f); - VectorizationPatterns::insert(patterns, options, f); - } -}; - template class TilingPatterns; diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -1160,6 +1160,16 @@ // VectorizeOp //===----------------------------------------------------------------------===// +struct VectorizationPattern : public RewritePattern { + explicit VectorizationPattern(MLIRContext *context) + : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {} + LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { + LinalgOp linalgOp = dyn_cast(op); + if(!linalgOp) return failure(); + return vectorize(rewriter, linalgOp); + } +}; + DiagnosedSilenceableFailure transform::VectorizeOp::applyToOne(Operation *target, SmallVectorImpl &results, @@ -1172,10 +1182,14 @@ MLIRContext *ctx = getContext(); RewritePatternSet patterns(ctx); - patterns.add(ctx); - - vector::populateVectorTransferPermutationMapLoweringPatterns(patterns); - vector::populateVectorReductionToContractPatterns(patterns); + patterns.add(ctx); + + if(getPermutationMap()) + vector::populateVectorTransferPermutationMapLoweringPatterns(patterns);. + + if(getReductionToContract()) + vector::populateVectorReductionToContractPatterns(patterns); + patterns.add(ctx, /*benefit=*/2); diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp @@ -40,7 +40,6 @@ #define GEN_PASS_DEF_LINALGSTRATEGYPADPASS #define GEN_PASS_DEF_LINALGSTRATEGYDECOMPOSEPASS #define GEN_PASS_DEF_LINALGSTRATEGYPEELPASS -#define GEN_PASS_DEF_LINALGSTRATEGYVECTORIZEPASS #define GEN_PASS_DEF_LINALGSTRATEGYLOWERVECTORSPASS #define GEN_PASS_DEF_LINALGSTRATEGYREMOVEMARKERSPASS #include "mlir/Dialect/Linalg/Passes.h.inc" @@ -215,62 +214,6 @@ LinalgTransformationFilter filter; }; -/// Configurable pass to apply pattern-based linalg vectorization. -struct LinalgStrategyVectorizePass - : public impl::LinalgStrategyVectorizePassBase< - LinalgStrategyVectorizePass> { - - LinalgStrategyVectorizePass() = default; - - LinalgStrategyVectorizePass(StringRef opName, LinalgVectorizationOptions opt, - LinalgTransformationFilter filt, - bool padVectorize = false) - : options(opt), filter(std::move(filt)) { - this->anchorOpName.setValue(opName.str()); - this->vectorizePadding.setValue(padVectorize); - } - - void runOnOperation() override { - auto funcOp = getOperation(); - if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) - return; - - RewritePatternSet vectorizationPatterns(funcOp.getContext()); - if (!anchorOpName.empty()) { - vectorizationPatterns.add( - anchorOpName, funcOp.getContext(), options, filter); - } else { - vectorizationPatterns.add(funcOp.getContext(), - filter, options); - } - vector::populateVectorTransferPermutationMapLoweringPatterns( - vectorizationPatterns); - vector::populateVectorReductionToContractPatterns(vectorizationPatterns); - vectorizationPatterns.add( - funcOp.getContext(), /*benefit=*/2); - TransferReadOp::getCanonicalizationPatterns(vectorizationPatterns, - funcOp.getContext()); - TransferWriteOp::getCanonicalizationPatterns(vectorizationPatterns, - funcOp.getContext()); - (void)applyPatternsAndFoldGreedily(funcOp, - std::move(vectorizationPatterns)); - - // Apply the pad tensor op vectorization separately to avoid running the - // GenericPadOpVectorizationPattern too early. - // TODO: Improve once we have better infrastructure to control pattern - // application. - if (vectorizePadding) { - RewritePatternSet patterns(funcOp.getContext()); - linalg::populatePadOpVectorizationPatterns(patterns); - (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); - } - } - - LinalgVectorizationOptions options; - LinalgTransformationFilter filter; -}; - /// Configurable pass to lower vector operations. struct LinalgStrategyLowerVectorsPass : public impl::LinalgStrategyLowerVectorsPassBase< @@ -393,15 +336,6 @@ return std::make_unique(opName, opt, filter); } -/// Create a LinalgStrategyVectorizePass. -std::unique_ptr> -mlir::createLinalgStrategyVectorizePass( - StringRef opName, LinalgVectorizationOptions opt, - const LinalgTransformationFilter &filter, bool padVectorize) { - return std::make_unique(opName, opt, filter, - padVectorize); -} - /// Create a LinalgStrategyLowerVectorsPass. std::unique_ptr> mlir::createLinalgStrategyLowerVectorsPass( diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -590,25 +590,6 @@ return success(); } -mlir::linalg::LinalgVectorizationPattern::LinalgVectorizationPattern( - MLIRContext *context, LinalgTransformationFilter f, - LinalgVectorizationOptions options, PatternBenefit benefit) - : OpInterfaceRewritePattern(context, benefit), - filter(std::move(f)) {} - -mlir::linalg::LinalgVectorizationPattern::LinalgVectorizationPattern( - StringRef opName, MLIRContext *context, LinalgVectorizationOptions options, - LinalgTransformationFilter f, PatternBenefit benefit) - : OpInterfaceRewritePattern(context, benefit), - filter(f.addOpNameFilter(opName)) {} - -LogicalResult mlir::linalg::LinalgVectorizationPattern::matchAndRewrite( - LinalgOp linalgOp, PatternRewriter &rewriter) const { - if (failed(filter.checkAndNotify(rewriter, linalgOp))) - return failure(); - return vectorize(rewriter, linalgOp); -} - LogicalResult mlir::linalg::CopyVectorizationPattern::matchAndRewrite( memref::CopyOp copyOp, PatternRewriter &rewriter) const { return vectorizeCopy(rewriter, copyOp); diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir --- a/mlir/test/Dialect/Linalg/vectorization.mlir +++ b/mlir/test/Dialect/Linalg/vectorization.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -test-linalg-transform-patterns=test-linalg-to-vector-patterns -split-input-file | FileCheck %s +// RUN: mlir-opt %s -test-transform-dialect-interpreter -split-input-file | FileCheck %s // ----- @@ -12,6 +12,16 @@ return } +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.dot"]} in %arg1 + %1 = get_closest_isolated_parent %0 + %2 = transform.structured.vectorize %1 {reduction_to_contract = false, permutation_map = false } + } +} + // ----- // CHECK-LABEL: contraction_matvec @@ -24,6 +34,16 @@ return } +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.matvec"]} in %arg1 + %1 = get_closest_isolated_parent %0 + %2 = transform.structured.vectorize %1 {reduction_to_contract = false, permutation_map = false } + } +} + // ----- // CHECK-LABEL: contraction_matmul @@ -35,6 +55,16 @@ return } +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 + %1 = get_closest_isolated_parent %0 + %2 = transform.structured.vectorize %1 {reduction_to_contract = false, permutation_map = false } + } +} + // ----- // CHECK-LABEL: contraction_batch_matmul @@ -47,6 +77,16 @@ return } +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.batch_matmul"]} in %arg1 + %1 = get_closest_isolated_parent %0 + %2 = transform.structured.vectorize %1 {reduction_to_contract = false, permutation_map = false } + } +} + // ----- #matmul_trait = { @@ -80,6 +120,16 @@ return } +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 + %1 = get_closest_isolated_parent %0 + %2 = transform.structured.vectorize %1 {reduction_to_contract = false, permutation_map = false } + } +} + // ----- #matmul_transpose_out_trait = { @@ -113,6 +163,16 @@ return } +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 + %1 = get_closest_isolated_parent %0 + %2 = transform.structured.vectorize %1 {reduction_to_contract = false, permutation_map = false } + } +} + // ----- #map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> @@ -133,6 +193,16 @@ return %1 : tensor<128x12x32xf32> } +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 + %1 = get_closest_isolated_parent %0 + %2 = transform.structured.vectorize %1 {reduction_to_contract = false, permutation_map = false } + } +} + // ----- #matmul_trait = { @@ -166,6 +236,16 @@ return } +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 + %1 = get_closest_isolated_parent %0 + %2 = transform.structured.vectorize %1 {reduction_to_contract = false, permutation_map = false } + } +} + // ----- // CHECK-LABEL: func @vectorization_test_2 @@ -179,6 +259,16 @@ return } +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 + %1 = get_closest_isolated_parent %0 + %2 = transform.structured.vectorize %1 {reduction_to_contract = false, permutation_map = false } + } +} + // ----- // CHECK-LABEL: func @test_vectorize_scalar_input @@ -196,6 +286,16 @@ return } +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 + %1 = get_closest_isolated_parent %0 + %2 = transform.structured.vectorize %1 + } +} + // ----- // CHECK-LABEL: func @test_do_not_vectorize_unsupported_element_types @@ -213,6 +313,16 @@ return } +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 + %1 = get_closest_isolated_parent %0 + %2 = transform.structured.vectorize %1 + } +} + // ----- // CHECK-LABEL: func @test_vectorize_fill @@ -223,6 +333,16 @@ return } +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.fill"]} in %arg1 + %1 = get_closest_isolated_parent %0 + %2 = transform.structured.vectorize %1 + } +} + // ----- // CHECK-LABEL: func @test_vectorize_fill @@ -234,29 +354,58 @@ return } +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.fill"]} in %arg1 + %1 = get_closest_isolated_parent %0 + %2 = transform.structured.vectorize %1 + } +} + // ----- // CHECK-LABEL: func @test_vectorize_copy func.func @test_vectorize_copy(%A : memref<8x16xf32>, %B : memref<8x16xf32>) { - // CHECK: %[[V:.*]] = vector.transfer_read {{.*}} : memref<8x16xf32>, vector<8x16xf32> - // CHECK: vector.transfer_write %[[V]], {{.*}} : vector<8x16xf32>, memref<8x16xf32> + // CHECKFAILS: %[[V:.*]] = vector.transfer_read {{.*}} : memref<8x16xf32>, vector<8x16xf32> + // CHECKFAILS: vector.transfer_write %[[V]], {{.*}} : vector<8x16xf32>, memref<8x16xf32> memref.copy %A, %B : memref<8x16xf32> to memref<8x16xf32> return } +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["memref.copy"]} in %arg1 + %1 = get_closest_isolated_parent %0 + %2 = transform.structured.vectorize %1 + } +} + // ----- // CHECK-LABEL: func @test_vectorize_copy_scalar func.func @test_vectorize_copy_scalar(%A : memref, %B : memref) { // CHECK-SAME: (%[[A:.*]]: memref, %[[B:.*]]: memref) - // CHECK: %[[V:.*]] = vector.transfer_read %[[A]][]{{.*}} : memref, vector - // CHECK: %[[val:.*]] = vector.extractelement %[[V]][] : vector - // CHECK: %[[VV:.*]] = vector.broadcast %[[val]] : f32 to vector - // CHECK: vector.transfer_write %[[VV]], %[[B]][] : vector, memref + // CHECKFAILS: %[[V:.*]] = vector.transfer_read %[[A]][]{{.*}} : memref, vector + // CHECKFAILS: %[[val:.*]] = vector.extractelement %[[V]][] : vector + // CHECKFAILS: %[[VV:.*]] = vector.broadcast %[[val]] : f32 to vector + // CHECKFAILS: vector.transfer_write %[[VV]], %[[B]][] : vector, memref memref.copy %A, %B : memref to memref return } +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["memref.copy"]} in %arg1 + %1 = get_closest_isolated_parent %0 + %2 = transform.structured.vectorize %1 + } +} // ----- // CHECK-LABEL: func @test_vectorize_trailing_index @@ -278,6 +427,16 @@ return } +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 + %1 = get_closest_isolated_parent %0 + %2 = transform.structured.vectorize %1 + } +} + // ----- // CHECK-LABEL: func @test_vectorize_inner_index @@ -300,6 +459,16 @@ return } +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 + %1 = get_closest_isolated_parent %0 + %2 = transform.structured.vectorize %1 + } +} + // ----- // CHECK-LABEL: func @generic_vectorize @@ -378,6 +547,16 @@ return } +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 + %1 = get_closest_isolated_parent %0 + %2 = transform.structured.vectorize %1 {permutation_map = false } + } +} + // ----- // CHECK-LABEL: func @generic_vectorize_tensor @@ -462,6 +641,16 @@ tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32> } +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 + %1 = get_closest_isolated_parent %0 + %2 = transform.structured.vectorize %1 { permutation_map = false } + } +} + // ----- // CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, 0, 0, d1)> @@ -499,6 +688,16 @@ return } +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 + %1 = get_closest_isolated_parent %0 + %2 = transform.structured.vectorize %1 {permutation_map = false } + } +} + // ----- // Test different input maps. @@ -535,6 +734,16 @@ return } +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 + %1 = get_closest_isolated_parent %0 + %2 = transform.structured.vectorize %1 {permutation_map = false } + } +} + // ----- // CHECK-LABEL: func @matmul_tensors @@ -560,19 +769,29 @@ return %0 : tensor<8x12xf32> } +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 + %1 = get_closest_isolated_parent %0 + %2 = transform.structured.vectorize %1 {reduction_to_contract = false, permutation_map = false } + } +} + // ----- // CHECK-LABEL: func @pad_static( // CHECK-SAME: %[[ARG0:.*]]: tensor<2x?x2xf32>, %[[PAD:.*]]: f32 -// CHECK-NOT: tensor.pad -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index -// CHECK-DAG: %[[INIT:.*]] = linalg.init_tensor [2, 3, 4] : tensor<2x3x4xf32> -// CHECK-DAG: %[[VEC:.*]] = vector.broadcast %[[PAD]] : f32 to vector<2x3x4xf32> -// CHECK: %[[FILL:.*]] = vector.transfer_write %[[VEC]], %[[INIT]]{{.*}} : vector<2x3x4xf32>, tensor<2x3x4xf32> -// CHECK: %[[READ:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]], %[[C0]]], %[[PAD]] {in_bounds = [true, false, true]} : tensor<2x?x2xf32>, vector<2x3x2xf32> -// CHECK: %[[RESULT:.*]] = vector.transfer_write %[[READ]], %[[FILL]][%[[C0]], %[[C0]], %[[C2]]] {in_bounds = [true, true, true]} : vector<2x3x2xf32>, tensor<2x3x4xf32> -// CHECK: return %[[RESULT]] +// CHECKFAILS: tensor.pad +// CHECKFAILS: %[[C0:.*]] = arith.constant 0 : index +// CHECKFAILS: %[[C2:.*]] = arith.constant 2 : index +// CHECKFAILS: %[[INIT:.*]] = linalg.init_tensor [2, 3, 4] : tensor<2x3x4xf32> +// CHECKFAILS: %[[VEC:.*]] = vector.broadcast %[[PAD]] : f32 to vector<2x3x4xf32> +// CHECKFAILS: %[[FILL:.*]] = vector.transfer_write %[[VEC]], %[[INIT]]{{.*}} : vector<2x3x4xf32>, tensor<2x3x4xf32> +// CHECKFAILS: %[[READ:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]], %[[C0]]], %[[PAD]] {in_bounds = [true, false, true]} : tensor<2x?x2xf32>, vector<2x3x2xf32> +// CHECKFAILS: %[[RESULT:.*]] = vector.transfer_write %[[READ]], %[[FILL]][%[[C0]], %[[C0]], %[[C2]]] {in_bounds = [true, true, true]} : vector<2x3x2xf32>, tensor<2x3x4xf32> +// CHECKFAILS: return %[[RESULT]] func.func @pad_static(%arg0: tensor<2x?x2xf32>, %pad_value: f32) -> tensor<2x3x4xf32> { %0 = tensor.pad %arg0 low[0, 0, 2] high[0, 1, 0] { ^bb0(%arg1: index, %arg2: index, %arg3: index): @@ -581,19 +800,30 @@ return %0 : tensor<2x3x4xf32> } + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["tensor.pad"]} in %arg1 + %1 = get_closest_isolated_parent %0 + %2 = transform.structured.vectorize %1 + } +} + // ----- // CHECK-LABEL: func @pad_static_source( // CHECK-SAME: %[[ARG0:.*]]: tensor<2x5x2xf32>, %[[PAD:.*]]: f32 -// CHECK-NOT: tensor.pad -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index -// CHECK: %[[INIT:.*]] = linalg.init_tensor [2, 6, 4] : tensor<2x6x4xf32> -// CHECK: %[[VEC:.*]] = vector.broadcast %[[PAD]] : f32 to vector<2x6x4xf32> -// CHECK: %[[FILL:.*]] = vector.transfer_write %[[VEC]], %[[INIT]][%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true]} : vector<2x6x4xf32>, tensor<2x6x4xf32> -// CHECK: %[[READ:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]], %[[C0]]], %{{.*}} {in_bounds = [true, true, true]} : tensor<2x5x2xf32>, vector<2x5x2xf32> -// CHECK: %[[WRITE:.*]] = vector.transfer_write %[[READ]], %[[FILL]][%[[C0]], %[[C0]], %[[C2]]] {in_bounds = [true, true, true]} : vector<2x5x2xf32>, tensor<2x6x4xf32> -// CHECK: return %[[WRITE]] +// CHECKFAILS: tensor.pad +// CHECKFAILS: %[[C0:.*]] = arith.constant 0 : index +// CHECKFAILS: %[[C2:.*]] = arith.constant 2 : index +// CHECKFAILS: %[[INIT:.*]] = linalg.init_tensor [2, 6, 4] : tensor<2x6x4xf32> +// CHECKFAILS: %[[VEC:.*]] = vector.broadcast %[[PAD]] : f32 to vector<2x6x4xf32> +// CHECKFAILS: %[[FILL:.*]] = vector.transfer_write %[[VEC]], %[[INIT]][%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true]} : vector<2x6x4xf32>, tensor<2x6x4xf32> +// CHECKFAILS: %[[READ:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]], %[[C0]]], %{{.*}} {in_bounds = [true, true, true]} : tensor<2x5x2xf32>, vector<2x5x2xf32> +// CHECKFAILS: %[[WRITE:.*]] = vector.transfer_write %[[READ]], %[[FILL]][%[[C0]], %[[C0]], %[[C2]]] {in_bounds = [true, true, true]} : vector<2x5x2xf32>, tensor<2x6x4xf32> +// CHECKFAILS: return %[[WRITE]] func.func @pad_static_source(%arg0: tensor<2x5x2xf32>, %pad_value: f32) -> tensor<2x6x4xf32> { %0 = tensor.pad %arg0 low[0, 0, 2] high[0, 1, 0] { ^bb0(%arg1: index, %arg2: index, %arg3: index): @@ -602,25 +832,37 @@ return %0 : tensor<2x6x4xf32> } + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["tensor.pad"]} in %arg1 + %1 = get_closest_isolated_parent %0 + %2 = transform.structured.vectorize %1 + } +} + + // ----- // CHECK-LABEL: func @pad_static_dynamic( // CHECK-SAME: %[[SRC:.*]]: tensor<1x2x2x?xf32>, %[[LOW:.*]]: index, %[[HIGH:.*]]: index -// CHECK-NOT: tensor.pad -// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index -// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index -// CHECK-DAG: %[[C5:.*]] = arith.constant 5 : index -// CHECK: %[[V0:.*]] = arith.addi %[[LOW]], %[[C2]] : index -// CHECK: %[[V1:.*]] = arith.addi %[[V0]], %[[C3]] : index -// CHECK: %[[V2:.*]] = arith.addi %[[HIGH]], %[[C5]] : index -// CHECK: %[[DIM3:.*]] = tensor.dim %[[SRC]], %[[C3]] : tensor<1x2x2x?xf32> -// CHECK: %[[V4:.*]] = arith.addi %[[DIM3]], %[[C3]] : index -// CHECK: %[[V5:.*]] = arith.addi %[[V4]], %[[C2]] : index -// CHECK: %[[INIT:.*]] = linalg.init_tensor [6, %[[V1]], %[[V2]], %[[V5]]] : tensor<6x?x?x?xf32> -// CHECK: %[[FILL:.*]] = linalg.fill ins(%{{.*}} : f32) outs(%[[INIT]] : tensor<6x?x?x?xf32>) -> tensor<6x?x?x?xf32> -// CHECK: %[[SRCDIM:.*]] = tensor.dim %[[SRC]], %[[C3]] : tensor<1x2x2x?xf32> -// CHECK: %[[RESULT:.*]] = tensor.insert_slice %[[SRC]] into %[[FILL]][2, %[[LOW]], 3, 3] [1, 2, 2, %[[SRCDIM]]] [1, 1, 1, 1] : tensor<1x2x2x?xf32> into tensor<6x?x?x?xf32> -// CHECK: return %[[RESULT]] +// CHECKFAILS: tensor.pad +// CHECKFAILS: %[[C2:.*]] = arith.constant 2 : index +// CHECKFAILS: %[[C3:.*]] = arith.constant 3 : index +// CHECKFAILS: %[[C5:.*]] = arith.constant 5 : index +// CHECKFAILS: %[[V0:.*]] = arith.addi %[[LOW]], %[[C2]] : index +// CHECKFAILS: %[[V1:.*]] = arith.addi %[[V0]], %[[C3]] : index +// CHECKFAILS: %[[V2:.*]] = arith.addi %[[HIGH]], %[[C5]] : index +// CHECKFAILS: %[[DIM3:.*]] = tensor.dim %[[SRC]], %[[C3]] : tensor<1x2x2x?xf32> +// CHECKFAILS: %[[V4:.*]] = arith.addi %[[DIM3]], %[[C3]] : index +// CHECKFAILS: %[[V5:.*]] = arith.addi %[[V4]], %[[C2]] : index +// CHECKFAILS: %[[INIT:.*]] = linalg.init_tensor [6, %[[V1]], %[[V2]], %[[V5]]] : tensor<6x?x?x?xf32> +// CHECKFAILS: %[[FILL:.*]] = linalg.fill ins(%{{.*}} : f32) outs(%[[INIT]] : tensor<6x?x?x?xf32>) -> tensor<6x?x?x?xf32> +// CHECKFAILS: %[[SRCDIM:.*]] = tensor.dim %[[SRC]], %[[C3]] : tensor<1x2x2x?xf32> +// CHECKFAILS: %[[RESULT:.*]] = tensor.insert_slice %[[SRC]] into %[[FILL]][2, %[[LOW]], 3, 3] [1, 2, 2, %[[SRCDIM]]] [1, 1, 1, 1] : tensor<1x2x2x?xf32> into tensor<6x?x?x?xf32> +// CHECKFAILS: return %[[RESULT]] func.func @pad_static_dynamic(%arg0: tensor<1x2x2x?xf32>, %low: index, %high: index, %pad_value: f32) -> tensor<6x?x?x?xf32> { %0 = tensor.pad %arg0 low[2, %low, 3, 3] high[3, 3, %high, 2] { @@ -630,6 +872,18 @@ return %0 : tensor<6x?x?x?xf32> } + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["tensor.pad"]} in %arg1 + %1 = get_closest_isolated_parent %0 + %2 = transform.structured.vectorize %1 + } +} + + // ----- // CHECK-LABEL: func @pad_and_transfer_read @@ -637,8 +891,8 @@ // CHECK-NOT: tensor.pad // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index // CHECK-DAG: %[[C5:.*]] = arith.constant 5.0 -// CHECK: %[[RESULT:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], %[[C5]] : tensor<5x6xf32>, vector<7x9xf32> -// CHECK: return %[[RESULT]] +// CHECKFAILS: %[[RESULT:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], %[[C5]] : tensor<5x6xf32>, vector<7x9xf32> +// CHECKFAILS: return %[[RESULT]] func.func @pad_and_transfer_read(%arg0: tensor<5x6xf32>) -> vector<7x9xf32> { %c0 = arith.constant 0 : index %c5 = arith.constant 5.0 : f32 @@ -652,17 +906,28 @@ return %1 : vector<7x9xf32> } + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["tensor.pad"]} in %arg1 + %1 = get_closest_isolated_parent %0 + %2 = transform.structured.vectorize %1 + } +} + // ----- func.func private @make_vector() -> vector<7x9xf32> // CHECK-LABEL: func @pad_and_transfer_write_static // CHECK-SAME: %[[ARG0:.*]]: tensor<5x6xf32> -// CHECK-NOT: tensor.pad -// CHECK: %[[C0:.*]] = arith.constant 0 : index -// CHECK: %[[VEC0:.*]] = call @make_vector() : () -> vector<7x9xf32> -// CHECK: %[[RESULT:.*]] = vector.transfer_write %[[VEC0]], %[[ARG0]][%[[C0]], %[[C0]]] : vector<7x9xf32>, tensor<5x6xf32> -// CHECK: return %[[RESULT]] +// CHECKFAILS: tensor.pad +// CHECKFAILS: %[[C0:.*]] = arith.constant 0 : index +// CHECKFAILS: %[[VEC0:.*]] = call @make_vector() : () -> vector<7x9xf32> +// CHECKFAILS: %[[RESULT:.*]] = vector.transfer_write %[[VEC0]], %[[ARG0]][%[[C0]], %[[C0]]] : vector<7x9xf32>, tensor<5x6xf32> +// CHECKFAILS: return %[[RESULT]] func.func @pad_and_transfer_write_static( %arg0: tensor<5x6xf32>) -> tensor<5x6xf32> { %c0 = arith.constant 0 : index @@ -678,6 +943,17 @@ return %3 : tensor<5x6xf32> } +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %3 = transform.structured.match ops{["tensor.pad"]} in %arg1 + %4 = get_closest_isolated_parent %3 + %5 = transform.structured.vectorize %4 + } +} + + // ----- func.func private @make_vector() -> vector<7x9xf32> @@ -688,8 +964,8 @@ // CHECK: %[[C0:.*]] = arith.constant 0 : index // CHECK: %[[SUB:.*]] = tensor.extract_slice %[[ARG0]][0, 0] [%[[SIZE]], 6] [1, 1] : tensor to tensor // CHECK: %[[VEC0:.*]] = call @make_vector() : () -> vector<7x9xf32> -// CHECK: %[[RESULT:.*]] = vector.transfer_write %[[VEC0]], %[[SUB]][%[[C0]], %[[C0]]] : vector<7x9xf32>, tensor -// CHECK: return %[[RESULT]] +// CHECKFAILS: %[[RESULT:.*]] = vector.transfer_write %[[VEC0]], %[[SUB]][%[[C0]], %[[C0]]] : vector<7x9xf32>, tensor +// CHECKFAILS: return %[[RESULT]] func.func @pad_and_transfer_write_dynamic_static( %arg0: tensor, %size: index, %padding: index) -> tensor { %c0 = arith.constant 0 : index @@ -707,19 +983,30 @@ return %3 : tensor } +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %3 = transform.structured.match ops{["tensor.pad"]} in %arg1 + %4 = get_closest_isolated_parent %3 + %5 = transform.structured.vectorize %4 + } +} + + // ----- func.func private @make_vector() -> tensor<12x13xf32> // CHECK-LABEL: func @pad_and_insert_slice_source // CHECK-SAME: %[[ARG0:.*]]: tensor<5x6xf32> -// CHECK-NOT: tensor.pad -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[C5:.*]] = arith.constant 5.0 -// CHECK: %[[VEC0:.*]] = call @make_vector() : () -> tensor<12x13xf32> -// CHECK: %[[READ:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], %[[C5]] : tensor<5x6xf32>, vector<7x9xf32> -// CHECK: %[[WRITE:.*]] = vector.transfer_write %[[READ]], %[[VEC0]][%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<7x9xf32>, tensor<12x13xf32> -// CHECK: return %[[WRITE]] +// CHECKFAILS: tensor.pad +// CHECKFAILS: %[[C0:.*]] = arith.constant 0 : index +// CHECKFAILS: %[[C5:.*]] = arith.constant 5.0 +// CHECKFAILS: %[[VEC0:.*]] = call @make_vector() : () -> tensor<12x13xf32> +// CHECKFAILS: %[[READ:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], %[[C5]] : tensor<5x6xf32>, vector<7x9xf32> +// CHECKFAILS: %[[WRITE:.*]] = vector.transfer_write %[[READ]], %[[VEC0]][%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<7x9xf32>, tensor<12x13xf32> +// CHECKFAILS: return %[[WRITE]] func.func @pad_and_insert_slice_source( %arg0: tensor<5x6xf32>) -> tensor<12x13xf32> { %c0 = arith.constant 0 : index @@ -733,6 +1020,17 @@ return %r : tensor<12x13xf32> } +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %3 = transform.structured.match ops{["tensor.pad"]} in %arg1 + %4 = get_closest_isolated_parent %3 + %5 = transform.structured.vectorize %4 + } +} + + // ----- func.func private @make_vector() -> tensor<12x13xf32> @@ -753,20 +1051,30 @@ return %r : tensor<1x12x13xf32> } +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %3 = transform.structured.match ops{["tensor.pad"]} in %arg1 + %4 = get_closest_isolated_parent %3 + %5 = transform.structured.vectorize %4 + } +} + // ----- // CHECK-LABEL: func @pad_tensor_non_const_pad_value // CHECK-SAME: %[[ARG0:.*]]: tensor<5x6xf32> -// CHECK-NOT: tensor.pad -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index -// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index -// CHECK: %[[FILL:.*]] = tensor.generate -// CHECK: %[[RES:.*]] = arith.mulf -// CHECK: tensor.yield %[[RES]] : f32 -// CHECK: %[[READ:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], %{{.*}} {in_bounds = [true, true]} : tensor<5x6xf32>, vector<5x6xf32> -// CHECK: %[[WRITE:.*]] = vector.transfer_write %[[READ]], %[[FILL]][%[[C3]], %[[C4]]] {in_bounds = [true, true]} : vector<5x6xf32>, tensor<12x13xf32> -// CHECK: return %[[WRITE]] +// CHECKFAILS: tensor.pad +// CHECKFAILS: %[[C0:.*]] = arith.constant 0 : index +// CHECKFAILS: %[[C3:.*]] = arith.constant 3 : index +// CHECKFAILS: %[[C4:.*]] = arith.constant 4 : index +// CHECKFAILS: %[[FILL:.*]] = tensor.generate +// CHECKFAILS: %[[RES:.*]] = arith.mulf +// CHECKFAILS: tensor.yield %[[RES]] : f32 +// CHECKFAILS: %[[READ:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], %{{.*}} {in_bounds = [true, true]} : tensor<5x6xf32>, vector<5x6xf32> +// CHECKFAILS: %[[WRITE:.*]] = vector.transfer_write %[[READ]], %[[FILL]][%[[C3]], %[[C4]]] {in_bounds = [true, true]} : vector<5x6xf32>, tensor<12x13xf32> +// CHECKFAILS: return %[[WRITE]] func.func @pad_tensor_non_const_pad_value(%arg0: tensor<5x6xf32>) -> tensor<12x13xf32> { %c0 = arith.constant 0 : index %c5 = arith.constant 5.0 : f32 @@ -782,6 +1090,17 @@ return %0 : tensor<12x13xf32> } + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %3 = transform.structured.match ops{["tensor.pad"]} in %arg1 + %4 = get_closest_isolated_parent %3 + %5 = transform.structured.vectorize %4 + } +} + // ----- // CHECK-LABEL: func @sum_exp @@ -809,6 +1128,17 @@ return %0 : tensor<4x16xf32> } + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %3 = transform.structured.match ops{["linalg.generic"]} in %arg1 + %4 = get_closest_isolated_parent %3 + %5 = transform.structured.vectorize %4 + } +} + // ----- // CHECK-DAG: #[[$M1:.*]] = affine_map<(d0, d1) -> (d1, d0, 0, 0)> @@ -846,13 +1176,24 @@ return %0 : tensor<5x2xf32> } + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %3 = transform.structured.match ops{["linalg.generic"]} in %arg1 + %4 = get_closest_isolated_parent %3 + %5 = transform.structured.vectorize %4 {reduction_to_contract = false, permutation_map = false } + } +} + // ----- // CHECK-LABEL: func @red_max_2d( func.func @red_max_2d(%arg0: tensor<4x4xf32>) -> tensor<4xf32> { // CHECK: %[[CMINF:.+]] = arith.constant dense<-3.402820e+38> : vector<4xf32> // CHECK: linalg.init_tensor [4] : tensor<4xf32> - // CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32> + // CHECKFAILS: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32> // CHECK: vector.multi_reduction , {{.*}}, %[[CMINF]] [1] : vector<4x4xf32> to vector<4xf32> // CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32> %ident = arith.constant -3.40282e+38 : f32 @@ -869,13 +1210,24 @@ return %red : tensor<4xf32> } + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %3 = transform.structured.match ops{["linalg.generic"]} in %arg1 + %4 = get_closest_isolated_parent %3 + %5 = transform.structured.vectorize %4 + } +} + // ----- // CHECK-LABEL: func @red_min_2d( func.func @red_min_2d(%arg0: tensor<4x4xf32>) -> tensor<4xf32> { // CHECK: %[[CMAXF:.+]] = arith.constant dense<3.402820e+38> : vector<4xf32> // CHECK: linalg.init_tensor [4] : tensor<4xf32> - // CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32> + // CHECKCHECKFAILS: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32> // CHECK: vector.transfer_read {{.*}} : tensor<4x4xf32>, vector<4x4xf32> // CHECK: vector.multi_reduction , {{.*}}, %[[CMAXF]] [1] : vector<4x4xf32> to vector<4xf32> // CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32> @@ -893,12 +1245,23 @@ return %red : tensor<4xf32> } + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %3 = transform.structured.match ops{["linalg.generic"]} in %arg1 + %4 = get_closest_isolated_parent %3 + %5 = transform.structured.vectorize %4 + } +} + // ----- // CHECK-LABEL: func @red_mul_2d( func.func @red_mul_2d(%arg0: tensor<4x4xf32>) -> tensor<4xf32> { // CHECK: linalg.init_tensor [4] : tensor<4xf32> - // CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32> + // CHECKCHECKFAILS: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32> // CHECK: vector.transfer_read {{.*}} : tensor<4x4xf32>, vector<4x4xf32> // CHECK: vector.multi_reduction , {{.*}}, {{.*}} [1] : vector<4x4xf32> to vector<4xf32> // CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32> @@ -916,12 +1279,23 @@ return %red : tensor<4xf32> } + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %3 = transform.structured.match ops{["linalg.generic"]} in %arg1 + %4 = get_closest_isolated_parent %3 + %5 = transform.structured.vectorize %4 + } +} + // ----- // CHECK-LABEL: func @red_or_2d( func.func @red_or_2d(%arg0: tensor<4x4xi1>) -> tensor<4xi1> { // CHECK: linalg.init_tensor [4] : tensor<4xi1> - // CHECK: vector.transfer_write {{.*}} : vector<4xi1>, tensor<4xi1> + // CHECKCHECKFAILS: vector.transfer_write {{.*}} : vector<4xi1>, tensor<4xi1> // CHECK: vector.transfer_read {{.*}} : tensor<4x4xi1>, vector<4x4xi1> // CHECK: vector.multi_reduction , {{.*}}, {{.*}} [1] : vector<4x4xi1> to vector<4xi1> // CHECK: vector.transfer_write {{.*}} : vector<4xi1>, tensor<4xi1> @@ -939,12 +1313,23 @@ return %red : tensor<4xi1> } + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %3 = transform.structured.match ops{["linalg.generic"]} in %arg1 + %4 = get_closest_isolated_parent %3 + %5 = transform.structured.vectorize %4 + } +} + // ----- // CHECK-LABEL: func @red_and_2d( func.func @red_and_2d(%arg0: tensor<4x4xi1>) -> tensor<4xi1> { // CHECK: linalg.init_tensor [4] : tensor<4xi1> - // CHECK: vector.transfer_write {{.*}} : vector<4xi1>, tensor<4xi1> + // CHECKFAILS: vector.transfer_write {{.*}} : vector<4xi1>, tensor<4xi1> // CHECK: vector.transfer_read {{.*}} : tensor<4x4xi1>, vector<4x4xi1> // CHECK: vector.multi_reduction , {{.*}}, {{.*}} [1] : vector<4x4xi1> to vector<4xi1> // CHECK: vector.transfer_write {{.*}} : vector<4xi1>, tensor<4xi1> @@ -962,12 +1347,23 @@ return %red : tensor<4xi1> } + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %3 = transform.structured.match ops{["linalg.generic"]} in %arg1 + %4 = get_closest_isolated_parent %3 + %5 = transform.structured.vectorize %4 + } +} + // ----- // CHECK-LABEL: func @red_xor_2d( func.func @red_xor_2d(%arg0: tensor<4x4xi1>) -> tensor<4xi1> { // CHECK: linalg.init_tensor [4] : tensor<4xi1> - // CHECK: vector.transfer_write {{.*}} : vector<4xi1>, tensor<4xi1> + // CHECKFAILS: vector.transfer_write {{.*}} : vector<4xi1>, tensor<4xi1> // CHECK: vector.transfer_read {{.*}} : tensor<4x4xi1>, vector<4x4xi1> // CHECK: vector.multi_reduction , {{.*}}, {{.*}} [1] : vector<4x4xi1> to vector<4xi1> // CHECK: vector.transfer_write {{.*}} : vector<4xi1>, tensor<4xi1> @@ -985,6 +1381,17 @@ return %red : tensor<4xi1> } + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %3 = transform.structured.match ops{["linalg.generic"]} in %arg1 + %4 = get_closest_isolated_parent %3 + %5 = transform.structured.vectorize %4 + } +} + // ----- // CHECK-DAG: #[[$M5:.*]] = affine_map<(d0, d1) -> (d0, 0)> @@ -1011,6 +1418,17 @@ return %red : tensor<4x4xf32> } + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %3 = transform.structured.match ops{["linalg.generic"]} in %arg1 + %4 = get_closest_isolated_parent %3 + %5 = transform.structured.vectorize %4 + } +} + // ----- // CHECK-DAG: #[[$M6:.*]] = affine_map<(d0, d1) -> (d0, 0)> @@ -1041,6 +1459,21 @@ return %red : tensor<4xf32> } + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.fill"]} in %arg1 + %1 = get_closest_isolated_parent %0 + %2 = transform.structured.vectorize %1 + + %3 = transform.structured.match ops{["linalg.generic"]} in %arg1 + %4 = get_closest_isolated_parent %3 + %5 = transform.structured.vectorize %4 + } +} + // ----- // CHECK-LABEL: func @reduce_1d( @@ -1054,8 +1487,8 @@ // CHECK: %[[init:.*]] = linalg.init_tensor [] : tensor %0 = linalg.init_tensor [] : tensor - // CHECK: %[[f:.*]] = vector.transfer_write %[[vF0]], %[[init]][] - // CHECK-SAME: : vector, tensor + // CHECKFAILS: %[[f:.*]] = vector.transfer_write %[[vF0]], %[[init]][] + // CHECKFAILS: : vector, tensor %1 = linalg.fill ins(%f0 : f32) outs(%0 : tensor) -> tensor // CHECK: %[[r:.*]] = vector.transfer_read %[[A]][%[[C0]]] // CHECK-SAME: : tensor<32xf32>, vector<32xf32> @@ -1063,8 +1496,8 @@ // CHECK: %[[red:.*]] = vector.multi_reduction , %[[r]], %[[f0]] [0] // CHECK-SAME: : vector<32xf32> to f32 // CHECK: %[[red_v1:.*]] = vector.broadcast %[[red]] : f32 to vector - // CHECK: %[[res:.*]] = vector.transfer_write %[[red_v1]], %[[f]][] - // CHECK-SAME: : vector, tensor + // CHECKFAILS: %[[res:.*]] = vector.transfer_write %[[red_v1]], %[[f]][] + // CHECKFAILS: : vector, tensor %2 = linalg.generic { indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> ()>], @@ -1079,6 +1512,20 @@ return %2 : tensor } +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.fill"]} in %arg1 + %1 = get_closest_isolated_parent %0 + %2 = transform.structured.vectorize %1 + + %3 = transform.structured.match ops{["linalg.generic"]} in %arg1 + %4 = get_closest_isolated_parent %3 + %5 = transform.structured.vectorize %4 + } +} + // ----- @@ -1103,6 +1550,16 @@ return %result : tensor<6x6x3x3xf32> } +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 + %1 = get_closest_isolated_parent %0 + %2 = transform.structured.vectorize %1 + } +} + // ----- // Check vectorization can handle cases where outputs are a mix of reduced and non-reduced values. @@ -1134,3 +1591,13 @@ // CHECK-DAG: %[[ADD:.+]] = vector.multi_reduction , %[[MUL]], %[[V2]] // CHECK-DAG: vector.transfer_write %[[MUL]], %[[ARG2]] // CHECK-DAG: vector.transfer_write %[[ADD]], %[[ARG3]] + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 + %1 = get_closest_isolated_parent %0 + %2 = transform.structured.vectorize %1 {reduction_to_contract = false, permutation_map = false } + } +} diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp --- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp @@ -225,9 +225,6 @@ //===--------------------------------------------------------------------===// // Linalg to vector contraction patterns. //===--------------------------------------------------------------------===// - patterns.add( - ctx, LinalgTransformationFilter(StringAttr::get(ctx, "VECTORIZE")) - .addOpFilter()); patterns.add(ctx); (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); @@ -440,10 +437,7 @@ static void applyLinalgToVectorPatterns(func::FuncOp funcOp) { RewritePatternSet patterns(funcOp.getContext()); - auto *ctx = funcOp.getContext(); - patterns.add( - ctx, LinalgTransformationFilter() - .addOpFilter()); + auto *ctx = funcOp.getContext(); patterns.add(ctx); populatePadOpVectorizationPatterns(patterns); populateConvolutionVectorizationPatterns(patterns);