diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -239,6 +239,22 @@ transform::ApplyPatternsOp::applyToOne(Operation *target, ApplyToEachResultList &results, transform::TransformState &state) { + // Make sure that this transform is not applied to itself. Modifying the + // transform IR while it is being interpreted is generally dangerous. Even + // more so for the ApplyPatternsOp because the GreedyPatternRewriteDriver + // performs many additional simplifications such as dead code elimination. + Operation *transformAncestor = getOperation(); + while (transformAncestor) { + if (transformAncestor == target) { + DiagnosedDefiniteFailure diag = + emitDefiniteFailure() + << "cannot apply transform to itself (or one of its ancestors)"; + diag.attachNote(target->getLoc()) << "target payload op"; + return diag; + } + transformAncestor = transformAncestor->getParentOp(); + } + // Gather all specified patterns. MLIRContext *ctx = target->getContext(); RewritePatternSet patterns(ctx); diff --git a/mlir/test/Dialect/Tensor/fold-empty-op.mlir b/mlir/test/Dialect/Tensor/fold-empty-op.mlir --- a/mlir/test/Dialect/Tensor/fold-empty-op.mlir +++ b/mlir/test/Dialect/Tensor/fold-empty-op.mlir @@ -2,7 +2,9 @@ transform.sequence failures(propagate) { ^bb1(%module_op: !transform.any_op): - transform.apply_patterns to %module_op { + %func_op = transform.structured.match ops{["func.func"]} in %module_op + : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func_op { transform.apply_patterns.tensor.fold_tensor_empty } : !transform.any_op } @@ -66,7 +68,9 @@ transform.sequence failures(propagate) { ^bb1(%module_op: !transform.any_op): - transform.apply_patterns to %module_op { + %func_op = transform.structured.match ops{["func.func"]} in %module_op + : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func_op { transform.apply_patterns.tensor.fold_tensor_empty {fold_single_use_only = true} } : !transform.any_op diff --git a/mlir/test/Dialect/Tensor/rewrite-as-constant.mlir b/mlir/test/Dialect/Tensor/rewrite-as-constant.mlir --- a/mlir/test/Dialect/Tensor/rewrite-as-constant.mlir +++ b/mlir/test/Dialect/Tensor/rewrite-as-constant.mlir @@ -2,7 +2,9 @@ transform.sequence failures(propagate) { ^bb1(%module_op: !transform.any_op): - transform.apply_patterns to %module_op { + %func_op = transform.structured.match ops{["func.func"]} in %module_op + : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func_op { transform.apply_patterns.tensor.rewrite_as_constant } : !transform.any_op } diff --git a/mlir/test/Dialect/Transform/test-pattern-application.mlir b/mlir/test/Dialect/Transform/test-pattern-application.mlir --- a/mlir/test/Dialect/Transform/test-pattern-application.mlir +++ b/mlir/test/Dialect/Transform/test-pattern-application.mlir @@ -155,3 +155,22 @@ } : !transform.any_op transform.test_print_remark_at_operand %0, "op was replaced" : !transform.any_op } + +// ----- + +// expected-note @below{{target payload op}} +module { + func.func @invalid_pattern_application_to_transform_ir() { + return + } + + module { + transform.sequence failures(propagate) { + ^bb1(%arg1: !transform.any_op): + // expected-error @below {{cannot apply transform to itself (or one of its ancestors)}} + transform.apply_patterns to %arg1 { + transform.apply_patterns.canonicalization + } : !transform.any_op + } + } +} diff --git a/mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir --- a/mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir @@ -210,7 +210,9 @@ transform.sequence failures(propagate) { ^bb1(%module_op: !transform.any_op): - transform.apply_patterns to %module_op { + %func_op = transform.structured.match ops{["func.func"]} in %module_op + : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func_op { transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct" } : !transform.any_op } diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir --- a/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir +++ b/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir @@ -19,8 +19,6 @@ // CHECK: %[[RESULT_VEC:.+]] = vector.insertelement %[[RV1:.+]], %[[RESULT_VEC_1]][%[[C1]] : index] : vector<2xf32> // CHECK: return %[[RESULT_VEC]] -// ----- - func.func @vector_multi_reduction_to_scalar(%arg0: vector<2x4xf32>, %acc: f32) -> f32 { %0 = vector.multi_reduction , %arg0, %acc [0, 1] : vector<2x4xf32> to f32 return %0 : f32 @@ -33,8 +31,6 @@ // CHECK: %[[RES:.*]] = vector.extract %[[INSERTED]][0] : vector<1xf32> // CHECK: return %[[RES]] -// ----- - func.func @vector_reduction_inner(%arg0: vector<2x3x4x5xi32>, %acc: vector<2x3xi32>) -> vector<2x3xi32> { %0 = vector.multi_reduction , %arg0, %acc [2, 3] : vector<2x3x4x5xi32> to vector<2x3xi32> return %0 : vector<2x3xi32> @@ -76,8 +72,6 @@ // CHECK: %[[RESULT:.+]] = vector.shape_cast %[[FLAT_RESULT_VEC]] : vector<6xi32> to vector<2x3xi32> // CHECK: return %[[RESULT]] -// ----- - func.func @vector_multi_reduction_transposed(%arg0: vector<2x3x4x5xf32>, %acc: vector<2x5xf32>) -> vector<2x5xf32> { %0 = vector.multi_reduction , %arg0, %acc [1, 2] : vector<2x3x4x5xf32> to vector<2x5xf32> return %0 : vector<2x5xf32> @@ -90,8 +84,6 @@ // CHECK: %[[RESULT:.+]] = vector.shape_cast %{{.*}} : vector<10xf32> to vector<2x5xf32> // CHECK: return %[[RESULT]] -// ----- - func.func @vector_multi_reduction_ordering(%arg0: vector<3x2x4xf32>, %acc: vector<2x4xf32>) -> vector<2x4xf32> { %0 = vector.multi_reduction , %arg0, %acc [0] : vector<3x2x4xf32> to vector<2x4xf32> return %0 : vector<2x4xf32> @@ -143,8 +135,6 @@ // CHECK: %[[RESHAPED_VEC:.+]] = vector.shape_cast %[[RESULT_VEC]] : vector<8xf32> to vector<2x4xf32> // CHECK: return %[[RESHAPED_VEC]] -// ----- - func.func @vectorize_dynamic_reduction(%arg0: tensor, %arg1: tensor) -> tensor { %c0 = arith.constant 0 : index %dim = tensor.dim %arg0, %c0 : tensor @@ -187,8 +177,6 @@ // CHECK: %[[VAL_32:.*]] = vector.mask %[[VAL_31]] { vector.reduction , %{{.*}} : vector<8xf32> into f32 } : vector<8xi1> -> f32 // CHECK: %[[VAL_33:.*]] = vector.insertelement -// ----- - func.func @vectorize_1d_dynamic_reduction(%arg0: tensor) -> f32 { %c0 = arith.constant 0 : index %dim = tensor.dim %arg0, %c0 : tensor @@ -207,8 +195,6 @@ // CHECK: %[[VAL_5:.*]] = vector.create_mask {{.*}} : vector<8xi1> // CHECK: %[[VAL_7:.*]] = vector.mask %[[VAL_5]] { vector.reduction , %{{.*}} : vector<8xf32> into f32 } : vector<8xi1> -> f32 -// ----- - func.func @vectorize_dynamic_transpose_reduction(%arg0: tensor, %arg1: tensor) -> tensor { %c0 = arith.constant 0 : index %dim = tensor.dim %arg0, %c0 : tensor @@ -254,8 +240,6 @@ // CHECK: %[[VAL_159:.*]] = vector.mask %[[VAL_158]] { vector.reduction // CHECK: %[[VAL_160:.*]] = vector.insertelement %[[VAL_159]] -// ----- - func.func @vector_multi_reduction_parallel_middle(%arg0: vector<3x4x5xf32>, %acc: vector<4xf32>) -> vector<4xf32> { %0 = vector.multi_reduction , %arg0, %acc [0, 2] : vector<3x4x5xf32> to vector<4xf32> return %0 : vector<4xf32> @@ -267,7 +251,9 @@ transform.sequence failures(propagate) { ^bb1(%module_op: !transform.any_op): - transform.apply_patterns to %module_op { + %func_op = transform.structured.match ops{["func.func"]} in %module_op + : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func_op { transform.apply_patterns.vector.lower_multi_reduction lowering_strategy = "innerreduction" } : !transform.any_op } diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir --- a/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir +++ b/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir @@ -190,7 +190,9 @@ transform.sequence failures(propagate) { ^bb1(%module_op: !transform.any_op): - transform.apply_patterns to %module_op { + %func_op = transform.structured.match ops{["func.func"]} in %module_op + : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func_op { transform.apply_patterns.vector.lower_multi_reduction lowering_strategy = "innerparallel" } : !transform.any_op } diff --git a/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir b/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir --- a/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s --test-transform-dialect-interpreter --split-input-file | FileCheck %s +// RUN: mlir-opt %s --test-transform-dialect-interpreter | FileCheck %s func.func @transfer_read_rank_reducing( %arg : memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>) -> vector<3x2xi8> { @@ -8,44 +8,24 @@ memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>, vector<3x2xi8> return %v : vector<3x2xi8> } - // CHECK-LABEL: func @transfer_read_rank_reducing // CHECK-SAME: %[[ARG:.+]]: memref<1x1x3x2xi8 // CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0, 0, 0] [1, 1, 3, 2] [1, 1, 1, 1] // CHECK-SAME: memref<1x1x3x2xi8, {{.*}}> to memref<3x2xi8, {{.*}}> // CHECK: vector.transfer_read %[[SUBVIEW]] -transform.sequence failures(propagate) { -^bb1(%module_op: !transform.any_op): - transform.apply_patterns to %module_op { - transform.apply_patterns.vector.rank_reducing_subview_patterns - } : !transform.any_op -} - -// ----- - func.func @transfer_write_rank_reducing(%arg : memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>, %vec : vector<3x2xi8>) { %c0 = arith.constant 0 : index vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] : vector<3x2xi8>, memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>> return } - // CHECK-LABEL: func @transfer_write_rank_reducing // CHECK-SAME: %[[ARG:.+]]: memref<1x1x3x2xi8 // CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0, 0, 0] [1, 1, 3, 2] [1, 1, 1, 1] // CHECK-SAME: memref<1x1x3x2xi8, {{.*}}> to memref<3x2xi8, {{.*}}> // CHECK: vector.transfer_write %{{.*}}, %[[SUBVIEW]] -transform.sequence failures(propagate) { -^bb1(%module_op: !transform.any_op): - transform.apply_patterns to %module_op { - transform.apply_patterns.vector.rank_reducing_subview_patterns - } : !transform.any_op -} - -// ----- - func.func @transfer_read_and_vector_rank_reducing( %arg : memref<1x1x3x2x1xf32>) -> vector<3x2x1xf32> { %c0 = arith.constant 0 : index @@ -54,22 +34,12 @@ memref<1x1x3x2x1xf32>, vector<3x2x1xf32> return %v : vector<3x2x1xf32> } - // CHECK-LABEL: func @transfer_read_and_vector_rank_reducing // CHECK-SAME: %[[ARG:.+]]: memref<1x1x3x2x1xf32> // CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0, 0, 0, 0] [1, 1, 3, 2, 1] [1, 1, 1, 1, 1] // CHECK-SAME: memref<1x1x3x2x1xf32> to memref<3x2xf32> // CHECK: vector.transfer_read %[[SUBVIEW]]{{.*}} {in_bounds = [true, true]} : memref<3x2xf32>, vector<3x2xf32> -transform.sequence failures(propagate) { -^bb1(%module_op: !transform.any_op): - transform.apply_patterns to %module_op { - transform.apply_patterns.vector.rank_reducing_subview_patterns - } : !transform.any_op -} - -// ----- - func.func @transfer_write_and_vector_rank_reducing( %arg : memref<1x1x3x2x1xf32>, %vec : vector<3x2x1xf32>) { @@ -78,22 +48,12 @@ vector<3x2x1xf32>, memref<1x1x3x2x1xf32> return } - // CHECK-LABEL: func @transfer_write_and_vector_rank_reducing // CHECK-SAME: %[[ARG:.+]]: memref<1x1x3x2x1xf32> // CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0, 0, 0, 0] [1, 1, 3, 2, 1] [1, 1, 1, 1, 1] // CHECK-SAME: memref<1x1x3x2x1xf32> to memref<3x2xf32> // CHECK: vector.transfer_write %{{.*}}, %[[SUBVIEW]]{{.*}} {in_bounds = [true, true]} : vector<3x2xf32>, memref<3x2xf32> -transform.sequence failures(propagate) { -^bb1(%module_op: !transform.any_op): - transform.apply_patterns to %module_op { - transform.apply_patterns.vector.rank_reducing_subview_patterns - } : !transform.any_op -} - -// ----- - func.func @transfer_read_and_vector_rank_reducing_to_0d( %arg : memref<1x1x1x1x1xf32>) -> vector<1x1x1xf32> { %c0 = arith.constant 0 : index @@ -102,22 +62,12 @@ memref<1x1x1x1x1xf32>, vector<1x1x1xf32> return %v : vector<1x1x1xf32> } - // CHECK-LABEL: func @transfer_read_and_vector_rank_reducing_to_0d // CHECK-SAME: %[[MEMREF:.+]]: memref<1x1x1x1x1xf32> // CHECK: %[[SUBVIEW:.+]] = memref.subview %[[MEMREF]][0, 0, 0, 0, 0] [1, 1, 1, 1, 1] [1, 1, 1, 1, 1] : memref<1x1x1x1x1xf32> to memref // CHECK: %[[READ:.+]] = vector.transfer_read %[[SUBVIEW]]{{.*}} : memref, vector // CHECK: vector.shape_cast %[[READ]] : vector to vector<1x1x1xf32> -transform.sequence failures(propagate) { -^bb1(%module_op: !transform.any_op): - transform.apply_patterns to %module_op { - transform.apply_patterns.vector.rank_reducing_subview_patterns - } : !transform.any_op -} - -// ----- - func.func @transfer_write_and_vector_rank_reducing_to_0d( %arg : memref<1x1x1x1x1xf32>, %vec : vector<1x1x1xf32>) { @@ -126,7 +76,6 @@ vector<1x1x1xf32>, memref<1x1x1x1x1xf32> return } - // CHECK-LABEL: func @transfer_write_and_vector_rank_reducing_to_0d // CHECK-SAME: %[[MEMREF:.+]]: memref<1x1x1x1x1xf32>, %[[VECTOR:.+]]: vector<1x1x1xf32> // CHECK: %[[SUBVIEW:.+]] = memref.subview %[[MEMREF]][0, 0, 0, 0, 0] [1, 1, 1, 1, 1] [1, 1, 1, 1, 1] : memref<1x1x1x1x1xf32> to memref @@ -135,7 +84,9 @@ transform.sequence failures(propagate) { ^bb1(%module_op: !transform.any_op): - transform.apply_patterns to %module_op { + %func_op = transform.structured.match ops{["func.func"]} in %module_op + : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func_op { transform.apply_patterns.vector.rank_reducing_subview_patterns } : !transform.any_op } diff --git a/mlir/test/Dialect/Vector/vector-transfer-full-partial-split-copy-transform.mlir b/mlir/test/Dialect/Vector/vector-transfer-full-partial-split-copy-transform.mlir --- a/mlir/test/Dialect/Vector/vector-transfer-full-partial-split-copy-transform.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-full-partial-split-copy-transform.mlir @@ -108,7 +108,9 @@ transform.sequence failures(propagate) { ^bb1(%module_op: !transform.any_op): - transform.apply_patterns to %module_op { + %func_op = transform.structured.match ops{["func.func"]} in %module_op + : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func_op { transform.apply_patterns.vector.split_transfer_full_partial split_transfer_strategy = "linalg-copy" } : !transform.any_op } @@ -169,7 +171,9 @@ transform.sequence failures(propagate) { ^bb1(%module_op: !transform.any_op): - transform.apply_patterns to %module_op { + %func_op = transform.structured.match ops{["func.func"]} in %module_op + : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func_op { transform.apply_patterns.vector.split_transfer_full_partial split_transfer_strategy = "linalg-copy" } : !transform.any_op } @@ -237,7 +241,9 @@ transform.sequence failures(propagate) { ^bb1(%module_op: !transform.any_op): - transform.apply_patterns to %module_op { + %func_op = transform.structured.match ops{["func.func"]} in %module_op + : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func_op { transform.apply_patterns.vector.split_transfer_full_partial split_transfer_strategy = "linalg-copy" } : !transform.any_op } diff --git a/mlir/test/Dialect/Vector/vector-transfer-full-partial-split.mlir b/mlir/test/Dialect/Vector/vector-transfer-full-partial-split.mlir --- a/mlir/test/Dialect/Vector/vector-transfer-full-partial-split.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-full-partial-split.mlir @@ -103,7 +103,9 @@ transform.sequence failures(propagate) { ^bb1(%module_op: !transform.any_op): - transform.apply_patterns to %module_op { + %func_op = transform.structured.match ops{["func.func"]} in %module_op + : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func_op { transform.apply_patterns.vector.split_transfer_full_partial split_transfer_strategy = "vector-transfer" } : !transform.any_op } @@ -161,7 +163,9 @@ transform.sequence failures(propagate) { ^bb1(%module_op: !transform.any_op): - transform.apply_patterns to %module_op { + %func_op = transform.structured.match ops{["func.func"]} in %module_op + : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func_op { transform.apply_patterns.vector.split_transfer_full_partial split_transfer_strategy = "vector-transfer" } : !transform.any_op } @@ -223,7 +227,9 @@ transform.sequence failures(propagate) { ^bb1(%module_op: !transform.any_op): - transform.apply_patterns to %module_op { + %func_op = transform.structured.match ops{["func.func"]} in %module_op + : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func_op { transform.apply_patterns.vector.split_transfer_full_partial split_transfer_strategy = "vector-transfer" } : !transform.any_op } @@ -265,7 +271,9 @@ transform.sequence failures(propagate) { ^bb1(%module_op: !transform.any_op): - transform.apply_patterns to %module_op { + %func_op = transform.structured.match ops{["func.func"]} in %module_op + : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func_op { transform.apply_patterns.vector.split_transfer_full_partial split_transfer_strategy = "vector-transfer" } : !transform.any_op } diff --git a/mlir/test/Dialect/Vector/vector-transfer-tensor-slice-transforms.mlir b/mlir/test/Dialect/Vector/vector-transfer-tensor-slice-transforms.mlir --- a/mlir/test/Dialect/Vector/vector-transfer-tensor-slice-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-tensor-slice-transforms.mlir @@ -2,7 +2,9 @@ transform.sequence failures(propagate) { ^bb1(%module_op: !transform.any_op): - transform.apply_patterns to %module_op { + %func_op = transform.structured.match ops{["func.func"]} in %module_op + : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func_op { transform.apply_patterns.vector.fold_tensor_slice_into_transfer } : !transform.any_op } diff --git a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir --- a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir @@ -240,7 +240,9 @@ transform.sequence failures(propagate) { ^bb1(%module_op: !transform.any_op): - transform.apply_patterns to %module_op { + %func_op = transform.structured.match ops{["func.func"]} in %module_op + : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func_op { transform.apply_patterns.vector.lower_transfer max_transfer_rank = 99 transform.apply_patterns.vector.transfer_permutation_patterns } : !transform.any_op @@ -362,7 +364,9 @@ transform.sequence failures(propagate) { ^bb1(%module_op: !transform.any_op): - transform.apply_patterns to %module_op { + %func_op = transform.structured.match ops{["func.func"]} in %module_op + : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func_op { transform.apply_patterns.vector.lower_transfer max_transfer_rank = 99 transform.apply_patterns.vector.transfer_permutation_patterns } : !transform.any_op diff --git a/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir b/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir --- a/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir +++ b/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir @@ -76,7 +76,9 @@ transform.sequence failures(propagate) { ^bb1(%module_op: !transform.any_op): - transform.apply_patterns to %module_op { + %func_op = transform.structured.match ops{["func.func"]} in %module_op + : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func_op { transform.apply_patterns.vector.lower_transpose lowering_strategy = "eltwise" } : !transform.any_op } @@ -99,7 +101,9 @@ transform.sequence failures(propagate) { ^bb1(%module_op: !transform.any_op): - transform.apply_patterns to %module_op { + %func_op = transform.structured.match ops{["func.func"]} in %module_op + : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func_op { transform.apply_patterns.vector.lower_transpose lowering_strategy = "shuffle_1d" } : !transform.any_op } @@ -118,7 +122,9 @@ transform.sequence failures(propagate) { ^bb1(%module_op: !transform.any_op): - transform.apply_patterns to %module_op { + %func_op = transform.structured.match ops{["func.func"]} in %module_op + : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func_op { transform.apply_patterns.vector.lower_transpose lowering_strategy = "flat_transpose" } : !transform.any_op } @@ -605,7 +611,9 @@ transform.sequence failures(propagate) { ^bb1(%module_op: !transform.any_op): - transform.apply_patterns to %module_op { + %func_op = transform.structured.match ops{["func.func"]} in %module_op + : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func_op { transform.apply_patterns.vector.lower_transpose avx2_lowering_strategy = true } : !transform.any_op } @@ -683,7 +691,9 @@ transform.sequence failures(propagate) { ^bb1(%module_op: !transform.any_op): - transform.apply_patterns to %module_op { + %func_op = transform.structured.match ops{["func.func"]} in %module_op + : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func_op { transform.apply_patterns.vector.lower_transpose lowering_strategy = "shuffle_16x16" } : !transform.any_op } @@ -762,7 +772,9 @@ transform.sequence failures(propagate) { ^bb1(%module_op: !transform.any_op): - transform.apply_patterns to %module_op { + %func_op = transform.structured.match ops{["func.func"]} in %module_op + : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func_op { transform.apply_patterns.vector.lower_transpose lowering_strategy = "shuffle_16x16" } : !transform.any_op } diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-shuffle16x16.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-shuffle16x16.mlir --- a/mlir/test/Integration/Dialect/Vector/CPU/test-shuffle16x16.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/test-shuffle16x16.mlir @@ -31,7 +31,9 @@ transform.sequence failures(propagate) { ^bb1(%module_op: !transform.any_op): - transform.apply_patterns to %module_op { + %func_op = transform.structured.match ops{["func.func"]} in %module_op + : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func_op { transform.apply_patterns.vector.lower_transpose lowering_strategy = "shuffle_16x16" } : !transform.any_op }