diff --git a/llvm/include/llvm/IR/InstrTypes.h b/llvm/include/llvm/IR/InstrTypes.h --- a/llvm/include/llvm/IR/InstrTypes.h +++ b/llvm/include/llvm/IR/InstrTypes.h @@ -1698,7 +1698,22 @@ } /// Extract the alignment of the return value. - MaybeAlign getRetAlign() const { return Attrs.getRetAlignment(); } + /// If CallSiteOnly is true, look into the attribute that is explicitly + /// attached to this call site only. Otherwise, look into the called + /// function's declaration/definition as well and return the maximum value. + MaybeAlign getRetAlign(bool CallSiteOnly = true) const { + MaybeAlign A = Attrs.getRetAlignment(); + auto *Fn = getCalledFunction(); + if (!CallSiteOnly && Fn) { + if (Fn->hasAttribute(AttributeList::ReturnIndex, Attribute::Alignment)) { + A = max(A, Fn->getAttribute(AttributeList::ReturnIndex, + llvm::Attribute::Alignment) + .getAlignment() + .getValue()); + } + } + return A; + } /// Extract the alignment for a call or parameter (0=unknown). LLVM_ATTRIBUTE_DEPRECATED(unsigned getParamAlignment(unsigned ArgNo) const, @@ -1709,8 +1724,17 @@ } /// Extract the alignment for a call or parameter (0=unknown). - MaybeAlign getParamAlign(unsigned ArgNo) const { - return Attrs.getParamAlignment(ArgNo); + /// If CallSiteOnly is true, look into the attribute that is explicitly + /// attached to this call site only. Otherwise, look into the called + /// function's declaration/definition as well and return the maximum value. + MaybeAlign getParamAlign(unsigned ArgNo, bool CallSiteOnly = true) const { + MaybeAlign A = Attrs.getParamAlignment(ArgNo); + auto *Fn = getCalledFunction(); + if (!CallSiteOnly && Fn && ArgNo < Fn->arg_size() && + Fn->hasParamAttribute(ArgNo, Attribute::Alignment)) { + A = max(A, Fn->getParamAlign(ArgNo).getValue()); + } + return A; } /// Extract the byval type for a call or parameter. @@ -1727,14 +1751,37 @@ /// Extract the number of dereferenceable bytes for a call or /// parameter (0=unknown). - uint64_t getDereferenceableBytes(unsigned i) const { - return Attrs.getDereferenceableBytes(i); + /// If CallSiteOnly is true, look into the attribute that is explicitly + /// attached to this call site only. Otherwise, look into the called + /// function's declaration/definition as well and return the maximum value. + uint64_t getDereferenceableBytes(unsigned i, bool CallSiteOnly = true) const { + uint64_t Bytes = Attrs.getDereferenceableBytes(i); + auto *Fn = getCalledFunction(); + if (!CallSiteOnly && Fn) { + bool Flag = i == AttributeList::ReturnIndex || + (i - AttributeList::FirstArgIndex) < Fn->arg_size(); + if (Flag) + Bytes = std::max(Bytes, Fn->getDereferenceableBytes(i)); + } + return Bytes; } /// Extract the number of dereferenceable_or_null bytes for a call or /// parameter (0=unknown). - uint64_t getDereferenceableOrNullBytes(unsigned i) const { - return Attrs.getDereferenceableOrNullBytes(i); + /// If CallSiteOnly is true, look into the attribute that is explicitly + /// attached to this call site only. Otherwise, look into the called + /// function's declaration/definition as well and return the maximum value. + uint64_t getDereferenceableOrNullBytes(unsigned i, + bool CallSiteOnly = true) const { + uint64_t Bytes = Attrs.getDereferenceableOrNullBytes(i); + auto *Fn = getCalledFunction(); + if (!CallSiteOnly && Fn) { + bool Flag = i == AttributeList::ReturnIndex || + (i - AttributeList::FirstArgIndex) < Fn->arg_size(); + if (Flag) + Bytes = std::max(Bytes, Fn->getDereferenceableOrNullBytes(i)); + } + return Bytes; } /// Return true if the return value is known to be not null. diff --git a/llvm/include/llvm/IR/IntrinsicInst.h b/llvm/include/llvm/IR/IntrinsicInst.h --- a/llvm/include/llvm/IR/IntrinsicInst.h +++ b/llvm/include/llvm/IR/IntrinsicInst.h @@ -452,11 +452,13 @@ /// FIXME: Remove this function once transition to Align is over. /// Use getDestAlign() instead. unsigned getDestAlignment() const { - if (auto MA = getParamAlign(ARG_DEST)) + if (auto MA = getParamAlign(ARG_DEST, /*CallSiteOnly=*/false)) return MA->value(); return 0; } - MaybeAlign getDestAlign() const { return getParamAlign(ARG_DEST); } + MaybeAlign getDestAlign() const { + return getParamAlign(ARG_DEST, /*CallSiteOnly=*/false); + } /// Set the specified arguments of the instruction. void setDest(Value *Ptr) { @@ -517,13 +519,13 @@ /// FIXME: Remove this function once transition to Align is over. /// Use getSourceAlign() instead. unsigned getSourceAlignment() const { - if (auto MA = BaseCL::getParamAlign(ARG_SOURCE)) + if (auto MA = BaseCL::getParamAlign(ARG_SOURCE, /*CallSiteOnly=*/false)) return MA->value(); return 0; } MaybeAlign getSourceAlign() const { - return BaseCL::getParamAlign(ARG_SOURCE); + return BaseCL::getParamAlign(ARG_SOURCE, /*CallSiteOnly=*/false); } void setSource(Value *Ptr) { diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp @@ -4753,9 +4753,7 @@ } else Result = lowerRangeToAssertZExt(DAG, I, Result); - MaybeAlign Alignment = I.getRetAlign(); - if (!Alignment) - Alignment = F->getAttributes().getRetAlignment(); + MaybeAlign Alignment = I.getRetAlign(false); // Insert `assertalign` node if there's an alignment. if (InsertAssertAlign && Alignment) { Result = diff --git a/llvm/lib/IR/Instructions.cpp b/llvm/lib/IR/Instructions.cpp --- a/llvm/lib/IR/Instructions.cpp +++ b/llvm/lib/IR/Instructions.cpp @@ -301,9 +301,9 @@ if (hasRetAttr(Attribute::NonNull)) return true; - if (getDereferenceableBytes(AttributeList::ReturnIndex) > 0 && - !NullPointerIsDefined(getCaller(), - getType()->getPointerAddressSpace())) + if (getDereferenceableBytes(AttributeList::ReturnIndex, + /*CallSiteOnly=*/false) > 0 && + !NullPointerIsDefined(getCaller(), getType()->getPointerAddressSpace())) return true; return false; diff --git a/llvm/lib/IR/Value.cpp b/llvm/lib/IR/Value.cpp --- a/llvm/lib/IR/Value.cpp +++ b/llvm/lib/IR/Value.cpp @@ -746,10 +746,11 @@ CanBeNull = true; } } else if (const auto *Call = dyn_cast(this)) { - DerefBytes = Call->getDereferenceableBytes(AttributeList::ReturnIndex); + DerefBytes = Call->getDereferenceableBytes(AttributeList::ReturnIndex, + /*CallSiteOnly=*/false); if (DerefBytes == 0) { - DerefBytes = - Call->getDereferenceableOrNullBytes(AttributeList::ReturnIndex); + DerefBytes = Call->getDereferenceableOrNullBytes( + AttributeList::ReturnIndex, /*CallSiteOnly=*/false); CanBeNull = true; } } else if (const LoadInst *LI = dyn_cast(this)) { diff --git a/llvm/unittests/IR/CMakeLists.txt b/llvm/unittests/IR/CMakeLists.txt --- a/llvm/unittests/IR/CMakeLists.txt +++ b/llvm/unittests/IR/CMakeLists.txt @@ -14,6 +14,7 @@ AttributesTest.cpp BasicBlockTest.cpp CFGBuilder.cpp + CallBaseTest.cpp ConstantRangeTest.cpp ConstantsTest.cpp DataLayoutTest.cpp diff --git a/llvm/unittests/IR/CallBaseTest.cpp b/llvm/unittests/IR/CallBaseTest.cpp new file mode 100644 --- /dev/null +++ b/llvm/unittests/IR/CallBaseTest.cpp @@ -0,0 +1,122 @@ +//===--------------- CallBaseTest.cpp - CallBase Unittests ----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/AsmParser/Parser.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Module.h" +#include "llvm/Support/SourceMgr.h" +#include "gtest/gtest.h" + +using namespace llvm; + +static std::unique_ptr parseIR(LLVMContext &C, const char *IR) { + SMDiagnostic Err; + std::unique_ptr Mod = parseAssemblyString(IR, Err, C); + if (!Mod) + Err.print("CallBaseTests", errs()); + return Mod; +} + +TEST(CallBase, Call) { + LLVMContext C; + + const char *IR = + "declare align 4 dereferenceable(100) dereferenceable_or_null(200) " + " i8* @g(i8* align(16) " + " dereferenceable(960) " + " dereferenceable_or_null(9600) %p)\n" + "define void @f1() {\n" + " call align 2 dereferenceable(60) dereferenceable_or_null(80) " + " i8* @g(i8* align(8) " + " dereferenceable(480) " + " dereferenceable_or_null(4800) null)\n" + " ret void\n" + "}\n" + "define void @f2() {\n" + " call align 8 dereferenceable(112) dereferenceable_or_null(208) " + " i8* @g(i8* align(32) " + " dereferenceable(1920) " + " dereferenceable_or_null(19200) null)\n" + " ret void\n" + "}\n"; + + std::unique_ptr M = parseIR(C, IR); + ASSERT_TRUE(M); + + auto RetIdx = AttributeList::ReturnIndex; + auto FstIdx = AttributeList::FirstArgIndex; + + { + Function *F = M->getFunction("f1"); + ASSERT_NE(F, nullptr); + const CallBase *I = dyn_cast(&*F->begin()->begin()); + MaybeAlign A; + + // I->getXX(_, true) < I->getXX(_, false) because the callsite attributes + // in f2 has smaller values than the attributes in the declaration of g. + + ASSERT_EQ(I->getDereferenceableBytes(RetIdx, true), 60u); + ASSERT_EQ(I->getDereferenceableBytes(RetIdx, false), 100u); + ASSERT_EQ(I->getDereferenceableOrNullBytes(RetIdx, true), 80u); + ASSERT_EQ(I->getDereferenceableOrNullBytes(RetIdx, false), 200u); + + A = I->getRetAlign(true); + ASSERT_TRUE(A.hasValue()); + ASSERT_EQ(A.getValue().value(), 2u); + A = I->getRetAlign(false); + ASSERT_TRUE(A.hasValue()); + ASSERT_EQ(A.getValue().value(), 4u); + + ASSERT_EQ(I->getDereferenceableBytes(FstIdx, true), 480u); + ASSERT_EQ(I->getDereferenceableBytes(FstIdx, false), 960u); + ASSERT_EQ(I->getDereferenceableOrNullBytes(FstIdx, true), 4800u); + ASSERT_EQ(I->getDereferenceableOrNullBytes(FstIdx, false), 9600u); + + A = I->getParamAlign(0, true); + ASSERT_TRUE(A.hasValue()); + ASSERT_EQ(A.getValue().value(), 8u); + A = I->getParamAlign(0, false); + ASSERT_TRUE(A.hasValue()); + ASSERT_EQ(A.getValue().value(), 16u); + } + + { + Function *F = M->getFunction("f2"); + ASSERT_NE(F, nullptr); + const CallBase *I = dyn_cast(&*F->begin()->begin()); + MaybeAlign A; + + // I->getXX(_, true) == I->getXX(_, false) because the callsite attributes + // in f1 has larger values than the attributes in the declaration of g. + + ASSERT_EQ(I->getDereferenceableBytes(RetIdx, true), 112u); + ASSERT_EQ(I->getDereferenceableBytes(RetIdx, false), 112u); + ASSERT_EQ(I->getDereferenceableOrNullBytes(RetIdx, true), 208u); + ASSERT_EQ(I->getDereferenceableOrNullBytes(RetIdx, false), 208u); + + A = I->getRetAlign(true); + ASSERT_TRUE(A.hasValue()); + ASSERT_EQ(A.getValue().value(), 8u); + A = I->getRetAlign(false); + ASSERT_TRUE(A.hasValue()); + ASSERT_EQ(A.getValue().value(), 8u); + + ASSERT_EQ(I->getDereferenceableBytes(FstIdx, true), 1920u); + ASSERT_EQ(I->getDereferenceableBytes(FstIdx, false), 1920u); + ASSERT_EQ(I->getDereferenceableOrNullBytes(FstIdx, true), 19200u); + ASSERT_EQ(I->getDereferenceableOrNullBytes(FstIdx, false), 19200u); + + A = I->getParamAlign(0, true); + ASSERT_TRUE(A.hasValue()); + ASSERT_EQ(A.getValue().value(), 32u); + A = I->getParamAlign(0, false); + ASSERT_TRUE(A.hasValue()); + ASSERT_EQ(A.getValue().value(), 32u); + } +}