diff --git a/llvm/include/llvm/IR/InstrTypes.h b/llvm/include/llvm/IR/InstrTypes.h --- a/llvm/include/llvm/IR/InstrTypes.h +++ b/llvm/include/llvm/IR/InstrTypes.h @@ -1728,14 +1728,29 @@ /// Extract the byval type for a call or parameter. Type *getParamByValType(unsigned ArgNo) const { - Type *Ty = Attrs.getParamByValType(ArgNo); - return Ty ? Ty : getArgOperand(ArgNo)->getType()->getPointerElementType(); + if (auto *Ty = Attrs.getParamByValType(ArgNo)) + return Ty; + if (const Function *F = getCalledFunction()) + return F->getAttributes().getParamByValType(ArgNo); + return nullptr; } /// Extract the preallocated type for a call or parameter. Type *getParamPreallocatedType(unsigned ArgNo) const { - Type *Ty = Attrs.getParamPreallocatedType(ArgNo); - return Ty ? Ty : getArgOperand(ArgNo)->getType()->getPointerElementType(); + if (auto *Ty = Attrs.getParamPreallocatedType(ArgNo)) + return Ty; + if (const Function *F = getCalledFunction()) + return F->getAttributes().getParamPreallocatedType(ArgNo); + return nullptr; + } + + /// Extract the preallocated type for a call or parameter. + Type *getParamInAllocaType(unsigned ArgNo) const { + if (auto *Ty = Attrs.getParamInAllocaType(ArgNo)) + return Ty; + if (const Function *F = getCalledFunction()) + return F->getAttributes().getParamInAllocaType(ArgNo); + return nullptr; } /// Extract the number of dereferenceable bytes for a call or diff --git a/llvm/unittests/IR/AttributesTest.cpp b/llvm/unittests/IR/AttributesTest.cpp --- a/llvm/unittests/IR/AttributesTest.cpp +++ b/llvm/unittests/IR/AttributesTest.cpp @@ -7,8 +7,12 @@ //===----------------------------------------------------------------------===// #include "llvm/IR/Attributes.h" -#include "llvm/IR/LLVMContext.h" +#include "llvm/AsmParser/Parser.h" #include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Module.h" +#include "llvm/Support/SourceMgr.h" #include "gtest/gtest.h" using namespace llvm; @@ -252,4 +256,44 @@ } } +TEST(Attributes, MismatchedABIAttrs) { + const char *IRString = R"IR( + declare void @f1(i32* byval(i32)) + define void @g() { + call void @f1(i32* null) + ret void + } + declare void @f2(i32* preallocated(i32)) + define void @h() { + call void @f2(i32* null) + ret void + } + declare void @f3(i32* inalloca(i32)) + define void @i() { + call void @f3(i32* null) + ret void + } + )IR"; + + SMDiagnostic Err; + LLVMContext Context; + std::unique_ptr M = parseAssemblyString(IRString, Err, Context); + ASSERT_TRUE(M); + + { + auto *I = cast(&M->getFunction("g")->getEntryBlock().front()); + ASSERT_TRUE(I->isByValArgument(0)); + ASSERT_TRUE(I->getParamByValType(0)); + } + { + auto *I = cast(&M->getFunction("h")->getEntryBlock().front()); + ASSERT_TRUE(I->getParamPreallocatedType(0)); + } + { + auto *I = cast(&M->getFunction("i")->getEntryBlock().front()); + ASSERT_TRUE(I->isInAllocaArgument(0)); + ASSERT_TRUE(I->getParamInAllocaType(0)); + } +} + } // end anonymous namespace