diff --git a/llvm/include/llvm/Transforms/Utils/CodeExtractor.h b/llvm/include/llvm/Transforms/Utils/CodeExtractor.h
--- a/llvm/include/llvm/Transforms/Utils/CodeExtractor.h
+++ b/llvm/include/llvm/Transforms/Utils/CodeExtractor.h
@@ -137,7 +137,8 @@
     ///
     /// Returns zero when called on a CodeExtractor instance where isEligible
     /// returns false.
-    Function *extractCodeRegion(const CodeExtractorAnalysisCache &CEAC);
+    Function *extractCodeRegion(const CodeExtractorAnalysisCache &CEAC,
+                                const ValueSet &ExcludeArgsFromAggregate = {});
 
     /// Verify that assumption cache isn't stale after a region is extracted.
     /// Returns true when verifier finds errors. AssumptionCache is passed as
@@ -212,11 +213,11 @@
     void severSplitPHINodesOfExits(const SmallPtrSetImpl<BasicBlock *> &Exits);
     void splitReturnBlocks();
 
-    Function *constructFunction(const ValueSet &inputs,
-                                const ValueSet &outputs,
-                                BasicBlock *header,
-                                BasicBlock *newRootNode, BasicBlock *newHeader,
-                                Function *oldFunction, Module *M);
+    Function *constructFunction(const ValueSet &inputs, const ValueSet &outputs,
+                                const ValueSet &ExcludeArgsFromAggregate,
+                                BasicBlock *header, BasicBlock *newRootNode,
+                                BasicBlock *newHeader, Function *oldFunction,
+                                Module *M);
 
     void moveCodeToFunction(Function *newFunction);
 
@@ -225,9 +226,10 @@
         DenseMap<BasicBlock *, BlockFrequency> &ExitWeights,
         BranchProbabilityInfo *BPI);
 
-    CallInst *emitCallAndSwitchStatement(Function *newFunction,
-                                         BasicBlock *newHeader,
-                                         ValueSet &inputs, ValueSet &outputs);
+    CallInst *
+    emitCallAndSwitchStatement(Function *newFunction, BasicBlock *newHeader,
+                               ValueSet &inputs, ValueSet &outputs,
+                               const ValueSet &ExcludeArgsFromAggregate);
   };
 
 } // end namespace llvm
diff --git a/llvm/lib/Transforms/Utils/CodeExtractor.cpp b/llvm/lib/Transforms/Utils/CodeExtractor.cpp
--- a/llvm/lib/Transforms/Utils/CodeExtractor.cpp
+++ b/llvm/lib/Transforms/Utils/CodeExtractor.cpp
@@ -811,13 +811,11 @@
 
 /// constructFunction - make a function based on inputs and outputs, as follows:
 /// f(in0, ..., inN, out0, ..., outN)
