diff --git a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp @@ -167,23 +167,64 @@ if (analysisFailure || backwardSlice.empty()) return; - // Backward slice is a topologically sorted list of ops starting at - // `outermostEnclosingForOp`. - assert(outermostEnclosingForOp == backwardSlice.front()); - - // Filter out the loops whose induction variable is not used to compute the - // padded result. As a first approximation, just look for IVs that have no use - // in the backwardSlice. - // These are the dimensions of reuse that we can exploit to reduce the amount - // of copy / memory. - for (scf::ForOp forOp : llvm::reverse(reverseEnclosingLoops)) { - for (Operation *user : forOp.getInductionVar().getUsers()) { - if (backwardSlice.contains(user)) { - packingLoops.insert(forOp); - break; - } + // Compute the backward slice used to index the padded operand and filter out + // any enclosing loop not used by the indexing. All iterations of a loop not + // part of the index computation can share the same packed data. As a result, + // detecting these loops is performance critical as it increases cache reuse + // and minizes the size of the packing. We approximate the index computation + // by keeping track of all values holding index computations or data consumed + // by the padding. We add only index and data edges if an operation result + // is found in the index or data edges collected so far. Note that this + // algorithm overapproximates the index computation since it is not aware of + // the op semantics. + SetVector indexEdges; + indexEdges.insert(padTensorOp.getResult()); + SetVector indexSlice; + getBackwardSlice(padTensorOp.getOperation(), &indexSlice, [&](Operation *op) { + if (!domInfo.dominates(outermostEnclosingForOp, op)) + return false; + // Add ForOps if the induction variable is part of the index computation. At + // the moment, ForOps are the only operations with a body containing index + // computation. + if (auto forOp = dyn_cast(op)) { + // Adding the block arguments mapped to the iteration arguments is not + // needed since we do not support hoisting for output tensors. Verify + // neither the block arguments nor the results are part of the index + // computation. + assert(llvm::none_of(forOp.getIterOpOperands(), + [&](OpOperand &opOperand) { + return indexEdges.contains( + forOp.getRegionIterArgForOpOperand(opOperand)); + }) && + "expect the loop iter args are not part of the index computation"); + assert(llvm::none_of( + forOp.getResults(), + [&](Value value) { return indexEdges.contains(value); }) && + "expect the loop results are not part of the index computation"); + if (!indexEdges.contains(forOp.getInductionVar())) + return false; + indexEdges.insert(forOp.lowerBound()); + indexEdges.insert(forOp.upperBound()); + indexEdges.insert(forOp.step()); + return true; } + // All other operations are part of the index computation if the result is + // an index edge collected so far. + if (llvm::none_of(op->getResults(), + [&](Value value) { return indexEdges.contains(value); })) + return false; + for (Value value : op->getOperands()) + indexEdges.insert(value); + return true; + }); + + // Add only the loops part of the index slice to the packing loops. + for (scf::ForOp forOp : llvm::reverse(reverseEnclosingLoops)) { + if (indexSlice.contains(forOp)) + packingLoops.insert(forOp); } + if (packingLoops.empty()) + valid = false; // The analysis is valid and hoisting can occur. valid = true; diff --git a/mlir/test/Dialect/Linalg/pad-and-hoist.mlir b/mlir/test/Dialect/Linalg/pad-and-hoist.mlir --- a/mlir/test/Dialect/Linalg/pad-and-hoist.mlir +++ b/mlir/test/Dialect/Linalg/pad-and-hoist.mlir @@ -1,5 +1,5 @@ // RUN: mlir-opt %s -test-linalg-transform-patterns="test-pad-pattern pack-paddings=1,1,0 hoist-paddings=2,1,0" -cse -canonicalize -split-input-file | FileCheck %s -// RUN: mlir-opt %s -test-linalg-transform-patterns="test-pad-pattern pack-paddings=1,1,0 hoist-paddings=4,3,0" -cse -canonicalize -split-input-file | FileCheck %s --check-prefix=CHECK-DOUBLE +// RUN: mlir-opt %s -test-linalg-transform-patterns="test-pad-pattern pack-paddings=1,1,0 hoist-paddings=3,2,0" -cse -canonicalize -split-input-file | FileCheck %s --check-prefix=CHECK-DOUBLE // CHECK-DAG: #[[MAP0:[0-9a-z]+]] = affine_map<(d0) -> (5, -d0 + 24)> // CHECK-DAG: #[[MAP1:[0-9a-z]+]] = affine_map<(d0) -> (8, -d0 + 12)> @@ -114,15 +114,11 @@ %c5 = arith.constant 5 : index %c6 = arith.constant 6 : index - // Packing the first input operand. - // CHECK-DOUBLE: = linalg.init_tensor - // CHECK-DOUBLE: = linalg.pad_tensor {{.*}} nofold - // CHECK-DOUBLE: scf.for %[[IV0:[0-9a-zA-Z]*]] = %0 = scf.for %arg3 = %c0 to %c24 step %c15 iter_args(%arg4 = %arg2) -> (tensor<24x25xf32>) { - // Packing the second input operand. - // CHECK-DOUBLE: = linalg.init_tensor + // Packing the first input operand. + // CHECK-DOUBLE: = linalg.init_tensor [3, 5, 12] // CHECK-DOUBLE: = linalg.pad_tensor {{.*}} nofold // CHECK-DOUBLE: scf.for %[[IV1:[0-9a-zA-Z]*]] = @@ -131,6 +127,10 @@ %3 = affine.min #map1(%arg5) %4 = tensor.extract_slice %arg6[%arg3, %arg5] [%2, %3] [1, 1] : tensor<24x25xf32> to tensor + // Packing the second input operand. + // CHECK-DOUBLE: = linalg.init_tensor [%{{.*}}, 12, 6] + // CHECK-DOUBLE: = linalg.pad_tensor {{.*}} nofold + // CHECK-DOUBLE: scf.for %[[IV2:[0-9a-zA-Z]*]] = %5 = scf.for %arg7 = %c0 to %2 step %c5 iter_args(%arg8 = %4) -> (tensor) {