diff --git a/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp b/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp --- a/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp @@ -445,11 +445,12 @@ Value *AlignOp) { // Make sure the operation will be supported by the backend. MaybeAlign MA = cast(AlignOp)->getMaybeAlignValue(); - if (!MA || !TLI->isLegalStridedLoadStore(*DL, DataType, *MA)) + EVT DataTypeVT = TLI->getValueType(*DL, DataType); + if (!MA || !TLI->isLegalStridedLoadStore(DataTypeVT, *MA)) return false; // FIXME: Let the backend type legalize by splitting/widening? - if (!TLI->isTypeLegal(TLI->getValueType(*DL, DataType))) + if (!TLI->isTypeLegal(DataTypeVT)) return false; // Pointer should be a GEP. diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h --- a/llvm/lib/Target/RISCV/RISCVISelLowering.h +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h @@ -666,7 +666,7 @@ bool shouldRemoveExtendFromGSIndex(EVT IndexVT, EVT DataVT) const override; - bool isLegalElementTypeForRVV(Type *ScalarTy) const; + bool isLegalElementTypeForRVV(EVT ScalarTy) const; bool shouldConvertFpToSat(unsigned Op, EVT FPVT, EVT VT) const override; @@ -706,7 +706,7 @@ /// Return true if a stride load store of the given result type and /// alignment is legal. - bool isLegalStridedLoadStore(const DataLayout &DL, Type *DataType, Align Alignment) const; + bool isLegalStridedLoadStore(EVT DataType, Align Alignment) const; unsigned getMaxSupportedInterleaveFactor() const override { return 8; } diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -2060,27 +2060,30 @@ (VT.isFixedLengthVector() && VT.getVectorElementType() == MVT::i1); } -bool RISCVTargetLowering::isLegalElementTypeForRVV(Type *ScalarTy) const { - if (ScalarTy->isPointerTy()) +bool RISCVTargetLowering::isLegalElementTypeForRVV(EVT ScalarTy) const { + if (!ScalarTy.isSimple()) + return false; + switch (ScalarTy.getSimpleVT().SimpleTy) { + case MVT::iPTR: return Subtarget.is64Bit() ? Subtarget.hasVInstructionsI64() : true; - - if (ScalarTy->isIntegerTy(8) || ScalarTy->isIntegerTy(16) || - ScalarTy->isIntegerTy(32)) + case MVT::i8: + case MVT::i16: + case MVT::i32: return true; - - if (ScalarTy->isIntegerTy(64)) + case MVT::i64: return Subtarget.hasVInstructionsI64(); - - if (ScalarTy->isHalfTy()) + case MVT::f16: return Subtarget.hasVInstructionsF16(); - if (ScalarTy->isFloatTy()) + case MVT::f32: return Subtarget.hasVInstructionsF32(); - if (ScalarTy->isDoubleTy()) + case MVT::f64: return Subtarget.hasVInstructionsF64(); - - return false; + default: + return false; + } } + unsigned RISCVTargetLowering::combineRepeatedFPDivisors() const { return NumRepeatedDivisors; } @@ -11458,8 +11461,7 @@ return SDValue(); // Check that the operation is legal - Type *WideVecTy = EVT(WideVecVT).getTypeForEVT(*DAG.getContext()); - if (!TLI.isLegalStridedLoadStore(DAG.getDataLayout(), WideVecTy, Align)) + if (!TLI.isLegalStridedLoadStore(WideVecVT, Align)) return SDValue(); MVT ContainerVT = TLI.getContainerForFixedLengthVector(WideVecVT); @@ -15815,12 +15817,14 @@ FixedVectorType *VTy, unsigned Factor, const DataLayout &DL) const { if (!Subtarget.useRVVForFixedLengthVectors()) return false; - if (!isLegalElementTypeForRVV(VTy->getElementType())) - return false; EVT VT = getValueType(DL, VTy); // Don't lower vlseg/vsseg for fixed length vector types that can't be split. if (!isTypeLegal(VT)) return false; + + if (!isLegalElementTypeForRVV(VT.getScalarType())) + return false; + // Sometimes the interleaved access pass picks up splats as interleaves of one // element. Don't lower these. if (VTy->getNumElements() < 2) @@ -15834,22 +15838,21 @@ return Factor * LMUL <= 8; } -bool RISCVTargetLowering::isLegalStridedLoadStore(const DataLayout &DL, - Type *DataType, +bool RISCVTargetLowering::isLegalStridedLoadStore(EVT DataType, Align Alignment) const { if (!Subtarget.hasVInstructions()) return false; // Only support fixed vectors if we know the minimum vector size. - if (isa(DataType) && !Subtarget.useRVVForFixedLengthVectors()) + if (DataType.isFixedLengthVector() && !Subtarget.useRVVForFixedLengthVectors()) return false; - Type *ScalarType = DataType->getScalarType(); + EVT ScalarType = DataType.getScalarType(); if (!isLegalElementTypeForRVV(ScalarType)) return false; if (!Subtarget.enableUnalignedVectorMem() && - Alignment < DL.getTypeStoreSize(ScalarType).getFixedValue()) + Alignment < ScalarType.getStoreSize()) return false; return true; diff --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h --- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h +++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h @@ -179,23 +179,25 @@ const Instruction *CxtI = nullptr); bool isElementTypeLegalForScalableVector(Type *Ty) const { - return TLI->isLegalElementTypeForRVV(Ty); + return TLI->isLegalElementTypeForRVV(TLI->getValueType(DL, Ty)); } bool isLegalMaskedLoadStore(Type *DataType, Align Alignment) { if (!ST->hasVInstructions()) return false; + EVT DataTypeVT = TLI->getValueType(DL, DataType); + // Only support fixed vectors if we know the minimum vector size. - if (isa(DataType) && !ST->useRVVForFixedLengthVectors()) + if (DataTypeVT.isFixedLengthVector() && !ST->useRVVForFixedLengthVectors()) return false; - auto *ElemType = DataType->getScalarType(); - if (!ST->enableUnalignedVectorMem() && - Alignment < DL.getTypeStoreSize(ElemType).getFixedValue()) + EVT ElemType = DataTypeVT.getScalarType(); + if (!ST->enableUnalignedVectorMem() && Alignment < ElemType.getStoreSize()) return false; return TLI->isLegalElementTypeForRVV(ElemType); + } bool isLegalMaskedLoad(Type *DataType, Align Alignment) { @@ -209,13 +211,14 @@ if (!ST->hasVInstructions()) return false; + EVT DataTypeVT = TLI->getValueType(DL, DataType); + // Only support fixed vectors if we know the minimum vector size. - if (isa(DataType) && !ST->useRVVForFixedLengthVectors()) + if (DataTypeVT.isFixedLengthVector() && !ST->useRVVForFixedLengthVectors()) return false; - auto *ElemType = DataType->getScalarType(); - if (!ST->enableUnalignedVectorMem() && - Alignment < DL.getTypeStoreSize(ElemType).getFixedValue()) + EVT ElemType = DataTypeVT.getScalarType(); + if (!ST->enableUnalignedVectorMem() && Alignment < ElemType.getStoreSize()) return false; return TLI->isLegalElementTypeForRVV(ElemType); @@ -262,7 +265,7 @@ return true; Type *Ty = RdxDesc.getRecurrenceType(); - if (!TLI->isLegalElementTypeForRVV(Ty)) + if (!TLI->isLegalElementTypeForRVV(TLI->getValueType(DL, Ty))) return false; switch (RdxDesc.getRecurrenceKind()) {