-Function *CodeExtractor::constructFunction(const ValueSet &inputs,
-                                           const ValueSet &outputs,
-                                           BasicBlock *header,
-                                           BasicBlock *newRootNode,
-                                           BasicBlock *newHeader,
-                                           Function *oldFunction,
-                                           Module *M) {
+Function *CodeExtractor::constructFunction(
+    const ValueSet &inputs, const ValueSet &outputs,
+    const ValueSet &ExcludeArgsFromAggregate, BasicBlock *header,
+    BasicBlock *newRootNode, BasicBlock *newHeader, Function *oldFunction,
+    Module *M) {
   LLVM_DEBUG(dbgs() << "inputs: " << inputs.size() << "\n");
   LLVM_DEBUG(dbgs() << "outputs: " << outputs.size() << "\n");
 
@@ -830,20 +828,40 @@
   }
 
   std::vector<Type *> paramTy;
+  std::vector<Type *> ScalarParamTy;
+  std::vector<Type *> AggParamTy;
+  ValueSet StructValues;
 
   // Add the types of the input values to the function's argument list
   for (Value *value : inputs) {
     LLVM_DEBUG(dbgs() << "value used in func: " << *value << "\n");
-    paramTy.push_back(value->getType());
+    if (AggregateArgs && !ExcludeArgsFromAggregate.contains(value)) {
+      AggParamTy.push_back(value->getType());
+      StructValues.insert(value);
+    } else
+      ScalarParamTy.push_back(value->getType());
   }
 
   // Add the types of the output values to the function's argument list.
   for (Value *output : outputs) {
     LLVM_DEBUG(dbgs() << "instr used in func: " << *output << "\n");
-    if (AggregateArgs)
-      paramTy.push_back(output->getType());
-    else
-      paramTy.push_back(PointerType::getUnqual(output->getType()));
+    if (AggregateArgs && !ExcludeArgsFromAggregate.contains(output)) {
+      AggParamTy.push_back(output->getType());
+      StructValues.insert(output);
+    } else
+      ScalarParamTy.push_back(PointerType::getUnqual(output->getType()));
+  }
+
+  assert(
+      (ScalarParamTy.size() + AggParamTy.size()) ==
+          (inputs.size() + outputs.size()) &&
+      "Number of scalar and aggregate params does not match inputs, outputs");
+
+  paramTy = ScalarParamTy;
+  StructType *StructTy = nullptr;
+  if (AggregateArgs && !AggParamTy.empty()) {
+    StructTy = StructType::get(M->getContext(), AggParamTy);
+    paramTy.push_back(PointerType::getUnqual(StructTy));
   }
 
   LLVM_DEBUG({
@@ -853,12 +871,6 @@
     dbgs() << ")\n";
   });
 
-  StructType *StructTy = nullptr;
-  if (AggregateArgs && (inputs.size() + outputs.size() > 0)) {
-    StructTy = StructType::get(M->getContext(), paramTy);
-    paramTy.clear();
-    paramTy.push_back(PointerType::getUnqual(StructTy));
-  }
   FunctionType *funcType =
                   FunctionType::get(RetTy, paramTy,
                                     AllowVarArgs && oldFunction->isVarArg());
@@ -981,24 +993,27 @@
   }
   newFunction->getBasicBlockList().push_back(newRootNode);
 
-  // Create an iterator to name all of the arguments we inserted.
-  Function::arg_iterator AI = newFunction->arg_begin();
+  // Create scalar and aggregate iterators to name all of the arguments we
+  // inserted.
+  Function::arg_iterator ScalarAI = newFunction->arg_begin();
+  Function::arg_iterator AggAI = std::next(ScalarAI, ScalarParamTy.size());
 
   // Rewrite all users of the inputs in the extracted region to use the
   // arguments (or appropriate addressing into struct) instead.
