diff --git a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h --- a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h +++ b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h @@ -15,6 +15,11 @@ class ModuleOp; template class OpPassBase; +/// Collect a set of patterns to convert from Vector contractions to LLVM Matrix +/// Intrinsics. This needs to go through memory atm. +void populateVectorToLLVMMatrixConversionPatterns( + LLVMTypeConverter &converter, OwningRewritePatternList &patterns); + /// Collect a set of patterns to convert from the Vector dialect to LLVM. void populateVectorToLLVMConversionPatterns(LLVMTypeConverter &converter, OwningRewritePatternList &patterns); diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -797,12 +797,12 @@ : LLVM_OneResultOp<"intr.matrix.multiply">, Arguments<( ins LLVM_Type:$lhs, LLVM_Type:$rhs, - I32Attr:$lhs_rows, I32Attr:$lhs_columns, I32Attr:$rhs_rows)> { + I32Attr:$lhs_rows, I32Attr:$lhs_columns, I32Attr:$rhs_columns)> { string llvmBuilder = [{ llvm::MatrixBuilder mb(builder); $res = mb.CreateMatrixMultiply( $lhs, $rhs, $lhs_rows.getZExtValue(), $lhs_columns.getZExtValue(), - $rhs_rows.getZExtValue()); + $rhs_columns.getZExtValue()); }]; let assemblyFormat = "$lhs `,` $rhs attr-dict " "`:` `(` type($lhs) `,` type($rhs) `)` `->` type($res)"; diff --git a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h --- a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h +++ b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h @@ -17,11 +17,23 @@ #ifndef MLIR_DIALECT_UTILS_STRUCTUREDOPSUTILS_H #define MLIR_DIALECT_UTILS_STRUCTUREDOPSUTILS_H +#include "mlir/IR/AffineMap.h" #include "mlir/IR/Attributes.h" #include "mlir/Support/LLVM.h" #include "llvm/ADT/StringRef.h" namespace mlir { + +inline bool isMatmul(ArrayAttr indexingMaps) { + AffineExpr m, n, k; + bindDims(indexingMaps.getContext(), m, n, k); + auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {m, k})); + auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {k, n})); + auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n})); + auto maps = ArrayAttr::get({mapA, mapB, mapC}, indexingMaps.getContext()); + return indexingMaps == maps; +} + /// Attribute name for the AffineArrayAttr which encodes the relationship /// between a structured op iterators' and its operands. constexpr StringRef getIndexingMapsAttrName() { return "indexing_maps"; } diff --git a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td --- a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td +++ b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td @@ -1091,7 +1091,12 @@ described above is applied to each source/result tuple element pair. It is currently assumed that this operation does not require moving data, - and that it will be canonicalized away before lowering vector operations. + and that it will be folded away before lowering vector operations. + + There is an exception to the folding expectation when targeting + llvm.intr.matrix operations. We need a type conversion back and forth from a + 2-D MLIR vector to a 1-D flattened LLVM vector.shape_cast lowering to LLVM + is supported in that particular case, for now. Examples: @@ -1108,6 +1113,14 @@ tuple, vector<9x2xf32>> ``` }]; + let extraClassDeclaration = [{ + VectorType getSourceVectorType() { + return source().getType().cast(); + } + VectorType getResultVectorType() { + return getResult().getType().cast(); + } + }]; let assemblyFormat = "$source attr-dict `:` type($source) `to` type($result)"; } @@ -1323,4 +1336,69 @@ let assemblyFormat = "$source attr-dict `:` type($source)"; } +//===----------------------------------------------------------------------===// +// Ops used for supporting progressive lowering and conversion type changes. +//===----------------------------------------------------------------------===// + +/// Vector dialect matrix multiplication op that operates on flattened 1-D +/// MLIR vectors. This is the counterpart of llvm.matrix.multiply in MLIR. +/// This may seem redundant with vector.contract but it serves the purposes of +/// more progressive lowering and localized type conversion on the path: +/// `vector<...x...xf32> -> vector<...xf32> -> !llvm<... x float>`. +def Vector_MatmulOp : Vector_Op<"matrix_multiply", [NoSideEffect, + PredOpTrait<"lhs operand and result have same element type", + TCresVTEtIsSameAsOpBase<0, 0>>, + PredOpTrait<"rhs operand and result have same element type", + TCresVTEtIsSameAsOpBase<0, 1>>]>, + Arguments<( + // TODO(ntv, fhahn): tighten vector element types that make sense. + ins VectorOfRankAndType<[1], + [AnySignlessInteger, AnySignedInteger, AnyFloat]>:$lhs, + VectorOfRankAndType<[1], + [AnySignlessInteger, AnySignedInteger, AnyFloat]>:$rhs, + I32Attr:$lhs_rows, I32Attr:$lhs_columns, I32Attr:$rhs_columns)>, + Results<( + outs VectorOfRankAndType<[1], + [AnySignlessInteger, AnySignedInteger, AnyFloat]>:$res)> +{ + let summary = "Vector matrix multiplication op that operates on flattened 1-D" + " MLIR vectors"; + let description = [{ + This is the counterpart of llvm.matrix.multiply in MLIR. It serves the + purposes of more progressive lowering and localized type conversion. + + The ‘vector.matrix_multiply’ op treats `lhs` as matrix with rows + and columns, `rhs` as matrix with rows and + and multiplies them. The result matrix is returned embedded in + the result vector. + + One big difference is that MLIR vector ops assume row major layout but for + LLVM matrix operations: "Currently column-major layout is assumed." as per + the LLVM LangRef. + + Example: + + ``` + %C = vector.matrix_multiply %A, %B + { lhs_rows = 4: i32, lhs_columns = 16: i32 , rhs_columns = 3: i32 } : + (vector<64xf64>, vector<48xf64>) -> vector<12xf64> + ``` + }]; + let builders = [ + OpBuilder<"Builder *builder, OperationState &result, Value lhs, Value rhs, " + "unsigned lhsRows, unsigned lhsColumns, unsigned rhsColumns", + [{ + result.addOperands({lhs, rhs}); + result.addAttribute("lhs_rows", builder->getI32IntegerAttr(lhsRows)); + result.addAttribute("lhs_columns", builder->getI32IntegerAttr(lhsColumns)); + result.addAttribute("rhs_columns", builder->getI32IntegerAttr(rhsColumns)); + result.addTypes(VectorType::get(lhsRows * lhsColumns, + lhs.getType().cast().getElementType())); + }]>, + ]; + let verifier = ?; + let assemblyFormat = "$lhs `,` $rhs attr-dict " + "`:` `(` type($lhs) `,` type($rhs) `)` `->` type($res)"; +} + #endif // VECTOR_OPS diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -500,7 +500,6 @@ IsVectorOfLengthPred, " of length " # StrJoinInt.result>; - // Any vector where the number of elements is from the given // `allowedLengths` list and the type is from the given `allowedTypes` // list @@ -511,6 +510,28 @@ VectorOf.description # VectorOfLength.description>; +// Whether the rank of a vector is from the given `allowedRanks` list +class IsVectorOfRankPred allowedRanks> : + And<[IsVectorTypePred, + Or().getRank() + == }] + # allowedlength>)>]>; + +// Any vector where the rank is from the given `allowedRanks` list +class VectorOfRank allowedRanks> : Type< + IsVectorOfRankPred, + " of ranks " # StrJoinInt.result>; + +// Any vector where the rank is from the given `allowedRanks` list and the type +// is from the given `allowedTypes` list +class VectorOfRankAndType allowedRanks, + list allowedTypes> : Type< + And<[VectorOf.predicate, + VectorOfRank.predicate]>, + VectorOf.description # + VectorOfRank.description>; + def AnyVector : VectorOf<[AnyType]>; // Tensor types. diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -275,6 +275,34 @@ } }; +/// Conversion pattern for a vector.matrix_multiply. +/// This is lowered directly to the proper llvm.intr.matrix.multiply. +class VectorMatmulOpConversion : public ConvertToLLVMPattern { +public: + explicit VectorMatmulOpConversion(MLIRContext *context, + LLVMTypeConverter &typeConverter) + : ConvertToLLVMPattern(vector::MatmulOp::getOperationName(), context, + typeConverter) {} + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto matmulOp = cast(op); + auto adaptor = vector::MatmulOpOperandAdaptor(operands); + // MLIR assumes row-major layout but LLVM matrix intrinsics assume + // column-major layout. + auto lhs = adaptor.rhs(); + auto lhsRows = matmulOp.lhs_columns(); + auto lhsColumns = matmulOp.rhs_columns(); + auto rhs = adaptor.lhs(); + auto rhsRows = matmulOp.lhs_rows(); + rewriter.replaceOpWithNewOp( + op, typeConverter.convertType(matmulOp.res().getType()), lhs, + rhs, lhsRows, lhsColumns, rhsRows); + return matchSuccess(); + } +}; + class VectorReductionOpConversion : public ConvertToLLVMPattern { public: explicit VectorReductionOpConversion(MLIRContext *context, @@ -1145,8 +1173,15 @@ struct LowerVectorToLLVMPass : public ModulePass { void runOnModule() override; }; + } // namespace +void mlir::populateVectorToLLVMMatrixConversionPatterns( + LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { + MLIRContext *ctx = converter.getDialect()->getContext(); + patterns.insert(ctx, converter); +} + void LowerVectorToLLVMPass::runOnModule() { // Perform progressive lowering of operations on "slices" and // all contraction operations. Also applies folding and DCE. @@ -1161,6 +1196,7 @@ LLVMTypeConverter converter(&getContext()); OwningRewritePatternList patterns; populateVectorToLLVMConversionPatterns(converter, patterns); + populateVectorToLLVMMatrixConversionPatterns(converter, patterns); populateStdToLLVMConversionPatterns(converter, patterns); LLVMConversionTarget target(getContext()); diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp @@ -15,6 +15,7 @@ #include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" +#include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/Dialect/VectorOps/VectorOps.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/Matchers.h" @@ -145,16 +146,8 @@ // TODO(ntv) should be Tablegen'd from a single source that generates the op // itself. static bool isMatmul(linalg::GenericOp genericOp) { - auto *ctx = genericOp.getContext(); - auto m = getAffineDimExpr(0, ctx); - auto n = getAffineDimExpr(1, ctx); - auto k = getAffineDimExpr(2, ctx); - auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {m, k})); - auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {k, n})); - auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n})); - auto maps = ArrayAttr::get({mapA, mapB, mapC}, ctx); return genericOp.getNumInputs() == 2 && genericOp.getNumOutputs() == 1 && - genericOp.indexing_maps() == maps && hasMultiplyAddBody(genericOp); + isMatmul(genericOp.indexing_maps()) && hasMultiplyAddBody(genericOp); } // TODO(ntv, ataei): This is in fact much more general than just vectorization @@ -172,7 +165,7 @@ return success(); auto genericOp = dyn_cast(op); - if (!genericOp || !isMatmul(genericOp)) + if (!genericOp || !::isMatmul(genericOp)) return failure(); // TODO(ntv): non-identity layout. diff --git a/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp b/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp --- a/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp +++ b/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp @@ -42,6 +42,11 @@ using llvm::dbgs; using mlir::functional::zipMap; +static llvm::cl::opt lowerToLLVMMatrixIntrinsics( + "vector-lower-matrix-intrinsics", + llvm::cl::desc("Lower vector.contract to llvm.intr.matrix.multiply"), + llvm::cl::init(false)); + /// Given a shape with sizes greater than 0 along all dimensions, /// returns the distance, in number of elements, between a slice in a dimension /// and the next slice in the same dimension. @@ -888,6 +893,38 @@ if (llvm::size(op.masks()) != 0) return matchFailure(); + // TODO(ntv, ajcbik): implement benefits, cost models, separate this out in + // a new pattern. + if (lowerToLLVMMatrixIntrinsics && isMatmul(op.indexing_maps())) { + VectorType lhsType = op.lhs().getType().cast(); + VectorType rhsType = op.rhs().getType().cast(); + Type flattenedLHSType = + VectorType::get(lhsType.getNumElements(), lhsType.getElementType()); + Type flattenedRHSType = + VectorType::get(rhsType.getNumElements(), rhsType.getElementType()); + auto lhs = rewriter.create( + op.getLoc(), flattenedLHSType, op.lhs()); + auto rhs = rewriter.create( + op.getLoc(), flattenedRHSType, op.rhs()); + + unsigned lhsRows = op.getLhsType().getShape()[0]; + unsigned lhsColumns = op.getLhsType().getShape()[1]; + unsigned rhsColumns = op.getRhsType().getShape()[1]; + Value mul = rewriter.create( + op.getLoc(), lhs, rhs, lhsRows, lhsColumns, rhsColumns); + mul = rewriter.create(op.getLoc(), + op.acc().getType(), mul); + Type elementType = op.getLhsType().getElementType(); + if (elementType.isIntOrFloat()) { + if (elementType.isSignedInteger() || elementType.isSignedInteger()) + rewriter.replaceOpWithNewOp(op, op.acc(), mul); + else + rewriter.replaceOpWithNewOp(op, op.acc(), mul); + return matchSuccess(); + } + return matchFailure(); + } + // Find first batch dimension in LHS/RHS, and lower when found. std::vector> batchDimMap = op.getBatchDimMap(); if (!batchDimMap.empty()) { @@ -1171,6 +1208,75 @@ } }; +/// ShapeOp 2D -> 1D downcast serves the purpose of flattening 2-D to 1-D +/// vectors progressively on the way to target llvm.matrix intrinsics. +/// This iterates over the most major dimension of the 2-D vector and performs +/// rewrites into: +/// vector.extract from 2-D + vector.insert_strided_slice offset into 1-D +class ShapeCastOp2DDownCastRewritePattern + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + PatternMatchResult matchAndRewrite(vector::ShapeCastOp op, + PatternRewriter &rewriter) const override { + auto sourceVectorType = op.getSourceVectorType(); + auto resultVectorType = op.getResultVectorType(); + if (sourceVectorType.getRank() != 2 || resultVectorType.getRank() != 1) + return matchFailure(); + + auto loc = op.getLoc(); + auto elemType = sourceVectorType.getElementType(); + Value zero = rewriter.create(loc, elemType, + rewriter.getZeroAttr(elemType)); + Value desc = rewriter.create(loc, resultVectorType, zero); + unsigned mostMinorVectorSize = sourceVectorType.getShape()[1]; + for (int64_t i = 0, e = sourceVectorType.getShape().front(); i != e; ++i) { + Value vec = rewriter.create(loc, op.source(), i); + desc = rewriter.create( + loc, vec, desc, + /*offsets=*/i * mostMinorVectorSize, /*strides=*/1); + } + rewriter.replaceOp(op, desc); + return matchSuccess(); + } +}; + +/// ShapeOp 1D -> 2D upcast serves the purpose of unflattening 2-D from 1-D +/// vectors progressively on the way from targeting llvm.matrix intrinsics. +/// This iterates over the most major dimension of the 2-D vector and performs +/// rewrites into: +/// vector.strided_slice from 1-D + vector.insert into 2-D +class ShapeCastOp2DUpCastRewritePattern + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + PatternMatchResult matchAndRewrite(vector::ShapeCastOp op, + PatternRewriter &rewriter) const override { + auto sourceVectorType = op.getSourceVectorType(); + auto resultVectorType = op.getResultVectorType(); + if (sourceVectorType.getRank() != 1 || resultVectorType.getRank() != 2) + return matchFailure(); + + auto loc = op.getLoc(); + auto elemType = sourceVectorType.getElementType(); + Value zero = rewriter.create(loc, elemType, + rewriter.getZeroAttr(elemType)); + Value desc = rewriter.create(loc, resultVectorType, zero); + unsigned mostMinorVectorSize = resultVectorType.getShape()[1]; + for (int64_t i = 0, e = resultVectorType.getShape().front(); i != e; ++i) { + Value vec = rewriter.create( + loc, op.source(), /*offsets=*/i * mostMinorVectorSize, + /*sizes=*/mostMinorVectorSize, + /*strides=*/1); + desc = rewriter.create(loc, vec, desc, i); + } + rewriter.replaceOp(op, desc); + return matchSuccess(); + } +}; + } // namespace // TODO(andydavis) Add pattern to rewrite ExtractSlices(ConstantMaskOp). @@ -1188,5 +1294,9 @@ void mlir::vector::populateVectorContractLoweringPatterns( OwningRewritePatternList &patterns, MLIRContext *context) { - patterns.insert(context); + patterns.insert(context); } diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -701,3 +701,19 @@ // CHECK: %[[V:.*]] = "llvm.intr.experimental.vector.reduce.add"(%[[A]]) // CHECK: llvm.return %[[V]] : !llvm.i64 +// 4x16 16x3 4x3 +func @matrix_ops(%A: vector<64xf64>, %B: vector<48xf64>) -> vector<12xf64> { + %C = vector.matrix_multiply %A, %B + { lhs_rows = 4: i32, lhs_columns = 16: i32 , rhs_columns = 3: i32 } : + (vector<64xf64>, vector<48xf64>) -> vector<12xf64> + return %C: vector<12xf64> +} +// CHECK-LABEL: llvm.func @matrix_ops +// CHECK-SAME: %[[A:.*]]: !llvm<"<64 x double>"> +// CHECK-SAME: %[[B:.*]]: !llvm<"<48 x double>"> +// +// MLIR is row-major but LLVM intrinsics are column major. +// CHECK: llvm.intr.matrix.multiply %[[B]], %[[A]] { +// CHECK-SAME: lhs_columns = 3 : i32, lhs_rows = 16 : i32, rhs_columns = 4 : i32 +// CHECK-SAME: } : (!llvm<"<48 x double>">, !llvm<"<64 x double>">) -> !llvm<"<12 x double>"> + diff --git a/mlir/test/Dialect/VectorOps/vector-contract-transforms.mlir b/mlir/test/Dialect/VectorOps/vector-contract-transforms.mlir --- a/mlir/test/Dialect/VectorOps/vector-contract-transforms.mlir +++ b/mlir/test/Dialect/VectorOps/vector-contract-transforms.mlir @@ -1,4 +1,6 @@ -// RUN: mlir-opt %s -test-vector-contraction-conversion | FileCheck %s +// R_UN: mlir-opt %s -test-vector-contraction-conversion | FileCheck %s +// RUN: mlir-opt %s -test-vector-contraction-conversion -vector-lower-matrix-intrinsics +//| FileCheck %s --check-prefix=MATRIX #dotp_accesses = [ affine_map<(i) -> (i)>, @@ -161,6 +163,19 @@ // CHECK: %[[T45:.*]] = vector.insert %[[T44]], %[[T22]] [1] : vector<2xf32> into vector<2x2xf32> // CHECK: return %[[T45]] : vector<2x2xf32> +// MATRIX-LABEL: func @extract_contract4 +// MATRIX-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x2xf32>, +// MATRIX-SAME: %[[B:[a-zA-Z0-9]*]]: vector<2x2xf32>, +// MATRIX-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x2xf32> +// MATRIX: %[[a:.*]] = vector.shape_cast %[[A]] : vector<2x2xf32> to vector<4xf32> +// MATRIX: %[[b:.*]] = vector.shape_cast %[[B]] : vector<2x2xf32> to vector<4xf32> +// MATRIX: %[[c:.*]] = vector.matrix_multiply %[[a]], %[[b]] +// MATRIX-SAME: {lhs_columns = 2 : i32, lhs_rows = 2 : i32, rhs_columns = 2 : i32} : +// MATRIX-SAME: (vector<4xf32>, vector<4xf32>) -> vector<4xf32> +// MATRIX: %[[D:.*]] = vector.shape_cast %[[c]] : vector<4xf32> to vector<2x2xf32> +// MATRIX: %[[E:.*]] = addf %[[C]], %[[D]] : vector<2x2xf32> +// MATRIX: return %[[E]] : vector<2x2xf32> + func @extract_contract4(%arg0: vector<2x2xf32>, %arg1: vector<2x2xf32>, %arg2: vector<2x2xf32>) -> vector<2x2xf32> { @@ -250,3 +265,25 @@ : vector<2x3xf32>, vector<3x2xf32> into f32 return %0 : f32 } + +// Shape up and downcasts for 2-D vectors, for supporting converion to +// llvm.matrix operations +// CHECK-LABEL: func @shape_casts +func @shape_casts(%a: vector<2x2xf32>) -> (vector<4xf32>, vector<2x2xf32>) { + // CHECK: %[[cst:.*]] = constant dense<0.000000e+00> : vector<4xf32> + // CHECK: %[[cst22:.*]] = constant dense<0.000000e+00> : vector<2x2xf32> + // CHECK: %[[ex0:.*]] = vector.extract %{{.*}}[0] : vector<2x2xf32> + // CHECK: %[[in0:.*]] = vector.insert_strided_slice %[[ex0]], %[[cst]] {offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32> + // CHECK: %[[ex1:.*]] = vector.extract %{{.*}}[1] : vector<2x2xf32> + // CHECK: %[[in2:.*]] = vector.insert_strided_slice %[[ex1]], %[[in0]] {offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32> + %0 = vector.shape_cast %a : vector<2x2xf32> to vector<4xf32> + // CHECK: %[[add:.*]] = addf %[[in2]], %[[in2]] : vector<4xf32> + %r0 = addf %0, %0: vector<4xf32> + // CHECK: %[[ss0:.*]] = vector.strided_slice %[[add]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32> + // CHECK: %[[res0:.*]] = vector.insert %[[ss0]], %[[cst22]] [0] : vector<2xf32> into vector<2x2xf32> + // CHECK: %[[s2:.*]] = vector.strided_slice %[[add]] {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32> + // CHECK: %[[res1:.*]] = vector.insert %[[s2]], %[[res0]] [1] : vector<2xf32> into vector<2x2xf32> + %1 = vector.shape_cast %r0 : vector<4xf32> to vector<2x2xf32> + // CHECK: return %[[add]], %[[res1]] : vector<4xf32>, vector<2x2xf32> + return %r0, %1 : vector<4xf32>, vector<2x2xf32> +} diff --git a/mlir/test/Target/llvmir-intrinsics.mlir b/mlir/test/Target/llvmir-intrinsics.mlir --- a/mlir/test/Target/llvmir-intrinsics.mlir +++ b/mlir/test/Target/llvmir-intrinsics.mlir @@ -138,7 +138,7 @@ { // CHECK: call <12 x float> @llvm.matrix.multiply.v12f32.v64f32.v48f32(<64 x float> %0, <48 x float> %1, i32 4, i32 16, i32 3) %C = llvm.intr.matrix.multiply %A, %B - { lhs_rows = 4: i32, lhs_columns = 16: i32 , rhs_rows = 3: i32} : + { lhs_rows = 4: i32, lhs_columns = 16: i32 , rhs_columns = 3: i32} : (!llvm<"<64 x float>">, !llvm<"<48 x float>">) -> !llvm<"<12 x float>"> llvm.return %C: !llvm<"<12 x float>"> } diff --git a/mlir/test/mlir-cpu-runner/test-contraction-matrix-intrinsic.mlir b/mlir/test/mlir-cpu-runner/test-contraction-matrix-intrinsic.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/mlir-cpu-runner/test-contraction-matrix-intrinsic.mlir @@ -0,0 +1,73 @@ + +// NOT FOR COMMIT: for illustration purposes only + +// RUN: mlir-opt %s -vector-lower-matrix-intrinsics=true -convert-vector-to-llvm -convert-std-to-llvm | mlir-cpu-runner -e entry -entry-point-result=void -shared-libs=%linalg_test_lib_dir/libmlir_runner_utils%shlibext -lower-matrix-intrinsics | FileCheck %s + +#matmat_accesses = [ + affine_map<(i, j, k) -> (i, k)>, + affine_map<(i, j, k) -> (k, j)>, + affine_map<(i, j, k) -> (i, j)> +] +#matmat_trait = { + indexing_maps = #matmat_accesses, + iterator_types = ["parallel", "parallel", "reduction"] +} + +func @entry() { + %f0 = constant 0.0: f32 + %f1 = constant 1.0: f32 + %f2 = constant 2.0: f32 + %f3 = constant 3.0: f32 + %f4 = constant 4.0: f32 + %f5 = constant 5.0: f32 + %f6 = constant 6.0: f32 + %f7 = constant 7.0: f32 + %f8 = constant 8.0: f32 + + // Zero vectors. + %Z = vector.broadcast %f0 : f32 to vector<2x2xf32> + + // Construct test vectors. + %0 = vector.broadcast %f1 : f32 to vector<2xf32> + %a = vector.insert %f2, %0[1] : f32 into vector<2xf32> + %1 = vector.broadcast %f3 : f32 to vector<2xf32> + %b = vector.insert %f4, %1[1] : f32 into vector<2xf32> + %2 = vector.broadcast %f5 : f32 to vector<2xf32> + %c = vector.insert %f6, %2[1] : f32 into vector<2xf32> + %3 = vector.broadcast %f7 : f32 to vector<2xf32> + %d = vector.insert %f8, %3[1] : f32 into vector<2xf32> + + // Construct test matrices. + %4 = vector.broadcast %f0 : f32 to vector<2x2xf32> + %5 = vector.insert %a, %4[0] : vector<2xf32> into vector<2x2xf32> + %A = vector.insert %b, %5[1] : vector<2xf32> into vector<2x2xf32> + + %6 = vector.broadcast %f0 : f32 to vector<2x2xf32> + %7 = vector.insert %c, %6[0] : vector<2xf32> into vector<2x2xf32> + %B = vector.insert %d, %7[1] : vector<2xf32> into vector<2x2xf32> + + // CHECK: ( ( 1, 2 ), ( 3, 4 ) ) + vector.print %A : vector<2x2xf32> + // CHECK: ( ( 5, 6 ), ( 7, 8 ) ) + vector.print %B : vector<2x2xf32> + + // CHECK: ( 1, 2, 3, 4 ) + %aa = vector.shape_cast %A : vector<2x2xf32> to vector<4xf32> + vector.print %aa : vector<4xf32> + // CHECK: ( 5, 6, 7, 8 ) + %bb = vector.shape_cast %B : vector<2x2xf32> to vector<4xf32> + vector.print %bb : vector<4xf32> + + // CHECK: ( 19, 22, 43, 50 ) + %cc = vector.matrix_multiply %aa, %bb + {lhs_columns = 2 : i32, lhs_rows = 2 : i32, rhs_columns = 2 : i32} : + (vector<4xf32>, vector<4xf32>) -> vector<4xf32> + vector.print %cc : vector<4xf32> + + // CHECK: ( ( 19, 22 ), ( 43, 50 ) ) + %mm1 = vector.contract #matmat_trait %A, %B, %Z + : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> + vector.print %mm1 : vector<2x2xf32> + + return +}