diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp --- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp @@ -75,7 +75,7 @@ SDValue LegalizeOp(SDValue Op); /// Assuming the node is legal, "legalize" the results. - SDValue TranslateLegalizeResults(SDValue Op, SDValue Result); + SDValue TranslateLegalizeResults(SDValue Op, SDNode *Result); /// Implements unrolling a VSETCC. SDValue UnrollVSETCC(SDValue Op); @@ -84,15 +84,16 @@ /// /// This is just a high-level routine to dispatch to specific code paths for /// operations to legalize them. - SDValue Expand(SDValue Op); + SDValue Expand(SDNode *Node, unsigned ResNo); /// Implements expansion for FP_TO_UINT; falls back to UnrollVectorOp if /// FP_TO_SINT isn't legal. - SDValue ExpandFP_TO_UINT(SDValue Op); + void ExpandFP_TO_UINT(SDValue Op, SmallVectorImpl &Results); /// Implements expansion for UINT_TO_FLOAT; falls back to UnrollVectorOp if /// SINT_TO_FLOAT and SHR on vectors isn't legal. - SDValue ExpandUINT_TO_FLOAT(SDValue Op); + void ExpandUINT_TO_FLOAT(SDValue Op, + SmallVectorImpl &Results); /// Implement expansion for SIGN_EXTEND_INREG using SRL and SRA. SDValue ExpandSEXTINREG(SDValue Op); @@ -130,8 +131,8 @@ /// supported by the target. SDValue ExpandVSELECT(SDValue Op); SDValue ExpandSELECT(SDValue Op); - std::pair ExpandLoad(SDValue Op); - SDValue ExpandStore(SDValue Op); + std::pair ExpandLoad(SDNode *N); + SDValue ExpandStore(SDNode *N); SDValue ExpandFNEG(SDValue Op); SDValue ExpandFSUB(SDValue Op); SDValue ExpandBITREVERSE(SDValue Op); @@ -141,31 +142,31 @@ SDValue ExpandFunnelShift(SDValue Op); SDValue ExpandROT(SDValue Op); SDValue ExpandFMINNUM_FMAXNUM(SDValue Op); - SDValue ExpandUADDSUBO(SDValue Op); - SDValue ExpandSADDSUBO(SDValue Op); - SDValue ExpandMULO(SDValue Op); + std::pair ExpandUADDSUBO(SDValue Op); + std::pair ExpandSADDSUBO(SDValue Op); + std::pair ExpandMULO(SDValue Op); SDValue ExpandAddSubSat(SDValue Op); SDValue ExpandFixedPointMul(SDValue Op); - SDValue ExpandStrictFPOp(SDValue Op); + void ExpandStrictFPOp(SDValue Op, SmallVectorImpl &Results); - SDValue UnrollStrictFPOp(SDValue Op); + void UnrollStrictFPOp(SDValue Op, SmallVectorImpl &Results); /// Implements vector promotion. /// /// This is essentially just bitcasting the operands to a different type and /// bitcasting the result back to the original type. - SDValue Promote(SDValue Op); + SDValue Promote(SDNode *Node, unsigned ResNo); /// Implements [SU]INT_TO_FP vector promotion. /// /// This is a [zs]ext of the input operand to a larger integer type. - SDValue PromoteINT_TO_FP(SDValue Op); + void PromoteINT_TO_FP(SDValue Op, SmallVectorImpl &Results); /// Implements FP_TO_[SU]INT vector promotion of the result type. /// /// It is promoted to a larger integer type. The result is then /// truncated back to the original type. - SDValue PromoteFP_TO_INT(SDValue Op); + void PromoteFP_TO_INT(SDValue Op, SmallVectorImpl &Results); public: VectorLegalizer(SelectionDAG& dag) : @@ -221,11 +222,11 @@ return Changed; } -SDValue VectorLegalizer::TranslateLegalizeResults(SDValue Op, SDValue Result) { +SDValue VectorLegalizer::TranslateLegalizeResults(SDValue Op, SDNode *Result) { // Generic legalization: just pass the operand through. for (unsigned i = 0, e = Op.getNode()->getNumValues(); i != e; ++i) - AddLegalizedOperand(Op.getValue(i), Result.getValue(i)); - return Result.getValue(Op.getResNo()); + AddLegalizedOperand(Op.getValue(i), SDValue(Result, i)); + return SDValue(Result, Op.getResNo()); } SDValue VectorLegalizer::LegalizeOp(SDValue Op) { @@ -234,18 +235,15 @@ DenseMap::iterator I = LegalizedNodes.find(Op); if (I != LegalizedNodes.end()) return I->second; - SDNode* Node = Op.getNode(); - // Legalize the operands SmallVector Ops; - for (const SDValue &Op : Node->op_values()) - Ops.push_back(LegalizeOp(Op)); + for (const SDValue &Oper : Op->op_values()) + Ops.push_back(LegalizeOp(Oper)); - SDValue Result = SDValue(DAG.UpdateNodeOperands(Op.getNode(), Ops), - Op.getResNo()); + SDNode *Node = DAG.UpdateNodeOperands(Op.getNode(), Ops); if (Op.getOpcode() == ISD::LOAD) { - LoadSDNode *LD = cast(Op.getNode()); + LoadSDNode *LD = cast(Node); ISD::LoadExtType ExtType = LD->getExtensionType(); if (LD->getMemoryVT().isVector() && ExtType != ISD::NON_EXTLOAD) { LLVM_DEBUG(dbgs() << "\nLegalizing extending vector load: "; @@ -254,22 +252,22 @@ LD->getMemoryVT())) { default: llvm_unreachable("This action is not supported yet!"); case TargetLowering::Legal: - return TranslateLegalizeResults(Op, Result); + return TranslateLegalizeResults(Op, Node); case TargetLowering::Custom: - if (SDValue Lowered = TLI.LowerOperation(Result, DAG)) { + if (SDValue Lowered = TLI.LowerOperation(SDValue(Node, 0), DAG)) { assert(Lowered->getNumValues() == Op->getNumValues() && "Unexpected number of results"); - if (Lowered != Result) { + if (Lowered != SDValue(Node, 0)) { // Make sure the new code is also legal. Lowered = LegalizeOp(Lowered); Changed = true; } - return TranslateLegalizeResults(Op, Lowered); + return TranslateLegalizeResults(Op, Lowered.getNode()); } LLVM_FALLTHROUGH; case TargetLowering::Expand: { Changed = true; - std::pair Tmp = ExpandLoad(Result); + std::pair Tmp = ExpandLoad(Node); AddLegalizedOperand(Op.getValue(0), Tmp.first); AddLegalizedOperand(Op.getValue(1), Tmp.second); return Op.getResNo() ? Tmp.first : Tmp.second; @@ -277,7 +275,7 @@ } } } else if (Op.getOpcode() == ISD::STORE) { - StoreSDNode *ST = cast(Op.getNode()); + StoreSDNode *ST = cast(Node); EVT StVT = ST->getMemoryVT(); MVT ValVT = ST->getValue().getSimpleValueType(); if (StVT.isVector() && ST->isTruncatingStore()) { @@ -286,19 +284,19 @@ switch (TLI.getTruncStoreAction(ValVT, StVT)) { default: llvm_unreachable("This action is not supported yet!"); case TargetLowering::Legal: - return TranslateLegalizeResults(Op, Result); + return TranslateLegalizeResults(Op, Node); case TargetLowering::Custom: { - SDValue Lowered = TLI.LowerOperation(Result, DAG); - if (Lowered != Result) { + SDValue Lowered = TLI.LowerOperation(SDValue(Node, 0), DAG); + if (Lowered != SDValue(Node, 0)) { // Make sure the new code is also legal. Lowered = LegalizeOp(Lowered); Changed = true; } - return TranslateLegalizeResults(Op, Lowered); + return TranslateLegalizeResults(Op, Lowered.getNode()); } case TargetLowering::Expand: { Changed = true; - SDValue Chain = ExpandStore(Result); + SDValue Chain = ExpandStore(Node); AddLegalizedOperand(Op, Chain); return Chain; } @@ -309,17 +307,17 @@ bool HasVectorValueOrOp = false; for (auto J = Node->value_begin(), E = Node->value_end(); J != E; ++J) HasVectorValueOrOp |= J->isVector(); - for (const SDValue &Op : Node->op_values()) - HasVectorValueOrOp |= Op.getValueType().isVector(); + for (const SDValue &Oper : Node->op_values()) + HasVectorValueOrOp |= Oper.getValueType().isVector(); if (!HasVectorValueOrOp) - return TranslateLegalizeResults(Op, Result); + return TranslateLegalizeResults(Op, Node); TargetLowering::LegalizeAction Action = TargetLowering::Legal; EVT ValVT; switch (Op.getOpcode()) { default: - return TranslateLegalizeResults(Op, Result); + return TranslateLegalizeResults(Op, Node); #define INSTRUCTION(NAME, NARG, ROUND_MODE, INTRINSIC, DAGN) \ case ISD::STRICT_##DAGN: #include "llvm/IR/ConstrainedOps.def" @@ -472,40 +470,53 @@ switch (Action) { default: llvm_unreachable("This action is not supported yet!"); - case TargetLowering::Promote: - Result = Promote(Op); - Changed = true; - break; + case TargetLowering::Promote: { + LLVM_DEBUG(dbgs() << "Promoting\n"); + return Promote(Node, Op.getResNo()); + } case TargetLowering::Legal: LLVM_DEBUG(dbgs() << "Legal node: nothing to do\n"); - break; + TranslateLegalizeResults(Op, Node); + return SDValue(Node, Op.getResNo()); case TargetLowering::Custom: { LLVM_DEBUG(dbgs() << "Trying custom legalization\n"); - if (SDValue Tmp1 = TLI.LowerOperation(Op, DAG)) { + if (SDValue Tmp = TLI.LowerOperation(SDValue(Node, 0), DAG)) { LLVM_DEBUG(dbgs() << "Successfully custom legalized node\n"); - Result = Tmp1; - break; + if (Tmp != SDValue(Node, 0)) { + // Make sure that the generated code is itself legal. + Tmp = LegalizeOp(Tmp); + Changed = true; + } + + // Tmp might point to a single result from a multi result node, in that + // case we need to use it's result number. + if (Node->getNumValues() == 1) { + AddLegalizedOperand(Op, Tmp); + return Tmp; + } + + // Otherwise it should be a multi-result node with the same number of + // results. + assert(Tmp->getNumValues() == Node->getNumValues() && + "Unexpected number of results"); + return TranslateLegalizeResults(Op, Tmp.getNode()); } LLVM_DEBUG(dbgs() << "Could not custom legalize node\n"); LLVM_FALLTHROUGH; } case TargetLowering::Expand: - Result = Expand(Op); + LLVM_DEBUG(dbgs() << "Expanding\n"); + return Expand(Node, Op.getResNo()); } - // Make sure that the generated code is itself legal. - if (Result != Op) { - Result = LegalizeOp(Result); - Changed = true; - } - - // Note that LegalizeOp may be reentered even from single-use nodes, which - // means that we always must cache transformed nodes. - AddLegalizedOperand(Op, Result); - return Result; + llvm_unreachable("Unexpected legalization behavior!"); } -SDValue VectorLegalizer::Promote(SDValue Op) { +SDValue VectorLegalizer::Promote(SDNode *Node, unsigned ResNo) { + SDValue Op(Node, 0); // FIXME: Use Node throughout. + + SmallVector Results; + // For a few operations there is a specific concept for promotion based on // the operand's type. switch (Op.getOpcode()) { @@ -514,56 +525,78 @@ case ISD::STRICT_SINT_TO_FP: case ISD::STRICT_UINT_TO_FP: // "Promote" the operation by extending the operand. - return PromoteINT_TO_FP(Op); + PromoteINT_TO_FP(Op, Results); + break; case ISD::FP_TO_UINT: case ISD::FP_TO_SINT: case ISD::STRICT_FP_TO_UINT: case ISD::STRICT_FP_TO_SINT: // Promote the operation by extending the operand. - return PromoteFP_TO_INT(Op); + PromoteFP_TO_INT(Op, Results); + break; case ISD::FP_ROUND: case ISD::FP_EXTEND: // These operations are used to do promotion so they can't be promoted // themselves. llvm_unreachable("Don't know how to promote this operation!"); - } + default: { + // There are currently two cases of vector promotion: + // 1) Bitcasting a vector of integers to a different type to a vector of the + // same overall length. For example, x86 promotes ISD::AND v2i32 to v1i64. + // 2) Extending a vector of floats to a vector of the same number of larger + // floats. For example, AArch64 promotes ISD::FADD on v4f16 to v4f32. + MVT VT = Op.getSimpleValueType(); + assert(Op.getNode()->getNumValues() == 1 && + "Can't promote a vector with multiple results!"); + MVT NVT = TLI.getTypeToPromoteTo(Op.getOpcode(), VT); + SDLoc dl(Op); + SmallVector Operands(Op.getNumOperands()); + + for (unsigned j = 0; j != Op.getNumOperands(); ++j) { + if (Op.getOperand(j).getValueType().isVector()) + if (Op.getOperand(j) + .getValueType() + .getVectorElementType() + .isFloatingPoint() && + NVT.isVector() && NVT.getVectorElementType().isFloatingPoint()) + Operands[j] = DAG.getNode(ISD::FP_EXTEND, dl, NVT, Op.getOperand(j)); + else + Operands[j] = DAG.getNode(ISD::BITCAST, dl, NVT, Op.getOperand(j)); + else + Operands[j] = Op.getOperand(j); + } - // There are currently two cases of vector promotion: - // 1) Bitcasting a vector of integers to a different type to a vector of the - // same overall length. For example, x86 promotes ISD::AND v2i32 to v1i64. - // 2) Extending a vector of floats to a vector of the same number of larger - // floats. For example, AArch64 promotes ISD::FADD on v4f16 to v4f32. - MVT VT = Op.getSimpleValueType(); - assert(Op.getNode()->getNumValues() == 1 && - "Can't promote a vector with multiple results!"); - MVT NVT = TLI.getTypeToPromoteTo(Op.getOpcode(), VT); - SDLoc dl(Op); - SmallVector Operands(Op.getNumOperands()); + Op = DAG.getNode(Op.getOpcode(), dl, NVT, Operands, + Op.getNode()->getFlags()); - for (unsigned j = 0; j != Op.getNumOperands(); ++j) { - if (Op.getOperand(j).getValueType().isVector()) - if (Op.getOperand(j) - .getValueType() - .getVectorElementType() - .isFloatingPoint() && - NVT.isVector() && NVT.getVectorElementType().isFloatingPoint()) - Operands[j] = DAG.getNode(ISD::FP_EXTEND, dl, NVT, Op.getOperand(j)); - else - Operands[j] = DAG.getNode(ISD::BITCAST, dl, NVT, Op.getOperand(j)); + SDValue Res; + if ((VT.isFloatingPoint() && NVT.isFloatingPoint()) || + (VT.isVector() && VT.getVectorElementType().isFloatingPoint() && + NVT.isVector() && NVT.getVectorElementType().isFloatingPoint())) + Res = DAG.getNode(ISD::FP_ROUND, dl, VT, Op, DAG.getIntPtrConstant(0, dl)); else - Operands[j] = Op.getOperand(j); + Res = DAG.getNode(ISD::BITCAST, dl, VT, Op); + + Results.push_back(Res); + break; + } } - Op = DAG.getNode(Op.getOpcode(), dl, NVT, Operands, Op.getNode()->getFlags()); - if ((VT.isFloatingPoint() && NVT.isFloatingPoint()) || - (VT.isVector() && VT.getVectorElementType().isFloatingPoint() && - NVT.isVector() && NVT.getVectorElementType().isFloatingPoint())) - return DAG.getNode(ISD::FP_ROUND, dl, VT, Op, DAG.getIntPtrConstant(0, dl)); - else - return DAG.getNode(ISD::BITCAST, dl, VT, Op); + assert(Results.size() == Node->getNumValues() && + "Unexpected number of results"); + + // Make sure that the generated code is itself legal. + for (unsigned i = 0, e = Results.size(); i != e; ++i) { + Results[i] = LegalizeOp(Results[i]); + AddLegalizedOperand(SDValue(Node, i), Results[i]); + } + + Changed = true; + return Results[ResNo]; } -SDValue VectorLegalizer::PromoteINT_TO_FP(SDValue Op) { +void VectorLegalizer::PromoteINT_TO_FP(SDValue Op, + SmallVectorImpl &Results) { // INT_TO_FP operations may require the input operand be promoted even // when the type is otherwise legal. bool IsStrict = Op->isStrictFPOpcode(); @@ -586,18 +619,24 @@ Operands[j] = Op.getOperand(j); } - if (IsStrict) - return DAG.getNode(Op.getOpcode(), dl, {Op.getValueType(), MVT::Other}, - Operands); + if (IsStrict) { + SDValue Res = DAG.getNode(Op.getOpcode(), dl, + {Op.getValueType(), MVT::Other}, Operands); + Results.push_back(Res); + Results.push_back(Res.getValue(1)); + return; + } - return DAG.getNode(Op.getOpcode(), dl, Op.getValueType(), Operands); + SDValue Res = DAG.getNode(Op.getOpcode(), dl, Op.getValueType(), Operands); + Results.push_back(Res); } // For FP_TO_INT we promote the result type to a vector type with wider // elements and then truncate the result. This is different from the default // PromoteVector which uses bitcast to promote thus assumning that the // promoted vector type has the same overall size. -SDValue VectorLegalizer::PromoteFP_TO_INT(SDValue Op) { +void VectorLegalizer::PromoteFP_TO_INT(SDValue Op, + SmallVectorImpl &Results) { MVT VT = Op.getSimpleValueType(); MVT NVT = TLI.getTypeToPromoteTo(Op.getOpcode(), VT); bool IsStrict = Op->isStrictFPOpcode(); @@ -636,14 +675,13 @@ Promoted = DAG.getNode(NewOpc, dl, NVT, Promoted, DAG.getValueType(VT.getScalarType())); Promoted = DAG.getNode(ISD::TRUNCATE, dl, VT, Promoted); + Results.push_back(Promoted); if (IsStrict) - return DAG.getMergeValues({Promoted, Chain}, dl); - - return Promoted; + Results.push_back(Chain); } -std::pair VectorLegalizer::ExpandLoad(SDValue Op) { - LoadSDNode *LD = cast(Op.getNode()); +std::pair VectorLegalizer::ExpandLoad(SDNode *N) { + LoadSDNode *LD = cast(N); EVT SrcVT = LD->getMemoryVT(); EVT SrcEltVT = SrcVT.getScalarType(); @@ -652,7 +690,7 @@ SDValue NewChain; SDValue Value; if (SrcVT.getVectorNumElements() > 1 && !SrcEltVT.isByteSized()) { - SDLoc dl(Op); + SDLoc dl(N); SmallVector Vals; SmallVector LoadChains; @@ -764,7 +802,7 @@ } NewChain = DAG.getNode(ISD::TokenFactor, dl, MVT::Other, LoadChains); - Value = DAG.getBuildVector(Op.getNode()->getValueType(0), dl, Vals); + Value = DAG.getBuildVector(N->getValueType(0), dl, Vals); } else { std::tie(Value, NewChain) = TLI.scalarizeVectorLoad(LD, DAG); } @@ -772,87 +810,131 @@ return std::make_pair(Value, NewChain); } -SDValue VectorLegalizer::ExpandStore(SDValue Op) { - StoreSDNode *ST = cast(Op.getNode()); +SDValue VectorLegalizer::ExpandStore(SDNode *N) { + StoreSDNode *ST = cast(N); SDValue TF = TLI.scalarizeVectorStore(ST, DAG); return TF; } -SDValue VectorLegalizer::Expand(SDValue Op) { +SDValue VectorLegalizer::Expand(SDNode *Node, unsigned ResNo) { + SDValue Op(Node, 0); // FIXME: Just pass Node to all the expanders. + + SmallVector Results; switch (Op->getOpcode()) { case ISD::SIGN_EXTEND_INREG: - return ExpandSEXTINREG(Op); + Results.push_back(ExpandSEXTINREG(Op)); + break; case ISD::ANY_EXTEND_VECTOR_INREG: - return ExpandANY_EXTEND_VECTOR_INREG(Op); + Results.push_back(ExpandANY_EXTEND_VECTOR_INREG(Op)); + break; case ISD::SIGN_EXTEND_VECTOR_INREG: - return ExpandSIGN_EXTEND_VECTOR_INREG(Op); + Results.push_back(ExpandSIGN_EXTEND_VECTOR_INREG(Op)); + break; case ISD::ZERO_EXTEND_VECTOR_INREG: - return ExpandZERO_EXTEND_VECTOR_INREG(Op); + Results.push_back(ExpandZERO_EXTEND_VECTOR_INREG(Op)); + break; case ISD::BSWAP: - return ExpandBSWAP(Op); + Results.push_back(ExpandBSWAP(Op)); + break; case ISD::VSELECT: - return ExpandVSELECT(Op); + Results.push_back(ExpandVSELECT(Op)); + break; case ISD::SELECT: - return ExpandSELECT(Op); + Results.push_back(ExpandSELECT(Op)); + break; case ISD::FP_TO_UINT: - return ExpandFP_TO_UINT(Op); + ExpandFP_TO_UINT(Op, Results); + break; case ISD::UINT_TO_FP: - return ExpandUINT_TO_FLOAT(Op); + ExpandUINT_TO_FLOAT(Op, Results); + break; case ISD::FNEG: - return ExpandFNEG(Op); + Results.push_back(ExpandFNEG(Op)); + break; case ISD::FSUB: - return ExpandFSUB(Op); + if (SDValue Tmp = ExpandFSUB(Op)) + Results.push_back(Tmp); + break; case ISD::SETCC: - return UnrollVSETCC(Op); + Results.push_back(UnrollVSETCC(Op)); + break; case ISD::ABS: - return ExpandABS(Op); + Results.push_back(ExpandABS(Op)); + break; case ISD::BITREVERSE: - return ExpandBITREVERSE(Op); + if (SDValue Tmp = ExpandBITREVERSE(Op)) + Results.push_back(Tmp); + break; case ISD::CTPOP: - return ExpandCTPOP(Op); + Results.push_back(ExpandCTPOP(Op)); + break; case ISD::CTLZ: case ISD::CTLZ_ZERO_UNDEF: - return ExpandCTLZ(Op); + Results.push_back(ExpandCTLZ(Op)); + break; case ISD::CTTZ: case ISD::CTTZ_ZERO_UNDEF: - return ExpandCTTZ(Op); + Results.push_back(ExpandCTTZ(Op)); + break; case ISD::FSHL: case ISD::FSHR: - return ExpandFunnelShift(Op); + Results.push_back(ExpandFunnelShift(Op)); + break; case ISD::ROTL: case ISD::ROTR: - return ExpandROT(Op); + Results.push_back(ExpandROT(Op)); + break; case ISD::FMINNUM: case ISD::FMAXNUM: - return ExpandFMINNUM_FMAXNUM(Op); + Results.push_back(ExpandFMINNUM_FMAXNUM(Op)); + break; case ISD::UADDO: - case ISD::USUBO: - return ExpandUADDSUBO(Op); + case ISD::USUBO: { + SDValue Result, Overflow; + std::tie(Result, Overflow) = ExpandUADDSUBO(Op); + Results.push_back(Result); + Results.push_back(Overflow); + break; + } case ISD::SADDO: - case ISD::SSUBO: - return ExpandSADDSUBO(Op); + case ISD::SSUBO: { + SDValue Result, Overflow; + std::tie(Result, Overflow) = ExpandSADDSUBO(Op); + Results.push_back(Result); + Results.push_back(Overflow); + break; + } case ISD::UMULO: - case ISD::SMULO: - return ExpandMULO(Op); + case ISD::SMULO: { + SDValue Result, Overflow; + std::tie(Result, Overflow) = ExpandMULO(Op); + Results.push_back(Result); + Results.push_back(Overflow); + break; + } case ISD::USUBSAT: case ISD::SSUBSAT: case ISD::UADDSAT: case ISD::SADDSAT: - return ExpandAddSubSat(Op); + Results.push_back(ExpandAddSubSat(Op)); + break; case ISD::SMULFIX: case ISD::UMULFIX: - return ExpandFixedPointMul(Op); + Results.push_back(ExpandFixedPointMul(Op)); + break; case ISD::SMULFIXSAT: case ISD::UMULFIXSAT: // FIXME: We do not expand SMULFIXSAT/UMULFIXSAT here yet, not sure exactly // why. Maybe it results in worse codegen compared to the unroll for some // targets? This should probably be investigated. And if we still prefer to // unroll an explanation could be helpful. - return DAG.UnrollVectorOp(Op.getNode()); + Results.push_back(DAG.UnrollVectorOp(Op.getNode())); + break; #define INSTRUCTION(NAME, NARG, ROUND_MODE, INTRINSIC, DAGN) \ case ISD::STRICT_##DAGN: #include "llvm/IR/ConstrainedOps.def" - return ExpandStrictFPOp(Op); + ExpandStrictFPOp(Op, Results); + break; case ISD::VECREDUCE_ADD: case ISD::VECREDUCE_MUL: case ISD::VECREDUCE_AND: @@ -866,10 +948,29 @@ case ISD::VECREDUCE_FMUL: case ISD::VECREDUCE_FMAX: case ISD::VECREDUCE_FMIN: - return TLI.expandVecReduce(Op.getNode(), DAG); + Results.push_back(TLI.expandVecReduce(Op.getNode(), DAG)); + break; default: - return DAG.UnrollVectorOp(Op.getNode()); + Results.push_back(DAG.UnrollVectorOp(Op.getNode())); + break; } + + if (Results.empty()) { + TranslateLegalizeResults(Op, Node); + return SDValue(Node, ResNo); + } + + assert(Results.size() == Node->getNumValues() && + "Unexpected number of results"); + + // Make sure that the generated code is itself legal. + for (unsigned i = 0, e = Results.size(); i != e; ++i) { + Results[i] = LegalizeOp(Results[i]); + AddLegalizedOperand(SDValue(Node, i), Results[i]); + } + + Changed = true; + return Results[ResNo]; } SDValue VectorLegalizer::ExpandSELECT(SDValue Op) { @@ -1114,7 +1215,7 @@ return DAG.UnrollVectorOp(Op.getNode()); // Let LegalizeDAG handle this later. - return Op; + return SDValue(); } SDValue VectorLegalizer::ExpandVSELECT(SDValue Op) { @@ -1174,23 +1275,28 @@ return DAG.UnrollVectorOp(Op.getNode()); } -SDValue VectorLegalizer::ExpandFP_TO_UINT(SDValue Op) { +void VectorLegalizer::ExpandFP_TO_UINT(SDValue Op, + SmallVectorImpl &Results) { // Attempt to expand using TargetLowering. SDValue Result, Chain; if (TLI.expandFP_TO_UINT(Op.getNode(), Result, Chain, DAG)) { + Results.push_back(Result); if (Op->isStrictFPOpcode()) - // Relink the chain - DAG.ReplaceAllUsesOfValueWith(Op.getValue(1), Chain); - return Result; + Results.push_back(Chain); + return; } // Otherwise go ahead and unroll. - if (Op->isStrictFPOpcode()) - return UnrollStrictFPOp(Op); - return DAG.UnrollVectorOp(Op.getNode()); + if (Op->isStrictFPOpcode()) { + UnrollStrictFPOp(Op, Results); + return; + } + + Results.push_back(DAG.UnrollVectorOp(Op.getNode())); } -SDValue VectorLegalizer::ExpandUINT_TO_FLOAT(SDValue Op) { +void VectorLegalizer::ExpandUINT_TO_FLOAT(SDValue Op, + SmallVectorImpl &Results) { bool IsStrict = Op.getNode()->isStrictFPOpcode(); unsigned OpNo = IsStrict ? 1 : 0; SDValue Src = Op.getOperand(OpNo); @@ -1201,10 +1307,10 @@ SDValue Result; SDValue Chain; if (TLI.expandUINT_TO_FP(Op.getNode(), Result, Chain, DAG)) { + Results.push_back(Result); if (IsStrict) - // Relink the chain - DAG.ReplaceAllUsesOfValueWith(Op.getValue(1), Chain); - return Result; + Results.push_back(Chain); + return; } // Make sure that the SINT_TO_FP and SRL instructions are available. @@ -1213,9 +1319,13 @@ (IsStrict && TLI.getOperationAction(ISD::STRICT_SINT_TO_FP, VT) == TargetLowering::Expand)) || TLI.getOperationAction(ISD::SRL, VT) == TargetLowering::Expand) { - if (IsStrict) - return UnrollStrictFPOp(Op); - return DAG.UnrollVectorOp(Op.getNode()); + if (IsStrict) { + UnrollStrictFPOp(Op, Results); + return; + } + + Results.push_back(DAG.UnrollVectorOp(Op.getNode())); + return; } unsigned BW = VT.getScalarSizeInBits(); @@ -1255,9 +1365,9 @@ DAG.getNode(ISD::STRICT_FADD, DL, {Op.getValueType(), MVT::Other}, {SDValue(fLO.getNode(), 1), fHI, fLO}); - // Relink the chain - DAG.ReplaceAllUsesOfValueWith(Op.getValue(1), SDValue(Result.getNode(), 1)); - return Result; + Results.push_back(Result); + Results.push_back(Result.getValue(1)); + return; } // Convert hi and lo to floats @@ -1268,7 +1378,7 @@ SDValue fLO = DAG.getNode(ISD::SINT_TO_FP, DL, Op.getValueType(), LO); // Add the two halves - return DAG.getNode(ISD::FADD, DL, Op.getValueType(), fHI, fLO); + Results.push_back(DAG.getNode(ISD::FADD, DL, Op.getValueType(), fHI, fLO)); } SDValue VectorLegalizer::ExpandFNEG(SDValue Op) { @@ -1289,7 +1399,7 @@ EVT VT = Op.getValueType(); if (TLI.isOperationLegalOrCustom(ISD::FNEG, VT) && TLI.isOperationLegalOrCustom(ISD::FADD, VT)) - return Op; // Defer to LegalizeDAG + return SDValue(); // Defer to LegalizeDAG return DAG.UnrollVectorOp(Op.getNode()); } @@ -1340,44 +1450,23 @@ return DAG.UnrollVectorOp(Op.getNode()); } -SDValue VectorLegalizer::ExpandUADDSUBO(SDValue Op) { +std::pair VectorLegalizer::ExpandUADDSUBO(SDValue Op) { SDValue Result, Overflow; TLI.expandUADDSUBO(Op.getNode(), Result, Overflow, DAG); - - if (Op.getResNo() == 0) { - AddLegalizedOperand(Op.getValue(1), LegalizeOp(Overflow)); - return Result; - } else { - AddLegalizedOperand(Op.getValue(0), LegalizeOp(Result)); - return Overflow; - } + return std::make_pair(Result, Overflow); } -SDValue VectorLegalizer::ExpandSADDSUBO(SDValue Op) { +std::pair VectorLegalizer::ExpandSADDSUBO(SDValue Op) { SDValue Result, Overflow; TLI.expandSADDSUBO(Op.getNode(), Result, Overflow, DAG); - - if (Op.getResNo() == 0) { - AddLegalizedOperand(Op.getValue(1), LegalizeOp(Overflow)); - return Result; - } else { - AddLegalizedOperand(Op.getValue(0), LegalizeOp(Result)); - return Overflow; - } + return std::make_pair(Result, Overflow); } -SDValue VectorLegalizer::ExpandMULO(SDValue Op) { +std::pair VectorLegalizer::ExpandMULO(SDValue Op) { SDValue Result, Overflow; if (!TLI.expandMULO(Op.getNode(), Result, Overflow, DAG)) std::tie(Result, Overflow) = DAG.UnrollVectorOverflowOp(Op.getNode()); - - if (Op.getResNo() == 0) { - AddLegalizedOperand(Op.getValue(1), LegalizeOp(Overflow)); - return Result; - } else { - AddLegalizedOperand(Op.getValue(0), LegalizeOp(Result)); - return Overflow; - } + return std::make_pair(Result, Overflow); } SDValue VectorLegalizer::ExpandAddSubSat(SDValue Op) { @@ -1392,16 +1481,22 @@ return DAG.UnrollVectorOp(Op.getNode()); } -SDValue VectorLegalizer::ExpandStrictFPOp(SDValue Op) { - if (Op.getOpcode() == ISD::STRICT_UINT_TO_FP) - return ExpandUINT_TO_FLOAT(Op); - if (Op.getOpcode() == ISD::STRICT_FP_TO_UINT) - return ExpandFP_TO_UINT(Op); +void VectorLegalizer::ExpandStrictFPOp(SDValue Op, + SmallVectorImpl &Results) { + if (Op.getOpcode() == ISD::STRICT_UINT_TO_FP) { + ExpandUINT_TO_FLOAT(Op, Results); + return; + } + if (Op.getOpcode() == ISD::STRICT_FP_TO_UINT) { + ExpandFP_TO_UINT(Op, Results); + return; + } - return UnrollStrictFPOp(Op); + UnrollStrictFPOp(Op, Results); } -SDValue VectorLegalizer::UnrollStrictFPOp(SDValue Op) { +void VectorLegalizer::UnrollStrictFPOp(SDValue Op, + SmallVectorImpl &Results) { EVT VT = Op.getValue(0).getValueType(); EVT EltVT = VT.getVectorElementType(); unsigned NumElems = VT.getVectorNumElements(); @@ -1458,10 +1553,8 @@ SDValue Result = DAG.getBuildVector(VT, dl, OpValues); SDValue NewChain = DAG.getNode(ISD::TokenFactor, dl, MVT::Other, OpChains); - AddLegalizedOperand(Op.getValue(0), Result); - AddLegalizedOperand(Op.getValue(1), NewChain); - - return Op.getResNo() ? NewChain : Result; + Results.push_back(Result); + Results.push_back(NewChain); } SDValue VectorLegalizer::UnrollVSETCC(SDValue Op) {