diff --git a/clang/lib/CodeGen/TargetInfo.cpp b/clang/lib/CodeGen/TargetInfo.cpp --- a/clang/lib/CodeGen/TargetInfo.cpp +++ b/clang/lib/CodeGen/TargetInfo.cpp @@ -3407,52 +3407,18 @@ return false; } -/// ContainsFloatAtOffset - Return true if the specified LLVM IR type has a -/// float member at the specified offset. For example, {int,{float}} has a -/// float at offset 4. It is conservatively correct for this routine to return -/// false. -static bool ContainsFloatAtOffset(llvm::Type *IRType, unsigned IROffset, - const llvm::DataLayout &TD) { - // Base case if we find a float. - if (IROffset == 0 && IRType->isFloatTy()) - return true; - - // If this is a struct, recurse into the field at the specified offset. - if (llvm::StructType *STy = dyn_cast(IRType)) { - const llvm::StructLayout *SL = TD.getStructLayout(STy); - unsigned Elt = SL->getElementContainingOffset(IROffset); - IROffset -= SL->getElementOffset(Elt); - return ContainsFloatAtOffset(STy->getElementType(Elt), IROffset, TD); - } - - // If this is an array, recurse into the field at the specified offset. - if (llvm::ArrayType *ATy = dyn_cast(IRType)) { - llvm::Type *EltTy = ATy->getElementType(); - unsigned EltSize = TD.getTypeAllocSize(EltTy); - IROffset -= IROffset/EltSize*EltSize; - return ContainsFloatAtOffset(EltTy, IROffset, TD); - } - - return false; -} - -/// ContainsHalfAtOffset - Return true if the specified LLVM IR type has a -/// half member at the specified offset. For example, {int,{half}} has a -/// half at offset 4. It is conservatively correct for this routine to return -/// false. -/// FIXME: Merge with ContainsFloatAtOffset -static bool ContainsHalfAtOffset(llvm::Type *IRType, unsigned IROffset, - const llvm::DataLayout &TD) { - // Base case if we find a float. - if (IROffset == 0 && IRType->isHalfTy()) - return true; +/// getFPTypeAtOffset - Return a floating point type at the specified offset. +static llvm::Type *getFPTypeAtOffset(llvm::Type *IRType, unsigned IROffset, + const llvm::DataLayout &TD) { + if (IROffset == 0 && IRType->isFloatingPointTy()) + return IRType; // If this is a struct, recurse into the field at the specified offset. if (llvm::StructType *STy = dyn_cast(IRType)) { const llvm::StructLayout *SL = TD.getStructLayout(STy); unsigned Elt = SL->getElementContainingOffset(IROffset); IROffset -= SL->getElementOffset(Elt); - return ContainsHalfAtOffset(STy->getElementType(Elt), IROffset, TD); + return getFPTypeAtOffset(STy->getElementType(Elt), IROffset, TD); } // If this is an array, recurse into the field at the specified offset. @@ -3460,10 +3426,10 @@ llvm::Type *EltTy = ATy->getElementType(); unsigned EltSize = TD.getTypeAllocSize(EltTy); IROffset -= IROffset / EltSize * EltSize; - return ContainsHalfAtOffset(EltTy, IROffset, TD); + return getFPTypeAtOffset(EltTy, IROffset, TD); } - return false; + return nullptr; } /// GetSSETypeAtOffset - Return a type that will be passed by the backend in the @@ -3471,39 +3437,37 @@ llvm::Type *X86_64ABIInfo:: GetSSETypeAtOffset(llvm::Type *IRType, unsigned IROffset, QualType SourceTy, unsigned SourceOffset) const { - // If the high 32 bits are not used, we have three choices. Single half, - // single float or two halfs. - if (BitsContainNoUserData(SourceTy, SourceOffset * 8 + 32, - SourceOffset * 8 + 64, getContext())) { - if (ContainsFloatAtOffset(IRType, IROffset, getDataLayout())) - return llvm::Type::getFloatTy(getVMContext()); - if (ContainsHalfAtOffset(IRType, IROffset + 2, getDataLayout())) - return llvm::FixedVectorType::get(llvm::Type::getHalfTy(getVMContext()), - 2); - - return llvm::Type::getHalfTy(getVMContext()); - } - - // We want to pass as <2 x float> if the LLVM IR type contains a float at - // offset+0 and offset+4. Walk the LLVM IR type to find out if this is the - // case. - if (ContainsFloatAtOffset(IRType, IROffset, getDataLayout()) && - ContainsFloatAtOffset(IRType, IROffset + 4, getDataLayout())) - return llvm::FixedVectorType::get(llvm::Type::getFloatTy(getVMContext()), - 2); - - // We want to pass as <4 x half> if the LLVM IR type contains a half at - // offset+0, +2, +4. Walk the LLVM IR type to find out if this is the case. - if (ContainsHalfAtOffset(IRType, IROffset, getDataLayout()) && - ContainsHalfAtOffset(IRType, IROffset + 2, getDataLayout()) && - ContainsHalfAtOffset(IRType, IROffset + 4, getDataLayout())) - return llvm::FixedVectorType::get(llvm::Type::getHalfTy(getVMContext()), 4); - - // We want to pass as <4 x half> if the LLVM IR type contains a mix of float - // and half. - // FIXME: Do we have a better representation for the mixed type? - if (ContainsFloatAtOffset(IRType, IROffset, getDataLayout()) || - ContainsFloatAtOffset(IRType, IROffset + 4, getDataLayout())) + const llvm::DataLayout &TD = getDataLayout(); + llvm::Type *T0 = getFPTypeAtOffset(IRType, IROffset, TD); + if (!T0 || T0->isDoubleTy()) + return llvm::Type::getDoubleTy(getVMContext()); + + // Get the adjacent FP type. + llvm::Type *T1 = + getFPTypeAtOffset(IRType, IROffset + TD.getTypeAllocSize(T0), TD); + if (T1 == nullptr) { + // Check if IRType is a half + float. float type will be in IROffset+4 due + // to its alignment. + if (T0->isHalfTy()) + T1 = getFPTypeAtOffset(IRType, IROffset + 4, TD); + // If we can't get a second FP type, return a simple half or float. + // avx512fp16-abi.c:pr51813_2 shows it works to return float for + // {float, i8} too. + if (T1 == nullptr) + return T0; + } + + if (T0->isFloatTy() && T1->isFloatTy()) + return llvm::FixedVectorType::get(T0, 2); + + if (T0->isHalfTy() && T1->isHalfTy()) { + llvm::Type *T2 = getFPTypeAtOffset(IRType, IROffset + 4, TD); + if (T2 == nullptr) + return llvm::FixedVectorType::get(T0, 2); + return llvm::FixedVectorType::get(T0, 4); + } + + if (T0->isHalfTy() || T1->isHalfTy()) return llvm::FixedVectorType::get(llvm::Type::getHalfTy(getVMContext()), 4); return llvm::Type::getDoubleTy(getVMContext()); diff --git a/clang/test/CodeGen/X86/avx512fp16-abi.c b/clang/test/CodeGen/X86/avx512fp16-abi.c --- a/clang/test/CodeGen/X86/avx512fp16-abi.c +++ b/clang/test/CodeGen/X86/avx512fp16-abi.c @@ -1,11 +1,12 @@ -// RUN: %clang_cc1 -triple x86_64-linux -emit-llvm -target-feature +avx512fp16 < %s | FileCheck %s --check-prefixes=CHECK +// RUN: %clang_cc1 -triple x86_64-linux -emit-llvm -target-feature +avx512fp16 < %s | FileCheck %s --check-prefixes=CHECK,CHECK-C +// RUN: %clang_cc1 -triple x86_64-linux -emit-llvm -target-feature +avx512fp16 -x c++ -std=c++11 < %s | FileCheck %s --check-prefixes=CHECK,CHECK-CPP struct half1 { _Float16 a; }; struct half1 h1(_Float16 a) { - // CHECK: define{{.*}}half @h1 + // CHECK: define{{.*}}half @ struct half1 x; x.a = a; return x; @@ -17,7 +18,7 @@ }; struct half2 h2(_Float16 a, _Float16 b) { - // CHECK: define{{.*}}<2 x half> @h2 + // CHECK: define{{.*}}<2 x half> @ struct half2 x; x.a = a; x.b = b; @@ -31,7 +32,7 @@ }; struct half3 h3(_Float16 a, _Float16 b, _Float16 c) { - // CHECK: define{{.*}}<4 x half> @h3 + // CHECK: define{{.*}}<4 x half> @ struct half3 x; x.a = a; x.b = b; @@ -47,7 +48,7 @@ }; struct half4 h4(_Float16 a, _Float16 b, _Float16 c, _Float16 d) { - // CHECK: define{{.*}}<4 x half> @h4 + // CHECK: define{{.*}}<4 x half> @ struct half4 x; x.a = a; x.b = b; @@ -62,7 +63,7 @@ }; struct floathalf fh(float a, _Float16 b) { - // CHECK: define{{.*}}<4 x half> @fh + // CHECK: define{{.*}}<4 x half> @ struct floathalf x; x.a = a; x.b = b; @@ -76,7 +77,7 @@ }; struct floathalf2 fh2(float a, _Float16 b, _Float16 c) { - // CHECK: define{{.*}}<4 x half> @fh2 + // CHECK: define{{.*}}<4 x half> @ struct floathalf2 x; x.a = a; x.b = b; @@ -90,7 +91,7 @@ }; struct halffloat hf(_Float16 a, float b) { - // CHECK: define{{.*}}<4 x half> @hf + // CHECK: define{{.*}}<4 x half> @ struct halffloat x; x.a = a; x.b = b; @@ -104,7 +105,7 @@ }; struct half2float h2f(_Float16 a, _Float16 b, float c) { - // CHECK: define{{.*}}<4 x half> @h2f + // CHECK: define{{.*}}<4 x half> @ struct half2float x; x.a = a; x.b = b; @@ -120,7 +121,7 @@ }; struct floathalf3 fh3(float a, _Float16 b, _Float16 c, _Float16 d) { - // CHECK: define{{.*}}{ <4 x half>, half } @fh3 + // CHECK: define{{.*}}{ <4 x half>, half } @ struct floathalf3 x; x.a = a; x.b = b; @@ -138,7 +139,7 @@ }; struct half5 h5(_Float16 a, _Float16 b, _Float16 c, _Float16 d, _Float16 e) { - // CHECK: define{{.*}}{ <4 x half>, half } @h5 + // CHECK: define{{.*}}{ <4 x half>, half } @ struct half5 x; x.a = a; x.b = b; @@ -147,3 +148,52 @@ x.e = e; return x; } + +struct float2 { + struct {} s; + float a; + float b; +}; + +float pr51813(struct float2 s) { + // CHECK-C: define{{.*}} @pr51813(<2 x float> + // CHECK-CPP: define{{.*}} @_Z7pr518136float2(double {{.*}}, float + return s.a; +} + +struct float3 { + float a; + struct {} s; + float b; +}; + +float pr51813_2(struct float3 s) { + // CHECK-C: define{{.*}} @pr51813_2(<2 x float> + // CHECK-CPP: define{{.*}} @_Z9pr51813_26float3(double {{.*}}, float + return s.a; +} + +struct shalf2 { + struct {} s; + _Float16 a; + _Float16 b; +}; + +_Float16 sf2(struct shalf2 s) { + // CHECK-C: define{{.*}} @sf2(<2 x half> + // CHECK-CPP: define{{.*}} @_Z3sf26shalf2(double {{.*}} + return s.a; +}; + +struct halfs2 { + _Float16 a; + struct {} s1; + _Float16 b; + struct {} s2; +}; + +_Float16 fs2(struct shalf2 s) { + // CHECK-C: define{{.*}} @fs2(<2 x half> + // CHECK-CPP: define{{.*}} @_Z3fs26shalf2(double {{.*}} + return s.a; +};