Index: llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h =================================================================== --- llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h +++ llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h @@ -58,8 +58,8 @@ }; bool isWideningInstruction(Type *DstTy, unsigned Opcode, - ArrayRef SrcTys, - ArrayRef Args); + ArrayRef Args, + Type *SrcOverrideTy = nullptr); // A helper function called by 'getVectorInstrCost'. // Index: llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp =================================================================== --- llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp +++ llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp @@ -1945,9 +1945,8 @@ } bool AArch64TTIImpl::isWideningInstruction(Type *DstTy, unsigned Opcode, - ArrayRef SrcTys, - ArrayRef Args) { - + ArrayRef Args, + Type *SrcOverrideTy) { // A helper that returns a vector type from the given type. The number of // elements in type Ty determines the vector width. auto toVectorTy = [&](Type *ArgTy) { @@ -1955,12 +1954,14 @@ cast(DstTy)->getElementCount()); }; - // Exit early if DstTy is not a vector type whose elements are at least - // 16-bits wide. SVE doesn't generally have the same set of instructions to + // Exit early if DstTy is not a vector type whose elements are one of [i16, + // i32, i64]. SVE doesn't generally have the same set of instructions to // perform an extend with the add/sub/mul. There are SMULLB style // instructions, but they operate on top/bottom, requiring some sort of lane // interleaving to be used with zext/sext. - if (!useNeonVector(DstTy) || DstTy->getScalarSizeInBits() < 16) + unsigned DstEltSize = DstTy->getScalarSizeInBits(); + if (!useNeonVector(DstTy) || Args.size() != 2 || + (DstEltSize != 16 && DstEltSize != 32 && DstEltSize != 64)) return false; // Determine if the operation has a widening variant. We consider both the @@ -1970,42 +1971,58 @@ // TODO: Add additional widening operations (e.g., shl, etc.) once we // verify that their extending operands are eliminated during code // generation. + Type *SrcTy = SrcOverrideTy; switch (Opcode) { case Instruction::Add: // UADDL(2), SADDL(2), UADDW(2), SADDW(2). case Instruction::Sub: // USUBL(2), SSUBL(2), USUBW(2), SSUBW(2). - case Instruction::Mul: // SMULL(2), UMULL(2) + // The second operand needs to be an extend + if (isa(Args[1]) || isa(Args[1])) { + if (!SrcTy) + SrcTy = + toVectorTy(cast(Args[1])->getOperand(0)->getType()); + } else + return false; + break; + case Instruction::Mul: { // SMULL(2), UMULL(2) + // Both operands need to be extends of the same type. + if (isa(Args[0]) && isa(Args[1])) { + if (!SrcTy) + SrcTy = + toVectorTy(cast(Args[0])->getOperand(0)->getType()); + } else if (isa(Args[0]) && isa(Args[1])) { + if (!SrcTy) + SrcTy = + toVectorTy(cast(Args[0])->getOperand(0)->getType()); + } else if (isa(Args[0]) || isa(Args[1])) { + // If one of the operands is a Zext and the other has enough zero bits to + // be treated as unsigned, we can still general a umull, meaning the zext + // is free. + KnownBits Known = + computeKnownBits(isa(Args[0]) ? Args[1] : Args[0], DL); + if (Args[0]->getType()->getScalarSizeInBits() - + Known.Zero.countLeadingOnes() > + DstTy->getScalarSizeInBits() / 2) + return false; + if (!SrcTy) + SrcTy = toVectorTy(Type::getIntNTy(DstTy->getContext(), + DstTy->getScalarSizeInBits() / 2)); + } else + return false; break; + } default: return false; } - // To be a widening instruction (either the "wide" or "long" versions), the - // second operand must be a sign- or zero extend. - if (Args.size() != 2 || - (!isa(Args[1]) && !isa(Args[1]))) - return false; - auto *Extend = cast(Args[1]); - auto *Arg0 = dyn_cast(Args[0]); - - // A mul only has a mull version (not like addw). Both operands need to be - // extending and the same type. - if (Opcode == Instruction::Mul && - (!Arg0 || Arg0->getOpcode() != Extend->getOpcode() || - (SrcTys.size() == 2 && SrcTys[0] != SrcTys[1]))) - return false; - // Legalize the destination type and ensure it can be used in a widening // operation. auto DstTyL = getTypeLegalizationCost(DstTy); - unsigned DstElTySize = DstTyL.second.getScalarSizeInBits(); - if (!DstTyL.second.isVector() || DstElTySize != DstTy->getScalarSizeInBits()) + if (!DstTyL.second.isVector() || DstEltSize != DstTy->getScalarSizeInBits()) return false; // Legalize the source type and ensure it can be used in a widening // operation. - Type *SrcTy = - SrcTys.size() > 0 ? SrcTys.back() : toVectorTy(Extend->getSrcTy()); - + assert(SrcTy && "Expected some SrcTy"); auto SrcTyL = getTypeLegalizationCost(SrcTy); unsigned SrcElTySize = SrcTyL.second.getScalarSizeInBits(); if (!SrcTyL.second.isVector() || SrcElTySize != SrcTy->getScalarSizeInBits()) @@ -2019,7 +2036,7 @@ // Return true if the legalized types have the same number of vector elements // and the destination element type size is twice that of the source type. - return NumDstEls == NumSrcEls && 2 * SrcElTySize == DstElTySize; + return NumDstEls == NumSrcEls && 2 * SrcElTySize == DstEltSize; } InstructionCost AArch64TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, @@ -2034,31 +2051,17 @@ if (I && I->hasOneUser()) { auto *SingleUser = cast(*I->user_begin()); SmallVector Operands(SingleUser->operand_values()); - SmallVector SrcTys; - for (const Value *Op : Operands) { - auto *Cast = dyn_cast(Op); - if (!Cast) - continue; - // Use provided Src type for I and other casts that have the same source - // type. - if (Op == I || cast(I)->getSrcTy() == Cast->getSrcTy()) - SrcTys.push_back(Src); - else - SrcTys.push_back(Cast->getSrcTy()); - } - if (isWideningInstruction(Dst, SingleUser->getOpcode(), SrcTys, Operands)) { - // If the cast is the second operand, it is free. We will generate either - // a "wide" or "long" version of the widening instruction. - if (I == SingleUser->getOperand(1)) - return 0; - // If the cast is not the second operand, it will be free if it looks the - // same as the second operand. In this case, we will generate a "long" - // version of the widening instruction. - if (auto *Cast = dyn_cast(SingleUser->getOperand(1))) - if (I->getOpcode() == unsigned(Cast->getOpcode()) && - (Src == Cast->getSrcTy() || - cast(I)->getSrcTy() == Cast->getSrcTy())) + if (isWideningInstruction(Dst, SingleUser->getOpcode(), Operands, Src)) { + // For adds only count the second operand as free if both operands are + // extends but not the same operation. (i.e both operands are not free in + // add(sext, zext)). + if (SingleUser->getOpcode() == Instruction::Add) { + if (I == SingleUser->getOperand(1) || + (isa(SingleUser->getOperand(1)) && + cast(SingleUser->getOperand(1))->getOpcode() == Opcode)) return 0; + } else // Others are free so long as isWideningInstruction returned true. + return 0; } } @@ -2693,7 +2696,7 @@ // LT.first = 2 the cost is 28. If both operands are extensions it will not // need to scalarize so the cost can be cheaper (smull or umull). // so the cost can be cheaper (smull or umull). - if (LT.second != MVT::v2i64 || isWideningInstruction(Ty, Opcode, {}, Args)) + if (LT.second != MVT::v2i64 || isWideningInstruction(Ty, Opcode, Args)) return LT.first; return LT.first * 14; case ISD::ADD: Index: llvm/test/Analysis/CostModel/AArch64/arith-widening.ll =================================================================== --- llvm/test/Analysis/CostModel/AArch64/arith-widening.ll +++ llvm/test/Analysis/CostModel/AArch64/arith-widening.ll @@ -2092,9 +2092,9 @@ ; CHECK-LABEL: 'extmul_const' ; CHECK-NEXT: Cost Model: Found an estimated cost of 1 for instruction: %sl1_8_16 = sext <8 x i8> %i8 to <8 x i16> ; CHECK-NEXT: Cost Model: Found an estimated cost of 1 for instruction: %asl_8_16 = mul <8 x i16> %sl1_8_16, -; CHECK-NEXT: Cost Model: Found an estimated cost of 1 for instruction: %zl1_8_16 = zext <8 x i8> %i8 to <8 x i16> +; CHECK-NEXT: Cost Model: Found an estimated cost of 0 for instruction: %zl1_8_16 = zext <8 x i8> %i8 to <8 x i16> ; CHECK-NEXT: Cost Model: Found an estimated cost of 1 for instruction: %azl_8_16 = mul <8 x i16> %zl1_8_16, -; CHECK-NEXT: Cost Model: Found an estimated cost of 1 for instruction: %zl1_8_16b = zext <8 x i8> %i8 to <8 x i16> +; CHECK-NEXT: Cost Model: Found an estimated cost of 0 for instruction: %zl1_8_16b = zext <8 x i8> %i8 to <8 x i16> ; CHECK-NEXT: Cost Model: Found an estimated cost of 1 for instruction: %and = and <8 x i16> %sl1_8_16, ; CHECK-NEXT: Cost Model: Found an estimated cost of 1 for instruction: %aal_8_16 = mul <8 x i16> %zl1_8_16b, %and ; CHECK-NEXT: Cost Model: Found an estimated cost of 0 for instruction: ret void