Index: lib/CodeGen/CGCall.h =================================================================== --- lib/CodeGen/CGCall.h +++ lib/CodeGen/CGCall.h @@ -204,12 +204,9 @@ assert(isVirtual()); return VirtualInfo.Addr; } - - llvm::FunctionType *getFunctionType() const { - if (isVirtual()) - return VirtualInfo.FTy; - return cast( - getFunctionPointer()->getType()->getPointerElementType()); + llvm::FunctionType *getVirtualFunctionType() const { + assert(isVirtual()); + return VirtualInfo.FTy; } /// If this is a delayed callee computation of some sort, prepare Index: lib/CodeGen/CGCall.cpp =================================================================== --- lib/CodeGen/CGCall.cpp +++ lib/CodeGen/CGCall.cpp @@ -3826,7 +3826,7 @@ QualType RetTy = CallInfo.getReturnType(); const ABIArgInfo &RetAI = CallInfo.getReturnInfo(); - llvm::FunctionType *IRFuncTy = Callee.getFunctionType(); + llvm::FunctionType *IRFuncTy = getTypes().GetFunctionType(CallInfo); #ifndef NDEBUG if (!(CallInfo.isVariadic() && CallInfo.getArgStruct())) { @@ -3837,8 +3837,13 @@ // // In other cases, we assert that the types match up (until pointers stop // having pointee types). - llvm::FunctionType *IRFuncTyFromInfo = getTypes().GetFunctionType(CallInfo); - assert(IRFuncTy == IRFuncTyFromInfo); + llvm::Type *TypeFromVal; + if (Callee.isVirtual()) + TypeFromVal = Callee.getVirtualFunctionType(); + else + TypeFromVal = + Callee.getFunctionPointer()->getType()->getPointerElementType(); + assert(IRFuncTy == TypeFromVal); } #endif @@ -4207,8 +4212,8 @@ // cases, we can't do any parameter mismatch checks. Give up and bitcast // the callee. unsigned CalleeAS = CalleePtr->getType()->getPointerAddressSpace(); - auto FnTy = getTypes().GetFunctionType(CallInfo)->getPointerTo(CalleeAS); - CalleePtr = Builder.CreateBitCast(CalleePtr, FnTy); + CalleePtr = + Builder.CreateBitCast(CalleePtr, IRFuncTy->getPointerTo(CalleeAS)); } else { llvm::Type *LastParamTy = IRFuncTy->getParamType(IRFuncTy->getNumParams() - 1); @@ -4240,19 +4245,20 @@ // // This makes the IR nicer, but more importantly it ensures that we // can inline the function at -O0 if it is marked always_inline. - auto simplifyVariadicCallee = [](llvm::Value *Ptr) -> llvm::Value* { - llvm::FunctionType *CalleeFT = - cast(Ptr->getType()->getPointerElementType()); + auto simplifyVariadicCallee = [](llvm::FunctionType *CalleeFT, + llvm::Value *Ptr) -> llvm::Function * { if (!CalleeFT->isVarArg()) - return Ptr; + return nullptr; - llvm::ConstantExpr *CE = dyn_cast(Ptr); - if (!CE || CE->getOpcode() != llvm::Instruction::BitCast) - return Ptr; + // Get underlying value if it's a bitcast + if (llvm::ConstantExpr *CE = dyn_cast(Ptr)) { + if (CE->getOpcode() == llvm::Instruction::BitCast) + Ptr = CE->getOperand(0); + } - llvm::Function *OrigFn = dyn_cast(CE->getOperand(0)); + llvm::Function *OrigFn = dyn_cast(Ptr); if (!OrigFn) - return Ptr; + return nullptr; llvm::FunctionType *OrigFT = OrigFn->getFunctionType(); @@ -4261,15 +4267,19 @@ if (OrigFT->isVarArg() || OrigFT->getNumParams() != CalleeFT->getNumParams() || OrigFT->getReturnType() != CalleeFT->getReturnType()) - return Ptr; + return nullptr; for (unsigned i = 0, e = OrigFT->getNumParams(); i != e; ++i) if (OrigFT->getParamType(i) != CalleeFT->getParamType(i)) - return Ptr; + return nullptr; return OrigFn; }; - CalleePtr = simplifyVariadicCallee(CalleePtr); + + if (llvm::Function *OrigFn = simplifyVariadicCallee(IRFuncTy, CalleePtr)) { + CalleePtr = OrigFn; + IRFuncTy = OrigFn->getFunctionType(); + } // 3. Perform the actual call. @@ -4364,10 +4374,10 @@ // Emit the actual call/invoke instruction. llvm::CallBase *CI; if (!InvokeDest) { - CI = Builder.CreateCall(CalleePtr, IRCallArgs, BundleList); + CI = Builder.CreateCall(IRFuncTy, CalleePtr, IRCallArgs, BundleList); } else { llvm::BasicBlock *Cont = createBasicBlock("invoke.cont"); - CI = Builder.CreateInvoke(CalleePtr, Cont, InvokeDest, IRCallArgs, + CI = Builder.CreateInvoke(IRFuncTy, CalleePtr, Cont, InvokeDest, IRCallArgs, BundleList); EmitBlock(Cont); } @@ -4591,7 +4601,7 @@ if (isVirtual()) { const CallExpr *CE = getVirtualCallExpr(); return CGF.CGM.getCXXABI().getVirtualFunctionPointer( - CGF, getVirtualMethodDecl(), getThisAddress(), getFunctionType(), + CGF, getVirtualMethodDecl(), getThisAddress(), getVirtualFunctionType(), CE ? CE->getBeginLoc() : SourceLocation()); }