diff --git a/llvm/lib/Transforms/Instrumentation/AddressSanitizer.cpp b/llvm/lib/Transforms/Instrumentation/AddressSanitizer.cpp --- a/llvm/lib/Transforms/Instrumentation/AddressSanitizer.cpp +++ b/llvm/lib/Transforms/Instrumentation/AddressSanitizer.cpp @@ -1439,6 +1439,37 @@ IsWrite, nullptr, UseCalls, Exp); } +static void SplitBlockAndInsertSimpleForLoop(Value *End, + Instruction *SplitBefore, + Instruction *&BodyIP, + Value *&Index) { + BasicBlock *LoopPred = SplitBefore->getParent(); + BasicBlock *LoopBody = SplitBlock(SplitBefore->getParent(), SplitBefore); + BasicBlock *LoopExit = SplitBlock(SplitBefore->getParent(), SplitBefore); + + auto *Ty = End->getType(); + auto &DL = SplitBefore->getModule()->getDataLayout(); + const unsigned Bitwidth = DL.getTypeSizeInBits(Ty); + + IRBuilder<> Builder(LoopBody->getTerminator()); + auto *IV = Builder.CreatePHI(Ty, 2, "iv"); + auto *IVNext = + Builder.CreateAdd(IV, ConstantInt::get(Ty, 1), IV->getName() + ".next", + /*HasNUW=*/true, /*HasNSW=*/Bitwidth != 2); + auto *IVCheck = Builder.CreateICmpEQ(IVNext, End, + IV->getName() + ".check"); + Builder.CreateCondBr(IVCheck, LoopExit, LoopBody); + LoopBody->getTerminator()->eraseFromParent(); + + // Populate the IV PHI. + IV->addIncoming(ConstantInt::get(Ty, 0), LoopPred); + IV->addIncoming(IVNext, LoopBody); + + BodyIP = LoopBody->getFirstNonPHI(); + Index = IV; +} + + static void instrumentMaskedLoadOrStore(AddressSanitizer *Pass, const DataLayout &DL, Type *IntptrTy, Value *Mask, Instruction *I, @@ -1446,36 +1477,65 @@ unsigned Granularity, Type *OpType, bool IsWrite, Value *SizeArgument, bool UseCalls, uint32_t Exp) { - auto *VTy = cast(OpType); - uint64_t ElemTypeSize = DL.getTypeStoreSizeInBits(VTy->getScalarType()); - unsigned Num = VTy->getNumElements(); + auto *VTy = cast(OpType); + + TypeSize ElemTypeSize = DL.getTypeStoreSizeInBits(VTy->getScalarType()); auto Zero = ConstantInt::get(IntptrTy, 0); - for (unsigned Idx = 0; Idx < Num; ++Idx) { - Value *InstrumentedAddress = nullptr; - Instruction *InsertBefore = I; - if (auto *Vector = dyn_cast(Mask)) { - // dyn_cast as we might get UndefValue - if (auto *Masked = dyn_cast(Vector->getOperand(Idx))) { - if (Masked->isZero()) - // Mask is constant false, so no instrumentation needed. - continue; - // If we have a true or undef value, fall through to doInstrumentAddress - // with InsertBefore == I + + // For fixed length vectors, it's legal to fallthrough into the generic loop + // lowering below, but we chose to unroll and specialize instead. We might want + // to revisit this heuristic decision. + if (auto *FVTy = dyn_cast(VTy)) { + unsigned Num = FVTy->getNumElements(); + for (unsigned Idx = 0; Idx < Num; ++Idx) { + Value *InstrumentedAddress = nullptr; + Instruction *InsertBefore = I; + if (auto *Vector = dyn_cast(Mask)) { + // dyn_cast as we might get UndefValue + if (auto *Masked = dyn_cast(Vector->getOperand(Idx))) { + if (Masked->isZero()) + // Mask is constant false, so no instrumentation needed. + continue; + // If we have a true or undef value, fall through to doInstrumentAddress + // with InsertBefore == I + } + } else { + IRBuilder<> IRB(I); + Value *MaskElem = IRB.CreateExtractElement(Mask, Idx); + Instruction *ThenTerm = SplitBlockAndInsertIfThen(MaskElem, I, false); + InsertBefore = ThenTerm; } - } else { - IRBuilder<> IRB(I); - Value *MaskElem = IRB.CreateExtractElement(Mask, Idx); - Instruction *ThenTerm = SplitBlockAndInsertIfThen(MaskElem, I, false); - InsertBefore = ThenTerm; - } - IRBuilder<> IRB(InsertBefore); - InstrumentedAddress = + IRBuilder<> IRB(InsertBefore); + InstrumentedAddress = IRB.CreateGEP(VTy, Addr, {Zero, ConstantInt::get(IntptrTy, Idx)}); - doInstrumentAddress(Pass, I, InsertBefore, InstrumentedAddress, Alignment, - Granularity, TypeSize::Fixed(ElemTypeSize), IsWrite, - SizeArgument, UseCalls, Exp); + doInstrumentAddress(Pass, I, InsertBefore, InstrumentedAddress, Alignment, + Granularity, ElemTypeSize, IsWrite, + SizeArgument, UseCalls, Exp); + } + return; } + + + IRBuilder<> IRB(I); + Constant *MinNumElem = + ConstantInt::get(IntptrTy, VTy->getElementCount().getKnownMinValue()); + assert(isa(VTy) && "generalize if reused for fixed length"); + Value *NumElements = IRB.CreateVScale(MinNumElem); + + Instruction *BodyIP; + Value *Index; + SplitBlockAndInsertSimpleForLoop(NumElements, I, BodyIP, Index); + + IRB.SetInsertPoint(BodyIP); + Value *MaskElem = IRB.CreateExtractElement(Mask, Index); + Instruction *ThenTerm = SplitBlockAndInsertIfThen(MaskElem, BodyIP, false); + IRB.SetInsertPoint(ThenTerm); + + Value *InstrumentedAddress = IRB.CreateGEP(VTy, Addr, {Zero, Index}); + doInstrumentAddress(Pass, I, &*IRB.GetInsertPoint(), InstrumentedAddress, Alignment, + Granularity, ElemTypeSize, IsWrite, SizeArgument, + UseCalls, Exp); } void AddressSanitizer::instrumentMop(ObjectSizeOffsetVisitor &ObjSizeVis, diff --git a/llvm/test/Instrumentation/AddressSanitizer/asan-masked-load-store.ll b/llvm/test/Instrumentation/AddressSanitizer/asan-masked-load-store.ll --- a/llvm/test/Instrumentation/AddressSanitizer/asan-masked-load-store.ll +++ b/llvm/test/Instrumentation/AddressSanitizer/asan-masked-load-store.ll @@ -308,3 +308,68 @@ %res2 = tail call <4 x float> @llvm.masked.load.v4f32.p0(ptr %p, i32 4, <4 x i1> , <4 x float> %arg) ret <4 x float> %res2 } + +;; Scalable vector tests +;; --------------------------- +declare @llvm.masked.load.nxv4f32.p0(ptr, i32, , ) +declare void @llvm.masked.store.nxv4f32.p0(, ptr, i32, ) + +define @scalable.load.nxv4f32(ptr %p, %mask) sanitize_address { +; CHECK-LABEL: @scalable.load.nxv4f32( +; CHECK-NEXT: [[TMP1:%.*]] = call i64 @llvm.vscale.i64() +; CHECK-NEXT: [[TMP2:%.*]] = mul i64 [[TMP1]], 4 +; CHECK-NEXT: br label [[DOTSPLIT:%.*]] +; CHECK: .split: +; CHECK-NEXT: [[IV:%.*]] = phi i64 [ 0, [[TMP0:%.*]] ], [ [[IV_NEXT:%.*]], [[TMP7:%.*]] ] +; CHECK-NEXT: [[TMP3:%.*]] = extractelement [[MASK:%.*]], i64 [[IV]] +; CHECK-NEXT: br i1 [[TMP3]], label [[TMP4:%.*]], label [[TMP7]] +; CHECK: 4: +; CHECK-NEXT: [[TMP5:%.*]] = getelementptr , ptr [[P:%.*]], i64 0, i64 [[IV]] +; CHECK-NEXT: [[TMP6:%.*]] = ptrtoint ptr [[TMP5]] to i64 +; CHECK-NEXT: call void @__asan_load4(i64 [[TMP6]]) +; CHECK-NEXT: br label [[TMP7]] +; CHECK: 7: +; CHECK-NEXT: [[IV_NEXT]] = add nuw nsw i64 [[IV]], 1 +; CHECK-NEXT: [[IV_CHECK:%.*]] = icmp eq i64 [[IV_NEXT]], [[TMP2]] +; CHECK-NEXT: br i1 [[IV_CHECK]], label [[DOTSPLIT_SPLIT:%.*]], label [[DOTSPLIT]] +; CHECK: .split.split: +; CHECK-NEXT: [[RES:%.*]] = tail call @llvm.masked.load.nxv4f32.p0(ptr [[P]], i32 4, [[MASK]], undef) +; CHECK-NEXT: ret [[RES]] +; +; DISABLED-LABEL: @scalable.load.nxv4f32( +; DISABLED-NEXT: [[RES:%.*]] = tail call @llvm.masked.load.nxv4f32.p0(ptr [[P:%.*]], i32 4, [[MASK:%.*]], undef) +; DISABLED-NEXT: ret [[RES]] +; + %res = tail call @llvm.masked.load.nxv4f32.p0(ptr %p, i32 4, %mask, undef) + ret %res +} + +define void @scalable.store.nxv4f32(ptr %p, %arg, %mask) sanitize_address { +; CHECK-LABEL: @scalable.store.nxv4f32( +; CHECK-NEXT: [[TMP1:%.*]] = call i64 @llvm.vscale.i64() +; CHECK-NEXT: [[TMP2:%.*]] = mul i64 [[TMP1]], 4 +; CHECK-NEXT: br label [[DOTSPLIT:%.*]] +; CHECK: .split: +; CHECK-NEXT: [[IV:%.*]] = phi i64 [ 0, [[TMP0:%.*]] ], [ [[IV_NEXT:%.*]], [[TMP7:%.*]] ] +; CHECK-NEXT: [[TMP3:%.*]] = extractelement [[MASK:%.*]], i64 [[IV]] +; CHECK-NEXT: br i1 [[TMP3]], label [[TMP4:%.*]], label [[TMP7]] +; CHECK: 4: +; CHECK-NEXT: [[TMP5:%.*]] = getelementptr , ptr [[P:%.*]], i64 0, i64 [[IV]] +; CHECK-NEXT: [[TMP6:%.*]] = ptrtoint ptr [[TMP5]] to i64 +; CHECK-NEXT: call void @__asan_store4(i64 [[TMP6]]) +; CHECK-NEXT: br label [[TMP7]] +; CHECK: 7: +; CHECK-NEXT: [[IV_NEXT]] = add nuw nsw i64 [[IV]], 1 +; CHECK-NEXT: [[IV_CHECK:%.*]] = icmp eq i64 [[IV_NEXT]], [[TMP2]] +; CHECK-NEXT: br i1 [[IV_CHECK]], label [[DOTSPLIT_SPLIT:%.*]], label [[DOTSPLIT]] +; CHECK: .split.split: +; CHECK-NEXT: tail call void @llvm.masked.store.nxv4f32.p0( [[ARG:%.*]], ptr [[P]], i32 4, [[MASK]]) +; CHECK-NEXT: ret void +; +; DISABLED-LABEL: @scalable.store.nxv4f32( +; DISABLED-NEXT: tail call void @llvm.masked.store.nxv4f32.p0( [[ARG:%.*]], ptr [[P:%.*]], i32 4, [[MASK:%.*]]) +; DISABLED-NEXT: ret void +; + tail call void @llvm.masked.store.nxv4f32.p0( %arg, ptr %p, i32 4, %mask) + ret void +}