diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp --- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp @@ -75,7 +75,17 @@ SDValue LegalizeOp(SDValue Op); /// Assuming the node is legal, "legalize" the results. - SDValue TranslateLegalizeResults(SDValue Op, SDValue Result); + SDValue TranslateLegalizeResults(SDValue Op, SDNode *Result); + + /// Make sure Results are legal and update the translation cache. + SDValue RecursivelyLegalizeResults(SDValue Op, + MutableArrayRef Results); + + /// Wrapper to interface LowerOperation with a vector of Results. + /// Returns false if the target wants to use default expansion. Otherwise + /// returns true. If return is true and the Results are empty, then the + /// target wants to keep the input node as is. + bool LowerOperationWrapper(SDNode *N, SmallVectorImpl &Results); /// Implements unrolling a VSETCC. SDValue UnrollVSETCC(SDValue Op); @@ -84,15 +94,15 @@ /// /// This is just a high-level routine to dispatch to specific code paths for /// operations to legalize them. - SDValue Expand(SDValue Op); + void Expand(SDNode *Node, SmallVectorImpl &Results); /// Implements expansion for FP_TO_UINT; falls back to UnrollVectorOp if /// FP_TO_SINT isn't legal. - SDValue ExpandFP_TO_UINT(SDValue Op); + void ExpandFP_TO_UINT(SDValue Op, SmallVectorImpl &Results); /// Implements expansion for UINT_TO_FLOAT; falls back to UnrollVectorOp if /// SINT_TO_FLOAT and SHR on vectors isn't legal. - SDValue ExpandUINT_TO_FLOAT(SDValue Op); + void ExpandUINT_TO_FLOAT(SDValue Op, SmallVectorImpl &Results); /// Implement expansion for SIGN_EXTEND_INREG using SRL and SRA. SDValue ExpandSEXTINREG(SDValue Op); @@ -130,8 +140,8 @@ /// supported by the target. SDValue ExpandVSELECT(SDValue Op); SDValue ExpandSELECT(SDValue Op); - std::pair ExpandLoad(SDValue Op); - SDValue ExpandStore(SDValue Op); + std::pair ExpandLoad(SDNode *N); + SDValue ExpandStore(SDNode *N); SDValue ExpandFNEG(SDValue Op); SDValue ExpandFSUB(SDValue Op); SDValue ExpandBITREVERSE(SDValue Op); @@ -141,32 +151,33 @@ SDValue ExpandFunnelShift(SDValue Op); SDValue ExpandROT(SDValue Op); SDValue ExpandFMINNUM_FMAXNUM(SDValue Op); - SDValue ExpandUADDSUBO(SDValue Op); - SDValue ExpandSADDSUBO(SDValue Op); - SDValue ExpandMULO(SDValue Op); + void ExpandUADDSUBO(SDValue Op, SmallVectorImpl &Results); + void ExpandSADDSUBO(SDValue Op, SmallVectorImpl &Results); + void ExpandMULO(SDValue Op, SmallVectorImpl &Results); SDValue ExpandAddSubSat(SDValue Op); SDValue ExpandFixedPointMul(SDValue Op); SDValue ExpandFixedPointDiv(SDValue Op); SDValue ExpandStrictFPOp(SDValue Op); + void ExpandStrictFPOp(SDValue Op, SmallVectorImpl &Results); - SDValue UnrollStrictFPOp(SDValue Op); + void UnrollStrictFPOp(SDValue Op, SmallVectorImpl &Results); /// Implements vector promotion. /// /// This is essentially just bitcasting the operands to a different type and /// bitcasting the result back to the original type. - SDValue Promote(SDValue Op); + void Promote(SDNode *Node, SmallVectorImpl &Results); /// Implements [SU]INT_TO_FP vector promotion. /// /// This is a [zs]ext of the input operand to a larger integer type. - SDValue PromoteINT_TO_FP(SDValue Op); + void PromoteINT_TO_FP(SDValue Op, SmallVectorImpl &Results); /// Implements FP_TO_[SU]INT vector promotion of the result type. /// /// It is promoted to a larger integer type. The result is then /// truncated back to the original type. - SDValue PromoteFP_TO_INT(SDValue Op); + void PromoteFP_TO_INT(SDValue Op, SmallVectorImpl &Results); public: VectorLegalizer(SelectionDAG& dag) : @@ -222,11 +233,27 @@ return Changed; } -SDValue VectorLegalizer::TranslateLegalizeResults(SDValue Op, SDValue Result) { +SDValue VectorLegalizer::TranslateLegalizeResults(SDValue Op, SDNode *Result) { + assert(Op->getNumValues() == Result->getNumValues() && + "Unexpected number of results"); // Generic legalization: just pass the operand through. - for (unsigned i = 0, e = Op.getNode()->getNumValues(); i != e; ++i) - AddLegalizedOperand(Op.getValue(i), Result.getValue(i)); - return Result.getValue(Op.getResNo()); + for (unsigned i = 0, e = Op->getNumValues(); i != e; ++i) + AddLegalizedOperand(Op.getValue(i), SDValue(Result, i)); + return SDValue(Result, Op.getResNo()); +} + +SDValue +VectorLegalizer::RecursivelyLegalizeResults(SDValue Op, + MutableArrayRef Results) { + assert(Results.size() == Op->getNumValues() && + "Unexpected number of results"); + // Make sure that the generated code is itself legal. + for (unsigned i = 0, e = Results.size(); i != e; ++i) { + Results[i] = LegalizeOp(Results[i]); + AddLegalizedOperand(Op.getValue(i), Results[i]); + } + + return Results[Op.getResNo()]; } SDValue VectorLegalizer::LegalizeOp(SDValue Op) { @@ -235,18 +262,15 @@ DenseMap::iterator I = LegalizedNodes.find(Op); if (I != LegalizedNodes.end()) return I->second; - SDNode* Node = Op.getNode(); - // Legalize the operands SmallVector Ops; - for (const SDValue &Op : Node->op_values()) - Ops.push_back(LegalizeOp(Op)); + for (const SDValue &Oper : Op->op_values()) + Ops.push_back(LegalizeOp(Oper)); - SDValue Result = SDValue(DAG.UpdateNodeOperands(Op.getNode(), Ops), - Op.getResNo()); + SDNode *Node = DAG.UpdateNodeOperands(Op.getNode(), Ops); if (Op.getOpcode() == ISD::LOAD) { - LoadSDNode *LD = cast(Op.getNode()); + LoadSDNode *LD = cast(Node); ISD::LoadExtType ExtType = LD->getExtensionType(); if (LD->getMemoryVT().isVector() && ExtType != ISD::NON_EXTLOAD) { LLVM_DEBUG(dbgs() << "\nLegalizing extending vector load: "; @@ -255,22 +279,21 @@ LD->getMemoryVT())) { default: llvm_unreachable("This action is not supported yet!"); case TargetLowering::Legal: - return TranslateLegalizeResults(Op, Result); - case TargetLowering::Custom: - if (SDValue Lowered = TLI.LowerOperation(Result, DAG)) { - assert(Lowered->getNumValues() == Op->getNumValues() && - "Unexpected number of results"); - if (Lowered != Result) { - // Make sure the new code is also legal. - Lowered = LegalizeOp(Lowered); - Changed = true; - } - return TranslateLegalizeResults(Op, Lowered); + return TranslateLegalizeResults(Op, Node); + case TargetLowering::Custom: { + SmallVector ResultVals; + if (LowerOperationWrapper(Node, ResultVals)) { + if (ResultVals.empty()) + return TranslateLegalizeResults(Op, Node); + + Changed = true; + return RecursivelyLegalizeResults(Op, ResultVals); } LLVM_FALLTHROUGH; + } case TargetLowering::Expand: { Changed = true; - std::pair Tmp = ExpandLoad(Result); + std::pair Tmp = ExpandLoad(Node); AddLegalizedOperand(Op.getValue(0), Tmp.first); AddLegalizedOperand(Op.getValue(1), Tmp.second); return Op.getResNo() ? Tmp.first : Tmp.second; @@ -278,7 +301,7 @@ } } } else if (Op.getOpcode() == ISD::STORE) { - StoreSDNode *ST = cast(Op.getNode()); + StoreSDNode *ST = cast(Node); EVT StVT = ST->getMemoryVT(); MVT ValVT = ST->getValue().getSimpleValueType(); if (StVT.isVector() && ST->isTruncatingStore()) { @@ -287,19 +310,21 @@ switch (TLI.getTruncStoreAction(ValVT, StVT)) { default: llvm_unreachable("This action is not supported yet!"); case TargetLowering::Legal: - return TranslateLegalizeResults(Op, Result); + return TranslateLegalizeResults(Op, Node); case TargetLowering::Custom: { - SDValue Lowered = TLI.LowerOperation(Result, DAG); - if (Lowered != Result) { - // Make sure the new code is also legal. - Lowered = LegalizeOp(Lowered); + SmallVector ResultVals; + if (LowerOperationWrapper(Node, ResultVals)) { + if (ResultVals.empty()) + return TranslateLegalizeResults(Op, Node); + Changed = true; + return RecursivelyLegalizeResults(Op, ResultVals); } - return TranslateLegalizeResults(Op, Lowered); + LLVM_FALLTHROUGH; } case TargetLowering::Expand: { Changed = true; - SDValue Chain = ExpandStore(Result); + SDValue Chain = ExpandStore(Node); AddLegalizedOperand(Op, Chain); return Chain; } @@ -310,17 +335,17 @@ bool HasVectorValueOrOp = false; for (auto J = Node->value_begin(), E = Node->value_end(); J != E; ++J) HasVectorValueOrOp |= J->isVector(); - for (const SDValue &Op : Node->op_values()) - HasVectorValueOrOp |= Op.getValueType().isVector(); + for (const SDValue &Oper : Node->op_values()) + HasVectorValueOrOp |= Oper.getValueType().isVector(); if (!HasVectorValueOrOp) - return TranslateLegalizeResults(Op, Result); + return TranslateLegalizeResults(Op, Node); TargetLowering::LegalizeAction Action = TargetLowering::Legal; EVT ValVT; switch (Op.getOpcode()) { default: - return TranslateLegalizeResults(Op, Result); + return TranslateLegalizeResults(Op, Node); #define INSTRUCTION(NAME, NARG, ROUND_MODE, INTRINSIC, DAGN) \ case ISD::STRICT_##DAGN: #include "llvm/IR/ConstrainedOps.def" @@ -473,42 +498,70 @@ LLVM_DEBUG(dbgs() << "\nLegalizing vector op: "; Node->dump(&DAG)); + SmallVector ResultVals; switch (Action) { default: llvm_unreachable("This action is not supported yet!"); case TargetLowering::Promote: - Result = Promote(Op); - Changed = true; + LLVM_DEBUG(dbgs() << "Promoting\n"); + Promote(Node, ResultVals); + assert(!ResultVals.empty() && "No results for promotion?"); break; case TargetLowering::Legal: LLVM_DEBUG(dbgs() << "Legal node: nothing to do\n"); break; - case TargetLowering::Custom: { + case TargetLowering::Custom: LLVM_DEBUG(dbgs() << "Trying custom legalization\n"); - if (SDValue Tmp1 = TLI.LowerOperation(Op, DAG)) { - LLVM_DEBUG(dbgs() << "Successfully custom legalized node\n"); - Result = Tmp1; + if (LowerOperationWrapper(Node, ResultVals)) break; - } LLVM_DEBUG(dbgs() << "Could not custom legalize node\n"); LLVM_FALLTHROUGH; - } case TargetLowering::Expand: - Result = Expand(Op); + LLVM_DEBUG(dbgs() << "Expanding\n"); + Expand(Node, ResultVals); + break; } - // Make sure that the generated code is itself legal. - if (Result != Op) { - Result = LegalizeOp(Result); - Changed = true; + if (ResultVals.empty()) + return TranslateLegalizeResults(Op, Node); + + Changed = true; + return RecursivelyLegalizeResults(Op, ResultVals); +} + +// FIME: This is very similar to the X86 override of +// TargetLowering::LowerOperationWrapper. Can we merge them somehow? +bool VectorLegalizer::LowerOperationWrapper(SDNode *Node, + SmallVectorImpl &Results) { + SDValue Res = TLI.LowerOperation(SDValue(Node, 0), DAG); + + if (!Res.getNode()) + return false; + + if (Res == SDValue(Node, 0)) + return true; + + // If the original node has one result, take the return value from + // LowerOperation as is. It might not be result number 0. + if (Node->getNumValues() == 1) { + Results.push_back(Res); + return true; } - // Note that LegalizeOp may be reentered even from single-use nodes, which - // means that we always must cache transformed nodes. - AddLegalizedOperand(Op, Result); - return Result; + // If the original node has multiple results, then the return node should + // have the same number of results. + assert((Node->getNumValues() == Res->getNumValues()) && + "Lowering returned the wrong number of results!"); + + // Places new result values base on N result number. + for (unsigned I = 0, E = Node->getNumValues(); I != E; ++I) + Results.push_back(Res.getValue(I)); + + return true; } -SDValue VectorLegalizer::Promote(SDValue Op) { +void VectorLegalizer::Promote(SDNode *Node, SmallVectorImpl &Results) { + SDValue Op(Node, 0); // FIXME: Use Node throughout. + // For a few operations there is a specific concept for promotion based on // the operand's type. switch (Op.getOpcode()) { @@ -517,13 +570,15 @@ case ISD::STRICT_SINT_TO_FP: case ISD::STRICT_UINT_TO_FP: // "Promote" the operation by extending the operand. - return PromoteINT_TO_FP(Op); + PromoteINT_TO_FP(Op, Results); + return; case ISD::FP_TO_UINT: case ISD::FP_TO_SINT: case ISD::STRICT_FP_TO_UINT: case ISD::STRICT_FP_TO_SINT: // Promote the operation by extending the operand. - return PromoteFP_TO_INT(Op); + PromoteFP_TO_INT(Op, Results); + return; case ISD::FP_ROUND: case ISD::FP_EXTEND: // These operations are used to do promotion so they can't be promoted @@ -558,15 +613,20 @@ } Op = DAG.getNode(Op.getOpcode(), dl, NVT, Operands, Op.getNode()->getFlags()); + + SDValue Res; if ((VT.isFloatingPoint() && NVT.isFloatingPoint()) || (VT.isVector() && VT.getVectorElementType().isFloatingPoint() && NVT.isVector() && NVT.getVectorElementType().isFloatingPoint())) - return DAG.getNode(ISD::FP_ROUND, dl, VT, Op, DAG.getIntPtrConstant(0, dl)); + Res = DAG.getNode(ISD::FP_ROUND, dl, VT, Op, DAG.getIntPtrConstant(0, dl)); else - return DAG.getNode(ISD::BITCAST, dl, VT, Op); + Res = DAG.getNode(ISD::BITCAST, dl, VT, Op); + + Results.push_back(Res); } -SDValue VectorLegalizer::PromoteINT_TO_FP(SDValue Op) { +void VectorLegalizer::PromoteINT_TO_FP(SDValue Op, + SmallVectorImpl &Results) { // INT_TO_FP operations may require the input operand be promoted even // when the type is otherwise legal. bool IsStrict = Op->isStrictFPOpcode(); @@ -589,18 +649,24 @@ Operands[j] = Op.getOperand(j); } - if (IsStrict) - return DAG.getNode(Op.getOpcode(), dl, {Op.getValueType(), MVT::Other}, - Operands); + if (IsStrict) { + SDValue Res = DAG.getNode(Op.getOpcode(), dl, + {Op.getValueType(), MVT::Other}, Operands); + Results.push_back(Res); + Results.push_back(Res.getValue(1)); + return; + } - return DAG.getNode(Op.getOpcode(), dl, Op.getValueType(), Operands); + SDValue Res = DAG.getNode(Op.getOpcode(), dl, Op.getValueType(), Operands); + Results.push_back(Res); } // For FP_TO_INT we promote the result type to a vector type with wider // elements and then truncate the result. This is different from the default // PromoteVector which uses bitcast to promote thus assumning that the // promoted vector type has the same overall size. -SDValue VectorLegalizer::PromoteFP_TO_INT(SDValue Op) { +void VectorLegalizer::PromoteFP_TO_INT(SDValue Op, + SmallVectorImpl &Results) { MVT VT = Op.getSimpleValueType(); MVT NVT = TLI.getTypeToPromoteTo(Op.getOpcode(), VT); bool IsStrict = Op->isStrictFPOpcode(); @@ -639,14 +705,13 @@ Promoted = DAG.getNode(NewOpc, dl, NVT, Promoted, DAG.getValueType(VT.getScalarType())); Promoted = DAG.getNode(ISD::TRUNCATE, dl, VT, Promoted); + Results.push_back(Promoted); if (IsStrict) - return DAG.getMergeValues({Promoted, Chain}, dl); - - return Promoted; + Results.push_back(Chain); } -std::pair VectorLegalizer::ExpandLoad(SDValue Op) { - LoadSDNode *LD = cast(Op.getNode()); +std::pair VectorLegalizer::ExpandLoad(SDNode *N) { + LoadSDNode *LD = cast(N); EVT SrcVT = LD->getMemoryVT(); EVT SrcEltVT = SrcVT.getScalarType(); @@ -655,7 +720,7 @@ SDValue NewChain; SDValue Value; if (SrcVT.getVectorNumElements() > 1 && !SrcEltVT.isByteSized()) { - SDLoc dl(Op); + SDLoc dl(N); SmallVector Vals; SmallVector LoadChains; @@ -767,7 +832,7 @@ } NewChain = DAG.getNode(ISD::TokenFactor, dl, MVT::Other, LoadChains); - Value = DAG.getBuildVector(Op.getNode()->getValueType(0), dl, Vals); + Value = DAG.getBuildVector(N->getValueType(0), dl, Vals); } else { std::tie(Value, NewChain) = TLI.scalarizeVectorLoad(LD, DAG); } @@ -775,90 +840,122 @@ return std::make_pair(Value, NewChain); } -SDValue VectorLegalizer::ExpandStore(SDValue Op) { - StoreSDNode *ST = cast(Op.getNode()); +SDValue VectorLegalizer::ExpandStore(SDNode *N) { + StoreSDNode *ST = cast(N); SDValue TF = TLI.scalarizeVectorStore(ST, DAG); return TF; } -SDValue VectorLegalizer::Expand(SDValue Op) { +void VectorLegalizer::Expand(SDNode *Node, SmallVectorImpl &Results) { + SDValue Op(Node, 0); // FIXME: Just pass Node to all the expanders. + switch (Op->getOpcode()) { case ISD::SIGN_EXTEND_INREG: - return ExpandSEXTINREG(Op); + Results.push_back(ExpandSEXTINREG(Op)); + return; case ISD::ANY_EXTEND_VECTOR_INREG: - return ExpandANY_EXTEND_VECTOR_INREG(Op); + Results.push_back(ExpandANY_EXTEND_VECTOR_INREG(Op)); + return; case ISD::SIGN_EXTEND_VECTOR_INREG: - return ExpandSIGN_EXTEND_VECTOR_INREG(Op); + Results.push_back(ExpandSIGN_EXTEND_VECTOR_INREG(Op)); + return; case ISD::ZERO_EXTEND_VECTOR_INREG: - return ExpandZERO_EXTEND_VECTOR_INREG(Op); + Results.push_back(ExpandZERO_EXTEND_VECTOR_INREG(Op)); + return; case ISD::BSWAP: - return ExpandBSWAP(Op); + Results.push_back(ExpandBSWAP(Op)); + return; case ISD::VSELECT: - return ExpandVSELECT(Op); + Results.push_back(ExpandVSELECT(Op)); + return; case ISD::SELECT: - return ExpandSELECT(Op); + Results.push_back(ExpandSELECT(Op)); + return; case ISD::FP_TO_UINT: - return ExpandFP_TO_UINT(Op); + ExpandFP_TO_UINT(Op, Results); + return; case ISD::UINT_TO_FP: - return ExpandUINT_TO_FLOAT(Op); + ExpandUINT_TO_FLOAT(Op, Results); + return; case ISD::FNEG: - return ExpandFNEG(Op); + Results.push_back(ExpandFNEG(Op)); + return; case ISD::FSUB: - return ExpandFSUB(Op); + if (SDValue Tmp = ExpandFSUB(Op)) + Results.push_back(Tmp); + return; case ISD::SETCC: - return UnrollVSETCC(Op); + Results.push_back(UnrollVSETCC(Op)); + return; case ISD::ABS: - return ExpandABS(Op); + Results.push_back(ExpandABS(Op)); + return; case ISD::BITREVERSE: - return ExpandBITREVERSE(Op); + if (SDValue Tmp = ExpandBITREVERSE(Op)) + Results.push_back(Tmp); + return; case ISD::CTPOP: - return ExpandCTPOP(Op); + Results.push_back(ExpandCTPOP(Op)); + return; case ISD::CTLZ: case ISD::CTLZ_ZERO_UNDEF: - return ExpandCTLZ(Op); + Results.push_back(ExpandCTLZ(Op)); + return; case ISD::CTTZ: case ISD::CTTZ_ZERO_UNDEF: - return ExpandCTTZ(Op); + Results.push_back(ExpandCTTZ(Op)); + return; case ISD::FSHL: case ISD::FSHR: - return ExpandFunnelShift(Op); + Results.push_back(ExpandFunnelShift(Op)); + return; case ISD::ROTL: case ISD::ROTR: - return ExpandROT(Op); + Results.push_back(ExpandROT(Op)); + return; case ISD::FMINNUM: case ISD::FMAXNUM: - return ExpandFMINNUM_FMAXNUM(Op); + Results.push_back(ExpandFMINNUM_FMAXNUM(Op)); + return; case ISD::UADDO: case ISD::USUBO: - return ExpandUADDSUBO(Op); + ExpandUADDSUBO(Op, Results); + return; case ISD::SADDO: case ISD::SSUBO: - return ExpandSADDSUBO(Op); + ExpandSADDSUBO(Op, Results); + return; case ISD::UMULO: case ISD::SMULO: - return ExpandMULO(Op); + ExpandMULO(Op, Results); + return; case ISD::USUBSAT: case ISD::SSUBSAT: case ISD::UADDSAT: case ISD::SADDSAT: - return ExpandAddSubSat(Op); + Results.push_back(ExpandAddSubSat(Op)); + return; case ISD::SMULFIX: case ISD::UMULFIX: - return ExpandFixedPointMul(Op); + Results.push_back(ExpandFixedPointMul(Op)); + return; case ISD::SMULFIXSAT: case ISD::UMULFIXSAT: // FIXME: We do not expand SMULFIXSAT/UMULFIXSAT here yet, not sure exactly // why. Maybe it results in worse codegen compared to the unroll for some // targets? This should probably be investigated. And if we still prefer to // unroll an explanation could be helpful. - return DAG.UnrollVectorOp(Op.getNode()); + Results.push_back(DAG.UnrollVectorOp(Op.getNode())); + return; case ISD::SDIVFIX: case ISD::UDIVFIX: - return ExpandFixedPointDiv(Op); + Results.push_back(ExpandFixedPointDiv(Op)); + return; #define INSTRUCTION(NAME, NARG, ROUND_MODE, INTRINSIC, DAGN) \ case ISD::STRICT_##DAGN: #include "llvm/IR/ConstrainedOps.def" - return ExpandStrictFPOp(Op); + ExpandStrictFPOp(Op, Results); + return; case ISD::VECREDUCE_ADD: case ISD::VECREDUCE_MUL: case ISD::VECREDUCE_AND: @@ -872,9 +969,11 @@ case ISD::VECREDUCE_FMUL: case ISD::VECREDUCE_FMAX: case ISD::VECREDUCE_FMIN: - return TLI.expandVecReduce(Op.getNode(), DAG); + Results.push_back(TLI.expandVecReduce(Op.getNode(), DAG)); + return; default: - return DAG.UnrollVectorOp(Op.getNode()); + Results.push_back(DAG.UnrollVectorOp(Op.getNode())); + return; } } @@ -1120,7 +1219,7 @@ return DAG.UnrollVectorOp(Op.getNode()); // Let LegalizeDAG handle this later. - return Op; + return SDValue(); } SDValue VectorLegalizer::ExpandVSELECT(SDValue Op) { @@ -1180,23 +1279,28 @@ return DAG.UnrollVectorOp(Op.getNode()); } -SDValue VectorLegalizer::ExpandFP_TO_UINT(SDValue Op) { +void VectorLegalizer::ExpandFP_TO_UINT(SDValue Op, + SmallVectorImpl &Results) { // Attempt to expand using TargetLowering. SDValue Result, Chain; if (TLI.expandFP_TO_UINT(Op.getNode(), Result, Chain, DAG)) { + Results.push_back(Result); if (Op->isStrictFPOpcode()) - // Relink the chain - DAG.ReplaceAllUsesOfValueWith(Op.getValue(1), Chain); - return Result; + Results.push_back(Chain); + return; } // Otherwise go ahead and unroll. - if (Op->isStrictFPOpcode()) - return UnrollStrictFPOp(Op); - return DAG.UnrollVectorOp(Op.getNode()); + if (Op->isStrictFPOpcode()) { + UnrollStrictFPOp(Op, Results); + return; + } + + Results.push_back(DAG.UnrollVectorOp(Op.getNode())); } -SDValue VectorLegalizer::ExpandUINT_TO_FLOAT(SDValue Op) { +void VectorLegalizer::ExpandUINT_TO_FLOAT(SDValue Op, + SmallVectorImpl &Results) { bool IsStrict = Op.getNode()->isStrictFPOpcode(); unsigned OpNo = IsStrict ? 1 : 0; SDValue Src = Op.getOperand(OpNo); @@ -1207,10 +1311,10 @@ SDValue Result; SDValue Chain; if (TLI.expandUINT_TO_FP(Op.getNode(), Result, Chain, DAG)) { + Results.push_back(Result); if (IsStrict) - // Relink the chain - DAG.ReplaceAllUsesOfValueWith(Op.getValue(1), Chain); - return Result; + Results.push_back(Chain); + return; } // Make sure that the SINT_TO_FP and SRL instructions are available. @@ -1219,9 +1323,13 @@ (IsStrict && TLI.getOperationAction(ISD::STRICT_SINT_TO_FP, VT) == TargetLowering::Expand)) || TLI.getOperationAction(ISD::SRL, VT) == TargetLowering::Expand) { - if (IsStrict) - return UnrollStrictFPOp(Op); - return DAG.UnrollVectorOp(Op.getNode()); + if (IsStrict) { + UnrollStrictFPOp(Op, Results); + return; + } + + Results.push_back(DAG.UnrollVectorOp(Op.getNode())); + return; } unsigned BW = VT.getScalarSizeInBits(); @@ -1261,9 +1369,9 @@ DAG.getNode(ISD::STRICT_FADD, DL, {Op.getValueType(), MVT::Other}, {SDValue(fLO.getNode(), 1), fHI, fLO}); - // Relink the chain - DAG.ReplaceAllUsesOfValueWith(Op.getValue(1), SDValue(Result.getNode(), 1)); - return Result; + Results.push_back(Result); + Results.push_back(Result.getValue(1)); + return; } // Convert hi and lo to floats @@ -1274,7 +1382,7 @@ SDValue fLO = DAG.getNode(ISD::SINT_TO_FP, DL, Op.getValueType(), LO); // Add the two halves - return DAG.getNode(ISD::FADD, DL, Op.getValueType(), fHI, fLO); + Results.push_back(DAG.getNode(ISD::FADD, DL, Op.getValueType(), fHI, fLO)); } SDValue VectorLegalizer::ExpandFNEG(SDValue Op) { @@ -1295,7 +1403,7 @@ EVT VT = Op.getValueType(); if (TLI.isOperationLegalOrCustom(ISD::FNEG, VT) && TLI.isOperationLegalOrCustom(ISD::FADD, VT)) - return Op; // Defer to LegalizeDAG + return SDValue(); // Defer to LegalizeDAG return DAG.UnrollVectorOp(Op.getNode()); } @@ -1346,44 +1454,30 @@ return DAG.UnrollVectorOp(Op.getNode()); } -SDValue VectorLegalizer::ExpandUADDSUBO(SDValue Op) { +void VectorLegalizer::ExpandUADDSUBO(SDValue Op, + SmallVectorImpl &Results) { SDValue Result, Overflow; TLI.expandUADDSUBO(Op.getNode(), Result, Overflow, DAG); - - if (Op.getResNo() == 0) { - AddLegalizedOperand(Op.getValue(1), LegalizeOp(Overflow)); - return Result; - } else { - AddLegalizedOperand(Op.getValue(0), LegalizeOp(Result)); - return Overflow; - } + Results.push_back(Result); + Results.push_back(Overflow); } -SDValue VectorLegalizer::ExpandSADDSUBO(SDValue Op) { +void VectorLegalizer::ExpandSADDSUBO(SDValue Op, + SmallVectorImpl &Results) { SDValue Result, Overflow; TLI.expandSADDSUBO(Op.getNode(), Result, Overflow, DAG); - - if (Op.getResNo() == 0) { - AddLegalizedOperand(Op.getValue(1), LegalizeOp(Overflow)); - return Result; - } else { - AddLegalizedOperand(Op.getValue(0), LegalizeOp(Result)); - return Overflow; - } + Results.push_back(Result); + Results.push_back(Overflow); } -SDValue VectorLegalizer::ExpandMULO(SDValue Op) { +void VectorLegalizer::ExpandMULO(SDValue Op, + SmallVectorImpl &Results) { SDValue Result, Overflow; if (!TLI.expandMULO(Op.getNode(), Result, Overflow, DAG)) std::tie(Result, Overflow) = DAG.UnrollVectorOverflowOp(Op.getNode()); - if (Op.getResNo() == 0) { - AddLegalizedOperand(Op.getValue(1), LegalizeOp(Overflow)); - return Result; - } else { - AddLegalizedOperand(Op.getValue(0), LegalizeOp(Result)); - return Overflow; - } + Results.push_back(Result); + Results.push_back(Overflow); } SDValue VectorLegalizer::ExpandAddSubSat(SDValue Op) { @@ -1406,16 +1500,22 @@ return DAG.UnrollVectorOp(N); } -SDValue VectorLegalizer::ExpandStrictFPOp(SDValue Op) { - if (Op.getOpcode() == ISD::STRICT_UINT_TO_FP) - return ExpandUINT_TO_FLOAT(Op); - if (Op.getOpcode() == ISD::STRICT_FP_TO_UINT) - return ExpandFP_TO_UINT(Op); +void VectorLegalizer::ExpandStrictFPOp(SDValue Op, + SmallVectorImpl &Results) { + if (Op.getOpcode() == ISD::STRICT_UINT_TO_FP) { + ExpandUINT_TO_FLOAT(Op, Results); + return; + } + if (Op.getOpcode() == ISD::STRICT_FP_TO_UINT) { + ExpandFP_TO_UINT(Op, Results); + return; + } - return UnrollStrictFPOp(Op); + UnrollStrictFPOp(Op, Results); } -SDValue VectorLegalizer::UnrollStrictFPOp(SDValue Op) { +void VectorLegalizer::UnrollStrictFPOp(SDValue Op, + SmallVectorImpl &Results) { EVT VT = Op.getValue(0).getValueType(); EVT EltVT = VT.getVectorElementType(); unsigned NumElems = VT.getVectorNumElements(); @@ -1472,10 +1572,8 @@ SDValue Result = DAG.getBuildVector(VT, dl, OpValues); SDValue NewChain = DAG.getNode(ISD::TokenFactor, dl, MVT::Other, OpChains); - AddLegalizedOperand(Op.getValue(0), Result); - AddLegalizedOperand(Op.getValue(1), NewChain); - - return Op.getResNo() ? NewChain : Result; + Results.push_back(Result); + Results.push_back(NewChain); } SDValue VectorLegalizer::UnrollVSETCC(SDValue Op) { diff --git a/llvm/test/CodeGen/X86/avx512-cmp.ll b/llvm/test/CodeGen/X86/avx512-cmp.ll --- a/llvm/test/CodeGen/X86/avx512-cmp.ll +++ b/llvm/test/CodeGen/X86/avx512-cmp.ll @@ -181,3 +181,39 @@ if.end.i: ret i32 6 } + +; This test previously caused an infinite loop in legalize vector ops. Due to +; CSE triggering on the call to UpdateNodeOperands and the resulting node not +; being passed to LowerOperation. The add is needed to force the zext into a +; sext on that path. The shuffle keeps the zext alive. The xor somehow +; influences the zext to be visited before the sext exposing the CSE opportunity +; for the sext since zext of setcc is custom legalized to a sext and shift. +define <8 x i32> @legalize_loop(<8 x double> %arg) { +; KNL-LABEL: legalize_loop: +; KNL: ## %bb.0: +; KNL-NEXT: vxorpd %xmm1, %xmm1, %xmm1 +; KNL-NEXT: vcmpnltpd %zmm0, %zmm1, %k1 +; KNL-NEXT: vpternlogd $255, %zmm0, %zmm0, %zmm0 {%k1} {z} +; KNL-NEXT: vpsrld $31, %ymm0, %ymm1 +; KNL-NEXT: vpshufd {{.*#+}} ymm1 = ymm1[3,2,1,0,7,6,5,4] +; KNL-NEXT: vpermq {{.*#+}} ymm1 = ymm1[2,3,0,1] +; KNL-NEXT: vpsubd %ymm0, %ymm1, %ymm0 +; KNL-NEXT: retq +; +; SKX-LABEL: legalize_loop: +; SKX: ## %bb.0: +; SKX-NEXT: vxorpd %xmm1, %xmm1, %xmm1 +; SKX-NEXT: vcmpnltpd %zmm0, %zmm1, %k0 +; SKX-NEXT: vpmovm2d %k0, %ymm0 +; SKX-NEXT: vpsrld $31, %ymm0, %ymm1 +; SKX-NEXT: vpshufd {{.*#+}} ymm1 = ymm1[3,2,1,0,7,6,5,4] +; SKX-NEXT: vpermq {{.*#+}} ymm1 = ymm1[2,3,0,1] +; SKX-NEXT: vpsubd %ymm0, %ymm1, %ymm0 +; SKX-NEXT: retq + %tmp = fcmp ogt <8 x double> %arg, zeroinitializer + %tmp1 = xor <8 x i1> %tmp, + %tmp2 = zext <8 x i1> %tmp1 to <8 x i32> + %tmp3 = shufflevector <8 x i32> %tmp2, <8 x i32> undef, <8 x i32> + %tmp4 = add <8 x i32> %tmp2, %tmp3 + ret <8 x i32> %tmp4 +}