diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h --- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h +++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h @@ -527,11 +527,12 @@ vector::VectorTransformsOptions vectorTransformOptions; FilterConstraintType filter; // Lower one parallel dimension. - Value lowerParallel(vector::ContractionOp op, int64_t lhsIndex, - int64_t rhsIndex, PatternRewriter &rewriter) const; + FailureOr lowerParallel(vector::ContractionOp op, int64_t lhsIndex, + int64_t rhsIndex, + PatternRewriter &rewriter) const; // Lower one reduction dimension. - Value lowerReduction(vector::ContractionOp op, - PatternRewriter &rewriter) const; + FailureOr lowerReduction(vector::ContractionOp op, + PatternRewriter &rewriter) const; }; } // namespace vector diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -1794,7 +1794,10 @@ if (!batchDimMap.empty()) { int64_t lhsIndex = batchDimMap[0].first; int64_t rhsIndex = batchDimMap[0].second; - rewriter.replaceOp(op, lowerParallel(op, lhsIndex, rhsIndex, rewriter)); + auto newOp = lowerParallel(op, lhsIndex, rhsIndex, rewriter); + if (failed(newOp)) + return failure(); + rewriter.replaceOp(op, newOp.value()); return success(); } @@ -1812,8 +1815,10 @@ VectorType lhsType = op.getLhsType(); for (int64_t lhsIndex = 0, e = lhsType.getRank(); lhsIndex < e; ++lhsIndex) { if (lhsContractingDimSet.count(lhsIndex) == 0) { - rewriter.replaceOp( - op, lowerParallel(op, lhsIndex, /*rhsIndex=*/-1, rewriter)); + auto newOp = lowerParallel(op, lhsIndex, /*rhsIndex=*/-1, rewriter); + if (failed(newOp)) + return failure(); + rewriter.replaceOp(op, newOp.value()); return success(); } } @@ -1822,15 +1827,20 @@ VectorType rhsType = op.getRhsType(); for (int64_t rhsIndex = 0, e = rhsType.getRank(); rhsIndex < e; ++rhsIndex) { if (rhsContractingDimSet.count(rhsIndex) == 0) { - rewriter.replaceOp( - op, lowerParallel(op, /*lhsIndex=*/-1, rhsIndex, rewriter)); + auto newOp = lowerParallel(op, /*lhsIndex=*/-1, rhsIndex, rewriter); + if (failed(newOp)) + return failure(); + rewriter.replaceOp(op, newOp.value()); return success(); } } // Lower the first remaining reduction dimension. if (!contractingDimMap.empty()) { - rewriter.replaceOp(op, lowerReduction(op, rewriter)); + auto newOp = lowerReduction(op, rewriter); + if (failed(newOp)) + return failure(); + rewriter.replaceOp(op, newOp.value()); return success(); } @@ -1838,10 +1848,12 @@ } // Lower one parallel dimension. +// Incidentally also tolerates unit-size (hence trivial) reduction dimensions. // TODO: consider reusing existing contract unrolling -Value ContractionOpLowering::lowerParallel(vector::ContractionOp op, - int64_t lhsIndex, int64_t rhsIndex, - PatternRewriter &rewriter) const { +FailureOr +ContractionOpLowering::lowerParallel(vector::ContractionOp op, int64_t lhsIndex, + int64_t rhsIndex, + PatternRewriter &rewriter) const { VectorType lhsType = op.getLhsType(); VectorType rhsType = op.getRhsType(); VectorType resType = op.getResultType().cast(); @@ -1851,18 +1863,34 @@ int64_t dimSize = -1; if (lhsIndex >= 0) { iterIndex = iMap[0].getDimPosition(lhsIndex); - assert((rhsIndex < 0 || iterIndex == iMap[1].getDimPosition(rhsIndex)) && - "parallel index should be free in LHS or batch in LHS/RHS"); + if (rhsIndex >= 0 && iterIndex != iMap[1].getDimPosition(rhsIndex)) + return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { + diag << "expected lhsIndex=" << lhsIndex << " and rhsIndex=" << rhsIndex + << " to map to the same dimension"; + }); dimSize = lhsType.getDimSize(lhsIndex); - } else { - assert(rhsIndex >= 0 && "missing parallel index"); + } else if (rhsIndex >= 0) { iterIndex = iMap[1].getDimPosition(rhsIndex); dimSize = rhsType.getDimSize(rhsIndex); } - assert(iterIndex >= 0 && "parallel index not listed in operand mapping"); - Optional lookup = getResultIndex(iMap[2], iterIndex); - assert(lookup.has_value() && "parallel index not listed in reduction"); - int64_t resIndex = lookup.getValue(); + if (iterIndex < 0) + return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { + diag << "expected either lhsIndex=" << lhsIndex + << " or rhsIndex=" << rhsIndex << " to be nonnegative"; + }); + // getValueOr(-1) means that we tolerate a dimension not appearing + // in the result map. That can't happen for actual parallel iterators, but + // the caller ContractionOpLowering::matchAndRewrite is currently calling + // lowerParallel also for the case of unit-size reduction dims appearing only + // on one of LHS or RHS, not both. At the moment, such cases are created by + // CastAwayContractionLeadingOneDim, so we need to either support that or + // modify that pattern. + int64_t resIndex = getResultIndex(iMap[2], iterIndex).getValueOr(-1); + if (resIndex == -1 && dimSize != 1) + return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { + diag << "expected the dimension for iterIndex=" << iterIndex + << " to either appear in the result map, or to be a unit dimension"; + }); // Construct new iterator types and affine map array attribute. std::array lowIndexingMaps = { adjustMap(iMap[0], iterIndex, rewriter), @@ -1888,33 +1916,49 @@ } // Lower one reduction dimension. -Value ContractionOpLowering::lowerReduction(vector::ContractionOp op, - PatternRewriter &rewriter) const { +FailureOr +ContractionOpLowering::lowerReduction(vector::ContractionOp op, + PatternRewriter &rewriter) const { auto loc = op.getLoc(); VectorType lhsType = op.getLhsType(); VectorType rhsType = op.getRhsType(); Type resType = op.getResultType(); - assert(!resType.isa()); + if (resType.isa()) + return rewriter.notifyMatchFailure(op, + "did not expect a VectorType result"); bool isInt = resType.isa(); // Use iterator index 0. int64_t iterIndex = 0; SmallVector iMap = op.getIndexingMaps(); Optional lookupLhs = getResultIndex(iMap[0], iterIndex); Optional lookupRhs = getResultIndex(iMap[1], iterIndex); - assert(lookupLhs.has_value() && "missing LHS parallel index"); - assert(lookupRhs.has_value() && "missing RHS parallel index"); + if (!lookupLhs.hasValue()) + return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { + diag << "expected iterIndex=" << iterIndex << "to map to a LHS dimension"; + }); + if (!lookupRhs.hasValue()) + return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { + diag << "expected iterIndex=" << iterIndex << "to map to a RHS dimension"; + }); int64_t lhsIndex = lookupLhs.getValue(); int64_t rhsIndex = lookupRhs.getValue(); int64_t dimSize = lhsType.getDimSize(lhsIndex); - assert(dimSize == rhsType.getDimSize(rhsIndex) && "corrupt shape"); + if (dimSize != rhsType.getDimSize(rhsIndex)) + return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { + diag << "expect LHS dimension " << lhsIndex + << " to have the same size as RHS dimension " << rhsIndex; + }); // Base case. if (lhsType.getRank() == 1) { - assert(rhsType.getRank() == 1 && "corrupt contraction"); + if (rhsType.getRank() != 1) + return rewriter.notifyMatchFailure( + op, "When LHS has rank 1, expected also RHS to have rank 1"); Value m = createMul(loc, op.getLhs(), op.getRhs(), isInt, rewriter); auto kind = vector::CombiningKind::ADD; if (auto acc = op.getAcc()) - return rewriter.create(loc, kind, m, acc); - return rewriter.create(loc, kind, m); + return rewriter.create(loc, kind, m, acc) + .getResult(); + return rewriter.create(loc, kind, m).getResult(); } // Construct new iterator types and affine map array attribute. std::array lowIndexingMaps = { diff --git a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir --- a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir @@ -858,6 +858,34 @@ return %0 : vector<2x1x7xi1> } +// CHECK-LABEL: @contract_one_sided_unit_reduction_dim +// CHECK-SAME: (%[[A0:.+]]: vector<1x2xi32>, %[[A1:.+]]: vector<2x2xi32>, %[[A2:.+]]: vector<2xi32>) +// CHECK-DAG: %[[C:.+]] = arith.constant dense<0> : vector<2xi32> +// CHECK-DAG: %[[E00:.+]] = vector.extract %[[A0]][0] : vector<1x2xi32> +// CHECK-DAG: %[[E10:.+]] = vector.extract %[[A1]][0] : vector<2x2xi32> +// CHECK: %[[M0:.+]] = arith.muli %[[E10]], %[[E00]] : vector<2xi32> +// CHECK: %[[R0:.+]] = vector.reduction , %[[M0]] : vector<2xi32> into i32 +// CHECK: %[[I0:.+]] = vector.insert %[[R0]], %[[C]] [0] : i32 into vector<2xi32> +// CHECK: %[[E11:.+]] = vector.extract %[[A1]][1] : vector<2x2xi32> +// CHECK: %[[M1:.+]] = arith.muli %[[E11]], %[[E00]] : vector<2xi32> +// CHECK: %[[R1:.+]] = vector.reduction , %[[M1]] : vector<2xi32> into i32 +// CHECK: %[[I1:.+]] = vector.insert %[[R1]], %[[I0]] [1] : i32 into vector<2xi32> +// CHECK: %[[S:.+]] = arith.addi %[[I1]], %[[A2]] : vector<2xi32> +// CHECK: return %[[S]] : vector<2xi32> + +func.func @contract_one_sided_unit_reduction_dim(%arg0 : vector<1x2xi32>, %arg1 : vector<2x2xi32>, %arg2 : vector<2xi32>) -> vector<2xi32> { + %res = vector.contract { + indexing_maps = [ + affine_map<(d0, d1, d2) -> (d0, d2)>, + affine_map<(d0, d1, d2) -> (d1, d2)>, + affine_map<(d0, d1, d2) -> (d1)> + ], + iterator_types = ["reduction", "parallel", "reduction"], + kind = #vector.kind + } %arg0, %arg1, %arg2 : vector<1x2xi32>, vector<2x2xi32>, vector<2xi32> into vector<2xi32> + return %res : vector<2xi32> +} + #matmat_accesses_0 = [ affine_map<(m, n, k) -> (m, k)>, affine_map<(m, n, k) -> (k, n)>,