Index: clang/include/clang/Basic/Sanitizers.def =================================================================== --- clang/include/clang/Basic/Sanitizers.def +++ clang/include/clang/Basic/Sanitizers.def @@ -104,12 +104,13 @@ SANITIZER("cfi-cast-strict", CFICastStrict) SANITIZER("cfi-derived-cast", CFIDerivedCast) SANITIZER("cfi-icall", CFIICall) +SANITIZER("cfi-mfcall", CFIMFCall) SANITIZER("cfi-unrelated-cast", CFIUnrelatedCast) SANITIZER("cfi-nvcall", CFINVCall) SANITIZER("cfi-vcall", CFIVCall) SANITIZER_GROUP("cfi", CFI, - CFIDerivedCast | CFIICall | CFIUnrelatedCast | CFINVCall | - CFIVCall) + CFIDerivedCast | CFIICall | CFIMFCall | CFIUnrelatedCast | + CFINVCall | CFIVCall) // Safe Stack SANITIZER("safe-stack", SafeStack) Index: clang/lib/CodeGen/CGVTables.cpp =================================================================== --- clang/lib/CodeGen/CGVTables.cpp +++ clang/lib/CodeGen/CGVTables.cpp @@ -1012,30 +1012,29 @@ CharUnits PointerWidth = Context.toCharUnitsFromBits(Context.getTargetInfo().getPointerWidth(0)); - typedef std::pair TypeMetadata; - std::vector TypeMetadatas; - // Create type metadata for each address point. + typedef std::pair AddressPoint; + std::vector AddressPoints; for (auto &&AP : VTLayout.getAddressPoints()) - TypeMetadatas.push_back(std::make_pair( + AddressPoints.push_back(std::make_pair( AP.first.getBase(), VTLayout.getVTableOffset(AP.second.VTableIndex) + AP.second.AddressPointIndex)); - // Sort the type metadata for determinism. - llvm::sort(TypeMetadatas.begin(), TypeMetadatas.end(), - [this](const TypeMetadata &M1, const TypeMetadata &M2) { - if (&M1 == &M2) + // Sort the address points for determinism. + llvm::sort(AddressPoints.begin(), AddressPoints.end(), + [this](const AddressPoint &AP1, const AddressPoint &AP2) { + if (&AP1 == &AP2) return false; std::string S1; llvm::raw_string_ostream O1(S1); getCXXABI().getMangleContext().mangleTypeName( - QualType(M1.first->getTypeForDecl(), 0), O1); + QualType(AP1.first->getTypeForDecl(), 0), O1); O1.flush(); std::string S2; llvm::raw_string_ostream O2(S2); getCXXABI().getMangleContext().mangleTypeName( - QualType(M2.first->getTypeForDecl(), 0), O2); + QualType(AP2.first->getTypeForDecl(), 0), O2); O2.flush(); if (S1 < S2) @@ -1043,10 +1042,26 @@ if (S1 != S2) return false; - return M1.second < M2.second; + return AP1.second < AP2.second; }); - for (auto TypeMetadata : TypeMetadatas) - AddVTableTypeMetadata(VTable, PointerWidth * TypeMetadata.second, - TypeMetadata.first); + ArrayRef Comps = VTLayout.vtable_components(); + for (auto AP : AddressPoints) { + // Create type metadata for the address point. + AddVTableTypeMetadata(VTable, PointerWidth * AP.second, AP.first); + + // The class associated with each address point could also potentially be + // used for indirect calls via a member function pointer, so we need to + // annotate the address of each function pointer with the appropriate member + // function pointer type. + for (unsigned I = 0; I != Comps.size(); ++I) { + if (Comps[I].getKind() != VTableComponent::CK_FunctionPointer) + continue; + llvm::Metadata *MD = CreateMetadataIdentifierForVirtualMemPtrType( + Context.getMemberPointerType( + Comps[I].getFunctionDecl()->getType(), + Context.getRecordType(AP.first).getTypePtr())); + VTable->addTypeMetadata((PointerWidth * I).getQuantity(), MD); + } + } } Index: clang/lib/CodeGen/CodeGenModule.h =================================================================== --- clang/lib/CodeGen/CodeGenModule.h +++ clang/lib/CodeGen/CodeGenModule.h @@ -503,6 +503,7 @@ /// MDNodes. typedef llvm::DenseMap MetadataTypeMap; MetadataTypeMap MetadataIdMap; + MetadataTypeMap VirtualMetadataIdMap; MetadataTypeMap GeneralizedMetadataIdMap; public: @@ -1229,6 +1230,10 @@ /// internal identifiers). llvm::Metadata *CreateMetadataIdentifierForType(QualType T); + /// Create a metadata identifier that is intended to be used to check virtual + /// calls via a member function pointer. + llvm::Metadata *CreateMetadataIdentifierForVirtualMemPtrType(QualType T); + /// Create a metadata identifier for the generalization of the given type. /// This may either be an MDString (for external identifiers) or a distinct /// unnamed MDNode (for internal identifiers). @@ -1244,6 +1249,9 @@ void AddVTableTypeMetadata(llvm::GlobalVariable *VTable, CharUnits Offset, const CXXRecordDecl *RD); + std::vector + getMostBaseClasses(const CXXRecordDecl *RD); + /// Get the declaration of std::terminate for the platform. llvm::Constant *getTerminateFn(); @@ -1405,6 +1413,9 @@ void ConstructDefaultFnAttrList(StringRef Name, bool HasOptnone, bool AttrOnCallSite, llvm::AttrBuilder &FuncAttrs); + + llvm::Metadata *CreateMetadataIdentifierImpl(QualType T, MetadataTypeMap &Map, + StringRef Suffix); }; } // end namespace CodeGen Index: clang/lib/CodeGen/CodeGenModule.cpp =================================================================== --- clang/lib/CodeGen/CodeGenModule.cpp +++ clang/lib/CodeGen/CodeGenModule.cpp @@ -1378,14 +1378,45 @@ GV->setLinkage(llvm::GlobalValue::ExternalWeakLinkage); } +std::vector +CodeGenModule::getMostBaseClasses(const CXXRecordDecl *RD) { + llvm::SetVector MostBases; + + std::function CollectMostBases; + CollectMostBases = [&](const CXXRecordDecl *RD) { + if (RD->getNumBases() == 0) + MostBases.insert(RD); + for (const CXXBaseSpecifier &B : RD->bases()) + CollectMostBases(B.getType()->getAsCXXRecordDecl()); + }; + CollectMostBases(RD); + return MostBases.takeVector(); +} + void CodeGenModule::CreateFunctionTypeMetadata(const FunctionDecl *FD, llvm::Function *F) { - // Only if we are checking indirect calls. - if (!LangOpts.Sanitize.has(SanitizerKind::CFIICall)) + auto *MD = dyn_cast(FD); + if (MD && !MD->isStatic()) { + if (!LangOpts.Sanitize.has(SanitizerKind::CFIMFCall)) + return; + + // Only functions whose address can be taken with a member function pointer + // need type metadata. + if (MD->isVirtual() || isa(MD) || + isa(MD)) + return; + + for (const CXXRecordDecl *Base : getMostBaseClasses(MD->getParent())) { + llvm::Metadata *Id = + CreateMetadataIdentifierForType(Context.getMemberPointerType( + FD->getType(), Context.getRecordType(Base).getTypePtr())); + F->addTypeMetadata(0, Id); + } return; + } - // Non-static class methods are handled via vtable pointer checks elsewhere. - if (isa(FD) && !cast(FD)->isStatic()) + // Only if we are checking indirect calls. + if (!LangOpts.Sanitize.has(SanitizerKind::CFIICall)) return; // Additionally, if building with cross-DSO support... @@ -1396,13 +1427,13 @@ return; } - llvm::Metadata *MD = CreateMetadataIdentifierForType(FD->getType()); - F->addTypeMetadata(0, MD); + llvm::Metadata *Id = CreateMetadataIdentifierForType(FD->getType()); + F->addTypeMetadata(0, Id); F->addTypeMetadata(0, CreateMetadataIdentifierGeneralized(FD->getType())); // Emit a hash-based bit set entry for cross-DSO calls. if (CodeGenOpts.SanitizeCfiCrossDso) - if (auto CrossDsoTypeId = CreateCrossDsoCfiTypeId(MD)) + if (auto CrossDsoTypeId = CreateCrossDsoCfiTypeId(Id)) F->addTypeMetadata(0, llvm::ConstantAsMetadata::get(CrossDsoTypeId)); } @@ -4928,8 +4959,10 @@ } } -llvm::Metadata *CodeGenModule::CreateMetadataIdentifierForType(QualType T) { - llvm::Metadata *&InternalId = MetadataIdMap[T.getCanonicalType()]; +llvm::Metadata * +CodeGenModule::CreateMetadataIdentifierImpl(QualType T, MetadataTypeMap &Map, + StringRef Suffix) { + llvm::Metadata *&InternalId = Map[T.getCanonicalType()]; if (InternalId) return InternalId; @@ -4937,6 +4970,7 @@ std::string OutName; llvm::raw_string_ostream Out(OutName); getCXXABI().getMangleContext().mangleTypeName(T, Out); + Out << Suffix; InternalId = llvm::MDString::get(getLLVMContext(), Out.str()); } else { @@ -4947,6 +4981,15 @@ return InternalId; } +llvm::Metadata *CodeGenModule::CreateMetadataIdentifierForType(QualType T) { + return CreateMetadataIdentifierImpl(T, MetadataIdMap, ""); +} + +llvm::Metadata * +CodeGenModule::CreateMetadataIdentifierForVirtualMemPtrType(QualType T) { + return CreateMetadataIdentifierImpl(T, VirtualMetadataIdMap, ".virtual"); +} + // Generalize pointer types to a void pointer with the qualifiers of the // originally pointed-to type, e.g. 'const char *' and 'char * const *' // generalize to 'const void *' while 'char *' and 'const char **' generalize to @@ -4980,25 +5023,8 @@ } llvm::Metadata *CodeGenModule::CreateMetadataIdentifierGeneralized(QualType T) { - T = GeneralizeFunctionType(getContext(), T); - - llvm::Metadata *&InternalId = GeneralizedMetadataIdMap[T.getCanonicalType()]; - if (InternalId) - return InternalId; - - if (isExternallyVisible(T->getLinkage())) { - std::string OutName; - llvm::raw_string_ostream Out(OutName); - getCXXABI().getMangleContext().mangleTypeName(T, Out); - Out << ".generalized"; - - InternalId = llvm::MDString::get(getLLVMContext(), Out.str()); - } else { - InternalId = llvm::MDNode::getDistinct(getLLVMContext(), - llvm::ArrayRef()); - } - - return InternalId; + return CreateMetadataIdentifierImpl(GeneralizeFunctionType(getContext(), T), + GeneralizedMetadataIdMap, ".generalized"); } /// Returns whether this module needs the "all-vtables" type identifier. Index: clang/lib/CodeGen/ItaniumCXXABI.cpp =================================================================== --- clang/lib/CodeGen/ItaniumCXXABI.cpp +++ clang/lib/CodeGen/ItaniumCXXABI.cpp @@ -621,10 +621,26 @@ VTableOffset = Builder.CreateTrunc(VTableOffset, CGF.Int32Ty); VTableOffset = Builder.CreateZExt(VTableOffset, CGM.PtrDiffTy); } + // Compute the address of the virtual function pointer. VTable = Builder.CreateGEP(VTable, VTableOffset); + VTable = Builder.CreateBitCast(VTable, FTy->getPointerTo()->getPointerTo()); + + // Check the address of the function pointer if CFI on member function + // pointers is enabled. + if (CGF.SanOpts.has(SanitizerKind::CFIMFCall)) { + llvm::Metadata *MD = + CGM.CreateMetadataIdentifierForVirtualMemPtrType(QualType(MPT, 0)); + llvm::Value *TypeId = llvm::MetadataAsValue::get(CGF.getLLVMContext(), MD); + + llvm::Value *CastedVTable = Builder.CreateBitCast(VTable, CGF.Int8PtrTy); + llvm::Value *TypeTest = Builder.CreateCall( + CGM.getIntrinsic(llvm::Intrinsic::type_test), {CastedVTable, TypeId}); + + CGF.EmitTrapCheck(TypeTest); + FnVirtual = Builder.GetInsertBlock(); + } // Load the virtual function to call. - VTable = Builder.CreateBitCast(VTable, FTy->getPointerTo()->getPointerTo()); llvm::Value *VirtualFn = Builder.CreateAlignedLoad(VTable, CGF.getPointerAlign(), "memptr.virtualfn"); @@ -636,6 +652,27 @@ llvm::Value *NonVirtualFn = Builder.CreateIntToPtr(FnAsInt, FTy->getPointerTo(), "memptr.nonvirtualfn"); + // Check the function pointer if CFI on member function pointers is enabled. + if (CGF.SanOpts.has(SanitizerKind::CFIMFCall)) { + llvm::Value *Bit = Builder.getFalse(); + llvm::Value *CastedNonVirtualFn = + Builder.CreateBitCast(NonVirtualFn, CGF.Int8PtrTy); + for (const CXXRecordDecl *Base : + CGM.getMostBaseClasses(MPT->getClass()->getAsCXXRecordDecl())) { + llvm::Metadata *MD = + CGM.CreateMetadataIdentifierForType(getContext().getMemberPointerType( + MPT->getPointeeType(), + getContext().getRecordType(Base).getTypePtr())); + llvm::Value *TypeId = llvm::MetadataAsValue::get(CGF.getLLVMContext(), MD); + + llvm::Value *TypeTest = Builder.CreateCall( + CGM.getIntrinsic(llvm::Intrinsic::type_test), {CastedNonVirtualFn, TypeId}); + Bit = Builder.CreateOr(Bit, TypeTest); + } + CGF.EmitTrapCheck(Bit); + FnNonVirtual = Builder.GetInsertBlock(); + } + // We're done. CGF.EmitBlock(FnEnd); llvm::PHINode *CalleePtr = Builder.CreatePHI(FTy->getPointerTo(), 2); Index: clang/lib/Driver/SanitizerArgs.cpp =================================================================== --- clang/lib/Driver/SanitizerArgs.cpp +++ clang/lib/Driver/SanitizerArgs.cpp @@ -44,7 +44,8 @@ TrappingSupported = (Undefined & ~Vptr) | UnsignedIntegerOverflow | Nullability | LocalBounds | CFI, TrappingDefault = CFI, - CFIClasses = CFIVCall | CFINVCall | CFIDerivedCast | CFIUnrelatedCast, + CFIClasses = + CFIVCall | CFINVCall | CFIMFCall | CFIDerivedCast | CFIUnrelatedCast, CompatibleWithMinimalRuntime = TrappingSupported, };