diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -155,6 +155,12 @@ bounds.resize(getNumRegions(), InvocationBounds(0, 1)); } +static void forwardEmptyOperands(Block *block, transform::TransformState &state, + transform::TransformResults &results) { + for (const auto &res : block->getParentOp()->getOpResults()) + results.set(res, {}); +} + static void forwardTerminatorOperands(Block *block, transform::TransformState &state, transform::TransformResults &results) { @@ -594,8 +600,11 @@ return result; if (result.isSilenceableFailure()) { - if (getFailurePropagationMode() == FailurePropagationMode::Propagate) + if (getFailurePropagationMode() == FailurePropagationMode::Propagate) { + // Propagate empty results in case of early exit. + forwardEmptyOperands(getBodyBlock(), state, results); return result; + } (void)result.silence(); } } diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir --- a/mlir/test/Dialect/Transform/test-interpreter.mlir +++ b/mlir/test/Dialect/Transform/test-interpreter.mlir @@ -920,3 +920,20 @@ } "test.some_op"() : () -> () +// ----- + +func.func @split_handles(%a: index, %b: index, %c: index) { + %0 = arith.muli %a, %b : index + %1 = arith.muli %a, %c : index + return +} + +transform.sequence -> !pdl.operation failures(propagate) { +^bb1(%fun: !pdl.operation): + %muli = transform.structured.match ops{["arith.muli"]} in %fun + // expected-error @below {{expected to contain 3 operation handles but it only contains 2 handles}} + %h_2:3 = split_handles %muli in [3] : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation) + /// Test that yield does not crash in the presence of silenceable error in + /// propagate mode. + yield %fun : !pdl.operation +}