diff --git a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h --- a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h @@ -44,13 +44,14 @@ /// - `loop` isnt erased, but is left in a "no-op" state where the body of the /// loop just yields the basic block arguments that correspond to the /// initialization values of a loop. The loop is dead after this method. -/// - All uses of the `newIterOperands` within the generated new loop -/// are replaced with the corresponding `BlockArgument` in the loop body. +/// - If `replaceIterOperandsUsesInLoop` is true, all uses of the +/// `newIterOperands` within the generated new loop using NewYieldValueFn = std::function( OpBuilder &b, Location loc, ArrayRef newBBArgs)>; scf::ForOp replaceLoopWithNewYields(OpBuilder &builder, scf::ForOp loop, ValueRange newIterOperands, - const NewYieldValueFn &newYieldValuesFn); + const NewYieldValueFn &newYieldValuesFn, + bool replaceIterOperandsUsesInLoop = true); /// Update a perfectly nested loop nest to yield new values from the innermost /// loop and propagating it up through the loop nest. This function @@ -64,12 +65,13 @@ /// the body of the loop just yields the basic block arguments that correspond /// to the initialization values of a loop. The original loops are dead after /// this method. -/// - All uses of the `newIterOperands` within the generated new loop -/// are replaced with the corresponding `BlockArgument` in the loop body. +/// - If `replaceIterOperandsUsesInLoop` is true, all uses of the +/// `newIterOperands` within the generated new loop SmallVector replaceLoopNestWithNewYields(OpBuilder &builder, ArrayRef loopNest, ValueRange newIterOperands, - const NewYieldValueFn &newYieldValueFn); + const NewYieldValueFn &newYieldValueFn, + bool replaceIterOperandsUsesInLoop = true); /// Outline a region with a single block into a new FuncOp. /// Assumes the FuncOp result types is the type of the yielded operands of the diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -167,6 +167,99 @@ return loops; } +/// For a value to be yielded (`yieldedValue`) from within a loop nest `loops`, +/// construct the destructive update pattern that inserts the yielded +/// value into a destination tensor provided by `initValue` at offset +/// `tileOffsets` and size `tileSizes`. For example, +/// +/// ```mlir +/// scf.for %iv0 = ... { +/// %0 = tiled_op +/// } +/// ``` +/// +/// is transformed to +/// +/// ```mlir +/// %result = scf.for %iv0 = ... iter_args(%arg = %init) -> .. { +/// %0 = tiled_op +/// %1 = tensor.insert_slice %0 into %arg[..] [..] [..] +/// scf.yield %1 +/// } +/// ``` +static SmallVector +yieldTiledValues(RewriterBase &rewriter, ValueRange initValues, + ValueRange yieldedValues, + ArrayRef> tileOffsetsList, + ArrayRef> tileSizesList, + MutableArrayRef loops) { + NewYieldValueFn yieldValueFn = + [&](OpBuilder &b, Location loc, + ArrayRef newBBArgs) -> SmallVector { + SmallVector inserts; + inserts.reserve(initValues.size()); + for (auto yieldedValue : llvm::enumerate(yieldedValues)) { + ArrayRef tileOffsets = + tileOffsetsList[yieldedValue.index()]; + ArrayRef tileSizes = tileSizesList[yieldedValue.index()]; + SmallVector tileStrides(tileOffsets.size(), + b.getIndexAttr(1)); + Value insert = b.create( + loc, yieldedValue.value(), newBBArgs[yieldedValue.index()], + tileOffsets, tileSizes, tileStrides); + inserts.push_back(insert); + } + return inserts; + }; + + SmallVector newLoops = + replaceLoopNestWithNewYields(rewriter, loops, initValues, yieldValueFn, + /*replaceIterOperandsUsesInLoop =*/false); + for (const auto &loop : llvm::enumerate(loops)) { + rewriter.eraseOp(loop.value()); + loops[loop.index()] = newLoops[loop.index()]; + } + return llvm::to_vector( + llvm::map_range(loops.front().getResults().take_back(initValues.size()), + [](OpResult r) -> Value { return r; })); +} + +/// If the tiled operation is destination passing style, update the +/// slice of the destination used (which refers to the untiled destination) +/// to use the corresponding region argument of the innermost loop. +/// +/// ```mlir +/// %0 = +/// scf.for %iv0 = ... iter_args(%arg = %0) { +/// %1 = tensor.extract_slice %0 +/// %2 = tiled_op +/// %3 = tensor.insert_slice %2 into %arg +/// scf.yield %3 +/// } +/// ``` +/// +/// is transformed to +/// +/// ```mlir +/// scf.for %iv0 = ... iter_args(%arg = %0) { +/// %1 = tensor.extract_slice %arg +/// %2 = tiled_op +/// %3 = tensor.insert_slice %2 into %arg +/// scf.yield %3 +/// } +/// ``` +static void +updateDestinationOperandsForTiledOp(OpBuilder &builder, + ValueRange tiledOpDestinationValues, + ValueRange bbArgsList) { + for (auto destValue : llvm::enumerate(tiledOpDestinationValues)) { + auto sliceOp = destValue.value().getDefiningOp(); + if (!sliceOp) + continue; + sliceOp.setOperand(0, bbArgsList[destValue.index()]); + } +} + scf::TileUsingSCFForOp::TileUsingSCFForOp(MLIRContext *context, scf::SCFTilingOptions options, PatternBenefit benefit) @@ -281,7 +374,6 @@ // 5. If the original operations has results, modify the loop nest to yield // the replacement values. - SmallVector replacements; if (tilingResult.loops.empty()) { // 5a. If there were no loops, the tiled implementation results are the // replacements. @@ -289,58 +381,34 @@ return tilingResult; } - // 5b. `scf.for` with tensor semantics requires the loop nest to yield the - // replacement values using destructive updates. Use the `TilingInterface` - // to get the position of the result tiles and use that to generate the - // destructive update pattern, i.e., - // - // ```mlir - // scf.for %iv0 = ... { - // %0 = tiled_op - // } - // ``` - // - // is transformed to - // - // ```mlir - // %result = scf.for %iv0 = ... iter_args(%arg = %init) -> .. { - // %0 = tiled_op - // %1 = tensor.insert_slice %0 into %arg[..] [..] [..] - // scf.yield %1 - // } - // ``` - NewYieldValueFn yieldValueFn = - [&](OpBuilder &b, Location loc, - ArrayRef newBBArgs) -> SmallVector { - SmallVector yieldedValues; - Attribute one = b.getIndexAttr(1); - for (auto resultNum : llvm::seq(0, op->getNumResults())) { - SmallVector resultTileOffsets, resultTileSizes; - if (failed(op.getResultTilePosition(b, resultNum, offsets, sizes, - resultTileOffsets, - resultTileSizes))) { - op.emitOpError("unable to get position of result ") - << resultNum << " of the tiled implementation"; - return {}; - } - SmallVector resultTileStrides(resultTileOffsets.size(), - one); - Value yieldedValue = b.create( - op->getLoc(), tilingResult.tiledOp->getResult(resultNum), - newBBArgs[resultNum], resultTileOffsets, resultTileSizes, - resultTileStrides); - yieldedValues.push_back(yieldedValue); + // 6. Yield the results of the tiled operation from the loop nest as + // replacements for the original untiled ops. + if (tilingResult.tiledOp->getNumResults() != op->getNumResults()) { + return rewriter.notifyMatchFailure( + tilingResult.tiledOp, + "expected tiled op to have as many results as the untiled operation"); + } + + // Since new loops are created during the process of yielding the values, and + // `offsets` typically depends on induction variable of loops, it is necessary + // to yield all the tiled result values at the same time. + unsigned numResults = tilingResult.tiledOp->getNumResults(); + SmallVector> resultTileOffsetsList(numResults); + SmallVector> resultTileSizesList(numResults); + for (auto tiledResult : llvm::enumerate(tilingResult.tiledOp->getResults())) { + if (failed(op.getResultTilePosition( + rewriter, tiledResult.index(), offsets, sizes, + resultTileOffsetsList[tiledResult.index()], + resultTileSizesList[tiledResult.index()]))) { + return rewriter.notifyMatchFailure( + op, "unable to get insertion position of tiled result"); } - return yieldedValues; - }; - SmallVector newLoops = replaceLoopNestWithNewYields( - rewriter, tilingResult.loops, op.getDestinationOperands(rewriter), - yieldValueFn); - for (const auto &loop : llvm::enumerate(tilingResult.loops)) { - rewriter.eraseOp(loop.value()); - tilingResult.loops[loop.index()] = newLoops[loop.index()]; } - rewriter.replaceOp(op, tilingResult.loops.front().getResults()); + SmallVector destinations = op.getDestinationOperands(rewriter); + SmallVector replacements = yieldTiledValues( + rewriter, destinations, tilingResult.tiledOp->getResults(), + resultTileOffsetsList, resultTileSizesList, tilingResult.loops); + rewriter.replaceOp(op, replacements); return tilingResult; } @@ -363,36 +431,25 @@ : OpInterfaceRewritePattern(context, benefit), tilingPattern(context, std::move(options)) {} -/// Return the `Value` that is defined by an operation that implements -/// the `TilingInterface`. Looks through `iter_args` of scf.for nest -/// if required. -static Optional getFusableProducer(Value v) { - while (auto blockArg = v.dyn_cast()) { - auto loopOp = dyn_cast(blockArg.getOwner()->getParentOp()); - if (!loopOp) - return llvm::None; - v = loopOp.getOpOperandForRegionIterArg(blockArg).get(); - } - if (!isa_and_nonnull(v.getDefiningOp())) - return llvm::None; - return v.cast(); -} - -// Replace iter args of the outer most loop with region args of the inner most -// one. -static void replaceIterArgs(scf::ForOp outerFor, scf::ForOp innerFor, - PatternRewriter &rewriter) { - assert(outerFor.getNumIterOperands() == innerFor.getNumIterOperands() && - "expect same number of iter args"); - Block *block = &(*innerFor.getRegion().begin()); - for (auto it : - llvm::zip(outerFor.getIterOperands(), innerFor.getRegionIterArgs())) { - Value source = std::get<0>(it); - Value target = std::get<1>(it); - source.replaceUsesWithIf(target, [&](OpOperand &use) { - return use.getOwner()->getBlock() == block; - }); +/// Return the untiled producer whose slice is used in a tiled consumer. The +/// method traverses the tile loop nest (`loops`) if needed, and returns the +/// `iter_args` of the outer most that is encountered. Traversing the iter_args +/// indicates that this is a destination operand of the consumer. If there was +/// no loop traversal needed, the second value of the returned tuple is empty. +static std::tuple> +getProducerOp(OpOperand *source, ArrayRef loops) { + Optional destinationIterArg; + auto loopIt = loops.rbegin(); + while (auto iterArg = source->get().dyn_cast()) { + scf::ForOp loop = *loopIt; + if (iterArg.getOwner()->getParentOp() != loop) + break; + source = &loop.getOpOperandForRegionIterArg(iterArg); + loopIt++; } + if (loopIt == loops.rend()) + destinationIterArg = source; + return {source->get().dyn_cast(), destinationIterArg}; } FailureOr @@ -441,8 +498,10 @@ // 2b. Get the producer of the source (potentially walking through // `iter_args` of nested `scf.for`) - Optional fusableProducer = - getFusableProducer(candidateSliceOp.getSource()); + OpResult fusableProducer; + Optional destinationIterArg; + std::tie(fusableProducer, destinationIterArg) = getProducerOp( + &candidateSliceOp->getOpOperand(0), tileAndFuseResult.loops); if (!fusableProducer) continue; @@ -450,7 +509,7 @@ rewriter.setInsertionPoint(candidateSliceOp); FailureOr fusedProducerValue = tensor::replaceExtractSliceWithTiledProducer(rewriter, candidateSliceOp, - fusableProducer.value()); + fusableProducer); if (failed(fusedProducerValue)) continue; rewriter.replaceOp(candidateSliceOp, fusedProducerValue.value()); @@ -462,56 +521,56 @@ tileAndFuseResult.tiledAndFusedOps.push_back(fusedProducer); addCandidateSlices(fusedProducer, candidates); - // 2e. If the operation being fused creates a value that is used as `outs` - // in the tiled operation, the result of the unfused operation will be - // used in the `iter_args` of the tiled loop generated. When the - // operation is fused, this use in `iter_args` needs to be modified to - // use the destination of the fused operation. For example, starting - // with + // 2e. If the slice is for a destination operand, then + // - Update the iter_arg of the outer most loop to use the destination + // of the untiled producer. + // - Update the destination of the slice of the tiled producer generated + // to use the same basic block argument as the slice that was used to + // producer the tiled implementation. + // For example // - // ```mlir - // %0 = linalg.init_tensor ... - // %1 = linalg.fill ... outs(%0:...)... - // %2 = linalg.matmul ... outs(%1:...).... - // ``` + // ```mlir + // %0 = linalg.init + // %1 = linalg.fill .. outs(%0 : ) + // %2 = scf.for .. iter_args(%arg0 = %1) { + // %3 = tensor.extract_slice %arg0 + // .. = linalg.matmul .. outs(%3 : ) + // } + // ``` // - // First the `linalg.matmul` gets tiled + // is transformed to // - // ```mlir - // %0 = linalg.init_tensor - // %1 = linalg.fill - // %2 = scf.for .... iter_args(%arg0 = %1)... - // ... - // ... = linalg.matmul ... - // - // ``` - // - // When the `linalg.fill` gets fused, the `iter_args` needs to be - // modified - // - // ```mlir - // %0 = linalg.init_tensor - // %1 = scf.for ... iter_args(%arg0 = %0)... - // ... - // %2 = linalg.fill ... - // %3 = linalg.matmul ... outs(%2: ...)... - // ``` - TilingInterface unfusedProducerOp = - cast(fusableProducer->getOwner()); - scf::ForOp outerMostTiledLoop = tileAndFuseResult.loops.front(); - SmallVector unfusedProducerOpDestValues = - unfusedProducerOp.getDestinationOperands(rewriter); - for (OpOperand &uses : unfusedProducerOp->getUses()) { - if (uses.getOwner() == outerMostTiledLoop.getOperation()) { - unsigned resultNumber = uses.get().cast().getResultNumber(); - unsigned operandNumber = uses.getOperandNumber(); - outerMostTiledLoop->setOperand( - operandNumber, unfusedProducerOpDestValues[resultNumber]); + // ```mlir + // %0 = linalg.init + // %1 = scf.for .. iter_args(%arg0 = %0) { + // %3 = tensor.extract_slice %arg0 + // %4 = linalg.fill .. outs(%3 : ) + // .. = linalg.matmul .. outs(%4 : ) + // } + // ``` + if (destinationIterArg) { + unsigned iterArgNumber = destinationIterArg.value()->getOperandNumber(); + unsigned resultNumber = fusableProducer.getResultNumber(); + if (auto producerOp = + dyn_cast(fusableProducer.getOwner())) { + scf::ForOp outerMostLoop = tileAndFuseResult.loops.front(); + SmallVector destination = + producerOp.getDestinationOperands(rewriter); + outerMostLoop.setOperand(iterArgNumber, destination[resultNumber]); + } + if (auto tiledAndFusedInterfaceOp = + fusedProducerValue.value().getDefiningOp()) { + scf::ForOp innerMostLoop = tileAndFuseResult.loops.back(); + SmallVector destination = + tiledAndFusedInterfaceOp.getDestinationOperands(rewriter); + updateDestinationOperandsForTiledOp( + rewriter, destination[resultNumber], + innerMostLoop + .getRegionIterArgs()[iterArgNumber - + innerMostLoop.getNumControlOperands()]); } } } - replaceIterArgs(tileAndFuseResult.loops.front(), - tileAndFuseResult.loops.back(), rewriter); return tileAndFuseResult; } diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp --- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp +++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp @@ -40,7 +40,8 @@ scf::ForOp mlir::replaceLoopWithNewYields(OpBuilder &builder, scf::ForOp loop, ValueRange newIterOperands, - const NewYieldValueFn &newYieldValuesFn) { + const NewYieldValueFn &newYieldValuesFn, + bool replaceIterOperandsUsesInLoop) { // Create a new loop before the existing one, with the extra operands. OpBuilder::InsertionGuard g(builder); builder.setInsertionPoint(loop); @@ -79,13 +80,15 @@ llvm::zip(bbArgs, newLoopBody->getArguments().take_front(bbArgs.size()))) std::get<0>(it).replaceAllUsesWith(std::get<1>(it)); - // Replace all uses of `newIterOperands` with the corresponding basic block - // arguments. - for (auto it : llvm::zip(newIterOperands, newBBArgs)) { - std::get<0>(it).replaceUsesWithIf(std::get<1>(it), [&](OpOperand &use) { - Operation *user = use.getOwner(); - return newLoop->isProperAncestor(user); - }); + if (replaceIterOperandsUsesInLoop) { + // Replace all uses of `newIterOperands` with the corresponding basic block + // arguments. + for (auto it : llvm::zip(newIterOperands, newBBArgs)) { + std::get<0>(it).replaceUsesWithIf(std::get<1>(it), [&](OpOperand &use) { + Operation *user = use.getOwner(); + return newLoop->isProperAncestor(user); + }); + } } // Replace all uses of the original loop with corresponding values from the @@ -104,7 +107,8 @@ SmallVector mlir::replaceLoopNestWithNewYields( OpBuilder &builder, ArrayRef loopNest, - ValueRange newIterOperands, const NewYieldValueFn &newYieldValueFn) { + ValueRange newIterOperands, const NewYieldValueFn &newYieldValueFn, + bool replaceIterOperandsUsesInLoop) { if (loopNest.empty()) return {}; SmallVector newLoopNest(loopNest.size()); @@ -121,8 +125,19 @@ newIterOperands.size())); return newYields; }; - newLoopNest[loopDepth] = replaceLoopWithNewYields( - builder, loopNest[loopDepth], newIterOperands, fn); + newLoopNest[loopDepth] = + replaceLoopWithNewYields(builder, loopNest[loopDepth], newIterOperands, + fn, replaceIterOperandsUsesInLoop); + if (!replaceIterOperandsUsesInLoop) { + unsigned subLen = newIterOperands.size(); + unsigned subStart = + newLoopNest[loopDepth + 1].getNumIterOperands() - subLen; + auto resetOperands = + newLoopNest[loopDepth + 1].getInitArgsMutable().slice(subStart, + subLen); + resetOperands.assign( + newLoopNest[loopDepth].getRegionIterArgs().take_back(subLen)); + } } return newLoopNest; } diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir --- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir +++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -test-tiling-interface=tile-consumer-and-fuse-producer-using-scf-for -split-input-file %s | FileCheck %s +// RUN: mlir-opt -test-tiling-interface=tile-consumer-and-fuse-producer-using-scf-for -split-input-file %s | FileCHECK:%s func.func @gemm_fill_fusion(%arg0 : tensor, %arg1 : tensor) -> tensor { %c0 = arith.constant 0 : index @@ -30,7 +30,7 @@ // CHECK-SAME: ins(%[[LHS_TILE]], %[[RHS_TILE]] : // CHECK-SAME: outs(%[[FILL_TILE]] : // CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[GEMM_TILE]] into %[[ITERARG1]][%[[IV0]], %[[IV1]]] -// CHECK scf.yield %[[INSERT]] +// CHECK: scf.yield %[[INSERT]] // ----- @@ -68,7 +68,7 @@ // CHECK-SAME: iter_args(%[[ITERARG1:.+]] = %[[ITERARG0]]) // CHECK-DAG: %[[LHS_TILE:.+]] = tensor.extract_slice %[[ARG0]][%[[IV0]], 0] // CHECK-DAG: %[[RHS_TILE:.+]] = tensor.extract_slice %[[ARG1]][0, %[[IV1]]] -// CHECK-DAG: %[[INIT_TILE:.+]] = tensor.extract_slice %[[ITERARG1]][%[[IV0]], %[[IV1]]] +// CHECK-DAG: %[[INIT_TILE:.+]] = tensor.extract_slice %[[INIT]][%[[IV0]], %[[IV1]]] // CHECK: %[[FILL_TILE:.+]] = linalg.fill // CHECK-SAME: outs(%[[INIT_TILE]] : // CHECK: %[[GEMM_TILE:.+]] = linalg.matmul @@ -80,7 +80,7 @@ // CHECK-SAME: ins(%[[GEMM_TILE]], %[[BIAS_TILE]] : // CHECK-SAME: outs(%[[OUTS_TILE]] : // CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[GENERIC_TILE]] into %[[ITERARG1]][%[[IV0]], %[[IV1]]] -// CHECK scf.yield %[[INSERT]] +// CHECK: scf.yield %[[INSERT]] // ----- @@ -130,7 +130,7 @@ // CHECK-SAME: ins(%[[GEMM0_TILE]], %[[RHS1_TILE]] : // CHECK-SAME: outs(%[[FILL1_TILE]] : // CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[GEMM1_TILE]] into %[[ITERARG]][%[[IV]], 0] -// CHECK scf.yield %[[INSERT]] +// CHECK: scf.yield %[[INSERT]] // ----- @@ -182,7 +182,7 @@ // CHECK-SAME: ins(%[[GEMM_TILE]] : // CHECK-SAME: outs(%[[OUTS_TILE]] : // CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[GENERIC_TILE]] into %[[ITERARG1]][%[[IV1]], %[[IV0]]] -// CHECK scf.yield %[[INSERT]] +// CHECK: scf.yield %[[INSERT]] // ----- @@ -218,7 +218,7 @@ // CHECK-SAME: iter_args(%[[ITERARG1:.+]] = %[[ITERARG0]]) // CHECK-DAG: %[[LHS_TILE:.+]] = tensor.extract_slice %[[ARG0]][%[[IV1]], 0] // CHECK-DAG: %[[RHS_TILE:.+]] = tensor.extract_slice %[[ARG1]][0, %[[IV0]]] -// CHECK-DAG: %[[INIT_TILE:.+]] = tensor.extract_slice %[[ITERARG1]][%[[IV1]], %[[IV0]]] +// CHECK-DAG: %[[INIT_TILE:.+]] = tensor.extract_slice %[[INIT]][%[[IV1]], %[[IV0]]] // CHECK: %[[FILL_TILE:.+]] = linalg.fill // CHECK-SAME: outs(%[[INIT_TILE]] : // CHECK: %[[GEMM_TILE:.+]] = linalg.matmul @@ -229,7 +229,7 @@ // CHECK-SAME: ins(%[[GEMM_TILE]] : // CHECK-SAME: outs(%[[INIT_TILE_2]] : // CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[GENERIC_TILE]] into %[[ITERARG1]][%[[IV1]], %[[IV0]]] -// CHECK scf.yield %[[INSERT]] +// CHECK: scf.yield %[[INSERT]] // -----