diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Hoisting.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Hoisting.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Hoisting.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Hoisting.h @@ -21,8 +21,9 @@ // TODO: generalize on a per-need basis. void hoistViewAllocOps(FuncOp func); -/// Hoist vector.transfer_read/vector.transfer_write pairs out of immediately -/// enclosing scf::ForOp iteratively, if the following conditions are true: +/// Hoist vector.transfer_read/vector.transfer_write on buffers pairs out of +/// immediately enclosing scf::ForOp iteratively, if the following conditions +/// are true: /// 1. The two ops access the same memref with the same indices. /// 2. All operands are invariant under the enclosing scf::ForOp. /// 3. No uses of the memref either dominate the transfer_read or are @@ -35,6 +36,10 @@ // TODO: generalize on a per-need basis. void hoistRedundantVectorTransfers(FuncOp func); +/// Same behavior as `hoistRedundantVectorTransfers` but works on tensors +/// instead of buffers. +void hoistRedundantVectorTransfersOnTensor(FuncOp func); + } // namespace linalg } // namespace mlir diff --git a/mlir/include/mlir/Dialect/Vector/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/VectorUtils.h --- a/mlir/include/mlir/Dialect/Vector/VectorUtils.h +++ b/mlir/include/mlir/Dialect/Vector/VectorUtils.h @@ -165,6 +165,12 @@ bool isDisjointTransferSet(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB); +/// Same behavior as `isDisjointTransferSet` but doesn't require the operations +/// to have the same tensor/memref. This allows comparing operations accessing +/// different tensors. +bool isDisjointTransferIndices(VectorTransferOpInterface transferA, + VectorTransferOpInterface transferB); + namespace matcher { /// Matches vector.transfer_read, vector.transfer_write and ops that return a diff --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp @@ -81,12 +81,151 @@ } } +/// Look for a transfer_read, in the given tensor uses, accessing the same +/// offset as the transfer_write. +static vector::TransferReadOp +findMatchingTransferRead(vector::TransferWriteOp write, Value srcTensor) { + for (Operation *user : srcTensor.getUsers()) { + auto read = dyn_cast(user); + if (read && read.indices() == write.indices() && + read.getVectorType() == write.getVectorType()) { + return read; + } + } + return nullptr; +} + +/// Check if the chunk of data inserted by the transfer_write in the given +/// tensor are read by any other op than the read candidate. +static bool tensorChunkAccessedByUnknownOp(vector::TransferWriteOp write, + vector::TransferReadOp candidateRead, + Value srcTensor) { + // Make sure none of the other uses read the part of the tensor modified + // by the transfer_write. + llvm::SmallVector uses; + uses.push_back(srcTensor.getUses()); + while (!uses.empty()) { + for (OpOperand &use : uses.pop_back_val()) { + Operation *user = use.getOwner(); + // Skip the candidate use, only inspect the "other" uses. + if (user == candidateRead.getOperation() || user == write.getOperation()) + continue; + // Consider all transitive uses through a vector.transfer_write. + if (auto writeUser = dyn_cast(user)) { + uses.push_back(writeUser->getResult(0).getUses()); + continue; + } + // Consider all nested uses through an scf::ForOp. We may have + // pass-through tensor arguments left from previous level of + // hoisting. + if (auto forUser = dyn_cast(user)) { + Value arg = forUser.getLoopBody().getArgument( + use.getOperandNumber() - forUser.getNumControlOperands() + + /*iv value*/ 1); + uses.push_back(arg.getUses()); + continue; + } + // Follow the use yield as long as it doesn't escape the original + // region. + scf::YieldOp yieldUser = dyn_cast(user); + if (yieldUser && + write->getParentOp()->isAncestor(yieldUser->getParentOp())) { + Value ret = yieldUser->getParentOp()->getResult(use.getOperandNumber()); + uses.push_back(ret.getUses()); + continue; + } + auto read = dyn_cast(user); + if (!read || !isDisjointTransferIndices( + cast(read.getOperation()), + cast(write.getOperation()))) { + return true; + } + } + } + return false; +} + +// To hoist transfer op on tensor the logic can be significantly simplified +// compared to the case on buffer. The transformation follows this logic: +// 1. Look for transfer_write with a single use from ForOp yield +// 2. Check the uses of the matching block argument and look for a transfer_read +// with the same indices. +// 3. Check that all the other uses of the tensor argument are either disjoint +// tensor_read or transfer_write. For transfer_write uses recurse to make sure +// the new tensor has the same restrictions on its uses. +// 4. Hoist the tensor_read/tensor_write and update the tensor SSA links. +// After this transformation the scf.forOp may have unused arguments that can be +// remove by the canonicalization pass. +void mlir::linalg::hoistRedundantVectorTransfersOnTensor(FuncOp func) { + bool changed = true; + while (changed) { + changed = false; + func.walk([&](scf::ForOp forOp) { + Operation *yield = forOp.getBody()->getTerminator(); + for (auto it : llvm::enumerate(forOp.getRegionIterArgs())) { + Value ret = yield->getOperand(it.index()); + auto write = ret.getDefiningOp(); + if (!write || !write->hasOneUse()) + continue; + LLVM_DEBUG(DBGS() << "Candidate write for hoisting: " + << *write.getOperation() << "\n"); + if (llvm::any_of(write.indices(), [&forOp](Value index) { + return !forOp.isDefinedOutsideOfLoop(index); + })) + continue; + // Find a read with the same type and indices. + vector::TransferReadOp matchingRead = + findMatchingTransferRead(write, it.value()); + // Make sure none of the other uses read the part of the tensor modified + // by the transfer_write. + if (!matchingRead || + tensorChunkAccessedByUnknownOp(write, matchingRead, it.value())) + continue; + + // Hoist read before. + if (failed(forOp.moveOutOfLoop({matchingRead}))) + llvm_unreachable( + "Unexpected failure to move transfer read out of loop"); + // Update the source tensor. + matchingRead.sourceMutable().assign(forOp.initArgs()[it.index()]); + + // Hoist write after. + write->moveAfter(forOp); + yield->setOperand(it.index(), write.source()); + + // Rewrite `loop` with new yields by cloning and erase the original + // loop. + OpBuilder b(matchingRead); + auto newForOp = + cloneWithNewYields(b, forOp, matchingRead.vector(), write.vector()); + + // Transfer write has been hoisted, need to update the vector and tensor + // source. Replace the result of the loop to use the new tensor created + // outside the loop. + newForOp.getResult(it.index()).replaceAllUsesWith(write.getResult(0)); + write.vectorMutable().assign(newForOp.getResults().back()); + write.sourceMutable().assign(newForOp.getResult(it.index())); + + changed = true; + forOp.erase(); + // Need to interrupt and restart because erasing the loop messes up the + // walk. + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + } +} + void mlir::linalg::hoistRedundantVectorTransfers(FuncOp func) { bool changed = true; while (changed) { changed = false; func.walk([&](vector::TransferReadOp transferRead) { + if (!transferRead.getShapedType().isa()) + return WalkResult::advance(); + LLVM_DEBUG(DBGS() << "Candidate for hoisting: " << *transferRead.getOperation() << "\n"); auto loop = dyn_cast(transferRead->getParentOp()); diff --git a/mlir/lib/Dialect/Vector/VectorUtils.cpp b/mlir/lib/Dialect/Vector/VectorUtils.cpp --- a/mlir/lib/Dialect/Vector/VectorUtils.cpp +++ b/mlir/lib/Dialect/Vector/VectorUtils.cpp @@ -312,10 +312,8 @@ return true; } -bool mlir::isDisjointTransferSet(VectorTransferOpInterface transferA, - VectorTransferOpInterface transferB) { - if (transferA.source() != transferB.source()) - return false; +bool mlir::isDisjointTransferIndices(VectorTransferOpInterface transferA, + VectorTransferOpInterface transferB) { // For simplicity only look at transfer of same type. if (transferA.getVectorType() != transferB.getVectorType()) return false; @@ -345,3 +343,10 @@ } return false; } + +bool mlir::isDisjointTransferSet(VectorTransferOpInterface transferA, + VectorTransferOpInterface transferB) { + if (transferA.source() != transferB.source()) + return false; + return isDisjointTransferIndices(transferA, transferB); +} diff --git a/mlir/test/Dialect/Linalg/hoisting.mlir b/mlir/test/Dialect/Linalg/hoisting.mlir --- a/mlir/test/Dialect/Linalg/hoisting.mlir +++ b/mlir/test/Dialect/Linalg/hoisting.mlir @@ -230,3 +230,169 @@ } return } + +// VECTOR_TRANSFERS-LABEL: func @hoist_vector_transfer_pairs_tensor +func @hoist_vector_transfer_pairs_tensor( + %tensor0: tensor, %tensor1: tensor, %tensor2: tensor, + %tensor3: tensor, %tensor4: tensor, %tensor5: tensor, + %val: index, %lb : index, %ub : index, %step: index) -> + (tensor, tensor, tensor, tensor, + tensor, tensor) { + %c0 = constant 0 : index + %cst = constant 0.0 : f32 + +// VECTOR_TRANSFERS: vector.transfer_read %{{.*}} : tensor, vector<1xf32> +// VECTOR_TRANSFERS: scf.for {{.*}} iter_args({{.*}}) -> +// VECTOR_TRANSFERS-SAME: (tensor, tensor, tensor, tensor, tensor, tensor, vector<1xf32>) { +// VECTOR_TRANSFERS: vector.transfer_read %{{.*}} : tensor, vector<2xf32> +// VECTOR_TRANSFERS: scf.for {{.*}} iter_args({{.*}}) -> +// VECTOR_TRANSFERS-SAME: (tensor, tensor, tensor, tensor, tensor, tensor, vector<2xf32>, vector<1xf32>) { +// VECTOR_TRANSFERS: vector.transfer_read %{{.*}} : tensor, vector<3xf32> +// VECTOR_TRANSFERS: vector.transfer_read %{{.*}} : tensor, vector<4xf32> +// VECTOR_TRANSFERS: "some_crippling_use"(%{{.*}}) : (tensor) -> () +// VECTOR_TRANSFERS: vector.transfer_read %{{.*}} : tensor, vector<5xf32> +// VECTOR_TRANSFERS: "some_use"(%{{.*}}) : (vector<1xf32>) -> vector<1xf32> +// VECTOR_TRANSFERS: "some_use"(%{{.*}}) : (vector<2xf32>) -> vector<2xf32> +// VECTOR_TRANSFERS: "some_use"(%{{.*}}) : (tensor) -> vector<3xf32> +// VECTOR_TRANSFERS: "some_use"(%{{.*}}) : (vector<4xf32>) -> vector<4xf32> +// VECTOR_TRANSFERS: "some_use"(%{{.*}}) : (vector<5xf32>) -> vector<5xf32> +// VECTOR_TRANSFERS: vector.transfer_write %{{.*}} : vector<3xf32>, tensor +// VECTOR_TRANSFERS: vector.transfer_write %{{.*}} : vector<4xf32>, tensor +// VECTOR_TRANSFERS: vector.transfer_write %{{.*}} : vector<5xf32>, tensor +// VECTOR_TRANSFERS: "some_crippling_use"(%{{.*}}) : (tensor) -> () +// VECTOR_TRANSFERS: scf.yield {{.*}} : +// VECTOR_TRANSFERS-SAME: tensor, tensor, tensor, tensor, tensor, tensor, vector<2xf32>, vector<1xf32> +// VECTOR_TRANSFERS: } +// VECTOR_TRANSFERS: vector.transfer_write %{{.*}} : vector<2xf32>, tensor +// VECTOR_TRANSFERS: scf.yield {{.*}} : +// VECTOR_TRANSFERS-SAME: tensor, tensor, tensor, tensor, tensor, tensor, vector<1xf32> +// VECTOR_TRANSFERS: } +// VECTOR_TRANSFERS: vector.transfer_write %{{.*}} : vector<1xf32>, tensor + %0:6 = scf.for %i = %lb to %ub step %step + iter_args(%arg0 = %tensor0, %arg1 = %tensor1, %arg2 = %tensor2, + %arg3 = %tensor3, %arg4 = %tensor4, %arg5 = %tensor5) + -> (tensor, tensor, tensor, tensor, + tensor, tensor) { + %1:6 = scf.for %j = %lb to %ub step %step + iter_args(%arg6 = %arg0, %arg7 = %arg1, %arg8 = %arg2, + %arg9 = %arg3, %arg10 = %arg4, %arg11 = %arg5) + -> (tensor, tensor, tensor, tensor, + tensor, tensor) { + %r0 = vector.transfer_read %arg7[%c0, %c0], %cst: tensor, vector<1xf32> + %r1 = vector.transfer_read %arg6[%i, %i], %cst: tensor, vector<2xf32> + %r2 = vector.transfer_read %arg8[%c0, %c0], %cst: tensor, vector<3xf32> + %r3 = vector.transfer_read %arg9[%c0, %c0], %cst: tensor, vector<4xf32> + "some_crippling_use"(%arg10) : (tensor) -> () + %r4 = vector.transfer_read %arg10[%c0, %c0], %cst: tensor, vector<5xf32> + %r5 = vector.transfer_read %arg11[%c0, %c0], %cst: tensor, vector<6xf32> + "some_crippling_use"(%arg11) : (tensor) -> () + %u0 = "some_use"(%r0) : (vector<1xf32>) -> vector<1xf32> + %u1 = "some_use"(%r1) : (vector<2xf32>) -> vector<2xf32> + %u2 = "some_use"(%arg8) : (tensor) -> vector<3xf32> + %u3 = "some_use"(%r3) : (vector<4xf32>) -> vector<4xf32> + %u4 = "some_use"(%r4) : (vector<5xf32>) -> vector<5xf32> + %u5 = "some_use"(%r5) : (vector<6xf32>) -> vector<6xf32> + %w1 = vector.transfer_write %u0, %arg7[%c0, %c0] : vector<1xf32>, tensor + %w0 = vector.transfer_write %u1, %arg6[%i, %i] : vector<2xf32>, tensor + %w2 = vector.transfer_write %u2, %arg8[%c0, %c0] : vector<3xf32>, tensor + %w3 = vector.transfer_write %u3, %arg9[%c0, %c0] : vector<4xf32>, tensor + %w4 = vector.transfer_write %u4, %arg10[%c0, %c0] : vector<5xf32>, tensor + %w5 = vector.transfer_write %u5, %arg11[%c0, %c0] : vector<6xf32>, tensor + "some_crippling_use"(%w3) : (tensor) -> () + scf.yield %w0, %w1, %w2, %w3, %w4, %w5 : + tensor, tensor, tensor, tensor, + tensor, tensor + } + scf.yield %1#0, %1#1, %1#2, %1#3, %1#4, %1#5 : + tensor, tensor, tensor, tensor, + tensor, tensor + } + return %0#0, %0#1, %0#2, %0#3, %0#4, %0#5 : + tensor, tensor, tensor, tensor, + tensor, tensor +} + +// VECTOR_TRANSFERS-LABEL: func @hoist_vector_transfer_pairs_disjoint_tensor( +// VECTOR_TRANSFERS-SAME: %[[TENSOR0:[a-zA-Z0-9]*]]: tensor, +// VECTOR_TRANSFERS-SAME: %[[TENSOR1:[a-zA-Z0-9]*]]: tensor, +// VECTOR_TRANSFERS-SAME: %[[TENSOR2:[a-zA-Z0-9]*]]: tensor, +// VECTOR_TRANSFERS-SAME: %[[TENSOR3:[a-zA-Z0-9]*]]: tensor, +func @hoist_vector_transfer_pairs_disjoint_tensor( + %tensor0: tensor, %tensor1: tensor, + %tensor2: tensor, %tensor3: tensor, + %val: index, %lb : index, %ub : index, %step: index, + %random_index : index) -> + (tensor, tensor, tensor, tensor) { + %c0 = constant 0 : index + %c1 = constant 1 : index + %c3 = constant 3 : index + %cst = constant 0.0 : f32 + +// VECTOR_TRANSFERS: vector.transfer_read %[[TENSOR2]]{{.*}} : tensor, vector<3xf32> +// VECTOR_TRANSFERS: vector.transfer_read %[[TENSOR2]]{{.*}} : tensor, vector<3xf32> +// VECTOR_TRANSFERS: vector.transfer_read %[[TENSOR3]]{{.*}} : tensor, vector<4xf32> +// VECTOR_TRANSFERS: vector.transfer_read %[[TENSOR3]]{{.*}} : tensor, vector<4xf32> +// VECTOR_TRANSFERS: %[[R:.*]]:8 = scf.for {{.*}} iter_args({{.*}}) -> +// VECTOR_TRANSFERS-SAME: (tensor, tensor, tensor, tensor, vector<3xf32>, vector<3xf32>, vector<4xf32>, vector<4xf32>) { +// VECTOR_TRANSFERS: scf.for {{.*}} iter_args({{.*}}) -> +// VECTOR_TRANSFERS-SAME: (tensor, tensor, tensor, tensor, vector<3xf32>, vector<3xf32>, vector<4xf32>, vector<4xf32>) { +// VECTOR_TRANSFERS: vector.transfer_read %[[TENSOR1]]{{.*}} : tensor, vector<2xf32> +// VECTOR_TRANSFERS: vector.transfer_read %[[TENSOR1]]{{.*}} : tensor, vector<2xf32> +// VECTOR_TRANSFERS: "some_use"(%{{.*}}) : (vector<2xf32>) -> vector<2xf32> +// VECTOR_TRANSFERS: "some_use"(%{{.*}}) : (vector<2xf32>) -> vector<2xf32> +// VECTOR_TRANSFERS: "some_use"(%{{.*}}) : (vector<3xf32>) -> vector<3xf32> +// VECTOR_TRANSFERS: "some_use"(%{{.*}}) : (vector<3xf32>) -> vector<3xf32> +// VECTOR_TRANSFERS: "some_use"(%{{.*}}) : (vector<4xf32>) -> vector<4xf32> +// VECTOR_TRANSFERS: "some_use"(%{{.*}}) : (vector<4xf32>) -> vector<4xf32> +// VECTOR_TRANSFERS: "some_use"(%{{.*}}) : (vector<2xf32>) -> vector<2xf32> +// VECTOR_TRANSFERS: "some_use"(%{{.*}}) : (vector<2xf32>) -> vector<2xf32> +// VECTOR_TRANSFERS: vector.transfer_write %{{.*}}, %{{.*}}{{.*}} : vector<2xf32>, tensor +// VECTOR_TRANSFERS: vector.transfer_write %{{.*}}, %{{.*}}{{.*}} : vector<2xf32>, tensor +// VECTOR_TRANSFERS: scf.yield {{.*}} : +// VECTOR_TRANSFERS-SAME: tensor, tensor, tensor, tensor, vector<3xf32>, vector<3xf32>, vector<4xf32>, vector<4xf32> +// VECTOR_TRANSFERS: } +// VECTOR_TRANSFERS: scf.yield {{.*}} : +// VECTOR_TRANSFERS-SAME: tensor, tensor, tensor, tensor, vector<3xf32>, vector<3xf32>, vector<4xf32>, vector<4xf32> +// VECTOR_TRANSFERS: } +// VECTOR_TRANSFERS: %[[TENSOR4:.*]] = vector.transfer_write %{{.*}}, %[[R]]#3{{.*}} : vector<4xf32>, tensor +// VECTOR_TRANSFERS: vector.transfer_write %{{.*}}, %[[TENSOR4]]{{.*}} : vector<4xf32>, tensor +// VECTOR_TRANSFERS: %[[TENSOR5:.*]] = vector.transfer_write %{{.*}}, %[[R]]#2{{.*}} : vector<3xf32>, tensor +// VECTOR_TRANSFERS: vector.transfer_write %{{.*}}, %[[TENSOR5]]{{.*}} : vector<3xf32>, tensor + %0:4 = scf.for %i = %lb to %ub step %step + iter_args(%arg0 = %tensor0, %arg1 = %tensor1, %arg2 = %tensor2, + %arg3 = %tensor3) + -> (tensor, tensor, tensor, tensor) { + %1:4 = scf.for %j = %lb to %ub step %step + iter_args(%arg4 = %arg0, %arg5 = %arg1, %arg6 = %arg2, + %arg7 = %arg3) + -> (tensor, tensor, tensor, tensor) { + %r00 = vector.transfer_read %arg5[%c0, %c0], %cst: tensor, vector<2xf32> + %r01 = vector.transfer_read %arg5[%c0, %c1], %cst: tensor, vector<2xf32> + %r20 = vector.transfer_read %arg6[%c0, %c0], %cst: tensor, vector<3xf32> + %r21 = vector.transfer_read %arg6[%c0, %c3], %cst: tensor, vector<3xf32> + %r30 = vector.transfer_read %arg7[%c0, %random_index], %cst: tensor, vector<4xf32> + %r31 = vector.transfer_read %arg7[%c1, %random_index], %cst: tensor, vector<4xf32> + %r10 = vector.transfer_read %arg4[%i, %i], %cst: tensor, vector<2xf32> + %r11 = vector.transfer_read %arg4[%random_index, %random_index], %cst: tensor, vector<2xf32> + %u00 = "some_use"(%r00) : (vector<2xf32>) -> vector<2xf32> + %u01 = "some_use"(%r01) : (vector<2xf32>) -> vector<2xf32> + %u20 = "some_use"(%r20) : (vector<3xf32>) -> vector<3xf32> + %u21 = "some_use"(%r21) : (vector<3xf32>) -> vector<3xf32> + %u30 = "some_use"(%r30) : (vector<4xf32>) -> vector<4xf32> + %u31 = "some_use"(%r31) : (vector<4xf32>) -> vector<4xf32> + %u10 = "some_use"(%r10) : (vector<2xf32>) -> vector<2xf32> + %u11 = "some_use"(%r11) : (vector<2xf32>) -> vector<2xf32> + %w10 = vector.transfer_write %u00, %arg5[%c0, %c0] : vector<2xf32>, tensor + %w11 = vector.transfer_write %u01, %w10[%c0, %c1] : vector<2xf32>, tensor + %w20 = vector.transfer_write %u20, %arg6[%c0, %c0] : vector<3xf32>, tensor + %w21 = vector.transfer_write %u21, %w20[%c0, %c3] : vector<3xf32>, tensor + %w30 = vector.transfer_write %u30, %arg7[%c0, %random_index] : vector<4xf32>, tensor + %w31 = vector.transfer_write %u31, %w30[%c1, %random_index] : vector<4xf32>, tensor + %w00 = vector.transfer_write %u10, %arg4[%i, %i] : vector<2xf32>, tensor + %w01 = vector.transfer_write %u11, %w00[%random_index, %random_index] : vector<2xf32>, tensor + scf.yield %w01, %w11, %w21, %w31 : tensor, tensor, tensor, tensor + } + scf.yield %1#0, %1#1, %1#2, %1#3 : tensor, tensor, tensor, tensor + } + return %0#0, %0#1, %0#2, %0#3 : tensor, tensor, tensor, tensor +} diff --git a/mlir/test/lib/Transforms/TestLinalgHoisting.cpp b/mlir/test/lib/Transforms/TestLinalgHoisting.cpp --- a/mlir/test/lib/Transforms/TestLinalgHoisting.cpp +++ b/mlir/test/lib/Transforms/TestLinalgHoisting.cpp @@ -47,6 +47,7 @@ } if (testHoistRedundantTransfers) { hoistRedundantVectorTransfers(getFunction()); + hoistRedundantVectorTransfersOnTensor(getFunction()); return; } }