Index: llvm/lib/Target/SPIRV/CMakeLists.txt =================================================================== --- llvm/lib/Target/SPIRV/CMakeLists.txt +++ llvm/lib/Target/SPIRV/CMakeLists.txt @@ -15,13 +15,17 @@ add_llvm_target(SPIRVCodeGen SPIRVAsmPrinter.cpp SPIRVCallLowering.cpp + SPIRVGlobalRegistry.cpp SPIRVInstrInfo.cpp + SPIRVInstructionSelector.cpp SPIRVISelLowering.cpp + SPIRVLegalizerInfo.cpp SPIRVMCInstLower.cpp SPIRVRegisterBankInfo.cpp SPIRVRegisterInfo.cpp SPIRVSubtarget.cpp SPIRVTargetMachine.cpp + SPIRVUtils.cpp LINK_COMPONENTS CodeGen Index: llvm/lib/Target/SPIRV/MCTargetDesc/CMakeLists.txt =================================================================== --- llvm/lib/Target/SPIRV/MCTargetDesc/CMakeLists.txt +++ llvm/lib/Target/SPIRV/MCTargetDesc/CMakeLists.txt @@ -1,4 +1,5 @@ add_llvm_component_library(LLVMSPIRVDesc + SPIRVBaseInfo.cpp SPIRVMCAsmInfo.cpp SPIRVMCTargetDesc.cpp SPIRVTargetStreamer.cpp Index: llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVBaseInfo.h =================================================================== --- /dev/null +++ llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVBaseInfo.h @@ -0,0 +1,790 @@ +//===-- SPIRVBaseInfo.h - Top level definitions for SPIRV ------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains small standalone helper functions and enum definitions for +// the SPIRV target useful for the compiler back-end and the MC libraries. +// As such, it deliberately does not include references to LLVM core +// code gen types, passes, etc.. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_LIB_TARGET_SPIRV_MCTARGETDESC_SPIRVBASEINFO_H +#define LLVM_LIB_TARGET_SPIRV_MCTARGETDESC_SPIRVBASEINFO_H + +#include + +// Macros to define an enum and the functions to return its name. + +#define MAKE_ENUM(Enum, Var, Val) Var = Val, + +#define MAKE_NAME_CASE(Enum, Var, Val) \ + case Enum::Var: \ + return #Var; + +#define MAKE_MASK_ENUM_NAME_CASE(Enum, Var, Val) \ + if (e == Enum::Var) { \ + return #Var; \ + } else if ((Enum::Var != 0) && (e & Enum::Var)) { \ + nameString += sep + #Var; \ + sep = "|"; \ + } + +#define DEF_ENUM(EnumName, DefEnumCommand) \ + namespace EnumName { \ + enum EnumName : uint32_t { DefEnumCommand(EnumName, MAKE_ENUM) }; \ + } + +#define DEF_NAME_FUNC_HEADER(EnumName) \ + std::string get##EnumName##Name(EnumName::EnumName e); + +// Use this for enums that can only take a single value +#define DEF_NAME_FUNC_BODY(EnumName, DefEnumCommand) \ + std::string get##EnumName##Name(EnumName::EnumName e) { \ + switch (e) { DefEnumCommand(EnumName, MAKE_NAME_CASE) } \ + return "UNKNOWN_ENUM"; \ + } + +// Use this for bitmasks that can take multiple values e.g. DontInline|Const +#define DEF_MASK_NAME_FUNC_BODY(EnumName, DefEnumCommand) \ + std::string get##EnumName##Name(EnumName::EnumName e) { \ + std::string nameString = ""; \ + std::string sep = ""; \ + DefEnumCommand(EnumName, MAKE_MASK_ENUM_NAME_CASE); \ + return nameString; \ + } + +#define GEN_ENUM_HEADER(EnumName) \ + DEF_ENUM(EnumName, DEF_##EnumName) \ + DEF_NAME_FUNC_HEADER(EnumName) + +// Use this for enums that can only take a single value +#define GEN_ENUM_IMPL(EnumName) DEF_NAME_FUNC_BODY(EnumName, DEF_##EnumName) + +// Use this for bitmasks that can take multiple values e.g. DontInline|Const +#define GEN_MASK_ENUM_IMPL(EnumName) \ + DEF_MASK_NAME_FUNC_BODY(EnumName, DEF_##EnumName) + +//===----------------------------------------------------------------------===// +// The actual enum definitions are added below +// +// Call GEN_ENUM_HEADER here in the header. +// +// Call GEN_ENUM_IMPL or GEN_MASK_ENUM_IMPL in SPIRVBaseInfo.cpp, depending on +// whether the enum can only take a single value, or whether it can be a bitmask +// of multiple values e.g. FunctionControl which can be DontInline|Const. +// +// Syntax for each line is: +// X(N, Name, IdNum) \ +// +// Each enum def must fit on a single line, so additional macros are sometimes +// used for defining capabilities with long names. +//===----------------------------------------------------------------------===// + +#define ANUIE(Pref) Pref##ArrayNonUniformIndexingEXT +#define SBBA(Bits) StorageBuffer##Bits##BitAccess +#define SB8BA(Pref) Pref##StorageBuffer8BitAccess +#define ADIE(Pref) Pref##ArrayDynamicIndexingEXT +#define VAR_PTR_SB VariablePointersStorageBuffer + +#define DEF_Capability(N, X) \ + X(N, Matrix, 0) \ + X(N, Shader, 1) \ + X(N, Geometry, 2) \ + X(N, Tessellation, 3) \ + X(N, Addresses, 4) \ + X(N, Linkage, 5) \ + X(N, Kernel, 6) \ + X(N, Vector16, 7) \ + X(N, Float16Buffer, 8) \ + X(N, Float16, 9) \ + X(N, Float64, 10) \ + X(N, Int64, 11) \ + X(N, Int64Atomics, 12) \ + X(N, ImageBasic, 13) \ + X(N, ImageReadWrite, 14) \ + X(N, ImageMipmap, 15) \ + X(N, Pipes, 17) \ + X(N, Groups, 18) \ + X(N, DeviceEnqueue, 19) \ + X(N, LiteralSampler, 20) \ + X(N, AtomicStorage, 21) \ + X(N, Int16, 22) \ + X(N, TessellationPointSize, 23) \ + X(N, GeometryPointSize, 24) \ + X(N, ImageGatherExtended, 25) \ + X(N, StorageImageMultisample, 27) \ + X(N, UniformBufferArrayDynamicIndexing, 28) \ + X(N, SampledImageArrayDymnamicIndexing, 29) \ + X(N, ClipDistance, 32) \ + X(N, CullDistance, 33) \ + X(N, ImageCubeArray, 34) \ + X(N, SampleRateShading, 35) \ + X(N, ImageRect, 36) \ + X(N, SampledRect, 37) \ + X(N, GenericPointer, 38) \ + X(N, Int8, 39) \ + X(N, InputAttachment, 40) \ + X(N, SparseResidency, 41) \ + X(N, MinLod, 42) \ + X(N, Sampled1D, 43) \ + X(N, Image1D, 44) \ + X(N, SampledCubeArray, 45) \ + X(N, SampledBuffer, 46) \ + X(N, ImageBuffer, 47) \ + X(N, ImageMSArray, 48) \ + X(N, StorageImageExtendedFormats, 49) \ + X(N, ImageQuery, 50) \ + X(N, DerivativeControl, 51) \ + X(N, InterpolationFunction, 52) \ + X(N, TransformFeedback, 53) \ + X(N, GeometryStreams, 54) \ + X(N, StorageImageReadWithoutFormat, 55) \ + X(N, StorageImageWriteWithoutFormat, 56) \ + X(N, MultiViewport, 57) \ + X(N, SubgroupDispatch, 58) \ + X(N, NamedBarrier, 59) \ + X(N, PipeStorage, 60) \ + X(N, GroupNonUniform, 61) \ + X(N, GroupNonUniformVote, 62) \ + X(N, GroupNonUniformArithmetic, 63) \ + X(N, GroupNonUniformBallot, 64) \ + X(N, GroupNonUniformShuffle, 65) \ + X(N, GroupNonUniformShuffleRelative, 66) \ + X(N, GroupNonUniformClustered, 67) \ + X(N, GroupNonUniformQuad, 68) \ + X(N, SubgroupBallotKHR, 4423) \ + X(N, DrawParameters, 4427) \ + X(N, SubgroupVoteKHR, 4431) \ + X(N, SBBA(16), 4433) \ + X(N, StorageUniform16, 4434) \ + X(N, StoragePushConstant16, 4435) \ + X(N, StorageInputOutput16, 4436) \ + X(N, DeviceGroup, 4437) \ + X(N, MultiView, 4439) \ + X(N, VAR_PTR_SB, 4441) \ + X(N, VariablePointers, 4442) \ + X(N, AtomicStorageOps, 4445) \ + X(N, SampleMaskPostDepthCoverage, 4447) \ + X(N, StorageBuffer8BitAccess, 4448) \ + X(N, SB8BA(UniformAnd), 4449) \ + X(N, StoragePushConstant8, 4450) \ + X(N, DenormPreserve, 4464) \ + X(N, DenormFlushToZero, 4465) \ + X(N, SignedZeroInfNanPreserve, 4466) \ + X(N, RoundingModeRTE, 4467) \ + X(N, RoundingModeRTZ, 4468) \ + X(N, Float16ImageAMD, 5008) \ + X(N, ImageGatherBiasLodAMD, 5009) \ + X(N, FragmentMaskAMD, 5010) \ + X(N, StencilExportEXT, 5013) \ + X(N, ImageReadWriteLodAMD, 5015) \ + X(N, SampleMaskOverrideCoverageNV, 5249) \ + X(N, GeometryShaderPassthroughNV, 5251) \ + X(N, ShaderViewportIndexLayerEXT, 5254) \ + X(N, ShaderViewportMaskNV, 5255) \ + X(N, ShaderStereoViewNV, 5259) \ + X(N, PerViewAttributesNV, 5260) \ + X(N, FragmentFullyCoveredEXT, 5265) \ + X(N, MeshShadingNV, 5266) \ + X(N, ShaderNonUniformEXT, 5301) \ + X(N, RuntimeDescriptorArrayEXT, 5302) \ + X(N, ADIE(InputAttachment), 5303) \ + X(N, ADIE(UniformTexelBuffer), 5304) \ + X(N, ADIE(StorageTexelBuffer), 5305) \ + X(N, ANUIE(UniformBuffer), 5306) \ + X(N, ANUIE(SampledImage), 5307) \ + X(N, ANUIE(StorageBuffer), 5308) \ + X(N, ANUIE(StorageImage), 5309) \ + X(N, ANUIE(InputAttachment), 5310) \ + X(N, ANUIE(UniformTexelBuffer), 5311) \ + X(N, ANUIE(StorageTexelBuffer), 5312) \ + X(N, RayTracingNV, 5340) \ + X(N, SubgroupShuffleINTEL, 5568) \ + X(N, SubgroupBufferBlockIOINTEL, 5569) \ + X(N, SubgroupImageBlockIOINTEL, 5570) \ + X(N, SubgroupImageMediaBlockIOINTEL, 5579) \ + X(N, SubgroupAvcMotionEstimationINTEL, 5696) \ + X(N, SubgroupAvcMotionEstimationIntraINTEL, 5697) \ + X(N, SubgroupAvcMotionEstimationChromaINTEL, 5698) \ + X(N, GroupNonUniformPartitionedNV, 5297) \ + X(N, VulkanMemoryModelKHR, 5345) \ + X(N, VulkanMemoryModelDeviceScopeKHR, 5346) \ + X(N, ImageFootprintNV, 5282) \ + X(N, FragmentBarycentricNV, 5284) \ + X(N, ComputeDerivativeGroupQuadsNV, 5288) \ + X(N, ComputeDerivativeGroupLinearNV, 5350) \ + X(N, FragmentDensityEXT, 5291) \ + X(N, PhysicalStorageBufferAddressesEXT, 5347) \ + X(N, CooperativeMatrixNV, 5357) +GEN_ENUM_HEADER(Capability) + +#define DEF_SourceLanguage(N, X) \ + X(N, Unknown, 0) \ + X(N, ESSL, 1) \ + X(N, GLSL, 2) \ + X(N, OpenCL_C, 3) \ + X(N, OpenCL_CPP, 4) \ + X(N, HLSL, 5) +GEN_ENUM_HEADER(SourceLanguage) + +#define PSB(Suff) PhysicalStorageBuffer##Suff +#define DEF_AddressingModel(N, X) \ + X(N, Logical, 0) \ + X(N, Physical32, 1) \ + X(N, Physical64, 2) \ + X(N, PSB(64EXT), 5348) +GEN_ENUM_HEADER(AddressingModel) + +#define DEF_ExecutionModel(N, X) \ + X(N, Vertex, 0) \ + X(N, TessellationControl, 1) \ + X(N, TessellationEvaluation, 2) \ + X(N, Geometry, 3) \ + X(N, Fragment, 4) \ + X(N, GLCompute, 5) \ + X(N, Kernel, 6) \ + X(N, TaskNV, 5267) \ + X(N, MeshNV, 5268) \ + X(N, RayGenerationNV, 5313) \ + X(N, IntersectionNV, 5314) \ + X(N, AnyHitNV, 5315) \ + X(N, ClosestHitNV, 5316) \ + X(N, MissNV, 5317) \ + X(N, CallableNV, 5318) +GEN_ENUM_HEADER(ExecutionModel) + +#define DEF_MemoryModel(N, X) \ + X(N, Simple, 0) \ + X(N, GLSL450, 1) \ + X(N, OpenCL, 2) \ + X(N, VulkanKHR, 3) +GEN_ENUM_HEADER(MemoryModel) + +#define MSNV MeshShadingNV +#define DG1(Suff) DerivativeGroup##Suff +#define DG2(Pref, Suff) Pref##DerivativeGroup##Suff +#define DEF_ExecutionMode(N, X) \ + X(N, Invocations, 0) \ + X(N, SpacingEqual, 1) \ + X(N, SpacingFractionalEven, 2) \ + X(N, SpacingFractionalOdd, 3) \ + X(N, VertexOrderCw, 4) \ + X(N, VertexOrderCcw, 5) \ + X(N, PixelCenterInteger, 6) \ + X(N, OriginUpperLeft, 7) \ + X(N, OriginLowerLeft, 8) \ + X(N, EarlyFragmentTests, 9) \ + X(N, PointMode, 10) \ + X(N, Xfb, 11) \ + X(N, DepthReplacing, 12) \ + X(N, DepthGreater, 14) \ + X(N, DepthLess, 15) \ + X(N, DepthUnchanged, 16) \ + X(N, LocalSize, 17) \ + X(N, LocalSizeHint, 18) \ + X(N, InputPoints, 19) \ + X(N, InputLines, 20) \ + X(N, InputLinesAdjacency, 21) \ + X(N, Triangles, 22) \ + X(N, InputTrianglesAdjacency, 23) \ + X(N, Quads, 24) \ + X(N, Isolines, 25) \ + X(N, OutputVertices, 26) \ + X(N, OutputPoints, 27) \ + X(N, OutputLineStrip, 28) \ + X(N, OutputTriangleStrip, 29) \ + X(N, VecTypeHint, 30) \ + X(N, ContractionOff, 31) \ + X(N, Initializer, 33) \ + X(N, Finalizer, 34) \ + X(N, SubgroupSize, 35) \ + X(N, SubgroupsPerWorkgroup, 36) \ + X(N, SubgroupsPerWorkgroupId, 37) \ + X(N, LocalSizeId, 38) \ + X(N, LocalSizeHintId, 39) \ + X(N, PostDepthCoverage, 4446) \ + X(N, DenormPreserve, 4459) \ + X(N, DenormFlushToZero, 4460) \ + X(N, SignedZeroInfNanPreserve, 4461) \ + X(N, RoundingModeRTE, 4462) \ + X(N, RoundingModeRTZ, 4463) \ + X(N, StencilRefReplacingEXT, 5027) \ + X(N, OutputLinesNV, 5269) \ + X(N, DG1(QuadsNV), 5289) \ + X(N, DG1(LinearNV), 5290) \ + X(N, OutputTrianglesNV, 5298) +GEN_ENUM_HEADER(ExecutionMode) + +#define DEF_StorageClass(N, X) \ + X(N, UniformConstant, 0) \ + X(N, Input, 1) \ + X(N, Uniform, 2) \ + X(N, Output, 3) \ + X(N, Workgroup, 4) \ + X(N, CrossWorkgroup, 5) \ + X(N, Private, 6) \ + X(N, Function, 7) \ + X(N, Generic, 8) \ + X(N, PushConstant, 9) \ + X(N, AtomicCounter, 10) \ + X(N, Image, 11) \ + X(N, StorageBuffer, 12) \ + X(N, CallableDataNV, 5328) \ + X(N, IncomingCallableDataNV, 5329) \ + X(N, RayPayloadNV, 5338) \ + X(N, HitAttributeNV, 5339) \ + X(N, IncomingRayPayloadNV, 5342) \ + X(N, ShaderRecordBufferNV, 5343) \ + X(N, PSB(EXT), 5349) +GEN_ENUM_HEADER(StorageClass) + +// Need to manually do the getDimName() function, as "1D" is not a valid token +// so it can't be auto-generated with these macros + +#define DEF_Dim(N, X) \ + X(N, DIM_1D, 0) \ + X(N, DIM_2D, 1) \ + X(N, DIM_3D, 2) \ + X(N, DIM_Cube, 3) \ + X(N, DIM_Rect, 4) \ + X(N, DIM_Buffer, 5) \ + X(N, DIM_SubpassData, 6) +GEN_ENUM_HEADER(Dim) + +#define DEF_SamplerAddressingMode(N, X) \ + X(N, None, 0) \ + X(N, ClampToEdge, 1) \ + X(N, Clamp, 2) \ + X(N, Repeat, 3) \ + X(N, RepeatMirrored, 4) +GEN_ENUM_HEADER(SamplerAddressingMode) + +#define DEF_SamplerFilterMode(N, X) \ + X(N, Nearest, 0) \ + X(N, Linear, 1) +GEN_ENUM_HEADER(SamplerFilterMode) + +#define DEF_ImageFormat(N, X) \ + X(N, Unknown, 0) \ + X(N, Rgba32f, 1) \ + X(N, Rgba16f, 2) \ + X(N, R32f, 3) \ + X(N, Rgba8, 4) \ + X(N, Rgba8Snorm, 5) \ + X(N, Rg32f, 6) \ + X(N, Rg16f, 7) \ + X(N, R11fG11fB10f, 8) \ + X(N, R16f, 9) \ + X(N, Rgba16, 10) \ + X(N, Rgb10A2, 11) \ + X(N, Rg16, 12) \ + X(N, Rg8, 13) \ + X(N, R16, 14) \ + X(N, R8, 15) \ + X(N, Rgba16Snorm, 16) \ + X(N, Rg16Snorm, 17) \ + X(N, Rg8Snorm, 18) \ + X(N, R16Snorm, 19) \ + X(N, R8Snorm, 20) \ + X(N, Rgba32i, 21) \ + X(N, Rgba16i, 22) \ + X(N, Rgba8i, 23) \ + X(N, R32i, 24) \ + X(N, Rg32i, 25) \ + X(N, Rg16i, 26) \ + X(N, Rg8i, 27) \ + X(N, R16i, 28) \ + X(N, R8i, 29) \ + X(N, Rgba32ui, 30) \ + X(N, Rgba16ui, 31) \ + X(N, Rgba8ui, 32) \ + X(N, R32ui, 33) \ + X(N, Rgb10a2ui, 34) \ + X(N, Rg32ui, 35) \ + X(N, Rg16ui, 36) \ + X(N, Rg8ui, 37) \ + X(N, R16ui, 38) \ + X(N, R8ui, 39) +GEN_ENUM_HEADER(ImageFormat) + +#define DEF_ImageChannelOrder(N, X) \ + X(N, R, 0) \ + X(N, A, 1) \ + X(N, RG, 2) \ + X(N, RA, 3) \ + X(N, RGB, 4) \ + X(N, RGBA, 5) \ + X(N, BGRA, 6) \ + X(N, ARGB, 7) \ + X(N, Intensity, 8) \ + X(N, Luminance, 9) \ + X(N, Rx, 10) \ + X(N, RGx, 11) \ + X(N, RGBx, 12) \ + X(N, Depth, 13) \ + X(N, DepthStencil, 14) \ + X(N, sRGB, 15) \ + X(N, sRGBx, 16) \ + X(N, sRGBA, 17) \ + X(N, sBGRA, 18) \ + X(N, ABGR, 19) +GEN_ENUM_HEADER(ImageChannelOrder) + +#define DEF_ImageChannelDataType(N, X) \ + X(N, SnormInt8, 0) \ + X(N, SnormInt16, 1) \ + X(N, UnormInt8, 2) \ + X(N, UnormInt16, 3) \ + X(N, UnormShort565, 4) \ + X(N, UnormShort555, 5) \ + X(N, UnormInt101010, 6) \ + X(N, SignedInt8, 7) \ + X(N, SignedInt16, 8) \ + X(N, SignedInt32, 9) \ + X(N, UnsignedInt8, 10) \ + X(N, UnsignedInt16, 11) \ + X(N, UnsigendInt32, 12) \ + X(N, HalfFloat, 13) \ + X(N, Float, 14) \ + X(N, UnormInt24, 15) \ + X(N, UnormInt101010_2, 16) +GEN_ENUM_HEADER(ImageChannelDataType) + +#define DEF_ImageOperand(N, X) \ + X(N, None, 0x0) \ + X(N, Bias, 0x1) \ + X(N, Lod, 0x2) \ + X(N, Grad, 0x4) \ + X(N, ConstOffset, 0x8) \ + X(N, Offset, 0x10) \ + X(N, ConstOffsets, 0x20) \ + X(N, Sample, 0x40) \ + X(N, MinLod, 0x80) \ + X(N, MakeTexelAvailableKHR, 0x100) \ + X(N, MakeTexelVisibleKHR, 0x200) \ + X(N, NonPrivateTexelKHR, 0x400) \ + X(N, VolatileTexelKHR, 0x800) \ + X(N, SignExtend, 0x1000) \ + X(N, ZeroExtend, 0x2000) +GEN_ENUM_HEADER(ImageOperand) + +#define DEF_FPFastMathMode(N, X) \ + X(N, None, 0x0) \ + X(N, NotNaN, 0x1) \ + X(N, NotInf, 0x2) \ + X(N, NSZ, 0x4) \ + X(N, AllowRecip, 0x8) \ + X(N, Fast, 0x10) +GEN_ENUM_HEADER(FPFastMathMode) + +#define DEF_FPRoundingMode(N, X) \ + X(N, RTE, 0) \ + X(N, RTZ, 1) \ + X(N, RTP, 2) \ + X(N, RTN, 3) +GEN_ENUM_HEADER(FPRoundingMode) + +#define DEF_LinkageType(N, X) \ + X(N, Export, 0) \ + X(N, Import, 1) +GEN_ENUM_HEADER(LinkageType) + +#define DEF_AccessQualifier(N, X) \ + X(N, ReadOnly, 0) \ + X(N, WriteOnly, 1) \ + X(N, ReadWrite, 2) +GEN_ENUM_HEADER(AccessQualifier) + +#define DEF_FunctionParameterAttribute(N, X) \ + X(N, Zext, 0) \ + X(N, Sext, 1) \ + X(N, ByVal, 2) \ + X(N, Sret, 3) \ + X(N, NoAlias, 4) \ + X(N, NoCapture, 5) \ + X(N, NoWrite, 6) \ + X(N, NoReadWrite, 7) +GEN_ENUM_HEADER(FunctionParameterAttribute) + +#define DEF_Decoration(N, X) \ + X(N, RelaxedPrecision, 0) \ + X(N, SpecId, 1) \ + X(N, Block, 2) \ + X(N, BufferBlock, 3) \ + X(N, RowMajor, 4) \ + X(N, ColMajor, 5) \ + X(N, ArrayStride, 6) \ + X(N, MatrixStride, 7) \ + X(N, GLSLShared, 8) \ + X(N, GLSLPacked, 9) \ + X(N, CPacked, 10) \ + X(N, BuiltIn, 11) \ + X(N, NoPerspective, 13) \ + X(N, Flat, 14) \ + X(N, Patch, 15) \ + X(N, Centroid, 16) \ + X(N, Sample, 17) \ + X(N, Invariant, 18) \ + X(N, Restrict, 19) \ + X(N, Aliased, 20) \ + X(N, Volatile, 21) \ + X(N, Constant, 22) \ + X(N, Coherent, 23) \ + X(N, NonWritable, 24) \ + X(N, NonReadable, 25) \ + X(N, Uniform, 26) \ + X(N, UniformId, 27) \ + X(N, SaturatedConversion, 28) \ + X(N, Stream, 29) \ + X(N, Location, 30) \ + X(N, Component, 31) \ + X(N, Index, 32) \ + X(N, Binding, 33) \ + X(N, DescriptorSet, 34) \ + X(N, Offset, 35) \ + X(N, XfbBuffer, 36) \ + X(N, XfbStride, 37) \ + X(N, FuncParamAttr, 38) \ + X(N, FPRoundingMode, 39) \ + X(N, FPFastMathMode, 40) \ + X(N, LinkageAttributes, 41) \ + X(N, NoContraction, 42) \ + X(N, InputAttachmentIndex, 43) \ + X(N, Alignment, 44) \ + X(N, MaxByteOffset, 45) \ + X(N, AlignmentId, 46) \ + X(N, MaxByteOffsetId, 47) \ + X(N, NoSignedWrap, 4469) \ + X(N, NoUnsignedWrap, 4470) \ + X(N, ExplicitInterpAMD, 4999) \ + X(N, OverrideCoverageNV, 5248) \ + X(N, PassthroughNV, 5250) \ + X(N, ViewportRelativeNV, 5252) \ + X(N, SecondaryViewportRelativeNV, 5256) \ + X(N, PerPrimitiveNV, 5271) \ + X(N, PerViewNV, 5272) \ + X(N, PerVertexNV, 5273) \ + X(N, NonUniformEXT, 5300) \ + X(N, CountBuffer, 5634) \ + X(N, UserSemantic, 5635) \ + X(N, RestrictPointerEXT, 5355) \ + X(N, AliasedPointerEXT, 5356) +GEN_ENUM_HEADER(Decoration) + +#define SBK SubgroupBallotKHR +#define PVANV PerViewAttributesNV +#define GNU GroupNonUniform +#define MSNV MeshShadingNV +#define DEF_BuiltIn(N, X) \ + X(N, Position, 0) \ + X(N, PointSize, 1) \ + X(N, ClipDistance, 3) \ + X(N, CullDistance, 4) \ + X(N, VertexId, 5) \ + X(N, InstanceId, 6) \ + X(N, PrimitiveId, 7) \ + X(N, InvocationId, 8) \ + X(N, Layer, 9) \ + X(N, ViewportIndex, 10) \ + X(N, TessLevelOuter, 11) \ + X(N, TessLevelInner, 12) \ + X(N, TessCoord, 13) \ + X(N, PatchVertices, 14) \ + X(N, FragCoord, 15) \ + X(N, PointCoord, 16) \ + X(N, FrontFacing, 17) \ + X(N, SampleId, 18) \ + X(N, SamplePosition, 19) \ + X(N, SampleMask, 20) \ + X(N, FragDepth, 22) \ + X(N, HelperInvocation, 23) \ + X(N, NumWorkgroups, 24) \ + X(N, WorkgroupSize, 25) \ + X(N, WorkgroupId, 26) \ + X(N, LocalInvocationId, 27) \ + X(N, GlobalInvocationId, 28) \ + X(N, LocalInvocationIndex, 29) \ + X(N, WorkDim, 30) \ + X(N, GlobalSize, 31) \ + X(N, EnqueuedWorkgroupSize, 32) \ + X(N, GlobalOffset, 33) \ + X(N, GlobalLinearId, 34) \ + X(N, SubgroupSize, 36) \ + X(N, SubgroupMaxSize, 37) \ + X(N, NumSubgroups, 38) \ + X(N, NumEnqueuedSubgroups, 39) \ + X(N, SubgroupId, 40) \ + X(N, SubgroupLocalInvocationId, 41) \ + X(N, VertexIndex, 42) \ + X(N, InstanceIndex, 43) \ + X(N, SubgroupEqMask, 4416) \ + X(N, SubgroupGeMask, 4417) \ + X(N, SubgroupGtMask, 4418) \ + X(N, SubgroupLeMask, 4419) \ + X(N, SubgroupLtMask, 4420) \ + X(N, BaseVertex, 4424) \ + X(N, BaseInstance, 4425) \ + X(N, DrawIndex, 4426) \ + X(N, DeviceIndex, 4438) \ + X(N, ViewIndex, 4440) \ + X(N, BaryCoordNoPerspAMD, 4492) \ + X(N, BaryCoordNoPerspCentroidAMD, 4493) \ + X(N, BaryCoordNoPerspSampleAMD, 4494) \ + X(N, BaryCoordSmoothAMD, 4495) \ + X(N, BaryCoordSmoothCentroid, 4496) \ + X(N, BaryCoordSmoothSample, 4497) \ + X(N, BaryCoordPullModel, 4498) \ + X(N, FragStencilRefEXT, 5014) \ + X(N, ViewportMaskNV, 5253) \ + X(N, SecondaryPositionNV, 5257) \ + X(N, SecondaryViewportMaskNV, 5258) \ + X(N, PositionPerViewNV, 5261) \ + X(N, ViewportMaskPerViewNV, 5262) \ + X(N, FullyCoveredEXT, 5264) \ + X(N, TaskCountNV, 5274) \ + X(N, PrimitiveCountNV, 5275) \ + X(N, PrimitiveIndicesNV, 5276) \ + X(N, ClipDistancePerViewNV, 5277) \ + X(N, CullDistancePerViewNV, 5278) \ + X(N, LayerPerViewNV, 5279) \ + X(N, MeshViewCountNV, 5280) \ + X(N, MeshViewIndices, 5281) \ + X(N, BaryCoordNV, 5286) \ + X(N, BaryCoordNoPerspNV, 5287) \ + X(N, FragSizeEXT, 5292) \ + X(N, FragInvocationCountEXT, 5293) \ + X(N, LaunchIdNV, 5319) \ + X(N, LaunchSizeNV, 5320) \ + X(N, WorldRayOriginNV, 5321) \ + X(N, WorldRayDirectionNV, 5322) \ + X(N, ObjectRayOriginNV, 5323) \ + X(N, ObjectRayDirectionNV, 5324) \ + X(N, RayTminNV, 5325) \ + X(N, RayTmaxNV, 5326) \ + X(N, InstanceCustomIndexNV, 5327) \ + X(N, ObjectToWorldNV, 5330) \ + X(N, WorldToObjectNV, 5331) \ + X(N, HitTNV, 5332) \ + X(N, HitKindNV, 5333) \ + X(N, IncomingRayFlagsNV, 5351) +GEN_ENUM_HEADER(BuiltIn) + +#define DEF_SelectionControl(N, X) \ + X(N, None, 0x0) \ + X(N, Flatten, 0x1) \ + X(N, DontFlatten, 0x2) +GEN_ENUM_HEADER(SelectionControl) + +#define DEF_LoopControl(N, X) \ + X(N, None, 0x0) \ + X(N, Unroll, 0x1) \ + X(N, DontUnroll, 0x2) \ + X(N, DependencyInfinite, 0x4) \ + X(N, DependencyLength, 0x8) \ + X(N, MinIterations, 0x10) \ + X(N, MaxIterations, 0x20) \ + X(N, IterationMultiple, 0x40) \ + X(N, PeelCount, 0x80) \ + X(N, PartialCount, 0x100) +GEN_ENUM_HEADER(LoopControl) + +#define DEF_FunctionControl(N, X) \ + X(N, None, 0x0) \ + X(N, Inline, 0x1) \ + X(N, DontInline, 0x2) \ + X(N, Pure, 0x4) \ + X(N, Const, 0x8) +GEN_ENUM_HEADER(FunctionControl) + +#define DEF_MemorySemantics(N, X) \ + X(N, None, 0x0) \ + X(N, Acquire, 0x2) \ + X(N, Release, 0x4) \ + X(N, AcquireRelease, 0x8) \ + X(N, SequentiallyConsistent, 0x10) \ + X(N, UniformMemory, 0x40) \ + X(N, SubgroupMemory, 0x80) \ + X(N, WorkgroupMemory, 0x100) \ + X(N, CrossWorkgroupMemory, 0x200) \ + X(N, AtomicCounterMemory, 0x400) \ + X(N, ImageMemory, 0x800) \ + X(N, OutputMemoryKHR, 0x1000) \ + X(N, MakeAvailableKHR, 0x2000) \ + X(N, MakeVisibleKHR, 0x4000) +GEN_ENUM_HEADER(MemorySemantics) + +#define DEF_MemoryOperand(N, X) \ + X(N, None, 0x0) \ + X(N, Volatile, 0x1) \ + X(N, Aligned, 0x2) \ + X(N, Nontemporal, 0x4) \ + X(N, MakePointerAvailableKHR, 0x8) \ + X(N, MakePointerVisibleKHR, 0x10) \ + X(N, NonPrivatePointerKHR, 0x20) +GEN_ENUM_HEADER(MemoryOperand) + +#define DEF_Scope(N, X) \ + X(N, CrossDevice, 0) \ + X(N, Device, 1) \ + X(N, Workgroup, 2) \ + X(N, Subgroup, 3) \ + X(N, Invocation, 4) \ + X(N, QueueFamilyKHR, 5) +GEN_ENUM_HEADER(Scope) + +#define GNUA GroupNonUniformArithmetic +#define GNUB GroupNonUniformBallot +#define GNUPNV GroupNonUniformPartitionedNV +#define DEF_GroupOperation(N, X) \ + X(N, Reduce, 0) \ + X(N, InclusiveScan, 1) \ + X(N, ExclusiveScan, 2) \ + X(N, ClusteredReduce, 3) \ + X(N, PartitionedReduceNV, 6) \ + X(N, PartitionedInclusiveScanNV, 7) \ + X(N, PartitionedExclusiveScanNV, 8) +GEN_ENUM_HEADER(GroupOperation) + +#define DEF_KernelEnqueueFlags(N, X) \ + X(N, NoWait, 0) \ + X(N, WaitKernel, 1) \ + X(N, WaitWorkGroup, 2) +GEN_ENUM_HEADER(KernelEnqueueFlags) + +#define DEF_KernelProfilingInfo(N, X) \ + X(N, None, 0x0) \ + X(N, CmdExecTime, 0x1) +GEN_ENUM_HEADER(KernelProfilingInfo) + +// Return a string representation of the operands from startIndex onwards. +// Templated to allow both MachineInstr and MCInst to use the same logic. +template +std::string getSPIRVStringOperand(const InstType &MI, unsigned int StartIndex) { + std::string s = ""; // Iteratively append to this string + + const unsigned int NumOps = MI.getNumOperands(); + bool IsFinished = false; + for (unsigned int i = StartIndex; i < NumOps && !IsFinished; ++i) { + const auto &Op = MI.getOperand(i); + if (!Op.isImm()) // Stop if we hit a register operand + break; + uint32_t Imm = Op.getImm(); // Each i32 word is up to 4 characters + for (unsigned ShiftAmount = 0; ShiftAmount < 32; ShiftAmount += 8) { + char c = (Imm >> ShiftAmount) & 0xff; + if (c == 0) { // Stop if we hit a null-terminator character + IsFinished = true; + break; + } else { + s += c; // Otherwise, append the character to the result string + } + } + } + return s; +} + +#endif // LLVM_LIB_TARGET_SPIRV_MCTARGETDESC_SPIRVBASEINFO_H Index: llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVBaseInfo.cpp =================================================================== --- /dev/null +++ llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVBaseInfo.cpp @@ -0,0 +1,78 @@ +//===-- SPIRVBaseInfo.cpp - Top level definitions for SPIRV ----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains small standalone helper functions and enum definitions for +// the SPIRV target useful for the compiler back-end and the MC libraries. +// As such, it deliberately does not include references to LLVM core +// code gen types, passes, etc.. +// +//===----------------------------------------------------------------------===// + +#include "SPIRVBaseInfo.h" + +// Implement getEnumName(Enum e) helper functions. +GEN_ENUM_IMPL(Capability) +GEN_ENUM_IMPL(SourceLanguage) +GEN_ENUM_IMPL(ExecutionModel) +GEN_ENUM_IMPL(AddressingModel) +GEN_ENUM_IMPL(MemoryModel) +GEN_ENUM_IMPL(ExecutionMode) +GEN_ENUM_IMPL(StorageClass) + +// Dim must be implemented manually, as "1D" is not a valid C++ naming token +std::string getDimName(Dim::Dim dim) { + switch (dim) { + case Dim::DIM_1D: + return "1D"; + case Dim::DIM_2D: + return "2D"; + case Dim::DIM_3D: + return "3D"; + case Dim::DIM_Cube: + return "Cube"; + case Dim::DIM_Rect: + return "Rect"; + case Dim::DIM_Buffer: + return "Buffer"; + case Dim::DIM_SubpassData: + return "SubpassData"; + default: + return "UNKNOWN_Dim"; + } +} + +GEN_ENUM_IMPL(SamplerAddressingMode) +GEN_ENUM_IMPL(SamplerFilterMode) + +GEN_ENUM_IMPL(ImageFormat) +GEN_ENUM_IMPL(ImageChannelOrder) +GEN_ENUM_IMPL(ImageChannelDataType) +GEN_MASK_ENUM_IMPL(ImageOperand) + +GEN_MASK_ENUM_IMPL(FPFastMathMode) +GEN_ENUM_IMPL(FPRoundingMode) + +GEN_ENUM_IMPL(LinkageType) +GEN_ENUM_IMPL(AccessQualifier) +GEN_ENUM_IMPL(FunctionParameterAttribute) + +GEN_ENUM_IMPL(Decoration) +GEN_ENUM_IMPL(BuiltIn) + +GEN_MASK_ENUM_IMPL(SelectionControl) +GEN_MASK_ENUM_IMPL(LoopControl) +GEN_MASK_ENUM_IMPL(FunctionControl) + +GEN_MASK_ENUM_IMPL(MemorySemantics) +GEN_MASK_ENUM_IMPL(MemoryOperand) + +GEN_ENUM_IMPL(Scope) +GEN_ENUM_IMPL(GroupOperation) + +GEN_ENUM_IMPL(KernelEnqueueFlags) +GEN_MASK_ENUM_IMPL(KernelProfilingInfo) Index: llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp =================================================================== --- llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp +++ llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp @@ -11,6 +11,7 @@ //===----------------------------------------------------------------------===// #include "SPIRVInstPrinter.h" +#include "SPIRVBaseInfo.h" #include "llvm/CodeGen/Register.h" #include "llvm/MC/MCAsmInfo.h" #include "llvm/MC/MCExpr.h" @@ -112,7 +113,26 @@ void SPIRVInstPrinter::printStringImm(const MCInst *MI, unsigned OpNo, raw_ostream &O) { - llvm_unreachable("Unimplemented printStringImm"); + const unsigned NumOps = MI->getNumOperands(); + unsigned StrStartIndex = OpNo; + while (StrStartIndex < NumOps) { + if (MI->getOperand(StrStartIndex).isReg()) + break; + + std::string Str = getSPIRVStringOperand(*MI, OpNo); + if (StrStartIndex != OpNo) + O << ' '; // Add a space if we're starting a new string/argument + O << '"'; + for (char c : Str) { + if (c == '"') + O.write('\\'); // Escape " characters (might break for complex UTF-8) + O.write(c); + } + O << '"'; + + unsigned numOpsInString = (Str.size() / 4) + 1; + StrStartIndex += numOpsInString; + } } void SPIRVInstPrinter::printExtInst(const MCInst *MI, unsigned OpNo, @@ -120,12 +140,17 @@ llvm_unreachable("Unimplemented printExtInst"); } -// Methods for printing textual names of SPIR-V enums +// Implementation of SPIR-V Enum printing using definitions in SPIRVBaseInfo.h #define GEN_INSTR_PRINTER_IMPL(EnumName) \ -void SPIRVInstPrinter::print##EnumName(const MCInst *MI, unsigned OpNo, \ - raw_ostream &O) { \ - llvm_unreachable("Unimplemented print" #EnumName ); \ -} + void SPIRVInstPrinter::print##EnumName(const MCInst *MI, unsigned OpNo, \ + raw_ostream &O) { \ + if (OpNo < MI->getNumOperands()) { \ + EnumName::EnumName e = \ + static_cast(MI->getOperand(OpNo).getImm()); \ + O << get##EnumName##Name(e); \ + } \ + } + GEN_INSTR_PRINTER_IMPL(Capability) GEN_INSTR_PRINTER_IMPL(SourceLanguage) GEN_INSTR_PRINTER_IMPL(ExecutionModel) Index: llvm/lib/Target/SPIRV/SPIRV.h =================================================================== --- llvm/lib/Target/SPIRV/SPIRV.h +++ llvm/lib/Target/SPIRV/SPIRV.h @@ -17,6 +17,12 @@ class SPIRVTargetMachine; class SPIRVRegisterBankInfo; class SPIRVSubtarget; +class InstructionSelector; + +InstructionSelector * +createSPIRVInstructionSelector(const SPIRVTargetMachine &TM, + const SPIRVSubtarget &Subtarget, + const SPIRVRegisterBankInfo &RBI); } // namespace llvm #endif // LLVM_LIB_TARGET_SPIRV_SPIRV_H Index: llvm/lib/Target/SPIRV/SPIRVCallLowering.h =================================================================== --- llvm/lib/Target/SPIRV/SPIRVCallLowering.h +++ llvm/lib/Target/SPIRV/SPIRVCallLowering.h @@ -13,6 +13,7 @@ #ifndef LLVM_LIB_TARGET_SPIRV_SPIRVCALLLOWERING_H #define LLVM_LIB_TARGET_SPIRV_SPIRVCALLLOWERING_H +#include "SPIRVGlobalRegistry.h" #include "llvm/CodeGen/GlobalISel/CallLowering.h" namespace llvm { @@ -21,8 +22,11 @@ class SPIRVCallLowering : public CallLowering { private: + // Used to create and assign function, argument, and return type information + SPIRVGlobalRegistry *GR; + public: - SPIRVCallLowering(const SPIRVTargetLowering &TLI); + SPIRVCallLowering(const SPIRVTargetLowering &TLI, SPIRVGlobalRegistry *GR); // Built OpReturn or OpReturnValue bool lowerReturn(MachineIRBuilder &MIRBuiler, const Value *Val, Index: llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp =================================================================== --- llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp +++ llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp @@ -12,18 +12,21 @@ //===----------------------------------------------------------------------===// #include "SPIRVCallLowering.h" +#include "MCTargetDesc/SPIRVBaseInfo.h" #include "SPIRV.h" +#include "SPIRVGlobalRegistry.h" #include "SPIRVISelLowering.h" #include "SPIRVRegisterInfo.h" #include "SPIRVSubtarget.h" +#include "SPIRVUtils.h" #include "llvm/CodeGen/FunctionLoweringInfo.h" -#include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h" #include "llvm/Demangle/Demangle.h" using namespace llvm; -SPIRVCallLowering::SPIRVCallLowering(const SPIRVTargetLowering &TLI) - : CallLowering(&TLI) {} +SPIRVCallLowering::SPIRVCallLowering(const SPIRVTargetLowering &TLI, + SPIRVGlobalRegistry *GR) + : CallLowering(&TLI), GR(GR) {} bool SPIRVCallLowering::lowerReturn(MachineIRBuilder &MIRBuilder, const Value *Val, ArrayRef VRegs, @@ -31,8 +34,8 @@ Register SwiftErrorVReg) const { assert(VRegs.size() < 2 && "All return types should use a single register"); if (Val) { - MIRBuilder.buildInstr(SPIRV::OpReturnValue).addUse(VRegs[0]); - return true; + auto MIB = MIRBuilder.buildInstr(SPIRV::OpReturnValue).addUse(VRegs[0]); + return constrainRegOperands(MIB); } else { MIRBuilder.buildInstr(SPIRV::OpReturn); return true; @@ -41,7 +44,19 @@ // Based on the LLVM function attributes, get a SPIR-V FunctionControl static uint32_t getFunctionControl(const Function &F) { - uint32_t FuncControl = 0; + uint32_t FuncControl = FunctionControl::None; + if (F.hasFnAttribute(Attribute::AttrKind::AlwaysInline)) { + FuncControl |= FunctionControl::Inline; + } + if (F.hasFnAttribute(Attribute::AttrKind::ReadNone)) { + FuncControl |= FunctionControl::Pure; + } + if (F.hasFnAttribute(Attribute::AttrKind::ReadOnly)) { + FuncControl |= FunctionControl::Const; + } + if (F.hasFnAttribute(Attribute::AttrKind::NoInline)) { + FuncControl |= FunctionControl::DontInline; + } return FuncControl; } @@ -49,30 +64,99 @@ const Function &F, ArrayRef> VRegs, FunctionLoweringInfo &FLI) const { - auto MRI = MIRBuilder.getMRI(); + assert(GR && "Must initialize the SPIRV type registry before lowering args."); + // Assign types and names to all args, and store their types for later SmallVector ArgTypeVRegs; if (VRegs.size() > 0) { unsigned int i = 0; for (const auto &Arg : F.args()) { assert(VRegs[i].size() == 1 && "Formal arg has multiple vregs"); - ArgTypeVRegs.push_back( - MRI->createGenericVirtualRegister(LLT::scalar(32))); + // auto *SpirvTy = GR->getOrCreateSPIRVType(Arg.getType(), MIRBuilder); + // SPIRVType *SpirvTy = GR->getSPIRVTypeForVReg(VRegs[i][0]); + // if (!SpirvTy) + auto *SpirvTy = + GR->assignTypeToVReg(Arg.getType(), VRegs[i][0], MIRBuilder); + ArgTypeVRegs.push_back(GR->getSPIRVTypeID(SpirvTy)); + + if (Arg.hasName()) + buildOpName(VRegs[i][0], Arg.getName(), MIRBuilder); + if (Arg.getType()->isPointerTy()) { + auto DerefBytes = static_cast(Arg.getDereferenceableBytes()); + if (DerefBytes != 0) + buildOpDecorate(VRegs[i][0], MIRBuilder, Decoration::MaxByteOffset, + {DerefBytes}); + } + if (Arg.hasAttribute(Attribute::Alignment)) { + auto Alignment = static_cast( + Arg.getAttribute(Attribute::Alignment).getValueAsInt()); + buildOpDecorate(VRegs[i][0], MIRBuilder, Decoration::Alignment, + {Alignment}); + } + if (Arg.hasAttribute(Attribute::ReadOnly)) { + buildOpDecorate(VRegs[i][0], MIRBuilder, Decoration::FuncParamAttr, + {FunctionParameterAttribute::NoWrite}); + } + if (Arg.hasAttribute(Attribute::ZExt)) { + buildOpDecorate(VRegs[i][0], MIRBuilder, Decoration::FuncParamAttr, + {FunctionParameterAttribute::Zext}); + } ++i; } } // Generate a SPIR-V type for the function + auto MRI = MIRBuilder.getMRI(); Register FuncVReg = MRI->createGenericVirtualRegister(LLT::scalar(32)); MRI->setRegClass(FuncVReg, &SPIRV::IDRegClass); + auto *FTy = F.getFunctionType(); + + // this code restores function args/retvalue types + // for composite cases because the final types should still be aggregate + // whereas they're i32 during the translation to cope with + // aggregates flattenning etc + auto *NamedMD = F.getParent()->getNamedMetadata("spv.cloned_funcs"); + if (NamedMD) { + Type *RetTy = F.getFunctionType()->getReturnType(); + SmallVector ArgTypes; + auto ThisFuncMDIt = + std::find_if(NamedMD->op_begin(), NamedMD->op_end(), [&F](MDNode *N) { + return isa(N->getOperand(0)) && + cast(N->getOperand(0))->getString() == F.getName(); + }); + + if (ThisFuncMDIt != NamedMD->op_end()) { + auto *ThisFuncMD = *ThisFuncMDIt; + if (cast( + cast( + cast(ThisFuncMD->getOperand(1))->getOperand(0)) + ->getValue()) + // TODO: currently -1 indicates return value, support this types + // renaming for arguments as well + ->getSExtValue() == -1) + RetTy = cast( + cast(ThisFuncMD->getOperand(1))->getOperand(1)) + ->getType(); + } + + for (auto &Arg : F.args()) + ArgTypes.push_back(Arg.getType()); + + FTy = FunctionType::get(RetTy, ArgTypes, F.isVarArg()); + } + + auto FuncTy = GR->assignTypeToVReg(FTy, FuncVReg, MIRBuilder); + + // Build the OpTypeFunction declaring it + Register ReturnTypeID = FuncTy->getOperand(1).getReg(); uint32_t FuncControl = getFunctionControl(F); MIRBuilder.buildInstr(SPIRV::OpFunction) .addDef(FuncVReg) - .addUse(MRI->createGenericVirtualRegister(LLT::scalar(32))) + .addUse(ReturnTypeID) .addImm(FuncControl) - .addUse(MRI->createGenericVirtualRegister(LLT::scalar(32))); + .addUse(GR->getSPIRVTypeID(FuncTy)); // Add OpFunctionParameters const unsigned int NumArgs = ArgTypeVRegs.size(); @@ -83,6 +167,25 @@ .addDef(VRegs[i][0]) .addUse(ArgTypeVRegs[i]); } + + // Name the function + if (F.hasName()) + buildOpName(FuncVReg, F.getName(), MIRBuilder); + + // Handle entry points and function linkage + if (F.getCallingConv() == CallingConv::SPIR_KERNEL) { + auto ExecModel = ExecutionModel::Kernel; + auto MIB = MIRBuilder.buildInstr(SPIRV::OpEntryPoint) + .addImm(ExecModel) + .addUse(FuncVReg); + addStringImm(F.getName(), MIB); + } else if (F.getLinkage() == GlobalValue::LinkageTypes::ExternalLinkage || + F.getLinkage() == GlobalValue::LinkOnceODRLinkage) { + auto LnkTy = F.isDeclaration() ? LinkageType::Import : LinkageType::Export; + buildOpDecorate(FuncVReg, MIRBuilder, Decoration::LinkageAttributes, + {LnkTy}, F.getGlobalIdentifier()); + } + return true; } @@ -94,20 +197,52 @@ Register ResVReg = Info.OrigRet.Regs.empty() ? Register(0) : Info.OrigRet.Regs[0]; + // Emit a regular OpFunctionCall + + // If it's an externally declared function, be sure to emit its type and + // function declaration here. It will be hoisted globally later + auto M = MIRBuilder.getMF().getFunction().getParent(); + Function *Callee = M->getFunction(FuncName); + Register FuncVReg; + if (Callee && Callee->isDeclaration()) { + // Emit the type info and forward function declaration to the first MBB + // to ensure VReg definition dependencies are valid across all MBBs + MachineIRBuilder FirstBlockBuilder; + auto &MF = MIRBuilder.getMF(); + FirstBlockBuilder.setMF(MF); + FirstBlockBuilder.setMBB(*MF.getBlockNumbered(0)); + + SmallVector, 8> VRegArgs; + SmallVector, 8> ToInsert; + for (const Argument &Arg : Callee->args()) { + if (MIRBuilder.getDataLayout().getTypeStoreSize(Arg.getType()).isZero()) + continue; // Don't handle zero sized types. + ToInsert.push_back( + {MIRBuilder.getMRI()->createGenericVirtualRegister(LLT::scalar(32))}); + VRegArgs.push_back(ToInsert.back()); + } + + // TODO: Reuse FunctionLoweringInfo + FunctionLoweringInfo FuncInfo; + lowerFormalArguments(FirstBlockBuilder, *Callee, VRegArgs, FuncInfo); + } + // Make sure there's a valid return reg, even for functions returning void if (!ResVReg.isValid()) { ResVReg = MIRBuilder.getMRI()->createVirtualRegister(&SPIRV::IDRegClass); } + SPIRVType *RetType = + GR->assignTypeToVReg(Info.OrigRet.Ty, ResVReg, MIRBuilder); + // Emit the OpFunctionCall and its args auto MIB = MIRBuilder.buildInstr(SPIRV::OpFunctionCall) .addDef(ResVReg) - .addUse(MIRBuilder.getMRI()->createVirtualRegister( - &SPIRV::IDRegClass)) + .addUse(GR->getSPIRVTypeID(RetType)) .add(Info.Callee); for (const auto &arg : Info.OrigArgs) { assert(arg.Regs.size() == 1 && "Call arg has multiple VRegs"); MIB.addUse(arg.Regs[0]); } - return true; + return constrainRegOperands(MIB); } Index: llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h =================================================================== --- /dev/null +++ llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h @@ -0,0 +1,168 @@ +//===-- SPIRVGlobalRegistry.h - SPIR-V Global Registry ----------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// SPIRVGlobalRegistry is used to maintain rich type information required for +// SPIR-V even after lowering from LLVM IR to GMIR. It can convert an llvm::Type +// into an OpTypeXXX instruction, and map it to a virtual register. Also it +// builds and supports consistency of constants and global variables. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_LIB_TARGET_SPIRV_SPIRVTYPEMANAGER_H +#define LLVM_LIB_TARGET_SPIRV_SPIRVTYPEMANAGER_H + +#include "MCTargetDesc/SPIRVBaseInfo.h" +#include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h" + +namespace AQ = AccessQualifier; + +namespace llvm { +using SPIRVType = const MachineInstr; + +class SPIRVGlobalRegistry { + // Registers holding values which have types associated with them. + // Initialized upon VReg definition in IRTranslator. + // Do not confuse this with DuplicatesTracker as DT maps Type* to + // where Reg = OpType... + // while VRegToTypeMap tracks SPIR-V type assigned to other regs (i.e. not + // type-declaring ones) + DenseMap> VRegToTypeMap; + + DenseMap SPIRVToLLVMType; + + // Number of bits pointers and size_t integers require. + const unsigned int PointerSize; + + // Add a new OpTypeXXX instruction without checking for duplicates + SPIRVType *createSPIRVType(const Type *Type, MachineIRBuilder &MIRBuilder, + AQ::AccessQualifier accessQual = AQ::ReadWrite, + bool EmitIR = true); + +public: + SPIRVGlobalRegistry(unsigned int PointerSize); + + MachineFunction *CurMF; + + // Get or create a SPIR-V type corresponding the given LLVM IR type, + // and map it to the given VReg by creating an ASSIGN_TYPE instruction. + SPIRVType *assignTypeToVReg(const Type *Type, Register VReg, + MachineIRBuilder &MIRBuilder, + AQ::AccessQualifier AccessQual = AQ::ReadWrite); + + // In cases where the SPIR-V type is already known, this function can be + // used to map it to the given VReg via an ASSIGN_TYPE instruction. + void assignSPIRVTypeToVReg(SPIRVType *Type, Register VReg, + MachineIRBuilder &MIRBuilder); + + // Either generate a new OpTypeXXX instruction or return an existing one + // corresponding to the given LLVM IR type. + // EmitIR controls if we emit GMIR or SPV constants (e.g. for array sizes) + // because this method may be called from InstructionSelector and we don't + // want to emit extra IR instructions there + SPIRVType * + getOrCreateSPIRVType(const Type *Type, MachineIRBuilder &MIRBuilder, + AQ::AccessQualifier accessQual = AQ::ReadWrite, + bool EmitIR = true); + + const Type *getTypeForSPIRVType(const SPIRVType *Ty) const { + auto Res = SPIRVToLLVMType.find(Ty); + assert(Res != SPIRVToLLVMType.end()); + return Res->second; + } + + // Return the SPIR-V type instruction corresponding to the given VReg, or + // nullptr if no such type instruction exists. + SPIRVType *getSPIRVTypeForVReg(Register VReg) const; + + // Whether the given VReg has a SPIR-V type mapped to it yet. + bool hasSPIRVTypeForVReg(Register VReg) const { + return getSPIRVTypeForVReg(VReg) != nullptr; + } + + // Return the VReg holding the result of the given OpTypeXXX instruction + Register getSPIRVTypeID(const SPIRVType *SpirvType) const { + assert(SpirvType && "Attempting to get type id for nullptr type."); + return SpirvType->defs().begin()->getReg(); + } + + void setCurrentFunc(MachineFunction &MF) { CurMF = &MF; } + + // Whether the given VReg has an OpTypeXXX instruction mapped to it with the + // given opcode (e.g. OpTypeFloat). + bool isScalarOfType(Register VReg, unsigned int TypeOpcode) const; + + // Return true if the given VReg's assigned SPIR-V type is either a scalar + // matching the given opcode, or a vector with an element type matching that + // opcode (e.g. OpTypeBool, or OpTypeVector %x 4, where %x is OpTypeBool). + bool isScalarOrVectorOfType(Register VReg, unsigned int TypeOpcode) const; + + // For vectors or scalars of ints or floats, return the scalar type's bitwidth + unsigned getScalarOrVectorBitWidth(const SPIRVType *Type) const; + + // For integer vectors or scalars, return whether the integers are signed + bool isScalarOrVectorSigned(const SPIRVType *Type) const; + + // Gets the storage class of the pointer type assigned to this vreg + StorageClass::StorageClass getPointerStorageClass(Register VReg) const; + + // Return the number of bits SPIR-V pointers and size_t variables require. + unsigned getPointerSize() const { return PointerSize; } + +private: + SPIRVType *getOpTypeBool(MachineIRBuilder &MIRBuilder); + + SPIRVType *getOpTypeInt(uint32_t Width, MachineIRBuilder &MIRBuilder, + bool IsSigned = false); + + SPIRVType *getOpTypeFloat(uint32_t Width, MachineIRBuilder &MIRBuilder); + + SPIRVType *getOpTypeVoid(MachineIRBuilder &MIRBuilder); + + SPIRVType *getOpTypeVector(uint32_t NumElems, SPIRVType *ElemType, + MachineIRBuilder &MIRBuilder); + + SPIRVType *getOpTypeArray(uint32_t NumElems, SPIRVType *ElemType, + MachineIRBuilder &MIRBuilder, bool EmitIR = true); + + SPIRVType *getOpTypePointer(StorageClass::StorageClass SC, + SPIRVType *ElemType, + MachineIRBuilder &MIRBuilder); + + SPIRVType *getOpTypeFunction(SPIRVType *RetType, + const SmallVectorImpl &ArgTypes, + MachineIRBuilder &MIRBuilder); + + SPIRVType *getOpTypeByOpcode(MachineIRBuilder &MIRBuilder, + unsigned int Opcode); + +public: + Register buildConstantInt(uint64_t Val, MachineIRBuilder &MIRBuilder, + SPIRVType *SpvType = nullptr, bool EmitIR = true); + Register buildConstantFP(APFloat Val, MachineIRBuilder &MIRBuilder, + SPIRVType *SpvType = nullptr); + Register buildGlobalVariable(Register Reg, SPIRVType *BaseType, + StringRef Name, const GlobalValue *GV, + StorageClass::StorageClass Storage, + const MachineInstr *Init, bool IsConst, + bool HasLinkageTy, + LinkageType::LinkageType LinkageType, + MachineIRBuilder &MIRBuilder); + + // Convenient helpers for getting types with check for duplicates. + SPIRVType *getOrCreateSPIRVIntegerType(unsigned BitWidth, + MachineIRBuilder &MIRBuilder); + SPIRVType *getOrCreateSPIRVBoolType(MachineIRBuilder &MIRBuilder); + SPIRVType *getOrCreateSPIRVVectorType(SPIRVType *BaseType, + unsigned NumElements, + MachineIRBuilder &MIRBuilder); + SPIRVType *getOrCreateSPIRVPointerType( + SPIRVType *BaseType, MachineIRBuilder &MIRBuilder, + StorageClass::StorageClass SClass = StorageClass::Function); +}; +} // end namespace llvm +#endif // LLLVM_LIB_TARGET_SPIRV_SPIRVTYPEMANAGER_H Index: llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp =================================================================== --- /dev/null +++ llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp @@ -0,0 +1,408 @@ +//===-- SPIRVGlobalRegistry.cpp - SPIR-V Global Registry --------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains the implementation of the SPIRVGlobalRegistry class, +// which is used to maintain rich type information required for SPIR-V even +// after lowering from LLVM IR to GMIR. It can convert an llvm::Type into +// an OpTypeXXX instruction, and map it to a virtual register. Also it builds +// and supports consistency of constants and global variables. +// +//===----------------------------------------------------------------------===// + +#include "SPIRVGlobalRegistry.h" +#include "SPIRV.h" +#include "SPIRVSubtarget.h" +#include "SPIRVTargetMachine.h" +#include "SPIRVUtils.h" + +using namespace llvm; +SPIRVGlobalRegistry::SPIRVGlobalRegistry(unsigned int PointerSize) + : PointerSize(PointerSize) {} + +SPIRVType * +SPIRVGlobalRegistry::assignTypeToVReg(const Type *Type, Register VReg, + MachineIRBuilder &MIRBuilder, + AQ::AccessQualifier AccessQual) { + + SPIRVType *SpirvType = getOrCreateSPIRVType(Type, MIRBuilder, AccessQual); + assignSPIRVTypeToVReg(SpirvType, VReg, MIRBuilder); + return SpirvType; +} + +void SPIRVGlobalRegistry::assignSPIRVTypeToVReg(SPIRVType *SpirvType, + Register VReg, + MachineIRBuilder &MIRBuilder) { + VRegToTypeMap[&MIRBuilder.getMF()][VReg] = SpirvType; +} + +static Register createTypeVReg(MachineIRBuilder &MIRBuilder) { + auto &MRI = MIRBuilder.getMF().getRegInfo(); + auto Res = MRI.createGenericVirtualRegister(LLT::scalar(32)); + MRI.setRegClass(Res, &SPIRV::TYPERegClass); + return Res; +} + +SPIRVType *SPIRVGlobalRegistry::getOpTypeBool(MachineIRBuilder &MIRBuilder) { + return MIRBuilder.buildInstr(SPIRV::OpTypeBool) + .addDef(createTypeVReg(MIRBuilder)); +} + +SPIRVType *SPIRVGlobalRegistry::getOpTypeInt(uint32_t Width, + MachineIRBuilder &MIRBuilder, + bool IsSigned) { + auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeInt) + .addDef(createTypeVReg(MIRBuilder)) + .addImm(Width) + .addImm(IsSigned ? 1 : 0); + return MIB; +} + +SPIRVType *SPIRVGlobalRegistry::getOpTypeFloat(uint32_t Width, + MachineIRBuilder &MIRBuilder) { + auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeFloat) + .addDef(createTypeVReg(MIRBuilder)) + .addImm(Width); + return MIB; +} + +SPIRVType *SPIRVGlobalRegistry::getOpTypeVoid(MachineIRBuilder &MIRBuilder) { + return MIRBuilder.buildInstr(SPIRV::OpTypeVoid) + .addDef(createTypeVReg(MIRBuilder)); +} + +SPIRVType *SPIRVGlobalRegistry::getOpTypeVector(uint32_t NumElems, + SPIRVType *ElemType, + MachineIRBuilder &MIRBuilder) { + using namespace SPIRV; + auto EleOpc = ElemType->getOpcode(); + if (EleOpc != OpTypeInt && EleOpc != OpTypeFloat && EleOpc != OpTypeBool) + report_fatal_error("Invalid vector element type"); + + auto MIB = MIRBuilder.buildInstr(OpTypeVector) + .addDef(createTypeVReg(MIRBuilder)) + .addUse(getSPIRVTypeID(ElemType)) + .addImm(NumElems); + return MIB; +} + +Register SPIRVGlobalRegistry::buildConstantInt(uint64_t Val, + MachineIRBuilder &MIRBuilder, + SPIRVType *SpvType, + bool EmitIR) { + auto &MF = MIRBuilder.getMF(); + Register Res; + IntegerType *LLVMIntTy; + if (SpvType) { + Type *LLVMTy = const_cast(getTypeForSPIRVType(SpvType)); + assert(LLVMTy->isIntegerTy()); + LLVMIntTy = cast(LLVMTy); + } else { + LLVMIntTy = IntegerType::getInt32Ty(MF.getFunction().getContext()); + } + // Find a constant in DT or build a new one. + const auto ConstInt = ConstantInt::get(LLVMIntTy, Val); + unsigned BitWidth = SpvType ? getScalarOrVectorBitWidth(SpvType) : 32; + Res = MF.getRegInfo().createGenericVirtualRegister(LLT::scalar(BitWidth)); + assignTypeToVReg(LLVMIntTy, Res, MIRBuilder); + if (EmitIR) + MIRBuilder.buildConstant(Res, *ConstInt); + else + MIRBuilder.buildInstr(SPIRV::OpConstantI) + .addDef(Res) + .addImm(ConstInt->getSExtValue()); + return Res; +} + +Register SPIRVGlobalRegistry::buildConstantFP(APFloat Val, + MachineIRBuilder &MIRBuilder, + SPIRVType *SpvType) { + auto &MF = MIRBuilder.getMF(); + Register Res; + Type *LLVMFPTy; + if (SpvType) { + Type *LLVMTy = const_cast(getTypeForSPIRVType(SpvType)); + assert(LLVMTy->isFloatingPointTy()); + LLVMFPTy = LLVMTy; + } else { + LLVMFPTy = IntegerType::getFloatTy(MF.getFunction().getContext()); + } + // Find a constant in DT or build a new one. + const auto ConstFP = ConstantFP::get(LLVMFPTy->getContext(), Val); + unsigned BitWidth = SpvType ? getScalarOrVectorBitWidth(SpvType) : 32; + Res = MF.getRegInfo().createGenericVirtualRegister(LLT::scalar(BitWidth)); + assignTypeToVReg(LLVMFPTy, Res, MIRBuilder); + MIRBuilder.buildFConstant(Res, *ConstFP); + return Res; +} + +Register SPIRVGlobalRegistry::buildGlobalVariable( + Register ResVReg, SPIRVType *BaseType, StringRef Name, + const GlobalValue *GV, StorageClass::StorageClass Storage, + const MachineInstr *Init, bool IsConst, bool HasLinkageTy, + LinkageType::LinkageType LinkageType, MachineIRBuilder &MIRBuilder) { + const GlobalVariable *GVar = nullptr; + if (GV) + GVar = dyn_cast(GV); + else { + // If GV is not passed explicitly, use the name to find or construct + // the global variable. + auto *Module = MIRBuilder.getMBB().getBasicBlock()->getModule(); + GVar = Module->getGlobalVariable(Name); + if (GVar == nullptr) { + auto Ty = getTypeForSPIRVType(BaseType); // TODO check type + GVar = new GlobalVariable( + *const_cast(Module), const_cast(Ty), false, + GlobalValue::ExternalLinkage, nullptr, Twine(Name)); + } + GV = GVar; + } + assert(GV && "Global variable is expected"); + Register Reg; + + auto MIB = MIRBuilder.buildInstr(SPIRV::OpVariable) + .addDef(ResVReg) + .addUse(getSPIRVTypeID(BaseType)) + .addImm(Storage); + + if (Init != 0) { + MIB.addUse(Init->getOperand(0).getReg()); + } + + // ISel may introduce a new register on this step, so we need to add it to + // DT and correct its type avoiding fails on the next stage. + constrainRegOperands(MIB, CurMF); + Reg = MIB->getOperand(0).getReg(); + + // Set to Reg the same type as ResVReg has. + auto MRI = MIRBuilder.getMRI(); + assert(MRI->getType(ResVReg).isPointer() && "Pointer type is expected"); + if (Reg != ResVReg) { + LLT RegLLTy = LLT::pointer(MRI->getType(ResVReg).getAddressSpace(), 32); + MRI->setType(Reg, RegLLTy); + assignSPIRVTypeToVReg(BaseType, Reg, MIRBuilder); + } + + // If it's a global variable with name, output OpName for it. + if (GVar && GVar->hasName()) + buildOpName(Reg, GVar->getName(), MIRBuilder); + + // Output decorations for the GV. + // TODO: maybe move to GenerateDecorations pass. + if (IsConst) + buildOpDecorate(Reg, MIRBuilder, Decoration::Constant, {}); + + if (GVar && GVar->getAlign().valueOrOne().value() != 1) { + unsigned Alignment = (unsigned)GVar->getAlign().valueOrOne().value(); + buildOpDecorate(Reg, MIRBuilder, Decoration::Alignment, {Alignment}); + } + + if (HasLinkageTy) + buildOpDecorate(Reg, MIRBuilder, Decoration::LinkageAttributes, + {LinkageType}, Name); + return Reg; +} + +SPIRVType *SPIRVGlobalRegistry::getOpTypeArray(uint32_t NumElems, + SPIRVType *ElemType, + MachineIRBuilder &MIRBuilder, + bool EmitIR) { + if (ElemType->getOpcode() == SPIRV::OpTypeVoid) + report_fatal_error("Invalid array element type"); + Register NumElementsVReg = + buildConstantInt(NumElems, MIRBuilder, nullptr, EmitIR); + auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeArray) + .addDef(createTypeVReg(MIRBuilder)) + .addUse(getSPIRVTypeID(ElemType)) + .addUse(NumElementsVReg); + assert(constrainRegOperands(MIB, CurMF)); + return MIB; +} + +SPIRVType *SPIRVGlobalRegistry::getOpTypePointer(StorageClass::StorageClass SC, + SPIRVType *ElemType, + MachineIRBuilder &MIRBuilder) { + auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypePointer) + .addDef(createTypeVReg(MIRBuilder)) + .addImm(SC) + .addUse(getSPIRVTypeID(ElemType)); + return MIB; +} + +SPIRVType *SPIRVGlobalRegistry::getOpTypeFunction( + SPIRVType *RetType, const SmallVectorImpl &ArgTypes, + MachineIRBuilder &MIRBuilder) { + auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeFunction) + .addDef(createTypeVReg(MIRBuilder)) + .addUse(getSPIRVTypeID(RetType)); + for (const auto &ArgType : ArgTypes) + MIB.addUse(getSPIRVTypeID(ArgType)); + return MIB; +} + +SPIRVType *SPIRVGlobalRegistry::createSPIRVType(const Type *Ty, + MachineIRBuilder &MIRBuilder, + AQ::AccessQualifier AccQual, + bool EmitIR) { + if (auto IType = dyn_cast(Ty)) { + const unsigned int Width = IType->getBitWidth(); + return Width == 1 ? getOpTypeBool(MIRBuilder) + : getOpTypeInt(Width, MIRBuilder, false); + } else if (Ty->isFloatingPointTy()) + return getOpTypeFloat(Ty->getPrimitiveSizeInBits(), MIRBuilder); + else if (Ty->isVoidTy()) + return getOpTypeVoid(MIRBuilder); + else if (Ty->isVectorTy()) { + auto El = getOrCreateSPIRVType(cast(Ty)->getElementType(), + MIRBuilder); + return getOpTypeVector(cast(Ty)->getNumElements(), El, + MIRBuilder); + } else if (Ty->isArrayTy()) { + auto *El = getOrCreateSPIRVType(Ty->getArrayElementType(), MIRBuilder); + return getOpTypeArray(Ty->getArrayNumElements(), El, MIRBuilder, EmitIR); + } else if (dyn_cast(Ty)) { + llvm_unreachable("Unsupported LLVM type"); + } else if (auto FType = dyn_cast(Ty)) { + SPIRVType *RetTy = getOrCreateSPIRVType(FType->getReturnType(), MIRBuilder); + SmallVector ParamTypes; + for (const auto &t : FType->params()) { + ParamTypes.push_back(getOrCreateSPIRVType(t, MIRBuilder)); + } + return getOpTypeFunction(RetTy, ParamTypes, MIRBuilder); + } else if (const auto &PType = dyn_cast(Ty)) { + Type *ElemType = PType->getPointerElementType(); + + // Some OpenCL and SPIRV builtins like image2d_t are passed in as pointers, + // but should be treated as custom types like OpTypeImage + if (dyn_cast(ElemType)) + llvm_unreachable("Unsupported LLVM type"); + + // Otherwise, treat it as a regular pointer type + auto SC = addressSpaceToStorageClass(PType->getAddressSpace()); + SPIRVType *SpvElementType = + getOrCreateSPIRVType(ElemType, MIRBuilder, AQ::ReadWrite, EmitIR); + return getOpTypePointer(SC, SpvElementType, MIRBuilder); + } else + llvm_unreachable("Unable to convert LLVM type to SPIRVType"); +} + +SPIRVType *SPIRVGlobalRegistry::getSPIRVTypeForVReg(Register VReg) const { + auto t = VRegToTypeMap.find(CurMF); + if (t != VRegToTypeMap.end()) { + auto tt = t->second.find(VReg); + if (tt != t->second.end()) + return tt->second; + } + return nullptr; +} + +SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVType( + const Type *Type, MachineIRBuilder &MIRBuilder, + AQ::AccessQualifier AccessQual, bool EmitIR) { + Register Reg; + SPIRVType *SpirvType = createSPIRVType(Type, MIRBuilder, AccessQual, EmitIR); + VRegToTypeMap[&MIRBuilder.getMF()][getSPIRVTypeID(SpirvType)] = SpirvType; + SPIRVToLLVMType[SpirvType] = Type; + return SpirvType; +} + +bool SPIRVGlobalRegistry::isScalarOfType(Register VReg, + unsigned int TypeOpcode) const { + SPIRVType *Type = getSPIRVTypeForVReg(VReg); + assert(Type && "isScalarOfType VReg has no type assigned"); + return Type->getOpcode() == TypeOpcode; +} + +bool SPIRVGlobalRegistry::isScalarOrVectorOfType( + Register VReg, unsigned int TypeOpcode) const { + SPIRVType *Type = getSPIRVTypeForVReg(VReg); + assert(Type && "isScalarOrVectorOfType VReg has no type assigned"); + if (Type->getOpcode() == TypeOpcode) { + return true; + } else if (Type->getOpcode() == SPIRV::OpTypeVector) { + Register ScalarTypeVReg = Type->getOperand(1).getReg(); + SPIRVType *ScalarType = getSPIRVTypeForVReg(ScalarTypeVReg); + return ScalarType->getOpcode() == TypeOpcode; + } + return false; +} + +unsigned +SPIRVGlobalRegistry::getScalarOrVectorBitWidth(const SPIRVType *Type) const { + if (Type && Type->getOpcode() == SPIRV::OpTypeVector) { + auto EleTypeReg = Type->getOperand(1).getReg(); + Type = getSPIRVTypeForVReg(EleTypeReg); + } + if (Type && (Type->getOpcode() == SPIRV::OpTypeInt || + Type->getOpcode() == SPIRV::OpTypeFloat)) { + return Type->getOperand(1).getImm(); + } else if (Type && Type->getOpcode() == SPIRV::OpTypeBool) { + return 1; + } + llvm_unreachable("Attempting to get bit width of non-integer/float type."); +} + +bool SPIRVGlobalRegistry::isScalarOrVectorSigned(const SPIRVType *Type) const { + if (Type && Type->getOpcode() == SPIRV::OpTypeVector) { + auto EleTypeReg = Type->getOperand(1).getReg(); + Type = getSPIRVTypeForVReg(EleTypeReg); + } + if (Type && Type->getOpcode() == SPIRV::OpTypeInt) { + return Type->getOperand(2).getImm() != 0; + } + llvm_unreachable("Attempting to get sign of non-integer type."); +} + +StorageClass::StorageClass +SPIRVGlobalRegistry::getPointerStorageClass(Register VReg) const { + SPIRVType *Type = getSPIRVTypeForVReg(VReg); + if (Type && Type->getOpcode() == SPIRV::OpTypePointer) { + auto scOp = Type->getOperand(1).getImm(); + return static_cast(scOp); + } + llvm_unreachable("Attempting to get storage class of non-pointer type."); +} + +SPIRVType *SPIRVGlobalRegistry::getOpTypeByOpcode(MachineIRBuilder &MIRBuilder, + unsigned int Opcode) { + Register ResVReg = createTypeVReg(MIRBuilder); + auto MIB = MIRBuilder.buildInstr(Opcode).addDef(ResVReg); + constrainRegOperands(MIB); + return MIB; +} + +SPIRVType * +SPIRVGlobalRegistry::getOrCreateSPIRVIntegerType(unsigned BitWidth, + MachineIRBuilder &MIRBuilder) { + return getOrCreateSPIRVType( + IntegerType::get(MIRBuilder.getMF().getFunction().getContext(), BitWidth), + MIRBuilder); +} + +SPIRVType * +SPIRVGlobalRegistry::getOrCreateSPIRVBoolType(MachineIRBuilder &MIRBuilder) { + return getOrCreateSPIRVType( + IntegerType::get(MIRBuilder.getMF().getFunction().getContext(), 1), + MIRBuilder); +} + +SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVVectorType( + SPIRVType *BaseType, unsigned NumElements, MachineIRBuilder &MIRBuilder) { + return getOrCreateSPIRVType( + FixedVectorType::get(const_cast(getTypeForSPIRVType(BaseType)), + NumElements), + MIRBuilder); +} + +SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVPointerType( + SPIRVType *BaseType, MachineIRBuilder &MIRBuilder, + StorageClass::StorageClass SClass) { + return getOrCreateSPIRVType( + PointerType::get(const_cast(getTypeForSPIRVType(BaseType)), + storageClassToAddressSpace(SClass)), + MIRBuilder); +} Index: llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp =================================================================== --- /dev/null +++ llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp @@ -0,0 +1,1168 @@ +//===- SPIRVInstructionSelector.cpp ------------------------------*- C++ -*-==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the targeting of the InstructionSelector class for +// SPIRV. +// TODO: This should be generated by TableGen. +// +//===----------------------------------------------------------------------===// + +#include "SPIRV.h" +#include "SPIRVGlobalRegistry.h" +#include "SPIRVInstrInfo.h" +#include "SPIRVRegisterBankInfo.h" +#include "SPIRVRegisterInfo.h" +#include "SPIRVSubtarget.h" +#include "SPIRVTargetMachine.h" +#include "SPIRVUtils.h" +#include "llvm/ADT/APFloat.h" +#include "llvm/CodeGen/GlobalISel/InstructionSelector.h" +#include "llvm/CodeGen/GlobalISel/InstructionSelectorImpl.h" +#include "llvm/CodeGen/MachineInstrBuilder.h" +#include "llvm/CodeGen/MachineRegisterInfo.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "spirv-isel" + +using namespace llvm; + +namespace { + +#define GET_GLOBALISEL_PREDICATE_BITSET +#include "SPIRVGenGlobalISel.inc" +#undef GET_GLOBALISEL_PREDICATE_BITSET + +class SPIRVInstructionSelector : public InstructionSelector { + const SPIRVSubtarget &STI; + const SPIRVInstrInfo &TII; + const SPIRVRegisterInfo &TRI; + const SPIRVRegisterBankInfo &RBI; + SPIRVGlobalRegistry &GR; + +public: + SPIRVInstructionSelector(const SPIRVTargetMachine &TM, + const SPIRVSubtarget &ST, + const SPIRVRegisterBankInfo &RBI); + + // Common selection code. Instruction-specific selection occurs in spvSelect() + bool select(MachineInstr &I) override; + static const char *getName() { return DEBUG_TYPE; } + +#define GET_GLOBALISEL_PREDICATES_DECL +#include "SPIRVGenGlobalISel.inc" +#undef GET_GLOBALISEL_PREDICATES_DECL + +#define GET_GLOBALISEL_TEMPORARIES_DECL +#include "SPIRVGenGlobalISel.inc" +#undef GET_GLOBALISEL_TEMPORARIES_DECL + +private: + // tblgen-erated 'select' implementation, used as the initial selector for + // the patterns that don't require complex C++. + bool selectImpl(MachineInstr &I, CodeGenCoverage &CoverageInfo) const; + + // All instruction-specific selection that didn't happen in "select()". + // Is basically a large Switch/Case delegating to all other select method. + bool spvSelect(Register ResVReg, const SPIRVType *ResType, + const MachineInstr &I, MachineIRBuilder &MIRBuilder) const; + + bool selectGlobalValue(Register ResVReg, const MachineInstr &I, + MachineIRBuilder &MIRBuilder, + const MachineInstr *Init = nullptr) const; + + bool selectUnOpWithSrc(Register ResVReg, const SPIRVType *ResType, + const MachineInstr &I, Register SrcReg, + MachineIRBuilder &MIRBuilder, unsigned Opcode) const; + bool selectUnOp(Register ResVReg, const SPIRVType *ResType, + const MachineInstr &I, MachineIRBuilder &MIRBuilder, + unsigned Opcode) const; + + bool selectLoad(Register ResVReg, const SPIRVType *ResType, + const MachineInstr &I, MachineIRBuilder &MIRBuilder) const; + + bool selectStore(const MachineInstr &I, MachineIRBuilder &MIRBuilder) const; + + bool selectMemOperation(Register ResVReg, const MachineInstr &I, + MachineIRBuilder &MIRBuilder) const; + + bool selectAtomicRMW(Register ResVReg, const SPIRVType *ResType, + const MachineInstr &I, MachineIRBuilder &MIRBuilder, + unsigned NewOpcode) const; + + bool selectAtomicCmpXchg(Register ResVReg, const SPIRVType *ResType, + const MachineInstr &I, MachineIRBuilder &MIRBuilder, + bool WithSuccess) const; + + bool selectFence(const MachineInstr &I, MachineIRBuilder &MIRBuilder) const; + + bool selectAddrSpaceCast(Register ResVReg, const SPIRVType *ResType, + const MachineInstr &I, + MachineIRBuilder &MIRBuilder) const; + + bool selectBitreverse(Register ResVReg, const SPIRVType *ResType, + const MachineInstr &I, + MachineIRBuilder &MIRBuilder) const; + + bool selectConstVector(Register ResVReg, const SPIRVType *ResType, + const MachineInstr &I, + MachineIRBuilder &MIRBuilder) const; + + bool selectCmp(Register ResVReg, const SPIRVType *ResType, + unsigned scalarTypeOpcode, unsigned comparisonOpcode, + const MachineInstr &I, MachineIRBuilder &MIRBuilder) const; + + bool selectICmp(Register ResVReg, const SPIRVType *ResType, + const MachineInstr &I, MachineIRBuilder &MIRBuilder) const; + bool selectFCmp(Register ResVReg, const SPIRVType *ResType, + const MachineInstr &I, MachineIRBuilder &MIRBuilder) const; + + void renderImm32(MachineInstrBuilder &MIB, const MachineInstr &I, + int OpIdx) const; + void renderFImm32(MachineInstrBuilder &MIB, const MachineInstr &I, + int OpIdx) const; + + bool selectConst(Register ResVReg, const SPIRVType *ResType, const APInt &Imm, + MachineIRBuilder &MIRBuilder) const; + + bool selectSelect(Register ResVReg, const SPIRVType *ResType, + const MachineInstr &I, bool IsSigned, + MachineIRBuilder &MIRBuilder) const; + bool selectIToF(Register ResVReg, const SPIRVType *ResType, + const MachineInstr &I, bool IsSigned, + MachineIRBuilder &MIRBuilder, unsigned Opcode) const; + bool selectExt(Register ResVReg, const SPIRVType *ResType, + const MachineInstr &I, bool IsSigned, + MachineIRBuilder &MIRBuilder) const; + + bool selectTrunc(Register ResVReg, const SPIRVType *ResType, + const MachineInstr &I, MachineIRBuilder &MIRBuilder) const; + + bool selectIntToBool(Register IntReg, Register ResVReg, + const SPIRVType *intTy, const SPIRVType *boolTy, + MachineIRBuilder &MIRBuilder) const; + + bool selectOpUndef(Register ResVReg, const SPIRVType *ResType, + MachineIRBuilder &MIRBuilder) const; + bool selectIntrinsic(Register ResVReg, const SPIRVType *ResType, + const MachineInstr &I, + MachineIRBuilder &MIRBuilder) const; + + bool selectFrameIndex(Register ResVReg, const SPIRVType *ResType, + MachineIRBuilder &MIRBuilder) const; + + bool selectBranch(const MachineInstr &I, MachineIRBuilder &MIRBuilder) const; + bool selectBranchCond(const MachineInstr &I, + MachineIRBuilder &MIRBuilder) const; + + bool selectPhi(Register ResVReg, const SPIRVType *ResType, + const MachineInstr &I, MachineIRBuilder &MIRBuilder) const; + + Register buildI32Constant(uint32_t Val, MachineIRBuilder &MIRBuilder, + const SPIRVType *ResType = nullptr) const; + + Register buildZerosVal(const SPIRVType *ResType, + MachineIRBuilder &MIRBuilder) const; + Register buildOnesVal(bool AllOnes, const SPIRVType *ResType, + MachineIRBuilder &MIRBuilder) const; +}; + +} // end anonymous namespace + +#define GET_GLOBALISEL_IMPL +#include "SPIRVGenGlobalISel.inc" +#undef GET_GLOBALISEL_IMPL + +SPIRVInstructionSelector::SPIRVInstructionSelector( + const SPIRVTargetMachine &TM, const SPIRVSubtarget &ST, + const SPIRVRegisterBankInfo &RBI) + : InstructionSelector(), STI(ST), TII(*ST.getInstrInfo()), + TRI(*ST.getRegisterInfo()), RBI(RBI), GR(*ST.getSPIRVGlobalRegistry()), +#define GET_GLOBALISEL_PREDICATES_INIT +#include "SPIRVGenGlobalISel.inc" +#undef GET_GLOBALISEL_PREDICATES_INIT +#define GET_GLOBALISEL_TEMPORARIES_INIT +#include "SPIRVGenGlobalISel.inc" +#undef GET_GLOBALISEL_TEMPORARIES_INIT +{ +} + +// Defined in SPIRVLegalizerInfo.cpp +extern bool isTypeFoldingSupported(unsigned Opcode); + +bool SPIRVInstructionSelector::select(MachineInstr &I) { + assert(I.getParent() && "Instruction should be in a basic block!"); + assert(I.getParent()->getParent() && "Instruction should be in a function!"); + + MachineBasicBlock &MBB = *I.getParent(); + MachineFunction &MF = *MBB.getParent(); + + MachineIRBuilder MIRBuilder; + MIRBuilder.setMF(MF); + MIRBuilder.setMBB(MBB); + MIRBuilder.setInstr(I); + + Register Opcode = I.getOpcode(); + + // If it's not a GMIR instruction, we've selected it already. + if (!isPreISelGenericOpcode(Opcode)) { + if (Opcode == SPIRV::ASSIGN_TYPE) { // These pseudos aren't needed any more + auto MRI = MIRBuilder.getMRI(); + auto *Def = MRI->getVRegDef(I.getOperand(1).getReg()); + if (isTypeFoldingSupported(Def->getOpcode())) { + auto Res = selectImpl(I, *CoverageInfo); + assert(Res || Def->getOpcode() == TargetOpcode::G_CONSTANT); + if (Res) + return Res; + } + MRI->replaceRegWith(I.getOperand(1).getReg(), I.getOperand(0).getReg()); + I.removeFromParent(); + } else if (I.getNumDefs() == 1) { // Make all vregs 32 bits (for SPIR-V IDs) + MIRBuilder.getMRI()->setType(I.getOperand(0).getReg(), LLT::scalar(32)); + } + return true; + } + + if (I.getNumOperands() != I.getNumExplicitOperands()) { + LLVM_DEBUG(errs() << "Generic instr has unexpected implicit operands\n"); + return false; + } + + // Common code for getting return reg+type, and removing selected instr + // from parent occurs here. Instr-specific selection happens in spvSelect(). + bool HasDefs = I.getNumDefs() > 0; + Register ResVReg = HasDefs ? I.getOperand(0).getReg() : Register(0); + SPIRVType *ResType = HasDefs ? GR.getSPIRVTypeForVReg(ResVReg) : nullptr; + assert(!HasDefs || ResType || I.getOpcode() == TargetOpcode::G_GLOBAL_VALUE); + if (spvSelect(ResVReg, ResType, I, MIRBuilder)) { + if (HasDefs) { // Make all vregs 32 bits (for SPIR-V IDs) + MIRBuilder.getMRI()->setType(ResVReg, LLT::scalar(32)); + } + I.removeFromParent(); + return true; + } + return false; +} + +bool SPIRVInstructionSelector::spvSelect(Register ResVReg, + const SPIRVType *ResType, + const MachineInstr &I, + MachineIRBuilder &MIRBuilder) const { + using namespace SPIRV; + assert(!isTypeFoldingSupported(I.getOpcode()) || + I.getOpcode() == TargetOpcode::G_CONSTANT); + + const unsigned Opcode = I.getOpcode(); + switch (Opcode) { + case TargetOpcode::G_CONSTANT: + return selectConst(ResVReg, ResType, I.getOperand(1).getCImm()->getValue(), + MIRBuilder); + case TargetOpcode::G_GLOBAL_VALUE: + return selectGlobalValue(ResVReg, I, MIRBuilder); + case TargetOpcode::G_IMPLICIT_DEF: + return selectOpUndef(ResVReg, ResType, MIRBuilder); + + case TargetOpcode::G_INTRINSIC_W_SIDE_EFFECTS: + return selectIntrinsic(ResVReg, ResType, I, MIRBuilder); + case TargetOpcode::G_BITREVERSE: + return selectBitreverse(ResVReg, ResType, I, MIRBuilder); + + case TargetOpcode::G_BUILD_VECTOR: + return selectConstVector(ResVReg, ResType, I, MIRBuilder); + + case TargetOpcode::G_SHUFFLE_VECTOR: { + auto MIB = MIRBuilder.buildInstr(SPIRV::OpVectorShuffle) + .addDef(ResVReg) + .addUse(GR.getSPIRVTypeID(ResType)) + .addUse(I.getOperand(1).getReg()) + .addUse(I.getOperand(2).getReg()); + for (auto V : I.getOperand(3).getShuffleMask()) + MIB.addImm(V); + + return MIB.constrainAllUses(TII, TRI, RBI); + } + case TargetOpcode::G_MEMMOVE: + case TargetOpcode::G_MEMCPY: + return selectMemOperation(ResVReg, I, MIRBuilder); + + case TargetOpcode::G_ICMP: + return selectICmp(ResVReg, ResType, I, MIRBuilder); + case TargetOpcode::G_FCMP: + return selectFCmp(ResVReg, ResType, I, MIRBuilder); + + case TargetOpcode::G_FRAME_INDEX: + return selectFrameIndex(ResVReg, ResType, MIRBuilder); + + case TargetOpcode::G_LOAD: + return selectLoad(ResVReg, ResType, I, MIRBuilder); + case TargetOpcode::G_STORE: + return selectStore(I, MIRBuilder); + + case TargetOpcode::G_BR: + return selectBranch(I, MIRBuilder); + case TargetOpcode::G_BRCOND: + return selectBranchCond(I, MIRBuilder); + + case TargetOpcode::G_PHI: + return selectPhi(ResVReg, ResType, I, MIRBuilder); + + case TargetOpcode::G_FPTOSI: + return selectUnOp(ResVReg, ResType, I, MIRBuilder, OpConvertFToS); + case TargetOpcode::G_FPTOUI: + return selectUnOp(ResVReg, ResType, I, MIRBuilder, OpConvertFToU); + + case TargetOpcode::G_SITOFP: + return selectIToF(ResVReg, ResType, I, true, MIRBuilder, OpConvertSToF); + case TargetOpcode::G_UITOFP: + return selectIToF(ResVReg, ResType, I, false, MIRBuilder, OpConvertUToF); + + case TargetOpcode::G_CTPOP: + return selectUnOp(ResVReg, ResType, I, MIRBuilder, OpBitCount); + + case TargetOpcode::G_SEXT: + return selectExt(ResVReg, ResType, I, true, MIRBuilder); + case TargetOpcode::G_ANYEXT: + case TargetOpcode::G_ZEXT: + return selectExt(ResVReg, ResType, I, false, MIRBuilder); + case TargetOpcode::G_TRUNC: + return selectTrunc(ResVReg, ResType, I, MIRBuilder); + case TargetOpcode::G_FPTRUNC: + case TargetOpcode::G_FPEXT: + return selectUnOp(ResVReg, ResType, I, MIRBuilder, OpFConvert); + + case TargetOpcode::G_PTRTOINT: + return selectUnOp(ResVReg, ResType, I, MIRBuilder, OpConvertPtrToU); + case TargetOpcode::G_INTTOPTR: + return selectUnOp(ResVReg, ResType, I, MIRBuilder, OpConvertUToPtr); + case TargetOpcode::G_BITCAST: + return selectUnOp(ResVReg, ResType, I, MIRBuilder, OpBitcast); + case TargetOpcode::G_ADDRSPACE_CAST: + return selectAddrSpaceCast(ResVReg, ResType, I, MIRBuilder); + + case TargetOpcode::G_ATOMICRMW_OR: + return selectAtomicRMW(ResVReg, ResType, I, MIRBuilder, OpAtomicOr); + case TargetOpcode::G_ATOMICRMW_ADD: + return selectAtomicRMW(ResVReg, ResType, I, MIRBuilder, OpAtomicIAdd); + case TargetOpcode::G_ATOMICRMW_AND: + return selectAtomicRMW(ResVReg, ResType, I, MIRBuilder, OpAtomicAnd); + case TargetOpcode::G_ATOMICRMW_MAX: + return selectAtomicRMW(ResVReg, ResType, I, MIRBuilder, OpAtomicSMax); + case TargetOpcode::G_ATOMICRMW_MIN: + return selectAtomicRMW(ResVReg, ResType, I, MIRBuilder, OpAtomicSMin); + case TargetOpcode::G_ATOMICRMW_SUB: + return selectAtomicRMW(ResVReg, ResType, I, MIRBuilder, OpAtomicISub); + case TargetOpcode::G_ATOMICRMW_XOR: + return selectAtomicRMW(ResVReg, ResType, I, MIRBuilder, OpAtomicXor); + case TargetOpcode::G_ATOMICRMW_UMAX: + return selectAtomicRMW(ResVReg, ResType, I, MIRBuilder, OpAtomicUMax); + case TargetOpcode::G_ATOMICRMW_UMIN: + return selectAtomicRMW(ResVReg, ResType, I, MIRBuilder, OpAtomicUMin); + case TargetOpcode::G_ATOMICRMW_XCHG: + return selectAtomicRMW(ResVReg, ResType, I, MIRBuilder, OpAtomicExchange); + + case TargetOpcode::G_ATOMIC_CMPXCHG_WITH_SUCCESS: + return selectAtomicCmpXchg(ResVReg, ResType, I, MIRBuilder, true); + case TargetOpcode::G_ATOMIC_CMPXCHG: + return selectAtomicCmpXchg(ResVReg, ResType, I, MIRBuilder, false); + + case TargetOpcode::G_FENCE: + return selectFence(I, MIRBuilder); + + default: + return false; + } +} + +bool SPIRVInstructionSelector::selectUnOpWithSrc( + Register ResVReg, const SPIRVType *ResType, const MachineInstr &I, + Register SrcReg, MachineIRBuilder &MIRBuilder, unsigned Opcode) const { + return MIRBuilder.buildInstr(Opcode) + .addDef(ResVReg) + .addUse(GR.getSPIRVTypeID(ResType)) + .addUse(SrcReg) + .constrainAllUses(TII, TRI, RBI); +} + +bool SPIRVInstructionSelector::selectUnOp(Register ResVReg, + const SPIRVType *ResType, + const MachineInstr &I, + MachineIRBuilder &MIRBuilder, + unsigned Opcode) const { + return selectUnOpWithSrc(ResVReg, ResType, I, I.getOperand(1).getReg(), + MIRBuilder, Opcode); +} + +static MemorySemantics::MemorySemantics getMemSemantics(AtomicOrdering Ord) { + switch (Ord) { + case AtomicOrdering::Acquire: + return MemorySemantics::Acquire; + case AtomicOrdering::Release: + return MemorySemantics::Release; + case AtomicOrdering::AcquireRelease: + return MemorySemantics::AcquireRelease; + case AtomicOrdering::SequentiallyConsistent: + return MemorySemantics::SequentiallyConsistent; + case AtomicOrdering::Unordered: + case AtomicOrdering::Monotonic: + case AtomicOrdering::NotAtomic: + default: + return MemorySemantics::None; + } +} + +static Scope::Scope getScope(SyncScope::ID Ord) { + switch (Ord) { + case SyncScope::SingleThread: + return Scope::Invocation; + case SyncScope::System: + return Scope::Device; + default: + llvm_unreachable("Unsupported synchronization Scope ID."); + } +} + +static void addMemoryOperands(MachineMemOperand *MemOp, + MachineInstrBuilder &MIB) { + uint32_t SpvMemOp = MemoryOperand::None; + if (MemOp->isVolatile()) + SpvMemOp |= MemoryOperand::Volatile; + if (MemOp->isNonTemporal()) + SpvMemOp |= MemoryOperand::Nontemporal; + if (MemOp->getAlign().value()) + SpvMemOp |= MemoryOperand::Aligned; + + if (SpvMemOp != MemoryOperand::None) { + MIB.addImm(SpvMemOp); + if (SpvMemOp & MemoryOperand::Aligned) + MIB.addImm(MemOp->getAlign().value()); + } +} + +static void addMemoryOperands(uint64_t Flags, MachineInstrBuilder &MIB) { + uint32_t SpvMemOp = MemoryOperand::None; + if (Flags & MachineMemOperand::Flags::MOVolatile) + SpvMemOp |= MemoryOperand::Volatile; + if (Flags & MachineMemOperand::Flags::MONonTemporal) + SpvMemOp |= MemoryOperand::Nontemporal; + + if (SpvMemOp != MemoryOperand::None) + MIB.addImm(SpvMemOp); +} + +bool SPIRVInstructionSelector::selectLoad(Register ResVReg, + const SPIRVType *ResType, + const MachineInstr &I, + MachineIRBuilder &MIRBuilder) const { + unsigned OpOffset = + I.getOpcode() == TargetOpcode::G_INTRINSIC_W_SIDE_EFFECTS ? 1 : 0; + auto Ptr = I.getOperand(1 + OpOffset).getReg(); + auto MIB = MIRBuilder.buildInstr(SPIRV::OpLoad) + .addDef(ResVReg) + .addUse(GR.getSPIRVTypeID(ResType)) + .addUse(Ptr); + if (!I.getNumMemOperands()) { + assert(I.getOpcode() == TargetOpcode::G_INTRINSIC_W_SIDE_EFFECTS); + addMemoryOperands(I.getOperand(2 + OpOffset).getImm(), MIB); + } else { + auto MemOp = *I.memoperands_begin(); + addMemoryOperands(MemOp, MIB); + } + return MIB.constrainAllUses(TII, TRI, RBI); +} + +bool SPIRVInstructionSelector::selectStore(const MachineInstr &I, + MachineIRBuilder &MIRBuilder) const { + unsigned OpOffset = + I.getOpcode() == TargetOpcode::G_INTRINSIC_W_SIDE_EFFECTS ? 1 : 0; + auto StoreVal = I.getOperand(0 + OpOffset).getReg(); + auto Ptr = I.getOperand(1 + OpOffset).getReg(); + auto MIB = MIRBuilder.buildInstr(SPIRV::OpStore).addUse(Ptr).addUse(StoreVal); + if (!I.getNumMemOperands()) { + assert(I.getOpcode() == TargetOpcode::G_INTRINSIC_W_SIDE_EFFECTS); + addMemoryOperands(I.getOperand(2 + OpOffset).getImm(), MIB); + } else { + auto MemOp = *I.memoperands_begin(); + addMemoryOperands(MemOp, MIB); + } + return MIB.constrainAllUses(TII, TRI, RBI); +} + +bool SPIRVInstructionSelector::selectMemOperation( + Register ResVReg, const MachineInstr &I, + MachineIRBuilder &MIRBuilder) const { + auto MIB = MIRBuilder.buildInstr(SPIRV::OpCopyMemorySized) + .addDef(I.getOperand(0).getReg()) + .addUse(I.getOperand(1).getReg()) + .addUse(I.getOperand(2).getReg()); + if (I.getNumMemOperands()) { + auto MemOp = *I.memoperands_begin(); + addMemoryOperands(MemOp, MIB); + } + bool Result = MIB.constrainAllUses(TII, TRI, RBI); + if (ResVReg != MIB->getOperand(0).getReg()) { + MIRBuilder.buildCopy(ResVReg, MIB->getOperand(0).getReg()); + } + return Result; +} + +bool SPIRVInstructionSelector::selectAtomicRMW(Register ResVReg, + const SPIRVType *ResType, + const MachineInstr &I, + MachineIRBuilder &MIRBuilder, + unsigned NewOpcode) const { + assert(I.hasOneMemOperand()); + auto MemOp = *I.memoperands_begin(); + auto Scope = getScope(MemOp->getSyncScopeID()); + Register ScopeReg = buildI32Constant(Scope, MIRBuilder); + + auto Ptr = I.getOperand(1).getReg(); + // Changed as it's implemented in the translator. See test/atomicrmw.ll + // auto ScSem = + // getMemSemanticsForStorageClass(GR.getPointerStorageClass(Ptr)); + + auto MemSem = getMemSemantics(MemOp->getSuccessOrdering()); + Register MemSemReg = buildI32Constant(MemSem /*| ScSem*/, MIRBuilder); + + return MIRBuilder.buildInstr(NewOpcode) + .addDef(ResVReg) + .addUse(GR.getSPIRVTypeID(ResType)) + .addUse(Ptr) + .addUse(ScopeReg) + .addUse(MemSemReg) + .addUse(I.getOperand(2).getReg()) + .constrainAllUses(TII, TRI, RBI); +} + +bool SPIRVInstructionSelector::selectFence(const MachineInstr &I, + MachineIRBuilder &MIRBuilder) const { + auto MemSem = getMemSemantics(AtomicOrdering(I.getOperand(0).getImm())); + Register MemSemReg = buildI32Constant(MemSem, MIRBuilder); + + auto Scope = getScope(SyncScope::ID(I.getOperand(1).getImm())); + Register ScopeReg = buildI32Constant(Scope, MIRBuilder); + + return MIRBuilder.buildInstr(SPIRV::OpMemoryBarrier) + .addUse(ScopeReg) + .addUse(MemSemReg) + .constrainAllUses(TII, TRI, RBI); +} + +bool SPIRVInstructionSelector::selectAtomicCmpXchg(Register ResVReg, + const SPIRVType *ResType, + const MachineInstr &I, + MachineIRBuilder &MIRBuilder, + bool WithSuccess) const { + auto MRI = MIRBuilder.getMRI(); + assert(I.hasOneMemOperand()); + auto MemOp = *I.memoperands_begin(); + auto Scope = getScope(MemOp->getSyncScopeID()); + Register ScopeReg = buildI32Constant(Scope, MIRBuilder); + + auto Ptr = I.getOperand(2).getReg(); + auto Cmp = I.getOperand(3).getReg(); + auto Val = I.getOperand(4).getReg(); + + auto SpvValTy = GR.getSPIRVTypeForVReg(Val); + auto ScSem = getMemSemanticsForStorageClass(GR.getPointerStorageClass(Ptr)); + + auto MemSemEq = getMemSemantics(MemOp->getSuccessOrdering()) | ScSem; + Register MemSemEqReg = buildI32Constant(MemSemEq, MIRBuilder); + + auto MemSemNeq = getMemSemantics(MemOp->getFailureOrdering()) | ScSem; + Register MemSemNeqReg = MemSemEq == MemSemNeq + ? MemSemEqReg + : buildI32Constant(MemSemNeq, MIRBuilder); + + auto TmpReg = WithSuccess ? MRI->createGenericVirtualRegister(LLT::scalar(32)) + : ResVReg; + bool Success = MIRBuilder.buildInstr(SPIRV::OpAtomicCompareExchange) + .addDef(TmpReg) + .addUse(GR.getSPIRVTypeID(SpvValTy)) + .addUse(Ptr) + .addUse(ScopeReg) + .addUse(MemSemEqReg) + .addUse(MemSemNeqReg) + .addUse(Val) + .addUse(Cmp) + .constrainAllUses(TII, TRI, RBI); + if (!WithSuccess) // If we just need the old Val, not {oldVal, Success} + return Success; + assert(ResType->getOpcode() == SPIRV::OpTypeStruct); + auto BoolReg = MRI->createGenericVirtualRegister(LLT::scalar(1)); + Register BoolTyID = ResType->getOperand(2).getReg(); + Success &= MIRBuilder.buildInstr(SPIRV::OpIEqual) + .addDef(BoolReg) + .addUse(BoolTyID) + .addUse(TmpReg) + .addUse(Cmp) + .constrainAllUses(TII, TRI, RBI); + return Success && MIRBuilder.buildInstr(SPIRV::OpCompositeConstruct) + .addDef(ResVReg) + .addUse(GR.getSPIRVTypeID(ResType)) + .addUse(TmpReg) + .addUse(BoolReg) + .constrainAllUses(TII, TRI, RBI); +} + +static bool isGenericCastablePtr(StorageClass::StorageClass Sc) { + switch (Sc) { + case StorageClass::Workgroup: + case StorageClass::CrossWorkgroup: + case StorageClass::Function: + return true; + default: + return false; + } +} + +// In SPIR-V address space casting can only happen to and from the Generic +// storage class. We can also only case Workgroup, CrossWorkgroup, or Function +// pointers to and from Generic pointers. As such, we can convert e.g. from +// Workgroup to Function by going via a Generic pointer as an intermediary. All +// other combinations can only be done by a bitcast, and are probably not safe. +bool SPIRVInstructionSelector::selectAddrSpaceCast( + Register ResVReg, const SPIRVType *ResType, const MachineInstr &I, + MachineIRBuilder &MIRBuilder) const { + using namespace SPIRV; + namespace SC = StorageClass; + auto SrcPtr = I.getOperand(1).getReg(); + auto SrcPtrTy = GR.getSPIRVTypeForVReg(SrcPtr); + auto SrcSC = GR.getPointerStorageClass(SrcPtr); + auto DstSC = GR.getPointerStorageClass(ResVReg); + + if (DstSC == SC::Generic && isGenericCastablePtr(SrcSC)) { + // We're casting from an eligable pointer to Generic + return selectUnOp(ResVReg, ResType, I, MIRBuilder, OpPtrCastToGeneric); + } else if (SrcSC == SC::Generic && isGenericCastablePtr(DstSC)) { + // We're casting from Generic to an eligable pointer + return selectUnOp(ResVReg, ResType, I, MIRBuilder, OpGenericCastToPtr); + } else if (isGenericCastablePtr(SrcSC) && isGenericCastablePtr(DstSC)) { + // We're casting between 2 eligable pointers using Generic as an + // intermediary + auto Tmp = MIRBuilder.getMRI()->createVirtualRegister(&IDRegClass); + auto GenericPtrTy = GR.getOrCreateSPIRVPointerType(SrcPtrTy, MIRBuilder, + StorageClass::Generic); + bool Success = MIRBuilder.buildInstr(OpPtrCastToGeneric) + .addDef(Tmp) + .addUse(GR.getSPIRVTypeID(GenericPtrTy)) + .addUse(SrcPtr) + .constrainAllUses(TII, TRI, RBI); + return Success && MIRBuilder.buildInstr(OpGenericCastToPtr) + .addDef(ResVReg) + .addUse(GR.getSPIRVTypeID(ResType)) + .addUse(Tmp) + .constrainAllUses(TII, TRI, RBI); + } else { + // TODO Should this case just be disallowed completely? + // We're casting 2 other arbitrary address spaces, so have to bitcast + return selectUnOp(ResVReg, ResType, I, MIRBuilder, OpBitcast); + } +} + +static unsigned int getFCmpOpcode(unsigned PredNum) { + auto Pred = static_cast(PredNum); + switch (Pred) { + case CmpInst::FCMP_OEQ: + return SPIRV::OpFOrdEqual; + case CmpInst::FCMP_OGE: + return SPIRV::OpFOrdGreaterThanEqual; + case CmpInst::FCMP_OGT: + return SPIRV::OpFOrdGreaterThan; + case CmpInst::FCMP_OLE: + return SPIRV::OpFOrdLessThanEqual; + case CmpInst::FCMP_OLT: + return SPIRV::OpFOrdLessThan; + case CmpInst::FCMP_ONE: + return SPIRV::OpFOrdNotEqual; + case CmpInst::FCMP_ORD: + return SPIRV::OpOrdered; + case CmpInst::FCMP_UEQ: + return SPIRV::OpFUnordEqual; + case CmpInst::FCMP_UGE: + return SPIRV::OpFUnordGreaterThanEqual; + case CmpInst::FCMP_UGT: + return SPIRV::OpFUnordGreaterThan; + case CmpInst::FCMP_ULE: + return SPIRV::OpFUnordLessThanEqual; + case CmpInst::FCMP_ULT: + return SPIRV::OpFUnordLessThan; + case CmpInst::FCMP_UNE: + return SPIRV::OpFUnordNotEqual; + case CmpInst::FCMP_UNO: + return SPIRV::OpUnordered; + default: + report_fatal_error("Unknown predicate type for FCmp: " + + CmpInst::getPredicateName(Pred)); + } +} + +static unsigned int getICmpOpcode(unsigned PredNum) { + auto Pred = static_cast(PredNum); + switch (Pred) { + case CmpInst::ICMP_EQ: + return SPIRV::OpIEqual; + case CmpInst::ICMP_NE: + return SPIRV::OpINotEqual; + case CmpInst::ICMP_SGE: + return SPIRV::OpSGreaterThanEqual; + case CmpInst::ICMP_SGT: + return SPIRV::OpSGreaterThan; + case CmpInst::ICMP_SLE: + return SPIRV::OpSLessThanEqual; + case CmpInst::ICMP_SLT: + return SPIRV::OpSLessThan; + case CmpInst::ICMP_UGE: + return SPIRV::OpUGreaterThanEqual; + case CmpInst::ICMP_UGT: + return SPIRV::OpUGreaterThan; + case CmpInst::ICMP_ULE: + return SPIRV::OpULessThanEqual; + case CmpInst::ICMP_ULT: + return SPIRV::OpULessThan; + default: + report_fatal_error("Unknown predicate type for ICmp: " + + CmpInst::getPredicateName(Pred)); + } +} + +// Return +static std::pair getPtrCmpOpcode(unsigned Pred) { + switch (static_cast(Pred)) { + case CmpInst::ICMP_EQ: + return {SPIRV::OpPtrEqual, false}; + case CmpInst::ICMP_NE: + return {SPIRV::OpPtrNotEqual, false}; + default: + return {getICmpOpcode(Pred), true}; + } +} + +// Return the logical operation, or abort if none exists +static unsigned int getBoolCmpOpcode(unsigned PredNum) { + auto Pred = static_cast(PredNum); + switch (Pred) { + case CmpInst::ICMP_EQ: + return SPIRV::OpLogicalEqual; + case CmpInst::ICMP_NE: + return SPIRV::OpLogicalNotEqual; + default: + report_fatal_error("Unknown predicate type for Bool comparison: " + + CmpInst::getPredicateName(Pred)); + } +} + +bool SPIRVInstructionSelector::selectBitreverse( + Register ResVReg, const SPIRVType *ResType, const MachineInstr &I, + MachineIRBuilder &MIRBuilder) const { + return MIRBuilder.buildInstr(SPIRV::OpBitReverse) + .addDef(ResVReg) + .addUse(GR.getSPIRVTypeID(ResType)) + .addUse(I.getOperand(1).getReg()) + .constrainAllUses(TII, TRI, RBI); +} + +bool SPIRVInstructionSelector::selectConstVector( + Register ResVReg, const SPIRVType *ResType, const MachineInstr &I, + MachineIRBuilder &MIRBuilder) const { + // TODO: only const case is supported for now + assert(std::all_of( + I.operands_begin(), I.operands_end(), + [&MIRBuilder](const MachineOperand &MO) { + if (MO.isDef()) + return true; + if (!MO.isReg()) + return false; + auto *ConstTy = MIRBuilder.getMRI()->getVRegDef(MO.getReg()); + assert(ConstTy && ConstTy->getOpcode() == SPIRV::ASSIGN_TYPE && + ConstTy->getOperand(1).isReg()); + auto *Const = + MIRBuilder.getMRI()->getVRegDef(ConstTy->getOperand(1).getReg()); + assert(Const); + return (Const->getOpcode() == TargetOpcode::G_CONSTANT || + Const->getOpcode() == TargetOpcode::G_FCONSTANT); + })); + + auto MIB = MIRBuilder.buildInstr(SPIRV::OpConstantComposite) + .addDef(ResVReg) + .addUse(GR.getSPIRVTypeID(ResType)); + for (unsigned i = I.getNumExplicitDefs(); i < I.getNumExplicitOperands(); ++i) + MIB.addUse(I.getOperand(i).getReg()); + return MIB.constrainAllUses(TII, TRI, RBI); +} + +bool SPIRVInstructionSelector::selectCmp(Register ResVReg, + const SPIRVType *ResType, + unsigned scalarTyOpc, unsigned CmpOpc, + const MachineInstr &I, + MachineIRBuilder &MIRBuilder) const { + using namespace SPIRV; + + Register Cmp0 = I.getOperand(2).getReg(); + Register Cmp1 = I.getOperand(3).getReg(); + SPIRVType *Cmp0Type = GR.getSPIRVTypeForVReg(Cmp0); + SPIRVType *Cmp1Type = GR.getSPIRVTypeForVReg(Cmp1); + assert(Cmp0Type->getOpcode() == Cmp1Type->getOpcode()); + + if (Cmp0Type->getOpcode() != OpTypePointer && + (!GR.isScalarOrVectorOfType(Cmp0, scalarTyOpc) || + !GR.isScalarOrVectorOfType(Cmp1, scalarTyOpc))) + llvm_unreachable("Incompatible type for comparison"); + + return MIRBuilder.buildInstr(CmpOpc) + .addDef(ResVReg) + .addUse(GR.getSPIRVTypeID(ResType)) + .addUse(Cmp0) + .addUse(Cmp1) + .constrainAllUses(TII, TRI, RBI); +} + +bool SPIRVInstructionSelector::selectICmp(Register ResVReg, + const SPIRVType *ResType, + const MachineInstr &I, + MachineIRBuilder &MIRBuilder) const { + + auto Pred = I.getOperand(1).getPredicate(); + unsigned CmpOpc = getICmpOpcode(Pred); + unsigned TypeOpc = SPIRV::OpTypeInt; + bool PtrToUInt = false; + + Register CmpOperand = I.getOperand(2).getReg(); + if (GR.isScalarOfType(CmpOperand, SPIRV::OpTypePointer)) { + if (STI.canDirectlyComparePointers()) { + std::tie(CmpOpc, PtrToUInt) = getPtrCmpOpcode(Pred); + } else { + PtrToUInt = true; + } + } else if (GR.isScalarOrVectorOfType(CmpOperand, SPIRV::OpTypeBool)) { + TypeOpc = SPIRV::OpTypeBool; + CmpOpc = getBoolCmpOpcode(Pred); + } + return selectCmp(ResVReg, ResType, TypeOpc, CmpOpc, I, MIRBuilder); +} + +void SPIRVInstructionSelector::renderFImm32(MachineInstrBuilder &MIB, + const MachineInstr &I, + int OpIdx) const { + assert(I.getOpcode() == TargetOpcode::G_FCONSTANT && OpIdx == -1 && + "Expected G_FCONSTANT"); + const ConstantFP *FPImm = I.getOperand(1).getFPImm(); + addNumImm(FPImm->getValueAPF().bitcastToAPInt(), MIB, true); +} + +void SPIRVInstructionSelector::renderImm32(MachineInstrBuilder &MIB, + const MachineInstr &I, + int OpIdx) const { + assert(I.getOpcode() == TargetOpcode::G_CONSTANT && OpIdx == -1 && + "Expected G_CONSTANT"); + addNumImm(I.getOperand(1).getCImm()->getValue(), MIB); +} + +Register +SPIRVInstructionSelector::buildI32Constant(uint32_t Val, + MachineIRBuilder &MIRBuilder, + const SPIRVType *ResType) const { + auto MRI = MIRBuilder.getMRI(); + auto LLVMTy = + IntegerType::get(MIRBuilder.getMF().getFunction().getContext(), 32); + auto SpvI32Ty = + ResType ? ResType : GR.getOrCreateSPIRVType(LLVMTy, MIRBuilder); + Register NewReg; + NewReg = MRI->createGenericVirtualRegister(LLT::scalar(32)); + MachineInstr *MI; + if (Val == 0) + MI = MIRBuilder.buildInstr(SPIRV::OpConstantNull) + .addDef(NewReg) + .addUse(GR.getSPIRVTypeID(SpvI32Ty)); + else + MI = MIRBuilder.buildInstr(SPIRV::OpConstantI) + .addDef(NewReg) + .addUse(GR.getSPIRVTypeID(SpvI32Ty)) + .addImm(APInt(32, Val).getZExtValue()); + constrainSelectedInstRegOperands(*MI, TII, TRI, RBI); + return NewReg; +} + +bool SPIRVInstructionSelector::selectFCmp(Register ResVReg, + const SPIRVType *ResType, + const MachineInstr &I, + MachineIRBuilder &MIRBuilder) const { + unsigned int CmpOp = getFCmpOpcode(I.getOperand(1).getPredicate()); + return selectCmp(ResVReg, ResType, SPIRV::OpTypeFloat, CmpOp, I, MIRBuilder); +} + +Register +SPIRVInstructionSelector::buildZerosVal(const SPIRVType *ResType, + MachineIRBuilder &MIRBuilder) const { + return buildI32Constant(0, MIRBuilder, ResType); +} + +Register +SPIRVInstructionSelector::buildOnesVal(bool AllOnes, const SPIRVType *ResType, + MachineIRBuilder &MIRBuilder) const { + auto MRI = MIRBuilder.getMRI(); + unsigned BitWidth = GR.getScalarOrVectorBitWidth(ResType); + auto One = AllOnes ? APInt::getAllOnesValue(BitWidth) + : APInt::getOneBitSet(BitWidth, 0); + Register OneReg = buildI32Constant(One.getZExtValue(), MIRBuilder, ResType); + if (ResType->getOpcode() == SPIRV::OpTypeVector) { + const unsigned NumEles = ResType->getOperand(2).getImm(); + Register OneVec = MRI->createVirtualRegister(&SPIRV::IDRegClass); + auto MIB = MIRBuilder.buildInstr(SPIRV::OpConstantComposite) + .addDef(OneVec) + .addUse(GR.getSPIRVTypeID(ResType)); + for (unsigned i = 0; i < NumEles; ++i) { + MIB.addUse(OneReg); + } + constrainRegOperands(MIB); + return OneVec; + } else { + return OneReg; + } +} + +bool SPIRVInstructionSelector::selectSelect( + Register ResVReg, const SPIRVType *ResType, const MachineInstr &I, + bool IsSigned, MachineIRBuilder &MIRBuilder) const { + // To extend a bool, we need to use OpSelect between constants + using namespace SPIRV; + Register ZeroReg = buildZerosVal(ResType, MIRBuilder); + Register OneReg = buildOnesVal(IsSigned, ResType, MIRBuilder); + bool IsScalarBool = GR.isScalarOfType(I.getOperand(1).getReg(), OpTypeBool); + return MIRBuilder.buildInstr(IsScalarBool ? OpSelectSISCond : OpSelectSIVCond) + .addDef(ResVReg) + .addUse(GR.getSPIRVTypeID(ResType)) + .addUse(I.getOperand(1).getReg()) + .addUse(OneReg) + .addUse(ZeroReg) + .constrainAllUses(TII, TRI, RBI); +} + +bool SPIRVInstructionSelector::selectIToF(Register ResVReg, + const SPIRVType *ResType, + const MachineInstr &I, bool IsSigned, + MachineIRBuilder &MIRBuilder, + unsigned Opcode) const { + Register SrcReg = I.getOperand(1).getReg(); + // We can convert bool value directly to float type without OpConvert*ToF, + // however the translator generates OpSelect+OpConvert*ToF, so we do the same. + if (GR.isScalarOrVectorOfType(I.getOperand(1).getReg(), SPIRV::OpTypeBool)) { + auto BitWidth = GR.getScalarOrVectorBitWidth(ResType); + SPIRVType *TmpType = GR.getOrCreateSPIRVIntegerType(BitWidth, MIRBuilder); + if (ResType->getOpcode() == SPIRV::OpTypeVector) { + const unsigned NumElts = ResType->getOperand(2).getImm(); + TmpType = GR.getOrCreateSPIRVVectorType(TmpType, NumElts, MIRBuilder); + } + auto MRI = MIRBuilder.getMRI(); + SrcReg = MRI->createVirtualRegister(&SPIRV::IDRegClass); + selectSelect(SrcReg, TmpType, I, false, MIRBuilder); + } + return selectUnOpWithSrc(ResVReg, ResType, I, SrcReg, MIRBuilder, Opcode); +} + +bool SPIRVInstructionSelector::selectExt(Register ResVReg, + const SPIRVType *ResType, + const MachineInstr &I, bool IsSigned, + MachineIRBuilder &MIRBuilder) const { + using namespace SPIRV; + if (GR.isScalarOrVectorOfType(I.getOperand(1).getReg(), OpTypeBool)) { + return selectSelect(ResVReg, ResType, I, IsSigned, MIRBuilder); + } else { + return selectUnOp(ResVReg, ResType, I, MIRBuilder, + IsSigned ? OpSConvert : OpUConvert); + } +} + +bool SPIRVInstructionSelector::selectIntToBool( + Register IntReg, Register ResVReg, const SPIRVType *IntTy, + const SPIRVType *BoolTy, MachineIRBuilder &MIRBuilder) const { + // To truncate to a bool, we use OpBitwiseAnd 1 and OpINotEqual to zero + auto MRI = MIRBuilder.getMRI(); + Register BitIntReg = MRI->createVirtualRegister(&SPIRV::IDRegClass); + bool IsVectorTy = IntTy->getOpcode() == SPIRV::OpTypeVector; + auto Opcode = IsVectorTy ? SPIRV::OpBitwiseAndV : SPIRV::OpBitwiseAndS; + auto Zero = buildZerosVal(IntTy, MIRBuilder); + auto One = buildOnesVal(false, IntTy, MIRBuilder); + MIRBuilder.buildInstr(Opcode) + .addDef(BitIntReg) + .addUse(GR.getSPIRVTypeID(IntTy)) + .addUse(IntReg) + .addUse(One) + .constrainAllUses(TII, TRI, RBI); + return MIRBuilder.buildInstr(SPIRV::OpINotEqual) + .addDef(ResVReg) + .addUse(GR.getSPIRVTypeID(BoolTy)) + .addUse(BitIntReg) + .addUse(Zero) + .constrainAllUses(TII, TRI, RBI); +} + +bool SPIRVInstructionSelector::selectTrunc(Register ResVReg, + const SPIRVType *ResType, + const MachineInstr &I, + MachineIRBuilder &MIRBuilder) const { + using namespace SPIRV; + if (GR.isScalarOrVectorOfType(ResVReg, OpTypeBool)) { + auto IntReg = I.getOperand(1).getReg(); + auto ArgType = GR.getSPIRVTypeForVReg(IntReg); + return selectIntToBool(IntReg, ResVReg, ArgType, ResType, MIRBuilder); + } else { + bool IsSigned = GR.isScalarOrVectorSigned(ResType); + return selectUnOp(ResVReg, ResType, I, MIRBuilder, + IsSigned ? OpSConvert : OpUConvert); + } +} + +bool SPIRVInstructionSelector::selectConst(Register ResVReg, + const SPIRVType *ResType, + const APInt &Imm, + MachineIRBuilder &MIRBuilder) const { + assert(ResType->getOpcode() != SPIRV::OpTypePointer || Imm.isNullValue()); + if (ResType->getOpcode() == SPIRV::OpTypePointer && Imm.isNullValue()) + return MIRBuilder.buildInstr(SPIRV::OpConstantNull) + .addDef(ResVReg) + .addUse(GR.getSPIRVTypeID(ResType)) + .constrainAllUses(TII, TRI, RBI); + else { + auto MIB = MIRBuilder.buildInstr(SPIRV::OpConstantI) + .addDef(ResVReg) + .addUse(GR.getSPIRVTypeID(ResType)); + // <=32-bit integers should be caught by the sdag pattern + assert(Imm.getBitWidth() > 32); + addNumImm(Imm, MIB); + return MIB.constrainAllUses(TII, TRI, RBI); + } +} + +bool SPIRVInstructionSelector::selectOpUndef( + Register ResVReg, const SPIRVType *ResType, + MachineIRBuilder &MIRBuilder) const { + return MIRBuilder.buildInstr(SPIRV::OpUndef) + .addDef(ResVReg) + .addUse(GR.getSPIRVTypeID(ResType)) + .constrainAllUses(TII, TRI, RBI); +} + +bool SPIRVInstructionSelector::selectIntrinsic( + Register ResVReg, const SPIRVType *ResType, const MachineInstr &I, + MachineIRBuilder &MIRBuilder) const { + llvm_unreachable("Intrinsic selection not implemented"); +} + +bool SPIRVInstructionSelector::selectFrameIndex( + Register ResVReg, const SPIRVType *ResType, + MachineIRBuilder &MIRBuilder) const { + return MIRBuilder.buildInstr(SPIRV::OpVariable) + .addDef(ResVReg) + .addUse(GR.getSPIRVTypeID(ResType)) + .addImm(StorageClass::Function) + .constrainAllUses(TII, TRI, RBI); +} + +bool SPIRVInstructionSelector::selectBranch( + const MachineInstr &I, MachineIRBuilder &MIRBuilder) const { + // InstructionSelector walks backwards through the instructions. We can use + // both a G_BR and a G_BRCOND to create an OpBranchConditional. We hit G_BR + // first, so can generate an OpBranchConditional here. If there is no + // G_BRCOND, we just use OpBranch for a regular unconditional branch. + const MachineInstr *PrevI = I.getPrevNode(); + if (PrevI != nullptr && PrevI->getOpcode() == TargetOpcode::G_BRCOND) { + return MIRBuilder.buildInstr(SPIRV::OpBranchConditional) + .addUse(PrevI->getOperand(0).getReg()) + .addMBB(PrevI->getOperand(1).getMBB()) + .addMBB(I.getOperand(0).getMBB()) + .constrainAllUses(TII, TRI, RBI); + } else { + return MIRBuilder.buildInstr(SPIRV::OpBranch) + .addMBB(I.getOperand(0).getMBB()) + .constrainAllUses(TII, TRI, RBI); + } +} + +bool SPIRVInstructionSelector::selectBranchCond( + const MachineInstr &I, MachineIRBuilder &MIRBuilder) const { + // InstructionSelector walks backwards through the instructions. For an + // explicit conditional branch with no fallthrough, we use both a G_BR and a + // G_BRCOND to create an OpBranchConditional. We should hit G_BR first, and + // generate the OpBranchConditional in selectBranch above. + // + // If an OpBranchConditional has been generated, we simply return, as the work + // is alread done. If there is no OpBranchConditional, LLVM must be relying on + // implicit fallthrough to the next basic block, so we need to create an + // OpBranchConditional with an explicit "false" argument pointing to the next + // basic block that LLVM would fall through to. + const MachineInstr *NextI = I.getNextNode(); + // Check if this has already been successfully selected + if (NextI != nullptr && NextI->getOpcode() == SPIRV::OpBranchConditional) { + return true; + } else { + // Must be relying on implicit block fallthrough, so generate an + // OpBranchConditional with the "next" basic block as the "false" target. + auto NextMBBNum = I.getParent()->getNextNode()->getNumber(); + auto NextMBB = I.getMF()->getBlockNumbered(NextMBBNum); + return MIRBuilder.buildInstr(SPIRV::OpBranchConditional) + .addUse(I.getOperand(0).getReg()) + .addMBB(I.getOperand(1).getMBB()) + .addMBB(NextMBB) + .constrainAllUses(TII, TRI, RBI); + } +} + +bool SPIRVInstructionSelector::selectPhi(Register ResVReg, + const SPIRVType *ResType, + const MachineInstr &I, + MachineIRBuilder &MIRBuilder) const { + auto MIB = MIRBuilder.buildInstr(SPIRV::OpPhi) + .addDef(ResVReg) + .addUse(GR.getSPIRVTypeID(ResType)); + const unsigned int NumOps = I.getNumOperands(); + assert((NumOps % 2 == 1) && "Require odd number of operands for G_PHI"); + for (unsigned int i = 1; i < NumOps; i += 2) { + MIB.addUse(I.getOperand(i + 0).getReg()); + MIB.addMBB(I.getOperand(i + 1).getMBB()); + } + return MIB.constrainAllUses(TII, TRI, RBI); +} + +bool SPIRVInstructionSelector::selectGlobalValue( + Register ResVReg, const MachineInstr &I, MachineIRBuilder &MIRBuilder, + const MachineInstr *Init) const { + auto *GV = I.getOperand(1).getGlobal(); + SPIRVType *ResType = + GR.getOrCreateSPIRVType(GV->getType(), MIRBuilder, AQ::ReadWrite, false); + + auto GlobalIdent = GV->getGlobalIdentifier(); + auto GlobalVar = dyn_cast(GV); + + bool HasInit = GlobalVar && GlobalVar->hasInitializer() && + !isa(GlobalVar->getInitializer()); + // Skip empty declaration for GVs with initilaizers till we get the decl with + // passed initializer. + if (HasInit && !Init) + return true; + + auto AddrSpace = GV->getAddressSpace(); + auto Storage = addressSpaceToStorageClass(AddrSpace); + bool HasLnkTy = GV->getLinkage() != GlobalValue::InternalLinkage && + Storage != StorageClass::Function; + auto LnkType = (GV->isDeclaration() || GV->hasAvailableExternallyLinkage()) + ? LinkageType::Import + : LinkageType::Export; + + auto Reg = GR.buildGlobalVariable(ResVReg, ResType, GlobalIdent, GV, Storage, + Init, GlobalVar && GlobalVar->isConstant(), + HasLnkTy, LnkType, MIRBuilder); + return Reg.isValid(); +} + +namespace llvm { +InstructionSelector * +createSPIRVInstructionSelector(const SPIRVTargetMachine &TM, + const SPIRVSubtarget &Subtarget, + const SPIRVRegisterBankInfo &RBI) { + return new SPIRVInstructionSelector(TM, Subtarget, RBI); +} +} // namespace llvm Index: llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.h =================================================================== --- /dev/null +++ llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.h @@ -0,0 +1,36 @@ +//===- SPIRVLegalizerInfo.h --- SPIR-V Legalization Rules --------*- C++ -*-==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares the targeting of the MachineLegalizer class for SPIR-V. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_LIB_TARGET_SPIRV_SPIRVMACHINELEGALIZER_H +#define LLVM_LIB_TARGET_SPIRV_SPIRVMACHINELEGALIZER_H + +#include "SPIRVGlobalRegistry.h" +#include "llvm/CodeGen/GlobalISel/LegalizerInfo.h" + +bool isTypeFoldingSupported(unsigned Opcode); + +namespace llvm { + +class LLVMContext; +class SPIRVSubtarget; + +// This class provides the information for legalizing SPIR-V instructions. +class SPIRVLegalizerInfo : public LegalizerInfo { + const SPIRVSubtarget *ST; + SPIRVGlobalRegistry *GR; + +public: + bool legalizeCustom(LegalizerHelper &Helper, MachineInstr &MI) const override; + SPIRVLegalizerInfo(const SPIRVSubtarget &ST); +}; +} // namespace llvm +#endif // LLVM_LIB_TARGET_SPIRV_SPIRVMACHINELEGALIZER_H Index: llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp =================================================================== --- /dev/null +++ llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp @@ -0,0 +1,376 @@ +//===- SPIRVLegalizerInfo.cpp --- SPIR-V Legalization Rules ------*- C++ -*-==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the targeting of the Machinelegalizer class for SPIR-V. +// +//===----------------------------------------------------------------------===// + +#include "SPIRVLegalizerInfo.h" +#include "SPIRV.h" +#include "SPIRVGlobalRegistry.h" +#include "SPIRVSubtarget.h" +#include "llvm/CodeGen/GlobalISel/LegalizerHelper.h" +#include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h" +#include "llvm/CodeGen/MachineInstr.h" +#include "llvm/CodeGen/MachineRegisterInfo.h" +#include "llvm/CodeGen/TargetOpcodes.h" + +#include + +using namespace llvm; +using namespace LegalizeActions; +using namespace LegalityPredicates; + +static const std::unordered_set TypeFoldingSupportingOpcs = { + TargetOpcode::G_ADD, + TargetOpcode::G_FADD, + TargetOpcode::G_SUB, + TargetOpcode::G_FSUB, + TargetOpcode::G_MUL, + TargetOpcode::G_FMUL, + TargetOpcode::G_SDIV, + TargetOpcode::G_UDIV, + TargetOpcode::G_FDIV, + TargetOpcode::G_SREM, + TargetOpcode::G_UREM, + TargetOpcode::G_FREM, + TargetOpcode::G_FNEG, + TargetOpcode::G_CONSTANT, + TargetOpcode::G_FCONSTANT, + TargetOpcode::G_AND, + TargetOpcode::G_OR, + TargetOpcode::G_XOR, + TargetOpcode::G_SHL, + TargetOpcode::G_ASHR, + TargetOpcode::G_LSHR, + TargetOpcode::G_SELECT, + TargetOpcode::G_EXTRACT_VECTOR_ELT, +}; + +static const std::unordered_set &getTypeFoldingSupportingOpcs() { + return TypeFoldingSupportingOpcs; +} + +bool isTypeFoldingSupported(unsigned Opcode) { + return TypeFoldingSupportingOpcs.count(Opcode) > 0; +} + +SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) { + using namespace TargetOpcode; + + this->ST = &ST; + GR = ST.getSPIRVGlobalRegistry(); + + const LLT s1 = LLT::scalar(1); + const LLT s8 = LLT::scalar(8); + const LLT s16 = LLT::scalar(16); + const LLT s32 = LLT::scalar(32); + const LLT s64 = LLT::scalar(64); + + const LLT v16s64 = LLT::fixed_vector(16, 64); + const LLT v16s32 = LLT::fixed_vector(16, 32); + const LLT v16s16 = LLT::fixed_vector(16, 16); + const LLT v16s8 = LLT::fixed_vector(16, 8); + const LLT v16s1 = LLT::fixed_vector(16, 1); + + const LLT v8s64 = LLT::fixed_vector(8, 64); + const LLT v8s32 = LLT::fixed_vector(8, 32); + const LLT v8s16 = LLT::fixed_vector(8, 16); + const LLT v8s8 = LLT::fixed_vector(8, 8); + const LLT v8s1 = LLT::fixed_vector(8, 1); + + const LLT v4s64 = LLT::fixed_vector(4, 64); + const LLT v4s32 = LLT::fixed_vector(4, 32); + const LLT v4s16 = LLT::fixed_vector(4, 16); + const LLT v4s8 = LLT::fixed_vector(4, 8); + const LLT v4s1 = LLT::fixed_vector(4, 1); + + const LLT v3s64 = LLT::fixed_vector(3, 64); + const LLT v3s32 = LLT::fixed_vector(3, 32); + const LLT v3s16 = LLT::fixed_vector(3, 16); + const LLT v3s8 = LLT::fixed_vector(3, 8); + const LLT v3s1 = LLT::fixed_vector(3, 1); + + const LLT v2s64 = LLT::fixed_vector(2, 64); + const LLT v2s32 = LLT::fixed_vector(2, 32); + const LLT v2s16 = LLT::fixed_vector(2, 16); + const LLT v2s8 = LLT::fixed_vector(2, 8); + const LLT v2s1 = LLT::fixed_vector(2, 1); + + const unsigned PSize = ST.getPointerSize(); + const LLT p0 = LLT::pointer(0, PSize); // Function + const LLT p1 = LLT::pointer(1, PSize); // CrossWorkgroup + const LLT p2 = LLT::pointer(2, PSize); // UniformConstant + const LLT p3 = LLT::pointer(3, PSize); // Workgroup + const LLT p4 = LLT::pointer(4, PSize); // Generic + const LLT p5 = LLT::pointer(5, PSize); // Input + + // TODO: remove copy-pasting here by using concatenation in some way + auto allPtrsScalarsAndVectors = { + p0, p1, p2, p3, p4, p5, s1, s8, s16, + s32, s64, v2s1, v2s8, v2s16, v2s32, v2s64, v3s1, v3s8, + v3s16, v3s32, v3s64, v4s1, v4s8, v4s16, v4s32, v4s64, v8s1, + v8s8, v8s16, v8s32, v8s64, v16s1, v16s8, v16s16, v16s32, v16s64}; + + auto allScalarsAndVectors = { + s1, s8, s16, s32, s64, v2s1, v2s8, v2s16, v2s32, v2s64, + v3s1, v3s8, v3s16, v3s32, v3s64, v4s1, v4s8, v4s16, v4s32, v4s64, + v8s1, v8s8, v8s16, v8s32, v8s64, v16s1, v16s8, v16s16, v16s32, v16s64}; + + auto allIntScalarsAndVectors = {s8, s16, s32, s64, v2s8, v2s16, + v2s32, v2s64, v3s8, v3s16, v3s32, v3s64, + v4s8, v4s16, v4s32, v4s64, v8s8, v8s16, + v8s32, v8s64, v16s8, v16s16, v16s32, v16s64}; + + auto allBoolScalarsAndVectors = {s1, v2s1, v3s1, v4s1, v8s1, v16s1}; + + auto allIntScalars = {s8, s16, s32, s64}; + + auto allFloatScalarsAndVectors = { + s16, s32, s64, v2s16, v2s32, v2s64, v3s16, v3s32, v3s64, + v4s16, v4s32, v4s64, v8s16, v8s32, v8s64, v16s16, v16s32, v16s64}; + + auto allFloatAndIntScalars = allIntScalars; + + auto allPtrs = {p0, p1, p2, p3, p4, p5}; + auto allWritablePtrs = {p0, p1, p3, p4}; + + for (auto Opc : getTypeFoldingSupportingOpcs()) + getActionDefinitionsBuilder(Opc).custom(); + + getActionDefinitionsBuilder(G_GLOBAL_VALUE).alwaysLegal(); + + // TODO: add proper rules for vectors legalization + getActionDefinitionsBuilder({G_BUILD_VECTOR, G_SHUFFLE_VECTOR}).alwaysLegal(); + + getActionDefinitionsBuilder({G_MEMCPY, G_MEMMOVE}) + .legalIf(all(typeInSet(0, allWritablePtrs), typeInSet(1, allPtrs))); + + getActionDefinitionsBuilder(G_ADDRSPACE_CAST) + .legalForCartesianProduct(allPtrs, allPtrs); + + getActionDefinitionsBuilder({G_LOAD, G_STORE}).legalIf(typeInSet(1, allPtrs)); + + getActionDefinitionsBuilder(G_BITREVERSE).legalFor(allFloatScalarsAndVectors); + + getActionDefinitionsBuilder(G_FMA).legalFor(allFloatScalarsAndVectors); + + getActionDefinitionsBuilder({G_FPTOSI, G_FPTOUI}) + .legalForCartesianProduct(allIntScalarsAndVectors, + allFloatScalarsAndVectors); + + getActionDefinitionsBuilder({G_SITOFP, G_UITOFP}) + .legalForCartesianProduct(allFloatScalarsAndVectors, + allScalarsAndVectors); + + getActionDefinitionsBuilder({G_SMIN, G_SMAX, G_UMIN, G_UMAX, G_ABS}) + .legalFor(allIntScalarsAndVectors); + + getActionDefinitionsBuilder(G_CTPOP).legalForCartesianProduct( + allIntScalarsAndVectors, allIntScalarsAndVectors); + + getActionDefinitionsBuilder(G_PHI).alwaysLegal(); + + getActionDefinitionsBuilder(G_BITCAST).legalIf(all( + typeInSet(0, allPtrsScalarsAndVectors), + typeInSet(1, allPtrsScalarsAndVectors), + LegalityPredicate(([=](const LegalityQuery &Query) { + return Query.Types[0].getSizeInBits() == Query.Types[1].getSizeInBits(); + })))); + + getActionDefinitionsBuilder(G_IMPLICIT_DEF).alwaysLegal(); + + getActionDefinitionsBuilder(G_INTTOPTR) + .legalForCartesianProduct(allPtrs, allIntScalars); + getActionDefinitionsBuilder(G_PTRTOINT) + .legalForCartesianProduct(allIntScalars, allPtrs); + getActionDefinitionsBuilder(G_PTR_ADD).legalForCartesianProduct( + allPtrs, allIntScalars); + + getActionDefinitionsBuilder(G_ICMP).customIf( + all(typeInSet(0, allBoolScalarsAndVectors), + typeInSet(1, allPtrsScalarsAndVectors), + LegalityPredicate([&](const LegalityQuery &Query) { + LLT retTy = Query.Types[0]; + LLT cmpTy = Query.Types[1]; + if (retTy.isVector()) + return cmpTy.isVector() && + retTy.getNumElements() == cmpTy.getNumElements(); + else + // ST.canDirectlyComparePointers() for ponter args is + // checked in legalizeCustom(). + return cmpTy.isScalar() || cmpTy.isPointer(); + }))); + + getActionDefinitionsBuilder(G_FCMP).legalIf( + all(typeInSet(0, allBoolScalarsAndVectors), + typeInSet(1, allFloatScalarsAndVectors), + LegalityPredicate([=](const LegalityQuery &Query) { + LLT retTy = Query.Types[0]; + LLT cmpTy = Query.Types[1]; + if (retTy.isVector()) { + return cmpTy.isVector() && + retTy.getNumElements() == cmpTy.getNumElements(); + } else { + return cmpTy.isScalar(); + } + }))); + + getActionDefinitionsBuilder({G_ATOMICRMW_OR, G_ATOMICRMW_ADD, G_ATOMICRMW_AND, + G_ATOMICRMW_MAX, G_ATOMICRMW_MIN, + G_ATOMICRMW_SUB, G_ATOMICRMW_XOR, + G_ATOMICRMW_UMAX, G_ATOMICRMW_UMIN}) + .legalForCartesianProduct(allIntScalars, allWritablePtrs); + + getActionDefinitionsBuilder(G_ATOMICRMW_XCHG) + .legalForCartesianProduct(allFloatAndIntScalars, allWritablePtrs); + + // Struct return types become a single large scalar, so cannot easily legalize + getActionDefinitionsBuilder({G_ATOMIC_CMPXCHG, G_ATOMIC_CMPXCHG_WITH_SUCCESS}) + .alwaysLegal(); + getActionDefinitionsBuilder({G_UADDO, G_USUBO, G_SMULO, G_UMULO}) + .alwaysLegal(); + + // Extensions + getActionDefinitionsBuilder({G_TRUNC, G_ZEXT, G_SEXT, G_ANYEXT}) + .legalForCartesianProduct(allScalarsAndVectors); + + // FP conversions + getActionDefinitionsBuilder({G_FPTRUNC, G_FPEXT}) + .legalForCartesianProduct(allFloatScalarsAndVectors); + + // Pointer-handling + getActionDefinitionsBuilder(G_FRAME_INDEX).legalFor({p0}); + + // Control-flow + getActionDefinitionsBuilder(G_BRCOND).legalFor({s1}); + + getActionDefinitionsBuilder({G_FPOW, + G_FEXP, + G_FEXP2, + G_FLOG, + G_FLOG2, + G_FABS, + G_FMINNUM, + G_FMAXNUM, + G_FCEIL, + G_FCOS, + G_FSIN, + G_FSQRT, + G_FFLOOR, + G_FRINT, + G_FNEARBYINT, + G_INTRINSIC_ROUND, + G_INTRINSIC_TRUNC, + G_FMINIMUM, + G_FMAXIMUM, + G_INTRINSIC_ROUNDEVEN}) + .legalFor(allFloatScalarsAndVectors); + + getActionDefinitionsBuilder(G_FCOPYSIGN) + .legalForCartesianProduct(allFloatScalarsAndVectors, + allFloatScalarsAndVectors); + + getActionDefinitionsBuilder(G_FPOWI).legalForCartesianProduct( + allFloatScalarsAndVectors, allIntScalarsAndVectors); + + getLegacyLegalizerInfo().computeTables(); + verify(*ST.getInstrInfo()); +} + +static std::pair +createNewIdReg(Register ValReg, unsigned Opcode, MachineRegisterInfo &MRI, + const SPIRVGlobalRegistry &GR) { + auto NewT = LLT::scalar(32); + auto SpvType = GR.getSPIRVTypeForVReg(ValReg); + assert(SpvType && "VReg is expected to have SPIRV type"); + bool IsFloat = SpvType->getOpcode() == SPIRV::OpTypeFloat; + bool IsVectorFloat = + SpvType->getOpcode() == SPIRV::OpTypeVector && + GR.getSPIRVTypeForVReg(SpvType->getOperand(1).getReg())->getOpcode() == + SPIRV::OpTypeFloat; + IsFloat |= IsVectorFloat; + auto GetIdOp = IsFloat ? SPIRV::GET_fID : SPIRV::GET_ID; + auto DstClass = IsFloat ? &SPIRV::fIDRegClass : &SPIRV::IDRegClass; + if (MRI.getType(ValReg).isPointer()) { + NewT = LLT::pointer(0, 32); + GetIdOp = SPIRV::GET_pID; + DstClass = &SPIRV::pIDRegClass; + } else if (MRI.getType(ValReg).isVector()) { + NewT = LLT::fixed_vector(2, NewT); + GetIdOp = IsFloat ? SPIRV::GET_vfID : SPIRV::GET_vID; + DstClass = IsFloat ? &SPIRV::vfIDRegClass : &SPIRV::vIDRegClass; + } + auto IdReg = MRI.createGenericVirtualRegister(NewT); + MRI.setRegClass(IdReg, DstClass); + return {IdReg, GetIdOp}; +} + +bool SPIRVLegalizerInfo::legalizeCustom(LegalizerHelper &Helper, + MachineInstr &MI) const { + auto Opc = MI.getOpcode(); + if (!isTypeFoldingSupported(Opc)) { + assert(Opc == TargetOpcode::G_ICMP); + auto &MRI = MI.getMF()->getRegInfo(); + // Add missed SPIRV type to the VReg + // TODO: move SPIRV type detection to one place + auto ResVReg = MI.getOperand(0).getReg(); + auto ResType = GR->getSPIRVTypeForVReg(ResVReg); + if (!ResType) { + LLT Ty = MRI.getType(ResVReg); + LLT BaseTy = Ty.isVector() ? Ty.getElementType() : Ty; + Type *LLVMTy = IntegerType::get(MI.getMF()->getFunction().getContext(), + BaseTy.getSizeInBits()); + if (Ty.isVector()) + LLVMTy = FixedVectorType::get(LLVMTy, Ty.getNumElements()); + auto *SpirvType = GR->getOrCreateSPIRVType(LLVMTy, Helper.MIRBuilder); + GR->assignSPIRVTypeToVReg(SpirvType, ResVReg, Helper.MIRBuilder); + } + auto &Op0 = MI.getOperand(2); + auto &Op1 = MI.getOperand(3); + if (!ST->canDirectlyComparePointers() && + MRI.getType(Op0.getReg()).isPointer() && + MRI.getType(Op1.getReg()).isPointer()) { + auto ConvT = LLT::scalar(ST->getPointerSize()); + auto ConvReg0 = MRI.createGenericVirtualRegister(ConvT); + auto ConvReg1 = MRI.createGenericVirtualRegister(ConvT); + auto *SpirvType = GR->getOrCreateSPIRVType( + IntegerType::get(MI.getMF()->getFunction().getContext(), + ST->getPointerSize()), + Helper.MIRBuilder); + GR->assignSPIRVTypeToVReg(SpirvType, ConvReg0, Helper.MIRBuilder); + GR->assignSPIRVTypeToVReg(SpirvType, ConvReg1, Helper.MIRBuilder); + Helper.MIRBuilder.buildInstr(TargetOpcode::G_PTRTOINT) + .addDef(ConvReg0) + .addUse(Op0.getReg()); + Helper.MIRBuilder.buildInstr(TargetOpcode::G_PTRTOINT) + .addDef(ConvReg1) + .addUse(Op1.getReg()); + Op0.setReg(ConvReg0); + Op1.setReg(ConvReg1); + } + return true; + } + auto &MRI = MI.getMF()->getRegInfo(); + assert(MI.getNumDefs() > 0 && MRI.hasOneUse(MI.getOperand(0).getReg())); + MachineInstr &AssignTypeInst = + *(MRI.use_instr_begin(MI.getOperand(0).getReg())); + auto NewReg = createNewIdReg(MI.getOperand(0).getReg(), Opc, MRI, *GR).first; + AssignTypeInst.getOperand(1).setReg(NewReg); + MI.getOperand(0).setReg(NewReg); + for (auto &Op : MI.operands()) { + if (!Op.isReg() || Op.isDef()) + continue; + auto IdOpInfo = createNewIdReg(Op.getReg(), Opc, MRI, *GR); + Helper.MIRBuilder.buildInstr(IdOpInfo.second) + .addDef(IdOpInfo.first) + .addUse(Op.getReg()); + Op.setReg(IdOpInfo.first); + } + return true; +} Index: llvm/lib/Target/SPIRV/SPIRVSubtarget.h =================================================================== --- llvm/lib/Target/SPIRV/SPIRVSubtarget.h +++ llvm/lib/Target/SPIRV/SPIRVSubtarget.h @@ -35,7 +35,7 @@ namespace llvm { class StringRef; - +class SPIRVGlobalRegistry; class SPIRVTargetMachine; class SPIRVSubtarget : public SPIRVGenSubtargetInfo { @@ -59,9 +59,15 @@ // TODO Some of these fields might work without unique_ptr. // But they are shared with other classes, so if the SPIRVSubtarget // moves, not relying on unique_ptr breaks things. + std::unique_ptr GR; std::unique_ptr CallLoweringInfo; std::unique_ptr RegBankInfo; + // The legalizer and instruction selector both rely on the set of available + // extensions, capabilities, register bank information, and so on. + std::unique_ptr Legalizer; + std::unique_ptr InstSelector; + private: // Initialise the available extensions, extended instruction sets and // capabilities based on the environment settings (i.e. the previous @@ -95,10 +101,20 @@ uint32_t getTargetSPIRVVersion() const { return TargetSPIRVVersion; }; + SPIRVGlobalRegistry *getSPIRVGlobalRegistry() const { return GR.get(); } + const CallLowering *getCallLowering() const override { return CallLoweringInfo.get(); } + InstructionSelector *getInstructionSelector() const override { + return InstSelector.get(); + } + + const LegalizerInfo *getLegalizerInfo() const override { + return Legalizer.get(); + } + const RegisterBankInfo *getRegBankInfo() const override { return RegBankInfo.get(); } Index: llvm/lib/Target/SPIRV/SPIRVSubtarget.cpp =================================================================== --- llvm/lib/Target/SPIRV/SPIRVSubtarget.cpp +++ llvm/lib/Target/SPIRV/SPIRVSubtarget.cpp @@ -12,6 +12,8 @@ #include "SPIRVSubtarget.h" #include "SPIRV.h" +#include "SPIRVGlobalRegistry.h" +#include "SPIRVLegalizerInfo.h" #include "SPIRVTargetMachine.h" #include "llvm/MC/TargetRegistry.h" #include "llvm/Support/Host.h" @@ -55,8 +57,13 @@ OpenCLFullProfile(computeOpenCLFullProfile(TT)), OpenCLImageSupport(computeOpenCLImageSupport(TT)), InstrInfo(), FrameLowering(initSubtargetDependencies(CPU, FS)), TLInfo(TM, *this), - CallLoweringInfo(new SPIRVCallLowering(TLInfo)), - RegBankInfo(new SPIRVRegisterBankInfo()) {} + GR(new SPIRVGlobalRegistry(PointerSize)), + CallLoweringInfo(new SPIRVCallLowering(TLInfo, GR.get())), + RegBankInfo(new SPIRVRegisterBankInfo()) { + Legalizer.reset(new SPIRVLegalizerInfo(*this)); + InstSelector.reset( + createSPIRVInstructionSelector(TM, *this, *RegBankInfo.get())); +} SPIRVSubtarget &SPIRVSubtarget::initSubtargetDependencies(StringRef CPU, StringRef FS) { Index: llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp =================================================================== --- llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp +++ llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp @@ -12,6 +12,9 @@ #include "SPIRVTargetMachine.h" #include "SPIRV.h" +#include "SPIRVCallLowering.h" +#include "SPIRVGlobalRegistry.h" +#include "SPIRVLegalizerInfo.h" #include "SPIRVSubtarget.h" #include "SPIRVTargetObjectFile.h" #include "SPIRVTargetTransformInfo.h" @@ -30,11 +33,18 @@ #include "llvm/Target/TargetOptions.h" using namespace llvm; +InstructionSelector * +createSPIRVInstructionSelector(const SPIRVTargetMachine &TM, + SPIRVSubtarget &Subtarget, + SPIRVRegisterBankInfo &RBI); extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeSPIRVTarget() { // Register the target. RegisterTargetMachine X(getTheSPIRV32Target()); RegisterTargetMachine Y(getTheSPIRV64Target()); + + PassRegistry &PR = *PassRegistry::getPassRegistry(); + initializeGlobalISel(PR); } // DataLayout: little or big endian @@ -157,7 +167,19 @@ return false; } +namespace { +// A custom subclass of InstructionSelect, which is mostly the same except from +// not requiring RegBankSelect to occur previously. +class SPIRVInstructionSelect : public InstructionSelect { + // We don't use register banks, so unset the requirement for them + MachineFunctionProperties getRequiredProperties() const override { + return InstructionSelect::getRequiredProperties().reset( + MachineFunctionProperties::Property::RegBankSelected); + } +}; +} // namespace + bool SPIRVPassConfig::addGlobalInstructionSelect() { - addPass(new InstructionSelect(getOptLevel())); + addPass(new SPIRVInstructionSelect()); return false; } Index: llvm/lib/Target/SPIRV/SPIRVUtils.h =================================================================== --- /dev/null +++ llvm/lib/Target/SPIRV/SPIRVUtils.h @@ -0,0 +1,58 @@ +//===--- SPIRVUtils.h ---- SPIR-V Utility Functions -------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains miscellaneous utility functions. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_LIB_TARGET_SPIRV_SPIRVUTILS_H +#define LLVM_LIB_TARGET_SPIRV_SPIRVUTILS_H + +#include "MCTargetDesc/SPIRVBaseInfo.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h" +#include "llvm/CodeGen/MachineInstr.h" +#include "llvm/CodeGen/MachineInstrBuilder.h" +#include "llvm/IR/IRBuilder.h" +#include + +// Add the given string as a series of integer operand, inserting null +// terminators and padding to make sure the operands all have 32-bit +// little-endian words. +void addStringImm(const llvm::StringRef &Str, llvm::MachineInstrBuilder &MIB); +void addStringImm(const llvm::StringRef &Str, llvm::IRBuilder<> &B, + std::vector &Args); + +// Add the given numerical immediate to MIB. +void addNumImm(const llvm::APInt &Imm, llvm::MachineInstrBuilder &MIB, + bool IsFloat = false); + +// Add an OpName instruction for the given target register. +void buildOpName(llvm::Register Target, const llvm::StringRef &Name, + llvm::MachineIRBuilder &MIRBuilder); + +// Add an OpDecorate instruction for the given Reg. +void buildOpDecorate(llvm::Register Reg, llvm::MachineIRBuilder &MIRBuilder, + Decoration::Decoration Dec, + const std::vector &DecArgs, + llvm::StringRef StrImm = ""); + +// Convert a SPIR-V storage class to the corresponding LLVM IR address space. +unsigned int storageClassToAddressSpace(StorageClass::StorageClass SC); + +// Convert an LLVM IR address space to a SPIR-V storage class. +StorageClass::StorageClass addressSpaceToStorageClass(unsigned int AddrSpace); + +// Utility method to constrain an instruction's operands to the correct +// register classes, and return true if this worked. +bool constrainRegOperands(llvm::MachineInstrBuilder &MIB, + llvm::MachineFunction *MF = nullptr); + +MemorySemantics::MemorySemantics +getMemSemanticsForStorageClass(StorageClass::StorageClass sc); +#endif // LLVM_LIB_TARGET_SPIRV_SPIRVUTILS_H Index: llvm/lib/Target/SPIRV/SPIRVUtils.cpp =================================================================== --- /dev/null +++ llvm/lib/Target/SPIRV/SPIRVUtils.cpp @@ -0,0 +1,169 @@ +//===--- SPIRVUtils.cpp ---- SPIR-V Utility Functions -----------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains miscellaneous utility functions. +// +//===----------------------------------------------------------------------===// + +#include "SPIRVUtils.h" +#include "SPIRV.h" + +using namespace llvm; + +// The following functions are used to add these string literals as a series of +// 32-bit integer operands with the correct format, and unpack them if necessary +// when making string comparisons in compiler passes. +// SPIR-V requires null-terminated UTF-8 strings padded to 32-bit alignment. +static uint32_t convertCharsToWord(const StringRef &Str, unsigned i) { + uint32_t Word = 0u; // Build up this 32-bit word from 4 8-bit chars + for (unsigned WordIndex = 0; WordIndex < 4; ++WordIndex) { + unsigned StrIndex = i + WordIndex; + uint8_t CharToAdd = 0; // Initilize char as padding/null + if (StrIndex < Str.size()) { // If it's within the string, get a real char + CharToAdd = Str[StrIndex]; + } + Word |= (CharToAdd << (WordIndex * 8)); + } + return Word; +} + +// Get length including padding and null terminator. +static size_t getPaddedLen(const StringRef &Str) { + const size_t Len = Str.size() + 1; + return (Len % 4 == 0) ? Len : Len + (4 - (Len % 4)); +} + +void addStringImm(const StringRef &Str, MachineInstrBuilder &MIB) { + const size_t PaddedLen = getPaddedLen(Str); + for (unsigned i = 0; i < PaddedLen; i += 4) { + // Add an operand for the 32-bits of chars or padding + MIB.addImm(convertCharsToWord(Str, i)); + } +} + +void addStringImm(const StringRef &Str, IRBuilder<> &B, + std::vector &Args) { + const size_t PaddedLen = getPaddedLen(Str); + for (unsigned i = 0; i < PaddedLen; i += 4) { + // Add a vector element for the 32-bits of chars or padding + Args.push_back(B.getInt32(convertCharsToWord(Str, i))); + } +} + +void addNumImm(const APInt &Imm, MachineInstrBuilder &MIB, bool IsFloat) { + const auto Bitwidth = Imm.getBitWidth(); + switch (Bitwidth) { + case 1: + break; // Already handled + case 8: + case 16: + case 32: + MIB.addImm(Imm.getZExtValue()); + break; + case 64: { + if (!IsFloat) { + uint64_t FullImm = Imm.getZExtValue(); + uint32_t LowBits = FullImm & 0xffffffff; + uint32_t HighBits = (FullImm >> 32) & 0xffffffff; + MIB.addImm(LowBits).addImm(HighBits); + } else + MIB.addImm(Imm.getZExtValue()); + break; + } + default: + report_fatal_error("Unsupported constant bitwidth"); + } +} + +void buildOpName(Register Target, const StringRef &Name, + MachineIRBuilder &MIRBuilder) { + if (!Name.empty()) { + auto MIB = MIRBuilder.buildInstr(SPIRV::OpName).addUse(Target); + addStringImm(Name, MIB); + } +} + +void buildOpDecorate(Register Reg, MachineIRBuilder &MIRBuilder, + Decoration::Decoration Dec, + const std::vector &DecArgs, StringRef StrImm) { + auto MIB = MIRBuilder.buildInstr(SPIRV::OpDecorate).addUse(Reg).addImm(Dec); + if (!StrImm.empty()) + addStringImm(StrImm, MIB); + for (const auto &DecArg : DecArgs) + MIB.addImm(DecArg); +} + +// TODO: maybe the following two functions should be handled in the subtarget +// to allow for different OpenCL vs Vulkan handling. +unsigned int storageClassToAddressSpace(StorageClass::StorageClass SC) { + switch (SC) { + case StorageClass::Function: + return 0; + case StorageClass::CrossWorkgroup: + return 1; + case StorageClass::UniformConstant: + return 2; + case StorageClass::Workgroup: + return 3; + case StorageClass::Generic: + return 4; + case StorageClass::Input: + return 7; + default: + llvm_unreachable("Unable to get address space id"); + } +} + +StorageClass::StorageClass addressSpaceToStorageClass(unsigned int AddrSpace) { + switch (AddrSpace) { + case 0: + return StorageClass::Function; + case 1: + return StorageClass::CrossWorkgroup; + case 2: + return StorageClass::UniformConstant; + case 3: + return StorageClass::Workgroup; + case 4: + return StorageClass::Generic; + case 7: + return StorageClass::Input; + default: + llvm_unreachable("Unknown address space"); + } +} + +MemorySemantics::MemorySemantics +getMemSemanticsForStorageClass(StorageClass::StorageClass sc) { + switch (sc) { + case StorageClass::StorageBuffer: + case StorageClass::Uniform: + return MemorySemantics::UniformMemory; + case StorageClass::Workgroup: + return MemorySemantics::WorkgroupMemory; + case StorageClass::CrossWorkgroup: + return MemorySemantics::CrossWorkgroupMemory; + case StorageClass::AtomicCounter: + return MemorySemantics::AtomicCounterMemory; + case StorageClass::Image: + return MemorySemantics::ImageMemory; + default: + return MemorySemantics::None; + } +} + +bool constrainRegOperands(MachineInstrBuilder &MIB, MachineFunction *MF) { + if (!MF) + MF = MIB->getMF(); + const auto &Subtarget = MF->getSubtarget(); + const TargetInstrInfo *TII = Subtarget.getInstrInfo(); + const TargetRegisterInfo *TRI = Subtarget.getRegisterInfo(); + const RegisterBankInfo *RBI = Subtarget.getRegBankInfo(); + + return constrainSelectedInstRegOperands(*MIB, *TII, *TRI, *RBI); +}