diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h --- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h +++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h @@ -308,24 +308,34 @@ return false; } - bool isLegalNTStore(Type *DataType, Align Alignment) { + bool isLegalNTStoreLoad(Type *DataType, Align Alignment) { // NOTE: The logic below is mostly geared towards LV, which calls it with // vectors with 2 elements. We might want to improve that, if other // users show up. - // Nontemporal vector stores can be directly lowered to STNP, if the vector - // can be halved so that each half fits into a register. That's the case if - // the element type fits into a register and the number of elements is a - // power of 2 > 1. - if (auto *DataTypeVTy = dyn_cast(DataType)) { - unsigned NumElements = - cast(DataTypeVTy)->getNumElements(); - unsigned EltSize = DataTypeVTy->getElementType()->getScalarSizeInBits(); + // Nontemporal vector loads/stores can be directly lowered to LDNP/STNP, if + // the vector can be halved so that each half fits into a register. That's + // the case if the element type fits into a register and the number of + // elements is a power of 2 > 1. + if (auto *DataTypeTy = dyn_cast(DataType)) { + unsigned NumElements = DataTypeTy->getNumElements(); + unsigned EltSize = DataTypeTy->getElementType()->getScalarSizeInBits(); return NumElements > 1 && isPowerOf2_64(NumElements) && EltSize >= 8 && EltSize <= 128 && isPowerOf2_64(EltSize); } return BaseT::isLegalNTStore(DataType, Alignment); } + bool isLegalNTStore(Type *DataType, Align Alignment) { + return isLegalNTStoreLoad(DataType, Alignment); + } + + bool isLegalNTLoad(Type *DataType, Align Alignment) { + // Only supports little-endian targets. + if (ST->isLittleEndian()) + return isLegalNTStoreLoad(DataType, Alignment); + return BaseT::isLegalNTLoad(DataType, Alignment); + } + bool enableOrderedReductions() const { return true; } InstructionCost getInterleavedMemoryOpCost( diff --git a/llvm/test/Transforms/LoopVectorize/AArch64/nontemporal-load-store.ll b/llvm/test/Transforms/LoopVectorize/AArch64/nontemporal-load-store.ll --- a/llvm/test/Transforms/LoopVectorize/AArch64/nontemporal-load-store.ll +++ b/llvm/test/Transforms/LoopVectorize/AArch64/nontemporal-load-store.ll @@ -258,8 +258,7 @@ define i4 @test_i4_load(i4* %ddst) { ; CHECK-LABEL: define i4 @test_i4_load -; CHECK-LABEL: vector.body: -; CHECK: [[LOAD:%.*]] = load i4, i4* {{.*}}, align 1, !nontemporal !0 +; CHECK-NOT: vector.body: ; CHECk: ret i4 %{{.*}} ; entry: @@ -281,7 +280,8 @@ define i8 @test_load_i8(i8* %ddst) { ; CHECK-LABEL: @test_load_i8( -; CHECK-NOT: vector.body: +; CHECK: vector.body: +; CHECK: load <4 x i8>, <4 x i8>* {{.*}}, align 1, !nontemporal !0 ; CHECk: ret i8 %{{.*}} ; entry: @@ -303,7 +303,8 @@ define half @test_half_load(half* %ddst) { ; CHECK-LABEL: @test_half_load -; CHECK-NOT: vector.body: +; CHECK-LABEL: vector.body: +; CHECK: load <4 x half>, <4 x half>* {{.*}}, align 2, !nontemporal !0 ; CHECk: ret half %{{.*}} ; entry: @@ -325,7 +326,8 @@ define i16 @test_i16_load(i16* %ddst) { ; CHECK-LABEL: @test_i16_load -; CHECK-NOT: vector.body: +; CHECK-LABEL: vector.body: +; CHECK: load <4 x i16>, <4 x i16>* {{.*}}, align 2, !nontemporal !0 ; CHECk: ret i16 %{{.*}} ; entry: @@ -347,7 +349,8 @@ define i32 @test_i32_load(i32* %ddst) { ; CHECK-LABEL: @test_i32_load -; CHECK-NOT: vector.body: +; CHECK-LABEL: vector.body: +; CHECK: load <4 x i32>, <4 x i32>* {{.*}}, align 4, !nontemporal !0 ; CHECk: ret i32 %{{.*}} ; entry: @@ -413,7 +416,8 @@ define i64 @test_i64_load(i64* %ddst) { ; CHECK-LABEL: @test_i64_load -; CHECK-NOT: vector.body: +; CHECK-LABEL: vector.body: +; CHECK: load <4 x i64>, <4 x i64>* {{.*}}, align 4, !nontemporal !0 ; CHECk: ret i64 %{{.*}} ; entry: @@ -435,7 +439,8 @@ define double @test_double_load(double* %ddst) { ; CHECK-LABEL: @test_double_load -; CHECK-NOT: vector.body: +; CHECK-LABEL: vector.body: +; CHECK: load <4 x double>, <4 x double>* {{.*}}, align 4, !nontemporal !0 ; CHECk: ret double %{{.*}} ; entry: @@ -457,7 +462,8 @@ define i128 @test_i128_load(i128* %ddst) { ; CHECK-LABEL: @test_i128_load -; CHECK-NOT: vector.body: +; CHECK-LABEL: vector.body: +; CHECK: load <4 x i128>, <4 x i128>* {{.*}}, align 4, !nontemporal !0 ; CHECk: ret i128 %{{.*}} ; entry: