diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td --- a/mlir/include/mlir/IR/BuiltinTypes.td +++ b/mlir/include/mlir/IR/BuiltinTypes.td @@ -1125,6 +1125,24 @@ /// provided shape is `std::nullopt`, the current shape of the type is used. VectorType cloneWith(std::optional> shape, Type elementType) const; + + /// Return a new vector type which results from dropping the first \p n + /// dimensions of this vector type. + VectorType dropFrontDims(size_t n) const; + + /// Return a new vector type which results from dropping the last \p n + /// dimensions of this vector type. + VectorType dropBackDims(size_t n) const; + + /// Alias for dropFrontDims(1) + VectorType dropFrontDim() const { return dropFrontDims(1); } + + /// Alias for dropBackDims(1) + VectorType dropBackDim() const { return dropBackDims(1); } + + /// Return a new vector type which results from dropping the first \p n + /// dimensions and keeping \p m following dimensions. + VectorType sliceDims(size_t n, size_t m) const; }]; let skipDefaultBuilders = 1; let genVerifyDecl = 1; diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp --- a/mlir/lib/IR/BuiltinTypes.cpp +++ b/mlir/lib/IR/BuiltinTypes.cpp @@ -265,6 +265,21 @@ getScalableDims()); } +VectorType VectorType::dropFrontDims(size_t n) const { + assert(size_t(getRank()) >= n && "Dropping more dims than exist"); + return sliceDims(n, size_t(getRank()) - n); +} + +VectorType VectorType::dropBackDims(size_t n) const { + assert(size_t(getRank()) >= n && "Dropping more dims than exist"); + return sliceDims(0, size_t(getRank()) - n); +} + +VectorType VectorType::sliceDims(size_t n, size_t m) const { + return VectorType::get(getShape().slice(n, m), getElementType(), + getScalableDims().slice(n, m)); +} + //===----------------------------------------------------------------------===// // TensorType //===----------------------------------------------------------------------===//