diff --git a/mlir/include/mlir/Dialect/GPU/GPUDialect.h b/mlir/include/mlir/Dialect/GPU/GPUDialect.h --- a/mlir/include/mlir/Dialect/GPU/GPUDialect.h +++ b/mlir/include/mlir/Dialect/GPU/GPUDialect.h @@ -165,6 +165,13 @@ /// C += A*B. This function returns which operand in the given equation is /// held by this type. String returned can be one of"AOp", "BOp" and "COp". StringRef getOperand() const; + + /// Returns whether this operand represents an accumulator or result type. + bool isAccOrResult() const { return getOperand() == "COp"; } + + int64_t getElementTypeBitWidth() const { + return getElementType().getIntOrFloatBitWidth(); + } }; // Adds a `gpu.async.token` to the front of the argument list. diff --git a/mlir/include/mlir/Dialect/GPU/GPUOps.td b/mlir/include/mlir/Dialect/GPU/GPUOps.td --- a/mlir/include/mlir/Dialect/GPU/GPUOps.td +++ b/mlir/include/mlir/Dialect/GPU/GPUOps.td @@ -96,6 +96,19 @@ }]; } +def GPU_LaneIdOp : GPU_Op<"lane_id", [NoSideEffect]> { + let description = [{ + Returns the lane id within the subgroup (warp/wave). + + Example: + ```mlir + %laneId = gpu.lane_id : index + ``` + }]; + let results = (outs Index:$result); + let assemblyFormat = "attr-dict `:` type($result)"; +} + def GPU_SubgroupIdOp : GPU_Op<"subgroup_id", [NoSideEffect]>, Arguments<(ins)>, Results<(outs Index:$result)> { let description = [{ @@ -1354,4 +1367,58 @@ }]; } +def GPU_MmaLdMatrixOp : GPU_Op<"mma.ldmatrix", + [MemoryEffects<[MemRead]>]> { + let description = [{ + The `gpu.mma.ldmatrix` op represents loading a matrix fragment from + memory. The load source and result type must be compatible with lowering + to the `nvvm.ldmatrix` instruction. This op is meant to represent + the distributed version of a `vector.transfer_read` as an intermediate + step between lowering from `vector.transfer_read` to `nvvm.ldmatrix`. + + Example: + + ```mlir + gpu.mma.ldmatrix %shm_buffer[%c0, %c0] : memref<16x16xf16, 3> -> vector<4x2xf16> + ``` + }]; + + let arguments = (ins Arg:$srcMemref, + Variadic:$indices, BoolAttr:$transpose, + I32Attr:$numTiles); + let results = (outs AnyVector:$res); + let assemblyFormat = [{ + $srcMemref`[` $indices `]` attr-dict `:` type($srcMemref) `->` type($res) + }]; +} + +def GPU_MmaSyncOp : GPU_Op<"mma.sync", [NoSideEffect]> { + let description = [{ + The `gpu.mma.sync` op represents the distributed form of a collective + matrix-multiply-and-accumulate (mma) operation that is compatible with + `nvvm.mma.sync`. The operands and results are fragments of the full matrix + operands. The full shape of the distributed mma operation is given by the + `mmaShape` attribute in the form of a list of dimensions `[m, n, k]`. + + This operation is meant to be lowered to the `nvvm.mma.sync` instruction, and + is an intermediate point between lowering from `vector.contract` to + `nvvm.mma.sync`. + + Example: + + ```mlir + gpu.mma.sync (%a, %b, %c) : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16> + ``` + }]; + let arguments = (ins AnyVector:$matrixA, AnyVector:$matrixB, AnyVector:$matrixC, + I64ArrayAttr:$mmaShape); + + let results = (outs AnyVector:$res); + + let assemblyFormat = [{ + `(` $matrixA`,` $matrixB`,` $matrixC `)` attr-dict + `:` `(` type($matrixA) `,` type($matrixB) `,` type($matrixC) `)` `->` type($res) + }]; +} + #endif // GPU_OPS diff --git a/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h b/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h --- a/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h +++ b/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h @@ -62,6 +62,34 @@ } }; +template +struct GPULaneIdIntrinsicOpLowering : ConvertOpToLLVMPattern { +private: + unsigned indexBitwidth; + +public: + explicit GPULaneIdIntrinsicOpLowering(LLVMTypeConverter &typeConverter) + : ConvertOpToLLVMPattern(typeConverter), + indexBitwidth(typeConverter.getIndexTypeBitwidth()) {} + // Convert the kernel arguments to an LLVM type, preserve the rest. + LogicalResult + matchAndRewrite(SrcOp op, typename SrcOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + MLIRContext *context = rewriter.getContext(); + Value newOp = rewriter.create(loc, rewriter.getI32Type()); + if (indexBitwidth > 32) { + newOp = rewriter.create( + loc, IntegerType::get(context, indexBitwidth), newOp); + } else if (indexBitwidth < 32) { + newOp = rewriter.create( + loc, IntegerType::get(context, indexBitwidth), newOp); + } + rewriter.replaceOp(op, {newOp}); + return success(); + } +}; + } // namespace mlir #endif // MLIR_CONVERSION_GPUCOMMON_INDEXINTRINSICSOPLOWERING_H_ diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp --- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp @@ -18,17 +18,27 @@ #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" #include "mlir/Conversion/LLVMCommon/LoweringOptions.h" +#include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Conversion/LLVMCommon/TypeConverter.h" #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" +#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/GPU/GPUDialect.h" #include "mlir/Dialect/GPU/Passes.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" #include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/Support/FormatVariadic.h" @@ -208,6 +218,280 @@ } }; +struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(gpu::MmaLdMatrixOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + MLIRContext *ctx = getContext(); + Location loc = op->getLoc(); + + // The result type of ldmatrix will always be a struct of 32bit integer + // registers if more than one 32bit value is returned. Otherwise, the result + // is a single i32. The result type of the GPU operation is always a vector + // of shape (NumRegisters, VectorRegister) where VectorRegister is the + // vector type of the result and always 32 bits long. We bitcast the result + // of the NVVM::LdMatrix to this vector type. + auto vectorResultType = op->getResultTypes()[0].dyn_cast(); + if (!vectorResultType) + return failure(); + Type innerVectorType = LLVM::getFixedVectorType( + vectorResultType.getElementType(), vectorResultType.getShape()[1]); + + int64_t num32BitRegs = vectorResultType.getShape()[0]; + + Type ldMatrixResultType; + if (num32BitRegs > 1) { + ldMatrixResultType = LLVM::LLVMStructType::getLiteral( + ctx, SmallVector(vectorResultType.getShape()[0], + rewriter.getI32Type())); + } else { + ldMatrixResultType = rewriter.getI32Type(); + } + + auto srcMemrefType = op.srcMemref().getType().cast(); + Value srcPtr = getStridedElementPtr(loc, srcMemrefType, adaptor.srcMemref(), + adaptor.indices(), rewriter); + Value ldMatrixResult = rewriter.create( + loc, ldMatrixResultType, srcPtr, + /*num=*/op.numTiles(), + /*layout=*/op.transpose() ? NVVM::MMALayout::col + : NVVM::MMALayout::row); + + Type finalResultType = typeConverter->convertType(vectorResultType); + Value result = rewriter.create(loc, finalResultType); + for (int i = 0; i < vectorResultType.getShape()[0]; i++) { + Value i32Register = num32BitRegs > 1 + ? rewriter.create( + loc, rewriter.getI32Type(), ldMatrixResult, + rewriter.getI64ArrayAttr(i)) + : ldMatrixResult; + Value casted = + rewriter.create(loc, innerVectorType, i32Register); + result = rewriter.create( + loc, finalResultType, result, casted, rewriter.getI64ArrayAttr(i)); + } + + rewriter.replaceOp(op, result); + return success(); + } +}; + +struct MmaSyncOptoNVVM : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + +private: + /// Checks if all the operands of the op being lowered are of LLVM Types. The + /// types are expected to be converted by the `LLVMTypeConverter` before the + /// op is actually lowered. If the type of an operands is not already + /// converted it hints a missing typeConversion and failure is returned in + /// that case. + static LogicalResult areAllLLVMTypes(Operation *op, ValueRange operands, + ConversionPatternRewriter &rewriter) { + if (!llvm::all_of(operands, [](Value value) { + return LLVM::isCompatibleType(value.getType()); + })) { + return rewriter.notifyMatchFailure( + op, "cannot convert if operands aren't of LLVM type."); + } + + return success(); + } + + /// Returns the type for the intrinsic given the vectorResultType of the + /// `gpu.mma.sync` operation. + static Type inferIntrinsicResultType(Type vectorResultType) { + MLIRContext *ctx = vectorResultType.getContext(); + auto a = vectorResultType.cast(); + auto f16x2Ty = LLVM::getFixedVectorType(Float16Type::get(ctx), 2); + auto i32Ty = IntegerType::get(ctx, 32); + auto i32x2Ty = LLVM::getFixedVectorType(i32Ty, 2); + Type f64Ty = Float64Type::get(ctx); + Type f64x2Ty = LLVM::getFixedVectorType(f64Ty, 2); + if (a.getElementType() == f16x2Ty) { + return LLVM::LLVMStructType::getLiteral( + ctx, SmallVector(a.getNumElements(), f16x2Ty)); + } + if (a.getElementType() == i32x2Ty) { + return LLVM::LLVMStructType::getLiteral( + ctx, SmallVector(static_cast(a.getNumElements()) * 2, + i32Ty)); + } + if (a.getElementType() == f64x2Ty) { + return LLVM::LLVMStructType::getLiteral(ctx, {f64Ty, f64Ty}); + } + return vectorResultType; + } + + /// Convert the SSA result of the NVVM intrinsic `nvvm.mma.sync` (which is + /// always an LLVM struct) into a fragment that is compatible with the vector + /// type of this operation. This involves extracting elements from the struct + /// and inserting them into an LLVM array. These extra data-movement + /// operations should be canonicalized away by the LLVM backend. + static Value convertIntrinsicResult(Location loc, Type intrinsicResultType, + Type resultType, Value intrinsicResult, + RewriterBase &rewriter) { + MLIRContext *ctx = rewriter.getContext(); + auto structType = intrinsicResultType.dyn_cast(); + auto arrayType = resultType.dyn_cast(); + Type i32Ty = rewriter.getI32Type(); + Type f64Ty = rewriter.getF64Type(); + Type f16x2Ty = LLVM::getFixedVectorType(rewriter.getF16Type(), 2); + Type i32x2Ty = LLVM::getFixedVectorType(i32Ty, 2); + Type f64x2Ty = LLVM::getFixedVectorType(f64Ty, 2); + + auto makeConst = [&](int32_t index) -> Value { + return rewriter.create( + loc, IntegerType::get(ctx, 32), rewriter.getI32IntegerAttr(index)); + }; + + if (arrayType) { + SmallVector elements; + + if (arrayType.getElementType() == f16x2Ty) { + for (unsigned i = 0; i < structType.getBody().size(); i++) { + elements.push_back(rewriter.create( + loc, structType.getBody()[i], intrinsicResult, + rewriter.getI64ArrayAttr(i))); + } + } + + // The intrinsic returns i32 and f64 values as individual scalars. We need + // to extract them from the struct and pack them into vectors. + if (arrayType.getElementType() == i32x2Ty || + arrayType.getElementType() == f64x2Ty) { + Value vec = + rewriter.create(loc, arrayType.getElementType()); + for (unsigned i = 0; i < structType.getBody().size() / 2; i++) { + Value x1 = rewriter.create( + loc, structType.getBody()[i * 2], intrinsicResult, + rewriter.getI64ArrayAttr(i * 2)); + Value x2 = rewriter.create( + loc, structType.getBody()[i * 2 + 1], intrinsicResult, + rewriter.getI64ArrayAttr(i * 2 + 1)); + vec = rewriter.create(loc, vec.getType(), vec, + x1, makeConst(0)); + vec = rewriter.create(loc, vec.getType(), vec, + x2, makeConst(1)); + } + elements.push_back(vec); + } + + // Create the final vectorized result. + Value result = rewriter.create(loc, arrayType); + for (auto el : llvm::enumerate(elements)) { + result = rewriter.create( + loc, arrayType, result, el.value(), + rewriter.getI64ArrayAttr(el.index())); + } + return result; + } + + return intrinsicResult; + } + + static SmallVector + unpackOperandVector(RewriterBase &rewriter, Location loc, Value operand) { + SmallVector result; + Type i32Ty = rewriter.getI32Type(); + Type f64Ty = rewriter.getF64Type(); + Type i8Ty = rewriter.getI8Type(); + auto arrayTy = operand.getType().cast(); + for (size_t i = 0, e = arrayTy.getNumElements(); i < e; ++i) { + Value toUse = rewriter.create( + loc, arrayTy.getElementType(), operand, rewriter.getI64ArrayAttr(i)); + + // For 4xi8 vectors, the intrinsic expects these to be provided as i32 + // scalar types. + Type i8x4Ty = LLVM::getFixedVectorType(i8Ty, 4); + if (arrayTy.getElementType() == i8x4Ty) { + result.push_back(rewriter.create( + loc, rewriter.getI32Type(), toUse)); + continue; + } + + // For some element types (i32, f64), we need to unpack the inner + // vector/array type as well because the intrinsic expects individual + // scalars to be provided. + VectorType innerArrayTy = arrayTy.getElementType().dyn_cast(); + if (innerArrayTy && (innerArrayTy.getElementType() == i32Ty || + innerArrayTy.getElementType() == f64Ty)) { + for (int idx = 0; idx < innerArrayTy.getNumElements(); idx++) { + result.push_back(rewriter.create( + loc, toUse, + rewriter.create( + loc, rewriter.getI64Type(), + rewriter.getI64IntegerAttr(idx)))); + } + continue; + } + result.push_back(toUse); + } + return result; + } + +public: + LogicalResult + matchAndRewrite(gpu::MmaSyncOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + + if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter))) { + return failure(); + } + + // Get the shapes of the MMAMatrix type being used. The shapes will + // choose which intrinsic this op will be lowered to. + auto aType = op.matrixA().getType().cast(); + + int64_t m = op.mmaShape()[0].cast().getInt(); + int64_t n = op.mmaShape()[1].cast().getInt(); + int64_t k = op.mmaShape()[2].cast().getInt(); + std::array gemmShape{m, n, k}; + + auto matA = unpackOperandVector(rewriter, loc, adaptor.matrixA()); + auto matB = unpackOperandVector(rewriter, loc, adaptor.matrixB()); + auto matC = unpackOperandVector(rewriter, loc, adaptor.matrixC()); + + NVVM::MMATypes ptxTypeA; + NVVM::MMATypes ptxTypeB; + Optional overflow(llvm::None); + if (aType.getElementType().isInteger(8)) { + ptxTypeA = NVVM::MMATypes::s8; + ptxTypeB = NVVM::MMATypes::s8; + overflow = NVVM::MMAIntOverflow::satfinite; + + } else if (aType.getElementType().isF16()) { + ptxTypeA = NVVM::MMATypes::f16; + ptxTypeB = NVVM::MMATypes::f16; + } else if (aType.getElementType().isF64()) { + ptxTypeA = NVVM::MMATypes::f64; + ptxTypeB = NVVM::MMATypes::f64; + } else { + return op->emitError("could not deduce operand PTX types"); + } + + Type desiredRetTy = typeConverter->convertType(op->getResultTypes()[0]); + Type intrinsicResTy = inferIntrinsicResultType( + typeConverter->convertType(op->getResultTypes()[0])); + Value intrinsicResult = rewriter.create( + op.getLoc(), intrinsicResTy, matA, matB, matC, + /*shape=*/gemmShape, + /*b1Op=*/llvm::None, + /*intOverflow=*/overflow, + /*multiplicandPtxTypes=*/ + std::array{ptxTypeA, ptxTypeB}, + /*multiplicandLayouts=*/ + std::array{NVVM::MMALayout::row, + NVVM::MMALayout::col}); + rewriter.replaceOp(op, convertIntrinsicResult(op.getLoc(), intrinsicResTy, + desiredRetTy, intrinsicResult, + rewriter)); + return success(); + } +}; + /// Import the GPU Ops to NVVM Patterns. #include "GPUToNVVM.cpp.inc" @@ -303,7 +587,9 @@ NVVM::BlockIdYOp, NVVM::BlockIdZOp>, GPUIndexIntrinsicOpLowering, - GPUShuffleOpLowering, GPUReturnOpLowering>(converter); + GPULaneIdIntrinsicOpLowering, + GPUShuffleOpLowering, GPUReturnOpLowering, MmaSyncOptoNVVM, + MmaLdMatrixOpToNVVM>(converter); // Explicitly drop memory space when lowering private memory // attributions since NVVM models it as `alloca`s in the default diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -60,7 +60,8 @@ StringRef MMAMatrixType::getOperand() const { return getImpl()->getOperand(); } bool MMAMatrixType::isValidElementType(Type elementType) { - return elementType.isF16() || elementType.isF32(); + return elementType.isF16() || elementType.isF32() || elementType.isF64() || + elementType.isInteger(8) || elementType.isInteger(32); } LogicalResult diff --git a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir --- a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir +++ b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir @@ -6,7 +6,8 @@ // CHECK32-LABEL: func @gpu_index_ops() func.func @gpu_index_ops() -> (index, index, index, index, index, index, - index, index, index, index, index, index) { + index, index, index, index, index, index, + index) { // CHECK32-NOT: = llvm.sext %{{.*}} : i32 to i64 // CHECK: = nvvm.read.ptx.sreg.tid.x : i32 @@ -49,10 +50,17 @@ // CHECK: = llvm.sext %{{.*}} : i32 to i64 %gDimZ = gpu.grid_dim z + + // CHECK: = nvvm.read.ptx.sreg.laneid : i32 + // CHECK: = llvm.sext %{{.*}} : i32 to i64 + %laneId = gpu.lane_id : index + func.return %tIdX, %tIdY, %tIdZ, %bDimX, %bDimY, %bDimZ, - %bIdX, %bIdY, %bIdZ, %gDimX, %gDimY, %gDimZ + %bIdX, %bIdY, %bIdZ, %gDimX, %gDimY, %gDimZ, + %laneId : index, index, index, index, index, index, - index, index, index, index, index, index + index, index, index, index, index, index, + index } } diff --git a/mlir/test/Conversion/GPUToNVVM/mma-sync-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/mma-sync-to-nvvm.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/GPUToNVVM/mma-sync-to-nvvm.mlir @@ -0,0 +1,129 @@ +// RUN: mlir-opt --convert-gpu-to-nvvm --split-input-file %s | FileCheck %s + +gpu.module @test_module { + // CHECK-LABEL: @m16n8k16_fp16 + func @m16n8k16_fp16(%arg0: vector<4x2xf16>, %arg1: vector<2x2xf16>, %arg2: vector<2x2xf16>) -> vector<2x2xf16> { + // CHECK: llvm.extractvalue %arg0[0] : !llvm.array<4 x vector<2xf16>> + // CHECK: llvm.extractvalue %arg0[1] : !llvm.array<4 x vector<2xf16>> + // CHECK: llvm.extractvalue %arg0[2] : !llvm.array<4 x vector<2xf16>> + // CHECK: llvm.extractvalue %arg0[3] : !llvm.array<4 x vector<2xf16>> + + // CHECK: llvm.extractvalue %arg1[0] : !llvm.array<2 x vector<2xf16>> + // CHECK: llvm.extractvalue %arg1[1] : !llvm.array<2 x vector<2xf16>> + + // CHECK: llvm.extractvalue %arg2[0] : !llvm.array<2 x vector<2xf16>> + // CHECK: llvm.extractvalue %arg2[1] : !llvm.array<2 x vector<2xf16>> + // CHECK-NOT llvm.extractvalue + // CHECK: [[d:%.+]] = nvvm.mma.sync + // CHECK-SAME: shape = {k = 16 : i32, m = 16 : i32, n = 8 : i32} + %d = gpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 16]} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16> + // CHECK-DAG: llvm.extractvalue [[d]][0] : !llvm.struct<(vector<2xf16>, vector<2xf16>)> + // CHECK-DAG: llvm.extractvalue [[d]][1] : !llvm.struct<(vector<2xf16>, vector<2xf16>)> + // CHECK: llvm.mlir.undef : !llvm.array<2 x vector<2xf16>> + // CHECK-DAG: llvm.insertvalue {{%.+}}, {{%.+}}[0] : !llvm.array<2 x vector<2xf16>> + // CHECK-DAG: llvm.insertvalue {{%.+}}, {{%.+}}[1] : !llvm.array<2 x vector<2xf16>> + // CHECK: llvm.return {{%.+}} : !llvm.array<2 x vector<2xf16>> + return %d : vector<2x2xf16> + } + + // CHECK-LABEL: @m16n8k8_fp16 + func @m16n8k8_fp16(%arg0: vector<2x2xf16>, %arg1: vector<1x2xf16>, %arg2: vector<2x2xf16>) -> vector<2x2xf16> { + // CHECK: llvm.extractvalue %arg0[0] : !llvm.array<2 x vector<2xf16>> + // CHECK: llvm.extractvalue %arg0[1] : !llvm.array<2 x vector<2xf16>> + + // CHECK: llvm.extractvalue %arg1[0] : !llvm.array<1 x vector<2xf16>> + + // CHECK: llvm.extractvalue %arg2[0] : !llvm.array<2 x vector<2xf16>> + // CHECK: llvm.extractvalue %arg2[1] : !llvm.array<2 x vector<2xf16>> + // CHECK-NOT llvm.extractvalue + // CHECK: [[d:%.+]] = nvvm.mma.sync + // CHECK-SAME: shape = {k = 8 : i32, m = 16 : i32, n = 8 : i32} + %d = gpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 8]} : (vector<2x2xf16>, vector<1x2xf16>, vector<2x2xf16>) -> vector<2x2xf16> + // CHECK-DAG: llvm.extractvalue [[d]][0] : !llvm.struct<(vector<2xf16>, vector<2xf16>)> + // CHECK-DAG: llvm.extractvalue [[d]][1] : !llvm.struct<(vector<2xf16>, vector<2xf16>)> + // CHECK: llvm.mlir.undef : !llvm.array<2 x vector<2xf16>> + // CHECK-DAG: llvm.insertvalue {{%.+}}, {{%.+}}[0] : !llvm.array<2 x vector<2xf16>> + // CHECK-DAG: llvm.insertvalue {{%.+}}, {{%.+}}[1] : !llvm.array<2 x vector<2xf16>> + // CHECK: llvm.return {{%.+}} : !llvm.array<2 x vector<2xf16>> + return %d : vector<2x2xf16> + } + + // CHECK-LABEL: @m16n8k32_int8 + func @m16n8k32_int8(%arg0: vector<4x4xi8>, %arg1: vector<2x4xi8>, %arg2: vector<2x2xi32>) -> vector<2x2xi32> { + + // CHECK: [[el:%.+]] = llvm.extractvalue %arg0[{{.*}}] : !llvm.array<4 x vector<4xi8>> + // CHECK: llvm.bitcast [[el]] : vector<4xi8> to i32 + // CHECK: [[el:%.+]] = llvm.extractvalue %arg0[{{.*}}] : !llvm.array<4 x vector<4xi8>> + // CHECK: llvm.bitcast [[el]] : vector<4xi8> to i32 + // CHECK: [[el:%.+]] = llvm.extractvalue %arg0[{{.*}}] : !llvm.array<4 x vector<4xi8>> + // CHECK: llvm.bitcast [[el]] : vector<4xi8> to i32 + // CHECK: [[el:%.+]] = llvm.extractvalue %arg0[{{.*}}] : !llvm.array<4 x vector<4xi8>> + // CHECK: llvm.bitcast [[el]] : vector<4xi8> to i32 + + // CHECK: [[el:%.+]] = llvm.extractvalue %arg1[{{.*}}] : !llvm.array<2 x vector<4xi8>> + // CHECK: llvm.bitcast [[el]] : vector<4xi8> to i32 + // CHECK: [[el:%.+]] = llvm.extractvalue %arg1[{{.*}}] : !llvm.array<2 x vector<4xi8>> + // CHECK: llvm.bitcast [[el]] : vector<4xi8> to i32 + + // CHECK: [[el:%.+]] = llvm.extractvalue %arg2[{{.*}}] : !llvm.array<2 x vector<2xi32>> + // CHECK: [[el:%.+]] = llvm.extractvalue %arg2[{{.*}}] : !llvm.array<2 x vector<2xi32>> + + // CHECK: [[d:%.+]] = nvvm.mma.sync + // CHECK-SAME: intOverflowBehavior = #nvvm.mma_int_overflow + // CHECK-SAME: multiplicandAPtxType = #nvvm.mma_type + // CHECK-SAME: multiplicandBPtxType = #nvvm.mma_type + // CHECK-SAME: shape = {k = 32 : i32, m = 16 : i32, n = 8 : i32} + %d = gpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 32]} : (vector<4x4xi8>, vector<2x4xi8>, vector<2x2xi32>) -> vector<2x2xi32> + + // CHECK: llvm.return {{%.+}} : !llvm.array<2 x vector<2xi32>> + return %d : vector<2x2xi32> + } + + // CHECK-LABEL: @m8n8k4_f64 + func @m8n8k4_f64(%arg0: vector<1x1xf64>, %arg1: vector<1x1xf64>, %arg2: vector<1x2xf64>) -> vector<1x2xf64> { + // CHECK: llvm.extractvalue %arg0 + // CHECK: llvm.extractvalue %arg1 + // CHECK: llvm.extractvalue %arg2 + + // CHECK: [[d:%.+]] = nvvm.mma.sync A[{{%.+}}] B[{{%.+}}] C[{{%.+}}, {{%.+}}] + // CHECK-SAME: shape = {k = 4 : i32, m = 8 : i32, n = 8 : i32} + %d = gpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [8, 8, 4]} : (vector<1x1xf64>, vector<1x1xf64>, vector<1x2xf64>) -> vector<1x2xf64> + // CHECK: llvm.mlir.undef : vector<2xf64> + // CHECK-DAG: llvm.extractvalue [[d]][0] : !llvm.struct<(f64, f64)> + // CHECK-DAG: llvm.extractvalue [[d]][1] : !llvm.struct<(f64, f64)> + // CHECK-COUNT-2: llvm.insertelement {{.*}} : vector<2xf64> + // CHECK-DAG: llvm.insertvalue {{%.+}}, {{%.+}}[0] : !llvm.array<1 x vector<2xf64>> + // CHECK: llvm.return {{%.+}} : !llvm.array<1 x vector<2xf64>> + return %d : vector<1x2xf64> + } + + // CHECK-LABEL: @ldmatrix_x4 + func @ldmatrix_x4(%arg0: memref<128x128xf16, 3>) -> vector<4x2xf16> { + %c0 = arith.constant 0 : index + // CHECK: nvvm.ldmatrix {{%.+}} {layout = #nvvm.mma_layout, num = 4 : i32} {{.*}} -> !llvm.struct<(i32, i32, i32, i32) + %a = gpu.mma.ldmatrix %arg0[%c0, %c0] {transpose = false, numTiles = 4 : i32} : memref<128x128xf16, 3> -> vector<4x2xf16> + // CHECK: llvm.extractvalue + // CHECK: llvm.bitcast + // CHECK: llvm.insertvalue + // CHECK: llvm.extractvalue + // CHECK: llvm.bitcast + // CHECK: llvm.insertvalue + // CHECK: llvm.extractvalue + // CHECK: llvm.bitcast + // CHECK: llvm.insertvalue + // CHECK: llvm.extractvalue + // CHECK: llvm.bitcast + // CHECK: llvm.insertvalue + return %a : vector<4x2xf16> + } + + // CHECK-LABEL: @ldmatrix_x1 + func @ldmatrix_x1(%arg0: memref<128x128xf16, 3>) -> vector<1x2xf16> { + %c0 = arith.constant 0 : index + // CHECK: nvvm.ldmatrix {{%.+}} {layout = #nvvm.mma_layout, num = 1 : i32} {{.*}} -> i32 + %a = gpu.mma.ldmatrix %arg0[%c0, %c0] {transpose = false, numTiles = 1 : i32} : memref<128x128xf16, 3> -> vector<1x2xf16> + // CHECK: llvm.bitcast + // CHECK: llvm.insertvalue + return %a : vector<1x2xf16> + } +} diff --git a/mlir/test/Dialect/GPU/invalid.mlir b/mlir/test/Dialect/GPU/invalid.mlir --- a/mlir/test/Dialect/GPU/invalid.mlir +++ b/mlir/test/Dialect/GPU/invalid.mlir @@ -476,16 +476,6 @@ // ----- -func @mmamatrix_invalid_element_type(){ - %wg = memref.alloca() {alignment = 32} : memref<32x32xf16, 3> - %i = arith.constant 16 : index - // expected-error @+1 {{MMAMatrixType elements must be F16 or F32}} - %0 = gpu.subgroup_mma_load_matrix %wg[%i, %i] {leadDimension = 32 : index} : memref<32x32xf16, 3> -> !gpu.mma_matrix<16x16xi32, "AOp"> - return -} - -// ----- - #layout_map_col_major = affine_map<(i, j) -> (j, i)> func @mmaLoadOp_identity_layout(){