diff --git a/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h b/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h --- a/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h +++ b/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h @@ -476,4 +476,14 @@ extern "C" MLIR_CRUNNERUTILS_EXPORT void printFlops(double flops); extern "C" MLIR_CRUNNERUTILS_EXPORT double rtclock(); +//===----------------------------------------------------------------------===// +// Runtime support library for random number generation. +//===----------------------------------------------------------------------===// +// Uses a seed to initialize a random generator and returns the generator. +extern "C" MLIR_CRUNNERUTILS_EXPORT void *rtsrand(uint64_t s); +// Returns a random number in the range of [0, m). +extern "C" MLIR_CRUNNERUTILS_EXPORT uint64_t rtrand(void *, uint64_t m); +// Deletes the random number generator. +extern "C" MLIR_CRUNNERUTILS_EXPORT void rtdrand(void *); + #endif // MLIR_EXECUTIONENGINE_CRUNNERUTILS_H diff --git a/mlir/lib/ExecutionEngine/CRunnerUtils.cpp b/mlir/lib/ExecutionEngine/CRunnerUtils.cpp --- a/mlir/lib/ExecutionEngine/CRunnerUtils.cpp +++ b/mlir/lib/ExecutionEngine/CRunnerUtils.cpp @@ -29,6 +29,7 @@ #include #include #include +#include #include #ifdef MLIR_CRUNNERUTILS_DEFINE_FUNCTIONS @@ -148,4 +149,20 @@ #endif } +extern "C" void *rtsrand(uint64_t s) { + // Standard mersenne_twister_engine seeded with s. + return new std::mt19937(s); +} + +extern "C" uint64_t rtrand(void *g, uint64_t m) { + std::mt19937 *generator = static_cast(g); + std::uniform_int_distribution distrib(0, m); + return distrib(*generator); +} + +extern "C" void rtdrand(void *g) { + std::mt19937 *generator = static_cast(g); + delete generator; +} + #endif // MLIR_CRUNNERUTILS_DEFINE_FUNCTIONS