diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td @@ -1482,6 +1482,9 @@ //===----------------------------------------------------------------------===// // Ops used for supporting progressive lowering and conversion type changes. +// The Ops are typically not used directly by higher level dialects, but are +// used by intra-dialect rewriting rules to bring vector operations closer +// to the hardware ISA. //===----------------------------------------------------------------------===// /// Vector dialect matrix multiplication op that operates on flattened 1-D @@ -1510,12 +1513,20 @@ let description = [{ This is the counterpart of llvm.matrix.multiply in MLIR. It serves the purposes of more progressive lowering and localized type conversion. + Higher levels typically lower matrix multiplications into 'vector.contract' + operations. Subsequent rewriting rule progressively lower these operations + into 'vector.matrix_multiply' operations to bring the operations closer + to the hardware ISA. The ‘vector.matrix_multiply’ op treats `lhs` as matrix with rows and columns, `rhs` as matrix with rows and and multiplies them. The result matrix is returned embedded in the result vector. + Also see: + + http://llvm.org/docs/LangRef.html#llvm-matrix-multiply-intrinsic + Example: ```mlir @@ -1541,4 +1552,48 @@ "`:` `(` type($lhs) `,` type($rhs) `)` `->` type($res)"; } +/// Vector dialect matrix tranposition op that operates on flattened 1-D +/// MLIR vectors. This is the counterpart of llvm.matrix.transpose in MLIR. +/// This may seem redundant with vector.transpose but it serves the purposes of +/// more progressive lowering and localized type conversion on the path: +/// `vector<...x...xf32> -> vector<...xf32> -> !llvm<... x float>`. +def Vector_FlatTransposeOp : Vector_Op<"flat_transpose", [NoSideEffect, + PredOpTrait<"source operand and result have same element type", + TCresVTEtIsSameAsOpBase<0, 0>>]>, + Arguments<( + // TODO(ntv, fhahn, ajcbik): tighten vector element types that make sense. + ins VectorOfRankAndType<[1], + [AnySignlessInteger, AnySignedInteger, AnyFloat]>:$matrix, + I32Attr:$rows, I32Attr:$columns)>, + Results<( + outs VectorOfRankAndType<[1], + [AnySignlessInteger, AnySignedInteger, AnyFloat]>:$res)> { + let summary = "Vector matrix transposition on flattened 1-D MLIR vectors"; + let description = [{ + This is the counterpart of llvm.matrix.transpose in MLIR. It serves + the purposes of more progressive lowering and localized type conversion. + Higher levels typically lower matrix tranpositions into 'vector.transpose' + operations. Subsequent rewriting rule progressively lower these operations + into 'vector.flat_transpose' operations to bring the operations closer + to the hardware ISA. + + The ‘vector.flat_transpose’ op treats the 1-D input `matrix` as + a 2-D matrix with rows and columns, and returns the + transposed matrix in flattened form in 'res'. + + Also see: + + http://llvm.org/docs/LangRef.html#llvm-matrix-transpose-intrinsic + + Example: + + ```mlir + %1 = vector.flat_transpose %0 { rows = 4: i32, columns = 4: i32 } + : (vector<16xf32>) -> vector<16xf32> + ``` + }]; + let verifier = ?; + let assemblyFormat = "$matrix attr-dict `:` type($matrix) `->` type($res)"; +} + #endif // VECTOR_OPS diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -148,6 +148,27 @@ } }; +/// Conversion pattern for a vector.flat_transpose. +/// This is lowered directly to the proper llvm.intr.matrix.transpose. +class VectorFlatTransposeOpConversion : public ConvertToLLVMPattern { +public: + explicit VectorFlatTransposeOpConversion(MLIRContext *context, + LLVMTypeConverter &typeConverter) + : ConvertToLLVMPattern(vector::FlatTransposeOp::getOperationName(), + context, typeConverter) {} + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto transOp = cast(op); + auto adaptor = vector::FlatTransposeOpOperandAdaptor(operands); + rewriter.replaceOpWithNewOp( + transOp, typeConverter.convertType(transOp.res().getType()), + adaptor.matrix(), transOp.rows(), transOp.columns()); + return success(); + } +}; + class VectorReductionOpConversion : public ConvertToLLVMPattern { public: explicit VectorReductionOpConversion(MLIRContext *context, @@ -1157,6 +1178,7 @@ LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { MLIRContext *ctx = converter.getDialect()->getContext(); patterns.insert(ctx, converter); + patterns.insert(ctx, converter); } namespace { diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -952,3 +952,15 @@ // CHECK: %[[T8:.*]] = llvm.mlir.constant(3 : i64) : !llvm.i64 // CHECK: %[[T9:.*]] = llvm.insertelement %[[T0]], %[[T7]][%[[T8]] : !llvm.i64] : !llvm<"<8 x i1>"> // CHECK: llvm.return %9 : !llvm<"<8 x i1>"> + +// CHECK-LABEL: func @flat_transpose +// CHECK-SAME: %[[A:.*]]: !llvm<"<16 x float>"> +// CHECK: %[[T:.*]] = llvm.intr.matrix.transpose %[[A]] +// CHECK-SAME: {columns = 4 : i32, rows = 4 : i32} : +// CHECK-SAME: !llvm<"<16 x float>"> into !llvm<"<16 x float>"> +// CHECK: llvm.return %[[T]] : !llvm<"<16 x float>"> +func @flat_transpose(%arg0: vector<16xf32>) -> vector<16xf32> { + %0 = vector.flat_transpose %arg0 { rows = 4: i32, columns = 4: i32 } + : vector<16xf32> -> vector<16xf32> + return %0 : vector<16xf32> +} diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -1145,6 +1145,13 @@ // ----- +func @flat_transpose_type_mismatch(%arg0: vector<16xf32>) { + // expected-error@+1 {{'vector.flat_transpose' op failed to verify that source operand and result have same element type}} + %0 = vector.flat_transpose %arg0 { rows = 4: i32, columns = 4: i32 } : vector<16xf32> -> vector<16xf64> +} + +// ----- + func @type_cast_layout(%arg0: memref<4x3xf32, affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s0 + d1 * s1 + s2)>>) { // expected-error@+1 {{expects operand to be a memref with no layout}} %0 = vector.type_cast %arg0: memref<4x3xf32, affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s0 + d1 * s1 + s2)>> to memref> diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir --- a/mlir/test/Dialect/Vector/ops.mlir +++ b/mlir/test/Dialect/Vector/ops.mlir @@ -140,7 +140,7 @@ indexing_maps = #contraction_to_scalar_accesses, iterator_types = ["reduction"] } -// CHECK-LABEL: contraction_to_scalar +// CHECK-LABEL: @contraction_to_scalar func @contraction_to_scalar(%arg0: vector<10xf32>, %arg1: vector<10xf32>) -> f32 { // CHECK: %[[C0:.*]] = constant 0.000000e+00 : f32 %f0 = constant 0.0: f32 @@ -172,7 +172,7 @@ iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"] } -// CHECK-LABEL: contraction +// CHECK-LABEL: @contraction func @contraction(%arg0 : vector<7x8x16x15xf32>, %arg1 : vector<8x16x7x5xf32>, %arg2 : vector<8x15x5xf32>, %arg3 : vector<8x8x15x5xf32>, %arg4 : index) { @@ -196,7 +196,7 @@ return } -// CHECK-LABEL: create_vector_mask +// CHECK-LABEL: @create_vector_mask func @create_vector_mask() { // CHECK: %[[C2:.*]] = constant 2 : index %c2 = constant 2 : index @@ -208,14 +208,14 @@ return } -// CHECK-LABEL: constant_vector_mask +// CHECK-LABEL: @constant_vector_mask func @constant_vector_mask() { // CHECK: vector.constant_mask [3, 2] : vector<4x3xi1> %0 = vector.constant_mask [3, 2] : vector<4x3xi1> return } -// CHECK-LABEL: extract_slices +// CHECK-LABEL: @extract_slices func @extract_slices(%arg0 : vector<4x2xf32>) -> (tuple, vector<2x2xf32>>) { // CHECK: vector.extract_slices %{{.*}}, [2, 2], [1, 1] : vector<4x2xf32> into tuple, vector<2x2xf32>> @@ -227,7 +227,7 @@ return %3 : tuple, vector<2x2xf32>> } -// CHECK-LABEL: insert_slices +// CHECK-LABEL: @insert_slices func @insert_slices(%arg0 : tuple, vector<2x2xf32>>) -> (vector<4x2xf32>) { // CHECK: vector.insert_slices %{{.*}}, [2, 2], [1, 1] : tuple, vector<2x2xf32>> into vector<4x2xf32> @@ -243,7 +243,7 @@ return } -// CHECK-LABEL: reshape +// CHECK-LABEL: @reshape func @reshape(%arg0 : vector<3x2x4xf32>) -> (vector<2x3x4xf32>) { // CHECK: %[[C2:.*]] = constant 2 : index %c2 = constant 2 : index @@ -260,7 +260,7 @@ return %1 : vector<2x3x4xf32> } -// CHECK-LABEL: shape_cast +// CHECK-LABEL: @shape_cast func @shape_cast(%arg0 : vector<5x1x3x2xf32>, %arg1 : tuple, vector<3x4x2xf32>>) -> (vector<15x2xf32>, tuple, vector<12x2xf32>>) { @@ -284,7 +284,7 @@ return } -// CHECK-LABEL: reduce_fp +// CHECK-LABEL: @reduce_fp func @reduce_fp(%arg0: vector<16xf32>, %arg1: f32) -> f32 { // CHECK: vector.reduction "add", %{{.*}} : vector<16xf32> into f32 vector.reduction "add", %arg0 : vector<16xf32> into f32 @@ -302,7 +302,7 @@ return %0 : f32 } -// CHECK-LABEL: reduce_int +// CHECK-LABEL: @reduce_int func @reduce_int(%arg0: vector<16xi32>) -> i32 { // CHECK: vector.reduction "add", %{{.*}} : vector<16xi32> into i32 vector.reduction "add", %arg0 : vector<16xi32> into i32 @@ -322,14 +322,34 @@ return %0 : i32 } -// CHECK-LABEL: transpose_fp +// CHECK-LABEL: @transpose_fp func @transpose_fp(%arg0: vector<3x7xf32>) -> vector<7x3xf32> { + // CHECK: %[[X:.*]] = vector.transpose %{{.*}}, [1, 0] : vector<3x7xf32> to vector<7x3xf32> %0 = vector.transpose %arg0, [1, 0] : vector<3x7xf32> to vector<7x3xf32> + // CHECK: return %[[X]] : vector<7x3xf32> return %0 : vector<7x3xf32> } -// CHECK-LABEL: transpose_int +// CHECK-LABEL: @transpose_int func @transpose_int(%arg0: vector<11x7x3x2xi32>) -> vector<2x11x7x3xi32> { + // CHECK: %[[X:.*]] = vector.transpose %{{.*}}, [3, 0, 1, 2] : vector<11x7x3x2xi32> to vector<2x11x7x3xi32> %0 = vector.transpose %arg0, [3, 0, 1, 2] : vector<11x7x3x2xi32> to vector<2x11x7x3xi32> + // CHECK: return %[[X]] : vector<2x11x7x3xi32> return %0 : vector<2x11x7x3xi32> } + +// CHECK-LABEL: @flat_transpose_fp +func @flat_transpose_fp(%arg0: vector<16xf32>) -> vector<16xf32> { + // CHECK: %[[X:.*]] = vector.flat_transpose %{{.*}} {columns = 4 : i32, rows = 4 : i32} : vector<16xf32> -> vector<16xf32> + %0 = vector.flat_transpose %arg0 { rows = 4: i32, columns = 4: i32 } : vector<16xf32> -> vector<16xf32> + // CHECK: return %[[X]] : vector<16xf32> + return %0 : vector<16xf32> +} + +// CHECK-LABEL: @flat_transpose_int +func @flat_transpose_int(%arg0: vector<16xi32>) -> vector<16xi32> { + // CHECK: %[[X:.*]] = vector.flat_transpose %{{.*}} {columns = 8 : i32, rows = 2 : i32} : vector<16xi32> -> vector<16xi32> + %0 = vector.flat_transpose %arg0 { rows = 2: i32, columns = 8: i32 } : vector<16xi32> -> vector<16xi32> + // CHECK: return %[[X]] : vector<16xi32> + return %0 : vector<16xi32> +}