Index: lib/Transforms/InstCombine/InstCombineCalls.cpp =================================================================== --- lib/Transforms/InstCombine/InstCombineCalls.cpp +++ lib/Transforms/InstCombine/InstCombineCalls.cpp @@ -21,6 +21,7 @@ #include "llvm/ADT/Statistic.h" #include "llvm/ADT/Twine.h" #include "llvm/Analysis/AssumptionCache.h" +#include "llvm/Analysis/Loads.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/MemoryBuiltins.h" #include "llvm/Transforms/Utils/Local.h" @@ -1191,19 +1192,28 @@ } // TODO, Obvious Missing Transforms: -// * Dereferenceable address -> speculative load/select // * Narrow width by halfs excluding zero/undef lanes static Value *simplifyMaskedLoad(const IntrinsicInst &II, InstCombiner::BuilderTy &Builder) { + Value *LoadPtr = II.getArgOperand(0); + unsigned Alignment = cast(II.getArgOperand(1))->getZExtValue(); + // If the mask is all ones or undefs, this is a plain vector load of the 1st // argument. - if (maskIsAllOneOrUndef(II.getArgOperand(2))) { - Value *LoadPtr = II.getArgOperand(0); - unsigned Alignment = cast(II.getArgOperand(1))->getZExtValue(); + if (maskIsAllOneOrUndef(II.getArgOperand(2))) return Builder.CreateAlignedLoad(II.getType(), LoadPtr, Alignment, "unmaskedload"); - } + // If we can unconditionally load from this address, replace with a + // load/select idiom. TODO: use DT for context sensative query + if (isDereferenceableAndAlignedPointer(LoadPtr, Alignment, + II.getModule()->getDataLayout(), + &II, nullptr)) { + auto *LI = Builder.CreateAlignedLoad(II.getType(), LoadPtr, Alignment, + "unmaskedload"); + return Builder.CreateSelect(II.getArgOperand(2), LI, II.getArgOperand(3)); + } + return nullptr; } @@ -1239,7 +1249,6 @@ } // TODO, Obvious Missing Transforms: -// * Single constant active lane load -> load // * Dereferenceable address & few lanes -> scalarize speculative load/selects // * Adjacent vector addresses -> masked.load // * Narrow width by halfs excluding zero/undef lanes @@ -1247,10 +1256,39 @@ // * Vector incrementing address -> vector masked load static Instruction *simplifyMaskedGather(IntrinsicInst &II, InstCombiner &IC) { // If the mask is all zeros, return the "passthru" argument of the gather. + // TODO: move to instsimplify auto *ConstMask = dyn_cast(II.getArgOperand(2)); - if (ConstMask && ConstMask->isNullValue()) + if (!ConstMask) + return nullptr; + if (ConstMask->isNullValue()) return IC.replaceInstUsesWith(II, II.getArgOperand(3)); + APInt DemandedElts = possiblyDemandedEltsInMask(ConstMask); + // If we have at most a single lane active, emit a scalar masked.load. If + // the mask is known to be active (as opposed to simply not known inactive), + // then the masked.load combines will convert it to a simple load. + if (DemandedElts.isPowerOf2()) { + // Note: APInt indexes the bit vector from LSB to MSB, thus + // countTrailingZeros returns the index in the vector which is set. + unsigned Idx = DemandedElts.countTrailingZeros(); + auto &B = IC.Builder; + auto *PtrLane = B.CreateExtractElement(II.getArgOperand(0), Idx); + unsigned Alignment = cast(II.getArgOperand(1))->getZExtValue(); + auto *MaskLane = B.CreateExtractElement(II.getArgOperand(2), Idx); + auto *PassThrough = II.getArgOperand(3); + auto *PTLane = B.CreateExtractElement(PassThrough, Idx); + // TODO: pull out a scalar masked load helper function. + auto *PTy = PointerType::get(VectorType::get(PTLane->getType(), 1), + PtrLane->getType()->getPointerAddressSpace()); + auto *ML = B.CreateMaskedLoad(B.CreateBitCast(PtrLane, PTy), + Alignment, + B.CreateVectorSplat(1, MaskLane), + B.CreateVectorSplat(1, PTLane)); + auto *E = B.CreateExtractElement(ML, (uint64_t)0); + auto *Res = B.CreateInsertElement(PassThrough, E, Idx); + return IC.replaceInstUsesWith(II, Res); + } + return nullptr; } Index: test/Transforms/InstCombine/masked_intrinsics.ll =================================================================== --- test/Transforms/InstCombine/masked_intrinsics.ll +++ test/Transforms/InstCombine/masked_intrinsics.ll @@ -73,8 +73,9 @@ ; CHECK-LABEL: @load_speculative( ; CHECK-NEXT: [[PTV1:%.*]] = insertelement <2 x double> undef, double [[PT:%.*]], i64 0 ; CHECK-NEXT: [[PTV2:%.*]] = shufflevector <2 x double> [[PTV1]], <2 x double> undef, <2 x i32> zeroinitializer -; CHECK-NEXT: [[RES:%.*]] = call <2 x double> @llvm.masked.load.v2f64.p0v2f64(<2 x double>* nonnull [[PTR:%.*]], i32 4, <2 x i1> [[MASK:%.*]], <2 x double> [[PTV2]]) -; CHECK-NEXT: ret <2 x double> [[RES]] +; CHECK-NEXT: [[UNMASKEDLOAD:%.*]] = load <2 x double>, <2 x double>* [[PTR:%.*]], align 4 +; CHECK-NEXT: [[TMP1:%.*]] = select <2 x i1> [[MASK:%.*]], <2 x double> [[UNMASKEDLOAD]], <2 x double> [[PTV2]] +; CHECK-NEXT: ret <2 x double> [[TMP1]] ; double %pt, <2 x i1> %mask) { %ptv1 = insertelement <2 x double> undef, double %pt, i64 0 @@ -176,9 +177,12 @@ define <2 x double> @gather_lane0(double* %base, double %pt) { ; CHECK-LABEL: @gather_lane0( ; CHECK-NEXT: [[PTRS:%.*]] = getelementptr double, double* [[BASE:%.*]], <2 x i64> -; CHECK-NEXT: [[PT_V2:%.*]] = insertelement <2 x double> undef, double [[PT:%.*]], i64 1 -; CHECK-NEXT: [[RES:%.*]] = call <2 x double> @llvm.masked.gather.v2f64.v2p0f64(<2 x double*> [[PTRS]], i32 4, <2 x i1> , <2 x double> [[PT_V2]]) -; CHECK-NEXT: ret <2 x double> [[RES]] +; CHECK-NEXT: [[BC:%.*]] = bitcast <2 x double*> [[PTRS]] to <2 x <1 x double>*> +; CHECK-NEXT: [[TMP1:%.*]] = extractelement <2 x <1 x double>*> [[BC]], i64 0 +; CHECK-NEXT: [[UNMASKEDLOAD:%.*]] = load <1 x double>, <1 x double>* [[TMP1]], align 4 +; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <1 x double> [[UNMASKEDLOAD]], <1 x double> undef, <2 x i32> +; CHECK-NEXT: [[TMP3:%.*]] = insertelement <2 x double> [[TMP2]], double [[PT:%.*]], i64 1 +; CHECK-NEXT: ret <2 x double> [[TMP3]] ; %ptrs = getelementptr double, double *%base, <2 x i64> %pt_v1 = insertelement <2 x double> undef, double %pt, i64 0