diff --git a/mlir/include/mlir/ExecutionEngine/AsyncRuntime.h b/mlir/include/mlir/ExecutionEngine/AsyncRuntime.h --- a/mlir/include/mlir/ExecutionEngine/AsyncRuntime.h +++ b/mlir/include/mlir/ExecutionEngine/AsyncRuntime.h @@ -32,15 +32,18 @@ #define MLIR_ASYNCRUNTIME_DEFINE_FUNCTIONS #endif // _WIN32 +namespace mlir { +namespace runtime { + //===----------------------------------------------------------------------===// // Async runtime API. //===----------------------------------------------------------------------===// // Runtime implementation of `async.token` data type. -typedef struct AsyncToken MLIR_AsyncToken; +typedef struct AsyncToken AsyncToken; // Runtime implementation of `async.group` data type. -typedef struct AsyncGroup MLIR_AsyncGroup; +typedef struct AsyncGroup AsyncGroup; // Async runtime uses LLVM coroutines to represent asynchronous tasks. Task // function is a coroutine handle and a resume function that continue coroutine @@ -102,4 +105,7 @@ extern "C" MLIR_ASYNCRUNTIME_EXPORT void mlirAsyncRuntimePrintCurrentThreadId(); +} // namespace runtime +} // namespace mlir + #endif // MLIR_EXECUTIONENGINE_ASYNCRUNTIME_H_ diff --git a/mlir/lib/ExecutionEngine/AsyncRuntime.cpp b/mlir/lib/ExecutionEngine/AsyncRuntime.cpp --- a/mlir/lib/ExecutionEngine/AsyncRuntime.cpp +++ b/mlir/lib/ExecutionEngine/AsyncRuntime.cpp @@ -24,10 +24,14 @@ #include #include +using namespace mlir::runtime; + //===----------------------------------------------------------------------===// // Async runtime API. //===----------------------------------------------------------------------===// +namespace mlir { +namespace runtime { namespace { // Forward declare class defined below. @@ -66,12 +70,6 @@ std::atomic numRefCountedObjects; }; -// Returns the default per-process instance of an async runtime. -AsyncRuntime *getDefaultAsyncRuntimeInstance() { - static auto runtime = std::make_unique(); - return runtime.get(); -} - // -------------------------------------------------------------------------- // // A base class for all reference counted objects created by the async runtime. // -------------------------------------------------------------------------- // @@ -110,6 +108,12 @@ } // namespace +// Returns the default per-process instance of an async runtime. +static AsyncRuntime *getDefaultAsyncRuntimeInstance() { + static auto runtime = std::make_unique(); + return runtime.get(); +} + struct AsyncToken : public RefCounted { // AsyncToken created with a reference count of 2 because it will be returned // to the `async.execute` caller and also will be later on emplaced by the @@ -140,32 +144,34 @@ std::vector> awaiters; }; +} // namespace runtime +} // namespace mlir + // Adds references to reference counted runtime object. -extern "C" void mlirAsyncRuntimeAddRef(RefCountedObjPtr ptr, int32_t count) { +void mlirAsyncRuntimeAddRef(RefCountedObjPtr ptr, int32_t count) { RefCounted *refCounted = static_cast(ptr); refCounted->addRef(count); } // Drops references from reference counted runtime object. -extern "C" void mlirAsyncRuntimeDropRef(RefCountedObjPtr ptr, int32_t count) { +void mlirAsyncRuntimeDropRef(RefCountedObjPtr ptr, int32_t count) { RefCounted *refCounted = static_cast(ptr); refCounted->dropRef(count); } // Create a new `async.token` in not-ready state. -extern "C" AsyncToken *mlirAsyncRuntimeCreateToken() { +AsyncToken *mlirAsyncRuntimeCreateToken() { AsyncToken *token = new AsyncToken(getDefaultAsyncRuntimeInstance()); return token; } // Create a new `async.group` in empty state. -extern "C" AsyncGroup *mlirAsyncRuntimeCreateGroup() { +AsyncGroup *mlirAsyncRuntimeCreateGroup() { AsyncGroup *group = new AsyncGroup(getDefaultAsyncRuntimeInstance()); return group; } -extern "C" int64_t mlirAsyncRuntimeAddTokenToGroup(AsyncToken *token, - AsyncGroup *group) { +int64_t mlirAsyncRuntimeAddTokenToGroup(AsyncToken *token, AsyncGroup *group) { std::unique_lock lockToken(token->mu); std::unique_lock lockGroup(group->mu); @@ -206,7 +212,7 @@ } // Switches `async.token` to ready state and runs all awaiters. -extern "C" void mlirAsyncRuntimeEmplaceToken(AsyncToken *token) { +void mlirAsyncRuntimeEmplaceToken(AsyncToken *token) { // Make sure that `dropRef` does not destroy the mutex owned by the lock. { std::unique_lock lock(token->mu); @@ -221,25 +227,24 @@ token->dropRef(); } -extern "C" void mlirAsyncRuntimeAwaitToken(AsyncToken *token) { +void mlirAsyncRuntimeAwaitToken(AsyncToken *token) { std::unique_lock lock(token->mu); if (!token->ready) token->cv.wait(lock, [token] { return token->ready; }); } -extern "C" void mlirAsyncRuntimeAwaitAllInGroup(AsyncGroup *group) { +void mlirAsyncRuntimeAwaitAllInGroup(AsyncGroup *group) { std::unique_lock lock(group->mu); if (group->pendingTokens != 0) group->cv.wait(lock, [group] { return group->pendingTokens == 0; }); } -extern "C" void mlirAsyncRuntimeExecute(CoroHandle handle, CoroResume resume) { +void mlirAsyncRuntimeExecute(CoroHandle handle, CoroResume resume) { (*resume)(handle); } -extern "C" void mlirAsyncRuntimeAwaitTokenAndExecute(AsyncToken *token, - CoroHandle handle, - CoroResume resume) { +void mlirAsyncRuntimeAwaitTokenAndExecute(AsyncToken *token, CoroHandle handle, + CoroResume resume) { std::unique_lock lock(token->mu); auto execute = [handle, resume]() { (*resume)(handle); }; if (token->ready) @@ -248,9 +253,9 @@ token->awaiters.push_back([execute]() { execute(); }); } -extern "C" void mlirAsyncRuntimeAwaitAllInGroupAndExecute(AsyncGroup *group, - CoroHandle handle, - CoroResume resume) { +void mlirAsyncRuntimeAwaitAllInGroupAndExecute(AsyncGroup *group, + CoroHandle handle, + CoroResume resume) { std::unique_lock lock(group->mu); auto execute = [handle, resume]() { (*resume)(handle); }; if (group->pendingTokens == 0) @@ -263,7 +268,7 @@ // Small async runtime support library for testing. //===----------------------------------------------------------------------===// -extern "C" void mlirAsyncRuntimePrintCurrentThreadId() { +void mlirAsyncRuntimePrintCurrentThreadId() { static thread_local std::thread::id thisId = std::this_thread::get_id(); std::cout << "Current thread id: " << thisId << std::endl; }