diff --git a/mlir/test/Dialect/Linalg/matmul-to-vector.mlir b/mlir/test/Dialect/Linalg/matmul-to-vector.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/matmul-to-vector.mlir @@ -0,0 +1,16 @@ +// RUN: mlir-opt %s --linalg-matmul-to-vector + +func @matmul_perm(%A: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>, + %B: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>, + %C: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>) { + linalg.matmul(%A, %B, %C) {__internal_linalg_transform__ = "__with_perm__"} : + memref<1584x1584xf32, offset: 0, strides: [1584, 1]>, + memref<1584x1584xf32, offset: 0, strides: [1584, 1]>, + memref<1584x1584xf32, offset: 0, strides: [1584, 1]> + return +} + +// CHECK-LABEL:func @matmul_perm +// CHECK: vector.contract +// CHECK-SAME: {iterator_types = ["parallel", "parallel", "reduction"]} +// CHECK-SAME: : vector<8x16xf32>, vector<16x12xf32> into vector<8x12xf32> diff --git a/mlir/test/lib/DeclarativeTransforms/CMakeLists.txt b/mlir/test/lib/DeclarativeTransforms/CMakeLists.txt --- a/mlir/test/lib/DeclarativeTransforms/CMakeLists.txt +++ b/mlir/test/lib/DeclarativeTransforms/CMakeLists.txt @@ -5,3 +5,7 @@ set(LLVM_TARGET_DEFINITIONS TestVectorTransformPatterns.td) mlir_tablegen(TestVectorTransformPatterns.h.inc -gen-rewriters) add_public_tablegen_target(MLIRTestVectorTransformPatternsIncGen) + +set(LLVM_TARGET_DEFINITIONS TestLinalgMatmulToVectorPatterns.td) +mlir_tablegen(TestLinalgMatmulToVectorPatterns.h.inc -gen-rewriters) +add_public_tablegen_target(MLIRTestLinalgMatmulToVectorPatternsIncGen) diff --git a/mlir/test/lib/DeclarativeTransforms/TestLinalgMatmulToVectorPatterns.td b/mlir/test/lib/DeclarativeTransforms/TestLinalgMatmulToVectorPatterns.td new file mode 100644 --- /dev/null +++ b/mlir/test/lib/DeclarativeTransforms/TestLinalgMatmulToVectorPatterns.td @@ -0,0 +1,48 @@ +//===- TestLinalgMatmulToVectorPatterns.td - Test patterns -*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This is the pattern definition file for declarative Linalg transformations +// tests. +// +//===----------------------------------------------------------------------===// + +#ifndef TEST_LINALG_MATMUL_TO_VECTOR_PATTERNS +#define TEST_LINALG_MATMUL_TO_VECTOR_PATTERNS + +include "mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td" +include "mlir/Dialect/VectorOps/VectorTransformPatterns.td" + +//===----------------------------------------------------------------------===// +// Linalg tiling and permutation patterns. +//===----------------------------------------------------------------------===// +def : Pat<(MatmulOp:$op $_, $_, $_), + (TileLinalgOp<[768, 264, 768], "L2__with_perm__", [1, 2, 0]>), + [(Constraint>)]>; +def : Pat<(MatmulOp:$op $_, $_, $_), + (TileLinalgOp<[8, 12, 16], "L1__with_perm__", [1, 0, 2]>), + [(Constraint>)]>; +def : Pat<(MatmulOp:$op $_, $_, $_), + (PromoteSubviewsLinalgOp), + [(Constraint>), + (Constraint>)]>; + +//===----------------------------------------------------------------------===// +// Linalg to vector contraction patterns. +//===----------------------------------------------------------------------===// +def : Pattern<(MatmulOp:$op $_, $_, $_), + [(VectorizeLinalgOp)], + [(Constraint, + PreconditionVectorizeLinalgOp]>>)]>; +def : Pattern<(GenericOp:$op $_, $_, $_, $_, $_, $_, $_, $_), + [(VectorizeLinalgOp)], + [(Constraint, + PreconditionVectorizeLinalgOp]>>)]>; + +#endif // TEST_LINALG_MATMUL_TO_VECTOR_PATTERNS diff --git a/mlir/test/lib/Transforms/CMakeLists.txt b/mlir/test/lib/Transforms/CMakeLists.txt --- a/mlir/test/lib/Transforms/CMakeLists.txt +++ b/mlir/test/lib/Transforms/CMakeLists.txt @@ -5,6 +5,7 @@ TestLoopFusion.cpp TestGpuMemoryPromotion.cpp TestInlining.cpp + TestLinalgMatmulToVector.cpp TestLinalgTransforms.cpp TestLiveness.cpp TestLoopMapping.cpp @@ -25,6 +26,7 @@ add_dependencies(MLIRTestTransforms MLIRStandardOpsIncGen) add_dependencies(MLIRTestTransforms MLIRTestLinalgTransformPatternsIncGen) add_dependencies(MLIRTestTransforms MLIRTestVectorTransformPatternsIncGen) +add_dependencies(MLIRTestTransforms MLIRTestLinalgMatmulToVectorPatternsIncGen) target_link_libraries(MLIRTestTransforms MLIRAffineOps MLIRAnalysis diff --git a/mlir/test/lib/Transforms/TestLinalgMatmulToVector.cpp b/mlir/test/lib/Transforms/TestLinalgMatmulToVector.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Transforms/TestLinalgMatmulToVector.cpp @@ -0,0 +1,127 @@ +//===- TestLinalgMatmulToVector.cpp - Test VectorTransfers lowering -------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include + +#include "mlir/Dialect/Linalg/IR/LinalgOps.h" +#include "mlir/Dialect/Linalg/Transforms/LinalgTransforms.h" +#include "mlir/Dialect/VectorOps/VectorOps.h" +#include "mlir/Dialect/VectorOps/VectorTransforms.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/Pass/Pass.h" + +using namespace mlir; +using namespace mlir::linalg; +using namespace mlir::vector; + +namespace { +#include "TestLinalgMatmulToVectorPatterns.h.inc" + +/// Canonicalize a LinalgOp followed by either: +/// +/// ```mlir +/// ... = memref_cast ... : memref<8x16xf32> to memref +/// // or +/// ... = memref_cast ... : memref<8x16xf32, affine_map<(i, j)->(16 * i + j)>> +/// to memref +/// ``` +/// +/// into +/// +/// ```mlir +/// ... = memref_cast ... : ... to memref<8x16xf32> +/// ``` +/// +class MemRefCastDynamicRewrite final : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + PatternMatchResult matchAndRewrite(MemRefCastOp castOp, + PatternRewriter &rewriter) const override { + MemRefType sourceType = castOp.source().getType().dyn_cast(); + AffineMap sourceMap = (sourceType.getAffineMaps().empty()) + ? AffineMap() + : sourceType.getAffineMaps().front(); + MemRefType resultType = castOp.getType().dyn_cast(); + MemRefType desiredResultType = + MemRefType::get(sourceType.getShape(), sourceType.getElementType(), {}); + + // If we don't have MemRefType as source and destination, bail out. + if (!sourceType || !resultType) + return matchFailure(); + + // If we're already in canonical form all is good. + if (resultType == desiredResultType) + return matchFailure(); + + // If resultType has a map, it needs to be the same as the source type to + // canonicalize. + if (!resultType.getAffineMaps().empty() && sourceType != resultType) + return matchFailure(); + + // Inspect uses and bails out if it hits an op it does not know about. + // TODO(ntv, rriddle): this may be a good use case for an OpInterface. + for (auto &u : castOp.getResult().getUses()) { + if (isa(u.getOwner())) + continue; + return matchFailure(); + } + + // Ensure that: + // 1. source is static + // 2. source and target have the same rank (will be extended when needed) + // 3. if result is partially static, ensure sizes match. + if (!sourceType.hasStaticShape() || + sourceType.getRank() != resultType.getRank()) + return matchFailure(); + for (auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) { + auto sourceSize = std::get<0>(it); + auto resultSize = std::get<1>(it); + if (ShapedType::isDynamic(resultSize)) + continue; + if (sourceSize != resultSize) + return matchFailure(); + } + + if (sourceMap) { + int64_t offset; + SmallVector strides; + auto res = getStridesAndOffset(sourceType, strides, offset); + (void)res; + assert(succeeded(res)); + auto stridedMap = + makeStridedLinearLayoutMap(strides, offset, sourceMap.getContext()); + if (sourceMap != stridedMap) + return matchFailure(); + } + + rewriter.replaceOpWithNewOp(castOp, castOp.source(), + desiredResultType); + + return matchSuccess(); + } +}; + +struct DeclarativeTransforms : public FunctionPass { + void runOnFunction() override { + OwningRewritePatternList patterns; + auto *context = &getContext(); + populateWithGenerated(context, &patterns); + populateVectorToVectorCanonicalizationPatterns(patterns, context); + populateVectorToVectorTransformationPatterns(patterns, context); + patterns.insert(context); + applyPatternsGreedily(getFunction(), patterns); + } +}; +} // end anonymous namespace + +static PassRegistration + pass("linalg-matmul-to-vector", + "Test declarative transform patterns for matmul 3-D tiling + promotion" + " + vectorization");