diff --git a/clang-tools-extra/clangd/refactor/tweaks/PopulateSwitch.cpp b/clang-tools-extra/clangd/refactor/tweaks/PopulateSwitch.cpp --- a/clang-tools-extra/clangd/refactor/tweaks/PopulateSwitch.cpp +++ b/clang-tools-extra/clangd/refactor/tweaks/PopulateSwitch.cpp @@ -40,6 +40,8 @@ #include "clang/Basic/SourceLocation.h" #include "clang/Basic/SourceManager.h" #include "clang/Tooling/Core/Replacement.h" +#include "llvm/ADT/MapVector.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallSet.h" #include #include @@ -112,15 +114,25 @@ if (!EnumD) return false; - // We trigger if there are fewer cases than enum values (and no case covers - // multiple values). This guarantees we'll have at least one case to insert. - // We don't yet determine what the cases are, as that means evaluating - // expressions. - auto I = EnumD->enumerator_begin(); - auto E = EnumD->enumerator_end(); + // Special case of the empty enum + if (EnumD->enumerator_begin() == EnumD->enumerator_end()) + return false; + + ASTContext &Ctx = Sel.AST->getASTContext(); + + unsigned EnumIntWidth = Ctx.getIntWidth(QualType(EnumT, 0)); + bool EnumIsSigned = EnumT->isSignedIntegerOrEnumerationType(); + + llvm::SmallMapVector EnumConstants; + for (auto *EnumConstant : EnumD->enumerators()) { + llvm::APSInt Val = EnumConstant->getInitVal(); + Val = Val.extOrTrunc(EnumIntWidth); + Val.setIsSigned(EnumIsSigned); + EnumConstants.insert(std::make_pair(Val, false)); + } - for (const SwitchCase *CaseList = Switch->getSwitchCaseList(); - CaseList && I != E; CaseList = CaseList->getNextSwitchCase(), I++) { + for (const SwitchCase *CaseList = Switch->getSwitchCaseList(); CaseList; + CaseList = CaseList->getNextSwitchCase()) { // Default likely intends to cover cases we'd insert. if (isa(CaseList)) return false; @@ -135,10 +147,20 @@ const ConstantExpr *CE = dyn_cast(CS->getLHS()); if (!CE || CE->isValueDependent()) return false; + + if (CE->getResultStorageKind() != ConstantExpr::RSK_Int64) + return false; + llvm::APSInt Val = CE->getResultAsAPSInt(); + Val = Val.extOrTrunc(EnumIntWidth); + Val.setIsSigned(EnumIsSigned); + auto *Iter = EnumConstants.find(Val); + if (Iter == EnumConstants.end()) + return false; + Iter->second = true; + continue; } - // Only suggest tweak if we have more enumerators than cases. - return I != E; + return !llvm::all_of(EnumConstants, [](auto &Pair) { return Pair.second; }); } Expected PopulateSwitch::apply(const Selection &Sel) { @@ -166,7 +188,10 @@ std::string Text; for (EnumConstantDecl *Enumerator : EnumD->enumerators()) { - if (ExistingEnumerators.contains(Enumerator->getInitVal())) + // Try to insert this Enumerator into the set. If this fails, the value was + // either already there to begin with or we have already added it using a + // different name. + if (!ExistingEnumerators.insert(Enumerator->getInitVal()).second) continue; Text += "case "; diff --git a/clang-tools-extra/clangd/unittests/TweakTests.cpp b/clang-tools-extra/clangd/unittests/TweakTests.cpp --- a/clang-tools-extra/clangd/unittests/TweakTests.cpp +++ b/clang-tools-extra/clangd/unittests/TweakTests.cpp @@ -2980,6 +2980,18 @@ void function() { switch (ns::A) {case ns::A:break;} } )"", }, + { + // Duplicated constant names + Function, + R""(enum Enum {A,B,b=B}; ^switch (A) {})"", + R""(enum Enum {A,B,b=B}; switch (A) {case A:case B:break;})"", + }, + { + // Duplicated constant names all in switch + Function, + R""(enum Enum {A,B,b=B}; ^switch (A) {case A:case B:break;})"", + "unavailable", + }, }; for (const auto &Case : Cases) { diff --git a/llvm/include/llvm/ADT/DenseMapInfo.h b/llvm/include/llvm/ADT/DenseMapInfo.h --- a/llvm/include/llvm/ADT/DenseMapInfo.h +++ b/llvm/include/llvm/ADT/DenseMapInfo.h @@ -14,6 +14,7 @@ #define LLVM_ADT_DENSEMAPINFO_H #include "llvm/ADT/APInt.h" +#include "llvm/ADT/APSInt.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/Hashing.h" #include "llvm/ADT/StringRef.h" @@ -371,6 +372,26 @@ } }; +/// Provide DenseMapInfo for APSInt, using the DenseMapInfo for APInt. +template <> struct DenseMapInfo { + static inline APSInt getEmptyKey() { + return APSInt(DenseMapInfo::getEmptyKey()); + } + + static inline APSInt getTombstoneKey() { + return APSInt(DenseMapInfo::getTombstoneKey()); + } + + static unsigned getHashValue(const APSInt &Key) { + return static_cast(hash_value(Key)); + } + + static bool isEqual(const APSInt &LHS, const APSInt &RHS) { + return LHS.getBitWidth() == RHS.getBitWidth() && + LHS.isUnsigned() == RHS.isUnsigned() && LHS == RHS; + } +}; + } // end namespace llvm #endif // LLVM_ADT_DENSEMAPINFO_H