diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td --- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td @@ -509,26 +509,28 @@ }]; } -def SplitHandlesOp : TransformDialectOp<"split_handles", +def SplitHandleOp : TransformDialectOp<"split_handle", [FunctionalStyleTransformOpTrait, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { - let summary = "Splits handles from a union of payload ops to a list"; + let summary = "Splits a handle of payload ops into handles with a single op"; let description = [{ - Creates `num_result_handles` transform IR handles extracted from the - `handle` operand. The resulting Payload IR operation handles are listed - in the same order as the operations appear in the source `handle`. - This is useful for ensuring a statically known number of operations are - tracked by the source `handle` and to extract them into individual handles - that can be further manipulated in isolation. - - This operation succeeds and returns `num_result_handles` if the statically - specified `num_result_handles` corresponds to the dynamic number of - operations contained in the source `handle`. Otherwise it silently fails. + Splits `handle` into one or multiple handles, as specified by the number + of results of this operation. `handle` should be mapped to as many payload + ops as there are results. Otherwise, this transform will fail silently. + Each result handle is mapped to exactly one payload op. The order + of the payload ops is preserved, i.e., the i-th payload op is mapped to the + i-th result handle. + + This operation is useful for ensuring a statically known number of + operations are tracked by the source `handle` and to extract them into + individual handles that can be further manipulated in isolation. + + If `handle` is empty, this transform will succeed and all result handles + are empty. }]; - let arguments = (ins TransformHandleTypeInterface:$handle, - I64Attr:$num_result_handles); + let arguments = (ins TransformHandleTypeInterface:$handle); let results = (outs Variadic:$results); let builders = [ @@ -536,8 +538,7 @@ ]; let assemblyFormat = [{ - $handle `in` `[` $num_result_handles `]` - attr-dict `:` functional-type(operands, results) + $handle attr-dict `:` functional-type(operands, results) }]; } diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -1488,48 +1488,43 @@ } //===----------------------------------------------------------------------===// -// SplitHandlesOp +// SplitHandleOp //===----------------------------------------------------------------------===// -void transform::SplitHandlesOp::build(OpBuilder &builder, - OperationState &result, Value target, - int64_t numResultHandles) { +void transform::SplitHandleOp::build(OpBuilder &builder, OperationState &result, + Value target, int64_t numResultHandles) { result.addOperands(target); - result.addAttribute(SplitHandlesOp::getNumResultHandlesAttrName(result.name), - builder.getI64IntegerAttr(numResultHandles)); auto pdlOpType = pdl::OperationType::get(builder.getContext()); result.addTypes(SmallVector(numResultHandles, pdlOpType)); } DiagnosedSilenceableFailure -transform::SplitHandlesOp::apply(transform::TransformResults &results, - transform::TransformState &state) { - int64_t numResultHandles = - getHandle() ? state.getPayloadOps(getHandle()).size() : 0; - int64_t expectedNumResultHandles = getNumResultHandles(); - if (numResultHandles != expectedNumResultHandles) { - // Empty input handle corner case: always propagates empty handles in both - // suppress and propagate modes. - if (numResultHandles == 0) { - for (OpResult result : getResults()) - results.set(result, {}); - return DiagnosedSilenceableFailure::success(); - } +transform::SplitHandleOp::apply(transform::TransformResults &results, + transform::TransformState &state) { + int64_t numPayloadOps = state.getPayloadOps(getHandle()).size(); - // If the input handle was not empty and the number of result handles does - // not match, this is a legit silenceable error. - return emitSilenceableError() - << getHandle() << " expected to contain " << expectedNumResultHandles - << " operation handles but it contains " << numResultHandles - << " handles"; + // Empty handle corner case: all result handles are empty. + if (numPayloadOps == 0) { + for (OpResult result : getResults()) + results.set(result, {}); + return DiagnosedSilenceableFailure::success(); } - // Normal successful case. + + // If the input handle was not empty and the number of payload ops does not + // match, this is a legit silenceable error. + if (numPayloadOps != getNumResults()) + return emitSilenceableError() + << getHandle() << " expected to contain " << getNumResults() + << " payload ops but it contains " << numPayloadOps + << " payload ops"; + for (const auto &en : llvm::enumerate(state.getPayloadOps(getHandle()))) results.set(getResults()[en.index()].cast(), en.value()); + return DiagnosedSilenceableFailure::success(); } -void transform::SplitHandlesOp::getEffects( +void transform::SplitHandleOp::getEffects( SmallVectorImpl &effects) { onlyReadsHandle(getHandle(), effects); producesHandle(getResults(), effects); diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir --- a/mlir/test/Dialect/Transform/test-interpreter.mlir +++ b/mlir/test/Dialect/Transform/test-interpreter.mlir @@ -818,7 +818,7 @@ // ----- -func.func @split_handles(%a: index, %b: index, %c: index) { +func.func @split_handle(%a: index, %b: index, %c: index) { %0 = arith.muli %a, %b : index %1 = arith.muli %a, %c : index return @@ -827,17 +827,17 @@ transform.sequence failures(propagate) { ^bb1(%fun: !pdl.operation): %muli = transform.structured.match ops{["arith.muli"]} in %fun : (!pdl.operation) -> !pdl.operation - %h:2 = split_handles %muli in [2] : (!pdl.operation) -> (!pdl.operation, !pdl.operation) + %h:2 = split_handle %muli : (!pdl.operation) -> (!pdl.operation, !pdl.operation) // expected-remark @below {{1}} transform.test_print_number_of_associated_payload_ir_ops %h#0 %muli_2 = transform.structured.match ops{["arith.muli"]} in %fun : (!pdl.operation) -> !pdl.operation - // expected-error @below {{expected to contain 3 operation handles but it contains 2 handles}} - %h_2:3 = split_handles %muli_2 in [3] : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation) + // expected-error @below {{expected to contain 3 payload ops but it contains 2 payload ops}} + %h_2:3 = split_handle %muli_2 : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation) } // ----- -func.func @split_handles(%a: index, %b: index, %c: index) { +func.func @split_handle(%a: index, %b: index, %c: index) { %0 = arith.muli %a, %b : index %1 = arith.muli %a, %c : index return @@ -846,12 +846,12 @@ transform.sequence failures(suppress) { ^bb1(%fun: !pdl.operation): %muli = transform.structured.match ops{["arith.muli"]} in %fun : (!pdl.operation) -> !pdl.operation - %h:2 = split_handles %muli in [2] : (!pdl.operation) -> (!pdl.operation, !pdl.operation) + %h:2 = split_handle %muli : (!pdl.operation) -> (!pdl.operation, !pdl.operation) // expected-remark @below {{1}} transform.test_print_number_of_associated_payload_ir_ops %h#0 %muli_2 = transform.structured.match ops{["arith.muli"]} in %fun : (!pdl.operation) -> !pdl.operation // Silenceable failure and all handles are now empty. - %h_2:3 = split_handles %muli_2 in [3] : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation) + %h_2:3 = split_handle %muli_2 : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation) // expected-remark @below {{0}} transform.test_print_number_of_associated_payload_ir_ops %h_2#0 } @@ -966,7 +966,7 @@ // ----- -func.func @split_handles(%a: index, %b: index, %c: index) { +func.func @split_handle(%a: index, %b: index, %c: index) { %0 = arith.muli %a, %b : index %1 = arith.muli %a, %c : index return @@ -975,8 +975,8 @@ transform.sequence -> !pdl.operation failures(propagate) { ^bb1(%fun: !pdl.operation): %muli = transform.structured.match ops{["arith.muli"]} in %fun : (!pdl.operation) -> !pdl.operation - // expected-error @below {{expected to contain 3 operation handles but it contains 2 handles}} - %h_2:3 = split_handles %muli in [3] : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation) + // expected-error @below {{expected to contain 3 payload ops but it contains 2 payload ops}} + %h_2:3 = split_handle %muli : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation) /// Test that yield does not crash in the presence of silenceable error in /// propagate mode. yield %fun : !pdl.operation @@ -988,7 +988,7 @@ ^bb0(%arg0: !transform.any_op): %muli = transform.structured.match ops{["arith.muli"]} in %arg0 : (!transform.any_op) -> !transform.any_op // Edge case propagating empty handles in splitting. - %0:3 = split_handles %muli in [3] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) + %0:3 = split_handle %muli : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) // Test does not crash when accessing the empty handle. yield %0#0 : !transform.any_op }