diff --git a/mlir/include/mlir/IR/BlockUtilities.h b/mlir/include/mlir/IR/BlockUtilities.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/IR/BlockUtilities.h @@ -0,0 +1,32 @@ +//===- BlockUtilities.h - MLIR Block utility functions-----------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines utility functions working on Blocks. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_IR_BLOCKUTILITIES_H +#define MLIR_IR_BLOCKUTILITIES_H + +#include "mlir/IR/Block.h" + +namespace mlir { + +/// Return the unique instance of OpType in `block` if it is indeed unique. +/// Return null if none or more than 1 instances exist. +template +OpType getSingleOpOfType(Block &block) { + auto ops = block.getOps(); + if (!llvm::hasSingleElement(ops)) + return nullptr; + return *ops.begin(); +} + +} // end namespace mlir + +#endif // MLIR_IR_BLOCKUTILITIES_H diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp @@ -12,6 +12,7 @@ #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/AffineExprVisitor.h" #include "mlir/IR/AffineMap.h" +#include "mlir/IR/BlockUtilities.h" #include "mlir/IR/TypeUtilities.h" #include "llvm/ADT/SmallSet.h" @@ -40,22 +41,6 @@ }; } -/// Return the unique instance of OpType in `block` if it is indeed unique. -/// Return null if none or more than 1 instances exist. -template -static OpType getSingleOpOfType(Block &block) { - OpType res = nullptr; - block.walk([&](OpType op) { - if (res) { - res = nullptr; - return WalkResult::interrupt(); - } - res = op; - return WalkResult::advance(); - }); - return res; -} - /// Detect whether res is any permutation of `u5(u1(c) + u2(u3(a) * u4(b)))` /// on the field (AddOpType, MulOpType), where u1, u2, u3, u4 and u5 represent /// unary operations that may change the type. diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -19,6 +19,7 @@ #include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/Dialect/Vector/VectorOps.h" #include "mlir/IR/AffineExpr.h" +#include "mlir/IR/BlockUtilities.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" @@ -39,22 +40,6 @@ #define DEBUG_TYPE "linalg-vectorization" -/// Return the unique instance of OpType in `block` if it is indeed unique. -/// Return null if none or more than 1 instances exist. -template -static OpType getSingleOpOfType(Block &block) { - OpType res; - block.walk([&](OpType op) { - if (res) { - res = nullptr; - return WalkResult::interrupt(); - } - res = op; - return WalkResult::advance(); - }); - return res; -} - /// Given an indexing `map` coming from a LinalgOp indexing, restricted to a /// projectedPermutation, compress the unused dimensions to serve as a /// permutation_map for a vector transfer operation.