diff --git a/clang/lib/CodeGen/CGClass.cpp b/clang/lib/CodeGen/CGClass.cpp --- a/clang/lib/CodeGen/CGClass.cpp +++ b/clang/lib/CodeGen/CGClass.cpp @@ -26,6 +26,7 @@ #include "clang/Basic/CodeGenOptions.h" #include "clang/Basic/TargetBuiltins.h" #include "clang/CodeGen/CGFunctionInfo.h" +#include "llvm/IR/Attributes.h" #include "llvm/IR/Intrinsics.h" #include "llvm/IR/Metadata.h" #include "llvm/Transforms/Utils/SanitizerStats.h" @@ -2894,8 +2895,8 @@ } void CodeGenFunction::EmitForwardingCallToLambda( - const CXXMethodDecl *callOperator, - CallArgList &callArgs) { + const CXXMethodDecl *callOperator, CallArgList &callArgs, + llvm::CallBase **callOrInvoke) { // Get the address of the call operator. const CGFunctionInfo &calleeFnInfo = CGM.getTypes().arrangeCXXMethodDeclaration(callOperator); @@ -2921,7 +2922,8 @@ // Now emit our call. auto callee = CGCallee::forDirect(calleePtr, GlobalDecl(callOperator)); - RValue RV = EmitCall(calleeFnInfo, callee, returnSlot, callArgs); + RValue RV = + EmitCall(calleeFnInfo, callee, returnSlot, callArgs, callOrInvoke); // If necessary, copy the returned value into the slot. if (!resultType->isVoidType() && returnSlot.isNull()) { @@ -2960,7 +2962,7 @@ assert(!Lambda->isGenericLambda() && "generic lambda interconversion to block not implemented"); - EmitForwardingCallToLambda(CallOp, CallArgs); + EmitForwardingCallToLambda(CallOp, CallArgs, nullptr); } void CodeGenFunction::EmitLambdaDelegatingInvokeBody(const CXXMethodDecl *MD) { @@ -2990,7 +2992,23 @@ assert(CorrespondingCallOpSpecialization); CallOp = cast(CorrespondingCallOpSpecialization); } - EmitForwardingCallToLambda(CallOp, CallArgs); + llvm::CallBase *CallOrInvoke = nullptr; + EmitForwardingCallToLambda(CallOp, CallArgs, &CallOrInvoke); + + // ThisPtr is Undef, so we need to reset incompatible attributes. + const auto &ToRemove = + llvm::AttributeFuncs::getUBImplyingAttributes().addAttribute( + llvm::Attribute::NonNull); + CallOrInvoke->getParent()->getParent()->dump(); + llvm::Function *F = CallOrInvoke->getCalledFunction(); + + for (llvm::Use &U : ThisPtr->uses()) { + if (U.getUser() == CallOrInvoke && CallOrInvoke->isArgOperand(&U)) { + unsigned ArgNo = CallOrInvoke->getArgOperandNo(&U); + F->removeParamAttrs(ArgNo, ToRemove); + CallOrInvoke->removeParamAttrs(ArgNo, ToRemove); + } + } } void CodeGenFunction::EmitLambdaStaticInvokeBody(const CXXMethodDecl *MD) { diff --git a/clang/lib/CodeGen/CodeGenFunction.h b/clang/lib/CodeGen/CodeGenFunction.h --- a/clang/lib/CodeGen/CodeGenFunction.h +++ b/clang/lib/CodeGen/CodeGenFunction.h @@ -2219,7 +2219,8 @@ void EmitBlockWithFallThrough(llvm::BasicBlock *BB, const Stmt *S); void EmitForwardingCallToLambda(const CXXMethodDecl *LambdaCallOperator, - CallArgList &CallArgs); + CallArgList &CallArgs, + llvm::CallBase **CallOrInvoke); void EmitLambdaBlockInvokeBody(); void EmitLambdaDelegatingInvokeBody(const CXXMethodDecl *MD); void EmitLambdaStaticInvokeBody(const CXXMethodDecl *MD); diff --git a/clang/test/CodeGenCXX/lambda-to-function-pointer-conversion.cpp b/clang/test/CodeGenCXX/lambda-to-function-pointer-conversion.cpp --- a/clang/test/CodeGenCXX/lambda-to-function-pointer-conversion.cpp +++ b/clang/test/CodeGenCXX/lambda-to-function-pointer-conversion.cpp @@ -3,9 +3,9 @@ // This code used to cause an assertion failure in EmitDelegateCallArg. // CHECK: define internal void @"?__invoke@@?0??test@@YAXXZ@CA@UTrivial@@@Z"( -// CHECK: call void @"??R@?0??test@@YAXXZ@QEBA@UTrivial@@@Z"( +// CHECK: call void @"??R@?0??test@@YAXXZ@QEBA@UTrivial@@@Z"(ptr align 1 undef, -// CHECK: define internal void @"??R@?0??test@@YAXXZ@QEBA@UTrivial@@@Z"( +// CHECK: define internal void @"??R@?0??test@@YAXXZ@QEBA@UTrivial@@@Z"(ptr align 1 %this, struct Trivial { int x; @@ -16,3 +16,14 @@ void test() { fnptr = [](Trivial a){ (void)a; }; } + +// CHECK: define internal i32 @"?__invoke@@?0??test2@@YAXXZ@CA@H@Z"( +// CHECK: call void @"??R@?0??test2@@YAXXZ@QEBA@H@Z"(ptr align 1 undef, + +// CHECK: define internal void @"??R@?0??test2@@YAXXZ@QEBA@H@Z"(ptr align 1 %this, + +Trivial (*fnptr2)(int); + +void test2() { + fnptr2 = [](int) -> Trivial { return {}; }; +} \ No newline at end of file