diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -746,8 +746,19 @@ collector.collect(&op); // Return false if a non 'affine.for' region was found (not currently // supported). - if (collector.hasNonForRegion) - return false; + if (collector.hasNonForRegion) { + // 'affine.if' inside 'affine.for' should not globally block other + // fusions. 'canFuseLoops' method which is called later is able to check + // 'affine.if' to decide whether the fusion happens or not for a pair of + // 'affine.for' src and dest ops. + auto walkResult = forOp->walk([&](Operation *innerOp) { + if (isa(innerOp)) + return WalkResult::interrupt(); + return WalkResult::advance(); + }); + if (!walkResult.wasInterrupted()) + return false; + } Node node(nextNodeId++, &op); for (auto *opInst : collector.loadOpInsts) { node.loads.push_back(opInst); diff --git a/mlir/test/Transforms/loop-fusion.mlir b/mlir/test/Transforms/loop-fusion.mlir --- a/mlir/test/Transforms/loop-fusion.mlir +++ b/mlir/test/Transforms/loop-fusion.mlir @@ -502,6 +502,106 @@ // ----- +#set = affine_set<(d0) : (d0 - 1 >= 0)> + +// CHECK-LABEL: func @should_fuse_if_inst_in_loop_nest_not_sandwiched() -> memref<10x10xf32> { +func @should_fuse_if_inst_in_loop_nest_not_sandwiched() -> memref<10x10xf32> { + %a = memref.alloc() : memref<10x10xf32> + %b = memref.alloc() : memref<10x10xf32> + %cf7 = constant 7.0 : f32 + + affine.for %i0 = 0 to 10 { + affine.for %i1 = 0 to 10 { + affine.store %cf7, %a[%i0, %i1] : memref<10x10xf32> + } + } + affine.for %i2 = 0 to 10 { + affine.for %i3 = 0 to 10 { + %v0 = affine.load %a[%i3, %i2] : memref<10x10xf32> + affine.store %v0, %b[%i2, %i3] : memref<10x10xf32> + } + } + affine.for %i4 = 0 to 10 { + affine.for %i5 = 0 to 10 { + affine.if #set(%i5) {} else {} + } + } + + // IfOp in ForInst should not prevent fusion if it does not in between the + // source and dest ForInst ops. + + // CHECK: affine.for %{{.*}} = 0 to 10 { + // CHECK: affine.for %{{.*}} = 0 to 10 { + // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<10x10xf32> + // CHECK-NEXT: %{{.*}} = affine.load %{{.*}}[%{{.*}}, %{{.*}}] : memref<10x10xf32> + // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<10x10xf32> + // CHECK-NEXT: } + // CHECK-NEXT: } + // CHECK-NEXT: affine.for %{{.*}} = 0 to 10 { + // CHECK-NEXT: affine.for %{{.*}} = 0 to 10 { + // CHECK-NEXT: affine.if #set(%{{.*}}) { + // CHECK-NEXT: } + // CHECK-NEXT: } + // CHECK-NEXT: } + // CHECK-NOT: affine.for + return %a : memref<10x10xf32> +} + +// ----- + +#set = affine_set<(d0) : (d0 - 1 >= 0)> + +// CHECK-LABEL: func @should_not_fuse_if_inst_in_loop_nest_between_src_and_dest() -> memref<10x10xf32> { +func @should_not_fuse_if_inst_in_loop_nest_between_src_and_dest() -> memref<10x10xf32> { + %a = memref.alloc() : memref<10x10xf32> + %b = memref.alloc() : memref<10x10xf32> + %cf7 = constant 7.0 : f32 + + affine.for %i0 = 0 to 10 { + affine.for %i1 = 0 to 10 { + affine.store %cf7, %a[%i0, %i1] : memref<10x10xf32> + } + } + affine.for %i2 = 0 to 10 { + affine.for %i3 = 0 to 10 { + affine.if #set(%i3) { + affine.store %cf7, %a[%i2, %i3] : memref<10x10xf32> + } + } + } + affine.for %i4 = 0 to 10 { + affine.for %i5 = 0 to 10 { + %v0 = affine.load %a[%i5, %i4] : memref<10x10xf32> + affine.store %v0, %b[%i5, %i4] : memref<10x10xf32> + } + } + return %b : memref<10x10xf32> + + // IfOp in ForInst which modifies the memref should prevent fusion if it is in + // between the source and dest ForInst ops. + + // CHECK: affine.for %{{.*}} = 0 to 10 { + // CHECK-NEXT: affine.for %{{.*}} = 0 to 10 { + // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<10x10xf32> + // CHECK-NEXT: } + // CHECK-NEXT: } + // CHECK-NEXT: affine.for %{{.*}} = 0 to 10 { + // CHECK-NEXT: affine.for %{{.*}} = 0 to 10 { + // CHECK-NEXT: affine.if #set(%{{.*}}) { + // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<10x10xf32> + // CHECK-NEXT: } + // CHECK-NEXT: } + // CHECK-NEXT: } + // CHECK-NEXT: affine.for %{{.*}} = 0 to 10 { + // CHECK-NEXT: affine.for %{{.*}} = 0 to 10 { + // CHECK-NEXT: %{{.*}} = affine.load %{{.*}}[%{{.*}}, %{{.*}}] : memref<10x10xf32> + // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<10x10xf32> + // CHECK-NEXT: } + // CHECK-NEXT: } +} + +// ----- + // CHECK-LABEL: func @permute_and_fuse() { func @permute_and_fuse() { %m = memref.alloc() : memref<10x20x30xf32>