diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -505,6 +505,101 @@ return {source->get().dyn_cast(), destinationIterArg}; } +static std::optional +tileAndFuseProducerOfSlice(RewriterBase &rewriter, + tensor::ExtractSliceOp candidateSliceOp, + MutableArrayRef loops) { + // 1. Get the producer of the source (potentially walking through + // `iter_args` of nested `scf.for`) + auto [fusableProducer, destinationIterArg] = + getUntiledProducerFromSliceSource(&candidateSliceOp->getOpOperand(0), + loops); + if (!fusableProducer) + return std::nullopt; + + // 2. Generate the tiled implementation of the producer of the source + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(candidateSliceOp); + FailureOr fusedProducerValue = + tensor::replaceExtractSliceWithTiledProducer(rewriter, candidateSliceOp, + fusableProducer); + if (failed(fusedProducerValue)) + return std::nullopt; + rewriter.replaceOp(candidateSliceOp, fusedProducerValue.value()); + + // 3. If the slice is for a destination operand, for example, + // + // ```mlir + // %0 = linalg.init + // %1 = linalg.fill .. outs(%0 : ) + // %2 = scf.for .. iter_args(%arg0 = %1) { + // %3 = scf.for .. iter_args(%arg1 = %arg0) { + // %4 = tensor.extract_slice %arg1 [..] + // .. = linalg.matmul .. outs(%4 : ) + // } + // } + // ``` + // + // the IR is currently + // + // ``` + // %0 = linalg.init + // %1 = linalg.fill + // %2 = scf.for .. iter_args(%arg0 = %1 /* incorrect value */ ) { + // %3 = scf.for .. iter_args(%arg1 = %arg0) { + // %4 = tensor.extract_slice %0 /*incorrect value */ [..] + // %5 = linalg.fill .. outs(%4 : ) + // .. = linalg.matmul .. outs(%5 : ) + // } + // } + // ``` + // + // The untiled `linalg.fill` is still used as the `init_value` since it + // was originally a destination operand of the untiled `linalg.matmul`. + // When fusing an operand that is a destination operand. + // - Update the iter_arg of the outer most loop to use the destination + // of the untiled producer. + // - Update the destination of the slice of the tiled producer generated + // to use the same basic block argument as the slice that was used to + // generate inplace the tiled implementation of the producer. + // With this the IR will be. + // + // ``` + // %0 = linalg.init + // %1 = scf.for .. iter_args(%arg0 = %0 /* corrected value */ ) { + // %2 = scf.for .. iter_args(%arg1 = %arg0) { + // %3 = tensor.extract_slice %arg1 /* corrected value */ [..] + // %4 = linalg.fill .. outs(%3 : ) + // .. = linalg.matmul .. outs(%4 : ) + // } + // } + // ``` + // TODO: This can be modeled better if the `DestinationStyleOpInterface`. + // Update to use that when it does become available. + scf::ForOp outerMostLoop = loops.front(); + Optional iterArgNumber; + if (destinationIterArg) { + iterArgNumber = + outerMostLoop.getIterArgNumberForOpOperand(*destinationIterArg.value()); + } + if (iterArgNumber) { + int64_t resultNumber = fusableProducer.getResultNumber(); + if (auto dstOp = + dyn_cast(fusableProducer.getOwner())) { + outerMostLoop.setIterArg(iterArgNumber.value(), + dstOp.getTiedOpOperand(fusableProducer)->get()); + } + if (auto dstOp = fusedProducerValue.value() + .getDefiningOp()) { + scf::ForOp innerMostLoop = loops.back(); + updateDestinationOperandsForTiledOp( + rewriter, dstOp.getDpsInitOperand(resultNumber)->get(), + innerMostLoop.getRegionIterArgs()[iterArgNumber.value()]); + } + } + return fusedProducerValue->getDefiningOp(); +} + /// Implementation of tile consumer and fuse producer greedily. FailureOr mlir::scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp( @@ -559,105 +654,20 @@ addCandidateSlices(tileAndFuseResult.tiledAndFusedOps.back(), candidates); OpBuilder::InsertionGuard g(rewriter); while (!candidates.empty()) { - // 2a. Traverse the slices in BFS fashion. + // Traverse the slices in BFS fashion. tensor::ExtractSliceOp candidateSliceOp = candidates.front(); candidates.pop_front(); - // 2b. Get the producer of the source (potentially walking through - // `iter_args` of nested `scf.for`) - auto [fusableProducer, destinationIterArg] = - getUntiledProducerFromSliceSource(&candidateSliceOp->getOpOperand(0), - tileAndFuseResult.loops); - if (!fusableProducer) + // The operands of the fused producer might themselved be slices of + // values produced by operations that implement the `TilingInterface`. + // Add these operations to the worklist. + Optional fusedProducer = tileAndFuseProducerOfSlice( + rewriter, candidateSliceOp, tileAndFuseResult.loops); + if (!fusedProducer) continue; - // 2c. Generate the tiled implementation of the producer of the source - rewriter.setInsertionPoint(candidateSliceOp); - FailureOr fusedProducerValue = - tensor::replaceExtractSliceWithTiledProducer(rewriter, candidateSliceOp, - fusableProducer); - if (failed(fusedProducerValue)) - continue; - rewriter.replaceOp(candidateSliceOp, *fusedProducerValue); - - // 2d. The operands of the fused producer might themselved be slices of - // values produced by operations that implement the `TilingInterface`. - // Add these operations to the worklist. - Operation *fusedProducer = fusedProducerValue->getDefiningOp(); - tileAndFuseResult.tiledAndFusedOps.insert(fusedProducer); - addCandidateSlices(fusedProducer, candidates); - - // 2e. If the slice is for a destination operand, for example, - // - // ```mlir - // %0 = linalg.init - // %1 = linalg.fill .. outs(%0 : ) - // %2 = scf.for .. iter_args(%arg0 = %1) { - // %3 = scf.for .. iter_args(%arg1 = %arg0) { - // %4 = tensor.extract_slice %arg1 [..] - // .. = linalg.matmul .. outs(%4 : ) - // } - // } - // ``` - // - // the IR is currently - // - // ``` - // %0 = linalg.init - // %1 = linalg.fill - // %2 = scf.for .. iter_args(%arg0 = %1 /* incorrect value */ ) { - // %3 = scf.for .. iter_args(%arg1 = %arg0) { - // %4 = tensor.extract_slice %0 /*incorrect value */ [..] - // %5 = linalg.fill .. outs(%4 : ) - // .. = linalg.matmul .. outs(%5 : ) - // } - // } - // ``` - // - // The untiled `linalg.fill` is still used as the `init_value` since it - // was originally a destination operand of the untiled `linalg.matmul`. - // When fusing an operand that is a destination operand. - // - Update the iter_arg of the outer most loop to use the destination - // of the untiled producer. - // - Update the destination of the slice of the tiled producer generated - // to use the same basic block argument as the slice that was used to - // generate inplace the tiled implementation of the producer. - // With this the IR will be. - // - // ``` - // %0 = linalg.init - // %1 = scf.for .. iter_args(%arg0 = %0 /* corrected value */ ) { - // %2 = scf.for .. iter_args(%arg1 = %arg0) { - // %3 = tensor.extract_slice %arg1 /* corrected value */ [..] - // %4 = linalg.fill .. outs(%3 : ) - // .. = linalg.matmul .. outs(%4 : ) - // } - // } - // ``` - // TODO: This can be modeled better if the `DestinationStyleOpInterface`. - // Update to use that when it does become available. - scf::ForOp outerMostLoop = tileAndFuseResult.loops.front(); - std::optional iterArgNumber; - if (destinationIterArg) { - iterArgNumber = outerMostLoop.getIterArgNumberForOpOperand( - *destinationIterArg.value()); - } - if (iterArgNumber) { - int64_t resultNumber = fusableProducer.getResultNumber(); - if (auto dstOp = dyn_cast( - fusableProducer.getOwner())) { - outerMostLoop.setIterArg( - iterArgNumber.value(), - dstOp.getTiedOpOperand(fusableProducer)->get()); - } - if (auto dstOp = fusedProducerValue - ->getDefiningOp()) { - scf::ForOp innerMostLoop = tileAndFuseResult.loops.back(); - updateDestinationOperandsForTiledOp( - rewriter, dstOp.getDpsInitOperand(resultNumber)->get(), - innerMostLoop.getRegionIterArgs()[iterArgNumber.value()]); - } - } + tileAndFuseResult.tiledAndFusedOps.insert(fusedProducer.value()); + addCandidateSlices(fusedProducer.value(), candidates); } return tileAndFuseResult; }