Index: docs/LangRef.rst =================================================================== --- docs/LangRef.rst +++ docs/LangRef.rst @@ -575,8 +575,8 @@ LLVM IR optionally allows the frontend to denote pointers in certain address spaces as "non-integral" via the :ref:`datalayout string`. Non-integral pointer types represent pointers that have an *unspecified* bitwise -representation; that is, the integral representation may be target dependent or -unstable (not backed by a fixed integer). +representation (including the null value); that is, the integral representation +may be target dependent or unstable (not backed by a fixed integer). ``inttoptr`` instructions converting integers to non-integral pointer types are ill-typed, and so are ``ptrtoint`` instructions converting values of Index: include/llvm/Analysis/ValueTracking.h =================================================================== --- include/llvm/Analysis/ValueTracking.h +++ include/llvm/Analysis/ValueTracking.h @@ -223,7 +223,7 @@ /// 0.0 etc. If the value can't be handled with a repeated byte store (e.g. /// i16 0x1234), return null. If the value is entirely undef and padding, /// return undef. - Value *isBytewiseValue(Value *V); + Value *isBytewiseValue(Value *V, const DataLayout &DL); /// Given an aggregrate and an sequence of indices, see if the scalar value /// indexed is already around as a register, for example if it were inserted Index: lib/Analysis/ValueTracking.cpp =================================================================== --- lib/Analysis/ValueTracking.cpp +++ lib/Analysis/ValueTracking.cpp @@ -3107,12 +3107,39 @@ return true; } -Value *llvm::isBytewiseValue(Value *V) { +// Whether a type contains non-integral pointer. +static bool containsNonIntegralPointer(Type *T, const DataLayout &DL) { + if (DL.getNonIntegralAddressSpaces().empty()) + return false; + + switch (T->getTypeID()) { + case Type::PointerTyID: + return DL.isNonIntegralPointerType(T); + case Type::ArrayTyID: + return containsNonIntegralPointer(T->getArrayElementType(), DL); + case Type::VectorTyID: + return containsNonIntegralPointer(T->getVectorElementType(), DL); + case Type::StructTyID: { + for (unsigned I = 0, N = T->getStructNumElements(); I < N; ++I) + if (containsNonIntegralPointer(T->getStructElementType(I), DL)) + return true; + return false; + } + default: + return false; + } +} + +Value *llvm::isBytewiseValue(Value *V, const DataLayout &DL) { // All byte-wide stores are splatable, even of arbitrary variables. if (V->getType()->isIntegerTy(8)) return V; + // The bit pattern of non-integral pointer is technically undefined. + if (containsNonIntegralPointer(V->getType(), DL)) + return nullptr; + LLVMContext &Ctx = V->getContext(); // Undef don't care. @@ -3146,7 +3173,7 @@ else if (CFP->getType()->isDoubleTy()) Ty = Type::getInt64Ty(Ctx); // Don't handle long double formats, which have strange constraints. - return Ty ? isBytewiseValue(ConstantExpr::getBitCast(CFP, Ty)) : nullptr; + return Ty ? isBytewiseValue(ConstantExpr::getBitCast(CFP, Ty), DL) : nullptr; } // We can handle constant integers that are multiple of 8 bits. @@ -3174,20 +3201,20 @@ if (ConstantDataSequential *CA = dyn_cast(C)) { Value *Val = UndefInt8; for (unsigned I = 0, E = CA->getNumElements(); I != E; ++I) - if (!(Val = Merge(Val, isBytewiseValue(CA->getElementAsConstant(I))))) + if (!(Val = Merge(Val, isBytewiseValue(CA->getElementAsConstant(I), DL)))) return nullptr; return Val; } if (isa(C)) { Constant *Splat = cast(C)->getSplatValue(); - return Splat ? isBytewiseValue(Splat) : nullptr; + return Splat ? isBytewiseValue(Splat, DL) : nullptr; } if (isa(C) || isa(C)) { Value *Val = UndefInt8; for (unsigned I = 0, E = C->getNumOperands(); I != E; ++I) - if (!(Val = Merge(Val, isBytewiseValue(C->getOperand(I))))) + if (!(Val = Merge(Val, isBytewiseValue(C->getOperand(I), DL)))) return nullptr; return Val; } Index: lib/Transforms/Scalar/LoopIdiomRecognize.cpp =================================================================== --- lib/Transforms/Scalar/LoopIdiomRecognize.cpp +++ lib/Transforms/Scalar/LoopIdiomRecognize.cpp @@ -430,7 +430,7 @@ // turned into a memset of i8 -1, assuming that all the consecutive bytes // are stored. A store of i32 0x01020304 can never be turned into a memset, // but it can be turned into memset_pattern if the target supports it. - Value *SplatValue = isBytewiseValue(StoredVal); + Value *SplatValue = isBytewiseValue(StoredVal, *DL); Constant *PatternValue = nullptr; // Note: memset and memset_pattern on unordered-atomic is yet not supported @@ -607,7 +607,7 @@ Constant *FirstPatternValue = nullptr; if (For == ForMemset::Yes) - FirstSplatValue = isBytewiseValue(FirstStoredVal); + FirstSplatValue = isBytewiseValue(FirstStoredVal, *DL); else FirstPatternValue = getMemSetPatternValue(FirstStoredVal, DL); @@ -640,7 +640,7 @@ Constant *SecondPatternValue = nullptr; if (For == ForMemset::Yes) - SecondSplatValue = isBytewiseValue(SecondStoredVal); + SecondSplatValue = isBytewiseValue(SecondStoredVal, *DL); else SecondPatternValue = getMemSetPatternValue(SecondStoredVal, DL); @@ -860,7 +860,7 @@ Value *StoredVal, Instruction *TheStore, SmallPtrSetImpl &Stores, const SCEVAddRecExpr *Ev, const SCEV *BECount, bool NegStride, bool IsLoopMemset) { - Value *SplatValue = isBytewiseValue(StoredVal); + Value *SplatValue = isBytewiseValue(StoredVal, *DL); Constant *PatternValue = nullptr; if (!SplatValue) Index: lib/Transforms/Scalar/MemCpyOptimizer.cpp =================================================================== --- lib/Transforms/Scalar/MemCpyOptimizer.cpp +++ lib/Transforms/Scalar/MemCpyOptimizer.cpp @@ -413,7 +413,7 @@ if (!NextStore->isSimple()) break; // Check to see if this stored value is of the same byte-splattable value. - Value *StoredByte = isBytewiseValue(NextStore->getOperand(0)); + Value *StoredByte = isBytewiseValue(NextStore->getOperand(0), DL); if (isa(ByteVal) && StoredByte) ByteVal = StoredByte; if (ByteVal != StoredByte) @@ -750,7 +750,7 @@ // byte at a time like "0" or "-1" or any width, as well as things like // 0xA0A0A0A0 and 0.0. auto *V = SI->getOperand(0); - if (Value *ByteVal = isBytewiseValue(V)) { + if (Value *ByteVal = isBytewiseValue(V, DL)) { if (Instruction *I = tryMergingIntoMemset(SI, SI->getPointerOperand(), ByteVal)) { BBI = I->getIterator(); // Don't invalidate iterator. @@ -1225,10 +1225,12 @@ return false; } + const DataLayout &DL = M->getModule()->getDataLayout(); + // If copying from a constant, try to turn the memcpy into a memset. if (GlobalVariable *GV = dyn_cast(M->getSource())) if (GV->isConstant() && GV->hasDefinitiveInitializer()) - if (Value *ByteVal = isBytewiseValue(GV->getInitializer())) { + if (Value *ByteVal = isBytewiseValue(GV->getInitializer(), DL)) { IRBuilder<> Builder(M); Builder.CreateMemSet(M->getRawDest(), ByteVal, M->getLength(), M->getDestAlignment(), false); Index: test/Transforms/MemCpyOpt/non-integral-ptr.ll =================================================================== --- /dev/null +++ test/Transforms/MemCpyOpt/non-integral-ptr.ll @@ -0,0 +1,31 @@ +; RUN: opt -memcpyopt -S < %s | FileCheck %s + +; Check that non-integral pointer writes are not optimized to memset. + +target datalayout = "e-m:e-i64:64-f80:128-n8:16:32:64-S128-ni:1" +target triple = "x86_64-unknown-linux-gnu" + +declare void @llvm.memcpy.p0i8.p0i8.i64(i8* nocapture, i8* nocapture, i64, i1) nounwind + +@x = global i8 addrspace(1)* zeroinitializer +@const0 = private constant i8 addrspace(1)* null + +define void @test_memcpy() { + ; CHECK-LABEL: test_memcpy + ; CHECK: @llvm.memcpy + ; CHECK-NOT: @llvm.memset + %1 = bitcast i8 addrspace(1)** @x to i8* + %2 = bitcast i8 addrspace(1)** @const0 to i8* + call void @llvm.memcpy.p0i8.p0i8.i64(i8* align 8 %1, i8* align 8 %2, i64 8, i1 false) + ret void +} + +@y = global [3 x i8 addrspace(1)*] zeroinitializer + +define void @test_aggregate() { + ; CHECK-LABEL: test_aggregate + ; CHECK: store + ; CHECK-NOT: @llvm.memset + store [3 x i8 addrspace(1)*] zeroinitializer, [3 x i8 addrspace(1)*]* @y + ret void +} Index: unittests/Analysis/ValueTrackingTest.cpp =================================================================== --- unittests/Analysis/ValueTrackingTest.cpp +++ unittests/Analysis/ValueTrackingTest.cpp @@ -616,3 +616,35 @@ "declare i16 @llvm.fshl.i16(i16, i16, i16)\n"); expectKnownBits(/*zero*/ 15u, /*one*/ 3840u); } + +TEST(ValueTracking, IsBytewiseValueNonIntegral) { + StringRef Assembly = "%t = type { i64, i64 addrspace(10)* }" + "@x = private constant i64 addrspace(10)* null" + "@y = private constant [3 x i64 addrspace(10)*] zeroinitializer" + "@z = private constant %t zeroinitializer"; + + LLVMContext Context; + SMDiagnostic Error; + auto M = parseAssemblyString(Assembly, Error, Context); + assert(M && "Bad assembly?"); + + // Set address space 10 as non-integral. + std::string DLStr = M->getDataLayoutStr(); + if (!DLStr.empty()) + DLStr += "-"; + DLStr += "ni:10"; + DataLayout DL(DLStr); + M->setDataLayout(DL); + + GlobalValue *X = M->getNamedValue("x"); + assert(X && "Bad assembly?"); + EXPECT_FALSE(isBytewiseValue(X, DL)); + + GlobalValue *Y = M->getNamedValue("y"); + assert(Y && "Bad assembly?"); + EXPECT_FALSE(isBytewiseValue(Y, DL)); + + GlobalValue *Z = M->getNamedValue("z"); + assert(Z && "Bad assembly?"); + EXPECT_FALSE(isBytewiseValue(Z, DL)); +}