diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp @@ -553,12 +553,17 @@ // Match the shift amount operands for a funnel/rotate pattern. This always // matches a subtraction on the R operand. - auto matchShiftAmount = [](Value *L, Value *R, unsigned Width) -> Value * { + auto matchShiftAmount = [&](Value *L, Value *R, unsigned Width) -> Value * { // The shift amounts may add up to the narrow bit width: // (shl ShVal0, L) | (lshr ShVal1, Width - L) if (match(R, m_OneUse(m_Sub(m_SpecificInt(Width), m_Specific(L))))) return L; + // The following patterns currently only work for rotation patterns. + // TODO: Add more general funnel-shift compatible patterns. + if (ShVal0 != ShVal1) + return nullptr; + // The shift amount may be masked with negation: // (shl ShVal0, (X & (Width - 1))) | (lshr ShVal1, ((-X) & (Width - 1))) Value *X; @@ -575,11 +580,6 @@ return nullptr; }; - // TODO: Add support for funnel shifts (ShVal0 != ShVal1). - if (ShVal0 != ShVal1) - return nullptr; - Value *ShVal = ShVal0; - Value *ShAmt = matchShiftAmount(ShAmt0, ShAmt1, NarrowWidth); bool IsFshl = true; // Sub on LSHR. if (!ShAmt) { @@ -593,18 +593,22 @@ // will be a zext, but it could also be the result of an 'and' or 'shift'. unsigned WideWidth = Trunc.getSrcTy()->getScalarSizeInBits(); APInt HiBitMask = APInt::getHighBitsSet(WideWidth, WideWidth - NarrowWidth); - if (!MaskedValueIsZero(ShVal, HiBitMask, 0, &Trunc)) + if (!MaskedValueIsZero(ShVal0, HiBitMask, 0, &Trunc) || + !MaskedValueIsZero(ShVal1, HiBitMask, 0, &Trunc)) return nullptr; // We have an unnecessarily wide rotate! - // trunc (or (lshr ShVal, ShAmt), (shl ShVal, BitWidth - ShAmt)) + // trunc (or (lshr ShVal0, ShAmt), (shl ShVal1, BitWidth - ShAmt)) // Narrow the inputs and convert to funnel shift intrinsic: // llvm.fshl.i8(trunc(ShVal), trunc(ShVal), trunc(ShAmt)) Value *NarrowShAmt = Builder.CreateTrunc(ShAmt, DestTy); - Value *X = Builder.CreateTrunc(ShVal, DestTy); + Value *X, *Y; + X = Y = Builder.CreateTrunc(ShVal0, DestTy); + if (ShVal0 != ShVal1) + Y = Builder.CreateTrunc(ShVal1, DestTy); Intrinsic::ID IID = IsFshl ? Intrinsic::fshl : Intrinsic::fshr; Function *F = Intrinsic::getDeclaration(Trunc.getModule(), IID, DestTy); - return IntrinsicInst::Create(F, { X, X, NarrowShAmt }); + return IntrinsicInst::Create(F, {X, Y, NarrowShAmt}); } /// Try to narrow the width of math or bitwise logic instructions by pulling a diff --git a/llvm/test/Transforms/InstCombine/funnel.ll b/llvm/test/Transforms/InstCombine/funnel.ll --- a/llvm/test/Transforms/InstCombine/funnel.ll +++ b/llvm/test/Transforms/InstCombine/funnel.ll @@ -205,14 +205,8 @@ define i16 @fshl_16bit(i16 %x, i16 %y, i32 %shift) { ; CHECK-LABEL: @fshl_16bit( -; CHECK-NEXT: [[AND:%.*]] = and i32 [[SHIFT:%.*]], 15 -; CHECK-NEXT: [[CONVX:%.*]] = zext i16 [[X:%.*]] to i32 -; CHECK-NEXT: [[SHL:%.*]] = shl i32 [[CONVX]], [[AND]] -; CHECK-NEXT: [[SUB:%.*]] = sub nuw nsw i32 16, [[AND]] -; CHECK-NEXT: [[CONVY:%.*]] = zext i16 [[Y:%.*]] to i32 -; CHECK-NEXT: [[SHR:%.*]] = lshr i32 [[CONVY]], [[SUB]] -; CHECK-NEXT: [[OR:%.*]] = or i32 [[SHR]], [[SHL]] -; CHECK-NEXT: [[CONV2:%.*]] = trunc i32 [[OR]] to i16 +; CHECK-NEXT: [[TMP1:%.*]] = trunc i32 [[SHIFT:%.*]] to i16 +; CHECK-NEXT: [[CONV2:%.*]] = call i16 @llvm.fshl.i16(i16 [[X:%.*]], i16 [[Y:%.*]], i16 [[TMP1]]) ; CHECK-NEXT: ret i16 [[CONV2]] ; %and = and i32 %shift, 15 @@ -230,14 +224,8 @@ define <2 x i16> @fshl_commute_16bit_vec(<2 x i16> %x, <2 x i16> %y, <2 x i32> %shift) { ; CHECK-LABEL: @fshl_commute_16bit_vec( -; CHECK-NEXT: [[AND:%.*]] = and <2 x i32> [[SHIFT:%.*]], -; CHECK-NEXT: [[CONVX:%.*]] = zext <2 x i16> [[X:%.*]] to <2 x i32> -; CHECK-NEXT: [[SHL:%.*]] = shl <2 x i32> [[CONVX]], [[AND]] -; CHECK-NEXT: [[SUB:%.*]] = sub nuw nsw <2 x i32> , [[AND]] -; CHECK-NEXT: [[CONVY:%.*]] = zext <2 x i16> [[Y:%.*]] to <2 x i32> -; CHECK-NEXT: [[SHR:%.*]] = lshr <2 x i32> [[CONVY]], [[SUB]] -; CHECK-NEXT: [[OR:%.*]] = or <2 x i32> [[SHL]], [[SHR]] -; CHECK-NEXT: [[CONV2:%.*]] = trunc <2 x i32> [[OR]] to <2 x i16> +; CHECK-NEXT: [[TMP1:%.*]] = trunc <2 x i32> [[SHIFT:%.*]] to <2 x i16> +; CHECK-NEXT: [[CONV2:%.*]] = call <2 x i16> @llvm.fshl.v2i16(<2 x i16> [[X:%.*]], <2 x i16> [[Y:%.*]], <2 x i16> [[TMP1]]) ; CHECK-NEXT: ret <2 x i16> [[CONV2]] ; %and = and <2 x i32> %shift, @@ -255,14 +243,8 @@ define i8 @fshr_8bit(i8 %x, i8 %y, i3 %shift) { ; CHECK-LABEL: @fshr_8bit( -; CHECK-NEXT: [[AND:%.*]] = zext i3 [[SHIFT:%.*]] to i32 -; CHECK-NEXT: [[CONVX:%.*]] = zext i8 [[X:%.*]] to i32 -; CHECK-NEXT: [[SHR:%.*]] = lshr i32 [[CONVX]], [[AND]] -; CHECK-NEXT: [[SUB:%.*]] = sub nuw nsw i32 8, [[AND]] -; CHECK-NEXT: [[CONVY:%.*]] = zext i8 [[Y:%.*]] to i32 -; CHECK-NEXT: [[SHL:%.*]] = shl i32 [[CONVY]], [[SUB]] -; CHECK-NEXT: [[OR:%.*]] = or i32 [[SHL]], [[SHR]] -; CHECK-NEXT: [[CONV2:%.*]] = trunc i32 [[OR]] to i8 +; CHECK-NEXT: [[TMP1:%.*]] = zext i3 [[SHIFT:%.*]] to i8 +; CHECK-NEXT: [[CONV2:%.*]] = call i8 @llvm.fshr.i8(i8 [[Y:%.*]], i8 [[X:%.*]], i8 [[TMP1]]) ; CHECK-NEXT: ret i8 [[CONV2]] ; %and = zext i3 %shift to i32 @@ -281,14 +263,11 @@ define i8 @fshr_commute_8bit(i32 %x, i32 %y, i32 %shift) { ; CHECK-LABEL: @fshr_commute_8bit( -; CHECK-NEXT: [[AND:%.*]] = and i32 [[SHIFT:%.*]], 3 -; CHECK-NEXT: [[CONVX:%.*]] = and i32 [[X:%.*]], 255 -; CHECK-NEXT: [[SHR:%.*]] = lshr i32 [[CONVX]], [[AND]] -; CHECK-NEXT: [[SUB:%.*]] = sub nuw nsw i32 8, [[AND]] -; CHECK-NEXT: [[CONVY:%.*]] = and i32 [[Y:%.*]], 255 -; CHECK-NEXT: [[SHL:%.*]] = shl i32 [[CONVY]], [[SUB]] -; CHECK-NEXT: [[OR:%.*]] = or i32 [[SHR]], [[SHL]] -; CHECK-NEXT: [[CONV2:%.*]] = trunc i32 [[OR]] to i8 +; CHECK-NEXT: [[TMP1:%.*]] = trunc i32 [[SHIFT:%.*]] to i8 +; CHECK-NEXT: [[TMP2:%.*]] = and i8 [[TMP1]], 3 +; CHECK-NEXT: [[TMP3:%.*]] = trunc i32 [[Y:%.*]] to i8 +; CHECK-NEXT: [[TMP4:%.*]] = trunc i32 [[X:%.*]] to i8 +; CHECK-NEXT: [[CONV2:%.*]] = call i8 @llvm.fshr.i8(i8 [[TMP3]], i8 [[TMP4]], i8 [[TMP2]]) ; CHECK-NEXT: ret i8 [[CONV2]] ; %and = and i32 %shift, 3