diff --git a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp @@ -100,6 +100,9 @@ /// The ExtractSliceOp that feeds the PadOp we want to hoist. tensor::ExtractSliceOp sliceOp; + /// If non-empty, this is the unique scf::ForOp that consumes the `sliceOp`. + scf::ForOp padConsumingForOp; + private: /// Drop any non-index dependencies of `padOp` and `sliceOp` from /// `backwardSlice`. The method follows the use-def chains of the index @@ -224,9 +227,12 @@ LLVM_DEBUG(DBGS() << "----Source not defined outside of loops -> Skip\n"); return; } + if (sliceOp->hasOneUse()) { + padConsumingForOp = dyn_cast(*(sliceOp->getUsers().begin())); + } - // Check the region of `padOp` depends on a constant only. Adding - // hoisting support for arbitrary padding regions would require cloning all + // Check the region of `padOp` depends on a constant only. Adding hoisting + // support for arbitrary padding regions would require cloning all // dependencies captured by the padding region. Value paddingValue = padOp.getConstantPaddingValue(); if (!paddingValue || @@ -259,6 +265,13 @@ if (backwardSlice.contains(forOp)) packingLoops.push_back(forOp); + // TODO: for multiple loops we need to track the use to the innermost loop. + if (packingLoops.size() > 1 && padConsumingForOp) { + LLVM_DEBUG(DBGS() << "--Cannot hoist multiple loops through iter_args -> " + "Downgrade to 1 loop\n"); + packingLoops.resize(1); + } + // Note: at this point, packing loops may be empty but we would still like // to hoist the padding if so specified. @@ -512,18 +525,21 @@ paddedTensor = maybeTransposeOp.getResult(0); } - // Step 4. Create InsertSliceOp at the innermost loop level, inserting an - // optionally transposed padded slice into the packed tensor. - Value inserted = rewriter.create( - loc, paddedTensor, packedTensor, offsets, sizes, strides); - - // Step 5. Iteratively pop the stack and propagate the yield. - Value valueToYield = inserted; - for (Value iv : llvm::reverse(clonedLoopIvs)) { - auto forOp = scf::getForInductionVarOwner(iv); - rewriter.setInsertionPointToEnd(&forOp.getRegion().front()); - rewriter.create(loc, valueToYield); - valueToYield = forOp.getResult(0); + // Innermost tensor.insert_slice and yields are optional / need loops. + if (nPackedLoops > 0) { + // Step 4. Create InsertSliceOp at the innermost loop level, inserting an + // optionally transposed padded slice into the packed tensor. + Value inserted = rewriter.create( + loc, paddedTensor, packedTensor, offsets, sizes, strides); + + // Step 5. Iteratively pop the stack and propagate the yield. + Value valueToYield = inserted; + for (Value iv : llvm::reverse(clonedLoopIvs)) { + auto forOp = scf::getForInductionVarOwner(iv); + rewriter.setInsertionPointToEnd(&forOp.getRegion().front()); + rewriter.create(loc, valueToYield); + valueToYield = forOp.getResult(0); + } } return PackingLoopNestResult{offsets, @@ -534,6 +550,36 @@ maybeTransposeOp}; } +// If the original consumer of `sliceOp` was a `forOp` (i.e. through an iter +// arg), propagate the `packedTensor` value through the same iter arg. +// TODO: for multiple loops we need to track the use to the innermost loop. +static void padThroughLoopIterArg(RewriterBase &rewriter, Value packedTensor, + tensor::ExtractSliceOp sliceOp, + scf::ForOp forOp) { + OpOperand *pUse = nullptr; + for (OpOperand &use : sliceOp->getUses()) { + if (use.getOwner() == forOp) { + assert(!pUse && "Multiple slice uses in the for loop"); + pUse = &use; + } + } + assert(pUse && "No slice use in the for loop"); + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPointAfter(packedTensor.getDefiningOp()); + Value casted = rewriter.create( + packedTensor.getLoc(), pUse->get().getType(), packedTensor); + + std::optional operandNumber = + forOp.getIterArgNumberForOpOperand(*pUse); + assert(operandNumber.has_value() && "expected a proper iter arg number"); + SmallVector initArgs = forOp.getInitArgs(); + initArgs[operandNumber.value()] = casted; + rewriter.startRootUpdate(forOp); + forOp.getInitArgsMutable().assign(initArgs); + rewriter.finalizeRootUpdate(forOp); + packedTensor = forOp.getRegionIterArgForOpOperand(*pUse); +} + /// Produce a tensor extracted from the packingResult. This can be used as a /// replacement for `opToHoist` in callers. static Value replaceByPackingLoopNestResult( @@ -588,8 +634,11 @@ LLVM_DEBUG(DBGS() << "packedTensor: " << packedTensor << "\n"); - // TODO: atm we are missing the plumbing of packedTensor through the loop - // bbarg when required (i.e. when hoisting init tensors). + // If the consumer of `padOp` was a `forOp`, propagate through iter args. + scf::ForOp forOp = analysis.padConsumingForOp; + if (forOp) { + padThroughLoopIterArg(rewriter, packedTensor, analysis.sliceOp, forOp); + } // offsets = [maybe_leading_ivs, 0 .. 0]. // sizes = [1 .. 1, transposedShape] (defined above). diff --git a/mlir/test/Dialect/Linalg/transform-op-hoist-pad.mlir b/mlir/test/Dialect/Linalg/transform-op-hoist-pad.mlir --- a/mlir/test/Dialect/Linalg/transform-op-hoist-pad.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-hoist-pad.mlir @@ -161,14 +161,13 @@ // CHECK: scf.for %{{.*}} -> (tensor<24x25xf32>) { // CHECK: %[[PADDED:.*]] = tensor.pad %{{.*}} // CHECK: : tensor to tensor<5x25xf32> - // CHECK: scf.for %{{.*}} -> (tensor) { - // CHECK: %[[RES:.*]] = linalg.matmul {{.*}} outs(%[[PADDED]] : tensor<5x25xf32> - // - // TODO: atm we are missing the plumbing of packedTensor through the loop bbarg - // when required (i.e. when hoisting init tensors). - // CHECK: %[[RES_EXTRACTED:.*]] = tensor.extract_slice %[[RES]][0, 0] [%{{.*}}, 25] [1, 1] - // CHECK-SAME: : tensor<5x25xf32> to tensor - // CHECK: scf.yield %[[RES_EXTRACTED]] : tensor + // CHECK: scf.for %{{.*}} iter_args(%[[INNER_PADDED:[0-9a-zA-Z]*]] = %[[PADDED]]) -> (tensor<5x25xf32>) + // CHECK: %[[RES:.*]] = linalg.matmul {{.*}} outs(%[[INNER_PADDED]] + // CHECK-SAME: : tensor<5x25xf32> + // CHECK: scf.yield %[[RES]] : tensor<5x25xf32> + // CHECK: %[[CAST:.*]] = tensor.cast %{{.*}} : tensor<5x25xf32> to tensor + // CHECK: tensor.insert_slice %[[CAST]] into %{{.*}}[%{{.*}}, 0] [%{{.*}}, 25] [1, 1] + // CHECK-SAME: : tensor into tensor<24x25xf32> %0 = linalg.matmul ins(%arg0, %arg1 : tensor<24x12xf32>, tensor<12x25xf32>) outs(%arg2 : tensor<24x25xf32>) -> tensor<24x25xf32> func.return %0 : tensor<24x25xf32> }