diff --git a/llvm/include/llvm/IR/PatternMatch.h b/llvm/include/llvm/IR/PatternMatch.h --- a/llvm/include/llvm/IR/PatternMatch.h +++ b/llvm/include/llvm/IR/PatternMatch.h @@ -1258,6 +1258,12 @@ return CastClass_match(Op); } +template +inline match_combine_or, OpTy> +m_TruncOrSelf(const OpTy &Op) { + return m_CombineOr(m_Trunc(Op), Op); +} + /// Matches SExt. template inline CastClass_match m_SExt(const OpTy &Op) { diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -3299,6 +3299,7 @@ // we should move shifts to the same hand of 'and', i.e. rewrite as // icmp eq/ne (and (x shift (Q+K)), y), 0 iff (Q+K) u< bitwidth(x) // We are only interested in opposite logical shifts here. +// One of the shifts can be truncated. For now, it can only be 'shl'. // If we can, we want to end up creating 'lshr' shift. static Value * foldShiftIntoShiftInAnotherHandOfAndInICmp(ICmpInst &I, const SimplifyQuery SQ, @@ -3308,18 +3309,37 @@ return nullptr; auto m_AnyLogicalShift = m_LogicalShift(m_Value(), m_Value()); - auto m_AnyLShr = m_LShr(m_Value(), m_Value()); - // Look for an 'and' of two (opposite) logical shifts. - // Pick the single-use shift as XShift. - Instruction *XShift, *YShift; - if (!match(I.getOperand(0), - m_c_And(m_CombineAnd(m_AnyLogicalShift, m_Instruction(XShift)), - m_CombineAnd(m_AnyLogicalShift, m_Instruction(YShift))))) + // Look for an 'and' of two logical shifts, one of which may be truncated. + // We use m_TruncOrSelf() on the RHS to correctly handle commutative case. + Instruction *XShift, *MaybeTruncation, *YShift; + if (!match( + I.getOperand(0), + m_c_And(m_CombineAnd(m_AnyLogicalShift, m_Instruction(XShift)), + m_CombineAnd(m_TruncOrSelf(m_CombineAnd( + m_AnyLogicalShift, m_Instruction(YShift))), + m_Instruction(MaybeTruncation))))) return nullptr; + Instruction *UntruncatedShift = XShift; + + // We potentially looked past 'trunc', but only when matching YShift, + // therefore YShift must have the widest type. + Type *WidestTy = YShift->getType(); + assert(XShift->getType() == I.getOperand(0)->getType() && + "We did not look past any shifts while matching XShift though."); + bool HadTrunc = WidestTy != I.getOperand(0)->getType(); + + if (HadTrunc) { + // We did indeed have a truncation. For now, let's only proceed if the 'shl' + // was truncated, since that does not require any extra legality checks. + // FIXME: trunc-of-lshr. + if (!match(YShift, m_Shl(m_Value(), m_Value()))) + return nullptr; + } + // If YShift is a 'lshr', swap the shifts around. - if (match(YShift, m_AnyLShr)) + if (match(YShift, m_LShr(m_Value(), m_Value()))) std::swap(XShift, YShift); // The shifts must be in opposite directions. @@ -3328,37 +3348,54 @@ return nullptr; // Do not care about same-direction shifts here. Value *X, *XShAmt, *Y, *YShAmt; - match(XShift, m_BinOp(m_Value(X), m_Value(XShAmt))); - match(YShift, m_BinOp(m_Value(Y), m_Value(YShAmt))); + match(XShift, m_BinOp(m_Value(X), m_ZExtOrSelf(m_Value(XShAmt)))); + match(YShift, m_BinOp(m_Value(Y), m_ZExtOrSelf(m_Value(YShAmt)))); // If one of the values being shifted is a constant, then we will end with - // and+icmp, and shift instr will be constant-folded. If they are not, + // and+icmp, and [zext+]shift instrs will be constant-folded. If they are not, // however, we will need to ensure that we won't increase instruction count. if (!isa(X) && !isa(Y)) { // At least one of the hands of the 'and' should be one-use shift. if (!match(I.getOperand(0), m_c_And(m_OneUse(m_AnyLogicalShift), m_Value()))) return nullptr; + if (HadTrunc) { + // Due to the 'trunc', we will need to widen X. For that either the old + // 'trunc' or the shift amt in the non-truncated shift should be one-use. + if (!MaybeTruncation->hasOneUse() && + !UntruncatedShift->getOperand(1)->hasOneUse()) + return nullptr; + } } + // We have two shift amounts from two different shifts. The types of those + // shift amounts may not match. If that's the case let's bailout now. + if (XShAmt->getType() != YShAmt->getType()) + return nullptr; + // Can we fold (XShAmt+YShAmt) ? - Value *NewShAmt = SimplifyAddInst(XShAmt, YShAmt, /*IsNSW=*/false, - /*IsNUW=*/false, SQ.getWithInstruction(&I)); + auto *NewShAmt = dyn_cast_or_null( + SimplifyAddInst(XShAmt, YShAmt, /*isNSW=*/false, + /*isNUW=*/false, SQ.getWithInstruction(&I))); if (!NewShAmt) return nullptr; // Is the new shift amount smaller than the bit width? // FIXME: could also rely on ConstantRange. - unsigned BitWidth = X->getType()->getScalarSizeInBits(); - if (!match(NewShAmt, m_SpecificInt_ICMP(ICmpInst::Predicate::ICMP_ULT, - APInt(BitWidth, BitWidth)))) + if (!match(NewShAmt, m_SpecificInt_ICMP( + ICmpInst::Predicate::ICMP_ULT, + APInt(NewShAmt->getType()->getScalarSizeInBits(), + WidestTy->getScalarSizeInBits())))) return nullptr; - // All good, we can do this fold. The shift is the same that was for X. + // All good, we can do this fold. + NewShAmt = ConstantExpr::getZExtOrBitCast(NewShAmt, WidestTy); + X = Builder.CreateZExt(X, WidestTy); + // The shift is the same that was for X. Value *T0 = XShiftOpcode == Instruction::BinaryOps::LShr ? Builder.CreateLShr(X, NewShAmt) : Builder.CreateShl(X, NewShAmt); Value *T1 = Builder.CreateAnd(T0, Y); return Builder.CreateICmp(I.getPredicate(), T1, - Constant::getNullValue(X->getType())); + Constant::getNullValue(WidestTy)); } /// Try to fold icmp (binop), X or icmp X, (binop). diff --git a/llvm/test/Transforms/InstCombine/shift-amount-reassociation-in-bittest-with-truncation-shl.ll b/llvm/test/Transforms/InstCombine/shift-amount-reassociation-in-bittest-with-truncation-shl.ll --- a/llvm/test/Transforms/InstCombine/shift-amount-reassociation-in-bittest-with-truncation-shl.ll +++ b/llvm/test/Transforms/InstCombine/shift-amount-reassociation-in-bittest-with-truncation-shl.ll @@ -6,6 +6,8 @@ ; we should move shifts to the same hand of 'and', i.e. e.g. rewrite as ; icmp eq/ne (and (((x shift Q) shift K), y)), 0 ; We are only interested in opposite logical shifts here. +; We still can handle the case where there is a truncation between a shift +; and an 'and', but for now only if it's 'shl' - simpler legality check. ;------------------------------------------------------------------------------- ; Basic scalar tests @@ -13,15 +15,11 @@ define i1 @t0_const_after_fold_lshr_shl_ne(i32 %x, i64 %y, i32 %len) { ; CHECK-LABEL: @t0_const_after_fold_lshr_shl_ne( -; CHECK-NEXT: [[T0:%.*]] = sub i32 32, [[LEN:%.*]] -; CHECK-NEXT: [[T1:%.*]] = lshr i32 [[X:%.*]], [[T0]] -; CHECK-NEXT: [[T2:%.*]] = add i32 [[LEN]], -1 -; CHECK-NEXT: [[T2_WIDE:%.*]] = zext i32 [[T2]] to i64 -; CHECK-NEXT: [[T3:%.*]] = shl i64 [[Y:%.*]], [[T2_WIDE]] -; CHECK-NEXT: [[T3_TRUNC:%.*]] = trunc i64 [[T3]] to i32 -; CHECK-NEXT: [[T4:%.*]] = and i32 [[T1]], [[T3_TRUNC]] -; CHECK-NEXT: [[T5:%.*]] = icmp ne i32 [[T4]], 0 -; CHECK-NEXT: ret i1 [[T5]] +; CHECK-NEXT: [[TMP1:%.*]] = lshr i32 [[X:%.*]], 31 +; CHECK-NEXT: [[TMP2:%.*]] = zext i32 [[TMP1]] to i64 +; CHECK-NEXT: [[TMP3:%.*]] = and i64 [[TMP2]], [[Y:%.*]] +; CHECK-NEXT: [[TMP4:%.*]] = icmp ne i64 [[TMP3]], 0 +; CHECK-NEXT: ret i1 [[TMP4]] ; %t0 = sub i32 32, %len %t1 = lshr i32 %x, %t0 @@ -40,15 +38,11 @@ define <2 x i1> @t1_vec_splat(<2 x i32> %x, <2 x i64> %y, <2 x i32> %len) { ; CHECK-LABEL: @t1_vec_splat( -; CHECK-NEXT: [[T0:%.*]] = sub <2 x i32> , [[LEN:%.*]] -; CHECK-NEXT: [[T1:%.*]] = lshr <2 x i32> [[X:%.*]], [[T0]] -; CHECK-NEXT: [[T2:%.*]] = add <2 x i32> [[LEN]], -; CHECK-NEXT: [[T2_WIDE:%.*]] = zext <2 x i32> [[T2]] to <2 x i64> -; CHECK-NEXT: [[T3:%.*]] = shl <2 x i64> [[Y:%.*]], [[T2_WIDE]] -; CHECK-NEXT: [[T3_TRUNC:%.*]] = trunc <2 x i64> [[T3]] to <2 x i32> -; CHECK-NEXT: [[T4:%.*]] = and <2 x i32> [[T1]], [[T3_TRUNC]] -; CHECK-NEXT: [[T5:%.*]] = icmp ne <2 x i32> [[T4]], zeroinitializer -; CHECK-NEXT: ret <2 x i1> [[T5]] +; CHECK-NEXT: [[TMP1:%.*]] = lshr <2 x i32> [[X:%.*]], +; CHECK-NEXT: [[TMP2:%.*]] = zext <2 x i32> [[TMP1]] to <2 x i64> +; CHECK-NEXT: [[TMP3:%.*]] = and <2 x i64> [[TMP2]], [[Y:%.*]] +; CHECK-NEXT: [[TMP4:%.*]] = icmp ne <2 x i64> [[TMP3]], zeroinitializer +; CHECK-NEXT: ret <2 x i1> [[TMP4]] ; %t0 = sub <2 x i32> , %len %t1 = lshr <2 x i32> %x, %t0 @@ -63,15 +57,11 @@ define <2 x i1> @t2_vec_nonsplat(<2 x i32> %x, <2 x i64> %y, <2 x i32> %len) { ; CHECK-LABEL: @t2_vec_nonsplat( -; CHECK-NEXT: [[T0:%.*]] = sub <2 x i32> , [[LEN:%.*]] -; CHECK-NEXT: [[T1:%.*]] = lshr <2 x i32> [[X:%.*]], [[T0]] -; CHECK-NEXT: [[T2:%.*]] = add <2 x i32> [[LEN]], -; CHECK-NEXT: [[T2_WIDE:%.*]] = zext <2 x i32> [[T2]] to <2 x i64> -; CHECK-NEXT: [[T3:%.*]] = shl <2 x i64> [[Y:%.*]], [[T2_WIDE]] -; CHECK-NEXT: [[T3_TRUNC:%.*]] = trunc <2 x i64> [[T3]] to <2 x i32> -; CHECK-NEXT: [[T4:%.*]] = and <2 x i32> [[T1]], [[T3_TRUNC]] -; CHECK-NEXT: [[T5:%.*]] = icmp ne <2 x i32> [[T4]], zeroinitializer -; CHECK-NEXT: ret <2 x i1> [[T5]] +; CHECK-NEXT: [[TMP1:%.*]] = zext <2 x i32> [[X:%.*]] to <2 x i64> +; CHECK-NEXT: [[TMP2:%.*]] = lshr <2 x i64> [[TMP1]], +; CHECK-NEXT: [[TMP3:%.*]] = and <2 x i64> [[TMP2]], [[Y:%.*]] +; CHECK-NEXT: [[TMP4:%.*]] = icmp ne <2 x i64> [[TMP3]], zeroinitializer +; CHECK-NEXT: ret <2 x i1> [[TMP4]] ; %t0 = sub <2 x i32> , %len %t1 = lshr <2 x i32> %x, %t0 @@ -214,17 +204,17 @@ ; CHECK-LABEL: @t6_oneuse3( ; CHECK-NEXT: [[T0:%.*]] = sub i32 32, [[LEN:%.*]] ; CHECK-NEXT: call void @use32(i32 [[T0]]) -; CHECK-NEXT: [[T1:%.*]] = lshr i32 [[X:%.*]], [[T0]] ; CHECK-NEXT: [[T2:%.*]] = add i32 [[LEN]], -1 ; CHECK-NEXT: call void @use32(i32 [[T2]]) ; CHECK-NEXT: [[T2_WIDE:%.*]] = zext i32 [[T2]] to i64 ; CHECK-NEXT: call void @use64(i64 [[T2_WIDE]]) ; CHECK-NEXT: [[T3:%.*]] = shl i64 [[Y:%.*]], [[T2_WIDE]] ; CHECK-NEXT: call void @use64(i64 [[T3]]) -; CHECK-NEXT: [[T3_TRUNC:%.*]] = trunc i64 [[T3]] to i32 -; CHECK-NEXT: [[T4:%.*]] = and i32 [[T1]], [[T3_TRUNC]] -; CHECK-NEXT: [[T5:%.*]] = icmp ne i32 [[T4]], 0 -; CHECK-NEXT: ret i1 [[T5]] +; CHECK-NEXT: [[TMP1:%.*]] = lshr i32 [[X:%.*]], 31 +; CHECK-NEXT: [[TMP2:%.*]] = zext i32 [[TMP1]] to i64 +; CHECK-NEXT: [[TMP3:%.*]] = and i64 [[TMP2]], [[Y]] +; CHECK-NEXT: [[TMP4:%.*]] = icmp ne i64 [[TMP3]], 0 +; CHECK-NEXT: ret i1 [[TMP4]] ; %t0 = sub i32 32, %len call void @use32(i32 %t0) @@ -244,9 +234,7 @@ ; Ok, shift amount of non-truncated shift has no extra uses; define i1 @t7_oneuse4(i32 %x, i64 %y, i32 %len) { ; CHECK-LABEL: @t7_oneuse4( -; CHECK-NEXT: [[T0:%.*]] = sub i32 32, [[LEN:%.*]] -; CHECK-NEXT: [[T1:%.*]] = lshr i32 [[X:%.*]], [[T0]] -; CHECK-NEXT: [[T2:%.*]] = add i32 [[LEN]], -1 +; CHECK-NEXT: [[T2:%.*]] = add i32 [[LEN:%.*]], -1 ; CHECK-NEXT: call void @use32(i32 [[T2]]) ; CHECK-NEXT: [[T2_WIDE:%.*]] = zext i32 [[T2]] to i64 ; CHECK-NEXT: call void @use64(i64 [[T2_WIDE]]) @@ -254,9 +242,11 @@ ; CHECK-NEXT: call void @use64(i64 [[T3]]) ; CHECK-NEXT: [[T3_TRUNC:%.*]] = trunc i64 [[T3]] to i32 ; CHECK-NEXT: call void @use32(i32 [[T3_TRUNC]]) -; CHECK-NEXT: [[T4:%.*]] = and i32 [[T1]], [[T3_TRUNC]] -; CHECK-NEXT: [[T5:%.*]] = icmp ne i32 [[T4]], 0 -; CHECK-NEXT: ret i1 [[T5]] +; CHECK-NEXT: [[TMP1:%.*]] = lshr i32 [[X:%.*]], 31 +; CHECK-NEXT: [[TMP2:%.*]] = zext i32 [[TMP1]] to i64 +; CHECK-NEXT: [[TMP3:%.*]] = and i64 [[TMP2]], [[Y]] +; CHECK-NEXT: [[TMP4:%.*]] = icmp ne i64 [[TMP3]], 0 +; CHECK-NEXT: ret i1 [[TMP4]] ; %t0 = sub i32 32, %len ; no extra uses %t1 = lshr i32 %x, %t0 ; no extra uses @@ -288,9 +278,9 @@ ; CHECK-NEXT: call void @use64(i64 [[T3]]) ; CHECK-NEXT: [[T3_TRUNC:%.*]] = trunc i64 [[T3]] to i32 ; CHECK-NEXT: call void @use32(i32 [[T3_TRUNC]]) -; CHECK-NEXT: [[T4:%.*]] = and i32 [[T1]], [[T3_TRUNC]] -; CHECK-NEXT: [[T5:%.*]] = icmp ne i32 [[T4]], 0 -; CHECK-NEXT: ret i1 [[T5]] +; CHECK-NEXT: [[TMP1:%.*]] = and i64 [[Y]], 1 +; CHECK-NEXT: [[TMP2:%.*]] = icmp ne i64 [[TMP1]], 0 +; CHECK-NEXT: ret i1 [[TMP2]] ; %t0 = sub i32 32, %len call void @use32(i32 %t0) @@ -324,9 +314,7 @@ ; CHECK-NEXT: call void @use64(i64 [[T3]]) ; CHECK-NEXT: [[T3_TRUNC:%.*]] = trunc i64 [[T3]] to i32 ; CHECK-NEXT: call void @use32(i32 [[T3_TRUNC]]) -; CHECK-NEXT: [[T4:%.*]] = and i32 [[T1]], [[T3_TRUNC]] -; CHECK-NEXT: [[T5:%.*]] = icmp ne i32 [[T4]], 0 -; CHECK-NEXT: ret i1 [[T5]] +; CHECK-NEXT: ret i1 false ; %t0 = sub i32 32, %len call void @use32(i32 %t0) @@ -413,7 +401,7 @@ ; CHECK-LABEL: @n13_overshift( ; CHECK-NEXT: [[T0:%.*]] = sub i32 32, [[LEN:%.*]] ; CHECK-NEXT: [[T1:%.*]] = lshr i32 [[X:%.*]], [[T0]] -; CHECK-NEXT: [[T2:%.*]] = add i32 [[LEN]], 1 +; CHECK-NEXT: [[T2:%.*]] = add i32 [[LEN]], 32 ; CHECK-NEXT: [[T2_WIDE:%.*]] = zext i32 [[T2]] to i64 ; CHECK-NEXT: [[T3:%.*]] = shl i64 [[Y:%.*]], [[T2_WIDE]] ; CHECK-NEXT: [[T3_TRUNC:%.*]] = trunc i64 [[T3]] to i32 @@ -423,7 +411,7 @@ ; %t0 = sub i32 32, %len %t1 = lshr i32 %x, %t0 - %t2 = add i32 %len, 1 ; too much + %t2 = add i32 %len, 32 ; too much %t2_wide = zext i32 %t2 to i64 %t3 = shl i64 %y, %t2_wide %t3_trunc = trunc i64 %t3 to i32