Index: mlir/lib/Transforms/MemRefDataFlowOpt.cpp =================================================================== --- mlir/lib/Transforms/MemRefDataFlowOpt.cpp +++ mlir/lib/Transforms/MemRefDataFlowOpt.cpp @@ -7,7 +7,8 @@ //===----------------------------------------------------------------------===// // // This file implements a pass to forward memref stores to loads, thereby -// potentially getting rid of intermediate memref's entirely. +// potentially getting rid of intermediate memref's entirely. It also removes +// redundant loads. // TODO: In the future, similar techniques could be used to eliminate // dead memref store's and perform more complex forwarding when support for // SSA scalars live out of 'affine.for'/'affine.if' statements is available. @@ -29,21 +30,24 @@ using namespace mlir; namespace { -// The store to load forwarding relies on three conditions: +// The store to load forwarding and load CSE rely on three conditions: // -// 1) they need to have mathematically equivalent affine access functions -// (checked after full composition of load/store operands); this implies that -// they access the same single memref element for all iterations of the common -// surrounding loop, +// 1) store/load and load need to have mathematically equivalent affine access +// functions (checked after full composition of load/store operands); this +// implies that they access the same single memref element for all iterations of +// the common surrounding loop, // -// 2) the store op should dominate the load op, +// 2) the store/load op should dominate the load op, // -// 3) among all op's that satisfy both (1) and (2), the one that postdominates -// all store op's that have a dependence into the load, is provably the last -// writer to the particular memref location being loaded at the load op, and its -// store value can be forwarded to the load. Note that the only dependences -// that are to be considered are those that are satisfied at the block* of the -// innermost common surrounding loop of the being considered. +// 3) among all op's that satisfy both (1) and (2), for store to load +// forwarding, the one that postdominates all store op's that have a dependence +// into the load, is provably the last writer to the particular memref location +// being loaded at the load op, and its store value can be forwarded to the +// load; for load CSE, any op that postdominates all store op's that have a +// dependence into the load can be forwarded and the first one found is chosen. +// Note that the only dependences that are to be considered are those that are +// satisfied at the block* of the innermost common surrounding loop of the +// being considered. // // (* A dependence being satisfied at a block: a dependence that is satisfied by // virtue of the destination operation appearing textually / lexically after @@ -65,10 +69,12 @@ void runOnFunction() override; void forwardStoreToLoad(AffineReadOpInterface loadOp); + void loadCSE(AffineReadOpInterface loadOp); // A list of memref's that are potentially dead / could be eliminated. SmallPtrSet memrefsToErase; - // Load op's whose results were replaced by those forwarded from stores. + // Load op's whose results were replaced by those forwarded from stores + // dominating stores or loads.. SmallVector loadOpsToErase; DominanceInfo *domInfo = nullptr; @@ -177,6 +183,83 @@ loadOpsToErase.push_back(loadOp); } +void MemRefDataFlowOpt::loadCSE(AffineReadOpInterface loadOp) { + // The list of load op candidates for forwarding that satisfy conditions + // (1) and (2) above - they will be filtered later when checking (3). + SmallVector fwdingCandidates; + SmallVector storeOps; + unsigned minSurroundingLoops = getNestingDepth(loadOp); + MemRefAccess memRefAccess(loadOp); + // First pass over the use list to get 1) the minimum number of surrounding + // loops common between the load op and load op candidate, with min taken + // across all load op candidates; 2) load op candidates; 3) store ops. + for (auto *user : loadOp.getMemRef().getUsers()) { + if (auto storeOp = dyn_cast(user)) { + storeOps.push_back(storeOp); + } else if (auto aLoadOp = dyn_cast(user)) { + MemRefAccess otherMemRefAccess(aLoadOp); + if (aLoadOp != loadOp && + std::find(loadOpsToErase.begin(), loadOpsToErase.end(), aLoadOp) == + loadOpsToErase.end() && + memRefAccess == otherMemRefAccess && + domInfo->dominates(aLoadOp, loadOp)) { + fwdingCandidates.push_back(aLoadOp); + unsigned nsLoops = getNumCommonSurroundingLoops(*loadOp, *aLoadOp); + minSurroundingLoops = std::min(nsLoops, minSurroundingLoops); + } + } + } + + // No forwarding candidate. + if (fwdingCandidates.empty()) + return; + + // Store ops that have a dependence into the load. + SmallVector depSrcStores; + + for (auto *storeOp : storeOps) { + MemRefAccess srcAccess(storeOp); + MemRefAccess destAccess(loadOp); + // Find stores that may be reaching the load. + FlatAffineConstraints dependenceConstraints; + unsigned nsLoops = getNumCommonSurroundingLoops(*loadOp, *storeOp); + unsigned d; + // Dependences at loop depth <= minSurroundingLoops do NOT matter. + for (d = nsLoops + 1; d > minSurroundingLoops; d--) { + DependenceResult result = checkMemrefAccessDependence( + srcAccess, destAccess, d, &dependenceConstraints, + /*dependenceComponents=*/nullptr); + if (hasDependence(result)) + break; + } + if (d == minSurroundingLoops) + continue; + + // Stores that *may* be reaching the load. + depSrcStores.push_back(storeOp); + } + + // 3. Of all the load op's that meet the above criteria, return the first load + // found that postdominates all 'depSrcStores' (if one exists). + Operation *firstLoadOp = nullptr; + for (auto *loadOp : fwdingCandidates) { + if (llvm::all_of(depSrcStores, [&](Operation *depStore) { + return postDomInfo->postDominates(loadOp, depStore); + })) { + firstLoadOp = loadOp; + break; + } + } + if (!firstLoadOp) + return; + + // Perform the actual load to load forwarding. + Value loadVal = cast(firstLoadOp).getValue(); + loadOp.getValue().replaceAllUsesWith(loadVal); + // Record this to erase later. + loadOpsToErase.push_back(loadOp); +} + void MemRefDataFlowOpt::runOnFunction() { // Only supports single block functions at the moment. FuncOp f = getFunction(); @@ -198,6 +281,14 @@ for (auto *loadOp : loadOpsToErase) loadOp->erase(); + // Walk the remaining load's and perform load to load forwarding. + loadOpsToErase.clear(); + f.walk([&](AffineReadOpInterface loadOp) { loadCSE(loadOp); }); + + // Erase all load op's whose results were replaced with prior loads. + for (auto *loadOp : loadOpsToErase) + loadOp->erase(); + // Check if the store fwd'ed memrefs are now left with only stores and can // thus be completely deleted. Note: the canonicalize pass should be able // to do this as well, but we'll do it here since we collected these anyway. Index: mlir/test/Transforms/memref-dataflow-opt.mlir =================================================================== --- mlir/test/Transforms/memref-dataflow-opt.mlir +++ mlir/test/Transforms/memref-dataflow-opt.mlir @@ -300,3 +300,202 @@ // CHECK-NEXT: %[[LDVAL:.*]] = affine.vector_load // CHECK-NEXT: affine.vector_store %[[LDVAL]],{{.*}} // CHECK-NEXT: } + +// CHECK-LABEL: func @simple_three_loads +func @simple_three_loads(%in : memref<10xf32>) { + affine.for %i0 = 0 to 10 { + // CHECK: affine.load + %v0 = affine.load %in[%i0] : memref<10xf32> + // CHECK-NOT: affine.load + %v1 = affine.load %in[%i0] : memref<10xf32> + %v2 = addf %v0, %v1 : f32 + %v3 = affine.load %in[%i0] : memref<10xf32> + %v4 = addf %v2, %v3 : f32 + } + return +} + +// CHECK-LABEL: func @nested_loads_const_index +func @nested_loads_const_index(%in : memref<10xf32>) { + %c0 = constant 0 : index + // CHECK: affine.load + %v0 = affine.load %in[%c0] : memref<10xf32> + affine.for %i0 = 0 to 10 { + affine.for %i1 = 0 to 20 { + affine.for %i2 = 0 to 30 { + // CHECK-NOT: affine.load + %v1 = affine.load %in[%c0] : memref<10xf32> + %v2 = addf %v0, %v1 : f32 + } + } + } + return +} + +// CHECK-LABEL: func @nested_loads +func @nested_loads(%N : index, %in : memref<10xf32>) { + affine.for %i0 = 0 to 10 { + // CHECK: affine.load + %v0 = affine.load %in[%i0] : memref<10xf32> + affine.for %i1 = 0 to %N { + // CHECK-NOT: affine.load + %v1 = affine.load %in[%i0] : memref<10xf32> + %v2 = addf %v0, %v1 : f32 + } + } + return +} + +// CHECK-LABEL: func @nested_loads_different_memref_accesses_no_cse +func @nested_loads_different_memref_accesses_no_cse(%in : memref<10xf32>) { + affine.for %i0 = 0 to 10 { + // CHECK: affine.load + %v0 = affine.load %in[%i0] : memref<10xf32> + affine.for %i1 = 0 to 20 { + // CHECK: affine.load + %v1 = affine.load %in[%i1] : memref<10xf32> + %v2 = addf %v0, %v1 : f32 + } + } + return +} + +// CHECK-LABEL: func @load_load_store +func @load_load_store(%m : memref<10xf32>) { + affine.for %i0 = 0 to 10 { + // CHECK: affine.load + %v0 = affine.load %m[%i0] : memref<10xf32> + // CHECK-NOT: affine.load + %v1 = affine.load %m[%i0] : memref<10xf32> + %v2 = addf %v0, %v1 : f32 + affine.store %v2, %m[%i0] : memref<10xf32> + } + return +} + +// CHECK-LABEL: func @load_load_store_2_loops_no_cse +func @load_load_store_2_loops_no_cse(%N : index, %m : memref<10xf32>) { + affine.for %i0 = 0 to 10 { + // CHECK: affine.load + %v0 = affine.load %m[%i0] : memref<10xf32> + affine.for %i1 = 0 to %N { + // CHECK: affine.load + %v1 = affine.load %m[%i0] : memref<10xf32> + %v2 = addf %v0, %v1 : f32 + affine.store %v2, %m[%i0] : memref<10xf32> + } + } + return +} + +// CHECK-LABEL: func @load_load_store_3_loops_no_cse +func @load_load_store_3_loops_no_cse(%m : memref<10xf32>) { +%cf1 = constant 1.0 : f32 + affine.for %i0 = 0 to 10 { + // CHECK: affine.load + %v0 = affine.load %m[%i0] : memref<10xf32> + affine.for %i1 = 0 to 20 { + affine.for %i2 = 0 to 30 { + // CHECK: affine.load + %v1 = affine.load %m[%i0] : memref<10xf32> + %v2 = addf %v0, %v1 : f32 + } + affine.store %cf1, %m[%i0] : memref<10xf32> + } + } + return +} + +// CHECK-LABEL: func @load_load_store_3_loops +func @load_load_store_3_loops(%m : memref<10xf32>) { +%cf1 = constant 1.0 : f32 + affine.for %i0 = 0 to 10 { + affine.for %i1 = 0 to 20 { + // CHECK: affine.load + %v0 = affine.load %m[%i0] : memref<10xf32> + affine.for %i2 = 0 to 30 { + // CHECK-NOT: affine.load + %v1 = affine.load %m[%i0] : memref<10xf32> + %v2 = addf %v0, %v1 : f32 + } + } + affine.store %cf1, %m[%i0] : memref<10xf32> + } + return +} + +// CHECK-LABEL: func @loads_in_sibling_loops_const_index_no_cse +func @loads_in_sibling_loops_const_index_no_cse(%m : memref<10xf32>) { + %c0 = constant 0 : index + affine.for %i0 = 0 to 10 { + // CHECK: affine.load + %v0 = affine.load %m[%c0] : memref<10xf32> + } + affine.for %i1 = 0 to 10 { + // CHECK: affine.load + %v0 = affine.load %m[%c0] : memref<10xf32> + %v1 = addf %v0, %v0 : f32 + } + return +} + +// CHECK-LABEL: func @load_load_affine_apply +func @load_load_affine_apply(%in : memref<10x10xf32>) { + affine.for %i0 = 0 to 10 { + affine.for %i1 = 0 to 10 { + %t0 = affine.apply affine_map<(d0, d1) -> (d1 + 1)>(%i0, %i1) + %t1 = affine.apply affine_map<(d0, d1) -> (d0)>(%i0, %i1) + %idx0 = affine.apply affine_map<(d0, d1) -> (d1)> (%t0, %t1) + %idx1 = affine.apply affine_map<(d0, d1) -> (d0 - 1)> (%t0, %t1) + // CHECK: affine.load + %v0 = affine.load %in[%idx0, %idx1] : memref<10x10xf32> + // CHECK-NOT: affine.load + %v1 = affine.load %in[%i0, %i1] : memref<10x10xf32> + %v2 = addf %v0, %v1 : f32 + } + } + return +} + +// CHECK-LABEL: func @vector_loads +func @vector_loads(%in : memref<512xf32>, %out : memref<512xf32>) { + affine.for %i = 0 to 16 { + // CHECK: affine.vector_load + %ld0 = affine.vector_load %in[32*%i] : memref<512xf32>, vector<32xf32> + // CHECK-NOT: affine.vector_load + %ld1 = affine.vector_load %in[32*%i] : memref<512xf32>, vector<32xf32> + %add = addf %ld0, %ld1 : vector<32xf32> + affine.vector_store %ld1, %out[32*%i] : memref<512xf32>, vector<32xf32> + } + return +} + +// CHECK-LABEL: func @vector_load_store_load_no_cse +func @vector_load_store_load_no_cse(%in : memref<512xf32>, %out : memref<512xf32>) { + affine.for %i = 0 to 16 { + // CHECK: affine.vector_load + %ld0 = affine.vector_load %in[32*%i] : memref<512xf32>, vector<32xf32> + affine.vector_store %ld0, %in[16*%i] : memref<512xf32>, vector<32xf32> + // CHECK: affine.vector_load + %ld1 = affine.vector_load %in[32*%i] : memref<512xf32>, vector<32xf32> + %add = addf %ld0, %ld1 : vector<32xf32> + affine.vector_store %ld1, %out[32*%i] : memref<512xf32>, vector<32xf32> + } + return +} + +// CHECK-LABEL: func @vector_load_affine_apply_store_load +func @vector_load_affine_apply_store_load(%in : memref<512xf32>, %out : memref<512xf32>) { + %cf1 = constant 1: index + affine.for %i = 0 to 15 { + // CHECK: affine.vector_load + %ld0 = affine.vector_load %in[32*%i] : memref<512xf32>, vector<32xf32> + %idx = affine.apply affine_map<(d0) -> (d0 + 1)> (%i) + affine.vector_store %ld0, %in[32*%idx] : memref<512xf32>, vector<32xf32> + // CHECK-NOT: affine.vector_load + %ld1 = affine.vector_load %in[32*%i] : memref<512xf32>, vector<32xf32> + %add = addf %ld0, %ld1 : vector<32xf32> + affine.vector_store %ld1, %out[32*%i] : memref<512xf32>, vector<32xf32> + } + return +}