diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -447,18 +447,40 @@ // Search the producer slices accessed within the containing operation. // TODO: Generalize to more extract/insert/parallel_insert triples, maybe // evolve into an interface. - auto it = llvm::find_if(tileableProducer->getUsers(), [&](Operation *user) { + Attribute commonSize, commonStride, commonOffset; + SmallVector consumers; + llvm::find_if(tileableProducer->getUsers(), [&](Operation *user) { auto sliceOp = dyn_cast(user); - return sliceOp && containingOp->isProperAncestor(sliceOp); + if (sliceOp && containingOp->isProperAncestor(sliceOp)) { + auto size = sliceOp.getMixedSizes()[0].dyn_cast(); + auto offset = sliceOp.getMixedOffsets()[0].dyn_cast(); + auto stride = sliceOp.getMixedStrides()[0].dyn_cast(); + // Find fully overlapped consumers that share the same producer. + if (!commonSize) + commonSize = size; + if (size != commonSize) + return false; + if (!commonOffset) + commonOffset = offset; + if (offset != commonOffset) + return false; + if (!commonStride) + commonStride = stride; + if (stride != commonStride) + return false; + consumers.push_back(sliceOp); + } + return false; }); // Find a fusion opportunity. - if (it == tileableProducer->getUsers().end()) { + if (consumers.size() == 0) { diag.attachNote(tileableProducer->getLoc()) << "could not find fusion opportunity for: " << *tileableProducer; return {}; } - auto sliceOpToTile = cast(*it); + auto firstConsumer = consumers[consumers.size() - 1]; + auto sliceOpToTile = cast(*firstConsumer); // Try to fuse the producer in-place. OpBuilder::InsertionGuard guard(rewriter); @@ -495,6 +517,14 @@ assert(succeeded(maybeRankReduced) && "unexpected shape"); rewriter.replaceOp(sliceOpToTile, *maybeRankReduced); + // Fully overlapped consumers can share a single fused producer. + Operation *fusedOp = tileAndFuseResult->tiledOps[0]; + for (int64_t i = consumers.size() - 2; i >= 0; --i) { + auto consumer = consumers[i]; + auto sliceOp = dyn_cast(*consumer); + rewriter.replaceOp(sliceOp, fusedOp->getResult(resultNumber)); + } + // Add new outputs to containing op, if required Operation *newContainingOp = replaceForAllWithNewSignature( rewriter, diag, producerOp, containingOp, *tileAndFuseResult, @@ -736,6 +766,8 @@ return producerOp; } } + // This point is reached when remainingProducers contains ONLY poducerOps + // with zero numUsersInContainingOp. return failure(); }; @@ -743,10 +775,12 @@ IRRewriter rewriter(getContext(), &listener); while (!remainingProducers.empty()) { auto nextProducer = getNextProducer(); - if (failed(nextProducer)) { - return mlir::emitSilenceableFailure(containingOp->getLoc()) - << "could not find next producer to fuse into container"; - } + + // Only producers with zero numUsesInContainingOp reach this point. These + // producers have multiple overlapped consumers that were replaced in + // one-go. + if (failed(nextProducer)) + break; Operation *producerOp = *nextProducer; diff --git a/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir b/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir --- a/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir @@ -626,3 +626,79 @@ : (!transform.op<"linalg.generic">, !transform.op<"scf.forall">) -> (!transform.any_op, !transform.any_op) } } + + +// ----- + +#map0 = affine_map<()[s0, s1] -> (s0 ceildiv s1)> +#map1 = affine_map<(d0)[s0] -> (d0 * s0)> +#map2 = affine_map<(d0)[s0, s1] -> (-(d0 * s1) + s0, s1)> + +module { + // Two producers: + // First producer has two fully overlapped consumers. + // Second producer has two fully overlapped consumers and a non-overlapping consumer. + // + // CHECK-LABEL: func.func @two_producers_multiple_consumers + // CHECK-SAME: %[[CHUNK_SIZE:[0-9a-z]+]]: index + // CHECK-SAME: %[[IN:[0-9a-z]+]]: tensor + // CHECK-SAME: %[[OUT:[0-9a-z]+]]: tensor<32xf32> + func.func @two_producers_multiple_consumers(%arg0: index, + %arg1: tensor, %arg2: tensor<32xf32>, %arg3: tensor<64xf32>) -> tensor<32xf32> { + + // CHECK: %[[VAL:.*]] = arith.constant + %cst0 = arith.constant 4.0 : f32 + %p0 = linalg.fill ins(%cst0 : f32) outs(%arg1 : tensor) -> tensor + + // CHECK: %[[VAL2:.*]] = arith.constant + %cst2 = arith.constant 5.0 : f32 + %p1 = linalg.fill ins(%cst2 : f32) outs(%arg3 : tensor<64xf32>) -> tensor<64xf32> + + %c0 = arith.constant 0 : index + %d0 = tensor.dim %arg1, %c0 : tensor + %1 = affine.apply #map0()[%d0, %arg0] + + // CHECK: scf.forall {{.*}} { + %2 = scf.forall (%arg4) in (%1) shared_outs(%out = %arg2) -> (tensor<32xf32>) { + %3 = affine.apply #map1(%arg4)[%arg0] + %4 = affine.min #map2(%arg4)[%d0, %arg0] + // CHECK: %extracted_slice + %5 = tensor.extract_slice %out[5] [32] [1] : tensor<32xf32> to tensor<32xf32> + + // CHECK: tensor.extract_slice + %6 = tensor.extract_slice %p0[%3] [%4] [1] : tensor to tensor + // CHECK: %[[SLICE:.*]] = linalg.fill ins(%[[VAL]] : f32) + // CHECK: linalg.elemwise_unary ins(%[[SLICE]] : tensor) + %7 = linalg.elemwise_unary ins(%6 : tensor) outs(%5 : tensor<32xf32>) -> tensor<32xf32> + + // CHECK: tensor.extract_slice + %8 = tensor.extract_slice %p1[0] [32] [1] : tensor<64xf32> to tensor<32xf32> + // CHECK: %[[SLICE2:.*]] = linalg.fill ins(%[[VAL2]] : f32) + // CHECK: linalg.elemwise_unary ins(%[[SLICE2]] : tensor<32xf32>) + %9 = linalg.elemwise_unary ins(%8 : tensor<32xf32>) outs(%5 : tensor<32xf32>) -> tensor<32xf32> + + // CHECK: linalg.elemwise_unary ins(%[[SLICE]] : tensor) + %10 = tensor.extract_slice %p0[%3] [%4] [1] : tensor to tensor + %11 = linalg.elemwise_unary ins(%10 : tensor) outs(%5 : tensor<32xf32>) -> tensor<32xf32> + + // CHECK: linalg.elemwise_unary ins(%[[SLICE2]] : tensor<32xf32>) + %12 = tensor.extract_slice %p1[0] [32] [1] : tensor<64xf32> to tensor<32xf32> + %13 = linalg.elemwise_unary ins(%12 : tensor<32xf32>) outs(%5 : tensor<32xf32>) -> tensor<32xf32> + + // CHECK: tensor.extract_slice + %14 = tensor.extract_slice %p1[7] [32] [1] : tensor<64xf32> to tensor<32xf32> + // CHECK: %[[SLICE3:.*]] = linalg.fill ins(%[[VAL2]] : f32) + // CHECK: linalg.elemwise_unary ins(%[[SLICE3]] : tensor<32xf32>) + %15 = linalg.elemwise_unary ins(%14 : tensor<32xf32>) outs(%5 : tensor<32xf32>) -> tensor<32xf32> + } + func.return %2 : tensor<32xf32> + } + + transform.sequence failures(propagate) { + ^bb1(%arg1: !transform.any_op): + %0 = transform.structured.match ops{["linalg.fill"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1 = transform.structured.match ops{["scf.forall"]} in %arg1 : (!transform.any_op) -> !transform.op<"scf.forall"> + %fused_ops, %new_forall = transform.structured.fuse_into_containing_op %0 into %1 : + (!transform.any_op, !transform.op<"scf.forall">) -> (!transform.any_op, !transform.op<"scf.forall">) + } +}