diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp --- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp +++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp @@ -39,6 +39,7 @@ STATISTIC(NumShufOfBitcast, "Number of shuffles moved after bitcast"); STATISTIC(NumScalarBO, "Number of scalar binops formed"); STATISTIC(NumScalarCmp, "Number of scalar compares formed"); +STATISTIC(NumTruncElim, "Number of truncates eliminated"); static cl::opt DisableVectorCombine( "disable-vector-combine", cl::init(false), cl::Hidden, @@ -79,6 +80,9 @@ bool foldBitcastShuf(Instruction &I); bool scalarizeBinopOrCmp(Instruction &I); bool foldExtractedCmps(Instruction &I); + /// Try to shorten zexts so if it allows eliminating a trunc and the shorter + /// zexts are free. + bool foldTruncBinopExt(Instruction &I); }; static void replaceValue(Value &Old, Value &New) { @@ -604,6 +608,55 @@ return true; } +bool VectorCombine::foldTruncBinopExt(Instruction &I) { + Value *V1; + Value *V2; + if (!I.getType()->isVectorTy()) + return false; + + BinaryOperator *BO; + if (!match(&I, m_Trunc(m_OneUse(m_CombineAnd( + m_BinOp(m_OneUse(m_Value(V1)), m_OneUse(m_Value(V2))), + m_BinOp(BO)))))) + return false; + + // Limit transform to known-safe opcodes. + // TODO: Add more safe opcodes. + if (BO->getOpcode() != Instruction::Add && + BO->getOpcode() != Instruction::Sub) + return false; + + // Limit to ZExt operands that have the same input width and are at least as + // large as the truncate width. + auto *ZExt1 = dyn_cast(V1); + auto *ZExt2 = dyn_cast(V2); + if (!ZExt1 || !ZExt2) + return false; + if (ZExt1->getOperand(0)->getType() != ZExt2->getOperand(0)->getType()) + return false; + if (I.getType()->getScalarSizeInBits() <= + ZExt1->getOperand(0)->getType()->getScalarSizeInBits()) + return false; + + // Check if the new ZExts can be done for free by folding them into their + // user. + if (TTI.getCastInstrCost(ZExt1->getOpcode(), I.getType(), V1->getType(), + TTI::TCK_RecipThroughput, ZExt1) != 0) + return false; + + auto *ShortZExt1 = Builder.CreateZExt(ZExt1->getOperand(0), I.getType()); + auto *ShortZExt2 = Builder.CreateZExt(ZExt2->getOperand(0), I.getType()); + auto *ShortSub = Builder.CreateBinOp(BO->getOpcode(), ShortZExt1, ShortZExt2); + auto *ShortBO = dyn_cast(ShortSub); + if (BO->hasNoUnsignedWrap()) + ShortBO->setHasNoUnsignedWrap(); + if (BO->hasNoSignedWrap()) + ShortBO->setHasNoSignedWrap(); + NumTruncElim++; + I.replaceAllUsesWith(ShortSub); + return true; +} + /// This is the entry point for all transforms. Pass manager differences are /// handled in the callers of this function. bool VectorCombine::run() { @@ -627,6 +680,7 @@ MadeChange |= foldBitcastShuf(I); MadeChange |= scalarizeBinopOrCmp(I); MadeChange |= foldExtractedCmps(I); + MadeChange |= foldTruncBinopExt(I); } } diff --git a/llvm/test/Transforms/VectorCombine/AArch64/shorten-extend-if-free.ll b/llvm/test/Transforms/VectorCombine/AArch64/shorten-extend-if-free.ll --- a/llvm/test/Transforms/VectorCombine/AArch64/shorten-extend-if-free.ll +++ b/llvm/test/Transforms/VectorCombine/AArch64/shorten-extend-if-free.ll @@ -4,11 +4,10 @@ ; Zext can be folded with sub add for free. define <8 x i32> @shorten_zext_sub(<8 x i16> %v.1, <8 x i16> %v.2) { ; CHECK-LABEL: @shorten_zext_sub( -; CHECK-NEXT: [[EXT_1:%.*]] = zext <8 x i16> [[V_1:%.*]] to <8 x i64> -; CHECK-NEXT: [[EXT_2:%.*]] = zext <8 x i16> [[V_2:%.*]] to <8 x i64> -; CHECK-NEXT: [[SUB:%.*]] = sub nsw <8 x i64> [[EXT_1]], [[EXT_2]] -; CHECK-NEXT: [[T:%.*]] = trunc <8 x i64> [[SUB]] to <8 x i32> -; CHECK-NEXT: ret <8 x i32> [[T]] +; CHECK-NEXT: [[TMP1:%.*]] = zext <8 x i16> [[V_1:%.*]] to <8 x i32> +; CHECK-NEXT: [[TMP2:%.*]] = zext <8 x i16> [[V_2:%.*]] to <8 x i32> +; CHECK-NEXT: [[TMP3:%.*]] = sub nsw <8 x i32> [[TMP1]], [[TMP2]] +; CHECK-NEXT: ret <8 x i32> [[TMP3]] ; %ext.1 = zext <8 x i16> %v.1 to <8 x i64> %ext.2 = zext <8 x i16> %v.2 to <8 x i64> @@ -20,11 +19,10 @@ ; Zext can be folded with add for free. define <8 x i32> @shorten_zext_add(<8 x i16> %v.1, <8 x i16> %v.2) { ; CHECK-LABEL: @shorten_zext_add( -; CHECK-NEXT: [[EXT_1:%.*]] = zext <8 x i16> [[V_1:%.*]] to <8 x i64> -; CHECK-NEXT: [[EXT_2:%.*]] = zext <8 x i16> [[V_2:%.*]] to <8 x i64> -; CHECK-NEXT: [[SUB:%.*]] = add nsw <8 x i64> [[EXT_1]], [[EXT_2]] -; CHECK-NEXT: [[T:%.*]] = trunc <8 x i64> [[SUB]] to <8 x i32> -; CHECK-NEXT: ret <8 x i32> [[T]] +; CHECK-NEXT: [[TMP1:%.*]] = zext <8 x i16> [[V_1:%.*]] to <8 x i32> +; CHECK-NEXT: [[TMP2:%.*]] = zext <8 x i16> [[V_2:%.*]] to <8 x i32> +; CHECK-NEXT: [[TMP3:%.*]] = add nsw <8 x i32> [[TMP1]], [[TMP2]] +; CHECK-NEXT: ret <8 x i32> [[TMP3]] ; %ext.1 = zext <8 x i16> %v.1 to <8 x i64> %ext.2 = zext <8 x i16> %v.2 to <8 x i64>