diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h --- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h +++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h @@ -510,12 +510,12 @@ vector::VectorTransformsOptions vectorTransformOptions; FilterConstraintType filter; // Lower one parallel dimension. - FailureOr lowerParallel(vector::ContractionOp op, int64_t lhsIndex, - int64_t rhsIndex, - PatternRewriter &rewriter) const; + FailureOr lowerParallel(PatternRewriter &rewriter, + vector::ContractionOp op, int64_t lhsIndex, + int64_t rhsIndex, Value mask) const; // Lower one reduction dimension. - FailureOr lowerReduction(vector::ContractionOp op, - PatternRewriter &rewriter) const; + FailureOr lowerReduction(PatternRewriter &rewriter, + vector::ContractionOp op, Value mask) const; }; } // namespace vector diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -1904,11 +1904,6 @@ LogicalResult ContractionOpLowering::matchAndRewrite(vector::ContractionOp op, PatternRewriter &rewriter) const { - // TODO: Support vector.mask. - auto maskableOp = cast(op.getOperation()); - if (maskableOp.isMasked()) - return failure(); - // TODO: Remove native masks from contraction op? if (!op.getMasks().empty()) return failure(); @@ -1944,15 +1939,25 @@ if (succeeded(pat4.matchAndRewrite(op, rewriter))) return success(); + // Vector mask setup. + OpBuilder::InsertionGuard guard(rewriter); + Operation *rootOp = op; + Value mask; + if (op.isMasked()) { + rewriter.setInsertionPoint(op.getMaskingOp()); + rootOp = op.getMaskingOp(); + mask = op.getMaskingOp().getMask(); + } + // Find first batch dimension in LHS/RHS, and lower when found. std::vector> batchDimMap = op.getBatchDimMap(); if (!batchDimMap.empty()) { int64_t lhsIndex = batchDimMap[0].first; int64_t rhsIndex = batchDimMap[0].second; - auto newOp = lowerParallel(op, lhsIndex, rhsIndex, rewriter); + auto newOp = lowerParallel(rewriter, op, lhsIndex, rhsIndex, mask); if (failed(newOp)) return failure(); - rewriter.replaceOp(op, *newOp); + rewriter.replaceOp(rootOp, *newOp); return success(); } @@ -1970,10 +1975,10 @@ VectorType lhsType = op.getLhsType(); for (int64_t lhsIndex = 0, e = lhsType.getRank(); lhsIndex < e; ++lhsIndex) { if (lhsContractingDimSet.count(lhsIndex) == 0) { - auto newOp = lowerParallel(op, lhsIndex, /*rhsIndex=*/-1, rewriter); + auto newOp = lowerParallel(rewriter, op, lhsIndex, /*rhsIndex=*/-1, mask); if (failed(newOp)) return failure(); - rewriter.replaceOp(op, *newOp); + rewriter.replaceOp(rootOp, *newOp); return success(); } } @@ -1982,20 +1987,20 @@ VectorType rhsType = op.getRhsType(); for (int64_t rhsIndex = 0, e = rhsType.getRank(); rhsIndex < e; ++rhsIndex) { if (rhsContractingDimSet.count(rhsIndex) == 0) { - auto newOp = lowerParallel(op, /*lhsIndex=*/-1, rhsIndex, rewriter); + auto newOp = lowerParallel(rewriter, op, /*lhsIndex=*/-1, rhsIndex, mask); if (failed(newOp)) return failure(); - rewriter.replaceOp(op, *newOp); + rewriter.replaceOp(rootOp, *newOp); return success(); } } // Lower the first remaining reduction dimension. if (!contractingDimMap.empty()) { - auto newOp = lowerReduction(op, rewriter); + auto newOp = lowerReduction(rewriter, op, mask); if (failed(newOp)) return failure(); - rewriter.replaceOp(op, *newOp); + rewriter.replaceOp(rootOp, *newOp); return success(); } @@ -2005,10 +2010,11 @@ // Lower one parallel dimension. // Incidentally also tolerates unit-size (hence trivial) reduction dimensions. // TODO: consider reusing existing contract unrolling -FailureOr -ContractionOpLowering::lowerParallel(vector::ContractionOp op, int64_t lhsIndex, - int64_t rhsIndex, - PatternRewriter &rewriter) const { +FailureOr ContractionOpLowering::lowerParallel(PatternRewriter &rewriter, + vector::ContractionOp op, + int64_t lhsIndex, + int64_t rhsIndex, + Value mask) const { VectorType lhsType = op.getLhsType(); VectorType rhsType = op.getRhsType(); VectorType resType = op.getResultType().cast(); @@ -2046,6 +2052,7 @@ diag << "expected the dimension for iterIndex=" << iterIndex << " to either appear in the result map, or to be a unit dimension"; }); + // Construct new iterator types and affine map array attribute. std::array lowIndexingMaps = { adjustMap(iMap[0], iterIndex, rewriter), @@ -2058,22 +2065,29 @@ Location loc = op.getLoc(); Value result = rewriter.create( loc, resType, rewriter.getZeroAttr(resType)); + for (int64_t d = 0; d < dimSize; ++d) { auto lhs = reshapeLoad(loc, op.getLhs(), lhsType, lhsIndex, d, rewriter); auto rhs = reshapeLoad(loc, op.getRhs(), rhsType, rhsIndex, d, rewriter); auto acc = reshapeLoad(loc, op.getAcc(), resType, resIndex, d, rewriter); - Value lowContract = rewriter.create( + + Value lowMask; + if (mask) + lowMask = reshapeLoad(loc, mask, cast(mask.getType()), + iterIndex, d, rewriter); + + Operation *lowContract = rewriter.create( loc, lhs, rhs, acc, lowAffine, lowIter); - result = - reshapeStore(loc, lowContract, result, resType, resIndex, d, rewriter); + lowContract = maskOperation(rewriter, lowContract, lowMask); + result = reshapeStore(loc, lowContract->getResult(0), result, resType, + resIndex, d, rewriter); } return result; } // Lower one reduction dimension. -FailureOr -ContractionOpLowering::lowerReduction(vector::ContractionOp op, - PatternRewriter &rewriter) const { +FailureOr ContractionOpLowering::lowerReduction( + PatternRewriter &rewriter, vector::ContractionOp op, Value mask) const { auto loc = op.getLoc(); VectorType lhsType = op.getLhsType(); VectorType rhsType = op.getRhsType(); @@ -2110,10 +2124,12 @@ op, "When LHS has rank 1, expected also RHS to have rank 1"); Value m = createMul(loc, op.getLhs(), op.getRhs(), isInt, rewriter); auto kind = vector::CombiningKind::ADD; - if (auto acc = op.getAcc()) - return rewriter.create(loc, kind, m, acc) - .getResult(); - return rewriter.create(loc, kind, m).getResult(); + + Value acc = op.getAcc(); + Operation *reductionOp = + acc ? rewriter.create(loc, kind, m, acc) + : rewriter.create(loc, kind, m); + return maskOperation(rewriter, reductionOp, mask)->getResult(0); } // Construct new iterator types and affine map array attribute. std::array lowIndexingMaps = { @@ -2131,8 +2147,14 @@ for (int64_t d = 0; d < dimSize; ++d) { auto lhs = reshapeLoad(loc, op.getLhs(), lhsType, lhsIndex, d, rewriter); auto rhs = reshapeLoad(loc, op.getRhs(), rhsType, rhsIndex, d, rewriter); - result = rewriter.create(loc, lhs, rhs, result, - lowAffine, lowIter); + Value newMask; + if (mask) + newMask = reshapeLoad(loc, mask, cast(mask.getType()), + iterIndex, d, rewriter); + + Operation *newContract = rewriter.create( + loc, lhs, rhs, result, lowAffine, lowIter); + result = maskOperation(rewriter, newContract, newMask)->getResult(0); } return result; } diff --git a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir --- a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir @@ -28,6 +28,18 @@ return %0 : f32 } +// CHECK-LABEL: func @masked_extract_contract1 +// CHECK-SAME: %[[A:.*0]]: vector<4xf32>, %[[B:.*1]]: vector<4xf32>, %[[C:.*2]]: f32 +// CHECK-SAME: %[[M:.*]]: vector<4xi1> +// CHECK: %[[F:.*]] = arith.mulf %[[A]], %[[B]] : vector<4xf32> +// CHECK: %[[R:.*]] = vector.mask %[[M]] { vector.reduction , %0, %arg2 : vector<4xf32> into f32 } : vector<4xi1> -> f32 +// CHECK: return %[[R]] : f32 + +func.func @masked_extract_contract1(%arg0: vector<4xf32>, %arg1: vector<4xf32>, %arg2: f32, %mask: vector<4xi1>) -> f32 { + %0 = vector.mask %mask { vector.contract #dotp_trait %arg0, %arg1, %arg2 : vector<4xf32>, vector<4xf32> into f32 } : vector<4xi1> -> f32 + return %0 : f32 +} + // CHECK-LABEL: func @extract_contract1_int // CHECK-SAME: %[[A:.*0]]: vector<4xi32>, // CHECK-SAME: %[[B:.*1]]: vector<4xi32>,