Index: llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp =================================================================== --- llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -1412,6 +1412,55 @@ return nullptr; } +/// Fold icmp involving trivial vscale calculations, e.g.: +/// icmp Pred vscale, C +/// icmp Pred (mul vscale, X), C +/// icmp Pred (shl vscale, X), C +/// We can use the maximum value of vscale if it exists in order to determine +/// the result of certain cases at compile time. +Instruction *InstCombinerImpl::foldICmpWithConstantVScale(ICmpInst &Cmp) { + CmpInst::Predicate Pred = Cmp.getPredicate(); + Value *Op0 = Cmp.getOperand(0), *Op1 = Cmp.getOperand(1); + + ConstantInt *CI = dyn_cast(Op1); + if (!CI) + return nullptr; + + ConstantInt *VScaleShift = nullptr; + ConstantInt *VScaleMul = nullptr; + if (Pred == ICmpInst::ICMP_UGT && + (match(Op0, m_VScale(DL)) || + match(Op0, m_Shl(m_VScale(DL), m_ConstantInt(VScaleShift))) || + match(Op0, m_Mul(m_VScale(DL), m_ConstantInt(VScaleMul))))) { + // If we don't know the max value of vscale we can't fold anything. + if (!Cmp.getFunction()->hasFnAttribute(Attribute::VScaleRange)) + return nullptr; + + uint64_t MaxVScale = Cmp.getFunction() + ->getFnAttribute(Attribute::VScaleRange) + .getVScaleRangeArgs() + .second; + if (MaxVScale == 0) + return nullptr; + + // We know the maximum value of vscale * multipler and so we may + // be able to completely eliminate the icmp. + uint64_t MaxLHSVal = MaxVScale; + if (VScaleShift) { + uint64_t ShiftVal = VScaleShift ? VScaleShift->getZExtValue() : 1; + MaxLHSVal = MaxVScale << ShiftVal; + } else { + uint64_t MulVal = VScaleMul ? VScaleMul->getZExtValue() : 1; + MaxLHSVal = MaxVScale * MulVal; + } + + if (MaxLHSVal <= CI->getZExtValue()) + return replaceInstUsesWith(Cmp, ConstantInt::get(Cmp.getType(), 0)); + } + + return nullptr; +} + /// Fold icmp Pred X, C. /// TODO: This code structure does not make sense. The saturating add fold /// should be moved to some other helper and extended as noted below (it is also @@ -5841,6 +5890,9 @@ if (Instruction *Res = foldICmpWithConstant(I)) return Res; + if (Instruction *Res = foldICmpWithConstantVScale(I)) + return Res; + if (Instruction *Res = foldICmpWithDominatingICmp(I)) return Res; Index: llvm/lib/Transforms/InstCombine/InstCombineInternal.h =================================================================== --- llvm/lib/Transforms/InstCombine/InstCombineInternal.h +++ llvm/lib/Transforms/InstCombine/InstCombineInternal.h @@ -649,6 +649,7 @@ Instruction *foldICmpUsingKnownBits(ICmpInst &Cmp); Instruction *foldICmpWithDominatingICmp(ICmpInst &Cmp); Instruction *foldICmpWithConstant(ICmpInst &Cmp); + Instruction *foldICmpWithConstantVScale(ICmpInst &Cmp); Instruction *foldICmpInstWithConstant(ICmpInst &Cmp); Instruction *foldICmpInstWithConstantNotInt(ICmpInst &Cmp); Instruction *foldICmpBinOp(ICmpInst &Cmp, const SimplifyQuery &SQ); Index: llvm/test/Transforms/InstCombine/icmp-vscale.ll =================================================================== --- llvm/test/Transforms/InstCombine/icmp-vscale.ll +++ llvm/test/Transforms/InstCombine/icmp-vscale.ll @@ -37,6 +37,18 @@ ret i1 %res } +define i1 @uge_vscale8_x_4() vscale_range(0,16) { +; CHECK-LABEL: @uge_vscale8_x_4( +; CHECK-NEXT: entry: +; CHECK-NEXT: ret i1 false +; +entry: + %vscale = call i8 @llvm.vscale.i8() + %num_els = mul i8 %vscale, 5 + %res = icmp uge i8 %num_els, 128 + ret i1 %res +} + define i1 @ult_vscale16() vscale_range(0,16) { ; CHECK-LABEL: @ult_vscale16( ; CHECK-NEXT: entry: @@ -71,6 +83,52 @@ ret i1 %res } +; Negative tests + +define i1 @fail_uge_vscale64_x_32_max32() vscale_range(0,32) { +; CHECK-LABEL: @fail_uge_vscale64_x_32_max32( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[VSCALE:%.*]] = call i64 @llvm.vscale.i64() +; CHECK-NEXT: [[RES:%.*]] = icmp ugt i64 [[VSCALE]], 31 +; CHECK-NEXT: ret i1 [[RES]] +; +entry: + %vscale = call i64 @llvm.vscale.i64() + %num_els = shl i64 %vscale, 5 + %res = icmp uge i64 %num_els, 1024 + ret i1 %res +} + +define i1 @fail_uge_vscale64_x_32_max0() vscale_range(0,0) { +; CHECK-LABEL: @fail_uge_vscale64_x_32_max0( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[VSCALE:%.*]] = call i64 @llvm.vscale.i64() +; CHECK-NEXT: [[NUM_ELS:%.*]] = mul i64 [[VSCALE]], 31 +; CHECK-NEXT: [[RES:%.*]] = icmp ugt i64 [[NUM_ELS]], 12345677 +; CHECK-NEXT: ret i1 [[RES]] +; +entry: + %vscale = call i64 @llvm.vscale.i64() + %num_els = mul i64 %vscale, 31 + %res = icmp uge i64 %num_els, 12345678 + ret i1 %res +} + +define i1 @fail_uge_vscale64_x_val(i64 %val) vscale_range(0,32) { +; CHECK-LABEL: @fail_uge_vscale64_x_val( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[VSCALE:%.*]] = call i64 @llvm.vscale.i64() +; CHECK-NEXT: [[NUM_ELS:%.*]] = mul i64 [[VSCALE]], [[VAL:%.*]] +; CHECK-NEXT: [[RES:%.*]] = icmp ugt i64 [[NUM_ELS]], 1023 +; CHECK-NEXT: ret i1 [[RES]] +; +entry: + %vscale = call i64 @llvm.vscale.i64() + %num_els = mul i64 %vscale, %val + %res = icmp uge i64 %num_els, 1024 + ret i1 %res +} + declare i8 @llvm.vscale.i8() declare i16 @llvm.vscale.i16() declare i32 @llvm.vscale.i32()