diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp --- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp +++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp @@ -145,7 +145,9 @@ Type f64Ty = rewriter.getF64Type(); Type f32Ty = rewriter.getF32Type(); Type i8Ty = rewriter.getI8Type(); + Type i4Ty = rewriter.getIntegerType(4); Type i8x4Ty = LLVM::getFixedVectorType(i8Ty, 4); + Type i4x8Ty = LLVM::getFixedVectorType(i4Ty, 8); Type f32x1Ty = LLVM::getFixedVectorType(f32Ty, 1); auto arrayTy = operand.getType().cast(); @@ -156,6 +158,7 @@ // For 4xi8 vectors, the intrinsic expects these to be provided as i32 // scalar types. if (arrayTy.getElementType() == i8x4Ty || + arrayTy.getElementType() == i4x8Ty || (arrayTy.getElementType() == f32x1Ty && operandPtxType == NVVM::MMATypes::tf32)) { result.push_back( @@ -281,6 +284,10 @@ ptxTypeA = NVVM::MMATypes::s8; ptxTypeB = NVVM::MMATypes::s8; overflow = NVVM::MMAIntOverflow::satfinite; + } else if (aType.getElementType().isInteger(4)) { + ptxTypeA = NVVM::MMATypes::s4; + ptxTypeB = NVVM::MMATypes::s4; + overflow = NVVM::MMAIntOverflow::satfinite; } else if (aType.getElementType().isF16()) { ptxTypeA = NVVM::MMATypes::f16; ptxTypeB = NVVM::MMATypes::f16; diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir --- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir +++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir @@ -102,6 +102,54 @@ // ----- +// CHECK-LABEL: @m16n8k32_i4 +func.func @m16n8k32_i4(%arg0: vector<2x8xi4>, %arg1: vector<1x8xi4>, %arg2: vector<2x2xi32>) -> vector<2x2xi32> { + // CHECK: [[el:%.+]] = llvm.extractvalue %{{.*}}[{{.*}}] : !llvm.array<2 x vector<8xi4>> + // CHECK: llvm.bitcast [[el]] : vector<8xi4> to i32 + // CHECK: [[el:%.+]] = llvm.extractvalue %{{.*}}[{{.*}}] : !llvm.array<2 x vector<8xi4>> + // CHECK: llvm.bitcast [[el]] : vector<8xi4> to i32 + // CHECK: [[el:%.+]] = llvm.extractvalue %{{.*}}[{{.*}}] : !llvm.array<1 x vector<8xi4>> + // CHECK: llvm.bitcast [[el]] : vector<8xi4> to i32 + // CHECK: [[el:%.+]] = llvm.extractvalue %{{.*}}[{{.*}}] : !llvm.array<2 x vector<2xi32>> + // CHECK: [[el:%.+]] = llvm.extractvalue %{{.*}}[{{.*}}] : !llvm.array<2 x vector<2xi32>> + // CHECK: [[d:%.+]] = nvvm.mma.sync + // CHECK-SAME: intOverflowBehavior = #nvvm.mma_int_overflow + // CHECK-SAME: multiplicandAPtxType = #nvvm.mma_type + // CHECK-SAME: multiplicandBPtxType = #nvvm.mma_type + // CHECK-SAME: shape = {k = 32 : i32, m = 16 : i32, n = 8 : i32} + %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 32]} : (vector<2x8xi4>, vector<1x8xi4>, vector<2x2xi32>) -> vector<2x2xi32> + return %d : vector<2x2xi32> +} + +// ----- + +// CHECK-LABEL: @m16n8k64_i4 +func.func @m16n8k64_i4(%arg0: vector<4x8xi4>, %arg1: vector<2x8xi4>, %arg2: vector<2x2xi32>) -> vector<2x2xi32> { + // CHECK: [[el:%.+]] = llvm.extractvalue %{{.*}}[{{.*}}] : !llvm.array<4 x vector<8xi4>> + // CHECK: llvm.bitcast [[el]] : vector<8xi4> to i32 + // CHECK: [[el:%.+]] = llvm.extractvalue %{{.*}}[{{.*}}] : !llvm.array<4 x vector<8xi4>> + // CHECK: llvm.bitcast [[el]] : vector<8xi4> to i32 + // CHECK: [[el:%.+]] = llvm.extractvalue %{{.*}}[{{.*}}] : !llvm.array<4 x vector<8xi4>> + // CHECK: llvm.bitcast [[el]] : vector<8xi4> to i32 + // CHECK: [[el:%.+]] = llvm.extractvalue %{{.*}}[{{.*}}] : !llvm.array<4 x vector<8xi4>> + // CHECK: llvm.bitcast [[el]] : vector<8xi4> to i32 + // CHECK: [[el:%.+]] = llvm.extractvalue %{{.*}}[{{.*}}] : !llvm.array<2 x vector<8xi4>> + // CHECK: llvm.bitcast [[el]] : vector<8xi4> to i32 + // CHECK: [[el:%.+]] = llvm.extractvalue %{{.*}}[{{.*}}] : !llvm.array<2 x vector<8xi4>> + // CHECK: llvm.bitcast [[el]] : vector<8xi4> to i32 + // CHECK: [[el:%.+]] = llvm.extractvalue %{{.*}}[{{.*}}] : !llvm.array<2 x vector<2xi32>> + // CHECK: [[el:%.+]] = llvm.extractvalue %{{.*}}[{{.*}}] : !llvm.array<2 x vector<2xi32>> + // CHECK: [[d:%.+]] = nvvm.mma.sync + // CHECK-SAME: intOverflowBehavior = #nvvm.mma_int_overflow + // CHECK-SAME: multiplicandAPtxType = #nvvm.mma_type + // CHECK-SAME: multiplicandBPtxType = #nvvm.mma_type + // CHECK-SAME: shape = {k = 64 : i32, m = 16 : i32, n = 8 : i32} + %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 64]} : (vector<4x8xi4>, vector<2x8xi4>, vector<2x2xi32>) -> vector<2x2xi32> + return %d : vector<2x2xi32> +} + +// ----- + // CHECK-LABEL: @m8n8k4_f64 func.func @m8n8k4_f64(%arg0: vector<1x1xf64>, %arg1: vector<1x1xf64>, %arg2: vector<1x2xf64>) -> vector<1x2xf64> { // CHECK: llvm.extractvalue