diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td --- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td @@ -854,6 +854,30 @@ "type($pattern) `,` type($handles)"; } +def SelectOp : TransformDialectOp<"select", + [DeclareOpInterfaceMethods, + NavigationTransformOpTrait, MemoryEffectsOpInterface]> { + let summary = "Select payload ops by name"; + let description = [{ + The handle defined by this Transform op corresponds to all operations among + `target` that have the specified properties. Currently the following + properties are supported: + + - `op_name`: The op must have the specified name. + + The result payload ops are in the same relative order as the targeted ops. + This transform op reads the `target` handle and produces the `result` + handle. It reads the payload, but does not modify it. + }]; + + let arguments = (ins TransformHandleTypeInterface:$target, + StrAttr:$op_name); + let results = (outs TransformHandleTypeInterface:$result); + let assemblyFormat = [{ + $op_name `in` $target attr-dict `:` functional-type(operands, results) + }]; +} + def SequenceOp : TransformDialectOp<"sequence", [DeclareOpInterfaceMethods result; + auto payloadOps = state.getPayloadOps(getTarget()); + for (Operation *op : payloadOps) { + if (op->getName().getStringRef() == getOpName()) + result.push_back(op); + } + results.set(cast(getResult()), result); + return DiagnosedSilenceableFailure::success(); +} + //===----------------------------------------------------------------------===// // SplitHandleOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir --- a/mlir/test/Dialect/Transform/test-interpreter.mlir +++ b/mlir/test/Dialect/Transform/test-interpreter.mlir @@ -1945,3 +1945,32 @@ // expected-error @below{{failed to verify payload op}} transform.verify %0 : !transform.any_op } + +// ----- + +func.func @select() { + // expected-remark @below{{found foo}} + "test.foo"() : () -> () + // expected-remark @below{{found bar}} + "test.bar"() : () -> () + // expected-remark @below{{found foo}} + "test.foo"() : () -> () + func.return +} + +transform.sequence failures(propagate) { +^bb0(%arg0: !transform.any_op): + // Match all ops inside the function (including the function itself). + %func_op = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %0 = transform.structured.match in %func_op : (!transform.any_op) -> !transform.any_op + // expected-remark @below{{5}} + test_print_number_of_associated_payload_ir_ops %0 : !transform.any_op + + // Select "test.foo". + %foo = transform.select "test.foo" in %0 : (!transform.any_op) -> !transform.any_op + test_print_remark_at_operand %foo, "found foo" : !transform.any_op + + // Select "test.bar". + %bar = transform.select "test.bar" in %0 : (!transform.any_op) -> !transform.any_op + test_print_remark_at_operand %bar, "found bar" : !transform.any_op +}