diff --git a/mlir/include/mlir/Dialect/Vector/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/VectorRewritePatterns.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Vector/VectorRewritePatterns.h @@ -0,0 +1,58 @@ +//===- VectorRewritePatterns.h - Vector rewrite patterns --------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef DIALECT_VECTOR_VECTORREWRITEPATTERNS_H_ +#define DIALECT_VECTOR_VECTORREWRITEPATTERNS_H_ + +namespace mlir { +class RewritePatternSet; + +namespace vector { + +/// Populate `patterns` with the following patterns. +/// +/// [VectorInsertStridedSliceOpDifferentRankRewritePattern] +/// ======================================================= +/// RewritePattern for InsertStridedSliceOp where source and destination vectors +/// have different ranks. +/// +/// When ranks are different, InsertStridedSlice needs to extract a properly +/// ranked vector from the destination vector into which to insert. This pattern +/// only takes care of this extraction part and forwards the rest to +/// [VectorInsertStridedSliceOpSameRankRewritePattern]. +/// +/// For a k-D source and n-D destination vector (k < n), we emit: +/// 1. ExtractOp to extract the (unique) (n-1)-D subvector into which to +/// insert the k-D source. +/// 2. k-D -> (n-1)-D InsertStridedSlice op +/// 3. InsertOp that is the reverse of 1. +/// +/// [VectorInsertStridedSliceOpSameRankRewritePattern] +/// ================================================== +/// RewritePattern for InsertStridedSliceOp where source and destination vectors +/// have the same rank. For each outermost index in the slice: +/// begin end stride +/// [offset : offset+size*stride : stride] +/// 1. ExtractOp one (k-1)-D source subvector and one (n-1)-D dest subvector. +/// 2. InsertStridedSlice (k-1)-D into (n-1)-D +/// 3. the destination subvector is inserted back in the proper place +/// 3. InsertOp that is the reverse of 1. +/// +/// [VectorExtractStridedSliceOpRewritePattern] +/// =========================================== +/// Progressive lowering of ExtractStridedSliceOp to either: +/// 1. single offset extract as a direct vector::ShuffleOp. +/// 2. ExtractOp/ExtractElementOp + lower rank ExtractStridedSliceOp + +/// InsertOp/InsertElementOp for the n-D case. +void populateVectorInsertExtractStridedSliceTransforms( + RewritePatternSet &patterns); + +} // namespace vector +} // namespace mlir + +#endif // DIALECT_VECTOR_VECTORREWRITEPATTERNS_H_ diff --git a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h --- a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h +++ b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h @@ -24,13 +24,6 @@ class IfOp; } // namespace scf -/// Collect a set of patterns to convert from the Vector dialect to itself. -/// Should be merged with populateVectorToSCFLoweringPattern. -void populateVectorToVectorConversionPatterns( - MLIRContext *context, RewritePatternSet &patterns, - ArrayRef coarseVectorShape = {}, - ArrayRef fineVectorShape = {}); - namespace vector { /// Options that control the vector unrolling. diff --git a/mlir/include/mlir/Dialect/Vector/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/VectorUtils.h --- a/mlir/include/mlir/Dialect/Vector/VectorUtils.h +++ b/mlir/include/mlir/Dialect/Vector/VectorUtils.h @@ -9,6 +9,7 @@ #ifndef MLIR_DIALECT_VECTOR_VECTORUTILS_H_ #define MLIR_DIALECT_VECTOR_VECTORUTILS_H_ +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/Support/LLVM.h" #include "llvm/ADT/DenseMap.h" @@ -184,6 +185,11 @@ bool checkSameValueWAW(vector::TransferWriteOp write, vector::TransferWriteOp priorWrite); +// Helper that returns a subset of `arrayAttr` as a vector of int64_t. +SmallVector getI64SubArray(ArrayAttr arrayAttr, + unsigned dropFront = 0, + unsigned dropBack = 0); + namespace matcher { /// Matches vector.transfer_read, vector.transfer_write and ops that return a diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -15,6 +15,7 @@ #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/Vector/VectorOps.h" +#include "mlir/Dialect/Vector/VectorRewritePatterns.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/Support/MathExtras.h" #include "mlir/Target/LLVMIR/TypeToLLVM.h" @@ -52,17 +53,6 @@ rewriter.getI64ArrayAttr(pos)); } -// Helper that picks the proper sequence for inserting. -static Value insertOne(PatternRewriter &rewriter, Location loc, Value from, - Value into, int64_t offset) { - auto vectorType = into.getType().cast(); - if (vectorType.getRank() > 1) - return rewriter.create(loc, from, into, offset); - return rewriter.create( - loc, vectorType, from, into, - rewriter.create(loc, offset)); -} - // Helper that picks the proper sequence for extracting. static Value extractOne(ConversionPatternRewriter &rewriter, LLVMTypeConverter &typeConverter, Location loc, @@ -79,32 +69,6 @@ rewriter.getI64ArrayAttr(pos)); } -// Helper that picks the proper sequence for extracting. -static Value extractOne(PatternRewriter &rewriter, Location loc, Value vector, - int64_t offset) { - auto vectorType = vector.getType().cast(); - if (vectorType.getRank() > 1) - return rewriter.create(loc, vector, offset); - return rewriter.create( - loc, vectorType.getElementType(), vector, - rewriter.create(loc, offset)); -} - -// Helper that returns a subset of `arrayAttr` as a vector of int64_t. -// TODO: Better support for attribute subtype forwarding + slicing. -static SmallVector getI64SubArray(ArrayAttr arrayAttr, - unsigned dropFront = 0, - unsigned dropBack = 0) { - assert(arrayAttr.size() > dropFront + dropBack && "Out of bounds"); - auto range = arrayAttr.getAsRange(); - SmallVector res; - res.reserve(arrayAttr.size() - dropFront - dropBack); - for (auto it = range.begin() + dropFront, eit = range.end() - dropBack; - it != eit; ++it) - res.push_back((*it).getValue().getSExtValue()); - return res; -} - // Helper that returns data layout alignment of a memref. LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter, MemRefType memrefType, unsigned &align) { @@ -813,132 +777,6 @@ } }; -// When ranks are different, InsertStridedSlice needs to extract a properly -// ranked vector from the destination vector into which to insert. This pattern -// only takes care of this part and forwards the rest of the conversion to -// another pattern that converts InsertStridedSlice for operands of the same -// rank. -// -// RewritePattern for InsertStridedSliceOp where source and destination vectors -// have different ranks. In this case: -// 1. the proper subvector is extracted from the destination vector -// 2. a new InsertStridedSlice op is created to insert the source in the -// destination subvector -// 3. the destination subvector is inserted back in the proper place -// 4. the op is replaced by the result of step 3. -// The new InsertStridedSlice from step 2. will be picked up by a -// `VectorInsertStridedSliceOpSameRankRewritePattern`. -class VectorInsertStridedSliceOpDifferentRankRewritePattern - : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(InsertStridedSliceOp op, - PatternRewriter &rewriter) const override { - auto srcType = op.getSourceVectorType(); - auto dstType = op.getDestVectorType(); - - if (op.offsets().getValue().empty()) - return failure(); - - auto loc = op.getLoc(); - int64_t rankDiff = dstType.getRank() - srcType.getRank(); - assert(rankDiff >= 0); - if (rankDiff == 0) - return failure(); - - int64_t rankRest = dstType.getRank() - rankDiff; - // Extract / insert the subvector of matching rank and InsertStridedSlice - // on it. - Value extracted = - rewriter.create(loc, op.dest(), - getI64SubArray(op.offsets(), /*dropFront=*/0, - /*dropBack=*/rankRest)); - // A different pattern will kick in for InsertStridedSlice with matching - // ranks. - auto stridedSliceInnerOp = rewriter.create( - loc, op.source(), extracted, - getI64SubArray(op.offsets(), /*dropFront=*/rankDiff), - getI64SubArray(op.strides(), /*dropFront=*/0)); - rewriter.replaceOpWithNewOp( - op, stridedSliceInnerOp.getResult(), op.dest(), - getI64SubArray(op.offsets(), /*dropFront=*/0, - /*dropBack=*/rankRest)); - return success(); - } -}; - -// RewritePattern for InsertStridedSliceOp where source and destination vectors -// have the same rank. In this case, we reduce -// 1. the proper subvector is extracted from the destination vector -// 2. a new InsertStridedSlice op is created to insert the source in the -// destination subvector -// 3. the destination subvector is inserted back in the proper place -// 4. the op is replaced by the result of step 3. -// The new InsertStridedSlice from step 2. will be picked up by a -// `VectorInsertStridedSliceOpSameRankRewritePattern`. -class VectorInsertStridedSliceOpSameRankRewritePattern - : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - - void initialize() { - // This pattern creates recursive InsertStridedSliceOp, but the recursion is - // bounded as the rank is strictly decreasing. - setHasBoundedRewriteRecursion(); - } - - LogicalResult matchAndRewrite(InsertStridedSliceOp op, - PatternRewriter &rewriter) const override { - auto srcType = op.getSourceVectorType(); - auto dstType = op.getDestVectorType(); - - if (op.offsets().getValue().empty()) - return failure(); - - int64_t rankDiff = dstType.getRank() - srcType.getRank(); - assert(rankDiff >= 0); - if (rankDiff != 0) - return failure(); - - if (srcType == dstType) { - rewriter.replaceOp(op, op.source()); - return success(); - } - - int64_t offset = - op.offsets().getValue().front().cast().getInt(); - int64_t size = srcType.getShape().front(); - int64_t stride = - op.strides().getValue().front().cast().getInt(); - - auto loc = op.getLoc(); - Value res = op.dest(); - // For each slice of the source vector along the most major dimension. - for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e; - off += stride, ++idx) { - // 1. extract the proper subvector (or element) from source - Value extractedSource = extractOne(rewriter, loc, op.source(), idx); - if (extractedSource.getType().isa()) { - // 2. If we have a vector, extract the proper subvector from destination - // Otherwise we are at the element level and no need to recurse. - Value extractedDest = extractOne(rewriter, loc, op.dest(), off); - // 3. Reduce the problem to lowering a new InsertStridedSlice op with - // smaller rank. - extractedSource = rewriter.create( - loc, extractedSource, extractedDest, - getI64SubArray(op.offsets(), /* dropFront=*/1), - getI64SubArray(op.strides(), /* dropFront=*/1)); - } - // 4. Insert the extractedSource into the res vector. - res = insertOne(rewriter, loc, extractedSource, res, off); - } - - rewriter.replaceOp(op, res); - return success(); - } -}; - /// Returns the strides if the memory underlying `memRefType` has a contiguous /// static layout. static llvm::Optional> @@ -1189,67 +1027,6 @@ } }; -/// Progressive lowering of ExtractStridedSliceOp to either: -/// 1. express single offset extract as a direct shuffle. -/// 2. extract + lower rank strided_slice + insert for the n-D case. -class VectorExtractStridedSliceOpConversion - : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - - void initialize() { - // This pattern creates recursive ExtractStridedSliceOp, but the recursion - // is bounded as the rank is strictly decreasing. - setHasBoundedRewriteRecursion(); - } - - LogicalResult matchAndRewrite(ExtractStridedSliceOp op, - PatternRewriter &rewriter) const override { - auto dstType = op.getType(); - - assert(!op.offsets().getValue().empty() && "Unexpected empty offsets"); - - int64_t offset = - op.offsets().getValue().front().cast().getInt(); - int64_t size = op.sizes().getValue().front().cast().getInt(); - int64_t stride = - op.strides().getValue().front().cast().getInt(); - - auto loc = op.getLoc(); - auto elemType = dstType.getElementType(); - assert(elemType.isSignlessIntOrIndexOrFloat()); - - // Single offset can be more efficiently shuffled. - if (op.offsets().getValue().size() == 1) { - SmallVector offsets; - offsets.reserve(size); - for (int64_t off = offset, e = offset + size * stride; off < e; - off += stride) - offsets.push_back(off); - rewriter.replaceOpWithNewOp(op, dstType, op.vector(), - op.vector(), - rewriter.getI64ArrayAttr(offsets)); - return success(); - } - - // Extract/insert on a lower ranked extract strided slice op. - Value zero = rewriter.create( - loc, elemType, rewriter.getZeroAttr(elemType)); - Value res = rewriter.create(loc, dstType, zero); - for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e; - off += stride, ++idx) { - Value one = extractOne(rewriter, loc, op.vector(), off); - Value extracted = rewriter.create( - loc, one, getI64SubArray(op.offsets(), /* dropFront=*/1), - getI64SubArray(op.sizes(), /* dropFront=*/1), - getI64SubArray(op.strides(), /* dropFront=*/1)); - res = insertOne(rewriter, loc, extracted, res, idx); - } - rewriter.replaceOp(op, res); - return success(); - } -}; - } // namespace /// Populate the given list with patterns that convert from Vector to LLVM. @@ -1257,10 +1034,8 @@ LLVMTypeConverter &converter, RewritePatternSet &patterns, bool reassociateFPReductions) { MLIRContext *ctx = converter.getDialect()->getContext(); - patterns.add(ctx); + patterns.add(ctx); + populateVectorInsertExtractStridedSliceTransforms(patterns); patterns.add(converter, reassociateFPReductions); patterns .add(); + if (vectorType.getRank() > 1) + return rewriter.create(loc, from, into, offset); + return rewriter.create( + loc, vectorType, from, into, + rewriter.create(loc, offset)); +} + +// Helper that picks the proper sequence for extracting. +static Value extractOne(PatternRewriter &rewriter, Location loc, Value vector, + int64_t offset) { + auto vectorType = vector.getType().cast(); + if (vectorType.getRank() > 1) + return rewriter.create(loc, vector, offset); + return rewriter.create( + loc, vectorType.getElementType(), vector, + rewriter.create(loc, offset)); +} + +/// RewritePattern for InsertStridedSliceOp where source and destination vectors +/// have different ranks. +/// +/// When ranks are different, InsertStridedSlice needs to extract a properly +/// ranked vector from the destination vector into which to insert. This pattern +/// only takes care of this extraction part and forwards the rest to +/// [VectorInsertStridedSliceOpSameRankRewritePattern]. +/// +/// For a k-D source and n-D destination vector (k < n), we emit: +/// 1. ExtractOp to extract the (unique) (n-1)-D subvector into which to +/// insert the k-D source. +/// 2. k-D -> (n-1)-D InsertStridedSlice op +/// 3. InsertOp that is the reverse of 1. +class VectorInsertStridedSliceOpDifferentRankRewritePattern + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(InsertStridedSliceOp op, + PatternRewriter &rewriter) const override { + auto srcType = op.getSourceVectorType(); + auto dstType = op.getDestVectorType(); + + if (op.offsets().getValue().empty()) + return failure(); + + auto loc = op.getLoc(); + int64_t rankDiff = dstType.getRank() - srcType.getRank(); + assert(rankDiff >= 0); + if (rankDiff == 0) + return failure(); + + int64_t rankRest = dstType.getRank() - rankDiff; + // Extract / insert the subvector of matching rank and InsertStridedSlice + // on it. + Value extracted = + rewriter.create(loc, op.dest(), + getI64SubArray(op.offsets(), /*dropFront=*/0, + /*dropBack=*/rankRest)); + + // A different pattern will kick in for InsertStridedSlice with matching + // ranks. + auto stridedSliceInnerOp = rewriter.create( + loc, op.source(), extracted, + getI64SubArray(op.offsets(), /*dropFront=*/rankDiff), + getI64SubArray(op.strides(), /*dropFront=*/0)); + + rewriter.replaceOpWithNewOp( + op, stridedSliceInnerOp.getResult(), op.dest(), + getI64SubArray(op.offsets(), /*dropFront=*/0, + /*dropBack=*/rankRest)); + return success(); + } +}; + +/// RewritePattern for InsertStridedSliceOp where source and destination vectors +/// have the same rank. For each outermost index in the slice: +/// begin end stride +/// [offset : offset+size*stride : stride] +/// 1. ExtractOp one (k-1)-D source subvector and one (n-1)-D dest subvector. +/// 2. InsertStridedSlice (k-1)-D into (n-1)-D +/// 3. the destination subvector is inserted back in the proper place +/// 3. InsertOp that is the reverse of 1. +class VectorInsertStridedSliceOpSameRankRewritePattern + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + void initialize() { + // This pattern creates recursive InsertStridedSliceOp, but the recursion is + // bounded as the rank is strictly decreasing. + setHasBoundedRewriteRecursion(); + } + + LogicalResult matchAndRewrite(InsertStridedSliceOp op, + PatternRewriter &rewriter) const override { + auto srcType = op.getSourceVectorType(); + auto dstType = op.getDestVectorType(); + + if (op.offsets().getValue().empty()) + return failure(); + + int64_t rankDiff = dstType.getRank() - srcType.getRank(); + assert(rankDiff >= 0); + if (rankDiff != 0) + return failure(); + + if (srcType == dstType) { + rewriter.replaceOp(op, op.source()); + return success(); + } + + int64_t offset = + op.offsets().getValue().front().cast().getInt(); + int64_t size = srcType.getShape().front(); + int64_t stride = + op.strides().getValue().front().cast().getInt(); + + auto loc = op.getLoc(); + Value res = op.dest(); + // For each slice of the source vector along the most major dimension. + for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e; + off += stride, ++idx) { + // 1. extract the proper subvector (or element) from source + Value extractedSource = extractOne(rewriter, loc, op.source(), idx); + if (extractedSource.getType().isa()) { + // 2. If we have a vector, extract the proper subvector from destination + // Otherwise we are at the element level and no need to recurse. + Value extractedDest = extractOne(rewriter, loc, op.dest(), off); + // 3. Reduce the problem to lowering a new InsertStridedSlice op with + // smaller rank. + extractedSource = rewriter.create( + loc, extractedSource, extractedDest, + getI64SubArray(op.offsets(), /* dropFront=*/1), + getI64SubArray(op.strides(), /* dropFront=*/1)); + } + // 4. Insert the extractedSource into the res vector. + res = insertOne(rewriter, loc, extractedSource, res, off); + } + + rewriter.replaceOp(op, res); + return success(); + } +}; + +/// Progressive lowering of ExtractStridedSliceOp to either: +/// 1. single offset extract as a direct vector::ShuffleOp. +/// 2. ExtractOp/ExtractElementOp + lower rank ExtractStridedSliceOp + +/// InsertOp/InsertElementOp for the n-D case. +class VectorExtractStridedSliceOpRewritePattern + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + void initialize() { + // This pattern creates recursive ExtractStridedSliceOp, but the recursion + // is bounded as the rank is strictly decreasing. + setHasBoundedRewriteRecursion(); + } + + LogicalResult matchAndRewrite(ExtractStridedSliceOp op, + PatternRewriter &rewriter) const override { + auto dstType = op.getType(); + + assert(!op.offsets().getValue().empty() && "Unexpected empty offsets"); + + int64_t offset = + op.offsets().getValue().front().cast().getInt(); + int64_t size = op.sizes().getValue().front().cast().getInt(); + int64_t stride = + op.strides().getValue().front().cast().getInt(); + + auto loc = op.getLoc(); + auto elemType = dstType.getElementType(); + assert(elemType.isSignlessIntOrIndexOrFloat()); + + // Single offset can be more efficiently shuffled. + if (op.offsets().getValue().size() == 1) { + SmallVector offsets; + offsets.reserve(size); + for (int64_t off = offset, e = offset + size * stride; off < e; + off += stride) + offsets.push_back(off); + rewriter.replaceOpWithNewOp(op, dstType, op.vector(), + op.vector(), + rewriter.getI64ArrayAttr(offsets)); + return success(); + } + + // Extract/insert on a lower ranked extract strided slice op. + Value zero = rewriter.create( + loc, elemType, rewriter.getZeroAttr(elemType)); + Value res = rewriter.create(loc, dstType, zero); + for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e; + off += stride, ++idx) { + Value one = extractOne(rewriter, loc, op.vector(), off); + Value extracted = rewriter.create( + loc, one, getI64SubArray(op.offsets(), /* dropFront=*/1), + getI64SubArray(op.sizes(), /* dropFront=*/1), + getI64SubArray(op.strides(), /* dropFront=*/1)); + res = insertOne(rewriter, loc, extracted, res, idx); + } + rewriter.replaceOp(op, res); + return success(); + } +}; + +/// Populate the given list with patterns that convert from Vector to LLVM. +void mlir::vector::populateVectorInsertExtractStridedSliceTransforms( + RewritePatternSet &patterns) { + patterns.add( + patterns.getContext()); +} diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -2204,20 +2204,6 @@ } }; -// Helper that returns a subset of `arrayAttr` as a vector of int64_t. -static SmallVector getI64SubArray(ArrayAttr arrayAttr, - unsigned dropFront = 0, - unsigned dropBack = 0) { - assert(arrayAttr.size() > dropFront + dropBack && "Out of bounds"); - auto range = arrayAttr.getAsRange(); - SmallVector res; - res.reserve(arrayAttr.size() - dropFront - dropBack); - for (auto it = range.begin() + dropFront, eit = range.end() - dropBack; - it != eit; ++it) - res.push_back((*it).getValue().getSExtValue()); - return res; -} - // Pattern to rewrite an ExtractStridedSliceOp(BroadcastOp) to // BroadcastOp(ExtractStrideSliceOp). class StridedSliceBroadcast final diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -1034,10 +1034,11 @@ }; /// ShapeOp 1D -> 2D upcast serves the purpose of unflattening 2-D from 1-D -/// vectors progressively on the way from targeting llvm.matrix intrinsics. +/// vectors progressively. /// This iterates over the most major dimension of the 2-D vector and performs /// rewrites into: -/// vector.strided_slice from 1-D + vector.insert into 2-D +/// vector.extract_strided_slice from 1-D + vector.insert into 2-D +/// Note that 1-D extract_strided_slice are lowered to efficient vector.shuffle. class ShapeCastOp2DUpCastRewritePattern : public OpRewritePattern { public: diff --git a/mlir/lib/Dialect/Vector/VectorUtils.cpp b/mlir/lib/Dialect/Vector/VectorUtils.cpp --- a/mlir/lib/Dialect/Vector/VectorUtils.cpp +++ b/mlir/lib/Dialect/Vector/VectorUtils.cpp @@ -362,3 +362,16 @@ priorWrite.getVectorType() == write.getVectorType() && priorWrite.permutation_map() == write.permutation_map(); } + +SmallVector mlir::getI64SubArray(ArrayAttr arrayAttr, + unsigned dropFront, + unsigned dropBack) { + assert(arrayAttr.size() > dropFront + dropBack && "Out of bounds"); + auto range = arrayAttr.getAsRange(); + SmallVector res; + res.reserve(arrayAttr.size() - dropFront - dropBack); + for (auto it = range.begin() + dropFront, eit = range.end() - dropBack; + it != eit; ++it) + res.push_back((*it).getValue().getSExtValue()); + return res; +}