diff --git a/llvm/include/llvm/ADT/SmallVectorExtras.h b/llvm/include/llvm/ADT/SmallVectorExtras.h --- a/llvm/include/llvm/ADT/SmallVectorExtras.h +++ b/llvm/include/llvm/ADT/SmallVectorExtras.h @@ -20,6 +20,11 @@ namespace llvm { /// Map a range to a SmallVector with element types deduced from the mapping. +template +auto map_to_vector(ContainerTy &&C, FuncTy &&F) { + return to_vector( + map_range(std::forward(C), std::forward(F))); +} template auto map_to_vector(ContainerTy &&C, FuncTy &&F) { return to_vector( diff --git a/mlir/include/mlir/IR/AffineMap.h b/mlir/include/mlir/IR/AffineMap.h --- a/mlir/include/mlir/IR/AffineMap.h +++ b/mlir/include/mlir/IR/AffineMap.h @@ -19,6 +19,7 @@ #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseMapInfo.h" #include "llvm/ADT/SmallBitVector.h" +#include "llvm/ADT/SmallVectorExtras.h" #include namespace llvm { @@ -226,11 +227,11 @@ AffineMap shiftDims(unsigned shift, unsigned offset = 0) const { assert(offset <= getNumDims()); return AffineMap::get(getNumDims() + shift, getNumSymbols(), - llvm::to_vector<4>(llvm::map_range( + llvm::map_to_vector<4>( getResults(), [&](AffineExpr e) { return e.shiftDims(getNumDims(), shift, offset); - })), + }), getContext()); } @@ -238,12 +239,12 @@ /// by symbols[offset + shift ... shift + numSymbols). AffineMap shiftSymbols(unsigned shift, unsigned offset = 0) const { return AffineMap::get(getNumDims(), getNumSymbols() + shift, - llvm::to_vector<4>(llvm::map_range( - getResults(), - [&](AffineExpr e) { - return e.shiftSymbols(getNumSymbols(), shift, - offset); - })), + llvm::map_to_vector<4>(getResults(), + [&](AffineExpr e) { + return e.shiftSymbols( + getNumSymbols(), shift, + offset); + }), getContext()); } diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -15,6 +15,7 @@ #include "mlir/IR/IntegerSet.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/SymbolTable.h" +#include "llvm/ADT/SmallVectorExtras.h" #include "llvm/Support/raw_ostream.h" using namespace mlir; @@ -261,57 +262,56 @@ } ArrayAttr Builder::getBoolArrayAttr(ArrayRef values) { - auto attrs = llvm::to_vector<8>(llvm::map_range( - values, [this](bool v) -> Attribute { return getBoolAttr(v); })); + auto attrs = llvm::map_to_vector<8>( + values, [this](bool v) -> Attribute { return getBoolAttr(v); }); return getArrayAttr(attrs); } ArrayAttr Builder::getI32ArrayAttr(ArrayRef values) { - auto attrs = llvm::to_vector<8>(llvm::map_range( - values, [this](int32_t v) -> Attribute { return getI32IntegerAttr(v); })); + auto attrs = llvm::map_to_vector<8>( + values, [this](int32_t v) -> Attribute { return getI32IntegerAttr(v); }); return getArrayAttr(attrs); } ArrayAttr Builder::getI64ArrayAttr(ArrayRef values) { - auto attrs = llvm::to_vector<8>(llvm::map_range( - values, [this](int64_t v) -> Attribute { return getI64IntegerAttr(v); })); + auto attrs = llvm::map_to_vector<8>( + values, [this](int64_t v) -> Attribute { return getI64IntegerAttr(v); }); return getArrayAttr(attrs); } ArrayAttr Builder::getIndexArrayAttr(ArrayRef values) { - auto attrs = llvm::to_vector<8>( - llvm::map_range(values, [this](int64_t v) -> Attribute { - return getIntegerAttr(IndexType::get(getContext()), v); - })); + auto attrs = llvm::map_to_vector<8>(values, [this](int64_t v) -> Attribute { + return getIntegerAttr(IndexType::get(getContext()), v); + }); return getArrayAttr(attrs); } ArrayAttr Builder::getF32ArrayAttr(ArrayRef values) { - auto attrs = llvm::to_vector<8>(llvm::map_range( - values, [this](float v) -> Attribute { return getF32FloatAttr(v); })); + auto attrs = llvm::map_to_vector<8>( + values, [this](float v) -> Attribute { return getF32FloatAttr(v); }); return getArrayAttr(attrs); } ArrayAttr Builder::getF64ArrayAttr(ArrayRef values) { - auto attrs = llvm::to_vector<8>(llvm::map_range( - values, [this](double v) -> Attribute { return getF64FloatAttr(v); })); + auto attrs = llvm::map_to_vector<8>( + values, [this](double v) -> Attribute { return getF64FloatAttr(v); }); return getArrayAttr(attrs); } ArrayAttr Builder::getStrArrayAttr(ArrayRef values) { - auto attrs = llvm::to_vector<8>(llvm::map_range( - values, [this](StringRef v) -> Attribute { return getStringAttr(v); })); + auto attrs = llvm::map_to_vector<8>( + values, [this](StringRef v) -> Attribute { return getStringAttr(v); }); return getArrayAttr(attrs); } ArrayAttr Builder::getTypeArrayAttr(TypeRange values) { - auto attrs = llvm::to_vector<8>(llvm::map_range( - values, [](Type v) -> Attribute { return TypeAttr::get(v); })); + auto attrs = llvm::map_to_vector<8>( + values, [](Type v) -> Attribute { return TypeAttr::get(v); }); return getArrayAttr(attrs); } ArrayAttr Builder::getAffineMapArrayAttr(ArrayRef values) { - auto attrs = llvm::to_vector<8>(llvm::map_range( - values, [](AffineMap v) -> Attribute { return AffineMapAttr::get(v); })); + auto attrs = llvm::map_to_vector<8>( + values, [](AffineMap v) -> Attribute { return AffineMapAttr::get(v); }); return getArrayAttr(attrs); } diff --git a/mlir/lib/IR/TypeUtilities.cpp b/mlir/lib/IR/TypeUtilities.cpp --- a/mlir/lib/IR/TypeUtilities.cpp +++ b/mlir/lib/IR/TypeUtilities.cpp @@ -11,13 +11,12 @@ //===----------------------------------------------------------------------===// #include "mlir/IR/TypeUtilities.h" - -#include - #include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Types.h" #include "mlir/IR/Value.h" +#include "llvm/ADT/SmallVectorExtras.h" +#include using namespace mlir; @@ -119,8 +118,8 @@ /// have compatible dimensions. Dimensions are compatible if all non-dynamic /// dims are equal. The element type does not matter. LogicalResult mlir::verifyCompatibleShapes(TypeRange types) { - auto shapedTypes = llvm::to_vector<8>(llvm::map_range( - types, [](auto type) { return llvm::dyn_cast(type); })); + auto shapedTypes = llvm::map_to_vector<8>( + types, [](auto type) { return llvm::dyn_cast(type); }); // Return failure if some, but not all are not shaped. Return early if none // are shaped also. if (llvm::none_of(shapedTypes, [](auto t) { return t; })) @@ -155,10 +154,10 @@ for (unsigned i = 0; i < firstRank; ++i) { // Retrieve all ranked dimensions - auto dims = llvm::to_vector<8>(llvm::map_range( + auto dims = llvm::map_to_vector<8>( llvm::make_filter_range( shapes, [&](auto shape) { return shape.getRank() >= i; }), - [&](auto shape) { return shape.getDimSize(i); })); + [&](auto shape) { return shape.getDimSize(i); }); if (verifyCompatibleDims(dims).failed()) return failure(); }