diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -179,8 +179,8 @@ // which contain accesses to the same memref 'value'. If the value is a // non-memref value, then the dependence is between a graph node which // defines an SSA value and another graph node which uses the SSA value - // (e.g. a constant operation defining a value which is used inside a loop - // nest). + // (e.g. a constant or load operation defining a value which is used inside + // a loop nest). Value value; }; @@ -369,6 +369,14 @@ return outEdgeCount; } + /// Return all defining nodes of a given node. + void gatherDefiningNodes(unsigned id, DenseSet &definingNodes) { + for (MemRefDependenceGraph::Edge edge : inEdges[id]) + // Defining node is the one on an edge with non-memref value. + if (!edge.value.getType().isa()) + definingNodes.insert(edge.id); + } + // Computes and returns an insertion point operation, before which the // the fused loop nest can be inserted while preserving // dependences. Returns nullptr if no such insertion point is found. @@ -376,6 +384,18 @@ if (outEdges.count(srcId) == 0) return getNode(dstId)->op; + // Skip if there is any defining node of 'dstId' that depends on 'srcId'. + DenseSet definingNodes; + gatherDefiningNodes(dstId, definingNodes); + if (llvm::any_of(definingNodes, [&](unsigned id) { + return hasDependencePath(srcId, id); + })) { + LLVM_DEBUG(llvm::dbgs() + << "Can't fuse: a defining op with a user in the dst " + "loop has dependence from the src loop\n"); + return nullptr; + } + // Build set of insts in range (srcId, dstId) which depend on 'srcId'. SmallPtrSet srcDepInsts; for (auto &outEdge : outEdges[srcId]) @@ -783,10 +803,11 @@ } // Add dependence edges between nodes which produce SSA values and their - // users. + // users. Load ops can be considered as the ones producing SSA values. for (auto &idAndNode : nodes) { const Node &node = idAndNode.second; - if (!node.loads.empty() || !node.stores.empty()) + // Stores don't define SSA values, skip them. + if (!node.stores.empty()) continue; auto *opInst = node.op; for (auto value : opInst->getResults()) { @@ -955,7 +976,7 @@ /// Walking from node 'srcId' to node 'dstId' (exclusive of 'srcId' and /// 'dstId'), if there is any non-affine operation accessing 'memref', return -/// false. Otherwise, return true. +/// true. Otherwise, return false. static bool hasNonAffineUsersOnThePath(unsigned srcId, unsigned dstId, Value memref, MemRefDependenceGraph *mdg) { @@ -1388,6 +1409,10 @@ // Skip if 'dstNode' is not a loop nest. if (!isa(dstNode->op)) continue; + // Skip if 'dstNode' is a loop nest returning values. + // TODO: support loop nests that return values. + if (dstNode->op->getNumResults() > 0) + continue; LLVM_DEBUG(llvm::dbgs() << "Evaluating dst loop " << dstId << "\n"); @@ -1418,6 +1443,11 @@ LLVM_DEBUG(llvm::dbgs() << "Evaluating src loop " << srcId << " for dst loop " << dstId << "\n"); + // Skip if 'srcNode' is a loop nest returning values. + // TODO: support loop nests that return values. + if (isa(srcNode->op) && srcNode->op->getNumResults() > 0) + continue; + DenseSet producerConsumerMemrefs; gatherProducerConsumerMemrefs(srcId, dstId, mdg, producerConsumerMemrefs); @@ -1450,7 +1480,6 @@ continue; } - // Compute an operation list insertion point for the fused loop // nest which preserves dependences. Operation *fusedLoopInsPoint = mdg->getFusedLoopNestInsertionPoint(srcNode->id, dstNode->id); diff --git a/mlir/test/Transforms/loop-fusion.mlir b/mlir/test/Transforms/loop-fusion.mlir --- a/mlir/test/Transforms/loop-fusion.mlir +++ b/mlir/test/Transforms/loop-fusion.mlir @@ -2835,3 +2835,143 @@ return } + +// ----- + +// CHECK-LABEL: func @should_fuse_defining_node_has_no_dependence_from_source_node +func @should_fuse_defining_node_has_no_dependence_from_source_node( + %a : memref<10xf32>, %b : memref) -> () { + affine.for %i0 = 0 to 10 { + %0 = affine.load %b[] : memref + affine.store %0, %a[%i0] : memref<10xf32> + } + %0 = affine.load %b[] : memref + affine.for %i1 = 0 to 10 { + %1 = affine.load %a[%i1] : memref<10xf32> + %2 = divf %0, %1 : f32 + } + + // Loops '%i0' and '%i1' should be fused even though there is a defining + // node between the loops. It is because the node has no dependence from '%i0'. + // CHECK: affine.load %{{.*}}[] : memref + // CHECK-NEXT: affine.for %{{.*}} = 0 to 10 { + // CHECK-NEXT: affine.load %{{.*}}[] : memref + // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32> + // CHECK-NEXT: affine.load %{{.*}}[%{{.*}}] : memref<10xf32> + // CHECK-NEXT: divf + // CHECK-NEXT: } + // CHECK-NOT: affine.for + return +} + +// ----- + +// CHECK-LABEL: func @should_not_fuse_defining_node_has_dependence_from_source_loop +func @should_not_fuse_defining_node_has_dependence_from_source_loop( + %a : memref<10xf32>, %b : memref) -> () { + %cst = constant 0.000000e+00 : f32 + affine.for %i0 = 0 to 10 { + affine.store %cst, %b[] : memref + affine.store %cst, %a[%i0] : memref<10xf32> + } + %0 = affine.load %b[] : memref + affine.for %i1 = 0 to 10 { + %1 = affine.load %a[%i1] : memref<10xf32> + %2 = divf %0, %1 : f32 + } + + // Loops '%i0' and '%i1' should not be fused because the defining node + // of '%0' used in '%i1' has dependence from loop '%i0'. + // CHECK: affine.for %{{.*}} = 0 to 10 { + // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[] : memref + // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32> + // CHECK-NEXT: } + // CHECK-NEXT: affine.load %{{.*}}[] : memref + // CHECK: affine.for %{{.*}} = 0 to 10 { + // CHECK-NEXT: affine.load %{{.*}}[%{{.*}}] : memref<10xf32> + // CHECK-NEXT: divf + // CHECK-NEXT: } + return +} + +// ----- + +// CHECK-LABEL: func @should_not_fuse_defining_node_has_transitive_dependence_from_source_loop +func @should_not_fuse_defining_node_has_transitive_dependence_from_source_loop( + %a : memref<10xf32>, %b : memref<10xf32>, %c : memref) -> () { + %cst = constant 0.000000e+00 : f32 + affine.for %i0 = 0 to 10 { + affine.store %cst, %a[%i0] : memref<10xf32> + affine.store %cst, %b[%i0] : memref<10xf32> + } + affine.for %i1 = 0 to 10 { + %1 = affine.load %b[%i1] : memref<10xf32> + affine.store %1, %c[] : memref + } + %0 = affine.load %c[] : memref + affine.for %i2 = 0 to 10 { + %1 = affine.load %a[%i2] : memref<10xf32> + %2 = divf %0, %1 : f32 + } + + // When loops '%i0' and '%i2' are evaluated first, they should not be + // fused. The defining node of '%0' in loop '%i2' has transitive dependence + // from loop '%i0'. After that, loops '%i0' and '%i1' are evaluated, and they + // will be fused as usual. + // CHECK: affine.for %{{.*}} = 0 to 10 { + // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32> + // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32> + // CHECK-NEXT: affine.load %{{.*}}[%{{.*}}] : memref<10xf32> + // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[] : memref + // CHECK-NEXT: } + // CHECK-NEXT: affine.load %{{.*}}[] : memref + // CHECK: affine.for %{{.*}} = 0 to 10 { + // CHECK-NEXT: affine.load %{{.*}}[%{{.*}}] : memref<10xf32> + // CHECK-NEXT: divf + // CHECK-NEXT: } + // CHECK-NOT: affine.for + return +} + +// ----- + +// TODO: fuse loop nests that returnvalues. +// CHECK-LABEL: func @should_not_fuse_dest_loop_nest_return_value +func @should_not_fuse_dest_loop_nest_return_value( + %a : memref<10xf32>) -> () { + %cst = constant 0.000000e+00 : f32 + affine.for %i0 = 0 to 10 { + affine.store %cst, %a[%i0] : memref<10xf32> + } + %b = affine.for %i1 = 0 to 10 step 2 iter_args(%b_iter = %cst) -> f32 { + %load_a = affine.load %a[%i1] : memref<10xf32> + affine.yield %load_a: f32 + } + + // CHECK: affine.for %{{.*}} = {{.*}} + // CHECK: {{.*}} = affine.for %{{.*}} = {{.*}} + + return +} + +// ----- + +// TODO: fuse loop nests that return values. +// CHECK-LABEL: func @should_not_fuse_src_loop_nest_return_value +func @should_not_fuse_src_loop_nest_return_value( + %a : memref<10xf32>) -> () { + %cst = constant 1.000000e+00 : f32 + %b = affine.for %i = 0 to 10 step 2 iter_args(%b_iter = %cst) -> f32 { + %c = addf %b_iter, %b_iter : f32 + affine.store %c, %a[%i] : memref<10xf32> + affine.yield %c: f32 + } + affine.for %i1 = 0 to 10 { + %1 = affine.load %a[%i1] : memref<10xf32> + } + + // CHECK: {{.*}} = affine.for %{{.*}} = {{.*}} + // CHECK: affine.for %{{.*}} = {{.*}} + + return +}