diff --git a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp --- a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp +++ b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp @@ -3246,6 +3246,71 @@ return SDValue(); } +// Tries to perform (srl (add (zext a, i2^n,), (zext b, i2^n)), 2^(n-1)) -> +// (uaddo a, b).overflow where a and b has type i2^(n-1) +static SDValue combineSrlAdd(SDNode *N, SelectionDAG &DAG) { + SDValue N0 = N->getOperand(0); + SDValue N1 = N->getOperand(1); + + if (N0.getOpcode() != ISD::ADD) + return SDValue(); + + ConstantSDNode *N1C = dyn_cast(N1); + if (!N1C) + return SDValue(); + + if (!N1C->getAPIntValue().isPowerOf2()) + return SDValue(); + + SDValue N0LHS = N0.getOperand(0); + SDValue N0RHS = N0.getOperand(1); + unsigned ShiftAmt = N1C->getZExtValue(); + + // Make sure that this add was zero-extended from an int type of `ShiftAmt` + // width + if (N0LHS.getOpcode() != ISD::ZERO_EXTEND || + N0RHS.getOpcode() != ISD::ZERO_EXTEND) + return SDValue(); + + SDValue N0NoZextLHS = N0LHS.getOperand(0); + SDValue N0NoZextRHS = N0RHS.getOperand(0); + + if ((ShiftAmt != N0NoZextLHS.getValueType().getScalarSizeInBits()) || + (ShiftAmt != N0NoZextRHS.getValueType().getScalarSizeInBits())) + return SDValue(); + + if ((N0.getValueType().getScalarSizeInBits() / 2) != ShiftAmt) + return SDValue(); + + SmallVector Truncs; + // Make sure `N0` is only used by `N` and `ShiftAmt`-truncates + for (SDNode *U : N0->uses()) { + if (U == N) + continue; + + if (U->getOpcode() == ISD::TRUNCATE && + U->getValueType(0).getScalarSizeInBits() == ShiftAmt) { + Truncs.push_back(U); + continue; + } + + return SDValue(); + } + + SDValue UAddO = DAG.getNode( + ISD::UADDO, SDLoc(N0), DAG.getVTList(N0NoZextLHS.getValueType(), MVT::i1), + N0NoZextLHS, N0NoZextRHS); + + // Replace the `N0` truncate users with `UAddO` since `UAddO` performs the + // truncated version of the addition performed by `N0` + for (SDNode *Trunc : Truncs) { + DAG.ReplaceAllUsesWith(SDValue(Trunc, 0), UAddO.getValue(0)); + } + + DAG.ReplaceAllUsesOfValueWith(N0.getValue(0), UAddO.getValue(0)); + return DAG.getZExtOrTrunc(UAddO.getValue(1), SDLoc(N), N->getValueType(0)); +} + SDValue AMDGPUTargetLowering::performSrlCombine(SDNode *N, DAGCombinerInfo &DCI) const { auto *RHS = dyn_cast(N->getOperand(1)); @@ -3275,6 +3340,9 @@ if (VT != MVT::i64) return SDValue(); + if (SDValue SrlAdd = combineSrlAdd(N, DAG)) + return SrlAdd; + if (ShiftAmt < 32) return SDValue(); diff --git a/llvm/test/CodeGen/AMDGPU/combine-srl-add.ll b/llvm/test/CodeGen/AMDGPU/combine-srl-add.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/AMDGPU/combine-srl-add.ll @@ -0,0 +1,73 @@ +; RUN: llc -march=amdgcn -mcpu=gfx908 -verify-machineinstrs < %s | FileCheck %s + +; CHECK: v_add_co_u32_e32 v0, vcc, v0, v1 +; CHECK-NEXT: v_cndmask_b32_e64 v0, 0, 1, vcc +; CHECK-NEXT: v_add_co_u32_e32 v0, vcc, v2, v0 +; CHECK-NEXT: v_addc_co_u32_e32 v1, vcc, 0, v3, vcc +define i64 @basic(i32 %a, i32 %b, i64 %c) { +entry: + %a.zext = zext i32 %a to i64 + %b.zext = zext i32 %b to i64 + %add.a.b = add i64 %a.zext, %b.zext + %shr = lshr i64 %add.a.b, 32 + %add.c.shr = add i64 %c, %shr + ret i64 %add.c.shr +} + +; CHECK: v_add_co_u32_e32 v0, vcc, v0, v1 +; CHECK-NEXT: v_addc_co_u32_e32 v0, vcc, 0, v0, vcc +define i32 @truncUse(i32 %a, i32 %b) { +entry: + %a.zext = zext i32 %a to i64 + %b.zext = zext i32 %b to i64 + %add.a.b = add i64 %a.zext, %b.zext + %add.a.b.trunc = trunc i64 %add.a.b to i32 + %shr = lshr i64 %add.a.b, 32 + %shr.trunc = trunc i64 %shr to i32 + %ret = add i32 %add.a.b.trunc, %shr.trunc + ret i32 %ret +} + +; CHECK: v_ashrrev_i32_e32 [[ashrV0:v[0-9]+]], 31, v0 +; CHECK-NEXT: v_ashrrev_i32_e32 [[ashrV1:v[0-9]+]], 31, v1 +; CHECK-NEXT: v_add_co_u32_e32 v0, vcc, v0, v1 +; CHECK-NEXT: v_addc_co_u32_e32 v0, vcc, [[ashrV0]], [[ashrV1]], vcc +; CHECK-NEXT: v_add_co_u32_e32 v0, vcc, v2, v0 +; CHECK-NEXT: v_addc_co_u32_e32 v1, vcc, 0, v3, vcc +define i64 @sext(i32 %a, i32 %b, i64 %c) { +entry: + %a.zext = sext i32 %a to i64 + %b.zext = sext i32 %b to i64 + %add.a.b = add i64 %a.zext, %b.zext + %shr = lshr i64 %add.a.b, 32 + %add.c.shr = add i64 %c, %shr + ret i64 %add.c.shr +} + +; CHECK: v_lshrrev_b64 v[0:1], 31, v[0:1] +define i64 @shr31(i32 %a, i32 %b, i64 %c) { +entry: + %a.zext = zext i32 %a to i64 + %b.zext = zext i32 %b to i64 + %add.a.b = add i64 %a.zext, %b.zext + %shr = lshr i64 %add.a.b, 31 + %add.c.shr = add i64 %c, %shr + ret i64 %add.c.shr +} + +; CHECK: v_add_co_u32_e32 v0, vcc, v0, v1 +; CHECK-NEXT: v_addc_co_u32_e64 v1, {{.+}}, 0, 0, vcc +; CHECK-NEXT: v_add_co_u32_e32 v2, vcc, v2, v1 +; CHECK-NEXT: v_addc_co_u32_e32 v3, vcc, 0, v3, vcc +; CHECK-NEXT: v_add_co_u32_e32 v0, vcc, v2, v0 +; CHECK-NEXT: v_addc_co_u32_e32 v1, vcc, v3, v1, vcc +define i64 @badUse(i32 %a, i32 %b, i64 %c) { +entry: + %a.zext = zext i32 %a to i64 + %b.zext = zext i32 %b to i64 + %add.a.b = add i64 %a.zext, %b.zext + %shr = lshr i64 %add.a.b, 32 + %add.c.shr = add i64 %c, %shr + %ret = add i64 %add.c.shr, %add.a.b + ret i64 %ret +}