diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -1161,13 +1161,24 @@ if (!unusedDimsBitVector.test(i)) iterators.push_back(contractOp.getIteratorTypes().getValue()[i]); } - // Check that compressing unused dims isn't removing all reduction - // iterators. For example, if the vector.contract had only one reduction + // Check that compressing unused dims isn't removing all reduction dimension + // pairs. For example, if the vector.contract had only one reduction // iterator and that was a unit-dimension created by a broadcast, // then we should bail here, otherwise we would create a contract without - // a reduction iterator. - if (!llvm::any_of(iterators, isReductionIterator)) + // a reduction dimension pair. + bool hasReductionIteratorApplyingOnBothSides = false; + for (unsigned i = 0; i < iterators.size(); ++i) { + if (!isReductionIterator(iterators[i])) + continue; + // Search lhs/rhs map results for 'targetExpr'. + if (getResultIndex(maps[0], i) && getResultIndex(maps[1], i)) { + hasReductionIteratorApplyingOnBothSides = true; + break; + } + } + if (!hasReductionIteratorApplyingOnBothSides) return failure(); + // If the compressed maps have a dimension that is not used by either LHS or // RHS then the ContractionOp verifier would fail. if (getUnusedDimsBitVector({maps[0], maps[1]}).any()) diff --git a/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir b/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir --- a/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir +++ b/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir @@ -213,6 +213,38 @@ return %result : vector<1xi32> } +// ----- + +// Test that CombineContractBroadcast is not combining this case, as that would +// result in a vector.contract without a reduction dimention pair, as the only +// reduction dimension would be used by only one side among LHS, RHS. +// This is arguably a convoluted edge case (the affine_maps here look weird!) +// but it is something that we actually ran into from linalg.matmul tests that +// were exercising 1x1 shapes, and using various drop-unit-dims patterns. + +#map0 = affine_map<(d0, d1) -> (d1)> +#map1 = affine_map<(d0, d1) -> (d1, d0)> +#map2 = affine_map<(d0, d1) -> (d0)> + +// CHECK-LABEL: contract_broadcast_would_have_no_reduction_dim_pair +// CHECK-SAME: (%[[ARG0:.+]]: vector<1xf32>, %[[ARG1:.+]]: vector<1xf32>, %[[ARG2:.+]]: vector<1xf32>) +// CHECK: %[[BROADCAST1:.+]] = vector.broadcast %[[ARG1]] : vector<1xf32> to vector<1x1xf32> +// CHECK: vector.contract +// CHECK-SAME: indexing_maps = [#[[$map0]], #[[$map1]], #[[$map2]]] +// CHECK-SAME: iterator_types = ["parallel", "reduction"] +// CHECK-SAME: %[[ARG0]], %[[BROADCAST1]], %[[ARG2]] : vector<1xf32>, vector<1x1xf32> into vector<1xf32> + +func.func @contract_broadcast_would_have_no_reduction_dim_pair(%arg0 : vector<1xf32>, %arg1 : vector<1xf32>, %arg2 : vector<1xf32>) -> vector<1xf32> { + %1 = vector.broadcast %arg1 : vector<1xf32> to vector<1x1xf32> + %result = vector.contract { + indexing_maps = [#map0, #map1, #map2], + iterator_types = ["parallel", "reduction"], + kind = #vector.kind + } %arg0, %1, %arg2 : vector<1xf32>, vector<1x1xf32> into vector<1xf32> + return %result : vector<1xf32> +} + + //===----------------------------------------------------------------------===// // Reorder casting ops and vector ops. The casting ops have almost identical // pattern, so only arith.extsi op is tested.