Index: lib/Transforms/InstCombine/InstCombineCalls.cpp =================================================================== --- lib/Transforms/InstCombine/InstCombineCalls.cpp +++ lib/Transforms/InstCombine/InstCombineCalls.cpp @@ -21,6 +21,7 @@ #include "llvm/ADT/Statistic.h" #include "llvm/ADT/Twine.h" #include "llvm/Analysis/AssumptionCache.h" +#include "llvm/Analysis/Loads.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/MemoryBuiltins.h" #include "llvm/Transforms/Utils/Local.h" @@ -1213,7 +1214,6 @@ } // TODO, Obvious Missing Transforms: -// * Single constant active lane load -> load // * Dereferenceable address & few lanes -> scalarize speculative load/selects // * Adjacent vector addresses -> masked.load // * Narrow width by halfs excluding zero/undef lanes @@ -1222,9 +1222,37 @@ static Instruction *simplifyMaskedGather(IntrinsicInst &II, InstCombiner &IC) { // If the mask is all zeros, return the "passthru" argument of the gather. auto *ConstMask = dyn_cast(II.getArgOperand(2)); - if (ConstMask && ConstMask->isNullValue()) + if (!ConstMask) + return nullptr; + if (ConstMask->isNullValue()) return IC.replaceInstUsesWith(II, II.getArgOperand(3)); + APInt DemandedElts = possiblyDemandedEltsInMask(ConstMask); + // If we have at most a single lane active, emit a scalar masked.load. If + // the mask is known to be active (as opposed to simply not known inactive), + // then the masked.load combines will convert it to a simple load. + if (DemandedElts.isPowerOf2()) { + // Note: APInt indexes the bit vector from LSB to MSB, thus + // countTrailingZeros returns the index in the vector which is set. + unsigned Idx = DemandedElts.countTrailingZeros(); + auto &B = IC.Builder; + auto *PtrLane = B.CreateExtractElement(II.getArgOperand(0), Idx); + unsigned Alignment = cast(II.getArgOperand(1))->getZExtValue(); + auto *MaskLane = B.CreateExtractElement(II.getArgOperand(2), Idx); + auto *PassThrough = II.getArgOperand(3); + auto *PTLane = B.CreateExtractElement(PassThrough, Idx); + // TODO: pull out a scalar masked load helper function. + auto *PTy = PointerType::get(VectorType::get(PTLane->getType(), 1), + PtrLane->getType()->getPointerAddressSpace()); + auto *ML = B.CreateMaskedLoad(B.CreateBitCast(PtrLane, PTy), + Alignment, + B.CreateVectorSplat(1, MaskLane), + B.CreateVectorSplat(1, PTLane)); + auto *E = B.CreateExtractElement(ML, (uint64_t)0); + auto *Res = B.CreateInsertElement(PassThrough, E, Idx); + return IC.replaceInstUsesWith(II, Res); + } + return nullptr; } Index: test/Transforms/InstCombine/masked_intrinsics.ll =================================================================== --- test/Transforms/InstCombine/masked_intrinsics.ll +++ test/Transforms/InstCombine/masked_intrinsics.ll @@ -206,8 +206,12 @@ ; CHECK-NEXT: [[PTRS:%.*]] = getelementptr double, double* [[BASE:%.*]], <4 x i64> ; CHECK-NEXT: [[PT_V1:%.*]] = insertelement <4 x double> undef, double [[PT:%.*]], i64 0 ; CHECK-NEXT: [[PT_V2:%.*]] = shufflevector <4 x double> [[PT_V1]], <4 x double> undef, <4 x i32> -; CHECK-NEXT: [[RES:%.*]] = call <4 x double> @llvm.masked.gather.v4f64.v4p0f64(<4 x double*> [[PTRS]], i32 4, <4 x i1> , <4 x double> [[PT_V2]]) -; CHECK-NEXT: ret <4 x double> [[RES]] +; CHECK-NEXT: [[BC:%.*]] = bitcast <4 x double*> [[PTRS]] to <4 x <1 x double>*> +; CHECK-NEXT: [[TMP1:%.*]] = extractelement <4 x <1 x double>*> [[BC]], i64 2 +; CHECK-NEXT: [[UNMASKEDLOAD:%.*]] = load <1 x double>, <1 x double>* [[TMP1]], align 4 +; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <1 x double> [[UNMASKEDLOAD]], <1 x double> undef, <4 x i32> +; CHECK-NEXT: [[TMP3:%.*]] = shufflevector <4 x double> [[PT_V2]], <4 x double> [[TMP2]], <4 x i32> +; CHECK-NEXT: ret <4 x double> [[TMP3]] ; %ptrs = getelementptr double, double *%base, <4 x i64> %pt_v1 = insertelement <4 x double> undef, double %pt, i64 0