diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp @@ -546,16 +546,6 @@ } case Instruction::Mul: { if (DemandedMask.isPowerOf2()) { - if (I->getOperand(0) == I->getOperand(1)) { - // X * X is odd iff X is odd. - if (DemandedMask == 1) - return I->getOperand(0); - - // 'Quadratic Reciprocity': mul(x,x) -> 0 if we're only demanding bit[1] - if (DemandedMask == 2) - return ConstantInt::getNullValue(VTy); - } - // The LSB of X*Y is set only if (X & 1) == 1 and (Y & 1) == 1. // If we demand exactly one bit N and we have "X * (C' << N)" where C' is // odd (has LSB set), then the left-shifted low bit of X is the answer. @@ -568,6 +558,15 @@ return InsertNewInstWith(Shl, *I); } } + // For a squared value "X * X", the bottom 2 bits are 0 and X[0] because: + // X * X is odd iff X is odd. + // 'Quadratic Reciprocity': X * X -> 0 for bit[1] + if (I->getOperand(0) == I->getOperand(1) && DemandedMask.ult(4)) { + Constant *One = ConstantInt::get(VTy, 1); + Instruction *And1 = BinaryOperator::CreateAnd(I->getOperand(0), One); + return InsertNewInstWith(And1, *I); + } + computeKnownBits(I, Known, Depth, CxtI); break; } diff --git a/llvm/test/Transforms/InstCombine/mul-masked-bits.ll b/llvm/test/Transforms/InstCombine/mul-masked-bits.ll --- a/llvm/test/Transforms/InstCombine/mul-masked-bits.ll +++ b/llvm/test/Transforms/InstCombine/mul-masked-bits.ll @@ -113,8 +113,8 @@ define i33 @squared_one_demanded_low_bit(i33 %x) { ; CHECK-LABEL: @squared_one_demanded_low_bit( -; CHECK-NEXT: [[AND:%.*]] = and i33 [[X:%.*]], 1 -; CHECK-NEXT: ret i33 [[AND]] +; CHECK-NEXT: [[TMP1:%.*]] = and i33 [[X:%.*]], 1 +; CHECK-NEXT: ret i33 [[TMP1]] ; %mul = mul i33 %x, %x %and = and i33 %mul, 1 @@ -133,9 +133,8 @@ define i33 @squared_demanded_2_low_bits(i33 %x) { ; CHECK-LABEL: @squared_demanded_2_low_bits( -; CHECK-NEXT: [[MUL:%.*]] = mul i33 [[X:%.*]], [[X]] -; CHECK-NEXT: [[AND:%.*]] = and i33 [[MUL]], 3 -; CHECK-NEXT: ret i33 [[AND]] +; CHECK-NEXT: [[TMP1:%.*]] = and i33 [[X:%.*]], 1 +; CHECK-NEXT: ret i33 [[TMP1]] ; %mul = mul i33 %x, %x %and = and i33 %mul, 3 @@ -144,11 +143,24 @@ define <2 x i8> @squared_demanded_2_low_bits_splat(<2 x i8> %x) { ; CHECK-LABEL: @squared_demanded_2_low_bits_splat( -; CHECK-NEXT: [[MUL:%.*]] = mul <2 x i8> [[X:%.*]], [[X]] -; CHECK-NEXT: [[AND:%.*]] = or <2 x i8> [[MUL]], +; CHECK-NEXT: [[TMP1:%.*]] = and <2 x i8> [[X:%.*]], +; CHECK-NEXT: [[AND:%.*]] = or <2 x i8> [[TMP1]], ; CHECK-NEXT: ret <2 x i8> [[AND]] ; %mul = mul <2 x i8> %x, %x %and = or <2 x i8> %mul, ret <2 x i8> %and } + +; negative test + +define i33 @squared_demanded_3_low_bits(i33 %x) { +; CHECK-LABEL: @squared_demanded_3_low_bits( +; CHECK-NEXT: [[MUL:%.*]] = mul i33 [[X:%.*]], [[X]] +; CHECK-NEXT: [[AND:%.*]] = and i33 [[MUL]], 7 +; CHECK-NEXT: ret i33 [[AND]] +; + %mul = mul i33 %x, %x + %and = and i33 %mul, 7 + ret i33 %and +}