Index: lib/Transforms/InstCombine/InstCombineAndOrXor.cpp =================================================================== --- lib/Transforms/InstCombine/InstCombineAndOrXor.cpp +++ lib/Transforms/InstCombine/InstCombineAndOrXor.cpp @@ -1328,11 +1328,13 @@ return nullptr; } -static Instruction *foldBoolSextMaskToSelect(BinaryOperator &I) { - Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); - - // Canonicalize SExt or Not to the LHS - if (match(Op1, m_SExt(m_Value())) || match(Op1, m_Not(m_Value()))) { +static Instruction *foldBoolSextMaskToSelect(BinaryOperator &And, + InstCombiner::BuilderTy &Builder) { + Value *Op0 = And.getOperand(0), *Op1 = And.getOperand(1); + + // Canonicalize SExt, Not, or BitCast to the LHS. + if (match(Op1, m_SExt(m_Value())) || match(Op1, m_Not(m_Value())) || + match(Op1, m_BitCast(m_Value()))) { std::swap(Op0, Op1); } @@ -1350,7 +1352,29 @@ Value *Zero = Constant::getNullValue(Op0->getType()); return SelectInst::Create(X, Zero, Op1); } - + + // Look through a bitcast of the first 'and' operand and try the same folds as + // above. Bitcast the other operand of the 'and' and bitcast the new select, + // so the types all line up. We're assuming that bitcasts are free and, as + // above, a select is cheaper than the combination of sext + and. + // + // Fold (and (bc (sext bool X), Op1)) --> (bc (select X, (bc Op1), 0)) + // Fold (and (bc ~(sext bool X), Op1)) --> (bc (select X, 0, (bc Op1))) + + Instruction *SextOrNot; + if (match(Op0, m_BitCast(m_Instruction(SextOrNot))) && + match(SextOrNot, + m_CombineOr(m_SExt(m_Value(X)), m_Not(m_SExt(m_Value(X))))) && + X->getType()->getScalarType()->isIntegerTy(1)) { + Type *SrcType = SextOrNot->getType(); + Value *Zero = Constant::getNullValue(SrcType); + Value *Op1BC = Builder.CreateBitCast(Op1, SrcType); + bool IsSext = SextOrNot->getOpcode() == Instruction::SExt; + Value *Select = IsSext ? Builder.CreateSelect(X, Op1BC, Zero) + : Builder.CreateSelect(X, Zero, Op1BC); + return CastInst::Create(Instruction::BitCast, Select, Op0->getType()); + } + return nullptr; } @@ -1612,7 +1636,7 @@ } } - if (Instruction *Select = foldBoolSextMaskToSelect(I)) + if (Instruction *Select = foldBoolSextMaskToSelect(I, *Builder)) return Select; return Changed ? &I : nullptr; Index: test/Transforms/InstCombine/logical-select.ll =================================================================== --- test/Transforms/InstCombine/logical-select.ll +++ test/Transforms/InstCombine/logical-select.ll @@ -83,9 +83,9 @@ define <2 x i64> @vecBitcastOp0(<4 x i1> %cmp, <2 x i64> %a) { ; CHECK-LABEL: @vecBitcastOp0( -; CHECK-NEXT: [[SEXT:%.*]] = sext <4 x i1> %cmp to <4 x i32> -; CHECK-NEXT: [[BC:%.*]] = bitcast <4 x i32> [[SEXT]] to <2 x i64> -; CHECK-NEXT: [[AND:%.*]] = and <2 x i64> [[BC]], %a +; CHECK-NEXT: [[TMP1:%.*]] = bitcast <2 x i64> %a to <4 x i32> +; CHECK-NEXT: [[TMP2:%.*]] = select <4 x i1> %cmp, <4 x i32> [[TMP1]], <4 x i32> zeroinitializer +; CHECK-NEXT: [[AND:%.*]] = bitcast <4 x i32> [[TMP2]] to <2 x i64> ; CHECK-NEXT: ret <2 x i64> [[AND]] ; %sext = sext <4 x i1> %cmp to <4 x i32> @@ -100,9 +100,9 @@ define <2 x i64> @vecBitcastOp1(<4 x i1> %cmp, <2 x i64> %a) { ; CHECK-LABEL: @vecBitcastOp1( ; CHECK-NEXT: [[A2:%.*]] = shl <2 x i64> %a, -; CHECK-NEXT: [[SEXT:%.*]] = sext <4 x i1> %cmp to <4 x i32> -; CHECK-NEXT: [[BC:%.*]] = bitcast <4 x i32> [[SEXT]] to <2 x i64> -; CHECK-NEXT: [[AND:%.*]] = and <2 x i64> [[A2]], [[BC]] +; CHECK-NEXT: [[TMP1:%.*]] = bitcast <2 x i64> [[A2]] to <4 x i32> +; CHECK-NEXT: [[TMP2:%.*]] = select <4 x i1> %cmp, <4 x i32> [[TMP1]], <4 x i32> zeroinitializer +; CHECK-NEXT: [[AND:%.*]] = bitcast <4 x i32> [[TMP2]] to <2 x i64> ; CHECK-NEXT: ret <2 x i64> [[AND]] ; %a2 = add <2 x i64> %a, %a @@ -116,10 +116,9 @@ define <2 x i64> @vecBitcastNotOp0(<4 x i1> %cmp, <2 x i64> %a) { ; CHECK-LABEL: @vecBitcastNotOp0( -; CHECK-NEXT: [[SEXT:%.*]] = sext <4 x i1> %cmp to <4 x i32> -; CHECK-NEXT: [[NEG:%.*]] = xor <4 x i32> [[SEXT]], -; CHECK-NEXT: [[BC:%.*]] = bitcast <4 x i32> [[NEG]] to <2 x i64> -; CHECK-NEXT: [[AND:%.*]] = and <2 x i64> [[BC]], %a +; CHECK-NEXT: [[TMP1:%.*]] = bitcast <2 x i64> %a to <4 x i32> +; CHECK-NEXT: [[TMP2:%.*]] = select <4 x i1> %cmp, <4 x i32> zeroinitializer, <4 x i32> [[TMP1]] +; CHECK-NEXT: [[AND:%.*]] = bitcast <4 x i32> [[TMP2]] to <2 x i64> ; CHECK-NEXT: ret <2 x i64> [[AND]] ; %sext = sext <4 x i1> %cmp to <4 x i32> @@ -135,10 +134,9 @@ define <2 x i64> @vecBitcastNotOp1(<4 x i1> %cmp, <2 x i64> %a) { ; CHECK-LABEL: @vecBitcastNotOp1( ; CHECK-NEXT: [[A2:%.*]] = shl <2 x i64> %a, -; CHECK-NEXT: [[SEXT:%.*]] = sext <4 x i1> %cmp to <4 x i32> -; CHECK-NEXT: [[NEG:%.*]] = xor <4 x i32> [[SEXT]], -; CHECK-NEXT: [[BC:%.*]] = bitcast <4 x i32> [[NEG]] to <2 x i64> -; CHECK-NEXT: [[AND:%.*]] = and <2 x i64> [[A2]], [[BC]] +; CHECK-NEXT: [[TMP1:%.*]] = bitcast <2 x i64> [[A2]] to <4 x i32> +; CHECK-NEXT: [[TMP2:%.*]] = select <4 x i1> %cmp, <4 x i32> zeroinitializer, <4 x i32> [[TMP1]] +; CHECK-NEXT: [[AND:%.*]] = bitcast <4 x i32> [[TMP2]] to <2 x i64> ; CHECK-NEXT: ret <2 x i64> [[AND]] ; %a2 = add <2 x i64> %a, %a @@ -153,10 +151,9 @@ define <2 x i64> @vecBitcastSext(<4 x i1> %cmp, <2 x i64> %a) { ; CHECK-LABEL: @vecBitcastSext( -; CHECK-NEXT: [[SEXT:%.*]] = sext <4 x i1> %cmp to <4 x i32> -; CHECK-NEXT: [[NEG1:%.*]] = xor <4 x i32> [[SEXT]], -; CHECK-NEXT: [[NEG:%.*]] = bitcast <4 x i32> [[NEG:%.*]]1 to <2 x i64> -; CHECK-NEXT: [[AND:%.*]] = and <2 x i64> [[NEG]], %a +; CHECK-NEXT: [[TMP1:%.*]] = bitcast <2 x i64> %a to <4 x i32> +; CHECK-NEXT: [[TMP2:%.*]] = select <4 x i1> %cmp, <4 x i32> zeroinitializer, <4 x i32> [[TMP1]] +; CHECK-NEXT: [[AND:%.*]] = bitcast <4 x i32> [[TMP2]] to <2 x i64> ; CHECK-NEXT: ret <2 x i64> [[AND]] ; %sext = sext <4 x i1> %cmp to <4 x i32>