diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -1161,13 +1161,21 @@ if (!unusedDimsBitVector.test(i)) iterators.push_back(contractOp.getIteratorTypes().getValue()[i]); } - // Check that compressing unused dims isn't removing all reduction - // iterators. For example, if the vector.contract had only one reduction + // Check that compressing unused dims isn't removing all reduction dimension + // pairs. For example, if the vector.contract had only one reduction // iterator and that was a unit-dimension created by a broadcast, // then we should bail here, otherwise we would create a contract without - // a reduction iterator. - if (!llvm::any_of(iterators, isReductionIterator)) - return failure(); + // a reduction dimension pair. + bool hasReductionIteratorApplyingOnBothSides = false; + for (unsigned i = 0; i < iterators.size(); ++i) { + if (!isReductionIterator(iterators[i])) continue; + // Search lhs/rhs map results for 'targetExpr'. + if (getResultIndex(maps[0], i) && getResultIndex(maps[1], i) ) { + hasReductionIteratorApplyingOnBothSides = true; + break; + } + } + if (!hasReductionIteratorApplyingOnBothSides) return failure(); // If the compressed maps have a dimension that is not used by either LHS or // RHS then the ContractionOp verifier would fail. if (getUnusedDimsBitVector({maps[0], maps[1]}).any())