diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -1067,6 +1067,20 @@ } }; +// Helper to return a bitvector where each bit set indicates a dimension that +// is not used by any of the maps in the input array `maps`. +static llvm::SmallBitVector getUnusedDimsBitVector(ArrayRef maps) { + unsigned numDims = maps[0].getNumDims(); + llvm::SmallBitVector numDimsBitVector(numDims, true); + for (const auto &m : maps) { + for (unsigned i = 0; i < numDims; ++i) { + if (m.isFunctionOfDim(i)) + numDimsBitVector.reset(i); + } + } + return numDimsBitVector; +} + /// Merge BroadcastOp into ContractionOp user. /// Ex: /// ``` @@ -1155,21 +1169,14 @@ // Determine which dims are usused, now that the maps have been composed // with the broadcast maps. - unsigned numDims = maps[0].getNumDims(); - llvm::SmallBitVector unusedDims(numDims, true); - for (const auto &m : maps) { - for (unsigned i = 0; i < numDims; ++i) { - if (m.isFunctionOfDim(i)) - unusedDims.reset(i); - } - } + llvm::SmallBitVector unusedDimsBitVector = getUnusedDimsBitVector(maps); // Compress unused dims. for (auto &m : maps) - m = compressDims(m, unusedDims); + m = compressDims(m, unusedDimsBitVector); // Compute the combined iterators. SmallVector iterators; - for (unsigned i = 0; i < numDims; ++i) { - if (!unusedDims.test(i)) + for (unsigned i = 0; i < unusedDimsBitVector.size(); ++i) { + if (!unusedDimsBitVector.test(i)) iterators.push_back(contractOp.getIteratorTypes().getValue()[i]); } // Check that compressing unused dims isn't removing all reduction @@ -1179,7 +1186,10 @@ // a reduction iterator. if (!llvm::any_of(iterators, isReductionIterator)) return failure(); - + // If the compressed maps have a dimension that is not used by either LHS or + // RHS then the combine is illegal. + if (getUnusedDimsBitVector({maps[0], maps[1]}).any()) + return failure(); rewriter.replaceOpWithNewOp( contractOp, lhs, rhs, contractOp.getAcc(), rewriter.getAffineMapArrayAttr(maps), rewriter.getArrayAttr(iterators)); diff --git a/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir b/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir --- a/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir +++ b/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir @@ -159,6 +159,10 @@ #map1 = affine_map<(d0, d1, d2) -> (d0, d2)> #map2 = affine_map<(d0, d1, d2) -> (d1, d2)> +// CHECK-DAG: #[[$map0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)> +// CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)> +// CHECK-DAG: #[[$map2:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)> + // CHECK-LABEL: contract_broadcast_unit_dim_reduction_as_only_reduction // CHECK-SAME: (%[[ARG0:.+]]: vector<8xi32>, %[[ARG1:.+]]: vector<8xi32>, %[[ARG2:.+]]: vector<8x8xi32>) // CHECK: %[[BROADCAST0:.+]] = vector.broadcast %[[ARG0]] : vector<8xi32> to vector<1x8xi32> @@ -178,6 +182,37 @@ return %result : vector<8x8xi32> } +// ----- + +// Test that CombineContractBroadcast is not combining this case, as that would +// result in a dimension being unused in the LHS and RHS maps, which is illegal. + +#map0 = affine_map<(d0, d1, d2) -> (d0, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +#map2 = affine_map<(d0, d1, d2) -> (d1)> + +// CHECK-DAG: #[[$map0:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)> +// CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK-DAG: #[[$map2:.*]] = affine_map<(d0, d1, d2) -> (d1)> + +// CHECK-LABEL: contract_broadcast_dimension_would_go_unused_in_lhs_rhs +// CHECK-SAME: (%[[ARG0:.+]]: vector<1x2xi32>, %[[ARG1:.+]]: vector<2xi32>, %[[ARG2:.+]]: vector<1xi32>) +// CHECK: %[[BROADCAST1:.+]] = vector.broadcast %[[ARG1]] : vector<2xi32> to vector<1x1x2xi32> +// CHECK: vector.contract +// CHECK-SAME: indexing_maps = [#[[$map0]], #[[$map1]], #[[$map2]]] +// CHECK-SAME: iterator_types = ["reduction", "parallel", "reduction"] +// CHECK-SAME: %[[ARG0]], %[[BROADCAST1]], %[[ARG2]] : vector<1x2xi32>, vector<1x1x2xi32> into vector<1xi32> + +func.func @contract_broadcast_dimension_would_go_unused_in_lhs_rhs(%arg0 : vector<1x2xi32>, %arg1 : vector<2xi32>, %arg2 : vector<1xi32>) -> vector<1xi32> { + %1 = vector.broadcast %arg1 : vector<2xi32> to vector<1x1x2xi32> + %result = vector.contract { + indexing_maps = [#map0, #map1, #map2], + iterator_types = ["reduction", "parallel", "reduction"], + kind = #vector.kind + } %arg0, %1, %arg2 : vector<1x2xi32>, vector<1x1x2xi32> into vector<1xi32> + return %result : vector<1xi32> +} + //===----------------------------------------------------------------------===// // Reorder casting ops and vector ops. The casting ops have almost identical // pattern, so only arith.extsi op is tested.