diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp @@ -2071,22 +2071,22 @@ return LastInst; } -/// Transform UB-safe variants of bitwise rotate to the funnel shift intrinsic. -static Instruction *matchRotate(Instruction &Or) { +/// Match UB-safe variants of the funnel shift intrinsic. +static Instruction *matchFunnelShift(Instruction &Or) { // TODO: Can we reduce the code duplication between this and the related // rotate matching code under visitSelect and visitTrunc? unsigned Width = Or.getType()->getScalarSizeInBits(); // First, find an or'd pair of opposite shifts with the same shifted operand: - // or (lshr ShVal, ShAmt0), (shl ShVal, ShAmt1) + // or (lshr ShVal0, ShAmt0), (shl ShVal1, ShAmt1) BinaryOperator *Or0, *Or1; if (!match(Or.getOperand(0), m_BinOp(Or0)) || !match(Or.getOperand(1), m_BinOp(Or1))) return nullptr; - Value *ShVal, *ShAmt0, *ShAmt1; - if (!match(Or0, m_OneUse(m_LogicalShift(m_Value(ShVal), m_Value(ShAmt0)))) || - !match(Or1, m_OneUse(m_LogicalShift(m_Specific(ShVal), m_Value(ShAmt1))))) + Value *ShVal0, *ShVal1, *ShAmt0, *ShAmt1; + if (!match(Or0, m_OneUse(m_LogicalShift(m_Value(ShVal0), m_Value(ShAmt0)))) || + !match(Or1, m_OneUse(m_LogicalShift(m_Value(ShVal1), m_Value(ShAmt1))))) return nullptr; BinaryOperator::BinaryOps ShiftOpcode0 = Or0->getOpcode(); @@ -2094,9 +2094,9 @@ if (ShiftOpcode0 == ShiftOpcode1) return nullptr; - // Match the shift amount operands for a rotate pattern. This always matches - // a subtraction on the R operand. - auto matchShiftAmount = [](Value *L, Value *R, unsigned Width) -> Value * { + // Match the shift amount operands for a funnel shift pattern. This always + // matches a subtraction on the R operand. + auto matchShiftAmount = [&](Value *L, Value *R, unsigned Width) -> Value * { // Check for constant shift amounts that sum to the bitwidth. // TODO: Support non-uniform shift amounts. const APInt *LC, *RC; @@ -2104,6 +2104,11 @@ if (LC->ult(Width) && RC->ult(Width) && (*LC + *RC) == Width) return ConstantInt::get(L->getType(), *LC); + // For non-constant cases, the following patterns currently only support + // rotation patterns. + if (ShVal0 != ShVal1) + return nullptr; + // For non-constant cases we don't support non-pow2 shift masks. // TODO: Is it worth matching urem as well? if (!isPowerOf2_32(Width)) @@ -2140,7 +2145,8 @@ (SubIsOnLHS && ShiftOpcode1 == BinaryOperator::Shl); Intrinsic::ID IID = IsFshl ? Intrinsic::fshl : Intrinsic::fshr; Function *F = Intrinsic::getDeclaration(Or.getModule(), IID, Or.getType()); - return IntrinsicInst::Create(F, { ShVal, ShVal, ShAmt }); + return IntrinsicInst::Create( + F, {IsFshl ? ShVal0 : ShVal1, IsFshl ? ShVal1 : ShVal0, ShAmt}); } /// Attempt to combine or(zext(x),shl(zext(y),bw/2) concat packing patterns. @@ -2593,8 +2599,8 @@ if (Instruction *BSwap = matchBSwap(I)) return BSwap; - if (Instruction *Rotate = matchRotate(I)) - return Rotate; + if (Instruction *Funnel = matchFunnelShift(I)) + return Funnel; if (Instruction *Concat = matchOrConcat(I, Builder)) return replaceInstUsesWith(I, Concat); diff --git a/llvm/test/Transforms/InstCombine/funnel.ll b/llvm/test/Transforms/InstCombine/funnel.ll --- a/llvm/test/Transforms/InstCombine/funnel.ll +++ b/llvm/test/Transforms/InstCombine/funnel.ll @@ -3,16 +3,14 @@ target datalayout = "e-p:32:32:32-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:32:64-f32:32:32-f64:32:64-v64:64:64-v128:128:128-a0:0:64-f80:128:128" -; TODO: Canonicalize or(shl,lshr) by constant to funnel shift intrinsics. +; Canonicalize or(shl,lshr) by constant to funnel shift intrinsics. ; This should help cost modeling for vectorization, inlining, etc. ; If a target does not have a fshl instruction, the expansion will ; be exactly these same 3 basic ops (shl/lshr/or). define i32 @fshl_i32_constant(i32 %x, i32 %y) { ; CHECK-LABEL: @fshl_i32_constant( -; CHECK-NEXT: [[SHL:%.*]] = shl i32 [[X:%.*]], 11 -; CHECK-NEXT: [[SHR:%.*]] = lshr i32 [[Y:%.*]], 21 -; CHECK-NEXT: [[R:%.*]] = or i32 [[SHR]], [[SHL]] +; CHECK-NEXT: [[R:%.*]] = call i32 @llvm.fshl.i32(i32 [[X:%.*]], i32 [[Y:%.*]], i32 11) ; CHECK-NEXT: ret i32 [[R]] ; %shl = shl i32 %x, 11 @@ -23,9 +21,7 @@ define i42 @fshr_i42_constant(i42 %x, i42 %y) { ; CHECK-LABEL: @fshr_i42_constant( -; CHECK-NEXT: [[SHR:%.*]] = lshr i42 [[X:%.*]], 31 -; CHECK-NEXT: [[SHL:%.*]] = shl i42 [[Y:%.*]], 11 -; CHECK-NEXT: [[R:%.*]] = or i42 [[SHR]], [[SHL]] +; CHECK-NEXT: [[R:%.*]] = call i42 @llvm.fshl.i42(i42 [[Y:%.*]], i42 [[X:%.*]], i42 11) ; CHECK-NEXT: ret i42 [[R]] ; %shr = lshr i42 %x, 31 @@ -34,13 +30,11 @@ ret i42 %r } -; TODO: Vector types are allowed. +; Vector types are allowed. define <2 x i16> @fshl_v2i16_constant_splat(<2 x i16> %x, <2 x i16> %y) { ; CHECK-LABEL: @fshl_v2i16_constant_splat( -; CHECK-NEXT: [[SHL:%.*]] = shl <2 x i16> [[X:%.*]], -; CHECK-NEXT: [[SHR:%.*]] = lshr <2 x i16> [[Y:%.*]], -; CHECK-NEXT: [[R:%.*]] = or <2 x i16> [[SHL]], [[SHR]] +; CHECK-NEXT: [[R:%.*]] = call <2 x i16> @llvm.fshl.v2i16(<2 x i16> [[X:%.*]], <2 x i16> [[Y:%.*]], <2 x i16> ) ; CHECK-NEXT: ret <2 x i16> [[R]] ; %shl = shl <2 x i16> %x, @@ -51,9 +45,7 @@ define <2 x i16> @fshl_v2i16_constant_splat_undef0(<2 x i16> %x, <2 x i16> %y) { ; CHECK-LABEL: @fshl_v2i16_constant_splat_undef0( -; CHECK-NEXT: [[SHL:%.*]] = shl <2 x i16> [[X:%.*]], -; CHECK-NEXT: [[SHR:%.*]] = lshr <2 x i16> [[Y:%.*]], -; CHECK-NEXT: [[R:%.*]] = or <2 x i16> [[SHL]], [[SHR]] +; CHECK-NEXT: [[R:%.*]] = call <2 x i16> @llvm.fshl.v2i16(<2 x i16> [[X:%.*]], <2 x i16> [[Y:%.*]], <2 x i16> ) ; CHECK-NEXT: ret <2 x i16> [[R]] ; %shl = shl <2 x i16> %x, @@ -64,9 +56,7 @@ define <2 x i16> @fshl_v2i16_constant_splat_undef1(<2 x i16> %x, <2 x i16> %y) { ; CHECK-LABEL: @fshl_v2i16_constant_splat_undef1( -; CHECK-NEXT: [[SHL:%.*]] = shl <2 x i16> [[X:%.*]], -; CHECK-NEXT: [[SHR:%.*]] = lshr <2 x i16> [[Y:%.*]], -; CHECK-NEXT: [[R:%.*]] = or <2 x i16> [[SHL]], [[SHR]] +; CHECK-NEXT: [[R:%.*]] = call <2 x i16> @llvm.fshl.v2i16(<2 x i16> [[X:%.*]], <2 x i16> [[Y:%.*]], <2 x i16> ) ; CHECK-NEXT: ret <2 x i16> [[R]] ; %shl = shl <2 x i16> %x, @@ -75,13 +65,11 @@ ret <2 x i16> %r } -; TODO: Non-power-of-2 vector types are allowed. +; Non-power-of-2 vector types are allowed. define <2 x i17> @fshr_v2i17_constant_splat(<2 x i17> %x, <2 x i17> %y) { ; CHECK-LABEL: @fshr_v2i17_constant_splat( -; CHECK-NEXT: [[SHR:%.*]] = lshr <2 x i17> [[X:%.*]], -; CHECK-NEXT: [[SHL:%.*]] = shl <2 x i17> [[Y:%.*]], -; CHECK-NEXT: [[R:%.*]] = or <2 x i17> [[SHR]], [[SHL]] +; CHECK-NEXT: [[R:%.*]] = call <2 x i17> @llvm.fshl.v2i17(<2 x i17> [[Y:%.*]], <2 x i17> [[X:%.*]], <2 x i17> ) ; CHECK-NEXT: ret <2 x i17> [[R]] ; %shr = lshr <2 x i17> %x, @@ -92,9 +80,7 @@ define <2 x i17> @fshr_v2i17_constant_splat_undef0(<2 x i17> %x, <2 x i17> %y) { ; CHECK-LABEL: @fshr_v2i17_constant_splat_undef0( -; CHECK-NEXT: [[SHR:%.*]] = lshr <2 x i17> [[X:%.*]], -; CHECK-NEXT: [[SHL:%.*]] = shl <2 x i17> [[Y:%.*]], -; CHECK-NEXT: [[R:%.*]] = or <2 x i17> [[SHR]], [[SHL]] +; CHECK-NEXT: [[R:%.*]] = call <2 x i17> @llvm.fshl.v2i17(<2 x i17> [[Y:%.*]], <2 x i17> [[X:%.*]], <2 x i17> ) ; CHECK-NEXT: ret <2 x i17> [[R]] ; %shr = lshr <2 x i17> %x, @@ -105,9 +91,7 @@ define <2 x i17> @fshr_v2i17_constant_splat_undef1(<2 x i17> %x, <2 x i17> %y) { ; CHECK-LABEL: @fshr_v2i17_constant_splat_undef1( -; CHECK-NEXT: [[SHR:%.*]] = lshr <2 x i17> [[X:%.*]], -; CHECK-NEXT: [[SHL:%.*]] = shl <2 x i17> [[Y:%.*]], -; CHECK-NEXT: [[R:%.*]] = or <2 x i17> [[SHR]], [[SHL]] +; CHECK-NEXT: [[R:%.*]] = call <2 x i17> @llvm.fshl.v2i17(<2 x i17> [[Y:%.*]], <2 x i17> [[X:%.*]], <2 x i17> ) ; CHECK-NEXT: ret <2 x i17> [[R]] ; %shr = lshr <2 x i17> %x,