diff --git a/clang-tools-extra/clangd/refactor/tweaks/PopulateSwitch.cpp b/clang-tools-extra/clangd/refactor/tweaks/PopulateSwitch.cpp --- a/clang-tools-extra/clangd/refactor/tweaks/PopulateSwitch.cpp +++ b/clang-tools-extra/clangd/refactor/tweaks/PopulateSwitch.cpp @@ -40,7 +40,8 @@ #include "clang/Basic/SourceLocation.h" #include "clang/Basic/SourceManager.h" #include "clang/Tooling/Core/Replacement.h" -#include "llvm/ADT/SmallSet.h" +#include "llvm/ADT/MapVector.h" +#include "llvm/ADT/STLExtras.h" #include #include @@ -57,11 +58,27 @@ } private: + class ExpectedCase { + public: + ExpectedCase(const EnumConstantDecl *Decl) : Data(Decl, false) {} + bool isCovered() const { return Data.getInt(); } + void setCovered(bool Val = true) { Data.setInt(Val); } + const EnumConstantDecl *getEnumConstant() const { + return Data.getPointer(); + } + + private: + llvm::PointerIntPair Data; + }; + const DeclContext *DeclCtx = nullptr; const SwitchStmt *Switch = nullptr; const CompoundStmt *Body = nullptr; const EnumType *EnumT = nullptr; const EnumDecl *EnumD = nullptr; + // Maps the Enum values to the EnumConstantDecl and a bool signifying if its + // covered in the switch. + llvm::MapVector ExpectedCases; }; REGISTER_TWEAK(PopulateSwitch) @@ -112,21 +129,34 @@ if (!EnumD) return false; - // We trigger if there are fewer cases than enum values (and no case covers - // multiple values). This guarantees we'll have at least one case to insert. - // We don't yet determine what the cases are, as that means evaluating - // expressions. - auto I = EnumD->enumerator_begin(); - auto E = EnumD->enumerator_end(); + // We trigger if there are any values in the enum that aren't covered by the + // switch. + + ASTContext &Ctx = Sel.AST->getASTContext(); + + unsigned EnumIntWidth = Ctx.getIntWidth(QualType(EnumT, 0)); + bool EnumIsSigned = EnumT->isSignedIntegerOrEnumerationType(); - for (const SwitchCase *CaseList = Switch->getSwitchCaseList(); - CaseList && I != E; CaseList = CaseList->getNextSwitchCase(), I++) { + auto Normalize = [&](llvm::APSInt Val) { + Val = Val.extOrTrunc(EnumIntWidth); + Val.setIsSigned(EnumIsSigned); + return Val; + }; + + for (auto *EnumConstant : EnumD->enumerators()) { + ExpectedCases.insert( + std::make_pair(Normalize(EnumConstant->getInitVal()), EnumConstant)); + } + + for (const SwitchCase *CaseList = Switch->getSwitchCaseList(); CaseList; + CaseList = CaseList->getNextSwitchCase()) { // Default likely intends to cover cases we'd insert. if (isa(CaseList)) return false; const CaseStmt *CS = cast(CaseList); - // Case statement covers multiple values, so just counting doesn't work. + + // GNU range cases are rare, we don't support them. if (CS->caseStmtIsGNURange()) return false; @@ -135,48 +165,36 @@ const ConstantExpr *CE = dyn_cast(CS->getLHS()); if (!CE || CE->isValueDependent()) return false; + + // Unsure if this case could ever come up, but prevents an unreachable + // executing in getResultAsAPSInt. + if (CE->getResultStorageKind() == ConstantExpr::RSK_None) + return false; + auto Iter = ExpectedCases.find(Normalize(CE->getResultAsAPSInt())); + if (Iter != ExpectedCases.end()) + Iter->second.setCovered(); } - // Only suggest tweak if we have more enumerators than cases. - return I != E; + return !llvm::all_of(ExpectedCases, + [](auto &Pair) { return Pair.second.isCovered(); }); } Expected PopulateSwitch::apply(const Selection &Sel) { ASTContext &Ctx = Sel.AST->getASTContext(); - // Get the enum's integer width and signedness, for adjusting case literals. - unsigned EnumIntWidth = Ctx.getIntWidth(QualType(EnumT, 0)); - bool EnumIsSigned = EnumT->isSignedIntegerOrEnumerationType(); - - llvm::SmallSet ExistingEnumerators; - for (const SwitchCase *CaseList = Switch->getSwitchCaseList(); CaseList; - CaseList = CaseList->getNextSwitchCase()) { - const CaseStmt *CS = cast(CaseList); - assert(!CS->caseStmtIsGNURange()); - const ConstantExpr *CE = cast(CS->getLHS()); - assert(!CE->isValueDependent()); - llvm::APSInt Val = CE->getResultAsAPSInt(); - Val = Val.extOrTrunc(EnumIntWidth); - Val.setIsSigned(EnumIsSigned); - ExistingEnumerators.insert(Val); - } - SourceLocation Loc = Body->getRBracLoc(); ASTContext &DeclASTCtx = DeclCtx->getParentASTContext(); - std::string Text; - for (EnumConstantDecl *Enumerator : EnumD->enumerators()) { - if (ExistingEnumerators.contains(Enumerator->getInitVal())) + llvm::SmallString<256> Text; + for (auto &EnumConstant : ExpectedCases) { + // Skip any enum constants already covered + if (EnumConstant.second.isCovered()) continue; - Text += "case "; - Text += getQualification(DeclASTCtx, DeclCtx, Loc, EnumD); - if (EnumD->isScoped()) { - Text += EnumD->getName(); - Text += "::"; - } - Text += Enumerator->getName(); - Text += ":"; + Text.append({"case ", getQualification(DeclASTCtx, DeclCtx, Loc, EnumD)}); + if (EnumD->isScoped()) + Text.append({EnumD->getName(), "::"}); + Text.append({EnumConstant.second.getEnumConstant()->getName(), ":"}); } assert(!Text.empty() && "No enumerators to insert!"); diff --git a/clang-tools-extra/clangd/unittests/TweakTests.cpp b/clang-tools-extra/clangd/unittests/TweakTests.cpp --- a/clang-tools-extra/clangd/unittests/TweakTests.cpp +++ b/clang-tools-extra/clangd/unittests/TweakTests.cpp @@ -2980,6 +2980,18 @@ void function() { switch (ns::A) {case ns::A:break;} } )"", }, + { + // Duplicated constant names + Function, + R""(enum Enum {A,B,b=B}; ^switch (A) {})"", + R""(enum Enum {A,B,b=B}; switch (A) {case A:case B:break;})"", + }, + { + // Duplicated constant names all in switch + Function, + R""(enum Enum {A,B,b=B}; ^switch (A) {case A:case B:break;})"", + "unavailable", + }, }; for (const auto &Case : Cases) { diff --git a/llvm/include/llvm/ADT/DenseMapInfo.h b/llvm/include/llvm/ADT/DenseMapInfo.h --- a/llvm/include/llvm/ADT/DenseMapInfo.h +++ b/llvm/include/llvm/ADT/DenseMapInfo.h @@ -14,6 +14,7 @@ #define LLVM_ADT_DENSEMAPINFO_H #include "llvm/ADT/APInt.h" +#include "llvm/ADT/APSInt.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/Hashing.h" #include "llvm/ADT/StringRef.h" @@ -371,6 +372,26 @@ } }; +/// Provide DenseMapInfo for APSInt, using the DenseMapInfo for APInt. +template <> struct DenseMapInfo { + static inline APSInt getEmptyKey() { + return APSInt(DenseMapInfo::getEmptyKey()); + } + + static inline APSInt getTombstoneKey() { + return APSInt(DenseMapInfo::getTombstoneKey()); + } + + static unsigned getHashValue(const APSInt &Key) { + return static_cast(hash_value(Key)); + } + + static bool isEqual(const APSInt &LHS, const APSInt &RHS) { + return LHS.getBitWidth() == RHS.getBitWidth() && + LHS.isUnsigned() == RHS.isUnsigned() && LHS == RHS; + } +}; + } // end namespace llvm #endif // LLVM_ADT_DENSEMAPINFO_H