diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp --- a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp +++ b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp @@ -252,12 +252,9 @@ Value blockFirstIndex = b.create(blockIndex, blockSize); // The last one-dimensional index in the block defined by the `blockIndex`: - // blockLastIndex = max(blockFirstIndex + blockSize, tripCount) - 1 Value blockEnd0 = b.create(blockFirstIndex, blockSize); - Value blockEnd1 = - b.create(arith::CmpIPredicate::sge, blockEnd0, tripCount); - Value blockEnd2 = b.create(blockEnd1, tripCount, blockEnd0); - Value blockLastIndex = b.create(blockEnd2, c1); + Value blockEnd1 = b.create(blockEnd0, tripCount); + Value blockLastIndex = b.create(blockEnd1, c1); // Convert one-dimensional indices to multi-dimensional coordinates. auto blockFirstCoord = delinearize(b, blockFirstIndex, tripCounts); @@ -696,17 +693,9 @@ // blockSize = min(tripCount, // max(ceil_div(tripCount, maxComputeBlocks), // ceil_div(minTaskSize, bodySize))) - Value bs0 = b.create(tripCount, maxComputeBlocks); - Value bs1 = - b.create(arith::CmpIPredicate::sge, bs0, minTaskSizeCst); - Value bs2 = b.create(bs1, bs0, minTaskSizeCst); - Value bs3 = - b.create(arith::CmpIPredicate::sle, tripCount, bs2); - Value blockSize0 = b.create(bs3, tripCount, bs2); - Value blockCount0 = b.create(tripCount, blockSize0); - - // Compute balanced block size for the estimated block count. - Value blockSize = b.create(tripCount, blockCount0); + Value bs0 = b.create(tripCount, maxComputeBlocks); + Value bs1 = b.create(bs0, minTaskSizeCst); + Value blockSize = b.create(tripCount, bs1); Value blockCount = b.create(tripCount, blockSize); // Create a parallel compute function that takes a block id and computes the diff --git a/mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-2d.mlir b/mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-2d.mlir --- a/mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-2d.mlir +++ b/mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-2d.mlir @@ -2,6 +2,7 @@ // RUN: -async-to-async-runtime \ // RUN: -async-runtime-ref-counting \ // RUN: -async-runtime-ref-counting-opt \ +// RUN: -arith-expand \ // RUN: -convert-async-to-llvm \ // RUN: -convert-scf-to-std \ // RUN: -convert-memref-to-llvm \ @@ -16,6 +17,7 @@ // RUN: mlir-opt %s -async-parallel-for \ // RUN: -async-to-async-runtime \ // RUN: -async-runtime-policy-based-ref-counting \ +// RUN: -arith-expand \ // RUN: -convert-async-to-llvm \ // RUN: -convert-scf-to-std \ // RUN: -convert-memref-to-llvm \ @@ -33,6 +35,7 @@ // RUN: -async-to-async-runtime \ // RUN: -async-runtime-ref-counting \ // RUN: -async-runtime-ref-counting-opt \ +// RUN: -arith-expand \ // RUN: -convert-async-to-llvm \ // RUN: -convert-scf-to-std \ // RUN: -convert-memref-to-llvm \