diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp @@ -1950,10 +1950,30 @@ // -> (ptrmask p, (and A, B)) if (match(Op0, m_OneUse(m_Intrinsic( m_Value(InnerPtr), m_Value(InnerMask))))) { + // See if combining the two masks is free. + bool OkayToMerge = InnerMask->getType() == Op1->getType(); + bool NeedsNew = false; + if (!OkayToMerge) { + if (match(InnerMask, m_ImmConstant())) { + InnerMask = Builder.CreateZExtOrTrunc(InnerMask, Op1->getType()); + OkayToMerge = true; + } else if (match(Op1, m_ImmConstant())) { + Op1 = Builder.CreateZExtOrTrunc(Op1, InnerMask->getType()); + OkayToMerge = true; + // Need to create a new one here, as the intrinsic id needs to change. + NeedsNew = true; + } + } if (InnerMask->getType() == Op1->getType()) { // TODO: If InnerMask == Op1, we could copy attributes from inner // callsite -> outer callsite. Value *NewMask = Builder.CreateAnd(Op1, InnerMask); + if (NeedsNew) + return replaceInstUsesWith( + *II, + Builder.CreateIntrinsic(InnerPtr->getType(), Intrinsic::ptrmask, + {InnerPtr, NewMask})); + replaceOperand(CI, 0, InnerPtr); replaceOperand(CI, 1, NewMask); Changed = true; diff --git a/llvm/test/Transforms/InstCombine/consecutive-ptrmask.ll b/llvm/test/Transforms/InstCombine/consecutive-ptrmask.ll --- a/llvm/test/Transforms/InstCombine/consecutive-ptrmask.ll +++ b/llvm/test/Transforms/InstCombine/consecutive-ptrmask.ll @@ -70,8 +70,8 @@ define ptr @fold_2x_type_mismatch_const0(ptr %p, i32 %m1) { ; CHECK-LABEL: define ptr @fold_2x_type_mismatch_const0 ; CHECK-SAME: (ptr [[P:%.*]], i32 [[M1:%.*]]) { -; CHECK-NEXT: [[P0:%.*]] = call align 128 ptr @llvm.ptrmask.p0.i64(ptr [[P]], i64 -128) -; CHECK-NEXT: [[P1:%.*]] = call align 128 ptr @llvm.ptrmask.p0.i32(ptr [[P0]], i32 [[M1]]) +; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[M1]], -128 +; CHECK-NEXT: [[P1:%.*]] = call align 128 ptr @llvm.ptrmask.p0.i32(ptr [[P]], i32 [[TMP1]]) ; CHECK-NEXT: ret ptr [[P1]] ; %p0 = call ptr @llvm.ptrmask.p0.i64(ptr %p, i64 -128) @@ -82,8 +82,8 @@ define ptr @fold_2x_type_mismatch_const1(ptr %p, i64 %m0) { ; CHECK-LABEL: define ptr @fold_2x_type_mismatch_const1 ; CHECK-SAME: (ptr [[P:%.*]], i64 [[M0:%.*]]) { -; CHECK-NEXT: [[P0:%.*]] = call ptr @llvm.ptrmask.p0.i64(ptr [[P]], i64 [[M0]]) -; CHECK-NEXT: [[P1:%.*]] = call align 2 ptr @llvm.ptrmask.p0.i32(ptr [[P0]], i32 -2) +; CHECK-NEXT: [[TMP1:%.*]] = and i64 [[M0]], 4294967294 +; CHECK-NEXT: [[P1:%.*]] = call align 2 ptr @llvm.ptrmask.p0.i64(ptr [[P]], i64 [[TMP1]]) ; CHECK-NEXT: ret ptr [[P1]] ; %p0 = call ptr @llvm.ptrmask.p0.i64(ptr %p, i64 %m0) @@ -95,8 +95,7 @@ define ptr @fold_2x_type_mismatch_const2(ptr %p) { ; CHECK-LABEL: define ptr @fold_2x_type_mismatch_const2 ; CHECK-SAME: (ptr [[P:%.*]]) { -; CHECK-NEXT: [[P0:%.*]] = call align 4 ptr @llvm.ptrmask.p0.i32(ptr [[P]], i32 -4) -; CHECK-NEXT: [[P1:%.*]] = call align 32 ptr @llvm.ptrmask.p0.i64(ptr [[P0]], i64 4294967264) +; CHECK-NEXT: [[P1:%.*]] = call align 32 ptr @llvm.ptrmask.p0.i64(ptr [[P]], i64 4294967264) ; CHECK-NEXT: ret ptr [[P1]] ; %p0 = call ptr @llvm.ptrmask.p0.i32(ptr %p, i32 -4)