Index: llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp =================================================================== --- llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp +++ llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp @@ -1286,6 +1286,9 @@ if (Value *V = SimplifyUsingDistributiveLaws(I)) return replaceInstUsesWith(I, V); + if (Value *V = SimplifyMull(I)) + return replaceInstUsesWith(I, V); + if (Instruction *R = factorizeMathWithShlOps(I, Builder)) return R; Index: llvm/lib/Transforms/InstCombine/InstCombineInternal.h =================================================================== --- llvm/lib/Transforms/InstCombine/InstCombineInternal.h +++ llvm/lib/Transforms/InstCombine/InstCombineInternal.h @@ -546,6 +546,9 @@ /// value, or null if it didn't simplify. Value *SimplifyUsingDistributiveLaws(BinaryOperator &I); + /// Tries to simplify a few sequence operations into MULL + Value *SimplifyMull(BinaryOperator &I); + /// Tries to simplify add operations using the definition of remainder. /// /// The definition of remainder is X % C = X - (X / C ) * C. The add Index: llvm/lib/Transforms/InstCombine/InstructionCombining.cpp =================================================================== --- llvm/lib/Transforms/InstCombine/InstructionCombining.cpp +++ llvm/lib/Transforms/InstCombine/InstructionCombining.cpp @@ -852,6 +852,42 @@ return SimplifySelectsFeedingBinaryOp(I, LHS, RHS); } +Value *InstCombinerImpl::SimplifyMull(BinaryOperator &I) { + if (!I.getType()->isIntegerTy()) + return nullptr; + + Value *In0, *In1; + Value *M01, *M10, *M00, *Addc; + unsigned HalfBits = I.getType()->getIntegerBitWidth() >> 1; + unsigned HalfMasks = (1LL << HalfBits) - 1; + + // addc = m01 + m10; + // ResLo = m00 + (addc >> 32); + bool IsMulLow = + match(&I, m_c_Add(m_Value(M00), + m_Shl(m_Value(Addc), m_SpecificInt(HalfBits)))) && + match(Addc, m_c_Add(m_Value(M01), m_Value(M10))); + + // In0Lo = in0 & 0xffffffff; In0Hi = in0 >> 32; + // In1Lo = in1 & 0xffffffff; In1Hi = in1 >> 32; + // m01 = In1Hi * In0Lo; m10 = In1Lo * In0Hi; m00 = In1Lo * In0Lo; + if (IsMulLow && + match(M00, m_c_Mul(m_And(m_Value(In1), m_SpecificInt(HalfMasks)), + m_And(m_Value(In0), m_SpecificInt(HalfMasks)))) && + (match(M01, m_c_Mul(m_LShr(m_Specific(In1), m_SpecificInt(HalfBits)), + m_Specific(In0))) || + match(M01, m_c_Mul(m_LShr(m_Specific(In1), m_SpecificInt(HalfBits)), + m_And(m_Specific(In0), m_SpecificInt(HalfMasks))))) && + (match(M10, m_c_Mul(m_LShr(m_Specific(In0), m_SpecificInt(HalfBits)), + m_Specific(In1))) || + match(M10, m_c_Mul(m_LShr(m_Specific(In0), m_SpecificInt(HalfBits)), + m_And(m_Specific(In1), m_SpecificInt(HalfMasks)))))) { + return Builder.CreateMul(In0, In1); + } + + return nullptr; +} + Value *InstCombinerImpl::SimplifySelectsFeedingBinaryOp(BinaryOperator &I, Value *LHS, Value *RHS) { Index: llvm/test/Transforms/InstCombine/mul.ll =================================================================== --- llvm/test/Transforms/InstCombine/mul.ll +++ llvm/test/Transforms/InstCombine/mul.ll @@ -1574,3 +1574,103 @@ %r = mul i32 %zx, -16777216 ; -1 << 24 ret i32 %r } + +define i64 @mul64_low(i64 noundef %in0, i64 noundef %in1) { +; CHECK-LABEL: @mul64_low( +; CHECK-NEXT: [[TMP1:%.*]] = mul i64 [[IN0:%.*]], [[IN1:%.*]] +; CHECK-NEXT: ret i64 [[TMP1]] +; + %In0Lo = and i64 %in0, 4294967295 + %In0Hi = lshr i64 %in0, 32 + %In1Lo = and i64 %in1, 4294967295 + %In1Hi = lshr i64 %in1, 32 + %m10 = mul i64 %In1Hi, %In0Lo + %m01 = mul i64 %In1Lo, %In0Hi + %m00 = mul i64 %In1Lo, %In0Lo + %addc = add i64 %m10, %m01 + %shl = shl i64 %addc, 32 + %addc9 = add i64 %shl, %m00 + ret i64 %addc9 +} + +define i32 @mul32_low_one_extra_user(i32 noundef %in0, i32 noundef %in1) { +; CHECK-LABEL: @mul32_low_one_extra_user( +; CHECK-NEXT: [[IN0LO:%.*]] = and i32 [[IN0:%.*]], 65535 +; CHECK-NEXT: [[IN0HI:%.*]] = lshr i32 [[IN0]], 16 +; CHECK-NEXT: [[IN1LO:%.*]] = and i32 [[IN1:%.*]], 65535 +; CHECK-NEXT: [[IN1HI:%.*]] = lshr i32 [[IN1]], 16 +; CHECK-NEXT: [[M10:%.*]] = mul nuw i32 [[IN1HI]], [[IN0LO]] +; CHECK-NEXT: [[M01:%.*]] = mul nuw i32 [[IN1LO]], [[IN0HI]] +; CHECK-NEXT: [[ADDC:%.*]] = add i32 [[M10]], [[M01]] +; CHECK-NEXT: call void @use32(i32 [[ADDC]]) +; CHECK-NEXT: [[TMP1:%.*]] = mul i32 [[IN0]], [[IN1]] +; CHECK-NEXT: ret i32 [[TMP1]] +; + %In0Lo = and i32 %in0, 65535 + %In0Hi = lshr i32 %in0, 16 + %In1Lo = and i32 %in1, 65535 + %In1Hi = lshr i32 %in1, 16 + %m10 = mul i32 %In1Hi, %In0Lo + %m01 = mul i32 %In1Lo, %In0Hi + %m00 = mul i32 %In1Lo, %In0Lo + %addc = add i32 %m10, %m01 + call void @use32(i32 %addc) + %shl = shl i32 %addc, 16 + %addc9 = add i32 %shl, %m00 + ret i32 %addc9 +} + +define i32 @mul32_low(i32 noundef %in0, i32 noundef %in1) { +; CHECK-LABEL: @mul32_low( +; CHECK-NEXT: [[TMP1:%.*]] = mul i32 [[IN0:%.*]], [[IN1:%.*]] +; CHECK-NEXT: ret i32 [[TMP1]] +; + %In0Lo = and i32 %in0, 65535 + %In0Hi = lshr i32 %in0, 16 + %In1Lo = and i32 %in1, 65535 + %In1Hi = lshr i32 %in1, 16 + %m10 = mul i32 %In1Hi, %In0Lo + %m01 = mul i32 %In1Lo, %In0Hi + %m00 = mul i32 %In1Lo, %In0Lo + %addc = add i32 %m10, %m01 + %shl = shl i32 %addc, 16 + %addc9 = add i32 %shl, %m00 + ret i32 %addc9 +} + +define i16 @mul16_low(i16 %in0, i16 %in1) { +; CHECK-LABEL: @mul16_low( +; CHECK-NEXT: [[TMP1:%.*]] = mul i16 [[IN0:%.*]], [[IN1:%.*]] +; CHECK-NEXT: ret i16 [[TMP1]] +; + %In0Lo = and i16 %in0, 255 + %In0Hi = lshr i16 %in0, 8 + %In1Lo = and i16 %in1, 255 + %In1Hi = lshr i16 %in1, 8 + %m10 = mul i16 %In1Hi, %In0Lo + %m01 = mul i16 %In1Lo, %In0Hi + %m00 = mul i16 %In1Lo, %In0Lo + %addc = add i16 %m10, %m01 + %shl = shl i16 %addc, 8 + %addc9 = add i16 %shl, %m00 + ret i16 %addc9 +} + +; https://alive2.llvm.org/ce/z/2BqKLt +define i8 @mul8_low(i8 %in0, i8 %in1) { +; CHECK-LABEL: @mul8_low( +; CHECK-NEXT: [[TMP1:%.*]] = mul i8 [[IN0:%.*]], [[IN1:%.*]] +; CHECK-NEXT: ret i8 [[TMP1]] +; + %In0Lo = and i8 %in0, 15 + %In0Hi = lshr i8 %in0, 4 + %In1Lo = and i8 %in1, 15 + %In1Hi = lshr i8 %in1, 4 + %m10 = mul i8 %In1Hi, %In0Lo + %m01 = mul i8 %In1Lo, %In0Hi + %m00 = mul i8 %In1Lo, %In0Lo + %addc = add i8 %m10, %m01 + %shl = shl i8 %addc, 4 + %addc9 = add i8 %shl, %m00 + ret i8 %addc9 +}