diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUBase.td b/mlir/include/mlir/Dialect/GPU/IR/GPUBase.td --- a/mlir/include/mlir/Dialect/GPU/IR/GPUBase.td +++ b/mlir/include/mlir/Dialect/GPU/IR/GPUBase.td @@ -101,7 +101,7 @@ GPU_Dialect, IsMMAMatrixTypePred, "MMAMatrix type">; // Memref type acceptable to gpu.subgroup_mma_{load|store}_matrix ops. -def GPU_MMAMemRef : MemRefOf<[F16, F32, VectorOfRankAndType<[1], [F16, F32]>]>; +def GPU_MMAMemRef : MemRefOf<[I8, I32, F16, F32, VectorOfRankAndType<[1], [I8, I32, F16, F32]>]>; class MMAMatrixOf allowedTypes> : ContainerType, IsMMAMatrixTypePred, diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td --- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td +++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td @@ -1150,6 +1150,10 @@ matrix which eventually allows the lowering to determine the size of each row. If the `transpose` attribute is present then the op does a transposed load. + For integer types, the resulting `!gpu.mma_matrix` type needs to specify the + signedness of the data if the matrix type is an `A` or `B` operand for + `gpu.subgroup_mma_compute`. + This op is often meant to be used along with `gpu.subgroup_mma_store_matrix` and `gpu.subgroup_mma_compute`. @@ -1201,7 +1205,7 @@ ``` }]; - let arguments = (ins Arg>:$src, + let arguments = (ins Arg>:$src, Arg:$dstMemref, Variadic:$indices, IndexAttr:$leadDimension, @@ -1227,11 +1231,15 @@ as `C += A * B`. The op returns a `!gpu.mma_matrix` which contains the result of the operation held by all threads in a subgroup. `a_transpose` or `b_transpose` if present, signify that the respective operand was loaded in a - transposed manner. The transpose opernads are required to map to correct + transposed manner. The transpose operands are required to map to correct underlying intrisics but they currently do not seem to affect correctness even if they are absent given that the operands were loaded correctly using the `transpose` attribute in `gpu.subgroup_mma_load_matrix` op. + For integer types, the `A` and `B` matrices carry their signedness with their + types. The accumulator type is expected to be signless and imply a signed integer + with a greater width than the other two operands. + This op is meant to be used along with `gpu.subgroup_mma_store_matrix` and `gpu.subgroup_mma_load_matrix` ops. @@ -1244,9 +1252,9 @@ ``` }]; - let arguments = (ins Arg>:$opA, - Arg>:$opB, - Arg>:$opC, + let arguments = (ins Arg>:$opA, + Arg>:$opB, + Arg>:$opC, OptionalAttr:$a_transpose, OptionalAttr:$b_transpose); @@ -1288,7 +1296,7 @@ ``` }]; - let arguments = (ins AnyTypeOf<[F16, F32]>:$value); + let arguments = (ins AnyTypeOf<[SI8, UI8, I32, F16, F32]>:$value); let results = (outs GPU_MMAMatrix:$res); diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h @@ -37,7 +37,8 @@ /// of given chracteristics. This matches the logic in IntrinsicsNVVM.td /// WMMA_REGS structure. std::pair inferMMAType(mlir::NVVM::MMATypes type, - mlir::NVVM::MMAFrag frag, + mlir::NVVM::MMAFrag frag, int nRow, + int nCol, mlir::MLIRContext *context); } // namespace NVVM } // namespace mlir diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -385,16 +385,20 @@ list> fp_wmma_ops = MMA_OPS< [GEOM<16, 16, 16>, GEOM<32, 8, 16>, GEOM<8, 32, 16>], ["f16"], [], ["f16", "f32"], []>.ret; + list> i8_wmma_ops = MMA_OPS< + [GEOM<16, 16, 16>, GEOM<32, 8, 16>, GEOM<8, 32, 16>], + ["s8","u8"], [], ["s32"], []>.ret; list> all_wmma_ops = !listconcat( tf32_wmma_ops, - fp_wmma_ops); + fp_wmma_ops, + i8_wmma_ops); list ldst_ab_ops = MMA_LDST_OPS< [GEOM<16, 16, 16>, GEOM<32, 8, 16>, GEOM<8, 32, 16>], - ["a", "b"], ["f16"]>.ret; + ["a", "b"], ["f16","s8","u8"]>.ret; list ldst_cd_ops = MMA_LDST_OPS< [GEOM<16, 16, 16>, GEOM<32, 8, 16>, GEOM<8, 32, 16>], - ["c", "d"], ["f16", "f32"]>.ret; + ["c", "d"], ["f16", "f32","s32"]>.ret; list ldst_tf32_ab_ops = MMA_LDST_OPS< [GEOM<16, 16, 8>], ["a", "b"], ["tf32"]>.ret; diff --git a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp --- a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp @@ -57,6 +57,12 @@ if (type.getElementType().isF32()) return type.getOperand().equals("COp") ? NVVM::MMATypes::f32 : NVVM::MMATypes::tf32; + + if (type.getElementType().isSignedInteger(8)) + return NVVM::MMATypes::s8; + // Accumulator type is signless and implies signed. + if (type.getElementType().isInteger(32)) + return NVVM::MMATypes::s32; llvm_unreachable("Unsupported type"); } @@ -106,8 +112,11 @@ } NVVM::MMAFrag frag = convertOperand(retType.getOperand()); // Check that there is an exisiting instruction for the combination we need. - if (NVVM::WMMALoadOp::getIntrinsicID(m, n, k, layout, eltype, frag) == 0) + if (NVVM::WMMALoadOp::getIntrinsicID(m, n, k, layout, eltype, frag) == 0) { + llvm::errs() << "No matching intrinsic " << m << " " << n << " " << k + << "\n"; return rewriter.notifyMatchFailure(op, kInvalidCaseStr); + } Type resType = convertMMAToLLVMType(retType); Location loc = op->getLoc(); @@ -366,8 +375,10 @@ LLVM::LLVMStructType mlir::convertMMAToLLVMType(gpu::MMAMatrixType type) { NVVM::MMAFrag frag = convertOperand(type.getOperand()); NVVM::MMATypes eltType = getElementType(type); + auto nRow = type.getShape()[0]; + auto nCol = type.getShape()[1]; std::pair typeInfo = - NVVM::inferMMAType(eltType, frag, type.getContext()); + NVVM::inferMMAType(eltType, frag, nRow, nCol, type.getContext()); return LLVM::LLVMStructType::getLiteral( type.getContext(), SmallVector(typeInfo.second, typeInfo.first)); } diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp --- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp +++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp @@ -140,6 +140,12 @@ return false; if (!getMemrefConstantHorizontalStride(readOp.getShapedType())) return false; + + // Only allow integer types if the signedness can be inferred. + if (!useNvGpu && readOp.getVectorType().getElementType().isInteger(8)) + if (!readOp->hasOneUse() || !isa(*readOp->user_begin())) + return false; + AffineMap map = readOp.getPermutationMap(); OpBuilder b(readOp.getContext()); AffineExpr innerDim = b.getAffineDimExpr(map.getNumDims() - 1); @@ -185,8 +191,16 @@ /// Return true if this is a broadcast from scalar to a 2D vector. static bool broadcastSupportsMMAMatrixType(vector::BroadcastOp broadcastOp) { - return broadcastOp.getVectorType().getRank() == 2 && - broadcastOp.getSource().getType().isa(); + return broadcastOp.getVectorType().getRank() == 2; +} + +/// Return true if this signed extend op can be folded into a contract op. +static bool signedExtendSupportsMMAMatrixType(arith::ExtSIOp extOp) { + if (!isa(extOp.getOperand().getDefiningOp())) + return false; + return llvm::all_of(extOp->getUsers(), [](Operation *user) { + return isa(user); + }); } /// Return the MMA elementwise enum associated with `op` if it is supported. @@ -268,6 +282,8 @@ return constantSupportsMMAMatrixType(constant); if (auto broadcast = dyn_cast(op)) return broadcastSupportsMMAMatrixType(broadcast); + if (auto extend = dyn_cast(op)) + return signedExtendSupportsMMAMatrixType(extend); return elementwiseSupportsMMAMatrixType(op); } @@ -411,8 +427,18 @@ LogicalResult matchAndRewrite(vector::TransposeOp op, PatternRewriter &rewriter) const override { - auto transferReadOp = - op.getVector().getDefiningOp(); + // Look through integer extend ops. + Value source = op.getVector(); + auto extOp = source.getDefiningOp(); + auto resultType = op.getVectorType(); + if (extOp) { + source = extOp.getOperand(); + resultType = + VectorType::get(resultType.getShape(), + source.getType().cast().getElementType()); + } + + auto transferReadOp = source.getDefiningOp(); if (!transferReadOp) return failure(); @@ -431,11 +457,23 @@ AffineMap::getPermutationMap(permU, op.getContext()); AffineMap newMap = permutationMap.compose(transferReadOp.getPermutationMap()); - rewriter.replaceOpWithNewOp( - op, op.getType(), transferReadOp.getSource(), - transferReadOp.getIndices(), AffineMapAttr::get(newMap), - transferReadOp.getPadding(), transferReadOp.getMask(), - transferReadOp.getInBoundsAttr()); + + auto loc = op.getLoc(); + Value result = + rewriter + .create( + loc, resultType, transferReadOp.getSource(), + transferReadOp.getIndices(), AffineMapAttr::get(newMap), + transferReadOp.getPadding(), transferReadOp.getMask(), + transferReadOp.getInBoundsAttr()) + .getResult(); + + // Fuse through the integer extend op. + if (extOp) + result = rewriter.create(loc, op.getType(), result) + .getResult(); + + rewriter.replaceOp(op, result); return success(); } }; @@ -479,14 +517,26 @@ stride = 0; } assert(stride); + Value mappingResult = op.getResult(); + auto elType = op.getVectorType().getElementType(); const char *fragType = inferFragType(op); + if (op->hasOneUse()) { + auto extOp = dyn_cast(*op->user_begin()); + // Infer the signedness of the mma type from the signed extend. + if (extOp) { + elType = IntegerType::get(op.getContext(), + elType.cast().getWidth(), + IntegerType::Signed); + mappingResult = extOp.getResult(); + fragType = inferFragType(extOp); + } + } gpu::MMAMatrixType type = - gpu::MMAMatrixType::get(op.getVectorType().getShape(), - op.getVectorType().getElementType(), fragType); + gpu::MMAMatrixType::get(op.getVectorType().getShape(), elType, fragType); Value load = b.create( op.getLoc(), type, op.getSource(), op.getIndices(), b.getIndexAttr(*stride), isTranspose ? b.getUnitAttr() : UnitAttr()); - valueMapping[op.getResult()] = load; + valueMapping[mappingResult] = load; } static void convertTransferWriteOp(vector::TransferWriteOp op, diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -78,7 +78,9 @@ StringRef MMAMatrixType::getOperand() const { return getImpl()->getOperand(); } bool MMAMatrixType::isValidElementType(Type elementType) { - return elementType.isF16() || elementType.isF32(); + return elementType.isF16() || elementType.isF32() || + elementType.isUnsignedInteger(8) || elementType.isSignedInteger(8) || + elementType.isInteger(32); } LogicalResult @@ -93,7 +95,8 @@ return emitError() << "MMAMatrixType must have exactly two dimensions"; if (!MMAMatrixType::isValidElementType(elementType)) - return emitError() << "MMAMatrixType elements must be F16 or F32"; + return emitError() + << "MMAMatrixType elements must be SI8, UI8, I32, F16, or F32"; return success(); } diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -537,7 +537,8 @@ } std::pair NVVM::inferMMAType(NVVM::MMATypes type, - NVVM::MMAFrag frag, + NVVM::MMAFrag frag, int nRow, + int nCol, MLIRContext *context) { unsigned numberElements = 0; Type elementType; @@ -555,11 +556,48 @@ } else if (type == NVVM::MMATypes::tf32) { elementType = builder.getI32Type(); numberElements = 4; + } else if (type == NVVM::MMATypes::s8 || type == NVVM::MMATypes::u8) { + elementType = builder.getI32Type(); + int parallelSize = 0; + if (frag == NVVM::MMAFrag::a) + parallelSize = nRow; + if (frag == NVVM::MMAFrag::b) + parallelSize = nCol; + + // m == 16 && n == 16 && k == 16 + if (parallelSize == 16) + numberElements = 2; + // m == 8 && n == 32 && k == 16 or m == 32 && n == 8 && k == 16 + else if (parallelSize == 8) + numberElements = 1; + else if (parallelSize == 32) + numberElements = 4; + } else if (type == NVVM::MMATypes::s32) { + elementType = builder.getI32Type(); + numberElements = 8; } assert(numberElements != 0 && elementType != nullptr); return std::make_pair(elementType, numberElements); } +static std::pair +inferMMATypeFromMNK(NVVM::MMATypes type, NVVM::MMAFrag frag, int m, int n, + int k, MLIRContext *context) { + int nRow, nCol; + if (frag == NVVM::MMAFrag::a) { + nRow = m; + nCol = k; + } else if (frag == NVVM::MMAFrag::b) { + nRow = k; + nCol = n; + } else { + nRow = m; + nCol = n; + } + assert(nRow && nCol); + return inferMMAType(type, frag, nRow, nCol, context); +} + LogicalResult NVVM::WMMALoadOp::verify() { unsigned addressSpace = getPtr().getType().cast().getAddressSpace(); @@ -570,8 +608,8 @@ if (NVVM::WMMALoadOp::getIntrinsicID(getM(), getN(), getK(), getLayout(), getEltype(), getFrag()) == 0) return emitOpError() << "invalid attribute combination"; - std::pair typeInfo = - inferMMAType(getEltype(), getFrag(), getContext()); + std::pair typeInfo = inferMMATypeFromMNK( + getEltype(), getFrag(), getM(), getN(), getK(), getContext()); Type dstType = LLVM::LLVMStructType::getLiteral( getContext(), SmallVector(typeInfo.second, typeInfo.first)); if (getType() != dstType) @@ -590,8 +628,8 @@ if (NVVM::WMMAStoreOp::getIntrinsicID(getM(), getN(), getK(), getLayout(), getEltype()) == 0) return emitOpError() << "invalid attribute combination"; - std::pair typeInfo = - inferMMAType(getEltype(), NVVM::MMAFrag::c, getContext()); + std::pair typeInfo = inferMMATypeFromMNK( + getEltype(), NVVM::MMAFrag::c, getM(), getN(), getK(), getContext()); if (getArgs().size() != typeInfo.second) return emitOpError() << "expected " << typeInfo.second << " data operands"; if (llvm::any_of(getArgs(), [&typeInfo](Value operands) { @@ -606,12 +644,12 @@ getLayoutB(), getEltypeA(), getEltypeB()) == 0) return emitOpError() << "invalid attribute combination"; - std::pair typeInfoA = - inferMMAType(getEltypeA(), NVVM::MMAFrag::a, getContext()); - std::pair typeInfoB = - inferMMAType(getEltypeA(), NVVM::MMAFrag::b, getContext()); - std::pair typeInfoC = - inferMMAType(getEltypeB(), NVVM::MMAFrag::c, getContext()); + std::pair typeInfoA = inferMMATypeFromMNK( + getEltypeA(), NVVM::MMAFrag::a, getM(), getN(), getK(), getContext()); + std::pair typeInfoB = inferMMATypeFromMNK( + getEltypeA(), NVVM::MMAFrag::b, getM(), getN(), getK(), getContext()); + std::pair typeInfoC = inferMMATypeFromMNK( + getEltypeB(), NVVM::MMAFrag::c, getM(), getN(), getK(), getContext()); SmallVector arguments; arguments.append(typeInfoA.second, typeInfoA.first); arguments.append(typeInfoB.second, typeInfoB.first); diff --git a/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir --- a/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir +++ b/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir @@ -40,6 +40,45 @@ // ----- +gpu.module @test_module { + + // CHECK-LABEL: func @gpu_wmma_int8_load_op() -> + // CHECK-SAME: !llvm.struct<(i32, i32)> + // CHECK32-LABEL: func @gpu_wmma_int8_load_op() -> + func.func @gpu_wmma_int8_load_op() -> (!gpu.mma_matrix<16x16xsi8, "AOp">) { + %wg = memref.alloca() {alignment = 32} : memref<32x32xi8, 3> + %i = arith.constant 16 : index + %j = arith.constant 16 : index + %0 = gpu.subgroup_mma_load_matrix %wg[%i, %j] {leadDimension = 32 : index, transpose} : memref<32x32xi8, 3> -> !gpu.mma_matrix<16x16xsi8, "AOp"> + // CHECK: %[[INX:.*]] = llvm.mlir.constant(16 : index) : i64 + // CHECK: %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[{{.*}}, {{.*}}] + // CHECK: %[[BASE:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: %[[LDM:.*]] = llvm.mlir.constant(32 : index) : i64 + // CHECK: %[[LI:.*]] = llvm.mul %[[INX]], %[[LDM]] : i64 + // CHECK: %[[LIJ:.*]] = llvm.add %[[LI]], %[[INX]] : i64 + // CHECK: %[[ADDRESS:.*]] = llvm.getelementptr %[[BASE]][%[[LIJ]]] : (!llvm.ptr, i64) -> !llvm.ptr + // CHECK: %[[LDM32:.*]] = llvm.mlir.constant(32 : index) : i32 + // CHECK: %[[FRAG:.*]] = nvvm.wmma.load %[[ADDRESS]], %[[LDM32]] + // CHECK-SAME: {eltype = #nvvm.mma_type, frag = #nvvm.mma_frag, k = 16 : i32, layout = #nvvm.mma_layout, m = 16 : i32, n = 16 : i32} : (!llvm.ptr) -> !llvm.struct<(i32, i32)> + // CHECK: llvm.return %[[FRAG]] : !llvm.struct<(i32, i32)> + + // CHECK32: %[[INX:.*]] = llvm.mlir.constant(16 : index) : i32 + // CHECK32: %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[{{.*}}, {{.*}}] + // CHECK32: %[[BASE:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> + // CHECK32: %[[LDM:.*]] = llvm.mlir.constant(32 : index) : i32 + // CHECK32: %[[LI:.*]] = llvm.mul %[[INX]], %[[LDM]] : i32 + // CHECK32: %[[LIJ:.*]] = llvm.add %[[LI]], %[[INX]] : i32 + // CHECK32: %[[ADDRESS:.*]] = llvm.getelementptr %[[BASE]][%[[LIJ]]] : (!llvm.ptr, i32) -> !llvm.ptr + // CHECK32: %[[LDM32:.*]] = llvm.mlir.constant(32 : index) : i32 + // CHECK32: %[[FRAG:.*]] = nvvm.wmma.load %[[ADDRESS]], %[[LDM32]] + // CHECK32-SAME: {eltype = #nvvm.mma_type, frag = #nvvm.mma_frag, k = 16 : i32, layout = #nvvm.mma_layout, m = 16 : i32, n = 16 : i32} : (!llvm.ptr) -> !llvm.struct<(i32, i32)> + // CHECK32: llvm.return %[[FRAG]] : !llvm.struct<(i32, i32)> + return %0 : !gpu.mma_matrix<16x16xsi8, "AOp"> + } +} + +// ----- + gpu.module @test_module { // CHECK-LABEL: func @gpu_wmma_store_op @@ -124,6 +163,35 @@ // ----- +gpu.module @test_module { + + // CHECK-LABEL: func @gpu_wmma_mma_int8_op + // CHECK-SAME: (%[[A:.*]]: !llvm.struct<(i32, i32, i32, i32)>, %[[B:.*]]: !llvm.struct<(i32)>, %[[C:.*]]: !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)>) + func.func @gpu_wmma_mma_int8_op(%A : !gpu.mma_matrix<32x16xsi8, "AOp">, %B : !gpu.mma_matrix<16x8xsi8, "BOp">, %C : !gpu.mma_matrix<32x8xi32, "COp">) -> (!gpu.mma_matrix<32x8xi32, "COp">) { + %D = gpu.subgroup_mma_compute %A, %B, %C {a_transpose} : !gpu.mma_matrix<32x16xsi8, "AOp">, !gpu.mma_matrix<16x8xsi8, "BOp"> -> !gpu.mma_matrix<32x8xi32, "COp"> + // CHECK: %[[A1:.*]] = llvm.extractvalue %[[A]][0] : !llvm.struct<(i32, i32, i32, i32)> + // CHECK: %[[A2:.*]] = llvm.extractvalue %[[A]][1] : !llvm.struct<(i32, i32, i32, i32)> + // CHECK: %[[A3:.*]] = llvm.extractvalue %[[A]][2] : !llvm.struct<(i32, i32, i32, i32)> + // CHECK: %[[A4:.*]] = llvm.extractvalue %[[A]][3] : !llvm.struct<(i32, i32, i32, i32)> + // CHECK: %[[B1:.*]] = llvm.extractvalue %[[B]][0] : !llvm.struct<(i32)> + // CHECK: %[[C1:.*]] = llvm.extractvalue %[[C]][0] : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)> + // CHECK: %[[C2:.*]] = llvm.extractvalue %[[C]][1] : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)> + // CHECK: %[[C3:.*]] = llvm.extractvalue %[[C]][2] : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)> + // CHECK: %[[C4:.*]] = llvm.extractvalue %[[C]][3] : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)> + // CHECK: %[[C5:.*]] = llvm.extractvalue %[[C]][4] : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)> + // CHECK: %[[C6:.*]] = llvm.extractvalue %[[C]][5] : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)> + // CHECK: %[[C7:.*]] = llvm.extractvalue %[[C]][6] : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)> + // CHECK: %[[C8:.*]] = llvm.extractvalue %[[C]][7] : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)> + // CHECK: %[[RES:.*]] = nvvm.wmma.mma %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[B1]], %[[C1]], %[[C2]], %[[C3]], %[[C4]], %[[C5]], %[[C6]], %[[C7]], %[[C8]] + // CHECK-SAME: {eltypeA = #nvvm.mma_type, eltypeB = #nvvm.mma_type, k = 16 : i32, layoutA = #nvvm.mma_layout, layoutB = #nvvm.mma_layout, m = 32 : i32, n = 8 : i32} : ( + // CHECK-SAME: i32, {{.*}}) -> !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)> + // CHECK: llvm.return %[[RES]] : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)> + return %D : !gpu.mma_matrix<32x8xi32, "COp"> + } +} + +// ----- + gpu.module @test_module { // CHECK-LABEL: func @gpu_wmma_mma_loop_op diff --git a/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir b/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir --- a/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir +++ b/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir @@ -225,3 +225,44 @@ vector.transfer_write %D, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<16x16xf16> return } + +// Do not convert to subgroup_mma ops with integer types if signedness cannot be inferred. +// CHECK-LABEL: func @matmul_no_extend_int8 +// CHECK-DAG: %[[A:.+]] = vector.transfer_read %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}} {in_bounds = [true, true]} : memref<16x16xi8>, vector<16x16xi8> +// CHECK-DAG: %[[B:.+]] = vector.transfer_read %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}} {in_bounds = [true, true]} : memref<16x16xi8>, vector<16x16xi8> +// CHECK-DAG: %[[C:.+]] = vector.transfer_read %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}} {in_bounds = [true, true]} : memref<16x16xi32>, vector<16x16xi32> +// CHECK: %[[D:.+]] = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %[[A]], %[[B]], %[[C]] : vector<16x16xi8>, vector<16x16xi8> into vector<16x16xi32> +// CHECK: vector.transfer_write %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] {in_bounds = [true, true]} : vector<16x16xi32>, memref<16x16xi32> +func.func @matmul_no_extend_int8(%arg0: memref<16x16xi8>, %arg1: memref<16x16xi8>, %arg2: memref<16x16xi32>) { + %cst_0 = arith.constant dense<0> : vector<16x16xi8> + %c0 = arith.constant 0 : index + %cst_i8 = arith.constant 0 : i8 + %cst_i32 = arith.constant 0 : i32 + %A = vector.transfer_read %arg0[%c0, %c0], %cst_i8 {in_bounds = [true, true]} : memref<16x16xi8>, vector<16x16xi8> + %B = vector.transfer_read %arg1[%c0, %c0], %cst_i8 {permutation_map = #map0, in_bounds = [true, true]} : memref<16x16xi8>, vector<16x16xi8> + %C = vector.transfer_read %arg2[%c0, %c0], %cst_i32 {in_bounds = [true, true]} : memref<16x16xi32>, vector<16x16xi32> + %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %A, %B, %C : vector<16x16xi8>, vector<16x16xi8> into vector<16x16xi32> + vector.transfer_write %D, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xi32>, memref<16x16xi32> + return +} + +// CHECK-LABEL: func @matmul_int8 +// CHECK-DAG: %[[A:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index} : memref<16x16xi8> -> !gpu.mma_matrix<16x16xsi8, "AOp"> +// CHECK-DAG: %[[B:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index} : memref<16x16xi8> -> !gpu.mma_matrix<16x16xsi8, "BOp"> +// CHECK-DAG: %[[C:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index} : memref<16x16xi32> -> !gpu.mma_matrix<16x16xi32, "COp"> +// CHECK: %[[D:.+]] = gpu.subgroup_mma_compute %[[A]], %[[B]], %[[C]] : !gpu.mma_matrix<16x16xsi8, "AOp">, !gpu.mma_matrix<16x16xsi8, "BOp"> -> !gpu.mma_matrix<16x16xi32, "COp"> +// CHECK: gpu.subgroup_mma_store_matrix %[[D]], %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index} : !gpu.mma_matrix<16x16xi32, "COp">, memref<16x16xi32> +func.func @matmul_int8(%arg0: memref<16x16xi8>, %arg1: memref<16x16xi8>, %arg2: memref<16x16xi32>) { + %cst_0 = arith.constant dense<0> : vector<16x16xi8> + %c0 = arith.constant 0 : index + %cst_i8 = arith.constant 0 : i8 + %cst_i32 = arith.constant 0 : i32 + %Ar = vector.transfer_read %arg0[%c0, %c0], %cst_i8 {in_bounds = [true, true]} : memref<16x16xi8>, vector<16x16xi8> + %Br = vector.transfer_read %arg1[%c0, %c0], %cst_i8 {permutation_map = #map0, in_bounds = [true, true]} : memref<16x16xi8>, vector<16x16xi8> + %C = vector.transfer_read %arg2[%c0, %c0], %cst_i32 {in_bounds = [true, true]} : memref<16x16xi32>, vector<16x16xi32> + %Ae = arith.extsi %Ar : vector<16x16xi8> to vector<16x16xi32> + %Be = arith.extsi %Br : vector<16x16xi8> to vector<16x16xi32> + %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %Ae, %Be, %C : vector<16x16xi32>, vector<16x16xi32> into vector<16x16xi32> + vector.transfer_write %D, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xi32>, memref<16x16xi32> + return +} diff --git a/mlir/test/Dialect/GPU/invalid.mlir b/mlir/test/Dialect/GPU/invalid.mlir --- a/mlir/test/Dialect/GPU/invalid.mlir +++ b/mlir/test/Dialect/GPU/invalid.mlir @@ -485,8 +485,8 @@ func.func @mmamatrix_invalid_element_type(){ %wg = memref.alloca() {alignment = 32} : memref<32x32xf16, 3> %i = arith.constant 16 : index - // expected-error @+1 {{MMAMatrixType elements must be F16 or F32}} - %0 = gpu.subgroup_mma_load_matrix %wg[%i, %i] {leadDimension = 32 : index} : memref<32x32xf16, 3> -> !gpu.mma_matrix<16x16xi32, "AOp"> + // expected-error @+1 {{MMAMatrixType elements must be SI8, UI8, I32, F16, or F32}} + %0 = gpu.subgroup_mma_load_matrix %wg[%i, %i] {leadDimension = 32 : index} : memref<32x32xf16, 3> -> !gpu.mma_matrix<16x16xbf16, "AOp"> return } @@ -505,7 +505,7 @@ // ----- func.func @mma_invalid_memref_type(%src: memref<32x4xvector<4x8xf32>>, %i: index) { - // expected-error @+1 {{operand #0 must be memref of 16-bit float or 32-bit float or vector of 16-bit float or 32-bit float values of ranks 1 values}} + // expected-error @+1 {{operand #0 must be memref of 8-bit signless integer or 32-bit signless integer or 16-bit float or 32-bit float or vector of 8-bit signless integer or 32-bit signless integer or 16-bit float or 32-bit float values of ranks 1 values}} %0 = gpu.subgroup_mma_load_matrix %src[%i, %i] {leadDimension = 4 : index} : memref<32x4xvector<4x8xf32>> -> !gpu.mma_matrix<16x16xf16, "AOp"> return }