diff --git a/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp b/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp --- a/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp +++ b/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp @@ -882,6 +882,8 @@ ComplexDeinterleavingGraph::NodePtr ComplexDeinterleavingGraph::identifyNode(Instruction *Real, Instruction *Imag) { LLVM_DEBUG(dbgs() << "identifyNode on " << *Real << " / " << *Imag << "\n"); + assert(Real->getType() == Imag->getType() && + "Real and imaginary parts should not have different types"); if (NodePtr CN = getContainingComposite(Real, Imag)) { LLVM_DEBUG(dbgs() << " - Folding to existing node\n"); return CN; @@ -1463,6 +1465,8 @@ auto *Real = OperationInstruction[i]; auto *Imag = OperationInstruction[j]; + if (Real->getType() != Imag->getType()) + continue; RealPHI = ReductionInfo[Real].first; ImagPHI = ReductionInfo[Imag].first; diff --git a/llvm/test/CodeGen/AArch64/complex-deinterleaving-reductions-scalable.ll b/llvm/test/CodeGen/AArch64/complex-deinterleaving-reductions-scalable.ll --- a/llvm/test/CodeGen/AArch64/complex-deinterleaving-reductions-scalable.ll +++ b/llvm/test/CodeGen/AArch64/complex-deinterleaving-reductions-scalable.ll @@ -310,7 +310,91 @@ ret %"class.std::complex" %.fca.0.1.insert } +; Integer and floating point complex number reduction in the same loop: +; complex *s = ...; +; int *a = ...; +; +; for (int i = 0; i < N; ++i) { +; sum += s[i]; +; int_sum += a[i]; +; } +; +define dso_local %"class.std::complex" @reduction_mix(ptr %a, ptr %b, ptr noalias nocapture noundef readnone %c, [2 x double] %d.coerce, ptr nocapture noundef readonly %s, ptr nocapture noundef writeonly %outs) local_unnamed_addr #0 { +; CHECK-LABEL: reduction_mix: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: cntd x9 +; CHECK-NEXT: mov w11, #100 // =0x64 +; CHECK-NEXT: neg x10, x9 +; CHECK-NEXT: mov x8, xzr +; CHECK-NEXT: and x10, x10, x11 +; CHECK-NEXT: mov z0.d, #0 // =0x0 +; CHECK-NEXT: rdvl x11, #2 +; CHECK-NEXT: zip2 z1.d, z0.d, z0.d +; CHECK-NEXT: zip1 z2.d, z0.d, z0.d +; CHECK-NEXT: ptrue p0.d +; CHECK-NEXT: .LBB3_1: // %vector.body +; CHECK-NEXT: // =>This Inner Loop Header: Depth=1 +; CHECK-NEXT: ld1w { z3.d }, p0/z, [x3, x8, lsl #2] +; CHECK-NEXT: ld1d { z4.d }, p0/z, [x0] +; CHECK-NEXT: ld1d { z5.d }, p0/z, [x0, #1, mul vl] +; CHECK-NEXT: add x8, x8, x9 +; CHECK-NEXT: add x0, x0, x11 +; CHECK-NEXT: cmp x10, x8 +; CHECK-NEXT: add z0.d, z3.d, z0.d +; CHECK-NEXT: fadd z2.d, z4.d, z2.d +; CHECK-NEXT: fadd z1.d, z5.d, z1.d +; CHECK-NEXT: b.ne .LBB3_1 +; CHECK-NEXT: // %bb.2: // %middle.block +; CHECK-NEXT: uzp1 z3.d, z2.d, z1.d +; CHECK-NEXT: uzp2 z1.d, z2.d, z1.d +; CHECK-NEXT: uaddv d2, p0, z0.d +; CHECK-NEXT: faddv d0, p0, z1.d +; CHECK-NEXT: fmov x8, d2 +; CHECK-NEXT: faddv d1, p0, z3.d +; CHECK-NEXT: // kill: def $d0 killed $d0 killed $z0 +; CHECK-NEXT: // kill: def $d1 killed $d1 killed $z1 +; CHECK-NEXT: str w8, [x4] +; CHECK-NEXT: ret +entry: + %0 = tail call i64 @llvm.vscale.i64() + %1 = shl nuw nsw i64 %0, 1 + %n.mod.vf = urem i64 100, %1 + %n.vec = sub nuw nsw i64 100, %n.mod.vf + %2 = tail call i64 @llvm.vscale.i64() + %3 = shl nuw nsw i64 %2, 1 + br label %vector.body + +vector.body: ; preds = %vector.body, %entry + %index = phi i64 [ 0, %entry ], [ %index.next, %vector.body ] + %vec.phi = phi [ zeroinitializer, %entry ], [ %5, %vector.body ] + %vec.phi13 = phi [ zeroinitializer, %entry ], [ %9, %vector.body ] + %vec.phi14 = phi [ zeroinitializer, %entry ], [ %10, %vector.body ] + %4 = getelementptr inbounds i32, ptr %s, i64 %index + %wide.load = load , ptr %4, align 4 + %5 = add %wide.load, %vec.phi + %6 = getelementptr inbounds %"class.std::complex", ptr %a, i64 %index + %wide.vec = load , ptr %6, align 8 + %strided.vec = tail call { , } @llvm.experimental.vector.deinterleave2.nxv4f64( %wide.vec) + %7 = extractvalue { , } %strided.vec, 0 + %8 = extractvalue { , } %strided.vec, 1 + %9 = fadd fast %7, %vec.phi13 + %10 = fadd fast %8, %vec.phi14 + %index.next = add nuw i64 %index, %3 + %11 = icmp eq i64 %index.next, %n.vec + br i1 %11, label %middle.block, label %vector.body + +middle.block: ; preds = %vector.body + %12 = tail call fast double @llvm.vector.reduce.fadd.nxv2f64(double -0.000000e+00, %10) + %13 = tail call fast double @llvm.vector.reduce.fadd.nxv2f64(double -0.000000e+00, %9) + %14 = tail call i32 @llvm.vector.reduce.add.nxv2i32( %5) + store i32 %14, ptr %outs, align 4 + %.fca.0.0.insert = insertvalue %"class.std::complex" poison, double %12, 0, 0 + %.fca.0.1.insert = insertvalue %"class.std::complex" %.fca.0.0.insert, double %13, 0, 1 + ret %"class.std::complex" %.fca.0.1.insert +} + declare i64 @llvm.vscale.i64() declare { , } @llvm.experimental.vector.deinterleave2.nxv4f64() declare double @llvm.vector.reduce.fadd.nxv2f64(double, ) +declare i32 @llvm.vector.reduce.add.nxv2i32()