diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -32,6 +32,7 @@ #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/MapVector.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallBitVector.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" @@ -1126,6 +1127,20 @@ // relaxed we can remove this case. if (innerDimBroadcast) continue; + + // It would be incorrect to fold a broadcast onto a reduction dimension + // of non-unit size. + bool nonUnitDimReductionBroadcast = false; + for (int64_t i = 0; i < rankDiff; ++i) { + if (broadcast.getVectorType().getDimSize(i) != 1 && + isReductionIterator(contractOp.getIteratorTypes().getValue()[i])) { + nonUnitDimReductionBroadcast = true; + break; + } + } + if (nonUnitDimReductionBroadcast) + continue; + AffineMap broadcastMap = AffineMap::get(broadcast.getVectorType().getRank(), 0, originalDims, contractOp.getContext()); @@ -1133,11 +1148,37 @@ *operand = broadcast.getSource(); changed = true; } + if (!changed) return failure(); + + // Determine which dims are usused, not that the maps have been composed + // with the broadcast maps. + unsigned numDims = maps[0].getNumDims(); + llvm::SmallBitVector unusedDims(numDims, true); + for (const auto &m : maps) { + for (unsigned i = 0; i < numDims; ++i) { + if (m.isFunctionOfDim(i)) + unusedDims.reset(i); + } + } + // Compress unused dims. + for (auto &m : maps) + m = compressDims(m, unusedDims); + // Compute the combined iterators. + SmallVector iterators; + for (unsigned i = 0; i < numDims; ++i) { + if (!unusedDims.test(i)) + iterators.push_back(contractOp.getIteratorTypes().getValue()[i]); + } + // Check that compressing unused dims isn't removing all reduction + // iterators. + if (!llvm::any_of(iterators, isReductionIterator)) + return failure(); + rewriter.replaceOpWithNewOp( contractOp, lhs, rhs, contractOp.getAcc(), - rewriter.getAffineMapArrayAttr(maps), contractOp.getIteratorTypes()); + rewriter.getAffineMapArrayAttr(maps), rewriter.getArrayAttr(iterators)); return success(); } }; diff --git a/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir b/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir --- a/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir +++ b/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir @@ -86,6 +86,64 @@ return %1 : vector<8x32xf32> } +// ----- +// Test that CombineContractBroadcast is able to combine a broadcast that +// creates a unit dim that is consumed by a reduction iterator, dropping that +// reduction iterator, as long as there is another reduction iterator left. + +#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> +#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)> +#map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)> + +// CHECK-DAG: #[[$map0:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)> +// CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)> +// CHECK-DAG: #[[$map2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)> + +// CHECK-LABEL: contract_broadcast_unit_dim_reduction +// CHECK-SAME: (%[[ARG0:.+]]: vector<8x4xi32>, %[[ARG1:.+]]: vector<8x4xi32>, %[[ARG2:.+]]: vector<8x8xi32>) +// CHECK: vector.contract +// CHECK-SAME: indexing_maps = [#[[$map0]], #[[$map1]], #[[$map2]]] +// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"] +// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]] : vector<8x4xi32>, vector<8x4xi32> into vector<8x8xi32> +func.func @contract_broadcast_unit_dim_reduction(%arg0 : vector<8x4xi32>, %arg1 : vector<8x4xi32>, %arg2 : vector<8x8xi32>) -> vector<8x8xi32> { + %0 = vector.broadcast %arg0 : vector<8x4xi32> to vector<1x8x4xi32> + %1 = vector.broadcast %arg1 : vector<8x4xi32> to vector<1x8x4xi32> + %result = vector.contract { + indexing_maps = [#map0, #map1, #map2], + iterator_types = ["reduction", "parallel", "parallel", "reduction"], + kind = #vector.kind + } %0, %1, %arg2 : vector<1x8x4xi32>, vector<1x8x4xi32> into vector<8x8xi32> + return %result : vector<8x8xi32> +} + +// ----- + +// Test that CombineContractBroadcast is not combining this case, as that would +// result in dropping this contract's only reduction iterator. + +#map0 = affine_map<(d0, d1, d2) -> (d0, d1)> +#map1 = affine_map<(d0, d1, d2) -> (d0, d2)> +#map2 = affine_map<(d0, d1, d2) -> (d1, d2)> + +// CHECK-LABEL: contract_broadcast_unit_dim_reduction_as_only_reduction +// CHECK-SAME: (%[[ARG0:.+]]: vector<8xi32>, %[[ARG1:.+]]: vector<8xi32>, %[[ARG2:.+]]: vector<8x8xi32>) +// CHECK: %[[BROADCAST0:.+]] = vector.broadcast %[[ARG0]] : vector<8xi32> to vector<1x8xi32> +// CHECK: %[[BROADCAST1:.+]] = vector.broadcast %[[ARG1]] : vector<8xi32> to vector<1x8xi32> +// CHECK: vector.contract +// CHECK-SAME: indexing_maps = [#[[$map0]], #[[$map1]], #[[$map2]]] +// CHECK-SAME: iterator_types = ["reduction", "parallel", "parallel"] +// CHECK-SAME: %[[BROADCAST0]], %[[BROADCAST1]], %[[ARG2]] : vector<1x8xi32>, vector<1x8xi32> into vector<8x8xi32> +func.func @contract_broadcast_unit_dim_reduction_as_only_reduction(%arg0 : vector<8xi32>, %arg1 : vector<8xi32>, %arg2 : vector<8x8xi32>) -> vector<8x8xi32> { + %0 = vector.broadcast %arg0 : vector<8xi32> to vector<1x8xi32> + %1 = vector.broadcast %arg1 : vector<8xi32> to vector<1x8xi32> + %result = vector.contract { + indexing_maps = [#map0, #map1, #map2], + iterator_types = ["reduction", "parallel", "parallel"], + kind = #vector.kind + } %0, %1, %arg2 : vector<1x8xi32>, vector<1x8xi32> into vector<8x8xi32> + return %result : vector<8x8xi32> +} + //===----------------------------------------------------------------------===// // Reorder casting ops and vector ops. The casting ops have almost identical // pattern, so only arith.extsi op is tested.