Index: llvm/lib/Target/SPIRV/CMakeLists.txt =================================================================== --- llvm/lib/Target/SPIRV/CMakeLists.txt +++ llvm/lib/Target/SPIRV/CMakeLists.txt @@ -24,8 +24,10 @@ SPIRVInstructionSelector.cpp SPIRVISelLowering.cpp SPIRVLegalizerInfo.cpp + SPIRVLowerConstExpr.cpp SPIRVMCInstLower.cpp SPIRVModuleAnalysis.cpp + SPIRVOCLRegularizer.cpp SPIRVPreLegalizer.cpp SPIRVPrepareFunctions.cpp SPIRVRegisterBankInfo.cpp Index: llvm/lib/Target/SPIRV/SPIRV.h =================================================================== --- llvm/lib/Target/SPIRV/SPIRV.h +++ llvm/lib/Target/SPIRV/SPIRV.h @@ -20,6 +20,8 @@ class RegisterBankInfo; ModulePass *createSPIRVPrepareFunctionsPass(); +FunctionPass *createSPIRVOCLRegularizerPass(); +ModulePass *createSPIRVLowerConstExprLegacyPass(); FunctionPass *createSPIRVPreLegalizerPass(); FunctionPass *createSPIRVEmitIntrinsicsPass(SPIRVTargetMachine *TM); InstructionSelector * Index: llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp =================================================================== --- llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp +++ llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp @@ -809,7 +809,7 @@ } // These queries ask for a single size_t result for a given dimension index, e.g -// size_t get_global_id(uintt dimindex). In SPIR-V, the builtins corresonding to +// size_t get_global_id(uint dimindex). In SPIR-V, the builtins corresonding to // these values are all vec3 types, so we need to extract the correct index or // return defaultVal (0 or 1 depending on the query). We also handle extending // or tuncating in case size_t does not match the expected result type's @@ -1655,16 +1655,15 @@ static const SPIRV::DemangledType *findBuiltinType(StringRef Name) { if (Name.startswith("opencl.")) return SPIRV::lookupBuiltinType(Name); - if (Name.startswith("spirv.")) { - // Some SPIR-V builtin types have a complex list of parameters as part of - // their name (e.g. spirv.Image._void_1_0_0_0_0_0_0). Those parameters often - // are numeric literals which cannot be easily represented by TableGen - // records and should be parsed instead. - unsigned BaseTypeNameLength = - Name.contains('_') ? Name.find('_') - 1 : Name.size(); - return SPIRV::lookupBuiltinType(Name.substr(0, BaseTypeNameLength).str()); - } - return nullptr; + if (!Name.startswith("spirv.")) + return nullptr; + // Some SPIR-V builtin types have a complex list of parameters as part of + // their name (e.g. spirv.Image._void_1_0_0_0_0_0_0). Those parameters often + // are numeric literals which cannot be easily represented by TableGen + // records and should be parsed instead. + unsigned BaseTypeNameLength = + Name.contains('_') ? Name.find('_') - 1 : Name.size(); + return SPIRV::lookupBuiltinType(Name.substr(0, BaseTypeNameLength).str()); } static std::unique_ptr @@ -1674,37 +1673,36 @@ const SPIRV::ImageType *Record = SPIRV::lookupImageType(Name); return std::unique_ptr(new SPIRV::ImageType(*Record)); } - if (Name.startswith("spirv.")) { - // Parse the literals of SPIR-V image builtin parameters. The name should - // have the following format: - // spirv.Image._Type_Dim_Depth_Arrayed_MS_Sampled_ImageFormat_AccessQualifier - // e.g. %spirv.Image._void_1_0_0_0_0_0_0 - StringRef TypeParametersString = Name.substr(strlen("spirv.Image.")); - SmallVector TypeParameters; - SplitString(TypeParametersString, TypeParameters, "_"); - assert(TypeParameters.size() == 8 && - "Wrong number of literals in SPIR-V builtin image type"); - - StringRef SampledType = TypeParameters[0]; - unsigned Dim, Depth, Arrayed, Multisampled, Sampled, Format, AccessQual; - bool AreParameterLiteralsValid = - !(TypeParameters[1].getAsInteger(10, Dim) || - TypeParameters[2].getAsInteger(10, Depth) || - TypeParameters[3].getAsInteger(10, Arrayed) || - TypeParameters[4].getAsInteger(10, Multisampled) || - TypeParameters[5].getAsInteger(10, Sampled) || - TypeParameters[6].getAsInteger(10, Format) || - TypeParameters[7].getAsInteger(10, AccessQual)); - assert(AreParameterLiteralsValid && - "Invalid format of SPIR-V image type parameter literals."); - - return std::unique_ptr(new SPIRV::ImageType{ - Name, SampledType, SPIRV::AccessQualifier::AccessQualifier(AccessQual), - SPIRV::Dim::Dim(Dim), static_cast(Arrayed), - static_cast(Depth), static_cast(Multisampled), - static_cast(Sampled), SPIRV::ImageFormat::ImageFormat(Format)}); - } - llvm_unreachable("Unknown builtin image type name/literal"); + if (!Name.startswith("spirv.")) + llvm_unreachable("Unknown builtin image type name/literal"); + // Parse the literals of SPIR-V image builtin parameters. The name should + // have the following format: + // spirv.Image._Type_Dim_Depth_Arrayed_MS_Sampled_ImageFormat_AccessQualifier + // e.g. %spirv.Image._void_1_0_0_0_0_0_0 + StringRef TypeParametersString = Name.substr(strlen("spirv.Image.")); + SmallVector TypeParameters; + SplitString(TypeParametersString, TypeParameters, "_"); + assert(TypeParameters.size() == 8 && + "Wrong number of literals in SPIR-V builtin image type"); + + StringRef SampledType = TypeParameters[0]; + unsigned Dim, Depth, Arrayed, Multisampled, Sampled, Format, AccessQual; + bool AreParameterLiteralsValid = + !(TypeParameters[1].getAsInteger(10, Dim) || + TypeParameters[2].getAsInteger(10, Depth) || + TypeParameters[3].getAsInteger(10, Arrayed) || + TypeParameters[4].getAsInteger(10, Multisampled) || + TypeParameters[5].getAsInteger(10, Sampled) || + TypeParameters[6].getAsInteger(10, Format) || + TypeParameters[7].getAsInteger(10, AccessQual)); + assert(AreParameterLiteralsValid && + "Invalid format of SPIR-V image type parameter literals."); + + return std::unique_ptr(new SPIRV::ImageType{ + Name, SampledType, SPIRV::AccessQualifier::AccessQualifier(AccessQual), + SPIRV::Dim::Dim(Dim), static_cast(Arrayed), + static_cast(Depth), static_cast(Multisampled), + static_cast(Sampled), SPIRV::ImageFormat::ImageFormat(Format)}); } static std::unique_ptr @@ -1714,46 +1712,46 @@ const SPIRV::PipeType *Record = SPIRV::lookupPipeType(Name); return std::unique_ptr(new SPIRV::PipeType(*Record)); } - if (Name.startswith("spirv.")) { - // Parse the access qualifier literal in the name of the SPIR-V pipe type. - // The name should have the following format: - // spirv.Pipe._AccessQualifier - // e.g. %spirv.Pipe._1 - if (Name.endswith("_0")) - return std::unique_ptr( - new SPIRV::PipeType{Name, SPIRV::AccessQualifier::ReadOnly}); - if (Name.endswith("_1")) - return std::unique_ptr( - new SPIRV::PipeType{Name, SPIRV::AccessQualifier::WriteOnly}); - if (Name.endswith("_2")) - return std::unique_ptr( - new SPIRV::PipeType{Name, SPIRV::AccessQualifier::ReadWrite}); - llvm_unreachable("Unknown pipe type access qualifier literal"); - } - llvm_unreachable("Unknown builtin pipe type name/literal"); + if (!Name.startswith("spirv.")) + llvm_unreachable("Unknown builtin pipe type name/literal"); + // Parse the access qualifier literal in the name of the SPIR-V pipe type. + // The name should have the following format: + // spirv.Pipe._AccessQualifier + // e.g. %spirv.Pipe._1 + if (Name.endswith("_0")) + return std::unique_ptr( + new SPIRV::PipeType{Name, SPIRV::AccessQualifier::ReadOnly}); + if (Name.endswith("_1")) + return std::unique_ptr( + new SPIRV::PipeType{Name, SPIRV::AccessQualifier::WriteOnly}); + if (Name.endswith("_2")) + return std::unique_ptr( + new SPIRV::PipeType{Name, SPIRV::AccessQualifier::ReadWrite}); + llvm_unreachable("Unknown pipe type access qualifier literal"); } //===----------------------------------------------------------------------===// // Implementation functions for builtin types. //===----------------------------------------------------------------------===// -SPIRVType *getNonParametrizedType(const StructType *OpaqueType, - const SPIRV::DemangledType *TypeRecord, - MachineIRBuilder &MIRBuilder, - SPIRVGlobalRegistry *GR) { +static SPIRVType *getNonParametrizedType(const StructType *OpaqueType, + const SPIRV::DemangledType *TypeRecord, + MachineIRBuilder &MIRBuilder, + SPIRVGlobalRegistry *GR) { unsigned Opcode = TypeRecord->Opcode; // Create or get an existing type from GlobalRegistry. return GR->getOrCreateOpTypeByOpcode(OpaqueType, MIRBuilder, Opcode); } -SPIRVType *getSamplerType(MachineIRBuilder &MIRBuilder, - SPIRVGlobalRegistry *GR) { +static SPIRVType *getSamplerType(MachineIRBuilder &MIRBuilder, + SPIRVGlobalRegistry *GR) { // Create or get an existing type from GlobalRegistry. return GR->getOrCreateOpTypeSampler(MIRBuilder); } -SPIRVType *getPipeType(const StructType *OpaqueType, - MachineIRBuilder &MIRBuilder, SPIRVGlobalRegistry *GR) { +static SPIRVType *getPipeType(const StructType *OpaqueType, + MachineIRBuilder &MIRBuilder, + SPIRVGlobalRegistry *GR) { // Lookup pipe type lowering details in TableGen records or parse the // name/literal for details. std::unique_ptr Record = @@ -1762,9 +1760,10 @@ return GR->getOrCreateOpTypePipe(MIRBuilder, Record.get()->Qualifier); } -SPIRVType *getImageType(const StructType *OpaqueType, - SPIRV::AccessQualifier::AccessQualifier AccessQual, - MachineIRBuilder &MIRBuilder, SPIRVGlobalRegistry *GR) { +static SPIRVType * +getImageType(const StructType *OpaqueType, + SPIRV::AccessQualifier::AccessQualifier AccessQual, + MachineIRBuilder &MIRBuilder, SPIRVGlobalRegistry *GR) { // Lookup image type lowering details in TableGen records or parse the // name/literal for details. std::unique_ptr Record = @@ -1781,9 +1780,9 @@ : Record.get()->Qualifier); } -SPIRVType *getSampledImageType(const StructType *OpaqueType, - MachineIRBuilder &MIRBuilder, - SPIRVGlobalRegistry *GR) { +static SPIRVType *getSampledImageType(const StructType *OpaqueType, + MachineIRBuilder &MIRBuilder, + SPIRVGlobalRegistry *GR) { StringRef TypeParametersString = OpaqueType->getName().substr(strlen("spirv.SampledImage.")); LLVMContext &Context = MIRBuilder.getMF().getFunction().getContext(); Index: llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp =================================================================== --- llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp +++ llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp @@ -115,6 +115,117 @@ return FunctionType::get(RetTy, ArgTypes, F.isVarArg()); } +static MDString *getKernelArgAttribute(const Function &KernelFunction, + unsigned ArgIdx, + const StringRef AttributeName) { + assert(KernelFunction.getCallingConv() == CallingConv::SPIR_KERNEL && + "Kernel attributes are only attached/belong to kernel functions."); + + // Lookup the argument attribute in metadata attached to the kernel function. + MDNode *Node = KernelFunction.getMetadata(AttributeName); + if (Node && ArgIdx < Node->getNumOperands()) + return cast(Node->getOperand(ArgIdx)); + + // Sometimes metadata containing kernel attributes is not attached to the + // function, but can be found in the named module-level metadata instead. For + // example: + // !opencl.kernels = !{!0} + // !0 = !{void ()* @someKernelFunction, !1, ...} + // !1 = !{!"kernel_arg_addr_space", ...} + // + // In this case the actual index of searched argument attribute is ArgIdx + 1, + // since the first metadata node operand is occupied by attribute name + // ("kernel_arg_addr_space" in the example above). + unsigned MDArgIdx = ArgIdx + 1; + NamedMDNode *OpenCLKernelsMD = + KernelFunction.getParent()->getNamedMetadata("opencl.kernels"); + if (!OpenCLKernelsMD || OpenCLKernelsMD->getNumOperands() == 0) + return nullptr; + + MDNode *KernelToMDNodeList = OpenCLKernelsMD->getOperand(0); + // KernelToMDNodeList contains kernel function declarations followed by + // corresponding MDNodes for each attribute. Search only MDNodes "belonging" + // to the currently lowered kernel function. + bool FoundLoweredKernelFunction = false; + for (const MDOperand &Operand : KernelToMDNodeList->operands()) { + ValueAsMetadata *MaybeValue = dyn_cast(Operand); + if (MaybeValue && + dyn_cast_or_null(MaybeValue->getValue())->getName() == + KernelFunction.getName()) { + FoundLoweredKernelFunction = true; + continue; + } + if (MaybeValue && FoundLoweredKernelFunction) + return nullptr; + + MDNode *MaybeNode = dyn_cast(Operand); + if (FoundLoweredKernelFunction && MaybeNode && + cast(MaybeNode->getOperand(0))->getString() == + AttributeName && + MDArgIdx < MaybeNode->getNumOperands()) + return cast(MaybeNode->getOperand(MDArgIdx)); + } + + return nullptr; +} + +static SPIRV::AccessQualifier::AccessQualifier +getArgAccessQual(const Function &F, unsigned ArgIdx) { + SPIRV::AccessQualifier::AccessQualifier AQ = + SPIRV::AccessQualifier::ReadWrite; + + if (F.getCallingConv() != CallingConv::SPIR_KERNEL) + return AQ; + + MDString *ArgAttribute = + getKernelArgAttribute(F, ArgIdx, "kernel_arg_access_qual"); + if (ArgAttribute) { + StringRef AQString = ArgAttribute->getString(); + if (AQString.compare("read_only") == 0) + AQ = SPIRV::AccessQualifier::ReadOnly; + else if (AQString.compare("write_only") == 0) + AQ = SPIRV::AccessQualifier::WriteOnly; + } + + return AQ; +} + +static std::vector +getKernelArgTypeQual(const Function &KernelFunction, unsigned ArgIdx) { + MDString *ArgAttribute = + getKernelArgAttribute(KernelFunction, ArgIdx, "kernel_arg_type_qual"); + if (ArgAttribute) { + StringRef TypeQual = ArgAttribute->getString(); + if (TypeQual.compare("volatile") == 0) + return {SPIRV::Decoration::Volatile}; + } + + return {}; +} + +static Type *getArgType(const Function &F, unsigned ArgIdx) { + Type *OriginalArgType = getOriginalFunctionType(F)->getParamType(ArgIdx); + + if (F.getCallingConv() == CallingConv::SPIR_KERNEL && + !isSpecialOpaqueType(OriginalArgType)) { + + MDString *MDKernelArgType = + getKernelArgAttribute(F, ArgIdx, "kernel_arg_type"); + if (MDKernelArgType && MDKernelArgType->getString().endswith("_t")) { + std::string KernelArgTypeStr = + "opencl." + MDKernelArgType->getString().str(); + + Type *ExistingOpaqueType = + StructType::getTypeByName(F.getContext(), KernelArgTypeStr); + return ExistingOpaqueType + ? ExistingOpaqueType + : StructType::create(F.getContext(), KernelArgTypeStr); + } + } + + return OriginalArgType; +} + bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder, const Function &F, ArrayRef> VRegs, @@ -132,18 +243,10 @@ // TODO: handle the case of multiple registers. if (VRegs[i].size() > 1) return false; - Type *ArgTy = FTy->getParamType(i); - SPIRV::AccessQualifier::AccessQualifier AQ = - SPIRV::AccessQualifier::ReadWrite; - MDNode *Node = F.getMetadata("kernel_arg_access_qual"); - if (Node && i < Node->getNumOperands()) { - StringRef AQString = cast(Node->getOperand(i))->getString(); - if (AQString.compare("read_only") == 0) - AQ = SPIRV::AccessQualifier::ReadOnly; - else if (AQString.compare("write_only") == 0) - AQ = SPIRV::AccessQualifier::WriteOnly; - } - auto *SpirvTy = GR->assignTypeToVReg(ArgTy, VRegs[i][0], MIRBuilder, AQ); + SPIRV::AccessQualifier::AccessQualifier ArgAccessQual = + getArgAccessQual(F, i); + auto *SpirvTy = GR->assignTypeToVReg(getArgType(F, i), VRegs[i][0], + MIRBuilder, ArgAccessQual); ArgTypeVRegs.push_back(SpirvTy); if (Arg.hasName()) @@ -178,14 +281,15 @@ buildOpDecorate(VRegs[i][0], MIRBuilder, SPIRV::Decoration::FuncParamAttr, {Attr}); } - Node = F.getMetadata("kernel_arg_type_qual"); - if (Node && i < Node->getNumOperands()) { - StringRef TypeQual = cast(Node->getOperand(i))->getString(); - if (TypeQual.compare("volatile") == 0) - buildOpDecorate(VRegs[i][0], MIRBuilder, SPIRV::Decoration::Volatile, - {}); + + if (F.getCallingConv() == CallingConv::SPIR_KERNEL) { + std::vector ArgTypeQualDecs = + getKernelArgTypeQual(F, i); + for (SPIRV::Decoration::Decoration Decoration : ArgTypeQualDecs) + buildOpDecorate(VRegs[i][0], MIRBuilder, Decoration, {}); } - Node = F.getMetadata("spirv.ParameterDecorations"); + + MDNode *Node = F.getMetadata("spirv.ParameterDecorations"); if (Node && i < Node->getNumOperands() && isa(Node->getOperand(i))) { MDNode *MD = cast(Node->getOperand(i)); @@ -286,7 +390,7 @@ Register ResVReg = Info.OrigRet.Regs.empty() ? Register(0) : Info.OrigRet.Regs[0]; std::string FuncName = Info.Callee.getGlobal()->getGlobalIdentifier(); - std::string DemangledName = mayBeOclOrSpirvBuiltin(FuncName); + std::string DemangledName = getOclOrSpirvBuiltinDemangledName(FuncName); const auto *ST = static_cast(&MF.getSubtarget()); // TODO: check that it's OCL builtin, then apply OpenCL_std. if (!DemangledName.empty() && CF && CF->isDeclaration() && Index: llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp =================================================================== --- llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp +++ llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp @@ -544,26 +544,6 @@ return MIB; } -static bool isOpenCLBuiltinType(const StructType *SType) { - return SType->isOpaque() && SType->hasName() && - SType->getName().startswith("opencl."); -} - -static bool isSPIRVBuiltinType(const StructType *SType) { - return SType->isOpaque() && SType->hasName() && - SType->getName().startswith("spirv."); -} - -static bool isSpecialType(const Type *Ty) { - if (auto PType = dyn_cast(Ty)) { - if (!PType->isOpaque()) - Ty = PType->getNonOpaquePointerElementType(); - } - if (auto SType = dyn_cast(Ty)) - return isOpenCLBuiltinType(SType) || isSPIRVBuiltinType(SType); - return false; -} - SPIRVType *SPIRVGlobalRegistry::getOrCreateSpecialType( const Type *Ty, MachineIRBuilder &MIRBuilder, SPIRV::AccessQualifier::AccessQualifier AccQual) { @@ -574,7 +554,7 @@ Ty = PType->getNonOpaquePointerElementType(); } auto SType = cast(Ty); - assert(isOpenCLBuiltinType(SType) || isSPIRVBuiltinType(SType)); + assert(isSpecialOpaqueType(SType) && "Not a special opaque builtin type"); return SPIRV::lowerBuiltinType(SType, AccQual, MIRBuilder, this); } @@ -639,7 +619,7 @@ SPIRVType *SPIRVGlobalRegistry::createSPIRVType( const Type *Ty, MachineIRBuilder &MIRBuilder, SPIRV::AccessQualifier::AccessQualifier AccQual, bool EmitIR) { - if (isSpecialType(Ty)) + if (isSpecialOpaqueType(Ty)) return getOrCreateSpecialType(Ty, MIRBuilder, AccQual); auto &TypeToSPIRVTypeMap = DT.getTypes()->getAllUses(); auto t = TypeToSPIRVTypeMap.find(Ty); @@ -725,7 +705,7 @@ // Do not add OpTypeForwardPointer to DT, a corresponding normal pointer type // will be added later. For special types it is already added to DT. if (SpirvType->getOpcode() != SPIRV::OpTypeForwardPointer && !Reg.isValid() && - !isSpecialType(Ty)) + !isSpecialOpaqueType(Ty)) DT.add(Ty, &MIRBuilder.getMF(), getSPIRVTypeID(SpirvType)); return SpirvType; @@ -745,7 +725,7 @@ const Type *Ty, MachineIRBuilder &MIRBuilder, SPIRV::AccessQualifier::AccessQualifier AccessQual, bool EmitIR) { Register Reg = DT.find(Ty, &MIRBuilder.getMF()); - if (Reg.isValid() && !isSpecialType(Ty)) + if (Reg.isValid() && !isSpecialOpaqueType(Ty)) return getSPIRVTypeForVReg(Reg); TypesInProcessing.clear(); SPIRVType *STy = restOfCreateSPIRVType(Ty, MIRBuilder, AccessQual, EmitIR); Index: llvm/lib/Target/SPIRV/SPIRVISelLowering.h =================================================================== --- llvm/lib/Target/SPIRV/SPIRVISelLowering.h +++ llvm/lib/Target/SPIRV/SPIRVISelLowering.h @@ -41,6 +41,9 @@ EVT VT) const override; MVT getRegisterTypeForCallingConv(LLVMContext &Context, CallingConv::ID CC, EVT VT) const override; + bool getTgtMemIntrinsic(IntrinsicInfo &Info, const CallInst &I, + MachineFunction &MF, + unsigned Intrinsic) const override; }; } // namespace llvm Index: llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp =================================================================== --- llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp +++ llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp @@ -12,6 +12,7 @@ #include "SPIRVISelLowering.h" #include "SPIRV.h" +#include "llvm/IR/IntrinsicsSPIRV.h" #define DEBUG_TYPE "spirv-lower" @@ -43,3 +44,31 @@ } return getRegisterType(Context, VT); } + +bool SPIRVTargetLowering::getTgtMemIntrinsic(IntrinsicInfo &Info, + const CallInst &I, + MachineFunction &MF, + unsigned Intrinsic) const { + unsigned AlignIdx = 3; + switch (Intrinsic) { + case Intrinsic::spv_load: + AlignIdx = 2; + LLVM_FALLTHROUGH; + case Intrinsic::spv_store: { + if (I.getNumOperands() >= AlignIdx + 1) { + auto *AlignOp = cast(I.getOperand(AlignIdx)); + Info.align = Align(AlignOp->getZExtValue()); + } + Info.flags = static_cast( + cast(I.getOperand(AlignIdx - 1))->getZExtValue()); + Info.memVT = MVT::i64; + // TODO: take into account opaque pointers (don't use getElementType). + // MVT::getVT(PtrTy->getElementType()); + return true; + break; + } + default: + break; + } + return false; +} Index: llvm/lib/Target/SPIRV/SPIRVInstrFormats.td =================================================================== --- llvm/lib/Target/SPIRV/SPIRVInstrFormats.td +++ llvm/lib/Target/SPIRV/SPIRVInstrFormats.td @@ -28,4 +28,5 @@ // Pseudo instructions class Pseudo : Op<0, outs, ins, ""> { let isPseudo = 1; + let hasSideEffects = 0; } Index: llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp =================================================================== --- llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp +++ llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp @@ -145,6 +145,9 @@ getActionDefinitionsBuilder({G_MEMCPY, G_MEMMOVE}) .legalIf(all(typeInSet(0, allWritablePtrs), typeInSet(1, allPtrs))); + getActionDefinitionsBuilder(G_MEMSET).legalIf( + all(typeInSet(0, allWritablePtrs), typeInSet(1, allIntScalars))); + getActionDefinitionsBuilder(G_ADDRSPACE_CAST) .legalForCartesianProduct(allPtrs, allPtrs); @@ -223,8 +226,8 @@ // Pointer-handling. getActionDefinitionsBuilder(G_FRAME_INDEX).legalFor({p0}); - // Control-flow. - getActionDefinitionsBuilder(G_BRCOND).legalFor({s1}); + // Control-flow. In some cases (e.g. constants) s1 may be promoted to s32. + getActionDefinitionsBuilder(G_BRCOND).legalFor({s1, s32}); getActionDefinitionsBuilder({G_FPOW, G_FEXP, Index: llvm/lib/Target/SPIRV/SPIRVLowerConstExpr.cpp =================================================================== --- /dev/null +++ llvm/lib/Target/SPIRV/SPIRVLowerConstExpr.cpp @@ -0,0 +1,192 @@ +//===-- SPIRVLowerConstExpr.cpp - Regularize LLVM IR for SPIR-V --- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This pass implements regularization of LLVM module for SPIR-V. +// +//===----------------------------------------------------------------------===// + +#include "SPIRV.h" +#include "SPIRVTargetMachine.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/InstIterator.h" +#include "llvm/IR/InstVisitor.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/PassManager.h" +#include "llvm/Pass.h" +#include "llvm/Support/CommandLine.h" + +#include + +#define DEBUG_TYPE "spv-lower-const-expr" + +using namespace llvm; + +namespace llvm { +void initializeSPIRVLowerConstExprLegacyPass(PassRegistry &); +} // namespace llvm + +class SPIRVLowerConstExprBase { +public: + SPIRVLowerConstExprBase() : M(nullptr), Ctx(nullptr) {} + bool runLowerConstExpr(Module &M); + void visit(Module *M); + +private: + Module *M; + LLVMContext *Ctx; +}; + +class SPIRVLowerConstExprPass + : public llvm::PassInfoMixin, + public SPIRVLowerConstExprBase { +public: + llvm::PreservedAnalyses run(llvm::Module &M, + llvm::ModuleAnalysisManager &MAM) { + return runLowerConstExpr(M) ? llvm::PreservedAnalyses::none() + : llvm::PreservedAnalyses::all(); + } +}; + +namespace { +class SPIRVLowerConstExprLegacy : public ModulePass, + public SPIRVLowerConstExprBase { +public: + static char ID; + SPIRVLowerConstExprLegacy() : ModulePass(ID) { + initializeSPIRVLowerConstExprLegacyPass(*PassRegistry::getPassRegistry()); + } + bool runOnModule(Module &M) override { return runLowerConstExpr(M); } +}; +} // namespace + +char SPIRVLowerConstExprLegacy::ID = 0; + +INITIALIZE_PASS(SPIRVLowerConstExprLegacy, DEBUG_TYPE, + "Regularize LLVM IR for SPIR-V", false, false) + +bool SPIRVLowerConstExprBase::runLowerConstExpr(Module &Module) { + M = &Module; + Ctx = &M->getContext(); + + LLVM_DEBUG(dbgs() << "Enter SPIRVLowerConstExpr:\n"); + visit(M); + + return true; +} + +/// Since SPIR-V cannot represent constant expression, constant expressions +/// in LLVM IR need to be lowered to instructions. For each function, +/// the constant expressions used by instructions of the function are replaced +/// by instructions placed in the entry block since it dominates all other BBs. +/// Each constant expression only needs to be lowered once in each function +/// and all uses of it by instructions in that function are replaced by +/// one instruction. +/// TODO: remove redundant instructions for common subexpression. +void SPIRVLowerConstExprBase::visit(Module *M) { + for (Function &F : M->functions()) { + std::list WorkList; + for (auto &II : instructions(F)) + WorkList.push_back(&II); + + auto FBegin = F.begin(); + while (!WorkList.empty()) { + Instruction *II = WorkList.front(); + + auto LowerOp = [&II, &FBegin, &F](Value *V) -> Value * { + if (isa(V)) + return V; + auto *CE = cast(V); + LLVM_DEBUG(dbgs() << "[lowerConstantExpressions] " << *CE); + auto ReplInst = CE->getAsInstruction(); + auto InsPoint = II->getParent() == &*FBegin ? II : &FBegin->back(); + ReplInst->insertBefore(InsPoint); + LLVM_DEBUG(dbgs() << " -> " << *ReplInst << '\n'); + std::vector Users; + // Do not replace use during iteration of use. Do it in another loop. + for (auto U : CE->users()) { + LLVM_DEBUG(dbgs() + << "[lowerConstantExpressions] Use: " << *U << '\n'); + auto InstUser = dyn_cast(U); + // Only replace users in scope of current function. + if (InstUser && InstUser->getParent()->getParent() == &F) + Users.push_back(InstUser); + } + for (auto &User : Users) { + if (ReplInst->getParent() == User->getParent() && User->comesBefore(ReplInst)) + ReplInst->moveBefore(User); + User->replaceUsesOfWith(CE, ReplInst); + } + return ReplInst; + }; + + WorkList.pop_front(); + auto LowerConstantVec = [&II, &LowerOp, &WorkList, + &M](ConstantVector *Vec, + unsigned NumOfOp) -> Value * { + if (std::all_of(Vec->op_begin(), Vec->op_end(), [](Value *V) { + return isa(V) || isa(V); + })) { + // Expand a vector of constexprs and construct it back with + // series of insertelement instructions. + std::list OpList; + std::transform(Vec->op_begin(), Vec->op_end(), + std::back_inserter(OpList), + [LowerOp](Value *V) { return LowerOp(V); }); + Value *Repl = nullptr; + unsigned Idx = 0; + auto *PhiII = dyn_cast(II); + Instruction *InsPoint = + PhiII ? &PhiII->getIncomingBlock(NumOfOp)->back() : II; + std::list ReplList; + for (auto V : OpList) { + if (auto *Inst = dyn_cast(V)) + ReplList.push_back(Inst); + Repl = InsertElementInst::Create( + (Repl ? Repl : UndefValue::get(Vec->getType())), V, + ConstantInt::get(Type::getInt32Ty(M->getContext()), Idx++), "", + InsPoint); + } + WorkList.splice(WorkList.begin(), ReplList); + return Repl; + } + return nullptr; + }; + + for (unsigned OI = 0, OE = II->getNumOperands(); OI != OE; ++OI) { + auto *Op = II->getOperand(OI); + if (auto *Vec = dyn_cast(Op)) { + Value *ReplInst = LowerConstantVec(Vec, OI); + if (ReplInst) + II->replaceUsesOfWith(Op, ReplInst); + } else if (auto CE = dyn_cast(Op)) { + WorkList.push_front(cast(LowerOp(CE))); + } else if (auto MDAsVal = dyn_cast(Op)) { + auto ConstMD = dyn_cast(MDAsVal->getMetadata()); + if (!ConstMD) + continue; + Constant *C = ConstMD->getValue(); + Value *ReplInst = nullptr; + if (auto *Vec = dyn_cast(C)) + ReplInst = LowerConstantVec(Vec, OI); + if (auto *CE = dyn_cast(C)) + ReplInst = LowerOp(CE); + if (!ReplInst) + continue; + Metadata *RepMD = ValueAsMetadata::get(ReplInst); + Value *RepMDVal = MetadataAsValue::get(M->getContext(), RepMD); + II->setOperand(OI, RepMDVal); + WorkList.push_front(cast(ReplInst)); + } + } + } + } +} + +ModulePass *llvm::createSPIRVLowerConstExprLegacyPass() { + return new SPIRVLowerConstExprLegacy(); +} Index: llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp =================================================================== --- llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp +++ llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp @@ -751,6 +751,7 @@ break; case SPIRV::OpTypeDeviceEvent: case SPIRV::OpTypeQueue: + case SPIRV::OpBuildNDRange: Reqs.addCapability(SPIRV::Capability::DeviceEnqueue); break; case SPIRV::OpDecorate: Index: llvm/lib/Target/SPIRV/SPIRVOCLRegularizer.cpp =================================================================== --- /dev/null +++ llvm/lib/Target/SPIRV/SPIRVOCLRegularizer.cpp @@ -0,0 +1,137 @@ +//===-- SPIRVOCLRegularizer.cpp - regularize OpenCL builtins ----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This pass fixes calls to OCL builtins that accept vector arguments and one +// of them is actually a scalar splat. The prototype of this pass was taken +// from SPIRV-LLVM translator. +// +//===----------------------------------------------------------------------===// + +#include "SPIRV.h" +#include "SPIRVSubtarget.h" +#include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h" +#include "llvm/Demangle/Demangle.h" +#include "llvm/IR/InstVisitor.h" +#include "llvm/Transforms/Utils/Cloning.h" + +using namespace llvm; + +namespace llvm { +void initializeSPIRVOCLRegularizerPass(PassRegistry &); +} + +namespace { +struct SPIRVOCLRegularizer : public FunctionPass, + InstVisitor { + DenseMap Old2NewFuncs; + +public: + static char ID; + SPIRVOCLRegularizer() : FunctionPass(ID) { + initializeSPIRVOCLRegularizerPass(*PassRegistry::getPassRegistry()); + } + bool runOnFunction(Function &F) override; + StringRef getPassName() const override { return "SPIRV OCL Regularizer"; } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + FunctionPass::getAnalysisUsage(AU); + } + void visitCallInst(CallInst &CI); + +private: + void visitCallScalToVec(CallInst *CI, StringRef MangledName, + StringRef DemangledName); +}; +} // namespace + +char SPIRVOCLRegularizer::ID = 0; + +INITIALIZE_PASS(SPIRVOCLRegularizer, "spirv-ocl-regularizer", + "SPIRV OpenCL Regularizer", false, false) + +void SPIRVOCLRegularizer::visitCallInst(CallInst &CI) { + auto F = CI.getCalledFunction(); + if (!F) + return; + + auto MangledName = F->getName(); + size_t n; + int status; + char *NameStr = itaniumDemangle(F->getName().data(), nullptr, &n, &status); + StringRef DemangledName(NameStr); + + // TODO: add support for other builtins. + if (DemangledName.startswith("fmin") || DemangledName.startswith("fmax") || + DemangledName.startswith("min") || DemangledName.startswith("max")) + visitCallScalToVec(&CI, MangledName, DemangledName); + free(NameStr); +} + +void SPIRVOCLRegularizer::visitCallScalToVec(CallInst *CI, + StringRef MangledName, + StringRef DemangledName) { + // Check if all arguments have the same type - it's simple case. + auto Uniform = true; + Type *Arg0Ty = CI->getOperand(0)->getType(); + auto IsArg0Vector = isa(Arg0Ty); + for (unsigned I = 1, E = CI->arg_size(); Uniform && (I != E); ++I) + Uniform = isa(CI->getOperand(I)->getType()) == IsArg0Vector; + if (Uniform) + return; + + auto *OldF = CI->getCalledFunction(); + Function *NewF = nullptr; + if (!Old2NewFuncs.count(OldF)) { + AttributeList Attrs = CI->getCalledFunction()->getAttributes(); + SmallVector ArgTypes = {OldF->getArg(0)->getType(), Arg0Ty}; + auto *NewFTy = + FunctionType::get(OldF->getReturnType(), ArgTypes, OldF->isVarArg()); + NewF = Function::Create(NewFTy, OldF->getLinkage(), OldF->getName(), + *OldF->getParent()); + ValueToValueMapTy VMap; + auto NewFArgIt = NewF->arg_begin(); + for (auto &Arg : OldF->args()) { + auto ArgName = Arg.getName(); + NewFArgIt->setName(ArgName); + VMap[&Arg] = &(*NewFArgIt++); + } + SmallVector Returns; + CloneFunctionInto(NewF, OldF, VMap, + CloneFunctionChangeType::LocalChangesOnly, Returns); + NewF->setAttributes(Attrs); + Old2NewFuncs[OldF] = NewF; + } else { + NewF = Old2NewFuncs[OldF]; + } + assert(NewF); + + auto ConstInt = ConstantInt::get(IntegerType::get(CI->getContext(), 32), 0); + UndefValue *UndefVal = UndefValue::get(Arg0Ty); + Instruction *Inst = InsertElementInst::Create(UndefVal, CI->getOperand(1), ConstInt, "", CI); + ElementCount VecElemCount = cast(Arg0Ty)->getElementCount(); + Constant *ConstVec = ConstantVector::getSplat(VecElemCount, ConstInt); + Value *NewVec = new ShuffleVectorInst(Inst, UndefVal, ConstVec, "", CI); + CI->setOperand(1, NewVec); + CI->replaceUsesOfWith(OldF, NewF); + CI->mutateFunctionType(NewF->getFunctionType()); +} + +bool SPIRVOCLRegularizer::runOnFunction(Function &F) { + visit(F); + for (auto &OldNew : Old2NewFuncs) { + Function *OldF = OldNew.first; + Function *NewF = OldNew.second; + NewF->takeName(OldF); + OldF->eraseFromParent(); + } + return true; +} + +FunctionPass *llvm::createSPIRVOCLRegularizerPass() { + return new SPIRVOCLRegularizer(); +} Index: llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp =================================================================== --- llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp +++ llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp @@ -369,11 +369,19 @@ if (MI.getOpcode() != SPIRV::ASSIGN_TYPE) continue; Register SrcReg = MI.getOperand(1).getReg(); - if (!isTypeFoldingSupported(MRI.getVRegDef(SrcReg)->getOpcode())) + unsigned Opcode = MRI.getVRegDef(SrcReg)->getOpcode(); + if (!isTypeFoldingSupported(Opcode)) continue; Register DstReg = MI.getOperand(0).getReg(); if (MRI.getType(DstReg).isVector()) MRI.setRegClass(DstReg, &SPIRV::IDRegClass); + // Don't need to reset type of register holding constant and used in + // G_ADDRSPACE_CAST, since it braaks legalizer. + if (Opcode == TargetOpcode::G_CONSTANT && MRI.hasOneUse(DstReg)) { + MachineInstr &UseMI = *MRI.use_instr_begin(DstReg); + if (UseMI.getOpcode() == TargetOpcode::G_ADDRSPACE_CAST) + continue; + } MRI.setType(DstReg, LLT::scalar(32)); } } Index: llvm/lib/Target/SPIRV/SPIRVPrepareFunctions.cpp =================================================================== --- llvm/lib/Target/SPIRV/SPIRVPrepareFunctions.cpp +++ llvm/lib/Target/SPIRV/SPIRVPrepareFunctions.cpp @@ -18,6 +18,7 @@ #include "SPIRV.h" #include "SPIRVTargetMachine.h" #include "SPIRVUtils.h" +#include "llvm/CodeGen/IntrinsicLowering.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/Transforms/Utils/Cloning.h" @@ -141,6 +142,69 @@ return NewF; } +static void lowerIntrinsicToFunction(Module *M, IntrinsicInst *Intrinsic) { + // For @llvm.memset.* intrinsic cases with constant value and length arguments + // are emulated via "storing" a constant array to the destination. For other + // cases we wrap the intrinsic in @spirv.llvm_memset_* function and expand the + // intrinsic to a loop via expandMemSetAsLoop(). + if (auto *MSI = dyn_cast(Intrinsic)) + if (isa(MSI->getValue()) && isa(MSI->getLength())) + return; // It is handled later using OpCopyMemorySized. + + std::string FuncName = lowerLLVMIntrinsicName(Intrinsic); + if (Intrinsic->isVolatile()) + FuncName += ".volatile"; + // Redirect @llvm.intrinsic.* call to @spirv.llvm_intrinsic_* + Function *F = M->getFunction(FuncName); + if (F) { + Intrinsic->setCalledFunction(F); + return; + } + // TODO copy arguments attributes: nocapture writeonly. + FunctionCallee FC = + M->getOrInsertFunction(FuncName, Intrinsic->getFunctionType()); + auto IntrinsicID = Intrinsic->getIntrinsicID(); + Intrinsic->setCalledFunction(FC); + + F = dyn_cast(FC.getCallee()); + assert(F && "Callee must be a function"); + + switch (IntrinsicID) { + case Intrinsic::memset: { + auto *MSI = static_cast(Intrinsic); + Argument *Dest = F->getArg(0); + Argument *Val = F->getArg(1); + Argument *Len = F->getArg(2); + Argument *IsVolatile = F->getArg(3); + Dest->setName("dest"); + Val->setName("val"); + Len->setName("len"); + IsVolatile->setName("isvolatile"); + BasicBlock *EntryBB = BasicBlock::Create(M->getContext(), "entry", F); + IRBuilder<> IRB(EntryBB); + auto *MemSet = IRB.CreateMemSet(Dest, Val, Len, MSI->getDestAlign(), + MSI->isVolatile()); + IRB.CreateRetVoid(); + expandMemSetAsLoop(cast(MemSet)); + MemSet->eraseFromParent(); + break; + } + case Intrinsic::bswap: { + BasicBlock *EntryBB = BasicBlock::Create(M->getContext(), "entry", F); + IRBuilder<> IRB(EntryBB); + auto *BSwap = IRB.CreateIntrinsic(Intrinsic::bswap, Intrinsic->getType(), + F->getArg(0)); + IRB.CreateRet(BSwap); + IntrinsicLowering IL(M->getDataLayout()); + IL.LowerIntrinsicCall(BSwap); + break; + } + default: + break; + } + return; +} + static void lowerFunnelShifts(Module *M, IntrinsicInst *FSHIntrinsic) { // Get a separate function - otherwise, we'd have to rework the CFG of the // current one. Then simply replace the intrinsic uses with a call to the new @@ -248,8 +312,11 @@ if (!CF || !CF->isIntrinsic()) continue; auto *II = cast(Call); - if (II->getIntrinsicID() == Intrinsic::fshl || - II->getIntrinsicID() == Intrinsic::fshr) + if (II->getIntrinsicID() == Intrinsic::memset || + II->getIntrinsicID() == Intrinsic::bswap) + lowerIntrinsicToFunction(M, II); + else if (II->getIntrinsicID() == Intrinsic::fshl || + II->getIntrinsicID() == Intrinsic::fshr) lowerFunnelShifts(M, II); else if (II->getIntrinsicID() == Intrinsic::umul_with_overflow) lowerUMulWithOverflow(M, II); Index: llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp =================================================================== --- llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp +++ llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp @@ -70,7 +70,7 @@ : LLVMTargetMachine(T, computeDataLayout(TT), TT, CPU, FS, Options, getEffectiveRelocModel(RM), getEffectiveCodeModel(CM, CodeModel::Small), OL), - TLOF(std::make_unique()), + TLOF(std::make_unique()), Subtarget(TT, CPU.str(), FS.str(), *this) { initAsmInfo(); setGlobalISel(true); @@ -142,6 +142,8 @@ void SPIRVPassConfig::addIRPasses() { TargetPassConfig::addIRPasses(); + addPass(createSPIRVLowerConstExprLegacyPass()); + addPass(createSPIRVOCLRegularizerPass()); addPass(createSPIRVPrepareFunctionsPass()); } @@ -159,13 +161,13 @@ addPass(createSPIRVPreLegalizerPass()); } -// Use a default legalizer. +// Use the default legalizer. bool SPIRVPassConfig::addLegalizeMachineIR() { addPass(new Legalizer()); return false; } -// Do not add a RegBankSelect pass, as we only ever need virtual registers. +// Do not add the RegBankSelect pass, as we only ever need virtual registers. bool SPIRVPassConfig::addRegBankSelect() { disablePass(&RegBankSelect::ID); return false; @@ -183,6 +185,7 @@ }; } // namespace +// Add the custom SPIRVInstructionSelect from above. bool SPIRVPassConfig::addGlobalInstructionSelect() { addPass(new SPIRVInstructionSelect()); return false; Index: llvm/lib/Target/SPIRV/SPIRVUtils.h =================================================================== --- llvm/lib/Target/SPIRV/SPIRVUtils.h +++ llvm/lib/Target/SPIRV/SPIRVUtils.h @@ -84,8 +84,11 @@ // Get type of i-th operand of the metadata node. Type *getMDOperandAsType(const MDNode *N, unsigned I); -// Return a demangled name with arg type info by itaniumDemangle(). -// If the parser fails, return only function name. -std::string mayBeOclOrSpirvBuiltin(StringRef Name); +// If OpenCL or SPIR-V builtin function name is recognized, return a demangled +// name, otherwise return an empty string. +std::string getOclOrSpirvBuiltinDemangledName(StringRef Name); + +// Check if given LLVM type is a special opaque builtin type. +bool isSpecialOpaqueType(const Type *Ty); } // namespace llvm #endif // LLVM_LIB_TARGET_SPIRV_SPIRVUTILS_H Index: llvm/lib/Target/SPIRV/SPIRVUtils.cpp =================================================================== --- llvm/lib/Target/SPIRV/SPIRVUtils.cpp +++ llvm/lib/Target/SPIRV/SPIRVUtils.cpp @@ -289,7 +289,7 @@ Name == "__translate_sampler_initializer"; } -std::string mayBeOclOrSpirvBuiltin(StringRef Name) { +std::string getOclOrSpirvBuiltinDemangledName(StringRef Name) { bool IsNonMangledOCL = isNonMangledOCLBuiltin(Name); bool IsNonMangledSPIRV = Name.startswith("__spirv_"); bool IsMangled = Name.startswith("_Z"); @@ -331,4 +331,24 @@ .getAsInteger(10, Len); return Name.substr(Start, Len).str(); } + +static bool isOpenCLBuiltinType(const StructType *SType) { + return SType->isOpaque() && SType->hasName() && + SType->getName().startswith("opencl."); +} + +static bool isSPIRVBuiltinType(const StructType *SType) { + return SType->isOpaque() && SType->hasName() && + SType->getName().startswith("spirv."); +} + +bool isSpecialOpaqueType(const Type *Ty) { + if (auto PType = dyn_cast(Ty)) { + if (!PType->isOpaque()) + Ty = PType->getNonOpaquePointerElementType(); + } + if (auto SType = dyn_cast(Ty)) + return isOpenCLBuiltinType(SType) || isSPIRVBuiltinType(SType); + return false; +} } // namespace llvm