Index: lib/Target/AArch64/AArch64ISelLowering.cpp =================================================================== --- lib/Target/AArch64/AArch64ISelLowering.cpp +++ lib/Target/AArch64/AArch64ISelLowering.cpp @@ -9491,6 +9491,98 @@ return SDValue(); } +// Optimize some simple tbz/tbnz cases. Returns the new operand and bit to test +// as well as whether the test should be inverted. This code is required to +// catch these cases (as opposed to standard dag combines) because +// AArch64ISD::TBZ is matched during legalization. +static SDValue getTestBitOperand(SDValue Op, unsigned &Bit, bool &Invert, + SelectionDAG &DAG) { + + if (!Op->hasOneUse()) + return Op; + + // We don't handle undef/constant-fold cases here, as they should have + // already been taken care of (e.g. and of 0, test of undefined shifted + // bits, etc.) + + // (tbz (trunc x), b) -> (tbz x, b) + // This case is just here to enable more of the below cases to be caught. + if (Op->getOpcode() == ISD::TRUNCATE && + Bit < Op->getValueType(0).getSizeInBits()) { + return getTestBitOperand(Op->getOperand(0), Bit, Invert, DAG); + } + + if (Op->getNumOperands() != 2) + return Op; + + auto *C = dyn_cast(Op->getOperand(1)); + if (!C) + return Op; + + switch (Op->getOpcode()) { + default: + return Op; + + // (tbz (and x, m), b) -> (tbz x, b) + case ISD::AND: + if ((C->getZExtValue() >> Bit) & 1) { + return getTestBitOperand(Op->getOperand(0), Bit, Invert, DAG); + } + return Op; + + // (tbz (shl x, c), b) -> (tbz x, b-c) + case ISD::SHL: + if (C->getZExtValue() <= Bit && + (Bit - C->getZExtValue()) < Op->getValueType(0).getSizeInBits()) { + Bit = Bit - C->getZExtValue(); + return getTestBitOperand(Op->getOperand(0), Bit, Invert, DAG); + } + return Op; + + // (tbz (shr x, c), b) -> (tbz x, b+c) + case ISD::SRA: + case ISD::SRL: + if ((Bit + C->getZExtValue()) < Op->getValueType(0).getSizeInBits()) { + Bit = Bit + C->getZExtValue(); + return getTestBitOperand(Op->getOperand(0), Bit, Invert, DAG); + } + return Op; + + // (tbz (xor x, -1), b) -> (tbnz x, b) + case ISD::XOR: + if ((C->getZExtValue() >> Bit) & 1) + Invert = !Invert; + return getTestBitOperand(Op->getOperand(0), Bit, Invert, DAG); + } +} + +// Optimize test single bit zero/non-zero and branch. +static SDValue performTBZCombine(SDNode *N, + TargetLowering::DAGCombinerInfo &DCI, + SelectionDAG &DAG) { + unsigned Bit = cast(N->getOperand(2))->getZExtValue(); + bool Invert = false; + SDValue TestSrc = N->getOperand(1); + SDValue NewTestSrc = getTestBitOperand(TestSrc, Bit, Invert, DAG); + + if (TestSrc == NewTestSrc) + return SDValue(); + + unsigned NewOpc = N->getOpcode(); + if (Invert) { + if (NewOpc == AArch64ISD::TBZ) + NewOpc = AArch64ISD::TBNZ; + else { + assert(NewOpc == AArch64ISD::TBNZ); + NewOpc = AArch64ISD::TBZ; + } + } + + SDLoc DL(N); + return DAG.getNode(NewOpc, DL, MVT::Other, N->getOperand(0), NewTestSrc, + DAG.getConstant(Bit, DL, MVT::i64), N->getOperand(3)); +} + // vselect (v1i1 setcc) -> // vselect (v1iXX setcc) (XX is the size of the compared operand type) // FIXME: Currently the type legalizer can't handle VSELECT having v1i1 as @@ -9642,6 +9734,9 @@ return performSTORECombine(N, DCI, DAG, Subtarget); case AArch64ISD::BRCOND: return performBRCONDCombine(N, DCI, DAG); + case AArch64ISD::TBNZ: + case AArch64ISD::TBZ: + return performTBZCombine(N, DCI, DAG); case AArch64ISD::CSEL: return performCONDCombine(N, DCI, DAG, 2, 3); case AArch64ISD::DUP: Index: test/CodeGen/AArch64/tbz-tbnz.ll =================================================================== --- test/CodeGen/AArch64/tbz-tbnz.ll +++ test/CodeGen/AArch64/tbz-tbnz.ll @@ -256,3 +256,53 @@ if.end: ret void } + +define void @test14(i1 %cond) { +; CHECK-LABEL: @test14 + br i1 %cond, label %if.end, label %if.then + +; CHECK-NOT: and +; CHECK: tbnz w0, #0 + +if.then: + call void @t() + br label %if.end + +if.end: + ret void +} + +define void @test15(i1 %cond) { +; CHECK-LABEL: @test15 + %cond1 = xor i1 %cond, -1 + br i1 %cond1, label %if.then, label %if.end + +; CHECK-NOT: movn +; CHECK: tbnz w0, #0 + +if.then: + call void @t() + br label %if.end + +if.end: + ret void +} + +define void @test16(i64 %in) { +; CHECK-LABEL: @test16 + %shl = lshr i64 %in, 3 + %trunc = trunc i64 %shl to i32 + %and = and i32 %trunc, 1 + %cond = icmp eq i32 %and, 0 + br i1 %cond, label %then, label %end + +; CHECK-NOT: ubfx +; CHECK: tbnz w0, #3 + +then: + call void @t() + br label %end + +end: + ret void +}