Index: llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp
===================================================================
--- llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp
+++ llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp
@@ -259,6 +259,39 @@
   }
 };
 
+struct CallSiteInfo {
+  std::vector<VirtualCallSite> CallSites;
+};
+
+struct VTableSlotInfo {
+  CallSiteInfo CSInfo;
+  std::map<std::vector<uint64_t>, CallSiteInfo> ConstCSInfo;
+
+  void addCallSite(Value *VTable, CallSite CS, unsigned *NumUnsafeUses);
+
+private:
+  CallSiteInfo &findCallSiteInfo(CallSite CS);
+};
+
+CallSiteInfo &VTableSlotInfo::findCallSiteInfo(CallSite CS) {
+  std::vector<uint64_t> Args;
+  auto *CI = dyn_cast<IntegerType>(CS.getType());
+  if (!CI || CI->getBitWidth() > 64)
+    return CSInfo;
+  for (auto &&Arg : make_range(CS.arg_begin() + 1, CS.arg_end())) {
+    auto *CI = dyn_cast<ConstantInt>(Arg);
+    if (!CI || CI->getBitWidth() > 64)
+      return CSInfo;
+    Args.push_back(CI->getZExtValue());
+  }
+  return ConstCSInfo[Args];
+}
+
+void VTableSlotInfo::addCallSite(Value *VTable, CallSite CS,
+                                 unsigned *NumUnsafeUses) {
+  findCallSiteInfo(CS).CallSites.push_back({VTable, CS, NumUnsafeUses});
+}
+
 struct DevirtModule {
   Module &M;
   IntegerType *Int8Ty;
@@ -267,7 +300,7 @@
 
   bool RemarksEnabled;
 
-  MapVector<VTableSlot, std::vector<VirtualCallSite>> CallSlots;
+  MapVector<VTableSlot, VTableSlotInfo> CallSlots;
 
   // This map keeps track of the number of "unsafe" uses of a loaded function
   // pointer. The key is the associated llvm.type.test intrinsic call generated
@@ -299,18 +332,17 @@
                             const std::set<TypeMemberInfo> &TypeMemberInfos,
                             uint64_t ByteOffset);
   bool trySingleImplDevirt(MutableArrayRef<VirtualCallTarget> TargetsForSlot,
-                           MutableArrayRef<VirtualCallSite> CallSites);
+                           VTableSlotInfo &SlotInfo);
   bool tryEvaluateFunctionsWithArgs(
       MutableArrayRef<VirtualCallTarget> TargetsForSlot,
-      ArrayRef<ConstantInt *> Args);
-  bool tryUniformRetValOpt(IntegerType *RetType,
-                           MutableArrayRef<VirtualCallTarget> TargetsForSlot,
-                           MutableArrayRef<VirtualCallSite> CallSites);
+      ArrayRef<uint64_t> Args);
+  bool tryUniformRetValOpt(MutableArrayRef<VirtualCallTarget> TargetsForSlot,
+                           CallSiteInfo &CSInfo);
   bool tryUniqueRetValOpt(unsigned BitWidth,
                           MutableArrayRef<VirtualCallTarget> TargetsForSlot,
-                          MutableArrayRef<VirtualCallSite> CallSites);
+                          CallSiteInfo &CSInfo);
   bool tryVirtualConstProp(MutableArrayRef<VirtualCallTarget> TargetsForSlot,
-                           ArrayRef<VirtualCallSite> CallSites);
+                           VTableSlotInfo &SlotInfo);
 
   void rebuildGlobal(VTableBits &B);
 
