Index: llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp =================================================================== --- llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp +++ llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp @@ -291,6 +291,7 @@ Register ResultRegister = MIRBuilder.getMRI()->createGenericVirtualRegister(Type); + MIRBuilder.getMRI()->setRegClass(ResultRegister, &SPIRV::IDRegClass); GR->assignSPIRVTypeToVReg(BoolType, ResultRegister, MIRBuilder.getMF()); return std::make_tuple(ResultRegister, BoolType); } @@ -591,6 +592,8 @@ MRI->setType(Expected, DesiredLLT); Register Tmp = !IsCmpxchg ? MRI->createGenericVirtualRegister(DesiredLLT) : Call->ReturnRegister; + if (!MRI->getRegClassOrNull(Tmp)) + MRI->setRegClass(Tmp, &SPIRV::IDRegClass); GR->assignSPIRVTypeToVReg(SpvDesiredTy, Tmp, MIRBuilder.getMF()); SPIRVType *IntTy = GR->getOrCreateSPIRVIntegerType(32, MIRBuilder); @@ -936,16 +939,17 @@ // If it's out of range (max dimension is 3), we can just return the constant // default value (0 or 1 depending on which query function). if (IsConstantIndex && getIConstVal(IndexRegister, MRI) >= 3) { - Register defaultReg = Call->ReturnRegister; + Register DefaultReg = Call->ReturnRegister; if (PointerSize != ResultWidth) { - defaultReg = MRI->createGenericVirtualRegister(LLT::scalar(PointerSize)); - GR->assignSPIRVTypeToVReg(PointerSizeType, defaultReg, + DefaultReg = MRI->createGenericVirtualRegister(LLT::scalar(PointerSize)); + MRI->setRegClass(DefaultReg, &SPIRV::IDRegClass); + GR->assignSPIRVTypeToVReg(PointerSizeType, DefaultReg, MIRBuilder.getMF()); - ToTruncate = defaultReg; + ToTruncate = DefaultReg; } auto NewRegister = GR->buildConstantInt(DefaultValue, MIRBuilder, PointerSizeType); - MIRBuilder.buildCopy(defaultReg, NewRegister); + MIRBuilder.buildCopy(DefaultReg, NewRegister); } else { // If it could be in range, we need to load from the given builtin. auto Vec3Ty = GR->getOrCreateSPIRVVectorType(PointerSizeType, 3, MIRBuilder); @@ -956,6 +960,7 @@ Register Extracted = Call->ReturnRegister; if (!IsConstantIndex || PointerSize != ResultWidth) { Extracted = MRI->createGenericVirtualRegister(LLT::scalar(PointerSize)); + MRI->setRegClass(Extracted, &SPIRV::IDRegClass); GR->assignSPIRVTypeToVReg(PointerSizeType, Extracted, MIRBuilder.getMF()); } // Use Intrinsic::spv_extractelt so dynamic vs static extraction is @@ -974,6 +979,7 @@ Register CompareRegister = MRI->createGenericVirtualRegister(LLT::scalar(1)); + MRI->setRegClass(CompareRegister, &SPIRV::IDRegClass); GR->assignSPIRVTypeToVReg(BoolType, CompareRegister, MIRBuilder.getMF()); // Use G_ICMP to check if idxVReg < 3. @@ -990,6 +996,7 @@ if (PointerSize != ResultWidth) { SelectionResult = MRI->createGenericVirtualRegister(LLT::scalar(PointerSize)); + MRI->setRegClass(SelectionResult, &SPIRV::IDRegClass); GR->assignSPIRVTypeToVReg(PointerSizeType, SelectionResult, MIRBuilder.getMF()); } @@ -1125,6 +1132,7 @@ if (NumExpectedRetComponents != NumActualRetComponents) { QueryResult = MIRBuilder.getMRI()->createGenericVirtualRegister( LLT::fixed_vector(NumActualRetComponents, 32)); + MIRBuilder.getMRI()->setRegClass(QueryResult, &SPIRV::IDRegClass); SPIRVType *IntTy = GR->getOrCreateSPIRVIntegerType(32, MIRBuilder); QueryResultType = GR->getOrCreateSPIRVVectorType( IntTy, NumActualRetComponents, MIRBuilder); @@ -1274,6 +1282,7 @@ } LLT LLType = LLT::scalar(GR->getScalarOrVectorBitWidth(TempType)); Register TempRegister = MRI->createGenericVirtualRegister(LLType); + MRI->setRegClass(TempRegister, &SPIRV::IDRegClass); GR->assignSPIRVTypeToVReg(TempType, TempRegister, MIRBuilder.getMF()); MIRBuilder.buildInstr(SPIRV::OpImageSampleExplicitLod) @@ -1431,6 +1440,68 @@ } } +static bool buildNDRange(const SPIRV::IncomingCall *Call, + MachineIRBuilder &MIRBuilder, + SPIRVGlobalRegistry *GR) { + MachineRegisterInfo *MRI = MIRBuilder.getMRI(); + SPIRVType *PtrType = GR->getSPIRVTypeForVReg(Call->Arguments[0]); + assert(PtrType->getOpcode() == SPIRV::OpTypePointer && + PtrType->getOperand(2).isReg()); + Register TypeReg = PtrType->getOperand(2).getReg(); + SPIRVType *StructType = GR->getSPIRVTypeForVReg(TypeReg); + MachineFunction &MF = MIRBuilder.getMF(); + Register TmpReg = MRI->createVirtualRegister(&SPIRV::IDRegClass); + GR->assignSPIRVTypeToVReg(StructType, TmpReg, MF); + // Skip the first arg, it's the destination pointer. OpBuildNDRange takes + // three other arguments, so pass zero constant on absence. + unsigned NumArgs = Call->Arguments.size(); + assert(NumArgs >= 2); + Register GlobalWorkSize = Call->Arguments[NumArgs < 4 ? 1 : 2]; + Register LocalWorkSize = + NumArgs == 2 ? Register(0) : Call->Arguments[NumArgs < 4 ? 2 : 3]; + Register GlobalWorkOffset = NumArgs <= 3 ? Register(0) : Call->Arguments[1]; + if (NumArgs < 4) { + Register Const; + SPIRVType *SpvTy = GR->getSPIRVTypeForVReg(GlobalWorkSize); + if (SpvTy->getOpcode() == SPIRV::OpTypePointer) { + MachineInstr *DefInstr = MRI->getUniqueVRegDef(GlobalWorkSize); + assert(DefInstr && isSpvIntrinsic(*DefInstr, Intrinsic::spv_gep) && + DefInstr->getOperand(3).isReg()); + Register GWSPtr = DefInstr->getOperand(3).getReg(); + if (!MRI->getRegClassOrNull(GWSPtr)) + MRI->setRegClass(GWSPtr, &SPIRV::IDRegClass); + // TODO: Maybe simplify generation of the type of the fields. + unsigned Size = Call->Builtin->Name.equals("ndrange_3D") ? 3 : 2; + unsigned BitWidth = GR->getPointerSize() == 64 ? 64 : 32; + Type *BaseTy = IntegerType::get(MF.getFunction().getContext(), BitWidth); + Type *FieldTy = ArrayType::get(BaseTy, Size); + SPIRVType *SpvFieldTy = GR->getOrCreateSPIRVType(FieldTy, MIRBuilder); + GlobalWorkSize = MRI->createVirtualRegister(&SPIRV::IDRegClass); + GR->assignSPIRVTypeToVReg(SpvFieldTy, GlobalWorkSize, MF); + MIRBuilder.buildInstr(SPIRV::OpLoad) + .addDef(GlobalWorkSize) + .addUse(GR->getSPIRVTypeID(SpvFieldTy)) + .addUse(GWSPtr); + Const = GR->getOrCreateConsIntArray(0, MIRBuilder, SpvFieldTy); + } else { + Const = GR->buildConstantInt(0, MIRBuilder, SpvTy); + } + if (!LocalWorkSize.isValid()) + LocalWorkSize = Const; + if (!GlobalWorkOffset.isValid()) + GlobalWorkOffset = Const; + } + MIRBuilder.buildInstr(SPIRV::OpBuildNDRange) + .addDef(TmpReg) + .addUse(TypeReg) + .addUse(GlobalWorkSize) + .addUse(LocalWorkSize) + .addUse(GlobalWorkOffset); + return MIRBuilder.buildInstr(SPIRV::OpStore) + .addUse(Call->Arguments[0]) + .addUse(TmpReg); +} + static MachineInstr *getBlockStructInstr(Register ParamReg, MachineRegisterInfo *MRI) { // We expect the following sequence of instructions: @@ -1538,9 +1609,8 @@ const SPIRVType *PointerSizeTy = GR->getOrCreateSPIRVPointerType( Int32Ty, MIRBuilder, SPIRV::StorageClass::Function); for (unsigned I = 0; I < LocalSizeNum; ++I) { - Register Reg = - MIRBuilder.getMRI()->createVirtualRegister(&SPIRV::IDRegClass); - MIRBuilder.getMRI()->setType(Reg, LLType); + Register Reg = MRI->createVirtualRegister(&SPIRV::IDRegClass); + MRI->setType(Reg, LLType); GR->assignSPIRVTypeToVReg(PointerSizeTy, Reg, MIRBuilder.getMF()); auto GEPInst = MIRBuilder.buildIntrinsic(Intrinsic::spv_gep, ArrayRef{Reg}, true); @@ -1625,64 +1695,8 @@ .addUse(Call->Arguments[0]) .addUse(Call->Arguments[1]) .addUse(Call->Arguments[2]); - case SPIRV::OpBuildNDRange: { - MachineRegisterInfo *MRI = MIRBuilder.getMRI(); - SPIRVType *PtrType = GR->getSPIRVTypeForVReg(Call->Arguments[0]); - assert(PtrType->getOpcode() == SPIRV::OpTypePointer && - PtrType->getOperand(2).isReg()); - Register TypeReg = PtrType->getOperand(2).getReg(); - SPIRVType *StructType = GR->getSPIRVTypeForVReg(TypeReg); - Register TmpReg = MRI->createVirtualRegister(&SPIRV::IDRegClass); - GR->assignSPIRVTypeToVReg(StructType, TmpReg, MIRBuilder.getMF()); - // Skip the first arg, it's the destination pointer. OpBuildNDRange takes - // three other arguments, so pass zero constant on absence. - unsigned NumArgs = Call->Arguments.size(); - assert(NumArgs >= 2); - Register GlobalWorkSize = Call->Arguments[NumArgs < 4 ? 1 : 2]; - Register LocalWorkSize = - NumArgs == 2 ? Register(0) : Call->Arguments[NumArgs < 4 ? 2 : 3]; - Register GlobalWorkOffset = NumArgs <= 3 ? Register(0) : Call->Arguments[1]; - if (NumArgs < 4) { - Register Const; - SPIRVType *SpvTy = GR->getSPIRVTypeForVReg(GlobalWorkSize); - if (SpvTy->getOpcode() == SPIRV::OpTypePointer) { - MachineInstr *DefInstr = MRI->getUniqueVRegDef(GlobalWorkSize); - assert(DefInstr && isSpvIntrinsic(*DefInstr, Intrinsic::spv_gep) && - DefInstr->getOperand(3).isReg()); - Register GWSPtr = DefInstr->getOperand(3).getReg(); - // TODO: Maybe simplify generation of the type of the fields. - unsigned Size = Call->Builtin->Name.equals("ndrange_3D") ? 3 : 2; - unsigned BitWidth = GR->getPointerSize() == 64 ? 64 : 32; - Type *BaseTy = IntegerType::get( - MIRBuilder.getMF().getFunction().getContext(), BitWidth); - Type *FieldTy = ArrayType::get(BaseTy, Size); - SPIRVType *SpvFieldTy = GR->getOrCreateSPIRVType(FieldTy, MIRBuilder); - GlobalWorkSize = MRI->createVirtualRegister(&SPIRV::IDRegClass); - GR->assignSPIRVTypeToVReg(SpvFieldTy, GlobalWorkSize, - MIRBuilder.getMF()); - MIRBuilder.buildInstr(SPIRV::OpLoad) - .addDef(GlobalWorkSize) - .addUse(GR->getSPIRVTypeID(SpvFieldTy)) - .addUse(GWSPtr); - Const = GR->getOrCreateConsIntArray(0, MIRBuilder, SpvFieldTy); - } else { - Const = GR->buildConstantInt(0, MIRBuilder, SpvTy); - } - if (!LocalWorkSize.isValid()) - LocalWorkSize = Const; - if (!GlobalWorkOffset.isValid()) - GlobalWorkOffset = Const; - } - MIRBuilder.buildInstr(Opcode) - .addDef(TmpReg) - .addUse(TypeReg) - .addUse(GlobalWorkSize) - .addUse(LocalWorkSize) - .addUse(GlobalWorkOffset); - return MIRBuilder.buildInstr(SPIRV::OpStore) - .addUse(Call->Arguments[0]) - .addUse(TmpReg); - } + case SPIRV::OpBuildNDRange: + return buildNDRange(Call, MIRBuilder, GR); case SPIRV::OpEnqueueKernel: return buildEnqueueKernel(Call, MIRBuilder, GR); default: @@ -1846,6 +1860,8 @@ SPIRVType *ReturnType = nullptr; if (OrigRetTy && !OrigRetTy->isVoidTy()) { ReturnType = GR->assignTypeToVReg(OrigRetTy, OrigRet, MIRBuilder); + if (!MIRBuilder.getMRI()->getRegClassOrNull(ReturnRegister)) + MIRBuilder.getMRI()->setRegClass(ReturnRegister, &SPIRV::IDRegClass); } else if (OrigRetTy && OrigRetTy->isVoidTy()) { ReturnRegister = MIRBuilder.getMRI()->createVirtualRegister(&IDRegClass); MIRBuilder.getMRI()->setType(ReturnRegister, LLT::scalar(32)); Index: llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp =================================================================== --- llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp +++ llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp @@ -374,6 +374,7 @@ FTy = getOriginalFunctionType(*CF); } + MachineRegisterInfo *MRI = MIRBuilder.getMRI(); Register ResVReg = Info.OrigRet.Regs.empty() ? Register(0) : Info.OrigRet.Regs[0]; std::string FuncName = Info.Callee.getGlobal()->getName().str(); @@ -388,6 +389,8 @@ SmallVector ArgVRegs; for (auto Arg : Info.OrigArgs) { assert(Arg.Regs.size() == 1 && "Call arg has multiple VRegs"); + if (!MRI->getRegClassOrNull(Arg.Regs[0])) + MRI->setRegClass(Arg.Regs[0], &SPIRV::IDRegClass); ArgVRegs.push_back(Arg.Regs[0]); SPIRVType *SPIRVTy = GR->getOrCreateSPIRVType(Arg.Ty, MIRBuilder); GR->assignSPIRVTypeToVReg(SPIRVTy, Arg.Regs[0], MIRBuilder.getMF()); @@ -410,8 +413,9 @@ for (const Argument &Arg : CF->args()) { if (MIRBuilder.getDataLayout().getTypeStoreSize(Arg.getType()).isZero()) continue; // Don't handle zero sized types. - ToInsert.push_back( - {MIRBuilder.getMRI()->createGenericVirtualRegister(LLT::scalar(32))}); + Register Reg = MRI->createGenericVirtualRegister(LLT::scalar(32)); + MRI->setRegClass(Reg, &SPIRV::IDRegClass); + ToInsert.push_back({Reg}); VRegArgs.push_back(ToInsert.back()); } // TODO: Reuse FunctionLoweringInfo Index: llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp =================================================================== --- llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp +++ llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp @@ -133,6 +133,7 @@ unsigned BitWidth = SpvType ? getScalarOrVectorBitWidth(SpvType) : 32; LLT LLTy = LLT::scalar(32); Res = CurMF->getRegInfo().createGenericVirtualRegister(LLTy); + CurMF->getRegInfo().setRegClass(Res, &SPIRV::IDRegClass); if (MIRBuilder) assignTypeToVReg(LLVMIntTy, Res, *MIRBuilder); else @@ -192,6 +193,7 @@ unsigned BitWidth = SpvType ? getScalarOrVectorBitWidth(SpvType) : 32; LLT LLTy = LLT::scalar(EmitIR ? BitWidth : 32); Res = MF.getRegInfo().createGenericVirtualRegister(LLTy); + MF.getRegInfo().setRegClass(Res, &SPIRV::IDRegClass); assignTypeToVReg(LLVMIntTy, Res, MIRBuilder, SPIRV::AccessQualifier::ReadWrite, EmitIR); DT.add(ConstInt, &MIRBuilder.getMF(), Res); @@ -237,6 +239,7 @@ if (!Res.isValid()) { unsigned BitWidth = SpvType ? getScalarOrVectorBitWidth(SpvType) : 32; Res = MF.getRegInfo().createGenericVirtualRegister(LLT::scalar(BitWidth)); + MF.getRegInfo().setRegClass(Res, &SPIRV::IDRegClass); assignTypeToVReg(LLVMFPTy, Res, MIRBuilder); DT.add(ConstFP, &MF, Res); MIRBuilder.buildFConstant(Res, *ConstFP); @@ -262,6 +265,7 @@ LLT LLTy = LLT::scalar(32); Register SpvVecConst = CurMF->getRegInfo().createGenericVirtualRegister(LLTy); + CurMF->getRegInfo().setRegClass(SpvVecConst, &SPIRV::IDRegClass); assignSPIRVTypeToVReg(SpvType, SpvVecConst, *CurMF); DT.add(CA, CurMF, SpvVecConst); MachineInstrBuilder MIB; @@ -333,6 +337,7 @@ LLT LLTy = EmitIR ? LLT::fixed_vector(ElemCnt, BitWidth) : LLT::scalar(32); Register SpvVecConst = CurMF->getRegInfo().createGenericVirtualRegister(LLTy); + CurMF->getRegInfo().setRegClass(SpvVecConst, &SPIRV::IDRegClass); assignSPIRVTypeToVReg(SpvType, SpvVecConst, *CurMF); DT.add(CA, CurMF, SpvVecConst); if (EmitIR) { @@ -401,6 +406,7 @@ if (!Res.isValid()) { LLT LLTy = LLT::pointer(LLVMPtrTy->getAddressSpace(), PointerSize); Res = CurMF->getRegInfo().createGenericVirtualRegister(LLTy); + CurMF->getRegInfo().setRegClass(Res, &SPIRV::IDRegClass); assignSPIRVTypeToVReg(SpvType, Res, *CurMF); MIRBuilder.buildInstr(SPIRV::OpConstantNull) .addDef(Res) @@ -1081,6 +1087,7 @@ return Res; LLT LLTy = LLT::scalar(32); Res = CurMF->getRegInfo().createGenericVirtualRegister(LLTy); + CurMF->getRegInfo().setRegClass(Res, &SPIRV::IDRegClass); assignSPIRVTypeToVReg(SpvType, Res, *CurMF); DT.add(UV, CurMF, Res); Index: llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp =================================================================== --- llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp +++ llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp @@ -85,6 +85,9 @@ Register Reg = MI->getOperand(2).getReg(); if (RegsAlreadyAddedToDT.find(MI) != RegsAlreadyAddedToDT.end()) Reg = RegsAlreadyAddedToDT[MI]; + auto *RC = MRI.getRegClassOrNull(MI->getOperand(0).getReg()); + if (!MRI.getRegClassOrNull(Reg) && RC) + MRI.setRegClass(Reg, RC); MRI.replaceRegWith(MI->getOperand(0).getReg(), Reg); MI->eraseFromParent(); } @@ -201,8 +204,12 @@ (Def->getNextNode() ? Def->getNextNode()->getIterator() : Def->getParent()->end())); Register NewReg = MRI.createGenericVirtualRegister(MRI.getType(Reg)); - if (auto *RC = MRI.getRegClassOrNull(Reg)) + if (auto *RC = MRI.getRegClassOrNull(Reg)) { MRI.setRegClass(NewReg, RC); + } else { + MRI.setRegClass(NewReg, &SPIRV::IDRegClass); + MRI.setRegClass(Reg, &SPIRV::IDRegClass); + } SpirvTy = SpirvTy ? SpirvTy : GR->getOrCreateSPIRVType(Ty, MIB); GR->assignSPIRVTypeToVReg(SpirvTy, Reg, MIB.getMF()); // This is to make it convenient for Legalizer to get the SPIRVType @@ -217,7 +224,6 @@ .addUse(GR->getSPIRVTypeID(SpirvTy)) .setMIFlags(Flags); Def->getOperand(0).setReg(NewReg); - MRI.setRegClass(Reg, &SPIRV::ANYIDRegClass); return NewReg; } } // namespace llvm