Index: lib/Transforms/InstCombine/InstCombineCompares.cpp =================================================================== --- lib/Transforms/InstCombine/InstCombineCompares.cpp +++ lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -2008,6 +2008,246 @@ return ExtractValueInst::Create(Call, 1, "uadd.overflow"); } +/// \brief Recognize and process idiom involving test for multiplication +/// overflow. +/// +/// The caller has matched a pattern of the form: +/// I = cmp u (mul(zext A, zext B), V +/// The function checks if this is a test for overflow and if so replaces +/// multiplication with call to 'mul.with.overflow' intrinsic. +/// +/// \param I Compare instruction. +/// \param MulVal Result of 'mult' instruction. It is one of the arguments of +/// the compare instruction. Must be of integer type. +/// \param OtherVal The other argument of compare instruction. +/// \returns Instruction which must replace the compare instruction, NULL if no +/// replacement required. +static Instruction *ProcessUMulZExtIdiom(ICmpInst &I, Value *MulVal, + Value *OtherVal, InstCombiner &IC) { + assert(I.getOperand(0) == MulVal || I.getOperand(1) == MulVal); + assert(I.getOperand(0) == OtherVal || I.getOperand(1) == OtherVal); + assert(isa(MulVal->getType())); + Instruction *MulInstr = cast(MulVal); + assert(MulInstr->getOpcode() == Instruction::Mul); + + Instruction *LHS = cast(MulInstr->getOperand(0)), + *RHS = cast(MulInstr->getOperand(1)); + assert(LHS->getOpcode() == Instruction::ZExt); + assert(RHS->getOpcode() == Instruction::ZExt); + Value *A = LHS->getOperand(0), *B = RHS->getOperand(0); + + // Calculate type and width of the result produced by mul.with.overflow. + Type *TyA = A->getType(), *TyB = B->getType(); + unsigned WidthA = TyA->getPrimitiveSizeInBits(), + WidthB = TyB->getPrimitiveSizeInBits(); + unsigned MulWidth; + Type *MulType; + if (WidthB > WidthA) { + MulWidth = WidthB; + MulType = TyB; + } else { + MulWidth = WidthA; + MulType = TyA; + } + + // In order to replace the original mul with a narrower mul.with.overflow, + // all uses must ignore upper bits of the product. The number of used low + // bits must be not greater than the width of mul.with.overflow. + if (MulVal->hasNUsesOrMore(2)) + for (User *U : MulVal->users()) { + if (U == &I) + continue; + if (TruncInst *TI = dyn_cast(U)) { + // Check if truncation ignores bits above MulWidth. + unsigned TruncWidth = TI->getType()->getPrimitiveSizeInBits(); + if (TruncWidth > MulWidth) + return 0; + } else if (BinaryOperator *BO = dyn_cast(U)) { + // Check if & ignores bits above MulWidth. + if (BO->getOpcode() != Instruction::And) + return 0; + ConstantInt *CI = dyn_cast(BO->getOperand(1)); + if (!CI) { + CI = dyn_cast(BO->getOperand(0)); + if (CI) + I.swapOperands(); + else + return 0; + } + APInt CVal = CI->getValue(); + if (CVal.getBitWidth() - CVal.countLeadingZeros() > MulWidth) + return 0; + } else { + return 0; + } + } + + // Recognize patterns + switch (I.getPredicate()) { + case ICmpInst::ICMP_EQ: + case ICmpInst::ICMP_NE: + // Recognize pattern: + // mulval = mul(zext A, zext B) + // cmp eq/neq mulval, zext trunc mulval + if (Instruction *ZextInstr = dyn_cast(OtherVal)) + if (ZextInstr->getOpcode() == Instruction::ZExt && + ZextInstr->hasOneUse()) { + Value *ZextArg = ZextInstr->getOperand(0); + if (Instruction *TruncInstr = dyn_cast(ZextArg)) + if (TruncInstr->getOpcode() == Instruction::Trunc && + TruncInstr->getType()->getPrimitiveSizeInBits() == MulWidth) { + break; //Recognized + } + } + + // Recognize pattern: + // mulval = mul(zext A, zext B) + // cmp eq/neq mulval, and(mulval, mask), mask selects low MulWidth bits. + if (BinaryOperator *BO = dyn_cast(OtherVal)) + if (BO->getOpcode() == Instruction::And) + if (ConstantInt *CI = dyn_cast(BO->getOperand(1))) { + APInt CVal = CI->getValue() + 1; + if (CVal.isPowerOf2()) { + unsigned MaskWidth = CVal.logBase2(); + if (MaskWidth == MulWidth) + break; // Recognized + } + } + return 0; + + case ICmpInst::ICMP_UGT: + // Recognize pattern: + // mulval = mul(zext A, zext B) + // cmp ugt mulval, max + if (ConstantInt *CI = dyn_cast(OtherVal)) { + APInt MaxVal = APInt::getMaxValue(MulWidth); + MaxVal = MaxVal.zext(CI->getBitWidth()); + if (MaxVal.eq(CI->getValue())) { + break; // Recognized + } + } + return 0; + + case ICmpInst::ICMP_UGE: + // Recognize pattern: + // mulval = mul(zext A, zext B) + // cmp uge mulval, max+1 + if (ConstantInt *CI = dyn_cast(OtherVal)) { + APInt MaxVal(CI->getBitWidth(), 1ULL << MulWidth); + if (MaxVal.eq(CI->getValue())) { + break; // Recognized + } + } + return 0; + + case ICmpInst::ICMP_ULE: + // Recognize pattern: + // mulval = mul(zext A, zext B) + // cmp ule mulval, max + if (ConstantInt *CI = dyn_cast(OtherVal)) { + APInt MaxVal = APInt::getMaxValue(MulWidth); + MaxVal = MaxVal.zext(CI->getBitWidth()); + if (MaxVal.eq(CI->getValue())) { + break; // Recognized + } + } + return 0; + + case ICmpInst::ICMP_ULT: + // Recognize pattern: + // mulval = mul(zext A, zext B) + // cmp ule mulval, max + 1 + if (ConstantInt *CI = dyn_cast(OtherVal)) { + APInt MaxVal(CI->getBitWidth(), 1ULL << MulWidth); + if (MaxVal.eq(CI->getValue())) { + break; // Recognized + } + } + return 0; + + default: + return 0; + } + + InstCombiner::BuilderTy *Builder = IC.Builder; + Builder->SetInsertPoint(MulInstr); + Module *M = I.getParent()->getParent()->getParent(); + + // Replace: mul(zext A, zext B) --> mul.with.overflow(A, B) + Value *MulA = A, *MulB = B; + if (WidthA < MulWidth) + MulA = Builder->CreateZExt(A, MulType); + if (WidthB < MulWidth) + MulB = Builder->CreateZExt(B, MulType); + Value *F = + Intrinsic::getDeclaration(M, Intrinsic::umul_with_overflow, MulType); + CallInst *Call = Builder->CreateCall2(F, MulA, MulB, "umul"); + IC.Worklist.Add(MulInstr); + + // If there are uses of mul result other than the comparison, replace them + // with the new mul. + if (MulVal->hasNUsesOrMore(2)) { + Value *Mul = Builder->CreateExtractValue(Call, 0, "umul.value"); + for (User *U : MulVal->users()) { + if (U == &I || U == OtherVal) + continue; + if (TruncInst *TI = dyn_cast(U)) { + if (TI->getType()->getPrimitiveSizeInBits() == MulWidth) + IC.ReplaceInstUsesWith(*TI, Mul); + else + TI->setOperand(0, Mul); + } else if (BinaryOperator *BO = dyn_cast(U)) { + assert(BO->getOpcode() == Instruction::And); + // Replace (mul & mask) --> zext (mul.with.overflow & short_mast) + ConstantInt *CI = cast(BO->getOperand(1)); + APInt Mask = CI->getValue(); + Mask = Mask.trunc(MulWidth); + Value *ShortAnd = Builder->CreateAnd(Mul, Mask); + Instruction *Zext = + cast(Builder->CreateZExt(ShortAnd, BO->getType())); + IC.Worklist.Add(Zext); + IC.ReplaceInstUsesWith(*BO, Zext); + } else { + llvm_unreachable("Unexpected Binary operation"); + } + IC.Worklist.Add(cast(U)); + } + } + if (isa(OtherVal)) + IC.Worklist.Add(cast(OtherVal)); + + // The original icmp gets replaced with the overflow value, maybe inverted + // depending on predicate. + bool Inverse = false; + switch (I.getPredicate()) { + case ICmpInst::ICMP_NE: + break; + case ICmpInst::ICMP_EQ: + Inverse = true; + break; + case ICmpInst::ICMP_UGT: + case ICmpInst::ICMP_UGE: + if (I.getOperand(0) == MulVal) + break; + Inverse = true; + break; + case ICmpInst::ICMP_ULT: + case ICmpInst::ICMP_ULE: + if (I.getOperand(1) == MulVal) + break; + Inverse = true; + break; + default: + llvm_unreachable("Unexpected predicate"); + } + if (Inverse) { + Value *Res = Builder->CreateExtractValue(Call, 1); + return BinaryOperator::CreateNot(Res); + } + + return ExtractValueInst::Create(Call, 1); +} + // DemandedBitsLHSMask - When performing a comparison against a constant, // it is possible that not all the bits in the LHS are demanded. This helper // method computes the mask that IS demanded. @@ -2877,6 +3117,20 @@ (Op0 == A || Op0 == B)) if (Instruction *R = ProcessUAddIdiom(I, Op1, *this)) return R; + + // (zext a) * (zext b) --> llvm.umul.with.overflow. + if (match(Op0, m_Mul(m_ZExt(m_Value(A)), m_ZExt(m_Value(B))))) { + if (isa(A->getType())) { + if (Instruction *R = ProcessUMulZExtIdiom(I, Op0, Op1, *this)) + return R; + } + } + if (match(Op1, m_Mul(m_ZExt(m_Value(A)), m_ZExt(m_Value(B))))) { + if (isa(A->getType())) { + if (Instruction *R = ProcessUMulZExtIdiom(I, Op1, Op0, *this)) + return R; + } + } } if (I.isEquality()) { Index: test/Transforms/InstCombine/overflow-mul.ll =================================================================== --- /dev/null +++ test/Transforms/InstCombine/overflow-mul.ll @@ -0,0 +1,164 @@ +; RUN: opt -S -instcombine < %s | FileCheck %s + +; return mul(zext x, zext y) > MAX +define i32 @pr4917_1(i32 %x, i32 %y) nounwind { +; CHECK-LABEL: @pr4917_1( +entry: + %l = zext i32 %x to i64 + %r = zext i32 %y to i64 +; CHECK-NOT: zext i32 + %mul64 = mul i64 %l, %r +; CHECK: [[MUL:%.*]] = call { i32, i1 } @llvm.umul.with.overflow.i32(i32 %x, i32 %y) + %overflow = icmp ugt i64 %mul64, 4294967295 +; CHECK: extractvalue { i32, i1 } [[MUL]], 1 + %retval = zext i1 %overflow to i32 + ret i32 %retval +} + +; return mul(zext x, zext y) >= MAX+1 +define i32 @pr4917_1a(i32 %x, i32 %y) nounwind { +; CHECK-LABEL: @pr4917_1a( +entry: + %l = zext i32 %x to i64 + %r = zext i32 %y to i64 +; CHECK-NOT: zext i32 + %mul64 = mul i64 %l, %r +; CHECK: [[MUL:%.*]] = call { i32, i1 } @llvm.umul.with.overflow.i32(i32 %x, i32 %y) + %overflow = icmp uge i64 %mul64, 4294967296 +; CHECK: extractvalue { i32, i1 } [[MUL]], 1 + %retval = zext i1 %overflow to i32 + ret i32 %retval +} + +; mul(zext x, zext y) > MAX +; mul(x, y) is used +define i32 @pr4917_2(i32 %x, i32 %y) nounwind { +; CHECK-LABEL: @pr4917_2( +entry: + %l = zext i32 %x to i64 + %r = zext i32 %y to i64 +; CHECK-NOT: zext i32 + %mul64 = mul i64 %l, %r +; CHECK: [[MUL:%.*]] = call { i32, i1 } @llvm.umul.with.overflow.i32(i32 %x, i32 %y) + %overflow = icmp ugt i64 %mul64, 4294967295 +; CHECK-DAG: [[VAL:%.*]] = extractvalue { i32, i1 } [[MUL]], 0 + %mul32 = trunc i64 %mul64 to i32 +; CHECK-DAG: [[OVFL:%.*]] = extractvalue { i32, i1 } [[MUL]], 1 + %retval = select i1 %overflow, i32 %mul32, i32 111 +; CHECK: select i1 [[OVFL]], i32 [[VAL]] + ret i32 %retval +} + +; return mul(zext x, zext y) > MAX +; mul is used in non-truncate +define i64 @pr4917_3(i32 %x, i32 %y) nounwind { +; CHECK-LABEL: @pr4917_3( +entry: + %l = zext i32 %x to i64 + %r = zext i32 %y to i64 + %mul64 = mul i64 %l, %r +; CHECK-NOT: umul.with.overflow.i32 + %overflow = icmp ugt i64 %mul64, 4294967295 + %retval = select i1 %overflow, i64 %mul64, i64 111 + ret i64 %retval +} + +; return mul(zext x, zext y) <= MAX +define i32 @pr4917_4(i32 %x, i32 %y) nounwind { +; CHECK-LABEL: @pr4917_4( +entry: + %l = zext i32 %x to i64 + %r = zext i32 %y to i64 +; CHECK-NOT: zext i32 + %mul64 = mul i64 %l, %r +; CHECK: [[MUL:%.*]] = call { i32, i1 } @llvm.umul.with.overflow.i32(i32 %x, i32 %y) + %overflow = icmp ule i64 %mul64, 4294967295 +; CHECK: extractvalue { i32, i1 } [[MUL]], 1 +; CHECK: xor + %retval = zext i1 %overflow to i32 + ret i32 %retval +} + +; return mul(zext x, zext y) < MAX+1 +define i32 @pr4917_4a(i32 %x, i32 %y) nounwind { +; CHECK-LABEL: @pr4917_4a( +entry: + %l = zext i32 %x to i64 + %r = zext i32 %y to i64 +; CHECK-NOT: zext i32 + %mul64 = mul i64 %l, %r +; CHECK: [[MUL:%.*]] = call { i32, i1 } @llvm.umul.with.overflow.i32(i32 %x, i32 %y) + %overflow = icmp ult i64 %mul64, 4294967296 +; CHECK: extractvalue { i32, i1 } [[MUL]], 1 +; CHECK: xor + %retval = zext i1 %overflow to i32 + ret i32 %retval +} + +; operands of mul are of different size +define i32 @pr4917_5(i32 %x, i8 %y) nounwind { +; CHECK-LABEL: @pr4917_5( +entry: + %l = zext i32 %x to i64 + %r = zext i8 %y to i64 +; CHECK: [[Y:%.*]] = zext i8 %y to i32 + %mul64 = mul i64 %l, %r + %overflow = icmp ugt i64 %mul64, 4294967295 + %mul32 = trunc i64 %mul64 to i32 +; CHECK: [[MUL:%.*]] = call { i32, i1 } @llvm.umul.with.overflow.i32(i32 %x, i32 [[Y]]) +; CHECK-DAG: [[VAL:%.*]] = extractvalue { i32, i1 } [[MUL]], 0 +; CHECK-DAG: [[OVFL:%.*]] = extractvalue { i32, i1 } [[MUL]], 1 + %retval = select i1 %overflow, i32 %mul32, i32 111 +; CHECK: select i1 [[OVFL]], i32 [[VAL]] + ret i32 %retval +} + +; mul(zext x, zext y) != zext trunc mul +define i32 @pr4918_1(i32 %x, i32 %y) nounwind { +; CHECK-LABEL: @pr4918_1( +entry: + %l = zext i32 %x to i64 + %r = zext i32 %y to i64 + %mul64 = mul i64 %l, %r +; CHECK: [[MUL:%.*]] = call { i32, i1 } @llvm.umul.with.overflow.i32(i32 %x, i32 %y) + %part32 = trunc i64 %mul64 to i32 + %part64 = zext i32 %part32 to i64 + %overflow = icmp ne i64 %mul64, %part64 +; CHECK: [[OVFL:%.*]] = extractvalue { i32, i1 } [[MUL:%.*]], 1 + %retval = zext i1 %overflow to i32 + ret i32 %retval +} + +; mul(zext x, zext y) == zext trunc mul +define i32 @pr4918_2(i32 %x, i32 %y) nounwind { +; CHECK-LABEL: @pr4918_2( +entry: + %l = zext i32 %x to i64 + %r = zext i32 %y to i64 + %mul64 = mul i64 %l, %r +; CHECK: [[MUL:%.*]] = call { i32, i1 } @llvm.umul.with.overflow.i32(i32 %x, i32 %y) + %part32 = trunc i64 %mul64 to i32 + %part64 = zext i32 %part32 to i64 + %overflow = icmp eq i64 %mul64, %part64 +; CHECK: extractvalue { i32, i1 } [[MUL]] + %retval = zext i1 %overflow to i32 +; CHECK: xor + ret i32 %retval +} + +; zext trunc mul != mul(zext x, zext y) +define i32 @pr4918_3(i32 %x, i32 %y) nounwind { +; CHECK-LABEL: @pr4918_3( +entry: + %l = zext i32 %x to i64 + %r = zext i32 %y to i64 + %mul64 = mul i64 %l, %r +; CHECK: [[MUL:%.*]] = call { i32, i1 } @llvm.umul.with.overflow.i32(i32 %x, i32 %y) + %part32 = trunc i64 %mul64 to i32 + %part64 = zext i32 %part32 to i64 + %overflow = icmp ne i64 %part64, %mul64 +; CHECK: extractvalue { i32, i1 } [[MUL]], 1 + %retval = zext i1 %overflow to i32 + ret i32 %retval +} +