diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp --- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp +++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp @@ -25,6 +25,8 @@ auto i32x2Ty = LLVM::getFixedVectorType(i32Ty, 2); Type f64Ty = Float64Type::get(ctx); Type f64x2Ty = LLVM::getFixedVectorType(f64Ty, 2); + Type f32Ty = Float32Type::get(ctx); + Type f32x2Ty = LLVM::getFixedVectorType(f32Ty, 2); if (a.getElementType() == f16x2Ty) { return LLVM::LLVMStructType::getLiteral( ctx, SmallVector(a.getNumElements(), f16x2Ty)); @@ -37,6 +39,15 @@ if (a.getElementType() == f64x2Ty) { return LLVM::LLVMStructType::getLiteral(ctx, {f64Ty, f64Ty}); } + if (a.getElementType() == f32x2Ty) { + return LLVM::LLVMStructType::getLiteral( + ctx, + SmallVector(static_cast(a.getNumElements()) * 2, f32Ty)); + } + if (a.getElementType() == LLVM::getFixedVectorType(f32Ty, 1)) { + return LLVM::LLVMStructType::getLiteral( + ctx, SmallVector(static_cast(a.getNumElements()), f32Ty)); + } return vectorResultType; } @@ -52,10 +63,13 @@ auto structType = intrinsicResultType.dyn_cast(); auto arrayType = resultType.dyn_cast(); Type i32Ty = rewriter.getI32Type(); + Type f32Ty = rewriter.getF32Type(); Type f64Ty = rewriter.getF64Type(); Type f16x2Ty = LLVM::getFixedVectorType(rewriter.getF16Type(), 2); Type i32x2Ty = LLVM::getFixedVectorType(i32Ty, 2); Type f64x2Ty = LLVM::getFixedVectorType(f64Ty, 2); + Type f32x2Ty = LLVM::getFixedVectorType(f32Ty, 2); + Type f32x1Ty = LLVM::getFixedVectorType(f32Ty, 1); auto makeConst = [&](int32_t index) -> Value { return rewriter.create(loc, IntegerType::get(ctx, 32), @@ -65,21 +79,31 @@ if (arrayType) { SmallVector elements; - if (arrayType.getElementType() == f16x2Ty) { + // The intrinsic returns 32-bit wide elements in a form which can be + // directly bitcasted and inserted into the result vector. + if (arrayType.getElementType() == f16x2Ty || + arrayType.getElementType() == f32x1Ty) { for (unsigned i = 0; i < structType.getBody().size(); i++) { - elements.push_back(rewriter.create( + Value el = rewriter.create( loc, structType.getBody()[i], intrinsicResult, - rewriter.getI64ArrayAttr(i))); + rewriter.getI64ArrayAttr(i)); + el = rewriter.createOrFold( + loc, arrayType.getElementType(), el); + elements.push_back(el); } } - // The intrinsic returns i32 and f64 values as individual scalars. We need - // to extract them from the struct and pack them into vectors. + // The intrinsic returns i32, f64, and f32 values as individual scalars, + // even when the result is notionally a 64-bit wide element (e.g. f32x2). We + // need to extract them from the struct and pack them into the 64-bit wide + // rows of the vector result. if (arrayType.getElementType() == i32x2Ty || - arrayType.getElementType() == f64x2Ty) { - Value vec = - rewriter.create(loc, arrayType.getElementType()); + arrayType.getElementType() == f64x2Ty || + arrayType.getElementType() == f32x2Ty) { + for (unsigned i = 0, e = structType.getBody().size() / 2; i < e; i++) { + Value vec = + rewriter.create(loc, arrayType.getElementType()); Value x1 = rewriter.create( loc, structType.getBody()[i * 2], intrinsicResult, rewriter.getI64ArrayAttr(i * 2)); @@ -90,8 +114,8 @@ x1, makeConst(0)); vec = rewriter.create(loc, vec.getType(), vec, x2, makeConst(1)); + elements.push_back(vec); } - elements.push_back(vec); } // Create the final vectorized result. @@ -113,12 +137,15 @@ /// scalars of certain types. This function helps unpack the `vector` arguments /// and cast them to the types expected by `nvvm.mma.sync`. static SmallVector unpackOperandVector(RewriterBase &rewriter, - Location loc, Value operand) { + Location loc, Value operand, + NVVM::MMATypes operandPtxType) { SmallVector result; Type i32Ty = rewriter.getI32Type(); Type f64Ty = rewriter.getF64Type(); + Type f32Ty = rewriter.getF32Type(); Type i8Ty = rewriter.getI8Type(); Type i8x4Ty = LLVM::getFixedVectorType(i8Ty, 4); + Type f32x1Ty = LLVM::getFixedVectorType(f32Ty, 1); auto arrayTy = operand.getType().cast(); for (unsigned i = 0, e = arrayTy.getNumElements(); i < e; ++i) { @@ -127,18 +154,21 @@ // For 4xi8 vectors, the intrinsic expects these to be provided as i32 // scalar types. - if (arrayTy.getElementType() == i8x4Ty) { + if (arrayTy.getElementType() == i8x4Ty || + (arrayTy.getElementType() == f32x1Ty && + operandPtxType == NVVM::MMATypes::tf32)) { result.push_back( rewriter.create(loc, rewriter.getI32Type(), toUse)); continue; } - // For some element types (i32, f64), we need to unpack the inner + // For some element types (i32, f32, f64), we need to unpack the inner // vector/array type as well because the intrinsic expects individual // scalars to be provided. VectorType innerArrayTy = arrayTy.getElementType().dyn_cast(); if (innerArrayTy && (innerArrayTy.getElementType() == i32Ty || - innerArrayTy.getElementType() == f64Ty)) { + innerArrayTy.getElementType() == f64Ty || + innerArrayTy.getElementType() == f32Ty)) { for (unsigned idx = 0, innerSize = innerArrayTy.getNumElements(); idx < innerSize; idx++) { result.push_back(rewriter.create( @@ -229,37 +259,47 @@ // Get the shapes of the MMAMatrix type being used. The shapes will // choose which intrinsic this op will be lowered to. auto aType = op.matrixA().getType().cast(); + auto cType = op.matrixC().getType().cast(); int64_t m = op.mmaShape()[0].cast().getInt(); int64_t n = op.mmaShape()[1].cast().getInt(); int64_t k = op.mmaShape()[2].cast().getInt(); std::array gemmShape{m, n, k}; - SmallVector matA = - unpackOperandVector(rewriter, loc, adaptor.matrixA()); - SmallVector matB = - unpackOperandVector(rewriter, loc, adaptor.matrixB()); - SmallVector matC = - unpackOperandVector(rewriter, loc, adaptor.matrixC()); - NVVM::MMATypes ptxTypeA; NVVM::MMATypes ptxTypeB; + Optional ptxTypeC = NVVM::MmaOp::inferOperandMMAType( + cType.getElementType(), /*isAccumulator=*/true); + if (!ptxTypeC) { + return op->emitError( + "could not infer the PTX type for the accumulator/result"); + } + Optional overflow(llvm::None); if (aType.getElementType().isInteger(8)) { ptxTypeA = NVVM::MMATypes::s8; ptxTypeB = NVVM::MMATypes::s8; overflow = NVVM::MMAIntOverflow::satfinite; - } else if (aType.getElementType().isF16()) { ptxTypeA = NVVM::MMATypes::f16; ptxTypeB = NVVM::MMATypes::f16; } else if (aType.getElementType().isF64()) { ptxTypeA = NVVM::MMATypes::f64; ptxTypeB = NVVM::MMATypes::f64; + } else if (aType.getElementType().isF32()) { + ptxTypeA = NVVM::MMATypes::tf32; + ptxTypeB = NVVM::MMATypes::tf32; } else { return op->emitError("could not deduce operand PTX types"); } + SmallVector matA = + unpackOperandVector(rewriter, loc, adaptor.matrixA(), ptxTypeA); + SmallVector matB = + unpackOperandVector(rewriter, loc, adaptor.matrixB(), ptxTypeB); + SmallVector matC = + unpackOperandVector(rewriter, loc, adaptor.matrixC(), *ptxTypeC); + Type desiredRetTy = typeConverter->convertType(op->getResultTypes()[0]); Type intrinsicResTy = inferIntrinsicResultType( typeConverter->convertType(op->getResultTypes()[0])); diff --git a/mlir/test/Conversion/NVGPUToNVVM/mma-sync-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/mma-sync-to-nvvm.mlir --- a/mlir/test/Conversion/NVGPUToNVVM/mma-sync-to-nvvm.mlir +++ b/mlir/test/Conversion/NVGPUToNVVM/mma-sync-to-nvvm.mlir @@ -24,6 +24,34 @@ // ----- +// Same as above but with fp32 acumulation type. + +// CHECK-LABEL: @m16n8k16_fp16_fp32 +func.func @m16n8k16_fp16_fp32(%arg0: vector<4x2xf16>, %arg1: vector<2x2xf16>, %arg2: vector<2x2xf32>) -> vector<2x2xf32> { + // We just need to check the mma instruction and the manipulatin of the result. + // CHECK: [[d:%.+]] = nvvm.mma.sync + // CHECK-SAME: shape = {k = 16 : i32, m = 16 : i32, n = 8 : i32} + // CHECK-SAME: (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32)> + %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 16]} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf32>) -> vector<2x2xf32> + // CHECK: [[undef:%.+]] = llvm.mlir.undef : vector<2xf32> + // CHECK-DAG: llvm.extractvalue [[d]][0] : !llvm.struct<(f32, f32, f32, f32)> + // CHECK-DAG: llvm.extractvalue [[d]][1] : !llvm.struct<(f32, f32, f32, f32)> + // CHECK: [[d00:%.+]] = llvm.insertelement {{%.+}}, [[undef]][{{.*}}] : vector<2xf32> + // CHECK: [[d01:%.+]] = llvm.insertelement {{%.+}}, [[d00]][{{.*}}] : vector<2xf32> + + // CHECK: [[undef:%.+]] = llvm.mlir.undef : vector<2xf32> + // CHECK-DAG: llvm.extractvalue [[d]][2] : !llvm.struct<(f32, f32, f32, f32)> + // CHECK-DAG: llvm.extractvalue [[d]][3] : !llvm.struct<(f32, f32, f32, f32)> + // CHECK: [[d10:%.+]] = llvm.insertelement {{%.+}}, [[undef]][{{.*}}] : vector<2xf32> + // CHECK: [[d11:%.+]] = llvm.insertelement {{%.+}}, [[d10]][{{.*}}] : vector<2xf32> + + // CHECK-DAG: llvm.insertvalue [[d01]], {{%.+}}[0] : !llvm.array<2 x vector<2xf32>> + // CHECK-DAG: llvm.insertvalue [[d11]], {{%.+}}[1] : !llvm.array<2 x vector<2xf32>> + return %d : vector<2x2xf32> +} + +// ----- + // CHECK-LABEL: @m16n8k8_fp16 func.func @m16n8k8_fp16(%arg0: vector<2x2xf16>, %arg1: vector<1x2xf16>, %arg2: vector<2x2xf16>) -> vector<2x2xf16> { // CHECK: llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<2xf16>> @@ -125,3 +153,33 @@ // CHECK: llvm.insertvalue return %a : vector<1x2xf16> } + +// ----- + +// CHECK-LABEL: @m16n8k4_tf32 +func.func @m16n8k4_tf32(%arg0: vector<2x1xf32>, %arg1: vector<1x1xf32>, %arg2: vector<4x1xf32>) -> vector<4x1xf32> { + // The A, B operand should be bitcast to i32 + // CHECK: llvm.extractvalue + // CHECK: llvm.bitcast {{.*}} : vector<1xf32> to i32 + // CHECK: llvm.extractvalue + // CHECK: llvm.bitcast {{.*}} : vector<1xf32> to i32 + // CHECK: llvm.extractvalue + // CHECK: llvm.bitcast {{.*}} : vector<1xf32> to i32 + + // CHECK: [[d:%.+]] = nvvm.mma.sync A[{{%.+}}, {{%.+}}] B[{{%.+}}] C[{{%.+}}, {{%.+}}, {{%.+}}, {{%.+}}] + // CHECK-SAME: multiplicandAPtxType = #nvvm.mma_type + // CHECK-SAME: multiplicandBPtxType = #nvvm.mma_type + // CHECK-SAME: shape = {k = 4 : i32, m = 16 : i32, n = 8 : i32} + // CHECK-SAME: -> !llvm.struct<(f32, f32, f32, f32)> + %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 4]} : (vector<2x1xf32>, vector<1x1xf32>, vector<4x1xf32>) -> vector<4x1xf32> + // CHECK: [[el:%.+]] = llvm.extractvalue [[d]][0] + // CHECK: llvm.bitcast [[el]] : f32 to vector<1xf32> + // CHECK: [[el:%.+]] = llvm.extractvalue [[d]][1] + // CHECK: llvm.bitcast [[el]] : f32 to vector<1xf32> + // CHECK: [[el:%.+]] = llvm.extractvalue [[d]][2] + // CHECK: llvm.bitcast [[el]] : f32 to vector<1xf32> + // CHECK: [[el:%.+]] = llvm.extractvalue [[d]][3] + // CHECK: llvm.bitcast [[el]] : f32 to vector<1xf32> + // CHECK-COUNT-4: llvm.insertvalue {{.*}} : !llvm.array<4 x vector<1xf32>> + return %d : vector<4x1xf32> +} \ No newline at end of file