diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td --- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td @@ -509,26 +509,28 @@ }]; } -def SplitHandlesOp : TransformDialectOp<"split_handles", +def SplitHandleOp : TransformDialectOp<"split_handle", [FunctionalStyleTransformOpTrait, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { - let summary = "Splits handles from a union of payload ops to a list"; + let summary = "Splits a handle of payload ops into handles with a single op"; let description = [{ Creates `num_result_handles` transform IR handles extracted from the - `handle` operand. The resulting Payload IR operation handles are listed - in the same order as the operations appear in the source `handle`. - This is useful for ensuring a statically known number of operations are - tracked by the source `handle` and to extract them into individual handles - that can be further manipulated in isolation. - - This operation succeeds and returns `num_result_handles` if the statically - specified `num_result_handles` corresponds to the dynamic number of - operations contained in the source `handle`. Otherwise it silently fails. + `handle` operand. Each result handle is mapped to one payload op from the + `handle` operand. The result handles are listed in the same order as the + operations appear in the source `handle`. + + This operation is useful for ensuring a statically known number of + operations are tracked by the source `handle` and to extract them into + individual handles that can be further manipulated in isolation. + + This operation succeeds if the statically specified `num_result_handles` + corresponds to the dynamic number of operations mapped to the source + `handle`. It also succeeds if the source `handle` is empty. Otherwise it + silently fails. }]; - let arguments = (ins TransformHandleTypeInterface:$handle, - I64Attr:$num_result_handles); + let arguments = (ins TransformHandleTypeInterface:$handle); let results = (outs Variadic:$results); let builders = [ @@ -536,8 +538,7 @@ ]; let assemblyFormat = [{ - $handle `in` `[` $num_result_handles `]` - attr-dict `:` functional-type(operands, results) + $handle attr-dict `:` functional-type(operands, results) }]; } diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -1488,44 +1488,40 @@ } //===----------------------------------------------------------------------===// -// SplitHandlesOp +// SplitHandleOp //===----------------------------------------------------------------------===// -void transform::SplitHandlesOp::build(OpBuilder &builder, - OperationState &result, Value target, - int64_t numResultHandles) { +void transform::SplitHandleOp::build(OpBuilder &builder, OperationState &result, + Value target, int64_t numResultHandles) { result.addOperands(target); - result.addAttribute(SplitHandlesOp::getNumResultHandlesAttrName(result.name), - builder.getI64IntegerAttr(numResultHandles)); auto pdlOpType = pdl::OperationType::get(builder.getContext()); result.addTypes(SmallVector(numResultHandles, pdlOpType)); } DiagnosedSilenceableFailure -transform::SplitHandlesOp::apply(transform::TransformResults &results, - transform::TransformState &state) { - int64_t numResultHandles = - getHandle() ? state.getPayloadOps(getHandle()).size() : 0; - int64_t expectedNumResultHandles = getNumResultHandles(); - if (numResultHandles != expectedNumResultHandles) { - // Empty input handle corner case: always propagates empty handles in both - // suppress and propagate modes. - if (numResultHandles == 0) - return DiagnosedSilenceableFailure::success(); - // If the input handle was not empty and the number of result handles does - // not match, this is a legit silenceable error. +transform::SplitHandleOp::apply(transform::TransformResults &results, + transform::TransformState &state) { + int64_t numPayloadOps = state.getPayloadOps(getHandle()).size(); + + // Empty handle corner case: all result handles are empty. + if (numPayloadOps == 0) + return DiagnosedSilenceableFailure::success(); + + // If the input handle was not empty and the number of payload ops does not + // match, this is a legit silenceable error. + if (numPayloadOps != getNumResults()) return emitSilenceableError() - << getHandle() << " expected to contain " << expectedNumResultHandles - << " operation handles but it only contains " << numResultHandles - << " handles"; - } - // Normal successful case. + << getHandle() << " expected to contain " << getNumResults() + << " payload ops but it only contains " << numPayloadOps + << " payload ops"; + for (const auto &en : llvm::enumerate(state.getPayloadOps(getHandle()))) results.set(getResults()[en.index()].cast(), en.value()); + return DiagnosedSilenceableFailure::success(); } -void transform::SplitHandlesOp::getEffects( +void transform::SplitHandleOp::getEffects( SmallVectorImpl &effects) { onlyReadsHandle(getHandle(), effects); producesHandle(getResults(), effects); diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir --- a/mlir/test/Dialect/Transform/test-interpreter.mlir +++ b/mlir/test/Dialect/Transform/test-interpreter.mlir @@ -822,7 +822,7 @@ // ----- -func.func @split_handles(%a: index, %b: index, %c: index) { +func.func @split_handle(%a: index, %b: index, %c: index) { %0 = arith.muli %a, %b : index %1 = arith.muli %a, %c : index return @@ -831,17 +831,17 @@ transform.sequence failures(propagate) { ^bb1(%fun: !pdl.operation): %muli = transform.structured.match ops{["arith.muli"]} in %fun : (!pdl.operation) -> !pdl.operation - %h:2 = split_handles %muli in [2] : (!pdl.operation) -> (!pdl.operation, !pdl.operation) + %h:2 = split_handle %muli : (!pdl.operation) -> (!pdl.operation, !pdl.operation) // expected-remark @below {{1}} transform.test_print_number_of_associated_payload_ir_ops %h#0 %muli_2 = transform.structured.match ops{["arith.muli"]} in %fun : (!pdl.operation) -> !pdl.operation - // expected-error @below {{expected to contain 3 operation handles but it only contains 2 handles}} - %h_2:3 = split_handles %muli_2 in [3] : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation) + // expected-error @below {{expected to contain 3 payload ops but it only contains 2 payload ops}} + %h_2:3 = split_handle %muli_2 : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation) } // ----- -func.func @split_handles(%a: index, %b: index, %c: index) { +func.func @split_handle(%a: index, %b: index, %c: index) { %0 = arith.muli %a, %b : index %1 = arith.muli %a, %c : index return @@ -850,12 +850,12 @@ transform.sequence failures(suppress) { ^bb1(%fun: !pdl.operation): %muli = transform.structured.match ops{["arith.muli"]} in %fun : (!pdl.operation) -> !pdl.operation - %h:2 = split_handles %muli in [2] : (!pdl.operation) -> (!pdl.operation, !pdl.operation) + %h:2 = split_handle %muli : (!pdl.operation) -> (!pdl.operation, !pdl.operation) // expected-remark @below {{1}} transform.test_print_number_of_associated_payload_ir_ops %h#0 %muli_2 = transform.structured.match ops{["arith.muli"]} in %fun : (!pdl.operation) -> !pdl.operation // Silenceable failure and all handles are now empty. - %h_2:3 = split_handles %muli_2 in [3] : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation) + %h_2:3 = split_handle %muli_2 : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation) // expected-remark @below {{0}} transform.test_print_number_of_associated_payload_ir_ops %h_2#0 } @@ -970,7 +970,7 @@ // ----- -func.func @split_handles(%a: index, %b: index, %c: index) { +func.func @split_handle(%a: index, %b: index, %c: index) { %0 = arith.muli %a, %b : index %1 = arith.muli %a, %c : index return @@ -979,8 +979,8 @@ transform.sequence -> !pdl.operation failures(propagate) { ^bb1(%fun: !pdl.operation): %muli = transform.structured.match ops{["arith.muli"]} in %fun : (!pdl.operation) -> !pdl.operation - // expected-error @below {{expected to contain 3 operation handles but it only contains 2 handles}} - %h_2:3 = split_handles %muli in [3] : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation) + // expected-error @below {{expected to contain 3 payload ops but it only contains 2 payload ops}} + %h_2:3 = split_handle %muli : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation) /// Test that yield does not crash in the presence of silenceable error in /// propagate mode. yield %fun : !pdl.operation