diff --git a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp @@ -40,17 +40,22 @@ using folded_affine_min = folded::ValueBuilder; using folded_linalg_range = folded::ValueBuilder; +using folded_std_dim = folded::ValueBuilder; +using folded_std_subview = folded::ValueBuilder; +using folded_std_view = folded::ValueBuilder; #define DEBUG_TYPE "linalg-promotion" -static Value allocBuffer(Type elementType, Value size, bool dynamicBuffers) { +static Value allocBuffer(Type elementType, Value size, bool dynamicBuffers, + OperationFolder *folder) { auto *ctx = size.getContext(); auto width = llvm::divideCeil(elementType.getIntOrFloatBitWidth(), 8); if (!dynamicBuffers) if (auto cst = dyn_cast_or_null(size.getDefiningOp())) return std_alloc( MemRefType::get(width * cst.getValue(), IntegerType::get(8, ctx))); - Value mul = std_muli(std_constant_index(width), size); + Value mul = + folded_std_muli(folder, folded_std_constant_index(folder, width), size); return std_alloc(MemRefType::get(-1, IntegerType::get(8, ctx)), mul); } @@ -87,18 +92,32 @@ for (auto en : llvm::enumerate(subView.getRanges())) { auto rank = en.index(); auto rangeValue = en.value(); + // Try to extract a tight constant Value d = rangeValue.size; + if (auto affineMinOp = dyn_cast_or_null(d.getDefiningOp())) { + int64_t minConst = std::numeric_limits::max(); + for (auto e : affineMinOp.getAffineMap().getResults()) + if (auto cst = e.dyn_cast()) + minConst = std::min(minConst, cst.getValue()); + if (minConst != std::numeric_limits::max()) + d = b.create(loc, minConst); + } allocSize = folded_std_muli(folder, allocSize, d).getValue(); fullRanges.push_back(d); - partialRanges.push_back( - folded_linalg_range(folder, zero, std_dim(subView, rank), one)); + partialRanges.push_back(folded_std_dim(folder, subView, rank)); } SmallVector dynSizes(fullRanges.size(), -1); auto buffer = - allocBuffer(viewType.getElementType(), allocSize, dynamicBuffers); - auto fullLocalView = std_view( + allocBuffer(viewType.getElementType(), allocSize, dynamicBuffers, folder); + auto fullLocalView = folded_std_view(folder, MemRefType::get(dynSizes, viewType.getElementType()), buffer, fullRanges); - auto partialLocalView = linalg_slice(fullLocalView, partialRanges); + // clang-format off + auto partialLocalView = folded_std_subview(folder, + fullLocalView, + SmallVector(fullRanges.size(), zero), + partialRanges, + SmallVector(fullRanges.size(), one)); + // clang-format on return PromotionInfo{buffer, fullLocalView, partialLocalView}; } diff --git a/mlir/test/Dialect/Linalg/matmul-to-vector.mlir b/mlir/test/Dialect/Linalg/matmul-to-vector.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/matmul-to-vector.mlir @@ -0,0 +1,26 @@ +// RUN: mlir-opt %s -linalg-matmul-to-vector -canonicalize -linalg-matmul-to-vector +//| FileCheck %s + +// TODO(ntv, rriddle) Need ViewOp canonicalizer to enable vectorization +// without phase ordering. Unfortunately there is some pattern interaction +// because ViewOp and MemRefCastOp canonicalization order need to happen +// after AffineApplyOp canonicalization to get fully static shapes. +// Unfortunately bumping up the priority of AffineApplyOp canonicalization +// does not help and making ViewOp canonicalization know about AffineApplyOp +// does not work either due to cyclic dependences between dialects. +// So for now, we give up and we still have phase ordering issues.. + +func @matmul_perm(%A: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>, + %B: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>, + %C: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>) { + linalg.matmul(%A, %B, %C) {__internal_linalg_transform__ = "__with_perm__"} : + memref<1584x1584xf32, offset: 0, strides: [1584, 1]>, + memref<1584x1584xf32, offset: 0, strides: [1584, 1]>, + memref<1584x1584xf32, offset: 0, strides: [1584, 1]> + return +} + +// CHECK-LABEL:func @matmul_perm +// CHECK: vector.contract +// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"] +// CHECK-SAME: : vector<8x16xf32>, vector<16x12xf32> into vector<8x12xf32> diff --git a/mlir/test/lib/DeclarativeTransforms/CMakeLists.txt b/mlir/test/lib/DeclarativeTransforms/CMakeLists.txt --- a/mlir/test/lib/DeclarativeTransforms/CMakeLists.txt +++ b/mlir/test/lib/DeclarativeTransforms/CMakeLists.txt @@ -5,3 +5,7 @@ set(LLVM_TARGET_DEFINITIONS TestVectorTransformPatterns.td) mlir_tablegen(TestVectorTransformPatterns.h.inc -gen-rewriters) add_public_tablegen_target(MLIRTestVectorTransformPatternsIncGen) + +set(LLVM_TARGET_DEFINITIONS TestLinalgMatmulToVectorPatterns.td) +mlir_tablegen(TestLinalgMatmulToVectorPatterns.h.inc -gen-rewriters) +add_public_tablegen_target(MLIRTestLinalgMatmulToVectorPatternsIncGen) diff --git a/mlir/test/lib/DeclarativeTransforms/TestLinalgMatmulToVectorPatterns.td b/mlir/test/lib/DeclarativeTransforms/TestLinalgMatmulToVectorPatterns.td new file mode 100644 --- /dev/null +++ b/mlir/test/lib/DeclarativeTransforms/TestLinalgMatmulToVectorPatterns.td @@ -0,0 +1,43 @@ +//===- TestLinalgMatmulToVectorPatterns.td - Test patterns -*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This is the pattern definition file for declarative Linalg transformations +// tests. +// +//===----------------------------------------------------------------------===// + +#ifndef TEST_LINALG_MATMUL_TO_VECTOR_PATTERNS +#define TEST_LINALG_MATMUL_TO_VECTOR_PATTERNS + +include "mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td" +include "mlir/Dialect/Vector/VectorTransformPatterns.td" + +//===----------------------------------------------------------------------===// +// Linalg tiling and permutation patterns. +//===----------------------------------------------------------------------===// +def : Pat<(MatmulOp:$op $_, $_, $_), + (TileLinalgOp<[768, 264, 768], "L2__with_perm__", [1, 2, 0]>), + [(Constraint>)]>; +def : Pat<(MatmulOp:$op $_, $_, $_), + (TileLinalgOp<[8, 12, 16], "L1__with_perm__", [1, 0, 2]>), + [(Constraint>)]>; +def : Pat<(MatmulOp:$op $_, $_, $_), + (PromoteSubviewsLinalgOp), + [(Constraint>), + (Constraint>)]>; + +//===----------------------------------------------------------------------===// +// Linalg to vector contraction patterns. +//===----------------------------------------------------------------------===// +def : Pattern<(MatmulOp:$op $_, $_, $_), + [(VectorizeLinalgOp)], + [(Constraint, + PreconditionVectorizeLinalgOp]>>)]>; + +#endif // TEST_LINALG_MATMUL_TO_VECTOR_PATTERNS diff --git a/mlir/test/lib/Transforms/CMakeLists.txt b/mlir/test/lib/Transforms/CMakeLists.txt --- a/mlir/test/lib/Transforms/CMakeLists.txt +++ b/mlir/test/lib/Transforms/CMakeLists.txt @@ -8,6 +8,7 @@ TestGpuMemoryPromotion.cpp TestGpuParallelLoopMapping.cpp TestInlining.cpp + TestLinalgMatmulToVector.cpp TestLinalgTransforms.cpp TestLiveness.cpp TestLoopMapping.cpp @@ -24,6 +25,7 @@ DEPENDS MLIRStandardOpsIncGen + MLIRTestLinalgMatmulToVectorPatternsIncGen MLIRTestLinalgTransformPatternsIncGen MLIRTestVectorTransformPatternsIncGen ) diff --git a/mlir/test/lib/Transforms/TestLinalgMatmulToVector.cpp b/mlir/test/lib/Transforms/TestLinalgMatmulToVector.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Transforms/TestLinalgMatmulToVector.cpp @@ -0,0 +1,50 @@ +//===- TestLinalgMatmulToVector.cpp - Test VectorTransfers lowering -------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Linalg/IR/LinalgOps.h" +#include "mlir/Dialect/Linalg/Transforms/LinalgTransforms.h" +#include "mlir/Dialect/Vector/VectorOps.h" +#include "mlir/Dialect/Vector/VectorTransforms.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/Pass/Pass.h" + +using namespace mlir; +using namespace mlir::linalg; +using namespace mlir::vector; + +namespace { +#include "TestLinalgMatmulToVectorPatterns.h.inc" + +struct DeclarativeTransforms : public FunctionPass { + void runOnFunction() override { + OwningRewritePatternList patterns; + auto *context = &getContext(); + AffineApplyOp::getCanonicalizationPatterns(patterns, context); + AffineMinOp::getCanonicalizationPatterns(patterns, context); + AffineMaxOp::getCanonicalizationPatterns(patterns, context); + AllocOp::getCanonicalizationPatterns(patterns, context); + SubViewOp::getCanonicalizationPatterns(patterns, context); + ViewOp::getCanonicalizationPatterns(patterns, context); + populateWithGenerated(context, &patterns); + applyPatternsGreedily(getFunction(), patterns); + } +}; +} // end anonymous namespace + +namespace mlir { +void registerTestLinalgMatmulToVectorPass() { + PassRegistration pass( + "linalg-matmul-to-vector", + "Test declarative transform patterns for matmul 3-D tiling + promotion" + " + vectorization"); +} +} // namespace mlir diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -39,6 +39,7 @@ void registerSymbolTestPasses(); void registerTestAffineDataCopyPass(); void registerTestAllReduceLoweringPass(); +void registerTestLinalgMatmulToVectorPass(); void registerTestLoopPermutationPass(); void registerTestCallGraphPass(); void registerTestConstantFold(); @@ -101,6 +102,7 @@ registerSymbolTestPasses(); registerTestAffineDataCopyPass(); registerTestAllReduceLoweringPass(); + registerTestLinalgMatmulToVectorPass(); registerTestLoopPermutationPass(); registerTestCallGraphPass(); registerTestConstantFold();