Index: llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp =================================================================== --- llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp +++ llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp @@ -115,6 +115,103 @@ return FunctionType::get(RetTy, ArgTypes, F.isVarArg()); } +static MDString *getKernelArgAttribute(const Function &KernelFunction, + unsigned ArgIdx, + const StringRef AttributeName) { + assert(KernelFunction.getCallingConv() == CallingConv::SPIR_KERNEL && + "Kernel attributes are attached/belong only to kernel functions"); + + // Lookup the argument attribute in metadata attached to the kernel function. + MDNode *Node = KernelFunction.getMetadata(AttributeName); + if (Node && ArgIdx < Node->getNumOperands()) + return cast(Node->getOperand(ArgIdx)); + + // Sometimes metadata containing kernel attributes is not attached to the + // function, but can be found in the named module-level metadata instead. + // For example: + // !opencl.kernels = !{!0} + // !0 = !{void ()* @someKernelFunction, !1, ...} + // !1 = !{!"kernel_arg_addr_space", ...} + // In this case the actual index of searched argument attribute is ArgIdx + 1, + // since the first metadata node operand is occupied by attribute name + // ("kernel_arg_addr_space" in the example above). + unsigned MDArgIdx = ArgIdx + 1; + NamedMDNode *OpenCLKernelsMD = + KernelFunction.getParent()->getNamedMetadata("opencl.kernels"); + if (!OpenCLKernelsMD || OpenCLKernelsMD->getNumOperands() == 0) + return nullptr; + + // KernelToMDNodeList contains kernel function declarations followed by + // corresponding MDNodes for each attribute. Search only MDNodes "belonging" + // to the currently lowered kernel function. + MDNode *KernelToMDNodeList = OpenCLKernelsMD->getOperand(0); + bool FoundLoweredKernelFunction = false; + for (const MDOperand &Operand : KernelToMDNodeList->operands()) { + ValueAsMetadata *MaybeValue = dyn_cast(Operand); + if (MaybeValue && + dyn_cast_or_null(MaybeValue->getValue())->getName() == + KernelFunction.getName()) { + FoundLoweredKernelFunction = true; + continue; + } + if (MaybeValue && FoundLoweredKernelFunction) + return nullptr; + + MDNode *MaybeNode = dyn_cast(Operand); + if (FoundLoweredKernelFunction && MaybeNode && + cast(MaybeNode->getOperand(0))->getString() == + AttributeName && + MDArgIdx < MaybeNode->getNumOperands()) + return cast(MaybeNode->getOperand(MDArgIdx)); + } + return nullptr; +} + +static SPIRV::AccessQualifier::AccessQualifier +getArgAccessQual(const Function &F, unsigned ArgIdx) { + if (F.getCallingConv() != CallingConv::SPIR_KERNEL) + return SPIRV::AccessQualifier::ReadWrite; + + MDString *ArgAttribute = + getKernelArgAttribute(F, ArgIdx, "kernel_arg_access_qual"); + if (!ArgAttribute) + return SPIRV::AccessQualifier::ReadWrite; + + if (ArgAttribute->getString().compare("read_only") == 0) + return SPIRV::AccessQualifier::ReadOnly; + if (ArgAttribute->getString().compare("write_only") == 0) + return SPIRV::AccessQualifier::WriteOnly; + return SPIRV::AccessQualifier::ReadWrite; +} + +static std::vector +getKernelArgTypeQual(const Function &KernelFunction, unsigned ArgIdx) { + MDString *ArgAttribute = + getKernelArgAttribute(KernelFunction, ArgIdx, "kernel_arg_type_qual"); + if (ArgAttribute && ArgAttribute->getString().compare("volatile") == 0) + return {SPIRV::Decoration::Volatile}; + return {}; +} + +static Type *getArgType(const Function &F, unsigned ArgIdx) { + Type *OriginalArgType = getOriginalFunctionType(F)->getParamType(ArgIdx); + if (F.getCallingConv() != CallingConv::SPIR_KERNEL || + isSpecialOpaqueType(OriginalArgType)) + return OriginalArgType; + + MDString *MDKernelArgType = + getKernelArgAttribute(F, ArgIdx, "kernel_arg_type"); + if (!MDKernelArgType || !MDKernelArgType->getString().endswith("_t")) + return OriginalArgType; + + std::string KernelArgTypeStr = "opencl." + MDKernelArgType->getString().str(); + Type *ExistingOpaqueType = + StructType::getTypeByName(F.getContext(), KernelArgTypeStr); + return ExistingOpaqueType + ? ExistingOpaqueType + : StructType::create(F.getContext(), KernelArgTypeStr); +} + bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder, const Function &F, ArrayRef> VRegs, @@ -132,18 +229,10 @@ // TODO: handle the case of multiple registers. if (VRegs[i].size() > 1) return false; - Type *ArgTy = FTy->getParamType(i); - SPIRV::AccessQualifier::AccessQualifier AQ = - SPIRV::AccessQualifier::ReadWrite; - MDNode *Node = F.getMetadata("kernel_arg_access_qual"); - if (Node && i < Node->getNumOperands()) { - StringRef AQString = cast(Node->getOperand(i))->getString(); - if (AQString.compare("read_only") == 0) - AQ = SPIRV::AccessQualifier::ReadOnly; - else if (AQString.compare("write_only") == 0) - AQ = SPIRV::AccessQualifier::WriteOnly; - } - auto *SpirvTy = GR->assignTypeToVReg(ArgTy, VRegs[i][0], MIRBuilder, AQ); + SPIRV::AccessQualifier::AccessQualifier ArgAccessQual = + getArgAccessQual(F, i); + auto *SpirvTy = GR->assignTypeToVReg(getArgType(F, i), VRegs[i][0], + MIRBuilder, ArgAccessQual); ArgTypeVRegs.push_back(SpirvTy); if (Arg.hasName()) @@ -178,14 +267,15 @@ buildOpDecorate(VRegs[i][0], MIRBuilder, SPIRV::Decoration::FuncParamAttr, {Attr}); } - Node = F.getMetadata("kernel_arg_type_qual"); - if (Node && i < Node->getNumOperands()) { - StringRef TypeQual = cast(Node->getOperand(i))->getString(); - if (TypeQual.compare("volatile") == 0) - buildOpDecorate(VRegs[i][0], MIRBuilder, SPIRV::Decoration::Volatile, - {}); + + if (F.getCallingConv() == CallingConv::SPIR_KERNEL) { + std::vector ArgTypeQualDecs = + getKernelArgTypeQual(F, i); + for (SPIRV::Decoration::Decoration Decoration : ArgTypeQualDecs) + buildOpDecorate(VRegs[i][0], MIRBuilder, Decoration, {}); } - Node = F.getMetadata("spirv.ParameterDecorations"); + + MDNode *Node = F.getMetadata("spirv.ParameterDecorations"); if (Node && i < Node->getNumOperands() && isa(Node->getOperand(i))) { MDNode *MD = cast(Node->getOperand(i));