diff --git a/llvm/lib/Transforms/Instrumentation/ThreadSanitizer.cpp b/llvm/lib/Transforms/Instrumentation/ThreadSanitizer.cpp --- a/llvm/lib/Transforms/Instrumentation/ThreadSanitizer.cpp +++ b/llvm/lib/Transforms/Instrumentation/ThreadSanitizer.cpp @@ -31,6 +31,7 @@ #include "llvm/IR/DataLayout.h" #include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Intrinsics.h" #include "llvm/IR/LLVMContext.h" @@ -141,7 +142,7 @@ SmallVectorImpl &All, const DataLayout &DL); bool addrPointsToConstantData(Value *Addr); - int getMemoryAccessFuncIndex(Value *Addr, const DataLayout &DL); + int getMemoryAccessFuncIndex(Type *OrigTy, Value *Addr, const DataLayout &DL); void InsertRuntimeIgnores(Function &F); Type *IntptrTy; @@ -614,6 +615,7 @@ const bool IsWrite = isa(*II.Inst); Value *Addr = IsWrite ? cast(II.Inst)->getPointerOperand() : cast(II.Inst)->getPointerOperand(); + Type *OrigTy = getLoadStoreType(II.Inst); // swifterror memory addresses are mem2reg promoted by instruction selection. // As such they cannot have regular uses like an instrumentation function and @@ -621,7 +623,7 @@ if (Addr->isSwiftError()) return false; - int Idx = getMemoryAccessFuncIndex(Addr, DL); + int Idx = getMemoryAccessFuncIndex(OrigTy, Addr, DL); if (Idx < 0) return false; if (IsWrite && isVtableAccess(II.Inst)) { @@ -658,7 +660,6 @@ : cast(II.Inst)->isVolatile()); assert((!IsVolatile || !IsCompoundRW) && "Compound volatile invalid!"); - Type *OrigTy = cast(Addr->getType())->getElementType(); const uint32_t TypeSize = DL.getTypeStoreSizeInBits(OrigTy); FunctionCallee OnAccessFunc = nullptr; if (Alignment == 0 || Alignment >= 8 || (Alignment % (TypeSize / 8)) == 0) { @@ -742,7 +743,8 @@ IRBuilder<> IRB(I); if (LoadInst *LI = dyn_cast(I)) { Value *Addr = LI->getPointerOperand(); - int Idx = getMemoryAccessFuncIndex(Addr, DL); + Type *OrigTy = LI->getType(); + int Idx = getMemoryAccessFuncIndex(OrigTy, Addr, DL); if (Idx < 0) return false; const unsigned ByteSize = 1U << Idx; @@ -751,13 +753,13 @@ Type *PtrTy = Ty->getPointerTo(); Value *Args[] = {IRB.CreatePointerCast(Addr, PtrTy), createOrdering(&IRB, LI->getOrdering())}; - Type *OrigTy = cast(Addr->getType())->getElementType(); Value *C = IRB.CreateCall(TsanAtomicLoad[Idx], Args); Value *Cast = IRB.CreateBitOrPointerCast(C, OrigTy); I->replaceAllUsesWith(Cast); } else if (StoreInst *SI = dyn_cast(I)) { Value *Addr = SI->getPointerOperand(); - int Idx = getMemoryAccessFuncIndex(Addr, DL); + int Idx = + getMemoryAccessFuncIndex(SI->getValueOperand()->getType(), Addr, DL); if (Idx < 0) return false; const unsigned ByteSize = 1U << Idx; @@ -771,7 +773,8 @@ ReplaceInstWithInst(I, C); } else if (AtomicRMWInst *RMWI = dyn_cast(I)) { Value *Addr = RMWI->getPointerOperand(); - int Idx = getMemoryAccessFuncIndex(Addr, DL); + int Idx = + getMemoryAccessFuncIndex(RMWI->getValOperand()->getType(), Addr, DL); if (Idx < 0) return false; FunctionCallee F = TsanAtomicRMW[RMWI->getOperation()][Idx]; @@ -788,7 +791,8 @@ ReplaceInstWithInst(I, C); } else if (AtomicCmpXchgInst *CASI = dyn_cast(I)) { Value *Addr = CASI->getPointerOperand(); - int Idx = getMemoryAccessFuncIndex(Addr, DL); + Type *OrigOldValTy = CASI->getNewValOperand()->getType(); + int Idx = getMemoryAccessFuncIndex(OrigOldValTy, Addr, DL); if (Idx < 0) return false; const unsigned ByteSize = 1U << Idx; @@ -807,7 +811,6 @@ CallInst *C = IRB.CreateCall(TsanAtomicCAS[Idx], Args); Value *Success = IRB.CreateICmpEQ(C, CmpOperand); Value *OldVal = C; - Type *OrigOldValTy = CASI->getNewValOperand()->getType(); if (Ty != OrigOldValTy) { // The value is a pointer, so we need to cast the return value. OldVal = IRB.CreateIntToPtr(C, OrigOldValTy); @@ -830,10 +833,8 @@ return true; } -int ThreadSanitizer::getMemoryAccessFuncIndex(Value *Addr, +int ThreadSanitizer::getMemoryAccessFuncIndex(Type *OrigTy, Value *Addr, const DataLayout &DL) { - Type *OrigPtrTy = Addr->getType(); - Type *OrigTy = cast(OrigPtrTy)->getElementType(); assert(OrigTy->isSized()); uint32_t TypeSize = DL.getTypeStoreSizeInBits(OrigTy); if (TypeSize != 8 && TypeSize != 16 &&