diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -34,6 +34,7 @@ #include "mlir/IR/TypeUtilities.h" #include "mlir/Interfaces/TilingInterface.h" #include "mlir/Support/LLVM.h" +#include "mlir/Support/TypeID.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/ScopeExit.h" @@ -663,6 +664,36 @@ return true; } +namespace { +/// Unsafely exposes an internal protected method of TransformState::Extension +/// as public. +/// +/// MUST NOT be used directly. +class UnsafeOpReplacementStateExtension : public TransformState::Extension { +public: + UnsafeOpReplacementStateExtension(TransformState &state) + : TransformState::Extension(state) {} + + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( + UnsafeOpReplacementStateExtension) + + LogicalResult doReplacePayloadOp(Operation *op, Operation *replacement) { + return replacePayloadOp(op, replacement); + } +}; +} // namespace + +/// Replaces `payload` with `replacement` in all handles stored in the state. +/// MUST NOT be used except for the case immediately below. +static void forciblyReplaceReferencedPayloadOperation(TransformState &state, + Operation *payload, + Operation *replacement) { + UnsafeOpReplacementStateExtension extension(state); + // This may return failure if the payload is not associated with any handle, + // ignore that. + (void)extension.doReplacePayloadOp(payload, replacement); +} + DiagnosedSilenceableFailure transform::FuseIntoContainingOp::apply(transform::TransformResults &results, transform::TransformState &state) { @@ -757,6 +788,14 @@ return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag)); } + // Update handles associated with the containing op so we don't need to + // invalidate them. This is a hack to support better composability between + // tiling and fusion while a proper mechanism is being investigated. + // + // DO NOT replicate this elsewhere unless you understand what you are doing. + forciblyReplaceReferencedPayloadOperation(state, *containingOps.begin(), + containingOp); + results.set(cast(getFusedOp()), fusedOps); results.set(cast(getNewContainingOp()), {containingOp}); return DiagnosedSilenceableFailure::success(); @@ -765,7 +804,7 @@ void transform::FuseIntoContainingOp::getEffects( SmallVectorImpl &effects) { consumesHandle(getProducerOp(), effects); - consumesHandle(getContainingOp(), effects); + onlyReadsHandle(getContainingOp(), effects); producesHandle(getResults(), effects); modifiesPayload(effects); } diff --git a/mlir/test/Dialect/Linalg/transform-ops.mlir b/mlir/test/Dialect/Linalg/transform-ops.mlir --- a/mlir/test/Dialect/Linalg/transform-ops.mlir +++ b/mlir/test/Dialect/Linalg/transform-ops.mlir @@ -35,3 +35,15 @@ // CHECK: transform.structured.scalarize %0 = transform.structured.scalarize %arg0 : (!transform.any_op) -> !transform.any_op } + +// Check that the second argument of `fuse_into_containing_op` is not consumed +// (if it had been, we would have seen a diagnostic about multiple consumers). +transform.sequence failures(propagate) { +^bb1(%arg0: !transform.any_op, %arg1: !transform.any_op, %arg2: !transform.any_op): + %loop = transform.structured.match ops{["scf.forall"]} in %arg0 + : (!transform.any_op) -> !transform.any_op + %0:2 = transform.structured.fuse_into_containing_op %arg1 into %loop + : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + %1:2 = transform.structured.fuse_into_containing_op %arg2 into %loop + : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) +}