Index: llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp =================================================================== --- llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp +++ llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp @@ -142,6 +142,7 @@ SDValue ExpandROT(SDValue Op); SDValue ExpandFMINNUM_FMAXNUM(SDValue Op); SDValue ExpandAddSubSat(SDValue Op); + SDValue ExpandFixedPointMul(SDValue Op); SDValue ExpandStrictFPOp(SDValue Op); /// Implements vector promotion. @@ -783,6 +784,8 @@ case ISD::UADDSAT: case ISD::SADDSAT: return ExpandAddSubSat(Op); + case ISD::SMULFIX: + return ExpandFixedPointMul(Op); case ISD::STRICT_FADD: case ISD::STRICT_FSUB: case ISD::STRICT_FMUL: @@ -1218,6 +1221,13 @@ return DAG.UnrollVectorOp(Op.getNode()); } +SDValue VectorLegalizer::ExpandFixedPointMul(SDValue Op) { + if (SDValue Expanded = + TLI.getExpandedFixedPointMultiplication(Op.getNode(), DAG)) + return Expanded; + return DAG.UnrollVectorOp(Op.getNode()); +} + SDValue VectorLegalizer::ExpandStrictFPOp(SDValue Op) { EVT VT = Op.getValueType(); EVT EltVT = VT.getVectorElementType(); Index: llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp =================================================================== --- llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp +++ llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp @@ -5358,12 +5358,21 @@ TargetLowering::getExpandedFixedPointMultiplication(SDNode *Node, SelectionDAG &DAG) const { assert(Node->getOpcode() == ISD::SMULFIX && "Expected opcode to be SMULFIX."); - assert(Node->getNumOperands() == 3 && - "Expected signed fixed point multiplication to have 3 operands."); SDLoc dl(Node); SDValue LHS = Node->getOperand(0); SDValue RHS = Node->getOperand(1); + EVT VT = LHS.getValueType(); + unsigned Scale = Node->getConstantOperandVal(2); + + // [us]mul.fix(a, b, 0) -> mul(a, b) + if (!Scale && isOperationLegalOrCustom(ISD::MUL, VT)) + return DAG.getNode(ISD::MUL, dl, VT, LHS, RHS); + + if (VT.isVector()) + // Do not scalarize here. + return SDValue(); + assert(LHS.getValueType().isScalarInteger() && "Expected operands to be integers. Vector of int arguments should " "already be unrolled."); @@ -5372,15 +5381,9 @@ "already be unrolled."); assert(LHS.getValueType() == RHS.getValueType() && "Expected both operands to be the same type"); - - unsigned Scale = Node->getConstantOperandVal(2); - EVT VT = LHS.getValueType(); assert(Scale < VT.getScalarSizeInBits() && "Expected scale to be less than the number of bits."); - if (!Scale) - return DAG.getNode(ISD::MUL, dl, VT, LHS, RHS); - // Get the upper and lower bits of the result. SDValue Lo, Hi; if (isOperationLegalOrCustom(ISD::SMUL_LOHI, VT)) { Index: llvm/test/CodeGen/X86/smul_fix.ll =================================================================== --- llvm/test/CodeGen/X86/smul_fix.ll +++ llvm/test/CodeGen/X86/smul_fix.ll @@ -295,32 +295,13 @@ define <4 x i32> @vec2(<4 x i32> %x, <4 x i32> %y) nounwind { ; X64-LABEL: vec2: ; X64: # %bb.0: -; X64-NEXT: pshufd {{.*#+}} xmm2 = xmm1[3,1,2,3] -; X64-NEXT: movd %xmm2, %eax -; X64-NEXT: pshufd {{.*#+}} xmm2 = xmm0[3,1,2,3] -; X64-NEXT: movd %xmm2, %ecx -; X64-NEXT: imull %eax, %ecx -; X64-NEXT: movd %ecx, %xmm2 -; X64-NEXT: pshufd {{.*#+}} xmm3 = xmm1[2,3,0,1] -; X64-NEXT: movd %xmm3, %eax -; X64-NEXT: pshufd {{.*#+}} xmm3 = xmm0[2,3,0,1] -; X64-NEXT: movd %xmm3, %ecx -; X64-NEXT: imull %eax, %ecx -; X64-NEXT: movd %ecx, %xmm3 -; X64-NEXT: punpckldq {{.*#+}} xmm3 = xmm3[0],xmm2[0],xmm3[1],xmm2[1] -; X64-NEXT: movd %xmm1, %eax -; X64-NEXT: movd %xmm0, %ecx -; X64-NEXT: imull %eax, %ecx -; X64-NEXT: movd %ecx, %xmm2 -; X64-NEXT: pshufd {{.*#+}} xmm1 = xmm1[1,1,2,3] -; X64-NEXT: movd %xmm1, %eax -; X64-NEXT: pshufd {{.*#+}} xmm0 = xmm0[1,1,2,3] -; X64-NEXT: movd %xmm0, %ecx -; X64-NEXT: imull %eax, %ecx -; X64-NEXT: movd %ecx, %xmm0 -; X64-NEXT: punpckldq {{.*#+}} xmm2 = xmm2[0],xmm0[0],xmm2[1],xmm0[1] -; X64-NEXT: punpcklqdq {{.*#+}} xmm2 = xmm2[0],xmm3[0] -; X64-NEXT: movdqa %xmm2, %xmm0 +; X64-NEXT: pshufd {{.*#+}} xmm2 = xmm0[1,1,3,3] +; X64-NEXT: pmuludq %xmm1, %xmm0 +; X64-NEXT: pshufd {{.*#+}} xmm0 = xmm0[0,2,2,3] +; X64-NEXT: pshufd {{.*#+}} xmm1 = xmm1[1,1,3,3] +; X64-NEXT: pmuludq %xmm2, %xmm1 +; X64-NEXT: pshufd {{.*#+}} xmm1 = xmm1[0,2,2,3] +; X64-NEXT: punpckldq {{.*#+}} xmm0 = xmm0[0],xmm1[0],xmm0[1],xmm1[1] ; X64-NEXT: retq ; ; X86-LABEL: vec2: