diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -1004,9 +1004,27 @@ } if (!changed) return failure(); - rewriter.replaceOpWithNewOp( - contractOp, lhs, rhs, contractOp.getAcc(), - rewriter.getAffineMapArrayAttr(maps), contractOp.getIteratorTypes()); + + ArrayAttr iteratorTypes = contractOp.getIteratorTypes(); + // We need to make sure at least one reduction dimension is actually used to + // generate valid vector.contract ops. + if (llvm::any_of(maps, [&](AffineMap map) { + for (int i = 0, e = map.getNumResults(); i < e; ++i) + if (isReductionIterator(iteratorTypes[map.getDimPosition(i)])) + return true; + return false; + })) { + rewriter.replaceOpWithNewOp( + contractOp, lhs, rhs, contractOp.getAcc(), + rewriter.getAffineMapArrayAttr(maps), contractOp.getIteratorTypes()); + } else { + // After combining, all access are through parallel dimensions. It can be + // simplified into a vector.fma op if all maps are the same. + if (!llvm::is_splat(maps)) + return failure(); + rewriter.replaceOpWithNewOp(contractOp, lhs, rhs, + contractOp.getAcc()); + } return success(); } }; diff --git a/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir b/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir --- a/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir +++ b/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir @@ -86,6 +86,20 @@ return %1 : vector<8x32xf32> } +// CHECK-LABEL: contract_broadcast_fma +// CHECK-SAME: (%[[A:.+]]: vector<4xf32>, %[[B:.+]]: vector<4xf32>, %[[C:.+]]: vector<4xf32>) +// CHECK: %[[FMA:.+]] = vector.fma %[[A]], %[[B]], %[[C]] : vector<4xf32> +// CHECK: return %[[FMA]] : vector<4xf32 +func @contract_broadcast_fma(%a: vector<4xf32>, %b: vector<4xf32>, %c: vector<4xf32>) -> vector<4xf32> { + %bcast_a = vector.broadcast %a : vector<4xf32> to vector<1x1x4xf32> + %bcast_b = vector.broadcast %b : vector<4xf32> to vector<1x1x4xf32> + %contract = vector.contract { + indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d2, d0)>, affine_map<(d0, d1, d2) -> (d1, d2, d0)>, affine_map<(d0, d1, d2) -> (d0)>], + iterator_types = ["parallel", "reduction", "reduction"], kind = #vector.kind + } %bcast_a, %bcast_b, %c : vector<1x1x4xf32>, vector<1x1x4xf32> into vector<4xf32> + return %contract: vector<4xf32> +} + //===----------------------------------------------------------------------===// // Reorder casting ops and vector ops. The casting ops have almost identical // pattern, so only arith.extsi op is tested.