diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp @@ -339,14 +339,7 @@ } // Otherwise, replace the pointed-to object of all handles while preserving - // their relative order. - if (op->getNumResults() != replacement->getNumResults()) { - return emitError(op->getLoc()) - << "cannot replace an op with another op producing a different " - "number of results while tracking handles"; - } - - // Replace the mapped operation if present. + // their relative order. First, replace the mapped operation if present. for (Value handle : opHandles) { Mappings &mappings = getMapping(handle); auto it = mappings.direct.find(handle); @@ -362,9 +355,21 @@ mappings.reverse[replacement].push_back(handle); } - // Replace the mapped results of the operation. - for (auto [origResult, replacementResult, handleList] : llvm::zip( - op->getResults(), replacement->getResults(), resultValueHandles)) { + // Second, replace the mapped results of the operation. + for (auto [origResult, handleList] : + llvm::zip(op->getResults(), resultValueHandles)) { + // No handles to the value, skip even if there is no replacement. + if (handleList.empty()) + continue; + + unsigned resultNumber = origResult.getResultNumber(); + if (resultNumber >= replacement->getNumResults()) { + return emitError(op->getLoc()) + << "cannot replace an op with another op producing less results " + "while tracking handles"; + } + + Value replacementResult = replacement->getResult(resultNumber); for (Value resultHandle : handleList) { Mappings &mappings = getMapping(resultHandle); auto it = mappings.values.find(resultHandle); diff --git a/mlir/test/Dialect/Transform/transform-state-extension.mlir b/mlir/test/Dialect/Transform/transform-state-extension.mlir --- a/mlir/test/Dialect/Transform/transform-state-extension.mlir +++ b/mlir/test/Dialect/Transform/transform-state-extension.mlir @@ -47,15 +47,36 @@ // ----- -// expected-error @below {{cannot replace an op with another op producing a different number of results while tracking handles}} -module { - transform.sequence failures(propagate) { - ^bb0(%arg0: !pdl.operation): - test_add_test_extension "A" - %dummy = test_remap_operand_to_self %arg0 : !transform.any_op - } +transform.sequence failures(propagate) { +^bb0(%arg0: !pdl.operation): + test_add_test_extension "A" + // This is okay because we are replacing the top-level module opeation + // (0 results) with this operation that has _more_ (1) results. + %dummy = test_remap_operand_to_self %arg0 : !pdl.operation +} + +// ----- + +transform.sequence failures(propagate) { +^bb0(%arg0: !pdl.operation): + test_add_test_extension "A" + %dummy = test_remap_operand_to_self %arg0 : !pdl.operation + // This is still okay. Even though we are replacing the previous + // operation with (1 result) with this operation that has less (0) results, + // there is no handle to the result, hence no issue with value handle update. + test_remap_operand_to_self %dummy } +// ----- + +transform.sequence failures(propagate) { +^bb0(%arg0: !pdl.operation): + test_add_test_extension "A" + // expected-error @below {{cannot replace an op with another op producing less results while tracking handles}} + %dummy = test_remap_operand_to_self %arg0 : !pdl.operation + %valuehandle = transform.get_result %dummy[0] : (!pdl.operation) -> !transform.any_value + test_remap_operand_to_self %dummy +} // ----- diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp @@ -297,6 +297,8 @@ if (failed(extension->updateMapping(state.getPayloadOps(getOperand()).front(), getOperation()))) return DiagnosedSilenceableFailure::definiteFailure(); + if (getNumResults() > 0) + results.set(getResult(0).cast(), getOperation()); return DiagnosedSilenceableFailure::success(); }