diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp --- a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp +++ b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp @@ -240,10 +240,9 @@ } // Create a parallel compute fuction from the parallel operation. -static ParallelComputeFunction -createParallelComputeFunction(scf::ParallelOp op, - ParallelComputeFunctionBounds bounds, - PatternRewriter &rewriter) { +static ParallelComputeFunction createParallelComputeFunction( + scf::ParallelOp op, ParallelComputeFunctionBounds bounds, + unsigned numBlockAlignedInnerLoops, PatternRewriter &rewriter) { OpBuilder::InsertionGuard guard(rewriter); ImplicitLocOpBuilder b(op.getLoc(), rewriter); @@ -384,17 +383,26 @@ // Keep building loop nest. if (loopIdx < op.getNumLoops() - 1) { - // Select nested loop lower/upper bounds depending on our position in - // the multi-dimensional iteration space. - auto lb = nb.create(isBlockFirstCoord[loopIdx], - blockFirstCoord[loopIdx + 1], c0); + if (loopIdx + 1 >= op.getNumLoops() - numBlockAlignedInnerLoops) { + // For block aligned loops we always iterate starting from 0 up to + // the loop trip counts. + nb.create(c0, tripCounts[loopIdx + 1], c1, ValueRange(), + workLoopBuilder(loopIdx + 1)); + + } else { + // Select nested loop lower/upper bounds depending on our position in + // the multi-dimensional iteration space. + auto lb = nb.create(isBlockFirstCoord[loopIdx], + blockFirstCoord[loopIdx + 1], c0); + + auto ub = nb.create(isBlockLastCoord[loopIdx], + blockEndCoord[loopIdx + 1], + tripCounts[loopIdx + 1]); + + nb.create(lb, ub, c1, ValueRange(), + workLoopBuilder(loopIdx + 1)); + } - auto ub = nb.create(isBlockLastCoord[loopIdx], - blockEndCoord[loopIdx + 1], - tripCounts[loopIdx + 1]); - - nb.create(lb, ub, c1, ValueRange(), - workLoopBuilder(loopIdx + 1)); nb.create(loc); return; } @@ -731,6 +739,46 @@ auto dispatch = [&](OpBuilder &nestedBuilder, Location loc) { ImplicitLocOpBuilder nb(loc, nestedBuilder); + // Collect statically known constants defining the loop nest in the parallel + // compute function. LLVM can't always push constants across the non-trivial + // async dispatch call graph, by providing these values explicitly we can + // choose to build more efficient loop nest, and rely on a better constant + // folding, loop unrolling and vectorization. + ParallelComputeFunctionBounds staticBounds = { + integerConstants(tripCounts), + integerConstants(op.lowerBound()), + integerConstants(op.upperBound()), + integerConstants(op.step()), + }; + + // Find how many inner iteration dimensions are statically known, and their + // product is smaller than the `512`. We aling the parallel compute block + // size by the product of statically known dimensions, so that we can + // guarantee that the inner loops executes from 0 to the loop trip counts + // and we can elide dynamic loop boundaries, and give LLVM an opportunity to + // unroll the loops. The constant `512` is arbitrary, it should depend on + // how many iterations LLVM will typically decide to unroll. + static constexpr int64_t maxIterations = 512; + + // The number of inner loops with statically known number of iterations less + // than the `maxIterations` value. + int numUnrollableLoops = 0; + + auto getInt = [](IntegerAttr attr) { return attr ? attr.getInt() : 0; }; + + SmallVector numIterations(op.getNumLoops()); + numIterations.back() = getInt(staticBounds.tripCounts.back()); + + for (int i = op.getNumLoops() - 2; i >= 0; --i) { + int64_t tripCount = getInt(staticBounds.tripCounts[i]); + int64_t innerIterations = numIterations[i + 1]; + numIterations[i] = tripCount * innerIterations; + + // Update the number of inner loops that we can potentially unroll. + if (innerIterations > 0 && innerIterations <= maxIterations) + numUnrollableLoops++; + } + // With large number of threads the value of creating many compute blocks // is reduced because the problem typically becomes memory bound. For small // number of threads it helps with stragglers. @@ -755,24 +803,25 @@ Value bs0 = b.create(tripCount, maxComputeBlocks); Value bs1 = b.create(bs0, minTaskSizeCst); Value blockSize = b.create(tripCount, bs1); - Value blockCount = b.create(tripCount, blockSize); - // Collect statically known constants defining the loop nest in the parallel - // compute function. LLVM can't always push constants across the non-trivial - // async dispatch call graph, by providing these values explicitly we can - // choose to build more efficient loop nest, and rely on a better constant - // folding, loop unrolling and vectorization. - ParallelComputeFunctionBounds staticBounds = { - integerConstants(tripCounts), - integerConstants(op.lowerBound()), - integerConstants(op.upperBound()), - integerConstants(op.step()), - }; + // Align the block size to be a multiple of the statically known number + // of iterations in the inner loops. + if (numUnrollableLoops > 0 && minTaskSize >= maxIterations) { + Value numIters = b.create( + numIterations[op.getNumLoops() - numUnrollableLoops]); + Value bs2 = b.create( + b.create(blockSize, numIters), numIters); + blockSize = b.create(tripCount, bs2); + } + + // Compute the number of parallel compute blocks. + Value blockCount = b.create(tripCount, blockSize); - // Create a parallel compute function that takes a block id and computes the - // parallel operation body for a subset of iteration space. + // Create a parallel compute function that takes a block id and computes + // the parallel operation body for a subset of iteration space. ParallelComputeFunction parallelComputeFunction = - createParallelComputeFunction(op, staticBounds, rewriter); + createParallelComputeFunction(op, staticBounds, numUnrollableLoops, + rewriter); // Dispatch parallel compute function using async recursive work splitting, // or by submitting compute task sequentially from a caller thread. diff --git a/mlir/test/Dialect/Async/async-parallel-for-compute-fn.mlir b/mlir/test/Dialect/Async/async-parallel-for-compute-fn.mlir --- a/mlir/test/Dialect/Async/async-parallel-for-compute-fn.mlir +++ b/mlir/test/Dialect/Async/async-parallel-for-compute-fn.mlir @@ -65,4 +65,44 @@ // CHECK: scf.for %[[I:arg[0-9]+]] // CHECK: %[[TMP:.*]] = arith.muli %[[I]], %[[CSTEP]] // CHECK: %[[IDX:.*]] = arith.addi %[[LB]], %[[TMP]] -// CHECK: memref.store \ No newline at end of file +// CHECK: memref.store + +// ----- + +// Check that for statically known inner loop bound block size is aligned and +// inner loop uses statically known loop trip counts. + +// CHECK-LABEL: func @sink_constant_step( +func @sink_constant_step(%arg0: memref, %lb: index, %ub: index) { + %one = arith.constant 1.0 : f32 + + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c10 = arith.constant 10 : index + + scf.parallel (%i, %j) = (%lb, %c0) to (%ub, %c10) step (%c1, %c1) { + memref.store %one, %arg0[%i, %j] : memref + } + + return +} + +// CHECK-LABEL: func private @parallel_compute_fn( +// CHECK-SAME: %[[BLOCK_INDEX:arg[0-9]+]]: index, +// CHECK-SAME: %[[BLOCK_SIZE:arg[0-9]+]]: index, +// CHECK-SAME: %[[TRIP_COUNT0:arg[0-9]+]]: index, +// CHECK-SAME: %[[TRIP_COUNT1:arg[0-9]+]]: index, +// CHECK-SAME: %[[LB0:arg[0-9]+]]: index, +// CHECK-SAME: %[[LB1:arg[0-9]+]]: index, +// CHECK-SAME: %[[UB0:arg[0-9]+]]: index, +// CHECK-SAME: %[[UB1:arg[0-9]+]]: index, +// CHECK-SAME: %[[STEP0:arg[0-9]+]]: index, +// CHECK-SAME: %[[STEP1:arg[0-9]+]]: index, +// CHECK-SAME: %[[MEMREF:arg[0-9]+]]: memref +// CHECK-SAME: ) { +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[C10:.*]] = arith.constant 10 : index +// CHECK: scf.for %[[I:arg[0-9]+]] +// CHECK-NOT: select +// CHECK: scf.for %[[J:arg[0-9]+]] = %c0 to %c10 step %c1 \ No newline at end of file