diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp --- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp +++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp @@ -864,6 +864,81 @@ return None; } +static Optional instCombineLD1GatherIndex(InstCombiner &IC, + IntrinsicInst &II) { + Value *Mask = II.getOperand(0); + Value *BasePtr = II.getOperand(1); + Value *Index = II.getOperand(2); + Type *Ty = II.getType(); + Type *BasePtrTy = BasePtr->getType(); + Value *PassThru = ConstantAggregateZero::get(Ty); + + // Contiguous gather => masked load. + // (sve.ld1.gather.index Mask BasePtr (sve.index IndexBase 1)) + // => (masked.load (gep BasePtr IndexBase) Align Mask zeroinitializer) + Value *IndexBase; + if (match(Index, m_Intrinsic( + m_Value(IndexBase), m_SpecificInt(1)))) { + IRBuilder<> Builder(II.getContext()); + Builder.SetInsertPoint(&II); + + uint64_t AlignN = + BasePtr->getPointerAlignment(II.getModule()->getDataLayout()).value(); + Constant *Align = + ConstantInt::get(IntegerType::getInt32Ty(II.getContext()), AlignN); + + Type *VecPtrTy = PointerType::getUnqual(Ty); + Value *Ptr = Builder.CreateGEP(BasePtrTy->getPointerElementType(), BasePtr, + IndexBase); + Ptr = Builder.CreateBitCast(Ptr, VecPtrTy); + auto MaskedLoad = Builder.CreateIntrinsic( + Intrinsic::masked_load, {Ty, VecPtrTy}, {Ptr, Align, Mask, PassThru}); + MaskedLoad->takeName(&II); + return IC.replaceInstUsesWith(II, MaskedLoad); + } + + return None; +} + +static Optional instCombineST1ScatterIndex(InstCombiner &IC, + IntrinsicInst &II) { + Value *Val = II.getOperand(0); + Value *Mask = II.getOperand(1); + Value *BasePtr = II.getOperand(2); + Value *Index = II.getOperand(3); + Type *Ty = Val->getType(); + Type *BasePtrTy = BasePtr->getType(); + + // Contiguous scatter => masked store. + // (sve.ld1.scatter.index Value Mask BasePtr (sve.index IndexBase 1)) + // => (masked.store Value (gep BasePtr IndexBase) Align Mask) + Value *IndexBase; + if (match(Index, m_Intrinsic( + m_Value(IndexBase), m_SpecificInt(1)))) { + IRBuilder<> Builder(II.getContext()); + Builder.SetInsertPoint(&II); + + uint64_t AlignN = + BasePtr->getPointerAlignment(II.getModule()->getDataLayout()).value(); + Constant *Align = + ConstantInt::get(IntegerType::getInt32Ty(II.getContext()), AlignN); + + Value *Ptr = Builder.CreateGEP(BasePtrTy->getPointerElementType(), BasePtr, + IndexBase); + Type *VecPtrTy = PointerType::getUnqual(Ty); + Ptr = Builder.CreateBitCast(Ptr, VecPtrTy); + + auto MaskedStore = Builder.CreateIntrinsic( + Intrinsic::masked_store, {Ty, VecPtrTy}, {Val, Ptr, Align, Mask}); + MaskedStore->takeName(&II); + II.eraseFromParent(); + + return IC.replaceInstUsesWith(II, MaskedStore); + } + + return None; +} + Optional AArch64TTIImpl::instCombineIntrinsic(InstCombiner &IC, IntrinsicInst &II) const { @@ -915,6 +990,10 @@ case Intrinsic::aarch64_sve_zip1: case Intrinsic::aarch64_sve_zip2: return instCombineSVEZip(IC, II); + case Intrinsic::aarch64_sve_ld1_gather_index: + return instCombineLD1GatherIndex(IC, II); + case Intrinsic::aarch64_sve_st1_scatter_index: + return instCombineST1ScatterIndex(IC, II); } return None; diff --git a/llvm/test/Transforms/InstCombine/AArch64/sve-intrinsic-gatherscatter.ll b/llvm/test/Transforms/InstCombine/AArch64/sve-intrinsic-gatherscatter.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/InstCombine/AArch64/sve-intrinsic-gatherscatter.ll @@ -0,0 +1,88 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; RUN: opt -S -instcombine -dce < %s | FileCheck %s + +target triple = "aarch64-unknown-linux-gnu" + +;; +;; Gathers. +;; + +define @test.ld1.gather.index.nxv2f64.stride1( %pred, double* %x, i64 %base) #0 { +; CHECK-LABEL: @test.ld1.gather.index.nxv2f64.stride1( +; CHECK-NEXT: [[TMP1:%.*]] = getelementptr double, double* [[X:%.*]], i64 [[BASE:%.*]] +; CHECK-NEXT: [[TMP2:%.*]] = bitcast double* [[TMP1]] to * +; CHECK-NEXT: [[LD:%.*]] = call @llvm.masked.load.nxv2f64.p0nxv2f64(* [[TMP2]], i32 1, [[PRED:%.*]], zeroinitializer) +; CHECK-NEXT: ret [[LD]] +; + %idx = tail call @llvm.aarch64.sve.index.nxv2i64(i64 %base, i64 1) + %ld = tail call @llvm.aarch64.sve.ld1.gather.index.nxv2f64( %pred, double* %x, %idx) + ret %ld +} + +define @test.ld1.gather.index.nxv2f64.stride2.negtest( %pred, double* %x, i64 %base) #0 { +; CHECK-LABEL: @test.ld1.gather.index.nxv2f64.stride2.negtest( +; CHECK-NEXT: [[IDX:%.*]] = tail call @llvm.aarch64.sve.index.nxv2i64(i64 [[BASE:%.*]], i64 2) +; CHECK-NEXT: [[LD:%.*]] = tail call @llvm.aarch64.sve.ld1.gather.index.nxv2f64( [[PRED:%.*]], double* [[X:%.*]], [[IDX]]) +; CHECK-NEXT: ret [[LD]] +; + %idx = tail call @llvm.aarch64.sve.index.nxv2i64(i64 %base, i64 2) + %ld = tail call @llvm.aarch64.sve.ld1.gather.index.nxv2f64( %pred, double* %x, %idx) + ret %ld +} + +define @test.ld1.gather.index.nxv2f64.stride1.align8( %pred, double* align 8 %x, i64 %base) #0 { +; CHECK-LABEL: @test.ld1.gather.index.nxv2f64.stride1.align8( +; CHECK-NEXT: [[TMP1:%.*]] = getelementptr double, double* [[X:%.*]], i64 [[BASE:%.*]] +; CHECK-NEXT: [[TMP2:%.*]] = bitcast double* [[TMP1]] to * +; CHECK-NEXT: [[LD:%.*]] = call @llvm.masked.load.nxv2f64.p0nxv2f64(* [[TMP2]], i32 8, [[PRED:%.*]], zeroinitializer) +; CHECK-NEXT: ret [[LD]] +; + %idx = tail call @llvm.aarch64.sve.index.nxv2i64(i64 %base, i64 1) + %ld = tail call @llvm.aarch64.sve.ld1.gather.index.nxv2f64( %pred, double* %x, %idx) + ret %ld +} + +;; +;; Scatters. +;; + +define void @test.st1.scatter.index.nxv2f64.stride1( %pred, double* %x, i64 %base, %val) #0 { +; CHECK-LABEL: @test.st1.scatter.index.nxv2f64.stride1( +; CHECK-NEXT: [[TMP1:%.*]] = getelementptr double, double* [[X:%.*]], i64 [[BASE:%.*]] +; CHECK-NEXT: [[TMP2:%.*]] = bitcast double* [[TMP1]] to * +; CHECK-NEXT: call void @llvm.masked.store.nxv2f64.p0nxv2f64( [[VAL:%.*]], * [[TMP2]], i32 1, [[PRED:%.*]]) +; CHECK-NEXT: ret void +; + %idx = tail call @llvm.aarch64.sve.index.nxv2i64(i64 %base, i64 1) + tail call void @llvm.aarch64.sve.st1.scatter.index.nxv2f64( %val, %pred, double* %x, %idx) + ret void +} + +define void @test.st1.scatter.index.nxv2f64.stride2.negtest( %pred, double* %x, i64 %base, %val) #0 { +; CHECK-LABEL: @test.st1.scatter.index.nxv2f64.stride2.negtest( +; CHECK-NEXT: [[IDX:%.*]] = tail call @llvm.aarch64.sve.index.nxv2i64(i64 [[BASE:%.*]], i64 2) +; CHECK-NEXT: tail call void @llvm.aarch64.sve.st1.scatter.index.nxv2f64( [[VAL:%.*]], [[PRED:%.*]], double* [[X:%.*]], [[IDX]]) +; CHECK-NEXT: ret void +; + %idx = tail call @llvm.aarch64.sve.index.nxv2i64(i64 %base, i64 2) + tail call void @llvm.aarch64.sve.st1.scatter.index.nxv2f64( %val, %pred, double* %x, %idx) + ret void +} + +define void @test.st1.scatter.index.nxv2f64.stride1.align8( %pred, double* align 8 %x, i64 %base, %val) #0 { +; CHECK-LABEL: @test.st1.scatter.index.nxv2f64.stride1.align8( +; CHECK-NEXT: [[TMP1:%.*]] = getelementptr double, double* [[X:%.*]], i64 [[BASE:%.*]] +; CHECK-NEXT: [[TMP2:%.*]] = bitcast double* [[TMP1]] to * +; CHECK-NEXT: call void @llvm.masked.store.nxv2f64.p0nxv2f64( [[VAL:%.*]], * [[TMP2]], i32 8, [[PRED:%.*]]) +; CHECK-NEXT: ret void +; + %idx = tail call @llvm.aarch64.sve.index.nxv2i64(i64 %base, i64 1) + tail call void @llvm.aarch64.sve.st1.scatter.index.nxv2f64( %val, %pred, double* %x, %idx) + ret void +} + +declare @llvm.aarch64.sve.index.nxv2i64(i64, i64) +declare @llvm.aarch64.sve.ld1.gather.index.nxv2f64(, double*, ) +declare void @llvm.aarch64.sve.st1.scatter.index.nxv2f64(, , double*, ) + +attributes #0 = { "target-features"="+sve" }