diff --git a/mlir/lib/Dialect/LoopOps/Transforms/ParallelLoopFusion.cpp b/mlir/lib/Dialect/LoopOps/Transforms/ParallelLoopFusion.cpp --- a/mlir/lib/Dialect/LoopOps/Transforms/ParallelLoopFusion.cpp +++ b/mlir/lib/Dialect/LoopOps/Transforms/ParallelLoopFusion.cpp @@ -162,11 +162,12 @@ struct ParallelLoopFusion : public LoopParallelLoopFusionBase { void runOnOperation() override { - for (Region ®ion : getOperation()->getRegions()) - naivelyFuseParallelOps(region); + getOperation()->walk([&](Operation *child) { + for (Region ®ion : child->getRegions()) + naivelyFuseParallelOps(region); + }); } }; - } // namespace std::unique_ptr mlir::createParallelLoopFusionPass() { diff --git a/mlir/test/Dialect/Loops/parallel-loop-fusion.mlir b/mlir/test/Dialect/Loops/parallel-loop-fusion.mlir --- a/mlir/test/Dialect/Loops/parallel-loop-fusion.mlir +++ b/mlir/test/Dialect/Loops/parallel-loop-fusion.mlir @@ -307,3 +307,53 @@ // CHECK-LABEL: func @do_not_fuse_loops_with_memref_defined_in_loop_bodies // CHECK: loop.parallel // CHECK: loop.parallel + +// ----- + +func @nested_fuse(%A: memref<2x2xf32>, %B: memref<2x2xf32>, + %C: memref<2x2xf32>, %result: memref<2x2xf32>) { + %c2 = constant 2 : index + %c0 = constant 0 : index + %c1 = constant 1 : index + %sum = alloc() : memref<2x2xf32> + loop.parallel (%k) = (%c0) to (%c2) step (%c1) { + loop.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { + %B_elem = load %B[%i, %j] : memref<2x2xf32> + %C_elem = load %C[%i, %j] : memref<2x2xf32> + %sum_elem = addf %B_elem, %C_elem : f32 + store %sum_elem, %sum[%i, %j] : memref<2x2xf32> + loop.yield + } + loop.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { + %sum_elem = load %sum[%i, %j] : memref<2x2xf32> + %A_elem = load %A[%i, %j] : memref<2x2xf32> + %product_elem = mulf %sum_elem, %A_elem : f32 + store %product_elem, %result[%i, %j] : memref<2x2xf32> + loop.yield + } + } + dealloc %sum : memref<2x2xf32> + return +} +// CHECK-LABEL: func @nested_fuse +// CHECK-SAME: ([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}, [[C:%.*]]: {{.*}}, +// CHECK-SAME: [[RESULT:%.*]]: {{.*}}) { +// CHECK: [[C2:%.*]] = constant 2 : index +// CHECK: [[C0:%.*]] = constant 0 : index +// CHECK: [[C1:%.*]] = constant 1 : index +// CHECK: [[SUM:%.*]] = alloc() +// CHECK: loop.parallel +// CHECK: loop.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]]) +// CHECK-SAME: to ([[C2]], [[C2]]) step ([[C1]], [[C1]]) { +// CHECK: [[B_ELEM:%.*]] = load [[B]]{{\[}}[[I]], [[J]]] +// CHECK: [[C_ELEM:%.*]] = load [[C]]{{\[}}[[I]], [[J]]] +// CHECK: [[SUM_ELEM:%.*]] = addf [[B_ELEM]], [[C_ELEM]] +// CHECK: store [[SUM_ELEM]], [[SUM]]{{\[}}[[I]], [[J]]] +// CHECK: [[SUM_ELEM_:%.*]] = load [[SUM]]{{\[}}[[I]], [[J]]] +// CHECK: [[A_ELEM:%.*]] = load [[A]]{{\[}}[[I]], [[J]]] +// CHECK: [[PRODUCT_ELEM:%.*]] = mulf [[SUM_ELEM_]], [[A_ELEM]] +// CHECK: store [[PRODUCT_ELEM]], [[RESULT]]{{\[}}[[I]], [[J]]] +// CHECK: loop.yield +// CHECK: } +// CHECK: } +// CHECK: dealloc [[SUM]]