diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp --- a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp +++ b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp @@ -183,6 +183,20 @@ return {rewriter.getFunctionType(inputs, TypeRange()), capturesVector}; } +static Value createCmpSelect(ImplicitLocOpBuilder builder, + arith::CmpIPredicate predicate, Value x, Value y) { + Value selector = builder.create(predicate, x, y); + return builder.create(selector, x, y); +} + +static Value createMax(ImplicitLocOpBuilder builder, Value a, Value b) { + return createCmpSelect(builder, arith::CmpIPredicate::sge, a, b); +} + +static Value createMin(ImplicitLocOpBuilder builder, Value a, Value b) { + return createCmpSelect(builder, arith::CmpIPredicate::sle, a, b); +} + // Create a parallel compute fuction from the parallel operation. static ParallelComputeFunction createParallelComputeFunction(scf::ParallelOp op, PatternRewriter &rewriter) { @@ -252,12 +266,9 @@ Value blockFirstIndex = b.create(blockIndex, blockSize); // The last one-dimensional index in the block defined by the `blockIndex`: - // blockLastIndex = max(blockFirstIndex + blockSize, tripCount) - 1 Value blockEnd0 = b.create(blockFirstIndex, blockSize); - Value blockEnd1 = - b.create(arith::CmpIPredicate::sge, blockEnd0, tripCount); - Value blockEnd2 = b.create(blockEnd1, tripCount, blockEnd0); - Value blockLastIndex = b.create(blockEnd2, c1); + Value blockEnd1 = createMin(b, blockEnd0, tripCount); + Value blockLastIndex = b.create(blockEnd1, c1); // Convert one-dimensional indices to multi-dimensional coordinates. auto blockFirstCoord = delinearize(b, blockFirstIndex, tripCounts); @@ -696,17 +707,9 @@ // blockSize = min(tripCount, // max(ceil_div(tripCount, maxComputeBlocks), // ceil_div(minTaskSize, bodySize))) - Value bs0 = b.create(tripCount, maxComputeBlocks); - Value bs1 = - b.create(arith::CmpIPredicate::sge, bs0, minTaskSizeCst); - Value bs2 = b.create(bs1, bs0, minTaskSizeCst); - Value bs3 = - b.create(arith::CmpIPredicate::sle, tripCount, bs2); - Value blockSize0 = b.create(bs3, tripCount, bs2); - Value blockCount0 = b.create(tripCount, blockSize0); - - // Compute balanced block size for the estimated block count. - Value blockSize = b.create(tripCount, blockCount0); + Value bs0 = b.create(tripCount, maxComputeBlocks); + Value bs1 = createMax(b, bs0, minTaskSizeCst); + Value blockSize = createMin(b, tripCount, bs1); Value blockCount = b.create(tripCount, blockSize); // Create a parallel compute function that takes a block id and computes the