diff --git a/llvm/include/llvm/Support/KnownBits.h b/llvm/include/llvm/Support/KnownBits.h
--- a/llvm/include/llvm/Support/KnownBits.h
+++ b/llvm/include/llvm/Support/KnownBits.h
@@ -166,6 +166,16 @@
     return *this;
   }
 
+  /// Return known bits for a sign extension or truncation of the value we're
+  /// tracking.
+  KnownBits sextOrTrunc(unsigned BitWidth) const {
+    if (BitWidth > getBitWidth())
+      return sext(BitWidth);
+    if (BitWidth < getBitWidth())
+      return trunc(BitWidth);
+    return *this;
+  }
+
   /// Return a KnownBits with the extracted bits
   /// [bitPosition,bitPosition+numBits).
   KnownBits extractBits(unsigned NumBits, unsigned BitPosition) const {
diff --git a/llvm/unittests/Support/KnownBitsTest.cpp b/llvm/unittests/Support/KnownBitsTest.cpp
--- a/llvm/unittests/Support/KnownBitsTest.cpp
+++ b/llvm/unittests/Support/KnownBitsTest.cpp
@@ -203,4 +203,33 @@
   });
 }
 
+TEST(KnownBitsTest, SExtOrTrunc) {
+  const unsigned NarrowerSize = 4;
+  const unsigned BaseSize = 6;
+  const unsigned WiderSize = 8;
+  APInt NegativeFitsNarrower(BaseSize, -4, /*isSigned*/ true);
+  APInt NegativeDoesntFitNarrower(BaseSize, -28, /*isSigned*/ true);
+  APInt PositiveFitsNarrower(BaseSize, 14);
+  APInt PositiveDoesntFitNarrower(BaseSize, 36);
+  auto InitKnownBits = [&](KnownBits &Res, const APInt &Input) {
+    Res = KnownBits(Input.getBitWidth());
+    Res.One = Input;
+    Res.Zero = ~Input;
+  };
+
+  for (unsigned Size : {NarrowerSize, BaseSize, WiderSize}) {
+    for (const APInt &Input :
+         {NegativeFitsNarrower, NegativeDoesntFitNarrower, PositiveFitsNarrower,
+          PositiveDoesntFitNarrower}) {
+      KnownBits Test;
+      InitKnownBits(Test, Input);
+      KnownBits Baseline;
+      InitKnownBits(Baseline, Input.sextOrTrunc(Size));
+      Test = Test.sextOrTrunc(Size);
+      EXPECT_EQ(Test.One, Baseline.One);
+      EXPECT_EQ(Test.Zero, Baseline.Zero);
+    }
+  }
+}
+
 } // end anonymous namespace