diff --git a/llvm/include/llvm/Analysis/Loads.h b/llvm/include/llvm/Analysis/Loads.h --- a/llvm/include/llvm/Analysis/Loads.h +++ b/llvm/include/llvm/Analysis/Loads.h @@ -155,6 +155,15 @@ BasicBlock::iterator &ScanFrom, unsigned MaxInstsToScan, AAResults *AA, bool *IsLoadCSE, unsigned *NumScanedInst); + +/// Returns true if a pointer value \p A can be replace with another pointer +/// value \B if they are deemed equal through some means (e.g. information from +/// conditions). +/// NOTE: the current implementations is incomplete and unsound. It does not +/// reject all invalid cases yet, but will be made stricter in the future. In +/// particular this means returning true means unknown if replacement is safe. +bool canReplacePointersIfEqual(Value *A, Value *B, const DataLayout &DL, + Instruction *CtxI); } #endif diff --git a/llvm/lib/Analysis/Loads.cpp b/llvm/lib/Analysis/Loads.cpp --- a/llvm/lib/Analysis/Loads.cpp +++ b/llvm/lib/Analysis/Loads.cpp @@ -503,3 +503,23 @@ // block. return nullptr; } + +bool llvm::canReplacePointersIfEqual(Value *A, Value *B, const DataLayout &DL, + Instruction *CtxI) { + Type *Ty = A->getType(); + assert(Ty == B->getType() && Ty->isPointerTy() && + "values must have matching pointer types"); + + // NOTE: The checks in the function are incomplete and currently miss illegal + // cases! The current implementation is a starting point and the + // implementation should be made stricter over time. + if (auto *C = dyn_cast(B)) { + // Do not allow replacing a pointer with a constant pointer, unless it is + // either null or at least one byte is dereferenceable. + APInt OneByte(DL.getPointerTypeSizeInBits(A->getType()), 1); + return C->isNullValue() || + isDereferenceableAndAlignedPointer(B, Align(1), OneByte, DL, CtxI); + } + + return true; +} diff --git a/llvm/unittests/Analysis/LoadsTest.cpp b/llvm/unittests/Analysis/LoadsTest.cpp --- a/llvm/unittests/Analysis/LoadsTest.cpp +++ b/llvm/unittests/Analysis/LoadsTest.cpp @@ -59,3 +59,42 @@ ASSERT_TRUE(CI); ASSERT_TRUE(CI->equalsInt(42)); } + +TEST(LoadsTest, CanReplacePointersIfEqual) { + LLVMContext C; + std::unique_ptr M = parseIR(C, + R"IR( +@y = common global [1 x i32] zeroinitializer, align 4 +@x = common global [1 x i32] zeroinitializer, align 4 + +declare void @use(i32*) + +define void @f(i32* %p) { + call void @use(i32* getelementptr inbounds ([1 x i32], [1 x i32]* @y, i64 0, i64 0)) + call void @use(i32* getelementptr inbounds (i32, i32* getelementptr inbounds ([1 x i32], [1 x i32]* @x, i64 0, i64 0), i64 1)) + ret void +} +)IR"); + auto &DL = M->getDataLayout(); + auto *GV = M->getNamedValue("f"); + ASSERT_TRUE(GV); + auto *F = dyn_cast(GV); + ASSERT_TRUE(F); + + // NOTE: the implementation of canReplacePointersIfEqual is incomplete. + // Currently the only the cases it returns false for are really sound and + // returning true means unknown. + Value *P = &*F->arg_begin(); + auto InstIter = F->front().begin(); + Value *ConstDerefPtr = *cast(&*InstIter)->arg_begin(); + // ConstDerefPtr is a constant pointer that is provably de-referenceable. We + // can replace an arbitrary pointer with it. + EXPECT_TRUE(canReplacePointersIfEqual(P, ConstDerefPtr, DL, nullptr)); + + ++InstIter; + Value *ConstUnDerefPtr = *cast(&*InstIter)->arg_begin(); + // ConstUndDerefPtr is a constant pointer that is provably not + // de-referenceable. We cannot replace an arbitrary pointer with it. + EXPECT_FALSE( + canReplacePointersIfEqual(ConstDerefPtr, ConstUnDerefPtr, DL, nullptr)); +}