diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h --- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h @@ -824,6 +824,17 @@ decltype(&OpTy::applyToOne)>::template arg_t<0>; ArrayRef targets = state.getPayloadOps(this->getOperation()->getOperand(0)); + // Handle the corner case where no target is specified. + // This is typically the case when the matcher fails to apply and we need to + // propagate gracefully. + // In this case, we fill all results with an empty vector. + if (targets.empty()) { + SmallVector emptyResult; + for (auto r : this->getOperation()->getResults()) + transformResults.set(r.template cast(), emptyResult); + return DiagnosedSilenceableFailure::success(); + } + SmallVector, 1> results; // In the multi-result case, collect the number of results each transform // produced. @@ -831,14 +842,17 @@ targets, results, [&](TransformOpType specificOp) { return static_cast(this)->applyToOne(specificOp, state); }); + // Propagate the failure (definite or silencable) if any. if (!result.succeeded()) return result; - if (results.empty()) + + // Legitimately no results, bail early. + if (results.empty() && OpTy::template hasTrait()) return DiagnosedSilenceableFailure::success(); // Ensure all applications return the same number of results. // Variadic cases are much trickier to handle in a generic fashion. - int64_t nRes = results[0].size(); + int64_t nRes = results.empty() ? 0 : results[0].size(); if (llvm::any_of(results, [&](const auto &r) { return static_cast(r.size()) != nRes; })) { @@ -849,6 +863,8 @@ "generic `apply` instead of the specialized `applyToOne`"; } // Ensure the number of results agrees with what the transform op expects. + // Unless we see empty results, in which case we just want to propagate the + // emptiness. if (this->getOperation()->getNumResults() != nRes) { InFlightDiagnostic diag = static_cast(this)->emitError() << "unexpected number of results (got " << nRes @@ -857,10 +873,6 @@ return DiagnosedSilenceableFailure::definiteFailure(); } - // If no results, bail early. - if (OpTy::template hasTrait()) - return DiagnosedSilenceableFailure::success(); - // Perform transposition of M applications producing N results each into N // results for each of the M applications. SmallVector> transposedResults = diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir --- a/mlir/test/Dialect/Transform/test-interpreter.mlir +++ b/mlir/test/Dialect/Transform/test-interpreter.mlir @@ -436,3 +436,27 @@ %1:2 = transform.test_correct_number_of_multi_results %0 } } + +// ----- + +func.func @foo() { + "wrong_op_name" () : () -> () + return +} + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + pdl.pattern @some : benefit(1) { + %0 = pdl.operands + %1 = pdl.types + %2 = pdl.operation "op"(%0 : !pdl.range) -> (%1 : !pdl.range) + pdl.rewrite %2 with "transform.dialect" + } + + transform.sequence %arg0 { + ^bb0(%arg1: !pdl.operation): + %0 = pdl_match @some in %arg1 + // Transform fails to match any but still produces 2 results. + %1:2 = transform.test_correct_number_of_multi_results %0 + } +}