diff --git a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp @@ -74,25 +74,29 @@ SmallVector packingLoops; private: - /// Returns the loops in `backwardSlice` used to index the padded data. The - /// method starts from `padTensorOp` and `sliceOp`, follows the use-def - /// chains of their index operands, and stores any enclosing loop whose - /// induction variable is part of the walked index computation. + /// Drop any non-index dependencies of `padTensorOp` and `sliceOp` from + /// `backwardSlice`. The method follows the use-def chains of the index + /// operands consumed by `padTensorOp` and `sliceOp` and drops the operations + /// not part of this index computation. Afterwards, the filtered + /// `backwardSlice` contains only the loops whose induction variable is used, + /// directly or indirectly, to index the padded tensor. /// /// Example: /// ``` /// %source = linalg.fill(%cst, %arg0) /// scf.for %i - /// scf.for %j - /// scf.for %k // not used to index %source! + /// %unrelated = linalg.fill(%cst, %arg1) // not used to index %source! + /// scf.for %j (%arg2 = %unrelated) + /// scf.for %k // not used to index %source! /// %ubi = affine.min #map(%i) /// %ubj = affine.min #map(%j) /// %slice = tensor.extract_slice %source [%i, %j] [%ubi, %ubj] /// %padded_slice = linalg.pad_tensor %slice /// ``` - /// getIndexingLoops(%padded_slice, %slice) returns [scf.for %i, scf.for %j] - SmallVector getIndexingLoops(PadTensorOp padTensorOp, - tensor::ExtractSliceOp sliceOp); + /// dropNonIndexDependencies(%padded_slice, %slice) + /// removes [scf.for %k, linalg.fill(%cst, %arg1)] from backwardSlice. + void dropNonIndexDependencies(PadTensorOp padTensorOp, + tensor::ExtractSliceOp sliceOp); /// Encodes whether the analysis is valid and hoisting can proceed. bool valid; @@ -144,7 +148,7 @@ if (!isOnlyUsedAsInputOfLinalgOp(padTensorOp)) return; - // Get at most nLevels of immediately enclosing loops. + // Get at most `numLoops` of immediately enclosing loops. SmallVector reverseEnclosingLoops; getAtMostNEnclosingLoops(padTensorOp, numLoops, reverseEnclosingLoops); if (reverseEnclosingLoops.empty()) { @@ -154,28 +158,6 @@ outermostEnclosingForOp = reverseEnclosingLoops.back(); - // Get all the ops in the backwards slice starting from `padTensorOp` and that - // are dominated by the outermost enclosing loop. - // Bail on any op with a region that is not either a scf::ForOp or a LinalgOp. - bool analysisFailure = false; - DominanceInfo domInfo(outermostEnclosingForOp); - getBackwardSlice( - padTensorOp.getOperation(), &backwardSlice, [&](Operation *op) { - if (!domInfo.dominates(outermostEnclosingForOp, op)) - return false; - if (op != padTensorOp && op->getNumRegions() > 0 && - !isa(op)) { - analysisFailure = true; - LLVM_DEBUG(DBGS() - << "Unsupported op with region: " << *op << " -> skip\n"); - return false; - } - return true; - }); - - if (analysisFailure || backwardSlice.empty()) - return; - // Get the `sliceOp` that defines the source tensor of `padTensorOp` and // check its source is defined outside of the outermost loop. This check // ensures the padded data is available for packing before entering the @@ -201,21 +183,42 @@ return; } - // Search the loops found in `backwardSlice` used to index the padded data. - SmallVector indexingLoops = - getIndexingLoops(padTensorOp, sliceOp); + // Get all the ops in the backwards slice starting from `padTensorOp` and that + // are dominated by the outermost enclosing loop. + DominanceInfo domInfo(outermostEnclosingForOp); + getBackwardSlice(padTensorOp.getOperation(), &backwardSlice, + [&](Operation *op) { + return domInfo.dominates(outermostEnclosingForOp, op); + }); + if (backwardSlice.empty()) + return; + // Add `padTensorOp` itself to the backward slice. + backwardSlice.insert(padTensorOp.getOperation()); + + // Remove all ops in the backward slice that are not used to index the padded + // tensor. In particular, keep `padTensorOp`, `sliceOp`, and the loop and + // affine operations used for the index computation. + dropNonIndexDependencies(padTensorOp, sliceOp); + + // Check if an op has a region it is either `padTensorOp`, a scf::ForOp, or a + // LinalgOp. + for (Operation *op : backwardSlice) { + if (op != padTensorOp && op->getNumRegions() > 0 && + !isa(op)) { + LLVM_DEBUG(DBGS() << "Unsupported op with region: " << *op + << " -> skip\n"); + return; + } + } - // Add only the loops part of `indexingLoops` to the packing loops. All other - // loops are not used to index the padded data and consequently access the - // same data in every loop iteration. Adding them to the packing loops would - // increase the cache footprint of the packed data by storing the same data - // multiple times. + // Add only the loops part of the filtered `backwardSlice` to the packing + // loops. All other loops are not used to index the padded data and + // consequently access the same data in every loop iteration. Adding them to + // the packing loops would increase the cache footprint of the packed data + // by storing the same data multiple times. for (scf::ForOp forOp : llvm::reverse(reverseEnclosingLoops)) - if (!indexingLoops.empty() && indexingLoops.back() == forOp) - packingLoops.push_back(indexingLoops.pop_back_val()); - assert(indexingLoops.empty() && - "expect the all indexing loops are enclosing loops"); - + if (backwardSlice.contains(forOp)) + packingLoops.push_back(forOp); if (packingLoops.empty()) { LLVM_DEBUG(DBGS() << "Cannot find a packing loop -> skip\n"); return; @@ -225,9 +228,8 @@ valid = true; } -SmallVector -HoistingAnalysis::getIndexingLoops(PadTensorOp padTensorOp, - tensor::ExtractSliceOp sliceOp) { +void HoistingAnalysis::dropNonIndexDependencies( + PadTensorOp padTensorOp, tensor::ExtractSliceOp sliceOp) { // Set of all values used for index computation. SetVector indexEdges; @@ -239,17 +241,25 @@ indexEdges.insert(operand); }; + // Check if any operation result is contained in `indexEdges`. + auto hasIndexResult = [&](Operation *operation) { + return llvm::any_of(operation->getResults(), [&](Value result) { + return indexEdges.contains(result); + }); + }; + // Starting from `padTensorOp` and `sliceOp` walk the use-def edges of index // type in `backwardSlice`. Add the index operands of an operation to - // `indexEdges` if one of its results is an index edge found so far and store - // all loops part of the index computation to `indexingLoops`. + // `indexEdges` and remove all operations from `backwardSlice` that are not + // part of the index computation. // // Example: // ``` // %source = linalg.fill(%cst, %arg0) // scf.for %i - // scf.for %j - // scf.for %k // not used to index %source! + // %unrelated = linalg.fill(%cst, %arg1) // not used to index %source! + // scf.for %j (%arg2 = %unrelated) + // scf.for %k // not used to index %source! // %ubi = affine.min #map(%i) // %ubj = affine.min #map(%j) // %slice = tensor.extract_slice %source [%i, %j] [%ubi, %ubj] @@ -257,8 +267,7 @@ // ``` // After iterating `backwardSlice` we obtain: // indexEdges = [%i, %j, %ubi, %ubj] - // indexingLoops = [scf.for %i, scf.for %j] - SmallVector indexingLoops; + // backwardSlice = backwardSlice / [linalg.fill(%cst, %arg1), scf.for %k] for (Operation *op : llvm::reverse(backwardSlice)) { // Add the index operands of `padTensorOp` and `sliceOp` to start the // exploration of the index computation. @@ -267,22 +276,24 @@ continue; } // Add the index operands of the loop if its induction variable is - // used for index computation. Additionally, insert the loop into - // `indexingLoops` + // used for index computation. if (auto forOp = dyn_cast(op)) { if (indexEdges.contains(forOp.getInductionVar())) { addIndexOperandsToIndexEdges(op); - indexingLoops.push_back(forOp); continue; } } // Add the index operands of all other operations if at least one result is // used for index computation. - if (llvm::any_of(op->getResults(), - [&](Value result) { return indexEdges.contains(result); })) + if (hasIndexResult(op)) { addIndexOperandsToIndexEdges(op); + continue; + } + // Remove all other operation not used by the index computation except for + // constant operations that may be padding values used by `padTensorOp`. + if (!isa(op)) + backwardSlice.remove(op); } - return indexingLoops; } SmallVector @@ -387,8 +398,6 @@ clonedLoopIvs.reserve(nPackedLoops); leadingPackedTensorIndexings.reserve(nPackedLoops); BlockAndValueMapping bvm; - // Insert `opToHoist` into the backwardSlice so we clone it too. - analysis.backwardSlice.insert(opToHoist); // Stack step 1. iteratively clone loops and push `packedTensor`. for (Operation *op : analysis.backwardSlice) { // Specifically sit out in the extract_slice(packedTensor) case: this is the @@ -405,10 +414,8 @@ } // TODO: support more cases as they appear. auto forOp = dyn_cast(op); - assert(forOp && "Expected scf::ForOp when hoisting pad ops"); - // Unused loop, just skip it. - if (!llvm::is_contained(analysis.packingLoops, forOp)) - continue; + assert(forOp && llvm::is_contained(analysis.packingLoops, forOp) && + "expect an scf::ForOp that is a packing loop"); auto clonedForOp = b.create(loc, bvm.lookupOrDefault(forOp.lowerBound()), diff --git a/mlir/test/Dialect/Linalg/pad-and-hoist.mlir b/mlir/test/Dialect/Linalg/pad-and-hoist.mlir --- a/mlir/test/Dialect/Linalg/pad-and-hoist.mlir +++ b/mlir/test/Dialect/Linalg/pad-and-hoist.mlir @@ -195,6 +195,82 @@ // ----- +// CHECK-DAG: #[[DIV3:[0-9a-z]+]] = affine_map<(d0) -> (d0 ceildiv 3)> + +// CHECK: multiple_operations +// CHECK-DOUBLE: multiple_operations +// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]*]]: tensor<24x12xf32> +// CHECK-SAME: %[[ARG1:[0-9a-zA-Z]*]]: tensor<12x25xf32> +// CHECK-SAME: %[[ARG2:[0-9a-zA-Z]*]]: tensor<24x25xf32> +func @multiple_operations(%arg0: tensor<24x12xf32>, + %arg1: tensor<12x25xf32>, + %arg2: tensor<24x25xf32>) -> tensor<24x25xf32> { + %c12 = arith.constant 12 : index + %c3 = arith.constant 3 : index + %c0 = arith.constant 0 : index + %c25 = arith.constant 25 : index + %c24 = arith.constant 24 : index + %c5 = arith.constant 5 : index + %c4 = arith.constant 4 : index + %cst = arith.constant 0.000000e+00 : f32 + + // CHECK: scf.for %[[IV0:[0-9a-zA-Z]*]] = + %0 = scf.for %arg3 = %c0 to %c24 step %c4 iter_args(%arg4 = %arg2) -> (tensor<24x25xf32>) { + + // Packing the first input operand for all values of IV2 (IV2x4x3). + // CHECK: = linalg.init_tensor [4, 4, 3] + // CHECK: %[[PT0:.*]] = scf.for %[[PIV0:[0-9a-z]+]] = + // CHECK: %[[PIDX0:.*]] = affine.apply #[[DIV3]](%[[PIV0]]) + // CHECK: %[[T0:.*]] = tensor.extract_slice %[[ARG0]] + // CHECK: %[[T1:.*]] = linalg.pad_tensor %[[T0]] nofold + // CHECK: %[[T2:.*]] = tensor.insert_slice %[[T1:.*]] into %{{.*}}[%[[PIDX0]], 0, 0] + // CHECK: scf.yield %[[T2:.*]] + + // CHECK: scf.for %[[IV1:[0-9a-zA-Z]*]] = + %1 = scf.for %arg5 = %c0 to %c25 step %c5 iter_args(%arg6 = %arg4) -> (tensor<24x25xf32>) { + %2 = tensor.extract_slice %arg6[%arg3, %arg5] [4, 5] [1, 1] : tensor<24x25xf32> to tensor<4x5xf32> + + // Check the fill and pad_tensor ops do not prevent hoisting. + %3 = linalg.pad_tensor %2 nofold low[%c0, %c0] high[%c0, %c0] { + ^bb0(%arg7: index, %arg8: index): // no predecessors + linalg.yield %cst : f32 + } : tensor<4x5xf32> to tensor<4x5xf32> + %4 = linalg.fill(%cst, %3) : f32, tensor<4x5xf32> -> tensor<4x5xf32> + + // Packing the second input operand for all values of IV2 (IV2x3x5). + // CHECK: = linalg.init_tensor [4, 3, 5] + // CHECK: %[[PT1:.*]] = scf.for %[[PIV1:[0-9a-z]+]] = + // CHECK: %[[PIDX1:.*]] = affine.apply #[[DIV3]](%[[PIV1]]) + // CHECK: %[[T3:.*]] = tensor.extract_slice %[[ARG1]] + // CHECK: %[[T4:.*]] = linalg.pad_tensor %[[T3]] nofold + // CHECK: %[[T5:.*]] = tensor.insert_slice %[[T4:.*]] into %{{.*}}[%[[PIDX1]], 0, 0] + // CHECK: scf.yield %[[T5:.*]] + + // CHECK: scf.for %[[IV2:[0-9a-zA-Z]*]] = + %5 = scf.for %arg7 = %c0 to %c12 step %c3 iter_args(%arg8 = %4) -> (tensor<4x5xf32>) { + + // Index the packed operands. + // CHECK-DAG: %[[IDX0:.*]] = affine.apply #[[DIV3]](%[[IV2]]) + // CHECK-DAG: %[[T6:.*]] = tensor.extract_slice %[[PT0]][%[[IDX0]] + // CHECK-DAG: %[[T7:.*]] = tensor.extract_slice %[[PT1]][%[[IDX0]] + %7 = tensor.extract_slice %arg0[%arg3, %arg7] [4, 3] [1, 1] : tensor<24x12xf32> to tensor<4x3xf32> + %8 = tensor.extract_slice %arg1[%arg7, %arg5] [3, 5] [1, 1] : tensor<12x25xf32> to tensor<3x5xf32> + + // Check matmul uses the packed input operands. + // CHECK: = linalg.matmul ins(%[[T6]], %[[T7]] + %9 = linalg.matmul ins(%7, %8 : tensor<4x3xf32>, tensor<3x5xf32>) outs(%arg8 : tensor<4x5xf32>) -> tensor<4x5xf32> + scf.yield %9 : tensor<4x5xf32> + } + %6 = tensor.insert_slice %5 into %arg6[%arg3, %arg5] [4, 5] [1, 1] : tensor<4x5xf32> into tensor<24x25xf32> + scf.yield %6 : tensor<24x25xf32> + } + scf.yield %1 : tensor<24x25xf32> + } + return %0 : tensor<24x25xf32> +} + +// ----- + // CHECK-DOUBLE-DAG: #[[DIV5:[0-9a-z]+]] = affine_map<(d0) -> (d0 ceildiv 5)> // CHECK-DOUBLE-DAG: #[[DIV6:[0-9a-z]+]] = affine_map<(d0) -> (d0 ceildiv 6)> #map0 = affine_map<(d0) -> (15, -d0 + 24)>