diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -1645,6 +1645,10 @@ // Add edge from 'newMemRef' node to dstNode. mdg->addEdge(newMemRefNodeId, dstId, newMemRef); } + // One or more entries for 'newMemRef' alloc op are inserted into + // the DenseMap mdg->nodes. Since an insertion may cause DenseMap to + // reallocate, update dstNode. + dstNode = mdg->getNode(dstId); } // Collect dst loop stats after memref privatization transformation. diff --git a/mlir/test/Transforms/loop-fusion.mlir b/mlir/test/Transforms/loop-fusion.mlir --- a/mlir/test/Transforms/loop-fusion.mlir +++ b/mlir/test/Transforms/loop-fusion.mlir @@ -3115,3 +3115,186 @@ // CHECK-NEXT: affine.load // CHECK-NEXT: mulf // CHECK-NEXT: affine.store + +// ----- + +// CHECK-LABEL: func @fuse_large_number_of_loops +func @fuse_large_number_of_loops(%arg0: memref<20x10xf32, 1>, %arg1: memref<20x10xf32, 1>, %arg2: memref<20x10xf32, 1>, %arg3: memref<20x10xf32, 1>, %arg4: memref<20x10xf32, 1>, %arg5: memref, %arg6: memref, %arg7: memref, %arg8: memref, %arg9: memref<20x10xf32, 1>, %arg10: memref<20x10xf32, 1>, %arg11: memref<20x10xf32, 1>, %arg12: memref<20x10xf32, 1>) { + %cst = constant 1.000000e+00 : f32 + %0 = memref.alloc() : memref + affine.store %cst, %0[] : memref + %1 = memref.alloc() : memref<20x10xf32, 1> + affine.for %arg13 = 0 to 20 { + affine.for %arg14 = 0 to 10 { + %21 = affine.load %arg6[] : memref + affine.store %21, %1[%arg13, %arg14] : memref<20x10xf32, 1> + } + } + %2 = memref.alloc() : memref<20x10xf32, 1> + affine.for %arg13 = 0 to 20 { + affine.for %arg14 = 0 to 10 { + %21 = affine.load %1[%arg13, %arg14] : memref<20x10xf32, 1> + %22 = affine.load %arg3[%arg13, %arg14] : memref<20x10xf32, 1> + %23 = mulf %22, %21 : f32 + affine.store %23, %2[%arg13, %arg14] : memref<20x10xf32, 1> + } + } + %3 = memref.alloc() : memref + %4 = affine.load %arg6[] : memref + %5 = affine.load %0[] : memref + %6 = subf %5, %4 : f32 + affine.store %6, %3[] : memref + %7 = memref.alloc() : memref<20x10xf32, 1> + affine.for %arg13 = 0 to 20 { + affine.for %arg14 = 0 to 10 { + %21 = affine.load %3[] : memref + affine.store %21, %7[%arg13, %arg14] : memref<20x10xf32, 1> + } + } + %8 = memref.alloc() : memref<20x10xf32, 1> + affine.for %arg13 = 0 to 20 { + affine.for %arg14 = 0 to 10 { + %21 = affine.load %arg1[%arg13, %arg14] : memref<20x10xf32, 1> + %22 = affine.load %7[%arg13, %arg14] : memref<20x10xf32, 1> + %23 = mulf %22, %21 : f32 + affine.store %23, %8[%arg13, %arg14] : memref<20x10xf32, 1> + } + } + %9 = memref.alloc() : memref<20x10xf32, 1> + affine.for %arg13 = 0 to 20 { + affine.for %arg14 = 0 to 10 { + %21 = affine.load %arg1[%arg13, %arg14] : memref<20x10xf32, 1> + %22 = affine.load %8[%arg13, %arg14] : memref<20x10xf32, 1> + %23 = mulf %22, %21 : f32 + affine.store %23, %9[%arg13, %arg14] : memref<20x10xf32, 1> + } + } + affine.for %arg13 = 0 to 20 { + affine.for %arg14 = 0 to 10 { + %21 = affine.load %9[%arg13, %arg14] : memref<20x10xf32, 1> + %22 = affine.load %2[%arg13, %arg14] : memref<20x10xf32, 1> + %23 = addf %22, %21 : f32 + affine.store %23, %arg11[%arg13, %arg14] : memref<20x10xf32, 1> + } + } + %10 = memref.alloc() : memref<20x10xf32, 1> + affine.for %arg13 = 0 to 20 { + affine.for %arg14 = 0 to 10 { + %21 = affine.load %1[%arg13, %arg14] : memref<20x10xf32, 1> + %22 = affine.load %arg2[%arg13, %arg14] : memref<20x10xf32, 1> + %23 = mulf %22, %21 : f32 + affine.store %23, %10[%arg13, %arg14] : memref<20x10xf32, 1> + } + } + affine.for %arg13 = 0 to 20 { + affine.for %arg14 = 0 to 10 { + %21 = affine.load %8[%arg13, %arg14] : memref<20x10xf32, 1> + %22 = affine.load %10[%arg13, %arg14] : memref<20x10xf32, 1> + %23 = addf %22, %21 : f32 + affine.store %23, %arg10[%arg13, %arg14] : memref<20x10xf32, 1> + } + } + %11 = memref.alloc() : memref<20x10xf32, 1> + affine.for %arg13 = 0 to 20 { + affine.for %arg14 = 0 to 10 { + %21 = affine.load %arg10[%arg13, %arg14] : memref<20x10xf32, 1> + %22 = affine.load %arg10[%arg13, %arg14] : memref<20x10xf32, 1> + %23 = mulf %22, %21 : f32 + affine.store %23, %11[%arg13, %arg14] : memref<20x10xf32, 1> + } + } + %12 = memref.alloc() : memref<20x10xf32, 1> + affine.for %arg13 = 0 to 20 { + affine.for %arg14 = 0 to 10 { + %21 = affine.load %11[%arg13, %arg14] : memref<20x10xf32, 1> + %22 = affine.load %arg11[%arg13, %arg14] : memref<20x10xf32, 1> + %23 = subf %22, %21 : f32 + affine.store %23, %12[%arg13, %arg14] : memref<20x10xf32, 1> + } + } + %13 = memref.alloc() : memref<20x10xf32, 1> + affine.for %arg13 = 0 to 20 { + affine.for %arg14 = 0 to 10 { + %21 = affine.load %arg7[] : memref + affine.store %21, %13[%arg13, %arg14] : memref<20x10xf32, 1> + } + } + %14 = memref.alloc() : memref<20x10xf32, 1> + affine.for %arg13 = 0 to 20 { + affine.for %arg14 = 0 to 10 { + %21 = affine.load %arg4[%arg13, %arg14] : memref<20x10xf32, 1> + %22 = affine.load %13[%arg13, %arg14] : memref<20x10xf32, 1> + %23 = mulf %22, %21 : f32 + affine.store %23, %14[%arg13, %arg14] : memref<20x10xf32, 1> + } + } + %15 = memref.alloc() : memref<20x10xf32, 1> + affine.for %arg13 = 0 to 20 { + affine.for %arg14 = 0 to 10 { + %21 = affine.load %arg8[] : memref + affine.store %21, %15[%arg13, %arg14] : memref<20x10xf32, 1> + } + } + %16 = memref.alloc() : memref<20x10xf32, 1> + affine.for %arg13 = 0 to 20 { + affine.for %arg14 = 0 to 10 { + %21 = affine.load %15[%arg13, %arg14] : memref<20x10xf32, 1> + %22 = affine.load %12[%arg13, %arg14] : memref<20x10xf32, 1> + %23 = addf %22, %21 : f32 + affine.store %23, %16[%arg13, %arg14] : memref<20x10xf32, 1> + } + } + %17 = memref.alloc() : memref<20x10xf32, 1> + affine.for %arg13 = 0 to 20 { + affine.for %arg14 = 0 to 10 { + %21 = affine.load %16[%arg13, %arg14] : memref<20x10xf32, 1> + %22 = math.sqrt %21 : f32 + affine.store %22, %17[%arg13, %arg14] : memref<20x10xf32, 1> + } + } + %18 = memref.alloc() : memref<20x10xf32, 1> + affine.for %arg13 = 0 to 20 { + affine.for %arg14 = 0 to 10 { + %21 = affine.load %arg5[] : memref + affine.store %21, %18[%arg13, %arg14] : memref<20x10xf32, 1> + } + } + %19 = memref.alloc() : memref<20x10xf32, 1> + affine.for %arg13 = 0 to 20 { + affine.for %arg14 = 0 to 10 { + %21 = affine.load %arg1[%arg13, %arg14] : memref<20x10xf32, 1> + %22 = affine.load %18[%arg13, %arg14] : memref<20x10xf32, 1> + %23 = mulf %22, %21 : f32 + affine.store %23, %19[%arg13, %arg14] : memref<20x10xf32, 1> + } + } + %20 = memref.alloc() : memref<20x10xf32, 1> + affine.for %arg13 = 0 to 20 { + affine.for %arg14 = 0 to 10 { + %21 = affine.load %17[%arg13, %arg14] : memref<20x10xf32, 1> + %22 = affine.load %19[%arg13, %arg14] : memref<20x10xf32, 1> + %23 = divf %22, %21 : f32 + affine.store %23, %20[%arg13, %arg14] : memref<20x10xf32, 1> + } + } + affine.for %arg13 = 0 to 20 { + affine.for %arg14 = 0 to 10 { + %21 = affine.load %20[%arg13, %arg14] : memref<20x10xf32, 1> + %22 = affine.load %14[%arg13, %arg14] : memref<20x10xf32, 1> + %23 = addf %22, %21 : f32 + affine.store %23, %arg12[%arg13, %arg14] : memref<20x10xf32, 1> + } + } + affine.for %arg13 = 0 to 20 { + affine.for %arg14 = 0 to 10 { + %21 = affine.load %arg12[%arg13, %arg14] : memref<20x10xf32, 1> + %22 = affine.load %arg0[%arg13, %arg14] : memref<20x10xf32, 1> + %23 = subf %22, %21 : f32 + affine.store %23, %arg9[%arg13, %arg14] : memref<20x10xf32, 1> + } + } + return +} +// CHECK: affine.for +// CHECK: affine.for +// CHECK-NOT: affine.for