diff --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp @@ -106,8 +106,10 @@ if (write.insertSliceOp) LLVM_DEBUG(DBGS() << "findMatchingTransferRead inserSliceOp: " << *write.insertSliceOp.getOperation() << "\n"); - - for (Operation *user : srcTensor.getUsers()) { + SmallVector users(srcTensor.getUsers().begin(), + srcTensor.getUsers().end()); + while (!users.empty()) { + Operation *user = users.pop_back_val(); LLVM_DEBUG(DBGS() << "findMatchingTransferRead inspect user: " << *user << "\n"); @@ -153,6 +155,16 @@ if (read && read.getIndices() == write.transferWriteOp.getIndices() && read.getVectorType() == write.transferWriteOp.getVectorType()) return HoistableRead{read, sliceOp}; + + if (isa(user)) { + // If we find a write with disjoint indices recurse through its uses. + if (vector::isDisjointTransferIndices( + cast(user), + cast( + write.transferWriteOp.getOperation()))) { + users.append(user->getUsers().begin(), user->getUsers().end()); + } + } } return HoistableRead(); } diff --git a/mlir/test/Dialect/Linalg/hoisting.mlir b/mlir/test/Dialect/Linalg/hoisting.mlir --- a/mlir/test/Dialect/Linalg/hoisting.mlir +++ b/mlir/test/Dialect/Linalg/hoisting.mlir @@ -431,3 +431,41 @@ } return %0#0, %0#1, %0#2 : tensor, tensor, tensor } + +// ----- + +// CHECK-LABEL: func @hoist_vector_transfer_write_pairs_disjoint_tensor( +// CHECK-SAME: %[[T:.*]]: tensor, +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index +// CHECK-DAG: %[[R0:.*]] = vector.transfer_read %[[T]][%[[C0]], %[[C0]]], %{{.*}} : tensor, vector<2xf32> +// CHECK-DAG: %[[R1:.*]] = vector.transfer_read %[[T]][%[[C0]], %[[C3]]], %{{.*}} : tensor, vector<2xf32> +// CHECK: %[[F:.*]]:2 = scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[R3:.*]] = %[[R1:.*]], %[[R2:.*]] = %[[R0]]) -> (vector<2xf32>, vector<2xf32>) { +// CHECK: %[[R4:.*]] = "some_use"(%[[R2]]) : (vector<2xf32>) -> vector<2xf32> +// CHECK: %[[R5:.*]] = "some_use"(%[[R3]]) : (vector<2xf32>) -> vector<2xf32> +// CHECK: scf.yield %[[R5]], %[[R4]] : vector<2xf32>, vector<2xf32> +// CHECK: } +// CHECK: %[[W0:.*]] = vector.transfer_write %[[F]]#1, %[[T]][%[[C0]], %[[C0]]] : vector<2xf32>, tensor +// CHECK: %[[W1:.*]] = vector.transfer_write %[[F]]#0, %[[W0]][%[[C0]], %[[C3]]] : vector<2xf32>, tensor +// CHECK: return %[[W1]] : tensor +func.func @hoist_vector_transfer_write_pairs_disjoint_tensor( + %tensor: tensor, + %val: index, %lb : index, %ub : index, %step: index) -> + (tensor) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c3 = arith.constant 3 : index + %cst = arith.constant 0.0 : f32 + %1 = scf.for %j = %lb to %ub step %step iter_args(%arg5 = %tensor) + -> (tensor) { + %r00 = vector.transfer_read %arg5[%c0, %c0], %cst: tensor, vector<2xf32> + %u00 = "some_use"(%r00) : (vector<2xf32>) -> vector<2xf32> + %w10 = vector.transfer_write %u00, %arg5[%c0, %c0] : vector<2xf32>, tensor + %r01 = vector.transfer_read %w10[%c0, %c3], %cst: tensor, vector<2xf32> + %u01 = "some_use"(%r01) : (vector<2xf32>) -> vector<2xf32> + %w11 = vector.transfer_write %u01, %w10[%c0, %c3] : vector<2xf32>, tensor + scf.yield %w11 : tensor + } + return %1 : tensor +} +