Index: lib/Target/ARM/ARMISelDAGToDAG.cpp =================================================================== --- lib/Target/ARM/ARMISelDAGToDAG.cpp +++ lib/Target/ARM/ARMISelDAGToDAG.cpp @@ -2998,13 +2998,26 @@ // Other cases are autogenerated. break; } - case ARMISD::WLS: { - SDValue Ops[] = { N->getOperand(1), // Loop count - N->getOperand(2), // Exit target + case ARMISD::WLS: + case ARMISD::LE: { + SDValue Ops[] = { N->getOperand(1), + N->getOperand(2), N->getOperand(0) }; - SDNode *LoopStart = - CurDAG->getMachineNode(ARM::t2WhileLoopStart, dl, MVT::Other, Ops); - ReplaceUses(N, LoopStart); + unsigned Opc = N->getOpcode() == ARMISD::WLS ? + ARM::t2WhileLoopStart : ARM::t2LoopEnd; + SDNode *New = CurDAG->getMachineNode(Opc, dl, MVT::Other, Ops); + ReplaceUses(N, New); + CurDAG->RemoveDeadNode(N); + return; + } + case ARMISD::LoopDec: { + SDValue Ops[] = { N->getOperand(1), + N->getOperand(2), + N->getOperand(0) }; + SDNode *Dec = + CurDAG->getMachineNode(ARM::t2LoopDec, dl, + CurDAG->getVTList(MVT::i32, MVT::Other), Ops); + ReplaceUses(N, Dec); CurDAG->RemoveDeadNode(N); return; } @@ -3035,36 +3048,6 @@ unsigned CC = (unsigned) cast(N2)->getZExtValue(); if (InFlag.getOpcode() == ARMISD::CMPZ) { - if (InFlag.getOperand(0).getOpcode() == ISD::INTRINSIC_W_CHAIN) { - SDValue Int = InFlag.getOperand(0); - uint64_t ID = cast(Int->getOperand(1))->getZExtValue(); - - // Handle low-overhead loops. - if (ID == Intrinsic::loop_decrement_reg) { - SDValue Elements = Int.getOperand(2); - SDValue Size = CurDAG->getTargetConstant( - cast(Int.getOperand(3))->getZExtValue(), dl, - MVT::i32); - - SDValue Args[] = { Elements, Size, Int.getOperand(0) }; - SDNode *LoopDec = - CurDAG->getMachineNode(ARM::t2LoopDec, dl, - CurDAG->getVTList(MVT::i32, MVT::Other), - Args); - ReplaceUses(Int.getNode(), LoopDec); - - SDValue EndArgs[] = { SDValue(LoopDec, 0), N1, Chain }; - SDNode *LoopEnd = - CurDAG->getMachineNode(ARM::t2LoopEnd, dl, MVT::Other, EndArgs); - - ReplaceUses(N, LoopEnd); - CurDAG->RemoveDeadNode(N); - CurDAG->RemoveDeadNode(InFlag.getNode()); - CurDAG->RemoveDeadNode(Int.getNode()); - return; - } - } - bool SwitchEQNEToPLMI; SelectCMPZ(InFlag.getNode(), SwitchEQNEToPLMI); InFlag = N->getOperand(4); Index: lib/Target/ARM/ARMISelLowering.h =================================================================== --- lib/Target/ARM/ARMISelLowering.h +++ lib/Target/ARM/ARMISelLowering.h @@ -126,6 +126,8 @@ WIN__DBZCHK, // Windows' divide by zero check WLS, // Low-overhead loops, While Loop Start + LoopDec, // Really a part of LE, performs a sub. + LE, // Low-overhead loops, Loop End VCEQ, // Vector compare equal. VCEQZ, // Vector compare equal to zero. Index: lib/Target/ARM/ARMISelLowering.cpp =================================================================== --- lib/Target/ARM/ARMISelLowering.cpp +++ lib/Target/ARM/ARMISelLowering.cpp @@ -652,8 +652,10 @@ addMVEVectorTypes(Subtarget->hasMVEFloatOps()); // Combine low-overhead loop intrinsics so that we can lower i1 types. - if (Subtarget->hasLOB()) + if (Subtarget->hasLOB()) { setTargetDAGCombine(ISD::BRCOND); + setTargetDAGCombine(ISD::BR_CC); + } if (Subtarget->hasNEON()) { addDRTypeForNEON(MVT::v2f32); @@ -1568,6 +1570,8 @@ case ARMISD::VST3LN_UPD: return "ARMISD::VST3LN_UPD"; case ARMISD::VST4LN_UPD: return "ARMISD::VST4LN_UPD"; case ARMISD::WLS: return "ARMISD::WLS"; + case ARMISD::LE: return "ARMISD::LE"; + case ARMISD::LoopDec: return "ARMISD::LoopDec"; } return nullptr; } @@ -12984,43 +12988,155 @@ return V; } +// Given N, the value controlling the conditional branch, search for the loop +// intrinsic, returning it, along with how the value is used. We need to handle +// patterns such as the following: +// (brcond (xor (setcc (loop.decrement), 0, ne), 1), exit) +// (brcond (setcc (loop.decrement), 0, eq), exit) +// (brcond (setcc (loop.decrement), 0, ne), header) +static SDValue IsLoopIntrinsic(SDValue N, ISD::CondCode &CC, int &Imm, + bool &Negate) { + switch (N->getOpcode()) { + default: + break; + case ISD::XOR: { + if (!isa(N.getOperand(1))) + return SDValue(); + if (!cast(N.getOperand(1))->isOne()) + return SDValue(); + Negate = !Negate; + return IsLoopIntrinsic(N.getOperand(0), CC, Imm, Negate); + } + case ISD::SETCC: { + auto *Const = dyn_cast(N.getOperand(1)); + if (!Const) + return SDValue(); + if (Const->isNullValue()) + Imm = 0; + else if (Const->isOne()) + Imm = 1; + else + return SDValue(); + CC = cast(N.getOperand(2))->get(); + return IsLoopIntrinsic(N->getOperand(0), CC, Imm, Negate); + } + case ISD::INTRINSIC_W_CHAIN: { + unsigned IntOp = cast(N.getOperand(1))->getZExtValue(); + if (IntOp != Intrinsic::test_set_loop_iterations && + IntOp != Intrinsic::loop_decrement_reg) + return SDValue(); + return N; + } + } + return SDValue(); +} + static SDValue PerformHWLoopCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, const ARMSubtarget *ST) { - // Look for (brcond (xor test.set.loop.iterations, -1) - SDValue CC = N->getOperand(1); - unsigned Opc = CC->getOpcode(); - SDValue Int; - if ((Opc == ISD::XOR || Opc == ISD::SETCC) && - (CC->getOperand(0)->getOpcode() == ISD::INTRINSIC_W_CHAIN)) { + // The hwloop intrinsics that we're interested are used for control-flow, + // either for entering or exiting the loop: + // - test.set.loop.iterations will test whether its operand is zero. If it + // is zero, the proceeding branch should not enter the loop. + // - loop.decrement.reg also tests whether its operand is zero. If it is + // zero, the proceeding branch should not branch back to the beginning of + // the loop. + // So here, we need to check that how the brcond is using the result of each + // of the intrinsics to ensure that we're branching to the right place at the + // right time. + + ISD::CondCode CC; + SDValue Cond; + int Imm = 1; + bool Negate = false; + SDValue Chain = N->getOperand(0); + SDValue Dest; - assert((isa(CC->getOperand(1)) && - cast(CC->getOperand(1))->isOne()) && - "Expected to compare against 1"); + if (N->getOpcode() == ISD::BRCOND) { + CC = ISD::SETEQ; + Cond = N->getOperand(1); + Dest = N->getOperand(2); + } else { + CC = cast(N->getOperand(1))->get(); + Cond = N->getOperand(2); + Dest = N->getOperand(4); + if (auto *Const = dyn_cast(N->getOperand(3))) { + if (!Const->isOne() && !Const->isNull()) + return SDValue(); + Imm = Const->getZExtValue(); + } else + return SDValue(); + } - Int = CC->getOperand(0); - } else if (CC->getOpcode() == ISD::INTRINSIC_W_CHAIN) - Int = CC; - else + SDValue Int = IsLoopIntrinsic(Cond, CC, Imm, Negate); + if (!Int) return SDValue(); - unsigned IntOp = cast(Int.getOperand(1))->getZExtValue(); - if (IntOp != Intrinsic::test_set_loop_iterations) - return SDValue(); + assert((CC == ISD::SETEQ || CC == ISD::SETNE) && + "unexpected condition code"); + + if (Negate) + CC = CC == ISD::SETEQ ? ISD::SETNE : ISD::SETEQ; SDLoc dl(Int); - SDValue Chain = N->getOperand(0); + SelectionDAG &DAG = DCI.DAG; SDValue Elements = Int.getOperand(2); - SDValue ExitBlock = N->getOperand(2); + unsigned IntOp = cast(Int->getOperand(1))->getZExtValue(); + assert((N->hasOneUse() && N->use_begin()->getOpcode() == ISD::BR) + && "expected single br user"); + SDNode *Br = *N->use_begin(); + SDValue OtherTarget = Br->getOperand(1); + + // Update the unconditional branch to branch to the given Dest. + auto UpdateUncondBr = [](SDNode *Br, SDValue Dest, SelectionDAG &DAG) { + SDValue NewBrOps[] = { Br->getOperand(0), Dest }; + SDValue NewBr = DAG.getNode(ISD::BR, SDLoc(Br), MVT::Other, NewBrOps); + DAG.ReplaceAllUsesOfValueWith(SDValue(Br, 0), NewBr); + }; - // TODO: Once we start supporting tail predication, we can add another - // operand to WLS for the number of elements processed in a vector loop. + if (IntOp == Intrinsic::test_set_loop_iterations) { + SDValue Res; + // We expect this 'instruction' to branch when the counter is zero. + if ((CC == ISD::SETEQ && Imm == 0) || + (CC == ISD::SETNE && Imm == 1)) { + SDValue Ops[] = { Chain, Elements, Dest }; + Res = DAG.getNode(ARMISD::WLS, dl, MVT::Other, Ops); + } else if ((CC == ISD::SETNE && Imm == 0) || + (CC == ISD::SETEQ && Imm == 1)) { + // The logic is the reverse of what we need for WLS, so find the other + // basic block target: the target of the proceeding br. + UpdateUncondBr(Br, Dest, DAG); + + SDValue Ops[] = { Chain, Elements, OtherTarget }; + Res = DAG.getNode(ARMISD::WLS, dl, MVT::Other, Ops); + } else + llvm_unreachable("unhandled condition"); + DAG.ReplaceAllUsesOfValueWith(Int.getValue(1), Int.getOperand(0)); + return Res; + } else { + SDValue Size = DAG.getTargetConstant( + cast(Int.getOperand(3))->getZExtValue(), dl, MVT::i32); + SDValue Args[] = { Int.getOperand(0), Elements, Size, }; + SDValue LoopDec = DAG.getNode(ARMISD::LoopDec, dl, + DAG.getVTList(MVT::i32, MVT::Other), Args); + DAG.ReplaceAllUsesWith(Int.getNode(), LoopDec.getNode()); - SDValue Ops[] = { Chain, Elements, ExitBlock }; - SDValue Res = DCI.DAG.getNode(ARMISD::WLS, dl, MVT::Other, Ops); - DCI.DAG.ReplaceAllUsesOfValueWith(Int.getValue(1), Int.getOperand(0)); - return Res; + // We expect this instruction to branch when the count is not zero. + SDValue Target = (CC == ISD::SETNE && Imm == 0) ? Dest : OtherTarget; + + // Update the unconditional branch to target the loop preheader if we've + // found the condition has been reversed. + if (Target == OtherTarget) + UpdateUncondBr(Br, Dest, DAG); + + Chain = DAG.getNode(ISD::TokenFactor, dl, MVT::Other, + SDValue(LoopDec.getNode(), 1), Chain); + + SDValue EndArgs[] = { Chain, SDValue(LoopDec.getNode(), 0), Target }; + return DAG.getNode(ARMISD::LE, dl, MVT::Other, EndArgs); + } + return SDValue(); } /// PerformBRCONDCombine - Target-specific DAG combining for ARMISD::BRCOND. @@ -13254,7 +13370,8 @@ case ISD::OR: return PerformORCombine(N, DCI, Subtarget); case ISD::XOR: return PerformXORCombine(N, DCI, Subtarget); case ISD::AND: return PerformANDCombine(N, DCI, Subtarget); - case ISD::BRCOND: return PerformHWLoopCombine(N, DCI, Subtarget); + case ISD::BRCOND: + case ISD::BR_CC: return PerformHWLoopCombine(N, DCI, Subtarget); case ARMISD::ADDC: case ARMISD::SUBC: return PerformAddcSubcCombine(N, DCI, Subtarget); case ARMISD::SUBE: return PerformAddeSubeCombine(N, DCI, Subtarget); Index: lib/Target/ARM/ARMInstrInfo.td =================================================================== --- lib/Target/ARM/ARMInstrInfo.td +++ lib/Target/ARM/ARMInstrInfo.td @@ -108,8 +108,8 @@ // TODO Add another operand for 'Size' so that we can re-use this node when we // start supporting *TP versions. -def SDT_ARMWhileLoop : SDTypeProfile<0, 2, [SDTCisVT<0, i32>, - SDTCisVT<1, OtherVT>]>; +def SDT_ARMLoLoop : SDTypeProfile<0, 2, [SDTCisVT<0, i32>, + SDTCisVT<1, OtherVT>]>; def ARMSmlald : SDNode<"ARMISD::SMLALD", SDT_LongMac>; def ARMSmlaldx : SDNode<"ARMISD::SMLALDX", SDT_LongMac>; @@ -254,8 +254,9 @@ def ARMvmvnImm : SDNode<"ARMISD::VMVNIMM", SDTARMVMOVIMM>; def ARMvmovFPImm : SDNode<"ARMISD::VMOVFPIMM", SDTARMVMOVIMM>; -def ARMWLS : SDNode<"ARMISD::WLS", SDT_ARMWhileLoop, - [SDNPHasChain]>; +def ARMWLS : SDNode<"ARMISD::WLS", SDT_ARMLoLoop, [SDNPHasChain]>; +def ARMLE : SDNode<"ARMISD::LE", SDT_ARMLoLoop, [SDNPHasChain]>; +def ARMLoopDec : SDNode<"ARMISD::LoopDec", SDTIntBinOp, [SDNPHasChain]>; //===----------------------------------------------------------------------===// // ARM Flag Definitions. Index: test/CodeGen/Thumb2/LowOverheadLoops/branch-targets.ll =================================================================== --- /dev/null +++ test/CodeGen/Thumb2/LowOverheadLoops/branch-targets.ll @@ -0,0 +1,166 @@ +; RUN: llc -mtriple=thumbv8.1m.main -O0 -mattr=+lob -disable-arm-loloops=false -stop-before=arm-low-overhead-loops %s -o - | FileCheck %s --check-prefix=CHECK-MID +; RUN: llc -mtriple=thumbv8.1m.main -O0 -mattr=+lob -disable-arm-loloops=false -verify-machineinstrs %s -o - | FileCheck %s --check-prefix=CHECK-END + +; Test that the branch targets are correct after isel, even though the loop +; will sometimes be reverted anyway. + +; CHECK-MID: check_loop_dec_brcond_combine +; CHECK-MID: bb.2.for.body: +; CHECK-MID: renamable $lr = t2LoopDec killed renamable $lr, 1 +; CHECK-MID: t2LoopEnd killed renamable $lr, %bb.3 +; CHECK-MID: bb.3.for.header: +; CHECK-MID: tB %bb.2 + +; CHECK-END: .LBB0_1: +; CHECK-END: b .LBB0_3 +; CHECK-END: .LBB0_2: +; CHECK-END: sub.w lr, lr, #1 +; CHECK-END: cmp.w lr, #0 +; CHECK-END: bne.w .LBB0_3 +; CHECK-END: b .LBB0_4 +; CHECK-END: .LBB0_3: +; CHECK-END: b .LBB0_2 +define void @check_loop_dec_brcond_combine(i32* nocapture %a, i32* nocapture readonly %b, i32* nocapture readonly %c, i32 %N) { +entry: + call void @llvm.set.loop.iterations.i32(i32 %N) + br label %for.body.preheader + +for.body.preheader: + %scevgep = getelementptr i32, i32* %a, i32 -1 + %scevgep4 = getelementptr i32, i32* %c, i32 -1 + %scevgep8 = getelementptr i32, i32* %b, i32 -1 + br label %for.header + +for.body: + %scevgep11 = getelementptr i32, i32* %lsr.iv9, i32 1 + %ld1 = load i32, i32* %scevgep11, align 4 + %scevgep7 = getelementptr i32, i32* %lsr.iv5, i32 1 + %ld2 = load i32, i32* %scevgep7, align 4 + %mul = mul nsw i32 %ld2, %ld1 + %scevgep3 = getelementptr i32, i32* %lsr.iv1, i32 1 + store i32 %mul, i32* %scevgep3, align 4 + %scevgep2 = getelementptr i32, i32* %lsr.iv1, i32 1 + %scevgep6 = getelementptr i32, i32* %lsr.iv5, i32 1 + %scevgep10 = getelementptr i32, i32* %lsr.iv9, i32 1 + %count.next = call i32 @llvm.loop.decrement.reg.i32.i32.i32(i32 %count, i32 1) + %cmp = icmp ne i32 %count.next, 0 + br i1 %cmp, label %for.header, label %for.cond.cleanup + +for.header: + %lsr.iv9 = phi i32* [ %scevgep8, %for.body.preheader ], [ %scevgep10, %for.body ] + %lsr.iv5 = phi i32* [ %scevgep4, %for.body.preheader ], [ %scevgep6, %for.body ] + %lsr.iv1 = phi i32* [ %scevgep, %for.body.preheader ], [ %scevgep2, %for.body ] + %count = phi i32 [ %N, %for.body.preheader ], [ %count.next, %for.body ] + br label %for.body + +for.cond.cleanup: + ret void +} + +; CHECK-MID: check_negated_xor_wls +; CHECK-MID: t2WhileLoopStart killed renamable $r2, %bb.3 +; CHECK-MID: tB %bb.1 +; CHECK-MID: bb.1.while.body.preheader: +; CHECK-MID: $lr = t2LoopDec killed renamable $lr, 1 +; CHECK-MID: t2LoopEnd killed renamable $lr, %bb.2 +; CHECk-MID: tB %bb.3 +; CHECK-MID: bb.3.while.end: +define void @check_negated_xor_wls(i16* nocapture %a, i16* nocapture readonly %b, i32 %N) { +entry: + %wls = call i1 @llvm.test.set.loop.iterations.i32(i32 %N) + %xor = xor i1 %wls, 1 + br i1 %xor, label %while.end, label %while.body.preheader + +while.body.preheader: + br label %while.body + +while.body: + %a.addr.06 = phi i16* [ %incdec.ptr1, %while.body ], [ %a, %while.body.preheader ] + %b.addr.05 = phi i16* [ %incdec.ptr, %while.body ], [ %b, %while.body.preheader ] + %count = phi i32 [ %N, %while.body.preheader ], [ %count.next, %while.body ] + %incdec.ptr = getelementptr inbounds i16, i16* %b.addr.05, i32 1 + %ld.b = load i16, i16* %b.addr.05, align 2 + %incdec.ptr1 = getelementptr inbounds i16, i16* %a.addr.06, i32 1 + store i16 %ld.b, i16* %a.addr.06, align 2 + %count.next = call i32 @llvm.loop.decrement.reg.i32.i32.i32(i32 %count, i32 1) + %cmp = icmp ne i32 %count.next, 0 + br i1 %cmp, label %while.body, label %while.end + +while.end: + ret void +} + +; CHECK-MID: check_negated_cmp_wls +; CHECK-MID: t2WhileLoopStart killed renamable $r2, %bb.3 +; CHECK-MID: tB %bb.1 +; CHECK-MID: bb.1.while.body.preheader: +; CHECK-MID: $lr = t2LoopDec killed renamable $lr, 1 +; CHECK-MID: t2LoopEnd killed renamable $lr, %bb.2 +; CHECk-MID: tB %bb.3 +; CHECK-MID: bb.3.while.end: +define void @check_negated_cmp_wls(i16* nocapture %a, i16* nocapture readonly %b, i32 %N) { +entry: + %wls = call i1 @llvm.test.set.loop.iterations.i32(i32 %N) + %cmp = icmp ne i1 %wls, 1 + br i1 %cmp, label %while.end, label %while.body.preheader + +while.body.preheader: + br label %while.body + +while.body: + %a.addr.06 = phi i16* [ %incdec.ptr1, %while.body ], [ %a, %while.body.preheader ] + %b.addr.05 = phi i16* [ %incdec.ptr, %while.body ], [ %b, %while.body.preheader ] + %count = phi i32 [ %N, %while.body.preheader ], [ %count.next, %while.body ] + %incdec.ptr = getelementptr inbounds i16, i16* %b.addr.05, i32 1 + %ld.b = load i16, i16* %b.addr.05, align 2 + %incdec.ptr1 = getelementptr inbounds i16, i16* %a.addr.06, i32 1 + store i16 %ld.b, i16* %a.addr.06, align 2 + %count.next = call i32 @llvm.loop.decrement.reg.i32.i32.i32(i32 %count, i32 1) + %cmp.1 = icmp ne i32 %count.next, 0 + br i1 %cmp.1, label %while.body, label %while.end + +while.end: + ret void +} + +; CHECK-MID: check_negated_reordered_wls +; CHECK-MID: bb.1.while.body.preheader: +; CHECK-MID: tB %bb.2 +; CHECK-MID: bb.2.while.body: +; CHECK-MID: t2LoopDec killed renamable $lr, 1 +; CHECK-MID: t2LoopEnd killed renamable $lr, %bb.2 +; CHECK-MID: tB %bb.4 +; CHECK-MID: bb.3.while: +; CHECK-MID: t2WhileLoopStart {{.*}}, %bb.4 +; CHECK-MID: bb.4.while.end +define void @check_negated_reordered_wls(i16* nocapture %a, i16* nocapture readonly %b, i32 %N) { +entry: + br label %while + +while.body.preheader: + br label %while.body + +while.body: + %a.addr.06 = phi i16* [ %incdec.ptr1, %while.body ], [ %a, %while.body.preheader ] + %b.addr.05 = phi i16* [ %incdec.ptr, %while.body ], [ %b, %while.body.preheader ] + %count = phi i32 [ %N, %while.body.preheader ], [ %count.next, %while.body ] + %incdec.ptr = getelementptr inbounds i16, i16* %b.addr.05, i32 1 + %ld.b = load i16, i16* %b.addr.05, align 2 + %incdec.ptr1 = getelementptr inbounds i16, i16* %a.addr.06, i32 1 + store i16 %ld.b, i16* %a.addr.06, align 2 + %count.next = call i32 @llvm.loop.decrement.reg.i32.i32.i32(i32 %count, i32 1) + %cmp = icmp ne i32 %count.next, 0 + br i1 %cmp, label %while.body, label %while.end + +while: + %wls = call i1 @llvm.test.set.loop.iterations.i32(i32 %N) + %xor = xor i1 %wls, 1 + br i1 %xor, label %while.end, label %while.body.preheader + +while.end: + ret void +} + +declare void @llvm.set.loop.iterations.i32(i32) +declare i1 @llvm.test.set.loop.iterations.i32(i32) +declare i32 @llvm.loop.decrement.reg.i32.i32.i32(i32, i32)