Index: include/llvm/Analysis/VectorUtils.h =================================================================== --- include/llvm/Analysis/VectorUtils.h +++ include/llvm/Analysis/VectorUtils.h @@ -75,7 +75,7 @@ /// Get splat value if the input is a splat vector or return nullptr. /// The value may be extracted from a splat constants vector or from /// a sequence of instructions that broadcast a single value into a vector. -const Value *getSplatValue(const Value *V); +Value *getSplatValue(const Value *V); /// Compute a map of integer instructions to their minimum legal type /// size. Index: lib/Analysis/VectorUtils.cpp =================================================================== --- lib/Analysis/VectorUtils.cpp +++ lib/Analysis/VectorUtils.cpp @@ -302,7 +302,7 @@ /// the input value is (1) a splat constants vector or (2) a sequence /// of instructions that broadcast a single value into a vector. /// -const llvm::Value *llvm::getSplatValue(const Value *V) { +llvm::Value *llvm::getSplatValue(const Value *V) { if (auto *C = dyn_cast(V)) if (isa(V->getType())) Index: lib/Transforms/InstCombine/InstCombineCalls.cpp =================================================================== --- lib/Transforms/InstCombine/InstCombineCalls.cpp +++ lib/Transforms/InstCombine/InstCombineCalls.cpp @@ -25,6 +25,7 @@ #include "llvm/Analysis/MemoryBuiltins.h" #include "llvm/Transforms/Utils/Local.h" #include "llvm/Analysis/ValueTracking.h" +#include "llvm/Analysis/VectorUtils.h" #include "llvm/IR/Attributes.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/CallSite.h" @@ -1162,14 +1163,14 @@ return nullptr; } -static Instruction *simplifyMaskedStore(IntrinsicInst &II, InstCombiner &IC) { +Instruction *InstCombiner::simplifyMaskedStore(IntrinsicInst &II) { auto *ConstMask = dyn_cast(II.getArgOperand(3)); if (!ConstMask) return nullptr; // If the mask is all zeros, this instruction does nothing. if (ConstMask->isNullValue()) - return IC.eraseInstFromFunction(II); + return eraseInstFromFunction(II); // If the mask is all ones, this is a plain vector store of the 1st argument. if (ConstMask->isAllOnesValue()) { @@ -1178,6 +1179,24 @@ return new StoreInst(II.getArgOperand(0), StorePtr, false, Alignment); } + // Use masked off lanes to simplify operands via SimplifyDemandedVectorElts + // TODO: Surely we have something more general for understanding masks? + auto *CV = dyn_cast(ConstMask); + if (!CV) + return nullptr; + + const unsigned VWidth = CV->getType()->getNumElements(); + APInt DemandedElts(VWidth, 0); + for (unsigned i = 0; i < VWidth; i++) + if (!CV->getAggregateElement(i)->isNullValue()) + DemandedElts.setBit(i); + APInt UndefElts(VWidth, 0); + if (Value *V = SimplifyDemandedVectorElts(II.getOperand(0), + DemandedElts, UndefElts)) { + II.setOperand(0, V); + return &II; + } + return nullptr; } @@ -1224,11 +1243,32 @@ return cast(Result); } -static Instruction *simplifyMaskedScatter(IntrinsicInst &II, InstCombiner &IC) { +Instruction *InstCombiner::simplifyMaskedScatter(IntrinsicInst &II) { // If the mask is all zeros, a scatter does nothing. auto *ConstMask = dyn_cast(II.getArgOperand(3)); if (ConstMask && ConstMask->isNullValue()) - return IC.eraseInstFromFunction(II); + return eraseInstFromFunction(II); + + auto *CV = dyn_cast_or_null(ConstMask); + if (!CV) + return nullptr; + + const unsigned VWidth = CV->getType()->getNumElements(); + APInt DemandedElts(VWidth, 0); + for (unsigned i = 0; i < VWidth; i++) + if (!CV->getAggregateElement(i)->isNullValue()) + DemandedElts.setBit(i); + APInt UndefElts(VWidth, 0); + if (Value *V = SimplifyDemandedVectorElts(II.getOperand(0), + DemandedElts, UndefElts)) { + II.setOperand(0, V); + return &II; + } + if (Value *V = SimplifyDemandedVectorElts(II.getOperand(1), + DemandedElts, UndefElts)) { + II.setOperand(1, V); + return &II; + } return nullptr; } @@ -1904,11 +1944,11 @@ return replaceInstUsesWith(CI, SimplifiedMaskedOp); break; case Intrinsic::masked_store: - return simplifyMaskedStore(*II, *this); + return simplifyMaskedStore(*II); case Intrinsic::masked_gather: return simplifyMaskedGather(*II, *this); case Intrinsic::masked_scatter: - return simplifyMaskedScatter(*II, *this); + return simplifyMaskedScatter(*II); case Intrinsic::launder_invariant_group: case Intrinsic::strip_invariant_group: if (auto *SkippedBarrier = simplifyInvariantGroupIntrinsic(*II, *this)) Index: lib/Transforms/InstCombine/InstCombineInternal.h =================================================================== --- lib/Transforms/InstCombine/InstCombineInternal.h +++ lib/Transforms/InstCombine/InstCombineInternal.h @@ -473,6 +473,9 @@ Instruction *transformCallThroughTrampoline(CallSite CS, IntrinsicInst *Tramp); + Instruction *simplifyMaskedStore(IntrinsicInst &II); + Instruction *simplifyMaskedScatter(IntrinsicInst &II); + /// Transform (zext icmp) to bitwise / integer operations in order to /// eliminate it. /// Index: test/Transforms/InstCombine/masked_intrinsics.ll =================================================================== --- test/Transforms/InstCombine/masked_intrinsics.ll +++ test/Transforms/InstCombine/masked_intrinsics.ll @@ -48,6 +48,19 @@ ; CHECK-NEXT: ret void } +define void @store_demandedelts(<2 x double>* %ptr, double %val) { + %valvec1 = insertelement <2 x double> undef, double %val, i32 0 + %valvec2 = insertelement <2 x double> %valvec1, double %val, i32 1 + call void @llvm.masked.store.v2f64.p0v2f64(<2 x double> %valvec2, <2 x double>* %ptr, i32 4, <2 x i1> ) + ret void + +; CHECK-LABEL: @store_demandedelts( +; CHECK-NEXT: %valvec2 = insertelement <2 x double> undef, double %val, i32 0 +; CHECK-NEXT: call void @llvm.masked.store.v2f64.p0v2f64(<2 x double> %valvec2, <2 x double>* %ptr, i32 4, <2 x i1> ) +; CHECK-NEXT: ret void +} + + define <2 x double> @gather_zeromask(<2 x double*> %ptrs, <2 x double> %passthru) { %res = call <2 x double> @llvm.masked.gather.v2f64.v2p0f64(<2 x double*> %ptrs, i32 5, <2 x i1> zeroinitializer, <2 x double> %passthru) ret <2 x double> %res @@ -57,10 +70,24 @@ } define void @scatter_zeromask(<2 x double*> %ptrs, <2 x double> %val) { - call void @llvm.masked.scatter.v2f64.v2p0f64(<2 x double> %val, <2 x double*> %ptrs, i32 6, <2 x i1> zeroinitializer) + call void @llvm.masked.scatter.v2f64.v2p0f64(<2 x double> %val, <2 x double*> %ptrs, i32 8, <2 x i1> zeroinitializer) ret void ; CHECK-LABEL: @scatter_zeromask( ; CHECK-NEXT: ret void } +define void @scatter_demandedelts(double* %ptr, double %val) { + %ptrs = getelementptr double, double* %ptr, <2 x i64> + %valvec1 = insertelement <2 x double> undef, double %val, i32 0 + %valvec2 = insertelement <2 x double> %valvec1, double %val, i32 1 + call void @llvm.masked.scatter.v2f64.v2p0f64(<2 x double> %valvec2, <2 x double*> %ptrs, i32 8, <2 x i1> ) + ret void + +; CHECK-LABEL: @scatter_demandedelts( +; CHECK-NEXT: %ptrs = getelementptr double, double* %ptr, <2 x i64> +; CHECK-NEXT: %valvec2 = insertelement <2 x double> undef, double %val, i32 0 +; CHECK-NEXT: call void @llvm.masked.scatter.v2f64.v2p0f64(<2 x double> %valvec2, <2 x double*> %ptrs, i32 8, <2 x i1> ) +; CHECK-NEXT: ret void +} +