diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td --- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td +++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td @@ -250,6 +250,9 @@ void setLowerBound(Value bound) { getOperation()->setOperand(0, bound); } void setUpperBound(Value bound) { getOperation()->setOperand(1, bound); } void setStep(Value step) { getOperation()->setOperand(2, step); } + void setIterArg(unsigned iterArgNum, Value iterArgValue) { + getOperation()->setOperand(iterArgNum + getNumControlOperands(), iterArgValue); + } /// Number of induction variables, always 1 for scf::ForOp. unsigned getNumInductionVars() { return 1; } @@ -267,6 +270,17 @@ unsigned getNumIterOperands() { return getOperation()->getNumOperands() - getNumControlOperands(); } + /// Get the iter arg number for an operand. If it isnt an iter arg + /// operand return llvm::None. + Optional getIterArgNumberForOpOperand(OpOperand &opOperand) { + if (opOperand.getOwner() != getOperation()) + return llvm::None; + unsigned operandNumber = opOperand.getOperandNumber(); + if (operandNumber < getNumControlOperands()) + return llvm::None; + return operandNumber - getNumControlOperands(); + } + /// Get the region iter arg that corresponds to an OpOperand. /// This helper prevents internal op implementation detail leakage to /// clients by hiding the operand / block argument mapping. diff --git a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h --- a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h @@ -44,13 +44,15 @@ /// - `loop` isnt erased, but is left in a "no-op" state where the body of the /// loop just yields the basic block arguments that correspond to the /// initialization values of a loop. The loop is dead after this method. -/// - All uses of the `newIterOperands` within the generated new loop -/// are replaced with the corresponding `BlockArgument` in the loop body. +/// - If `replaceIterOperandsUsesInLoop` is true, all uses of the +/// `newIterOperands` within the generated new loop are replaced +/// with the corresponding `BlockArgument` in the loop body. using NewYieldValueFn = std::function( OpBuilder &b, Location loc, ArrayRef newBBArgs)>; scf::ForOp replaceLoopWithNewYields(OpBuilder &builder, scf::ForOp loop, ValueRange newIterOperands, - const NewYieldValueFn &newYieldValuesFn); + const NewYieldValueFn &newYieldValuesFn, + bool replaceIterOperandsUsesInLoop = true); /// Update a perfectly nested loop nest to yield new values from the innermost /// loop and propagating it up through the loop nest. This function @@ -64,12 +66,14 @@ /// the body of the loop just yields the basic block arguments that correspond /// to the initialization values of a loop. The original loops are dead after /// this method. -/// - All uses of the `newIterOperands` within the generated new loop -/// are replaced with the corresponding `BlockArgument` in the loop body. +/// - If `replaceIterOperandsUsesInLoop` is true, all uses of the +/// `newIterOperands` within the generated new loop are replaced with the +/// corresponding `BlockArgument` in the loop body. SmallVector replaceLoopNestWithNewYields(OpBuilder &builder, ArrayRef loopNest, ValueRange newIterOperands, - const NewYieldValueFn &newYieldValueFn); + const NewYieldValueFn &newYieldValueFn, + bool replaceIterOperandsUsesInLoop = true); /// Outline a region with a single block into a new FuncOp. /// Assumes the FuncOp result types is the type of the yielded operands of the diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -167,6 +167,44 @@ return loops; } +/// If the tiled operation is in destination passing style, update the +/// slice of the destination used (which refers to the untiled destination) +/// to use the corresponding region argument of the innermost loop. +/// +/// ```mlir +/// %0 = +/// scf.for %iv0 = ... iter_args(%arg = %0) { +/// %1 = tensor.extract_slice %0 +/// %2 = tiled_op +/// %3 = tensor.insert_slice %2 into %arg +/// scf.yield %3 +/// } +/// ``` +/// +/// is transformed to +/// +/// ```mlir +/// scf.for %iv0 = ... iter_args(%arg = %0) { +/// %1 = tensor.extract_slice %arg +/// %2 = tiled_op +/// %3 = tensor.insert_slice %2 into %arg +/// scf.yield %3 +/// } +/// ``` +/// TODO: This can be made much cleaner when `DestinationStyleOp` interface is +/// available generally. +static void +updateDestinationOperandsForTiledOp(OpBuilder &builder, + ValueRange tiledOpDestinationValues, + ValueRange bbArgsList) { + for (auto destValue : llvm::enumerate(tiledOpDestinationValues)) { + auto sliceOp = destValue.value().getDefiningOp(); + if (!sliceOp) + continue; + sliceOp.setOperand(0, bbArgsList[destValue.index()]); + } +} + scf::TileUsingSCFForOp::TileUsingSCFForOp(MLIRContext *context, scf::SCFTilingOptions options, PatternBenefit benefit) @@ -281,7 +319,6 @@ // 5. If the original operations has results, modify the loop nest to yield // the replacement values. - SmallVector replacements; if (tilingResult.loops.empty()) { // 5a. If there were no loops, the tiled implementation results are the // replacements. @@ -289,7 +326,15 @@ return tilingResult; } - // 5b. `scf.for` with tensor semantics requires the loop nest to yield the + // 6. Yield the results of the tiled operation from the loop nest as + // replacements for the original untiled ops. + if (tilingResult.tiledOp->getNumResults() != op->getNumResults()) { + return rewriter.notifyMatchFailure( + tilingResult.tiledOp, + "expected tiled op to have as many results as the untiled operation"); + } + + // `scf.for` with tensor semantics requires the loop nest to yield the // replacement values using destructive updates. Use the `TilingInterface` // to get the position of the result tiles and use that to generate the // destructive update pattern, i.e., @@ -335,7 +380,7 @@ }; SmallVector newLoops = replaceLoopNestWithNewYields( rewriter, tilingResult.loops, op.getDestinationOperands(rewriter), - yieldValueFn); + yieldValueFn, /*replaceIterOperandsUsesInLoops =*/false); for (const auto &loop : llvm::enumerate(tilingResult.loops)) { rewriter.eraseOp(loop.value()); tilingResult.loops[loop.index()] = newLoops[loop.index()]; @@ -363,36 +408,26 @@ : OpInterfaceRewritePattern(context, benefit), tilingPattern(context, std::move(options)) {} -/// Return the `Value` that is defined by an operation that implements -/// the `TilingInterface`. Looks through `iter_args` of scf.for nest -/// if required. -static Optional getFusableProducer(Value v) { - while (auto blockArg = v.dyn_cast()) { - auto loopOp = dyn_cast(blockArg.getOwner()->getParentOp()); - if (!loopOp) - return llvm::None; - v = loopOp.getOpOperandForRegionIterArg(blockArg).get(); - } - if (!isa_and_nonnull(v.getDefiningOp())) - return llvm::None; - return v.cast(); -} - -// Replace iter args of the outer most loop with region args of the inner most -// one. -static void replaceIterArgs(scf::ForOp outerFor, scf::ForOp innerFor, - PatternRewriter &rewriter) { - assert(outerFor.getNumIterOperands() == innerFor.getNumIterOperands() && - "expect same number of iter args"); - Block *block = &(*innerFor.getRegion().begin()); - for (auto it : - llvm::zip(outerFor.getIterOperands(), innerFor.getRegionIterArgs())) { - Value source = std::get<0>(it); - Value target = std::get<1>(it); - source.replaceUsesWithIf(target, [&](OpOperand &use) { - return use.getOwner()->getBlock() == block; - }); +/// Return the untiled producer whose slice is used in a tiled consumer. The +/// method traverses the tile loop nest (`loops`) if needed, and returns the +/// `iter_args` of the outer most that is encountered. Traversing the iter_args +/// indicates that this is a destination operand of the consumer. If there was +/// no loop traversal needed, the second value of the returned tuple is empty. +static std::tuple> +getUntiledProducerFromSliceSource(OpOperand *source, + ArrayRef loops) { + Optional destinationIterArg; + auto loopIt = loops.rbegin(); + while (auto iterArg = source->get().dyn_cast()) { + scf::ForOp loop = *loopIt; + if (iterArg.getOwner()->getParentOp() != loop) + break; + source = &loop.getOpOperandForRegionIterArg(iterArg); + loopIt++; } + if (loopIt == loops.rend()) + destinationIterArg = source; + return {source->get().dyn_cast(), destinationIterArg}; } FailureOr @@ -441,8 +476,9 @@ // 2b. Get the producer of the source (potentially walking through // `iter_args` of nested `scf.for`) - Optional fusableProducer = - getFusableProducer(candidateSliceOp.getSource()); + auto [fusableProducer, destinationIterArg] = + getUntiledProducerFromSliceSource(&candidateSliceOp->getOpOperand(0), + tileAndFuseResult.loops); if (!fusableProducer) continue; @@ -450,7 +486,7 @@ rewriter.setInsertionPoint(candidateSliceOp); FailureOr fusedProducerValue = tensor::replaceExtractSliceWithTiledProducer(rewriter, candidateSliceOp, - fusableProducer.value()); + fusableProducer); if (failed(fusedProducerValue)) continue; rewriter.replaceOp(candidateSliceOp, fusedProducerValue.value()); @@ -462,56 +498,81 @@ tileAndFuseResult.tiledAndFusedOps.push_back(fusedProducer); addCandidateSlices(fusedProducer, candidates); - // 2e. If the operation being fused creates a value that is used as `outs` - // in the tiled operation, the result of the unfused operation will be - // used in the `iter_args` of the tiled loop generated. When the - // operation is fused, this use in `iter_args` needs to be modified to - // use the destination of the fused operation. For example, starting - // with + // 2e. If the slice is for a destination operand, for example, // - // ```mlir - // %0 = linalg.init_tensor ... - // %1 = linalg.fill ... outs(%0:...)... - // %2 = linalg.matmul ... outs(%1:...).... - // ``` + // ```mlir + // %0 = linalg.init + // %1 = linalg.fill .. outs(%0 : ) + // %2 = scf.for .. iter_args(%arg0 = %1) { + // %3 = scf.for .. iter_args(%arg1 = %arg0) { + // %4 = tensor.extract_slice %arg1 [..] + // .. = linalg.matmul .. outs(%4 : ) + // } + // } + // ``` // - // First the `linalg.matmul` gets tiled + // the IR is currently // - // ```mlir - // %0 = linalg.init_tensor - // %1 = linalg.fill - // %2 = scf.for .... iter_args(%arg0 = %1)... - // ... - // ... = linalg.matmul ... + // ``` + // %0 = linalg.init + // %1 = linalg.fill + // %2 = scf.for .. iter_args(%arg0 = %1 /* incorrect value */ ) { + // %3 = scf.for .. iter_args(%arg1 = %arg0) { + // %4 = tensor.extract_slice %0 /*incorrect value */ [..] + // %5 = linalg.fill .. outs(%4 : ) + // .. = linalg.matmul .. outs(%5 : ) + // } + // } + // ``` // - // ``` + // The untiled `linalg.fill` is still used as the `init_value` since it + // was originally a destination operand of the untiled `linalg.matmul`. + // When fusing an operand that is a destination operand. + // - Update the iter_arg of the outer most loop to use the destination + // of the untiled producer. + // - Update the destination of the slice of the tiled producer generated + // to use the same basic block argument as the slice that was used to + // generate inplace the tiled implementation of the producer. + // With this the IR will be. // - // When the `linalg.fill` gets fused, the `iter_args` needs to be - // modified - // - // ```mlir - // %0 = linalg.init_tensor - // %1 = scf.for ... iter_args(%arg0 = %0)... - // ... - // %2 = linalg.fill ... - // %3 = linalg.matmul ... outs(%2: ...)... - // ``` - TilingInterface unfusedProducerOp = - cast(fusableProducer->getOwner()); - scf::ForOp outerMostTiledLoop = tileAndFuseResult.loops.front(); - SmallVector unfusedProducerOpDestValues = - unfusedProducerOp.getDestinationOperands(rewriter); - for (OpOperand &uses : unfusedProducerOp->getUses()) { - if (uses.getOwner() == outerMostTiledLoop.getOperation()) { - unsigned resultNumber = uses.get().cast().getResultNumber(); - unsigned operandNumber = uses.getOperandNumber(); - outerMostTiledLoop->setOperand( - operandNumber, unfusedProducerOpDestValues[resultNumber]); + // ``` + // %0 = linalg.init + // %1 = scf.for .. iter_args(%arg0 = %0 /* corrected value */ ) { + // %2 = scf.for .. iter_args(%arg1 = %arg0) { + // %3 = tensor.extract_slice %arg1 /* corrected value */ [..] + // %4 = linalg.fill .. outs(%3 : ) + // .. = linalg.matmul .. outs(%4 : ) + // } + // } + // ``` + // TODO: This can be modeled better if the `DestinationStyleOpInterface`. + // Update to use that when it does become available. + scf::ForOp outerMostLoop = tileAndFuseResult.loops.front(); + Optional iterArgNumber; + if (destinationIterArg) { + iterArgNumber = outerMostLoop.getIterArgNumberForOpOperand( + *destinationIterArg.value()); + } + if (iterArgNumber) { + unsigned resultNumber = fusableProducer.getResultNumber(); + if (auto producerOp = + dyn_cast(fusableProducer.getOwner())) { + SmallVector destination = + producerOp.getDestinationOperands(rewriter); + outerMostLoop.setIterArg(iterArgNumber.value(), + destination[resultNumber]); + } + if (auto tiledAndFusedInterfaceOp = + fusedProducerValue.value().getDefiningOp()) { + scf::ForOp innerMostLoop = tileAndFuseResult.loops.back(); + SmallVector destination = + tiledAndFusedInterfaceOp.getDestinationOperands(rewriter); + updateDestinationOperandsForTiledOp( + rewriter, destination[resultNumber], + innerMostLoop.getRegionIterArgs()[iterArgNumber.value()]); } } } - replaceIterArgs(tileAndFuseResult.loops.front(), - tileAndFuseResult.loops.back(), rewriter); return tileAndFuseResult; } diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp --- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp +++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp @@ -40,7 +40,8 @@ scf::ForOp mlir::replaceLoopWithNewYields(OpBuilder &builder, scf::ForOp loop, ValueRange newIterOperands, - const NewYieldValueFn &newYieldValuesFn) { + const NewYieldValueFn &newYieldValuesFn, + bool replaceIterOperandsUsesInLoop) { // Create a new loop before the existing one, with the extra operands. OpBuilder::InsertionGuard g(builder); builder.setInsertionPoint(loop); @@ -79,13 +80,15 @@ llvm::zip(bbArgs, newLoopBody->getArguments().take_front(bbArgs.size()))) std::get<0>(it).replaceAllUsesWith(std::get<1>(it)); - // Replace all uses of `newIterOperands` with the corresponding basic block - // arguments. - for (auto it : llvm::zip(newIterOperands, newBBArgs)) { - std::get<0>(it).replaceUsesWithIf(std::get<1>(it), [&](OpOperand &use) { - Operation *user = use.getOwner(); - return newLoop->isProperAncestor(user); - }); + if (replaceIterOperandsUsesInLoop) { + // Replace all uses of `newIterOperands` with the corresponding basic block + // arguments. + for (auto it : llvm::zip(newIterOperands, newBBArgs)) { + std::get<0>(it).replaceUsesWithIf(std::get<1>(it), [&](OpOperand &use) { + Operation *user = use.getOwner(); + return newLoop->isProperAncestor(user); + }); + } } // Replace all uses of the original loop with corresponding values from the @@ -104,7 +107,8 @@ SmallVector mlir::replaceLoopNestWithNewYields( OpBuilder &builder, ArrayRef loopNest, - ValueRange newIterOperands, const NewYieldValueFn &newYieldValueFn) { + ValueRange newIterOperands, const NewYieldValueFn &newYieldValueFn, + bool replaceIterOperandsUsesInLoop) { if (loopNest.empty()) return {}; SmallVector newLoopNest(loopNest.size()); @@ -121,8 +125,41 @@ newIterOperands.size())); return newYields; }; - newLoopNest[loopDepth] = replaceLoopWithNewYields( - builder, loopNest[loopDepth], newIterOperands, fn); + newLoopNest[loopDepth] = + replaceLoopWithNewYields(builder, loopNest[loopDepth], newIterOperands, + fn, replaceIterOperandsUsesInLoop); + if (!replaceIterOperandsUsesInLoop) { + /// The yield is expected to producer the following structure + /// ``` + /// %0 = scf.for ... iter_args(%arg0 = %init) { + /// %1 = scf.for ... iter_args(%arg1 = %arg0) { + /// scf.yield %yield + /// } + /// } + /// ``` + /// + /// since the yield is propagated from inside out, after the inner + /// loop is processed the IR is in this form + /// + /// ``` + /// scf.for ... iter_args { + /// %1 = scf.for ... iter_args(%arg1 = %init) { + /// scf.yield %yield + /// } + /// ``` + /// + /// If `replaceIterOperandUsesInLoops` is true, there is nothing to do. + /// `%init` will be replaced with `%arg0` when it is created for the + /// outer loop. But without that this has to be done explicitly. + unsigned subLen = newIterOperands.size(); + unsigned subStart = + newLoopNest[loopDepth + 1].getNumIterOperands() - subLen; + auto resetOperands = + newLoopNest[loopDepth + 1].getInitArgsMutable().slice(subStart, + subLen); + resetOperands.assign( + newLoopNest[loopDepth].getRegionIterArgs().take_back(subLen)); + } } return newLoopNest; } diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir --- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir +++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir @@ -30,7 +30,7 @@ // CHECK-SAME: ins(%[[LHS_TILE]], %[[RHS_TILE]] : // CHECK-SAME: outs(%[[FILL_TILE]] : // CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[GEMM_TILE]] into %[[ITERARG1]][%[[IV0]], %[[IV1]]] -// CHECK scf.yield %[[INSERT]] +// CHECK: scf.yield %[[INSERT]] // ----- @@ -68,7 +68,7 @@ // CHECK-SAME: iter_args(%[[ITERARG1:.+]] = %[[ITERARG0]]) // CHECK-DAG: %[[LHS_TILE:.+]] = tensor.extract_slice %[[ARG0]][%[[IV0]], 0] // CHECK-DAG: %[[RHS_TILE:.+]] = tensor.extract_slice %[[ARG1]][0, %[[IV1]]] -// CHECK-DAG: %[[INIT_TILE:.+]] = tensor.extract_slice %[[ITERARG1]][%[[IV0]], %[[IV1]]] +// CHECK-DAG: %[[INIT_TILE:.+]] = tensor.extract_slice %[[INIT]][%[[IV0]], %[[IV1]]] // CHECK: %[[FILL_TILE:.+]] = linalg.fill // CHECK-SAME: outs(%[[INIT_TILE]] : // CHECK: %[[GEMM_TILE:.+]] = linalg.matmul @@ -80,7 +80,7 @@ // CHECK-SAME: ins(%[[GEMM_TILE]], %[[BIAS_TILE]] : // CHECK-SAME: outs(%[[OUTS_TILE]] : // CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[GENERIC_TILE]] into %[[ITERARG1]][%[[IV0]], %[[IV1]]] -// CHECK scf.yield %[[INSERT]] +// CHECK: scf.yield %[[INSERT]] // ----- @@ -130,7 +130,7 @@ // CHECK-SAME: ins(%[[GEMM0_TILE]], %[[RHS1_TILE]] : // CHECK-SAME: outs(%[[FILL1_TILE]] : // CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[GEMM1_TILE]] into %[[ITERARG]][%[[IV]], 0] -// CHECK scf.yield %[[INSERT]] +// CHECK: scf.yield %[[INSERT]] // ----- @@ -182,7 +182,7 @@ // CHECK-SAME: ins(%[[GEMM_TILE]] : // CHECK-SAME: outs(%[[OUTS_TILE]] : // CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[GENERIC_TILE]] into %[[ITERARG1]][%[[IV1]], %[[IV0]]] -// CHECK scf.yield %[[INSERT]] +// CHECK: scf.yield %[[INSERT]] // ----- @@ -218,7 +218,7 @@ // CHECK-SAME: iter_args(%[[ITERARG1:.+]] = %[[ITERARG0]]) // CHECK-DAG: %[[LHS_TILE:.+]] = tensor.extract_slice %[[ARG0]][%[[IV1]], 0] // CHECK-DAG: %[[RHS_TILE:.+]] = tensor.extract_slice %[[ARG1]][0, %[[IV0]]] -// CHECK-DAG: %[[INIT_TILE:.+]] = tensor.extract_slice %[[ITERARG1]][%[[IV1]], %[[IV0]]] +// CHECK-DAG: %[[INIT_TILE:.+]] = tensor.extract_slice %[[INIT]][%[[IV1]], %[[IV0]]] // CHECK: %[[FILL_TILE:.+]] = linalg.fill // CHECK-SAME: outs(%[[INIT_TILE]] : // CHECK: %[[GEMM_TILE:.+]] = linalg.matmul @@ -229,7 +229,7 @@ // CHECK-SAME: ins(%[[GEMM_TILE]] : // CHECK-SAME: outs(%[[INIT_TILE_2]] : // CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[GENERIC_TILE]] into %[[ITERARG1]][%[[IV1]], %[[IV0]]] -// CHECK scf.yield %[[INSERT]] +// CHECK: scf.yield %[[INSERT]] // -----