Index: llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp =================================================================== --- llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp +++ llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp @@ -545,10 +545,20 @@ break; } case Instruction::Mul: { - // The LSB of X*Y is set only if (X & 1) == 1 and (Y & 1) == 1. - // If we demand exactly one bit N and we have "X * (C' << N)" where C' is - // odd (has LSB set), then the left-shifted low bit of X is the answer. if (DemandedMask.isPowerOf2()) { + if (I->getOperand(0) == I->getOperand(1)) { + // X * X is odd iff X is odd. + if (DemandedMask == 1) + return I->getOperand(0); + + // 'Quadratic Reciprocity': mul(x,x) -> 0 if we're only demanding bit[1] + if (DemandedMask == 2) + return ConstantInt::getNullValue(VTy); + } + + // The LSB of X*Y is set only if (X & 1) == 1 and (Y & 1) == 1. + // If we demand exactly one bit N and we have "X * (C' << N)" where C' is + // odd (has LSB set), then the left-shifted low bit of X is the answer. unsigned CTZ = DemandedMask.countTrailingZeros(); const APInt *C; if (match(I->getOperand(1), m_APInt(C)) && @@ -557,9 +567,6 @@ Instruction *Shl = BinaryOperator::CreateShl(I->getOperand(0), ShiftC); return InsertNewInstWith(Shl, *I); } - // 'Quadratic Reciprocity': mul(x,x) -> 0 if we're only demanding bit[1] - if (DemandedMask == 2 && I->getOperand(0) == I->getOperand(1)) - return ConstantInt::getNullValue(VTy); } computeKnownBits(I, Known, Depth, CxtI); break; Index: llvm/test/Transforms/InstCombine/mul-masked-bits.ll =================================================================== --- llvm/test/Transforms/InstCombine/mul-masked-bits.ll +++ llvm/test/Transforms/InstCombine/mul-masked-bits.ll @@ -87,9 +87,8 @@ define i33 @squared_one_demanded_low_bit(i33 %x) { ; CHECK-LABEL: @squared_one_demanded_low_bit( -; CHECK-NEXT: [[MUL:%.*]] = mul i33 [[X:%.*]], [[X]] -; CHECK-NEXT: [[AND:%.*]] = and i33 [[MUL]], 1 -; CHECK-NEXT: ret i33 [[AND]] +; CHECK-NEXT: [[TMP1:%.*]] = and i33 [[X:%.*]], 1 +; CHECK-NEXT: ret i33 [[TMP1]] ; %mul = mul i33 %x, %x %and = and i33 %mul, 1 @@ -98,8 +97,7 @@ define <2 x i8> @squared_one_demanded_low_bit_splat(<2 x i8> %x) { ; CHECK-LABEL: @squared_one_demanded_low_bit_splat( -; CHECK-NEXT: [[MUL:%.*]] = mul <2 x i8> [[X:%.*]], [[X]] -; CHECK-NEXT: [[AND:%.*]] = or <2 x i8> [[MUL]], +; CHECK-NEXT: [[AND:%.*]] = or <2 x i8> [[X:%.*]], ; CHECK-NEXT: ret <2 x i8> [[AND]] ; %mul = mul <2 x i8> %x, %x