diff --git a/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h b/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h --- a/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h +++ b/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h @@ -69,11 +69,6 @@ SmallVector concat(ArrayRef a, ArrayRef b); -/// Return the dims that are `iteratorTypeName` loops in the LinalgOp `op`. -/// Assumes `op` is a LinalgOp. -void getDimsOfType(Operation *op, StringRef iteratorTypeName, - SmallVectorImpl &res); - } // namespace linalg } // namespace mlir diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td @@ -206,7 +206,8 @@ /*args=*/(ins "SmallVectorImpl &":$res), /*methodBody=*/"", /*defaultImplementation=*/[{ - return getDimsOfType($_op, getParallelIteratorTypeName(), res); + return findPositionsOfType($_op.iterator_types(), + getParallelIteratorTypeName(), res); }] >, InterfaceMethod< @@ -231,7 +232,8 @@ /*args=*/(ins "SmallVectorImpl &":$res), /*methodBody=*/"", /*defaultImplementation=*/[{ - return getDimsOfType($_op, getReductionIteratorTypeName(), res); + return findPositionsOfType($_op.iterator_types(), + getReductionIteratorTypeName(), res); }] >, InterfaceMethod< @@ -256,7 +258,8 @@ /*args=*/(ins "SmallVectorImpl &":$res), /*methodBody=*/"", /*defaultImplementation=*/[{ - return getDimsOfType($_op.getOperation(), getWindowIteratorTypeName(), res); + return findPositionsOfType($_op.iterator_types(), + getWindowIteratorTypeName(), res); }] >, InterfaceMethod< diff --git a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h --- a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h +++ b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h @@ -110,6 +110,17 @@ return res; } +/// Return positions in `iteratorTypes` that match `iteratorTypeName`. +inline void findPositionsOfType(ArrayAttr iteratorTypes, + StringRef iteratorTypeName, + SmallVectorImpl &res) { + for (const auto &en : + llvm::enumerate(iteratorTypes.getAsValueRange())) { + if (en.value() == iteratorTypeName) + res.push_back(en.index()); + } +} + /// Helper StructuredGenerator class to manipulate and rewrite ops with /// `StructuredOpInterface`. This is templated for now because VectorOps do not /// yet implement the StructuredOpInterface itself. diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -1540,22 +1540,6 @@ #define GET_OP_CLASSES #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc" -/// Return the dims that are `iteratorTypeName` loops in the LinalgOp `op`. -/// Assumes `op` is a LinalgOp. -void mlir::linalg::getDimsOfType(Operation *op, StringRef iteratorTypeName, - SmallVectorImpl &res) { - if (!cast(op).iterator_types()) - return; - - unsigned dim = 0; - for (auto tn : - cast(op).iterator_types().getAsValueRange()) { - if (tn == iteratorTypeName) - res.push_back(dim); - ++dim; - } -} - AffineMap mlir::linalg::extractOrIdentityMap(Optional maybeMap, unsigned rank, MLIRContext *context) {