diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp @@ -1930,6 +1930,10 @@ break; } case Intrinsic::ptrmask: { + KnownBits Known(DL.getPointerTypeSizeInBits(II->getType())); + if (SimplifyDemandedInstructionBits(*II, Known)) + return II; + Value *Op0 = II->getArgOperand(0); Value *Op1 = II->getArgOperand(1); // Fail loudly in case this is ever changed. @@ -1952,7 +1956,6 @@ } } bool Changed = false; - KnownBits Known = computeKnownBits(II, /*Depth*/ 0, II); // See if we can deduce non-null. if (!CI.hasRetAttr(Attribute::NonNull) && (Known.isNonZero() || diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h --- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h +++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h @@ -498,6 +498,7 @@ /// Tries to simplify operands to an integer instruction based on its /// demanded bits. bool SimplifyDemandedInstructionBits(Instruction &Inst); + bool SimplifyDemandedInstructionBits(Instruction &Inst, KnownBits &Known); Value *SimplifyDemandedVectorElts(Value *V, APInt DemandedElts, APInt &UndefElts, unsigned Depth = 0, diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp @@ -48,15 +48,19 @@ return true; } +/// Returns the bitwidth of the given scalar or pointer type. For vector types, +/// returns the element type's bitwidth. +static unsigned getBitWidth(Type *Ty, const DataLayout &DL) { + if (unsigned BitWidth = Ty->getScalarSizeInBits()) + return BitWidth; + return DL.getPointerTypeSizeInBits(Ty); +} /// Inst is an integer instruction that SimplifyDemandedBits knows about. See if /// the instruction has any properties that allow us to simplify its operands. -bool InstCombinerImpl::SimplifyDemandedInstructionBits(Instruction &Inst) { - unsigned BitWidth = Inst.getType()->getScalarSizeInBits(); - KnownBits Known(BitWidth); - APInt DemandedMask(APInt::getAllOnes(BitWidth)); - +bool InstCombinerImpl::SimplifyDemandedInstructionBits(Instruction &Inst, KnownBits &Known) { + APInt DemandedMask(APInt::getAllOnes(Known.getBitWidth())); Value *V = SimplifyDemandedUseBits(&Inst, DemandedMask, Known, 0, &Inst); if (!V) return false; @@ -65,6 +69,13 @@ return true; } +/// Inst is an integer instruction that SimplifyDemandedBits knows about. See if +/// the instruction has any properties that allow us to simplify its operands. +bool InstCombinerImpl::SimplifyDemandedInstructionBits(Instruction &Inst) { + KnownBits Known(getBitWidth(Inst.getType(), DL)); + return SimplifyDemandedInstructionBits(Inst, Known); +} + /// This form of SimplifyDemandedBits simplifies the specified instruction /// operand if possible, updating it in place. It returns true if it made any /// change and false otherwise. @@ -898,6 +909,46 @@ } break; } + case Intrinsic::ptrmask: { + // Fail loudly in case this is ever changed. + // Likely not much needs to be changed here to support vector types. + assert(!I->getOperand(0)->getType()->isVectorTy() && + !I->getOperand(1)->getType()->isVectorTy() && + "These simplifications where written at a time when ptrmask did " + "not support vector types and may not work for vectors"); + + unsigned MaskWidth = I->getOperand(1)->getType()->getScalarSizeInBits(); + RHSKnown = KnownBits(MaskWidth); + // If either the LHS or the RHS are Zero, the result is zero. + if (SimplifyDemandedBits(I, 0, DemandedMask, LHSKnown, Depth + 1) || + SimplifyDemandedBits( + I, 1, (DemandedMask & ~LHSKnown.Zero).zextOrTrunc(MaskWidth), + RHSKnown, Depth + 1)) + return I; + + RHSKnown = RHSKnown.zextOrTrunc(BitWidth); + assert(!RHSKnown.hasConflict() && "Bits known to be one AND zero?"); + assert(!LHSKnown.hasConflict() && "Bits known to be one AND zero?"); + + Known = LHSKnown & RHSKnown; + + // If the client is only demanding bits that we know, return the known + // constant. + if (DemandedMask.isSubsetOf(Known.Zero)) + return Constant::getNullValue(I->getOperand(0)->getType()); + + // Mask in demanded space does nothing. + if (DemandedMask.isSubsetOf(RHSKnown.One | LHSKnown.Zero)) + return I->getOperand(0); + + // If the RHS is a constant, see if we can simplify it. + if (ShrinkDemandedConstant( + I, 1, (DemandedMask & ~LHSKnown.Zero).zextOrTrunc(MaskWidth))) + return I; + + break; + } + case Intrinsic::fshr: case Intrinsic::fshl: { const APInt *SA; diff --git a/llvm/test/Transforms/InstCombine/align-addr.ll b/llvm/test/Transforms/InstCombine/align-addr.ll --- a/llvm/test/Transforms/InstCombine/align-addr.ll +++ b/llvm/test/Transforms/InstCombine/align-addr.ll @@ -223,7 +223,7 @@ ; than the pointer size. define <16 x i8> @ptrmask_align8_ptr_align1_bigmask(ptr align 1 %ptr) { ; CHECK-LABEL: @ptrmask_align8_ptr_align1_bigmask( -; CHECK-NEXT: [[ALIGNED:%.*]] = call align 8 ptr @llvm.ptrmask.p0.i128(ptr [[PTR:%.*]], i128 -8) +; CHECK-NEXT: [[ALIGNED:%.*]] = call align 8 ptr @llvm.ptrmask.p0.i128(ptr [[PTR:%.*]], i128 18446744073709551608) ; CHECK-NEXT: [[LOAD:%.*]] = load <16 x i8>, ptr [[ALIGNED]], align 8 ; CHECK-NEXT: ret <16 x i8> [[LOAD]] ; diff --git a/llvm/test/Transforms/InstCombine/consecutive-ptrmask.ll b/llvm/test/Transforms/InstCombine/consecutive-ptrmask.ll --- a/llvm/test/Transforms/InstCombine/consecutive-ptrmask.ll +++ b/llvm/test/Transforms/InstCombine/consecutive-ptrmask.ll @@ -96,7 +96,7 @@ ; CHECK-LABEL: define ptr @fold_2x_type_mismatch_const2 ; CHECK-SAME: (ptr [[P:%.*]]) { ; CHECK-NEXT: [[P0:%.*]] = call align 4 ptr @llvm.ptrmask.p0.i32(ptr [[P]], i32 -4) -; CHECK-NEXT: [[P1:%.*]] = call align 32 ptr @llvm.ptrmask.p0.i64(ptr [[P0]], i64 -31) +; CHECK-NEXT: [[P1:%.*]] = call align 32 ptr @llvm.ptrmask.p0.i64(ptr [[P0]], i64 4294967264) ; CHECK-NEXT: ret ptr [[P1]] ; %p0 = call ptr @llvm.ptrmask.p0.i32(ptr %p, i32 -4) diff --git a/llvm/test/Transforms/InstCombine/ptrmask.ll b/llvm/test/Transforms/InstCombine/ptrmask.ll --- a/llvm/test/Transforms/InstCombine/ptrmask.ll +++ b/llvm/test/Transforms/InstCombine/ptrmask.ll @@ -66,7 +66,7 @@ ; CHECK-SAME: (ptr [[P:%.*]]) { ; CHECK-NEXT: [[PM0:%.*]] = call align 64 ptr @llvm.ptrmask.p0.i64(ptr [[P]], i64 -64) ; CHECK-NEXT: [[PGEP:%.*]] = getelementptr i8, ptr [[PM0]], i64 33 -; CHECK-NEXT: [[R:%.*]] = call nonnull align 32 ptr @llvm.ptrmask.p0.i64(ptr [[PGEP]], i64 -16) +; CHECK-NEXT: [[R:%.*]] = call nonnull align 32 ptr @llvm.ptrmask.p0.i64(ptr [[PGEP]], i64 -32) ; CHECK-NEXT: ret ptr [[R]] ; %pm0 = call ptr @llvm.ptrmask.p0.i64(ptr %p, i64 -64)