diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td --- a/clang/include/clang/Basic/DiagnosticSemaKinds.td +++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td @@ -3627,6 +3627,12 @@ "the vecreturn attribute can only be used on a POD (plain old data) class or structure (i.e. no virtual functions)">; def err_sme_attr_mismatch : Error< "function declared %0 was previously declared %1, which has different SME function attributes">; +def err_sme_call_in_non_sme_target : Error< + "call to a streaming function requires 'sme'">; +def err_sme_definition_using_sm_in_non_sme_target : Error< + "function executed in streaming-SVE mode requires 'sme'">; +def err_sme_definition_using_za_in_non_sme_target : Error< + "function using ZA state requires 'sme'">; def err_cconv_change : Error< "function declared '%0' here was previously declared " "%select{'%2'|without calling convention}1">; diff --git a/clang/lib/Sema/SemaChecking.cpp b/clang/lib/Sema/SemaChecking.cpp --- a/clang/lib/Sema/SemaChecking.cpp +++ b/clang/lib/Sema/SemaChecking.cpp @@ -6669,8 +6669,8 @@ } /// Handles the checks for format strings, non-POD arguments to vararg -/// functions, NULL arguments passed to non-NULL parameters, and diagnose_if -/// attributes. +/// functions, NULL arguments passed to non-NULL parameters, diagnose_if +/// attributes and AArch64 SME attributes. void Sema::checkCall(NamedDecl *FDecl, const FunctionProtoType *Proto, const Expr *ThisArg, ArrayRef Args, bool IsMemberFunction, SourceLocation Loc, @@ -6751,6 +6751,20 @@ ArgTy, ParamTy); } } + + // If the callee has an AArch64 SME attribute to indicate that it is an + // __arm_streaming function, then the caller requires SME to be available. + FunctionProtoType::ExtProtoInfo ExtInfo = Proto->getExtProtoInfo(); + if (ExtInfo.AArch64SMEAttributes & FunctionType::SME_PStateSMEnabledMask) { + if (auto *CallerFD = dyn_cast(CurContext)) { + llvm::StringMap CallerFeatureMap; + Context.getFunctionFeatureMap(CallerFeatureMap, CallerFD); + if (!CallerFeatureMap.contains("sme")) + Diag(Loc, diag::err_sme_call_in_non_sme_target); + } else if (!Context.getTargetInfo().hasFeature("sme")) { + Diag(Loc, diag::err_sme_call_in_non_sme_target); + } + } } if (FDecl && FDecl->hasAttr()) { diff --git a/clang/lib/Sema/SemaDecl.cpp b/clang/lib/Sema/SemaDecl.cpp --- a/clang/lib/Sema/SemaDecl.cpp +++ b/clang/lib/Sema/SemaDecl.cpp @@ -12140,6 +12140,33 @@ if (!Redeclaration && LangOpts.CUDA) checkCUDATargetOverload(NewFD, Previous); } + + // Check if the function definition uses any AArch64 SME features without + // having the '+sme' feature enabled. + if (DeclIsDefn) { + bool UsesSM = NewFD->hasAttr(); + bool UsesZA = NewFD->hasAttr(); + if (const auto *FPT = NewFD->getType()->getAs()) { + FunctionProtoType::ExtProtoInfo EPI = FPT->getExtProtoInfo(); + UsesSM |= + EPI.AArch64SMEAttributes & FunctionType::SME_PStateSMEnabledMask; + UsesZA |= EPI.AArch64SMEAttributes & FunctionType::SME_PStateZASharedMask; + } + + if (UsesSM || UsesZA) { + llvm::StringMap FeatureMap; + Context.getFunctionFeatureMap(FeatureMap, NewFD); + if (!FeatureMap.contains("sme")) { + if (UsesSM) + Diag(NewFD->getLocation(), + diag::err_sme_definition_using_sm_in_non_sme_target); + else + Diag(NewFD->getLocation(), + diag::err_sme_definition_using_za_in_non_sme_target); + } + } + } + return Redeclaration; } diff --git a/clang/test/Sema/aarch64-sme-func-attrs-without-target-feature.cpp b/clang/test/Sema/aarch64-sme-func-attrs-without-target-feature.cpp new file mode 100644 --- /dev/null +++ b/clang/test/Sema/aarch64-sme-func-attrs-without-target-feature.cpp @@ -0,0 +1,48 @@ +// RUN: %clang_cc1 -triple aarch64-none-linux-gnu -fsyntax-only -verify %s + +// This test is testing the diagnostics that Clang emits when compiling without '+sme'. + +void streaming_compatible_def() __arm_streaming_compatible {} // OK +void streaming_def() __arm_streaming { } // expected-error {{function executed in streaming-SVE mode requires 'sme'}} +void shared_za_def() __arm_shared_za { } // expected-error {{function using ZA state requires 'sme'}} +__arm_new_za void new_za_def() { } // expected-error {{function using ZA state requires 'sme'}} +__arm_locally_streaming void locally_streaming_def() { } // expected-error {{function executed in streaming-SVE mode requires 'sme'}} +void streaming_shared_za_def() __arm_streaming __arm_shared_za { } // expected-error {{function executed in streaming-SVE mode requires 'sme'}} + +// It should work fine when we explicitly add the target("sme") attribute. +__attribute__((target("sme"))) void streaming_compatible_def_sme_attr() __arm_streaming_compatible {} // OK +__attribute__((target("sme"))) void streaming_def_sme_attr() __arm_streaming { } // OK +__attribute__((target("sme"))) void shared_za_def_sme_attr() __arm_shared_za { } // OK +__arm_new_za __attribute__((target("sme"))) void new_za_def_sme_attr() {} // OK +__arm_locally_streaming __attribute__((target("sme"))) void locally_streaming_def_sme_attr() {} // OK + +// Test that it also works with the target("sme2") attribute. +__attribute__((target("sme2"))) void streaming_def_sme2_attr() __arm_streaming { } // OK + +// No code is generated for declarations, so it should be fine to declare using the attribute. +void streaming_compatible_decl() __arm_streaming_compatible; // OK +void streaming_decl() __arm_streaming; // OK +void shared_za_decl() __arm_shared_za; // OK + +void non_streaming_decl(); +void non_streaming_def(void (*streaming_fn_ptr)(void) __arm_streaming, + void (*streaming_compatible_fn_ptr)(void) __arm_streaming_compatible) { + streaming_compatible_decl(); // OK + streaming_compatible_fn_ptr(); // OK + streaming_decl(); // expected-error {{call to a streaming function requires 'sme'}} + streaming_fn_ptr(); // expected-error {{call to a streaming function requires 'sme'}} +} + +void streaming_compatible_def2(void (*streaming_fn_ptr)(void) __arm_streaming, + void (*streaming_compatible_fn_ptr)(void) __arm_streaming_compatible) + __arm_streaming_compatible { + non_streaming_decl(); // OK + streaming_compatible_decl(); // OK + streaming_compatible_fn_ptr(); // OK + streaming_decl(); // expected-error {{call to a streaming function requires 'sme'}} + streaming_fn_ptr(); // expected-error {{call to a streaming function requires 'sme'}} +} + +// Also test when call-site is not a function. +int streaming_decl_ret_int() __arm_streaming; +int x = streaming_decl_ret_int(); // expected-error {{call to a streaming function requires 'sme'}} diff --git a/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td --- a/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td +++ b/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td @@ -134,52 +134,6 @@ // Mode selection and state access instructions //===----------------------------------------------------------------------===// -// SME defines three pstate fields to set or clear PSTATE.SM, PSTATE.ZA, or -// both fields: -// -// MSR SVCRSM, # -// MSR SVCRZA, # -// MSR SVCRSMZA, # -// -// It's tricky to using the existing pstate operand defined in -// AArch64SystemOperands.td since it only encodes 5 bits including op1;op2, -// when these fields are also encoded in CRm[3:1]. -def MSRpstatesvcrImm1 - : PstateWriteSimple<(ins svcr_op:$pstatefield, timm0_1:$imm), "msr", - "\t$pstatefield, $imm">, - Sched<[WriteSys]> { - bits<3> pstatefield; - bit imm; - let Inst{18-16} = 0b011; // op1 - let Inst{11-9} = pstatefield; - let Inst{8} = imm; - let Inst{7-5} = 0b011; // op2 -} - -def : InstAlias<"smstart", (MSRpstatesvcrImm1 0b011, 0b1)>; -def : InstAlias<"smstart sm", (MSRpstatesvcrImm1 0b001, 0b1)>; -def : InstAlias<"smstart za", (MSRpstatesvcrImm1 0b010, 0b1)>; - -def : InstAlias<"smstop", (MSRpstatesvcrImm1 0b011, 0b0)>; -def : InstAlias<"smstop sm", (MSRpstatesvcrImm1 0b001, 0b0)>; -def : InstAlias<"smstop za", (MSRpstatesvcrImm1 0b010, 0b0)>; - - -// Pseudo to match to smstart/smstop. This expands: -// -// pseudonode (pstate_za|pstate_sm), before_call, expected_value -// -// Into: -// -// if (before_call != expected_value) -// node (pstate_za|pstate_sm) -// -// where node can be either 'smstart' or 'smstop'. -def MSRpstatePseudo : - Pseudo<(outs), - (ins svcr_op:$pstatefield, timm0_1:$imm, GPR64:$rtpstate, timm0_1:$expected_pstate, variable_ops), []>, - Sched<[WriteSys]>; - // Pseudo to conditionally restore ZA state. This expands: // // pseudonode tpidr2_el0, tpidr2obj, restore_routine @@ -226,12 +180,6 @@ def : Pat<(AArch64_smstop (i32 svcr_op:$pstate), (i64 0), (i64 1)), // after call (MSRpstatesvcrImm1 svcr_op:$pstate, 0b0)>; -// The generic case which gets expanded to a pseudo node. -def : Pat<(AArch64_smstart (i32 svcr_op:$pstate), (i64 GPR64:$rtpstate), (i64 timm0_1:$expected_pstate)), - (MSRpstatePseudo svcr_op:$pstate, 0b1, GPR64:$rtpstate, timm0_1:$expected_pstate)>; -def : Pat<(AArch64_smstop (i32 svcr_op:$pstate), (i64 GPR64:$rtpstate), (i64 timm0_1:$expected_pstate)), - (MSRpstatePseudo svcr_op:$pstate, 0b0, GPR64:$rtpstate, timm0_1:$expected_pstate)>; - // Read and write TPIDR2_EL0 def : Pat<(int_aarch64_sme_set_tpidr2 i64:$val), (MSR 0xde85, GPR64:$val)>; @@ -243,6 +191,31 @@ (OBSCURE_COPY GPR64:$idx)>; } // End let Predicates = [HasSME] +// Pseudo to match to smstart/smstop. This expands: +// +// pseudonode (pstate_za|pstate_sm), before_call, expected_value +// +// Into: +// +// if (before_call != expected_value) +// node (pstate_za|pstate_sm) +// +// where node can be either 'smstart' or 'smstop'. +// +// This pseudo and corresponding patterns don't need to be predicated by SME, +// because when they're emitted for streaming-compatible functions and run +// in a non-SME context the generated code-paths will never execute any +// SME instructions. +def MSRpstatePseudo : + Pseudo<(outs), + (ins svcr_op:$pstatefield, timm0_1:$imm, GPR64:$rtpstate, timm0_1:$expected_pstate, variable_ops), []>, + Sched<[WriteSys]>; + +def : Pat<(AArch64_smstart (i32 svcr_op:$pstate), (i64 GPR64:$rtpstate), (i64 timm0_1:$expected_pstate)), + (MSRpstatePseudo svcr_op:$pstate, 0b1, GPR64:$rtpstate, timm0_1:$expected_pstate)>; +def : Pat<(AArch64_smstop (i32 svcr_op:$pstate), (i64 GPR64:$rtpstate), (i64 timm0_1:$expected_pstate)), + (MSRpstatePseudo svcr_op:$pstate, 0b0, GPR64:$rtpstate, timm0_1:$expected_pstate)>; + //===----------------------------------------------------------------------===// // SME2 Instructions //===----------------------------------------------------------------------===// diff --git a/llvm/lib/Target/AArch64/AArch64TargetMachine.cpp b/llvm/lib/Target/AArch64/AArch64TargetMachine.cpp --- a/llvm/lib/Target/AArch64/AArch64TargetMachine.cpp +++ b/llvm/lib/Target/AArch64/AArch64TargetMachine.cpp @@ -442,8 +442,6 @@ assert((!StreamingSVEMode || I->hasSME()) && "Expected SME to be available"); - assert((!StreamingCompatibleSVEMode || I->hasSVEorSME()) && - "Expected SVE or SME to be available"); return I.get(); } diff --git a/llvm/lib/Target/AArch64/SMEInstrFormats.td b/llvm/lib/Target/AArch64/SMEInstrFormats.td --- a/llvm/lib/Target/AArch64/SMEInstrFormats.td +++ b/llvm/lib/Target/AArch64/SMEInstrFormats.td @@ -190,6 +190,42 @@ : Pat<(intrinsic imm_ty:$tile, (pg_ty PPR3bAny:$Pn), (pg_ty PPR3bAny:$Pm), vt:$Zn, vt:$Zm), (!cast(name # _PSEUDO) $tile, $Pn, $Pm, $Zn, $Zm)>; + +//===----------------------------------------------------------------------===// +// SME smstart/smstop +//===----------------------------------------------------------------------===// + +// SME defines three pstate fields to set or clear PSTATE.SM, PSTATE.ZA, or +// both fields: +// +// MSR SVCRSM, # +// MSR SVCRZA, # +// MSR SVCRSMZA, # +// +// It's tricky to using the existing pstate operand defined in +// AArch64SystemOperands.td since it only encodes 5 bits including op1;op2, +// when these fields are also encoded in CRm[3:1]. +def MSRpstatesvcrImm1 + : PstateWriteSimple<(ins svcr_op:$pstatefield, timm0_1:$imm), "msr", + "\t$pstatefield, $imm">, + Sched<[WriteSys]> { + bits<3> pstatefield; + bit imm; + let Inst{18-16} = 0b011; // op1 + let Inst{11-9} = pstatefield; + let Inst{8} = imm; + let Inst{7-5} = 0b011; // op2 +} + +def : InstAlias<"smstart", (MSRpstatesvcrImm1 0b011, 0b1)>; +def : InstAlias<"smstart sm", (MSRpstatesvcrImm1 0b001, 0b1)>; +def : InstAlias<"smstart za", (MSRpstatesvcrImm1 0b010, 0b1)>; + +def : InstAlias<"smstop", (MSRpstatesvcrImm1 0b011, 0b0)>; +def : InstAlias<"smstop sm", (MSRpstatesvcrImm1 0b001, 0b0)>; +def : InstAlias<"smstop za", (MSRpstatesvcrImm1 0b010, 0b0)>; + + //===----------------------------------------------------------------------===// // SME Outer Products //===----------------------------------------------------------------------===// diff --git a/llvm/test/CodeGen/AArch64/sme-call-streaming-compatible-to-normal-fn-wihout-sme-attr.ll b/llvm/test/CodeGen/AArch64/sme-call-streaming-compatible-to-normal-fn-wihout-sme-attr.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/AArch64/sme-call-streaming-compatible-to-normal-fn-wihout-sme-attr.ll @@ -0,0 +1,41 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 2 +; RUN: llc < %s | FileCheck %s + +; Verify that the following code can be compiled without +sme, because if the +; call is not entered in streaming-SVE mode at runtime, the codepath leading +; to the smstop/smstart pair will not be executed either. + +target triple = "aarch64" + +define void @streaming_compatible() #0 { +; CHECK-LABEL: streaming_compatible: +; CHECK: // %bb.0: +; CHECK-NEXT: stp d15, d14, [sp, #-80]! // 16-byte Folded Spill +; CHECK-NEXT: stp d13, d12, [sp, #16] // 16-byte Folded Spill +; CHECK-NEXT: stp d11, d10, [sp, #32] // 16-byte Folded Spill +; CHECK-NEXT: stp d9, d8, [sp, #48] // 16-byte Folded Spill +; CHECK-NEXT: stp x30, x19, [sp, #64] // 16-byte Folded Spill +; CHECK-NEXT: bl __arm_sme_state +; CHECK-NEXT: and x19, x0, #0x1 +; CHECK-NEXT: tbz x19, #0, .LBB0_2 +; CHECK-NEXT: // %bb.1: +; CHECK-NEXT: smstop sm +; CHECK-NEXT: .LBB0_2: +; CHECK-NEXT: bl non_streaming +; CHECK-NEXT: tbz x19, #0, .LBB0_4 +; CHECK-NEXT: // %bb.3: +; CHECK-NEXT: smstart sm +; CHECK-NEXT: .LBB0_4: +; CHECK-NEXT: ldp x30, x19, [sp, #64] // 16-byte Folded Reload +; CHECK-NEXT: ldp d9, d8, [sp, #48] // 16-byte Folded Reload +; CHECK-NEXT: ldp d11, d10, [sp, #32] // 16-byte Folded Reload +; CHECK-NEXT: ldp d13, d12, [sp, #16] // 16-byte Folded Reload +; CHECK-NEXT: ldp d15, d14, [sp], #80 // 16-byte Folded Reload +; CHECK-NEXT: ret + call void @non_streaming() + ret void +} + +declare void @non_streaming() + +attributes #0 = { nounwind "aarch64_pstate_sm_compatible" }