diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.h b/mlir/include/mlir/Dialect/Vector/VectorOps.h --- a/mlir/include/mlir/Dialect/Vector/VectorOps.h +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.h @@ -53,7 +53,9 @@ /// Collect a set of transformation patterns that are related to contracting /// or expanding vector operations: /// ContractionOpLowering, -/// ShapeCastOp2DDownCastRewritePattern, ShapeCastOp2DUpCastRewritePattern +/// ShapeCastOp2DDownCastRewritePattern, +/// ShapeCastOp2DUpCastRewritePattern +/// TransposeOpLowering /// OuterproductOpLowering /// These transformation express higher level vector ops in terms of more /// elementary extraction, insertion, reduction, product, and broadcast ops. diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td @@ -88,7 +88,7 @@ iterator in the iterator type list, to each dimension of an N-D vector. Examples: - + ``` // Simple dot product (K = 0). #contraction_accesses = [ affine_map<(i) -> (i)>, @@ -139,6 +139,7 @@ %5 = vector.contract #contraction_trait %0, %1, %2, %lhs_mask, %rhs_mask : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x8x5xf32> + ``` }]; let builders = [OpBuilder< "Builder *builder, OperationState &result, Value lhs, Value rhs, " @@ -448,7 +449,6 @@ to the `llvm.fma.*` intrinsic. Example: - ``` %3 = vector.fma %0, %1, %2: vector<8x16xf32> ``` @@ -659,7 +659,6 @@ lower to actual `fma` instructions on x86. Examples: - ``` %2 = vector.outerproduct %0, %1: vector<4xf32>, vector<8xf32> return %2: vector<4x8xf32> @@ -709,8 +708,8 @@ In the examples below, valid data elements are represented by an alphabetic character, and undefined data elements are represented by '-'. - Example - + Example: + ``` vector<1x8xf32> with valid data shape [6], fixed vector sizes [8] input: [a, b, c, d, e, f] @@ -719,8 +718,9 @@ vector layout: [a, b, c, d, e, f, -, -] - Example - + ``` + Example: + ``` vector<2x8xf32> with valid data shape [10], fixed vector sizes [8] input: [a, b, c, d, e, f, g, h, i, j] @@ -729,9 +729,9 @@ vector layout: [[a, b, c, d, e, f, g, h], [i, j, -, -, -, -, -, -]] - - Example - + ``` + Example: + ``` vector<2x2x2x3xf32> with valid data shape [3, 5], fixed vector sizes [2, 3] @@ -750,9 +750,9 @@ [-, -, -]] [[n, o, -], [-, -, -]]]] - - Example - + ``` + Example: + ``` %1 = vector.reshape %0, [%c3, %c6], [%c2, %c9], [4] : vector<3x2x4xf32> to vector<2x3x4xf32> @@ -776,6 +776,7 @@ [[j, k, l, m], [n, o, p, q], [r, -, -, -]]] + ``` }]; let extraClassDeclaration = [{ @@ -953,7 +954,6 @@ ``` Examples: - ```mlir // Read the slice `%A[%i0, %i1:%i1+256, %i2:%i2+32]` into vector<32x256xf32> // and pad with %f0 to handle the boundary case: @@ -1183,8 +1183,10 @@ define a hyper-rectangular region within which elements values are set to 1 (otherwise element values are set to 0). - Example: create a constant vector mask of size 4x3xi1 with elements in range - 0 <= row <= 2 and 0 <= col <= 1 are set to 1 (others to 0). + Example: + ``` + create a constant vector mask of size 4x3xi1 with elements in range + 0 <= row <= 2 and 0 <= col <= 1 are set to 1 (others to 0). %1 = vector.constant_mask [3, 2] : vector<4x3xi1> @@ -1196,6 +1198,7 @@ rows 1 | 1 1 0 2 | 1 1 0 3 | 0 0 0 + ``` }]; let extraClassDeclaration = [{ @@ -1217,8 +1220,10 @@ hyper-rectangular region within which elements values are set to 1 (otherwise element values are set to 0). - Example: create a vector mask of size 4x3xi1 where elements in range - 0 <= row <= 2 and 0 <= col <= 1 are set to 1 (others to 0). + Example: + ``` + create a vector mask of size 4x3xi1 where elements in range + 0 <= row <= 2 and 0 <= col <= 1 are set to 1 (others to 0). %1 = vector.create_mask %c3, %c2 : vector<4x3xi1> @@ -1230,6 +1235,7 @@ rows 1 | 1 1 0 2 | 1 1 0 3 | 0 0 0 + ``` }]; let hasCanonicalizer = 1; @@ -1248,7 +1254,6 @@ transformation and should be removed before lowering to lower-level dialects. - Examples: ``` %0 = vector.transfer_read ... : vector<2x2xf32> @@ -1280,20 +1285,21 @@ Takes a n-D vector and returns the transposed n-D vector defined by the permutation of ranks in the n-sized integer array attribute. In the operation - - %1 = vector.tranpose %0, [i_1, .., i_n] - : vector - to vector - + ```mlir + %1 = vector.tranpose %0, [i_1, .., i_n] + : vector + to vector + ``` the transp array [i_1, .., i_n] must be a permutation of [0, .., n-1]. Example: - + ``` %1 = vector.tranpose %0, [1, 0] : vector<2x3xf32> to vector<3x2xf32> [ [a, b, c], [ [a, d], [d, e, f] ] -> [b, e], [c, f] ] + ``` }]; let extraClassDeclaration = [{ VectorType getVectorType() { diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -864,6 +864,67 @@ } }; +/// Progressive lowering of OuterProductOp. +/// One: +/// %x = vector.transpose %y, [1, 0] +/// is replaced by: +/// %z = constant dense<0.000000e+00> +/// %0 = vector.extract %y[0, 0] +/// %1 = vector.insert %0, %z [0, 0] +/// .. +/// %x = vector.insert .., .. [.., ..] +class TransposeOpLowering : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::TransposeOp op, + PatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + VectorType resType = op.getResultType(); + Type eltType = resType.getElementType(); + + // Set up convenience transposition table. + SmallVector transp; + for (auto attr : op.transp()) + transp.push_back(attr.cast().getInt()); + + // Generate fully unrolled extract/insert ops. + Value zero = rewriter.create(loc, eltType, + rewriter.getZeroAttr(eltType)); + Value result = rewriter.create(loc, resType, zero); + SmallVector lhs(transp.size(), 0); + SmallVector rhs(transp.size(), 0); + rewriter.replaceOp(op, expandIndices(loc, resType, 0, transp, lhs, rhs, + op.vector(), result, rewriter)); + return success(); + } + +private: + // Builds the indices arrays for the lhs and rhs. Generates the extract/insert + // operation when al ranks are exhausted. + Value expandIndices(Location loc, VectorType resType, int64_t pos, + SmallVector &transp, + SmallVector &lhs, + SmallVector &rhs, Value input, Value result, + PatternRewriter &rewriter) const { + if (pos >= resType.getRank()) { + auto ridx = rewriter.getI64ArrayAttr(rhs); + auto lidx = rewriter.getI64ArrayAttr(lhs); + Type eltType = resType.getElementType(); + Value e = rewriter.create(loc, eltType, input, ridx); + return rewriter.create(loc, resType, e, result, lidx); + } + for (int64_t d = 0, e = resType.getDimSize(pos); d < e; ++d) { + lhs[pos] = d; + rhs[transp[pos]] = d; + result = expandIndices(loc, resType, pos + 1, transp, lhs, rhs, input, + result, rewriter); + } + return result; + } +}; + /// Progressive lowering of OuterProductOp. /// One: /// %x = vector.outerproduct %lhs, %rhs, %acc @@ -1353,7 +1414,7 @@ OwningRewritePatternList &patterns, MLIRContext *context, VectorTransformsOptions parameters) { patterns.insert( - context); + ShapeCastOp2DUpCastRewritePattern, TransposeOpLowering, + OuterProductOpLowering>(context); patterns.insert(parameters, context); } diff --git a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir --- a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir @@ -296,6 +296,28 @@ return %0: vector<2x3xf32> } +// CHECK-LABEL: func @transpose23 +// CHECK-SAME: %[[A:.*]]: vector<2x3xf32> +// CHECK: %[[Z:.*]] = constant dense<0.000000e+00> : vector<3x2xf32> +// CHECK: %[[T0:.*]] = vector.extract %[[A]][0, 0] : vector<2x3xf32> +// CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[Z]] [0, 0] : f32 into vector<3x2xf32> +// CHECK: %[[T2:.*]] = vector.extract %[[A]][1, 0] : vector<2x3xf32> +// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [0, 1] : f32 into vector<3x2xf32> +// CHECK: %[[T4:.*]] = vector.extract %[[A]][0, 1] : vector<2x3xf32> +// CHECK: %[[T5:.*]] = vector.insert %[[T4]], %[[T3]] [1, 0] : f32 into vector<3x2xf32> +// CHECK: %[[T6:.*]] = vector.extract %[[A]][1, 1] : vector<2x3xf32> +// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T5]] [1, 1] : f32 into vector<3x2xf32> +// CHECK: %[[T8:.*]] = vector.extract %[[A]][0, 2] : vector<2x3xf32> +// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T7]] [2, 0] : f32 into vector<3x2xf32> +// CHECK: %[[T10:.*]] = vector.extract %[[A]][1, 2] : vector<2x3xf32> +// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[T9]] [2, 1] : f32 into vector<3x2xf32> +// CHECK: return %[[T11]] : vector<3x2xf32> + +func @transpose23(%arg0: vector<2x3xf32>) -> vector<3x2xf32> { + %0 = vector.transpose %arg0, [1, 0] : vector<2x3xf32> to vector<3x2xf32> + return %0 : vector<3x2xf32> +} + // Shape up and downcasts for 2-D vectors, for supporting conversion to // llvm.matrix operations // CHECK-LABEL: func @shape_casts