diff --git a/mlir/test/Transforms/loop-fusion.mlir b/mlir/test/Transforms/loop-fusion.mlir --- a/mlir/test/Transforms/loop-fusion.mlir +++ b/mlir/test/Transforms/loop-fusion.mlir @@ -2535,3 +2535,67 @@ // CHECK: mulf // CHECK-NOT: affine.for // CHECK: divf + +// ----- + +// CHECK-LABEL: func @should_not_fuse_since_non_affine_users +func @should_not_fuse_since_non_affine_users(%in0 : memref<32xf32>, + %in1 : memref<32xf32>) { + affine.for %d = 0 to 32 { + %lhs = affine.load %in0[%d] : memref<32xf32> + %rhs = affine.load %in1[%d] : memref<32xf32> + %add = addf %lhs, %rhs : f32 + affine.store %add, %in0[%d] : memref<32xf32> + } + affine.for %d = 0 to 32 { + %lhs = load %in0[%d] : memref<32xf32> + %rhs = load %in1[%d] : memref<32xf32> + %add = subf %lhs, %rhs : f32 + store %add, %in0[%d] : memref<32xf32> + } + affine.for %d = 0 to 32 { + %lhs = affine.load %in0[%d] : memref<32xf32> + %rhs = affine.load %in1[%d] : memref<32xf32> + %add = mulf %lhs, %rhs : f32 + affine.store %add, %in0[%d] : memref<32xf32> + } + return +} + +// CHECK: affine.for +// CHECK: addf +// CHECK: affine.for +// CHECK: subf +// CHECK: affine.for +// CHECK: mulf + +// ----- + +// CHECK-LABEL: func @should_not_fuse_since_top_level_non_affine_users +func @should_not_fuse_since_top_level_non_affine_users(%in0 : memref<32xf32>, + %in1 : memref<32xf32>) { + %sum = alloc() : memref + affine.for %d = 0 to 32 { + %lhs = affine.load %in0[%d] : memref<32xf32> + %rhs = affine.load %in1[%d] : memref<32xf32> + %add = addf %lhs, %rhs : f32 + store %add, %sum[] : memref + affine.store %add, %in0[%d] : memref<32xf32> + } + %load_sum = load %sum[] : memref + affine.for %d = 0 to 32 { + %lhs = affine.load %in0[%d] : memref<32xf32> + %rhs = affine.load %in1[%d] : memref<32xf32> + %add = mulf %lhs, %rhs : f32 + %sub = subf %add, %load_sum: f32 + affine.store %sub, %in0[%d] : memref<32xf32> + } + dealloc %sum : memref + return +} + +// CHECK: affine.for +// CHECK: addf +// CHECK: affine.for +// CHECK: mulf +// CHECK: subf