diff --git a/mlir/include/mlir/Dialect/GPU/GPUOps.td b/mlir/include/mlir/Dialect/GPU/GPUOps.td --- a/mlir/include/mlir/Dialect/GPU/GPUOps.td +++ b/mlir/include/mlir/Dialect/GPU/GPUOps.td @@ -96,6 +96,19 @@ }]; } +def GPU_LaneIdOp : GPU_Op<"lane_id", [NoSideEffect]> { + let description = [{ + Returns the lane id within the subgroup (warp/wave). + + Example: + ```mlir + %laneId = gpu.lane_id + ``` + }]; + let results = (outs Index:$result); + let assemblyFormat = "attr-dict"; +} + def GPU_SubgroupIdOp : GPU_Op<"subgroup_id", [NoSideEffect]>, Arguments<(ins)>, Results<(outs Index:$result)> { let description = [{ @@ -1354,4 +1367,58 @@ }]; } +def GPU_MmaLdMatrixOp : GPU_Op<"mma.ldmatrix", + [MemoryEffects<[MemRead]>]> { + let description = [{ + The `gpu.mma.ldmatrix` op represents loading a matrix fragment from + memory. The load source and result type must be compatible with lowering + to the `nvvm.ldmatrix` instruction. This op is meant to represent + the distributed version of a `vector.transfer_read` as an intermediate + step between lowering from `vector.transfer_read` to `nvvm.ldmatrix`. + + Example: + + ```mlir + gpu.mma.ldmatrix %shm_buffer[%c0, %c0] : memref<16x16xf16, 3> -> vector<4x2xf16> + ``` + }]; + + let arguments = (ins Arg:$srcMemref, + Variadic:$indices, BoolAttr:$transpose, + I32Attr:$numTiles); + let results = (outs AnyVector:$res); + let assemblyFormat = [{ + $srcMemref`[` $indices `]` attr-dict `:` type($srcMemref) `->` type($res) + }]; +} + +def GPU_MmaSyncOp : GPU_Op<"mma.sync", [NoSideEffect]> { + let description = [{ + The `gpu.mma.sync` op represents the distributed form of a collective + matrix-multiply-and-accumulate (mma) operation that is compatible with + `nvvm.mma.sync`. The operands and results are fragments of the full matrix + operands. The full shape of the distributed mma operation is given by the + `mmaShape` attribute in the form of a list of dimensions `[m, n, k]`. + + This operation is meant to be lowered to the `nvvm.mma.sync` instruction, and + is an intermediate point between lowering from `vector.contract` to + `nvvm.mma.sync`. + + Example: + + ```mlir + gpu.mma.sync (%a, %b, %c) : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16> + ``` + }]; + let arguments = (ins AnyVector:$matrixA, AnyVector:$matrixB, AnyVector:$matrixC, + I64ArrayAttr:$mmaShape); + + let results = (outs AnyVector:$res); + + let assemblyFormat = [{ + `(` $matrixA`,` $matrixB`,` $matrixC `)` attr-dict + `:` `(` type($matrixA) `,` type($matrixB) `,` type($matrixC) `)` `->` type($res) + }]; +} + #endif // GPU_OPS diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp --- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp @@ -208,6 +208,314 @@ } }; +struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(gpu::MmaLdMatrixOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + MLIRContext *ctx = getContext(); + Location loc = op->getLoc(); + + // The result type of ldmatrix will always be a struct of 32bit integer + // registers if more than one 32bit value is returned. Otherwise, the result + // is a single i32. The result type of the GPU operation is always a vector + // of shape (NumRegisters, VectorRegister) where VectorRegister is the + // vector type of the result and always 32 bits long. We bitcast the result + // of the NVVM::LdMatrix to this vector type. + auto vectorResultType = op->getResultTypes()[0].dyn_cast(); + if (!vectorResultType) { + return failure(); + } + Type innerVectorType = LLVM::getFixedVectorType( + vectorResultType.getElementType(), vectorResultType.getDimSize(1)); + + int64_t num32BitRegs = vectorResultType.getDimSize(0); + + Type ldMatrixResultType; + if (num32BitRegs > 1) { + ldMatrixResultType = LLVM::LLVMStructType::getLiteral( + ctx, SmallVector(num32BitRegs, rewriter.getI32Type())); + } else { + ldMatrixResultType = rewriter.getI32Type(); + } + + auto srcMemrefType = op.srcMemref().getType().cast(); + Value srcPtr = getStridedElementPtr(loc, srcMemrefType, adaptor.srcMemref(), + adaptor.indices(), rewriter); + Value ldMatrixResult = rewriter.create( + loc, ldMatrixResultType, srcPtr, + /*num=*/op.numTiles(), + /*layout=*/op.transpose() ? NVVM::MMALayout::col + : NVVM::MMALayout::row); + + // The ldmatrix operation returns either a single i32 value or a struct of + // i32 values. Here we unpack those values and cast them back to their + // actual vector type (still of width 32b) and repack them into a result + // struct. + Type finalResultType = typeConverter->convertType(vectorResultType); + Value result = rewriter.create(loc, finalResultType); + for (int64_t i = 0, e = vectorResultType.getDimSize(0); i < e; i++) { + Value i32Register = num32BitRegs > 1 + ? rewriter.create( + loc, rewriter.getI32Type(), ldMatrixResult, + rewriter.getI64ArrayAttr(i)) + : ldMatrixResult; + Value casted = + rewriter.create(loc, innerVectorType, i32Register); + result = rewriter.create( + loc, finalResultType, result, casted, rewriter.getI64ArrayAttr(i)); + } + + rewriter.replaceOp(op, result); + return success(); + } +}; + +/// Checks if all the operands of the op being lowered are of LLVM Types. The +/// types are expected to be converted by the `LLVMTypeConverter` before the +/// op is actually lowered. If the type of an operands is not already +/// converted it hints a missing typeConversion and failure is returned in +/// that case. +LogicalResult areAllLLVMTypes(Operation *op, ValueRange operands, + ConversionPatternRewriter &rewriter) { + if (!llvm::all_of(operands, [](Value value) { + return LLVM::isCompatibleType(value.getType()); + })) { + return rewriter.notifyMatchFailure( + op, "cannot convert if operands aren't of LLVM type."); + } + + return success(); +} + +/// Returns the type for the intrinsic given the vectorResultType of the +/// `gpu.mma.sync` operation. +Type inferIntrinsicResultType(Type vectorResultType) { + MLIRContext *ctx = vectorResultType.getContext(); + auto a = vectorResultType.cast(); + auto f16x2Ty = LLVM::getFixedVectorType(Float16Type::get(ctx), 2); + auto i32Ty = IntegerType::get(ctx, 32); + auto i32x2Ty = LLVM::getFixedVectorType(i32Ty, 2); + Type f64Ty = Float64Type::get(ctx); + Type f64x2Ty = LLVM::getFixedVectorType(f64Ty, 2); + if (a.getElementType() == f16x2Ty) { + return LLVM::LLVMStructType::getLiteral( + ctx, SmallVector(a.getNumElements(), f16x2Ty)); + } + if (a.getElementType() == i32x2Ty) { + return LLVM::LLVMStructType::getLiteral( + ctx, + SmallVector(static_cast(a.getNumElements()) * 2, i32Ty)); + } + if (a.getElementType() == f64x2Ty) { + return LLVM::LLVMStructType::getLiteral(ctx, {f64Ty, f64Ty}); + } + return vectorResultType; +} + +/// Convert the SSA result of the NVVM intrinsic `nvvm.mma.sync` (which is +/// always an LLVM struct) into a fragment that is compatible with the vector +/// type of this operation. This involves extracting elements from the struct +/// and inserting them into an LLVM array. These extra data-movement +/// operations should be canonicalized away by the LLVM backend. +Value convertIntrinsicResult(Location loc, Type intrinsicResultType, + Type resultType, Value intrinsicResult, + RewriterBase &rewriter) { + MLIRContext *ctx = rewriter.getContext(); + auto structType = intrinsicResultType.dyn_cast(); + auto arrayType = resultType.dyn_cast(); + Type i32Ty = rewriter.getI32Type(); + Type f64Ty = rewriter.getF64Type(); + Type f16x2Ty = LLVM::getFixedVectorType(rewriter.getF16Type(), 2); + Type i32x2Ty = LLVM::getFixedVectorType(i32Ty, 2); + Type f64x2Ty = LLVM::getFixedVectorType(f64Ty, 2); + + auto makeConst = [&](int32_t index) -> Value { + return rewriter.create(loc, IntegerType::get(ctx, 32), + rewriter.getI32IntegerAttr(index)); + }; + + if (arrayType) { + SmallVector elements; + + if (arrayType.getElementType() == f16x2Ty) { + for (unsigned i = 0; i < structType.getBody().size(); i++) { + elements.push_back(rewriter.create( + loc, structType.getBody()[i], intrinsicResult, + rewriter.getI64ArrayAttr(i))); + } + } + + // The intrinsic returns i32 and f64 values as individual scalars. We need + // to extract them from the struct and pack them into vectors. + if (arrayType.getElementType() == i32x2Ty || + arrayType.getElementType() == f64x2Ty) { + Value vec = + rewriter.create(loc, arrayType.getElementType()); + for (unsigned i = 0, e = structType.getBody().size() / 2; i < e; i++) { + Value x1 = rewriter.create( + loc, structType.getBody()[i * 2], intrinsicResult, + rewriter.getI64ArrayAttr(i * 2)); + Value x2 = rewriter.create( + loc, structType.getBody()[i * 2 + 1], intrinsicResult, + rewriter.getI64ArrayAttr(i * 2 + 1)); + vec = rewriter.create(loc, vec.getType(), vec, + x1, makeConst(0)); + vec = rewriter.create(loc, vec.getType(), vec, + x2, makeConst(1)); + } + elements.push_back(vec); + } + + // Create the final vectorized result. + Value result = rewriter.create(loc, arrayType); + for (const auto &el : llvm::enumerate(elements)) { + result = rewriter.create( + loc, arrayType, result, el.value(), + rewriter.getI64ArrayAttr(el.index())); + } + return result; + } + + return intrinsicResult; +} + +/// The `gpu.mma.sync` converter below expects matrix fragment operands to be +/// given as 2D `vectors` where the rows are 32b or 64b wide. The +/// `nvvm.mma.sync` op expects these argments to be a given in a long list of +/// scalars of certain types. This function helps unpack the `vector` arguments +/// and cast them to the types expected by `nvvm.mma.sync`. +SmallVector unpackOperandVector(RewriterBase &rewriter, Location loc, + Value operand) { + SmallVector result; + Type i32Ty = rewriter.getI32Type(); + Type f64Ty = rewriter.getF64Type(); + Type i8Ty = rewriter.getI8Type(); + Type i8x4Ty = LLVM::getFixedVectorType(i8Ty, 4); + auto arrayTy = operand.getType().cast(); + + for (unsigned i = 0, e = arrayTy.getNumElements(); i < e; ++i) { + Value toUse = rewriter.create( + loc, arrayTy.getElementType(), operand, rewriter.getI64ArrayAttr(i)); + + // For 4xi8 vectors, the intrinsic expects these to be provided as i32 + // scalar types. + if (arrayTy.getElementType() == i8x4Ty) { + result.push_back( + rewriter.create(loc, rewriter.getI32Type(), toUse)); + continue; + } + + // For some element types (i32, f64), we need to unpack the inner + // vector/array type as well because the intrinsic expects individual + // scalars to be provided. + VectorType innerArrayTy = arrayTy.getElementType().dyn_cast(); + if (innerArrayTy && (innerArrayTy.getElementType() == i32Ty || + innerArrayTy.getElementType() == f64Ty)) { + for (unsigned idx = 0, innerSize = innerArrayTy.getNumElements(); + idx < innerSize; idx++) { + result.push_back(rewriter.create( + loc, toUse, + rewriter.create( + loc, rewriter.getI64Type(), rewriter.getI64IntegerAttr(idx)))); + } + continue; + } + result.push_back(toUse); + } + return result; +} + +struct MmaSyncOptoNVVM : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(gpu::MmaSyncOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter))) { + return failure(); + } + + // Get the shapes of the MMAMatrix type being used. The shapes will + // choose which intrinsic this op will be lowered to. + auto aType = op.matrixA().getType().cast(); + + int64_t m = op.mmaShape()[0].cast().getInt(); + int64_t n = op.mmaShape()[1].cast().getInt(); + int64_t k = op.mmaShape()[2].cast().getInt(); + std::array gemmShape{m, n, k}; + + SmallVector matA = + unpackOperandVector(rewriter, loc, adaptor.matrixA()); + SmallVector matB = + unpackOperandVector(rewriter, loc, adaptor.matrixB()); + SmallVector matC = + unpackOperandVector(rewriter, loc, adaptor.matrixC()); + + NVVM::MMATypes ptxTypeA; + NVVM::MMATypes ptxTypeB; + Optional overflow(llvm::None); + if (aType.getElementType().isInteger(8)) { + ptxTypeA = NVVM::MMATypes::s8; + ptxTypeB = NVVM::MMATypes::s8; + overflow = NVVM::MMAIntOverflow::satfinite; + + } else if (aType.getElementType().isF16()) { + ptxTypeA = NVVM::MMATypes::f16; + ptxTypeB = NVVM::MMATypes::f16; + } else if (aType.getElementType().isF64()) { + ptxTypeA = NVVM::MMATypes::f64; + ptxTypeB = NVVM::MMATypes::f64; + } else { + return op->emitError("could not deduce operand PTX types"); + } + + Type desiredRetTy = typeConverter->convertType(op->getResultTypes()[0]); + Type intrinsicResTy = inferIntrinsicResultType( + typeConverter->convertType(op->getResultTypes()[0])); + Value intrinsicResult = rewriter.create( + op.getLoc(), intrinsicResTy, matA, matB, matC, + /*shape=*/gemmShape, + /*b1Op=*/llvm::None, + /*intOverflow=*/overflow, + /*multiplicandPtxTypes=*/ + std::array{ptxTypeA, ptxTypeB}, + /*multiplicandLayouts=*/ + std::array{NVVM::MMALayout::row, + NVVM::MMALayout::col}); + rewriter.replaceOp(op, convertIntrinsicResult(op.getLoc(), intrinsicResTy, + desiredRetTy, intrinsicResult, + rewriter)); + return success(); + } +}; + +struct GPULaneIdOpToNVVM : ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(gpu::LaneIdOp op, gpu::LaneIdOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + MLIRContext *context = rewriter.getContext(); + Value newOp = rewriter.create(loc, rewriter.getI32Type()); + // Truncate or extend the result depending on the index bitwidth specified + // by the LLVMTypeConverter options. + const unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth(); + if (indexBitwidth > 32) { + newOp = rewriter.create( + loc, IntegerType::get(context, indexBitwidth), newOp); + } else if (indexBitwidth < 32) { + newOp = rewriter.create( + loc, IntegerType::get(context, indexBitwidth), newOp); + } + rewriter.replaceOp(op, {newOp}); + return success(); + } +}; + /// Import the GPU Ops to NVVM Patterns. #include "GPUToNVVM.cpp.inc" @@ -303,7 +611,8 @@ NVVM::BlockIdYOp, NVVM::BlockIdZOp>, GPUIndexIntrinsicOpLowering, - GPUShuffleOpLowering, GPUReturnOpLowering>(converter); + GPULaneIdOpToNVVM, GPUShuffleOpLowering, GPUReturnOpLowering, + MmaSyncOptoNVVM, MmaLdMatrixOpToNVVM>(converter); // Explicitly drop memory space when lowering private memory // attributions since NVVM models it as `alloca`s in the default diff --git a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir --- a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir +++ b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir @@ -6,7 +6,8 @@ // CHECK32-LABEL: func @gpu_index_ops() func.func @gpu_index_ops() -> (index, index, index, index, index, index, - index, index, index, index, index, index) { + index, index, index, index, index, index, + index) { // CHECK32-NOT: = llvm.sext %{{.*}} : i32 to i64 // CHECK: = nvvm.read.ptx.sreg.tid.x : i32 @@ -49,10 +50,17 @@ // CHECK: = llvm.sext %{{.*}} : i32 to i64 %gDimZ = gpu.grid_dim z + + // CHECK: = nvvm.read.ptx.sreg.laneid : i32 + // CHECK: = llvm.sext %{{.*}} : i32 to i64 + %laneId = gpu.lane_id + func.return %tIdX, %tIdY, %tIdZ, %bDimX, %bDimY, %bDimZ, - %bIdX, %bIdY, %bIdZ, %gDimX, %gDimY, %gDimZ + %bIdX, %bIdY, %bIdZ, %gDimX, %gDimY, %gDimZ, + %laneId : index, index, index, index, index, index, - index, index, index, index, index, index + index, index, index, index, index, index, + index } } diff --git a/mlir/test/Conversion/GPUToNVVM/mma-sync-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/mma-sync-to-nvvm.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/GPUToNVVM/mma-sync-to-nvvm.mlir @@ -0,0 +1,129 @@ +// RUN: mlir-opt --convert-gpu-to-nvvm --split-input-file %s | FileCheck %s + +gpu.module @test_module { + // CHECK-LABEL: @m16n8k16_fp16 + func @m16n8k16_fp16(%arg0: vector<4x2xf16>, %arg1: vector<2x2xf16>, %arg2: vector<2x2xf16>) -> vector<2x2xf16> { + // CHECK: llvm.extractvalue %arg0[0] : !llvm.array<4 x vector<2xf16>> + // CHECK: llvm.extractvalue %arg0[1] : !llvm.array<4 x vector<2xf16>> + // CHECK: llvm.extractvalue %arg0[2] : !llvm.array<4 x vector<2xf16>> + // CHECK: llvm.extractvalue %arg0[3] : !llvm.array<4 x vector<2xf16>> + + // CHECK: llvm.extractvalue %arg1[0] : !llvm.array<2 x vector<2xf16>> + // CHECK: llvm.extractvalue %arg1[1] : !llvm.array<2 x vector<2xf16>> + + // CHECK: llvm.extractvalue %arg2[0] : !llvm.array<2 x vector<2xf16>> + // CHECK: llvm.extractvalue %arg2[1] : !llvm.array<2 x vector<2xf16>> + // CHECK-NOT llvm.extractvalue + // CHECK: [[d:%.+]] = nvvm.mma.sync + // CHECK-SAME: shape = {k = 16 : i32, m = 16 : i32, n = 8 : i32} + %d = gpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 16]} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16> + // CHECK-DAG: llvm.extractvalue [[d]][0] : !llvm.struct<(vector<2xf16>, vector<2xf16>)> + // CHECK-DAG: llvm.extractvalue [[d]][1] : !llvm.struct<(vector<2xf16>, vector<2xf16>)> + // CHECK: llvm.mlir.undef : !llvm.array<2 x vector<2xf16>> + // CHECK-DAG: llvm.insertvalue {{%.+}}, {{%.+}}[0] : !llvm.array<2 x vector<2xf16>> + // CHECK-DAG: llvm.insertvalue {{%.+}}, {{%.+}}[1] : !llvm.array<2 x vector<2xf16>> + // CHECK: llvm.return {{%.+}} : !llvm.array<2 x vector<2xf16>> + return %d : vector<2x2xf16> + } + + // CHECK-LABEL: @m16n8k8_fp16 + func @m16n8k8_fp16(%arg0: vector<2x2xf16>, %arg1: vector<1x2xf16>, %arg2: vector<2x2xf16>) -> vector<2x2xf16> { + // CHECK: llvm.extractvalue %arg0[0] : !llvm.array<2 x vector<2xf16>> + // CHECK: llvm.extractvalue %arg0[1] : !llvm.array<2 x vector<2xf16>> + + // CHECK: llvm.extractvalue %arg1[0] : !llvm.array<1 x vector<2xf16>> + + // CHECK: llvm.extractvalue %arg2[0] : !llvm.array<2 x vector<2xf16>> + // CHECK: llvm.extractvalue %arg2[1] : !llvm.array<2 x vector<2xf16>> + // CHECK-NOT llvm.extractvalue + // CHECK: [[d:%.+]] = nvvm.mma.sync + // CHECK-SAME: shape = {k = 8 : i32, m = 16 : i32, n = 8 : i32} + %d = gpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 8]} : (vector<2x2xf16>, vector<1x2xf16>, vector<2x2xf16>) -> vector<2x2xf16> + // CHECK-DAG: llvm.extractvalue [[d]][0] : !llvm.struct<(vector<2xf16>, vector<2xf16>)> + // CHECK-DAG: llvm.extractvalue [[d]][1] : !llvm.struct<(vector<2xf16>, vector<2xf16>)> + // CHECK: llvm.mlir.undef : !llvm.array<2 x vector<2xf16>> + // CHECK-DAG: llvm.insertvalue {{%.+}}, {{%.+}}[0] : !llvm.array<2 x vector<2xf16>> + // CHECK-DAG: llvm.insertvalue {{%.+}}, {{%.+}}[1] : !llvm.array<2 x vector<2xf16>> + // CHECK: llvm.return {{%.+}} : !llvm.array<2 x vector<2xf16>> + return %d : vector<2x2xf16> + } + + // CHECK-LABEL: @m16n8k32_int8 + func @m16n8k32_int8(%arg0: vector<4x4xi8>, %arg1: vector<2x4xi8>, %arg2: vector<2x2xi32>) -> vector<2x2xi32> { + + // CHECK: [[el:%.+]] = llvm.extractvalue %arg0[{{.*}}] : !llvm.array<4 x vector<4xi8>> + // CHECK: llvm.bitcast [[el]] : vector<4xi8> to i32 + // CHECK: [[el:%.+]] = llvm.extractvalue %arg0[{{.*}}] : !llvm.array<4 x vector<4xi8>> + // CHECK: llvm.bitcast [[el]] : vector<4xi8> to i32 + // CHECK: [[el:%.+]] = llvm.extractvalue %arg0[{{.*}}] : !llvm.array<4 x vector<4xi8>> + // CHECK: llvm.bitcast [[el]] : vector<4xi8> to i32 + // CHECK: [[el:%.+]] = llvm.extractvalue %arg0[{{.*}}] : !llvm.array<4 x vector<4xi8>> + // CHECK: llvm.bitcast [[el]] : vector<4xi8> to i32 + + // CHECK: [[el:%.+]] = llvm.extractvalue %arg1[{{.*}}] : !llvm.array<2 x vector<4xi8>> + // CHECK: llvm.bitcast [[el]] : vector<4xi8> to i32 + // CHECK: [[el:%.+]] = llvm.extractvalue %arg1[{{.*}}] : !llvm.array<2 x vector<4xi8>> + // CHECK: llvm.bitcast [[el]] : vector<4xi8> to i32 + + // CHECK: [[el:%.+]] = llvm.extractvalue %arg2[{{.*}}] : !llvm.array<2 x vector<2xi32>> + // CHECK: [[el:%.+]] = llvm.extractvalue %arg2[{{.*}}] : !llvm.array<2 x vector<2xi32>> + + // CHECK: [[d:%.+]] = nvvm.mma.sync + // CHECK-SAME: intOverflowBehavior = #nvvm.mma_int_overflow + // CHECK-SAME: multiplicandAPtxType = #nvvm.mma_type + // CHECK-SAME: multiplicandBPtxType = #nvvm.mma_type + // CHECK-SAME: shape = {k = 32 : i32, m = 16 : i32, n = 8 : i32} + %d = gpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 32]} : (vector<4x4xi8>, vector<2x4xi8>, vector<2x2xi32>) -> vector<2x2xi32> + + // CHECK: llvm.return {{%.+}} : !llvm.array<2 x vector<2xi32>> + return %d : vector<2x2xi32> + } + + // CHECK-LABEL: @m8n8k4_f64 + func @m8n8k4_f64(%arg0: vector<1x1xf64>, %arg1: vector<1x1xf64>, %arg2: vector<1x2xf64>) -> vector<1x2xf64> { + // CHECK: llvm.extractvalue %arg0 + // CHECK: llvm.extractvalue %arg1 + // CHECK: llvm.extractvalue %arg2 + + // CHECK: [[d:%.+]] = nvvm.mma.sync A[{{%.+}}] B[{{%.+}}] C[{{%.+}}, {{%.+}}] + // CHECK-SAME: shape = {k = 4 : i32, m = 8 : i32, n = 8 : i32} + %d = gpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [8, 8, 4]} : (vector<1x1xf64>, vector<1x1xf64>, vector<1x2xf64>) -> vector<1x2xf64> + // CHECK: llvm.mlir.undef : vector<2xf64> + // CHECK-DAG: llvm.extractvalue [[d]][0] : !llvm.struct<(f64, f64)> + // CHECK-DAG: llvm.extractvalue [[d]][1] : !llvm.struct<(f64, f64)> + // CHECK-COUNT-2: llvm.insertelement {{.*}} : vector<2xf64> + // CHECK-DAG: llvm.insertvalue {{%.+}}, {{%.+}}[0] : !llvm.array<1 x vector<2xf64>> + // CHECK: llvm.return {{%.+}} : !llvm.array<1 x vector<2xf64>> + return %d : vector<1x2xf64> + } + + // CHECK-LABEL: @ldmatrix_x4 + func @ldmatrix_x4(%arg0: memref<128x128xf16, 3>) -> vector<4x2xf16> { + %c0 = arith.constant 0 : index + // CHECK: nvvm.ldmatrix {{%.+}} {layout = #nvvm.mma_layout, num = 4 : i32} {{.*}} -> !llvm.struct<(i32, i32, i32, i32) + %a = gpu.mma.ldmatrix %arg0[%c0, %c0] {transpose = false, numTiles = 4 : i32} : memref<128x128xf16, 3> -> vector<4x2xf16> + // CHECK: llvm.extractvalue + // CHECK: llvm.bitcast + // CHECK: llvm.insertvalue + // CHECK: llvm.extractvalue + // CHECK: llvm.bitcast + // CHECK: llvm.insertvalue + // CHECK: llvm.extractvalue + // CHECK: llvm.bitcast + // CHECK: llvm.insertvalue + // CHECK: llvm.extractvalue + // CHECK: llvm.bitcast + // CHECK: llvm.insertvalue + return %a : vector<4x2xf16> + } + + // CHECK-LABEL: @ldmatrix_x1 + func @ldmatrix_x1(%arg0: memref<128x128xf16, 3>) -> vector<1x2xf16> { + %c0 = arith.constant 0 : index + // CHECK: nvvm.ldmatrix {{%.+}} {layout = #nvvm.mma_layout, num = 1 : i32} {{.*}} -> i32 + %a = gpu.mma.ldmatrix %arg0[%c0, %c0] {transpose = false, numTiles = 1 : i32} : memref<128x128xf16, 3> -> vector<1x2xf16> + // CHECK: llvm.bitcast + // CHECK: llvm.insertvalue + return %a : vector<1x2xf16> + } +}