diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -699,11 +699,6 @@ transform::TransformState &state) { SmallVector fusedOps; auto producerOps = state.getPayloadOps(getProducerOp()); - // If nothing to fuse, propagate success. - if (std::empty(producerOps)) { - results.set(cast(getFusedOp()), SmallVector{}); - return DiagnosedSilenceableFailure::success(); - } auto containingOps = state.getPayloadOps(getContainingOp()); if (!llvm::hasSingleElement(containingOps)) { return emitDefiniteFailure() @@ -712,6 +707,13 @@ } Operation *containingOp = *containingOps.begin(); + // If nothing to fuse, propagate success. + if (std::empty(producerOps)) { + results.set(cast(getFusedOp()), SmallVector{}); + results.set(cast(getNewContainingOp()), {containingOp}); + return DiagnosedSilenceableFailure::success(); + } + // Helper function to find the next producer that should be fused. Take any // producer that has a use inside the containing op. SetVector remainingProducers(producerOps.begin(),