Page MenuHomePhabricator

No OneTemporary

File Metadata

Created
Fri, Jan 24, 4:27 PM
diff --git a/mlir/include/mlir/Dialect/VectorOps/VectorOps.h b/mlir/include/mlir/Dialect/VectorOps/VectorOps.h
index 7234d46..e6c97fd 100644
--- a/mlir/include/mlir/Dialect/VectorOps/VectorOps.h
+++ b/mlir/include/mlir/Dialect/VectorOps/VectorOps.h
@@ -43,6 +43,17 @@ void populateVectorToVectorCanonicalizationPatterns(
void populateVectorToVectorTransformationPatterns(
OwningRewritePatternList &patterns, MLIRContext *context);
+/// Collect a set of vector slices transformation patterns:
+/// ExtractSlicesOpLowering, InsertSlicesOpLowering
+/// Useful for clients that want to express all vector "slices"
+/// ops in terms of more elementary vector "slice" ops. If all
+/// "produced" tuple values are "consumed" (the most common
+/// use for "slices" ops), this lowering removes all tuple related
+/// operations as well (through DCE and folding). If tuple values
+/// "leak" coming in, however, some tuple related ops will remain.
+void populateVectorSlicesLoweringPatterns(OwningRewritePatternList &patterns,
+ MLIRContext *context);
+
/// Returns the integer type required for subscripts in the vector dialect.
IntegerType getVectorSubscriptType(Builder &builder);
diff --git a/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp b/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp
index 00ed27a..1cc9419 100644
--- a/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp
@@ -13,6 +13,7 @@
#include <type_traits>
#include "mlir/Dialect/AffineOps/AffineOps.h"
+#include "mlir/Dialect/StandardOps/Ops.h"
#include "mlir/Dialect/VectorOps/VectorOps.h"
#include "mlir/Dialect/VectorOps/VectorTransforms.h"
#include "mlir/Dialect/VectorOps/VectorUtils.h"
@@ -28,6 +29,7 @@
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Types.h"
#include "mlir/Support/Functional.h"
+#include "mlir/Support/MathExtras.h"
#include "mlir/Support/STLExtras.h"
#include "llvm/Support/CommandLine.h"
@@ -657,6 +659,131 @@ struct TupleGetFolderOp : public OpRewritePattern<vector::TupleGetOp> {
}
};
+/// Progressive lowering of ExtractSlicesOp to tuple of StridedSliceOp.
+/// One:
+/// %x = vector.extract_slices %0
+/// is replaced by:
+/// %a = vector.strided_slice %0
+/// %b = vector.strided_slice %0
+/// ..
+/// %x = vector.tuple %a, %b, ..
+class ExtractSlicesOpLowering
+ : public OpRewritePattern<vector::ExtractSlicesOp> {
+public:
+ using OpRewritePattern<vector::ExtractSlicesOp>::OpRewritePattern;
+
+ // TODO(ajcbik): refactor slice utilities out into VectorUtils.h
+ PatternMatchResult matchAndRewrite(vector::ExtractSlicesOp op,
+ PatternRewriter &rewriter) const override {
+ auto loc = op.getLoc();
+
+ VectorType vectorType = op.getSourceVectorType();
+ int64_t rank = vectorType.getRank();
+ auto shape = vectorType.getShape();
+
+ SmallVector<int64_t, 4> sizes;
+ op.getSizes(sizes);
+ SmallVector<int64_t, 4> strides;
+ op.getStrides(strides); // all-ones at the moment
+
+ // Compute the number of slices in each dimension.
+ SmallVector<int64_t, 4> sliceDimCounts(rank);
+ for (int64_t r = 0; r < rank; ++r)
+ sliceDimCounts[r] = ceilDiv(shape[r], sizes[r]);
+
+ // For each element in the tuple, generate the proper strided slice.
+ auto basis = computeStrides(sliceDimCounts);
+ TupleType tupleType = op.getResultTupleType();
+ int64_t tupleSize = tupleType.size();
+ SmallVector<Value, 4> tupleValues(tupleSize);
+ for (int64_t i = 0; i < tupleSize; ++i) {
+ // De-linearize w.r.t. 'basis'.
+ auto vectorOffsets = delinearize(i, basis);
+ // Convert from unrolled vector-space offsets to element-space offsets.
+ auto elementOffsets = mlir::functional::zipMap(
+ [](int64_t v1, int64_t v2) { return v1 * v2; }, vectorOffsets, sizes);
+ // Compute the size of each slice.
+ SmallVector<int64_t, 4> sliceSizes(rank);
+ for (int64_t r = 0; r < rank; ++r)
+ sliceSizes[r] = std::min(sizes[r], shape[r] - elementOffsets[r]);
+ // Insert in tuple.
+ tupleValues[i] = rewriter.create<vector::StridedSliceOp>(
+ loc, op.vector(), elementOffsets, sliceSizes, strides);
+ }
+
+ rewriter.replaceOpWithNewOp<vector::TupleOp>(op, tupleType, tupleValues);
+ return matchSuccess();
+ }
+};
+
+/// Progressive lowering of InsertSlicesOp to series of InsertStridedSliceOp.
+/// One:
+/// %x = vector.insert_slices %0
+/// is replaced by:
+/// %r0 = vector.splat 0
+// %t1 = vector.tuple_get %0, 0
+/// %r1 = vector.insert_strided_slice %r0, %t1
+// %t2 = vector.tuple_get %0, 1
+/// %r2 = vector.insert_strided_slice %r1, %t2
+/// ..
+/// %x = ..
+class InsertSlicesOpLowering : public OpRewritePattern<vector::InsertSlicesOp> {
+public:
+ using OpRewritePattern<vector::InsertSlicesOp>::OpRewritePattern;
+
+ // TODO(ajcbik): refactor slice utilities out into VectorUtils.h
+ PatternMatchResult matchAndRewrite(vector::InsertSlicesOp op,
+ PatternRewriter &rewriter) const override {
+ auto loc = op.getLoc();
+
+ VectorType vectorType = op.getResultVectorType();
+ int64_t rank = vectorType.getRank();
+ auto shape = vectorType.getShape();
+
+ SmallVector<int64_t, 4> sizes;
+ op.getSizes(sizes);
+ SmallVector<int64_t, 4> strides;
+ op.getStrides(strides); // all-ones at the moment
+
+ // Compute the number of slices in each dimension.
+ SmallVector<int64_t, 4> sliceDimCounts(rank);
+ for (int64_t r = 0; r < rank; ++r)
+ sliceDimCounts[r] = ceilDiv(shape[r], sizes[r]);
+
+ // Prepare result.
+ auto elemType = vectorType.getElementType();
+ Value zero = rewriter.create<ConstantOp>(loc, elemType,
+ rewriter.getZeroAttr(elemType));
+ Value result = rewriter.create<SplatOp>(loc, vectorType, zero);
+
+ // For each element in the tuple, extract the proper strided slice.
+ auto basis = computeStrides(sliceDimCounts);
+ TupleType tupleType = op.getSourceTupleType();
+ int64_t tupleSize = tupleType.size();
+ SmallVector<Value, 4> tupleValues(tupleSize);
+ for (int64_t i = 0; i < tupleSize; ++i) {
+ // De-linearize w.r.t. 'basis'.
+ auto vectorOffsets = delinearize(i, basis);
+ // Convert from unrolled vector-space offsets to element-space offsets.
+ auto elementOffsets = mlir::functional::zipMap(
+ [](int64_t v1, int64_t v2) { return v1 * v2; }, vectorOffsets, sizes);
+ // Compute the size of each slice.
+ SmallVector<int64_t, 4> sliceSizes(rank);
+ for (int64_t r = 0; r < rank; ++r)
+ sliceSizes[r] = std::min(sizes[r], shape[r] - elementOffsets[r]);
+ // Extract from tuple into the result.
+ auto index = rewriter.getI64IntegerAttr(i);
+ auto tupleGet = rewriter.create<vector::TupleGetOp>(
+ loc, tupleType.getType(i), op.getOperand(), index);
+ result = rewriter.create<vector::InsertStridedSliceOp>(
+ loc, tupleGet, result, elementOffsets, strides);
+ }
+
+ rewriter.replaceOp(op, result);
+ return matchSuccess();
+ }
+};
+
} // namespace
// TODO(andydavis) Add pattern to rewrite ExtractSlices(ConstantMaskOp).
@@ -666,3 +793,8 @@ void mlir::vector::populateVectorToVectorTransformationPatterns(
patterns.insert<SplitTransferReadOp, SplitTransferWriteOp, TupleGetFolderOp>(
context);
}
+
+void mlir::vector::populateVectorSlicesLoweringPatterns(
+ OwningRewritePatternList &patterns, MLIRContext *context) {
+ patterns.insert<ExtractSlicesOpLowering, InsertSlicesOpLowering>(context);
+}
diff --git a/mlir/test/Dialect/VectorOps/vector-slices-transforms.mlir b/mlir/test/Dialect/VectorOps/vector-slices-transforms.mlir
new file mode 100644
index 0000000..8936865
--- /dev/null
+++ b/mlir/test/Dialect/VectorOps/vector-slices-transforms.mlir
@@ -0,0 +1,63 @@
+// RUN: mlir-opt %s -test-vector-slices-conversion | FileCheck %s
+
+// CHECK-LABEL: func @extract_slices(%arg0: vector<3x3xf32>)
+// CHECK: %[[SS:.*]] = vector.strided_slice %arg0 {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]}
+// CHECK: return %[[SS]]
+
+func @extract_slices(%arg0: vector<3x3xf32>) -> vector<2x2xf32> {
+ %0 = vector.extract_slices %arg0, [2, 2], [1, 1]
+ : vector<3x3xf32> into tuple<vector<2x2xf32>, vector<2x1xf32>, vector<1x2xf32>, vector<1x1xf32>>
+ %1 = vector.tuple_get %0, 0 : tuple<vector<2x2xf32>, vector<2x1xf32>, vector<1x2xf32>, vector<1x1xf32>>
+ return %1 : vector<2x2xf32>
+}
+
+// CHECK-LABEL: func @insert_slices(%arg0: vector<2x2xf32>, %arg1: vector<2x1xf32>, %arg2: vector<1x2xf32>, %arg3: vector<1x1xf32>)
+// CHECK: %[[C0:.*]] = constant dense<0.000000e+00> : vector<3x3xf32>
+// CHECK: %[[I0:.*]] = vector.insert_strided_slice %arg0, %[[C0]] {offsets = [0, 0], strides = [1, 1]}
+// CHECK: %[[I1:.*]] = vector.insert_strided_slice %arg1, %[[I0]] {offsets = [0, 2], strides = [1, 1]}
+// CHECK: %[[I2:.*]] = vector.insert_strided_slice %arg2, %[[I1]] {offsets = [2, 0], strides = [1, 1]}
+// CHECK: %[[I3:.*]] = vector.insert_strided_slice %arg3, %[[I2]] {offsets = [2, 2], strides = [1, 1]}
+// CHECK: return %[[I3]]
+
+func @insert_slices(%arg0: vector<2x2xf32>,
+ %arg1: vector<2x1xf32>,
+ %arg2: vector<1x2xf32>,
+ %arg3: vector<1x1xf32>) -> vector<3x3xf32> {
+ %0 = vector.tuple %arg0, %arg1, %arg2, %arg3
+ : vector<2x2xf32>, vector<2x1xf32>, vector<1x2xf32>, vector<1x1xf32>
+ %1 = vector.insert_slices %0, [2, 2], [1, 1]
+ : tuple<vector<2x2xf32>, vector<2x1xf32>, vector<1x2xf32>, vector<1x1xf32>> into vector<3x3xf32>
+ return %1 : vector<3x3xf32>
+}
+
+// CHECK-LABEL: func @extract_insert_slices(%arg0: vector<3x3xf32>)
+// CHECK: %[[C:.*]] = constant dense<0.000000e+00> : vector<3x3xf32>
+// CHECK: %[[X0:.*]] = vector.strided_slice %arg0 {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]}
+// CHECK: %[[X1:.*]] = vector.strided_slice %arg0 {offsets = [0, 2], sizes = [2, 1], strides = [1, 1]}
+// CHECK: %[[X2:.*]] = vector.strided_slice %arg0 {offsets = [2, 0], sizes = [1, 2], strides = [1, 1]}
+// CHECK: %[[X3:.*]] = vector.strided_slice %arg0 {offsets = [2, 2], sizes = [1, 1], strides = [1, 1]}
+// CHECK: %[[X4:.*]] = vector.insert_strided_slice %[[X0]], %[[C0]] {offsets = [0, 0], strides = [1, 1]}
+// CHECK: %[[X5:.*]] = vector.insert_strided_slice %[[X1]], %[[X4]] {offsets = [0, 2], strides = [1, 1]}
+// CHECK: %[[X6:.*]] = vector.insert_strided_slice %[[X2]], %[[X5]] {offsets = [2, 0], strides = [1, 1]}
+// CHECK: %[[X7:.*]] = vector.insert_strided_slice %[[X3]], %[[X6]] {offsets = [2, 2], strides = [1, 1]}
+// CHECK:return %[[X7]]
+
+func @extract_insert_slices(%arg0: vector<3x3xf32>) -> vector<3x3xf32> {
+ %0 = vector.extract_slices %arg0, [2, 2], [1, 1]
+ : vector<3x3xf32> into tuple<vector<2x2xf32>, vector<2x1xf32>, vector<1x2xf32>, vector<1x1xf32>>
+ %1 = vector.insert_slices %0, [2, 2], [1, 1]
+ : tuple<vector<2x2xf32>, vector<2x1xf32>, vector<1x2xf32>, vector<1x1xf32>> into vector<3x3xf32>
+ return %1 : vector<3x3xf32>
+}
+
+// CHECK-LABEL: func @extract_slices_tuple_leaks(%arg0: vector<4xf32>)
+// CHECK: %[[X0:.*]] = vector.strided_slice %arg0 {offsets = [0], sizes = [2], strides = [1]}
+// CHECK: %[[X1:.*]] = vector.strided_slice %arg0 {offsets = [2], sizes = [2], strides = [1]}
+// CHECK: %[[X2:.*]] = vector.tuple %[[X0]], %[[X1]]
+// CHECK: return %[[X2]]
+
+func @extract_slices_tuple_leaks(%arg0: vector<4xf32>) -> tuple<vector<2xf32>, vector<2xf32>> {
+ %0 = vector.extract_slices %arg0, [2], [1] : vector<4xf32> into tuple<vector<2xf32>, vector<2xf32>>
+ return %0 : tuple<vector<2xf32>, vector<2xf32>>
+}
+
diff --git a/mlir/test/lib/Transforms/TestVectorTransforms.cpp b/mlir/test/lib/Transforms/TestVectorTransforms.cpp
index 664d49a..6f529fd 100644
--- a/mlir/test/lib/Transforms/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Transforms/TestVectorTransforms.cpp
@@ -18,6 +18,7 @@ using namespace mlir;
using namespace mlir::vector;
namespace {
+
#include "TestVectorTransformPatterns.h.inc"
struct TestVectorToVectorConversion
@@ -31,8 +32,22 @@ struct TestVectorToVectorConversion
applyPatternsGreedily(getFunction(), patterns);
}
};
+
+struct TestVectorSlicesConversion
+ : public FunctionPass<TestVectorSlicesConversion> {
+ void runOnFunction() override {
+ OwningRewritePatternList patterns;
+ populateVectorSlicesLoweringPatterns(patterns, &getContext());
+ applyPatternsGreedily(getFunction(), patterns);
+ }
+};
+
} // end anonymous namespace
static PassRegistration<TestVectorToVectorConversion>
pass("test-vector-to-vector-conversion",
"Test conversion patterns between ops in the vector dialect");
+
+static PassRegistration<TestVectorSlicesConversion> slices_pass(
+ "test-vector-slices-conversion",
+ "Test conversion patterns that lower slices ops in the vector dialect");

Event Timeline