diff --git a/llvm/include/llvm/CodeGen/GlobalISel/IRTranslator.h b/llvm/include/llvm/CodeGen/GlobalISel/IRTranslator.h --- a/llvm/include/llvm/CodeGen/GlobalISel/IRTranslator.h +++ b/llvm/include/llvm/CodeGen/GlobalISel/IRTranslator.h @@ -307,19 +307,23 @@ void emitSwitchCase(SwitchCG::CaseBlock &CB, MachineBasicBlock *SwitchBB, MachineIRBuilder &MIB); - bool lowerJumpTableWorkItem(SwitchCG::SwitchWorkListItem W, - MachineBasicBlock *SwitchMBB, - MachineBasicBlock *CurMBB, - MachineBasicBlock *DefaultMBB, - MachineIRBuilder &MIB, - MachineFunction::iterator BBI, - BranchProbability UnhandledProbs, - SwitchCG::CaseClusterIt I, - MachineBasicBlock *Fallthrough, - bool FallthroughUnreachable); - - bool lowerSwitchRangeWorkItem(SwitchCG::CaseClusterIt I, - Value *Cond, + /// Generate for for the BitTest header block, which precedes each sequence of + /// BitTestCases. + void emitBitTestHeader(SwitchCG::BitTestBlock &BTB, + MachineBasicBlock *SwitchMBB); + /// Generate code to produces one "bit test" for a given BitTestCase \p B. + void emitBitTestCase(SwitchCG::BitTestBlock &BB, MachineBasicBlock *NextMBB, + BranchProbability BranchProbToNext, Register Reg, + SwitchCG::BitTestCase &B, MachineBasicBlock *SwitchBB); + + bool lowerJumpTableWorkItem( + SwitchCG::SwitchWorkListItem W, MachineBasicBlock *SwitchMBB, + MachineBasicBlock *CurMBB, MachineBasicBlock *DefaultMBB, + MachineIRBuilder &MIB, MachineFunction::iterator BBI, + BranchProbability UnhandledProbs, SwitchCG::CaseClusterIt I, + MachineBasicBlock *Fallthrough, bool FallthroughUnreachable); + + bool lowerSwitchRangeWorkItem(SwitchCG::CaseClusterIt I, Value *Cond, MachineBasicBlock *Fallthrough, bool FallthroughUnreachable, BranchProbability UnhandledProbs, @@ -327,6 +331,14 @@ MachineIRBuilder &MIB, MachineBasicBlock *SwitchMBB); + bool lowerBitTestWorkItem( + SwitchCG::SwitchWorkListItem W, MachineBasicBlock *SwitchMBB, + MachineBasicBlock *CurMBB, MachineBasicBlock *DefaultMBB, + MachineIRBuilder &MIB, MachineFunction::iterator BBI, + BranchProbability DefaultProb, BranchProbability UnhandledProbs, + SwitchCG::CaseClusterIt I, MachineBasicBlock *Fallthrough, + bool FallthroughUnreachable); + bool lowerSwitchWorkItem(SwitchCG::SwitchWorkListItem W, Value *Cond, MachineBasicBlock *SwitchMBB, MachineBasicBlock *DefaultMBB, diff --git a/llvm/include/llvm/CodeGen/GlobalISel/MachineIRBuilder.h b/llvm/include/llvm/CodeGen/GlobalISel/MachineIRBuilder.h --- a/llvm/include/llvm/CodeGen/GlobalISel/MachineIRBuilder.h +++ b/llvm/include/llvm/CodeGen/GlobalISel/MachineIRBuilder.h @@ -729,7 +729,7 @@ /// depend on bit 0 (for now). /// /// \return The newly created instruction. - MachineInstrBuilder buildBrCond(Register Tst, MachineBasicBlock &Dest); + MachineInstrBuilder buildBrCond(const SrcOp &Tst, MachineBasicBlock &Dest); /// Build and insert G_BRINDIRECT \p Tgt /// diff --git a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp --- a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp +++ b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp @@ -440,6 +440,7 @@ } SL->findJumpTables(Clusters, &SI, DefaultMBB, nullptr, nullptr); + SL->findBitTestClusters(Clusters, &SI); LLVM_DEBUG({ dbgs() << "Case clusters: "; @@ -717,6 +718,156 @@ return true; } +void IRTranslator::emitBitTestHeader(SwitchCG::BitTestBlock &B, + MachineBasicBlock *SwitchBB) { + MachineIRBuilder &MIB = *CurBuilder; + MIB.setMBB(*SwitchBB); + + // Subtract the minimum value. + Register SwitchOpReg = getOrCreateVReg(*B.SValue); + + LLT SwitchOpTy = MRI->getType(SwitchOpReg); + Register MinValReg = MIB.buildConstant(SwitchOpTy, B.First).getReg(0); + auto RangeSub = MIB.buildSub(SwitchOpTy, SwitchOpReg, MinValReg); + + // Ensure that the type will fit the mask value. + LLT MaskTy = SwitchOpTy; + for (unsigned I = 0, E = B.Cases.size(); I != E; ++I) { + if (!isUIntN(SwitchOpTy.getSizeInBits(), B.Cases[I].Mask)) { + // Switch table case range are encoded into series of masks. + // Just use pointer type, it's guaranteed to fit. + MaskTy = LLT::scalar(64); + break; + } + } + Register SubReg = RangeSub.getReg(0); + if (SwitchOpTy != MaskTy) + SubReg = MIB.buildZExtOrTrunc(MaskTy, SubReg).getReg(0); + + B.RegVT = getMVTForLLT(MaskTy); + B.Reg = SubReg; + + MachineBasicBlock *MBB = B.Cases[0].ThisBB; + + if (!B.OmitRangeCheck) + addSuccessorWithProb(SwitchBB, B.Default, B.DefaultProb); + addSuccessorWithProb(SwitchBB, MBB, B.Prob); + + SwitchBB->normalizeSuccProbs(); + + if (!B.OmitRangeCheck) { + // Conditional branch to the default block. + auto RangeCst = MIB.buildConstant(SwitchOpTy, B.Range); + auto RangeCmp = MIB.buildICmp(CmpInst::Predicate::ICMP_UGT, LLT::scalar(1), + RangeSub, RangeCst); + MIB.buildBrCond(RangeCmp.getReg(0), *B.Default); + } + + // Avoid emitting unnecessary branches to the next block. + if (MBB != SwitchBB->getNextNode()) + MIB.buildBr(*MBB); +} + +void IRTranslator::emitBitTestCase(SwitchCG::BitTestBlock &BB, + MachineBasicBlock *NextMBB, + BranchProbability BranchProbToNext, + Register Reg, SwitchCG::BitTestCase &B, + MachineBasicBlock *SwitchBB) { + MachineIRBuilder &MIB = *CurBuilder; + MIB.setMBB(*SwitchBB); + + LLT SwitchTy = getLLTForMVT(BB.RegVT); + Register Cmp; + unsigned PopCount = countPopulation(B.Mask); + if (PopCount == 1) { + // Testing for a single bit; just compare the shift count with what it + // would need to be to shift a 1 bit in that position. + auto MaskTrailingZeros = + MIB.buildConstant(SwitchTy, countTrailingZeros(B.Mask)); + Cmp = + MIB.buildICmp(ICmpInst::ICMP_EQ, LLT::scalar(1), Reg, MaskTrailingZeros) + .getReg(0); + } else if (PopCount == BB.Range) { + // There is only one zero bit in the range, test for it directly. + auto MaskTrailingOnes = + MIB.buildConstant(SwitchTy, countTrailingOnes(B.Mask)); + Cmp = MIB.buildICmp(CmpInst::ICMP_NE, LLT::scalar(1), Reg, MaskTrailingOnes) + .getReg(0); + } else { + // Make desired shift. + auto CstOne = MIB.buildConstant(SwitchTy, 1); + auto SwitchVal = MIB.buildShl(SwitchTy, CstOne, Reg); + + // Emit bit tests and jumps. + auto CstMask = MIB.buildConstant(SwitchTy, B.Mask); + auto AndOp = MIB.buildAnd(SwitchTy, SwitchVal, CstMask); + auto CstZero = MIB.buildConstant(SwitchTy, 0); + Cmp = MIB.buildICmp(CmpInst::ICMP_NE, LLT::scalar(1), AndOp, CstZero) + .getReg(0); + } + + // The branch probability from SwitchBB to B.TargetBB is B.ExtraProb. + addSuccessorWithProb(SwitchBB, B.TargetBB, B.ExtraProb); + // The branch probability from SwitchBB to NextMBB is BranchProbToNext. + addSuccessorWithProb(SwitchBB, NextMBB, BranchProbToNext); + // It is not guaranteed that the sum of B.ExtraProb and BranchProbToNext is + // one as they are relative probabilities (and thus work more like weights), + // and hence we need to normalize them to let the sum of them become one. + SwitchBB->normalizeSuccProbs(); + + // Record the fact that the IR edge from the header to the bit test target + // will go through our new block. Neeeded for PHIs to have nodes added. + addMachineCFGPred({BB.Parent->getBasicBlock(), B.TargetBB->getBasicBlock()}, + SwitchBB); + + MIB.buildBrCond(Cmp, *B.TargetBB); + + // Avoid emitting unnecessary branches to the next block. + if (NextMBB != SwitchBB->getNextNode()) + MIB.buildBr(*NextMBB); +} + +bool IRTranslator::lowerBitTestWorkItem( + SwitchCG::SwitchWorkListItem W, MachineBasicBlock *SwitchMBB, + MachineBasicBlock *CurMBB, MachineBasicBlock *DefaultMBB, + MachineIRBuilder &MIB, MachineFunction::iterator BBI, + BranchProbability DefaultProb, BranchProbability UnhandledProbs, + SwitchCG::CaseClusterIt I, MachineBasicBlock *Fallthrough, + bool FallthroughUnreachable) { + using namespace SwitchCG; + MachineFunction *CurMF = SwitchMBB->getParent(); + // FIXME: Optimize away range check based on pivot comparisons. + BitTestBlock *BTB = &SL->BitTestCases[I->BTCasesIndex]; + // The bit test blocks haven't been inserted yet; insert them here. + for (BitTestCase &BTC : BTB->Cases) + CurMF->insert(BBI, BTC.ThisBB); + + // Fill in fields of the BitTestBlock. + BTB->Parent = CurMBB; + BTB->Default = Fallthrough; + + BTB->DefaultProb = UnhandledProbs; + // If the cases in bit test don't form a contiguous range, we evenly + // distribute the probability on the edge to Fallthrough to two + // successors of CurMBB. + if (!BTB->ContiguousRange) { + BTB->Prob += DefaultProb / 2; + BTB->DefaultProb -= DefaultProb / 2; + } + + if (FallthroughUnreachable) { + // Skip the range check if the fallthrough block is unreachable. + BTB->OmitRangeCheck = true; + } + + // If we're in the right place, emit the bit test header right now. + if (CurMBB == SwitchMBB) { + emitBitTestHeader(*BTB, SwitchMBB); + BTB->Emitted = true; + } + return true; +} + bool IRTranslator::lowerSwitchWorkItem(SwitchCG::SwitchWorkListItem W, Value *Cond, MachineBasicBlock *SwitchMBB, @@ -777,9 +928,15 @@ switch (I->Kind) { case CC_BitTests: { - LLVM_DEBUG(dbgs() << "Switch to bit test optimization unimplemented"); - return false; // Bit tests currently unimplemented. + if (!lowerBitTestWorkItem(W, SwitchMBB, CurMBB, DefaultMBB, MIB, BBI, + DefaultProb, UnhandledProbs, I, Fallthrough, + FallthroughUnreachable)) { + LLVM_DEBUG(dbgs() << "Failed to lower bit test for switch"); + return false; + } + break; } + case CC_JumpTable: { if (!lowerJumpTableWorkItem(W, SwitchMBB, CurMBB, DefaultMBB, MIB, BBI, UnhandledProbs, I, Fallthrough, @@ -2309,6 +2466,57 @@ } void IRTranslator::finalizeBasicBlock() { + for (auto &BTB : SL->BitTestCases) { + // Emit header first, if it wasn't already emitted. + if (!BTB.Emitted) + emitBitTestHeader(BTB, BTB.Parent); + + BranchProbability UnhandledProb = BTB.Prob; + for (unsigned j = 0, ej = BTB.Cases.size(); j != ej; ++j) { + UnhandledProb -= BTB.Cases[j].ExtraProb; + // Set the current basic block to the mbb we wish to insert the code into + MachineBasicBlock *MBB = BTB.Cases[j].ThisBB; + // If all cases cover a contiguous range, it is not necessary to jump to + // the default block after the last bit test fails. This is because the + // range check during bit test header creation has guaranteed that every + // case here doesn't go outside the range. In this case, there is no need + // to perform the last bit test, as it will always be true. Instead, make + // the second-to-last bit-test fall through to the target of the last bit + // test, and delete the last bit test. + + MachineBasicBlock *NextMBB; + if (BTB.ContiguousRange && j + 2 == ej) { + // Second-to-last bit-test with contiguous range: fall through to the + // target of the final bit test. + NextMBB = BTB.Cases[j + 1].TargetBB; + } else if (j + 1 == ej) { + // For the last bit test, fall through to Default. + NextMBB = BTB.Default; + } else { + // Otherwise, fall through to the next bit test. + NextMBB = BTB.Cases[j + 1].ThisBB; + } + + emitBitTestCase(BTB, NextMBB, UnhandledProb, BTB.Reg, BTB.Cases[j], MBB); + + // FIXME delete this block below? + if (BTB.ContiguousRange && j + 2 == ej) { + // Since we're not going to use the final bit test, remove it. + BTB.Cases.pop_back(); + break; + } + } + // This is "default" BB. We have two jumps to it. From "header" BB and from + // last "case" BB, unless the latter was skipped. + CFGEdge HeaderToDefaultEdge = {BTB.Parent->getBasicBlock(), + BTB.Default->getBasicBlock()}; + addMachineCFGPred(HeaderToDefaultEdge, BTB.Parent); + if (!BTB.ContiguousRange) { + addMachineCFGPred(HeaderToDefaultEdge, BTB.Cases.back().ThisBB); + } + } + SL->BitTestCases.clear(); + for (auto &JTCase : SL->JTCases) { // Emit header first, if it wasn't already emitted. if (!JTCase.first.Emitted) diff --git a/llvm/lib/CodeGen/GlobalISel/MachineIRBuilder.cpp b/llvm/lib/CodeGen/GlobalISel/MachineIRBuilder.cpp --- a/llvm/lib/CodeGen/GlobalISel/MachineIRBuilder.cpp +++ b/llvm/lib/CodeGen/GlobalISel/MachineIRBuilder.cpp @@ -312,11 +312,14 @@ return buildFConstant(Res, *CFP); } -MachineInstrBuilder MachineIRBuilder::buildBrCond(Register Tst, +MachineInstrBuilder MachineIRBuilder::buildBrCond(const SrcOp &Tst, MachineBasicBlock &Dest) { - assert(getMRI()->getType(Tst).isScalar() && "invalid operand type"); + assert(Tst.getLLTTy(*getMRI()).isScalar() && "invalid operand type"); - return buildInstr(TargetOpcode::G_BRCOND).addUse(Tst).addMBB(&Dest); + auto MIB = buildInstr(TargetOpcode::G_BRCOND); + Tst.addSrcToMIB(MIB); + MIB.addMBB(&Dest); + return MIB; } MachineInstrBuilder diff --git a/llvm/test/CodeGen/AArch64/GlobalISel/irtranslator-switch-bittest.ll b/llvm/test/CodeGen/AArch64/GlobalISel/irtranslator-switch-bittest.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/AArch64/GlobalISel/irtranslator-switch-bittest.ll @@ -0,0 +1,132 @@ +; NOTE: Assertions have been autogenerated by utils/update_mir_test_checks.py +; RUN: llc -mtriple aarch64 -aarch64-enable-atomic-cfg-tidy=0 -stop-after=irtranslator -global-isel -verify-machineinstrs %s -o - 2>&1 | FileCheck %s + +define i32 @test_bittest(i16 %p) { + ; CHECK-LABEL: name: test_bittest + ; CHECK: bb.1 (%ir-block.0): + ; CHECK: successors: %bb.4(0x40000000), %bb.5(0x40000000) + ; CHECK: liveins: $w0 + ; CHECK: [[COPY:%[0-9]+]]:_(s32) = COPY $w0 + ; CHECK: [[TRUNC:%[0-9]+]]:_(s16) = G_TRUNC [[COPY]](s32) + ; CHECK: [[C:%[0-9]+]]:_(s32) = G_CONSTANT i32 114 + ; CHECK: [[C1:%[0-9]+]]:_(s32) = G_CONSTANT i32 42 + ; CHECK: [[C2:%[0-9]+]]:_(s32) = G_CONSTANT i32 0 + ; CHECK: [[ZEXT:%[0-9]+]]:_(s32) = G_ZEXT [[TRUNC]](s16) + ; CHECK: [[C3:%[0-9]+]]:_(s32) = G_CONSTANT i32 0 + ; CHECK: [[SUB:%[0-9]+]]:_(s32) = G_SUB [[ZEXT]], [[C3]] + ; CHECK: [[ZEXT1:%[0-9]+]]:_(s64) = G_ZEXT [[SUB]](s32) + ; CHECK: [[C4:%[0-9]+]]:_(s32) = G_CONSTANT i32 59 + ; CHECK: [[ICMP:%[0-9]+]]:_(s1) = G_ICMP intpred(ugt), [[SUB]](s32), [[C4]] + ; CHECK: G_BRCOND [[ICMP]](s1), %bb.4 + ; CHECK: G_BR %bb.5 + ; CHECK: bb.4 (%ir-block.0): + ; CHECK: successors: %bb.3(0x40000000), %bb.2(0x40000000) + ; CHECK: [[ICMP1:%[0-9]+]]:_(s1) = G_ICMP intpred(eq), [[ZEXT]](s32), [[C]] + ; CHECK: G_BRCOND [[ICMP1]](s1), %bb.3 + ; CHECK: G_BR %bb.2 + ; CHECK: bb.5 (%ir-block.0): + ; CHECK: successors: %bb.3(0x40000000), %bb.4(0x40000000) + ; CHECK: [[C5:%[0-9]+]]:_(s64) = G_CONSTANT i64 1 + ; CHECK: [[SHL:%[0-9]+]]:_(s64) = G_SHL [[C5]], [[ZEXT1]](s64) + ; CHECK: [[C6:%[0-9]+]]:_(s64) = G_CONSTANT i64 866239240827043840 + ; CHECK: [[AND:%[0-9]+]]:_(s64) = G_AND [[SHL]], [[C6]] + ; CHECK: [[C7:%[0-9]+]]:_(s64) = G_CONSTANT i64 0 + ; CHECK: [[ICMP2:%[0-9]+]]:_(s1) = G_ICMP intpred(ne), [[AND]](s64), [[C7]] + ; CHECK: G_BRCOND [[ICMP2]](s1), %bb.3 + ; CHECK: G_BR %bb.4 + ; CHECK: bb.2.sw.epilog: + ; CHECK: $w0 = COPY [[C2]](s32) + ; CHECK: RET_ReallyLR implicit $w0 + ; CHECK: bb.3.cb1: + ; CHECK: $w0 = COPY [[C1]](s32) + ; CHECK: RET_ReallyLR implicit $w0 + switch i16 %p, label %sw.epilog [ + i16 58, label %cb1 + i16 59, label %cb1 + i16 47, label %cb1 + i16 48, label %cb1 + i16 50, label %cb1 + i16 114, label %cb1 + ] +sw.epilog: + ret i32 0 + +cb1: + ret i32 42 +} + + +declare void @callee() + +define void @test_bittest_2_bt(i32 %p) { + ; CHECK-LABEL: name: test_bittest_2_bt + ; CHECK: bb.1.entry: + ; CHECK: successors: %bb.5(0x40000000), %bb.6(0x40000000) + ; CHECK: liveins: $w0 + ; CHECK: [[COPY:%[0-9]+]]:_(s32) = COPY $w0 + ; CHECK: [[C:%[0-9]+]]:_(s32) = G_CONSTANT i32 176 + ; CHECK: [[SUB:%[0-9]+]]:_(s32) = G_SUB [[COPY]], [[C]] + ; CHECK: [[C1:%[0-9]+]]:_(s32) = G_CONSTANT i32 15 + ; CHECK: [[ICMP:%[0-9]+]]:_(s1) = G_ICMP intpred(ugt), [[SUB]](s32), [[C1]] + ; CHECK: G_BRCOND [[ICMP]](s1), %bb.5 + ; CHECK: G_BR %bb.6 + ; CHECK: bb.5.entry: + ; CHECK: successors: %bb.4(0x40000000), %bb.7(0x40000000) + ; CHECK: [[C2:%[0-9]+]]:_(s32) = G_CONSTANT i32 0 + ; CHECK: [[SUB1:%[0-9]+]]:_(s32) = G_SUB [[COPY]], [[C2]] + ; CHECK: [[ZEXT:%[0-9]+]]:_(s64) = G_ZEXT [[SUB1]](s32) + ; CHECK: [[C3:%[0-9]+]]:_(s32) = G_CONSTANT i32 38 + ; CHECK: [[ICMP1:%[0-9]+]]:_(s1) = G_ICMP intpred(ugt), [[SUB1]](s32), [[C3]] + ; CHECK: G_BRCOND [[ICMP1]](s1), %bb.4 + ; CHECK: G_BR %bb.7 + ; CHECK: bb.6.entry: + ; CHECK: successors: %bb.2(0x40000000), %bb.5(0x40000000) + ; CHECK: [[C4:%[0-9]+]]:_(s32) = G_CONSTANT i32 1 + ; CHECK: [[SHL:%[0-9]+]]:_(s32) = G_SHL [[C4]], [[SUB]](s32) + ; CHECK: [[C5:%[0-9]+]]:_(s32) = G_CONSTANT i32 57351 + ; CHECK: [[AND:%[0-9]+]]:_(s32) = G_AND [[SHL]], [[C5]] + ; CHECK: [[C6:%[0-9]+]]:_(s32) = G_CONSTANT i32 0 + ; CHECK: [[ICMP2:%[0-9]+]]:_(s1) = G_ICMP intpred(ne), [[AND]](s32), [[C6]] + ; CHECK: G_BRCOND [[ICMP2]](s1), %bb.2 + ; CHECK: G_BR %bb.5 + ; CHECK: bb.7.entry: + ; CHECK: successors: %bb.3(0x40000000), %bb.4(0x40000000) + ; CHECK: [[C7:%[0-9]+]]:_(s64) = G_CONSTANT i64 1 + ; CHECK: [[SHL1:%[0-9]+]]:_(s64) = G_SHL [[C7]], [[ZEXT]](s64) + ; CHECK: [[C8:%[0-9]+]]:_(s64) = G_CONSTANT i64 365072220160 + ; CHECK: [[AND1:%[0-9]+]]:_(s64) = G_AND [[SHL1]], [[C8]] + ; CHECK: [[C9:%[0-9]+]]:_(s64) = G_CONSTANT i64 0 + ; CHECK: [[ICMP3:%[0-9]+]]:_(s1) = G_ICMP intpred(ne), [[AND1]](s64), [[C9]] + ; CHECK: G_BRCOND [[ICMP3]](s1), %bb.3 + ; CHECK: G_BR %bb.4 + ; CHECK: bb.2.sw.bb37: + ; CHECK: TCRETURNdi @callee, 0, csr_aarch64_aapcs, implicit $sp + ; CHECK: bb.3.sw.bb55: + ; CHECK: TCRETURNdi @callee, 0, csr_aarch64_aapcs, implicit $sp + ; CHECK: bb.4.sw.default: + ; CHECK: RET_ReallyLR +entry: + switch i32 %p, label %sw.default [ + i32 32, label %sw.bb55 + i32 34, label %sw.bb55 + i32 36, label %sw.bb55 + i32 191, label %sw.bb37 + i32 190, label %sw.bb37 + i32 189, label %sw.bb37 + i32 178, label %sw.bb37 + i32 177, label %sw.bb37 + i32 176, label %sw.bb37 + i32 38, label %sw.bb55 + ] + +sw.bb37: ; preds = %entry, %entry, %entry, %entry, %entry, %entry + tail call void @callee() + ret void + +sw.bb55: ; preds = %entry, %entry, %entry, %entry + tail call void @callee() + ret void + +sw.default: ; preds = %entry + ret void +}