diff --git a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp @@ -727,12 +727,35 @@ // hoistPaddingOnTensors Implementation. //===----------------------------------------------------------------------===// -// If the original consumer of `sliceOp` was a `forOp` (i.e. through an iter -// arg), propagate the `packedTensor` value through the same iter arg. -// TODO: for multiple loops we need to track the use to the innermost loop. -static Value padThroughLoopIterArg(RewriterBase &rewriter, Value packedTensor, - tensor::ExtractSliceOp sliceOp, - scf::ForOp forOp) { +/// If the original consumer of `sliceOp` was a `forOp` (i.e. through an iter +/// arg), propagate the `packedTensor` value through the same iter arg. +/// TODO: for multiple loops we need to track the use to the innermost loop. +/// +/// Match: +/// ``` +/// %s = tensor.extract_slice .. +/// %f = scf.for ... iter_args(%arg0 = %s) { +/// %0 = tensor.pad %arg0 +/// %1 = compute %0 +/// %2 = tensor.extract_slice %1 +/// scf.yield %2 +/// } +/// ``` +/// +/// and rewrite as: +/// ``` +/// %s = tensor.extract_slice .. +/// %0 = tensor.pad %s +/// %f = scf.for ... iter_args(%arg0 = %0) { +/// %1 = compute %0 +/// scf.yield %1 +/// } +/// %2 = tensor.extract_slice %f +/// ``` +static tensor::ExtractSliceOp +padThroughLoopIterArg(RewriterBase &rewriter, tensor::PadOp opToHoist, + Value packedTensor, tensor::ExtractSliceOp sliceOp, + scf::ForOp forOp) { OpOperand *pUse = nullptr; for (OpOperand &use : sliceOp->getUses()) { if (use.getOwner() == forOp) { @@ -743,19 +766,59 @@ assert(pUse && "No slice use in the for loop"); OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPointAfter(packedTensor.getDefiningOp()); - Value casted = rewriter.create( - packedTensor.getLoc(), pUse->get().getType(), packedTensor); - std::optional operandNumber = + std::optional maybeOperandNumber = forOp.getIterArgNumberForOpOperand(*pUse); - assert(operandNumber.has_value() && "expected a proper iter arg number"); + if (!maybeOperandNumber.has_value()) + return tensor::ExtractSliceOp(); + int64_t operandNumber = maybeOperandNumber.value(); + auto yieldOp = cast(forOp.getBody(0)->getTerminator()); + auto yieldingExtractSliceOp = yieldOp->getOperand(operandNumber) + .getDefiningOp(); + if (!yieldingExtractSliceOp) + return tensor::ExtractSliceOp(); + + // TODO: check that `sliceOp` and `yieldingExtractSliceOp` are the same slice. SmallVector initArgs = forOp.getInitArgs(); - initArgs[operandNumber.value()] = casted; - rewriter.startRootUpdate(forOp); - forOp.getInitArgsMutable().assign(initArgs); - rewriter.finalizeRootUpdate(forOp); - return forOp.getRegionIterArgForOpOperand(*pUse); + initArgs[operandNumber] = packedTensor; + SmallVector yieldOperands = yieldOp.getOperands(); + yieldOperands[operandNumber] = yieldingExtractSliceOp.getSource(); + + int64_t numOriginalForOpResults = initArgs.size(); + LLVM_DEBUG(DBGS() << "numOriginalForOpResults: " << numOriginalForOpResults + << "\n"); + tensor::ExtractSliceOp extracted; + { + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPointAfter(forOp); + extracted = rewriter.create( + packedTensor.getLoc(), packedTensor, sliceOp.getMixedOffsets(), + sliceOp.getMixedSizes(), sliceOp.getMixedStrides()); + rewriter.replaceAllUsesWith(forOp.getResult(operandNumber), extracted); + } + scf::ForOp newForOp = + replaceLoopWithNewYields(rewriter, forOp, initArgs, yieldOperands); + + LLVM_DEBUG(DBGS() << "newForOp results: " << newForOp.getNumResults() + << "\n"); + LLVM_DEBUG(DBGS() << "replace source of: " << extracted << "\n"); + LLVM_DEBUG(DBGS() << "with result #" + << numOriginalForOpResults + operandNumber + << " of forOp, giving us: " << extracted << "\n"); + rewriter.startRootUpdate(extracted); + extracted.getSourceMutable().assign( + newForOp.getResult(numOriginalForOpResults + operandNumber)); + rewriter.finalizeRootUpdate(extracted); + + LLVM_DEBUG(DBGS() << "replace uses of: " << opToHoist << "\n"); + LLVM_DEBUG(DBGS() << "with region iter arg #" + << numOriginalForOpResults + operandNumber << "\n"); + rewriter.replaceAllUsesWith( + opToHoist, + newForOp.getRegionIterArg(numOriginalForOpResults + operandNumber)); + + return extracted; } /// Produce a tensor extracted from the packingResult. This can be used as a @@ -811,8 +874,8 @@ // If the consumer of `padOp` was a `forOp`, propagate through iter args. scf::ForOp forOp = analysis.padConsumingForOp; if (forOp) { - packedTensor = - padThroughLoopIterArg(rewriter, packedTensor, analysis.sliceOp, forOp); + return padThroughLoopIterArg(rewriter, opToHoist, packedTensor, + analysis.sliceOp, forOp); } // offsets = [maybe_leading_ivs, 0 .. 0]. diff --git a/mlir/test/Dialect/Linalg/transform-op-hoist-pad.mlir b/mlir/test/Dialect/Linalg/transform-op-hoist-pad.mlir --- a/mlir/test/Dialect/Linalg/transform-op-hoist-pad.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-hoist-pad.mlir @@ -161,12 +161,13 @@ // CHECK: scf.for %{{.*}} -> (tensor<24x25xf32>) { // CHECK: %[[PADDED:.*]] = tensor.pad %{{.*}} // CHECK: : tensor to tensor<5x25xf32> - // CHECK: scf.for %{{.*}} iter_args(%[[INNER_PADDED:[0-9a-zA-Z]*]] = %[[PADDED]]) -> (tensor<5x25xf32>) + // CHECK: %[[SCF_YIELD:.*]] = scf.for %{{.*}} iter_args(%[[INNER_PADDED:[0-9a-zA-Z]*]] = %[[PADDED]]) -> (tensor<5x25xf32>) // CHECK: %[[RES:.*]] = linalg.matmul {{.*}} outs(%[[INNER_PADDED]] // CHECK-SAME: : tensor<5x25xf32> // CHECK: scf.yield %[[RES]] : tensor<5x25xf32> - // CHECK: %[[CAST:.*]] = tensor.cast %{{.*}} : tensor<5x25xf32> to tensor - // CHECK: tensor.insert_slice %[[CAST]] into %{{.*}}[%{{.*}}, 0] [%{{.*}}, 25] [1, 1] + // CHECK: %[[EXTRACTED:.*]] = tensor.extract_slice %[[SCF_YIELD]][%{{.*}}, 0] [%{{.*}}, 25] [1, 1] + // CHECK-SAME: : tensor<5x25xf32> to tensor + // CHECK: tensor.insert_slice %[[EXTRACTED]] into %{{.*}}[%{{.*}}, 0] [%{{.*}}, 25] [1, 1] // CHECK-SAME: : tensor into tensor<24x25xf32> %0 = linalg.matmul ins(%arg0, %arg1 : tensor<24x12xf32>, tensor<12x25xf32>) outs(%arg2 : tensor<24x25xf32>) -> tensor<24x25xf32> func.return %0 : tensor<24x25xf32>