diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -1841,6 +1841,7 @@ } // Lower one parallel dimension. +// Incidentally also tolerates unit-size (hence trivial) reduction dimensions. // TODO: consider reusing existing contract unrolling Value ContractionOpLowering::lowerParallel(vector::ContractionOp op, int64_t lhsIndex, int64_t rhsIndex, @@ -1863,9 +1864,16 @@ dimSize = rhsType.getDimSize(rhsIndex); } assert(iterIndex >= 0 && "parallel index not listed in operand mapping"); - Optional lookup = getResultIndex(iMap[2], iterIndex); - assert(lookup.hasValue() && "parallel index not listed in reduction"); - int64_t resIndex = lookup.getValue(); + // getValueOr(-1) means that we tolerate a dimension not appearing + // in the result map. That can't happen for actual parallel iterators, but + // the caller ContractionOpLowering::matchAndRewrite is currently calling + // lowerParallel also for the case of unit-size reduction dims appearing only + // on one of LHS or RHS, not both. At the moment, such cases are created by + // CastAwayContractionLeadingOneDim, so we need to either support that or + // modify that pattern. + int64_t resIndex = getResultIndex(iMap[2], iterIndex).getValueOr(-1); + assert(resIndex != -1 || + dimSize == 1 && "expected either a parallel or a unit dim"); // Construct new iterator types and affine map array attribute. std::array lowIndexingMaps = { adjustMap(iMap[0], iterIndex, rewriter), diff --git a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir --- a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir @@ -858,6 +858,34 @@ return %0 : vector<2x1x7xi1> } +// CHECK-LABEL: @contract_one_sided_unit_reduction_dim +// CHECK-SAME: (%[[A0:.+]]: vector<1x2xi32>, %[[A1:.+]]: vector<2x2xi32>, %[[A2:.+]]: vector<2xi32>) +// CHECK-DAG: %[[C:.+]] = arith.constant dense<0> : vector<2xi32> +// CHECK-DAG: %[[E00:.+]] = vector.extract %[[A0]][0] : vector<1x2xi32> +// CHECK-DAG: %[[E10:.+]] = vector.extract %[[A1]][0] : vector<2x2xi32> +// CHECK: %[[M0:.+]] = arith.muli %[[E10]], %[[E00]] : vector<2xi32> +// CHECK: %[[R0:.+]] = vector.reduction , %[[M0]] : vector<2xi32> into i32 +// CHECK: %[[I0:.+]] = vector.insert %[[R0]], %[[C]] [0] : i32 into vector<2xi32> +// CHECK: %[[E11:.+]] = vector.extract %[[A1]][1] : vector<2x2xi32> +// CHECK: %[[M1:.+]] = arith.muli %[[E11]], %[[E00]] : vector<2xi32> +// CHECK: %[[R1:.+]] = vector.reduction , %[[M1]] : vector<2xi32> into i32 +// CHECK: %[[I1:.+]] = vector.insert %[[R1]], %[[I0]] [1] : i32 into vector<2xi32> +// CHECK: %[[S:.+]] = arith.addi %[[I1]], %[[A2]] : vector<2xi32> +// CHECK: return %[[S]] : vector<2xi32> + +func.func @contract_one_sided_unit_reduction_dim(%arg0 : vector<1x2xi32>, %arg1 : vector<2x2xi32>, %arg2 : vector<2xi32>) -> vector<2xi32> { + %res = vector.contract { + indexing_maps = [ + affine_map<(d0, d1, d2) -> (d0, d2)>, + affine_map<(d0, d1, d2) -> (d1, d2)>, + affine_map<(d0, d1, d2) -> (d1)> + ], + iterator_types = ["reduction", "parallel", "reduction"], + kind = #vector.kind + } %arg0, %arg1, %arg2 : vector<1x2xi32>, vector<2x2xi32>, vector<2xi32> into vector<2xi32> + return %res : vector<2xi32> +} + #matmat_accesses_0 = [ affine_map<(m, n, k) -> (m, k)>, affine_map<(m, n, k) -> (k, n)>,