diff --git a/llvm/include/llvm/Transforms/Utils/CodeExtractor.h b/llvm/include/llvm/Transforms/Utils/CodeExtractor.h --- a/llvm/include/llvm/Transforms/Utils/CodeExtractor.h +++ b/llvm/include/llvm/Transforms/Utils/CodeExtractor.h @@ -137,7 +137,8 @@ /// /// Returns zero when called on a CodeExtractor instance where isEligible /// returns false. - Function *extractCodeRegion(const CodeExtractorAnalysisCache &CEAC); + Function *extractCodeRegion(const CodeExtractorAnalysisCache &CEAC, + const ValueSet &ExcludeArgsFromAggregate = {}); /// Verify that assumption cache isn't stale after a region is extracted. /// Returns true when verifier finds errors. AssumptionCache is passed as @@ -212,11 +213,11 @@ void severSplitPHINodesOfExits(const SmallPtrSetImpl &Exits); void splitReturnBlocks(); - Function *constructFunction(const ValueSet &inputs, - const ValueSet &outputs, - BasicBlock *header, - BasicBlock *newRootNode, BasicBlock *newHeader, - Function *oldFunction, Module *M); + Function *constructFunction(const ValueSet &inputs, const ValueSet &outputs, + const ValueSet &ExcludeArgsFromAggregate, + BasicBlock *header, BasicBlock *newRootNode, + BasicBlock *newHeader, Function *oldFunction, + Module *M); void moveCodeToFunction(Function *newFunction); @@ -225,9 +226,10 @@ DenseMap &ExitWeights, BranchProbabilityInfo *BPI); - CallInst *emitCallAndSwitchStatement(Function *newFunction, - BasicBlock *newHeader, - ValueSet &inputs, ValueSet &outputs); + CallInst * + emitCallAndSwitchStatement(Function *newFunction, BasicBlock *newHeader, + ValueSet &inputs, ValueSet &outputs, + const ValueSet &ExcludeArgsFromAggregate); }; } // end namespace llvm diff --git a/llvm/lib/Transforms/Utils/CodeExtractor.cpp b/llvm/lib/Transforms/Utils/CodeExtractor.cpp --- a/llvm/lib/Transforms/Utils/CodeExtractor.cpp +++ b/llvm/lib/Transforms/Utils/CodeExtractor.cpp @@ -811,13 +811,11 @@ /// constructFunction - make a function based on inputs and outputs, as follows: /// f(in0, ..., inN, out0, ..., outN) -Function *CodeExtractor::constructFunction(const ValueSet &inputs, - const ValueSet &outputs, - BasicBlock *header, - BasicBlock *newRootNode, - BasicBlock *newHeader, - Function *oldFunction, - Module *M) { +Function *CodeExtractor::constructFunction( + const ValueSet &inputs, const ValueSet &outputs, + const ValueSet &ExcludeArgsFromAggregate, BasicBlock *header, + BasicBlock *newRootNode, BasicBlock *newHeader, Function *oldFunction, + Module *M) { LLVM_DEBUG(dbgs() << "inputs: " << inputs.size() << "\n"); LLVM_DEBUG(dbgs() << "outputs: " << outputs.size() << "\n"); @@ -830,20 +828,40 @@ } std::vector paramTy; + std::vector ScalarParamTy; + std::vector AggParamTy; + ValueSet StructValues; // Add the types of the input values to the function's argument list for (Value *value : inputs) { LLVM_DEBUG(dbgs() << "value used in func: " << *value << "\n"); - paramTy.push_back(value->getType()); + if (AggregateArgs && !ExcludeArgsFromAggregate.contains(value)) { + AggParamTy.push_back(value->getType()); + StructValues.insert(value); + } else + ScalarParamTy.push_back(value->getType()); } // Add the types of the output values to the function's argument list. for (Value *output : outputs) { LLVM_DEBUG(dbgs() << "instr used in func: " << *output << "\n"); - if (AggregateArgs) - paramTy.push_back(output->getType()); - else - paramTy.push_back(PointerType::getUnqual(output->getType())); + if (AggregateArgs && !ExcludeArgsFromAggregate.contains(output)) { + AggParamTy.push_back(output->getType()); + StructValues.insert(output); + } else + ScalarParamTy.push_back(PointerType::getUnqual(output->getType())); + } + + assert( + (ScalarParamTy.size() + AggParamTy.size()) == + (inputs.size() + outputs.size()) && + "Number of scalar and aggregate params does not match inputs, outputs"); + + paramTy = ScalarParamTy; + StructType *StructTy = nullptr; + if (AggregateArgs && !AggParamTy.empty()) { + StructTy = StructType::get(M->getContext(), AggParamTy); + paramTy.push_back(PointerType::getUnqual(StructTy)); } LLVM_DEBUG({ @@ -853,12 +871,6 @@ dbgs() << ")\n"; }); - StructType *StructTy = nullptr; - if (AggregateArgs && (inputs.size() + outputs.size() > 0)) { - StructTy = StructType::get(M->getContext(), paramTy); - paramTy.clear(); - paramTy.push_back(PointerType::getUnqual(StructTy)); - } FunctionType *funcType = FunctionType::get(RetTy, paramTy, AllowVarArgs && oldFunction->isVarArg()); @@ -981,24 +993,27 @@ } newFunction->getBasicBlockList().push_back(newRootNode); - // Create an iterator to name all of the arguments we inserted. - Function::arg_iterator AI = newFunction->arg_begin(); + // Create scalar and aggregate iterators to name all of the arguments we + // inserted. + Function::arg_iterator ScalarAI = newFunction->arg_begin(); + Function::arg_iterator AggAI = std::next(ScalarAI, ScalarParamTy.size()); // Rewrite all users of the inputs in the extracted region to use the // arguments (or appropriate addressing into struct) instead. - for (unsigned i = 0, e = inputs.size(); i != e; ++i) { + for (unsigned i = 0, e = inputs.size(), aggIdx = 0; i != e; ++i) { Value *RewriteVal; - if (AggregateArgs) { + if (AggregateArgs && StructValues.contains(inputs[i])) { Value *Idx[2]; Idx[0] = Constant::getNullValue(Type::getInt32Ty(header->getContext())); - Idx[1] = ConstantInt::get(Type::getInt32Ty(header->getContext()), i); + Idx[1] = ConstantInt::get(Type::getInt32Ty(header->getContext()), aggIdx); Instruction *TI = newFunction->begin()->getTerminator(); GetElementPtrInst *GEP = GetElementPtrInst::Create( - StructTy, &*AI, Idx, "gep_" + inputs[i]->getName(), TI); - RewriteVal = new LoadInst(StructTy->getElementType(i), GEP, + StructTy, &*AggAI, Idx, "gep_" + inputs[i]->getName(), TI); + RewriteVal = new LoadInst(StructTy->getElementType(aggIdx), GEP, "loadgep_" + inputs[i]->getName(), TI); + ++aggIdx; } else - RewriteVal = &*AI++; + RewriteVal = &*ScalarAI++; std::vector Users(inputs[i]->user_begin(), inputs[i]->user_end()); for (User *use : Users) @@ -1008,12 +1023,14 @@ } // Set names for input and output arguments. - if (!AggregateArgs) { - AI = newFunction->arg_begin(); - for (unsigned i = 0, e = inputs.size(); i != e; ++i, ++AI) - AI->setName(inputs[i]->getName()); - for (unsigned i = 0, e = outputs.size(); i != e; ++i, ++AI) - AI->setName(outputs[i]->getName()+".out"); + if (!AggregateArgs || !ScalarParamTy.empty()) { + ScalarAI = newFunction->arg_begin(); + for (unsigned i = 0, e = inputs.size(); i != e; ++i, ++ScalarAI) + if (!StructValues.contains(inputs[i])) + ScalarAI->setName(inputs[i]->getName()); + for (unsigned i = 0, e = outputs.size(); i != e; ++i, ++ScalarAI) + if (!StructValues.contains(outputs[i])) + ScalarAI->setName(outputs[i]->getName() + ".out"); } // Rewrite branches to basic blocks outside of the loop to new dummy blocks @@ -1116,13 +1133,13 @@ /// emitCallAndSwitchStatement - This method sets up the caller side by adding /// the call instruction, splitting any PHI nodes in the header block as /// necessary. -CallInst *CodeExtractor::emitCallAndSwitchStatement(Function *newFunction, - BasicBlock *codeReplacer, - ValueSet &inputs, - ValueSet &outputs) { +CallInst *CodeExtractor::emitCallAndSwitchStatement( + Function *newFunction, BasicBlock *codeReplacer, ValueSet &inputs, + ValueSet &outputs, const ValueSet &ExcludeArgsFromAggregate) { // Emit a call to the new function, passing in: *pointer to struct (if // aggregating parameters), or plan inputs and allocated memory for outputs - std::vector params, StructValues, ReloadOutputs, Reloads; + std::vector params, ReloadOutputs, Reloads; + ValueSet StructValues; Module *M = newFunction->getParent(); LLVMContext &Context = M->getContext(); @@ -1130,23 +1147,24 @@ CallInst *call = nullptr; // Add inputs as params, or to be filled into the struct - unsigned ArgNo = 0; + unsigned ScalarInputArgNo = 0; SmallVector SwiftErrorArgs; for (Value *input : inputs) { - if (AggregateArgs) - StructValues.push_back(input); + if (AggregateArgs && !ExcludeArgsFromAggregate.contains(input)) + StructValues.insert(input); else { params.push_back(input); if (input->isSwiftError()) - SwiftErrorArgs.push_back(ArgNo); + SwiftErrorArgs.push_back(ScalarInputArgNo); } - ++ArgNo; + ++ScalarInputArgNo; } // Create allocas for the outputs + unsigned ScalarOutputArgNo = 0; for (Value *output : outputs) { - if (AggregateArgs) { - StructValues.push_back(output); + if (AggregateArgs && !ExcludeArgsFromAggregate.contains(output)) { + StructValues.insert(output); } else { AllocaInst *alloca = new AllocaInst(output->getType(), DL.getAllocaAddrSpace(), @@ -1154,12 +1172,14 @@ &codeReplacer->getParent()->front().front()); ReloadOutputs.push_back(alloca); params.push_back(alloca); + ++ScalarOutputArgNo; } } StructType *StructArgTy = nullptr; AllocaInst *Struct = nullptr; - if (AggregateArgs && (inputs.size() + outputs.size() > 0)) { + unsigned NumAggregatedInputs = 0; + if (AggregateArgs && !StructValues.empty()) { std::vector ArgTypes; for (ValueSet::iterator v = StructValues.begin(), ve = StructValues.end(); v != ve; ++v) @@ -1172,14 +1192,18 @@ &codeReplacer->getParent()->front().front()); params.push_back(Struct); - for (unsigned i = 0, e = inputs.size(); i != e; ++i) { - Value *Idx[2]; - Idx[0] = Constant::getNullValue(Type::getInt32Ty(Context)); - Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), i); - GetElementPtrInst *GEP = GetElementPtrInst::Create( - StructArgTy, Struct, Idx, "gep_" + StructValues[i]->getName()); - codeReplacer->getInstList().push_back(GEP); - new StoreInst(StructValues[i], GEP, codeReplacer); + // Store aggregated inputs in the struct. + for (unsigned i = 0, e = StructValues.size(); i != e; ++i) { + if (inputs.contains(StructValues[i])) { + Value *Idx[2]; + Idx[0] = Constant::getNullValue(Type::getInt32Ty(Context)); + Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), i); + GetElementPtrInst *GEP = GetElementPtrInst::Create( + StructArgTy, Struct, Idx, "gep_" + StructValues[i]->getName()); + codeReplacer->getInstList().push_back(GEP); + new StoreInst(StructValues[i], GEP, codeReplacer); + NumAggregatedInputs++; + } } } @@ -1202,24 +1226,24 @@ newFunction->addParamAttr(SwiftErrArgNo, Attribute::SwiftError); } - Function::arg_iterator OutputArgBegin = newFunction->arg_begin(); - unsigned FirstOut = inputs.size(); - if (!AggregateArgs) - std::advance(OutputArgBegin, inputs.size()); - - // Reload the outputs passed in by reference. - for (unsigned i = 0, e = outputs.size(); i != e; ++i) { + // Reload the outputs passed in by reference, use the struct if output is in + // the aggregate or reload from the scalar argument. + for (unsigned i = 0, e = outputs.size(), scalarIdx = 0, + aggIdx = NumAggregatedInputs; + i != e; ++i) { Value *Output = nullptr; - if (AggregateArgs) { + if (AggregateArgs && StructValues.contains(outputs[i])) { Value *Idx[2]; Idx[0] = Constant::getNullValue(Type::getInt32Ty(Context)); - Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), FirstOut + i); + Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), aggIdx); GetElementPtrInst *GEP = GetElementPtrInst::Create( StructArgTy, Struct, Idx, "gep_reload_" + outputs[i]->getName()); codeReplacer->getInstList().push_back(GEP); Output = GEP; + ++aggIdx; } else { - Output = ReloadOutputs[i]; + Output = ReloadOutputs[scalarIdx]; + ++scalarIdx; } LoadInst *load = new LoadInst(outputs[i]->getType(), Output, outputs[i]->getName() + ".reload", @@ -1289,8 +1313,13 @@ // Store the arguments right after the definition of output value. // This should be proceeded after creating exit stubs to be ensure that invoke // result restore will be placed in the outlined function. - Function::arg_iterator OAI = OutputArgBegin; - for (unsigned i = 0, e = outputs.size(); i != e; ++i) { + Function::arg_iterator ScalarOutputArgBegin = newFunction->arg_begin(); + std::advance(ScalarOutputArgBegin, ScalarInputArgNo); + Function::arg_iterator AggOutputArgBegin = newFunction->arg_begin(); + std::advance(AggOutputArgBegin, ScalarInputArgNo + ScalarOutputArgNo); + + for (unsigned i = 0, e = outputs.size(), aggIdx = NumAggregatedInputs; i != e; + ++i) { auto *OutI = dyn_cast(outputs[i]); if (!OutI) continue; @@ -1310,23 +1339,27 @@ assert((InsertBefore->getFunction() == newFunction || Blocks.count(InsertBefore->getParent())) && "InsertPt should be in new function"); - assert(OAI != newFunction->arg_end() && - "Number of output arguments should match " - "the amount of defined values"); - if (AggregateArgs) { + if (AggregateArgs && StructValues.contains(outputs[i])) { + assert(AggOutputArgBegin != newFunction->arg_end() && + "Number of aggregate output arguments should match " + "the number of defined values"); Value *Idx[2]; Idx[0] = Constant::getNullValue(Type::getInt32Ty(Context)); - Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), FirstOut + i); + Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), aggIdx); GetElementPtrInst *GEP = GetElementPtrInst::Create( - StructArgTy, &*OAI, Idx, "gep_" + outputs[i]->getName(), + StructArgTy, &*AggOutputArgBegin, Idx, "gep_" + outputs[i]->getName(), InsertBefore); new StoreInst(outputs[i], GEP, InsertBefore); + ++aggIdx; // Since there should be only one struct argument aggregating - // all the output values, we shouldn't increment OAI, which always - // points to the struct argument, in this case. + // all the output values, we shouldn't increment AggOutputArgBegin, which + // always points to the struct argument, in this case. } else { - new StoreInst(outputs[i], &*OAI, InsertBefore); - ++OAI; + assert(ScalarOutputArgBegin != newFunction->arg_end() && + "Number of scalar output arguments should match " + "the number of defined values"); + new StoreInst(outputs[i], &*ScalarOutputArgBegin, InsertBefore); + ++ScalarOutputArgBegin; } } @@ -1566,7 +1599,8 @@ } Function * -CodeExtractor::extractCodeRegion(const CodeExtractorAnalysisCache &CEAC) { +CodeExtractor::extractCodeRegion(const CodeExtractorAnalysisCache &CEAC, + const ValueSet &ExcludeArgsFromAggregate) { if (!isEligible()) return nullptr; @@ -1698,9 +1732,9 @@ eraseLifetimeMarkersOnInputs(Blocks, SinkingCands, LifetimesStart); // Construct new function based on inputs/outputs & add allocas for all defs. - Function *newFunction = - constructFunction(inputs, outputs, header, newFuncRoot, codeReplacer, - oldFunction, oldFunction->getParent()); + Function *newFunction = constructFunction( + inputs, outputs, ExcludeArgsFromAggregate, header, newFuncRoot, + codeReplacer, oldFunction, oldFunction->getParent()); // Update the entry count of the function. if (BFI) { @@ -1711,8 +1745,8 @@ BFI->setBlockFreq(codeReplacer, EntryFreq.getFrequency()); } - CallInst *TheCall = - emitCallAndSwitchStatement(newFunction, codeReplacer, inputs, outputs); + CallInst *TheCall = emitCallAndSwitchStatement( + newFunction, codeReplacer, inputs, outputs, ExcludeArgsFromAggregate); moveCodeToFunction(newFunction);