diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -33,6 +33,7 @@ #include "llvm/IR/IRBuilder.h" #include "llvm/IR/IntrinsicsRISCV.h" #include "llvm/IR/PatternMatch.h" +#include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/KnownBits.h" @@ -45,6 +46,12 @@ STATISTIC(NumTailCalls, "Number of tail calls"); +static cl::opt + AllowSplatInVW_W(DEBUG_TYPE "-form-vw-w-with-splat", cl::Hidden, + cl::desc("Allow the formation of VW_W operations (e.g., " + "VWADD_W) with splat constants"), + cl::init(false)); + RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM, const RISCVSubtarget &STI) : TargetLowering(TM), Subtarget(STI) { @@ -8204,228 +8211,548 @@ return SDValue(); } -// Try to form vwadd(u).wv/wx or vwsub(u).wv/wx. It might later be optimized to -// vwadd(u).vv/vx or vwsub(u).vv/vx. -static SDValue combineADDSUB_VLToVWADDSUB_VL(SDNode *N, SelectionDAG &DAG, - bool Commute = false) { - assert((N->getOpcode() == RISCVISD::ADD_VL || - N->getOpcode() == RISCVISD::SUB_VL) && - "Unexpected opcode"); - bool IsAdd = N->getOpcode() == RISCVISD::ADD_VL; - SDValue Op0 = N->getOperand(0); - SDValue Op1 = N->getOperand(1); - if (Commute) - std::swap(Op0, Op1); +namespace { +// Forward declaration of the structure holding the necessary information to +// apply a combine. +struct CombineResult; - MVT VT = N->getSimpleValueType(0); +/// Helper class for folding sign/zero extensions. +/// In particular, this class is used for the following combines: +/// add_vl -> vwadd(u) | vwadd(u)_w +/// sub_vl -> vwsub(u) | vwsub(u)_w +/// mul_vl -> vwmul(u) | vwmul_su +/// +/// An object of this class represents an operand of the operation we want to +/// combine. +/// E.g., when trying to combine `mul_vl a, b`, we will have one instance of +/// NodeExtensionHelper for `a` and one for `b`. +/// +/// This class abstracts away how the extension is materialized and +/// how its Mask, VL, number of users affect the combines. +/// +/// In particular: +/// - VWADD_W is conceptually == add(op0, sext(op1)) +/// - VWADDU_W == add(op0, zext(op1)) +/// - VWSUB_W == sub(op0, sext(op1)) +/// - VWSUBU_W == sub(op0, zext(op1)) +/// +/// And VMV_V_X_VL, depending on the value, is conceptually equivalent to +/// zext|sext(smaller_value). +struct NodeExtensionHelper { + /// Records if this operand is like being zero extended. + bool SupportsZExt; + /// Records if this operand is like being sign extended. + /// Note: SupportsZExt and SupportsSExt are not mutually exclusive. For + /// instance, a splat constant (e.g., 3), would support being both sign and + /// zero extended. + bool SupportsSExt; + /// This boolean captures whether we care if this operand would still be + /// around after the folding happens. + bool EnforceOneUse; + /// Records if this operand's mask needs to match the mask of the operation + /// that it will fold into. + bool CheckMask; + /// Value of the Mask for this operand. + /// It may be SDValue(). + SDValue Mask; + /// Value of the vector length operand. + /// It may be SDValue(). + SDValue VL; + /// Original value that this NodeExtensionHelper represents. + SDValue OrigOperand; + + /// Get the value feeding the extension or the value itself. + /// E.g., for zext(a), this would return a. + SDValue getSource() const { + switch (OrigOperand.getOpcode()) { + case RISCVISD::VSEXT_VL: + case RISCVISD::VZEXT_VL: + return OrigOperand.getOperand(0); + default: + return OrigOperand; + } + } + + /// Check if this instance represents a splat. + bool isSplat() const { + return OrigOperand.getOpcode() == RISCVISD::VMV_V_X_VL; + } + + /// Get or create a value that can feed \p Root with the given \p ExtOpc. + /// If \p ExtOpc is None, this returns the source of this operand. + /// \see ::getSource(). + SDValue getOrCreateExtendedOp(const SDNode *Root, SelectionDAG &DAG, + Optional ExtOpc) const { + SDValue Source = getSource(); + if (!ExtOpc) + return Source; + + MVT NarrowVT = getNarrowType(Root); + // If we need an extension, we should be changing the type. + assert(Source.getValueType() != NarrowVT && "Needless extension"); + SDLoc DL(Root); + auto [Mask, VL] = getMaskAndVL(Root); + switch (OrigOperand.getOpcode()) { + case RISCVISD::VSEXT_VL: + case RISCVISD::VZEXT_VL: + return DAG.getNode(*ExtOpc, DL, NarrowVT, Source, Mask, VL); + case RISCVISD::VMV_V_X_VL: + return DAG.getNode(RISCVISD::VMV_V_X_VL, DL, NarrowVT, + DAG.getUNDEF(NarrowVT), Source.getOperand(1), VL); + default: + // Other opcodes can only come from the original LHS of VW(ADD|SUB)_W_VL + // and that operand should already have the right NarrowVT so no + // extension should be required at this point. + llvm_unreachable("Unsupported opcode"); + } + } - // Determine the narrow size for a widening add/sub. - unsigned NarrowSize = VT.getScalarSizeInBits() / 2; - MVT NarrowVT = MVT::getVectorVT(MVT::getIntegerVT(NarrowSize), - VT.getVectorElementCount()); + /// Helper function to get the narrow type for \p Root. + /// The narrow type is the type of \p Root where we divided the size of each + /// element by 2. E.g., if Root's type <2xi16> -> narrow type <2xi8>. + /// \pre The size of the type of the elements of Root must be a multiple of 2 + /// and be greater than 16. + static MVT getNarrowType(const SDNode *Root) { + MVT VT = Root->getSimpleValueType(0); - SDValue Merge = N->getOperand(2); - SDValue Mask = N->getOperand(3); - SDValue VL = N->getOperand(4); + // Determine the narrow size. + unsigned NarrowSize = VT.getScalarSizeInBits() / 2; + assert(NarrowSize >= 8 && "Trying to extend something we can't represent"); + MVT NarrowVT = MVT::getVectorVT(MVT::getIntegerVT(NarrowSize), + VT.getVectorElementCount()); + return NarrowVT; + } - SDLoc DL(N); + /// Return the opcode required to materialize the folding of the sign + /// extensions (\p IsSExt == true) or zero extensions (IsSExt == false) for + /// both operands for \p Opcode. + /// Put differently, get the opcode to materialize: + /// - ISExt == true: \p Opcode(sext(a), sext(b)) -> newOpcode(a, b) + /// - ISExt == false: \p Opcode(zext(a), zext(b)) -> newOpcode(a, b) + /// \pre \p Opcode represents a supported root (\see ::isSupportedRoot()). + static unsigned getSameExtensionOpcode(unsigned Opcode, bool IsSExt) { + switch (Opcode) { + case RISCVISD::ADD_VL: + case RISCVISD::VWADD_W_VL: + case RISCVISD::VWADDU_W_VL: + return IsSExt ? RISCVISD::VWADD_VL : RISCVISD::VWADDU_VL; + case RISCVISD::MUL_VL: + return IsSExt ? RISCVISD::VWMUL_VL : RISCVISD::VWMULU_VL; + case RISCVISD::SUB_VL: + case RISCVISD::VWSUB_W_VL: + case RISCVISD::VWSUBU_W_VL: + return IsSExt ? RISCVISD::VWSUB_VL : RISCVISD::VWSUBU_VL; + default: + llvm_unreachable("Unexpected opcode"); + } + } - // If the RHS is a sext or zext, we can form a widening op. - if ((Op1.getOpcode() == RISCVISD::VZEXT_VL || - Op1.getOpcode() == RISCVISD::VSEXT_VL) && - Op1.hasOneUse() && Op1.getOperand(1) == Mask && Op1.getOperand(2) == VL) { - unsigned ExtOpc = Op1.getOpcode(); - Op1 = Op1.getOperand(0); - // Re-introduce narrower extends if needed. - if (Op1.getValueType() != NarrowVT) - Op1 = DAG.getNode(ExtOpc, DL, NarrowVT, Op1, Mask, VL); - - unsigned WOpc; - if (ExtOpc == RISCVISD::VSEXT_VL) - WOpc = IsAdd ? RISCVISD::VWADD_W_VL : RISCVISD::VWSUB_W_VL; - else - WOpc = IsAdd ? RISCVISD::VWADDU_W_VL : RISCVISD::VWSUBU_W_VL; + /// Get the opcode to materialize \p Opcode(sext(a), zext(b)) -> + /// newOpcode(a, b). + static unsigned getSUOpcode(unsigned Opcode) { + assert(Opcode == RISCVISD::MUL_VL && "SU is only supported for MUL"); + return RISCVISD::VWMULSU_VL; + } - return DAG.getNode(WOpc, DL, VT, Op0, Op1, Merge, Mask, VL); + /// Get the opcode to materialize \p Opcode(a, s|zext(b)) -> + /// newOpcode(a, b). + static unsigned getWOpcode(unsigned Opcode, bool IsSExt) { + switch (Opcode) { + case RISCVISD::ADD_VL: + return IsSExt ? RISCVISD::VWADD_W_VL : RISCVISD::VWADDU_W_VL; + case RISCVISD::SUB_VL: + return IsSExt ? RISCVISD::VWSUB_W_VL : RISCVISD::VWSUBU_W_VL; + default: + llvm_unreachable("Unexpected opcode"); + } } - // FIXME: Is it useful to form a vwadd.wx or vwsub.wx if it removes a scalar - // sext/zext? + using CombineToTry = std::function( + SDNode * /*Root*/, const NodeExtensionHelper & /*LHS*/, + const NodeExtensionHelper & /*RHS*/)>; - return SDValue(); -} + /// Check if this node needs to be fully folded or extended for all users. + bool needToPromoteOtherUsers() const { return EnforceOneUse; } -// Try to convert vwadd(u).wv/wx or vwsub(u).wv/wx to vwadd(u).vv/vx or -// vwsub(u).vv/vx. -static SDValue combineVWADD_W_VL_VWSUB_W_VL(SDNode *N, SelectionDAG &DAG) { - SDValue Op0 = N->getOperand(0); - SDValue Op1 = N->getOperand(1); - SDValue Merge = N->getOperand(2); - SDValue Mask = N->getOperand(3); - SDValue VL = N->getOperand(4); + /// Helper method to set the various fields of this struct based on the + /// type of \p Root. + void fillUpExtensionSupport(SDNode *Root, SelectionDAG &DAG) { + SupportsZExt = false; + SupportsSExt = false; + EnforceOneUse = true; + CheckMask = true; + switch (OrigOperand.getOpcode()) { + case RISCVISD::VZEXT_VL: + SupportsZExt = true; + Mask = OrigOperand.getOperand(1); + VL = OrigOperand.getOperand(2); + break; + case RISCVISD::VSEXT_VL: + SupportsSExt = true; + Mask = OrigOperand.getOperand(1); + VL = OrigOperand.getOperand(2); + break; + case RISCVISD::VMV_V_X_VL: { + // Historically, we didn't care about splat values not disappearing during + // combines. + EnforceOneUse = false; + CheckMask = false; + VL = OrigOperand.getOperand(2); - MVT VT = N->getSimpleValueType(0); - MVT NarrowVT = Op1.getSimpleValueType(); - unsigned NarrowSize = NarrowVT.getScalarSizeInBits(); + // The operand is a splat of a scalar. - unsigned VOpc; - switch (N->getOpcode()) { - default: llvm_unreachable("Unexpected opcode"); - case RISCVISD::VWADD_W_VL: VOpc = RISCVISD::VWADD_VL; break; - case RISCVISD::VWSUB_W_VL: VOpc = RISCVISD::VWSUB_VL; break; - case RISCVISD::VWADDU_W_VL: VOpc = RISCVISD::VWADDU_VL; break; - case RISCVISD::VWSUBU_W_VL: VOpc = RISCVISD::VWSUBU_VL; break; - } + // The pasthru must be undef for tail agnostic. + if (!OrigOperand.getOperand(0).isUndef()) + break; - bool IsSigned = N->getOpcode() == RISCVISD::VWADD_W_VL || - N->getOpcode() == RISCVISD::VWSUB_W_VL; + // Get the scalar value. + SDValue Op = OrigOperand.getOperand(1); + + // See if we have enough sign bits or zero bits in the scalar to use a + // widening opcode by splatting to smaller element size. + MVT VT = Root->getSimpleValueType(0); + unsigned EltBits = VT.getScalarSizeInBits(); + unsigned ScalarBits = Op.getValueSizeInBits(); + // Make sure we're getting all element bits from the scalar register. + // FIXME: Support implicit sign extension of vmv.v.x? + if (ScalarBits < EltBits) + break; - SDLoc DL(N); + unsigned NarrowSize = VT.getScalarSizeInBits() / 2; + // If the narrow type cannot be expressed with a legal VMV, + // this is not a valid candidate. + if (NarrowSize < 8) + break; - // If the LHS is a sext or zext, we can narrow this op to the same size as - // the RHS. - if (((Op0.getOpcode() == RISCVISD::VZEXT_VL && !IsSigned) || - (Op0.getOpcode() == RISCVISD::VSEXT_VL && IsSigned)) && - Op0.hasOneUse() && Op0.getOperand(1) == Mask && Op0.getOperand(2) == VL) { - unsigned ExtOpc = Op0.getOpcode(); - Op0 = Op0.getOperand(0); - // Re-introduce narrower extends if needed. - if (Op0.getValueType() != NarrowVT) - Op0 = DAG.getNode(ExtOpc, DL, NarrowVT, Op0, Mask, VL); - return DAG.getNode(VOpc, DL, VT, Op0, Op1, Merge, Mask, VL); - } - - bool IsAdd = N->getOpcode() == RISCVISD::VWADD_W_VL || - N->getOpcode() == RISCVISD::VWADDU_W_VL; - - // Look for splats on the left hand side of a vwadd(u).wv. We might be able - // to commute and use a vwadd(u).vx instead. - if (IsAdd && Op0.getOpcode() == RISCVISD::VMV_V_X_VL && - Op0.getOperand(0).isUndef() && Op0.getOperand(2) == VL) { - Op0 = Op0.getOperand(1); - - // See if have enough sign bits or zero bits in the scalar to use a - // widening add/sub by splatting to smaller element size. - unsigned EltBits = VT.getScalarSizeInBits(); - unsigned ScalarBits = Op0.getValueSizeInBits(); - // Make sure we're getting all element bits from the scalar register. - // FIXME: Support implicit sign extension of vmv.v.x? - if (ScalarBits < EltBits) - return SDValue(); + if (DAG.ComputeMaxSignificantBits(Op) <= NarrowSize) + SupportsSExt = true; + if (DAG.MaskedValueIsZero(Op, + APInt::getBitsSetFrom(ScalarBits, NarrowSize))) + SupportsZExt = true; + break; + } + default: + break; + } + } - if (IsSigned) { - if (DAG.ComputeMaxSignificantBits(Op0) > NarrowSize) - return SDValue(); - } else { - if (!DAG.MaskedValueIsZero(Op0, - APInt::getBitsSetFrom(ScalarBits, NarrowSize))) - return SDValue(); + /// Check if \p Root supports any extension folding combines. + static bool isSupportedRoot(const SDNode *Root) { + switch (Root->getOpcode()) { + case RISCVISD::ADD_VL: + case RISCVISD::MUL_VL: + case RISCVISD::VWADD_W_VL: + case RISCVISD::VWADDU_W_VL: + case RISCVISD::SUB_VL: + case RISCVISD::VWSUB_W_VL: + case RISCVISD::VWSUBU_W_VL: + return true; + default: + return false; } + } - Op0 = DAG.getNode(RISCVISD::VMV_V_X_VL, DL, NarrowVT, - DAG.getUNDEF(NarrowVT), Op0, VL); - return DAG.getNode(VOpc, DL, VT, Op1, Op0, Merge, Mask, VL); + /// Build a NodeExtensionHelper for \p Root.getOperand(\p OperandIdx). + NodeExtensionHelper(SDNode *Root, unsigned OperandIdx, SelectionDAG &DAG) { + assert(isSupportedRoot(Root) && "Trying to build an helper with an " + "unsupported root"); + assert(OperandIdx < 2 && "Requesting something else than LHS or RHS"); + OrigOperand = Root->getOperand(OperandIdx); + + unsigned Opc = Root->getOpcode(); + switch (Opc) { + // We consider VW(U)_W(LHS, RHS) as if they were + // (LHS, S|ZEXT(RHS)) + case RISCVISD::VWADD_W_VL: + case RISCVISD::VWADDU_W_VL: + case RISCVISD::VWSUB_W_VL: + case RISCVISD::VWSUBU_W_VL: + if (OperandIdx == 1) { + SupportsZExt = + Opc == RISCVISD::VWADDU_W_VL || Opc == RISCVISD::VWSUBU_W_VL; + SupportsSExt = !SupportsZExt; + std::tie(Mask, VL) = getMaskAndVL(Root); + // There's no existing extension here, so we don't have to worry about + // making sure it gets removed. + EnforceOneUse = false; + break; + } + [[fallthrough]]; + default: + fillUpExtensionSupport(Root, DAG); + break; + } } - return SDValue(); -} + /// Check if this operand is compatible with the given vector length \p VL. + bool isVLCompatible(SDValue VL) const { return this->VL && this->VL == VL; } -// Try to form VWMUL, VWMULU or VWMULSU. -// TODO: Support VWMULSU.vx with a sign extend Op and a splat of scalar Op. -static SDValue combineMUL_VLToVWMUL_VL(SDNode *N, SelectionDAG &DAG, - bool Commute) { - assert(N->getOpcode() == RISCVISD::MUL_VL && "Unexpected opcode"); - SDValue Op0 = N->getOperand(0); - SDValue Op1 = N->getOperand(1); - if (Commute) - std::swap(Op0, Op1); - - bool IsSignExt = Op0.getOpcode() == RISCVISD::VSEXT_VL; - bool IsZeroExt = Op0.getOpcode() == RISCVISD::VZEXT_VL; - bool IsVWMULSU = IsSignExt && Op1.getOpcode() == RISCVISD::VZEXT_VL; - if ((!IsSignExt && !IsZeroExt) || !Op0.hasOneUse()) - return SDValue(); + /// Check if this operand is compatible with the given \p Mask. + bool isMaskCompatible(SDValue Mask) const { + return !CheckMask || (this->Mask && this->Mask == Mask); + } - SDValue Merge = N->getOperand(2); - SDValue Mask = N->getOperand(3); - SDValue VL = N->getOperand(4); + /// Helper function to get the Mask and VL from \p Root. + static std::pair getMaskAndVL(const SDNode *Root) { + assert(isSupportedRoot(Root) && "Unexpected root"); + return std::make_pair(Root->getOperand(3), Root->getOperand(4)); + } - // Make sure the mask and VL match. - if (Op0.getOperand(1) != Mask || Op0.getOperand(2) != VL) - return SDValue(); + /// Check if the Mask and VL of this operand are compatible with \p Root. + bool areVLAndMaskCompatible(const SDNode *Root) const { + auto [Mask, VL] = getMaskAndVL(Root); + return isMaskCompatible(Mask) && isVLCompatible(VL); + } - MVT VT = N->getSimpleValueType(0); + /// Helper function to check if \p N is commutative with respect to the + /// foldings that are supported by this class. + static bool isCommutative(const SDNode *N) { + switch (N->getOpcode()) { + case RISCVISD::ADD_VL: + case RISCVISD::MUL_VL: + case RISCVISD::VWADD_W_VL: + case RISCVISD::VWADDU_W_VL: + return true; + case RISCVISD::SUB_VL: + case RISCVISD::VWSUB_W_VL: + case RISCVISD::VWSUBU_W_VL: + return false; + default: + llvm_unreachable("Unexpected opcode"); + } + } - // Determine the narrow size for a widening multiply. - unsigned NarrowSize = VT.getScalarSizeInBits() / 2; - MVT NarrowVT = MVT::getVectorVT(MVT::getIntegerVT(NarrowSize), - VT.getVectorElementCount()); + /// Get a list of combine to try for folding extensions in \p Root. + /// Note that each returned CombineToTry function doesn't actually modify + /// anything. Instead they produce an optional CombineResult that if not None, + /// need to be materialized for the combine to be applied. + /// \see CombineResult::materialize. + /// If the related CombineToTry function returns None, that means the combine + /// didn't match. + static SmallVector getSupportedFoldings(const SDNode *Root); +}; - SDLoc DL(N); +/// Helper structure that holds all the necessary information to materialize a +/// combine that does some extension folding. +struct CombineResult { + /// Opcode to be generated when materializing the combine. + unsigned TargetOpcode; + /// Extension opcode to be applied to the source of LHS when materializing + /// TargetOpcode. + /// \see NodeExtensionHelper::getSource(). + Optional LHSExtOpc; + /// Extension opcode to be applied to the source of RHS when materializing + /// TargetOpcode. + Optional RHSExtOpc; + /// Root of the combine. + SDNode *Root; + /// LHS of the TargetOpcode. + const NodeExtensionHelper &LHS; + /// RHS of the TargetOpcode. + const NodeExtensionHelper &RHS; + + CombineResult(unsigned TargetOpcode, SDNode *Root, + const NodeExtensionHelper &LHS, Optional SExtLHS, + const NodeExtensionHelper &RHS, Optional SExtRHS) + : TargetOpcode(TargetOpcode), Root(Root), LHS(LHS), RHS(RHS) { + MVT NarrowVT = NodeExtensionHelper::getNarrowType(Root); + if (SExtLHS && LHS.getSource().getValueType() != NarrowVT) + LHSExtOpc = *SExtLHS ? RISCVISD::VSEXT_VL : RISCVISD::VZEXT_VL; + if (SExtRHS && RHS.getSource().getValueType() != NarrowVT) + RHSExtOpc = *SExtRHS ? RISCVISD::VSEXT_VL : RISCVISD::VZEXT_VL; + } + + /// Return a value that uses TargetOpcode and that can be used to replace + /// Root. + /// The actual replacement is *not* done in that method. + SDValue materialize(SelectionDAG &DAG) const { + SDValue Mask, VL, Merge; + std::tie(Mask, VL) = NodeExtensionHelper::getMaskAndVL(Root); + Merge = Root->getOperand(2); + return DAG.getNode(TargetOpcode, SDLoc(Root), Root->getValueType(0), + LHS.getOrCreateExtendedOp(Root, DAG, LHSExtOpc), + RHS.getOrCreateExtendedOp(Root, DAG, RHSExtOpc), Merge, + Mask, VL); + } +}; - // See if the other operand is the same opcode. - if (IsVWMULSU || Op0.getOpcode() == Op1.getOpcode()) { - if (!Op1.hasOneUse()) - return SDValue(); +/// Check if \p Root follows a pattern Root(ext(LHS), ext(RHS)) +/// where `ext` is the same for both LHS and RHS (i.e., both are sext or both +/// are zext) and LHS and RHS can be folded into Root. +/// AllowSExt and AllozZExt define which form `ext` can take in this pattern. +/// +/// \note If the pattern can match with both zext and sext, the returned +/// CombineResult will feature the zext result. +/// +/// \returns None if the pattern doesn't match or a CombineResult that can be +/// used to apply the pattern. +static Optional +canFoldToVWWithSameExtensionImpl(SDNode *Root, const NodeExtensionHelper &LHS, + const NodeExtensionHelper &RHS, bool AllowSExt, + bool AllowZExt) { + assert((AllowSExt || AllowZExt) && "Forgot to set what you want?"); + if (!LHS.areVLAndMaskCompatible(Root) || !RHS.areVLAndMaskCompatible(Root)) + return None; + if (AllowZExt && LHS.SupportsZExt && RHS.SupportsZExt) + return CombineResult(NodeExtensionHelper::getSameExtensionOpcode( + Root->getOpcode(), /*IsSExt=*/false), + Root, LHS, /*SExtLHS=*/false, RHS, + /*SExtRHS=*/false); + if (AllowSExt && LHS.SupportsSExt && RHS.SupportsSExt) + return CombineResult(NodeExtensionHelper::getSameExtensionOpcode( + Root->getOpcode(), /*IsSExt=*/true), + Root, LHS, /*SExtLHS=*/true, RHS, + /*SExtRHS=*/true); + return None; +} - // Make sure the mask and VL match. - if (Op1.getOperand(1) != Mask || Op1.getOperand(2) != VL) - return SDValue(); +/// Check if \p Root follows a pattern Root(ext(LHS), ext(RHS)) +/// where `ext` is the same for both LHS and RHS (i.e., both are sext or both +/// are zext) and LHS and RHS can be folded into Root. +/// +/// \returns None if the pattern doesn't match or a CombineResult that can be +/// used to apply the pattern. +static Optional +canFoldToVWWithSameExtension(SDNode *Root, const NodeExtensionHelper &LHS, + const NodeExtensionHelper &RHS) { + return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, /*AllowSExt=*/true, + /*AllowZExt=*/true); +} - Op1 = Op1.getOperand(0); - } else if (Op1.getOpcode() == RISCVISD::VMV_V_X_VL) { - // The operand is a splat of a scalar. +/// Check if \p Root follows a pattern Root(LHS, ext(RHS)) +/// +/// \returns None if the pattern doesn't match or a CombineResult that can be +/// used to apply the pattern. +static Optional canFoldToVW_W(SDNode *Root, + const NodeExtensionHelper &LHS, + const NodeExtensionHelper &RHS) { + if (!RHS.areVLAndMaskCompatible(Root)) + return None; - // The pasthru must be undef for tail agnostic - if (!Op1.getOperand(0).isUndef()) - return SDValue(); - // The VL must be the same. - if (Op1.getOperand(2) != VL) - return SDValue(); + // FIXME: Is it useful to form a vwadd.wx or vwsub.wx if it removes a scalar + // sext/zext? + // Control this behavior behind an option (AllowSplatInVW_W) for testing + // purposes. + if (RHS.SupportsZExt && (!RHS.isSplat() || AllowSplatInVW_W)) + return CombineResult( + NodeExtensionHelper::getWOpcode(Root->getOpcode(), /*IsSExt=*/false), + Root, LHS, /*SExtLHS=*/None, RHS, /*SExtRHS=*/false); + if (RHS.SupportsSExt && (!RHS.isSplat() || AllowSplatInVW_W)) + return CombineResult( + NodeExtensionHelper::getWOpcode(Root->getOpcode(), /*IsSExt=*/true), + Root, LHS, /*SExtLHS=*/None, RHS, /*SExtRHS=*/true); + return None; +} - // Get the scalar value. - Op1 = Op1.getOperand(1); +/// Check if \p Root follows a pattern Root(sext(LHS), sext(RHS)) +/// +/// \returns None if the pattern doesn't match or a CombineResult that can be +/// used to apply the pattern. +static Optional +canFoldToVWWithSEXT(SDNode *Root, const NodeExtensionHelper &LHS, + const NodeExtensionHelper &RHS) { + return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, /*AllowSExt=*/true, + /*AllowZExt=*/false); +} - // See if have enough sign bits or zero bits in the scalar to use a - // widening multiply by splatting to smaller element size. - unsigned EltBits = VT.getScalarSizeInBits(); - unsigned ScalarBits = Op1.getValueSizeInBits(); - // Make sure we're getting all element bits from the scalar register. - // FIXME: Support implicit sign extension of vmv.v.x? - if (ScalarBits < EltBits) - return SDValue(); +/// Check if \p Root follows a pattern Root(zext(LHS), zext(RHS)) +/// +/// \returns None if the pattern doesn't match or a CombineResult that can be +/// used to apply the pattern. +static Optional +canFoldToVWWithZEXT(SDNode *Root, const NodeExtensionHelper &LHS, + const NodeExtensionHelper &RHS) { + return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, /*AllowSExt=*/false, + /*AllowZExt=*/true); +} - // If the LHS is a sign extend, try to use vwmul. - if (IsSignExt && DAG.ComputeMaxSignificantBits(Op1) <= NarrowSize) { - // Can use vwmul. - } else if (DAG.MaskedValueIsZero( - Op1, APInt::getBitsSetFrom(ScalarBits, NarrowSize))) { - // Scalar is zero extended, if the vector is sign extended we can use - // vwmulsu. If the vector is zero extended we can use vwmulu. - IsVWMULSU = IsSignExt; - } else - return SDValue(); +/// Check if \p Root follows a pattern Root(sext(LHS), zext(RHS)) +/// +/// \returns None if the pattern doesn't match or a CombineResult that can be +/// used to apply the pattern. +static Optional canFoldToVW_SU(SDNode *Root, + const NodeExtensionHelper &LHS, + const NodeExtensionHelper &RHS) { + if (!LHS.SupportsSExt || !RHS.SupportsZExt) + return None; + if (!LHS.areVLAndMaskCompatible(Root) || !RHS.areVLAndMaskCompatible(Root)) + return None; + return CombineResult(NodeExtensionHelper::getSUOpcode(Root->getOpcode()), + Root, LHS, /*SExtLHS=*/true, RHS, /*SExtRHS=*/false); +} - Op1 = DAG.getNode(RISCVISD::VMV_V_X_VL, DL, NarrowVT, - DAG.getUNDEF(NarrowVT), Op1, VL); - } else +SmallVector +NodeExtensionHelper::getSupportedFoldings(const SDNode *Root) { + SmallVector Strategies; + switch (Root->getOpcode()) { + case RISCVISD::ADD_VL: + case RISCVISD::SUB_VL: + // add|sub -> vwadd(u)|vwsub(u) + Strategies.push_back(canFoldToVWWithSameExtension); + // add|sub -> vwadd(u)_w|vwsub(u)_w + Strategies.push_back(canFoldToVW_W); + break; + case RISCVISD::MUL_VL: + // mul -> vwmul(u) + Strategies.push_back(canFoldToVWWithSameExtension); + // mul -> vwmulsu + Strategies.push_back(canFoldToVW_SU); + break; + case RISCVISD::VWADD_W_VL: + case RISCVISD::VWSUB_W_VL: + // vwadd_w|vwsub_w -> vwadd|vwsub + Strategies.push_back(canFoldToVWWithSEXT); + break; + case RISCVISD::VWADDU_W_VL: + case RISCVISD::VWSUBU_W_VL: + // vwaddu_w|vwsubu_w -> vwaddu|vwsubu + Strategies.push_back(canFoldToVWWithZEXT); + break; + default: + llvm_unreachable("Unexpected opcode"); + } + return Strategies; +} +} // End anonymous namespace. + +/// Combine a binary operation to its equivalent VW or VW_W form. +/// The supported combines are: +/// add_vl -> vwadd(u) | vwadd(u)_w +/// sub_vl -> vwsub(u) | vwsub(u)_w +/// mul_vl -> vwmul(u) | vwmul_su +/// vwadd_w(u) -> vwadd(u) +/// vwub_w(u) -> vwadd(u) +static SDValue +combineBinOp_VLToVWBinOp_VL(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) { + SelectionDAG &DAG = DCI.DAG; + + assert(NodeExtensionHelper::isSupportedRoot(N) && + "Shouldn't have called this method"); + + NodeExtensionHelper LHS(N, 0, DAG); + NodeExtensionHelper RHS(N, 1, DAG); + + if (LHS.needToPromoteOtherUsers() && !LHS.OrigOperand.hasOneUse()) return SDValue(); - Op0 = Op0.getOperand(0); + if (RHS.needToPromoteOtherUsers() && !RHS.OrigOperand.hasOneUse()) + return SDValue(); - // Re-introduce narrower extends if needed. - unsigned ExtOpc = IsSignExt ? RISCVISD::VSEXT_VL : RISCVISD::VZEXT_VL; - if (Op0.getValueType() != NarrowVT) - Op0 = DAG.getNode(ExtOpc, DL, NarrowVT, Op0, Mask, VL); - // vwmulsu requires second operand to be zero extended. - ExtOpc = IsVWMULSU ? RISCVISD::VZEXT_VL : ExtOpc; - if (Op1.getValueType() != NarrowVT) - Op1 = DAG.getNode(ExtOpc, DL, NarrowVT, Op1, Mask, VL); + SmallVector FoldingStrategies = + NodeExtensionHelper::getSupportedFoldings(N); - unsigned WMulOpc = RISCVISD::VWMULSU_VL; - if (!IsVWMULSU) - WMulOpc = IsSignExt ? RISCVISD::VWMUL_VL : RISCVISD::VWMULU_VL; - return DAG.getNode(WMulOpc, DL, VT, Op0, Op1, Merge, Mask, VL); + assert(!FoldingStrategies.empty() && "Nothing to be folded"); + for (int Attempt = 0; Attempt != 1 + NodeExtensionHelper::isCommutative(N); + ++Attempt) { + for (NodeExtensionHelper::CombineToTry FoldingStrategy : + FoldingStrategies) { + Optional Res = FoldingStrategy(N, LHS, RHS); + if (Res) + return Res->materialize(DAG); + } + std::swap(LHS, RHS); + } + return SDValue(); } // Fold @@ -9232,21 +9559,13 @@ break; } case RISCVISD::ADD_VL: - if (SDValue V = combineADDSUB_VLToVWADDSUB_VL(N, DAG, /*Commute*/ false)) - return V; - return combineADDSUB_VLToVWADDSUB_VL(N, DAG, /*Commute*/ true); case RISCVISD::SUB_VL: - return combineADDSUB_VLToVWADDSUB_VL(N, DAG); case RISCVISD::VWADD_W_VL: case RISCVISD::VWADDU_W_VL: case RISCVISD::VWSUB_W_VL: case RISCVISD::VWSUBU_W_VL: - return combineVWADD_W_VL_VWSUB_W_VL(N, DAG); case RISCVISD::MUL_VL: - if (SDValue V = combineMUL_VLToVWMUL_VL(N, DAG, /*Commute*/ false)) - return V; - // Mul is commutative. - return combineMUL_VLToVWMUL_VL(N, DAG, /*Commute*/ true); + return combineBinOp_VLToVWBinOp_VL(N, DCI); case RISCVISD::VFMADD_VL: case RISCVISD::VFNMADD_VL: case RISCVISD::VFMSUB_VL: diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwmul.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwmul.ll --- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwmul.ll +++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwmul.ll @@ -18,6 +18,32 @@ ret <2 x i16> %e } +define <2 x i16> @vwmul_v2i16_multiple_users(<2 x i8>* %x, <2 x i8>* %y, <2 x i8> *%z) { +; CHECK-LABEL: vwmul_v2i16_multiple_users: +; CHECK: # %bb.0: +; CHECK-NEXT: vsetivli zero, 2, e16, mf4, ta, mu +; CHECK-NEXT: vle8.v v8, (a0) +; CHECK-NEXT: vle8.v v9, (a1) +; CHECK-NEXT: vle8.v v10, (a2) +; CHECK-NEXT: vsext.vf2 v11, v8 +; CHECK-NEXT: vsext.vf2 v8, v9 +; CHECK-NEXT: vsext.vf2 v9, v10 +; CHECK-NEXT: vmul.vv v8, v11, v8 +; CHECK-NEXT: vmul.vv v9, v11, v9 +; CHECK-NEXT: vor.vv v8, v8, v9 +; CHECK-NEXT: ret + %a = load <2 x i8>, <2 x i8>* %x + %b = load <2 x i8>, <2 x i8>* %y + %b2 = load <2 x i8>, <2 x i8>* %z + %c = sext <2 x i8> %a to <2 x i16> + %d = sext <2 x i8> %b to <2 x i16> + %d2 = sext <2 x i8> %b2 to <2 x i16> + %e = mul <2 x i16> %c, %d + %f = mul <2 x i16> %c, %d2 + %g = or <2 x i16> %e, %f + ret <2 x i16> %g +} + define <4 x i16> @vwmul_v4i16(<4 x i8>* %x, <4 x i8>* %y) { ; CHECK-LABEL: vwmul_v4i16: ; CHECK: # %bb.0: diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwmulsu.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwmulsu.ll --- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwmulsu.ll +++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwmulsu.ll @@ -701,11 +701,10 @@ define <8 x i16> @vwmulsu_vx_v8i16_i8_swap(<8 x i8>* %x, i8* %y) { ; CHECK-LABEL: vwmulsu_vx_v8i16_i8_swap: ; CHECK: # %bb.0: -; CHECK-NEXT: vsetivli zero, 8, e16, m1, ta, mu -; CHECK-NEXT: vle8.v v8, (a0) -; CHECK-NEXT: lb a0, 0(a1) -; CHECK-NEXT: vzext.vf2 v9, v8 -; CHECK-NEXT: vmul.vx v8, v9, a0 +; CHECK-NEXT: vsetivli zero, 8, e8, mf2, ta, mu +; CHECK-NEXT: vle8.v v9, (a0) +; CHECK-NEXT: vlse8.v v10, (a1), zero +; CHECK-NEXT: vwmulsu.vv v8, v10, v9 ; CHECK-NEXT: ret %a = load <8 x i8>, <8 x i8>* %x %b = load i8, i8* %y diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwsub.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwsub.ll --- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwsub.ll +++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwsub.ll @@ -647,12 +647,10 @@ define <8 x i16> @vwsub_vx_v8i16_i8(<8 x i8>* %x, i8* %y) { ; CHECK-LABEL: vwsub_vx_v8i16_i8: ; CHECK: # %bb.0: -; CHECK-NEXT: vsetivli zero, 8, e16, m1, ta, mu -; CHECK-NEXT: lb a1, 0(a1) +; CHECK-NEXT: vsetivli zero, 8, e8, mf2, ta, mu ; CHECK-NEXT: vle8.v v9, (a0) -; CHECK-NEXT: vmv.v.x v8, a1 -; CHECK-NEXT: vsetvli zero, zero, e8, mf2, ta, mu -; CHECK-NEXT: vwsub.wv v8, v8, v9 +; CHECK-NEXT: vlse8.v v10, (a1), zero +; CHECK-NEXT: vwsub.vv v8, v10, v9 ; CHECK-NEXT: ret %a = load <8 x i8>, <8 x i8>* %x %b = load i8, i8* %y @@ -684,12 +682,11 @@ define <4 x i32> @vwsub_vx_v4i32_i8(<4 x i16>* %x, i8* %y) { ; CHECK-LABEL: vwsub_vx_v4i32_i8: ; CHECK: # %bb.0: -; CHECK-NEXT: vsetivli zero, 4, e32, m1, ta, mu +; CHECK-NEXT: vsetivli zero, 4, e16, mf2, ta, mu ; CHECK-NEXT: lb a1, 0(a1) ; CHECK-NEXT: vle16.v v9, (a0) -; CHECK-NEXT: vmv.v.x v8, a1 -; CHECK-NEXT: vsetvli zero, zero, e16, mf2, ta, mu -; CHECK-NEXT: vwsub.wv v8, v8, v9 +; CHECK-NEXT: vmv.v.x v10, a1 +; CHECK-NEXT: vwsub.vv v8, v10, v9 ; CHECK-NEXT: ret %a = load <4 x i16>, <4 x i16>* %x %b = load i8, i8* %y @@ -704,12 +701,10 @@ define <4 x i32> @vwsub_vx_v4i32_i16(<4 x i16>* %x, i16* %y) { ; CHECK-LABEL: vwsub_vx_v4i32_i16: ; CHECK: # %bb.0: -; CHECK-NEXT: vsetivli zero, 4, e32, m1, ta, mu -; CHECK-NEXT: lh a1, 0(a1) +; CHECK-NEXT: vsetivli zero, 4, e16, mf2, ta, mu ; CHECK-NEXT: vle16.v v9, (a0) -; CHECK-NEXT: vmv.v.x v8, a1 -; CHECK-NEXT: vsetvli zero, zero, e16, mf2, ta, mu -; CHECK-NEXT: vwsub.wv v8, v8, v9 +; CHECK-NEXT: vlse16.v v10, (a1), zero +; CHECK-NEXT: vwsub.vv v8, v10, v9 ; CHECK-NEXT: ret %a = load <4 x i16>, <4 x i16>* %x %b = load i16, i16* %y @@ -756,12 +751,11 @@ ; ; RV64-LABEL: vwsub_vx_v2i64_i8: ; RV64: # %bb.0: -; RV64-NEXT: vsetivli zero, 2, e64, m1, ta, mu +; RV64-NEXT: vsetivli zero, 2, e32, mf2, ta, mu ; RV64-NEXT: lb a1, 0(a1) ; RV64-NEXT: vle32.v v9, (a0) -; RV64-NEXT: vmv.v.x v8, a1 -; RV64-NEXT: vsetvli zero, zero, e32, mf2, ta, mu -; RV64-NEXT: vwsub.wv v8, v8, v9 +; RV64-NEXT: vmv.v.x v10, a1 +; RV64-NEXT: vwsub.vv v8, v10, v9 ; RV64-NEXT: ret %a = load <2 x i32>, <2 x i32>* %x %b = load i8, i8* %y @@ -791,12 +785,11 @@ ; ; RV64-LABEL: vwsub_vx_v2i64_i16: ; RV64: # %bb.0: -; RV64-NEXT: vsetivli zero, 2, e64, m1, ta, mu +; RV64-NEXT: vsetivli zero, 2, e32, mf2, ta, mu ; RV64-NEXT: lh a1, 0(a1) ; RV64-NEXT: vle32.v v9, (a0) -; RV64-NEXT: vmv.v.x v8, a1 -; RV64-NEXT: vsetvli zero, zero, e32, mf2, ta, mu -; RV64-NEXT: vwsub.wv v8, v8, v9 +; RV64-NEXT: vmv.v.x v10, a1 +; RV64-NEXT: vwsub.vv v8, v10, v9 ; RV64-NEXT: ret %a = load <2 x i32>, <2 x i32>* %x %b = load i16, i16* %y @@ -826,12 +819,10 @@ ; ; RV64-LABEL: vwsub_vx_v2i64_i32: ; RV64: # %bb.0: -; RV64-NEXT: vsetivli zero, 2, e64, m1, ta, mu -; RV64-NEXT: lw a1, 0(a1) +; RV64-NEXT: vsetivli zero, 2, e32, mf2, ta, mu ; RV64-NEXT: vle32.v v9, (a0) -; RV64-NEXT: vmv.v.x v8, a1 -; RV64-NEXT: vsetvli zero, zero, e32, mf2, ta, mu -; RV64-NEXT: vwsub.wv v8, v8, v9 +; RV64-NEXT: vlse32.v v10, (a1), zero +; RV64-NEXT: vwsub.vv v8, v10, v9 ; RV64-NEXT: ret %a = load <2 x i32>, <2 x i32>* %x %b = load i32, i32* %y diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwsubu.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwsubu.ll --- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwsubu.ll +++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwsubu.ll @@ -647,12 +647,10 @@ define <8 x i16> @vwsubu_vx_v8i16_i8(<8 x i8>* %x, i8* %y) { ; CHECK-LABEL: vwsubu_vx_v8i16_i8: ; CHECK: # %bb.0: -; CHECK-NEXT: vsetivli zero, 8, e16, m1, ta, mu -; CHECK-NEXT: lbu a1, 0(a1) +; CHECK-NEXT: vsetivli zero, 8, e8, mf2, ta, mu ; CHECK-NEXT: vle8.v v9, (a0) -; CHECK-NEXT: vmv.v.x v8, a1 -; CHECK-NEXT: vsetvli zero, zero, e8, mf2, ta, mu -; CHECK-NEXT: vwsubu.wv v8, v8, v9 +; CHECK-NEXT: vlse8.v v10, (a1), zero +; CHECK-NEXT: vwsubu.vv v8, v10, v9 ; CHECK-NEXT: ret %a = load <8 x i8>, <8 x i8>* %x %b = load i8, i8* %y @@ -684,12 +682,11 @@ define <4 x i32> @vwsubu_vx_v4i32_i8(<4 x i16>* %x, i8* %y) { ; CHECK-LABEL: vwsubu_vx_v4i32_i8: ; CHECK: # %bb.0: -; CHECK-NEXT: vsetivli zero, 4, e32, m1, ta, mu +; CHECK-NEXT: vsetivli zero, 4, e16, mf2, ta, mu ; CHECK-NEXT: lbu a1, 0(a1) ; CHECK-NEXT: vle16.v v9, (a0) -; CHECK-NEXT: vmv.v.x v8, a1 -; CHECK-NEXT: vsetvli zero, zero, e16, mf2, ta, mu -; CHECK-NEXT: vwsubu.wv v8, v8, v9 +; CHECK-NEXT: vmv.v.x v10, a1 +; CHECK-NEXT: vwsubu.vv v8, v10, v9 ; CHECK-NEXT: ret %a = load <4 x i16>, <4 x i16>* %x %b = load i8, i8* %y @@ -704,12 +701,10 @@ define <4 x i32> @vwsubu_vx_v4i32_i16(<4 x i16>* %x, i16* %y) { ; CHECK-LABEL: vwsubu_vx_v4i32_i16: ; CHECK: # %bb.0: -; CHECK-NEXT: vsetivli zero, 4, e32, m1, ta, mu -; CHECK-NEXT: lhu a1, 0(a1) +; CHECK-NEXT: vsetivli zero, 4, e16, mf2, ta, mu ; CHECK-NEXT: vle16.v v9, (a0) -; CHECK-NEXT: vmv.v.x v8, a1 -; CHECK-NEXT: vsetvli zero, zero, e16, mf2, ta, mu -; CHECK-NEXT: vwsubu.wv v8, v8, v9 +; CHECK-NEXT: vlse16.v v10, (a1), zero +; CHECK-NEXT: vwsubu.vv v8, v10, v9 ; CHECK-NEXT: ret %a = load <4 x i16>, <4 x i16>* %x %b = load i16, i16* %y @@ -755,12 +750,11 @@ ; ; RV64-LABEL: vwsubu_vx_v2i64_i8: ; RV64: # %bb.0: -; RV64-NEXT: vsetivli zero, 2, e64, m1, ta, mu +; RV64-NEXT: vsetivli zero, 2, e32, mf2, ta, mu ; RV64-NEXT: lbu a1, 0(a1) ; RV64-NEXT: vle32.v v9, (a0) -; RV64-NEXT: vmv.v.x v8, a1 -; RV64-NEXT: vsetvli zero, zero, e32, mf2, ta, mu -; RV64-NEXT: vwsubu.wv v8, v8, v9 +; RV64-NEXT: vmv.v.x v10, a1 +; RV64-NEXT: vwsubu.vv v8, v10, v9 ; RV64-NEXT: ret %a = load <2 x i32>, <2 x i32>* %x %b = load i8, i8* %y @@ -789,12 +783,11 @@ ; ; RV64-LABEL: vwsubu_vx_v2i64_i16: ; RV64: # %bb.0: -; RV64-NEXT: vsetivli zero, 2, e64, m1, ta, mu +; RV64-NEXT: vsetivli zero, 2, e32, mf2, ta, mu ; RV64-NEXT: lhu a1, 0(a1) ; RV64-NEXT: vle32.v v9, (a0) -; RV64-NEXT: vmv.v.x v8, a1 -; RV64-NEXT: vsetvli zero, zero, e32, mf2, ta, mu -; RV64-NEXT: vwsubu.wv v8, v8, v9 +; RV64-NEXT: vmv.v.x v10, a1 +; RV64-NEXT: vwsubu.vv v8, v10, v9 ; RV64-NEXT: ret %a = load <2 x i32>, <2 x i32>* %x %b = load i16, i16* %y @@ -823,12 +816,10 @@ ; ; RV64-LABEL: vwsubu_vx_v2i64_i32: ; RV64: # %bb.0: -; RV64-NEXT: vsetivli zero, 2, e64, m1, ta, mu -; RV64-NEXT: lwu a1, 0(a1) +; RV64-NEXT: vsetivli zero, 2, e32, mf2, ta, mu ; RV64-NEXT: vle32.v v9, (a0) -; RV64-NEXT: vmv.v.x v8, a1 -; RV64-NEXT: vsetvli zero, zero, e32, mf2, ta, mu -; RV64-NEXT: vwsubu.wv v8, v8, v9 +; RV64-NEXT: vlse32.v v10, (a1), zero +; RV64-NEXT: vwsubu.vv v8, v10, v9 ; RV64-NEXT: ret %a = load <2 x i32>, <2 x i32>* %x %b = load i32, i32* %y