Index: lib/Transforms/InstCombine/InstCombineCalls.cpp =================================================================== --- lib/Transforms/InstCombine/InstCombineCalls.cpp +++ lib/Transforms/InstCombine/InstCombineCalls.cpp @@ -675,6 +675,40 @@ return nullptr; } +static Value *simplifyMaskedLoad(const IntrinsicInst &II, + InstCombiner::BuilderTy &Builder) { + assert(II.getNumArgOperands() == 4 && + "Wrong masked intrinsic format"); + assert(isa(II.getArgOperand(0)->getType()) && + "Wrong type for 1st arg"); + assert(isa(II.getArgOperand(1)) && + "Wrong type for 2nd arg"); + assert(isa(II.getArgOperand(2)->getType()) && + "Wrong type for 3rd arg"); + assert(II.getArgOperand(3)->getType() == II.getType() && + "Wrong type for 4th arg"); + + Value *LoadPtr = II.getArgOperand(0); + unsigned Alignment = cast(II.getArgOperand(1))->getZExtValue(); + Value *Mask = II.getArgOperand(2); + Value *PassThru = II.getArgOperand(3); + + // If the mask is all zeros, the "passthru" argument is the result. + if (isa(Mask)) + return PassThru; + + // If the mask is all ones, this is a plain vector load of the first argument. + // Note that the mask type is which is not a ConstantDataVector. + if (auto *ConstMask = dyn_cast(Mask)) + if (ConstMask->isAllOnesValue()) + return Builder.CreateAlignedLoad(LoadPtr, Alignment, "unmaskedload"); + + // TODO: A mask with only one set bit can be reduced to a scalar load and + // insertelement into the passthru vector. + + return nullptr; +} + /// CallInst simplification. This mostly only handles folding of intrinsic /// instructions. For normal calls, it allows visitCallSite to do the heavy /// lifting. @@ -799,6 +833,16 @@ break; } + case Intrinsic::masked_load: + if (Value *SimplifiedMaskedOp = simplifyMaskedLoad(*II, *Builder)) + return ReplaceInstUsesWith(CI, SimplifiedMaskedOp); + break; + + // TODO: Handle the other masked ops. + // case Intrinsic::masked_store: + // case Intrinsic::masked_gather: + // case Intrinsic::masked_scatter: + case Intrinsic::powi: if (ConstantInt *Power = dyn_cast(II->getArgOperand(1))) { // powi(x, 0) -> 1.0 Index: test/Transforms/InstCombine/masked_intrinsics.ll =================================================================== --- test/Transforms/InstCombine/masked_intrinsics.ll +++ test/Transforms/InstCombine/masked_intrinsics.ll @@ -2,15 +2,13 @@ declare <2 x double> @llvm.masked.load.v2f64(<2 x double>* %ptrs, i32, <2 x i1> %mask, <2 x double> %src0) -; FIXME: All of these could be simplified. define <2 x double> @load_zeromask(<2 x double>* %ptr, <2 x double> %passthru) { %res = call <2 x double> @llvm.masked.load.v2f64(<2 x double>* %ptr, i32 1, <2 x i1> zeroinitializer, <2 x double> %passthru) ret <2 x double> %res ; CHECK-LABEL: @load_zeromask( -; CHECK-NEXT: %res = call <2 x double> @llvm.masked.load.v2f64(<2 x double>* %ptr, i32 1, <2 x i1> zeroinitializer, <2 x double> %passthru) -; CHECK-NEXT ret <2 x double> %res +; CHECK-NEXT ret <2 x double> %passthru } define <2 x double> @load_onemask(<2 x double>* %ptr, <2 x double> %passthru) { @@ -18,10 +16,12 @@ ret <2 x double> %res ; CHECK-LABEL: @load_onemask( -; CHECK-NEXT: %res = call <2 x double> @llvm.masked.load.v2f64(<2 x double>* %ptr, i32 2, <2 x i1> , <2 x double> %passthru) +; CHECK-NEXT: %unmaskedload = load <2 x double>, <2 x double>* %ptr, align 2 ; CHECK-NEXT ret <2 x double> %res } +; FIXME: These could be simplified. + define <2 x double> @load_onesetbitmask1(<2 x double>* %ptr, <2 x double> %passthru) { %res = call <2 x double> @llvm.masked.load.v2f64(<2 x double>* %ptr, i32 3, <2 x i1> , <2 x double> %passthru) ret <2 x double> %res