diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -936,6 +936,21 @@ } }; +/// Returns true if the given `lhsMap` and `rhsMap` from a vector.contract op +/// use a reduction dimension access pair. +static bool usesReductionPair(AffineMap lhsMap, AffineMap rhsMap, + ArrayAttr iteratorTypes) { + for (const auto &it : llvm::enumerate(iteratorTypes)) { + if (!isReductionIterator(it.value())) + continue; + auto lhsDim = getResultIndex(lhsMap, it.index()); + auto rhsDim = getResultIndex(rhsMap, it.index()); + if (lhsDim && rhsDim) + return true; + } + return false; +} + /// Merge BroadcastOp into ContractionOp user. /// Ex: /// ``` @@ -1005,10 +1020,29 @@ } if (!changed) return failure(); - rewriter.replaceOpWithNewOp( - contractOp, lhs, rhs, contractOp.getAcc(), - rewriter.getAffineMapArrayAttr(maps), contractOp.getIteratorTypes()); - return success(); + + ArrayAttr iteratorTypes = contractOp.getIteratorTypes(); + auto hasParallelDim = [&](AffineMap map) { + for (int i = 0, e = map.getNumResults(); i < e; ++i) + if (isParallelIterator(iteratorTypes[map.getDimPosition(i)])) + return true; + return false; + }; + + // We need to make sure at least one reduction dimension pair is actually + // used to generate valid vector.contract ops. Also, if there are no + // parallel dimensions used in LHS and RHS, the accumulator/result should be + // a scalar. + bool isFullReduction = !hasParallelDim(maps[0]) && !hasParallelDim(maps[1]); + if (usesReductionPair(maps[0], maps[1], iteratorTypes) && + !(contractOp.getAccType().isa() && isFullReduction)) { + rewriter.replaceOpWithNewOp( + contractOp, lhs, rhs, contractOp.getAcc(), + rewriter.getAffineMapArrayAttr(maps), contractOp.getIteratorTypes()); + return success(); + } + + return failure(); } }; diff --git a/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir b/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir --- a/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir +++ b/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir @@ -86,6 +86,34 @@ return %1 : vector<8x32xf32> } +// ----- + +// CHECK-LABEL: func @contract_broadcast_no_reduction_pair +// CHECK: vector.broadcast +// CHECK: vector.contract +func.func @contract_broadcast_no_reduction_pair(%a: vector<1xf32>, %b: vector<4xf32>, %c: vector<4xf32>) -> vector<4xf32> { + %bcast = vector.broadcast %b : vector<4xf32> to vector<1x4xf32> + %contract = vector.contract { + indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0)>], + iterator_types = ["parallel", "reduction"], kind = #vector.kind + } %a, %bcast, %c : vector<1xf32>, vector<1x4xf32> into vector<4xf32> + return %contract: vector<4xf32> +} + +// ----- + +// CHECK-LABEL: func @contract_broadcast_no_parallel_pair +// CHECK: vector.broadcast +// CHECK: vector.contract +func.func @contract_broadcast_no_parallel_pair(%a: vector<4xf32>, %c: vector<1xf32>) -> vector<1xf32> { + %bcast_a = vector.broadcast %a : vector<4xf32> to vector<1x4xf32> + %contract = vector.contract { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], + iterator_types = ["parallel", "reduction"], kind = #vector.kind + } %bcast_a, %bcast_a, %c : vector<1x4xf32>, vector<1x4xf32> into vector<1xf32> + return %contract: vector<1xf32> +} + //===----------------------------------------------------------------------===// // Reorder casting ops and vector ops. The casting ops have almost identical // pattern, so only arith.extsi op is tested.