diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp @@ -1995,7 +1995,34 @@ } LLVM_FALLTHROUGH; } - case Intrinsic::vector_reduce_add: + case Intrinsic::vector_reduce_add: { + if (IID == Intrinsic::vector_reduce_add) { + // Convert vector_reduce_add(ZExt()) to + // ZExtOrTrunc(ctpop(bitcast to in)). + // Convert vector_reduce_add(SExt()) to + // -ZExtOrTrunc(ctpop(bitcast to in)). + // Convert vector_reduce_add() to + // Trunc(ctpop(bitcast to in)). + Value *Arg = II->getArgOperand(0); + Value *Vect; + if (match(Arg, m_ZExtOrSExtOrSelf(m_Value(Vect)))) { + if (auto *FTy = dyn_cast(Vect->getType())) + if (FTy->getElementType() == Builder.getInt1Ty()) { + Value *V = Builder.CreateBitCast( + Vect, Builder.getIntNTy(FTy->getNumElements())); + Value *Res = Builder.CreateUnaryIntrinsic(Intrinsic::ctpop, V); + if (Res->getType() != II->getType()) + Res = Builder.CreateZExtOrTrunc(Res, II->getType()); + if (Arg != Vect && + cast(Arg)->getOpcode() == Instruction::SExt) + Res = Builder.CreateNeg(Res); + replaceInstUsesWith(CI, Res); + return eraseInstFromFunction(CI); + } + } + } + LLVM_FALLTHROUGH; + } case Intrinsic::vector_reduce_mul: case Intrinsic::vector_reduce_xor: case Intrinsic::vector_reduce_umax: diff --git a/llvm/test/Transforms/InstCombine/reduction-add-sext-zext-i1.ll b/llvm/test/Transforms/InstCombine/reduction-add-sext-zext-i1.ll --- a/llvm/test/Transforms/InstCombine/reduction-add-sext-zext-i1.ll +++ b/llvm/test/Transforms/InstCombine/reduction-add-sext-zext-i1.ll @@ -3,8 +3,11 @@ define i1 @reduce_add_self(<8 x i1> %x) { ; CHECK-LABEL: @reduce_add_self( -; CHECK-NEXT: [[RES:%.*]] = call i1 @llvm.vector.reduce.add.v8i1(<8 x i1> [[X:%.*]]) -; CHECK-NEXT: ret i1 [[RES]] +; CHECK-NEXT: [[TMP1:%.*]] = bitcast <8 x i1> [[X:%.*]] to i8 +; CHECK-NEXT: [[TMP2:%.*]] = call i8 @llvm.ctpop.i8(i8 [[TMP1]]), !range [[RNG0:![0-9]+]] +; CHECK-NEXT: [[TMP3:%.*]] = and i8 [[TMP2]], 1 +; CHECK-NEXT: [[TMP4:%.*]] = icmp ne i8 [[TMP3]], 0 +; CHECK-NEXT: ret i1 [[TMP4]] ; %res = call i1 @llvm.vector.reduce.add.v8i32(<8 x i1> %x) ret i1 %res @@ -12,9 +15,11 @@ define i32 @reduce_add_sext(<4 x i1> %x) { ; CHECK-LABEL: @reduce_add_sext( -; CHECK-NEXT: [[SEXT:%.*]] = sext <4 x i1> [[X:%.*]] to <4 x i32> -; CHECK-NEXT: [[RES:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[SEXT]]) -; CHECK-NEXT: ret i32 [[RES]] +; CHECK-NEXT: [[TMP1:%.*]] = bitcast <4 x i1> [[X:%.*]] to i4 +; CHECK-NEXT: [[TMP2:%.*]] = call i4 @llvm.ctpop.i4(i4 [[TMP1]]), !range [[RNG1:![0-9]+]] +; CHECK-NEXT: [[TMP3:%.*]] = zext i4 [[TMP2]] to i32 +; CHECK-NEXT: [[TMP4:%.*]] = sub nsw i32 0, [[TMP3]] +; CHECK-NEXT: ret i32 [[TMP4]] ; %sext = sext <4 x i1> %x to <4 x i32> %res = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %sext) @@ -23,9 +28,10 @@ define i64 @reduce_add_zext(<8 x i1> %x) { ; CHECK-LABEL: @reduce_add_zext( -; CHECK-NEXT: [[ZEXT:%.*]] = zext <8 x i1> [[X:%.*]] to <8 x i64> -; CHECK-NEXT: [[RES:%.*]] = call i64 @llvm.vector.reduce.add.v8i64(<8 x i64> [[ZEXT]]) -; CHECK-NEXT: ret i64 [[RES]] +; CHECK-NEXT: [[TMP1:%.*]] = bitcast <8 x i1> [[X:%.*]] to i8 +; CHECK-NEXT: [[TMP2:%.*]] = call i8 @llvm.ctpop.i8(i8 [[TMP1]]), !range [[RNG0]] +; CHECK-NEXT: [[TMP3:%.*]] = zext i8 [[TMP2]] to i64 +; CHECK-NEXT: ret i64 [[TMP3]] ; %zext = zext <8 x i1> %x to <8 x i64> %res = call i64 @llvm.vector.reduce.add.v8i64(<8 x i64> %zext) @@ -34,9 +40,10 @@ define i16 @reduce_add_sext_same(<16 x i1> %x) { ; CHECK-LABEL: @reduce_add_sext_same( -; CHECK-NEXT: [[SEXT:%.*]] = sext <16 x i1> [[X:%.*]] to <16 x i16> -; CHECK-NEXT: [[RES:%.*]] = call i16 @llvm.vector.reduce.add.v16i16(<16 x i16> [[SEXT]]) -; CHECK-NEXT: ret i16 [[RES]] +; CHECK-NEXT: [[TMP1:%.*]] = bitcast <16 x i1> [[X:%.*]] to i16 +; CHECK-NEXT: [[TMP2:%.*]] = call i16 @llvm.ctpop.i16(i16 [[TMP1]]), !range [[RNG2:![0-9]+]] +; CHECK-NEXT: [[TMP3:%.*]] = sub nsw i16 0, [[TMP2]] +; CHECK-NEXT: ret i16 [[TMP3]] ; %sext = sext <16 x i1> %x to <16 x i16> %res = call i16 @llvm.vector.reduce.add.v16i16(<16 x i16> %sext) @@ -45,9 +52,11 @@ define i8 @reduce_add_zext_long(<128 x i1> %x) { ; CHECK-LABEL: @reduce_add_zext_long( -; CHECK-NEXT: [[SEXT:%.*]] = sext <128 x i1> [[X:%.*]] to <128 x i8> -; CHECK-NEXT: [[RES:%.*]] = call i8 @llvm.vector.reduce.add.v128i8(<128 x i8> [[SEXT]]) -; CHECK-NEXT: ret i8 [[RES]] +; CHECK-NEXT: [[TMP1:%.*]] = bitcast <128 x i1> [[X:%.*]] to i128 +; CHECK-NEXT: [[TMP2:%.*]] = call i128 @llvm.ctpop.i128(i128 [[TMP1]]), !range [[RNG3:![0-9]+]] +; CHECK-NEXT: [[TMP3:%.*]] = trunc i128 [[TMP2]] to i8 +; CHECK-NEXT: [[TMP4:%.*]] = sub i8 0, [[TMP3]] +; CHECK-NEXT: ret i8 [[TMP4]] ; %sext = sext <128 x i1> %x to <128 x i8> %res = call i8 @llvm.vector.reduce.add.v128i8(<128 x i8> %sext) @@ -57,11 +66,14 @@ @glob = external global i8, align 1 define i8 @reduce_add_zext_long_external_use(<128 x i1> %x) { ; CHECK-LABEL: @reduce_add_zext_long_external_use( -; CHECK-NEXT: [[SEXT:%.*]] = sext <128 x i1> [[X:%.*]] to <128 x i8> -; CHECK-NEXT: [[RES:%.*]] = call i8 @llvm.vector.reduce.add.v128i8(<128 x i8> [[SEXT]]) -; CHECK-NEXT: [[EXT:%.*]] = extractelement <128 x i8> [[SEXT]], i32 0 +; CHECK-NEXT: [[TMP1:%.*]] = bitcast <128 x i1> [[X:%.*]] to i128 +; CHECK-NEXT: [[TMP2:%.*]] = call i128 @llvm.ctpop.i128(i128 [[TMP1]]), !range [[RNG3]] +; CHECK-NEXT: [[TMP3:%.*]] = trunc i128 [[TMP2]] to i8 +; CHECK-NEXT: [[TMP4:%.*]] = sub i8 0, [[TMP3]] +; CHECK-NEXT: [[TMP5:%.*]] = extractelement <128 x i1> [[X]], i32 0 +; CHECK-NEXT: [[EXT:%.*]] = sext i1 [[TMP5]] to i8 ; CHECK-NEXT: store i8 [[EXT]], i8* @glob, align 1 -; CHECK-NEXT: ret i8 [[RES]] +; CHECK-NEXT: ret i8 [[TMP4]] ; %sext = sext <128 x i1> %x to <128 x i8> %res = call i8 @llvm.vector.reduce.add.v128i8(<128 x i8> %sext) @@ -73,11 +85,13 @@ @glob1 = external global i64, align 8 define i64 @reduce_add_zext_external_use(<8 x i1> %x) { ; CHECK-LABEL: @reduce_add_zext_external_use( -; CHECK-NEXT: [[ZEXT:%.*]] = zext <8 x i1> [[X:%.*]] to <8 x i64> -; CHECK-NEXT: [[RES:%.*]] = call i64 @llvm.vector.reduce.add.v8i64(<8 x i64> [[ZEXT]]) -; CHECK-NEXT: [[EXT:%.*]] = extractelement <8 x i64> [[ZEXT]], i32 0 +; CHECK-NEXT: [[TMP1:%.*]] = bitcast <8 x i1> [[X:%.*]] to i8 +; CHECK-NEXT: [[TMP2:%.*]] = call i8 @llvm.ctpop.i8(i8 [[TMP1]]), !range [[RNG0]] +; CHECK-NEXT: [[TMP3:%.*]] = zext i8 [[TMP2]] to i64 +; CHECK-NEXT: [[TMP4:%.*]] = extractelement <8 x i1> [[X]], i32 0 +; CHECK-NEXT: [[EXT:%.*]] = zext i1 [[TMP4]] to i64 ; CHECK-NEXT: store i64 [[EXT]], i64* @glob1, align 8 -; CHECK-NEXT: ret i64 [[RES]] +; CHECK-NEXT: ret i64 [[TMP3]] ; %zext = zext <8 x i1> %x to <8 x i64> %res = call i64 @llvm.vector.reduce.add.v8i64(<8 x i64> %zext)