diff --git a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp --- a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp +++ b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp @@ -41,14 +41,14 @@ // 2) the store/load op should dominate the load op, // // 3) among all op's that satisfy both (1) and (2), for store to load -// forwarding, the one that postdominates all store op's that have a dependence -// into the load, is provably the last writer to the particular memref location -// being loaded at the load op, and its store value can be forwarded to the -// load; for load CSE, any op that postdominates all store op's that have a -// dependence into the load can be forwarded and the first one found is chosen. -// Note that the only dependences that are to be considered are those that are -// satisfied at the block* of the innermost common surrounding loop of the -// being considered. +// forwarding, the one that does not dominate any store op that has a +// dependence into the load, is provably the last writer to the particular +// memref location being loaded at the load op, and its store value can be +// forwarded to the load; for load CSE, any op that does not dominate all store +// ops that have a dependence into the load can be forwarded and the first one +// found is chosen. Note that the only dependences that are to be considered are +// those that are satisfied at the block* of the innermost common surrounding +// loop of the being considered. // // (* A dependence being satisfied at a block: a dependence that is satisfied by // virtue of the destination operation appearing textually / lexically after @@ -74,12 +74,11 @@ // A list of memref's that are potentially dead / could be eliminated. SmallPtrSet memrefsToErase; - // Load op's whose results were replaced by those forwarded from stores + // Load ops whose results were replaced by those forwarded from stores // dominating stores or loads.. SmallVector loadOpsToErase; DominanceInfo *domInfo = nullptr; - PostDominanceInfo *postDomInfo = nullptr; }; } // end anonymous namespace @@ -136,7 +135,7 @@ // Store ops that have a dependence into the load (even if they aren't // forwarding candidates). Each forwarding candidate will be checked for a - // post-dominance on these. 'fwdingCandidates' are a subset of depSrcStores. + // dominance on these. 'fwdingCandidates' are a subset of depSrcStores. SmallVector depSrcStores; for (auto *storeOp : storeOps) { @@ -167,16 +166,23 @@ fwdingCandidates.push_back(storeOp); } - // 3. Of all the store op's that meet the above criteria, the store that - // postdominates all 'depSrcStores' (if one exists) is the unique store - // providing the value to the load, i.e., provably the last writer to that - // memref loc. - // Note: this can be implemented in a cleaner way with postdominator tree - // traversals. Consider this for the future if needed. + // 3. Of all the store ops that meet the above criteria, the store op + // that does not dominate any of the ops in 'depSrcStores' (if such exists) + // will not have any of those latter ops on its paths to `loadOp`. It would + // thus be the unique store providing the value to the load. This condition is + // however conservative for eg: + // + // for ... { + // store + // load + // store + // load + // } + // Operation *lastWriteStoreOp = nullptr; for (auto *storeOp : fwdingCandidates) { if (llvm::all_of(depSrcStores, [&](Operation *depStore) { - return postDomInfo->postDominates(storeOp, depStore); + return !domInfo->properlyDominates(storeOp, depStore); })) { lastWriteStoreOp = storeOp; break; @@ -205,7 +211,8 @@ // loadA will be be replaced with loadB if: // 1) loadA and loadB have mathematically equivalent affine access functions. // 2) loadB dominates loadA. -// 3) loadB postdominates all the store op's that have a dependence into loadA. +// 3) loadB does not dominate all the store op's that have a dependence into +// loadA. void MemRefDataFlowOpt::loadCSE(AffineReadOpInterface loadOp) { // The list of load op candidates for forwarding that satisfy conditions // (1) and (2) above - they will be filtered later when checking (3). @@ -257,15 +264,15 @@ } // 3. Of all the load op's that meet the above criteria, return the first load - // found that postdominates all 'depSrcStores' and has the same shape as the - // load to be replaced (if one exists). The shape check is needed for affine - // vector loads. + // found that does not dominate all 'depSrcStores' and has the same shape as + // the load to be replaced (if one exists). The shape check is needed for + // affine vector loads. Operation *firstLoadOp = nullptr; Value oldVal = loadOp.getValue(); for (auto *loadOp : fwdingCandidates) { if (llvm::all_of(depSrcStores, [&](Operation *depStore) { - return postDomInfo->postDominates(loadOp, depStore); + return !domInfo->properlyDominates(loadOp, depStore); }) && cast(loadOp).getValue().getType() == oldVal.getType()) { @@ -292,7 +299,6 @@ } domInfo = &getAnalysis(); - postDomInfo = &getAnalysis(); loadOpsToErase.clear(); memrefsToErase.clear(); diff --git a/mlir/test/Transforms/memref-dataflow-opt.mlir b/mlir/test/Transforms/memref-dataflow-opt.mlir --- a/mlir/test/Transforms/memref-dataflow-opt.mlir +++ b/mlir/test/Transforms/memref-dataflow-opt.mlir @@ -515,6 +515,31 @@ return } +// CHECK-LABEL: func @reduction_multi_store +func @reduction_multi_store() -> memref<1xf32> { + %A = memref.alloc() : memref<1xf32> + %cf0 = constant 0.0 : f32 + %cf5 = constant 5.0 : f32 + + affine.store %cf0, %A[0] : memref<1xf32> + affine.for %i = 0 to 100 step 2 { + %l = affine.load %A[0] : memref<1xf32> + %s = addf %l, %cf5 : f32 + // Store to load forwarding from this store should happen. + affine.store %s, %A[0] : memref<1xf32> + %m = affine.load %A[0] : memref<1xf32> + "test.foo"(%m) : (f32) -> () + } + +// CHECK: affine.for +// CHECK: affine.load +// CHECK: affine.store %[[S:.*]], +// CHECK-NOT: affine.load +// CHECK: "test.foo"(%[[S]]) + + return %A : memref<1xf32> +} + // CHECK-LABEL: func @vector_load_affine_apply_store_load func @vector_load_affine_apply_store_load(%in : memref<512xf32>, %out : memref<512xf32>) { %cf1 = constant 1: index