diff --git a/mlir/include/mlir/Dialect/Vector/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/VectorRewritePatterns.h --- a/mlir/include/mlir/Dialect/Vector/VectorRewritePatterns.h +++ b/mlir/include/mlir/Dialect/Vector/VectorRewritePatterns.h @@ -29,6 +29,8 @@ /// Lower 2-D transpose to `vector.flat_transpose`, maps 1-1 to LLVM matrix /// intrinsics. Flat = 1, + /// Lower 2-D transpose to `vector.shuffle`. + Shuffle = 2, }; /// Enum to control the lowering of `vector.multi_reduction` operations. enum class VectorMultiReductionLowering { diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -686,6 +686,12 @@ for (auto attr : op.transp()) transp.push_back(attr.cast().getInt()); + if (vectorTransformOptions.vectorTransposeLowering == + vector::VectorTransposeLowering::Shuffle && + resType.getRank() == 2 && transp[0] == 1 && transp[1] == 0) + return rewriter.notifyMatchFailure( + op, "Options specifies lowering to shuffle"); + // Handle a true 2-D matrix transpose differently when requested. if (vectorTransformOptions.vectorTransposeLowering == vector::VectorTransposeLowering::Flat && @@ -740,6 +746,61 @@ vector::VectorTransformsOptions vectorTransformOptions; }; +/// Rewrite a 2-D vector.transpose as a sequence of: +/// vector.shape_cast 2D -> 1D +/// vector.shuffle +/// vector.shape_cast 1D -> 2D +class TransposeOp2DToShuffleLowering + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + TransposeOp2DToShuffleLowering( + vector::VectorTransformsOptions vectorTransformOptions, + MLIRContext *context) + : OpRewritePattern(context), + vectorTransformOptions(vectorTransformOptions) {} + + LogicalResult matchAndRewrite(vector::TransposeOp op, + PatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + VectorType srcType = op.getVectorType(); + if (srcType.getRank() != 2) + return rewriter.notifyMatchFailure(op, "Not a 2D transpose"); + + SmallVector transp; + for (auto attr : op.transp()) + transp.push_back(attr.cast().getInt()); + if (transp[0] != 1 && transp[1] != 0) + return rewriter.notifyMatchFailure(op, "Not a 2D transpose permutation"); + + if (vectorTransformOptions.vectorTransposeLowering != + VectorTransposeLowering::Shuffle) + return rewriter.notifyMatchFailure(op, "Options do not ask for Shuffle"); + + int64_t m = srcType.getShape().front(), n = srcType.getShape().back(); + Value casted = rewriter.create( + loc, VectorType::get({m * n}, srcType.getElementType()), op.vector()); + SmallVector mask; + mask.reserve(m * n); + for (int64_t j = 0; j < n; ++j) + for (int64_t i = 0; i < m; ++i) + mask.push_back(i * n + j); + + Value shuffled = + rewriter.create(loc, casted, casted, mask); + rewriter.replaceOpWithNewOp(op, op.getResultType(), + shuffled); + + return success(); + } + +private: + /// Options to control the vector patterns. + vector::VectorTransformsOptions vectorTransformOptions; +}; + /// Progressive lowering of OuterProductOp. /// One: /// %x = vector.outerproduct %lhs, %rhs, %acc @@ -3648,7 +3709,8 @@ void mlir::vector::populateVectorTransposeLoweringPatterns( RewritePatternSet &patterns, VectorTransformsOptions options) { - patterns.add(options, patterns.getContext()); + patterns.add( + options, patterns.getContext()); } void mlir::vector::populateVectorReductionToContractPatterns( diff --git a/mlir/test/Dialect/Vector/vector-transpose-to-shuffle.mlir b/mlir/test/Dialect/Vector/vector-transpose-to-shuffle.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Vector/vector-transpose-to-shuffle.mlir @@ -0,0 +1,14 @@ +// RUN: mlir-opt %s -test-vector-contraction-conversion=vector-shuffle-transpose=1 | FileCheck %s + +// CHECK-LABEL: func @transpose +func @transpose(%arg0: vector<2x4xf32>) -> vector<4x2xf32> { + // CHECK: vector.shape_cast %{{.*}} : vector<2x4xf32> to vector<8xf32> + // 0 4 + // 0 1 2 3 1 5 + // 4 5 6 7 -> 2 6 + // 3 7 + // CHECK: vector.shuffle %{{.*}} [0, 4, 1, 5, 2, 6, 3, 7] : vector<8xf32>, vector<8xf32> + // CHECK: vector.shape_cast %{{.*}} : vector<8xf32> to vector<4x2xf32> + %0 = vector.transpose %arg0, [1, 0] : vector<2x4xf32> to vector<4x2xf32> + return %0 : vector<4x2xf32> +} diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -116,6 +116,10 @@ *this, "vector-flat-transpose", llvm::cl::desc("Lower 2-D vector.transpose to vector.flat_transpose"), llvm::cl::init(false)}; + Option lowerToShuffleTranspose{ + *this, "vector-shuffle-transpose", + llvm::cl::desc("Lower 2-D vector.transpose to shape_cast + shuffle"), + llvm::cl::init(false)}; Option lowerToOuterProduct{ *this, "vector-outerproduct", llvm::cl::desc("Lower vector.contract to vector.outerproduct"), @@ -165,12 +169,15 @@ VectorTransposeLowering::EltWise; if (lowerToFlatTranspose) transposeLowering = VectorTransposeLowering::Flat; + if (lowerToShuffleTranspose) + transposeLowering = VectorTransposeLowering::Shuffle; VectorTransformsOptions options{ contractLowering, vectorMultiReductionLowering, transposeLowering}; populateVectorBroadcastLoweringPatterns(patterns); populateVectorContractLoweringPatterns(patterns, options); populateVectorMaskOpLoweringPatterns(patterns); - populateVectorShapeCastLoweringPatterns(patterns); + if (!lowerToShuffleTranspose) + populateVectorShapeCastLoweringPatterns(patterns); populateVectorTransposeLoweringPatterns(patterns, options); (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); }