diff --git a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp --- a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp @@ -912,15 +912,18 @@ } break; } - case ISD::BITCAST: - // Just drop bitcasts between scalable vectors. - if (VT.isScalableVector() && - Node->getOperand(0).getSimpleValueType().isScalableVector()) { + case ISD::BITCAST: { + MVT SrcVT = Node->getOperand(0).getSimpleValueType(); + // Just drop bitcasts between vectors if both are fixed or both are + // scalable. + if ((VT.isScalableVector() && SrcVT.isScalableVector()) || + (VT.isFixedLengthVector() && SrcVT.isFixedLengthVector())) { ReplaceUses(SDValue(Node, 0), Node->getOperand(0)); CurDAG->RemoveDeadNode(Node); return; } break; + } case ISD::INSERT_SUBVECTOR: { SDValue V = Node->getOperand(0); SDValue SubV = Node->getOperand(1); diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -571,6 +571,8 @@ setOperationAction(ISD::ANY_EXTEND, VT, Custom); setOperationAction(ISD::SIGN_EXTEND, VT, Custom); setOperationAction(ISD::ZERO_EXTEND, VT, Custom); + + setOperationAction(ISD::BITCAST, VT, Custom); } for (MVT VT : MVT::fp_fixedlen_vector_valuetypes()) { @@ -602,6 +604,8 @@ setCondCodeAction(CC, VT, Expand); setOperationAction(ISD::VSELECT, VT, Custom); + + setOperationAction(ISD::BITCAST, VT, Custom); } } } @@ -1099,11 +1103,18 @@ case ISD::SRL_PARTS: return lowerShiftRightParts(Op, DAG, false); case ISD::BITCAST: { + SDValue Op0 = Op.getOperand(0); + // We can handle fixed length vector bitcasts with a simple replacement + // in isel. + if (Op.getValueType().isFixedLengthVector()) { + if (Op0.getValueType().isFixedLengthVector()) + return Op; + return SDValue(); + } assert(((Subtarget.is64Bit() && Subtarget.hasStdExtF()) || Subtarget.hasStdExtZfh()) && "Unexpected custom legalisation"); SDLoc DL(Op); - SDValue Op0 = Op.getOperand(0); if (Op.getValueType() == MVT::f16 && Subtarget.hasStdExtZfh()) { if (Op0.getValueType() != MVT::i16) return SDValue();