diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -71,9 +71,8 @@ let description = [{Fuse a producer into a containing operation.}]; let summary = [{ - Fuses the `producer_op` into the `containing_op`. Only producers with a - single result are supported at the moment. Returns a handle to the fused - ops. + Fuses the `producer_op` into the `containing_op`. + Returns a handle to the fused ops. The producer is typically a slice of a tileable op (i.e., implements TilingInterface). In that case, this transform computes the accessed @@ -98,8 +97,10 @@ This is the case when tiling fails or when no producer op could be found among the remaining producers that has at least one use within the containing op. I.e., "producers" that are not consumed within the containing - op are rejected by this operation. This operation reads and frees the - producer handle. It reads the containing op handle. + op are rejected by this operation. + + This operation reads and frees the producer handle. + This operation reads the containing op handle. }]; let arguments = (ins Arg extractUIntArray(ArrayAttr attr) { @@ -258,6 +261,7 @@ Diagnostic &diag, Operation *producerOp, Operation *containingOp) { + LLVM_DEBUG(llvm::dbgs() << "Try to fuse a direct extract use\n"); auto tileableProducer = dyn_cast(producerOp); if (!tileableProducer) { diag.attachNote(producerOp->getLoc()) @@ -286,18 +290,23 @@ rewriter.setInsertionPoint(sliceOpToTile); // Tile the producer. + int64_t resultNumber = + sliceOpToTile.getSource().cast().getResultNumber(); + LLVM_DEBUG(llvm::dbgs() << "resultNumber: " << resultNumber << "\n"); + FailureOr tiledProducer = tileableProducer.generateResultTileValue( - rewriter, /*resultNumber=*/0, sliceOpToTile.getMixedOffsets(), + rewriter, resultNumber, sliceOpToTile.getMixedOffsets(), sliceOpToTile.getMixedSizes()); if (failed(tiledProducer)) { diag.attachNote(tileableProducer->getLoc()) << "failed to tile producer op: " << *tileableProducer; return nullptr; } + LLVM_DEBUG(llvm::dbgs() << "tiledProducer: " << *tiledProducer << "\n"); // Replace the extract op. Operation *fusedOp = tiledProducer->getDefiningOp(); - rewriter.replaceOp(sliceOpToTile, fusedOp->getResult(0)); + rewriter.replaceOp(sliceOpToTile, fusedOp->getResult(resultNumber)); return fusedOp; } @@ -310,6 +319,8 @@ static Operation *tileAndFuseFirstExtractUseThroughContainingOpBlockArgument( RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp) { + LLVM_DEBUG( + llvm::dbgs() << "Try to fuse an extract use through block argument\n"); auto tileableProducer = dyn_cast(producerOp); if (!tileableProducer) { @@ -318,16 +329,6 @@ return nullptr; } - // Ensure `tileableProducer` has exactly one destination operand that we can - // replace the ForeachThreadOp bbArg with. - auto destinationOperands = tileableProducer.getDestinationOperands(rewriter); - if (destinationOperands.size() != 1) { - diag.attachNote(tileableProducer->getLoc()) - << "tileableProducer must have exactly one destination operand: " - << *tileableProducer; - return nullptr; - } - // Search the first use by a "scf::ForeachThreadOp" user. scf::ForeachThreadOp foreachThreadOp; auto itProducerUses = @@ -371,8 +372,13 @@ // Replace the use in the tileableProducer before tiling: clone, replace and // then tile. + int64_t resultNumber = pUse->get().cast().getResultNumber(); + LLVM_DEBUG(llvm::dbgs() << "resultNumber: " << resultNumber << "\n"); + + auto destinationOperands = tileableProducer.getDestinationOperands(rewriter); + BlockAndValueMapping bvm; - bvm.map(destinationOperands.front(), bbArg); + bvm.map(destinationOperands[resultNumber], bbArg); auto tileableProducerClone = cast(rewriter.clone(*tileableProducer, bvm)); auto scopeGuard = @@ -381,17 +387,18 @@ // Tile the producer. FailureOr tiledProducer = tileableProducerClone.generateResultTileValue( - rewriter, /*resultNumber=*/0, sliceOpToTile.getMixedOffsets(), + rewriter, resultNumber, sliceOpToTile.getMixedOffsets(), sliceOpToTile.getMixedSizes()); if (failed(tiledProducer)) { diag.attachNote(tileableProducer->getLoc()) << "failed to tile producer op: " << *tileableProducer; return nullptr; } + LLVM_DEBUG(llvm::dbgs() << "tiledProducer: " << *tiledProducer << "\n"); // Replace the extract op. Operation *fusedOp = tiledProducer->getDefiningOp(); - rewriter.replaceOp(sliceOpToTile, fusedOp->getResult(0)); + rewriter.replaceOp(sliceOpToTile, fusedOp->getResult(resultNumber)); // Replace the use in containingOp. rewriter.updateRootInPlace(containingOp, [&]() { @@ -405,6 +412,8 @@ static Operation *cloneAndFuseFirstUse(RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp) { + LLVM_DEBUG(llvm::dbgs() << "Try to fuse an use by cloning\n"); + // Gather all uses inside the containing op. SmallVector uses; for (OpResult result : producerOp->getOpResults()) { @@ -437,6 +446,8 @@ assert(!isa(use->getOwner()) && "Parallel insert slice is not a valid clone destination"); unsigned resultNumber = use->get().cast().getResultNumber(); + LLVM_DEBUG(llvm::dbgs() << "resultNumber: " << resultNumber << "\n"); + OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPoint(use->getOwner()); fusedOp = rewriter.clone(*producerOp); @@ -453,21 +464,17 @@ ArrayRef producerOps = state.getPayloadOps(getProducerOp()); // If nothing to fuse, propagate success. if (producerOps.empty()) { - results.set(getResult().cast(), SmallVector{}); + results.set(getFusedOp().cast(), + SmallVector{}); return DiagnosedSilenceableFailure::success(); } - for (Operation *producerOp : producerOps) { - if (producerOp->getNumResults() != 1) { - Diagnostic diag(producerOp->getLoc(), DiagnosticSeverity::Remark); - diag << "op with != 1 results not supported"; - return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag)); - } - } ArrayRef containingOps = state.getPayloadOps(getContainingOp()); - if (containingOps.size() != 1) + if (containingOps.size() != 1) { + // Definite failure. return DiagnosedSilenceableFailure( this->emitOpError("requires exactly one containing_op handle (got ") << containingOps.size() << ")"); + } Operation *containingOp = containingOps.front(); // Helper function to find the next producer that should be fused. Take any @@ -498,6 +505,7 @@ while (!remainingProducers.empty()) { auto nextProducer = getNextProducer(); if (failed(nextProducer)) { + results.set(getFusedOp().cast(), ArrayRef()); Diagnostic diag(containingOp->getLoc(), DiagnosticSeverity::Remark); diag << "could not find next producer to fuse into container"; return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag)); @@ -505,7 +513,7 @@ Operation *producerOp = *nextProducer; - // Detaul diagnostic, to be complemented with more failure information. + // Default diagnostic, to be complemented with more failure information. Diagnostic diag(producerOp->getLoc(), DiagnosticSeverity::Remark); diag << "could not fuse " << *producerOp << " into " << *containingOp; @@ -517,6 +525,8 @@ Operation *tiled = tileAndFuseFirstExtractUse(rewriter, diag, producerOp, containingOp); if (tiled) { + LLVM_DEBUG(llvm::dbgs() << "\nFused a direct extract use\n" + << *containingOp); fusedOps.push_back(tiled); continue; } @@ -525,6 +535,9 @@ tileAndFuseFirstExtractUseThroughContainingOpBlockArgument( rewriter, diag, producerOp, containingOp); if (tiledContainingOpOperand) { + LLVM_DEBUG(llvm::dbgs() + << "\nFused an extract use through block argument\n" + << *containingOp); fusedOps.push_back(tiledContainingOpOperand); continue; } @@ -532,10 +545,12 @@ Operation *cloned = cloneAndFuseFirstUse(rewriter, diag, producerOp, containingOp); if (cloned) { + LLVM_DEBUG(llvm::dbgs() << "\nFused an use by cloning\n" + << *containingOp); fusedOps.push_back(cloned); continue; } - + results.set(getFusedOp().cast(), ArrayRef()); return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag)); } diff --git a/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir b/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir --- a/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir @@ -141,3 +141,63 @@ transform.structured.fuse_into_containing_op %0 into %1 } } + +// ----- + +#map0 = affine_map<()[s0, s1] -> (s0 ceildiv s1)> +#map1 = affine_map<(d0)[s0] -> (d0 * s0)> +#map2 = affine_map<(d0)[s0, s1] -> (-(d0 * s1) + s0, s1)> + +module { + // CHECK-LABEL: func.func @fuse_tileable_multi_output_op + // CHECK-SAME: %[[CHUNK_SIZE:[0-9a-z]+]]: index + // CHECK-SAME: %[[IN:[0-9a-z]+]]: tensor + // CHECK-SAME: %[[OUT_1:[0-9a-z]+]]: tensor + // CHECK-SAME: %[[OUT_2:[0-9a-z]+]]: tensor + // CHECK-SAME: %[[OUT_3:[0-9a-z]+]]: tensor + func.func @fuse_tileable_multi_output_op(%idx: index, %in: tensor, %out_1: tensor, %out_2: tensor, %out_3: tensor) -> tensor { + %cst = arith.constant 4.200000e+01 : f32 + %c0 = arith.constant 0 : index + + %0:2 = linalg.generic { + indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], + iterator_types = ["parallel"] + } ins(%in : tensor) outs(%out_1, %out_3 : tensor, tensor) { + ^bb0(%a: f32, %b: f32, %c: f32): + %d = arith.addf %a, %b : f32 + %e = arith.addf %d, %c : f32 + linalg.yield %d, %e : f32, f32 + } -> (tensor, tensor) + %d0 = tensor.dim %out_1, %c0 : tensor + + %1 = affine.apply #map0()[%d0, %idx] + + // CHECK: scf.foreach_thread {{.*}} { + %2 = scf.foreach_thread (%i) in (%1) shared_outs(%o = %out_2) -> (tensor) { + %3 = affine.apply #map1(%i)[%idx] + %4 = affine.min #map2(%i)[%d0, %idx] + %5 = tensor.extract_slice %o[%3] [%4] [1] : tensor to tensor + + // CHECK: %[[T0:.*]] = tensor.extract_slice %[[IN]][%{{.*}}] [%{{.*}}] [{{.*}}] + // CHECK: %[[T1:.*]]:2 = linalg.generic {{.*}} ins(%[[T0]] + %6 = tensor.extract_slice %0#0[%3] [%4] [1] : tensor to tensor + + // CHECK: %[[T2:.*]] = linalg.elemwise_unary ins(%[[T1]]#0 + %7 = linalg.elemwise_unary ins(%6 : tensor) outs(%5 : tensor) -> tensor + scf.foreach_thread.perform_concurrently { + tensor.parallel_insert_slice %7 into %o[%3] [%4] [1] : tensor into tensor + } + } + // CHECK: } + func.return %2 : tensor + } + + transform.sequence failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 + %1 = transform.structured.match ops{["scf.foreach_thread"]} in %arg1 + + // linalg.generic is tileable. The op is tiled and fused. + transform.structured.fuse_into_containing_op %0 into %1 + } +}