diff --git a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td --- a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td +++ b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td @@ -46,16 +46,22 @@ DeclareOpInterfaceMethods]> { let summary = "Outlines a loop into a named function"; let description = [{ - Moves the loop into a separate function with the specified name and - replaces the loop in the Payload IR with a call to that function. Takes - care of forwarding values that are used in the loop as function arguments. - If the operand is associated with more than one loop, each loop will be - outlined into a separate function. The provided name is used as a _base_ - for forming actual function names following SymbolTable auto-renaming - scheme to avoid duplicate symbols. Expects that all ops in the Payload IR - have a SymbolTable ancestor (typically true because of the top-level - module). Returns the handle to the list of outlined functions in the same - order as the operand handle. + Moves the loop into a separate function with the specified name and replaces + the loop in the Payload IR with a call to that function. Takes care of + forwarding values that are used in the loop as function arguments. If the + operand is associated with more than one loop, each loop will be outlined + into a separate function. The provided name is used as a _base_ for forming + actual function names following `SymbolTable` auto-renaming scheme to avoid + duplicate symbols. Expects that all ops in the Payload IR have a + `SymbolTable` ancestor (typically true because of the top-level module). + + #### Return Modes + + Returns a handle to the list of outlined functions and a handle to the + corresponding function call operations in the same order as the operand + handle. + + Produces a definite failure if outlining failed for any of the targets. }]; // Note that despite the name of the transform operation and related utility @@ -63,7 +69,8 @@ // a loop. let arguments = (ins TransformHandleTypeInterface:$target, StrAttr:$func_name); - let results = (outs TransformHandleTypeInterface:$transformed); + let results = (outs TransformHandleTypeInterface:$function, + TransformHandleTypeInterface:$call); let assemblyFormat = "$target attr-dict `:` functional-type(operands, results)"; diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp --- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp +++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp @@ -85,7 +85,8 @@ DiagnosedSilenceableFailure transform::LoopOutlineOp::apply(transform::TransformResults &results, transform::TransformState &state) { - SmallVector transformed; + SmallVector functions; + SmallVector calls; DenseMap symbolTables; for (Operation *target : state.getPayloadOps(getTarget())) { Location location = target->getLoc(); @@ -112,9 +113,11 @@ symbolTable.insert(*outlined); call.setCalleeAttr(FlatSymbolRefAttr::get(*outlined)); } - transformed.push_back(*outlined); + functions.push_back(*outlined); + calls.push_back(call); } - results.set(getTransformed().cast(), transformed); + results.set(getFunction().cast(), functions); + results.set(getCall().cast(), calls); return DiagnosedSilenceableFailure::success(); } diff --git a/mlir/python/mlir/dialects/_loop_transform_ops_ext.py b/mlir/python/mlir/dialects/_loop_transform_ops_ext.py --- a/mlir/python/mlir/dialects/_loop_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_loop_transform_ops_ext.py @@ -39,7 +39,8 @@ def __init__( self, - result_type: Type, + function_type: Type, + call_type: Type, target: Union[Operation, Value], *, func_name: Union[str, StringAttr], @@ -47,7 +48,8 @@ loc=None, ): super().__init__( - result_type, + function_type, + call_type, _get_op_result_or_value(target), func_name=(func_name if isinstance(func_name, StringAttr) else StringAttr.get(func_name)), diff --git a/mlir/test/Dialect/SCF/transform-ops-invalid.mlir b/mlir/test/Dialect/SCF/transform-ops-invalid.mlir --- a/mlir/test/Dialect/SCF/transform-ops-invalid.mlir +++ b/mlir/test/Dialect/SCF/transform-ops-invalid.mlir @@ -54,8 +54,8 @@ } transform.sequence failures(propagate) { -^bb1(%arg1: !pdl.operation): - %0 = transform.structured.match ops{["scf.while"]} in %arg1 : (!pdl.operation) -> !pdl.operation +^bb1(%arg1: !transform.any_op): + %0 = transform.structured.match ops{["scf.while"]} in %arg1 : (!transform.any_op) -> !transform.any_op // expected-error @below {{failed to outline}} - transform.loop.outline %0 {func_name = "foo"} : (!pdl.operation) -> !pdl.operation + transform.loop.outline %0 {func_name = "foo"} : (!transform.any_op) -> (!transform.any_op, !transform.any_op) } diff --git a/mlir/test/Dialect/SCF/transform-ops.mlir b/mlir/test/Dialect/SCF/transform-ops.mlir --- a/mlir/test/Dialect/SCF/transform-ops.mlir +++ b/mlir/test/Dialect/SCF/transform-ops.mlir @@ -75,11 +75,11 @@ } transform.sequence failures(propagate) { -^bb1(%arg1: !pdl.operation): - %0 = transform.structured.match ops{["arith.addi"]} in %arg1 : (!pdl.operation) -> !pdl.operation - %1 = transform.loop.get_parent_for %0 : (!pdl.operation) -> !transform.op<"scf.for"> +^bb1(%arg1: !transform.any_op): + %0 = transform.structured.match ops{["arith.addi"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1 = transform.loop.get_parent_for %0 : (!transform.any_op) -> !transform.op<"scf.for"> // CHECK: = transform.loop.outline %{{.*}} - transform.loop.outline %1 {func_name = "foo"} : (!transform.op<"scf.for">) -> !pdl.operation + transform.loop.outline %1 {func_name = "foo"} : (!transform.op<"scf.for">) -> (!transform.any_op, !transform.any_op) } // ----- diff --git a/mlir/test/python/dialects/transform_loop_ext.py b/mlir/test/python/dialects/transform_loop_ext.py --- a/mlir/test/python/dialects/transform_loop_ext.py +++ b/mlir/test/python/dialects/transform_loop_ext.py @@ -33,7 +33,7 @@ sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], transform.OperationType.get("scf.for")) with InsertionPoint(sequence.body): - loop.LoopOutlineOp(pdl.OperationType.get(), sequence.bodyTarget, func_name="foo") + loop.LoopOutlineOp(transform.AnyOpType.get(), transform.AnyOpType.get(), sequence.bodyTarget, func_name="foo") transform.YieldOp() # CHECK-LABEL: TEST: loopOutline # CHECK: = transform.loop.outline %