diff --git a/llvm/include/llvm/Transforms/Utils/SCCPSolver.h b/llvm/include/llvm/Transforms/Utils/SCCPSolver.h --- a/llvm/include/llvm/Transforms/Utils/SCCPSolver.h +++ b/llvm/include/llvm/Transforms/Utils/SCCPSolver.h @@ -168,17 +168,18 @@ /// range with a single element. Constant *getConstant(const ValueLatticeElement &LV) const; + /// Return either a Constant or nullptr for a given Value. + Constant *getConstantOrNull(Value *V) const; + /// Return a reference to the set of argument tracked functions. SmallPtrSetImpl &getArgumentTrackedFunctions(); - /// Mark the constant arguments of a new function specialization. \p F points - /// to the cloned function and \p Args contains a list of constant arguments - /// represented as pairs of {formal,actual} values (the formal argument is - /// associated with the original function definition). All other arguments of - /// the specialization inherit the lattice state of their corresponding values - /// in the original function. - void markArgInFuncSpecialization(Function *F, - const SmallVectorImpl &Args); + /// Set the Lattice Value for the arguments of a specialization \p F. + /// If an argument is Constant then its lattice value is marked with the + /// corresponding actual argument in \p Args. Otherwise, its lattice value + /// is inherited (copied) from the corresponding formal argument in \p Args. + void setLatticeValueForSpecializationArguments(Function *F, + const SmallVectorImpl &Args); /// Mark all of the blocks in function \p F non-executable. Clients can used /// this method to erase a function from the module (e.g., if it has been diff --git a/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp b/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp --- a/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp +++ b/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp @@ -536,7 +536,7 @@ // Initialize the lattice state of the arguments of the function clone, // marking the argument on which we specialized the function constant // with the given value. - Solver.markArgInFuncSpecialization(Clone, S.Args); + Solver.setLatticeValueForSpecializationArguments(Clone, S.Args); Solver.addArgumentTrackedFunction(Clone); Solver.markBlockExecutable(&Clone->front()); @@ -666,16 +666,9 @@ if (A->user_empty()) return false; - // For now, don't attempt to specialize functions based on the values of - // composite types. - Type *ArgTy = A->getType(); - if (!ArgTy->isSingleValueType()) - return false; - - // Specialization of integer and floating point types needs to be explicitly - // enabled. - if (!SpecializeLiteralConstant && - (ArgTy->isIntegerTy() || ArgTy->isFloatingPointTy())) + Type *Ty = A->getType(); + if (!Ty->isPointerTy() && (!SpecializeLiteralConstant || + (!Ty->isIntegerTy() && !Ty->isFloatingPointTy() && !Ty->isStructTy()))) return false; // SCCP solver does not record an argument that will be constructed on @@ -686,21 +679,22 @@ // Check the lattice value and decide if we should attemt to specialize, // based on this argument. No point in specialization, if the lattice value // is already a constant. - const ValueLatticeElement &LV = Solver.getLatticeValueFor(A); - if (LV.isUnknownOrUndef() || LV.isConstant() || - (LV.isConstantRange() && LV.getConstantRange().isSingleElement())) { - LLVM_DEBUG(dbgs() << "FnSpecialization: Nothing to do, parameter " - << A->getNameOrAsOperand() << " is already constant\n"); - return false; - } - - LLVM_DEBUG(dbgs() << "FnSpecialization: Found interesting parameter " - << A->getNameOrAsOperand() << "\n"); + bool IsOverdefined = Ty->isStructTy() + ? any_of(Solver.getStructLatticeValueFor(A), SCCPSolver::isOverdefined) + : SCCPSolver::isOverdefined(Solver.getLatticeValueFor(A)); - return true; + LLVM_DEBUG( + if (IsOverdefined) + dbgs() << "FnSpecialization: Found interesting parameter " + << A->getNameOrAsOperand() << "\n"; + else + dbgs() << "FnSpecialization: Nothing to do, parameter " + << A->getNameOrAsOperand() << " is already constant\n"; + ); + return IsOverdefined; } -/// Check if the valuy \p V (an actual argument) is a constant or can only +/// Check if the value \p V (an actual argument) is a constant or can only /// have a constant value. Return that constant. Constant *FunctionSpecializer::getCandidateConstant(Value *V) { if (isa(V)) @@ -720,18 +714,8 @@ // Select for possible specialisation values that are constants or // are deduced to be constants or constant ranges with a single element. Constant *C = dyn_cast(V); - if (!C) { - const ValueLatticeElement &LV = Solver.getLatticeValueFor(V); - if (LV.isConstant()) - C = LV.getConstant(); - else if (LV.isConstantRange() && LV.getConstantRange().isSingleElement()) { - assert(V->getType()->isIntegerTy() && "Non-integral constant range"); - C = Constant::getIntegerValue(V->getType(), - *LV.getConstantRange().getSingleElement()); - } else - return nullptr; - } - + if (!C) + C = Solver.getConstantOrNull(V); return C; } diff --git a/llvm/lib/Transforms/Utils/SCCPSolver.cpp b/llvm/lib/Transforms/Utils/SCCPSolver.cpp --- a/llvm/lib/Transforms/Utils/SCCPSolver.cpp +++ b/llvm/lib/Transforms/Utils/SCCPSolver.cpp @@ -73,30 +73,9 @@ } bool SCCPSolver::tryToReplaceWithConstant(Value *V) { - Constant *Const = nullptr; - if (V->getType()->isStructTy()) { - std::vector IVs = getStructLatticeValueFor(V); - if (llvm::any_of(IVs, isOverdefined)) - return false; - std::vector ConstVals; - auto *ST = cast(V->getType()); - for (unsigned i = 0, e = ST->getNumElements(); i != e; ++i) { - ValueLatticeElement V = IVs[i]; - ConstVals.push_back(SCCPSolver::isConstant(V) - ? getConstant(V) - : UndefValue::get(ST->getElementType(i))); - } - Const = ConstantStruct::get(ST, ConstVals); - } else { - const ValueLatticeElement &IV = getLatticeValueFor(V); - if (isOverdefined(IV)) - return false; - - Const = SCCPSolver::isConstant(IV) ? getConstant(IV) - : UndefValue::get(V->getType()); - } - assert(Const && "Constant is nullptr here!"); - + Constant *Const = getConstantOrNull(V); + if (!Const) + return false; // Replacing `musttail` instructions with constant breaks `musttail` invariant // unless the call itself can be removed. // Calls with "clang.arc.attachedcall" implicitly use the return value and @@ -734,12 +713,14 @@ Constant *getConstant(const ValueLatticeElement &LV) const; + Constant *getConstantOrNull(Value *V) const; + SmallPtrSetImpl &getArgumentTrackedFunctions() { return TrackingIncomingArguments; } - void markArgInFuncSpecialization(Function *F, - const SmallVectorImpl &Args); + void setLatticeValueForSpecializationArguments(Function *F, + const SmallVectorImpl &Args); void markFunctionUnreachable(Function *F) { for (auto &BB : *F) @@ -833,36 +814,68 @@ return nullptr; } -void SCCPInstVisitor::markArgInFuncSpecialization( - Function *F, const SmallVectorImpl &Args) { +Constant *SCCPInstVisitor::getConstantOrNull(Value *V) const { + Constant *Const = nullptr; + if (V->getType()->isStructTy()) { + std::vector LVs = getStructLatticeValueFor(V); + if (any_of(LVs, SCCPSolver::isOverdefined)) + return nullptr; + std::vector ConstVals; + auto *ST = cast(V->getType()); + for (unsigned I = 0, E = ST->getNumElements(); I != E; ++I) { + ValueLatticeElement LV = LVs[I]; + ConstVals.push_back(SCCPSolver::isConstant(LV) + ? getConstant(LV) + : UndefValue::get(ST->getElementType(I))); + } + Const = ConstantStruct::get(ST, ConstVals); + } else { + const ValueLatticeElement &LV = getLatticeValueFor(V); + if (SCCPSolver::isOverdefined(LV)) + return nullptr; + Const = SCCPSolver::isConstant(LV) ? getConstant(LV) + : UndefValue::get(V->getType()); + } + assert(Const && "Constant is nullptr here!"); + return Const; +} + +void SCCPInstVisitor::setLatticeValueForSpecializationArguments(Function *F, + const SmallVectorImpl &Args) { assert(!Args.empty() && "Specialization without arguments"); assert(F->arg_size() == Args[0].Formal->getParent()->arg_size() && "Functions should have the same number of arguments"); auto Iter = Args.begin(); - Argument *NewArg = F->arg_begin(); - Argument *OldArg = Args[0].Formal->getParent()->arg_begin(); + Function::arg_iterator NewArg = F->arg_begin(); + Function::arg_iterator OldArg = Args[0].Formal->getParent()->arg_begin(); for (auto End = F->arg_end(); NewArg != End; ++NewArg, ++OldArg) { LLVM_DEBUG(dbgs() << "SCCP: Marking argument " << NewArg->getNameOrAsOperand() << "\n"); - if (Iter != Args.end() && OldArg == Iter->Formal) { - // Mark the argument constants in the new function. - markConstant(NewArg, Iter->Actual); + // Mark the argument constants in the new function + // or copy the lattice state over from the old function. + if (Iter != Args.end() && Iter->Formal == &*OldArg) { + if (auto *STy = dyn_cast(NewArg->getType())) { + for (unsigned I = 0, E = STy->getNumElements(); I != E; ++I) { + ValueLatticeElement &NewValue = StructValueState[{&*NewArg, I}]; + NewValue.markConstant(Iter->Actual->getAggregateElement(I)); + } + } else { + ValueState[&*NewArg].markConstant(Iter->Actual); + } ++Iter; - } else if (ValueState.count(OldArg)) { - // For the remaining arguments in the new function, copy the lattice state - // over from the old function. - // - // Note: This previously looked like this: - // ValueState[NewArg] = ValueState[OldArg]; - // This is incorrect because the DenseMap class may resize the underlying - // memory when inserting `NewArg`, which will invalidate the reference to - // `OldArg`. Instead, we make sure `NewArg` exists before setting it. - auto &NewValue = ValueState[NewArg]; - NewValue = ValueState[OldArg]; - pushToWorkList(NewValue, NewArg); + } else { + if (auto *STy = dyn_cast(NewArg->getType())) { + for (unsigned I = 0, E = STy->getNumElements(); I != E; ++I) { + ValueLatticeElement &NewValue = StructValueState[{&*NewArg, I}]; + NewValue = StructValueState[{&*OldArg, I}]; + } + } else { + ValueLatticeElement &NewValue = ValueState[&*NewArg]; + NewValue = ValueState[&*OldArg]; + } } } } @@ -1945,13 +1958,17 @@ return Visitor->getConstant(LV); } +Constant *SCCPSolver::getConstantOrNull(Value *V) const { + return Visitor->getConstantOrNull(V); +} + SmallPtrSetImpl &SCCPSolver::getArgumentTrackedFunctions() { return Visitor->getArgumentTrackedFunctions(); } -void SCCPSolver::markArgInFuncSpecialization( - Function *F, const SmallVectorImpl &Args) { - Visitor->markArgInFuncSpecialization(F, Args); +void SCCPSolver::setLatticeValueForSpecializationArguments(Function *F, + const SmallVectorImpl &Args) { + Visitor->setLatticeValueForSpecializationArguments(F, Args); } void SCCPSolver::markFunctionUnreachable(Function *F) { diff --git a/llvm/test/Transforms/FunctionSpecialization/constant-struct.ll b/llvm/test/Transforms/FunctionSpecialization/constant-struct.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/FunctionSpecialization/constant-struct.ll @@ -0,0 +1,46 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py + +; RUN: opt -passes="ipsccp" -force-specialization \ +; RUN: -funcspec-for-literal-constant -S < %s | FileCheck %s + +define i32 @foo(i32 %y0, i32 %y1) { +; CHECK-LABEL: @foo( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[Y:%.*]] = insertvalue { i32, i32 } undef, i32 [[Y0:%.*]], 0 +; CHECK-NEXT: [[YY:%.*]] = insertvalue { i32, i32 } [[Y]], i32 [[Y1:%.*]], 1 +; CHECK-NEXT: [[CALL:%.*]] = tail call i32 @add.1({ i32, i32 } { i32 2, i32 3 }, { i32, i32 } [[YY]]) +; CHECK-NEXT: ret i32 [[CALL]] +; +entry: + %y = insertvalue { i32, i32 } undef, i32 %y0, 0 + %yy = insertvalue { i32, i32 } %y, i32 %y1, 1 + %call = tail call i32 @add({i32, i32} {i32 2, i32 3}, {i32, i32} %yy) + ret i32 %call +} + +define i32 @bar(i32 %x0, i32 %x1) { +; CHECK-LABEL: @bar( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[X:%.*]] = insertvalue { i32, i32 } undef, i32 [[X0:%.*]], 0 +; CHECK-NEXT: [[XX:%.*]] = insertvalue { i32, i32 } [[X]], i32 [[X1:%.*]], 1 +; CHECK-NEXT: [[CALL:%.*]] = tail call i32 @add.2({ i32, i32 } [[XX]], { i32, i32 } { i32 3, i32 2 }) +; CHECK-NEXT: ret i32 [[CALL]] +; +entry: + %x = insertvalue { i32, i32 } undef, i32 %x0, 0 + %xx = insertvalue { i32, i32 } %x, i32 %x1, 1 + %call = tail call i32 @add({i32, i32} %xx, {i32, i32} {i32 3, i32 2}) + ret i32 %call +} + +define internal i32 @add({i32, i32} %x, {i32, i32} %y) { +entry: + %x0 = extractvalue {i32, i32} %x, 0 + %y0 = extractvalue {i32, i32} %y, 0 + %add0 = add nsw i32 %x0, %y0 + %x1 = extractvalue {i32, i32} %x, 1 + %y1 = extractvalue {i32, i32} %y, 1 + %add1 = add nsw i32 %x1, %y1 + %mul = mul i32 %add0, %add1 + ret i32 %mul +}