-  for (unsigned i = 0, e = inputs.size(); i != e; ++i) {
+  for (unsigned i = 0, e = inputs.size(), aggIdx = 0; i != e; ++i) {
     Value *RewriteVal;
-    if (AggregateArgs) {
+    if (AggregateArgs && StructValues.contains(inputs[i])) {
       Value *Idx[2];
       Idx[0] = Constant::getNullValue(Type::getInt32Ty(header->getContext()));
-      Idx[1] = ConstantInt::get(Type::getInt32Ty(header->getContext()), i);
+      Idx[1] = ConstantInt::get(Type::getInt32Ty(header->getContext()), aggIdx);
       Instruction *TI = newFunction->begin()->getTerminator();
       GetElementPtrInst *GEP = GetElementPtrInst::Create(
-          StructTy, &*AI, Idx, "gep_" + inputs[i]->getName(), TI);
-      RewriteVal = new LoadInst(StructTy->getElementType(i), GEP,
+          StructTy, &*AggAI, Idx, "gep_" + inputs[i]->getName(), TI);
+      RewriteVal = new LoadInst(StructTy->getElementType(aggIdx), GEP,
                                 "loadgep_" + inputs[i]->getName(), TI);
+      ++aggIdx;
     } else
-      RewriteVal = &*AI++;
+      RewriteVal = &*ScalarAI++;
 
     std::vector<User *> Users(inputs[i]->user_begin(), inputs[i]->user_end());
     for (User *use : Users)
@@ -1008,12 +1023,14 @@
   }
 
   // Set names for input and output arguments.
-  if (!AggregateArgs) {
-    AI = newFunction->arg_begin();
-    for (unsigned i = 0, e = inputs.size(); i != e; ++i, ++AI)
-      AI->setName(inputs[i]->getName());
-    for (unsigned i = 0, e = outputs.size(); i != e; ++i, ++AI)
-      AI->setName(outputs[i]->getName()+".out");
+  if (!AggregateArgs || !ScalarParamTy.empty()) {
+    ScalarAI = newFunction->arg_begin();
+    for (unsigned i = 0, e = inputs.size(); i != e; ++i, ++ScalarAI)
+      if (!StructValues.contains(inputs[i]))
+        ScalarAI->setName(inputs[i]->getName());
+    for (unsigned i = 0, e = outputs.size(); i != e; ++i, ++ScalarAI)
+      if (!StructValues.contains(outputs[i]))
+        ScalarAI->setName(outputs[i]->getName() + ".out");
   }
 
   // Rewrite branches to basic blocks outside of the loop to new dummy blocks
@@ -1116,13 +1133,13 @@
 /// emitCallAndSwitchStatement - This method sets up the caller side by adding
 /// the call instruction, splitting any PHI nodes in the header block as
 /// necessary.
-CallInst *CodeExtractor::emitCallAndSwitchStatement(Function *newFunction,
-                                                    BasicBlock *codeReplacer,
-                                                    ValueSet &inputs,
-                                                    ValueSet &outputs) {
+CallInst *CodeExtractor::emitCallAndSwitchStatement(
+    Function *newFunction, BasicBlock *codeReplacer, ValueSet &inputs,
+    ValueSet &outputs, const ValueSet &ExcludeArgsFromAggregate) {
   // Emit a call to the new function, passing in: *pointer to struct (if
   // aggregating parameters), or plan inputs and allocated memory for outputs
-  std::vector<Value *> params, StructValues, ReloadOutputs, Reloads;
+  std::vector<Value *> params, ReloadOutputs, Reloads;
+  ValueSet StructValues;
 
   Module *M = newFunction->getParent();
   LLVMContext &Context = M->getContext();
@@ -1130,23 +1147,24 @@
   CallInst *call = nullptr;
 
   // Add inputs as params, or to be filled into the struct
-  unsigned ArgNo = 0;
+  unsigned ScalarInputArgNo = 0;
   SmallVector<unsigned, 1> SwiftErrorArgs;
   for (Value *input : inputs) {
-    if (AggregateArgs)
-      StructValues.push_back(input);
+    if (AggregateArgs && !ExcludeArgsFromAggregate.contains(input))
+      StructValues.insert(input);
     else {
       params.push_back(input);
       if (input->isSwiftError())
-        SwiftErrorArgs.push_back(ArgNo);
+        SwiftErrorArgs.push_back(ScalarInputArgNo);
     }
-    ++ArgNo;
+    ++ScalarInputArgNo;
   }
 
   // Create allocas for the outputs
+  unsigned ScalarOutputArgNo = 0;
   for (Value *output : outputs) {
-    if (AggregateArgs) {
-      StructValues.push_back(output);
+    if (AggregateArgs && !ExcludeArgsFromAggregate.contains(output)) {
+      StructValues.insert(output);
     } else {
       AllocaInst *alloca =
         new AllocaInst(output->getType(), DL.getAllocaAddrSpace(),
@@ -1154,12 +1172,14 @@
                        &codeReplacer->getParent()->front().front());
       ReloadOutputs.push_back(alloca);
       params.push_back(alloca);
+      ++ScalarOutputArgNo;
     }
   }
 
   StructType *StructArgTy = nullptr;
   AllocaInst *Struct = nullptr;
-  if (AggregateArgs && (inputs.size() + outputs.size() > 0)) {
+  unsigned NumAggregatedInputs = 0;
+  if (AggregateArgs && !StructValues.empty()) {
     std::vector<Type *> ArgTypes;
     for (ValueSet::iterator v = StructValues.begin(),
            ve = StructValues.end(); v != ve; ++v)
@@ -1172,14 +1192,18 @@
                             &codeReplacer->getParent()->front().front());
     params.push_back(Struct);
 
-    for (unsigned i = 0, e = inputs.size(); i != e; ++i) {
-      Value *Idx[2];
-      Idx[0] = Constant::getNullValue(Type::getInt32Ty(Context));
-      Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), i);
-      GetElementPtrInst *GEP = GetElementPtrInst::Create(
-          StructArgTy, Struct, Idx, "gep_" + StructValues[i]->getName());
-      codeReplacer->getInstList().push_back(GEP);
-      new StoreInst(StructValues[i], GEP, codeReplacer);
+    // Store aggregated inputs in the struct.
+    for (unsigned i = 0, e = StructValues.size(); i != e; ++i) {
+      if (inputs.contains(StructValues[i])) {
+        Value *Idx[2];
+        Idx[0] = Constant::getNullValue(Type::getInt32Ty(Context));
+        Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), i);
+        GetElementPtrInst *GEP = GetElementPtrInst::Create(
+            StructArgTy, Struct, Idx, "gep_" + StructValues[i]->getName());
+        codeReplacer->getInstList().push_back(GEP);
+        new StoreInst(StructValues[i], GEP, codeReplacer);
+        NumAggregatedInputs++;
+      }
     }
   }
 
@@ -1202,24 +1226,24 @@
     newFunction->addParamAttr(SwiftErrArgNo, Attribute::SwiftError);
   }
 
-  Function::arg_iterator OutputArgBegin = newFunction->arg_begin();
-  unsigned FirstOut = inputs.size();
-  if (!AggregateArgs)
-    std::advance(OutputArgBegin, inputs.size());
-
-  // Reload the outputs passed in by reference.
-  for (unsigned i = 0, e = outputs.size(); i != e; ++i) {
+  // Reload the outputs passed in by reference, use the struct if output is in
+  // the aggregate or reload from the scalar argument.
+  for (unsigned i = 0, e = outputs.size(), scalarIdx = 0,
+                aggIdx = NumAggregatedInputs;
+       i != e; ++i) {
     Value *Output = nullptr;
-    if (AggregateArgs) {
+    if (AggregateArgs && StructValues.contains(outputs[i])) {
       Value *Idx[2];
       Idx[0] = Constant::getNullValue(Type::getInt32Ty(Context));
-      Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), FirstOut + i);
+      Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), aggIdx);
       GetElementPtrInst *GEP = GetElementPtrInst::Create(
           StructArgTy, Struct, Idx, "gep_reload_" + outputs[i]->getName());
       codeReplacer->getInstList().push_back(GEP);
       Output = GEP;
+      ++aggIdx;
     } else {
-      Output = ReloadOutputs[i];
+      Output = ReloadOutputs[scalarIdx];
+      ++scalarIdx;
     }
     LoadInst *load = new LoadInst(outputs[i]->getType(), Output,
                                   outputs[i]->getName() + ".reload",
@@ -1289,8 +1313,13 @@
   // Store the arguments right after the definition of output value.
   // This should be proceeded after creating exit stubs to be ensure that invoke
   // result restore will be placed in the outlined function.
-  Function::arg_iterator OAI = OutputArgBegin;
-  for (unsigned i = 0, e = outputs.size(); i != e; ++i) {
+  Function::arg_iterator ScalarOutputArgBegin = newFunction->arg_begin();
+  std::advance(ScalarOutputArgBegin, ScalarInputArgNo);
+  Function::arg_iterator AggOutputArgBegin = newFunction->arg_begin();
+  std::advance(AggOutputArgBegin, ScalarInputArgNo + ScalarOutputArgNo);
+
+  for (unsigned i = 0, e = outputs.size(), aggIdx = NumAggregatedInputs; i != e;
+       ++i) {
     auto *OutI = dyn_cast<Instruction>(outputs[i]);
     if (!OutI)
       continue;
@@ -1310,23 +1339,27 @@
     assert((InsertBefore->getFunction() == newFunction ||
             Blocks.count(InsertBefore->getParent())) &&
            "InsertPt should be in new function");
-    assert(OAI != newFunction->arg_end() &&
-           "Number of output arguments should match "
-           "the amount of defined values");
-    if (AggregateArgs) {
+    if (AggregateArgs && StructValues.contains(outputs[i])) {
+      assert(AggOutputArgBegin != newFunction->arg_end() &&
+             "Number of aggregate output arguments should match "
+             "the number of defined values");
       Value *Idx[2];
       Idx[0] = Constant::getNullValue(Type::getInt32Ty(Context));
-      Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), FirstOut + i);
+      Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), aggIdx);
       GetElementPtrInst *GEP = GetElementPtrInst::Create(
-          StructArgTy, &*OAI, Idx, "gep_" + outputs[i]->getName(),
+          StructArgTy, &*AggOutputArgBegin, Idx, "gep_" + outputs[i]->getName(),
           InsertBefore);
       new StoreInst(outputs[i], GEP, InsertBefore);
+      ++aggIdx;
       // Since there should be only one struct argument aggregating
-      // all the output values, we shouldn't increment OAI, which always
-      // points to the struct argument, in this case.
+      // all the output values, we shouldn't increment AggOutputArgBegin, which
+      // always points to the struct argument, in this case.
     } else {
-      new StoreInst(outputs[i], &*OAI, InsertBefore);
-      ++OAI;
+      assert(ScalarOutputArgBegin != newFunction->arg_end() &&
+             "Number of scalar output arguments should match "
+             "the number of defined values");
+      new StoreInst(outputs[i], &*ScalarOutputArgBegin, InsertBefore);
+      ++ScalarOutputArgBegin;
     }
   }
 
