Index: llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp =================================================================== --- llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -7910,6 +7910,73 @@ return SDValue(); } +// Transform a right shift of a multiply into a multiply-high. +// Examples: +// (srl (mul (zext i32:$a to i64), (zext i32:$a to i64)), 32) -> (mulhu $a, $b) +// (sra (mul (sext i32:$a to i64), (sext i32:$a to i64)), 32) -> (mulhs $a, $b) +static SDValue combineShiftToMULH(SDNode *N, SelectionDAG &DAG, + const TargetLowering &TLI) { + assert((N->getOpcode() == ISD::SRL || N->getOpcode() == ISD::SRA) && + "SRL or SRA node is required here!"); + SDLoc DL(N); + + // The operation feeding into the shift must be a multiply. + SDValue ShiftOperand = N->getOperand(0); + if (ShiftOperand.getOpcode() != ISD::MUL) + return SDValue(); + + // Both operands must be equivalent extend nodes. + SDValue LeftOp = ShiftOperand.getOperand(0); + SDValue RightOp = ShiftOperand.getOperand(1); + bool IsSignExt = LeftOp.getOpcode() == ISD::SIGN_EXTEND; + bool IsZeroExt = LeftOp.getOpcode() == ISD::ZERO_EXTEND; + + if ((!(IsSignExt || IsZeroExt)) || LeftOp.getOpcode() != RightOp.getOpcode()) + return SDValue(); + + EVT WideVT1 = LeftOp.getValueType(); + EVT WideVT2 = RightOp.getValueType(); + // Proceed with the transformation if the wide types match. + assert((WideVT1 == WideVT2) && + "Cannot have a multiply node with two different operand types."); + + EVT NarrowVT = LeftOp.getOperand(0).getValueType(); + // Only transform into mulh if mulh for the narrow type is cheaper than + // a multiply followed by a shift. This should also check if mulh is + // legal for NarrowVT on the target. + if (!TLI.isMulhCheaperThanMulShift(NarrowVT)) + return SDValue(); + + // Proceed with the transformation if the wide type is twice as large + // as the narrow type. + unsigned NarrowVTSize = NarrowVT.getScalarSizeInBits(); + if (WideVT1.getScalarSizeInBits() != 2 * NarrowVTSize) + return SDValue(); + + // Check the shift amount with the narrow type size. + // Proceed with the transformation if the shift amount is the width + // of the narrow type. + ConstantSDNode *ShiftAmtSrc = isConstOrConstSplat(N->getOperand(1)); + if (!ShiftAmtSrc) + return SDValue(); + + unsigned ShiftAmt = ShiftAmtSrc->getZExtValue(); + if (ShiftAmt != NarrowVTSize) + return SDValue(); + + // If the operation feeding into the MUL is a sign extend (sext), + // we use mulhs. Othewise, zero extends (zext) use mulhu. + unsigned MulhOpcode = IsSignExt ? ISD::MULHS : ISD::MULHU; + + if (NarrowVT != RightOp.getOperand(0).getValueType()) + return SDValue(); + + SDValue Result = DAG.getNode(MulhOpcode, DL, NarrowVT, LeftOp.getOperand(0), + RightOp.getOperand(0)); + return (N->getOpcode() == ISD::SRA ? DAG.getSExtOrTrunc(Result, DL, WideVT1) + : DAG.getZExtOrTrunc(Result, DL, WideVT1)); +} + SDValue DAGCombiner::visitSRA(SDNode *N) { SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); @@ -8097,6 +8164,11 @@ if (SDValue NewSRA = visitShiftByConstant(N)) return NewSRA; + // Try to transform this shift into a multiply-high if + // it matches the appropriate pattern detected in combineShiftToMULH. + if (SDValue MULH = combineShiftToMULH(N, DAG, TLI)) + return MULH; + return SDValue(); } @@ -8322,6 +8394,11 @@ } } + // Try to transform this shift into a multiply-high if + // it matches the appropriate pattern detected in combineShiftToMULH. + if (SDValue MULH = combineShiftToMULH(N, DAG, TLI)) + return MULH; + return SDValue(); } Index: llvm/test/CodeGen/PowerPC/combine-to-mulh-shift-amount.ll =================================================================== --- /dev/null +++ llvm/test/CodeGen/PowerPC/combine-to-mulh-shift-amount.ll @@ -0,0 +1,116 @@ +; RUN: llc -verify-machineinstrs -mtriple=powerpc64le-unknown-linux-gnu \ +; RUN: -mcpu=pwr9 -ppc-asm-full-reg-names -ppc-vsr-nums-as-vr < %s | \ +; RUN: FileCheck %s + +; These tests show that for 32-bit and 64-bit scalars, combining a shift to +; a single multiply-high is only valid when the shift amount is the same as +; the width of the narrow type. + +; That is, combining a shift to mulh is only valid for 32-bit when the shift +; amount is 32. +; Likewise, combining a shift to mulh is only valid for 64-bit when the shift +; amount is 64. + +define i32 @test_mulhw(i32 %a, i32 %b) { +; CHECK-LABEL: test_mulhw: +; CHECK: mulld +; CHECK-NOT: mulhw +; CHECK: blr + %1 = sext i32 %a to i64 + %2 = sext i32 %b to i64 + %mul = mul i64 %1, %2 + %shr = lshr i64 %mul, 33 + %tr = trunc i64 %shr to i32 + ret i32 %tr +} + +define i32 @test_mulhu(i32 %a, i32 %b) { +; CHECK-LABEL: test_mulhu: +; CHECK: mulld +; CHECK-NOT: mulhwu +; CHECK: blr + %1 = zext i32 %a to i64 + %2 = zext i32 %b to i64 + %mul = mul i64 %1, %2 + %shr = lshr i64 %mul, 33 + %tr = trunc i64 %shr to i32 + ret i32 %tr +} + +define i64 @test_mulhd(i64 %a, i64 %b) { +; CHECK-LABEL: test_mulhd: +; CHECK: mulhd +; CHECK: mulld +; CHECK: blr + %1 = sext i64 %a to i128 + %2 = sext i64 %b to i128 + %mul = mul i128 %1, %2 + %shr = lshr i128 %mul, 63 + %tr = trunc i128 %shr to i64 + ret i64 %tr +} + +define i64 @test_mulhdu(i64 %a, i64 %b) { +; CHECK-LABEL: test_mulhdu: +; CHECK: mulhdu +; CHECK: mulld +; CHECK: blr + %1 = zext i64 %a to i128 + %2 = zext i64 %b to i128 + %mul = mul i128 %1, %2 + %shr = lshr i128 %mul, 63 + %tr = trunc i128 %shr to i64 + ret i64 %tr +} + +define signext i32 @test_mulhw_signext(i32 %a, i32 %b) { +; CHECK-LABEL: test_mulhw_signext: +; CHECK: mulld +; CHECK-NOT: mulhw +; CHECK: blr + %1 = sext i32 %a to i64 + %2 = sext i32 %b to i64 + %mul = mul i64 %1, %2 + %shr = lshr i64 %mul, 33 + %tr = trunc i64 %shr to i32 + ret i32 %tr +} + +define zeroext i32 @test_mulhu_zeroext(i32 %a, i32 %b) { +; CHECK-LABEL: test_mulhu_zeroext: +; CHECK: mulld +; CHECK-NOT: mulhwu +; CHECK: blr + %1 = zext i32 %a to i64 + %2 = zext i32 %b to i64 + %mul = mul i64 %1, %2 + %shr = lshr i64 %mul, 33 + %tr = trunc i64 %shr to i32 + ret i32 %tr +} + +define signext i64 @test_mulhd_signext(i64 %a, i64 %b) { +; CHECK-LABEL: test_mulhd_signext: +; CHECK: mulhd +; CHECK: mulld +; CHECK: blr + %1 = sext i64 %a to i128 + %2 = sext i64 %b to i128 + %mul = mul i128 %1, %2 + %shr = lshr i128 %mul, 63 + %tr = trunc i128 %shr to i64 + ret i64 %tr +} + +define zeroext i64 @test_mulhdu_zeroext(i64 %a, i64 %b) { +; CHECK-LABEL: test_mulhdu_zeroext: +; CHECK: mulhdu +; CHECK: mulld +; CHECK: blr + %1 = zext i64 %a to i128 + %2 = zext i64 %b to i128 + %mul = mul i128 %1, %2 + %shr = lshr i128 %mul, 63 + %tr = trunc i128 %shr to i64 + ret i64 %tr +} Index: llvm/test/CodeGen/PowerPC/mul-high.ll =================================================================== --- /dev/null +++ llvm/test/CodeGen/PowerPC/mul-high.ll @@ -0,0 +1,125 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py +; RUN: llc -verify-machineinstrs -mtriple=powerpc64le-unknown-linux-gnu \ +; RUN: -mcpu=pwr9 -ppc-asm-full-reg-names -ppc-vsr-nums-as-vr < %s | \ +; RUN: FileCheck %s + +; This test case tests multiply high for i32 and i64. When the values are +; sign-extended, mulh[d|w] is emitted. When values are zero-extended, +; mulh[d|w]u is emitted instead. + +; The primary goal is transforming the pattern: +; (shift (mul (ext $a, ), (ext $b, )), ) +; into (mulhs $a, $b) for sign extend, and (mulhu $a, $b) for zero extend, +; provided that the mulh operation is legal for . +; The shift operation can be either the srl or sra operations. + +; When no attribute is present on i32, the shift operation is srl. +define i32 @test_mulhw(i32 %a, i32 %b) { +; CHECK-LABEL: test_mulhw: +; CHECK: # %bb.0: +; CHECK-NEXT: mulhw r3, r3, r4 +; CHECK-NEXT: clrldi r3, r3, 32 +; CHECK-NEXT: blr + %1 = sext i32 %a to i64 + %2 = sext i32 %b to i64 + %mul = mul i64 %1, %2 + %shr = lshr i64 %mul, 32 + %tr = trunc i64 %shr to i32 + ret i32 %tr +} + +define i32 @test_mulhu(i32 %a, i32 %b) { +; CHECK-LABEL: test_mulhu: +; CHECK: # %bb.0: +; CHECK-NEXT: mulhwu r3, r3, r4 +; CHECK-NEXT: clrldi r3, r3, 32 +; CHECK-NEXT: blr + %1 = zext i32 %a to i64 + %2 = zext i32 %b to i64 + %mul = mul i64 %1, %2 + %shr = lshr i64 %mul, 32 + %tr = trunc i64 %shr to i32 + ret i32 %tr +} + +define i64 @test_mulhd(i64 %a, i64 %b) { +; CHECK-LABEL: test_mulhd: +; CHECK: # %bb.0: +; CHECK-NEXT: mulhd r3, r3, r4 +; CHECK-NEXT: blr + %1 = sext i64 %a to i128 + %2 = sext i64 %b to i128 + %mul = mul i128 %1, %2 + %shr = lshr i128 %mul, 64 + %tr = trunc i128 %shr to i64 + ret i64 %tr +} + +define i64 @test_mulhdu(i64 %a, i64 %b) { +; CHECK-LABEL: test_mulhdu: +; CHECK: # %bb.0: +; CHECK-NEXT: mulhdu r3, r3, r4 +; CHECK-NEXT: blr + %1 = zext i64 %a to i128 + %2 = zext i64 %b to i128 + %mul = mul i128 %1, %2 + %shr = lshr i128 %mul, 64 + %tr = trunc i128 %shr to i64 + ret i64 %tr +} + +; When the signext attribute is present on i32, the shift operation is sra. +; We are actually transforming (sra (mul sext_in_reg, sext_in_reg)) into mulh. +define signext i32 @test_mulhw_signext(i32 %a, i32 %b) { +; CHECK-LABEL: test_mulhw_signext: +; CHECK: # %bb.0: +; CHECK-NEXT: mulhw r3, r3, r4 +; CHECK-NEXT: extsw r3, r3 +; CHECK-NEXT: blr + %1 = sext i32 %a to i64 + %2 = sext i32 %b to i64 + %mul = mul i64 %1, %2 + %shr = lshr i64 %mul, 32 + %tr = trunc i64 %shr to i32 + ret i32 %tr +} + +define zeroext i32 @test_mulhu_zeroext(i32 %a, i32 %b) { +; CHECK-LABEL: test_mulhu_zeroext: +; CHECK: # %bb.0: +; CHECK-NEXT: mulhwu r3, r3, r4 +; CHECK-NEXT: clrldi r3, r3, 32 +; CHECK-NEXT: blr + %1 = zext i32 %a to i64 + %2 = zext i32 %b to i64 + %mul = mul i64 %1, %2 + %shr = lshr i64 %mul, 32 + %tr = trunc i64 %shr to i32 + ret i32 %tr +} + +define signext i64 @test_mulhd_signext(i64 %a, i64 %b) { +; CHECK-LABEL: test_mulhd_signext: +; CHECK: # %bb.0: +; CHECK-NEXT: mulhd r3, r3, r4 +; CHECK-NEXT: blr + %1 = sext i64 %a to i128 + %2 = sext i64 %b to i128 + %mul = mul i128 %1, %2 + %shr = lshr i128 %mul, 64 + %tr = trunc i128 %shr to i64 + ret i64 %tr +} + +define zeroext i64 @test_mulhdu_zeroext(i64 %a, i64 %b) { +; CHECK-LABEL: test_mulhdu_zeroext: +; CHECK: # %bb.0: +; CHECK-NEXT: mulhdu r3, r3, r4 +; CHECK-NEXT: blr + %1 = zext i64 %a to i128 + %2 = zext i64 %b to i128 + %mul = mul i128 %1, %2 + %shr = lshr i128 %mul, 64 + %tr = trunc i128 %shr to i64 + ret i64 %tr +}