Index: llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp =================================================================== --- llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp +++ llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp @@ -134,8 +134,6 @@ } void SPIRVAsmPrinter::emitOpLabel(const MachineBasicBlock &MBB) { - if (MAI->MBBsToSkip.contains(&MBB)) - return; MCInst LabelInst; LabelInst.setOpcode(SPIRV::OpLabel); LabelInst.addOperand(MCOperand::createReg(MAI->getOrCreateMBBRegister(MBB))); @@ -143,6 +141,8 @@ } void SPIRVAsmPrinter::emitBasicBlockStart(const MachineBasicBlock &MBB) { + assert(!MBB.empty() && "MBB is empty!"); + // If it's the first MBB in MF, it has OpFunction and OpFunctionParameter, so // OpLabel should be output after them. if (MBB.getNumber() == MF->front().getNumber()) { Index: llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.h =================================================================== --- llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.h +++ llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.h @@ -136,9 +136,6 @@ // The set contains machine instructions which are necessary // for correct MIR but will not be emitted in function bodies. DenseSet InstrsToDelete; - // The set contains machine basic blocks which are necessary - // for correct MIR but will not be emitted. - DenseSet MBBsToSkip; // The table contains global aliases of local registers for each machine // function. The aliases are used to substitute local registers during // code emission. Index: llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp =================================================================== --- llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp +++ llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp @@ -386,47 +386,6 @@ } } -// Find OpIEqual and OpBranchConditional instructions originating from -// OpSwitches, mark them skipped for emission. Also mark MBB skipped if it -// contains only these instructions. -static void processSwitches(const Module &M, SPIRV::ModuleAnalysisInfo &MAI, - MachineModuleInfo *MMI) { - DenseSet SwitchRegs; - for (auto F = M.begin(), E = M.end(); F != E; ++F) { - MachineFunction *MF = MMI->getMachineFunction(*F); - if (!MF) - continue; - for (MachineBasicBlock &MBB : *MF) - for (MachineInstr &MI : MBB) { - if (MAI.getSkipEmission(&MI)) - continue; - if (MI.getOpcode() == SPIRV::OpSwitch) { - assert(MI.getOperand(0).isReg()); - SwitchRegs.insert(MI.getOperand(0).getReg()); - } - if (MI.getOpcode() == SPIRV::OpISubS && - SwitchRegs.contains(MI.getOperand(2).getReg())) { - SwitchRegs.insert(MI.getOperand(0).getReg()); - MAI.setSkipEmission(&MI); - } - if ((MI.getOpcode() != SPIRV::OpIEqual && - MI.getOpcode() != SPIRV::OpULessThanEqual) || - !MI.getOperand(2).isReg() || - !SwitchRegs.contains(MI.getOperand(2).getReg())) - continue; - Register CmpReg = MI.getOperand(0).getReg(); - MachineInstr *CBr = MI.getNextNode(); - assert(CBr && CBr->getOpcode() == SPIRV::OpBranchConditional && - CBr->getOperand(0).isReg() && - CBr->getOperand(0).getReg() == CmpReg); - MAI.setSkipEmission(&MI); - MAI.setSkipEmission(CBr); - if (&MBB.front() == &MI && &MBB.back() == CBr) - MAI.MBBsToSkip.insert(&MBB); - } - } -} - // RequirementHandler implementations. void SPIRV::RequirementHandler::getAndAddRequirements( SPIRV::OperandCategory::OperandCategory Category, uint32_t i, @@ -1020,8 +979,6 @@ collectReqs(M, MAI, MMI, *ST); - processSwitches(M, MAI, MMI); - // Process type/const/global var/func decl instructions, number their // destination registers from 0 to N, collect Extensions and Capabilities. processDefInstrs(M); Index: llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp =================================================================== --- llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp +++ llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp @@ -411,19 +411,23 @@ // // Sometimes (in case of range-compare switches), additional G_SUBs // instructions are inserted before G_ICMPs. Those need to be additionally - // processed and require type assignment. + // processed. // // This function modifies spv_switch call's operands to include destination // MBBs (default and for each constant value). - // Note that this function does not remove G_ICMP + G_BRCOND + G_BR sequences, - // but they are marked by ModuleAnalysis as skipped and as a result AsmPrinter - // does not output them. + // + // At the end, the function removes redundant [G_SUB] + G_ICMP + G_BRCOND + + // G_BR sequences. MachineRegisterInfo &MRI = MF.getRegInfo(); - // Collect all MIs relevant to switches across all MBBs in MF. + // Collect spv_switches and G_ICMPs across all MBBs in MF. std::vector RelevantInsts; + // Collect redundant MIs from [G_SUB] + G_ICMP + G_BRCOND + G_BR sequences. + // After updating spv_switches, the instructions can be removed. + std::vector PostUpdateArtifacts; + // Temporary set of compare registers. G_SUBs and G_ICMPs relating to // spv_switch use these registers. DenseSet CompareRegs; @@ -443,23 +447,21 @@ assert(MI.getOperand(0).isReg() && MI.getOperand(1).isReg()); Register Dst = MI.getOperand(0).getReg(); CompareRegs.insert(Dst); - SPIRVType *Ty = GR->getSPIRVTypeForVReg(MI.getOperand(1).getReg()); - insertAssignInstr(Dst, nullptr, Ty, GR, MIB, MRI); + PostUpdateArtifacts.push_back(&MI); } // G_ICMPs relating to switches. if (MI.getOpcode() == TargetOpcode::G_ICMP && MI.getOperand(2).isReg() && CompareRegs.contains(MI.getOperand(2).getReg())) { Register Dst = MI.getOperand(0).getReg(); - // Set type info for destination register of switch's ICMP instruction. - if (GR->getSPIRVTypeForVReg(Dst) == nullptr) { - MIB.setInsertPt(*MI.getParent(), MI); - Type *LLVMTy = IntegerType::get(MF.getFunction().getContext(), 1); - SPIRVType *SpirvTy = GR->getOrCreateSPIRVType(LLVMTy, MIB); - MRI.setRegClass(Dst, &SPIRV::IDRegClass); - GR->assignSPIRVTypeToVReg(SpirvTy, Dst, MIB.getMF()); - } RelevantInsts.push_back(&MI); + PostUpdateArtifacts.push_back(&MI); + MachineInstr *CBr = MRI.use_begin(Dst)->getParent(); + assert(CBr->getOpcode() == SPIRV::G_BRCOND); + PostUpdateArtifacts.push_back(CBr); + MachineInstr *Br = CBr->getNextNode(); + assert(Br->getOpcode() == SPIRV::G_BR); + PostUpdateArtifacts.push_back(Br); } } } @@ -503,6 +505,9 @@ // Map switch case Value to target MBB. ValuesToMBBs[Value] = MBB; + // Add target MBB as successor to the switch's MBB. + Switch->getParent()->addSuccessor(MBB); + // The next MI is always G_BR to either the next case or the default. MachineInstr *NextMI = CBr->getNextNode(); assert(NextMI->getOpcode() == SPIRV::G_BR && @@ -512,8 +517,11 @@ // register. if (NextMBB->front().getOpcode() != SPIRV::G_ICMP || (NextMBB->front().getOperand(2).isReg() && - NextMBB->front().getOperand(2).getReg() != CompareReg)) + NextMBB->front().getOperand(2).getReg() != CompareReg)) { + // Set default MBB and add it as successor to the switch's MBB. DefaultMBB = NextMBB; + Switch->getParent()->addSuccessor(DefaultMBB); + } } // Modify considered spv_switch operands using collected Values and @@ -540,6 +548,24 @@ Switch->addOperand(MachineOperand::CreateMBB(MBBs[k])); } } + + for (MachineInstr *MI : PostUpdateArtifacts) { + MachineBasicBlock *ParentMBB = MI->getParent(); + MI->eraseFromParent(); + // If G_ICMP + G_BRCOND + G_BR were the only MIs in MBB, erase this MBB. It + // can be safely assumed, there are no breaks or phis directing into this + // MBB. However, we need to remove this MBB from the CFG graph. MBBs must be + // erased top-down. + if (ParentMBB->empty()) { + while (!ParentMBB->pred_empty()) + (*ParentMBB->pred_begin())->removeSuccessor(ParentMBB); + + while (!ParentMBB->succ_empty()) + ParentMBB->removeSuccessor(ParentMBB->succ_begin()); + + ParentMBB->eraseFromParent(); + } + } } bool SPIRVPreLegalizer::runOnMachineFunction(MachineFunction &MF) { Index: llvm/test/CodeGen/SPIRV/branching/OpSwitchBranches.ll =================================================================== --- /dev/null +++ llvm/test/CodeGen/SPIRV/branching/OpSwitchBranches.ll @@ -0,0 +1,41 @@ +; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s --check-prefix=CHECK-SPIRV + +define i32 @test_switch_branches(i32 %a) { +entry: + %alloc = alloca i32 +; CHECK-SPIRV: OpSwitch %[[#]] %[[#DEFAULT:]] 1 %[[#CASE1:]] 2 %[[#CASE2:]] 3 %[[#CASE3:]] + switch i32 %a, label %default [ + i32 1, label %case1 + i32 2, label %case2 + i32 3, label %case3 + ] + +; CHECK-SPIRV: %[[#CASE1]] = OpLabel +case1: + store i32 1, ptr %alloc +; CHECK-SPIRV: OpBranch %[[#END:]] + br label %end + +; CHECK-SPIRV: %[[#CASE2]] = OpLabel +case2: + store i32 2, ptr %alloc +; CHECK-SPIRV: OpBranch %[[#END]] + br label %end + +; CHECK-SPIRV: %[[#CASE3]] = OpLabel +case3: + store i32 3, ptr %alloc +; CHECK-SPIRV: OpBranch %[[#END]] + br label %end + +; CHECK-SPIRV: %[[#DEFAULT]] = OpLabel +default: + store i32 0, ptr %alloc +; CHECK-SPIRV: OpBranch %[[#END]] + br label %end + +; CHECK-SPIRV: %[[#END]] = OpLabel +end: + %result = load i32, ptr %alloc + ret i32 %result +}