diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUBase.td b/mlir/include/mlir/Dialect/GPU/IR/GPUBase.td --- a/mlir/include/mlir/Dialect/GPU/IR/GPUBase.td +++ b/mlir/include/mlir/Dialect/GPU/IR/GPUBase.td @@ -102,7 +102,7 @@ GPU_Dialect, IsMMAMatrixTypePred, "MMAMatrix type">; // Memref type acceptable to gpu.subgroup_mma_{load|store}_matrix ops. -def GPU_MMAMemRef : MemRefOf<[F16, F32, VectorOfRankAndType<[1], [F16, F32]>]>; +def GPU_MMAMemRef : MemRefOf<[I8, I32, F16, F32, VectorOfRankAndType<[1], [I8, I32, F16, F32]>]>; class MMAMatrixOf allowedTypes> : ContainerType, IsMMAMatrixTypePred, diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td --- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td +++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td @@ -1201,7 +1201,7 @@ ``` }]; - let arguments = (ins Arg>:$src, + let arguments = (ins Arg>:$src, Arg:$dstMemref, Variadic:$indices, IndexAttr:$leadDimension, @@ -1227,11 +1227,14 @@ as `C += A * B`. The op returns a `!gpu.mma_matrix` which contains the result of the operation held by all threads in a subgroup. `a_transpose` or `b_transpose` if present, signify that the respective operand was loaded in a - transposed manner. The transpose opernads are required to map to correct + transposed manner. The transpose operands are required to map to correct underlying intrisics but they currently do not seem to affect correctness even if they are absent given that the operands were loaded correctly using the `transpose` attribute in `gpu.subgroup_mma_load_matrix` op. + For integer types the `A` and `B` matrices use the signedness of the operand + types in the multiplication. The accumulator type is implicitly signed. + This op is meant to be used along with `gpu.subgroup_mma_store_matrix` and `gpu.subgroup_mma_load_matrix` ops. @@ -1244,9 +1247,9 @@ ``` }]; - let arguments = (ins Arg>:$opA, - Arg>:$opB, - Arg>:$opC, + let arguments = (ins Arg>:$opA, + Arg>:$opB, + Arg>:$opC, OptionalAttr:$a_transpose, OptionalAttr:$b_transpose); @@ -1288,7 +1291,7 @@ ``` }]; - let arguments = (ins AnyTypeOf<[F16, F32]>:$value); + let arguments = (ins AnyTypeOf<[SI8, UI8, I32, F16, F32]>:$value); let results = (outs GPU_MMAMatrix:$res); diff --git a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp --- a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp @@ -57,6 +57,12 @@ if (type.getElementType().isF32()) return type.getOperand().equals("COp") ? NVVM::MMATypes::f32 : NVVM::MMATypes::tf32; + + // MMA types with integer element types imply signed. + if (type.getElementType().isInteger(8)) + return NVVM::MMATypes::s8; + if (type.getElementType().isInteger(32)) + return NVVM::MMATypes::s32; llvm_unreachable("Unsupported type"); } diff --git a/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp --- a/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp @@ -250,6 +250,7 @@ mlir::convertMMAToSPIRVType(gpu::MMAMatrixType type) { ArrayRef retTypeShape = type.getShape(); Type elementType = type.getElementType(); + return spirv::CooperativeMatrixNVType::get( elementType, spirv::Scope::Subgroup, retTypeShape[0], retTypeShape[1]); } diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp --- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp +++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp @@ -185,8 +185,14 @@ /// Return true if this is a broadcast from scalar to a 2D vector. static bool broadcastSupportsMMAMatrixType(vector::BroadcastOp broadcastOp) { - return broadcastOp.getVectorType().getRank() == 2 && - broadcastOp.getSource().getType().isa(); + return broadcastOp.getVectorType().getRank() == 2; +} + +/// Return true if this signed extend op can be folded into a contract op. +static bool signedExtendSupportsMMAMatrixType(arith::ExtSIOp extOp) { + return extOp->hasOneUse() && + isa(*extOp->user_begin()) && + isa(extOp.getOperand().getDefiningOp()); } /// Return the MMA elementwise enum associated with `op` if it is supported. @@ -268,6 +274,8 @@ return constantSupportsMMAMatrixType(constant); if (auto broadcast = dyn_cast(op)) return broadcastSupportsMMAMatrixType(broadcast); + if (auto extend = dyn_cast(op)) + return signedExtendSupportsMMAMatrixType(extend); return elementwiseSupportsMMAMatrixType(op); } @@ -479,14 +487,26 @@ stride = 0; } assert(stride); + Value mappingResult = op.getResult(); + auto elType = op.getVectorType().getElementType(); const char *fragType = inferFragType(op); + if (op->hasOneUse()) { + auto extOp = dyn_cast(*op->user_begin()); + // Infer the signedness of the mma type from the signed extend. + if (extOp) { + elType = IntegerType::get(op.getContext(), + elType.cast().getWidth(), + IntegerType::Signed); + mappingResult = extOp.getResult(); + fragType = inferFragType(extOp); + } + } gpu::MMAMatrixType type = - gpu::MMAMatrixType::get(op.getVectorType().getShape(), - op.getVectorType().getElementType(), fragType); + gpu::MMAMatrixType::get(op.getVectorType().getShape(), elType, fragType); Value load = b.create( op.getLoc(), type, op.getSource(), op.getIndices(), b.getIndexAttr(*stride), isTranspose ? b.getUnitAttr() : UnitAttr()); - valueMapping[op.getResult()] = load; + valueMapping[mappingResult] = load; } static void convertTransferWriteOp(vector::TransferWriteOp op, diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -78,7 +78,8 @@ StringRef MMAMatrixType::getOperand() const { return getImpl()->getOperand(); } bool MMAMatrixType::isValidElementType(Type elementType) { - return elementType.isF16() || elementType.isF32(); + return elementType.isF16() || elementType.isF32() || + elementType.isInteger(8) || elementType.isInteger(32); } LogicalResult @@ -93,7 +94,7 @@ return emitError() << "MMAMatrixType must have exactly two dimensions"; if (!MMAMatrixType::isValidElementType(elementType)) - return emitError() << "MMAMatrixType elements must be F16 or F32"; + return emitError() << "MMAMatrixType elements must be I8, I32, F16, or F32"; return success(); } diff --git a/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir b/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir --- a/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir +++ b/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir @@ -25,6 +25,25 @@ return } +// CHECK-LABEL: func @matmul_int8 +// CHECK-DAG: %[[A:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index} : memref<16x16xi8> -> !gpu.mma_matrix<16x16xi8, "AOp"> +// CHECK-DAG: %[[B:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index} : memref<16x16xi8> -> !gpu.mma_matrix<16x16xi8, "BOp"> +// CHECK-DAG: %[[C:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index} : memref<16x16xi32> -> !gpu.mma_matrix<16x16xi32, "COp"> +// CHECK: %[[D:.+]] = gpu.subgroup_mma_compute %[[A]], %[[B]], %[[C]] : !gpu.mma_matrix<16x16xi8, "AOp">, !gpu.mma_matrix<16x16xi8, "BOp"> -> !gpu.mma_matrix<16x16xi32, "COp"> +// CHECK: gpu.subgroup_mma_store_matrix %[[D]], %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index} : !gpu.mma_matrix<16x16xi32, "COp">, memref<16x16xi32> +func.func @matmul_int8(%arg0: memref<16x16xi8>, %arg1: memref<16x16xi8>, %arg2: memref<16x16xi32>) { + %cst_0 = arith.constant dense<0> : vector<16x16xi8> + %c0 = arith.constant 0 : index + %cst_i8 = arith.constant 0 : i8 + %cst_i32 = arith.constant 0 : i32 + %A = vector.transfer_read %arg0[%c0, %c0], %cst_i8 {in_bounds = [true, true]} : memref<16x16xi8>, vector<16x16xi8> + %B = vector.transfer_read %arg1[%c0, %c0], %cst_i8 {permutation_map = #map0, in_bounds = [true, true]} : memref<16x16xi8>, vector<16x16xi8> + %C = vector.transfer_read %arg2[%c0, %c0], %cst_i32 {in_bounds = [true, true]} : memref<16x16xi32>, vector<16x16xi32> + %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %A, %B, %C : vector<16x16xi8>, vector<16x16xi8> into vector<16x16xi32> + vector.transfer_write %D, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xi32>, memref<16x16xi32> + return +} + // CHECK-LABEL: func @matmul_cst // CHECK-DAG: %[[CST:.+]] = arith.constant 0.000000e+00 : f16 // CHECK-DAG: %[[A:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "AOp"> diff --git a/mlir/test/Dialect/GPU/invalid.mlir b/mlir/test/Dialect/GPU/invalid.mlir --- a/mlir/test/Dialect/GPU/invalid.mlir +++ b/mlir/test/Dialect/GPU/invalid.mlir @@ -485,8 +485,8 @@ func.func @mmamatrix_invalid_element_type(){ %wg = memref.alloca() {alignment = 32} : memref<32x32xf16, 3> %i = arith.constant 16 : index - // expected-error @+1 {{MMAMatrixType elements must be F16 or F32}} - %0 = gpu.subgroup_mma_load_matrix %wg[%i, %i] {leadDimension = 32 : index} : memref<32x32xf16, 3> -> !gpu.mma_matrix<16x16xi32, "AOp"> + // expected-error @+1 {{MMAMatrixType elements must be I8, I32, F16, or F32}} + %0 = gpu.subgroup_mma_load_matrix %wg[%i, %i] {leadDimension = 32 : index} : memref<32x32xf16, 3> -> !gpu.mma_matrix<16x16xbf16, "AOp"> return } @@ -505,7 +505,7 @@ // ----- func.func @mma_invalid_memref_type(%src: memref<32x4xvector<4x8xf32>>, %i: index) { - // expected-error @+1 {{operand #0 must be memref of 16-bit float or 32-bit float or vector of 16-bit float or 32-bit float values of ranks 1 values}} + // expected-error @+1 {{operand #0 must be memref of 8-bit signless integer or 32-bit signless integer or 16-bit float or 32-bit float or vector of 8-bit signless integer or 32-bit signless integer or 16-bit float or 32-bit float values of ranks 1 values}} %0 = gpu.subgroup_mma_load_matrix %src[%i, %i] {leadDimension = 4 : index} : memref<32x4xvector<4x8xf32>> -> !gpu.mma_matrix<16x16xf16, "AOp"> return }