diff --git a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h --- a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h +++ b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h @@ -1649,15 +1649,17 @@ /// Returns true if \p V is a bitwise not operation. Assumes that an all ones /// constant is canonicalized to be operand 1. -bool isBitwiseNot(SDValue V); +bool isBitwiseNot(SDValue V, bool AllowUndefs = false); /// Returns the SDNode if it is a constant splat BuildVector or constant int. -ConstantSDNode *isConstOrConstSplat(SDValue N, bool AllowUndefs = false); +ConstantSDNode *isConstOrConstSplat(SDValue N, bool AllowUndefs = false, + bool AllowTruncation = false); /// Returns the SDNode if it is a demanded constant splat BuildVector or /// constant int. ConstantSDNode *isConstOrConstSplat(SDValue N, const APInt &DemandedElts, - bool AllowUndefs = false); + bool AllowUndefs = false, + bool AllowTruncation = false); /// Returns the SDNode if it is a constant splat BuildVector or constant float. ConstantFPSDNode *isConstOrConstSplatFP(SDValue N, bool AllowUndefs = false); diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -8640,14 +8640,18 @@ return V; } -bool llvm::isBitwiseNot(SDValue V) { +bool llvm::isBitwiseNot(SDValue V, bool AllowUndefs) { if (V.getOpcode() != ISD::XOR) return false; - ConstantSDNode *C = isConstOrConstSplat(peekThroughBitcasts(V.getOperand(1))); - return C && C->isAllOnesValue(); + V = peekThroughBitcasts(V.getOperand(1)); + unsigned NumBits = V.getScalarValueSizeInBits(); + ConstantSDNode *C = + isConstOrConstSplat(V, AllowUndefs, /*AllowTruncation*/ true); + return C && (C->getAPIntValue().countTrailingOnes() >= NumBits); } -ConstantSDNode *llvm::isConstOrConstSplat(SDValue N, bool AllowUndefs) { +ConstantSDNode *llvm::isConstOrConstSplat(SDValue N, bool AllowUndefs, + bool AllowTruncation) { if (ConstantSDNode *CN = dyn_cast(N)) return CN; @@ -8655,17 +8659,23 @@ BitVector UndefElements; ConstantSDNode *CN = BV->getConstantSplatNode(&UndefElements); - // BuildVectors can truncate their operands. Ignore that case here. - if (CN && (UndefElements.none() || AllowUndefs) && - CN->getValueType(0) == N.getValueType().getScalarType()) - return CN; + // BuildVectors can truncate their operands. Ignore that case here unless + // AllowTruncation is set. + if (CN && (UndefElements.none() || AllowUndefs)) { + EVT CVT = CN->getValueType(0); + EVT NSVT = N.getValueType().getScalarType(); + assert(CVT.bitsGE(NSVT) && "Illegal build vector element extension"); + if (AllowTruncation || (CVT == NSVT)) + return CN; + } } return nullptr; } ConstantSDNode *llvm::isConstOrConstSplat(SDValue N, const APInt &DemandedElts, - bool AllowUndefs) { + bool AllowUndefs, + bool AllowTruncation) { if (ConstantSDNode *CN = dyn_cast(N)) return CN; @@ -8673,10 +8683,15 @@ BitVector UndefElements; ConstantSDNode *CN = BV->getConstantSplatNode(DemandedElts, &UndefElements); - // BuildVectors can truncate their operands. Ignore that case here. - if (CN && (UndefElements.none() || AllowUndefs) && - CN->getValueType(0) == N.getValueType().getScalarType()) - return CN; + // BuildVectors can truncate their operands. Ignore that case here unless + // AllowTruncation is set. + if (CN && (UndefElements.none() || AllowUndefs)) { + EVT CVT = CN->getValueType(0); + EVT NSVT = N.getValueType().getScalarType(); + assert(CVT.bitsGE(NSVT) && "Illegal build vector element extension"); + if (AllowTruncation || (CVT == NSVT)) + return CN; + } } return nullptr; diff --git a/llvm/test/CodeGen/AArch64/sat-add.ll b/llvm/test/CodeGen/AArch64/sat-add.ll --- a/llvm/test/CodeGen/AArch64/sat-add.ll +++ b/llvm/test/CodeGen/AArch64/sat-add.ll @@ -364,8 +364,7 @@ ; CHECK-NEXT: movi v1.16b, #42 ; CHECK-NEXT: add v1.16b, v0.16b, v1.16b ; CHECK-NEXT: cmhi v0.16b, v0.16b, v1.16b -; CHECK-NEXT: bic v1.16b, v1.16b, v0.16b -; CHECK-NEXT: orr v0.16b, v0.16b, v1.16b +; CHECK-NEXT: orr v0.16b, v1.16b, v0.16b ; CHECK-NEXT: ret %a = add <16 x i8> %x, %c = icmp ugt <16 x i8> %x, %a @@ -380,8 +379,7 @@ ; CHECK-NEXT: movi v2.16b, #213 ; CHECK-NEXT: add v1.16b, v0.16b, v1.16b ; CHECK-NEXT: cmhi v0.16b, v0.16b, v2.16b -; CHECK-NEXT: bic v1.16b, v1.16b, v0.16b -; CHECK-NEXT: orr v0.16b, v0.16b, v1.16b +; CHECK-NEXT: orr v0.16b, v1.16b, v0.16b ; CHECK-NEXT: ret %a = add <16 x i8> %x, %c = icmp ugt <16 x i8> %x, @@ -409,8 +407,7 @@ ; CHECK-NEXT: movi v1.8h, #42 ; CHECK-NEXT: add v1.8h, v0.8h, v1.8h ; CHECK-NEXT: cmhi v0.8h, v0.8h, v1.8h -; CHECK-NEXT: bic v1.16b, v1.16b, v0.16b -; CHECK-NEXT: orr v0.16b, v0.16b, v1.16b +; CHECK-NEXT: orr v0.16b, v1.16b, v0.16b ; CHECK-NEXT: ret %a = add <8 x i16> %x, %c = icmp ugt <8 x i16> %x, %a @@ -425,8 +422,7 @@ ; CHECK-NEXT: mvni v2.8h, #42 ; CHECK-NEXT: add v1.8h, v0.8h, v1.8h ; CHECK-NEXT: cmhi v0.8h, v0.8h, v2.8h -; CHECK-NEXT: bic v1.16b, v1.16b, v0.16b -; CHECK-NEXT: orr v0.16b, v0.16b, v1.16b +; CHECK-NEXT: orr v0.16b, v1.16b, v0.16b ; CHECK-NEXT: ret %a = add <8 x i16> %x, %c = icmp ugt <8 x i16> %x, @@ -545,8 +541,7 @@ ; CHECK: // %bb.0: ; CHECK-NEXT: add v1.16b, v0.16b, v1.16b ; CHECK-NEXT: cmhi v0.16b, v0.16b, v1.16b -; CHECK-NEXT: bic v1.16b, v1.16b, v0.16b -; CHECK-NEXT: orr v0.16b, v0.16b, v1.16b +; CHECK-NEXT: orr v0.16b, v1.16b, v0.16b ; CHECK-NEXT: ret %a = add <16 x i8> %x, %y %c = icmp ugt <16 x i8> %x, %a @@ -560,8 +555,7 @@ ; CHECK-NEXT: mvn v2.16b, v1.16b ; CHECK-NEXT: add v1.16b, v0.16b, v1.16b ; CHECK-NEXT: cmhi v0.16b, v0.16b, v2.16b -; CHECK-NEXT: bic v1.16b, v1.16b, v0.16b -; CHECK-NEXT: orr v0.16b, v0.16b, v1.16b +; CHECK-NEXT: orr v0.16b, v1.16b, v0.16b ; CHECK-NEXT: ret %noty = xor <16 x i8> %y, %a = add <16 x i8> %x, %y @@ -589,8 +583,7 @@ ; CHECK: // %bb.0: ; CHECK-NEXT: add v1.8h, v0.8h, v1.8h ; CHECK-NEXT: cmhi v0.8h, v0.8h, v1.8h -; CHECK-NEXT: bic v1.16b, v1.16b, v0.16b -; CHECK-NEXT: orr v0.16b, v0.16b, v1.16b +; CHECK-NEXT: orr v0.16b, v1.16b, v0.16b ; CHECK-NEXT: ret %a = add <8 x i16> %x, %y %c = icmp ugt <8 x i16> %x, %a @@ -604,8 +597,7 @@ ; CHECK-NEXT: mvn v2.16b, v1.16b ; CHECK-NEXT: add v1.8h, v0.8h, v1.8h ; CHECK-NEXT: cmhi v0.8h, v0.8h, v2.8h -; CHECK-NEXT: bic v1.16b, v1.16b, v0.16b -; CHECK-NEXT: orr v0.16b, v0.16b, v1.16b +; CHECK-NEXT: orr v0.16b, v1.16b, v0.16b ; CHECK-NEXT: ret %noty = xor <8 x i16> %y, %a = add <8 x i16> %x, %y