diff --git a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h --- a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h +++ b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h @@ -17,11 +17,33 @@ #ifndef MLIR_DIALECT_UTILS_STRUCTUREDOPSUTILS_H #define MLIR_DIALECT_UTILS_STRUCTUREDOPSUTILS_H +#include "mlir/IR/AffineMap.h" #include "mlir/IR/Attributes.h" #include "mlir/Support/LLVM.h" #include "llvm/ADT/StringRef.h" namespace mlir { + +inline bool isRowMajorMatmul(ArrayAttr indexingMaps) { + AffineExpr m, n, k; + bindDims(indexingMaps.getContext(), m, n, k); + auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {m, k})); + auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {k, n})); + auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n})); + auto maps = ArrayAttr::get({mapA, mapB, mapC}, indexingMaps.getContext()); + return indexingMaps == maps; +} + +inline bool isColumnMajorMatmul(ArrayAttr indexingMaps) { + AffineExpr m, n, k; + bindDims(indexingMaps.getContext(), m, n, k); + auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {k, n})); + auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {m, k})); + auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {n, m})); + auto maps = ArrayAttr::get({mapA, mapB, mapC}, indexingMaps.getContext()); + return indexingMaps == maps; +} + /// Attribute name for the AffineArrayAttr which encodes the relationship /// between a structured op iterators' and its operands. constexpr StringRef getIndexingMapsAttrName() { return "indexing_maps"; } diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -1138,6 +1138,7 @@ OwningRewritePatternList patterns; populateVectorToLLVMMatrixConversionPatterns(converter, patterns); populateVectorToLLVMConversionPatterns(converter, patterns); + populateVectorToLLVMMatrixConversionPatterns(converter, patterns); populateStdToLLVMConversionPatterns(converter, patterns); LLVMConversionTarget target(getContext()); diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp @@ -15,6 +15,7 @@ #include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" +#include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/Dialect/VectorOps/VectorOps.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/Matchers.h" @@ -144,17 +145,10 @@ // TODO(ntv) should be Tablegen'd from a single source that generates the op // itself. -static bool isMatmul(linalg::GenericOp genericOp) { - auto *ctx = genericOp.getContext(); - auto m = getAffineDimExpr(0, ctx); - auto n = getAffineDimExpr(1, ctx); - auto k = getAffineDimExpr(2, ctx); - auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {m, k})); - auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {k, n})); - auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n})); - auto maps = ArrayAttr::get({mapA, mapB, mapC}, ctx); +static bool isRowMajorMatmul(linalg::GenericOp genericOp) { return genericOp.getNumInputs() == 2 && genericOp.getNumOutputs() == 1 && - genericOp.indexing_maps() == maps && hasMultiplyAddBody(genericOp); + isRowMajorMatmul(genericOp.indexing_maps()) && + hasMultiplyAddBody(genericOp); } // TODO(ntv, ataei): This is in fact much more general than just vectorization @@ -172,7 +166,7 @@ return success(); auto genericOp = dyn_cast(op); - if (!genericOp || !isMatmul(genericOp)) + if (!genericOp || !::isRowMajorMatmul(genericOp)) return failure(); // TODO(ntv): non-identity layout. diff --git a/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp b/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp --- a/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp +++ b/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp @@ -42,6 +42,13 @@ using llvm::dbgs; using mlir::functional::zipMap; +static llvm::cl::OptionCategory clOptionsCategory(DEBUG_TYPE " options"); + +static llvm::cl::opt lowerToLLVMMatrixIntrinsics( + "vector-lower-matrix-intrinsics", + llvm::cl::desc("Lower vector.contract to llvm.intr.matrix.multiply"), + llvm::cl::init(false), llvm::cl::cat(clOptionsCategory)); + /// Given a shape with sizes greater than 0 along all dimensions, /// returns the distance, in number of elements, between a slice in a dimension /// and the next slice in the same dimension. @@ -935,6 +942,39 @@ if (llvm::size(op.masks()) != 0) return matchFailure(); + // TODO(ntv, ajcbik): implement benefits, cost models, separate this out in + // a new pattern. + // TODO(ntv, fhahn): once row-major mode is available in LLVM's matrix + // intrinsics, use that. + if (lowerToLLVMMatrixIntrinsics && + isColumnMajorMatmul(op.indexing_maps())) { + VectorType lhsType = op.getLhsType(); + VectorType rhsType = op.getRhsType(); + Type flattenedLHSType = + VectorType::get(lhsType.getNumElements(), lhsType.getElementType()); + Type flattenedRHSType = + VectorType::get(rhsType.getNumElements(), rhsType.getElementType()); + auto lhs = rewriter.create( + op.getLoc(), flattenedLHSType, op.lhs()); + auto rhs = rewriter.create( + op.getLoc(), flattenedRHSType, op.rhs()); + + unsigned lhsRows = op.getLhsType().getShape()[0]; + unsigned lhsColumns = op.getLhsType().getShape()[1]; + unsigned rhsColumns = op.getRhsType().getShape()[1]; + Value mul = rewriter.create( + op.getLoc(), lhs, rhs, lhsRows, lhsColumns, rhsColumns); + mul = rewriter.create(op.getLoc(), + op.acc().getType(), mul); + Type elementType = op.getLhsType().getElementType(); + assert(elementType.isIntOrFloat()); + if (elementType.isa()) + rewriter.replaceOpWithNewOp(op, op.acc(), mul); + else + rewriter.replaceOpWithNewOp(op, op.acc(), mul); + return matchSuccess(); + } + // Find first batch dimension in LHS/RHS, and lower when found. std::vector> batchDimMap = op.getBatchDimMap(); if (!batchDimMap.empty()) { diff --git a/mlir/test/Dialect/VectorOps/vector-contract-transforms.mlir b/mlir/test/Dialect/VectorOps/vector-contract-transforms.mlir --- a/mlir/test/Dialect/VectorOps/vector-contract-transforms.mlir +++ b/mlir/test/Dialect/VectorOps/vector-contract-transforms.mlir @@ -1,4 +1,5 @@ // RUN: mlir-opt %s -test-vector-contraction-conversion | FileCheck %s +// RUN: mlir-opt %s -test-vector-contraction-conversion -vector-lower-matrix-intrinsics | FileCheck %s --check-prefix=MATRIX #dotp_accesses = [ affine_map<(i) -> (i)>, @@ -333,3 +334,47 @@ // CHECK: return %[[add]], %[[res1]] : vector<4xf32>, vector<2x2xf32> return %r0, %1 : vector<4xf32>, vector<2x2xf32> } + +// MATRIX-LABEL: func @column_major_matmul +// MATRIX-SAME: %[[A:[a-zA-Z0-9]*]]: vector<4x3xf32>, +// MATRIX-SAME: %[[B:[a-zA-Z0-9]*]]: vector<2x4xf32>, +// MATRIX-SAME: %[[C:[a-zA-Z0-9]*]]: vector<3x2xf32> +// MATRIX: %[[vcst:.*]] = constant dense<0.000000e+00> : vector<12xf32> +// MATRIX: %[[vcst_0:.*]] = constant dense<0.000000e+00> : vector<8xf32> +// MATRIX: %[[vcst_1:.*]] = constant dense<0.000000e+00> : vector<3x2xf32> +// MATRIX: %[[a0:.*]] = vector.extract %[[A]][0] : vector<4x3xf32> +// MATRIX: %[[a1:.*]] = vector.insert_strided_slice %[[a0]], %[[vcst]] {offsets = [0], strides = [1]} : vector<3xf32> into vector<12xf32> +// MATRIX: %[[a2:.*]] = vector.extract %[[A]][1] : vector<4x3xf32> +// MATRIX: %[[a3:.*]] = vector.insert_strided_slice %[[a2]], %[[a1]] {offsets = [3], strides = [1]} : vector<3xf32> into vector<12xf32> +// MATRIX: %[[a4:.*]] = vector.extract %[[A]][2] : vector<4x3xf32> +// MATRIX: %[[a5:.*]] = vector.insert_strided_slice %[[a4]], %[[a3]] {offsets = [6], strides = [1]} : vector<3xf32> into vector<12xf32> +// MATRIX: %[[a6:.*]] = vector.extract %[[A]][3] : vector<4x3xf32> +// MATRIX: %[[a7:.*]] = vector.insert_strided_slice %[[a6]], %[[a5]] {offsets = [9], strides = [1]} : vector<3xf32> into vector<12xf32> +// MATRIX: %[[b8:.*]] = vector.extract %[[B]][0] : vector<2x4xf32> +// MATRIX: %[[b9:.*]] = vector.insert_strided_slice %[[b8]], %[[vcst_0]] {offsets = [0], strides = [1]} : vector<4xf32> into vector<8xf32> +// MATRIX: %[[b10:.*]] = vector.extract %[[B]][1] : vector<2x4xf32> +// MATRIX: %[[b11:.*]] = vector.insert_strided_slice %[[b10]], %[[b9]] {offsets = [4], strides = [1]} : vector<4xf32> into vector<8xf32> +// MATRIX: %[[mm12:.*]] = vector.matrix_multiply %[[a7]], %[[b11]] {lhs_columns = 3 : i32, lhs_rows = 4 : i32, rhs_columns = 4 : i32} : (vector<12xf32>, vector<8xf32>) -> vector<12xf32> +// MATRIX: %[[mm13:.*]] = vector.strided_slice %[[mm12]] {offsets = [0], sizes = [2], strides = [1]} : vector<12xf32> to vector<2xf32> +// MATRIX: %[[mm14:.*]] = vector.insert %[[mm13]], %[[vcst_1]] [0] : vector<2xf32> into vector<3x2xf32> +// MATRIX: %[[mm15:.*]] = vector.strided_slice %[[mm12]] {offsets = [2], sizes = [2], strides = [1]} : vector<12xf32> to vector<2xf32> +// MATRIX: %[[mm16:.*]] = vector.insert %[[mm15]], %[[mm14]] [1] : vector<2xf32> into vector<3x2xf32> +// MATRIX: %[[mm17:.*]] = vector.strided_slice %[[mm12]] {offsets = [4], sizes = [2], strides = [1]} : vector<12xf32> to vector<2xf32> +// MATRIX: %[[mm18:.*]] = vector.insert %[[mm17]], %[[mm16]] [2] : vector<2xf32> into vector<3x2xf32> +// MATRIX: %[[mm19:.*]] = addf %[[C]], %[[mm18]] : vector<3x2xf32> +#column_major_matmat_accesses = [ + affine_map<(i, j, k) -> (k, j)>, + affine_map<(i, j, k) -> (i, k)>, + affine_map<(i, j, k) -> (j, i)> +] +#column_major_matmat_trait = { + indexing_maps = #column_major_matmat_accesses, + iterator_types = ["parallel", "parallel", "reduction"] +} +func @column_major_matmul(%arg0: vector<4x3xf32>, + %arg1: vector<2x4xf32>, + %arg2: vector<3x2xf32>) -> vector<3x2xf32> { + %0 = vector.contract #column_major_matmat_trait %arg0, %arg1, %arg2 + : vector<4x3xf32>, vector<2x4xf32> into vector<3x2xf32> + return %0 : vector<3x2xf32> +}