diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td @@ -91,7 +91,7 @@ Example: ```mlir - // Simple dot product (K = 0). + // Simple DOT product (K = 0). #contraction_accesses = [ affine_map<(i) -> (i)>, affine_map<(i) -> (i)>, @@ -668,19 +668,36 @@ } def Vector_OuterProductOp : - Vector_Op<"outerproduct", [NoSideEffect, SameOperandsAndResultElementType]>, - Arguments<(ins AnyVector:$lhs, AnyVector:$rhs, Variadic:$acc)>, + Vector_Op<"outerproduct", [NoSideEffect, + PredOpTrait<"lhs operand and result have same element type", + TCresVTEtIsSameAsOpBase<0, 0>>, + PredOpTrait<"rhs operand and result have same element type", + TCresVTEtIsSameAsOpBase<0, 1>>]>, + Arguments<(ins AnyVector:$lhs, AnyType:$rhs, Variadic:$acc)>, Results<(outs AnyVector)> { let summary = "vector outerproduct with optional fused add"; let description = [{ - Takes 2 1-D vectors and returns the 2-D vector containing the outer-product. + Takes 2 1-D vectors and returns the 2-D vector containing the outer-product, + as illustrated below: + ``` + outer | [c, d] + ------+------------ + [a, | [ [a*c, a*d], + b] | [b*c, b*d] ] + ``` + This operation also accepts a 1-D vector lhs and a scalar rhs. In this + case a simple AXPY operation is performed, which returns a 1-D vector. + ``` + [a, b] * c = [a*c, b*c] + ``` - An optional extra 2-D vector argument may be specified in which case the - operation returns the sum of the outer-product and the extra vector. In this - multiply-accumulate scenario, the rounding mode is that obtained by - guaranteeing that a fused-multiply add operation is emitted. When lowered to - the LLVMIR dialect, this form emits `llvm.intr.fma`, which is guaranteed to - lower to actual `fma` instructions on x86. + An optional extra vector argument with the same shape as the output + vector may be specified in which case the operation returns the sum of + the outer-product and the extra vector. In this multiply-accumulate + scenario for floating-point arguments, the rounding mode is enforced + by guaranteeing that a fused-multiply add operation is emitted. When + lowered to the LLVMIR dialect, this form emits `llvm.intr.fma`, which + is guaranteed to lower to actual `fma` instructions on x86. Example: @@ -691,6 +708,10 @@ %3 = vector.outerproduct %0, %1, %2: vector<4xf32>, vector<8xf32>, vector<4x8xf32> return %3: vector<4x8xf32> + + %6 = vector.outerproduct %4, %5: vector<10xf32>, f32 + return %6: vector<10xf32> + ``` }]; let builders = [ @@ -702,12 +723,13 @@ VectorType getOperandVectorTypeLHS() { return lhs().getType().cast(); } - VectorType getOperandVectorTypeRHS() { - return rhs().getType().cast(); + Type getOperandTypeRHS() { + return rhs().getType(); } VectorType getOperandVectorTypeACC() { - return (llvm::size(acc()) == 0) ? VectorType() : - (*acc().begin()).getType().cast(); + return (llvm::size(acc()) == 0) + ? VectorType() + : (*acc().begin()).getType().cast(); } VectorType getVectorType() { return getResult().getType().cast(); diff --git a/mlir/integration_test/Dialect/Vector/CPU/test-outerproduct-f32.mlir b/mlir/integration_test/Dialect/Vector/CPU/test-outerproduct-f32.mlir --- a/mlir/integration_test/Dialect/Vector/CPU/test-outerproduct-f32.mlir +++ b/mlir/integration_test/Dialect/Vector/CPU/test-outerproduct-f32.mlir @@ -11,6 +11,8 @@ !vector_type_Y = type vector<3xf32> !vector_type_Z = type vector<2x3xf32> +!vector_type_R = type vector<7xf32> + func @vector_outerproduct_splat_8x8(%fa: f32, %fb: f32, %fc: f32) -> !vector_type_C { %a = splat %fa: !vector_type_A %b = splat %fb: !vector_type_B @@ -33,6 +35,7 @@ } func @entry() { + %f0 = constant 0.0: f32 %f1 = constant 1.0: f32 %f2 = constant 2.0: f32 %f3 = constant 3.0: f32 @@ -72,5 +75,26 @@ // // CHECK: ( ( 6, 8, 10 ), ( 12, 16, 20 ) ) + %3 = vector.broadcast %f0 : f32 to !vector_type_R + %4 = vector.insert %f1, %3[1] : f32 into !vector_type_R + %5 = vector.insert %f2, %4[2] : f32 into !vector_type_R + %6 = vector.insert %f3, %5[3] : f32 into !vector_type_R + %7 = vector.insert %f4, %6[4] : f32 into !vector_type_R + %8 = vector.insert %f5, %7[5] : f32 into !vector_type_R + %9 = vector.insert %f10, %8[6] : f32 into !vector_type_R + + %o = vector.broadcast %f1 : f32 to !vector_type_R + + %axpy1 = vector.outerproduct %9, %f2 : !vector_type_R, f32 + %axpy2 = vector.outerproduct %9, %f2, %o : !vector_type_R, f32 + + vector.print %axpy1 : !vector_type_R + vector.print %axpy2 : !vector_type_R + // + // axpy operations: + // + // CHECK: ( 0, 2, 4, 6, 8, 10, 20 ) + // CHECK: ( 1, 3, 5, 7, 9, 11, 21 ) + return } diff --git a/mlir/integration_test/Dialect/Vector/CPU/test-outerproduct-i64.mlir b/mlir/integration_test/Dialect/Vector/CPU/test-outerproduct-i64.mlir --- a/mlir/integration_test/Dialect/Vector/CPU/test-outerproduct-i64.mlir +++ b/mlir/integration_test/Dialect/Vector/CPU/test-outerproduct-i64.mlir @@ -11,6 +11,8 @@ !vector_type_Y = type vector<3xi64> !vector_type_Z = type vector<2x3xi64> +!vector_type_R = type vector<7xi64> + func @vector_outerproduct_splat_8x8(%ia: i64, %ib: i64, %ic: i64) -> !vector_type_C { %a = splat %ia: !vector_type_A %b = splat %ib: !vector_type_B @@ -33,6 +35,7 @@ } func @entry() { + %i0 = constant 0: i64 %i1 = constant 1: i64 %i2 = constant 2: i64 %i3 = constant 3: i64 @@ -72,5 +75,26 @@ // // CHECK: ( ( 6, 8, 10 ), ( 12, 16, 20 ) ) + %3 = vector.broadcast %i0 : i64 to !vector_type_R + %4 = vector.insert %i1, %3[1] : i64 into !vector_type_R + %5 = vector.insert %i2, %4[2] : i64 into !vector_type_R + %6 = vector.insert %i3, %5[3] : i64 into !vector_type_R + %7 = vector.insert %i4, %6[4] : i64 into !vector_type_R + %8 = vector.insert %i5, %7[5] : i64 into !vector_type_R + %9 = vector.insert %i10, %8[6] : i64 into !vector_type_R + + %o = vector.broadcast %i1 : i64 to !vector_type_R + + %axpy1 = vector.outerproduct %9, %i2 : !vector_type_R, i64 + %axpy2 = vector.outerproduct %9, %i2, %o : !vector_type_R, i64 + + vector.print %axpy1 : !vector_type_R + vector.print %axpy2 : !vector_type_R + // + // axpy operations: + // + // CHECK: ( 0, 2, 4, 6, 8, 10, 20 ) + // CHECK: ( 1, 3, 5, 7, 9, 11, 21 ) + return } diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -1203,10 +1203,13 @@ "expected at least 2 operands"); VectorType vLHS = tLHS.dyn_cast(); VectorType vRHS = tRHS.dyn_cast(); - if (!vLHS || !vRHS) - return parser.emitError(parser.getNameLoc(), "expected 2 vector types"); - VectorType resType = VectorType::get({vLHS.getDimSize(0), vRHS.getDimSize(0)}, - vLHS.getElementType()); + if (!vLHS) + return parser.emitError(parser.getNameLoc(), + "expected vector type for operand #1"); + VectorType resType = + vRHS ? VectorType::get({vLHS.getDimSize(0), vRHS.getDimSize(0)}, + vLHS.getElementType()) + : VectorType::get({vLHS.getDimSize(0)}, vLHS.getElementType()); return failure( parser.resolveOperand(operandsInfo[0], tLHS, result.operands) || parser.resolveOperand(operandsInfo[1], tRHS, result.operands) || @@ -1216,19 +1219,32 @@ } static LogicalResult verify(OuterProductOp op) { + Type tRHS = op.getOperandTypeRHS(); VectorType vLHS = op.getOperandVectorTypeLHS(), - vRHS = op.getOperandVectorTypeRHS(), + vRHS = tRHS.dyn_cast(), vACC = op.getOperandVectorTypeACC(), vRES = op.getVectorType(); + if (vLHS.getRank() != 1) return op.emitOpError("expected 1-d vector for operand #1"); - if (vRHS.getRank() != 1) - return op.emitOpError("expected 1-d vector for operand #2"); - if (vRES.getRank() != 2) - return op.emitOpError("expected 2-d vector result"); - if (vLHS.getDimSize(0) != vRES.getDimSize(0)) - return op.emitOpError("expected #1 operand dim to match result dim #1"); - if (vRHS.getDimSize(0) != vRES.getDimSize(1)) - return op.emitOpError("expected #2 operand dim to match result dim #2"); + + if (vRHS) { + // Proper OUTER operation. + if (vRHS.getRank() != 1) + return op.emitOpError("expected 1-d vector for operand #2"); + if (vRES.getRank() != 2) + return op.emitOpError("expected 2-d vector result"); + if (vLHS.getDimSize(0) != vRES.getDimSize(0)) + return op.emitOpError("expected #1 operand dim to match result dim #1"); + if (vRHS.getDimSize(0) != vRES.getDimSize(1)) + return op.emitOpError("expected #2 operand dim to match result dim #2"); + } else { + // An AXPY operation. + if (vRES.getRank() != 1) + return op.emitOpError("expected 1-d vector result"); + if (vLHS.getDimSize(0) != vRES.getDimSize(0)) + return op.emitOpError("expected #1 operand dim to match result dim #1"); + } + if (vACC && vACC != vRES) return op.emitOpError("expected operand #3 of same type as result type"); return success(); diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -1262,7 +1262,7 @@ /// %0 = vector.extract %lhs[0] /// %1 = vector.broadcast %0 /// %2 = vector.extract %acc[0] -/// %3 = vector.fma %1, %arg1, %2 +/// %3 = vector.fma %1, %rhs, %2 /// %4 = vector.insert %3, %z[0] /// .. /// %x = vector.insert %.., %..[N-1] @@ -1275,36 +1275,49 @@ PatternRewriter &rewriter) const override { auto loc = op.getLoc(); - VectorType rhsType = op.getOperandVectorTypeRHS(); + VectorType lhsType = op.getOperandVectorTypeLHS(); + VectorType rhsType = op.getOperandTypeRHS().dyn_cast(); VectorType resType = op.getVectorType(); Type eltType = resType.getElementType(); + bool isInt = eltType.isa(); Value acc = (op.acc().empty()) ? nullptr : op.acc()[0]; + if (!rhsType) { + // Special case: AXPY operation. + Value b = rewriter.create(loc, lhsType, op.rhs()); + rewriter.replaceOp(op, genMult(loc, op.lhs(), b, acc, isInt, rewriter)); + return success(); + } + Value result = rewriter.create(loc, resType, rewriter.getZeroAttr(resType)); for (int64_t d = 0, e = resType.getDimSize(0); d < e; ++d) { auto pos = rewriter.getI64ArrayAttr(d); Value x = rewriter.create(loc, eltType, op.lhs(), pos); - Value b = rewriter.create(loc, rhsType, x); - Value m; - if (acc) { - Value e = rewriter.create(loc, rhsType, acc, pos); - if (eltType.isa()) - m = rewriter.create( - loc, rewriter.create(loc, b, op.rhs()), e); - else - m = rewriter.create(loc, b, op.rhs(), e); - } else { - if (eltType.isa()) - m = rewriter.create(loc, b, op.rhs()); - else - m = rewriter.create(loc, b, op.rhs()); - } + Value a = rewriter.create(loc, rhsType, x); + Value r = nullptr; + if (acc) + r = rewriter.create(loc, rhsType, acc, pos); + Value m = genMult(loc, a, op.rhs(), r, isInt, rewriter); result = rewriter.create(loc, resType, m, result, pos); } rewriter.replaceOp(op, result); return success(); } + +private: + static Value genMult(Location loc, Value x, Value y, Value acc, bool isInt, + PatternRewriter &rewriter) { + if (acc) { + if (isInt) + return rewriter.create(loc, rewriter.create(loc, x, y), + acc); + return rewriter.create(loc, x, y, acc); + } + if (isInt) + return rewriter.create(loc, x, y); + return rewriter.create(loc, x, y); + } }; /// Progressive lowering of ConstantMaskOp. diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -187,7 +187,7 @@ // ----- func @outerproduct_non_vector_operand(%arg0: f32) { - // expected-error@+1 {{expected 2 vector types}} + // expected-error@+1 {{expected vector type for operand #1}} %1 = vector.outerproduct %arg0, %arg0 : f32, f32 } @@ -228,6 +228,27 @@ // ----- +func @outerproduct_axpy_operand(%arg0: vector<4x8xf32>, %arg1: f32) { + // expected-error@+1 {{expected 1-d vector for operand #1}} + %1 = vector.outerproduct %arg0, %arg1 : vector<4x8xf32>, f32 +} + +// ----- + +func @outerproduct_axpy_result_generic(%arg0: vector<4xf32>, %arg1: f32) { + // expected-error@+1 {{expected 1-d vector result}} + %1 = "vector.outerproduct" (%arg0, %arg1) : (vector<4xf32>, f32) -> (vector<4x8xf32>) +} + +// ----- + +func @outerproduct_axpy_operand_dim_generic(%arg0: vector<8xf32>, %arg1: f32) { + // expected-error@+1 {{expected #1 operand dim to match result dim #1}} + %1 = "vector.outerproduct" (%arg0, %arg1) : (vector<8xf32>, f32) -> (vector<16xf32>) +} + +// ----- + func @outerproduct_operand_3_result_type_generic(%arg0: vector<4xf32>, %arg1: vector<8xf32>, %arg2: vector<4x16xf32>) { // expected-error@+1 {{expected operand #3 of same type as result type}} %1 = "vector.outerproduct" (%arg0, %arg1, %arg2) : (vector<4xf32>, vector<8xf32>, vector<4x16xf32>) -> (vector<4x8xf32>) diff --git a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir --- a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir @@ -326,6 +326,53 @@ return %0: vector<2x3xi32> } +// CHECK-LABEL: func @axpy_fp( +// CHECK-SAME: %[[A:.*0]]: vector<16xf32>, +// CHECK-SAME: %[[B:.*1]]: f32) +// CHECK: %[[T0:.*]] = splat %[[B]] : vector<16xf32> +// CHECK: %[[T1:.*]] = mulf %[[A]], %[[T0]] : vector<16xf32> +// CHECK: return %[[T1]] : vector<16xf32> +func @axpy_fp(%arg0: vector<16xf32>, %arg1: f32) -> vector<16xf32> { + %0 = vector.outerproduct %arg0, %arg1: vector<16xf32>, f32 + return %0: vector<16xf32> +} + +// CHECK-LABEL: func @axpy_fp_add( +// CHECK-SAME: %[[A:.*0]]: vector<16xf32>, +// CHECK-SAME: %[[B:.*1]]: f32, +// CHECK-SAME: %[[C:.*2]]: vector<16xf32>) +// CHECK: %[[T0:.*]] = splat %[[B]] : vector<16xf32> +// CHECK: %[[T1:.*]] = vector.fma %[[A]], %[[T0]], %[[C]] : vector<16xf32> +// CHECK: return %[[T1]] : vector<16xf32> +func @axpy_fp_add(%arg0: vector<16xf32>, %arg1: f32, %arg2 : vector<16xf32>) -> vector<16xf32> { + %0 = vector.outerproduct %arg0, %arg1, %arg2: vector<16xf32>, f32 + return %0: vector<16xf32> +} + +// CHECK-LABEL: func @axpy_int( +// CHECK-SAME: %[[A:.*0]]: vector<16xi32>, +// CHECK-SAME: %[[B:.*1]]: i32) +// CHECK: %[[T0:.*]] = splat %[[B]] : vector<16xi32> +// CHECK: %[[T1:.*]] = muli %[[A]], %[[T0]] : vector<16xi32> +// CHECK: return %[[T1]] : vector<16xi32> +func @axpy_int(%arg0: vector<16xi32>, %arg1: i32) -> vector<16xi32> { + %0 = vector.outerproduct %arg0, %arg1: vector<16xi32>, i32 + return %0: vector<16xi32> +} + +// CHECK-LABEL: func @axpy_int_add( +// CHECK-SAME: %[[A:.*0]]: vector<16xi32>, +// CHECK-SAME: %[[B:.*1]]: i32, +// CHECK-SAME: %[[C:.*2]]: vector<16xi32>) +// CHECK: %[[T0:.*]] = splat %[[B]] : vector<16xi32> +// CHECK: %[[T1:.*]] = muli %[[A]], %[[T0]] : vector<16xi32> +// CHECK: %[[T2:.*]] = addi %[[T1]], %[[C]] : vector<16xi32> +// CHECK: return %[[T2]] : vector<16xi32> +func @axpy_int_add(%arg0: vector<16xi32>, %arg1: i32, %arg2: vector<16xi32>) -> vector<16xi32> { + %0 = vector.outerproduct %arg0, %arg1, %arg2: vector<16xi32>, i32 + return %0: vector<16xi32> +} + // CHECK-LABEL: func @transpose23 // CHECK-SAME: %[[A:.*]]: vector<2x3xf32> // CHECK: %[[Z:.*]] = constant dense<0.000000e+00> : vector<3x2xf32>