diff --git a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp --- a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp @@ -303,7 +303,6 @@ // Generate a SPIR-V type for the function. auto MRI = MIRBuilder.getMRI(); Register FuncVReg = MRI->createGenericVirtualRegister(LLT::scalar(32)); - MRI->setRegClass(FuncVReg, &SPIRV::IDRegClass); if (F.isDeclaration()) GR->add(&F, &MIRBuilder.getMF(), FuncVReg); SPIRVType *RetTy = GR->getOrCreateSPIRVType(FTy->getReturnType(), MIRBuilder); @@ -313,22 +312,28 @@ // Build the OpTypeFunction declaring it. uint32_t FuncControl = getFunctionControl(F); - MIRBuilder.buildInstr(SPIRV::OpFunction) - .addDef(FuncVReg) - .addUse(GR->getSPIRVTypeID(RetTy)) - .addImm(FuncControl) - .addUse(GR->getSPIRVTypeID(FuncTy)); + auto B = MIRBuilder.buildInstr(SPIRV::OpFunction) + .addDef(FuncVReg) + .addUse(GR->getSPIRVTypeID(RetTy)) + .addImm(FuncControl) + .addUse(GR->getSPIRVTypeID(FuncTy)); + const auto &ST = GR->CurMF->getSubtarget(); + constrainSelectedInstRegOperands(*B.getInstr(), *ST.getInstrInfo(), + *ST.getRegisterInfo(), *ST.getRegBankInfo()); // Add OpFunctionParameters. int i = 0; for (const auto &Arg : F.args()) { assert(VRegs[i].size() == 1 && "Formal arg has multiple vregs"); - MRI->setRegClass(VRegs[i][0], &SPIRV::IDRegClass); - MIRBuilder.buildInstr(SPIRV::OpFunctionParameter) - .addDef(VRegs[i][0]) - .addUse(GR->getSPIRVTypeID(ArgTypeVRegs[i])); + auto B = MIRBuilder.buildInstr(SPIRV::OpFunctionParameter) + .addDef(VRegs[i][0]) + .addUse(GR->getSPIRVTypeID(ArgTypeVRegs[i])); if (F.isDeclaration()) GR->add(&Arg, &MIRBuilder.getMF(), VRegs[i][0]); + const auto &ST = GR->CurMF->getSubtarget(); + constrainSelectedInstRegOperands(*B.getInstr(), *ST.getInstrInfo(), + *ST.getRegisterInfo(), + *ST.getRegBankInfo()); i++; } // Name the function. @@ -412,7 +417,6 @@ if (MIRBuilder.getDataLayout().getTypeStoreSize(Arg.getType()).isZero()) continue; // Don't handle zero sized types. Register Reg = MRI->createGenericVirtualRegister(LLT::scalar(32)); - MRI->setRegClass(Reg, &SPIRV::IDRegClass); ToInsert.push_back({Reg}); VRegArgs.push_back(ToInsert.back()); } @@ -423,7 +427,8 @@ // Make sure there's a valid return reg, even for functions returning void. if (!ResVReg.isValid()) - ResVReg = MIRBuilder.getMRI()->createVirtualRegister(&SPIRV::IDRegClass); + ResVReg = + MIRBuilder.getMRI()->createGenericVirtualRegister(LLT::scalar(32)); SPIRVType *RetType = GR->assignTypeToVReg(FTy->getReturnType(), ResVReg, MIRBuilder); diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp --- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp @@ -733,6 +733,10 @@ return nullptr; TypesInProcessing.insert(Ty); SPIRVType *SpirvType = createSPIRVType(Ty, MIRBuilder, AccessQual, EmitIR); + const auto &ST = CurMF->getSubtarget(); + constrainSelectedInstRegOperands(*const_cast(SpirvType), + *ST.getInstrInfo(), *ST.getRegisterInfo(), + *ST.getRegBankInfo()); TypesInProcessing.erase(Ty); VRegToTypeMap[&MIRBuilder.getMF()][getSPIRVTypeID(SpirvType)] = SpirvType; SPIRVToLLVMType[SpirvType] = Ty; @@ -978,6 +982,10 @@ VRegToTypeMap[CurMF][getSPIRVTypeID(SpirvType)] = SpirvType; SPIRVToLLVMType[SpirvType] = LLVMTy; DT.add(LLVMTy, CurMF, getSPIRVTypeID(SpirvType)); + const auto &ST = CurMF->getSubtarget(); + constrainSelectedInstRegOperands(*const_cast(SpirvType), + *ST.getInstrInfo(), *ST.getRegisterInfo(), + *ST.getRegBankInfo()); return SpirvType; } diff --git a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp --- a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp @@ -85,9 +85,6 @@ Register Reg = MI->getOperand(2).getReg(); if (RegsAlreadyAddedToDT.find(MI) != RegsAlreadyAddedToDT.end()) Reg = RegsAlreadyAddedToDT[MI]; - auto *RC = MRI.getRegClassOrNull(MI->getOperand(0).getReg()); - if (!MRI.getRegClassOrNull(Reg) && RC) - MRI.setRegClass(Reg, RC); MRI.replaceRegWith(MI->getOperand(0).getReg(), Reg); MI->eraseFromParent(); } @@ -182,8 +179,6 @@ } if (SpirvTy) GR->assignSPIRVTypeToVReg(SpirvTy, Reg, MIB.getMF()); - if (!MRI.getRegClassOrNull(Reg)) - MRI.setRegClass(Reg, &SPIRV::IDRegClass); } } return SpirvTy; @@ -204,12 +199,6 @@ (Def->getNextNode() ? Def->getNextNode()->getIterator() : Def->getParent()->end())); Register NewReg = MRI.createGenericVirtualRegister(MRI.getType(Reg)); - if (auto *RC = MRI.getRegClassOrNull(Reg)) { - MRI.setRegClass(NewReg, RC); - } else { - MRI.setRegClass(NewReg, &SPIRV::IDRegClass); - MRI.setRegClass(Reg, &SPIRV::IDRegClass); - } SpirvTy = SpirvTy ? SpirvTy : GR->getOrCreateSPIRVType(Ty, MIB); GR->assignSPIRVTypeToVReg(SpirvTy, Reg, MIB.getMF()); // This is to make it convenient for Legalizer to get the SPIRVType @@ -218,12 +207,15 @@ // Copy MIFlags from Def to ASSIGN_TYPE instruction. It's required to keep // the flags after instruction selection. const uint32_t Flags = Def->getFlags(); - MIB.buildInstr(SPIRV::ASSIGN_TYPE) - .addDef(Reg) - .addUse(NewReg) - .addUse(GR->getSPIRVTypeID(SpirvTy)) - .setMIFlags(Flags); + auto B = MIB.buildInstr(SPIRV::ASSIGN_TYPE) + .addDef(Reg) + .addUse(NewReg) + .addUse(GR->getSPIRVTypeID(SpirvTy)) + .setMIFlags(Flags); Def->getOperand(0).setReg(NewReg); + const auto &ST = GR->CurMF->getSubtarget(); + constrainSelectedInstRegOperands(*B.getInstr(), *ST.getInstrInfo(), + *ST.getRegisterInfo(), *ST.getRegBankInfo()); return NewReg; } } // namespace llvm @@ -318,18 +310,14 @@ SPIRV::OpTypeFloat; IsFloat |= IsVectorFloat; auto GetIdOp = IsFloat ? SPIRV::GET_fID : SPIRV::GET_ID; - auto DstClass = IsFloat ? &SPIRV::fIDRegClass : &SPIRV::IDRegClass; if (MRI.getType(ValReg).isPointer()) { NewT = LLT::pointer(0, 32); GetIdOp = SPIRV::GET_pID; - DstClass = &SPIRV::pIDRegClass; } else if (MRI.getType(ValReg).isVector()) { NewT = LLT::fixed_vector(2, NewT); GetIdOp = IsFloat ? SPIRV::GET_vfID : SPIRV::GET_vID; - DstClass = IsFloat ? &SPIRV::vfIDRegClass : &SPIRV::vIDRegClass; } Register IdReg = MRI.createGenericVirtualRegister(NewT); - MRI.setRegClass(IdReg, DstClass); return {IdReg, GetIdOp}; } @@ -342,15 +330,20 @@ auto NewReg = createNewIdReg(MI.getOperand(0).getReg(), Opc, MRI, *GR).first; AssignTypeInst.getOperand(1).setReg(NewReg); MI.getOperand(0).setReg(NewReg); - MIB.setInsertPt(*MI.getParent(), - (MI.getNextNode() ? MI.getNextNode()->getIterator() - : MI.getParent()->end())); + MIB.setInsertPt(*MI.getParent(), MI.getIterator()); for (auto &Op : MI.operands()) { if (!Op.isReg() || Op.isDef()) continue; auto IdOpInfo = createNewIdReg(Op.getReg(), Opc, MRI, *GR); - MIB.buildInstr(IdOpInfo.second).addDef(IdOpInfo.first).addUse(Op.getReg()); + auto B = MIB.buildInstr(IdOpInfo.second) + .addDef(IdOpInfo.first) + .addUse(Op.getReg()); Op.setReg(IdOpInfo.first); + + const auto &ST = GR->CurMF->getSubtarget(); + constrainSelectedInstRegOperands(*B.getInstr(), *ST.getInstrInfo(), + *ST.getRegisterInfo(), + *ST.getRegBankInfo()); } } @@ -379,8 +372,6 @@ if (!isTypeFoldingSupported(Opcode)) continue; Register DstReg = MI.getOperand(0).getReg(); - if (MRI.getType(DstReg).isVector()) - MRI.setRegClass(DstReg, &SPIRV::IDRegClass); // Don't need to reset type of register holding constant and used in // G_ADDRSPACE_CAST, since it braaks legalizer. if (Opcode == TargetOpcode::G_CONSTANT && MRI.hasOneUse(DstReg)) {