diff --git a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp --- a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp @@ -60,6 +60,8 @@ if (type.getElementType().isSignedInteger(8)) return NVVM::MMATypes::s8; + if (type.getElementType().isUnsignedInteger(8)) + return NVVM::MMATypes::u8; // Accumulator type is signless and implies signed. if (type.getElementType().isInteger(32)) return NVVM::MMATypes::s32; @@ -112,11 +114,8 @@ } NVVM::MMAFrag frag = convertOperand(retType.getOperand()); // Check that there is an exisiting instruction for the combination we need. - if (NVVM::WMMALoadOp::getIntrinsicID(m, n, k, layout, eltype, frag) == 0) { - llvm::errs() << "No matching intrinsic " << m << " " << n << " " << k - << "\n"; + if (NVVM::WMMALoadOp::getIntrinsicID(m, n, k, layout, eltype, frag) == 0) return rewriter.notifyMatchFailure(op, kInvalidCaseStr); - } Type resType = convertMMAToLLVMType(retType); Location loc = op->getLoc(); @@ -245,6 +244,12 @@ destType) == 0) return rewriter.notifyMatchFailure(op, kInvalidCaseStr); + NVVM::MMATypes bElementType = getElementType( + subgroupMmaComputeOp.getOpB().getType().cast()); + if (bElementType != sourceType) + return rewriter.notifyMatchFailure( + op, "WMMA compute op input matrix element types must match."); + unpackOp(adaptor.getOpA()); unpackOp(adaptor.getOpB()); unpackOp(adaptor.getOpC()); diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp --- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp +++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp @@ -143,7 +143,8 @@ // Only allow integer types if the signedness can be inferred. if (!useNvGpu && readOp.getVectorType().getElementType().isInteger(8)) - if (!readOp->hasOneUse() || !isa(*readOp->user_begin())) + if (!readOp->hasOneUse() || (!isa(*readOp->user_begin()) && + !isa(*readOp->user_begin()))) return false; AffineMap map = readOp.getPermutationMap(); @@ -194,8 +195,9 @@ return broadcastOp.getVectorType().getRank() == 2; } -/// Return true if this signed extend op can be folded into a contract op. -static bool signedExtendSupportsMMAMatrixType(arith::ExtSIOp extOp) { +/// Return true if this integer extend op can be folded into a contract op. +template +static bool integerExtendSupportsMMAMatrixType(ExtOpTy extOp) { if (!isa(extOp.getOperand().getDefiningOp())) return false; return llvm::all_of(extOp->getUsers(), [](Operation *user) { @@ -282,8 +284,10 @@ return constantSupportsMMAMatrixType(constant); if (auto broadcast = dyn_cast(op)) return broadcastSupportsMMAMatrixType(broadcast); - if (auto extend = dyn_cast(op)) - return signedExtendSupportsMMAMatrixType(extend); + if (auto signedExtend = dyn_cast(op)) + return integerExtendSupportsMMAMatrixType(signedExtend); + if (auto unsignedExtend = dyn_cast(op)) + return integerExtendSupportsMMAMatrixType(unsignedExtend); return elementwiseSupportsMMAMatrixType(op); } @@ -429,10 +433,11 @@ PatternRewriter &rewriter) const override { // Look through integer extend ops. Value source = op.getVector(); - auto extOp = source.getDefiningOp(); auto resultType = op.getVectorType(); - if (extOp) { - source = extOp.getOperand(); + Operation *extOp; + if ((extOp = source.getDefiningOp()) || + (extOp = source.getDefiningOp())) { + source = extOp->getOperand(0); resultType = VectorType::get(resultType.getShape(), source.getType().cast().getElementType()); @@ -469,9 +474,14 @@ .getResult(); // Fuse through the integer extend op. - if (extOp) - result = rewriter.create(loc, op.getType(), result) - .getResult(); + if (extOp) { + if (isa(extOp)) + result = rewriter.create(loc, op.getType(), result) + .getResult(); + else + result = rewriter.create(loc, op.getType(), result) + .getResult(); + } rewriter.replaceOp(op, result); return success(); @@ -484,15 +494,14 @@ // Figure the right layout to use by looking at op uses. // TODO: Change the GPU dialect to abstract the layout at the this level and // only care about it during lowering to NVVM. -template -static const char *inferFragType(OpTy op) { +static const char *inferFragType(Operation *op) { for (Operation *users : op->getUsers()) { auto contract = dyn_cast(users); if (!contract) continue; - if (contract.getLhs() == op.getResult()) + if (contract.getLhs() == op->getResult(0)) return "AOp"; - if (contract.getRhs() == op.getResult()) + if (contract.getRhs() == op->getResult(0)) return "BOp"; } return "COp"; @@ -521,14 +530,15 @@ auto elType = op.getVectorType().getElementType(); const char *fragType = inferFragType(op); if (op->hasOneUse()) { - auto extOp = dyn_cast(*op->user_begin()); - // Infer the signedness of the mma type from the signed extend. - if (extOp) { - elType = IntegerType::get(op.getContext(), - elType.cast().getWidth(), - IntegerType::Signed); - mappingResult = extOp.getResult(); - fragType = inferFragType(extOp); + auto user = *op->user_begin(); + // Infer the signedness of the mma type from the integer extend. + bool isSignedExtend = isa(user); + if (isSignedExtend || isa(user)) { + elType = IntegerType::get( + op.getContext(), elType.cast().getWidth(), + isSignedExtend ? IntegerType::Signed : IntegerType::Unsigned); + mappingResult = user->getResult(0); + fragType = inferFragType(user); } } gpu::MMAMatrixType type = diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -4028,9 +4028,19 @@ typeR.getScope() != typeB.getScope() || typeR.getScope() != typeC.getScope()) return op.emitOpError("matrix scope must match"); - if (typeA.getElementType() != typeB.getElementType() || - typeR.getElementType() != typeC.getElementType()) - return op.emitOpError("matrix element type must match"); + auto elementTypeA = typeA.getElementType(); + auto elementTypeB = typeB.getElementType(); + if (isa(elementTypeA) && isa(elementTypeB)) { + if (elementTypeA.cast().getWidth() != + elementTypeB.cast().getWidth()) + return op.emitOpError( + "matrix A and B integer element types must be the same bit width"); + } else if (elementTypeA != elementTypeB) { + return op.emitOpError( + "matrix A and B non-integer element types must match"); + } + if (typeR.getElementType() != typeC.getElementType()) + return op.emitOpError("matrix accumulator element type must match"); return success(); } diff --git a/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir b/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir --- a/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir +++ b/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir @@ -266,3 +266,24 @@ vector.transfer_write %D, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xi32>, memref<16x16xi32> return } + +// CHECK-LABEL: func @matmul_mixed_signedness_int8 +// CHECK-DAG: %[[A:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index} : memref<16x16xi8> -> !gpu.mma_matrix<16x16xui8, "AOp"> +// CHECK-DAG: %[[B:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index} : memref<16x16xi8> -> !gpu.mma_matrix<16x16xsi8, "BOp"> +// CHECK-DAG: %[[C:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index} : memref<16x16xi32> -> !gpu.mma_matrix<16x16xi32, "COp"> +// CHECK: %[[D:.+]] = gpu.subgroup_mma_compute %[[A]], %[[B]], %[[C]] : !gpu.mma_matrix<16x16xui8, "AOp">, !gpu.mma_matrix<16x16xsi8, "BOp"> -> !gpu.mma_matrix<16x16xi32, "COp"> +// CHECK: gpu.subgroup_mma_store_matrix %[[D]], %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index} : !gpu.mma_matrix<16x16xi32, "COp">, memref<16x16xi32> +func.func @matmul_mixed_signedness_int8(%arg0: memref<16x16xi8>, %arg1: memref<16x16xi8>, %arg2: memref<16x16xi32>) { + %cst_0 = arith.constant dense<0> : vector<16x16xi8> + %c0 = arith.constant 0 : index + %cst_i8 = arith.constant 0 : i8 + %cst_i32 = arith.constant 0 : i32 + %Ar = vector.transfer_read %arg0[%c0, %c0], %cst_i8 {in_bounds = [true, true]} : memref<16x16xi8>, vector<16x16xi8> + %Br = vector.transfer_read %arg1[%c0, %c0], %cst_i8 {permutation_map = #map0, in_bounds = [true, true]} : memref<16x16xi8>, vector<16x16xi8> + %C = vector.transfer_read %arg2[%c0, %c0], %cst_i32 {in_bounds = [true, true]} : memref<16x16xi32>, vector<16x16xi32> + %Ae = arith.extui %Ar : vector<16x16xi8> to vector<16x16xi32> + %Be = arith.extsi %Br : vector<16x16xi8> to vector<16x16xi32> + %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %Ae, %Be, %C : vector<16x16xi32>, vector<16x16xi32> into vector<16x16xi32> + vector.transfer_write %D, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xi32>, memref<16x16xi32> + return +} diff --git a/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir b/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir --- a/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir @@ -136,13 +136,21 @@ // ----- spirv.func @cooperative_matrix_muladd(%a : !spirv.coopmatrix<8x16xf32, Subgroup>, %b : !spirv.coopmatrix<16x8xi32, Subgroup>, %c : !spirv.coopmatrix<8x8xi32, Subgroup>) "None" { - // expected-error @+1 {{matrix element type must match}} + // expected-error @+1 {{matrix A and B non-integer element types must match}} %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.coopmatrix<8x16xf32, Subgroup>, !spirv.coopmatrix<16x8xi32, Subgroup> -> !spirv.coopmatrix<8x8xi32, Subgroup> spirv.Return } // ----- +spirv.func @cooperative_matrix_muladd(%a : !spirv.coopmatrix<8x16xui8, Subgroup>, %b : !spirv.coopmatrix<16x8xsi32, Subgroup>, %c : !spirv.coopmatrix<8x8xi32, Subgroup>) "None" { + // expected-error @+1 {{matrix A and B integer element types must be the same bit width}} + %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.coopmatrix<8x16xui8, Subgroup>, !spirv.coopmatrix<16x8xsi32, Subgroup> -> !spirv.coopmatrix<8x8xi32, Subgroup> + spirv.Return +} + +// ----- + spirv.func @cooperative_matrix_load_memaccess(%ptr : !spirv.ptr, StorageBuffer>, %stride : i32, %b : i1) "None" { // expected-error @+1 {{Pointer must point to a scalar or vector type}} %0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %b : !spirv.ptr, StorageBuffer> as !spirv.coopmatrix<8x16xi32, Subgroup>