@@ -445,7 +477,7 @@
 
 bool DevirtModule::trySingleImplDevirt(
     MutableArrayRef<VirtualCallTarget> TargetsForSlot,
-    MutableArrayRef<VirtualCallSite> CallSites) {
+    VTableSlotInfo &SlotInfo) {
   // See if the program contains a single implementation of this virtual
   // function.
   Function *TheFn = TargetsForSlot[0].Fn;
@@ -456,36 +488,44 @@
   if (RemarksEnabled)
     TargetsForSlot[0].WasDevirt = true;
   // If so, update each call site to call that implementation directly.
-  for (auto &&VCallSite : CallSites) {
-    if (RemarksEnabled)
-      VCallSite.emitRemark("single-impl", TheFn->getName());
-    VCallSite.CS.setCalledFunction(ConstantExpr::getBitCast(
-        TheFn, VCallSite.CS.getCalledValue()->getType()));
-    // This use is no longer unsafe.
-    if (VCallSite.NumUnsafeUses)
-      --*VCallSite.NumUnsafeUses;
-  }
+  auto Apply = [&](CallSiteInfo &CSInfo) {
+    for (auto &&VCallSite : CSInfo.CallSites) {
+      if (RemarksEnabled)
+        VCallSite.emitRemark("single-impl", TheFn->getName());
+      VCallSite.CS.setCalledFunction(ConstantExpr::getBitCast(
+          TheFn, VCallSite.CS.getCalledValue()->getType()));
+      // This use is no longer unsafe.
+      if (VCallSite.NumUnsafeUses)
+        --*VCallSite.NumUnsafeUses;
+    }
+  };
+  Apply(SlotInfo.CSInfo);
+  for (auto &P : SlotInfo.ConstCSInfo)
+    Apply(P.second);
   return true;
 }
 
 bool DevirtModule::tryEvaluateFunctionsWithArgs(
     MutableArrayRef<VirtualCallTarget> TargetsForSlot,
-    ArrayRef<ConstantInt *> Args) {
+    ArrayRef<uint64_t> Args) {
   // Evaluate each function and store the result in each target's RetVal
   // field.
   for (VirtualCallTarget &Target : TargetsForSlot) {
     if (Target.Fn->arg_size() != Args.size() + 1)
       return false;
-    for (unsigned I = 0; I != Args.size(); ++I)
-      if (Target.Fn->getFunctionType()->getParamType(I + 1) !=
-          Args[I]->getType())
-        return false;
 
     Evaluator Eval(M.getDataLayout(), nullptr);
     SmallVector<Constant *, 2> EvalArgs;
     EvalArgs.push_back(
         Constant::getNullValue(Target.Fn->getFunctionType()->getParamType(0)));
-    EvalArgs.insert(EvalArgs.end(), Args.begin(), Args.end());
+    for (unsigned I = 0; I != Args.size(); ++I) {
+      auto *ArgTy = dyn_cast<IntegerType>(
+          Target.Fn->getFunctionType()->getParamType(I + 1));
+      if (!ArgTy)
+        return false;
+      EvalArgs.push_back(ConstantInt::get(ArgTy, Args[I]));
+    }
+
     Constant *RetVal;
     if (!Eval.EvaluateFunction(Target.Fn, RetVal, EvalArgs) ||
         !isa<ConstantInt>(RetVal))
@@ -496,8 +536,7 @@
 }
 
 bool DevirtModule::tryUniformRetValOpt(
-    IntegerType *RetType, MutableArrayRef<VirtualCallTarget> TargetsForSlot,
-    MutableArrayRef<VirtualCallSite> CallSites) {
+    MutableArrayRef<VirtualCallTarget> TargetsForSlot, CallSiteInfo &CSInfo) {
   // Uniform return value optimization. If all functions return the same
   // constant, replace all calls with that constant.
   uint64_t TheRetVal = TargetsForSlot[0].RetVal;
@@ -505,10 +544,10 @@
     if (Target.RetVal != TheRetVal)
       return false;
 
-  auto TheRetValConst = ConstantInt::get(RetType, TheRetVal);
-  for (auto Call : CallSites)
+  for (auto Call : CSInfo.CallSites)
     Call.replaceAndErase("uniform-ret-val", TargetsForSlot[0].Fn->getName(),
-                         RemarksEnabled, TheRetValConst);
+                         RemarksEnabled,
+                         ConstantInt::get(Call.CS->getType(), TheRetVal));
   if (RemarksEnabled)
     for (auto &&Target : TargetsForSlot)
       Target.WasDevirt = true;
@@ -517,7 +556,7 @@
 
 bool DevirtModule::tryUniqueRetValOpt(
     unsigned BitWidth, MutableArrayRef<VirtualCallTarget> TargetsForSlot,
-    MutableArrayRef<VirtualCallSite> CallSites) {
+    CallSiteInfo &CSInfo) {
   // IsOne controls whether we look for a 0 or a 1.
   auto tryUniqueRetValOptFor = [&](bool IsOne) {
     const TypeMemberInfo *UniqueMember = nullptr;
@@ -534,12 +573,13 @@
     assert(UniqueMember);
 
     // Replace each call with the comparison.
-    for (auto &&Call : CallSites) {
+    for (auto &&Call : CSInfo.CallSites) {
       IRBuilder<> B(Call.CS.getInstruction());
       Value *OneAddr = B.CreateBitCast(UniqueMember->Bits->GV, Int8PtrTy);
       OneAddr = B.CreateConstGEP1_64(OneAddr, UniqueMember->Offset);
       Value *Cmp = B.CreateICmp(IsOne ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE,
                                 Call.VTable, OneAddr);
+      Cmp = B.CreateZExt(Cmp, Call.CS->getType());
       Call.replaceAndErase("unique-ret-val", TargetsForSlot[0].Fn->getName(),
                            RemarksEnabled, Cmp);
     }
@@ -562,7 +602,7 @@
 
 bool DevirtModule::tryVirtualConstProp(
     MutableArrayRef<VirtualCallTarget> TargetsForSlot,
-    ArrayRef<VirtualCallSite> CallSites) {
+    VTableSlotInfo &SlotInfo) {
   // This only works if the function returns an integer.
   auto RetType = dyn_cast<IntegerType>(TargetsForSlot[0].Fn->getReturnType());
   if (!RetType)
@@ -581,42 +621,11 @@
       return false;
   }
 
-  // Group call sites by the list of constant arguments they pass.
-  // The comparator ensures deterministic ordering.
-  struct ByAPIntValue {
-    bool operator()(const std::vector<ConstantInt *> &A,
-                    const std::vector<ConstantInt *> &B) const {
-      return std::lexicographical_compare(
-          A.begin(), A.end(), B.begin(), B.end(),
-          [](ConstantInt *AI, ConstantInt *BI) {
-            return AI->getValue().ult(BI->getValue());
-          });
-    }
-  };
-  std::map<std::vector<ConstantInt *>, std::vector<VirtualCallSite>,
-           ByAPIntValue>
-      VCallSitesByConstantArg;
-  for (auto &&VCallSite : CallSites) {
-    std::vector<ConstantInt *> Args;
-    if (VCallSite.CS.getType() != RetType)
-      continue;
-    for (auto &&Arg :
-         make_range(VCallSite.CS.arg_begin() + 1, VCallSite.CS.arg_end())) {
-      if (!isa<ConstantInt>(Arg))
-        break;
-      Args.push_back(cast<ConstantInt>(&Arg));
-    }
-    if (Args.size() + 1 != VCallSite.CS.arg_size())
-      continue;
-
-    VCallSitesByConstantArg[Args].push_back(VCallSite);
-  }
-
-  for (auto &&CSByConstantArg : VCallSitesByConstantArg) {
+  for (auto &&CSByConstantArg : SlotInfo.ConstCSInfo) {
     if (!tryEvaluateFunctionsWithArgs(TargetsForSlot, CSByConstantArg.first))
       continue;
 
-    if (tryUniformRetValOpt(RetType, TargetsForSlot, CSByConstantArg.second))
+    if (tryUniformRetValOpt(TargetsForSlot, CSByConstantArg.second))
       continue;
 
     if (tryUniqueRetValOpt(BitWidth, TargetsForSlot, CSByConstantArg.second))
@@ -660,20 +669,23 @@
         Target.WasDevirt = true;
 
     // Rewrite each call to a load from OffsetByte/OffsetBit.
-    for (auto Call : CSByConstantArg.second) {
+    for (auto Call : CSByConstantArg.second.CallSites) {
+      auto *CSRetType = cast<IntegerType>(Call.CS.getType());
       IRBuilder<> B(Call.CS.getInstruction());
       Value *Addr = B.CreateConstGEP1_64(Call.VTable, OffsetByte);
-      if (BitWidth == 1) {
+      if (CSRetType->getBitWidth() == 1) {
         Value *Bits = B.CreateLoad(Addr);
         Value *Bit = ConstantInt::get(Int8Ty, 1ULL << OffsetBit);
         Value *BitsAndBit = B.CreateAnd(Bits, Bit);
-        auto IsBitSet = B.CreateICmpNE(BitsAndBit, ConstantInt::get(Int8Ty, 0));
+        Value *IsBitSet =
+            B.CreateICmpNE(BitsAndBit, ConstantInt::get(Int8Ty, 0));
+        IsBitSet = B.CreateZExt(IsBitSet, Call.CS->getType());
         Call.replaceAndErase("virtual-const-prop-1-bit",
                              TargetsForSlot[0].Fn->getName(),
                              RemarksEnabled, IsBitSet);
       } else {
-        Value *ValAddr = B.CreateBitCast(Addr, RetType->getPointerTo());
-        Value *Val = B.CreateLoad(RetType, ValAddr);
+        Value *ValAddr = B.CreateBitCast(Addr, CSRetType->getPointerTo());
+        Value *Val = B.CreateLoad(CSRetType, ValAddr);
         Call.replaceAndErase("virtual-const-prop",
                              TargetsForSlot[0].Fn->getName(),
                              RemarksEnabled, Val);
@@ -766,8 +778,8 @@
       Value *Ptr = CI->getArgOperand(0)->stripPointerCasts();
       if (SeenPtrs.insert(Ptr).second) {
         for (DevirtCallSite Call : DevirtCalls) {
-          CallSlots[{TypeId, Call.Offset}].push_back(
-              {CI->getArgOperand(0), Call.CS, nullptr});
+          CallSlots[{TypeId, Call.Offset}].addCallSite(CI->getArgOperand(0),
+                                                       Call.CS, nullptr);
         }
       }
     }
@@ -853,8 +865,8 @@
     if (HasNonCallUses)
       ++NumUnsafeUses;
     for (DevirtCallSite Call : DevirtCalls) {
-      CallSlots[{TypeId, Call.Offset}].push_back(
-          {Ptr, Call.CS, &NumUnsafeUses});
+      CallSlots[{TypeId, Call.Offset}].addCallSite(Ptr, Call.CS,
+                                                   &NumUnsafeUses);
     }
 
     CI->eraseFromParent();
Index: llvm/test/Transforms/WholeProgramDevirt/vcp-too-wide-ints.ll
===================================================================
--- llvm/test/Transforms/WholeProgramDevirt/vcp-too-wide-ints.ll
+++ llvm/test/Transforms/WholeProgramDevirt/vcp-too-wide-ints.ll
@@ -3,19 +3,21 @@
 target datalayout = "e-p:64:64"
 target triple = "x86_64-unknown-linux-gnu"
 
-@vt1 = constant [1 x i8*] [i8* bitcast (i128 (i8*, i128)* @vf1 to i8*)], !type !0
-@vt2 = constant [1 x i8*] [i8* bitcast (i128 (i8*, i128)* @vf2 to i8*)], !type !0
+@vt1 = constant [1 x i8*] [i8* bitcast (i64 (i8*, i128)* @vf1 to i8*)], !type !0
+@vt2 = constant [1 x i8*] [i8* bitcast (i64 (i8*, i128)* @vf2 to i8*)], !type !0
 
-define i128 @vf1(i8* %this, i128 %arg) readnone {
-  ret i128 %arg
+define i64 @vf1(i8* %this, i128 %arg) readnone {
+  %argtrunc = trunc i128 %arg to i64
+  ret i64 %argtrunc
 }
 
-define i128 @vf2(i8* %this, i128 %arg) readnone {
-  ret i128 %arg
+define i64 @vf2(i8* %this, i128 %arg) readnone {
+  %argtrunc = trunc i128 %arg to i64
+  ret i64 %argtrunc
 }
 
-; CHECK: define i128 @call
-define i128 @call(i8* %obj) {
+; CHECK: define i128 @call1
+define i128 @call1(i8* %obj) {
   %vtableptr = bitcast i8* %obj to [1 x i8*]**
   %vtable = load [1 x i8*]*, [1 x i8*]** %vtableptr
   %vtablei8 = bitcast [1 x i8*]* %vtable to i8*
@@ -29,6 +31,21 @@
   ret i128 %result
 }
 
+; CHECK: define i64 @call2
+define i64 @call2(i8* %obj) {
+  %vtableptr = bitcast i8* %obj to [1 x i8*]**
+  %vtable = load [1 x i8*]*, [1 x i8*]** %vtableptr
+  %vtablei8 = bitcast [1 x i8*]* %vtable to i8*
+  %p = call i1 @llvm.type.test(i8* %vtablei8, metadata !"typeid")
+  call void @llvm.assume(i1 %p)
+  %fptrptr = getelementptr [1 x i8*], [1 x i8*]* %vtable, i32 0, i32 0
+  %fptr = load i8*, i8** %fptrptr
+  %fptr_casted = bitcast i8* %fptr to i64 (i8*, i128)*
+  ; CHECK: call i64 %
+  %result = call i64 %fptr_casted(i8* %obj, i128 1)
+  ret i64 %result
+}
+
 declare i1 @llvm.type.test(i8*, metadata)
 declare void @llvm.assume(i1)
 
Index: llvm/test/Transforms/WholeProgramDevirt/vcp-type-mismatch.ll
===================================================================
--- llvm/test/Transforms/WholeProgramDevirt/vcp-type-mismatch.ll
+++ llvm/test/Transforms/WholeProgramDevirt/vcp-type-mismatch.ll
@@ -24,8 +24,8 @@
   %fptrptr = getelementptr [1 x i8*], [1 x i8*]* %vtable, i32 0, i32 0
   %fptr = load i8*, i8** %fptrptr
   %fptr_casted = bitcast i8* %fptr to i32 (i8*, i64)*
-  ; CHECK: call i32 %
   %result = call i32 %fptr_casted(i8* %obj, i64 1)
+  ; CHECK: ret i32 1
   ret i32 %result
 }
 
@@ -54,8 +54,8 @@
   %fptrptr = getelementptr [1 x i8*], [1 x i8*]* %vtable, i32 0, i32 0
   %fptr = load i8*, i8** %fptrptr
   %fptr_casted = bitcast i8* %fptr to i64 (i8*, i32)*
-  ; CHECK: call i64 %
   %result = call i64 %fptr_casted(i8* %obj, i32 1)
+  ; CHECK: ret i64 1
   ret i64 %result
 }