diff --git a/llvm/include/llvm/ADT/STLFunctionalExtras.h b/llvm/include/llvm/ADT/STLFunctionalExtras.h --- a/llvm/include/llvm/ADT/STLFunctionalExtras.h +++ b/llvm/include/llvm/ADT/STLFunctionalExtras.h @@ -46,6 +46,22 @@ std::forward(params)...); } + // Handles construction from an empty callable (for example a default + // constructed std::function) and ensure that bool conversion to `false` is + // propagated. + template ::value> * = + nullptr> + static decltype(callback) get_callback(const Callable &callable) { + return callable ? callback_fn> : nullptr; + } + template ::value> * = + nullptr> + static decltype(callback) get_callback(const Callable &callable) { + return callback_fn>; + } + public: function_ref() = default; function_ref(std::nullptr_t) {} @@ -61,7 +77,7 @@ std::is_convertible()( std::declval()...)), Ret>::value> * = nullptr) - : callback(callback_fn::type>), + : callback(get_callback(callable)), callable(reinterpret_cast(&callable)) {} Ret operator()(Params ...params) const { diff --git a/llvm/unittests/ADT/FunctionRefTest.cpp b/llvm/unittests/ADT/FunctionRefTest.cpp --- a/llvm/unittests/ADT/FunctionRefTest.cpp +++ b/llvm/unittests/ADT/FunctionRefTest.cpp @@ -59,4 +59,11 @@ EXPECT_EQ("string", returns([] { return "hello"; })); } +TEST(FunctionRefTest, std_function) { + // Check that function_ref constructed from an empty std::function is empty. + std::function f; + function_ref f_ref(f); + ASSERT_FALSE((bool)f_ref); +} + } // namespace