diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -70,19 +70,20 @@ namespace { // LoopNestStateCollector walks loop nests and collects load and store -// operations, and whether or not an IfInst was encountered in the loop nest. +// operations, and whether or not a region holding op other than ForOp and IfOp +// was encountered in the loop nest. struct LoopNestStateCollector { SmallVector forOps; SmallVector loadOpInsts; SmallVector storeOpInsts; - bool hasNonForRegion = false; + bool hasNonAffineRegionOp = false; void collect(Operation *opToWalk) { opToWalk->walk([&](Operation *op) { if (isa(op)) forOps.push_back(cast(op)); - else if (op->getNumRegions() != 0) - hasNonForRegion = true; + else if (op->getNumRegions() != 0 && !isa(op)) + hasNonAffineRegionOp = true; else if (isa(op)) loadOpInsts.push_back(op); else if (isa(op)) @@ -744,9 +745,9 @@ // all loads and store accesses it contains. LoopNestStateCollector collector; collector.collect(&op); - // Return false if a non 'affine.for' region was found (not currently - // supported). - if (collector.hasNonForRegion) + // Return false if a region holding op other than 'affine.for' and + // 'affine.if' was found (not currently supported). + if (collector.hasNonAffineRegionOp) return false; Node node(nextNodeId++, &op); for (auto *opInst : collector.loadOpInsts) { diff --git a/mlir/test/Transforms/loop-fusion.mlir b/mlir/test/Transforms/loop-fusion.mlir --- a/mlir/test/Transforms/loop-fusion.mlir +++ b/mlir/test/Transforms/loop-fusion.mlir @@ -445,8 +445,8 @@ #set0 = affine_set<(d0) : (1 == 0)> -// CHECK-LABEL: func @should_not_fuse_if_inst_at_top_level() { -func @should_not_fuse_if_inst_at_top_level() { +// CHECK-LABEL: func @should_not_fuse_if_op_at_top_level() { +func @should_not_fuse_if_op_at_top_level() { %m = memref.alloc() : memref<10xf32> %cf7 = constant 7.0 : f32 @@ -473,8 +473,8 @@ #set0 = affine_set<(d0) : (1 == 0)> -// CHECK-LABEL: func @should_not_fuse_if_inst_in_loop_nest() { -func @should_not_fuse_if_inst_in_loop_nest() { +// CHECK-LABEL: func @should_not_fuse_if_op_in_loop_nest() { +func @should_not_fuse_if_op_in_loop_nest() { %m = memref.alloc() : memref<10xf32> %cf7 = constant 7.0 : f32 %c4 = constant 4 : index @@ -488,7 +488,7 @@ %v0 = affine.load %m[%i1] : memref<10xf32> } - // IfOp in ForInst should prevent fusion. + // IfOp in ForOp should prevent fusion. // CHECK: affine.for %{{.*}} = 0 to 10 { // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32> // CHECK-NEXT: } @@ -502,6 +502,83 @@ // ----- +#set = affine_set<(d0) : (d0 - 1 >= 0)> + +// CHECK-LABEL: func @should_fuse_if_op_in_loop_nest_not_sandwiched() -> memref<10xf32> { +func @should_fuse_if_op_in_loop_nest_not_sandwiched() -> memref<10xf32> { + %a = memref.alloc() : memref<10xf32> + %b = memref.alloc() : memref<10xf32> + %cf7 = constant 7.0 : f32 + + affine.for %i0 = 0 to 10 { + affine.store %cf7, %a[%i0] : memref<10xf32> + } + affine.for %i1 = 0 to 10 { + %v0 = affine.load %a[%i1] : memref<10xf32> + affine.store %v0, %b[%i1] : memref<10xf32> + } + affine.for %i2 = 0 to 10 { + affine.if #set(%i2) { + %v0 = affine.load %b[%i2] : memref<10xf32> + } + } + + // IfOp in ForOp should not prevent fusion if it does not in between the + // source and dest ForOp ops. + + // CHECK: affine.for + // CHECK-NEXT: affine.store + // CHECK-NEXT: affine.load + // CHECK-NEXT: affine.store + // CHECK: affine.for + // CHECK-NEXT: affine.if + // CHECK-NEXT: affine.load + // CHECK-NOT: affine.for + // CHECK: return + + return %a : memref<10xf32> +} + +// ----- + +#set = affine_set<(d0) : (d0 - 1 >= 0)> + +// CHECK-LABEL: func @should_not_fuse_if_op_in_loop_nest_between_src_and_dest() -> memref<10xf32> { +func @should_not_fuse_if_op_in_loop_nest_between_src_and_dest() -> memref<10xf32> { + %a = memref.alloc() : memref<10xf32> + %b = memref.alloc() : memref<10xf32> + %cf7 = constant 7.0 : f32 + + affine.for %i0 = 0 to 10 { + affine.store %cf7, %a[%i0] : memref<10xf32> + } + affine.for %i1 = 0 to 10 { + affine.if #set(%i1) { + affine.store %cf7, %a[%i1] : memref<10xf32> + } + } + affine.for %i3 = 0 to 10 { + %v0 = affine.load %a[%i3] : memref<10xf32> + affine.store %v0, %b[%i3] : memref<10xf32> + } + return %b : memref<10xf32> + + // IfOp in ForOp which modifies the memref should prevent fusion if it is in + // between the source and dest ForOp. + + // CHECK: affine.for + // CHECK-NEXT: affine.store + // CHECK: affine.for + // CHECK-NEXT: affine.if + // CHECK-NEXT: affine.store + // CHECK: affine.for + // CHECK-NEXT: affine.load + // CHECK-NEXT: affine.store + // CHECK: return +} + +// ----- + // CHECK-LABEL: func @permute_and_fuse() { func @permute_and_fuse() { %m = memref.alloc() : memref<10x20x30xf32>