diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -110,6 +110,11 @@ num_batch_dims (see dimension type descriptions below)). For K = 0 (no free or batch dimensions), the accumulator and output are a scalar. + If operands and the result have types of different bitwidths, operands are + promoted to have the same bitwidth as the result before performing the + contraction. For integer types, only signless integer types are supported, + and the promotion happens via sign extension. + Optional vector mask arguments (produced by CreateMaskOp or ConstantMaskOp) specify the dynamic dimension sizes of valid data within the lhs/rhs vector arguments. diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -803,10 +803,15 @@ } LogicalResult ContractionOp::verify() { - auto lhsType = getLhsType(); - auto rhsType = getRhsType(); - auto accType = getAccType(); - auto resType = getResultType(); + VectorType lhsType = getLhsType(); + VectorType rhsType = getRhsType(); + Type accType = getAccType(); + Type resType = getResultType(); + + if (lhsType.getElementType().isa()) { + if (!lhsType.getElementType().isSignlessInteger()) + return emitOpError("only supports signless integer types"); + } // Verify that an indexing map was specified for each vector operand. if (getIndexingMapsArray().size() != 3) diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -1652,3 +1652,14 @@ // expected-error@+1 {{op failed to verify that position is a multiple of the result length.}} %0 = vector.scalable.extract %vec[5] : vector<4xf32> from vector<[16]xf32> } + +// ----- + +func.func @integer_vector_contract(%arg0: vector<16x32xsi8>, %arg1: vector<32x16xsi8>, %arg2: vector<16x16xsi32>) -> vector<16x16xsi32> { + // expected-error@+1 {{op only supports signless integer types}} + %0 = vector.contract { + indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], + iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind + } %arg0, %arg1, %arg2 : vector<16x32xsi8>, vector<32x16xsi8> into vector<16x16xsi32> + return %0: vector<16x16xsi32> +}