Index: lib/IR/Function.cpp =================================================================== --- lib/IR/Function.cpp +++ lib/IR/Function.cpp @@ -154,10 +154,8 @@ /// it in its containing function. bool Argument::hasStructRetAttr() const { if (!getType()->isPointerTy()) return false; - if (this != getParent()->arg_begin()) - return false; // StructRet param must be first param return getParent()->getAttributes(). - hasAttribute(1, Attribute::StructRet); + hasAttribute(getArgNo()+1, Attribute::StructRet); } /// hasReturnedAttr - Return true if this argument has the returned attribute on Index: lib/Transforms/IPO/ArgumentPromotion.cpp =================================================================== --- lib/Transforms/IPO/ArgumentPromotion.cpp +++ lib/Transforms/IPO/ArgumentPromotion.cpp @@ -84,7 +84,8 @@ bool isSafeToPromoteArgument(Argument *Arg, bool isByVal) const; CallGraphNode *DoPromotion(Function *F, SmallPtrSetImpl &ArgsToPromote, - SmallPtrSetImpl &ByValArgsToTransform); + SmallPtrSetImpl &ByValArgsToTransform, + unsigned StructRetArgIndex); using llvm::Pass::doInitialization; bool doInitialization(CallGraph &CG) override; @@ -241,10 +242,17 @@ // add it to ArgsToPromote. SmallPtrSet ArgsToPromote; SmallPtrSet ByValArgsToTransform; + unsigned StructRetArgIndex = 0; for (unsigned i = 0, e = PointerArgs.size(); i != e; ++i) { Argument *PtrArg = PointerArgs[i]; Type *AgTy = cast(PtrArg->getType())->getElementType(); + // Keep track of the location of the sret argument. + if (PtrArg->hasStructRetAttr()) { + StructRetArgIndex = PtrArg->getArgNo() + 1; + continue; + } + // If this is a byval argument, and if the aggregate type is small, just // pass the elements, which is always safe, if the passed value is densely // packed or if we can prove the padding bytes are never accessed. This does @@ -305,7 +313,7 @@ if (ArgsToPromote.empty() && ByValArgsToTransform.empty()) return nullptr; - return DoPromotion(F, ArgsToPromote, ByValArgsToTransform); + return DoPromotion(F, ArgsToPromote, ByValArgsToTransform, StructRetArgIndex); } /// AllCallersPassInValidPointerForArgument - Return true if we can prove that @@ -579,7 +587,8 @@ /// safe to do so. CallGraphNode *ArgPromotion::DoPromotion(Function *F, SmallPtrSetImpl &ArgsToPromote, - SmallPtrSetImpl &ByValArgsToTransform) { + SmallPtrSetImpl &ByValArgsToTransform, + unsigned StructRetArgIndex) { // Start by computing a new prototype for the function, which is the same as // the old function, but has modified arguments. @@ -616,9 +625,24 @@ PAL.getRetAttributes())); // First, determine the new argument list + + if (StructRetArgIndex) { + // Add the sret argument first now in case the first argument is being + // promoted to more than one argument, which would otherwise violate the IR + // constraint that 'sret' may only be applied to the first or second + // argument. + Params.push_back(F->getFunctionType()->getParamType(StructRetArgIndex - 1)); + AttributeSet attrs = PAL.getParamAttributes(StructRetArgIndex); + AttrBuilder B(attrs, StructRetArgIndex); + AttributesVec.push_back( + AttributeSet::get(F->getContext(), Params.size(), B)); + } + unsigned ArgIndex = 1; for (Function::arg_iterator I = F->arg_begin(), E = F->arg_end(); I != E; ++I, ++ArgIndex) { + if (ArgIndex == StructRetArgIndex) + continue; if (ByValArgsToTransform.count(I)) { // Simple byval argument? Just add all the struct element types. Type *AgTy = cast(I->getType())->getElementType(); @@ -751,12 +775,22 @@ AttributesVec.push_back(AttributeSet::get(F->getContext(), CallPAL.getRetAttributes())); + // Add the sret argument. + if (StructRetArgIndex) { + Args.push_back(CS.getArgument(StructRetArgIndex - 1)); + AttrBuilder B(CallPAL, StructRetArgIndex); + AttributesVec. + push_back(AttributeSet::get(F->getContext(), Args.size(), B)); + } + // Loop over the operands, inserting GEP and loads in the caller as // appropriate. CallSite::arg_iterator AI = CS.arg_begin(); ArgIndex = 1; for (Function::arg_iterator I = F->arg_begin(), E = F->arg_end(); - I != E; ++I, ++AI, ++ArgIndex) + I != E; ++I, ++AI, ++ArgIndex) { + if (ArgIndex == StructRetArgIndex) + continue; if (!ArgsToPromote.count(I) && !ByValArgsToTransform.count(I)) { Args.push_back(*AI); // Unmodified argument @@ -822,6 +856,7 @@ AA.copyValue(OrigLoad, Args.back()); } } + } // Push any varargs arguments on the list. for (; AI != CS.arg_end(); ++AI, ++ArgIndex) { @@ -881,10 +916,21 @@ NF->getBasicBlockList().splice(NF->begin(), F->getBasicBlockList()); // Loop over the argument list, transferring uses of the old arguments over to - // the new arguments, also transferring over the names as well. - // - for (Function::arg_iterator I = F->arg_begin(), E = F->arg_end(), - I2 = NF->arg_begin(); I != E; ++I) { + // the new arguments, also transferring over the names as well, but deal with + // the sret argument first. + Function::arg_iterator I2 = NF->arg_begin(); + if (StructRetArgIndex) { + Function::arg_iterator I = F->arg_begin(); + std::advance(I, StructRetArgIndex - 1); + I->replaceAllUsesWith(I2); + I2->takeName(I); + AA.replaceWithNewValue(I, I2); + ++I2; + } + + for (Function::arg_iterator I = F->arg_begin(), E = F->arg_end(); I != E; ++I) { + if (I->hasStructRetAttr()) + continue; if (!ArgsToPromote.count(I) && !ByValArgsToTransform.count(I)) { // If this is an unmodified argument, move the name and users over to the // new version. Index: test/Transforms/ArgumentPromotion/sret.ll =================================================================== --- /dev/null +++ test/Transforms/ArgumentPromotion/sret.ll @@ -0,0 +1,33 @@ +; RUN: opt < %s -argpromotion -S | FileCheck %s + +target datalayout = "e-m:w-i64:64-f80:128-n8:16:32:64-S128" +target triple = "x86_64-pc-windows-msvc" + +; CHECK: define internal void @add(i32* sret %[[SR:.*]], i32 %[[THIS1:.*]], i32 %[[THIS2:.*]]) +define internal void @add({i32, i32}* %this, i32* sret %r) { + %ap = getelementptr {i32, i32}, {i32, i32}* %this, i32 0, i32 0 + %bp = getelementptr {i32, i32}, {i32, i32}* %this, i32 0, i32 1 + %a = load i32, i32* %ap + %b = load i32, i32* %bp + ; CHECK: %[[AB:.*]] = add i32 %[[THIS1]], %[[THIS2]] + %ab = add i32 %a, %b + ; CHECK: store i32 %[[AB]], i32* %[[SR]] + store i32 %ab, i32* %r + ret void +} + +; CHECK: define void @f() +define void @f() { + ; CHECK: %[[R:.*]] = alloca i32 + ; CHECK: %[[PAIR:.*]] = alloca { i32, i32 } + %r = alloca i32 + %pair = alloca {i32, i32} + + ; CHECK: %[[PAIR1P:.*]] = getelementptr { i32, i32 }, { i32, i32 }* %[[PAIR]], i64 0, i32 0 + ; CHECK: %[[PAIR1:.*]] = load i32, i32* %[[PAIR1P]] + ; CHECK: %[[PAIR2P:.*]] = getelementptr { i32, i32 }, { i32, i32 }* %[[PAIR]], i64 0, i32 1 + ; CHECK: %[[PAIR2:.*]] = load i32, i32* %[[PAIR2P]] + ; CHECK: call void @add(i32* sret %[[R]], i32 %[[PAIR1]], i32 %[[PAIR2]]) + call void @add({i32, i32}* %pair, i32* sret %r) + ret void +}