diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -91,6 +91,7 @@ PredOpTrait<"lhs and rhs have same element type", TCopVTEtIsSameAs<0, 1>>, PredOpTrait<"third operand acc and result have same element type", TCresVTEtIsSameAsOpBase<0, 2>>, + DeclareOpInterfaceMethods, DeclareOpInterfaceMethods ]>, Arguments<(ins AnyVector:$lhs, AnyVector:$rhs, AnyType:$acc, @@ -632,6 +633,7 @@ def Vector_FMAOp : Op, + DeclareOpInterfaceMethods, DeclareOpInterfaceMethods ] # ElementwiseMappable.traits>, Arguments<(ins VectorOfAnyRankOf<[AnyFloat]>:$lhs, @@ -923,7 +925,8 @@ PredOpTrait<"lhs operand and result have same element type", TCresVTEtIsSameAsOpBase<0, 0>>, PredOpTrait<"rhs operand and result have same element type", - TCresVTEtIsSameAsOpBase<0, 1>>]>, + TCresVTEtIsSameAsOpBase<0, 1>>, + DeclareOpInterfaceMethods]>, Arguments<(ins AnyVector:$lhs, AnyType:$rhs, Variadic:$acc, DefaultValuedAttr:$kind)>, diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -1107,12 +1107,48 @@ VectorType vType = fmaOp.getVectorType(); if (vType.getRank() > 1) return failure(); + + // Masked fmas are lowered separately. + auto maskableOp = cast(fmaOp.getOperation()); + if (maskableOp.isMasked()) + return failure(); + rewriter.replaceOpWithNewOp( fmaOp, adaptor.getLhs(), adaptor.getRhs(), adaptor.getAcc()); return success(); } }; +/// Conversion pattern that turns a masked vector.fma on a 1-D vector into their +/// LLVM counterpart representation. Non side effecting VP intrinsics are not +/// fully supported by some backends, including x86, and they don't support +/// pass-through values either. For these reasons, we generate an unmasked +/// fma followed by a select instrution to emulate the masking behavior. +/// This pattern is peepholed by some backends with support for masked fma +/// instructions. This pattern does not match vectors of n >= 2 rank. +class MaskedFMAOp1DConversion + : public VectorMaskOpConversionBase { +public: + using VectorMaskOpConversionBase::VectorMaskOpConversionBase; + + MaskedFMAOp1DConversion(LLVMTypeConverter &converter, bool fullVPIntr) + : VectorMaskOpConversionBase(converter) {} + + virtual LogicalResult matchAndRewriteMaskableOp( + vector::MaskOp maskOp, MaskableOpInterface maskableOp, + ConversionPatternRewriter &rewriter) const override { + auto fmaOp = cast(maskableOp.getOperation()); + Type llvmType = typeConverter->convertType(fmaOp.getVectorType()); + + Value fmulAddOp = rewriter.create( + fmaOp.getLoc(), llvmType, fmaOp.getLhs(), fmaOp.getRhs(), + fmaOp.getAcc()); + rewriter.replaceOpWithNewOp( + maskOp, llvmType, maskOp.getMask(), fmulAddOp, fmaOp.getAcc()); + return success(); + } +}; + class VectorInsertElementOpConversion : public ConvertOpToLLVMPattern { public: @@ -1279,6 +1315,11 @@ if (vType.getRank() < 2) return failure(); + // Masked fmas are lowered separately. + auto maskableOp = cast(op.getOperation()); + if (maskableOp.isMasked()) + return failure(); + auto loc = op.getLoc(); auto elemType = vType.getElementType(); Value zero = rewriter.create( @@ -1707,9 +1748,10 @@ patterns .add, VectorLoadStoreConversion, diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -889,6 +889,34 @@ return success(); } +// MaskableOpInterface methods. + +/// Returns the mask type expected by this operation. Mostly used for +/// verification purposes. It requires the operation to be vectorized." +Type ContractionOp::getExpectedMaskType() { + auto indexingMaps = this->getIndexingMapsArray(); + AffineMap lhsIdxMap = indexingMaps[0]; + AffineMap rhsIdxMap = indexingMaps[1]; + VectorType lhsType = this->getLhsType(); + VectorType rhsType = this->getRhsType(); + + unsigned numVecDims = lhsIdxMap.getNumDims(); + SmallVector maskShape(numVecDims, ShapedType::kDynamic); + + // Using the information in the indexing maps, extract the size of each + // dimension in the vector.contract operation from the two input operands. + for (auto [dimIdx, dimSize] : llvm::enumerate(lhsType.getShape())) + maskShape[lhsIdxMap.getDimPosition(dimIdx)] = dimSize; + for (auto [dimIdx, dimSize] : llvm::enumerate(rhsType.getShape())) + maskShape[rhsIdxMap.getDimPosition(dimIdx)] = dimSize; + + assert(!ShapedType::isDynamicShape(maskShape) && + "Mask shape couldn't be computed"); + + return VectorType::get(maskShape, + IntegerType::get(lhsType.getContext(), /*width=*/1)); +} + SmallVector ContractionOp::getTraitAttrNames() { return SmallVector{getIndexingMapsAttrName(), getIteratorTypesAttrName(), getKindAttrName()}; @@ -1760,6 +1788,16 @@ return llvm::to_vector<4>(getVectorType().getShape()); } +// MaskableOpInterface methods. + +/// Returns the mask type expected by this operation. Mostly used for +/// verification purposes. It requires the operation to be vectorized." +Type FMAOp::getExpectedMaskType() { + auto vecType = this->getVectorType(); + return VectorType::get(vecType.getShape(), + IntegerType::get(vecType.getContext(), /*width=*/1)); +} + //===----------------------------------------------------------------------===// // BroadcastOp //===----------------------------------------------------------------------===// @@ -2762,6 +2800,16 @@ return success(); } +// MaskableOpInterface methods. + +/// Returns the mask type expected by this operation. Mostly used for +/// verification purposes. It requires the operation to be vectorized." +Type OuterProductOp::getExpectedMaskType() { + auto vecType = this->getVectorType(); + return VectorType::get(vecType.getShape(), + IntegerType::get(vecType.getContext(), /*width=*/1)); +} + //===----------------------------------------------------------------------===// // ReshapeOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -147,13 +147,13 @@ } /// Helper to create arithmetic operation associated with a kind of contraction. -static std::optional createContractArithOp(Location loc, Value x, - Value y, Value acc, - vector::CombiningKind kind, - PatternRewriter &rewriter, - bool isInt) { +static std::optional +createContractArithOp(Location loc, Value x, Value y, Value acc, + vector::CombiningKind kind, PatternRewriter &rewriter, + bool isInt, Optional maybeMask = std::nullopt) { using vector::CombiningKind; Value mul; + if (isInt) { if (kind == CombiningKind::MINF || kind == CombiningKind::MAXF) // Only valid for floating point types. @@ -169,11 +169,17 @@ return std::nullopt; // Special case for fused multiply-add. if (acc && acc.getType().isa() && kind == CombiningKind::ADD) { - return std::optional( - rewriter.create(loc, x, y, acc)); + Operation *fmaOp = rewriter.create(loc, x, y, acc); + if (maybeMask.has_value() && maybeMask.value()) + fmaOp = maskOperation(rewriter, fmaOp, maybeMask.value()); + return fmaOp->getResult(0); } mul = rewriter.create(loc, x, y); } + + assert((!maybeMask.has_value() || !maybeMask.value()) && + "Unsupported masked case"); + if (!acc) return std::optional(mul); return makeArithReduction(rewriter, loc, kind, mul, acc); @@ -550,14 +556,27 @@ Value acc = (op.getAcc().empty()) ? nullptr : op.getAcc()[0]; vector::CombiningKind kind = op.getKind(); + // Vector mask setup. + OpBuilder::InsertionGuard guard(rewriter); + auto maskableOp = cast(op.getOperation()); + Operation *rootOp; + Value mask; + if (maskableOp.isMasked()) { + rewriter.setInsertionPoint(maskableOp.getMaskingOp()); + rootOp = maskableOp.getMaskingOp(); + mask = maskableOp.getMaskingOp().getMask(); + } else { + rootOp = op; + } + if (!rhsType) { // Special case: AXPY operation. Value b = rewriter.create(loc, lhsType, op.getRhs()); std::optional mult = createContractArithOp( - loc, op.getLhs(), b, acc, kind, rewriter, isInt); + loc, op.getLhs(), b, acc, kind, rewriter, isInt, mask); if (!mult.has_value()) return failure(); - rewriter.replaceOp(op, *mult); + rewriter.replaceOp(rootOp, *mult); return success(); } @@ -571,13 +590,14 @@ Value r = nullptr; if (acc) r = rewriter.create(loc, rhsType, acc, pos); - std::optional m = - createContractArithOp(loc, a, op.getRhs(), r, kind, rewriter, isInt); + std::optional m = createContractArithOp( + loc, a, op.getRhs(), r, kind, rewriter, isInt, mask); if (!m.has_value()) return failure(); result = rewriter.create(loc, resType, *m, result, pos); } - rewriter.replaceOp(op, result); + + rewriter.replaceOp(rootOp, result); return success(); } }; @@ -601,7 +621,12 @@ LogicalResult matchAndRewrite(vector::ContractionOp contractOp, PatternRewriter &rewriter) const override { - // TODO: implement masks + // TODO: Support vector.mask. + auto maskableOp = cast(contractOp.getOperation()); + if (maskableOp.isMasked()) + return failure(); + + // TODO: Remove native masks from contraction op? if (!contractOp.getMasks().empty()) return failure(); @@ -1429,7 +1454,12 @@ LogicalResult ContractionOpToMatmulOpLowering::matchAndRewrite(vector::ContractionOp op, PatternRewriter &rew) const { - // TODO: implement masks + // TODO: Support vector.mask. + auto maskableOp = cast(op.getOperation()); + if (maskableOp.isMasked()) + return failure(); + + // TODO: Remove native masks from contraction op? if (!op.getMasks().empty()) return failure(); if (vectorTransformOptions.vectorContractLowering != @@ -1525,10 +1555,16 @@ UnrolledOuterProductGenerator(RewriterBase &b, vector::ContractionOp op) : StructuredGenerator(b, op), kind(op.getKind()), lhs(op.getLhs()), rhs(op.getRhs()), - res(op.getAcc()), lhsType(op.getLhsType()) {} + res(op.getAcc()), lhsType(op.getLhsType()) { + auto maskableOp = cast(op.getOperation()); + if (maskableOp.isMasked()) + mask = maskableOp.getMaskingOp().getMask(); + } Value t(Value v) { static constexpr std::array perm = {1, 0}; + if (!v) + return v; return rewriter.create(loc, v, perm); } @@ -1547,16 +1583,27 @@ return rewriter.create(loc, promotedType, v); } - Value outerProd(Value lhs, Value rhs, Value res, int reductionSize) { + FailureOr outerProd(Value lhs, Value rhs, Value res, int reductionSize, + Optional maybeMask = std::nullopt) { assert(reductionSize > 0); + // Incremental support for masking. + if (mask && !maybeMask.has_value()) + return failure(); + Type resElementType = res.getType().cast().getElementType(); for (int64_t k = 0; k < reductionSize; ++k) { Value extractA = rewriter.create(loc, lhs, k); Value extractB = rewriter.create(loc, rhs, k); extractA = promote(extractA, resElementType); extractB = promote(extractB, resElementType); - res = rewriter.create(loc, res.getType(), extractA, - extractB, res, kind); + Value extractMask; + if (maybeMask.has_value() && maybeMask.value()) + extractMask = + rewriter.create(loc, maybeMask.value(), k); + + Operation *outerProdOp = rewriter.create( + loc, res.getType(), extractA, extractB, res, kind); + res = maskOperation(rewriter, outerProdOp, extractMask)->getResult(0); } return res; } @@ -1607,7 +1654,7 @@ // Case mat-vec: transpose. if (layout({{m, k}, {k}, {m}})) - return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1)); + return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1), t(mask)); // Case mat-trans-vec: ready to go. if (layout({{k, m}, {k}, {m}})) return outerProd(lhs, rhs, res, lhsType.getDimSize(0)); @@ -1646,7 +1693,7 @@ private: vector::CombiningKind kind; - Value lhs, rhs, res; + Value lhs, rhs, res, mask; VectorType lhsType; }; } // namespace @@ -1668,7 +1715,7 @@ /// otherwise supports any layout permutation of the matrix-multiply. LogicalResult ContractionOpToOuterProductOpLowering::matchAndRewrite( vector::ContractionOp op, PatternRewriter &rewriter) const { - // TODO: implement masks + // TODO: Remove native masks from contraction op? if (!op.getMasks().empty()) return failure(); @@ -1679,20 +1726,31 @@ if (failed(filter(op))) return failure(); + // Vector mask setup. + OpBuilder::InsertionGuard guard(rewriter); + auto maskableOp = cast(op.getOperation()); + Operation *rootOp; + if (maskableOp.isMasked()) { + rewriter.setInsertionPoint(maskableOp.getMaskingOp()); + rootOp = maskableOp.getMaskingOp(); + } else { + rootOp = op; + } + UnrolledOuterProductGenerator e(rewriter, op); FailureOr matmatRes = e.matmat(); if (succeeded(matmatRes)) { - rewriter.replaceOp(op, *matmatRes); + rewriter.replaceOp(rootOp, *matmatRes); return success(); } FailureOr matvecRes = e.matvec(); if (succeeded(matvecRes)) { - rewriter.replaceOp(op, *matvecRes); + rewriter.replaceOp(rootOp, *matvecRes); return success(); } FailureOr tmatvecRes = e.tmatvec(); if (succeeded(tmatvecRes)) { - rewriter.replaceOp(op, *tmatvecRes); + rewriter.replaceOp(rootOp, *tmatvecRes); return success(); } @@ -1702,7 +1760,12 @@ LogicalResult ContractionOpToDotLowering::matchAndRewrite(vector::ContractionOp op, PatternRewriter &rewriter) const { - // TODO: implement masks + // TODO: Support vector.mask. + auto maskableOp = cast(op.getOperation()); + if (maskableOp.isMasked()) + return failure(); + + // TODO: Remove native masks from contraction op? if (!op.getMasks().empty()) return failure(); @@ -1834,7 +1897,12 @@ LogicalResult ContractionOpLowering::matchAndRewrite(vector::ContractionOp op, PatternRewriter &rewriter) const { - // TODO: implement masks. + // TODO: Support vector.mask. + auto maskableOp = cast(op.getOperation()); + if (maskableOp.isMasked()) + return failure(); + + // TODO: Remove native masks from contraction op? if (!op.getMasks().empty()) return failure(); diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -416,6 +416,18 @@ // CHECK: %[[T19:.*]] = builtin.unrealized_conversion_cast %[[T18]] : !llvm.array<2 x vector<3xf32>> to vector<2x3xf32> // CHECK: return %[[T19]] : vector<2x3xf32> +// ----- + +func.func @masked_vector_contract(%arg0: vector<2xf32>, %arg1: f32, %arg2: vector<2xf32>, %m: vector<2xi1>) -> vector<2xf32> { + %0 = vector.mask %m { vector.outerproduct %arg0, %arg1, %arg2 {kind = #vector.kind} : vector<2xf32>, f32 } : vector<2xi1> -> vector<2xf32> + return %0 : vector<2xf32> +} + +// We can't check for the intermediate 'vector.mask { vector.fma }' state so we +// just make sure the vector.fma is lowered. + +// CHECK: llvm.intr.fmuladd +// CHECK: llvm.select // ----- @@ -2145,3 +2157,17 @@ %0 = vector.scalable.extract %vec[0] : vector<8xf32> from vector<[4]xf32> return %0 : vector<8xf32> } + +// ----- + +// CHECK-LABEL: func.func @masked_vector_fma( +// CHECK-SAME: %[[INPUT:.*]]: vector<8xf32>, +// CHECK-SAME: %[[MASK:.*]]: vector<8xi1>) -> vector<8xf32> +// CHECK: %[[FMA:.*]] = llvm.intr.fmuladd(%[[INPUT]], %[[INPUT]], %[[INPUT]]) : (vector<8xf32>, vector<8xf32>, vector<8xf32>) -> vector<8xf32> +// CHECK: llvm.select %[[MASK]], %[[FMA]], %[[INPUT]] : vector<8xi1>, vector<8xf32> + +func.func @masked_vector_fma(%a: vector<8xf32>, %m: vector<8xi1>) -> vector<8xf32> { + %0 = vector.mask %m { vector.fma %a, %a, %a : vector<8xf32> } : vector<8xi1> -> vector<8xf32> + return %0 : vector<8xf32> +} + diff --git a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir --- a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir @@ -1196,3 +1196,27 @@ %arg0, %arg1, %arg2 : vector<1x1xf32>, vector<1x1xf32> into f32 return %0 : f32 } + +func.func @masked_vector_contract(%arg0: vector<2x3xf32>, + %arg1: vector<3xf32>, + %arg2: vector<2xf32>, + %m: vector<2x3xi1>) -> vector<2xf32> { + %0 = vector.mask %m { vector.contract #matvec_trait %arg0, %arg1, %arg2 + : vector<2x3xf32>, vector<3xf32> into vector<2xf32> } : vector<2x3xi1> -> vector<2xf32> + return %0 : vector<2xf32> +} + +// OUTERPRODUCT-LABEL: func.func @masked_vector_contract( +// OUTERPRODUCT-SAME: %[[VAL_0:.*]]: vector<2x3xf32>, +// OUTERPRODUCT-SAME: %[[VAL_1:.*]]: vector<3xf32>, +// OUTERPRODUCT-SAME: %[[VAL_2:.*]]: vector<2xf32>, +// OUTERPRODUCT-SAME: %[[IN_MASK:.*]]: vector<2x3xi1>) -> vector<2xf32> +// OUTERPRODUCT: %[[T_MASK:.*]] = vector.transpose %[[IN_MASK]], [1, 0] : vector<2x3xi1> to vector<3x2xi1> +// OUTERPRODUCT: %[[MASK0:.*]] = vector.extract %[[T_MASK]][0] : vector<3x2xi1> +// OUTERPRODUCT: vector.mask %[[MASK0]] { vector.outerproduct + +// OUTERPRODUCT: %[[MASK1:.*]] = vector.extract %[[T_MASK]][1] : vector<3x2xi1> +// OUTERPRODUCT: vector.mask %[[MASK1]] { vector.outerproduct + +// OUTERPRODUCT: %[[MASK2:.*]] = vector.extract %[[T_MASK]][2] : vector<3x2xi1> +// OUTERPRODUCT: vector.mask %[[MASK2]] { vector.outerproduct