diff --git a/clang/lib/Basic/Targets/AArch64.cpp b/clang/lib/Basic/Targets/AArch64.cpp --- a/clang/lib/Basic/Targets/AArch64.cpp +++ b/clang/lib/Basic/Targets/AArch64.cpp @@ -1288,8 +1288,9 @@ Info.setAllowsRegister(); return true; case 'U': - if (Name[1] == 'p' && (Name[2] == 'l' || Name[2] == 'a')) { - // SVE predicate registers ("Upa"=P0-15, "Upl"=P0-P7) + if (Name[1] == 'p' && + (Name[2] == 'l' || Name[2] == 'a' || Name[2] == 'h')) { + // SVE predicate registers ("Upa"=P0-15, "Upl"=P0-P7, "Uph"=P8-P15) Info.setAllowsRegister(); Name += 2; return true; diff --git a/clang/test/CodeGen/aarch64-sve-inline-asm-datatypes.c b/clang/test/CodeGen/aarch64-sve-inline-asm-datatypes.c --- a/clang/test/CodeGen/aarch64-sve-inline-asm-datatypes.c +++ b/clang/test/CodeGen/aarch64-sve-inline-asm-datatypes.c @@ -168,6 +168,30 @@ SVBOOL_TEST_UPL(__SVInt64_t, d) ; // CHECK: call asm sideeffect "fadd $0.d, $1.d, $2.d, $3.d\0A", "=w,@3Upl,w,w"( %in1, %in2, %in3) +#define SVBOOL_TEST_UPH(DT, KIND)\ +__SVBool_t func_bool_uph_##KIND(__SVBool_t in1, DT in2, DT in3)\ +{\ + __SVBool_t out;\ + asm volatile (\ + "fadd %[out]." #KIND ", %[in1]." #KIND ", %[in2]." #KIND ", %[in3]." #KIND "\n"\ + : [out] "=w" (out)\ + : [in1] "Uph" (in1),\ + [in2] "w" (in2),\ + [in3] "w" (in3)\ + :);\ + return out;\ +} + +SVBOOL_TEST_UPH(__SVInt8_t, b) ; +// CHECK: call asm sideeffect "fadd $0.b, $1.b, $2.b, $3.b\0A", "=w,@3Uph,w,w"( %in1, %in2, %in3) +SVBOOL_TEST_UPH(__SVInt16_t, h) ; +// CHECK: call asm sideeffect "fadd $0.h, $1.h, $2.h, $3.h\0A", "=w,@3Uph,w,w"( %in1, %in2, %in3) +SVBOOL_TEST_UPH(__SVInt32_t, s) ; +// CHECK: call asm sideeffect "fadd $0.s, $1.s, $2.s, $3.s\0A", "=w,@3Uph,w,w"( %in1, %in2, %in3) +SVBOOL_TEST_UPH(__SVInt64_t, d) ; +// CHECK: call asm sideeffect "fadd $0.d, $1.d, $2.d, $3.d\0A", "=w,@3Uph,w,w"( %in1, %in2, %in3) + + #define SVFLOAT_TEST(DT,KIND)\ DT func_float_##DT##KIND(DT inout1, DT in2)\ {\ diff --git a/llvm/docs/LangRef.rst b/llvm/docs/LangRef.rst --- a/llvm/docs/LangRef.rst +++ b/llvm/docs/LangRef.rst @@ -4997,7 +4997,8 @@ - ``w``: A 32, 64, or 128-bit floating-point, SIMD or SVE vector register. - ``x``: Like w, but restricted to registers 0 to 15 inclusive. - ``y``: Like w, but restricted to SVE vector registers Z0 to Z7 inclusive. -- ``Upl``: One of the low eight SVE predicate registers (P0 to P7) +- ``Uph``: One of the upper eight SVE predicate registers (P8 to P15) +- ``Upl``: One of the lower eight SVE predicate registers (P0 to P7) - ``Upa``: Any of the SVE predicate registers (P0 to P15) AMDGPU: diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -9987,19 +9987,31 @@ return "r"; } -enum PredicateConstraint { - Upl, - Upa, - Invalid -}; +enum PredicateConstraint { Uph, Upl, Upa, Invalid }; static PredicateConstraint parsePredicateConstraint(StringRef Constraint) { - PredicateConstraint P = PredicateConstraint::Invalid; - if (Constraint == "Upa") - P = PredicateConstraint::Upa; - if (Constraint == "Upl") - P = PredicateConstraint::Upl; - return P; + return StringSwitch(Constraint) + .Case("Uph", PredicateConstraint::Uph) + .Case("Upl", PredicateConstraint::Upl) + .Case("Upa", PredicateConstraint::Upa) + .Default(PredicateConstraint::Invalid); +} + +static const TargetRegisterClass * +getPredicateRegisterClass(PredicateConstraint Constraint, EVT VT) { + if (!VT.isScalableVector() || VT.getVectorElementType() != MVT::i1) + return nullptr; + + switch (Constraint) { + default: + return nullptr; + case PredicateConstraint::Uph: + return &AArch64::PPR_p8to15RegClass; + case PredicateConstraint::Upl: + return &AArch64::PPR_3bRegClass; + case PredicateConstraint::Upa: + return &AArch64::PPRRegClass; + } } // The set of cc code supported is from @@ -10191,13 +10203,8 @@ } } else { PredicateConstraint PC = parsePredicateConstraint(Constraint); - if (PC != PredicateConstraint::Invalid) { - if (!VT.isScalableVector() || VT.getVectorElementType() != MVT::i1) - return std::make_pair(0U, nullptr); - bool restricted = (PC == PredicateConstraint::Upl); - return restricted ? std::make_pair(0U, &AArch64::PPR_3bRegClass) - : std::make_pair(0U, &AArch64::PPRRegClass); - } + if (const TargetRegisterClass *RegClass = getPredicateRegisterClass(PC, VT)) + return std::make_pair(0U, RegClass); } if (StringRef("{cc}").equals_insensitive(Constraint) || parseConstraintCode(Constraint) != AArch64CC::Invalid) diff --git a/llvm/test/CodeGen/AArch64/aarch64-sve-asm.ll b/llvm/test/CodeGen/AArch64/aarch64-sve-asm.ll --- a/llvm/test/CodeGen/AArch64/aarch64-sve-asm.ll +++ b/llvm/test/CodeGen/AArch64/aarch64-sve-asm.ll @@ -68,3 +68,14 @@ %1 = tail call asm "incp $0.s, $1", "=w,@3Upa,0"( %Pg, %Zn) ret %1 } + +; Function Attrs: nounwind readnone +; CHECK: [[ARG1:%[0-9]+]]:zpr = COPY $z1 +; CHECK: [[ARG2:%[0-9]+]]:zpr = COPY $z0 +; CHECK: [[ARG3:%[0-9]+]]:ppr = COPY $p0 +; CHECK: [[ARG4:%[0-9]+]]:ppr_p8to15 = COPY [[ARG3]] +; CHECK: INLINEASM {{.*}} [[ARG4]] +define @test_svfadd_f16_Uph_constraint( %Pg, %Zn, %Zm) { + %1 = tail call asm "fadd $0.h, $1/m, $2.h, $3.h", "=w,@3Uph,w,w"( %Pg, %Zn, %Zm) + ret %1 +}