Index: mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp =================================================================== --- mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp +++ mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp @@ -18,6 +18,7 @@ #include "mlir/Dialect/Vector/Utils/VectorUtils.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Dominance.h" +#include "mlir/Transforms/SideEffectUtils.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/Debug.h" @@ -95,14 +96,32 @@ void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) { LLVM_DEBUG(DBGS() << "Candidate for dead store: " << *write.getOperation() << "\n"); - llvm::SmallVector reads; + llvm::SmallVector blockingAccesses; Operation *firstOverwriteCandidate = nullptr; - for (auto *user : write.getSource().getUsers()) { + Value source = write.getSource(); + // Skip subview ops. + while (auto subView = source.getDefiningOp()) + source = subView.getSource(); + llvm::SmallVector users(source.getUsers().begin(), + source.getUsers().end()); + llvm::SmallDenseSet processed; + while (!users.empty()) { + Operation *user = users.pop_back_val(); + // If the user has already been processed skip. + if (!processed.insert(user).second) + continue; + if (auto subView = dyn_cast(user)) { + users.append(subView->getUsers().begin(), subView->getUsers().end()); + continue; + } + if (isSideEffectFree(user)) + continue; if (user == write.getOperation()) continue; if (auto nextWrite = dyn_cast(user)) { // Check candidate that can override the store. - if (checkSameValueWAW(nextWrite, write) && + if (write.getSource() == nextWrite.getSource() && + checkSameValueWAW(nextWrite, write) && postDominators.postDominates(nextWrite, write)) { if (firstOverwriteCandidate == nullptr || postDominators.postDominates(firstOverwriteCandidate, nextWrite)) @@ -110,17 +129,17 @@ else assert( postDominators.postDominates(nextWrite, firstOverwriteCandidate)); + continue; } - } else { - if (auto read = dyn_cast(user)) { - // Don't need to consider disjoint reads. - if (vector::isDisjointTransferSet( - cast(write.getOperation()), - cast(read.getOperation()))) - continue; - } - reads.push_back(user); } + if (auto transferOp = dyn_cast(user)) { + // Don't need to consider disjoint accesses. + if (vector::isDisjointTransferSet( + cast(write.getOperation()), + cast(transferOp.getOperation()))) + continue; + } + blockingAccesses.push_back(user); } if (firstOverwriteCandidate == nullptr) return; @@ -129,15 +148,16 @@ assert(writeAncestor && "write op should be recursively part of the top region"); - for (Operation *read : reads) { - Operation *readAncestor = findAncestorOpInRegion(topRegion, read); - // TODO: if the read and write have the same ancestor we could recurse in - // the region to know if the read is reachable with more precision. - if (readAncestor == nullptr || !isReachable(writeAncestor, readAncestor)) + for (Operation *access : blockingAccesses) { + Operation *accessAncestor = findAncestorOpInRegion(topRegion, access); + // TODO: if the access and write have the same ancestor we could recurse in + // the region to know if the access is reachable with more precision. + if (accessAncestor == nullptr || + !isReachable(writeAncestor, accessAncestor)) continue; - if (!dominators.dominates(firstOverwriteCandidate, read)) { - LLVM_DEBUG(DBGS() << "Store may not be dead due to op: " << *read - << "\n"); + if (!dominators.dominates(firstOverwriteCandidate, accessAncestor)) { + LLVM_DEBUG(DBGS() << "Store may not be dead due to op: " + << *accessAncestor << "\n"); return; } } @@ -164,8 +184,23 @@ << "\n"); SmallVector blockingWrites; vector::TransferWriteOp lastwrite = nullptr; - for (Operation *user : read.getSource().getUsers()) { - if (isa(user)) + Value source = read.getSource(); + // Skip subview ops. + while (auto subView = source.getDefiningOp()) + source = subView.getSource(); + llvm::SmallVector users(source.getUsers().begin(), + source.getUsers().end()); + llvm::SmallDenseSet processed; + while (!users.empty()) { + Operation *user = users.pop_back_val(); + // If the user has already been processed skip. + if (!processed.insert(user).second) + continue; + if (auto subView = dyn_cast(user)) { + users.append(subView->getUsers().begin(), subView->getUsers().end()); + continue; + } + if (isSideEffectFree(user) || isa(user)) continue; if (auto write = dyn_cast(user)) { // If there is a write, but we can prove that it is disjoint we can ignore @@ -174,7 +209,8 @@ cast(write.getOperation()), cast(read.getOperation()))) continue; - if (dominators.dominates(write, read) && checkSameValueRAW(write, read)) { + if (write.getSource() == read.getSource() && + dominators.dominates(write, read) && checkSameValueRAW(write, read)) { if (lastwrite == nullptr || dominators.dominates(lastwrite, write)) lastwrite = write; else Index: mlir/test/Dialect/Vector/vector-transferop-opt.mlir =================================================================== --- mlir/test/Dialect/Vector/vector-transferop-opt.mlir +++ mlir/test/Dialect/Vector/vector-transferop-opt.mlir @@ -184,3 +184,35 @@ return } +// CHECK-LABEL: func @forward_dead_store_negative +// CHECK: vector.transfer_write +// CHECK: vector.transfer_write +// CHECK: vector.transfer_write +// CHECK: vector.transfer_write +// CHECK: vector.transfer_read +// CHECK: vector.transfer_write +// CHECK: return +func.func @forward_dead_store_negative(%arg0: i1, %arg1 : memref<4x4xf32>, + %v0 : vector<1x4xf32>, %v1 : vector<1x1xf32>, %v2 : vector<1x4xf32>, %i : index) -> vector<1x4xf32> { + %alias = memref.subview %arg1[0, 0] [2, 2] [1, 1] : + memref<4x4xf32> to memref<2x2xf32, strided<[4, 1]>> + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c0 = arith.constant 0 : index + %cf0 = arith.constant 0.0 : f32 + vector.transfer_write %v0, %arg1[%c1, %c0] {in_bounds = [true, true]} : + vector<1x4xf32>, memref<4x4xf32> + // blocking write. + vector.transfer_write %v1, %alias[%c0, %c0] {in_bounds = [true, true]} : + vector<1x1xf32>, memref<2x2xf32, strided<[4, 1]>> + vector.transfer_write %v2, %arg1[%c1, %c0] {in_bounds = [true, true]} : + vector<1x4xf32>, memref<4x4xf32> + // blocking write. + vector.transfer_write %v1, %alias[%c1, %c0] {in_bounds = [true, true]} : + vector<1x1xf32>, memref<2x2xf32, strided<[4, 1]>> + %0 = vector.transfer_read %arg1[%c1, %c0], %cf0 {in_bounds = [true, true]} : + memref<4x4xf32>, vector<1x4xf32> + vector.transfer_write %v2, %arg1[%c1, %c0] {in_bounds = [true, true]} : + vector<1x4xf32>, memref<4x4xf32> + return %0 : vector<1x4xf32> +}