Index: llvm/trunk/lib/Transforms/InstCombine/InstCombineCasts.cpp =================================================================== --- llvm/trunk/lib/Transforms/InstCombine/InstCombineCasts.cpp +++ llvm/trunk/lib/Transforms/InstCombine/InstCombineCasts.cpp @@ -443,13 +443,81 @@ return ExtractElementInst::Create(VecInput, IC.Builder.getInt32(Elt)); } +/// Rotate left/right may occur in a wider type than necessary because of type +/// promotion rules. Try to narrow all of the component instructions. +Instruction *InstCombiner::narrowRotate(TruncInst &Trunc) { + assert((isa(Trunc.getSrcTy()) || + shouldChangeType(Trunc.getSrcTy(), Trunc.getType())) && + "Don't narrow to an illegal scalar type"); + + // First, find an or'd pair of opposite shifts with the same shifted operand: + // trunc (or (lshr ShVal, ShAmt0), (shl ShVal, ShAmt1)) + Value *Or0, *Or1; + if (!match(Trunc.getOperand(0), m_OneUse(m_Or(m_Value(Or0), m_Value(Or1))))) + return nullptr; + + Value *ShVal, *ShAmt0, *ShAmt1; + if (!match(Or0, m_OneUse(m_LogicalShift(m_Value(ShVal), m_Value(ShAmt0)))) || + !match(Or1, m_OneUse(m_LogicalShift(m_Specific(ShVal), m_Value(ShAmt1))))) + return nullptr; + + auto ShiftOpcode0 = cast(Or0)->getOpcode(); + auto ShiftOpcode1 = cast(Or1)->getOpcode(); + if (ShiftOpcode0 == ShiftOpcode1) + return nullptr; + + // The shift amounts must add up to the narrow bit width. + Value *ShAmt; + bool SubIsOnLHS; + Type *DestTy = Trunc.getType(); + unsigned NarrowWidth = DestTy->getScalarSizeInBits(); + if (match(ShAmt0, + m_OneUse(m_Sub(m_SpecificInt(NarrowWidth), m_Specific(ShAmt1))))) { + ShAmt = ShAmt1; + SubIsOnLHS = true; + } else if (match(ShAmt1, m_OneUse(m_Sub(m_SpecificInt(NarrowWidth), + m_Specific(ShAmt0))))) { + ShAmt = ShAmt0; + SubIsOnLHS = false; + } else { + return nullptr; + } + + // The shifted value must have high zeros in the wide type. Typically, this + // will be a zext, but it could also be the result of an 'and' or 'shift'. + unsigned WideWidth = Trunc.getSrcTy()->getScalarSizeInBits(); + APInt HiBitMask = APInt::getHighBitsSet(WideWidth, WideWidth - NarrowWidth); + if (!MaskedValueIsZero(ShVal, HiBitMask, 0, &Trunc)) + return nullptr; + + // We have an unnecessarily wide rotate! + // trunc (or (lshr ShVal, ShAmt), (shl ShVal, BitWidth - ShAmt)) + // Narrow it down to eliminate the zext/trunc: + // or (lshr trunc(ShVal), ShAmt0'), (shl trunc(ShVal), ShAmt1') + Value *NarrowShAmt = Builder.CreateTrunc(ShAmt, DestTy); + Value *NegShAmt = Builder.CreateNeg(NarrowShAmt); + + // Mask both shift amounts to ensure there's no UB from oversized shifts. + Constant *MaskC = ConstantInt::get(DestTy, NarrowWidth - 1); + Value *MaskedShAmt = Builder.CreateAnd(NarrowShAmt, MaskC); + Value *MaskedNegShAmt = Builder.CreateAnd(NegShAmt, MaskC); + + // Truncate the original value and use narrow ops. + Value *X = Builder.CreateTrunc(ShVal, DestTy); + Value *NarrowShAmt0 = SubIsOnLHS ? MaskedNegShAmt : MaskedShAmt; + Value *NarrowShAmt1 = SubIsOnLHS ? MaskedShAmt : MaskedNegShAmt; + Value *NarrowSh0 = Builder.CreateBinOp(ShiftOpcode0, X, NarrowShAmt0); + Value *NarrowSh1 = Builder.CreateBinOp(ShiftOpcode1, X, NarrowShAmt1); + return BinaryOperator::CreateOr(NarrowSh0, NarrowSh1); +} + /// Try to narrow the width of math or bitwise logic instructions by pulling a /// truncate ahead of binary operators. /// TODO: Transforms for truncated shifts should be moved into here. Instruction *InstCombiner::narrowBinOp(TruncInst &Trunc) { Type *SrcTy = Trunc.getSrcTy(); Type *DestTy = Trunc.getType(); - if (isa(SrcTy) && !shouldChangeType(SrcTy, DestTy)) + if (!isa(SrcTy) && !shouldChangeType(SrcTy, DestTy)) return nullptr; BinaryOperator *BinOp; @@ -485,6 +553,9 @@ default: break; } + if (Instruction *NarrowOr = narrowRotate(Trunc)) + return NarrowOr; + return nullptr; } Index: llvm/trunk/lib/Transforms/InstCombine/InstCombineInternal.h =================================================================== --- llvm/trunk/lib/Transforms/InstCombine/InstCombineInternal.h +++ llvm/trunk/lib/Transforms/InstCombine/InstCombineInternal.h @@ -440,6 +440,7 @@ Value *EvaluateInDifferentElementOrder(Value *V, ArrayRef Mask); Instruction *foldCastedBitwiseLogic(BinaryOperator &I); Instruction *narrowBinOp(TruncInst &Trunc); + Instruction *narrowRotate(TruncInst &Trunc); Instruction *optimizeBitCastFromPhi(CastInst &CI, PHINode *PN); /// Determine if a pair of casts can be replaced by a single cast. Index: llvm/trunk/test/Transforms/InstCombine/rotate.ll =================================================================== --- llvm/trunk/test/Transforms/InstCombine/rotate.ll +++ llvm/trunk/test/Transforms/InstCombine/rotate.ll @@ -0,0 +1,123 @@ +; RUN: opt < %s -instcombine -S | FileCheck %s + +target datalayout = "e-p:32:32:32-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:32:64-f32:32:32-f64:32:64-v64:64:64-v128:128:128-a0:0:64-f80:128:128" + +; These are UB-free rotate left/right patterns that are narrowed to a smaller bitwidth. +; See PR34046 and PR16726 for motivating examples: +; https://bugs.llvm.org/show_bug.cgi?id=34046 +; https://bugs.llvm.org/show_bug.cgi?id=16726 + +define i16 @rotate_left_16bit(i16 %v, i32 %shift) { +; CHECK-LABEL: @rotate_left_16bit( +; CHECK-NEXT: [[TMP1:%.*]] = trunc i32 %shift to i16 +; CHECK-NEXT: [[TMP2:%.*]] = and i16 [[TMP1]], 15 +; CHECK-NEXT: [[TMP3:%.*]] = sub i16 0, [[TMP1]] +; CHECK-NEXT: [[TMP4:%.*]] = and i16 [[TMP3]], 15 +; CHECK-NEXT: [[TMP5:%.*]] = lshr i16 %v, [[TMP4]] +; CHECK-NEXT: [[TMP6:%.*]] = shl i16 %v, [[TMP2]] +; CHECK-NEXT: [[CONV2:%.*]] = or i16 [[TMP5]], [[TMP6]] +; CHECK-NEXT: ret i16 [[CONV2]] +; + %and = and i32 %shift, 15 + %conv = zext i16 %v to i32 + %shl = shl i32 %conv, %and + %sub = sub i32 16, %and + %shr = lshr i32 %conv, %sub + %or = or i32 %shr, %shl + %conv2 = trunc i32 %or to i16 + ret i16 %conv2 +} + +; Commute the 'or' operands and try a vector type. + +define <2 x i16> @rotate_left_commute_16bit_vec(<2 x i16> %v, <2 x i32> %shift) { +; CHECK-LABEL: @rotate_left_commute_16bit_vec( +; CHECK-NEXT: [[TMP1:%.*]] = trunc <2 x i32> %shift to <2 x i16> +; CHECK-NEXT: [[TMP2:%.*]] = and <2 x i16> [[TMP1]], +; CHECK-NEXT: [[TMP3:%.*]] = sub <2 x i16> zeroinitializer, [[TMP1]] +; CHECK-NEXT: [[TMP4:%.*]] = and <2 x i16> [[TMP3]], +; CHECK-NEXT: [[TMP5:%.*]] = shl <2 x i16> %v, [[TMP2]] +; CHECK-NEXT: [[TMP6:%.*]] = lshr <2 x i16> %v, [[TMP4]] +; CHECK-NEXT: [[CONV2:%.*]] = or <2 x i16> [[TMP5]], [[TMP6]] +; CHECK-NEXT: ret <2 x i16> [[CONV2]] +; + %and = and <2 x i32> %shift, + %conv = zext <2 x i16> %v to <2 x i32> + %shl = shl <2 x i32> %conv, %and + %sub = sub <2 x i32> , %and + %shr = lshr <2 x i32> %conv, %sub + %or = or <2 x i32> %shl, %shr + %conv2 = trunc <2 x i32> %or to <2 x i16> + ret <2 x i16> %conv2 +} + +; Change the size, rotation direction (the subtract is on the left-shift), and mask op. + +define i8 @rotate_right_8bit(i8 %v, i3 %shift) { +; CHECK-LABEL: @rotate_right_8bit( +; CHECK-NEXT: [[TMP1:%.*]] = zext i3 %shift to i8 +; CHECK-NEXT: [[TMP2:%.*]] = sub i3 0, %shift +; CHECK-NEXT: [[TMP3:%.*]] = zext i3 [[TMP2]] to i8 +; CHECK-NEXT: [[TMP4:%.*]] = shl i8 %v, [[TMP3]] +; CHECK-NEXT: [[TMP5:%.*]] = lshr i8 %v, [[TMP1]] +; CHECK-NEXT: [[CONV2:%.*]] = or i8 [[TMP4]], [[TMP5]] +; CHECK-NEXT: ret i8 [[CONV2]] +; + %and = zext i3 %shift to i32 + %conv = zext i8 %v to i32 + %shr = lshr i32 %conv, %and + %sub = sub i32 8, %and + %shl = shl i32 %conv, %sub + %or = or i32 %shl, %shr + %conv2 = trunc i32 %or to i8 + ret i8 %conv2 +} + +; The shifted value does not need to be a zexted value; here it is masked. +; The shift mask could be less than the bitwidth, but this is still ok. + +define i8 @rotate_right_commute_8bit(i32 %v, i32 %shift) { +; CHECK-LABEL: @rotate_right_commute_8bit( +; CHECK-NEXT: [[TMP1:%.*]] = trunc i32 %shift to i8 +; CHECK-NEXT: [[TMP2:%.*]] = and i8 [[TMP1]], 3 +; CHECK-NEXT: [[TMP3:%.*]] = sub nsw i8 0, [[TMP2]] +; CHECK-NEXT: [[TMP4:%.*]] = and i8 [[TMP3]], 7 +; CHECK-NEXT: [[TMP5:%.*]] = trunc i32 %v to i8 +; CHECK-NEXT: [[TMP6:%.*]] = lshr i8 [[TMP5]], [[TMP2]] +; CHECK-NEXT: [[TMP7:%.*]] = shl i8 [[TMP5]], [[TMP4]] +; CHECK-NEXT: [[CONV2:%.*]] = or i8 [[TMP6]], [[TMP7]] +; CHECK-NEXT: ret i8 [[CONV2]] +; + %and = and i32 %shift, 3 + %conv = and i32 %v, 255 + %shr = lshr i32 %conv, %and + %sub = sub i32 8, %and + %shl = shl i32 %conv, %sub + %or = or i32 %shr, %shl + %conv2 = trunc i32 %or to i8 + ret i8 %conv2 +} + +; If the original source does not mask the shift amount, +; we still do the transform by adding masks to make it safe. + +define i8 @rotate8_not_safe(i8 %v, i32 %shamt) { +; CHECK-LABEL: @rotate8_not_safe( +; CHECK-NEXT: [[TMP1:%.*]] = trunc i32 %shamt to i8 +; CHECK-NEXT: [[TMP2:%.*]] = sub i8 0, [[TMP1]] +; CHECK-NEXT: [[TMP3:%.*]] = and i8 [[TMP1]], 7 +; CHECK-NEXT: [[TMP4:%.*]] = and i8 [[TMP2]], 7 +; CHECK-NEXT: [[TMP5:%.*]] = lshr i8 %v, [[TMP4]] +; CHECK-NEXT: [[TMP6:%.*]] = shl i8 %v, [[TMP3]] +; CHECK-NEXT: [[RET:%.*]] = or i8 [[TMP5]], [[TMP6]] +; CHECK-NEXT: ret i8 [[RET]] +; + %conv = zext i8 %v to i32 + %sub = sub i32 8, %shamt + %shr = lshr i32 %conv, %sub + %shl = shl i32 %conv, %shamt + %or = or i32 %shr, %shl + %ret = trunc i32 %or to i8 + ret i8 %ret +} +