diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -143,7 +143,8 @@ def FuseIntoContainingOp : Op, + [DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "Fuse a producer into a containing operation."; @@ -160,7 +161,7 @@ producer op handle may be associated with multiple payload ops. This transform fuses producers one-by-one, always picking an unspecified producer that has at least one use inside the containing op among the - producers. + producers. A producer can be listed multiple times in the handle. Note: If a producer has multiple uses inside the containing op, it is currently tiled and/or cloned multiple times into the containing op. @@ -176,8 +177,8 @@ containing op. I.e., "producers" that are not consumed within the containing op are rejected by this operation. - This operation reads and frees the producer handle. - This operation reads the containing op handle. + This operation consumes the producer handle. + This operation only reads the containing op handle. }]; let arguments = (ins PDL_Operation:$producer_op, diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -571,6 +571,11 @@ return fusedOp; } +bool transform::FuseIntoContainingOp::allowsRepeatedHandleOperands() { + // Allow repeated handles since we are fusing everything anyway. + return true; +} + DiagnosedSilenceableFailure transform::FuseIntoContainingOp::apply(transform::TransformResults &results, transform::TransformState &state) { @@ -591,8 +596,8 @@ // Helper function to find the next producer that should be fused. Take any // producer that has a use inside the containing op. - SmallVector remainingProducers(producerOps.begin(), - producerOps.end()); + SetVector remainingProducers(producerOps.begin(), + producerOps.end()); auto getNextProducer = [&]() -> FailureOr { for (const auto &it : enumerate(remainingProducers)) { Operation *producerOp = it.value(); diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp @@ -724,6 +724,10 @@ FULL_LDBG("--handle not consumed -> SKIP\n"); continue; } + if (transform.allowsRepeatedHandleOperands()) { + FULL_LDBG("--op allows repeated handles -> SKIP\n"); + continue; + } FULL_LDBG("--handle is consumed\n"); Type operandType = operand.get().getType(); diff --git a/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir b/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir --- a/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir @@ -247,3 +247,39 @@ transform.structured.fuse_into_containing_op %0 into %1 } } + +// ----- + +module { + // CHECK-LABEL: func.func @fuse_repeated + func.func @fuse_repeated(%fill: tensor<2xf32>, %output: tensor<2xf32>) -> tensor<2xf32> { + %c0 = arith.constant 0.0 : f32 + %0 = linalg.fill ins(%c0 : f32) outs(%fill : tensor<2xf32>) -> tensor<2xf32> + + // CHECK: scf.forall + %1 = scf.forall (%i) in (2) shared_outs(%arg1 = %output) -> (tensor<2xf32>) { + %2 = tensor.extract_slice %0[%i][1][1] : tensor<2xf32> to tensor<1xf32> + %3 = tensor.extract_slice %arg1[%i][1][1] : tensor<2xf32> to tensor<1xf32> + // CHECK: %[[FUSED:.+]] = linalg.fill + // CHECK: elemwise_unary ins(%[[FUSED]] + %4 = linalg.elemwise_unary ins(%2 : tensor<1xf32>) outs(%3 : tensor<1xf32>) -> tensor<1xf32> + scf.forall.in_parallel { + tensor.parallel_insert_slice %4 into %arg1[%i][1][1] : tensor<1xf32> into tensor<2xf32> + } + } + + return %1 : tensor<2xf32> + } + + transform.sequence failures(propagate) { + ^bb1(%arg1: !transform.any_op): + %0 = transform.structured.match ops{["linalg.fill"]} in %arg1 : (!transform.any_op) -> !pdl.operation + %1 = transform.structured.match ops{["scf.forall"]} in %arg1 : (!transform.any_op) -> !pdl.operation + + // Create a new handle that points to `linalg.fill` twice. + %2 = transform.merge_handles %0, %0 : !pdl.operation + + // It shouldn't be a problem to fuse this handle. + transform.structured.fuse_into_containing_op %2 into %1 + } +}