This is an archive of the discontinued LLVM Phabricator instance.

[PowerPC] use lfiwax/lfiwzx for scalar_to_vector + load at PWR7
AbandonedPublic

Authored by shchenz on Jul 20 2021, 3:49 AM.

Details

Reviewers
nemanjai
jsji
qiucf
Group Reviewers
Restricted Project
Summary

We handle well for scalar_to_vector + load for PWR8 and above. This means using lfiwax/lfiwzx for above pattern thus we don't need the memory to do the type change.

This patch adds a similar code-gen improvement for PWR7.

We still miss handling the case for i32->i64 sign-ext and unsigned-ext. PowerPC backend fails to recognize build_vector t1, t1 as scalar_to_vector + vector_shuffle<0, 0>. So it will not be hit by this patch. We will handle this in a later patch.

Diff Detail

Event Timeline

shchenz created this revision.Jul 20 2021, 3:49 AM
shchenz requested review of this revision.Jul 20 2021, 3:49 AM
Herald added a project: Restricted Project. · View Herald TranscriptJul 20 2021, 3:49 AM
nemanjai requested changes to this revision.Jul 20 2021, 6:46 PM

I don't think this is actually correct.

llvm/lib/Target/PowerPC/PPCISelLowering.cpp
10519

Is this expected to have further uses in the future? Defining a lambda with all the implementation in it and then simply calling it once seems like a very strange idiom.

10540

I think this is incorrect for little endian systems since on LE, LFIW[AZ]X will put the value into an element that is not element zero.

llvm/test/CodeGen/PowerPC/load-and-splat.ll
62

This looks wrong. It will splat word zero (which is just zero) rather than word 1 which it is supposed to splat (i.e. it should match the P8 codegen).

This revision now requires changes to proceed.Jul 20 2021, 6:46 PM

It would seem that a simpler approach would be to just use the LD_SPLAT node (and add a ZExt/SExt version of it). Something like this (note: the patch is untested but should fix the FIXME you added to the test case):

