diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h @@ -191,7 +191,7 @@ /// Return the result value of reducing two scalar/vector values with the /// corresponding arith operation. Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, - Value v1, Value v2); + Value v1, Value acc, Value mask = Value()); /// Returns true if `attr` has "parallel" iterator type semantics. inline bool isParallelIterator(Attribute attr) { @@ -214,8 +214,17 @@ /// Creates a vector.mask operation around a maskable operation. Returns the /// vector.mask operation if the mask provided is valid. Otherwise, returns the /// maskable operation itself. -Operation *maskOperation(RewriterBase &rewriter, Operation *maskableOp, - Value mask); +Operation *maskOperation(OpBuilder &builder, Operation *maskableOp, + Value mask, Value passthru = Value()); + +/// Creates a vector select operation that picks values from `newValue` or +/// `passthru` for each result vector lane based on `mask`. This utility is used +/// to propagate the pass-thru value for masked-out or expeculatively executed +/// lanes. VP intrinsics do not support pass-thru values and every mask-out lane +/// is set to poison. LLVM backends are usually able to match op + select +/// patterns and fold them into a native target instructions. +Value selectPassthru(OpBuilder &builder, Value mask, Value newValue, + Value passthru); } // namespace vector } // namespace mlir diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -633,7 +633,6 @@ def Vector_FMAOp : Op, - DeclareOpInterfaceMethods, DeclareOpInterfaceMethods ] # ElementwiseMappable.traits>, Arguments<(ins VectorOfAnyRankOf<[AnyFloat]>:$lhs, diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -704,11 +704,6 @@ Value acc = adaptor.getAcc(); Location loc = reductionOp.getLoc(); - // Masked reductions are lowered separately. - auto maskableOp = cast(reductionOp.getOperation()); - if (maskableOp.isMasked()) - return failure(); - if (eltType.isIntOrIndex()) { // Integer reductions: add/mul/min/max/and/or/xor. Value result; @@ -1108,47 +1103,12 @@ if (vType.getRank() > 1) return failure(); - // Masked fmas are lowered separately. - auto maskableOp = cast(fmaOp.getOperation()); - if (maskableOp.isMasked()) - return failure(); - rewriter.replaceOpWithNewOp( fmaOp, adaptor.getLhs(), adaptor.getRhs(), adaptor.getAcc()); return success(); } }; -/// Conversion pattern that turns a masked vector.fma on a 1-D vector into their -/// LLVM counterpart representation. Non side effecting VP intrinsics are not -/// fully supported by some backends, including x86, and they don't support -/// pass-through values either. For these reasons, we generate an unmasked -/// fma followed by a select instrution to emulate the masking behavior. -/// This pattern is peepholed by some backends with support for masked fma -/// instructions. This pattern does not match vectors of n >= 2 rank. -class MaskedFMAOp1DConversion - : public VectorMaskOpConversionBase { -public: - using VectorMaskOpConversionBase::VectorMaskOpConversionBase; - - MaskedFMAOp1DConversion(LLVMTypeConverter &converter, bool fullVPIntr) - : VectorMaskOpConversionBase(converter) {} - - virtual LogicalResult matchAndRewriteMaskableOp( - vector::MaskOp maskOp, MaskableOpInterface maskableOp, - ConversionPatternRewriter &rewriter) const override { - auto fmaOp = cast(maskableOp.getOperation()); - Type llvmType = typeConverter->convertType(fmaOp.getVectorType()); - - Value fmulAddOp = rewriter.create( - fmaOp.getLoc(), llvmType, fmaOp.getLhs(), fmaOp.getRhs(), - fmaOp.getAcc()); - rewriter.replaceOpWithNewOp( - maskOp, llvmType, maskOp.getMask(), fmulAddOp, fmaOp.getAcc()); - return success(); - } -}; - class VectorInsertElementOpConversion : public ConvertOpToLLVMPattern { public: @@ -1315,11 +1275,6 @@ if (vType.getRank() < 2) return failure(); - // Masked fmas are lowered separately. - auto maskableOp = cast(op.getOperation()); - if (maskableOp.isMasked()) - return failure(); - auto loc = op.getLoc(); auto elemType = vType.getElementType(); Value zero = rewriter.create( @@ -1748,10 +1703,9 @@ patterns .add, VectorLoadStoreConversion, diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -1790,16 +1790,6 @@ return llvm::to_vector<4>(getVectorType().getShape()); } -// MaskableOpInterface methods. - -/// Returns the mask type expected by this operation. Mostly used for -/// verification purposes. It requires the operation to be vectorized." -Type FMAOp::getExpectedMaskType() { - auto vecType = this->getVectorType(); - return VectorType::get(vecType.getShape(), - IntegerType::get(vecType.getContext(), /*width=*/1)); -} - //===----------------------------------------------------------------------===// // BroadcastOp //===----------------------------------------------------------------------===// @@ -5807,53 +5797,71 @@ } Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc, - CombiningKind kind, Value v1, Value v2) { + CombiningKind kind, Value v1, Value acc, + Value mask) { Type t1 = getElementTypeOrSelf(v1.getType()); - Type t2 = getElementTypeOrSelf(v2.getType()); + Type tAcc = getElementTypeOrSelf(acc.getType()); + Value result; + switch (kind) { case CombiningKind::ADD: - if (t1.isIntOrIndex() && t2.isIntOrIndex()) - return b.createOrFold(loc, v1, v2); - else if (t1.isa() && t2.isa()) - return b.createOrFold(loc, v1, v2); - llvm_unreachable("invalid value types for ADD reduction"); + if (t1.isIntOrIndex() && tAcc.isIntOrIndex()) + result = b.createOrFold(loc, v1, acc); + else if (t1.isa() && tAcc.isa()) + result = b.createOrFold(loc, v1, acc); + else + llvm_unreachable("invalid value types for ADD reduction"); + break; case CombiningKind::AND: - assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values"); - return b.createOrFold(loc, v1, v2); + assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values"); + result = b.createOrFold(loc, v1, acc); + break; case CombiningKind::MAXF: - assert(t1.isa() && t2.isa() && + assert(t1.isa() && tAcc.isa() && "expected float values"); - return b.createOrFold(loc, v1, v2); + result = b.createOrFold(loc, v1, acc); + break; case CombiningKind::MINF: - assert(t1.isa() && t2.isa() && + assert(t1.isa() && tAcc.isa() && "expected float values"); - return b.createOrFold(loc, v1, v2); + result = b.createOrFold(loc, v1, acc); + break; case CombiningKind::MAXSI: - assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values"); - return b.createOrFold(loc, v1, v2); + assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values"); + result = b.createOrFold(loc, v1, acc); + break; case CombiningKind::MINSI: - assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values"); - return b.createOrFold(loc, v1, v2); + assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values"); + result = b.createOrFold(loc, v1, acc); + break; case CombiningKind::MAXUI: - assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values"); - return b.createOrFold(loc, v1, v2); + assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values"); + result = b.createOrFold(loc, v1, acc); + break; case CombiningKind::MINUI: - assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values"); - return b.createOrFold(loc, v1, v2); + assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values"); + result = b.createOrFold(loc, v1, acc); + break; case CombiningKind::MUL: - if (t1.isIntOrIndex() && t2.isIntOrIndex()) - return b.createOrFold(loc, v1, v2); - else if (t1.isa() && t2.isa()) - return b.createOrFold(loc, v1, v2); - llvm_unreachable("invalid value types for MUL reduction"); + if (t1.isIntOrIndex() && tAcc.isIntOrIndex()) + result = b.createOrFold(loc, v1, acc); + else if (t1.isa() && tAcc.isa()) + result = b.createOrFold(loc, v1, acc); + else + llvm_unreachable("invalid value types for MUL reduction"); + break; case CombiningKind::OR: - assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values"); - return b.createOrFold(loc, v1, v2); + assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values"); + result = b.createOrFold(loc, v1, acc); + break; case CombiningKind::XOR: - assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values"); - return b.createOrFold(loc, v1, v2); + assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values"); + result = b.createOrFold(loc, v1, acc); + break; }; - llvm_unreachable("unknown CombiningKind"); + + assert(result && "unknown CombiningKind"); + return selectPassthru(b, mask, result, acc); } //===----------------------------------------------------------------------===// @@ -5875,13 +5883,34 @@ /// Creates a vector.mask operation around a maskable operation. Returns the /// vector.mask operation if the mask provided is valid. Otherwise, returns /// the maskable operation itself. -Operation *mlir::vector::maskOperation(RewriterBase &rewriter, - Operation *maskableOp, Value mask) { +Operation *mlir::vector::maskOperation(OpBuilder &builder, + Operation *maskableOp, Value mask, + Value passthru) { if (!mask) return maskableOp; - return rewriter.create(maskableOp->getLoc(), - maskableOp->getResultTypes(), mask, maskableOp, - createMaskOpRegion); + if (passthru) + return builder.create(maskableOp->getLoc(), + maskableOp->getResultTypes(), mask, passthru, + maskableOp, createMaskOpRegion); + return builder.create(maskableOp->getLoc(), + maskableOp->getResultTypes(), mask, maskableOp, + createMaskOpRegion); +} + +/// Creates a vector select operation that picks values from `newValue` or +/// `passthru` for each result vector lane based on `mask`. This utility is used +/// to propagate the pass-thru value of vector.mask or for cases where only the +/// pass-thru value propagation is needed. VP intrinsics do not support +/// pass-thru values and every mask-out lane is set to poison. LLVM backends are +/// usually able to match op + select patterns and fold them into a native +/// target instructions. +Value mlir::vector::selectPassthru(OpBuilder &builder, Value mask, + Value newValue, Value passthru) { + if (!mask) + return newValue; + + return builder.create(newValue.getLoc(), newValue.getType(), + mask, newValue, passthru); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -151,8 +151,7 @@ static std::optional createContractArithOp(Location loc, Value x, Value y, Value acc, vector::CombiningKind kind, PatternRewriter &rewriter, - bool isInt, - std::optional maybeMask = std::nullopt) { + bool isInt, Value mask = Value()) { using vector::CombiningKind; Value mul; @@ -171,20 +170,20 @@ return std::nullopt; // Special case for fused multiply-add. if (acc && acc.getType().isa() && kind == CombiningKind::ADD) { - Operation *fmaOp = rewriter.create(loc, x, y, acc); - if (maybeMask.has_value() && maybeMask.value()) - fmaOp = maskOperation(rewriter, fmaOp, maybeMask.value()); - return fmaOp->getResult(0); + Value fma = rewriter.create(loc, x, y, acc); + if (mask) + // The fma op doesn't need explicit masking. However, fma ops used in + // reductions must preserve previous 'acc' values for masked-out lanes. + fma = selectPassthru(rewriter, mask, fma, acc); + return fma; } mul = rewriter.create(loc, x, y); } - assert((!maybeMask.has_value() || !maybeMask.value()) && - "Unsupported masked case"); - if (!acc) return std::optional(mul); - return makeArithReduction(rewriter, loc, kind, mul, acc); + + return makeArithReduction(rewriter, loc, kind, mul, acc, mask); } /// Return the positions of the reductions in the given map. @@ -587,13 +586,17 @@ for (int64_t d = 0, e = resType.getDimSize(0); d < e; ++d) { auto pos = rewriter.getI64ArrayAttr(d); Value x = - rewriter.create(loc, eltType, op.getLhs(), pos); + rewriter.create(loc, op.getLhs(), pos); Value a = rewriter.create(loc, rhsType, x); Value r = nullptr; if (acc) - r = rewriter.create(loc, rhsType, acc, pos); + r = rewriter.create(loc, acc, pos); + Value extrMask; + if (mask) + extrMask = rewriter.create(loc, mask, pos); + std::optional m = createContractArithOp( - loc, a, op.getRhs(), r, kind, rewriter, isInt, mask); + loc, a, op.getRhs(), r, kind, rewriter, isInt, extrMask); if (!m.has_value()) return failure(); result = rewriter.create(loc, resType, *m, result, pos); @@ -638,6 +641,7 @@ if (vectorTransformOptions.vectorContractLowering != vector::VectorContractLowering::ParallelArith) return failure(); + ArrayRef lhsShape = contractOp.getLhsType().getShape(); ArrayRef rhsShape = contractOp.getRhsType().getShape(); AffineMap lhsMap = contractOp.getIndexingMapsArray()[0]; @@ -1564,8 +1568,7 @@ mask = maskableOp.getMaskingOp().getMask(); } - Value t(Value v) { - static constexpr std::array perm = {1, 0}; + Value t(Value v, ArrayRef perm = {1, 0}) { if (!v) return v; return rewriter.create(loc, v, perm); @@ -1620,7 +1623,8 @@ bindDims(rewriter.getContext(), m, n, k); // Classical row-major matmul: Just permute the lhs. if (layout({{m, k}, {k, n}, {m, n}})) - return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1)); + return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1), + t(mask, {2, 0, 1})); // TODO: may be better to fail and use some vector -> scalar reduction. if (layout({{m, k}, {n, k}, {m, n}})) { Value tlhs = t(lhs); diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -418,16 +418,132 @@ // ----- -func.func @masked_vector_contract(%arg0: vector<2xf32>, %arg1: f32, %arg2: vector<2xf32>, %m: vector<2xi1>) -> vector<2xf32> { +func.func @masked_float_add_outerprod(%arg0: vector<2xf32>, %arg1: f32, %arg2: vector<2xf32>, %m: vector<2xi1>) -> vector<2xf32> { %0 = vector.mask %m { vector.outerproduct %arg0, %arg1, %arg2 {kind = #vector.kind} : vector<2xf32>, f32 } : vector<2xi1> -> vector<2xf32> return %0 : vector<2xf32> } -// We can't check for the intermediate 'vector.mask { vector.fma }' state so we -// just make sure the vector.fma is lowered. +// CHECK-LABEL: func.func @masked_float_add_outerprod( +// CHECK-SAME: %[[VAL_0:.*]]: vector<2xf32>, %[[VAL_1:.*]]: f32, %[[VAL_2:.*]]: vector<2xf32>, %[[VAL_3:.*]]: vector<2xi1>) -> vector<2xf32> { +// CHECK: %[[VAL_8:.*]] = llvm.intr.fmuladd(%[[VAL_0]], %{{.*}}, %[[VAL_2]]) : (vector<2xf32>, vector<2xf32>, vector<2xf32>) -> vector<2xf32> +// CHECK: %[[VAL_9:.*]] = arith.select %[[VAL_3]], %[[VAL_8]], %[[VAL_2]] : vector<2xi1>, vector<2xf32> -// CHECK: llvm.intr.fmuladd -// CHECK: llvm.select +// ----- + +func.func @masked_float_mul_outerprod(%arg0: vector<2xf32>, %arg1: f32, %arg2: vector<2xf32>, %m: vector<2xi1>) -> vector<2xf32> { + %0 = vector.mask %m { vector.outerproduct %arg0, %arg1, %arg2 {kind = #vector.kind} : vector<2xf32>, f32 } : vector<2xi1> -> vector<2xf32> + return %0 : vector<2xf32> +} + +// CHECK-LABEL: func.func @masked_float_mul_outerprod( +// CHECK-SAME: %[[VAL_0:.*]]: vector<2xf32>, %[[VAL_1:.*]]: f32, %[[VAL_2:.*]]: vector<2xf32>, %[[VAL_3:.*]]: vector<2xi1>) -> vector<2xf32> { +// CHECK: %[[VAL_8:.*]] = arith.mulf %[[VAL_0]], %{{.*}} : vector<2xf32> +// CHECK: %[[VAL_9:.*]] = arith.mulf %[[VAL_8]], %[[VAL_2]] : vector<2xf32> +// CHECK: %[[VAL_10:.*]] = arith.select %[[VAL_3]], %[[VAL_9]], %[[VAL_2]] : vector<2xi1>, vector<2xf32> + +// ----- + +func.func @masked_float_max_outerprod(%arg0: vector<2xf32>, %arg1: f32, %arg2: vector<2xf32>, %m: vector<2xi1>) -> vector<2xf32> { + %0 = vector.mask %m { vector.outerproduct %arg0, %arg1, %arg2 {kind = #vector.kind} : vector<2xf32>, f32 } : vector<2xi1> -> vector<2xf32> + return %0 : vector<2xf32> +} + +// CHECK-LABEL: func.func @masked_float_max_outerprod( +// CHECK-SAME: %[[VAL_0:.*]]: vector<2xf32>, %[[VAL_1:.*]]: f32, %[[VAL_2:.*]]: vector<2xf32>, %[[VAL_3:.*]]: vector<2xi1>) -> vector<2xf32> { +// CHECK: %[[VAL_8:.*]] = arith.mulf %[[VAL_0]], %{{.*}} : vector<2xf32> +// CHECK: %[[VAL_9:.*]] = arith.maxf %[[VAL_8]], %[[VAL_2]] : vector<2xf32> +// CHECK: %[[VAL_10:.*]] = arith.select %[[VAL_3]], %[[VAL_9]], %[[VAL_2]] : vector<2xi1>, vector<2xf32> + +// ----- + +func.func @masked_float_min_outerprod(%arg0: vector<2xf32>, %arg1: f32, %arg2: vector<2xf32>, %m: vector<2xi1>) -> vector<2xf32> { + %0 = vector.mask %m { vector.outerproduct %arg0, %arg1, %arg2 {kind = #vector.kind} : vector<2xf32>, f32 } : vector<2xi1> -> vector<2xf32> + return %0 : vector<2xf32> +} + +// CHECK-LABEL: func.func @masked_float_min_outerprod( +// CHECK-SAME: %[[VAL_0:.*]]: vector<2xf32>, %[[VAL_1:.*]]: f32, %[[VAL_2:.*]]: vector<2xf32>, %[[VAL_3:.*]]: vector<2xi1>) -> vector<2xf32> { +// CHECK: %[[VAL_8:.*]] = arith.mulf %[[VAL_0]], %{{.*}} : vector<2xf32> +// CHECK: %[[VAL_9:.*]] = arith.minf %[[VAL_8]], %[[VAL_2]] : vector<2xf32> +// CHECK: %[[VAL_10:.*]] = arith.select %[[VAL_3]], %[[VAL_9]], %[[VAL_2]] : vector<2xi1>, vector<2xf32> + +// ----- + +func.func @masked_int_add_outerprod(%arg0: vector<2xi32>, %arg1: i32, %arg2: vector<2xi32>, %m: vector<2xi1>) -> vector<2xi32> { + %0 = vector.mask %m { vector.outerproduct %arg0, %arg1, %arg2 {kind = #vector.kind} : vector<2xi32>, i32 } : vector<2xi1> -> vector<2xi32> + return %0 : vector<2xi32> +} + +// CHECK-LABEL: func.func @masked_int_add_outerprod( +// CHECK-SAME: %[[VAL_0:.*]]: vector<2xi32>, %[[VAL_1:.*]]: i32, %[[VAL_2:.*]]: vector<2xi32>, %[[VAL_3:.*]]: vector<2xi1>) -> vector<2xi32> { +// CHECK: %[[VAL_8:.*]] = arith.muli %[[VAL_0]], %{{.*}} : vector<2xi32> +// CHECK: %[[VAL_9:.*]] = arith.addi %[[VAL_8]], %[[VAL_2]] : vector<2xi32> +// CHECK: %[[VAL_10:.*]] = arith.select %[[VAL_3]], %[[VAL_9]], %[[VAL_2]] : vector<2xi1>, vector<2xi32> + +// ----- + +func.func @masked_int_mul_outerprod(%arg0: vector<2xi32>, %arg1: i32, %arg2: vector<2xi32>, %m: vector<2xi1>) -> vector<2xi32> { + %0 = vector.mask %m { vector.outerproduct %arg0, %arg1, %arg2 {kind = #vector.kind} : vector<2xi32>, i32 } : vector<2xi1> -> vector<2xi32> + return %0 : vector<2xi32> +} + +// CHECK-LABEL: func.func @masked_int_mul_outerprod( +// CHECK-SAME: %[[VAL_0:.*]]: vector<2xi32>, %[[VAL_1:.*]]: i32, %[[VAL_2:.*]]: vector<2xi32>, %[[VAL_3:.*]]: vector<2xi1>) -> vector<2xi32> { +// CHECK: %[[VAL_8:.*]] = arith.muli %[[VAL_0]], %{{.*}} : vector<2xi32> +// CHECK: %[[VAL_9:.*]] = arith.muli %[[VAL_8]], %[[VAL_2]] : vector<2xi32> +// CHECK: %[[VAL_10:.*]] = arith.select %[[VAL_3]], %[[VAL_9]], %[[VAL_2]] : vector<2xi1>, vector<2xi32> + +// ----- + +func.func @masked_int_max_outerprod(%arg0: vector<2xi32>, %arg1: i32, %arg2: vector<2xi32>, %m: vector<2xi1>) -> vector<2xi32> { + %0 = vector.mask %m { vector.outerproduct %arg0, %arg1, %arg2 {kind = #vector.kind} : vector<2xi32>, i32 } : vector<2xi1> -> vector<2xi32> + return %0 : vector<2xi32> +} + +// CHECK-LABEL: func.func @masked_int_max_outerprod( +// CHECK-SAME: %[[VAL_0:.*]]: vector<2xi32>, %[[VAL_1:.*]]: i32, %[[VAL_2:.*]]: vector<2xi32>, %[[VAL_3:.*]]: vector<2xi1>) -> vector<2xi32> { +// CHECK: %[[VAL_8:.*]] = arith.muli %[[VAL_0]], %{{.*}} : vector<2xi32> +// CHECK: %[[VAL_9:.*]] = arith.maxsi %[[VAL_8]], %[[VAL_2]] : vector<2xi32> +// CHECK: %[[VAL_10:.*]] = arith.select %[[VAL_3]], %[[VAL_9]], %[[VAL_2]] : vector<2xi1>, vector<2xi32> + +// ----- + +func.func @masked_int_min_outerprod(%arg0: vector<2xi32>, %arg1: i32, %arg2: vector<2xi32>, %m: vector<2xi1>) -> vector<2xi32> { + %0 = vector.mask %m { vector.outerproduct %arg0, %arg1, %arg2 {kind = #vector.kind} : vector<2xi32>, i32 } : vector<2xi1> -> vector<2xi32> + return %0 : vector<2xi32> +} + +// CHECK-LABEL: func.func @masked_int_min_outerprod( +// CHECK-SAME: %[[VAL_0:.*]]: vector<2xi32>, %[[VAL_1:.*]]: i32, %[[VAL_2:.*]]: vector<2xi32>, %[[VAL_3:.*]]: vector<2xi1>) -> vector<2xi32> { +// CHECK: %[[VAL_8:.*]] = arith.muli %[[VAL_0]], %{{.*}} : vector<2xi32> +// CHECK: %[[VAL_9:.*]] = arith.minui %[[VAL_8]], %[[VAL_2]] : vector<2xi32> +// CHECK: %[[VAL_10:.*]] = arith.select %[[VAL_3]], %[[VAL_9]], %[[VAL_2]] : vector<2xi1>, vector<2xi32> + +// ----- + +func.func @masked_int_and_outerprod(%arg0: vector<2xi32>, %arg1: i32, %arg2: vector<2xi32>, %m: vector<2xi1>) -> vector<2xi32> { + %0 = vector.mask %m { vector.outerproduct %arg0, %arg1, %arg2 {kind = #vector.kind} : vector<2xi32>, i32 } : vector<2xi1> -> vector<2xi32> + return %0 : vector<2xi32> +} + +// CHECK-LABEL: func.func @masked_int_and_outerprod( +// CHECK-SAME: %[[VAL_0:.*]]: vector<2xi32>, %[[VAL_1:.*]]: i32, %[[VAL_2:.*]]: vector<2xi32>, %[[VAL_3:.*]]: vector<2xi1>) -> vector<2xi32> { +// CHECK: %[[VAL_8:.*]] = arith.muli %[[VAL_0]], %{{.*}} : vector<2xi32> +// CHECK: %[[VAL_9:.*]] = arith.andi %[[VAL_8]], %[[VAL_2]] : vector<2xi32> +// CHECK: %[[VAL_10:.*]] = arith.select %[[VAL_3]], %[[VAL_9]], %[[VAL_2]] : vector<2xi1>, vector<2xi32> + +// ----- + +func.func @masked_int_or_outerprod(%arg0: vector<2xi32>, %arg1: i32, %arg2: vector<2xi32>, %m: vector<2xi1>) -> vector<2xi32> { + %0 = vector.mask %m { vector.outerproduct %arg0, %arg1, %arg2 {kind = #vector.kind} : vector<2xi32>, i32 } : vector<2xi1> -> vector<2xi32> + return %0 : vector<2xi32> +} + +// CHECK-LABEL: func.func @masked_int_or_outerprod( +// CHECK-SAME: %[[VAL_0:.*]]: vector<2xi32>, %[[VAL_1:.*]]: i32, %[[VAL_2:.*]]: vector<2xi32>, %[[VAL_3:.*]]: vector<2xi1>) -> vector<2xi32> { +// CHECK: %[[VAL_8:.*]] = arith.muli %[[VAL_0]], %{{.*}} : vector<2xi32> +// CHECK: %[[VAL_9:.*]] = arith.ori %[[VAL_8]], %[[VAL_2]] : vector<2xi32> +// CHECK: %[[VAL_10:.*]] = arith.select %[[VAL_3]], %[[VAL_9]], %[[VAL_2]] : vector<2xi1>, vector<2xi32> // ----- @@ -2157,17 +2273,3 @@ %0 = vector.scalable.extract %vec[0] : vector<8xf32> from vector<[4]xf32> return %0 : vector<8xf32> } - -// ----- - -// CHECK-LABEL: func.func @masked_vector_fma( -// CHECK-SAME: %[[INPUT:.*]]: vector<8xf32>, -// CHECK-SAME: %[[MASK:.*]]: vector<8xi1>) -> vector<8xf32> -// CHECK: %[[FMA:.*]] = llvm.intr.fmuladd(%[[INPUT]], %[[INPUT]], %[[INPUT]]) : (vector<8xf32>, vector<8xf32>, vector<8xf32>) -> vector<8xf32> -// CHECK: llvm.select %[[MASK]], %[[FMA]], %[[INPUT]] : vector<8xi1>, vector<8xf32> - -func.func @masked_vector_fma(%a: vector<8xf32>, %m: vector<8xi1>) -> vector<8xf32> { - %0 = vector.mask %m { vector.fma %a, %a, %a : vector<8xf32> } : vector<8xi1> -> vector<8xf32> - return %0 : vector<8xf32> -} - diff --git a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir --- a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir @@ -76,6 +76,30 @@ return %0 : vector<2xf32> } +// OUTERPRODUCT-LABEL: func.func @masked_extract_contract2( +// OUTERPRODUCT-SAME: %[[VAL_0:.*]]: vector<2x3xf32>, +// OUTERPRODUCT-SAME: %[[VAL_1:.*]]: vector<3xf32>, +// OUTERPRODUCT-SAME: %[[VAL_2:.*]]: vector<2xf32>, +// OUTERPRODUCT-SAME: %[[IN_MASK:.*]]: vector<2x3xi1>) -> vector<2xf32> +// OUTERPRODUCT: %[[T_MASK:.*]] = vector.transpose %[[IN_MASK]], [1, 0] : vector<2x3xi1> to vector<3x2xi1> +// OUTERPRODUCT: %[[MASK0:.*]] = vector.extract %[[T_MASK]][0] : vector<3x2xi1> +// OUTERPRODUCT: vector.mask %[[MASK0]] { vector.outerproduct + +// OUTERPRODUCT: %[[MASK1:.*]] = vector.extract %[[T_MASK]][1] : vector<3x2xi1> +// OUTERPRODUCT: vector.mask %[[MASK1]] { vector.outerproduct + +// OUTERPRODUCT: %[[MASK2:.*]] = vector.extract %[[T_MASK]][2] : vector<3x2xi1> +// OUTERPRODUCT: vector.mask %[[MASK2]] { vector.outerproduct + +func.func @masked_extract_contract2(%arg0: vector<2x3xf32>, + %arg1: vector<3xf32>, + %arg2: vector<2xf32>, + %m: vector<2x3xi1>) -> vector<2xf32> { + %0 = vector.mask %m { vector.contract #matvec_trait %arg0, %arg1, %arg2 + : vector<2x3xf32>, vector<3xf32> into vector<2xf32> } : vector<2x3xi1> -> vector<2xf32> + return %0 : vector<2xf32> +} + // CHECK-LABEL: func @extract_contract2_int // CHECK-SAME: %[[A:.*0]]: vector<2x3xi32>, // CHECK-SAME: %[[B:.*1]]: vector<3xi32>, @@ -182,6 +206,32 @@ return %0 : vector<2x2xf32> } +// OUTERPRODUCT-LABEL: func.func @masked_extract_contract4( +// OUTERPRODUCT-SAME: %[[VAL_0:.*]]: vector<3x5xf32>, +// OUTERPRODUCT-SAME: %[[VAL_1:.*]]: vector<5x7xf32>, +// OUTERPRODUCT-SAME: %[[VAL_2:.*]]: vector<3x7xf32>, +// OUTERPRODUCT-SAME: %[[VAL_3:.*]]: vector<3x7x5xi1>) -> vector<3x7xf32> { +// OUTERPRODUCT: %[[VAL_5:.*]] = vector.transpose %[[VAL_3]], [2, 0, 1] : vector<3x7x5xi1> to vector<5x3x7xi1> +// OUTERPRODUCT: %[[VAL_8:.*]] = vector.extract %[[VAL_5]][0] : vector<5x3x7xi1> +// OUTERPRODUCT: %[[VAL_9:.*]] = vector.mask %[[VAL_8]] { vector.outerproduct %{{.*}} {kind = #vector.kind} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32> +// OUTERPRODUCT: %[[VAL_12:.*]] = vector.extract %[[VAL_5]][1] : vector<5x3x7xi1> +// OUTERPRODUCT: %[[VAL_13:.*]] = vector.mask %[[VAL_12]] { vector.outerproduct %{{.*}} {kind = #vector.kind} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32> +// OUTERPRODUCT: %[[VAL_16:.*]] = vector.extract %[[VAL_5]][2] : vector<5x3x7xi1> +// OUTERPRODUCT: %[[VAL_17:.*]] = vector.mask %[[VAL_16]] { vector.outerproduct %{{.*}} {kind = #vector.kind} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32> +// OUTERPRODUCT: %[[VAL_20:.*]] = vector.extract %[[VAL_5]][3] : vector<5x3x7xi1> +// OUTERPRODUCT: %[[VAL_21:.*]] = vector.mask %[[VAL_20]] { vector.outerproduct %{{.*}} {kind = #vector.kind} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32> +// OUTERPRODUCT: %[[VAL_24:.*]] = vector.extract %[[VAL_5]][4] : vector<5x3x7xi1> +// OUTERPRODUCT: %[[VAL_25:.*]] = vector.mask %[[VAL_24]] { vector.outerproduct %{{.*}} {kind = #vector.kind} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32> + +func.func @masked_extract_contract4(%arg0: vector<3x5xf32>, + %arg1: vector<5x7xf32>, + %arg2: vector<3x7xf32>, + %m : vector<3x7x5xi1>) -> vector<3x7xf32> { + %0 = vector.mask %m { vector.contract #matmat_trait %arg0, %arg1, %arg2 + : vector<3x5xf32>, vector<5x7xf32> into vector<3x7xf32> } : vector<3x7x5xi1> -> vector<3x7xf32> + return %0 : vector<3x7xf32> +} + #contraction2d_accesses = [ affine_map<(i, j) -> (i, j)>, affine_map<(i, j) -> (i, j)>, @@ -1197,26 +1247,4 @@ return %0 : f32 } -func.func @masked_vector_contract(%arg0: vector<2x3xf32>, - %arg1: vector<3xf32>, - %arg2: vector<2xf32>, - %m: vector<2x3xi1>) -> vector<2xf32> { - %0 = vector.mask %m { vector.contract #matvec_trait %arg0, %arg1, %arg2 - : vector<2x3xf32>, vector<3xf32> into vector<2xf32> } : vector<2x3xi1> -> vector<2xf32> - return %0 : vector<2xf32> -} -// OUTERPRODUCT-LABEL: func.func @masked_vector_contract( -// OUTERPRODUCT-SAME: %[[VAL_0:.*]]: vector<2x3xf32>, -// OUTERPRODUCT-SAME: %[[VAL_1:.*]]: vector<3xf32>, -// OUTERPRODUCT-SAME: %[[VAL_2:.*]]: vector<2xf32>, -// OUTERPRODUCT-SAME: %[[IN_MASK:.*]]: vector<2x3xi1>) -> vector<2xf32> -// OUTERPRODUCT: %[[T_MASK:.*]] = vector.transpose %[[IN_MASK]], [1, 0] : vector<2x3xi1> to vector<3x2xi1> -// OUTERPRODUCT: %[[MASK0:.*]] = vector.extract %[[T_MASK]][0] : vector<3x2xi1> -// OUTERPRODUCT: vector.mask %[[MASK0]] { vector.outerproduct - -// OUTERPRODUCT: %[[MASK1:.*]] = vector.extract %[[T_MASK]][1] : vector<3x2xi1> -// OUTERPRODUCT: vector.mask %[[MASK1]] { vector.outerproduct - -// OUTERPRODUCT: %[[MASK2:.*]] = vector.extract %[[T_MASK]][2] : vector<3x2xi1> -// OUTERPRODUCT: vector.mask %[[MASK2]] { vector.outerproduct