diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp @@ -2571,57 +2571,31 @@ } void InstCombinerImpl::annotateAnyAllocSite(CallBase &Call, const TargetLibraryInfo *TLI) { - unsigned NumArgs = Call.arg_size(); - ConstantInt *Op0C = dyn_cast(Call.getOperand(0)); - ConstantInt *Op1C = - (NumArgs == 1) ? nullptr : dyn_cast(Call.getOperand(1)); - // Bail out if the allocation size is zero (or an invalid alignment of zero - // with aligned_alloc). - if ((Op0C && Op0C->isNullValue()) || (Op1C && Op1C->isNullValue())) - return; - if (isMallocLikeFn(&Call, TLI) && Op0C) { + uint64_t Size; + ObjectSizeOpts Opts; + if (getObjectSize(&Call, Size, DL, TLI, Opts) && Size > 0) { + // TODO: should be annotating these nonnull if (isOpNewLikeFn(&Call, TLI)) Call.addRetAttr(Attribute::getWithDereferenceableBytes( - Call.getContext(), Op0C->getZExtValue())); + Call.getContext(), Size)); else Call.addRetAttr(Attribute::getWithDereferenceableOrNullBytes( - Call.getContext(), Op0C->getZExtValue())); - } else if (isAlignedAllocLikeFn(&Call, TLI)) { - if (Op1C) - Call.addRetAttr(Attribute::getWithDereferenceableOrNullBytes( - Call.getContext(), Op1C->getZExtValue())); - // Add alignment attribute if alignment is a power of two constant. - if (Op0C && Op0C->getValue().ult(llvm::Value::MaximumAlignment) && - isKnownNonZero(Call.getOperand(1), DL, 0, &AC, &Call, &DT)) { - uint64_t AlignmentVal = Op0C->getZExtValue(); - if (llvm::isPowerOf2_64(AlignmentVal)) { - Call.removeRetAttr(Attribute::Alignment); - Call.addRetAttr(Attribute::getWithAlignment(Call.getContext(), - Align(AlignmentVal))); - } - } - } else if (isReallocLikeFn(&Call, TLI) && Op1C) { - Call.addRetAttr(Attribute::getWithDereferenceableOrNullBytes( - Call.getContext(), Op1C->getZExtValue())); - } else if (isCallocLikeFn(&Call, TLI) && Op0C && Op1C) { - bool Overflow; - const APInt &N = Op0C->getValue(); - APInt Size = N.umul_ov(Op1C->getValue(), Overflow); - if (!Overflow) - Call.addRetAttr(Attribute::getWithDereferenceableOrNullBytes( - Call.getContext(), Size.getZExtValue())); - } else if (isStrdupLikeFn(&Call, TLI)) { - uint64_t Len = GetStringLength(Call.getOperand(0)); - if (Len) { - // strdup - if (NumArgs == 1) - Call.addRetAttr(Attribute::getWithDereferenceableOrNullBytes( - Call.getContext(), Len)); - // strndup - else if (NumArgs == 2 && Op1C) - Call.addRetAttr(Attribute::getWithDereferenceableOrNullBytes( - Call.getContext(), std::min(Len, Op1C->getZExtValue() + 1))); + Call.getContext(), Size)); + } + + // Add alignment attribute if alignment is a power of two constant. + if (!isAlignedAllocLikeFn(&Call, TLI)) + return; + + ConstantInt *Op0C = dyn_cast(Call.getOperand(0)); + if (Op0C && Op0C->getValue().ult(llvm::Value::MaximumAlignment) && + isKnownNonZero(Call.getOperand(1), DL, 0, &AC, &Call, &DT)) { + uint64_t AlignmentVal = Op0C->getZExtValue(); + if (llvm::isPowerOf2_64(AlignmentVal)) { + Call.removeRetAttr(Attribute::Alignment); + Call.addRetAttr(Attribute::getWithAlignment(Call.getContext(), + Align(AlignmentVal))); } } } diff --git a/llvm/test/Transforms/InstCombine/deref-alloc-fns.ll b/llvm/test/Transforms/InstCombine/deref-alloc-fns.ll --- a/llvm/test/Transforms/InstCombine/deref-alloc-fns.ll +++ b/llvm/test/Transforms/InstCombine/deref-alloc-fns.ll @@ -77,7 +77,7 @@ define noalias i8* @aligned_alloc_dynamic_args(i64 %align, i64 %size) { ; CHECK-LABEL: @aligned_alloc_dynamic_args( ; CHECK-NEXT: [[CALL:%.*]] = tail call noalias dereferenceable_or_null(1024) i8* @aligned_alloc(i64 [[ALIGN:%.*]], i64 1024) -; CHECK-NEXT: [[CALL_1:%.*]] = tail call noalias i8* @aligned_alloc(i64 0, i64 1024) +; CHECK-NEXT: [[CALL_1:%.*]] = tail call noalias dereferenceable_or_null(1024) i8* @aligned_alloc(i64 0, i64 1024) ; CHECK-NEXT: [[CALL_2:%.*]] = tail call noalias i8* @aligned_alloc(i64 32, i64 [[SIZE:%.*]]) ; CHECK-NEXT: [[TMP1:%.*]] = call i8* @foo(i8* [[CALL]], i8* [[CALL_1]], i8* [[CALL_2]]) ; CHECK-NEXT: ret i8* [[CALL]]