diff --git a/mlir/include/mlir/Dialect/Async/IR/Async.h b/mlir/include/mlir/Dialect/Async/IR/Async.h --- a/mlir/include/mlir/Dialect/Async/IR/Async.h +++ b/mlir/include/mlir/Dialect/Async/IR/Async.h @@ -22,6 +22,7 @@ #include "mlir/IR/OpImplementation.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td --- a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td +++ b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td @@ -16,6 +16,7 @@ include "mlir/Dialect/Async/IR/AsyncDialect.td" include "mlir/Dialect/Async/IR/AsyncTypes.td" include "mlir/Interfaces/ControlFlowInterfaces.td" +include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" //===----------------------------------------------------------------------===// @@ -529,4 +530,15 @@ }]; } +def Async_RuntimeNumWorkerThreadsOp : Async_Op<"runtime.num_worker_threads"> { + let summary = "gets the number of threads in the threadpool from the runtime"; + let description = [{ + The `async.runtime.num_worker_threads` operation gets the number of threads + in the threadpool from the runtime. + }]; + + let results = (outs Index:$result); + let assemblyFormat = "attr-dict `:` type($result)"; +} + #endif // ASYNC_OPS diff --git a/mlir/include/mlir/Dialect/Async/Passes.td b/mlir/include/mlir/Dialect/Async/Passes.td --- a/mlir/include/mlir/Dialect/Async/Passes.td +++ b/mlir/include/mlir/Dialect/Async/Passes.td @@ -25,7 +25,8 @@ Option<"numWorkerThreads", "num-workers", "int32_t", /*default=*/"8", - "The number of available workers to execute async operations.">, + "The number of available workers to execute async operations. If `-1` " + "the value will be retrieved from the runtime.">, Option<"minTaskSize", "min-task-size", "int32_t", /*default=*/"1000", diff --git a/mlir/include/mlir/ExecutionEngine/AsyncRuntime.h b/mlir/include/mlir/ExecutionEngine/AsyncRuntime.h --- a/mlir/include/mlir/ExecutionEngine/AsyncRuntime.h +++ b/mlir/include/mlir/ExecutionEngine/AsyncRuntime.h @@ -123,6 +123,9 @@ extern "C" void mlirAsyncRuntimeAwaitAllInGroupAndExecute(AsyncGroup *, CoroHandle, CoroResume); +// Returns the current number of available worker threads in the threadpool. +extern "C" int64_t mlirAsyncRuntimGetNumWorkerThreads(); + //===----------------------------------------------------------------------===// // Small async runtime support library for testing. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp --- a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp +++ b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp @@ -58,6 +58,8 @@ "mlirAsyncRuntimeAwaitValueAndExecute"; static constexpr const char *kAwaitAllAndExecute = "mlirAsyncRuntimeAwaitAllInGroupAndExecute"; +static constexpr const char *kGetNumWorkerThreads = + "mlirAsyncRuntimGetNumWorkerThreads"; namespace { /// Async Runtime API function types. @@ -180,6 +182,10 @@ return FunctionType::get(ctx, {GroupType::get(ctx), hdl, resume}, {}); } + static FunctionType getNumWorkerThreads(MLIRContext *ctx) { + return FunctionType::get(ctx, {}, {IndexType::get(ctx)}); + } + // Auxiliary coroutine resume intrinsic wrapper. static Type resumeFunctionType(MLIRContext *ctx) { auto voidTy = LLVM::LLVMVoidType::get(ctx); @@ -225,6 +231,7 @@ AsyncAPI::awaitValueAndExecuteFunctionType(ctx)); addFuncDecl(kAwaitAllAndExecute, AsyncAPI::awaitAllAndExecuteFunctionType(ctx)); + addFuncDecl(kGetNumWorkerThreads, AsyncAPI::getNumWorkerThreads(ctx)); } //===----------------------------------------------------------------------===// @@ -887,6 +894,30 @@ }; } // namespace +//===----------------------------------------------------------------------===// +// Convert async.runtime.num_worker_threads to the corresponding runtime API +// call. +//===----------------------------------------------------------------------===// + +namespace { +class RuntimeNumWorkerThreadsOpLowering + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(RuntimeNumWorkerThreadsOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + // Replace with a runtime API function call. + rewriter.replaceOpWithNewOp(op, kGetNumWorkerThreads, + rewriter.getIndexType()); + + return success(); + } +}; +} // namespace + //===----------------------------------------------------------------------===// // Async reference counting ops lowering (`async.runtime.add_ref` and // `async.runtime.drop_ref` to the corresponding API calls). @@ -993,8 +1024,9 @@ patterns.add(converter, ctx); + RuntimeAddToGroupOpLowering, RuntimeNumWorkerThreadsOpLowering, + RuntimeAddRefOpLowering, RuntimeDropRefOpLowering>(converter, + ctx); // Lower async.runtime operations that rely on LLVM type converter to convert // from async value payload type to the LLVM type. diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp --- a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp +++ b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp @@ -795,19 +795,56 @@ numUnrollableLoops++; } + Value numWorkerThreadsVal; + if (numWorkerThreads >= 0) + numWorkerThreadsVal = b.create(numWorkerThreads); + else + numWorkerThreadsVal = b.create(); + // With large number of threads the value of creating many compute blocks - // is reduced because the problem typically becomes memory bound. For small - // number of threads it helps with stragglers. - float overshardingFactor = numWorkerThreads <= 4 ? 8.0 - : numWorkerThreads <= 8 ? 4.0 - : numWorkerThreads <= 16 ? 2.0 - : numWorkerThreads <= 32 ? 1.0 - : numWorkerThreads <= 64 ? 0.8 - : 0.6; - - // Do not overload worker threads with too many compute blocks. - Value maxComputeBlocks = b.create( - std::max(1, static_cast(numWorkerThreads * overshardingFactor))); + // is reduced because the problem typically becomes memory bound. For this + // reason we scale the number of workers using an equivalent to the + // following logic: + // float overshardingFactor = numWorkerThreads <= 4 ? 8.0 + // : numWorkerThreads <= 8 ? 4.0 + // : numWorkerThreads <= 16 ? 2.0 + // : numWorkerThreads <= 32 ? 1.0 + // : numWorkerThreads <= 64 ? 0.8 + // : 0.6; + + // Pairs of non-inclusive lower end of the bracket and factor that the + // number of workers needs to be scaled with if it falls in that bucket. + const SmallVector> overshardingBrackets = { + {4, 4.0f}, {8, 2.0f}, {16, 1.0f}, {32, 0.8f}, {64, 0.6f}}; + const float initialOvershardingFactor = 8.0f; + + auto makeFloat = [](ImplicitLocOpBuilder ib, float val) -> Value { + return ib.create(llvm::APFloat(val), + ib.getF32Type()); + }; + + Value scalingFactor = makeFloat(b, initialOvershardingFactor); + for (const std::pair &p : overshardingBrackets) { + Value bracketBegin = b.create(p.first); + Value inBracket = b.create( + arith::CmpIPredicate::sgt, numWorkerThreadsVal, bracketBegin); + Value bracketScalingFactor = makeFloat(b, p.second); + scalingFactor = + b.create(inBracket, bracketScalingFactor, scalingFactor); + } + Value numWorkersIndex = + b.create(numWorkerThreadsVal, b.getI32Type()); + Value numWorkersFloat = + b.create(numWorkersIndex, b.getF32Type()); + Value scaledNumWorkers = + b.create(scalingFactor, numWorkersFloat); + Value scaledNumInt = + b.create(scaledNumWorkers, b.getI32Type()); + Value scaledWorkers = + b.create(scaledNumInt, b.getIndexType()); + + Value maxComputeBlocks = b.create( + b.create(1), scaledWorkers); // Compute parallel block size from the parallel problem size: // blockSize = min(tripCount, diff --git a/mlir/lib/ExecutionEngine/AsyncRuntime.cpp b/mlir/lib/ExecutionEngine/AsyncRuntime.cpp --- a/mlir/lib/ExecutionEngine/AsyncRuntime.cpp +++ b/mlir/lib/ExecutionEngine/AsyncRuntime.cpp @@ -438,6 +438,8 @@ } } +extern "C" int64_t mlirAsyncRuntimGetNumWorkerThreads() { return 1; } + //===----------------------------------------------------------------------===// // Small async runtime support library for testing. //===----------------------------------------------------------------------===// @@ -515,6 +517,8 @@ &mlir::runtime::mlirAsyncRuntimeAwaitAllInGroup); exportSymbol("mlirAsyncRuntimeAwaitAllInGroupAndExecute", &mlir::runtime::mlirAsyncRuntimeAwaitAllInGroupAndExecute); + exportSymbol("mlirAsyncRuntimGetNumWorkerThreads", + &mlir::runtime::mlirAsyncRuntimGetNumWorkerThreads); exportSymbol("mlirAsyncRuntimePrintCurrentThreadId", &mlir::runtime::mlirAsyncRuntimePrintCurrentThreadId); } diff --git a/mlir/test/Dialect/Async/async-parallel-for-num-worker-threads.mlir b/mlir/test/Dialect/Async/async-parallel-for-num-worker-threads.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Async/async-parallel-for-num-worker-threads.mlir @@ -0,0 +1,45 @@ +// RUN: mlir-opt %s -split-input-file -async-parallel-for=num-workers=-1 \ +// RUN: | FileCheck %s --dump-input=always + +// CHECK-LABEL: @num_worker_threads( +// CHECK: %[[MEMREF:.*]]: memref +func @num_worker_threads(%arg0: memref) { + + // CHECK: %[[scalingCstInit:.*]] = arith.constant 8.000000e+00 : f32 + // CHECK: %[[bracketLowerBound4:.*]] = arith.constant 4 : index + // CHECK: %[[scalingCst4:.*]] = arith.constant 4.000000e+00 : f32 + // CHECK: %[[bracketLowerBound8:.*]] = arith.constant 8 : index + // CHECK: %[[scalingCst8:.*]] = arith.constant 2.000000e+00 : f32 + // CHECK: %[[bracketLowerBound16:.*]] = arith.constant 16 : index + // CHECK: %[[scalingCst16:.*]] = arith.constant 1.000000e+00 : f32 + // CHECK: %[[bracketLowerBound32:.*]] = arith.constant 32 : index + // CHECK: %[[scalingCst32:.*]] = arith.constant 8.000000e-01 : f32 + // CHECK: %[[bracketLowerBound64:.*]] = arith.constant 64 : index + // CHECK: %[[scalingCst64:.*]] = arith.constant 6.000000e-01 : f32 + // CHECK: %[[workersIndex:.*]] = async.runtime.num_worker_threads : index + // CHECK: %[[inBracket4:.*]] = arith.cmpi sgt, %[[workersIndex]], %[[bracketLowerBound4]] : index + // CHECK: %[[scalingFactor4:.*]] = select %[[inBracket4]], %[[scalingCst4]], %[[scalingCstInit]] : f32 + // CHECK: %[[inBracket8:.*]] = arith.cmpi sgt, %[[workersIndex]], %[[bracketLowerBound8]] : index + // CHECK: %[[scalingFactor8:.*]] = select %[[inBracket8]], %[[scalingCst8]], %[[scalingFactor4]] : f32 + // CHECK: %[[inBracket16:.*]] = arith.cmpi sgt, %[[workersIndex]], %[[bracketLowerBound16]] : index + // CHECK: %[[scalingFactor16:.*]] = select %[[inBracket16]], %[[scalingCst16]], %[[scalingFactor8]] : f32 + // CHECK: %[[inBracket32:.*]] = arith.cmpi sgt, %[[workersIndex]], %[[bracketLowerBound32]] : index + // CHECK: %[[scalingFactor32:.*]] = select %[[inBracket32]], %[[scalingCst32]], %[[scalingFactor16]] : f32 + // CHECK: %[[inBracket64:.*]] = arith.cmpi sgt, %[[workersIndex]], %[[bracketLowerBound64]] : index + // CHECK: %[[scalingFactor64:.*]] = select %[[inBracket64]], %[[scalingCst64]], %[[scalingFactor32]] : f32 + // CHECK: %[[workersInt:.*]] = arith.index_cast %[[workersIndex]] : index to i32 + // CHECK: %[[workersFloat:.*]] = arith.sitofp %[[workersInt]] : i32 to f32 + // CHECK: %[[scaledFloat:.*]] = arith.mulf %[[scalingFactor64]], %[[workersFloat]] : f32 + // CHECK: %[[scaledInt:.*]] = arith.fptosi %[[scaledFloat]] : f32 to i32 + // CHECK: %[[scaledIndex:.*]] = arith.index_cast %[[scaledInt]] : i32 to index + + %lb = arith.constant 0 : index + %ub = arith.constant 100 : index + %st = arith.constant 1 : index + scf.parallel (%i) = (%lb) to (%ub) step (%st) { + %one = arith.constant 1.0 : f32 + memref.store %one, %arg0[%i] : memref + } + + return +} diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -1207,6 +1207,7 @@ includes = ["include"], deps = [ ":ControlFlowInterfacesTdFiles", + ":InferTypeOpInterfaceTdFiles", ":OpBaseTdFiles", ":SideEffectInterfacesTdFiles", ], @@ -2089,6 +2090,7 @@ ":ControlFlowInterfaces", ":Dialect", ":IR", + ":InferTypeOpInterface", ":SideEffectInterfaces", ":StandardOps", ":Support",