Index: llvm/include/llvm/Analysis/VectorUtils.h =================================================================== --- llvm/include/llvm/Analysis/VectorUtils.h +++ llvm/include/llvm/Analysis/VectorUtils.h @@ -354,7 +354,7 @@ /// Get splat value if the input is a splat vector or return nullptr. /// The value may be extracted from a splat constants vector or from /// a sequence of instructions that broadcast a single value into a vector. -const Value *getSplatValue(const Value *V); +Value *getSplatValue(const Value *V); /// Return true if each element of the vector value \p V is poisoned or equal to /// every other non-poisoned element. If an index element is specified, either Index: llvm/lib/Analysis/VectorUtils.cpp =================================================================== --- llvm/lib/Analysis/VectorUtils.cpp +++ llvm/lib/Analysis/VectorUtils.cpp @@ -342,7 +342,7 @@ /// This function is not fully general. It checks only 2 cases: /// the input value is (1) a splat constant vector or (2) a sequence /// of instructions that broadcasts a scalar at element 0. -const llvm::Value *llvm::getSplatValue(const Value *V) { +llvm::Value *llvm::getSplatValue(const Value *V) { if (isa(V->getType())) if (auto *C = dyn_cast(V)) return C->getSplatValue(); Index: llvm/lib/CodeGen/CodeGenPrepare.cpp =================================================================== --- llvm/lib/CodeGen/CodeGenPrepare.cpp +++ llvm/lib/CodeGen/CodeGenPrepare.cpp @@ -5309,88 +5309,108 @@ /// zero index. bool CodeGenPrepare::optimizeGatherScatterInst(Instruction *MemoryInst, Value *Ptr) { - const GetElementPtrInst *GEP = dyn_cast(Ptr); - if (!GEP || !GEP->hasIndices()) - return false; + Value *NewAddr; - // If the GEP and the gather/scatter aren't in the same BB, don't optimize. - // FIXME: We should support this by sinking the GEP. - if (MemoryInst->getParent() != GEP->getParent()) - return false; + if (const auto *GEP = dyn_cast(Ptr)) { + // Don't optimize GEPs that don't have indices. + if (!GEP->hasIndices()) + return false; - SmallVector Ops(GEP->op_begin(), GEP->op_end()); + // If the GEP and the gather/scatter aren't in the same BB, don't optimize. + // FIXME: We should support this by sinking the GEP. + if (MemoryInst->getParent() != GEP->getParent()) + return false; - bool RewriteGEP = false; + SmallVector Ops(GEP->op_begin(), GEP->op_end()); - if (Ops[0]->getType()->isVectorTy()) { - Ops[0] = const_cast(getSplatValue(Ops[0])); - if (!Ops[0]) - return false; - RewriteGEP = true; - } + bool RewriteGEP = false; + + if (Ops[0]->getType()->isVectorTy()) { + Ops[0] = getSplatValue(Ops[0]); + if (!Ops[0]) + return false; + RewriteGEP = true; + } - unsigned FinalIndex = Ops.size() - 1; + unsigned FinalIndex = Ops.size() - 1; - // Ensure all but the last index is 0. - // FIXME: This isn't strictly required. All that's required is that they are - // all scalars or splats. - for (unsigned i = 1; i < FinalIndex; ++i) { - auto *C = dyn_cast(Ops[i]); - if (!C) - return false; - if (isa(C->getType())) - C = C->getSplatValue(); - auto *CI = dyn_cast_or_null(C); - if (!CI || !CI->isZero()) - return false; - // Scalarize the index if needed. - Ops[i] = CI; - } - - // Try to scalarize the final index. - if (Ops[FinalIndex]->getType()->isVectorTy()) { - if (Value *V = const_cast(getSplatValue(Ops[FinalIndex]))) { - auto *C = dyn_cast(V); - // Don't scalarize all zeros vector. - if (!C || !C->isZero()) { - Ops[FinalIndex] = V; - RewriteGEP = true; + // Ensure all but the last index is 0. + // FIXME: This isn't strictly required. All that's required is that they are + // all scalars or splats. + for (unsigned i = 1; i < FinalIndex; ++i) { + auto *C = dyn_cast(Ops[i]); + if (!C) + return false; + if (isa(C->getType())) + C = C->getSplatValue(); + auto *CI = dyn_cast_or_null(C); + if (!CI || !CI->isZero()) + return false; + // Scalarize the index if needed. + Ops[i] = CI; + } + + // Try to scalarize the final index. + if (Ops[FinalIndex]->getType()->isVectorTy()) { + if (Value *V = getSplatValue(Ops[FinalIndex])) { + auto *C = dyn_cast(V); + // Don't scalarize all zeros vector. + if (!C || !C->isZero()) { + Ops[FinalIndex] = V; + RewriteGEP = true; + } } } - } - // If we made any changes or the we have extra operands, we need to generate - // new instructions. - if (!RewriteGEP && Ops.size() == 2) - return false; - - unsigned NumElts = cast(Ptr->getType())->getNumElements(); + // If we made any changes or the we have extra operands, we need to generate + // new instructions. + if (!RewriteGEP && Ops.size() == 2) + return false; - IRBuilder<> Builder(MemoryInst); + unsigned NumElts = cast(Ptr->getType())->getNumElements(); - Type *ScalarIndexTy = DL->getIndexType(Ops[0]->getType()->getScalarType()); + IRBuilder<> Builder(MemoryInst); - Value *NewAddr; + Type *ScalarIndexTy = DL->getIndexType(Ops[0]->getType()->getScalarType()); - // If the final index isn't a vector, emit a scalar GEP containing all ops - // and a vector GEP with all zeroes final index. - if (!Ops[FinalIndex]->getType()->isVectorTy()) { - NewAddr = Builder.CreateGEP(Ops[0], makeArrayRef(Ops).drop_front()); - auto *IndexTy = FixedVectorType::get(ScalarIndexTy, NumElts); - NewAddr = Builder.CreateGEP(NewAddr, Constant::getNullValue(IndexTy)); - } else { - Value *Base = Ops[0]; - Value *Index = Ops[FinalIndex]; + // If the final index isn't a vector, emit a scalar GEP containing all ops + // and a vector GEP with all zeroes final index. + if (!Ops[FinalIndex]->getType()->isVectorTy()) { + NewAddr = Builder.CreateGEP(Ops[0], makeArrayRef(Ops).drop_front()); + auto *IndexTy = FixedVectorType::get(ScalarIndexTy, NumElts); + NewAddr = Builder.CreateGEP(NewAddr, Constant::getNullValue(IndexTy)); + } else { + Value *Base = Ops[0]; + Value *Index = Ops[FinalIndex]; + + // Create a scalar GEP if there are more than 2 operands. + if (Ops.size() != 2) { + // Replace the last index with 0. + Ops[FinalIndex] = Constant::getNullValue(ScalarIndexTy); + Base = Builder.CreateGEP(Base, makeArrayRef(Ops).drop_front()); + } - // Create a scalar GEP if there are more than 2 operands. - if (Ops.size() != 2) { - // Replace the last index with 0. - Ops[FinalIndex] = Constant::getNullValue(ScalarIndexTy); - Base = Builder.CreateGEP(Base, makeArrayRef(Ops).drop_front()); + // Now create the GEP with scalar pointer and vector index. + NewAddr = Builder.CreateGEP(Base, Index); } + } else if (!isa(Ptr)) { + // Not a GEP, maybe its a splat and we can create a GEP to enable + // SelectionDAGBuilder to use it as a uniform base. + Value *V = getSplatValue(Ptr); + if (!V) + return false; + + unsigned NumElts = cast(Ptr->getType())->getNumElements(); + + IRBuilder<> Builder(MemoryInst); - // Now create the GEP with scalar pointer and vector index. - NewAddr = Builder.CreateGEP(Base, Index); + // Emit a vector GEP with a scalar pointer and all 0s vector index. + Type *ScalarIndexTy = DL->getIndexType(V->getType()->getScalarType()); + auto *IndexTy = FixedVectorType::get(ScalarIndexTy, NumElts); + NewAddr = Builder.CreateGEP(V, Constant::getNullValue(IndexTy)); + } else { + // Constant, SelectionDAGBuilder knows to check if its a splat. + return false; } MemoryInst->replaceUsesOfWith(Ptr, NewAddr); Index: llvm/test/CodeGen/X86/masked_gather_scatter.ll =================================================================== --- llvm/test/CodeGen/X86/masked_gather_scatter.ll +++ llvm/test/CodeGen/X86/masked_gather_scatter.ll @@ -3319,3 +3319,57 @@ call void @llvm.masked.scatter.v16i32.v16p0i32(<16 x i32> %src0, <16 x i32*> %gep, i32 4, <16 x i1> %mask) ret void } + +define void @splat_ptr(i32* %ptr, <4 x i1> %mask, <4 x i32> %val) { +; CHECK-LABEL: @splat_ptr( +; CHECK-NEXT: [[TMP1:%.*]] = getelementptr i32, i32* [[PTR:%.*]], <4 x i64> zeroinitializer +; CHECK-NEXT: call void @llvm.masked.scatter.v4i32.v4p0i32(<4 x i32> [[VAL:%.*]], <4 x i32*> [[TMP1]], i32 4, <4 x i1> [[MASK:%.*]]) +; CHECK-NEXT: ret void +; +; KNL_64-LABEL: splat_ptr: +; KNL_64: # %bb.0: +; KNL_64-NEXT: # kill: def $xmm1 killed $xmm1 def $zmm1 +; KNL_64-NEXT: vpslld $31, %xmm0, %xmm0 +; KNL_64-NEXT: vptestmd %zmm0, %zmm0, %k0 +; KNL_64-NEXT: kshiftlw $12, %k0, %k0 +; KNL_64-NEXT: kshiftrw $12, %k0, %k1 +; KNL_64-NEXT: vpxor %xmm0, %xmm0, %xmm0 +; KNL_64-NEXT: vpscatterdd %zmm1, (%rdi,%zmm0,4) {%k1} +; KNL_64-NEXT: vzeroupper +; KNL_64-NEXT: retq +; +; KNL_32-LABEL: splat_ptr: +; KNL_32: # %bb.0: +; KNL_32-NEXT: # kill: def $xmm1 killed $xmm1 def $zmm1 +; KNL_32-NEXT: vpslld $31, %xmm0, %xmm0 +; KNL_32-NEXT: vptestmd %zmm0, %zmm0, %k0 +; KNL_32-NEXT: kshiftlw $12, %k0, %k0 +; KNL_32-NEXT: kshiftrw $12, %k0, %k1 +; KNL_32-NEXT: movl {{[0-9]+}}(%esp), %eax +; KNL_32-NEXT: vpxor %xmm0, %xmm0, %xmm0 +; KNL_32-NEXT: vpscatterdd %zmm1, (%eax,%zmm0,4) {%k1} +; KNL_32-NEXT: vzeroupper +; KNL_32-NEXT: retl +; +; SKX-LABEL: splat_ptr: +; SKX: # %bb.0: +; SKX-NEXT: vpslld $31, %xmm0, %xmm0 +; SKX-NEXT: vpmovd2m %xmm0, %k1 +; SKX-NEXT: vpxor %xmm0, %xmm0, %xmm0 +; SKX-NEXT: vpscatterdd %xmm1, (%rdi,%xmm0,4) {%k1} +; SKX-NEXT: retq +; +; SKX_32-LABEL: splat_ptr: +; SKX_32: # %bb.0: +; SKX_32-NEXT: vpslld $31, %xmm0, %xmm0 +; SKX_32-NEXT: vpmovd2m %xmm0, %k1 +; SKX_32-NEXT: movl {{[0-9]+}}(%esp), %eax +; SKX_32-NEXT: vpxor %xmm0, %xmm0, %xmm0 +; SKX_32-NEXT: vpscatterdd %xmm1, (%eax,%xmm0,4) {%k1} +; SKX_32-NEXT: retl + %1 = insertelement <4 x i32*> undef, i32* %ptr, i32 0 + %2 = shufflevector <4 x i32*> %1, <4 x i32*> undef, <4 x i32> zeroinitializer + call void @llvm.masked.scatter.v4i32.v4p0i32(<4 x i32> %val, <4 x i32*> %2, i32 4, <4 x i1> %mask) + ret void +} + Index: llvm/test/Transforms/CodeGenPrepare/X86/gather-scatter-opt.ll =================================================================== --- llvm/test/Transforms/CodeGenPrepare/X86/gather-scatter-opt.ll +++ llvm/test/Transforms/CodeGenPrepare/X86/gather-scatter-opt.ll @@ -85,4 +85,17 @@ ret <4 x i32> %4 } +define void @splat_ptr(i32* %ptr, <4 x i1> %mask, <4 x i32> %val) { +; CHECK-LABEL: @splat_ptr( +; CHECK-NEXT: [[TMP1:%.*]] = getelementptr i32, i32* [[PTR:%.*]], <4 x i64> zeroinitializer +; CHECK-NEXT: call void @llvm.masked.scatter.v4i32.v4p0i32(<4 x i32> [[VAL:%.*]], <4 x i32*> [[TMP1]], i32 4, <4 x i1> [[MASK:%.*]]) +; CHECK-NEXT: ret void +; + %1 = insertelement <4 x i32*> undef, i32* %ptr, i32 0 + %2 = shufflevector <4 x i32*> %1, <4 x i32*> undef, <4 x i32> zeroinitializer + call void @llvm.masked.scatter.v4i32.v4p0i32(<4 x i32> %val, <4 x i32*> %2, i32 4, <4 x i1> %mask) + ret void +} + declare <4 x i32> @llvm.masked.gather.v4i32.v4p0i32(<4 x i32*>, i32, <4 x i1>, <4 x i32>) +declare void @llvm.masked.scatter.v4i32.v4p0i32(<4 x i32>, <4 x i32*>, i32, <4 x i1>)