Index: mlir/lib/Transforms/MemRefDataFlowOpt.cpp =================================================================== --- mlir/lib/Transforms/MemRefDataFlowOpt.cpp +++ mlir/lib/Transforms/MemRefDataFlowOpt.cpp @@ -7,7 +7,8 @@ //===----------------------------------------------------------------------===// // // This file implements a pass to forward memref stores to loads, thereby -// potentially getting rid of intermediate memref's entirely. +// potentially getting rid of intermediate memref's entirely. It also removes +// redundant loads. // TODO: In the future, similar techniques could be used to eliminate // dead memref store's and perform more complex forwarding when support for // SSA scalars live out of 'affine.for'/'affine.if' statements is available. @@ -20,6 +21,7 @@ #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/Dominance.h" +#include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/Passes.h" #include "llvm/ADT/SmallPtrSet.h" #include @@ -29,21 +31,24 @@ using namespace mlir; namespace { -// The store to load forwarding relies on three conditions: +// The store to load forwarding and load CSE rely on three conditions: // -// 1) they need to have mathematically equivalent affine access functions -// (checked after full composition of load/store operands); this implies that -// they access the same single memref element for all iterations of the common -// surrounding loop, +// 1) store/load and load need to have mathematically equivalent affine access +// functions (checked after full composition of load/store operands); this +// implies that they access the same single memref element for all iterations of +// the common surrounding loop, // -// 2) the store op should dominate the load op, +// 2) the store/load op should dominate the load op, // -// 3) among all op's that satisfy both (1) and (2), the one that postdominates -// all store op's that have a dependence into the load, is provably the last -// writer to the particular memref location being loaded at the load op, and its -// store value can be forwarded to the load. Note that the only dependences -// that are to be considered are those that are satisfied at the block* of the -// innermost common surrounding loop of the being considered. +// 3) among all op's that satisfy both (1) and (2), for store to load +// forwarding, the one that postdominates all store op's that have a dependence +// into the load, is provably the last writer to the particular memref location +// being loaded at the load op, and its store value can be forwarded to the +// load; for load CSE, any op that postdominates all store op's that have a +// dependence into the load can be forwarded and the first one found is chosen. +// Note that the only dependences that are to be considered are those that are +// satisfied at the block* of the innermost common surrounding loop of the +// being considered. // // (* A dependence being satisfied at a block: a dependence that is satisfied by // virtue of the destination operation appearing textually / lexically after @@ -64,11 +69,13 @@ struct MemRefDataFlowOpt : public MemRefDataFlowOptBase { void runOnFunction() override; - void forwardStoreToLoad(AffineReadOpInterface loadOp); + LogicalResult forwardStoreToLoad(AffineReadOpInterface loadOp); + void loadCSE(AffineReadOpInterface loadOp); // A list of memref's that are potentially dead / could be eliminated. SmallPtrSet memrefsToErase; - // Load op's whose results were replaced by those forwarded from stores. + // Load op's whose results were replaced by those forwarded from stores + // dominating stores or loads.. SmallVector loadOpsToErase; DominanceInfo *domInfo = nullptr; @@ -83,9 +90,32 @@ return std::make_unique(); } +// Check if the store may be reaching the load. +static bool storeMayReachLoad(Operation *storeOp, Operation *loadOp, + unsigned minSurroundingLoops) { + MemRefAccess srcAccess(storeOp); + MemRefAccess destAccess(loadOp); + FlatAffineConstraints dependenceConstraints; + unsigned nsLoops = getNumCommonSurroundingLoops(*loadOp, *storeOp); + unsigned d; + // Dependences at loop depth <= minSurroundingLoops do NOT matter. + for (d = nsLoops + 1; d > minSurroundingLoops; d--) { + DependenceResult result = checkMemrefAccessDependence( + srcAccess, destAccess, d, &dependenceConstraints, + /*dependenceComponents=*/nullptr); + if (hasDependence(result)) + break; + } + if (d <= minSurroundingLoops) + return false; + + return true; +} + // This is a straightforward implementation not optimized for speed. Optimize // if needed. -void MemRefDataFlowOpt::forwardStoreToLoad(AffineReadOpInterface loadOp) { +LogicalResult +MemRefDataFlowOpt::forwardStoreToLoad(AffineReadOpInterface loadOp) { // First pass over the use list to get the minimum number of surrounding // loops common between the load op and the store op, with min taken across // all store ops. @@ -110,21 +140,7 @@ SmallVector depSrcStores; for (auto *storeOp : storeOps) { - MemRefAccess srcAccess(storeOp); - MemRefAccess destAccess(loadOp); - // Find stores that may be reaching the load. - FlatAffineConstraints dependenceConstraints; - unsigned nsLoops = getNumCommonSurroundingLoops(*loadOp, *storeOp); - unsigned d; - // Dependences at loop depth <= minSurroundingLoops do NOT matter. - for (d = nsLoops + 1; d > minSurroundingLoops; d--) { - DependenceResult result = checkMemrefAccessDependence( - srcAccess, destAccess, d, &dependenceConstraints, - /*dependenceComponents=*/nullptr); - if (hasDependence(result)) - break; - } - if (d == minSurroundingLoops) + if (!storeMayReachLoad(storeOp, loadOp, minSurroundingLoops)) continue; // Stores that *may* be reaching the load. @@ -138,6 +154,8 @@ // store %A[%M] // load %A[%N] // Use the AffineValueMap difference based memref access equality checking. + MemRefAccess srcAccess(storeOp); + MemRefAccess destAccess(loadOp); if (srcAccess != destAccess) continue; @@ -165,7 +183,7 @@ } } if (!lastWriteStoreOp) - return; + return failure(); // Perform the actual store to load forwarding. Value storeVal = @@ -175,6 +193,84 @@ memrefsToErase.insert(loadOp.getMemRef()); // Record this to erase later. loadOpsToErase.push_back(loadOp); + return success(); +} + +// The load to load forwarding / redundant load elimination is similar to the +// store to load forwarding. +// loadA will be be replaced with loadB if: +// 1) loadA and loadB have mathematically equivalent affine access functions. +// 2) loadB dominates loadA. +// 3) loadB postdominates all the store op's that have a dependence into loadA. +void MemRefDataFlowOpt::loadCSE(AffineReadOpInterface loadOp) { + // The list of load op candidates for forwarding that satisfy conditions + // (1) and (2) above - they will be filtered later when checking (3). + SmallVector fwdingCandidates; + SmallVector storeOps; + unsigned minSurroundingLoops = getNestingDepth(loadOp); + MemRefAccess memRefAccess(loadOp); + // First pass over the use list to get 1) the minimum number of surrounding + // loops common between the load op and an load op candidate, with min taken + // across all load op candidates; 2) load op candidates; 3) store ops. + // We take min across all load op candidates instead of all load ops to make + // sure later dependence check is performed at loop depths that do matter. + for (auto *user : loadOp.getMemRef().getUsers()) { + if (auto storeOp = dyn_cast(user)) { + storeOps.push_back(storeOp); + } else if (auto aLoadOp = dyn_cast(user)) { + MemRefAccess otherMemRefAccess(aLoadOp); + // No need to consider Load ops that have been replaced in previous store + // to load forwarding or loadCSE. If loadA or storeA can be forwarded to + // loadB, then loadA or storeA can be forwarded to loadC iff loadB can be + // forwarded to loadC. + // If loadB is visited before loadC and replace with loadA, we do not put + // loadB in candidates list, only loadA. If loadC is visited before loadB, + // loadC may be replaced with loadB, which will be replaced with loadA + // later. + if (aLoadOp != loadOp && !llvm::is_contained(loadOpsToErase, aLoadOp) && + memRefAccess == otherMemRefAccess && + domInfo->dominates(aLoadOp, loadOp)) { + fwdingCandidates.push_back(aLoadOp); + unsigned nsLoops = getNumCommonSurroundingLoops(*loadOp, *aLoadOp); + minSurroundingLoops = std::min(nsLoops, minSurroundingLoops); + } + } + } + + // No forwarding candidate. + if (fwdingCandidates.empty()) + return; + + // Store ops that have a dependence into the load. + SmallVector depSrcStores; + + for (auto *storeOp : storeOps) { + if (!storeMayReachLoad(storeOp, loadOp, minSurroundingLoops)) + continue; + + // Stores that *may* be reaching the load. + depSrcStores.push_back(storeOp); + } + + // 3. Of all the load op's that meet the above criteria, return the first load + // found that postdominates all 'depSrcStores' (if one exists). + Operation *firstLoadOp = nullptr; + for (auto *loadOp : fwdingCandidates) { + if (llvm::all_of(depSrcStores, [&](Operation *depStore) { + return postDomInfo->postDominates(loadOp, depStore); + })) { + firstLoadOp = loadOp; + break; + } + } + if (!firstLoadOp) + return; + + // Perform the actual load to load forwarding. + Value loadVal = cast(firstLoadOp).getValue(); + loadOp.getValue().replaceAllUsesWith(loadVal); + // Record this to erase later. + loadOpsToErase.push_back(loadOp); } void MemRefDataFlowOpt::runOnFunction() { @@ -191,10 +287,15 @@ loadOpsToErase.clear(); memrefsToErase.clear(); - // Walk all load's and perform store to load forwarding. - f.walk([&](AffineReadOpInterface loadOp) { forwardStoreToLoad(loadOp); }); + // Walk all load's and perform store to load forwarding and loadCSE. + f.walk([&](AffineReadOpInterface loadOp) { + // Do store to load forwarding first, if no success, try loadCSE. + if (failed(forwardStoreToLoad(loadOp))) + loadCSE(loadOp); + }); - // Erase all load op's whose results were replaced with store fwd'ed ones. + // Erase all load op's whose results were replaced with store or load fwd'ed + // ones. for (auto *loadOp : loadOpsToErase) loadOp->erase(); Index: mlir/test/Transforms/memref-dataflow-opt.mlir =================================================================== --- mlir/test/Transforms/memref-dataflow-opt.mlir +++ mlir/test/Transforms/memref-dataflow-opt.mlir @@ -300,3 +300,202 @@ // CHECK-NEXT: %[[LDVAL:.*]] = affine.vector_load // CHECK-NEXT: affine.vector_store %[[LDVAL]],{{.*}} // CHECK-NEXT: } + +// CHECK-LABEL: func @simple_three_loads +func @simple_three_loads(%in : memref<10xf32>) { + affine.for %i0 = 0 to 10 { + // CHECK: affine.load + %v0 = affine.load %in[%i0] : memref<10xf32> + // CHECK-NOT: affine.load + %v1 = affine.load %in[%i0] : memref<10xf32> + %v2 = addf %v0, %v1 : f32 + %v3 = affine.load %in[%i0] : memref<10xf32> + %v4 = addf %v2, %v3 : f32 + } + return +} + +// CHECK-LABEL: func @nested_loads_const_index +func @nested_loads_const_index(%in : memref<10xf32>) { + %c0 = constant 0 : index + // CHECK: affine.load + %v0 = affine.load %in[%c0] : memref<10xf32> + affine.for %i0 = 0 to 10 { + affine.for %i1 = 0 to 20 { + affine.for %i2 = 0 to 30 { + // CHECK-NOT: affine.load + %v1 = affine.load %in[%c0] : memref<10xf32> + %v2 = addf %v0, %v1 : f32 + } + } + } + return +} + +// CHECK-LABEL: func @nested_loads +func @nested_loads(%N : index, %in : memref<10xf32>) { + affine.for %i0 = 0 to 10 { + // CHECK: affine.load + %v0 = affine.load %in[%i0] : memref<10xf32> + affine.for %i1 = 0 to %N { + // CHECK-NOT: affine.load + %v1 = affine.load %in[%i0] : memref<10xf32> + %v2 = addf %v0, %v1 : f32 + } + } + return +} + +// CHECK-LABEL: func @nested_loads_different_memref_accesses_no_cse +func @nested_loads_different_memref_accesses_no_cse(%in : memref<10xf32>) { + affine.for %i0 = 0 to 10 { + // CHECK: affine.load + %v0 = affine.load %in[%i0] : memref<10xf32> + affine.for %i1 = 0 to 20 { + // CHECK: affine.load + %v1 = affine.load %in[%i1] : memref<10xf32> + %v2 = addf %v0, %v1 : f32 + } + } + return +} + +// CHECK-LABEL: func @load_load_store +func @load_load_store(%m : memref<10xf32>) { + affine.for %i0 = 0 to 10 { + // CHECK: affine.load + %v0 = affine.load %m[%i0] : memref<10xf32> + // CHECK-NOT: affine.load + %v1 = affine.load %m[%i0] : memref<10xf32> + %v2 = addf %v0, %v1 : f32 + affine.store %v2, %m[%i0] : memref<10xf32> + } + return +} + +// CHECK-LABEL: func @load_load_store_2_loops_no_cse +func @load_load_store_2_loops_no_cse(%N : index, %m : memref<10xf32>) { + affine.for %i0 = 0 to 10 { + // CHECK: affine.load + %v0 = affine.load %m[%i0] : memref<10xf32> + affine.for %i1 = 0 to %N { + // CHECK: affine.load + %v1 = affine.load %m[%i0] : memref<10xf32> + %v2 = addf %v0, %v1 : f32 + affine.store %v2, %m[%i0] : memref<10xf32> + } + } + return +} + +// CHECK-LABEL: func @load_load_store_3_loops_no_cse +func @load_load_store_3_loops_no_cse(%m : memref<10xf32>) { +%cf1 = constant 1.0 : f32 + affine.for %i0 = 0 to 10 { + // CHECK: affine.load + %v0 = affine.load %m[%i0] : memref<10xf32> + affine.for %i1 = 0 to 20 { + affine.for %i2 = 0 to 30 { + // CHECK: affine.load + %v1 = affine.load %m[%i0] : memref<10xf32> + %v2 = addf %v0, %v1 : f32 + } + affine.store %cf1, %m[%i0] : memref<10xf32> + } + } + return +} + +// CHECK-LABEL: func @load_load_store_3_loops +func @load_load_store_3_loops(%m : memref<10xf32>) { +%cf1 = constant 1.0 : f32 + affine.for %i0 = 0 to 10 { + affine.for %i1 = 0 to 20 { + // CHECK: affine.load + %v0 = affine.load %m[%i0] : memref<10xf32> + affine.for %i2 = 0 to 30 { + // CHECK-NOT: affine.load + %v1 = affine.load %m[%i0] : memref<10xf32> + %v2 = addf %v0, %v1 : f32 + } + } + affine.store %cf1, %m[%i0] : memref<10xf32> + } + return +} + +// CHECK-LABEL: func @loads_in_sibling_loops_const_index_no_cse +func @loads_in_sibling_loops_const_index_no_cse(%m : memref<10xf32>) { + %c0 = constant 0 : index + affine.for %i0 = 0 to 10 { + // CHECK: affine.load + %v0 = affine.load %m[%c0] : memref<10xf32> + } + affine.for %i1 = 0 to 10 { + // CHECK: affine.load + %v0 = affine.load %m[%c0] : memref<10xf32> + %v1 = addf %v0, %v0 : f32 + } + return +} + +// CHECK-LABEL: func @load_load_affine_apply +func @load_load_affine_apply(%in : memref<10x10xf32>) { + affine.for %i0 = 0 to 10 { + affine.for %i1 = 0 to 10 { + %t0 = affine.apply affine_map<(d0, d1) -> (d1 + 1)>(%i0, %i1) + %t1 = affine.apply affine_map<(d0, d1) -> (d0)>(%i0, %i1) + %idx0 = affine.apply affine_map<(d0, d1) -> (d1)> (%t0, %t1) + %idx1 = affine.apply affine_map<(d0, d1) -> (d0 - 1)> (%t0, %t1) + // CHECK: affine.load + %v0 = affine.load %in[%idx0, %idx1] : memref<10x10xf32> + // CHECK-NOT: affine.load + %v1 = affine.load %in[%i0, %i1] : memref<10x10xf32> + %v2 = addf %v0, %v1 : f32 + } + } + return +} + +// CHECK-LABEL: func @vector_loads +func @vector_loads(%in : memref<512xf32>, %out : memref<512xf32>) { + affine.for %i = 0 to 16 { + // CHECK: affine.vector_load + %ld0 = affine.vector_load %in[32*%i] : memref<512xf32>, vector<32xf32> + // CHECK-NOT: affine.vector_load + %ld1 = affine.vector_load %in[32*%i] : memref<512xf32>, vector<32xf32> + %add = addf %ld0, %ld1 : vector<32xf32> + affine.vector_store %ld1, %out[32*%i] : memref<512xf32>, vector<32xf32> + } + return +} + +// CHECK-LABEL: func @vector_load_store_load_no_cse +func @vector_load_store_load_no_cse(%in : memref<512xf32>, %out : memref<512xf32>) { + affine.for %i = 0 to 16 { + // CHECK: affine.vector_load + %ld0 = affine.vector_load %in[32*%i] : memref<512xf32>, vector<32xf32> + affine.vector_store %ld0, %in[16*%i] : memref<512xf32>, vector<32xf32> + // CHECK: affine.vector_load + %ld1 = affine.vector_load %in[32*%i] : memref<512xf32>, vector<32xf32> + %add = addf %ld0, %ld1 : vector<32xf32> + affine.vector_store %ld1, %out[32*%i] : memref<512xf32>, vector<32xf32> + } + return +} + +// CHECK-LABEL: func @vector_load_affine_apply_store_load +func @vector_load_affine_apply_store_load(%in : memref<512xf32>, %out : memref<512xf32>) { + %cf1 = constant 1: index + affine.for %i = 0 to 15 { + // CHECK: affine.vector_load + %ld0 = affine.vector_load %in[32*%i] : memref<512xf32>, vector<32xf32> + %idx = affine.apply affine_map<(d0) -> (d0 + 1)> (%i) + affine.vector_store %ld0, %in[32*%idx] : memref<512xf32>, vector<32xf32> + // CHECK-NOT: affine.vector_load + %ld1 = affine.vector_load %in[32*%i] : memref<512xf32>, vector<32xf32> + %add = addf %ld0, %ld1 : vector<32xf32> + affine.vector_store %ld1, %out[32*%i] : memref<512xf32>, vector<32xf32> + } + return +}