Index: lib/Transforms/Scalar/InferAddressSpaces.cpp =================================================================== --- lib/Transforms/Scalar/InferAddressSpaces.cpp +++ lib/Transforms/Scalar/InferAddressSpaces.cpp @@ -143,6 +143,8 @@ bool handleComplexPtrUse(User &U, Value *OldV, Value *NewV) const; + bool isSafeToCastConstAddrSpace(Constant *C, unsigned NewAS) const; + // Changes the generic address expressions in function F to point to specific // address spaces if InferredAddrSpace says so. Postorder is the postorder of // all generic address expressions in the use-def graph of function F. @@ -307,6 +309,13 @@ pushPtrOperand(MTI->getRawSource()); } else if (auto *II = dyn_cast(&I)) collectRewritableIntrinsicOperands(II, &PostorderStack, &Visited); + else if (ICmpInst *Cmp = dyn_cast(&I)) { + // FIXME: Handle vectors of pointers + if (Cmp->getOperand(0)->getType()->isPointerTy()) { + pushPtrOperand(Cmp->getOperand(0)); + pushPtrOperand(Cmp->getOperand(1)); + } + } } std::vector Postorder; // The resultant postorder. @@ -661,6 +670,29 @@ return true; } +// \p returns true if it is OK to change the address space of constant \p C with +// a ConstantExpr addrspacecast. +bool InferAddressSpaces::isSafeToCastConstAddrSpace(Constant *C, unsigned NewAS) const { + if (C->getType()->getPointerAddressSpace() == NewAS) + return true; + + if (isa(C) || isa(C)) + return true; + + if (auto *Op = dyn_cast(C)) { + // If we already have a constant addrspacecast, it should be safe to cast it + // off. + if (Op->getOpcode() == Instruction::AddrSpaceCast) + return isSafeToCastConstAddrSpace(cast(Op->getOperand(0)), NewAS); + + if (Op->getOpcode() == Instruction::IntToPtr && + Op->getType()->getPointerAddressSpace() == FlatAddrSpace) + return true; + } + + return false; +} + static Value::use_iterator skipToNextUser(Value::use_iterator I, Value::use_iterator E) { User *CurUser = I->getUser(); @@ -739,15 +771,38 @@ } if (isa(CurUser)) { + if (ICmpInst *Cmp = dyn_cast(CurUser)) { + // If we can infer that both pointers are in the same addrspace, + // transform e.g. + // %cmp = icmp eq float* %p, %q + // into + // %cmp = icmp eq float addrspace(3)* %new_p, %new_q + + unsigned NewAS = NewV->getType()->getPointerAddressSpace(); + int SrcIdx = U.getOperandNo(); + int OtherIdx = (SrcIdx == 0) ? 1 : 0; + Value *OtherSrc = Cmp->getOperand(OtherIdx); + + if (Value *OtherNewV = ValueWithNewAddrSpace.lookup(OtherSrc)) { + if (OtherNewV->getType()->getPointerAddressSpace() == NewAS) { + Cmp->setOperand(OtherIdx, OtherNewV); + Cmp->setOperand(SrcIdx, NewV); + continue; + } + } + + // Even if the type mismatches, we can cast the constant. + if (auto *KOtherSrc = dyn_cast(OtherSrc)) { + if (isSafeToCastConstAddrSpace(KOtherSrc, NewAS)) { + Cmp->setOperand(SrcIdx, NewV); + Cmp->setOperand(OtherIdx, + ConstantExpr::getAddrSpaceCast(KOtherSrc, NewV->getType())); + continue; + } + } + } + // Otherwise, replaces the use with generic(NewV). - // TODO: Some optimization opportunities are missed. For example, in - // %0 = icmp eq float* %p, %q - // if both p and q are inferred to be shared, we can rewrite %0 as - // %0 = icmp eq float addrspace(3)* %new_p, %new_q - // instead of currently - // %generic_p = addrspacecast float addrspace(3)* %new_p to float* - // %generic_q = addrspacecast float addrspace(3)* %new_q to float* - // %0 = icmp eq float* %generic_p, %generic_q if (Instruction *I = dyn_cast(V)) { BasicBlock::iterator InsertPos = std::next(I->getIterator()); while (isa(InsertPos)) Index: test/Transforms/InferAddressSpaces/AMDGPU/icmp.ll =================================================================== --- /dev/null +++ test/Transforms/InferAddressSpaces/AMDGPU/icmp.ll @@ -0,0 +1,142 @@ +; RUN: opt -S -mtriple=amdgcn-amd-amdhsa -infer-address-spaces %s | FileCheck %s + +; CHECK-LABEL: @icmp_flat_cmp_self( +; CHECK: %cmp = icmp eq i32 addrspace(3)* %group.ptr.0, %group.ptr.0 +define i1 @icmp_flat_cmp_self(i32 addrspace(3)* %group.ptr.0) #0 { + %cast0 = addrspacecast i32 addrspace(3)* %group.ptr.0 to i32 addrspace(4)* + %cmp = icmp eq i32 addrspace(4)* %cast0, %cast0 + ret i1 %cmp +} + +; CHECK-LABEL: @icmp_flat_flat_from_group( +; CHECK: %cmp = icmp eq i32 addrspace(3)* %group.ptr.0, %group.ptr.1 +define i1 @icmp_flat_flat_from_group(i32 addrspace(3)* %group.ptr.0, i32 addrspace(3)* %group.ptr.1) #0 { + %cast0 = addrspacecast i32 addrspace(3)* %group.ptr.0 to i32 addrspace(4)* + %cast1 = addrspacecast i32 addrspace(3)* %group.ptr.1 to i32 addrspace(4)* + %cmp = icmp eq i32 addrspace(4)* %cast0, %cast1 + ret i1 %cmp +} + +; CHECK-LABEL: @icmp_mismatch_flat_from_group_private( +; CHECK: %1 = addrspacecast i32* %private.ptr.0 to i32 addrspace(4)* +; CHECK: %2 = addrspacecast i32 addrspace(3)* %group.ptr.1 to i32 addrspace(4)* +; CHECK: %cmp = icmp eq i32 addrspace(4)* %1, %2 +define i1 @icmp_mismatch_flat_from_group_private(i32* %private.ptr.0, i32 addrspace(3)* %group.ptr.1) #0 { + %cast0 = addrspacecast i32* %private.ptr.0 to i32 addrspace(4)* + %cast1 = addrspacecast i32 addrspace(3)* %group.ptr.1 to i32 addrspace(4)* + %cmp = icmp eq i32 addrspace(4)* %cast0, %cast1 + ret i1 %cmp +} + +; CHECK-LABEL: @icmp_flat_group_flat( +; CHECK: %1 = addrspacecast i32 addrspace(3)* %group.ptr.0 to i32 addrspace(4)* +; CHECK: %cmp = icmp eq i32 addrspace(4)* %1, %flat.ptr.1 +define i1 @icmp_flat_group_flat(i32 addrspace(3)* %group.ptr.0, i32 addrspace(4)* %flat.ptr.1) #0 { + %cast0 = addrspacecast i32 addrspace(3)* %group.ptr.0 to i32 addrspace(4)* + %cmp = icmp eq i32 addrspace(4)* %cast0, %flat.ptr.1 + ret i1 %cmp +} + +; CHECK-LABEL: @icmp_flat_flat_group( +; CHECK: %1 = addrspacecast i32 addrspace(3)* %group.ptr.1 to i32 addrspace(4)* +; CHECK: %cmp = icmp eq i32 addrspace(4)* %flat.ptr.0, %1 +define i1 @icmp_flat_flat_group(i32 addrspace(4)* %flat.ptr.0, i32 addrspace(3)* %group.ptr.1) #0 { + %cast1 = addrspacecast i32 addrspace(3)* %group.ptr.1 to i32 addrspace(4)* + %cmp = icmp eq i32 addrspace(4)* %flat.ptr.0, %cast1 + ret i1 %cmp +} + +; Keeping as cmp addrspace(3)* is better +; CHECK-LABEL: @icmp_flat_to_group_cmp( +; CHECK: %cast0 = addrspacecast i32 addrspace(4)* %flat.ptr.0 to i32 addrspace(3)* +; CHECK: %cast1 = addrspacecast i32 addrspace(4)* %flat.ptr.1 to i32 addrspace(3)* +; CHECK: %cmp = icmp eq i32 addrspace(3)* %cast0, %cast1 +define i1 @icmp_flat_to_group_cmp(i32 addrspace(4)* %flat.ptr.0, i32 addrspace(4)* %flat.ptr.1) #0 { + %cast0 = addrspacecast i32 addrspace(4)* %flat.ptr.0 to i32 addrspace(3)* + %cast1 = addrspacecast i32 addrspace(4)* %flat.ptr.1 to i32 addrspace(3)* + %cmp = icmp eq i32 addrspace(3)* %cast0, %cast1 + ret i1 %cmp +} + +; FIXME: Should be able to ask target about how to constant fold the +; constant cast if this is OK to change if 0 is a valid pointer. + +; CHECK-LABEL: @icmp_group_flat_cmp_null( +; CHECK: %cmp = icmp eq i32 addrspace(3)* %group.ptr.0, addrspacecast (i32 addrspace(4)* null to i32 addrspace(3)*) +define i1 @icmp_group_flat_cmp_null(i32 addrspace(3)* %group.ptr.0) #0 { + %cast0 = addrspacecast i32 addrspace(3)* %group.ptr.0 to i32 addrspace(4)* + %cmp = icmp eq i32 addrspace(4)* %cast0, null + ret i1 %cmp +} + +; CHECK-LABEL: @icmp_group_flat_cmp_constant_inttoptr( +; CHECK: %cmp = icmp eq i32 addrspace(3)* %group.ptr.0, addrspacecast (i32 addrspace(4)* inttoptr (i64 400 to i32 addrspace(4)*) to i32 addrspace(3)*) +define i1 @icmp_group_flat_cmp_constant_inttoptr(i32 addrspace(3)* %group.ptr.0) #0 { + %cast0 = addrspacecast i32 addrspace(3)* %group.ptr.0 to i32 addrspace(4)* + %cmp = icmp eq i32 addrspace(4)* %cast0, inttoptr (i64 400 to i32 addrspace(4)*) + ret i1 %cmp +} + +; CHECK-LABEL: @icmp_mismatch_flat_group_private_cmp_null( +; CHECK: %cmp = icmp eq i32 addrspace(3)* %group.ptr.0, addrspacecast (i32* null to i32 addrspace(3)*) +define i1 @icmp_mismatch_flat_group_private_cmp_null(i32 addrspace(3)* %group.ptr.0) #0 { + %cast0 = addrspacecast i32 addrspace(3)* %group.ptr.0 to i32 addrspace(4)* + %cmp = icmp eq i32 addrspace(4)* %cast0, addrspacecast (i32* null to i32 addrspace(4)*) + ret i1 %cmp +} + +@lds0 = internal addrspace(3) global i32 0, align 4 +@global0 = internal addrspace(1) global i32 0, align 4 + +; CHECK-LABEL: @icmp_mismatch_flat_group_global_cmp_gv( +; CHECK: %1 = addrspacecast i32 addrspace(3)* %group.ptr.0 to i32 addrspace(4)* +; CHECK: %cmp = icmp eq i32 addrspace(4)* %1, addrspacecast (i32 addrspace(1)* @global0 to i32 addrspace(4)*) +define i1 @icmp_mismatch_flat_group_global_cmp_gv(i32 addrspace(3)* %group.ptr.0) #0 { + %cast0 = addrspacecast i32 addrspace(3)* %group.ptr.0 to i32 addrspace(4)* + %cmp = icmp eq i32 addrspace(4)* %cast0, addrspacecast (i32 addrspace(1)* @global0 to i32 addrspace(4)*) + ret i1 %cmp +} + +; CHECK-LABEL: @icmp_mismatch_group_global_cmp_gv_gv( +; CHECK: %cmp = icmp eq i32 addrspace(4)* addrspacecast (i32 addrspace(3)* @lds0 to i32 addrspace(4)*), addrspacecast (i32 addrspace(1)* @global0 to i32 addrspace(4)*) +define i1 @icmp_mismatch_group_global_cmp_gv_gv(i32 addrspace(3)* %group.ptr.0) #0 { + %cmp = icmp eq i32 addrspace(4)* addrspacecast (i32 addrspace(3)* @lds0 to i32 addrspace(4)*), addrspacecast (i32 addrspace(1)* @global0 to i32 addrspace(4)*) + ret i1 %cmp +} + +; CHECK-LABEL: @icmp_group_flat_cmp_undef( +; CHECK: %cmp = icmp eq i32 addrspace(3)* %group.ptr.0, undef +define i1 @icmp_group_flat_cmp_undef(i32 addrspace(3)* %group.ptr.0) #0 { + %cast0 = addrspacecast i32 addrspace(3)* %group.ptr.0 to i32 addrspace(4)* + %cmp = icmp eq i32 addrspace(4)* %cast0, undef + ret i1 %cmp +} + +; Test non-canonical orders +; CHECK-LABEL: @icmp_mismatch_flat_group_private_cmp_null_swap( +; CHECK: %cmp = icmp eq i32 addrspace(3)* addrspacecast (i32* null to i32 addrspace(3)*), %group.ptr.0 +define i1 @icmp_mismatch_flat_group_private_cmp_null_swap(i32 addrspace(3)* %group.ptr.0) #0 { + %cast0 = addrspacecast i32 addrspace(3)* %group.ptr.0 to i32 addrspace(4)* + %cmp = icmp eq i32 addrspace(4)* addrspacecast (i32* null to i32 addrspace(4)*), %cast0 + ret i1 %cmp +} + +; CHECK-LABEL: @icmp_group_flat_cmp_undef_swap( +; CHECK: %cmp = icmp eq i32 addrspace(3)* undef, %group.ptr.0 +define i1 @icmp_group_flat_cmp_undef_swap(i32 addrspace(3)* %group.ptr.0) #0 { + %cast0 = addrspacecast i32 addrspace(3)* %group.ptr.0 to i32 addrspace(4)* + %cmp = icmp eq i32 addrspace(4)* undef, %cast0 + ret i1 %cmp +} + +; TODO: Should be handled +; CHECK-LABEL: @icmp_flat_flat_from_group_vector( +; CHECK: %cmp = icmp eq <2 x i32 addrspace(4)*> %cast0, %cast1 +define <2 x i1> @icmp_flat_flat_from_group_vector(<2 x i32 addrspace(3)*> %group.ptr.0, <2 x i32 addrspace(3)*> %group.ptr.1) #0 { + %cast0 = addrspacecast <2 x i32 addrspace(3)*> %group.ptr.0 to <2 x i32 addrspace(4)*> + %cast1 = addrspacecast <2 x i32 addrspace(3)*> %group.ptr.1 to <2 x i32 addrspace(4)*> + %cmp = icmp eq <2 x i32 addrspace(4)*> %cast0, %cast1 + ret <2 x i1> %cmp +} + +attributes #0 = { nounwind } Index: test/Transforms/InferAddressSpaces/AMDGPU/infer-address-space.ll =================================================================== --- test/Transforms/InferAddressSpaces/AMDGPU/infer-address-space.ll +++ test/Transforms/InferAddressSpaces/AMDGPU/infer-address-space.ll @@ -106,8 +106,7 @@ ; CHECK-LABEL: @loop( ; CHECK: %p = bitcast [10 x float] addrspace(3)* @array to float addrspace(3)* -; CHECK: %0 = addrspacecast float addrspace(3)* %p to float addrspace(4)* -; CHECK: %end = getelementptr float, float addrspace(4)* %0, i64 10 +; CHECK: %end = getelementptr float, float addrspace(3)* %p, i64 10 ; CHECK: br label %loop ; CHECK: loop: ; preds = %loop, %entry @@ -115,8 +114,8 @@ ; CHECK: %v = load float, float addrspace(3)* %i ; CHECK: call void @use(float %v) ; CHECK: %i2 = getelementptr float, float addrspace(3)* %i, i64 1 -; CHECK: %1 = addrspacecast float addrspace(3)* %i2 to float addrspace(4)* -; CHECK: %exit_cond = icmp eq float addrspace(4)* %1, %end +; CHECK: %exit_cond = icmp eq float addrspace(3)* %i2, %end + ; CHECK: br i1 %exit_cond, label %exit, label %loop define void @loop() #0 { entry: