Index: llvm/include/llvm/Analysis/VectorUtils.h =================================================================== --- llvm/include/llvm/Analysis/VectorUtils.h +++ llvm/include/llvm/Analysis/VectorUtils.h @@ -358,7 +358,7 @@ /// Get splat value if the input is a splat vector or return nullptr. /// The value may be extracted from a splat constants vector or from /// a sequence of instructions that broadcast a single value into a vector. -const Value *getSplatValue(const Value *V); +Value *getSplatValue(const Value *V); /// Return true if each element of the vector value \p V is poisoned or equal to /// every other non-poisoned element. If an index element is specified, either Index: llvm/lib/Analysis/VectorUtils.cpp =================================================================== --- llvm/lib/Analysis/VectorUtils.cpp +++ llvm/lib/Analysis/VectorUtils.cpp @@ -342,7 +342,7 @@ /// This function is not fully general. It checks only 2 cases: /// the input value is (1) a splat constant vector or (2) a sequence /// of instructions that broadcasts a scalar at element 0. -const llvm::Value *llvm::getSplatValue(const Value *V) { +Value *llvm::getSplatValue(const Value *V) { if (isa(V->getType())) if (auto *C = dyn_cast(V)) return C->getSplatValue(); Index: llvm/lib/CodeGen/CodeGenPrepare.cpp =================================================================== --- llvm/lib/CodeGen/CodeGenPrepare.cpp +++ llvm/lib/CodeGen/CodeGenPrepare.cpp @@ -5314,88 +5314,112 @@ /// zero index. bool CodeGenPrepare::optimizeGatherScatterInst(Instruction *MemoryInst, Value *Ptr) { - const GetElementPtrInst *GEP = dyn_cast(Ptr); - if (!GEP || !GEP->hasIndices()) + // FIXME: Support scalable vectors. + if (isa(Ptr->getType())) return false; - // If the GEP and the gather/scatter aren't in the same BB, don't optimize. - // FIXME: We should support this by sinking the GEP. - if (MemoryInst->getParent() != GEP->getParent()) - return false; - - SmallVector Ops(GEP->op_begin(), GEP->op_end()); + Value *NewAddr; - bool RewriteGEP = false; + if (const auto *GEP = dyn_cast(Ptr)) { + // Don't optimize GEPs that don't have indices. + if (!GEP->hasIndices()) + return false; - if (Ops[0]->getType()->isVectorTy()) { - Ops[0] = const_cast(getSplatValue(Ops[0])); - if (!Ops[0]) + // If the GEP and the gather/scatter aren't in the same BB, don't optimize. + // FIXME: We should support this by sinking the GEP. + if (MemoryInst->getParent() != GEP->getParent()) return false; - RewriteGEP = true; - } - unsigned FinalIndex = Ops.size() - 1; + SmallVector Ops(GEP->op_begin(), GEP->op_end()); - // Ensure all but the last index is 0. - // FIXME: This isn't strictly required. All that's required is that they are - // all scalars or splats. - for (unsigned i = 1; i < FinalIndex; ++i) { - auto *C = dyn_cast(Ops[i]); - if (!C) - return false; - if (isa(C->getType())) - C = C->getSplatValue(); - auto *CI = dyn_cast_or_null(C); - if (!CI || !CI->isZero()) - return false; - // Scalarize the index if needed. - Ops[i] = CI; - } - - // Try to scalarize the final index. - if (Ops[FinalIndex]->getType()->isVectorTy()) { - if (Value *V = const_cast(getSplatValue(Ops[FinalIndex]))) { - auto *C = dyn_cast(V); - // Don't scalarize all zeros vector. - if (!C || !C->isZero()) { - Ops[FinalIndex] = V; - RewriteGEP = true; - } + bool RewriteGEP = false; + + if (Ops[0]->getType()->isVectorTy()) { + Ops[0] = getSplatValue(Ops[0]); + if (!Ops[0]) + return false; + RewriteGEP = true; } - } - // If we made any changes or the we have extra operands, we need to generate - // new instructions. - if (!RewriteGEP && Ops.size() == 2) - return false; + unsigned FinalIndex = Ops.size() - 1; - unsigned NumElts = cast(Ptr->getType())->getNumElements(); + // Ensure all but the last index is 0. + // FIXME: This isn't strictly required. All that's required is that they are + // all scalars or splats. + for (unsigned i = 1; i < FinalIndex; ++i) { + auto *C = dyn_cast(Ops[i]); + if (!C) + return false; + if (isa(C->getType())) + C = C->getSplatValue(); + auto *CI = dyn_cast_or_null(C); + if (!CI || !CI->isZero()) + return false; + // Scalarize the index if needed. + Ops[i] = CI; + } + + // Try to scalarize the final index. + if (Ops[FinalIndex]->getType()->isVectorTy()) { + if (Value *V = getSplatValue(Ops[FinalIndex])) { + auto *C = dyn_cast(V); + // Don't scalarize all zeros vector. + if (!C || !C->isZero()) { + Ops[FinalIndex] = V; + RewriteGEP = true; + } + } + } - IRBuilder<> Builder(MemoryInst); + // If we made any changes or the we have extra operands, we need to generate + // new instructions. + if (!RewriteGEP && Ops.size() == 2) + return false; - Type *ScalarIndexTy = DL->getIndexType(Ops[0]->getType()->getScalarType()); + unsigned NumElts = cast(Ptr->getType())->getNumElements(); - Value *NewAddr; + IRBuilder<> Builder(MemoryInst); - // If the final index isn't a vector, emit a scalar GEP containing all ops - // and a vector GEP with all zeroes final index. - if (!Ops[FinalIndex]->getType()->isVectorTy()) { - NewAddr = Builder.CreateGEP(Ops[0], makeArrayRef(Ops).drop_front()); - auto *IndexTy = FixedVectorType::get(ScalarIndexTy, NumElts); - NewAddr = Builder.CreateGEP(NewAddr, Constant::getNullValue(IndexTy)); - } else { - Value *Base = Ops[0]; - Value *Index = Ops[FinalIndex]; + Type *ScalarIndexTy = DL->getIndexType(Ops[0]->getType()->getScalarType()); - // Create a scalar GEP if there are more than 2 operands. - if (Ops.size() != 2) { - // Replace the last index with 0. - Ops[FinalIndex] = Constant::getNullValue(ScalarIndexTy); - Base = Builder.CreateGEP(Base, makeArrayRef(Ops).drop_front()); + // If the final index isn't a vector, emit a scalar GEP containing all ops + // and a vector GEP with all zeroes final index. + if (!Ops[FinalIndex]->getType()->isVectorTy()) { + NewAddr = Builder.CreateGEP(Ops[0], makeArrayRef(Ops).drop_front()); + auto *IndexTy = FixedVectorType::get(ScalarIndexTy, NumElts); + NewAddr = Builder.CreateGEP(NewAddr, Constant::getNullValue(IndexTy)); + } else { + Value *Base = Ops[0]; + Value *Index = Ops[FinalIndex]; + + // Create a scalar GEP if there are more than 2 operands. + if (Ops.size() != 2) { + // Replace the last index with 0. + Ops[FinalIndex] = Constant::getNullValue(ScalarIndexTy); + Base = Builder.CreateGEP(Base, makeArrayRef(Ops).drop_front()); + } + + // Now create the GEP with scalar pointer and vector index. + NewAddr = Builder.CreateGEP(Base, Index); } + } else if (!isa(Ptr)) { + // Not a GEP, maybe its a splat and we can create a GEP to enable + // SelectionDAGBuilder to use it as a uniform base. + Value *V = getSplatValue(Ptr); + if (!V) + return false; + + unsigned NumElts = cast(Ptr->getType())->getNumElements(); + + IRBuilder<> Builder(MemoryInst); - // Now create the GEP with scalar pointer and vector index. - NewAddr = Builder.CreateGEP(Base, Index); + // Emit a vector GEP with a scalar pointer and all 0s vector index. + Type *ScalarIndexTy = DL->getIndexType(V->getType()->getScalarType()); + auto *IndexTy = FixedVectorType::get(ScalarIndexTy, NumElts); + NewAddr = Builder.CreateGEP(V, Constant::getNullValue(IndexTy)); + } else { + // Constant, SelectionDAGBuilder knows to check if its a splat. + return false; } MemoryInst->replaceUsesOfWith(Ptr, NewAddr); Index: llvm/test/CodeGen/X86/masked_gather_scatter.ll =================================================================== --- llvm/test/CodeGen/X86/masked_gather_scatter.ll +++ llvm/test/CodeGen/X86/masked_gather_scatter.ll @@ -3328,14 +3328,13 @@ ; ; KNL_64-LABEL: splat_ptr: ; KNL_64: # %bb.0: -; KNL_64-NEXT: # kill: def $xmm1 killed $xmm1 def $ymm1 +; KNL_64-NEXT: # kill: def $xmm1 killed $xmm1 def $zmm1 ; KNL_64-NEXT: vpslld $31, %xmm0, %xmm0 ; KNL_64-NEXT: vptestmd %zmm0, %zmm0, %k0 ; KNL_64-NEXT: kshiftlw $12, %k0, %k0 ; KNL_64-NEXT: kshiftrw $12, %k0, %k1 -; KNL_64-NEXT: vmovq %rdi, %xmm0 -; KNL_64-NEXT: vpbroadcastq %xmm0, %ymm0 -; KNL_64-NEXT: vpscatterqd %ymm1, (,%zmm0) {%k1} +; KNL_64-NEXT: vpxor %xmm0, %xmm0, %xmm0 +; KNL_64-NEXT: vpscatterdd %zmm1, (%rdi,%zmm0,4) {%k1} ; KNL_64-NEXT: vzeroupper ; KNL_64-NEXT: retq ; @@ -3346,8 +3345,9 @@ ; KNL_32-NEXT: vptestmd %zmm0, %zmm0, %k0 ; KNL_32-NEXT: kshiftlw $12, %k0, %k0 ; KNL_32-NEXT: kshiftrw $12, %k0, %k1 -; KNL_32-NEXT: vpbroadcastd {{[0-9]+}}(%esp), %xmm0 -; KNL_32-NEXT: vpscatterdd %zmm1, (,%zmm0) {%k1} +; KNL_32-NEXT: movl {{[0-9]+}}(%esp), %eax +; KNL_32-NEXT: vpxor %xmm0, %xmm0, %xmm0 +; KNL_32-NEXT: vpscatterdd %zmm1, (%eax,%zmm0,4) {%k1} ; KNL_32-NEXT: vzeroupper ; KNL_32-NEXT: retl ; @@ -3355,17 +3355,17 @@ ; SKX: # %bb.0: ; SKX-NEXT: vpslld $31, %xmm0, %xmm0 ; SKX-NEXT: vpmovd2m %xmm0, %k1 -; SKX-NEXT: vpbroadcastq %rdi, %ymm0 -; SKX-NEXT: vpscatterqd %xmm1, (,%ymm0) {%k1} -; SKX-NEXT: vzeroupper +; SKX-NEXT: vpxor %xmm0, %xmm0, %xmm0 +; SKX-NEXT: vpscatterdd %xmm1, (%rdi,%xmm0,4) {%k1} ; SKX-NEXT: retq ; ; SKX_32-LABEL: splat_ptr: ; SKX_32: # %bb.0: ; SKX_32-NEXT: vpslld $31, %xmm0, %xmm0 ; SKX_32-NEXT: vpmovd2m %xmm0, %k1 -; SKX_32-NEXT: vpbroadcastd {{[0-9]+}}(%esp), %xmm0 -; SKX_32-NEXT: vpscatterdd %xmm1, (,%xmm0) {%k1} +; SKX_32-NEXT: movl {{[0-9]+}}(%esp), %eax +; SKX_32-NEXT: vpxor %xmm0, %xmm0, %xmm0 +; SKX_32-NEXT: vpscatterdd %xmm1, (%eax,%xmm0,4) {%k1} ; SKX_32-NEXT: retl %1 = insertelement <4 x i32*> undef, i32* %ptr, i32 0 %2 = shufflevector <4 x i32*> %1, <4 x i32*> undef, <4 x i32> zeroinitializer Index: llvm/test/Transforms/CodeGenPrepare/X86/gather-scatter-opt.ll =================================================================== --- llvm/test/Transforms/CodeGenPrepare/X86/gather-scatter-opt.ll +++ llvm/test/Transforms/CodeGenPrepare/X86/gather-scatter-opt.ll @@ -87,9 +87,8 @@ define void @splat_ptr(i32* %ptr, <4 x i1> %mask, <4 x i32> %val) { ; CHECK-LABEL: @splat_ptr( -; CHECK-NEXT: [[TMP1:%.*]] = insertelement <4 x i32*> undef, i32* [[PTR:%.*]], i32 0 -; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <4 x i32*> [[TMP1]], <4 x i32*> undef, <4 x i32> zeroinitializer -; CHECK-NEXT: call void @llvm.masked.scatter.v4i32.v4p0i32(<4 x i32> [[VAL:%.*]], <4 x i32*> [[TMP2]], i32 4, <4 x i1> [[MASK:%.*]]) +; CHECK-NEXT: [[TMP1:%.*]] = getelementptr i32, i32* [[PTR:%.*]], <4 x i64> zeroinitializer +; CHECK-NEXT: call void @llvm.masked.scatter.v4i32.v4p0i32(<4 x i32> [[VAL:%.*]], <4 x i32*> [[TMP1]], i32 4, <4 x i1> [[MASK:%.*]]) ; CHECK-NEXT: ret void ; %1 = insertelement <4 x i32*> undef, i32* %ptr, i32 0