Index: lib/Transforms/Scalar/InferAddressSpaces.cpp =================================================================== --- lib/Transforms/Scalar/InferAddressSpaces.cpp +++ lib/Transforms/Scalar/InferAddressSpaces.cpp @@ -138,7 +138,7 @@ // Tries to infer the specific address space of each address expression in // Postorder. - void inferAddressSpaces(const std::vector &Postorder, + void inferAddressSpaces(ArrayRef Postorder, ValueToAddrSpaceMapTy *InferredAddrSpace) const; bool isSafeToCastConstAddrSpace(Constant *C, unsigned NewAS) const; @@ -147,7 +147,7 @@ // address spaces if InferredAddrSpace says so. Postorder is the postorder of // all flat expressions in the use-def graph of function F. bool - rewriteWithNewAddressSpaces(const std::vector &Postorder, + rewriteWithNewAddressSpaces(ArrayRef Postorder, const ValueToAddrSpaceMapTy &InferredAddrSpace, Function *F) const; @@ -162,7 +162,7 @@ std::vector> &PostorderStack, DenseSet &Visited) const; - std::vector collectFlatAddressExpressions(Function &F) const; + std::vector collectFlatAddressExpressions(Function &F) const; Value *cloneValueWithNewAddressSpace( Value *V, unsigned NewAddrSpace, @@ -274,16 +274,36 @@ Value *V, std::vector> &PostorderStack, DenseSet &Visited) const { assert(V->getType()->isPointerTy()); + + // Generic addressing expressions may be hidden in nested constant + // expressions. + if (ConstantExpr *CE = dyn_cast(V)) { + // TODO: Look in non-address parts, like icmp operands. + if (isAddressExpression(*CE) && Visited.insert(CE).second) + PostorderStack.push_back(std::make_pair(CE, false)); + + return; + } + if (isAddressExpression(*V) && V->getType()->getPointerAddressSpace() == FlatAddrSpace) { - if (Visited.insert(V).second) + if (Visited.insert(V).second) { PostorderStack.push_back(std::make_pair(V, false)); + + Operator *Op = cast(V); + for (unsigned I = 0, E = Op->getNumOperands(); I != E; ++I) { + if (ConstantExpr *CE = dyn_cast(Op->getOperand(I))) { + if (isAddressExpression(*CE) && Visited.insert(CE).second) + PostorderStack.emplace_back(CE, false); + } + } + } } } // Returns all flat address expressions in function F. The elements are ordered // ordered in postorder. -std::vector +std::vector InferAddressSpaces::collectFlatAddressExpressions(Function &F) const { // This function implements a non-recursive postorder traversal of a partial // use-def graph of function F. @@ -332,18 +352,19 @@ } } - std::vector Postorder; // The resultant postorder. + std::vector Postorder; // The resultant postorder. while (!PostorderStack.empty()) { + Value *TopVal = PostorderStack.back().first; // If the operands of the expression on the top are already explored, // adds that expression to the resultant postorder. if (PostorderStack.back().second) { - Postorder.push_back(PostorderStack.back().first); + Postorder.push_back(TopVal); PostorderStack.pop_back(); continue; } // Otherwise, adds its operands to the stack and explores them. PostorderStack.back().second = true; - for (Value *PtrOperand : getPointerOperands(*PostorderStack.back().first)) { + for (Value *PtrOperand : getPointerOperands(*TopVal)) { appendsFlatAddressExpressionToPostorderStack(PtrOperand, PostorderStack, Visited); } @@ -562,7 +583,7 @@ return false; // Collects all flat address expressions in postorder. - std::vector Postorder = collectFlatAddressExpressions(F); + std::vector Postorder = collectFlatAddressExpressions(F); // Runs a data-flow analysis to refine the address spaces of every expression // in Postorder. @@ -575,7 +596,7 @@ } void InferAddressSpaces::inferAddressSpaces( - const std::vector &Postorder, + ArrayRef Postorder, ValueToAddrSpaceMapTy *InferredAddrSpace) const { SetVector Worklist(Postorder.begin(), Postorder.end()); // Initially, all expressions are in the uninitialized address space. @@ -787,7 +808,7 @@ } bool InferAddressSpaces::rewriteWithNewAddressSpaces( - const std::vector &Postorder, + ArrayRef Postorder, const ValueToAddrSpaceMapTy &InferredAddrSpace, Function *F) const { // For each address expression to be modified, creates a clone of it with its // pointer operands converted to the new address space. Since the pointer @@ -818,7 +839,9 @@ std::vector DeadInstructions; // Replaces the uses of the old address expressions with the new ones. - for (Value *V : Postorder) { + for (const WeakVH &WVH : Postorder) { + assert(WVH && "value was unexpectedly deleted"); + Value *V = WVH; Value *NewV = ValueWithNewAddrSpace.lookup(V); if (NewV == nullptr) continue; @@ -826,6 +849,17 @@ DEBUG(dbgs() << "Replacing the uses of " << *V << "\n with\n " << *NewV << '\n'); + if (Constant *C = dyn_cast(V)) { + Constant *Replace = ConstantExpr::getAddrSpaceCast(cast(NewV), + C->getType()); + if (C != Replace) { + DEBUG(dbgs() << "Inserting replacement const cast: " + << Replace << ": " << *Replace << '\n'); + C->replaceAllUsesWith(Replace); + V = Replace; + } + } + Value::use_iterator I, E, Next; for (I = V->use_begin(), E = V->use_end(); I != E; ) { Use &U = *I; Index: test/Transforms/InferAddressSpaces/AMDGPU/infer-getelementptr.ll =================================================================== --- test/Transforms/InferAddressSpaces/AMDGPU/infer-getelementptr.ll +++ test/Transforms/InferAddressSpaces/AMDGPU/infer-getelementptr.ll @@ -15,9 +15,8 @@ ret void } -; FIXME: Should be able to eliminate inner constantexpr addrspacecast. ; CHECK-LABEL: @constexpr_gep_addrspacecast( -; CHECK: %gep0 = getelementptr inbounds double, double addrspace(3)* addrspacecast (double addrspace(4)* getelementptr ([648 x double], [648 x double] addrspace(4)* addrspacecast ([648 x double] addrspace(3)* @lds to [648 x double] addrspace(4)*), i64 0, i64 384) to double addrspace(3)*), i64 %idx0 +; CHECK-NEXT: %gep0 = getelementptr inbounds double, double addrspace(3)* getelementptr inbounds ([648 x double], [648 x double] addrspace(3)* @lds, i64 0, i64 384), i64 %idx0 ; CHECK-NEXT: store double 1.000000e+00, double addrspace(3)* %gep0, align 8 define void @constexpr_gep_addrspacecast(i64 %idx0, i64 %idx1) { %gep0 = getelementptr inbounds double, double addrspace(4)* getelementptr ([648 x double], [648 x double] addrspace(4)* addrspacecast ([648 x double] addrspace(3)* @lds to [648 x double] addrspace(4)*), i64 0, i64 384), i64 %idx0 @@ -54,3 +53,21 @@ store i32 99, i32 addrspace(4)* %p3 ret void } + +; CHECK-LABEL: @repeated_constexpr_gep_addrspacecast( +; CHECK-NEXT: %gep0 = getelementptr inbounds double, double addrspace(3)* getelementptr inbounds ([648 x double], [648 x double] addrspace(3)* @lds, i64 0, i64 384), i64 %idx0 +; CHECK-NEXT: store double 1.000000e+00, double addrspace(3)* %gep0, align 8 +; CHECK-NEXT: %gep1 = getelementptr inbounds double, double addrspace(3)* getelementptr inbounds ([648 x double], [648 x double] addrspace(3)* @lds, i64 0, i64 384), i64 %idx1 +; CHECK-NEXT: store double 1.000000e+00, double addrspace(3)* %gep1, align 8 +; CHECK-NEXT: ret void +define void @repeated_constexpr_gep_addrspacecast(i64 %idx0, i64 %idx1) { + %gep0 = getelementptr inbounds double, double addrspace(4)* getelementptr ([648 x double], [648 x double] addrspace(4)* addrspacecast ([648 x double] addrspace(3)* @lds to [648 x double] addrspace(4)*), i64 0, i64 384), i64 %idx0 + %asc0 = addrspacecast double addrspace(4)* %gep0 to double addrspace(3)* + store double 1.0, double addrspace(3)* %asc0, align 8 + + %gep1 = getelementptr inbounds double, double addrspace(4)* getelementptr ([648 x double], [648 x double] addrspace(4)* addrspacecast ([648 x double] addrspace(3)* @lds to [648 x double] addrspace(4)*), i64 0, i64 384), i64 %idx1 + %asc1 = addrspacecast double addrspace(4)* %gep1 to double addrspace(3)* + store double 1.0, double addrspace(3)* %asc1, align 8 + + ret void +} Index: test/Transforms/InferAddressSpaces/NVPTX/bug31948.ll =================================================================== --- test/Transforms/InferAddressSpaces/NVPTX/bug31948.ll +++ test/Transforms/InferAddressSpaces/NVPTX/bug31948.ll @@ -10,7 +10,7 @@ ; CHECK: %tmp = load float*, float* addrspace(3)* getelementptr inbounds (%struct.bar, %struct.bar addrspace(3)* @var1, i64 0, i32 1), align 8 ; CHECK: %tmp1 = load float, float* %tmp, align 4 ; CHECK: store float %conv1, float* %tmp, align 4 -; CHECK: store i32 32, i32 addrspace(3)* addrspacecast (i32* bitcast (float** getelementptr (%struct.bar, %struct.bar* addrspacecast (%struct.bar addrspace(3)* @var1 to %struct.bar*), i64 0, i32 1) to i32*) to i32 addrspace(3)*), align 4 +; CHECK: store i32 32, i32 addrspace(3)* bitcast (float* addrspace(3)* getelementptr inbounds (%struct.bar, %struct.bar addrspace(3)* @var1, i64 0, i32 1) to i32 addrspace(3)*), align 4 define void @bug31948(float %a, float* nocapture readnone %x, float* nocapture readnone %y) local_unnamed_addr #0 { entry: %tmp = load float*, float** getelementptr (%struct.bar, %struct.bar* addrspacecast (%struct.bar addrspace(3)* @var1 to %struct.bar*), i64 0, i32 1), align 8