diff --git a/llvm/include/llvm/IR/InstrTypes.h b/llvm/include/llvm/IR/InstrTypes.h --- a/llvm/include/llvm/IR/InstrTypes.h +++ b/llvm/include/llvm/IR/InstrTypes.h @@ -1728,14 +1728,16 @@ /// Extract the byval type for a call or parameter. Type *getParamByValType(unsigned ArgNo) const { - Type *Ty = Attrs.getParamByValType(ArgNo); - return Ty ? Ty : getArgOperand(ArgNo)->getType()->getPointerElementType(); + if (auto *Ty = Attrs.getParamByValType(ArgNo)) + return Ty; + if (const Function *F = getCalledFunction()) + return F->getAttributes().getParamByValType(ArgNo); + return nullptr; } /// Extract the preallocated type for a call or parameter. Type *getParamPreallocatedType(unsigned ArgNo) const { - Type *Ty = Attrs.getParamPreallocatedType(ArgNo); - return Ty ? Ty : getArgOperand(ArgNo)->getType()->getPointerElementType(); + return Attrs.getParamPreallocatedType(ArgNo); } /// Extract the number of dereferenceable bytes for a call or diff --git a/llvm/unittests/IR/AttributesTest.cpp b/llvm/unittests/IR/AttributesTest.cpp --- a/llvm/unittests/IR/AttributesTest.cpp +++ b/llvm/unittests/IR/AttributesTest.cpp @@ -7,8 +7,12 @@ //===----------------------------------------------------------------------===// #include "llvm/IR/Attributes.h" -#include "llvm/IR/LLVMContext.h" +#include "llvm/AsmParser/Parser.h" #include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Module.h" +#include "llvm/Support/SourceMgr.h" #include "gtest/gtest.h" using namespace llvm; @@ -252,4 +256,24 @@ } } +TEST(Attributes, MismatchedByval) { + const char *IRString = R"IR( + declare void @f(i32* byval(i32)) + define void @g() { + call void @f(i32* null) + ret void + } + )IR"; + + SMDiagnostic Err; + LLVMContext Context; + std::unique_ptr M = parseAssemblyString(IRString, Err, Context); + ASSERT_TRUE(M); + + auto *G = M->getFunction("g"); + ASSERT_TRUE(G); + auto *I = cast(&G->getEntryBlock().front()); + ASSERT_TRUE(I->getParamByValType(0)); +} + } // end anonymous namespace