diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.h b/mlir/include/mlir/Dialect/Vector/VectorOps.h --- a/mlir/include/mlir/Dialect/Vector/VectorOps.h +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.h @@ -56,6 +56,7 @@ /// Structure to control the behavior of vector transform patterns. struct VectorTransformsOptions { VectorContractLowering vectorContractLowering = VectorContractLowering::FMA; + bool vectorUseFlatTranspose = false; }; /// Collect a set of transformation patterns that are related to contracting diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td @@ -1206,6 +1206,7 @@ } }]; let assemblyFormat = "$source attr-dict `:` type($source) `to` type($result)"; + let hasFolder = 1; } def Vector_TypeCastOp : diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -1667,6 +1667,19 @@ return success(); } +OpFoldResult ShapeCastOp::fold(ArrayRef operands) { + // Nop shape cast. + if (source().getType() == result().getType()) + return source(); + + // Canceling shape casts. + if (auto otherOp = source().getDefiningOp()) + if (result().getType() == otherOp.source().getType()) + return otherOp.source(); + + return {}; +} + //===----------------------------------------------------------------------===// // TypeCastOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -1186,6 +1186,11 @@ public: using OpRewritePattern::OpRewritePattern; + TransposeOpLowering(vector::VectorTransformsOptions vectorTransformsOptions, + MLIRContext *context) + : OpRewritePattern(context), + vectorTransformsOptions(vectorTransformsOptions) {} + LogicalResult matchAndRewrite(vector::TransposeOp op, PatternRewriter &rewriter) const override { auto loc = op.getLoc(); @@ -1197,6 +1202,21 @@ for (auto attr : op.transp()) transp.push_back(attr.cast().getInt()); + // Handle a true 2-D matrix tranpose differently when requested. + if (vectorTransformsOptions.vectorUseFlatTranspose && + resType.getRank() == 2 && transp[0] == 1 && transp[1] == 0) { + Type flattenedType = + VectorType::get(resType.getNumElements(), resType.getElementType()); + auto matrix = + rewriter.create(loc, flattenedType, op.vector()); + auto rows = rewriter.getI32IntegerAttr(resType.getShape()[0]); + auto columns = rewriter.getI32IntegerAttr(resType.getShape()[1]); + Value trans = rewriter.create( + loc, flattenedType, matrix, rows, columns); + rewriter.replaceOpWithNewOp(op, resType, trans); + return success(); + } + // Generate fully unrolled extract/insert ops. Value result = rewriter.create(loc, resType, rewriter.getZeroAttr(resType)); @@ -1230,6 +1250,9 @@ } return result; } + + /// Options to control the vector patterns. + vector::VectorTransformsOptions vectorTransformsOptions; }; /// Progressive lowering of OuterProductOp. @@ -1829,9 +1852,9 @@ ConstantMaskOpLowering, OuterProductOpLowering, ShapeCastOp2DDownCastRewritePattern, - ShapeCastOp2DUpCastRewritePattern, - TransposeOpLowering>(context); - patterns.insert(context); + patterns.insert(parameters, context); // clang-format on diff --git a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir --- a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir @@ -319,6 +319,26 @@ return %0 : vector<3x2xf32> } + +// CHECK-LABEL: func @nop_shape_cast +// CHECK-SAME: %[[A:.*]]: vector<16xf32> +// CHECK: return %[[A]] : vector<16xf32> + +func @nop_shape_cast(%arg0: vector<16xf32>) -> vector<16xf32> { + %0 = vector.shape_cast %arg0 : vector<16xf32> to vector<16xf32> + return %0 : vector<16xf32> +} + +// CHECK-LABEL: func @cancel_shape_cast +// CHECK-SAME: %[[A:.*]]: vector<16xf32> +// CHECK: return %[[A]] : vector<16xf32> + +func @cancel_shape_cast(%arg0: vector<16xf32>) -> vector<16xf32> { + %0 = vector.shape_cast %arg0 : vector<16xf32> to vector<4x4xf32> + %1 = vector.shape_cast %0 : vector<4x4xf32> to vector<16xf32> + return %1 : vector<16xf32> +} + // Shape up and downcasts for 2-D vectors, for supporting conversion to // llvm.matrix operations // CHECK-LABEL: func @shape_casts diff --git a/mlir/test/Dialect/Vector/vector-flat-transforms.mlir b/mlir/test/Dialect/Vector/vector-flat-transforms.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Vector/vector-flat-transforms.mlir @@ -0,0 +1,62 @@ +// RUN: mlir-opt %s -test-vector-contraction-conversion=vector-flat-transpose=1 | FileCheck %s --dump-input-on-failure + +// Tests for lowering 2-D vector.transpose into vector.flat_tranpose. +// +// TODO(ajcbik,ntv): having ShapeCastOp2DDownCastRewritePattern and +// ShapeCastOp2DUpCastRewritePattern too early in +// the greedy rewriting patterns misses opportunities +// to fold shape casts! + +// No shape cast folding expected. +// +// CHECK-LABEL: func @transpose44_44( +// CHECK-SAME: %[[A:.*]]: vector<4x4xf32> +// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<4x4xf32> +// CHECK: %[[T8:.*]] = vector.flat_transpose %{{.*}} {columns = 4 : i32, rows = 4 : i32} : vector<16xf32> -> vector<16xf32> +// CHECK: %[[T9:.*]] = vector.extract_strided_slice %[[T8]] {offsets = [0], sizes = [4], strides = [1]} : vector<16xf32> to vector<4xf32> +// +func @transpose44_44(%arg0: vector<4x4xf32>) -> vector<4x4xf32> { + %0 = vector.transpose %arg0, [1, 0] : vector<4x4xf32> to vector<4x4xf32> + return %0 : vector<4x4xf32> +} + +// Folds preceding shape cast as expected, +// no following shape cast folding expected. +// +// CHECK-LABEL: func @transpose16_44( +// CHECK-SAME: %[[A:.*]]: vector<16xf32> +// CHECK: %[[T0:.*]] = vector.flat_transpose %[[A]] {columns = 4 : i32, rows = 4 : i32} : vector<16xf32> -> vector<16xf32> +// CHECK: %[[T1:.*]] = vector.extract_strided_slice %[[T0]] {offsets = [0], sizes = [4], strides = [1]} : vector<16xf32> to vector<4xf32> +// +func @transpose16_44(%arg0: vector<16xf32>) -> vector<4x4xf32> { + %0 = vector.shape_cast %arg0 : vector<16xf32> to vector<4x4xf32> + %1 = vector.transpose %0, [1, 0] : vector<4x4xf32> to vector<4x4xf32> + return %1 : vector<4x4xf32> +} + +// No preceding shape cast folding expected, +// but FAILS to fold following cast. +// +// CHECK-LABEL: func @transpose44_16( +// CHECK-SAME: %[[A:.*]]: vector<4x4xf32> +// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<4x4xf32> +// CHECK: %[[T8:.*]] = vector.flat_transpose %{{.*}} {columns = 4 : i32, rows = 4 : i32} : vector<16xf32> -> vector<16xf32> +func @transpose44_16(%arg0: vector<4x4xf32>) -> vector<16xf32> { + %0 = vector.transpose %arg0, [1, 0] : vector<4x4xf32> to vector<4x4xf32> + %1 = vector.shape_cast %0 : vector<4x4xf32> to vector<16xf32> + return %1 : vector<16xf32> +} + +// Folds preceding shape cast as expected, +// but FAILS to fold following cast. +// +// CHECK-LABEL: func @transpose16_16( +// CHECK-SAME: %[[A:.*]]: vector<16xf32> +// CHECK: %[[T0:.*]] = vector.flat_transpose %[[A]] {columns = 4 : i32, rows = 4 : i32} : vector<16xf32> -> vector<16xf32> +// +func @transpose16_16(%arg0: vector<16xf32>) -> vector<16xf32> { + %0 = vector.shape_cast %arg0 : vector<16xf32> to vector<4x4xf32> + %1 = vector.transpose %0, [1, 0] : vector<4x4xf32> to vector<4x4xf32> + %2 = vector.shape_cast %1 : vector<4x4xf32> to vector<16xf32> + return %2 : vector<16xf32> +} diff --git a/mlir/test/lib/Transforms/TestVectorTransforms.cpp b/mlir/test/lib/Transforms/TestVectorTransforms.cpp --- a/mlir/test/lib/Transforms/TestVectorTransforms.cpp +++ b/mlir/test/lib/Transforms/TestVectorTransforms.cpp @@ -51,6 +51,10 @@ *this, "vector-lower-matrix-intrinsics", llvm::cl::desc("Lower vector.contract to llvm.intr.matrix.multiply"), llvm::cl::init(false)}; + Option flatTranspose{ + *this, "vector-flat-transpose", + llvm::cl::desc("Lower 2-D vector.transpose to vector.flat_tranpose"), + llvm::cl::init(false)}; Option lowerToOuterProduct{ *this, "vector-outerproduct", llvm::cl::desc("Lower vector.contract to vector.outerproduct"), @@ -70,7 +74,7 @@ VectorContractLowering lowering = VectorContractLowering::FMA; if (lowerToLLVMMatrixIntrinsics) lowering = VectorContractLowering::Matmul; - VectorTransformsOptions options{lowering}; + VectorTransformsOptions options{lowering, flatTranspose}; populateVectorContractLoweringPatterns(patterns, &getContext(), options); applyPatternsAndFoldGreedily(getFunction(), patterns); }