diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineScalarReplacement.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineScalarReplacement.cpp --- a/mlir/lib/Dialect/Affine/Transforms/AffineScalarReplacement.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/AffineScalarReplacement.cpp @@ -68,6 +68,11 @@ void loadCSE(AffineReadOpInterface loadOp, SmallVectorImpl &loadOpsToErase, DominanceInfo &domInfo); + + void findUnusedStore(AffineWriteOpInterface storeOp, + SmallVectorImpl &storeOpsToErase, + SmallPtrSetImpl &memrefsToErase, + PostDominanceInfo &postDominanceInfo); }; } // end anonymous namespace @@ -256,6 +261,51 @@ return !hasSideEffect; } +// This attempts to find stores which have no impact on the final result. +// A writing op writeA will be eliminated if there exists an op writeB if +// 1) writeA and writeB have mathematically equivalent affine access functions. +// 2) writeB postdominates writeA. +// 3) There is no potential read between writeA and writeB. +void AffineScalarReplacement::findUnusedStore( + AffineWriteOpInterface writeA, SmallVectorImpl &opsToErase, + SmallPtrSetImpl &memrefsToErase, + PostDominanceInfo &postDominanceInfo) { + + for (Operation *user : writeA.getMemRef().getUsers()) { + // Only consider writing operations. + auto writeB = dyn_cast(user); + if (!writeB) + continue; + + // The operations must be distinct. + if (writeB == writeA) + continue; + + // Both operations must lie in the same region. + if (writeB->getParentRegion() != writeA->getParentRegion()) + continue; + + // Both operations must write to the same memory. + MemRefAccess srcAccess(writeB); + MemRefAccess destAccess(writeA); + + if (srcAccess != destAccess) + continue; + + // writeB must postdominate writeA. + if (!postDominanceInfo.postDominates(writeB, writeA)) + continue; + + // There cannot be an operation which reads from memory between + // the two writes. + if (!hasNoInterveningEffect(writeA, writeB)) + continue; + + opsToErase.push_back(writeA); + break; + } +} + /// Attempt to eliminate loadOp by replacing it with a value stored into memory /// which the load is guaranteed to retrieve. This check involves three /// components: 1) The store and load must be on the same location 2) The store @@ -394,6 +444,7 @@ SmallPtrSet memrefsToErase; auto &domInfo = getAnalysis(); + auto &postDomInfo = getAnalysis(); // Walk all load's and perform store to load forwarding. f.walk([&](AffineReadOpInterface loadOp) { @@ -404,6 +455,15 @@ }); // Erase all load op's whose results were replaced with store fwd'ed ones. + for (auto *op : opsToErase) + op->erase(); + opsToErase.clear(); + + // Walk all store's and perform unused store elimination + f.walk([&](AffineWriteOpInterface storeOp) { + findUnusedStore(storeOp, opsToErase, memrefsToErase, postDomInfo); + }); + // Erase all store op's which don't impact the program for (auto *op : opsToErase) op->erase(); diff --git a/mlir/test/Dialect/Affine/scalrep.mlir b/mlir/test/Dialect/Affine/scalrep.mlir --- a/mlir/test/Dialect/Affine/scalrep.mlir +++ b/mlir/test/Dialect/Affine/scalrep.mlir @@ -642,3 +642,38 @@ // CHECK-NEXT: return %{{.*}} : f32 } +// CHECK-LABEL: func @redundant_store_elim + +func @redundant_store_elim(%out : memref<512xf32>) { + %cf1 = constant 1.0 : f32 + %cf2 = constant 2.0 : f32 + affine.for %i = 0 to 16 { + affine.store %cf1, %out[32*%i] : memref<512xf32> + affine.store %cf2, %out[32*%i] : memref<512xf32> + } + return +} + +// CHECK: affine.for +// CHECK-NEXT: affine.store +// CHECK-NEXT: } + +// CHECK-LABEL: func @redundant_store_elim_fail + +func @redundant_store_elim_fail(%out : memref<512xf32>) { + %cf1 = constant 1.0 : f32 + %cf2 = constant 2.0 : f32 + affine.for %i = 0 to 16 { + affine.store %cf1, %out[32*%i] : memref<512xf32> + "test.use"(%out) : (memref<512xf32>) -> () + affine.store %cf2, %out[32*%i] : memref<512xf32> + } + return +} + +// CHECK: affine.for +// CHECK-NEXT: affine.store +// CHECK-NEXT: "test.use" +// CHECK-NEXT: affine.store +// CHECK-NEXT: } +