diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h --- a/clang/include/clang/Sema/Sema.h +++ b/clang/include/clang/Sema/Sema.h @@ -11991,6 +11991,10 @@ QualType CheckVectorConditionalTypes(ExprResult &Cond, ExprResult &LHS, ExprResult &RHS, SourceLocation QuestionLoc); + + QualType CheckSizelessVectorConditionalTypes(ExprResult &Cond, + ExprResult &LHS, ExprResult &RHS, + SourceLocation QuestionLoc); QualType FindCompositePointerType(SourceLocation Loc, Expr *&E1, Expr *&E2, bool ConvertArgs = true); QualType FindCompositePointerType(SourceLocation Loc, diff --git a/clang/lib/CodeGen/CGExprScalar.cpp b/clang/lib/CodeGen/CGExprScalar.cpp --- a/clang/lib/CodeGen/CGExprScalar.cpp +++ b/clang/lib/CodeGen/CGExprScalar.cpp @@ -4642,7 +4642,8 @@ return tmp5; } - if (condExpr->getType()->isVectorType()) { + if (condExpr->getType()->isVectorType() || + condExpr->getType()->isVLSTBuiltinType()) { CGF.incrementProfileCounter(E); llvm::Value *CondV = CGF.EmitScalarExpr(condExpr); diff --git a/clang/lib/Sema/SemaExprCXX.cpp b/clang/lib/Sema/SemaExprCXX.cpp --- a/clang/lib/Sema/SemaExprCXX.cpp +++ b/clang/lib/Sema/SemaExprCXX.cpp @@ -6103,6 +6103,16 @@ return EltTy->isIntegralType(Ctx); } +static bool isValidSizelessVectorForConditionalCondition(ASTContext &Ctx, + QualType CondTy) { + if (!CondTy->isVLSTBuiltinType()) + return false; + const QualType EltTy = + cast(CondTy.getCanonicalType())->getSveEltType(Ctx); + assert(!EltTy->isEnumeralType() && "Vectors cant be enum types"); + return EltTy->isIntegralType(Ctx); +} + QualType Sema::CheckVectorConditionalTypes(ExprResult &Cond, ExprResult &LHS, ExprResult &RHS, SourceLocation QuestionLoc) { @@ -6194,6 +6204,96 @@ return ResultType; } +QualType Sema::CheckSizelessVectorConditionalTypes(ExprResult &Cond, + ExprResult &LHS, + ExprResult &RHS, + SourceLocation QuestionLoc) { + LHS = DefaultFunctionArrayLvalueConversion(LHS.get()); + RHS = DefaultFunctionArrayLvalueConversion(RHS.get()); + + QualType CondType = Cond.get()->getType(); + const auto *CondBT = CondType->castAs(); + QualType CondElementTy = CondBT->getSveEltType(Context); + llvm::ElementCount CondElementCount = + Context.getBuiltinVectorTypeInfo(CondBT).EC; + + QualType LHSType = LHS.get()->getType(); + const auto *LHSBT = + LHSType->isVLSTBuiltinType() ? LHSType->getAs() : nullptr; + QualType RHSType = RHS.get()->getType(); + const auto *RHSBT = + RHSType->isVLSTBuiltinType() ? RHSType->getAs() : nullptr; + + QualType ResultType; + + if (LHSBT && RHSBT) { + // If both are sizeless vector types, they must be the same type. + if (!Context.hasSameType(LHSType, RHSType)) { + Diag(QuestionLoc, diag::err_conditional_vector_mismatched) + << LHSType << RHSType; + return QualType(); + } + ResultType = LHSType; + } else if (LHSBT || RHSBT) { + ResultType = CheckSizelessVectorOperands( + LHS, RHS, QuestionLoc, /*IsCompAssign*/ false, ACK_Conditional); + if (ResultType.isNull()) + return QualType(); + } else { + // Both are scalar so splat + QualType ResultElementTy; + LHSType = LHSType.getCanonicalType().getUnqualifiedType(); + RHSType = RHSType.getCanonicalType().getUnqualifiedType(); + + if (Context.hasSameType(LHSType, RHSType)) + ResultElementTy = LHSType; + else + ResultElementTy = + UsualArithmeticConversions(LHS, RHS, QuestionLoc, ACK_Conditional); + + if (ResultElementTy->isEnumeralType()) { + Diag(QuestionLoc, diag::err_conditional_vector_operand_type) + << ResultElementTy; + return QualType(); + } + + if (Context.getTypeSize(ResultElementTy) != + Context.getTypeSize(CondElementTy)) { + Diag(QuestionLoc, diag::err_conditional_vector_element_size) + << CondType << LHSType; + return QualType(); + } + + ResultType = Context.getScalableVectorType( + ResultElementTy, CondElementCount.getKnownMinValue()); + + LHS = ImpCastExprToType(LHS.get(), ResultType, CK_VectorSplat); + RHS = ImpCastExprToType(RHS.get(), ResultType, CK_VectorSplat); + } + + assert(!ResultType.isNull() && ResultType->isVLSTBuiltinType() && + "Result should have been a vector type"); + auto *ResultBuiltinTy = ResultType->castAs(); + QualType ResultElementTy = ResultBuiltinTy->getSveEltType(Context); + llvm::ElementCount ResultElementCount = + Context.getBuiltinVectorTypeInfo(ResultBuiltinTy).EC; + + if (ResultElementCount != CondElementCount) { + Diag(QuestionLoc, diag::err_conditional_vector_size) + << CondType << ResultType; + return QualType(); + } + + if (Context.getTypeSize(ResultElementTy) != + Context.getTypeSize(CondElementTy)) { + Diag(QuestionLoc, diag::err_conditional_vector_element_size) + << CondType << ResultType; + return QualType(); + } + + return ResultType; +} + /// Check the operands of ?: under C++ semantics. /// /// See C++ [expr.cond]. Note that LHS is never null, even for the GNU x ?: y @@ -6227,10 +6327,14 @@ bool IsVectorConditional = isValidVectorForConditionalCondition(Context, Cond.get()->getType()); + bool IsSizelessVectorConditional = + isValidSizelessVectorForConditionalCondition(Context, + Cond.get()->getType()); + // C++11 [expr.cond]p1 // The first expression is contextually converted to bool. if (!Cond.get()->isTypeDependent()) { - ExprResult CondRes = IsVectorConditional + ExprResult CondRes = IsVectorConditional || IsSizelessVectorConditional ? DefaultFunctionArrayLvalueConversion(Cond.get()) : CheckCXXBooleanCondition(Cond.get()); if (CondRes.isInvalid()) @@ -6299,6 +6403,9 @@ if (IsVectorConditional) return CheckVectorConditionalTypes(Cond, LHS, RHS, QuestionLoc); + if (IsSizelessVectorConditional) + return CheckSizelessVectorConditionalTypes(Cond, LHS, RHS, QuestionLoc); + // C++11 [expr.cond]p3 // Otherwise, if the second and third operand have different types, and // either has (cv) class type [...] an attempt is made to convert each of diff --git a/clang/test/CodeGenCXX/aarch64-sve-vector-conditional-op.cpp b/clang/test/CodeGenCXX/aarch64-sve-vector-conditional-op.cpp new file mode 100644 --- /dev/null +++ b/clang/test/CodeGenCXX/aarch64-sve-vector-conditional-op.cpp @@ -0,0 +1,224 @@ +// NOTE: Assertions have been autogenerated by utils/update_cc_test_checks.py +// RUN: %clang_cc1 -triple aarch64-none-linux-gnu -target-feature +sve \ +// RUN: -fallow-half-arguments-and-returns -disable-O0-optnone \ +// RUN: -emit-llvm -o - %s | opt -S -sroa | FileCheck %s + +// REQUIRES: aarch64-registered-target + +#include + +// CHECK-LABEL: @_Z9cond_boolu10__SVBool_tu10__SVBool_t( +// CHECK-NEXT: entry: +// CHECK-NEXT: [[CMP:%.*]] = icmp ult [[A:%.*]], [[B:%.*]] +// CHECK-NEXT: [[VECTOR_COND:%.*]] = icmp ne [[CMP]], zeroinitializer +// CHECK-NEXT: [[VECTOR_SELECT:%.*]] = select [[VECTOR_COND]], [[A]], [[B]] +// CHECK-NEXT: ret [[VECTOR_SELECT]] +// +svbool_t cond_bool(svbool_t a, svbool_t b) { + return a < b ? a : b; +} + +// CHECK-LABEL: @_Z7cond_i8u10__SVInt8_tu10__SVInt8_t( +// CHECK-NEXT: entry: +// CHECK-NEXT: [[CMP:%.*]] = icmp ult [[A:%.*]], [[B:%.*]] +// CHECK-NEXT: [[CONV:%.*]] = zext [[CMP]] to +// CHECK-NEXT: [[VECTOR_COND:%.*]] = icmp ne [[CONV]], zeroinitializer +// CHECK-NEXT: [[VECTOR_SELECT:%.*]] = select [[VECTOR_COND]], [[A]], [[B]] +// CHECK-NEXT: ret [[VECTOR_SELECT]] +// +svint8_t cond_i8(svint8_t a, svint8_t b) { + return a < b ? a : b; +} + +// CHECK-LABEL: @_Z7cond_u8u11__SVUint8_tu11__SVUint8_t( +// CHECK-NEXT: entry: +// CHECK-NEXT: [[CMP:%.*]] = icmp ult [[A:%.*]], [[B:%.*]] +// CHECK-NEXT: [[CONV:%.*]] = zext [[CMP]] to +// CHECK-NEXT: [[VECTOR_COND:%.*]] = icmp ne [[CONV]], zeroinitializer +// CHECK-NEXT: [[VECTOR_SELECT:%.*]] = select [[VECTOR_COND]], [[A]], [[B]] +// CHECK-NEXT: ret [[VECTOR_SELECT]] +// +svuint8_t cond_u8(svuint8_t a, svuint8_t b) { + return a < b ? a : b; +} + +// CHECK-LABEL: @_Z8cond_i16u11__SVInt16_tu11__SVInt16_t( +// CHECK-NEXT: entry: +// CHECK-NEXT: [[CMP:%.*]] = icmp ult [[A:%.*]], [[B:%.*]] +// CHECK-NEXT: [[CONV:%.*]] = zext [[CMP]] to +// CHECK-NEXT: [[VECTOR_COND:%.*]] = icmp ne [[CONV]], zeroinitializer +// CHECK-NEXT: [[VECTOR_SELECT:%.*]] = select [[VECTOR_COND]], [[A]], [[B]] +// CHECK-NEXT: ret [[VECTOR_SELECT]] +// +svint16_t cond_i16(svint16_t a, svint16_t b) { + return a < b ? a : b; +} + +// CHECK-LABEL: @_Z8cond_u16u12__SVUint16_tu12__SVUint16_t( +// CHECK-NEXT: entry: +// CHECK-NEXT: [[CMP:%.*]] = icmp ult [[A:%.*]], [[B:%.*]] +// CHECK-NEXT: [[CONV:%.*]] = zext [[CMP]] to +// CHECK-NEXT: [[VECTOR_COND:%.*]] = icmp ne [[CONV]], zeroinitializer +// CHECK-NEXT: [[VECTOR_SELECT:%.*]] = select [[VECTOR_COND]], [[A]], [[B]] +// CHECK-NEXT: ret [[VECTOR_SELECT]] +// +svuint16_t cond_u16(svuint16_t a, svuint16_t b) { + return a < b ? a : b; +} + +// CHECK-LABEL: @_Z8cond_i32u11__SVInt32_tu11__SVInt32_t( +// CHECK-NEXT: entry: +// CHECK-NEXT: [[CMP:%.*]] = icmp ult [[A:%.*]], [[B:%.*]] +// CHECK-NEXT: [[CONV:%.*]] = zext [[CMP]] to +// CHECK-NEXT: [[VECTOR_COND:%.*]] = icmp ne [[CONV]], zeroinitializer +// CHECK-NEXT: [[VECTOR_SELECT:%.*]] = select [[VECTOR_COND]], [[A]], [[B]] +// CHECK-NEXT: ret [[VECTOR_SELECT]] +// +svint32_t cond_i32(svint32_t a, svint32_t b) { + return a < b ? a : b; +} + +// CHECK-LABEL: @_Z8cond_u32u12__SVUint32_tu12__SVUint32_t( +// CHECK-NEXT: entry: +// CHECK-NEXT: [[CMP:%.*]] = icmp ult [[A:%.*]], [[B:%.*]] +// CHECK-NEXT: [[CONV:%.*]] = zext [[CMP]] to +// CHECK-NEXT: [[VECTOR_COND:%.*]] = icmp ne [[CONV]], zeroinitializer +// CHECK-NEXT: [[VECTOR_SELECT:%.*]] = select [[VECTOR_COND]], [[A]], [[B]] +// CHECK-NEXT: ret [[VECTOR_SELECT]] +// +svuint32_t cond_u32(svuint32_t a, svuint32_t b) { + return a < b ? a : b; +} + +// CHECK-LABEL: @_Z8cond_i64u11__SVInt64_tu11__SVInt64_t( +// CHECK-NEXT: entry: +// CHECK-NEXT: [[CMP:%.*]] = icmp ult [[A:%.*]], [[B:%.*]] +// CHECK-NEXT: [[CONV:%.*]] = zext [[CMP]] to +// CHECK-NEXT: [[VECTOR_COND:%.*]] = icmp ne [[CONV]], zeroinitializer +// CHECK-NEXT: [[VECTOR_SELECT:%.*]] = select [[VECTOR_COND]], [[A]], [[B]] +// CHECK-NEXT: ret [[VECTOR_SELECT]] +// +svint64_t cond_i64(svint64_t a, svint64_t b) { + return a < b ? a : b; +} + +// CHECK-LABEL: @_Z8cond_u64u12__SVUint64_tu12__SVUint64_t( +// CHECK-NEXT: entry: +// CHECK-NEXT: [[CMP:%.*]] = icmp ult [[A:%.*]], [[B:%.*]] +// CHECK-NEXT: [[CONV:%.*]] = zext [[CMP]] to +// CHECK-NEXT: [[VECTOR_COND:%.*]] = icmp ne [[CONV]], zeroinitializer +// CHECK-NEXT: [[VECTOR_SELECT:%.*]] = select [[VECTOR_COND]], [[A]], [[B]] +// CHECK-NEXT: ret [[VECTOR_SELECT]] +// +svuint64_t cond_u64(svuint64_t a, svuint64_t b) { + return a < b ? a : b; +} + +// CHECK-LABEL: @_Z8cond_f16u13__SVFloat16_tu13__SVFloat16_t( +// CHECK-NEXT: entry: +// CHECK-NEXT: [[CMP:%.*]] = fcmp olt [[A:%.*]], [[B:%.*]] +// CHECK-NEXT: [[CONV:%.*]] = zext [[CMP]] to +// CHECK-NEXT: [[VECTOR_COND:%.*]] = icmp ne [[CONV]], zeroinitializer +// CHECK-NEXT: [[VECTOR_SELECT:%.*]] = select [[VECTOR_COND]], [[A]], [[B]] +// CHECK-NEXT: ret [[VECTOR_SELECT]] +// +svfloat16_t cond_f16(svfloat16_t a, svfloat16_t b) { + return a < b ? a : b; +} + +// CHECK-LABEL: @_Z8cond_f32u13__SVFloat32_tu13__SVFloat32_t( +// CHECK-NEXT: entry: +// CHECK-NEXT: [[CMP:%.*]] = fcmp olt [[A:%.*]], [[B:%.*]] +// CHECK-NEXT: [[CONV:%.*]] = zext [[CMP]] to +// CHECK-NEXT: [[VECTOR_COND:%.*]] = icmp ne [[CONV]], zeroinitializer +// CHECK-NEXT: [[VECTOR_SELECT:%.*]] = select [[VECTOR_COND]], [[A]], [[B]] +// CHECK-NEXT: ret [[VECTOR_SELECT]] +// +svfloat32_t cond_f32(svfloat32_t a, svfloat32_t b) { + return a < b ? a : b; +} + +// CHECK-LABEL: @_Z8cond_f64u13__SVFloat64_tu13__SVFloat64_t( +// CHECK-NEXT: entry: +// CHECK-NEXT: [[CMP:%.*]] = fcmp olt [[A:%.*]], [[B:%.*]] +// CHECK-NEXT: [[CONV:%.*]] = zext [[CMP]] to +// CHECK-NEXT: [[VECTOR_COND:%.*]] = icmp ne [[CONV]], zeroinitializer +// CHECK-NEXT: [[VECTOR_SELECT:%.*]] = select [[VECTOR_COND]], [[A]], [[B]] +// CHECK-NEXT: ret [[VECTOR_SELECT]] +// +svfloat64_t cond_f64(svfloat64_t a, svfloat64_t b) { + return a < b ? a : b; +} + +// CHECK-LABEL: @_Z14cond_i32_splatu11__SVInt32_t( +// CHECK-NEXT: entry: +// CHECK-NEXT: [[CMP:%.*]] = icmp ult [[A:%.*]], zeroinitializer +// CHECK-NEXT: [[CONV:%.*]] = zext [[CMP]] to +// CHECK-NEXT: [[VECTOR_COND:%.*]] = icmp ne [[CONV]], zeroinitializer +// CHECK-NEXT: [[VECTOR_SELECT:%.*]] = select [[VECTOR_COND]], [[A]], zeroinitializer +// CHECK-NEXT: ret [[VECTOR_SELECT]] +// +svint32_t cond_i32_splat(svint32_t a) { + return a < 0 ? a : 0; +} + +// CHECK-LABEL: @_Z14cond_u32_splatu12__SVUint32_t( +// CHECK-NEXT: entry: +// CHECK-NEXT: [[CMP:%.*]] = icmp ult [[A:%.*]], shufflevector ( insertelement ( poison, i32 1, i32 0), poison, zeroinitializer) +// CHECK-NEXT: [[CONV:%.*]] = zext [[CMP]] to +// CHECK-NEXT: [[VECTOR_COND:%.*]] = icmp ne [[CONV]], zeroinitializer +// CHECK-NEXT: [[VECTOR_SELECT:%.*]] = select [[VECTOR_COND]], [[A]], shufflevector ( insertelement ( poison, i32 1, i32 0), poison, zeroinitializer) +// CHECK-NEXT: ret [[VECTOR_SELECT]] +// +svuint32_t cond_u32_splat(svuint32_t a) { + return a < 1u ? a : 1u; +} + +// CHECK-LABEL: @_Z14cond_i64_splatu11__SVInt64_t( +// CHECK-NEXT: entry: +// CHECK-NEXT: [[CMP:%.*]] = icmp ult [[A:%.*]], zeroinitializer +// CHECK-NEXT: [[CONV:%.*]] = zext [[CMP]] to +// CHECK-NEXT: [[VECTOR_COND:%.*]] = icmp ne [[CONV]], zeroinitializer +// CHECK-NEXT: [[VECTOR_SELECT:%.*]] = select [[VECTOR_COND]], [[A]], zeroinitializer +// CHECK-NEXT: ret [[VECTOR_SELECT]] +// +svint64_t cond_i64_splat(svint64_t a) { + return a < 0l ? a : 0l; +} + +// CHECK-LABEL: @_Z14cond_u64_splatu12__SVUint64_t( +// CHECK-NEXT: entry: +// CHECK-NEXT: [[CMP:%.*]] = icmp ult [[A:%.*]], shufflevector ( insertelement ( poison, i64 1, i32 0), poison, zeroinitializer) +// CHECK-NEXT: [[CONV:%.*]] = zext [[CMP]] to +// CHECK-NEXT: [[VECTOR_COND:%.*]] = icmp ne [[CONV]], zeroinitializer +// CHECK-NEXT: [[VECTOR_SELECT:%.*]] = select [[VECTOR_COND]], [[A]], shufflevector ( insertelement ( poison, i64 1, i32 0), poison, zeroinitializer) +// CHECK-NEXT: ret [[VECTOR_SELECT]] +// +svuint64_t cond_u64_splat(svuint64_t a) { + return a < 1ul ? a : 1ul; +} + +// CHECK-LABEL: @_Z14cond_f32_splatu13__SVFloat32_t( +// CHECK-NEXT: entry: +// CHECK-NEXT: [[CMP:%.*]] = fcmp olt [[A:%.*]], zeroinitializer +// CHECK-NEXT: [[CONV:%.*]] = zext [[CMP]] to +// CHECK-NEXT: [[VECTOR_COND:%.*]] = icmp ne [[CONV]], zeroinitializer +// CHECK-NEXT: [[VECTOR_SELECT:%.*]] = select [[VECTOR_COND]], [[A]], zeroinitializer +// CHECK-NEXT: ret [[VECTOR_SELECT]] +// +svfloat32_t cond_f32_splat(svfloat32_t a) { + return a < 0.f ? a : 0.f; +} + +// CHECK-LABEL: @_Z14cond_f64_splatu13__SVFloat64_t( +// CHECK-NEXT: entry: +// CHECK-NEXT: [[CMP:%.*]] = fcmp olt [[A:%.*]], zeroinitializer +// CHECK-NEXT: [[CONV:%.*]] = zext [[CMP]] to +// CHECK-NEXT: [[VECTOR_COND:%.*]] = icmp ne [[CONV]], zeroinitializer +// CHECK-NEXT: [[VECTOR_SELECT:%.*]] = select [[VECTOR_COND]], [[A]], zeroinitializer +// CHECK-NEXT: ret [[VECTOR_SELECT]] +// +svfloat64_t cond_f64_splat(svfloat64_t a) { + return a < 0. ? a : 0.; +} + diff --git a/clang/test/SemaCXX/aarch64-sve-vector-conditional-op.cpp b/clang/test/SemaCXX/aarch64-sve-vector-conditional-op.cpp new file mode 100644 --- /dev/null +++ b/clang/test/SemaCXX/aarch64-sve-vector-conditional-op.cpp @@ -0,0 +1,25 @@ +// RUN: %clang_cc1 -verify -triple aarch64-none-linux-gnu -target-feature +sve -fallow-half-arguments-and-returns -fsyntax-only %s + +// REQUIRES: aarch64-registered-target + +#include + +void cond(svint8_t i8, svint16_t i16, svint32_t i32, svint64_t i64, + svuint8_t u8, svuint16_t u16, svuint32_t u32, svuint64_t u64, + svfloat16_t f16, svfloat32_t f32, svfloat64_t f64, + svbool_t b) { + (void)(i8 ? i16 : i16); // expected-error{{vector condition type 'svint8_t' (aka '__SVInt8_t') and result type 'svint16_t' (aka '__SVInt16_t') do not have the same number of elements}} + (void)(i8 ? i32 : i32); // expected-error{{vector condition type 'svint8_t' (aka '__SVInt8_t') and result type 'svint32_t' (aka '__SVInt32_t') do not have the same number of elements}} + (void)(i8 ? i64 : i64); // expected-error{{vector condition type 'svint8_t' (aka '__SVInt8_t') and result type 'svint64_t' (aka '__SVInt64_t') do not have the same number of elements}} + + (void)(i16 ? i16 : i8); // expected-error{{vector operands to the vector conditional must be the same type}} + (void)(i16 ? i16 : i32); // expected-error{{vector operands to the vector conditional must be the same type}} + (void)(i16 ? i16 : i64); // expected-error{{vector operands to the vector conditional must be the same type}} + + (void)(i16 ? i8 : i16); // expected-error{{vector operands to the vector conditional must be the same type}} + (void)(i16 ? i32 : i16); // expected-error{{vector operands to the vector conditional must be the same type}} + (void)(i16 ? i64 : i16); // expected-error{{vector operands to the vector conditional must be the same type}} + + (void)(f32 < 0.f ? 1. : 0.); // expected-error {{vector condition type '__SVInt32_t' and result type 'double' do not have elements of the same size}} + (void)(f64 < 0. ? 1.f : 0.f); // expected-error {{vector condition type '__SVInt64_t' and result type 'float' do not have elements of the same size}} +}