diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp --- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp @@ -862,9 +862,10 @@ /// other operations will overwrite the memory loaded between the given load /// and store. If such a value exists, the replaced `loadOp` will be added to /// `loadOpsToErase` and its memref will be added to `memrefsToErase`. -static LogicalResult forwardStoreToLoad( - AffineReadOpInterface loadOp, SmallVectorImpl &loadOpsToErase, - SmallPtrSetImpl &memrefsToErase, DominanceInfo &domInfo) { +static void forwardStoreToLoad(AffineReadOpInterface loadOp, + SmallVectorImpl &loadOpsToErase, + SmallPtrSetImpl &memrefsToErase, + DominanceInfo &domInfo) { // The store op candidate for forwarding that satisfies all conditions // to replace the load, if any. @@ -911,7 +912,7 @@ } if (!lastWriteStoreOp) - return failure(); + return; // Perform the actual store to load forwarding. Value storeVal = @@ -919,13 +920,12 @@ // Check if 2 values have the same shape. This is needed for affine vector // loads and stores. if (storeVal.getType() != loadOp.getValue().getType()) - return failure(); + return; loadOp.getValue().replaceAllUsesWith(storeVal); // Record the memref for a later sweep to optimize away. memrefsToErase.insert(loadOp.getMemRef()); // Record this to erase later. loadOpsToErase.push_back(loadOp); - return success(); } template bool @@ -1000,7 +1000,7 @@ continue; } - // 2. The store has to dominate the load op to be candidate. + // 2. loadB dominates loadA. if (!domInfo.dominates(loadB, loadA)) continue; @@ -1073,13 +1073,8 @@ // Walk all load's and perform store to load forwarding. f.walk([&](AffineReadOpInterface loadOp) { - if (failed( - forwardStoreToLoad(loadOp, opsToErase, memrefsToErase, domInfo))) { - loadCSE(loadOp, opsToErase, domInfo); - } + forwardStoreToLoad(loadOp, opsToErase, memrefsToErase, domInfo); }); - - // Erase all load op's whose results were replaced with store fwd'ed ones. for (auto *op : opsToErase) op->erase(); opsToErase.clear(); @@ -1088,9 +1083,9 @@ f.walk([&](AffineWriteOpInterface storeOp) { findUnusedStore(storeOp, opsToErase, postDomInfo); }); - // Erase all store op's which don't impact the program for (auto *op : opsToErase) op->erase(); + opsToErase.clear(); // Check if the store fwd'ed memrefs are now left with only stores and // deallocs and can thus be completely deleted. Note: the canonicalize pass @@ -1114,6 +1109,15 @@ user->erase(); defOp->erase(); } + + // To eliminate as many loads as possible, run load CSE after eliminating + // stores. Otherwise, some stores are wrongly seen as having an intervening + // effect. + f.walk([&](AffineReadOpInterface loadOp) { + loadCSE(loadOp, opsToErase, domInfo); + }); + for (auto *op : opsToErase) + op->erase(); } // Perform the replacement in `op`. diff --git a/mlir/test/Dialect/Affine/scalrep.mlir b/mlir/test/Dialect/Affine/scalrep.mlir --- a/mlir/test/Dialect/Affine/scalrep.mlir +++ b/mlir/test/Dialect/Affine/scalrep.mlir @@ -280,6 +280,31 @@ return } +// CHECK-LABEL: func @elim_load_after_store +func.func @elim_load_after_store(%arg0: memref<100xf32>, %arg1: memref<100xf32>) { + %alloc = memref.alloc() : memref<1xf32> + %alloc_0 = memref.alloc() : memref<1xf32> + // CHECK: affine.for + affine.for %arg2 = 0 to 100 { + // CHECK: affine.load + %0 = affine.load %arg0[%arg2] : memref<100xf32> + %1 = affine.load %arg0[%arg2] : memref<100xf32> + // CHECK: arith.addf + %2 = arith.addf %0, %1 : f32 + affine.store %2, %alloc_0[0] : memref<1xf32> + %3 = affine.load %arg0[%arg2] : memref<100xf32> + %4 = affine.load %alloc_0[0] : memref<1xf32> + // CHECK-NEXT: arith.addf + %5 = arith.addf %3, %4 : f32 + affine.store %5, %alloc[0] : memref<1xf32> + %6 = affine.load %arg0[%arg2] : memref<100xf32> + %7 = affine.load %alloc[0] : memref<1xf32> + %8 = arith.addf %6, %7 : f32 + affine.store %8, %arg1[%arg2] : memref<100xf32> + } + return +} + // The test checks for value forwarding from vector stores to vector loads. // The value loaded from %in can directly be stored to %out by eliminating // store and load from %tmp.