diff --git a/mlir/include/mlir/Dialect/VectorOps/VectorOps.h b/mlir/include/mlir/Dialect/VectorOps/VectorOps.h --- a/mlir/include/mlir/Dialect/VectorOps/VectorOps.h +++ b/mlir/include/mlir/Dialect/VectorOps/VectorOps.h @@ -43,6 +43,17 @@ void populateVectorToVectorTransformationPatterns( OwningRewritePatternList &patterns, MLIRContext *context); +/// Collect a set of vector slices transformation patterns: +/// ExtractSlicesOpLowering, InsertSlicesOpLowering +/// Useful for clients that want to express all vector "slices" +/// ops in terms of more elementary vector "slice" ops. If all +/// "produced" tuple values are "consumed" (the most common +/// use for "slices" ops), this lowering removes all tuple related +/// operations as well (through DCE and folding). If tuple values +/// "leak" coming in, however, some tuple related ops will remain. +void populateVectorSlicesLoweringPatterns(OwningRewritePatternList &patterns, + MLIRContext *context); + /// Returns the integer type required for subscripts in the vector dialect. IntegerType getVectorSubscriptType(Builder &builder); diff --git a/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp b/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp --- a/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp +++ b/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp @@ -13,6 +13,7 @@ #include #include "mlir/Dialect/AffineOps/AffineOps.h" +#include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/Dialect/VectorOps/VectorOps.h" #include "mlir/Dialect/VectorOps/VectorTransforms.h" #include "mlir/Dialect/VectorOps/VectorUtils.h" @@ -28,6 +29,7 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/IR/Types.h" #include "mlir/Support/Functional.h" +#include "mlir/Support/MathExtras.h" #include "mlir/Support/STLExtras.h" #include "llvm/Support/CommandLine.h" @@ -657,6 +659,131 @@ } }; +/// Progressive lowering of ExtractSlicesOp to tuple of StridedSliceOp. +/// One: +/// %x = vector.extract_slices %0 +/// is replaced by: +/// %a = vector.strided_slice %0 +/// %b = vector.strided_slice %0 +/// .. +/// %x = vector.tuple %a, %b, .. +class ExtractSlicesOpLowering + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + // TODO(ajcbik): refactor slice utilities out into VectorUtils.h + PatternMatchResult matchAndRewrite(vector::ExtractSlicesOp op, + PatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + VectorType vectorType = op.getSourceVectorType(); + int64_t rank = vectorType.getRank(); + auto shape = vectorType.getShape(); + + SmallVector sizes; + op.getSizes(sizes); + SmallVector strides; + op.getStrides(strides); // all-ones at the moment + + // Compute the number of slices in each dimension. + SmallVector sliceDimCounts(rank); + for (int64_t r = 0; r < rank; ++r) + sliceDimCounts[r] = ceilDiv(shape[r], sizes[r]); + + // For each element in the tuple, generate the proper strided slice. + auto basis = computeStrides(sliceDimCounts); + TupleType tupleType = op.getResultTupleType(); + int64_t tupleSize = tupleType.size(); + SmallVector tupleValues(tupleSize); + for (int64_t i = 0; i < tupleSize; ++i) { + // De-linearize w.r.t. 'basis'. + auto vectorOffsets = delinearize(i, basis); + // Convert from unrolled vector-space offsets to element-space offsets. + auto elementOffsets = mlir::functional::zipMap( + [](int64_t v1, int64_t v2) { return v1 * v2; }, vectorOffsets, sizes); + // Compute the size of each slice. + SmallVector sliceSizes(rank); + for (int64_t r = 0; r < rank; ++r) + sliceSizes[r] = std::min(sizes[r], shape[r] - elementOffsets[r]); + // Insert in tuple. + tupleValues[i] = rewriter.create( + loc, op.vector(), elementOffsets, sliceSizes, strides); + } + + rewriter.replaceOpWithNewOp(op, tupleType, tupleValues); + return matchSuccess(); + } +}; + +/// Progressive lowering of InsertSlicesOp to series of InsertStridedSliceOp. +/// One: +/// %x = vector.insert_slices %0 +/// is replaced by: +/// %r0 = vector.splat 0 +// %t1 = vector.tuple_get %0, 0 +/// %r1 = vector.insert_strided_slice %r0, %t1 +// %t2 = vector.tuple_get %0, 1 +/// %r2 = vector.insert_strided_slice %r1, %t2 +/// .. +/// %x = .. +class InsertSlicesOpLowering : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + // TODO(ajcbik): refactor slice utilities out into VectorUtils.h + PatternMatchResult matchAndRewrite(vector::InsertSlicesOp op, + PatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + VectorType vectorType = op.getResultVectorType(); + int64_t rank = vectorType.getRank(); + auto shape = vectorType.getShape(); + + SmallVector sizes; + op.getSizes(sizes); + SmallVector strides; + op.getStrides(strides); // all-ones at the moment + + // Compute the number of slices in each dimension. + SmallVector sliceDimCounts(rank); + for (int64_t r = 0; r < rank; ++r) + sliceDimCounts[r] = ceilDiv(shape[r], sizes[r]); + + // Prepare result. + auto elemType = vectorType.getElementType(); + Value zero = rewriter.create(loc, elemType, + rewriter.getZeroAttr(elemType)); + Value result = rewriter.create(loc, vectorType, zero); + + // For each element in the tuple, extract the proper strided slice. + auto basis = computeStrides(sliceDimCounts); + TupleType tupleType = op.getSourceTupleType(); + int64_t tupleSize = tupleType.size(); + SmallVector tupleValues(tupleSize); + for (int64_t i = 0; i < tupleSize; ++i) { + // De-linearize w.r.t. 'basis'. + auto vectorOffsets = delinearize(i, basis); + // Convert from unrolled vector-space offsets to element-space offsets. + auto elementOffsets = mlir::functional::zipMap( + [](int64_t v1, int64_t v2) { return v1 * v2; }, vectorOffsets, sizes); + // Compute the size of each slice. + SmallVector sliceSizes(rank); + for (int64_t r = 0; r < rank; ++r) + sliceSizes[r] = std::min(sizes[r], shape[r] - elementOffsets[r]); + // Extract from tuple into the result. + auto index = rewriter.getI64IntegerAttr(i); + auto tupleGet = rewriter.create( + loc, tupleType.getType(i), op.getOperand(), index); + result = rewriter.create( + loc, tupleGet, result, elementOffsets, strides); + } + + rewriter.replaceOp(op, result); + return matchSuccess(); + } +}; + } // namespace // TODO(andydavis) Add pattern to rewrite ExtractSlices(ConstantMaskOp). @@ -666,3 +793,8 @@ patterns.insert( context); } + +void mlir::vector::populateVectorSlicesLoweringPatterns( + OwningRewritePatternList &patterns, MLIRContext *context) { + patterns.insert(context); +} diff --git a/mlir/test/Dialect/VectorOps/vector-slices-transforms.mlir b/mlir/test/Dialect/VectorOps/vector-slices-transforms.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/VectorOps/vector-slices-transforms.mlir @@ -0,0 +1,63 @@ +// RUN: mlir-opt %s -test-vector-slices-conversion | FileCheck %s + +// CHECK-LABEL: func @extract_slices(%arg0: vector<3x3xf32>) +// CHECK: %[[SS:.*]] = vector.strided_slice %arg0 {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} +// CHECK: return %[[SS]] + +func @extract_slices(%arg0: vector<3x3xf32>) -> vector<2x2xf32> { + %0 = vector.extract_slices %arg0, [2, 2], [1, 1] + : vector<3x3xf32> into tuple, vector<2x1xf32>, vector<1x2xf32>, vector<1x1xf32>> + %1 = vector.tuple_get %0, 0 : tuple, vector<2x1xf32>, vector<1x2xf32>, vector<1x1xf32>> + return %1 : vector<2x2xf32> +} + +// CHECK-LABEL: func @insert_slices(%arg0: vector<2x2xf32>, %arg1: vector<2x1xf32>, %arg2: vector<1x2xf32>, %arg3: vector<1x1xf32>) +// CHECK: %[[C0:.*]] = constant dense<0.000000e+00> : vector<3x3xf32> +// CHECK: %[[I0:.*]] = vector.insert_strided_slice %arg0, %[[C0]] {offsets = [0, 0], strides = [1, 1]} +// CHECK: %[[I1:.*]] = vector.insert_strided_slice %arg1, %[[I0]] {offsets = [0, 2], strides = [1, 1]} +// CHECK: %[[I2:.*]] = vector.insert_strided_slice %arg2, %[[I1]] {offsets = [2, 0], strides = [1, 1]} +// CHECK: %[[I3:.*]] = vector.insert_strided_slice %arg3, %[[I2]] {offsets = [2, 2], strides = [1, 1]} +// CHECK: return %[[I3]] + +func @insert_slices(%arg0: vector<2x2xf32>, + %arg1: vector<2x1xf32>, + %arg2: vector<1x2xf32>, + %arg3: vector<1x1xf32>) -> vector<3x3xf32> { + %0 = vector.tuple %arg0, %arg1, %arg2, %arg3 + : vector<2x2xf32>, vector<2x1xf32>, vector<1x2xf32>, vector<1x1xf32> + %1 = vector.insert_slices %0, [2, 2], [1, 1] + : tuple, vector<2x1xf32>, vector<1x2xf32>, vector<1x1xf32>> into vector<3x3xf32> + return %1 : vector<3x3xf32> +} + +// CHECK-LABEL: func @extract_insert_slices(%arg0: vector<3x3xf32>) +// CHECK: %[[C:.*]] = constant dense<0.000000e+00> : vector<3x3xf32> +// CHECK: %[[X0:.*]] = vector.strided_slice %arg0 {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} +// CHECK: %[[X1:.*]] = vector.strided_slice %arg0 {offsets = [0, 2], sizes = [2, 1], strides = [1, 1]} +// CHECK: %[[X2:.*]] = vector.strided_slice %arg0 {offsets = [2, 0], sizes = [1, 2], strides = [1, 1]} +// CHECK: %[[X3:.*]] = vector.strided_slice %arg0 {offsets = [2, 2], sizes = [1, 1], strides = [1, 1]} +// CHECK: %[[X4:.*]] = vector.insert_strided_slice %[[X0]], %[[C0]] {offsets = [0, 0], strides = [1, 1]} +// CHECK: %[[X5:.*]] = vector.insert_strided_slice %[[X1]], %[[X4]] {offsets = [0, 2], strides = [1, 1]} +// CHECK: %[[X6:.*]] = vector.insert_strided_slice %[[X2]], %[[X5]] {offsets = [2, 0], strides = [1, 1]} +// CHECK: %[[X7:.*]] = vector.insert_strided_slice %[[X3]], %[[X6]] {offsets = [2, 2], strides = [1, 1]} +// CHECK:return %[[X7]] + +func @extract_insert_slices(%arg0: vector<3x3xf32>) -> vector<3x3xf32> { + %0 = vector.extract_slices %arg0, [2, 2], [1, 1] + : vector<3x3xf32> into tuple, vector<2x1xf32>, vector<1x2xf32>, vector<1x1xf32>> + %1 = vector.insert_slices %0, [2, 2], [1, 1] + : tuple, vector<2x1xf32>, vector<1x2xf32>, vector<1x1xf32>> into vector<3x3xf32> + return %1 : vector<3x3xf32> +} + +// CHECK-LABEL: func @extract_slices_tuple_leaks(%arg0: vector<4xf32>) +// CHECK: %[[X0:.*]] = vector.strided_slice %arg0 {offsets = [0], sizes = [2], strides = [1]} +// CHECK: %[[X1:.*]] = vector.strided_slice %arg0 {offsets = [2], sizes = [2], strides = [1]} +// CHECK: %[[X2:.*]] = vector.tuple %[[X0]], %[[X1]] +// CHECK: return %[[X2]] + +func @extract_slices_tuple_leaks(%arg0: vector<4xf32>) -> tuple, vector<2xf32>> { + %0 = vector.extract_slices %arg0, [2], [1] : vector<4xf32> into tuple, vector<2xf32>> + return %0 : tuple, vector<2xf32>> +} + diff --git a/mlir/test/lib/Transforms/TestVectorTransforms.cpp b/mlir/test/lib/Transforms/TestVectorTransforms.cpp --- a/mlir/test/lib/Transforms/TestVectorTransforms.cpp +++ b/mlir/test/lib/Transforms/TestVectorTransforms.cpp @@ -18,6 +18,7 @@ using namespace mlir::vector; namespace { + #include "TestVectorTransformPatterns.h.inc" struct TestVectorToVectorConversion @@ -31,8 +32,22 @@ applyPatternsGreedily(getFunction(), patterns); } }; + +struct TestVectorSlicesConversion + : public FunctionPass { + void runOnFunction() override { + OwningRewritePatternList patterns; + populateVectorSlicesLoweringPatterns(patterns, &getContext()); + applyPatternsGreedily(getFunction(), patterns); + } +}; + } // end anonymous namespace static PassRegistration pass("test-vector-to-vector-conversion", "Test conversion patterns between ops in the vector dialect"); + +static PassRegistration slices_pass( + "test-vector-slices-conversion", + "Test conversion patterns that lower slices ops in the vector dialect");