diff --git a/clang/include/clang/Basic/BuiltinsRISCV.def b/clang/include/clang/Basic/BuiltinsRISCV.def --- a/clang/include/clang/Basic/BuiltinsRISCV.def +++ b/clang/include/clang/Basic/BuiltinsRISCV.def @@ -79,5 +79,9 @@ TARGET_BUILTIN(__builtin_riscv_sm3p0, "LiLi", "nc", "zksh") TARGET_BUILTIN(__builtin_riscv_sm3p1, "LiLi", "nc", "zksh") +// Zihintntl extension +TARGET_BUILTIN(__builtin_riscv_ntl_load, "v.", "t", "experimental-zihintntl") +TARGET_BUILTIN(__builtin_riscv_ntl_store, "v.", "t", "experimental-zihintntl") + #undef BUILTIN #undef TARGET_BUILTIN diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp --- a/clang/lib/CodeGen/CGBuiltin.cpp +++ b/clang/lib/CodeGen/CGBuiltin.cpp @@ -19821,6 +19821,47 @@ IntrinsicTypes = {ResultType}; break; + // Zihintntl + case RISCV::BI__builtin_riscv_ntl_load: { + llvm::Type *ResTy = ConvertType(E->getType()); + ConstantInt *Mode = llvm::dyn_cast(Ops[1]); + + assert( + Mode && + "__builtin_riscv_ntl_load's domain argument value must be constant."); + + llvm::MDNode *Node = llvm::MDNode::get( + getLLVMContext(), + llvm::ConstantAsMetadata::get(Builder.getInt32(Mode->getZExtValue()))); + + int Width = ResTy->getPrimitiveSizeInBits(); + LoadInst *Load = Builder.CreateLoad( + Address(Ops[0], ResTy, CharUnits::fromQuantity(Width / 8))); + + Load->setMetadata(CGM.getModule().getMDKindID("nontemporal"), Node); + + return Load; + } + case RISCV::BI__builtin_riscv_ntl_store: { + ConstantInt *Mode = llvm::dyn_cast(Ops[2]); + + assert( + Mode && + "__builtin_riscv_ntl_store's domain argument value must be constant."); + + llvm::MDNode *Node = llvm::MDNode::get( + getLLVMContext(), + llvm::ConstantAsMetadata::get(Builder.getInt32(Mode->getZExtValue()))); + + Value *BC = Builder.CreateBitCast( + Ops[0], llvm::PointerType::getUnqual(Ops[1]->getType()), "cast"); + + StoreInst *Store = Builder.CreateDefaultAlignedStore(Ops[1], BC); + Store->setMetadata(CGM.getModule().getMDKindID("nontemporal"), Node); + + return Store; + } + // Vector builtins are handled from here. #include "clang/Basic/riscv_vector_builtin_cg.inc" } diff --git a/clang/lib/Headers/CMakeLists.txt b/clang/lib/Headers/CMakeLists.txt --- a/clang/lib/Headers/CMakeLists.txt +++ b/clang/lib/Headers/CMakeLists.txt @@ -98,6 +98,10 @@ htmxlintrin.h ) +set(riscv_files + riscv_ntlh.h + ) + set(systemz_files s390intrin.h vecintrin.h @@ -243,6 +247,7 @@ ${opencl_files} ${ppc_files} ${ppc_htm_files} + ${riscv_files} ${systemz_files} ${ve_files} ${x86_files} @@ -424,7 +429,7 @@ add_header_target("mips-resource-headers" "${mips_msa_files}") add_header_target("ppc-resource-headers" "${ppc_files};${ppc_wrapper_files}") add_header_target("ppc-htm-resource-headers" "${ppc_htm_files}") -add_header_target("riscv-resource-headers" "${riscv_generated_files}") +add_header_target("riscv-resource-headers" "${riscv_files};${riscv_generated_files}") add_header_target("systemz-resource-headers" "${systemz_files}") add_header_target("ve-resource-headers" "${ve_files}") add_header_target("webassembly-resource-headers" "${webassembly_files}") @@ -547,6 +552,12 @@ EXCLUDE_FROM_ALL COMPONENT riscv-resource-headers) +install( + FILES ${riscv_files} + DESTINATION ${header_install_dir} + EXCLUDE_FROM_ALL + COMPONENT riscv-resource-headers) + install( FILES ${systemz_files} DESTINATION ${header_install_dir} diff --git a/clang/lib/Headers/riscv_ntlh.h b/clang/lib/Headers/riscv_ntlh.h new file mode 100644 --- /dev/null +++ b/clang/lib/Headers/riscv_ntlh.h @@ -0,0 +1,10 @@ +enum { + __RISCV_NTLH_INNERMOST_PRIVATE = 2, + __RISCV_NTLH_ALL_PRIVATE, + __RISCV_NTLH_INNERMOST_SHARED, + __RISCV_NTLH_ALL +}; + +#define __rv_ntl_load(PTR, DOMAIN) __builtin_riscv_ntl_load(PTR, DOMAIN) +#define __rv_ntl_store(PTR, VAL, DOMAIN) \ + __builtin_riscv_ntl_store(PTR, VAL, DOMAIN) diff --git a/clang/lib/Sema/SemaChecking.cpp b/clang/lib/Sema/SemaChecking.cpp --- a/clang/lib/Sema/SemaChecking.cpp +++ b/clang/lib/Sema/SemaChecking.cpp @@ -4652,6 +4652,65 @@ // Check if rnum is in [0, 10] case RISCV::BI__builtin_riscv_aes64ks1i_64: return SemaBuiltinConstantArgRange(TheCall, 1, 0, 10); + case RISCV::BI__builtin_riscv_ntl_load: + case RISCV::BI__builtin_riscv_ntl_store: + DeclRefExpr *DRE = + cast(TheCall->getCallee()->IgnoreParenCasts()); + assert((BuiltinID == RISCV::BI__builtin_riscv_ntl_store || + BuiltinID == RISCV::BI__builtin_riscv_ntl_load) && + "Unexpected RISC-V nontemporal load/store builtin!"); + bool IsStore = BuiltinID == RISCV::BI__builtin_riscv_ntl_store; + unsigned NumArgs = IsStore ? 3 : 2; + + if (checkArgCount(*this, TheCall, NumArgs)) + return true; + + // Domain value should be compile-time constant. + // 2 <= domain <= 5 + if (SemaBuiltinConstantArgRange(TheCall, NumArgs - 1, 2, 5)) + return true; + + Expr *PointerArg = TheCall->getArg(0); + ExprResult PointerArgResult = + DefaultFunctionArrayLvalueConversion(PointerArg); + + if (PointerArgResult.isInvalid()) + return true; + PointerArg = PointerArgResult.get(); + + const PointerType *PtrType = PointerArg->getType()->getAs(); + if (!PtrType) { + Diag(DRE->getBeginLoc(), diag::err_nontemporal_builtin_must_be_pointer) + << PointerArg->getType() << PointerArg->getSourceRange(); + return true; + } + + QualType ValType = PtrType->getPointeeType(); + ValType = ValType.getUnqualifiedType(); + if (!ValType->isIntegerType() && !ValType->isAnyPointerType() && + !ValType->isBlockPointerType() && !ValType->isFloatingType() && + !ValType->isVectorType()) { + Diag(DRE->getBeginLoc(), + diag::err_nontemporal_builtin_must_be_pointer_intfltptr_or_vector) + << PointerArg->getType() << PointerArg->getSourceRange(); + return true; + } + + if (!IsStore) { + TheCall->setType(ValType); + return false; + } + + ExprResult ValArg = TheCall->getArg(1); + InitializedEntity Entity = InitializedEntity::InitializeParameter( + Context, ValType, /*consume*/ false); + ValArg = PerformCopyInitialization(Entity, SourceLocation(), ValArg); + if (ValArg.isInvalid()) + return true; + + TheCall->setArg(1, ValArg.get()); + TheCall->setType(Context.VoidTy); + return false; } return false; diff --git a/clang/test/CodeGen/RISCV/ntlh-intrinsics/riscv32-zihintntl.c b/clang/test/CodeGen/RISCV/ntlh-intrinsics/riscv32-zihintntl.c new file mode 100644 --- /dev/null +++ b/clang/test/CodeGen/RISCV/ntlh-intrinsics/riscv32-zihintntl.c @@ -0,0 +1,122 @@ +// RUN: %clang_cc1 -triple riscv32 -target-feature +experimental-zihintntl -emit-llvm %s -o - \ +// RUN: | FileCheck %s + +#include + +signed char sc; +unsigned char uc; +signed short ss; +unsigned short us; +signed int si; +unsigned int ui; +signed long long sll; +unsigned long long ull; +_Float16 h1, h2; +float f1, f2; +double d1, d2; + +// clang-format off +void ntl_all_sizes() { // CHECK-LABEL: ntl_all_sizes + uc = __rv_ntl_load(&sc, __RISCV_NTLH_INNERMOST_PRIVATE); // CHECK: load i8{{.*}}align 1, !nontemporal !4 + sc = __rv_ntl_load(&uc, __RISCV_NTLH_INNERMOST_PRIVATE); // CHECK: load i8{{.*}}align 1, !nontemporal !4 + us = __rv_ntl_load(&ss, __RISCV_NTLH_INNERMOST_PRIVATE); // CHECK: load i16{{.*}}align 2, !nontemporal !4 + ss = __rv_ntl_load(&us, __RISCV_NTLH_INNERMOST_PRIVATE); // CHECK: load i16{{.*}}align 2, !nontemporal !4 + ui = __rv_ntl_load(&si, __RISCV_NTLH_INNERMOST_PRIVATE); // CHECK: load i32{{.*}}align 4, !nontemporal !4 + si = __rv_ntl_load(&ui, __RISCV_NTLH_INNERMOST_PRIVATE); // CHECK: load i32{{.*}}align 4, !nontemporal !4 + ull = __rv_ntl_load(&sll, __RISCV_NTLH_INNERMOST_PRIVATE); // CHECK: load i64{{.*}}align 8, !nontemporal !4 + sll = __rv_ntl_load(&ull, __RISCV_NTLH_INNERMOST_PRIVATE); // CHECK: load i64{{.*}}align 8, !nontemporal !4 + h1 = __rv_ntl_load(&h2, __RISCV_NTLH_INNERMOST_PRIVATE); // CHECK: load half{{.*}}align 2, !nontemporal !4 + f1 = __rv_ntl_load(&f2, __RISCV_NTLH_INNERMOST_PRIVATE); // CHECK: load float{{.*}}align 4, !nontemporal !4 + d1 = __rv_ntl_load(&d2, __RISCV_NTLH_INNERMOST_PRIVATE); // CHECK: load double{{.*}}align 8, !nontemporal !4 + + uc = __rv_ntl_load(&sc, __RISCV_NTLH_ALL_PRIVATE); // CHECK: load i8{{.*}}align 1, !nontemporal !5 + sc = __rv_ntl_load(&uc, __RISCV_NTLH_ALL_PRIVATE); // CHECK: load i8{{.*}}align 1, !nontemporal !5 + us = __rv_ntl_load(&ss, __RISCV_NTLH_ALL_PRIVATE); // CHECK: load i16{{.*}}align 2, !nontemporal !5 + ss = __rv_ntl_load(&us, __RISCV_NTLH_ALL_PRIVATE); // CHECK: load i16{{.*}}align 2, !nontemporal !5 + ui = __rv_ntl_load(&si, __RISCV_NTLH_ALL_PRIVATE); // CHECK: load i32{{.*}}align 4, !nontemporal !5 + si = __rv_ntl_load(&ui, __RISCV_NTLH_ALL_PRIVATE); // CHECK: load i32{{.*}}align 4, !nontemporal !5 + ull = __rv_ntl_load(&sll, __RISCV_NTLH_ALL_PRIVATE); // CHECK: load i64{{.*}}align 8, !nontemporal !5 + sll = __rv_ntl_load(&ull, __RISCV_NTLH_ALL_PRIVATE); // CHECK: load i64{{.*}}align 8, !nontemporal !5 + h1 = __rv_ntl_load(&h2, __RISCV_NTLH_ALL_PRIVATE); // CHECK: load half{{.*}}align 2, !nontemporal !5 + f1 = __rv_ntl_load(&f2, __RISCV_NTLH_ALL_PRIVATE); // CHECK: load float{{.*}}align 4, !nontemporal !5 + d1 = __rv_ntl_load(&d2, __RISCV_NTLH_ALL_PRIVATE); // CHECK: load double{{.*}}align 8, !nontemporal !5 + + uc = __rv_ntl_load(&sc, __RISCV_NTLH_INNERMOST_SHARED); // CHECK: load i8{{.*}}align 1, !nontemporal !6 + sc = __rv_ntl_load(&uc, __RISCV_NTLH_INNERMOST_SHARED); // CHECK: load i8{{.*}}align 1, !nontemporal !6 + us = __rv_ntl_load(&ss, __RISCV_NTLH_INNERMOST_SHARED); // CHECK: load i16{{.*}}align 2, !nontemporal !6 + ss = __rv_ntl_load(&us, __RISCV_NTLH_INNERMOST_SHARED); // CHECK: load i16{{.*}}align 2, !nontemporal !6 + ui = __rv_ntl_load(&si, __RISCV_NTLH_INNERMOST_SHARED); // CHECK: load i32{{.*}}align 4, !nontemporal !6 + si = __rv_ntl_load(&ui, __RISCV_NTLH_INNERMOST_SHARED); // CHECK: load i32{{.*}}align 4, !nontemporal !6 + ull = __rv_ntl_load(&sll, __RISCV_NTLH_INNERMOST_SHARED); // CHECK: load i64{{.*}}align 8, !nontemporal !6 + sll = __rv_ntl_load(&ull, __RISCV_NTLH_INNERMOST_SHARED); // CHECK: load i64{{.*}}align 8, !nontemporal !6 + h1 = __rv_ntl_load(&h2, __RISCV_NTLH_INNERMOST_SHARED); // CHECK: load half{{.*}}align 2, !nontemporal !6 + f1 = __rv_ntl_load(&f2, __RISCV_NTLH_INNERMOST_SHARED); // CHECK: load float{{.*}}align 4, !nontemporal !6 + d1 = __rv_ntl_load(&d2, __RISCV_NTLH_INNERMOST_SHARED); // CHECK: load double{{.*}}align 8, !nontemporal !6 + + uc = __rv_ntl_load(&sc, __RISCV_NTLH_ALL); // CHECK: load i8{{.*}}align 1, !nontemporal !7 + sc = __rv_ntl_load(&uc, __RISCV_NTLH_ALL); // CHECK: load i8{{.*}}align 1, !nontemporal !7 + us = __rv_ntl_load(&ss, __RISCV_NTLH_ALL); // CHECK: load i16{{.*}}align 2, !nontemporal !7 + ss = __rv_ntl_load(&us, __RISCV_NTLH_ALL); // CHECK: load i16{{.*}}align 2, !nontemporal !7 + ui = __rv_ntl_load(&si, __RISCV_NTLH_ALL); // CHECK: load i32{{.*}}align 4, !nontemporal !7 + si = __rv_ntl_load(&ui, __RISCV_NTLH_ALL); // CHECK: load i32{{.*}}align 4, !nontemporal !7 + ull = __rv_ntl_load(&sll, __RISCV_NTLH_ALL); // CHECK: load i64{{.*}}align 8, !nontemporal !7 + sll = __rv_ntl_load(&ull, __RISCV_NTLH_ALL); // CHECK: load i64{{.*}}align 8, !nontemporal !7 + h1 = __rv_ntl_load(&h2, __RISCV_NTLH_ALL); // CHECK: load half{{.*}}align 2, !nontemporal !7 + f1 = __rv_ntl_load(&f2, __RISCV_NTLH_ALL); // CHECK: load float{{.*}}align 4, !nontemporal !7 + d1 = __rv_ntl_load(&d2, __RISCV_NTLH_ALL); // CHECK: load double{{.*}}align 8, !nontemporal !7 + + __rv_ntl_store(&uc, 1, __RISCV_NTLH_INNERMOST_PRIVATE); // CHECK: store i8{{.*}}align 1, !nontemporal !4 + __rv_ntl_store(&sc, 1, __RISCV_NTLH_INNERMOST_PRIVATE); // CHECK: store i8{{.*}}align 1, !nontemporal !4 + __rv_ntl_store(&us, 1, __RISCV_NTLH_INNERMOST_PRIVATE); // CHECK: store i16{{.*}}align 2, !nontemporal !4 + __rv_ntl_store(&ss, 1, __RISCV_NTLH_INNERMOST_PRIVATE); // CHECK: store i16{{.*}}align 2, !nontemporal !4 + __rv_ntl_store(&ui, 1, __RISCV_NTLH_INNERMOST_PRIVATE); // CHECK: store i32{{.*}}align 4, !nontemporal !4 + __rv_ntl_store(&si, 1, __RISCV_NTLH_INNERMOST_PRIVATE); // CHECK: store i32{{.*}}align 4, !nontemporal !4 + __rv_ntl_store(&ull, 1, __RISCV_NTLH_INNERMOST_PRIVATE); // CHECK: store i64{{.*}}align 8, !nontemporal !4 + __rv_ntl_store(&sll, 1, __RISCV_NTLH_INNERMOST_PRIVATE); // CHECK: store i64{{.*}}align 8, !nontemporal !4 + __rv_ntl_store(&h1, 1.0, __RISCV_NTLH_INNERMOST_PRIVATE); // CHECK: store half{{.*}}align 2, !nontemporal !4 + __rv_ntl_store(&f1, 1.0, __RISCV_NTLH_INNERMOST_PRIVATE); // CHECK: store float{{.*}}align 4, !nontemporal !4 + __rv_ntl_store(&d1, 1.0, __RISCV_NTLH_INNERMOST_PRIVATE); // CHECK: store double{{.*}}align 8, !nontemporal !4 + + __rv_ntl_store(&uc, 1, __RISCV_NTLH_ALL_PRIVATE); // CHECK: store i8{{.*}}align 1, !nontemporal !5 + __rv_ntl_store(&sc, 1, __RISCV_NTLH_ALL_PRIVATE); // CHECK: store i8{{.*}}align 1, !nontemporal !5 + __rv_ntl_store(&us, 1, __RISCV_NTLH_ALL_PRIVATE); // CHECK: store i16{{.*}}align 2, !nontemporal !5 + __rv_ntl_store(&ss, 1, __RISCV_NTLH_ALL_PRIVATE); // CHECK: store i16{{.*}}align 2, !nontemporal !5 + __rv_ntl_store(&ui, 1, __RISCV_NTLH_ALL_PRIVATE); // CHECK: store i32{{.*}}align 4, !nontemporal !5 + __rv_ntl_store(&si, 1, __RISCV_NTLH_ALL_PRIVATE); // CHECK: store i32{{.*}}align 4, !nontemporal !5 + __rv_ntl_store(&ull, 1, __RISCV_NTLH_ALL_PRIVATE); // CHECK: store i64{{.*}}align 8, !nontemporal !5 + __rv_ntl_store(&sll, 1, __RISCV_NTLH_ALL_PRIVATE); // CHECK: store i64{{.*}}align 8, !nontemporal !5 + __rv_ntl_store(&h1, 1.0, __RISCV_NTLH_ALL_PRIVATE); // CHECK: store half{{.*}}align 2, !nontemporal !5 + __rv_ntl_store(&f1, 1.0, __RISCV_NTLH_ALL_PRIVATE); // CHECK: store float{{.*}}align 4, !nontemporal !5 + __rv_ntl_store(&d1, 1.0, __RISCV_NTLH_ALL_PRIVATE); // CHECK: store double{{.*}}align 8, !nontemporal !5 + + __rv_ntl_store(&uc, 1, __RISCV_NTLH_INNERMOST_SHARED); // CHECK: store i8{{.*}}align 1, !nontemporal !6 + __rv_ntl_store(&sc, 1, __RISCV_NTLH_INNERMOST_SHARED); // CHECK: store i8{{.*}}align 1, !nontemporal !6 + __rv_ntl_store(&us, 1, __RISCV_NTLH_INNERMOST_SHARED); // CHECK: store i16{{.*}}align 2, !nontemporal !6 + __rv_ntl_store(&ss, 1, __RISCV_NTLH_INNERMOST_SHARED); // CHECK: store i16{{.*}}align 2, !nontemporal !6 + __rv_ntl_store(&ui, 1, __RISCV_NTLH_INNERMOST_SHARED); // CHECK: store i32{{.*}}align 4, !nontemporal !6 + __rv_ntl_store(&si, 1, __RISCV_NTLH_INNERMOST_SHARED); // CHECK: store i32{{.*}}align 4, !nontemporal !6 + __rv_ntl_store(&ull, 1, __RISCV_NTLH_INNERMOST_SHARED); // CHECK: store i64{{.*}}align 8, !nontemporal !6 + __rv_ntl_store(&sll, 1, __RISCV_NTLH_INNERMOST_SHARED); // CHECK: store i64{{.*}}align 8, !nontemporal !6 + __rv_ntl_store(&h1, 1.0, __RISCV_NTLH_INNERMOST_SHARED); // CHECK: store half{{.*}}align 2, !nontemporal !6 + __rv_ntl_store(&f1, 1.0, __RISCV_NTLH_INNERMOST_SHARED); // CHECK: store float{{.*}}align 4, !nontemporal !6 + __rv_ntl_store(&d1, 1.0, __RISCV_NTLH_INNERMOST_SHARED); // CHECK: store double{{.*}}align 8, !nontemporal !6 + + __rv_ntl_store(&uc, 1, __RISCV_NTLH_ALL); // CHECK: store i8{{.*}}align 1, !nontemporal !7 + __rv_ntl_store(&sc, 1, __RISCV_NTLH_ALL); // CHECK: store i8{{.*}}align 1, !nontemporal !7 + __rv_ntl_store(&us, 1, __RISCV_NTLH_ALL); // CHECK: store i16{{.*}}align 2, !nontemporal !7 + __rv_ntl_store(&ss, 1, __RISCV_NTLH_ALL); // CHECK: store i16{{.*}}align 2, !nontemporal !7 + __rv_ntl_store(&ui, 1, __RISCV_NTLH_ALL); // CHECK: store i32{{.*}}align 4, !nontemporal !7 + __rv_ntl_store(&si, 1, __RISCV_NTLH_ALL); // CHECK: store i32{{.*}}align 4, !nontemporal !7 + __rv_ntl_store(&ull, 1, __RISCV_NTLH_ALL); // CHECK: store i64{{.*}}align 8, !nontemporal !7 + __rv_ntl_store(&sll, 1, __RISCV_NTLH_ALL); // CHECK: store i64{{.*}}align 8, !nontemporal !7 + __rv_ntl_store(&h1, 1.0, __RISCV_NTLH_ALL); // CHECK: store half{{.*}}align 2, !nontemporal !7 + __rv_ntl_store(&f1, 1.0, __RISCV_NTLH_ALL); // CHECK: store float{{.*}}align 4, !nontemporal !7 + __rv_ntl_store(&d1, 1.0, __RISCV_NTLH_ALL); // CHECK: store double{{.*}}align 8, !nontemporal !7 + +} +// clang-format on + +// CHECK: !4 = !{i32 2} +// CHECK: !5 = !{i32 3} +// CHECK: !6 = !{i32 4} +// CHECK: !7 = !{i32 5} diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h --- a/llvm/include/llvm/CodeGen/TargetLowering.h +++ b/llvm/include/llvm/CodeGen/TargetLowering.h @@ -433,6 +433,13 @@ return MachineMemOperand::MONone; } + /// This callback is used to inspect load/store SDNode. + /// The default implementation does nothing. + virtual MachineMemOperand::Flags + getTargetMMOFlags(const MemSDNode &Node) const { + return MachineMemOperand::MONone; + } + MachineMemOperand::Flags getLoadMemOperandFlags(const LoadInst &LI, const DataLayout &DL, AssumptionCache *AC = nullptr, @@ -672,6 +679,13 @@ return false; } + /// Return true if it is valid to merge the TargetMMOFlags in two SDNodes. + virtual bool + areTwoSDNodeTargetMMOFlagsMergeable(const MemSDNode &NodeX, + const MemSDNode &NodeY) const { + return true; + } + /// Use bitwise logic to make pairs of compares more efficient. For example: /// and (seteq A, B), (seteq C, D) --> seteq (or (xor A, B), (xor C, D)), 0 /// This should be true when it takes more than one instruction to lower diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -19366,6 +19366,8 @@ // Don't mix temporal stores with non-temporal stores. if (St->isNonTemporal() != Other->isNonTemporal()) return false; + if (!TLI.areTwoSDNodeTargetMMOFlagsMergeable(*St, *Other)) + return false; SDValue OtherBC = peekThroughBitcasts(Other->getValue()); // Allow merging constants of different types as integers. bool NoTypeMatch = (MemVT.isInteger()) ? !MemVT.bitsEq(Other->getMemoryVT()) @@ -19391,6 +19393,9 @@ // Don't mix temporal loads with non-temporal loads. if (cast(Val)->isNonTemporal() != OtherLd->isNonTemporal()) return false; + if (!TLI.areTwoSDNodeTargetMMOFlagsMergeable(*cast(Val), + *OtherLd)) + return false; if (!(LBasePtr.equalBaseIndex(LPtr, DAG))) return false; break; @@ -20015,10 +20020,14 @@ if (IsNonTemporalLoad) LdMMOFlags |= MachineMemOperand::MONonTemporal; + LdMMOFlags |= TLI.getTargetMMOFlags(*FirstLoad); + MachineMemOperand::Flags StMMOFlags = IsNonTemporalStore ? MachineMemOperand::MONonTemporal : MachineMemOperand::MONone; + StMMOFlags |= TLI.getTargetMMOFlags(*StoreNodes[0].MemNode); + SDValue NewLoad, NewStore; if (UseVectorTy || !DoIntegerTruncate) { NewLoad = DAG.getLoad( diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h --- a/llvm/lib/Target/RISCV/RISCVISelLowering.h +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h @@ -474,6 +474,16 @@ // This method returns the name of a target specific DAG node. const char *getTargetNodeName(unsigned Opcode) const override; + MachineMemOperand::Flags + getTargetMMOFlags(const Instruction &I) const override; + + MachineMemOperand::Flags + getTargetMMOFlags(const MemSDNode &Node) const override; + + bool + areTwoSDNodeTargetMMOFlagsMergeable(const MemSDNode &NodeX, + const MemSDNode &NodeY) const override; + ConstraintType getConstraintType(StringRef Constraint) const override; unsigned getInlineAsmMemConstraint(StringRef ConstraintCode) const override; diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -15050,6 +15050,56 @@ return Reg; } +MachineMemOperand::Flags +RISCVTargetLowering::getTargetMMOFlags(const Instruction &I) const { + const MDNode *NontemporalInfo = I.getMetadata(LLVMContext::MD_nontemporal); + + if (NontemporalInfo == nullptr) + return MachineMemOperand::MONone; + + // 1 for default value work as __RISCV_NTLH_ALL + // 2 -> __RISCV_NTLH_INNERMOST_PRIVATE + // 3 -> __RISCV_NTLH_ALL_PRIVATE + // 4 -> __RISCV_NTLH_INNERMOST_SHARED + // 5 -> __RISCV_NTLH_ALL + int NontemporalLevel = + cast( + cast(NontemporalInfo->getOperand(0))->getValue()) + ->getZExtValue(); + + assert((1 <= NontemporalLevel && NontemporalLevel <= 5) && + "RISC-V target doesn't support this non-temporal domain."); + + // Mapping default value into __RISCV_NTLH_ALL + if (NontemporalLevel == 1) + NontemporalLevel = 5; + + NontemporalLevel -= 2; + MachineMemOperand::Flags Flags = MachineMemOperand::MONone; + if (NontemporalLevel & 0b1) + Flags |= MONontemporalBit0; + if (NontemporalLevel & 0b10) + Flags |= MONontemporalBit1; + + return Flags; +} + +MachineMemOperand::Flags +RISCVTargetLowering::getTargetMMOFlags(const MemSDNode &Node) const { + + MachineMemOperand::Flags NodeFlags = Node.getMemOperand()->getFlags(); + MachineMemOperand::Flags TargetFlags = MachineMemOperand::MONone; + TargetFlags |= (NodeFlags & MONontemporalBit0); + TargetFlags |= (NodeFlags & MONontemporalBit1); + + return TargetFlags; +} + +bool RISCVTargetLowering::areTwoSDNodeTargetMMOFlagsMergeable( + const MemSDNode &NodeX, const MemSDNode &NodeY) const { + return getTargetMMOFlags(NodeX) == getTargetMMOFlags(NodeY); +} + namespace llvm::RISCVVIntrinsicsTable { #define GET_RISCVVIntrinsicsTable_IMPL diff --git a/llvm/lib/Target/RISCV/RISCVInsertNTLHInsts.cpp b/llvm/lib/Target/RISCV/RISCVInsertNTLHInsts.cpp --- a/llvm/lib/Target/RISCV/RISCVInsertNTLHInsts.cpp +++ b/llvm/lib/Target/RISCV/RISCVInsertNTLHInsts.cpp @@ -67,11 +67,27 @@ continue; MachineMemOperand *MMO = *(MBBI.memoperands_begin()); if (MMO->isNonTemporal()) { + uint64_t NontemporalMode = 0; + if (MMO->getFlags() & MONontemporalBit0) + NontemporalMode += 0b1; + if (MMO->getFlags() & MONontemporalBit1) + NontemporalMode += 0b10; + + static const uint16_t NTLOpc[] = { + RISCV::PseudoNTLP1, RISCV::PseudoNTLPALL, RISCV::PseudoNTLS1, + RISCV::PseudoNTLALL}; + static const uint16_t CNTLOpc[] = { + RISCV::PseudoCNTLP1, RISCV::PseudoCNTLPALL, RISCV::PseudoCNTLS1, + RISCV::PseudoCNTLALL}; + + unsigned CurrNTLOpc; DebugLoc DL = MBBI.getDebugLoc(); if (ST.hasStdExtCOrZca() && ST.enableRVCHintInstrs()) - BuildMI(MBB, MBBI, DL, TII->get(RISCV::PseudoCNTLALL)); + CurrNTLOpc = CNTLOpc[NontemporalMode]; else - BuildMI(MBB, MBBI, DL, TII->get(RISCV::PseudoNTLALL)); + CurrNTLOpc = NTLOpc[NontemporalMode]; + + BuildMI(MBB, MBBI, DL, TII->get(CurrNTLOpc)); Changed = true; } } diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.h b/llvm/lib/Target/RISCV/RISCVInstrInfo.h --- a/llvm/lib/Target/RISCV/RISCVInstrInfo.h +++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.h @@ -25,6 +25,11 @@ class RISCVSubtarget; +static const MachineMemOperand::Flags MONontemporalBit0 = + MachineMemOperand::MOTargetFlag1; +static const MachineMemOperand::Flags MONontemporalBit1 = + MachineMemOperand::MOTargetFlag2; + namespace RISCVCC { enum CondCode { @@ -238,6 +243,9 @@ return hasAllNBitUsers(MI, MRI, 32); } + ArrayRef> + getSerializableMachineMemOperandTargetFlags() const override; + protected: const RISCVSubtarget &STI; }; diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp --- a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp +++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp @@ -2848,6 +2848,14 @@ return true; } +ArrayRef> +RISCVInstrInfo::getSerializableMachineMemOperandTargetFlags() const { + static const std::pair TargetFlags[] = + {{MONontemporalBit0, "riscv-non-temporal-domain-bit-0"}, + {MONontemporalBit1, "riscv-non-temporal-domain-bit-1"}}; + return makeArrayRef(TargetFlags); +} + // Returns true if this is the sext.w pattern, addiw rd, rs1, 0. bool RISCV::isSEXT_W(const MachineInstr &MI) { return MI.getOpcode() == RISCV::ADDIW && MI.getOperand(1).isReg() && diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoZihintntl.td b/llvm/lib/Target/RISCV/RISCVInstrInfoZihintntl.td --- a/llvm/lib/Target/RISCV/RISCVInstrInfoZihintntl.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoZihintntl.td @@ -11,12 +11,24 @@ /// //===----------------------------------------------------------------------===// -let hasSideEffects = 0, mayLoad = 0, mayStore = 0, Size = 4 in { - def PseudoNTLALL : Pseudo<(outs), (ins), [], "ntl.all">, +let hasSideEffects = 0, mayLoad = 0, mayStore = 0, Size = 4, isCodeGenOnly = 1 in { + def PseudoNTLP1 : Pseudo<(outs), (ins), [], "ntl.p1">, + PseudoInstExpansion<(ADD X0, X0, X2)>; + def PseudoNTLPALL : Pseudo<(outs), (ins), [], "ntl.pall">, + PseudoInstExpansion<(ADD X0, X0, X3)>; + def PseudoNTLS1 : Pseudo<(outs), (ins), [], "ntl.s1">, + PseudoInstExpansion<(ADD X0, X0, X4)>; + def PseudoNTLALL : Pseudo<(outs), (ins), [], "ntl.all">, PseudoInstExpansion<(ADD X0, X0, X5)>; } -let hasSideEffects = 0, mayLoad = 0, mayStore = 0, Size = 2 in { - def PseudoCNTLALL : Pseudo<(outs), (ins), [], "c.ntl.all">, +let hasSideEffects = 0, mayLoad = 0, mayStore = 0, Size = 2, isCodeGenOnly = 1 in { + def PseudoCNTLP1 : Pseudo<(outs), (ins), [], "c.ntl.p1">, + PseudoInstExpansion<(C_ADD_HINT X0, X0, X2)>; + def PseudoCNTLPALL : Pseudo<(outs), (ins), [], "c.ntl.pall">, + PseudoInstExpansion<(C_ADD_HINT X0, X0, X3)>; + def PseudoCNTLS1 : Pseudo<(outs), (ins), [], "c.ntl.s1">, + PseudoInstExpansion<(C_ADD_HINT X0, X0, X4)>; + def PseudoCNTLALL : Pseudo<(outs), (ins), [], "c.ntl.all">, PseudoInstExpansion<(C_ADD_HINT X0, X0, X5)>; }