diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td --- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td @@ -420,33 +420,6 @@ }]; } -def GetClosestIsolatedParentOp : TransformDialectOp<"get_closest_isolated_parent", - [DeclareOpInterfaceMethods, - NavigationTransformOpTrait, MemoryEffectsOpInterface]> { - let summary = "Gets handles to the closest isolated-from-above parents"; - let description = [{ - The handles defined by this Transform op correspond to the closest isolated - from above ancestor of the Payload IR operations associated with its - operand. If any of the given Payload IR ops has no such parent (unlikely as - there usually is a top-level ModuleOp), the transformation is considered to - have failed. - - Ancestor ops follow the same order as the ops associated with the - operand, except for potential duplicates (multiple Payload IR ops associated - with the operand have the same parent) for which the ancestor will only be - listed once for the first time it occurs. For example, given the list - "(childof(A), childof(B), childof(B), childof(A), childof(B))", the - resulting list will be just "(A, B)". Note that no other semantic ordering - is applied, e.g., "B" may itself be a parent of "A". This may have an impact - on the further transformation applied to the handle produced here. - }]; - - let arguments = (ins TransformHandleTypeInterface:$target); - let results = (outs TransformHandleTypeInterface:$parent); - let assemblyFormat = - "$target attr-dict `:` functional-type(operands, results)"; -} - def GetConsumersOfResult : TransformDialectOp<"get_consumers_of_result", [DeclareOpInterfaceMethods, NavigationTransformOpTrait, MemoryEffectsOpInterface]> { @@ -485,6 +458,40 @@ "functional-type(operands, results)"; } +def GetParentOp : TransformDialectOp<"get_parent_op", + [DeclareOpInterfaceMethods, + NavigationTransformOpTrait, MemoryEffectsOpInterface]> { + let summary = "Gets handles to the closest isolated-from-above parents"; + let description = [{ + The handle defined by this Transform op corresponds to the parents of the + targeted payload ops (in the same order). + + Requirements that parent ops must fulfill can be optionally specified. In + that case for each target op, the closest parent op that fulfills all + requirements, is returned. + - `isolated_from_above`: the parent op must be isolated from above + - `op_name`: the parent op must have the specified name + + If `deduplicate` is set, the result handle does not contain any duplicate + ops. For example, given the list + "(childof(A), childof(B), childof(B), childof(A), childof(B))", the + resulting list will be just "(A, B)". Note that no other semantic ordering + is applied, e.g., "B" may itself be a parent of "A". This may have an impact + on the further transformation applied to the handle produced here. + + If any of the given Payload IR ops has no such suitable parent, the + transformation fails silently. + }]; + + let arguments = (ins TransformHandleTypeInterface:$target, + UnitAttr:$isolated_from_above, + OptionalAttr:$op_name, + UnitAttr:$deduplicate); + let results = (outs TransformHandleTypeInterface:$parent); + let assemblyFormat = + "$target attr-dict `:` functional-type(operands, results)"; +} + def GetProducerOfOperand : TransformDialectOp<"get_producer_of_operand", [DeclareOpInterfaceMethods, NavigationTransformOpTrait, MemoryEffectsOpInterface]> { diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -909,26 +909,43 @@ } //===----------------------------------------------------------------------===// -// GetClosestIsolatedParentOp +// GetParentOp //===----------------------------------------------------------------------===// -DiagnosedSilenceableFailure transform::GetClosestIsolatedParentOp::apply( - transform::TransformRewriter &rewriter, - transform::TransformResults &results, transform::TransformState &state) { - SetVector parents; +DiagnosedSilenceableFailure +transform::GetParentOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, + transform::TransformState &state) { + SmallVector parents; + DenseSet resultSet; for (Operation *target : state.getPayloadOps(getTarget())) { - Operation *parent = - target->getParentWithTrait(); + Operation *parent = target->getParentOp(); + do { + bool checkIsolatedFromAbove = + !getIsolatedFromAbove() || + parent->hasTrait(); + bool checkOpName = !getOpName().has_value() || + parent->getName().getStringRef() == *getOpName(); + if (checkIsolatedFromAbove && checkOpName) + break; + } while ((parent = parent->getParentOp())); if (!parent) { DiagnosedSilenceableFailure diag = emitSilenceableError() - << "could not find an isolated-from-above parent op"; + << "could not find a parent op that matches all requirements"; diag.attachNote(target->getLoc()) << "target op"; return diag; } - parents.insert(parent); + if (getDeduplicate()) { + if (!resultSet.contains(parent)) { + parents.push_back(parent); + resultSet.insert(parent); + } + } else { + parents.push_back(parent); + } } - results.set(llvm::cast(getResult()), parents.getArrayRef()); + results.set(llvm::cast(getResult()), parents); return DiagnosedSilenceableFailure::success(); } diff --git a/mlir/python/mlir/dialects/_transform_ops_ext.py b/mlir/python/mlir/dialects/_transform_ops_ext.py --- a/mlir/python/mlir/dialects/_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_transform_ops_ext.py @@ -15,17 +15,42 @@ class CastOp: - def __init__( - self, result_type: Type, target: Union[Operation, Value], *, loc=None, ip=None - ): - super().__init__(result_type, _get_op_result_or_value(target), loc=loc, ip=ip) + def __init__( + self, + result_type: Type, + target: Union[Operation, Value], + *, + loc=None, + ip=None, + ): + super().__init__( + result_type, _get_op_result_or_value(target), loc=loc, ip=ip + ) -class GetClosestIsolatedParentOp: - def __init__( - self, result_type: Type, target: Union[Operation, Value], *, loc=None, ip=None - ): - super().__init__(result_type, _get_op_result_or_value(target), loc=loc, ip=ip) + +class testGetParentOp: + + def __init__( + self, + result_type: Type, + target: Union[Operation, Value], + *, + isolated_from_above: bool = False, + op_name: Optional[str] = None, + deduplicate: bool = False, + loc=None, + ip=None, + ): + super().__init__( + result_type, + _get_op_result_or_value(target), + isolated_from_above=isolated_from_above, + op_name=op_name, + deduplicate=deduplicate, + loc=loc, + ip=ip, + ) class MergeHandlesOp: diff --git a/mlir/test/Dialect/LLVM/transform-e2e.mlir b/mlir/test/Dialect/LLVM/transform-e2e.mlir --- a/mlir/test/Dialect/LLVM/transform-e2e.mlir +++ b/mlir/test/Dialect/LLVM/transform-e2e.mlir @@ -16,7 +16,7 @@ ^bb1(%module_op: !transform.any_op): %0 = transform.structured.match ops{["linalg.matmul"]} in %module_op : (!transform.any_op) -> !transform.any_op %1, %loops:3 = transform.structured.tile %0 [2, 2, 2] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op) - %2 = get_closest_isolated_parent %1 : (!transform.any_op) -> !transform.any_op + %2 = get_parent_op %1 {isolated_from_above} : (!transform.any_op) -> !transform.any_op transform.structured.vectorize %2 : (!transform.any_op) -> !transform.any_op %b = transform.bufferization.one_shot_bufferize layout{IdentityLayoutMap} %module_op {bufferize_function_boundaries = true} diff --git a/mlir/test/Dialect/Linalg/transform-op-matmul-to-outerproduct.mlir b/mlir/test/Dialect/Linalg/transform-op-matmul-to-outerproduct.mlir --- a/mlir/test/Dialect/Linalg/transform-op-matmul-to-outerproduct.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-matmul-to-outerproduct.mlir @@ -30,7 +30,7 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1 = get_closest_isolated_parent %0 : (!transform.any_op) -> !transform.any_op + %1 = get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op %2 = transform.structured.vectorize %1 : (!transform.any_op) -> !transform.any_op transform.apply_patterns to %2 { transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct" diff --git a/mlir/test/Dialect/Linalg/transform-op-vectorize.mlir b/mlir/test/Dialect/Linalg/transform-op-vectorize.mlir --- a/mlir/test/Dialect/Linalg/transform-op-vectorize.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-vectorize.mlir @@ -19,7 +19,7 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1 = get_closest_isolated_parent %0 : (!transform.any_op) -> !transform.any_op + %1 = get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op %2 = transform.structured.vectorize %1 : (!transform.any_op) -> !transform.any_op } @@ -44,7 +44,7 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1 = get_closest_isolated_parent %0 : (!transform.any_op) -> !transform.any_op + %1 = get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op %2 = transform.structured.vectorize %1 : (!transform.any_op) -> !transform.any_op } @@ -64,7 +64,7 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match ops{["linalg.copy"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1 = get_closest_isolated_parent %0 : (!transform.any_op) -> !transform.any_op + %1 = get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op %2 = transform.structured.vectorize %1 : (!transform.any_op) -> !transform.any_op } @@ -110,7 +110,7 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1 = get_closest_isolated_parent %0 : (!transform.any_op) -> !transform.any_op + %1 = get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op %2 = transform.structured.vectorize %1 : (!transform.any_op) -> !transform.any_op } @@ -158,7 +158,7 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1 = get_closest_isolated_parent %0 : (!transform.any_op) -> !transform.any_op + %1 = get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op %2 = transform.structured.vectorize %1 {vectorize_padding} : (!transform.any_op) -> !transform.any_op } diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir --- a/mlir/test/Dialect/Linalg/vectorization.mlir +++ b/mlir/test/Dialect/Linalg/vectorization.mlir @@ -13,7 +13,7 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match ops{["linalg.dot"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1 = get_closest_isolated_parent %0 : (!transform.any_op) -> !transform.any_op + %1 = get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op %2 = transform.structured.vectorize %1 { disable_multi_reduction_to_contract_patterns } : (!transform.any_op) -> !transform.any_op } @@ -32,7 +32,7 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match ops{["linalg.matvec"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1 = get_closest_isolated_parent %0 : (!transform.any_op) -> !transform.any_op + %1 = get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op %2 = transform.structured.vectorize %1 { disable_multi_reduction_to_contract_patterns } : (!transform.any_op) -> !transform.any_op } @@ -50,7 +50,7 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1 = get_closest_isolated_parent %0 : (!transform.any_op) -> !transform.any_op + %1 = get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op %2 = transform.structured.vectorize %1 { disable_multi_reduction_to_contract_patterns } : (!transform.any_op) -> !transform.any_op } @@ -69,7 +69,7 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match ops{["linalg.batch_matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1 = get_closest_isolated_parent %0 : (!transform.any_op) -> !transform.any_op + %1 = get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op %2 = transform.structured.vectorize %1 { disable_multi_reduction_to_contract_patterns } : (!transform.any_op) -> !transform.any_op } @@ -109,7 +109,7 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1 = get_closest_isolated_parent %0 : (!transform.any_op) -> !transform.any_op + %1 = get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op %2 = transform.structured.vectorize %1 { disable_multi_reduction_to_contract_patterns, disable_transfer_permutation_map_lowering_patterns } : (!transform.any_op) -> !transform.any_op } @@ -149,7 +149,7 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1 = get_closest_isolated_parent %0 : (!transform.any_op) -> !transform.any_op + %1 = get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op %2 = transform.structured.vectorize %1 { disable_multi_reduction_to_contract_patterns, disable_transfer_permutation_map_lowering_patterns } : (!transform.any_op) -> !transform.any_op } @@ -176,7 +176,7 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1 = get_closest_isolated_parent %0 : (!transform.any_op) -> !transform.any_op + %1 = get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op %2 = transform.structured.vectorize %1 { disable_multi_reduction_to_contract_patterns, disable_transfer_permutation_map_lowering_patterns } : (!transform.any_op) -> !transform.any_op } @@ -216,7 +216,7 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1 = get_closest_isolated_parent %0 : (!transform.any_op) -> !transform.any_op + %1 = get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op %2 = transform.structured.vectorize %1 { disable_multi_reduction_to_contract_patterns, disable_transfer_permutation_map_lowering_patterns } : (!transform.any_op) -> !transform.any_op } @@ -236,7 +236,7 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1 = get_closest_isolated_parent %0 : (!transform.any_op) -> !transform.any_op + %1 = get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op %2 = transform.structured.vectorize %1 { disable_multi_reduction_to_contract_patterns } : (!transform.any_op) -> !transform.any_op } @@ -260,7 +260,7 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1 = get_closest_isolated_parent %0 : (!transform.any_op) -> !transform.any_op + %1 = get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op %2 = transform.structured.vectorize %1 : (!transform.any_op) -> !transform.any_op } @@ -284,7 +284,7 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1 = get_closest_isolated_parent %0 : (!transform.any_op) -> !transform.any_op + %1 = get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op %2 = transform.structured.vectorize %1 : (!transform.any_op) -> !transform.any_op } @@ -329,7 +329,7 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1 = get_closest_isolated_parent %0 : (!transform.any_op) -> !transform.any_op + %1 = get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op %2 = transform.structured.vectorize %1 { vectorize_nd_extract } : (!transform.any_op) -> !transform.any_op } @@ -346,7 +346,7 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match ops{["linalg.fill"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1 = get_closest_isolated_parent %0 : (!transform.any_op) -> !transform.any_op + %1 = get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op %2 = transform.structured.vectorize %1 : (!transform.any_op) -> !transform.any_op } @@ -364,7 +364,7 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match ops{["linalg.fill"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1 = get_closest_isolated_parent %0 : (!transform.any_op) -> !transform.any_op + %1 = get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op %2 = transform.structured.vectorize %1 : (!transform.any_op) -> !transform.any_op } @@ -381,7 +381,7 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match ops{["memref.copy"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1 = get_closest_isolated_parent %0 : (!transform.any_op) -> !transform.any_op + %1 = get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op %2 = transform.structured.vectorize %1 : (!transform.any_op) -> !transform.any_op } @@ -401,7 +401,7 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match ops{["memref.copy"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1 = get_closest_isolated_parent %0 : (!transform.any_op) -> !transform.any_op + %1 = get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op %2 = transform.structured.vectorize %1 : (!transform.any_op) -> !transform.any_op } @@ -417,7 +417,7 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match ops{["memref.copy"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1 = get_closest_isolated_parent %0 : (!transform.any_op) -> !transform.any_op + %1 = get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op %2 = transform.structured.vectorize %1 : (!transform.any_op) -> !transform.any_op } @@ -445,7 +445,7 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1 = get_closest_isolated_parent %0 : (!transform.any_op) -> !transform.any_op + %1 = get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op %2 = transform.structured.vectorize %1 : (!transform.any_op) -> !transform.any_op } @@ -474,7 +474,7 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1 = get_closest_isolated_parent %0 : (!transform.any_op) -> !transform.any_op + %1 = get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op %2 = transform.structured.vectorize %1 : (!transform.any_op) -> !transform.any_op } @@ -559,7 +559,7 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1 = get_closest_isolated_parent %0 : (!transform.any_op) -> !transform.any_op + %1 = get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op %2 = transform.structured.vectorize %1 { disable_transfer_permutation_map_lowering_patterns } : (!transform.any_op) -> !transform.any_op } @@ -650,7 +650,7 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1 = get_closest_isolated_parent %0 : (!transform.any_op) -> !transform.any_op + %1 = get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op %2 = transform.structured.vectorize %1 { disable_transfer_permutation_map_lowering_patterns } : (!transform.any_op) -> !transform.any_op } @@ -694,7 +694,7 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1 = get_closest_isolated_parent %0 : (!transform.any_op) -> !transform.any_op + %1 = get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op %2 = transform.structured.vectorize %1 { disable_transfer_permutation_map_lowering_patterns } : (!transform.any_op) -> !transform.any_op } @@ -737,7 +737,7 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1 = get_closest_isolated_parent %0 : (!transform.any_op) -> !transform.any_op + %1 = get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op %2 = transform.structured.vectorize %1 { disable_transfer_permutation_map_lowering_patterns } : (!transform.any_op) -> !transform.any_op } @@ -769,7 +769,7 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1 = get_closest_isolated_parent %0 : (!transform.any_op) -> !transform.any_op + %1 = get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op %2 = transform.structured.vectorize %1 { disable_multi_reduction_to_contract_patterns, disable_transfer_permutation_map_lowering_patterns } : (!transform.any_op) -> !transform.any_op } @@ -798,7 +798,7 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match ops{["tensor.pad"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1 = get_closest_isolated_parent %0 : (!transform.any_op) -> !transform.any_op + %1 = get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op %2 = transform.structured.vectorize %1 { vectorize_padding } : (!transform.any_op) -> !transform.any_op } @@ -827,7 +827,7 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match ops{["tensor.pad"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1 = get_closest_isolated_parent %0 : (!transform.any_op) -> !transform.any_op + %1 = get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op %2 = transform.structured.vectorize %1 { vectorize_padding } : (!transform.any_op) -> !transform.any_op } @@ -864,7 +864,7 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match ops{["tensor.pad"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1 = get_closest_isolated_parent %0 : (!transform.any_op) -> !transform.any_op + %1 = get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op %2 = transform.structured.vectorize %1 { vectorize_padding } : (!transform.any_op) -> !transform.any_op } @@ -884,7 +884,7 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match ops{["tensor.pad"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1 = get_closest_isolated_parent %0 : (!transform.any_op) -> !transform.any_op + %1 = get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op %2 = transform.structured.vectorize %1 { vectorize_padding } : (!transform.any_op) -> !transform.any_op } @@ -914,7 +914,7 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match ops{["tensor.pad"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1 = get_closest_isolated_parent %0 : (!transform.any_op) -> !transform.any_op + %1 = get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op %2 = transform.structured.vectorize %1 { vectorize_padding } : (!transform.any_op) -> !transform.any_op } @@ -947,7 +947,7 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %3 = transform.structured.match ops{["tensor.pad"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %4 = get_closest_isolated_parent %3 : (!transform.any_op) -> !transform.any_op + %4 = get_parent_op %3 {isolated_from_above} : (!transform.any_op) -> !transform.any_op %5 = transform.structured.vectorize %4 { vectorize_padding } : (!transform.any_op) -> !transform.any_op } @@ -984,7 +984,7 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %3 = transform.structured.match ops{["tensor.pad"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %4 = get_closest_isolated_parent %3 : (!transform.any_op) -> !transform.any_op + %4 = get_parent_op %3 {isolated_from_above} : (!transform.any_op) -> !transform.any_op %5 = transform.structured.vectorize %4 { vectorize_padding } : (!transform.any_op) -> !transform.any_op } @@ -1018,7 +1018,7 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %3 = transform.structured.match ops{["tensor.pad"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %4 = get_closest_isolated_parent %3 : (!transform.any_op) -> !transform.any_op + %4 = get_parent_op %3 {isolated_from_above} : (!transform.any_op) -> !transform.any_op %5 = transform.structured.vectorize %4 { vectorize_padding } : (!transform.any_op) -> !transform.any_op } @@ -1046,7 +1046,7 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %3 = transform.structured.match ops{["tensor.pad"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %4 = get_closest_isolated_parent %3 : (!transform.any_op) -> !transform.any_op + %4 = get_parent_op %3 {isolated_from_above} : (!transform.any_op) -> !transform.any_op %5 = transform.structured.vectorize %4 : (!transform.any_op) -> !transform.any_op } @@ -1083,7 +1083,7 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %3 = transform.structured.match ops{["tensor.pad"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %4 = get_closest_isolated_parent %3 : (!transform.any_op) -> !transform.any_op + %4 = get_parent_op %3 {isolated_from_above} : (!transform.any_op) -> !transform.any_op %5 = transform.structured.vectorize %4 { vectorize_padding } : (!transform.any_op) -> !transform.any_op } @@ -1118,7 +1118,7 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %3 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %4 = get_closest_isolated_parent %3 : (!transform.any_op) -> !transform.any_op + %4 = get_parent_op %3 {isolated_from_above} : (!transform.any_op) -> !transform.any_op %5 = transform.structured.vectorize %4 : (!transform.any_op) -> !transform.any_op } @@ -1163,7 +1163,7 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %3 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %4 = get_closest_isolated_parent %3 : (!transform.any_op) -> !transform.any_op + %4 = get_parent_op %3 {isolated_from_above} : (!transform.any_op) -> !transform.any_op %5 = transform.structured.vectorize %4 { disable_multi_reduction_to_contract_patterns, disable_transfer_permutation_map_lowering_patterns } : (!transform.any_op) -> !transform.any_op } @@ -1193,7 +1193,7 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %3 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %4 = get_closest_isolated_parent %3 : (!transform.any_op) -> !transform.any_op + %4 = get_parent_op %3 {isolated_from_above} : (!transform.any_op) -> !transform.any_op %5 = transform.structured.vectorize %4 { vectorize_padding } : (!transform.any_op) -> !transform.any_op } @@ -1224,7 +1224,7 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %3 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %4 = get_closest_isolated_parent %3 : (!transform.any_op) -> !transform.any_op + %4 = get_parent_op %3 {isolated_from_above} : (!transform.any_op) -> !transform.any_op %5 = transform.structured.vectorize %4 : (!transform.any_op) -> !transform.any_op } @@ -1254,7 +1254,7 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %3 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %4 = get_closest_isolated_parent %3 : (!transform.any_op) -> !transform.any_op + %4 = get_parent_op %3 {isolated_from_above} : (!transform.any_op) -> !transform.any_op %5 = transform.structured.vectorize %4 : (!transform.any_op) -> !transform.any_op } @@ -1284,7 +1284,7 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %3 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %4 = get_closest_isolated_parent %3 : (!transform.any_op) -> !transform.any_op + %4 = get_parent_op %3 {isolated_from_above} : (!transform.any_op) -> !transform.any_op %5 = transform.structured.vectorize %4 : (!transform.any_op) -> !transform.any_op } @@ -1314,7 +1314,7 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %3 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %4 = get_closest_isolated_parent %3 : (!transform.any_op) -> !transform.any_op + %4 = get_parent_op %3 {isolated_from_above} : (!transform.any_op) -> !transform.any_op %5 = transform.structured.vectorize %4 : (!transform.any_op) -> !transform.any_op } @@ -1344,7 +1344,7 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %3 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %4 = get_closest_isolated_parent %3 : (!transform.any_op) -> !transform.any_op + %4 = get_parent_op %3 {isolated_from_above} : (!transform.any_op) -> !transform.any_op %5 = transform.structured.vectorize %4 : (!transform.any_op) -> !transform.any_op } @@ -1378,7 +1378,7 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %3 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %4 = get_closest_isolated_parent %3 : (!transform.any_op) -> !transform.any_op + %4 = get_parent_op %3 {isolated_from_above} : (!transform.any_op) -> !transform.any_op %5 = transform.structured.vectorize %4 : (!transform.any_op) -> !transform.any_op } @@ -1416,11 +1416,11 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match ops{["linalg.fill"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1 = get_closest_isolated_parent %0 : (!transform.any_op) -> !transform.any_op + %1 = get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op %2 = transform.structured.vectorize %1 : (!transform.any_op) -> !transform.any_op %3 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %4 = get_closest_isolated_parent %3 : (!transform.any_op) -> !transform.any_op + %4 = get_parent_op %3 {isolated_from_above} : (!transform.any_op) -> !transform.any_op %5 = transform.structured.vectorize %4 : (!transform.any_op) -> !transform.any_op } @@ -1463,7 +1463,7 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1 = get_closest_isolated_parent %0 : (!transform.any_op) -> !transform.any_op + %1 = get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op %2 = transform.structured.vectorize %1 : (!transform.any_op) -> !transform.any_op } @@ -1494,7 +1494,7 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1 = get_closest_isolated_parent %0 : (!transform.any_op) -> !transform.any_op + %1 = get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op %2 = transform.structured.vectorize %1 : (!transform.any_op) -> !transform.any_op } @@ -1533,7 +1533,7 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1 = get_closest_isolated_parent %0 : (!transform.any_op) -> !transform.any_op + %1 = get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op %2 = transform.structured.vectorize %1 { disable_multi_reduction_to_contract_patterns, disable_transfer_permutation_map_lowering_patterns } : (!transform.any_op) -> !transform.any_op } @@ -1557,7 +1557,7 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match ops{["linalg.map"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1 = get_closest_isolated_parent %0 : (!transform.any_op) -> !transform.any_op + %1 = get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op %2 = transform.structured.vectorize %1 : (!transform.any_op) -> !transform.any_op } @@ -1576,7 +1576,7 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match ops{["linalg.transpose"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1 = get_closest_isolated_parent %0 : (!transform.any_op) -> !transform.any_op + %1 = get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op %2 = transform.structured.vectorize %1 : (!transform.any_op) -> !transform.any_op } @@ -1599,7 +1599,7 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match ops{["linalg.reduce"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1 = get_closest_isolated_parent %0 : (!transform.any_op) -> !transform.any_op + %1 = get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op %2 = transform.structured.vectorize %1 : (!transform.any_op) -> !transform.any_op } @@ -1666,7 +1666,7 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1 = get_closest_isolated_parent %0 : (!transform.any_op) -> !transform.any_op + %1 = get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op %2 = transform.structured.vectorize %1 : (!transform.any_op) -> !transform.any_op } @@ -1695,7 +1695,7 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1 = get_closest_isolated_parent %0 : (!transform.any_op) -> !transform.any_op + %1 = get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op %2 = transform.structured.vectorize %1 : (!transform.any_op) -> !transform.any_op } @@ -1716,7 +1716,7 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match ops{["tensor.pad"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1 = get_closest_isolated_parent %0 : (!transform.any_op) -> !transform.any_op + %1 = get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op %2 = transform.structured.vectorize %1 { vectorize_padding } : (!transform.any_op) -> !transform.any_op } @@ -1738,7 +1738,7 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %3 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %4 = get_closest_isolated_parent %3 : (!transform.any_op) -> !transform.any_op + %4 = get_parent_op %3 {isolated_from_above} : (!transform.any_op) -> !transform.any_op %5 = transform.structured.vectorize %4 : (!transform.any_op) -> !transform.any_op } @@ -1775,7 +1775,7 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %3 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %4 = get_closest_isolated_parent %3 : (!transform.any_op) -> !transform.any_op + %4 = get_parent_op %3 {isolated_from_above} : (!transform.any_op) -> !transform.any_op %5 = transform.structured.vectorize %4 : (!transform.any_op) -> !transform.any_op } diff --git a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir --- a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir +++ b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir @@ -30,7 +30,7 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1 = get_closest_isolated_parent %0 : (!transform.any_op) -> !transform.any_op + %1 = get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op %2 = transform.structured.vectorize %1 : (!transform.any_op) -> !transform.any_op } @@ -66,7 +66,7 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1 = get_closest_isolated_parent %0 : (!transform.any_op) -> !transform.any_op + %1 = get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op %2 = transform.structured.vectorize %1 { vectorize_nd_extract } : (!transform.any_op) -> !transform.any_op } @@ -105,7 +105,7 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1 = get_closest_isolated_parent %0 : (!transform.any_op) -> !transform.any_op + %1 = get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op %2 = transform.structured.vectorize %1 { vectorize_nd_extract } : (!transform.any_op) -> !transform.any_op } @@ -157,7 +157,7 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1 = get_closest_isolated_parent %0 : (!transform.any_op) -> !transform.any_op + %1 = get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op %2 = transform.structured.vectorize %1 { vectorize_nd_extract } : (!transform.any_op) -> !transform.any_op } @@ -205,7 +205,7 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1 = get_closest_isolated_parent %0 : (!transform.any_op) -> !transform.any_op + %1 = get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op %2 = transform.structured.vectorize %1 { vectorize_nd_extract } : (!transform.any_op) -> !transform.any_op } // ----- @@ -249,7 +249,7 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1 = get_closest_isolated_parent %0 : (!transform.any_op) -> !transform.any_op + %1 = get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op %2 = transform.structured.vectorize %1 { vectorize_nd_extract } : (!transform.any_op) -> !transform.any_op } @@ -291,7 +291,7 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1 = get_closest_isolated_parent %0 : (!transform.any_op) -> !transform.any_op + %1 = get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op %2 = transform.structured.vectorize %1 { vectorize_nd_extract } : (!transform.any_op) -> !transform.any_op } @@ -333,7 +333,7 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1 = get_closest_isolated_parent %0 : (!transform.any_op) -> !transform.any_op + %1 = get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op %2 = transform.structured.vectorize %1 { vectorize_nd_extract } : (!transform.any_op) -> !transform.any_op } @@ -377,7 +377,7 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1 = get_closest_isolated_parent %0 : (!transform.any_op) -> !transform.any_op + %1 = get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op %2 = transform.structured.vectorize %1 { vectorize_nd_extract } : (!transform.any_op) -> !transform.any_op } @@ -417,7 +417,7 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1 = get_closest_isolated_parent %0 : (!transform.any_op) -> !transform.any_op + %1 = get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op %2 = transform.structured.vectorize %1 { vectorize_nd_extract } : (!transform.any_op) -> !transform.any_op } @@ -457,7 +457,7 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1 = get_closest_isolated_parent %0 : (!transform.any_op) -> !transform.any_op + %1 = get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op %2 = transform.structured.vectorize %1 { vectorize_nd_extract } : (!transform.any_op) -> !transform.any_op } @@ -496,6 +496,6 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1 = get_closest_isolated_parent %0 : (!transform.any_op) -> !transform.any_op + %1 = get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op %2 = transform.structured.vectorize %1 { vectorize_nd_extract } : (!transform.any_op) -> !transform.any_op } diff --git a/mlir/test/Dialect/Transform/expensive-checks.mlir b/mlir/test/Dialect/Transform/expensive-checks.mlir --- a/mlir/test/Dialect/Transform/expensive-checks.mlir +++ b/mlir/test/Dialect/Transform/expensive-checks.mlir @@ -19,7 +19,7 @@ ^bb1(%arg1: !transform.any_op): // expected-note @below {{handle to invalidated ops}} %0 = pdl_match @return in %arg1 : (!transform.any_op) -> !transform.any_op - %1 = get_closest_isolated_parent %0 : (!transform.any_op) -> !transform.any_op + %1 = get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op // expected-note @below {{invalidated by this transform op that consumes its operand #0}} test_consume_operand %1 : !transform.any_op // expected-error @below {{op uses a handle invalidated by a previously executed transform op}} diff --git a/mlir/test/Dialect/Transform/selective-targeting.mlir b/mlir/test/Dialect/Transform/selective-targeting.mlir --- a/mlir/test/Dialect/Transform/selective-targeting.mlir +++ b/mlir/test/Dialect/Transform/selective-targeting.mlir @@ -79,7 +79,7 @@ %0 = pdl_match @pdl_target_attrA in %arg1 : (!transform.any_op) -> !transform.any_op transform.structured.tile %0 [4, 4, 4] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op) %1 = pdl_match @pdl_target_attrC in %arg1 : (!transform.any_op) -> !transform.any_op - %2 = transform.get_closest_isolated_parent %1 : (!transform.any_op) -> !transform.any_op + %2 = get_parent_op %1 {isolated_from_above} : (!transform.any_op) -> !transform.any_op transform.structured.vectorize %2 : (!transform.any_op) -> !transform.any_op } } @@ -124,7 +124,7 @@ transform.sequence %arg0 : !transform.any_op failures(propagate) { ^bb1(%arg1: !transform.any_op): %0 = pdl_match @pdl_target in %arg1 : (!transform.any_op) -> !transform.any_op - %1 = get_closest_isolated_parent %0 : (!transform.any_op) -> !transform.any_op + %1 = get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op transform.structured.vectorize %1 : (!transform.any_op) -> !transform.any_op } } diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir --- a/mlir/test/Dialect/Transform/test-interpreter.mlir +++ b/mlir/test/Dialect/Transform/test-interpreter.mlir @@ -107,7 +107,7 @@ transform.sequence %arg0 : !transform.any_op failures(propagate) { ^bb1(%arg1: !transform.any_op): %f = pdl_match @const in %arg1 : (!transform.any_op) -> !transform.any_op - %m = get_closest_isolated_parent %f : (!transform.any_op) -> !transform.any_op + %m = get_parent_op %f {isolated_from_above} : (!transform.any_op) -> !transform.any_op test_print_remark_at_operand %m, "parent function" : !transform.any_op } } @@ -169,7 +169,7 @@ transform.sequence %arg0 : !transform.any_op failures(propagate) { ^bb1(%arg1: !transform.any_op): %0 = pdl_match @match_call in %arg1 : (!transform.any_op) -> !transform.any_op - %1 = get_closest_isolated_parent %0 : (!transform.any_op) -> !transform.any_op + %1 = get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op // expected-error @below {{all alternatives failed}} transform.alternatives %1 : !transform.any_op { ^bb2(%arg2: !transform.any_op): @@ -202,7 +202,7 @@ transform.sequence %arg0 : !transform.any_op failures(propagate) { ^bb1(%arg1: !transform.any_op): %0 = pdl_match @match_call in %arg1 : (!transform.any_op) -> !transform.any_op - %1 = get_closest_isolated_parent %0 : (!transform.any_op) -> !transform.any_op + %1 = get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op transform.alternatives %1 : !transform.any_op { ^bb2(%arg2: !transform.any_op): %2 = transform.pdl_match @match_call in %arg2 : (!transform.any_op) -> !transform.any_op @@ -243,7 +243,7 @@ transform.sequence %arg0 : !transform.any_op failures(propagate) { ^bb1(%arg1: !transform.any_op): %0 = pdl_match @match_call in %arg1 : (!transform.any_op) -> !transform.any_op - %1 = get_closest_isolated_parent %0 : (!transform.any_op) -> !transform.any_op + %1 = get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op transform.alternatives %1 : !transform.any_op { ^bb2(%arg2: !transform.any_op): %2 = transform.pdl_match @match_call in %arg2 : (!transform.any_op) -> !transform.any_op @@ -279,7 +279,7 @@ transform.sequence %arg0 : !transform.any_op failures(propagate) { ^bb1(%arg1: !transform.any_op): %0 = pdl_match @match_call in %arg1 : (!transform.any_op) -> !transform.any_op - %1 = get_closest_isolated_parent %0 : (!transform.any_op) -> !transform.any_op + %1 = get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op %2 = transform.alternatives %1 : !transform.any_op -> !transform.any_op { ^bb2(%arg2: !transform.any_op): %3 = transform.pdl_match @match_call in %arg2 : (!transform.any_op) -> !transform.any_op @@ -1826,3 +1826,37 @@ transform.apply_licm to %arg1 : !transform.any_op } } + +// ----- + +func.func @get_parent_op() { + // expected-remark @below{{found test.foo parent}} + "test.foo"() ({ + // expected-remark @below{{direct parent}} + "test.bar"() ({ + "test.qux"() : () -> () + "test.qux"() : () -> () + }) : () -> () + }) : () -> () +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !transform.any_op): + %0 = transform.structured.match ops{["test.qux"]} in %arg1 : (!transform.any_op) -> !transform.any_op + + // Get parent by name. + %1 = transform.get_parent_op %0 {op_name = "test.foo"} : (!transform.any_op) -> !transform.any_op + test_print_remark_at_operand %1, "found test.foo parent" : !transform.any_op + + // Get immediate parent. + %2 = transform.get_parent_op %0 : (!transform.any_op) -> !transform.any_op + test_print_remark_at_operand %2, "direct parent" : !transform.any_op + // expected-remark @below{{2}} + test_print_number_of_associated_payload_ir_ops %2 : !transform.any_op + + // Deduplicate results. + %3 = transform.structured.match ops{["test.qux"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %4 = transform.get_parent_op %3 {deduplicate} : (!transform.any_op) -> !transform.any_op + // expected-remark @below{{1}} + test_print_number_of_associated_payload_ir_ops %4 : !transform.any_op +} diff --git a/mlir/test/Dialect/Vector/transform-vector.mlir b/mlir/test/Dialect/Vector/transform-vector.mlir --- a/mlir/test/Dialect/Vector/transform-vector.mlir +++ b/mlir/test/Dialect/Vector/transform-vector.mlir @@ -18,7 +18,7 @@ %0 = transform.structured.match ops{["linalg.matmul"]} in %module_op : (!transform.any_op) -> !transform.any_op %1, %loops:3 = transform.structured.tile %0 [8, 4, 2] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op) - %2 = get_closest_isolated_parent %1 : (!transform.any_op) -> !transform.any_op + %2 = get_parent_op %1 {isolated_from_above} : (!transform.any_op) -> !transform.any_op transform.structured.vectorize %2 : (!transform.any_op) -> !transform.any_op %b = transform.bufferization.one_shot_bufferize layout{IdentityLayoutMap} %module_op diff --git a/mlir/test/python/dialects/transform.py b/mlir/test/python/dialects/transform.py --- a/mlir/test/python/dialects/transform.py +++ b/mlir/test/python/dialects/transform.py @@ -94,67 +94,67 @@ @run def testNestedSequenceOpWithExtras(): - sequence = transform.SequenceOp( + sequence = transform.SequenceOp( transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get(), [transform.AnyOpType.get(), transform.OperationType.get("foo.bar")], ) - with InsertionPoint(sequence.body): - nested = transform.SequenceOp( + with InsertionPoint(sequence.body): + nested = transform.SequenceOp( transform.FailurePropagationMode.PROPAGATE, [], sequence.bodyTarget, sequence.bodyExtraArgs, ) - with InsertionPoint(nested.body): - transform.YieldOp() - transform.YieldOp() - # CHECK-LABEL: TEST: testNestedSequenceOpWithExtras - # CHECK: transform.sequence failures(propagate) - # CHECK: ^{{.*}}(%[[ARG0:.*]]: !transform.any_op, %[[ARG1:.*]]: !transform.any_op, %[[ARG2:.*]]: !transform.op<"foo.bar">): - # CHECK: sequence %[[ARG0]], %[[ARG1]], %[[ARG2]] : (!transform.any_op, !transform.any_op, !transform.op<"foo.bar">) + with InsertionPoint(nested.body): + transform.YieldOp() + transform.YieldOp() + # CHECK-LABEL: TEST: testNestedSequenceOpWithExtras + # CHECK: transform.sequence failures(propagate) + # CHECK: ^{{.*}}(%[[ARG0:.*]]: !transform.any_op, %[[ARG1:.*]]: !transform.any_op, %[[ARG2:.*]]: !transform.op<"foo.bar">): + # CHECK: sequence %[[ARG0]], %[[ARG1]], %[[ARG2]] : (!transform.any_op, !transform.any_op, !transform.op<"foo.bar">) @run def testTransformPDLOps(): - withPdl = transform_pdl.WithPDLPatternsOp(transform.AnyOpType.get()) - with InsertionPoint(withPdl.body): - sequence = transform.SequenceOp( - transform.FailurePropagationMode.PROPAGATE, - [transform.AnyOpType.get()], - withPdl.bodyTarget, - ) - with InsertionPoint(sequence.body): - match = transform_pdl.PDLMatchOp( - transform.AnyOpType.get(), sequence.bodyTarget, "pdl_matcher" - ) - transform.YieldOp(match) - # CHECK-LABEL: TEST: testTransformPDLOps - # CHECK: transform.with_pdl_patterns { - # CHECK: ^{{.*}}(%[[ARG0:.+]]: !transform.any_op): - # CHECK: = sequence %[[ARG0]] : !transform.any_op -> !transform.any_op failures(propagate) { - # CHECK: ^{{.*}}(%[[ARG1:.+]]: !transform.any_op): - # CHECK: %[[RES:.+]] = pdl_match @pdl_matcher in %[[ARG1]] - # CHECK: yield %[[RES]] : !transform.any_op - # CHECK: } - # CHECK: } + withPdl = transform_pdl.WithPDLPatternsOp(transform.AnyOpType.get()) + with InsertionPoint(withPdl.body): + sequence = transform.SequenceOp( + transform.FailurePropagationMode.PROPAGATE, + [transform.AnyOpType.get()], + withPdl.bodyTarget, + ) + with InsertionPoint(sequence.body): + match = transform_pdl.PDLMatchOp( + transform.AnyOpType.get(), sequence.bodyTarget, "pdl_matcher" + ) + transform.YieldOp(match) + # CHECK-LABEL: TEST: testTransformPDLOps + # CHECK: transform.with_pdl_patterns { + # CHECK: ^{{.*}}(%[[ARG0:.+]]: !transform.any_op): + # CHECK: = sequence %[[ARG0]] : !transform.any_op -> !transform.any_op failures(propagate) { + # CHECK: ^{{.*}}(%[[ARG1:.+]]: !transform.any_op): + # CHECK: %[[RES:.+]] = pdl_match @pdl_matcher in %[[ARG1]] + # CHECK: yield %[[RES]] : !transform.any_op + # CHECK: } + # CHECK: } @run -def testGetClosestIsolatedParentOp(): - sequence = transform.SequenceOp( - transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get() +def testGetParentOp(): + sequence = transform.SequenceOp( + transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get() + ) + with InsertionPoint(sequence.body): + transform.GetParentOp( + transform.AnyOpType.get(), sequence.bodyTarget, isolated_from_above=True ) - with InsertionPoint(sequence.body): - transform.GetClosestIsolatedParentOp( - transform.AnyOpType.get(), sequence.bodyTarget - ) - transform.YieldOp() - # CHECK-LABEL: TEST: testGetClosestIsolatedParentOp - # CHECK: transform.sequence - # CHECK: ^{{.*}}(%[[ARG1:.+]]: !transform.any_op): - # CHECK: = get_closest_isolated_parent %[[ARG1]] + transform.YieldOp() + # CHECK-LABEL: TEST: testGetParentOp + # CHECK: transform.sequence + # CHECK: ^{{.*}}(%[[ARG1:.+]]: !transform.any_op): + # CHECK: = get_parent_op %[[ARG1]] {isolated_from_above} @run