diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineParallelize.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineParallelize.cpp --- a/mlir/lib/Dialect/Affine/Transforms/AffineParallelize.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/AffineParallelize.cpp @@ -50,14 +50,13 @@ void AffineParallelize::runOnFunction() { FuncOp f = getFunction(); - // The walker proceeds in post-order, but we need to process outer loops first - // to control the number of outer parallel loops, so push candidate loops to - // the front of a deque. - std::deque parallelizableLoops; - f.walk([&](AffineForOp loop) { + // The walker proceeds in pre-order to process the outer loops first + // and control the number of outer parallel loops. + std::vector parallelizableLoops; + f.walk([&](AffineForOp loop) { SmallVector reductions; if (isLoopParallel(loop, parallelReductions ? &reductions : nullptr)) - parallelizableLoops.emplace_back(loop, std::move(reductions)); + parallelizableLoops.push_back({loop, std::move(reductions)}); }); for (const ParallelizationCandidate &candidate : parallelizableLoops) { diff --git a/mlir/test/Dialect/Affine/parallelize.mlir b/mlir/test/Dialect/Affine/parallelize.mlir --- a/mlir/test/Dialect/Affine/parallelize.mlir +++ b/mlir/test/Dialect/Affine/parallelize.mlir @@ -155,6 +155,27 @@ return } +// MAX-NESTED-LABEL: @max_nested_1 +func @max_nested_1(%arg0: memref<4096x4096xf32>, %arg1: memref<4096x4096xf32>, %arg2: memref<4096x4096xf32>) { + %0 = memref.alloc() : memref<4096x4096xf32> + // MAX-NESTED: affine.parallel + affine.for %arg3 = 0 to 4096 { + // MAX-NESTED-NEXT: affine.for + affine.for %arg4 = 0 to 4096 { + // MAX-NESTED-NEXT: affine.for + affine.for %arg5 = 0 to 4096 { + %1 = affine.load %arg0[%arg3, %arg5] : memref<4096x4096xf32> + %2 = affine.load %arg1[%arg5, %arg4] : memref<4096x4096xf32> + %3 = affine.load %0[%arg3, %arg4] : memref<4096x4096xf32> + %4 = mulf %1, %2 : f32 + %5 = addf %3, %4 : f32 + affine.store %5, %0[%arg3, %arg4] : memref<4096x4096xf32> + } + } + } + return +} + // CHECK-LABEL: @iter_args // REDUCE-LABEL: @iter_args func @iter_args(%in: memref<10xf32>) {