diff --git a/llvm/include/llvm/Transforms/Utils/CodeExtractor.h b/llvm/include/llvm/Transforms/Utils/CodeExtractor.h --- a/llvm/include/llvm/Transforms/Utils/CodeExtractor.h +++ b/llvm/include/llvm/Transforms/Utils/CodeExtractor.h @@ -168,7 +168,7 @@ /// /// Based on the blocks used when constructing the code extractor, /// determine whether it is eligible for extraction. - /// + /// /// Checks that varargs handling (with vastart and vaend) is only done in /// the outlined blocks. bool isEligible() const; @@ -214,6 +214,10 @@ /// original block will be added to the outline region. BasicBlock *findOrCreateBlockForHoisting(BasicBlock *CommonExitBlock); + /// Exclude a value from aggregate argument passing when extracting a code + /// region, passing it instead as a scalar. + void excludeArgFromAggregate(Value *Arg); + private: struct LifetimeMarkerInfo { bool SinkLifeStart = false; @@ -222,6 +226,8 @@ Instruction *LifeEnd = nullptr; }; + ValueSet ExcludeArgsFromAggregate; + LifetimeMarkerInfo getLifetimeMarkers(const CodeExtractorAnalysisCache &CEAC, Instruction *Addr, BasicBlock *ExitBlock) const; diff --git a/llvm/lib/Transforms/Utils/CodeExtractor.cpp b/llvm/lib/Transforms/Utils/CodeExtractor.cpp --- a/llvm/lib/Transforms/Utils/CodeExtractor.cpp +++ b/llvm/lib/Transforms/Utils/CodeExtractor.cpp @@ -830,20 +830,40 @@ } std::vector paramTy; + std::vector ScalarParamTy; + std::vector AggParamTy; + ValueSet StructValues; // Add the types of the input values to the function's argument list for (Value *value : inputs) { LLVM_DEBUG(dbgs() << "value used in func: " << *value << "\n"); - paramTy.push_back(value->getType()); + if (AggregateArgs && !ExcludeArgsFromAggregate.contains(value)) { + AggParamTy.push_back(value->getType()); + StructValues.insert(value); + } else + ScalarParamTy.push_back(value->getType()); } // Add the types of the output values to the function's argument list. for (Value *output : outputs) { LLVM_DEBUG(dbgs() << "instr used in func: " << *output << "\n"); - if (AggregateArgs) - paramTy.push_back(output->getType()); - else - paramTy.push_back(PointerType::getUnqual(output->getType())); + if (AggregateArgs && !ExcludeArgsFromAggregate.contains(output)) { + AggParamTy.push_back(output->getType()); + StructValues.insert(output); + } else + ScalarParamTy.push_back(PointerType::getUnqual(output->getType())); + } + + assert( + (ScalarParamTy.size() + AggParamTy.size()) == + (inputs.size() + outputs.size()) && + "Number of scalar and aggregate params does not match inputs, outputs"); + + paramTy = ScalarParamTy; + StructType *StructTy = nullptr; + if (AggregateArgs && !AggParamTy.empty()) { + StructTy = StructType::get(M->getContext(), AggParamTy); + paramTy.push_back(PointerType::getUnqual(StructTy)); } LLVM_DEBUG({ @@ -853,12 +873,6 @@ dbgs() << ")\n"; }); - StructType *StructTy = nullptr; - if (AggregateArgs && (inputs.size() + outputs.size() > 0)) { - StructTy = StructType::get(M->getContext(), paramTy); - paramTy.clear(); - paramTy.push_back(PointerType::getUnqual(StructTy)); - } FunctionType *funcType = FunctionType::get(RetTy, paramTy, AllowVarArgs && oldFunction->isVarArg()); @@ -986,24 +1000,27 @@ } newFunction->getBasicBlockList().push_back(newRootNode); - // Create an iterator to name all of the arguments we inserted. - Function::arg_iterator AI = newFunction->arg_begin(); + // Create scalar and aggregate iterators to name all of the arguments we + // inserted. + Function::arg_iterator ScalarAI = newFunction->arg_begin(); + Function::arg_iterator AggAI = std::next(ScalarAI, ScalarParamTy.size()); // Rewrite all users of the inputs in the extracted region to use the // arguments (or appropriate addressing into struct) instead. - for (unsigned i = 0, e = inputs.size(); i != e; ++i) { + for (unsigned i = 0, e = inputs.size(), aggIdx = 0; i != e; ++i) { Value *RewriteVal; - if (AggregateArgs) { + if (AggregateArgs && StructValues.contains(inputs[i])) { Value *Idx[2]; Idx[0] = Constant::getNullValue(Type::getInt32Ty(header->getContext())); - Idx[1] = ConstantInt::get(Type::getInt32Ty(header->getContext()), i); + Idx[1] = ConstantInt::get(Type::getInt32Ty(header->getContext()), aggIdx); Instruction *TI = newFunction->begin()->getTerminator(); GetElementPtrInst *GEP = GetElementPtrInst::Create( - StructTy, &*AI, Idx, "gep_" + inputs[i]->getName(), TI); - RewriteVal = new LoadInst(StructTy->getElementType(i), GEP, + StructTy, &*AggAI, Idx, "gep_" + inputs[i]->getName(), TI); + RewriteVal = new LoadInst(StructTy->getElementType(aggIdx), GEP, "loadgep_" + inputs[i]->getName(), TI); + ++aggIdx; } else - RewriteVal = &*AI++; + RewriteVal = &*ScalarAI++; std::vector Users(inputs[i]->user_begin(), inputs[i]->user_end()); for (User *use : Users) @@ -1013,12 +1030,14 @@ } // Set names for input and output arguments. - if (!AggregateArgs) { - AI = newFunction->arg_begin(); - for (unsigned i = 0, e = inputs.size(); i != e; ++i, ++AI) - AI->setName(inputs[i]->getName()); - for (unsigned i = 0, e = outputs.size(); i != e; ++i, ++AI) - AI->setName(outputs[i]->getName()+".out"); + if (!AggregateArgs || !ScalarParamTy.empty()) { + ScalarAI = newFunction->arg_begin(); + for (unsigned i = 0, e = inputs.size(); i != e; ++i, ++ScalarAI) + if (!StructValues.contains(inputs[i])) + ScalarAI->setName(inputs[i]->getName()); + for (unsigned i = 0, e = outputs.size(); i != e; ++i, ++ScalarAI) + if (!StructValues.contains(outputs[i])) + ScalarAI->setName(outputs[i]->getName() + ".out"); } // Rewrite branches to basic blocks outside of the loop to new dummy blocks @@ -1127,7 +1146,8 @@ ValueSet &outputs) { // Emit a call to the new function, passing in: *pointer to struct (if // aggregating parameters), or plan inputs and allocated memory for outputs - std::vector params, StructValues, ReloadOutputs, Reloads; + std::vector params, ReloadOutputs, Reloads; + ValueSet StructValues; Module *M = newFunction->getParent(); LLVMContext &Context = M->getContext(); @@ -1135,23 +1155,24 @@ CallInst *call = nullptr; // Add inputs as params, or to be filled into the struct - unsigned ArgNo = 0; + unsigned ScalarInputArgNo = 0; SmallVector SwiftErrorArgs; for (Value *input : inputs) { - if (AggregateArgs) - StructValues.push_back(input); + if (AggregateArgs && !ExcludeArgsFromAggregate.contains(input)) + StructValues.insert(input); else { params.push_back(input); if (input->isSwiftError()) - SwiftErrorArgs.push_back(ArgNo); + SwiftErrorArgs.push_back(ScalarInputArgNo); } - ++ArgNo; + ++ScalarInputArgNo; } // Create allocas for the outputs + unsigned ScalarOutputArgNo = 0; for (Value *output : outputs) { - if (AggregateArgs) { - StructValues.push_back(output); + if (AggregateArgs && !ExcludeArgsFromAggregate.contains(output)) { + StructValues.insert(output); } else { AllocaInst *alloca = new AllocaInst(output->getType(), DL.getAllocaAddrSpace(), @@ -1159,12 +1180,14 @@ &codeReplacer->getParent()->front().front()); ReloadOutputs.push_back(alloca); params.push_back(alloca); + ++ScalarOutputArgNo; } } StructType *StructArgTy = nullptr; AllocaInst *Struct = nullptr; - if (AggregateArgs && (inputs.size() + outputs.size() > 0)) { + unsigned NumAggregatedInputs = 0; + if (AggregateArgs && !StructValues.empty()) { std::vector ArgTypes; for (Value *V : StructValues) ArgTypes.push_back(V->getType()); @@ -1176,14 +1199,18 @@ &codeReplacer->getParent()->front().front()); params.push_back(Struct); - for (unsigned i = 0, e = inputs.size(); i != e; ++i) { - Value *Idx[2]; - Idx[0] = Constant::getNullValue(Type::getInt32Ty(Context)); - Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), i); - GetElementPtrInst *GEP = GetElementPtrInst::Create( - StructArgTy, Struct, Idx, "gep_" + StructValues[i]->getName()); - codeReplacer->getInstList().push_back(GEP); - new StoreInst(StructValues[i], GEP, codeReplacer); + // Store aggregated inputs in the struct. + for (unsigned i = 0, e = StructValues.size(); i != e; ++i) { + if (inputs.contains(StructValues[i])) { + Value *Idx[2]; + Idx[0] = Constant::getNullValue(Type::getInt32Ty(Context)); + Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), i); + GetElementPtrInst *GEP = GetElementPtrInst::Create( + StructArgTy, Struct, Idx, "gep_" + StructValues[i]->getName()); + codeReplacer->getInstList().push_back(GEP); + new StoreInst(StructValues[i], GEP, codeReplacer); + NumAggregatedInputs++; + } } } @@ -1206,24 +1233,24 @@ newFunction->addParamAttr(SwiftErrArgNo, Attribute::SwiftError); } - Function::arg_iterator OutputArgBegin = newFunction->arg_begin(); - unsigned FirstOut = inputs.size(); - if (!AggregateArgs) - std::advance(OutputArgBegin, inputs.size()); - - // Reload the outputs passed in by reference. - for (unsigned i = 0, e = outputs.size(); i != e; ++i) { + // Reload the outputs passed in by reference, use the struct if output is in + // the aggregate or reload from the scalar argument. + for (unsigned i = 0, e = outputs.size(), scalarIdx = 0, + aggIdx = NumAggregatedInputs; + i != e; ++i) { Value *Output = nullptr; - if (AggregateArgs) { + if (AggregateArgs && StructValues.contains(outputs[i])) { Value *Idx[2]; Idx[0] = Constant::getNullValue(Type::getInt32Ty(Context)); - Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), FirstOut + i); + Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), aggIdx); GetElementPtrInst *GEP = GetElementPtrInst::Create( StructArgTy, Struct, Idx, "gep_reload_" + outputs[i]->getName()); codeReplacer->getInstList().push_back(GEP); Output = GEP; + ++aggIdx; } else { - Output = ReloadOutputs[i]; + Output = ReloadOutputs[scalarIdx]; + ++scalarIdx; } LoadInst *load = new LoadInst(outputs[i]->getType(), Output, outputs[i]->getName() + ".reload", @@ -1305,8 +1332,13 @@ // Store the arguments right after the definition of output value. // This should be proceeded after creating exit stubs to be ensure that invoke // result restore will be placed in the outlined function. - Function::arg_iterator OAI = OutputArgBegin; - for (unsigned i = 0, e = outputs.size(); i != e; ++i) { + Function::arg_iterator ScalarOutputArgBegin = newFunction->arg_begin(); + std::advance(ScalarOutputArgBegin, ScalarInputArgNo); + Function::arg_iterator AggOutputArgBegin = newFunction->arg_begin(); + std::advance(AggOutputArgBegin, ScalarInputArgNo + ScalarOutputArgNo); + + for (unsigned i = 0, e = outputs.size(), aggIdx = NumAggregatedInputs; i != e; + ++i) { auto *OutI = dyn_cast(outputs[i]); if (!OutI) continue; @@ -1326,23 +1358,27 @@ assert((InsertBefore->getFunction() == newFunction || Blocks.count(InsertBefore->getParent())) && "InsertPt should be in new function"); - assert(OAI != newFunction->arg_end() && - "Number of output arguments should match " - "the amount of defined values"); - if (AggregateArgs) { + if (AggregateArgs && StructValues.contains(outputs[i])) { + assert(AggOutputArgBegin != newFunction->arg_end() && + "Number of aggregate output arguments should match " + "the number of defined values"); Value *Idx[2]; Idx[0] = Constant::getNullValue(Type::getInt32Ty(Context)); - Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), FirstOut + i); + Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), aggIdx); GetElementPtrInst *GEP = GetElementPtrInst::Create( - StructArgTy, &*OAI, Idx, "gep_" + outputs[i]->getName(), + StructArgTy, &*AggOutputArgBegin, Idx, "gep_" + outputs[i]->getName(), InsertBefore); new StoreInst(outputs[i], GEP, InsertBefore); + ++aggIdx; // Since there should be only one struct argument aggregating - // all the output values, we shouldn't increment OAI, which always - // points to the struct argument, in this case. + // all the output values, we shouldn't increment AggOutputArgBegin, which + // always points to the struct argument, in this case. } else { - new StoreInst(outputs[i], &*OAI, InsertBefore); - ++OAI; + assert(ScalarOutputArgBegin != newFunction->arg_end() && + "Number of scalar output arguments should match " + "the number of defined values"); + new StoreInst(outputs[i], &*ScalarOutputArgBegin, InsertBefore); + ++ScalarOutputArgBegin; } } @@ -1844,3 +1880,7 @@ } return false; } + +void CodeExtractor::excludeArgFromAggregate(Value *Arg) { + ExcludeArgsFromAggregate.insert(Arg); +} diff --git a/llvm/unittests/Transforms/Utils/CodeExtractorTest.cpp b/llvm/unittests/Transforms/Utils/CodeExtractorTest.cpp --- a/llvm/unittests/Transforms/Utils/CodeExtractorTest.cpp +++ b/llvm/unittests/Transforms/Utils/CodeExtractorTest.cpp @@ -188,7 +188,7 @@ EXPECT_TRUE(NextReturn); ConstantInt *CINext = dyn_cast(NextReturn->getReturnValue()); EXPECT_TRUE(CINext->getLimitedValue() == 0u); - + EXPECT_FALSE(verifyFunction(*Outlined)); EXPECT_FALSE(verifyFunction(*Func)); } @@ -245,7 +245,7 @@ EXPECT_TRUE(NextReturn); ConstantInt *CINext = dyn_cast(NextReturn->getReturnValue()); EXPECT_TRUE(CINext->getLimitedValue() == 0u); - + EXPECT_FALSE(verifyFunction(*Outlined)); EXPECT_FALSE(verifyFunction(*Func)); } @@ -504,4 +504,54 @@ EXPECT_FALSE(verifyFunction(*Outlined)); EXPECT_FALSE(verifyFunction(*Func)); } + +TEST(CodeExtractor, PartialAggregateArgs) { + LLVMContext Ctx; + SMDiagnostic Err; + std::unique_ptr M(parseAssemblyString(R"ir( + target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128" + target triple = "x86_64-unknown-linux-gnu" + + declare void @use(i32) + + define void @foo(i32 %a, i32 %b, i32 %c) { + entry: + br label %extract + + extract: + call void @use(i32 %a) + call void @use(i32 %b) + call void @use(i32 %c) + br label %exit + + exit: + ret void + } + )ir", + Err, Ctx)); + + Function *Func = M->getFunction("foo"); + SmallVector Blocks{getBlockByName(Func, "extract")}; + + // Create the CodeExtractor with arguments aggregation enabled. + CodeExtractor CE(Blocks, /* DominatorTree */ nullptr, + /* AggregateArgs */ true); + EXPECT_TRUE(CE.isEligible()); + + CodeExtractorAnalysisCache CEAC(*Func); + SetVector Inputs, Outputs, SinkingCands, HoistingCands; + BasicBlock *CommonExit = nullptr; + CE.findAllocas(CEAC, SinkingCands, HoistingCands, CommonExit); + CE.findInputsOutputs(Inputs, Outputs, SinkingCands); + // Exclude the first input from the argument aggregate. + CE.excludeArgFromAggregate(Inputs[0]); + + Function *Outlined = CE.extractCodeRegion(CEAC, Inputs, Outputs); + EXPECT_TRUE(Outlined); + // Expect 2 arguments in the outlined function: the excluded input and the + // struct aggregate for the remaining inputs. + EXPECT_EQ(Outlined->arg_size(), 2U); + EXPECT_FALSE(verifyFunction(*Outlined)); + EXPECT_FALSE(verifyFunction(*Func)); +} } // end anonymous namespace