diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td --- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td @@ -521,8 +521,8 @@ let description = [{ Splits `handle` into one or multiple handles, as specified by the number of results of this operation. `handle` should be mapped to as many payload - ops as there are results. Otherwise, this transform will fail silently. - Each result handle is mapped to exactly one payload op. The order + ops as there are results. Otherwise, this transform will fail silently by + default. Each result handle is mapped to exactly one payload op. The order of the payload ops is preserved, i.e., the i-th payload op is mapped to the i-th result handle. @@ -530,12 +530,23 @@ operations are tracked by the source `handle` and to extract them into individual handles that can be further manipulated in isolation. - If `handle` is empty, this transform will succeed and all result handles - are empty. + If there are more payload ops than results, the remaining ops are mapped to + the result with index `overflow_result`. If no `overflow_result` is + specified, the transform fails silently. + + If there are fewer payload ops than results, the transform fails silently + if `fail_on_payload_too_small` is set to "true". Otherwise, it succeeds and + the remaining result handles are not mapped to any op. It also succeeds if + `handle` is empty and `pass_through_empty_handle` is set to "true", + regardless of `fail_on_payload_too_small`. }]; - let arguments = (ins TransformHandleTypeInterface:$handle); + let arguments = (ins TransformHandleTypeInterface:$handle, + DefaultValuedAttr<BoolAttr, "true">:$pass_through_empty_handle, + DefaultValuedAttr<BoolAttr, "true">:$fail_on_payload_too_small, + OptionalAttr<I64Attr>:$overflow_result); let results = (outs Variadic<TransformHandleTypeInterface>:$results); + let hasVerifier = 1; let builders = [ OpBuilder<(ins "Value":$handle, "int64_t":$numResultHandles)> diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -1502,24 +1502,40 @@ transform::SplitHandleOp::apply(transform::TransformResults &results, transform::TransformState &state) { int64_t numPayloadOps = state.getPayloadOps(getHandle()).size(); - - // Empty handle corner case: all result handles are empty. - if (numPayloadOps == 0) { - for (OpResult result : getResults()) - results.set(result, {}); - return DiagnosedSilenceableFailure::success(); - } - - // If the input handle was not empty and the number of payload ops does not - // match, this is a legit silenceable error. - if (numPayloadOps != getNumResults()) + auto produceNumOpsError = [&]() { return emitSilenceableError() - << getHandle() << " expected to contain " << getNumResults() + << getHandle() << " expected to contain " << this->getNumResults() << " payload ops but it contains " << numPayloadOps << " payload ops"; + }; - for (const auto &en : llvm::enumerate(state.getPayloadOps(getHandle()))) - results.set(getResults()[en.index()].cast<OpResult>(), en.value()); + // Fail if there are more payload ops than results and no overflow result was + // specified. + if (numPayloadOps > getNumResults() && !getOverflowResult().has_value()) + return produceNumOpsError(); + + // Fail if there are more results than payload ops. Unless: + // - "fail_on_payload_too_small" is set to "false", or + // - "pass_through_empty_handle" is set to "true" and there are 0 payload ops. + if (numPayloadOps < getNumResults() && getFailOnPayloadTooSmall() && + !(numPayloadOps == 0 && getPassThroughEmptyHandle())) + return produceNumOpsError(); + + // Distribute payload ops. + SmallVector<SmallVector<Operation *, 1>> resultHandles(getNumResults(), {}); + if (getOverflowResult()) + resultHandles[*getOverflowResult()].reserve(numPayloadOps - + getNumResults()); + for (auto &&en : llvm::enumerate(state.getPayloadOps(getHandle()))) { + int64_t resultNum = en.index(); + if (resultNum >= getNumResults()) + resultNum = *getOverflowResult(); + resultHandles[resultNum].push_back(en.value()); + } + + // Set transform op results. + for (auto &&it : llvm::enumerate(resultHandles)) + results.set(getResult(it.index()).cast<OpResult>(), it.value()); return DiagnosedSilenceableFailure::success(); } @@ -1532,6 +1548,13 @@ // manipulation. } +LogicalResult transform::SplitHandleOp::verify() { + if (getOverflowResult().has_value() && + !(*getOverflowResult() >= 0 && *getOverflowResult() < getNumResults())) + return emitOpError("overflow_result is not a valid result index"); + return success(); +} + //===----------------------------------------------------------------------===// // PDLMatchOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir --- a/mlir/test/Dialect/Transform/test-interpreter.mlir +++ b/mlir/test/Dialect/Transform/test-interpreter.mlir @@ -858,6 +858,47 @@ // ----- +func.func @split_handle(%a: index, %b: index, %c: index) { + %0 = arith.muli %a, %b : index + %1 = arith.muli %a, %c : index + return +} + +transform.sequence failures(propagate) { +^bb1(%fun: !pdl.operation): + %muli_2 = transform.structured.match ops{["arith.muli"]} in %fun : (!pdl.operation) -> !pdl.operation + // No error, last result handle is empty. + %h:3 = split_handle %muli_2 {fail_on_payload_too_small = false} : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation) + // expected-remark @below {{1}} + transform.test_print_number_of_associated_payload_ir_ops %h#0 + // expected-remark @below {{1}} + transform.test_print_number_of_associated_payload_ir_ops %h#1 + // expected-remark @below {{0}} + transform.test_print_number_of_associated_payload_ir_ops %h#2 +} + +// ----- + +func.func @split_handle(%a: index, %b: index, %c: index) { + %0 = arith.muli %a, %b : index + %1 = arith.muli %a, %c : index + %2 = arith.muli %a, %c : index + %3 = arith.muli %a, %c : index + return +} + +transform.sequence failures(propagate) { +^bb1(%fun: !pdl.operation): + %muli_2 = transform.structured.match ops{["arith.muli"]} in %fun : (!pdl.operation) -> !pdl.operation + %h:2 = split_handle %muli_2 {overflow_result = 0} : (!pdl.operation) -> (!pdl.operation, !pdl.operation) + // expected-remark @below {{3}} + transform.test_print_number_of_associated_payload_ir_ops %h#0 + // expected-remark @below {{1}} + transform.test_print_number_of_associated_payload_ir_ops %h#1 +} + +// ----- + "test.some_op"() : () -> () "other_dialect.other_op"() : () -> ()