diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h --- a/llvm/include/llvm/CodeGen/TargetLowering.h +++ b/llvm/include/llvm/CodeGen/TargetLowering.h @@ -290,8 +290,7 @@ bool IsSwiftError : 1; bool IsCFGuardTarget : 1; MaybeAlign Alignment = None; - Type *ByValType = nullptr; - Type *PreallocatedType = nullptr; + Type *IndirectType = nullptr; ArgListEntry() : IsSExt(false), IsZExt(false), IsInReg(false), IsSRet(false), diff --git a/llvm/include/llvm/IR/InstrTypes.h b/llvm/include/llvm/IR/InstrTypes.h --- a/llvm/include/llvm/IR/InstrTypes.h +++ b/llvm/include/llvm/IR/InstrTypes.h @@ -1728,14 +1728,29 @@ /// Extract the byval type for a call or parameter. Type *getParamByValType(unsigned ArgNo) const { - Type *Ty = Attrs.getParamByValType(ArgNo); - return Ty; + if (auto *Ty = Attrs.getParamByValType(ArgNo)) + return Ty; + if (const Function *F = getCalledFunction()) + return F->getAttributes().getParamByValType(ArgNo); + return nullptr; + } + + /// Extract the inalloca type for a call or parameter. + Type *getParamInAllocaType(unsigned ArgNo) const { + if (auto *Ty = Attrs.getParamInAllocaType(ArgNo)) + return Ty; + if (const Function *F = getCalledFunction()) + return F->getAttributes().getParamInAllocaType(ArgNo); + return nullptr; } /// Extract the preallocated type for a call or parameter. Type *getParamPreallocatedType(unsigned ArgNo) const { - Type *Ty = Attrs.getParamPreallocatedType(ArgNo); - return Ty; + if (auto *Ty = Attrs.getParamPreallocatedType(ArgNo)) + return Ty; + if (const Function *F = getCalledFunction()) + return F->getAttributes().getParamPreallocatedType(ArgNo); + return nullptr; } /// Extract the number of dereferenceable bytes for a call or diff --git a/llvm/lib/CodeGen/SelectionDAG/FastISel.cpp b/llvm/lib/CodeGen/SelectionDAG/FastISel.cpp --- a/llvm/lib/CodeGen/SelectionDAG/FastISel.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/FastISel.cpp @@ -1033,7 +1033,7 @@ for (auto &Arg : CLI.getArgs()) { Type *FinalType = Arg.Ty; if (Arg.IsByVal) - FinalType = cast(Arg.Ty)->getElementType(); + FinalType = Arg.IndirectType; bool NeedsRegBlock = TLI.functionArgumentNeedsConsecutiveRegisters( FinalType, CLI.CallConv, CLI.IsVarArg); @@ -1076,10 +1076,8 @@ } MaybeAlign MemAlign = Arg.Alignment; if (Arg.IsByVal || Arg.IsInAlloca || Arg.IsPreallocated) { - PointerType *Ty = cast(Arg.Ty); - Type *ElementTy = Ty->getElementType(); - unsigned FrameSize = - DL.getTypeAllocSize(Arg.ByValType ? Arg.ByValType : ElementTy); + Type *ElementTy = Arg.IndirectType; + unsigned FrameSize = DL.getTypeAllocSize(ElementTy); // For ByVal, alignment should come from FE. BE will guess if this info // is not there, but there are cases it cannot get right. diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp @@ -9490,7 +9490,7 @@ // FIXME: Split arguments if CLI.IsPostTypeLegalization Type *FinalType = Args[i].Ty; if (Args[i].IsByVal) - FinalType = cast(Args[i].Ty)->getElementType(); + FinalType = Args[i].IndirectType; bool NeedsRegBlock = functionArgumentNeedsConsecutiveRegisters( FinalType, CLI.CallConv, CLI.IsVarArg); for (unsigned Value = 0, NumValues = ValueVTs.size(); Value != NumValues; @@ -9563,11 +9563,9 @@ } Align MemAlign; if (Args[i].IsByVal || Args[i].IsInAlloca || Args[i].IsPreallocated) { - PointerType *Ty = cast(Args[i].Ty); - Type *ElementTy = Ty->getElementType(); + Type *ElementTy = Args[i].IndirectType; - unsigned FrameSize = DL.getTypeAllocSize( - Args[i].ByValType ? Args[i].ByValType : ElementTy); + unsigned FrameSize = DL.getTypeAllocSize(ElementTy); Flags.setByValSize(FrameSize); // info is not there but there are cases it cannot get right. diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp --- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp @@ -119,15 +119,24 @@ IsSwiftAsync = Call->paramHasAttr(ArgIdx, Attribute::SwiftAsync); IsSwiftError = Call->paramHasAttr(ArgIdx, Attribute::SwiftError); Alignment = Call->getParamStackAlign(ArgIdx); - ByValType = nullptr; + assert(IsByVal + IsInAlloca + IsPreallocated <= 1 && + "can't have multiple indirect attributes"); + IndirectType = nullptr; if (IsByVal) { - ByValType = Call->getParamByValType(ArgIdx); + IndirectType = Call->getParamByValType(ArgIdx); + assert(IndirectType && "no byval type?"); if (!Alignment) Alignment = Call->getParamAlign(ArgIdx); } - PreallocatedType = nullptr; - if (IsPreallocated) - PreallocatedType = Call->getParamPreallocatedType(ArgIdx); + if (IsInAlloca) { + IndirectType = Call->getParamInAllocaType(ArgIdx); + assert(IndirectType && "no inalloca type?"); + } + + if (IsPreallocated) { + IndirectType = Call->getParamPreallocatedType(ArgIdx); + assert(IndirectType && "no preallocated type?"); + } } /// Generate a libcall taking the given operands as arguments and returning a diff --git a/llvm/test/CodeGen/Generic/mismatched-byval.ll b/llvm/test/CodeGen/Generic/mismatched-byval.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/Generic/mismatched-byval.ll @@ -0,0 +1,12 @@ +; RUN: llc < %s -o /dev/null + +; Make sure we don't crash if the argument is missing the byval attribute + +declare void @f(i32* byval(i32)) + +define void @g() { + %a = alloca i32 + store i32 2, i32* %a + call void @f(i32* %a) + ret void +}