Index: llvm/trunk/lib/Transforms/InstCombine/InstCombineCalls.cpp =================================================================== --- llvm/trunk/lib/Transforms/InstCombine/InstCombineCalls.cpp +++ llvm/trunk/lib/Transforms/InstCombine/InstCombineCalls.cpp @@ -1789,17 +1789,63 @@ break; } - // X86 scalar intrinsics simplified with SimplifyDemandedVectorElts. case Intrinsic::x86_avx512_mask_add_ss_round: case Intrinsic::x86_avx512_mask_div_ss_round: case Intrinsic::x86_avx512_mask_mul_ss_round: case Intrinsic::x86_avx512_mask_sub_ss_round: - case Intrinsic::x86_avx512_mask_max_ss_round: - case Intrinsic::x86_avx512_mask_min_ss_round: case Intrinsic::x86_avx512_mask_add_sd_round: case Intrinsic::x86_avx512_mask_div_sd_round: case Intrinsic::x86_avx512_mask_mul_sd_round: case Intrinsic::x86_avx512_mask_sub_sd_round: + // If the rounding mode is CUR_DIRECTION(4) we can turn these into regular + // IR operations. + if (auto *R = dyn_cast(II->getArgOperand(4))) { + if (R->getValue() == 4) { + // Only do this if the mask bit is 1 so that we don't need a select. + // TODO: Improve this to handle masking cases. Isel doesn't fold + // the mask correctly right now. + if (auto *M = dyn_cast(II->getArgOperand(3))) { + if (M->getValue()[0]) { + // Extract the element as scalars. + Value *Arg0 = II->getArgOperand(0); + Value *Arg1 = II->getArgOperand(1); + Value *LHS = Builder->CreateExtractElement(Arg0, (uint64_t)0); + Value *RHS = Builder->CreateExtractElement(Arg1, (uint64_t)0); + + Value *V; + switch (II->getIntrinsicID()) { + default: llvm_unreachable("Case stmts out of sync!"); + case Intrinsic::x86_avx512_mask_add_ss_round: + case Intrinsic::x86_avx512_mask_add_sd_round: + V = Builder->CreateFAdd(LHS, RHS); + break; + case Intrinsic::x86_avx512_mask_sub_ss_round: + case Intrinsic::x86_avx512_mask_sub_sd_round: + V = Builder->CreateFSub(LHS, RHS); + break; + case Intrinsic::x86_avx512_mask_mul_ss_round: + case Intrinsic::x86_avx512_mask_mul_sd_round: + V = Builder->CreateFMul(LHS, RHS); + break; + case Intrinsic::x86_avx512_mask_div_ss_round: + case Intrinsic::x86_avx512_mask_div_sd_round: + V = Builder->CreateFDiv(LHS, RHS); + break; + } + + // Insert the result back into the original argument 0. + V = Builder->CreateInsertElement(Arg0, V, (uint64_t)0); + + return replaceInstUsesWith(*II, V); + } + } + } + } + LLVM_FALLTHROUGH; + + // X86 scalar intrinsics simplified with SimplifyDemandedVectorElts. + case Intrinsic::x86_avx512_mask_max_ss_round: + case Intrinsic::x86_avx512_mask_min_ss_round: case Intrinsic::x86_avx512_mask_max_sd_round: case Intrinsic::x86_avx512_mask_min_sd_round: case Intrinsic::x86_avx512_mask_vfmadd_ss: Index: llvm/trunk/test/Transforms/InstCombine/x86-avx512.ll =================================================================== --- llvm/trunk/test/Transforms/InstCombine/x86-avx512.ll +++ llvm/trunk/test/Transforms/InstCombine/x86-avx512.ll @@ -6,8 +6,11 @@ define <4 x float> @test_add_ss(<4 x float> %a, <4 x float> %b) { ; CHECK-LABEL: @test_add_ss( -; CHECK-NEXT: [[TMP1:%.*]] = tail call <4 x float> @llvm.x86.avx512.mask.add.ss.round(<4 x float> %a, <4 x float> %b, <4 x float> undef, i8 -1, i32 4) -; CHECK-NEXT: ret <4 x float> [[TMP1]] +; CHECK-NEXT: [[TMP1:%.*]] = extractelement <4 x float> %a, i64 0 +; CHECK-NEXT: [[TMP2:%.*]] = extractelement <4 x float> %b, i32 0 +; CHECK-NEXT: [[TMP3:%.*]] = fadd float [[TMP1]], [[TMP2]] +; CHECK-NEXT: [[TMP4:%.*]] = insertelement <4 x float> %a, float [[TMP3]], i64 0 +; CHECK-NEXT: ret <4 x float> [[TMP4]] ; %1 = insertelement <4 x float> %b, float 1.000000e+00, i32 1 %2 = insertelement <4 x float> %1, float 2.000000e+00, i32 2 @@ -16,6 +19,18 @@ ret <4 x float> %4 } +define <4 x float> @test_add_ss_round(<4 x float> %a, <4 x float> %b) { +; CHECK-LABEL: @test_add_ss_round( +; CHECK-NEXT: [[TMP1:%.*]] = tail call <4 x float> @llvm.x86.avx512.mask.add.ss.round(<4 x float> %a, <4 x float> %b, <4 x float> undef, i8 -1, i32 8) +; CHECK-NEXT: ret <4 x float> [[TMP1]] +; + %1 = insertelement <4 x float> %b, float 1.000000e+00, i32 1 + %2 = insertelement <4 x float> %1, float 2.000000e+00, i32 2 + %3 = insertelement <4 x float> %2, float 3.000000e+00, i32 3 + %4 = tail call <4 x float> @llvm.x86.avx512.mask.add.ss.round(<4 x float> %a, <4 x float> %3, <4 x float> undef, i8 -1, i32 8) + ret <4 x float> %4 +} + define <4 x float> @test_add_ss_mask(<4 x float> %a, <4 x float> %b, <4 x float> %c, i8 %mask) { ; CHECK-LABEL: @test_add_ss_mask( ; CHECK-NEXT: [[TMP1:%.*]] = tail call <4 x float> @llvm.x86.avx512.mask.add.ss.round(<4 x float> %a, <4 x float> %b, <4 x float> %c, i8 %mask, i32 4) @@ -49,14 +64,27 @@ define <2 x double> @test_add_sd(<2 x double> %a, <2 x double> %b) { ; CHECK-LABEL: @test_add_sd( -; CHECK-NEXT: [[TMP1:%.*]] = tail call <2 x double> @llvm.x86.avx512.mask.add.sd.round(<2 x double> %a, <2 x double> %b, <2 x double> undef, i8 -1, i32 4) -; CHECK-NEXT: ret <2 x double> [[TMP1]] +; CHECK-NEXT: [[TMP1:%.*]] = extractelement <2 x double> %a, i64 0 +; CHECK-NEXT: [[TMP2:%.*]] = extractelement <2 x double> %b, i64 0 +; CHECK-NEXT: [[TMP3:%.*]] = fadd double [[TMP1]], [[TMP2]] +; CHECK-NEXT: [[TMP4:%.*]] = insertelement <2 x double> %a, double [[TMP3]], i64 0 +; CHECK-NEXT: ret <2 x double> [[TMP4]] ; %1 = insertelement <2 x double> %b, double 1.000000e+00, i32 1 %2 = tail call <2 x double> @llvm.x86.avx512.mask.add.sd.round(<2 x double> %a, <2 x double> %1, <2 x double> undef, i8 -1, i32 4) ret <2 x double> %2 } +define <2 x double> @test_add_sd_round(<2 x double> %a, <2 x double> %b) { +; CHECK-LABEL: @test_add_sd_round( +; CHECK-NEXT: [[TMP1:%.*]] = tail call <2 x double> @llvm.x86.avx512.mask.add.sd.round(<2 x double> %a, <2 x double> %b, <2 x double> undef, i8 -1, i32 8) +; CHECK-NEXT: ret <2 x double> [[TMP1]] +; + %1 = insertelement <2 x double> %b, double 1.000000e+00, i32 1 + %2 = tail call <2 x double> @llvm.x86.avx512.mask.add.sd.round(<2 x double> %a, <2 x double> %1, <2 x double> undef, i8 -1, i32 8) + ret <2 x double> %2 +} + define <2 x double> @test_add_sd_mask(<2 x double> %a, <2 x double> %b, <2 x double> %c, i8 %mask) { ; CHECK-LABEL: @test_add_sd_mask( ; CHECK-NEXT: [[TMP1:%.*]] = tail call <2 x double> @llvm.x86.avx512.mask.add.sd.round(<2 x double> %a, <2 x double> %b, <2 x double> %c, i8 %mask, i32 4) @@ -84,8 +112,11 @@ define <4 x float> @test_sub_ss(<4 x float> %a, <4 x float> %b) { ; CHECK-LABEL: @test_sub_ss( -; CHECK-NEXT: [[TMP1:%.*]] = tail call <4 x float> @llvm.x86.avx512.mask.sub.ss.round(<4 x float> %a, <4 x float> %b, <4 x float> undef, i8 -1, i32 4) -; CHECK-NEXT: ret <4 x float> [[TMP1]] +; CHECK-NEXT: [[TMP1:%.*]] = extractelement <4 x float> %a, i64 0 +; CHECK-NEXT: [[TMP2:%.*]] = extractelement <4 x float> %b, i32 0 +; CHECK-NEXT: [[TMP3:%.*]] = fsub float [[TMP1]], [[TMP2]] +; CHECK-NEXT: [[TMP4:%.*]] = insertelement <4 x float> %a, float [[TMP3]], i64 0 +; CHECK-NEXT: ret <4 x float> [[TMP4]] ; %1 = insertelement <4 x float> %b, float 1.000000e+00, i32 1 %2 = insertelement <4 x float> %1, float 2.000000e+00, i32 2 @@ -94,6 +125,18 @@ ret <4 x float> %4 } +define <4 x float> @test_sub_ss_round(<4 x float> %a, <4 x float> %b) { +; CHECK-LABEL: @test_sub_ss_round( +; CHECK-NEXT: [[TMP1:%.*]] = tail call <4 x float> @llvm.x86.avx512.mask.sub.ss.round(<4 x float> %a, <4 x float> %b, <4 x float> undef, i8 -1, i32 8) +; CHECK-NEXT: ret <4 x float> [[TMP1]] +; + %1 = insertelement <4 x float> %b, float 1.000000e+00, i32 1 + %2 = insertelement <4 x float> %1, float 2.000000e+00, i32 2 + %3 = insertelement <4 x float> %2, float 3.000000e+00, i32 3 + %4 = tail call <4 x float> @llvm.x86.avx512.mask.sub.ss.round(<4 x float> %a, <4 x float> %3, <4 x float> undef, i8 -1, i32 8) + ret <4 x float> %4 +} + define <4 x float> @test_sub_ss_mask(<4 x float> %a, <4 x float> %b, <4 x float> %c, i8 %mask) { ; CHECK-LABEL: @test_sub_ss_mask( ; CHECK-NEXT: [[TMP1:%.*]] = tail call <4 x float> @llvm.x86.avx512.mask.sub.ss.round(<4 x float> %a, <4 x float> %b, <4 x float> %c, i8 %mask, i32 4) @@ -127,14 +170,27 @@ define <2 x double> @test_sub_sd(<2 x double> %a, <2 x double> %b) { ; CHECK-LABEL: @test_sub_sd( -; CHECK-NEXT: [[TMP1:%.*]] = tail call <2 x double> @llvm.x86.avx512.mask.sub.sd.round(<2 x double> %a, <2 x double> %b, <2 x double> undef, i8 -1, i32 4) -; CHECK-NEXT: ret <2 x double> [[TMP1]] +; CHECK-NEXT: [[TMP1:%.*]] = extractelement <2 x double> %a, i64 0 +; CHECK-NEXT: [[TMP2:%.*]] = extractelement <2 x double> %b, i64 0 +; CHECK-NEXT: [[TMP3:%.*]] = fsub double [[TMP1]], [[TMP2]] +; CHECK-NEXT: [[TMP4:%.*]] = insertelement <2 x double> %a, double [[TMP3]], i64 0 +; CHECK-NEXT: ret <2 x double> [[TMP4]] ; %1 = insertelement <2 x double> %b, double 1.000000e+00, i32 1 %2 = tail call <2 x double> @llvm.x86.avx512.mask.sub.sd.round(<2 x double> %a, <2 x double> %1, <2 x double> undef, i8 -1, i32 4) ret <2 x double> %2 } +define <2 x double> @test_sub_sd_round(<2 x double> %a, <2 x double> %b) { +; CHECK-LABEL: @test_sub_sd_round( +; CHECK-NEXT: [[TMP1:%.*]] = tail call <2 x double> @llvm.x86.avx512.mask.sub.sd.round(<2 x double> %a, <2 x double> %b, <2 x double> undef, i8 -1, i32 8) +; CHECK-NEXT: ret <2 x double> [[TMP1]] +; + %1 = insertelement <2 x double> %b, double 1.000000e+00, i32 1 + %2 = tail call <2 x double> @llvm.x86.avx512.mask.sub.sd.round(<2 x double> %a, <2 x double> %1, <2 x double> undef, i8 -1, i32 8) + ret <2 x double> %2 +} + define <2 x double> @test_sub_sd_mask(<2 x double> %a, <2 x double> %b, <2 x double> %c, i8 %mask) { ; CHECK-LABEL: @test_sub_sd_mask( ; CHECK-NEXT: [[TMP1:%.*]] = tail call <2 x double> @llvm.x86.avx512.mask.sub.sd.round(<2 x double> %a, <2 x double> %b, <2 x double> %c, i8 %mask, i32 4) @@ -162,8 +218,11 @@ define <4 x float> @test_mul_ss(<4 x float> %a, <4 x float> %b) { ; CHECK-LABEL: @test_mul_ss( -; CHECK-NEXT: [[TMP1:%.*]] = tail call <4 x float> @llvm.x86.avx512.mask.mul.ss.round(<4 x float> %a, <4 x float> %b, <4 x float> undef, i8 -1, i32 4) -; CHECK-NEXT: ret <4 x float> [[TMP1]] +; CHECK-NEXT: [[TMP1:%.*]] = extractelement <4 x float> %a, i64 0 +; CHECK-NEXT: [[TMP2:%.*]] = extractelement <4 x float> %b, i32 0 +; CHECK-NEXT: [[TMP3:%.*]] = fmul float [[TMP1]], [[TMP2]] +; CHECK-NEXT: [[TMP4:%.*]] = insertelement <4 x float> %a, float [[TMP3]], i64 0 +; CHECK-NEXT: ret <4 x float> [[TMP4]] ; %1 = insertelement <4 x float> %b, float 1.000000e+00, i32 1 %2 = insertelement <4 x float> %1, float 2.000000e+00, i32 2 @@ -172,6 +231,18 @@ ret <4 x float> %4 } +define <4 x float> @test_mul_ss_round(<4 x float> %a, <4 x float> %b) { +; CHECK-LABEL: @test_mul_ss_round( +; CHECK-NEXT: [[TMP1:%.*]] = tail call <4 x float> @llvm.x86.avx512.mask.mul.ss.round(<4 x float> %a, <4 x float> %b, <4 x float> undef, i8 -1, i32 8) +; CHECK-NEXT: ret <4 x float> [[TMP1]] +; + %1 = insertelement <4 x float> %b, float 1.000000e+00, i32 1 + %2 = insertelement <4 x float> %1, float 2.000000e+00, i32 2 + %3 = insertelement <4 x float> %2, float 3.000000e+00, i32 3 + %4 = tail call <4 x float> @llvm.x86.avx512.mask.mul.ss.round(<4 x float> %a, <4 x float> %3, <4 x float> undef, i8 -1, i32 8) + ret <4 x float> %4 +} + define <4 x float> @test_mul_ss_mask(<4 x float> %a, <4 x float> %b, <4 x float> %c, i8 %mask) { ; CHECK-LABEL: @test_mul_ss_mask( ; CHECK-NEXT: [[TMP1:%.*]] = tail call <4 x float> @llvm.x86.avx512.mask.mul.ss.round(<4 x float> %a, <4 x float> %b, <4 x float> %c, i8 %mask, i32 4) @@ -205,14 +276,27 @@ define <2 x double> @test_mul_sd(<2 x double> %a, <2 x double> %b) { ; CHECK-LABEL: @test_mul_sd( -; CHECK-NEXT: [[TMP1:%.*]] = tail call <2 x double> @llvm.x86.avx512.mask.mul.sd.round(<2 x double> %a, <2 x double> %b, <2 x double> undef, i8 -1, i32 4) -; CHECK-NEXT: ret <2 x double> [[TMP1]] +; CHECK-NEXT: [[TMP1:%.*]] = extractelement <2 x double> %a, i64 0 +; CHECK-NEXT: [[TMP2:%.*]] = extractelement <2 x double> %b, i64 0 +; CHECK-NEXT: [[TMP3:%.*]] = fmul double [[TMP1]], [[TMP2]] +; CHECK-NEXT: [[TMP4:%.*]] = insertelement <2 x double> %a, double [[TMP3]], i64 0 +; CHECK-NEXT: ret <2 x double> [[TMP4]] ; %1 = insertelement <2 x double> %b, double 1.000000e+00, i32 1 %2 = tail call <2 x double> @llvm.x86.avx512.mask.mul.sd.round(<2 x double> %a, <2 x double> %1, <2 x double> undef, i8 -1, i32 4) ret <2 x double> %2 } +define <2 x double> @test_mul_sd_round(<2 x double> %a, <2 x double> %b) { +; CHECK-LABEL: @test_mul_sd_round( +; CHECK-NEXT: [[TMP1:%.*]] = tail call <2 x double> @llvm.x86.avx512.mask.mul.sd.round(<2 x double> %a, <2 x double> %b, <2 x double> undef, i8 -1, i32 8) +; CHECK-NEXT: ret <2 x double> [[TMP1]] +; + %1 = insertelement <2 x double> %b, double 1.000000e+00, i32 1 + %2 = tail call <2 x double> @llvm.x86.avx512.mask.mul.sd.round(<2 x double> %a, <2 x double> %1, <2 x double> undef, i8 -1, i32 8) + ret <2 x double> %2 +} + define <2 x double> @test_mul_sd_mask(<2 x double> %a, <2 x double> %b, <2 x double> %c, i8 %mask) { ; CHECK-LABEL: @test_mul_sd_mask( ; CHECK-NEXT: [[TMP1:%.*]] = tail call <2 x double> @llvm.x86.avx512.mask.mul.sd.round(<2 x double> %a, <2 x double> %b, <2 x double> %c, i8 %mask, i32 4) @@ -240,8 +324,11 @@ define <4 x float> @test_div_ss(<4 x float> %a, <4 x float> %b) { ; CHECK-LABEL: @test_div_ss( -; CHECK-NEXT: [[TMP1:%.*]] = tail call <4 x float> @llvm.x86.avx512.mask.div.ss.round(<4 x float> %a, <4 x float> %b, <4 x float> undef, i8 -1, i32 4) -; CHECK-NEXT: ret <4 x float> [[TMP1]] +; CHECK-NEXT: [[TMP1:%.*]] = extractelement <4 x float> %a, i64 0 +; CHECK-NEXT: [[TMP2:%.*]] = extractelement <4 x float> %b, i32 0 +; CHECK-NEXT: [[TMP3:%.*]] = fdiv float [[TMP1]], [[TMP2]] +; CHECK-NEXT: [[TMP4:%.*]] = insertelement <4 x float> %a, float [[TMP3]], i64 0 +; CHECK-NEXT: ret <4 x float> [[TMP4]] ; %1 = insertelement <4 x float> %b, float 1.000000e+00, i32 1 %2 = insertelement <4 x float> %1, float 2.000000e+00, i32 2 @@ -250,6 +337,18 @@ ret <4 x float> %4 } +define <4 x float> @test_div_ss_round(<4 x float> %a, <4 x float> %b) { +; CHECK-LABEL: @test_div_ss_round( +; CHECK-NEXT: [[TMP1:%.*]] = tail call <4 x float> @llvm.x86.avx512.mask.div.ss.round(<4 x float> %a, <4 x float> %b, <4 x float> undef, i8 -1, i32 8) +; CHECK-NEXT: ret <4 x float> [[TMP1]] +; + %1 = insertelement <4 x float> %b, float 1.000000e+00, i32 1 + %2 = insertelement <4 x float> %1, float 2.000000e+00, i32 2 + %3 = insertelement <4 x float> %2, float 3.000000e+00, i32 3 + %4 = tail call <4 x float> @llvm.x86.avx512.mask.div.ss.round(<4 x float> %a, <4 x float> %3, <4 x float> undef, i8 -1, i32 8) + ret <4 x float> %4 +} + define <4 x float> @test_div_ss_mask(<4 x float> %a, <4 x float> %b, <4 x float> %c, i8 %mask) { ; CHECK-LABEL: @test_div_ss_mask( ; CHECK-NEXT: [[TMP1:%.*]] = tail call <4 x float> @llvm.x86.avx512.mask.div.ss.round(<4 x float> %a, <4 x float> %b, <4 x float> %c, i8 %mask, i32 4) @@ -283,14 +382,27 @@ define <2 x double> @test_div_sd(<2 x double> %a, <2 x double> %b) { ; CHECK-LABEL: @test_div_sd( -; CHECK-NEXT: [[TMP1:%.*]] = tail call <2 x double> @llvm.x86.avx512.mask.div.sd.round(<2 x double> %a, <2 x double> %b, <2 x double> undef, i8 -1, i32 4) -; CHECK-NEXT: ret <2 x double> [[TMP1]] +; CHECK-NEXT: [[TMP1:%.*]] = extractelement <2 x double> %a, i64 0 +; CHECK-NEXT: [[TMP2:%.*]] = extractelement <2 x double> %b, i64 0 +; CHECK-NEXT: [[TMP3:%.*]] = fdiv double [[TMP1]], [[TMP2]] +; CHECK-NEXT: [[TMP4:%.*]] = insertelement <2 x double> %a, double [[TMP3]], i64 0 +; CHECK-NEXT: ret <2 x double> [[TMP4]] ; %1 = insertelement <2 x double> %b, double 1.000000e+00, i32 1 %2 = tail call <2 x double> @llvm.x86.avx512.mask.div.sd.round(<2 x double> %a, <2 x double> %1, <2 x double> undef, i8 -1, i32 4) ret <2 x double> %2 } +define <2 x double> @test_div_sd_round(<2 x double> %a, <2 x double> %b) { +; CHECK-LABEL: @test_div_sd_round( +; CHECK-NEXT: [[TMP1:%.*]] = tail call <2 x double> @llvm.x86.avx512.mask.div.sd.round(<2 x double> %a, <2 x double> %b, <2 x double> undef, i8 -1, i32 8) +; CHECK-NEXT: ret <2 x double> [[TMP1]] +; + %1 = insertelement <2 x double> %b, double 1.000000e+00, i32 1 + %2 = tail call <2 x double> @llvm.x86.avx512.mask.div.sd.round(<2 x double> %a, <2 x double> %1, <2 x double> undef, i8 -1, i32 8) + ret <2 x double> %2 +} + define <2 x double> @test_div_sd_mask(<2 x double> %a, <2 x double> %b, <2 x double> %c, i8 %mask) { ; CHECK-LABEL: @test_div_sd_mask( ; CHECK-NEXT: [[TMP1:%.*]] = tail call <2 x double> @llvm.x86.avx512.mask.div.sd.round(<2 x double> %a, <2 x double> %b, <2 x double> %c, i8 %mask, i32 4)