diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp --- a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp +++ b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp @@ -252,6 +252,7 @@ Value blockFirstIndex = b.create(blockIndex, blockSize); // The last one-dimensional index in the block defined by the `blockIndex`: + // blockLastIndex = min(blockFirstIndex + blockSize, tripCount) - 1 Value blockEnd0 = b.create(blockFirstIndex, blockSize); Value blockEnd1 = b.create(blockEnd0, tripCount); Value blockLastIndex = b.create(blockEnd1, c1); @@ -279,7 +280,7 @@ // iteration coordinate using parallel operation bounds and step: // // computeBlockInductionVars[loopIdx] = - // lowerBound[loopIdx] + blockCoord[loopIdx] * step[loopDdx] + // lowerBound[loopIdx] + blockCoord[loopIdx] * step[loopIdx] SmallVector computeBlockInductionVars(op.getNumLoops()); // We need to know if we are in the first or last iteration of the @@ -329,7 +330,7 @@ // Keep building loop nest. if (loopIdx < op.getNumLoops() - 1) { - // Select nested loop lower/upper bounds depending on out position in + // Select nested loop lower/upper bounds depending on our position in // the multi-dimensional iteration space. auto lb = nb.create(isBlockFirstCoord[loopIdx], blockFirstCoord[loopIdx + 1], c0); diff --git a/mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-1d.mlir b/mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-1d.mlir --- a/mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-1d.mlir +++ b/mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-1d.mlir @@ -68,11 +68,17 @@ %A = memref.alloc() : memref<9xf32> %U = memref.cast %A : memref<9xf32> to memref<*xf32> + scf.parallel (%i) = (%lb) to (%ub) step (%c1) { + memref.store %c0, %A[%i] : memref<9xf32> + } + // 1. %i = (0) to (9) step (1) scf.parallel (%i) = (%lb) to (%ub) step (%c1) { %0 = arith.index_cast %i : index to i32 %1 = arith.sitofp %0 : i32 to f32 - memref.store %1, %A[%i] : memref<9xf32> + %2 = memref.load %A[%i] : memref<9xf32> + %3 = arith.addf %1, %2 : f32 + memref.store %3, %A[%i] : memref<9xf32> } // CHECK: [0, 1, 2, 3, 4, 5, 6, 7, 8] call @print_memref_f32(%U): (memref<*xf32>) -> () @@ -85,7 +91,9 @@ scf.parallel (%i) = (%lb) to (%ub) step (%c2) { %0 = arith.index_cast %i : index to i32 %1 = arith.sitofp %0 : i32 to f32 - memref.store %1, %A[%i] : memref<9xf32> + %2 = memref.load %A[%i] : memref<9xf32> + %3 = arith.addf %1, %2 : f32 + memref.store %3, %A[%i] : memref<9xf32> } // CHECK: [0, 0, 2, 0, 4, 0, 6, 0, 8] call @print_memref_f32(%U): (memref<*xf32>) -> () @@ -102,7 +110,9 @@ %1 = arith.sitofp %0 : i32 to f32 %2 = arith.constant 20 : index %3 = arith.addi %i, %2 : index - memref.store %1, %A[%3] : memref<9xf32> + %4 = memref.load %A[%3] : memref<9xf32> + %5 = arith.addf %1, %4 : f32 + memref.store %5, %A[%3] : memref<9xf32> } // CHECK: [-20, 0, 0, -17, 0, 0, -14, 0, 0] call @print_memref_f32(%U): (memref<*xf32>) -> () diff --git a/mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-2d.mlir b/mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-2d.mlir --- a/mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-2d.mlir +++ b/mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-2d.mlir @@ -7,45 +7,13 @@ // RUN: -convert-scf-to-std \ // RUN: -convert-memref-to-llvm \ // RUN: -convert-std-to-llvm \ -// RUN: -reconcile-unrealized-casts \ +// RUN: -reconcile-unrealized-casts -print-ir-after=async-parallel-for \ // RUN: | mlir-cpu-runner \ // RUN: -e entry -entry-point-result=void -O0 \ // RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \ // RUN: -shared-libs=%mlir_integration_test_dir/libmlir_async_runtime%shlibext\ // RUN: | FileCheck %s --dump-input=always -// RUN: mlir-opt %s -async-parallel-for \ -// RUN: -async-to-async-runtime \ -// RUN: -async-runtime-policy-based-ref-counting \ -// RUN: -arith-expand \ -// RUN: -convert-async-to-llvm \ -// RUN: -convert-scf-to-std \ -// RUN: -convert-memref-to-llvm \ -// RUN: -convert-std-to-llvm \ -// RUN: -reconcile-unrealized-casts \ -// RUN: | mlir-cpu-runner \ -// RUN: -e entry -entry-point-result=void -O0 \ -// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \ -// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_async_runtime%shlibext\ -// RUN: | FileCheck %s --dump-input=always - -// RUN: mlir-opt %s -async-parallel-for="async-dispatch=false \ -// RUN: num-workers=20 \ -// RUN: min-task-size=1" \ -// RUN: -async-to-async-runtime \ -// RUN: -async-runtime-ref-counting \ -// RUN: -async-runtime-ref-counting-opt \ -// RUN: -arith-expand \ -// RUN: -convert-async-to-llvm \ -// RUN: -convert-scf-to-std \ -// RUN: -convert-memref-to-llvm \ -// RUN: -convert-std-to-llvm \ -// RUN: -reconcile-unrealized-casts \ -// RUN: | mlir-cpu-runner \ -// RUN: -e entry -entry-point-result=void -O0 \ -// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \ -// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_async_runtime%shlibext\ -// RUN: | FileCheck %s --dump-input=always func @entry() { %c0 = arith.constant 0.0 : f32 @@ -59,12 +27,18 @@ %A = memref.alloc() : memref<8x8xf32> %U = memref.cast %A : memref<8x8xf32> to memref<*xf32> + scf.parallel (%i, %j) = (%lb, %lb) to (%ub, %ub) step (%c1, %c1) { + memref.store %c0, %A[%i, %j] : memref<8x8xf32> + } + // 1. (%i, %i) = (0, 8) to (8, 8) step (1, 1) scf.parallel (%i, %j) = (%lb, %lb) to (%ub, %ub) step (%c1, %c1) { %0 = arith.muli %i, %c8 : index %1 = arith.addi %j, %0 : index %2 = arith.index_cast %1 : index to i32 %3 = arith.sitofp %2 : i32 to f32 + %4 = memref.load %A[%i, %j] : memref<8x8xf32> + %5 = arith.addf %3, %4 : f32 memref.store %3, %A[%i, %j] : memref<8x8xf32> } @@ -88,7 +62,9 @@ %1 = arith.addi %j, %0 : index %2 = arith.index_cast %1 : index to i32 %3 = arith.sitofp %2 : i32 to f32 - memref.store %3, %A[%i, %j] : memref<8x8xf32> + %4 = memref.load %A[%i, %j] : memref<8x8xf32> + %5 = arith.addf %3, %4 : f32 + memref.store %5, %A[%i, %j] : memref<8x8xf32> } // CHECK: [0, 1, 2, 3, 4, 5, 6, 7] @@ -111,7 +87,9 @@ %1 = arith.addi %j, %0 : index %2 = arith.index_cast %1 : index to i32 %3 = arith.sitofp %2 : i32 to f32 - memref.store %3, %A[%i, %j] : memref<8x8xf32> + %4 = memref.load %A[%i, %j] : memref<8x8xf32> + %5 = arith.addf %3, %4 : f32 + memref.store %5, %A[%i, %j] : memref<8x8xf32> } // CHECK: [0, 0, 2, 0, 4, 0, 6, 0]