Index: include/llvm/Target/TargetLowering.h =================================================================== --- include/llvm/Target/TargetLowering.h +++ include/llvm/Target/TargetLowering.h @@ -2376,6 +2376,12 @@ unsigned Depth = 0, bool AssumeSingleUse = false) const; + /// Helper wrapper around SimplifyDemandedBits for a low bit mask + bool SimplifyDemandedLowBits(SDValue Op, unsigned BitWidth, DAGCombinerInfo &DCI) const; + + /// Helper wrapper around SimplifyDemandedBits for a high bit mask + bool SimplifyDemandedHighBits(SDValue Op, unsigned BitWidth, DAGCombinerInfo &DCI) const; + /// Determine which of the bits specified in Mask are known to be either zero /// or one and return them in the KnownZero/KnownOne bitsets. virtual void computeKnownBitsForTargetNode(const SDValue Op, Index: lib/CodeGen/SelectionDAG/TargetLowering.cpp =================================================================== --- lib/CodeGen/SelectionDAG/TargetLowering.cpp +++ lib/CodeGen/SelectionDAG/TargetLowering.cpp @@ -471,6 +471,40 @@ return true; } +bool TargetLowering::SimplifyDemandedLowBits(SDValue Op, unsigned BitWidth, + DAGCombinerInfo &DCI) const { + + SelectionDAG &DAG = DCI.DAG; + TargetLoweringOpt TLO(DAG, !DCI.isBeforeLegalize(), + !DCI.isBeforeLegalizeOps()); + EVT VT = Op.getValueType(); + APInt DemandedMask = APInt::getLowBitsSet(VT.getSizeInBits(), BitWidth); + APInt KnownZero, KnownOne; + + bool Simplified = SimplifyDemandedBits(Op, DemandedMask, KnownZero, KnownOne, + TLO); + if (Simplified) + DCI.CommitTargetLoweringOpt(TLO); + return Simplified; +} + +bool TargetLowering::SimplifyDemandedHighBits(SDValue Op, unsigned BitWidth, + DAGCombinerInfo &DCI) const { + + SelectionDAG &DAG = DCI.DAG; + TargetLoweringOpt TLO(DAG, !DCI.isBeforeLegalize(), + !DCI.isBeforeLegalizeOps()); + EVT VT = Op.getValueType(); + APInt DemandedMask = APInt::getHighBitsSet(VT.getSizeInBits(), BitWidth); + APInt KnownZero, KnownOne; + + bool Simplified = SimplifyDemandedBits(Op, DemandedMask, KnownZero, KnownOne, + TLO); + if (Simplified) + DCI.CommitTargetLoweringOpt(TLO); + return Simplified; +} + /// Look at Op. At this point, we know that only the DemandedMask bits of the /// result of Op are ever used downstream. If we can use this information to /// simplify Op, create a new simplified DAG node and return true, returning the Index: lib/Target/ARM/ARMISelDAGToDAG.cpp =================================================================== --- lib/Target/ARM/ARMISelDAGToDAG.cpp +++ lib/Target/ARM/ARMISelDAGToDAG.cpp @@ -247,8 +247,6 @@ void SelectConcatVector(SDNode *N); void SelectCMPZ(SDNode *N, bool &SwitchEQNEToPLMI); - bool trySMLAWSMULW(SDNode *N); - void SelectCMP_SWAP(SDNode *N); /// SelectInlineAsmMemoryOperand - Implement addressing mode selection for @@ -2559,141 +2557,6 @@ return false; } -static bool SearchSignedMulShort(SDValue SignExt, unsigned *Opc, SDValue &Src1, - bool Accumulate) { - // For SM*WB, we need to some form of sext. - // For SM*WT, we need to search for (sra X, 16) - // Src1 then gets set to X. - if ((SignExt.getOpcode() == ISD::SIGN_EXTEND || - SignExt.getOpcode() == ISD::SIGN_EXTEND_INREG || - SignExt.getOpcode() == ISD::AssertSext) && - SignExt.getValueType() == MVT::i32) { - - *Opc = Accumulate ? ARM::SMLAWB : ARM::SMULWB; - Src1 = SignExt.getOperand(0); - return true; - } - - if (SignExt.getOpcode() != ISD::SRA) - return false; - - ConstantSDNode *SRASrc1 = dyn_cast(SignExt.getOperand(1)); - if (!SRASrc1 || SRASrc1->getZExtValue() != 16) - return false; - - SDValue Op0 = SignExt.getOperand(0); - - // The sign extend operand for SM*WB could be generated by a shl and ashr. - if (Op0.getOpcode() == ISD::SHL) { - SDValue SHL = Op0; - ConstantSDNode *SHLSrc1 = dyn_cast(SHL.getOperand(1)); - if (!SHLSrc1 || SHLSrc1->getZExtValue() != 16) - return false; - - *Opc = Accumulate ? ARM::SMLAWB : ARM::SMULWB; - Src1 = Op0.getOperand(0); - return true; - } - *Opc = Accumulate ? ARM::SMLAWT : ARM::SMULWT; - Src1 = SignExt.getOperand(0); - return true; -} - -static bool SearchSignedMulLong(SDValue OR, unsigned *Opc, SDValue &Src0, - SDValue &Src1, bool Accumulate) { - // First we look for: - // (add (or (srl ?, 16), (shl ?, 16))) - if (OR.getOpcode() != ISD::OR) - return false; - - SDValue SRL = OR.getOperand(0); - SDValue SHL = OR.getOperand(1); - - if (SRL.getOpcode() != ISD::SRL || SHL.getOpcode() != ISD::SHL) { - SRL = OR.getOperand(1); - SHL = OR.getOperand(0); - if (SRL.getOpcode() != ISD::SRL || SHL.getOpcode() != ISD::SHL) - return false; - } - - ConstantSDNode *SRLSrc1 = dyn_cast(SRL.getOperand(1)); - ConstantSDNode *SHLSrc1 = dyn_cast(SHL.getOperand(1)); - if (!SRLSrc1 || !SHLSrc1 || SRLSrc1->getZExtValue() != 16 || - SHLSrc1->getZExtValue() != 16) - return false; - - // The first operands to the shifts need to be the two results from the - // same smul_lohi node. - if ((SRL.getOperand(0).getNode() != SHL.getOperand(0).getNode()) || - SRL.getOperand(0).getOpcode() != ISD::SMUL_LOHI) - return false; - - SDNode *SMULLOHI = SRL.getOperand(0).getNode(); - if (SRL.getOperand(0) != SDValue(SMULLOHI, 0) || - SHL.getOperand(0) != SDValue(SMULLOHI, 1)) - return false; - - // Now we have: - // (add (or (srl (smul_lohi ?, ?), 16), (shl (smul_lohi ?, ?), 16))) - // For SMLAW[B|T] smul_lohi will take a 32-bit and a 16-bit arguments. - // For SMLAWB the 16-bit value will signed extended somehow. - // For SMLAWT only the SRA is required. - - // Check both sides of SMUL_LOHI - if (SearchSignedMulShort(SMULLOHI->getOperand(0), Opc, Src1, Accumulate)) { - Src0 = SMULLOHI->getOperand(1); - } else if (SearchSignedMulShort(SMULLOHI->getOperand(1), Opc, Src1, - Accumulate)) { - Src0 = SMULLOHI->getOperand(0); - } else { - return false; - } - return true; -} - -bool ARMDAGToDAGISel::trySMLAWSMULW(SDNode *N) { - if (!Subtarget->hasV6Ops() || - (Subtarget->isThumb() && !Subtarget->hasThumb2())) - return false; - - SDLoc dl(N); - SDValue Src0 = N->getOperand(0); - SDValue Src1 = N->getOperand(1); - SDValue A, B; - unsigned Opc = 0; - - if (N->getOpcode() == ISD::ADD) { - if (Src0.getOpcode() != ISD::OR && Src1.getOpcode() != ISD::OR) - return false; - - SDValue Acc; - if (SearchSignedMulLong(Src0, &Opc, A, B, true)) { - Acc = Src1; - } else if (SearchSignedMulLong(Src1, &Opc, A, B, true)) { - Acc = Src0; - } else { - return false; - } - if (Opc == 0) - return false; - - SDValue Ops[] = { A, B, Acc, getAL(CurDAG, dl), - CurDAG->getRegister(0, MVT::i32) }; - CurDAG->SelectNodeTo(N, Opc, MVT::i32, MVT::Other, Ops); - return true; - } else if (N->getOpcode() == ISD::OR && - SearchSignedMulLong(SDValue(N, 0), &Opc, A, B, false)) { - if (Opc == 0) - return false; - - SDValue Ops[] = { A, B, getAL(CurDAG, dl), - CurDAG->getRegister(0, MVT::i32)}; - CurDAG->SelectNodeTo(N, Opc, MVT::i32, Ops); - return true; - } - return false; -} - /// We've got special pseudo-instructions for these void ARMDAGToDAGISel::SelectCMP_SWAP(SDNode *N) { unsigned Opcode; @@ -2822,11 +2685,6 @@ switch (N->getOpcode()) { default: break; - case ISD::ADD: - case ISD::OR: - if (trySMLAWSMULW(N)) - return; - break; case ISD::WRITE_REGISTER: if (tryWriteRegister(N)) return; Index: lib/Target/ARM/ARMISelLowering.h =================================================================== --- lib/Target/ARM/ARMISelLowering.h +++ lib/Target/ARM/ARMISelLowering.h @@ -175,6 +175,8 @@ VMULLs, // ...signed VMULLu, // ...unsigned + SMULWB, // Signed multiply word by half word, bottom + SMULWT, // Signed multiply word by half word, top UMLAL, // 64bit Unsigned Accumulate Multiply SMLAL, // 64bit Signed Accumulate Multiply UMAAL, // 64-bit Unsigned Accumulate Accumulate Multiply Index: lib/Target/ARM/ARMISelLowering.cpp =================================================================== --- lib/Target/ARM/ARMISelLowering.cpp +++ lib/Target/ARM/ARMISelLowering.cpp @@ -1344,6 +1344,8 @@ case ARMISD::UMAAL: return "ARMISD::UMAAL"; case ARMISD::UMLAL: return "ARMISD::UMLAL"; case ARMISD::SMLAL: return "ARMISD::SMLAL"; + case ARMISD::SMULWB: return "ARMISD::SMULWB"; + case ARMISD::SMULWT: return "ARMISD::SMULWT"; case ARMISD::BUILD_VECTOR: return "ARMISD::BUILD_VECTOR"; case ARMISD::BFI: return "ARMISD::BFI"; case ARMISD::VORRIMM: return "ARMISD::VORRIMM"; @@ -1453,6 +1455,40 @@ // Lowering Code //===----------------------------------------------------------------------===// +static bool isSRL16(const SDValue &Op) { + if (Op.getOpcode() != ISD::SRL) + return false; + if (auto Const = dyn_cast(Op.getOperand(1))) + return Const->getZExtValue() == 16; + return false; +} + +static bool isSRA16(const SDValue &Op) { + if (Op.getOpcode() != ISD::SRA) + return false; + if (auto Const = dyn_cast(Op.getOperand(1))) + return Const->getZExtValue() == 16; + return false; +} + +static bool isSHL16(const SDValue &Op) { + if (Op.getOpcode() != ISD::SHL) + return false; + if (auto Const = dyn_cast(Op.getOperand(1))) + return Const->getZExtValue() == 16; + return false; +} + +// Check for a signed 16-bit value. We special case SRA because it makes it +// more simple when also looking for SRAs that aren't sign extending a +// smaller value. Without the check, we'd need to take extra care with +// checking order for some operations. +static bool isS16(const SDValue &Op, SelectionDAG &DAG) { + if (isSRA16(Op)) + return isSHL16(Op.getOperand(0)); + return DAG.ComputeNumSignBits(Op) == 17; +} + /// IntCCToARMCC - Convert a DAG integer condition code to an ARM CC static ARMCC::CondCodes IntCCToARMCC(ISD::CondCode CC) { switch (CC) { @@ -9890,6 +9926,67 @@ return SDValue(); } +// Try combining OR nodes to SMULWB, SMULWT. +static SDValue PerformORCombineToSMULWBT(SDNode *OR, + TargetLowering::DAGCombinerInfo &DCI, + const ARMSubtarget *Subtarget) { + if (!Subtarget->hasV6Ops() || + (Subtarget->isThumb() && + (!Subtarget->hasThumb2() || !Subtarget->hasDSP()))) + return SDValue(); + + SDValue SRL = OR->getOperand(0); + SDValue SHL = OR->getOperand(1); + + if (SRL.getOpcode() != ISD::SRL || SHL.getOpcode() != ISD::SHL) { + SRL = OR->getOperand(1); + SHL = OR->getOperand(0); + } + if (!isSRL16(SRL) || !isSHL16(SHL)) + return SDValue(); + + // The first operands to the shifts need to be the two results from the + // same smul_lohi node. + if ((SRL.getOperand(0).getNode() != SHL.getOperand(0).getNode()) || + SRL.getOperand(0).getOpcode() != ISD::SMUL_LOHI) + return SDValue(); + + SDNode *SMULLOHI = SRL.getOperand(0).getNode(); + if (SRL.getOperand(0) != SDValue(SMULLOHI, 0) || + SHL.getOperand(0) != SDValue(SMULLOHI, 1)) + return SDValue(); + + // Now we have: + // (or (srl (smul_lohi ?, ?), 16), (shl (smul_lohi ?, ?), 16))) + // For SMUL[B|T] smul_lohi will take a 32-bit and a 16-bit arguments. + // For SMUWB the 16-bit value will signed extended somehow. + // For SMULWT only the SRA is required. + // Check both sides of SMUL_LOHI + SDValue OpS16 = SMULLOHI->getOperand(0); + SDValue OpS32 = SMULLOHI->getOperand(1); + + SelectionDAG &DAG = DCI.DAG; + if (!isS16(OpS16, DAG) && !isSRA16(OpS16)) { + OpS16 = OpS32; + OpS32 = SMULLOHI->getOperand(0); + } + + SDLoc dl(OR); + unsigned Opcode = 0; + if (isS16(OpS16, DAG)) + Opcode = ARMISD::SMULWB; + else if (isSRA16(OpS16)) { + Opcode = ARMISD::SMULWT; + OpS16 = OpS16->getOperand(0); + } + else + return SDValue(); + + SDValue Res = DAG.getNode(Opcode, dl, MVT::i32, OpS32, OpS16); + DAG.ReplaceAllUsesOfValueWith(SDValue(OR, 0), Res); + return SDValue(OR, 0); +} + /// PerformORCombine - Target-specific dag combine xforms for ISD::OR static SDValue PerformORCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, @@ -9927,6 +10024,8 @@ // fold (or (select cc, 0, c), x) -> (select cc, x, (or, x, c)) if (SDValue Result = combineSelectAndUseCommutative(N, false, DCI)) return Result; + if (SDValue Result = PerformORCombineToSMULWBT(N, DCI, Subtarget)) + return Result; } // The code below optimizes (or (and X, Y), Z). @@ -11722,6 +11821,14 @@ return PerformVLDCombine(N, DCI); case ARMISD::BUILD_VECTOR: return PerformARMBUILD_VECTORCombine(N, DCI); + case ARMISD::SMULWB: + if (SimplifyDemandedLowBits(N->getOperand(1), 16, DCI)) + return SDValue(); + break; + case ARMISD::SMULWT: + if (SimplifyDemandedHighBits(N->getOperand(1), 16, DCI)) + return SDValue(); + break; case ISD::INTRINSIC_VOID: case ISD::INTRINSIC_W_CHAIN: switch (cast(N->getOperand(1))->getZExtValue()) { Index: lib/Target/ARM/ARMInstrInfo.td =================================================================== --- lib/Target/ARM/ARMInstrInfo.td +++ lib/Target/ARM/ARMInstrInfo.td @@ -183,6 +183,9 @@ [SDNPHasChain, SDNPInGlue, SDNPOutGlue, SDNPMayStore, SDNPMayLoad]>; +def ARMsmulwb : SDNode<"ARMISD::SMULWB", SDTIntBinOp, []>; +def ARMsmulwt : SDNode<"ARMISD::SMULWT", SDTIntBinOp, []>; + //===----------------------------------------------------------------------===// // ARM Instruction Predicate Definitions. // @@ -4100,13 +4103,13 @@ def WB : AMulxyI<0b0001001, 0b01, (outs GPR:$Rd), (ins GPR:$Rn, GPR:$Rm), IIC_iMUL16, !strconcat(opc, "wb"), "\t$Rd, $Rn, $Rm", - []>, + [(set GPR:$Rd, (ARMsmulwb GPR:$Rn, GPR:$Rm))]>, Requires<[IsARM, HasV5TE]>, Sched<[WriteMUL16, ReadMUL, ReadMUL]>; def WT : AMulxyI<0b0001001, 0b11, (outs GPR:$Rd), (ins GPR:$Rn, GPR:$Rm), IIC_iMUL16, !strconcat(opc, "wt"), "\t$Rd, $Rn, $Rm", - []>, + [(set GPR:$Rd, (ARMsmulwt GPR:$Rn, GPR:$Rm))]>, Requires<[IsARM, HasV5TE]>, Sched<[WriteMUL16, ReadMUL, ReadMUL]>; } @@ -4153,14 +4156,16 @@ def WB : AMulxyIa<0b0001001, 0b00, (outs GPRnopc:$Rd), (ins GPRnopc:$Rn, GPRnopc:$Rm, GPR:$Ra), IIC_iMAC16, !strconcat(opc, "wb"), "\t$Rd, $Rn, $Rm, $Ra", - []>, + [(set GPRnopc:$Rd, + (add GPR:$Ra, (ARMsmulwb GPRnopc:$Rn, GPRnopc:$Rm)))]>, Requires<[IsARM, HasV5TE, UseMulOps]>, Sched<[WriteMAC16, ReadMUL, ReadMUL, ReadMAC]>; def WT : AMulxyIa<0b0001001, 0b10, (outs GPRnopc:$Rd), (ins GPRnopc:$Rn, GPRnopc:$Rm, GPR:$Ra), IIC_iMAC16, !strconcat(opc, "wt"), "\t$Rd, $Rn, $Rm, $Ra", - []>, + [(set GPRnopc:$Rd, + (add GPR:$Ra, (ARMsmulwt GPRnopc:$Rn, GPRnopc:$Rm)))]>, Requires<[IsARM, HasV5TE, UseMulOps]>, Sched<[WriteMAC16, ReadMUL, ReadMUL, ReadMAC]>; } Index: lib/Target/ARM/ARMInstrThumb2.td =================================================================== --- lib/Target/ARM/ARMInstrThumb2.td +++ lib/Target/ARM/ARMInstrThumb2.td @@ -2676,8 +2676,10 @@ def t2SMULTT : T2ThreeRegSMUL<0b001, 0b11, "smultt", [(set rGPR:$Rd, (mul (sra rGPR:$Rn, (i32 16)), (sra rGPR:$Rm, (i32 16))))]>; -def t2SMULWB : T2ThreeRegSMUL<0b011, 0b00, "smulwb", []>; -def t2SMULWT : T2ThreeRegSMUL<0b011, 0b01, "smulwt", []>; +def t2SMULWB : T2ThreeRegSMUL<0b011, 0b00, "smulwb", + [(set rGPR:$Rd, (ARMsmulwb rGPR:$Rn, rGPR:$Rm))]>; +def t2SMULWT : T2ThreeRegSMUL<0b011, 0b01, "smulwt", + [(set rGPR:$Rd, (ARMsmulwt rGPR:$Rn, rGPR:$Rm))]>; def : Thumb2DSPPat<(mul sext_16_node:$Rm, sext_16_node:$Rn), (t2SMULBB rGPR:$Rm, rGPR:$Rn)>; @@ -2712,8 +2714,10 @@ def t2SMLATT : T2FourRegSMLA<0b001, 0b11, "smlatt", [(set rGPR:$Rd, (add rGPR:$Ra, (mul (sra rGPR:$Rn, (i32 16)), (sra rGPR:$Rm, (i32 16)))))]>; -def t2SMLAWB : T2FourRegSMLA<0b011, 0b00, "smlawb", []>; -def t2SMLAWT : T2FourRegSMLA<0b011, 0b01, "smlawt", []>; +def t2SMLAWB : T2FourRegSMLA<0b011, 0b00, "smlawb", + [(set rGPR:$Rd, (add rGPR:$Ra, (ARMsmulwb rGPR:$Rn, rGPR:$Rm)))]>; +def t2SMLAWT : T2FourRegSMLA<0b011, 0b01, "smlawt", + [(set rGPR:$Rd, (add rGPR:$Ra, (ARMsmulwt rGPR:$Rn, rGPR:$Rm)))]>; def : Thumb2DSPMulPat<(add rGPR:$Ra, (mul sext_16_node:$Rn, sext_16_node:$Rm)), (t2SMLABB rGPR:$Rn, rGPR:$Rm, rGPR:$Ra)>; Index: test/CodeGen/ARM/smul.ll =================================================================== --- test/CodeGen/ARM/smul.ll +++ test/CodeGen/ARM/smul.ll @@ -262,3 +262,32 @@ %tmp5 = add i32 %a, %tmp4 ret i32 %tmp5 } + +@global_b = external global i16, align 2 + +define i32 @f22(i32 %a) { +; CHECK-LABEL: f22: +; CHECK: smulwb r0, r0, r1 +; CHECK-THUMBV6-NOT: smulwb + %b = load i16, i16* @global_b, align 2 + %sext = sext i16 %b to i64 + %conv = sext i32 %a to i64 + %mul = mul nsw i64 %sext, %conv + %shr37 = lshr i64 %mul, 16 + %conv4 = trunc i64 %shr37 to i32 + ret i32 %conv4 +} + +define i32 @f23(i32 %a, i32 %c) { +; CHECK-LABEL: f23: +; CHECK: smlawb r0, r0, r2, r1 +; CHECK-THUMBV6-NOT: smlawb + %b = load i16, i16* @global_b, align 2 + %sext = sext i16 %b to i64 + %conv = sext i32 %a to i64 + %mul = mul nsw i64 %sext, %conv + %shr49 = lshr i64 %mul, 16 + %conv5 = trunc i64 %shr49 to i32 + %add = add nsw i32 %conv5, %c + ret i32 %add +}