diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -935,6 +935,21 @@ } }; +/// Returns true if the given `lhsMap` and `rhsMap` from a vector.contract op +/// has a reduction dimension access pair. +static bool hasReductionPair(AffineMap lhsMap, AffineMap rhsMap, + ArrayAttr iteratorTypes) { + for (const auto &it : llvm::enumerate(iteratorTypes)) { + if (!isReductionIterator(it.value())) + continue; + auto lhsDim = getResultIndex(lhsMap, it.index()); + auto rhsDim = getResultIndex(rhsMap, it.index()); + if (lhsDim && rhsDim) + return true; + } + return false; +} + /// Merge BroadcastOp into ContractionOp user. /// Ex: /// ``` @@ -1004,9 +1019,22 @@ } if (!changed) return failure(); - rewriter.replaceOpWithNewOp( - contractOp, lhs, rhs, contractOp.getAcc(), - rewriter.getAffineMapArrayAttr(maps), contractOp.getIteratorTypes()); + + ArrayAttr iteratorTypes = contractOp.getIteratorTypes(); + // We need to make sure at least one reduction dimension is actually used to + // generate valid vector.contract ops. + if (hasReductionPair(maps[0], maps[1], iteratorTypes)) { + rewriter.replaceOpWithNewOp( + contractOp, lhs, rhs, contractOp.getAcc(), + rewriter.getAffineMapArrayAttr(maps), contractOp.getIteratorTypes()); + } else { + // After combining, all access are through parallel dimensions. It can be + // simplified into a vector.fma op if all maps are the same. + if (!llvm::is_splat(maps)) + return failure(); + rewriter.replaceOpWithNewOp(contractOp, lhs, rhs, + contractOp.getAcc()); + } return success(); } }; diff --git a/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir b/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir --- a/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir +++ b/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir @@ -86,6 +86,36 @@ return %1 : vector<8x32xf32> } +// ----- + +// CHECK-LABEL: contract_broadcast_fma +// CHECK-SAME: (%[[A:.+]]: vector<4xf32>, %[[B:.+]]: vector<4xf32>, %[[C:.+]]: vector<4xf32>) +// CHECK: %[[FMA:.+]] = vector.fma %[[A]], %[[B]], %[[C]] : vector<4xf32> +// CHECK: return %[[FMA]] : vector<4xf32 +func @contract_broadcast_fma(%a: vector<4xf32>, %b: vector<4xf32>, %c: vector<4xf32>) -> vector<4xf32> { + %bcast_a = vector.broadcast %a : vector<4xf32> to vector<1x1x4xf32> + %bcast_b = vector.broadcast %b : vector<4xf32> to vector<1x1x4xf32> + %contract = vector.contract { + indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d2, d0)>, affine_map<(d0, d1, d2) -> (d1, d2, d0)>, affine_map<(d0, d1, d2) -> (d0)>], + iterator_types = ["parallel", "reduction", "reduction"], kind = #vector.kind + } %bcast_a, %bcast_b, %c : vector<1x1x4xf32>, vector<1x1x4xf32> into vector<4xf32> + return %contract: vector<4xf32> +} + +// ----- + +// CHECK-LABEL: func @contract_broadcast_no_reduction_pair +// CHECK: vector.broadcast +// CHECK: vector.contract +func @contract_broadcast_no_reduction_pair(%a: vector<1xf32>, %b: vector<4xf32>, %c: vector<4xf32>) -> vector<4xf32> { + %bcast = vector.broadcast %b : vector<4xf32> to vector<1x4xf32> + %contract = vector.contract { + indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0)>], + iterator_types = ["parallel", "reduction"], kind = #vector.kind + } %a, %bcast, %c : vector<1xf32>, vector<1x4xf32> into vector<4xf32> + return %contract: vector<4xf32> +} + //===----------------------------------------------------------------------===// // Reorder casting ops and vector ops. The casting ops have almost identical // pattern, so only arith.extsi op is tested.