diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -784,6 +784,15 @@ // could be used by loop nest nodes. Node node(nextNodeId++, &op); nodes.insert({node.id, node}); + } else if (isa(op)) { + // Create graph node for top-level Call Op that takes any argument of + // memref type. Call Op that returns one or more memref type results + // is already taken care of, by the previous conditions. + if (llvm::any_of(op.getOperandTypes(), + [&](Type t) { return t.isa(); })) { + Node node(nextNodeId++, &op); + nodes.insert({node.id, node}); + } } else if (auto effectInterface = dyn_cast(op)) { // Create graph node for top-level op, which could have a memory write // side effect. diff --git a/mlir/test/Transforms/loop-fusion.mlir b/mlir/test/Transforms/loop-fusion.mlir --- a/mlir/test/Transforms/loop-fusion.mlir +++ b/mlir/test/Transforms/loop-fusion.mlir @@ -3016,3 +3016,55 @@ return } + +// ----- + +func private @some_function(memref<16xf32>) +func @call_op_prevents_fusion(%arg0: memref<16xf32>){ + %A = alloc() : memref<16xf32> + %cst_1 = constant 1.000000e+00 : f32 + affine.for %arg1 = 0 to 16 { + %a = affine.load %arg0[%arg1] : memref<16xf32> + affine.store %a, %A[%arg1] : memref<16xf32> + } + call @some_function(%A) : (memref<16xf32>) -> () + %B = alloc() : memref<16xf32> + affine.for %arg1 = 0 to 16 { + %a = affine.load %A[%arg1] : memref<16xf32> + %b = addf %cst_1, %a : f32 + affine.store %b, %B[%arg1] : memref<16xf32> + } + return +} +// CHECK-LABEL: func @call_op_prevents_fusion +// CHECK: affine.for +// CHECK-NEXT: affine.load +// CHECK-NEXT: affine.store +// CHECK: call +// CHECK: affine.for +// CHECK-NEXT: affine.load +// CHECK-NEXT: addf +// CHECK-NEXT: affine.store + +// ----- + +func private @some_function() +func @call_op_does_not_prevent_fusion(%arg0: memref<16xf32>){ + %A = alloc() : memref<16xf32> + %cst_1 = constant 1.000000e+00 : f32 + affine.for %arg1 = 0 to 16 { + %a = affine.load %arg0[%arg1] : memref<16xf32> + affine.store %a, %A[%arg1] : memref<16xf32> + } + call @some_function() : () -> () + %B = alloc() : memref<16xf32> + affine.for %arg1 = 0 to 16 { + %a = affine.load %A[%arg1] : memref<16xf32> + %b = addf %cst_1, %a : f32 + affine.store %b, %B[%arg1] : memref<16xf32> + } + return +} +// CHECK-LABEL: func @call_op_does_not_prevent_fusion +// CHECK: affine.for +// CHECK-NOT: affine.for