Index: lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp =================================================================== --- lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp +++ lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp @@ -272,23 +272,6 @@ /// /// Verified in @i32_add in split-gep.ll bool canonicalizeArrayIndicesToPointerSize(GetElementPtrInst *GEP); - /// For each array index that is in the form of zext(a), convert it to sext(a) - /// if we can prove zext(a) <= max signed value of typeof(a). We prefer - /// sext(a) to zext(a), because in the special case where x + y >= 0 and - /// (x >= 0 or y >= 0), function CanTraceInto can split sext(x + y), - /// while no such case exists for zext(x + y). - /// - /// Note that - /// zext(x + y) = zext(x) + zext(y) - /// is wrong, e.g., - /// zext i32(UINT_MAX + 1) to i64 != - /// (zext i32 UINT_MAX to i64) + (zext i32 1 to i64) - /// - /// Returns true if the module changes. - /// - /// Verified in @inbounds_zext_add in split-gep.ll and @sum_of_array3 in - /// split-gep-and-gvn.ll - bool convertInBoundsZExtToSExt(GetElementPtrInst *GEP); const DataLayout *DL; }; @@ -613,43 +596,6 @@ return Changed; } -bool -SeparateConstOffsetFromGEP::convertInBoundsZExtToSExt(GetElementPtrInst *GEP) { - if (!GEP->isInBounds()) - return false; - - // TODO: consider alloca - GlobalVariable *UnderlyingObject = - dyn_cast(GEP->getPointerOperand()); - if (UnderlyingObject == nullptr) - return false; - - uint64_t ObjectSize = - DL->getTypeAllocSize(UnderlyingObject->getType()->getElementType()); - gep_type_iterator GTI = gep_type_begin(*GEP); - bool Changed = false; - for (User::op_iterator I = GEP->op_begin() + 1, E = GEP->op_end(); I != E; - ++I, ++GTI) { - if (isa(*GTI)) { - if (ZExtInst *Extended = dyn_cast(*I)) { - unsigned SrcBitWidth = - cast(Extended->getSrcTy())->getBitWidth(); - // For GEP operand zext(a), if a <= max signed value of typeof(a), then - // the sign bit of a is zero and sext(a) = zext(a). Because the GEP is - // in bounds, we know a <= ObjectSize, so the condition can be reduced - // to ObjectSize <= max signed value of typeof(a). - if (ObjectSize <= - APInt::getSignedMaxValue(SrcBitWidth).getZExtValue()) { - *I = new SExtInst(Extended->getOperand(0), Extended->getType(), - Extended->getName(), GEP); - Changed = true; - } - } - } - } - return Changed; -} - int64_t SeparateConstOffsetFromGEP::accumulateByteOffset(GetElementPtrInst *GEP, bool &NeedsExtraction) { @@ -684,9 +630,7 @@ if (GEP->hasAllConstantIndices()) return false; - bool Changed = false; - Changed |= canonicalizeArrayIndicesToPointerSize(GEP); - Changed |= convertInBoundsZExtToSExt(GEP); + bool Changed = canonicalizeArrayIndicesToPointerSize(GEP); bool NeedsExtraction; int64_t AccumulativeByteOffset = accumulateByteOffset(GEP, NeedsExtraction); Index: test/Transforms/SeparateConstOffsetFromGEP/NVPTX/split-gep-and-gvn.ll =================================================================== --- test/Transforms/SeparateConstOffsetFromGEP/NVPTX/split-gep-and-gvn.ll +++ test/Transforms/SeparateConstOffsetFromGEP/NVPTX/split-gep-and-gvn.ll @@ -99,8 +99,17 @@ ; IR: getelementptr float addrspace(3)* [[BASE_PTR]], i64 32 ; IR: getelementptr float addrspace(3)* [[BASE_PTR]], i64 33 -; Similar to @sum_of_array3, but extends array indices using zext instead of -; sext. e.g., array[zext(x + 1)][zext(y + 1)]. + +; This function loads +; array[zext(x)][zext(y)] +; array[zext(x)][zext(y +nuw 1)] +; array[zext(x +nuw 1)][zext(y)] +; array[zext(x +nuw 1)][zext(y +nuw 1)]. +; +; This function is similar to @sum_of_array, but it +; 1) extends array indices using zext instead of sext; +; 2) annotates the addition with "nuw"; otherwise, zext(x + 1) => zext(x) + 1 +; may be invalid. define void @sum_of_array3(i32 %x, i32 %y, float* nocapture %output) { .preheader: %0 = zext i32 %y to i64 @@ -109,13 +118,13 @@ %3 = addrspacecast float addrspace(3)* %2 to float* %4 = load float* %3, align 4 %5 = fadd float %4, 0.000000e+00 - %6 = add i32 %y, 1 + %6 = add nuw i32 %y, 1 %7 = zext i32 %6 to i64 %8 = getelementptr inbounds [32 x [32 x float]] addrspace(3)* @array, i64 0, i64 %1, i64 %7 %9 = addrspacecast float addrspace(3)* %8 to float* %10 = load float* %9, align 4 %11 = fadd float %5, %10 - %12 = add i32 %x, 1 + %12 = add nuw i32 %x, 1 %13 = zext i32 %12 to i64 %14 = getelementptr inbounds [32 x [32 x float]] addrspace(3)* @array, i64 0, i64 %13, i64 %0 %15 = addrspacecast float addrspace(3)* %14 to float* @@ -139,3 +148,49 @@ ; IR: getelementptr float addrspace(3)* [[BASE_PTR]], i64 1 ; IR: getelementptr float addrspace(3)* [[BASE_PTR]], i64 32 ; IR: getelementptr float addrspace(3)* [[BASE_PTR]], i64 33 + + +; This function loads +; array[zext(x)][zext(y)] +; array[zext(x)][zext(y)] +; array[zext(x) + 1][zext(y) + 1] +; array[zext(x) + 1][zext(y) + 1]. +; +; We expect the generated code to reuse the computation of +; &array[zext(x)][zext(y)]. See the expected IR and PTX for details. +define void @sum_of_array4(i32 %x, i32 %y, float* nocapture %output) { +.preheader: + %0 = zext i32 %y to i64 + %1 = zext i32 %x to i64 + %2 = getelementptr inbounds [32 x [32 x float]] addrspace(3)* @array, i64 0, i64 %1, i64 %0 + %3 = addrspacecast float addrspace(3)* %2 to float* + %4 = load float* %3, align 4 + %5 = fadd float %4, 0.000000e+00 + %6 = add i64 %0, 1 + %7 = getelementptr inbounds [32 x [32 x float]] addrspace(3)* @array, i64 0, i64 %1, i64 %6 + %8 = addrspacecast float addrspace(3)* %7 to float* + %9 = load float* %8, align 4 + %10 = fadd float %5, %9 + %11 = add i64 %1, 1 + %12 = getelementptr inbounds [32 x [32 x float]] addrspace(3)* @array, i64 0, i64 %11, i64 %0 + %13 = addrspacecast float addrspace(3)* %12 to float* + %14 = load float* %13, align 4 + %15 = fadd float %10, %14 + %16 = getelementptr inbounds [32 x [32 x float]] addrspace(3)* @array, i64 0, i64 %11, i64 %6 + %17 = addrspacecast float addrspace(3)* %16 to float* + %18 = load float* %17, align 4 + %19 = fadd float %15, %18 + store float %19, float* %output, align 4 + ret void +} +; PTX-LABEL: sum_of_array4( +; PTX: ld.shared.f32 {{%f[0-9]+}}, {{\[}}[[BASE_REG:%(rd|r)[0-9]+]]{{\]}} +; PTX: ld.shared.f32 {{%f[0-9]+}}, {{\[}}[[BASE_REG]]+4{{\]}} +; PTX: ld.shared.f32 {{%f[0-9]+}}, {{\[}}[[BASE_REG]]+128{{\]}} +; PTX: ld.shared.f32 {{%f[0-9]+}}, {{\[}}[[BASE_REG]]+132{{\]}} + +; IR-LABEL: @sum_of_array4( +; IR: [[BASE_PTR:%[a-zA-Z0-9]+]] = getelementptr inbounds [32 x [32 x float]] addrspace(3)* @array, i64 0, i64 %{{[a-zA-Z0-9]+}}, i64 %{{[a-zA-Z0-9]+}} +; IR: getelementptr float addrspace(3)* [[BASE_PTR]], i64 1 +; IR: getelementptr float addrspace(3)* [[BASE_PTR]], i64 32 +; IR: getelementptr float addrspace(3)* [[BASE_PTR]], i64 33 Index: test/Transforms/SeparateConstOffsetFromGEP/NVPTX/split-gep.ll =================================================================== --- test/Transforms/SeparateConstOffsetFromGEP/NVPTX/split-gep.ll +++ test/Transforms/SeparateConstOffsetFromGEP/NVPTX/split-gep.ll @@ -234,28 +234,3 @@ ; CHECK-LABEL: @and( ; CHECK: getelementptr [32 x [32 x float]]* @float_2d_array ; CHECK-NOT: getelementptr - -; if zext(a + b) <= max signed value of typeof(a + b), then we can prove -; a + b >= 0 and zext(a + b) == sext(a + b). If we can prove further a or b is -; non-negative, we have zext(a + b) == sext(a) + sext(b). -define float* @inbounds_zext_add(i32 %i, i4 %j) { -entry: - %0 = add i32 %i, 1 - %1 = zext i32 %0 to i64 - ; Because zext(i + 1) is an index of an in bounds GEP based on - ; float_2d_array, zext(i + 1) <= sizeof(float_2d_array) = 4096. - ; Furthermore, since typeof(i + 1) is i32 and 4096 < 2^31, we are sure the - ; sign bit of i + 1 is 0. This implies zext(i + 1) = sext(i + 1). - %2 = add i4 %j, 2 - %3 = zext i4 %2 to i64 - ; In this case, typeof(j + 2) is i4, so zext(j + 2) <= 4096 does not imply - ; the sign bit of j + 2 is 0. - %p = getelementptr inbounds [32 x [32 x float]]* @float_2d_array, i64 0, i64 %1, i64 %3 - ret float* %p -} -; CHECK-LABEL: @inbounds_zext_add( -; CHECK-NOT: add -; CHECK: add i4 %j, 2 -; CHECK: sext -; CHECK: getelementptr [32 x [32 x float]]* @float_2d_array, i64 0, i64 %{{[a-zA-Z0-9]+}}, i64 %{{[a-zA-Z0-9]+}} -; CHECK: getelementptr float* %{{[a-zA-Z0-9]+}}, i64 32