diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -1863,9 +1863,15 @@ dimSize = rhsType.getDimSize(rhsIndex); } assert(iterIndex >= 0 && "parallel index not listed in operand mapping"); - Optional lookup = getResultIndex(iMap[2], iterIndex); - assert(lookup.hasValue() && "parallel index not listed in reduction"); - int64_t resIndex = lookup.getValue(); + // getValueOr(-1) means that we tolerate a dimension not appearing + // in the result map. That can't happen for actual parallel iterators, but + // the caller ContractionOpLowering::matchAndRewrite is currently calling + // lowerParallel also for the case of unit-size reduction dims appearing only + // on one of LHS or RHS, not both. At the moment, such cases are created by + // CastAwayContractionLeadingOneDim, so we need to either support that or + // modify that pattern. + int64_t resIndex = getResultIndex(iMap[2], iterIndex).getValueOr(-1); + assert(resIndex != -1 || dimSize == 1); // Construct new iterator types and affine map array attribute. std::array lowIndexingMaps = { adjustMap(iMap[0], iterIndex, rewriter), diff --git a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir --- a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir @@ -858,6 +858,22 @@ return %0 : vector<2x1x7xi1> } +// CHECK-LABEL: @contract_one_sided_unit_reduction_dim +// (this testcase is just checking that we don't crash/assert) + +func.func @contract_one_sided_unit_reduction_dim(%arg0 : vector<1x2xi32>, %arg1 : vector<8x2xi32>, %arg2 : vector<8xi32>) -> vector<8xi32> { + %res = vector.contract { + indexing_maps = [ + affine_map<(d0, d1, d2) -> (d0, d2)>, + affine_map<(d0, d1, d2) -> (d1, d2)>, + affine_map<(d0, d1, d2) -> (d1)> + ], + iterator_types = ["reduction", "parallel", "reduction"], + kind = #vector.kind + } %arg0, %arg1, %arg2 : vector<1x2xi32>, vector<8x2xi32>, vector<8xi32> into vector<8xi32> + return %res : vector<8xi32> +} + #matmat_accesses_0 = [ affine_map<(m, n, k) -> (m, k)>, affine_map<(m, n, k) -> (k, n)>,