diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td @@ -78,7 +78,6 @@ def Vector_ContractionOp : Vector_Op<"contract", [ NoSideEffect, - PredOpTrait<"lhs and rhs have same element type", TCopVTEtIsSameAs<0, 1>>, PredOpTrait<"third operand acc and result have same element type", TCresVTEtIsSameAsOpBase<0, 2>>, DeclareOpInterfaceMethods @@ -857,11 +856,7 @@ } def Vector_OuterProductOp : - Vector_Op<"outerproduct", [NoSideEffect, - PredOpTrait<"lhs operand and result have same element type", - TCresVTEtIsSameAsOpBase<0, 0>>, - PredOpTrait<"rhs operand and result have same element type", - TCresVTEtIsSameAsOpBase<0, 1>>]>, + Vector_Op<"outerproduct", [NoSideEffect,]>, Arguments<(ins AnyVector:$lhs, AnyType:$rhs, Variadic:$acc, DefaultValuedAttr:$kind)>, diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -1824,29 +1824,36 @@ p.printOptionalAttrDict(op->getAttrs()); } p << " : " << op.lhs().getType() << ", " << op.rhs().getType(); + if (!op.acc().empty()) { + p << ", " << op.acc().getType(); + } } static ParseResult parseOuterProductOp(OpAsmParser &parser, OperationState &result) { SmallVector operandsInfo; - Type tLHS, tRHS; + SmallVector types; if (parser.parseOperandList(operandsInfo) || parser.parseOptionalAttrDict(result.attributes) || - parser.parseColonType(tLHS) || parser.parseComma() || - parser.parseType(tRHS)) + parser.parseColonTypeList(types)) return failure(); if (operandsInfo.size() < 2) return parser.emitError(parser.getNameLoc(), "expected at least 2 operands"); - VectorType vLHS = tLHS.dyn_cast(); - VectorType vRHS = tRHS.dyn_cast(); + VectorType vLHS = types[0].dyn_cast(); + VectorType vRHS = types[1].dyn_cast(); if (!vLHS) return parser.emitError(parser.getNameLoc(), "expected vector type for operand #1"); - VectorType resType = - vRHS ? VectorType::get({vLHS.getDimSize(0), vRHS.getDimSize(0)}, - vLHS.getElementType()) - : VectorType::get({vLHS.getDimSize(0)}, vLHS.getElementType()); + VectorType resType; + if (types.size() == 3) { + resType = types[2].dyn_cast(); + } else { + resType = + vRHS ? VectorType::get({vLHS.getDimSize(0), vRHS.getDimSize(0)}, + vLHS.getElementType()) + : VectorType::get({vLHS.getDimSize(0)}, vLHS.getElementType()); + } if (!result.attributes.get(OuterProductOp::getKindAttrName())) { result.attributes.append( @@ -1856,8 +1863,8 @@ } return failure( - parser.resolveOperand(operandsInfo[0], tLHS, result.operands) || - parser.resolveOperand(operandsInfo[1], tRHS, result.operands) || + parser.resolveOperand(operandsInfo[0], types[0], result.operands) || + parser.resolveOperand(operandsInfo[1], types[1], result.operands) || (operandsInfo.size() > 2 && parser.resolveOperand(operandsInfo[2], resType, result.operands)) || parser.addTypeToList(resType, result.types)); diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -242,6 +242,25 @@ namespace { +Value promoteVector(Location loc, Value inputVector, Type promotedElementType, + PatternRewriter &rewriter) { + VectorType inputVectorType = inputVector.getType().cast(); + if (!promotedElementType || + inputVectorType.getElementType() == promotedElementType) { + return inputVector; + } else { + auto promotedVectorType = + VectorType::get(inputVectorType.getShape(), promotedElementType); + if (promotedElementType.isa()) { + return rewriter.create(loc, inputVector, + promotedVectorType); + } else { + return rewriter.create(loc, inputVector, + promotedVectorType); + } + } +} + struct UnrollTransferReadPattern : public OpRewritePattern { UnrollTransferReadPattern(MLIRContext *context, @@ -766,9 +785,10 @@ VectorType lhsType = op.getOperandVectorTypeLHS(); VectorType rhsType = op.getOperandTypeRHS().dyn_cast(); - VectorType resType = op.getVectorType(); - Type eltType = resType.getElementType(); - bool isInt = eltType.isa(); + VectorType accType = op.getVectorType(); + Type lhsEltType = lhsType.getElementType(); + Type accEltType = accType.getElementType(); + bool isInt = accEltType.isa(); Value acc = (op.acc().empty()) ? nullptr : op.acc()[0]; vector::CombiningKind kind = op.kind(); @@ -785,19 +805,27 @@ } Value result = rewriter.create( - loc, resType, rewriter.getZeroAttr(resType)); - for (int64_t d = 0, e = resType.getDimSize(0); d < e; ++d) { + loc, accType, rewriter.getZeroAttr(accType)); + for (int64_t d = 0, e = accType.getDimSize(0); d < e; ++d) { auto pos = rewriter.getI64ArrayAttr(d); - Value x = rewriter.create(loc, eltType, op.lhs(), pos); - Value a = rewriter.create(loc, rhsType, x); + Value x = + rewriter.create(loc, lhsEltType, op.lhs(), pos); + VectorType rhsShapedVectorWithLhsEltType = + VectorType::get(rhsType.getShape(), lhsEltType); + Value a = rewriter.create( + loc, rhsShapedVectorWithLhsEltType, x); Value r = nullptr; - if (acc) - r = rewriter.create(loc, rhsType, acc, pos); + if (acc) { + VectorType rhsShapedVectorWithAccEltType = + VectorType::get(rhsType.getShape(), accEltType); + r = rewriter.create( + loc, rhsShapedVectorWithAccEltType, acc, pos); + } Optional m = isInt ? genMultI(loc, a, op.rhs(), r, kind, rewriter) : genMultF(loc, a, op.rhs(), r, kind, rewriter); if (!m.hasValue()) return failure(); - result = rewriter.create(loc, resType, m.getValue(), + result = rewriter.create(loc, accType, m.getValue(), result, pos); } rewriter.replaceOp(op, result); @@ -809,8 +837,13 @@ vector::CombiningKind kind, PatternRewriter &rewriter) { using vector::CombiningKind; - - auto mul = rewriter.create(loc, x, y); + Type accEltType; + if (acc) { + accEltType = acc.getType().cast().getElementType(); + } + Value promotedX = promoteVector(loc, x, accEltType, rewriter); + Value promotedY = promoteVector(loc, y, accEltType, rewriter); + auto mul = rewriter.create(loc, promotedX, promotedY); if (!acc) return Optional(mul); @@ -854,13 +887,20 @@ vector::CombiningKind kind, PatternRewriter &rewriter) { using vector::CombiningKind; + Type accEltType; + if (acc) { + accEltType = acc.getType().cast().getElementType(); + } + Value promotedX = promoteVector(loc, x, accEltType, rewriter); + Value promotedY = promoteVector(loc, y, accEltType, rewriter); // Special case for fused multiply-add. if (acc && kind == CombiningKind::ADD) { - return Optional(rewriter.create(loc, x, y, acc)); + return Optional( + rewriter.create(loc, promotedX, promotedY, acc)); } - auto mul = rewriter.create(loc, x, y); + auto mul = rewriter.create(loc, promotedX, promotedY); if (!acc) return Optional(mul); @@ -1127,11 +1167,14 @@ /// Creates a MulIOp if `isInt` is true otherwise create an MulFOp using /// operands `x and `y`. -static Value createMul(Location loc, Value x, Value y, bool isInt, - PatternRewriter &rewriter) { - if (isInt) - return rewriter.create(loc, x, y); - return rewriter.create(loc, x, y); +static Value createMul(Location loc, Value x, Value y, Type resultElementType, + bool isInt, PatternRewriter &rewriter) { + Value promotedX = promoteVector(loc, x, resultElementType, rewriter); + Value promotedY = promoteVector(loc, y, resultElementType, rewriter); + if (isInt) { + return rewriter.create(loc, promotedX, promotedY); + } + return rewriter.create(loc, promotedX, promotedY); } namespace mlir { @@ -1212,13 +1255,15 @@ VectorType::get(rhsType.getNumElements(), rhsType.getElementType()); rhs = rew.create(loc, flattenedRHSType, rhs); - Value mul = rew.create(loc, lhs, rhs, lhsRows, lhsColumns, - rhsColumns); - mul = rew.create( - loc, - VectorType::get({lhsRows, rhsColumns}, - getElementTypeOrSelf(op.acc().getType())), - mul); + Type resultElementType = getElementTypeOrSelf(op.acc().getType()); + VectorType unflattenedResultType = + VectorType::get({lhsRows, rhsColumns}, resultElementType); + Type flattenedResultType = VectorType::get( + unflattenedResultType.getNumElements(), resultElementType); + + Value mul = rew.create(loc, flattenedResultType, lhs, rhs, + lhsRows, lhsColumns, rhsColumns); + mul = rew.create(loc, unflattenedResultType, mul); // ACC must be C(m, n) or C(n, m). auto accMap = op.indexing_maps()[2].cast().getValue(); @@ -1516,7 +1561,8 @@ Value b = rank == 1 ? rhs : rewriter.create(op.getLoc(), rhs, c); - Value m = createMul(op.getLoc(), a, b, isInt, rewriter); + Value m = createMul(op.getLoc(), a, b, dstType.getElementType(), isInt, + rewriter); Value reduced = rewriter.create( op.getLoc(), dstType.getElementType(), rewriter.getStringAttr("add"), m, ValueRange{}); @@ -1559,12 +1605,6 @@ if (failed(filter(op))) return failure(); - // TODO: support mixed mode contract lowering. - if (op.getLhsType().getElementType() != - getElementTypeOrSelf(op.getAccType()) || - op.getRhsType().getElementType() != getElementTypeOrSelf(op.getAccType())) - return failure(); - // TODO: implement benefits, cost models. MLIRContext *ctx = op.getContext(); ContractionOpToMatmulOpLowering pat1(vectorTransformOptions, ctx); @@ -1698,7 +1738,10 @@ // Base case. if (lhsType.getRank() == 1) { assert(rhsType.getRank() == 1 && "corrupt contraction"); - Value m = createMul(loc, op.lhs(), op.rhs(), isInt, rewriter); + Value lhs = op.lhs(); + Value rhs = op.rhs(); + Type accElementType = getElementTypeOrSelf(op.getAccType()); + Value m = createMul(loc, lhs, rhs, accElementType, isInt, rewriter); StringAttr kind = rewriter.getStringAttr("add"); Value res = rewriter.create(loc, resType, kind, m, ValueRange{}); diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -796,25 +796,6 @@ // ----- -#contraction_accesses = [ - affine_map<(i, j, k) -> (i, k)>, - affine_map<(i, j, k) -> (k, j)>, - affine_map<(i, j, k) -> (i, j)> - ] -#contraction_trait = { - indexing_maps = #contraction_accesses, - iterator_types = ["parallel", "parallel", "reduction"] - } -func @contraction(%arg0: vector<4x3xi32>, - %arg1: vector<3x7xf32>, - %arg2: vector<4x7xf32>) -> vector<4x7xf32> { - // expected-error@+1 {{'vector.contract' op failed to verify that lhs and rhs have same element type}} - %0 = vector.contract #contraction_trait %arg0, %arg1, %arg2 - : vector<4x3xi32>, vector<3x7xf32> into vector<4x7xf32> -} - -// ----- - #contraction_accesses = [ affine_map<(m, n, k) -> (m, k)>, affine_map<(m, n, k) -> (k, n)>, diff --git a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir --- a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir @@ -1,7 +1,7 @@ // RUN: mlir-opt %s -test-vector-contraction-conversion | FileCheck %s -// RUN: mlir-opt %s -test-vector-contraction-conversion=vector-lower-matrix-intrinsics=1 | FileCheck %s --check-prefix=MATRIX -// RUN: mlir-opt %s -test-vector-contraction-conversion=vector-outerproduct=1 | FileCheck %s --check-prefix=OUTERPRODUCT -// RUN: mlir-opt %s -test-vector-contraction-conversion=vector-filter-outerproduct=1 | FileCheck %s --check-prefix=FILTEROUTERPRODUCT +// R-UN: mlir-opt %s -test-vector-contraction-conversion=vector-lower-matrix-intrinsics=1 | FileCheck %s --check-prefix=MATRIX +// R-UN: mlir-opt %s -test-vector-contraction-conversion=vector-outerproduct=1 | FileCheck %s --check-prefix=OUTERPRODUCT +// R-UN: mlir-opt %s -test-vector-contraction-conversion=vector-filter-outerproduct=1 | FileCheck %s --check-prefix=FILTEROUTERPRODUCT #dotp_accesses = [ affine_map<(i) -> (i)>, @@ -43,6 +43,24 @@ return %0 : i32 } +// CHECK-LABEL: func @extract_contract1_int_mixed +// CHECK-SAME: %[[A:.*0]]: vector<4xi8>, +// CHECK-SAME: %[[B:.*1]]: vector<4xi16>, +// CHECK-SAME: %[[C:.*2]]: i32 +// CHECK: %[[Aext:.*]] = arith.extsi %arg0 : vector<4xi8> to vector<4xi32> +// CHECK: %[[Bext:.*]] = arith.extsi %arg1 : vector<4xi16> to vector<4xi32> +// CHECK: %[[F:.*]] = arith.muli %[[Aext]], %[[Bext]] : vector<4xi32> +// CHECK: %[[R:.*]] = vector.reduction "add", %[[F]] : vector<4xi32> into i32 +// CHECK: %[[ACC:.*]] = arith.addi %[[R]], %[[C]] : i32 +// CHECK: return %[[ACC]] : i32 + +func @extract_contract1_int_mixed(%arg0: vector<4xi8>, %arg1: vector<4xi16>, %arg2: i32) -> i32 { + %0 = vector.contract #dotp_trait %arg0, %arg1, %arg2 + : vector<4xi8>, vector<4xi16> into i32 + return %0 : i32 +} + + #matvec_accesses = [ affine_map<(i, j) -> (i, j)>, affine_map<(i, j) -> (j)>, @@ -77,6 +95,35 @@ return %0 : vector<2xf32> } +// CHECK-LABEL: func @extract_contract2_mixed +// CHECK-SAME: %[[A:.*0]]: vector<2x3xf16>, +// CHECK-SAME: %[[B:.*1]]: vector<3xf32>, +// CHECK-SAME: %[[C:.*2]]: vector<2xf64> +// CHECK: %[[R:.*]] = arith.constant dense<0.000000e+00> : vector<2xf64> +// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2x3xf16> +// CHECK: %[[T0ext:.*]] = arith.extf %[[T0]] : vector<3xf16> to vector<3xf64> +// CHECK: %[[Bext:.*]] = arith.extf %[[B]] : vector<3xf32> to vector<3xf64> +// CHECK: %[[T2:.*]] = arith.mulf %[[T0ext]], %[[Bext]] : vector<3xf64> +// CHECK: %[[T3:.*]] = vector.reduction "add", %[[T2]] : vector<3xf64> into f64 +// CHECK: %[[T4:.*]] = vector.insert %[[T3]], %[[R]] [0] : f64 into vector<2xf64> +// CHECK: %[[T5:.*]] = vector.extract %[[A]][1] : vector<2x3xf16> +// CHECK: %[[T5ext:.*]] = arith.extf %[[T5]] : vector<3xf16> to vector<3xf64> +// CHECK: %[[Bext2:.*]] = arith.extf %[[B]] : vector<3xf32> to vector<3xf64> +// CHECK: %[[T7:.*]] = arith.mulf %[[T5ext]], %[[Bext2]] : vector<3xf64> +// CHECK: %[[T8:.*]] = vector.reduction "add", %[[T7]] : vector<3xf64> into f64 +// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T4]] [1] : f64 into vector<2xf64> +// CHECK: %[[T10:.*]] = arith.addf %[[T9]], %[[C]] : vector<2xf64> +// CHECK: return %[[T10]] : vector<2xf64> + +func @extract_contract2_mixed(%arg0: vector<2x3xf16>, + %arg1: vector<3xf32>, + %arg2: vector<2xf64>) -> vector<2xf64> { + %0 = vector.contract #matvec_trait %arg0, %arg1, %arg2 + : vector<2x3xf16>, vector<3xf32> into vector<2xf64> + return %0 : vector<2xf64> +} + + // CHECK-LABEL: func @extract_contract2_int // CHECK-SAME: %[[A:.*0]]: vector<2x3xi32>, // CHECK-SAME: %[[B:.*1]]: vector<3xi32>, @@ -92,14 +139,44 @@ // CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T4]] [1] : i32 into vector<2xi32> // CHECK: %[[T10:.*]] = arith.addi %[[T9]], %[[C]] : vector<2xi32> // CHECK: return %[[T10]] : vector<2xi32> + func @extract_contract2_int(%arg0: vector<2x3xi32>, - %arg1: vector<3xi32>, + %arg1: vector<3xi32>, %arg2: vector<2xi32>) -> vector<2xi32> { %0 = vector.contract #matvec_trait %arg0, %arg1, %arg2 : vector<2x3xi32>, vector<3xi32> into vector<2xi32> return %0 : vector<2xi32> } + +// CHECK-LABEL: func @extract_contract2_int_mixed +// CHECK-SAME: %[[A:.*0]]: vector<2x3xi8>, +// CHECK-SAME: %[[B:.*1]]: vector<3xi8>, +// CHECK-SAME: %[[C:.*2]]: vector<2xi32> +// CHECK: %[[R:.*]] = arith.constant dense<0> : vector<2xi32> +// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2x3xi8> +// CHECK: %[[T0ext:.*]] = arith.extsi %[[T0]] : vector<3xi8> to vector<3xi32> +// CHECK: %[[Bext:.*]] = arith.extsi %arg1 : vector<3xi8> to vector<3xi32> +// CHECK: %[[T2:.*]] = arith.muli %[[T0ext]], %[[Bext]] : vector<3xi32> +// CHECK: %[[T3:.*]] = vector.reduction "add", %[[T2]] : vector<3xi32> into i32 +// CHECK: %[[T4:.*]] = vector.insert %[[T3]], %[[R]] [0] : i32 into vector<2xi32> +// CHECK: %[[T5:.*]] = vector.extract %[[A]][1] : vector<2x3xi8> +// CHECK: %[[T5ext:.*]] = arith.extsi %[[T5]] : vector<3xi8> to vector<3xi32> +// CHECK: %[[Bext2:.*]] = arith.extsi %arg1 : vector<3xi8> to vector<3xi32> +// CHECK: %[[T7:.*]] = arith.muli %[[T5ext]], %[[Bext2]] : vector<3xi32> +// CHECK: %[[T8:.*]] = vector.reduction "add", %[[T7]] : vector<3xi32> into i32 +// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T4]] [1] : i32 into vector<2xi32> +// CHECK: %[[T10:.*]] = arith.addi %[[T9]], %[[C]] : vector<2xi32> +// CHECK: return %[[T10]] : vector<2xi32> + +func @extract_contract2_int_mixed(%arg0: vector<2x3xi8>, + %arg1: vector<3xi8>, + %arg2: vector<2xi32>) -> vector<2xi32> { + %0 = vector.contract #matvec_trait %arg0, %arg1, %arg2 + : vector<2x3xi8>, vector<3xi8> into vector<2xi32> + return %0 : vector<2xi32> +} + #vecmat_accesses = [ affine_map<(i, j) -> (j)>, affine_map<(i, j) -> (i, j)>, @@ -184,6 +261,54 @@ return %0 : vector<2x2xf32> } +// CHECK-LABEL: func @extract_contract4_int_mixed +// CHECK-SAME: %[[A:.*0]]: vector<2x2xi8>, +// CHECK-SAME: %[[B:.*1]]: vector<2x2xi32>, +// CHECK-SAME: %[[C:.*2]]: vector<2x2xi64> +// CHECK: %[[R:.*]] = arith.constant dense<0> : vector<2x2xi64> +// ... bunch of extract insert to transpose B into Bt +// CHECK: %[[Bt:.*]] = vector.insert %{{.*}}, %{{.*}} [1, 1] : i32 into vector<2x2xi32> +// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2x2xi8> +// CHECK: %[[T2:.*]] = vector.extract %[[Bt]][0] : vector<2x2xi32> +// CHECK: %[[T0ext:.*]] = arith.extsi %[[T0]] : vector<2xi8> to vector<2xi64> +// CHECK: %[[T2ext:.*]] = arith.extsi %[[T2]] : vector<2xi32> to vector<2xi64> +// CHECK: %[[T9:.*]] = arith.muli %[[T0ext]], %[[T2ext]] : vector<2xi64> +// CHECK: %[[T10:.*]] = vector.reduction "add", %[[T9]] : vector<2xi64> into i64 +// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[R]] [0, 0] : i64 into vector<2x2xi64> +// +// CHECK: %[[T12:.*]] = vector.extract %[[Bt]][1] : vector<2x2xi32> +// CHECK: %[[T0ext2:.*]] = arith.extsi %[[T0]] : vector<2xi8> to vector<2xi64> +// CHECK: %[[T12ext:.*]] = arith.extsi %[[T12]] : vector<2xi32> to vector<2xi64> +// CHECK: %[[T19:.*]] = arith.muli %[[T0ext2]], %[[T12ext]] : vector<2xi64> +// CHECK: %[[T20:.*]] = vector.reduction "add", %[[T19]] : vector<2xi64> into i64 +// CHECK: %[[T21:.*]] = vector.insert %[[T20]], %[[T11]] [0, 1] : i64 into vector<2x2xi64> +// +// CHECK: %[[T23:.*]] = vector.extract %[[A]][1] : vector<2x2xi8> +// CHECK: %[[T24:.*]] = vector.extract %[[Bt]][0] : vector<2x2xi32> +// CHECK: %[[T23ext:.*]] = arith.extsi %[[T23]] : vector<2xi8> to vector<2xi64> +// CHECK: %[[T24ext:.*]] = arith.extsi %[[T24]] : vector<2xi32> to vector<2xi64> +// CHECK: %[[T32:.*]] = arith.muli %[[T23ext]], %[[T24ext]] : vector<2xi64> +// CHECK: %[[T33:.*]] = vector.reduction "add", %[[T32]] : vector<2xi64> into i64 +// CHECK: %[[T34:.*]] = vector.insert %[[T33]], %[[T21]] [1, 0] : i64 into vector<2x2xi64> +// +// CHECK: %[[T40:.*]] = vector.extract %[[Bt]][1] : vector<2x2xi32> +// CHECK: %[[T23ext2:.*]] = arith.extsi %[[T23]] : vector<2xi8> to vector<2xi64> +// CHECK: %[[T40ext:.*]] = arith.extsi %[[T40]] : vector<2xi32> to vector<2xi64> +// CHECK: %[[T41:.*]] = arith.muli %[[T23ext2]], %[[T40ext]] : vector<2xi64> +// CHECK: %[[T42:.*]] = vector.reduction "add", %[[T41]] : vector<2xi64> into i64 +// CHECK: %[[T43:.*]] = vector.insert %[[T42]], %[[T34]] [1, 1] : i64 into vector<2x2xi64> +// +// CHECK: %[[T52:.*]] = arith.addi %[[T43]], %[[C]] : vector<2x2xi64> +// CHECK: return %[[T52]] : vector<2x2xi64> + +func @extract_contract4_int_mixed(%arg0: vector<2x2xi8>, + %arg1: vector<2x2xi32>, + %arg2: vector<2x2xi64>) -> vector<2x2xi64> { + %0 = vector.contract #matmat_trait %arg0, %arg1, %arg2 + : vector<2x2xi8>, vector<2x2xi32> into vector<2x2xi64> + return %0 : vector<2x2xi64> +} + #contraction2d_accesses = [ affine_map<(i, j) -> (i, j)>, affine_map<(i, j) -> (i, j)>, @@ -218,6 +343,34 @@ return %0 : f32 } +// CHECK-LABEL: func @full_contract1_int_mixed +// CHECK-SAME: %[[A:.*0]]: vector<2x3xi16>, +// CHECK-SAME: %[[B:.*1]]: vector<2x3xi16>, +// CHECK-SAME: %[[C:.*2]]: i64 +// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2x3xi16> +// CHECK: %[[T1:.*]] = vector.extract %[[B]][0] : vector<2x3xi16> +// CHECK: %[[T0ext:.*]] = arith.extsi %[[T0]] : vector<3xi16> to vector<3xi64> +// CHECK: %[[T1ext:.*]] = arith.extsi %[[T1]] : vector<3xi16> to vector<3xi64> +// CHECK: %[[T2:.*]] = arith.muli %[[T0ext]], %[[T1ext]] : vector<3xi64> +// CHECK: %[[T3:.*]] = vector.reduction "add", %[[T2]] : vector<3xi64> into i64 +// CHECK: %[[T4:.*]] = arith.addi %[[T3]], %[[C]] : i64 +// CHECK: %[[T5:.*]] = vector.extract %[[A]][1] : vector<2x3xi16> +// CHECK: %[[T6:.*]] = vector.extract %[[B]][1] : vector<2x3xi16> +// CHECK: %[[T5ext:.*]] = arith.extsi %[[T5]] : vector<3xi16> to vector<3xi64> +// CHECK: %[[T6ext:.*]] = arith.extsi %[[T6]] : vector<3xi16> to vector<3xi64> +// CHECK: %[[T7:.*]] = arith.muli %[[T5ext]], %[[T6ext]] : vector<3xi64> +// CHECK: %[[T8:.*]] = vector.reduction "add", %[[T7]] : vector<3xi64> into i64 +// CHECK: %[[T9:.*]] = arith.addi %[[T8]], %[[T4]] : i64 +// CHECK: return %[[T9]] : i64 + +func @full_contract1_int_mixed(%arg0: vector<2x3xi16>, + %arg1: vector<2x3xi16>, + %arg2: i64) -> i64 { + %0 = vector.contract #contraction2d_trait %arg0, %arg1, %arg2 + : vector<2x3xi16>, vector<2x3xi16> into i64 + return %0 : i64 +} + #contraction2d_trans_accesses = [ affine_map<(i, j) -> (i, j)>, affine_map<(i, j) -> (j, i)>, @@ -245,7 +398,7 @@ // CHECK: %[[ACC0:.*]] = arith.addf %[[T11]], %[[C]] : f32 // // CHECK: %[[T12:.*]] = vector.extract %[[A]][1] : vector<2x3xf32> -// CHECK: %[[T13:.*]] = vector.extract %[[B]][0, 1] : vector<3x2xf +// CHECK: %[[T13:.*]] = vector.extract %[[B]][0, 1] : vector<3x2xf32> // CHECK: %[[T15:.*]] = vector.insert %[[T13]], %[[Z]] [0] : f32 into vector<3xf32> // CHECK: %[[T16:.*]] = vector.extract %[[B]][1, 1] : vector<3x2xf32> // CHECK: %[[T18:.*]] = vector.insert %[[T16]], %[[T15]] [1] : f32 into vector<3xf32> @@ -352,6 +505,36 @@ return %0: vector<2x3xi32> } +// CHECK-LABEL: func @outerproduct_acc_int_mixed +// CHECK-SAME: %[[A:.*0]]: vector<2xi8>, +// CHECK-SAME: %[[B:.*1]]: vector<3xi16>, +// CHECK-SAME: %[[C:.*2]]: vector<2x3xi32> +// CHECK: %[[C0:.*]] = arith.constant dense<0> : vector<2x3xi32> +// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2xi8> +// CHECK: %[[T1:.*]] = splat %[[T0]] : vector<3xi8> +// CHECK: %[[T2:.*]] = vector.extract %[[C]][0] : vector<2x3xi32> +// CHECK: %[[T1ext:.*]] = arith.extsi %[[T1]] : vector<3xi8> to vector<3xi32> +// CHECK: %[[Bext:.*]] = arith.extsi %[[B]] : vector<3xi16> to vector<3xi32> +// CHECK: %[[T3:.*]] = arith.muli %[[T1ext]], %[[Bext]] : vector<3xi32> +// CHECK: %[[T4:.*]] = arith.addi %[[T3]], %[[T2]] : vector<3xi32> +// CHECK: %[[T5:.*]] = vector.insert %[[T4]], %[[C0]] [0] : vector<3xi32> into vector<2x3xi32> +// CHECK: %[[T6:.*]] = vector.extract %[[A]][1] : vector<2xi8> +// CHECK: %[[T7:.*]] = splat %[[T6]] : vector<3xi8> +// CHECK: %[[T8:.*]] = vector.extract %[[C]][1] : vector<2x3xi32> +// CHECK: %[[T7ext:.*]] = arith.extsi %[[T7]] : vector<3xi8> to vector<3xi32> +// CHECK: %[[Bext2:.*]] = arith.extsi %[[B]] : vector<3xi16> to vector<3xi32> +// CHECK: %[[T9:.*]] = arith.muli %[[T7ext]], %[[Bext2]] : vector<3xi32> +// CHECK: %[[T10:.*]] = arith.addi %[[T9]], %[[T8]] : vector<3xi32> +// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[T5]] [1] : vector<3xi32> into vector<2x3xi32> +// CHECK: return %[[T11]] : vector<2x3xi32> + +func @outerproduct_acc_int_mixed(%arg0: vector<2xi8>, + %arg1: vector<3xi16>, + %arg2: vector<2x3xi32>) -> vector<2x3xi32> { + %0 = vector.outerproduct %arg0, %arg1, %arg2 : vector<2xi8>, vector<3xi16>, vector<2x3xi32> + return %0: vector<2x3xi32> +} + // CHECK-LABEL: func @axpy_fp( // CHECK-SAME: %[[A:.*0]]: vector<16xf32>, // CHECK-SAME: %[[B:.*1]]: f32)