diff --git a/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp b/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp --- a/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp +++ b/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp @@ -242,19 +242,30 @@ const auto *DeadII = dyn_cast(DeadI); if (KillingII == nullptr || DeadII == nullptr) return OW_Unknown; - if (KillingII->getIntrinsicID() != Intrinsic::masked_store || - DeadII->getIntrinsicID() != Intrinsic::masked_store) + if (KillingII->getIntrinsicID() != DeadII->getIntrinsicID()) return OW_Unknown; - // Pointers. - Value *KillingPtr = KillingII->getArgOperand(1)->stripPointerCasts(); - Value *DeadPtr = DeadII->getArgOperand(1)->stripPointerCasts(); - if (KillingPtr != DeadPtr && !AA.isMustAlias(KillingPtr, DeadPtr)) - return OW_Unknown; - // Masks. - // TODO: check that KillingII's mask is a superset of the DeadII's mask. - if (KillingII->getArgOperand(3) != DeadII->getArgOperand(3)) - return OW_Unknown; - return OW_Complete; + if (KillingII->getIntrinsicID() == Intrinsic::masked_store) { + // Type size. + VectorType *KillingTy = + cast(KillingII->getArgOperand(0)->getType()); + VectorType *DeadTy = cast(DeadII->getArgOperand(0)->getType()); + if (KillingTy->getScalarSizeInBits() != DeadTy->getScalarSizeInBits()) + return OW_Unknown; + // Element count. + if (KillingTy->getElementCount() != DeadTy->getElementCount()) + return OW_Unknown; + // Pointers. + Value *KillingPtr = KillingII->getArgOperand(1)->stripPointerCasts(); + Value *DeadPtr = DeadII->getArgOperand(1)->stripPointerCasts(); + if (KillingPtr != DeadPtr && !AA.isMustAlias(KillingPtr, DeadPtr)) + return OW_Unknown; + // Masks. + // TODO: check that KillingII's mask is a superset of the DeadII's mask. + if (KillingII->getArgOperand(3) != DeadII->getArgOperand(3)) + return OW_Unknown; + return OW_Complete; + } + return OW_Unknown; } /// Return 'OW_Complete' if a store to the 'KillingLoc' location completely diff --git a/llvm/test/Transforms/DeadStoreElimination/masked-dead-store.ll b/llvm/test/Transforms/DeadStoreElimination/masked-dead-store.ll --- a/llvm/test/Transforms/DeadStoreElimination/masked-dead-store.ll +++ b/llvm/test/Transforms/DeadStoreElimination/masked-dead-store.ll @@ -59,7 +59,8 @@ define dllexport i32 @f1(<4 x i32>* %a, <4 x i8> %v1, <4 x i32> %v2) { ; CHECK-LABEL: @f1( -; CHECK-NEXT: [[PTR:%.*]] = bitcast <4 x i32>* [[A:%.*]] to <4 x i8>* +; CHECK-NEXT: call void @llvm.masked.store.v4i32.p0v4i32(<4 x i32> [[V2:%.*]], <4 x i32>* [[A:%.*]], i32 1, <4 x i1> ) +; CHECK-NEXT: [[PTR:%.*]] = bitcast <4 x i32>* [[A]] to <4 x i8>* ; CHECK-NEXT: call void @llvm.masked.store.v4i8.p0v4i8(<4 x i8> [[V1:%.*]], <4 x i8>* [[PTR]], i32 1, <4 x i1> ) ; CHECK-NEXT: ret i32 0 ; @@ -69,6 +70,19 @@ ret i32 0 } +define dllexport i32 @f2(<4 x i32>* %a, <4 x i8> %v1, <4 x i32> %v2, <4 x i1> %mask) { +; CHECK-LABEL: @f2( +; CHECK-NEXT: call void @llvm.masked.store.v4i32.p0v4i32(<4 x i32> [[V2:%.*]], <4 x i32>* [[A:%.*]], i32 1, <4 x i1> [[MASK:%.*]]) +; CHECK-NEXT: [[PTR:%.*]] = bitcast <4 x i32>* [[A]] to <4 x i8>* +; CHECK-NEXT: call void @llvm.masked.store.v4i8.p0v4i8(<4 x i8> [[V1:%.*]], <4 x i8>* [[PTR]], i32 1, <4 x i1> [[MASK]]) +; CHECK-NEXT: ret i32 0 +; + tail call void @llvm.masked.store.v4i32.p0(<4 x i32> %v2, <4 x i32>* %a, i32 1, <4 x i1> %mask) + %ptr = bitcast <4 x i32>* %a to <4 x i8>* + tail call void @llvm.masked.store.v4i8.p0(<4 x i8> %v1, <4 x i8>* %ptr, i32 1, <4 x i1> %mask) + ret i32 0 +} + declare void @llvm.masked.store.v4i8.p0(<4 x i8>, <4 x i8>*, i32, <4 x i1>) declare void @llvm.masked.store.v4i32.p0(<4 x i32>, <4 x i32>*, i32, <4 x i1>)