diff --git a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp --- a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp +++ b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp @@ -596,7 +596,7 @@ matchAndRewrite(RuntimeCreateGroupOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { TypeConverter *converter = getTypeConverter(); - Type resultType = op->getResultTypes()[0]; + Type resultType = op.getResult().getType(); rewriter.replaceOpWithNewOp( op, kCreateGroup, converter->convertType(resultType), operands); diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp --- a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp +++ b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp @@ -90,6 +90,14 @@ struct AsyncParallelForPass : public AsyncParallelForBase { AsyncParallelForPass() = default; + + AsyncParallelForPass(bool asyncDispatch, int32_t numWorkerThreads, + int32_t targetBlockSize) { + this->asyncDispatch = asyncDispatch; + this->numWorkerThreads = numWorkerThreads; + this->targetBlockSize = targetBlockSize; + } + void runOnOperation() override; }; @@ -127,7 +135,7 @@ // Converts one-dimensional iteration index in the [0, tripCount) interval // into multidimensional iteration coordinate. static SmallVector delinearize(ImplicitLocOpBuilder &b, Value index, - const SmallVector &tripCounts) { + ArrayRef tripCounts) { SmallVector coords(tripCounts.size()); assert(!tripCounts.empty() && "tripCounts must be not empty"); @@ -184,7 +192,6 @@ ImplicitLocOpBuilder b(op.getLoc(), rewriter); ModuleOp module = op->getParentOfType(); - b.setInsertionPointToStart(&module->getRegion(0).front()); ParallelComputeFunctionType computeFuncType = getParallelComputeFunctionType(op, rewriter); @@ -204,12 +211,13 @@ unsigned offset = 0; // argument offset for arguments decoding - // Load multiple arguments into values vector. - auto getArguments = [&](unsigned num_arguments) -> SmallVector { - SmallVector values(num_arguments); - for (unsigned i = 0; i < num_arguments; ++i) - values[i] = block->getArgument(offset++); - return values; + // Returns `numArguments` arguments starting from `offset` and updates offset + // by moving forward to the next argument. + auto getArguments = [&](unsigned numArguments) -> ArrayRef { + auto args = block->getArguments(); + auto slice = args.drop_front(offset).take_front(numArguments); + offset += numArguments; + return {slice.begin(), slice.end()}; }; // Block iteration position defined by the block index and size. @@ -217,11 +225,11 @@ Value blockSize = block->getArgument(offset++); // Constants used below. - Value c0 = b.create(b.getIndexAttr(0)); - Value c1 = b.create(b.getIndexAttr(1)); + Value c0 = b.create(0); + Value c1 = b.create(1); // Multi-dimensional parallel iteration space defined by the loop trip counts. - SmallVector tripCounts = getArguments(op.getNumLoops()); + ArrayRef tripCounts = getArguments(op.getNumLoops()); // Compute a product of trip counts to get the size of the flattened // one-dimensional iteration space. @@ -229,35 +237,34 @@ for (unsigned i = 1; i < tripCounts.size(); ++i) tripCount = b.create(tripCount, tripCounts[i]); - // Parallel operation lower bound, upper bound and step. - SmallVector lowerBound = getArguments(op.getNumLoops()); - SmallVector upperBound = getArguments(op.getNumLoops()); - SmallVector step = getArguments(op.getNumLoops()); + // Parallel operation lower bound and step. + ArrayRef lowerBound = getArguments(op.getNumLoops()); + offset += op.getNumLoops(); // skip upper bound arguments + ArrayRef step = getArguments(op.getNumLoops()); // Remaining arguments are implicit captures of the parallel operation. - SmallVector captures = getArguments(block->getNumArguments() - offset); + ArrayRef captures = getArguments(block->getNumArguments() - offset); // Find one-dimensional iteration bounds: [blockFirstIndex, blockLastIndex]: // blockFirstIndex = blockIndex * blockSize Value blockFirstIndex = b.create(blockIndex, blockSize); // The last one-dimensional index in the block defined by the `blockIndex`: - // blockLastIndex = max((blockIndex + 1) * blockSize, tripCount) - 1 - Value blockEnd0 = b.create(blockIndex, c1); - Value blockEnd1 = b.create(blockEnd0, blockSize); - Value blockEnd2 = b.create(CmpIPredicate::sge, blockEnd1, tripCount); - Value blockEnd3 = b.create(blockEnd2, tripCount, blockEnd1); - Value blockLastIndex = b.create(blockEnd3, c1); + // blockLastIndex = max(blockFirstIndex + blockSize, tripCount) - 1 + Value blockEnd0 = b.create(blockFirstIndex, blockSize); + Value blockEnd1 = b.create(CmpIPredicate::sge, blockEnd0, tripCount); + Value blockEnd2 = b.create(blockEnd1, tripCount, blockEnd0); + Value blockLastIndex = b.create(blockEnd2, c1); // Convert one-dimensional indices to multi-dimensional coordinates. auto blockFirstCoord = delinearize(b, blockFirstIndex, tripCounts); auto blockLastCoord = delinearize(b, blockLastIndex, tripCounts); - // Compute compute loops upper bounds from the block last coordinates: + // Compute loops upper bounds derived from the block last coordinates: // blockEndCoord[i] = blockLastCoord[i] + 1 // // Block first and last coordinates can be the same along the outer compute - // dimension when inner compute dimension containts multple blocks. + // dimension when inner compute dimension contains multiple blocks. SmallVector blockEndCoord(op.getNumLoops()); for (size_t i = 0; i < blockLastCoord.size(); ++i) blockEndCoord[i] = b.create(blockLastCoord[i], c1); @@ -312,7 +319,7 @@ isBlockLastCoord[loopIdx] = nb.create(CmpIPredicate::eq, iv, blockLastCoord[loopIdx]); - // Check if the previous loop is in its first of last iteration. + // Check if the previous loop is in its first or last iteration. if (loopIdx > 0) { isBlockFirstCoord[loopIdx] = nb.create( isBlockFirstCoord[loopIdx], isBlockFirstCoord[loopIdx - 1]); @@ -380,7 +387,6 @@ ImplicitLocOpBuilder b(loc, rewriter); ModuleOp module = computeFunc.func->getParentOfType(); - b.setInsertionPointToStart(&module->getRegion(0).front()); ArrayRef computeFuncInputTypes = computeFunc.func.type().cast().getInputs(); @@ -408,8 +414,8 @@ b.setInsertionPointToEnd(block); Type indexTy = b.getIndexType(); - Value c1 = b.create(b.getIndexAttr(1)); - Value c2 = b.create(b.getIndexAttr(2)); + Value c1 = b.create(1); + Value c2 = b.create(2); // Get the async group that will track async dispatch completion. Value group = block->getArgument(0); @@ -439,14 +445,14 @@ } // Setup the async dispatch loop body: recursively call dispatch function - // for second the half of the original range and go to the next iteration. + // for the seconds half of the original range and go to the next iteration. { b.setInsertionPointToEnd(after); Value start = after->getArgument(0); Value end = after->getArgument(1); Value distance = b.create(end, start); Value halfDistance = b.create(distance, c2); - Value midIndex = b.create(after->getArgument(0), halfDistance); + Value midIndex = b.create(start, halfDistance); // Call parallel compute function inside the async.execute region. auto executeBodyBuilder = [&](OpBuilder &executeBuilder, @@ -466,7 +472,7 @@ auto execute = b.create(TypeRange(), ValueRange(), ValueRange(), executeBodyBuilder); b.create(indexTy, execute.token(), group); - b.create(ValueRange({after->getArgument(0), midIndex})); + b.create(ValueRange({start, midIndex})); } // After dispatching async operations to process the tail of the block range @@ -498,8 +504,8 @@ FuncOp asyncDispatchFunction = createAsyncDispatchFunction(parallelComputeFunction, rewriter); - Value c0 = b.create(b.getIndexAttr(0)); - Value c1 = b.create(b.getIndexAttr(1)); + Value c0 = b.create(0); + Value c1 = b.create(1); // Create an async.group to wait on all async tokens from the concurrent // execution of multiple parallel compute function. First block will be @@ -535,8 +541,8 @@ FuncOp compute = parallelComputeFunction.func; - Value c0 = b.create(b.getIndexAttr(0)); - Value c1 = b.create(b.getIndexAttr(1)); + Value c0 = b.create(0); + Value c1 = b.create(1); // Create an async.group to wait on all async tokens from the concurrent // execution of multiple parallel compute function. First block will be @@ -617,19 +623,16 @@ for (size_t i = 1; i < tripCounts.size(); ++i) tripCount = b.create(tripCount, tripCounts[i]); - auto indexTy = b.getIndexType(); - // Do not overload worker threads with too many compute blocks. - Value maxComputeBlocks = b.create( - indexTy, b.getIndexAttr(numWorkerThreads * kMaxOversharding)); + Value maxComputeBlocks = + b.create(numWorkerThreads * kMaxOversharding); // Target block size from the pass parameters. - Value targetComputeBlockSize = - b.create(indexTy, b.getIndexAttr(targetBlockSize)); + Value targetComputeBlockSize = b.create(targetBlockSize); // Compute parallel block size from the parallel problem size: // blockSize = min(tripCount, - // max(divup(tripCount, maxComputeBlocks), + // max(ceil_div(tripCount, maxComputeBlocks), // targetComputeBlockSize)) Value bs0 = b.create(tripCount, maxComputeBlocks); Value bs1 = b.create(CmpIPredicate::sge, bs0, targetComputeBlockSize); @@ -653,7 +656,7 @@ blockCount, tripCounts); } - // Parallel operation was replaces with a block iteration loop. + // Parallel operation was replaced with a block iteration loop. rewriter.eraseOp(op); return success(); @@ -673,3 +676,10 @@ std::unique_ptr mlir::createAsyncParallelForPass() { return std::make_unique(); } + +std::unique_ptr +mlir::createAsyncParallelForPass(bool asyncDispatch, int32_t numWorkerThreads, + int32_t targetBlockSize) { + return std::make_unique(asyncDispatch, numWorkerThreads, + targetBlockSize); +} diff --git a/mlir/test/Dialect/Async/async-parallel-for-async-dispatch.mlir b/mlir/test/Dialect/Async/async-parallel-for-async-dispatch.mlir --- a/mlir/test/Dialect/Async/async-parallel-for-async-dispatch.mlir +++ b/mlir/test/Dialect/Async/async-parallel-for-async-dispatch.mlir @@ -18,18 +18,33 @@ // CHECK: memref.store // CHECK-LABEL: func private @async_dispatch_fn +// CHECK-SAME: ( // CHECK-SAME: %[[GROUP:arg0]]: !async.group, // CHECK-SAME: %[[BLOCK_START:arg1]]: index // CHECK-SAME: %[[BLOCK_END:arg2]]: index - -// CHECK: scf.while (%[[S:.*]] = %[[BLOCK_START]], -// CHECK-SAME: %[[E:.*]] = %[[BLOCK_END]]) +// CHECK-SAME: ) +// CHECK: %[[C1:.*]] = constant 1 : index +// CHECK: %[[C2:.*]] = constant 2 : index +// CHECK: scf.while (%[[S0:.*]] = %[[BLOCK_START]], +// CHECK-SAME: %[[E0:.*]] = %[[BLOCK_END]]) +// While loop `before` block decides if we need to dispatch more tasks. +// CHECK: { +// CHECK: %[[DIFF0:.*]] = subi %[[E0]], %[[S0]] +// CHECK: %[[COND:.*]] = cmpi sgt, %[[DIFF0]], %[[C1]] +// CHECK: scf.condition(%[[COND]]) +// While loop `after` block splits the range in half and submits async task +// to process the second half using the call to the same dispatch function. // CHECK: } do { +// CHECK: ^bb0(%[[S1:.*]]: index, %[[E1:.*]]: index): +// CHECK: %[[DIFF1:.*]] = subi %[[E1]], %[[S1]] +// CHECK: %[[HALF:.*]] = divi_signed %[[DIFF1]], %[[C2]] +// CHECK: %[[MID:.*]] = addi %[[S1]], %[[HALF]] // CHECK: %[[TOKEN:.*]] = async.execute // CHECK: call @async_dispatch_fn -// CHECK: async.add_to_group +// CHECK: async.add_to_group +// CHECK: scf.yield %[[S1]], %[[MID]] // CHECK: } - +// After async dispatch the first block processed in the caller thread. // CHECK: call @parallel_compute_fn(%[[BLOCK_START]] // ----- diff --git a/mlir/test/Dialect/Async/async-parallel-for-seq-dispatch.mlir b/mlir/test/Dialect/Async/async-parallel-for-seq-dispatch.mlir --- a/mlir/test/Dialect/Async/async-parallel-for-seq-dispatch.mlir +++ b/mlir/test/Dialect/Async/async-parallel-for-seq-dispatch.mlir @@ -1,6 +1,9 @@ // RUN: mlir-opt %s -split-input-file -async-parallel-for=async-dispatch=false \ // RUN: | FileCheck %s --dump-input=always +// The structure of @parallel_compute_fn checked in the async dispatch test. +// Here we only check the structure of the sequential dispatch loop. + // CHECK-LABEL: @loop_1d func @loop_1d(%arg0: index, %arg1: index, %arg2: index, %arg3: memref) { // CHECK: %[[GROUP:.*]] = async.create_group