diff --git a/llvm/include/llvm/ADT/STLExtras.h b/llvm/include/llvm/ADT/STLExtras.h --- a/llvm/include/llvm/ADT/STLExtras.h +++ b/llvm/include/llvm/ADT/STLExtras.h @@ -1787,12 +1787,20 @@ } /// Returns true if Range consists of the same value repeated multiple times. -template bool is_splat(R &&Range) { +template +LLVM_DEPRECATED( + "Use 'all_equal(Range)' or '!empty(Range) && all_equal(Range)' instead.", + "all_equal") +bool is_splat(R &&Range) { return !llvm::empty(Range) && all_equal(Range); } /// Returns true if Values consists of the same value repeated multiple times. -template bool is_splat(std::initializer_list Values) { +template +LLVM_DEPRECATED( + "Use 'all_equal(Values)' or '!empty(Values) && all_equal(Values)' instead.", + "all_equal") +bool is_splat(std::initializer_list Values) { return is_splat>(std::move(Values)); } diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp --- a/llvm/lib/Analysis/InstructionSimplify.cpp +++ b/llvm/lib/Analysis/InstructionSimplify.cpp @@ -4997,7 +4997,7 @@ // value type is same as the input vectors' type. if (auto *OpShuf = dyn_cast(Op0)) if (Q.isUndefValue(Op1) && RetTy == InVecTy && - is_splat(OpShuf->getShuffleMask())) + all_equal(OpShuf->getShuffleMask())) return Op0; // All remaining transformation depend on the value of the mask, which is diff --git a/llvm/lib/Analysis/VectorUtils.cpp b/llvm/lib/Analysis/VectorUtils.cpp --- a/llvm/lib/Analysis/VectorUtils.cpp +++ b/llvm/lib/Analysis/VectorUtils.cpp @@ -398,7 +398,7 @@ if (auto *Shuf = dyn_cast(V)) { // FIXME: We can safely allow undefs here. If Index was specified, we will // check that the mask elt is defined at the required index. - if (!is_splat(Shuf->getShuffleMask())) + if (!all_equal(Shuf->getShuffleMask())) return false; // Match any index. @@ -478,7 +478,7 @@ if (SliceFront < 0) { // Negative values (undef or other "sentinel" values) must be equal across // the entire slice. - if (!is_splat(MaskSlice)) + if (!all_equal(MaskSlice)) return false; ScaledMask.push_back(SliceFront); } else { diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -23713,7 +23713,7 @@ // demanded elements analysis. It is further limited to not change a splat // of an inserted scalar because that may be optimized better by // load-folding or other target-specific behaviors. - if (isConstOrConstSplat(RHS) && Shuf0 && is_splat(Shuf0->getMask()) && + if (isConstOrConstSplat(RHS) && Shuf0 && all_equal(Shuf0->getMask()) && Shuf0->hasOneUse() && Shuf0->getOperand(1).isUndef() && Shuf0->getOperand(0).getOpcode() != ISD::INSERT_VECTOR_ELT) { // binop (splat X), (splat C) --> splat (binop X, C) @@ -23722,7 +23722,7 @@ return DAG.getVectorShuffle(VT, DL, NewBinOp, DAG.getUNDEF(VT), Shuf0->getMask()); } - if (isConstOrConstSplat(LHS) && Shuf1 && is_splat(Shuf1->getMask()) && + if (isConstOrConstSplat(LHS) && Shuf1 && all_equal(Shuf1->getMask()) && Shuf1->hasOneUse() && Shuf1->getOperand(1).isUndef() && Shuf1->getOperand(0).getOpcode() != ISD::INSERT_VECTOR_ELT) { // binop (splat C), (splat X) --> splat (binop C, X) diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp @@ -3287,7 +3287,7 @@ Flags.copyFMF(*FPOp); // Min/max matching is only viable if all output VTs are the same. - if (is_splat(ValueVTs)) { + if (all_equal(ValueVTs)) { EVT VT = ValueVTs[0]; LLVMContext &Ctx = *DAG.getContext(); auto &TLI = DAG.getTargetLoweringInfo(); diff --git a/llvm/lib/IR/Instructions.cpp b/llvm/lib/IR/Instructions.cpp --- a/llvm/lib/IR/Instructions.cpp +++ b/llvm/lib/IR/Instructions.cpp @@ -2061,7 +2061,7 @@ return false; if (isa(V1->getType())) - if ((Mask[0] != 0 && Mask[0] != UndefMaskElem) || !is_splat(Mask)) + if ((Mask[0] != 0 && Mask[0] != UndefMaskElem) || !all_equal(Mask)) return false; return true; @@ -2152,7 +2152,7 @@ Type *ResultTy) { Type *Int32Ty = Type::getInt32Ty(ResultTy->getContext()); if (isa(ResultTy)) { - assert(is_splat(Mask) && "Unexpected shuffle"); + assert(all_equal(Mask) && "Unexpected shuffle"); Type *VecTy = VectorType::get(Int32Ty, Mask.size(), true); if (Mask[0] == 0) return Constant::getNullValue(VecTy); diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -12894,7 +12894,7 @@ static bool isSplatShuffle(Value *V) { if (auto *Shuf = dyn_cast(V)) - return is_splat(Shuf->getShuffleMask()); + return all_equal(Shuf->getShuffleMask()); return false; } @@ -20831,7 +20831,7 @@ // All non aggregate members of the type must have the same type SmallVector ValueVTs; ComputeValueVTs(*this, DL, Ty, ValueVTs); - return is_splat(ValueVTs); + return all_equal(ValueVTs); } bool AArch64TargetLowering::shouldNormalizeToSelectSequence(LLVMContext &, diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp @@ -726,7 +726,7 @@ InstCombiner::BuilderTy &Builder) { auto *Shuf = dyn_cast(Trunc.getOperand(0)); if (Shuf && Shuf->hasOneUse() && match(Shuf->getOperand(1), m_Undef()) && - is_splat(Shuf->getShuffleMask()) && + all_equal(Shuf->getShuffleMask()) && Shuf->getType() == Shuf->getOperand(0)->getType()) { // trunc (shuf X, Undef, SplatMask) --> shuf (trunc X), Poison, SplatMask // trunc (shuf X, Poison, SplatMask) --> shuf (trunc X), Poison, SplatMask diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -3141,7 +3141,7 @@ ArrayRef Mask; if (match(BCSrcOp, m_Shuffle(m_Value(Vec), m_Undef(), m_Mask(Mask)))) { // Check whether every element of Mask is the same constant - if (is_splat(Mask)) { + if (all_equal(Mask)) { auto *VecTy = cast(SrcType); auto *EltTy = cast(VecTy->getElementType()); if (C->isSplat(EltTy->getBitWidth())) { diff --git a/llvm/lib/Transforms/Scalar/NewGVN.cpp b/llvm/lib/Transforms/Scalar/NewGVN.cpp --- a/llvm/lib/Transforms/Scalar/NewGVN.cpp +++ b/llvm/lib/Transforms/Scalar/NewGVN.cpp @@ -3166,7 +3166,7 @@ make_filter_range(MP->operands(), ReachableOperandPred); SmallVector OperandList; llvm::copy(FilteredPhiArgs, std::back_inserter(OperandList)); - bool Okay = is_splat(OperandList); + bool Okay = all_equal(OperandList); if (Okay) return singleReachablePHIPath(Visited, cast(OperandList[0]), Second); @@ -3261,7 +3261,7 @@ const MemoryDef *MD = cast(U); return ValueToClass.lookup(MD->getMemoryInst()); }); - assert(is_splat(PhiOpClasses) && + assert(all_equal(PhiOpClasses) && "All MemoryPhi arguments should be in the same class"); } } diff --git a/llvm/unittests/ADT/STLExtrasTest.cpp b/llvm/unittests/ADT/STLExtrasTest.cpp --- a/llvm/unittests/ADT/STLExtrasTest.cpp +++ b/llvm/unittests/ADT/STLExtrasTest.cpp @@ -611,28 +611,6 @@ EXPECT_TRUE(all_equal({1, 1, 1})); } -TEST(STLExtrasTest, IsSplat) { - std::vector V; - EXPECT_FALSE(is_splat(V)); - - V.push_back(1); - EXPECT_TRUE(is_splat(V)); - - V.push_back(1); - V.push_back(1); - EXPECT_TRUE(is_splat(V)); - - V.push_back(2); - EXPECT_FALSE(is_splat(V)); -} - -TEST(STLExtrasTest, IsSplatInitializerList) { - EXPECT_TRUE(is_splat({1})); - EXPECT_TRUE(is_splat({1, 1})); - EXPECT_FALSE(is_splat({1, 2})); - EXPECT_TRUE(is_splat({1, 1, 1})); -} - TEST(STLExtrasTest, to_address) { int *V1 = new int; EXPECT_EQ(V1, to_address(V1)); diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -2432,7 +2432,7 @@ // 1) all operands involved are of shaped type and // 2) the indices are not out of range. class TCopVTEtAreSameAt indices> : CPred< - "::llvm::is_splat(::llvm::map_range(" + "::llvm::all_equal(::llvm::map_range(" "::mlir::ArrayRef({" # !interleave(indices, ", ") # "}), " "[this](unsigned i) { return getElementTypeOrSelf(this->getOperand(i)); " "}))">; diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp --- a/mlir/lib/Bindings/Python/IRAttributes.cpp +++ b/mlir/lib/Bindings/Python/IRAttributes.cpp @@ -768,7 +768,7 @@ .def_static("get_splat", PyDenseElementsAttribute::getSplat, py::arg("shaped_type"), py::arg("element_attr"), "Gets a DenseElementsAttr where all values are the same") - .def_property_readonly("is_splat", + .def_property_readonly("all_equal", [](PyDenseElementsAttribute &self) -> bool { return mlirDenseElementsAttrIsSplat(self); }) diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp @@ -77,8 +77,8 @@ auto getOperandElementType = [](OpOperand *operand) { return operand->get().getType().cast().getElementType(); }; - if (!llvm::is_splat(llvm::map_range(genericOp.getInputAndOutputOperands(), - getOperandElementType))) + if (!llvm::all_equal(llvm::map_range(genericOp.getInputAndOutputOperands(), + getOperandElementType))) return failure(); // We can only handle the case where we have int/float elements. diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -2871,9 +2871,9 @@ if (resultType.getNumElements() != 2) return emitOpError("expected result struct type containing two members"); - if (!llvm::is_splat({operand1().getType(), operand2().getType(), - resultType.getElementType(0), - resultType.getElementType(1)})) + if (!llvm::all_equal({operand1().getType(), operand2().getType(), + resultType.getElementType(0), + resultType.getElementType(1)})) return emitOpError( "expected all operand types and struct member types are the same"); @@ -2920,9 +2920,9 @@ if (resultType.getNumElements() != 2) return emitOpError("expected result struct type containing two members"); - if (!llvm::is_splat({operand1().getType(), operand2().getType(), - resultType.getElementType(0), - resultType.getElementType(1)})) + if (!llvm::all_equal({operand1().getType(), operand2().getType(), + resultType.getElementType(0), + resultType.getElementType(1)})) return emitOpError( "expected all operand types and struct member types are the same"); diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -1269,7 +1269,7 @@ // This is an elementwise op, so all transposed operands should have the // same type. We need to additionally check that all transposes uses the // same map. - if (!llvm::is_splat(transposeMaps)) + if (!llvm::all_equal(transposeMaps)) return rewriter.notifyMatchFailure(op, "different transpose map"); SmallVector srcValues; diff --git a/mlir/test/mlir-tblgen/predicate.td b/mlir/test/mlir-tblgen/predicate.td --- a/mlir/test/mlir-tblgen/predicate.td +++ b/mlir/test/mlir-tblgen/predicate.td @@ -102,7 +102,7 @@ } // CHECK-LABEL: OpJAdaptor::verify -// CHECK: ::llvm::is_splat(::llvm::map_range( +// CHECK: ::llvm::all_equal(::llvm::map_range( // CHECK-SAME: ::mlir::ArrayRef({0, 2, 3}), // CHECK-SAME: [this](unsigned i) { return getElementTypeOrSelf(this->getOperand(i)); })) // CHECK: "failed to verify that operands indexed at 0, 2, 3 should all have the same type" diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp --- a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp +++ b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp @@ -41,7 +41,7 @@ if (selectedDialect.empty()) { // If a dialect was not specified, ensure that all found defs belong to the // same dialect. - if (!llvm::is_splat(llvm::map_range( + if (!llvm::all_equal(llvm::map_range( defs, [](const auto &def) { return def.getDialect(); }))) { llvm::PrintFatalError("defs belonging to more than one dialect. Must " "select one via '--(attr|type)defs-dialect'");