diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp --- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp +++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp @@ -25,6 +25,8 @@ auto i32x2Ty = LLVM::getFixedVectorType(i32Ty, 2); Type f64Ty = Float64Type::get(ctx); Type f64x2Ty = LLVM::getFixedVectorType(f64Ty, 2); + Type f32Ty = Float32Type::get(ctx); + Type f32x2Ty = LLVM::getFixedVectorType(f32Ty, 2); if (a.getElementType() == f16x2Ty) { return LLVM::LLVMStructType::getLiteral( ctx, SmallVector(a.getNumElements(), f16x2Ty)); @@ -37,6 +39,11 @@ if (a.getElementType() == f64x2Ty) { return LLVM::LLVMStructType::getLiteral(ctx, {f64Ty, f64Ty}); } + if (a.getElementType() == f32x2Ty) { + return LLVM::LLVMStructType::getLiteral( + ctx, + SmallVector(static_cast(a.getNumElements()) * 2, f32Ty)); + } return vectorResultType; } @@ -52,10 +59,12 @@ auto structType = intrinsicResultType.dyn_cast(); auto arrayType = resultType.dyn_cast(); Type i32Ty = rewriter.getI32Type(); + Type f32Ty = rewriter.getF32Type(); Type f64Ty = rewriter.getF64Type(); Type f16x2Ty = LLVM::getFixedVectorType(rewriter.getF16Type(), 2); Type i32x2Ty = LLVM::getFixedVectorType(i32Ty, 2); Type f64x2Ty = LLVM::getFixedVectorType(f64Ty, 2); + Type f32x2Ty = LLVM::getFixedVectorType(f32Ty, 2); auto makeConst = [&](int32_t index) -> Value { return rewriter.create(loc, IntegerType::get(ctx, 32), @@ -73,13 +82,15 @@ } } - // The intrinsic returns i32 and f64 values as individual scalars. We need - // to extract them from the struct and pack them into vectors. + // The intrinsic returns i32, f64, and f32 values as individual scalars. We + // need to extract them from the struct and pack them into vectors. if (arrayType.getElementType() == i32x2Ty || - arrayType.getElementType() == f64x2Ty) { - Value vec = - rewriter.create(loc, arrayType.getElementType()); + arrayType.getElementType() == f64x2Ty || + arrayType.getElementType() == f32x2Ty) { + for (unsigned i = 0, e = structType.getBody().size() / 2; i < e; i++) { + Value vec = + rewriter.create(loc, arrayType.getElementType()); Value x1 = rewriter.create( loc, structType.getBody()[i * 2], intrinsicResult, rewriter.getI64ArrayAttr(i * 2)); @@ -90,8 +101,8 @@ x1, makeConst(0)); vec = rewriter.create(loc, vec.getType(), vec, x2, makeConst(1)); + elements.push_back(vec); } - elements.push_back(vec); } // Create the final vectorized result. @@ -117,6 +128,7 @@ SmallVector result; Type i32Ty = rewriter.getI32Type(); Type f64Ty = rewriter.getF64Type(); + Type f32Ty = rewriter.getF32Type(); Type i8Ty = rewriter.getI8Type(); Type i8x4Ty = LLVM::getFixedVectorType(i8Ty, 4); auto arrayTy = operand.getType().cast(); @@ -133,12 +145,13 @@ continue; } - // For some element types (i32, f64), we need to unpack the inner + // For some element types (i32, f32, f64), we need to unpack the inner // vector/array type as well because the intrinsic expects individual // scalars to be provided. VectorType innerArrayTy = arrayTy.getElementType().dyn_cast(); if (innerArrayTy && (innerArrayTy.getElementType() == i32Ty || - innerArrayTy.getElementType() == f64Ty)) { + innerArrayTy.getElementType() == f64Ty || + innerArrayTy.getElementType() == f32Ty)) { for (unsigned idx = 0, innerSize = innerArrayTy.getNumElements(); idx < innerSize; idx++) { result.push_back(rewriter.create( diff --git a/mlir/test/Conversion/NVGPUToNVVM/mma-sync-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/mma-sync-to-nvvm.mlir --- a/mlir/test/Conversion/NVGPUToNVVM/mma-sync-to-nvvm.mlir +++ b/mlir/test/Conversion/NVGPUToNVVM/mma-sync-to-nvvm.mlir @@ -24,6 +24,34 @@ // ----- +// Same as above but with fp32 acumulation type. + +// CHECK-LABEL: @m16n8k16_fp16_fp32 +func.func @m16n8k16_fp16_fp32(%arg0: vector<4x2xf16>, %arg1: vector<2x2xf16>, %arg2: vector<2x2xf32>) -> vector<2x2xf32> { + // We just need to check the mma instruction and the manipulatin of the result. + // CHECK: [[d:%.+]] = nvvm.mma.sync + // CHECK-SAME: shape = {k = 16 : i32, m = 16 : i32, n = 8 : i32} + // CHECK-SAME: (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32)> + %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 16]} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf32>) -> vector<2x2xf32> + // CHECK: [[undef:%.+]] = llvm.mlir.undef : vector<2xf32> + // CHECK-DAG: llvm.extractvalue [[d]][0] : !llvm.struct<(f32, f32, f32, f32)> + // CHECK-DAG: llvm.extractvalue [[d]][1] : !llvm.struct<(f32, f32, f32, f32)> + // CHECK: [[d00:%.+]] = llvm.insertelement {{%.+}}, [[undef]][{{.*}}] : vector<2xf32> + // CHECK: [[d01:%.+]] = llvm.insertelement {{%.+}}, [[d00]][{{.*}}] : vector<2xf32> + + // CHECK: [[undef:%.+]] = llvm.mlir.undef : vector<2xf32> + // CHECK-DAG: llvm.extractvalue [[d]][2] : !llvm.struct<(f32, f32, f32, f32)> + // CHECK-DAG: llvm.extractvalue [[d]][3] : !llvm.struct<(f32, f32, f32, f32)> + // CHECK: [[d10:%.+]] = llvm.insertelement {{%.+}}, [[undef]][{{.*}}] : vector<2xf32> + // CHECK: [[d11:%.+]] = llvm.insertelement {{%.+}}, [[d10]][{{.*}}] : vector<2xf32> + + // CHECK-DAG: llvm.insertvalue [[d01]], {{%.+}}[0] : !llvm.array<2 x vector<2xf32>> + // CHECK-DAG: llvm.insertvalue [[d11]], {{%.+}}[1] : !llvm.array<2 x vector<2xf32>> + return %d : vector<2x2xf32> +} + +// ----- + // CHECK-LABEL: @m16n8k8_fp16 func.func @m16n8k8_fp16(%arg0: vector<2x2xf16>, %arg1: vector<1x2xf16>, %arg2: vector<2x2xf16>) -> vector<2x2xf16> { // CHECK: llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<2xf16>>