diff --git a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp @@ -582,10 +582,10 @@ Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); Type *Ty = I.getType(); + unsigned BitWidth = Ty->getScalarSizeInBits(); const APInt *ShAmtAPInt; if (match(Op1, m_APInt(ShAmtAPInt))) { unsigned ShAmt = ShAmtAPInt->getZExtValue(); - unsigned BitWidth = Ty->getScalarSizeInBits(); // shl (zext X), ShAmt --> zext (shl X, ShAmt) // This is only valid if X would have zeros shifted out. @@ -668,6 +668,13 @@ // (X * C2) << C1 --> X * (C2 << C1) if (match(Op0, m_Mul(m_Value(X), m_Constant(C2)))) return BinaryOperator::CreateMul(X, ConstantExpr::getShl(C2, C1)); + + // shl (zext (i1 X)), C1 --> select (X, 0, 1 << C1) + if (match(Op0, m_ZExt(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1)) { + return SelectInst::Create( + X, ConstantExpr::getShl(ConstantInt::get(Ty, APInt(BitWidth, 1)), C1), + ConstantInt::get(Ty, 0)); + } } return nullptr; diff --git a/llvm/test/Transforms/InstCombine/and.ll b/llvm/test/Transforms/InstCombine/and.ll --- a/llvm/test/Transforms/InstCombine/and.ll +++ b/llvm/test/Transforms/InstCombine/and.ll @@ -346,8 +346,7 @@ define i32 @test31(i1 %X) { ; CHECK-LABEL: @test31( -; CHECK-NEXT: [[Y:%.*]] = zext i1 %X to i32 -; CHECK-NEXT: [[Z:%.*]] = shl nuw nsw i32 [[Y]], 4 +; CHECK-NEXT: [[Z:%.*]] = select i1 %X, i32 16, i32 0 ; CHECK-NEXT: ret i32 [[Z]] ; %Y = zext i1 %X to i32 diff --git a/llvm/test/Transforms/InstCombine/shift.ll b/llvm/test/Transforms/InstCombine/shift.ll --- a/llvm/test/Transforms/InstCombine/shift.ll +++ b/llvm/test/Transforms/InstCombine/shift.ll @@ -1191,6 +1191,36 @@ ret i64 %shl } +define i32 @test_shl_zext_bool(i1 %t) { +; CHECK-LABEL: @test_shl_zext_bool( +; CHECK-NEXT: [[SEL:%.*]] = select i1 %t, i32 4, i32 0 +; CHECK-NEXT: ret i32 [[SEL]] +; + %ext = zext i1 %t to i32 + %shl = shl i32 %ext, 2 + ret i32 %shl +} + +define <2 x i32> @test_shl_zext_bool_splat(<2 x i1> %t) { +; CHECK-LABEL: @test_shl_zext_bool_splat( +; CHECK-NEXT: [[SEL:%.*]] = select <2 x i1> %t, <2 x i32> , <2 x i32> zeroinitializer +; CHECK-NEXT: ret <2 x i32> [[SEL]] +; + %ext = zext <2 x i1> %t to <2 x i32> + %shl = shl <2 x i32> %ext, + ret <2 x i32> %shl +} + +define <2 x i32> @test_shl_zext_bool_vec(<2 x i1> %t) { +; CHECK-LABEL: @test_shl_zext_bool_vec( +; CHECK-NEXT: [[SEL:%.*]] = select <2 x i1> %t, <2 x i32> , <2 x i32> zeroinitializer +; CHECK-NEXT: ret <2 x i32> [[SEL]] +; + %ext = zext <2 x i1> %t to <2 x i32> + %shl = shl <2 x i32> %ext, + ret <2 x i32> %shl +} + define <2 x i64> @test_64_splat_vec(<2 x i32> %t) { ; CHECK-LABEL: @test_64_splat_vec( ; CHECK-NEXT: [[TMP1:%.*]] = shl <2 x i32> %t,