diff --git a/mlir/lib/Dialect/Vector/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransferOpTransforms.cpp --- a/mlir/lib/Dialect/Vector/VectorTransferOpTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransferOpTransforms.cpp @@ -34,13 +34,33 @@ return op; } +/// Return true if the transfer_write fully writes the data accessed by the +/// transfer_read. +static bool transferEncompasses(vector::TransferWriteOp defWrite, + vector::TransferReadOp read) { + return !defWrite.hasMaskedDim() && defWrite.indices() == read.indices() && + defWrite.getVectorType() == read.getVectorType() && + defWrite.permutation_map() == read.permutation_map(); +} + +/// Return true if the write op fully over-write the priorWrite transfer_write +/// op. +static bool transferEncompasses(vector::TransferWriteOp write, + vector::TransferWriteOp priorWrite) { + return priorWrite.indices() == write.indices() && + priorWrite.getVectorType() == write.getVectorType() && + priorWrite.permutation_map() == write.permutation_map(); +} + namespace { class TransferOptimization { public: TransferOptimization(FuncOp func) : dominators(func), postDominators(func) {} void deadStoreOp(vector::TransferWriteOp); + void deadStoreOpTensor(vector::TransferWriteOp); void storeToLoadForwarding(vector::TransferReadOp); + void storeToLoadForwardingTensor(vector::TransferReadOp); void removeDeadOp() { for (Operation *op : opToErase) op->erase(); @@ -99,9 +119,7 @@ continue; if (auto nextWrite = dyn_cast(user)) { // Check candidate that can override the store. - if (write.indices() == nextWrite.indices() && - write.getVectorType() == nextWrite.getVectorType() && - write.permutation_map() == write.permutation_map() && + if (transferEncompasses(nextWrite, write) && postDominators.postDominates(nextWrite, write)) { if (firstOverwriteCandidate == nullptr || postDominators.postDominates(firstOverwriteCandidate, nextWrite)) @@ -173,10 +191,8 @@ cast(write.getOperation()), cast(read.getOperation()))) continue; - if (dominators.dominates(write, read) && !write.hasMaskedDim() && - write.indices() == read.indices() && - write.getVectorType() == read.getVectorType() && - write.permutation_map() == read.permutation_map()) { + if (dominators.dominates(write, read) && + transferEncompasses(write, read)) { if (lastwrite == nullptr || dominators.dominates(lastwrite, write)) lastwrite = write; else @@ -214,15 +230,62 @@ opToErase.push_back(read.getOperation()); } +/// Walk up the SSA links, if any write gets fully overwritten we can skip it. +/// If it has no more uses it becomes dead. +void TransferOptimization::deadStoreOpTensor(vector::TransferWriteOp write) { + auto defWrite = write.source().getDefiningOp(); + while (defWrite) { + if (transferEncompasses(write, defWrite)) { + write.sourceMutable().assign(defWrite.source()); + if (defWrite->use_empty()) + opToErase.push_back(defWrite.getOperation()); + return; + } + if (!isDisjointTransferIndices( + cast(defWrite.getOperation()), + cast(write.getOperation()))) + break; + defWrite = defWrite.source().getDefiningOp(); + } +} + +/// Walk up the SSA links, if any write fully match the written vector we can +/// replace the read by the vector. The read becomes dead and can be removed. +void TransferOptimization::storeToLoadForwardingTensor( + vector::TransferReadOp read) { + auto defWrite = read.source().getDefiningOp(); + while (defWrite) { + if (transferEncompasses(defWrite, read)) { + read.replaceAllUsesWith(defWrite.vector()); + opToErase.push_back(read.getOperation()); + return; + } + if (!isDisjointTransferIndices( + cast(defWrite.getOperation()), + cast(read.getOperation()))) + break; + defWrite = defWrite.source().getDefiningOp(); + } +} + } // namespace void mlir::vector::transferOpflowOpt(FuncOp func) { TransferOptimization opt(func); // Run store to load forwarding first since it can expose more dead store // opportunity. - func.walk( - [&](vector::TransferReadOp read) { opt.storeToLoadForwarding(read); }); + func.walk([&](vector::TransferReadOp read) { + if (read.getShapedType().isa()) + opt.storeToLoadForwarding(read); + else + opt.storeToLoadForwardingTensor(read); + }); opt.removeDeadOp(); - func.walk([&](vector::TransferWriteOp write) { opt.deadStoreOp(write); }); + func.walk([&](vector::TransferWriteOp write) { + if (write.getShapedType().isa()) + opt.deadStoreOp(write); + else + opt.deadStoreOpTensor(write); + }); opt.removeDeadOp(); } diff --git a/mlir/test/Dialect/Vector/vector-transferop-opt.mlir b/mlir/test/Dialect/Vector/vector-transferop-opt.mlir --- a/mlir/test/Dialect/Vector/vector-transferop-opt.mlir +++ b/mlir/test/Dialect/Vector/vector-transferop-opt.mlir @@ -13,16 +13,16 @@ %c4 = constant 4 : index %c0 = constant 0 : index %cf0 = constant 0.0 : f32 - vector.transfer_write %v0, %arg1[%c1, %c0] {masked = [false, false]} : + vector.transfer_write %v0, %arg1[%c1, %c0] {masked = [false, false]} : vector<1x4xf32>, memref<4x4xf32> - %0 = vector.transfer_read %arg1[%c1, %c0], %cf0 {masked = [false, false]} : + %0 = vector.transfer_read %arg1[%c1, %c0], %cf0 {masked = [false, false]} : memref<4x4xf32>, vector<1x4xf32> - %x = scf.for %i0 = %c0 to %c4 step %c1 iter_args(%acc = %0) + %x = scf.for %i0 = %c0 to %c4 step %c1 iter_args(%acc = %0) -> (vector<1x4xf32>) { %1 = addf %acc, %acc : vector<1x4xf32> scf.yield %1 : vector<1x4xf32> } - vector.transfer_write %x, %arg1[%c1, %c0] {masked = [false, false]} : + vector.transfer_write %x, %arg1[%c1, %c0] {masked = [false, false]} : vector<1x4xf32>, memref<4x4xf32> return } @@ -103,7 +103,7 @@ // CHECK: vector.transfer_read // CHECK: return func @dead_store_region(%arg0: i1, %arg1 : memref<4x4xf32>, - %v0 : vector<1x4xf32>, %v1 : vector<1x4xf32>, %i : index) + %v0 : vector<1x4xf32>, %v1 : vector<1x4xf32>, %i : index) -> (vector<1x4xf32>) { %c0 = constant 0 : index %c1 = constant 1 : index @@ -184,3 +184,56 @@ return } +// CHECK-LABEL: func @forward_dead_store_tensor +// CHECK-NOT: vector.transfer_write +// CHECK-NOT: vector.transfer_read +// CHECK: scf.for +// CHECK: } +// CHECK: %[[VTW:.*]] = vector.transfer_write +// CHECK: return %[[VTW]] : tensor<4x4xf32> +func @forward_dead_store_tensor(%arg0: i1, %arg1 : tensor<4x4xf32>, + %v0 : vector<1x4xf32>, %v1 : vector<1x4xf32>, %i : index) -> tensor<4x4xf32> { + %c1 = constant 1 : index + %c4 = constant 4 : index + %c0 = constant 0 : index + %cf0 = constant 0.0 : f32 + %w0 = vector.transfer_write %v0, %arg1[%c1, %c0] {masked = [false, false]} : + vector<1x4xf32>, tensor<4x4xf32> + %0 = vector.transfer_read %w0[%c1, %c0], %cf0 {masked = [false, false]} : + tensor<4x4xf32>, vector<1x4xf32> + %x = scf.for %i0 = %c0 to %c4 step %c1 iter_args(%acc = %0) + -> (vector<1x4xf32>) { + %1 = addf %acc, %acc : vector<1x4xf32> + scf.yield %1 : vector<1x4xf32> + } + %w1 = vector.transfer_write %x, %w0[%c1, %c0] {masked = [false, false]} : + vector<1x4xf32>, tensor<4x4xf32> + return %w1 : tensor<4x4xf32> +} + +// CHECK-LABEL: func @forward_dead_store_negative_tensor +// CHECK: vector.transfer_write +// CHECK: vector.transfer_read +// CHECK: scf.for +// CHECK: } +// CHECK: %[[VTW:.*]] = vector.transfer_write +// CHECK: return %[[VTW]] : tensor<4x4xf32> +func @forward_dead_store_negative_tensor(%arg0: i1, %arg1 : tensor<4x4xf32>, + %v0 : vector<1x4xf32>, %v1 : vector<1x4xf32>, %i : index) -> tensor<4x4xf32> { + %c1 = constant 1 : index + %c4 = constant 4 : index + %c0 = constant 0 : index + %cf0 = constant 0.0 : f32 + %w0 = vector.transfer_write %v0, %arg1[%c1, %i] {masked = [false, false]} : + vector<1x4xf32>, tensor<4x4xf32> + %0 = vector.transfer_read %w0[%c1, %c0], %cf0 {masked = [false, false]} : + tensor<4x4xf32>, vector<1x4xf32> + %x = scf.for %i0 = %c0 to %c4 step %c1 iter_args(%acc = %0) + -> (vector<1x4xf32>) { + %1 = addf %acc, %acc : vector<1x4xf32> + scf.yield %1 : vector<1x4xf32> + } + %w1 = vector.transfer_write %x, %w0[%c1, %c0] {masked = [false, false]} : + vector<1x4xf32>, tensor<4x4xf32> + return %w1 : tensor<4x4xf32> +}