diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h --- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h +++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h @@ -241,8 +241,8 @@ /// Populate `patterns` with the following patterns. /// -/// [VectorInsertStridedSliceOpDifferentRankRewritePattern] -/// ======================================================= +/// [DecomposeDifferentRankInsertStridedSlice] +/// ========================================== /// RewritePattern for InsertStridedSliceOp where source and destination vectors /// have different ranks. /// @@ -257,8 +257,19 @@ /// 2. k-D -> (n-1)-D InsertStridedSlice op /// 3. InsertOp that is the reverse of 1. /// -/// [VectorInsertStridedSliceOpSameRankRewritePattern] -/// ================================================== +/// [DecomposeNDExtractStridedSlice] +/// ================================ +/// For such cases, we can rewrite it to ExtractOp/ExtractElementOp + lower +/// rank ExtractStridedSliceOp + InsertOp/InsertElementOp for the n-D case. +void populateVectorInsertExtractStridedSliceDecompositionPatterns( + RewritePatternSet &patterns); + +/// Populate `patterns` with the following patterns. +/// +/// Patterns in populateVectorInsertExtractStridedSliceDecompositionPatterns(); +/// +/// [ConvertSameRankInsertStridedSliceIntoShuffle] +/// ============================================== /// RewritePattern for InsertStridedSliceOp where source and destination vectors /// have the same rank. For each outermost index in the slice: /// begin end stride @@ -268,12 +279,9 @@ /// 3. the destination subvector is inserted back in the proper place /// 3. InsertOp that is the reverse of 1. /// -/// [VectorExtractStridedSliceOpRewritePattern] -/// =========================================== -/// Progressive lowering of ExtractStridedSliceOp to either: -/// 1. single offset extract as a direct vector::ShuffleOp. -/// 2. ExtractOp/ExtractElementOp + lower rank ExtractStridedSliceOp + -/// InsertOp/InsertElementOp for the n-D case. +/// [Convert1DExtractStridedSliceIntoShuffle] +/// ========================================= +/// For such cases, we can lower it to a ShuffleOp. void populateVectorInsertExtractStridedSliceTransforms( RewritePatternSet &patterns); diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp @@ -45,14 +45,14 @@ /// When ranks are different, InsertStridedSlice needs to extract a properly /// ranked vector from the destination vector into which to insert. This pattern /// only takes care of this extraction part and forwards the rest to -/// [VectorInsertStridedSliceOpSameRankRewritePattern]. +/// [ConvertSameRankInsertStridedSliceIntoShuffle]. /// /// For a k-D source and n-D destination vector (k < n), we emit: /// 1. ExtractOp to extract the (unique) (n-1)-D subvector into which to /// insert the k-D source. /// 2. k-D -> (n-1)-D InsertStridedSlice op /// 3. InsertOp that is the reverse of 1. -class VectorInsertStridedSliceOpDifferentRankRewritePattern +class DecomposeDifferentRankInsertStridedSlice : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; @@ -102,7 +102,7 @@ /// 2. InsertStridedSlice (k-1)-D into (n-1)-D /// 3. the destination subvector is inserted back in the proper place /// 3. InsertOp that is the reverse of 1. -class VectorInsertStridedSliceOpSameRankRewritePattern +class ConvertSameRankInsertStridedSliceIntoShuffle : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; @@ -193,11 +193,50 @@ } }; -/// Progressive lowering of ExtractStridedSliceOp to either: -/// 1. single offset extract as a direct vector::ShuffleOp. -/// 2. ExtractOp/ExtractElementOp + lower rank ExtractStridedSliceOp + -/// InsertOp/InsertElementOp for the n-D case. -class VectorExtractStridedSliceOpRewritePattern +/// RewritePattern for ExtractStridedSliceOp where source and destination +/// vectors are 1-D. For such cases, we can lower it to a ShuffleOp. +class Convert1DExtractStridedSliceIntoShuffle + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ExtractStridedSliceOp op, + PatternRewriter &rewriter) const override { + auto dstType = op.getType(); + + assert(!op.getOffsets().getValue().empty() && "Unexpected empty offsets"); + + int64_t offset = + op.getOffsets().getValue().front().cast().getInt(); + int64_t size = + op.getSizes().getValue().front().cast().getInt(); + int64_t stride = + op.getStrides().getValue().front().cast().getInt(); + + auto loc = op.getLoc(); + auto elemType = dstType.getElementType(); + assert(elemType.isSignlessIntOrIndexOrFloat()); + + // Single offset can be more efficiently shuffled. + if (op.getOffsets().getValue().size() != 1) + return failure(); + + SmallVector offsets; + offsets.reserve(size); + for (int64_t off = offset, e = offset + size * stride; off < e; + off += stride) + offsets.push_back(off); + rewriter.replaceOpWithNewOp(op, dstType, op.getVector(), + op.getVector(), + rewriter.getI64ArrayAttr(offsets)); + return success(); + } +}; + +/// RewritePattern for ExtractStridedSliceOp where the source vector is n-D. +/// For such cases, we can rewrite it to ExtractOp/ExtractElementOp + lower +/// rank ExtractStridedSliceOp + InsertOp/InsertElementOp for the n-D case. +class DecomposeNDExtractStridedSlice : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; @@ -225,18 +264,10 @@ auto elemType = dstType.getElementType(); assert(elemType.isSignlessIntOrIndexOrFloat()); - // Single offset can be more efficiently shuffled. - if (op.getOffsets().getValue().size() == 1) { - SmallVector offsets; - offsets.reserve(size); - for (int64_t off = offset, e = offset + size * stride; off < e; - off += stride) - offsets.push_back(off); - rewriter.replaceOpWithNewOp(op, dstType, op.getVector(), - op.getVector(), - rewriter.getI64ArrayAttr(offsets)); - return success(); - } + // Single offset can be more efficiently shuffled. It's handled in + // Convert1DExtractStridedSliceIntoShuffle. + if (op.getOffsets().getValue().size() == 1) + return failure(); // Extract/insert on a lower ranked extract strided slice op. Value zero = rewriter.create( @@ -256,11 +287,16 @@ } }; +void mlir::vector::populateVectorInsertExtractStridedSliceDecompositionPatterns( + RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); +} + /// Populate the given list with patterns that convert from Vector to LLVM. void mlir::vector::populateVectorInsertExtractStridedSliceTransforms( RewritePatternSet &patterns) { - patterns.add( - patterns.getContext()); + populateVectorInsertExtractStridedSliceDecompositionPatterns(patterns); + patterns.add(patterns.getContext()); }