diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -32,6 +32,15 @@ // General utilities //===----------------------------------------------------------------------===// +/// Check if all indexing maps are projected permutations. +bool allIndexingsAreProjectedPermutation(LinalgOp op); + +/// Detect whether `r` has only ConstantOp, ElementwiseMappable and YieldOp. +bool hasOnlyScalarElementwiseOp(Region &r); + +/// Check if a LinalgOp is an element-wise operation. +bool isElementwise(LinalgOp op); + /// Check if `permutation` is a permutation of the range /// `[0, permutation.size())`. bool isPermutation(ArrayRef permutation); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -417,48 +417,6 @@ llvm::to_vector<4>(returnTypes), op->getAttrs())}; } -/// Detect whether `r` has only ConstantOp, ElementwiseMappable and YieldOp. -static bool hasOnlyScalarElementwiseOp(Region &r) { - if (!llvm::hasSingleElement(r)) - return false; - for (Operation &op : r.front()) { - if (!(isa(op) || - OpTrait::hasElementwiseMappableTraits(&op)) || - llvm::any_of(op.getResultTypes(), - [](Type type) { return !type.isIntOrIndexOrFloat(); })) - return false; - } - return true; -} - -/// Returns `true` if all indexing maps of the linalg op are projected -/// permutations. -static bool allIndexingsAreProjectedPermutation(LinalgOp op) { - return llvm::all_of(op.getIndexingMaps(), [](AffineMap m) { - return m.isProjectedPermutation(/*allowZeroInResults=*/true); - }); -} - -// Return true if the op is an element-wise linalg op. -static bool isElementwise(Operation *op) { - auto linalgOp = dyn_cast(op); - if (!linalgOp) - return false; - if (linalgOp.getNumLoops() != linalgOp.getNumParallelLoops()) - return false; - - if (!allIndexingsAreProjectedPermutation(linalgOp)) - return false; - - // TODO: relax the restrictions on indexing map. - for (OpOperand *opOperand : linalgOp.getOutputOperands()) { - if (!linalgOp.getTiedIndexingMap(opOperand).isPermutation()) - return false; - } - return hasOnlyScalarElementwiseOp(linalgOp->getRegion(0)); -} - /// Generic vectorization function that rewrites the body of a `linalgOp` into /// vector form. Generic vectorization proceeds as follows: /// 1. Verify the `linalgOp` has one non-empty region. diff --git a/mlir/lib/Dialect/Linalg/Utils/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Utils/CMakeLists.txt --- a/mlir/lib/Dialect/Linalg/Utils/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/Utils/CMakeLists.txt @@ -9,6 +9,7 @@ MLIRAffineAnalysis MLIRAffineUtils MLIRArithmeticDialect + MLIRFuncDialect MLIRIR MLIRLinalgDialect MLIRSCFDialect diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -19,6 +19,7 @@ #include "mlir/Dialect/Affine/LoopUtils.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Arithmetic/Utils/Utils.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" @@ -141,6 +142,41 @@ namespace mlir { namespace linalg { +bool allIndexingsAreProjectedPermutation(LinalgOp op) { + return llvm::all_of(op.getIndexingMaps(), [](AffineMap m) { + return m.isProjectedPermutation(/*allowZeroInResults=*/true); + }); +} + +bool hasOnlyScalarElementwiseOp(Region &r) { + if (!llvm::hasSingleElement(r)) + return false; + for (Operation &op : r.front()) { + if (!(isa(op) || + OpTrait::hasElementwiseMappableTraits(&op)) || + llvm::any_of(op.getResultTypes(), + [](Type type) { return !type.isIntOrIndexOrFloat(); })) + return false; + } + return true; +} + +bool isElementwise(LinalgOp op) { + if (op.getNumLoops() != op.getNumParallelLoops()) + return false; + + if (!allIndexingsAreProjectedPermutation(op)) + return false; + + // TODO: relax the restrictions on indexing map. + for (OpOperand *opOperand : op.getOutputOperands()) { + if (!op.getTiedIndexingMap(opOperand).isPermutation()) + return false; + } + return hasOnlyScalarElementwiseOp(op->getRegion(0)); +} + bool isPermutation(ArrayRef permutation) { // Count the number of appearances for all indices. SmallVector indexCounts(permutation.size(), 0); diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -7472,6 +7472,7 @@ ":ArithmeticDialect", ":ArithmeticUtils", ":DialectUtils", + ":FuncDialect", ":IR", ":LinalgAnalysis", ":LinalgDialect",