@@ -1566,7 +1599,8 @@
 }
 
 Function *
-CodeExtractor::extractCodeRegion(const CodeExtractorAnalysisCache &CEAC) {
+CodeExtractor::extractCodeRegion(const CodeExtractorAnalysisCache &CEAC,
+                                 const ValueSet &ExcludeArgsFromAggregate) {
   if (!isEligible())
     return nullptr;
 
@@ -1698,9 +1732,9 @@
   eraseLifetimeMarkersOnInputs(Blocks, SinkingCands, LifetimesStart);
 
   // Construct new function based on inputs/outputs & add allocas for all defs.
-  Function *newFunction =
-      constructFunction(inputs, outputs, header, newFuncRoot, codeReplacer,
-                        oldFunction, oldFunction->getParent());
+  Function *newFunction = constructFunction(
+      inputs, outputs, ExcludeArgsFromAggregate, header, newFuncRoot,
+      codeReplacer, oldFunction, oldFunction->getParent());
 
   // Update the entry count of the function.
   if (BFI) {
@@ -1711,8 +1745,8 @@
     BFI->setBlockFreq(codeReplacer, EntryFreq.getFrequency());
   }
 
-  CallInst *TheCall =
-      emitCallAndSwitchStatement(newFunction, codeReplacer, inputs, outputs);
+  CallInst *TheCall = emitCallAndSwitchStatement(
+      newFunction, codeReplacer, inputs, outputs, ExcludeArgsFromAggregate);
 
   moveCodeToFunction(newFunction);
 
diff --git a/llvm/unittests/Transforms/Utils/CodeExtractorTest.cpp b/llvm/unittests/Transforms/Utils/CodeExtractorTest.cpp
--- a/llvm/unittests/Transforms/Utils/CodeExtractorTest.cpp
+++ b/llvm/unittests/Transforms/Utils/CodeExtractorTest.cpp
@@ -331,4 +331,55 @@
   EXPECT_FALSE(verifyFunction(*Outlined));
   EXPECT_FALSE(verifyFunction(*Func));
 }
+
+TEST(CodeExtractor, PartialAggregateArgs) {
+  LLVMContext Ctx;
+  SMDiagnostic Err;
+  std::unique_ptr<Module> M(parseAssemblyString(R"ir(
+    target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
+    target triple = "x86_64-unknown-linux-gnu"
+
+    declare void @use(i32)
+
+    define void @foo(i32 %a, i32 %b, i32 %c) {
+    entry:
+      br label %extract
+
+    extract:
+      call void @use(i32 %a)
+      call void @use(i32 %b)
+      call void @use(i32 %c)
+      br label %exit
+
+    exit:
+      ret void
+    }
+  )ir",
+                                                Err, Ctx));
+
+  Function *Func = M->getFunction("foo");
+  SmallVector<BasicBlock *, 1> Blocks{getBlockByName(Func, "extract")};
+
+  // Create the CodeExtractor with arguments aggregation enabled.
+  CodeExtractor CE(Blocks, /* DominatorTree */ nullptr,
+                   /* AggregateArgs */ true);
+  EXPECT_TRUE(CE.isEligible());
+
+  CodeExtractorAnalysisCache CEAC(*Func);
+  SetVector<Value *> Inputs, Outputs, SinkingCands, HoistingCands;
+  BasicBlock *CommonExit = nullptr;
+  CE.findAllocas(CEAC, SinkingCands, HoistingCands, CommonExit);
+  CE.findInputsOutputs(Inputs, Outputs, SinkingCands);
+
+  SetVector<Value *> ExcludeAggArgs;
+  // Exclude the first input from the argument aggregate.
+  ExcludeAggArgs.insert(Inputs[0]);
+  Function *Outlined = CE.extractCodeRegion(CEAC, ExcludeAggArgs);
+  EXPECT_TRUE(Outlined);
+  // Expect 2 arguments in the outlined function: the excluded input and the
+  // struct aggregate for the remaining inputs.
+  EXPECT_EQ(Outlined->arg_size(), 2U);
+  EXPECT_FALSE(verifyFunction(*Outlined));
+  EXPECT_FALSE(verifyFunction(*Func));
+}
 } // end anonymous namespace