diff --git a/llvm/include/llvm/Transforms/Utils/SCCPSolver.h b/llvm/include/llvm/Transforms/Utils/SCCPSolver.h --- a/llvm/include/llvm/Transforms/Utils/SCCPSolver.h +++ b/llvm/include/llvm/Transforms/Utils/SCCPSolver.h @@ -168,6 +168,9 @@ /// range with a single element. Constant *getConstant(const ValueLatticeElement &LV) const; + /// Return either a Constant or nullptr for a given Value. + Constant *getConstantOrNull(Value *V) const; + /// Return a reference to the set of argument tracked functions. SmallPtrSetImpl &getArgumentTrackedFunctions(); diff --git a/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp b/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp --- a/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp +++ b/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp @@ -658,6 +658,11 @@ return TotalCost + Bonus; } +static bool isValidArgumentType(Type *Ty) { + return Ty->isPointerTy() || (EnableSpecializationForLiteralConstant && + (Ty->isIntegerTy() || Ty->isFloatingPointTy() || Ty->isStructTy())); +} + /// Determine if it is possible to specialise the function for constant values /// of the formal parameter \p A. bool FunctionSpecializer::isArgumentInteresting(Argument *A) { @@ -665,16 +670,8 @@ if (A->user_empty()) return false; - // For now, don't attempt to specialize functions based on the values of - // composite types. Type *ArgTy = A->getType(); - if (!ArgTy->isSingleValueType()) - return false; - - // Specialization of integer and floating point types needs to be explicitly - // enabled. - if (!EnableSpecializationForLiteralConstant && - (ArgTy->isIntegerTy() || ArgTy->isFloatingPointTy())) + if (!isValidArgumentType(ArgTy)) return false; // SCCP solver does not record an argument that will be constructed on @@ -685,21 +682,21 @@ // Check the lattice value and decide if we should attemt to specialize, // based on this argument. No point in specialization, if the lattice value // is already a constant. - const ValueLatticeElement &LV = Solver.getLatticeValueFor(A); - if (LV.isUnknownOrUndef() || LV.isConstant() || - (LV.isConstantRange() && LV.getConstantRange().isSingleElement())) { + bool IsOverdefined = ArgTy->isStructTy() + ? any_of(Solver.getStructLatticeValueFor(A), SCCPSolver::isOverdefined) + : SCCPSolver::isOverdefined(Solver.getLatticeValueFor(A)); + + if (IsOverdefined) + LLVM_DEBUG(dbgs() << "FnSpecialization: Found interesting parameter " + << A->getNameOrAsOperand() << "\n"); + else LLVM_DEBUG(dbgs() << "FnSpecialization: Nothing to do, parameter " << A->getNameOrAsOperand() << " is already constant\n"); - return false; - } - LLVM_DEBUG(dbgs() << "FnSpecialization: Found interesting parameter " - << A->getNameOrAsOperand() << "\n"); - - return true; + return IsOverdefined; } -/// Check if the valuy \p V (an actual argument) is a constant or can only +/// Check if the value \p V (an actual argument) is a constant or can only /// have a constant value. Return that constant. Constant *FunctionSpecializer::getCandidateConstant(Value *V) { if (isa(V)) @@ -719,18 +716,8 @@ // Select for possible specialisation values that are constants or // are deduced to be constants or constant ranges with a single element. Constant *C = dyn_cast(V); - if (!C) { - const ValueLatticeElement &LV = Solver.getLatticeValueFor(V); - if (LV.isConstant()) - C = LV.getConstant(); - else if (LV.isConstantRange() && LV.getConstantRange().isSingleElement()) { - assert(V->getType()->isIntegerTy() && "Non-integral constant range"); - C = Constant::getIntegerValue(V->getType(), - *LV.getConstantRange().getSingleElement()); - } else - return nullptr; - } - + if (!C) + C = Solver.getConstantOrNull(V); return C; } diff --git a/llvm/lib/Transforms/Utils/SCCPSolver.cpp b/llvm/lib/Transforms/Utils/SCCPSolver.cpp --- a/llvm/lib/Transforms/Utils/SCCPSolver.cpp +++ b/llvm/lib/Transforms/Utils/SCCPSolver.cpp @@ -73,30 +73,9 @@ } bool SCCPSolver::tryToReplaceWithConstant(Value *V) { - Constant *Const = nullptr; - if (V->getType()->isStructTy()) { - std::vector IVs = getStructLatticeValueFor(V); - if (llvm::any_of(IVs, isOverdefined)) - return false; - std::vector ConstVals; - auto *ST = cast(V->getType()); - for (unsigned i = 0, e = ST->getNumElements(); i != e; ++i) { - ValueLatticeElement V = IVs[i]; - ConstVals.push_back(SCCPSolver::isConstant(V) - ? getConstant(V) - : UndefValue::get(ST->getElementType(i))); - } - Const = ConstantStruct::get(ST, ConstVals); - } else { - const ValueLatticeElement &IV = getLatticeValueFor(V); - if (isOverdefined(IV)) - return false; - - Const = SCCPSolver::isConstant(IV) ? getConstant(IV) - : UndefValue::get(V->getType()); - } - assert(Const && "Constant is nullptr here!"); - + Constant *Const = getConstantOrNull(V); + if (!Const) + return false; // Replacing `musttail` instructions with constant breaks `musttail` invariant // unless the call itself can be removed. // Calls with "clang.arc.attachedcall" implicitly use the return value and @@ -734,6 +713,8 @@ Constant *getConstant(const ValueLatticeElement &LV) const; + Constant *getConstantOrNull(Value *V) const; + SmallPtrSetImpl &getArgumentTrackedFunctions() { return TrackingIncomingArguments; } @@ -833,6 +814,32 @@ return nullptr; } +Constant *SCCPInstVisitor::getConstantOrNull(Value *V) const { + Constant *Const; + if (V->getType()->isStructTy()) { + std::vector LVs = getStructLatticeValueFor(V); + if (any_of(LVs, SCCPSolver::isOverdefined)) + return nullptr; + std::vector ConstVals; + auto *ST = cast(V->getType()); + for (unsigned I = 0, E = ST->getNumElements(); I != E; ++I) { + ValueLatticeElement LV = LVs[I]; + ConstVals.push_back(SCCPSolver::isConstant(LV) + ? getConstant(LV) + : UndefValue::get(ST->getElementType(I))); + } + Const = ConstantStruct::get(ST, ConstVals); + } else { + const ValueLatticeElement &LV = getLatticeValueFor(V); + if (SCCPSolver::isOverdefined(LV)) + return nullptr; + Const = SCCPSolver::isConstant(LV) ? getConstant(LV) + : UndefValue::get(V->getType()); + } + assert(Const && "Constant is nullptr here!"); + return Const; +} + void SCCPInstVisitor::markArgInFuncSpecialization( Function *F, const SmallVectorImpl &Args) { assert(!Args.empty() && "Specialization without arguments"); @@ -847,22 +854,19 @@ LLVM_DEBUG(dbgs() << "SCCP: Marking argument " << NewArg->getNameOrAsOperand() << "\n"); - if (Iter != Args.end() && OldArg == Iter->Formal) { - // Mark the argument constants in the new function. - markConstant(NewArg, Iter->Actual); - ++Iter; - } else if (ValueState.count(OldArg)) { - // For the remaining arguments in the new function, copy the lattice state - // over from the old function. - // - // Note: This previously looked like this: - // ValueState[NewArg] = ValueState[OldArg]; - // This is incorrect because the DenseMap class may resize the underlying - // memory when inserting `NewArg`, which will invalidate the reference to - // `OldArg`. Instead, we make sure `NewArg` exists before setting it. - auto &NewValue = ValueState[NewArg]; - NewValue = ValueState[OldArg]; - pushToWorkList(NewValue, NewArg); + // Mark the argument constants in the new function + // or copy the lattice state over from the old function. + Value *V = (Iter != Args.end() && OldArg == Iter->Formal) ? + cast((Iter++)->Actual) : cast(OldArg); + + if (auto *STy = dyn_cast(NewArg->getType())) { + for (unsigned I = 0, E = STy->getNumElements(); I != E; ++I) { + ValueLatticeElement LV = getStructValueState(V, I); + mergeInValue(getStructValueState(NewArg, I), NewArg, LV); + } + } else { + const ValueLatticeElement &LV = getValueState(V); + mergeInValue(getValueState(NewArg), NewArg, LV); } } } @@ -1945,6 +1949,10 @@ return Visitor->getConstant(LV); } +Constant *SCCPSolver::getConstantOrNull(Value *V) const { + return Visitor->getConstantOrNull(V); +} + SmallPtrSetImpl &SCCPSolver::getArgumentTrackedFunctions() { return Visitor->getArgumentTrackedFunctions(); } diff --git a/llvm/test/Transforms/FunctionSpecialization/constant-struct.ll b/llvm/test/Transforms/FunctionSpecialization/constant-struct.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/FunctionSpecialization/constant-struct.ll @@ -0,0 +1,47 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py + +; RUN: opt -passes="ipsccp" -force-function-specialization \ +; RUN: -function-specialization-for-literal-constant -S < %s | FileCheck %s + +define i32 @foo(i32 %y0, i32 %y1) { +; CHECK-LABEL: @foo( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[Y:%.*]] = insertvalue { i32, i32 } undef, i32 [[Y0:%.*]], 0 +; CHECK-NEXT: [[YY:%.*]] = insertvalue { i32, i32 } [[Y]], i32 [[Y1:%.*]], 1 +; CHECK-NEXT: [[CALL:%.*]] = tail call i32 @add.1({ i32, i32 } { i32 2, i32 3 }, { i32, i32 } [[YY]]) +; CHECK-NEXT: ret i32 [[CALL]] +; +entry: + %y = insertvalue { i32, i32 } undef, i32 %y0, 0 + %yy = insertvalue { i32, i32 } %y, i32 %y1, 1 + %call = tail call i32 @add({i32, i32} {i32 2, i32 3}, {i32, i32} %yy) + ret i32 %call +} + +define i32 @bar(i32 %x0, i32 %x1) { +; CHECK-LABEL: @bar( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[X:%.*]] = insertvalue { i32, i32 } undef, i32 [[X0:%.*]], 0 +; CHECK-NEXT: [[XX:%.*]] = insertvalue { i32, i32 } [[X]], i32 [[X1:%.*]], 1 +; CHECK-NEXT: [[CALL:%.*]] = tail call i32 @add.2({ i32, i32 } [[XX]], { i32, i32 } { i32 3, i32 2 }) +; CHECK-NEXT: ret i32 [[CALL]] +; +entry: + %x = insertvalue { i32, i32 } undef, i32 %x0, 0 + %xx = insertvalue { i32, i32 } %x, i32 %x1, 1 + %call = tail call i32 @add({i32, i32} %xx, {i32, i32} {i32 3, i32 2}) + ret i32 %call +} + +define internal i32 @add({i32, i32} %x, {i32, i32} %y) { +entry: + %x0 = extractvalue {i32, i32} %x, 0 + %y0 = extractvalue {i32, i32} %y, 0 + %add0 = add nsw i32 %x0, %y0 + %x1 = extractvalue {i32, i32} %x, 1 + %y1 = extractvalue {i32, i32} %y, 1 + %add1 = add nsw i32 %x1, %y1 + %mul = mul i32 %add0, %add1 + ret i32 %mul +} +