diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp --- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp @@ -638,6 +638,25 @@ return success(); } +/// Returns true if the memory operation of `destAccess` depends on `srcAccess` +/// inside of the innermost common surrounding affine loop between the two +/// accesses. +static bool mustReachAtInnermost(const MemRefAccess &srcAccess, + const MemRefAccess &destAccess) { + // Affine dependence analysis is possible only if both ops in the same + // AffineScope. + if (getAffineScope(srcAccess.opInst) != getAffineScope(destAccess.opInst)) + return false; + + unsigned nsLoops = + getNumCommonSurroundingLoops(*srcAccess.opInst, *destAccess.opInst); + FlatAffineValueConstraints dependenceConstraints; + DependenceResult result = checkMemrefAccessDependence( + srcAccess, destAccess, nsLoops + 1, &dependenceConstraints, + /*dependenceComponents=*/nullptr); + return hasDependence(result); +} + /// Returns true if `srcMemOp` may have an effect on `destMemOp` within the /// scope of the outermost `minSurroundingLoops` loops that surround them. /// `srcMemOp` and `destMemOp` are expected to be affine read/write ops. @@ -858,7 +877,14 @@ if (!domInfo.dominates(storeOp, loadOp)) continue; - // 3. Ensure there is no intermediate operation which could replace the + // 3. The store must reach the load. Access function equivalence only + // guarantees this for accesses in the same block. The load could be in a + // nested block that is unreachable. + if (storeOp->getBlock() != loadOp->getBlock() && + !mustReachAtInnermost(srcAccess, destAccess)) + continue; + + // 4. Ensure there is no intermediate operation which could replace the // value in memory. if (!mlir::hasNoInterveningEffect(storeOp, loadOp)) continue; diff --git a/mlir/test/Dialect/Affine/scalrep.mlir b/mlir/test/Dialect/Affine/scalrep.mlir --- a/mlir/test/Dialect/Affine/scalrep.mlir +++ b/mlir/test/Dialect/Affine/scalrep.mlir @@ -846,3 +846,26 @@ // CHECK-NEXT: } // CHECK-NEXT: return } + + +// CHECK-LABEL: func.func @dead_affine_region_op +func.func @dead_affine_region_op() { + %c1 = arith.constant 1 : index + %alloc = memref.alloc() : memref<15xi1> + %true = arith.constant true + affine.store %true, %alloc[%c1] : memref<15xi1> + // Dead store. + affine.store %true, %alloc[%c1] : memref<15xi1> + // This affine.if is dead. + affine.if affine_set<(d0, d1, d2, d3) : ((d0 + 1) mod 8 >= 0, d0 * -8 >= 0)>(%c1, %c1, %c1, %c1){ + // No forwarding will happen. + affine.load %alloc[%c1] : memref<15xi1> + } + // CHECK-NEXT: arith.constant + // CHECK-NEXT: memref.alloc + // CHECK-NEXT: arith.constant + // CHECK-NEXT: affine.store + // CHECK-NEXT: affine.if + // CHECK-NEXT: affine.load + return +}