diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.h b/mlir/include/mlir/Dialect/Vector/VectorOps.h --- a/mlir/include/mlir/Dialect/Vector/VectorOps.h +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.h @@ -53,9 +53,19 @@ /// Lower to `vector.outerproduct`. OuterProduct = 2, }; +/// Enum to control the lowering of `vector.transpose` operations. +enum class VectorTransposeLowering { + // Lower transpose into element-wise extract and inserts. + EltWise = 0, + /// Lower 2-D transpose to `vector.flat_transpose`, maps 1-1 to LLVM matrix + /// intrinsics. + Flat = 1, +}; /// Structure to control the behavior of vector transform patterns. struct VectorTransformsOptions { VectorContractLowering vectorContractLowering = VectorContractLowering::FMA; + VectorTransposeLowering vectorTransposeLowering = + VectorTransposeLowering::EltWise; VectorTransformsOptions & setVectorTransformsOptions(VectorContractLowering opt) { vectorContractLowering = opt; diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td @@ -1206,6 +1206,7 @@ } }]; let assemblyFormat = "$source attr-dict `:` type($source) `to` type($result)"; + let hasFolder = 1; } def Vector_TypeCastOp : diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -1667,6 +1667,19 @@ return success(); } +OpFoldResult ShapeCastOp::fold(ArrayRef operands) { + // Nop shape cast. + if (source().getType() == result().getType()) + return source(); + + // Canceling shape casts. + if (auto otherOp = source().getDefiningOp()) + if (result().getType() == otherOp.source().getType()) + return otherOp.source(); + + return {}; +} + //===----------------------------------------------------------------------===// // TypeCastOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -1186,6 +1186,11 @@ public: using OpRewritePattern::OpRewritePattern; + TransposeOpLowering(vector::VectorTransformsOptions vectorTransformsOptions, + MLIRContext *context) + : OpRewritePattern(context), + vectorTransformsOptions(vectorTransformsOptions) {} + LogicalResult matchAndRewrite(vector::TransposeOp op, PatternRewriter &rewriter) const override { auto loc = op.getLoc(); @@ -1197,6 +1202,22 @@ for (auto attr : op.transp()) transp.push_back(attr.cast().getInt()); + // Handle a true 2-D matrix transpose differently when requested. + if (vectorTransformsOptions.vectorTransposeLowering == + vector::VectorTransposeLowering::Flat && + resType.getRank() == 2 && transp[0] == 1 && transp[1] == 0) { + Type flattenedType = + VectorType::get(resType.getNumElements(), resType.getElementType()); + auto matrix = + rewriter.create(loc, flattenedType, op.vector()); + auto rows = rewriter.getI32IntegerAttr(resType.getShape()[0]); + auto columns = rewriter.getI32IntegerAttr(resType.getShape()[1]); + Value trans = rewriter.create( + loc, flattenedType, matrix, rows, columns); + rewriter.replaceOpWithNewOp(op, resType, trans); + return success(); + } + // Generate fully unrolled extract/insert ops. Value result = rewriter.create(loc, resType, rewriter.getZeroAttr(resType)); @@ -1230,6 +1251,9 @@ } return result; } + + /// Options to control the vector patterns. + vector::VectorTransformsOptions vectorTransformsOptions; }; /// Progressive lowering of OuterProductOp. @@ -1829,9 +1853,9 @@ ConstantMaskOpLowering, OuterProductOpLowering, ShapeCastOp2DDownCastRewritePattern, - ShapeCastOp2DUpCastRewritePattern, - TransposeOpLowering>(context); - patterns.insert(context); + patterns.insert(parameters, context); // clang-format on diff --git a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir --- a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir @@ -319,6 +319,26 @@ return %0 : vector<3x2xf32> } + +// CHECK-LABEL: func @nop_shape_cast +// CHECK-SAME: %[[A:.*]]: vector<16xf32> +// CHECK: return %[[A]] : vector<16xf32> + +func @nop_shape_cast(%arg0: vector<16xf32>) -> vector<16xf32> { + %0 = vector.shape_cast %arg0 : vector<16xf32> to vector<16xf32> + return %0 : vector<16xf32> +} + +// CHECK-LABEL: func @cancel_shape_cast +// CHECK-SAME: %[[A:.*]]: vector<16xf32> +// CHECK: return %[[A]] : vector<16xf32> + +func @cancel_shape_cast(%arg0: vector<16xf32>) -> vector<16xf32> { + %0 = vector.shape_cast %arg0 : vector<16xf32> to vector<4x4xf32> + %1 = vector.shape_cast %0 : vector<4x4xf32> to vector<16xf32> + return %1 : vector<16xf32> +} + // Shape up and downcasts for 2-D vectors, for supporting conversion to // llvm.matrix operations // CHECK-LABEL: func @shape_casts diff --git a/mlir/test/Dialect/Vector/vector-flat-transforms.mlir b/mlir/test/Dialect/Vector/vector-flat-transforms.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Vector/vector-flat-transforms.mlir @@ -0,0 +1,62 @@ +// RUN: mlir-opt %s -test-vector-contraction-conversion=vector-flat-transpose=1 | FileCheck %s --dump-input-on-failure + +// Tests for lowering 2-D vector.transpose into vector.flat_transpose. +// +// TODO(ajcbik,ntv): having ShapeCastOp2DDownCastRewritePattern and +// ShapeCastOp2DUpCastRewritePattern too early in +// the greedy rewriting patterns misses opportunities +// to fold shape casts! + +// No shape cast folding expected. +// +// CHECK-LABEL: func @transpose44_44( +// CHECK-SAME: %[[A:.*]]: vector<4x4xf32> +// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<4x4xf32> +// CHECK: %[[T8:.*]] = vector.flat_transpose %{{.*}} {columns = 4 : i32, rows = 4 : i32} : vector<16xf32> -> vector<16xf32> +// CHECK: %[[T9:.*]] = vector.extract_strided_slice %[[T8]] {offsets = [0], sizes = [4], strides = [1]} : vector<16xf32> to vector<4xf32> +// +func @transpose44_44(%arg0: vector<4x4xf32>) -> vector<4x4xf32> { + %0 = vector.transpose %arg0, [1, 0] : vector<4x4xf32> to vector<4x4xf32> + return %0 : vector<4x4xf32> +} + +// Folds preceding shape cast as expected, +// no following shape cast folding expected. +// +// CHECK-LABEL: func @transpose16_44( +// CHECK-SAME: %[[A:.*]]: vector<16xf32> +// CHECK: %[[T0:.*]] = vector.flat_transpose %[[A]] {columns = 4 : i32, rows = 4 : i32} : vector<16xf32> -> vector<16xf32> +// CHECK: %[[T1:.*]] = vector.extract_strided_slice %[[T0]] {offsets = [0], sizes = [4], strides = [1]} : vector<16xf32> to vector<4xf32> +// +func @transpose16_44(%arg0: vector<16xf32>) -> vector<4x4xf32> { + %0 = vector.shape_cast %arg0 : vector<16xf32> to vector<4x4xf32> + %1 = vector.transpose %0, [1, 0] : vector<4x4xf32> to vector<4x4xf32> + return %1 : vector<4x4xf32> +} + +// No preceding shape cast folding expected, +// but FAILS to fold following cast. +// +// CHECK-LABEL: func @transpose44_16( +// CHECK-SAME: %[[A:.*]]: vector<4x4xf32> +// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<4x4xf32> +// CHECK: %[[T8:.*]] = vector.flat_transpose %{{.*}} {columns = 4 : i32, rows = 4 : i32} : vector<16xf32> -> vector<16xf32> +func @transpose44_16(%arg0: vector<4x4xf32>) -> vector<16xf32> { + %0 = vector.transpose %arg0, [1, 0] : vector<4x4xf32> to vector<4x4xf32> + %1 = vector.shape_cast %0 : vector<4x4xf32> to vector<16xf32> + return %1 : vector<16xf32> +} + +// Folds preceding shape cast as expected, +// but FAILS to fold following cast. +// +// CHECK-LABEL: func @transpose16_16( +// CHECK-SAME: %[[A:.*]]: vector<16xf32> +// CHECK: %[[T0:.*]] = vector.flat_transpose %[[A]] {columns = 4 : i32, rows = 4 : i32} : vector<16xf32> -> vector<16xf32> +// +func @transpose16_16(%arg0: vector<16xf32>) -> vector<16xf32> { + %0 = vector.shape_cast %arg0 : vector<16xf32> to vector<4x4xf32> + %1 = vector.transpose %0, [1, 0] : vector<4x4xf32> to vector<4x4xf32> + %2 = vector.shape_cast %1 : vector<4x4xf32> to vector<16xf32> + return %2 : vector<16xf32> +} diff --git a/mlir/test/lib/Transforms/TestVectorTransforms.cpp b/mlir/test/lib/Transforms/TestVectorTransforms.cpp --- a/mlir/test/lib/Transforms/TestVectorTransforms.cpp +++ b/mlir/test/lib/Transforms/TestVectorTransforms.cpp @@ -47,10 +47,14 @@ TestVectorContractionConversion(const TestVectorContractionConversion &pass) { } - Option lowerToLLVMMatrixIntrinsics{ + Option lowerToFlatMatrix{ *this, "vector-lower-matrix-intrinsics", llvm::cl::desc("Lower vector.contract to llvm.intr.matrix.multiply"), llvm::cl::init(false)}; + Option lowerToFlatTranspose{ + *this, "vector-flat-transpose", + llvm::cl::desc("Lower 2-D vector.transpose to vector.flat_transpose"), + llvm::cl::init(false)}; Option lowerToOuterProduct{ *this, "vector-outerproduct", llvm::cl::desc("Lower vector.contract to vector.outerproduct"), @@ -67,10 +71,14 @@ return; } - VectorContractLowering lowering = VectorContractLowering::FMA; - if (lowerToLLVMMatrixIntrinsics) - lowering = VectorContractLowering::Matmul; - VectorTransformsOptions options{lowering}; + VectorContractLowering contractLowering = VectorContractLowering::FMA; + if (lowerToFlatMatrix) + contractLowering = VectorContractLowering::Matmul; + VectorTransposeLowering transposeLowering = + VectorTransposeLowering::EltWise; + if (lowerToFlatTranspose) + transposeLowering = VectorTransposeLowering::Flat; + VectorTransformsOptions options{contractLowering, transposeLowering}; populateVectorContractLoweringPatterns(patterns, &getContext(), options); applyPatternsAndFoldGreedily(getFunction(), patterns); }