diff --git a/llvm/lib/Target/AArch64/AArch64FastISel.cpp b/llvm/lib/Target/AArch64/AArch64FastISel.cpp --- a/llvm/lib/Target/AArch64/AArch64FastISel.cpp +++ b/llvm/lib/Target/AArch64/AArch64FastISel.cpp @@ -3681,11 +3681,13 @@ if (VT == MVT::i32) { MulReg = emitSMULL_rr(MVT::i64, LHSReg, RHSReg); - unsigned ShiftReg = emitLSR_ri(MVT::i64, MVT::i64, MulReg, 32); - MulReg = fastEmitInst_extractsubreg(VT, MulReg, AArch64::sub_32); - ShiftReg = fastEmitInst_extractsubreg(VT, ShiftReg, AArch64::sub_32); - emitSubs_rs(VT, ShiftReg, MulReg, AArch64_AM::ASR, 31, - /*WantResult=*/false); + unsigned MulSubReg = + fastEmitInst_extractsubreg(VT, MulReg, AArch64::sub_32); + // cmp xreg, wreg, sxtw + emitAddSub_rx(/*UseAdd=*/false, MVT::i64, MulReg, MulSubReg, + AArch64_AM::SXTW, /*ShiftImm=*/0, /*SetFlags=*/true, + /*WantResult=*/false); + MulReg = MulSubReg; } else { assert(VT == MVT::i64 && "Unexpected value type."); // LHSReg and RHSReg cannot be killed by this Mul, since they are @@ -3709,8 +3711,11 @@ if (VT == MVT::i32) { MulReg = emitUMULL_rr(MVT::i64, LHSReg, RHSReg); - emitSubs_rs(MVT::i64, AArch64::XZR, MulReg, AArch64_AM::LSR, 32, - /*WantResult=*/false); + // tst xreg, #0xffffffff00000000 + BuildMI(*FuncInfo.MBB, FuncInfo.InsertPt, DbgLoc, + TII.get(AArch64::ANDSXri), AArch64::XZR) + .addReg(MulReg) + .addImm(AArch64_AM::encodeLogicalImmediate(0xFFFFFFFF00000000, 64)); MulReg = fastEmitInst_extractsubreg(VT, MulReg, AArch64::sub_32); } else { assert(VT == MVT::i64 && "Unexpected value type."); diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -2995,50 +2995,25 @@ CC = AArch64CC::NE; bool IsSigned = Op.getOpcode() == ISD::SMULO; if (Op.getValueType() == MVT::i32) { + // Extend to 64-bits, then perform a 64-bit multiply. unsigned ExtendOpc = IsSigned ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND; - // For a 32 bit multiply with overflow check we want the instruction - // selector to generate a widening multiply (SMADDL/UMADDL). For that we - // need to generate the following pattern: - // (i64 add 0, (i64 mul (i64 sext|zext i32 %a), (i64 sext|zext i32 %b)) LHS = DAG.getNode(ExtendOpc, DL, MVT::i64, LHS); RHS = DAG.getNode(ExtendOpc, DL, MVT::i64, RHS); SDValue Mul = DAG.getNode(ISD::MUL, DL, MVT::i64, LHS, RHS); - SDValue Add = DAG.getNode(ISD::ADD, DL, MVT::i64, Mul, - DAG.getConstant(0, DL, MVT::i64)); - // On AArch64 the upper 32 bits are always zero extended for a 32 bit - // operation. We need to clear out the upper 32 bits, because we used a - // widening multiply that wrote all 64 bits. In the end this should be a - // noop. - Value = DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, Add); + Value = DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, Mul); + + // Check that the result fits into a 32-bit integer. + SDVTList VTs = DAG.getVTList(MVT::i64, MVT_CC); if (IsSigned) { - // The signed overflow check requires more than just a simple check for - // any bit set in the upper 32 bits of the result. These bits could be - // just the sign bits of a negative number. To perform the overflow - // check we have to arithmetic shift right the 32nd bit of the result by - // 31 bits. Then we compare the result to the upper 32 bits. - SDValue UpperBits = DAG.getNode(ISD::SRL, DL, MVT::i64, Add, - DAG.getConstant(32, DL, MVT::i64)); - UpperBits = DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, UpperBits); - SDValue LowerBits = DAG.getNode(ISD::SRA, DL, MVT::i32, Value, - DAG.getConstant(31, DL, MVT::i64)); - // It is important that LowerBits is last, otherwise the arithmetic - // shift will not be folded into the compare (SUBS). - SDVTList VTs = DAG.getVTList(MVT::i32, MVT::i32); - Overflow = DAG.getNode(AArch64ISD::SUBS, DL, VTs, UpperBits, LowerBits) - .getValue(1); + // cmp xreg, wreg, sxtw + SDValue SExtMul = DAG.getNode(ISD::SIGN_EXTEND, DL, MVT::i64, Value); + Overflow = + DAG.getNode(AArch64ISD::SUBS, DL, VTs, Mul, SExtMul).getValue(1); } else { - // The overflow check for unsigned multiply is easy. We only need to - // check if any of the upper 32 bits are set. This can be done with a - // CMP (shifted register). For that we need to generate the following - // pattern: - // (i64 AArch64ISD::SUBS i64 0, (i64 srl i64 %Mul, i64 32) - SDValue UpperBits = DAG.getNode(ISD::SRL, DL, MVT::i64, Mul, - DAG.getConstant(32, DL, MVT::i64)); - SDVTList VTs = DAG.getVTList(MVT::i64, MVT::i32); + // tst xreg, #0xffffffff00000000 + SDValue UpperBits = DAG.getConstant(0xFFFFFFFF00000000, DL, MVT::i64); Overflow = - DAG.getNode(AArch64ISD::SUBS, DL, VTs, - DAG.getConstant(0, DL, MVT::i64), - UpperBits).getValue(1); + DAG.getNode(AArch64ISD::ANDS, DL, VTs, Mul, UpperBits).getValue(1); } break; } diff --git a/llvm/test/CodeGen/AArch64/arm64-xaluo.ll b/llvm/test/CodeGen/AArch64/arm64-xaluo.ll --- a/llvm/test/CodeGen/AArch64/arm64-xaluo.ll +++ b/llvm/test/CodeGen/AArch64/arm64-xaluo.ll @@ -202,8 +202,7 @@ entry: ; CHECK-LABEL: smulo.i32 ; CHECK: smull x[[MREG:[0-9]+]], w0, w1 -; CHECK-NEXT: lsr x[[SREG:[0-9]+]], x[[MREG]], #32 -; CHECK-NEXT: cmp w[[SREG]], w[[MREG]], asr #31 +; CHECK-NEXT: cmp x[[MREG]], w[[MREG]], sxtw ; CHECK-NEXT: cset {{w[0-9]+}}, ne %t = call {i32, i1} @llvm.smul.with.overflow.i32(i32 %v1, i32 %v2) %val = extractvalue {i32, i1} %t, 0 @@ -242,7 +241,7 @@ entry: ; CHECK-LABEL: umulo.i32 ; CHECK: umull [[MREG:x[0-9]+]], w0, w1 -; CHECK-NEXT: cmp xzr, [[MREG]], lsr #32 +; CHECK-NEXT: tst [[MREG]], #0xffffffff00000000 ; CHECK-NEXT: cset {{w[0-9]+}}, ne %t = call {i32, i1} @llvm.umul.with.overflow.i32(i32 %v1, i32 %v2) %val = extractvalue {i32, i1} %t, 0 @@ -460,8 +459,7 @@ entry: ; CHECK-LABEL: smulo.select.i32 ; CHECK: smull x[[MREG:[0-9]+]], w0, w1 -; CHECK-NEXT: lsr x[[SREG:[0-9]+]], x[[MREG]], #32 -; CHECK-NEXT: cmp w[[SREG]], w[[MREG]], asr #31 +; CHECK-NEXT: cmp x[[MREG]], w[[MREG]], sxtw ; CHECK-NEXT: csel w0, w0, w1, ne %t = call {i32, i1} @llvm.smul.with.overflow.i32(i32 %v1, i32 %v2) %obit = extractvalue {i32, i1} %t, 1 @@ -473,8 +471,7 @@ entry: ; CHECK-LABEL: smulo.not.i32 ; CHECK: smull x[[MREG:[0-9]+]], w0, w1 -; CHECK-NEXT: lsr x[[SREG:[0-9]+]], x[[MREG]], #32 -; CHECK-NEXT: cmp w[[SREG]], w[[MREG]], asr #31 +; CHECK-NEXT: cmp x[[MREG]], w[[MREG]], sxtw ; CHECK-NEXT: cset w0, eq %t = call {i32, i1} @llvm.smul.with.overflow.i32(i32 %v1, i32 %v2) %obit = extractvalue {i32, i1} %t, 1 @@ -512,7 +509,7 @@ entry: ; CHECK-LABEL: umulo.select.i32 ; CHECK: umull [[MREG:x[0-9]+]], w0, w1 -; CHECK-NEXT: cmp xzr, [[MREG]], lsr #32 +; CHECK-NEXT: tst [[MREG]], #0xffffffff00000000 ; CHECK-NEXT: csel w0, w0, w1, ne %t = call {i32, i1} @llvm.umul.with.overflow.i32(i32 %v1, i32 %v2) %obit = extractvalue {i32, i1} %t, 1 @@ -524,7 +521,7 @@ entry: ; CHECK-LABEL: umulo.not.i32 ; CHECK: umull [[MREG:x[0-9]+]], w0, w1 -; CHECK-NEXT: cmp xzr, [[MREG]], lsr #32 +; CHECK-NEXT: tst [[MREG]], #0xffffffff00000000 ; CHECK-NEXT: cset w0, eq %t = call {i32, i1} @llvm.umul.with.overflow.i32(i32 %v1, i32 %v2) %obit = extractvalue {i32, i1} %t, 1 @@ -700,8 +697,7 @@ entry: ; CHECK-LABEL: smulo.br.i32 ; CHECK: smull x[[MREG:[0-9]+]], w0, w1 -; CHECK-NEXT: lsr x[[SREG:[0-9]+]], x8, #32 -; CHECK-NEXT: cmp w[[SREG]], w[[MREG]], asr #31 +; CHECK-NEXT: cmp x[[MREG]], w[[MREG]], sxtw ; CHECK-NEXT: b.eq %t = call {i32, i1} @llvm.smul.with.overflow.i32(i32 %v1, i32 %v2) %val = extractvalue {i32, i1} %t, 0 @@ -755,7 +751,7 @@ entry: ; CHECK-LABEL: umulo.br.i32 ; CHECK: umull [[MREG:x[0-9]+]], w0, w1 -; CHECK-NEXT: cmp xzr, [[MREG]], lsr #32 +; CHECK-NEXT: tst [[MREG]], #0xffffffff00000000 ; CHECK-NEXT: b.eq %t = call {i32, i1} @llvm.umul.with.overflow.i32(i32 %v1, i32 %v2) %val = extractvalue {i32, i1} %t, 0