diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h --- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h +++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h @@ -116,6 +116,7 @@ BinaryOperator &I); Instruction *foldVariableSignZeroExtensionOfVariableHighBitExtract( BinaryOperator &OldAShr); + Instruction *foldXShrToOverflow(BinaryOperator &I); Instruction *visitAShr(BinaryOperator &I); Instruction *visitLShr(BinaryOperator &I); Instruction *commonShiftTransforms(BinaryOperator &I); diff --git a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp @@ -1035,6 +1035,76 @@ return nullptr; } +// Tries to perform (Xshr (add (zext a, i2^n), (zext b, i2^n)), 2^(n-1)) -> +// (llvm.uadd.with.overflow a, b).overflow where Xshr can be ashr or lshr; a and +// b has type i2^(n-1). +Instruction *InstCombinerImpl::foldXShrToOverflow(BinaryOperator &I) { + assert(I.getOpcode() == Instruction::AShr || + I.getOpcode() == Instruction::LShr); + + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + Type *Ty = I.getType(); + const APInt *ShAmtAPInt; + if (!match(Op1, m_APInt(ShAmtAPInt))) + return nullptr; + + unsigned ShAmt = ShAmtAPInt->getZExtValue(); + unsigned BitWidth = Ty->getScalarSizeInBits(); + if (!(isPowerOf2_32(BitWidth) && Ty->isIntegerTy(ShAmt * 2))) + return nullptr; + + // Be conservative if the newly created overflow would be a live-out of it's + // basic-block. + for (const User *Usr : I.users()) { + const Instruction *UsrInst = dyn_cast(Usr); + if (UsrInst && UsrInst->getParent() != I.getParent()) + return nullptr; + } + + Value *X = nullptr, *Y = nullptr; + if (match(Op0, m_Add(m_ZExt(m_Value(X)), m_ZExt(m_Value(Y)))) && + X->getType()->getScalarSizeInBits() == ShAmt && + Y->getType()->getScalarSizeInBits() == ShAmt) { + + // Make sure that `Op0` is only used by `I` and `ShAmt`-truncates. + SmallVector Truncs; + for (User *Usr : Op0->users()) { + if (Usr == &I) + continue; + + TruncInst *Trunc = dyn_cast(Usr); + if (!Trunc) + return nullptr; + + if (Trunc->getType()->getScalarSizeInBits() != ShAmt) + return nullptr; + + Truncs.push_back(Trunc); + } + + // If we get here, we can be sure that `Op0` is only used by `Truncs` and + // `I`. + + Value *UAddOverflow = Builder.CreateBinaryIntrinsic( + Intrinsic::uadd_with_overflow, X, Y, /* FMFSource */ nullptr, "uaddo"); + Value *UAdd = Builder.CreateExtractValue(UAddOverflow, 0, + UAddOverflow->getName() + ".add"); + Value *Overflow = Builder.CreateExtractValue( + UAddOverflow, 1, UAddOverflow->getName() + ".overflow"); + + // Replace the uses of truncated `Op0` with `UAdd` since `UAddOverflow` + // performs the truncated version of the addition performed by `Op0`. + for (TruncInst *Trunc : Truncs) { + Trunc->replaceAllUsesWith(UAdd); + } + + // Replace the use of `Op0` by `I` with `Overflow`. + return new ZExtInst(Overflow, Ty); + } + + return nullptr; +} + Instruction *InstCombinerImpl::visitLShr(BinaryOperator &I) { if (Value *V = SimplifyLShrInst(I.getOperand(0), I.getOperand(1), I.isExact(), SQ.getWithInstruction(&I))) @@ -1181,6 +1251,9 @@ } } + if (Instruction *Overflow = foldXShrToOverflow(I)) + return Overflow; + // Transform (x << y) >> y to x & (-1 >> y) Value *X; if (match(Op0, m_OneUse(m_Shl(m_Value(X), m_Specific(Op1))))) { diff --git a/llvm/test/Transforms/InstCombine/shift-add.ll b/llvm/test/Transforms/InstCombine/shift-add.ll --- a/llvm/test/Transforms/InstCombine/shift-add.ll +++ b/llvm/test/Transforms/InstCombine/shift-add.ll @@ -120,3 +120,50 @@ %E = lshr <4 x i32> , %D ret <4 x i32> %E } + +define i64 @lshr_32_add_zext_basic(i32 %a, i32 %b) { +; CHECK-LABEL: define i64 @lshr_32_add_zext_basic( +; CHECK-NEXT: [[uaddo:%.*]] = call { i32, i1 } @llvm.uadd.with.overflow.i32(i32 %a, i32 %b) +; CHECK-NEXT: [[overflow:%.*]] = extractvalue { i32, i1 } [[uaddo]], 1 +; CHECK-NEXT: [[zextOverflow:%.*]] = zext i1 %uaddo.overflow to i64 +; CHECK-NEXT: ret i64 [[zextOverflow]] + + %zext.a = zext i32 %a to i64 + %zext.b = zext i32 %b to i64 + %add = add i64 %zext.a, %zext.b + %lshr = lshr i64 %add, 32 + ret i64 %lshr +} + +define i64 @ashr_32_add_zext_basic(i32 %a, i32 %b) { +; CHECK-LABEL: define i64 @ashr_32_add_zext_basic( +; CHECK-NEXT: [[uaddo:%.*]] = call { i32, i1 } @llvm.uadd.with.overflow.i32(i32 %a, i32 %b) +; CHECK-NEXT: [[overflow:%.*]] = extractvalue { i32, i1 } [[uaddo]], 1 +; CHECK-NEXT: [[zextOverflow:%.*]] = zext i1 %uaddo.overflow to i64 +; CHECK-NEXT: ret i64 [[zextOverflow]] + + %zext.a = zext i32 %a to i64 + %zext.b = zext i32 %b to i64 + %add = add i64 %zext.a, %zext.b + %lshr = ashr i64 %add, 32 + ret i64 %lshr +} + +define i32 @lshr_32_add_zext_trunc(i32 %a, i32 %b) { +; CHECK-LABEL: define i32 @lshr_32_add_zext_trunc( +; CHECK-NEXT: [[uaddo:%.*]] = call { i32, i1 } @llvm.uadd.with.overflow.i32(i32 %a, i32 %b) +; CHECK-NEXT: [[add:%.*]] = extractvalue { i32, i1 } [[uaddo]], 0 +; CHECK-NEXT: [[overflow:%.*]] = extractvalue { i32, i1 } [[uaddo]], 1 +; CHECK-NEXT: [[zextOverflow:%.*]] = zext i1 [[overflow]] to i32 +; CHECK-NEXT: [[ret:%.*]] = add i32 [[add]], [[zextOverflow]] +; CHECK-NEXT: ret i32 [[ret]] + + %zext.a = zext i32 %a to i64 + %zext.b = zext i32 %b to i64 + %add = add i64 %zext.a, %zext.b + %trunc.add = trunc i64 %add to i32 + %shr = lshr i64 %add, 32 + %trunc.shr = trunc i64 %shr to i32 + %ret = add i32 %trunc.add, %trunc.shr + ret i32 %ret +}