diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -22,6 +22,7 @@ #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/Utils/VectorUtils.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/ImplicitLocOpBuilder.h" @@ -1066,6 +1067,21 @@ } }; +/// Returns true if the given `lhsMap` and `rhsMap` from a vector.contract op +/// use a reduction dimension access pair. +static bool usesReductionPair(AffineMap lhsMap, AffineMap rhsMap, + ArrayAttr iteratorTypes) { + for (const auto &it : llvm::enumerate(iteratorTypes)) { + if (!isReductionIterator(it.value())) + continue; + auto lhsDim = getResultIndex(lhsMap, it.index()); + auto rhsDim = getResultIndex(rhsMap, it.index()); + if (lhsDim && rhsDim) + return true; + } + return false; +} + /// Merge BroadcastOp into ContractionOp user. /// Ex: /// ``` @@ -1135,6 +1151,26 @@ } if (!changed) return failure(); + + // We need to make sure at least one reduction dimension pair is actually + // used to generate valid vector.contract ops. + ArrayAttr iteratorTypes = contractOp.getIteratorTypes(); + if (!usesReductionPair(maps[0], maps[1], iteratorTypes)) + return failure(); + + // Also, if there are no parallel dimensions used in LHS and RHS, it's a + // full reduction. For such cases, contraction ops expect to have a scalar + // accumulator/result to be valid. + auto hasParallelDim = [&](AffineMap map) { + for (int i = 0, e = map.getNumResults(); i < e; ++i) + if (isParallelIterator(iteratorTypes[map.getDimPosition(i)])) + return true; + return false; + }; + if (contractOp.getAccType().isa() && + (!hasParallelDim(maps[0]) && !hasParallelDim(maps[1]))) + return failure(); + rewriter.replaceOpWithNewOp( contractOp, lhs, rhs, contractOp.getAcc(), rewriter.getAffineMapArrayAttr(maps), contractOp.getIteratorTypes()); diff --git a/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir b/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir --- a/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir +++ b/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir @@ -86,6 +86,34 @@ return %1 : vector<8x32xf32> } +// ----- + +// CHECK-LABEL: func @contract_broadcast_no_reduction_pair +// CHECK: vector.broadcast +// CHECK: vector.contract +func.func @contract_broadcast_no_reduction_pair(%a: vector<1xf32>, %b: vector<4xf32>, %c: vector<4xf32>) -> vector<4xf32> { + %bcast = vector.broadcast %b : vector<4xf32> to vector<1x4xf32> + %contract = vector.contract { + indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0)>], + iterator_types = ["parallel", "reduction"], kind = #vector.kind + } %a, %bcast, %c : vector<1xf32>, vector<1x4xf32> into vector<4xf32> + return %contract: vector<4xf32> +} + +// ----- + +// CHECK-LABEL: func @contract_broadcast_no_parallel_pair +// CHECK: vector.broadcast +// CHECK: vector.contract +func.func @contract_broadcast_no_parallel_pair(%a: vector<4xf32>, %c: vector<1xf32>) -> vector<1xf32> { + %bcast_a = vector.broadcast %a : vector<4xf32> to vector<1x4xf32> + %contract = vector.contract { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], + iterator_types = ["parallel", "reduction"], kind = #vector.kind + } %bcast_a, %bcast_a, %c : vector<1x4xf32>, vector<1x4xf32> into vector<1xf32> + return %contract: vector<1xf32> +} + //===----------------------------------------------------------------------===// // Reorder casting ops and vector ops. The casting ops have almost identical // pattern, so only arith.extsi op is tested.