Index: llvm/lib/Target/SPIRV/CMakeLists.txt =================================================================== --- llvm/lib/Target/SPIRV/CMakeLists.txt +++ llvm/lib/Target/SPIRV/CMakeLists.txt @@ -15,13 +15,17 @@ add_llvm_target(SPIRVCodeGen SPIRVAsmPrinter.cpp SPIRVCallLowering.cpp + SPIRVGlobalRegistry.cpp SPIRVInstrInfo.cpp + SPIRVInstructionSelector.cpp SPIRVISelLowering.cpp + SPIRVLegalizerInfo.cpp SPIRVMCInstLower.cpp SPIRVRegisterBankInfo.cpp SPIRVRegisterInfo.cpp SPIRVSubtarget.cpp SPIRVTargetMachine.cpp + SPIRVUtils.cpp LINK_COMPONENTS Analysis Index: llvm/lib/Target/SPIRV/MCTargetDesc/CMakeLists.txt =================================================================== --- llvm/lib/Target/SPIRV/MCTargetDesc/CMakeLists.txt +++ llvm/lib/Target/SPIRV/MCTargetDesc/CMakeLists.txt @@ -1,4 +1,5 @@ add_llvm_component_library(LLVMSPIRVDesc + SPIRVBaseInfo.cpp SPIRVMCAsmInfo.cpp SPIRVMCTargetDesc.cpp SPIRVTargetStreamer.cpp Index: llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVBaseInfo.h =================================================================== --- /dev/null +++ llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVBaseInfo.h @@ -0,0 +1,739 @@ +//===-- SPIRVBaseInfo.h - Top level definitions for SPIRV ------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains small standalone helper functions and enum definitions for +// the SPIRV target useful for the compiler back-end and the MC libraries. +// As such, it deliberately does not include references to LLVM core +// code gen types, passes, etc.. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_LIB_TARGET_SPIRV_MCTARGETDESC_SPIRVBASEINFO_H +#define LLVM_LIB_TARGET_SPIRV_MCTARGETDESC_SPIRVBASEINFO_H + +#include "llvm/ADT/StringRef.h" +#include + +namespace llvm { +namespace SPIRV { +enum class Capability : uint32_t { + Matrix = 0, + Shader = 1, + Geometry = 2, + Tessellation = 3, + Addresses = 4, + Linkage = 5, + Kernel = 6, + Vector16 = 7, + Float16Buffer = 8, + Float16 = 9, + Float64 = 10, + Int64 = 11, + Int64Atomics = 12, + ImageBasic = 13, + ImageReadWrite = 14, + ImageMipmap = 15, + Pipes = 17, + Groups = 18, + DeviceEnqueue = 19, + LiteralSampler = 20, + AtomicStorage = 21, + Int16 = 22, + TessellationPointSize = 23, + GeometryPointSize = 24, + ImageGatherExtended = 25, + StorageImageMultisample = 27, + UniformBufferArrayDynamicIndexing = 28, + SampledImageArrayDymnamicIndexing = 29, + ClipDistance = 32, + CullDistance = 33, + ImageCubeArray = 34, + SampleRateShading = 35, + ImageRect = 36, + SampledRect = 37, + GenericPointer = 38, + Int8 = 39, + InputAttachment = 40, + SparseResidency = 41, + MinLod = 42, + Sampled1D = 43, + Image1D = 44, + SampledCubeArray = 45, + SampledBuffer = 46, + ImageBuffer = 47, + ImageMSArray = 48, + StorageImageExtendedFormats = 49, + ImageQuery = 50, + DerivativeControl = 51, + InterpolationFunction = 52, + TransformFeedback = 53, + GeometryStreams = 54, + StorageImageReadWithoutFormat = 55, + StorageImageWriteWithoutFormat = 56, + MultiViewport = 57, + SubgroupDispatch = 58, + NamedBarrier = 59, + PipeStorage = 60, + GroupNonUniform = 61, + GroupNonUniformVote = 62, + GroupNonUniformArithmetic = 63, + GroupNonUniformBallot = 64, + GroupNonUniformShuffle = 65, + GroupNonUniformShuffleRelative = 66, + GroupNonUniformClustered = 67, + GroupNonUniformQuad = 68, + SubgroupBallotKHR = 4423, + DrawParameters = 4427, + SubgroupVoteKHR = 4431, + StorageBuffer16BitAccess = 4433, + StorageUniform16 = 4434, + StoragePushConstant16 = 4435, + StorageInputOutput16 = 4436, + DeviceGroup = 4437, + MultiView = 4439, + VariablePointersStorageBuffer = 4441, + VariablePointers = 4442, + AtomicStorageOps = 4445, + SampleMaskPostDepthCoverage = 4447, + StorageBuffer8BitAccess = 4448, + UniformAndStorageBuffer8BitAccess = 4449, + StoragePushConstant8 = 4450, + DenormPreserve = 4464, + DenormFlushToZero = 4465, + SignedZeroInfNanPreserve = 4466, + RoundingModeRTE = 4467, + RoundingModeRTZ = 4468, + Float16ImageAMD = 5008, + ImageGatherBiasLodAMD = 5009, + FragmentMaskAMD = 5010, + StencilExportEXT = 5013, + ImageReadWriteLodAMD = 5015, + SampleMaskOverrideCoverageNV = 5249, + GeometryShaderPassthroughNV = 5251, + ShaderViewportIndexLayerEXT = 5254, + ShaderViewportMaskNV = 5255, + ShaderStereoViewNV = 5259, + PerViewAttributesNV = 5260, + FragmentFullyCoveredEXT = 5265, + MeshShadingNV = 5266, + ShaderNonUniformEXT = 5301, + RuntimeDescriptorArrayEXT = 5302, + InputAttachmentArrayDynamicIndexingEXT = 5303, + UniformTexelBufferArrayDynamicIndexingEXT = 5304, + StorageTexelBufferArrayDynamicIndexingEXT = 5305, + UniformBufferArrayNonUniformIndexingEXT = 5306, + SampledImageArrayNonUniformIndexingEXT = 5307, + StorageBufferArrayNonUniformIndexingEXT = 5308, + StorageImageArrayNonUniformIndexingEXT = 5309, + InputAttachmentArrayNonUniformIndexingEXT = 5310, + UniformTexelBufferArrayNonUniformIndexingEXT = 5311, + StorageTexelBufferArrayNonUniformIndexingEXT = 5312, + RayTracingNV = 5340, + SubgroupShuffleINTEL = 5568, + SubgroupBufferBlockIOINTEL = 5569, + SubgroupImageBlockIOINTEL = 5570, + SubgroupImageMediaBlockIOINTEL = 5579, + SubgroupAvcMotionEstimationINTEL = 5696, + SubgroupAvcMotionEstimationIntraINTEL = 5697, + SubgroupAvcMotionEstimationChromaINTEL = 5698, + GroupNonUniformPartitionedNV = 5297, + VulkanMemoryModelKHR = 5345, + VulkanMemoryModelDeviceScopeKHR = 5346, + ImageFootprintNV = 5282, + FragmentBarycentricNV = 5284, + ComputeDerivativeGroupQuadsNV = 5288, + ComputeDerivativeGroupLinearNV = 5350, + FragmentDensityEXT = 5291, + PhysicalStorageBufferAddressesEXT = 5347, + CooperativeMatrixNV = 5357, +}; +StringRef getCapabilityName(Capability e); + +enum class SourceLanguage : uint32_t { + Unknown = 0, + ESSL = 1, + GLSL = 2, + OpenCL_C = 3, + OpenCL_CPP = 4, + HLSL = 5, +}; +StringRef getSourceLanguageName(SourceLanguage e); + +enum class AddressingModel : uint32_t { + Logical = 0, + Physical32 = 1, + Physical64 = 2, + PhysicalStorageBuffer64EXT = 5348, +}; +StringRef getAddressingModelName(AddressingModel e); + +enum class ExecutionModel : uint32_t { + Vertex = 0, + TessellationControl = 1, + TessellationEvaluation = 2, + Geometry = 3, + Fragment = 4, + GLCompute = 5, + Kernel = 6, + TaskNV = 5267, + MeshNV = 5268, + RayGenerationNV = 5313, + IntersectionNV = 5314, + AnyHitNV = 5315, + ClosestHitNV = 5316, + MissNV = 5317, + CallableNV = 5318, +}; +StringRef getExecutionModelName(ExecutionModel e); + +enum class MemoryModel : uint32_t { + Simple = 0, + GLSL450 = 1, + OpenCL = 2, + VulkanKHR = 3, +}; +StringRef getMemoryModelName(MemoryModel e); + +enum class ExecutionMode : uint32_t { + Invocations = 0, + SpacingEqual = 1, + SpacingFractionalEven = 2, + SpacingFractionalOdd = 3, + VertexOrderCw = 4, + VertexOrderCcw = 5, + PixelCenterInteger = 6, + OriginUpperLeft = 7, + OriginLowerLeft = 8, + EarlyFragmentTests = 9, + PointMode = 10, + Xfb = 11, + DepthReplacing = 12, + DepthGreater = 14, + DepthLess = 15, + DepthUnchanged = 16, + LocalSize = 17, + LocalSizeHint = 18, + InputPoints = 19, + InputLines = 20, + InputLinesAdjacency = 21, + Triangles = 22, + InputTrianglesAdjacency = 23, + Quads = 24, + Isolines = 25, + OutputVertices = 26, + OutputPoints = 27, + OutputLineStrip = 28, + OutputTriangleStrip = 29, + VecTypeHint = 30, + ContractionOff = 31, + Initializer = 33, + Finalizer = 34, + SubgroupSize = 35, + SubgroupsPerWorkgroup = 36, + SubgroupsPerWorkgroupId = 37, + LocalSizeId = 38, + LocalSizeHintId = 39, + PostDepthCoverage = 4446, + DenormPreserve = 4459, + DenormFlushToZero = 4460, + SignedZeroInfNanPreserve = 4461, + RoundingModeRTE = 4462, + RoundingModeRTZ = 4463, + StencilRefReplacingEXT = 5027, + OutputLinesNV = 5269, + DerivativeGroupQuadsNV = 5289, + DerivativeGroupLinearNV = 5290, + OutputTrianglesNV = 5298, +}; +StringRef getExecutionModeName(ExecutionMode e); + +enum class StorageClass : uint32_t { + UniformConstant = 0, + Input = 1, + Uniform = 2, + Output = 3, + Workgroup = 4, + CrossWorkgroup = 5, + Private = 6, + Function = 7, + Generic = 8, + PushConstant = 9, + AtomicCounter = 10, + Image = 11, + StorageBuffer = 12, + CallableDataNV = 5328, + IncomingCallableDataNV = 5329, + RayPayloadNV = 5338, + HitAttributeNV = 5339, + IncomingRayPayloadNV = 5342, + ShaderRecordBufferNV = 5343, + PhysicalStorageBufferEXT = 5349, +}; +StringRef getStorageClassName(StorageClass e); + +enum class Dim : uint32_t { + DIM_1D = 0, + DIM_2D = 1, + DIM_3D = 2, + DIM_Cube = 3, + DIM_Rect = 4, + DIM_Buffer = 5, + DIM_SubpassData = 6, +}; +StringRef getDimName(Dim e); + +enum class SamplerAddressingMode : uint32_t { + None = 0, + ClampToEdge = 1, + Clamp = 2, + Repeat = 3, + RepeatMirrored = 4, +}; +StringRef getSamplerAddressingModeName(SamplerAddressingMode e); + +enum class SamplerFilterMode : uint32_t { + Nearest = 0, + Linear = 1, +}; +StringRef getSamplerFilterModeName(SamplerFilterMode e); + +enum class ImageFormat : uint32_t { + Unknown = 0, + Rgba32f = 1, + Rgba16f = 2, + R32f = 3, + Rgba8 = 4, + Rgba8Snorm = 5, + Rg32f = 6, + Rg16f = 7, + R11fG11fB10f = 8, + R16f = 9, + Rgba16 = 10, + Rgb10A2 = 11, + Rg16 = 12, + Rg8 = 13, + R16 = 14, + R8 = 15, + Rgba16Snorm = 16, + Rg16Snorm = 17, + Rg8Snorm = 18, + R16Snorm = 19, + R8Snorm = 20, + Rgba32i = 21, + Rgba16i = 22, + Rgba8i = 23, + R32i = 24, + Rg32i = 25, + Rg16i = 26, + Rg8i = 27, + R16i = 28, + R8i = 29, + Rgba32ui = 30, + Rgba16ui = 31, + Rgba8ui = 32, + R32ui = 33, + Rgb10a2ui = 34, + Rg32ui = 35, + Rg16ui = 36, + Rg8ui = 37, + R16ui = 38, + R8ui = 39, +}; +StringRef getImageFormatName(ImageFormat e); + +enum class ImageChannelOrder : uint32_t { + R = 0, + A = 1, + RG = 2, + RA = 3, + RGB = 4, + RGBA = 5, + BGRA = 6, + ARGB = 7, + Intensity = 8, + Luminance = 9, + Rx = 10, + RGx = 11, + RGBx = 12, + Depth = 13, + DepthStencil = 14, + sRGB = 15, + sRGBx = 16, + sRGBA = 17, + sBGRA = 18, + ABGR = 19, +}; +StringRef getImageChannelOrderName(ImageChannelOrder e); + +enum class ImageChannelDataType : uint32_t { + SnormInt8 = 0, + SnormInt16 = 1, + UnormInt8 = 2, + UnormInt16 = 3, + UnormShort565 = 4, + UnormShort555 = 5, + UnormInt101010 = 6, + SignedInt8 = 7, + SignedInt16 = 8, + SignedInt32 = 9, + UnsignedInt8 = 10, + UnsignedInt16 = 11, + UnsigendInt32 = 12, + HalfFloat = 13, + Float = 14, + UnormInt24 = 15, + UnormInt101010_2 = 16, +}; +StringRef getImageChannelDataTypeName(ImageChannelDataType e); + +enum class ImageOperand : uint32_t { + None = 0x0, + Bias = 0x1, + Lod = 0x2, + Grad = 0x4, + ConstOffset = 0x8, + Offset = 0x10, + ConstOffsets = 0x20, + Sample = 0x40, + MinLod = 0x80, + MakeTexelAvailableKHR = 0x100, + MakeTexelVisibleKHR = 0x200, + NonPrivateTexelKHR = 0x400, + VolatileTexelKHR = 0x800, + SignExtend = 0x1000, + ZeroExtend = 0x2000, +}; +std::string getImageOperandName(uint32_t e); + +enum class FPFastMathMode : uint32_t { + None = 0x0, + NotNaN = 0x1, + NotInf = 0x2, + NSZ = 0x4, + AllowRecip = 0x8, + Fast = 0x10, +}; +std::string getFPFastMathModeName(uint32_t e); + +enum class FPRoundingMode : uint32_t { + RTE = 0, + RTZ = 1, + RTP = 2, + RTN = 3, +}; +StringRef getFPRoundingModeName(FPRoundingMode e); + +enum class LinkageType : uint32_t { + Export = 0, + Import = 1, +}; +StringRef getLinkageTypeName(LinkageType e); + +enum class AccessQualifier : uint32_t { + ReadOnly = 0, + WriteOnly = 1, + ReadWrite = 2, +}; +StringRef getAccessQualifierName(AccessQualifier e); + +enum class FunctionParameterAttribute : uint32_t { + Zext = 0, + Sext = 1, + ByVal = 2, + Sret = 3, + NoAlias = 4, + NoCapture = 5, + NoWrite = 6, + NoReadWrite = 7, +}; +StringRef getFunctionParameterAttributeName(FunctionParameterAttribute e); + +enum class Decoration : uint32_t { + RelaxedPrecision = 0, + SpecId = 1, + Block = 2, + BufferBlock = 3, + RowMajor = 4, + ColMajor = 5, + ArrayStride = 6, + MatrixStride = 7, + GLSLShared = 8, + GLSLPacked = 9, + CPacked = 10, + BuiltIn = 11, + NoPerspective = 13, + Flat = 14, + Patch = 15, + Centroid = 16, + Sample = 17, + Invariant = 18, + Restrict = 19, + Aliased = 20, + Volatile = 21, + Constant = 22, + Coherent = 23, + NonWritable = 24, + NonReadable = 25, + Uniform = 26, + UniformId = 27, + SaturatedConversion = 28, + Stream = 29, + Location = 30, + Component = 31, + Index = 32, + Binding = 33, + DescriptorSet = 34, + Offset = 35, + XfbBuffer = 36, + XfbStride = 37, + FuncParamAttr = 38, + FPRoundingMode = 39, + FPFastMathMode = 40, + LinkageAttributes = 41, + NoContraction = 42, + InputAttachmentIndex = 43, + Alignment = 44, + MaxByteOffset = 45, + AlignmentId = 46, + MaxByteOffsetId = 47, + NoSignedWrap = 4469, + NoUnsignedWrap = 4470, + ExplicitInterpAMD = 4999, + OverrideCoverageNV = 5248, + PassthroughNV = 5250, + ViewportRelativeNV = 5252, + SecondaryViewportRelativeNV = 5256, + PerPrimitiveNV = 5271, + PerViewNV = 5272, + PerVertexNV = 5273, + NonUniformEXT = 5300, + CountBuffer = 5634, + UserSemantic = 5635, + RestrictPointerEXT = 5355, + AliasedPointerEXT = 5356, +}; +StringRef getDecorationName(Decoration e); + +enum class BuiltIn : uint32_t { + Position = 0, + PointSize = 1, + ClipDistance = 3, + CullDistance = 4, + VertexId = 5, + InstanceId = 6, + PrimitiveId = 7, + InvocationId = 8, + Layer = 9, + ViewportIndex = 10, + TessLevelOuter = 11, + TessLevelInner = 12, + TessCoord = 13, + PatchVertices = 14, + FragCoord = 15, + PointCoord = 16, + FrontFacing = 17, + SampleId = 18, + SamplePosition = 19, + SampleMask = 20, + FragDepth = 22, + HelperInvocation = 23, + NumWorkgroups = 24, + WorkgroupSize = 25, + WorkgroupId = 26, + LocalInvocationId = 27, + GlobalInvocationId = 28, + LocalInvocationIndex = 29, + WorkDim = 30, + GlobalSize = 31, + EnqueuedWorkgroupSize = 32, + GlobalOffset = 33, + GlobalLinearId = 34, + SubgroupSize = 36, + SubgroupMaxSize = 37, + NumSubgroups = 38, + NumEnqueuedSubgroups = 39, + SubgroupId = 40, + SubgroupLocalInvocationId = 41, + VertexIndex = 42, + InstanceIndex = 43, + SubgroupEqMask = 4416, + SubgroupGeMask = 4417, + SubgroupGtMask = 4418, + SubgroupLeMask = 4419, + SubgroupLtMask = 4420, + BaseVertex = 4424, + BaseInstance = 4425, + DrawIndex = 4426, + DeviceIndex = 4438, + ViewIndex = 4440, + BaryCoordNoPerspAMD = 4492, + BaryCoordNoPerspCentroidAMD = 4493, + BaryCoordNoPerspSampleAMD = 4494, + BaryCoordSmoothAMD = 4495, + BaryCoordSmoothCentroid = 4496, + BaryCoordSmoothSample = 4497, + BaryCoordPullModel = 4498, + FragStencilRefEXT = 5014, + ViewportMaskNV = 5253, + SecondaryPositionNV = 5257, + SecondaryViewportMaskNV = 5258, + PositionPerViewNV = 5261, + ViewportMaskPerViewNV = 5262, + FullyCoveredEXT = 5264, + TaskCountNV = 5274, + PrimitiveCountNV = 5275, + PrimitiveIndicesNV = 5276, + ClipDistancePerViewNV = 5277, + CullDistancePerViewNV = 5278, + LayerPerViewNV = 5279, + MeshViewCountNV = 5280, + MeshViewIndices = 5281, + BaryCoordNV = 5286, + BaryCoordNoPerspNV = 5287, + FragSizeEXT = 5292, + FragInvocationCountEXT = 5293, + LaunchIdNV = 5319, + LaunchSizeNV = 5320, + WorldRayOriginNV = 5321, + WorldRayDirectionNV = 5322, + ObjectRayOriginNV = 5323, + ObjectRayDirectionNV = 5324, + RayTminNV = 5325, + RayTmaxNV = 5326, + InstanceCustomIndexNV = 5327, + ObjectToWorldNV = 5330, + WorldToObjectNV = 5331, + HitTNV = 5332, + HitKindNV = 5333, + IncomingRayFlagsNV = 5351, +}; +StringRef getBuiltInName(BuiltIn e); + +enum class SelectionControl : uint32_t { + None = 0x0, + Flatten = 0x1, + DontFlatten = 0x2, +}; +std::string getSelectionControlName(uint32_t e); + +enum class LoopControl : uint32_t { + None = 0x0, + Unroll = 0x1, + DontUnroll = 0x2, + DependencyInfinite = 0x4, + DependencyLength = 0x8, + MinIterations = 0x10, + MaxIterations = 0x20, + IterationMultiple = 0x40, + PeelCount = 0x80, + PartialCount = 0x100, +}; +std::string getLoopControlName(uint32_t e); + +enum class FunctionControl : uint32_t { + None = 0x0, + Inline = 0x1, + DontInline = 0x2, + Pure = 0x4, + Const = 0x8, +}; +std::string getFunctionControlName(uint32_t e); + +enum class MemorySemantics : uint32_t { + None = 0x0, + Acquire = 0x2, + Release = 0x4, + AcquireRelease = 0x8, + SequentiallyConsistent = 0x10, + UniformMemory = 0x40, + SubgroupMemory = 0x80, + WorkgroupMemory = 0x100, + CrossWorkgroupMemory = 0x200, + AtomicCounterMemory = 0x400, + ImageMemory = 0x800, + OutputMemoryKHR = 0x1000, + MakeAvailableKHR = 0x2000, + MakeVisibleKHR = 0x4000, +}; +std::string getMemorySemanticsName(uint32_t e); + +enum class MemoryOperand : uint32_t { + None = 0x0, + Volatile = 0x1, + Aligned = 0x2, + Nontemporal = 0x4, + MakePointerAvailableKHR = 0x8, + MakePointerVisibleKHR = 0x10, + NonPrivatePointerKHR = 0x20, +}; +std::string getMemoryOperandName(uint32_t e); + +enum class Scope : uint32_t { + CrossDevice = 0, + Device = 1, + Workgroup = 2, + Subgroup = 3, + Invocation = 4, + QueueFamilyKHR = 5, +}; +StringRef getScopeName(Scope e); + +enum class GroupOperation : uint32_t { + Reduce = 0, + InclusiveScan = 1, + ExclusiveScan = 2, + ClusteredReduce = 3, + PartitionedReduceNV = 6, + PartitionedInclusiveScanNV = 7, + PartitionedExclusiveScanNV = 8, +}; +StringRef getGroupOperationName(GroupOperation e); + +enum class KernelEnqueueFlags : uint32_t { + NoWait = 0, + WaitKernel = 1, + WaitWorkGroup = 2, +}; +StringRef getKernelEnqueueFlagsName(KernelEnqueueFlags e); + +enum class KernelProfilingInfo : uint32_t { + None = 0x0, + CmdExecTime = 0x1, +}; +StringRef getKernelProfilingInfoName(KernelProfilingInfo e); +} // namespace SPIRV +} // namespace llvm + +// Return a string representation of the operands from startIndex onwards. +// Templated to allow both MachineInstr and MCInst to use the same logic. +template +std::string getSPIRVStringOperand(const InstType &MI, unsigned StartIndex) { + std::string s; // Iteratively append to this string. + + const unsigned NumOps = MI.getNumOperands(); + bool IsFinished = false; + for (unsigned i = StartIndex; i < NumOps && !IsFinished; ++i) { + const auto &Op = MI.getOperand(i); + if (!Op.isImm()) // Stop if we hit a register operand. + break; + assert((Op.getImm() >> 32) == 0 && "Imm operand should be i32 word"); + const uint32_t Imm = Op.getImm(); // Each i32 word is up to 4 characters. + for (unsigned ShiftAmount = 0; ShiftAmount < 32; ShiftAmount += 8) { + char c = (Imm >> ShiftAmount) & 0xff; + if (c == 0) { // Stop if we hit a null-terminator character. + IsFinished = true; + break; + } else { + s += c; // Otherwise, append the character to the result string. + } + } + } + return s; +} + +#endif // LLVM_LIB_TARGET_SPIRV_MCTARGETDESC_SPIRVBASEINFO_H Index: llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVBaseInfo.cpp =================================================================== --- /dev/null +++ llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVBaseInfo.cpp @@ -0,0 +1,1591 @@ +//===-- SPIRVBaseInfo.cpp - Top level definitions for SPIRV ----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains small standalone helper functions and enum definitions for +// the SPIRV target useful for the compiler back-end and the MC libraries. +// As such, it deliberately does not include references to LLVM core +// code gen types, passes, etc.. +// +//===----------------------------------------------------------------------===// + +#include "SPIRVBaseInfo.h" +#include "llvm/Support/ErrorHandling.h" + +using namespace llvm; +using namespace SPIRV; +// Implement getEnumName(Enum e) helper functions. +StringRef SPIRV::getCapabilityName(Capability e) { + switch (e) { + case Capability::Matrix: + return "Matrix"; + case Capability::Shader: + return "Shader"; + case Capability::Geometry: + return "Geometry"; + case Capability::Tessellation: + return "Tessellation"; + case Capability::Addresses: + return "Addresses"; + case Capability::Linkage: + return "Linkage"; + case Capability::Kernel: + return "Kernel"; + case Capability::Vector16: + return "Vector16"; + case Capability::Float16Buffer: + return "Float16Buffer"; + case Capability::Float16: + return "Float16"; + case Capability::Float64: + return "Float64"; + case Capability::Int64: + return "Int64"; + case Capability::Int64Atomics: + return "Int64Atomics"; + case Capability::ImageBasic: + return "ImageBasic"; + case Capability::ImageReadWrite: + return "ImageReadWrite"; + case Capability::ImageMipmap: + return "ImageMipmap"; + case Capability::Pipes: + return "Pipes"; + case Capability::Groups: + return "Groups"; + case Capability::DeviceEnqueue: + return "DeviceEnqueue"; + case Capability::LiteralSampler: + return "LiteralSampler"; + case Capability::AtomicStorage: + return "AtomicStorage"; + case Capability::Int16: + return "Int16"; + case Capability::TessellationPointSize: + return "TessellationPointSize"; + case Capability::GeometryPointSize: + return "GeometryPointSize"; + case Capability::ImageGatherExtended: + return "ImageGatherExtended"; + case Capability::StorageImageMultisample: + return "StorageImageMultisample"; + case Capability::UniformBufferArrayDynamicIndexing: + return "UniformBufferArrayDynamicIndexing"; + case Capability::SampledImageArrayDymnamicIndexing: + return "SampledImageArrayDymnamicIndexing"; + case Capability::ClipDistance: + return "ClipDistance"; + case Capability::CullDistance: + return "CullDistance"; + case Capability::ImageCubeArray: + return "ImageCubeArray"; + case Capability::SampleRateShading: + return "SampleRateShading"; + case Capability::ImageRect: + return "ImageRect"; + case Capability::SampledRect: + return "SampledRect"; + case Capability::GenericPointer: + return "GenericPointer"; + case Capability::Int8: + return "Int8"; + case Capability::InputAttachment: + return "InputAttachment"; + case Capability::SparseResidency: + return "SparseResidency"; + case Capability::MinLod: + return "MinLod"; + case Capability::Sampled1D: + return "Sampled1D"; + case Capability::Image1D: + return "Image1D"; + case Capability::SampledCubeArray: + return "SampledCubeArray"; + case Capability::SampledBuffer: + return "SampledBuffer"; + case Capability::ImageBuffer: + return "ImageBuffer"; + case Capability::ImageMSArray: + return "ImageMSArray"; + case Capability::StorageImageExtendedFormats: + return "StorageImageExtendedFormats"; + case Capability::ImageQuery: + return "ImageQuery"; + case Capability::DerivativeControl: + return "DerivativeControl"; + case Capability::InterpolationFunction: + return "InterpolationFunction"; + case Capability::TransformFeedback: + return "TransformFeedback"; + case Capability::GeometryStreams: + return "GeometryStreams"; + case Capability::StorageImageReadWithoutFormat: + return "StorageImageReadWithoutFormat"; + case Capability::StorageImageWriteWithoutFormat: + return "StorageImageWriteWithoutFormat"; + case Capability::MultiViewport: + return "MultiViewport"; + case Capability::SubgroupDispatch: + return "SubgroupDispatch"; + case Capability::NamedBarrier: + return "NamedBarrier"; + case Capability::PipeStorage: + return "PipeStorage"; + case Capability::GroupNonUniform: + return "GroupNonUniform"; + case Capability::GroupNonUniformVote: + return "GroupNonUniformVote"; + case Capability::GroupNonUniformArithmetic: + return "GroupNonUniformArithmetic"; + case Capability::GroupNonUniformBallot: + return "GroupNonUniformBallot"; + case Capability::GroupNonUniformShuffle: + return "GroupNonUniformShuffle"; + case Capability::GroupNonUniformShuffleRelative: + return "GroupNonUniformShuffleRelative"; + case Capability::GroupNonUniformClustered: + return "GroupNonUniformClustered"; + case Capability::GroupNonUniformQuad: + return "GroupNonUniformQuad"; + case Capability::SubgroupBallotKHR: + return "SubgroupBallotKHR"; + case Capability::DrawParameters: + return "DrawParameters"; + case Capability::SubgroupVoteKHR: + return "SubgroupVoteKHR"; + case Capability::StorageBuffer16BitAccess: + return "SBBA(16)"; + case Capability::StorageUniform16: + return "StorageUniform16"; + case Capability::StoragePushConstant16: + return "StoragePushConstant16"; + case Capability::StorageInputOutput16: + return "StorageInputOutput16"; + case Capability::DeviceGroup: + return "DeviceGroup"; + case Capability::MultiView: + return "MultiView"; + case Capability::VariablePointersStorageBuffer: + return "VAR_PTR_SB"; + case Capability::VariablePointers: + return "VariablePointers"; + case Capability::AtomicStorageOps: + return "AtomicStorageOps"; + case Capability::SampleMaskPostDepthCoverage: + return "SampleMaskPostDepthCoverage"; + case Capability::StorageBuffer8BitAccess: + return "StorageBuffer8BitAccess"; + case Capability::UniformAndStorageBuffer8BitAccess: + return "SB8BA(UniformAnd)"; + case Capability::StoragePushConstant8: + return "StoragePushConstant8"; + case Capability::DenormPreserve: + return "DenormPreserve"; + case Capability::DenormFlushToZero: + return "DenormFlushToZero"; + case Capability::SignedZeroInfNanPreserve: + return "SignedZeroInfNanPreserve"; + case Capability::RoundingModeRTE: + return "RoundingModeRTE"; + case Capability::RoundingModeRTZ: + return "RoundingModeRTZ"; + case Capability::Float16ImageAMD: + return "Float16ImageAMD"; + case Capability::ImageGatherBiasLodAMD: + return "ImageGatherBiasLodAMD"; + case Capability::FragmentMaskAMD: + return "FragmentMaskAMD"; + case Capability::StencilExportEXT: + return "StencilExportEXT"; + case Capability::ImageReadWriteLodAMD: + return "ImageReadWriteLodAMD"; + case Capability::SampleMaskOverrideCoverageNV: + return "SampleMaskOverrideCoverageNV"; + case Capability::GeometryShaderPassthroughNV: + return "GeometryShaderPassthroughNV"; + case Capability::ShaderViewportIndexLayerEXT: + return "ShaderViewportIndexLayerEXT"; + case Capability::ShaderViewportMaskNV: + return "ShaderViewportMaskNV"; + case Capability::ShaderStereoViewNV: + return "ShaderStereoViewNV"; + case Capability::PerViewAttributesNV: + return "PerViewAttributesNV"; + case Capability::FragmentFullyCoveredEXT: + return "FragmentFullyCoveredEXT"; + case Capability::MeshShadingNV: + return "MeshShadingNV"; + case Capability::ShaderNonUniformEXT: + return "ShaderNonUniformEXT"; + case Capability::RuntimeDescriptorArrayEXT: + return "RuntimeDescriptorArrayEXT"; + case Capability::InputAttachmentArrayDynamicIndexingEXT: + return "ADIE(InputAttachment)"; + case Capability::UniformTexelBufferArrayDynamicIndexingEXT: + return "ADIE(UniformTexelBuffer)"; + case Capability::StorageTexelBufferArrayDynamicIndexingEXT: + return "ADIE(StorageTexelBuffer)"; + case Capability::UniformBufferArrayNonUniformIndexingEXT: + return "ANUIE(UniformBuffer)"; + case Capability::SampledImageArrayNonUniformIndexingEXT: + return "ANUIE(SampledImage)"; + case Capability::StorageBufferArrayNonUniformIndexingEXT: + return "ANUIE(StorageBuffer)"; + case Capability::StorageImageArrayNonUniformIndexingEXT: + return "ANUIE(StorageImage)"; + case Capability::InputAttachmentArrayNonUniformIndexingEXT: + return "ANUIE(InputAttachment)"; + case Capability::UniformTexelBufferArrayNonUniformIndexingEXT: + return "ANUIE(UniformTexelBuffer)"; + case Capability::StorageTexelBufferArrayNonUniformIndexingEXT: + return "ANUIE(StorageTexelBuffer)"; + case Capability::RayTracingNV: + return "RayTracingNV"; + case Capability::SubgroupShuffleINTEL: + return "SubgroupShuffleINTEL"; + case Capability::SubgroupBufferBlockIOINTEL: + return "SubgroupBufferBlockIOINTEL"; + case Capability::SubgroupImageBlockIOINTEL: + return "SubgroupImageBlockIOINTEL"; + case Capability::SubgroupImageMediaBlockIOINTEL: + return "SubgroupImageMediaBlockIOINTEL"; + case Capability::SubgroupAvcMotionEstimationINTEL: + return "SubgroupAvcMotionEstimationINTEL"; + case Capability::SubgroupAvcMotionEstimationIntraINTEL: + return "SubgroupAvcMotionEstimationIntraINTEL"; + case Capability::SubgroupAvcMotionEstimationChromaINTEL: + return "SubgroupAvcMotionEstimationChromaINTEL"; + case Capability::GroupNonUniformPartitionedNV: + return "GroupNonUniformPartitionedNV"; + case Capability::VulkanMemoryModelKHR: + return "VulkanMemoryModelKHR"; + case Capability::VulkanMemoryModelDeviceScopeKHR: + return "VulkanMemoryModelDeviceScopeKHR"; + case Capability::ImageFootprintNV: + return "ImageFootprintNV"; + case Capability::FragmentBarycentricNV: + return "FragmentBarycentricNV"; + case Capability::ComputeDerivativeGroupQuadsNV: + return "ComputeDerivativeGroupQuadsNV"; + case Capability::ComputeDerivativeGroupLinearNV: + return "ComputeDerivativeGroupLinearNV"; + case Capability::FragmentDensityEXT: + return "FragmentDensityEXT"; + case Capability::PhysicalStorageBufferAddressesEXT: + return "PhysicalStorageBufferAddressesEXT"; + case Capability::CooperativeMatrixNV: + return "CooperativeMatrixNV"; + default: + return "UNKNOWN_ENUM"; + } + llvm_unreachable("Unexpected operand"); +} + +StringRef SPIRV::getSourceLanguageName(SourceLanguage e) { + switch (e) { + case SourceLanguage::Unknown: + return "Unknown"; + case SourceLanguage::ESSL: + return "ESSL"; + case SourceLanguage::GLSL: + return "GLSL"; + case SourceLanguage::OpenCL_C: + return "OpenCL_C"; + case SourceLanguage::OpenCL_CPP: + return "OpenCL_CPP"; + case SourceLanguage::HLSL: + return "HLSL"; + default: + return "UNKNOWN_ENUM"; + } + llvm_unreachable("Unexpected operand"); +} + +StringRef SPIRV::getExecutionModelName(ExecutionModel e) { + switch (e) { + case ExecutionModel::Vertex: + return "Vertex"; + case ExecutionModel::TessellationControl: + return "TessellationControl"; + case ExecutionModel::TessellationEvaluation: + return "TessellationEvaluation"; + case ExecutionModel::Geometry: + return "Geometry"; + case ExecutionModel::Fragment: + return "Fragment"; + case ExecutionModel::GLCompute: + return "GLCompute"; + case ExecutionModel::Kernel: + return "Kernel"; + case ExecutionModel::TaskNV: + return "TaskNV"; + case ExecutionModel::MeshNV: + return "MeshNV"; + case ExecutionModel::RayGenerationNV: + return "RayGenerationNV"; + case ExecutionModel::IntersectionNV: + return "IntersectionNV"; + case ExecutionModel::AnyHitNV: + return "AnyHitNV"; + case ExecutionModel::ClosestHitNV: + return "ClosestHitNV"; + case ExecutionModel::MissNV: + return "MissNV"; + case ExecutionModel::CallableNV: + return "CallableNV"; + default: + return "UNKNOWN_ENUM"; + } + llvm_unreachable("Unexpected operand"); +} + +StringRef SPIRV::getAddressingModelName(AddressingModel e) { + switch (e) { + case AddressingModel::Logical: + return "Logical"; + case AddressingModel::Physical32: + return "Physical32"; + case AddressingModel::Physical64: + return "Physical64"; + case AddressingModel::PhysicalStorageBuffer64EXT: + return "PSB(64EXT)"; + default: + return "UNKNOWN_ENUM"; + } + llvm_unreachable("Unexpected operand"); +} + +StringRef SPIRV::getMemoryModelName(MemoryModel e) { + switch (e) { + case MemoryModel::Simple: + return "Simple"; + case MemoryModel::GLSL450: + return "GLSL450"; + case MemoryModel::OpenCL: + return "OpenCL"; + case MemoryModel::VulkanKHR: + return "VulkanKHR"; + default: + return "UNKNOWN_ENUM"; + } + llvm_unreachable("Unexpected operand"); +} + +StringRef SPIRV::getExecutionModeName(ExecutionMode e) { + switch (e) { + case ExecutionMode::Invocations: + return "Invocations"; + case ExecutionMode::SpacingEqual: + return "SpacingEqual"; + case ExecutionMode::SpacingFractionalEven: + return "SpacingFractionalEven"; + case ExecutionMode::SpacingFractionalOdd: + return "SpacingFractionalOdd"; + case ExecutionMode::VertexOrderCw: + return "VertexOrderCw"; + case ExecutionMode::VertexOrderCcw: + return "VertexOrderCcw"; + case ExecutionMode::PixelCenterInteger: + return "PixelCenterInteger"; + case ExecutionMode::OriginUpperLeft: + return "OriginUpperLeft"; + case ExecutionMode::OriginLowerLeft: + return "OriginLowerLeft"; + case ExecutionMode::EarlyFragmentTests: + return "EarlyFragmentTests"; + case ExecutionMode::PointMode: + return "PointMode"; + case ExecutionMode::Xfb: + return "Xfb"; + case ExecutionMode::DepthReplacing: + return "DepthReplacing"; + case ExecutionMode::DepthGreater: + return "DepthGreater"; + case ExecutionMode::DepthLess: + return "DepthLess"; + case ExecutionMode::DepthUnchanged: + return "DepthUnchanged"; + case ExecutionMode::LocalSize: + return "LocalSize"; + case ExecutionMode::LocalSizeHint: + return "LocalSizeHint"; + case ExecutionMode::InputPoints: + return "InputPoints"; + case ExecutionMode::InputLines: + return "InputLines"; + case ExecutionMode::InputLinesAdjacency: + return "InputLinesAdjacency"; + case ExecutionMode::Triangles: + return "Triangles"; + case ExecutionMode::InputTrianglesAdjacency: + return "InputTrianglesAdjacency"; + case ExecutionMode::Quads: + return "Quads"; + case ExecutionMode::Isolines: + return "Isolines"; + case ExecutionMode::OutputVertices: + return "OutputVertices"; + case ExecutionMode::OutputPoints: + return "OutputPoints"; + case ExecutionMode::OutputLineStrip: + return "OutputLineStrip"; + case ExecutionMode::OutputTriangleStrip: + return "OutputTriangleStrip"; + case ExecutionMode::VecTypeHint: + return "VecTypeHint"; + case ExecutionMode::ContractionOff: + return "ContractionOff"; + case ExecutionMode::Initializer: + return "Initializer"; + case ExecutionMode::Finalizer: + return "Finalizer"; + case ExecutionMode::SubgroupSize: + return "SubgroupSize"; + case ExecutionMode::SubgroupsPerWorkgroup: + return "SubgroupsPerWorkgroup"; + case ExecutionMode::SubgroupsPerWorkgroupId: + return "SubgroupsPerWorkgroupId"; + case ExecutionMode::LocalSizeId: + return "LocalSizeId"; + case ExecutionMode::LocalSizeHintId: + return "LocalSizeHintId"; + case ExecutionMode::PostDepthCoverage: + return "PostDepthCoverage"; + case ExecutionMode::DenormPreserve: + return "DenormPreserve"; + case ExecutionMode::DenormFlushToZero: + return "DenormFlushToZero"; + case ExecutionMode::SignedZeroInfNanPreserve: + return "SignedZeroInfNanPreserve"; + case ExecutionMode::RoundingModeRTE: + return "RoundingModeRTE"; + case ExecutionMode::RoundingModeRTZ: + return "RoundingModeRTZ"; + case ExecutionMode::StencilRefReplacingEXT: + return "StencilRefReplacingEXT"; + case ExecutionMode::OutputLinesNV: + return "OutputLinesNV"; + case ExecutionMode::DerivativeGroupQuadsNV: + return "DG1(QuadsNV)"; + case ExecutionMode::DerivativeGroupLinearNV: + return "DG1(LinearNV)"; + case ExecutionMode::OutputTrianglesNV: + return "OutputTrianglesNV"; + default: + return "UNKNOWN_ENUM"; + } + llvm_unreachable("Unexpected operand"); +} + +StringRef SPIRV::getStorageClassName(StorageClass e) { + switch (e) { + case StorageClass::UniformConstant: + return "UniformConstant"; + case StorageClass::Input: + return "Input"; + case StorageClass::Uniform: + return "Uniform"; + case StorageClass::Output: + return "Output"; + case StorageClass::Workgroup: + return "Workgroup"; + case StorageClass::CrossWorkgroup: + return "CrossWorkgroup"; + case StorageClass::Private: + return "Private"; + case StorageClass::Function: + return "Function"; + case StorageClass::Generic: + return "Generic"; + case StorageClass::PushConstant: + return "PushConstant"; + case StorageClass::AtomicCounter: + return "AtomicCounter"; + case StorageClass::Image: + return "Image"; + case StorageClass::StorageBuffer: + return "StorageBuffer"; + case StorageClass::CallableDataNV: + return "CallableDataNV"; + case StorageClass::IncomingCallableDataNV: + return "IncomingCallableDataNV"; + case StorageClass::RayPayloadNV: + return "RayPayloadNV"; + case StorageClass::HitAttributeNV: + return "HitAttributeNV"; + case StorageClass::IncomingRayPayloadNV: + return "IncomingRayPayloadNV"; + case StorageClass::ShaderRecordBufferNV: + return "ShaderRecordBufferNV"; + case StorageClass::PhysicalStorageBufferEXT: + return "PSB(EXT)"; + default: + return "UNKNOWN_ENUM"; + } + llvm_unreachable("Unexpected operand"); +} + +StringRef SPIRV::getDimName(Dim dim) { + switch (dim) { + case Dim::DIM_1D: + return "1D"; + case Dim::DIM_2D: + return "2D"; + case Dim::DIM_3D: + return "3D"; + case Dim::DIM_Cube: + return "Cube"; + case Dim::DIM_Rect: + return "Rect"; + case Dim::DIM_Buffer: + return "Buffer"; + case Dim::DIM_SubpassData: + return "SubpassData"; + default: + return "UNKNOWN_ENUM"; + } + llvm_unreachable("Unexpected operand"); +} + +StringRef SPIRV::getSamplerAddressingModeName(SamplerAddressingMode e) { + switch (e) { + case SamplerAddressingMode::None: + return "None"; + case SamplerAddressingMode::ClampToEdge: + return "ClampToEdge"; + case SamplerAddressingMode::Clamp: + return "Clamp"; + case SamplerAddressingMode::Repeat: + return "Repeat"; + case SamplerAddressingMode::RepeatMirrored: + return "RepeatMirrored"; + default: + return "UNKNOWN_ENUM"; + } + llvm_unreachable("Unexpected operand"); +} + +StringRef SPIRV::getSamplerFilterModeName(SamplerFilterMode e) { + switch (e) { + case SamplerFilterMode::Nearest: + return "Nearest"; + case SamplerFilterMode::Linear: + return "Linear"; + default: + return "UNKNOWN_ENUM"; + } + llvm_unreachable("Unexpected operand"); +} + +StringRef SPIRV::getImageFormatName(ImageFormat e) { + switch (e) { + case ImageFormat::Unknown: + return "Unknown"; + case ImageFormat::Rgba32f: + return "Rgba32f"; + case ImageFormat::Rgba16f: + return "Rgba16f"; + case ImageFormat::R32f: + return "R32f"; + case ImageFormat::Rgba8: + return "Rgba8"; + case ImageFormat::Rgba8Snorm: + return "Rgba8Snorm"; + case ImageFormat::Rg32f: + return "Rg32f"; + case ImageFormat::Rg16f: + return "Rg16f"; + case ImageFormat::R11fG11fB10f: + return "R11fG11fB10f"; + case ImageFormat::R16f: + return "R16f"; + case ImageFormat::Rgba16: + return "Rgba16"; + case ImageFormat::Rgb10A2: + return "Rgb10A2"; + case ImageFormat::Rg16: + return "Rg16"; + case ImageFormat::Rg8: + return "Rg8"; + case ImageFormat::R16: + return "R16"; + case ImageFormat::R8: + return "R8"; + case ImageFormat::Rgba16Snorm: + return "Rgba16Snorm"; + case ImageFormat::Rg16Snorm: + return "Rg16Snorm"; + case ImageFormat::Rg8Snorm: + return "Rg8Snorm"; + case ImageFormat::R16Snorm: + return "R16Snorm"; + case ImageFormat::R8Snorm: + return "R8Snorm"; + case ImageFormat::Rgba32i: + return "Rgba32i"; + case ImageFormat::Rgba16i: + return "Rgba16i"; + case ImageFormat::Rgba8i: + return "Rgba8i"; + case ImageFormat::R32i: + return "R32i"; + case ImageFormat::Rg32i: + return "Rg32i"; + case ImageFormat::Rg16i: + return "Rg16i"; + case ImageFormat::Rg8i: + return "Rg8i"; + case ImageFormat::R16i: + return "R16i"; + case ImageFormat::R8i: + return "R8i"; + case ImageFormat::Rgba32ui: + return "Rgba32ui"; + case ImageFormat::Rgba16ui: + return "Rgba16ui"; + case ImageFormat::Rgba8ui: + return "Rgba8ui"; + case ImageFormat::R32ui: + return "R32ui"; + case ImageFormat::Rgb10a2ui: + return "Rgb10a2ui"; + case ImageFormat::Rg32ui: + return "Rg32ui"; + case ImageFormat::Rg16ui: + return "Rg16ui"; + case ImageFormat::Rg8ui: + return "Rg8ui"; + case ImageFormat::R16ui: + return "R16ui"; + case ImageFormat::R8ui: + return "R8ui"; + default: + return "UNKNOWN_ENUM"; + } + llvm_unreachable("Unexpected operand"); +} + +StringRef SPIRV::getImageChannelOrderName(ImageChannelOrder e) { + switch (e) { + case ImageChannelOrder::R: + return "R"; + case ImageChannelOrder::A: + return "A"; + case ImageChannelOrder::RG: + return "RG"; + case ImageChannelOrder::RA: + return "RA"; + case ImageChannelOrder::RGB: + return "RGB"; + case ImageChannelOrder::RGBA: + return "RGBA"; + case ImageChannelOrder::BGRA: + return "BGRA"; + case ImageChannelOrder::ARGB: + return "ARGB"; + case ImageChannelOrder::Intensity: + return "Intensity"; + case ImageChannelOrder::Luminance: + return "Luminance"; + case ImageChannelOrder::Rx: + return "Rx"; + case ImageChannelOrder::RGx: + return "RGx"; + case ImageChannelOrder::RGBx: + return "RGBx"; + case ImageChannelOrder::Depth: + return "Depth"; + case ImageChannelOrder::DepthStencil: + return "DepthStencil"; + case ImageChannelOrder::sRGB: + return "sRGB"; + case ImageChannelOrder::sRGBx: + return "sRGBx"; + case ImageChannelOrder::sRGBA: + return "sRGBA"; + case ImageChannelOrder::sBGRA: + return "sBGRA"; + case ImageChannelOrder::ABGR: + return "ABGR"; + default: + return "UNKNOWN_ENUM"; + } + llvm_unreachable("Unexpected operand"); +} + +StringRef SPIRV::getImageChannelDataTypeName(ImageChannelDataType e) { + switch (e) { + case ImageChannelDataType::SnormInt8: + return "SnormInt8"; + case ImageChannelDataType::SnormInt16: + return "SnormInt16"; + case ImageChannelDataType::UnormInt8: + return "UnormInt8"; + case ImageChannelDataType::UnormInt16: + return "UnormInt16"; + case ImageChannelDataType::UnormShort565: + return "UnormShort565"; + case ImageChannelDataType::UnormShort555: + return "UnormShort555"; + case ImageChannelDataType::UnormInt101010: + return "UnormInt101010"; + case ImageChannelDataType::SignedInt8: + return "SignedInt8"; + case ImageChannelDataType::SignedInt16: + return "SignedInt16"; + case ImageChannelDataType::SignedInt32: + return "SignedInt32"; + case ImageChannelDataType::UnsignedInt8: + return "UnsignedInt8"; + case ImageChannelDataType::UnsignedInt16: + return "UnsignedInt16"; + case ImageChannelDataType::UnsigendInt32: + return "UnsigendInt32"; + case ImageChannelDataType::HalfFloat: + return "HalfFloat"; + case ImageChannelDataType::Float: + return "Float"; + case ImageChannelDataType::UnormInt24: + return "UnormInt24"; + case ImageChannelDataType::UnormInt101010_2: + return "UnormInt101010_2"; + default: + return "UNKNOWN_ENUM"; + } + llvm_unreachable("Unexpected operand"); +} + +std::string SPIRV::getImageOperandName(uint32_t e) { + std::string nameString = ""; + std::string sep = ""; + if (e == static_cast(ImageOperand::None)) + return "None"; + if (e == static_cast(ImageOperand::Bias)) + return "Bias"; + if (e & static_cast(ImageOperand::Bias)) { + nameString += sep + "Bias"; + sep = "|"; + } + if (e == static_cast(ImageOperand::Lod)) + return "Lod"; + if (e & static_cast(ImageOperand::Lod)) { + nameString += sep + "Lod"; + sep = "|"; + } + if (e == static_cast(ImageOperand::Grad)) + return "Grad"; + if (e & static_cast(ImageOperand::Grad)) { + nameString += sep + "Grad"; + sep = "|"; + } + if (e == static_cast(ImageOperand::ConstOffset)) + return "ConstOffset"; + if (e & static_cast(ImageOperand::ConstOffset)) { + nameString += sep + "ConstOffset"; + sep = "|"; + } + if (e == static_cast(ImageOperand::Offset)) + return "Offset"; + if (e & static_cast(ImageOperand::Offset)) { + nameString += sep + "Offset"; + sep = "|"; + } + if (e == static_cast(ImageOperand::ConstOffsets)) + return "ConstOffsets"; + if (e & static_cast(ImageOperand::ConstOffsets)) { + nameString += sep + "ConstOffsets"; + sep = "|"; + } + if (e == static_cast(ImageOperand::Sample)) + return "Sample"; + if (e & static_cast(ImageOperand::Sample)) { + nameString += sep + "Sample"; + sep = "|"; + } + if (e == static_cast(ImageOperand::MinLod)) + return "MinLod"; + if (e & static_cast(ImageOperand::MinLod)) { + nameString += sep + "MinLod"; + sep = "|"; + } + if (e == static_cast(ImageOperand::MakeTexelAvailableKHR)) + return "MakeTexelAvailableKHR"; + if (e & static_cast(ImageOperand::MakeTexelAvailableKHR)) { + nameString += sep + "MakeTexelAvailableKHR"; + sep = "|"; + } + if (e == static_cast(ImageOperand::MakeTexelVisibleKHR)) + return "MakeTexelVisibleKHR"; + if (e & static_cast(ImageOperand::MakeTexelVisibleKHR)) { + nameString += sep + "MakeTexelVisibleKHR"; + sep = "|"; + } + if (e == static_cast(ImageOperand::NonPrivateTexelKHR)) + return "NonPrivateTexelKHR"; + if (e & static_cast(ImageOperand::NonPrivateTexelKHR)) { + nameString += sep + "NonPrivateTexelKHR"; + sep = "|"; + } + if (e == static_cast(ImageOperand::VolatileTexelKHR)) + return "VolatileTexelKHR"; + if (e & static_cast(ImageOperand::VolatileTexelKHR)) { + nameString += sep + "VolatileTexelKHR"; + sep = "|"; + } + if (e == static_cast(ImageOperand::SignExtend)) + return "SignExtend"; + if (e & static_cast(ImageOperand::SignExtend)) { + nameString += sep + "SignExtend"; + sep = "|"; + } + if (e == static_cast(ImageOperand::ZeroExtend)) + return "ZeroExtend"; + if (e & static_cast(ImageOperand::ZeroExtend)) { + nameString += sep + "ZeroExtend"; + sep = "|"; + }; + return nameString; +} + +std::string SPIRV::getFPFastMathModeName(uint32_t e) { + std::string nameString = ""; + std::string sep = ""; + if (e == static_cast(FPFastMathMode::None)) + return "None"; + if (e == static_cast(FPFastMathMode::NotNaN)) + return "NotNaN"; + if (e & static_cast(FPFastMathMode::NotNaN)) { + nameString += sep + "NotNaN"; + sep = "|"; + } + if (e == static_cast(FPFastMathMode::NotInf)) + return "NotInf"; + if (e & static_cast(FPFastMathMode::NotInf)) { + nameString += sep + "NotInf"; + sep = "|"; + } + if (e == static_cast(FPFastMathMode::NSZ)) + return "NSZ"; + if (e & static_cast(FPFastMathMode::NSZ)) { + nameString += sep + "NSZ"; + sep = "|"; + } + if (e == static_cast(FPFastMathMode::AllowRecip)) + return "AllowRecip"; + if (e & static_cast(FPFastMathMode::AllowRecip)) { + nameString += sep + "AllowRecip"; + sep = "|"; + } + if (e == static_cast(FPFastMathMode::Fast)) + return "Fast"; + if (e & static_cast(FPFastMathMode::Fast)) { + nameString += sep + "Fast"; + sep = "|"; + }; + return nameString; +} + +StringRef SPIRV::getFPRoundingModeName(FPRoundingMode e) { + switch (e) { + case FPRoundingMode::RTE: + return "RTE"; + case FPRoundingMode::RTZ: + return "RTZ"; + case FPRoundingMode::RTP: + return "RTP"; + case FPRoundingMode::RTN: + return "RTN"; + default: + return "UNKNOWN_ENUM"; + } + llvm_unreachable("Unexpected operand"); +} + +StringRef SPIRV::getLinkageTypeName(LinkageType e) { + switch (e) { + case LinkageType::Export: + return "Export"; + case LinkageType::Import: + return "Import"; + default: + return "UNKNOWN_ENUM"; + } + llvm_unreachable("Unexpected operand"); +} + +StringRef SPIRV::getAccessQualifierName(AccessQualifier e) { + switch (e) { + case AccessQualifier::ReadOnly: + return "ReadOnly"; + case AccessQualifier::WriteOnly: + return "WriteOnly"; + case AccessQualifier::ReadWrite: + return "ReadWrite"; + default: + return "UNKNOWN_ENUM"; + } + llvm_unreachable("Unexpected operand"); +} + +StringRef +SPIRV::getFunctionParameterAttributeName(FunctionParameterAttribute e) { + switch (e) { + case FunctionParameterAttribute::Zext: + return "Zext"; + case FunctionParameterAttribute::Sext: + return "Sext"; + case FunctionParameterAttribute::ByVal: + return "ByVal"; + case FunctionParameterAttribute::Sret: + return "Sret"; + case FunctionParameterAttribute::NoAlias: + return "NoAlias"; + case FunctionParameterAttribute::NoCapture: + return "NoCapture"; + case FunctionParameterAttribute::NoWrite: + return "NoWrite"; + case FunctionParameterAttribute::NoReadWrite: + return "NoReadWrite"; + default: + return "UNKNOWN_ENUM"; + } + llvm_unreachable("Unexpected operand"); +} + +StringRef SPIRV::getDecorationName(Decoration e) { + switch (e) { + case Decoration::RelaxedPrecision: + return "RelaxedPrecision"; + case Decoration::SpecId: + return "SpecId"; + case Decoration::Block: + return "Block"; + case Decoration::BufferBlock: + return "BufferBlock"; + case Decoration::RowMajor: + return "RowMajor"; + case Decoration::ColMajor: + return "ColMajor"; + case Decoration::ArrayStride: + return "ArrayStride"; + case Decoration::MatrixStride: + return "MatrixStride"; + case Decoration::GLSLShared: + return "GLSLShared"; + case Decoration::GLSLPacked: + return "GLSLPacked"; + case Decoration::CPacked: + return "CPacked"; + case Decoration::BuiltIn: + return "BuiltIn"; + case Decoration::NoPerspective: + return "NoPerspective"; + case Decoration::Flat: + return "Flat"; + case Decoration::Patch: + return "Patch"; + case Decoration::Centroid: + return "Centroid"; + case Decoration::Sample: + return "Sample"; + case Decoration::Invariant: + return "Invariant"; + case Decoration::Restrict: + return "Restrict"; + case Decoration::Aliased: + return "Aliased"; + case Decoration::Volatile: + return "Volatile"; + case Decoration::Constant: + return "Constant"; + case Decoration::Coherent: + return "Coherent"; + case Decoration::NonWritable: + return "NonWritable"; + case Decoration::NonReadable: + return "NonReadable"; + case Decoration::Uniform: + return "Uniform"; + case Decoration::UniformId: + return "UniformId"; + case Decoration::SaturatedConversion: + return "SaturatedConversion"; + case Decoration::Stream: + return "Stream"; + case Decoration::Location: + return "Location"; + case Decoration::Component: + return "Component"; + case Decoration::Index: + return "Index"; + case Decoration::Binding: + return "Binding"; + case Decoration::DescriptorSet: + return "DescriptorSet"; + case Decoration::Offset: + return "Offset"; + case Decoration::XfbBuffer: + return "XfbBuffer"; + case Decoration::XfbStride: + return "XfbStride"; + case Decoration::FuncParamAttr: + return "FuncParamAttr"; + case Decoration::FPRoundingMode: + return "FPRoundingMode"; + case Decoration::FPFastMathMode: + return "FPFastMathMode"; + case Decoration::LinkageAttributes: + return "LinkageAttributes"; + case Decoration::NoContraction: + return "NoContraction"; + case Decoration::InputAttachmentIndex: + return "InputAttachmentIndex"; + case Decoration::Alignment: + return "Alignment"; + case Decoration::MaxByteOffset: + return "MaxByteOffset"; + case Decoration::AlignmentId: + return "AlignmentId"; + case Decoration::MaxByteOffsetId: + return "MaxByteOffsetId"; + case Decoration::NoSignedWrap: + return "NoSignedWrap"; + case Decoration::NoUnsignedWrap: + return "NoUnsignedWrap"; + case Decoration::ExplicitInterpAMD: + return "ExplicitInterpAMD"; + case Decoration::OverrideCoverageNV: + return "OverrideCoverageNV"; + case Decoration::PassthroughNV: + return "PassthroughNV"; + case Decoration::ViewportRelativeNV: + return "ViewportRelativeNV"; + case Decoration::SecondaryViewportRelativeNV: + return "SecondaryViewportRelativeNV"; + case Decoration::PerPrimitiveNV: + return "PerPrimitiveNV"; + case Decoration::PerViewNV: + return "PerViewNV"; + case Decoration::PerVertexNV: + return "PerVertexNV"; + case Decoration::NonUniformEXT: + return "NonUniformEXT"; + case Decoration::CountBuffer: + return "CountBuffer"; + case Decoration::UserSemantic: + return "UserSemantic"; + case Decoration::RestrictPointerEXT: + return "RestrictPointerEXT"; + case Decoration::AliasedPointerEXT: + return "AliasedPointerEXT"; + default: + return "UNKNOWN_ENUM"; + } + llvm_unreachable("Unexpected operand"); +} + +StringRef SPIRV::getBuiltInName(BuiltIn e) { + switch (e) { + case BuiltIn::Position: + return "Position"; + case BuiltIn::PointSize: + return "PointSize"; + case BuiltIn::ClipDistance: + return "ClipDistance"; + case BuiltIn::CullDistance: + return "CullDistance"; + case BuiltIn::VertexId: + return "VertexId"; + case BuiltIn::InstanceId: + return "InstanceId"; + case BuiltIn::PrimitiveId: + return "PrimitiveId"; + case BuiltIn::InvocationId: + return "InvocationId"; + case BuiltIn::Layer: + return "Layer"; + case BuiltIn::ViewportIndex: + return "ViewportIndex"; + case BuiltIn::TessLevelOuter: + return "TessLevelOuter"; + case BuiltIn::TessLevelInner: + return "TessLevelInner"; + case BuiltIn::TessCoord: + return "TessCoord"; + case BuiltIn::PatchVertices: + return "PatchVertices"; + case BuiltIn::FragCoord: + return "FragCoord"; + case BuiltIn::PointCoord: + return "PointCoord"; + case BuiltIn::FrontFacing: + return "FrontFacing"; + case BuiltIn::SampleId: + return "SampleId"; + case BuiltIn::SamplePosition: + return "SamplePosition"; + case BuiltIn::SampleMask: + return "SampleMask"; + case BuiltIn::FragDepth: + return "FragDepth"; + case BuiltIn::HelperInvocation: + return "HelperInvocation"; + case BuiltIn::NumWorkgroups: + return "NumWorkgroups"; + case BuiltIn::WorkgroupSize: + return "WorkgroupSize"; + case BuiltIn::WorkgroupId: + return "WorkgroupId"; + case BuiltIn::LocalInvocationId: + return "LocalInvocationId"; + case BuiltIn::GlobalInvocationId: + return "GlobalInvocationId"; + case BuiltIn::LocalInvocationIndex: + return "LocalInvocationIndex"; + case BuiltIn::WorkDim: + return "WorkDim"; + case BuiltIn::GlobalSize: + return "GlobalSize"; + case BuiltIn::EnqueuedWorkgroupSize: + return "EnqueuedWorkgroupSize"; + case BuiltIn::GlobalOffset: + return "GlobalOffset"; + case BuiltIn::GlobalLinearId: + return "GlobalLinearId"; + case BuiltIn::SubgroupSize: + return "SubgroupSize"; + case BuiltIn::SubgroupMaxSize: + return "SubgroupMaxSize"; + case BuiltIn::NumSubgroups: + return "NumSubgroups"; + case BuiltIn::NumEnqueuedSubgroups: + return "NumEnqueuedSubgroups"; + case BuiltIn::SubgroupId: + return "SubgroupId"; + case BuiltIn::SubgroupLocalInvocationId: + return "SubgroupLocalInvocationId"; + case BuiltIn::VertexIndex: + return "VertexIndex"; + case BuiltIn::InstanceIndex: + return "InstanceIndex"; + case BuiltIn::SubgroupEqMask: + return "SubgroupEqMask"; + case BuiltIn::SubgroupGeMask: + return "SubgroupGeMask"; + case BuiltIn::SubgroupGtMask: + return "SubgroupGtMask"; + case BuiltIn::SubgroupLeMask: + return "SubgroupLeMask"; + case BuiltIn::SubgroupLtMask: + return "SubgroupLtMask"; + case BuiltIn::BaseVertex: + return "BaseVertex"; + case BuiltIn::BaseInstance: + return "BaseInstance"; + case BuiltIn::DrawIndex: + return "DrawIndex"; + case BuiltIn::DeviceIndex: + return "DeviceIndex"; + case BuiltIn::ViewIndex: + return "ViewIndex"; + case BuiltIn::BaryCoordNoPerspAMD: + return "BaryCoordNoPerspAMD"; + case BuiltIn::BaryCoordNoPerspCentroidAMD: + return "BaryCoordNoPerspCentroidAMD"; + case BuiltIn::BaryCoordNoPerspSampleAMD: + return "BaryCoordNoPerspSampleAMD"; + case BuiltIn::BaryCoordSmoothAMD: + return "BaryCoordSmoothAMD"; + case BuiltIn::BaryCoordSmoothCentroid: + return "BaryCoordSmoothCentroid"; + case BuiltIn::BaryCoordSmoothSample: + return "BaryCoordSmoothSample"; + case BuiltIn::BaryCoordPullModel: + return "BaryCoordPullModel"; + case BuiltIn::FragStencilRefEXT: + return "FragStencilRefEXT"; + case BuiltIn::ViewportMaskNV: + return "ViewportMaskNV"; + case BuiltIn::SecondaryPositionNV: + return "SecondaryPositionNV"; + case BuiltIn::SecondaryViewportMaskNV: + return "SecondaryViewportMaskNV"; + case BuiltIn::PositionPerViewNV: + return "PositionPerViewNV"; + case BuiltIn::ViewportMaskPerViewNV: + return "ViewportMaskPerViewNV"; + case BuiltIn::FullyCoveredEXT: + return "FullyCoveredEXT"; + case BuiltIn::TaskCountNV: + return "TaskCountNV"; + case BuiltIn::PrimitiveCountNV: + return "PrimitiveCountNV"; + case BuiltIn::PrimitiveIndicesNV: + return "PrimitiveIndicesNV"; + case BuiltIn::ClipDistancePerViewNV: + return "ClipDistancePerViewNV"; + case BuiltIn::CullDistancePerViewNV: + return "CullDistancePerViewNV"; + case BuiltIn::LayerPerViewNV: + return "LayerPerViewNV"; + case BuiltIn::MeshViewCountNV: + return "MeshViewCountNV"; + case BuiltIn::MeshViewIndices: + return "MeshViewIndices"; + case BuiltIn::BaryCoordNV: + return "BaryCoordNV"; + case BuiltIn::BaryCoordNoPerspNV: + return "BaryCoordNoPerspNV"; + case BuiltIn::FragSizeEXT: + return "FragSizeEXT"; + case BuiltIn::FragInvocationCountEXT: + return "FragInvocationCountEXT"; + case BuiltIn::LaunchIdNV: + return "LaunchIdNV"; + case BuiltIn::LaunchSizeNV: + return "LaunchSizeNV"; + case BuiltIn::WorldRayOriginNV: + return "WorldRayOriginNV"; + case BuiltIn::WorldRayDirectionNV: + return "WorldRayDirectionNV"; + case BuiltIn::ObjectRayOriginNV: + return "ObjectRayOriginNV"; + case BuiltIn::ObjectRayDirectionNV: + return "ObjectRayDirectionNV"; + case BuiltIn::RayTminNV: + return "RayTminNV"; + case BuiltIn::RayTmaxNV: + return "RayTmaxNV"; + case BuiltIn::InstanceCustomIndexNV: + return "InstanceCustomIndexNV"; + case BuiltIn::ObjectToWorldNV: + return "ObjectToWorldNV"; + case BuiltIn::WorldToObjectNV: + return "WorldToObjectNV"; + case BuiltIn::HitTNV: + return "HitTNV"; + case BuiltIn::HitKindNV: + return "HitKindNV"; + case BuiltIn::IncomingRayFlagsNV: + return "IncomingRayFlagsNV"; + default: + return "UNKNOWN_ENUM"; + } + llvm_unreachable("Unexpected operand"); +} + +std::string SPIRV::getSelectionControlName(uint32_t e) { + std::string nameString = ""; + std::string sep = ""; + if (e == static_cast(SelectionControl::None)) + return "None"; + if (e == static_cast(SelectionControl::Flatten)) + return "Flatten"; + if (e & static_cast(SelectionControl::Flatten)) { + nameString += sep + "Flatten"; + sep = "|"; + } + if (e == static_cast(SelectionControl::DontFlatten)) + return "DontFlatten"; + if (e & static_cast(SelectionControl::DontFlatten)) { + nameString += sep + "DontFlatten"; + sep = "|"; + }; + return nameString; +} + +std::string SPIRV::getLoopControlName(uint32_t e) { + std::string nameString = ""; + std::string sep = ""; + if (e == static_cast(LoopControl::None)) + return "None"; + if (e == static_cast(LoopControl::Unroll)) + return "Unroll"; + if (e & static_cast(LoopControl::Unroll)) { + nameString += sep + "Unroll"; + sep = "|"; + } + if (e == static_cast(LoopControl::DontUnroll)) + return "DontUnroll"; + if (e & static_cast(LoopControl::DontUnroll)) { + nameString += sep + "DontUnroll"; + sep = "|"; + } + if (e == static_cast(LoopControl::DependencyInfinite)) + return "DependencyInfinite"; + if (e & static_cast(LoopControl::DependencyInfinite)) { + nameString += sep + "DependencyInfinite"; + sep = "|"; + } + if (e == static_cast(LoopControl::DependencyLength)) + return "DependencyLength"; + if (e & static_cast(LoopControl::DependencyLength)) { + nameString += sep + "DependencyLength"; + sep = "|"; + } + if (e == static_cast(LoopControl::MinIterations)) + return "MinIterations"; + if (e & static_cast(LoopControl::MinIterations)) { + nameString += sep + "MinIterations"; + sep = "|"; + } + if (e == static_cast(LoopControl::MaxIterations)) + return "MaxIterations"; + if (e & static_cast(LoopControl::MaxIterations)) { + nameString += sep + "MaxIterations"; + sep = "|"; + } + if (e == static_cast(LoopControl::IterationMultiple)) + return "IterationMultiple"; + if (e & static_cast(LoopControl::IterationMultiple)) { + nameString += sep + "IterationMultiple"; + sep = "|"; + } + if (e == static_cast(LoopControl::PeelCount)) + return "PeelCount"; + if (e & static_cast(LoopControl::PeelCount)) { + nameString += sep + "PeelCount"; + sep = "|"; + } + if (e == static_cast(LoopControl::PartialCount)) + return "PartialCount"; + if (e & static_cast(LoopControl::PartialCount)) { + nameString += sep + "PartialCount"; + sep = "|"; + }; + return nameString; +} + +std::string SPIRV::getFunctionControlName(uint32_t e) { + std::string nameString = ""; + std::string sep = ""; + if (e == static_cast(FunctionControl::None)) + return "None"; + if (e == static_cast(FunctionControl::Inline)) + return "Inline"; + if (e & static_cast(FunctionControl::Inline)) { + nameString += sep + "Inline"; + sep = "|"; + } + if (e == static_cast(FunctionControl::DontInline)) + return "DontInline"; + if (e & static_cast(FunctionControl::DontInline)) { + nameString += sep + "DontInline"; + sep = "|"; + } + if (e == static_cast(FunctionControl::Pure)) + return "Pure"; + if (e & static_cast(FunctionControl::Pure)) { + nameString += sep + "Pure"; + sep = "|"; + } + if (e == static_cast(FunctionControl::Const)) + return "Const"; + if (e & static_cast(FunctionControl::Const)) { + nameString += sep + "Const"; + sep = "|"; + }; + return nameString; +} + +std::string SPIRV::getMemorySemanticsName(uint32_t e) { + std::string nameString = ""; + std::string sep = ""; + if (e == static_cast(MemorySemantics::None)) + return "None"; + if (e == static_cast(MemorySemantics::Acquire)) + return "Acquire"; + if (e & static_cast(MemorySemantics::Acquire)) { + nameString += sep + "Acquire"; + sep = "|"; + } + if (e == static_cast(MemorySemantics::Release)) + return "Release"; + if (e & static_cast(MemorySemantics::Release)) { + nameString += sep + "Release"; + sep = "|"; + } + if (e == static_cast(MemorySemantics::AcquireRelease)) + return "AcquireRelease"; + if (e & static_cast(MemorySemantics::AcquireRelease)) { + nameString += sep + "AcquireRelease"; + sep = "|"; + } + if (e == static_cast(MemorySemantics::SequentiallyConsistent)) + return "SequentiallyConsistent"; + if (e & static_cast(MemorySemantics::SequentiallyConsistent)) { + nameString += sep + "SequentiallyConsistent"; + sep = "|"; + } + if (e == static_cast(MemorySemantics::UniformMemory)) + return "UniformMemory"; + if (e & static_cast(MemorySemantics::UniformMemory)) { + nameString += sep + "UniformMemory"; + sep = "|"; + } + if (e == static_cast(MemorySemantics::SubgroupMemory)) + return "SubgroupMemory"; + if (e & static_cast(MemorySemantics::SubgroupMemory)) { + nameString += sep + "SubgroupMemory"; + sep = "|"; + } + if (e == static_cast(MemorySemantics::WorkgroupMemory)) + return "WorkgroupMemory"; + if (e & static_cast(MemorySemantics::WorkgroupMemory)) { + nameString += sep + "WorkgroupMemory"; + sep = "|"; + } + if (e == static_cast(MemorySemantics::CrossWorkgroupMemory)) + return "CrossWorkgroupMemory"; + if (e & static_cast(MemorySemantics::CrossWorkgroupMemory)) { + nameString += sep + "CrossWorkgroupMemory"; + sep = "|"; + } + if (e == static_cast(MemorySemantics::AtomicCounterMemory)) + return "AtomicCounterMemory"; + if (e & static_cast(MemorySemantics::AtomicCounterMemory)) { + nameString += sep + "AtomicCounterMemory"; + sep = "|"; + } + if (e == static_cast(MemorySemantics::ImageMemory)) + return "ImageMemory"; + if (e & static_cast(MemorySemantics::ImageMemory)) { + nameString += sep + "ImageMemory"; + sep = "|"; + } + if (e == static_cast(MemorySemantics::OutputMemoryKHR)) + return "OutputMemoryKHR"; + if (e & static_cast(MemorySemantics::OutputMemoryKHR)) { + nameString += sep + "OutputMemoryKHR"; + sep = "|"; + } + if (e == static_cast(MemorySemantics::MakeAvailableKHR)) + return "MakeAvailableKHR"; + if (e & static_cast(MemorySemantics::MakeAvailableKHR)) { + nameString += sep + "MakeAvailableKHR"; + sep = "|"; + } + if (e == static_cast(MemorySemantics::MakeVisibleKHR)) + return "MakeVisibleKHR"; + if (e & static_cast(MemorySemantics::MakeVisibleKHR)) { + nameString += sep + "MakeVisibleKHR"; + sep = "|"; + }; + return nameString; +} + +std::string SPIRV::getMemoryOperandName(uint32_t e) { + std::string nameString = ""; + std::string sep = ""; + if (e == static_cast(MemoryOperand::None)) + return "None"; + if (e == static_cast(MemoryOperand::Volatile)) + return "Volatile"; + if (e & static_cast(MemoryOperand::Volatile)) { + nameString += sep + "Volatile"; + sep = "|"; + } + if (e == static_cast(MemoryOperand::Aligned)) + return "Aligned"; + if (e & static_cast(MemoryOperand::Aligned)) { + nameString += sep + "Aligned"; + sep = "|"; + } + if (e == static_cast(MemoryOperand::Nontemporal)) + return "Nontemporal"; + if (e & static_cast(MemoryOperand::Nontemporal)) { + nameString += sep + "Nontemporal"; + sep = "|"; + } + if (e == static_cast(MemoryOperand::MakePointerAvailableKHR)) + return "MakePointerAvailableKHR"; + if (e & static_cast(MemoryOperand::MakePointerAvailableKHR)) { + nameString += sep + "MakePointerAvailableKHR"; + sep = "|"; + } + if (e == static_cast(MemoryOperand::MakePointerVisibleKHR)) + return "MakePointerVisibleKHR"; + if (e & static_cast(MemoryOperand::MakePointerVisibleKHR)) { + nameString += sep + "MakePointerVisibleKHR"; + sep = "|"; + } + if (e == static_cast(MemoryOperand::NonPrivatePointerKHR)) + return "NonPrivatePointerKHR"; + if (e & static_cast(MemoryOperand::NonPrivatePointerKHR)) { + nameString += sep + "NonPrivatePointerKHR"; + sep = "|"; + }; + return nameString; +} + +StringRef SPIRV::getScopeName(Scope e) { + switch (e) { + case Scope::CrossDevice: + return "CrossDevice"; + case Scope::Device: + return "Device"; + case Scope::Workgroup: + return "Workgroup"; + case Scope::Subgroup: + return "Subgroup"; + case Scope::Invocation: + return "Invocation"; + case Scope::QueueFamilyKHR: + return "QueueFamilyKHR"; + default: + return "UNKNOWN_ENUM"; + } + llvm_unreachable("Unexpected operand"); +} + +StringRef SPIRV::getGroupOperationName(GroupOperation e) { + switch (e) { + case GroupOperation::Reduce: + return "Reduce"; + case GroupOperation::InclusiveScan: + return "InclusiveScan"; + case GroupOperation::ExclusiveScan: + return "ExclusiveScan"; + case GroupOperation::ClusteredReduce: + return "ClusteredReduce"; + case GroupOperation::PartitionedReduceNV: + return "PartitionedReduceNV"; + case GroupOperation::PartitionedInclusiveScanNV: + return "PartitionedInclusiveScanNV"; + case GroupOperation::PartitionedExclusiveScanNV: + return "PartitionedExclusiveScanNV"; + default: + return "UNKNOWN_ENUM"; + } + llvm_unreachable("Unexpected operand"); +} + +StringRef SPIRV::getKernelEnqueueFlagsName(KernelEnqueueFlags e) { + switch (e) { + case KernelEnqueueFlags::NoWait: + return "NoWait"; + case KernelEnqueueFlags::WaitKernel: + return "WaitKernel"; + case KernelEnqueueFlags::WaitWorkGroup: + return "WaitWorkGroup"; + default: + return "UNKNOWN_ENUM"; + } + llvm_unreachable("Unexpected operand"); +} + +StringRef SPIRV::getKernelProfilingInfoName(KernelProfilingInfo e) { + switch (e) { + case KernelProfilingInfo::None: + return "None"; + case KernelProfilingInfo::CmdExecTime: + return "CmdExecTime"; + default: + return "UNKNOWN_ENUM"; + } + llvm_unreachable("Unexpected operand"); +} Index: llvm/lib/Target/SPIRV/SPIRV.h =================================================================== --- llvm/lib/Target/SPIRV/SPIRV.h +++ llvm/lib/Target/SPIRV/SPIRV.h @@ -17,6 +17,12 @@ class SPIRVTargetMachine; class SPIRVRegisterBankInfo; class SPIRVSubtarget; +class InstructionSelector; + +InstructionSelector * +createSPIRVInstructionSelector(const SPIRVTargetMachine &TM, + const SPIRVSubtarget &Subtarget, + const SPIRVRegisterBankInfo &RBI); } // namespace llvm #endif // LLVM_LIB_TARGET_SPIRV_SPIRV_H Index: llvm/lib/Target/SPIRV/SPIRVCallLowering.h =================================================================== --- llvm/lib/Target/SPIRV/SPIRVCallLowering.h +++ llvm/lib/Target/SPIRV/SPIRVCallLowering.h @@ -17,12 +17,19 @@ namespace llvm { +class SPIRVGlobalRegistry; +class SPIRVSubtarget; class SPIRVTargetLowering; class SPIRVCallLowering : public CallLowering { private: + const SPIRVSubtarget &ST; + // Used to create and assign function, argument, and return type information. + SPIRVGlobalRegistry *GR; + public: - SPIRVCallLowering(const SPIRVTargetLowering &TLI); + SPIRVCallLowering(const SPIRVTargetLowering &TLI, const SPIRVSubtarget &ST, + SPIRVGlobalRegistry *GR); // Built OpReturn or OpReturnValue. bool lowerReturn(MachineIRBuilder &MIRBuiler, const Value *Val, Index: llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp =================================================================== --- llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp +++ llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp @@ -12,17 +12,21 @@ //===----------------------------------------------------------------------===// #include "SPIRVCallLowering.h" +#include "MCTargetDesc/SPIRVBaseInfo.h" #include "SPIRV.h" +#include "SPIRVGlobalRegistry.h" #include "SPIRVISelLowering.h" #include "SPIRVRegisterInfo.h" #include "SPIRVSubtarget.h" +#include "SPIRVUtils.h" #include "llvm/CodeGen/FunctionLoweringInfo.h" -#include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h" using namespace llvm; -SPIRVCallLowering::SPIRVCallLowering(const SPIRVTargetLowering &TLI) - : CallLowering(&TLI) {} +SPIRVCallLowering::SPIRVCallLowering(const SPIRVTargetLowering &TLI, + const SPIRVSubtarget &ST, + SPIRVGlobalRegistry *GR) + : CallLowering(&TLI), ST(ST), GR(GR) {} bool SPIRVCallLowering::lowerReturn(MachineIRBuilder &MIRBuilder, const Value *Val, ArrayRef VRegs, @@ -32,19 +36,39 @@ // TODO: handle the case of multiple registers. if (VRegs.size() > 1) return false; - if (Val) { - MIRBuilder.buildInstr(SPIRV::OpReturnValue).addUse(VRegs[0]); - return true; - } + if (Val) + return MIRBuilder.buildInstr(SPIRV::OpReturnValue) + .addUse(VRegs[0]) + .constrainAllUses(MIRBuilder.getTII(), *ST.getRegisterInfo(), + *ST.getRegBankInfo()); MIRBuilder.buildInstr(SPIRV::OpReturn); return true; } +// Based on the LLVM function attributes, get a SPIR-V FunctionControl. +static uint32_t getFunctionControl(const Function &F) { + uint32_t FuncControl = static_cast(SPIRV::FunctionControl::None); + if (F.hasFnAttribute(Attribute::AttrKind::AlwaysInline)) { + FuncControl |= static_cast(SPIRV::FunctionControl::Inline); + } + if (F.hasFnAttribute(Attribute::AttrKind::ReadNone)) { + FuncControl |= static_cast(SPIRV::FunctionControl::Pure); + } + if (F.hasFnAttribute(Attribute::AttrKind::ReadOnly)) { + FuncControl |= static_cast(SPIRV::FunctionControl::Const); + } + if (F.hasFnAttribute(Attribute::AttrKind::NoInline)) { + FuncControl |= static_cast(SPIRV::FunctionControl::DontInline); + } + return FuncControl; +} + bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder, const Function &F, ArrayRef> VRegs, FunctionLoweringInfo &FLI) const { - auto MRI = MIRBuilder.getMRI(); + assert(GR && "Must initialize the SPIRV type registry before lowering args."); + // Assign types and names to all args, and store their types for later. SmallVector ArgTypeVRegs; if (VRegs.size() > 0) { @@ -54,21 +78,57 @@ // TODO: handle the case of multiple registers. if (VRegs[i].size() > 1) return false; - ArgTypeVRegs.push_back( - MRI->createGenericVirtualRegister(LLT::scalar(32))); + auto *SpirvTy = + GR->assignTypeToVReg(Arg.getType(), VRegs[i][0], MIRBuilder); + ArgTypeVRegs.push_back(GR->getSPIRVTypeID(SpirvTy)); + + if (Arg.hasName()) + buildOpName(VRegs[i][0], Arg.getName(), MIRBuilder); + if (Arg.getType()->isPointerTy()) { + auto DerefBytes = static_cast(Arg.getDereferenceableBytes()); + if (DerefBytes != 0) + buildOpDecorate(VRegs[i][0], MIRBuilder, + SPIRV::Decoration::MaxByteOffset, {DerefBytes}); + } + if (Arg.hasAttribute(Attribute::Alignment)) { + auto Alignment = static_cast( + Arg.getAttribute(Attribute::Alignment).getValueAsInt()); + buildOpDecorate(VRegs[i][0], MIRBuilder, SPIRV::Decoration::Alignment, + {Alignment}); + } + if (Arg.hasAttribute(Attribute::ReadOnly)) { + auto Attr = + static_cast(SPIRV::FunctionParameterAttribute::NoWrite); + buildOpDecorate(VRegs[i][0], MIRBuilder, + SPIRV::Decoration::FuncParamAttr, {Attr}); + } + if (Arg.hasAttribute(Attribute::ZExt)) { + auto Attr = + static_cast(SPIRV::FunctionParameterAttribute::Zext); + buildOpDecorate(VRegs[i][0], MIRBuilder, + SPIRV::Decoration::FuncParamAttr, {Attr}); + } ++i; } } // Generate a SPIR-V type for the function. + auto MRI = MIRBuilder.getMRI(); Register FuncVReg = MRI->createGenericVirtualRegister(LLT::scalar(32)); MRI->setRegClass(FuncVReg, &SPIRV::IDRegClass); + auto *FTy = F.getFunctionType(); + auto FuncTy = GR->assignTypeToVReg(FTy, FuncVReg, MIRBuilder); + + // Build the OpTypeFunction declaring it. + Register ReturnTypeID = FuncTy->getOperand(1).getReg(); + uint32_t FuncControl = getFunctionControl(F); + MIRBuilder.buildInstr(SPIRV::OpFunction) .addDef(FuncVReg) - .addUse(MRI->createGenericVirtualRegister(LLT::scalar(32))) - .addImm(0) - .addUse(MRI->createGenericVirtualRegister(LLT::scalar(32))); + .addUse(ReturnTypeID) + .addImm(FuncControl) + .addUse(GR->getSPIRVTypeID(FuncTy)); // Add OpFunctionParameters. const unsigned NumArgs = ArgTypeVRegs.size(); @@ -79,6 +139,24 @@ .addDef(VRegs[i][0]) .addUse(ArgTypeVRegs[i]); } + // Name the function. + if (F.hasName()) + buildOpName(FuncVReg, F.getName(), MIRBuilder); + + // Handle entry points and function linkage. + if (F.getCallingConv() == CallingConv::SPIR_KERNEL) { + auto MIB = MIRBuilder.buildInstr(SPIRV::OpEntryPoint) + .addImm(static_cast(SPIRV::ExecutionModel::Kernel)) + .addUse(FuncVReg); + addStringImm(F.getName(), MIB); + } else if (F.getLinkage() == GlobalValue::LinkageTypes::ExternalLinkage || + F.getLinkage() == GlobalValue::LinkOnceODRLinkage) { + auto LnkTy = F.isDeclaration() ? SPIRV::LinkageType::Import + : SPIRV::LinkageType::Export; + buildOpDecorate(FuncVReg, MIRBuilder, SPIRV::Decoration::LinkageAttributes, + {static_cast(LnkTy)}, F.getGlobalIdentifier()); + } + return true; } @@ -91,15 +169,49 @@ Register ResVReg = Info.OrigRet.Regs.empty() ? Register(0) : Info.OrigRet.Regs[0]; + // Emit a regular OpFunctionCall. If it's an externally declared function, + // be sure to emit its type and function declaration here. It will be + // hoisted globally later. + if (Info.Callee.isGlobal()) { + auto *CF = dyn_cast_or_null(Info.Callee.getGlobal()); + // TODO: support constexpr casts and indirect calls. + if (CF == nullptr) + return false; + if (CF->isDeclaration()) { + // Emit the type info and forward function declaration to the first MBB + // to ensure VReg definition dependencies are valid across all MBBs. + MachineBasicBlock::iterator OldII = MIRBuilder.getInsertPt(); + MachineBasicBlock &OldBB = MIRBuilder.getMBB(); + MachineBasicBlock &FirstBB = *MIRBuilder.getMF().getBlockNumbered(0); + MIRBuilder.setInsertPt(FirstBB, FirstBB.instr_end()); + + SmallVector, 8> VRegArgs; + SmallVector, 8> ToInsert; + for (const Argument &Arg : CF->args()) { + if (MIRBuilder.getDataLayout().getTypeStoreSize(Arg.getType()).isZero()) + continue; // Don't handle zero sized types. + ToInsert.push_back({MIRBuilder.getMRI()->createGenericVirtualRegister( + LLT::scalar(32))}); + VRegArgs.push_back(ToInsert.back()); + } + // TODO: Reuse FunctionLoweringInfo. + FunctionLoweringInfo FuncInfo; + lowerFormalArguments(MIRBuilder, *CF, VRegArgs, FuncInfo); + MIRBuilder.setInsertPt(OldBB, OldII); + } + } + // Make sure there's a valid return reg, even for functions returning void. if (!ResVReg.isValid()) { ResVReg = MIRBuilder.getMRI()->createVirtualRegister(&SPIRV::IDRegClass); } + SPIRVType *RetType = + GR->assignTypeToVReg(Info.OrigRet.Ty, ResVReg, MIRBuilder); + // Emit the OpFunctionCall and its args. auto MIB = MIRBuilder.buildInstr(SPIRV::OpFunctionCall) .addDef(ResVReg) - .addUse(MIRBuilder.getMRI()->createVirtualRegister( - &SPIRV::IDRegClass)) + .addUse(GR->getSPIRVTypeID(RetType)) .add(Info.Callee); for (const auto &Arg : Info.OrigArgs) { @@ -108,5 +220,6 @@ return false; MIB.addUse(Arg.Regs[0]); } - return true; + return MIB.constrainAllUses(MIRBuilder.getTII(), *ST.getRegisterInfo(), + *ST.getRegBankInfo()); } Index: llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h =================================================================== --- /dev/null +++ llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h @@ -0,0 +1,174 @@ +//===-- SPIRVGlobalRegistry.h - SPIR-V Global Registry ----------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// SPIRVGlobalRegistry is used to maintain rich type information required for +// SPIR-V even after lowering from LLVM IR to GMIR. It can convert an llvm::Type +// into an OpTypeXXX instruction, and map it to a virtual register. Also it +// builds and supports consistency of constants and global variables. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_LIB_TARGET_SPIRV_SPIRVTYPEMANAGER_H +#define LLVM_LIB_TARGET_SPIRV_SPIRVTYPEMANAGER_H + +#include "MCTargetDesc/SPIRVBaseInfo.h" +#include "SPIRVInstrInfo.h" +#include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h" + +namespace llvm { +using SPIRVType = const MachineInstr; + +class SPIRVGlobalRegistry { + // Registers holding values which have types associated with them. + // Initialized upon VReg definition in IRTranslator. + // Do not confuse this with DuplicatesTracker as DT maps Type* to + // where Reg = OpType... + // while VRegToTypeMap tracks SPIR-V type assigned to other regs (i.e. not + // type-declaring ones) + DenseMap> VRegToTypeMap; + + DenseMap SPIRVToLLVMType; + + // Number of bits pointers and size_t integers require. + const unsigned PointerSize; + + // Add a new OpTypeXXX instruction without checking for duplicates. + SPIRVType * + createSPIRVType(const Type *Type, MachineIRBuilder &MIRBuilder, + SPIRV::AccessQualifier AQ = SPIRV::AccessQualifier::ReadWrite, + bool EmitIR = true); + +public: + SPIRVGlobalRegistry(unsigned PointerSize); + + MachineFunction *CurMF; + + // Get or create a SPIR-V type corresponding the given LLVM IR type, + // and map it to the given VReg by creating an ASSIGN_TYPE instruction. + SPIRVType *assignTypeToVReg( + const Type *Type, Register VReg, MachineIRBuilder &MIRBuilder, + SPIRV::AccessQualifier AQ = SPIRV::AccessQualifier::ReadWrite, + bool EmitIR = true); + + // In cases where the SPIR-V type is already known, this function can be + // used to map it to the given VReg via an ASSIGN_TYPE instruction. + void assignSPIRVTypeToVReg(SPIRVType *Type, Register VReg, + MachineIRBuilder &MIRBuilder); + + // Either generate a new OpTypeXXX instruction or return an existing one + // corresponding to the given LLVM IR type. + // EmitIR controls if we emit GMIR or SPV constants (e.g. for array sizes) + // because this method may be called from InstructionSelector and we don't + // want to emit extra IR instructions there. + SPIRVType *getOrCreateSPIRVType( + const Type *Type, MachineIRBuilder &MIRBuilder, + SPIRV::AccessQualifier AQ = SPIRV::AccessQualifier::ReadWrite, + bool EmitIR = true); + + const Type *getTypeForSPIRVType(const SPIRVType *Ty) const { + auto Res = SPIRVToLLVMType.find(Ty); + assert(Res != SPIRVToLLVMType.end()); + return Res->second; + } + + // Return the SPIR-V type instruction corresponding to the given VReg, or + // nullptr if no such type instruction exists. + SPIRVType *getSPIRVTypeForVReg(Register VReg) const; + + // Whether the given VReg has a SPIR-V type mapped to it yet. + bool hasSPIRVTypeForVReg(Register VReg) const { + return getSPIRVTypeForVReg(VReg) != nullptr; + } + + // Return the VReg holding the result of the given OpTypeXXX instruction. + Register getSPIRVTypeID(const SPIRVType *SpirvType) const { + assert(SpirvType && "Attempting to get type id for nullptr type."); + return SpirvType->defs().begin()->getReg(); + } + + void setCurrentFunc(MachineFunction &MF) { CurMF = &MF; } + + // Whether the given VReg has an OpTypeXXX instruction mapped to it with the + // given opcode (e.g. OpTypeFloat). + bool isScalarOfType(Register VReg, unsigned TypeOpcode) const; + + // Return true if the given VReg's assigned SPIR-V type is either a scalar + // matching the given opcode, or a vector with an element type matching that + // opcode (e.g. OpTypeBool, or OpTypeVector %x 4, where %x is OpTypeBool). + bool isScalarOrVectorOfType(Register VReg, unsigned TypeOpcode) const; + + // For vectors or scalars of ints/floats, return the scalar type's bitwidth. + unsigned getScalarOrVectorBitWidth(const SPIRVType *Type) const; + + // For integer vectors or scalars, return whether the integers are signed. + bool isScalarOrVectorSigned(const SPIRVType *Type) const; + + // Gets the storage class of the pointer type assigned to this vreg. + SPIRV::StorageClass getPointerStorageClass(Register VReg) const; + + // Return the number of bits SPIR-V pointers and size_t variables require. + unsigned getPointerSize() const { return PointerSize; } + +private: + SPIRVType *getOpTypeBool(MachineIRBuilder &MIRBuilder); + + SPIRVType *getOpTypeInt(uint32_t Width, MachineIRBuilder &MIRBuilder, + bool IsSigned = false); + + SPIRVType *getOpTypeFloat(uint32_t Width, MachineIRBuilder &MIRBuilder); + + SPIRVType *getOpTypeVoid(MachineIRBuilder &MIRBuilder); + + SPIRVType *getOpTypeVector(uint32_t NumElems, SPIRVType *ElemType, + MachineIRBuilder &MIRBuilder); + + SPIRVType *getOpTypeArray(uint32_t NumElems, SPIRVType *ElemType, + MachineIRBuilder &MIRBuilder, bool EmitIR = true); + + SPIRVType *getOpTypePointer(SPIRV::StorageClass SC, SPIRVType *ElemType, + MachineIRBuilder &MIRBuilder); + + SPIRVType *getOpTypeFunction(SPIRVType *RetType, + const SmallVectorImpl &ArgTypes, + MachineIRBuilder &MIRBuilder); + SPIRVType *restOfCreateSPIRVType(Type *LLVMTy, MachineInstrBuilder MIB); + +public: + Register buildConstantInt(uint64_t Val, MachineIRBuilder &MIRBuilder, + SPIRVType *SpvType = nullptr, bool EmitIR = true); + Register buildConstantFP(APFloat Val, MachineIRBuilder &MIRBuilder, + SPIRVType *SpvType = nullptr); + Register + buildGlobalVariable(Register Reg, SPIRVType *BaseType, StringRef Name, + const GlobalValue *GV, SPIRV::StorageClass Storage, + const MachineInstr *Init, bool IsConst, bool HasLinkageTy, + SPIRV::LinkageType LinkageType, + MachineIRBuilder &MIRBuilder, bool IsInstSelector); + + // Convenient helpers for getting types with check for duplicates. + SPIRVType *getOrCreateSPIRVIntegerType(unsigned BitWidth, + MachineIRBuilder &MIRBuilder); + SPIRVType *getOrCreateSPIRVIntegerType(unsigned BitWidth, MachineInstr &I, + const SPIRVInstrInfo &TII); + SPIRVType *getOrCreateSPIRVBoolType(MachineIRBuilder &MIRBuilder); + SPIRVType *getOrCreateSPIRVVectorType(SPIRVType *BaseType, + unsigned NumElements, + MachineIRBuilder &MIRBuilder); + SPIRVType *getOrCreateSPIRVVectorType(SPIRVType *BaseType, + unsigned NumElements, MachineInstr &I, + const SPIRVInstrInfo &TII); + + SPIRVType *getOrCreateSPIRVPointerType( + SPIRVType *BaseType, MachineIRBuilder &MIRBuilder, + SPIRV::StorageClass SClass = SPIRV::StorageClass::Function); + SPIRVType *getOrCreateSPIRVPointerType( + SPIRVType *BaseType, MachineInstr &I, const SPIRVInstrInfo &TII, + SPIRV::StorageClass SClass = SPIRV::StorageClass::Function); +}; +} // end namespace llvm +#endif // LLLVM_LIB_TARGET_SPIRV_SPIRVTYPEMANAGER_H Index: llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp =================================================================== --- /dev/null +++ llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp @@ -0,0 +1,455 @@ +//===-- SPIRVGlobalRegistry.cpp - SPIR-V Global Registry --------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains the implementation of the SPIRVGlobalRegistry class, +// which is used to maintain rich type information required for SPIR-V even +// after lowering from LLVM IR to GMIR. It can convert an llvm::Type into +// an OpTypeXXX instruction, and map it to a virtual register. Also it builds +// and supports consistency of constants and global variables. +// +//===----------------------------------------------------------------------===// + +#include "SPIRVGlobalRegistry.h" +#include "SPIRV.h" +#include "SPIRVSubtarget.h" +#include "SPIRVTargetMachine.h" +#include "SPIRVUtils.h" + +using namespace llvm; +SPIRVGlobalRegistry::SPIRVGlobalRegistry(unsigned PointerSize) + : PointerSize(PointerSize) {} + +SPIRVType *SPIRVGlobalRegistry::assignTypeToVReg( + const Type *Type, Register VReg, MachineIRBuilder &MIRBuilder, + SPIRV::AccessQualifier AccessQual, bool EmitIR) { + + SPIRVType *SpirvType = + getOrCreateSPIRVType(Type, MIRBuilder, AccessQual, EmitIR); + assignSPIRVTypeToVReg(SpirvType, VReg, MIRBuilder); + return SpirvType; +} + +void SPIRVGlobalRegistry::assignSPIRVTypeToVReg(SPIRVType *SpirvType, + Register VReg, + MachineIRBuilder &MIRBuilder) { + VRegToTypeMap[&MIRBuilder.getMF()][VReg] = SpirvType; +} + +static Register createTypeVReg(MachineIRBuilder &MIRBuilder) { + auto &MRI = MIRBuilder.getMF().getRegInfo(); + auto Res = MRI.createGenericVirtualRegister(LLT::scalar(32)); + MRI.setRegClass(Res, &SPIRV::TYPERegClass); + return Res; +} + +static Register createTypeVReg(MachineRegisterInfo &MRI) { + auto Res = MRI.createGenericVirtualRegister(LLT::scalar(32)); + MRI.setRegClass(Res, &SPIRV::TYPERegClass); + return Res; +} + +SPIRVType *SPIRVGlobalRegistry::getOpTypeBool(MachineIRBuilder &MIRBuilder) { + return MIRBuilder.buildInstr(SPIRV::OpTypeBool) + .addDef(createTypeVReg(MIRBuilder)); +} + +SPIRVType *SPIRVGlobalRegistry::getOpTypeInt(uint32_t Width, + MachineIRBuilder &MIRBuilder, + bool IsSigned) { + auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeInt) + .addDef(createTypeVReg(MIRBuilder)) + .addImm(Width) + .addImm(IsSigned ? 1 : 0); + return MIB; +} + +SPIRVType *SPIRVGlobalRegistry::getOpTypeFloat(uint32_t Width, + MachineIRBuilder &MIRBuilder) { + auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeFloat) + .addDef(createTypeVReg(MIRBuilder)) + .addImm(Width); + return MIB; +} + +SPIRVType *SPIRVGlobalRegistry::getOpTypeVoid(MachineIRBuilder &MIRBuilder) { + return MIRBuilder.buildInstr(SPIRV::OpTypeVoid) + .addDef(createTypeVReg(MIRBuilder)); +} + +SPIRVType *SPIRVGlobalRegistry::getOpTypeVector(uint32_t NumElems, + SPIRVType *ElemType, + MachineIRBuilder &MIRBuilder) { + auto EleOpc = ElemType->getOpcode(); + assert((EleOpc == SPIRV::OpTypeInt || EleOpc == SPIRV::OpTypeFloat || + EleOpc == SPIRV::OpTypeBool) && + "Invalid vector element type"); + + auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeVector) + .addDef(createTypeVReg(MIRBuilder)) + .addUse(getSPIRVTypeID(ElemType)) + .addImm(NumElems); + return MIB; +} + +Register SPIRVGlobalRegistry::buildConstantInt(uint64_t Val, + MachineIRBuilder &MIRBuilder, + SPIRVType *SpvType, + bool EmitIR) { + auto &MF = MIRBuilder.getMF(); + Register Res; + const IntegerType *LLVMIntTy; + if (SpvType) + LLVMIntTy = cast(getTypeForSPIRVType(SpvType)); + else + LLVMIntTy = IntegerType::getInt32Ty(MF.getFunction().getContext()); + // Find a constant in DT or build a new one. + const auto ConstInt = + ConstantInt::get(const_cast(LLVMIntTy), Val); + unsigned BitWidth = SpvType ? getScalarOrVectorBitWidth(SpvType) : 32; + Res = MF.getRegInfo().createGenericVirtualRegister(LLT::scalar(BitWidth)); + assignTypeToVReg(LLVMIntTy, Res, MIRBuilder); + if (EmitIR) + MIRBuilder.buildConstant(Res, *ConstInt); + else + MIRBuilder.buildInstr(SPIRV::OpConstantI) + .addDef(Res) + .addImm(ConstInt->getSExtValue()); + return Res; +} + +Register SPIRVGlobalRegistry::buildConstantFP(APFloat Val, + MachineIRBuilder &MIRBuilder, + SPIRVType *SpvType) { + auto &MF = MIRBuilder.getMF(); + Register Res; + const Type *LLVMFPTy; + if (SpvType) { + LLVMFPTy = getTypeForSPIRVType(SpvType); + assert(LLVMFPTy->isFloatingPointTy()); + } else { + LLVMFPTy = IntegerType::getFloatTy(MF.getFunction().getContext()); + } + // Find a constant in DT or build a new one. + const auto ConstFP = ConstantFP::get(LLVMFPTy->getContext(), Val); + unsigned BitWidth = SpvType ? getScalarOrVectorBitWidth(SpvType) : 32; + Res = MF.getRegInfo().createGenericVirtualRegister(LLT::scalar(BitWidth)); + assignTypeToVReg(LLVMFPTy, Res, MIRBuilder); + MIRBuilder.buildFConstant(Res, *ConstFP); + return Res; +} + +Register SPIRVGlobalRegistry::buildGlobalVariable( + Register ResVReg, SPIRVType *BaseType, StringRef Name, + const GlobalValue *GV, SPIRV::StorageClass Storage, + const MachineInstr *Init, bool IsConst, bool HasLinkageTy, + SPIRV::LinkageType LinkageType, MachineIRBuilder &MIRBuilder, + bool IsInstSelector) { + const GlobalVariable *GVar = nullptr; + if (GV) + GVar = cast(GV); + else { + // If GV is not passed explicitly, use the name to find or construct + // the global variable. + Module *M = MIRBuilder.getMF().getFunction().getParent(); + GVar = M->getGlobalVariable(Name); + if (GVar == nullptr) { + const Type *Ty = getTypeForSPIRVType(BaseType); // TODO: check type. + GVar = new GlobalVariable(*M, const_cast(Ty), false, + GlobalValue::ExternalLinkage, nullptr, + Twine(Name)); + } + GV = GVar; + } + Register Reg; + auto MIB = MIRBuilder.buildInstr(SPIRV::OpVariable) + .addDef(ResVReg) + .addUse(getSPIRVTypeID(BaseType)) + .addImm(static_cast(Storage)); + + if (Init != 0) { + MIB.addUse(Init->getOperand(0).getReg()); + } + + // ISel may introduce a new register on this step, so we need to add it to + // DT and correct its type avoiding fails on the next stage. + if (IsInstSelector) { + const auto &Subtarget = CurMF->getSubtarget(); + constrainSelectedInstRegOperands(*MIB, *Subtarget.getInstrInfo(), + *Subtarget.getRegisterInfo(), + *Subtarget.getRegBankInfo()); + } + Reg = MIB->getOperand(0).getReg(); + + // Set to Reg the same type as ResVReg has. + auto MRI = MIRBuilder.getMRI(); + assert(MRI->getType(ResVReg).isPointer() && "Pointer type is expected"); + if (Reg != ResVReg) { + LLT RegLLTy = LLT::pointer(MRI->getType(ResVReg).getAddressSpace(), 32); + MRI->setType(Reg, RegLLTy); + assignSPIRVTypeToVReg(BaseType, Reg, MIRBuilder); + } + + // If it's a global variable with name, output OpName for it. + if (GVar && GVar->hasName()) + buildOpName(Reg, GVar->getName(), MIRBuilder); + + // Output decorations for the GV. + // TODO: maybe move to GenerateDecorations pass. + if (IsConst) + buildOpDecorate(Reg, MIRBuilder, SPIRV::Decoration::Constant, {}); + + if (GVar && GVar->getAlign().valueOrOne().value() != 1) + buildOpDecorate( + Reg, MIRBuilder, SPIRV::Decoration::Alignment, + {static_cast(GVar->getAlign().valueOrOne().value())}); + + if (HasLinkageTy) + buildOpDecorate(Reg, MIRBuilder, SPIRV::Decoration::LinkageAttributes, + {static_cast(LinkageType)}, Name); + return Reg; +} + +SPIRVType *SPIRVGlobalRegistry::getOpTypeArray(uint32_t NumElems, + SPIRVType *ElemType, + MachineIRBuilder &MIRBuilder, + bool EmitIR) { + assert((ElemType->getOpcode() != SPIRV::OpTypeVoid) && + "Invalid array element type"); + Register NumElementsVReg = + buildConstantInt(NumElems, MIRBuilder, nullptr, EmitIR); + auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeArray) + .addDef(createTypeVReg(MIRBuilder)) + .addUse(getSPIRVTypeID(ElemType)) + .addUse(NumElementsVReg); + return MIB; +} + +SPIRVType *SPIRVGlobalRegistry::getOpTypePointer(SPIRV::StorageClass SC, + SPIRVType *ElemType, + MachineIRBuilder &MIRBuilder) { + auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypePointer) + .addDef(createTypeVReg(MIRBuilder)) + .addImm(static_cast(SC)) + .addUse(getSPIRVTypeID(ElemType)); + return MIB; +} + +SPIRVType *SPIRVGlobalRegistry::getOpTypeFunction( + SPIRVType *RetType, const SmallVectorImpl &ArgTypes, + MachineIRBuilder &MIRBuilder) { + auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeFunction) + .addDef(createTypeVReg(MIRBuilder)) + .addUse(getSPIRVTypeID(RetType)); + for (const SPIRVType *ArgType : ArgTypes) + MIB.addUse(getSPIRVTypeID(ArgType)); + return MIB; +} + +SPIRVType *SPIRVGlobalRegistry::createSPIRVType(const Type *Ty, + MachineIRBuilder &MIRBuilder, + SPIRV::AccessQualifier AccQual, + bool EmitIR) { + if (auto IType = dyn_cast(Ty)) { + const unsigned Width = IType->getBitWidth(); + return Width == 1 ? getOpTypeBool(MIRBuilder) + : getOpTypeInt(Width, MIRBuilder, false); + } + if (Ty->isFloatingPointTy()) + return getOpTypeFloat(Ty->getPrimitiveSizeInBits(), MIRBuilder); + if (Ty->isVoidTy()) + return getOpTypeVoid(MIRBuilder); + if (Ty->isVectorTy()) { + auto El = getOrCreateSPIRVType(cast(Ty)->getElementType(), + MIRBuilder); + return getOpTypeVector(cast(Ty)->getNumElements(), El, + MIRBuilder); + } + if (Ty->isArrayTy()) { + auto *El = getOrCreateSPIRVType(Ty->getArrayElementType(), MIRBuilder); + return getOpTypeArray(Ty->getArrayNumElements(), El, MIRBuilder, EmitIR); + } + assert(!isa(Ty) && "Unsupported StructType"); + if (auto FType = dyn_cast(Ty)) { + SPIRVType *RetTy = getOrCreateSPIRVType(FType->getReturnType(), MIRBuilder); + SmallVector ParamTypes; + for (const auto &t : FType->params()) { + ParamTypes.push_back(getOrCreateSPIRVType(t, MIRBuilder)); + } + return getOpTypeFunction(RetTy, ParamTypes, MIRBuilder); + } + if (auto PType = dyn_cast(Ty)) { + Type *ElemType = PType->getPointerElementType(); + + // Some OpenCL and SPIRV builtins like image2d_t are passed in as pointers, + // but should be treated as custom types like OpTypeImage. + assert(!isa(ElemType) && "Unsupported StructType pointer"); + + // Otherwise, treat it as a regular pointer type. + auto SC = addressSpaceToStorageClass(PType->getAddressSpace()); + SPIRVType *SpvElementType = getOrCreateSPIRVType( + ElemType, MIRBuilder, SPIRV::AccessQualifier::ReadWrite, EmitIR); + return getOpTypePointer(SC, SpvElementType, MIRBuilder); + } + llvm_unreachable("Unable to convert LLVM type to SPIRVType"); +} + +SPIRVType *SPIRVGlobalRegistry::getSPIRVTypeForVReg(Register VReg) const { + auto t = VRegToTypeMap.find(CurMF); + if (t != VRegToTypeMap.end()) { + auto tt = t->second.find(VReg); + if (tt != t->second.end()) + return tt->second; + } + return nullptr; +} + +SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVType( + const Type *Type, MachineIRBuilder &MIRBuilder, + SPIRV::AccessQualifier AccessQual, bool EmitIR) { + Register Reg; + SPIRVType *SpirvType = createSPIRVType(Type, MIRBuilder, AccessQual, EmitIR); + VRegToTypeMap[&MIRBuilder.getMF()][getSPIRVTypeID(SpirvType)] = SpirvType; + SPIRVToLLVMType[SpirvType] = Type; + return SpirvType; +} + +bool SPIRVGlobalRegistry::isScalarOfType(Register VReg, + unsigned TypeOpcode) const { + SPIRVType *Type = getSPIRVTypeForVReg(VReg); + assert(Type && "isScalarOfType VReg has no type assigned"); + return Type->getOpcode() == TypeOpcode; +} + +bool SPIRVGlobalRegistry::isScalarOrVectorOfType(Register VReg, + unsigned TypeOpcode) const { + SPIRVType *Type = getSPIRVTypeForVReg(VReg); + assert(Type && "isScalarOrVectorOfType VReg has no type assigned"); + if (Type->getOpcode() == TypeOpcode) + return true; + if (Type->getOpcode() == SPIRV::OpTypeVector) { + Register ScalarTypeVReg = Type->getOperand(1).getReg(); + SPIRVType *ScalarType = getSPIRVTypeForVReg(ScalarTypeVReg); + return ScalarType->getOpcode() == TypeOpcode; + } + return false; +} + +unsigned +SPIRVGlobalRegistry::getScalarOrVectorBitWidth(const SPIRVType *Type) const { + assert(Type && "Invalid Type pointer"); + if (Type->getOpcode() == SPIRV::OpTypeVector) { + auto EleTypeReg = Type->getOperand(1).getReg(); + Type = getSPIRVTypeForVReg(EleTypeReg); + } + if (Type->getOpcode() == SPIRV::OpTypeInt || + Type->getOpcode() == SPIRV::OpTypeFloat) + return Type->getOperand(1).getImm(); + if (Type->getOpcode() == SPIRV::OpTypeBool) + return 1; + llvm_unreachable("Attempting to get bit width of non-integer/float type."); +} + +bool SPIRVGlobalRegistry::isScalarOrVectorSigned(const SPIRVType *Type) const { + assert(Type && "Invalid Type pointer"); + if (Type->getOpcode() == SPIRV::OpTypeVector) { + auto EleTypeReg = Type->getOperand(1).getReg(); + Type = getSPIRVTypeForVReg(EleTypeReg); + } + if (Type->getOpcode() == SPIRV::OpTypeInt) + return Type->getOperand(2).getImm() != 0; + llvm_unreachable("Attempting to get sign of non-integer type."); +} + +SPIRV::StorageClass +SPIRVGlobalRegistry::getPointerStorageClass(Register VReg) const { + SPIRVType *Type = getSPIRVTypeForVReg(VReg); + if (Type && Type->getOpcode() == SPIRV::OpTypePointer) { + auto scOp = Type->getOperand(1).getImm(); + return static_cast(scOp); + } + llvm_unreachable("Attempting to get storage class of non-pointer type."); +} + +SPIRVType * +SPIRVGlobalRegistry::getOrCreateSPIRVIntegerType(unsigned BitWidth, + MachineIRBuilder &MIRBuilder) { + return getOrCreateSPIRVType( + IntegerType::get(MIRBuilder.getMF().getFunction().getContext(), BitWidth), + MIRBuilder); +} + +SPIRVType *SPIRVGlobalRegistry::restOfCreateSPIRVType(Type *LLVMTy, + MachineInstrBuilder MIB) { + SPIRVType *SpirvType = MIB; + VRegToTypeMap[CurMF][getSPIRVTypeID(SpirvType)] = SpirvType; + SPIRVToLLVMType[SpirvType] = LLVMTy; + return SpirvType; +} + +SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVIntegerType( + unsigned BitWidth, MachineInstr &I, const SPIRVInstrInfo &TII) { + Type *LLVMTy = IntegerType::get(CurMF->getFunction().getContext(), BitWidth); + MachineBasicBlock &BB = *I.getParent(); + auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpTypeInt)) + .addDef(createTypeVReg(CurMF->getRegInfo())) + .addImm(BitWidth) + .addImm(0); + return restOfCreateSPIRVType(LLVMTy, MIB); +} + +SPIRVType * +SPIRVGlobalRegistry::getOrCreateSPIRVBoolType(MachineIRBuilder &MIRBuilder) { + return getOrCreateSPIRVType( + IntegerType::get(MIRBuilder.getMF().getFunction().getContext(), 1), + MIRBuilder); +} + +SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVVectorType( + SPIRVType *BaseType, unsigned NumElements, MachineIRBuilder &MIRBuilder) { + return getOrCreateSPIRVType( + FixedVectorType::get(const_cast(getTypeForSPIRVType(BaseType)), + NumElements), + MIRBuilder); +} + +SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVVectorType( + SPIRVType *BaseType, unsigned NumElements, MachineInstr &I, + const SPIRVInstrInfo &TII) { + Type *LLVMTy = FixedVectorType::get( + const_cast(getTypeForSPIRVType(BaseType)), NumElements); + MachineBasicBlock &BB = *I.getParent(); + auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpTypeVector)) + .addDef(createTypeVReg(CurMF->getRegInfo())) + .addUse(getSPIRVTypeID(BaseType)) + .addImm(NumElements); + return restOfCreateSPIRVType(LLVMTy, MIB); +} + +SPIRVType * +SPIRVGlobalRegistry::getOrCreateSPIRVPointerType(SPIRVType *BaseType, + MachineIRBuilder &MIRBuilder, + SPIRV::StorageClass SClass) { + return getOrCreateSPIRVType( + PointerType::get(const_cast(getTypeForSPIRVType(BaseType)), + storageClassToAddressSpace(SClass)), + MIRBuilder); +} + +SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVPointerType( + SPIRVType *BaseType, MachineInstr &I, const SPIRVInstrInfo &TII, + SPIRV::StorageClass SC) { + Type *LLVMTy = + PointerType::get(const_cast(getTypeForSPIRVType(BaseType)), + storageClassToAddressSpace(SC)); + MachineBasicBlock &BB = *I.getParent(); + auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpTypePointer)) + .addDef(createTypeVReg(CurMF->getRegInfo())) + .addImm(static_cast(SC)) + .addUse(getSPIRVTypeID(BaseType)); + return restOfCreateSPIRVType(LLVMTy, MIB); +} Index: llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp =================================================================== --- /dev/null +++ llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp @@ -0,0 +1,1088 @@ +//===- SPIRVInstructionSelector.cpp ------------------------------*- C++ -*-==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the targeting of the InstructionSelector class for +// SPIRV. +// TODO: This should be generated by TableGen. +// +//===----------------------------------------------------------------------===// + +#include "SPIRV.h" +#include "SPIRVGlobalRegistry.h" +#include "SPIRVInstrInfo.h" +#include "SPIRVRegisterBankInfo.h" +#include "SPIRVRegisterInfo.h" +#include "SPIRVTargetMachine.h" +#include "SPIRVUtils.h" +#include "llvm/ADT/APFloat.h" +#include "llvm/CodeGen/GlobalISel/InstructionSelector.h" +#include "llvm/CodeGen/GlobalISel/InstructionSelectorImpl.h" +#include "llvm/CodeGen/MachineInstrBuilder.h" +#include "llvm/CodeGen/MachineRegisterInfo.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "spirv-isel" + +using namespace llvm; + +namespace { + +#define GET_GLOBALISEL_PREDICATE_BITSET +#include "SPIRVGenGlobalISel.inc" +#undef GET_GLOBALISEL_PREDICATE_BITSET + +class SPIRVInstructionSelector : public InstructionSelector { + const SPIRVSubtarget &STI; + const SPIRVInstrInfo &TII; + const SPIRVRegisterInfo &TRI; + const SPIRVRegisterBankInfo &RBI; + SPIRVGlobalRegistry &GR; + MachineRegisterInfo *MRI; + +public: + SPIRVInstructionSelector(const SPIRVTargetMachine &TM, + const SPIRVSubtarget &ST, + const SPIRVRegisterBankInfo &RBI); + void setupMF(MachineFunction &MF, GISelKnownBits *KB, + CodeGenCoverage &CoverageInfo, ProfileSummaryInfo *PSI, + BlockFrequencyInfo *BFI) override; + // Common selection code. Instruction-specific selection occurs in spvSelect. + bool select(MachineInstr &I) override; + static const char *getName() { return DEBUG_TYPE; } + +#define GET_GLOBALISEL_PREDICATES_DECL +#include "SPIRVGenGlobalISel.inc" +#undef GET_GLOBALISEL_PREDICATES_DECL + +#define GET_GLOBALISEL_TEMPORARIES_DECL +#include "SPIRVGenGlobalISel.inc" +#undef GET_GLOBALISEL_TEMPORARIES_DECL + +private: + // tblgen-erated 'select' implementation, used as the initial selector for + // the patterns that don't require complex C++. + bool selectImpl(MachineInstr &I, CodeGenCoverage &CoverageInfo) const; + + // All instruction-specific selection that didn't happen in "select()". + // Is basically a large Switch/Case delegating to all other select method. + bool spvSelect(Register ResVReg, const SPIRVType *ResType, + MachineInstr &I) const; + + bool selectGlobalValue(Register ResVReg, MachineInstr &I, + const MachineInstr *Init = nullptr) const; + + bool selectUnOpWithSrc(Register ResVReg, const SPIRVType *ResType, + MachineInstr &I, Register SrcReg, + unsigned Opcode) const; + bool selectUnOp(Register ResVReg, const SPIRVType *ResType, MachineInstr &I, + unsigned Opcode) const; + + bool selectLoad(Register ResVReg, const SPIRVType *ResType, + MachineInstr &I) const; + bool selectStore(MachineInstr &I) const; + + bool selectMemOperation(Register ResVReg, MachineInstr &I) const; + + bool selectAtomicRMW(Register ResVReg, const SPIRVType *ResType, + MachineInstr &I, unsigned NewOpcode) const; + + bool selectAtomicCmpXchg(Register ResVReg, const SPIRVType *ResType, + MachineInstr &I) const; + + bool selectFence(MachineInstr &I) const; + + bool selectAddrSpaceCast(Register ResVReg, const SPIRVType *ResType, + MachineInstr &I) const; + + bool selectBitreverse(Register ResVReg, const SPIRVType *ResType, + MachineInstr &I) const; + + bool selectConstVector(Register ResVReg, const SPIRVType *ResType, + MachineInstr &I) const; + + bool selectCmp(Register ResVReg, const SPIRVType *ResType, + unsigned comparisonOpcode, MachineInstr &I) const; + + bool selectICmp(Register ResVReg, const SPIRVType *ResType, + MachineInstr &I) const; + bool selectFCmp(Register ResVReg, const SPIRVType *ResType, + MachineInstr &I) const; + + void renderImm32(MachineInstrBuilder &MIB, const MachineInstr &I, + int OpIdx) const; + void renderFImm32(MachineInstrBuilder &MIB, const MachineInstr &I, + int OpIdx) const; + + bool selectConst(Register ResVReg, const SPIRVType *ResType, const APInt &Imm, + MachineInstr &I) const; + + bool selectSelect(Register ResVReg, const SPIRVType *ResType, MachineInstr &I, + bool IsSigned) const; + bool selectIToF(Register ResVReg, const SPIRVType *ResType, MachineInstr &I, + bool IsSigned, unsigned Opcode) const; + bool selectExt(Register ResVReg, const SPIRVType *ResType, MachineInstr &I, + bool IsSigned) const; + + bool selectTrunc(Register ResVReg, const SPIRVType *ResType, + MachineInstr &I) const; + + bool selectIntToBool(Register IntReg, Register ResVReg, + const SPIRVType *intTy, const SPIRVType *boolTy, + MachineInstr &I) const; + + bool selectOpUndef(Register ResVReg, const SPIRVType *ResType, + MachineInstr &I) const; + bool selectIntrinsic(Register ResVReg, const SPIRVType *ResType, + MachineInstr &I) const; + + bool selectFrameIndex(Register ResVReg, const SPIRVType *ResType, + MachineInstr &I) const; + + bool selectBranch(MachineInstr &I) const; + bool selectBranchCond(MachineInstr &I) const; + + bool selectPhi(Register ResVReg, const SPIRVType *ResType, + MachineInstr &I) const; + + Register buildI32Constant(uint32_t Val, MachineInstr &I, + const SPIRVType *ResType = nullptr) const; + + Register buildZerosVal(const SPIRVType *ResType, MachineInstr &I) const; + Register buildOnesVal(bool AllOnes, const SPIRVType *ResType, + MachineInstr &I) const; +}; + +} // end anonymous namespace + +#define GET_GLOBALISEL_IMPL +#include "SPIRVGenGlobalISel.inc" +#undef GET_GLOBALISEL_IMPL + +SPIRVInstructionSelector::SPIRVInstructionSelector( + const SPIRVTargetMachine &TM, const SPIRVSubtarget &ST, + const SPIRVRegisterBankInfo &RBI) + : InstructionSelector(), STI(ST), TII(*ST.getInstrInfo()), + TRI(*ST.getRegisterInfo()), RBI(RBI), GR(*ST.getSPIRVGlobalRegistry()), +#define GET_GLOBALISEL_PREDICATES_INIT +#include "SPIRVGenGlobalISel.inc" +#undef GET_GLOBALISEL_PREDICATES_INIT +#define GET_GLOBALISEL_TEMPORARIES_INIT +#include "SPIRVGenGlobalISel.inc" +#undef GET_GLOBALISEL_TEMPORARIES_INIT +{ +} + +void SPIRVInstructionSelector::setupMF(MachineFunction &MF, GISelKnownBits *KB, + CodeGenCoverage &CoverageInfo, + ProfileSummaryInfo *PSI, + BlockFrequencyInfo *BFI) { + MRI = &MF.getRegInfo(); + GR.setCurrentFunc(MF); + InstructionSelector::setupMF(MF, KB, CoverageInfo, PSI, BFI); +} + +// Defined in SPIRVLegalizerInfo.cpp. +extern bool isTypeFoldingSupported(unsigned Opcode); + +bool SPIRVInstructionSelector::select(MachineInstr &I) { + assert(I.getParent() && "Instruction should be in a basic block!"); + assert(I.getParent()->getParent() && "Instruction should be in a function!"); + + Register Opcode = I.getOpcode(); + // If it's not a GMIR instruction, we've selected it already. + if (!isPreISelGenericOpcode(Opcode)) { + if (Opcode == SPIRV::ASSIGN_TYPE) { // These pseudos aren't needed any more. + auto *Def = MRI->getVRegDef(I.getOperand(1).getReg()); + if (isTypeFoldingSupported(Def->getOpcode())) { + auto Res = selectImpl(I, *CoverageInfo); + assert(Res || Def->getOpcode() == TargetOpcode::G_CONSTANT); + if (Res) + return Res; + } + MRI->replaceRegWith(I.getOperand(1).getReg(), I.getOperand(0).getReg()); + I.removeFromParent(); + } else if (I.getNumDefs() == 1) { + // Make all vregs 32 bits (for SPIR-V IDs). + MRI->setType(I.getOperand(0).getReg(), LLT::scalar(32)); + } + return true; + } + + if (I.getNumOperands() != I.getNumExplicitOperands()) { + LLVM_DEBUG(errs() << "Generic instr has unexpected implicit operands\n"); + return false; + } + + // Common code for getting return reg+type, and removing selected instr + // from parent occurs here. Instr-specific selection happens in spvSelect(). + bool HasDefs = I.getNumDefs() > 0; + Register ResVReg = HasDefs ? I.getOperand(0).getReg() : Register(0); + SPIRVType *ResType = HasDefs ? GR.getSPIRVTypeForVReg(ResVReg) : nullptr; + assert(!HasDefs || ResType || I.getOpcode() == TargetOpcode::G_GLOBAL_VALUE); + if (spvSelect(ResVReg, ResType, I)) { + if (HasDefs) { // Make all vregs 32 bits (for SPIR-V IDs). + MRI->setType(ResVReg, LLT::scalar(32)); + } + I.removeFromParent(); + return true; + } + return false; +} + +bool SPIRVInstructionSelector::spvSelect(Register ResVReg, + const SPIRVType *ResType, + MachineInstr &I) const { + assert(!isTypeFoldingSupported(I.getOpcode()) || + I.getOpcode() == TargetOpcode::G_CONSTANT); + const unsigned Opcode = I.getOpcode(); + switch (Opcode) { + case TargetOpcode::G_CONSTANT: + return selectConst(ResVReg, ResType, I.getOperand(1).getCImm()->getValue(), + I); + case TargetOpcode::G_GLOBAL_VALUE: + return selectGlobalValue(ResVReg, I); + case TargetOpcode::G_IMPLICIT_DEF: + return selectOpUndef(ResVReg, ResType, I); + + case TargetOpcode::G_INTRINSIC_W_SIDE_EFFECTS: + return selectIntrinsic(ResVReg, ResType, I); + case TargetOpcode::G_BITREVERSE: + return selectBitreverse(ResVReg, ResType, I); + + case TargetOpcode::G_BUILD_VECTOR: + return selectConstVector(ResVReg, ResType, I); + + case TargetOpcode::G_SHUFFLE_VECTOR: { + MachineBasicBlock &BB = *I.getParent(); + auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpVectorShuffle)) + .addDef(ResVReg) + .addUse(GR.getSPIRVTypeID(ResType)) + .addUse(I.getOperand(1).getReg()) + .addUse(I.getOperand(2).getReg()); + for (auto V : I.getOperand(3).getShuffleMask()) + MIB.addImm(V); + return MIB.constrainAllUses(TII, TRI, RBI); + } + case TargetOpcode::G_MEMMOVE: + case TargetOpcode::G_MEMCPY: + return selectMemOperation(ResVReg, I); + + case TargetOpcode::G_ICMP: + return selectICmp(ResVReg, ResType, I); + case TargetOpcode::G_FCMP: + return selectFCmp(ResVReg, ResType, I); + + case TargetOpcode::G_FRAME_INDEX: + return selectFrameIndex(ResVReg, ResType, I); + + case TargetOpcode::G_LOAD: + return selectLoad(ResVReg, ResType, I); + case TargetOpcode::G_STORE: + return selectStore(I); + + case TargetOpcode::G_BR: + return selectBranch(I); + case TargetOpcode::G_BRCOND: + return selectBranchCond(I); + + case TargetOpcode::G_PHI: + return selectPhi(ResVReg, ResType, I); + + case TargetOpcode::G_FPTOSI: + return selectUnOp(ResVReg, ResType, I, SPIRV::OpConvertFToS); + case TargetOpcode::G_FPTOUI: + return selectUnOp(ResVReg, ResType, I, SPIRV::OpConvertFToU); + + case TargetOpcode::G_SITOFP: + return selectIToF(ResVReg, ResType, I, true, SPIRV::OpConvertSToF); + case TargetOpcode::G_UITOFP: + return selectIToF(ResVReg, ResType, I, false, SPIRV::OpConvertUToF); + + case TargetOpcode::G_CTPOP: + return selectUnOp(ResVReg, ResType, I, SPIRV::OpBitCount); + + case TargetOpcode::G_SEXT: + return selectExt(ResVReg, ResType, I, true); + case TargetOpcode::G_ANYEXT: + case TargetOpcode::G_ZEXT: + return selectExt(ResVReg, ResType, I, false); + case TargetOpcode::G_TRUNC: + return selectTrunc(ResVReg, ResType, I); + case TargetOpcode::G_FPTRUNC: + case TargetOpcode::G_FPEXT: + return selectUnOp(ResVReg, ResType, I, SPIRV::OpFConvert); + + case TargetOpcode::G_PTRTOINT: + return selectUnOp(ResVReg, ResType, I, SPIRV::OpConvertPtrToU); + case TargetOpcode::G_INTTOPTR: + return selectUnOp(ResVReg, ResType, I, SPIRV::OpConvertUToPtr); + case TargetOpcode::G_BITCAST: + return selectUnOp(ResVReg, ResType, I, SPIRV::OpBitcast); + case TargetOpcode::G_ADDRSPACE_CAST: + return selectAddrSpaceCast(ResVReg, ResType, I); + + case TargetOpcode::G_ATOMICRMW_OR: + return selectAtomicRMW(ResVReg, ResType, I, SPIRV::OpAtomicOr); + case TargetOpcode::G_ATOMICRMW_ADD: + return selectAtomicRMW(ResVReg, ResType, I, SPIRV::OpAtomicIAdd); + case TargetOpcode::G_ATOMICRMW_AND: + return selectAtomicRMW(ResVReg, ResType, I, SPIRV::OpAtomicAnd); + case TargetOpcode::G_ATOMICRMW_MAX: + return selectAtomicRMW(ResVReg, ResType, I, SPIRV::OpAtomicSMax); + case TargetOpcode::G_ATOMICRMW_MIN: + return selectAtomicRMW(ResVReg, ResType, I, SPIRV::OpAtomicSMin); + case TargetOpcode::G_ATOMICRMW_SUB: + return selectAtomicRMW(ResVReg, ResType, I, SPIRV::OpAtomicISub); + case TargetOpcode::G_ATOMICRMW_XOR: + return selectAtomicRMW(ResVReg, ResType, I, SPIRV::OpAtomicXor); + case TargetOpcode::G_ATOMICRMW_UMAX: + return selectAtomicRMW(ResVReg, ResType, I, SPIRV::OpAtomicUMax); + case TargetOpcode::G_ATOMICRMW_UMIN: + return selectAtomicRMW(ResVReg, ResType, I, SPIRV::OpAtomicUMin); + case TargetOpcode::G_ATOMICRMW_XCHG: + return selectAtomicRMW(ResVReg, ResType, I, SPIRV::OpAtomicExchange); + case TargetOpcode::G_ATOMIC_CMPXCHG: + return selectAtomicCmpXchg(ResVReg, ResType, I); + + case TargetOpcode::G_FENCE: + return selectFence(I); + + default: + return false; + } +} + +bool SPIRVInstructionSelector::selectUnOpWithSrc(Register ResVReg, + const SPIRVType *ResType, + MachineInstr &I, + Register SrcReg, + unsigned Opcode) const { + return BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(Opcode)) + .addDef(ResVReg) + .addUse(GR.getSPIRVTypeID(ResType)) + .addUse(SrcReg) + .constrainAllUses(TII, TRI, RBI); +} + +bool SPIRVInstructionSelector::selectUnOp(Register ResVReg, + const SPIRVType *ResType, + MachineInstr &I, + unsigned Opcode) const { + return selectUnOpWithSrc(ResVReg, ResType, I, I.getOperand(1).getReg(), + Opcode); +} + +static SPIRV::MemorySemantics getMemSemantics(AtomicOrdering Ord) { + switch (Ord) { + case AtomicOrdering::Acquire: + return SPIRV::MemorySemantics::Acquire; + case AtomicOrdering::Release: + return SPIRV::MemorySemantics::Release; + case AtomicOrdering::AcquireRelease: + return SPIRV::MemorySemantics::AcquireRelease; + case AtomicOrdering::SequentiallyConsistent: + return SPIRV::MemorySemantics::SequentiallyConsistent; + case AtomicOrdering::Unordered: + case AtomicOrdering::Monotonic: + case AtomicOrdering::NotAtomic: + default: + return SPIRV::MemorySemantics::None; + } +} + +static SPIRV::Scope getScope(SyncScope::ID Ord) { + switch (Ord) { + case SyncScope::SingleThread: + return SPIRV::Scope::Invocation; + case SyncScope::System: + return SPIRV::Scope::Device; + default: + llvm_unreachable("Unsupported synchronization Scope ID."); + } +} + +static void addMemoryOperands(MachineMemOperand *MemOp, + MachineInstrBuilder &MIB) { + uint32_t SpvMemOp = static_cast(SPIRV::MemoryOperand::None); + if (MemOp->isVolatile()) + SpvMemOp |= static_cast(SPIRV::MemoryOperand::Volatile); + if (MemOp->isNonTemporal()) + SpvMemOp |= static_cast(SPIRV::MemoryOperand::Nontemporal); + if (MemOp->getAlign().value()) + SpvMemOp |= static_cast(SPIRV::MemoryOperand::Aligned); + + if (SpvMemOp != static_cast(SPIRV::MemoryOperand::None)) { + MIB.addImm(SpvMemOp); + if (SpvMemOp & static_cast(SPIRV::MemoryOperand::Aligned)) + MIB.addImm(MemOp->getAlign().value()); + } +} + +static void addMemoryOperands(uint64_t Flags, MachineInstrBuilder &MIB) { + uint32_t SpvMemOp = static_cast(SPIRV::MemoryOperand::None); + if (Flags & MachineMemOperand::Flags::MOVolatile) + SpvMemOp |= static_cast(SPIRV::MemoryOperand::Volatile); + if (Flags & MachineMemOperand::Flags::MONonTemporal) + SpvMemOp |= static_cast(SPIRV::MemoryOperand::Nontemporal); + + if (SpvMemOp != static_cast(SPIRV::MemoryOperand::None)) + MIB.addImm(SpvMemOp); +} + +bool SPIRVInstructionSelector::selectLoad(Register ResVReg, + const SPIRVType *ResType, + MachineInstr &I) const { + unsigned OpOffset = + I.getOpcode() == TargetOpcode::G_INTRINSIC_W_SIDE_EFFECTS ? 1 : 0; + Register Ptr = I.getOperand(1 + OpOffset).getReg(); + auto MIB = BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(SPIRV::OpLoad)) + .addDef(ResVReg) + .addUse(GR.getSPIRVTypeID(ResType)) + .addUse(Ptr); + if (!I.getNumMemOperands()) { + assert(I.getOpcode() == TargetOpcode::G_INTRINSIC_W_SIDE_EFFECTS); + addMemoryOperands(I.getOperand(2 + OpOffset).getImm(), MIB); + } else { + addMemoryOperands(*I.memoperands_begin(), MIB); + } + return MIB.constrainAllUses(TII, TRI, RBI); +} + +bool SPIRVInstructionSelector::selectStore(MachineInstr &I) const { + unsigned OpOffset = + I.getOpcode() == TargetOpcode::G_INTRINSIC_W_SIDE_EFFECTS ? 1 : 0; + Register StoreVal = I.getOperand(0 + OpOffset).getReg(); + Register Ptr = I.getOperand(1 + OpOffset).getReg(); + MachineBasicBlock &BB = *I.getParent(); + auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpStore)) + .addUse(Ptr) + .addUse(StoreVal); + if (!I.getNumMemOperands()) { + assert(I.getOpcode() == TargetOpcode::G_INTRINSIC_W_SIDE_EFFECTS); + addMemoryOperands(I.getOperand(2 + OpOffset).getImm(), MIB); + } else { + addMemoryOperands(*I.memoperands_begin(), MIB); + } + return MIB.constrainAllUses(TII, TRI, RBI); +} + +bool SPIRVInstructionSelector::selectMemOperation(Register ResVReg, + MachineInstr &I) const { + MachineBasicBlock &BB = *I.getParent(); + auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpCopyMemorySized)) + .addDef(I.getOperand(0).getReg()) + .addUse(I.getOperand(1).getReg()) + .addUse(I.getOperand(2).getReg()); + if (I.getNumMemOperands()) + addMemoryOperands(*I.memoperands_begin(), MIB); + bool Result = MIB.constrainAllUses(TII, TRI, RBI); + if (ResVReg.isValid() && ResVReg != MIB->getOperand(0).getReg()) + BuildMI(BB, I, I.getDebugLoc(), TII.get(TargetOpcode::COPY), ResVReg) + .addUse(MIB->getOperand(0).getReg()); + return Result; +} + +bool SPIRVInstructionSelector::selectAtomicRMW(Register ResVReg, + const SPIRVType *ResType, + MachineInstr &I, + unsigned NewOpcode) const { + assert(I.hasOneMemOperand()); + const MachineMemOperand *MemOp = *I.memoperands_begin(); + uint32_t Scope = static_cast(getScope(MemOp->getSyncScopeID())); + Register ScopeReg = buildI32Constant(Scope, I); + + Register Ptr = I.getOperand(1).getReg(); + // TODO: Changed as it's implemented in the translator. See test/atomicrmw.ll + // auto ScSem = + // getMemSemanticsForStorageClass(GR.getPointerStorageClass(Ptr)); + AtomicOrdering AO = MemOp->getSuccessOrdering(); + uint32_t MemSem = static_cast(getMemSemantics(AO)); + Register MemSemReg = buildI32Constant(MemSem /*| ScSem*/, I); + + return BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(NewOpcode)) + .addDef(ResVReg) + .addUse(GR.getSPIRVTypeID(ResType)) + .addUse(Ptr) + .addUse(ScopeReg) + .addUse(MemSemReg) + .addUse(I.getOperand(2).getReg()) + .constrainAllUses(TII, TRI, RBI); +} + +bool SPIRVInstructionSelector::selectFence(MachineInstr &I) const { + AtomicOrdering AO = AtomicOrdering(I.getOperand(0).getImm()); + uint32_t MemSem = static_cast(getMemSemantics(AO)); + Register MemSemReg = buildI32Constant(MemSem, I); + SyncScope::ID Ord = SyncScope::ID(I.getOperand(1).getImm()); + uint32_t Scope = static_cast(getScope(Ord)); + Register ScopeReg = buildI32Constant(Scope, I); + MachineBasicBlock &BB = *I.getParent(); + return BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpMemoryBarrier)) + .addUse(ScopeReg) + .addUse(MemSemReg) + .constrainAllUses(TII, TRI, RBI); +} + +bool SPIRVInstructionSelector::selectAtomicCmpXchg(Register ResVReg, + const SPIRVType *ResType, + MachineInstr &I) const { + assert(I.hasOneMemOperand()); + const MachineMemOperand *MemOp = *I.memoperands_begin(); + uint32_t Scope = static_cast(getScope(MemOp->getSyncScopeID())); + Register ScopeReg = buildI32Constant(Scope, I); + + Register Ptr = I.getOperand(2).getReg(); + Register Cmp = I.getOperand(3).getReg(); + Register Val = I.getOperand(4).getReg(); + + SPIRVType *SpvValTy = GR.getSPIRVTypeForVReg(Val); + SPIRV::StorageClass SC = GR.getPointerStorageClass(Ptr); + uint32_t ScSem = static_cast(getMemSemanticsForStorageClass(SC)); + AtomicOrdering AO = MemOp->getSuccessOrdering(); + uint32_t MemSemEq = static_cast(getMemSemantics(AO)) | ScSem; + Register MemSemEqReg = buildI32Constant(MemSemEq, I); + AtomicOrdering FO = MemOp->getFailureOrdering(); + uint32_t MemSemNeq = static_cast(getMemSemantics(FO)) | ScSem; + Register MemSemNeqReg = + MemSemEq == MemSemNeq ? MemSemEqReg : buildI32Constant(MemSemNeq, I); + const DebugLoc &DL = I.getDebugLoc(); + return BuildMI(*I.getParent(), I, DL, TII.get(SPIRV::OpAtomicCompareExchange)) + .addDef(ResVReg) + .addUse(GR.getSPIRVTypeID(SpvValTy)) + .addUse(Ptr) + .addUse(ScopeReg) + .addUse(MemSemEqReg) + .addUse(MemSemNeqReg) + .addUse(Val) + .addUse(Cmp) + .constrainAllUses(TII, TRI, RBI); +} + +static bool isGenericCastablePtr(SPIRV::StorageClass SC) { + switch (SC) { + case SPIRV::StorageClass::Workgroup: + case SPIRV::StorageClass::CrossWorkgroup: + case SPIRV::StorageClass::Function: + return true; + default: + return false; + } +} + +// In SPIR-V address space casting can only happen to and from the Generic +// storage class. We can also only case Workgroup, CrossWorkgroup, or Function +// pointers to and from Generic pointers. As such, we can convert e.g. from +// Workgroup to Function by going via a Generic pointer as an intermediary. All +// other combinations can only be done by a bitcast, and are probably not safe. +bool SPIRVInstructionSelector::selectAddrSpaceCast(Register ResVReg, + const SPIRVType *ResType, + MachineInstr &I) const { + Register SrcPtr = I.getOperand(1).getReg(); + SPIRVType *SrcPtrTy = GR.getSPIRVTypeForVReg(SrcPtr); + SPIRV::StorageClass SrcSC = GR.getPointerStorageClass(SrcPtr); + SPIRV::StorageClass DstSC = GR.getPointerStorageClass(ResVReg); + + // Casting from an eligable pointer to Generic. + if (DstSC == SPIRV::StorageClass::Generic && isGenericCastablePtr(SrcSC)) + return selectUnOp(ResVReg, ResType, I, SPIRV::OpPtrCastToGeneric); + // Casting from Generic to an eligable pointer. + if (SrcSC == SPIRV::StorageClass::Generic && isGenericCastablePtr(DstSC)) + return selectUnOp(ResVReg, ResType, I, SPIRV::OpGenericCastToPtr); + // Casting between 2 eligable pointers using Generic as an intermediary. + if (isGenericCastablePtr(SrcSC) && isGenericCastablePtr(DstSC)) { + Register Tmp = MRI->createVirtualRegister(&SPIRV::IDRegClass); + SPIRVType *GenericPtrTy = GR.getOrCreateSPIRVPointerType( + SrcPtrTy, I, TII, SPIRV::StorageClass::Generic); + MachineBasicBlock &BB = *I.getParent(); + const DebugLoc &DL = I.getDebugLoc(); + bool Success = BuildMI(BB, I, DL, TII.get(SPIRV::OpPtrCastToGeneric)) + .addDef(Tmp) + .addUse(GR.getSPIRVTypeID(GenericPtrTy)) + .addUse(SrcPtr) + .constrainAllUses(TII, TRI, RBI); + return Success && BuildMI(BB, I, DL, TII.get(SPIRV::OpGenericCastToPtr)) + .addDef(ResVReg) + .addUse(GR.getSPIRVTypeID(ResType)) + .addUse(Tmp) + .constrainAllUses(TII, TRI, RBI); + } + // TODO Should this case just be disallowed completely? + // We're casting 2 other arbitrary address spaces, so have to bitcast. + return selectUnOp(ResVReg, ResType, I, SPIRV::OpBitcast); +} + +static unsigned getFCmpOpcode(unsigned PredNum) { + auto Pred = static_cast(PredNum); + switch (Pred) { + case CmpInst::FCMP_OEQ: + return SPIRV::OpFOrdEqual; + case CmpInst::FCMP_OGE: + return SPIRV::OpFOrdGreaterThanEqual; + case CmpInst::FCMP_OGT: + return SPIRV::OpFOrdGreaterThan; + case CmpInst::FCMP_OLE: + return SPIRV::OpFOrdLessThanEqual; + case CmpInst::FCMP_OLT: + return SPIRV::OpFOrdLessThan; + case CmpInst::FCMP_ONE: + return SPIRV::OpFOrdNotEqual; + case CmpInst::FCMP_ORD: + return SPIRV::OpOrdered; + case CmpInst::FCMP_UEQ: + return SPIRV::OpFUnordEqual; + case CmpInst::FCMP_UGE: + return SPIRV::OpFUnordGreaterThanEqual; + case CmpInst::FCMP_UGT: + return SPIRV::OpFUnordGreaterThan; + case CmpInst::FCMP_ULE: + return SPIRV::OpFUnordLessThanEqual; + case CmpInst::FCMP_ULT: + return SPIRV::OpFUnordLessThan; + case CmpInst::FCMP_UNE: + return SPIRV::OpFUnordNotEqual; + case CmpInst::FCMP_UNO: + return SPIRV::OpUnordered; + default: + llvm_unreachable("Unknown predicate type for FCmp"); + } +} + +static unsigned getICmpOpcode(unsigned PredNum) { + auto Pred = static_cast(PredNum); + switch (Pred) { + case CmpInst::ICMP_EQ: + return SPIRV::OpIEqual; + case CmpInst::ICMP_NE: + return SPIRV::OpINotEqual; + case CmpInst::ICMP_SGE: + return SPIRV::OpSGreaterThanEqual; + case CmpInst::ICMP_SGT: + return SPIRV::OpSGreaterThan; + case CmpInst::ICMP_SLE: + return SPIRV::OpSLessThanEqual; + case CmpInst::ICMP_SLT: + return SPIRV::OpSLessThan; + case CmpInst::ICMP_UGE: + return SPIRV::OpUGreaterThanEqual; + case CmpInst::ICMP_UGT: + return SPIRV::OpUGreaterThan; + case CmpInst::ICMP_ULE: + return SPIRV::OpULessThanEqual; + case CmpInst::ICMP_ULT: + return SPIRV::OpULessThan; + default: + llvm_unreachable("Unknown predicate type for ICmp"); + } +} + +static unsigned getPtrCmpOpcode(unsigned Pred) { + switch (static_cast(Pred)) { + case CmpInst::ICMP_EQ: + return SPIRV::OpPtrEqual; + case CmpInst::ICMP_NE: + return SPIRV::OpPtrNotEqual; + default: + llvm_unreachable("Unknown predicate type for pointer comparison"); + } +} + +// Return the logical operation, or abort if none exists. +static unsigned getBoolCmpOpcode(unsigned PredNum) { + auto Pred = static_cast(PredNum); + switch (Pred) { + case CmpInst::ICMP_EQ: + return SPIRV::OpLogicalEqual; + case CmpInst::ICMP_NE: + return SPIRV::OpLogicalNotEqual; + default: + llvm_unreachable("Unknown predicate type for Bool comparison"); + } +} + +bool SPIRVInstructionSelector::selectBitreverse(Register ResVReg, + const SPIRVType *ResType, + MachineInstr &I) const { + MachineBasicBlock &BB = *I.getParent(); + return BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpBitReverse)) + .addDef(ResVReg) + .addUse(GR.getSPIRVTypeID(ResType)) + .addUse(I.getOperand(1).getReg()) + .constrainAllUses(TII, TRI, RBI); +} + +bool SPIRVInstructionSelector::selectConstVector(Register ResVReg, + const SPIRVType *ResType, + MachineInstr &I) const { + // TODO: only const case is supported for now. + assert(std::all_of( + I.operands_begin(), I.operands_end(), [this](const MachineOperand &MO) { + if (MO.isDef()) + return true; + if (!MO.isReg()) + return false; + SPIRVType *ConstTy = this->MRI->getVRegDef(MO.getReg()); + assert(ConstTy && ConstTy->getOpcode() == SPIRV::ASSIGN_TYPE && + ConstTy->getOperand(1).isReg()); + Register ConstReg = ConstTy->getOperand(1).getReg(); + const MachineInstr *Const = this->MRI->getVRegDef(ConstReg); + assert(Const); + return (Const->getOpcode() == TargetOpcode::G_CONSTANT || + Const->getOpcode() == TargetOpcode::G_FCONSTANT); + })); + + auto MIB = BuildMI(*I.getParent(), I, I.getDebugLoc(), + TII.get(SPIRV::OpConstantComposite)) + .addDef(ResVReg) + .addUse(GR.getSPIRVTypeID(ResType)); + for (unsigned i = I.getNumExplicitDefs(); i < I.getNumExplicitOperands(); ++i) + MIB.addUse(I.getOperand(i).getReg()); + return MIB.constrainAllUses(TII, TRI, RBI); +} + +bool SPIRVInstructionSelector::selectCmp(Register ResVReg, + const SPIRVType *ResType, + unsigned CmpOpc, + MachineInstr &I) const { + Register Cmp0 = I.getOperand(2).getReg(); + Register Cmp1 = I.getOperand(3).getReg(); + assert(GR.getSPIRVTypeForVReg(Cmp0)->getOpcode() == + GR.getSPIRVTypeForVReg(Cmp1)->getOpcode() && + "CMP operands should have the same type"); + return BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(CmpOpc)) + .addDef(ResVReg) + .addUse(GR.getSPIRVTypeID(ResType)) + .addUse(Cmp0) + .addUse(Cmp1) + .constrainAllUses(TII, TRI, RBI); +} + +bool SPIRVInstructionSelector::selectICmp(Register ResVReg, + const SPIRVType *ResType, + MachineInstr &I) const { + auto Pred = I.getOperand(1).getPredicate(); + unsigned CmpOpc; + + Register CmpOperand = I.getOperand(2).getReg(); + if (GR.isScalarOfType(CmpOperand, SPIRV::OpTypePointer)) + CmpOpc = getPtrCmpOpcode(Pred); + else if (GR.isScalarOrVectorOfType(CmpOperand, SPIRV::OpTypeBool)) + CmpOpc = getBoolCmpOpcode(Pred); + else + CmpOpc = getICmpOpcode(Pred); + return selectCmp(ResVReg, ResType, CmpOpc, I); +} + +void SPIRVInstructionSelector::renderFImm32(MachineInstrBuilder &MIB, + const MachineInstr &I, + int OpIdx) const { + assert(I.getOpcode() == TargetOpcode::G_FCONSTANT && OpIdx == -1 && + "Expected G_FCONSTANT"); + const ConstantFP *FPImm = I.getOperand(1).getFPImm(); + addNumImm(FPImm->getValueAPF().bitcastToAPInt(), MIB); +} + +void SPIRVInstructionSelector::renderImm32(MachineInstrBuilder &MIB, + const MachineInstr &I, + int OpIdx) const { + assert(I.getOpcode() == TargetOpcode::G_CONSTANT && OpIdx == -1 && + "Expected G_CONSTANT"); + addNumImm(I.getOperand(1).getCImm()->getValue(), MIB); +} + +Register +SPIRVInstructionSelector::buildI32Constant(uint32_t Val, MachineInstr &I, + const SPIRVType *ResType) const { + const SPIRVType *SpvI32Ty = + ResType ? ResType : GR.getOrCreateSPIRVIntegerType(32, I, TII); + Register NewReg; + NewReg = MRI->createGenericVirtualRegister(LLT::scalar(32)); + MachineInstr *MI; + MachineBasicBlock &BB = *I.getParent(); + if (Val == 0) + MI = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantNull)) + .addDef(NewReg) + .addUse(GR.getSPIRVTypeID(SpvI32Ty)); + else + MI = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantI)) + .addDef(NewReg) + .addUse(GR.getSPIRVTypeID(SpvI32Ty)) + .addImm(APInt(32, Val).getZExtValue()); + constrainSelectedInstRegOperands(*MI, TII, TRI, RBI); + return NewReg; +} + +bool SPIRVInstructionSelector::selectFCmp(Register ResVReg, + const SPIRVType *ResType, + MachineInstr &I) const { + unsigned CmpOp = getFCmpOpcode(I.getOperand(1).getPredicate()); + return selectCmp(ResVReg, ResType, CmpOp, I); +} + +Register SPIRVInstructionSelector::buildZerosVal(const SPIRVType *ResType, + MachineInstr &I) const { + return buildI32Constant(0, I, ResType); +} + +Register SPIRVInstructionSelector::buildOnesVal(bool AllOnes, + const SPIRVType *ResType, + MachineInstr &I) const { + unsigned BitWidth = GR.getScalarOrVectorBitWidth(ResType); + APInt One = AllOnes ? APInt::getAllOnesValue(BitWidth) + : APInt::getOneBitSet(BitWidth, 0); + Register OneReg = buildI32Constant(One.getZExtValue(), I, ResType); + if (ResType->getOpcode() == SPIRV::OpTypeVector) { + const unsigned NumEles = ResType->getOperand(2).getImm(); + Register OneVec = MRI->createVirtualRegister(&SPIRV::IDRegClass); + unsigned Opcode = SPIRV::OpConstantComposite; + auto MIB = BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(Opcode)) + .addDef(OneVec) + .addUse(GR.getSPIRVTypeID(ResType)); + for (unsigned i = 0; i < NumEles; ++i) { + MIB.addUse(OneReg); + } + constrainSelectedInstRegOperands(*MIB, TII, TRI, RBI); + return OneVec; + } + return OneReg; +} + +bool SPIRVInstructionSelector::selectSelect(Register ResVReg, + const SPIRVType *ResType, + MachineInstr &I, + bool IsSigned) const { + // To extend a bool, we need to use OpSelect between constants. + Register ZeroReg = buildZerosVal(ResType, I); + Register OneReg = buildOnesVal(IsSigned, ResType, I); + bool IsScalarBool = + GR.isScalarOfType(I.getOperand(1).getReg(), SPIRV::OpTypeBool); + unsigned Opcode = + IsScalarBool ? SPIRV::OpSelectSISCond : SPIRV::OpSelectSIVCond; + return BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(Opcode)) + .addDef(ResVReg) + .addUse(GR.getSPIRVTypeID(ResType)) + .addUse(I.getOperand(1).getReg()) + .addUse(OneReg) + .addUse(ZeroReg) + .constrainAllUses(TII, TRI, RBI); +} + +bool SPIRVInstructionSelector::selectIToF(Register ResVReg, + const SPIRVType *ResType, + MachineInstr &I, bool IsSigned, + unsigned Opcode) const { + Register SrcReg = I.getOperand(1).getReg(); + // We can convert bool value directly to float type without OpConvert*ToF, + // however the translator generates OpSelect+OpConvert*ToF, so we do the same. + if (GR.isScalarOrVectorOfType(I.getOperand(1).getReg(), SPIRV::OpTypeBool)) { + unsigned BitWidth = GR.getScalarOrVectorBitWidth(ResType); + SPIRVType *TmpType = GR.getOrCreateSPIRVIntegerType(BitWidth, I, TII); + if (ResType->getOpcode() == SPIRV::OpTypeVector) { + const unsigned NumElts = ResType->getOperand(2).getImm(); + TmpType = GR.getOrCreateSPIRVVectorType(TmpType, NumElts, I, TII); + } + SrcReg = MRI->createVirtualRegister(&SPIRV::IDRegClass); + selectSelect(SrcReg, TmpType, I, false); + } + return selectUnOpWithSrc(ResVReg, ResType, I, SrcReg, Opcode); +} + +bool SPIRVInstructionSelector::selectExt(Register ResVReg, + const SPIRVType *ResType, + MachineInstr &I, bool IsSigned) const { + if (GR.isScalarOrVectorOfType(I.getOperand(1).getReg(), SPIRV::OpTypeBool)) + return selectSelect(ResVReg, ResType, I, IsSigned); + unsigned Opcode = IsSigned ? SPIRV::OpSConvert : SPIRV::OpUConvert; + return selectUnOp(ResVReg, ResType, I, Opcode); +} + +bool SPIRVInstructionSelector::selectIntToBool(Register IntReg, + Register ResVReg, + const SPIRVType *IntTy, + const SPIRVType *BoolTy, + MachineInstr &I) const { + // To truncate to a bool, we use OpBitwiseAnd 1 and OpINotEqual to zero. + Register BitIntReg = MRI->createVirtualRegister(&SPIRV::IDRegClass); + bool IsVectorTy = IntTy->getOpcode() == SPIRV::OpTypeVector; + unsigned Opcode = IsVectorTy ? SPIRV::OpBitwiseAndV : SPIRV::OpBitwiseAndS; + Register Zero = buildZerosVal(IntTy, I); + Register One = buildOnesVal(false, IntTy, I); + MachineBasicBlock &BB = *I.getParent(); + BuildMI(BB, I, I.getDebugLoc(), TII.get(Opcode)) + .addDef(BitIntReg) + .addUse(GR.getSPIRVTypeID(IntTy)) + .addUse(IntReg) + .addUse(One) + .constrainAllUses(TII, TRI, RBI); + return BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpINotEqual)) + .addDef(ResVReg) + .addUse(GR.getSPIRVTypeID(BoolTy)) + .addUse(BitIntReg) + .addUse(Zero) + .constrainAllUses(TII, TRI, RBI); +} + +bool SPIRVInstructionSelector::selectTrunc(Register ResVReg, + const SPIRVType *ResType, + MachineInstr &I) const { + if (GR.isScalarOrVectorOfType(ResVReg, SPIRV::OpTypeBool)) { + Register IntReg = I.getOperand(1).getReg(); + const SPIRVType *ArgType = GR.getSPIRVTypeForVReg(IntReg); + return selectIntToBool(IntReg, ResVReg, ArgType, ResType, I); + } + bool IsSigned = GR.isScalarOrVectorSigned(ResType); + unsigned Opcode = IsSigned ? SPIRV::OpSConvert : SPIRV::OpUConvert; + return selectUnOp(ResVReg, ResType, I, Opcode); +} + +bool SPIRVInstructionSelector::selectConst(Register ResVReg, + const SPIRVType *ResType, + const APInt &Imm, + MachineInstr &I) const { + assert(ResType->getOpcode() != SPIRV::OpTypePointer || Imm.isNullValue()); + MachineBasicBlock &BB = *I.getParent(); + if (ResType->getOpcode() == SPIRV::OpTypePointer && Imm.isNullValue()) + return BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantNull)) + .addDef(ResVReg) + .addUse(GR.getSPIRVTypeID(ResType)) + .constrainAllUses(TII, TRI, RBI); + + auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantI)) + .addDef(ResVReg) + .addUse(GR.getSPIRVTypeID(ResType)); + // <=32-bit integers should be caught by the sdag pattern. + assert(Imm.getBitWidth() > 32); + addNumImm(Imm, MIB); + return MIB.constrainAllUses(TII, TRI, RBI); +} + +bool SPIRVInstructionSelector::selectOpUndef(Register ResVReg, + const SPIRVType *ResType, + MachineInstr &I) const { + return BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(SPIRV::OpUndef)) + .addDef(ResVReg) + .addUse(GR.getSPIRVTypeID(ResType)) + .constrainAllUses(TII, TRI, RBI); +} + +bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg, + const SPIRVType *ResType, + MachineInstr &I) const { + llvm_unreachable("Intrinsic selection not implemented"); +} + +bool SPIRVInstructionSelector::selectFrameIndex(Register ResVReg, + const SPIRVType *ResType, + MachineInstr &I) const { + return BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(SPIRV::OpVariable)) + .addDef(ResVReg) + .addUse(GR.getSPIRVTypeID(ResType)) + .addImm(static_cast(SPIRV::StorageClass::Function)) + .constrainAllUses(TII, TRI, RBI); +} + +bool SPIRVInstructionSelector::selectBranch(MachineInstr &I) const { + // InstructionSelector walks backwards through the instructions. We can use + // both a G_BR and a G_BRCOND to create an OpBranchConditional. We hit G_BR + // first, so can generate an OpBranchConditional here. If there is no + // G_BRCOND, we just use OpBranch for a regular unconditional branch. + const MachineInstr *PrevI = I.getPrevNode(); + MachineBasicBlock &MBB = *I.getParent(); + if (PrevI != nullptr && PrevI->getOpcode() == TargetOpcode::G_BRCOND) + return BuildMI(MBB, I, I.getDebugLoc(), TII.get(SPIRV::OpBranchConditional)) + .addUse(PrevI->getOperand(0).getReg()) + .addMBB(PrevI->getOperand(1).getMBB()) + .addMBB(I.getOperand(0).getMBB()) + .constrainAllUses(TII, TRI, RBI); + return BuildMI(MBB, I, I.getDebugLoc(), TII.get(SPIRV::OpBranch)) + .addMBB(I.getOperand(0).getMBB()) + .constrainAllUses(TII, TRI, RBI); +} + +bool SPIRVInstructionSelector::selectBranchCond(MachineInstr &I) const { + // InstructionSelector walks backwards through the instructions. For an + // explicit conditional branch with no fallthrough, we use both a G_BR and a + // G_BRCOND to create an OpBranchConditional. We should hit G_BR first, and + // generate the OpBranchConditional in selectBranch above. + // + // If an OpBranchConditional has been generated, we simply return, as the work + // is alread done. If there is no OpBranchConditional, LLVM must be relying on + // implicit fallthrough to the next basic block, so we need to create an + // OpBranchConditional with an explicit "false" argument pointing to the next + // basic block that LLVM would fall through to. + const MachineInstr *NextI = I.getNextNode(); + // Check if this has already been successfully selected. + if (NextI != nullptr && NextI->getOpcode() == SPIRV::OpBranchConditional) + return true; + // Must be relying on implicit block fallthrough, so generate an + // OpBranchConditional with the "next" basic block as the "false" target. + MachineBasicBlock &MBB = *I.getParent(); + unsigned NextMBBNum = MBB.getNextNode()->getNumber(); + MachineBasicBlock *NextMBB = I.getMF()->getBlockNumbered(NextMBBNum); + return BuildMI(MBB, I, I.getDebugLoc(), TII.get(SPIRV::OpBranchConditional)) + .addUse(I.getOperand(0).getReg()) + .addMBB(I.getOperand(1).getMBB()) + .addMBB(NextMBB) + .constrainAllUses(TII, TRI, RBI); +} + +bool SPIRVInstructionSelector::selectPhi(Register ResVReg, + const SPIRVType *ResType, + MachineInstr &I) const { + auto MIB = BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(SPIRV::OpPhi)) + .addDef(ResVReg) + .addUse(GR.getSPIRVTypeID(ResType)); + const unsigned NumOps = I.getNumOperands(); + assert((NumOps % 2 == 1) && "Require odd number of operands for G_PHI"); + for (unsigned i = 1; i < NumOps; i += 2) { + MIB.addUse(I.getOperand(i + 0).getReg()); + MIB.addMBB(I.getOperand(i + 1).getMBB()); + } + return MIB.constrainAllUses(TII, TRI, RBI); +} + +bool SPIRVInstructionSelector::selectGlobalValue( + Register ResVReg, MachineInstr &I, const MachineInstr *Init) const { + MachineIRBuilder MIRBuilder(I); + const GlobalValue *GV = I.getOperand(1).getGlobal(); + SPIRVType *ResType = GR.getOrCreateSPIRVType( + GV->getType(), MIRBuilder, SPIRV::AccessQualifier::ReadWrite, false); + + std::string GlobalIdent = GV->getGlobalIdentifier(); + // TODO: suport @llvm.global.annotations. + auto GlobalVar = cast(GV); + + bool HasInit = GlobalVar->hasInitializer() && + !isa(GlobalVar->getInitializer()); + // Skip empty declaration for GVs with initilaizers till we get the decl with + // passed initializer. + if (HasInit && !Init) + return true; + + unsigned AddrSpace = GV->getAddressSpace(); + SPIRV::StorageClass Storage = addressSpaceToStorageClass(AddrSpace); + bool HasLnkTy = GV->getLinkage() != GlobalValue::InternalLinkage && + Storage != SPIRV::StorageClass::Function; + SPIRV::LinkageType LnkType = + (GV->isDeclaration() || GV->hasAvailableExternallyLinkage()) + ? SPIRV::LinkageType::Import + : SPIRV::LinkageType::Export; + + Register Reg = GR.buildGlobalVariable(ResVReg, ResType, GlobalIdent, GV, + Storage, Init, GlobalVar->isConstant(), + HasLnkTy, LnkType, MIRBuilder, true); + return Reg.isValid(); +} + +namespace llvm { +InstructionSelector * +createSPIRVInstructionSelector(const SPIRVTargetMachine &TM, + const SPIRVSubtarget &Subtarget, + const SPIRVRegisterBankInfo &RBI) { + return new SPIRVInstructionSelector(TM, Subtarget, RBI); +} +} // namespace llvm Index: llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.h =================================================================== --- /dev/null +++ llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.h @@ -0,0 +1,36 @@ +//===- SPIRVLegalizerInfo.h --- SPIR-V Legalization Rules --------*- C++ -*-==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares the targeting of the MachineLegalizer class for SPIR-V. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_LIB_TARGET_SPIRV_SPIRVMACHINELEGALIZER_H +#define LLVM_LIB_TARGET_SPIRV_SPIRVMACHINELEGALIZER_H + +#include "SPIRVGlobalRegistry.h" +#include "llvm/CodeGen/GlobalISel/LegalizerInfo.h" + +bool isTypeFoldingSupported(unsigned Opcode); + +namespace llvm { + +class LLVMContext; +class SPIRVSubtarget; + +// This class provides the information for legalizing SPIR-V instructions. +class SPIRVLegalizerInfo : public LegalizerInfo { + const SPIRVSubtarget *ST; + SPIRVGlobalRegistry *GR; + +public: + bool legalizeCustom(LegalizerHelper &Helper, MachineInstr &MI) const override; + SPIRVLegalizerInfo(const SPIRVSubtarget &ST); +}; +} // namespace llvm +#endif // LLVM_LIB_TARGET_SPIRV_SPIRVMACHINELEGALIZER_H Index: llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp =================================================================== --- /dev/null +++ llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp @@ -0,0 +1,301 @@ +//===- SPIRVLegalizerInfo.cpp --- SPIR-V Legalization Rules ------*- C++ -*-==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the targeting of the Machinelegalizer class for SPIR-V. +// +//===----------------------------------------------------------------------===// + +#include "SPIRVLegalizerInfo.h" +#include "SPIRV.h" +#include "SPIRVGlobalRegistry.h" +#include "SPIRVSubtarget.h" +#include "llvm/CodeGen/GlobalISel/LegalizerHelper.h" +#include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h" +#include "llvm/CodeGen/MachineInstr.h" +#include "llvm/CodeGen/MachineRegisterInfo.h" +#include "llvm/CodeGen/TargetOpcodes.h" + +using namespace llvm; +using namespace llvm::LegalizeActions; +using namespace llvm::LegalityPredicates; + +static const std::set TypeFoldingSupportingOpcs = { + TargetOpcode::G_ADD, + TargetOpcode::G_FADD, + TargetOpcode::G_SUB, + TargetOpcode::G_FSUB, + TargetOpcode::G_MUL, + TargetOpcode::G_FMUL, + TargetOpcode::G_SDIV, + TargetOpcode::G_UDIV, + TargetOpcode::G_FDIV, + TargetOpcode::G_SREM, + TargetOpcode::G_UREM, + TargetOpcode::G_FREM, + TargetOpcode::G_FNEG, + TargetOpcode::G_CONSTANT, + TargetOpcode::G_FCONSTANT, + TargetOpcode::G_AND, + TargetOpcode::G_OR, + TargetOpcode::G_XOR, + TargetOpcode::G_SHL, + TargetOpcode::G_ASHR, + TargetOpcode::G_LSHR, + TargetOpcode::G_SELECT, + TargetOpcode::G_EXTRACT_VECTOR_ELT, +}; + +bool isTypeFoldingSupported(unsigned Opcode) { + return TypeFoldingSupportingOpcs.count(Opcode) > 0; +} + +SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) { + using namespace TargetOpcode; + + this->ST = &ST; + GR = ST.getSPIRVGlobalRegistry(); + + const LLT s1 = LLT::scalar(1); + const LLT s8 = LLT::scalar(8); + const LLT s16 = LLT::scalar(16); + const LLT s32 = LLT::scalar(32); + const LLT s64 = LLT::scalar(64); + + const LLT v16s64 = LLT::fixed_vector(16, 64); + const LLT v16s32 = LLT::fixed_vector(16, 32); + const LLT v16s16 = LLT::fixed_vector(16, 16); + const LLT v16s8 = LLT::fixed_vector(16, 8); + const LLT v16s1 = LLT::fixed_vector(16, 1); + + const LLT v8s64 = LLT::fixed_vector(8, 64); + const LLT v8s32 = LLT::fixed_vector(8, 32); + const LLT v8s16 = LLT::fixed_vector(8, 16); + const LLT v8s8 = LLT::fixed_vector(8, 8); + const LLT v8s1 = LLT::fixed_vector(8, 1); + + const LLT v4s64 = LLT::fixed_vector(4, 64); + const LLT v4s32 = LLT::fixed_vector(4, 32); + const LLT v4s16 = LLT::fixed_vector(4, 16); + const LLT v4s8 = LLT::fixed_vector(4, 8); + const LLT v4s1 = LLT::fixed_vector(4, 1); + + const LLT v3s64 = LLT::fixed_vector(3, 64); + const LLT v3s32 = LLT::fixed_vector(3, 32); + const LLT v3s16 = LLT::fixed_vector(3, 16); + const LLT v3s8 = LLT::fixed_vector(3, 8); + const LLT v3s1 = LLT::fixed_vector(3, 1); + + const LLT v2s64 = LLT::fixed_vector(2, 64); + const LLT v2s32 = LLT::fixed_vector(2, 32); + const LLT v2s16 = LLT::fixed_vector(2, 16); + const LLT v2s8 = LLT::fixed_vector(2, 8); + const LLT v2s1 = LLT::fixed_vector(2, 1); + + const unsigned PSize = ST.getPointerSize(); + const LLT p0 = LLT::pointer(0, PSize); // Function + const LLT p1 = LLT::pointer(1, PSize); // CrossWorkgroup + const LLT p2 = LLT::pointer(2, PSize); // UniformConstant + const LLT p3 = LLT::pointer(3, PSize); // Workgroup + const LLT p4 = LLT::pointer(4, PSize); // Generic + const LLT p5 = LLT::pointer(5, PSize); // Input + + // TODO: remove copy-pasting here by using concatenation in some way. + auto allPtrsScalarsAndVectors = { + p0, p1, p2, p3, p4, p5, s1, s8, s16, + s32, s64, v2s1, v2s8, v2s16, v2s32, v2s64, v3s1, v3s8, + v3s16, v3s32, v3s64, v4s1, v4s8, v4s16, v4s32, v4s64, v8s1, + v8s8, v8s16, v8s32, v8s64, v16s1, v16s8, v16s16, v16s32, v16s64}; + + auto allScalarsAndVectors = { + s1, s8, s16, s32, s64, v2s1, v2s8, v2s16, v2s32, v2s64, + v3s1, v3s8, v3s16, v3s32, v3s64, v4s1, v4s8, v4s16, v4s32, v4s64, + v8s1, v8s8, v8s16, v8s32, v8s64, v16s1, v16s8, v16s16, v16s32, v16s64}; + + auto allIntScalarsAndVectors = {s8, s16, s32, s64, v2s8, v2s16, + v2s32, v2s64, v3s8, v3s16, v3s32, v3s64, + v4s8, v4s16, v4s32, v4s64, v8s8, v8s16, + v8s32, v8s64, v16s8, v16s16, v16s32, v16s64}; + + auto allBoolScalarsAndVectors = {s1, v2s1, v3s1, v4s1, v8s1, v16s1}; + + auto allIntScalars = {s8, s16, s32, s64}; + + auto allFloatScalarsAndVectors = { + s16, s32, s64, v2s16, v2s32, v2s64, v3s16, v3s32, v3s64, + v4s16, v4s32, v4s64, v8s16, v8s32, v8s64, v16s16, v16s32, v16s64}; + + auto allFloatAndIntScalars = allIntScalars; + + auto allPtrs = {p0, p1, p2, p3, p4, p5}; + auto allWritablePtrs = {p0, p1, p3, p4}; + + for (auto Opc : TypeFoldingSupportingOpcs) + getActionDefinitionsBuilder(Opc).custom(); + + getActionDefinitionsBuilder(G_GLOBAL_VALUE).alwaysLegal(); + + // TODO: add proper rules for vectors legalization. + getActionDefinitionsBuilder({G_BUILD_VECTOR, G_SHUFFLE_VECTOR}).alwaysLegal(); + + getActionDefinitionsBuilder({G_MEMCPY, G_MEMMOVE}) + .legalIf(all(typeInSet(0, allWritablePtrs), typeInSet(1, allPtrs))); + + getActionDefinitionsBuilder(G_ADDRSPACE_CAST) + .legalForCartesianProduct(allPtrs, allPtrs); + + getActionDefinitionsBuilder({G_LOAD, G_STORE}).legalIf(typeInSet(1, allPtrs)); + + getActionDefinitionsBuilder(G_BITREVERSE).legalFor(allFloatScalarsAndVectors); + + getActionDefinitionsBuilder(G_FMA).legalFor(allFloatScalarsAndVectors); + + getActionDefinitionsBuilder({G_FPTOSI, G_FPTOUI}) + .legalForCartesianProduct(allIntScalarsAndVectors, + allFloatScalarsAndVectors); + + getActionDefinitionsBuilder({G_SITOFP, G_UITOFP}) + .legalForCartesianProduct(allFloatScalarsAndVectors, + allScalarsAndVectors); + + getActionDefinitionsBuilder({G_SMIN, G_SMAX, G_UMIN, G_UMAX, G_ABS}) + .legalFor(allIntScalarsAndVectors); + + getActionDefinitionsBuilder(G_CTPOP).legalForCartesianProduct( + allIntScalarsAndVectors, allIntScalarsAndVectors); + + getActionDefinitionsBuilder(G_PHI).legalFor(allPtrsScalarsAndVectors); + + getActionDefinitionsBuilder(G_BITCAST).legalIf(all( + typeInSet(0, allPtrsScalarsAndVectors), + typeInSet(1, allPtrsScalarsAndVectors), + LegalityPredicate(([=](const LegalityQuery &Query) { + return Query.Types[0].getSizeInBits() == Query.Types[1].getSizeInBits(); + })))); + + getActionDefinitionsBuilder(G_IMPLICIT_DEF).alwaysLegal(); + + getActionDefinitionsBuilder(G_INTTOPTR) + .legalForCartesianProduct(allPtrs, allIntScalars); + getActionDefinitionsBuilder(G_PTRTOINT) + .legalForCartesianProduct(allIntScalars, allPtrs); + getActionDefinitionsBuilder(G_PTR_ADD).legalForCartesianProduct( + allPtrs, allIntScalars); + + // ST.canDirectlyComparePointers() for pointer args is supported in + // legalizeCustom(). + getActionDefinitionsBuilder(G_ICMP).customIf( + all(typeInSet(0, allBoolScalarsAndVectors), + typeInSet(1, allPtrsScalarsAndVectors))); + + getActionDefinitionsBuilder(G_FCMP).legalIf( + all(typeInSet(0, allBoolScalarsAndVectors), + typeInSet(1, allFloatScalarsAndVectors))); + + getActionDefinitionsBuilder({G_ATOMICRMW_OR, G_ATOMICRMW_ADD, G_ATOMICRMW_AND, + G_ATOMICRMW_MAX, G_ATOMICRMW_MIN, + G_ATOMICRMW_SUB, G_ATOMICRMW_XOR, + G_ATOMICRMW_UMAX, G_ATOMICRMW_UMIN}) + .legalForCartesianProduct(allIntScalars, allWritablePtrs); + + getActionDefinitionsBuilder(G_ATOMICRMW_XCHG) + .legalForCartesianProduct(allFloatAndIntScalars, allWritablePtrs); + + getActionDefinitionsBuilder(G_ATOMIC_CMPXCHG_WITH_SUCCESS).lower(); + // TODO: add proper legalization rules. + getActionDefinitionsBuilder(G_ATOMIC_CMPXCHG).alwaysLegal(); + + getActionDefinitionsBuilder({G_UADDO, G_USUBO, G_SMULO, G_UMULO}) + .alwaysLegal(); + + // Extensions. + getActionDefinitionsBuilder({G_TRUNC, G_ZEXT, G_SEXT, G_ANYEXT}) + .legalForCartesianProduct(allScalarsAndVectors); + + // FP conversions. + getActionDefinitionsBuilder({G_FPTRUNC, G_FPEXT}) + .legalForCartesianProduct(allFloatScalarsAndVectors); + + // Pointer-handling. + getActionDefinitionsBuilder(G_FRAME_INDEX).legalFor({p0}); + + // Control-flow. + getActionDefinitionsBuilder(G_BRCOND).legalFor({s1}); + + getActionDefinitionsBuilder({G_FPOW, + G_FEXP, + G_FEXP2, + G_FLOG, + G_FLOG2, + G_FABS, + G_FMINNUM, + G_FMAXNUM, + G_FCEIL, + G_FCOS, + G_FSIN, + G_FSQRT, + G_FFLOOR, + G_FRINT, + G_FNEARBYINT, + G_INTRINSIC_ROUND, + G_INTRINSIC_TRUNC, + G_FMINIMUM, + G_FMAXIMUM, + G_INTRINSIC_ROUNDEVEN}) + .legalFor(allFloatScalarsAndVectors); + + getActionDefinitionsBuilder(G_FCOPYSIGN) + .legalForCartesianProduct(allFloatScalarsAndVectors, + allFloatScalarsAndVectors); + + getActionDefinitionsBuilder(G_FPOWI).legalForCartesianProduct( + allFloatScalarsAndVectors, allIntScalarsAndVectors); + + getLegacyLegalizerInfo().computeTables(); + verify(*ST.getInstrInfo()); +} + +static Register convertPtrToInt(Register Reg, LLT ConvTy, SPIRVType *SpirvType, + LegalizerHelper &Helper, + MachineRegisterInfo &MRI, + SPIRVGlobalRegistry *GR) { + Register ConvReg = MRI.createGenericVirtualRegister(ConvTy); + GR->assignSPIRVTypeToVReg(SpirvType, ConvReg, Helper.MIRBuilder); + Helper.MIRBuilder.buildInstr(TargetOpcode::G_PTRTOINT) + .addDef(ConvReg) + .addUse(Reg); + return ConvReg; +} + +bool SPIRVLegalizerInfo::legalizeCustom(LegalizerHelper &Helper, + MachineInstr &MI) const { + auto Opc = MI.getOpcode(); + MachineRegisterInfo &MRI = MI.getMF()->getRegInfo(); + if (!isTypeFoldingSupported(Opc)) { + assert(Opc == TargetOpcode::G_ICMP); + assert(GR->getSPIRVTypeForVReg(MI.getOperand(0).getReg())); + auto &Op0 = MI.getOperand(2); + auto &Op1 = MI.getOperand(3); + Register Reg0 = Op0.getReg(); + Register Reg1 = Op1.getReg(); + CmpInst::Predicate Cond = + static_cast(MI.getOperand(1).getPredicate()); + if ((!ST->canDirectlyComparePointers() || + (Cond != CmpInst::ICMP_EQ && Cond != CmpInst::ICMP_NE)) && + MRI.getType(Reg0).isPointer() && MRI.getType(Reg1).isPointer()) { + LLT ConvT = LLT::scalar(ST->getPointerSize()); + Type *LLVMTy = IntegerType::get(MI.getMF()->getFunction().getContext(), + ST->getPointerSize()); + SPIRVType *SpirvTy = GR->getOrCreateSPIRVType(LLVMTy, Helper.MIRBuilder); + Op0.setReg(convertPtrToInt(Reg0, ConvT, SpirvTy, Helper, MRI, GR)); + Op1.setReg(convertPtrToInt(Reg1, ConvT, SpirvTy, Helper, MRI, GR)); + } + return true; + } + // TODO: implement legalization for other opcodes. + return true; +} Index: llvm/lib/Target/SPIRV/SPIRVSubtarget.h =================================================================== --- llvm/lib/Target/SPIRV/SPIRVSubtarget.h +++ llvm/lib/Target/SPIRV/SPIRVSubtarget.h @@ -30,7 +30,7 @@ namespace llvm { class StringRef; - +class SPIRVGlobalRegistry; class SPIRVTargetMachine; class SPIRVSubtarget : public SPIRVGenSubtargetInfo { @@ -38,6 +38,8 @@ const unsigned PointerSize; uint32_t SPIRVVersion; + std::unique_ptr GR; + SPIRVInstrInfo InstrInfo; SPIRVFrameLowering FrameLowering; SPIRVTargetLowering TLInfo; @@ -45,6 +47,8 @@ // GlobalISel related APIs. std::unique_ptr CallLoweringInfo; std::unique_ptr RegBankInfo; + std::unique_ptr Legalizer; + std::unique_ptr InstSelector; public: // This constructor initializes the data members to match that @@ -63,6 +67,8 @@ uint32_t getSPIRVVersion() const { return SPIRVVersion; }; + SPIRVGlobalRegistry *getSPIRVGlobalRegistry() const { return GR.get(); } + const CallLowering *getCallLowering() const override { return CallLoweringInfo.get(); } @@ -71,6 +77,14 @@ return RegBankInfo.get(); } + const LegalizerInfo *getLegalizerInfo() const override { + return Legalizer.get(); + } + + InstructionSelector *getInstructionSelector() const override { + return InstSelector.get(); + } + const SPIRVInstrInfo *getInstrInfo() const override { return &InstrInfo; } const SPIRVFrameLowering *getFrameLowering() const override { Index: llvm/lib/Target/SPIRV/SPIRVSubtarget.cpp =================================================================== --- llvm/lib/Target/SPIRV/SPIRVSubtarget.cpp +++ llvm/lib/Target/SPIRV/SPIRVSubtarget.cpp @@ -12,6 +12,8 @@ #include "SPIRVSubtarget.h" #include "SPIRV.h" +#include "SPIRVGlobalRegistry.h" +#include "SPIRVLegalizerInfo.h" #include "SPIRVRegisterBankInfo.h" #include "SPIRVTargetMachine.h" #include "llvm/MC/TargetRegistry.h" @@ -43,10 +45,13 @@ : SPIRVGenSubtargetInfo(TT, CPU, /*TuneCPU=*/CPU, FS), PointerSize(computePointerSize(TT)), SPIRVVersion(0), InstrInfo(), FrameLowering(initSubtargetDependencies(CPU, FS)), TLInfo(TM, *this) { - CallLoweringInfo.reset(new SPIRVCallLowering(TLInfo)); + GR.reset(new SPIRVGlobalRegistry(PointerSize)); + CallLoweringInfo.reset(new SPIRVCallLowering(TLInfo, *this, GR.get())); + Legalizer.reset(new SPIRVLegalizerInfo(*this)); auto *RBI = new SPIRVRegisterBankInfo(); RegBankInfo.reset(RBI); + InstSelector.reset(createSPIRVInstructionSelector(TM, *this, *RBI)); } SPIRVSubtarget &SPIRVSubtarget::initSubtargetDependencies(StringRef CPU, Index: llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp =================================================================== --- llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp +++ llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp @@ -12,6 +12,9 @@ #include "SPIRVTargetMachine.h" #include "SPIRV.h" +#include "SPIRVCallLowering.h" +#include "SPIRVGlobalRegistry.h" +#include "SPIRVLegalizerInfo.h" #include "SPIRVTargetObjectFile.h" #include "SPIRVTargetTransformInfo.h" #include "TargetInfo/SPIRVTargetInfo.h" @@ -29,11 +32,18 @@ #include "llvm/Target/TargetOptions.h" using namespace llvm; +InstructionSelector * +createSPIRVInstructionSelector(const SPIRVTargetMachine &TM, + SPIRVSubtarget &Subtarget, + SPIRVRegisterBankInfo &RBI); extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeSPIRVTarget() { // Register the target. RegisterTargetMachine X(getTheSPIRV32Target()); RegisterTargetMachine Y(getTheSPIRV64Target()); + + PassRegistry &PR = *PassRegistry::getPassRegistry(); + initializeGlobalISel(PR); } static std::string computeDataLayout(const Triple &TT) { @@ -155,7 +165,19 @@ return false; } +namespace { +// A custom subclass of InstructionSelect, which is mostly the same except from +// not requiring RegBankSelect to occur previously. +class SPIRVInstructionSelect : public InstructionSelect { + // We don't use register banks, so unset the requirement for them + MachineFunctionProperties getRequiredProperties() const override { + return InstructionSelect::getRequiredProperties().reset( + MachineFunctionProperties::Property::RegBankSelected); + } +}; +} // namespace + bool SPIRVPassConfig::addGlobalInstructionSelect() { - addPass(new InstructionSelect(getOptLevel())); + addPass(new SPIRVInstructionSelect()); return false; } Index: llvm/lib/Target/SPIRV/SPIRVUtils.h =================================================================== --- /dev/null +++ llvm/lib/Target/SPIRV/SPIRVUtils.h @@ -0,0 +1,58 @@ +//===--- SPIRVUtils.h ---- SPIR-V Utility Functions -------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains miscellaneous utility functions. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_LIB_TARGET_SPIRV_SPIRVUTILS_H +#define LLVM_LIB_TARGET_SPIRV_SPIRVUTILS_H + +#include "MCTargetDesc/SPIRVBaseInfo.h" +#include "SPIRVInstrInfo.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h" +#include "llvm/CodeGen/MachineInstr.h" +#include "llvm/CodeGen/MachineInstrBuilder.h" +#include "llvm/IR/IRBuilder.h" +#include + +// Add the given string as a series of integer operand, inserting null +// terminators and padding to make sure the operands all have 32-bit +// little-endian words. +void addStringImm(const llvm::StringRef &Str, llvm::MachineInstrBuilder &MIB); +void addStringImm(const llvm::StringRef &Str, llvm::IRBuilder<> &B, + std::vector &Args); + +// Add the given numerical immediate to MIB. +void addNumImm(const llvm::APInt &Imm, llvm::MachineInstrBuilder &MIB); + +// Add an OpName instruction for the given target register. +void buildOpName(llvm::Register Target, const llvm::StringRef &Name, + llvm::MachineIRBuilder &MIRBuilder); + +// Add an OpDecorate instruction for the given Reg. +void buildOpDecorate(llvm::Register Reg, llvm::MachineIRBuilder &MIRBuilder, + llvm::SPIRV::Decoration Dec, + const std::vector &DecArgs, + llvm::StringRef StrImm = ""); +void buildOpDecorate(llvm::Register Reg, llvm::MachineInstr &I, + const llvm::SPIRVInstrInfo &TII, + llvm::SPIRV::Decoration Dec, + const std::vector &DecArgs, + llvm::StringRef StrImm = ""); + +// Convert a SPIR-V storage class to the corresponding LLVM IR address space. +unsigned storageClassToAddressSpace(llvm::SPIRV::StorageClass SC); + +// Convert an LLVM IR address space to a SPIR-V storage class. +llvm::SPIRV::StorageClass addressSpaceToStorageClass(unsigned AddrSpace); + +llvm::SPIRV::MemorySemantics +getMemSemanticsForStorageClass(llvm::SPIRV::StorageClass sc); +#endif // LLVM_LIB_TARGET_SPIRV_SPIRVUTILS_H Index: llvm/lib/Target/SPIRV/SPIRVUtils.cpp =================================================================== --- /dev/null +++ llvm/lib/Target/SPIRV/SPIRVUtils.cpp @@ -0,0 +1,172 @@ +//===--- SPIRVUtils.cpp ---- SPIR-V Utility Functions -----------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains miscellaneous utility functions. +// +//===----------------------------------------------------------------------===// + +#include "SPIRVUtils.h" +#include "SPIRV.h" + +using namespace llvm; + +// The following functions are used to add these string literals as a series of +// 32-bit integer operands with the correct format, and unpack them if necessary +// when making string comparisons in compiler passes. +// SPIR-V requires null-terminated UTF-8 strings padded to 32-bit alignment. +static uint32_t convertCharsToWord(const StringRef &Str, unsigned i) { + uint32_t Word = 0u; // Build up this 32-bit word from 4 8-bit chars. + for (unsigned WordIndex = 0; WordIndex < 4; ++WordIndex) { + unsigned StrIndex = i + WordIndex; + uint8_t CharToAdd = 0; // Initilize char as padding/null. + if (StrIndex < Str.size()) { // If it's within the string, get a real char. + CharToAdd = Str[StrIndex]; + } + Word |= (CharToAdd << (WordIndex * 8)); + } + return Word; +} + +// Get length including padding and null terminator. +static size_t getPaddedLen(const StringRef &Str) { + const size_t Len = Str.size() + 1; + return (Len % 4 == 0) ? Len : Len + (4 - (Len % 4)); +} + +void addStringImm(const StringRef &Str, MachineInstrBuilder &MIB) { + const size_t PaddedLen = getPaddedLen(Str); + for (unsigned i = 0; i < PaddedLen; i += 4) { + // Add an operand for the 32-bits of chars or padding. + MIB.addImm(convertCharsToWord(Str, i)); + } +} + +void addStringImm(const StringRef &Str, IRBuilder<> &B, + std::vector &Args) { + const size_t PaddedLen = getPaddedLen(Str); + for (unsigned i = 0; i < PaddedLen; i += 4) { + // Add a vector element for the 32-bits of chars or padding. + Args.push_back(B.getInt32(convertCharsToWord(Str, i))); + } +} + +void addNumImm(const APInt &Imm, MachineInstrBuilder &MIB) { + const auto Bitwidth = Imm.getBitWidth(); + switch (Bitwidth) { + case 1: + break; // Already handled. + case 8: + case 16: + case 32: + MIB.addImm(Imm.getZExtValue()); + break; + case 64: { + uint64_t FullImm = Imm.getZExtValue(); + uint32_t LowBits = FullImm & 0xffffffff; + uint32_t HighBits = (FullImm >> 32) & 0xffffffff; + MIB.addImm(LowBits).addImm(HighBits); + break; + } + default: + report_fatal_error("Unsupported constant bitwidth"); + } +} + +void buildOpName(Register Target, const StringRef &Name, + MachineIRBuilder &MIRBuilder) { + if (!Name.empty()) { + auto MIB = MIRBuilder.buildInstr(SPIRV::OpName).addUse(Target); + addStringImm(Name, MIB); + } +} + +static void finishBuildOpDecorate(MachineInstrBuilder &MIB, + const std::vector &DecArgs, + StringRef StrImm) { + if (!StrImm.empty()) + addStringImm(StrImm, MIB); + for (const auto &DecArg : DecArgs) + MIB.addImm(DecArg); +} + +void buildOpDecorate(Register Reg, MachineIRBuilder &MIRBuilder, + llvm::SPIRV::Decoration Dec, + const std::vector &DecArgs, StringRef StrImm) { + auto MIB = MIRBuilder.buildInstr(SPIRV::OpDecorate) + .addUse(Reg) + .addImm(static_cast(Dec)); + finishBuildOpDecorate(MIB, DecArgs, StrImm); +} + +void buildOpDecorate(Register Reg, MachineInstr &I, const SPIRVInstrInfo &TII, + llvm::SPIRV::Decoration Dec, + const std::vector &DecArgs, StringRef StrImm) { + MachineBasicBlock &MBB = *I.getParent(); + auto MIB = BuildMI(MBB, I, I.getDebugLoc(), TII.get(SPIRV::OpDecorate)) + .addUse(Reg) + .addImm(static_cast(Dec)); + finishBuildOpDecorate(MIB, DecArgs, StrImm); +} + +// TODO: maybe the following two functions should be handled in the subtarget +// to allow for different OpenCL vs Vulkan handling. +unsigned storageClassToAddressSpace(SPIRV::StorageClass SC) { + switch (SC) { + case SPIRV::StorageClass::Function: + return 0; + case SPIRV::StorageClass::CrossWorkgroup: + return 1; + case SPIRV::StorageClass::UniformConstant: + return 2; + case SPIRV::StorageClass::Workgroup: + return 3; + case SPIRV::StorageClass::Generic: + return 4; + case SPIRV::StorageClass::Input: + return 7; + default: + llvm_unreachable("Unable to get address space id"); + } +} + +SPIRV::StorageClass addressSpaceToStorageClass(unsigned AddrSpace) { + switch (AddrSpace) { + case 0: + return SPIRV::StorageClass::Function; + case 1: + return SPIRV::StorageClass::CrossWorkgroup; + case 2: + return SPIRV::StorageClass::UniformConstant; + case 3: + return SPIRV::StorageClass::Workgroup; + case 4: + return SPIRV::StorageClass::Generic; + case 7: + return SPIRV::StorageClass::Input; + default: + llvm_unreachable("Unknown address space"); + } +} + +SPIRV::MemorySemantics getMemSemanticsForStorageClass(SPIRV::StorageClass SC) { + switch (SC) { + case SPIRV::StorageClass::StorageBuffer: + case SPIRV::StorageClass::Uniform: + return SPIRV::MemorySemantics::UniformMemory; + case SPIRV::StorageClass::Workgroup: + return SPIRV::MemorySemantics::WorkgroupMemory; + case SPIRV::StorageClass::CrossWorkgroup: + return SPIRV::MemorySemantics::CrossWorkgroupMemory; + case SPIRV::StorageClass::AtomicCounter: + return SPIRV::MemorySemantics::AtomicCounterMemory; + case SPIRV::StorageClass::Image: + return SPIRV::MemorySemantics::ImageMemory; + default: + return SPIRV::MemorySemantics::None; + } +}