diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h --- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h +++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h @@ -367,6 +367,7 @@ Value *matchSelectFromAndOr(Value *A, Value *B, Value *C, Value *D); Value *getSelectCondition(Value *A, Value *B); + Instruction *foldXShrToOverflow(BinaryOperator &I); Instruction *foldExtractOfOverflowIntrinsic(ExtractValueInst &EV); Instruction *foldIntrinsicWithOverflowCommon(IntrinsicInst *II); Instruction *foldFPSignBitOps(BinaryOperator &I); diff --git a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp @@ -839,6 +839,72 @@ return nullptr; } +// Tries to perform +// (Xshr (add (zext a, i2^n), (zext b, i2^n)), 2^(n-1)) +// -> (llvm.uadd.with.overflow a, b).overflow +// where +// - Xshr can be ashr or lshr +// - a and b has type i2^(n-1). +Instruction *InstCombinerImpl::foldXShrToOverflow(BinaryOperator &I) { + assert(I.getOpcode() == Instruction::AShr || + I.getOpcode() == Instruction::LShr); + + Instruction *Op0 = dyn_cast(I.getOperand(0)); + Value *Op1 = I.getOperand(1); + Type *Ty = I.getType(); + const APInt *ShAmtAPInt = nullptr; + + if (!Op0 || !match(Op1, m_APInt(ShAmtAPInt))) + return nullptr; + + const unsigned ShAmt = ShAmtAPInt->getZExtValue(); + const unsigned BitWidth = Ty->getScalarSizeInBits(); + + if (!isPowerOf2_32(BitWidth) || !Ty->isIntegerTy(ShAmt * 2)) + return nullptr; + + Value *X = nullptr, *Y = nullptr; + if (!match(Op0, m_Add(m_ZExt(m_Value(X)), m_ZExt(m_Value(Y)))) || + X->getType()->getScalarSizeInBits() != ShAmt || + Y->getType()->getScalarSizeInBits() != ShAmt) + return nullptr; + + // Make sure that `Op0` is only used by `I` and `ShAmt`-truncates. + SmallVector Truncs; + for (User *Usr : Op0->users()) { + if (Usr == &I) + continue; + + TruncInst *Trunc = dyn_cast(Usr); + if (!Trunc || Trunc->getType()->getScalarSizeInBits() != ShAmt) + return nullptr; + + Truncs.push_back(Trunc); + } + + // Insert at Op0 so that the newly created `UAdd` will dominate it's users + // (i.e. `Op0`'s users). + BasicBlock::iterator RestoreInsPt = Builder.GetInsertPoint(); + Builder.SetInsertPoint(Op0); + + Value *UAddOverflow = Builder.CreateBinaryIntrinsic( + Intrinsic::uadd_with_overflow, X, Y, /* FMFSource */ nullptr, "uaddo"); + Value *UAdd = Builder.CreateExtractValue(UAddOverflow, 0, + UAddOverflow->getName() + ".add"); + Value *Overflow = Builder.CreateExtractValue( + UAddOverflow, 1, UAddOverflow->getName() + ".overflow"); + + // Replace the uses of truncated `Op0` with `UAdd` since `UAddOverflow` + // performs the truncated version of the addition performed by `Op0`. + for (TruncInst *Trunc : Truncs) + replaceInstUsesWith(*Trunc, UAdd); + + Builder.SetInsertPoint(&(*RestoreInsPt)); + + // Replace the use of `Op0` by `I` with `Overflow`. + return new ZExtInst(Overflow, Ty); +} + Instruction *InstCombinerImpl::visitShl(BinaryOperator &I) { const SimplifyQuery Q = SQ.getWithInstruction(&I); @@ -1305,6 +1371,9 @@ return BinaryOperator::CreateAnd(Mask, X); } + if (Instruction *Overflow = foldXShrToOverflow(I)) + return Overflow; + return nullptr; } @@ -1496,5 +1565,8 @@ return BinaryOperator::CreateNot(NewAShr); } + if (Instruction *Overflow = foldXShrToOverflow(I)) + return Overflow; + return nullptr; } diff --git a/llvm/test/Transforms/InstCombine/shift-add.ll b/llvm/test/Transforms/InstCombine/shift-add.ll --- a/llvm/test/Transforms/InstCombine/shift-add.ll +++ b/llvm/test/Transforms/InstCombine/shift-add.ll @@ -431,3 +431,102 @@ %r = shl nsw i4 2, %a ret i4 %r } + +define i64 @lshr_32_add_zext_basic(i32 %a, i32 %b) { +; CHECK-LABEL: @lshr_32_add_zext_basic( +; CHECK-NEXT: [[UADDO:%.*]] = call { i32, i1 } @llvm.uadd.with.overflow.i32(i32 [[A:%.*]], i32 [[B:%.*]]) +; CHECK-NEXT: [[UADDO_OVERFLOW:%.*]] = extractvalue { i32, i1 } [[UADDO]], 1 +; CHECK-NEXT: [[LSHR:%.*]] = zext i1 [[UADDO_OVERFLOW]] to i64 +; CHECK-NEXT: ret i64 [[LSHR]] +; + %zext.a = zext i32 %a to i64 + %zext.b = zext i32 %b to i64 + %add = add i64 %zext.a, %zext.b + %lshr = lshr i64 %add, 32 + ret i64 %lshr +} + +define i64 @ashr_32_add_zext_basic(i32 %a, i32 %b) { +; CHECK-LABEL: @ashr_32_add_zext_basic( +; CHECK-NEXT: [[UADDO:%.*]] = call { i32, i1 } @llvm.uadd.with.overflow.i32(i32 [[A:%.*]], i32 [[B:%.*]]) +; CHECK-NEXT: [[UADDO_OVERFLOW:%.*]] = extractvalue { i32, i1 } [[UADDO]], 1 +; CHECK-NEXT: [[TMP1:%.*]] = zext i1 [[UADDO_OVERFLOW]] to i64 +; CHECK-NEXT: ret i64 [[TMP1]] +; + %zext.a = zext i32 %a to i64 + %zext.b = zext i32 %b to i64 + %add = add i64 %zext.a, %zext.b + %lshr = ashr i64 %add, 32 + ret i64 %lshr +} + +define i32 @lshr_32_add_zext_trunc(i32 %a, i32 %b) { +; CHECK-LABEL: @lshr_32_add_zext_trunc( +; CHECK-NEXT: [[UADDO:%.*]] = call { i32, i1 } @llvm.uadd.with.overflow.i32(i32 [[A:%.*]], i32 [[B:%.*]]) +; CHECK-NEXT: [[UADDO_ADD:%.*]] = extractvalue { i32, i1 } [[UADDO]], 0 +; CHECK-NEXT: [[UADDO_OVERFLOW:%.*]] = extractvalue { i32, i1 } [[UADDO]], 1 +; CHECK-NEXT: [[TRUNC_SHR:%.*]] = zext i1 [[UADDO_OVERFLOW]] to i32 +; CHECK-NEXT: [[RET:%.*]] = add i32 [[UADDO_ADD]], [[TRUNC_SHR]] +; CHECK-NEXT: ret i32 [[RET]] +; + %zext.a = zext i32 %a to i64 + %zext.b = zext i32 %b to i64 + %add = add i64 %zext.a, %zext.b + %trunc.add = trunc i64 %add to i32 + %shr = lshr i64 %add, 32 + %trunc.shr = trunc i64 %shr to i32 + %ret = add i32 %trunc.add, %trunc.shr + ret i32 %ret +} + +define <3 x i32> @add3_i96(<3 x i32> %0, <3 x i32> %1) { +; CHECK-LABEL: @add3_i96( +; CHECK-NEXT: [[TMP3:%.*]] = extractelement <3 x i32> [[TMP0:%.*]], i64 0 +; CHECK-NEXT: [[TMP4:%.*]] = extractelement <3 x i32> [[TMP1:%.*]], i64 0 +; CHECK-NEXT: [[UADDO:%.*]] = call { i32, i1 } @llvm.uadd.with.overflow.i32(i32 [[TMP4]], i32 [[TMP3]]) +; CHECK-NEXT: [[UADDO_ADD:%.*]] = extractvalue { i32, i1 } [[UADDO]], 0 +; CHECK-NEXT: [[UADDO_OVERFLOW:%.*]] = extractvalue { i32, i1 } [[UADDO]], 1 +; CHECK-NEXT: [[TMP5:%.*]] = extractelement <3 x i32> [[TMP0]], i64 1 +; CHECK-NEXT: [[TMP6:%.*]] = zext i32 [[TMP5]] to i64 +; CHECK-NEXT: [[TMP7:%.*]] = extractelement <3 x i32> [[TMP1]], i64 1 +; CHECK-NEXT: [[TMP8:%.*]] = zext i32 [[TMP7]] to i64 +; CHECK-NEXT: [[TMP9:%.*]] = add nuw nsw i64 [[TMP8]], [[TMP6]] +; CHECK-NEXT: [[TMP10:%.*]] = zext i1 [[UADDO_OVERFLOW]] to i64 +; CHECK-NEXT: [[TMP11:%.*]] = add nuw nsw i64 [[TMP9]], [[TMP10]] +; CHECK-NEXT: [[TMP12:%.*]] = extractelement <3 x i32> [[TMP0]], i64 2 +; CHECK-NEXT: [[TMP13:%.*]] = extractelement <3 x i32> [[TMP1]], i64 2 +; CHECK-NEXT: [[TMP14:%.*]] = add i32 [[TMP13]], [[TMP12]] +; CHECK-NEXT: [[TMP15:%.*]] = lshr i64 [[TMP11]], 32 +; CHECK-NEXT: [[TMP16:%.*]] = trunc i64 [[TMP15]] to i32 +; CHECK-NEXT: [[TMP17:%.*]] = add i32 [[TMP14]], [[TMP16]] +; CHECK-NEXT: [[TMP18:%.*]] = insertelement <3 x i32> undef, i32 [[UADDO_ADD]], i64 0 +; CHECK-NEXT: [[TMP19:%.*]] = trunc i64 [[TMP11]] to i32 +; CHECK-NEXT: [[TMP20:%.*]] = insertelement <3 x i32> [[TMP18]], i32 [[TMP19]], i64 1 +; CHECK-NEXT: [[TMP21:%.*]] = insertelement <3 x i32> [[TMP20]], i32 [[TMP17]], i64 2 +; CHECK-NEXT: ret <3 x i32> [[TMP21]] +; + %3 = extractelement <3 x i32> %0, i64 0 + %4 = zext i32 %3 to i64 + %5 = extractelement <3 x i32> %1, i64 0 + %6 = zext i32 %5 to i64 + %7 = add nuw nsw i64 %6, %4 + %8 = extractelement <3 x i32> %0, i64 1 + %9 = zext i32 %8 to i64 + %10 = extractelement <3 x i32> %1, i64 1 + %11 = zext i32 %10 to i64 + %12 = add nuw nsw i64 %11, %9 + %13 = lshr i64 %7, 32 + %14 = add nuw nsw i64 %12, %13 + %15 = extractelement <3 x i32> %0, i64 2 + %16 = extractelement <3 x i32> %1, i64 2 + %17 = add i32 %16, %15 + %18 = lshr i64 %14, 32 + %19 = trunc i64 %18 to i32 + %20 = add i32 %17, %19 + %21 = trunc i64 %7 to i32 + %22 = insertelement <3 x i32> undef, i32 %21, i32 0 + %23 = trunc i64 %14 to i32 + %24 = insertelement <3 x i32> %22, i32 %23, i32 1 + %25 = insertelement <3 x i32> %24, i32 %20, i32 2 + ret <3 x i32> %25 +}