diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp --- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp +++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp @@ -2943,6 +2943,23 @@ unsigned NewWidth = Known.getBitWidth() - std::max(LeadingKnownZeros, LeadingKnownOnes); + // Change 'switch (zext X)' into 'switch X' first, before trying the more + // aggressive trunc rewrite. This implements the same behavior as + // InstCombineCompares::foldICmpWithZextOrSext() for switch, and improves + // effectiveness of later passes, e.g. global value numbering. + Value *CastOp; + if (match(Cond, m_ZExtOrSExt(m_Value(CastOp)))) { + unsigned CastOpWidth = computeKnownBits(CastOp, 0, &SI).getBitWidth(); + if (NewWidth > 0 && NewWidth <= CastOpWidth) { + for (auto Case : SI.cases()) { + APInt TruncatedCase = + Case.getCaseValue()->getValue().trunc(CastOpWidth); + Case.setValue(ConstantInt::get(SI.getContext(), TruncatedCase)); + } + return replaceOperand(SI, 0, CastOp); + } + } + // Shrink the condition operand if the new type is smaller than the old type. // But do not shrink to a non-standard type, because backend can't generate // good code for that yet. diff --git a/llvm/test/Transforms/InstCombine/narrow-switch.ll b/llvm/test/Transforms/InstCombine/narrow-switch.ll --- a/llvm/test/Transforms/InstCombine/narrow-switch.ll +++ b/llvm/test/Transforms/InstCombine/narrow-switch.ll @@ -260,3 +260,30 @@ ret void } +define void @switch_zext_with_range(i32 *%p) { +; ALL-LABEL: @switch_zext_with_range( +; ALL: switch i32 +; ALL-NEXT: i32 0, label +; ALL-NEXT: i32 1, label +; ALL-NEXT: i32 2, label +; ALL-NEXT: i32 3, label +; ALL-NEXT: ] +; +entry: + %x = load i32, i32* %p, align 4, !range !0 + %zx = zext i32 %x to i64 + switch i64 %zx, label %bb.err [ + i64 0, label %bb.end + i64 1, label %bb.end + i64 2, label %bb.end + i64 3, label %bb.end + ] + +bb.end: + ret void + +bb.err: + unreachable +} + +!0 = !{i32 0, i32 4}