diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -70,19 +70,20 @@ namespace { // LoopNestStateCollector walks loop nests and collects load and store -// operations, and whether or not an IfInst was encountered in the loop nest. +// operations, and whether or not a non-region operation, except IfOp, was +// encountered in the loop nest. struct LoopNestStateCollector { SmallVector forOps; SmallVector loadOpInsts; SmallVector storeOpInsts; - bool hasNonForRegion = false; + bool hasNonForRegionExceptIfOp = false; void collect(Operation *opToWalk) { opToWalk->walk([&](Operation *op) { if (isa(op)) forOps.push_back(cast(op)); - else if (op->getNumRegions() != 0) - hasNonForRegion = true; + else if (op->getNumRegions() != 0 && !isa(op)) + hasNonForRegionExceptIfOp = true; else if (isa(op)) loadOpInsts.push_back(op); else if (isa(op)) @@ -745,8 +746,8 @@ LoopNestStateCollector collector; collector.collect(&op); // Return false if a non 'affine.for' region was found (not currently - // supported). - if (collector.hasNonForRegion) + // supported), except 'affine.if' which is handled by 'canFuseLoops'. + if (collector.hasNonForRegionExceptIfOp) return false; Node node(nextNodeId++, &op); for (auto *opInst : collector.loadOpInsts) { diff --git a/mlir/test/Transforms/loop-fusion.mlir b/mlir/test/Transforms/loop-fusion.mlir --- a/mlir/test/Transforms/loop-fusion.mlir +++ b/mlir/test/Transforms/loop-fusion.mlir @@ -502,6 +502,83 @@ // ----- +#set = affine_set<(d0) : (d0 - 1 >= 0)> + +// CHECK-LABEL: func @should_fuse_if_inst_in_loop_nest_not_sandwiched() -> memref<10xf32> { +func @should_fuse_if_inst_in_loop_nest_not_sandwiched() -> memref<10xf32> { + %a = memref.alloc() : memref<10xf32> + %b = memref.alloc() : memref<10xf32> + %cf7 = constant 7.0 : f32 + + affine.for %i0 = 0 to 10 { + affine.store %cf7, %a[%i0] : memref<10xf32> + } + affine.for %i1 = 0 to 10 { + %v0 = affine.load %a[%i1] : memref<10xf32> + affine.store %v0, %b[%i1] : memref<10xf32> + } + affine.for %i2 = 0 to 10 { + affine.if #set(%i2) { + %v0 = affine.load %b[%i2] : memref<10xf32> + } + } + + // IfOp in ForInst should not prevent fusion if it does not in between the + // source and dest ForInst ops. + + // CHECK: affine.for + // CHECK-NEXT: affine.store + // CHECK-NEXT: affine.load + // CHECK-NEXT: affine.store + // CHECK: affine.for + // CHECK-NEXT: affine.if + // CHECK-NEXT: affine.load + // CHECK-NOT: affine.for + // CHECK: return + + return %a : memref<10xf32> +} + +// ----- + +#set = affine_set<(d0) : (d0 - 1 >= 0)> + +// CHECK-LABEL: func @should_not_fuse_if_inst_in_loop_nest_between_src_and_dest() -> memref<10xf32> { +func @should_not_fuse_if_inst_in_loop_nest_between_src_and_dest() -> memref<10xf32> { + %a = memref.alloc() : memref<10xf32> + %b = memref.alloc() : memref<10xf32> + %cf7 = constant 7.0 : f32 + + affine.for %i0 = 0 to 10 { + affine.store %cf7, %a[%i0] : memref<10xf32> + } + affine.for %i1 = 0 to 10 { + affine.if #set(%i1) { + affine.store %cf7, %a[%i1] : memref<10xf32> + } + } + affine.for %i3 = 0 to 10 { + %v0 = affine.load %a[%i3] : memref<10xf32> + affine.store %v0, %b[%i3] : memref<10xf32> + } + return %b : memref<10xf32> + + // IfOp in ForInst which modifies the memref should prevent fusion if it is in + // between the source and dest ForInst ops. + + // CHECK: affine.for + // CHECK-NEXT: affine.store + // CHECK: affine.for + // CHECK-NEXT: affine.if + // CHECK-NEXT: affine.store + // CHECK: affine.for + // CHECK-NEXT: affine.load + // CHECK-NEXT: affine.store + // CHECK: return +} + +// ----- + // CHECK-LABEL: func @permute_and_fuse() { func @permute_and_fuse() { %m = memref.alloc() : memref<10x20x30xf32>