diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td --- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td @@ -768,6 +768,26 @@ dialect. Operand omission is only allowed for sequences not contained in another sequence. + The type of the block argument must match the type of the operand. If the + sequence is a top-level transform (without an operand), it can be used for + matching operations if the specified type within the top-level container + payload IR (including the container op itself). E.g.: + + ```mlir + transform.sequence failures(propagate) { + ^bb1(%arg1: !transform.any_op): + // %arg1 is mapped to the top-level container of the payload IR, which is + // typically a module + } + + transform.sequence failures(propagate) { + ^bb1(%arg1: !transform.op<"func.func>"): + // %arg1 is mapped to all "func.func" ops within and including the + // top-level container of the payload IR. Nested operations that have the + // specified op type are not included. + } + ``` + The body of the sequence terminates with an implicit or explicit `transform.yield` op. The operands of the terminator are returned as the results of the sequence op. diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp @@ -1562,7 +1562,24 @@ << " were provided to the interpreter"; } - targets.push_back(state.getTopLevel()); + // Top-level transforms can be used for matching. If no concrete operation + // type is specified, the block argument is mapped to the top-level op. + // Otherwise, it is mapped to all ops of the specified type within the + // top-level op (including the top-level op itself). Once an op is added as + // a target, its descendants are not explored any further. + BlockArgument bbArg = region.front().getArgument(0); + if (auto bbArgType = dyn_cast(bbArg.getType())) { + state.getTopLevel()->walk([&](Operation *op) { + if (op->getName().getStringRef() == bbArgType.getOperationName()) { + targets.push_back(op); + return WalkResult::skip(); + } + return WalkResult::advance(); + }); + } else { + targets.push_back(state.getTopLevel()); + } + for (unsigned i = 0, e = state.getNumTopLevelMappings(); i < e; ++i) extraMappings.push_back(llvm::to_vector(state.getTopLevelMapping(i))); } diff --git a/mlir/test/Dialect/Tensor/fold-empty-op.mlir b/mlir/test/Dialect/Tensor/fold-empty-op.mlir --- a/mlir/test/Dialect/Tensor/fold-empty-op.mlir +++ b/mlir/test/Dialect/Tensor/fold-empty-op.mlir @@ -1,12 +1,10 @@ // RUN: mlir-opt -split-input-file -test-transform-dialect-interpreter %s | FileCheck %s transform.sequence failures(propagate) { -^bb1(%module_op: !transform.any_op): - %func_op = transform.structured.match ops{["func.func"]} in %module_op - : (!transform.any_op) -> !transform.any_op +^bb1(%func_op: !transform.op<"func.func">): transform.apply_patterns to %func_op { transform.apply_patterns.tensor.fold_tensor_empty - } : !transform.any_op + } : !transform.op<"func.func"> } // CHECK: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 floordiv 28)> @@ -67,13 +65,11 @@ // ----- transform.sequence failures(propagate) { -^bb1(%module_op: !transform.any_op): - %func_op = transform.structured.match ops{["func.func"]} in %module_op - : (!transform.any_op) -> !transform.any_op +^bb1(%func_op: !transform.op<"func.func">): transform.apply_patterns to %func_op { transform.apply_patterns.tensor.fold_tensor_empty {fold_single_use_only = true} - } : !transform.any_op + } : !transform.op<"func.func"> } func.func @double_use_of_tensor_empty(%arg0: index, %arg1: index) diff --git a/mlir/test/Dialect/Tensor/rewrite-as-constant.mlir b/mlir/test/Dialect/Tensor/rewrite-as-constant.mlir --- a/mlir/test/Dialect/Tensor/rewrite-as-constant.mlir +++ b/mlir/test/Dialect/Tensor/rewrite-as-constant.mlir @@ -1,12 +1,10 @@ // RUN: mlir-opt -split-input-file -test-transform-dialect-interpreter %s | FileCheck %s transform.sequence failures(propagate) { -^bb1(%module_op: !transform.any_op): - %func_op = transform.structured.match ops{["func.func"]} in %module_op - : (!transform.any_op) -> !transform.any_op +^bb1(%func_op: !transform.op<"func.func">): transform.apply_patterns to %func_op { transform.apply_patterns.tensor.rewrite_as_constant - } : !transform.any_op + } : !transform.op<"func.func"> } // CHECK-LABEL: func @tensor_generate_constant( diff --git a/mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir --- a/mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir @@ -209,10 +209,8 @@ } transform.sequence failures(propagate) { -^bb1(%module_op: !transform.any_op): - %func_op = transform.structured.match ops{["func.func"]} in %module_op - : (!transform.any_op) -> !transform.any_op +^bb1(%func_op: !transform.op<"func.func">): transform.apply_patterns to %func_op { transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct" - } : !transform.any_op + } : !transform.op<"func.func"> } diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir --- a/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir +++ b/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir @@ -250,10 +250,8 @@ // CHECK: vector.transpose %[[INPUT]], [1, 0, 2] : vector<3x4x5xf32> to vector<4x3x5xf32> transform.sequence failures(propagate) { -^bb1(%module_op: !transform.any_op): - %func_op = transform.structured.match ops{["func.func"]} in %module_op - : (!transform.any_op) -> !transform.any_op +^bb1(%func_op: !transform.op<"func.func">): transform.apply_patterns to %func_op { transform.apply_patterns.vector.lower_multi_reduction lowering_strategy = "innerreduction" - } : !transform.any_op + } : !transform.op<"func.func"> } diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir --- a/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir +++ b/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir @@ -189,10 +189,8 @@ // CHECK: return %{{.+}} transform.sequence failures(propagate) { -^bb1(%module_op: !transform.any_op): - %func_op = transform.structured.match ops{["func.func"]} in %module_op - : (!transform.any_op) -> !transform.any_op +^bb1(%func_op: !transform.op<"func.func">): transform.apply_patterns to %func_op { transform.apply_patterns.vector.lower_multi_reduction lowering_strategy = "innerparallel" - } : !transform.any_op + } : !transform.op<"func.func"> } diff --git a/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir b/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir --- a/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir @@ -83,10 +83,8 @@ // CHECK: vector.transfer_write %[[SHCAST]], %[[SUBVIEW]]{{.*}} : vector, memref transform.sequence failures(propagate) { -^bb1(%module_op: !transform.any_op): - %func_op = transform.structured.match ops{["func.func"]} in %module_op - : (!transform.any_op) -> !transform.any_op +^bb1(%func_op: !transform.op<"func.func">): transform.apply_patterns to %func_op { transform.apply_patterns.vector.rank_reducing_subview_patterns - } : !transform.any_op + } : !transform.op<"func.func"> } diff --git a/mlir/test/Dialect/Vector/vector-transfer-full-partial-split-copy-transform.mlir b/mlir/test/Dialect/Vector/vector-transfer-full-partial-split-copy-transform.mlir --- a/mlir/test/Dialect/Vector/vector-transfer-full-partial-split-copy-transform.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-full-partial-split-copy-transform.mlir @@ -107,12 +107,10 @@ } transform.sequence failures(propagate) { -^bb1(%module_op: !transform.any_op): - %func_op = transform.structured.match ops{["func.func"]} in %module_op - : (!transform.any_op) -> !transform.any_op +^bb1(%func_op: !transform.op<"func.func">): transform.apply_patterns to %func_op { transform.apply_patterns.vector.split_transfer_full_partial split_transfer_strategy = "linalg-copy" - } : !transform.any_op + } : !transform.op<"func.func"> } // ----- @@ -170,12 +168,10 @@ // CHECK: } transform.sequence failures(propagate) { -^bb1(%module_op: !transform.any_op): - %func_op = transform.structured.match ops{["func.func"]} in %module_op - : (!transform.any_op) -> !transform.any_op +^bb1(%func_op: !transform.op<"func.func">): transform.apply_patterns to %func_op { transform.apply_patterns.vector.split_transfer_full_partial split_transfer_strategy = "linalg-copy" - } : !transform.any_op + } : !transform.op<"func.func"> } // ----- @@ -240,10 +236,8 @@ // CHECK: } transform.sequence failures(propagate) { -^bb1(%module_op: !transform.any_op): - %func_op = transform.structured.match ops{["func.func"]} in %module_op - : (!transform.any_op) -> !transform.any_op +^bb1(%func_op: !transform.op<"func.func">): transform.apply_patterns to %func_op { transform.apply_patterns.vector.split_transfer_full_partial split_transfer_strategy = "linalg-copy" - } : !transform.any_op + } : !transform.op<"func.func"> } diff --git a/mlir/test/Dialect/Vector/vector-transfer-full-partial-split.mlir b/mlir/test/Dialect/Vector/vector-transfer-full-partial-split.mlir --- a/mlir/test/Dialect/Vector/vector-transfer-full-partial-split.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-full-partial-split.mlir @@ -102,12 +102,10 @@ } transform.sequence failures(propagate) { -^bb1(%module_op: !transform.any_op): - %func_op = transform.structured.match ops{["func.func"]} in %module_op - : (!transform.any_op) -> !transform.any_op +^bb1(%func_op: !transform.op<"func.func">): transform.apply_patterns to %func_op { transform.apply_patterns.vector.split_transfer_full_partial split_transfer_strategy = "vector-transfer" - } : !transform.any_op + } : !transform.op<"func.func"> } // ----- @@ -162,12 +160,10 @@ transform.sequence failures(propagate) { -^bb1(%module_op: !transform.any_op): - %func_op = transform.structured.match ops{["func.func"]} in %module_op - : (!transform.any_op) -> !transform.any_op +^bb1(%func_op: !transform.op<"func.func">): transform.apply_patterns to %func_op { transform.apply_patterns.vector.split_transfer_full_partial split_transfer_strategy = "vector-transfer" - } : !transform.any_op + } : !transform.op<"func.func"> } // ----- @@ -226,12 +222,10 @@ // CHECK: } transform.sequence failures(propagate) { -^bb1(%module_op: !transform.any_op): - %func_op = transform.structured.match ops{["func.func"]} in %module_op - : (!transform.any_op) -> !transform.any_op +^bb1(%func_op: !transform.op<"func.func">): transform.apply_patterns to %func_op { transform.apply_patterns.vector.split_transfer_full_partial split_transfer_strategy = "vector-transfer" - } : !transform.any_op + } : !transform.op<"func.func"> } // ----- @@ -270,10 +264,8 @@ } transform.sequence failures(propagate) { -^bb1(%module_op: !transform.any_op): - %func_op = transform.structured.match ops{["func.func"]} in %module_op - : (!transform.any_op) -> !transform.any_op +^bb1(%func_op: !transform.op<"func.func">): transform.apply_patterns to %func_op { transform.apply_patterns.vector.split_transfer_full_partial split_transfer_strategy = "vector-transfer" - } : !transform.any_op + } : !transform.op<"func.func"> } diff --git a/mlir/test/Dialect/Vector/vector-transfer-tensor-slice-transforms.mlir b/mlir/test/Dialect/Vector/vector-transfer-tensor-slice-transforms.mlir --- a/mlir/test/Dialect/Vector/vector-transfer-tensor-slice-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-tensor-slice-transforms.mlir @@ -1,12 +1,10 @@ // RUN: mlir-opt -split-input-file -test-transform-dialect-interpreter %s | FileCheck %s transform.sequence failures(propagate) { -^bb1(%module_op: !transform.any_op): - %func_op = transform.structured.match ops{["func.func"]} in %module_op - : (!transform.any_op) -> !transform.any_op +^bb1(%func_op: !transform.op<"func.func">): transform.apply_patterns to %func_op { transform.apply_patterns.vector.fold_tensor_slice_into_transfer - } : !transform.any_op + } : !transform.op<"func.func"> } // CHECK-LABEL: func @transfer_read_of_extract_slice( diff --git a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir --- a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir @@ -239,13 +239,11 @@ transform.sequence failures(propagate) { -^bb1(%module_op: !transform.any_op): - %func_op = transform.structured.match ops{["func.func"]} in %module_op - : (!transform.any_op) -> !transform.any_op +^bb1(%func_op: !transform.op<"func.func">): transform.apply_patterns to %func_op { transform.apply_patterns.vector.lower_transfer max_transfer_rank = 99 transform.apply_patterns.vector.transfer_permutation_patterns - } : !transform.any_op + } : !transform.op<"func.func"> } // ----- @@ -363,11 +361,9 @@ } transform.sequence failures(propagate) { -^bb1(%module_op: !transform.any_op): - %func_op = transform.structured.match ops{["func.func"]} in %module_op - : (!transform.any_op) -> !transform.any_op +^bb1(%func_op: !transform.op<"func.func">): transform.apply_patterns to %func_op { transform.apply_patterns.vector.lower_transfer max_transfer_rank = 99 transform.apply_patterns.vector.transfer_permutation_patterns - } : !transform.any_op + } : !transform.op<"func.func"> } diff --git a/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir b/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir --- a/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir +++ b/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir @@ -75,12 +75,10 @@ } transform.sequence failures(propagate) { -^bb1(%module_op: !transform.any_op): - %func_op = transform.structured.match ops{["func.func"]} in %module_op - : (!transform.any_op) -> !transform.any_op +^bb1(%func_op: !transform.op<"func.func">): transform.apply_patterns to %func_op { transform.apply_patterns.vector.lower_transpose lowering_strategy = "eltwise" - } : !transform.any_op + } : !transform.op<"func.func"> } // ----- @@ -100,12 +98,10 @@ transform.sequence failures(propagate) { -^bb1(%module_op: !transform.any_op): - %func_op = transform.structured.match ops{["func.func"]} in %module_op - : (!transform.any_op) -> !transform.any_op +^bb1(%func_op: !transform.op<"func.func">): transform.apply_patterns to %func_op { transform.apply_patterns.vector.lower_transpose lowering_strategy = "shuffle_1d" - } : !transform.any_op + } : !transform.op<"func.func"> } // ----- @@ -121,12 +117,10 @@ transform.sequence failures(propagate) { -^bb1(%module_op: !transform.any_op): - %func_op = transform.structured.match ops{["func.func"]} in %module_op - : (!transform.any_op) -> !transform.any_op +^bb1(%func_op: !transform.op<"func.func">): transform.apply_patterns to %func_op { transform.apply_patterns.vector.lower_transpose lowering_strategy = "flat_transpose" - } : !transform.any_op + } : !transform.op<"func.func"> } // ----- @@ -610,12 +604,10 @@ } transform.sequence failures(propagate) { -^bb1(%module_op: !transform.any_op): - %func_op = transform.structured.match ops{["func.func"]} in %module_op - : (!transform.any_op) -> !transform.any_op +^bb1(%func_op: !transform.op<"func.func">): transform.apply_patterns to %func_op { transform.apply_patterns.vector.lower_transpose avx2_lowering_strategy = true - } : !transform.any_op + } : !transform.op<"func.func"> } // ----- @@ -690,12 +682,10 @@ } transform.sequence failures(propagate) { -^bb1(%module_op: !transform.any_op): - %func_op = transform.structured.match ops{["func.func"]} in %module_op - : (!transform.any_op) -> !transform.any_op +^bb1(%func_op: !transform.op<"func.func">): transform.apply_patterns to %func_op { transform.apply_patterns.vector.lower_transpose lowering_strategy = "shuffle_16x16" - } : !transform.any_op + } : !transform.op<"func.func"> } // ----- @@ -771,10 +761,8 @@ } transform.sequence failures(propagate) { -^bb1(%module_op: !transform.any_op): - %func_op = transform.structured.match ops{["func.func"]} in %module_op - : (!transform.any_op) -> !transform.any_op +^bb1(%func_op: !transform.op<"func.func">): transform.apply_patterns to %func_op { transform.apply_patterns.vector.lower_transpose lowering_strategy = "shuffle_16x16" - } : !transform.any_op + } : !transform.op<"func.func"> } diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-shuffle16x16.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-shuffle16x16.mlir --- a/mlir/test/Integration/Dialect/Vector/CPU/test-shuffle16x16.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/test-shuffle16x16.mlir @@ -30,11 +30,9 @@ } transform.sequence failures(propagate) { -^bb1(%module_op: !transform.any_op): - %func_op = transform.structured.match ops{["func.func"]} in %module_op - : (!transform.any_op) -> !transform.any_op +^bb1(%func_op: !transform.op<"func.func">): transform.apply_patterns to %func_op { transform.apply_patterns.vector.lower_transpose lowering_strategy = "shuffle_16x16" - } : !transform.any_op + } : !transform.op<"func.func"> }