diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -744,10 +744,6 @@ // all loads and store accesses it contains. LoopNestStateCollector collector; collector.collect(&op); - // Return false if a non 'affine.for' region was found (not currently - // supported). - if (collector.hasNonForRegion) - return false; Node node(nextNodeId++, &op); for (auto *opInst : collector.loadOpInsts) { node.loads.push_back(opInst); @@ -775,9 +771,6 @@ auto memref = cast(op).getMemRef(); memrefAccesses[memref].insert(node.id); nodes.insert({node.id, node}); - } else if (op.getNumRegions() != 0) { - // Return false if another region is found (not currently supported). - return false; } else if (op.getNumResults() > 0 && !op.use_empty()) { // Create graph node for top-level producer of SSA values, which // could be used by loop nest nodes. diff --git a/mlir/test/Transforms/loop-fusion.mlir b/mlir/test/Transforms/loop-fusion.mlir --- a/mlir/test/Transforms/loop-fusion.mlir +++ b/mlir/test/Transforms/loop-fusion.mlir @@ -445,8 +445,8 @@ #set0 = affine_set<(d0) : (1 == 0)> -// CHECK-LABEL: func @should_not_fuse_if_inst_at_top_level() { -func @should_not_fuse_if_inst_at_top_level() { +// CHECK-LABEL: func @should_fuse_if_inst_at_top_level() -> memref<10xf32> { +func @should_fuse_if_inst_at_top_level() -> memref<10xf32> { %m = memref.alloc() : memref<10xf32> %cf7 = constant 7.0 : f32 @@ -459,14 +459,56 @@ %c0 = constant 4 : index affine.if #set0(%c0) { } - // Top-level IfOp should prevent fusion. + // Top-level IfOp should not prevent fusion of the others. // CHECK: affine.for %{{.*}} = 0 to 10 { // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32> + // CHECK-NEXT: %{{.*}} = affine.load %{{.*}}[%{{.*}}] : memref<10xf32> // CHECK-NEXT: } + // CHECK-NOT: affine.for + return %m : memref<10xf32> +} + +// ----- + +#set = affine_set<(d0) : (d0 - 1 >= 0)> + +func @should_fuse_if_inside_affine_for() -> memref<10x10xf32> { + %a = memref.alloc() : memref<10x10xf32> + %b = memref.alloc() : memref<10x10xf32> + %cf7 = constant 7.0 : f32 + + affine.for %i0 = 0 to 10 { + affine.for %i1 = 0 to 10 { + affine.store %cf7, %a[%i0, %i1] : memref<10x10xf32> + } + } + affine.for %i2 = 0 to 10 { + affine.for %i3 = 0 to 10 { + %v0 = affine.load %a[%i3, %i2] : memref<10x10xf32> + affine.store %v0, %b[%i2, %i3] : memref<10x10xf32> + } + } + affine.for %i4 = 0 to 10 { + affine.for %i5 = 0 to 10 { + affine.if #set(%i5) {} else {} + } + } + // IfOp inside AffineForOp should not prevent fusion of the others. // CHECK: affine.for %{{.*}} = 0 to 10 { - // CHECK-NEXT: affine.load %{{.*}}[%{{.*}}] : memref<10xf32> + // CHECK: affine.for %{{.*}} = 0 to 10 { + // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<10x10xf32> + // CHECK-NEXT: %{{.*}} = affine.load %{{.*}}[%{{.*}}, %{{.*}}] : memref<10x10xf32> + // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<10x10xf32> + // CHECK-NEXT: } // CHECK-NEXT: } - return + // CHECK-NEXT: affine.for %{{.*}} = 0 to 10 { + // CHECK-NEXT: affine.for %{{.*}} = 0 to 10 { + // CHECK-NEXT: affine.if #set(%{{.*}}) { + // CHECK-NEXT: } + // CHECK-NEXT: } + // CHECK-NEXT: } + // CHECK-NOT: affine.for + return %a : memref<10x10xf32> } // -----