Index: llvm/include/llvm/IR/Attributes.h =================================================================== --- llvm/include/llvm/IR/Attributes.h +++ llvm/include/llvm/IR/Attributes.h @@ -112,6 +112,13 @@ static Attribute getWithByRefType(LLVMContext &Context, Type *Ty); static Attribute getWithPreallocatedType(LLVMContext &Context, Type *Ty); + /// For a typed attribute, return the equivalent attribute with the type + /// changed to \p ReplacementTy. + Attribute getWithNewType(LLVMContext &Context, Type *ReplacementTy) { + assert(isTypeAttribute() && "this requires a typed attribute"); + return get(Context, getKindAsEnum(), ReplacementTy); + } + static Attribute::AttrKind getAttrKindFromName(StringRef AttrName); static StringRef getNameFromAttrKind(Attribute::AttrKind AttrKind); @@ -510,6 +517,17 @@ return removeAttributes(C, ArgNo + FirstArgIndex); } + /// Replace the type contained by attribute \p AttrKind at index \p ArgNo wih + /// \p ReplacementTy, preserving all other attributes. + LLVM_NODISCARD AttributeList replaceAttributeType(LLVMContext &C, + unsigned ArgNo, + Attribute::AttrKind Kind, + Type *ReplacementTy) const { + Attribute Attr = getAttribute(ArgNo, Kind); + auto Attrs = removeAttribute(C, ArgNo, Kind); + return Attrs.addAttribute(C, ArgNo, Attr.getWithNewType(C, ReplacementTy)); + } + /// \brief Add the dereferenceable attribute to the attribute set at the given /// index. Returns a new list because attribute lists are immutable. LLVM_NODISCARD AttributeList addDereferenceableAttr(LLVMContext &C, Index: llvm/lib/Bitcode/Writer/ValueEnumerator.cpp =================================================================== --- llvm/lib/Bitcode/Writer/ValueEnumerator.cpp +++ llvm/lib/Bitcode/Writer/ValueEnumerator.cpp @@ -973,6 +973,8 @@ EnumerateValue(&I); if (I.hasAttribute(Attribute::ByVal)) EnumerateType(I.getParamByValType()); + else if (I.hasAttribute(Attribute::StructRet)) + EnumerateType(I.getParamStructRetType()); } FirstFuncConstantID = Values.size(); Index: llvm/lib/IR/Core.cpp =================================================================== --- llvm/lib/IR/Core.cpp +++ llvm/lib/IR/Core.cpp @@ -146,6 +146,11 @@ return wrap(Attribute::getWithByValType(Ctx, NULL)); } + if (AttrKind == Attribute::AttrKind::StructRet) { + // Same as byval. + return wrap(Attribute::getWithStructRetType(Ctx, NULL)); + } + return wrap(Attribute::get(Ctx, AttrKind, Val)); } Index: llvm/lib/Linker/IRMover.cpp =================================================================== --- llvm/lib/Linker/IRMover.cpp +++ llvm/lib/Linker/IRMover.cpp @@ -640,14 +640,14 @@ AttributeList IRLinker::mapAttributeTypes(LLVMContext &C, AttributeList Attrs) { for (unsigned i = 0; i < Attrs.getNumAttrSets(); ++i) { - if (Attrs.hasAttribute(i, Attribute::ByVal)) { - Type *Ty = Attrs.getAttribute(i, Attribute::ByVal).getValueAsType(); - if (!Ty) - continue; - - Attrs = Attrs.removeAttribute(C, i, Attribute::ByVal); - Attrs = Attrs.addAttribute( - C, i, Attribute::getWithByValType(C, TypeMap.get(Ty))); + for (Attribute::AttrKind TypedAttr : + {Attribute::ByVal, Attribute::StructRet}) { + if (Attrs.hasAttribute(i, TypedAttr)) { + if (Type *Ty = Attrs.getAttribute(i, TypedAttr).getValueAsType()) { + Attrs = Attrs.replaceAttributeType(C, i, TypedAttr, TypeMap.get(Ty)); + break; + } + } } } return Attrs; Index: llvm/lib/Transforms/Utils/ValueMapper.cpp =================================================================== --- llvm/lib/Transforms/Utils/ValueMapper.cpp +++ llvm/lib/Transforms/Utils/ValueMapper.cpp @@ -900,14 +900,13 @@ LLVMContext &C = CB->getContext(); AttributeList Attrs = CB->getAttributes(); for (unsigned i = 0; i < Attrs.getNumAttrSets(); ++i) { - if (Attrs.hasAttribute(i, Attribute::ByVal)) { - Type *Ty = Attrs.getAttribute(i, Attribute::ByVal).getValueAsType(); - if (!Ty) - continue; - - Attrs = Attrs.removeAttribute(C, i, Attribute::ByVal); - Attrs = Attrs.addAttribute( - C, i, Attribute::getWithByValType(C, TypeMapper->remapType(Ty))); + for (Attribute::AttrKind TypedAttr : + {Attribute::ByVal, Attribute::StructRet}) { + if (Type *Ty = Attrs.getAttribute(i, TypedAttr).getValueAsType()) { + Attrs = Attrs.replaceAttributeType(C, i, TypedAttr, + TypeMapper->remapType(Ty)); + break; + } } } CB->setAttributes(Attrs); Index: llvm/test/Linker/Inputs/sret-type-input.ll =================================================================== --- /dev/null +++ llvm/test/Linker/Inputs/sret-type-input.ll @@ -0,0 +1,13 @@ +%a = type { i64 } +%struct = type { i32, i8 } + +define void @g(%a* sret(%a)) { + ret void +} + +declare void @baz(%struct* sret(%struct)) + +define void @foo(%struct* sret(%struct) %a) { + call void @baz(%struct* sret(%struct) %a) + ret void +} Index: llvm/test/Linker/sret-types.ll =================================================================== --- /dev/null +++ llvm/test/Linker/sret-types.ll @@ -0,0 +1,25 @@ +; RUN: llvm-link %s %p/Inputs/sret-type-input.ll -S | FileCheck %s + +%a = type { i64 } +%struct = type { i32, i8 } + +; CHECK-LABEL: define void @f(%a* sret(%a) %0) +define void @f(%a* sret(%a)) { + ret void +} + +; CHECK-LABEL: define void @bar( +; CHECK: call void @foo(%struct* sret(%struct) %ptr) +define void @bar() { + %ptr = alloca %struct + call void @foo(%struct* sret(%struct) %ptr) + ret void +} + +; CHECK-LABEL: define void @g(%a* sret(%a) %0) + +; CHECK-LABEL: define void @foo(%struct* sret(%struct) %a) +; CHECK-NEXT: call void @baz(%struct* sret(%struct) %a) +declare void @foo(%struct* sret(%struct) %a) + +; CHECK: declare void @baz(%struct* sret(%struct))