Index: llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp =================================================================== --- llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp +++ llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp @@ -1792,20 +1792,49 @@ // instructions as normal vector adds. This is the only arithmetic vector // reduction operation for which we have an instruction. static const CostTblEntry CostTblNoPairwise[]{ - {ISD::ADD, MVT::v8i8, 1}, - {ISD::ADD, MVT::v16i8, 1}, - {ISD::ADD, MVT::v4i16, 1}, - {ISD::ADD, MVT::v8i16, 1}, - {ISD::ADD, MVT::v4i32, 1}, + {ISD::ADD, MVT::v8i8, 1}, + {ISD::ADD, MVT::v16i8, 1}, + {ISD::ADD, MVT::v4i16, 1}, + {ISD::ADD, MVT::v8i16, 1}, + {ISD::ADD, MVT::v4i32, 1}, + {ISD::OR, MVT::v8i8, 15}, + {ISD::OR, MVT::v16i8, 17}, + {ISD::OR, MVT::v4i16, 7}, + {ISD::OR, MVT::v8i16, 9}, + {ISD::OR, MVT::v4i32, 5}, + {ISD::OR, MVT::v2i64, 3}, }; - - if (const auto *Entry = CostTableLookup(CostTblNoPairwise, ISD, MTy)) - return LT.first * Entry->Cost; - + switch (ISD) { + default: + break; + case ISD::ADD: + if (const auto *Entry = CostTableLookup(CostTblNoPairwise, ISD, MTy)) + return LT.first * Entry->Cost; + break; + case ISD::OR: + if (const auto *Entry = CostTableLookup(CostTblNoPairwise, ISD, MTy)) { + auto *ValVTy = cast(ValTy); + if (!ValVTy->getElementType()->isIntegerTy(1) && + MTy.getVectorNumElements() <= ValVTy->getNumElements() && + isPowerOf2_32(ValVTy->getNumElements())) { + InstructionCost ExtraCost = 0; + if (LT.first != 1) { + // Type needs to be split, so there is an extra cost of LT.first - 1 + // arithmetic ops. + auto *Ty = FixedVectorType::get(ValTy->getElementType(), + MTy.getVectorNumElements()); + ExtraCost = getArithmeticInstrCost(Opcode, Ty, CostKind); + ExtraCost *= LT.first - 1; + } + return Entry->Cost + ExtraCost; + } + } + break; + } return BaseT::getArithmeticReductionCost(Opcode, ValTy, IsPairwiseForm, CostKind); } InstructionCost AArch64TTIImpl::getShuffleCost(TTI::ShuffleKind Kind, VectorType *Tp, ArrayRef Mask, int Index, Index: llvm/test/Analysis/CostModel/AArch64/reduce-or.ll =================================================================== --- llvm/test/Analysis/CostModel/AArch64/reduce-or.ll +++ llvm/test/Analysis/CostModel/AArch64/reduce-or.ll @@ -11,8 +11,23 @@ ; CHECK-NEXT: Cost Model: Found an estimated cost of 91 for instruction: %V32 = call i1 @llvm.vector.reduce.or.v32i1(<32 x i1> undef) ; CHECK-NEXT: Cost Model: Found an estimated cost of 181 for instruction: %V64 = call i1 @llvm.vector.reduce.or.v64i1(<64 x i1> undef) ; CHECK-NEXT: Cost Model: Found an estimated cost of 362 for instruction: %V128 = call i1 @llvm.vector.reduce.or.v128i1(<128 x i1> undef) +; CHECK-NEXT: Cost Model: Found an estimated cost of 0 for instruction: %V1i8 = call i8 @llvm.vector.reduce.or.v1i8(<1 x i8> undef) +; CHECK-NEXT: Cost Model: Found an estimated cost of 13 for instruction: %V3i8 = call i8 @llvm.vector.reduce.or.v3i8(<3 x i8> undef) +; CHECK-NEXT: Cost Model: Found an estimated cost of 7 for instruction: %V4i8 = call i8 @llvm.vector.reduce.or.v4i8(<4 x i8> undef) +; CHECK-NEXT: Cost Model: Found an estimated cost of 15 for instruction: %V8i8 = call i8 @llvm.vector.reduce.or.v8i8(<8 x i8> undef) +; CHECK-NEXT: Cost Model: Found an estimated cost of 17 for instruction: %V16i8 = call i8 @llvm.vector.reduce.or.v16i8(<16 x i8> undef) +; CHECK-NEXT: Cost Model: Found an estimated cost of 18 for instruction: %V32i8 = call i8 @llvm.vector.reduce.or.v32i8(<32 x i8> undef) +; CHECK-NEXT: Cost Model: Found an estimated cost of 20 for instruction: %V64i8 = call i8 @llvm.vector.reduce.or.v64i8(<64 x i8> undef) +; CHECK-NEXT: Cost Model: Found an estimated cost of 7 for instruction: %V4i16 = call i16 @llvm.vector.reduce.or.v4i16(<4 x i16> undef) +; CHECK-NEXT: Cost Model: Found an estimated cost of 9 for instruction: %V8i16 = call i16 @llvm.vector.reduce.or.v8i16(<8 x i16> undef) +; CHECK-NEXT: Cost Model: Found an estimated cost of 10 for instruction: %V16i16 = call i16 @llvm.vector.reduce.or.v16i16(<16 x i16> undef) +; CHECK-NEXT: Cost Model: Found an estimated cost of 5 for instruction: %V4i32 = call i32 @llvm.vector.reduce.or.v4i32(<4 x i32> undef) +; CHECK-NEXT: Cost Model: Found an estimated cost of 6 for instruction: %V8i32 = call i32 @llvm.vector.reduce.or.v8i32(<8 x i32> undef) +; CHECK-NEXT: Cost Model: Found an estimated cost of 3 for instruction: %V2i64 = call i64 @llvm.vector.reduce.or.v2i64(<2 x i64> undef) +; CHECK-NEXT: Cost Model: Found an estimated cost of 4 for instruction: %V4i64 = call i64 @llvm.vector.reduce.or.v4i64(<4 x i64> undef) ; CHECK-NEXT: Cost Model: Found an estimated cost of 0 for instruction: ret i32 undef ; + %V1 = call i1 @llvm.vector.reduce.or.v1i1(<1 x i1> undef) %V2 = call i1 @llvm.vector.reduce.or.v2i1(<2 x i1> undef) %V4 = call i1 @llvm.vector.reduce.or.v4i1(<4 x i1> undef) @@ -21,9 +36,24 @@ %V32 = call i1 @llvm.vector.reduce.or.v32i1(<32 x i1> undef) %V64 = call i1 @llvm.vector.reduce.or.v64i1(<64 x i1> undef) %V128 = call i1 @llvm.vector.reduce.or.v128i1(<128 x i1> undef) + + %V1i8 = call i8 @llvm.vector.reduce.or.v1i8(<1 x i8> undef) + %V3i8 = call i8 @llvm.vector.reduce.or.v3i8(<3 x i8> undef) + %V4i8 = call i8 @llvm.vector.reduce.or.v4i8(<4 x i8> undef) + %V8i8 = call i8 @llvm.vector.reduce.or.v8i8(<8 x i8> undef) + %V16i8 = call i8 @llvm.vector.reduce.or.v16i8(<16 x i8> undef) + %V32i8 = call i8 @llvm.vector.reduce.or.v32i8(<32 x i8> undef) + %V64i8 = call i8 @llvm.vector.reduce.or.v64i8(<64 x i8> undef) + %V4i16 = call i16 @llvm.vector.reduce.or.v4i16(<4 x i16> undef) + %V8i16 = call i16 @llvm.vector.reduce.or.v8i16(<8 x i16> undef) + %V16i16 = call i16 @llvm.vector.reduce.or.v16i16(<16 x i16> undef) + %V4i32 = call i32 @llvm.vector.reduce.or.v4i32(<4 x i32> undef) + %V8i32 = call i32 @llvm.vector.reduce.or.v8i32(<8 x i32> undef) + %V2i64 = call i64 @llvm.vector.reduce.or.v2i64(<2 x i64> undef) + %V4i64 = call i64 @llvm.vector.reduce.or.v4i64(<4 x i64> undef) ret i32 undef } declare i1 @llvm.vector.reduce.or.v1i1(<1 x i1>) declare i1 @llvm.vector.reduce.or.v2i1(<2 x i1>) declare i1 @llvm.vector.reduce.or.v4i1(<4 x i1>) @@ -32,3 +62,17 @@ declare i1 @llvm.vector.reduce.or.v32i1(<32 x i1>) declare i1 @llvm.vector.reduce.or.v64i1(<64 x i1>) declare i1 @llvm.vector.reduce.or.v128i1(<128 x i1>) +declare i8 @llvm.vector.reduce.or.v1i8(<1 x i8>) +declare i8 @llvm.vector.reduce.or.v3i8(<3 x i8>) +declare i8 @llvm.vector.reduce.or.v4i8(<4 x i8>) +declare i8 @llvm.vector.reduce.or.v8i8(<8 x i8>) +declare i8 @llvm.vector.reduce.or.v16i8(<16 x i8>) +declare i8 @llvm.vector.reduce.or.v32i8(<32 x i8>) +declare i8 @llvm.vector.reduce.or.v64i8(<64 x i8>) +declare i16 @llvm.vector.reduce.or.v4i16(<4 x i16>) +declare i16 @llvm.vector.reduce.or.v8i16(<8 x i16>) +declare i16 @llvm.vector.reduce.or.v16i16(<16 x i16>) +declare i32 @llvm.vector.reduce.or.v4i32(<4 x i32>) +declare i32 @llvm.vector.reduce.or.v8i32(<8 x i32>) +declare i64 @llvm.vector.reduce.or.v2i64(<2 x i64>) +declare i64 @llvm.vector.reduce.or.v4i64(<4 x i64>)