Index: llvm/lib/Target/SPIRV/CMakeLists.txt =================================================================== --- llvm/lib/Target/SPIRV/CMakeLists.txt +++ llvm/lib/Target/SPIRV/CMakeLists.txt @@ -15,6 +15,7 @@ add_llvm_target(SPIRVCodeGen SPIRVAsmPrinter.cpp + SPIRVBuiltins.cpp SPIRVCallLowering.cpp SPIRVDuplicatesTracker.cpp SPIRVEmitIntrinsics.cpp @@ -38,6 +39,7 @@ AsmPrinter CodeGen Core + Demangle GlobalISel MC SPIRVDesc Index: llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVBaseInfo.h =================================================================== --- llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVBaseInfo.h +++ llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVBaseInfo.h @@ -181,6 +181,11 @@ #include "SPIRVGenTables.inc" } // namespace KernelProfilingInfo +namespace InstructionSet { +#define GET_InstructionSet_DECL +#include "SPIRVGenTables.inc" +} // namespace InstructionSet + namespace OpenCLExtInst { #define GET_OpenCLExtInst_DECL #include "SPIRVGenTables.inc" @@ -196,12 +201,11 @@ #include "SPIRVGenTables.inc" } // namespace Opcode -enum class InstructionSet : uint32_t { - OpenCL_std = 0, - GLSL_std_450 = 1, - SPV_AMD_shader_trinary_minmax = 2, +struct ExtendedBuiltin { + StringRef Name; + InstructionSet::InstructionSet Set; + uint32_t Number; }; -std::string getExtInstSetName(InstructionSet e); } // namespace SPIRV using CapabilityList = SmallVector; @@ -226,6 +230,12 @@ bool getSpirvBuiltInIdByName(StringRef Name, SPIRV::BuiltIn::BuiltIn &BI); +std::string getExtInstSetName(SPIRV::InstructionSet::InstructionSet Set); +SPIRV::InstructionSet::InstructionSet +getExtInstSetFromString(std::string SetName); +std::string getExtInstName(SPIRV::InstructionSet::InstructionSet Set, + uint32_t InstructionNumber); + // Return a string representation of the operands from startIndex onwards. // Templated to allow both MachineInstr and MCInst to use the same logic. template Index: llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVBaseInfo.cpp =================================================================== --- llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVBaseInfo.cpp +++ llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVBaseInfo.cpp @@ -41,6 +41,7 @@ using namespace OperandCategory; using namespace Extension; using namespace Capability; +using namespace InstructionSet; #define GET_SymbolicOperands_DECL #define GET_SymbolicOperands_IMPL #define GET_ExtensionEntries_DECL @@ -50,19 +51,6 @@ #define GET_ExtendedBuiltins_DECL #define GET_ExtendedBuiltins_IMPL #include "SPIRVGenTables.inc" - -#define CASE(CLASS, ATTR) \ - case CLASS::ATTR: \ - return #ATTR; -std::string getExtInstSetName(InstructionSet e) { - switch (e) { - CASE(InstructionSet, OpenCL_std) - CASE(InstructionSet, GLSL_std_450) - CASE(InstructionSet, SPV_AMD_shader_trinary_minmax) - break; - } - llvm_unreachable("Unexpected operand"); -} } // namespace SPIRV std::string @@ -185,4 +173,38 @@ BI = static_cast(Lookup->Value); return true; } + +std::string getExtInstSetName(SPIRV::InstructionSet::InstructionSet Set) { + switch (Set) { + case SPIRV::InstructionSet::OpenCL_std: + return "OpenCL.std"; + case SPIRV::InstructionSet::GLSL_std_450: + return "GLSL.std.450"; + case SPIRV::InstructionSet::SPV_AMD_shader_trinary_minmax: + return "SPV_AMD_shader_trinary_minmax"; + } + return "UNKNOWN_EXT_INST_SET"; +} + +SPIRV::InstructionSet::InstructionSet +getExtInstSetFromString(std::string SetName) { + for (auto Set : {SPIRV::InstructionSet::GLSL_std_450, + SPIRV::InstructionSet::OpenCL_std}) { + if (SetName == getExtInstSetName(Set)) + return Set; + } + llvm_unreachable("UNKNOWN_EXT_INST_SET"); +} + +std::string getExtInstName(SPIRV::InstructionSet::InstructionSet Set, + uint32_t InstructionNumber) { + const SPIRV::ExtendedBuiltin *Lookup = + SPIRV::lookupExtendedBuiltinBySetAndNumber( + SPIRV::InstructionSet::OpenCL_std, InstructionNumber); + + if (!Lookup) + return "UNKNOWN_EXT_INST"; + + return Lookup->Name.str(); +} } // namespace llvm Index: llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.h =================================================================== --- llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.h +++ llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.h @@ -14,11 +14,13 @@ #define LLVM_LIB_TARGET_SPIRV_INSTPRINTER_SPIRVINSTPRINTER_H #include "MCTargetDesc/SPIRVBaseInfo.h" +#include "llvm/ADT/DenseSet.h" #include "llvm/MC/MCInstPrinter.h" namespace llvm { class SPIRVInstPrinter : public MCInstPrinter { private: + SmallDenseMap ExtInstSetIDs; void recordOpExtInstImport(const MCInst *MI); public: Index: llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp =================================================================== --- llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp +++ llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp @@ -60,7 +60,10 @@ } void SPIRVInstPrinter::recordOpExtInstImport(const MCInst *MI) { - // TODO: insert {Reg, Set} into ExtInstSetIDs map. + Register Reg = MI->getOperand(0).getReg(); + auto Name = getSPIRVStringOperand(*MI, 1); + auto Set = getExtInstSetFromString(Name); + ExtInstSetIDs.insert({Reg, Set}); } void SPIRVInstPrinter::printInst(const MCInst *MI, uint64_t Address, @@ -306,7 +309,10 @@ void SPIRVInstPrinter::printExtension(const MCInst *MI, unsigned OpNo, raw_ostream &O) { - llvm_unreachable("Unimplemented printExtension"); + auto SetReg = MI->getOperand(2).getReg(); + auto Set = ExtInstSetIDs[SetReg]; + auto Op = MI->getOperand(OpNo).getImm(); + O << getExtInstName(Set, Op); } template Index: llvm/lib/Target/SPIRV/SPIRV.td =================================================================== --- llvm/lib/Target/SPIRV/SPIRV.td +++ llvm/lib/Target/SPIRV/SPIRV.td @@ -11,6 +11,7 @@ include "SPIRVRegisterInfo.td" include "SPIRVRegisterBanks.td" include "SPIRVInstrInfo.td" +include "SPIRVBuiltins.td" def SPIRVInstrInfo : InstrInfo; Index: llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp =================================================================== --- llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp +++ llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp @@ -269,7 +269,8 @@ MCInst Inst; Inst.setOpcode(SPIRV::OpExtInstImport); Inst.addOperand(MCOperand::createReg(Reg)); - addStringImm(getExtInstSetName(static_cast(Set)), + addStringImm(getExtInstSetName( + static_cast(Set)), Inst); outputMCInst(Inst); } Index: llvm/lib/Target/SPIRV/SPIRVBuiltins.h =================================================================== --- /dev/null +++ llvm/lib/Target/SPIRV/SPIRVBuiltins.h @@ -0,0 +1,40 @@ +//===-- SPIRVBuiltins.h - SPIR-V Built-in Functions -------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Lowering builtin function calls and types using their demangled names. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_LIB_TARGET_SPIRV_SPIRVBUILTINS_H +#define LLVM_LIB_TARGET_SPIRV_SPIRVBUILTINS_H + +#include "SPIRVGlobalRegistry.h" +#include "llvm/CodeGen/GlobalISel/CallLowering.h" +#include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h" + +namespace llvm { +namespace SPIRV { +/// Lowers a builtin funtion call using the provided \p DemangledCall skeleton +/// and external instruction \p Set. +/// +/// \return a pair of boolean values, the first true means the call recognized +/// as a builtin, the second one indicates the successful lowering. +/// +/// \p DemangledCall is the skeleton of the lowered builtin function call. +/// \p Set is the external instruction set containing the given builtin. +/// \p OrigRet is the single original virtual return register if defined, +/// Register(0) otherwise. \p OrigRetTy is the type of the \p OrigRet. \p Args +/// are the arguments of the lowered builtin call. +std::pair +lowerBuiltin(const StringRef DemangledCall, InstructionSet::InstructionSet Set, + MachineIRBuilder &MIRBuilder, const Register OrigRet, + const Type *OrigRetTy, const SmallVectorImpl &Args, + SPIRVGlobalRegistry *GR); +} // namespace SPIRV +} // namespace llvm +#endif // LLVM_LIB_TARGET_SPIRV_SPIRVBUILTINS_H Index: llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp =================================================================== --- /dev/null +++ llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp @@ -0,0 +1,1617 @@ +//===- SPIRVBuiltins.cpp - SPIR-V Built-in Functions ------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements lowering builtin function calls and types using their +// demangled names and TableGen records. +// +//===----------------------------------------------------------------------===// + +#include "SPIRVBuiltins.h" +#include "SPIRV.h" +#include "SPIRVUtils.h" +#include "llvm/IR/IntrinsicsSPIRV.h" +#include +#include + +#define DEBUG_TYPE "spirv-builtins" + +namespace llvm { +namespace SPIRV { +#define GET_BuiltinGroup_DECL +#include "SPIRVGenTables.inc" + +struct DemangledBuiltin { + StringRef Name; + InstructionSet::InstructionSet Set; + BuiltinGroup Group; + uint8_t MinNumArgs; + uint8_t MaxNumArgs; +}; + +#define GET_DemangledBuiltins_DECL +#define GET_DemangledBuiltins_IMPL + +struct IncomingCall { + const std::string BuiltinName; + const DemangledBuiltin *Builtin; + + const Register ReturnRegister; + const SPIRVType *ReturnType; + const SmallVectorImpl &Arguments; + + IncomingCall(const std::string BuiltinName, const DemangledBuiltin *Builtin, + const Register ReturnRegister, const SPIRVType *ReturnType, + const SmallVectorImpl &Arguments) + : BuiltinName(BuiltinName), Builtin(Builtin), + ReturnRegister(ReturnRegister), ReturnType(ReturnType), + Arguments(Arguments) {} +}; + +struct NativeBuiltin { + StringRef Name; + InstructionSet::InstructionSet Set; + uint32_t Opcode; +}; + +#define GET_NativeBuiltins_DECL +#define GET_NativeBuiltins_IMPL + +struct GroupBuiltin { + StringRef Name; + uint32_t Opcode; + uint32_t GroupOperation; + bool IsElect; + bool IsAllOrAny; + bool IsAllEqual; + bool IsBallot; + bool IsInverseBallot; + bool IsBallotBitExtract; + bool IsBallotFindBit; + bool IsLogical; + bool NoGroupOperation; + bool HasBoolArg; +}; + +#define GET_GroupBuiltins_DECL +#define GET_GroupBuiltins_IMPL + +struct GetBuiltin { + StringRef Name; + InstructionSet::InstructionSet Set; + BuiltIn::BuiltIn Value; +}; + +using namespace BuiltIn; +#define GET_GetBuiltins_DECL +#define GET_GetBuiltins_IMPL + +struct ImageQueryBuiltin { + StringRef Name; + InstructionSet::InstructionSet Set; + uint32_t Component; +}; + +#define GET_ImageQueryBuiltins_DECL +#define GET_ImageQueryBuiltins_IMPL + +struct ConvertBuiltin { + StringRef Name; + InstructionSet::InstructionSet Set; + bool IsDestinationSigned; + bool IsSaturated; + bool IsRounded; + FPRoundingMode::FPRoundingMode RoundingMode; +}; + +struct VectorLoadStoreBuiltin { + StringRef Name; + InstructionSet::InstructionSet Set; + uint32_t Number; + bool IsRounded; + FPRoundingMode::FPRoundingMode RoundingMode; +}; + +using namespace FPRoundingMode; +#define GET_ConvertBuiltins_DECL +#define GET_ConvertBuiltins_IMPL + +using namespace InstructionSet; +#define GET_VectorLoadStoreBuiltins_DECL +#define GET_VectorLoadStoreBuiltins_IMPL + +#define GET_CLMemoryScope_DECL +#define GET_CLSamplerAddressingMode_DECL +#define GET_CLMemoryFenceFlags_DECL +#define GET_ExtendedBuiltins_DECL +#include "SPIRVGenTables.inc" +} // namespace SPIRV + +//===----------------------------------------------------------------------===// +// Misc functions for looking up builtins and veryfying requirements using +// TableGen records +//===----------------------------------------------------------------------===// + +/// Looks up the demangled builtin call in the SPIRVBuiltins.td records using +/// the provided \p DemangledCall and specified \p Set. +/// +/// The lookup follows the following algorithm, returning the first successful +/// match: +/// 1. Search with the plain demangled name (expecting a 1:1 match). +/// 2. Search with the prefix before or suffix after the demangled name +/// signyfying the type of the first argument. +/// +/// \returns Wrapper around the demangled call and found builtin definition. +static std::unique_ptr +lookupBuiltin(StringRef DemangledCall, + SPIRV::InstructionSet::InstructionSet Set, + Register ReturnRegister, const SPIRVType *ReturnType, + const SmallVectorImpl &Arguments) { + // Extract the builtin function name and types of arguments from the call + // skeleton. + std::string BuiltinName = + DemangledCall.substr(0, DemangledCall.find('(')).str(); + + // Check if the extracted name contains type information between angle + // brackets. If so, the builtin is an instantiated template - needs to have + // the information after angle brackets and return type removed. + if (BuiltinName.find('<') && BuiltinName.back() == '>') { + BuiltinName = BuiltinName.substr(0, BuiltinName.find('<')); + BuiltinName = BuiltinName.substr(BuiltinName.find_last_of(" ") + 1); + } + + // Check if the extracted name begins with "__spirv_ImageSampleExplicitLod" + // contains return type information at the end "_R", if so extract the + // plain builtin name without the type information. + if (StringRef(BuiltinName).contains("__spirv_ImageSampleExplicitLod") && + StringRef(BuiltinName).contains("_R")) { + BuiltinName = BuiltinName.substr(0, BuiltinName.find("_R")); + } + + SmallVector BuiltinArgumentTypes; + StringRef BuiltinArgs = + DemangledCall.slice(DemangledCall.find('(') + 1, DemangledCall.find(')')); + BuiltinArgs.split(BuiltinArgumentTypes, ',', -1, false); + + // Look up the builtin in the defined set. Start with the plain demangled + // name, expecting a 1:1 match in the defined builtin set. + const SPIRV::DemangledBuiltin *Builtin; + if ((Builtin = SPIRV::lookupBuiltin(BuiltinName, Set))) + return std::make_unique( + BuiltinName, Builtin, ReturnRegister, ReturnType, Arguments); + + // If the initial look up was unsuccessful and the demangled call takes at + // least 1 argument, add a prefix or suffix signifying the type of the first + // argument and repeat the search. + if (BuiltinArgumentTypes.size() >= 1) { + char FirstArgumentType = BuiltinArgumentTypes[0][0]; + // Prefix to be added to the builtin's name for lookup. + // For example, OpenCL "abs" taking an unsigned value has a prefix "u_". + std::string Prefix; + + switch (FirstArgumentType) { + // Unsigned: + case 'u': + if (Set == SPIRV::InstructionSet::OpenCL_std) + Prefix = "u_"; + else if (Set == SPIRV::InstructionSet::GLSL_std_450) + Prefix = "u"; + break; + // Signed: + case 'c': + case 's': + case 'i': + case 'l': + if (Set == SPIRV::InstructionSet::OpenCL_std) + Prefix = "s_"; + else if (Set == SPIRV::InstructionSet::GLSL_std_450) + Prefix = "s"; + break; + // Floating-point: + case 'f': + case 'd': + case 'h': + if (Set == SPIRV::InstructionSet::OpenCL_std || + Set == SPIRV::InstructionSet::GLSL_std_450) + Prefix = "f"; + break; + } + + // If argument-type name prefix was added, look up the builtin again. + if (!Prefix.empty() && + (Builtin = SPIRV::lookupBuiltin(Prefix + BuiltinName, Set))) + return std::make_unique( + BuiltinName, Builtin, ReturnRegister, ReturnType, Arguments); + + // If lookup with a prefix failed, find a suffix to be added to the + // builtin's name for lookup. For example, OpenCL "group_reduce_max" taking + // an unsigned value has a suffix "u". + std::string Suffix; + + switch (FirstArgumentType) { + // Unsigned: + case 'u': + Suffix = "u"; + break; + // Signed: + case 'c': + case 's': + case 'i': + case 'l': + Suffix = "s"; + break; + // Floating-point: + case 'f': + case 'd': + case 'h': + Suffix = "f"; + break; + } + + // If argument-type name suffix was added, look up the builtin again. + if (!Suffix.empty() && + (Builtin = SPIRV::lookupBuiltin(BuiltinName + Suffix, Set))) + return std::make_unique( + BuiltinName, Builtin, ReturnRegister, ReturnType, Arguments); + } + + // No builtin with such name was found in the set. + return nullptr; +} + +//===----------------------------------------------------------------------===// +// Helper functions for building misc instructions +//===----------------------------------------------------------------------===// + +/// Helper function building either a resulting scalar or vector bool register +/// depending on the expected \p ResultType. +/// +/// \returns Tuple of the resulting register and its type. +static std::tuple +buildBoolRegister(MachineIRBuilder &MIRBuilder, const SPIRVType *ResultType, + SPIRVGlobalRegistry *GR) { + LLT Type; + SPIRVType *BoolType = GR->getOrCreateSPIRVBoolType(MIRBuilder); + + if (ResultType->getOpcode() == SPIRV::OpTypeVector) { + unsigned VectorElements = ResultType->getOperand(2).getImm(); + BoolType = + GR->getOrCreateSPIRVVectorType(BoolType, VectorElements, MIRBuilder); + const FixedVectorType *LLVMVectorType = + cast(GR->getTypeForSPIRVType(BoolType)); + Type = LLT::vector(LLVMVectorType->getElementCount(), 1); + } else { + Type = LLT::scalar(1); + } + + Register ResultRegister = + MIRBuilder.getMRI()->createGenericVirtualRegister(Type); + GR->assignSPIRVTypeToVReg(BoolType, ResultRegister, MIRBuilder.getMF()); + return std::make_tuple(ResultRegister, BoolType); +} + +/// Helper function for building either a vector or scalar select instruction +/// depending on the expected \p ResultType. +static bool buildSelectInst(MachineIRBuilder &MIRBuilder, + Register ReturnRegister, Register SourceRegister, + const SPIRVType *ReturnType, + SPIRVGlobalRegistry *GR) { + Register TrueConst, FalseConst; + + if (ReturnType->getOpcode() == SPIRV::OpTypeVector) { + unsigned Bits = GR->getScalarOrVectorBitWidth(ReturnType); + uint64_t AllOnes = APInt::getAllOnesValue(Bits).getZExtValue(); + TrueConst = GR->getOrCreateConsIntVector(AllOnes, MIRBuilder, ReturnType); + FalseConst = GR->getOrCreateConsIntVector(0, MIRBuilder, ReturnType); + } else { + TrueConst = GR->buildConstantInt(1, MIRBuilder, ReturnType); + FalseConst = GR->buildConstantInt(0, MIRBuilder, ReturnType); + } + return MIRBuilder.buildSelect(ReturnRegister, SourceRegister, TrueConst, + FalseConst); +} + +/// Helper function for building a load instruction loading into the +/// \p DestinationReg. +static Register buildLoadInst(SPIRVType *BaseType, Register PtrRegister, + MachineIRBuilder &MIRBuilder, + SPIRVGlobalRegistry *GR, LLT LowLevelType, + Register DestinationReg = Register(0)) { + MachineRegisterInfo *MRI = MIRBuilder.getMRI(); + if (!DestinationReg.isValid()) { + DestinationReg = MRI->createVirtualRegister(&SPIRV::IDRegClass); + MRI->setType(DestinationReg, LLT::scalar(32)); + GR->assignSPIRVTypeToVReg(BaseType, DestinationReg, MIRBuilder.getMF()); + } + // TODO: consider using correct address space and alignment (p0 is canonical + // type for selection though). + MachinePointerInfo PtrInfo = MachinePointerInfo(); + MIRBuilder.buildLoad(DestinationReg, PtrRegister, PtrInfo, Align()); + return DestinationReg; +} + +/// Helper function for building a load instruction for loading a builtin global +/// variable of \p BuiltinValue value. +static Register buildBuiltinVariableLoad(MachineIRBuilder &MIRBuilder, + SPIRVType *VariableType, + SPIRVGlobalRegistry *GR, + SPIRV::BuiltIn::BuiltIn BuiltinValue, + LLT LLType, + Register Reg = Register(0)) { + Register NewRegister = + MIRBuilder.getMRI()->createVirtualRegister(&SPIRV::IDRegClass); + MIRBuilder.getMRI()->setType(NewRegister, + LLT::pointer(0, GR->getPointerSize())); + SPIRVType *PtrType = GR->getOrCreateSPIRVPointerType( + VariableType, MIRBuilder, SPIRV::StorageClass::Input); + GR->assignSPIRVTypeToVReg(PtrType, NewRegister, MIRBuilder.getMF()); + + // Set up the global OpVariable with the necessary builtin decorations. + Register Variable = GR->buildGlobalVariable( + NewRegister, PtrType, getLinkStringForBuiltIn(BuiltinValue), nullptr, + SPIRV::StorageClass::Input, nullptr, true, true, + SPIRV::LinkageType::Import, MIRBuilder, false); + + // Load the value from the global variable. + Register LoadedRegister = + buildLoadInst(VariableType, Variable, MIRBuilder, GR, LLType, Reg); + MIRBuilder.getMRI()->setType(LoadedRegister, LLType); + return LoadedRegister; +} + +/// Helper external function for inserting ASSIGN_TYPE instuction between \p Reg +/// and its definition, set the new register as a destination of the definition, +/// assign SPIRVType to both registers. If SpirvTy is provided, use it as +/// SPIRVType in ASSIGN_TYPE, otherwise create it from \p Ty. Defined in +/// SPIRVPreLegalizer.cpp. +extern Register insertAssignInstr(Register Reg, Type *Ty, SPIRVType *SpirvTy, + SPIRVGlobalRegistry *GR, + MachineIRBuilder &MIB, + MachineRegisterInfo &MRI); + +// TODO: Move to TableGen. +static SPIRV::MemorySemantics::MemorySemantics +getSPIRVMemSemantics(std::memory_order MemOrder) { + switch (MemOrder) { + case std::memory_order::memory_order_relaxed: + return SPIRV::MemorySemantics::None; + case std::memory_order::memory_order_acquire: + return SPIRV::MemorySemantics::Acquire; + case std::memory_order::memory_order_release: + return SPIRV::MemorySemantics::Release; + case std::memory_order::memory_order_acq_rel: + return SPIRV::MemorySemantics::AcquireRelease; + case std::memory_order::memory_order_seq_cst: + return SPIRV::MemorySemantics::SequentiallyConsistent; + default: + llvm_unreachable("Unknown CL memory scope"); + } +} + +static SPIRV::Scope::Scope getSPIRVScope(SPIRV::CLMemoryScope ClScope) { + switch (ClScope) { + case SPIRV::CLMemoryScope::memory_scope_work_item: + return SPIRV::Scope::Invocation; + case SPIRV::CLMemoryScope::memory_scope_work_group: + return SPIRV::Scope::Workgroup; + case SPIRV::CLMemoryScope::memory_scope_device: + return SPIRV::Scope::Device; + case SPIRV::CLMemoryScope::memory_scope_all_svm_devices: + return SPIRV::Scope::CrossDevice; + case SPIRV::CLMemoryScope::memory_scope_sub_group: + return SPIRV::Scope::Subgroup; + } + llvm_unreachable("Unknown CL memory scope"); +} + +static Register buildConstantIntReg(uint64_t Val, MachineIRBuilder &MIRBuilder, + SPIRVGlobalRegistry *GR, + unsigned BitWidth = 32) { + SPIRVType *IntType = GR->getOrCreateSPIRVIntegerType(BitWidth, MIRBuilder); + return GR->buildConstantInt(Val, MIRBuilder, IntType); +} + +/// Helper function for building an atomic load instruction. +static bool buildAtomicLoadInst(const SPIRV::IncomingCall *Call, + MachineIRBuilder &MIRBuilder, + SPIRVGlobalRegistry *GR) { + Register PtrRegister = Call->Arguments[0]; + // TODO: if true insert call to __translate_ocl_memory_sccope before + // OpAtomicLoad and the function implementation. We can use Translator's + // output for transcoding/atomic_explicit_arguments.cl as an example. + Register ScopeRegister; + if (Call->Arguments.size() > 1) + ScopeRegister = Call->Arguments[1]; + else + ScopeRegister = buildConstantIntReg(SPIRV::Scope::Device, MIRBuilder, GR); + + Register MemSemanticsReg; + if (Call->Arguments.size() > 2) { + // TODO: Insert call to __translate_ocl_memory_order before OpAtomicLoad. + MemSemanticsReg = Call->Arguments[2]; + } else { + int Semantics = + SPIRV::MemorySemantics::SequentiallyConsistent | + getMemSemanticsForStorageClass(GR->getPointerStorageClass(PtrRegister)); + MemSemanticsReg = buildConstantIntReg(Semantics, MIRBuilder, GR); + } + + MIRBuilder.buildInstr(SPIRV::OpAtomicLoad) + .addDef(Call->ReturnRegister) + .addUse(GR->getSPIRVTypeID(Call->ReturnType)) + .addUse(PtrRegister) + .addUse(ScopeRegister) + .addUse(MemSemanticsReg); + return true; +} + +/// Helper function for building an atomic store instruction. +static bool buildAtomicStoreInst(const SPIRV::IncomingCall *Call, + MachineIRBuilder &MIRBuilder, + SPIRVGlobalRegistry *GR) { + Register ScopeRegister = + buildConstantIntReg(SPIRV::Scope::Device, MIRBuilder, GR); + Register PtrRegister = Call->Arguments[0]; + int Semantics = + SPIRV::MemorySemantics::SequentiallyConsistent | + getMemSemanticsForStorageClass(GR->getPointerStorageClass(PtrRegister)); + Register MemSemanticsReg = buildConstantIntReg(Semantics, MIRBuilder, GR); + + MIRBuilder.buildInstr(SPIRV::OpAtomicStore) + .addUse(PtrRegister) + .addUse(ScopeRegister) + .addUse(MemSemanticsReg) + .addUse(Call->Arguments[1]); + return true; +} + +/// Helper function for building an atomic compare-exchange instruction. +static bool buildAtomicCompareExchangeInst(const SPIRV::IncomingCall *Call, + MachineIRBuilder &MIRBuilder, + SPIRVGlobalRegistry *GR) { + const SPIRV::DemangledBuiltin *Builtin = Call->Builtin; + unsigned Opcode = + SPIRV::lookupNativeBuiltin(Builtin->Name, Builtin->Set)->Opcode; + bool IsCmpxchg = Call->Builtin->Name.contains("cmpxchg"); + MachineRegisterInfo *MRI = MIRBuilder.getMRI(); + + Register ObjectPtr = Call->Arguments[0]; // Pointer (volatile A *object.) + Register ExpectedArg = Call->Arguments[1]; // Comparator (C* expected). + Register Desired = Call->Arguments[2]; // Value (C Desired). + SPIRVType *SpvDesiredTy = GR->getSPIRVTypeForVReg(Desired); + LLT DesiredLLT = MRI->getType(Desired); + + assert(GR->getSPIRVTypeForVReg(ObjectPtr)->getOpcode() == + SPIRV::OpTypePointer); + unsigned ExpectedType = GR->getSPIRVTypeForVReg(ExpectedArg)->getOpcode(); + assert(IsCmpxchg ? ExpectedType == SPIRV::OpTypeInt + : ExpectedType == SPIRV::OpTypePointer); + assert(GR->isScalarOfType(Desired, SPIRV::OpTypeInt)); + + SPIRVType *SpvObjectPtrTy = GR->getSPIRVTypeForVReg(ObjectPtr); + assert(SpvObjectPtrTy->getOperand(2).isReg() && "SPIRV type is expected"); + auto StorageClass = static_cast( + SpvObjectPtrTy->getOperand(1).getImm()); + auto MemSemStorage = getMemSemanticsForStorageClass(StorageClass); + + Register MemSemEqualReg; + Register MemSemUnequalReg; + uint64_t MemSemEqual = + IsCmpxchg + ? SPIRV::MemorySemantics::None + : SPIRV::MemorySemantics::SequentiallyConsistent | MemSemStorage; + uint64_t MemSemUnequal = + IsCmpxchg + ? SPIRV::MemorySemantics::None + : SPIRV::MemorySemantics::SequentiallyConsistent | MemSemStorage; + if (Call->Arguments.size() >= 4) { + assert(Call->Arguments.size() >= 5 && + "Need 5+ args for explicit atomic cmpxchg"); + auto MemOrdEq = + static_cast(getIConstVal(Call->Arguments[3], MRI)); + auto MemOrdNeq = + static_cast(getIConstVal(Call->Arguments[4], MRI)); + MemSemEqual = getSPIRVMemSemantics(MemOrdEq) | MemSemStorage; + MemSemUnequal = getSPIRVMemSemantics(MemOrdNeq) | MemSemStorage; + if (MemOrdEq == MemSemEqual) + MemSemEqualReg = Call->Arguments[3]; + if (MemOrdNeq == MemSemEqual) + MemSemUnequalReg = Call->Arguments[4]; + } + if (!MemSemEqualReg.isValid()) + MemSemEqualReg = buildConstantIntReg(MemSemEqual, MIRBuilder, GR); + if (!MemSemUnequalReg.isValid()) + MemSemUnequalReg = buildConstantIntReg(MemSemUnequal, MIRBuilder, GR); + + Register ScopeReg; + auto Scope = IsCmpxchg ? SPIRV::Scope::Workgroup : SPIRV::Scope::Device; + if (Call->Arguments.size() >= 6) { + assert(Call->Arguments.size() == 6 && + "Extra args for explicit atomic cmpxchg"); + auto ClScope = static_cast( + getIConstVal(Call->Arguments[5], MRI)); + Scope = getSPIRVScope(ClScope); + if (ClScope == static_cast(Scope)) + ScopeReg = Call->Arguments[5]; + } + if (!ScopeReg.isValid()) + ScopeReg = buildConstantIntReg(Scope, MIRBuilder, GR); + + Register Expected = IsCmpxchg + ? ExpectedArg + : buildLoadInst(SpvDesiredTy, ExpectedArg, MIRBuilder, + GR, LLT::scalar(32)); + MRI->setType(Expected, DesiredLLT); + Register Tmp = !IsCmpxchg ? MRI->createGenericVirtualRegister(DesiredLLT) + : Call->ReturnRegister; + GR->assignSPIRVTypeToVReg(SpvDesiredTy, Tmp, MIRBuilder.getMF()); + + SPIRVType *IntTy = GR->getOrCreateSPIRVIntegerType(32, MIRBuilder); + MIRBuilder.buildInstr(Opcode) + .addDef(Tmp) + .addUse(GR->getSPIRVTypeID(IntTy)) + .addUse(ObjectPtr) + .addUse(ScopeReg) + .addUse(MemSemEqualReg) + .addUse(MemSemUnequalReg) + .addUse(Desired) + .addUse(Expected); + if (!IsCmpxchg) { + MIRBuilder.buildInstr(SPIRV::OpStore).addUse(ExpectedArg).addUse(Tmp); + MIRBuilder.buildICmp(CmpInst::ICMP_EQ, Call->ReturnRegister, Tmp, Expected); + } + return true; +} + +/// Helper function for building an atomic load instruction. +static bool buildAtomicRMWInst(const SPIRV::IncomingCall *Call, unsigned Opcode, + MachineIRBuilder &MIRBuilder, + SPIRVGlobalRegistry *GR) { + const MachineRegisterInfo *MRI = MIRBuilder.getMRI(); + Register ScopeRegister; + SPIRV::Scope::Scope Scope = SPIRV::Scope::Workgroup; + if (Call->Arguments.size() >= 4) { + assert(Call->Arguments.size() == 4 && "Extra args for explicit atomic RMW"); + auto CLScope = static_cast( + getIConstVal(Call->Arguments[5], MRI)); + Scope = getSPIRVScope(CLScope); + if (CLScope == static_cast(Scope)) + ScopeRegister = Call->Arguments[5]; + } + if (!ScopeRegister.isValid()) + ScopeRegister = buildConstantIntReg(Scope, MIRBuilder, GR); + + Register PtrRegister = Call->Arguments[0]; + Register MemSemanticsReg; + unsigned Semantics = SPIRV::MemorySemantics::None; + if (Call->Arguments.size() >= 3) { + std::memory_order Order = + static_cast(getIConstVal(Call->Arguments[2], MRI)); + Semantics = + getSPIRVMemSemantics(Order) | + getMemSemanticsForStorageClass(GR->getPointerStorageClass(PtrRegister)); + if (Order == Semantics) + MemSemanticsReg = Call->Arguments[3]; + } + if (!MemSemanticsReg.isValid()) + MemSemanticsReg = buildConstantIntReg(Semantics, MIRBuilder, GR); + + MIRBuilder.buildInstr(Opcode) + .addDef(Call->ReturnRegister) + .addUse(GR->getSPIRVTypeID(Call->ReturnType)) + .addUse(PtrRegister) + .addUse(ScopeRegister) + .addUse(MemSemanticsReg) + .addUse(Call->Arguments[1]); + return true; +} + +/// Helper function for building barriers, i.e., memory/control ordering +/// operations. +static bool buildBarrierInst(const SPIRV::IncomingCall *Call, unsigned Opcode, + MachineIRBuilder &MIRBuilder, + SPIRVGlobalRegistry *GR) { + const MachineRegisterInfo *MRI = MIRBuilder.getMRI(); + unsigned MemFlags = getIConstVal(Call->Arguments[0], MRI); + unsigned MemSemantics = SPIRV::MemorySemantics::None; + + if (MemFlags & SPIRV::CLK_LOCAL_MEM_FENCE) + MemSemantics |= SPIRV::MemorySemantics::WorkgroupMemory; + + if (MemFlags & SPIRV::CLK_GLOBAL_MEM_FENCE) + MemSemantics |= SPIRV::MemorySemantics::CrossWorkgroupMemory; + + if (MemFlags & SPIRV::CLK_IMAGE_MEM_FENCE) + MemSemantics |= SPIRV::MemorySemantics::ImageMemory; + + if (Opcode == SPIRV::OpMemoryBarrier) { + std::memory_order MemOrder = + static_cast(getIConstVal(Call->Arguments[1], MRI)); + MemSemantics = getSPIRVMemSemantics(MemOrder) | MemSemantics; + } else { + MemSemantics |= SPIRV::MemorySemantics::SequentiallyConsistent; + } + + Register MemSemanticsReg; + if (MemFlags == MemSemantics) + MemSemanticsReg = Call->Arguments[0]; + else + MemSemanticsReg = buildConstantIntReg(MemSemantics, MIRBuilder, GR); + + Register ScopeReg; + SPIRV::Scope::Scope Scope = SPIRV::Scope::Workgroup; + SPIRV::Scope::Scope MemScope = Scope; + if (Call->Arguments.size() >= 2) { + assert( + ((Opcode != SPIRV::OpMemoryBarrier && Call->Arguments.size() == 2) || + (Opcode == SPIRV::OpMemoryBarrier && Call->Arguments.size() == 3)) && + "Extra args for explicitly scoped barrier"); + Register ScopeArg = (Opcode == SPIRV::OpMemoryBarrier) ? Call->Arguments[2] + : Call->Arguments[1]; + SPIRV::CLMemoryScope CLScope = + static_cast(getIConstVal(ScopeArg, MRI)); + MemScope = getSPIRVScope(CLScope); + if (!(MemFlags & SPIRV::CLK_LOCAL_MEM_FENCE) || + (Opcode == SPIRV::OpMemoryBarrier)) + Scope = MemScope; + + if (CLScope == static_cast(Scope)) + ScopeReg = Call->Arguments[1]; + } + + if (!ScopeReg.isValid()) + ScopeReg = buildConstantIntReg(Scope, MIRBuilder, GR); + + auto MIB = MIRBuilder.buildInstr(Opcode).addUse(ScopeReg); + if (Opcode != SPIRV::OpMemoryBarrier) + MIB.addUse(buildConstantIntReg(MemScope, MIRBuilder, GR)); + MIB.addUse(MemSemanticsReg); + return true; +} + +static unsigned getNumComponentsForDim(SPIRV::Dim::Dim dim) { + switch (dim) { + case SPIRV::Dim::DIM_1D: + case SPIRV::Dim::DIM_Buffer: + return 1; + case SPIRV::Dim::DIM_2D: + case SPIRV::Dim::DIM_Cube: + case SPIRV::Dim::DIM_Rect: + return 2; + case SPIRV::Dim::DIM_3D: + return 3; + default: + llvm_unreachable("Cannot get num components for given Dim"); + } +} + +/// Helper function for obtaining the number of size components. +static unsigned getNumSizeComponents(SPIRVType *imgType) { + assert(imgType->getOpcode() == SPIRV::OpTypeImage); + auto dim = static_cast(imgType->getOperand(2).getImm()); + unsigned numComps = getNumComponentsForDim(dim); + bool arrayed = imgType->getOperand(4).getImm() == 1; + return arrayed ? numComps + 1 : numComps; +} + +//===----------------------------------------------------------------------===// +// Implementation functions for each builtin group +//===----------------------------------------------------------------------===// + +static bool generateExtInst(const SPIRV::IncomingCall *Call, + MachineIRBuilder &MIRBuilder, + SPIRVGlobalRegistry *GR) { + // Lookup the extended instruction number in the TableGen records. + const SPIRV::DemangledBuiltin *Builtin = Call->Builtin; + uint32_t Number = + SPIRV::lookupExtendedBuiltin(Builtin->Name, Builtin->Set)->Number; + + // Build extended instruction. + auto MIB = + MIRBuilder.buildInstr(SPIRV::OpExtInst) + .addDef(Call->ReturnRegister) + .addUse(GR->getSPIRVTypeID(Call->ReturnType)) + .addImm(static_cast(SPIRV::InstructionSet::OpenCL_std)) + .addImm(Number); + + for (auto Argument : Call->Arguments) + MIB.addUse(Argument); + return true; +} + +static bool generateRelationalInst(const SPIRV::IncomingCall *Call, + MachineIRBuilder &MIRBuilder, + SPIRVGlobalRegistry *GR) { + // Lookup the instruction opcode in the TableGen records. + const SPIRV::DemangledBuiltin *Builtin = Call->Builtin; + unsigned Opcode = + SPIRV::lookupNativeBuiltin(Builtin->Name, Builtin->Set)->Opcode; + + Register CompareRegister; + SPIRVType *RelationType; + std::tie(CompareRegister, RelationType) = + buildBoolRegister(MIRBuilder, Call->ReturnType, GR); + + // Build relational instruction. + auto MIB = MIRBuilder.buildInstr(Opcode) + .addDef(CompareRegister) + .addUse(GR->getSPIRVTypeID(RelationType)); + + for (auto Argument : Call->Arguments) + MIB.addUse(Argument); + + // Build select instruction. + return buildSelectInst(MIRBuilder, Call->ReturnRegister, CompareRegister, + Call->ReturnType, GR); +} + +static bool generateGroupInst(const SPIRV::IncomingCall *Call, + MachineIRBuilder &MIRBuilder, + SPIRVGlobalRegistry *GR) { + const SPIRV::DemangledBuiltin *Builtin = Call->Builtin; + const SPIRV::GroupBuiltin *GroupBuiltin = + SPIRV::lookupGroupBuiltin(Builtin->Name); + const MachineRegisterInfo *MRI = MIRBuilder.getMRI(); + Register Arg0; + if (GroupBuiltin->HasBoolArg) { + Register ConstRegister = Call->Arguments[0]; + auto ArgInstruction = getDefInstrMaybeConstant(ConstRegister, MRI); + // TODO: support non-constant bool values. + assert(ArgInstruction->getOpcode() == TargetOpcode::G_CONSTANT && + "Only constant bool value args are supported"); + if (GR->getSPIRVTypeForVReg(Call->Arguments[0])->getOpcode() != + SPIRV::OpTypeBool) + Arg0 = GR->buildConstantInt(getIConstVal(ConstRegister, MRI), MIRBuilder, + GR->getOrCreateSPIRVBoolType(MIRBuilder)); + } + + Register GroupResultRegister = Call->ReturnRegister; + SPIRVType *GroupResultType = Call->ReturnType; + + // TODO: maybe we need to check whether the result type is already boolean + // and in this case do not insert select instruction. + const bool HasBoolReturnTy = + GroupBuiltin->IsElect || GroupBuiltin->IsAllOrAny || + GroupBuiltin->IsAllEqual || GroupBuiltin->IsLogical || + GroupBuiltin->IsInverseBallot || GroupBuiltin->IsBallotBitExtract; + + if (HasBoolReturnTy) + std::tie(GroupResultRegister, GroupResultType) = + buildBoolRegister(MIRBuilder, Call->ReturnType, GR); + + auto Scope = Builtin->Name.startswith("sub_group") ? SPIRV::Scope::Subgroup + : SPIRV::Scope::Workgroup; + Register ScopeRegister = buildConstantIntReg(Scope, MIRBuilder, GR); + + // Build work/sub group instruction. + auto MIB = MIRBuilder.buildInstr(GroupBuiltin->Opcode) + .addDef(GroupResultRegister) + .addUse(GR->getSPIRVTypeID(GroupResultType)) + .addUse(ScopeRegister); + + if (!GroupBuiltin->NoGroupOperation) + MIB.addImm(GroupBuiltin->GroupOperation); + if (Call->Arguments.size() > 0) { + MIB.addUse(Arg0.isValid() ? Arg0 : Call->Arguments[0]); + for (unsigned i = 1; i < Call->Arguments.size(); i++) + MIB.addUse(Call->Arguments[i]); + } + + // Build select instruction. + if (HasBoolReturnTy) + buildSelectInst(MIRBuilder, Call->ReturnRegister, GroupResultRegister, + Call->ReturnType, GR); + return true; +} + +// These queries ask for a single size_t result for a given dimension index, e.g +// size_t get_global_id(uintt dimindex). In SPIR-V, the builtins corresonding to +// these values are all vec3 types, so we need to extract the correct index or +// return defaultVal (0 or 1 depending on the query). We also handle extending +// or tuncating in case size_t does not match the expected result type's +// bitwidth. +// +// For a constant index >= 3 we generate: +// %res = OpConstant %SizeT 0 +// +// For other indices we generate: +// %g = OpVariable %ptr_V3_SizeT Input +// OpDecorate %g BuiltIn XXX +// OpDecorate %g LinkageAttributes "__spirv_BuiltInXXX" +// OpDecorate %g Constant +// %loadedVec = OpLoad %V3_SizeT %g +// +// Then, if the index is constant < 3, we generate: +// %res = OpCompositeExtract %SizeT %loadedVec idx +// If the index is dynamic, we generate: +// %tmp = OpVectorExtractDynamic %SizeT %loadedVec %idx +// %cmp = OpULessThan %bool %idx %const_3 +// %res = OpSelect %SizeT %cmp %tmp %const_0 +// +// If the bitwidth of %res does not match the expected return type, we add an +// extend or truncate. +static bool genWorkgroupQuery(const SPIRV::IncomingCall *Call, + MachineIRBuilder &MIRBuilder, + SPIRVGlobalRegistry *GR, + SPIRV::BuiltIn::BuiltIn BuiltinValue, + uint64_t DefaultValue) { + Register IndexRegister = Call->Arguments[0]; + const unsigned ResultWidth = Call->ReturnType->getOperand(1).getImm(); + const unsigned PointerSize = GR->getPointerSize(); + const SPIRVType *PointerSizeType = + GR->getOrCreateSPIRVIntegerType(PointerSize, MIRBuilder); + MachineRegisterInfo *MRI = MIRBuilder.getMRI(); + auto IndexInstruction = getDefInstrMaybeConstant(IndexRegister, MRI); + + // Set up the final register to do truncation or extension on at the end. + Register ToTruncate = Call->ReturnRegister; + + // If the index is constant, we can statically determine if it is in range. + bool IsConstantIndex = + IndexInstruction->getOpcode() == TargetOpcode::G_CONSTANT; + + // If it's out of range (max dimension is 3), we can just return the constant + // default value (0 or 1 depending on which query function). + if (IsConstantIndex && getIConstVal(IndexRegister, MRI) >= 3) { + Register defaultReg = Call->ReturnRegister; + if (PointerSize != ResultWidth) { + defaultReg = MRI->createGenericVirtualRegister(LLT::scalar(PointerSize)); + GR->assignSPIRVTypeToVReg(PointerSizeType, defaultReg, + MIRBuilder.getMF()); + ToTruncate = defaultReg; + } + auto NewRegister = + GR->buildConstantInt(DefaultValue, MIRBuilder, PointerSizeType); + MIRBuilder.buildCopy(defaultReg, NewRegister); + } else { // If it could be in range, we need to load from the given builtin. + auto Vec3Ty = + GR->getOrCreateSPIRVVectorType(PointerSizeType, 3, MIRBuilder); + Register LoadedVector = + buildBuiltinVariableLoad(MIRBuilder, Vec3Ty, GR, BuiltinValue, + LLT::fixed_vector(3, PointerSize)); + // Set up the vreg to extract the result to (possibly a new temporary one). + Register Extracted = Call->ReturnRegister; + if (!IsConstantIndex || PointerSize != ResultWidth) { + Extracted = MRI->createGenericVirtualRegister(LLT::scalar(PointerSize)); + GR->assignSPIRVTypeToVReg(PointerSizeType, Extracted, MIRBuilder.getMF()); + } + // Use Intrinsic::spv_extractelt so dynamic vs static extraction is + // handled later: extr = spv_extractelt LoadedVector, IndexRegister. + MachineInstrBuilder ExtractInst = MIRBuilder.buildIntrinsic( + Intrinsic::spv_extractelt, ArrayRef{Extracted}, true); + ExtractInst.addUse(LoadedVector).addUse(IndexRegister); + + // If the index is dynamic, need check if it's < 3, and then use a select. + if (!IsConstantIndex) { + insertAssignInstr(Extracted, nullptr, PointerSizeType, GR, MIRBuilder, + *MRI); + + auto IndexType = GR->getSPIRVTypeForVReg(IndexRegister); + auto BoolType = GR->getOrCreateSPIRVBoolType(MIRBuilder); + + Register CompareRegister = + MRI->createGenericVirtualRegister(LLT::scalar(1)); + GR->assignSPIRVTypeToVReg(BoolType, CompareRegister, MIRBuilder.getMF()); + + // Use G_ICMP to check if idxVReg < 3. + MIRBuilder.buildICmp(CmpInst::ICMP_ULT, CompareRegister, IndexRegister, + GR->buildConstantInt(3, MIRBuilder, IndexType)); + + // Get constant for the default value (0 or 1 depending on which + // function). + Register DefaultRegister = + GR->buildConstantInt(DefaultValue, MIRBuilder, PointerSizeType); + + // Get a register for the selection result (possibly a new temporary one). + Register SelectionResult = Call->ReturnRegister; + if (PointerSize != ResultWidth) { + SelectionResult = + MRI->createGenericVirtualRegister(LLT::scalar(PointerSize)); + GR->assignSPIRVTypeToVReg(PointerSizeType, SelectionResult, + MIRBuilder.getMF()); + } + // Create the final G_SELECT to return the extracted value or the default. + MIRBuilder.buildSelect(SelectionResult, CompareRegister, Extracted, + DefaultRegister); + ToTruncate = SelectionResult; + } else { + ToTruncate = Extracted; + } + } + // Alter the result's bitwidth if it does not match the SizeT value extracted. + if (PointerSize != ResultWidth) + MIRBuilder.buildZExtOrTrunc(Call->ReturnRegister, ToTruncate); + return true; +} + +static bool generateBuiltinVar(const SPIRV::IncomingCall *Call, + MachineIRBuilder &MIRBuilder, + SPIRVGlobalRegistry *GR) { + // Lookup the builtin variable record. + const SPIRV::DemangledBuiltin *Builtin = Call->Builtin; + SPIRV::BuiltIn::BuiltIn Value = + SPIRV::lookupGetBuiltin(Builtin->Name, Builtin->Set)->Value; + + if (Value == SPIRV::BuiltIn::GlobalInvocationId) + return genWorkgroupQuery(Call, MIRBuilder, GR, Value, 0); + + // Build a load instruction for the builtin variable. + unsigned BitWidth = GR->getScalarOrVectorBitWidth(Call->ReturnType); + LLT LLType; + if (Call->ReturnType->getOpcode() == SPIRV::OpTypeVector) + LLType = + LLT::fixed_vector(Call->ReturnType->getOperand(2).getImm(), BitWidth); + else + LLType = LLT::scalar(BitWidth); + + return buildBuiltinVariableLoad(MIRBuilder, Call->ReturnType, GR, Value, + LLType, Call->ReturnRegister); +} + +static bool generateAtomicInst(const SPIRV::IncomingCall *Call, + MachineIRBuilder &MIRBuilder, + SPIRVGlobalRegistry *GR) { + // Lookup the instruction opcode in the TableGen records. + const SPIRV::DemangledBuiltin *Builtin = Call->Builtin; + unsigned Opcode = + SPIRV::lookupNativeBuiltin(Builtin->Name, Builtin->Set)->Opcode; + + switch (Opcode) { + case SPIRV::OpAtomicLoad: + return buildAtomicLoadInst(Call, MIRBuilder, GR); + case SPIRV::OpAtomicStore: + return buildAtomicStoreInst(Call, MIRBuilder, GR); + case SPIRV::OpAtomicCompareExchange: + case SPIRV::OpAtomicCompareExchangeWeak: + return buildAtomicCompareExchangeInst(Call, MIRBuilder, GR); + case SPIRV::OpAtomicIAdd: + case SPIRV::OpAtomicISub: + case SPIRV::OpAtomicOr: + case SPIRV::OpAtomicXor: + case SPIRV::OpAtomicAnd: + return buildAtomicRMWInst(Call, Opcode, MIRBuilder, GR); + case SPIRV::OpMemoryBarrier: + return buildBarrierInst(Call, SPIRV::OpMemoryBarrier, MIRBuilder, GR); + default: + return false; + } +} + +static bool generateBarrierInst(const SPIRV::IncomingCall *Call, + MachineIRBuilder &MIRBuilder, + SPIRVGlobalRegistry *GR) { + // Lookup the instruction opcode in the TableGen records. + const SPIRV::DemangledBuiltin *Builtin = Call->Builtin; + unsigned Opcode = + SPIRV::lookupNativeBuiltin(Builtin->Name, Builtin->Set)->Opcode; + + return buildBarrierInst(Call, Opcode, MIRBuilder, GR); +} + +static bool generateDotOrFMulInst(const SPIRV::IncomingCall *Call, + MachineIRBuilder &MIRBuilder, + SPIRVGlobalRegistry *GR) { + unsigned Opcode = GR->getSPIRVTypeForVReg(Call->Arguments[0])->getOpcode(); + bool IsVec = Opcode == SPIRV::OpTypeVector; + // Use OpDot only in case of vector args and OpFMul in case of scalar args. + MIRBuilder.buildInstr(IsVec ? SPIRV::OpDot : SPIRV::OpFMulS) + .addDef(Call->ReturnRegister) + .addUse(GR->getSPIRVTypeID(Call->ReturnType)) + .addUse(Call->Arguments[0]) + .addUse(Call->Arguments[1]); + return true; +} + +static bool generateGetQueryInst(const SPIRV::IncomingCall *Call, + MachineIRBuilder &MIRBuilder, + SPIRVGlobalRegistry *GR) { + // Lookup the builtin record. + SPIRV::BuiltIn::BuiltIn Value = + SPIRV::lookupGetBuiltin(Call->Builtin->Name, Call->Builtin->Set)->Value; + uint64_t IsDefault = (Value == SPIRV::BuiltIn::GlobalSize || + Value == SPIRV::BuiltIn::WorkgroupSize || + Value == SPIRV::BuiltIn::EnqueuedWorkgroupSize); + return genWorkgroupQuery(Call, MIRBuilder, GR, Value, IsDefault ? 1 : 0); +} + +static bool generateImageSizeQueryInst(const SPIRV::IncomingCall *Call, + MachineIRBuilder &MIRBuilder, + SPIRVGlobalRegistry *GR) { + // Lookup the image size query component number in the TableGen records. + const SPIRV::DemangledBuiltin *Builtin = Call->Builtin; + uint32_t Component = + SPIRV::lookupImageQueryBuiltin(Builtin->Name, Builtin->Set)->Component; + // Query result may either be a vector or a scalar. If return type is not a + // vector, expect only a single size component. Otherwise get the number of + // expected components. + SPIRVType *RetTy = Call->ReturnType; + unsigned NumExpectedRetComponents = RetTy->getOpcode() == SPIRV::OpTypeVector + ? RetTy->getOperand(2).getImm() + : 1; + // Get the actual number of query result/size components. + SPIRVType *ImgType = GR->getSPIRVTypeForVReg(Call->Arguments[0]); + unsigned NumActualRetComponents = getNumSizeComponents(ImgType); + Register QueryResult = Call->ReturnRegister; + SPIRVType *QueryResultType = Call->ReturnType; + if (NumExpectedRetComponents != NumActualRetComponents) { + QueryResult = MIRBuilder.getMRI()->createGenericVirtualRegister( + LLT::fixed_vector(NumActualRetComponents, 32)); + SPIRVType *IntTy = GR->getOrCreateSPIRVIntegerType(32, MIRBuilder); + QueryResultType = GR->getOrCreateSPIRVVectorType( + IntTy, NumActualRetComponents, MIRBuilder); + GR->assignSPIRVTypeToVReg(QueryResultType, QueryResult, MIRBuilder.getMF()); + } + bool IsDimBuf = ImgType->getOperand(2).getImm() == SPIRV::Dim::DIM_Buffer; + unsigned Opcode = + IsDimBuf ? SPIRV::OpImageQuerySize : SPIRV::OpImageQuerySizeLod; + auto MIB = MIRBuilder.buildInstr(Opcode) + .addDef(QueryResult) + .addUse(GR->getSPIRVTypeID(QueryResultType)) + .addUse(Call->Arguments[0]); + if (!IsDimBuf) + MIB.addUse(buildConstantIntReg(0, MIRBuilder, GR)); // Lod id. + if (NumExpectedRetComponents == NumActualRetComponents) + return true; + if (NumExpectedRetComponents == 1) { + // Only 1 component is expected, build OpCompositeExtract instruction. + unsigned ExtractedComposite = + Component == 3 ? NumActualRetComponents - 1 : Component; + assert(ExtractedComposite < NumActualRetComponents && + "Invalid composite index!"); + MIRBuilder.buildInstr(SPIRV::OpCompositeExtract) + .addDef(Call->ReturnRegister) + .addUse(GR->getSPIRVTypeID(Call->ReturnType)) + .addUse(QueryResult) + .addImm(ExtractedComposite); + } else { + // More than 1 component is expected, fill a new vector. + auto MIB = MIRBuilder.buildInstr(SPIRV::OpVectorShuffle) + .addDef(Call->ReturnRegister) + .addUse(GR->getSPIRVTypeID(Call->ReturnType)) + .addUse(QueryResult) + .addUse(QueryResult); + for (unsigned i = 0; i < NumExpectedRetComponents; ++i) + MIB.addImm(i < NumActualRetComponents ? i : 0xffffffff); + } + return true; +} + +static bool generateImageMiscQueryInst(const SPIRV::IncomingCall *Call, + MachineIRBuilder &MIRBuilder, + SPIRVGlobalRegistry *GR) { + // TODO: Add support for other image query builtins. + Register Image = Call->Arguments[0]; + + assert(Call->ReturnType->getOpcode() == SPIRV::OpTypeInt && + "Image samples query result must be of int type!"); + assert(GR->getSPIRVTypeForVReg(Image)->getOperand(2).getImm() == 1 && + "Image must be of 2D dimensionality"); + MIRBuilder.buildInstr(SPIRV::OpImageQuerySamples) + .addDef(Call->ReturnRegister) + .addUse(GR->getSPIRVTypeID(Call->ReturnType)) + .addUse(Image); + return true; +} + +// TODO: Move to TableGen. +static SPIRV::SamplerAddressingMode::SamplerAddressingMode +getSamplerAddressingModeFromBitmask(unsigned Bitmask) { + switch (Bitmask & SPIRV::CLK_ADDRESS_MODE_MASK) { + case SPIRV::CLK_ADDRESS_CLAMP: + return SPIRV::SamplerAddressingMode::Clamp; + case SPIRV::CLK_ADDRESS_CLAMP_TO_EDGE: + return SPIRV::SamplerAddressingMode::ClampToEdge; + case SPIRV::CLK_ADDRESS_REPEAT: + return SPIRV::SamplerAddressingMode::Repeat; + case SPIRV::CLK_ADDRESS_MIRRORED_REPEAT: + return SPIRV::SamplerAddressingMode::RepeatMirrored; + case SPIRV::CLK_ADDRESS_NONE: + return SPIRV::SamplerAddressingMode::None; + default: + llvm_unreachable("Unknown CL address mode"); + } +} + +static unsigned getSamplerParamFromBitmask(unsigned Bitmask) { + return (Bitmask & SPIRV::CLK_NORMALIZED_COORDS_TRUE) ? 1 : 0; +} + +static SPIRV::SamplerFilterMode::SamplerFilterMode +getSamplerFilterModeFromBitmask(unsigned Bitmask) { + if (Bitmask & SPIRV::CLK_FILTER_LINEAR) + return SPIRV::SamplerFilterMode::Linear; + if (Bitmask & SPIRV::CLK_FILTER_NEAREST) + return SPIRV::SamplerFilterMode::Nearest; + return SPIRV::SamplerFilterMode::Nearest; +} + +static bool generateReadImageInst(const StringRef DemangledCall, + const SPIRV::IncomingCall *Call, + MachineIRBuilder &MIRBuilder, + SPIRVGlobalRegistry *GR) { + Register Image = Call->Arguments[0]; + MachineRegisterInfo *MRI = MIRBuilder.getMRI(); + + if (DemangledCall.contains_insensitive("ocl_sampler")) { + Register Sampler = Call->Arguments[1]; + + if (!GR->isScalarOfType(Sampler, SPIRV::OpTypeSampler) && + getDefInstrMaybeConstant(Sampler, MRI)->getOperand(1).isCImm()) { + uint64_t SamplerMask = getIConstVal(Sampler, MRI); + Sampler = GR->buildConstantSampler( + Register(), getSamplerAddressingModeFromBitmask(SamplerMask), + getSamplerParamFromBitmask(SamplerMask), + getSamplerFilterModeFromBitmask(SamplerMask), MIRBuilder, + GR->getSPIRVTypeForVReg(Sampler)); + } + SPIRVType *ImageType = GR->getSPIRVTypeForVReg(Image); + SPIRVType *SampledImageType = + GR->getOrCreateOpTypeSampledImage(ImageType, MIRBuilder); + Register SampledImage = MRI->createVirtualRegister(&SPIRV::IDRegClass); + + MIRBuilder.buildInstr(SPIRV::OpSampledImage) + .addDef(SampledImage) + .addUse(GR->getSPIRVTypeID(SampledImageType)) + .addUse(Image) + .addUse(Sampler); + + Register Lod = GR->buildConstantFP(APFloat::getZero(APFloat::IEEEsingle()), + MIRBuilder); + SPIRVType *TempType = Call->ReturnType; + bool NeedsExtraction = false; + if (TempType->getOpcode() != SPIRV::OpTypeVector) { + TempType = + GR->getOrCreateSPIRVVectorType(Call->ReturnType, 4, MIRBuilder); + NeedsExtraction = true; + } + LLT LLType = LLT::scalar(GR->getScalarOrVectorBitWidth(TempType)); + Register TempRegister = MRI->createGenericVirtualRegister(LLType); + GR->assignSPIRVTypeToVReg(TempType, TempRegister, MIRBuilder.getMF()); + + MIRBuilder.buildInstr(SPIRV::OpImageSampleExplicitLod) + .addDef(NeedsExtraction ? TempRegister : Call->ReturnRegister) + .addUse(GR->getSPIRVTypeID(TempType)) + .addUse(SampledImage) + .addUse(Call->Arguments[2]) // Coordinate. + .addImm(SPIRV::ImageOperand::Lod) + .addUse(Lod); + + if (NeedsExtraction) + MIRBuilder.buildInstr(SPIRV::OpCompositeExtract) + .addDef(Call->ReturnRegister) + .addUse(GR->getSPIRVTypeID(Call->ReturnType)) + .addUse(TempRegister) + .addImm(0); + } else if (DemangledCall.contains_insensitive("msaa")) { + MIRBuilder.buildInstr(SPIRV::OpImageRead) + .addDef(Call->ReturnRegister) + .addUse(GR->getSPIRVTypeID(Call->ReturnType)) + .addUse(Image) + .addUse(Call->Arguments[1]) // Coordinate. + .addImm(SPIRV::ImageOperand::Sample) + .addUse(Call->Arguments[2]); + } else { + MIRBuilder.buildInstr(SPIRV::OpImageRead) + .addDef(Call->ReturnRegister) + .addUse(GR->getSPIRVTypeID(Call->ReturnType)) + .addUse(Image) + .addUse(Call->Arguments[1]); // Coordinate. + } + return true; +} + +static bool generateWriteImageInst(const SPIRV::IncomingCall *Call, + MachineIRBuilder &MIRBuilder, + SPIRVGlobalRegistry *GR) { + MIRBuilder.buildInstr(SPIRV::OpImageWrite) + .addUse(Call->Arguments[0]) // Image. + .addUse(Call->Arguments[1]) // Coordinate. + .addUse(Call->Arguments[2]); // Texel. + return true; +} + +static bool generateSampleImageInst(const StringRef DemangledCall, + const SPIRV::IncomingCall *Call, + MachineIRBuilder &MIRBuilder, + SPIRVGlobalRegistry *GR) { + if (Call->Builtin->Name.contains_insensitive( + "__translate_sampler_initializer")) { + // Build sampler literal. + uint64_t Bitmask = getIConstVal(Call->Arguments[0], MIRBuilder.getMRI()); + Register Sampler = GR->buildConstantSampler( + Call->ReturnRegister, getSamplerAddressingModeFromBitmask(Bitmask), + getSamplerParamFromBitmask(Bitmask), + getSamplerFilterModeFromBitmask(Bitmask), MIRBuilder, Call->ReturnType); + return Sampler.isValid(); + } else if (Call->Builtin->Name.contains_insensitive("__spirv_SampledImage")) { + // Create OpSampledImage. + Register Image = Call->Arguments[0]; + SPIRVType *ImageType = GR->getSPIRVTypeForVReg(Image); + SPIRVType *SampledImageType = + GR->getOrCreateOpTypeSampledImage(ImageType, MIRBuilder); + Register SampledImage = + Call->ReturnRegister.isValid() + ? Call->ReturnRegister + : MIRBuilder.getMRI()->createVirtualRegister(&SPIRV::IDRegClass); + MIRBuilder.buildInstr(SPIRV::OpSampledImage) + .addDef(SampledImage) + .addUse(GR->getSPIRVTypeID(SampledImageType)) + .addUse(Image) + .addUse(Call->Arguments[1]); // Sampler. + return true; + } else if (Call->Builtin->Name.contains_insensitive( + "__spirv_ImageSampleExplicitLod")) { + // Sample an image using an explicit level of detail. + std::string ReturnType = DemangledCall.str(); + if (DemangledCall.contains("_R")) { + ReturnType = ReturnType.substr(ReturnType.find("_R") + 2); + ReturnType = ReturnType.substr(0, ReturnType.find('(')); + } + SPIRVType *Type = GR->getOrCreateSPIRVTypeByName(ReturnType, MIRBuilder); + MIRBuilder.buildInstr(SPIRV::OpImageSampleExplicitLod) + .addDef(Call->ReturnRegister) + .addUse(GR->getSPIRVTypeID(Type)) + .addUse(Call->Arguments[0]) // Image. + .addUse(Call->Arguments[1]) // Coordinate. + .addImm(SPIRV::ImageOperand::Lod) + .addUse(Call->Arguments[3]); + return true; + } + return false; +} + +static bool generateSelectInst(const SPIRV::IncomingCall *Call, + MachineIRBuilder &MIRBuilder) { + MIRBuilder.buildSelect(Call->ReturnRegister, Call->Arguments[0], + Call->Arguments[1], Call->Arguments[2]); + return true; +} + +static bool generateSpecConstantInst(const SPIRV::IncomingCall *Call, + MachineIRBuilder &MIRBuilder, + SPIRVGlobalRegistry *GR) { + // Lookup the instruction opcode in the TableGen records. + const SPIRV::DemangledBuiltin *Builtin = Call->Builtin; + unsigned Opcode = + SPIRV::lookupNativeBuiltin(Builtin->Name, Builtin->Set)->Opcode; + const MachineRegisterInfo *MRI = MIRBuilder.getMRI(); + + switch (Opcode) { + case SPIRV::OpSpecConstant: { + // Build the SpecID decoration. + unsigned SpecId = + static_cast(getIConstVal(Call->Arguments[0], MRI)); + buildOpDecorate(Call->ReturnRegister, MIRBuilder, SPIRV::Decoration::SpecId, + {SpecId}); + // Determine the constant MI. + Register ConstRegister = Call->Arguments[1]; + const MachineInstr *Const = getDefInstrMaybeConstant(ConstRegister, MRI); + assert(Const && + (Const->getOpcode() == TargetOpcode::G_CONSTANT || + Const->getOpcode() == TargetOpcode::G_FCONSTANT) && + "Argument should be either an int or floating-point constant"); + // Determine the opcode and built the OpSpec MI. + const MachineOperand &ConstOperand = Const->getOperand(1); + if (Call->ReturnType->getOpcode() == SPIRV::OpTypeBool) { + assert(ConstOperand.isCImm() && "Int constant operand is expected"); + Opcode = ConstOperand.getCImm()->getValue().getZExtValue() + ? SPIRV::OpSpecConstantTrue + : SPIRV::OpSpecConstantFalse; + } + auto MIB = MIRBuilder.buildInstr(Opcode) + .addDef(Call->ReturnRegister) + .addUse(GR->getSPIRVTypeID(Call->ReturnType)); + + if (Call->ReturnType->getOpcode() != SPIRV::OpTypeBool) { + if (Const->getOpcode() == TargetOpcode::G_CONSTANT) + addNumImm(ConstOperand.getCImm()->getValue(), MIB); + else + addNumImm(ConstOperand.getFPImm()->getValueAPF().bitcastToAPInt(), MIB); + } + return true; + } + case SPIRV::OpSpecConstantComposite: { + auto MIB = MIRBuilder.buildInstr(Opcode) + .addDef(Call->ReturnRegister) + .addUse(GR->getSPIRVTypeID(Call->ReturnType)); + for (unsigned i = 0; i < Call->Arguments.size(); i++) + MIB.addUse(Call->Arguments[i]); + return true; + } + default: + return false; + } +} + +static bool generateEnqueueInst(const SPIRV::IncomingCall *Call, + MachineIRBuilder &MIRBuilder, + SPIRVGlobalRegistry *GR) { + // Lookup the instruction opcode in the TableGen records. + const SPIRV::DemangledBuiltin *Builtin = Call->Builtin; + unsigned Opcode = + SPIRV::lookupNativeBuiltin(Builtin->Name, Builtin->Set)->Opcode; + + switch (Opcode) { + case SPIRV::OpRetainEvent: + case SPIRV::OpReleaseEvent: + return MIRBuilder.buildInstr(Opcode).addUse(Call->Arguments[0]); + case SPIRV::OpCreateUserEvent: + case SPIRV::OpGetDefaultQueue: + return MIRBuilder.buildInstr(Opcode) + .addDef(Call->ReturnRegister) + .addUse(GR->getSPIRVTypeID(Call->ReturnType)); + case SPIRV::OpIsValidEvent: + return MIRBuilder.buildInstr(Opcode) + .addDef(Call->ReturnRegister) + .addUse(GR->getSPIRVTypeID(Call->ReturnType)) + .addUse(Call->Arguments[0]); + case SPIRV::OpSetUserEventStatus: + return MIRBuilder.buildInstr(Opcode) + .addUse(Call->Arguments[0]) + .addUse(Call->Arguments[1]); + case SPIRV::OpCaptureEventProfilingInfo: + return MIRBuilder.buildInstr(Opcode) + .addUse(Call->Arguments[0]) + .addUse(Call->Arguments[1]) + .addUse(Call->Arguments[2]); + case SPIRV::OpBuildNDRange: { + MachineRegisterInfo *MRI = MIRBuilder.getMRI(); + SPIRVType *PtrType = GR->getSPIRVTypeForVReg(Call->Arguments[0]); + assert(PtrType->getOpcode() == SPIRV::OpTypePointer && + PtrType->getOperand(2).isReg()); + Register TypeReg = PtrType->getOperand(2).getReg(); + SPIRVType *StructType = GR->getSPIRVTypeForVReg(TypeReg); + Register TmpReg = MRI->createVirtualRegister(&SPIRV::IDRegClass); + GR->assignSPIRVTypeToVReg(StructType, TmpReg, MIRBuilder.getMF()); + // Skip the first arg, it's the destination pointer. OpBuildNDRange takes + // three other arguments, so pass zero constant on absence. + unsigned NumArgs = Call->Arguments.size(); + assert(NumArgs >= 2); + Register GlobalWorkSize = Call->Arguments[NumArgs < 4 ? 1 : 2]; + Register LocalWorkSize = + NumArgs == 2 ? Register(0) : Call->Arguments[NumArgs < 4 ? 2 : 3]; + Register GlobalWorkOffset = NumArgs <= 3 ? Register(0) : Call->Arguments[1]; + if (NumArgs < 4) { + Register Const; + SPIRVType *SpvTy = GR->getSPIRVTypeForVReg(GlobalWorkSize); + if (SpvTy->getOpcode() == SPIRV::OpTypePointer) { + MachineInstr *DefInstr = MRI->getUniqueVRegDef(GlobalWorkSize); + assert(DefInstr && isSpvIntrinsic(*DefInstr, Intrinsic::spv_gep) && + DefInstr->getOperand(3).isReg()); + Register GWSPtr = DefInstr->getOperand(3).getReg(); + // TODO: Maybe simplify generation of the type of the fields. + unsigned Size = Call->Builtin->Name.equals("ndrange_3D") ? 3 : 2; + unsigned BitWidth = GR->getPointerSize() == 64 ? 64 : 32; + Type *BaseTy = IntegerType::get( + MIRBuilder.getMF().getFunction().getContext(), BitWidth); + Type *FieldTy = ArrayType::get(BaseTy, Size); + SPIRVType *SpvFieldTy = GR->getOrCreateSPIRVType(FieldTy, MIRBuilder); + GlobalWorkSize = MRI->createVirtualRegister(&SPIRV::IDRegClass); + GR->assignSPIRVTypeToVReg(SpvFieldTy, GlobalWorkSize, + MIRBuilder.getMF()); + MIRBuilder.buildInstr(SPIRV::OpLoad) + .addDef(GlobalWorkSize) + .addUse(GR->getSPIRVTypeID(SpvFieldTy)) + .addUse(GWSPtr); + Const = GR->getOrCreateConsIntArray(0, MIRBuilder, SpvFieldTy); + } else { + Const = GR->buildConstantInt(0, MIRBuilder, SpvTy); + } + if (!LocalWorkSize.isValid()) + LocalWorkSize = Const; + if (!GlobalWorkOffset.isValid()) + GlobalWorkOffset = Const; + } + MIRBuilder.buildInstr(Opcode) + .addDef(TmpReg) + .addUse(TypeReg) + .addUse(GlobalWorkSize) + .addUse(LocalWorkSize) + .addUse(GlobalWorkOffset); + return MIRBuilder.buildInstr(SPIRV::OpStore) + .addUse(Call->Arguments[0]) + .addUse(TmpReg); + } + default: + return false; + } +} + +static bool generateAsyncCopy(const SPIRV::IncomingCall *Call, + MachineIRBuilder &MIRBuilder, + SPIRVGlobalRegistry *GR) { + // Lookup the instruction opcode in the TableGen records. + const SPIRV::DemangledBuiltin *Builtin = Call->Builtin; + unsigned Opcode = + SPIRV::lookupNativeBuiltin(Builtin->Name, Builtin->Set)->Opcode; + auto Scope = buildConstantIntReg(SPIRV::Scope::Workgroup, MIRBuilder, GR); + + switch (Opcode) { + case SPIRV::OpGroupAsyncCopy: + return MIRBuilder.buildInstr(Opcode) + .addDef(Call->ReturnRegister) + .addUse(GR->getSPIRVTypeID(Call->ReturnType)) + .addUse(Scope) + .addUse(Call->Arguments[0]) + .addUse(Call->Arguments[1]) + .addUse(Call->Arguments[2]) + .addUse(buildConstantIntReg(1, MIRBuilder, GR)) + .addUse(Call->Arguments[3]); + case SPIRV::OpGroupWaitEvents: + return MIRBuilder.buildInstr(Opcode) + .addUse(Scope) + .addUse(Call->Arguments[0]) + .addUse(Call->Arguments[1]); + default: + return false; + } +} + +static bool generateConvertInst(const StringRef DemangledCall, + const SPIRV::IncomingCall *Call, + MachineIRBuilder &MIRBuilder, + SPIRVGlobalRegistry *GR) { + // Lookup the conversion builtin in the TableGen records. + const SPIRV::ConvertBuiltin *Builtin = + SPIRV::lookupConvertBuiltin(Call->Builtin->Name, Call->Builtin->Set); + + if (Builtin->IsSaturated) + buildOpDecorate(Call->ReturnRegister, MIRBuilder, + SPIRV::Decoration::SaturatedConversion, {}); + if (Builtin->IsRounded) + buildOpDecorate(Call->ReturnRegister, MIRBuilder, + SPIRV::Decoration::FPRoundingMode, {Builtin->RoundingMode}); + + unsigned Opcode = SPIRV::OpNop; + if (GR->isScalarOrVectorOfType(Call->Arguments[0], SPIRV::OpTypeInt)) { + // Int -> ... + if (GR->isScalarOrVectorOfType(Call->ReturnRegister, SPIRV::OpTypeInt)) { + // Int -> Int + if (Builtin->IsSaturated) + Opcode = Builtin->IsDestinationSigned ? SPIRV::OpSatConvertUToS + : SPIRV::OpSatConvertSToU; + else + Opcode = Builtin->IsDestinationSigned ? SPIRV::OpUConvert + : SPIRV::OpSConvert; + } else if (GR->isScalarOrVectorOfType(Call->ReturnRegister, + SPIRV::OpTypeFloat)) { + // Int -> Float + bool IsSourceSigned = + DemangledCall[DemangledCall.find_first_of('(') + 1] != 'u'; + Opcode = IsSourceSigned ? SPIRV::OpConvertSToF : SPIRV::OpConvertUToF; + } + } else if (GR->isScalarOrVectorOfType(Call->Arguments[0], + SPIRV::OpTypeFloat)) { + // Float -> ... + if (GR->isScalarOrVectorOfType(Call->ReturnRegister, SPIRV::OpTypeInt)) + // Float -> Int + Opcode = Builtin->IsDestinationSigned ? SPIRV::OpConvertFToS + : SPIRV::OpConvertFToU; + else if (GR->isScalarOrVectorOfType(Call->ReturnRegister, + SPIRV::OpTypeFloat)) + // Float -> Float + Opcode = SPIRV::OpFConvert; + } + + assert(Opcode != SPIRV::OpNop && + "Conversion between the types not implemented!"); + + MIRBuilder.buildInstr(Opcode) + .addDef(Call->ReturnRegister) + .addUse(GR->getSPIRVTypeID(Call->ReturnType)) + .addUse(Call->Arguments[0]); + return true; +} + +static bool generateVectorLoadStoreInst(const SPIRV::IncomingCall *Call, + MachineIRBuilder &MIRBuilder, + SPIRVGlobalRegistry *GR) { + // Lookup the vector load/store builtin in the TableGen records. + const SPIRV::VectorLoadStoreBuiltin *Builtin = + SPIRV::lookupVectorLoadStoreBuiltin(Call->Builtin->Name, + Call->Builtin->Set); + // Build extended instruction. + auto MIB = + MIRBuilder.buildInstr(SPIRV::OpExtInst) + .addDef(Call->ReturnRegister) + .addUse(GR->getSPIRVTypeID(Call->ReturnType)) + .addImm(static_cast(SPIRV::InstructionSet::OpenCL_std)) + .addImm(Builtin->Number); + for (auto Argument : Call->Arguments) + MIB.addUse(Argument); + + // Rounding mode should be passed as a last argument in the MI for builtins + // like "vstorea_halfn_r". + if (Builtin->IsRounded) + MIB.addImm(static_cast(Builtin->RoundingMode)); + return true; +} + +/// Lowers a builtin funtion call using the provided \p DemangledCall skeleton +/// and external instruction \p Set. +namespace SPIRV { +std::pair +lowerBuiltin(const StringRef DemangledCall, InstructionSet::InstructionSet Set, + MachineIRBuilder &MIRBuilder, const Register OrigRet, + const Type *OrigRetTy, const SmallVectorImpl &Args, + SPIRVGlobalRegistry *GR) { + LLVM_DEBUG(dbgs() << "Lowering builtin call: " << DemangledCall << "\n"); + + // SPIR-V type and return register. + Register ReturnRegister = OrigRet; + SPIRVType *ReturnType = nullptr; + if (OrigRetTy && !OrigRetTy->isVoidTy()) { + ReturnType = GR->assignTypeToVReg(OrigRetTy, OrigRet, MIRBuilder); + } else if (OrigRetTy && OrigRetTy->isVoidTy()) { + ReturnRegister = MIRBuilder.getMRI()->createVirtualRegister(&IDRegClass); + MIRBuilder.getMRI()->setType(ReturnRegister, LLT::scalar(32)); + ReturnType = GR->assignTypeToVReg(OrigRetTy, ReturnRegister, MIRBuilder); + } + + // Lookup the builtin in the TableGen records. + std::unique_ptr Call = + lookupBuiltin(DemangledCall, Set, ReturnRegister, ReturnType, Args); + + if (!Call) { + LLVM_DEBUG(dbgs() << "Builtin record was not found!"); + return {false, false}; + } + + // TODO: check if the provided args meet the builtin requirments. + assert(Args.size() >= Call->Builtin->MinNumArgs && + "Too few arguments to generate the builtin"); + if (Call->Builtin->MaxNumArgs && Args.size() <= Call->Builtin->MaxNumArgs) + LLVM_DEBUG(dbgs() << "More arguments provided than required!"); + + // Match the builtin with implementation based on the grouping. + switch (Call->Builtin->Group) { + case SPIRV::Extended: + return {true, generateExtInst(Call.get(), MIRBuilder, GR)}; + case SPIRV::Relational: + return {true, generateRelationalInst(Call.get(), MIRBuilder, GR)}; + case SPIRV::Group: + return {true, generateGroupInst(Call.get(), MIRBuilder, GR)}; + case SPIRV::Variable: + return {true, generateBuiltinVar(Call.get(), MIRBuilder, GR)}; + case SPIRV::Atomic: + return {true, generateAtomicInst(Call.get(), MIRBuilder, GR)}; + case SPIRV::Barrier: + return {true, generateBarrierInst(Call.get(), MIRBuilder, GR)}; + case SPIRV::Dot: + return {true, generateDotOrFMulInst(Call.get(), MIRBuilder, GR)}; + case SPIRV::GetQuery: + return {true, generateGetQueryInst(Call.get(), MIRBuilder, GR)}; + case SPIRV::ImageSizeQuery: + return {true, generateImageSizeQueryInst(Call.get(), MIRBuilder, GR)}; + case SPIRV::ImageMiscQuery: + return {true, generateImageMiscQueryInst(Call.get(), MIRBuilder, GR)}; + case SPIRV::ReadImage: + return {true, + generateReadImageInst(DemangledCall, Call.get(), MIRBuilder, GR)}; + case SPIRV::WriteImage: + return {true, generateWriteImageInst(Call.get(), MIRBuilder, GR)}; + case SPIRV::SampleImage: + return {true, + generateSampleImageInst(DemangledCall, Call.get(), MIRBuilder, GR)}; + case SPIRV::Select: + return {true, generateSelectInst(Call.get(), MIRBuilder)}; + case SPIRV::SpecConstant: + return {true, generateSpecConstantInst(Call.get(), MIRBuilder, GR)}; + case SPIRV::Enqueue: + return {true, generateEnqueueInst(Call.get(), MIRBuilder, GR)}; + case SPIRV::AsyncCopy: + return {true, generateAsyncCopy(Call.get(), MIRBuilder, GR)}; + case SPIRV::Convert: + return {true, + generateConvertInst(DemangledCall, Call.get(), MIRBuilder, GR)}; + case SPIRV::VectorLoadStore: + return {true, generateVectorLoadStoreInst(Call.get(), MIRBuilder, GR)}; + } + return {true, false}; +} +} // namespace SPIRV +} // namespace llvm Index: llvm/lib/Target/SPIRV/SPIRVBuiltins.td =================================================================== --- /dev/null +++ llvm/lib/Target/SPIRV/SPIRVBuiltins.td @@ -0,0 +1,1266 @@ +//===-- SPIRVBuiltins.td - Describe SPIRV Builtins ---------*- tablegen -*-===// + // + // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. + // See https://llvm.org/LICENSE.txt for license information. + // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + // + //===----------------------------------------------------------------------===// + // + // TableGen records defining implementation details of demangled builtin + // functions and types. + // + //===----------------------------------------------------------------------===// + +// Define SPIR-V external builtin/instruction sets +def InstructionSet : GenericEnum { + let FilterClass = "InstructionSet"; + let NameField = "Name"; + let ValueField = "Value"; +} + +class InstructionSet value> { + string Name = NAME; + bits<32> Value = value; +} + +def OpenCL_std : InstructionSet<0>; +def GLSL_std_450 : InstructionSet<1>; +def SPV_AMD_shader_trinary_minmax : InstructionSet<2>; + +// Define various builtin groups +def BuiltinGroup : GenericEnum { + let FilterClass = "BuiltinGroup"; +} + +class BuiltinGroup; + +def Extended : BuiltinGroup; +def Relational : BuiltinGroup; +def Group : BuiltinGroup; +def Variable : BuiltinGroup; +def Atomic : BuiltinGroup; +def Barrier : BuiltinGroup; +def Dot : BuiltinGroup; +def GetQuery : BuiltinGroup; +def ImageSizeQuery : BuiltinGroup; +def ImageMiscQuery : BuiltinGroup; +def Convert : BuiltinGroup; +def ReadImage : BuiltinGroup; +def WriteImage : BuiltinGroup; +def SampleImage : BuiltinGroup; +def Select : BuiltinGroup; +def SpecConstant : BuiltinGroup; +def Enqueue : BuiltinGroup; +def AsyncCopy : BuiltinGroup; +def VectorLoadStore : BuiltinGroup; + +//===----------------------------------------------------------------------===// +// Class defining a demangled builtin record. The information in the record +// should be used to expand the builtin into either native SPIR-V instructions +// or an external call (in case of builtins without a direct mapping). +// +// name is the demangled name of the given builtin. +// set specifies which external instruction set the builtin belongs to. +// group specifies to which implementation group given record belongs. +// minNumArgs is the minimum required number of arguments for lowering. +// maxNumArgs specifies the maximum used number of arguments for lowering. +//===----------------------------------------------------------------------===// +class DemangledBuiltin minNumArgs, bits<8> maxNumArgs> { + string Name = name; + InstructionSet Set = set; + BuiltinGroup Group = group; + bits<8> MinNumArgs = minNumArgs; + bits<8> MaxNumArgs = maxNumArgs; +} + +// Table gathering all the builtins. +def DemangledBuiltins : GenericTable { + let FilterClass = "DemangledBuiltin"; + let Fields = ["Name", "Set", "Group", "MinNumArgs", "MaxNumArgs"]; + string TypeOf_Set = "InstructionSet"; + string TypeOf_Group = "BuiltinGroup"; +} + +// Function to lookup builtins by their demangled name and set. +def lookupBuiltin : SearchIndex { + let Table = DemangledBuiltins; + let Key = ["Name", "Set"]; +} + +// Dot builtin record: +def : DemangledBuiltin<"dot", OpenCL_std, Dot, 2, 2>; + +// Image builtin records: +def : DemangledBuiltin<"read_imagei", OpenCL_std, ReadImage, 2, 4>; +def : DemangledBuiltin<"read_imageui", OpenCL_std, ReadImage, 2, 4>; +def : DemangledBuiltin<"read_imagef", OpenCL_std, ReadImage, 2, 4>; + +def : DemangledBuiltin<"write_imagef", OpenCL_std, WriteImage, 3, 4>; +def : DemangledBuiltin<"write_imagei", OpenCL_std, WriteImage, 3, 4>; +def : DemangledBuiltin<"write_imageui", OpenCL_std, WriteImage, 3, 4>; +def : DemangledBuiltin<"write_imageh", OpenCL_std, WriteImage, 3, 4>; + +def : DemangledBuiltin<"__translate_sampler_initializer", OpenCL_std, SampleImage, 1, 1>; +def : DemangledBuiltin<"__spirv_SampledImage", OpenCL_std, SampleImage, 2, 2>; +def : DemangledBuiltin<"__spirv_ImageSampleExplicitLod", OpenCL_std, SampleImage, 3, 4>; + +// Select builtin record: +def : DemangledBuiltin<"__spirv_Select", OpenCL_std, Select, 3, 3>; + +//===----------------------------------------------------------------------===// +// Class defining an extended builtin record used for lowering into an +// OpExtInst instruction. +// +// name is the demangled name of the given builtin. +// set specifies which external instruction set the builtin belongs to. +// number specifies the number of the instruction in the external set. +//===----------------------------------------------------------------------===// +class ExtendedBuiltin { + string Name = name; + InstructionSet Set = set; + bits<32> Number = number; +} + +// Table gathering all the extended builtins. +def ExtendedBuiltins : GenericTable { + let FilterClass = "ExtendedBuiltin"; + let Fields = ["Name", "Set", "Number"]; + string TypeOf_Set = "InstructionSet"; +} + +// Function to lookup extended builtins by their name and set. +def lookupExtendedBuiltin : SearchIndex { + let Table = ExtendedBuiltins; + let Key = ["Name", "Set"]; +} + +// Function to lookup extended builtins by their set and number. +def lookupExtendedBuiltinBySetAndNumber : SearchIndex { + let Table = ExtendedBuiltins; + let Key = ["Set", "Number"]; +} + +// OpenCL extended instruction enums +def OpenCLExtInst : GenericEnum { + let FilterClass = "OpenCLExtInst"; + let NameField = "Name"; + let ValueField = "Value"; +} + +class OpenCLExtInst value> { + string Name = name; + bits<32> Value = value; +} + +// GLSL extended instruction enums +def GLSLExtInst : GenericEnum { + let FilterClass = "GLSLExtInst"; + let NameField = "Name"; + let ValueField = "Value"; +} + +class GLSLExtInst value> { + string Name = name; + bits<32> Value = value; +} + +// Multiclass used to define at the same time both a demangled builtin record +// and a corresponding extended builtin record. +multiclass DemangledExtendedBuiltin { + def : DemangledBuiltin; + def : ExtendedBuiltin; + + if !eq(set, OpenCL_std) then { + def : OpenCLExtInst; + } + + if !eq(set, GLSL_std_450) then { + def : GLSLExtInst; + } +} + +// Extended builtin records: +defm : DemangledExtendedBuiltin<"acos", OpenCL_std, 0>; +defm : DemangledExtendedBuiltin<"acosh", OpenCL_std, 1>; +defm : DemangledExtendedBuiltin<"acospi", OpenCL_std, 2>; +defm : DemangledExtendedBuiltin<"asin", OpenCL_std, 3>; +defm : DemangledExtendedBuiltin<"asinh", OpenCL_std, 4>; +defm : DemangledExtendedBuiltin<"asinpi", OpenCL_std, 5>; +defm : DemangledExtendedBuiltin<"atan", OpenCL_std, 6>; +defm : DemangledExtendedBuiltin<"atan2", OpenCL_std, 7>; +defm : DemangledExtendedBuiltin<"atanh", OpenCL_std, 8>; +defm : DemangledExtendedBuiltin<"atanpi", OpenCL_std, 9>; +defm : DemangledExtendedBuiltin<"atan2pi", OpenCL_std, 10>; +defm : DemangledExtendedBuiltin<"cbrt", OpenCL_std, 11>; +defm : DemangledExtendedBuiltin<"ceil", OpenCL_std, 12>; +defm : DemangledExtendedBuiltin<"copysign", OpenCL_std, 13>; +defm : DemangledExtendedBuiltin<"cos", OpenCL_std, 14>; +defm : DemangledExtendedBuiltin<"cosh", OpenCL_std, 15>; +defm : DemangledExtendedBuiltin<"cospi", OpenCL_std, 16>; +defm : DemangledExtendedBuiltin<"erfc", OpenCL_std, 17>; +defm : DemangledExtendedBuiltin<"erf", OpenCL_std, 18>; +defm : DemangledExtendedBuiltin<"exp", OpenCL_std, 19>; +defm : DemangledExtendedBuiltin<"exp2", OpenCL_std, 20>; +defm : DemangledExtendedBuiltin<"exp10", OpenCL_std, 21>; +defm : DemangledExtendedBuiltin<"expm1", OpenCL_std, 22>; +defm : DemangledExtendedBuiltin<"fabs", OpenCL_std, 23>; +defm : DemangledExtendedBuiltin<"fdim", OpenCL_std, 24>; +defm : DemangledExtendedBuiltin<"floor", OpenCL_std, 25>; +defm : DemangledExtendedBuiltin<"fma", OpenCL_std, 26>; +defm : DemangledExtendedBuiltin<"fmax", OpenCL_std, 27>; +defm : DemangledExtendedBuiltin<"fmin", OpenCL_std, 28>; +defm : DemangledExtendedBuiltin<"fmod", OpenCL_std, 29>; +defm : DemangledExtendedBuiltin<"fract", OpenCL_std, 30>; +defm : DemangledExtendedBuiltin<"frexp", OpenCL_std, 31>; +defm : DemangledExtendedBuiltin<"hypot", OpenCL_std, 32>; +defm : DemangledExtendedBuiltin<"ilogb", OpenCL_std, 33>; +defm : DemangledExtendedBuiltin<"ldexp", OpenCL_std, 34>; +defm : DemangledExtendedBuiltin<"lgamma", OpenCL_std, 35>; +defm : DemangledExtendedBuiltin<"lgamma_r", OpenCL_std, 36>; +defm : DemangledExtendedBuiltin<"log", OpenCL_std, 37>; +defm : DemangledExtendedBuiltin<"log2", OpenCL_std, 38>; +defm : DemangledExtendedBuiltin<"log10", OpenCL_std, 39>; +defm : DemangledExtendedBuiltin<"log1p", OpenCL_std, 40>; +defm : DemangledExtendedBuiltin<"logb", OpenCL_std, 41>; +defm : DemangledExtendedBuiltin<"mad", OpenCL_std, 42>; +defm : DemangledExtendedBuiltin<"maxmag", OpenCL_std, 43>; +defm : DemangledExtendedBuiltin<"minmag", OpenCL_std, 44>; +defm : DemangledExtendedBuiltin<"modf", OpenCL_std, 45>; +defm : DemangledExtendedBuiltin<"nan", OpenCL_std, 46>; +defm : DemangledExtendedBuiltin<"nextafter", OpenCL_std, 47>; +defm : DemangledExtendedBuiltin<"pow", OpenCL_std, 48>; +defm : DemangledExtendedBuiltin<"pown", OpenCL_std, 49>; +defm : DemangledExtendedBuiltin<"powr", OpenCL_std, 50>; +defm : DemangledExtendedBuiltin<"remainder", OpenCL_std, 51>; +defm : DemangledExtendedBuiltin<"remquo", OpenCL_std, 52>; +defm : DemangledExtendedBuiltin<"rint", OpenCL_std, 53>; +defm : DemangledExtendedBuiltin<"rootn", OpenCL_std, 54>; +defm : DemangledExtendedBuiltin<"round", OpenCL_std, 55>; +defm : DemangledExtendedBuiltin<"rsqrt", OpenCL_std, 56>; +defm : DemangledExtendedBuiltin<"sin", OpenCL_std, 57>; +defm : DemangledExtendedBuiltin<"sincos", OpenCL_std, 58>; +defm : DemangledExtendedBuiltin<"sinh", OpenCL_std, 59>; +defm : DemangledExtendedBuiltin<"sinpi", OpenCL_std, 60>; +defm : DemangledExtendedBuiltin<"sqrt", OpenCL_std, 61>; +defm : DemangledExtendedBuiltin<"tan", OpenCL_std, 62>; +defm : DemangledExtendedBuiltin<"tanh", OpenCL_std, 63>; +defm : DemangledExtendedBuiltin<"tanpi", OpenCL_std, 64>; +defm : DemangledExtendedBuiltin<"tgamma", OpenCL_std, 65>; +defm : DemangledExtendedBuiltin<"trunc", OpenCL_std, 66>; +defm : DemangledExtendedBuiltin<"half_cos", OpenCL_std, 67>; +defm : DemangledExtendedBuiltin<"half_divide", OpenCL_std, 68>; +defm : DemangledExtendedBuiltin<"half_exp", OpenCL_std, 69>; +defm : DemangledExtendedBuiltin<"half_exp2", OpenCL_std, 70>; +defm : DemangledExtendedBuiltin<"half_exp10", OpenCL_std, 71>; +defm : DemangledExtendedBuiltin<"half_log", OpenCL_std, 72>; +defm : DemangledExtendedBuiltin<"half_log2", OpenCL_std, 73>; +defm : DemangledExtendedBuiltin<"half_log10", OpenCL_std, 74>; +defm : DemangledExtendedBuiltin<"half_powr", OpenCL_std, 75>; +defm : DemangledExtendedBuiltin<"half_recip", OpenCL_std, 76>; +defm : DemangledExtendedBuiltin<"half_rsqrt", OpenCL_std, 77>; +defm : DemangledExtendedBuiltin<"half_sin", OpenCL_std, 78>; +defm : DemangledExtendedBuiltin<"half_sqrt", OpenCL_std, 79>; +defm : DemangledExtendedBuiltin<"half_tan", OpenCL_std, 80>; +defm : DemangledExtendedBuiltin<"native_cos", OpenCL_std, 81>; +defm : DemangledExtendedBuiltin<"native_divide", OpenCL_std, 82>; +defm : DemangledExtendedBuiltin<"native_exp", OpenCL_std, 83>; +defm : DemangledExtendedBuiltin<"native_exp2", OpenCL_std, 84>; +defm : DemangledExtendedBuiltin<"native_exp10", OpenCL_std, 85>; +defm : DemangledExtendedBuiltin<"native_log", OpenCL_std, 86>; +defm : DemangledExtendedBuiltin<"native_log2", OpenCL_std, 87>; +defm : DemangledExtendedBuiltin<"native_log10", OpenCL_std, 88>; +defm : DemangledExtendedBuiltin<"native_powr", OpenCL_std, 89>; +defm : DemangledExtendedBuiltin<"native_recip", OpenCL_std, 90>; +defm : DemangledExtendedBuiltin<"native_rsqrt", OpenCL_std, 91>; +defm : DemangledExtendedBuiltin<"native_sin", OpenCL_std, 92>; +defm : DemangledExtendedBuiltin<"native_sqrt", OpenCL_std, 93>; +defm : DemangledExtendedBuiltin<"native_tan", OpenCL_std, 94>; +defm : DemangledExtendedBuiltin<"s_abs", OpenCL_std, 141>; +defm : DemangledExtendedBuiltin<"s_abs_diff", OpenCL_std, 142>; +defm : DemangledExtendedBuiltin<"s_add_sat", OpenCL_std, 143>; +defm : DemangledExtendedBuiltin<"u_add_sat", OpenCL_std, 144>; +defm : DemangledExtendedBuiltin<"s_hadd", OpenCL_std, 145>; +defm : DemangledExtendedBuiltin<"u_hadd", OpenCL_std, 146>; +defm : DemangledExtendedBuiltin<"s_rhadd", OpenCL_std, 147>; +defm : DemangledExtendedBuiltin<"u_rhadd", OpenCL_std, 148>; +defm : DemangledExtendedBuiltin<"s_clamp", OpenCL_std, 149>; +defm : DemangledExtendedBuiltin<"u_clamp", OpenCL_std, 150>; +defm : DemangledExtendedBuiltin<"clz", OpenCL_std, 151>; +defm : DemangledExtendedBuiltin<"ctz", OpenCL_std, 152>; +defm : DemangledExtendedBuiltin<"s_mad_hi", OpenCL_std, 153>; +defm : DemangledExtendedBuiltin<"u_mad_sat", OpenCL_std, 154>; +defm : DemangledExtendedBuiltin<"s_mad_sat", OpenCL_std, 155>; +defm : DemangledExtendedBuiltin<"s_max", OpenCL_std, 156>; +defm : DemangledExtendedBuiltin<"u_max", OpenCL_std, 157>; +defm : DemangledExtendedBuiltin<"s_min", OpenCL_std, 158>; +defm : DemangledExtendedBuiltin<"u_min", OpenCL_std, 159>; +defm : DemangledExtendedBuiltin<"s_mul_hi", OpenCL_std, 160>; +defm : DemangledExtendedBuiltin<"rotate", OpenCL_std, 161>; +defm : DemangledExtendedBuiltin<"s_sub_sat", OpenCL_std, 162>; +defm : DemangledExtendedBuiltin<"u_sub_sat", OpenCL_std, 163>; +defm : DemangledExtendedBuiltin<"u_upsample", OpenCL_std, 164>; +defm : DemangledExtendedBuiltin<"s_upsample", OpenCL_std, 165>; +defm : DemangledExtendedBuiltin<"popcount", OpenCL_std, 166>; +defm : DemangledExtendedBuiltin<"s_mad24", OpenCL_std, 167>; +defm : DemangledExtendedBuiltin<"u_mad24", OpenCL_std, 168>; +defm : DemangledExtendedBuiltin<"s_mul24", OpenCL_std, 169>; +defm : DemangledExtendedBuiltin<"u_mul24", OpenCL_std, 170>; +defm : DemangledExtendedBuiltin<"u_abs", OpenCL_std, 201>; +defm : DemangledExtendedBuiltin<"u_abs_diff", OpenCL_std, 202>; +defm : DemangledExtendedBuiltin<"u_mul_hi", OpenCL_std, 203>; +defm : DemangledExtendedBuiltin<"u_mad_hi", OpenCL_std, 204>; +defm : DemangledExtendedBuiltin<"fclamp", OpenCL_std, 95>; +defm : DemangledExtendedBuiltin<"degrees", OpenCL_std, 96>; +defm : DemangledExtendedBuiltin<"fmax_common", OpenCL_std, 97>; +defm : DemangledExtendedBuiltin<"fmin_common", OpenCL_std, 98>; +defm : DemangledExtendedBuiltin<"mix", OpenCL_std, 99>; +defm : DemangledExtendedBuiltin<"radians", OpenCL_std, 100>; +defm : DemangledExtendedBuiltin<"step", OpenCL_std, 101>; +defm : DemangledExtendedBuiltin<"smoothstep", OpenCL_std, 102>; +defm : DemangledExtendedBuiltin<"sign", OpenCL_std, 103>; +defm : DemangledExtendedBuiltin<"cross", OpenCL_std, 104>; +defm : DemangledExtendedBuiltin<"distance", OpenCL_std, 105>; +defm : DemangledExtendedBuiltin<"length", OpenCL_std, 106>; +defm : DemangledExtendedBuiltin<"normalize", OpenCL_std, 107>; +defm : DemangledExtendedBuiltin<"fast_distance", OpenCL_std, 108>; +defm : DemangledExtendedBuiltin<"fast_length", OpenCL_std, 109>; +defm : DemangledExtendedBuiltin<"fast_normalize", OpenCL_std, 110>; +defm : DemangledExtendedBuiltin<"bitselect", OpenCL_std, 186>; +defm : DemangledExtendedBuiltin<"select", OpenCL_std, 187>; +defm : DemangledExtendedBuiltin<"vloadn", OpenCL_std, 171>; +defm : DemangledExtendedBuiltin<"vstoren", OpenCL_std, 172>; +defm : DemangledExtendedBuiltin<"vload_half", OpenCL_std, 173>; +defm : DemangledExtendedBuiltin<"vload_halfn", OpenCL_std, 174>; +defm : DemangledExtendedBuiltin<"vstore_half", OpenCL_std, 175>; +defm : DemangledExtendedBuiltin<"vstore_half_r", OpenCL_std, 176>; +defm : DemangledExtendedBuiltin<"vstore_halfn", OpenCL_std, 177>; +defm : DemangledExtendedBuiltin<"vstore_halfn_r", OpenCL_std, 178>; +defm : DemangledExtendedBuiltin<"vloada_halfn", OpenCL_std, 179>; +defm : DemangledExtendedBuiltin<"vstorea_halfn", OpenCL_std, 180>; +defm : DemangledExtendedBuiltin<"vstorea_halfn_r", OpenCL_std, 181>; +defm : DemangledExtendedBuiltin<"shuffle", OpenCL_std, 182>; +defm : DemangledExtendedBuiltin<"shuffle2", OpenCL_std, 183>; +defm : DemangledExtendedBuiltin<"printf", OpenCL_std, 184>; +defm : DemangledExtendedBuiltin<"prefetch", OpenCL_std, 185>; + +defm : DemangledExtendedBuiltin<"Round", GLSL_std_450, 1>; +defm : DemangledExtendedBuiltin<"RoundEven", GLSL_std_450, 2>; +defm : DemangledExtendedBuiltin<"Trunc", GLSL_std_450, 3>; +defm : DemangledExtendedBuiltin<"FAbs", GLSL_std_450, 4>; +defm : DemangledExtendedBuiltin<"SAbs", GLSL_std_450, 5>; +defm : DemangledExtendedBuiltin<"FSign", GLSL_std_450, 6>; +defm : DemangledExtendedBuiltin<"SSign", GLSL_std_450, 7>; +defm : DemangledExtendedBuiltin<"Floor", GLSL_std_450, 8>; +defm : DemangledExtendedBuiltin<"Ceil", GLSL_std_450, 9>; +defm : DemangledExtendedBuiltin<"Fract", GLSL_std_450, 10>; +defm : DemangledExtendedBuiltin<"Radians", GLSL_std_450, 11>; +defm : DemangledExtendedBuiltin<"Degrees", GLSL_std_450, 12>; +defm : DemangledExtendedBuiltin<"Sin", GLSL_std_450, 13>; +defm : DemangledExtendedBuiltin<"Cos", GLSL_std_450, 14>; +defm : DemangledExtendedBuiltin<"Tan", GLSL_std_450, 15>; +defm : DemangledExtendedBuiltin<"Asin", GLSL_std_450, 16>; +defm : DemangledExtendedBuiltin<"Acos", GLSL_std_450, 17>; +defm : DemangledExtendedBuiltin<"Atan", GLSL_std_450, 18>; +defm : DemangledExtendedBuiltin<"Sinh", GLSL_std_450, 19>; +defm : DemangledExtendedBuiltin<"Cosh", GLSL_std_450, 20>; +defm : DemangledExtendedBuiltin<"Tanh", GLSL_std_450, 21>; +defm : DemangledExtendedBuiltin<"Asinh", GLSL_std_450, 22>; +defm : DemangledExtendedBuiltin<"Acosh", GLSL_std_450, 23>; +defm : DemangledExtendedBuiltin<"Atanh", GLSL_std_450, 24>; +defm : DemangledExtendedBuiltin<"Atan2", GLSL_std_450, 25>; +defm : DemangledExtendedBuiltin<"Pow", GLSL_std_450, 26>; +defm : DemangledExtendedBuiltin<"Exp", GLSL_std_450, 27>; +defm : DemangledExtendedBuiltin<"Log", GLSL_std_450, 28>; +defm : DemangledExtendedBuiltin<"Exp2", GLSL_std_450, 29>; +defm : DemangledExtendedBuiltin<"Log2", GLSL_std_450, 30>; +defm : DemangledExtendedBuiltin<"Sqrt", GLSL_std_450, 31>; +defm : DemangledExtendedBuiltin<"InverseSqrt", GLSL_std_450, 32>; +defm : DemangledExtendedBuiltin<"Determinant", GLSL_std_450, 33>; +defm : DemangledExtendedBuiltin<"MatrixInverse", GLSL_std_450, 34>; +defm : DemangledExtendedBuiltin<"Modf", GLSL_std_450, 35>; +defm : DemangledExtendedBuiltin<"ModfStruct", GLSL_std_450, 36>; +defm : DemangledExtendedBuiltin<"FMin", GLSL_std_450, 37>; +defm : DemangledExtendedBuiltin<"UMin", GLSL_std_450, 38>; +defm : DemangledExtendedBuiltin<"SMin", GLSL_std_450, 39>; +defm : DemangledExtendedBuiltin<"FMax", GLSL_std_450, 40>; +defm : DemangledExtendedBuiltin<"UMax", GLSL_std_450, 41>; +defm : DemangledExtendedBuiltin<"SMax", GLSL_std_450, 42>; +defm : DemangledExtendedBuiltin<"FClamp", GLSL_std_450, 43>; +defm : DemangledExtendedBuiltin<"UClamp", GLSL_std_450, 44>; +defm : DemangledExtendedBuiltin<"SClamp", GLSL_std_450, 45>; +defm : DemangledExtendedBuiltin<"FMix", GLSL_std_450, 46>; +defm : DemangledExtendedBuiltin<"Step", GLSL_std_450, 48>; +defm : DemangledExtendedBuiltin<"SmoothStep", GLSL_std_450, 49>; +defm : DemangledExtendedBuiltin<"Fma", GLSL_std_450, 50>; +defm : DemangledExtendedBuiltin<"Frexp", GLSL_std_450, 51>; +defm : DemangledExtendedBuiltin<"FrexpStruct", GLSL_std_450, 52>; +defm : DemangledExtendedBuiltin<"Ldexp", GLSL_std_450, 53>; +defm : DemangledExtendedBuiltin<"PackSnorm4x8", GLSL_std_450, 54>; +defm : DemangledExtendedBuiltin<"PackUnorm4x8", GLSL_std_450, 55>; +defm : DemangledExtendedBuiltin<"PackSnorm2x16", GLSL_std_450, 56>; +defm : DemangledExtendedBuiltin<"PackUnorm2x16", GLSL_std_450, 57>; +defm : DemangledExtendedBuiltin<"PackHalf2x16", GLSL_std_450, 58>; +defm : DemangledExtendedBuiltin<"PackDouble2x32", GLSL_std_450, 59>; +defm : DemangledExtendedBuiltin<"UnpackSnorm2x16", GLSL_std_450, 60>; +defm : DemangledExtendedBuiltin<"UnpackUnorm2x16", GLSL_std_450, 61>; +defm : DemangledExtendedBuiltin<"UnpackHalf2x16", GLSL_std_450, 62>; +defm : DemangledExtendedBuiltin<"UnpackSnorm4x8", GLSL_std_450, 63>; +defm : DemangledExtendedBuiltin<"UnpackUnorm4x8", GLSL_std_450, 64>; +defm : DemangledExtendedBuiltin<"UnpackDouble2x32", GLSL_std_450, 65>; +defm : DemangledExtendedBuiltin<"Length", GLSL_std_450, 66>; +defm : DemangledExtendedBuiltin<"Distance", GLSL_std_450, 67>; +defm : DemangledExtendedBuiltin<"Cross", GLSL_std_450, 68>; +defm : DemangledExtendedBuiltin<"Normalize", GLSL_std_450, 69>; +defm : DemangledExtendedBuiltin<"FaceForward", GLSL_std_450, 70>; +defm : DemangledExtendedBuiltin<"Reflect", GLSL_std_450, 71>; +defm : DemangledExtendedBuiltin<"Refract", GLSL_std_450, 72>; +defm : DemangledExtendedBuiltin<"FindILsb", GLSL_std_450, 73>; +defm : DemangledExtendedBuiltin<"FindSMsb", GLSL_std_450, 74>; +defm : DemangledExtendedBuiltin<"FindUMsb", GLSL_std_450, 75>; +defm : DemangledExtendedBuiltin<"InterpolateAtCentroid", GLSL_std_450, 76>; +defm : DemangledExtendedBuiltin<"InterpolateAtSample", GLSL_std_450, 77>; +defm : DemangledExtendedBuiltin<"InterpolateAtOffset", GLSL_std_450, 78>; +defm : DemangledExtendedBuiltin<"NMin", GLSL_std_450, 79>; +defm : DemangledExtendedBuiltin<"NMax", GLSL_std_450, 80>; +defm : DemangledExtendedBuiltin<"NClamp", GLSL_std_450, 81>; + +//===----------------------------------------------------------------------===// +// Class defining an native builtin record used for direct translation into a +// SPIR-V instruction. +// +// name is the demangled name of the given builtin. +// set specifies which external instruction set the builtin belongs to. +// opcode specifies the SPIR-V operation code of the generated instruction. +//===----------------------------------------------------------------------===// +class NativeBuiltin { + string Name = name; + InstructionSet Set = set; + Op Opcode = operation; +} + +// Table gathering all the native builtins. +def NativeBuiltins : GenericTable { + let FilterClass = "NativeBuiltin"; + let Fields = ["Name", "Set", "Opcode"]; + string TypeOf_Set = "InstructionSet"; +} + +// Function to lookup native builtins by their name and set. +def lookupNativeBuiltin : SearchIndex { + let Table = NativeBuiltins; + let Key = ["Name", "Set"]; +} + +// Multiclass used to define at the same time both an incoming builtin record +// and a corresponding native builtin record. +multiclass DemangledNativeBuiltin minNumArgs, bits<8> maxNumArgs, Op operation> { + def : DemangledBuiltin; + def : NativeBuiltin; +} + +// Relational builtin records: +defm : DemangledNativeBuiltin<"isequal", OpenCL_std, Relational, 2, 2, OpFOrdEqual>; +defm : DemangledNativeBuiltin<"__spirv_FOrdEqual", OpenCL_std, Relational, 2, 2, OpFOrdEqual>; +defm : DemangledNativeBuiltin<"isnotequal", OpenCL_std, Relational, 2, 2, OpFUnordNotEqual>; +defm : DemangledNativeBuiltin<"__spirv_FUnordNotEqual", OpenCL_std, Relational, 2, 2, OpFUnordNotEqual>; +defm : DemangledNativeBuiltin<"isgreater", OpenCL_std, Relational, 2, 2, OpFOrdGreaterThan>; +defm : DemangledNativeBuiltin<"__spirv_FOrdGreaterThan", OpenCL_std, Relational, 2, 2, OpFOrdGreaterThan>; +defm : DemangledNativeBuiltin<"isgreaterequal", OpenCL_std, Relational, 2, 2, OpFOrdGreaterThanEqual>; +defm : DemangledNativeBuiltin<"__spirv_FOrdGreaterThanEqual", OpenCL_std, Relational, 2, 2, OpFOrdGreaterThanEqual>; +defm : DemangledNativeBuiltin<"isless", OpenCL_std, Relational, 2, 2, OpFOrdLessThan>; +defm : DemangledNativeBuiltin<"__spirv_FOrdLessThan", OpenCL_std, Relational, 2, 2, OpFOrdLessThan>; +defm : DemangledNativeBuiltin<"islessequal", OpenCL_std, Relational, 2, 2, OpFOrdLessThanEqual>; +defm : DemangledNativeBuiltin<"__spirv_FOrdLessThanEqual", OpenCL_std, Relational, 2, 2, OpFOrdLessThanEqual>; +defm : DemangledNativeBuiltin<"islessgreater", OpenCL_std, Relational, 2, 2, OpFOrdNotEqual>; +defm : DemangledNativeBuiltin<"__spirv_FOrdNotEqual", OpenCL_std, Relational, 2, 2, OpFOrdNotEqual>; +defm : DemangledNativeBuiltin<"isordered", OpenCL_std, Relational, 2, 2, OpOrdered>; +defm : DemangledNativeBuiltin<"__spirv_Ordered", OpenCL_std, Relational, 2, 2, OpOrdered>; +defm : DemangledNativeBuiltin<"isunordered", OpenCL_std, Relational, 2, 2, OpUnordered>; +defm : DemangledNativeBuiltin<"__spirv_Unordered", OpenCL_std, Relational, 2, 2, OpUnordered>; +defm : DemangledNativeBuiltin<"isfinite", OpenCL_std, Relational, 1, 1, OpIsFinite>; +defm : DemangledNativeBuiltin<"__spirv_IsFinite", OpenCL_std, Relational, 1, 1, OpIsFinite>; +defm : DemangledNativeBuiltin<"isinf", OpenCL_std, Relational, 1, 1, OpIsInf>; +defm : DemangledNativeBuiltin<"__spirv_IsInf", OpenCL_std, Relational, 1, 1, OpIsInf>; +defm : DemangledNativeBuiltin<"isnan", OpenCL_std, Relational, 1, 1, OpIsNan>; +defm : DemangledNativeBuiltin<"__spirv_IsNan", OpenCL_std, Relational, 1, 1, OpIsNan>; +defm : DemangledNativeBuiltin<"isnormal", OpenCL_std, Relational, 1, 1, OpIsNormal>; +defm : DemangledNativeBuiltin<"__spirv_IsNormal", OpenCL_std, Relational, 1, 1, OpIsNormal>; +defm : DemangledNativeBuiltin<"signbit", OpenCL_std, Relational, 1, 1, OpSignBitSet>; +defm : DemangledNativeBuiltin<"__spirv_SignBitSet", OpenCL_std, Relational, 1, 1, OpSignBitSet>; +defm : DemangledNativeBuiltin<"any", OpenCL_std, Relational, 1, 1, OpAny>; +defm : DemangledNativeBuiltin<"__spirv_Any", OpenCL_std, Relational, 1, 1, OpAny>; +defm : DemangledNativeBuiltin<"all", OpenCL_std, Relational, 1, 1, OpAll>; +defm : DemangledNativeBuiltin<"__spirv_All", OpenCL_std, Relational, 1, 1, OpAll>; + +// Atomic builtin records: +defm : DemangledNativeBuiltin<"atomic_load", OpenCL_std, Atomic, 1, 1, OpAtomicLoad>; +defm : DemangledNativeBuiltin<"atomic_load_explicit", OpenCL_std, Atomic, 2, 3, OpAtomicLoad>; +defm : DemangledNativeBuiltin<"atomic_store", OpenCL_std, Atomic, 2, 2, OpAtomicStore>; +defm : DemangledNativeBuiltin<"atomic_store_explicit", OpenCL_std, Atomic, 2, 2, OpAtomicStore>; +defm : DemangledNativeBuiltin<"atomic_compare_exchange_strong", OpenCL_std, Atomic, 3, 6, OpAtomicCompareExchange>; +defm : DemangledNativeBuiltin<"atomic_compare_exchange_strong_explicit", OpenCL_std, Atomic, 5, 6, OpAtomicCompareExchange>; +defm : DemangledNativeBuiltin<"atomic_compare_exchange_weak", OpenCL_std, Atomic, 3, 6, OpAtomicCompareExchangeWeak>; +defm : DemangledNativeBuiltin<"atomic_compare_exchange_weak_explicit", OpenCL_std, Atomic, 5, 6, OpAtomicCompareExchangeWeak>; +defm : DemangledNativeBuiltin<"atom_cmpxchg", OpenCL_std, Atomic, 3, 6, OpAtomicCompareExchange>; +defm : DemangledNativeBuiltin<"atomic_cmpxchg", OpenCL_std, Atomic, 3, 6, OpAtomicCompareExchange>; +defm : DemangledNativeBuiltin<"atom_add", OpenCL_std, Atomic, 2, 4, OpAtomicIAdd>; +defm : DemangledNativeBuiltin<"atomic_add", OpenCL_std, Atomic, 2, 4, OpAtomicIAdd>; +defm : DemangledNativeBuiltin<"atom_sub", OpenCL_std, Atomic, 2, 4, OpAtomicISub>; +defm : DemangledNativeBuiltin<"atomic_sub", OpenCL_std, Atomic, 2, 4, OpAtomicISub>; +defm : DemangledNativeBuiltin<"atom_or", OpenCL_std, Atomic, 2, 4, OpAtomicOr>; +defm : DemangledNativeBuiltin<"atomic_or", OpenCL_std, Atomic, 2, 4, OpAtomicOr>; +defm : DemangledNativeBuiltin<"atom_xor", OpenCL_std, Atomic, 2, 4, OpAtomicXor>; +defm : DemangledNativeBuiltin<"atomic_xor", OpenCL_std, Atomic, 2, 4, OpAtomicXor>; +defm : DemangledNativeBuiltin<"atom_and", OpenCL_std, Atomic, 2, 4, OpAtomicAnd>; +defm : DemangledNativeBuiltin<"atomic_and", OpenCL_std, Atomic, 2, 4, OpAtomicAnd>; +defm : DemangledNativeBuiltin<"atomic_exchange", OpenCL_std, Atomic, 2, 4, OpAtomicExchange>; +defm : DemangledNativeBuiltin<"atomic_exchange_explicit", OpenCL_std, Atomic, 2, 4, OpAtomicExchange>; +defm : DemangledNativeBuiltin<"atomic_work_item_fence", OpenCL_std, Atomic, 1, 3, OpMemoryBarrier>; +defm : DemangledNativeBuiltin<"atomic_fetch_add", OpenCL_std, Atomic, 2, 4, OpAtomicIAdd>; +defm : DemangledNativeBuiltin<"atomic_fetch_sub", OpenCL_std, Atomic, 2, 4, OpAtomicISub>; +defm : DemangledNativeBuiltin<"atomic_fetch_or", OpenCL_std, Atomic, 2, 4, OpAtomicOr>; +defm : DemangledNativeBuiltin<"atomic_fetch_xor", OpenCL_std, Atomic, 2, 4, OpAtomicXor>; +defm : DemangledNativeBuiltin<"atomic_fetch_and", OpenCL_std, Atomic, 2, 4, OpAtomicAnd>; +defm : DemangledNativeBuiltin<"atomic_fetch_add_explicit", OpenCL_std, Atomic, 4, 6, OpAtomicIAdd>; +defm : DemangledNativeBuiltin<"atomic_fetch_sub_explicit", OpenCL_std, Atomic, 4, 6, OpAtomicISub>; +defm : DemangledNativeBuiltin<"atomic_fetch_or_explicit", OpenCL_std, Atomic, 4, 6, OpAtomicOr>; +defm : DemangledNativeBuiltin<"atomic_fetch_xor_explicit", OpenCL_std, Atomic, 4, 6, OpAtomicXor>; +defm : DemangledNativeBuiltin<"atomic_fetch_and_explicit", OpenCL_std, Atomic, 4, 6, OpAtomicAnd>; + +// Barrier builtin records: +defm : DemangledNativeBuiltin<"barrier", OpenCL_std, Barrier, 1, 3, OpControlBarrier>; +defm : DemangledNativeBuiltin<"work_group_barrier", OpenCL_std, Barrier, 1, 3, OpControlBarrier>; + +// Kernel enqueue builtin records: +defm : DemangledNativeBuiltin<"retain_event", OpenCL_std, Enqueue, 1, 1, OpRetainEvent>; +defm : DemangledNativeBuiltin<"release_event", OpenCL_std, Enqueue, 1, 1, OpReleaseEvent>; +defm : DemangledNativeBuiltin<"create_user_event", OpenCL_std, Enqueue, 0, 0, OpCreateUserEvent>; +defm : DemangledNativeBuiltin<"is_valid_event", OpenCL_std, Enqueue, 1, 1, OpIsValidEvent>; +defm : DemangledNativeBuiltin<"set_user_event_status", OpenCL_std, Enqueue, 2, 2, OpSetUserEventStatus>; +defm : DemangledNativeBuiltin<"capture_event_profiling_info", OpenCL_std, Enqueue, 3, 3, OpCaptureEventProfilingInfo>; +defm : DemangledNativeBuiltin<"get_default_queue", OpenCL_std, Enqueue, 0, 0, OpGetDefaultQueue>; +defm : DemangledNativeBuiltin<"ndrange_1D", OpenCL_std, Enqueue, 1, 3, OpBuildNDRange>; +defm : DemangledNativeBuiltin<"ndrange_2D", OpenCL_std, Enqueue, 1, 3, OpBuildNDRange>; +defm : DemangledNativeBuiltin<"ndrange_3D", OpenCL_std, Enqueue, 1, 3, OpBuildNDRange>; + +// Spec constant builtin record: +defm : DemangledNativeBuiltin<"__spirv_SpecConstant", OpenCL_std, SpecConstant, 2, 2, OpSpecConstant>; +defm : DemangledNativeBuiltin<"__spirv_SpecConstantComposite", OpenCL_std, SpecConstant, 1, 0, OpSpecConstantComposite>; + +// Async Copy and Prefetch builtin records: +defm : DemangledNativeBuiltin<"async_work_group_copy", OpenCL_std, AsyncCopy, 4, 4, OpGroupAsyncCopy>; +defm : DemangledNativeBuiltin<"wait_group_events", OpenCL_std, AsyncCopy, 2, 2, OpGroupWaitEvents>; + +//===----------------------------------------------------------------------===// +// Class defining a work/sub group builtin that should be translated into a +// SPIR-V instruction using the defined properties. +// +// name is the demangled name of the given builtin. +// opcode specifies the SPIR-V operation code of the generated instruction. +//===----------------------------------------------------------------------===// +class GroupBuiltin { + string Name = name; + Op Opcode = operation; + bits<32> GroupOperation = !cond(!not(!eq(!find(name, "group_reduce"), -1)) : Reduce.Value, + !not(!eq(!find(name, "group_scan_inclusive"), -1)) : InclusiveScan.Value, + !not(!eq(!find(name, "group_scan_exclusive"), -1)) : ExclusiveScan.Value, + !not(!eq(!find(name, "group_ballot_bit_count"), -1)) : Reduce.Value, + !not(!eq(!find(name, "group_ballot_inclusive_scan"), -1)) : InclusiveScan.Value, + !not(!eq(!find(name, "group_ballot_exclusive_scan"), -1)) : ExclusiveScan.Value, + !not(!eq(!find(name, "group_non_uniform_reduce"), -1)) : Reduce.Value, + !not(!eq(!find(name, "group_non_uniform_scan_inclusive"), -1)) : InclusiveScan.Value, + !not(!eq(!find(name, "group_non_uniform_scan_exclusive"), -1)) : ExclusiveScan.Value, + !not(!eq(!find(name, "group_non_uniform_reduce_logical"), -1)) : Reduce.Value, + !not(!eq(!find(name, "group_non_uniform_scan_inclusive_logical"), -1)) : InclusiveScan.Value, + !not(!eq(!find(name, "group_non_uniform_scan_exclusive_logical"), -1)) : ExclusiveScan.Value, + !not(!eq(!find(name, "group_clustered_reduce"), -1)) : ClusteredReduce.Value, + !not(!eq(!find(name, "group_clustered_reduce_logical"), -1)) : ClusteredReduce.Value, + true : 0); + bit IsElect = !eq(operation, OpGroupNonUniformElect); + bit IsAllOrAny = !or(!eq(operation, OpGroupAll), + !eq(operation, OpGroupAny), + !eq(operation, OpGroupNonUniformAll), + !eq(operation, OpGroupNonUniformAny)); + bit IsAllEqual = !eq(operation, OpGroupNonUniformAllEqual); + bit IsBallot = !eq(operation, OpGroupNonUniformBallot); + bit IsInverseBallot = !eq(operation, OpGroupNonUniformInverseBallot); + bit IsBallotBitExtract = !eq(operation, OpGroupNonUniformBallotBitExtract); + bit IsBallotFindBit = !or(!eq(operation, OpGroupNonUniformBallotFindLSB), + !eq(operation, OpGroupNonUniformBallotFindMSB)); + bit IsLogical = !or(!eq(operation, OpGroupNonUniformLogicalAnd), + !eq(operation, OpGroupNonUniformLogicalOr), + !eq(operation, OpGroupNonUniformLogicalXor)); + bit NoGroupOperation = !or(IsElect, IsAllOrAny, IsAllEqual, + IsBallot, IsInverseBallot, + IsBallotBitExtract, IsBallotFindBit, + !eq(operation, OpGroupNonUniformShuffle), + !eq(operation, OpGroupNonUniformShuffleXor), + !eq(operation, OpGroupNonUniformShuffleUp), + !eq(operation, OpGroupNonUniformShuffleDown), + !eq(operation, OpGroupBroadcast), + !eq(operation, OpGroupNonUniformBroadcast), + !eq(operation, OpGroupNonUniformBroadcastFirst)); + bit HasBoolArg = !or(!and(IsAllOrAny, !eq(IsAllEqual, false)), IsBallot, IsLogical); +} + +// Table gathering all the work/sub group builtins. +def GroupBuiltins : GenericTable { + let FilterClass = "GroupBuiltin"; + let Fields = ["Name", "Opcode", "GroupOperation", "IsElect", "IsAllOrAny", + "IsAllEqual", "IsBallot", "IsInverseBallot", "IsBallotBitExtract", + "IsBallotFindBit", "IsLogical", "NoGroupOperation", "HasBoolArg"]; +} + +// Function to lookup native builtins by their name and set. +def lookupGroupBuiltin : SearchIndex { + let Table = GroupBuiltins; + let Key = ["Name"]; +} + +// Multiclass used to define at the same time both incoming builtin records +// and corresponding work/sub group builtin records. +defvar OnlyWork = 0; defvar OnlySub = 1; defvar WorkOrSub = 2; +multiclass DemangledGroupBuiltin { + assert !and(!ge(level, 0), !le(level, 2)), "group level is invalid: " # level; + + if !or(!eq(level, OnlyWork), !eq(level, WorkOrSub)) then { + def : DemangledBuiltin; + def : GroupBuiltin; + } + + if !or(!eq(level, OnlySub), !eq(level, WorkOrSub)) then { + def : DemangledBuiltin; + def : GroupBuiltin; + } +} + +defm : DemangledGroupBuiltin<"group_all", WorkOrSub, OpGroupAll>; +defm : DemangledGroupBuiltin<"group_any", WorkOrSub, OpGroupAny>; +defm : DemangledGroupBuiltin<"group_broadcast", WorkOrSub, OpGroupBroadcast>; +defm : DemangledGroupBuiltin<"group_non_uniform_broadcast", OnlySub, OpGroupNonUniformBroadcast>; +defm : DemangledGroupBuiltin<"group_broadcast_first", OnlySub, OpGroupNonUniformBroadcastFirst>; + +// cl_khr_subgroup_non_uniform_vote +defm : DemangledGroupBuiltin<"group_elect", OnlySub, OpGroupNonUniformElect>; +defm : DemangledGroupBuiltin<"group_non_uniform_all", OnlySub, OpGroupNonUniformAll>; +defm : DemangledGroupBuiltin<"group_non_uniform_any", OnlySub, OpGroupNonUniformAny>; +defm : DemangledGroupBuiltin<"group_non_uniform_all_equal", OnlySub, OpGroupNonUniformAllEqual>; + +// cl_khr_subgroup_ballot +defm : DemangledGroupBuiltin<"group_ballot", OnlySub, OpGroupNonUniformBallot>; +defm : DemangledGroupBuiltin<"group_inverse_ballot", OnlySub, OpGroupNonUniformInverseBallot>; +defm : DemangledGroupBuiltin<"group_ballot_bit_extract", OnlySub, OpGroupNonUniformBallotBitExtract>; +defm : DemangledGroupBuiltin<"group_ballot_bit_count", OnlySub, OpGroupNonUniformBallotBitCount>; +defm : DemangledGroupBuiltin<"group_ballot_inclusive_scan", OnlySub, OpGroupNonUniformBallotBitCount>; +defm : DemangledGroupBuiltin<"group_ballot_exclusive_scan", OnlySub, OpGroupNonUniformBallotBitCount>; +defm : DemangledGroupBuiltin<"group_ballot_find_lsb", OnlySub, OpGroupNonUniformBallotFindLSB>; +defm : DemangledGroupBuiltin<"group_ballot_find_msb", OnlySub, OpGroupNonUniformBallotFindMSB>; + +// cl_khr_subgroup_shuffle +defm : DemangledGroupBuiltin<"group_shuffle", OnlySub, OpGroupNonUniformShuffle>; +defm : DemangledGroupBuiltin<"group_shuffle_xor", OnlySub, OpGroupNonUniformShuffleXor>; + +// cl_khr_subgroup_shuffle_relative +defm : DemangledGroupBuiltin<"group_shuffle_up", OnlySub, OpGroupNonUniformShuffleUp>; +defm : DemangledGroupBuiltin<"group_shuffle_down", OnlySub, OpGroupNonUniformShuffleDown>; + +defm : DemangledGroupBuiltin<"group_iadd", WorkOrSub, OpGroupIAdd>; +defm : DemangledGroupBuiltin<"group_reduce_adds", WorkOrSub, OpGroupIAdd>; +defm : DemangledGroupBuiltin<"group_scan_exclusive_adds", WorkOrSub, OpGroupIAdd>; +defm : DemangledGroupBuiltin<"group_scan_inclusive_adds", WorkOrSub, OpGroupIAdd>; +defm : DemangledGroupBuiltin<"group_reduce_addu", WorkOrSub, OpGroupIAdd>; +defm : DemangledGroupBuiltin<"group_scan_exclusive_addu", WorkOrSub, OpGroupIAdd>; +defm : DemangledGroupBuiltin<"group_scan_inclusive_addu", WorkOrSub, OpGroupIAdd>; + +defm : DemangledGroupBuiltin<"group_fadd", WorkOrSub, OpGroupFAdd>; +defm : DemangledGroupBuiltin<"group_reduce_addf", WorkOrSub, OpGroupFAdd>; +defm : DemangledGroupBuiltin<"group_scan_exclusive_addf", WorkOrSub, OpGroupFAdd>; +defm : DemangledGroupBuiltin<"group_scan_inclusive_addf", WorkOrSub, OpGroupFAdd>; + +defm : DemangledGroupBuiltin<"group_fmin", WorkOrSub, OpGroupFMin>; +defm : DemangledGroupBuiltin<"group_reduce_minf", WorkOrSub, OpGroupFMin>; +defm : DemangledGroupBuiltin<"group_scan_exclusive_minf", WorkOrSub, OpGroupFMin>; +defm : DemangledGroupBuiltin<"group_scan_inclusive_minf", WorkOrSub, OpGroupFMin>; + +defm : DemangledGroupBuiltin<"group_umin", WorkOrSub, OpGroupUMin>; +defm : DemangledGroupBuiltin<"group_reduce_minu", WorkOrSub, OpGroupUMin>; +defm : DemangledGroupBuiltin<"group_scan_exclusive_minu", WorkOrSub, OpGroupUMin>; +defm : DemangledGroupBuiltin<"group_scan_inclusive_minu", WorkOrSub, OpGroupUMin>; + +defm : DemangledGroupBuiltin<"group_smin", WorkOrSub, OpGroupSMin>; +defm : DemangledGroupBuiltin<"group_reduce_mins", WorkOrSub, OpGroupSMin>; +defm : DemangledGroupBuiltin<"group_scan_exclusive_mins", WorkOrSub, OpGroupSMin>; +defm : DemangledGroupBuiltin<"group_scan_inclusive_mins", WorkOrSub, OpGroupSMin>; + +defm : DemangledGroupBuiltin<"group_fmax", WorkOrSub, OpGroupFMax>; +defm : DemangledGroupBuiltin<"group_reduce_maxf", WorkOrSub, OpGroupFMax>; +defm : DemangledGroupBuiltin<"group_scan_exclusive_maxf", WorkOrSub, OpGroupFMax>; +defm : DemangledGroupBuiltin<"group_scan_inclusive_maxf", WorkOrSub, OpGroupFMax>; + +defm : DemangledGroupBuiltin<"group_umax", WorkOrSub, OpGroupUMax>; +defm : DemangledGroupBuiltin<"group_reduce_maxu", WorkOrSub, OpGroupUMax>; +defm : DemangledGroupBuiltin<"group_scan_exclusive_maxu", WorkOrSub, OpGroupUMax>; +defm : DemangledGroupBuiltin<"group_scan_inclusive_maxu", WorkOrSub, OpGroupUMax>; + +defm : DemangledGroupBuiltin<"group_smax", WorkOrSub, OpGroupSMax>; +defm : DemangledGroupBuiltin<"group_reduce_maxs", WorkOrSub, OpGroupSMax>; +defm : DemangledGroupBuiltin<"group_scan_exclusive_maxs", WorkOrSub, OpGroupSMax>; +defm : DemangledGroupBuiltin<"group_scan_inclusive_maxs", WorkOrSub, OpGroupSMax>; + +// cl_khr_subgroup_non_uniform_arithmetic +defm : DemangledGroupBuiltin<"group_non_uniform_iadd", WorkOrSub, OpGroupNonUniformIAdd>; +defm : DemangledGroupBuiltin<"group_non_uniform_reduce_addu", WorkOrSub, OpGroupNonUniformIAdd>; +defm : DemangledGroupBuiltin<"group_non_uniform_reduce_adds", WorkOrSub, OpGroupNonUniformIAdd>; +defm : DemangledGroupBuiltin<"group_non_uniform_scan_inclusive_addu", WorkOrSub, OpGroupNonUniformIAdd>; +defm : DemangledGroupBuiltin<"group_non_uniform_scan_inclusive_adds", WorkOrSub, OpGroupNonUniformIAdd>; +defm : DemangledGroupBuiltin<"group_non_uniform_scan_exclusive_addu", WorkOrSub, OpGroupNonUniformIAdd>; +defm : DemangledGroupBuiltin<"group_non_uniform_scan_exclusive_adds", WorkOrSub, OpGroupNonUniformIAdd>; +defm : DemangledGroupBuiltin<"group_clustered_reduce_addu", WorkOrSub, OpGroupNonUniformIAdd>; +defm : DemangledGroupBuiltin<"group_clustered_reduce_adds", WorkOrSub, OpGroupNonUniformIAdd>; + +defm : DemangledGroupBuiltin<"group_non_uniform_fadd", WorkOrSub, OpGroupNonUniformFAdd>; +defm : DemangledGroupBuiltin<"group_non_uniform_reduce_addf", WorkOrSub, OpGroupNonUniformFAdd>; +defm : DemangledGroupBuiltin<"group_non_uniform_reduce_addh", WorkOrSub, OpGroupNonUniformFAdd>; +defm : DemangledGroupBuiltin<"group_non_uniform_reduce_addd", WorkOrSub, OpGroupNonUniformFAdd>; +defm : DemangledGroupBuiltin<"group_non_uniform_scan_inclusive_addf", WorkOrSub, OpGroupNonUniformFAdd>; +defm : DemangledGroupBuiltin<"group_non_uniform_scan_inclusive_addh", WorkOrSub, OpGroupNonUniformFAdd>; +defm : DemangledGroupBuiltin<"group_non_uniform_scan_inclusive_addd", WorkOrSub, OpGroupNonUniformFAdd>; +defm : DemangledGroupBuiltin<"group_non_uniform_scan_exclusive_addf", WorkOrSub, OpGroupNonUniformFAdd>; +defm : DemangledGroupBuiltin<"group_non_uniform_scan_exclusive_addh", WorkOrSub, OpGroupNonUniformFAdd>; +defm : DemangledGroupBuiltin<"group_non_uniform_scan_exclusive_addd", WorkOrSub, OpGroupNonUniformFAdd>; +defm : DemangledGroupBuiltin<"group_clustered_reduce_addf", WorkOrSub, OpGroupNonUniformFAdd>; +defm : DemangledGroupBuiltin<"group_clustered_reduce_addh", WorkOrSub, OpGroupNonUniformFAdd>; +defm : DemangledGroupBuiltin<"group_clustered_reduce_addd", WorkOrSub, OpGroupNonUniformFAdd>; + +defm : DemangledGroupBuiltin<"group_non_uniform_imul", WorkOrSub, OpGroupNonUniformIMul>; +defm : DemangledGroupBuiltin<"group_non_uniform_reduce_mulu", WorkOrSub, OpGroupNonUniformIMul>; +defm : DemangledGroupBuiltin<"group_non_uniform_reduce_muls", WorkOrSub, OpGroupNonUniformIMul>; +defm : DemangledGroupBuiltin<"group_non_uniform_scan_inclusive_mulu", WorkOrSub, OpGroupNonUniformIMul>; +defm : DemangledGroupBuiltin<"group_non_uniform_scan_inclusive_muls", WorkOrSub, OpGroupNonUniformIMul>; +defm : DemangledGroupBuiltin<"group_non_uniform_scan_exclusive_mulu", WorkOrSub, OpGroupNonUniformIMul>; +defm : DemangledGroupBuiltin<"group_non_uniform_scan_exclusive_muls", WorkOrSub, OpGroupNonUniformIMul>; +defm : DemangledGroupBuiltin<"group_clustered_reduce_mulu", WorkOrSub, OpGroupNonUniformIMul>; +defm : DemangledGroupBuiltin<"group_clustered_reduce_muls", WorkOrSub, OpGroupNonUniformIMul>; + +defm : DemangledGroupBuiltin<"group_non_uniform_fmul", WorkOrSub, OpGroupNonUniformFMul>; +defm : DemangledGroupBuiltin<"group_non_uniform_reduce_mulf", WorkOrSub, OpGroupNonUniformFMul>; +defm : DemangledGroupBuiltin<"group_non_uniform_reduce_mulh", WorkOrSub, OpGroupNonUniformFMul>; +defm : DemangledGroupBuiltin<"group_non_uniform_reduce_muld", WorkOrSub, OpGroupNonUniformFMul>; +defm : DemangledGroupBuiltin<"group_non_uniform_scan_inclusive_mulf", WorkOrSub, OpGroupNonUniformFMul>; +defm : DemangledGroupBuiltin<"group_non_uniform_scan_inclusive_mulh", WorkOrSub, OpGroupNonUniformFMul>; +defm : DemangledGroupBuiltin<"group_non_uniform_scan_inclusive_muld", WorkOrSub, OpGroupNonUniformFMul>; +defm : DemangledGroupBuiltin<"group_non_uniform_scan_exclusive_mulf", WorkOrSub, OpGroupNonUniformFMul>; +defm : DemangledGroupBuiltin<"group_non_uniform_scan_exclusive_mulh", WorkOrSub, OpGroupNonUniformFMul>; +defm : DemangledGroupBuiltin<"group_non_uniform_scan_exclusive_muld", WorkOrSub, OpGroupNonUniformFMul>; +defm : DemangledGroupBuiltin<"group_clustered_reduce_mulf", WorkOrSub, OpGroupNonUniformFMul>; +defm : DemangledGroupBuiltin<"group_clustered_reduce_mulh", WorkOrSub, OpGroupNonUniformFMul>; +defm : DemangledGroupBuiltin<"group_clustered_reduce_muld", WorkOrSub, OpGroupNonUniformFMul>; + +defm : DemangledGroupBuiltin<"group_non_uniform_smin", WorkOrSub, OpGroupNonUniformSMin>; +defm : DemangledGroupBuiltin<"group_non_uniform_reduce_mins", WorkOrSub, OpGroupNonUniformSMin>; +defm : DemangledGroupBuiltin<"group_non_uniform_scan_inclusive_mins", WorkOrSub, OpGroupNonUniformSMin>; +defm : DemangledGroupBuiltin<"group_non_uniform_scan_exclusive_mins", WorkOrSub, OpGroupNonUniformSMin>; +defm : DemangledGroupBuiltin<"group_clustered_reduce_mins", WorkOrSub, OpGroupNonUniformSMin>; + + +defm : DemangledGroupBuiltin<"group_non_uniform_umin", WorkOrSub, OpGroupNonUniformUMin>; +defm : DemangledGroupBuiltin<"group_non_uniform_reduce_minu", WorkOrSub, OpGroupNonUniformUMin>; +defm : DemangledGroupBuiltin<"group_non_uniform_scan_inclusive_minu", WorkOrSub, OpGroupNonUniformUMin>; +defm : DemangledGroupBuiltin<"group_non_uniform_scan_exclusive_minu", WorkOrSub, OpGroupNonUniformUMin>; +defm : DemangledGroupBuiltin<"group_clustered_reduce_minu", WorkOrSub, OpGroupNonUniformUMin>; + +defm : DemangledGroupBuiltin<"group_non_uniform_fmin", WorkOrSub, OpGroupNonUniformFMin>; +defm : DemangledGroupBuiltin<"group_non_uniform_reduce_minf", WorkOrSub, OpGroupNonUniformFMin>; +defm : DemangledGroupBuiltin<"group_non_uniform_reduce_minh", WorkOrSub, OpGroupNonUniformFMin>; +defm : DemangledGroupBuiltin<"group_non_uniform_reduce_mind", WorkOrSub, OpGroupNonUniformFMin>; +defm : DemangledGroupBuiltin<"group_non_uniform_scan_inclusive_minf", WorkOrSub, OpGroupNonUniformFMin>; +defm : DemangledGroupBuiltin<"group_non_uniform_scan_inclusive_minh", WorkOrSub, OpGroupNonUniformFMin>; +defm : DemangledGroupBuiltin<"group_non_uniform_scan_inclusive_mind", WorkOrSub, OpGroupNonUniformFMin>; +defm : DemangledGroupBuiltin<"group_non_uniform_scan_exclusive_minf", WorkOrSub, OpGroupNonUniformFMin>; +defm : DemangledGroupBuiltin<"group_non_uniform_scan_exclusive_minh", WorkOrSub, OpGroupNonUniformFMin>; +defm : DemangledGroupBuiltin<"group_non_uniform_scan_exclusive_mind", WorkOrSub, OpGroupNonUniformFMin>; +defm : DemangledGroupBuiltin<"group_clustered_reduce_minf", WorkOrSub, OpGroupNonUniformFMin>; +defm : DemangledGroupBuiltin<"group_clustered_reduce_minh", WorkOrSub, OpGroupNonUniformFMin>; +defm : DemangledGroupBuiltin<"group_clustered_reduce_mind", WorkOrSub, OpGroupNonUniformFMin>; + +defm : DemangledGroupBuiltin<"group_non_uniform_smax", WorkOrSub, OpGroupNonUniformSMax>; +defm : DemangledGroupBuiltin<"group_non_uniform_reduce_maxs", WorkOrSub, OpGroupNonUniformSMax>; +defm : DemangledGroupBuiltin<"group_non_uniform_scan_inclusive_maxs", WorkOrSub, OpGroupNonUniformSMax>; +defm : DemangledGroupBuiltin<"group_non_uniform_scan_exclusive_maxs", WorkOrSub, OpGroupNonUniformSMax>; +defm : DemangledGroupBuiltin<"group_clustered_reduce_maxs", WorkOrSub, OpGroupNonUniformSMax>; + +defm : DemangledGroupBuiltin<"group_non_uniform_umax", WorkOrSub, OpGroupNonUniformUMax>; +defm : DemangledGroupBuiltin<"group_non_uniform_reduce_maxu", WorkOrSub, OpGroupNonUniformUMax>; +defm : DemangledGroupBuiltin<"group_non_uniform_scan_inclusive_maxu", WorkOrSub, OpGroupNonUniformUMax>; +defm : DemangledGroupBuiltin<"group_non_uniform_scan_exclusive_maxu", WorkOrSub, OpGroupNonUniformUMax>; +defm : DemangledGroupBuiltin<"group_clustered_reduce_maxu", WorkOrSub, OpGroupNonUniformUMax>; + +defm : DemangledGroupBuiltin<"group_non_uniform_fmax", WorkOrSub, OpGroupNonUniformFMax>; +defm : DemangledGroupBuiltin<"group_non_uniform_reduce_maxf", WorkOrSub, OpGroupNonUniformFMax>; +defm : DemangledGroupBuiltin<"group_non_uniform_reduce_maxh", WorkOrSub, OpGroupNonUniformFMax>; +defm : DemangledGroupBuiltin<"group_non_uniform_reduce_maxd", WorkOrSub, OpGroupNonUniformFMax>; +defm : DemangledGroupBuiltin<"group_non_uniform_scan_inclusive_maxf", WorkOrSub, OpGroupNonUniformFMax>; +defm : DemangledGroupBuiltin<"group_non_uniform_scan_inclusive_maxh", WorkOrSub, OpGroupNonUniformFMax>; +defm : DemangledGroupBuiltin<"group_non_uniform_scan_inclusive_maxd", WorkOrSub, OpGroupNonUniformFMax>; +defm : DemangledGroupBuiltin<"group_non_uniform_scan_exclusive_maxf", WorkOrSub, OpGroupNonUniformFMax>; +defm : DemangledGroupBuiltin<"group_non_uniform_scan_exclusive_maxh", WorkOrSub, OpGroupNonUniformFMax>; +defm : DemangledGroupBuiltin<"group_non_uniform_scan_exclusive_maxd", WorkOrSub, OpGroupNonUniformFMax>; +defm : DemangledGroupBuiltin<"group_clustered_reduce_maxf", WorkOrSub, OpGroupNonUniformFMax>; +defm : DemangledGroupBuiltin<"group_clustered_reduce_maxh", WorkOrSub, OpGroupNonUniformFMax>; +defm : DemangledGroupBuiltin<"group_clustered_reduce_maxd", WorkOrSub, OpGroupNonUniformFMax>; + +defm : DemangledGroupBuiltin<"group_non_uniform_iand", WorkOrSub, OpGroupNonUniformBitwiseAnd>; +defm : DemangledGroupBuiltin<"group_non_uniform_reduce_andu", WorkOrSub, OpGroupNonUniformBitwiseAnd>; +defm : DemangledGroupBuiltin<"group_non_uniform_reduce_ands", WorkOrSub, OpGroupNonUniformBitwiseAnd>; +defm : DemangledGroupBuiltin<"group_non_uniform_scan_inclusive_andu", WorkOrSub, OpGroupNonUniformBitwiseAnd>; +defm : DemangledGroupBuiltin<"group_non_uniform_scan_inclusive_ands", WorkOrSub, OpGroupNonUniformBitwiseAnd>; +defm : DemangledGroupBuiltin<"group_non_uniform_scan_exclusive_andu", WorkOrSub, OpGroupNonUniformBitwiseAnd>; +defm : DemangledGroupBuiltin<"group_non_uniform_scan_exclusive_ands", WorkOrSub, OpGroupNonUniformBitwiseAnd>; +defm : DemangledGroupBuiltin<"group_clustered_reduce_andu", WorkOrSub, OpGroupNonUniformBitwiseAnd>; +defm : DemangledGroupBuiltin<"group_clustered_reduce_ands", WorkOrSub, OpGroupNonUniformBitwiseAnd>; + +defm : DemangledGroupBuiltin<"group_non_uniform_ior", WorkOrSub, OpGroupNonUniformBitwiseOr>; +defm : DemangledGroupBuiltin<"group_non_uniform_reduce_oru", WorkOrSub, OpGroupNonUniformBitwiseOr>; +defm : DemangledGroupBuiltin<"group_non_uniform_reduce_ors", WorkOrSub, OpGroupNonUniformBitwiseOr>; +defm : DemangledGroupBuiltin<"group_non_uniform_scan_inclusive_oru", WorkOrSub, OpGroupNonUniformBitwiseOr>; +defm : DemangledGroupBuiltin<"group_non_uniform_scan_inclusive_ors", WorkOrSub, OpGroupNonUniformBitwiseOr>; +defm : DemangledGroupBuiltin<"group_non_uniform_scan_exclusive_oru", WorkOrSub, OpGroupNonUniformBitwiseOr>; +defm : DemangledGroupBuiltin<"group_non_uniform_scan_exclusive_ors", WorkOrSub, OpGroupNonUniformBitwiseOr>; +defm : DemangledGroupBuiltin<"group_clustered_reduce_oru", WorkOrSub, OpGroupNonUniformBitwiseOr>; +defm : DemangledGroupBuiltin<"group_clustered_reduce_ors", WorkOrSub, OpGroupNonUniformBitwiseOr>; + +defm : DemangledGroupBuiltin<"group_non_uniform_ixor", WorkOrSub, OpGroupNonUniformBitwiseXor>; +defm : DemangledGroupBuiltin<"group_non_uniform_reduce_xoru", WorkOrSub, OpGroupNonUniformBitwiseXor>; +defm : DemangledGroupBuiltin<"group_non_uniform_reduce_xors", WorkOrSub, OpGroupNonUniformBitwiseXor>; +defm : DemangledGroupBuiltin<"group_non_uniform_scan_inclusive_xoru", WorkOrSub, OpGroupNonUniformBitwiseXor>; +defm : DemangledGroupBuiltin<"group_non_uniform_scan_inclusive_xors", WorkOrSub, OpGroupNonUniformBitwiseXor>; +defm : DemangledGroupBuiltin<"group_non_uniform_scan_exclusive_xoru", WorkOrSub, OpGroupNonUniformBitwiseXor>; +defm : DemangledGroupBuiltin<"group_non_uniform_scan_exclusive_xors", WorkOrSub, OpGroupNonUniformBitwiseXor>; +defm : DemangledGroupBuiltin<"group_clustered_reduce_xoru", WorkOrSub, OpGroupNonUniformBitwiseXor>; +defm : DemangledGroupBuiltin<"group_clustered_reduce_xors", WorkOrSub, OpGroupNonUniformBitwiseXor>; + +defm : DemangledGroupBuiltin<"group_non_uniform_logical_iand", WorkOrSub, OpGroupNonUniformLogicalAnd>; +defm : DemangledGroupBuiltin<"group_non_uniform_reduce_logical_ands", WorkOrSub, OpGroupNonUniformLogicalAnd>; +defm : DemangledGroupBuiltin<"group_non_uniform_scan_inclusive_logical_ands", WorkOrSub, OpGroupNonUniformLogicalAnd>; +defm : DemangledGroupBuiltin<"group_non_uniform_scan_exclusive_logical_ands", WorkOrSub, OpGroupNonUniformLogicalAnd>; +defm : DemangledGroupBuiltin<"group_clustered_reduce_logical_and", WorkOrSub, OpGroupNonUniformLogicalAnd>; + +defm : DemangledGroupBuiltin<"group_non_uniform_logical_ior", WorkOrSub, OpGroupNonUniformLogicalOr>; +defm : DemangledGroupBuiltin<"group_non_uniform_reduce_logical_ors", WorkOrSub, OpGroupNonUniformLogicalOr>; +defm : DemangledGroupBuiltin<"group_non_uniform_scan_inclusive_logical_ors", WorkOrSub, OpGroupNonUniformLogicalOr>; +defm : DemangledGroupBuiltin<"group_non_uniform_scan_exclusive_logical_ors", WorkOrSub, OpGroupNonUniformLogicalOr>; +defm : DemangledGroupBuiltin<"group_clustered_reduce_logical_or", WorkOrSub, OpGroupNonUniformLogicalOr>; + +defm : DemangledGroupBuiltin<"group_non_uniform_logical_ixor", WorkOrSub, OpGroupNonUniformLogicalXor>; +defm : DemangledGroupBuiltin<"group_non_uniform_reduce_logical_xors", WorkOrSub, OpGroupNonUniformLogicalXor>; +defm : DemangledGroupBuiltin<"group_non_uniform_scan_inclusive_logical_xors", WorkOrSub, OpGroupNonUniformLogicalXor>; +defm : DemangledGroupBuiltin<"group_non_uniform_scan_exclusive_logical_xors", WorkOrSub, OpGroupNonUniformLogicalXor>; +defm : DemangledGroupBuiltin<"group_clustered_reduce_logical_xor", WorkOrSub, OpGroupNonUniformLogicalXor>; + + +//===----------------------------------------------------------------------===// +// Class defining a get builtin record used for lowering builtin calls such as +// "get_sub_group_eq_mask" or "get_global_id" to SPIR-V instructions. +// +// name is the demangled name of the given builtin. +// set specifies which external instruction set the builtin belongs to. +// value specifies the value of the BuiltIn enum. +//===----------------------------------------------------------------------===// +class GetBuiltin { + string Name = name; + InstructionSet Set = set; + BuiltIn Value = value; +} + +// Table gathering all the get builtin records. +def GetBuiltins : GenericTable { + let FilterClass = "GetBuiltin"; + let Fields = ["Name", "Set", "Value"]; + string TypeOf_Set = "InstructionSet"; + string TypeOf_Value = "BuiltIn"; +} + +// Function to lookup get builtin records by their name and set. +def lookupGetBuiltin : SearchIndex { + let Table = GetBuiltins; + let Key = ["Name", "Set"]; +} + +// Multiclass used to define at the same time both a demangled builtin record +// and a corresponding get builtin record. +multiclass DemangledGetBuiltin { + def : DemangledBuiltin; + def : GetBuiltin; +} + +// Builtin variable records: +defm : DemangledGetBuiltin<"get_sub_group_eq_mask", OpenCL_std, Variable, SubgroupEqMask>; +defm : DemangledGetBuiltin<"get_sub_group_ge_mask", OpenCL_std, Variable, SubgroupGeMask>; +defm : DemangledGetBuiltin<"get_sub_group_gt_mask", OpenCL_std, Variable, SubgroupGtMask>; +defm : DemangledGetBuiltin<"get_sub_group_le_mask", OpenCL_std, Variable, SubgroupLeMask>; +defm : DemangledGetBuiltin<"get_sub_group_lt_mask", OpenCL_std, Variable, SubgroupLtMask>; +defm : DemangledGetBuiltin<"__spirv_BuiltInGlobalLinearId", OpenCL_std, Variable, GlobalLinearId>; +defm : DemangledGetBuiltin<"__spirv_BuiltInGlobalInvocationId", OpenCL_std, Variable, GlobalInvocationId>; + +// GetQuery builtin records: +defm : DemangledGetBuiltin<"get_local_id", OpenCL_std, GetQuery, LocalInvocationId>; +defm : DemangledGetBuiltin<"get_global_id", OpenCL_std, GetQuery, GlobalInvocationId>; +defm : DemangledGetBuiltin<"get_local_size", OpenCL_std, GetQuery, WorkgroupSize>; +defm : DemangledGetBuiltin<"get_global_size", OpenCL_std, GetQuery, GlobalSize>; +defm : DemangledGetBuiltin<"get_group_id", OpenCL_std, GetQuery, WorkgroupId>; +defm : DemangledGetBuiltin<"get_enqueued_local_size", OpenCL_std, GetQuery, EnqueuedWorkgroupSize>; +defm : DemangledGetBuiltin<"get_num_groups", OpenCL_std, GetQuery, NumWorkgroups>; + +//===----------------------------------------------------------------------===// +// Class defining an image query builtin record used for lowering the OpenCL +// "get_image_*" calls into OpImageQuerySize/OpImageQuerySizeLod instructions. +// +// name is the demangled name of the given builtin. +// set specifies which external instruction set the builtin belongs to. +// component specifies the unsigned number of the query component. +//===----------------------------------------------------------------------===// +class ImageQueryBuiltin component> { + string Name = name; + InstructionSet Set = set; + bits<32> Component = component; +} + +// Table gathering all the image query builtins. +def ImageQueryBuiltins : GenericTable { + let FilterClass = "ImageQueryBuiltin"; + let Fields = ["Name", "Set", "Component"]; + string TypeOf_Set = "InstructionSet"; +} + +// Function to lookup image query builtins by their name and set. +def lookupImageQueryBuiltin : SearchIndex { + let Table = ImageQueryBuiltins; + let Key = ["Name", "Set"]; +} + +// Multiclass used to define at the same time both a demangled builtin record +// and a corresponding image query builtin record. +multiclass DemangledImageQueryBuiltin { + def : DemangledBuiltin; + def : ImageQueryBuiltin; +} + +// Image query builtin records: +defm : DemangledImageQueryBuiltin<"get_image_width", OpenCL_std, 0>; +defm : DemangledImageQueryBuiltin<"get_image_height", OpenCL_std, 1>; +defm : DemangledImageQueryBuiltin<"get_image_depth", OpenCL_std, 2>; +defm : DemangledImageQueryBuiltin<"get_image_dim", OpenCL_std, 0>; +defm : DemangledImageQueryBuiltin<"get_image_array_size", OpenCL_std, 3>; + +defm : DemangledNativeBuiltin<"get_image_num_samples", OpenCL_std, ImageMiscQuery, 1, 1, OpImageQuerySamples>; + +//===----------------------------------------------------------------------===// +// Class defining a "convert_destType<_sat><_roundingMode>" call record for +// lowering into OpConvert instructions. +// +// name is the demangled name of the given builtin. +// set specifies which external instruction set the builtin belongs to. +//===----------------------------------------------------------------------===// +class ConvertBuiltin { + string Name = name; + InstructionSet Set = set; + bit IsDestinationSigned = !eq(!find(name, "convert_u"), -1); + bit IsSaturated = !not(!eq(!find(name, "_sat"), -1)); + bit IsRounded = !not(!eq(!find(name, "_rt"), -1)); + FPRoundingMode RoundingMode = !cond(!not(!eq(!find(name, "_rte"), -1)) : RTE, + !not(!eq(!find(name, "_rtz"), -1)) : RTZ, + !not(!eq(!find(name, "_rtp"), -1)) : RTP, + !not(!eq(!find(name, "_rtn"), -1)) : RTN, + true : RTE); +} + +// Table gathering all the convert builtins. +def ConvertBuiltins : GenericTable { + let FilterClass = "ConvertBuiltin"; + let Fields = ["Name", "Set", "IsDestinationSigned", "IsSaturated", "IsRounded", "RoundingMode"]; + string TypeOf_Set = "InstructionSet"; + string TypeOf_RoundingMode = "FPRoundingMode"; +} + +// Function to lookup convert builtins by their name and set. +def lookupConvertBuiltin : SearchIndex { + let Table = ConvertBuiltins; + let Key = ["Name", "Set"]; +} + +// Multiclass used to define at the same time both a demangled builtin records +// and a corresponding convert builtin records. +multiclass DemangledConvertBuiltin { + // Create records for scalar and 2, 4, 8, and 16 element vector conversions. + foreach i = ["", "2", "3", "4", "8", "16"] in { + // Also create records for each rounding mode. + foreach j = ["", "_rte", "_rtz", "_rtp", "_rtn"] in { + def : DemangledBuiltin; + def : ConvertBuiltin; + + // Create records with the "_sat" modifier for all conversions except + // those targeting floating-point types. + if !eq(!find(name, "float"), -1) then { + def : DemangledBuiltin; + def : ConvertBuiltin; + } + } + } +} + +// Explicit conversion builtin records: +defm : DemangledConvertBuiltin<"convert_char", OpenCL_std>; +defm : DemangledConvertBuiltin<"convert_uchar", OpenCL_std>; +defm : DemangledConvertBuiltin<"convert_short", OpenCL_std>; +defm : DemangledConvertBuiltin<"convert_ushort", OpenCL_std>; +defm : DemangledConvertBuiltin<"convert_int", OpenCL_std>; +defm : DemangledConvertBuiltin<"convert_uint", OpenCL_std>; +defm : DemangledConvertBuiltin<"convert_long", OpenCL_std>; +defm : DemangledConvertBuiltin<"convert_ulong", OpenCL_std>; +defm : DemangledConvertBuiltin<"convert_float", OpenCL_std>; + +//===----------------------------------------------------------------------===// +// Class defining a vector data load/store builtin record used for lowering +// into OpExtInst instruction. +// +// name is the demangled name of the given builtin. +// set specifies which external instruction set the builtin belongs to. +// number specifies the number of the instruction in the external set. +//===----------------------------------------------------------------------===// +class VectorLoadStoreBuiltin { + string Name = name; + InstructionSet Set = set; + bits<32> Number = number; + bit IsRounded = !not(!eq(!find(name, "_rt"), -1)); + FPRoundingMode RoundingMode = !cond(!not(!eq(!find(name, "_rte"), -1)) : RTE, + !not(!eq(!find(name, "_rtz"), -1)) : RTZ, + !not(!eq(!find(name, "_rtp"), -1)) : RTP, + !not(!eq(!find(name, "_rtn"), -1)) : RTN, + true : RTE); +} + +// Table gathering all the vector data load/store builtins. +def VectorLoadStoreBuiltins : GenericTable { + let FilterClass = "VectorLoadStoreBuiltin"; + let Fields = ["Name", "Set", "Number", "IsRounded", "RoundingMode"]; + string TypeOf_Set = "InstructionSet"; + string TypeOf_RoundingMode = "FPRoundingMode"; +} + +// Function to lookup vector data load/store builtins by their name and set. +def lookupVectorLoadStoreBuiltin : SearchIndex { + let Table = VectorLoadStoreBuiltins; + let Key = ["Name", "Set"]; +} + +// Multiclass used to define at the same time both a demangled builtin record +// and a corresponding vector data load/store builtin record. +multiclass DemangledVectorLoadStoreBuiltin minNumArgs, bits<8> maxNumArgs, int number> { + def : DemangledBuiltin; + def : VectorLoadStoreBuiltin; +} + +// Create records for scalar and 2, 4, 8, and 16 vector element count. +foreach i = ["", "2", "3", "4", "8", "16"] in { + if !eq(i, "") then { + defm : DemangledVectorLoadStoreBuiltin<"vload_half", 2, 2, 173>; + defm : DemangledVectorLoadStoreBuiltin<"vstore_half", 3, 3, 175>; + } else { + defm : DemangledVectorLoadStoreBuiltin; + defm : DemangledVectorLoadStoreBuiltin; + } + defm : DemangledVectorLoadStoreBuiltin; + defm : DemangledVectorLoadStoreBuiltin; + defm : DemangledVectorLoadStoreBuiltin; + defm : DemangledVectorLoadStoreBuiltin; + + // Also create records for each rounding mode. + foreach j = ["_rte", "_rtz", "_rtp", "_rtn"] in { + if !eq(i, "") then { + defm : DemangledVectorLoadStoreBuiltin; + } else { + defm : DemangledVectorLoadStoreBuiltin; + } + defm : DemangledVectorLoadStoreBuiltin; + } +} + +//===----------------------------------------------------------------------===// +// Class defining implementation details of demangled builtin types. The info +// in the record is used for lowering into OpType. +// +// name is the demangled name of the given builtin. +// operation specifies the SPIR-V opcode the StructType should be lowered to. +//===----------------------------------------------------------------------===// +class DemangledType { + string Name = name; + Op Opcode = operation; +} + +// Table gathering all the demangled type records. +def DemangledTypes : GenericTable { + let FilterClass = "DemangledType"; + let Fields = ["Name", "Opcode"]; +} + +// Function to lookup builtin types by their demangled name. +def lookupType : SearchIndex { + let Table = DemangledTypes; + let Key = ["Name"]; +} + +// OpenCL builtin types: +def : DemangledType<"opencl.reserve_id_t", OpTypeReserveId>; +def : DemangledType<"opencl.event_t", OpTypeEvent>; +def : DemangledType<"opencl.queue_t", OpTypeQueue>; +def : DemangledType<"opencl.sampler_t", OpTypeSampler>; +def : DemangledType<"opencl.clk_event_t", OpTypeDeviceEvent>; +def : DemangledType<"opencl.clk_event_t", OpTypeDeviceEvent>; + +// Class definining lowering details for various variants of image type indentifiers. +class ImageType { + string Name = name; + AccessQualifier Qualifier = !cond(!not(!eq(!find(name, "_ro_t"), -1)) : ReadOnly, + !not(!eq(!find(name, "_wo_t"), -1)) : WriteOnly, + !not(!eq(!find(name, "_rw_t"), -1)) : ReadWrite, + true : ReadOnly); + Dim Dimensionality = !cond(!not(!eq(!find(name, "buffer"), -1)) : DIM_Buffer, + !not(!eq(!find(name, "image1"), -1)) : DIM_1D, + !not(!eq(!find(name, "image2"), -1)) : DIM_2D, + !not(!eq(!find(name, "image3"), -1)) : DIM_3D); + bit Arrayed = !not(!eq(!find(name, "array"), -1)); + bit Depth = !not(!eq(!find(name, "depth"), -1)); +} + +// Table gathering all the image type records. +def ImageTypes : GenericTable { + let FilterClass = "ImageType"; + let Fields = ["Name", "Qualifier", "Dimensionality", "Arrayed", "Depth"]; + string TypeOf_Qualifier = "AccessQualifier"; + string TypeOf_Dimensionality = "Dim"; +} + +// Function to lookup builtin image types by their demangled name. +def lookupImageType : SearchIndex { + let Table = ImageTypes; + let Key = ["Name"]; +} + +// Multiclass used to define at the same time a DemangledType record used +// for matching an incoming demangled string to the OpTypeImage opcode and +// ImageType conatining the lowering details. +multiclass DemangledImageType { + def : DemangledType; + def : ImageType; +} + +foreach aq = ["_t", "_ro_t", "_wo_t", "_rw_t"] in { + defm : DemangledImageType; + defm : DemangledImageType; + defm : DemangledImageType; + + foreach a1 = ["", "_array"] in { + foreach a2 = ["", "_msaa"] in { + foreach a3 = ["", "_depth"] in { + defm : DemangledImageType; + } + } + } + + defm : DemangledImageType; +} + +// Class definining lowering details for various variants of pipe type indentifiers. +class PipeType { + string Name = name; + AccessQualifier Qualifier = !cond(!not(!eq(!find(name, "_ro_t"), -1)) : ReadOnly, + !not(!eq(!find(name, "_wo_t"), -1)) : WriteOnly, + !not(!eq(!find(name, "_rw_t"), -1)) : ReadWrite, + true : ReadOnly); +} + +// Table gathering all the pipe type records. +def PipeTypes : GenericTable { + let FilterClass = "PipeType"; + let Fields = ["Name", "Qualifier"]; + string TypeOf_Qualifier = "AccessQualifier"; +} + +// Function to lookup builtin pipe types by their demangled name. +def lookupPipeType : SearchIndex { + let Table = PipeTypes; + let Key = ["Name"]; +} + +// Multiclass used to define at the same time a DemangledType record used +// for matching an incoming demangled string to the OpTypePipe opcode and +// PipeType conatining the lowering details. +multiclass DemangledPipeType { + def : DemangledType; + def : PipeType; +} + +foreach aq = ["_t", "_ro_t", "_wo_t", "_rw_t"] in { + defm : DemangledPipeType; +} + +//===----------------------------------------------------------------------===// +// Classes definining various OpenCL enums. +//===----------------------------------------------------------------------===// + +// OpenCL memory_scope enum +def CLMemoryScope : GenericEnum { + let FilterClass = "CLMemoryScope"; + let NameField = "Name"; + let ValueField = "Value"; +} + +class CLMemoryScope value> { + string Name = NAME; + bits<32> Value = value; +} + +def memory_scope_work_item : CLMemoryScope<0>; +def memory_scope_work_group : CLMemoryScope<1>; +def memory_scope_device : CLMemoryScope<2>; +def memory_scope_all_svm_devices : CLMemoryScope<3>; +def memory_scope_sub_group : CLMemoryScope<4>; + +// OpenCL sampler addressing mode/bitmask enum +def CLSamplerAddressingMode : GenericEnum { + let FilterClass = "CLSamplerAddressingMode"; + let NameField = "Name"; + let ValueField = "Value"; +} + +class CLSamplerAddressingMode value> { + string Name = NAME; + bits<32> Value = value; +} + +def CLK_ADDRESS_NONE : CLSamplerAddressingMode<0x0>; +def CLK_ADDRESS_CLAMP : CLSamplerAddressingMode<0x4>; +def CLK_ADDRESS_CLAMP_TO_EDGE : CLSamplerAddressingMode<0x2>; +def CLK_ADDRESS_REPEAT : CLSamplerAddressingMode<0x6>; +def CLK_ADDRESS_MIRRORED_REPEAT : CLSamplerAddressingMode<0x8>; +def CLK_ADDRESS_MODE_MASK : CLSamplerAddressingMode<0xE>; +def CLK_NORMALIZED_COORDS_FALSE : CLSamplerAddressingMode<0x0>; +def CLK_NORMALIZED_COORDS_TRUE : CLSamplerAddressingMode<0x1>; +def CLK_FILTER_NEAREST : CLSamplerAddressingMode<0x10>; +def CLK_FILTER_LINEAR : CLSamplerAddressingMode<0x20>; + +// OpenCL memory fences +def CLMemoryFenceFlags : GenericEnum { + let FilterClass = "CLMemoryFenceFlags"; + let NameField = "Name"; + let ValueField = "Value"; +} + +class CLMemoryFenceFlags value> { + string Name = NAME; + bits<32> Value = value; +} + +def CLK_LOCAL_MEM_FENCE : CLMemoryFenceFlags<0x1>; +def CLK_GLOBAL_MEM_FENCE : CLMemoryFenceFlags<0x2>; +def CLK_IMAGE_MEM_FENCE : CLMemoryFenceFlags<0x4>; Index: llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp =================================================================== --- llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp +++ llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp @@ -14,6 +14,7 @@ #include "SPIRVCallLowering.h" #include "MCTargetDesc/SPIRVBaseInfo.h" #include "SPIRV.h" +#include "SPIRVBuiltins.h" #include "SPIRVGlobalRegistry.h" #include "SPIRVISelLowering.h" #include "SPIRVRegisterInfo.h" @@ -284,6 +285,28 @@ Register ResVReg = Info.OrigRet.Regs.empty() ? Register(0) : Info.OrigRet.Regs[0]; + std::string FuncName = Info.Callee.getGlobal()->getGlobalIdentifier(); + std::string DemangledName = mayBeOclOrSpirvBuiltin(FuncName); + const auto *ST = static_cast(&MF.getSubtarget()); + // TODO: check that it's OCL builtin, then apply OpenCL_std. + if (!DemangledName.empty() && CF && CF->isDeclaration() && + ST->canUseExtInstSet(SPIRV::InstructionSet::OpenCL_std)) { + const Type *OrigRetTy = Info.OrigRet.Ty; + if (FTy) + OrigRetTy = FTy->getReturnType(); + SmallVector ArgVRegs; + for (auto Arg : Info.OrigArgs) { + assert(Arg.Regs.size() == 1 && "Call arg has multiple VRegs"); + ArgVRegs.push_back(Arg.Regs[0]); + SPIRVType *SPIRVTy = GR->getOrCreateSPIRVType(Arg.Ty, MIRBuilder); + GR->assignSPIRVTypeToVReg(SPIRVTy, Arg.Regs[0], MIRBuilder.getMF()); + } + auto Res = + SPIRV::lowerBuiltin(DemangledName, SPIRV::InstructionSet::OpenCL_std, + MIRBuilder, ResVReg, OrigRetTy, ArgVRegs, GR); + if (Res.first) + return Res.second; + } if (CF && CF->isDeclaration() && !GR->find(CF, &MIRBuilder.getMF()).isValid()) { // Emit the type info and forward function declaration to the first MBB @@ -324,7 +347,6 @@ return false; MIB.addUse(Arg.Regs[0]); } - const auto &STI = MF.getSubtarget(); - return MIB.constrainAllUses(MIRBuilder.getTII(), *STI.getRegisterInfo(), - *STI.getRegBankInfo()); + return MIB.constrainAllUses(MIRBuilder.getTII(), *ST->getRegisterInfo(), + *ST->getRegBankInfo()); } Index: llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.h =================================================================== --- llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.h +++ llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.h @@ -50,8 +50,122 @@ const SmallVector &getDeps() const { return Deps; } void addDep(DTSortableEntry *E) { Deps.push_back(E); } }; + +struct SpecialTypeDescriptor { + enum SpecialTypeKind { + STK_Empty = 0, + STK_Image, + STK_SampledImage, + STK_Sampler, + STK_Pipe, + STK_Last = -1 + }; + SpecialTypeKind Kind; + + unsigned Hash; + + SpecialTypeDescriptor() = delete; + SpecialTypeDescriptor(SpecialTypeKind K) : Kind(K) { Hash = Kind; } + + unsigned getHash() const { return Hash; } + + virtual ~SpecialTypeDescriptor() {} +}; + +struct ImageTypeDescriptor : public SpecialTypeDescriptor { + union ImageAttrs { + struct BitFlags { + unsigned Dim : 3; + unsigned Depth : 2; + unsigned Arrayed : 1; + unsigned MS : 1; + unsigned Sampled : 2; + unsigned ImageFormat : 6; + unsigned AQ : 2; + } Flags; + unsigned Val; + }; + + ImageTypeDescriptor(const Type *SampledTy, unsigned Dim, unsigned Depth, + unsigned Arrayed, unsigned MS, unsigned Sampled, + unsigned ImageFormat, unsigned AQ = 0) + : SpecialTypeDescriptor(SpecialTypeKind::STK_Image) { + ImageAttrs Attrs; + Attrs.Val = 0; + Attrs.Flags.Dim = Dim; + Attrs.Flags.Depth = Depth; + Attrs.Flags.Arrayed = Arrayed; + Attrs.Flags.MS = MS; + Attrs.Flags.Sampled = Sampled; + Attrs.Flags.ImageFormat = ImageFormat; + Attrs.Flags.AQ = AQ; + Hash = (DenseMapInfo().getHashValue(SampledTy) & 0xffff) ^ + ((Attrs.Val << 8) | Kind); + } + + static bool classof(const SpecialTypeDescriptor *TD) { + return TD->Kind == SpecialTypeKind::STK_Image; + } +}; + +struct SampledImageTypeDescriptor : public SpecialTypeDescriptor { + SampledImageTypeDescriptor(const Type *SampledTy, const MachineInstr *ImageTy) + : SpecialTypeDescriptor(SpecialTypeKind::STK_SampledImage) { + assert(ImageTy->getOpcode() == SPIRV::OpTypeImage); + ImageTypeDescriptor TD( + SampledTy, ImageTy->getOperand(2).getImm(), + ImageTy->getOperand(3).getImm(), ImageTy->getOperand(4).getImm(), + ImageTy->getOperand(5).getImm(), ImageTy->getOperand(6).getImm(), + ImageTy->getOperand(7).getImm(), ImageTy->getOperand(8).getImm()); + Hash = TD.getHash() ^ Kind; + } + + static bool classof(const SpecialTypeDescriptor *TD) { + return TD->Kind == SpecialTypeKind::STK_SampledImage; + } +}; + +struct SamplerTypeDescriptor : public SpecialTypeDescriptor { + SamplerTypeDescriptor() + : SpecialTypeDescriptor(SpecialTypeKind::STK_Sampler) { + Hash = Kind; + } + + static bool classof(const SpecialTypeDescriptor *TD) { + return TD->Kind == SpecialTypeKind::STK_Sampler; + } +}; + +struct PipeTypeDescriptor : public SpecialTypeDescriptor { + + PipeTypeDescriptor(uint8_t AQ) + : SpecialTypeDescriptor(SpecialTypeKind::STK_Pipe) { + Hash = (AQ << 8) | Kind; + } + + static bool classof(const SpecialTypeDescriptor *TD) { + return TD->Kind == SpecialTypeKind::STK_Pipe; + } +}; } // namespace SPIRV +template <> struct DenseMapInfo { + static inline SPIRV::SpecialTypeDescriptor getEmptyKey() { + return SPIRV::SpecialTypeDescriptor( + SPIRV::SpecialTypeDescriptor::STK_Empty); + } + static inline SPIRV::SpecialTypeDescriptor getTombstoneKey() { + return SPIRV::SpecialTypeDescriptor(SPIRV::SpecialTypeDescriptor::STK_Last); + } + static unsigned getHashValue(SPIRV::SpecialTypeDescriptor Val) { + return Val.getHash(); + } + static bool isEqual(SPIRV::SpecialTypeDescriptor LHS, + SPIRV::SpecialTypeDescriptor RHS) { + return getHashValue(LHS) == getHashValue(RHS); + } +}; + template class SPIRVDuplicatesTrackerBase { public: // NOTE: using MapVector instead of DenseMap helps getting everything ordered @@ -107,12 +221,17 @@ template class SPIRVDuplicatesTracker : public SPIRVDuplicatesTrackerBase {}; +template <> +class SPIRVDuplicatesTracker + : public SPIRVDuplicatesTrackerBase {}; + class SPIRVGeneralDuplicatesTracker { SPIRVDuplicatesTracker TT; SPIRVDuplicatesTracker CT; SPIRVDuplicatesTracker GT; SPIRVDuplicatesTracker FT; SPIRVDuplicatesTracker AT; + SPIRVDuplicatesTracker ST; // NOTE: using MOs instead of regs to get rid of MF dependency to be able // to use flat data structure. @@ -150,6 +269,11 @@ AT.add(Arg, MF, R); } + void add(const SPIRV::SpecialTypeDescriptor &TD, const MachineFunction *MF, + Register R) { + ST.add(TD, MF, R); + } + Register find(const Type *T, const MachineFunction *MF) { return TT.find(const_cast(T), MF); } @@ -170,6 +294,11 @@ return AT.find(const_cast(Arg), MF); } + Register find(const SPIRV::SpecialTypeDescriptor &TD, + const MachineFunction *MF) { + return ST.find(TD, MF); + } + const SPIRVDuplicatesTracker *getTypes() { return &TT; } }; } // namespace llvm Index: llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.cpp =================================================================== --- llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.cpp +++ llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.cpp @@ -39,6 +39,7 @@ prebuildReg2Entry(GT, Reg2Entry); prebuildReg2Entry(FT, Reg2Entry); prebuildReg2Entry(AT, Reg2Entry); + prebuildReg2Entry(ST, Reg2Entry); for (auto &Op2E : Reg2Entry) { SPIRV::DTSortableEntry *E = Op2E.second; Index: llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h =================================================================== --- llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h +++ llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h @@ -38,6 +38,11 @@ DenseMap SPIRVToLLVMType; + // Look for an equivalent of the newType in the map. Return the equivalent + // if it's found, otherwise insert newType to the map and return the type. + const MachineInstr *checkSpecialInstr(const SPIRV::SpecialTypeDescriptor &TD, + MachineIRBuilder &MIRBuilder); + SmallPtrSet TypesInProcessing; DenseMap ForwardPointerTypes; @@ -131,6 +136,11 @@ return Res->second; } + // Either generate a new OpTypeXXX instruction or return an existing one + // corresponding to the given string containing the name of the builtin type. + SPIRVType *getOrCreateSPIRVTypeByName(StringRef TypeStr, + MachineIRBuilder &MIRBuilder); + // Return the SPIR-V type instruction corresponding to the given VReg, or // nullptr if no such type instruction exists. SPIRVType *getSPIRVTypeForVReg(Register VReg) const; @@ -202,6 +212,16 @@ uint64_t Val, SPIRVType *SpvType, MachineIRBuilder *MIRBuilder, MachineInstr *I = nullptr, const SPIRVInstrInfo *TII = nullptr); SPIRVType *finishCreatingSPIRVType(const Type *LLVMTy, SPIRVType *SpirvType); + Register getOrCreateIntCompositeOrNull(uint64_t Val, MachineInstr &I, + SPIRVType *SpvType, + const SPIRVInstrInfo &TII, + Constant *CA, unsigned BitWidth, + unsigned ElemCnt); + Register getOrCreateIntCompositeOrNull(uint64_t Val, + MachineIRBuilder &MIRBuilder, + SPIRVType *SpvType, bool EmitIR, + Constant *CA, unsigned BitWidth, + unsigned ElemCnt); public: Register buildConstantInt(uint64_t Val, MachineIRBuilder &MIRBuilder, @@ -213,6 +233,18 @@ Register getOrCreateConsIntVector(uint64_t Val, MachineInstr &I, SPIRVType *SpvType, const SPIRVInstrInfo &TII); + Register getOrCreateConsIntArray(uint64_t Val, MachineInstr &I, + SPIRVType *SpvType, + const SPIRVInstrInfo &TII); + Register getOrCreateConsIntVector(uint64_t Val, MachineIRBuilder &MIRBuilder, + SPIRVType *SpvType, bool EmitIR = true); + Register getOrCreateConsIntArray(uint64_t Val, MachineIRBuilder &MIRBuilder, + SPIRVType *SpvType, bool EmitIR = true); + + Register buildConstantSampler(Register Res, unsigned AddrMode, unsigned Param, + unsigned FilerMode, + MachineIRBuilder &MIRBuilder, + SPIRVType *SpvType); Register getOrCreateUndef(MachineInstr &I, SPIRVType *SpvType, const SPIRVInstrInfo &TII); Register buildGlobalVariable(Register Reg, SPIRVType *BaseType, @@ -244,6 +276,9 @@ SPIRVType *getOrCreateSPIRVPointerType( SPIRVType *BaseType, MachineInstr &I, const SPIRVInstrInfo &TII, SPIRV::StorageClass::StorageClass SClass = SPIRV::StorageClass::Function); + SPIRVType *getOrCreateOpTypeSampledImage(SPIRVType *ImageType, + MachineIRBuilder &MIRBuilder); + SPIRVType *getOrCreateOpTypeFunctionWithArgs( const Type *Ty, SPIRVType *RetType, const SmallVectorImpl &ArgTypes, Index: llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp =================================================================== --- llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp +++ llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp @@ -243,21 +243,13 @@ return Res; } -Register -SPIRVGlobalRegistry::getOrCreateConsIntVector(uint64_t Val, MachineInstr &I, - SPIRVType *SpvType, - const SPIRVInstrInfo &TII) { - const Type *LLVMTy = getTypeForSPIRVType(SpvType); - assert(LLVMTy->isVectorTy()); - const FixedVectorType *LLVMVecTy = cast(LLVMTy); - Type *LLVMBaseTy = LLVMVecTy->getElementType(); +Register SPIRVGlobalRegistry::getOrCreateIntCompositeOrNull( + uint64_t Val, MachineInstr &I, SPIRVType *SpvType, + const SPIRVInstrInfo &TII, Constant *CA, unsigned BitWidth, + unsigned ElemCnt) { // Find a constant vector in DT or build a new one. - const auto ConstInt = ConstantInt::get(LLVMBaseTy, Val); - auto ConstVec = - ConstantVector::getSplat(LLVMVecTy->getElementCount(), ConstInt); - Register Res = DT.find(ConstVec, CurMF); + Register Res = DT.find(CA, CurMF); if (!Res.isValid()) { - unsigned BitWidth = getScalarOrVectorBitWidth(SpvType); SPIRVType *SpvBaseType = getOrCreateSPIRVIntegerType(BitWidth, I, TII); // SpvScalConst should be created before SpvVecConst to avoid undefined ID // error on validation. @@ -269,9 +261,8 @@ LLT LLTy = LLT::scalar(32); Register SpvVecConst = CurMF->getRegInfo().createGenericVirtualRegister(LLTy); - const unsigned ElemCnt = SpvType->getOperand(2).getImm(); - assignVectTypeToVReg(SpvBaseType, ElemCnt, SpvVecConst, I, TII); - DT.add(ConstVec, CurMF, SpvVecConst); + assignSPIRVTypeToVReg(SpvType, SpvVecConst, *CurMF); + DT.add(CA, CurMF, SpvVecConst); MachineInstrBuilder MIB; MachineBasicBlock &BB = *I.getParent(); if (Val) { @@ -294,6 +285,133 @@ return Res; } +Register +SPIRVGlobalRegistry::getOrCreateConsIntVector(uint64_t Val, MachineInstr &I, + SPIRVType *SpvType, + const SPIRVInstrInfo &TII) { + const Type *LLVMTy = getTypeForSPIRVType(SpvType); + assert(LLVMTy->isVectorTy()); + const FixedVectorType *LLVMVecTy = cast(LLVMTy); + Type *LLVMBaseTy = LLVMVecTy->getElementType(); + const auto ConstInt = ConstantInt::get(LLVMBaseTy, Val); + auto ConstVec = + ConstantVector::getSplat(LLVMVecTy->getElementCount(), ConstInt); + unsigned BW = getScalarOrVectorBitWidth(SpvType); + return getOrCreateIntCompositeOrNull(Val, I, SpvType, TII, ConstVec, BW, + SpvType->getOperand(2).getImm()); +} + +Register +SPIRVGlobalRegistry::getOrCreateConsIntArray(uint64_t Val, MachineInstr &I, + SPIRVType *SpvType, + const SPIRVInstrInfo &TII) { + const Type *LLVMTy = getTypeForSPIRVType(SpvType); + assert(LLVMTy->isArrayTy()); + const ArrayType *LLVMArrTy = cast(LLVMTy); + Type *LLVMBaseTy = LLVMArrTy->getElementType(); + const auto ConstInt = ConstantInt::get(LLVMBaseTy, Val); + auto ConstArr = + ConstantArray::get(const_cast(LLVMArrTy), {ConstInt}); + SPIRVType *SpvBaseTy = getSPIRVTypeForVReg(SpvType->getOperand(1).getReg()); + unsigned BW = getScalarOrVectorBitWidth(SpvBaseTy); + return getOrCreateIntCompositeOrNull(Val, I, SpvType, TII, ConstArr, BW, + LLVMArrTy->getNumElements()); +} + +Register SPIRVGlobalRegistry::getOrCreateIntCompositeOrNull( + uint64_t Val, MachineIRBuilder &MIRBuilder, SPIRVType *SpvType, bool EmitIR, + Constant *CA, unsigned BitWidth, unsigned ElemCnt) { + Register Res = DT.find(CA, CurMF); + if (!Res.isValid()) { + Register SpvScalConst; + if (Val || EmitIR) { + SPIRVType *SpvBaseType = + getOrCreateSPIRVIntegerType(BitWidth, MIRBuilder); + SpvScalConst = buildConstantInt(Val, MIRBuilder, SpvBaseType, EmitIR); + } + LLT LLTy = EmitIR ? LLT::fixed_vector(ElemCnt, BitWidth) : LLT::scalar(32); + Register SpvVecConst = + CurMF->getRegInfo().createGenericVirtualRegister(LLTy); + assignSPIRVTypeToVReg(SpvType, SpvVecConst, *CurMF); + DT.add(CA, CurMF, SpvVecConst); + if (EmitIR) { + MIRBuilder.buildSplatVector(SpvVecConst, SpvScalConst); + } else { + if (Val) { + auto MIB = MIRBuilder.buildInstr(SPIRV::OpConstantComposite) + .addDef(SpvVecConst) + .addUse(getSPIRVTypeID(SpvType)); + for (unsigned i = 0; i < ElemCnt; ++i) + MIB.addUse(SpvScalConst); + } else { + MIRBuilder.buildInstr(SPIRV::OpConstantNull) + .addDef(SpvVecConst) + .addUse(getSPIRVTypeID(SpvType)); + } + } + return SpvVecConst; + } + return Res; +} + +Register +SPIRVGlobalRegistry::getOrCreateConsIntVector(uint64_t Val, + MachineIRBuilder &MIRBuilder, + SPIRVType *SpvType, bool EmitIR) { + const Type *LLVMTy = getTypeForSPIRVType(SpvType); + assert(LLVMTy->isVectorTy()); + const FixedVectorType *LLVMVecTy = cast(LLVMTy); + Type *LLVMBaseTy = LLVMVecTy->getElementType(); + const auto ConstInt = ConstantInt::get(LLVMBaseTy, Val); + auto ConstVec = + ConstantVector::getSplat(LLVMVecTy->getElementCount(), ConstInt); + unsigned BW = getScalarOrVectorBitWidth(SpvType); + return getOrCreateIntCompositeOrNull(Val, MIRBuilder, SpvType, EmitIR, + ConstVec, BW, + SpvType->getOperand(2).getImm()); +} + +Register +SPIRVGlobalRegistry::getOrCreateConsIntArray(uint64_t Val, + MachineIRBuilder &MIRBuilder, + SPIRVType *SpvType, bool EmitIR) { + const Type *LLVMTy = getTypeForSPIRVType(SpvType); + assert(LLVMTy->isArrayTy()); + const ArrayType *LLVMArrTy = cast(LLVMTy); + Type *LLVMBaseTy = LLVMArrTy->getElementType(); + const auto ConstInt = ConstantInt::get(LLVMBaseTy, Val); + auto ConstArr = + ConstantArray::get(const_cast(LLVMArrTy), {ConstInt}); + SPIRVType *SpvBaseTy = getSPIRVTypeForVReg(SpvType->getOperand(1).getReg()); + unsigned BW = getScalarOrVectorBitWidth(SpvBaseTy); + return getOrCreateIntCompositeOrNull(Val, MIRBuilder, SpvType, EmitIR, + ConstArr, BW, + LLVMArrTy->getNumElements()); +} + +Register SPIRVGlobalRegistry::buildConstantSampler( + Register ResReg, unsigned AddrMode, unsigned Param, unsigned FilerMode, + MachineIRBuilder &MIRBuilder, SPIRVType *SpvType) { + SPIRVType *SampTy; + if (SpvType) + SampTy = getOrCreateSPIRVType(getTypeForSPIRVType(SpvType), MIRBuilder); + else + SampTy = getOrCreateSPIRVTypeByName("opencl.sampler_t", MIRBuilder); + + auto Sampler = + ResReg.isValid() + ? ResReg + : MIRBuilder.getMRI()->createVirtualRegister(&SPIRV::IDRegClass); + auto Res = MIRBuilder.buildInstr(SPIRV::OpConstantSampler) + .addDef(Sampler) + .addUse(getSPIRVTypeID(SampTy)) + .addImm(AddrMode) + .addImm(Param) + .addImm(FilerMode); + assert(Res->getOperand(0).isReg()); + return Res->getOperand(0).getReg(); +} + Register SPIRVGlobalRegistry::buildGlobalVariable( Register ResVReg, SPIRVType *BaseType, StringRef Name, const GlobalValue *GV, SPIRV::StorageClass::StorageClass Storage, @@ -369,6 +487,12 @@ if (HasLinkageTy) buildOpDecorate(Reg, MIRBuilder, SPIRV::Decoration::LinkageAttributes, {static_cast(LinkageType)}, Name); + + SPIRV::BuiltIn::BuiltIn BuiltInId; + if (getSpirvBuiltInIdByName(Name, BuiltInId)) + buildOpDecorate(Reg, MIRBuilder, SPIRV::Decoration::BuiltIn, + {static_cast(BuiltInId)}); + return Reg; } @@ -680,6 +804,69 @@ Type->getOperand(1).getImm()); } +SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeSampledImage( + SPIRVType *ImageType, MachineIRBuilder &MIRBuilder) { + SPIRV::SampledImageTypeDescriptor TD( + SPIRVToLLVMType.lookup(MIRBuilder.getMF().getRegInfo().getVRegDef( + ImageType->getOperand(1).getReg())), + ImageType); + if (auto *Res = checkSpecialInstr(TD, MIRBuilder)) + return Res; + Register ResVReg = createTypeVReg(MIRBuilder); + auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeSampledImage) + .addDef(ResVReg) + .addUse(getSPIRVTypeID(ImageType)); + DT.add(TD, &MIRBuilder.getMF(), ResVReg); + return MIB; +} + +const MachineInstr * +SPIRVGlobalRegistry::checkSpecialInstr(const SPIRV::SpecialTypeDescriptor &TD, + MachineIRBuilder &MIRBuilder) { + Register Reg = DT.find(TD, &MIRBuilder.getMF()); + if (Reg.isValid()) + return MIRBuilder.getMF().getRegInfo().getUniqueVRegDef(Reg); + return nullptr; +} + +// TODO: maybe use tablegen to implement this. +SPIRVType * +SPIRVGlobalRegistry::getOrCreateSPIRVTypeByName(StringRef TypeStr, + MachineIRBuilder &MIRBuilder) { + unsigned VecElts = 0; + auto &Ctx = MIRBuilder.getMF().getFunction().getContext(); + + // Parse type name in either "typeN" or "type vector[N]" format, where + // N is the number of elements of the vector. + Type *Type; + if (TypeStr.startswith("void")) { + Type = Type::getVoidTy(Ctx); + TypeStr = TypeStr.substr(strlen("void")); + } else if (TypeStr.startswith("int") || TypeStr.startswith("uint")) { + Type = Type::getInt32Ty(Ctx); + TypeStr = TypeStr.startswith("int") ? TypeStr.substr(strlen("int")) + : TypeStr.substr(strlen("uint")); + } else if (TypeStr.startswith("float")) { + Type = Type::getFloatTy(Ctx); + TypeStr = TypeStr.substr(strlen("float")); + } else if (TypeStr.startswith("half")) { + Type = Type::getHalfTy(Ctx); + TypeStr = TypeStr.substr(strlen("half")); + } else if (TypeStr.startswith("opencl.sampler_t")) { + Type = StructType::create(Ctx, "opencl.sampler_t"); + } else + llvm_unreachable("Unable to recognize SPIRV type name."); + if (TypeStr.startswith(" vector[")) { + TypeStr = TypeStr.substr(strlen(" vector[")); + TypeStr = TypeStr.substr(0, TypeStr.find(']')); + } + TypeStr.getAsInteger(10, VecElts); + auto SpirvTy = getOrCreateSPIRVType(Type, MIRBuilder); + if (VecElts > 0) + SpirvTy = getOrCreateSPIRVVectorType(SpirvTy, VecElts, MIRBuilder); + return SpirvTy; +} + SPIRVType * SPIRVGlobalRegistry::getOrCreateSPIRVIntegerType(unsigned BitWidth, MachineIRBuilder &MIRBuilder) { Index: llvm/lib/Target/SPIRV/SPIRVInstrInfo.td =================================================================== --- llvm/lib/Target/SPIRV/SPIRVInstrInfo.td +++ llvm/lib/Target/SPIRV/SPIRVInstrInfo.td @@ -699,6 +699,10 @@ def OpCaptureEventProfilingInfo: Op<302, (outs), (ins ID:$event, ID:$info, ID:$value), "OpCaptureEventProfilingInfo $event $info $value">; +def OpGetDefaultQueue: Op<303, (outs ID:$res), (ins TYPE:$type), + "$res = OpGetDefaultQueue $type">; +def OpBuildNDRange: Op<304, (outs ID:$res), (ins TYPE:$type, ID:$GWS, ID:$LWS, ID:$GWO), + "$res = OpBuildNDRange $type $GWS $LWS $GWO">; // TODO: 3.42.23. Pipe Instructions Index: llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp =================================================================== --- llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp +++ llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp @@ -219,11 +219,12 @@ } MRI->replaceRegWith(I.getOperand(1).getReg(), I.getOperand(0).getReg()); I.removeFromParent(); + return true; } else if (I.getNumDefs() == 1) { // Make all vregs 32 bits (for SPIR-V IDs). MRI->setType(I.getOperand(0).getReg(), LLT::scalar(32)); } - return true; + return constrainSelectedInstRegOperands(I, TII, TRI, RBI); } if (I.getNumOperands() != I.getNumExplicitOperands()) { Index: llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp =================================================================== --- llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp +++ llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp @@ -465,8 +465,8 @@ if (!Req.IsSatisfiable) report_fatal_error("Adding SPIR-V requirements this target can't satisfy."); - if (Req.Cap.hasValue()) - addCapabilities({Req.Cap.getValue()}); + if (Req.Cap.has_value()) + addCapabilities({Req.Cap.value()}); addExtensions(Req.Exts); Index: llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp =================================================================== --- llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp +++ llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp @@ -189,11 +189,12 @@ // Insert ASSIGN_TYPE instuction between Reg and its definition, set NewReg as // a dst of the definition, assign SPIRVType to both registers. If SpirvTy is // provided, use it as SPIRVType in ASSIGN_TYPE, otherwise create it from Ty. +// It's used also in SPIRVBuiltins.cpp. // TODO: maybe move to SPIRVUtils. -static Register insertAssignInstr(Register Reg, Type *Ty, SPIRVType *SpirvTy, - SPIRVGlobalRegistry *GR, - MachineIRBuilder &MIB, - MachineRegisterInfo &MRI) { +namespace llvm { +Register insertAssignInstr(Register Reg, Type *Ty, SPIRVType *SpirvTy, + SPIRVGlobalRegistry *GR, MachineIRBuilder &MIB, + MachineRegisterInfo &MRI) { MachineInstr *Def = MRI.getVRegDef(Reg); assert((Ty || SpirvTy) && "Either LLVM or SPIRV type is expected."); MIB.setInsertPt(*Def->getParent(), @@ -219,6 +220,7 @@ MRI.setRegClass(Reg, &SPIRV::ANYIDRegClass); return NewReg; } +} // namespace llvm static void generateAssignInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR, MachineIRBuilder MIB) { Index: llvm/lib/Target/SPIRV/SPIRVSubtarget.h =================================================================== --- llvm/lib/Target/SPIRV/SPIRVSubtarget.h +++ llvm/lib/Target/SPIRV/SPIRVSubtarget.h @@ -39,6 +39,7 @@ uint32_t OpenCLVersion; SmallSet AvailableExtensions; + SmallSet AvailableExtInstSets; std::unique_ptr GR; SPIRVInstrInfo InstrInfo; @@ -51,9 +52,10 @@ std::unique_ptr Legalizer; std::unique_ptr InstSelector; - // TODO: Initialise the available extensions based on - // the environment settings. + // TODO: Initialise the available extensions, extended instruction sets + // based on the environment settings. void initAvailableExtensions(); + void initAvailableExtInstSets(); public: // This constructor initializes the data members to match that @@ -78,6 +80,7 @@ bool hasOpenCLFullProfile() const { return true; } bool hasOpenCLImageSupport() const { return true; } bool canUseExtension(SPIRV::Extension::Extension E) const; + bool canUseExtInstSet(SPIRV::InstructionSet::InstructionSet E) const; SPIRVGlobalRegistry *getSPIRVGlobalRegistry() const { return GR.get(); } Index: llvm/lib/Target/SPIRV/SPIRVSubtarget.cpp =================================================================== --- llvm/lib/Target/SPIRV/SPIRVSubtarget.cpp +++ llvm/lib/Target/SPIRV/SPIRVSubtarget.cpp @@ -46,7 +46,10 @@ PointerSize(computePointerSize(TT)), SPIRVVersion(0), OpenCLVersion(0), InstrInfo(), FrameLowering(initSubtargetDependencies(CPU, FS)), TLInfo(TM, *this) { + // The order of initialization is important. initAvailableExtensions(); + initAvailableExtInstSets(); + GR = std::make_unique(PointerSize); CallLoweringInfo = std::make_unique(TLInfo, GR.get()); Legalizer = std::make_unique(*this); @@ -69,6 +72,11 @@ return AvailableExtensions.contains(E); } +bool SPIRVSubtarget::canUseExtInstSet( + SPIRV::InstructionSet::InstructionSet E) const { + return AvailableExtInstSets.contains(E); +} + bool SPIRVSubtarget::isAtLeastSPIRVVer(uint32_t VerToCompareTo) const { return isAtLeastVer(SPIRVVersion, VerToCompareTo); } @@ -91,3 +99,20 @@ AvailableExtensions.insert( SPIRV::Extension::SPV_KHR_no_integer_wrap_decoration); } + +// TODO: use command line args for this rather than just defaults. +// Must have called initAvailableExtensions first. +void SPIRVSubtarget::initAvailableExtInstSets() { + AvailableExtInstSets.clear(); + if (!isOpenCLEnv()) + AvailableExtInstSets.insert(SPIRV::InstructionSet::GLSL_std_450); + else + AvailableExtInstSets.insert(SPIRV::InstructionSet::OpenCL_std); + + // Handle extended instruction sets from extensions. + if (canUseExtension( + SPIRV::Extension::SPV_AMD_shader_trinary_minmax_extension)) { + AvailableExtInstSets.insert( + SPIRV::InstructionSet::SPV_AMD_shader_trinary_minmax); + } +} Index: llvm/lib/Target/SPIRV/SPIRVUtils.h =================================================================== --- llvm/lib/Target/SPIRV/SPIRVUtils.h +++ llvm/lib/Target/SPIRV/SPIRVUtils.h @@ -83,5 +83,9 @@ // Get type of i-th operand of the metadata node. Type *getMDOperandAsType(const MDNode *N, unsigned I); + +// Return a demangled name with arg type info by itaniumDemangle(). +// If the parser fails, return only function name. +std::string mayBeOclOrSpirvBuiltin(StringRef Name); } // namespace llvm #endif // LLVM_LIB_TARGET_SPIRV_SPIRVUTILS_H Index: llvm/lib/Target/SPIRV/SPIRVUtils.cpp =================================================================== --- llvm/lib/Target/SPIRV/SPIRVUtils.cpp +++ llvm/lib/Target/SPIRV/SPIRVUtils.cpp @@ -18,6 +18,7 @@ #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h" #include "llvm/CodeGen/MachineInstr.h" #include "llvm/CodeGen/MachineInstrBuilder.h" +#include "llvm/Demangle/Demangle.h" #include "llvm/IR/IntrinsicsSPIRV.h" namespace llvm { @@ -238,4 +239,96 @@ Type *getMDOperandAsType(const MDNode *N, unsigned I) { return cast(N->getOperand(I))->getType(); } + +// The set of names is borrowed from the SPIR-V translator. +// TODO: may be implemented in SPIRVBuiltins.td. +static bool isPipeOrAddressSpaceCastBI(const StringRef MangledName) { + return MangledName == "write_pipe_2" || MangledName == "read_pipe_2" || + MangledName == "write_pipe_2_bl" || MangledName == "read_pipe_2_bl" || + MangledName == "write_pipe_4" || MangledName == "read_pipe_4" || + MangledName == "reserve_write_pipe" || + MangledName == "reserve_read_pipe" || + MangledName == "commit_write_pipe" || + MangledName == "commit_read_pipe" || + MangledName == "work_group_reserve_write_pipe" || + MangledName == "work_group_reserve_read_pipe" || + MangledName == "work_group_commit_write_pipe" || + MangledName == "work_group_commit_read_pipe" || + MangledName == "get_pipe_num_packets_ro" || + MangledName == "get_pipe_max_packets_ro" || + MangledName == "get_pipe_num_packets_wo" || + MangledName == "get_pipe_max_packets_wo" || + MangledName == "sub_group_reserve_write_pipe" || + MangledName == "sub_group_reserve_read_pipe" || + MangledName == "sub_group_commit_write_pipe" || + MangledName == "sub_group_commit_read_pipe" || + MangledName == "to_global" || MangledName == "to_local" || + MangledName == "to_private"; +} + +static bool isEnqueueKernelBI(const StringRef MangledName) { + return MangledName == "__enqueue_kernel_basic" || + MangledName == "__enqueue_kernel_basic_events" || + MangledName == "__enqueue_kernel_varargs" || + MangledName == "__enqueue_kernel_events_varargs"; +} + +static bool isKernelQueryBI(const StringRef MangledName) { + return MangledName == "__get_kernel_work_group_size_impl" || + MangledName == "__get_kernel_sub_group_count_for_ndrange_impl" || + MangledName == "__get_kernel_max_sub_group_size_for_ndrange_impl" || + MangledName == "__get_kernel_preferred_work_group_size_multiple_impl"; +} + +static bool isNonMangledOCLBuiltin(StringRef Name) { + if (!Name.startswith("__")) + return false; + + return isEnqueueKernelBI(Name) || isKernelQueryBI(Name) || + isPipeOrAddressSpaceCastBI(Name.drop_front(2)) || + Name == "__translate_sampler_initializer"; +} + +std::string mayBeOclOrSpirvBuiltin(StringRef Name) { + bool IsNonMangledOCL = isNonMangledOCLBuiltin(Name); + bool IsNonMangledSPIRV = Name.startswith("__spirv_"); + bool IsMangled = Name.startswith("_Z"); + + if (!IsNonMangledOCL && !IsNonMangledSPIRV && !IsMangled) + return std::string(); + + // Try to use the itanium demangler. + size_t n; + int Status; + char *DemangledName = itaniumDemangle(Name.data(), nullptr, &n, &Status); + + if (Status == demangle_success) { + std::string Result = DemangledName; + free(DemangledName); + return Result; + } + free(DemangledName); + // Otherwise use simple demangling to return the function name. + if (IsNonMangledOCL || IsNonMangledSPIRV) + return Name.str(); + + // Autocheck C++, maybe need to do explicit check of the source language. + // OpenCL C++ built-ins are declared in cl namespace. + // TODO: consider using 'St' abbriviation for cl namespace mangling. + // Similar to ::std:: in C++. + size_t Start, Len = 0; + size_t DemangledNameLenStart = 2; + if (Name.startswith("_ZN")) { + // Skip CV and ref qualifiers. + size_t NameSpaceStart = Name.find_first_not_of("rVKRO", 3); + // All built-ins are in the ::cl:: namespace. + if (Name.substr(NameSpaceStart, 11) != "2cl7__spirv") + return std::string(); + DemangledNameLenStart = NameSpaceStart + 11; + } + Start = Name.find_first_not_of("0123456789", DemangledNameLenStart); + Name.substr(DemangledNameLenStart, Start - DemangledNameLenStart) + .getAsInteger(10, Len); + return Name.substr(Start, Len).str(); +} } // namespace llvm Index: llvm/test/CodeGen/SPIRV/builtin_vars-decorate.ll =================================================================== --- /dev/null +++ llvm/test/CodeGen/SPIRV/builtin_vars-decorate.ll @@ -0,0 +1,59 @@ +; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s + +; CHECK: OpName %[[#WD:]] "__spirv_BuiltInWorkDim" +; CHECK: OpName %[[#GS:]] "__spirv_BuiltInGlobalSize" +; CHECK: OpName %[[#GII:]] "__spirv_BuiltInGlobalInvocationId" +; CHECK: OpName %[[#WS:]] "__spirv_BuiltInWorkgroupSize" +; CHECK: OpName %[[#EWS:]] "__spirv_BuiltInEnqueuedWorkgroupSize" +; CHECK: OpName %[[#LLI:]] "__spirv_BuiltInLocalInvocationId" +; CHECK: OpName %[[#NW:]] "__spirv_BuiltInNumWorkgroups" +; CHECK: OpName %[[#WI:]] "__spirv_BuiltInWorkgroupId" +; CHECK: OpName %[[#GO:]] "__spirv_BuiltInGlobalOffset" +; CHECK: OpName %[[#GLI:]] "__spirv_BuiltInGlobalLinearId" +; CHECK: OpName %[[#LLII:]] "__spirv_BuiltInLocalInvocationIndex" +; CHECK: OpName %[[#SS:]] "__spirv_BuiltInSubgroupSize" +; CHECK: OpName %[[#SMS:]] "__spirv_BuiltInSubgroupMaxSize" +; CHECK: OpName %[[#NS:]] "__spirv_BuiltInNumSubgroups" +; CHECK: OpName %[[#NES:]] "__spirv_BuiltInNumEnqueuedSubgroups" +; CHECK: OpName %[[#SI:]] "__spirv_BuiltInSubgroupId" +; CHECK: OpName %[[#SLII:]] "__spirv_BuiltInSubgroupLocalInvocationId" + +; CHECK-DAG: OpDecorate %[[#NW]] BuiltIn NumWorkgroups +; CHECK-DAG: OpDecorate %[[#WS]] BuiltIn WorkgroupSize +; CHECK-DAG: OpDecorate %[[#WI]] BuiltIn WorkgroupId +; CHECK-DAG: OpDecorate %[[#LLI]] BuiltIn LocalInvocationId +; CHECK-DAG: OpDecorate %[[#GII]] BuiltIn GlobalInvocationId +; CHECK-DAG: OpDecorate %[[#LLII]] BuiltIn LocalInvocationIndex +; CHECK-DAG: OpDecorate %[[#WD]] BuiltIn WorkDim +; CHECK-DAG: OpDecorate %[[#GS]] BuiltIn GlobalSize +; CHECK-DAG: OpDecorate %[[#EWS]] BuiltIn EnqueuedWorkgroupSize +; CHECK-DAG: OpDecorate %[[#GO]] BuiltIn GlobalOffset +; CHECK-DAG: OpDecorate %[[#GLI]] BuiltIn GlobalLinearId +; CHECK-DAG: OpDecorate %[[#SS]] BuiltIn SubgroupSize +; CHECK-DAG: OpDecorate %[[#SMS]] BuiltIn SubgroupMaxSize +; CHECK-DAG: OpDecorate %[[#NS]] BuiltIn NumSubgroups +; CHECK-DAG: OpDecorate %[[#NES]] BuiltIn NumEnqueuedSubgroups +; CHECK-DAG: OpDecorate %[[#SI]] BuiltIn SubgroupId +; CHECK-DAG: OpDecorate %[[#SLII]] BuiltIn SubgroupLocalInvocationId +@__spirv_BuiltInWorkDim = external addrspace(1) global i32 +@__spirv_BuiltInGlobalSize = external addrspace(1) global <3 x i32> +@__spirv_BuiltInGlobalInvocationId = external addrspace(1) global <3 x i32> +@__spirv_BuiltInWorkgroupSize = external addrspace(1) global <3 x i32> +@__spirv_BuiltInEnqueuedWorkgroupSize = external addrspace(1) global <3 x i32> +@__spirv_BuiltInLocalInvocationId = external addrspace(1) global <3 x i32> +@__spirv_BuiltInNumWorkgroups = external addrspace(1) global <3 x i32> +@__spirv_BuiltInWorkgroupId = external addrspace(1) global <3 x i32> +@__spirv_BuiltInGlobalOffset = external addrspace(1) global <3 x i32> +@__spirv_BuiltInGlobalLinearId = external addrspace(1) global i32 +@__spirv_BuiltInLocalInvocationIndex = external addrspace(1) global i32 +@__spirv_BuiltInSubgroupSize = external addrspace(1) global i32 +@__spirv_BuiltInSubgroupMaxSize = external addrspace(1) global i32 +@__spirv_BuiltInNumSubgroups = external addrspace(1) global i32 +@__spirv_BuiltInNumEnqueuedSubgroups = external addrspace(1) global i32 +@__spirv_BuiltInSubgroupId = external addrspace(1) global i32 +@__spirv_BuiltInSubgroupLocalInvocationId = external addrspace(1) global i32 + +define spir_kernel void @_Z1wv() { +entry: + ret void +} Index: llvm/test/CodeGen/SPIRV/capability-Int64Atomics.ll =================================================================== --- /dev/null +++ llvm/test/CodeGen/SPIRV/capability-Int64Atomics.ll @@ -0,0 +1,19 @@ +; OpenCL C source: +; #pragma OPENCL EXTENSION cl_khr_int64_base_atomics : enable +; #pragma OPENCL EXTENSION cl_khr_int64_extended_atomics : enable +; +; void foo (volatile atomic_long *object, long desired) { +; atomic_fetch_xor(object, desired); +;} + +; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s + +; CHECK: OpCapability Int64Atomics + +define spir_func void @foo(i64 addrspace(4)* %object, i64 %desired) { +entry: + %call = tail call spir_func i64 @_Z16atomic_fetch_xorPVU3AS4U7_Atomicll(i64 addrspace(4)* %object, i64 %desired) + ret void +} + +declare spir_func i64 @_Z16atomic_fetch_xorPVU3AS4U7_Atomicll(i64 addrspace(4)*, i64) Index: llvm/test/CodeGen/SPIRV/empty-module.ll =================================================================== --- /dev/null +++ llvm/test/CodeGen/SPIRV/empty-module.ll @@ -0,0 +1,8 @@ +; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s + +; CHECK-DAG: OpCapability Addresses +; CHECK-DAG: OpCapability Linkage +; CHECK-DAG: OpCapability Kernel +; CHECK: %1 = OpExtInstImport "OpenCL.std" +; CHECK: OpMemoryModel Physical64 OpenCL +; CHECK: OpSource Unknown 0 Index: llvm/test/CodeGen/SPIRV/spirv-tools-dis.ll =================================================================== --- /dev/null +++ llvm/test/CodeGen/SPIRV/spirv-tools-dis.ll @@ -0,0 +1,13 @@ +; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s + +; CHECK: %{{[0-9]+}} = OpExtInstImport "OpenCL.std" +; CHECK: %{{[0-9]+}} = OpTypeInt 32 0 + +define spir_kernel void @foo(i32 addrspace(1)* %a) { +entry: + %a.addr = alloca i32 addrspace(1)*, align 4 + store i32 addrspace(1)* %a, i32 addrspace(1)** %a.addr, align 4 + %0 = load i32 addrspace(1)*, i32 addrspace(1)** %a.addr, align 4 + store i32 0, i32 addrspace(1)* %0, align 4 + ret void +} Index: llvm/test/CodeGen/SPIRV/transcoding/builtin_calls.ll =================================================================== --- /dev/null +++ llvm/test/CodeGen/SPIRV/transcoding/builtin_calls.ll @@ -0,0 +1,16 @@ +; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s --check-prefix=CHECK-SPIRV + +; CHECK-SPIRV-DAG: OpDecorate %[[Id:[0-9]+]] BuiltIn GlobalInvocationId +; CHECK-SPIRV-DAG: OpDecorate %[[Id:[0-9]+]] BuiltIn GlobalLinearId +; CHECK-SPIRV: %[[Id:[0-9]+]] = OpVariable %{{[0-9]+}} +; CHECK-SPIRV: %[[Id:[0-9]+]] = OpVariable %{{[0-9]+}} + +define spir_kernel void @f(){ +entry: + %0 = call spir_func i32 @_Z29__spirv_BuiltInGlobalLinearIdv() + %1 = call spir_func i64 @_Z33__spirv_BuiltInGlobalInvocationIdi(i32 1) + ret void +} + +declare spir_func i32 @_Z29__spirv_BuiltInGlobalLinearIdv() +declare spir_func i64 @_Z33__spirv_BuiltInGlobalInvocationIdi(i32)