diff --git a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h --- a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h +++ b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h @@ -24,27 +24,23 @@ namespace mlir { -inline bool isRowMajorMatmul(ArrayAttr indexingMaps) { - auto context = indexingMaps.getContext(); - AffineExpr m, n, k; - bindDims(context, m, n, k); - auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, context)); - auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, context)); - auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, context)); - auto maps = ArrayAttr::get(context, {mapA, mapB, mapC}); - return indexingMaps == maps; -} - -inline bool isColumnMajorMatmul(ArrayAttr indexingMaps) { - auto context = indexingMaps.getContext(); - AffineExpr m, n, k; - bindDims(context, m, n, k); - auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, context)); - auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, context)); - auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {n, m}, context)); - auto maps = ArrayAttr::get(context, {mapA, mapB, mapC}); - return indexingMaps == maps; -} +/// Tests whether the given maps describe a row major matmul. The test is +/// permutation-invariant. Note that this only checks the affine maps from an +/// operation, so does not perform any checks on the math being performed within +/// the reduction. +bool isRowMajorMatmul(ArrayAttr indexingMaps); + +/// Tests whether the given maps describe a column major matmul. The test is +/// permutation-invariant. Note that this only checks the affine maps from an +/// operation, so does not perform any checks on the math being performed within +/// the reduction. +bool isColumnMajorMatmul(ArrayAttr indexingMaps); + +/// Tests whether the given maps describe a row major batch matmul. The test is +/// permutation-invariant. Note that this only checks the affine maps from an +/// operation, so does not perform any checks on the math being performed within +/// the reduction. +bool isRowMajorBatchMatmul(ArrayAttr indexingMaps); /// Attribute name for the AffineArrayAttr which encodes the relationship /// between a structured op iterators' and its operands. diff --git a/mlir/lib/Dialect/CMakeLists.txt b/mlir/lib/Dialect/CMakeLists.txt --- a/mlir/lib/Dialect/CMakeLists.txt +++ b/mlir/lib/Dialect/CMakeLists.txt @@ -20,6 +20,7 @@ add_subdirectory(StandardOps) add_subdirectory(Tensor) add_subdirectory(Tosa) +add_subdirectory(Utils) add_subdirectory(Vector) set(LLVM_OPTIONAL_SOURCES diff --git a/mlir/lib/Dialect/Utils/CMakeLists.txt b/mlir/lib/Dialect/Utils/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Utils/CMakeLists.txt @@ -0,0 +1,6 @@ +add_mlir_library(MLIRDialectUtils + StructuredOpsUtils.cpp + + LINK_LIBS PUBLIC + MLIRIR +) diff --git a/mlir/lib/Dialect/Utils/StructuredOpsUtils.cpp b/mlir/lib/Dialect/Utils/StructuredOpsUtils.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Utils/StructuredOpsUtils.cpp @@ -0,0 +1,92 @@ +//===- StructuredOpsUtils.cpp - Utilities used by structured ops ----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Utils/StructuredOpsUtils.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/BuiltinAttributes.h" + +using namespace mlir; + +bool mlir::isRowMajorMatmul(ArrayAttr indexingMaps) { + if (indexingMaps.size() != 3) + return false; + + auto map0 = indexingMaps[0].cast().getValue(); + auto map1 = indexingMaps[1].cast().getValue(); + auto map2 = indexingMaps[2].cast().getValue(); + + if (map0.getNumResults() != 2 || map1.getNumResults() != 2 || + map2.getNumResults() != 2 || map0.getNumInputs() != 3 || + map1.getNumInputs() != 3 || map2.getNumInputs() != 3) { + return false; + } + + // Extract dimensions for MxK * KxN -> MxN + AffineExpr m = map2.getResult(0); + AffineExpr n = map2.getResult(1); + AffineExpr k = map0.getResult(1); + auto *context = indexingMaps.getContext(); + auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, context)); + auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, context)); + auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, context)); + auto maps = ArrayAttr::get(context, {mapA, mapB, mapC}); + return indexingMaps == maps; +} + +bool mlir::isColumnMajorMatmul(ArrayAttr indexingMaps) { + if (indexingMaps.size() != 3) + return false; + + auto map0 = indexingMaps[0].cast().getValue(); + auto map1 = indexingMaps[1].cast().getValue(); + auto map2 = indexingMaps[2].cast().getValue(); + + if (map0.getNumResults() != 2 || map1.getNumResults() != 2 || + map2.getNumResults() != 2 || map0.getNumInputs() != 3 || + map1.getNumInputs() != 3 || map2.getNumInputs() != 3) { + return false; + } + + // Extract dimensions for KxM * NxK -> NxM + AffineExpr n = map2.getResult(0); + AffineExpr m = map2.getResult(1); + AffineExpr k = map0.getResult(0); + auto *context = indexingMaps.getContext(); + auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {k, m}, context)); + auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {n, k}, context)); + auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {n, m}, context)); + auto maps = ArrayAttr::get(context, {mapA, mapB, mapC}); + return indexingMaps == maps; +} + +bool mlir::isRowMajorBatchMatmul(ArrayAttr indexingMaps) { + if (indexingMaps.size() != 3) + return false; + + auto map0 = indexingMaps[0].cast().getValue(); + auto map1 = indexingMaps[1].cast().getValue(); + auto map2 = indexingMaps[2].cast().getValue(); + + if (map0.getNumResults() != 3 || map1.getNumResults() != 3 || + map2.getNumResults() != 3 || map0.getNumInputs() != 4 || + map1.getNumInputs() != 4 || map2.getNumInputs() != 4) { + return false; + } + + // Extract dimensions for BxMxK * BxKxN -> BxMxN + AffineExpr b = map2.getResult(0); + AffineExpr m = map2.getResult(1); + AffineExpr n = map2.getResult(2); + AffineExpr k = map0.getResult(2); + auto *context = indexingMaps.getContext(); + auto mapA = AffineMapAttr::get(AffineMap::get(4, 0, {b, m, k}, context)); + auto mapB = AffineMapAttr::get(AffineMap::get(4, 0, {b, k, n}, context)); + auto mapC = AffineMapAttr::get(AffineMap::get(4, 0, {b, m, n}, context)); + auto maps = ArrayAttr::get(context, {mapA, mapB, mapC}); + return indexingMaps == maps; +} diff --git a/mlir/lib/Dialect/Vector/CMakeLists.txt b/mlir/lib/Dialect/Vector/CMakeLists.txt --- a/mlir/lib/Dialect/Vector/CMakeLists.txt +++ b/mlir/lib/Dialect/Vector/CMakeLists.txt @@ -15,6 +15,7 @@ LINK_LIBS PUBLIC MLIRAffineEDSC MLIREDSC + MLIRDialectUtils MLIRIR MLIRStandard MLIRAffine diff --git a/mlir/unittests/Dialect/Utils/CMakeLists.txt b/mlir/unittests/Dialect/Utils/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/unittests/Dialect/Utils/CMakeLists.txt @@ -0,0 +1,6 @@ +add_mlir_unittest(MLIRDialectUtilsTests + StructuredOpsUtilsTest.cpp +) +target_link_libraries(MLIRDialectUtilsTests + PRIVATE + MLIRDialectUtils) diff --git a/mlir/unittests/Dialect/Utils/StructuredOpsUtilsTest.cpp b/mlir/unittests/Dialect/Utils/StructuredOpsUtilsTest.cpp new file mode 100644 --- /dev/null +++ b/mlir/unittests/Dialect/Utils/StructuredOpsUtilsTest.cpp @@ -0,0 +1,256 @@ +//===- StructuredOpsUtilsTest.cpp - StructuredOpsUtils unit tests ---------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Utils/StructuredOpsUtils.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/AffineMap.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +using namespace mlir; +using testing::Not; +using testing::Truly; + +namespace { + +TEST(isRowMajorMatmul, Simple) { + MLIRContext context; + + AffineExpr m, n, k; + bindDims(&context, m, n, k); + auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, &context)); + auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, &context)); + auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, &context)); + auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC}); + + EXPECT_THAT(maps, Truly(isRowMajorMatmul)); +} + +TEST(isRowMajorMatmul, BindingShifted) { + MLIRContext context; + + AffineExpr m, n, k; + bindDims(&context, k, m, n); // bind in different order + auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, &context)); + auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, &context)); + auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, &context)); + auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC}); + + EXPECT_THAT(maps, Truly(isRowMajorMatmul)); +} + +TEST(isRowMajorMatmul, BindingSwapped) { + MLIRContext context; + + AffineExpr m, n, k; + bindDims(&context, k, n, m); // bind in different order + auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, &context)); + auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, &context)); + auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, &context)); + auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC}); + + EXPECT_THAT(maps, Truly(isRowMajorMatmul)); +} + +TEST(isRowMajorMatmul, ColumnMajor) { + MLIRContext context; + + AffineExpr m, n, k; + bindDims(&context, m, n, k); + auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, &context)); + auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, &context)); + auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, &context)); + auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC}); + + EXPECT_THAT(maps, Not(Truly(isRowMajorMatmul))); +} + +TEST(isRowMajorMatmul, FirstInputSwapped) { + MLIRContext context; + + AffineExpr m, n, k; + bindDims(&context, m, n, k); + auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {k, m}, &context)); + auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, &context)); + auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, &context)); + auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC}); + + EXPECT_THAT(maps, Not(Truly(isRowMajorMatmul))); +} + +TEST(isRowMajorMatmul, TooFewMaps) { + MLIRContext context; + + AffineExpr m, n, k; + bindDims(&context, m, n, k); + auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, &context)); + auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, &context)); + auto maps = ArrayAttr::get(&context, {mapA, mapB}); + + EXPECT_THAT(maps, Not(Truly(isRowMajorMatmul))); +} + +TEST(isRowMajorMatmul, TooManyMaps) { + MLIRContext context; + + AffineExpr m, n, k; + bindDims(&context, m, n, k); + auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, &context)); + auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, &context)); + auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, &context)); + auto mapD = AffineMapAttr::get(AffineMap::get(3, 0, {k, m}, &context)); + + auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC, mapD}); + + EXPECT_THAT(maps, Not(Truly(isRowMajorMatmul))); +} + +TEST(isRowMajorMatmul, TooFewDims) { + MLIRContext context; + + AffineExpr m, n, k; + bindDims(&context, m, n, k); + auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, &context)); + auto mapB = AffineMapAttr::get(AffineMap::get(2, 0, {k, n}, &context)); + auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, &context)); + auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC}); + + EXPECT_THAT(maps, Not(Truly(isRowMajorMatmul))); +} + +TEST(isRowMajorMatmul, TooFewOutputs) { + MLIRContext context; + + AffineExpr m, n, k; + bindDims(&context, m, n, k); + auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {m}, &context)); + auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, &context)); + auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, &context)); + auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC}); + + EXPECT_THAT(maps, Not(Truly(isRowMajorMatmul))); +} + +TEST(isColumnMajorMatmul, Simple) { + MLIRContext context; + + AffineExpr m, n, k; + bindDims(&context, m, n, k); + auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, &context)); + auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, &context)); + auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, &context)); + auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC}); + + EXPECT_THAT(maps, Truly(isColumnMajorMatmul)); +} + +TEST(isColumnMajorMatmul, BindingShifted) { + MLIRContext context; + + AffineExpr m, n, k; + bindDims(&context, k, m, n); // bind in different order + auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, &context)); + auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, &context)); + auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, &context)); + auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC}); + + EXPECT_THAT(maps, Truly(isColumnMajorMatmul)); +} + +TEST(isColumnMajorMatmul, BindingSwapped) { + MLIRContext context; + + AffineExpr m, n, k; + bindDims(&context, k, n, m); // bind in different order + auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, &context)); + auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, &context)); + auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, &context)); + auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC}); + + EXPECT_THAT(maps, Truly(isColumnMajorMatmul)); +} + +TEST(isColumnMajorMatmul, RowMajor) { + MLIRContext context; + + AffineExpr m, n, k; + bindDims(&context, m, n, k); + auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, &context)); + auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, &context)); + auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, &context)); + auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC}); + + EXPECT_THAT(maps, Not(Truly(isColumnMajorMatmul))); +} + +TEST(isColumnMajorMatmul, FirstInputSwapped) { + MLIRContext context; + + AffineExpr m, n, k; + bindDims(&context, m, n, k); + auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {n, k}, &context)); + auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, &context)); + auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, &context)); + auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC}); + + EXPECT_THAT(maps, Not(Truly(isColumnMajorMatmul))); +} + +TEST(isRowMajorBatchMatmul, Simple) { + MLIRContext context; + + AffineExpr batch, m, n, k; + bindDims(&context, batch, m, n, k); + auto mapA = AffineMapAttr::get(AffineMap::get(4, 0, {batch, m, k}, &context)); + auto mapB = AffineMapAttr::get(AffineMap::get(4, 0, {batch, k, n}, &context)); + auto mapC = AffineMapAttr::get(AffineMap::get(4, 0, {batch, m, n}, &context)); + auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC}); + + EXPECT_THAT(maps, Truly(isRowMajorBatchMatmul)); +} + +TEST(isRowMajorBatchMatmul, BindingShifted) { + MLIRContext context; + + AffineExpr batch, m, n, k; + bindDims(&context, k, batch, m, n); // bind in different order + auto mapA = AffineMapAttr::get(AffineMap::get(4, 0, {batch, m, k}, &context)); + auto mapB = AffineMapAttr::get(AffineMap::get(4, 0, {batch, k, n}, &context)); + auto mapC = AffineMapAttr::get(AffineMap::get(4, 0, {batch, m, n}, &context)); + auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC}); + + EXPECT_THAT(maps, Truly(isRowMajorBatchMatmul)); +} + +TEST(isRowMajorBatchMatmul, BindingSwapped) { + MLIRContext context; + + AffineExpr batch, m, n, k; + bindDims(&context, batch, k, n, m); // bind in different order + auto mapA = AffineMapAttr::get(AffineMap::get(4, 0, {batch, m, k}, &context)); + auto mapB = AffineMapAttr::get(AffineMap::get(4, 0, {batch, k, n}, &context)); + auto mapC = AffineMapAttr::get(AffineMap::get(4, 0, {batch, m, n}, &context)); + auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC}); + + EXPECT_THAT(maps, Truly(isRowMajorBatchMatmul)); +} + +TEST(isRowMajorBatchMatmul, FirstInputSwapped) { + MLIRContext context; + + AffineExpr batch, m, n, k; + bindDims(&context, batch, m, n, k); + auto mapA = AffineMapAttr::get(AffineMap::get(4, 0, {batch, k, m}, &context)); + auto mapB = AffineMapAttr::get(AffineMap::get(4, 0, {batch, k, n}, &context)); + auto mapC = AffineMapAttr::get(AffineMap::get(4, 0, {batch, m, n}, &context)); + auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC}); + + EXPECT_THAT(maps, Not(Truly(isRowMajorBatchMatmul))); +} + +} // namespace