diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td @@ -91,7 +91,7 @@ Example: ```mlir - // Simple dot product (K = 0). + // Simple DOT product (K = 0). #contraction_accesses = [ affine_map<(i) -> (i)>, affine_map<(i) -> (i)>, @@ -668,19 +668,36 @@ } def Vector_OuterProductOp : - Vector_Op<"outerproduct", [NoSideEffect, SameOperandsAndResultElementType]>, - Arguments<(ins AnyVector:$lhs, AnyVector:$rhs, Variadic:$acc)>, + Vector_Op<"outerproduct", [NoSideEffect, + PredOpTrait<"lhs operand and result have same element type", + TCresVTEtIsSameAsOpBase<0, 0>>, + PredOpTrait<"rhs operand and result have same element type", + TCresVTEtIsSameAsOpBase<0, 1>>]>, + Arguments<(ins AnyType:$lhs, AnyVector:$rhs, Variadic:$acc)>, Results<(outs AnyVector)> { let summary = "vector outerproduct with optional fused add"; let description = [{ - Takes 2 1-D vectors and returns the 2-D vector containing the outer-product. + Takes 2 1-D vectors and returns the 2-D vector containing the outer-product, + as illustrated below: + ``` + outer | [c, d] + ------+------------ + [a, | [ [a*c, a*d], + b] | [b*c, b*d] ] + ``` + This operation also accepts a scalar lhs and a 1-D vector rhs. In this + case a simple AXPY operation is performed, which returns a 1-D vector. + ``` + a * [b, c] * = [a*b, a*c] + ``` - An optional extra 2-D vector argument may be specified in which case the - operation returns the sum of the outer-product and the extra vector. In this - multiply-accumulate scenario, the rounding mode is that obtained by - guaranteeing that a fused-multiply add operation is emitted. When lowered to - the LLVMIR dialect, this form emits `llvm.intr.fma`, which is guaranteed to - lower to actual `fma` instructions on x86. + An optional extra vector argument with the same shape as the output + vector may be specified in which case the operation returns the sum of + the outer-product and the extra vector. In this multiply-accumulate + scenario for floating-point arguments, the rounding mode is enforced + by guaranteeing that a fused-multiply add operation is emitted. When + lowered to the LLVMIR dialect, this form emits `llvm.intr.fma`, which + is guaranteed to lower to actual `fma` instructions on x86. Example: @@ -691,6 +708,10 @@ %3 = vector.outerproduct %0, %1, %2: vector<4xf32>, vector<8xf32>, vector<4x8xf32> return %3: vector<4x8xf32> + + %4 = vector.outerproduct %0, %4: vector<4xf32>, f32 + return %4: vector<4xf32> + ``` }]; let builders = [ @@ -699,15 +720,16 @@ "OpBuilder &builder, OperationState &result, Value lhs, Value rhs, " "Value acc">]; let extraClassDeclaration = [{ - VectorType getOperandVectorTypeLHS() { - return lhs().getType().cast(); + Type getOperandTypeLHS() { + return lhs().getType(); } VectorType getOperandVectorTypeRHS() { return rhs().getType().cast(); } VectorType getOperandVectorTypeACC() { - return (llvm::size(acc()) == 0) ? VectorType() : - (*acc().begin()).getType().cast(); + return (llvm::size(acc()) == 0) + ? VectorType() + : (*acc().begin()).getType().cast(); } VectorType getVectorType() { return getResult().getType().cast(); diff --git a/mlir/integration_test/Dialect/Vector/CPU/test-outerproduct-f32.mlir b/mlir/integration_test/Dialect/Vector/CPU/test-outerproduct-f32.mlir --- a/mlir/integration_test/Dialect/Vector/CPU/test-outerproduct-f32.mlir +++ b/mlir/integration_test/Dialect/Vector/CPU/test-outerproduct-f32.mlir @@ -11,6 +11,8 @@ !vector_type_Y = type vector<3xf32> !vector_type_Z = type vector<2x3xf32> +!vector_type_R = type vector<7xf32> + func @vector_outerproduct_splat_8x8(%fa: f32, %fb: f32, %fc: f32) -> !vector_type_C { %a = splat %fa: !vector_type_A %b = splat %fb: !vector_type_B @@ -33,6 +35,7 @@ } func @entry() { + %f0 = constant 0.0: f32 %f1 = constant 1.0: f32 %f2 = constant 2.0: f32 %f3 = constant 3.0: f32 @@ -72,5 +75,26 @@ // // CHECK: ( ( 6, 8, 10 ), ( 12, 16, 20 ) ) + %3 = vector.broadcast %f0 : f32 to !vector_type_R + %4 = vector.insert %f1, %3[1] : f32 into !vector_type_R + %5 = vector.insert %f2, %4[2] : f32 into !vector_type_R + %6 = vector.insert %f3, %5[3] : f32 into !vector_type_R + %7 = vector.insert %f4, %6[4] : f32 into !vector_type_R + %8 = vector.insert %f5, %7[5] : f32 into !vector_type_R + %9 = vector.insert %f10, %8[6] : f32 into !vector_type_R + + %o = vector.broadcast %f1 : f32 to !vector_type_R + + %axpy1 = vector.outerproduct %f2, %9 : f32, !vector_type_R + %axpy2 = vector.outerproduct %f2, %9, %o : f32, !vector_type_R + + vector.print %axpy1 : !vector_type_R + vector.print %axpy2 : !vector_type_R + // + // axpy operations: + // + // CHECK: ( 0, 2, 4, 6, 8, 10, 20 ) + // CHECK: ( 1, 3, 5, 7, 9, 11, 21 ) + return } diff --git a/mlir/integration_test/Dialect/Vector/CPU/test-outerproduct-i64.mlir b/mlir/integration_test/Dialect/Vector/CPU/test-outerproduct-i64.mlir --- a/mlir/integration_test/Dialect/Vector/CPU/test-outerproduct-i64.mlir +++ b/mlir/integration_test/Dialect/Vector/CPU/test-outerproduct-i64.mlir @@ -11,6 +11,8 @@ !vector_type_Y = type vector<3xi64> !vector_type_Z = type vector<2x3xi64> +!vector_type_R = type vector<7xi64> + func @vector_outerproduct_splat_8x8(%ia: i64, %ib: i64, %ic: i64) -> !vector_type_C { %a = splat %ia: !vector_type_A %b = splat %ib: !vector_type_B @@ -33,6 +35,7 @@ } func @entry() { + %i0 = constant 0: i64 %i1 = constant 1: i64 %i2 = constant 2: i64 %i3 = constant 3: i64 @@ -72,5 +75,26 @@ // // CHECK: ( ( 6, 8, 10 ), ( 12, 16, 20 ) ) + %3 = vector.broadcast %i0 : i64 to !vector_type_R + %4 = vector.insert %i1, %3[1] : i64 into !vector_type_R + %5 = vector.insert %i2, %4[2] : i64 into !vector_type_R + %6 = vector.insert %i3, %5[3] : i64 into !vector_type_R + %7 = vector.insert %i4, %6[4] : i64 into !vector_type_R + %8 = vector.insert %i5, %7[5] : i64 into !vector_type_R + %9 = vector.insert %i10, %8[6] : i64 into !vector_type_R + + %o = vector.broadcast %i1 : i64 to !vector_type_R + + %axpy1 = vector.outerproduct %i2, %9 : i64, !vector_type_R + %axpy2 = vector.outerproduct %i2, %9, %o : i64, !vector_type_R + + vector.print %axpy1 : !vector_type_R + vector.print %axpy2 : !vector_type_R + // + // axpy operations: + // + // CHECK: ( 0, 2, 4, 6, 8, 10, 20 ) + // CHECK: ( 1, 3, 5, 7, 9, 11, 21 ) + return } diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -1114,10 +1114,13 @@ "expected at least 2 operands"); VectorType vLHS = tLHS.dyn_cast(); VectorType vRHS = tRHS.dyn_cast(); - if (!vLHS || !vRHS) - return parser.emitError(parser.getNameLoc(), "expected 2 vector types"); - VectorType resType = VectorType::get({vLHS.getDimSize(0), vRHS.getDimSize(0)}, - vLHS.getElementType()); + if (!vRHS) + return parser.emitError(parser.getNameLoc(), + "expected vector type for operand #2"); + VectorType resType = + vLHS ? VectorType::get({vLHS.getDimSize(0), vRHS.getDimSize(0)}, + vRHS.getElementType()) + : VectorType::get({vRHS.getDimSize(0)}, vRHS.getElementType()); return failure( parser.resolveOperand(operandsInfo[0], tLHS, result.operands) || parser.resolveOperand(operandsInfo[1], tRHS, result.operands) || @@ -1127,19 +1130,33 @@ } static LogicalResult verify(OuterProductOp op) { - VectorType vLHS = op.getOperandVectorTypeLHS(), + Type tLHS = op.getOperandTypeLHS(); + VectorType vLHS = tLHS.dyn_cast(), vRHS = op.getOperandVectorTypeRHS(), vACC = op.getOperandVectorTypeACC(), vRES = op.getVectorType(); - if (vLHS.getRank() != 1) - return op.emitOpError("expected 1-d vector for operand #1"); - if (vRHS.getRank() != 1) - return op.emitOpError("expected 1-d vector for operand #2"); - if (vRES.getRank() != 2) - return op.emitOpError("expected 2-d vector result"); - if (vLHS.getDimSize(0) != vRES.getDimSize(0)) - return op.emitOpError("expected #1 operand dim to match result dim #1"); - if (vRHS.getDimSize(0) != vRES.getDimSize(1)) - return op.emitOpError("expected #2 operand dim to match result dim #2"); + + if (vLHS) { + // Proper OUTER operation. + if (vLHS.getRank() != 1) + return op.emitOpError("expected 1-d vector for operand #1"); + if (vRHS.getRank() != 1) + return op.emitOpError("expected 1-d vector for operand #2"); + if (vRES.getRank() != 2) + return op.emitOpError("expected 2-d vector result"); + if (vLHS.getDimSize(0) != vRES.getDimSize(0)) + return op.emitOpError("expected #1 operand dim to match result dim #1"); + if (vRHS.getDimSize(0) != vRES.getDimSize(1)) + return op.emitOpError("expected #2 operand dim to match result dim #2"); + } else { + // An AXPY operation. + if (vRHS.getRank() != 1) + return op.emitOpError("expected 1-d vector for operand #2"); + if (vRES.getRank() != 1) + return op.emitOpError("expected 1-d vector result"); + if (vRHS.getDimSize(0) != vRES.getDimSize(0)) + return op.emitOpError("expected #2 operand dim to match result dim #1"); + } + if (vACC && vACC != vRES) return op.emitOpError("expected operand #3 of same type as result type"); return success(); diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -1278,33 +1278,45 @@ VectorType rhsType = op.getOperandVectorTypeRHS(); VectorType resType = op.getVectorType(); Type eltType = resType.getElementType(); + bool isInt = eltType.isa(); Value acc = (op.acc().empty()) ? nullptr : op.acc()[0]; + if (resType.getRank() == 1) { + // Special case: AXPY operation. + Value a = rewriter.create(loc, rhsType, op.lhs()); + rewriter.replaceOp(op, genMult(loc, a, op.rhs(), acc, isInt, rewriter)); + return success(); + } + Value result = rewriter.create(loc, resType, rewriter.getZeroAttr(resType)); for (int64_t d = 0, e = resType.getDimSize(0); d < e; ++d) { auto pos = rewriter.getI64ArrayAttr(d); Value x = rewriter.create(loc, eltType, op.lhs(), pos); - Value b = rewriter.create(loc, rhsType, x); - Value m; - if (acc) { - Value e = rewriter.create(loc, rhsType, acc, pos); - if (eltType.isa()) - m = rewriter.create( - loc, rewriter.create(loc, b, op.rhs()), e); - else - m = rewriter.create(loc, b, op.rhs(), e); - } else { - if (eltType.isa()) - m = rewriter.create(loc, b, op.rhs()); - else - m = rewriter.create(loc, b, op.rhs()); - } + Value a = rewriter.create(loc, rhsType, x); + Value r = nullptr; + if (acc) + r = rewriter.create(loc, rhsType, acc, pos); + Value m = genMult(loc, a, op.rhs(), r, isInt, rewriter); result = rewriter.create(loc, resType, m, result, pos); } rewriter.replaceOp(op, result); return success(); } + +private: + static Value genMult(Location loc, Value x, Value y, Value acc, bool isInt, + PatternRewriter &rewriter) { + if (acc) { + if (isInt) + return rewriter.create(loc, rewriter.create(loc, x, y), + acc); + return rewriter.create(loc, x, y, acc); + } + if (isInt) + return rewriter.create(loc, x, y); + return rewriter.create(loc, x, y); + } }; /// Progressive lowering of ConstantMaskOp. diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -180,14 +180,14 @@ // ----- -func @outerproduct_num_operands(%arg0: f32) { +func @outerproduct_num_operands(%arg0: vector<4xf32>) { // expected-error@+1 {{expected at least 2 operands}} - %1 = vector.outerproduct %arg0 : f32, f32 + %1 = vector.outerproduct %arg0 : vector<4xf32>, vector<4xf32> } // ----- -func @outerproduct_non_vector_operand(%arg0: f32) { - // expected-error@+1 {{expected 2 vector types}} +func @outerproduct_non_vector_rhs_operand(%arg0: f32) { + // expected-error@+1 {{expected vector type for operand #2}} %1 = vector.outerproduct %arg0, %arg0 : f32, f32 } @@ -228,6 +228,27 @@ // ----- +func @outerproduct_axpy_operand_2(%arg0: f32, %arg1: vector<4x8xf32>) { + // expected-error@+1 {{expected 1-d vector for operand #2}} + %1 = vector.outerproduct %arg0, %arg1 : f32, vector<4x8xf32> +} + +// ----- + +func @outerproduct_axpy_result_generic(%arg0: f32, %arg1: vector<4xf32>) { + // expected-error@+1 {{expected 1-d vector result}} + %1 = "vector.outerproduct" (%arg0, %arg1) : (f32, vector<4xf32>) -> (vector<4x8xf32>) +} + +// ----- + +func @outerproduct_axpy_operand_2_dim_generic(%arg0: f32, %arg1: vector<8xf32>) { + // expected-error@+1 {{expected #2 operand dim to match result dim #1}} + %1 = "vector.outerproduct" (%arg0, %arg1) : (f32, vector<8xf32>) -> (vector<16xf32>) +} + +// ----- + func @outerproduct_operand_3_result_type_generic(%arg0: vector<4xf32>, %arg1: vector<8xf32>, %arg2: vector<4x16xf32>) { // expected-error@+1 {{expected operand #3 of same type as result type}} %1 = "vector.outerproduct" (%arg0, %arg1, %arg2) : (vector<4xf32>, vector<8xf32>, vector<4x16xf32>) -> (vector<4x8xf32>) diff --git a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir --- a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir @@ -337,6 +337,53 @@ return %0: vector<2x3xi32> } +// CHECK-LABEL: func @axpy_fp( +// CHECK-SAME: %[[A:.*0]]: f32, +// CHECK-SAME: %[[B:.*1]]: vector<16xf32>) +// CHECK: %[[T0:.*]] = splat %[[A]] : vector<16xf32> +// CHECK: %[[T1:.*]] = mulf %[[T0]], %[[B]] : vector<16xf32> +// CHECK: return %[[T1]] : vector<16xf32> +func @axpy_fp(%arg0: f32, %arg1: vector<16xf32>) -> vector<16xf32> { + %0 = vector.outerproduct %arg0, %arg1: f32, vector<16xf32> + return %0: vector<16xf32> +} + +// CHECK-LABEL: func @axpy_fp_add( +// CHECK-SAME: %[[A:.*0]]: f32, +// CHECK-SAME: %[[B:.*1]]: vector<16xf32>, +// CHECK-SAME: %[[C:.*2]]: vector<16xf32>) +// CHECK: %[[T0:.*]] = splat %[[A]] : vector<16xf32> +// CHECK: %[[T1:.*]] = vector.fma %[[T0]], %[[B]], %[[C]] : vector<16xf32> +// CHECK: return %[[T1]] : vector<16xf32> +func @axpy_fp_add(%arg0: f32, %arg1: vector<16xf32>, %arg2 : vector<16xf32>) -> vector<16xf32> { + %0 = vector.outerproduct %arg0, %arg1, %arg2: f32, vector<16xf32> + return %0: vector<16xf32> +} + +// CHECK-LABEL: func @axpy_int( +// CHECK-SAME: %[[A:.*0]]: i32, +// CHECK-SAME: %[[B:.*1]]: vector<16xi32>) +// CHECK: %[[T0:.*]] = splat %[[A]] : vector<16xi32> +// CHECK: %[[T1:.*]] = muli %[[T0]], %[[B]] : vector<16xi32> +// CHECK: return %[[T1]] : vector<16xi32> +func @axpy_int(%arg0: i32, %arg1: vector<16xi32>) -> vector<16xi32> { + %0 = vector.outerproduct %arg0, %arg1: i32, vector<16xi32> + return %0: vector<16xi32> +} + +// CHECK-LABEL: func @axpy_int_add( +// CHECK-SAME: %[[A:.*0]]: i32, +// CHECK-SAME: %[[B:.*1]]: vector<16xi32>, +// CHECK-SAME: %[[C:.*2]]: vector<16xi32>) +// CHECK: %[[T0:.*]] = splat %[[A]] : vector<16xi32> +// CHECK: %[[T1:.*]] = muli %[[T0]], %[[B]] : vector<16xi32> +// CHECK: %[[T2:.*]] = addi %[[T1]], %[[C]] : vector<16xi32> +// CHECK: return %[[T2]] : vector<16xi32> +func @axpy_int_add(%arg0: i32, %arg1: vector<16xi32>, %arg2: vector<16xi32>) -> vector<16xi32> { + %0 = vector.outerproduct %arg0, %arg1, %arg2: i32, vector<16xi32> + return %0: vector<16xi32> +} + // CHECK-LABEL: func @transpose23 // CHECK-SAME: %[[A:.*]]: vector<2x3xf32> // CHECK: %[[Z:.*]] = constant dense<0.000000e+00> : vector<3x2xf32>