diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td @@ -621,6 +621,31 @@ return failure(); }] >, + InterfaceMethod< + /*desc=*/[{ + Given a dimension of the iteration space of a Linalg operation, finds + all the operands in the operation that are defined on such dimension. + Returns all the operand values found and their dimension positions in + `operandDimPairs`. + }], + /*retTy=*/"void", + /*methodName=*/"mapIterationSpaceDimToAllOperandDims", + /*args=*/(ins "unsigned":$dimPos, + "mlir::SmallVectorImpl>&":$operandDimPairs), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + for (auto [i, idxMap] : llvm::enumerate($_op.getIndexingMapsArray())) { + if (idxMap.isProjectedPermutation()) { + if (auto mayOperandDim = idxMap.getResultPosition( + getAffineDimExpr(dimPos, idxMap.getContext()))) { + operandDimPairs.push_back({$_op->getOperand(i), *mayOperandDim}); + } + } + } + + return; + }] + >, //===------------------------------------------------------------------===// // Linalg generalization hooks. //===------------------------------------------------------------------===//