diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -32,6 +32,7 @@ #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/MapVector.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallBitVector.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" @@ -1126,6 +1127,21 @@ // relaxed we can remove this case. if (innerDimBroadcast) continue; + + // It would be incorrect to fold a broadcast onto a reduction dimension + // of non-unit size. + bool nonUnitDimReductionBroadcast = false; + for (int64_t i = 0; i < rankDiff; ++i) { + if (operand->getType().cast().getShape()[i] != 1 && + isReductionIterator(contractOp.getIteratorTypes().getValue()[i])) { + nonUnitDimReductionBroadcast = true; + break; + } + } + if (nonUnitDimReductionBroadcast) { + continue; + } + AffineMap broadcastMap = AffineMap::get(broadcast.getVectorType().getRank(), 0, originalDims, contractOp.getContext()); @@ -1133,11 +1149,44 @@ *operand = broadcast.getSource(); changed = true; } + if (!changed) return failure(); + + // Determine which dims are usused, not that the maps have been composed + // with the broadcast maps. + unsigned numDims = maps[0].getNumDims(); + llvm::SmallBitVector unusedDims(maps[0].getNumDims(), true); + for (const auto &m : maps) { + m.walkExprs([&](AffineExpr expr) { + if (auto dimExpr = expr.dyn_cast()) + unusedDims.reset(dimExpr.getPosition()); + }); + } + // Compress unused dims. + for (auto &m : maps) { + m = compressDims(m, unusedDims); + } + // Check that compressing unused dims isn't removing all reduction + // iterators. + SmallVector iterators; + bool stillHasReduction = false; + for (unsigned i = 0; i < numDims; ++i) { + if (!unusedDims.test(i)) { + auto iteratorType = contractOp.getIteratorTypes().getValue()[i]; + iterators.push_back(iteratorType); + if (isReductionIterator(iteratorType)) { + stillHasReduction = true; + } + } + } + if (!stillHasReduction) { + return failure(); + } + rewriter.replaceOpWithNewOp( contractOp, lhs, rhs, contractOp.getAcc(), - rewriter.getAffineMapArrayAttr(maps), contractOp.getIteratorTypes()); + rewriter.getAffineMapArrayAttr(maps), rewriter.getArrayAttr(iterators)); return success(); } }; diff --git a/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir b/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir --- a/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir +++ b/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir @@ -86,6 +86,34 @@ return %1 : vector<8x32xf32> } +// ----- + +#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> +#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)> +#map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)> + +// CHECK-DAG: #[[$map0:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)> +// CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)> +// CHECK-DAG: #[[$map2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)> + +// CHECK-LABEL: contract_broadcast_input_side_unit_dim +// CHECK-SAME: (%[[ARG0:.+]]: vector<8x4xi32>, %[[ARG1:.+]]: vector<8x4xi32>, %[[ARG2:.+]]: vector<8x8xi32>) +// CHECK: vector.contract +// CHECK-SAME: indexing_maps = [#[[$map0]], #[[$map1]], #[[$map2]]] +// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"] +// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]] : vector<8x4xi32>, vector<8x4xi32> into vector<8x8xi32> +func.func @contract_broadcast_input_side_unit_dim(%arg0 : vector<8x4xi32>, %arg1 : vector<8x4xi32>, %arg2 : vector<8x8xi32>) -> vector<8x8xi32> { + %0 = vector.broadcast %arg0 : vector<8x4xi32> to vector<1x8x4xi32> + %1 = vector.broadcast %arg1 : vector<8x4xi32> to vector<1x8x4xi32> + %result = vector.contract { + indexing_maps = [#map0, #map1, #map2], + iterator_types = ["reduction", "parallel", "parallel", "reduction"], + kind = #vector.kind + } %0, %1, %arg2 : vector<1x8x4xi32>, vector<1x8x4xi32> into vector<8x8xi32> + return %result : vector<8x8xi32> +} + + //===----------------------------------------------------------------------===// // Reorder casting ops and vector ops. The casting ops have almost identical // pattern, so only arith.extsi op is tested.