Index: lib/CodeGen/CGExpr.cpp =================================================================== --- lib/CodeGen/CGExpr.cpp +++ lib/CodeGen/CGExpr.cpp @@ -916,7 +916,8 @@ /// EmitPointerWithAlignment - Given an expression of pointer type, try to /// derive a more accurate bound on the alignment of the pointer. Address CodeGenFunction::EmitPointerWithAlignment(const Expr *E, - LValueBaseInfo *BaseInfo) { + LValueBaseInfo *BaseInfo, + TBAAAccessInfo *TBAAInfo) { // We allow this with ObjC object pointers because of fragile ABIs. assert(E->getType()->isPointerType() || E->getType()->isObjCObjectPointerType()); @@ -936,20 +937,30 @@ if (PtrTy->getPointeeType()->isVoidType()) break; - LValueBaseInfo InnerInfo; - Address Addr = EmitPointerWithAlignment(CE->getSubExpr(), &InnerInfo); - if (BaseInfo) *BaseInfo = InnerInfo; - - // If this is an explicit bitcast, and the source l-value is - // opaque, honor the alignment of the casted-to type. - if (isa(CE) && - InnerInfo.getAlignmentSource() != AlignmentSource::Decl) { - LValueBaseInfo ExpInfo; + LValueBaseInfo InnerBaseInfo; + TBAAAccessInfo InnerTBAAInfo; + Address Addr = EmitPointerWithAlignment(CE->getSubExpr(), + &InnerBaseInfo, + &InnerTBAAInfo); + if (BaseInfo) *BaseInfo = InnerBaseInfo; + if (TBAAInfo) *TBAAInfo = InnerTBAAInfo; + + if (isa(CE)) { + LValueBaseInfo TargetTypeBaseInfo; + TBAAAccessInfo TargetTypeTBAAInfo; CharUnits Align = getNaturalPointeeTypeAlignment(E->getType(), - &ExpInfo); - if (BaseInfo) - BaseInfo->mergeForCast(ExpInfo); - Addr = Address(Addr.getPointer(), Align); + &TargetTypeBaseInfo, + &TargetTypeTBAAInfo); + if (TBAAInfo) + *TBAAInfo = CGM.mergeTBAAInfoForCast(*TBAAInfo, + TargetTypeTBAAInfo); + // If the source l-value is opaque, honor the alignment of the + // casted-to type. + if (InnerBaseInfo.getAlignmentSource() != AlignmentSource::Decl) { + if (BaseInfo) + BaseInfo->mergeForCast(TargetTypeBaseInfo); + Addr = Address(Addr.getPointer(), Align); + } } if (SanOpts.has(SanitizerKind::CFIUnrelatedCast) && @@ -969,12 +980,13 @@ // Array-to-pointer decay. case CK_ArrayToPointerDecay: - return EmitArrayToPointerDecay(CE->getSubExpr(), BaseInfo); + return EmitArrayToPointerDecay(CE->getSubExpr(), BaseInfo, TBAAInfo); // Derived-to-base conversions. case CK_UncheckedDerivedToBase: case CK_DerivedToBase: { - Address Addr = EmitPointerWithAlignment(CE->getSubExpr(), BaseInfo); + Address Addr = EmitPointerWithAlignment(CE->getSubExpr(), BaseInfo, + TBAAInfo); auto Derived = CE->getSubExpr()->getType()->getPointeeCXXRecordDecl(); return GetAddressOfBaseClass(Addr, Derived, CE->path_begin(), CE->path_end(), @@ -994,6 +1006,7 @@ if (UO->getOpcode() == UO_AddrOf) { LValue LV = EmitLValue(UO->getSubExpr()); if (BaseInfo) *BaseInfo = LV.getBaseInfo(); + if (TBAAInfo) *TBAAInfo = LV.getTBAAInfo(); return LV.getAddress(); } } @@ -1001,7 +1014,8 @@ // TODO: conditional operators, comma. // Otherwise, use the alignment of the type. - CharUnits Align = getNaturalPointeeTypeAlignment(E->getType(), BaseInfo); + CharUnits Align = getNaturalPointeeTypeAlignment(E->getType(), BaseInfo, + TBAAInfo); return Address(EmitScalarExpr(E), Align); } @@ -2447,8 +2461,10 @@ assert(!T.isNull() && "CodeGenFunction::EmitUnaryOpLValue: Illegal type"); LValueBaseInfo BaseInfo; - Address Addr = EmitPointerWithAlignment(E->getSubExpr(), &BaseInfo); - LValue LV = MakeAddrLValue(Addr, T, BaseInfo, CGM.getTBAAAccessInfo(T)); + TBAAAccessInfo TBAAInfo; + Address Addr = EmitPointerWithAlignment(E->getSubExpr(), &BaseInfo, + &TBAAInfo); + LValue LV = MakeAddrLValue(Addr, T, BaseInfo, TBAAInfo); LV.getQuals().setAddressSpace(ExprTy.getAddressSpace()); // We should not generate __weak write barrier on indirect reference @@ -3048,7 +3064,8 @@ } Address CodeGenFunction::EmitArrayToPointerDecay(const Expr *E, - LValueBaseInfo *BaseInfo) { + LValueBaseInfo *BaseInfo, + TBAAAccessInfo *TBAAInfo) { assert(E->getType()->isArrayType() && "Array to pointer decay must have array source type!"); @@ -3056,6 +3073,7 @@ LValue LV = EmitLValue(E); Address Addr = LV.getAddress(); if (BaseInfo) *BaseInfo = LV.getBaseInfo(); + if (TBAAInfo) *TBAAInfo = LV.getTBAAInfo(); // If the array type was an incomplete type, we need to make sure // the decay ends up being the right type. @@ -3216,13 +3234,14 @@ } LValueBaseInfo BaseInfo; + TBAAAccessInfo TBAAInfo; Address Addr = Address::invalid(); if (const VariableArrayType *vla = getContext().getAsVariableArrayType(E->getType())) { // The base must be a pointer, which is not an aggregate. Emit // it. It needs to be emitted first in case it's what captures // the VLA bounds. - Addr = EmitPointerWithAlignment(E->getBase(), &BaseInfo); + Addr = EmitPointerWithAlignment(E->getBase(), &BaseInfo, &TBAAInfo); auto *Idx = EmitIdxAfterBase(/*Promote*/true); // The element count here is the total number of non-VLA elements. @@ -3246,7 +3265,7 @@ // Indexing over an interface, as in "NSString *P; P[4];" // Emit the base pointer. - Addr = EmitPointerWithAlignment(E->getBase(), &BaseInfo); + Addr = EmitPointerWithAlignment(E->getBase(), &BaseInfo, &TBAAInfo); auto *Idx = EmitIdxAfterBase(/*Promote*/true); CharUnits InterfaceSize = getContext().getTypeSizeInChars(OIT); @@ -3294,19 +3313,17 @@ E->getType(), !getLangOpts().isSignedOverflowDefined(), SignedIndices, E->getExprLoc()); BaseInfo = ArrayLV.getBaseInfo(); + TBAAInfo = CGM.getTBAAAccessInfo(E->getType()); } else { // The base must be a pointer; emit it with an estimate of its alignment. - Addr = EmitPointerWithAlignment(E->getBase(), &BaseInfo); + Addr = EmitPointerWithAlignment(E->getBase(), &BaseInfo, &TBAAInfo); auto *Idx = EmitIdxAfterBase(/*Promote*/true); Addr = emitArraySubscriptGEP(*this, Addr, Idx, E->getType(), !getLangOpts().isSignedOverflowDefined(), SignedIndices, E->getExprLoc()); } - LValue LV = MakeAddrLValue(Addr, E->getType(), BaseInfo, - CGM.getTBAAAccessInfo(E->getType())); - - // TODO: Preserve/extend path TBAA metadata? + LValue LV = MakeAddrLValue(Addr, E->getType(), BaseInfo, TBAAInfo); if (getLangOpts().ObjC1 && getLangOpts().getGC() != LangOptions::NonGC) { @@ -3321,8 +3338,6 @@ TBAAAccessInfo &TBAAInfo, QualType BaseTy, QualType ElTy, bool IsLowerBound) { - TBAAInfo = CGF.CGM.getTBAAAccessInfo(ElTy); - LValue BaseLVal; if (auto *ASE = dyn_cast(Base->IgnoreParenImpCasts())) { BaseLVal = CGF.EmitOMPArraySectionExpr(ASE, IsLowerBound); @@ -3352,7 +3367,7 @@ BaseInfo.mergeForCast(TypeInfo); return Address(CGF.Builder.CreateLoad(BaseLVal.getAddress()), Align); } - return CGF.EmitPointerWithAlignment(Base, &BaseInfo); + return CGF.EmitPointerWithAlignment(Base, &BaseInfo, &TBAAInfo); } LValue CodeGenFunction::EmitOMPArraySectionExpr(const OMPArraySectionExpr *E, @@ -3518,10 +3533,10 @@ // If it is a pointer to a vector, emit the address and form an lvalue with // it. LValueBaseInfo BaseInfo; - Address Ptr = EmitPointerWithAlignment(E->getBase(), &BaseInfo); + TBAAAccessInfo TBAAInfo; + Address Ptr = EmitPointerWithAlignment(E->getBase(), &BaseInfo, &TBAAInfo); const PointerType *PT = E->getBase()->getType()->getAs(); - Base = MakeAddrLValue(Ptr, PT->getPointeeType(), BaseInfo, - CGM.getTBAAAccessInfo(PT->getPointeeType())); + Base = MakeAddrLValue(Ptr, PT->getPointeeType(), BaseInfo, TBAAInfo); Base.getQuals().removeObjCGCAttr(); } else if (E->getBase()->isGLValue()) { // Otherwise, if the base is an lvalue ( as in the case of foo.x.x), @@ -3577,7 +3592,8 @@ LValue BaseLV; if (E->isArrow()) { LValueBaseInfo BaseInfo; - Address Addr = EmitPointerWithAlignment(BaseExpr, &BaseInfo); + TBAAAccessInfo TBAAInfo; + Address Addr = EmitPointerWithAlignment(BaseExpr, &BaseInfo, &TBAAInfo); QualType PtrTy = BaseExpr->getType()->getPointeeType(); SanitizerSet SkippedChecks; bool IsBaseCXXThis = IsWrappedCXXThis(BaseExpr); @@ -3587,8 +3603,7 @@ SkippedChecks.set(SanitizerKind::Null, true); EmitTypeCheck(TCK_MemberAccess, E->getExprLoc(), Addr.getPointer(), PtrTy, /*Alignment=*/CharUnits::Zero(), SkippedChecks); - BaseLV = MakeAddrLValue(Addr, PtrTy, BaseInfo, - CGM.getTBAAAccessInfo(PtrTy)); + BaseLV = MakeAddrLValue(Addr, PtrTy, BaseInfo, TBAAInfo); } else BaseLV = EmitCheckedLValue(BaseExpr, TCK_MemberAccess); Index: lib/CodeGen/CodeGenFunction.h =================================================================== --- lib/CodeGen/CodeGenFunction.h +++ lib/CodeGen/CodeGenFunction.h @@ -1942,7 +1942,8 @@ TBAAAccessInfo *TBAAInfo = nullptr, bool forPointeeType = false); CharUnits getNaturalPointeeTypeAlignment(QualType T, - LValueBaseInfo *BaseInfo = nullptr); + LValueBaseInfo *BaseInfo = nullptr, + TBAAAccessInfo *TBAAInfo = nullptr); Address EmitLoadOfReference(Address Ref, const ReferenceType *RefTy, LValueBaseInfo *BaseInfo = nullptr, @@ -3188,7 +3189,8 @@ RValue EmitRValueForField(LValue LV, const FieldDecl *FD, SourceLocation Loc); Address EmitArrayToPointerDecay(const Expr *Array, - LValueBaseInfo *BaseInfo = nullptr); + LValueBaseInfo *BaseInfo = nullptr, + TBAAAccessInfo *TBAAInfo = nullptr); class ConstantEmission { llvm::PointerIntPair ValueAndIsReference; @@ -3910,7 +3912,8 @@ /// reasonable to just ignore the returned alignment when it isn't from an /// explicit source. Address EmitPointerWithAlignment(const Expr *Addr, - LValueBaseInfo *BaseInfo = nullptr); + LValueBaseInfo *BaseInfo = nullptr, + TBAAAccessInfo *TBAAInfo = nullptr); void EmitSanitizerStatReport(llvm::SanitizerStatKind SSK); Index: lib/CodeGen/CodeGenFunction.cpp =================================================================== --- lib/CodeGen/CodeGenFunction.cpp +++ lib/CodeGen/CodeGenFunction.cpp @@ -118,9 +118,9 @@ } CharUnits CodeGenFunction::getNaturalPointeeTypeAlignment(QualType T, - LValueBaseInfo *BaseInfo) { - return getNaturalTypeAlignment(T->getPointeeType(), BaseInfo, - /* TBAAInfo= */ nullptr, + LValueBaseInfo *BaseInfo, + TBAAAccessInfo *TBAAInfo) { + return getNaturalTypeAlignment(T->getPointeeType(), BaseInfo, TBAAInfo, /* forPointeeType= */ true); } Index: lib/CodeGen/CodeGenModule.h =================================================================== --- lib/CodeGen/CodeGenModule.h +++ lib/CodeGen/CodeGenModule.h @@ -677,6 +677,11 @@ /// may-alias accesses. TBAAAccessInfo getTBAAMayAliasAccessInfo(); + /// mergeTBAAInfoForCast - Get merged TBAA information for the purposes of + /// type casts. + TBAAAccessInfo mergeTBAAInfoForCast(TBAAAccessInfo SourceInfo, + TBAAAccessInfo TargetInfo); + bool isTypeConstant(QualType QTy, bool ExcludeCtorDtor); bool isPaddedAtomicType(QualType type); Index: lib/CodeGen/CodeGenModule.cpp =================================================================== --- lib/CodeGen/CodeGenModule.cpp +++ lib/CodeGen/CodeGenModule.cpp @@ -612,6 +612,13 @@ return TBAA->getMayAliasAccessInfo(); } +TBAAAccessInfo CodeGenModule::mergeTBAAInfoForCast(TBAAAccessInfo SourceInfo, + TBAAAccessInfo TargetInfo) { + if (!TBAA) + return TBAAAccessInfo(); + return TBAA->mergeTBAAInfoForCast(SourceInfo, TargetInfo); +} + void CodeGenModule::DecorateInstructionWithTBAA(llvm::Instruction *Inst, TBAAAccessInfo TBAAInfo) { if (llvm::MDNode *Tag = getTBAAAccessTagInfo(TBAAInfo)) Index: lib/CodeGen/CodeGenTBAA.h =================================================================== --- lib/CodeGen/CodeGenTBAA.h +++ lib/CodeGen/CodeGenTBAA.h @@ -47,6 +47,12 @@ : TBAAAccessInfo(/* AccessType= */ nullptr) {} + bool operator==(const TBAAAccessInfo &Other) const { + return BaseType == Other.BaseType && + AccessType == Other.AccessType && + Offset == Other.Offset; + } + /// BaseType - The base/leading access type. May be null if this access /// descriptor represents an access that is not considered to be an access /// to an aggregate or union member. @@ -136,6 +142,11 @@ /// getMayAliasAccessInfo - Get TBAA information that represents may-alias /// accesses. TBAAAccessInfo getMayAliasAccessInfo(); + + /// mergeTBAAInfoForCast - Get merged TBAA information for the purpose of + /// type casts. + TBAAAccessInfo mergeTBAAInfoForCast(TBAAAccessInfo SourceInfo, + TBAAAccessInfo TargetInfo); }; } // end namespace CodeGen Index: lib/CodeGen/CodeGenTBAA.cpp =================================================================== --- lib/CodeGen/CodeGenTBAA.cpp +++ lib/CodeGen/CodeGenTBAA.cpp @@ -309,3 +309,11 @@ TBAAAccessInfo CodeGenTBAA::getMayAliasAccessInfo() { return TBAAAccessInfo(getChar()); } + +TBAAAccessInfo CodeGenTBAA::mergeTBAAInfoForCast(TBAAAccessInfo SourceInfo, + TBAAAccessInfo TargetInfo) { + TBAAAccessInfo MayAliasInfo = getMayAliasAccessInfo(); + if (SourceInfo == MayAliasInfo || TargetInfo == MayAliasInfo) + return MayAliasInfo; + return TargetInfo; +} Index: test/CodeGen/tbaa-cast.cpp =================================================================== --- test/CodeGen/tbaa-cast.cpp +++ test/CodeGen/tbaa-cast.cpp @@ -0,0 +1,23 @@ +// RUN: %clang_cc1 -triple x86_64-linux -O1 -disable-llvm-passes %s \ +// RUN: -emit-llvm -o - | FileCheck %s +// +// Check that we generate correct TBAA information for lvalues constructed +// with use of casts. + +struct V { + unsigned n; +}; + +struct S { + char bytes[4]; +}; + +void foo(S *p) { +// CHECK-LABEL: _Z3fooP1S +// CHECK: store i32 5, {{.*}}, !tbaa [[TAG_V_n:!.*]] + ((V*)p->bytes)->n = 5; +} + +// CHECK-DAG: [[TAG_V_n]] = !{[[TYPE_V:!.*]], [[TYPE_int:!.*]], i64 0} +// CHECK-DAG: [[TYPE_V]] = !{!"_ZTS1V", !{{.*}}, i64 0} +// CHECK-DAG: [[TYPE_int]] = !{!"int", !{{.*}}, i64 0}