diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h --- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h +++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h @@ -343,6 +343,7 @@ Instruction::CastOps isEliminableCastPair(const CastInst *CI1, const CastInst *CI2); Value *simplifyIntToPtrRoundTripCast(Value *Val); + Value *isIntToPtrRoundTripCast(Value *Val); Value *foldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS, BinaryOperator &And); Value *foldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, BinaryOperator &Or); diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp --- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp +++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp @@ -349,6 +349,14 @@ // Simplifies IntToPtr/PtrToInt RoundTrip Cast To BitCast. // inttoptr ( ptrtoint (x) ) --> x Value *InstCombinerImpl::simplifyIntToPtrRoundTripCast(Value *Val) { + if (Value *Ptr = isIntToPtrRoundTripCast(Val)) { + return CastInst::CreateBitOrPointerCast(Ptr, Val->getType(), "", + cast(Val)); + } + return nullptr; +} + +Value *InstCombinerImpl::isIntToPtrRoundTripCast(Value *Val) { auto *IntToPtr = dyn_cast(Val); if (IntToPtr && DL.getPointerTypeSizeInBits(IntToPtr->getDestTy()) == DL.getTypeSizeInBits(IntToPtr->getSrcTy())) { @@ -359,8 +367,7 @@ PtrToInt->getSrcTy()->getPointerAddressSpace() && DL.getPointerTypeSizeInBits(PtrToInt->getSrcTy()) == DL.getTypeSizeInBits(PtrToInt->getDestTy())) { - return CastInst::CreateBitOrPointerCast(PtrToInt->getOperand(0), CastTy, - "", PtrToInt); + return PtrToInt->getOperand(0); } } return nullptr;