Index: clang/docs/LanguageExtensions.rst =================================================================== --- clang/docs/LanguageExtensions.rst +++ clang/docs/LanguageExtensions.rst @@ -631,6 +631,7 @@ =========================================== ================================================================ ========================================= T __builtin_elementwise_abs(T x) return the absolute value of a number x; the absolute value of signed integer and floating point types the most negative integer remains the most negative integer + T __builtin_elementwise_fma(T x, T y, T z) fused multiply add, (x * y) + z. floating point types T __builtin_elementwise_ceil(T x) return the smallest integral value greater than or equal to x floating point types T __builtin_elementwise_sin(T x) return the sine of x interpreted as an angle in radians floating point types T __builtin_elementwise_cos(T x) return the cosine of x interpreted as an angle in radians floating point types Index: clang/include/clang/Basic/Builtins.def =================================================================== --- clang/include/clang/Basic/Builtins.def +++ clang/include/clang/Basic/Builtins.def @@ -671,6 +671,7 @@ BUILTIN(__builtin_elementwise_trunc, "v.", "nct") BUILTIN(__builtin_elementwise_canonicalize, "v.", "nct") BUILTIN(__builtin_elementwise_copysign, "v.", "nct") +BUILTIN(__builtin_elementwise_fma, "v.", "nct") BUILTIN(__builtin_elementwise_add_sat, "v.", "nct") BUILTIN(__builtin_elementwise_sub_sat, "v.", "nct") BUILTIN(__builtin_reduce_max, "v.", "nct") Index: clang/include/clang/Sema/Sema.h =================================================================== --- clang/include/clang/Sema/Sema.h +++ clang/include/clang/Sema/Sema.h @@ -13531,6 +13531,7 @@ bool CheckPPCMMAType(QualType Type, SourceLocation TypeLoc); bool SemaBuiltinElementwiseMath(CallExpr *TheCall); + bool SemaBuiltinElementwiseTernaryMath(CallExpr *TheCall); bool PrepareBuiltinElementwiseMathOneArgCall(CallExpr *TheCall); bool PrepareBuiltinReduceMathOneArgCall(CallExpr *TheCall); Index: clang/lib/CodeGen/CGBuiltin.cpp =================================================================== --- clang/lib/CodeGen/CGBuiltin.cpp +++ clang/lib/CodeGen/CGBuiltin.cpp @@ -3118,6 +3118,8 @@ emitUnaryBuiltin(*this, E, llvm::Intrinsic::canonicalize, "elt.trunc")); case Builtin::BI__builtin_elementwise_copysign: return RValue::get(emitBinaryBuiltin(*this, E, llvm::Intrinsic::copysign)); + case Builtin::BI__builtin_elementwise_fma: + return RValue::get(emitTernaryBuiltin(*this, E, llvm::Intrinsic::fma)); case Builtin::BI__builtin_elementwise_add_sat: case Builtin::BI__builtin_elementwise_sub_sat: { Value *Op0 = EmitScalarExpr(E->getArg(0)); Index: clang/lib/Sema/SemaChecking.cpp =================================================================== --- clang/lib/Sema/SemaChecking.cpp +++ clang/lib/Sema/SemaChecking.cpp @@ -2626,20 +2626,16 @@ return ExprError(); QualType ArgTy = TheCall->getArg(0)->getType(); - QualType EltTy = ArgTy; - - if (auto *VecTy = EltTy->getAs()) - EltTy = VecTy->getElementType(); - if (!EltTy->isFloatingType()) { - Diag(TheCall->getArg(0)->getBeginLoc(), - diag::err_builtin_invalid_arg_type) - << 1 << /* float ty*/ 5 << ArgTy; - + if (checkFPMathBuiltinElementType(*this, TheCall->getArg(0)->getBeginLoc(), + ArgTy, 1)) + return ExprError(); + break; + } + case Builtin::BI__builtin_elementwise_fma: { + if (SemaBuiltinElementwiseTernaryMath(TheCall)) return ExprError(); - } break; } - // These builtins restrict the element type to integer // types only. case Builtin::BI__builtin_elementwise_add_sat: @@ -17877,6 +17873,40 @@ return false; } +bool Sema::SemaBuiltinElementwiseTernaryMath(CallExpr *TheCall) { + if (checkArgCount(*this, TheCall, 3)) + return true; + + Expr *Args[3]; + for (int I = 0; I < 3; ++I) { + ExprResult Converted = UsualUnaryConversions(TheCall->getArg(I)); + if (Converted.isInvalid()) + return true; + Args[I] = Converted.get(); + } + + int ArgOrdinal = 1; + for (Expr *Arg : Args) { + if (checkFPMathBuiltinElementType(*this, Arg->getBeginLoc(), Arg->getType(), + ArgOrdinal++)) + return true; + } + + for (int I = 1; I < 3; ++I) { + if (Args[0]->getType().getCanonicalType() != + Args[I]->getType().getCanonicalType()) { + return Diag(Args[0]->getBeginLoc(), + diag::err_typecheck_call_different_arg_types) + << Args[0]->getType() << Args[I]->getType(); + } + + TheCall->setArg(I, Args[I]); + } + + TheCall->setType(Args[0]->getType()); + return false; +} + bool Sema::PrepareBuiltinReduceMathOneArgCall(CallExpr *TheCall) { if (checkArgCount(*this, TheCall, 1)) return true; Index: clang/test/CodeGen/builtins-elementwise-math.c =================================================================== --- clang/test/CodeGen/builtins-elementwise-math.c +++ clang/test/CodeGen/builtins-elementwise-math.c @@ -1,5 +1,9 @@ // RUN: %clang_cc1 -triple x86_64-apple-darwin %s -emit-llvm -disable-llvm-passes -o - | FileCheck %s +typedef _Float16 half; + +typedef half half2 __attribute__((ext_vector_type(2))); +typedef float float2 __attribute__((ext_vector_type(2))); typedef float float4 __attribute__((ext_vector_type(4))); typedef short int si8 __attribute__((ext_vector_type(8))); typedef unsigned int u4 __attribute__((ext_vector_type(4))); @@ -525,3 +529,77 @@ // CHECK-NEXT: call <2 x double> @llvm.copysign.v2f64(<2 x double> , <2 x double> [[V2F64]]) v2f64 = __builtin_elementwise_copysign((double2)1.0, v2f64); } + +void test_builtin_elementwise_fma(float f32, double f64, + float2 v2f32, float4 v4f32, + double2 v2f64, double3 v3f64, + const float4 c_v4f32, + half f16, half2 v2f16) { + // CHECK-LABEL: define void @test_builtin_elementwise_fma( + // CHECK: [[F32_0:%.+]] = load float, ptr %f32.addr + // CHECK-NEXT: [[F32_1:%.+]] = load float, ptr %f32.addr + // CHECK-NEXT: [[F32_2:%.+]] = load float, ptr %f32.addr + // CHECK-NEXT: call float @llvm.fma.f32(float [[F32_0]], float [[F32_1]], float [[F32_2]]) + float f2 = __builtin_elementwise_fma(f32, f32, f32); + + // CHECK: [[F64_0:%.+]] = load double, ptr %f64.addr + // CHECK-NEXT: [[F64_1:%.+]] = load double, ptr %f64.addr + // CHECK-NEXT: [[F64_2:%.+]] = load double, ptr %f64.addr + // CHECK-NEXT: call double @llvm.fma.f64(double [[F64_0]], double [[F64_1]], double [[F64_2]]) + double d2 = __builtin_elementwise_fma(f64, f64, f64); + + // CHECK: [[V4F32_0:%.+]] = load <4 x float>, ptr %v4f32.addr + // CHECK-NEXT: [[V4F32_1:%.+]] = load <4 x float>, ptr %v4f32.addr + // CHECK-NEXT: [[V4F32_2:%.+]] = load <4 x float>, ptr %v4f32.addr + // CHECK-NEXT: call <4 x float> @llvm.fma.v4f32(<4 x float> [[V4F32_0]], <4 x float> [[V4F32_1]], <4 x float> [[V4F32_2]]) + float4 tmp_v4f32 = __builtin_elementwise_fma(v4f32, v4f32, v4f32); + + + // FIXME: Are we really still doing the 3 vector load workaround + // CHECK: [[V3F64_LOAD_0:%.+]] = load <4 x double>, ptr %v3f64.addr + // CHECK-NEXT: [[V3F64_0:%.+]] = shufflevector + // CHECK-NEXT: [[V3F64_LOAD_1:%.+]] = load <4 x double>, ptr %v3f64.addr + // CHECK-NEXT: [[V3F64_1:%.+]] = shufflevector + // CHECK-NEXT: [[V3F64_LOAD_2:%.+]] = load <4 x double>, ptr %v3f64.addr + // CHECK-NEXT: [[V3F64_2:%.+]] = shufflevector + // CHECK-NEXT: call <3 x double> @llvm.fma.v3f64(<3 x double> [[V3F64_0]], <3 x double> [[V3F64_1]], <3 x double> [[V3F64_2]]) + v3f64 = __builtin_elementwise_fma(v3f64, v3f64, v3f64); + + // CHECK: [[F64_0:%.+]] = load double, ptr %f64.addr + // CHECK-NEXT: [[F64_1:%.+]] = load double, ptr %f64.addr + // CHECK-NEXT: [[F64_2:%.+]] = load double, ptr %f64.addr + // CHECK-NEXT: call double @llvm.fma.f64(double [[F64_0]], double [[F64_1]], double [[F64_2]]) + v2f64 = __builtin_elementwise_fma(f64, f64, f64); + + // CHECK: [[V4F32_0:%.+]] = load <4 x float>, ptr %c_v4f32.addr + // CHECK-NEXT: [[V4F32_1:%.+]] = load <4 x float>, ptr %c_v4f32.addr + // CHECK-NEXT: [[V4F32_2:%.+]] = load <4 x float>, ptr %c_v4f32.addr + // CHECK-NEXT: call <4 x float> @llvm.fma.v4f32(<4 x float> [[V4F32_0]], <4 x float> [[V4F32_1]], <4 x float> [[V4F32_2]]) + v4f32 = __builtin_elementwise_fma(c_v4f32, c_v4f32, c_v4f32); + + // CHECK: [[F16_0:%.+]] = load half, ptr %f16.addr + // CHECK-NEXT: [[F16_1:%.+]] = load half, ptr %f16.addr + // CHECK-NEXT: [[F16_2:%.+]] = load half, ptr %f16.addr + // CHECK-NEXT: call half @llvm.fma.f16(half [[F16_0]], half [[F16_1]], half [[F16_2]]) + half tmp_f16 = __builtin_elementwise_fma(f16, f16, f16); + + // CHECK: [[V2F16_0:%.+]] = load <2 x half>, ptr %v2f16.addr + // CHECK-NEXT: [[V2F16_1:%.+]] = load <2 x half>, ptr %v2f16.addr + // CHECK-NEXT: [[V2F16_2:%.+]] = load <2 x half>, ptr %v2f16.addr + // CHECK-NEXT: call <2 x half> @llvm.fma.v2f16(<2 x half> [[V2F16_0]], <2 x half> [[V2F16_1]], <2 x half> [[V2F16_2]]) + half2 tmp0_v2f16 = __builtin_elementwise_fma(v2f16, v2f16, v2f16); + + // CHECK: [[V2F16_0:%.+]] = load <2 x half>, ptr %v2f16.addr + // CHECK-NEXT: [[V2F16_1:%.+]] = load <2 x half>, ptr %v2f16.addr + // CHECK-NEXT: [[F16_2:%.+]] = load half, ptr %f16.addr + // CHECK-NEXT: [[V2F16_2_INSERT:%.+]] = insertelement + // CHECK-NEXT: [[V2F16_2:%.+]] = shufflevector <2 x half> [[V2F16_2_INSERT]], <2 x half> poison, <2 x i32> zeroinitializer + // CHECK-NEXT: call <2 x half> @llvm.fma.v2f16(<2 x half> [[V2F16_0]], <2 x half> [[V2F16_1]], <2 x half> [[V2F16_2]]) + half2 tmp1_v2f16 = __builtin_elementwise_fma(v2f16, v2f16, (half2)f16); + + // CHECK: [[V2F16_0:%.+]] = load <2 x half>, ptr %v2f16.addr + // CHECK-NEXT: [[V2F16_1:%.+]] = load <2 x half>, ptr %v2f16.addr + // CHECK-NEXT: call <2 x half> @llvm.fma.v2f16(<2 x half> [[V2F16_0]], <2 x half> [[V2F16_1]], <2 x half> ) + half2 tmp2_v2f16 = __builtin_elementwise_fma(v2f16, v2f16, (half2)4.0); + +} Index: clang/test/Sema/builtins-elementwise-math.c =================================================================== --- clang/test/Sema/builtins-elementwise-math.c +++ clang/test/Sema/builtins-elementwise-math.c @@ -4,6 +4,8 @@ typedef double double4 __attribute__((ext_vector_type(4))); typedef float float2 __attribute__((ext_vector_type(2))); typedef float float4 __attribute__((ext_vector_type(4))); + +typedef int int2 __attribute__((ext_vector_type(2))); typedef int int3 __attribute__((ext_vector_type(3))); typedef unsigned unsigned3 __attribute__((ext_vector_type(3))); typedef unsigned unsigned4 __attribute__((ext_vector_type(4))); @@ -572,3 +574,84 @@ float2 tmp9 = __builtin_elementwise_copysign(v4f32, v4f32); // expected-error@-1 {{initializing 'float2' (vector of 2 'float' values) with an expression of incompatible type 'float4' (vector of 4 'float' values)}} } + +void test_builtin_elementwise_fma(int i32, int2 v2i32, short i16, + double f64, double2 v2f64, double2 v3f64, + float f32, float2 v2f32, float v3f32, float4 v4f32, + const float4 c_v4f32, + int3 v3i32, int *ptr) { + + f32 = __builtin_elementwise_fma(); + // expected-error@-1 {{too few arguments to function call, expected 3, have 0}} + + f32 = __builtin_elementwise_fma(f32); + // expected-error@-1 {{too few arguments to function call, expected 3, have 1}} + + f32 = __builtin_elementwise_fma(f32, f32); + // expected-error@-1 {{too few arguments to function call, expected 3, have 2}} + + f32 = __builtin_elementwise_fma(f32, f32, f32, f32); + // expected-error@-1 {{too many arguments to function call, expected 3, have 4}} + + f32 = __builtin_elementwise_fma(f64, f32, f32); + // expected-error@-1 {{arguments are of different types ('double' vs 'float')}} + + f32 = __builtin_elementwise_fma(f32, f64, f32); + // expected-error@-1 {{arguments are of different types ('float' vs 'double')}} + + f32 = __builtin_elementwise_fma(f32, f32, f64); + // expected-error@-1 {{arguments are of different types ('float' vs 'double')}} + + f32 = __builtin_elementwise_fma(f32, f32, f64); + // expected-error@-1 {{arguments are of different types ('float' vs 'double')}} + + f64 = __builtin_elementwise_fma(f64, f32, f32); + // expected-error@-1 {{arguments are of different types ('double' vs 'float')}} + + f64 = __builtin_elementwise_fma(f64, f64, f32); + // expected-error@-1 {{arguments are of different types ('double' vs 'float')}} + + f64 = __builtin_elementwise_fma(f64, f32, f64); + // expected-error@-1 {{arguments are of different types ('double' vs 'float')}} + + v2f64 = __builtin_elementwise_fma(v2f32, f64, f64); + // expected-error@-1 {{arguments are of different types ('float2' (vector of 2 'float' values) vs 'double'}} + + v2f64 = __builtin_elementwise_fma(v2f32, v2f64, f64); + // expected-error@-1 {{arguments are of different types ('float2' (vector of 2 'float' values) vs 'double2' (vector of 2 'double' values)}} + + v2f64 = __builtin_elementwise_fma(v2f32, f64, v2f64); + // expected-error@-1 {{arguments are of different types ('float2' (vector of 2 'float' values) vs 'double'}} + + v2f64 = __builtin_elementwise_fma(f64, v2f32, v2f64); + // expected-error@-1 {{arguments are of different types ('double' vs 'float2' (vector of 2 'float' values)}} + + v2f64 = __builtin_elementwise_fma(f64, v2f64, v2f64); + // expected-error@-1 {{arguments are of different types ('double' vs 'double2' (vector of 2 'double' values)}} + + i32 = __builtin_elementwise_fma(i32, i32, i32); + // expected-error@-1 {{1st argument must be a floating point type (was 'int')}} + + v2i32 = __builtin_elementwise_fma(v2i32, v2i32, v2i32); + // expected-error@-1 {{1st argument must be a floating point type (was 'int2' (vector of 2 'int' values))}} + + f32 = __builtin_elementwise_fma(f32, f32, i32); + // expected-error@-1 {{3rd argument must be a floating point type (was 'int')}} + + f32 = __builtin_elementwise_fma(f32, i32, f32); + // expected-error@-1 {{2nd argument must be a floating point type (was 'int')}} + + f32 = __builtin_elementwise_fma(f32, f32, i32); + // expected-error@-1 {{3rd argument must be a floating point type (was 'int')}} + + + _Complex float c1, c2, c3; + c1 = __builtin_elementwise_fma(c1, f32, f32); + // expected-error@-1 {{1st argument must be a floating point type (was '_Complex float')}} + + c2 = __builtin_elementwise_fma(f32, c2, f32); + // expected-error@-1 {{2nd argument must be a floating point type (was '_Complex float')}} + + c3 = __builtin_elementwise_fma(f32, f32, c3); + // expected-error@-1 {{3rd argument must be a floating point type (was '_Complex float')}} +}