Index: llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp =================================================================== --- llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp +++ llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp @@ -1821,6 +1821,19 @@ } } + EVT DstTy = TLI->getValueType(DL, Dst); + EVT SrcTy = TLI->getValueType(DL, Src); + + // The BasicTTIImpl version only deals with CCH==TTI::CastContextHint::Normal, + // but we also want to include the TTI::CastContextHint::Masked case too. + if ((ISD == ISD::ZERO_EXTEND || ISD == ISD::SIGN_EXTEND) && + (CCH == TTI::CastContextHint::Masked) && ST->hasSVE()) { + unsigned LType = + ((Opcode == Instruction::ZExt) ? ISD::ZEXTLOAD : ISD::SEXTLOAD); + if (TLI->isLoadExtLegal(LType, DstTy, SrcTy)) + return 0; + } + // TODO: Allow non-throughput costs that aren't binary. auto AdjustCost = [&CostKind](InstructionCost Cost) -> InstructionCost { if (CostKind != TTI::TCK_RecipThroughput) @@ -1828,9 +1841,6 @@ return Cost; }; - EVT SrcTy = TLI->getValueType(DL, Src); - EVT DstTy = TLI->getValueType(DL, Dst); - if (!SrcTy.isSimple() || !DstTy.isSimple()) return AdjustCost( BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I)); Index: llvm/test/Analysis/CostModel/AArch64/masked_ldst.ll =================================================================== --- llvm/test/Analysis/CostModel/AArch64/masked_ldst.ll +++ llvm/test/Analysis/CostModel/AArch64/masked_ldst.ll @@ -110,29 +110,30 @@ define void @scalable_ext_loads() { ; CHECK-LABEL: 'scalable_ext_loads' ; CHECK-NEXT: Cost Model: Found an estimated cost of 1 for instruction: %load.nxv8i8 = call @llvm.masked.load.nxv8i8.p0(ptr undef, i32 8, undef, undef) -; CHECK-NEXT: Cost Model: Found an estimated cost of 1 for instruction: %zext.nxv8i8to16 = zext %load.nxv8i8 to +; CHECK-NEXT: Cost Model: Found an estimated cost of 0 for instruction: %zext.nxv8i8to16 = zext %load.nxv8i8 to ; CHECK-NEXT: Cost Model: Found an estimated cost of 1 for instruction: %load.nxv4i8 = call @llvm.masked.load.nxv4i8.p0(ptr undef, i32 8, undef, undef) -; CHECK-NEXT: Cost Model: Found an estimated cost of 1 for instruction: %zext.nxv4i8to32 = zext %load.nxv4i8 to +; CHECK-NEXT: Cost Model: Found an estimated cost of 0 for instruction: %zext.nxv4i8to32 = zext %load.nxv4i8 to ; CHECK-NEXT: Cost Model: Found an estimated cost of 1 for instruction: %load.nxv2i8 = call @llvm.masked.load.nxv2i8.p0(ptr undef, i32 8, undef, undef) -; CHECK-NEXT: Cost Model: Found an estimated cost of 1 for instruction: %zext.nxv2i8to64 = zext %load.nxv2i8 to +; CHECK-NEXT: Cost Model: Found an estimated cost of 0 for instruction: %zext.nxv2i8to64 = zext %load.nxv2i8 to ; CHECK-NEXT: Cost Model: Found an estimated cost of 1 for instruction: %load.nxv4i16 = call @llvm.masked.load.nxv4i16.p0(ptr undef, i32 8, undef, undef) -; CHECK-NEXT: Cost Model: Found an estimated cost of 1 for instruction: %zext.nxv4i16to32 = zext %load.nxv4i16 to +; CHECK-NEXT: Cost Model: Found an estimated cost of 0 for instruction: %zext.nxv4i16to32 = zext %load.nxv4i16 to ; CHECK-NEXT: Cost Model: Found an estimated cost of 1 for instruction: %load.nxv2i16 = call @llvm.masked.load.nxv2i16.p0(ptr undef, i32 8, undef, undef) -; CHECK-NEXT: Cost Model: Found an estimated cost of 1 for instruction: %zext.nxv2i16to64 = zext %load.nxv2i16 to +; CHECK-NEXT: Cost Model: Found an estimated cost of 0 for instruction: %zext.nxv2i16to64 = zext %load.nxv2i16 to ; CHECK-NEXT: Cost Model: Found an estimated cost of 1 for instruction: %load.nxv2i32 = call @llvm.masked.load.nxv2i32.p0(ptr undef, i32 8, undef, undef) -; CHECK-NEXT: Cost Model: Found an estimated cost of 1 for instruction: %zext.nxv2i32to64 = zext %load.nxv2i32 to +; CHECK-NEXT: Cost Model: Found an estimated cost of 0 for instruction: %zext.nxv2i32to64 = zext %load.nxv2i32 to + ; CHECK-NEXT: Cost Model: Found an estimated cost of 1 for instruction: %load2.nxv8i8 = call @llvm.masked.load.nxv8i8.p0(ptr undef, i32 8, undef, undef) -; CHECK-NEXT: Cost Model: Found an estimated cost of 1 for instruction: %sext.nxv8i8to16 = sext %load2.nxv8i8 to +; CHECK-NEXT: Cost Model: Found an estimated cost of 0 for instruction: %sext.nxv8i8to16 = sext %load2.nxv8i8 to ; CHECK-NEXT: Cost Model: Found an estimated cost of 1 for instruction: %load2.nxv4i8 = call @llvm.masked.load.nxv4i8.p0(ptr undef, i32 8, undef, undef) -; CHECK-NEXT: Cost Model: Found an estimated cost of 1 for instruction: %sext.nxv4i8to32 = sext %load2.nxv4i8 to +; CHECK-NEXT: Cost Model: Found an estimated cost of 0 for instruction: %sext.nxv4i8to32 = sext %load2.nxv4i8 to ; CHECK-NEXT: Cost Model: Found an estimated cost of 1 for instruction: %load2.nxv2i8 = call @llvm.masked.load.nxv2i8.p0(ptr undef, i32 8, undef, undef) -; CHECK-NEXT: Cost Model: Found an estimated cost of 1 for instruction: %sext.nxv2i8to64 = sext %load2.nxv2i8 to +; CHECK-NEXT: Cost Model: Found an estimated cost of 0 for instruction: %sext.nxv2i8to64 = sext %load2.nxv2i8 to ; CHECK-NEXT: Cost Model: Found an estimated cost of 1 for instruction: %load2.nxv4i16 = call @llvm.masked.load.nxv4i16.p0(ptr undef, i32 8, undef, undef) -; CHECK-NEXT: Cost Model: Found an estimated cost of 1 for instruction: %sext.nxv4i16to32 = sext %load2.nxv4i16 to +; CHECK-NEXT: Cost Model: Found an estimated cost of 0 for instruction: %sext.nxv4i16to32 = sext %load2.nxv4i16 to ; CHECK-NEXT: Cost Model: Found an estimated cost of 1 for instruction: %load2.nxv2i16 = call @llvm.masked.load.nxv2i16.p0(ptr undef, i32 8, undef, undef) -; CHECK-NEXT: Cost Model: Found an estimated cost of 1 for instruction: %sext.nxv2i16to64 = sext %load2.nxv2i16 to +; CHECK-NEXT: Cost Model: Found an estimated cost of 0 for instruction: %sext.nxv2i16to64 = sext %load2.nxv2i16 to ; CHECK-NEXT: Cost Model: Found an estimated cost of 1 for instruction: %load2.nxv2i32 = call @llvm.masked.load.nxv2i32.p0(ptr undef, i32 8, undef, undef) -; CHECK-NEXT: Cost Model: Found an estimated cost of 1 for instruction: %sext.nxv2i32to64 = sext %load2.nxv2i32 to +; CHECK-NEXT: Cost Model: Found an estimated cost of 0 for instruction: %sext.nxv2i32to64 = sext %load2.nxv2i32 to %load.nxv8i8 = call @llvm.masked.load.nxv8i8.p0(ptr undef, i32 8, undef, undef) %zext.nxv8i8to16 = zext %load.nxv8i8 to