diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -179,8 +179,8 @@ // which contain accesses to the same memref 'value'. If the value is a // non-memref value, then the dependence is between a graph node which // defines an SSA value and another graph node which uses the SSA value - // (e.g. a constant operation defining a value which is used inside a loop - // nest). + // (e.g. a constant or load operation defining a value which is used inside + // a loop nest). Value value; }; @@ -217,6 +217,14 @@ return nullptr; } + // Returns the graph node for 'op'. + Node *getOpNode(Operation *op) { + for (auto &idAndNode : nodes) + if (idAndNode.second.op == op) + return &idAndNode.second; + return nullptr; + } + // Adds a node with 'op' to the graph and returns its unique identifier. unsigned addNode(Operation *op) { Node node(nextNodeId++, op); @@ -702,6 +710,15 @@ } } +/// Return all defining nodes of a given node. +void gatherDefiningNodes(unsigned id, MemRefDependenceGraph *mdg, + DenseSet &definingNodes) { + for (MemRefDependenceGraph::Edge edge : mdg->inEdges[id]) + // Defining node is the one on an edge with non-memref value. + if (!edge.value.getType().isa()) + definingNodes.insert(edge.id); +} + } // end anonymous namespace // Initializes the data dependence graph by walking operations in 'f'. @@ -783,10 +800,10 @@ } // Add dependence edges between nodes which produce SSA values and their - // users. + // users. Load ops can be considered as the ones producing SSA values. for (auto &idAndNode : nodes) { const Node &node = idAndNode.second; - if (!node.loads.empty() || !node.stores.empty()) + if (!node.stores.empty()) continue; auto *opInst = node.op; for (auto value : opInst->getResults()) { @@ -955,7 +972,7 @@ /// Walking from node 'srcId' to node 'dstId' (exclusive of 'srcId' and /// 'dstId'), if there is any non-affine operation accessing 'memref', return -/// false. Otherwise, return true. +/// true. Otherwise, return false. static bool hasNonAffineUsersOnThePath(unsigned srcId, unsigned dstId, Value memref, MemRefDependenceGraph *mdg) { @@ -1450,6 +1467,21 @@ continue; } + // Gather all defining nodes of 'dstNode'. + DenseSet definingNodes; + gatherDefiningNodes(dstNode->id, mdg, definingNodes); + + // Skip if there is any defining node of 'dstNode' that has dependency + // on 'srcNode'. + if (llvm::any_of(definingNodes, [&](unsigned id) { + return mdg->hasDependencePath(srcId, id); + })) { + LLVM_DEBUG(llvm::dbgs() + << "Can't fuse: defining ops in between the loops and " + "having dependency on the source loop\n."); + continue; + } + // Compute an operation list insertion point for the fused loop // nest which preserves dependences. Operation *fusedLoopInsPoint = diff --git a/mlir/test/Transforms/loop-fusion.mlir b/mlir/test/Transforms/loop-fusion.mlir --- a/mlir/test/Transforms/loop-fusion.mlir +++ b/mlir/test/Transforms/loop-fusion.mlir @@ -2835,3 +2835,100 @@ return } + +// ----- + +// CHECK-LABEL: func @should_fuse_defining_node_has_no_dependence_on_source_node +func @should_fuse_defining_node_has_no_dependence_on_source_node( + %a : memref<10xf32>, %b : memref) -> () { + affine.for %i0 = 0 to 10 { + %0 = affine.load %b[] : memref + affine.store %0, %a[%i0] : memref<10xf32> + } + %0 = affine.load %b[] : memref + affine.for %i1 = 0 to 10 { + %1 = affine.load %a[%i1] : memref<10xf32> + %2 = divf %0, %1 : f32 + } + + // Loops '%i0' and '%i1' should be fused even though there is a defining + // node between the loops. It is because the node has no dependence on '%i0'. + // CHECK: affine.load %{{.*}}[] : memref + // CHECK-NEXT: affine.for %{{.*}} = 0 to 10 { + // CHECK-NEXT: affine.load %{{.*}}[] : memref + // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32> + // CHECK-NEXT: affine.load %{{.*}}[%{{.*}}] : memref<10xf32> + // CHECK-NEXT: divf + // CHECK-NEXT: } + // CHECK-NOT: affine.for + return +} + +// ----- + +// CHECK-LABEL: func @should_not_fuse_defining_node_has_dependence_on_source_loop +func @should_not_fuse_defining_node_has_dependence_on_source_loop( + %a : memref<10xf32>, %b : memref) -> () { + %cst = constant 0.000000e+00 : f32 + affine.for %i0 = 0 to 10 { + affine.store %cst, %b[] : memref + affine.store %cst, %a[%i0] : memref<10xf32> + } + %0 = affine.load %b[] : memref + affine.for %i1 = 0 to 10 { + %1 = affine.load %a[%i1] : memref<10xf32> + %2 = divf %0, %1 : f32 + } + + // Loops '%i0' and '%i1' should not be fused because the defining node + // of '%0' used in '%i1' has dependence on loop '%i0'. + // CHECK: affine.for %{{.*}} = 0 to 10 { + // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[] : memref + // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32> + // CHECK-NEXT: } + // CHECK-NEXT: affine.load %{{.*}}[] : memref + // CHECK: affine.for %{{.*}} = 0 to 10 { + // CHECK-NEXT: affine.load %{{.*}}[%{{.*}}] : memref<10xf32> + // CHECK-NEXT: divf + // CHECK-NEXT: } + return +} + +// ----- + +// CHECK-LABEL: func @should_not_fuse_defining_node_has_transitive_dependence_on_source_loop +func @should_not_fuse_defining_node_has_transitive_dependence_on_source_loop( + %a : memref<10xf32>, %b : memref<10xf32>, %c : memref) -> () { + %cst = constant 0.000000e+00 : f32 + affine.for %i0 = 0 to 10 { + affine.store %cst, %a[%i0] : memref<10xf32> + affine.store %cst, %b[%i0] : memref<10xf32> + } + affine.for %i1 = 0 to 10 { + %1 = affine.load %b[%i1] : memref<10xf32> + affine.store %1, %c[] : memref + } + %0 = affine.load %c[] : memref + affine.for %i2 = 0 to 10 { + %1 = affine.load %a[%i2] : memref<10xf32> + %2 = divf %0, %1 : f32 + } + + // When loops '%i0' and '%i2' are evaluated first, they should not be + // fused. The defining node of '%0' in loop '%i2' has transitive dependence on + // loop '%i0'. After that, loops '%i0' and '%i1' are evaluated, and they will + // be fused as usual. + // CHECK: affine.for %{{.*}} = 0 to 10 { + // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32> + // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32> + // CHECK-NEXT: affine.load %{{.*}}[%{{.*}}] : memref<10xf32> + // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[] : memref + // CHECK-NEXT: } + // CHECK-NEXT: affine.load %{{.*}}[] : memref + // CHECK: affine.for %{{.*}} = 0 to 10 { + // CHECK-NEXT: affine.load %{{.*}}[%{{.*}}] : memref<10xf32> + // CHECK-NEXT: divf + // CHECK-NEXT: } + // CHECK-NOT: affine.for + return +}