Index: include/llvm/CodeGen/BasicTTIImpl.h =================================================================== --- include/llvm/CodeGen/BasicTTIImpl.h +++ include/llvm/CodeGen/BasicTTIImpl.h @@ -542,6 +542,31 @@ unsigned Cost = static_cast(this)->getMemoryOpCost( Opcode, VecTy, Alignment, AddressSpace); + // Legalize the vector type and count the number of legalized instructions + // that will actually be used. We scale the cost of the memory operation by + // the fraction of legalized instructions that aren't dead. We shouldn't + // account for the cost of dead instructions since they will be removed. + // + // E.g., An interleaved load of factor 8: + // %vec = load <16 x i64>, <16 x i64>* %ptr + // %v0 = shufflevector %vec, undef, <0, 8> + // + // If <16 x i64> is legalized to 8 v2i64 loads, only 2 of the loads will be + // used (those corresponding to elements [0:1] and [8:9] of the unlegalized + // type). The other loads are unused. + // + // We only scale the cost of loads since interleaved store groups aren't + // allowed to have gaps. + unsigned NumLegalInsts = getTLI()->getTypeLegalizationCost(DL, VecTy).first; + if (Opcode == Instruction::Load && NumLegalInsts > 1) { + unsigned NumEltsPerLegalInst = NumElts / NumLegalInsts; + BitVector UsedInsts(NumLegalInsts, false); + for (unsigned Index : Indices) + for (unsigned Elt = 0; Elt < NumSubElts; Elt++) + UsedInsts.set((Index + Elt * Factor) / NumEltsPerLegalInst); + Cost *= UsedInsts.count() / NumLegalInsts; + } + // Then plus the cost of interleave operation. if (Opcode == Instruction::Load) { // The interleave cost is similar to extract sub vectors' elements Index: test/Transforms/LoopVectorize/AArch64/interleaved_cost.ll =================================================================== --- test/Transforms/LoopVectorize/AArch64/interleaved_cost.ll +++ test/Transforms/LoopVectorize/AArch64/interleaved_cost.ll @@ -14,6 +14,7 @@ ; 8xi8 and 16xi8 are valid i8 vector types, so the cost of the interleaved ; access group is 2. +; CHECK: LV: Checking a loop in "test_byte_interleaved_cost" ; CHECK: LV: Found an estimated cost of 2 for VF 8 For instruction: %tmp = load i8, i8* %arrayidx0, align 4 ; CHECK: LV: Found an estimated cost of 2 for VF 16 For instruction: %tmp = load i8, i8* %arrayidx0, align 4 @@ -37,3 +38,44 @@ for.end: ; preds = %for.body ret void } + +%ig.factor.8 = type { double*, double, double, double, double, double, double, double } +define double @wide_interleaved_group(%ig.factor.8* %s, double %a, double %b, i32 %n) { +entry: + br label %for.body + +; Check the default cost of a strided load with a factor that is greater than +; the maximum allowed. In this test, the interleave factor would be 8, which is +; not supported. + +; CHECK: LV: Checking a loop in "wide_interleaved_group" +; CHECK: LV: Found an estimated cost of 6 for VF 2 For instruction: %1 = load double, double* %0, align 8 +; CHECK: LV: Found an estimated cost of 0 for VF 2 For instruction: %5 = load double, double* %4, align 8 +; CHECK: LV: Found an estimated cost of 10 for VF 2 For instruction: store double %9, double* %10, align 8 + +for.body: + %i = phi i64 [ 0, %entry ], [ %i.next, %for.body ] + %r = phi double [ 0.000000e+00, %entry ], [ %12, %for.body ] + %0 = getelementptr inbounds %ig.factor.8, %ig.factor.8* %s, i64 %i, i32 2 + %1 = load double, double* %0, align 8 + %2 = fcmp fast olt double %1, %a + %3 = select i1 %2, double 0.000000e+00, double %1 + %4 = getelementptr inbounds %ig.factor.8, %ig.factor.8* %s, i64 %i, i32 6 + %5 = load double, double* %4, align 8 + %6 = fcmp fast olt double %5, %a + %7 = select i1 %6, double 0.000000e+00, double %5 + %8 = fmul fast double %7, %b + %9 = fadd fast double %8, %3 + %10 = getelementptr inbounds %ig.factor.8, %ig.factor.8* %s, i64 %i, i32 3 + store double %9, double* %10, align 8 + %11 = fmul fast double %9, %9 + %12 = fadd fast double %11, %r + %i.next = add nuw nsw i64 %i, 1 + %13 = trunc i64 %i.next to i32 + %cond = icmp eq i32 %13, %n + br i1 %cond, label %for.exit, label %for.body + +for.exit: + %r.lcssa = phi double [ %12, %for.body ] + ret double %r.lcssa +}