diff --git a/llvm/include/llvm/IR/Metadata.h b/llvm/include/llvm/IR/Metadata.h --- a/llvm/include/llvm/IR/Metadata.h +++ b/llvm/include/llvm/IR/Metadata.h @@ -1119,8 +1119,9 @@ } } - /// Check whether MDNode is a vtable access. + /// Check whether MDNode is a pointer/vtable access. bool isTBAAVtableAccess() const; + bool isTBAAPointerAccess() const; /// Methods for metadata merging. static MDNode *concatenate(MDNode *A, MDNode *B); diff --git a/llvm/lib/Analysis/TypeBasedAliasAnalysis.cpp b/llvm/lib/Analysis/TypeBasedAliasAnalysis.cpp --- a/llvm/lib/Analysis/TypeBasedAliasAnalysis.cpp +++ b/llvm/lib/Analysis/TypeBasedAliasAnalysis.cpp @@ -451,26 +451,38 @@ return AAResultBase::getModRefInfo(Call1, Call2, AAQI); } -bool MDNode::isTBAAVtableAccess() const { - if (!isStructPathTBAA(this)) { - if (getNumOperands() < 1) +static bool isAccessWithDesc(const MDNode *Node, + const SmallVector Descs) { + if (!isStructPathTBAA(Node)) { + if (Node->getNumOperands() < 1) return false; - if (MDString *Tag1 = dyn_cast(getOperand(0))) { - if (Tag1->getString() == "vtable pointer") - return true; + if (MDString *Tag1 = dyn_cast(Node->getOperand(0))) { + for (auto *Desc : Descs) + if (Tag1->getString() == Desc) + return true; } return false; } // For struct-path aware TBAA, we use the access type of the tag. - TBAAStructTagNode Tag(this); + TBAAStructTagNode Tag(Node); TBAAStructTypeNode AccessType(Tag.getAccessType()); - if(auto *Id = dyn_cast(AccessType.getId())) - if (Id->getString() == "vtable pointer") - return true; + if(auto *Id = dyn_cast(AccessType.getId())) { + for (auto *Desc : Descs) + if (Id->getString() == Desc) + return true; + } return false; } +bool MDNode::isTBAAVtableAccess() const { + return isAccessWithDesc(this, {"vtable pointer"}); +} + +bool MDNode::isTBAAPointerAccess() const { + return isAccessWithDesc(this, {"vtable pointer", "any pointer"}); +} + static bool matchAccessTags(const MDNode *A, const MDNode *B, const MDNode **GenericTag = nullptr); diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp @@ -151,16 +151,11 @@ if (*CopyDstAlign < Size || *CopySrcAlign < Size) return nullptr; - // Use an integer load+store unless we can find something better. unsigned SrcAddrSp = cast(MI->getArgOperand(1)->getType())->getAddressSpace(); unsigned DstAddrSp = cast(MI->getArgOperand(0)->getType())->getAddressSpace(); - IntegerType* IntType = IntegerType::get(MI->getContext(), Size<<3); - Type *NewSrcPtrTy = PointerType::get(IntType, SrcAddrSp); - Type *NewDstPtrTy = PointerType::get(IntType, DstAddrSp); - // If the memcpy has metadata describing the members, see if we can get the // TBAA tag describing our copy. MDNode *CopyMD = nullptr; @@ -178,9 +173,24 @@ CopyMD = cast(M->getOperand(2)); } + Type* DataType = nullptr; + if (CopyMD && CopyMD->isTBAAPointerAccess() && SrcAddrSp == DstAddrSp && + Size * 8 == DL.getPointerSizeInBits(SrcAddrSp)) { + // TBAA tag says this is a pointer type; follow the instruction to avoid + // creating redundant inttoptr/ptrtoints. + DataType = PointerType::get( + IntegerType::get(MI->getContext(), 8), SrcAddrSp); + } else { + // Use an integer load+store unless we can find something better. + DataType = IntegerType::get(MI->getContext(), Size<<3); + } + + Type *NewSrcPtrTy = PointerType::get(DataType, SrcAddrSp); + Type *NewDstPtrTy = PointerType::get(DataType, DstAddrSp); + Value *Src = Builder.CreateBitCast(MI->getArgOperand(1), NewSrcPtrTy); Value *Dest = Builder.CreateBitCast(MI->getArgOperand(0), NewDstPtrTy); - LoadInst *L = Builder.CreateLoad(IntType, Src); + LoadInst *L = Builder.CreateLoad(DataType, Src); // Alignment from the mem intrinsic will be better, so use it. L->setAlignment(*CopySrcAlign); if (CopyMD) diff --git a/llvm/test/Transforms/InstCombine/memcpy-tbaa.ll b/llvm/test/Transforms/InstCombine/memcpy-tbaa.ll --- a/llvm/test/Transforms/InstCombine/memcpy-tbaa.ll +++ b/llvm/test/Transforms/InstCombine/memcpy-tbaa.ll @@ -7,10 +7,10 @@ define void @f(%struct.T* %s, %struct.T* %t) { ; CHECK-LABEL: @f( ; CHECK-NEXT: entry: -; CHECK-NEXT: [[TMP0:%.*]] = bitcast %struct.T* [[T:%.*]] to i64* -; CHECK-NEXT: [[TMP1:%.*]] = bitcast %struct.T* [[S:%.*]] to i64* -; CHECK-NEXT: [[TMP2:%.*]] = load i64, i64* [[TMP0]], align 8, !tbaa [[TBAA0:![0-9]+]] -; CHECK-NEXT: store i64 [[TMP2]], i64* [[TMP1]], align 8, !tbaa [[TBAA0]] +; CHECK-NEXT: [[TMP0:%.*]] = bitcast %struct.T* [[T:%.*]] to i8** +; CHECK-NEXT: [[TMP1:%.*]] = bitcast %struct.T* [[S:%.*]] to i8** +; CHECK-NEXT: [[TMP2:%.*]] = load i8*, i8** [[TMP0]], align 8, !tbaa [[TBAA0:![0-9]+]] +; CHECK-NEXT: store i8* [[TMP2]], i8** [[TMP1]], align 8, !tbaa [[TBAA0]] ; CHECK-NEXT: ret void ; ; CHECK32-LABEL: @f(