diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -41808,17 +41808,14 @@ if (Op0.getOpcode() == ISD::SIGN_EXTEND) std::swap(Op0, Op1); - if (Op0.getOpcode() != ISD::ZERO_EXTEND) - return false; - auto IsFreeTruncation = [](SDValue &Op) -> bool { if ((Op.getOpcode() == ISD::ZERO_EXTEND || Op.getOpcode() == ISD::SIGN_EXTEND) && Op.getOperand(0).getScalarValueSizeInBits() <= 8) return true; - // TODO: Support contant value. - return false; + auto *BV = dyn_cast(Op); + return (BV && BV->isConstant()); }; // (dpbusd (zext a), (sext, b)). Since the first operand should be unsigned diff --git a/llvm/lib/Target/X86/X86PartialReduction.cpp b/llvm/lib/Target/X86/X86PartialReduction.cpp --- a/llvm/lib/Target/X86/X86PartialReduction.cpp +++ b/llvm/lib/Target/X86/X86PartialReduction.cpp @@ -76,9 +76,6 @@ if (isa(LHS)) std::swap(LHS, RHS); - if (!isa(LHS)) - return false; - auto IsFreeTruncation = [&](Value *Op) { if (auto *Cast = dyn_cast(Op)) { if (Cast->getParent() == Mul->getParent() && @@ -87,8 +84,8 @@ Cast->getOperand(0)->getType()->getScalarSizeInBits() <= 8) return true; } - // TODO: Support constant in ISel. - return false; + + return isa(Op); }; // (dpbusd (zext a), (sext, b)). Since the first operand should be unsigned diff --git a/llvm/test/CodeGen/X86/dpbusd_const.ll b/llvm/test/CodeGen/X86/dpbusd_const.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/X86/dpbusd_const.ll @@ -0,0 +1,121 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py +; RUN: llc < %s -mtriple=x86_64-unknown-unknown -mattr=+avx512vnni -mattr=+avx512vl | FileCheck %s + +define i32 @mul_4xi8_zc_exceed(<4 x i8> %a, i32 %c) { +; CHECK-LABEL: mul_4xi8_zc_exceed: +; CHECK: # %bb.0: # %entry +; CHECK-NEXT: vpmovzxbd {{.*#+}} xmm0 = xmm0[0],zero,zero,zero,xmm0[1],zero,zero,zero,xmm0[2],zero,zero,zero,xmm0[3],zero,zero,zero +; CHECK-NEXT: vpmaddwd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0 +; CHECK-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[2,3,2,3] +; CHECK-NEXT: vpaddd %xmm1, %xmm0, %xmm0 +; CHECK-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[1,1,1,1] +; CHECK-NEXT: vpaddd %xmm1, %xmm0, %xmm0 +; CHECK-NEXT: vmovd %xmm0, %eax +; CHECK-NEXT: addl %edi, %eax +; CHECK-NEXT: retq +entry: + %0 = zext <4 x i8> %a to <4 x i32> + %1 = mul nsw <4 x i32> %0, + %2 = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %1) + %op.extra = add nsw i32 %2, %c + ret i32 %op.extra +} + +define i32 @mul_4xi8_zc(<4 x i8> %a, i32 %c) { +; CHECK-LABEL: mul_4xi8_zc: +; CHECK: # %bb.0: # %entry +; CHECK-NEXT: vpxor %xmm1, %xmm1, %xmm1 +; CHECK-NEXT: vpblendw {{.*#+}} xmm0 = xmm0[0,1],xmm1[2,3,4,5,6,7] +; CHECK-NEXT: vpxor %xmm1, %xmm1, %xmm1 +; CHECK-NEXT: vpdpbusd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm1 +; CHECK-NEXT: vmovd %xmm1, %eax +; CHECK-NEXT: addl %edi, %eax +; CHECK-NEXT: retq +entry: + %0 = zext <4 x i8> %a to <4 x i32> + %1 = mul nsw <4 x i32> %0, + %2 = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %1) + %op.extra = add nsw i32 %2, %c + ret i32 %op.extra +} + +define i32 @mul_4xi4_cz(<4 x i4> %a, i32 %c) { +; CHECK-LABEL: mul_4xi4_cz: +; CHECK: # %bb.0: # %entry +; CHECK-NEXT: vpmovdb %xmm0, %xmm0 +; CHECK-NEXT: vpand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0 +; CHECK-NEXT: vmovdqa {{.*#+}} xmm1 = [0,1,2,127,0,0,0,0,0,0,0,0,0,0,0,0] +; CHECK-NEXT: vpxor %xmm2, %xmm2, %xmm2 +; CHECK-NEXT: vpdpbusd %xmm0, %xmm1, %xmm2 +; CHECK-NEXT: vmovd %xmm2, %eax +; CHECK-NEXT: addl %edi, %eax +; CHECK-NEXT: retq +entry: + %0 = zext <4 x i4> %a to <4 x i32> + %1 = mul nsw <4 x i32> , %0 + %2 = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %1) + %op.extra = add nsw i32 %2, %c + ret i32 %op.extra +} + +define i32 @mul_4xi8_cs(<4 x i8> %a, i32 %c) { +; CHECK-LABEL: mul_4xi8_cs: +; CHECK: # %bb.0: # %entry +; CHECK-NEXT: vpxor %xmm1, %xmm1, %xmm1 +; CHECK-NEXT: vpblendw {{.*#+}} xmm0 = xmm0[0,1],xmm1[2,3,4,5,6,7] +; CHECK-NEXT: vmovdqa {{.*#+}} xmm1 = [0,1,2,255,0,0,0,0,0,0,0,0,0,0,0,0] +; CHECK-NEXT: vpxor %xmm2, %xmm2, %xmm2 +; CHECK-NEXT: vpdpbusd %xmm0, %xmm1, %xmm2 +; CHECK-NEXT: vmovd %xmm2, %eax +; CHECK-NEXT: addl %edi, %eax +; CHECK-NEXT: retq +entry: + %0 = sext <4 x i8> %a to <4 x i32> + %1 = mul nsw <4 x i32> , %0 + %2 = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %1) + %op.extra = add nsw i32 %2, %c + ret i32 %op.extra +} + +define i32 @mul_4xi8_cs_exceed(<4 x i8> %a, i32 %c) { +; CHECK-LABEL: mul_4xi8_cs_exceed: +; CHECK: # %bb.0: # %entry +; CHECK-NEXT: vpmovsxbd %xmm0, %xmm0 +; CHECK-NEXT: vpmaddwd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0 +; CHECK-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[2,3,2,3] +; CHECK-NEXT: vpaddd %xmm1, %xmm0, %xmm0 +; CHECK-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[1,1,1,1] +; CHECK-NEXT: vpaddd %xmm1, %xmm0, %xmm0 +; CHECK-NEXT: vmovd %xmm0, %eax +; CHECK-NEXT: addl %edi, %eax +; CHECK-NEXT: retq +entry: + %0 = sext <4 x i8> %a to <4 x i32> + %1 = mul nsw <4 x i32> , %0 + %2 = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %1) + %op.extra = add nsw i32 %2, %c + ret i32 %op.extra +} + +define i32 @mul_16xi8_zc(<16 x i8> %a, i32 %c) { +; CHECK-LABEL: mul_16xi8_zc: +; CHECK: # %bb.0: # %entry +; CHECK-NEXT: vpxor %xmm1, %xmm1, %xmm1 +; CHECK-NEXT: vpdpbusd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm1 +; CHECK-NEXT: vpshufd {{.*#+}} xmm0 = xmm1[2,3,2,3] +; CHECK-NEXT: vpaddd %xmm0, %xmm1, %xmm0 +; CHECK-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[1,1,1,1] +; CHECK-NEXT: vpaddd %xmm1, %xmm0, %xmm0 +; CHECK-NEXT: vmovd %xmm0, %eax +; CHECK-NEXT: addl %edi, %eax +; CHECK-NEXT: retq +entry: + %0 = zext <16 x i8> %a to <16 x i32> + %1 = mul nsw <16 x i32> %0, + %2 = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %1) + %op.extra = add nsw i32 %2, %c + ret i32 %op.extra +} + +declare i32 @llvm.vector.reduce.add.v4i32(<4 x i32>) +declare i32 @llvm.vector.reduce.add.v16i32(<16 x i32>)