diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -16854,6 +16854,44 @@ return SDValue(); } +// ((X >> C) - Y) + Z --> (Z - Y) + (X >> C) +static SDValue performAddCombineSubShift(SDNode *N, SDValue SUB, SDValue Z, + SelectionDAG &DAG) { + auto IsOneUseShiftC = [&](SDValue Shift) { + if (!Shift.hasOneUse()) + return false; + + // TODO: support SRL and SRA also + if (Shift.getOpcode() != ISD::SHL) + return false; + + if (!isa(Shift.getOperand(1))) + return false; + return true; + }; + + // DAGCombiner will revert the combination when Z is constant cause + // dead loop. So don't enable the combination when Z is constant. + // If Z is one use shift C, we also can't do the optimization. + // It will falling to self infinite loop. + if (isa(Z) || IsOneUseShiftC(Z)) + return SDValue(); + + if (SUB.getOpcode() != ISD::SUB || !SUB.hasOneUse()) + return SDValue(); + + SDValue Shift = SUB.getOperand(0); + if (!IsOneUseShiftC(Shift)) + return SDValue(); + + SDLoc DL(N); + EVT VT = N->getValueType(0); + + SDValue Y = SUB.getOperand(1); + SDValue NewSub = DAG.getNode(ISD::SUB, DL, VT, Z, Y); + return DAG.getNode(ISD::ADD, DL, VT, NewSub, Shift); +} + static SDValue performAddCombineForShiftedOperands(SDNode *N, SelectionDAG &DAG) { // NOTE: Swapping LHS and RHS is not done for SUB, since SUB is not @@ -16871,6 +16909,11 @@ SDValue LHS = N->getOperand(0); SDValue RHS = N->getOperand(1); + if (SDValue Val = performAddCombineSubShift(N, LHS, RHS, DAG)) + return Val; + if (SDValue Val = performAddCombineSubShift(N, RHS, LHS, DAG)) + return Val; + uint64_t LHSImm = 0, RHSImm = 0; // If both operand are shifted by imm and shift amount is not greater than 4 // for one operand, swap LHS and RHS to put operand with smaller shift amount diff --git a/llvm/test/CodeGen/AArch64/addsub.ll b/llvm/test/CodeGen/AArch64/addsub.ll --- a/llvm/test/CodeGen/AArch64/addsub.ll +++ b/llvm/test/CodeGen/AArch64/addsub.ll @@ -694,12 +694,12 @@ ret i32 undef } +; ((X >> C) - Y) + Z --> (Z - Y) + (X >> C) define i32 @commute_subop0(i32 %x, i32 %y, i32 %z) { ; CHECK-LABEL: commute_subop0: ; CHECK: // %bb.0: -; CHECK-NEXT: lsl w8, w0, #3 -; CHECK-NEXT: sub w8, w8, w1 -; CHECK-NEXT: add w0, w8, w2 +; CHECK-NEXT: sub w8, w2, w1 +; CHECK-NEXT: add w0, w8, w0, lsl #3 ; CHECK-NEXT: ret %shl = shl i32 %x, 3 %sub = sub i32 %shl, %y @@ -707,12 +707,40 @@ ret i32 %add } +; ((X << C) - Y) + Z --> (Z - Y) + (X << C) +define i32 @commute_subop0_lshr(i32 %x, i32 %y, i32 %z) { +; CHECK-LABEL: commute_subop0_lshr: +; CHECK: // %bb.0: +; CHECK-NEXT: lsr w8, w0, #3 +; CHECK-NEXT: sub w8, w8, w1 +; CHECK-NEXT: add w0, w8, w2 +; CHECK-NEXT: ret + %lshr = lshr i32 %x, 3 + %sub = sub i32 %lshr, %y + %add = add i32 %sub, %z + ret i32 %add +} + +; ((X << C) - Y) + Z --> (Z - Y) + (X << C) +define i32 @commute_subop0_ashr(i32 %x, i32 %y, i32 %z) { +; CHECK-LABEL: commute_subop0_ashr: +; CHECK: // %bb.0: +; CHECK-NEXT: asr w8, w0, #3 +; CHECK-NEXT: sub w8, w8, w1 +; CHECK-NEXT: add w0, w8, w2 +; CHECK-NEXT: ret + %ashr = ashr i32 %x, 3 + %sub = sub i32 %ashr, %y + %add = add i32 %sub, %z + ret i32 %add +} + +; Z + ((X >> C) - Y) --> (Z - Y) + (X >> C) define i32 @commute_subop0_cadd(i32 %x, i32 %y, i32 %z) { ; CHECK-LABEL: commute_subop0_cadd: ; CHECK: // %bb.0: -; CHECK-NEXT: lsl w8, w0, #3 -; CHECK-NEXT: sub w8, w8, w1 -; CHECK-NEXT: add w0, w2, w8 +; CHECK-NEXT: sub w8, w2, w1 +; CHECK-NEXT: add w0, w8, w0, lsl #3 ; CHECK-NEXT: ret %shl = shl i32 %x, 3 %sub = sub i32 %shl, %y @@ -720,14 +748,61 @@ ret i32 %add } +; Y + ((X >> C) - X) --> (Y - X) + (X >> C) define i32 @commute_subop0_mul(i32 %x, i32 %y) { ; CHECK-LABEL: commute_subop0_mul: ; CHECK: // %bb.0: -; CHECK-NEXT: lsl w8, w0, #3 -; CHECK-NEXT: sub w8, w8, w0 -; CHECK-NEXT: add w0, w8, w1 +; CHECK-NEXT: sub w8, w1, w0 +; CHECK-NEXT: add w0, w8, w0, lsl #3 ; CHECK-NEXT: ret %mul = mul i32 %x, 7 %add = add i32 %mul, %y ret i32 %add } + +; negative case for ((X >> C) - Y) + Z --> (Z - Y) + (X >> C) +; Y can't be constant to avoid dead loop +define i32 @commute_subop0_zconst(i32 %x, i32 %y) { +; CHECK-LABEL: commute_subop0_zconst: +; CHECK: // %bb.0: +; CHECK-NEXT: lsl w8, w0, #3 +; CHECK-NEXT: sub w8, w8, w1 +; CHECK-NEXT: add w0, w8, #1 +; CHECK-NEXT: ret + %shl = shl i32 %x, 3 + %sub = sub i32 %shl, %y + %add = add i32 %sub, 1 + ret i32 %add +} + +; negative case for ((X >> C) - Y) + Z --> (Z - Y) + (X >> C) +; Y can't be shift C also to avoid dead loop +define i32 @commute_subop0_zshiftc_oneuse(i32 %x, i32 %y, i32 %z) { +; CHECK-LABEL: commute_subop0_zshiftc_oneuse: +; CHECK: // %bb.0: +; CHECK-NEXT: lsl w8, w0, #3 +; CHECK-NEXT: sub w8, w8, w1 +; CHECK-NEXT: add w0, w8, w2, lsl #2 +; CHECK-NEXT: ret + %xshl = shl i32 %x, 3 + %sub = sub i32 %xshl, %y + %zshl = shl i32 %z, 2 + %add = add i32 %sub, %zshl + ret i32 %add +} + +define i32 @commute_subop0_zshiftc(i32 %x, i32 %y, i32 %z) { +; CHECK-LABEL: commute_subop0_zshiftc: +; CHECK: // %bb.0: +; CHECK-NEXT: lsl w8, w2, #2 +; CHECK-NEXT: sub w9, w8, w1 +; CHECK-NEXT: add w9, w9, w0, lsl #3 +; CHECK-NEXT: eor w0, w8, w9 +; CHECK-NEXT: ret + %xshl = shl i32 %x, 3 + %sub = sub i32 %xshl, %y + %zshl = shl i32 %z, 2 + %add = add i32 %sub, %zshl + %r = xor i32 %zshl, %add + ret i32 %r +}