Index: llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp =================================================================== --- llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp +++ llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp @@ -1267,6 +1267,47 @@ return NewShl; } +/// Tries to simplify a few sequence operations into MULL +static Instruction *SimplifyMull(BinaryOperator &I) { + if (!I.getType()->isIntegerTy()) + return nullptr; + + unsigned BitWidth = I.getType()->getIntegerBitWidth(); + // Skip the odd bitwidth types and large bitwidth types + if ((BitWidth & 0x1) || (BitWidth > 128)) + return nullptr; + + unsigned HalfBits = BitWidth >> 1; + APInt Max = APInt::getMaxValue(HalfBits); + uint64_t HalfMask = Max.getZExtValue(); + + // ResLo = (CrossSum << HalfBits) + (YLo * XLo) + Value *XLo, *YLo; + Value *CrossSum; + if (!match(&I, m_c_Add(m_Shl(m_Value(CrossSum), m_SpecificInt(HalfBits)), + m_Mul(m_Value(YLo), m_Value(XLo))))) + return nullptr; + + // XLo = X & HalfMask + // YLo = Y & HalfMask + Value *X, *Y; + if (!match(XLo, m_And(m_Value(X), m_SpecificInt(HalfMask))) || + !match(YLo, m_And(m_Value(Y), m_SpecificInt(HalfMask)))) + return nullptr; + + // CrossSum = (X' * (Y >> Halfbits)) + (Y' * (X >> HalfBits)) + // X' can be either X or XLo in the pattern (and the same for Y') + if (match(CrossSum, + m_c_Add(m_c_Mul(m_LShr(m_Specific(Y), m_SpecificInt(HalfBits)), + m_CombineOr(m_Specific(X), m_Specific(XLo))), + + m_c_Mul(m_LShr(m_Specific(X), m_SpecificInt(HalfBits)), + m_CombineOr(m_Specific(Y), m_Specific(YLo)))))) + return BinaryOperator::CreateMul(X, Y); + + return nullptr; +} + Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) { if (Value *V = simplifyAddInst(I.getOperand(0), I.getOperand(1), I.hasNoSignedWrap(), I.hasNoUnsignedWrap(), @@ -1286,6 +1327,9 @@ if (Value *V = SimplifyUsingDistributiveLaws(I)) return replaceInstUsesWith(I, V); + if (Instruction *R = SimplifyMull(I)) + return R; + if (Instruction *R = factorizeMathWithShlOps(I, Builder)) return R; Index: llvm/test/Transforms/InstCombine/mul.ll =================================================================== --- llvm/test/Transforms/InstCombine/mul.ll +++ llvm/test/Transforms/InstCombine/mul.ll @@ -1574,3 +1574,177 @@ %r = mul i32 %zx, -16777216 ; -1 << 24 ret i32 %r } + +define i64 @mul64_low(i64 noundef %in0, i64 noundef %in1) { +; CHECK-LABEL: @mul64_low( +; CHECK-NEXT: [[ADDC9:%.*]] = mul i64 [[IN0:%.*]], [[IN1:%.*]] +; CHECK-NEXT: ret i64 [[ADDC9]] +; + %In0Lo = and i64 %in0, 4294967295 + %In0Hi = lshr i64 %in0, 32 + %In1Lo = and i64 %in1, 4294967295 + %In1Hi = lshr i64 %in1, 32 + %m10 = mul i64 %In1Hi, %In0Lo + %m01 = mul i64 %In1Lo, %In0Hi + %m00 = mul i64 %In1Lo, %In0Lo + %addc = add i64 %m10, %m01 + %shl = shl i64 %addc, 32 + %addc9 = add i64 %shl, %m00 + ret i64 %addc9 +} + +define i32 @mul32_low_one_extra_user(i32 noundef %in0, i32 noundef %in1) { +; CHECK-LABEL: @mul32_low_one_extra_user( +; CHECK-NEXT: [[IN0LO:%.*]] = and i32 [[IN0:%.*]], 65535 +; CHECK-NEXT: [[IN0HI:%.*]] = lshr i32 [[IN0]], 16 +; CHECK-NEXT: [[IN1LO:%.*]] = and i32 [[IN1:%.*]], 65535 +; CHECK-NEXT: [[IN1HI:%.*]] = lshr i32 [[IN1]], 16 +; CHECK-NEXT: [[M10:%.*]] = mul nuw i32 [[IN1HI]], [[IN0LO]] +; CHECK-NEXT: [[M01:%.*]] = mul nuw i32 [[IN1LO]], [[IN0HI]] +; CHECK-NEXT: [[ADDC:%.*]] = add i32 [[M10]], [[M01]] +; CHECK-NEXT: call void @use32(i32 [[ADDC]]) +; CHECK-NEXT: [[ADDC9:%.*]] = mul i32 [[IN0]], [[IN1]] +; CHECK-NEXT: ret i32 [[ADDC9]] +; + %In0Lo = and i32 %in0, 65535 + %In0Hi = lshr i32 %in0, 16 + %In1Lo = and i32 %in1, 65535 + %In1Hi = lshr i32 %in1, 16 + %m10 = mul i32 %In1Hi, %In0Lo + %m01 = mul i32 %In1Lo, %In0Hi + %m00 = mul i32 %In1Lo, %In0Lo + %addc = add i32 %m10, %m01 + call void @use32(i32 %addc) + %shl = shl i32 %addc, 16 + %addc9 = add i32 %shl, %m00 + ret i32 %addc9 +} + +define i32 @mul32_low(i32 noundef %in0, i32 noundef %in1) { +; CHECK-LABEL: @mul32_low( +; CHECK-NEXT: [[ADDC9:%.*]] = mul i32 [[IN0:%.*]], [[IN1:%.*]] +; CHECK-NEXT: ret i32 [[ADDC9]] +; + %In0Lo = and i32 %in0, 65535 + %In0Hi = lshr i32 %in0, 16 + %In1Lo = and i32 %in1, 65535 + %In1Hi = lshr i32 %in1, 16 + %m10 = mul i32 %In1Hi, %In0Lo + %m01 = mul i32 %In1Lo, %In0Hi + %m00 = mul i32 %In1Lo, %In0Lo + %addc = add i32 %m10, %m01 + %shl = shl i32 %addc, 16 + %addc9 = add i32 %shl, %m00 + ret i32 %addc9 +} + +define i16 @mul16_low(i16 %in0, i16 %in1) { +; CHECK-LABEL: @mul16_low( +; CHECK-NEXT: [[ADDC9:%.*]] = mul i16 [[IN0:%.*]], [[IN1:%.*]] +; CHECK-NEXT: ret i16 [[ADDC9]] +; + %In0Lo = and i16 %in0, 255 + %In0Hi = lshr i16 %in0, 8 + %In1Lo = and i16 %in1, 255 + %In1Hi = lshr i16 %in1, 8 + %m10 = mul i16 %In1Hi, %In0Lo + %m01 = mul i16 %In1Lo, %In0Hi + %m00 = mul i16 %In1Lo, %In0Lo + %addc = add i16 %m10, %m01 + %shl = shl i16 %addc, 8 + %addc9 = add i16 %shl, %m00 + ret i16 %addc9 +} + +; https://alive2.llvm.org/ce/z/2BqKLt +define i8 @mul8_low(i8 %in0, i8 %in1) { +; CHECK-LABEL: @mul8_low( +; CHECK-NEXT: [[ADDC9:%.*]] = mul i8 [[IN0:%.*]], [[IN1:%.*]] +; CHECK-NEXT: ret i8 [[ADDC9]] +; + %In0Lo = and i8 %in0, 15 + %In0Hi = lshr i8 %in0, 4 + %In1Lo = and i8 %in1, 15 + %In1Hi = lshr i8 %in1, 4 + %m10 = mul i8 %In1Hi, %In0Lo + %m01 = mul i8 %In1Lo, %In0Hi + %m00 = mul i8 %In1Lo, %In0Lo + %addc = add i8 %m10, %m01 + %shl = shl i8 %addc, 4 + %addc9 = add i8 %shl, %m00 + ret i8 %addc9 +} + +; Negative case: Skip odd bitwidth type +define i9 @mul9_low(i9 %in0, i9 %in1) { +; CHECK-LABEL: @mul9_low( +; CHECK-NEXT: [[IN0LO:%.*]] = and i9 [[IN0:%.*]], 15 +; CHECK-NEXT: [[IN0HI:%.*]] = lshr i9 [[IN0]], 4 +; CHECK-NEXT: [[IN1LO:%.*]] = and i9 [[IN1:%.*]], 15 +; CHECK-NEXT: [[IN1HI:%.*]] = lshr i9 [[IN1]], 4 +; CHECK-NEXT: [[M10:%.*]] = mul nuw i9 [[IN1HI]], [[IN0LO]] +; CHECK-NEXT: [[M01:%.*]] = mul nuw i9 [[IN1LO]], [[IN0HI]] +; CHECK-NEXT: [[M00:%.*]] = mul nuw nsw i9 [[IN1LO]], [[IN0LO]] +; CHECK-NEXT: [[ADDC:%.*]] = add i9 [[M10]], [[M01]] +; CHECK-NEXT: [[SHL:%.*]] = shl i9 [[ADDC]], 4 +; CHECK-NEXT: [[ADDC9:%.*]] = add i9 [[SHL]], [[M00]] +; CHECK-NEXT: ret i9 [[ADDC9]] +; + %In0Lo = and i9 %in0, 15 + %In0Hi = lshr i9 %in0, 4 + %In1Lo = and i9 %in1, 15 + %In1Hi = lshr i9 %in1, 4 + %m10 = mul i9 %In1Hi, %In0Lo + %m01 = mul i9 %In1Lo, %In0Hi + %m00 = mul i9 %In1Lo, %In0Lo + %addc = add i9 %m10, %m01 + %shl = shl i9 %addc, 4 + %addc9 = add i9 %shl, %m00 + ret i9 %addc9 +} + +; Negative case: Skip vector type +define <2 x i8> @mul_v2i8_low(<2 x i8> %in0, <2 x i8> %in1) { +; CHECK-LABEL: @mul_v2i8_low( +; CHECK-NEXT: [[IN0LO:%.*]] = and <2 x i8> [[IN0:%.*]], +; CHECK-NEXT: [[IN0HI:%.*]] = lshr <2 x i8> [[IN0]], +; CHECK-NEXT: [[IN1LO:%.*]] = and <2 x i8> [[IN1:%.*]], +; CHECK-NEXT: [[IN1HI:%.*]] = lshr <2 x i8> [[IN1]], +; CHECK-NEXT: [[M10:%.*]] = mul <2 x i8> [[IN1HI]], [[IN0]] +; CHECK-NEXT: [[M01:%.*]] = mul <2 x i8> [[IN0HI]], [[IN1]] +; CHECK-NEXT: [[M00:%.*]] = mul nuw <2 x i8> [[IN1LO]], [[IN0LO]] +; CHECK-NEXT: [[ADDC:%.*]] = add <2 x i8> [[M10]], [[M01]] +; CHECK-NEXT: [[SHL:%.*]] = shl <2 x i8> [[ADDC]], +; CHECK-NEXT: [[ADDC9:%.*]] = add <2 x i8> [[SHL]], [[M00]] +; CHECK-NEXT: ret <2 x i8> [[ADDC9]] +; + %In0Lo = and <2 x i8> %in0, + %In0Hi = lshr <2 x i8> %in0, + %In1Lo = and <2 x i8> %in1, + %In1Hi = lshr <2 x i8> %in1, + %m10 = mul <2 x i8> %In1Hi, %In0Lo + %m01 = mul <2 x i8> %In1Lo, %In0Hi + %m00 = mul <2 x i8> %In1Lo, %In0Lo + %addc = add <2 x i8> %m10, %m01 + %shl = shl <2 x i8> %addc, + %addc9 = add <2 x i8> %shl, %m00 + ret <2 x i8> %addc9 +} + +define i128 @mul128_low(i128 %in0, i128 %in1) { +; CHECK-LABEL: @mul128_low( +; CHECK-NEXT: [[ADDC9:%.*]] = mul i128 [[IN0:%.*]], [[IN1:%.*]] +; CHECK-NEXT: ret i128 [[ADDC9]] +; + %In0Lo = and i128 %in0, 18446744073709551615 + %In0Hi = lshr i128 %in0, 64 + %In1Lo = and i128 %in1, 18446744073709551615 + %In1Hi = lshr i128 %in1, 64 + %m10 = mul i128 %In1Hi, %In0Lo + %m01 = mul i128 %In1Lo, %In0Hi + %m00 = mul i128 %In1Lo, %In0Lo + %addc = add i128 %m10, %m01 + %shl = shl i128 %addc, 64 + %addc9 = add i128 %shl, %m00 + ret i128 %addc9 +}