diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -87,6 +87,10 @@ /// Pattern to convert TiledLoopOp to SCF loops. void populateTiledLoopToSCFPattern(RewritePatternSet &patterns); +/// Pattern to convert linalg.matmul to linalg.mmt4d. +void populateMatmulToMMT4DPatterns(RewritePatternSet &patterns, int M0, int N0, + int K0); + /// Options that control fusion of elementwise operations. struct LinalgElementwiseFusionOptions { /// Enable fusion of reshapes into the shape with elementwise operations. By @@ -911,8 +915,8 @@ /// scattering magic constants throughout the code base, the patterns must be /// added with this function. `baseBenefit` can be used to offset the benefit /// of all PadTensorOp vectorization patterns by a certain value. -void populatePadTensorOpVectorizationPatterns( - RewritePatternSet &patterns, PatternBenefit baseBenefit = 1); +void populatePadTensorOpVectorizationPatterns(RewritePatternSet &patterns, + PatternBenefit baseBenefit = 1); /// Match and rewrite for the pattern: /// ``` diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt @@ -13,6 +13,7 @@ InlineScalarOperands.cpp Interchange.cpp Loops.cpp + MatmulToMMT4d.cpp Promotion.cpp Tiling.cpp Transforms.cpp diff --git a/mlir/lib/Dialect/Linalg/Transforms/MatmulToMMT4d.cpp b/mlir/lib/Dialect/Linalg/Transforms/MatmulToMMT4d.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Linalg/Transforms/MatmulToMMT4d.cpp @@ -0,0 +1,153 @@ +//===- MatmulToMMT4d.cpp - Pass to inline scalar operands =============// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements patterns/pass to convert linalg.matmul into linalg.mmt4d +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" + +#include "mlir/Dialect/Linalg/IR/LinalgOps.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; +using namespace mlir::linalg; + +namespace { +class LinalgStaticMatmulOpToLinalgMMT4dOpPattern + : public OpRewritePattern { +public: + LinalgStaticMatmulOpToLinalgMMT4dOpPattern(MLIRContext *context, int M0, + int N0, int K0, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), M0Size(M0), N0Size(N0), + K0Size(K0) {} + + LogicalResult matchAndRewrite(MatmulOp matmulOp, + PatternRewriter &rewriter) const override { + auto loc = matmulOp.getLoc(); + + Value lhs = matmulOp.getInputOperand(0)->get(); + Value rhs = matmulOp.getInputOperand(1)->get(); + Value dst = matmulOp.getOutputOperand(0)->get(); + + RankedTensorType lhsType = lhs.getType().dyn_cast(); + RankedTensorType rhsType = rhs.getType().dyn_cast(); + + if (!lhsType || !rhsType || !lhsType.hasStaticShape() || + !rhsType.hasStaticShape()) { + return failure(); + } + + int m = lhsType.getShape()[0]; + int n = rhsType.getShape()[1]; + int k = rhsType.getShape()[0]; + + if (m % M0Size != 0 || n % N0Size != 0 || k % K0Size != 0) + return failure(); + + int m1 = m / M0Size; + int n1 = n / N0Size; + int k1 = k / K0Size; + + // Expands a 2d tensor operand to 4d given its target shape. + auto expandTo4D = [&](Value operand, + ArrayRef targetShape) -> Value { + auto operandType = operand.getType().cast(); + auto targetType = + RankedTensorType::get(targetShape, operandType.getElementType()); + SmallVector expandIndices = {{0, 1}, {2, 3}}; + Value reshapedOperand = rewriter.create( + loc, targetType, operand, expandIndices); + return reshapedOperand; + }; + + auto lhs4D = expandTo4D(lhs, {m1, M0Size, k1, K0Size}); + auto rhs4D = expandTo4D(rhs, {k1, K0Size, n1, N0Size}); + auto dst4D = expandTo4D(dst, {m1, M0Size, n1, N0Size}); + + auto transposeOperand = [&](Value operand, + ArrayRef indices) -> Value { + RankedTensorType operandTensorType = + operand.getType().cast(); + auto nloops = indices.size(); + auto inputShape = operandTensorType.getShape(); + + SmallVector exprs = llvm::to_vector<4>( + llvm::map_range(indices, [&](int64_t index) -> AffineExpr { + return rewriter.getAffineDimExpr(index); + })); + + SmallVector targetShape = llvm::to_vector<4>( + llvm::map_range(indices, [&](int64_t index) -> int64_t { + return inputShape[index]; + })); + + Value outputTensor = rewriter.create( + loc, targetShape, operandTensorType.getElementType()); + + SmallVector loopAttributeTypes(nloops, "parallel"); + + SmallVector indexingMaps = { + inversePermutation( + AffineMap::get(nloops, 0, exprs, rewriter.getContext())), + AffineMap::getMultiDimIdentityMap(nloops, rewriter.getContext())}; + + auto transposedOp = rewriter.create( + loc, outputTensor.getType(), + /*inputs=*/operand, /*outputs=*/outputTensor, indexingMaps, + loopAttributeTypes, + [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { + nestedBuilder.create(nestedLoc, args[0]); + }); + + return transposedOp.getResult(0); + }; + + auto lhs4DT = transposeOperand(lhs4D, {0, 2, 1, 3}); + auto rhs4DT = transposeOperand(rhs4D, {2, 0, 3, 1}); + auto dst4DT = transposeOperand(dst4D, {0, 2, 1, 3}); + + auto mmt4DResult = rewriter.create( + loc, dst4DT.getType(), ValueRange{lhs4DT, rhs4DT}, ValueRange{dst4DT}); + + auto mmt4dResultTransposed = + transposeOperand(mmt4DResult.getResult(0), {0, 2, 1, 3}); + + auto collapseTo2D = [&](Value operand, + ArrayRef targetShape) -> Value { + auto operandType = operand.getType().cast(); + auto targetType = + RankedTensorType::get(targetShape, operandType.getElementType()); + SmallVector collapseIndices = {{0, 1}, {2, 3}}; + Value reshapedOperand = rewriter.create( + loc, targetType, operand, collapseIndices); + return reshapedOperand; + }; + + Value result = collapseTo2D(mmt4dResultTransposed, {m, n}); + + rewriter.replaceOp(matmulOp, ArrayRef{result}); + + return success(); + } + +private: + int M0Size; + int N0Size; + int K0Size; +}; +} // namespace + +void mlir::linalg::populateMatmulToMMT4DPatterns(RewritePatternSet &patterns, + int M0, int N0, int K0) { + auto *context = patterns.getContext(); + patterns.add(context, M0, N0, K0); +} diff --git a/mlir/test/Dialect/Linalg/matmul_to_mmt4d.mlir b/mlir/test/Dialect/Linalg/matmul_to_mmt4d.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/matmul_to_mmt4d.mlir @@ -0,0 +1,52 @@ +// RUN: mlir-opt -split-input-file --test-linalg-matmul-to-mmt4d %s | FileCheck --check-prefix=CHECK %s + +func @check_mmt4d(%arg0: tensor<24x8xf32>, %arg1: tensor<8x32xf32>, %arg2: tensor<24x32xf32>) -> tensor<24x32xf32> { + %0 = linalg.matmul ins(%arg0, %arg1 : tensor<24x8xf32>, tensor<8x32xf32>) outs(%arg2 : tensor<24x32xf32>) -> tensor<24x32xf32> + return %0 : tensor<24x32xf32> +} +// CHECK-DAG:#[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)> +// CHECK-DAG:#[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK-DAG:#[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d3, d0, d2)> +// CHECK: @check_mmt4d(%[[LHS:.+]]: tensor<24x8xf32>, %[[RHS:.+]]: tensor<8x32xf32>, %[[DST:.+]]: tensor<24x32xf32>) +// CHECK: %[[LHS4D:.+]] = linalg.tensor_expand_shape %[[LHS]] +// CHECK-SAME: tensor<24x8xf32> into tensor<6x4x2x4xf32> +// CHECK: %[[RHS4D:.+]] = linalg.tensor_expand_shape %[[RHS]] +// CHECK-SAME: tensor<8x32xf32> into tensor<2x4x8x4xf32> +// CHECK: %[[DST4D:.+]] = linalg.tensor_expand_shape %[[DST]] +// CHECK-SAME: tensor<24x32xf32> into tensor<6x4x8x4xf32> +// CHECK: %[[LHS4DT_INIT:.+]] = linalg.init_tensor [6, 2, 4, 4] : tensor<6x2x4x4xf32> +// CHECK: %[[LHS4DT:.+]] = linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]] +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"] +// CHECK-SAME: ins(%[[LHS4D]] : tensor<6x4x2x4xf32>) outs(%[[LHS4DT_INIT]] : tensor<6x2x4x4xf32>) { +// CHECK-NEXT: ^bb0(%{{.*}}: f32, %{{.*}}: f32): +// CHECK-NEXT: linalg.yield +// CHECK-NEXT: } -> tensor<6x2x4x4xf32> +// CHECK: %[[RHS4DT_INIT:.+]] = linalg.init_tensor [8, 2, 4, 4] : tensor<8x2x4x4xf32> +// CHECK: %[[RHS4DT:.+]] = linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP1]]], +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"] +// CHECK-SAME: ins(%[[RHS4D]] : tensor<2x4x8x4xf32>) outs(%[[RHS4DT_INIT]] : tensor<8x2x4x4xf32>) { +// CHECK-NEXT: ^bb0(%{{.*}}: f32, %{{.*}}: f32): +// CHECK-NEXT: linalg.yield %arg3 : f32 +// CHECK-NEXT: } -> tensor<8x2x4x4xf32> +// CHECK-NEXT: %[[DST4DT_INIT:.+]] = linalg.init_tensor [6, 8, 4, 4] : tensor<6x8x4x4xf32> +// CHECK: %[[DST4DT:.+]] = linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]] +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"]} +// CHECK-SAME: ins(%[[DST4D]] : tensor<6x4x8x4xf32>) outs(%[[DST4DT_INIT]] : tensor<6x8x4x4xf32>) { +// CHECK-NEXT: ^bb0(%{{.*}}: f32, %{{.*}}: f32): +// CHECK-NEXT: linalg.yield %arg3 : f32 +// CHECK-NEXT: } -> tensor<6x8x4x4xf32> +// CHECK: %[[MMT4D:.+]] = linalg.mmt4d ins(%[[LHS4DT]], %[[RHS4DT]] : tensor<6x2x4x4xf32>, tensor<8x2x4x4xf32>) outs(%[[DST4DT]] : tensor<6x8x4x4xf32>) -> tensor<6x8x4x4xf32> +// CHECK: %[[MMT4DT_INIT:.+]] = linalg.init_tensor [6, 4, 8, 4] : tensor<6x4x8x4xf32> +// CHECK: %[[MMT4DT:.+]] = linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]] +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"]} +// CHECK-SAME: ins(%[[MMT4D]] : tensor<6x8x4x4xf32>) outs(%[[MMT4DT_INIT]] : tensor<6x4x8x4xf32>) { +// CHECK-NEXT: ^bb0(%{{.*}}: f32, %{{.*}}: f32): +// CHECK-NEXT: linalg.yield %arg3 : f32 +// CHECK-NEXT: } -> tensor<6x4x8x4xf32> +// CHECK: %[[RESULT:.+]] = linalg.tensor_collapse_shape %[[MMT4DT]] +// CHECK-SAME: tensor<6x4x8x4xf32> into tensor<24x32xf32> +// CHECK: return %[[RESULT]] : tensor<24x32xf32> diff --git a/mlir/test/lib/Dialect/Linalg/CMakeLists.txt b/mlir/test/lib/Dialect/Linalg/CMakeLists.txt --- a/mlir/test/lib/Dialect/Linalg/CMakeLists.txt +++ b/mlir/test/lib/Dialect/Linalg/CMakeLists.txt @@ -6,6 +6,7 @@ TestLinalgElementwiseFusion.cpp TestLinalgFusionTransforms.cpp TestLinalgHoisting.cpp + TestLinalgMatmulToMMT4d.cpp TestLinalgTransforms.cpp EXCLUDE_FROM_LIBMLIR diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgMatmulToMMT4d.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgMatmulToMMT4d.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgMatmulToMMT4d.cpp @@ -0,0 +1,66 @@ +//===- TestLinalgHoisting.cpp - Test Linalg hoisting functions ------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements logic for testing linalg.matmul to linalg.mmt4d +// conversion. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; +using namespace mlir::linalg; + +namespace { +struct TestLinalgMatmulToMMT4D + : PassWrapper { + TestLinalgMatmulToMMT4D() = default; + TestLinalgMatmulToMMT4D(const TestLinalgMatmulToMMT4D &pass) {} + + StringRef getArgument() const final { return "test-linalg-matmul-to-mmt4d"; } + StringRef getDescription() const final { + return "Test Linalg matmul -> mmt4d functions."; + } + void runOnFunction() override; + + Option testWithInnerDimM0{ + *this, "test-with-inner-dim-m0", + llvm::cl::desc("Test hoisting transfer_read/transfer_write pairs"), + llvm::cl::init(4)}; + + Option testWithInnerDimN0{ + *this, "test-with-inner-dim-n0", + llvm::cl::desc("Test hoisting transfer_read/transfer_write pairs"), + llvm::cl::init(4)}; + + Option testWithInnerDimK0{ + *this, "test-with-inner-dim-k0", + llvm::cl::desc("Test hoisting transfer_read/transfer_write pairs"), + llvm::cl::init(4)}; +}; + +void TestLinalgMatmulToMMT4D::runOnFunction() { + MLIRContext *context = &this->getContext(); + FuncOp funcOp = this->getFunction(); + RewritePatternSet patterns(context); + populateMatmulToMMT4DPatterns(patterns, testWithInnerDimM0, + testWithInnerDimN0, testWithInnerDimK0); + (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); +} + +} // namespace + +namespace mlir { +namespace test { +void registerTestLinalgMatmulToMMT4D() { + PassRegistration(); +} +} // namespace test +} // namespace mlir diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -86,6 +86,7 @@ void registerTestLinalgTiledLoopFusionTransforms(); void registerTestLinalgGreedyFusion(); void registerTestLinalgHoisting(); +void registerTestLinalgMatmulToMMT4D(); void registerTestLinalgTileAndFuseSequencePass(); void registerTestLinalgTransforms(); void registerTestLivenessPass(); @@ -167,6 +168,7 @@ test::registerTestLinalgTiledLoopFusionTransforms(); test::registerTestLinalgGreedyFusion(); test::registerTestLinalgHoisting(); + test::registerTestLinalgMatmulToMMT4D(); test::registerTestLinalgTileAndFuseSequencePass(); test::registerTestLinalgTransforms(); test::registerTestLivenessPass();