Index: lib/Target/AArch64/AArch64TargetTransformInfo.h =================================================================== --- lib/Target/AArch64/AArch64TargetTransformInfo.h +++ lib/Target/AArch64/AArch64TargetTransformInfo.h @@ -112,6 +112,11 @@ int getMemoryOpCost(unsigned Opcode, Type *Src, unsigned Alignment, unsigned AddressSpace); + int getIntrinsicInstrCost(Intrinsic::ID IID, Type *RetTy, + ArrayRef Tys, FastMathFlags FMF); + int getIntrinsicInstrCost(Intrinsic::ID IID, Type *RetTy, + ArrayRef Args, FastMathFlags FMF); + int getCostOfKeepingLiveOverCall(ArrayRef Tys); void getUnrollingPreferences(Loop *L, TTI::UnrollingPreferences &UP); Index: lib/Target/AArch64/AArch64TargetTransformInfo.cpp =================================================================== --- lib/Target/AArch64/AArch64TargetTransformInfo.cpp +++ lib/Target/AArch64/AArch64TargetTransformInfo.cpp @@ -495,6 +495,46 @@ return LT.first; } +int AArch64TTIImpl::getIntrinsicInstrCost(Intrinsic::ID IID, Type *RetTy, + ArrayRef Tys, FastMathFlags FMF) { + static const CostTblEntry KryoCostTbl[] = { + { ISD::FSQRT, MVT::f32, 5 }, + { ISD::FSQRT, MVT::f64, 6 }, + { ISD::FSQRT, MVT::v2f32, 7 }, + { ISD::FSQRT, MVT::v4f32, 7 }, + { ISD::FSQRT, MVT::v2f64, 7 }, + }; + + unsigned ISD = ISD::DELETED_NODE; + switch (IID) { + default: + break; + case Intrinsic::sqrt: + ISD = ISD::FSQRT; + break; + } + + // Legalize the type. + std::pair LT = TLI->getTypeLegalizationCost(DL, RetTy); + MVT MTy = LT.second; + + // Attempt to lookup cost. + switch(ST->getProcFamily()) { + default: + break; + case AArch64Subtarget::Kryo: + if (const auto *Entry = CostTableLookup(KryoCostTbl, ISD, MTy)) + return LT.first * Entry->Cost; + break; + } + return BaseT::getIntrinsicInstrCost(IID, RetTy, Tys, FMF); +} + +int AArch64TTIImpl::getIntrinsicInstrCost(Intrinsic::ID IID, Type *RetTy, + ArrayRef Args, FastMathFlags FMF) { + return BaseT::getIntrinsicInstrCost(IID, RetTy, Args, FMF); +} + int AArch64TTIImpl::getInterleavedMemoryOpCost(unsigned Opcode, Type *VecTy, unsigned Factor, ArrayRef Indices, Index: test/Analysis/CostModel/AArch64/arith-fp.ll =================================================================== --- /dev/null +++ test/Analysis/CostModel/AArch64/arith-fp.ll @@ -0,0 +1,36 @@ +; RUN: opt < %s -enable-no-nans-fp-math -cost-model -analyze | FileCheck %s +; RUN: opt < %s -enable-no-nans-fp-math -cost-model -analyze -mcpu=kryo | FileCheck %s --check-prefix=KRYO + +target datalayout = "e-m:e-i64:64-i128:128-n32:64-S128" +target triple = "aarch64--linux-gnu" + +define i32 @fsqrt(i32 %arg) { + %F32 = call float @llvm.sqrt.f32(float undef) +; CHECK: cost of 1 {{.*}} %F32 = call float @llvm.sqrt.f32 +; KRYO: cost of 5 {{.*}} %F32 = call float @llvm.sqrt.f32 + + %V2F32 = call <2 x float> @llvm.sqrt.v2f32(<2 x float> undef) +; CHECK: cost of 1 {{.*}} %V2F32 = call <2 x float> @llvm.sqrt.v2f32 +; KRYO: cost of 7 {{.*}} %V2F32 = call <2 x float> @llvm.sqrt.v2f32 + + %V4F32 = call <4 x float> @llvm.sqrt.v4f32(<4 x float> undef) +; CHECK: cost of 1 {{.*}} %V4F32 = call <4 x float> @llvm.sqrt.v4f32 +; KRYO: cost of 7 {{.*}} %V4F32 = call <4 x float> @llvm.sqrt.v4f32 + + %F64 = call double @llvm.sqrt.f64(double undef) +; CHECK: cost of 1 {{.*}} %F64 = call double @llvm.sqrt.f64 +; KRYO: cost of 6 {{.*}} %F64 = call double @llvm.sqrt.f64 + + %V2F64 = call <2 x double> @llvm.sqrt.v2f64(<2 x double> undef) +; CHECK: cost of 1 {{.*}} %V2F64 = call <2 x double> @llvm.sqrt.v2f64 +; KRYO: cost of 7 {{.*}} %V2F64 = call <2 x double> @llvm.sqrt.v2f64 + + ret i32 undef +} + +declare float @llvm.sqrt.f32(float) +declare <2 x float> @llvm.sqrt.v2f32(<2 x float>) +declare <4 x float> @llvm.sqrt.v4f32(<4 x float>) + +declare double @llvm.sqrt.f64(double) +declare <2 x double> @llvm.sqrt.v2f64(<2 x double>) Index: test/Transforms/SLPVectorizer/AArch64/sqrt.ll =================================================================== --- /dev/null +++ test/Transforms/SLPVectorizer/AArch64/sqrt.ll @@ -0,0 +1,63 @@ +; RUN: opt < %s -basicaa -slp-vectorizer -S -mtriple=aarch64-unknown-linux-gnu -mcpu=kryo | FileCheck %s + +target datalayout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128" +target triple = "aarch64--linux-gnu" + +; CHECK-LABEL: @test1( +; CHECK: fmul fast float +; CHECK: fmul fast float +; CHECK: fdiv fast <2 x float> +; CHECK: call fast <2 x float> @llvm.sqrt.v2f32(<2 x float> +; CHECK: call fast <2 x float> @llvm.sqrt.v2f32(<2 x float> +; CHECK: ret float +define float @test1(float %t1, float %t2, float %t3, float %z1, float %z2) { +entry: + %mul = fmul fast float %t1, 2.000000e+00 + %mul4 = fmul fast float %mul, %z1 + %div = fdiv fast float %mul4, %t2 + %0 = tail call fast float @llvm.sqrt.f32(float %div) + %div7 = fdiv fast float %mul4, %t3 + %1 = tail call fast float @llvm.sqrt.f32(float %div7) + %mul9 = fmul fast float %mul, %z2 + %div10 = fdiv fast float %mul9, %t2 + %2 = tail call fast float @llvm.sqrt.f32(float %div10) + %div13 = fdiv fast float %mul9, %t3 + %3 = tail call fast float @llvm.sqrt.f32(float %div13) + %cmp14 = fcmp fast ogt float %0, %2 + %cond = select i1 %cmp14, float %0, float %2 + %cmp15 = fcmp fast ogt float %1, %3 + %cond19 = select i1 %cmp15, float %1, float %3 + %add = fadd fast float %cond, %cond19 + ret float %add +} + +; CHECK-LABEL: @test2( +; CHECK: fmul fast double +; CHECK: fmul fast <2 x double> +; CHECK: fdiv fast <2 x double> +; CHECK: call fast <2 x double> @llvm.sqrt.v2f64(<2 x double> +; CHECK: call fast <2 x double> @llvm.sqrt.v2f64(<2 x double> +; CHECK: ret double +define double @test2(double %t1, double %t2, double %t3, double %z1, double %z2) { +entry: + %mul = fmul fast double %t1, 2.000000e+00 + %mul4 = fmul fast double %mul, %z1 + %div = fdiv fast double %mul4, %t2 + %0 = tail call fast double @llvm.sqrt.f64(double %div) + %div7 = fdiv fast double %mul4, %t3 + %1 = tail call fast double @llvm.sqrt.f64(double %div7) + %mul9 = fmul fast double %mul, %z2 + %div10 = fdiv fast double %mul9, %t2 + %2 = tail call fast double @llvm.sqrt.f64(double %div10) + %div13 = fdiv fast double %mul9, %t3 + %3 = tail call fast double @llvm.sqrt.f64(double %div13) + %cmp14 = fcmp fast ogt double %0, %2 + %cond = select i1 %cmp14, double %0, double %2 + %cmp15 = fcmp fast ogt double %1, %3 + %cond19 = select i1 %cmp15, double %1, double %3 + %add = fadd fast double %cond, %cond19 + ret double %add +} + +declare float @llvm.sqrt.f32(float) +declare double @llvm.sqrt.f64(double)