diff --git a/llvm/lib/Target/PowerPC/PPCISelLowering.cpp b/llvm/lib/Target/PowerPC/PPCISelLowering.cpp
index 37358176f35e..000243c59401 100644
--- a/llvm/lib/Target/PowerPC/PPCISelLowering.cpp
+++ b/llvm/lib/Target/PowerPC/PPCISelLowering.cpp
@@ -1707,6 +1707,8 @@ const char *PPCTargetLowering::getTargetNodeName(unsigned Opcode) const {
   case PPCISD::EXTRACT_VSX_REG: return "PPCISD::EXTRACT_VSX_REG";
   case PPCISD::XXMFACC:         return "PPCISD::XXMFACC";
   case PPCISD::LD_SPLAT:        return "PPCISD::LD_SPLAT";
+  case PPCISD::ZEXT_LD_SPLAT:   return "PPCISD::ZEXT_LD_SPLAT";
+  case PPCISD::SEXT_LD_SPLAT:   return "PPCISD::SEXT_LD_SPLAT";
   case PPCISD::FNMSUB:          return "PPCISD::FNMSUB";
   case PPCISD::STRICT_FADDRTZ:
     return "PPCISD::STRICT_FADDRTZ";
@@ -9133,13 +9135,33 @@ SDValue PPCTargetLowering::LowerBUILD_VECTOR(SDValue Op,
     bool IsPermutedLoad = false;
     const SDValue *InputLoad =
         getNormalLoadInput(Op.getOperand(0), IsPermutedLoad);
+    const SDNode *InputNode = Op.getOperand(0).getNode();
+    bool ZExt = false, SExt = false;
+    auto NewOpcode = PPCISD::LD_SPLAT;
+    // Handle zero/sign extended loads.
+    if (!InputLoad && ISD::isUNINDEXEDLoad(InputNode) &&
+        Op.getValueType() == MVT::v2i64 &&
+        cast<LoadSDNode>(Op.getOperand(0))->getMemoryVT() == MVT::i32) {
+      InputLoad = &Op.getOperand(0);
+      if (ISD::isZEXTLoad(InputNode)) {
+        ZExt = true;
+        NewOpcode = PPCISD::ZEXT_LD_SPLAT;
+      }
+      else if (ISD::isSEXTLoad(InputNode)) {
+        SExt = true;
+        NewOpcode = PPCISD::SEXT_LD_SPLAT;
+      }
+      else
+        InputLoad = nullptr;
+    }
+
     // Handle load-and-splat patterns as we have instructions that will do this
     // in one go.
     if (InputLoad && DAG.isSplatValue(Op, true)) {
       LoadSDNode *LD = cast<LoadSDNode>(*InputLoad);
 
       // We have handling for 4 and 8 byte elements.
-      unsigned ElementSize = LD->getMemoryVT().getScalarSizeInBits();
+      unsigned ElementSize = LD->getMemoryVT().getScalarSizeInBits() * (ZExt || SExt ? 2 : 1);
 
       // Checking for a single use of this load, we have to check for vector
       // width (128 bits) / ElementSize uses (since each operand of the
@@ -9150,15 +9172,14 @@ SDValue PPCTargetLowering::LowerBUILD_VECTOR(SDValue Op,
           NumUsesOfInputLD--;
       assert(NumUsesOfInputLD > 0 && "No uses of input LD of a build_vector?");
       if (InputLoad->getNode()->hasNUsesOfValue(NumUsesOfInputLD, 0) &&
-          ((Subtarget.hasVSX() && ElementSize == 64) ||
-           (Subtarget.hasP9Vector() && ElementSize == 32))) {
+          Subtarget.hasVSX() && (ElementSize == 64 || ElementSize == 32)) {
         SDValue Ops[] = {
           LD->getChain(),    // Chain
           LD->getBasePtr(),  // Ptr
           DAG.getValueType(Op.getValueType()) // VT
         };
         SDValue LdSplt = DAG.getMemIntrinsicNode(
-            PPCISD::LD_SPLAT, dl, DAG.getVTList(Op.getValueType(), MVT::Other),
+            NewOpcode, dl, DAG.getVTList(Op.getValueType(), MVT::Other),
             Ops, LD->getMemoryVT(), LD->getMemOperand());
         // Replace all uses of the output chain of the original load with the
         // output chain of the new load.
diff --git a/llvm/lib/Target/PowerPC/PPCISelLowering.h b/llvm/lib/Target/PowerPC/PPCISelLowering.h
index 87579bad118f..2452e12fe926 100644
--- a/llvm/lib/Target/PowerPC/PPCISelLowering.h
+++ b/llvm/lib/Target/PowerPC/PPCISelLowering.h
@@ -554,6 +554,14 @@ namespace llvm {
     /// instructions such as LXVDSX, LXVWSX.
     LD_SPLAT,
 
+    /// VSRC, CHAIN = ZEXT_LD_SPLAT, CHAIN, Ptr - a splatting load memory
+    /// that zero-extends.
+    ZEXT_LD_SPLAT,
+
+    /// VSRC, CHAIN = SEXT_LD_SPLAT, CHAIN, Ptr - a splatting load memory
+    /// that sign-extends.
+    SEXT_LD_SPLAT,
+
     /// CHAIN = STXVD2X CHAIN, VSRC, Ptr - Occurs only for little endian.
     /// Maps directly to an stxvd2x instruction that will be preceded by
     /// an xxswapd.
diff --git a/llvm/lib/Target/PowerPC/PPCInstrVSX.td b/llvm/lib/Target/PowerPC/PPCInstrVSX.td
index a13eb2b6e109..616cbfc7cd55 100644
--- a/llvm/lib/Target/PowerPC/PPCInstrVSX.td
+++ b/llvm/lib/Target/PowerPC/PPCInstrVSX.td
@@ -138,6 +138,10 @@ def PPCldvsxlh : SDNode<"PPCISD::LD_VSX_LH", SDT_PPCldvsxlh,
                         [SDNPHasChain, SDNPMayLoad, SDNPMemOperand]>;
 def PPCldsplat : SDNode<"PPCISD::LD_SPLAT", SDT_PPCldsplat,
                         [SDNPHasChain, SDNPMayLoad, SDNPMemOperand]>;
+def PPCzextldsplat : SDNode<"PPCISD::ZEXT_LD_SPLAT", SDT_PPCldsplat,
+                        [SDNPHasChain, SDNPMayLoad, SDNPMemOperand]>;
+def PPCsextldsplat : SDNode<"PPCISD::SEXT_LD_SPLAT", SDT_PPCldsplat,
+                        [SDNPHasChain, SDNPMayLoad, SDNPMemOperand]>;
 def PPCSToV : SDNode<"PPCISD::SCALAR_TO_VECTOR_PERMUTED",
                      SDTypeProfile<1, 1, []>, []>;
 
@@ -2827,6 +2831,14 @@ def : Pat<(v2f64 (PPCldsplat ForceXForm:$A)),
           (v2f64 (LXVDSX ForceXForm:$A))>;
 def : Pat<(v2i64 (PPCldsplat ForceXForm:$A)),
           (v2i64 (LXVDSX ForceXForm:$A))>;
+def : Pat<(v2i64 (PPCzextldsplat ForceXForm:$A)),
+          (v2i64 (XXPERMDIs (LFIWZX ForceXForm:$A), 0))>;
+def : Pat<(v2i64 (PPCsextldsplat ForceXForm:$A)),
+          (v2i64 (XXPERMDIs (LFIWAX ForceXForm:$A), 0))>;
+def : Pat<(v4f32 (PPCldsplat ForceXForm:$A)),
+          (v4f32 (XXSPLTW (SUBREG_TO_REG (i64 1), (LFIWZX ForceXForm:$A), sub_64), 1))>;
+def : Pat<(v4i32 (PPCldsplat ForceXForm:$A)),
+          (v4i32 (XXSPLTW (SUBREG_TO_REG (i64 1), (LFIWZX ForceXForm:$A), sub_64), 1))>;
 
 // Build vectors of floating point converted to i64.
 def : Pat<(v2i64 (build_vector FltToLong.A, FltToLong.A)),
diff --git a/llvm/lib/Target/PowerPC/PPCMIPeephole.cpp b/llvm/lib/Target/PowerPC/PPCMIPeephole.cpp
index 4bbb6ed85a6c..16e711cafee5 100644
--- a/llvm/lib/Target/PowerPC/PPCMIPeephole.cpp
+++ b/llvm/lib/Target/PowerPC/PPCMIPeephole.cpp
@@ -603,14 +603,23 @@ bool PPCMIPeephole::simplifyCode(void) {
             ToErase = &MI;
             Simplified = true;
           }
-        } else if ((Immed == 0 || Immed == 3) && DefOpc == PPC::XXPERMDIs &&
+        } else if ((Immed == 0 || Immed == 3 || Immed == 2) && DefOpc == PPC::XXPERMDIs &&
                    (DefMI->getOperand(2).getImm() == 0 ||
                     DefMI->getOperand(2).getImm() == 3)) {
+          ToErase = &MI;
+          Simplified = true;
+          // Swap of a splat, convert to copy.
+          if (Immed == 2) {
+            LLVM_DEBUG(dbgs() << "Optimizing swap(splat) => copy(splat): ");
+            LLVM_DEBUG(MI.dump());
+            BuildMI(MBB, &MI, MI.getDebugLoc(), TII->get(PPC::COPY),
+                    MI.getOperand(0).getReg())
+              .add(MI.getOperand(1));
+            break;
+          }
           // Splat fed by another splat - switch the output of the first
           // and remove the second.
           DefMI->getOperand(0).setReg(MI.getOperand(0).getReg());
-          ToErase = &MI;
-          Simplified = true;
           LLVM_DEBUG(dbgs() << "Removing redundant splat: ");
           LLVM_DEBUG(MI.dump());
         }
lkail added a subscriber: lkail.Jul 20 2021, 10:27 PM
lkail added inline comments.
llvm/lib/Target/PowerPC/PPCISelLowering.cpp
10522

Is !Subtarget.hasVSX() a typo? It looks not consistent with the comments.

Thanks for your comments @nemanjai
I posted a new patch https://reviews.llvm.org/D106555 based on your codes for the splat loads. In that patch, more types will be handled.

I will leave this as an improvement for scalar_to_vector without a splat user. Not sure the improvement for unadjusted_lxvwsx is a motivated case. I don't see the improvement in https://reviews.llvm.org/D106555.

shchenz added a comment.EditedJul 25 2021, 11:54 PM

After more investigation, I think we don't need this patch.

scalar_to_vector (scalar_load) will put the scalar load result at the first element of the vector result and set other elements of the vector result as undef.

So it is always correct to convert scalar_to_vector (scalar_load) to a splat_load(all vector elements are the same). And we already improve the code gen for splat_load in https://reviews.llvm.org/D106555. I verified all the types in this patch v4i32/v4f32/v2i64_signext/v2i64_zeroext, it can be handled in patch https://reviews.llvm.org/D106555. We already recognize all of them as splat load.

The only unexpected case is the case unadjusted_lxvwsx in file llvm/test/CodeGen/PowerPC/load-and-splat.ll. In this case, it has following patterns:

t7: v4i8,ch = load<(load (s32) from %ir.0)> t0, t2, undef:i64

This will be legalized as:

Legalizing node: t7: v4i8,ch = load<(load (s32) from %ir.0)> t0, t2, undef:i64
Analyzing result type: v4i8
Widen node result 0: t7: v4i8,ch = load<(load (s32) from %ir.0)> t0, t2, undef:i64

Creating new node: t15: i32,ch = load<(load (s32) from %ir.0)> t0, t2, undef:i64
Creating new node: t16: v4i32 = scalar_to_vector t15
Creating new node: t17: v16i8 = bitcast t16

Since the scalar_to_vector is not generated from BUILD_VECTOR, so it can not be recognized as a SPLAT_LOAD.

Even for the above case, I think we should fix it when we legalize the type v4i8. Since we already added many types of LD_SPLAT nodes, we should generate an LD_SPLAT node instead of scalar_to_vector(load) for the type legalizing now.

FYI @nemanjai If you think this is still needed, please let me know. Thanks.

shchenz abandoned this revision.Jul 25 2021, 11:59 PM