diff --git a/mlir/include/mlir/Dialect/Async/IR/Async.h b/mlir/include/mlir/Dialect/Async/IR/Async.h --- a/mlir/include/mlir/Dialect/Async/IR/Async.h +++ b/mlir/include/mlir/Dialect/Async/IR/Async.h @@ -22,6 +22,7 @@ #include "mlir/IR/OpImplementation.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td --- a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td +++ b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td @@ -16,6 +16,7 @@ include "mlir/Dialect/Async/IR/AsyncDialect.td" include "mlir/Dialect/Async/IR/AsyncTypes.td" include "mlir/Interfaces/ControlFlowInterfaces.td" +include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" //===----------------------------------------------------------------------===// @@ -529,4 +530,27 @@ }]; } +def Async_RuntimeNumWorkerThreadsOp : Async_Op<"runtime.num_worker_threads", [InferTypeOpInterface]> { + let summary = "gets the number of threads in the threadpool from the runtime"; + let description = [{ + The `async.runtime.num_worker_threads` operation gets the number of threads + in the threadpool from the runtime. + }]; + + let results = (outs Index:$result); + let assemblyFormat = "attr-dict"; + let extraClassDeclaration = [{ + static ::mlir::LogicalResult inferReturnTypes( + ::mlir::MLIRContext *context, ::llvm::Optional<::mlir::Location> location, + ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes, + ::mlir::RegionRange regions, + ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) { + inferredReturnTypes.resize(1); + ::mlir::Builder odsBuilder(context); + inferredReturnTypes[0] = odsBuilder.getIndexType(); + return ::mlir::success(); + } + }]; +} + #endif // ASYNC_OPS diff --git a/mlir/include/mlir/Dialect/Async/Passes.td b/mlir/include/mlir/Dialect/Async/Passes.td --- a/mlir/include/mlir/Dialect/Async/Passes.td +++ b/mlir/include/mlir/Dialect/Async/Passes.td @@ -25,7 +25,8 @@ Option<"numWorkerThreads", "num-workers", "int32_t", /*default=*/"8", - "The number of available workers to execute async operations.">, + "The number of available workers to execute async operations. If `-1` " + "the value will be retrieved from the runtime.">, Option<"minTaskSize", "min-task-size", "int32_t", /*default=*/"1000", diff --git a/mlir/include/mlir/ExecutionEngine/AsyncRuntime.h b/mlir/include/mlir/ExecutionEngine/AsyncRuntime.h --- a/mlir/include/mlir/ExecutionEngine/AsyncRuntime.h +++ b/mlir/include/mlir/ExecutionEngine/AsyncRuntime.h @@ -123,6 +123,8 @@ extern "C" void mlirAsyncRuntimeAwaitAllInGroupAndExecute(AsyncGroup *, CoroHandle, CoroResume); +// Returns the current number of available worker threads in the threadpool. +extern "C" int64_t mlirAsyncRuntimGetNumWorkerThreads(); //===----------------------------------------------------------------------===// // Small async runtime support library for testing. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp --- a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp +++ b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp @@ -58,6 +58,8 @@ "mlirAsyncRuntimeAwaitValueAndExecute"; static constexpr const char *kAwaitAllAndExecute = "mlirAsyncRuntimeAwaitAllInGroupAndExecute"; +static constexpr const char *kGetNumWorkerThreads = + "mlirAsyncRuntimGetNumWorkerThreads"; namespace { /// Async Runtime API function types. @@ -180,6 +182,10 @@ return FunctionType::get(ctx, {GroupType::get(ctx), hdl, resume}, {}); } + static FunctionType getNumWorkerThreads(MLIRContext *ctx) { + return FunctionType::get(ctx, {}, {IndexType::get(ctx)}); + } + // Auxiliary coroutine resume intrinsic wrapper. static Type resumeFunctionType(MLIRContext *ctx) { auto voidTy = LLVM::LLVMVoidType::get(ctx); @@ -225,6 +231,7 @@ AsyncAPI::awaitValueAndExecuteFunctionType(ctx)); addFuncDecl(kAwaitAllAndExecute, AsyncAPI::awaitAllAndExecuteFunctionType(ctx)); + addFuncDecl(kGetNumWorkerThreads, AsyncAPI::getNumWorkerThreads(ctx)); } //===----------------------------------------------------------------------===// @@ -887,6 +894,30 @@ }; } // namespace +//===----------------------------------------------------------------------===// +// Convert async.runtime.num_worker_threads to the corresponding runtime API +// call. +//===----------------------------------------------------------------------===// + +namespace { +class RuntimeNumWorkerThreadsOpLowering + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(RuntimeNumWorkerThreadsOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + // Replace with a runtime API function call. + rewriter.replaceOpWithNewOp(op, kGetNumWorkerThreads, + rewriter.getIndexType()); + + return success(); + } +}; +} // namespace + //===----------------------------------------------------------------------===// // Async reference counting ops lowering (`async.runtime.add_ref` and // `async.runtime.drop_ref` to the corresponding API calls). @@ -993,8 +1024,9 @@ patterns.add(converter, ctx); + RuntimeAddToGroupOpLowering, RuntimeNumWorkerThreadsOpLowering, + RuntimeAddRefOpLowering, RuntimeDropRefOpLowering>(converter, + ctx); // Lower async.runtime operations that rely on LLVM type converter to convert // from async value payload type to the LLVM type. diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp --- a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp +++ b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp @@ -795,19 +795,64 @@ numUnrollableLoops++; } + Value numWorkerThreadsVal; + if (numWorkerThreads >= 0) { + numWorkerThreadsVal = b.create(numWorkerThreads); + } else { + numWorkerThreadsVal = b.create(); + } + // With large number of threads the value of creating many compute blocks // is reduced because the problem typically becomes memory bound. For small // number of threads it helps with stragglers. - float overshardingFactor = numWorkerThreads <= 4 ? 8.0 - : numWorkerThreads <= 8 ? 4.0 - : numWorkerThreads <= 16 ? 2.0 - : numWorkerThreads <= 32 ? 1.0 - : numWorkerThreads <= 64 ? 0.8 - : 0.6; - - // Do not overload worker threads with too many compute blocks. - Value maxComputeBlocks = b.create( - std::max(1, static_cast(numWorkerThreads * overshardingFactor))); + + // Pairs of upper bound of the bracket and factor that the number of workers + // needs to be scaled with if it falls in that bucket. + SmallVector> overshardingBrackets = { + {4, 8.0f}, {8, 4.0f}, {16, 2.0f}, {32, 1.0f}, {64, 0.8f}}; + + auto makeFloat = [](ImplicitLocOpBuilder ib, float val) -> Value { + return ib.create(llvm::APFloat(val), + ib.getF32Type()); + }; + std::function makeSwitch = + [&makeFloat](ImplicitLocOpBuilder nb) -> Value { + return makeFloat(nb, 0.6f); + }; + + for (std::pair &p : llvm::reverse(overshardingBrackets)) { + makeSwitch = [&makeFloat, &numWorkerThreadsVal, &p, + makeSwitch](ImplicitLocOpBuilder &nb) -> Value { + auto inBracket = nb.create( + arith::CmpIPredicate::sle, numWorkerThreadsVal, + nb.create(p.first)); + return nb + .create( + TypeRange(nb.getF32Type()), inBracket, + [&](OpBuilder &nestedBuilder, Location nestedLoc) { + ImplicitLocOpBuilder nb(nestedLoc, nestedBuilder); + nb.create(makeFloat(nb, p.second)); + }, + [&](OpBuilder &nestedBuilder, Location nestedLoc) { + ImplicitLocOpBuilder nb(nestedLoc, nestedBuilder); + nb.create(makeSwitch(nb)); + }) + .getResults() + .front(); + }; + } + Value scaledWorkers = b.create( + b.create( + b.create( + makeSwitch(b), b.create( + b.create( + numWorkerThreadsVal, b.getI32Type()), + b.getF32Type())), + b.getI32Type()), + b.getIndexType()); + + Value maxComputeBlocks = b.create( + b.create(1), scaledWorkers); // Compute parallel block size from the parallel problem size: // blockSize = min(tripCount, diff --git a/mlir/lib/ExecutionEngine/AsyncRuntime.cpp b/mlir/lib/ExecutionEngine/AsyncRuntime.cpp --- a/mlir/lib/ExecutionEngine/AsyncRuntime.cpp +++ b/mlir/lib/ExecutionEngine/AsyncRuntime.cpp @@ -438,6 +438,8 @@ } } +extern "C" int64_t mlirAsyncRuntimGetNumWorkerThreads() { return 1; } + //===----------------------------------------------------------------------===// // Small async runtime support library for testing. //===----------------------------------------------------------------------===// @@ -515,6 +517,8 @@ &mlir::runtime::mlirAsyncRuntimeAwaitAllInGroup); exportSymbol("mlirAsyncRuntimeAwaitAllInGroupAndExecute", &mlir::runtime::mlirAsyncRuntimeAwaitAllInGroupAndExecute); + exportSymbol("mlirAsyncRuntimGetNumWorkerThreads", + &mlir::runtime::mlirAsyncRuntimGetNumWorkerThreads); exportSymbol("mlirAsyncRuntimePrintCurrentThreadId", &mlir::runtime::mlirAsyncRuntimePrintCurrentThreadId); } diff --git a/mlir/test/Dialect/Async/async-parallel-for-num-worker-threads.mlir b/mlir/test/Dialect/Async/async-parallel-for-num-worker-threads.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Async/async-parallel-for-num-worker-threads.mlir @@ -0,0 +1,36 @@ +// RUN: mlir-opt %s -split-input-file -async-parallel-for=num-workers=-1 \ +// RUN: | FileCheck %s --dump-input=always + +// We check only a single level of recursion in the switch statement in order +// to avoid brittleness. + +// CHECK-LABEL: @num_worker_threads( +// CHECK: %[[MEMREF:.*]]: memref +func @num_worker_threads(%arg0: memref) { + + // CHECK: %[[scalingCst:.*]] = arith.constant {{.*}} : f32 + // CHECK: scf.if %false { + // CHECK: } else { + // CHECK: %[[workersIndex:.*]] = async.runtime.num_worker_threads + // CHECK: %[[cond:.*]] = arith.cmpi sle, %[[workersIndex]], %c4 : index + // CHECK: %[[factor:.*]] = scf.if %[[cond]] -> (f32) { + // CHECK: scf.yield {{.*}} : f32 + // CHECK: } else { + // Skipping redundancy... + // CHECK: } + // CHECK: %[[workersInt:.*]] = arith.index_cast %[[workersIndex]] : index to i32 + // CHECK: %[[workersFloat:.*]] = arith.sitofp %[[workersInt]] : i32 to f32 + // CHECK: %[[scaledFloat:.*]] = arith.mulf %[[factor]], %[[workersFloat]] : f32 + // CHECK: %[[scaledInt:.*]] = arith.fptosi %[[scaledFloat]] : f32 to i32 + // CHECK: %[[scaledIndex:.*]] = arith.index_cast %[[scaledInt]] : i32 to index + + %lb = arith.constant 0 : index + %ub = arith.constant 100 : index + %st = arith.constant 1 : index + scf.parallel (%i) = (%lb) to (%ub) step (%st) { + %one = arith.constant 1.0 : f32 + memref.store %one, %arg0[%i] : memref + } + + return +} diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -1207,6 +1207,7 @@ includes = ["include"], deps = [ ":ControlFlowInterfacesTdFiles", + ":InferTypeOpInterfaceTdFiles", ":OpBaseTdFiles", ":SideEffectInterfacesTdFiles", ], @@ -2087,11 +2088,8 @@ deps = [ ":AsyncOpsIncGen", ":ControlFlowInterfaces", - ":Dialect", ":IR", - ":SideEffectInterfaces", - ":StandardOps", - ":Support", + ":InferTypeOpInterface", "//llvm:Support", ], )