diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h @@ -78,6 +78,7 @@ #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.h.inc" } // namespace linalg + } // namespace mlir #endif // MLIR_DIALECT_LINALG_LINALGOPS_H_ diff --git a/mlir/test/Dialect/Linalg/matmul-to-vector.mlir b/mlir/test/Dialect/Linalg/matmul-to-vector.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/matmul-to-vector.mlir @@ -0,0 +1,25 @@ +// RUN: mlir-opt %s -linalg-matmul-to-vector -canonicalize -linalg-matmul-to-vector | FileCheck %s + +// TODO(ntv, rriddle) Need ViewOp canonicalizer to enable vectorization +// without phase ordering. Unfortunately there is some pattern interaction +// because ViewOp and MemRefCastOp canonicalization order need to happen +// after AffineApplyOp canonicalization to get fully static shapes. +// Unfortunately bumping up the priority of AffineApplyOp canonicalization +// does not help and making ViewOp canonicalization know about AffineApplyOp +// does not work either due to cyclic dependences between dialects. +// So for now, we give up and we still have phase ordering issues.. + +func @matmul_perm(%A: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>, + %B: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>, + %C: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>) { + linalg.matmul(%A, %B, %C) {__internal_linalg_transform__ = "__with_perm__"} : + memref<1584x1584xf32, offset: 0, strides: [1584, 1]>, + memref<1584x1584xf32, offset: 0, strides: [1584, 1]>, + memref<1584x1584xf32, offset: 0, strides: [1584, 1]> + return +} + +// CHECK-LABEL:func @matmul_perm +// CHECK: vector.contract +// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"] +// CHECK-SAME: : vector<8x16xf32>, vector<16x12xf32> into vector<8x12xf32> diff --git a/mlir/test/lib/DeclarativeTransforms/CMakeLists.txt b/mlir/test/lib/DeclarativeTransforms/CMakeLists.txt --- a/mlir/test/lib/DeclarativeTransforms/CMakeLists.txt +++ b/mlir/test/lib/DeclarativeTransforms/CMakeLists.txt @@ -5,3 +5,7 @@ set(LLVM_TARGET_DEFINITIONS TestVectorTransformPatterns.td) mlir_tablegen(TestVectorTransformPatterns.h.inc -gen-rewriters) add_public_tablegen_target(MLIRTestVectorTransformPatternsIncGen) + +set(LLVM_TARGET_DEFINITIONS TestLinalgMatmulToVectorPatterns.td) +mlir_tablegen(TestLinalgMatmulToVectorPatterns.h.inc -gen-rewriters) +add_public_tablegen_target(MLIRTestLinalgMatmulToVectorPatternsIncGen) diff --git a/mlir/test/lib/DeclarativeTransforms/TestLinalgMatmulToVectorPatterns.td b/mlir/test/lib/DeclarativeTransforms/TestLinalgMatmulToVectorPatterns.td new file mode 100644 --- /dev/null +++ b/mlir/test/lib/DeclarativeTransforms/TestLinalgMatmulToVectorPatterns.td @@ -0,0 +1,43 @@ +//===- TestLinalgMatmulToVectorPatterns.td - Test patterns -*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This is the pattern definition file for declarative Linalg transformations +// tests. +// +//===----------------------------------------------------------------------===// + +#ifndef TEST_LINALG_MATMUL_TO_VECTOR_PATTERNS +#define TEST_LINALG_MATMUL_TO_VECTOR_PATTERNS + +include "mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td" +include "mlir/Dialect/VectorOps/VectorTransformPatterns.td" + +//===----------------------------------------------------------------------===// +// Linalg tiling and permutation patterns. +//===----------------------------------------------------------------------===// +def : Pat<(MatmulOp:$op $_, $_, $_), + (TileLinalgOp<[768, 264, 768], "L2__with_perm__", [1, 2, 0]>), + [(Constraint>)]>; +def : Pat<(MatmulOp:$op $_, $_, $_), + (TileLinalgOp<[8, 12, 16], "L1__with_perm__", [1, 0, 2]>), + [(Constraint>)]>; +def : Pat<(MatmulOp:$op $_, $_, $_), + (PromoteSubviewsLinalgOp), + [(Constraint>), + (Constraint>)]>; + +//===----------------------------------------------------------------------===// +// Linalg to vector contraction patterns. +//===----------------------------------------------------------------------===// +def : Pattern<(MatmulOp:$op $_, $_, $_), + [(VectorizeLinalgOp)], + [(Constraint, + PreconditionVectorizeLinalgOp]>>)]>; + +#endif // TEST_LINALG_MATMUL_TO_VECTOR_PATTERNS diff --git a/mlir/test/lib/Transforms/CMakeLists.txt b/mlir/test/lib/Transforms/CMakeLists.txt --- a/mlir/test/lib/Transforms/CMakeLists.txt +++ b/mlir/test/lib/Transforms/CMakeLists.txt @@ -5,6 +5,7 @@ TestLoopFusion.cpp TestGpuMemoryPromotion.cpp TestInlining.cpp + TestLinalgMatmulToVector.cpp TestLinalgTransforms.cpp TestLiveness.cpp TestLoopMapping.cpp @@ -25,6 +26,7 @@ add_dependencies(MLIRTestTransforms MLIRStandardOpsIncGen) add_dependencies(MLIRTestTransforms MLIRTestLinalgTransformPatternsIncGen) add_dependencies(MLIRTestTransforms MLIRTestVectorTransformPatternsIncGen) +add_dependencies(MLIRTestTransforms MLIRTestLinalgMatmulToVectorPatternsIncGen) target_link_libraries(MLIRTestTransforms MLIRAffineOps MLIRAnalysis diff --git a/mlir/test/lib/Transforms/TestLinalgMatmulToVector.cpp b/mlir/test/lib/Transforms/TestLinalgMatmulToVector.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Transforms/TestLinalgMatmulToVector.cpp @@ -0,0 +1,61 @@ +//===- TestLinalgMatmulToVector.cpp - Test VectorTransfers lowering -------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include + +#include "mlir/Dialect/AffineOps/AffineOps.h" +#include "mlir/Dialect/Linalg/IR/LinalgOps.h" +#include "mlir/Dialect/Linalg/Transforms/LinalgTransforms.h" +#include "mlir/Dialect/VectorOps/VectorOps.h" +#include "mlir/Dialect/VectorOps/VectorTransforms.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/Pass/Pass.h" + +using namespace mlir; +using namespace mlir::linalg; +using namespace mlir::vector; + +namespace { +#include "TestLinalgMatmulToVectorPatterns.h.inc" + +struct DeclarativeTransforms : public FunctionPass { + void runOnFunction() override { + OwningRewritePatternList patterns; + auto *context = &getContext(); + populateWithGenerated(context, &patterns); + populateVectorToVectorCanonicalizationPatterns(patterns, context); + populateVectorToVectorTransformationPatterns(patterns, context); + // TODO(ntv, rriddle) Need ViewOp canonicalizer to enable vectorization + // without phase ordering. Unfortunately there is some pattern interaction + // because ViewOp and MemRefCastOp canonicalization order need to happen + // after AffineApplyOp canonicalization to get fully static shapes. + // Unfortunately bumping up the priority of AffineApplyOp canonicalization + // does not help and making ViewOp canonicalization know about AffineApplyOp + // does not work either due to cyclic dependences between dialects. + // So for now, we give up and we still have phase ordering issues.. + // AffineApplyOp::getCanonicalizationPatterns(patterns, context); + // AllocOp::getCanonicalizationPatterns(patterns, context); + // MemRefCastOp::getCanonicalizationPatterns(patterns, context); + // SubViewOp::getCanonicalizationPatterns(patterns, context); + // ViewOp::getCanonicalizationPatterns(patterns, context); + // Need MemRefCastDynamicFolder to cleanup after ViewOp canonicalizer and + // enable vectorization. + // + // TODO(ntv): once phase ordering issues are solved, we should expose Linalg + // canonicalization patterns and call them here. + // populateLinalgCanonicalizationPatterns(patterns, context); + applyPatternsGreedily(getFunction(), patterns); + } +}; +} // end anonymous namespace + +static PassRegistration + pass("linalg-matmul-to-vector", + "Test declarative transform patterns for matmul 3-D tiling + promotion" + " + vectorization");