diff --git a/mlir/include/mlir/Dialect/NVGPU/Utils/MMAUtils.h b/mlir/include/mlir/Dialect/NVGPU/Utils/MMAUtils.h --- a/mlir/include/mlir/Dialect/NVGPU/Utils/MMAUtils.h +++ b/mlir/include/mlir/Dialect/NVGPU/Utils/MMAUtils.h @@ -93,18 +93,6 @@ getLaneIdToLdMatrixMatrixCoord(OpBuilder &builder, Location loc, const LdMatrixParams ¶ms); -/// Transform `vector.contract` into (m,k)x(n,k)x(m,n) form so that it can be -/// converted to `nvgpu.mma.sync`. This specific form is meant to indicate that -/// the vector operands are organized such that the reduction dimension is -/// contiguous. -struct PrepareContractToGPUMMASync - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(vector::ContractionOp op, - PatternRewriter &rewriter) const override; -}; - } // namespace nvgpu } // namespace mlir diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h --- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h +++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h @@ -518,6 +518,45 @@ vector::ContractionOp op, Value mask) const; }; +/// Canonicalization of a `vector.contract %a, %b, %c` with row-major matmul +/// semantics to a contraction suitable for MMT lowering. The canonical form is +/// "TNT" = A row-major, B col-major, C row-major (mk, nk, mn). This specific +/// form is meant to indicate that the vector operands are organized such that +/// the reduction dimension is contiguous. Example: +/// ``` +/// vector.contract {indexing_maps = [affine_map<(m, n, k) -> (m, k)>, +/// affine_map<(m, n, k) -> (n, k)>, +/// affine_map<(m, n, k) -> (m, n)>], +/// iterator_types = ["parallel", "parallel", "reduction"], +/// kind = #vector.kind} %arg0, %arg1, %arg2 : ... +/// ``` +/// +class PrepareContractionOpForMMTLowering + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + using FilterConstraintType = + std::function; + + static LogicalResult defaultFilter(vector::ContractionOp op) { + return success(); + } + + PrepareContractionOpForMMTLowering( + MLIRContext *context, PatternBenefit benefit = 1, + const FilterConstraintType &constraint = defaultFilter) + : OpRewritePattern(context, benefit), + filter(defaultFilter) {} + + LogicalResult matchAndRewrite(vector::ContractionOp op, + PatternRewriter &rewriter) const override; + +private: + /// Options to control the vector patterns. + FilterConstraintType filter; +}; + } // namespace vector } // namespace mlir diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp --- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp +++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp @@ -24,6 +24,7 @@ #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" #include "mlir/Dialect/Vector/Utils/VectorUtils.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" @@ -1174,7 +1175,7 @@ return; } patterns - .add( + .add( patterns.getContext()); } diff --git a/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp b/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp --- a/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp +++ b/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp @@ -272,60 +272,3 @@ return failure(); } - -LogicalResult nvgpu::PrepareContractToGPUMMASync::matchAndRewrite( - vector::ContractionOp op, PatternRewriter &rewriter) const { - Location loc = op.getLoc(); - Value lhs = op.getLhs(); - Value rhs = op.getRhs(); - Value res = op.getAcc(); - - // Set up the parallel/reduction structure in right form. - using MapList = ArrayRef>; - auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); }; - AffineExpr m; - AffineExpr n; - AffineExpr k; - bindDims(rewriter.getContext(), m, n, k); - static constexpr std::array perm = {1, 0}; - auto iteratorTypes = op.getIteratorTypes().getValue(); - SmallVector maps = op.getIndexingMapsArray(); - if (iteratorTypes.size() != 3) - return failure(); - if (!(vector::isParallelIterator(iteratorTypes[0]) && - vector::isParallelIterator(iteratorTypes[1]) && - vector::isReductionIterator(iteratorTypes[2]))) - return failure(); - - // The canonical form is "TNT" = A row-major, B col-major, C row-major. - const auto canonicalForm = infer({{m, k}, {n, k}, {m, n}}); - if (maps == canonicalForm) { - return failure(); - } - if (maps == infer({{m, k}, {k, n}, {m, n}})) { - rhs = rewriter.create(loc, rhs, perm); - } else if (maps == infer({{k, m}, {k, n}, {m, n}})) { - lhs = rewriter.create(loc, lhs, perm); - } else if (maps == infer({{k, m}, {k, n}, {m, n}})) { - rhs = rewriter.create(loc, rhs, perm); - lhs = rewriter.create(loc, lhs, perm); - } else if (maps == infer({{k, m}, {k, n}, {n, m}})) { - std::swap(rhs, lhs); - rhs = rewriter.create(loc, rhs, perm); - lhs = rewriter.create(loc, lhs, perm); - } else if (maps == infer({{k, m}, {n, k}, {n, m}})) { - std::swap(rhs, lhs); - rhs = rewriter.create(loc, rhs, perm); - } else if (maps == infer({{m, k}, {k, n}, {n, m}})) { - std::swap(lhs, rhs); - lhs = rewriter.create(loc, lhs, perm); - } else if (maps == infer({{m, k}, {n, k}, {n, m}})) { - std::swap(lhs, rhs); - } else { - return failure(); - } - rewriter.replaceOpWithNewOp( - op, lhs, rhs, res, rewriter.getAffineMapArrayAttr(canonicalForm), - op.getIteratorTypes()); - return success(); -} diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -24,6 +24,7 @@ #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" #include "mlir/Dialect/Vector/Utils/VectorUtils.h" #include "mlir/IR/BuiltinAttributeInterfaces.h" #include "mlir/IR/BuiltinTypes.h" @@ -2159,6 +2160,83 @@ return result; } +LogicalResult PrepareContractionOpForMMTLowering::matchAndRewrite( + vector::ContractionOp op, PatternRewriter &rewriter) const { + // TODO: Remove native masks from contraction op? + if (!op.getMasks().empty()) + return failure(); + + if (failed(filter(op))) + return failure(); + + Location loc = op.getLoc(); + Value lhs = op.getLhs(); + Value rhs = op.getRhs(); + Value res = op.getAcc(); + + // Set up the parallel/reduction structure in right form. + using MapList = ArrayRef>; + auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); }; + AffineExpr m; + AffineExpr n; + AffineExpr k; + bindDims(rewriter.getContext(), m, n, k); + static constexpr std::array perm = {1, 0}; + auto iteratorTypes = op.getIteratorTypes().getValue(); + SmallVector maps = op.getIndexingMapsArray(); + if (iteratorTypes.size() != 3 || + !vector::isParallelIterator(iteratorTypes[0]) || + !vector::isParallelIterator(iteratorTypes[1]) || + !vector::isReductionIterator(iteratorTypes[2])) + return rewriter.notifyMatchFailure(op, "contraction is not a gemm"); + + // The canonical form is "TNT" = A row-major, B col-major, C row-major. + const auto canonicalForm = infer({{m, k}, {n, k}, {m, n}}); + if (maps == canonicalForm) + return rewriter.notifyMatchFailure(op, "already in the canonical form"); + + // Create a vector transpose making sure to emit zero/sign-extend at the end. + auto createTranspose = [&rewriter, loc](Value mat) -> Value { + if (auto sext = mat.getDefiningOp()) { + Value trans = + rewriter.create(loc, sext.getIn(), perm); + return rewriter.create(loc, mat.getType(), trans); + } + if (auto zext = mat.getDefiningOp()) { + Value trans = + rewriter.create(loc, zext.getIn(), perm); + return rewriter.create(loc, mat.getType(), trans); + } + return rewriter.create(loc, mat, perm); + }; + + if (maps == infer({{m, k}, {k, n}, {m, n}})) { + rhs = createTranspose(rhs); + } else if (maps == infer({{k, m}, {n, k}, {m, n}})) { + lhs = createTranspose(lhs); + } else if (maps == infer({{k, m}, {k, n}, {m, n}})) { + rhs = createTranspose(rhs); + lhs = createTranspose(lhs); + } else if (maps == infer({{k, m}, {k, n}, {n, m}})) { + std::swap(rhs, lhs); + rhs = createTranspose(rhs); + lhs = createTranspose(lhs); + } else if (maps == infer({{k, m}, {n, k}, {n, m}})) { + std::swap(rhs, lhs); + rhs = createTranspose(rhs); + } else if (maps == infer({{m, k}, {k, n}, {n, m}})) { + std::swap(lhs, rhs); + lhs = createTranspose(lhs); + } else if (maps == infer({{m, k}, {n, k}, {n, m}})) { + std::swap(lhs, rhs); + } else { + return rewriter.notifyMatchFailure(op, "unhandled contraction form"); + } + rewriter.replaceOpWithNewOp( + op, lhs, rhs, res, rewriter.getAffineMapArrayAttr(canonicalForm), + op.getIteratorTypes()); + return success(); +} } // namespace mlir /// Progressive lowering of transfer_read. This pattern supports lowering of diff --git a/mlir/test/Dialect/Vector/vector-contract-matmul-canonicalization.mlir b/mlir/test/Dialect/Vector/vector-contract-matmul-canonicalization.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Vector/vector-contract-matmul-canonicalization.mlir @@ -0,0 +1,198 @@ +// RUN: mlir-opt %s -test-vector-contraction-prepare-for-mmt-lowering | FileCheck %s + +// CHECK-LABEL: func.func @not_matmul +// CHECK-SAME: ([[ARG0:%.+]]: vector<4xf32>, [[ARG1:%.+]]: vector<4xf32>, [[ARG2:%.+]]: f32) +// CHECK-NEXT: vector.contract +// CHECK-NEXT: return +func.func @not_matmul(%arg0: vector<4xf32>, %arg1: vector<4xf32>, %arg2: f32) -> f32 { + %0 = vector.contract {indexing_maps = [affine_map<(d0) -> (d0)>, + affine_map<(d0) -> (d0)>, + affine_map<(d0) -> ()>], + iterator_types = ["reduction"], + kind = #vector.kind} %arg0, %arg1, %arg2 : + vector<4xf32>, vector<4xf32> into f32 + return %0 : f32 +} + +// This contraction is already in the canonical form. +// CHECK-LABEL: func.func @matmul_mk_nk_mn_4x4xi32 +// CHECK-SAME: ([[ARG0:%.+]]: vector<4x4xi32>, [[ARG1:%.+]]: vector<4x4xi32>, [[ARG2:%.+]]: vector<4x4xi32>) +// CHECK-NEXT: [[RES:%.+]] = vector.contract {{.+}} [[ARG0]], [[ARG1]], [[ARG2]] +// CHECK-NEXT: return [[RES]] +func.func @matmul_mk_nk_mn_4x4xi32(%arg0: vector<4x4xi32>, %arg1: vector<4x4xi32>, %arg2: vector<4x4xi32>) -> vector<4x4xi32> { + %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, + affine_map<(d0, d1, d2) -> (d1, d2)>, + affine_map<(d0, d1, d2) -> (d0, d1)>], + iterator_types = ["parallel", "parallel", "reduction"], + kind = #vector.kind} %arg0, %arg1, %arg2 : vector<4x4xi32>, vector<4x4xi32> into vector<4x4xi32> + return %res : vector<4x4xi32> +} + +// CHECK-LABEL: func.func @matmul_mk_kn_mn_4x4xi32 +// CHECK-SAME: ([[ARG0:%.+]]: vector<4x4xi32>, [[ARG1:%.+]]: vector<4x4xi32>, [[ARG2:%.+]]: vector<4x4xi32>) +// CHECK-NEXT: [[TRANS:%.+]] = vector.transpose [[ARG1]], [1, 0] : vector<4x4xi32> to vector<4x4xi32> +// CHECK-NEXT: [[RES:%.+]] = vector.contract {{.+}} [[ARG0]], [[TRANS]], [[ARG2]] +// CHECK-NEXT: return [[RES]] +func.func @matmul_mk_kn_mn_4x4xi32(%arg0: vector<4x4xi32>, %arg1: vector<4x4xi32>, %arg2: vector<4x4xi32>) -> vector<4x4xi32> { + %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, + affine_map<(d0, d1, d2) -> (d2, d1)>, + affine_map<(d0, d1, d2) -> (d0, d1)>], + iterator_types = ["parallel", "parallel", "reduction"], + kind = #vector.kind} %arg0, %arg1, %arg2 : vector<4x4xi32>, vector<4x4xi32> into vector<4x4xi32> + return %res : vector<4x4xi32> +} + +// CHECK-LABEL: func.func @matmul_mk_kn_mn_4x4xi8_extsi_i32 +// CHECK-SAME: ([[ARG0:%.+]]: vector<4x4xi8>, [[ARG1:%.+]]: vector<4x4xi8>, [[ARG2:%.+]]: vector<4x4xi32>) +// CHECK-NEXT: [[LHS:%.+]] = arith.extsi [[ARG0]] : vector<4x4xi8> to vector<4x4xi32> +// CHECK-NEXT: [[TRANS:%.+]] = vector.transpose [[ARG1]], [1, 0] : vector<4x4xi8> to vector<4x4xi8> +// CHECK-NEXT: [[RHS:%.+]] = arith.extsi [[TRANS]] : vector<4x4xi8> to vector<4x4xi32> +// CHECK-NEXT: [[RES:%.+]] = vector.contract {{.+}} [[LHS]], [[RHS]], [[ARG2]] +// CHECK-NEXT: return [[RES]] +func.func @matmul_mk_kn_mn_4x4xi8_extsi_i32(%arg0: vector<4x4xi8>, %arg1: vector<4x4xi8>, %arg2: vector<4x4xi32>) -> vector<4x4xi32> { + %lhs = arith.extsi %arg0: vector<4x4xi8> to vector<4x4xi32> + %rhs = arith.extsi %arg1: vector<4x4xi8> to vector<4x4xi32> + %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, + affine_map<(d0, d1, d2) -> (d2, d1)>, + affine_map<(d0, d1, d2) -> (d0, d1)>], + iterator_types = ["parallel", "parallel", "reduction"], + kind = #vector.kind} %lhs, %rhs, %arg2 : vector<4x4xi32>, vector<4x4xi32> into vector<4x4xi32> + return %res : vector<4x4xi32> +} + +// Check that non-square shapes are also handled. +// CHECK-LABEL: func.func @matmul_mk_kn_mn_4x16xi32 +// CHECK-SAME: ([[ARG0:%.+]]: vector<4x16xi32>, [[ARG1:%.+]]: vector<16x4xi32>, [[ARG2:%.+]]: vector<4x4xi32>) +// CHECK-NEXT: [[TRANS:%.+]] = vector.transpose [[ARG1]], [1, 0] : vector<16x4xi32> to vector<4x16xi32> +// CHECK-NEXT: [[RES:%.+]] = vector.contract {{.+}} [[ARG0]], [[TRANS]], [[ARG2]] +// CHECK-NEXT: return [[RES]] +func.func @matmul_mk_kn_mn_4x16xi32(%arg0: vector<4x16xi32>, %arg1: vector<16x4xi32>, %arg2: vector<4x4xi32>) -> vector<4x4xi32> { + %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, + affine_map<(d0, d1, d2) -> (d2, d1)>, + affine_map<(d0, d1, d2) -> (d0, d1)>], + iterator_types = ["parallel", "parallel", "reduction"], + kind = #vector.kind} %arg0, %arg1, %arg2 : vector<4x16xi32>, vector<16x4xi32> into vector<4x4xi32> + return %res : vector<4x4xi32> +} + +// CHECK-LABEL: func.func @matmul_mk_kn_mn_4x4xi8_extui_i32 +// CHECK-SAME: ([[ARG0:%.+]]: vector<4x4xi8>, [[ARG1:%.+]]: vector<4x4xi8>, [[ARG2:%.+]]: vector<4x4xi32>) +// CHECK-NEXT: [[LHS:%.+]] = arith.extui [[ARG0]] : vector<4x4xi8> to vector<4x4xi32> +// CHECK-NEXT: [[TRANS:%.+]] = vector.transpose [[ARG1]], [1, 0] : vector<4x4xi8> to vector<4x4xi8> +// CHECK-NEXT: [[RHS:%.+]] = arith.extui [[TRANS]] : vector<4x4xi8> to vector<4x4xi32> +// CHECK-NEXT: [[RES:%.+]] = vector.contract {{.+}} [[LHS]], [[RHS]], [[ARG2]] +// CHECK-NEXT: return [[RES]] +func.func @matmul_mk_kn_mn_4x4xi8_extui_i32(%arg0: vector<4x4xi8>, %arg1: vector<4x4xi8>, %arg2: vector<4x4xi32>) -> vector<4x4xi32> { + %lhs = arith.extui %arg0: vector<4x4xi8> to vector<4x4xi32> + %rhs = arith.extui %arg1: vector<4x4xi8> to vector<4x4xi32> + %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, + affine_map<(d0, d1, d2) -> (d2, d1)>, + affine_map<(d0, d1, d2) -> (d0, d1)>], + iterator_types = ["parallel", "parallel", "reduction"], + kind = #vector.kind} %lhs, %rhs, %arg2 : vector<4x4xi32>, vector<4x4xi32> into vector<4x4xi32> + return %res : vector<4x4xi32> +} + +// CHECK-LABEL: func.func @matmul_km_nk_mn_4x4xi32 +// CHECK-SAME: ([[ARG0:%.+]]: vector<4x4xi32>, [[ARG1:%.+]]: vector<4x4xi32>, [[ARG2:%.+]]: vector<4x4xi32>) +// CHECK-NEXT: [[TRANS:%.+]] = vector.transpose [[ARG0]], [1, 0] : vector<4x4xi32> to vector<4x4xi32> +// CHECK-NEXT: [[RES:%.+]] = vector.contract {{.+}} [[TRANS]], [[ARG1]], [[ARG2]] +// CHECK-NEXT: return [[RES]] +func.func @matmul_km_nk_mn_4x4xi32(%arg0: vector<4x4xi32>, %arg1: vector<4x4xi32>, %arg2: vector<4x4xi32>) -> vector<4x4xi32> { + %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d0)>, + affine_map<(d0, d1, d2) -> (d1, d2)>, + affine_map<(d0, d1, d2) -> (d0, d1)>], + iterator_types = ["parallel", "parallel", "reduction"], + kind = #vector.kind} %arg0, %arg1, %arg2 : vector<4x4xi32>, vector<4x4xi32> into vector<4x4xi32> + return %res : vector<4x4xi32> +} + +// CHECK-LABEL: func.func @matmul_km_kn_mn_4x4xi32 +// CHECK-SAME: ([[ARG0:%.+]]: vector<4x4xi32>, [[ARG1:%.+]]: vector<4x4xi32>, [[ARG2:%.+]]: vector<4x4xi32>) +// CHECK-DAG: [[LHS:%.+]] = vector.transpose [[ARG0]], [1, 0] : vector<4x4xi32> to vector<4x4xi32> +// CHECK-DAG: [[RHS:%.+]] = vector.transpose [[ARG1]], [1, 0] : vector<4x4xi32> to vector<4x4xi32> +// CHECK-NEXT: [[RES:%.+]] = vector.contract {{.+}} [[LHS]], [[RHS]], [[ARG2]] +// CHECK-NEXT: return [[RES]] +func.func @matmul_km_kn_mn_4x4xi32(%arg0: vector<4x4xi32>, %arg1: vector<4x4xi32>, %arg2: vector<4x4xi32>) -> vector<4x4xi32> { + %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d0)>, + affine_map<(d0, d1, d2) -> (d2, d1)>, + affine_map<(d0, d1, d2) -> (d0, d1)>], + iterator_types = ["parallel", "parallel", "reduction"], + kind = #vector.kind} %arg0, %arg1, %arg2 : vector<4x4xi32>, vector<4x4xi32> into vector<4x4xi32> + return %res : vector<4x4xi32> +} + +// CHECK-LABEL: func.func @matmul_km_kn_mn_4x4xi8_mixed_ext_i32 +// CHECK-SAME: ([[ARG0:%.+]]: vector<4x4xi8>, [[ARG1:%.+]]: vector<4x4xi8>, [[ARG2:%.+]]: vector<4x4xi32>) +// CHECK-DAG: [[LHST:%.+]] = vector.transpose [[ARG0]], [1, 0] : vector<4x4xi8> to vector<4x4xi8> +// CHECK-DAG: [[LHS:%.+]] = arith.extsi [[LHST]] : vector<4x4xi8> to vector<4x4xi32> +// CHECK-DAG: [[RHST:%.+]] = vector.transpose [[ARG1]], [1, 0] : vector<4x4xi8> to vector<4x4xi8> +// CHECK-DAG: [[RHS:%.+]] = arith.extui [[RHST]] : vector<4x4xi8> to vector<4x4xi32> +// CHECK-NEXT: [[RES:%.+]] = vector.contract {{.+}} [[LHS]], [[RHS]], [[ARG2]] +// CHECK-NEXT: return [[RES]] +func.func @matmul_km_kn_mn_4x4xi8_mixed_ext_i32(%arg0: vector<4x4xi8>, %arg1: vector<4x4xi8>, %arg2: vector<4x4xi32>) -> vector<4x4xi32> { + %lhs = arith.extsi %arg0 : vector<4x4xi8> to vector<4x4xi32> + %rhs = arith.extui %arg1 : vector<4x4xi8> to vector<4x4xi32> + %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d0)>, + affine_map<(d0, d1, d2) -> (d2, d1)>, + affine_map<(d0, d1, d2) -> (d0, d1)>], + iterator_types = ["parallel", "parallel", "reduction"], + kind = #vector.kind} %lhs, %rhs, %arg2 : vector<4x4xi32>, vector<4x4xi32> into vector<4x4xi32> + return %res : vector<4x4xi32> +} + +// CHECK-LABEL: func.func @matmul_mk_nk_nm_4x4xi32 +// CHECK-SAME: ([[ARG0:%.+]]: vector<4x4xi32>, [[ARG1:%.+]]: vector<4x4xi32>, [[ARG2:%.+]]: vector<4x4xi32>) +// CHECK-NEXT: [[RES:%.+]] = vector.contract {{.+}} [[ARG1]], [[ARG0]], [[ARG2]] +// CHECK-NEXT: return [[RES]] +func.func @matmul_mk_nk_nm_4x4xi32(%arg0: vector<4x4xi32>, %arg1: vector<4x4xi32>, %arg2: vector<4x4xi32>) -> vector<4x4xi32> { + %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, + affine_map<(d0, d1, d2) -> (d1, d2)>, + affine_map<(d0, d1, d2) -> (d1, d0)>], + iterator_types = ["parallel", "parallel", "reduction"], + kind = #vector.kind} %arg0, %arg1, %arg2 : vector<4x4xi32>, vector<4x4xi32> into vector<4x4xi32> + return %res : vector<4x4xi32> +} + +// CHECK-LABEL: func.func @matmul_km_kn_nm_4x4xi32 +// CHECK-SAME: ([[ARG0:%.+]]: vector<4x4xi32>, [[ARG1:%.+]]: vector<4x4xi32>, [[ARG2:%.+]]: vector<4x4xi32>) +// CHECK-DAG: [[LHS:%.+]] = vector.transpose [[ARG0]], [1, 0] : vector<4x4xi32> to vector<4x4xi32> +// CHECK-DAG: [[RHS:%.+]] = vector.transpose [[ARG1]], [1, 0] : vector<4x4xi32> to vector<4x4xi32> +// CHECK-NEXT: [[RES:%.+]] = vector.contract {{.+}} [[RHS]], [[LHS]], [[ARG2]] +// CHECK-NEXT: return [[RES]] +func.func @matmul_km_kn_nm_4x4xi32(%arg0: vector<4x4xi32>, %arg1: vector<4x4xi32>, %arg2: vector<4x4xi32>) -> vector<4x4xi32> { + %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d0)>, + affine_map<(d0, d1, d2) -> (d2, d1)>, + affine_map<(d0, d1, d2) -> (d1, d0)>], + iterator_types = ["parallel", "parallel", "reduction"], + kind = #vector.kind} %arg0, %arg1, %arg2 : vector<4x4xi32>, vector<4x4xi32> into vector<4x4xi32> + return %res : vector<4x4xi32> +} + +// CHECK-LABEL: func.func @matmul_mk_kn_nm_4x4xi32 +// CHECK-SAME: ([[ARG0:%.+]]: vector<4x4xi32>, [[ARG1:%.+]]: vector<4x4xi32>, [[ARG2:%.+]]: vector<4x4xi32>) +// CHECK-DAG: [[RHS:%.+]] = vector.transpose [[ARG1]], [1, 0] : vector<4x4xi32> to vector<4x4xi32> +// CHECK-NEXT: [[RES:%.+]] = vector.contract {{.+}} [[RHS]], [[ARG0]], [[ARG2]] +// CHECK-NEXT: return [[RES]] +func.func @matmul_mk_kn_nm_4x4xi32(%arg0: vector<4x4xi32>, %arg1: vector<4x4xi32>, %arg2: vector<4x4xi32>) -> vector<4x4xi32> { + %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, + affine_map<(d0, d1, d2) -> (d2, d1)>, + affine_map<(d0, d1, d2) -> (d1, d0)>], + iterator_types = ["parallel", "parallel", "reduction"], + kind = #vector.kind} %arg0, %arg1, %arg2 : vector<4x4xi32>, vector<4x4xi32> into vector<4x4xi32> + return %res : vector<4x4xi32> +} + +// CHECK-LABEL: func.func @matmul_km_nk_nm_4x4xi32 +// CHECK-SAME: ([[ARG0:%.+]]: vector<4x4xi32>, [[ARG1:%.+]]: vector<4x4xi32>, [[ARG2:%.+]]: vector<4x4xi32>) +// CHECK-DAG: [[LHS:%.+]] = vector.transpose [[ARG0]], [1, 0] : vector<4x4xi32> to vector<4x4xi32> +// CHECK-NEXT: [[RES:%.+]] = vector.contract {{.+}} [[ARG1]], [[LHS]], [[ARG2]] +// CHECK-NEXT: return [[RES]] +func.func @matmul_km_nk_nm_4x4xi32(%arg0: vector<4x4xi32>, %arg1: vector<4x4xi32>, %arg2: vector<4x4xi32>) -> vector<4x4xi32> { + %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d0)>, + affine_map<(d0, d1, d2) -> (d1, d2)>, + affine_map<(d0, d1, d2) -> (d1, d0)>], + iterator_types = ["parallel", "parallel", "reduction"], + kind = #vector.kind} %arg0, %arg1, %arg2 : vector<4x4xi32>, vector<4x4xi32> into vector<4x4xi32> + return %res : vector<4x4xi32> +} diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -11,6 +11,7 @@ #include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" @@ -199,6 +200,33 @@ } }; +struct TestVectorContractionPrepareForMMTLowering + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( + TestVectorContractionPrepareForMMTLowering) + + StringRef getArgument() const final { + return "test-vector-contraction-prepare-for-mmt-lowering"; + } + StringRef getDescription() const final { + return "Test vector.contraction matmul canonicalization for MMT lowering."; + } + TestVectorContractionPrepareForMMTLowering() = default; + + void getDependentDialects(DialectRegistry ®istry) const override { + registry + .insert(); + } + + void runOnOperation() override { + MLIRContext *ctx = &getContext(); + RewritePatternSet patterns(ctx); + patterns.add(ctx); + (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + } +}; + struct TestVectorTransposeLowering : public PassWrapper> { @@ -892,6 +920,8 @@ PassRegistration(); + PassRegistration(); + PassRegistration(); PassRegistration();