diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -4568,6 +4568,14 @@ /// Handle icmp (cast x), (cast or constant). Instruction *InstCombinerImpl::foldICmpWithCastOp(ICmpInst &ICmp) { + // icmp (inttoptr (ptrtoint p1)), p2 --> icmp p1, p2. + Value *SimplifiedOp0 = simplifyIntToPtrRoundTripCast(ICmp.getOperand(0)); + Value *SimplifiedOp1 = simplifyIntToPtrRoundTripCast(ICmp.getOperand(1)); + if (SimplifiedOp0 || SimplifiedOp1) + return new ICmpInst(ICmp.getPredicate(), + SimplifiedOp0 ? SimplifiedOp0 : ICmp.getOperand(0), + SimplifiedOp1 ? SimplifiedOp1 : ICmp.getOperand(1)); + auto *CastOp0 = dyn_cast(ICmp.getOperand(0)); if (!CastOp0) return nullptr; diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h --- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h +++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h @@ -340,6 +340,7 @@ /// \see CastInst::isEliminableCastPair Instruction::CastOps isEliminableCastPair(const CastInst *CI1, const CastInst *CI2); + Value *simplifyIntToPtrRoundTripCast(Value *Val); Value *foldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS, BinaryOperator &And); Value *foldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, BinaryOperator &Or); diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp --- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp +++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp @@ -346,6 +346,29 @@ return true; } +// Simplifies IntToPtr/PtrToInt RoundTrip Cast To BitCast. +// inttoptr ( ptrtoint (x) ) --> x +Value *InstCombinerImpl::simplifyIntToPtrRoundTripCast(Value *Val) { + auto *Cast = dyn_cast(Val); + if (!Cast) + return nullptr; + if (Cast->getOpcode() == Instruction::IntToPtr && + DL.getPointerTypeSizeInBits(Cast->getDestTy()) == + DL.getTypeSizeInBits(Cast->getSrcTy())) { + auto *PtrToInt = dyn_cast(Cast->getOperand(0)); + Type *CastTy = Cast->getDestTy(); + if (PtrToInt && + CastTy->getPointerAddressSpace() == + PtrToInt->getSrcTy()->getPointerAddressSpace() && + DL.getPointerTypeSizeInBits(PtrToInt->getSrcTy()) == + DL.getTypeSizeInBits(PtrToInt->getDestTy())) { + Value *Ptr = PtrToInt->getOperand(0); + return Builder.CreateBitCast(Ptr, CastTy); + } + } + return nullptr; +} + /// This performs a few simplifications for operators that are associative or /// commutative: /// diff --git a/llvm/test/Transforms/InstCombine/ptr-int-ptr-icmp.ll b/llvm/test/Transforms/InstCombine/ptr-int-ptr-icmp.ll --- a/llvm/test/Transforms/InstCombine/ptr-int-ptr-icmp.ll +++ b/llvm/test/Transforms/InstCombine/ptr-int-ptr-icmp.ll @@ -8,9 +8,7 @@ define i1 @func(i8* %X, i8* %Y) { ; CHECK-LABEL: @func( -; CHECK-NEXT: [[I:%.*]] = ptrtoint i8* [[X:%.*]] to i64 -; CHECK-NEXT: [[P:%.*]] = inttoptr i64 [[I]] to i8* -; CHECK-NEXT: [[CMP:%.*]] = icmp eq i8* [[P]], [[Y:%.*]] +; CHECK-NEXT: [[CMP:%.*]] = icmp eq i8* [[X:%.*]], [[Y:%.*]] ; CHECK-NEXT: ret i1 [[CMP]] ; %i = ptrtoint i8* %X to i64 @@ -21,9 +19,8 @@ define i1 @func_pointer_different_types(i16* %X, i8* %Y) { ; CHECK-LABEL: @func_pointer_different_types( -; CHECK-NEXT: [[I:%.*]] = ptrtoint i16* [[X:%.*]] to i64 -; CHECK-NEXT: [[P:%.*]] = inttoptr i64 [[I]] to i8* -; CHECK-NEXT: [[CMP:%.*]] = icmp eq i8* [[P]], [[Y:%.*]] +; CHECK-NEXT: [[TMP1:%.*]] = bitcast i16* [[X:%.*]] to i8* +; CHECK-NEXT: [[CMP:%.*]] = icmp eq i8* [[TMP1]], [[Y:%.*]] ; CHECK-NEXT: ret i1 [[CMP]] ; %i = ptrtoint i16* %X to i64 @@ -37,9 +34,8 @@ define i1 @func_commutative(i16* %X) { ; CHECK-LABEL: @func_commutative( ; CHECK-NEXT: [[Y:%.*]] = call i8* @gen8ptr() -; CHECK-NEXT: [[I:%.*]] = ptrtoint i16* [[X:%.*]] to i64 -; CHECK-NEXT: [[P:%.*]] = inttoptr i64 [[I]] to i8* -; CHECK-NEXT: [[CMP:%.*]] = icmp eq i8* [[Y]], [[P]] +; CHECK-NEXT: [[TMP1:%.*]] = bitcast i16* [[X:%.*]] to i8* +; CHECK-NEXT: [[CMP:%.*]] = icmp eq i8* [[Y]], [[TMP1]] ; CHECK-NEXT: ret i1 [[CMP]] ; %Y = call i8* @gen8ptr() ; thwart complexity-based canonicalization