diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.h b/mlir/include/mlir/Dialect/Vector/VectorOps.h --- a/mlir/include/mlir/Dialect/Vector/VectorOps.h +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.h @@ -53,7 +53,9 @@ /// Collect a set of transformation patterns that are related to contracting /// or expanding vector operations: /// ContractionOpLowering, -/// ShapeCastOp2DDownCastRewritePattern, ShapeCastOp2DUpCastRewritePattern +/// ShapeCastOp2DDownCastRewritePattern, +/// ShapeCastOp2DUpCastRewritePattern +/// TransposeOpLowering /// OuterproductOpLowering /// These transformation express higher level vector ops in terms of more /// elementary extraction, insertion, reduction, product, and broadcast ops. diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td @@ -88,7 +88,7 @@ iterator in the iterator type list, to each dimension of an N-D vector. Examples: - + ```mlir // Simple dot product (K = 0). #contraction_accesses = [ affine_map<(i) -> (i)>, @@ -139,6 +139,7 @@ %5 = vector.contract #contraction_trait %0, %1, %2, %lhs_mask, %rhs_mask : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x8x5xf32> + ``` }]; let builders = [OpBuilder< "Builder *builder, OperationState &result, Value lhs, Value rhs, " @@ -203,7 +204,7 @@ http://llvm.org/docs/LangRef.html#experimental-vector-reduction-intrinsics Examples: - ``` + ```mlir %1 = vector.reduction "add", %0 : vector<16xf32> into f32 %3 = vector.reduction "xor", %2 : vector<4xi32> into i32 @@ -247,7 +248,7 @@ shaped vector with the same element type is always legal. Examples: - ``` + ```mlir %0 = constant 0.0 : f32 %1 = vector.broadcast %0 : f32 to vector<16xf32> %2 = vector.broadcast %1 : vector<16xf32> to vector<4x16xf32> @@ -290,7 +291,7 @@ above, all mask values are in the range [0,s_1+t_1) Examples: - ``` + ```mlir %0 = vector.shuffle %a, %b[0, 3] : vector<2xf32>, vector<2xf32> ; yields vector<2xf32> %1 = vector.shuffle %c, %b[0, 1, 2] @@ -332,7 +333,7 @@ https://llvm.org/docs/LangRef.html#extractelement-instruction Example: - ``` + ```mlir %c = constant 15 : i32 %1 = vector.extractelement %0[%c : i32]: vector<16xf32> ``` @@ -360,7 +361,7 @@ the proper position. Degenerates to an element type in the 0-D case. Examples: - ``` + ```mlir %1 = vector.extract %0[3]: vector<4x8x16xf32> %2 = vector.extract %0[3, 3, 3]: vector<4x8x16xf32> ``` @@ -396,7 +397,7 @@ Currently, only unit strides are supported. Examples: - ``` + ```mlir %0 = vector.transfer_read ...: vector<4x2xf32> %1 = vector.extract_slices %0, [2, 2], [1, 1] @@ -448,8 +449,7 @@ to the `llvm.fma.*` intrinsic. Example: - - ``` + ```mlir %3 = vector.fma %0, %1, %2: vector<8x16xf32> ``` }]; @@ -483,7 +483,7 @@ https://llvm.org/docs/LangRef.html#insertelement-instruction Example: - ``` + ```mlir %c = constant 15 : i32 %f = constant 0.0f : f32 %1 = vector.insertelement %f, %0[%c : i32]: vector<16xf32> @@ -516,7 +516,7 @@ position. Degenerates to a scalar source type when n = 0. Examples: - ``` + ```mlir %2 = vector.insert %0, %1[3]: vector<8x16xf32> into vector<4x8x16xf32> %5 = vector.insert %3, %4[3, 3, 3]: @@ -559,7 +559,7 @@ Currently, only unit strides are supported. Examples: - ``` + ```mlir %0 = vector.extract_slices %0, [2, 2], [1, 1] : vector<4x2xf32> into tuple, vector<2x2xf32>> @@ -617,7 +617,7 @@ the proper location as specified by the offsets. Examples: - ``` + ```mlir %2 = vector.insert_strided_slice %0, %1 {offsets = [0, 0, 2], strides = [1, 1]}: vector<2x4xf32> into vector<16x4x8xf32> @@ -659,8 +659,7 @@ lower to actual `fma` instructions on x86. Examples: - - ``` + ```mlir %2 = vector.outerproduct %0, %1: vector<4xf32>, vector<8xf32> return %2: vector<4x8xf32> @@ -709,8 +708,8 @@ In the examples below, valid data elements are represented by an alphabetic character, and undefined data elements are represented by '-'. - Example - + Example: + ```mlir vector<1x8xf32> with valid data shape [6], fixed vector sizes [8] input: [a, b, c, d, e, f] @@ -719,8 +718,9 @@ vector layout: [a, b, c, d, e, f, -, -] - Example - + ``` + Example: + ```mlir vector<2x8xf32> with valid data shape [10], fixed vector sizes [8] input: [a, b, c, d, e, f, g, h, i, j] @@ -729,9 +729,9 @@ vector layout: [[a, b, c, d, e, f, g, h], [i, j, -, -, -, -, -, -]] - - Example - + ``` + Example: + ```mlir vector<2x2x2x3xf32> with valid data shape [3, 5], fixed vector sizes [2, 3] @@ -750,9 +750,9 @@ [-, -, -]] [[n, o, -], [-, -, -]]]] - - Example - + ``` + Example: + ```mlir %1 = vector.reshape %0, [%c3, %c6], [%c2, %c9], [4] : vector<3x2x4xf32> to vector<2x3x4xf32> @@ -776,6 +776,7 @@ [[j, k, l, m], [n, o, p, q], [r, -, -, -]]] + ``` }]; let extraClassDeclaration = [{ @@ -828,7 +829,7 @@ `offsets` and ending at `offsets + sizes`. Examples: - ``` + ```mlir %1 = vector.strided_slice %0 {offsets = [0, 2], sizes = [2, 4], strides = [1, 1]}: vector<4x8x16xf32> to vector<2x4x16xf32> @@ -947,13 +948,12 @@ implemented using a warp-shuffle if loop `j` were mapped to `threadIdx.x`. Syntax - ``` + ```mlir operation ::= ssa-id `=` `vector.transfer_read` ssa-use-list `{` attribute-entry `} :` memref-type `,` vector-type ``` Examples: - ```mlir // Read the slice `%A[%i0, %i1:%i1+256, %i2:%i2+32]` into vector<32x256xf32> // and pad with %f0 to handle the boundary case: @@ -1028,7 +1028,7 @@ Syntax: - ``` + ```mlir operation ::= `vector.transfer_write` ssa-use-list `{` attribute-entry `} : ` vector-type ', ' memref-type ' ``` @@ -1139,7 +1139,7 @@ Syntax: - ``` + ```mlir operation ::= `vector.type_cast` ssa-use : memref-type to memref-type ``` @@ -1183,8 +1183,10 @@ define a hyper-rectangular region within which elements values are set to 1 (otherwise element values are set to 0). - Example: create a constant vector mask of size 4x3xi1 with elements in range - 0 <= row <= 2 and 0 <= col <= 1 are set to 1 (others to 0). + Example: + ``` + create a constant vector mask of size 4x3xi1 with elements in range + 0 <= row <= 2 and 0 <= col <= 1 are set to 1 (others to 0). %1 = vector.constant_mask [3, 2] : vector<4x3xi1> @@ -1196,6 +1198,7 @@ rows 1 | 1 1 0 2 | 1 1 0 3 | 0 0 0 + ``` }]; let extraClassDeclaration = [{ @@ -1217,8 +1220,10 @@ hyper-rectangular region within which elements values are set to 1 (otherwise element values are set to 0). - Example: create a vector mask of size 4x3xi1 where elements in range - 0 <= row <= 2 and 0 <= col <= 1 are set to 1 (others to 0). + Example: + ``` + create a vector mask of size 4x3xi1 where elements in range + 0 <= row <= 2 and 0 <= col <= 1 are set to 1 (others to 0). %1 = vector.create_mask %c3, %c2 : vector<4x3xi1> @@ -1230,6 +1235,7 @@ rows 1 | 1 1 0 2 | 1 1 0 3 | 0 0 0 + ``` }]; let hasCanonicalizer = 1; @@ -1248,9 +1254,8 @@ transformation and should be removed before lowering to lower-level dialects. - Examples: - ``` + ```mlir %0 = vector.transfer_read ... : vector<2x2xf32> %1 = vector.transfer_read ... : vector<2x1xf32> %2 = vector.transfer_read ... : vector<2x2xf32> @@ -1280,20 +1285,21 @@ Takes a n-D vector and returns the transposed n-D vector defined by the permutation of ranks in the n-sized integer array attribute. In the operation - - %1 = vector.tranpose %0, [i_1, .., i_n] - : vector - to vector - + ```mlir + %1 = vector.tranpose %0, [i_1, .., i_n] + : vector + to vector + ``` the transp array [i_1, .., i_n] must be a permutation of [0, .., n-1]. Example: - + ```mlir %1 = vector.tranpose %0, [1, 0] : vector<2x3xf32> to vector<3x2xf32> [ [a, b, c], [ [a, d], [d, e, f] ] -> [b, e], [c, f] ] + ``` }]; let extraClassDeclaration = [{ VectorType getVectorType() { @@ -1321,7 +1327,7 @@ dialects. Examples: - ``` + ```mlir %4 = vector.tuple %0, %1, %2, %3 : vector<2x2xf32>, vector<2x1xf32>, vector<2x2xf32>, vector<2x1xf32>> @@ -1351,7 +1357,7 @@ format (for testing and debugging). No return value. Examples: - ``` + ```mlir %0 = constant 0.0 : f32 %1 = vector.broadcast %0 : f32 to vector<4xf32> vector.print %1 : vector<4xf32> @@ -1414,7 +1420,7 @@ Example: - ``` + ```mlir %C = vector.matrix_multiply %A, %B { lhs_rows = 4: i32, lhs_columns = 16: i32 , rhs_columns = 3: i32 } : (vector<64xf64>, vector<48xf64>) -> vector<12xf64> diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -866,6 +866,67 @@ /// Progressive lowering of OuterProductOp. /// One: +/// %x = vector.transpose %y, [1, 0] +/// is replaced by: +/// %z = constant dense<0.000000e+00> +/// %0 = vector.extract %y[0, 0] +/// %1 = vector.insert %0, %z [0, 0] +/// .. +/// %x = vector.insert .., .. [.., ..] +class TransposeOpLowering : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::TransposeOp op, + PatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + VectorType resType = op.getResultType(); + Type eltType = resType.getElementType(); + + // Set up convenience transposition table. + SmallVector transp; + for (auto attr : op.transp()) + transp.push_back(attr.cast().getInt()); + + // Generate fully unrolled extract/insert ops. + Value zero = rewriter.create(loc, eltType, + rewriter.getZeroAttr(eltType)); + Value result = rewriter.create(loc, resType, zero); + SmallVector lhs(transp.size(), 0); + SmallVector rhs(transp.size(), 0); + rewriter.replaceOp(op, expandIndices(loc, resType, 0, transp, lhs, rhs, + op.vector(), result, rewriter)); + return success(); + } + +private: + // Builds the indices arrays for the lhs and rhs. Generates the extract/insert + // operation when al ranks are exhausted. + Value expandIndices(Location loc, VectorType resType, int64_t pos, + SmallVector &transp, + SmallVector &lhs, + SmallVector &rhs, Value input, Value result, + PatternRewriter &rewriter) const { + if (pos >= resType.getRank()) { + auto ridx = rewriter.getI64ArrayAttr(rhs); + auto lidx = rewriter.getI64ArrayAttr(lhs); + Type eltType = resType.getElementType(); + Value e = rewriter.create(loc, eltType, input, ridx); + return rewriter.create(loc, resType, e, result, lidx); + } + for (int64_t d = 0, e = resType.getDimSize(pos); d < e; ++d) { + lhs[pos] = d; + rhs[transp[pos]] = d; + result = expandIndices(loc, resType, pos + 1, transp, lhs, rhs, input, + result, rewriter); + } + return result; + } +}; + +/// Progressive lowering of OuterProductOp. +/// One: /// %x = vector.outerproduct %lhs, %rhs, %acc /// is replaced by: /// %z = zero-result @@ -1353,7 +1414,7 @@ OwningRewritePatternList &patterns, MLIRContext *context, VectorTransformsOptions parameters) { patterns.insert( - context); + ShapeCastOp2DUpCastRewritePattern, TransposeOpLowering, + OuterProductOpLowering>(context); patterns.insert(parameters, context); } diff --git a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir --- a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir @@ -296,6 +296,28 @@ return %0: vector<2x3xf32> } +// CHECK-LABEL: func @transpose23 +// CHECK-SAME: %[[A:.*]]: vector<2x3xf32> +// CHECK: %[[Z:.*]] = constant dense<0.000000e+00> : vector<3x2xf32> +// CHECK: %[[T0:.*]] = vector.extract %[[A]][0, 0] : vector<2x3xf32> +// CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[Z]] [0, 0] : f32 into vector<3x2xf32> +// CHECK: %[[T2:.*]] = vector.extract %[[A]][1, 0] : vector<2x3xf32> +// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [0, 1] : f32 into vector<3x2xf32> +// CHECK: %[[T4:.*]] = vector.extract %[[A]][0, 1] : vector<2x3xf32> +// CHECK: %[[T5:.*]] = vector.insert %[[T4]], %[[T3]] [1, 0] : f32 into vector<3x2xf32> +// CHECK: %[[T6:.*]] = vector.extract %[[A]][1, 1] : vector<2x3xf32> +// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T5]] [1, 1] : f32 into vector<3x2xf32> +// CHECK: %[[T8:.*]] = vector.extract %[[A]][0, 2] : vector<2x3xf32> +// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T7]] [2, 0] : f32 into vector<3x2xf32> +// CHECK: %[[T10:.*]] = vector.extract %[[A]][1, 2] : vector<2x3xf32> +// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[T9]] [2, 1] : f32 into vector<3x2xf32> +// CHECK: return %[[T11]] : vector<3x2xf32> + +func @transpose23(%arg0: vector<2x3xf32>) -> vector<3x2xf32> { + %0 = vector.transpose %arg0, [1, 0] : vector<2x3xf32> to vector<3x2xf32> + return %0 : vector<3x2xf32> +} + // Shape up and downcasts for 2-D vectors, for supporting conversion to // llvm.matrix operations // CHECK-LABEL: func @shape_casts