diff --git a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp --- a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp +++ b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp @@ -3246,6 +3246,64 @@ return SDValue(); } +// Tries to perform (srl (add (zext a, i2^n,), (zext b, i2^n)), 2^(n-1)) -> +// (addc a, b).carry where a and b has type i2^(n-1) +static SDValue combineSrlAdd(SDNode *N, SelectionDAG &DAG) { + SDValue N0 = N->getOperand(0); + SDValue N1 = N->getOperand(1); + ConstantSDNode *N1C = nullptr; + + if (!(N1C = dyn_cast(N1))) + return SDValue(); + + if (!N1C->getAPIntValue().isPowerOf2()) + return SDValue(); + + if (N0.getOpcode() != ISD::ADD) + return SDValue(); + + SDValue N0LHS = N0.getOperand(0); + SDValue N0RHS = N0.getOperand(1); + unsigned ShiftAmt = N1C->getZExtValue(); + + // Make sure that this add was zero-extended from an int type of `ShiftAmt` + // width + if (N0LHS.getOpcode() != ISD::ZERO_EXTEND || + N0RHS.getOpcode() != ISD::ZERO_EXTEND) + return SDValue(); + + SDValue N0NoZextLHS = N0LHS.getOperand(0); + SDValue N0NoZextRHS = N0RHS.getOperand(0); + + if ((ShiftAmt != N0NoZextLHS.getValueType().getScalarSizeInBits()) || + (ShiftAmt != N0NoZextRHS.getValueType().getScalarSizeInBits())) + return SDValue(); + + if ((N0.getValueType().getScalarSizeInBits() / 2) != ShiftAmt) + return SDValue(); + + // Make sure that the users of `N0` is not expecting the zero-extended type + for (SDNode *U : N0->uses()) { + if (U == N) + continue; + + if (U->getOpcode() == ISD::TRUNCATE && + U->getValueType(0).getScalarSizeInBits() == ShiftAmt) + continue; + + return SDValue(); + } + + SDValue Zero = DAG.getConstant(0, SDLoc(N0), N->getValueType(0)); + SDValue AddC = + DAG.getNode(ISD::ADDCARRY, SDLoc(N0), + DAG.getVTList(N0NoZextLHS.getValueType(), N->getValueType(0)), + N0NoZextLHS, N0NoZextRHS, Zero); + DAG.ReplaceAllUsesOfValueWith(N0.getValue(0), AddC.getValue(0)); + + return AddC.getValue(1); +} + SDValue AMDGPUTargetLowering::performSrlCombine(SDNode *N, DAGCombinerInfo &DCI) const { auto *RHS = dyn_cast(N->getOperand(1)); @@ -3275,6 +3333,9 @@ if (VT != MVT::i64) return SDValue(); + if (SDValue SrlAdd = combineSrlAdd(N, DAG)) + return SrlAdd; + if (ShiftAmt < 32) return SDValue(); diff --git a/llvm/test/CodeGen/AMDGPU/combine-srl-add.ll b/llvm/test/CodeGen/AMDGPU/combine-srl-add.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/AMDGPU/combine-srl-add.ll @@ -0,0 +1,13 @@ +; RUN: llc -march=amdgcn -mcpu=gfx908 -verify-machineinstrs < %s | FileCheck %s + +; CHECK: v_add_co_u32_e32 v0, vcc, v0, v1 +; CHECK: v_add_co_u32_e32 v0, vcc, vcc_lo, v2 +define i64 @add(i32 %a, i32 %b, i64 %c) { +entry: + %a.zext = zext i32 %a to i64 + %b.zext = zext i32 %b to i64 + %add.a.b = add i64 %a.zext, %b.zext + %shr = lshr i64 %add.a.b, 32 + %add.c.shr = add i64 %c, %shr + ret i64 %add.c.shr +}