Index: lib/Transforms/InstCombine/InstCombineVectorOps.cpp =================================================================== --- lib/Transforms/InstCombine/InstCombineVectorOps.cpp +++ lib/Transforms/InstCombine/InstCombineVectorOps.cpp @@ -1140,7 +1140,46 @@ return true; } -static Instruction *foldSelectShuffles(ShuffleVectorInst &Shuf) { +/// These are the ingredients in an alternate form binary operator as described +/// below. +struct BinopElts { + BinaryOperator::BinaryOps Opcode; + Value *Op0; + Value *Op1; +}; + +/// Binops may be transformed into binops with different opcodes and operands. +/// Reverse the usual canonicalization to enable folds with the non-canonical +/// form of the binop. If a transform is possible, return the elements of the +/// new binop. If not, return invalid elements. +static BinopElts getAlternateBinop(BinaryOperator *BO, const DataLayout &DL) { + switch (BO->getOpcode()) { + case Instruction::Shl: { + // shl X, C --> mul X, 1 << C + Constant *C; + if (match(BO->getOperand(1), m_Constant(C))) { + C = ConstantExpr::getShl(ConstantInt::get(BO->getType(), 1), C); + return { Instruction::Mul, BO->getOperand(0), C }; + } + break; + } + case Instruction::Or: { + // or X, C --> add X, C (when X and C have no common bits set) + const APInt *C; + if (match(BO->getOperand(1), m_APInt(C)) && + MaskedValueIsZero(BO->getOperand(0), *C, DL)) { + return { Instruction::Add, BO->getOperand(0), BO->getOperand(1) }; + } + break; + } + default: + break; + } + return { (BinaryOperator::BinaryOps)0, nullptr, nullptr }; +} + +static Instruction *foldSelectShuffle(ShuffleVectorInst &Shuf, + const DataLayout &DL) { // Folds under here require the equivalent of a vector select. if (!Shuf.isSelect()) return nullptr; @@ -1164,10 +1203,36 @@ else return nullptr; - // TODO: There are potential folds where the opcodes do not match (mul+shl). - if (B0->getOpcode() != B1->getOpcode()) + // We need matching binops to fold the lanes together. + BinaryOperator::BinaryOps Opc0 = B0->getOpcode(); + BinaryOperator::BinaryOps Opc1 = B1->getOpcode(); + bool DropNSW = false; + if (Opc0 != Opc1) { + // TODO: We drop "nsw" if shift is converted into multiply because it may + // not be correct when the shift amount is BitWidth - 1. We could examine + // each vector element to determine if it is safe to keep that flag. + if (Opc0 == Instruction::Shl || Opc1 == Instruction::Shl) + DropNSW = true; + BinopElts AltB0 = getAlternateBinop(B0, DL); + if (AltB0.Opcode && isa(AltB0.Op1)) { + Opc0 = AltB0.Opcode; + C0 = cast(AltB0.Op1); + } else { + // Try again with B1. + BinopElts AltB1 = getAlternateBinop(B1, DL); + if (AltB1.Opcode && isa(AltB1.Op1)) { + Opc1 = AltB1.Opcode; + C1 = cast(AltB1.Op1); + } + } + } + + if (Opc0 != Opc1) return nullptr; + // The opcodes must be the same. Use a new name to make that clear. + BinaryOperator::BinaryOps BOpc = Opc0; + // Remove a binop and the shuffle by rearranging the constant: // shuffle (op X, C0), (op X, C1), M --> op X, C' // shuffle (op C0, X), (op C1, X), M --> op C', X @@ -1179,13 +1244,14 @@ if (B0->isIntDivRem()) NewC = getSafeVectorConstantForIntDivRem(NewC); - BinaryOperator::BinaryOps Opc = B0->getOpcode(); - Instruction *NewBO = ConstantsAreOp1 ? BinaryOperator::Create(Opc, X, NewC) : - BinaryOperator::Create(Opc, NewC, X); + Instruction *NewBO = ConstantsAreOp1 ? BinaryOperator::Create(BOpc, X, NewC) : + BinaryOperator::Create(BOpc, NewC, X); // Flags are intersected from the 2 source binops. NewBO->copyIRFlags(B0); NewBO->andIRFlags(B1); + if (DropNSW) + NewBO->setHasNoSignedWrap(false); return NewBO; } @@ -1199,7 +1265,7 @@ LHS, RHS, SVI.getMask(), SVI.getType(), SQ.getWithInstruction(&SVI))) return replaceInstUsesWith(SVI, V); - if (Instruction *I = foldSelectShuffles(SVI)) + if (Instruction *I = foldSelectShuffle(SVI, DL)) return I; bool MadeChange = false; Index: test/Transforms/InstCombine/shuffle_select.ll =================================================================== --- test/Transforms/InstCombine/shuffle_select.ll +++ test/Transforms/InstCombine/shuffle_select.ll @@ -239,14 +239,11 @@ ret <4 x double> %t3 } -; FIXME: ; Shift-left with constant shift amount can be converted to mul to enable the fold. define <4 x i32> @mul_shl(<4 x i32> %v0) { ; CHECK-LABEL: @mul_shl( -; CHECK-NEXT: [[T1:%.*]] = mul nuw <4 x i32> [[V0:%.*]], -; CHECK-NEXT: [[T2:%.*]] = shl nuw <4 x i32> [[V0]], -; CHECK-NEXT: [[T3:%.*]] = shufflevector <4 x i32> [[T1]], <4 x i32> [[T2]], <4 x i32> +; CHECK-NEXT: [[T3:%.*]] = mul nuw <4 x i32> [[V0:%.*]], ; CHECK-NEXT: ret <4 x i32> [[T3]] ; %t1 = mul nuw <4 x i32> %v0, @@ -257,9 +254,7 @@ define <4 x i32> @shl_mul(<4 x i32> %v0) { ; CHECK-LABEL: @shl_mul( -; CHECK-NEXT: [[T1:%.*]] = shl nsw <4 x i32> [[V0:%.*]], -; CHECK-NEXT: [[T2:%.*]] = mul nsw <4 x i32> [[V0]], -; CHECK-NEXT: [[T3:%.*]] = shufflevector <4 x i32> [[T1]], <4 x i32> [[T2]], <4 x i32> +; CHECK-NEXT: [[T3:%.*]] = mul <4 x i32> [[V0:%.*]], ; CHECK-NEXT: ret <4 x i32> [[T3]] ; %t1 = shl nsw <4 x i32> %v0, @@ -273,8 +268,7 @@ define <4 x i32> @mul_is_nop_shl(<4 x i32> %v0) { ; CHECK-LABEL: @mul_is_nop_shl( -; CHECK-NEXT: [[T2:%.*]] = shl <4 x i32> [[V0:%.*]], -; CHECK-NEXT: [[T3:%.*]] = shufflevector <4 x i32> [[V0]], <4 x i32> [[T2]], <4 x i32> +; CHECK-NEXT: [[T3:%.*]] = shl <4 x i32> [[V0:%.*]], ; CHECK-NEXT: ret <4 x i32> [[T3]] ; %t1 = mul <4 x i32> %v0, @@ -283,6 +277,8 @@ ret <4 x i32> %t3 } +; Negative test: shift amount (operand 1) must be constant. + define <4 x i32> @shl_mul_not_constant_shift_amount(<4 x i32> %v0) { ; CHECK-LABEL: @shl_mul_not_constant_shift_amount( ; CHECK-NEXT: [[T1:%.*]] = shl <4 x i32> , [[V0:%.*]] @@ -303,9 +299,7 @@ define <4 x i32> @add_or(<4 x i32> %v) { ; CHECK-LABEL: @add_or( ; CHECK-NEXT: [[V0:%.*]] = shl <4 x i32> [[V:%.*]], -; CHECK-NEXT: [[T1:%.*]] = add <4 x i32> [[V0]], -; CHECK-NEXT: [[T2:%.*]] = or <4 x i32> [[V0]], -; CHECK-NEXT: [[T3:%.*]] = shufflevector <4 x i32> [[T1]], <4 x i32> [[T2]], <4 x i32> +; CHECK-NEXT: [[T3:%.*]] = add <4 x i32> [[V0]], ; CHECK-NEXT: ret <4 x i32> [[T3]] ; %v0 = shl <4 x i32> %v, ; clear the bottom bits @@ -320,9 +314,7 @@ define <4 x i8> @or_add(<4 x i8> %v) { ; CHECK-LABEL: @or_add( ; CHECK-NEXT: [[V0:%.*]] = lshr <4 x i8> [[V:%.*]], -; CHECK-NEXT: [[T1:%.*]] = or <4 x i8> [[V0]], -; CHECK-NEXT: [[T2:%.*]] = add nuw nsw <4 x i8> [[V0]], -; CHECK-NEXT: [[T3:%.*]] = shufflevector <4 x i8> [[T1]], <4 x i8> [[T2]], <4 x i32> +; CHECK-NEXT: [[T3:%.*]] = add nsw <4 x i8> [[V0]], ; CHECK-NEXT: ret <4 x i8> [[T3]] ; %v0 = lshr <4 x i8> %v, ; clear the top bits