Index: llvm/trunk/lib/Transforms/Utils/LoopUtils.cpp =================================================================== --- llvm/trunk/lib/Transforms/Utils/LoopUtils.cpp +++ llvm/trunk/lib/Transforms/Utils/LoopUtils.cpp @@ -98,6 +98,7 @@ SmallVector Worklist; bool FoundOneOperand = false; + unsigned DstSize = RT->getPrimitiveSizeInBits(); Worklist.push_back(Exit); // Traverse the instructions in the reduction expression, beginning with the @@ -120,11 +121,16 @@ // If the operand is not in Visited, it is not a reduction operation, but // it does feed into one. Make sure it is either a single-use sign- or - // zero-extend of the recurrence type. + // zero-extend instruction. CastInst *Cast = dyn_cast(J); bool IsSExtInst = isa(J); - if (!Cast || !Cast->hasOneUse() || Cast->getSrcTy() != RT || - !(isa(J) || IsSExtInst)) + if (!Cast || !Cast->hasOneUse() || !(isa(J) || IsSExtInst)) + return false; + + // Ensure the source type of the extend is no larger than the reduction + // type. It is not necessary for the types to be identical. + unsigned SrcSize = Cast->getSrcTy()->getPrimitiveSizeInBits(); + if (SrcSize > DstSize) return false; // Furthermore, ensure that all such extends are of the same kind. @@ -136,9 +142,11 @@ IsSigned = IsSExtInst; } - // Lastly, add the sign- or zero-extend to CI so that we can avoid - // accounting for it in the cost model. - CI.insert(Cast); + // Lastly, if the source type of the extend matches the reduction type, + // add the extend to CI so that we can avoid accounting for it in the + // cost model. + if (SrcSize == DstSize) + CI.insert(Cast); } } return true; Index: llvm/trunk/test/Transforms/LoopVectorize/AArch64/reduction-small-size.ll =================================================================== --- llvm/trunk/test/Transforms/LoopVectorize/AArch64/reduction-small-size.ll +++ llvm/trunk/test/Transforms/LoopVectorize/AArch64/reduction-small-size.ll @@ -66,9 +66,9 @@ br i1 %exitcond, label %for.cond.for.cond.cleanup_crit_edge, label %for.body } -; CHECK-LABEL: @reduction_i16 +; CHECK-LABEL: @reduction_i16_1 ; -; short reduction_i16(short *a, short *b, int n) { +; short reduction_i16_1(short *a, short *b, int n) { ; short sum = 0; ; for (int i = 0; i < n; ++i) ; sum += (a[i] + b[i]); @@ -92,7 +92,7 @@ ; CHECK: [[Rdx:%[a-zA-Z0-9.]+]] = extractelement <8 x i16> ; CHECK: zext i16 [[Rdx]] to i32 ; -define i16 @reduction_i16(i16* nocapture readonly %a, i16* nocapture readonly %b, i32 %n) { +define i16 @reduction_i16_1(i16* nocapture readonly %a, i16* nocapture readonly %b, i32 %n) { entry: %cmp.16 = icmp sgt i32 %n, 0 br i1 %cmp.16, label %for.body.preheader, label %for.cond.cleanup @@ -126,3 +126,66 @@ %exitcond = icmp eq i32 %lftr.wideiv, %n br i1 %exitcond, label %for.cond.for.cond.cleanup_crit_edge, label %for.body } + +; CHECK-LABEL: @reduction_i16_2 +; +; short reduction_i16_2(char *a, char *b, int n) { +; short sum = 0; +; for (int i = 0; i < n; ++i) +; sum += (a[i] + b[i]); +; return sum; +; } +; +; CHECK: vector.body: +; CHECK: phi <8 x i16> +; CHECK: [[Ld1:%[a-zA-Z0-9.]+]] = load <8 x i8> +; CHECK: zext <8 x i8> [[Ld1]] to <8 x i16> +; CHECK: [[Ld2:%[a-zA-Z0-9.]+]] = load <8 x i8> +; CHECK: zext <8 x i8> [[Ld2]] to <8 x i16> +; CHECK: add <8 x i16> +; CHECK: add <8 x i16> +; +; CHECK: middle.block: +; CHECK: shufflevector <8 x i16> +; CHECK: add <8 x i16> +; CHECK: shufflevector <8 x i16> +; CHECK: add <8 x i16> +; CHECK: shufflevector <8 x i16> +; CHECK: add <8 x i16> +; CHECK: [[Rdx:%[a-zA-Z0-9.]+]] = extractelement <8 x i16> +; CHECK: zext i16 [[Rdx]] to i32 +; +define i16 @reduction_i16_2(i8* nocapture readonly %a, i8* nocapture readonly %b, i32 %n) { +entry: + %cmp.14 = icmp sgt i32 %n, 0 + br i1 %cmp.14, label %for.body.preheader, label %for.cond.cleanup + +for.body.preheader: + br label %for.body + +for.cond.for.cond.cleanup_crit_edge: + %add5.lcssa = phi i32 [ %add5, %for.body ] + %conv6 = trunc i32 %add5.lcssa to i16 + br label %for.cond.cleanup + +for.cond.cleanup: + %sum.0.lcssa = phi i16 [ %conv6, %for.cond.for.cond.cleanup_crit_edge ], [ 0, %entry ] + ret i16 %sum.0.lcssa + +for.body: + %indvars.iv = phi i64 [ %indvars.iv.next, %for.body ], [ 0, %for.body.preheader ] + %sum.015 = phi i32 [ %add5, %for.body ], [ 0, %for.body.preheader ] + %arrayidx = getelementptr inbounds i8, i8* %a, i64 %indvars.iv + %0 = load i8, i8* %arrayidx, align 1 + %conv = zext i8 %0 to i32 + %arrayidx2 = getelementptr inbounds i8, i8* %b, i64 %indvars.iv + %1 = load i8, i8* %arrayidx2, align 1 + %conv3 = zext i8 %1 to i32 + %conv4.13 = and i32 %sum.015, 65535 + %add = add nuw nsw i32 %conv, %conv4.13 + %add5 = add nuw nsw i32 %add, %conv3 + %indvars.iv.next = add nuw nsw i64 %indvars.iv, 1 + %lftr.wideiv = trunc i64 %indvars.iv.next to i32 + %exitcond = icmp eq i32 %lftr.wideiv, %n + br i1 %exitcond, label %for.cond.for.cond.cleanup_crit_edge, label %for.body +}