diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -84,8 +84,19 @@ //===----------------------------------------------------------------------===// /// Creates an MLIR context and transfers its ownership to the caller. +/// This sets the default multithreading option (enabled). MLIR_CAPI_EXPORTED MlirContext mlirContextCreate(void); +/// Creates an MLIR context with an explicit setting of the multithreading +/// setting and transfers its ownership to the caller. +MLIR_CAPI_EXPORTED MlirContext +mlirContextCreateWithThreading(bool threadingEnabled); + +/// Creates an MLIR context, setting the multithreading setting explicitly and +/// pre-loading the dialects from the provided DialectRegistry. +MLIR_CAPI_EXPORTED MlirContext mlirContextCreateWithRegistry( + MlirDialectRegistry registry, bool threadingEnabled); + /// Checks if two contexts are equal. MLIR_CAPI_EXPORTED bool mlirContextEqual(MlirContext ctx1, MlirContext ctx2); @@ -144,6 +155,13 @@ MLIR_CAPI_EXPORTED bool mlirContextIsRegisteredOperation(MlirContext context, MlirStringRef name); +/// Sets the thread pool of the context explicitly, enabling multithreading in +/// the process. This API should be used to avoid re-creating thread pools in +/// long-running applications that perform multiple compilations, see +/// the C++ documentation for MLIRContext for details. +MLIR_CAPI_EXPORTED void mlirContextSetThreadPool(MlirContext context, + MlirLlvmThreadPool threadPool); + //===----------------------------------------------------------------------===// // Dialect API. //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir-c/Support.h b/mlir/include/mlir-c/Support.h --- a/mlir/include/mlir-c/Support.h +++ b/mlir/include/mlir-c/Support.h @@ -56,6 +56,8 @@ }; \ typedef struct name name +/// Re-export llvm::ThreadPool so as to avoid including the LLVM C API directly. +DEFINE_C_API_STRUCT(MlirLlvmThreadPool, void); DEFINE_C_API_STRUCT(MlirTypeID, const void); DEFINE_C_API_STRUCT(MlirTypeIDAllocator, void); @@ -138,6 +140,17 @@ return res; } +//===----------------------------------------------------------------------===// +// MlirLlvmThreadPool. +//===----------------------------------------------------------------------===// + +/// Create an LLVM thread pool. This is reexported here to avoid directly +/// pulling in the LLVM headers directly. +MLIR_CAPI_EXPORTED MlirLlvmThreadPool mlirLlvmThreadPoolCreate(void); + +/// Destroy an LLVM thread pool. +MLIR_CAPI_EXPORTED void mlirLlvmThreadPoolDestroy(MlirLlvmThreadPool pool); + //===----------------------------------------------------------------------===// // TypeID API. //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/CAPI/Support.h b/mlir/include/mlir/CAPI/Support.h --- a/mlir/include/mlir/CAPI/Support.h +++ b/mlir/include/mlir/CAPI/Support.h @@ -21,6 +21,10 @@ #include "mlir/Support/TypeID.h" #include "llvm/ADT/StringRef.h" +namespace llvm { +class ThreadPool; +} // namespace llvm + /// Converts a StringRef into its MLIR C API equivalent. inline MlirStringRef wrap(llvm::StringRef ref) { return mlirStringRefCreate(ref.data(), ref.size()); @@ -41,6 +45,7 @@ return mlir::success(mlirLogicalResultIsSuccess(res)); } +DEFINE_C_API_PTR_METHODS(MlirLlvmThreadPool, llvm::ThreadPool) DEFINE_C_API_METHODS(MlirTypeID, mlir::TypeID) DEFINE_C_API_PTR_METHODS(MlirTypeIDAllocator, mlir::TypeIDAllocator) diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -39,6 +39,23 @@ return wrap(context); } +static inline MLIRContext::Threading toThreadingEnum(bool threadingEnabled) { + return threadingEnabled ? MLIRContext::Threading::ENABLED + : MLIRContext::Threading::DISABLED; +} + +MlirContext mlirContextCreateWithThreading(bool threadingEnabled) { + auto *context = new MLIRContext(toThreadingEnum(threadingEnabled)); + return wrap(context); +} + +MlirContext mlirContextCreateWithRegistry(MlirDialectRegistry registry, + bool threadingEnabled) { + auto *context = + new MLIRContext(*unwrap(registry), toThreadingEnum(threadingEnabled)); + return wrap(context); +} + bool mlirContextEqual(MlirContext ctx1, MlirContext ctx2) { return unwrap(ctx1) == unwrap(ctx2); } @@ -84,6 +101,11 @@ unwrap(context)->loadAllAvailableDialects(); } +void mlirContextSetThreadPool(MlirContext context, + MlirLlvmThreadPool threadPool) { + unwrap(context)->setThreadPool(*unwrap(threadPool)); +} + //===----------------------------------------------------------------------===// // Dialect API. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/CAPI/IR/Support.cpp b/mlir/lib/CAPI/IR/Support.cpp --- a/mlir/lib/CAPI/IR/Support.cpp +++ b/mlir/lib/CAPI/IR/Support.cpp @@ -8,6 +8,7 @@ #include "mlir/CAPI/Support.h" #include "llvm/ADT/StringRef.h" +#include "llvm/Support/ThreadPool.h" #include @@ -20,6 +21,17 @@ llvm::StringRef(other.data, other.length); } +//===----------------------------------------------------------------------===// +// LLVM ThreadPool API. +//===----------------------------------------------------------------------===// +MlirLlvmThreadPool mlirLlvmThreadPoolCreate() { + return wrap(new llvm::ThreadPool()); +} + +void mlirLlvmThreadPoolDestroy(MlirLlvmThreadPool threadPool) { + delete unwrap(threadPool); +} + //===----------------------------------------------------------------------===// // TypeID API. //===----------------------------------------------------------------------===// diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c --- a/mlir/test/CAPI/ir.c +++ b/mlir/test/CAPI/ir.c @@ -2210,6 +2210,18 @@ return 0; } +void testExplicitThreadPools(void) { + MlirLlvmThreadPool threadPool = mlirLlvmThreadPoolCreate(); + MlirDialectRegistry registry = mlirDialectRegistryCreate(); + mlirRegisterAllDialects(registry); + MlirContext context = + mlirContextCreateWithRegistry(registry, /*threadingEnabled=*/false); + mlirContextSetThreadPool(context, threadPool); + mlirContextDestroy(context); + mlirDialectRegistryDestroy(registry); + mlirLlvmThreadPoolDestroy(threadPool); +} + void testDiagnostics(void) { MlirContext ctx = mlirContextCreate(); MlirDiagnosticHandlerID id = mlirContextAttachDiagnosticHandler( @@ -2310,6 +2322,7 @@ mlirContextDestroy(ctx); + testExplicitThreadPools(); testDiagnostics(); return 0; }