Index: ../../llvm/lib/Target/ARM/ARMISelLowering.cpp =================================================================== --- ../../llvm/lib/Target/ARM/ARMISelLowering.cpp +++ ../../llvm/lib/Target/ARM/ARMISelLowering.cpp @@ -7943,10 +7943,12 @@ TargetLowering::DAGCombinerInfo &DCI, const ARMSubtarget *Subtarget) { - if (Subtarget->isThumb1Only()) return SDValue(); + if (Subtarget->isThumb1Only()) + return SDValue(); // Only perform the checks after legalize when the pattern is available. - if (DCI.isBeforeLegalize()) return SDValue(); + if (DCI.isBeforeLegalize()) + return SDValue(); // Look for multiply add opportunities. // The pattern is a ISD::UMUL_LOHI followed by two add nodes, where @@ -7982,11 +7984,12 @@ if (AddcOp0->getOpcode() != ISD::UMUL_LOHI && AddcOp0->getOpcode() != ISD::SMUL_LOHI && AddcOp1->getOpcode() != ISD::UMUL_LOHI && - AddcOp1->getOpcode() != ISD::SMUL_LOHI) + AddcOp1->getOpcode() != ISD::SMUL_LOHI && + AddcOp0->getOpcode() != ISD::MUL && AddcOp1->getOpcode() != ISD::MUL) return SDValue(); // Look for the glued ADDE. - SDNode* AddeNode = AddcNode->getGluedUser(); + SDNode *AddeNode = AddcNode->getGluedUser(); if (!AddeNode) return SDValue(); @@ -8006,35 +8009,61 @@ if (AddeOp0.getNode() == AddeOp1.getNode()) return SDValue(); - // Find the MUL_LOHI node walking up ADDE's operands. - bool IsLeftOperandMUL = false; - SDValue MULOp = findMUL_LOHI(AddeOp0); - if (MULOp == SDValue()) - MULOp = findMUL_LOHI(AddeOp1); - else - IsLeftOperandMUL = true; - if (MULOp == SDValue()) - return SDValue(); + // Figure out the high and low input values to the MLAL node. + SDValue *HiAdd = nullptr; + SDValue *LoMul = nullptr; + SDValue *LowAdd = nullptr; + unsigned Opc, FinalOpc; + SDValue MULOp; + if (AddcOp0->getOpcode() == ISD::MUL || AddcOp1->getOpcode() == ISD::MUL) { + bool IsLeftOperandCopyFromReg = false; - // Figure out the right opcode. - unsigned Opc = MULOp->getOpcode(); - unsigned FinalOpc = (Opc == ISD::SMUL_LOHI) ? ARMISD::SMLAL : ARMISD::UMLAL; + // Ensure that ADDE operands are one the below to proceed. + if (AddeOp0->getOpcode() != ISD::SRA && AddeOp1->getOpcode() != ISD::SRA && + AddeOp0->getOpcode() != ISD::CopyFromReg && + AddeOp1->getOpcode() != ISD::CopyFromReg && + AddeOp1->getOpcode() != ISD::Constant && + AddeOp1->getOpcode() != ISD::Constant) + return SDValue(); - // Figure out the high and low input values to the MLAL node. - SDValue* HiAdd = nullptr; - SDValue* LoMul = nullptr; - SDValue* LowAdd = nullptr; + if (AddeOp0->getOpcode() == ISD::CopyFromReg) { + IsLeftOperandCopyFromReg = true; + } - // Ensure that ADDE is from high result of ISD::SMUL_LOHI. - if ((AddeOp0 != MULOp.getValue(1)) && (AddeOp1 != MULOp.getValue(1))) - return SDValue(); + // Figure out the right opcode. + SDValue MULOp = (AddcOp0->getOpcode() == ISD::MUL) ? AddcOp0 : AddcOp1; + EVT VT = MULOp.getValueType(); + unsigned VTBits = VT.getSizeInBits(); + FinalOpc = (VTBits & 0x8000) ? ARMISD::SMLAL : ARMISD::UMLAL; + if (IsLeftOperandCopyFromReg) + HiAdd = &AddeOp0; + else + HiAdd = &AddeOp1; + } else { + // Find the MUL_LOHI node walking up ADDE's operands. + bool IsLeftOperandMUL = false; + MULOp = findMUL_LOHI(AddeOp0); + if (MULOp == SDValue()) + MULOp = findMUL_LOHI(AddeOp1); + else + IsLeftOperandMUL = true; + if (MULOp == SDValue()) + return SDValue(); - if (IsLeftOperandMUL) - HiAdd = &AddeOp1; - else - HiAdd = &AddeOp0; + // Figure out the right opcode. + Opc = MULOp->getOpcode(); + FinalOpc = (Opc == ISD::SMUL_LOHI) ? ARMISD::SMLAL : ARMISD::UMLAL; + // Ensure that ADDE is from high result of ISD::SMUL_LOHI. + if ((AddeOp0 != MULOp.getValue(1)) && (AddeOp1 != MULOp.getValue(1))) + return SDValue(); + if (IsLeftOperandMUL) + HiAdd = &AddeOp1; + else + HiAdd = &AddeOp0; + } + // Ensure that LoMul and LowAdd are taken from correct ISD::SMUL_LOHI node // whose low result is fed to the ADDC we are checking. @@ -8060,7 +8089,7 @@ Ops.push_back(*LowAdd); Ops.push_back(*HiAdd); - SDValue MLALNode = DAG.getNode(FinalOpc, SDLoc(AddcNode), + SDValue MLALNode = DAG.getNode(FinalOpc, SDLoc(AddcNode), DAG.getVTList(MVT::i32, MVT::i32), Ops); // Replace the ADDs' nodes uses by the MLA node's values. @@ -8078,11 +8107,10 @@ /// PerformADDCCombine - Target-specific dag combine transform from /// ISD::ADDC, ISD::ADDE, and ISD::MUL_LOHI to MLAL. static SDValue PerformADDCCombine(SDNode *N, - TargetLowering::DAGCombinerInfo &DCI, - const ARMSubtarget *Subtarget) { + TargetLowering::DAGCombinerInfo &DCI, + const ARMSubtarget *Subtarget) { return AddCombineTo64bitMLAL(N, DCI, Subtarget); - } /// PerformADDCombineWithOperands - Try DAG combinations for an ADD with @@ -8089,9 +8117,10 @@ /// operands N0 and N1. This is a helper for PerformADDCombine that is /// called with the default operands, and if that fails, with commuted /// operands. -static SDValue PerformADDCombineWithOperands(SDNode *N, SDValue N0, SDValue N1, - TargetLowering::DAGCombinerInfo &DCI, - const ARMSubtarget *Subtarget){ +static SDValue +PerformADDCombineWithOperands(SDNode *N, SDValue N0, SDValue N1, + TargetLowering::DAGCombinerInfo &DCI, + const ARMSubtarget *Subtarget) { // Attempt to create vpaddl for this add. SDValue Result = AddCombineToVPADDL(N, N0, N1, DCI, Subtarget); Index: ../../llvm/test/CodeGen/ARM/longMAC.ll =================================================================== --- ../../llvm/test/CodeGen/ARM/longMAC.ll +++ ../../llvm/test/CodeGen/ARM/longMAC.ll @@ -116,3 +116,29 @@ ret i64 %add } +define i64 @MACLongTest9(i64 %a, i8* nocapture readonly %b, i8* nocapture readonly %c) { +;CHECK-LABEL: MACLongTest9: +;CHECK: umlal + %1 = load i8, i8* %b + %conv = zext i8 %1 to i32 + %2 = load i8, i8* %c + %conv1 = zext i8 %2 to i32 + %mul = mul nuw nsw i32 %conv1, %conv + %conv2.3 = zext i32 %mul to i64 + %add = add nsw i64 %conv2.3, %a + ret i64 %add +} + +define i64 @MACLongTest10(i64 %a, i8* nocapture readonly %b, i16* nocapture readonly %c) { +;CHECK-LABEL: MACLongTest10: +;CHECK: umlal + %1 = load i8, i8* %b + %conv = zext i8 %1 to i32 + %2 = load i16, i16* %c + %conv1 = zext i16 %2 to i32 + %mul = mul nuw nsw i32 %conv1, %conv + %conv2.3 = zext i32 %mul to i64 + %add = add i64 %conv2.3, %a + ret i64 %add +} +