diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -2312,42 +2312,43 @@ return SelectInst::Create(CmpABC, MinMaxOp, ThirdOp); } -/// Try to reduce a rotate pattern that includes a compare and select into a -/// funnel shift intrinsic. Example: +/// Try to reduce a funnel/rotate pattern that includes a compare and select +/// into a funnel shift intrinsic. Example: /// rotl32(a, b) --> (b == 0 ? a : ((a >> (32 - b)) | (a << b))) /// --> call llvm.fshl.i32(a, a, b) -static Instruction *foldSelectRotate(SelectInst &Sel, - InstCombiner::BuilderTy &Builder) { - // The false value of the select must be a rotate of the true value. +/// fshl32(a, b, c) --> (c == 0 ? a : ((b >> (32 - c)) | (a << c))) +/// --> call llvm.fshl.i32(a, b, c) +/// fshr32(a, b, c) --> (c == 0 ? b : ((a >> (32 - c)) | (b << c))) +/// --> call llvm.fshr.i32(a, b, c) +static Instruction *foldSelectFunnelShift(SelectInst &Sel, + InstCombiner::BuilderTy &Builder) { + // This must be a power-of-2 type for a bitmasking transform to be valid. + unsigned Width = Sel.getType()->getScalarSizeInBits(); + if (!isPowerOf2_32(Width)) + return nullptr; + BinaryOperator *Or0, *Or1; if (!match(Sel.getFalseValue(), m_OneUse(m_Or(m_BinOp(Or0), m_BinOp(Or1))))) return nullptr; - Value *TVal = Sel.getTrueValue(); - Value *SA0, *SA1; - if (!match(Or0, m_OneUse(m_LogicalShift(m_Specific(TVal), + Value *SV0, *SV1, *SA0, *SA1; + if (!match(Or0, m_OneUse(m_LogicalShift(m_Value(SV0), m_ZExtOrSelf(m_Value(SA0))))) || - !match(Or1, m_OneUse(m_LogicalShift(m_Specific(TVal), + !match(Or1, m_OneUse(m_LogicalShift(m_Value(SV1), m_ZExtOrSelf(m_Value(SA1))))) || Or0->getOpcode() == Or1->getOpcode()) return nullptr; - // Canonicalize to or(shl(TVal, SA0), lshr(TVal, SA1)). + // Canonicalize to or(shl(SV0, SA0), lshr(SV1, SA1)). if (Or0->getOpcode() == BinaryOperator::LShr) { std::swap(Or0, Or1); + std::swap(SV0, SV1); std::swap(SA0, SA1); } assert(Or0->getOpcode() == BinaryOperator::Shl && Or1->getOpcode() == BinaryOperator::LShr && "Illegal or(shift,shift) pair"); - // We should now have this pattern: - // select ?, TVal, (or (shl TVal, SA0), (lshr TVal, SA1)) - // This must be a power-of-2 rotate for a bitmasking transform to be valid. - unsigned Width = Sel.getType()->getScalarSizeInBits(); - if (!isPowerOf2_32(Width)) - return nullptr; - // Check the shift amounts to see if they are an opposite pair. Value *ShAmt; if (match(SA1, m_OneUse(m_Sub(m_SpecificInt(Width), m_Specific(SA0))))) @@ -2357,6 +2358,15 @@ else return nullptr; + // We should now have this pattern: + // select ?, TVal, (or (shl SV0, SA0), (lshr SV1, SA1)) + // The false value of the select must be a funnel-shift of the true value: + // IsFShl -> TVal must be SV0 else TVal must be SV1. + bool IsFshl = (ShAmt == SA0); + Value *TVal = Sel.getTrueValue(); + if ((IsFshl && TVal != SV0) || (!IsFshl && TVal != SV1)) + return nullptr; + // Finally, see if the select is filtering out a shift-by-zero. Value *Cond = Sel.getCondition(); ICmpInst::Predicate Pred; @@ -2364,13 +2374,12 @@ Pred != ICmpInst::ICMP_EQ) return nullptr; - // This is a rotate that avoids shift-by-bitwidth UB in a suboptimal way. + // This is a funnel/rotate that avoids shift-by-bitwidth UB in a suboptimal way. // Convert to funnel shift intrinsic. - bool IsFshl = (ShAmt == SA0); Intrinsic::ID IID = IsFshl ? Intrinsic::fshl : Intrinsic::fshr; Function *F = Intrinsic::getDeclaration(Sel.getModule(), IID, Sel.getType()); ShAmt = Builder.CreateZExt(ShAmt, Sel.getType()); - return IntrinsicInst::Create(F, { TVal, TVal, ShAmt }); + return IntrinsicInst::Create(F, { SV0, SV1, ShAmt }); } static Instruction *foldSelectToCopysign(SelectInst &Sel, @@ -3014,8 +3023,8 @@ if (Instruction *Select = foldSelectBinOpIdentity(SI, TLI, *this)) return Select; - if (Instruction *Rot = foldSelectRotate(SI, Builder)) - return Rot; + if (Instruction *Res = foldSelectFunnelShift(SI, Builder)) + return Res; if (Instruction *Copysign = foldSelectToCopysign(SI, Builder)) return Copysign; diff --git a/llvm/test/Transforms/InstCombine/funnel.ll b/llvm/test/Transforms/InstCombine/funnel.ll --- a/llvm/test/Transforms/InstCombine/funnel.ll +++ b/llvm/test/Transforms/InstCombine/funnel.ll @@ -285,12 +285,7 @@ define i8 @fshr_select(i8 %x, i8 %y, i8 %shamt) { ; CHECK-LABEL: @fshr_select( -; CHECK-NEXT: [[CMP:%.*]] = icmp eq i8 [[SHAMT:%.*]], 0 -; CHECK-NEXT: [[SUB:%.*]] = sub i8 8, [[SHAMT]] -; CHECK-NEXT: [[SHR:%.*]] = lshr i8 [[Y:%.*]], [[SHAMT]] -; CHECK-NEXT: [[SHL:%.*]] = shl i8 [[X:%.*]], [[SUB]] -; CHECK-NEXT: [[OR:%.*]] = or i8 [[SHL]], [[SHR]] -; CHECK-NEXT: [[R:%.*]] = select i1 [[CMP]], i8 [[Y]], i8 [[OR]] +; CHECK-NEXT: [[R:%.*]] = call i8 @llvm.fshr.i8(i8 [[X:%.*]], i8 [[Y:%.*]], i8 [[SHAMT:%.*]]) ; CHECK-NEXT: ret i8 [[R]] ; %cmp = icmp eq i8 %shamt, 0 @@ -306,12 +301,7 @@ define i16 @fshl_select(i16 %x, i16 %y, i16 %shamt) { ; CHECK-LABEL: @fshl_select( -; CHECK-NEXT: [[CMP:%.*]] = icmp eq i16 [[SHAMT:%.*]], 0 -; CHECK-NEXT: [[SUB:%.*]] = sub i16 16, [[SHAMT]] -; CHECK-NEXT: [[SHR:%.*]] = lshr i16 [[Y:%.*]], [[SUB]] -; CHECK-NEXT: [[SHL:%.*]] = shl i16 [[X:%.*]], [[SHAMT]] -; CHECK-NEXT: [[OR:%.*]] = or i16 [[SHR]], [[SHL]] -; CHECK-NEXT: [[R:%.*]] = select i1 [[CMP]], i16 [[X]], i16 [[OR]] +; CHECK-NEXT: [[R:%.*]] = call i16 @llvm.fshl.i16(i16 [[X:%.*]], i16 [[Y:%.*]], i16 [[SHAMT:%.*]]) ; CHECK-NEXT: ret i16 [[R]] ; %cmp = icmp eq i16 %shamt, 0 @@ -327,12 +317,7 @@ define <2 x i64> @fshl_select_vector(<2 x i64> %x, <2 x i64> %y, <2 x i64> %shamt) { ; CHECK-LABEL: @fshl_select_vector( -; CHECK-NEXT: [[CMP:%.*]] = icmp eq <2 x i64> [[SHAMT:%.*]], zeroinitializer -; CHECK-NEXT: [[SUB:%.*]] = sub <2 x i64> , [[SHAMT]] -; CHECK-NEXT: [[SHR:%.*]] = lshr <2 x i64> [[X:%.*]], [[SUB]] -; CHECK-NEXT: [[SHL:%.*]] = shl <2 x i64> [[Y:%.*]], [[SHAMT]] -; CHECK-NEXT: [[OR:%.*]] = or <2 x i64> [[SHL]], [[SHR]] -; CHECK-NEXT: [[R:%.*]] = select <2 x i1> [[CMP]], <2 x i64> [[Y]], <2 x i64> [[OR]] +; CHECK-NEXT: [[R:%.*]] = call <2 x i64> @llvm.fshl.v2i64(<2 x i64> [[Y:%.*]], <2 x i64> [[X:%.*]], <2 x i64> [[SHAMT:%.*]]) ; CHECK-NEXT: ret <2 x i64> [[R]] ; %cmp = icmp eq <2 x i64> %shamt, zeroinitializer