diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h --- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h +++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h @@ -285,6 +285,17 @@ void populateVectorInsertExtractStridedSliceDecompositionPatterns( RewritePatternSet &patterns, PatternBenefit benefit = 1); +/// Populate `patterns` with a pattern to breaks down 1-D extract_strided_slice +/// ops into a chain of Extract ops to extract each element from the source, and +/// then a chain of Insert ops to insert to the target vector. +/// +/// If `controlFn` is not nullptr, the pattern will only be invoked on ops that +/// `controlFn` returns true. Otherwise runs on ops. +void populateVectorExtractStridedSliceToExtractInsertChainPatterns( + RewritePatternSet &patterns, + std::function controlFn = nullptr, + PatternBenefit benefit = 1); + /// Populate `patterns` with the following patterns. /// /// Patterns in populateVectorInsertExtractStridedSliceDecompositionPatterns(); diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp @@ -13,6 +13,7 @@ #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" #include "mlir/Dialect/Vector/Utils/VectorUtils.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/PatternMatch.h" using namespace mlir; using namespace mlir::vector; @@ -231,6 +232,53 @@ } }; +/// For a 1-D ExtractStridedSlice, breaks it down into a chain of Extract ops +/// to extract each element from the source, and then a chain of Insert ops +/// to insert to the target vector. +class Convert1DExtractStridedSliceIntoExtractInsertChain final + : public OpRewritePattern { +public: + Convert1DExtractStridedSliceIntoExtractInsertChain( + MLIRContext *context, + std::function controlFn, + PatternBenefit benefit) + : OpRewritePattern(context, benefit), controlFn(std::move(controlFn)) {} + + LogicalResult matchAndRewrite(ExtractStridedSliceOp op, + PatternRewriter &rewriter) const override { + if (controlFn && !controlFn(op)) + return failure(); + + // Only handle 1-D cases. + if (op.getOffsets().getValue().size() != 1) + return failure(); + + int64_t offset = + op.getOffsets().getValue().front().cast().getInt(); + int64_t size = + op.getSizes().getValue().front().cast().getInt(); + int64_t stride = + op.getStrides().getValue().front().cast().getInt(); + + Location loc = op.getLoc(); + SmallVector elements; + elements.reserve(size); + for (int64_t i = offset, e = offset + size * stride; i < e; i += stride) + elements.push_back(rewriter.create(loc, op.getVector(), i)); + + Value result = rewriter.create( + loc, rewriter.getZeroAttr(op.getType())); + for (int64_t i = 0; i < size; ++i) + result = rewriter.create(loc, elements[i], result, i); + + rewriter.replaceOp(op, result); + return success(); + } + +private: + std::function controlFn; +}; + /// RewritePattern for ExtractStridedSliceOp where the source vector is n-D. /// For such cases, we can rewrite it to ExtractOp/ExtractElementOp + lower /// rank ExtractStridedSliceOp + InsertOp/InsertElementOp for the n-D case. @@ -285,14 +333,22 @@ } }; -void mlir::vector::populateVectorInsertExtractStridedSliceDecompositionPatterns( +void vector::populateVectorInsertExtractStridedSliceDecompositionPatterns( RewritePatternSet &patterns, PatternBenefit benefit) { patterns.add(patterns.getContext(), benefit); } +void vector::populateVectorExtractStridedSliceToExtractInsertChainPatterns( + RewritePatternSet &patterns, + std::function controlFn, + PatternBenefit benefit) { + patterns.add( + patterns.getContext(), std::move(controlFn), benefit); +} + /// Populate the given list with patterns that convert from Vector to LLVM. -void mlir::vector::populateVectorInsertExtractStridedSliceTransforms( +void vector::populateVectorInsertExtractStridedSliceTransforms( RewritePatternSet &patterns, PatternBenefit benefit) { populateVectorInsertExtractStridedSliceDecompositionPatterns(patterns, benefit); diff --git a/mlir/test/Dialect/Vector/vector-extract-strided-slice-lowering.mlir b/mlir/test/Dialect/Vector/vector-extract-strided-slice-lowering.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Vector/vector-extract-strided-slice-lowering.mlir @@ -0,0 +1,29 @@ +// RUN: mlir-opt -split-input-file -test-vector-extract-strided-slice-lowering %s | FileCheck %s + +// CHECK-LABEL: func.func @extract_strided_slice_1D +// CHECK-SAME: (%[[INPUT:.+]]: vector<8xf16>) +func.func @extract_strided_slice_1D(%input: vector<8xf16>) -> vector<4xf16> { + %0 = vector.extract_strided_slice %input {offsets = [1], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16> + return %0: vector<4xf16> +} + +// CHECK: %[[INIT:.+]] = arith.constant dense<0.000000e+00> : vector<4xf16> +// CHECK: %[[E0:.+]] = vector.extract %[[INPUT]][1] : vector<8xf16> +// CHECK: %[[E1:.+]] = vector.extract %[[INPUT]][2] : vector<8xf16> +// CHECK: %[[E2:.+]] = vector.extract %[[INPUT]][3] : vector<8xf16> +// CHECK: %[[E3:.+]] = vector.extract %[[INPUT]][4] : vector<8xf16> +// CHECK: %[[I0:.+]] = vector.insert %[[E0]], %[[INIT]] [0] : f16 into vector<4xf16> +// CHECK: %[[I1:.+]] = vector.insert %[[E1]], %[[I0]] [1] : f16 into vector<4xf16> +// CHECK: %[[I2:.+]] = vector.insert %[[E2]], %[[I1]] [2] : f16 into vector<4xf16> +// CHECK: %[[I3:.+]] = vector.insert %[[E3]], %[[I2]] [3] : f16 into vector<4xf16> +// CHECK: return %[[I3]] + + +// ----- + +// CHECK-LABEL: func.func @extract_strided_slice_2D +func.func @extract_strided_slice_2D(%input: vector<1x8xf16>) -> vector<1x4xf16> { + // CHECK: vector.extract_strided_slice + %0 = vector.extract_strided_slice %input {offsets = [0, 1], sizes = [1, 4], strides = [1, 1]} : vector<1x8xf16> to vector<1x4xf16> + return %0: vector<1x4xf16> +} diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -20,6 +20,7 @@ #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/Transforms/VectorDistribution.h" +#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" @@ -785,6 +786,26 @@ } }; +struct TestVectorExtractStridedSliceLowering + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( + TestVectorExtractStridedSliceLowering) + + StringRef getArgument() const final { + return "test-vector-extract-strided-slice-lowering"; + } + StringRef getDescription() const final { + return "Test lowering patterns that converts vector.extract_strided_slice " + "into a chain of vector.extract and vector.insert ops"; + } + void runOnOperation() override { + RewritePatternSet patterns(&getContext()); + populateVectorExtractStridedSliceToExtractInsertChainPatterns(patterns); + (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + } +}; + } // namespace namespace mlir { @@ -819,6 +840,8 @@ PassRegistration(); PassRegistration(); + + PassRegistration(); } } // namespace test } // namespace mlir