Index: lib/Transforms/InstCombine/InstCombineInternal.h =================================================================== --- lib/Transforms/InstCombine/InstCombineInternal.h +++ lib/Transforms/InstCombine/InstCombineInternal.h @@ -440,6 +440,77 @@ return &I; } + /// replacePtrUsesAndPropagateAddrSpace - Recursively update the users of a + /// pointer with a new value, propagating the new address space to all users. + void replacePtrUsesAndPropagateAddrSpace(Value *OldPtr, Value *NewPtr) { + unsigned NewAddrSpace = NewPtr->getType()->getPointerAddressSpace(); + + std::vector Users(OldPtr->users().begin(), OldPtr->users().end()); + for (auto U = Users.begin(); U != Users.end(); U++) { + if (auto BI = dyn_cast(*U)) { + // Create and insert new bitcast instruction + Type *DestType = BI->getDestTy()->getPointerElementType()->getPointerTo( + NewAddrSpace); + auto NewBI = new BitCastInst(NewPtr, DestType); + NewBI->insertAfter(BI); + replacePtrUsesAndPropagateAddrSpace(BI, NewBI); + } else if (auto ASCI = dyn_cast(*U)) { + if (OldPtr->getType() == ASCI->getType()) { + // If the type already matches, remove the cast completely + ASCI->replaceAllUsesWith(NewPtr); + } else { + // Create and insert new bitcast instruction (with new address space) + auto NewBI = new BitCastInst( + NewPtr, + ASCI->getType()->getPointerElementType()->getPointerTo( + NewAddrSpace)); + NewBI->insertAfter(ASCI); + replacePtrUsesAndPropagateAddrSpace(ASCI, NewBI); + } + } else if (auto GEP = dyn_cast(*U)) { + // Create and insert new getelementptr instruction + std::vector Indices(GEP->idx_begin(), GEP->idx_end()); + auto NewGEP = GetElementPtrInst::Create(NULL, NewPtr, Indices); + NewGEP->insertAfter(GEP); + replacePtrUsesAndPropagateAddrSpace(GEP, NewGEP); + } else if (auto LI = dyn_cast(*U)) { + // Create and insert new load instruction + auto NewLoad = new LoadInst(NewPtr); + NewLoad->setAlignment(LI->getAlignment()); + NewLoad->setVolatile(LI->isVolatile()); + NewLoad->insertAfter(LI); + LI->replaceAllUsesWith(NewLoad); + } else if (auto MI = dyn_cast(*U)) { + assert(MI->getArgOperand(1) == OldPtr); + + // Get intrinsic with new address space + auto MIType = MI->getFunctionType(); + ArrayRef Types = { + MIType->getParamType(0), + MIType->getParamType(1)->getPointerElementType()->getPointerTo( + NewAddrSpace), + MIType->getParamType(2)}; + auto NewMIFunc = Intrinsic::getDeclaration(MI->getModule(), + MI->getIntrinsicID(), Types); + NewMIFunc->setAttributes(MI->getCalledFunction()->getAttributes()); + + // Create and insert new call instruction + std::vector Args(MI->arg_begin(), MI->arg_end()); + Args[1] = NewPtr; + auto NewMI = CallInst::Create(NewMIFunc, Args); + NewMI->insertAfter(MI); + eraseInstFromFunction(*MI); + } else if (isa(*U)) { + // By this point the address space should match the original, so just + // replace the use + assert(OldPtr->getType() == NewPtr->getType()); + cast(*U)->replaceUsesOfWith(OldPtr, NewPtr); + } else { + assert(false && "Can't update address space for this instruction"); + } + } + } + /// Creates a result tuple for an overflow intrinsic \p II with a given /// \p Result and a constant \p Overflow value. Instruction *CreateOverflowTuple(IntrinsicInst *II, Value *Result, Index: lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp =================================================================== --- lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp +++ lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp @@ -59,6 +59,8 @@ // ahead and replace the value with the global, this lets the caller quickly // eliminate the markers. + SmallVector ArgAddrSpaces; + SmallVector, 35> ValuesToInspect; ValuesToInspect.emplace_back(V, false); while (!ValuesToInspect.empty()) { @@ -94,6 +96,10 @@ unsigned DataOpNo = CS.getDataOperandNo(&U); bool IsArgOperand = CS.isArgOperand(&U); + // Track the address spaces of function call arguments + if (IsArgOperand && !isa(I)) + ArgAddrSpaces.emplace_back(U->getType()->getPointerAddressSpace()); + // Inalloca arguments are clobbered by the call. if (IsArgOperand && CS.isInAllocaArgument(DataOpNo)) return false; @@ -152,6 +158,17 @@ TheCopy = MI; } } + + // Check that the address spaces of all function call arguments match the + // source of the memory transfer + if (TheCopy) { + for (unsigned AS : ArgAddrSpaces) { + if (AS != TheCopy->getSource()->getType()->getPointerAddressSpace() && + AS != TheCopy->getRawSource()->getType()->getPointerAddressSpace()) + return false; + } + } + return true; } @@ -292,13 +309,21 @@ DEBUG(dbgs() << " memcpy = " << *Copy << '\n'); for (unsigned i = 0, e = ToDelete.size(); i != e; ++i) eraseInstFromFunction(*ToDelete[i]); + Constant *TheSrc = cast(Copy->getSource()); - Constant *Cast - = ConstantExpr::getPointerBitCastOrAddrSpaceCast(TheSrc, AI.getType()); - Instruction *NewI = replaceInstUsesWith(AI, Cast); + eraseInstFromFunction(*Copy); ++NumGlobalCopies; - return NewI; + + // Replace the alloca uses with the global (casting as necessary), and + // propagate the new address space through all of the alloca uses. + Constant *Cast = ConstantExpr::getPointerCast( + TheSrc, + AI.getType()->getPointerElementType()->getPointerTo( + TheSrc->getType()->getPointerAddressSpace())); + replacePtrUsesAndPropagateAddrSpace(&AI, Cast); + + return &AI; } } } Index: test/Transforms/InstCombine/memcpy-from-global.ll =================================================================== --- test/Transforms/InstCombine/memcpy-from-global.ll +++ test/Transforms/InstCombine/memcpy-from-global.ll @@ -46,6 +46,7 @@ @G = constant %T {i8 1, [123 x i8] zeroinitializer } @H = constant [2 x %U] zeroinitializer, align 16 +@I = addrspace(1) constant [2 x %U] zeroinitializer, align 16 define void @test2() { %A = alloca %T @@ -60,7 +61,7 @@ ; CHECK-NEXT: getelementptr inbounds [124 x i8], [124 x i8]* ; use @G instead of %A -; CHECK-NEXT: call void @llvm.memcpy.p0i8.p0i8.i64(i8* %{{.*}}, i8* getelementptr inbounds (%T, %T* @G, i64 0, i32 0) +; CHECK-NEXT: call void @llvm.memcpy.p0i8.p0i8.i64(i8* nonnull %{{.*}}, i8* getelementptr inbounds (%T, %T* @G, i64 0, i32 0) call void @llvm.memcpy.p0i8.p0i8.i64(i8* %a, i8* bitcast (%T* @G to i8*), i64 124, i32 4, i1 false) call void @llvm.memcpy.p0i8.p0i8.i64(i8* %b, i8* %a, i64 124, i32 4, i1 false) call void @bar(i8* %b) @@ -83,7 +84,7 @@ ; CHECK-NEXT: addrspacecast ; use @G instead of %A -; CHECK-NEXT: call void @llvm.memcpy.p1i8.p1i8.i64(i8 addrspace(1)* %{{.*}}, +; CHECK-NEXT: call void @llvm.memcpy.p1i8.p0i8.i64(i8 addrspace(1)* %{{.*}}, call void @llvm.memcpy.p1i8.p0i8.i64(i8 addrspace(1)* %a, i8* bitcast (%T* @G to i8*), i64 124, i32 4, i1 false) call void @llvm.memcpy.p1i8.p1i8.i64(i8 addrspace(1)* %b, i8 addrspace(1)* %a, i64 124, i32 4, i1 false) call void @bar_as1(i8 addrspace(1)* %b) @@ -204,3 +205,79 @@ ; CHECK-NEXT: call void @bar(i8* bitcast (%U* getelementptr inbounds ([2 x %U], [2 x %U]* @H, i64 0, i64 1) to i8*)) ret void } + +define void @test10(i64 %index, i32* %out) { + %A = alloca %U, align 16 + %a = bitcast %U* %A to i8* + call void @llvm.memcpy.p0i8.p1i8.i64(i8* %a, i8 addrspace(1)* bitcast ([2 x %U] addrspace(1)* @I to i8 addrspace(1)*), i64 20, i32 16, i1 false) + %tmp1 = bitcast i8* %a to i32* + %tmp2 = getelementptr i32, i32* %tmp1, i64 %index + %tmp3 = load i32, i32* %tmp2 + store i32 %tmp3, i32* %out +; CHECK-LABEL: @test10( +; CHECK-NOT: alloca +; CHECK-NOT: memcpy +; CHECK-NOT: addrspacecast +; Ensure that load still happens through addrspace(1) +; CHECK: load i32, i32 addrspace(1)* + ret void +} + +define void @test10_addrspacecast(i64 %index, i32* %out) { + %A = alloca %U, align 16 + %a = bitcast %U* %A to i8* + call void @llvm.memcpy.p0i8.p1i8.i64(i8* %a, i8 addrspace(1)* bitcast ([2 x %U] addrspace(1)* @I to i8 addrspace(1)*), i64 20, i32 16, i1 false) + %tmp1 = addrspacecast i8* %a to i32 addrspace(2)* + %tmp2 = getelementptr i32, i32 addrspace(2)* %tmp1, i64 %index + %tmp3 = load i32, i32 addrspace(2)* %tmp2 + store i32 %tmp3, i32* %out +; CHECK-LABEL: @test10_addrspacecast( +; CHECK-NOT: alloca +; CHECK-NOT: memcpy +; CHECK-NOT: addrspacecast +; Similar to above but the load from the alloca is done through a different +; address space. +; Ensure that load still happens through addrspace(1) +; CHECK: load i32, i32 addrspace(1)* + ret void +} + +define void @test11() { + %A = alloca %U, align 16 + %a = bitcast %U* %A to i8* + call void @llvm.memcpy.p0i8.p1i8.i64(i8* %a, i8 addrspace(1)* bitcast ([2 x %U] addrspace(1)* @I to i8 addrspace(1)*), i64 20, i32 16, i1 false) + %tmp = addrspacecast i8* %a to i8 addrspace(1)* + call void @bar_as1(i8 addrspace(1)* %tmp) readonly +; CHECK-LABEL: @test11( +; CHECK-NEXT: call void @bar_as1(i8 addrspace(1)* bitcast ([2 x %U] addrspace(1)* @I to i8 addrspace(1)*)) + ret void +} + +define void @test12() { + %A = alloca %U, align 16 + %a = bitcast %U* %A to i8* + call void @llvm.memcpy.p0i8.p1i8.i64(i8* %a, i8 addrspace(1)* bitcast ([2 x %U] addrspace(1)* @I to i8 addrspace(1)*), i64 20, i32 16, i1 false) + call void @bar(i8* %a) readonly +; Must retain memcpy as source address space doesn't match function argument. +; CHECK-LABEL: @test12( +; CHECK: llvm.memcpy +; CHECK: bar + ret void +} + +define void @test13() { + %A = alloca %U, align 16 + %a = bitcast %U* %A to i8* + call void @llvm.memcpy.p0i8.p1i8.i64(i8* %a, i8 addrspace(1)* bitcast ([2 x %U] addrspace(1)* @I to i8 addrspace(1)*), i64 20, i32 16, i1 false) + %tmp = addrspacecast i8* %a to i8 addrspace(1)* + call void @bar_as1(i8 addrspace(1)* %tmp) readonly + call void @bar(i8* %a) readonly +; Must retain memcpy as source address space doesn't match function argument. +; Similar to test12, but ensures that we can handle multiple function calls +; with different address spaces. +; CHECK-LABEL: @test13( +; CHECK: llvm.memcpy +; CHECK: bar_as1 +; CHECK: bar + ret void +}