Index: lib/Transforms/InstCombine/InstCombineCalls.cpp =================================================================== --- lib/Transforms/InstCombine/InstCombineCalls.cpp +++ lib/Transforms/InstCombine/InstCombineCalls.cpp @@ -810,6 +810,55 @@ return nullptr; } +/// Return a constant boolean vector that has true elements in all positions +/// where the input constant integer vector has an element with the sign bit +/// set. +static Constant *getNegIsTrueBoolVec(ConstantDataVector *V) { + SmallVector BoolVec; + IntegerType *BoolTy = Type::getInt1Ty(V->getContext()); + for (unsigned I = 0, E = V->getNumElements(); I != E; ++I) { + ConstantInt *Elt = cast(V->getElementAsConstant(I)); + BoolVec.push_back(ConstantInt::get(BoolTy, Elt->isNegative())); + } + return ConstantVector::get(BoolVec); +} + +// TODO: If the x86 backend knew how to convert a bool vector mask back to an +// XMM register mask efficiently, we could transform all x86 masked intrinsics +// to LLVM masked intrinsics and remove the x86 masked intrinsic defs. +static Instruction *simplifyX86MaskedStore(IntrinsicInst &II, + InstCombiner &IC) { + Value *Ptr = II.getOperand(0); + Value *Mask = II.getOperand(1); + Value *Vec = II.getOperand(2); + + // Special case a zero mask since that's not a ConstantDataVector: + // this masked store instruction does nothing. + if (isa(Mask)) + return IC.eraseInstFromFunction(II); + + auto *ConstMask = dyn_cast(Mask); + if (!ConstMask) + return nullptr; + + // The mask is constant. Convert this x86 intrinsic to the LLVM instrinsic + // to allow target-independent optimizations. + + // First, cast the x86 intrinsic scalar pointer to a vector pointer to match + // the LLVM intrinsic definition for the pointer argument. + Type *VecPtrTy = PointerType::getUnqual(Vec->getType()); + Value *PtrCast = IC.Builder->CreateBitCast(Ptr, VecPtrTy, "castvec"); + + // Second, convert the x86 XMM integer vector mask to a vector of bools based + // on each element's most significant bit (the sign bit). + Constant *BoolMask = getNegIsTrueBoolVec(ConstMask); + + IC.Builder->CreateMaskedStore(Vec, PtrCast, 1, BoolMask); + + // 'Replace uses' doesn't work for stores. Erase the original masked store. + return IC.eraseInstFromFunction(II); +} + /// CallInst simplification. This mostly only handles folding of intrinsic /// instructions. For normal calls, it allows visitCallSite to do the heavy /// lifting. @@ -1552,6 +1601,13 @@ return replaceInstUsesWith(*II, V); break; + case Intrinsic::x86_avx_maskstore_ps: + case Intrinsic::x86_avx_maskstore_pd: + case Intrinsic::x86_avx_maskstore_ps_256: + case Intrinsic::x86_avx_maskstore_pd_256: + // TODO: The AVX2 integer variants can go here too. + return simplifyX86MaskedStore(*II, *this); + case Intrinsic::x86_xop_vpcomb: case Intrinsic::x86_xop_vpcomd: case Intrinsic::x86_xop_vpcomq: Index: test/Transforms/InstCombine/x86-masked-memops.ll =================================================================== --- test/Transforms/InstCombine/x86-masked-memops.ll +++ test/Transforms/InstCombine/x86-masked-memops.ll @@ -0,0 +1,94 @@ +; RUN: opt < %s -instcombine -S | FileCheck %s + +; If the mask isn't constant, do nothing. + +define void @mstore(i8* %f, <4 x i32> %mask, <4 x float> %v) { + tail call void @llvm.x86.avx.maskstore.ps(i8* %f, <4 x i32> %mask, <4 x float> %v) + ret void + +; CHECK-LABEL: @mstore( +; CHECK-NEXT: tail call void @llvm.x86.avx.maskstore.ps(i8* %f, <4 x i32> %mask, <4 x float> %v) +; CHECK-NEXT: ret void +} + +; Zero mask is a nop. + +define void @mstore_zeros(i8* %f, <4 x float> %v) { + tail call void @llvm.x86.avx.maskstore.ps(i8* %f, <4 x i32> zeroinitializer, <4 x float> %v) + ret void + +; CHECK-LABEL: @mstore_zeros( +; CHECK-NEXT: ret void +} + +; Only the sign bit matters. + +define void @mstore_fake_ones(i8* %f, <4 x float> %v) { + tail call void @llvm.x86.avx.maskstore.ps(i8* %f, <4 x i32> , <4 x float> %v) + ret void + +; CHECK-LABEL: @mstore_fake_ones( +; CHECK-NEXT: ret void +} + +; All mask bits are set, so this is just a vector store. + +define void @mstore_real_ones(i8* %f, <4 x float> %v) { + tail call void @llvm.x86.avx.maskstore.ps(i8* %f, <4 x i32> , <4 x float> %v) + ret void + +; CHECK-LABEL: @mstore_real_ones( +; CHECK-NEXT: %castvec = bitcast i8* %f to <4 x float>* +; CHECK-NEXT: store <4 x float> %v, <4 x float>* %castvec +; CHECK-NEXT: ret void +} + +; It's a constant mask, so convert to an LLVM intrinsic. The backend should optimize further. + +define void @mstore_one_one(i8* %f, <4 x float> %v) { + tail call void @llvm.x86.avx.maskstore.ps(i8* %f, <4 x i32> , <4 x float> %v) + ret void + +; CHECK-LABEL: @mstore_one_one( +; CHECK-NEXT: %castvec = bitcast i8* %f to <4 x float>* +; CHECK-NEXT: call void @llvm.masked.store.v4f32(<4 x float> %v, <4 x float>* %castvec, i32 1, <4 x i1> ) +; CHECK-NEXT: ret void +} + +; Try doubles. + +define void @mstore_one_one_double(i8* %f, <2 x double> %v) { + tail call void @llvm.x86.avx.maskstore.pd(i8* %f, <2 x i64> , <2 x double> %v) + ret void + +; CHECK-LABEL: @mstore_one_one_double( +; CHECK-NEXT: %castvec = bitcast i8* %f to <2 x double>* +; CHECK-NEXT: call void @llvm.masked.store.v2f64(<2 x double> %v, <2 x double>* %castvec, i32 1, <2 x i1> ) +; CHECK-NEXT: ret void +} + +define void @mstore_v8f32(i8* %f, <8 x float> %v) { + tail call void @llvm.x86.avx.maskstore.ps.256(i8* %f, <8 x i32> , <8 x float> %v) + ret void + +; CHECK-LABEL: @mstore_v8f32( +; CHECK-NEXT: %castvec = bitcast i8* %f to <8 x float>* +; CHECK-NEXT: call void @llvm.masked.store.v8f32(<8 x float> %v, <8 x float>* %castvec, i32 1, <8 x i1> ) +; CHECK-NEXT: ret void +} + +define void @mstore_v4f64(i8* %f, <4 x double> %v) { + tail call void @llvm.x86.avx.maskstore.pd.256(i8* %f, <4 x i64> , <4 x double> %v) + ret void + +; CHECK-LABEL: @mstore_v4f64( +; CHECK-NEXT: %castvec = bitcast i8* %f to <4 x double>* +; CHECK-NEXT: call void @llvm.masked.store.v4f64(<4 x double> %v, <4 x double>* %castvec, i32 1, <4 x i1> ) +; CHECK-NEXT: ret void +} + +declare void @llvm.x86.avx.maskstore.ps(i8*, <4 x i32>, <4 x float>) +declare void @llvm.x86.avx.maskstore.pd(i8*, <2 x i64>, <2 x double>) +declare void @llvm.x86.avx.maskstore.ps.256(i8*, <8 x i32>, <8 x float>) +declare void @llvm.x86.avx.maskstore.pd.256(i8*, <4 x i64>, <4 x double>) +