diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp @@ -4654,8 +4654,15 @@ } if (!Call.use_empty() && !Call.isMustTailCall()) - if (Value *ReturnedArg = Call.getReturnedArgOperand()) - return replaceInstUsesWith(Call, ReturnedArg); + if (Value *ReturnedArg = Call.getReturnedArgOperand()) { + Type *CallTy = Call.getType(); + Type *RetArgTy = ReturnedArg->getType(); + if (RetArgTy == CallTy) + return replaceInstUsesWith(Call, ReturnedArg); + if (CallTy->isPointerTy() && RetArgTy->isPointerTy()) + return replaceInstUsesWith( + Call, Builder.CreatePointerCast(ReturnedArg, CallTy)); + } if (isAllocLikeFn(&Call, &TLI)) return visitAllocSite(Call); diff --git a/llvm/test/Transforms/InstCombine/call-returned.ll b/llvm/test/Transforms/InstCombine/call-returned.ll --- a/llvm/test/Transforms/InstCombine/call-returned.ll +++ b/llvm/test/Transforms/InstCombine/call-returned.ll @@ -3,6 +3,7 @@ declare i32 @passthru_i32(i32 returned) declare i8* @passthru_p8(i8* returned) +declare i8* @passthru_p8fromi32(i32* returned) define i32 @returned_const_int_arg() { ; CHECK-LABEL: @returned_const_int_arg( @@ -22,6 +23,25 @@ ret i8* %x } +define i8* @returned_const_ptr_arg_casted() { +; CHECK-LABEL: @returned_const_ptr_arg_casted( +; CHECK-NEXT: [[X:%.*]] = call i8* @passthru_p8fromi32(i32* null) +; CHECK-NEXT: ret i8* null +; + %x = call i8* @passthru_p8fromi32(i32* null) + ret i8* %x +} + +define i8* @returned_ptr_arg_casted(i32* %a) { +; CHECK-LABEL: @returned_ptr_arg_casted( +; CHECK-NEXT: [[TMP1:%.*]] = bitcast i32* [[A:%.*]] to i8* +; CHECK-NEXT: [[X:%.*]] = call i8* @passthru_p8fromi32(i32* [[A]]) +; CHECK-NEXT: ret i8* [[TMP1]] +; + %x = call i8* @passthru_p8fromi32(i32* %a) + ret i8* %x +} + define i32 @returned_var_arg(i32 %arg) { ; CHECK-LABEL: @returned_var_arg( ; CHECK-NEXT: [[X:%.*]] = call i32 @passthru_i32(i32 [[ARG:%.*]]) @@ -48,3 +68,4 @@ %x = musttail call i32 @passthru_i32(i32 %arg) ret i32 %x } +